Skip to content

Commit 4275e94

Browse files
authored
Merge pull request #129 from inverted-ai/develop
Develop
2 parents 611c189 + 786f703 commit 4275e94

File tree

14 files changed

+699
-4
lines changed

14 files changed

+699
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
[pypi-link]: https://pypi.org/project/invertedai/
33
[colab-badge]: https://colab.research.google.com/assets/colab-badge.svg
44
[colab-link]: https://colab.research.google.com/github/inverted-ai/invertedai/blob/develop/examples/IAI_demo.ipynb
5-
[rest-link]: https://app.swaggerhub.com/apis/swaggerhub59/Inverted-AI
5+
[rest-link]: https://app.swaggerhub.com/apis-docs/InvertedAI/InvertedAI
66
[examples-link]: https://github.com/inverted-ai/invertedai/tree/master/examples
77

88
[![Documentation Status](https://readthedocs.org/projects/inverted-ai/badge/?version=latest)](https://inverted-ai.readthedocs.io/en/latest/?badge=latest)

examples/blame_example.ipynb

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"id": "uYrB4Yk3aizw"
7+
},
8+
"source": [
9+
"<img src=\"https://raw.githubusercontent.com/inverted-ai/invertedai/master/docs/images/banner-small.png\" alt=\"InvertedAI\" width=\"200\"/>\n"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": null,
15+
"metadata": {
16+
"id": "b7l_n8sULmAX"
17+
},
18+
"outputs": [],
19+
"source": [
20+
"import IPython\n",
21+
"from IPython.display import display, Image, clear_output\n",
22+
"from ipywidgets import interact\n",
23+
"from IPython.utils import io\n",
24+
"\n",
25+
"import matplotlib.pyplot as plt\n",
26+
"import imageio\n",
27+
"import numpy as np\n",
28+
"import cv2\n",
29+
"import invertedai as iai\n",
30+
"\n",
31+
"from shapely.geometry import Polygon\n",
32+
"from shapely.errors import GEOSException\n",
33+
"\n",
34+
"from dataclasses import dataclass\n",
35+
"from typing import Tuple"
36+
]
37+
},
38+
{
39+
"cell_type": "code",
40+
"execution_count": null,
41+
"metadata": {
42+
"id": "SFu9oOcDIQGs"
43+
},
44+
"outputs": [],
45+
"source": [
46+
"# API key:\n",
47+
"iai.add_apikey(\"\")"
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {
54+
"id": "PiySgisPG6mG"
55+
},
56+
"outputs": [],
57+
"source": [
58+
"# pick a location (4 way, signalized intgersection)\n",
59+
"location = \"iai:drake_street_and_pacific_blvd\""
60+
]
61+
},
62+
{
63+
"cell_type": "code",
64+
"execution_count": null,
65+
"metadata": {
66+
"id": "kT3m0-zoNdvW"
67+
},
68+
"outputs": [],
69+
"source": [
70+
"location_info_response = iai.location_info(location=location)\n",
71+
"rendered_static_map = location_info_response.birdview_image.decode()\n",
72+
"scene_plotter = iai.utils.ScenePlotter(rendered_static_map,\n",
73+
" location_info_response.map_fov,\n",
74+
" (location_info_response.map_center.x, location_info_response.map_center.y),\n",
75+
" location_info_response.static_actors)"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"metadata": {
82+
"id": "L0ufxMO8NdvX",
83+
"scrolled": true
84+
},
85+
"outputs": [],
86+
"source": [
87+
"@dataclass\n",
88+
"class LogCollision:\n",
89+
" collision_agents: Tuple[int, int]\n",
90+
" start_time: int \n",
91+
" end_time: int\n",
92+
"\n",
93+
"def transform_all_agent_vertices_into_world_frame(agent_states,agent_attributes):\n",
94+
" \"\"\"\n",
95+
" Transform the vertices of all agents into points within the world frame of the map environment\n",
96+
" Args:\n",
97+
" agent_states: List[AgentState] #List of all current agent states including x and y coordinates and angle.\n",
98+
" agent_attributes: List[AgentAttributes] #List of static attributes of all agents including agent length and width.\n",
99+
" Returns:\n",
100+
" List[Polygon] #List of Polygon data types containing a list of vertices for each agent\n",
101+
" \"\"\"\n",
102+
" polygons = [None]*len(agent_states)\n",
103+
" for i, (state, attributes) in enumerate(zip(agent_states,agent_attributes)):\n",
104+
" dx = attributes.length/2\n",
105+
" dy = attributes.width/2\n",
106+
"\n",
107+
" vehicle_origin = np.array([state.center.x, state.center.y])\n",
108+
"\n",
109+
" vehicle_orientation = state.orientation\n",
110+
" c, s = np.cos(vehicle_orientation), np.sin(vehicle_orientation)\n",
111+
"\n",
112+
" rotation_matrix = np.array([[c, s],\n",
113+
" [-s, c]])\n",
114+
" stacked_vertices = np.array([[dx,dy],[dx,-dy],[-dx,-dy],[-dx,dy]]) #Formatted in a continuous sequence\n",
115+
" rotated_vertices = np.matmul(stacked_vertices,rotation_matrix)\n",
116+
" \n",
117+
" polygons[i] = vehicle_origin + rotated_vertices\n",
118+
"\n",
119+
" return [Polygon(p) for p in polygons]\n",
120+
"\n",
121+
"def check_agent_pairwise_intersections(polygons):\n",
122+
" \"\"\"\n",
123+
" Check all pairs of polygons in a list for intersections in their area.\n",
124+
" Args:\n",
125+
" polygons: List[Polygon] #List of polygons representing agents in an environment\n",
126+
" Returns:\n",
127+
" List[Tuple[int,int]] #List of agent ID pair tuples indicating collisions\n",
128+
" \"\"\"\n",
129+
" \n",
130+
" detected_overlap_agent_pairs = []\n",
131+
" num_agents = len(polygons)\n",
132+
" for j in range(num_agents):\n",
133+
" for k in range(j+1,num_agents):\n",
134+
" try:\n",
135+
" if polygons[j].intersection(polygons[k]).area:\n",
136+
" detected_overlap_agent_pairs.append((j,k))\n",
137+
" except GEOSException as e:\n",
138+
" print(f\"Collision candidates {j} and {k} failed with error {e}.\")\n",
139+
" pass\n",
140+
" \n",
141+
" return detected_overlap_agent_pairs\n",
142+
" \n",
143+
"def compute_pairwise_collisions(agent_states_history,agent_attributes):\n",
144+
" \"\"\"\n",
145+
" Use polygon intersections to check each agent combination whether there is a collision.\n",
146+
" Args:\n",
147+
" agent_states: List[List[AgentState]] #At all time steps, list of all current agent states including x and y coordinates and angle.\n",
148+
" agent_attributes: List[AgentAttributes] #List of static attributes of all agents including agent length and width.\n",
149+
" Returns:\n",
150+
" List[LogCollision] #List of collisions logs containing information about the colliding agent pairs IDs and the time period of the collision\n",
151+
" \"\"\"\n",
152+
" \n",
153+
" collisions_ongoing = {}\n",
154+
" collisions_all = []\n",
155+
" \n",
156+
" for t, agent_states in enumerate(agent_states_history):\n",
157+
" if len(agent_states) != len(agent_attributes):\n",
158+
" raise Exception(\"Incorrect number of agents or agent attributes.\")\n",
159+
"\n",
160+
" polygons = transform_all_agent_vertices_into_world_frame(agent_states,agent_attributes)\n",
161+
" detected_agent_pairs = check_agent_pairwise_intersections(polygons) \n",
162+
" \n",
163+
" for agent_tuple in detected_agent_pairs:\n",
164+
" if agent_tuple not in collisions_ongoing:\n",
165+
" collisions_ongoing[agent_tuple] = LogCollision(\n",
166+
" collision_agents=agent_tuple,\n",
167+
" start_time=t,\n",
168+
" end_time=None\n",
169+
" )\n",
170+
" untracked_agent_pairs = []\n",
171+
" for agent_tuple, collision in collisions_ongoing.items():\n",
172+
" if agent_tuple not in detected_agent_pairs:\n",
173+
" #The previous time step is the last in which the collision was observed\n",
174+
" collisions_ongoing[agent_tuple].end_time = t-1\n",
175+
" elif t >= SIMULATION_LENGTH-1:\n",
176+
" #The collision has not necessarily ended at this time step but it is the last time step it was observed to occur\n",
177+
" collisions_ongoing[agent_tuple].end_time = t\n",
178+
" \n",
179+
" if collisions_ongoing[agent_tuple].end_time is not None:\n",
180+
" collisions_all.append(collisions_ongoing[agent_tuple])\n",
181+
" untracked_agent_pairs.append(agent_tuple)\n",
182+
" \n",
183+
" collisions_ongoing = {k:v for k, v in collisions_ongoing.items() if not k in untracked_agent_pairs}\n",
184+
" \n",
185+
" return collisions_all\n",
186+
"\n",
187+
"# Simulate with `initialize`, `drive` and `light` until there are collisions.\n",
188+
"for _ in range(20): #Attempt 20 simulations looking for a collision\n",
189+
" light_response = iai.light(location=location)\n",
190+
"\n",
191+
" response = iai.initialize(\n",
192+
" location=location,\n",
193+
" agent_count=15,\n",
194+
" get_birdview=True,\n",
195+
" traffic_light_state_history=[light_response.traffic_lights_states]\n",
196+
" )\n",
197+
" agent_attributes = response.agent_attributes\n",
198+
" scene_plotter.initialize_recording(\n",
199+
" response.agent_states,\n",
200+
" agent_attributes=agent_attributes,\n",
201+
" traffic_light_states=light_response.traffic_lights_states\n",
202+
" )\n",
203+
"\n",
204+
" agent_state_history = []\n",
205+
" traffic_light_state_history = []\n",
206+
"\n",
207+
" # 10-second scene\n",
208+
" SIMULATION_LENGTH = 100\n",
209+
" for t in range(SIMULATION_LENGTH):\n",
210+
" light_response = iai.light(\n",
211+
" location=location, \n",
212+
" recurrent_states=light_response.recurrent_states\n",
213+
" )\n",
214+
" response = iai.drive(\n",
215+
" location=location,\n",
216+
" agent_attributes=agent_attributes,\n",
217+
" agent_states=response.agent_states,\n",
218+
" recurrent_states=response.recurrent_states,\n",
219+
" get_birdview=False,\n",
220+
" traffic_lights_states=light_response.traffic_lights_states,\n",
221+
" get_infractions=True,\n",
222+
" random_seed=1\n",
223+
" )\n",
224+
" scene_plotter.record_step(\n",
225+
" response.agent_states, \n",
226+
" traffic_light_states=light_response.traffic_lights_states\n",
227+
" )\n",
228+
" agent_state_history.append(response.agent_states)\n",
229+
" traffic_light_state_history.append(light_response.traffic_lights_states)\n",
230+
" \n",
231+
" print(f\"Attempted collision simulation number {_} iteration number {t}.\")\n",
232+
" clear_output(wait=True)\n",
233+
" \n",
234+
" collisions = compute_pairwise_collisions(agent_state_history,agent_attributes)\n",
235+
" if collisions: \n",
236+
" #If a collision is detected, cease generating more simulations\n",
237+
" break\n",
238+
"\n",
239+
"print(collisions)"
240+
]
241+
},
242+
{
243+
"cell_type": "code",
244+
"execution_count": null,
245+
"metadata": {
246+
"id": "YQ4dXaKQNdvZ"
247+
},
248+
"outputs": [],
249+
"source": [
250+
"blame_responses = []\n",
251+
"all_collision_agents = []\n",
252+
"for collision_data in collisions:\n",
253+
" all_collision_agents.extend(list(collision_data.collision_agents))\n",
254+
" blame_response = iai.blame(\n",
255+
" location=location,\n",
256+
" colliding_agents=collision_data.collision_agents,\n",
257+
" agent_state_history=agent_state_history[:collision_data.start_time],\n",
258+
" traffic_light_state_history=traffic_light_state_history[:collision_data.start_time],\n",
259+
" agent_attributes=agent_attributes,\n",
260+
" get_reasons=True,\n",
261+
" get_confidence_score=True,\n",
262+
" get_birdviews=False\n",
263+
" )\n",
264+
" print(blame_response.agents_at_fault)\n",
265+
" blame_responses.append(blame_response)"
266+
]
267+
},
268+
{
269+
"cell_type": "code",
270+
"execution_count": null,
271+
"metadata": {
272+
"id": "jRRaiLInNdvZ"
273+
},
274+
"outputs": [],
275+
"source": [
276+
"for response in blame_responses:\n",
277+
" print(response.reasons)"
278+
]
279+
},
280+
{
281+
"cell_type": "code",
282+
"execution_count": null,
283+
"metadata": {
284+
"id": "oZ79k112Ndva"
285+
},
286+
"outputs": [],
287+
"source": [
288+
"for response in blame_responses:\n",
289+
" print(response.confidence_score)"
290+
]
291+
},
292+
{
293+
"cell_type": "code",
294+
"execution_count": null,
295+
"metadata": {
296+
"id": "QzbqXvS7Ndva"
297+
},
298+
"outputs": [],
299+
"source": [
300+
"%%capture\n",
301+
"fig, ax = plt.subplots(constrained_layout=True, figsize=(50, 50))\n",
302+
"gif_name = 'blame-example.gif'\n",
303+
"scene_plotter.animate_scene(\n",
304+
" output_name=gif_name,\n",
305+
" ax=ax,\n",
306+
" numbers=all_collision_agents,\n",
307+
" direction_vec=False,\n",
308+
" velocity_vec=False,\n",
309+
" plot_frame_number=True\n",
310+
")"
311+
]
312+
},
313+
{
314+
"cell_type": "code",
315+
"execution_count": null,
316+
"metadata": {
317+
"id": "aUYCAHfsNdva"
318+
},
319+
"outputs": [],
320+
"source": [
321+
"Image(gif_name, width=1000, height=800)"
322+
]
323+
},
324+
{
325+
"cell_type": "code",
326+
"execution_count": null,
327+
"metadata": {},
328+
"outputs": [],
329+
"source": []
330+
}
331+
],
332+
"metadata": {
333+
"colab": {
334+
"provenance": []
335+
},
336+
"kernelspec": {
337+
"display_name": "Python 3 (ipykernel)",
338+
"language": "python",
339+
"name": "python3"
340+
},
341+
"language_info": {
342+
"codemirror_mode": {
343+
"name": "ipython",
344+
"version": 3
345+
},
346+
"file_extension": ".py",
347+
"mimetype": "text/x-python",
348+
"name": "python",
349+
"nbconvert_exporter": "python",
350+
"pygments_lexer": "ipython3",
351+
"version": "3.8.10"
352+
}
353+
},
354+
"nbformat": 4,
355+
"nbformat_minor": 1
356+
}

0 commit comments

Comments
 (0)