Skip to content

Commit 61b7018

Browse files
committed
Move new test cases fromdraft PR #158
1 parent b0b2556 commit 61b7018

File tree

2 files changed

+78
-34
lines changed

2 files changed

+78
-34
lines changed

tests/test_drive.py

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
from invertedai.api.light import light
99
from invertedai.error import InvalidRequestError
1010

11+
def recurrent_states_helper(states_to_extend):
12+
result = [0.0] * 128
13+
result.extend(states_to_extend)
14+
return result
15+
16+
1117
positive_tests = [
1218
("carla:Town04",
1319
None,
@@ -20,20 +26,20 @@
2026
dict(agent_type='car'),
2127
dict()],
2228
False, 5),
23-
# ("canada:drake_street_and_pacific_blvd",
24-
# None,
25-
# [dict(agent_type='car'),
26-
# dict(),
27-
# dict(agent_type='pedestrian'),
28-
# dict()],
29-
# False, 5),
30-
# ("canada:drake_street_and_pacific_blvd",
31-
# None,
32-
# [dict(agent_type='pedestrian'),
33-
# dict(agent_type='pedestrian'),
34-
# dict(agent_type='pedestrian'),
35-
# dict(agent_type='pedestrian')],
36-
# False, 5),
29+
("canada:drake_street_and_pacific_blvd",
30+
None,
31+
[dict(agent_type='car'),
32+
dict(),
33+
dict(agent_type='pedestrian'),
34+
dict()],
35+
False, 5),
36+
("canada:drake_street_and_pacific_blvd",
37+
None,
38+
[dict(agent_type='pedestrian'),
39+
dict(agent_type='pedestrian'),
40+
dict(agent_type='pedestrian'),
41+
dict(agent_type='pedestrian')],
42+
False, 5),
3743
("canada:ubc_roundabout",
3844
[[dict(center=dict(x=60.82, y=1.22), orientation=0.63, speed=11.43),
3945
dict(center=dict(x=-36.88, y=-33.93), orientation=-2.64, speed=9.43)]],
@@ -53,16 +59,40 @@
5359
False, None),
5460
]
5561
negative_tests = [
56-
("canada:drake_street_and_pacific_blvd",
57-
None,
58-
[dict(agent_type='car'),
59-
dict(),
60-
dict(agent_type='bicycle'),
61-
dict()],
62-
False, 5),
62+
("carla:Town03",
63+
[dict(center=dict(x=-21.2, y=-17.11), orientation=4.54, speed=1.8),
64+
dict(center=dict(x=-5.81, y=-49.47), orientation=1.62, speed=11.4)],
65+
[dict(length=0.97, agent_type="pedestrian"),
66+
dict(length=4.86, width=2.12, rear_axis_offset=1.85, agent_type='car')],
67+
[dict(packed=recurrent_states_helper(
68+
[21.203039169311523, -17.10862159729004, -1.3971855640411377, 1.7982733249664307])),
69+
dict(packed=recurrent_states_helper(
70+
[5.810295104980469, -49.47068786621094, 1.5232856273651123, 11.404326438903809]))],
71+
False),
72+
("carla:Town03",
73+
[dict(center=dict(x=-21.2, y=-17.11), orientation=4.54, speed=1.8),
74+
dict(center=dict(x=-5.81, y=-49.47), orientation=1.62, speed=11.4)],
75+
[dict(length=0.97, width=1.06, rear_axis_offset=None, agent_type="pedestrian"),
76+
dict(length=4.86, width=2.12, agent_type='car')],
77+
[dict(packed=recurrent_states_helper(
78+
[21.203039169311523, -17.10862159729004, -1.3971855640411377, 1.7982733249664307])),
79+
dict(packed=recurrent_states_helper(
80+
[5.810295104980469, -49.47068786621094, 1.5232856273651123, 11.404326438903809]))],
81+
False),
82+
("carla:Town03",
83+
[dict(center=dict(x=-21.2, y=-17.11), orientation=4.54, speed=1.8),
84+
dict(center=dict(x=-5.81, y=-49.47), orientation=1.62, speed=11.4)],
85+
[dict(length=0.97, width=1.06, agent_type="pedestrian"),
86+
dict(width=2.12, rear_axis_offset=1.85)],
87+
[dict(packed=recurrent_states_helper(
88+
[21.203039169311523, -17.10862159729004, -1.3971855640411377, 1.7982733249664307])),
89+
dict(packed=recurrent_states_helper(
90+
[5.810295104980469, -49.47068786621094, 1.5232856273651123, 11.404326438903809]))],
91+
False),
6392
]
6493

65-
def run_drive(location, states_history, agent_attributes, get_infractions, agent_count, simulation_length: int = 20):
94+
def run_initialize_drive_flow(location, states_history, agent_attributes, get_infractions, agent_count,
95+
simulation_length: int = 20):
6696
location_info_response = location_info(location=location, rendering_fov=200)
6797
if any(actor.agent_type == "traffic-light" for actor in location_info_response.static_actors):
6898
scene_has_lights = True
@@ -94,13 +124,28 @@ def run_drive(location, states_history, agent_attributes, get_infractions, agent
94124
assert isinstance(updated_state,
95125
DriveResponse) and updated_state.agent_states is not None and updated_state.recurrent_states is not None
96126

127+
128+
def run_direct_drive(location, agent_states, agent_attributes, recurrent_states, get_infractions):
129+
drive_response = drive(
130+
agent_attributes=agent_attributes,
131+
agent_states=agent_states,
132+
recurrent_states=recurrent_states,
133+
traffic_lights_states=None,
134+
get_birdview=False,
135+
location=location,
136+
get_infractions=get_infractions
137+
)
138+
assert isinstance(drive_response,
139+
DriveResponse) and drive_response.agent_states is not None and drive_response.recurrent_states is not None
140+
141+
97142
@pytest.mark.parametrize("location, states_history, agent_attributes, get_infractions, agent_count", negative_tests)
98143
def test_negative(location, states_history, agent_attributes, get_infractions, agent_count,
99144
simulation_length: int = 20):
100145
with pytest.raises(InvalidRequestError):
101-
run_drive(location, states_history, agent_attributes, get_infractions, agent_count)
146+
run_direct_drive(location, states_history, agent_attributes, get_infractions, agent_count)
102147

103148
@pytest.mark.parametrize("location, states_history, agent_attributes, get_infractions, agent_count", positive_tests)
104149
def test_postivie(location, states_history, agent_attributes, get_infractions, agent_count,
105150
simulation_length: int = 20):
106-
run_drive(location, states_history, agent_attributes, get_infractions, agent_count)
151+
run_initialize_drive_flow(location, states_history, agent_attributes, get_infractions, agent_count)

tests/test_initialize.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
dict(center=dict(x=-46.62, y=-25.02), orientation=0.04, speed=1.09)],
2121
[dict(center=dict(x=-31.1, y=-23.36), orientation=2.21, speed=0.11),
2222
dict(center=dict(x=-47.62, y=-23.02), orientation=0.04, speed=1.09)]],
23-
[dict(length=1.39, width=1.78, rear_axis_offset=0.0, agent_type='pedestrian'),
23+
[dict(length=1.39, width=1.78, agent_type='pedestrian'),
2424
dict(length=1.37, width=1.98, rear_axis_offset=0.0, agent_type='pedestrian'),
2525
dict(agent_type='pedestrian'),
2626
dict(agent_type='car')],
@@ -31,7 +31,7 @@
3131
[dict(center=dict(x=-31.1, y=-23.36), orientation=2.21, speed=0.11),
3232
dict(center=dict(x=-47.62, y=-23.02), orientation=0.04, speed=1.09)]],
3333
[dict(length=1.39, width=1.78, rear_axis_offset=0.0, agent_type='pedestrian'),
34-
dict(length=1.37, width=1.98, rear_axis_offset=0.0, agent_type='pedestrian'),
34+
dict(length=1.37, width=1.98, agent_type='pedestrian'),
3535
dict(agent_type='car'),
3636
dict()],
3737
False, None),
@@ -40,8 +40,8 @@
4040
dict(center=dict(x=-46.62, y=-25.02), orientation=0.04, speed=1.09)],
4141
[dict(center=dict(x=-31.1, y=-23.36), orientation=2.21, speed=0.11),
4242
dict(center=dict(x=-47.62, y=-23.02), orientation=0.04, speed=1.09)]],
43-
[dict(length=1.39, width=1.78, rear_axis_offset=0.0, agent_type='pedestrian'),
44-
dict(length=1.37, width=1.98, rear_axis_offset=0.0, agent_type='pedestrian'),
43+
[dict(length=1.39, width=1.78, agent_type='pedestrian'),
44+
dict(length=1.37, width=1.98, rear_axis_offset=None, agent_type='pedestrian'),
4545
dict(agent_type='car'),
4646
dict()],
4747
False, 5),
@@ -78,8 +78,8 @@
7878
dict(center=dict(x=-46.62, y=-25.02), orientation=0.04, speed=1.09)],
7979
[dict(center=dict(x=-31.1, y=-23.36), orientation=2.21, speed=0.11),
8080
dict(center=dict(x=-47.62, y=-23.02), orientation=0.04, speed=1.09)]],
81-
[dict(length=1.39, width=1.78, rear_axis_offset=0.0, agent_type='pedestrian'),
82-
dict(length=1.37, width=1.98, rear_axis_offset=0.0, agent_type='pedestrian'),
81+
[dict(length=1.39, width=1.78, agent_type='pedestrian'),
82+
dict(length=1.37, width=1.98, agent_type='pedestrian'),
8383
dict(agent_type='pedestrian'),
8484
dict(agent_type='car')],
8585
False, 6),
@@ -100,7 +100,7 @@
100100
dict(center=dict(x=-46.62, y=-25.02), orientation=0.04, speed=1.09)],
101101
[dict(center=dict(x=-31.1, y=-23.36), orientation=2.21, speed=0.11),
102102
dict(center=dict(x=-47.62, y=-23.02), orientation=0.04, speed=1.09)]],
103-
[dict(length=1.39, rear_axis_offset=0.0, agent_type='pedestrian'),
103+
[dict(length=1.39, agent_type='pedestrian'),
104104
dict(width=1.98, rear_axis_offset=0.0, agent_type='pedestrian'),
105105
dict(agent_type='pedestrian'),
106106
dict(agent_type='car')],
@@ -110,8 +110,8 @@
110110
dict(center=dict(x=-46.62, y=-25.02), orientation=0.04, speed=1.09)],
111111
[dict(center=dict(x=-31.1, y=-23.36), orientation=2.21, speed=0.11),
112112
dict(center=dict(x=-47.62, y=-23.02), orientation=0.04, speed=1.09)]],
113-
[dict(length=1.39, rear_axis_offset=0.0, agent_type='pedestrian'),
114-
dict(length=1.37, width=1.98, agent_type='pedestrian'),
113+
[dict(length=1.39, width=1.20, agent_type='pedestrian'),
114+
dict(length=1.37, width=1.98, agent_type='car'),
115115
dict(agent_type='pedestrian'),
116116
dict(agent_type='car')],
117117
False, None),
@@ -137,7 +137,6 @@
137137
False, 1),
138138
]
139139

140-
141140
def run_initialize(location, states_history, agent_attributes, get_infractions, agent_count):
142141
location_info_response = location_info(location=location, rendering_fov=200)
143142
if any(actor.agent_type == "traffic-light" for actor in location_info_response.static_actors):

0 commit comments

Comments
 (0)