-
Notifications
You must be signed in to change notification settings - Fork 0
/
tl_detector.py
executable file
·212 lines (167 loc) · 7.3 KB
/
tl_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
#!/usr/bin/env python
import rospy
import tf
import yaml
from cv_bridge import CvBridge
from geometry_msgs.msg import PoseStamped
from scipy.spatial import KDTree
from sensor_msgs.msg import Image
from std_msgs.msg import Int32
from styx_msgs.msg import Lane
from styx_msgs.msg import TrafficLightArray, TrafficLight
from light_classification.tl_classifier import TLClassifier
STATE_COUNT_THRESHOLD = 3
class TLDetector(object):
def __init__(self):
rospy.init_node('tl_detector')
self.pose = None
self.waypoints = None
self.waypoints_2d = None
self.camera_image = None
self.has_image = False
self.lights = []
self.waypoint_tree = None
sub1 = rospy.Subscriber('/current_pose', PoseStamped, self.pose_cb, queue_size=1)
sub2 = rospy.Subscriber('/base_waypoints', Lane, self.waypoints_cb, queue_size=1)
'''
/vehicle/traffic_lights provides you with the location of the traffic light in 3D map space and
helps you acquire an accurate ground truth data source for the traffic light
classifier by sending the current color state of all traffic lights in the
simulator. When testing on the vehicle, the color state will not be available. You'll need to
rely on the position of the light and the camera image to predict it.
'''
sub3 = rospy.Subscriber('/vehicle/traffic_lights', TrafficLightArray, self.traffic_cb)
sub6 = rospy.Subscriber('/image_color', Image, self.image_cb)
config_string = rospy.get_param("/traffic_light_config")
self.config = yaml.load(config_string)
self.upcoming_red_light_pub = rospy.Publisher('/traffic_waypoint', Int32, queue_size=1)
self.bridge = CvBridge()
self.light_classifier = TLClassifier()
self.listener = tf.TransformListener()
self.state = TrafficLight.UNKNOWN
self.last_state = TrafficLight.UNKNOWN
self.last_wp = -1
self.state_count = 0
rospy.spin()
#self.loop()
def loop(self):
rate = rospy.Rate(5) # Hz
while not rospy.is_shutdown():
if self.pose is not None and self.waypoints is not None and self.has_image:
light_wp, state = self.process_traffic_lights()
'''
Publish upcoming red lights at camera frequency. Each predicted state has to
occur `STATE_COUNT_THRESHOLD` number of times till we start using it.
Otherwise the previous stable state is used.
'''
if self.state != state:
self.state_count = 0
self.state = state
elif self.state_count >= STATE_COUNT_THRESHOLD:
self.last_state = self.state
light_wp = light_wp if state == TrafficLight.RED else -1
self.last_wp = light_wp
self.upcoming_red_light_pub.publish(Int32(light_wp))
else:
self.upcoming_red_light_pub.publish(Int32(self.last_wp))
self.state_count += 1
rate.sleep()
def pose_cb(self, msg):
self.pose = msg
def waypoints_cb(self, waypoints):
self.waypoints = waypoints
if not self.waypoints_2d:
self.waypoints_2d = [
[point.pose.pose.position.x, point.pose.pose.position.y]
for point in waypoints.waypoints
]
self.waypoint_tree = KDTree(self.waypoints_2d)
def traffic_cb(self, msg):
self.lights = msg.lights
def image_cb(self, msg):
"""Identifies red lights in the incoming camera image and publishes the index
of the waypoint closest to the red light's stop line to /traffic_waypoint
Args:
msg (Image): image from car-mounted camera
"""
self.has_image = True
self.camera_image = msg
light_wp, state = self.process_traffic_lights()
'''
Publish upcoming red lights at camera frequency.
Each predicted state has to occur `STATE_COUNT_THRESHOLD` number
of times till we start using it. Otherwise the previous stable state is
used.
'''
if self.state != state:
self.state_count = 0
self.state = state
elif self.state_count >= STATE_COUNT_THRESHOLD:
self.last_state = self.state
light_wp = light_wp if state == TrafficLight.RED else -1
self.last_wp = light_wp
self.upcoming_red_light_pub.publish(Int32(light_wp))
else:
self.upcoming_red_light_pub.publish(Int32(self.last_wp))
self.state_count += 1
def get_closest_waypoint(self, x, y):
"""Identifies the closest path waypoint to the given position
https://en.wikipedia.org/wiki/Closest_pair_of_points_problem
Args:
pose (Pose): position to match a waypoint to
Returns:
int: index of the closest waypoint in self.waypoints
"""
closest_idx = self.waypoint_tree.query([x, y], 1)[1]
return closest_idx
def get_light_state(self, light):
"""Determines the current color of the traffic light
Args:
light (TrafficLight): light to classify
Returns:
int: ID of traffic light color (specified in styx_msgs/TrafficLight)
"""
# if(not self.has_image):
# self.prev_light_loc = None
# return False
#
# cv_image = self.bridge.imgmsg_to_cv2(self.camera_image, "bgr8")
#
# #Get classification
# return self.light_classifier.get_classification(cv_image)
# Using this for testing:
return light.state
def process_traffic_lights(self):
"""Finds closest visible traffic light, if one exists, and determines its
location and color
Returns:
int: index of waypoint closes to the upcoming stop line for a traffic light (-1 if none exists)
int: ID of traffic light color (specified in styx_msgs/TrafficLight)
"""
closest_light = None
line_wp_idx = None
# List of positions that correspond to the line to stop in front of for a given intersection
stop_line_positions = self.config['stop_line_positions']
if (self.pose):
car_wp_idx = self.get_closest_waypoint(self.pose.pose.position.x, self.pose.pose.position.y)
# find the closest visible traffic light (if one exists)
diff = len(self.waypoints.waypoints)
for i, light in enumerate(self.lights):
# Get stop line waypoint index
line = stop_line_positions[i]
temp_wp_idx = self.get_closest_waypoint(line[0], line[1])
# Find closest stop line waypoint index
d = temp_wp_idx - car_wp_idx
if d >= 0 and d < diff:
diff = d
closest_light = light
line_wp_idx = temp_wp_idx
if closest_light:
state = self.get_light_state(closest_light)
return line_wp_idx, state
return -1, TrafficLight.UNKNOWN
if __name__ == '__main__':
try:
TLDetector()
except rospy.ROSInterruptException:
rospy.logerr('Could not start traffic node.')