Skip to content

Commit 5cc6004

Browse files
authored
Bodypix update (#243)
1 parent 87efc93 commit 5cc6004

File tree

17 files changed

+152
-147
lines changed

17 files changed

+152
-147
lines changed

common/navigation/lasr_person_following/src/lasr_person_following/person_following.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(
124124
if not self._transcribe_speech_client_available:
125125
rospy.logwarn("Transcribe speech client not available")
126126

127-
self._detect_wave = rospy.ServiceProxy("/detect_wave", DetectWave)
127+
self._detect_wave = rospy.ServiceProxy("/bodypix/detect_wave", DetectWave)
128128
if not self._detect_wave.wait_for_service(rospy.Duration.from_sec(10.0)):
129129
rospy.logwarn("Detect wave service not available")
130130

common/vision/lasr_vision_bodypix/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,7 @@ include_directories(
160160
## Mark executable scripts (Python etc.) for installation
161161
## in contrast to setup.py, you can choose the destination
162162
catkin_install_python(PROGRAMS
163-
nodes/mask_service.py
164-
nodes/keypoint_service.py
163+
nodes/bodypix_services.py
165164
examples/mask_relay.py
166165
examples/keypoint_relay.py
167166
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
<launch>
2+
<description>Start BodyPix services</description>
3+
<usage doc="BodyPix service"></usage>
4+
<usage doc="Preload models and enable debug topic">debug:=true preload:=['resnet50', 'mobilenet50']</usage>
5+
6+
<arg name="preload" default="['resnet50']" doc="Array of models to preload when starting the service" />
7+
8+
9+
<node name="bodypix_services" pkg="lasr_vision_bodypix" type="bodypix_services.py" output="screen">
10+
<param name="preload" type="yaml" value="$(arg preload)" />
11+
</node>
12+
13+
</launch>

common/vision/lasr_vision_bodypix/launch/camera_keypoint.launch

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
<arg name="model" default="resnet50" doc="Model to use for the demo" />
88

99
<!-- BodyPix service -->
10-
<include file="$(find lasr_vision_bodypix)/launch/keypoint_service.launch">
11-
<arg name="debug" value="true" />
10+
<include file="$(find lasr_vision_bodypix)/launch/bodypix.launch">
1211
<arg name="preload" value="['$(arg model)']" />
1312
</include>
1413

common/vision/lasr_vision_bodypix/launch/camera_mask.launch

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
<arg name="model" default="resnet50" doc="Model to use for the demo" />
88

99
<!-- BodyPix service -->
10-
<include file="$(find lasr_vision_bodypix)/launch/mask_service.launch">
11-
<arg name="debug" value="true" />
10+
<include file="$(find lasr_vision_bodypix)/launch/bodypix.launch">
1211
<arg name="preload" value="['$(arg model)']" />
1312
</include>
14-
13+
1514
<!-- show debug topic -->
1615
<node name="image_view" pkg="rqt_image_view" type="rqt_image_view" respawn="false" output="screen" args="/bodypix/debug/$(arg model)" />
1716

@@ -22,4 +21,5 @@
2221
<include file="$(find video_stream_opencv)/launch/camera.launch">
2322
<arg name="visualize" value="true" />
2423
</include>
24+
2525
</launch>

common/vision/lasr_vision_bodypix/launch/gesture_service.launch

Lines changed: 0 additions & 13 deletions
This file was deleted.

common/vision/lasr_vision_bodypix/launch/keypoint_service.launch

Lines changed: 0 additions & 13 deletions
This file was deleted.

common/vision/lasr_vision_bodypix/launch/mask_service.launch

Lines changed: 0 additions & 13 deletions
This file was deleted.

common/vision/lasr_vision_bodypix/nodes/gesture_service.py renamed to common/vision/lasr_vision_bodypix/nodes/bodypix_services.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
1-
#!/usr/bin/env python3.9
2-
1+
#!/usr/bin/env python3
32
import rospy
4-
from typing import List, Union
5-
from sensor_msgs.msg import Image
63
import lasr_vision_bodypix as bodypix
7-
import cv2
8-
import cv2_img
9-
import ros_numpy as rnp
10-
from geometry_msgs.msg import PointStamped, Point
11-
from visualization_msgs.msg import Marker
12-
from markers import create_and_publish_marker
13-
from cv2_pcl import pcl_to_img_msg
14-
154
from lasr_vision_msgs.srv import (
5+
BodyPixMaskDetection,
6+
BodyPixMaskDetectionRequest,
7+
BodyPixMaskDetectionResponse,
168
BodyPixKeypointDetection,
179
BodyPixKeypointDetectionRequest,
1810
BodyPixKeypointDetectionResponse,
@@ -21,13 +13,38 @@
2113
DetectWaveResponse,
2214
)
2315

16+
from typing import Union
17+
from sensor_msgs.msg import Image
18+
import ros_numpy as rnp
19+
from geometry_msgs.msg import PointStamped, Point
2420
from std_msgs.msg import Header
2521
import numpy as np
22+
from cv2_pcl import pcl_to_img_msg
2623

27-
rospy.init_node("detect_wave_service")
24+
# Initialise rospy
25+
rospy.init_node("bodypix_mask_service")
2826

29-
DEBUG = rospy.get_param("~debug", True)
30-
marker_pub = rospy.Publisher("waving_person", Marker, queue_size=1)
27+
# Determine variables
28+
PRELOAD = rospy.get_param("~preload", []) # List of models to preload
29+
30+
for model in PRELOAD:
31+
bodypix.load_model_cached(model)
32+
33+
34+
def detect_masks(request: BodyPixMaskDetectionRequest) -> BodyPixMaskDetectionResponse:
35+
"""
36+
Hand off detection request to bodypix library
37+
"""
38+
return bodypix.detect_masks(request)
39+
40+
41+
def detect_keypoints(
42+
request: BodyPixKeypointDetectionRequest,
43+
) -> BodyPixKeypointDetectionResponse:
44+
"""
45+
Hand off detection request to bodypix library
46+
"""
47+
return bodypix.detect_keypoints(request)
3148

3249

3350
def detect_wave(
@@ -126,7 +143,8 @@ def detect_wave(
126143
)
127144

128145

129-
# rospy.Service("/detect_wave", DetectWave, lambda req: detect_wave(req, rospy.Publisher("debug_waving", Image, queue_size=1)))
130-
rospy.Service("/detect_wave", DetectWave, detect_wave)
131-
rospy.loginfo("Detect wave service started")
146+
rospy.Service("/bodypix/mask_detection", BodyPixMaskDetection, detect_masks)
147+
rospy.Service("/bodypix/keypoint_detection", BodyPixKeypointDetection, detect_keypoints)
148+
rospy.Service("/bodypix/detect_wave", DetectWave, detect_wave)
149+
rospy.loginfo("BodyPix service started")
132150
rospy.spin()

common/vision/lasr_vision_bodypix/nodes/keypoint_service.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

common/vision/lasr_vision_bodypix/nodes/mask_service.py

Lines changed: 0 additions & 29 deletions
This file was deleted.
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
tf-bodypix==0.4.2
2-
tensorflow==2.14.0
32
opencv-python==4.8.1.78
43
Pillow==10.1.0
54
matplotlib==3.8.1
65

76
# The following was manually added and freezed into requirements.txt:
8-
# tfjs-graph-converter==1.6.3
7+
tfjs-graph-converter==1.6.3
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
absl-py==2.1.0 # via chex, keras, optax, orbax-checkpoint, tensorboard, tensorflow, tensorflow-decision-forests, ydf
2+
astunparse==1.6.3 # via tensorflow
3+
certifi==2024.7.4 # via requests
4+
charset-normalizer==3.3.2 # via requests
5+
chex==0.1.86 # via optax
6+
contourpy==1.2.1 # via matplotlib
7+
cycler==0.12.1 # via matplotlib
8+
etils[epath,epy]==1.5.2 # via orbax-checkpoint
9+
flatbuffers==24.3.25 # via tensorflow
10+
flax==0.8.5 # via tensorflowjs
11+
fonttools==4.53.1 # via matplotlib
12+
fsspec==2024.6.1 # via etils
13+
gast==0.6.0 # via tensorflow
14+
google-pasta==0.2.0 # via tensorflow
15+
grpcio==1.64.1 # via tensorboard, tensorflow
16+
h5py==3.11.0 # via keras, tensorflow
17+
idna==3.7 # via requests
18+
importlib-metadata==8.0.0 # via jax, markdown
19+
importlib-resources==6.4.0 # via etils, matplotlib, tensorflowjs
20+
jax==0.4.30 # via chex, flax, optax, orbax-checkpoint, tensorflowjs
21+
jaxlib==0.4.30 # via chex, jax, optax, orbax-checkpoint, tensorflowjs
22+
keras==3.4.1 # via tensorflow
23+
kiwisolver==1.4.5 # via matplotlib
24+
libclang==18.1.1 # via tensorflow
25+
markdown==3.6 # via tensorboard
26+
markdown-it-py==3.0.0 # via rich
27+
markupsafe==2.1.5 # via werkzeug
28+
matplotlib==3.8.1 # via -r requirements.in
29+
mdurl==0.1.2 # via markdown-it-py
30+
ml-dtypes==0.3.2 # via jax, jaxlib, keras, tensorflow, tensorstore
31+
msgpack==1.0.8 # via flax, orbax-checkpoint
32+
namex==0.0.8 # via keras
33+
nest-asyncio==1.6.0 # via orbax-checkpoint
34+
numpy==1.26.4 # via chex, contourpy, flax, h5py, jax, jaxlib, keras, matplotlib, ml-dtypes, opencv-python, opt-einsum, optax, orbax-checkpoint, pandas, scipy, tensorboard, tensorflow, tensorflow-decision-forests, tensorflow-hub, tensorstore, ydf
35+
opencv-python==4.8.1.78 # via -r requirements.in
36+
opt-einsum==3.3.0 # via jax, tensorflow
37+
optax==0.2.2 # via flax
38+
optree==0.12.1 # via keras
39+
orbax-checkpoint==0.5.20 # via flax
40+
packaging==23.2 # via keras, matplotlib, tensorflow, tensorflowjs
41+
pandas==2.2.2 # via tensorflow-decision-forests
42+
pillow==10.1.0 # via -r requirements.in, matplotlib
43+
protobuf==4.25.3 # via orbax-checkpoint, tensorboard, tensorflow, tensorflow-hub, ydf
44+
pygments==2.18.0 # via rich
45+
pyparsing==3.1.2 # via matplotlib
46+
python-dateutil==2.9.0.post0 # via matplotlib, pandas
47+
pytz==2024.1 # via pandas
48+
pyyaml==6.0.1 # via flax, orbax-checkpoint
49+
requests==2.32.3 # via tensorflow, tf-bodypix
50+
rich==13.7.1 # via flax, keras
51+
scipy==1.13.1 # via jax, jaxlib
52+
six==1.16.0 # via astunparse, google-pasta, python-dateutil, tensorboard, tensorflow, tensorflow-decision-forests, tensorflowjs
53+
tensorboard==2.16.2 # via tensorflow
54+
tensorboard-data-server==0.7.2 # via tensorboard
55+
tensorflow==2.16.2 # via tensorflow-decision-forests, tensorflowjs, tf-keras
56+
tensorflow-decision-forests==1.9.1 # via tensorflowjs
57+
tensorflow-hub==0.16.1 # via tensorflowjs
58+
tensorflow-io-gcs-filesystem==0.37.1 # via tensorflow
59+
tensorflowjs==4.20.0 # via tfjs-graph-converter
60+
tensorstore==0.1.63 # via flax, orbax-checkpoint
61+
termcolor==2.4.0 # via tensorflow
62+
tf-bodypix==0.4.2 # via -r requirements.in
63+
tf-keras==2.16.0 # via tensorflow-decision-forests, tensorflow-hub, tensorflowjs
64+
tfjs-graph-converter==1.6.3 # via -r requirements.in
65+
toolz==0.12.1 # via chex
66+
typing-extensions==4.12.2 # via chex, etils, flax, optree, orbax-checkpoint, tensorflow
67+
tzdata==2024.1 # via pandas
68+
urllib3==2.2.2 # via requests
69+
werkzeug==3.0.3 # via tensorboard
70+
wheel==0.43.0 # via astunparse, tensorflow-decision-forests
71+
wrapt==1.16.0 # via tensorflow
72+
wurlitzer==3.1.1 # via tensorflow-decision-forests
73+
ydf==0.5.0 # via tensorflow-decision-forests
74+
zipp==3.19.2 # via etils, importlib-metadata, importlib-resources
75+
76+
# The following packages are considered to be unsafe in a requirements file:
77+
# setuptools

common/vision/lasr_vision_bodypix/src/lasr_vision_bodypix/bodypix.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,7 @@
2424
import rospkg
2525

2626
# model cache
27-
# preload resnet 50 model so that it won't waste the time
28-
# doing that in the middle of the task.
29-
loaded_models = {
30-
"resnet50": load_model(download_model(BodyPixModelPaths.RESNET50_FLOAT_STRIDE_16))
31-
}
27+
loaded_models = {}
3228
r = rospkg.RosPack()
3329

3430

@@ -47,20 +43,27 @@ def load_model_cached(dataset: str):
4743
"""
4844
model = None
4945
if dataset in loaded_models:
46+
rospy.loginfo(f"Using cached {dataset} model")
5047
model = loaded_models[dataset]
5148
else:
5249
if dataset == "resnet50":
50+
rospy.loginfo("Downloading resnet50 model")
5351
name = download_model(BodyPixModelPaths.RESNET50_FLOAT_STRIDE_16)
52+
rospy.loginfo("Loading resnet50 model")
5453
model = load_model(name)
5554
elif dataset == "mobilenet50":
55+
rospy.loginfo("Downloading mobilenet50 model")
5656
name = download_model(BodyPixModelPaths.MOBILENET_FLOAT_50_STRIDE_8)
57+
rospy.loginfo("Loading mobilenet50 model")
5758
model = load_model(name)
5859
elif dataset == "mobilenet100":
60+
rospy.loginfo("Downloading mobilenet100 model")
5961
name = download_model(BodyPixModelPaths.MOBILENET_FLOAT_100_STRIDE_8)
62+
rospy.loginfo("Loading mobilenet100 model")
6063
model = load_model(name)
6164
else:
6265
model = load_model(dataset)
63-
rospy.loginfo(f"Loaded {dataset} model")
66+
rospy.loginfo(f"Loaded {dataset} model into cache")
6467
loaded_models[dataset] = model
6568
return model
6669

skills/launch/unit_test_describe_people.launch

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,9 @@
44
<arg name="preload" value="['yolov8n-seg.pt']" />
55
</include>
66

7-
<node name="bodypix_keypoint_service" pkg="lasr_vision_bodypix" type="keypoint_service.py" output="screen">
8-
<param name="debug" type="bool" value="true" />
7+
<include file="$(find lasr_vision_bodypix)/launch/bodypix_services.launch">
98
<param name="preload" type="yaml" value='resnet50' />
10-
</node>
11-
12-
<node name="bodypix_mask_service" pkg="lasr_vision_bodypix" type="mask_service.py" output="screen">
13-
<param name="debug" type="bool" value="true" />
14-
<param name="preload" type="yaml" value='resnet50' />
15-
</node>
9+
</include>
1610

1711
<node pkg="lasr_vision_feature_extraction" type="service" name="torch_service" output="screen"/>
1812

0 commit comments

Comments
 (0)