diff --git a/README.md b/README.md index f4ee4f2..156f229 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,25 @@ # edgeimpulse_ros -ROS2 wrapper for Edge Impulse +ROS2 wrapper for Edge Impulse on Linux. -## How to install + +## 1. Topics + +- `/detection/input/image`, image topic to analyze +- `/detection/output/image`, image with bounding boxes +- `/detection/output/info`, VisionInfo message +- `/detection/output/results`, results as text + +## 2. Parameters + +- `frame_id` (**string**), _"base_link"_, frame id of output topics +- `model.filepath` (**string**), _""_, absolute filepath to .eim file +- `show.overlay` (**bool**), _true_, show bounding boxes on output image +- `show.labels` (**bool**), _true_, show labels on bounding boxes, +- `show.classification_info` (**bool**), _true_, show the attendibility (0-1) of the prediction + + +## 3. How to install 1. install edge_impulse_linux:
`pip3 install edge_impulse_linux` @@ -30,12 +47,23 @@ ROS2 wrapper for Edge Impulse `source install/setup.bash`
-## How to run +## 4. How to run Launch the node:
`ros2 run edgeimpulse_ros image_classification --ros-args -p model.filepath:="" -r /detection/input/image:="/your_image_topic"` `
+## 5. Models + +Here you find some prebuilt models: [https://github.com/gbr1/edgeimpulse_example_models](https://github.com/gbr1/edgeimpulse_example_models) + +## 6. Known issues + +- this wrapper works on foxy, galactic and humble are coming soon (incompatibility on vision msgs by ros-perception) +- if you use a classification model, topic results is empty +- you cannot change color of bounding boxes (coming soon) +- other types (imu and sound based ml) are unavailable + ***Copyright © 2022 Giovanni di Dio Bruno - gbr1.github.io*** diff --git a/edgeimpulse_ros/__pycache__/__init__.cpython-38.pyc b/edgeimpulse_ros/__pycache__/__init__.cpython-38.pyc index 93d905d..af7ee0b 100644 Binary files a/edgeimpulse_ros/__pycache__/__init__.cpython-38.pyc and b/edgeimpulse_ros/__pycache__/__init__.cpython-38.pyc differ diff --git a/edgeimpulse_ros/__pycache__/image_classification.cpython-38.pyc b/edgeimpulse_ros/__pycache__/image_classification.cpython-38.pyc index e1ecdbe..59650d8 100644 Binary files a/edgeimpulse_ros/__pycache__/image_classification.cpython-38.pyc and b/edgeimpulse_ros/__pycache__/image_classification.cpython-38.pyc differ diff --git a/edgeimpulse_ros/image_classification.py b/edgeimpulse_ros/image_classification.py index ab750fb..424103c 100644 --- a/edgeimpulse_ros/image_classification.py +++ b/edgeimpulse_ros/image_classification.py @@ -13,6 +13,7 @@ # limitations under the License. +from distutils.log import info from unittest import result from .submodules import device_patches import cv2 @@ -24,8 +25,10 @@ from sensor_msgs.msg import Image -#from vision_msgs.msg import BoundingBox2DArray -#from vision_msgs.msg import VisionInfo +from vision_msgs.msg import Detection2DArray +from vision_msgs.msg import Detection2D +from vision_msgs.msg import ObjectHypothesisWithPose +from vision_msgs.msg import VisionInfo import os import time @@ -35,31 +38,46 @@ + class EI_Image_node(Node): + def __init__(self): self.occupied = False self.img = None + self.info_msg = VisionInfo() self.cv_bridge = CvBridge() super().__init__('ei_image_classifier_node') self.init_parameters() self.ei_classifier = self.EI_Classifier(self.modelfile, self.get_logger()) - #self.publisher = self.create_publisher(BoundingBox2DArray,'/edge_impulse/detection',1) + + self.info_msg.header.frame_id = self.frame_id + self.info_msg.method = self.ei_classifier.model_info['model_parameters']['model_type'] + self.info_msg.database_location = self.ei_classifier.model_info['project']['name']+' / '+self.ei_classifier.model_info['project']['owner'] + self.info_msg.database_version = self.ei_classifier.model_info['project']['deploy_version'] self.timer_parameter = self.create_timer(2,self.parameters_callback) self.image_publisher = self.create_publisher(Image,'/detection/output/image',1) + self.results_publisher = self.create_publisher(Detection2DArray,'/detection/output/results',1) + self.info_publisher = self.create_publisher(VisionInfo, '/detection/output/info',1) + self.timer_classify = self.create_timer(0.01,self.classify_callback) self.timer_classify.cancel() self.subscription = self.create_subscription(Image,'/detection/input/image',self.listener_callback,1) self.subscription - + + + def init_parameters(self): self.declare_parameter('model.filepath','') self.modelfile= self.get_parameter('model.filepath').get_parameter_value().string_value + self.declare_parameter('frame_id','base_link') + self.frame_id= self.get_parameter('frame_id').get_parameter_value().string_value + self.declare_parameter('show.overlay', True) self.show_overlay = self.get_parameter('show.overlay').get_parameter_value().bool_value @@ -72,9 +90,6 @@ def init_parameters(self): - - - def parameters_callback(self): self.show_labels_on_image = self.get_parameter('show.labels').get_parameter_value().bool_value self.show_extra_classification_info = self.get_parameter('show.classification_info').get_parameter_value().bool_value @@ -91,13 +106,25 @@ def listener_callback(self, msg): self.img = current_frame self.timer_classify.reset() + + + + def classify_callback(self): self.occupied = True + + # vision msgs + results_msg = Detection2DArray() + time_now = self.get_clock().now().to_msg() + results_msg.header.stamp = time_now + results_msg.header.frame_id = self.frame_id + # classify features, cropped, res = self.ei_classifier.classify(self.img) - #prepare output + + #p repare output if "classification" in res["result"].keys(): if self.show_extra_classification_info: self.get_logger().info('Result (%d ms.) ' % (res['timing']['dsp'] + res['timing']['classification']), end='') @@ -112,6 +139,28 @@ def classify_callback(self): self.get_logger().info('Found %d bounding boxes (%d ms.)' % (len(res["result"]["bounding_boxes"]), res['timing']['dsp'] + res['timing']['classification'])) for bb in res["result"]["bounding_boxes"]: + result_msg = Detection2D() + result_msg.header.stamp = time_now + result_msg.header.frame_id = self.frame_id + + # object with hypthothesis + obj_hyp = ObjectHypothesisWithPose() + obj_hyp.id = bb['label'] #str(self.ei_classifier.labels.index(bb['label'])) + obj_hyp.score = bb['value'] + obj_hyp.pose.pose.position.x = float(bb['x']) + obj_hyp.pose.pose.position.y = float(bb['y']) + result_msg.results.append(obj_hyp) + + # bounding box + result_msg.bbox.center.x = float(bb['x']) + result_msg.bbox.center.y = float(bb['y']) + result_msg.bbox.size_x = float(bb['width']) + result_msg.bbox.size_y = float(bb['height']) + + + results_msg.detections.append(result_msg) + + # image if self.show_extra_classification_info: self.get_logger().info('%s (%.2f): x=%d y=%d w=%d h=%d' % (bb['label'], bb['value'], bb['x'], bb['y'], bb['width'], bb['height'])) if self.show_overlay: @@ -124,6 +173,9 @@ def classify_callback(self): # publish message self.image_publisher.publish(self.cv_bridge.cv2_to_imgmsg(cropped,"bgr8")) + self.results_publisher.publish(results_msg) + self.info_msg.header.stamp = time_now + self.info_publisher.publish(self.info_msg) self.occupied= False self.timer_classify.cancel() @@ -166,6 +218,8 @@ def classify(self, img): self.logger.error('Error on classification') + + def main(): rclpy.init() node = EI_Image_node() @@ -174,7 +228,6 @@ def main(): node.destroy_node() rclpy.shutdown() - if __name__ == "__main__": main() diff --git a/edgeimpulse_ros/submodules/__pycache__/device_patches.cpython-38.pyc b/edgeimpulse_ros/submodules/__pycache__/device_patches.cpython-38.pyc index 9346394..4c735e9 100644 Binary files a/edgeimpulse_ros/submodules/__pycache__/device_patches.cpython-38.pyc and b/edgeimpulse_ros/submodules/__pycache__/device_patches.cpython-38.pyc differ diff --git a/package.xml b/package.xml index 92eafb6..79ae856 100644 --- a/package.xml +++ b/package.xml @@ -14,7 +14,7 @@ ament_pep257 python3-pytest - + vision_msgs sensor_msgs ros2launch diff --git a/setup.py b/setup.py index a843408..4d5aa3d 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ submodules = 'edgeimpulse_ros/submodules' setup( name=package_name, - version='0.0.1', + version='0.0.2', packages=[package_name, submodules], data_files=[ ('share/ament_index/resource_index/packages',