Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Empty file modified build.sh
100644 → 100755
Empty file.
24 changes: 14 additions & 10 deletions src/torchvision_module.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Module that defines the Vision Service that wraps torchvision functionality"""

from typing import ClassVar, List, Mapping, Sequence, Any, Dict, Optional, Union
from typing import ClassVar, List, Mapping, Sequence, Any, Dict, Optional, Union, Tuple
from typing_extensions import Self
from viam.components.camera import Camera
from viam.media.video import ViamImage, CameraMimeType
from viam.media.video import ViamImage
from viam.proto.service.vision import Classification, Detection
from viam.services.vision import Vision, CaptureAllResult
from viam.module.types import Reconfigurable
Expand Down Expand Up @@ -47,7 +47,7 @@ def new_service(
return service

@classmethod
def validate_config(cls, config: ServiceConfig) -> Sequence[str]:
def validate_config(cls, config: ServiceConfig) -> Tuple[Sequence[str], Sequence[str]]:
"""Validates JSON Configuration"""
model_name = config.attributes.fields["model_name"].string_value
camera_name = config.attributes.fields["camera_name"].string_value
Expand All @@ -60,7 +60,7 @@ def validate_config(cls, config: ServiceConfig) -> Sequence[str]:
raise Exception(
"A camera name is required for this vision service module."
)
return [camera_name]
return [camera_name], []

def reconfigure(
self, config: ServiceConfig, dependencies: Mapping[ResourceName, ResourceBase]
Expand Down Expand Up @@ -197,19 +197,21 @@ async def capture_all_from_camera(
"is not the configured 'camera_name'",
self.camera_name,
)
image = await self.camera.get_image(mime_type=CameraMimeType.JPEG)
images, _ = await self.camera.get_images()
if images is None or len(images) == 0 and (return_image or return_classifications or return_detections):
raise ValueError("No images returned by get_images")
if return_image:
result.image = image
result.image = images[0]
if return_classifications:
try:
classifications = await self.get_classifications(image, 1)
classifications = await self.get_classifications(images[0], 1)
result.classifications = classifications
# pylint: disable=broad-exception-caught
except Exception as e:
LOGGER.info(f"getClassifications failed: {e}")
if return_detections:
try:
detections = await self.get_detections(image, timeout=timeout, extra=None)
detections = await self.get_detections(images[0], timeout=timeout, extra=None)
result.detections = detections
# pylint: disable=broad-exception-caught
except Exception as e:
Expand Down Expand Up @@ -362,8 +364,10 @@ def filter_output(

async def get_image_from_dependency(self, camera_name: str):
# cam = self.dependencies[Camera.get_resource_name("")]
im = await self.camera.get_image(mime_type=CameraMimeType.JPEG)
return decode_image(im)
imgs, _ = await self.camera.get_images()
if imgs is None or len(imgs) == 0:
raise ValueError("No images returned by get_images")
return decode_image(imgs[0])

def wrap_detections(self, prediction: dict, image_shape: Sequence[int]):
"""_summary_
Expand Down
Binary file not shown.
8 changes: 4 additions & 4 deletions tests/test_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ class TestVision:
@pytest.mark.asyncio
async def test_validate(self):
with pytest.raises(Exception):
response = TorchVisionService.validate_config(config=config)
response = TorchVisionService.validate_config(config=config2)
response, _ = TorchVisionService.validate_config(config=config)
response, _ = TorchVisionService.validate_config(config=config2)

@pytest.mark.asyncio
@patch('viam.components.camera.Camera.get_resource_name', return_value="fake_cam")
Expand All @@ -90,7 +90,7 @@ def test_capture_all_from_camera(self, fake_cam):
camera.camera_name = "fake_cam"

camera.reconfigure(cfg, dependencies={"fake_cam": Mock()})
camera.camera.get_image = AsyncMock(return_value=pil_to_viam_image(image, mime_type=CameraMimeType.JPEG))
camera.camera.get_images = AsyncMock(return_value=[[pil_to_viam_image(image, mime_type=CameraMimeType.JPEG)], None])

# without point clouds = True
result = asyncio.run(camera.capture_all_from_camera(
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_default_camera_behavior(self, fake_cam):
)

vs.reconfigure(cfg, dependencies={"fake_cam": Mock()})
vs.camera.get_image = AsyncMock(return_value=pil_to_viam_image(image, mime_type=CameraMimeType.JPEG))
vs.camera.get_images = AsyncMock(return_value=[[pil_to_viam_image(image, mime_type=CameraMimeType.JPEG)], None])

result = vs.get_classifications_from_camera(
"",
Expand Down