Skip to content

Commit

Permalink
Merge pull request #215 from inverted-ai/log_replay
Browse files Browse the repository at this point in the history
Log replay
  • Loading branch information
KieranRatcliffeInvertedAI authored Aug 16, 2024
2 parents ab4cfe3 + d9d9ff7 commit 921961b
Show file tree
Hide file tree
Showing 8 changed files with 745 additions and 36 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,4 @@ bazel-*
examples/*.png
examples/output/
examples/*.csv
examples/*.json
188 changes: 188 additions & 0 deletions examples/scenario_log_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import invertedai as iai
from invertedai.utils import get_default_agent_properties

import os
from random import randint
import matplotlib.pyplot as plt

LOCATION = "canada:drake_street_and_pacific_blvd" # select one of available locations
SIMULATION_LENGTH = 100
SIMULATION_LENGTH_EXTEND = 100
SIMULATION_BEGIN_NEW_ROLLOUT = 50

######################################################################################
# Produce a log and write it
print("Producing log...")
location_info_response = iai.location_info(location=LOCATION)

# initialize the simulation by spawning NPCs
response = iai.initialize(
location=LOCATION, # select one of available locations
agent_properties=get_default_agent_properties({"car":5}), # number of NPCs to spawn
)
agent_properties = response.agent_properties # get dimension and other attributes of NPCs

log_writer = iai.LogWriter()
log_writer.initialize(
location=LOCATION,
location_info_response=location_info_response,
init_response=response
)

print("Stepping through simulation...")
for _ in range(SIMULATION_LENGTH):
# query the API for subsequent NPC predictions
response = iai.drive(
location=LOCATION,
agent_properties=agent_properties,
agent_states=response.agent_states,
recurrent_states=response.recurrent_states,
traffic_lights_states=response.traffic_lights_states,
random_seed=randint(1,100000)
)

log_writer.drive(drive_response=response)


log_path = os.path.join(os.getcwd(),f"scenario_log_example.json")
log_writer.export_to_file(log_path=log_path)
gif_path_original = os.path.join(os.getcwd(),f"scenario_log_example_original.gif")
log_writer.visualize(
gif_path=gif_path_original,
fov = 200,
resolution = (2048,2048),
dpi = 300,
map_center = None,
direction_vec = False,
velocity_vec = False,
plot_frame_number = True
)

######################################################################################
# Replay original log
print("Reading log...")

log_reader = iai.LogReader(log_path)
gif_path_replay = os.path.join(os.getcwd(),f"scenario_log_example_replay.gif")
log_reader.visualize(
gif_path=gif_path_replay,
fov = 200,
resolution = (2048,2048),
dpi = 300,
map_center = None,
direction_vec = False,
velocity_vec = False,
plot_frame_number = True
)

print("Extending read log...")

location_info_response_replay = log_reader.location_info_response
log_reader.initialize()
agent_properties = log_reader.agent_properties

rendered_static_map = location_info_response_replay.birdview_image.decode()
scene_plotter = iai.utils.ScenePlotter(
rendered_static_map,
location_info_response_replay.map_fov,
(location_info_response_replay.map_center.x, location_info_response_replay.map_center.y),
location_info_response_replay.static_actors
)
scene_plotter.initialize_recording(
agent_states=log_reader.agent_states,
agent_properties=agent_properties
)

print("Stepping through simulation...")
while True: # Log reader will return None when it has run out of simulation data
is_timestep_populated = log_reader.drive()
if not is_timestep_populated:
break
scene_plotter.record_step(log_reader.agent_states,log_reader.traffic_lights_states)

agent_states = log_reader.agent_states
recurrent_states = log_reader.recurrent_states
traffic_lights_states = log_reader.traffic_lights_states
for _ in range(SIMULATION_LENGTH_EXTEND):
response = iai.drive(
location=log_reader.location,
agent_properties=agent_properties,
agent_states=agent_states,
recurrent_states=recurrent_states,
traffic_lights_states=traffic_lights_states
)

agent_states = response.agent_states
recurrent_states = response.recurrent_states
traffic_lights_states = response.traffic_lights_states

scene_plotter.record_step(agent_states,traffic_lights_states)

gif_path_extended = os.path.join(os.getcwd(),f"scenario_log_example_extended.gif")
fig, ax = plt.subplots(constrained_layout=True, figsize=(50, 50))
plt.axis('off')
scene_plotter.animate_scene(
output_name=gif_path_extended,
ax=ax,
direction_vec = False,
velocity_vec = False,
plot_frame_number = True
)


######################################################################################
# Re-read the log and choose an earlier timestep from which to branch off
print("Re-reading the log...")
log_reader.reset_log()

location_info_response_replay = log_reader.location_info_response
log_reader.initialize()
agent_properties = log_reader.agent_properties

rendered_static_map = location_info_response_replay.birdview_image.decode()
scene_plotter_new = iai.utils.ScenePlotter(
rendered_static_map,
location_info_response_replay.map_fov,
(location_info_response_replay.map_center.x, location_info_response_replay.map_center.y),
location_info_response_replay.static_actors
)
scene_plotter_new.initialize_recording(
agent_states=log_reader.agent_states,
agent_properties=agent_properties
)

print("Stepping through simulation...")
for _ in range(SIMULATION_BEGIN_NEW_ROLLOUT):
log_reader.drive()
scene_plotter_new.record_step(log_reader.agent_states,log_reader.traffic_lights_states)

agent_states = log_reader.agent_states
recurrent_states = log_reader.recurrent_states
traffic_lights_states = log_reader.traffic_lights_states
for _ in range(SIMULATION_LENGTH-SIMULATION_BEGIN_NEW_ROLLOUT):
response = iai.drive(
location=log_reader.location,
agent_properties=agent_properties,
agent_states=agent_states,
recurrent_states=recurrent_states,
traffic_lights_states=traffic_lights_states,
random_seed=randint(1,100000)
)

agent_states = response.agent_states
recurrent_states = response.recurrent_states
traffic_lights_states = response.traffic_lights_states

scene_plotter_new.record_step(agent_states,traffic_lights_states)

gif_path_branched = os.path.join(os.getcwd(),f"scenario_log_example_branched.gif")
fig_new, ax_new = plt.subplots(constrained_layout=True, figsize=(50, 50))
plt.axis('off')
scene_plotter_new.animate_scene(
output_name=gif_path_branched,
ax=ax_new,
direction_vec = False,
velocity_vec = False,
plot_frame_number = True
)

1 change: 1 addition & 0 deletions invertedai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from invertedai.utils import Jupyter_Render, IAILogger, Session
from invertedai.large.initialize import get_regions_in_grid, get_number_of_agents_per_region_by_drivable_area, insert_agents_into_nearest_region, get_regions_default, large_initialize
from invertedai.large.drive import large_drive
from invertedai.logs.logger import LogWriter, LogReader

dev = strtobool(os.environ.get("IAI_DEV", "false"))
if dev:
Expand Down
26 changes: 8 additions & 18 deletions invertedai/api/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,22 @@ class DriveResponse(BaseModel):
Response returned from an API call to :func:`iai.drive`.
"""

agent_states: List[
AgentState
] #: Predicted states for all agents at the next time step.
recurrent_states: List[
RecurrentState
] #: To pass to :func:`iai.drive` at the subsequent time step.
birdview: Optional[
Image
] #: If `get_birdview` was set, this contains the resulting image.
infractions: Optional[
List[InfractionIndicators]
] #: If `get_infractions` was set, they are returned here.
is_inside_supported_area: List[
bool
] #: For each agent, indicates whether the predicted state is inside supported area.
agent_states: List[AgentState] #: Predicted states for all agents at the next time step.
recurrent_states: List[RecurrentState] #: To pass to :func:`iai.drive` at the subsequent time step.
birdview: Optional[Image] #: If `get_birdview` was set, this contains the resulting image.
infractions: Optional[List[InfractionIndicators]] #: If `get_infractions` was set, they are returned here.
is_inside_supported_area: List[bool] #: For each agent, indicates whether the predicted state is inside supported area.
traffic_lights_states: Optional[TrafficLightStatesDict] #: Traffic light states for the full map, as seen by the agents before they performed their actions resulting in the returned state. Each key-value pair corresponds to one particular traffic light.
light_recurrent_states: Optional[LightRecurrentStates] #: Light recurrent states for the full map, each element corresponds to one light group. Pass this to the next call of :func:`iai.drive` for the server to realistically update the traffic light states.
api_model_version: str # Model version used for this API call
api_model_version: str # Model version used for this API call


@validate_call
def drive(
location: str,
agent_states: List[AgentState],
agent_attributes: Optional[List[AgentAttributes]]=None,
agent_properties: Optional[List[AgentProperties]]=None,
agent_attributes: Optional[List[AgentAttributes]] = None,
agent_properties: Optional[List[AgentProperties]] = None,
recurrent_states: Optional[List[RecurrentState]] = None,
traffic_lights_states: Optional[TrafficLightStatesDict] = None,
light_recurrent_states: Optional[LightRecurrentStates] = None,
Expand Down
22 changes: 7 additions & 15 deletions invertedai/api/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,15 @@ class InitializeResponse(BaseModel):
Response returned from an API call to :func:`iai.initialize`.
"""

recurrent_states: List[
Optional[RecurrentState]
] #: To pass to :func:`iai.drive` at the first time step.
agent_states: List[Optional[AgentState]] #: Initial states of all initialized agents.
agent_attributes: List[
Optional[AgentAttributes]
] #: Static attributes of all initialized agents.
agent_states: List[AgentState] #: Initial states of all initialized agents.
recurrent_states: List[Optional[RecurrentState]] #: To pass to :func:`iai.drive` at the first time step.
agent_attributes: List[Optional[AgentAttributes]] #: Static attributes of all initialized agents.
agent_properties: List[AgentProperties] #: Static agent properties of all initialized agents.
birdview: Optional[
Image
] #: If `get_birdview` was set, this contains the resulting image.
infractions: Optional[
List[InfractionIndicators]
] #: If `get_infractions` was set, they are returned here.
traffic_lights_states: Optional[TrafficLightStatesDict] #: Traffic light states for the full map, each key-value pair corresponds to one particular traffic light.
birdview: Optional[Image] #: If `get_birdview` was set, this contains the resulting image.
infractions: Optional[List[InfractionIndicators]] #: If `get_infractions` was set, they are returned here.
traffic_lights_states: Optional[TrafficLightStatesDict] #: Traffic light states for the full map, each key-value pair corresponds to one particular traffic light.
light_recurrent_states: Optional[LightRecurrentStates] #: Light recurrent states for the full map. Pass this to :func:`iai.drive` at the first time step to let the server generate a realistic continuation of the traffic light state sequence. This does not work correctly if any specific light states were specified as input to `initialize`.
api_model_version: str # Model version used for this API call
api_model_version: str #: Model version used for this API call


@validate_call
Expand Down
4 changes: 3 additions & 1 deletion invertedai/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Tuple
from enum import Enum
from pydantic import BaseModel, model_validator
import math
from PIL import Image as PImage
import numpy as np
import io
import json

import invertedai as iai
from invertedai.error import InvalidInputType, InvalidInput

Expand Down
Loading

0 comments on commit 921961b

Please sign in to comment.