Skip to content

Commit

Permalink
Merge pull request #24 from Emerge-Lab/hr_rl_fix_invalid_verhicles
Browse files Browse the repository at this point in the history
Remove vehicles with invalid positions and invalid goals from controlled vehicles
  • Loading branch information
daphne-cornelisse authored Jan 20, 2024
2 parents 4d6d3fd + 111cd0f commit 2c0cb9b
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 57 deletions.
2 changes: 1 addition & 1 deletion configs/env_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ ego_state_feat_max:
curr_steering: 3
curr_head_angle: 0.00001 # Not used at the moment

vis_obs_max: 100 # The maximum value across visible state elements
vis_obs_max: 500 # The maximum value across visible state elements
vis_obs_min: -10 # The minimum value across visible state elements

# # # # Agent settings # # # #
Expand Down
143 changes: 87 additions & 56 deletions nocturne/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,21 @@

_MAX_NUM_TRIES_TO_FIND_VALID_VEHICLE = 1_000

logging.getLogger('__name__')
logging.getLogger("__name__")

ActType = TypeVar("ActType") # pylint: disable=invalid-name
ObsType = TypeVar("ObsType") # pylint: disable=invalid-name
RenderType = TypeVar("RenderType") # pylint: disable=invalid-name


class CollisionType(Enum):
"""Enum for collision types."""

NONE = 0
VEHICLE_VEHICLE = 1
VEHICLE_EDGE = 2


class BaseEnv(Env): # pylint: disable=too-many-instance-attributes
"""Nocturne base Gym environment."""

Expand Down Expand Up @@ -70,8 +73,6 @@ def __init__( # pylint: disable=too-many-arguments
"padding": padding,
}
self.seed(self.config.seed)
self.count_invalid = 0
self.count_total = 0

# Load the list of valid files
with open(self.config.data_path / "valid_files.json", encoding="utf-8") as file:
Expand All @@ -82,7 +83,7 @@ def __init__( # pylint: disable=too-many-arguments
else:
files = list(self.valid_veh_dict.keys())
random.shuffle(files)

# Select files
if self.config.num_files != -1:
self.files = files[: self.config.num_files]
Expand All @@ -91,14 +92,21 @@ def __init__( # pylint: disable=too-many-arguments

# Set observation space
obs_dim = self._get_obs_space_dim(self.config)
self.observation_space = Box(low=-np.inf, high=np.inf, shape=obs_dim,)
self.observation_space = Box(
low=-np.inf,
high=np.inf,
shape=obs_dim,
)

# Set action space
if self.config.discretize_actions:
self._set_discrete_action_space()
else:
self._set_continuous_action_space()

# Count total and invalid samples
self.invalid_samples = 0
self.total_samples = 0

def apply_actions(self, action_dict: Dict[int, ActType]) -> None:
"""Apply a dict of actions to the vehicle objects.
Expand Down Expand Up @@ -150,6 +158,15 @@ def step( # pylint: disable=arguments-renamed,too-many-locals,too-many-branches
veh_id = veh_obj.getID()
if veh_id in self.done_ids:
continue

# Remove vehicle from the scene if position is invalid (but not collided or goal achieved?)
if np.isclose(veh_obj.position.x, self.config.scenario.invalid_position):
self.invalid_samples += 1
logging.debug(f"(IN STEP) t = {self.step_num} | {self.file}")
logging.debug(f"veh_id = {veh_obj.id} | pos: {veh_obj.position.x}")
logging.debug(f"controlled_vehs: {[veh.id for veh in self.controlled_vehicles]} \n")

# Get vehicle observation
self.context_dict[veh_id].append(self.get_observation(veh_obj))
if self.config.subscriber.n_frames_stacked > 1:
veh_deque = self.context_dict[veh_id]
Expand Down Expand Up @@ -270,6 +287,8 @@ def step( # pylint: disable=arguments-renamed,too-many-locals,too-many-branches

done_dict["__all__"] = all(done_dict.values())

self.total_samples += len(obs_dict.keys())

return obs_dict, rew_dict, done_dict, info_dict

def reset( # pylint: disable=arguments-differ,too-many-locals,too-many-branches,too-many-statements
Expand All @@ -293,21 +312,20 @@ def reset( # pylint: disable=arguments-differ,too-many-locals,too-many-branches
# we don't want to initialize scenes with 0 actors after satisfying
# all the conditions on a scene that we have
for _ in range(_MAX_NUM_TRIES_TO_FIND_VALID_VEHICLE):

# Sample new traffic scene
if filename is not None:
# Reset to a specific scene name
self.file = filename
self.file = filename
elif self.config.sample_file_method == "no_replacement":
# Random uniformly without replacement
self.file = self.files.pop()
elif psr_dict is not None:
# Prioritized scene replay: sample according to probabilities
probs = [item['prob'] for item in psr_dict.values()]
probs = [item["prob"] for item in psr_dict.values()]
self.file = np.random.choice(self.files, p=probs)
else: # Random uniformly with replacement (default)
else: # Random uniformly with replacement (default)
self.file = np.random.choice(self.files)

self.simulation = Simulation(str(self.config.data_path / self.file), config=self.config.scenario)
self.scenario = self.simulation.getScenario()

Expand Down Expand Up @@ -345,6 +363,7 @@ def reset( # pylint: disable=arguments-differ,too-many-locals,too-many-branches
for veh_obj in self.simulation.getScenario().getObjectsThatMoved():
obj_pos = _position_as_array(veh_obj.getPosition())
goal_pos = _position_as_array(veh_obj.getGoalPosition())

############################################
# Remove vehicles at goal
############################################
Expand All @@ -363,20 +382,40 @@ def reset( # pylint: disable=arguments-differ,too-many-locals,too-many-branches
temp_vehicles = np.random.permutation(self.scenario.getObjectsThatMoved())
curr_index = 0
self.controlled_vehicles = []
logging.debug(f"(IN RESET) selecting vehicles to control, current list: {self.controlled_vehicles}")

for vehicle in temp_vehicles:
# This vehicle was invalid at the end of the 1 second context
# step so we need to remove it
if np.isclose(vehicle.position.x, self.config.scenario.invalid_position):
logging.debug(f"looking at veh_id {vehicle.id}...")

# Remove vehicles that have invalid positions
veh_at_invalid_pos = np.isclose(
vehicle.position.x,
self.config.scenario.invalid_position,
)

# Exclude vehicles with invalid goal positions
veh_has_invalid_goal_pos = np.isclose(
vehicle.getGoalPosition().x, self.config.scenario.invalid_position
) or np.isclose(vehicle.getGoalPosition().y, self.config.scenario.invalid_position)

if veh_at_invalid_pos or veh_has_invalid_goal_pos:
self.scenario.removeVehicle(vehicle)
# We don't want to include vehicles that had unachievable goals
# as controlled vehicles
elif not vehicle.expert_control and curr_index < self.config.max_num_vehicles:
logging.debug(f"veh_id {vehicle.id} is INVALID!")

# Otherwise the vehicle is valid and we add it to the list of controlled vehicles
if (
not vehicle.expert_control
and not veh_at_invalid_pos
and not veh_has_invalid_goal_pos
and curr_index < self.config.max_num_vehicles
):
self.controlled_vehicles.append(vehicle)
logging.debug(f"updated self.controlled_vehicles: {[veh.id for veh in self.controlled_vehicles]}")
curr_index += 1
else:
vehicle.expert_control = True

self.all_vehicle_ids = [veh.getID() for veh in self.controlled_vehicles]
self.all_vehicle_ids = {veh.getID(): veh for veh in self.controlled_vehicles}

# check that we have at least one vehicle or if we have just one file, exit anyways
# or else we might be stuck in an infinite loop
Expand Down Expand Up @@ -420,6 +459,16 @@ def reset( # pylint: disable=arguments-differ,too-many-locals,too-many-branches

self.done_ids = []

# Sanity check: Check if any vehicle is at an invalid position
for veh_id in obs_dict.keys():
veh_obj = self.all_vehicle_ids[veh_id]
if np.isclose(veh_obj.position.x, self.config.scenario.invalid_position):
logging.debug(f"obs_dict contains invalid vehicle! veh_id: {veh_id} at t = {self.step_num}")
logging.debug(f"obs_max: {obs_dict[veh_id].max()}")
self.invalid_samples += 1

self.total_samples += len(obs_dict.keys())

return obs_dict

def get_observation(self, veh_obj: Vehicle) -> np.ndarray:
Expand All @@ -433,42 +482,31 @@ def get_observation(self, veh_obj: Vehicle) -> np.ndarray:
-------
np.ndarray: Observation for the vehicle.
"""
self.count_total += 1

cur_position = []
if self.config.subscriber.use_current_position:
cur_position = _position_as_array(veh_obj.getPosition())
speed = np.array([veh_obj.getSpeed()])
steer = np.array([veh_obj.steering])
if self.config.normalize_state:
cur_position = cur_position / np.linalg.norm(cur_position)

cur_position = np.concatenate([cur_position, speed, steer])

ego_state = []
if self.config.subscriber.use_ego_state:
ego_state = self.scenario.ego_state(veh_obj)

if self.config.normalize_state:
ego_state = self.normalize_ego_state_by_cat(ego_state)
tmp = self.scenario.ego_state(veh_obj)
if ego_state.max() > 1.5:
self.count_invalid += 1
logging.debug(f'-- veh_id: {veh_obj.id} in scene: {self.file} --')
logging.debug(f'ego_before_norm (speed): {tmp[2]:.2f} (dist_to_goal): {tmp[3]:.2f}')
logging.debug(f'ego_after_norm (speed) : {ego_state[2]:.2f} (dist_to_goal): {ego_state[3]:.2f} \n')


visible_state = []
if self.config.subscriber.use_observations:
visible_state = self.scenario.flattened_visible_state(
veh_obj, self.config.subscriber.view_dist, self.config.subscriber.view_angle
)
veh_obj, self.config.subscriber.view_dist, self.config.subscriber.view_angle
)
if self.config.normalize_state:
visible_state = self.normalize_obs_by_cat(visible_state)

if visible_state.max() > 1.5:
logging.debug(f'visible_after: {visible_state.min():.3f} | {visible_state.max():.3f}')

# Concatenate
obs = np.concatenate((ego_state, visible_state, cur_position))

Expand Down Expand Up @@ -500,17 +538,11 @@ def _get_obs_space_dim(self, config, base=0):
self.tl_dim = self.tl_feat * self.config.scenario.max_visible_traffic_lights
self.ss_dim = self.stop_sign_feat * self.config.scenario.max_visible_stop_signs

obs_space_dim += (
base +
self.ro_dim +
self.rg_dim +
self.tl_dim +
self.ss_dim
)
obs_space_dim += base + self.ro_dim + self.rg_dim + self.tl_dim + self.ss_dim

# Multiply by memory to get the final dimension
obs_space_dim = obs_space_dim * self.config.subscriber.n_frames_stacked

return (obs_space_dim,)

def normalize_ego_state_by_cat(self, state):
Expand All @@ -533,8 +565,10 @@ def render(self, mode: Optional[bool] = None) -> Optional[RenderType]: # pylint
Optional[RenderType]: Rendered image.
"""
return self.scenario.getImage(**self._render_settings)

env.scenario.getImage(**video_config.render)

env.scenario.getImage(
**self._render_settings,
)

def render_ego(self, mode: Optional[bool] = None) -> Optional[RenderType]: # pylint: disable=unused-argument
"""Render the ego vehicles.
Expand Down Expand Up @@ -628,23 +662,24 @@ def _set_continuous_action_space(self) -> None:
self.idx_to_actions = None

def unflatten_obs(self, obs_flat):
"Unsqueeeze the flattened object."""
"Unsqueeeze the flattened object." ""

# OBS FLAT ORDER: road_objects, road_points, traffic_lights, stop_signs
# Find the ends of each section
ROAD_OBJECTS_END = 13 * self.config.scenario.max_visible_objects
ROAD_POINTS_END = ROAD_OBJECTS_END + (13 * self.config.scenario.max_visible_road_points)
TL_END = ROAD_POINTS_END + (12 * self.config.scenario.max_visible_traffic_lights)
STOP_SIGN_END = TL_END + (3 * self.config.scenario.max_visible_stop_signs)

# Unflatten
road_objects = obs_flat[:ROAD_OBJECTS_END]
road_points = obs_flat[ROAD_OBJECTS_END:ROAD_POINTS_END]
traffic_lights = obs_flat[ROAD_POINTS_END:TL_END]
stop_signs = obs_flat[TL_END:STOP_SIGN_END]

return road_objects, road_points, traffic_lights, stop_signs


def _angle_sub(current_angle: float, target_angle: float) -> float:
"""Subtract two angles to find the minimum angle between them.
Expand Down Expand Up @@ -701,6 +736,7 @@ def _apply_action_to_vehicle(
veh_obj.acceleration = accel
veh_obj.steering = steer


def _position_as_array(position: Vector2D) -> np.ndarray:
"""Convert a position to an array.
Expand All @@ -716,26 +752,20 @@ def _position_as_array(position: Vector2D) -> np.ndarray:


if __name__ == "__main__":

logging.basicConfig(level=logging.DEBUG)

# Load environment variables and config
env_config = load_config("env_config")

env_config.num_files = 1000

# Initialize an environment
env = BaseEnv(config=env_config)

# Reset
obs_dict = env.reset()

# Get info
agent_ids = [agent_id for agent_id in obs_dict.keys()]
veh_objects = {agent.id: agent for agent in env.controlled_vehicles}
dead_agent_ids = []

num_total = 50_000
num_total = 10_000
for step in range(num_total):
# Sample actions
action_dict = {agent_id: env.action_space.sample() for agent_id in agent_ids if agent_id not in dead_agent_ids}
Expand All @@ -744,7 +774,8 @@ def _position_as_array(position: Vector2D) -> np.ndarray:
for obj in env.controlled_vehicles:
obj.expert_control = True

obs_dict, rew_dict, done_dict, info_dict = env.step(action_dict)
obs_dict, rew_dict, done_dict, info_dict = env.step(action_dict)

# Update dead agents
for agent_id, is_done in done_dict.items():
if is_done and agent_id not in dead_agent_ids:
Expand All @@ -755,7 +786,7 @@ def _position_as_array(position: Vector2D) -> np.ndarray:
obs_dict = env.reset()
dead_agent_ids = []

print(f'{env.count_invalid} / {env.count_total} invalid = {(env.count_invalid/env.count_total)*100}')
logging.info(f"INVALID_SAMPLES: {(env.invalid_samples/env.total_samples)*100:.2f}")

# Close environment
env.close()

0 comments on commit 2c0cb9b

Please sign in to comment.