Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove vehicles with invalid positions and invalid goals from controlled vehicles #24

Merged
merged 2 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/env_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ ego_state_feat_max:
curr_steering: 3
curr_head_angle: 0.00001 # Not used at the moment

vis_obs_max: 100 # The maximum value across visible state elements
vis_obs_max: 500 # The maximum value across visible state elements
vis_obs_min: -10 # The minimum value across visible state elements

# # # # Agent settings # # # #
Expand Down
143 changes: 87 additions & 56 deletions nocturne/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,21 @@

_MAX_NUM_TRIES_TO_FIND_VALID_VEHICLE = 1_000

logging.getLogger('__name__')
logging.getLogger("__name__")

ActType = TypeVar("ActType") # pylint: disable=invalid-name
ObsType = TypeVar("ObsType") # pylint: disable=invalid-name
RenderType = TypeVar("RenderType") # pylint: disable=invalid-name


class CollisionType(Enum):
"""Enum for collision types."""

NONE = 0
VEHICLE_VEHICLE = 1
VEHICLE_EDGE = 2


class BaseEnv(Env): # pylint: disable=too-many-instance-attributes
"""Nocturne base Gym environment."""

Expand Down Expand Up @@ -70,8 +73,6 @@ def __init__( # pylint: disable=too-many-arguments
"padding": padding,
}
self.seed(self.config.seed)
self.count_invalid = 0
self.count_total = 0

# Load the list of valid files
with open(self.config.data_path / "valid_files.json", encoding="utf-8") as file:
Expand All @@ -82,7 +83,7 @@ def __init__( # pylint: disable=too-many-arguments
else:
files = list(self.valid_veh_dict.keys())
random.shuffle(files)

# Select files
if self.config.num_files != -1:
self.files = files[: self.config.num_files]
Expand All @@ -91,14 +92,21 @@ def __init__( # pylint: disable=too-many-arguments

# Set observation space
obs_dim = self._get_obs_space_dim(self.config)
self.observation_space = Box(low=-np.inf, high=np.inf, shape=obs_dim,)
self.observation_space = Box(
low=-np.inf,
high=np.inf,
shape=obs_dim,
)

# Set action space
if self.config.discretize_actions:
self._set_discrete_action_space()
else:
self._set_continuous_action_space()

# Count total and invalid samples
self.invalid_samples = 0
self.total_samples = 0

def apply_actions(self, action_dict: Dict[int, ActType]) -> None:
"""Apply a dict of actions to the vehicle objects.
Expand Down Expand Up @@ -150,6 +158,15 @@ def step( # pylint: disable=arguments-renamed,too-many-locals,too-many-branches
veh_id = veh_obj.getID()
if veh_id in self.done_ids:
continue

# Remove vehicle from the scene if position is invalid (but not collided or goal achieved?)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing that this condition, which is post reset, is only getting triggered for vehicles that are expert controlled. The expert controlled vehicles will occasionally pop to the invalid position

if np.isclose(veh_obj.position.x, self.config.scenario.invalid_position):
self.invalid_samples += 1
logging.debug(f"(IN STEP) t = {self.step_num} | {self.file}")
logging.debug(f"veh_id = {veh_obj.id} | pos: {veh_obj.position.x}")
logging.debug(f"controlled_vehs: {[veh.id for veh in self.controlled_vehicles]} \n")

# Get vehicle observation
self.context_dict[veh_id].append(self.get_observation(veh_obj))
if self.config.subscriber.n_frames_stacked > 1:
veh_deque = self.context_dict[veh_id]
Expand Down Expand Up @@ -270,6 +287,8 @@ def step( # pylint: disable=arguments-renamed,too-many-locals,too-many-branches

done_dict["__all__"] = all(done_dict.values())

self.total_samples += len(obs_dict.keys())

return obs_dict, rew_dict, done_dict, info_dict

def reset( # pylint: disable=arguments-differ,too-many-locals,too-many-branches,too-many-statements
Expand All @@ -293,21 +312,20 @@ def reset( # pylint: disable=arguments-differ,too-many-locals,too-many-branches
# we don't want to initialize scenes with 0 actors after satisfying
# all the conditions on a scene that we have
for _ in range(_MAX_NUM_TRIES_TO_FIND_VALID_VEHICLE):

# Sample new traffic scene
if filename is not None:
# Reset to a specific scene name
self.file = filename
self.file = filename
elif self.config.sample_file_method == "no_replacement":
# Random uniformly without replacement
self.file = self.files.pop()
elif psr_dict is not None:
# Prioritized scene replay: sample according to probabilities
probs = [item['prob'] for item in psr_dict.values()]
probs = [item["prob"] for item in psr_dict.values()]
self.file = np.random.choice(self.files, p=probs)
else: # Random uniformly with replacement (default)
else: # Random uniformly with replacement (default)
self.file = np.random.choice(self.files)

self.simulation = Simulation(str(self.config.data_path / self.file), config=self.config.scenario)
self.scenario = self.simulation.getScenario()

Expand Down Expand Up @@ -345,6 +363,7 @@ def reset( # pylint: disable=arguments-differ,too-many-locals,too-many-branches
for veh_obj in self.simulation.getScenario().getObjectsThatMoved():
obj_pos = _position_as_array(veh_obj.getPosition())
goal_pos = _position_as_array(veh_obj.getGoalPosition())

############################################
# Remove vehicles at goal
############################################
Expand All @@ -363,20 +382,40 @@ def reset( # pylint: disable=arguments-differ,too-many-locals,too-many-branches
temp_vehicles = np.random.permutation(self.scenario.getObjectsThatMoved())
curr_index = 0
self.controlled_vehicles = []
logging.debug(f"(IN RESET) selecting vehicles to control, current list: {self.controlled_vehicles}")

for vehicle in temp_vehicles:
# This vehicle was invalid at the end of the 1 second context
# step so we need to remove it
if np.isclose(vehicle.position.x, self.config.scenario.invalid_position):
logging.debug(f"looking at veh_id {vehicle.id}...")

# Remove vehicles that have invalid positions
veh_at_invalid_pos = np.isclose(
vehicle.position.x,
self.config.scenario.invalid_position,
)

# Exclude vehicles with invalid goal positions
veh_has_invalid_goal_pos = np.isclose(
vehicle.getGoalPosition().x, self.config.scenario.invalid_position
) or np.isclose(vehicle.getGoalPosition().y, self.config.scenario.invalid_position)

if veh_at_invalid_pos or veh_has_invalid_goal_pos:
self.scenario.removeVehicle(vehicle)
# We don't want to include vehicles that had unachievable goals
# as controlled vehicles
elif not vehicle.expert_control and curr_index < self.config.max_num_vehicles:
logging.debug(f"veh_id {vehicle.id} is INVALID!")

# Otherwise the vehicle is valid and we add it to the list of controlled vehicles
if (
not vehicle.expert_control
and not veh_at_invalid_pos
and not veh_has_invalid_goal_pos
and curr_index < self.config.max_num_vehicles
):
self.controlled_vehicles.append(vehicle)
logging.debug(f"updated self.controlled_vehicles: {[veh.id for veh in self.controlled_vehicles]}")
curr_index += 1
else:
vehicle.expert_control = True

self.all_vehicle_ids = [veh.getID() for veh in self.controlled_vehicles]
self.all_vehicle_ids = {veh.getID(): veh for veh in self.controlled_vehicles}

# check that we have at least one vehicle or if we have just one file, exit anyways
# or else we might be stuck in an infinite loop
Expand Down Expand Up @@ -420,6 +459,16 @@ def reset( # pylint: disable=arguments-differ,too-many-locals,too-many-branches

self.done_ids = []

# Sanity check: Check if any vehicle is at an invalid position
for veh_id in obs_dict.keys():
veh_obj = self.all_vehicle_ids[veh_id]
if np.isclose(veh_obj.position.x, self.config.scenario.invalid_position):
logging.debug(f"obs_dict contains invalid vehicle! veh_id: {veh_id} at t = {self.step_num}")
logging.debug(f"obs_max: {obs_dict[veh_id].max()}")
self.invalid_samples += 1

self.total_samples += len(obs_dict.keys())

return obs_dict

def get_observation(self, veh_obj: Vehicle) -> np.ndarray:
Expand All @@ -433,42 +482,31 @@ def get_observation(self, veh_obj: Vehicle) -> np.ndarray:
-------
np.ndarray: Observation for the vehicle.
"""
self.count_total += 1

cur_position = []
if self.config.subscriber.use_current_position:
cur_position = _position_as_array(veh_obj.getPosition())
speed = np.array([veh_obj.getSpeed()])
steer = np.array([veh_obj.steering])
if self.config.normalize_state:
cur_position = cur_position / np.linalg.norm(cur_position)

cur_position = np.concatenate([cur_position, speed, steer])

ego_state = []
if self.config.subscriber.use_ego_state:
ego_state = self.scenario.ego_state(veh_obj)

if self.config.normalize_state:
ego_state = self.normalize_ego_state_by_cat(ego_state)
tmp = self.scenario.ego_state(veh_obj)
if ego_state.max() > 1.5:
self.count_invalid += 1
logging.debug(f'-- veh_id: {veh_obj.id} in scene: {self.file} --')
logging.debug(f'ego_before_norm (speed): {tmp[2]:.2f} (dist_to_goal): {tmp[3]:.2f}')
logging.debug(f'ego_after_norm (speed) : {ego_state[2]:.2f} (dist_to_goal): {ego_state[3]:.2f} \n')


visible_state = []
if self.config.subscriber.use_observations:
visible_state = self.scenario.flattened_visible_state(
veh_obj, self.config.subscriber.view_dist, self.config.subscriber.view_angle
)
veh_obj, self.config.subscriber.view_dist, self.config.subscriber.view_angle
)
if self.config.normalize_state:
visible_state = self.normalize_obs_by_cat(visible_state)

if visible_state.max() > 1.5:
logging.debug(f'visible_after: {visible_state.min():.3f} | {visible_state.max():.3f}')

# Concatenate
obs = np.concatenate((ego_state, visible_state, cur_position))

Expand Down Expand Up @@ -500,17 +538,11 @@ def _get_obs_space_dim(self, config, base=0):
self.tl_dim = self.tl_feat * self.config.scenario.max_visible_traffic_lights
self.ss_dim = self.stop_sign_feat * self.config.scenario.max_visible_stop_signs

obs_space_dim += (
base +
self.ro_dim +
self.rg_dim +
self.tl_dim +
self.ss_dim
)
obs_space_dim += base + self.ro_dim + self.rg_dim + self.tl_dim + self.ss_dim

# Multiply by memory to get the final dimension
obs_space_dim = obs_space_dim * self.config.subscriber.n_frames_stacked

return (obs_space_dim,)

def normalize_ego_state_by_cat(self, state):
Expand All @@ -533,8 +565,10 @@ def render(self, mode: Optional[bool] = None) -> Optional[RenderType]: # pylint
Optional[RenderType]: Rendered image.
"""
return self.scenario.getImage(**self._render_settings)

env.scenario.getImage(**video_config.render)

env.scenario.getImage(
**self._render_settings,
)

def render_ego(self, mode: Optional[bool] = None) -> Optional[RenderType]: # pylint: disable=unused-argument
"""Render the ego vehicles.
Expand Down Expand Up @@ -628,23 +662,24 @@ def _set_continuous_action_space(self) -> None:
self.idx_to_actions = None

def unflatten_obs(self, obs_flat):
"Unsqueeeze the flattened object."""
"Unsqueeeze the flattened object." ""

# OBS FLAT ORDER: road_objects, road_points, traffic_lights, stop_signs
# Find the ends of each section
ROAD_OBJECTS_END = 13 * self.config.scenario.max_visible_objects
ROAD_POINTS_END = ROAD_OBJECTS_END + (13 * self.config.scenario.max_visible_road_points)
TL_END = ROAD_POINTS_END + (12 * self.config.scenario.max_visible_traffic_lights)
STOP_SIGN_END = TL_END + (3 * self.config.scenario.max_visible_stop_signs)

# Unflatten
road_objects = obs_flat[:ROAD_OBJECTS_END]
road_points = obs_flat[ROAD_OBJECTS_END:ROAD_POINTS_END]
traffic_lights = obs_flat[ROAD_POINTS_END:TL_END]
stop_signs = obs_flat[TL_END:STOP_SIGN_END]

return road_objects, road_points, traffic_lights, stop_signs


def _angle_sub(current_angle: float, target_angle: float) -> float:
"""Subtract two angles to find the minimum angle between them.

Expand Down Expand Up @@ -701,6 +736,7 @@ def _apply_action_to_vehicle(
veh_obj.acceleration = accel
veh_obj.steering = steer


def _position_as_array(position: Vector2D) -> np.ndarray:
"""Convert a position to an array.

Expand All @@ -716,26 +752,20 @@ def _position_as_array(position: Vector2D) -> np.ndarray:


if __name__ == "__main__":

logging.basicConfig(level=logging.DEBUG)

# Load environment variables and config
env_config = load_config("env_config")

env_config.num_files = 1000

# Initialize an environment
env = BaseEnv(config=env_config)

# Reset
obs_dict = env.reset()

# Get info
agent_ids = [agent_id for agent_id in obs_dict.keys()]
veh_objects = {agent.id: agent for agent in env.controlled_vehicles}
dead_agent_ids = []

num_total = 50_000
num_total = 10_000
for step in range(num_total):
# Sample actions
action_dict = {agent_id: env.action_space.sample() for agent_id in agent_ids if agent_id not in dead_agent_ids}
Expand All @@ -744,7 +774,8 @@ def _position_as_array(position: Vector2D) -> np.ndarray:
for obj in env.controlled_vehicles:
obj.expert_control = True

obs_dict, rew_dict, done_dict, info_dict = env.step(action_dict)
obs_dict, rew_dict, done_dict, info_dict = env.step(action_dict)

# Update dead agents
for agent_id, is_done in done_dict.items():
if is_done and agent_id not in dead_agent_ids:
Expand All @@ -755,7 +786,7 @@ def _position_as_array(position: Vector2D) -> np.ndarray:
obs_dict = env.reset()
dead_agent_ids = []

print(f'{env.count_invalid} / {env.count_total} invalid = {(env.count_invalid/env.count_total)*100}')
logging.info(f"INVALID_SAMPLES: {(env.invalid_samples/env.total_samples)*100:.2f}")

# Close environment
env.close()
Loading