# Copyright 2020 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for mediapipe.python.solutions.pose.""" import json import os # pylint: disable=unused-import import tempfile # pylint: enable=unused-import from typing import NamedTuple from absl.testing import absltest from absl.testing import parameterized import cv2 import numpy as np import numpy.testing as npt from PIL import Image # resources dependency # undeclared dependency from mediapipe.python.solutions import drawing_styles from mediapipe.python.solutions import drawing_utils as mp_drawing from mediapipe.python.solutions import pose as mp_pose TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' DIFF_THRESHOLD = 15 # pixels EXPECTED_POSE_LANDMARKS = np.array([[460, 283], [467, 273], [471, 273], [474, 273], [465, 273], [465, 273], [466, 273], [491, 277], [480, 277], [470, 294], [465, 294], [545, 319], [453, 329], [622, 323], [375, 316], [696, 316], [299, 307], [719, 316], [278, 306], [721, 311], [274, 304], [713, 313], [283, 306], [520, 476], [467, 471], [612, 550], [358, 490], [701, 613], [349, 611], [709, 624], [363, 630], [730, 633], [303, 628]]) WORLD_DIFF_THRESHOLD = 0.2 # meters EXPECTED_POSE_WORLD_LANDMARKS = np.array([ [-0.11, -0.59, -0.15], [-0.09, -0.64, -0.16], [-0.09, -0.64, -0.16], [-0.09, -0.64, -0.16], [-0.11, -0.64, -0.14], [-0.11, -0.64, -0.14], [-0.11, -0.64, -0.14], [0.01, -0.65, -0.15], [-0.06, -0.64, -0.05], [-0.07, -0.57, -0.15], [-0.09, -0.57, -0.12], [0.18, -0.49, -0.09], [-0.14, -0.5, -0.03], [0.41, -0.48, -0.11], [-0.42, -0.5, -0.02], [0.64, -0.49, -0.17], [-0.63, -0.51, -0.13], [0.7, -0.5, -0.19], [-0.71, -0.53, -0.15], [0.72, -0.51, -0.23], [-0.69, -0.54, -0.19], [0.66, -0.49, -0.19], [-0.64, -0.52, -0.15], [0.09, 0., -0.04], [-0.09, -0., 0.03], [0.41, 0.23, -0.09], [-0.43, 0.1, -0.11], [0.69, 0.49, -0.04], [-0.48, 0.47, -0.02], [0.72, 0.52, -0.04], [-0.48, 0.51, -0.02], [0.8, 0.5, -0.14], [-0.59, 0.52, -0.11], ]) IOU_THRESHOLD = 0.85 # percents class PoseTest(parameterized.TestCase): def _landmarks_list_to_array(self, landmark_list, image_shape): rows, cols, _ = image_shape return np.asarray([(lmk.x * cols, lmk.y * rows, lmk.z * cols) for lmk in landmark_list.landmark]) def _world_landmarks_list_to_array(self, landmark_list): return np.asarray([(lmk.x, lmk.y, lmk.z) for lmk in landmark_list.landmark]) def _assert_diff_less(self, array1, array2, threshold): npt.assert_array_less(np.abs(array1 - array2), threshold) def _get_output_path(self, name): return os.path.join(tempfile.gettempdir(), self.id().split('.')[-1] + name) def _annotate(self, frame: np.ndarray, results: NamedTuple, idx: int): mp_drawing.draw_landmarks( frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, landmark_drawing_spec=drawing_styles.get_default_pose_landmarks_style()) path = self._get_output_path('_frame_{}.png'.format(idx)) cv2.imwrite(path, frame) def _annotate_segmentation(self, segmentation, expected_segmentation, idx: int): path = self._get_output_path('_segmentation_{}.png'.format(idx)) self._segmentation_to_rgb(segmentation).save(path) path = self._get_output_path('_segmentation_diff_{}.png'.format(idx)) self._segmentation_diff_to_rgb( expected_segmentation, segmentation).save(path) def _rgb_to_segmentation(self, img, back_color=(255, 0, 0), front_color=(0, 0, 255)): img = np.array(img) # Check all pixels are either front or back. is_back = (img == back_color).all(axis=2) is_front = (img == front_color).all(axis=2) np.logical_or(is_back, is_front).all() segm = np.zeros(img.shape[:2], dtype=np.uint8) segm[is_front] = 1 return segm def _segmentation_to_rgb(self, segm, back_color=(255, 0, 0), front_color=(0, 0, 255)): height, width = segm.shape img = np.zeros((height, width, 3), dtype=np.uint8) img[:, :] = back_color img[segm == 1] = front_color return Image.fromarray(img) def _segmentation_iou(self, segm_expected, segm_actual): intersection = segm_expected * segm_actual expected_dot = segm_expected * segm_expected actual_dot = segm_actual * segm_actual eps = np.finfo(np.float32).eps result = intersection.sum() / (expected_dot.sum() + actual_dot.sum() - intersection.sum() + eps) return result def _segmentation_diff_to_rgb(self, segm_expected, segm_actual, expected_color=(0, 255, 0), actual_color=(255, 0, 0)): height, width = segm_expected.shape img = np.zeros((height, width, 3), dtype=np.uint8) img[np.logical_and(segm_expected == 1, segm_actual == 0)] = expected_color img[np.logical_and(segm_expected == 0, segm_actual == 1)] = actual_color return Image.fromarray(img) def test_invalid_image_shape(self): with mp_pose.Pose() as pose: with self.assertRaisesRegex( ValueError, 'Input image must contain three channel rgb data.'): pose.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4)) def test_blank_image(self): with mp_pose.Pose(enable_segmentation=True) as pose: image = np.zeros([100, 100, 3], dtype=np.uint8) image.fill(255) results = pose.process(image) self.assertIsNone(results.pose_landmarks) self.assertIsNone(results.segmentation_mask) @parameterized.named_parameters(('static_lite', True, 0, 3), ('static_full', True, 1, 3), ('static_heavy', True, 2, 3), ('video_lite', False, 0, 3), ('video_full', False, 1, 3), ('video_heavy', False, 2, 3)) def test_on_image(self, static_image_mode, model_complexity, num_frames): image_path = os.path.join(os.path.dirname(__file__), 'testdata/pose.jpg') expected_segmentation_path = os.path.join( os.path.dirname(__file__), 'testdata/pose_segmentation.png') image = cv2.imread(image_path) expected_segmentation = self._rgb_to_segmentation( Image.open(expected_segmentation_path).convert('RGB')) with mp_pose.Pose(static_image_mode=static_image_mode, model_complexity=model_complexity, enable_segmentation=True) as pose: for idx in range(num_frames): results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) segmentation = results.segmentation_mask.round().astype(np.uint8) # TODO: Add rendering of world 3D when supported. self._annotate(image.copy(), results, idx) self._annotate_segmentation(segmentation, expected_segmentation, idx) self._assert_diff_less( self._landmarks_list_to_array(results.pose_landmarks, image.shape)[:, :2], EXPECTED_POSE_LANDMARKS, DIFF_THRESHOLD) self._assert_diff_less( self._world_landmarks_list_to_array(results.pose_world_landmarks), EXPECTED_POSE_WORLD_LANDMARKS, WORLD_DIFF_THRESHOLD) self.assertGreaterEqual( self._segmentation_iou(expected_segmentation, segmentation), IOU_THRESHOLD) @parameterized.named_parameters( ('full', 1, 'pose_squats.full.npz')) def test_on_video(self, model_complexity, expected_name): """Tests pose models on a video.""" # Set threshold for comparing actual and expected predictions in pixels. diff_threshold = 15 world_diff_threshold = 0.1 video_path = os.path.join(os.path.dirname(__file__), 'testdata/pose_squats.mp4') expected_path = os.path.join(os.path.dirname(__file__), 'testdata/{}'.format(expected_name)) # Predict pose landmarks for each frame. video_cap = cv2.VideoCapture(video_path) actual_per_frame = [] actual_world_per_frame = [] frame_idx = 0 with mp_pose.Pose(static_image_mode=False, model_complexity=model_complexity) as pose: while True: # Get next frame of the video. success, input_frame = video_cap.read() if not success: break # Run pose tracker. input_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) result = pose.process(image=input_frame) pose_landmarks = self._landmarks_list_to_array(result.pose_landmarks, input_frame.shape) pose_world_landmarks = self._world_landmarks_list_to_array( result.pose_world_landmarks) actual_per_frame.append(pose_landmarks) actual_world_per_frame.append(pose_world_landmarks) input_frame = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR) self._annotate(input_frame, result, frame_idx) frame_idx += 1 actual = np.array(actual_per_frame) actual_world = np.array(actual_world_per_frame) # Dump actual .npz. npz_path = self._get_output_path(expected_name) np.savez(npz_path, predictions=actual, predictions_world=actual_world) # Dump actual JSON. json_path = self._get_output_path(expected_name.replace('.npz', '.json')) with open(json_path, 'w') as fl: dump_data = { 'predictions': np.around(actual, 3).tolist(), 'predictions_world': np.around(actual_world, 3).tolist() } fl.write(json.dumps(dump_data, indent=2, separators=(',', ': '))) # Validate actual vs. expected landmarks. expected = np.load(expected_path)['predictions'] assert actual.shape == expected.shape, ( 'Unexpected shape of predictions: {} instead of {}'.format( actual.shape, expected.shape)) self._assert_diff_less( actual[..., :2], expected[..., :2], threshold=diff_threshold) # Validate actual vs. expected world landmarks. expected_world = np.load(expected_path)['predictions_world'] assert actual_world.shape == expected_world.shape, ( 'Unexpected shape of world predictions: {} instead of {}'.format( actual_world.shape, expected_world.shape)) self._assert_diff_less( actual_world, expected_world, threshold=world_diff_threshold) if __name__ == '__main__': absltest.main()