From 7eed269449aa04844489e0e53cc97b5369667d7d Mon Sep 17 00:00:00 2001 From: Matt Barker <105945282+m-barker@users.noreply.github.com> Date: Fri, 21 Jun 2024 16:21:55 +0100 Subject: [PATCH] Add required keypoint detection (#226) Co-authored-by: Jared Swift --- .../vision/lasr_vision_bodypix/CMakeLists.txt | 6 +- .../examples/keypoint_relay.py | 70 +++++++++ .../examples/{relay => mask_relay.py} | 30 ++-- .../launch/camera_keypoint.launch | 25 +++ .../{camera.launch => camera_mask.launch} | 4 +- .../launch/keypoint_service.launch | 13 ++ .../{service.launch => mask_service.launch} | 2 +- .../nodes/keypoint_service.py | 31 ++++ .../lasr_vision_bodypix/nodes/mask_service.py | 29 ++++ .../vision/lasr_vision_bodypix/nodes/service | 46 ------ .../src/lasr_vision_bodypix/__init__.py | 2 +- .../src/lasr_vision_bodypix/bodypix.py | 147 +++++++++++++----- common/vision/lasr_vision_msgs/CMakeLists.txt | 5 +- .../lasr_vision_msgs/msg/BodyPixKeypoint.msg | 12 +- .../lasr_vision_msgs/msg/BodyPixMask.msg | 2 + .../msg/BodyPixMaskRequest.msg | 5 - .../lasr_vision_msgs/msg/BodyPixPose.msg | 1 - .../srv/BodyPixKeypointDetection.srv | 12 ++ ...Detection.srv => BodyPixMaskDetection.srv} | 9 +- skills/src/lasr_skills/describe_people.py | 43 ++--- skills/src/lasr_skills/detect_gesture.py | 142 +++++++++-------- skills/src/lasr_skills/validate_keypoints.py | 98 ++++++++++++ 22 files changed, 514 insertions(+), 220 deletions(-) create mode 100644 common/vision/lasr_vision_bodypix/examples/keypoint_relay.py rename common/vision/lasr_vision_bodypix/examples/{relay => mask_relay.py} (65%) create mode 100644 common/vision/lasr_vision_bodypix/launch/camera_keypoint.launch rename common/vision/lasr_vision_bodypix/launch/{camera.launch => camera_mask.launch} (78%) create mode 100644 common/vision/lasr_vision_bodypix/launch/keypoint_service.launch rename common/vision/lasr_vision_bodypix/launch/{service.launch => mask_service.launch} (84%) create mode 100644 common/vision/lasr_vision_bodypix/nodes/keypoint_service.py create mode 100644 common/vision/lasr_vision_bodypix/nodes/mask_service.py delete mode 100644 common/vision/lasr_vision_bodypix/nodes/service delete mode 100644 common/vision/lasr_vision_msgs/msg/BodyPixMaskRequest.msg delete mode 100644 common/vision/lasr_vision_msgs/msg/BodyPixPose.msg create mode 100644 common/vision/lasr_vision_msgs/srv/BodyPixKeypointDetection.srv rename common/vision/lasr_vision_msgs/srv/{BodyPixDetection.srv => BodyPixMaskDetection.srv} (53%) create mode 100755 skills/src/lasr_skills/validate_keypoints.py diff --git a/common/vision/lasr_vision_bodypix/CMakeLists.txt b/common/vision/lasr_vision_bodypix/CMakeLists.txt index eacbc75c1..93249ed30 100644 --- a/common/vision/lasr_vision_bodypix/CMakeLists.txt +++ b/common/vision/lasr_vision_bodypix/CMakeLists.txt @@ -160,8 +160,10 @@ include_directories( ## Mark executable scripts (Python etc.) for installation ## in contrast to setup.py, you can choose the destination catkin_install_python(PROGRAMS - nodes/service - examples/relay + nodes/mask_service.py + nodes/keypoint_service.py + examples/mask_relay.py + examples/keypoint_relay.py DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} ) diff --git a/common/vision/lasr_vision_bodypix/examples/keypoint_relay.py b/common/vision/lasr_vision_bodypix/examples/keypoint_relay.py new file mode 100644 index 000000000..20994f3d4 --- /dev/null +++ b/common/vision/lasr_vision_bodypix/examples/keypoint_relay.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 + +import sys +import rospy +import threading + +from sensor_msgs.msg import Image +from lasr_vision_msgs.srv import ( + BodyPixKeypointDetection, + BodyPixKeypointDetectionRequest, +) + +if len(sys.argv) < 2: + print( + "Usage: rosrun lasr_vision_bodypix keypoint_relay.py [resnet50|mobilenet50|...]" + ) + exit() + +# figure out what we are listening to +listen_topic = sys.argv[1] + +# figure out what model we are using +if len(sys.argv) >= 3: + model = sys.argv[2] +else: + model = "resnet50" + +processing = False + + +def detect(image): + global processing + processing = True + rospy.loginfo("Received image message") + + try: + detect_service = rospy.ServiceProxy( + "/bodypix/keypoint_detection", BodyPixKeypointDetection + ) + req = BodyPixKeypointDetectionRequest() + req.image_raw = image + req.dataset = model + req.confidence = 0.7 + + resp = detect_service(req) + print(resp) + except rospy.ServiceException as e: + rospy.logerr("Service call failed: %s" % e) + finally: + processing = False + + +def image_callback(image): + global processing + if processing: + return + + t = threading.Thread(target=detect, args=(image,)) + t.start() + + +def listener(): + rospy.init_node("image_listener", anonymous=True) + rospy.wait_for_service("/bodypix/keypoint_detection") + rospy.Subscriber(listen_topic, Image, image_callback) + rospy.spin() + + +if __name__ == "__main__": + listener() diff --git a/common/vision/lasr_vision_bodypix/examples/relay b/common/vision/lasr_vision_bodypix/examples/mask_relay.py similarity index 65% rename from common/vision/lasr_vision_bodypix/examples/relay rename to common/vision/lasr_vision_bodypix/examples/mask_relay.py index 0937308ae..6513beba3 100644 --- a/common/vision/lasr_vision_bodypix/examples/relay +++ b/common/vision/lasr_vision_bodypix/examples/mask_relay.py @@ -5,11 +5,12 @@ import threading from sensor_msgs.msg import Image -from lasr_vision_msgs.msg import BodyPixMaskRequest -from lasr_vision_msgs.srv import BodyPixDetection, BodyPixDetectionRequest +from lasr_vision_msgs.srv import BodyPixMaskDetection, BodyPixMaskDetectionRequest if len(sys.argv) < 2: - print('Usage: rosrun lasr_vision_bodypix relay [resnet50|mobilenet50|...]') + print( + "Usage: rosrun lasr_vision_bodypix mask_relay.py [resnet50|mobilenet50|...]" + ) exit() # figure out what we are listening to @@ -23,34 +24,35 @@ processing = False + def detect(image): global processing processing = True rospy.loginfo("Received image message") try: - detect_service = rospy.ServiceProxy('/bodypix/detect', BodyPixDetection) - req = BodyPixDetectionRequest() + detect_service = rospy.ServiceProxy( + "/bodypix/mask_detection", BodyPixMaskDetection + ) + req = BodyPixMaskDetectionRequest() req.image_raw = image req.dataset = model req.confidence = 0.7 - - mask = BodyPixMaskRequest() - mask.parts = ['left_face', 'right_face'] - req.masks = [mask] + req.parts = ["left_face", "right_face"] resp = detect_service(req) # don't print the whole mask for mask in resp.masks: mask.mask = [True, False, True, False] - + print(resp) except rospy.ServiceException as e: rospy.logerr("Service call failed: %s" % e) finally: processing = False + def image_callback(image): global processing if processing: @@ -59,11 +61,13 @@ def image_callback(image): t = threading.Thread(target=detect, args=(image,)) t.start() + def listener(): - rospy.init_node('image_listener', anonymous=True) - rospy.wait_for_service('/bodypix/detect') + rospy.init_node("image_listener", anonymous=True) + rospy.wait_for_service("/bodypix/mask_detection") rospy.Subscriber(listen_topic, Image, image_callback) rospy.spin() -if __name__ == '__main__': + +if __name__ == "__main__": listener() diff --git a/common/vision/lasr_vision_bodypix/launch/camera_keypoint.launch b/common/vision/lasr_vision_bodypix/launch/camera_keypoint.launch new file mode 100644 index 000000000..f8f86d46b --- /dev/null +++ b/common/vision/lasr_vision_bodypix/launch/camera_keypoint.launch @@ -0,0 +1,25 @@ + + Run a BodyPix model using the camera + + model:=mobilenet50 + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/common/vision/lasr_vision_bodypix/launch/camera.launch b/common/vision/lasr_vision_bodypix/launch/camera_mask.launch similarity index 78% rename from common/vision/lasr_vision_bodypix/launch/camera.launch rename to common/vision/lasr_vision_bodypix/launch/camera_mask.launch index 1c793c235..7884781b4 100644 --- a/common/vision/lasr_vision_bodypix/launch/camera.launch +++ b/common/vision/lasr_vision_bodypix/launch/camera_mask.launch @@ -7,7 +7,7 @@ - + @@ -16,7 +16,7 @@ - + diff --git a/common/vision/lasr_vision_bodypix/launch/keypoint_service.launch b/common/vision/lasr_vision_bodypix/launch/keypoint_service.launch new file mode 100644 index 000000000..dfc5e1eab --- /dev/null +++ b/common/vision/lasr_vision_bodypix/launch/keypoint_service.launch @@ -0,0 +1,13 @@ + + Start the BodyPix service + + debug:=true preload:=['resnet50', 'mobilenet50'] + + + + + + + + + \ No newline at end of file diff --git a/common/vision/lasr_vision_bodypix/launch/service.launch b/common/vision/lasr_vision_bodypix/launch/mask_service.launch similarity index 84% rename from common/vision/lasr_vision_bodypix/launch/service.launch rename to common/vision/lasr_vision_bodypix/launch/mask_service.launch index 009c457cd..a0a03e8fc 100644 --- a/common/vision/lasr_vision_bodypix/launch/service.launch +++ b/common/vision/lasr_vision_bodypix/launch/mask_service.launch @@ -6,7 +6,7 @@ - + diff --git a/common/vision/lasr_vision_bodypix/nodes/keypoint_service.py b/common/vision/lasr_vision_bodypix/nodes/keypoint_service.py new file mode 100644 index 000000000..7602032bf --- /dev/null +++ b/common/vision/lasr_vision_bodypix/nodes/keypoint_service.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +import rospy +import lasr_vision_bodypix as bodypix +from lasr_vision_msgs.srv import ( + BodyPixKeypointDetection, + BodyPixKeypointDetectionRequest, + BodyPixKeypointDetectionResponse, +) + +# Initialise rospy +rospy.init_node("bodypix_keypoint_service") + +# Determine variables +PRELOAD = rospy.get_param("~preload", []) # resnet50 or mobilenet50 + +for model in PRELOAD: + pass + + +def detect( + request: BodyPixKeypointDetectionRequest, +) -> BodyPixKeypointDetectionResponse: + """ + Hand off detection request to bodypix library + """ + return bodypix.detect_keypoints(request) + + +rospy.Service("/bodypix/keypoint_detection", BodyPixKeypointDetection, detect) +rospy.loginfo("BodyPix keypoint service started") +rospy.spin() diff --git a/common/vision/lasr_vision_bodypix/nodes/mask_service.py b/common/vision/lasr_vision_bodypix/nodes/mask_service.py new file mode 100644 index 000000000..1ed72d163 --- /dev/null +++ b/common/vision/lasr_vision_bodypix/nodes/mask_service.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +import rospy +import lasr_vision_bodypix as bodypix +from lasr_vision_msgs.srv import ( + BodyPixMaskDetection, + BodyPixMaskDetectionRequest, + BodyPixMaskDetectionResponse, +) + +# Initialise rospy +rospy.init_node("bodypix_mask_service") + +# Determine variables +PRELOAD = rospy.get_param("~preload", []) # resnet50 or mobilenet50 + +for model in PRELOAD: + pass + + +def detect(request: BodyPixMaskDetectionRequest) -> BodyPixMaskDetectionResponse: + """ + Hand off detection request to bodypix library + """ + return bodypix.detect_masks(request) + + +rospy.Service("/bodypix/mask_detection", BodyPixMaskDetection, detect) +rospy.loginfo("BodyPix service started") +rospy.spin() diff --git a/common/vision/lasr_vision_bodypix/nodes/service b/common/vision/lasr_vision_bodypix/nodes/service deleted file mode 100644 index 025c813a9..000000000 --- a/common/vision/lasr_vision_bodypix/nodes/service +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env python3 - -import re -import rospy -import lasr_vision_bodypix as bodypix - -from sensor_msgs.msg import Image -from lasr_vision_msgs.srv import ( - BodyPixDetection, - BodyPixDetectionRequest, - BodyPixDetectionResponse, -) - -# Initialise rospy -rospy.init_node("bodypix_service") - -# Determine variables -DEBUG = rospy.get_param("~debug", False) -PRELOAD = rospy.get_param("~preload", []) # resnet50 or mobilenet50 - -for model in PRELOAD: - pass - -# Keep track of publishers -debug_publishers = {} - - -def detect(request: BodyPixDetectionRequest) -> BodyPixDetectionResponse: - """ - Hand off detection request to bodypix library - """ - debug_publisher = None - if DEBUG: - if request.dataset in debug_publishers: - debug_publisher = debug_publishers[request.dataset] - else: - topic_name = re.sub(r"[\W_]+", "", request.dataset) - debug_publisher = rospy.Publisher( - f"/bodypix/debug/{topic_name}", Image, queue_size=1 - ) - return bodypix.detect(request, debug_publisher) - - -rospy.Service("/bodypix/detect", BodyPixDetection, detect) -rospy.loginfo("BodyPix service started") -rospy.spin() diff --git a/common/vision/lasr_vision_bodypix/src/lasr_vision_bodypix/__init__.py b/common/vision/lasr_vision_bodypix/src/lasr_vision_bodypix/__init__.py index f9ca4ac2c..6a207a9ac 100644 --- a/common/vision/lasr_vision_bodypix/src/lasr_vision_bodypix/__init__.py +++ b/common/vision/lasr_vision_bodypix/src/lasr_vision_bodypix/__init__.py @@ -1 +1 @@ -from .bodypix import detect, load_model_cached +from .bodypix import detect_masks, detect_keypoints, load_model_cached diff --git a/common/vision/lasr_vision_bodypix/src/lasr_vision_bodypix/bodypix.py b/common/vision/lasr_vision_bodypix/src/lasr_vision_bodypix/bodypix.py index b6afa0f6a..db4c1cd60 100644 --- a/common/vision/lasr_vision_bodypix/src/lasr_vision_bodypix/bodypix.py +++ b/common/vision/lasr_vision_bodypix/src/lasr_vision_bodypix/bodypix.py @@ -1,9 +1,10 @@ from __future__ import annotations +from typing import List import rospy +import cv2 import cv2_img import numpy as np - from PIL import Image import re import tensorflow as tf @@ -11,8 +12,13 @@ from sensor_msgs.msg import Image as SensorImage -from lasr_vision_msgs.msg import BodyPixMask, BodyPixPose, BodyPixKeypoint -from lasr_vision_msgs.srv import BodyPixDetectionRequest, BodyPixDetectionResponse +from lasr_vision_msgs.msg import BodyPixMask, BodyPixKeypoint +from lasr_vision_msgs.srv import ( + BodyPixMaskDetectionRequest, + BodyPixMaskDetectionResponse, + BodyPixKeypointDetectionRequest, + BodyPixKeypointDetectionResponse, +) import rospkg @@ -30,7 +36,7 @@ def camel_to_snake(name): return re.sub(r"(? None: +def load_model_cached(dataset: str): """ Load a model into cache """ @@ -42,7 +48,10 @@ def load_model_cached(dataset: str) -> None: name = download_model(BodyPixModelPaths.RESNET50_FLOAT_STRIDE_16) model = load_model(name) elif dataset == "mobilenet50": - name = download_model(BodyPixModelPaths.MOBILENET_FLOAT_50_STRIDE_16) + name = download_model(BodyPixModelPaths.MOBILENET_FLOAT_50_STRIDE_8) + model = load_model(name) + elif dataset == "mobilenet100": + name = download_model(BodyPixModelPaths.MOBILENET_FLOAT_100_STRIDE_8) model = load_model(name) else: model = load_model(dataset) @@ -51,47 +60,57 @@ def load_model_cached(dataset: str) -> None: return model -def detect( - request: BodyPixDetectionRequest, - debug_publisher: rospy.Publisher = rospy.Publisher( - "/bodypix/debug", SensorImage, queue_size=1 - ), -) -> BodyPixDetectionResponse: - """ - Run BodyPix inference on given detection request - """ - +def run_inference(dataset: str, confidence: float, img: Image): # decode the image rospy.loginfo("Decoding") - img = cv2_img.msg_to_cv2_img(request.image_raw) + img = cv2_img.msg_to_cv2_img(img) # load model rospy.loginfo("Loading model") - model = load_model_cached(request.dataset) + model = load_model_cached(dataset) # run inference rospy.loginfo("Running inference") result = model.predict_single(img) - mask = result.get_mask(threshold=request.confidence) + mask = result.get_mask(threshold=confidence) rospy.loginfo("Inference complete") + return result, mask + + +def detect_masks( + request: BodyPixMaskDetectionRequest, + debug_publisher: rospy.Publisher = rospy.Publisher( + "/bodypix/mask_debug", SensorImage, queue_size=1 + ), +) -> BodyPixMaskDetectionResponse: + """ + Run BodyPix inference on given mask detection request + """ + + result, mask = run_inference(request.dataset, request.confidence, request.image_raw) # construct masks response masks = [] - for mask_request in request.masks: + + # This uses this list of parts: + # https://github.com/de-code/python-tf-bodypix/blob/develop/tf_bodypix/bodypix_js_utils/part_channels.py#L5 + + for part_name in request.parts: part_mask = result.get_part_mask( - mask=tf.identity(mask), part_names=mask_request.parts + mask=tf.identity(mask), part_names=[part_name] ).squeeze() - bodypix_mask = BodyPixMask() + if np.max(part_mask) == 0: + rospy.logwarn(f"No masks found for part {part_name}") + continue + bodypix_mask = BodyPixMask() bodypix_mask.mask = part_mask.flatten().astype(bool).tolist() bodypix_mask.shape = list(part_mask.shape) + bodypix_mask.part_name = part_name masks.append(bodypix_mask) - # construct poses response and neck coordinates - poses = result.get_poses() - # publish to debug topic if debug_publisher is not None: # create coloured mask with poses @@ -107,24 +126,66 @@ def detect( ) debug_publisher.publish(cv2_img.cv2_img_to_msg(coloured_mask)) - output_poses = [] - - for pose in poses: - pose_msg = BodyPixPose() - keypoints_msg = [] - - for i, keypoint in pose.keypoints.items(): - if camel_to_snake(keypoint.part) in request.masks[0].parts: - keypoint_msg = BodyPixKeypoint() - keypoint_msg.xy = [int(keypoint.position.x), int(keypoint.position.y)] - keypoint_msg.score = keypoint.score - keypoint_msg.part = keypoint.part - keypoints_msg.append(keypoint_msg) - - pose_msg.keypoints = keypoints_msg - output_poses.append(pose_msg) - - response = BodyPixDetectionResponse() - response.poses = output_poses + response = BodyPixMaskDetectionResponse() response.masks = masks return response + + +def detect_keypoints( + request: BodyPixKeypointDetectionRequest, + debug_publisher: rospy.Publisher = rospy.Publisher( + "/bodypix/keypoint_debug", SensorImage, queue_size=1 + ), +) -> BodyPixKeypointDetectionResponse: + + result, mask = run_inference(request.dataset, request.confidence, request.image_raw) + + poses = result.get_poses() + + detected_keypoints: List[BodyPixKeypoint] = [] + + for pose in poses: + for keypoint in pose.keypoints.values(): + # Check if keypoint is in the mask + x = int(keypoint.position.x) + y = int(keypoint.position.y) + try: + if mask[y, x] == 0: + continue + # Throws an error if the keypoint is out of bounds + # but not clear what type (some TF stuff) + except: + continue + rospy.loginfo(f"Keypoint {keypoint.part} at {x}, {y} is in mask") + detected_keypoints.append( + BodyPixKeypoint(keypoint_name=keypoint.part, x=x, y=y) + ) + + # publish to debug topic + if debug_publisher is not None: + # create coloured mask with poses + from tf_bodypix.draw import draw_poses + + coloured_mask = result.get_colored_part_mask(mask).astype(np.uint8) + coloured_mask = draw_poses( + coloured_mask.copy(), + poses, + keypoints_color=(255, 100, 100), + skeleton_color=(100, 100, 255), + ) + + # Add text of keypoints to image + for keypoint in detected_keypoints: + cv2.putText( + coloured_mask, + f"{keypoint.keypoint_name}", + (keypoint.x, keypoint.y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 255), + 2, + cv2.LINE_AA, + ) + debug_publisher.publish(cv2_img.cv2_img_to_msg(coloured_mask)) + + return BodyPixKeypointDetectionResponse(keypoints=detected_keypoints) diff --git a/common/vision/lasr_vision_msgs/CMakeLists.txt b/common/vision/lasr_vision_msgs/CMakeLists.txt index 3cde22083..1ef5676be 100644 --- a/common/vision/lasr_vision_msgs/CMakeLists.txt +++ b/common/vision/lasr_vision_msgs/CMakeLists.txt @@ -47,10 +47,8 @@ add_message_files( FILES Detection.msg Detection3D.msg - BodyPixPose.msg BodyPixKeypoint.msg BodyPixMask.msg - BodyPixMaskRequest.msg ) ## Generate services in the 'srv' folder @@ -58,7 +56,8 @@ add_service_files( FILES YoloDetection.srv YoloDetection3D.srv - BodyPixDetection.srv + BodyPixMaskDetection.srv + BodyPixKeypointDetection.srv Recognise.srv LearnFace.srv Vqa.srv diff --git a/common/vision/lasr_vision_msgs/msg/BodyPixKeypoint.msg b/common/vision/lasr_vision_msgs/msg/BodyPixKeypoint.msg index e875679af..b07bfc873 100644 --- a/common/vision/lasr_vision_msgs/msg/BodyPixKeypoint.msg +++ b/common/vision/lasr_vision_msgs/msg/BodyPixKeypoint.msg @@ -1,12 +1,8 @@ # Keypoint.msg -# int number of the parts following -# https://github.com/de-code/python-tf-bodypix/blob/develop/tf_bodypix/bodypix_js_utils/part_channels.py#L5 - -string part - -# the score of the body part -float64 score +# name of the keypoint +string keypoint_name # the x and y coordinates of the body part -int32[] xy +int32 x +int32 y diff --git a/common/vision/lasr_vision_msgs/msg/BodyPixMask.msg b/common/vision/lasr_vision_msgs/msg/BodyPixMask.msg index ed24a68c5..aec9b6128 100644 --- a/common/vision/lasr_vision_msgs/msg/BodyPixMask.msg +++ b/common/vision/lasr_vision_msgs/msg/BodyPixMask.msg @@ -5,3 +5,5 @@ bool[] mask # # Use in mask.reshape(...shape) to get back 2D array of mask uint32[] shape + +string part_name diff --git a/common/vision/lasr_vision_msgs/msg/BodyPixMaskRequest.msg b/common/vision/lasr_vision_msgs/msg/BodyPixMaskRequest.msg deleted file mode 100644 index 9ea1acf78..000000000 --- a/common/vision/lasr_vision_msgs/msg/BodyPixMaskRequest.msg +++ /dev/null @@ -1,5 +0,0 @@ -# List of parts -# -# A full list is available here: -# https://github.com/de-code/python-tf-bodypix/blob/develop/tf_bodypix/bodypix_js_utils/part_channels.py#L5 -string[] parts diff --git a/common/vision/lasr_vision_msgs/msg/BodyPixPose.msg b/common/vision/lasr_vision_msgs/msg/BodyPixPose.msg deleted file mode 100644 index 0d416eeae..000000000 --- a/common/vision/lasr_vision_msgs/msg/BodyPixPose.msg +++ /dev/null @@ -1 +0,0 @@ -BodyPixKeypoint[] keypoints \ No newline at end of file diff --git a/common/vision/lasr_vision_msgs/srv/BodyPixKeypointDetection.srv b/common/vision/lasr_vision_msgs/srv/BodyPixKeypointDetection.srv new file mode 100644 index 000000000..ca056b268 --- /dev/null +++ b/common/vision/lasr_vision_msgs/srv/BodyPixKeypointDetection.srv @@ -0,0 +1,12 @@ +# Image to run inference on +sensor_msgs/Image image_raw + +# BodyPix model to use +string dataset + +# How certain the detection should be to include +float32 confidence + +--- +# keypoints +lasr_vision_msgs/BodyPixKeypoint[] keypoints diff --git a/common/vision/lasr_vision_msgs/srv/BodyPixDetection.srv b/common/vision/lasr_vision_msgs/srv/BodyPixMaskDetection.srv similarity index 53% rename from common/vision/lasr_vision_msgs/srv/BodyPixDetection.srv rename to common/vision/lasr_vision_msgs/srv/BodyPixMaskDetection.srv index ba13e3271..913714495 100644 --- a/common/vision/lasr_vision_msgs/srv/BodyPixDetection.srv +++ b/common/vision/lasr_vision_msgs/srv/BodyPixMaskDetection.srv @@ -7,11 +7,10 @@ string dataset # How certain the detection should be to include float32 confidence -# The masks that should be generated -lasr_vision_msgs/BodyPixMaskRequest[] masks +# Name of parts to get the masks for +# A full list is available here: +# https://github.com/de-code/python-tf-bodypix/blob/develop/tf_bodypix/bodypix_js_utils/part_channels.py#L5 +string[] parts --- # Generated masks lasr_vision_msgs/BodyPixMask[] masks - -# Pose information -lasr_vision_msgs/BodyPixPose[] poses diff --git a/skills/src/lasr_skills/describe_people.py b/skills/src/lasr_skills/describe_people.py index 4f452b1aa..5538721cf 100755 --- a/skills/src/lasr_skills/describe_people.py +++ b/skills/src/lasr_skills/describe_people.py @@ -6,10 +6,10 @@ import cv2_img import numpy as np from lasr_skills import Say -from lasr_vision_msgs.msg import BodyPixMaskRequest from lasr_vision_msgs.srv import ( YoloDetection, - BodyPixDetection, + BodyPixMaskDetection, + BodyPixMaskDetectionRequest, TorchFaceFeatureDetectionDescription, ) from numpy2message import numpy2message @@ -147,16 +147,18 @@ def __init__(self): ], output_keys=["bodypix_masks"], ) - self.bodypix = rospy.ServiceProxy("/bodypix/detect", BodyPixDetection) + self.bodypix = rospy.ServiceProxy( + "/bodypix/mask_detection", BodyPixMaskDetection + ) def execute(self, userdata): try: - torso = BodyPixMaskRequest() - torso.parts = ["torso_front", "torso_back"] - head = BodyPixMaskRequest() - head.parts = ["left_face", "right_face"] - masks = [torso, head] - result = self.bodypix(userdata.img_msg, "resnet50", 0.7, masks) + request = BodyPixMaskDetectionRequest() + request.image_raw = userdata.img_msg + request.dataset = "resnet50" + request.confidence = 0.2 + request.parts = ["torso_front", "torso_back", "left_face", "right_face"] + result = self.bodypix(request) userdata.bodypix_masks = result.masks return "succeeded" except rospy.ServiceException as e: @@ -204,21 +206,20 @@ def execute(self, userdata): cv2.fillPoly( mask_image, pts=np.int32([contours]), color=(255, 255, 255) ) - mask_bin = mask_image > 128 - + mask_bin = mask_image > 0 + torso_mask = np.zeros((height, width), np.uint8) + head_mask = np.zeros((height, width), np.uint8) # process part masks - for bodypix_mask, part in zip( - userdata.bodypix_masks, ["torso", "head"] - ): - part_mask = np.array(bodypix_mask.mask).reshape( - bodypix_mask.shape[0], bodypix_mask.shape[1] + for part in userdata.bodypix_masks: + part_mask = np.array(part.mask).reshape( + part.shape[0], part.shape[1] ) # filter out part for current person segmentation try: part_mask[mask_bin == 0] = 0 except Exception: - rospy.logdebug("|> Failed to check {part} is visible") + rospy.logdebug(f"|> Failed to check {part} is visible") continue if part_mask.any(): @@ -227,10 +228,10 @@ def execute(self, userdata): rospy.logdebug(f"|> Person does not have {part} visible") continue - if part == "torso": - torso_mask = part_mask - elif part == "head": - head_mask = part_mask + if part.name == "torso_front" or part.name == "torso_back": + torso_mask = np.logical_or(torso_mask, part_mask) + elif part.name == "left_face" or part.name == "right_face": + head_mask = np.logical_or(head_mask, part_mask) torso_mask_data, torso_mask_shape, torso_mask_dtype = numpy2message( torso_mask diff --git a/skills/src/lasr_skills/detect_gesture.py b/skills/src/lasr_skills/detect_gesture.py index 438c83e9e..576bc9c11 100755 --- a/skills/src/lasr_skills/detect_gesture.py +++ b/skills/src/lasr_skills/detect_gesture.py @@ -1,13 +1,14 @@ #!/usr/bin/env python3 -from typing import Optional +from typing import Optional, List import smach import rospy import cv2 import cv2_img -from lasr_skills.vision import GetCroppedImage -from lasr_skills import PlayMotion -from lasr_vision_msgs.srv import BodyPixDetection, BodyPixDetectionRequest -from lasr_vision_msgs.msg import BodyPixMaskRequest +from lasr_skills.vision import GetCroppedImage, GetImage +from lasr_vision_msgs.srv import ( + BodyPixKeypointDetection, + BodyPixKeypointDetectionRequest, +) from sensor_msgs.msg import Image @@ -19,87 +20,99 @@ class DetectGesture(smach.State): def __init__( self, gesture_to_detect: Optional[str] = None, + bodypix_model: str = "resnet50", + bodypix_confidence: float = 0.1, buffer_width: int = 50, debug_publisher: str = "/skills/gesture_detection/debug", ): """Optionally stores the gesture to detect. If None, it will infer the gesture from the keypoints.""" smach.State.__init__( self, - outcomes=["succeeded", "failed"], + outcomes=["succeeded", "missing_keypoints", "failed"], input_keys=["img_msg"], output_keys=["gesture_detected"], ) self.gesture_to_detect = gesture_to_detect - self.body_pix_client = rospy.ServiceProxy("/bodypix/detect", BodyPixDetection) + self.bodypix_client = rospy.ServiceProxy( + "/bodypix/keypoint_detection", BodyPixKeypointDetection + ) + self.bodypix_model = bodypix_model + self.bodypix_confidence = bodypix_confidence self.debug_publisher = rospy.Publisher(debug_publisher, Image, queue_size=1) self.buffer_width = buffer_width + self.required_keypoints = [ + "leftWrist", + "leftShoulder", + "rightWrist", + "rightShoulder", + ] def execute(self, userdata): - body_pix_masks = BodyPixMaskRequest() - body_pix_masks.parts = [ - "left_shoulder", - "right_shoulder", - "left_wrist", - "right_wrist", - ] - masks = [body_pix_masks] - - req = BodyPixDetectionRequest() + req = BodyPixKeypointDetectionRequest() req.image_raw = userdata.img_msg - req.masks = masks - req.dataset = "resnet50" - req.confidence = 0.7 + req.dataset = self.bodypix_model + req.confidence = self.bodypix_confidence try: - res = self.body_pix_client(req) + res = self.bodypix_client(req) except Exception as e: print(e) return "failed" - part_info = {} - poses = res.poses - for pose in poses: - for keypoint in pose.keypoints: - part_info[keypoint.part] = { - "x": keypoint.xy[0], - "y": keypoint.xy[1], - "score": keypoint.score, - } - if ( - self.gesture_to_detect == "raising_left_arm" - or self.gesture_to_detect is None - ): - if part_info["leftWrist"]["y"] < part_info["leftShoulder"]["y"]: - self.gesture_to_detect = "raising_left_arm" - if ( - self.gesture_to_detect == "raising_right_arm" - or self.gesture_to_detect is None - ): - if part_info["rightWrist"]["y"] < part_info["rightShoulder"]["y"]: - self.gesture_to_detect = "raising_right_arm" - if ( - self.gesture_to_detect == "pointing_to_the_left" - or self.gesture_to_detect is None - ): + detected_keypoints = res.keypoints + + keypoint_info = { + keypoint.keypoint_name: {"x": keypoint.x, "y": keypoint.y} + for keypoint in detected_keypoints + } + + if "leftShoulder" in keypoint_info and "leftWrist" in keypoint_info: if ( - part_info["leftWrist"]["x"] - self.buffer_width - > part_info["leftShoulder"]["x"] + self.gesture_to_detect == "raising_left_arm" + or self.gesture_to_detect is None ): - self.gesture_to_detect = "pointing_to_the_left" + if keypoint_info["leftWrist"]["y"] < keypoint_info["leftShoulder"]["y"]: + self.gesture_to_detect = "raising_left_arm" + if ( + self.gesture_to_detect == "pointing_to_the_left" + or self.gesture_to_detect is None + ): + if ( + keypoint_info["leftWrist"]["x"] - self.buffer_width + > keypoint_info["leftShoulder"]["x"] + ): + self.gesture_to_detect = "pointing_to_the_left" + if ( - self.gesture_to_detect == "pointing_to_the_right" - or self.gesture_to_detect is None + "rightShoulder" in keypoint_info + and "rightWrist" in keypoint_info + and self.gesture_to_detect is None ): + print(keypoint_info["rightShoulder"]["x"], keypoint_info["rightWrist"]["x"]) + if ( + self.gesture_to_detect == "raising_right_arm" + or self.gesture_to_detect is None + ): + if ( + keypoint_info["rightWrist"]["y"] + < keypoint_info["rightShoulder"]["y"] + ): + self.gesture_to_detect = "raising_right_arm" if ( - part_info["rightShoulder"]["x"] - self.buffer_width - > part_info["rightWrist"]["x"] + self.gesture_to_detect == "pointing_to_the_right" + or self.gesture_to_detect is None ): - self.gesture_to_detect = "pointing_to_the_right" + if ( + keypoint_info["rightShoulder"]["x"] - self.buffer_width + > keypoint_info["rightWrist"]["x"] + ): + self.gesture_to_detect = "pointing_to_the_right" if self.gesture_to_detect is None: self.gesture_to_detect = "none" + rospy.loginfo(f"Detected gesture: {self.gesture_to_detect}") userdata.gesture_detected = self.gesture_to_detect cv2_gesture_img = cv2_img.msg_to_cv2_img(userdata.img_msg) @@ -132,14 +145,18 @@ def __init__(self, gesture_to_detect: Optional[str] = None): with self: smach.StateMachine.add( "GET_IMAGE", - GetCroppedImage("person", "nearest"), + GetCroppedImage("person", "centered", rgb_topic="/usb_cam/image_raw"), transitions={"succeeded": "BODY_PIX_DETECTION", "failed": "failed"}, ) smach.StateMachine.add( "BODY_PIX_DETECTION", DetectGesture(gesture_to_detect=self.gesture_to_detect), - transitions={"succeeded": "succeeded", "failed": "failed"}, + transitions={ + "succeeded": "succeeded", + "failed": "failed", + "missing_keypoints": "failed", + }, ) @@ -149,18 +166,5 @@ def __init__(self, gesture_to_detect: Optional[str] = None): while not rospy.is_shutdown(): sm = GestureDetectionSM() sm.execute() - gesture_state = PlayMotion(motion_name=sm.userdata.gesture_detected) - gesture_sm = smach.StateMachine(outcomes=["succeeded", "failed"]) - with gesture_sm: - smach.StateMachine.add( - "GESTURE_STATE", - gesture_state, - transitions={ - "succeeded": "succeeded", - "aborted": "failed", - "preempted": "failed", - }, - ) - gesture_sm.execute() rospy.spin() diff --git a/skills/src/lasr_skills/validate_keypoints.py b/skills/src/lasr_skills/validate_keypoints.py new file mode 100755 index 000000000..c7dc492d5 --- /dev/null +++ b/skills/src/lasr_skills/validate_keypoints.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +"""This skill checks whether a set of required bodypix keypoints can be detected in a given image.""" +from typing import List +import smach +import rospy +from lasr_vision_msgs.srv import ( + BodyPixKeypointDetection, + BodyPixKeypointDetectionRequest, +) +from lasr_skills.vision import GetCroppedImage + + +class ValidateKeypoints(smach.State): + + def __init__( + self, + keypoints_to_detect: List[str], + bodypix_model: str = "resnet50", + bodypix_confidence: float = 0.7, + ): + """Takes a list of keypoints to check for in the image. If any are missing, this will be returned + in the userdata. + + Args: + keypoints_to_detect (list[str]): List of keypoints to check for in the image. + + bodypix_model (str, optional): The bodypix model to use. Defaults to "resnet50". + + bodypix_confidence (float, optional): The confidence threshold for bodypix. Defaults to 0.7. + + + """ + smach.State.__init__( + self, + outcomes=["succeeded", "failed"], + input_keys=["img_msg"], + output_keys=["missing_keypoints"], + ) + self._keypoints_to_detect = keypoints_to_detect + self._bodypix_model = bodypix_model + self._bodypix_confidence = bodypix_confidence + self._bodypix_client = rospy.ServiceProxy( + "/bodypix/keypoint_detection", BodyPixKeypointDetection + ) + + def execute(self, userdata): + req = BodyPixKeypointDetectionRequest() + req.image_raw = userdata.img_msg + req.dataset = self._bodypix_model + req.confidence = self._bodypix_confidence + + try: + res = self._bodypix_client(req) + except Exception as e: + print(e) + return "failed" + + detected_keypoints = res.keypoints + + rospy.loginfo(f"Detected keypoints: {detected_keypoints}") + + keypoint_names = [keypoint.keypoint_name for keypoint in detected_keypoints] + + missing_keypoints = [ + keypoint + for keypoint in self._keypoints_to_detect + if keypoint not in keypoint_names + ] + + if missing_keypoints: + rospy.logwarn(f"Missing keypoints: {missing_keypoints}") + userdata.missing_keypoints = missing_keypoints + return "failed" + return "succeeded" + + +if __name__ == "__main__": + rospy.init_node("validate_keypoints") + while not rospy.is_shutdown(): + get_cropped_image = GetCroppedImage( + "person", + crop_method="centered", + rgb_topic="/usb_cam/image_raw", + ) + sm = smach.StateMachine(outcomes=["succeeded", "failed"]) + with sm: + smach.StateMachine.add( + "GET_CROPPED_IMAGE", + get_cropped_image, + transitions={"succeeded": "VALIDATE_KEYPOINTS", "failed": "failed"}, + ) + smach.StateMachine.add( + "VALIDATE_KEYPOINTS", + ValidateKeypoints(["nose"]), + transitions={"succeeded": "succeeded", "failed": "failed"}, + ) + sm.execute() + input("Press Enter to continue...")