From 949ea251f4018f553e79e1a8dec66763865e8f6c Mon Sep 17 00:00:00 2001 From: mttga Date: Mon, 12 Feb 2024 18:09:47 +0100 Subject: [PATCH] remove utracking --- jaxmarl/environments/utracking/__init__.py | 1 - jaxmarl/environments/utracking/animator.py | 293 ---------- .../traj_linear_models-checkpoint.json | 86 --- .../utracking/traj_models/__init__.py | 8 - .../traj_models/traj_linear_models.json | 98 ---- jaxmarl/environments/utracking/utracking.py | 529 ------------------ 6 files changed, 1015 deletions(-) delete mode 100644 jaxmarl/environments/utracking/__init__.py delete mode 100644 jaxmarl/environments/utracking/animator.py delete mode 100644 jaxmarl/environments/utracking/traj_models/.ipynb_checkpoints/traj_linear_models-checkpoint.json delete mode 100644 jaxmarl/environments/utracking/traj_models/__init__.py delete mode 100644 jaxmarl/environments/utracking/traj_models/traj_linear_models.json delete mode 100644 jaxmarl/environments/utracking/utracking.py diff --git a/jaxmarl/environments/utracking/__init__.py b/jaxmarl/environments/utracking/__init__.py deleted file mode 100644 index 28b246e6..00000000 --- a/jaxmarl/environments/utracking/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .utracking import UTracking \ No newline at end of file diff --git a/jaxmarl/environments/utracking/animator.py b/jaxmarl/environments/utracking/animator.py deleted file mode 100644 index aceb38d1..00000000 --- a/jaxmarl/environments/utracking/animator.py +++ /dev/null @@ -1,293 +0,0 @@ -"""Temporary visualizer helper that works only with parameters sharing' qlearning agents""" - -from functools import partial -import jax -from jax import numpy as jnp -import numpy as np - -import numpy as np -import logging -logging.getLogger('matplotlib').setLevel(logging.CRITICAL) -from matplotlib import pyplot as plt -from matplotlib import cm -import matplotlib.animation as animation -from matplotlib.ticker import MaxNLocator -import contextlib - -class UTrackingQLearningViz: - - def __init__(self, env, agent, agent_params, hidden_dim=64, max_steps=200): - self.env = env - self.agent = agent - self.agent_params = agent_params - self.hidden_dim = hidden_dim - self.max_steps=max_steps - - @partial(jax.jit, static_argnums=0) - def get_rollout(self, rng): - key, key_r, key_a = jax.random.split(rng, 3) - - init_x = ( - jnp.zeros((1, 1, self.env.obs_size)), # (time_step, batch_size, obs_size) - jnp.zeros((1, 1)) # (time_step, batch size) - ) - init_hstate = jnp.zeros((1, self.hidden_dim)) - _ = self.agent.init(key_a, init_hstate, init_x) - - init_dones = {agent:jnp.zeros(1, dtype=bool) for agent in self.env.agents+['__all__']} - - hstate = jnp.zeros((1*self.env.env.num_agents, self.hidden_dim)) - init_obs, env_state = self.env.batch_reset(key_r) - - def homogeneous_pass(params, hidden_state, obs, dones): - # concatenate agents and parallel envs to process them in one batch - agents, flatten_agents_obs = zip(*obs.items()) - original_shape = flatten_agents_obs[0].shape # assumes obs shape is the same for all agents - batched_input = ( - jnp.concatenate(flatten_agents_obs, axis=1), # (time_step, n_agents*n_envs, obs_size) - jnp.concatenate([dones[agent] for agent in agents], axis=1), # ensure to not pass other keys (like __all__) - ) - hidden_state, q_vals = self.agent.apply(params, hidden_state, batched_input) - q_vals = q_vals.reshape(original_shape[0], len(agents), *original_shape[1:-1], -1) # (time_steps, n_agents, n_envs, action_dim) - q_vals = {a:q_vals[:,i] for i,a in enumerate(agents)} - return hidden_state, q_vals - - def _env_step(step_state, unused): - params, env_state, last_obs, last_dones, hstate, rng = step_state - - rng, key_a, key_s = jax.random.split(rng, 3) - - obs_ = {a:last_obs[a] for a in self.env.agents} - obs_ = jax.tree_map(lambda x: x[np.newaxis, :], obs_) - dones_ = jax.tree_map(lambda x: x[np.newaxis, :], last_dones) - - hstate, q_vals = homogeneous_pass(params, hstate, obs_, dones_) - valid_q_vals = jax.tree_util.tree_map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, self.env.valid_actions) - # greedy actions - actions = jax.tree_util.tree_map(lambda q: jnp.argmax(q, axis=-1), valid_q_vals) - - # step - obs, env_state, rewards, dones, info = self.env.batch_step(key_s, env_state, actions) - info['pos'] = env_state.pos - info['done'] = dones['__all__'] - - step_state = (params, env_state, obs, dones, hstate, rng) - return step_state, info - - - step_state = ( - self.agent_params, - env_state, - init_obs, - init_dones, - hstate, - key, - ) - - step_state, infos = jax.lax.scan( - _env_step, step_state, None, self.max_steps - ) - - return infos - - def get_animation(self, rng, save_path='./tmp_animation'): - - infos = self.get_rollout(rng) - - #preprocess - x = jax.tree_map(lambda x: x[:,0], infos) - dones = x['done'] - first_done = jax.lax.select((jnp.argmax(dones)==0)&(dones[0]!=True), dones.size, jnp.argmax(dones)) - first_episode_mask = jnp.where(jnp.arange(dones.size) <= first_done, True, False) - x = jax.tree_map(lambda x: x[first_episode_mask], x) - - viz = UTrackingAnimator( - agent_positions = jnp.swapaxes(x['pos'], 0, 1)[:self.env.env.num_agents, :, :2], - landmark_positions = jnp.swapaxes(x['pos'], 0, 1)[self.env.env.num_agents:, :, :2], - landmark_predictions = jnp.swapaxes(x['tracking_pred'], 0, 1), - episode_rewards = x['rew'], - episode_errors = jnp.swapaxes(x['tracking_error'], 0, 1), - ) - viz.save_animation(save_path) - - -class UTrackingAnimator(animation.TimedAnimation): - - def __init__(self, - agent_positions, - landmark_positions, - landmark_predictions, - episode_rewards, - episode_errors, - lags=None): - - - # general parameters - self.frames = (agent_positions.shape[1]) - self.n_agents = len(agent_positions) - self.n_landmarks = len(landmark_positions) - - self.agent_positions = agent_positions - self.landmark_positions = landmark_positions - self.landmark_predictions = landmark_predictions - self.episode_rewards = episode_rewards - self.episode_errors = episode_errors - if lags is None: - self.lags = self.frames - else: - self.lags = self.lags - - # create the subplots - self.fig = plt.figure(figsize=(20, 10), dpi=120) - self.ax_episode = self.fig.add_subplot(1, 2, 1) - self.ax_reward = self.fig.add_subplot(2, 2, 2) - self.ax_error = self.fig.add_subplot(2, 2, 4) - - self.ax_episode.set_title('Episode') - self.ax_reward.set_title('Reward') - self.ax_error.set_title('Prediction Error') - - # colors - self.agent_colors = cm.Dark2.colors - self.landmark_colors = [cm.summer(l*10) for l in range(self.n_landmarks)] # pastl greens - self.prediction_colors = [cm.PiYG(l*10) for l in range(self.n_landmarks)] # pinks - - # init the lines - self.lines_episode = self._init_episode_animation(self.ax_episode) - self.lines_reward = self._init_reward_animation(self.ax_reward) - self.lines_error = self._init_error_animation(self.ax_error) - - - animation.TimedAnimation.__init__(self, self.fig, interval=100, blit=True) - - def save_animation(self, savepath='episode'): - with contextlib.redirect_stdout(None): - self.save(savepath+'.gif') - self.fig.savefig(savepath+'.png') - - - def _episode_update(self, data, line, frame, lags, name=None): - line.set_data(data[max(0,frame-lags):frame, 0], data[max(0,frame-lags):frame, 1]) - if name is not None: - line.set_label(name) - - def _frameline_update(self, data, line, frame, name=None): - line.set_data(np.arange(1,frame+1), data[:frame]) - if name is not None: - line.set_label(name) - - def _draw_frame(self, frame): - - # Update the episode subplot - line_episode = 0 - # update agents heads - for n in range(self.n_agents): - self._episode_update(self.agent_positions[n], self.lines_episode[line_episode], frame, 1, f'Agent_{n+1}') - line_episode += 1 - - # update agents trajectories - for n in range(self.n_agents): - self._episode_update(self.agent_positions[n], self.lines_episode[line_episode], max(0,frame-1), self.lags) - line_episode += 1 - - # landmark real positions - for n in range(self.n_landmarks): - self._episode_update(self.landmark_positions[n], self.lines_episode[line_episode], frame, self.lags, f'Landmark_{n+1}_real') - line_episode += 1 - - # landmark predictions - for n in range(self.n_landmarks): - self._episode_update(self.landmark_predictions[n], self.lines_episode[line_episode], frame, self.lags, f'Landmark_{n+1}_predictions') - line_episode += 1 - - self.ax_episode.legend() - - # Update the reward subplot - self._frameline_update(self.episode_rewards, self.lines_reward[0], frame) - - # Update the error subplot - for n in range(self.n_landmarks): - self._frameline_update(self.episode_errors[n], self.lines_error[n], frame, f'Landmark_{n+1}_error') - self.ax_error.legend() - - self._drawn_artists = self.lines_episode + self.lines_reward + self.lines_error - - - def _init_episode_animation(self, ax): - # retrieve the episode dimensions - x_max = max(self.agent_positions[:,:,0].max(), - self.landmark_positions[:,:,0].max()) - - x_min = min(self.agent_positions[:,:,0].min(), - self.landmark_positions[:,:,0].min()) - - y_max = max(self.agent_positions[:,:,1].max(), - self.landmark_positions[:,:,1].max()) - - y_min = min(self.agent_positions[:,:,1].min(), - self.landmark_positions[:,:,1].min()) - - abs_min = min(x_min, y_min) - abs_max = max(x_max, y_max) - - ax.set_xlim(abs_min-1, abs_max+1) - ax.set_ylim(abs_min-1,abs_max+1) - ax.set_ylabel('Y Position') - - # remove frame - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.spines['bottom'].set_visible(False) - ax.spines['left'].set_visible(False) - - # lines: - # 1. agent head - # 2. agent trajectory - # 3. landmark real - # 4. landmark prediction - lines = [ax.plot([],[],'o',color=self.agent_colors[a], alpha=0.8,markersize=8)[0] for a in range(self.n_agents)] + \ - [ax.plot([],[],'o',color=self.agent_colors[a], alpha=0.2,markersize=4)[0] for a in range(self.n_agents)] + \ - [ax.plot([],[],'s',color=self.landmark_colors[l], alpha=0.8,markersize=8)[0] for l in range(self.n_landmarks)] + \ - [ax.plot([],[],'s',color=self.prediction_colors[l], alpha=0.2,markersize=4)[0] for l in range(self.n_landmarks)] - - return lines - - def _init_reward_animation(self, ax): - ax.set_xlim(0, self.frames) - ax.set_ylim(self.episode_rewards.min(), self.episode_rewards.max()+1) - ax.set_xlabel('Timestep') - ax.set_ylabel('Reward') - ax.yaxis.set_major_locator(MaxNLocator(integer=True)) # force integer ticks - - # remove frame - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.spines['bottom'].set_visible(False) - ax.spines['left'].set_visible(False) - - lines = [ax.plot([],[], color='green')[0]] - return lines - - def _init_error_animation(self, ax): - ax.set_xlim(0, self.frames) - ax.set_ylim(self.episode_errors.min(), self.episode_errors.max()) - ax.set_xlabel('Timestep') - ax.set_ylabel('Prediction error') - - # remove frame - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.spines['bottom'].set_visible(False) - ax.spines['left'].set_visible(False) - - lines = [ax.plot([],[], color=self.prediction_colors[l])[0] for l in range(self.n_landmarks)] - return lines - - - def new_frame_seq(self): - return iter(range(self.frames)) - - def _init_draw(self): - lines = self.lines_episode + self.lines_reward + self.lines_error - for l in lines: - l.set_data([], []) \ No newline at end of file diff --git a/jaxmarl/environments/utracking/traj_models/.ipynb_checkpoints/traj_linear_models-checkpoint.json b/jaxmarl/environments/utracking/traj_models/.ipynb_checkpoints/traj_linear_models-checkpoint.json deleted file mode 100644 index d2549bd1..00000000 --- a/jaxmarl/environments/utracking/traj_models/.ipynb_checkpoints/traj_linear_models-checkpoint.json +++ /dev/null @@ -1,86 +0,0 @@ -{ - "vel": { - "coeff": 0.03290523576558816, - "intercept": -0.0008086262823894241 - }, - "angle": { - "dt_20": { - "prop_5": { - "coeff": -0.298714126374131, - "intercept": 0.000646663441020369 - }, - "prop_10": { - "coeff": -0.9892416568761737, - "intercept": 0.0029585906834277316 - }, - "prop_15": { - "coeff": -1.6992941668226123, - "intercept": -0.0005613088575260692 - }, - "prop_20": { - "coeff": -2.3877033380196195, - "intercept": 0.0004052716861677151 - }, - "prop_25": { - "coeff": -3.073353205170192, - "intercept": -0.001302143442857626 - }, - "prop_30": { - "coeff": -3.756881273475711, - "intercept": 0.0012445423833910441 - } - }, - "dt_30": { - "prop_5": { - "coeff": -0.6830755395872998, - "intercept": 0.0007760945402775334 - }, - "prop_10": { - "coeff": -1.7820564760840034, - "intercept": 0.0006003293552501779 - }, - "prop_15": { - "coeff": -2.8661005376653805, - "intercept": -0.0005866079121583283 - }, - "prop_20": { - "coeff": -3.9346842431351043, - "intercept": 0.0009530802429342303 - }, - "prop_25": { - "coeff": -4.993159246327514, - "intercept": -0.002065341676689019 - }, - "prop_30": { - "coeff": -6.052076684643244, - "intercept": -0.0005170343450260071 - } - }, - "dt_60": { - "prop_5": { - "coeff": -1.855053477567057, - "intercept": 0.00038509020371100926 - }, - "prop_10": { - "coeff": -4.086121985489588, - "intercept": 0.0009101616401817544 - }, - "prop_15": { - "coeff": -6.302017913874597, - "intercept": 0.0009829407292114394 - }, - "prop_20": { - "coeff": -8.50055812087025, - "intercept": -0.001227075490035038 - }, - "prop_25": { - "coeff": -10.682973863766284, - "intercept": 9.047023092948261e-05 - }, - "prop_30": { - "coeff": -12.860367695837134, - "intercept": -0.0016042385386930771 - } - } - } -} \ No newline at end of file diff --git a/jaxmarl/environments/utracking/traj_models/__init__.py b/jaxmarl/environments/utracking/traj_models/__init__.py deleted file mode 100644 index e00d1ecc..00000000 --- a/jaxmarl/environments/utracking/traj_models/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -import json -import os - -module_dir = os.path.dirname(__file__) -json_file_path = os.path.join(module_dir, 'traj_linear_models.json') - -with open(json_file_path,'r') as f: - traj_models = json.load(f) \ No newline at end of file diff --git a/jaxmarl/environments/utracking/traj_models/traj_linear_models.json b/jaxmarl/environments/utracking/traj_models/traj_linear_models.json deleted file mode 100644 index c1ee34dc..00000000 --- a/jaxmarl/environments/utracking/traj_models/traj_linear_models.json +++ /dev/null @@ -1,98 +0,0 @@ -{ - "vel": { - "coeff": 0.03290523576558816, - "intercept": -0.0008086262823894241 - }, - "angle": { - "dt_20": { - "prop_0": { - "coeff": 0.0, - "intercept": 0.0 - }, - "prop_5": { - "coeff": -0.298714126374131, - "intercept": 0.000646663441020369 - }, - "prop_10": { - "coeff": -0.9892416568761737, - "intercept": 0.0029585906834277316 - }, - "prop_15": { - "coeff": -1.6992941668226123, - "intercept": -0.0005613088575260692 - }, - "prop_20": { - "coeff": -2.3877033380196195, - "intercept": 0.0004052716861677151 - }, - "prop_25": { - "coeff": -3.073353205170192, - "intercept": -0.001302143442857626 - }, - "prop_30": { - "coeff": -3.756881273475711, - "intercept": 0.0012445423833910441 - } - }, - "dt_30": { - "prop_0": { - "coeff": 0.0, - "intercept": 0.0 - }, - "prop_5": { - "coeff": -0.6830755395872998, - "intercept": 0.0007760945402775334 - }, - "prop_10": { - "coeff": -1.7820564760840034, - "intercept": 0.0006003293552501779 - }, - "prop_15": { - "coeff": -2.8661005376653805, - "intercept": -0.0005866079121583283 - }, - "prop_20": { - "coeff": -3.9346842431351043, - "intercept": 0.0009530802429342303 - }, - "prop_25": { - "coeff": -4.993159246327514, - "intercept": -0.002065341676689019 - }, - "prop_30": { - "coeff": -6.052076684643244, - "intercept": -0.0005170343450260071 - } - }, - "dt_60": { - "prop_0": { - "coeff": 0.0, - "intercept": 0.0 - }, - "prop_5": { - "coeff": -1.855053477567057, - "intercept": 0.00038509020371100926 - }, - "prop_10": { - "coeff": -4.086121985489588, - "intercept": 0.0009101616401817544 - }, - "prop_15": { - "coeff": -6.302017913874597, - "intercept": 0.0009829407292114394 - }, - "prop_20": { - "coeff": -8.50055812087025, - "intercept": -0.001227075490035038 - }, - "prop_25": { - "coeff": -10.682973863766284, - "intercept": 9.047023092948261e-05 - }, - "prop_30": { - "coeff": -12.860367695837134, - "intercept": -0.0016042385386930771 - } - } - } -} \ No newline at end of file diff --git a/jaxmarl/environments/utracking/utracking.py b/jaxmarl/environments/utracking/utracking.py deleted file mode 100644 index 25b194a8..00000000 --- a/jaxmarl/environments/utracking/utracking.py +++ /dev/null @@ -1,529 +0,0 @@ -import jax -from jax import numpy as jnp -import chex -from flax import struct -import numpy as np - -from functools import partial -from typing import Tuple - -from .traj_models import traj_models -from jaxmarl.environments.spaces import Box, Discrete -from jaxmarl.environments.multi_agent_env import MultiAgentEnv - -@jax.jit -def fill_diagonal_zeros(arr): - # at the moment I haven't found a better way to fill the diagonal with 0s of unsquarred matrices - return arr - arr * (jnp.eye(arr.shape[0], arr.shape[1])) - -@jax.jit -def batched_least_squares(pos_x, pos_y, pos_xy, z): - """ - Predicts in a single batch the position of multiple landmarks in respect to multiple observers and observations - """ - N = jnp.identity(3)[:-1] - A = jnp.full((*z.shape, 3), -1, dtype=float) - A = A.at[..., 0].set(pos_x * 2) - A = A.at[..., 1].set(pos_y * 2) - - weights = jnp.where(z != 0, 1, 0)[..., None] # Set the weights of missing values to 0. - - b = (jnp.einsum('...ij,...ij->...j', pos_xy, pos_xy) - (z*z))[..., None] - A_aux = jnp.linalg.inv(jnp.einsum('...ij,...ik->...jk', A*weights, A*weights+1e-6)) - A_aux = jnp.einsum('ij,...kj->...ik', N, A_aux) - A_aux = jnp.einsum('...ij,...kj->...ik', A_aux, A*weights) - pred = jnp.einsum('...ij,...jk->...i', A_aux, b*weights) - return pred - - -@struct.dataclass -class State: - pos: chex.Array # [x,y,z,angle]*num_entities, physical state of entities - vel: chex.Array # [float]*num_entities, velocity of entities - traj_coeffs: chex.Array # [float]*num_entities, coefficient of linear trajectory models - traj_intercepts: chex.Array # [float]*num_entities, intercept of linear trajectory models - land_pred_pos: chex.Array # [num_agents, num_landmarks, xyz], current tracking state of each agent for each landmark - range_buffer: chex.Array # [num_agents, num_landmarks, (observer_xy, observed_range), len(buffer)], tracking buffer for each agent-landmark pair - steps_next_land_action: chex.Array # [int]*num_landmarks, step until when the landmarks are gonna change directions - range_buffer_head: int # head iterator of the tracking buffer - t: int # step - -class UTracking(MultiAgentEnv): - - traj_models = traj_models - discrete_actions_mapping = jnp.array([-0.24, -0.12, 0, 0.12, 0.24]) - - def __init__( - self, - num_agents:int, - num_landmarks:int, - dt:int=30, - max_steps:int=400, - render:bool=False, - discrete_actions:bool=True, - agent_depth:Tuple[float, float]=(0., 0.), # defines the range of depth for spawning agents - landmark_depth:Tuple[float, float]=(5., 20.), # defines the range of depth for spawinng landmarks - min_valid_distance:float=5., # under this distance it's considered a crash - min_init_distance:float=100., # minimum initial distance between vehicles - max_init_distance:float=300., # maximum initial distance between vehicles - max_range_dist:float=800., # above this distance a landmark is lost - prop_agent:int=30, # rpm of agent's propulsors, defines the speeds for agents (30rpm is ~1m/s) - prop_range_landmark:Tuple[int]=(0, 5, 10, 15, 20), # defines the possible (propulsor) speeds for landmarks - rudder_range_landmark:Tuple[float, float]=(0.05, 0.15), # defines the angle of movement change for landmarks - dirchange_time_range_landmark:Tuple[int, int]=(5, 15), # defines how many random steps to wait for changing the landmark directions - tracking_buffer_len:int=32, # maximum number of range observations keepen for predicting the landmark positions - range_noise_std:float=10., # standard deviation of the gaussian noise added to range measurments - lost_comm_prob=0.1, # probability of loosing communications - min_steps_ls:int=2, # minimum steps for collecting data and start predicting landmarks positions with least squares - rew_pred_thr:float=10., # tracking error threshold for trackig reward - cont_rew:bool=True, # if false, reward becomes sparse(r) (only based on thresholds), otherwise proportional to tracking error and landmark distance - continuous_actions:bool=False, # if false, discrete actions are defined by the discrete_actions_mapping array - pre_init_pos:bool=True, # computing the initial positions can be expensive if done on the go; to reduce the reset (and therefore step) time, precompute a bunch of possible options - rng_init_pos:chex.PRNGKey=jax.random.PRNGKey(0), # random seed for precomputing initial distance - pre_init_pos_len:int=100000, # how many initial positions preocompute - debug_obs:bool=False, - ): - assert f'dt_{dt}' in traj_models['angle'].keys(), f"dt must be in {traj_models['angle'].keys()}" - self.dt = dt - self.traj_model = traj_models['angle'][f'dt_{dt}'] - assert f'prop_{prop_agent}' in self.traj_model.keys(), \ - f"the propulsor velocity for agents must be in {self.traj_model.keys()}" - assert all(f'prop_{prop}' in self.traj_model.keys() for prop in prop_range_landmark), \ - f"the propulsor choices for landmarks must be in {self.traj_model.keys()}" - - self.max_steps = max_steps - self.num_agents = num_agents - self.num_landmarks = num_landmarks - self.num_entities = num_agents + num_landmarks - self.agents = [f'agent_{i}' for i in range(1, num_agents+1)] - self.landmarks = [f'landmark_{i}' for i in range(1, num_landmarks+1)] - self.entities = self.agents + self.landmarks - - self.discrete_actions = discrete_actions - self.agent_depth = agent_depth - self.landmark_depth = landmark_depth - self.min_valid_distance = min_valid_distance - self.min_init_distance = min_init_distance - self.max_init_distance = max_init_distance - self.max_range_dist = max_range_dist - self.prop_agent = prop_agent - self.prop_range_landmark = prop_range_landmark - self.rudder_range_landmark = np.array(rudder_range_landmark) - self.dirchange_time_range_landmark = dirchange_time_range_landmark - self.tracking_buffer_len = tracking_buffer_len - self.range_noise_std = range_noise_std - self.lost_comm_prob = lost_comm_prob - self.min_steps_ls = min_steps_ls - self.rew_pred_thr = rew_pred_thr - self.cont_rew = cont_rew - self.continuous_actions = continuous_actions - self.pre_init_pos = pre_init_pos - self.pre_init_pos_len = pre_init_pos_len - self.debug_obs = debug_obs - - # action and obs spaces - if self.continuous_actions: - self.action_spaces = {i: Box(-0.24, 0.24, (1,)) for i in self.agents} - else: - self.action_spaces = {i: Discrete(len(self.discrete_actions_mapping)) for i in self.agents} - self.observation_spaces = {i: Box(-jnp.inf, jnp.inf, (6*self.num_entities,)) for i in self.agents} - - # preprocess the traj models - self.traj_model_prop = jnp.array([int(k.split('_')[1]) for k in self.traj_model]) - self.traj_model_coeffs = jnp.array([v['coeff'] for v in self.traj_model.values()]) - self.traj_model_intercepts = jnp.array([v['intercept'] for v in self.traj_model.values()]) - - # trajectory model for agents - self.vel_model = lambda prop: jnp.where(prop==0, 0, prop*traj_models['vel']['coeff']+traj_models['vel']['intercept']) - self.traj_coeffs_agent = jnp.repeat(self.traj_model[f'prop_{self.prop_agent}']['coeff'], self.num_agents) - self.traj_intercepts_agent = jnp.repeat(self.traj_model[f'prop_{self.prop_agent}']['intercept'], self.num_agents) - self.vel_agent = jnp.repeat(self.vel_model(self.prop_agent), self.num_agents) - self.min_agent_dist = self.vel_agent[0]*self.dt # safe distance that agents should keep between them - - # index of the trajectory models valid for landmarks - self.idx_valid_traj_model_landmarks = jnp.array([ - i for i,p in enumerate(self.traj_model_prop) - if p in self.prop_range_landmark - ]) - - # precompute a batch of initial positions if required - if self.pre_init_pos: - rngs = jax.random.split(rng_init_pos, self.pre_init_pos_len) - self.pre_init_xy = jax.jit(jax.vmap(self.get_init_pos, in_axes=(0, None, None)))(rngs, self.min_init_distance, self.max_init_distance) - self.pre_init_choice = jnp.arange(self.pre_init_pos_len) - - @partial(jax.jit, static_argnums=0) - def reset(self, rng): - - # velocity and trajectory models - rng, _rng = jax.random.split(rng) - idx_traj_model_landmarks = jax.random.choice(_rng, self.idx_valid_traj_model_landmarks, shape=(self.num_landmarks,)) - traj_coeffs = jnp.concatenate(( - self.traj_coeffs_agent, # traj model for agents is costant and precomputed - self.traj_model_coeffs[idx_traj_model_landmarks] # sample correct coeff for each landmark - )) - traj_intercepts = jnp.concatenate(( - self.traj_intercepts_agent, - self.traj_model_intercepts[idx_traj_model_landmarks] # sample intercept coeff for each landmark - )) - vel = jnp.concatenate(( - self.vel_agent, # vel of agents is costant and precomputed - self.vel_model(self.traj_model_prop[idx_traj_model_landmarks]) - )) - - # init positions - rng, key_pos, key_agent_depth, key_land_depth, key_dir, key_dir_change = jax.random.split(rng, 6) - if self.pre_init_pos: - xy_pos = self.pre_init_xy[jax.random.choice(key_pos, self.pre_init_choice)] - else: - xy_pos = self.get_init_pos(key_pos, self.min_init_distance, self.max_init_distance) - z = jnp.concatenate(( - jax.random.uniform(key_agent_depth, shape=(self.num_agents,), minval=self.agent_depth[0], maxval=self.agent_depth[1]), - jax.random.uniform(key_land_depth, shape=(self.num_landmarks,), minval=self.landmark_depth[0], maxval=self.landmark_depth[1]), - )) - dir = jax.random.uniform(key_dir, shape=(self.num_entities,), minval=0, maxval=2*np.pi) - pos = jnp.concatenate((xy_pos, z[:, np.newaxis], dir[:, np.newaxis]),axis=1) - steps_next_land_action = jax.random.randint(key_dir_change, (self.num_landmarks,), *self.dirchange_time_range_landmark) - - # init tracking buffer variables - land_pred_pos = jnp.zeros((self.num_agents, self.num_landmarks, 3)) - range_buffer = jnp.zeros((self.num_agents, self.num_landmarks, 3, self.tracking_buffer_len)) # num_agents, num_landmarks, xy range, len(buffer) - range_buffer_head = 0 - t = 0 - - # first communication - rng, key_ranges, key_comm = jax.random.split(rng, 3) - delta_xyz, ranges_real_2d, ranges_real, ranges = self.get_ranges(key_ranges, pos) - range_buffer, range_buffer_head, comm_drop = self.communicate( - key_comm, - ranges, - pos, - range_buffer, - range_buffer_head - ) - land_pred_pos = self.update_predictions(t, range_buffer, pos, ranges) - - # first observation - obs = self.get_obs(delta_xyz, ranges, comm_drop, pos, land_pred_pos) - if self.debug_obs: - obs = {a:1e-3*jnp.concatenate((pos[i], delta_xyz[i].ravel(), ranges_real_2d[i].ravel())) for i, a in enumerate(self.agents)} - obs['__all__'] = self.get_global_state(pos, vel) - - # env state - state = State( - pos=pos, - vel=vel, - traj_coeffs=traj_coeffs, - traj_intercepts=traj_intercepts, - land_pred_pos=land_pred_pos, - range_buffer=range_buffer, - steps_next_land_action=steps_next_land_action, - range_buffer_head=range_buffer_head, - t=t - ) - return obs, state - - @partial(jax.jit, static_argnums=0) - def world_step(self, rudder_actions:chex.Array, pos:chex.Array, vel:chex.Array, traj_coeffs:chex.Array, traj_intercepts:chex.Array): - # update the angle - angle_change = rudder_actions*traj_coeffs+traj_intercepts - # update the x-y position (depth remains constant) - pos = pos.at[:, -1].add(angle_change) - pos = pos.at[:, 0].add(jnp.cos(pos[:, -1])*vel*self.dt) - pos = pos.at[:, 1].add(jnp.sin(pos[:, -1])*vel*self.dt) - return pos - - @partial(jax.jit, static_argnums=0) - def step_env(self, rng:chex.PRNGKey, state:State, actions:dict): - - # preprocess actions - agent_actions = jnp.array([actions[a] for a in self.agents]) - agent_actions = self.preprocess_actions(agent_actions) - landmark_actions, steps_next_land_action = self.get_landmarks_actions(rng, state.steps_next_land_action, state.t) - - # update physical positions - pos = self.world_step( - jnp.concatenate((agent_actions, landmark_actions)), - state.pos, - state.vel, - state.traj_coeffs, - state.traj_intercepts, - ) - - # update tracking - rng, key_ranges, key_comm = jax.random.split(rng, 3) - delta_xyz, ranges_real_2d, ranges_real, ranges = self.get_ranges(key_ranges, pos) - range_buffer, range_buffer_head, comm_drop = self.communicate( - key_comm, - ranges, - pos, - state.range_buffer, - state.range_buffer_head - ) - land_pred_pos = self.update_predictions(state.t, range_buffer, pos, ranges) - - # get global reward, done, info - reward, done, info = self.get_rew_done_info(state.t, pos, ranges, ranges_real_2d, land_pred_pos) - reward = {agent:reward for agent in self.agents} - done = {agent:done for agent in self.agents+['__all__']} - - # agents obs and global state - obs = self.get_obs(delta_xyz, ranges, comm_drop, pos, land_pred_pos) - if self.debug_obs: - obs = {a:1e-4*jnp.concatenate((pos[i], delta_xyz[i].ravel(), ranges_real_2d[i].ravel())) for i, a in enumerate(self.agents)} - obs['__all__'] = self.get_global_state(pos, state.vel) - - state = state.replace( - pos=pos, - land_pred_pos=land_pred_pos, - steps_next_land_action=steps_next_land_action, - range_buffer=range_buffer, - range_buffer_head=range_buffer_head, - t=state.t+1 - ) - return obs, state, reward, done, info - - @partial(jax.jit, static_argnums=0) - def get_obs(self, delta_xyz, ranges, comm_drop, pos, land_pred_pos): - # first a matrix with all the observations is created, composed by - # the position of the agent or the relative position of other agents (comunication) and landmarks (tracking) - # the absolute distance (ranges) is_agent, is_self features - # [pos_x, pos_y, pos_z, dist, is_agent, is_self]*n_entities - other_agents_dist = jnp.where(comm_drop[:, :, None], 0, delta_xyz[:, :self.num_agents]) # 0 for communication drop - self_mask = jnp.arange(self.num_agents) == np.arange(self.num_agents)[:, np.newaxis] - agents_rel_pos = jnp.where(self_mask[:, :, None], pos[:self.num_agents, [0,1,3]], other_agents_dist) # for self use pos_x, pos_y, angle - lands_rel_pos = land_pred_pos - pos[:self.num_agents, None, :3] # relative distance from predicted positions - pos_feats = jnp.concatenate((agents_rel_pos, lands_rel_pos), axis=1) - is_agent_feat = jnp.tile(jnp.concatenate((jnp.ones(self.num_agents), jnp.zeros(self.num_landmarks))), (self.num_agents, 1)) - is_self_feat = (jnp.arange(self.num_entities) == jnp.arange(self.num_agents)[:, np.newaxis]) - # the distance based feats are rescaled to hudreds of meters (better for NNs) - feats = jnp.concatenate((pos_feats*1e-4 , ranges[:, :, None]*1e-4, is_agent_feat[:, :, None], is_self_feat[:, :, None]), axis=2) - - # than it is assigned to each agent its obs - return { - a:feats[i].ravel() - for i, a in enumerate(self.agents) - } - - - @partial(jax.jit, static_argnums=0) - def get_global_state(self, pos, vel): - # state is obs, vel, is_agent for each entity - #pos = pos.at[:, :3].multiply(1e-3) # scale to hundreds of meters - is_agent = jnp.concatenate((jnp.ones(self.num_agents), jnp.zeros(self.num_landmarks))) - return jnp.concatenate(( - pos*1e-4, - vel[:, None], - is_agent[:, None] - ), axis=-1).ravel() - - @partial(jax.jit, static_argnums=0) - def get_rew_done_info(self, t, pos, ranges, ranges_real_2d, land_pred_pos): - #aggregated because reward, done and info computations share similar computations - - # get the prediction error per each landmark from the agent that recived the smaller range - land_2d_pos = pos[self.num_agents:, :2] - land_pred_2d_pos = land_pred_pos[..., :2] - land_ranges = ranges[:, self.num_agents:] - - # best prediction - land_closest_pred_agent = jnp.argmin(jnp.where(land_ranges==0, jnp.inf, land_ranges), axis=0) - best_2d_pred = land_pred_2d_pos[land_closest_pred_agent, jnp.arange(self.num_landmarks)] - pred_2d_err = jnp.sqrt(jnp.sum( - (land_2d_pos - best_2d_pred)** 2, - axis=1)) - - # set self-distance to inf and get the smallest distances - distances_2d = jnp.where(jnp.eye(ranges_real_2d.shape[0], ranges_real_2d.shape[1], dtype=bool), jnp.inf, ranges_real_2d) - min_land_dist = distances_2d[:, self.num_agents:].min(axis=0) # distance between landmarks and their closest agent - - # rewards - rew_good_pred = jnp.where( - pred_2d_err<=self.rew_pred_thr, - 0.1, # 0.1 when the prediction is good enough - 0 - ).sum() # reward for tracking beeing under threshold - - rew_land_distance = jnp.where( - (min_land_dist <= self.min_agent_dist*2), # agent is close enought to the landmark - jnp.where( - (min_land_dist >= self.min_agent_dist), # agent is respecting the safe distance - 1, # 1 if the agent close to the landmark is respecting the safe distance - -1, # -1 if the agent is not respecting the safe distance - ), - jnp.maximum(-1, -1e-3*min_land_dist), # penalty proportional to the distance of the landmark, with a max of -10 - ).sum() # landmark-distance-based reward enhance the agents to stay close to landmarks while respecting safety distance - - - # TODO: reward for keeping a good distance with the other agents - # agent_dist = distances_2d[:, :self.num_agents].min(axis=1) # distance between agents and all other entities - # pen_crash = jnp.where(min_ohter_dist < self.min_valid_distance, -1, 0).min() # penalty for crashing - - rew = rew_land_distance + rew_good_pred - - #rew = jnp.where(rew==self.num_landmarks*2, rew*10, rew) # bonus for following and tracing all the landmarks - - done = ( - (t == self.max_steps) # maximum steps reached - #|((min_ohter_dist < self.min_valid_distance).any()) # or crash - #|((min_land_dist > self.max_range_dist).any()) # or lost landmark - ) - - info = { - 'rew':rew, - 'tracking_pred': best_2d_pred, - 'tracking_error': pred_2d_err, - 'land_dist': min_land_dist, - 'tracking_error_mean': pred_2d_err.mean(), - 'land_dist_mean': min_land_dist.mean(), - } - - return rew, done, info - - @partial(jax.jit, static_argnums=0) - def preprocess_actions(self, actions): - if self.continuous_actions: - return jnp.clip(actions, a_min=-0.24, a_max=0.24).squeeze() - else: - return self.discrete_actions_mapping[actions] - - @partial(jax.jit, static_argnums=0) - def get_landmarks_actions(self, rng, steps_next_land_action, t): - # range of change of direction is 0 unti steps_next_land_action hasn't reached t - rng, key_action, key_sign, key_next_action = jax.random.split(rng, 4) - action_range = jnp.where( - steps_next_land_action[:, None]==t, - self.rudder_range_landmark, - jnp.zeros(2) - ) - actions = jax.random.uniform( - key_action, - shape=(self.num_landmarks,), - minval=action_range[:,0], - maxval=action_range[:,1] - ) * jax.random.choice(key_sign, shape=(self.num_landmarks,), a=jnp.array([-1,1])) # random sign - # sample the next step of direction change for landmarks that changed direction - steps_next_land_action = jnp.where( - steps_next_land_action==t, - t + jax.random.randint(key_next_action, (self.num_landmarks,), *self.dirchange_time_range_landmark), - steps_next_land_action - ) - return actions, steps_next_land_action - - @partial(jax.jit, static_argnums=0) - def get_ranges(self, rng, pos): - # computes the real 3d and 2d ranges and defines the observed range - rng, key_noise, key_lost = jax.random.split(rng, 3) - delta_xyz = pos[:self.num_agents, np.newaxis, :3] - pos[:, :3] - ranges_real_2d = jnp.sqrt(jnp.sum( - (delta_xyz[..., :2])** 2, - axis=2)) # euclidean distances between agents and all other entities in 2d space - ranges_real = jnp.sqrt(jnp.sum( - (pos[:self.num_agents, np.newaxis, :3] - pos[:, :3])** 2, - axis=2)) # euclidean distances between agents and all other entities - ranges = ranges_real + jax.random.normal(key_noise, shape=ranges_real.shape)*self.range_noise_std # add noise - ranges = jnp.where( - (jax.random.uniform(key_lost, shape=ranges.shape)>self.lost_comm_prob)|(ranges_real>self.max_range_dist), # lost communication or landmark too far - ranges, - 0 - ) # lost communications - ranges = fill_diagonal_zeros(ranges) # reset to 0s the self-ranges - return delta_xyz, ranges_real_2d, ranges_real, ranges - - @partial(jax.jit, static_argnums=0) - def communicate(self, rng, ranges, pos, range_buffer, range_buffer_head): - - rng, key_comm = jax.random.split(rng) - - # comm_drop is a bool mask that defines which agent-to-agent communications are dropped - comm_drop = jax.random.uniform(key_comm, shape=(self.num_agents, self.num_agents))=self.min_steps_ls, - true_fun=_update, - false_fun=_dummy_pred, - operand=None - ) - - # avoid nans for singular matrix in ls computation - land_pred_pos = jnp.where(jnp.isnan(land_pred_pos), _dummy_pred(None), land_pred_pos) - - return land_pred_pos - - @partial(jax.jit, static_argnums=0) - def estimate_depth(self, pred_xy, pos, ranges): - # bad depth estimation using pitagora - pos_xy = pos[:self.num_agents, :2] - ranges = ranges[:, self.num_agents:] # ranges between agents and landmarks - delta_xy = (pred_xy - pos_xy[:, np.newaxis])**2 - to_square = ranges**2 - delta_xy[:, :, 0] - delta_xy[:, :, 1] - z = pos[:self.num_agents, -1:] + jnp.sqrt(jnp.where(to_square>0, to_square, 0)) - return z - - @partial(jax.jit, static_argnums=0) - def get_init_pos(self, rng, min_init_distance, max_init_distance): - def generate_points(carry, _): - rng, points, i = carry - mask = jnp.arange(self.num_entities) >= i # consider a priori valid the distances with non-done points - def generate_point(while_state): - rng, _ = while_state - rng, _rng = jax.random.split(rng) - new_point = jax.random.uniform(_rng, (2,), minval=-max_init_distance, maxval=max_init_distance) - return rng, new_point - def is_valid_point(while_state): - _, point = while_state - distances = jnp.sqrt(jnp.sum((points - point)**2, axis=-1)) - return ~ jnp.all(mask | ((distances >= min_init_distance) & (distances <= max_init_distance))) - init_point = generate_point((rng, 0)) - rng, new_point = jax.lax.while_loop( - cond_fun = is_valid_point, - body_fun = generate_point, - init_val = init_point - ) - points = points.at[i].set(new_point) - carry = (rng, points, i+1) - return carry, new_point - rng, _rng = jax.random.split(rng) - pos = jnp.zeros((self.num_entities, 2)) - pos = pos.at[0].set(jax.random.uniform(_rng, (2,), minval=-max_init_distance, maxval=max_init_distance)) # first point - (rng, pos, i), _ = jax.lax.scan(generate_points, (rng, pos, 1), None, self.num_entities-1) - return pos \ No newline at end of file