Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log replay #215

Merged
merged 10 commits into from
Aug 16, 2024
Merged
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,4 @@ bazel-*
examples/*.png
examples/output/
examples/*.csv
examples/*.json
188 changes: 188 additions & 0 deletions examples/scenario_log_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import invertedai as iai
from invertedai.utils import get_default_agent_properties

import os
from random import randint
import matplotlib.pyplot as plt

LOCATION = "canada:drake_street_and_pacific_blvd" # select one of available locations
SIMULATION_LENGTH = 100
SIMULATION_LENGTH_EXTEND = 100
SIMULATION_BEGIN_NEW_ROLLOUT = 50

######################################################################################
# Produce a log and write it
print("Producing log...")
location_info_response = iai.location_info(location=LOCATION)

# initialize the simulation by spawning NPCs
response = iai.initialize(
location=LOCATION, # select one of available locations
agent_properties=get_default_agent_properties({"car":5}), # number of NPCs to spawn
)
agent_properties = response.agent_properties # get dimension and other attributes of NPCs

log_writer = iai.LogWriter()
log_writer.initialize(
location=LOCATION,
location_info_response=location_info_response,
init_response=response
)

print("Stepping through simulation...")
for _ in range(SIMULATION_LENGTH):
# query the API for subsequent NPC predictions
response = iai.drive(
location=LOCATION,
agent_properties=agent_properties,
agent_states=response.agent_states,
recurrent_states=response.recurrent_states,
traffic_lights_states=response.traffic_lights_states,
random_seed=randint(1,100000)
)

log_writer.drive(drive_response=response)


log_path = os.path.join(os.getcwd(),f"scenario_log_example.json")
log_writer.write_scenario_log_to_json(log_path=log_path)
gif_path_original = os.path.join(os.getcwd(),f"scenario_log_example_original.gif")
log_writer.visualize(
gif_path=gif_path_original,
fov = 200,
resolution = (2048,2048),
dpi = 300,
map_center = None,
direction_vec = False,
velocity_vec = False,
plot_frame_number = True
)

######################################################################################
# Replay original log
print("Reading log...")

log_reader = iai.LogReader(log_path)
gif_path_replay = os.path.join(os.getcwd(),f"scenario_log_example_replay.gif")
log_reader.visualize(
gif_path=gif_path_replay,
fov = 200,
resolution = (2048,2048),
dpi = 300,
map_center = None,
direction_vec = False,
velocity_vec = False,
plot_frame_number = True
)

print("Extending read log...")

location_info_response_replay = log_reader.location_info_response
log_reader.initialize()
agent_properties = log_reader.agent_properties

rendered_static_map = location_info_response_replay.birdview_image.decode()
scene_plotter = iai.utils.ScenePlotter(
rendered_static_map,
location_info_response_replay.map_fov,
(location_info_response_replay.map_center.x, location_info_response_replay.map_center.y),
location_info_response_replay.static_actors
)
scene_plotter.initialize_recording(
agent_states=log_reader.agent_states,
agent_properties=agent_properties
)

print("Stepping through simulation...")
while True: # Log reader will return None when it has run out of simulation data
is_timestep_populated = log_reader.drive()
if not is_timestep_populated:
break
scene_plotter.record_step(log_reader.agent_states,log_reader.traffic_lights_states)

agent_states = log_reader.agent_states
recurrent_states = log_reader.recurrent_states
traffic_lights_states = log_reader.traffic_lights_states
for _ in range(SIMULATION_LENGTH_EXTEND):
response = iai.drive(
location=log_reader.scenario_log.location,
agent_properties=agent_properties,
agent_states=agent_states,
recurrent_states=recurrent_states,
traffic_lights_states=traffic_lights_states
)

agent_states = response.agent_states
recurrent_states = response.recurrent_states
traffic_lights_states = response.traffic_lights_states

scene_plotter.record_step(agent_states,traffic_lights_states)

gif_path_extended = os.path.join(os.getcwd(),f"scenario_log_example_extended.gif")
fig, ax = plt.subplots(constrained_layout=True, figsize=(50, 50))
plt.axis('off')
scene_plotter.animate_scene(
output_name=gif_path_extended,
ax=ax,
direction_vec = False,
velocity_vec = False,
plot_frame_number = True
)


######################################################################################
# Re-read the log and choose an earlier timestep from which to branch off
print("Re-reading the log...")
log_reader.reset_log()

location_info_response_replay = log_reader.location_info_response
log_reader.initialize()
agent_properties = log_reader.agent_properties

rendered_static_map = location_info_response_replay.birdview_image.decode()
scene_plotter_new = iai.utils.ScenePlotter(
rendered_static_map,
location_info_response_replay.map_fov,
(location_info_response_replay.map_center.x, location_info_response_replay.map_center.y),
location_info_response_replay.static_actors
)
scene_plotter_new.initialize_recording(
agent_states=log_reader.agent_states,
agent_properties=agent_properties
)

print("Stepping through simulation...")
for _ in range(SIMULATION_BEGIN_NEW_ROLLOUT):
log_reader.drive()
scene_plotter_new.record_step(log_reader.agent_states,log_reader.traffic_lights_states)

agent_states = log_reader.agent_states
recurrent_states = log_reader.recurrent_states
traffic_lights_states = log_reader.traffic_lights_states
for _ in range(SIMULATION_LENGTH-SIMULATION_BEGIN_NEW_ROLLOUT):
response = iai.drive(
location=log_reader.scenario_log.location,
agent_properties=agent_properties,
agent_states=agent_states,
recurrent_states=recurrent_states,
traffic_lights_states=traffic_lights_states,
random_seed=randint(1,100000)
)

agent_states = response.agent_states
recurrent_states = response.recurrent_states
traffic_lights_states = response.traffic_lights_states

scene_plotter_new.record_step(agent_states,traffic_lights_states)

gif_path_branched = os.path.join(os.getcwd(),f"scenario_log_example_branched.gif")
fig_new, ax_new = plt.subplots(constrained_layout=True, figsize=(50, 50))
plt.axis('off')
scene_plotter_new.animate_scene(
output_name=gif_path_branched,
ax=ax_new,
direction_vec = False,
velocity_vec = False,
plot_frame_number = True
)

1 change: 1 addition & 0 deletions invertedai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from invertedai.utils import Jupyter_Render, IAILogger, Session
from invertedai.large.initialize import get_regions_in_grid, get_number_of_agents_per_region_by_drivable_area, insert_agents_into_nearest_region, get_regions_default, large_initialize
from invertedai.large.drive import large_drive
from invertedai.logs.common import LogWriter, LogReader

dev = strtobool(os.environ.get("IAI_DEV", "false"))
if dev:
Expand Down
26 changes: 8 additions & 18 deletions invertedai/api/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,22 @@ class DriveResponse(BaseModel):
Response returned from an API call to :func:`iai.drive`.
"""

agent_states: List[
AgentState
] #: Predicted states for all agents at the next time step.
recurrent_states: List[
RecurrentState
] #: To pass to :func:`iai.drive` at the subsequent time step.
birdview: Optional[
Image
] #: If `get_birdview` was set, this contains the resulting image.
infractions: Optional[
List[InfractionIndicators]
] #: If `get_infractions` was set, they are returned here.
is_inside_supported_area: List[
bool
] #: For each agent, indicates whether the predicted state is inside supported area.
agent_states: List[AgentState] #: Predicted states for all agents at the next time step.
recurrent_states: List[RecurrentState] #: To pass to :func:`iai.drive` at the subsequent time step.
birdview: Optional[Image] #: If `get_birdview` was set, this contains the resulting image.
infractions: Optional[List[InfractionIndicators]] #: If `get_infractions` was set, they are returned here.
is_inside_supported_area: List[bool] #: For each agent, indicates whether the predicted state is inside supported area.
traffic_lights_states: Optional[TrafficLightStatesDict] #: Traffic light states for the full map, as seen by the agents before they performed their actions resulting in the returned state. Each key-value pair corresponds to one particular traffic light.
light_recurrent_states: Optional[LightRecurrentStates] #: Light recurrent states for the full map, each element corresponds to one light group. Pass this to the next call of :func:`iai.drive` for the server to realistically update the traffic light states.
api_model_version: str # Model version used for this API call
api_model_version: str # Model version used for this API call


@validate_call
def drive(
location: str,
agent_states: List[AgentState],
agent_attributes: Optional[List[AgentAttributes]]=None,
agent_properties: Optional[List[AgentProperties]]=None,
agent_attributes: Optional[List[AgentAttributes]] = None,
agent_properties: Optional[List[AgentProperties]] = None,
recurrent_states: Optional[List[RecurrentState]] = None,
traffic_lights_states: Optional[TrafficLightStatesDict] = None,
light_recurrent_states: Optional[LightRecurrentStates] = None,
Expand Down
22 changes: 7 additions & 15 deletions invertedai/api/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,15 @@ class InitializeResponse(BaseModel):
Response returned from an API call to :func:`iai.initialize`.
"""

recurrent_states: List[
Optional[RecurrentState]
] #: To pass to :func:`iai.drive` at the first time step.
agent_states: List[Optional[AgentState]] #: Initial states of all initialized agents.
agent_attributes: List[
Optional[AgentAttributes]
] #: Static attributes of all initialized agents.
agent_states: List[AgentState] #: Initial states of all initialized agents.
recurrent_states: List[Optional[RecurrentState]] #: To pass to :func:`iai.drive` at the first time step.
agent_attributes: List[Optional[AgentAttributes]] #: Static attributes of all initialized agents.
agent_properties: List[AgentProperties] #: Static agent properties of all initialized agents.
birdview: Optional[
Image
] #: If `get_birdview` was set, this contains the resulting image.
infractions: Optional[
List[InfractionIndicators]
] #: If `get_infractions` was set, they are returned here.
traffic_lights_states: Optional[TrafficLightStatesDict] #: Traffic light states for the full map, each key-value pair corresponds to one particular traffic light.
birdview: Optional[Image] #: If `get_birdview` was set, this contains the resulting image.
infractions: Optional[List[InfractionIndicators]] #: If `get_infractions` was set, they are returned here.
traffic_lights_states: Optional[TrafficLightStatesDict] #: Traffic light states for the full map, each key-value pair corresponds to one particular traffic light.
light_recurrent_states: Optional[LightRecurrentStates] #: Light recurrent states for the full map. Pass this to :func:`iai.drive` at the first time step to let the server generate a realistic continuation of the traffic light state sequence. This does not work correctly if any specific light states were specified as input to `initialize`.
api_model_version: str # Model version used for this API call
api_model_version: str #: Model version used for this API call


@validate_call
Expand Down
4 changes: 3 additions & 1 deletion invertedai/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Tuple
from enum import Enum
from pydantic import BaseModel, model_validator
import math
from PIL import Image as PImage
import numpy as np
import io
import json

import invertedai as iai
from invertedai.error import InvalidInputType, InvalidInput

Expand Down
Loading
Loading