diff --git a/requirements.txt b/requirements.txt index b397302..ba6998f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,11 +6,11 @@ torch; sys_platform == 'darwin' torch==2.2.2+cpu; sys_platform == 'linux' -f https://download.pytorch.org/whl/cpu/torch_stable.html -torchvision == 0.18; sys_platform == 'darwin' +torchvision == 0.21; sys_platform == 'darwin' torchvision==0.17.2+cpu;sys_platform == 'linux' -f https://download.pytorch.org/whl/cpu/torch_stable.html -protobuf==4.25.3 +protobuf>4.25.3 pytest pylint pytest-asyncio \ No newline at end of file diff --git a/src/main.py b/src/main.py index 7f4ccce..da93f9c 100644 --- a/src/main.py +++ b/src/main.py @@ -15,7 +15,7 @@ async def main(): registry before the module adds the resource model. """ Registry.register_resource_creator( - Vision.SUBTYPE, + Vision.API, TorchVisionService.MODEL, ResourceCreatorRegistration( TorchVisionService.new_service, @@ -24,7 +24,7 @@ async def main(): ) module = Module.from_args() - module.add_model_from_registry(Vision.SUBTYPE, TorchVisionService.MODEL) + module.add_model_from_registry(Vision.API, TorchVisionService.MODEL) await module.start() diff --git a/src/properties.py b/src/properties.py index 0c892b0..ed25b6a 100644 --- a/src/properties.py +++ b/src/properties.py @@ -4,6 +4,6 @@ # pylint: disable=missing-class-docstring @dataclass class Properties: - implements_classification: bool = False - implements_detection: bool = False - implements_get_object_pcd: bool = False + classifications_supported: bool = False + detections_supported: bool = False + object_point_clouds_supported: bool = False diff --git a/src/torchvision_module.py b/src/torchvision_module.py index ba152e3..9a2f5b7 100644 --- a/src/torchvision_module.py +++ b/src/torchvision_module.py @@ -99,13 +99,13 @@ def get_attribute_from_config(attribute_name: str, default, of_type=None): model_name = get_attribute_from_config("model_name", None, str) self.properties = Properties( - implements_classification=True, - implements_detection=False, - implements_get_object_pcd=False, + classifications_supported=True, + detections_supported=False, + object_point_clouds_supported=False, ) if model_name in DETECTION_MODELS: - self.properties.implements_classification = False - self.properties.implements_detection = True + self.properties.classifications_supported = False + self.properties.detections_supported = True weights = get_attribute_from_config("weights", "DEFAULT") try: @@ -197,8 +197,7 @@ async def capture_all_from_camera( "is not the configured 'camera_name'", self.camera_name, ) - image = await self.get_image_from_dependency(camera_name) - + image = await self.camera.get_image(mime_type=CameraMimeType.JPEG) if return_image: result.image = image if return_classifications: @@ -227,7 +226,7 @@ async def get_object_point_clouds( timeout: Optional[float] = None, **kwargs, ) -> List[PointCloudObject]: - if not self.properties.implements_get_object_pcd: + if not self.properties.object_point_clouds_supported: raise NotImplementedError return 1 @@ -239,14 +238,14 @@ async def get_detections( timeout: float, ) -> List[Detection]: """Get detections from an image""" - if not self.properties.implements_detection: + if not self.properties.detections_supported: raise NotImplementedError LOGGER.info(f"input image is: {type(image)}") image = decode_image(image) input_tensor = self.preprocessor(image) with torch.no_grad(): prediction: Tensor = self.model(input_tensor)[0] - return self.wrap_detections(prediction) + return self.wrap_detections(prediction, image.shape) async def get_classifications( self, @@ -257,7 +256,7 @@ async def get_classifications( timeout: Optional[float] = None, ) -> List[Classification]: """Get classifications from image""" - if not self.properties.implements_classification: + if not self.properties.classifications_supported: return NotImplementedError image = decode_image(image) input_tensor = self.preprocessor(image) @@ -276,7 +275,7 @@ async def get_classifications_from_camera( timeout: Optional[float] = None, ) -> List[Classification]: """Gets classifications from a camera dependency""" - if not self.properties.implements_classification: + if not self.properties.classifications_supported: raise NotImplementedError if camera_name not in (self.camera_name, ""): @@ -296,7 +295,7 @@ async def get_detections_from_camera( self, camera_name: str, *, extra: Mapping[str, Any] = None, timeout: float = None, ) -> List[Detection]: """Gets detections from a camera dependency""" - if not self.properties.implements_detection: + if not self.properties.detections_supported: raise NotImplementedError if camera_name not in (self.camera_name, ""): raise ValueError( @@ -310,7 +309,7 @@ async def get_detections_from_camera( input_tensor = self.preprocessor(image) with torch.no_grad(): prediction: Tensor = self.model(input_tensor)[0] - return self.wrap_detections(prediction) + return self.wrap_detections(prediction, image.shape) async def get_properties( self, @@ -366,7 +365,7 @@ async def get_image_from_dependency(self, camera_name: str): im = await self.camera.get_image(mime_type=CameraMimeType.JPEG) return decode_image(im) - def wrap_detections(self, prediction: dict): + def wrap_detections(self, prediction: dict, image_shape: Sequence[int]): """_summary_ converts prediction output tensor from torchvision model for viam API @@ -379,17 +378,7 @@ def wrap_detections(self, prediction: dict): labels = [self.weights.meta["categories"][i] for i in prediction["labels"]] scores = prediction["scores"] boxes = prediction["boxes"].to(torch.int64).tolist() - res = [ - Detection( - x_min=x_min, - y_min=y_min, - x_max=x_max, - y_max=y_max, - confidence=score, - class_name=label, - ) - for (x_min, y_min, x_max, y_max), score, label in zip(boxes, scores, labels) - ] + res = self.make_detections(boxes, scores, labels, image_shape) res = self.filter_output(res) return res @@ -414,3 +403,32 @@ def wrap_classifications(self, prediction, count): ] res = self.filter_output(res) return res + + def make_detections(self, boxes, scores, labels, image_shape): + if image_shape[0] == 0 or image_shape[1] == 0: + return [ + Detection( + x_min=x_min, + y_min=y_min, + x_max=x_max, + y_max=y_max, + confidence=score, + class_name=label, + ) + for (x_min, y_min, x_max, y_max), score, label in zip(boxes, scores, labels) + ] + return [ + Detection( + x_min=x_min, + y_min=y_min, + x_max=x_max, + y_max=y_max, + x_min_normalized=x_min / image_shape[1], + y_min_normalized=y_min / image_shape[0], + x_max_normalized=x_max / image_shape[1], + y_max_normalized=y_max / image_shape[0], + confidence=score, + class_name=label, + ) + for (x_min, y_min, x_max, y_max), score, label in zip(boxes, scores, labels) + ] diff --git a/src/utils.py b/src/utils.py index aa4309c..6d6d1a4 100644 --- a/src/utils.py +++ b/src/utils.py @@ -35,7 +35,7 @@ def decode_image(image: Union[Image.Image, ViamImage, np.ndarray]) -> np.ndarray im = Image.open(BytesIO(image.data), formats=LIBRARY_SUPPORTED_FORMATS).convert( "RGB" ) # convert in RGB png openened in RGBA - return im + return np.array(im) res = image.convert("RGB") rgb = np.array(res) return rgb diff --git a/tests/test_torchvision.py b/tests/test_torchvision.py index 81b667b..2218ee0 100644 --- a/tests/test_torchvision.py +++ b/tests/test_torchvision.py @@ -11,6 +11,8 @@ from PIL import Image +from viam.components.camera import Camera +from viam.media.video import CameraMimeType from viam.media.utils.pil import pil_to_viam_image from viam.proto.app.robot import ComponentConfig @@ -20,7 +22,8 @@ from google.protobuf.struct_pb2 import Struct path_to_input_image = "tests/grasshopper.jpg" -input_image = np.array(Image.open(path_to_input_image)) +image = Image.open(path_to_input_image) +input_image = np.array(image) cfg = ServiceConfig( attributes=dict_to_struct( { @@ -33,9 +36,9 @@ VISION_SERVICE_NAME = "vision1" DETECTIONS = [] PROPERTIES = Properties( - implements_classification=True, - implements_detection=False, - implements_get_object_pcd=False + classifications_supported=True, + detections_supported=False, + object_point_clouds_supported=False ) def make_component_config(dictionary: Mapping[str, Any]) -> ComponentConfig: @@ -80,15 +83,14 @@ async def test_get_properties(self, get_image_from_dependency, fake_cam): assert response == PROPERTIES @patch('viam.components.camera.Camera.get_resource_name', return_value="fake_cam") - @patch.object(TorchVisionService, 'get_image_from_dependency', new_callable=AsyncMock) - def test_capture_all_from_camera(self, get_image_from_dependency, fake_cam): + def test_capture_all_from_camera(self, fake_cam): camera = TorchVisionService( name='tvs' ) camera.camera_name = "fake_cam" - get_image_from_dependency.return_value = input_image camera.reconfigure(cfg, dependencies={"fake_cam": Mock()}) + camera.camera.get_image = AsyncMock(return_value=pil_to_viam_image(image, mime_type=CameraMimeType.JPEG)) # without point clouds = True result = asyncio.run(camera.capture_all_from_camera( @@ -118,16 +120,14 @@ def test_capture_all_from_camera(self, get_image_from_dependency, fake_cam): # mock_get_detections.assert_called_once_with('test_image', timeout=None) @patch('viam.components.camera.Camera.get_resource_name', return_value="fake_cam") - @patch.object(TorchVisionService, 'get_image_from_dependency', new_callable=AsyncMock) - def test_default_camera_behavior(self, get_image_from_dependency, fake_cam): + def test_default_camera_behavior(self, fake_cam): vs = TorchVisionService( name='tvs' ) - get_image_from_dependency.return_value = input_image - # vs.camera_name = "fake_cam" vs.reconfigure(cfg, dependencies={"fake_cam": Mock()}) - + vs.camera.get_image = AsyncMock(return_value=pil_to_viam_image(image, mime_type=CameraMimeType.JPEG)) + result = vs.get_classifications_from_camera( "", count=1, @@ -155,3 +155,32 @@ def test_default_camera_behavior(self, get_image_from_dependency, fake_cam): ))) assert 'not_cam' in str(excinfo.value) + @patch('viam.components.camera.Camera.get_resource_name', return_value="fake_cam") + @patch.object(TorchVisionService, 'get_image_from_dependency', new_callable=AsyncMock) + def test_get_detections(self, get_image_from_dependency, fake_cam): + vision = TorchVisionService( + name='tvs' + ) + vision.camera_name = "fake_cam" + get_image_from_dependency.return_value = input_image + + cfg = ServiceConfig( + attributes=dict_to_struct( + { + "model_name": "fasterrcnn_resnet50_fpn", + "camera_name": "fake_cam", + } + ) + ) + + vision.reconfigure(cfg, dependencies={"fake_cam": Mock()}) + response = asyncio.run(vision.get_detections_from_camera("fake_cam")) + assert response is not None + assert response[0].x_min is not None + assert response[0].y_min is not None + assert response[0].x_max is not None + assert response[0].y_max is not None + assert response[0].x_min_normalized == response[0].x_min / input_image.shape[1] + assert response[0].y_min_normalized == response[0].y_min / input_image.shape[0] + assert response[0].x_max_normalized == response[0].x_max / input_image.shape[1] + assert response[0].y_max_normalized == response[0].y_max / input_image.shape[0]