diff --git a/src/preprocess.py b/src/preprocess.py index 6737c2a..a07dd64 100644 --- a/src/preprocess.py +++ b/src/preprocess.py @@ -35,7 +35,7 @@ def __repr__(self): # pylint: disable=too-few-public-methods class Preprocessor: """Main wrapper class that performs a series of preprocessing steps""" - # pylint: disable=too-many-arguments + # pylint: disable=too-many-arguments,too-many-positional-arguments def __init__( self, weights_transform=None, diff --git a/src/torchvision_module.py b/src/torchvision_module.py index 04583e2..ba152e3 100644 --- a/src/torchvision_module.py +++ b/src/torchvision_module.py @@ -142,7 +142,7 @@ def get_attribute_from_config(attribute_name: str, default, of_type=None): self.default_minimum_confidence = get_attribute_from_config( "default_minimum_confidence", 0 ) - #pylint: disable=too-many-arguments + #pylint: disable=too-many-arguments,too-many-positional-arguments async def capture_all_from_camera( self, camera_name: str, @@ -189,6 +189,14 @@ async def capture_all_from_camera( classifications, and objects, as well as any extra info the model may provide. """ result = CaptureAllResult() + + if camera_name not in (self.camera_name, ""): + raise ValueError( + "Camera name passed to method:", + camera_name, + "is not the configured 'camera_name'", + self.camera_name, + ) image = await self.get_image_from_dependency(camera_name) if return_image: @@ -271,6 +279,13 @@ async def get_classifications_from_camera( if not self.properties.implements_classification: raise NotImplementedError + if camera_name not in (self.camera_name, ""): + raise ValueError( + "Camera name passed to method:", + camera_name, + "is not the configured 'camera_name'", + self.camera_name, + ) image = await self.get_image_from_dependency(camera_name) input_tensor = self.preprocessor(image) with torch.no_grad(): @@ -278,11 +293,18 @@ async def get_classifications_from_camera( return self.wrap_classifications(prediction, count) async def get_detections_from_camera( - self, camera_name: str, *, extra: Mapping[str, Any], timeout: float + self, camera_name: str, *, extra: Mapping[str, Any] = None, timeout: float = None, ) -> List[Detection]: """Gets detections from a camera dependency""" if not self.properties.implements_detection: raise NotImplementedError + if camera_name not in (self.camera_name, ""): + raise ValueError( + "Camera name passed to method:", + camera_name, + "is not the configured 'camera_name'", + self.camera_name, + ) image = await self.get_image_from_dependency(camera_name) LOGGER.info(f"input image is: {type(image)}") input_tensor = self.preprocessor(image) diff --git a/tests/test_torchvision.py b/tests/test_torchvision.py index a5229a0..81b667b 100644 --- a/tests/test_torchvision.py +++ b/tests/test_torchvision.py @@ -92,7 +92,7 @@ def test_capture_all_from_camera(self, get_image_from_dependency, fake_cam): # without point clouds = True result = asyncio.run(camera.capture_all_from_camera( - 'camera1', + 'fake_cam', return_image=True, return_classifications=True, return_detections=True @@ -103,7 +103,7 @@ def test_capture_all_from_camera(self, get_image_from_dependency, fake_cam): assert result.detections is None result = asyncio.run(camera.capture_all_from_camera( - 'camera1', + 'fake_cam', return_image=True, return_classifications=True, return_detections=True, @@ -116,3 +116,42 @@ def test_capture_all_from_camera(self, get_image_from_dependency, fake_cam): assert result.objects is None # mock_get_classifications.assert_called_once_with('test_image', 1) # mock_get_detections.assert_called_once_with('test_image', timeout=None) + + @patch('viam.components.camera.Camera.get_resource_name', return_value="fake_cam") + @patch.object(TorchVisionService, 'get_image_from_dependency', new_callable=AsyncMock) + def test_default_camera_behavior(self, get_image_from_dependency, fake_cam): + vs = TorchVisionService( + name='tvs' + ) + get_image_from_dependency.return_value = input_image + + # vs.camera_name = "fake_cam" + vs.reconfigure(cfg, dependencies={"fake_cam": Mock()}) + + result = vs.get_classifications_from_camera( + "", + count=1, + ) + assert result is not None + + result = asyncio.run(vs.capture_all_from_camera( + "", + return_classifications=True, + )) + assert result is not None + assert result.classifications is not None + + with pytest.raises(ValueError) as excinfo: + asyncio.run(vs.get_classifications_from_camera( + "not_cam", + count=1, + )) + assert 'not_cam' in str(excinfo.value) + with pytest.raises(ValueError) as excinfo: + asyncio.run(asyncio.run(vs.capture_all_from_camera( + "not_cam", + return_classifications=True, + return_detections=True, + ))) + assert 'not_cam' in str(excinfo.value) +