397 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			397 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2020 The MediaPipe Authors.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #      http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| """Tests for mediapipe.python.solution_base."""
 | |
| 
 | |
| from absl.testing import absltest
 | |
| from absl.testing import parameterized
 | |
| import numpy as np
 | |
| 
 | |
| from google.protobuf import text_format
 | |
| from mediapipe.framework import calculator_pb2
 | |
| from mediapipe.framework.formats import detection_pb2
 | |
| from mediapipe.python import solution_base
 | |
| from mediapipe.python.solution_base import PacketDataType
 | |
| 
 | |
| CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG = """
 | |
|   input_stream: 'image_in'
 | |
|   output_stream: 'image_out'
 | |
|   node {
 | |
|     name: 'ImageTransformation'
 | |
|     calculator: 'ImageTransformationCalculator'
 | |
|     input_stream: 'IMAGE:image_in'
 | |
|     output_stream: 'IMAGE:image_out'
 | |
|     options: {
 | |
|       [mediapipe.ImageTransformationCalculatorOptions.ext] {
 | |
|          output_width: 10
 | |
|          output_height: 10
 | |
|       }
 | |
|     }
 | |
|     node_options: {
 | |
|       [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
 | |
|          output_width: 10
 | |
|          output_height: 10
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| """
 | |
| 
 | |
| 
 | |
| class SolutionBaseTest(parameterized.TestCase):
 | |
| 
 | |
|   def test_invalid_initialization_arguments(self):
 | |
|     with self.assertRaisesRegex(
 | |
|         ValueError,
 | |
|         'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.'
 | |
|     ):
 | |
|       solution_base.SolutionBase()
 | |
|     with self.assertRaisesRegex(
 | |
|         ValueError,
 | |
|         'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.'
 | |
|     ):
 | |
|       solution_base.SolutionBase(
 | |
|           graph_config=calculator_pb2.CalculatorGraphConfig(),
 | |
|           binary_graph_path='/tmp/no_such.binarypb')
 | |
| 
 | |
|   @parameterized.named_parameters(('no_graph_input_output_stream', """
 | |
|       node {
 | |
|         calculator: 'PassThroughCalculator'
 | |
|         input_stream: 'in'
 | |
|         output_stream: 'out'
 | |
|       }
 | |
|       """, RuntimeError, 'does not have a corresponding output stream.'),
 | |
|                                   ('calcualtor_io_mismatch', """
 | |
|       node {
 | |
|         calculator: 'PassThroughCalculator'
 | |
|         input_stream: 'in'
 | |
|         input_stream: 'in2'
 | |
|         output_stream: 'out'
 | |
|       }
 | |
|       """, ValueError, 'must use matching tags and indexes.'),
 | |
|                                   ('unkown_registered_stream_type_name', """
 | |
|       input_stream: 'in'
 | |
|       output_stream: 'out'
 | |
|       node {
 | |
|         calculator: 'PassThroughCalculator'
 | |
|         input_stream: 'in'
 | |
|         output_stream: 'out'
 | |
|       }
 | |
|       """, RuntimeError, 'Unable to find the type for stream \"in\".'))
 | |
|   def test_invalid_config(self, text_config, error_type, error_message):
 | |
|     config_proto = text_format.Parse(text_config,
 | |
|                                      calculator_pb2.CalculatorGraphConfig())
 | |
|     with self.assertRaisesRegex(error_type, error_message):
 | |
|       solution_base.SolutionBase(graph_config=config_proto)
 | |
| 
 | |
|   def test_valid_input_data_type_proto(self):
 | |
|     text_config = """
 | |
|       input_stream: 'input_detections'
 | |
|       output_stream: 'output_detections'
 | |
|       node {
 | |
|         calculator: 'DetectionUniqueIdCalculator'
 | |
|         input_stream: 'DETECTION_LIST:input_detections'
 | |
|         output_stream: 'DETECTION_LIST:output_detections'
 | |
|       }
 | |
|     """
 | |
|     config_proto = text_format.Parse(text_config,
 | |
|                                      calculator_pb2.CalculatorGraphConfig())
 | |
|     with solution_base.SolutionBase(graph_config=config_proto) as solution:
 | |
|       input_detections = detection_pb2.DetectionList()
 | |
|       detection_1 = input_detections.detection.add()
 | |
|       text_format.Parse('score: 0.5', detection_1)
 | |
|       detection_2 = input_detections.detection.add()
 | |
|       text_format.Parse('score: 0.8', detection_2)
 | |
|       results = solution.process({'input_detections': input_detections})
 | |
|       self.assertTrue(hasattr(results, 'output_detections'))
 | |
|       self.assertLen(results.output_detections.detection, 2)
 | |
|       expected_detection_1 = detection_pb2.Detection()
 | |
|       text_format.Parse('score: 0.5, detection_id: 1', expected_detection_1)
 | |
|       expected_detection_2 = detection_pb2.Detection()
 | |
|       text_format.Parse('score: 0.8, detection_id: 2', expected_detection_2)
 | |
|       self.assertEqual(results.output_detections.detection[0],
 | |
|                        expected_detection_1)
 | |
|       self.assertEqual(results.output_detections.detection[1],
 | |
|                        expected_detection_2)
 | |
| 
 | |
|   def test_invalid_input_data_type_proto_vector(self):
 | |
|     text_config = """
 | |
|       input_stream: 'input_detections'
 | |
|       output_stream: 'output_detections'
 | |
|       node {
 | |
|         calculator: 'DetectionUniqueIdCalculator'
 | |
|         input_stream: 'DETECTIONS:input_detections'
 | |
|         output_stream: 'DETECTIONS:output_detections'
 | |
|       }
 | |
|     """
 | |
|     config_proto = text_format.Parse(text_config,
 | |
|                                      calculator_pb2.CalculatorGraphConfig())
 | |
|     with solution_base.SolutionBase(graph_config=config_proto) as solution:
 | |
|       detection = detection_pb2.Detection()
 | |
|       text_format.Parse('score: 0.5', detection)
 | |
|       with self.assertRaisesRegex(
 | |
|           NotImplementedError,
 | |
|           'SolutionBase can only process non-audio and non-proto-list data. '
 | |
|           + 'PROTO_LIST type is not supported.'
 | |
|       ):
 | |
|         solution.process({'input_detections': detection})
 | |
| 
 | |
|   def test_invalid_input_image_data(self):
 | |
|     text_config = """
 | |
|       input_stream: 'image_in'
 | |
|       output_stream: 'image_out'
 | |
|       node {
 | |
|         calculator: 'ImageTransformationCalculator'
 | |
|         input_stream: 'IMAGE:image_in'
 | |
|         output_stream: 'IMAGE:transformed_image_in'
 | |
|       }
 | |
|       node {
 | |
|         calculator: 'ImageTransformationCalculator'
 | |
|         input_stream: 'IMAGE:transformed_image_in'
 | |
|         output_stream: 'IMAGE:image_out'
 | |
|       }
 | |
|     """
 | |
|     config_proto = text_format.Parse(text_config,
 | |
|                                      calculator_pb2.CalculatorGraphConfig())
 | |
|     with solution_base.SolutionBase(graph_config=config_proto) as solution:
 | |
|       with self.assertRaisesRegex(
 | |
|           ValueError, 'Input image must contain three channel rgb data.'):
 | |
|         solution.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4))
 | |
| 
 | |
|   @parameterized.named_parameters(('graph_without_side_packets', """
 | |
|       input_stream: 'image_in'
 | |
|       output_stream: 'image_out'
 | |
|       node {
 | |
|         calculator: 'ImageTransformationCalculator'
 | |
|         input_stream: 'IMAGE:image_in'
 | |
|         output_stream: 'IMAGE:transformed_image_in'
 | |
|       }
 | |
|       node {
 | |
|         calculator: 'ImageTransformationCalculator'
 | |
|         input_stream: 'IMAGE:transformed_image_in'
 | |
|         output_stream: 'IMAGE:image_out'
 | |
|       }
 | |
|       """, None), ('graph_with_side_packets', """
 | |
|       input_stream: 'image_in'
 | |
|       input_side_packet: 'allow_signal'
 | |
|       input_side_packet: 'rotation_degrees'
 | |
|       output_stream: 'image_out'
 | |
|       node {
 | |
|         calculator: 'ImageTransformationCalculator'
 | |
|         input_stream: 'IMAGE:image_in'
 | |
|         input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
 | |
|         output_stream: 'IMAGE:transformed_image_in'
 | |
|       }
 | |
|       node {
 | |
|         calculator: 'GateCalculator'
 | |
|         input_stream: 'transformed_image_in'
 | |
|         input_side_packet: 'ALLOW:allow_signal'
 | |
|         output_stream: 'image_out_to_transform'
 | |
|       }
 | |
|       node {
 | |
|         calculator: 'ImageTransformationCalculator'
 | |
|         input_stream: 'IMAGE:image_out_to_transform'
 | |
|         input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
 | |
|         output_stream: 'IMAGE:image_out'
 | |
|       }""", {
 | |
|           'allow_signal': True,
 | |
|           'rotation_degrees': 0
 | |
|       }))
 | |
|   def test_solution_process(self, text_config, side_inputs):
 | |
|     self._process_and_verify(
 | |
|         config_proto=text_format.Parse(text_config,
 | |
|                                        calculator_pb2.CalculatorGraphConfig()),
 | |
|         side_inputs=side_inputs)
 | |
| 
 | |
|   def test_invalid_calculator_options(self):
 | |
|     text_config = """
 | |
|       input_stream: 'image_in'
 | |
|       output_stream: 'image_out'
 | |
|       node {
 | |
|         calculator: 'ImageTransformationCalculator'
 | |
|         input_stream: 'IMAGE:image_in'
 | |
|         output_stream: 'IMAGE:transformed_image_in'
 | |
|       }
 | |
|       node {
 | |
|         name: 'SignalGate'
 | |
|         calculator: 'GateCalculator'
 | |
|         input_stream: 'transformed_image_in'
 | |
|         input_side_packet: 'ALLOW:allow_signal'
 | |
|         output_stream: 'image_out_to_transform'
 | |
|       }
 | |
|       node {
 | |
|         calculator: 'ImageTransformationCalculator'
 | |
|         input_stream: 'IMAGE:image_out_to_transform'
 | |
|         output_stream: 'IMAGE:image_out'
 | |
|       }
 | |
|     """
 | |
|     config_proto = text_format.Parse(text_config,
 | |
|                                      calculator_pb2.CalculatorGraphConfig())
 | |
|     with self.assertRaisesRegex(
 | |
|         ValueError,
 | |
|         'Modifying the calculator options of SignalGate is not supported.'):
 | |
|       solution_base.SolutionBase(
 | |
|           graph_config=config_proto,
 | |
|           calculator_params={'SignalGate.invalid_field': 'I am invalid'})
 | |
| 
 | |
|   def test_calculator_has_both_options_and_node_options(self):
 | |
|     config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
 | |
|                                      calculator_pb2.CalculatorGraphConfig())
 | |
|     with self.assertRaisesRegex(ValueError,
 | |
|                                 'has both options and node_options fields.'):
 | |
|       solution_base.SolutionBase(
 | |
|           graph_config=config_proto,
 | |
|           calculator_params={
 | |
|               'ImageTransformation.output_width': 0,
 | |
|               'ImageTransformation.output_height': 0
 | |
|           })
 | |
| 
 | |
|   def test_modifying_calculator_proto2_options(self):
 | |
|     config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
 | |
|                                      calculator_pb2.CalculatorGraphConfig())
 | |
|     # To test proto2 options only, remove the proto3 node_options field from the
 | |
|     # graph config.
 | |
|     self.assertEqual('ImageTransformation', config_proto.node[0].name)
 | |
|     config_proto.node[0].ClearField('node_options')
 | |
|     self._process_and_verify(
 | |
|         config_proto=config_proto,
 | |
|         calculator_params={
 | |
|             'ImageTransformation.output_width': 0,
 | |
|             'ImageTransformation.output_height': 0
 | |
|         })
 | |
| 
 | |
|   def test_modifying_calculator_proto3_node_options(self):
 | |
|     config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
 | |
|                                      calculator_pb2.CalculatorGraphConfig())
 | |
|     # To test proto3 node options only, remove the proto2 options field from the
 | |
|     # graph config.
 | |
|     self.assertEqual('ImageTransformation', config_proto.node[0].name)
 | |
|     config_proto.node[0].ClearField('options')
 | |
|     self._process_and_verify(
 | |
|         config_proto=config_proto,
 | |
|         calculator_params={
 | |
|             'ImageTransformation.output_width': 0,
 | |
|             'ImageTransformation.output_height': 0
 | |
|         })
 | |
| 
 | |
|   def test_adding_calculator_options(self):
 | |
|     config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
 | |
|                                      calculator_pb2.CalculatorGraphConfig())
 | |
|     # To test a calculator with no options field, remove both proto2 options and
 | |
|     # proto3 node_options fields from the graph config.
 | |
|     self.assertEqual('ImageTransformation', config_proto.node[0].name)
 | |
|     config_proto.node[0].ClearField('options')
 | |
|     config_proto.node[0].ClearField('node_options')
 | |
|     self._process_and_verify(
 | |
|         config_proto=config_proto,
 | |
|         calculator_params={
 | |
|             'ImageTransformation.output_width': 0,
 | |
|             'ImageTransformation.output_height': 0
 | |
|         })
 | |
| 
 | |
|   @parameterized.named_parameters(('graph_without_side_packets', """
 | |
|       input_stream: 'image_in'
 | |
|       output_stream: 'image_out'
 | |
|       node {
 | |
|         calculator: 'ImageTransformationCalculator'
 | |
|         input_stream: 'IMAGE:image_in'
 | |
|         output_stream: 'IMAGE:transformed_image_in'
 | |
|       }
 | |
|       node {
 | |
|         calculator: 'ImageTransformationCalculator'
 | |
|         input_stream: 'IMAGE:transformed_image_in'
 | |
|         output_stream: 'IMAGE:image_out'
 | |
|       }
 | |
|       """, None), ('graph_with_side_packets', """
 | |
|       input_stream: 'image_in'
 | |
|       input_side_packet: 'allow_signal'
 | |
|       input_side_packet: 'rotation_degrees'
 | |
|       output_stream: 'image_out'
 | |
|       node {
 | |
|         calculator: 'ImageTransformationCalculator'
 | |
|         input_stream: 'IMAGE:image_in'
 | |
|         input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
 | |
|         output_stream: 'IMAGE:transformed_image_in'
 | |
|       }
 | |
|       node {
 | |
|         calculator: 'GateCalculator'
 | |
|         input_stream: 'transformed_image_in'
 | |
|         input_side_packet: 'ALLOW:allow_signal'
 | |
|         output_stream: 'image_out_to_transform'
 | |
|       }
 | |
|       node {
 | |
|         calculator: 'ImageTransformationCalculator'
 | |
|         input_stream: 'IMAGE:image_out_to_transform'
 | |
|         input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
 | |
|         output_stream: 'IMAGE:image_out'
 | |
|       }""", {
 | |
|           'allow_signal': True,
 | |
|           'rotation_degrees': 0
 | |
|       }))
 | |
|   def test_solution_reset(self, text_config, side_inputs):
 | |
|     config_proto = text_format.Parse(text_config,
 | |
|                                      calculator_pb2.CalculatorGraphConfig())
 | |
|     input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
 | |
|     with solution_base.SolutionBase(
 | |
|         graph_config=config_proto, side_inputs=side_inputs) as solution:
 | |
|       for _ in range(20):
 | |
|         outputs = solution.process(input_image)
 | |
|         self.assertTrue(np.array_equal(input_image, outputs.image_out))
 | |
|         solution.reset()
 | |
| 
 | |
|   def test_solution_stream_type_hints(self):
 | |
|     text_config = """
 | |
|       input_stream: 'union_type_image_in'
 | |
|       output_stream: 'image_type_out'
 | |
|       node {
 | |
|         calculator: 'ToImageCalculator'
 | |
|         input_stream: 'IMAGE:union_type_image_in'
 | |
|         output_stream: 'IMAGE:image_type_out'
 | |
|       }
 | |
|     """
 | |
|     config_proto = text_format.Parse(text_config,
 | |
|                                      calculator_pb2.CalculatorGraphConfig())
 | |
|     input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
 | |
|     with solution_base.SolutionBase(
 | |
|         graph_config=config_proto,
 | |
|         stream_type_hints={'union_type_image_in': PacketDataType.IMAGE
 | |
|                           }) as solution:
 | |
|       for _ in range(20):
 | |
|         outputs = solution.process(input_image)
 | |
|         self.assertTrue(np.array_equal(input_image, outputs.image_type_out))
 | |
|     with solution_base.SolutionBase(
 | |
|         graph_config=config_proto,
 | |
|         stream_type_hints={'union_type_image_in': PacketDataType.IMAGE_FRAME
 | |
|                           }) as solution2:
 | |
|       for _ in range(20):
 | |
|         outputs = solution2.process(input_image)
 | |
|         self.assertTrue(np.array_equal(input_image, outputs.image_type_out))
 | |
| 
 | |
|   def _process_and_verify(self,
 | |
|                           config_proto,
 | |
|                           side_inputs=None,
 | |
|                           calculator_params=None):
 | |
|     input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
 | |
|     with solution_base.SolutionBase(
 | |
|         graph_config=config_proto,
 | |
|         side_inputs=side_inputs,
 | |
|         calculator_params=calculator_params) as solution:
 | |
|       outputs = solution.process(input_image)
 | |
|       outputs2 = solution.process({'image_in': input_image})
 | |
|     self.assertTrue(np.array_equal(input_image, outputs.image_out))
 | |
|     self.assertTrue(np.array_equal(input_image, outputs2.image_out))
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|   absltest.main()
 |