Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __repr__(self):
# pylint: disable=too-few-public-methods
class Preprocessor:
"""Main wrapper class that performs a series of preprocessing steps"""
# pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments,too-many-positional-arguments
def __init__(
self,
weights_transform=None,
Expand Down
26 changes: 24 additions & 2 deletions src/torchvision_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_attribute_from_config(attribute_name: str, default, of_type=None):
self.default_minimum_confidence = get_attribute_from_config(
"default_minimum_confidence", 0
)
#pylint: disable=too-many-arguments
#pylint: disable=too-many-arguments,too-many-positional-arguments
async def capture_all_from_camera(
self,
camera_name: str,
Expand Down Expand Up @@ -189,6 +189,14 @@ async def capture_all_from_camera(
classifications, and objects, as well as any extra info the model may provide.
"""
result = CaptureAllResult()

if camera_name not in (self.camera_name, ""):
raise ValueError(
"Camera name passed to method:",
camera_name,
"is not the configured 'camera_name'",
self.camera_name,
)
image = await self.get_image_from_dependency(camera_name)

if return_image:
Expand Down Expand Up @@ -271,18 +279,32 @@ async def get_classifications_from_camera(
if not self.properties.implements_classification:
raise NotImplementedError

if camera_name not in (self.camera_name, ""):
raise ValueError(
"Camera name passed to method:",
camera_name,
"is not the configured 'camera_name'",
self.camera_name,
)
image = await self.get_image_from_dependency(camera_name)
input_tensor = self.preprocessor(image)
with torch.no_grad():
prediction: Tensor = self.model(input_tensor)
return self.wrap_classifications(prediction, count)

async def get_detections_from_camera(
self, camera_name: str, *, extra: Mapping[str, Any], timeout: float
self, camera_name: str, *, extra: Mapping[str, Any] = None, timeout: float = None,
) -> List[Detection]:
"""Gets detections from a camera dependency"""
if not self.properties.implements_detection:
raise NotImplementedError
if camera_name not in (self.camera_name, ""):
raise ValueError(
"Camera name passed to method:",
camera_name,
"is not the configured 'camera_name'",
self.camera_name,
)
image = await self.get_image_from_dependency(camera_name)
LOGGER.info(f"input image is: {type(image)}")
input_tensor = self.preprocessor(image)
Expand Down
43 changes: 41 additions & 2 deletions tests/test_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_capture_all_from_camera(self, get_image_from_dependency, fake_cam):

# without point clouds = True
result = asyncio.run(camera.capture_all_from_camera(
'camera1',
'fake_cam',
return_image=True,
return_classifications=True,
return_detections=True
Expand All @@ -103,7 +103,7 @@ def test_capture_all_from_camera(self, get_image_from_dependency, fake_cam):
assert result.detections is None

result = asyncio.run(camera.capture_all_from_camera(
'camera1',
'fake_cam',
return_image=True,
return_classifications=True,
return_detections=True,
Expand All @@ -116,3 +116,42 @@ def test_capture_all_from_camera(self, get_image_from_dependency, fake_cam):
assert result.objects is None
# mock_get_classifications.assert_called_once_with('test_image', 1)
# mock_get_detections.assert_called_once_with('test_image', timeout=None)

@patch('viam.components.camera.Camera.get_resource_name', return_value="fake_cam")
@patch.object(TorchVisionService, 'get_image_from_dependency', new_callable=AsyncMock)
def test_default_camera_behavior(self, get_image_from_dependency, fake_cam):
vs = TorchVisionService(
name='tvs'
)
get_image_from_dependency.return_value = input_image

# vs.camera_name = "fake_cam"
vs.reconfigure(cfg, dependencies={"fake_cam": Mock()})

result = vs.get_classifications_from_camera(
"",
count=1,
)
assert result is not None

result = asyncio.run(vs.capture_all_from_camera(
"",
return_classifications=True,
))
assert result is not None
assert result.classifications is not None

with pytest.raises(ValueError) as excinfo:
asyncio.run(vs.get_classifications_from_camera(
"not_cam",
count=1,
))
assert 'not_cam' in str(excinfo.value)
with pytest.raises(ValueError) as excinfo:
asyncio.run(asyncio.run(vs.capture_all_from_camera(
"not_cam",
return_classifications=True,
return_detections=True,
)))
assert 'not_cam' in str(excinfo.value)