|
8 | 8 | from invertedai.api.light import light
|
9 | 9 | from invertedai.error import InvalidRequestError
|
10 | 10 |
|
| 11 | +def recurrent_states_helper(states_to_extend): |
| 12 | + result = [0.0] * 128 |
| 13 | + result.extend(states_to_extend) |
| 14 | + return result |
| 15 | + |
| 16 | + |
11 | 17 | positive_tests = [
|
12 | 18 | ("carla:Town04",
|
13 | 19 | None,
|
|
20 | 26 | dict(agent_type='car'),
|
21 | 27 | dict()],
|
22 | 28 | False, 5),
|
23 |
| - # ("canada:drake_street_and_pacific_blvd", |
24 |
| - # None, |
25 |
| - # [dict(agent_type='car'), |
26 |
| - # dict(), |
27 |
| - # dict(agent_type='pedestrian'), |
28 |
| - # dict()], |
29 |
| - # False, 5), |
30 |
| - # ("canada:drake_street_and_pacific_blvd", |
31 |
| - # None, |
32 |
| - # [dict(agent_type='pedestrian'), |
33 |
| - # dict(agent_type='pedestrian'), |
34 |
| - # dict(agent_type='pedestrian'), |
35 |
| - # dict(agent_type='pedestrian')], |
36 |
| - # False, 5), |
| 29 | + ("canada:drake_street_and_pacific_blvd", |
| 30 | + None, |
| 31 | + [dict(agent_type='car'), |
| 32 | + dict(), |
| 33 | + dict(agent_type='pedestrian'), |
| 34 | + dict()], |
| 35 | + False, 5), |
| 36 | + ("canada:drake_street_and_pacific_blvd", |
| 37 | + None, |
| 38 | + [dict(agent_type='pedestrian'), |
| 39 | + dict(agent_type='pedestrian'), |
| 40 | + dict(agent_type='pedestrian'), |
| 41 | + dict(agent_type='pedestrian')], |
| 42 | + False, 5), |
37 | 43 | ("canada:ubc_roundabout",
|
38 | 44 | [[dict(center=dict(x=60.82, y=1.22), orientation=0.63, speed=11.43),
|
39 | 45 | dict(center=dict(x=-36.88, y=-33.93), orientation=-2.64, speed=9.43)]],
|
|
53 | 59 | False, None),
|
54 | 60 | ]
|
55 | 61 | negative_tests = [
|
56 |
| - ("canada:drake_street_and_pacific_blvd", |
57 |
| - None, |
58 |
| - [dict(agent_type='car'), |
59 |
| - dict(), |
60 |
| - dict(agent_type='bicycle'), |
61 |
| - dict()], |
62 |
| - False, 5), |
| 62 | + ("carla:Town03", |
| 63 | + [dict(center=dict(x=-21.2, y=-17.11), orientation=4.54, speed=1.8), |
| 64 | + dict(center=dict(x=-5.81, y=-49.47), orientation=1.62, speed=11.4)], |
| 65 | + [dict(length=0.97, agent_type="pedestrian"), |
| 66 | + dict(length=4.86, width=2.12, rear_axis_offset=1.85, agent_type='car')], |
| 67 | + [dict(packed=recurrent_states_helper( |
| 68 | + [21.203039169311523, -17.10862159729004, -1.3971855640411377, 1.7982733249664307])), |
| 69 | + dict(packed=recurrent_states_helper( |
| 70 | + [5.810295104980469, -49.47068786621094, 1.5232856273651123, 11.404326438903809]))], |
| 71 | + False), |
| 72 | + ("carla:Town03", |
| 73 | + [dict(center=dict(x=-21.2, y=-17.11), orientation=4.54, speed=1.8), |
| 74 | + dict(center=dict(x=-5.81, y=-49.47), orientation=1.62, speed=11.4)], |
| 75 | + [dict(length=0.97, width=1.06, rear_axis_offset=None, agent_type="pedestrian"), |
| 76 | + dict(length=4.86, width=2.12, agent_type='car')], |
| 77 | + [dict(packed=recurrent_states_helper( |
| 78 | + [21.203039169311523, -17.10862159729004, -1.3971855640411377, 1.7982733249664307])), |
| 79 | + dict(packed=recurrent_states_helper( |
| 80 | + [5.810295104980469, -49.47068786621094, 1.5232856273651123, 11.404326438903809]))], |
| 81 | + False), |
| 82 | + ("carla:Town03", |
| 83 | + [dict(center=dict(x=-21.2, y=-17.11), orientation=4.54, speed=1.8), |
| 84 | + dict(center=dict(x=-5.81, y=-49.47), orientation=1.62, speed=11.4)], |
| 85 | + [dict(length=0.97, width=1.06, agent_type="pedestrian"), |
| 86 | + dict(width=2.12, rear_axis_offset=1.85)], |
| 87 | + [dict(packed=recurrent_states_helper( |
| 88 | + [21.203039169311523, -17.10862159729004, -1.3971855640411377, 1.7982733249664307])), |
| 89 | + dict(packed=recurrent_states_helper( |
| 90 | + [5.810295104980469, -49.47068786621094, 1.5232856273651123, 11.404326438903809]))], |
| 91 | + False), |
63 | 92 | ]
|
64 | 93 |
|
65 |
| -def run_drive(location, states_history, agent_attributes, get_infractions, agent_count, simulation_length: int = 20): |
| 94 | +def run_initialize_drive_flow(location, states_history, agent_attributes, get_infractions, agent_count, |
| 95 | + simulation_length: int = 20): |
66 | 96 | location_info_response = location_info(location=location, rendering_fov=200)
|
67 | 97 | if any(actor.agent_type == "traffic-light" for actor in location_info_response.static_actors):
|
68 | 98 | scene_has_lights = True
|
@@ -94,13 +124,28 @@ def run_drive(location, states_history, agent_attributes, get_infractions, agent
|
94 | 124 | assert isinstance(updated_state,
|
95 | 125 | DriveResponse) and updated_state.agent_states is not None and updated_state.recurrent_states is not None
|
96 | 126 |
|
| 127 | + |
| 128 | +def run_direct_drive(location, agent_states, agent_attributes, recurrent_states, get_infractions): |
| 129 | + drive_response = drive( |
| 130 | + agent_attributes=agent_attributes, |
| 131 | + agent_states=agent_states, |
| 132 | + recurrent_states=recurrent_states, |
| 133 | + traffic_lights_states=None, |
| 134 | + get_birdview=False, |
| 135 | + location=location, |
| 136 | + get_infractions=get_infractions |
| 137 | + ) |
| 138 | + assert isinstance(drive_response, |
| 139 | + DriveResponse) and drive_response.agent_states is not None and drive_response.recurrent_states is not None |
| 140 | + |
| 141 | + |
97 | 142 | @pytest.mark.parametrize("location, states_history, agent_attributes, get_infractions, agent_count", negative_tests)
|
98 | 143 | def test_negative(location, states_history, agent_attributes, get_infractions, agent_count,
|
99 | 144 | simulation_length: int = 20):
|
100 | 145 | with pytest.raises(InvalidRequestError):
|
101 |
| - run_drive(location, states_history, agent_attributes, get_infractions, agent_count) |
| 146 | + run_direct_drive(location, states_history, agent_attributes, get_infractions, agent_count) |
102 | 147 |
|
103 | 148 | @pytest.mark.parametrize("location, states_history, agent_attributes, get_infractions, agent_count", positive_tests)
|
104 | 149 | def test_postivie(location, states_history, agent_attributes, get_infractions, agent_count,
|
105 | 150 | simulation_length: int = 20):
|
106 |
| - run_drive(location, states_history, agent_attributes, get_infractions, agent_count) |
| 151 | + run_initialize_drive_flow(location, states_history, agent_attributes, get_infractions, agent_count) |
0 commit comments