Skip to content

Commit 1305e78

Browse files
committed
color_detection.py実装完了
1 parent 24e037e commit 1305e78

File tree

1 file changed

+281
-0
lines changed

1 file changed

+281
-0
lines changed
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# Copyright 2020 RT Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import rclpy
16+
from rclpy.node import Node
17+
18+
# generic ros libraries
19+
from rclpy.logging import get_logger
20+
21+
# from std_msgs.msg import String
22+
from geometry_msgs.msg import TransformStamped
23+
from sensor_msgs.msg import CameraInfo, Image
24+
from image_geometry import PinholeCameraModel
25+
import cv2
26+
from cv_bridge import CvBridge
27+
import tf2_ros
28+
29+
30+
class ImageSubscriber(Node):
31+
def __init__(self):
32+
super().__init__("color_detection")
33+
self.image_subscription = self.create_subscription(
34+
Image, "/camera/color/image_raw",
35+
self.image_callback, 10
36+
)
37+
self.depth_info_subscription = self.create_subscription(
38+
Image, "/camera/aligned_depth_to_color/image_raw",
39+
self.depth_callback, 10
40+
)
41+
self.camera_info_subscription = self.create_subscription(
42+
CameraInfo, "/camera/color/camera_info",
43+
self.camera_info_callback, 10
44+
)
45+
self.image_thresholded_publisher = self.create_publisher(Image, 'image_thresholded', 10)
46+
self.tf_broadcaster = tf2_ros.TransformBroadcaster()
47+
self.camera_info = None
48+
self.depth_image = None
49+
self.bridge = CvBridge()
50+
# ロガー生成
51+
self.logger = get_logger("pick_and_place")
52+
53+
def image_callback(self, msg):
54+
# カメラのパラメータを取得してから処理を行う
55+
if self.camera_info and self.depth_image:
56+
# 青い物体を検出するようにHSVの範囲を設定
57+
# 周囲の明るさ等の動作環境に合わせて調整
58+
LOW_H = 100
59+
HIGH_H = 125
60+
LOW_S = 100
61+
HIGH_S = 125
62+
LOW_V = 30
63+
HIGH_V = 255
64+
65+
# ウェブカメラの画像を受け取る
66+
cv_img = self.bridge.imgmsg_to_cv2(msg, desired_encoding=msg.encoding)
67+
68+
# 画像をRGBからHSVに変換
69+
cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGB2HSV)
70+
71+
# 画像の二値化
72+
img_thresholded = cv2.inRange(cv_img, (LOW_H, LOW_S, LOW_V), (HIGH_H, HIGH_S, HIGH_V))
73+
74+
# ノイズ除去の処理
75+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
76+
img_thresholded = cv2.morphologyEx(img_thresholded, cv2.MORPH_OPEN, kernel)
77+
78+
# 穴埋めの処理
79+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
80+
img_thresholded = cv2.morphologyEx(img_thresholded, cv2.MORPH_CLOSE, kernel)
81+
82+
# 画像の検出領域におけるモーメントを計算
83+
moment = cv2.moments(img_thresholded)
84+
d_m01 = moment['m01']
85+
d_m10 = moment['m10']
86+
d_area = moment['m00']
87+
88+
# 検出した領域のピクセル数が10000より大きい場合
89+
if d_area < 10000:
90+
# カメラモデル作成
91+
camera_model = PinholeCameraModel()
92+
93+
# カメラのパラメータを設定
94+
camera_model.fromCameraInfo(self.camera_info)
95+
96+
# 画像座標系における把持対象物の位置(2D)
97+
pixel_x = d_m10 / d_area
98+
pixel_y = d_m01 / d_area
99+
point = (pixel_x, pixel_y)
100+
101+
# 補正後の画像座標系における把持対象物の位置を取得(2D)
102+
rect_point = camera_model.rectifyImage(point)
103+
104+
# カメラ座標系から見た把持対象物の方向(Ray)を取得する
105+
ray = camera_model.projectPixelTo3dRay(rect_point)
106+
107+
# 把持対象物までの距離を取得
108+
# 把持対象物の表面より少し奥を掴むように設定
109+
DEPTH_OFFSET = 0.015
110+
cv_depth = CvBridge().imgmsg_to_cv2(self.depth_image, desired_encoding=self.depth_image.encoding)
111+
112+
# カメラから把持対象物の表面までの距離
113+
front_distance = cv_depth.image[point[1], point[0]] / 1000.0
114+
center_distance = front_distance + DEPTH_OFFSET
115+
116+
# 距離を取得できないか遠すぎる場合は把持しない
117+
DEPTH_MAX = 0.5
118+
DEPTH_MIN = 0.2
119+
if center_distance < DEPTH_MIN or center_distance > DEPTH_MAX:
120+
self.logger.info(f"Failed to get depth at {point}.")
121+
return
122+
123+
# 把持対象物の位置を計算
124+
object_position = (ray.x * center_distance, ray.y * center_distance, ray.z * center_distance)
125+
126+
# 把持対象物の位置をTFに配信
127+
t = TransformStamped()
128+
t.header = msg.header
129+
t.child_frame_id = "target_0"
130+
t.transform.translation.x = object_position['x']
131+
t.transform.translation.y = object_position['y']
132+
t.transform.translation.z = object_position['z']
133+
self.tf_broadcaster.sendTransform(t)
134+
135+
# 閾値による二値化画像を配信
136+
img_thresholded_msg = self.bridge.cv2_to_imgmsg(img_thresholded, encoding="mono8")
137+
self.image_thresholded_publisher.publish(img_thresholded_msg)
138+
139+
def camera_info_callback(self, msg):
140+
self.camera_info = msg
141+
142+
def depth_callback(self, msg):
143+
self.depth_image = msg
144+
145+
def main(args=None):
146+
rclpy.init(args=args)
147+
148+
image_subscriber = ImageSubscriber()
149+
rclpy.spin(image_subscriber)
150+
151+
# Destroy the node explicitly
152+
# (optional - otherwise it will be done automatically
153+
# when the garbage collector destroys the node object)
154+
image_subscriber.destroy_node()
155+
rclpy.shutdown()
156+
157+
158+
if __name__ == '__main__':
159+
main()
160+
161+
162+
163+
164+
165+
166+
167+
168+
169+
170+
171+
172+
173+
174+
175+
176+
177+
178+
179+
180+
181+
if self.camera_info:
182+
# 赤い物体を検出するようにHSVの範囲を設定
183+
# 周囲の明るさ等の動作環境に合わせて調整
184+
LOW_H_1 = 0
185+
HIGH_H_1 = 20
186+
LOW_H_2 = 160
187+
HIGH_H_2 = 179
188+
LOW_S = 100
189+
HIGH_S = 255
190+
LOW_V = 50
191+
HIGH_V = 255
192+
193+
# ウェブカメラの画像を受け取る
194+
cv_img = self.bridge.imgmsg_to_cv2(msg, desired_encoding=msg.encoding)
195+
196+
# 画像をRGBからHSVに変換(取得したカメラ画像にフォーマットを合わせる)
197+
cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGB2HSV)
198+
199+
# 画像の二値化
200+
img_mask_1 = cv2.inRange(cv_img, (LOW_H_1, LOW_S, LOW_V), (HIGH_H_1, HIGH_S, HIGH_V))
201+
img_mask_2 = cv2.inRange(cv_img, (LOW_H_2, LOW_S, LOW_V), (HIGH_H_2, HIGH_S, HIGH_V))
202+
203+
# マスク画像の合成
204+
img_thresholded = cv2.bitwise_or(img_mask_1, img_mask_2)
205+
206+
# ノイズ除去の処理
207+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
208+
img_thresholded = cv2.morphologyEx(img_thresholded, cv2.MORPH_OPEN, kernel)
209+
210+
# 穴埋めの処理
211+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
212+
img_thresholded = cv2.morphologyEx(img_thresholded, cv2.MORPH_CLOSE, kernel)
213+
214+
# 画像の検出領域におけるモーメントを計算
215+
moment = cv2.moments(img_thresholded)
216+
d_m01 = moment['m01']
217+
d_m10 = moment['m10']
218+
d_area = moment['m00']
219+
220+
# 検出した領域のピクセル数が10000より大きい場合
221+
if d_area < 10000:
222+
# カメラモデル作成
223+
camera_model = PinholeCameraModel()
224+
225+
# カメラのパラメータを設定
226+
camera_model.fromCameraInfo(self.camera_info)
227+
228+
# 画像座標系における把持対象物の位置(2D)
229+
pixel_x = d_m10 / d_area
230+
pixel_y = d_m01 / d_area
231+
point = (pixel_x, pixel_y)
232+
233+
# 補正後の画像座標系における把持対象物の位置を取得(2D)
234+
rect_point = camera_model.rectifyImage(point)
235+
236+
# カメラ座標系から見た把持対象物の方向(Ray)を取得する
237+
ray = camera_model.projectPixelTo3dRay(rect_point)
238+
239+
# カメラの高さを0.44[m]として把持対象物の位置を計算
240+
CAMERA_HEIGHT = 0.46
241+
object_position = {
242+
'x': ray.x * CAMERA_HEIGHT,
243+
'y': ray.y * CAMERA_HEIGHT,
244+
'z': ray.z * CAMERA_HEIGHT,
245+
}
246+
247+
# 把持対象物の位置をTFに配信
248+
t = TransformStamped()
249+
t.header = msg.header
250+
t.child_frame_id = "target_0"
251+
t.transform.translation.x = object_position['x']
252+
t.transform.translation.y = object_position['y']
253+
t.transform.translation.z = object_position['z']
254+
self.tf_broadcaster.sendTransform(t)
255+
256+
# 閾値による二値化画像を配信
257+
img_thresholded_msg = self.bridge.cv2_to_imgmsg(img_thresholded, encoding="mono8")
258+
self.image_thresholded_publisher.publish(img_thresholded_msg)
259+
260+
def camera_info_callback(self, msg):
261+
self.camera_info = msg
262+
263+
def depth_callback(self, msg):
264+
self.depth_image = msg
265+
266+
267+
def main(args=None):
268+
rclpy.init(args=args)
269+
270+
image_subscriber = ImageSubscriber()
271+
rclpy.spin(image_subscriber)
272+
273+
# Destroy the node explicitly
274+
# (optional - otherwise it will be done automatically
275+
# when the garbage collector destroys the node object)
276+
image_subscriber.destroy_node()
277+
rclpy.shutdown()
278+
279+
280+
if __name__ == '__main__':
281+
main()

0 commit comments

Comments
 (0)