Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ torch; sys_platform == 'darwin'
torch==2.2.2+cpu; sys_platform == 'linux'
-f https://download.pytorch.org/whl/cpu/torch_stable.html

torchvision == 0.18; sys_platform == 'darwin'
torchvision == 0.21; sys_platform == 'darwin'
torchvision==0.17.2+cpu;sys_platform == 'linux'
-f https://download.pytorch.org/whl/cpu/torch_stable.html

protobuf==4.25.3
protobuf>4.25.3
pytest
pylint
pytest-asyncio
4 changes: 2 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async def main():
registry before the module adds the resource model.
"""
Registry.register_resource_creator(
Vision.SUBTYPE,
Vision.API,
TorchVisionService.MODEL,
ResourceCreatorRegistration(
TorchVisionService.new_service,
Expand All @@ -24,7 +24,7 @@ async def main():
)
module = Module.from_args()

module.add_model_from_registry(Vision.SUBTYPE, TorchVisionService.MODEL)
module.add_model_from_registry(Vision.API, TorchVisionService.MODEL)
await module.start()


Expand Down
6 changes: 3 additions & 3 deletions src/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
# pylint: disable=missing-class-docstring
@dataclass
class Properties:
implements_classification: bool = False
implements_detection: bool = False
implements_get_object_pcd: bool = False
classifications_supported: bool = False
detections_supported: bool = False
object_point_clouds_supported: bool = False
70 changes: 44 additions & 26 deletions src/torchvision_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def get_attribute_from_config(attribute_name: str, default, of_type=None):

model_name = get_attribute_from_config("model_name", None, str)
self.properties = Properties(
implements_classification=True,
implements_detection=False,
implements_get_object_pcd=False,
classifications_supported=True,
detections_supported=False,
object_point_clouds_supported=False,
)
if model_name in DETECTION_MODELS:
self.properties.implements_classification = False
self.properties.implements_detection = True
self.properties.classifications_supported = False
self.properties.detections_supported = True

weights = get_attribute_from_config("weights", "DEFAULT")
try:
Expand Down Expand Up @@ -197,8 +197,7 @@ async def capture_all_from_camera(
"is not the configured 'camera_name'",
self.camera_name,
)
image = await self.get_image_from_dependency(camera_name)

image = await self.camera.get_image(mime_type=CameraMimeType.JPEG)
if return_image:
result.image = image
if return_classifications:
Expand Down Expand Up @@ -227,7 +226,7 @@ async def get_object_point_clouds(
timeout: Optional[float] = None,
**kwargs,
) -> List[PointCloudObject]:
if not self.properties.implements_get_object_pcd:
if not self.properties.object_point_clouds_supported:
raise NotImplementedError
return 1

Expand All @@ -239,14 +238,14 @@ async def get_detections(
timeout: float,
) -> List[Detection]:
"""Get detections from an image"""
if not self.properties.implements_detection:
if not self.properties.detections_supported:
raise NotImplementedError
LOGGER.info(f"input image is: {type(image)}")
image = decode_image(image)
input_tensor = self.preprocessor(image)
with torch.no_grad():
prediction: Tensor = self.model(input_tensor)[0]
return self.wrap_detections(prediction)
return self.wrap_detections(prediction, image.shape)

async def get_classifications(
self,
Expand All @@ -257,7 +256,7 @@ async def get_classifications(
timeout: Optional[float] = None,
) -> List[Classification]:
"""Get classifications from image"""
if not self.properties.implements_classification:
if not self.properties.classifications_supported:
return NotImplementedError
image = decode_image(image)
input_tensor = self.preprocessor(image)
Expand All @@ -276,7 +275,7 @@ async def get_classifications_from_camera(
timeout: Optional[float] = None,
) -> List[Classification]:
"""Gets classifications from a camera dependency"""
if not self.properties.implements_classification:
if not self.properties.classifications_supported:
raise NotImplementedError

if camera_name not in (self.camera_name, ""):
Expand All @@ -296,7 +295,7 @@ async def get_detections_from_camera(
self, camera_name: str, *, extra: Mapping[str, Any] = None, timeout: float = None,
) -> List[Detection]:
"""Gets detections from a camera dependency"""
if not self.properties.implements_detection:
if not self.properties.detections_supported:
raise NotImplementedError
if camera_name not in (self.camera_name, ""):
raise ValueError(
Expand All @@ -310,7 +309,7 @@ async def get_detections_from_camera(
input_tensor = self.preprocessor(image)
with torch.no_grad():
prediction: Tensor = self.model(input_tensor)[0]
return self.wrap_detections(prediction)
return self.wrap_detections(prediction, image.shape)

async def get_properties(
self,
Expand Down Expand Up @@ -366,7 +365,7 @@ async def get_image_from_dependency(self, camera_name: str):
im = await self.camera.get_image(mime_type=CameraMimeType.JPEG)
return decode_image(im)

def wrap_detections(self, prediction: dict):
def wrap_detections(self, prediction: dict, image_shape: Sequence[int]):
"""_summary_
converts prediction output tensor from torchvision model
for viam API
Expand All @@ -379,17 +378,7 @@ def wrap_detections(self, prediction: dict):
labels = [self.weights.meta["categories"][i] for i in prediction["labels"]]
scores = prediction["scores"]
boxes = prediction["boxes"].to(torch.int64).tolist()
res = [
Detection(
x_min=x_min,
y_min=y_min,
x_max=x_max,
y_max=y_max,
confidence=score,
class_name=label,
)
for (x_min, y_min, x_max, y_max), score, label in zip(boxes, scores, labels)
]
res = self.make_detections(boxes, scores, labels, image_shape)
res = self.filter_output(res)
return res

Expand All @@ -414,3 +403,32 @@ def wrap_classifications(self, prediction, count):
]
res = self.filter_output(res)
return res

def make_detections(self, boxes, scores, labels, image_shape):
if image_shape[0] == 0 or image_shape[1] == 0:
return [
Detection(
x_min=x_min,
y_min=y_min,
x_max=x_max,
y_max=y_max,
confidence=score,
class_name=label,
)
for (x_min, y_min, x_max, y_max), score, label in zip(boxes, scores, labels)
]
return [
Detection(
x_min=x_min,
y_min=y_min,
x_max=x_max,
y_max=y_max,
x_min_normalized=x_min / image_shape[1],
y_min_normalized=y_min / image_shape[0],
x_max_normalized=x_max / image_shape[1],
y_max_normalized=y_max / image_shape[0],
confidence=score,
class_name=label,
)
for (x_min, y_min, x_max, y_max), score, label in zip(boxes, scores, labels)
]
2 changes: 1 addition & 1 deletion src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def decode_image(image: Union[Image.Image, ViamImage, np.ndarray]) -> np.ndarray
im = Image.open(BytesIO(image.data), formats=LIBRARY_SUPPORTED_FORMATS).convert(
"RGB"
) # convert in RGB png openened in RGBA
return im
return np.array(im)
res = image.convert("RGB")
rgb = np.array(res)
return rgb
53 changes: 41 additions & 12 deletions tests/test_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from PIL import Image

from viam.components.camera import Camera
from viam.media.video import CameraMimeType
from viam.media.utils.pil import pil_to_viam_image
from viam.proto.app.robot import ComponentConfig

Expand All @@ -20,7 +22,8 @@
from google.protobuf.struct_pb2 import Struct

path_to_input_image = "tests/grasshopper.jpg"
input_image = np.array(Image.open(path_to_input_image))
image = Image.open(path_to_input_image)
input_image = np.array(image)
cfg = ServiceConfig(
attributes=dict_to_struct(
{
Expand All @@ -33,9 +36,9 @@
VISION_SERVICE_NAME = "vision1"
DETECTIONS = []
PROPERTIES = Properties(
implements_classification=True,
implements_detection=False,
implements_get_object_pcd=False
classifications_supported=True,
detections_supported=False,
object_point_clouds_supported=False
)

def make_component_config(dictionary: Mapping[str, Any]) -> ComponentConfig:
Expand Down Expand Up @@ -80,15 +83,14 @@ async def test_get_properties(self, get_image_from_dependency, fake_cam):
assert response == PROPERTIES

@patch('viam.components.camera.Camera.get_resource_name', return_value="fake_cam")
@patch.object(TorchVisionService, 'get_image_from_dependency', new_callable=AsyncMock)
def test_capture_all_from_camera(self, get_image_from_dependency, fake_cam):
def test_capture_all_from_camera(self, fake_cam):
camera = TorchVisionService(
name='tvs'
)
camera.camera_name = "fake_cam"
get_image_from_dependency.return_value = input_image

camera.reconfigure(cfg, dependencies={"fake_cam": Mock()})
camera.camera.get_image = AsyncMock(return_value=pil_to_viam_image(image, mime_type=CameraMimeType.JPEG))

# without point clouds = True
result = asyncio.run(camera.capture_all_from_camera(
Expand Down Expand Up @@ -118,16 +120,14 @@ def test_capture_all_from_camera(self, get_image_from_dependency, fake_cam):
# mock_get_detections.assert_called_once_with('test_image', timeout=None)

@patch('viam.components.camera.Camera.get_resource_name', return_value="fake_cam")
@patch.object(TorchVisionService, 'get_image_from_dependency', new_callable=AsyncMock)
def test_default_camera_behavior(self, get_image_from_dependency, fake_cam):
def test_default_camera_behavior(self, fake_cam):
vs = TorchVisionService(
name='tvs'
)
get_image_from_dependency.return_value = input_image

# vs.camera_name = "fake_cam"
vs.reconfigure(cfg, dependencies={"fake_cam": Mock()})

vs.camera.get_image = AsyncMock(return_value=pil_to_viam_image(image, mime_type=CameraMimeType.JPEG))

result = vs.get_classifications_from_camera(
"",
count=1,
Expand Down Expand Up @@ -155,3 +155,32 @@ def test_default_camera_behavior(self, get_image_from_dependency, fake_cam):
)))
assert 'not_cam' in str(excinfo.value)

@patch('viam.components.camera.Camera.get_resource_name', return_value="fake_cam")
@patch.object(TorchVisionService, 'get_image_from_dependency', new_callable=AsyncMock)
def test_get_detections(self, get_image_from_dependency, fake_cam):
vision = TorchVisionService(
name='tvs'
)
vision.camera_name = "fake_cam"
get_image_from_dependency.return_value = input_image

cfg = ServiceConfig(
attributes=dict_to_struct(
{
"model_name": "fasterrcnn_resnet50_fpn",
"camera_name": "fake_cam",
}
)
)

vision.reconfigure(cfg, dependencies={"fake_cam": Mock()})
response = asyncio.run(vision.get_detections_from_camera("fake_cam"))
assert response is not None
assert response[0].x_min is not None
assert response[0].y_min is not None
assert response[0].x_max is not None
assert response[0].y_max is not None
assert response[0].x_min_normalized == response[0].x_min / input_image.shape[1]
assert response[0].y_min_normalized == response[0].y_min / input_image.shape[0]
assert response[0].x_max_normalized == response[0].x_max / input_image.shape[1]
assert response[0].y_max_normalized == response[0].y_max / input_image.shape[0]