Skip to content

Commit

Permalink
Merge pull request #213 from inverted-ai/hotfix_agent_properties_refa…
Browse files Browse the repository at this point in the history
…ctor

Updated all large area tools to use agent properties by default while…
  • Loading branch information
KieranRatcliffeInvertedAI authored Aug 1, 2024
2 parents a421ee9 + 77e7874 commit df9f451
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 69 deletions.
9 changes: 6 additions & 3 deletions examples/large_map_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import invertedai as iai
from invertedai.large.common import Region
from invertedai.common import AgentAttributes
from invertedai.utils import get_default_agent_properties

import argparse
from tqdm import tqdm
Expand Down Expand Up @@ -51,20 +54,20 @@ def main(args):
)
scene_plotter.initialize_recording(
agent_states=response.agent_states,
agent_attributes=response.agent_attributes,
agent_properties=response.agent_properties,
traffic_light_states=response.traffic_lights_states
)

total_num_agents = len(response.agent_states)
print(f"Number of agents in simulation: {total_num_agents}")

print(f"Begin stepping through simulation.")
agent_attributes = response.agent_attributes
agent_properties = response.agent_properties
for _ in tqdm(range(args.sim_length)):
response = iai.large_drive(
location = args.location,
agent_states = response.agent_states,
agent_attributes = agent_attributes,
agent_properties = agent_properties,
recurrent_states = response.recurrent_states,
light_recurrent_states = response.light_recurrent_states,
random_seed = drive_seed,
Expand Down
2 changes: 1 addition & 1 deletion invertedai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from invertedai.api.blame import blame, async_blame
from invertedai.cosimulation import BasicCosimulation
from invertedai.utils import Jupyter_Render, IAILogger, Session
from invertedai.large.initialize import get_regions_in_grid, get_number_of_agents_per_region_by_drivable_area, get_regions_default, large_initialize
from invertedai.large.initialize import get_regions_in_grid, get_number_of_agents_per_region_by_drivable_area, insert_agents_into_nearest_region, get_regions_default, large_initialize
from invertedai.large.drive import large_drive

dev = strtobool(os.environ.get("IAI_DEV", "false"))
Expand Down
12 changes: 6 additions & 6 deletions invertedai/large/_quadtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import invertedai as iai
from invertedai.large.common import Region
from invertedai.common import Point, AgentState, AgentAttributes, RecurrentState
from invertedai.common import Point, AgentState, AgentProperties, RecurrentState

BUFFER_FOV = 35
QUADTREE_SIZE_BUFFER = 1
Expand All @@ -14,22 +14,22 @@ class QuadTreeAgentInfo(BaseModel):
See Also
--------
AgentState
AgentAttributes
AgentProperties
RecurrentState
"""

agent_state: AgentState
agent_attributes: AgentAttributes
agent_properties: AgentProperties
recurrent_state: RecurrentState
agent_id: int

def tolist(self):
return [self.agent_state, self.agent_attributes, self.recurrent_state, self.agent_id]
return [self.agent_state, self.agent_properties, self.recurrent_state, self.agent_id]

@classmethod
def fromlist(cls, l):
agent_state, agent_attributes, recurrent_state, agent_id = l
return cls(agent_state=agent_state, agent_attributes=agent_attributes, recurrent_state=recurrent_state, agent_id=agent_id)
agent_state, agent_properties, recurrent_state, agent_id = l
return cls(agent_state=agent_state, agent_properties=agent_properties, recurrent_state=recurrent_state, agent_id=agent_id)

class QuadTree:
def __init__(
Expand Down
32 changes: 22 additions & 10 deletions invertedai/large/common.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
from pydantic import BaseModel

from invertedai.common import AgentAttributes, AgentState, RecurrentState, Point
from invertedai.common import AgentAttributes, AgentProperties, AgentState, RecurrentState, Point
from invertedai.utils import convert_attributes_to_properties

class Region(BaseModel):
"""
A region in a map used to divide a large simulation into smaller parts.
See Also
--------
AgentAttributes
AgentProperties
"""

center: Point #: The center of the region if such a concept is relevant (e.g. center of a square, center of a rectangle)
size: float #: Side length of the region for the default interpretation of a region as a square
agent_states: Optional[List[AgentState]] = [] #: A list of existing agents within the region
agent_attributes: Optional[List[AgentAttributes]] = [] #: The attributes of agents that exist within the region or that will be initialized within the region
agent_properties: Optional[List[AgentProperties]] = [] #: The static parameters of agents that exist within the region or that will be initialized within the region
recurrent_states: Optional[List[RecurrentState]] = [] #: Recurrent states of the agents eixsting within the region

@classmethod
Expand All @@ -24,19 +25,28 @@ def create_square_region(
center: Point,
size: Optional[float] = 100,
agent_states: Optional[List[AgentState]] = [],
agent_attributes: Optional[List[AgentAttributes]] = [],
agent_properties: Optional[List[Union[AgentAttributes,AgentProperties]]] = [],
recurrent_states: Optional[List[RecurrentState]] = []
):
cls.center = center
cls.size = size
for agent in agent_states:
assert cls.is_inside(cls,agent.center), f"Existing agent states at position {agent.center} must be located within the region."


agent_properties_new = []
for properties in agent_properties:
properties_new = properties
if isinstance(properties,AgentAttributes):
properties_new = convert_attributes_to_properties(properties_new)
agent_properties_new.append(properties_new)
agent_properties = agent_properties_new

return cls(
center=center,
size=size,
agent_states=agent_states,
agent_attributes=agent_attributes,
agent_properties=agent_properties,
recurrent_states=recurrent_states
)

Expand All @@ -50,25 +60,27 @@ def copy(
center=region.center,
size=region.size,
agent_states=region.agent_states,
agent_attributes=region.agent_attributes,
agent_properties=region.agent_properties,
recurrent_states=region.recurrent_states
)

def clear_agents(self):

self.agent_states = []
self.agent_attributes = []
self.agent_properties = []
self.recurrent_states = []

def insert_all_agent_details(
self,
agent_state: AgentState,
agent_attributes: AgentAttributes,
agent_properties: Union[AgentAttributes,AgentProperties],
recurrent_state: RecurrentState
):
if isinstance(agent_properties,AgentAttributes):
agent_properties = convert_attributes_to_properties(agent_properties)

self.agent_states.append(agent_state)
self.agent_attributes.append(agent_attributes)
self.agent_properties.append(agent_properties)
self.recurrent_states.append(recurrent_state)

def is_inside(
Expand Down
28 changes: 20 additions & 8 deletions invertedai/large/drive.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Tuple, Optional, List
from typing import Tuple, Optional, List, Union
from pydantic import BaseModel, validate_call
from math import ceil
import asyncio

import invertedai as iai
from invertedai.large.common import Region
from invertedai.common import Point, AgentState, AgentAttributes, RecurrentState, TrafficLightStatesDict, LightRecurrentState
from invertedai.common import Point, AgentState, AgentAttributes, AgentProperties, RecurrentState, TrafficLightStatesDict, LightRecurrentState
from invertedai.api.drive import DriveResponse
from invertedai.utils import convert_attributes_to_properties
from invertedai.error import InvertedAIError, InvalidRequestError
from ._quadtree import QuadTreeAgentInfo, QuadTree, _flatten_and_sort, QUADTREE_SIZE_BUFFER

Expand All @@ -20,7 +21,7 @@ async def async_drive_all(async_input_params):
def large_drive(
location: str,
agent_states: List[AgentState],
agent_attributes: List[AgentAttributes],
agent_properties: List[Union[AgentAttributes,AgentProperties]],
recurrent_states: List[RecurrentState],
traffic_lights_states: Optional[TrafficLightStatesDict] = None,
light_recurrent_states: Optional[List[LightRecurrentState]] = None,
Expand All @@ -45,7 +46,7 @@ def large_drive(
agent_states:
Please refer to the documentation of :func:`drive` for information on this parameter.
agent_attributes:
agent_properties:
Please refer to the documentation of :func:`drive` for information on this parameter.
recurrent_states:
Expand Down Expand Up @@ -86,8 +87,19 @@ def large_drive(
if single_call_agent_limit > DRIVE_MAXIMUM_NUM_AGENTS:
single_call_agent_limit = DRIVE_MAXIMUM_NUM_AGENTS
iai.logger.warning(f"Single Call Agent Limit cannot be more than {DRIVE_MAXIMUM_NUM_AGENTS}, limiting this value to {DRIVE_MAXIMUM_NUM_AGENTS} and proceeding.")
if not (len(agent_states) == len(agent_attributes) == len(recurrent_states)):
if not (len(agent_states) == len(agent_properties) == len(recurrent_states)):
raise InvalidRequestError(message="Input lists are not of equal size.")
if not len(agent_states) > 0:
raise InvalidRequestError(message="Valid call must contain at least 1 agent.")

# Convert any AgentAttributes to AgentProperties for backwards compatibility
agent_properties_new = []
for properties in agent_properties:
properties_new = properties
if isinstance(properties,AgentAttributes):
properties_new = convert_attributes_to_properties(properties)
agent_properties_new.append(properties_new)
agent_properties = agent_properties_new

# Generate quadtree
agent_x = [agent.center.x for agent in agent_states]
Expand All @@ -103,7 +115,7 @@ def large_drive(
size=region_size
),
)
for i, (agent, attrs, recurr_state) in enumerate(zip(agent_states,agent_attributes,recurrent_states)):
for i, (agent, attrs, recurr_state) in enumerate(zip(agent_states,agent_properties,recurrent_states)):
agent_info = QuadTreeAgentInfo.fromlist([agent, attrs, recurr_state, i])
is_inserted = quadtree.insert(agent_info)

Expand All @@ -128,8 +140,8 @@ def large_drive(
agent_id_order.extend(region_agents_ids)
input_params = {
"location":location,
"agent_attributes":region.agent_attributes+region_buffer.agent_attributes,
"agent_states":region.agent_states+region_buffer.agent_states,
"agent_properties":region.agent_properties+region_buffer.agent_properties,
"recurrent_states":region.recurrent_states+region_buffer.recurrent_states,
"light_recurrent_states":light_recurrent_states,
"traffic_lights_states":traffic_lights_states,
Expand Down Expand Up @@ -164,7 +176,7 @@ def large_drive(
response = iai.drive(
location = location,
agent_states = agent_states,
agent_attributes = agent_attributes,
agent_properties = agent_properties,
recurrent_states = recurrent_states,
traffic_lights_states = traffic_lights_states,
light_recurrent_states = light_recurrent_states,
Expand Down
Loading

0 comments on commit df9f451

Please sign in to comment.