Skip to content

Commit

Permalink
starting curr
Browse files Browse the repository at this point in the history
  • Loading branch information
mttga committed Jan 24, 2025
1 parent 9f69751 commit 3ffc3ff
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 97 deletions.
48 changes: 22 additions & 26 deletions baselines/MAPPO/config/mappo_homogenous_transf_utracking.yaml
Original file line number Diff line number Diff line change
@@ -1,66 +1,62 @@
"LR": 0.0005
"NUM_ENVS": 128
"NUM_STEPS": 64
"TOTAL_TIMESTEPS": 5e7
"HIDDEN_DIM": 32
"AGENT_NUM_LAYERS": 2
"AGENT_NUM_HEADS": 4
"AGENT_FF_DIM": 128
"CRITIC_NUM_LAYERS": 2
"CRITIC_NUM_HEADS": 4
"CRITIC_FF_DIM": 128
"NUM_ENVS": 1024
"NUM_STEPS": 128
"TOTAL_TIMESTEPS": 1e8
"HIDDEN_DIM": 64
"NUM_LAYERS": 2
"NUM_HEADS": 8
"FF_DIM": 128
"UPDATE_EPOCHS": 4
"NUM_MINIBATCHES": 4
"NUM_MINIBATCHES": 8
"GAMMA": 0.99
"GAE_LAMBDA": 0.95
"CLIP_EPS": 0.2
"SCALE_CLIP_EPS": False
"ENT_COEF": 0.01
"VF_COEF": 0.5
"KL_COEF": 0.1
"MAX_GRAD_NORM": 0.5
"ACTIVATION": "relu"
"ANNEAL_LR": False

# just for the
"FC_DIM_SIZE": 128
"GRU_HIDDEN_DIM": 128

# ENV
"ENV_NAME": "utracking"
"ENV_KWARGS": {
"num_agents": 2,
"num_landmarks": 2,
"num_agents": 1,
"num_landmarks": 1,
"max_steps": 512,
"dt": 30,
"prop_range_landmark": [0, 5, 10, 15], # possible propulsor velocities of the landmarks
"rew_type": "follow",
"penalty_failed_episode": False,
"rew_pred_thr": 10,
"rew_type": "tracking_error",
"penalty_failed_episode": True,
"rew_pred_thr": 20,
"min_valid_distance": 5,
"min_init_distance": 30,
"max_init_distance": 200,
"max_init_distance": 100,
"pre_init_pos_len": 1000000,
"max_range_dist": 600,
"tracking_method": "pf",
"matrix_obs": True, # essential for transformer
"matrix_state": True,
"state_as_edges": False,
"pf_num_particles": 1000,
"pf_num_particles": 2000,
}

# EXP
"SEED": 0
"NUM_SEEDS": 1
"TUNE": False
"SAVE_PATH": "models"
"ALG_NAME": "mappo_transformer"
"ANIMATION_LOG_INTERVAL": #0.33 # percentage of total update steps. animating will slow down training and use more memory
"SAVE_PATH": "models/utracking_curr"
"LOAD_PATH": "models/utracking_curr/utracking_1v1/mappo_transforme_curr_utracking_1_vs_1_seed0_vmap0.safetensors"
"LOAD_CRITIC": True
"ALG_NAME": "mappo_transformer_curr_tracking_error_from_follow"
"ANIMATION_LOG_INTERVAL": 0.3 # percentage of total update steps. animating will slow down training and use more memory
"ANIMATION_MAX_STEPS": 256 # should be the same of the env

# WANDB
"WANDB_MODE": "online"
"ENTITY": "mttga"
"PROJECT": "utracking_new"
"PROJECT": "utracking_curr"
"WANDB_LOG_ALL_SEEDS": False


142 changes: 90 additions & 52 deletions baselines/MAPPO/mappo_transformer_utracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from jaxmarl import make
from jaxmarl.wrappers.baselines import LogWrapper, SMAXLogWrapper, JaxMARLWrapper
from jaxmarl.viz.utracking_visualizer import animate_from_infos
from jaxmarl.wrappers.baselines import save_params, load_params


class EncoderBlock(nn.Module):
Expand Down Expand Up @@ -55,7 +56,7 @@ def setup(self):
def __call__(self, x, mask=None, deterministic=True):

# Attention part
if mask is not None: # masking is not compatible with fast self attention
if mask is not None:
mask = jnp.repeat(
nn.make_attention_mask(mask, mask), self.num_heads, axis=-3
)
Expand Down Expand Up @@ -123,22 +124,24 @@ def __call__(self, carry, x):
hs = carry
embeddings, mask, done = x

# reset hidden state and add
# reset hidden state and add
hs = jnp.where(
done[:, np.newaxis], # batch_wize, 1,
self.initialize_carry(*done.shape, self.hidden_dim), # batch_size, hidden_dim
hs, # batch size, hidden_dim
done[:, np.newaxis], # batch_wize, 1,
self.initialize_carry(
*done.shape, self.hidden_dim
), # batch_size, hidden_dim
hs, # batch size, hidden_dim
)
embeddings = jnp.concatenate(
(
hs[..., np.newaxis, :], # batch size, 1, hidden_dim
hs[..., np.newaxis, :], # batch size, 1, hidden_dim
embeddings,
),
axis=-2,
)
for layer in self.encoders:
embeddings = layer(embeddings, mask=mask, deterministic=self.deterministic)
hs = embeddings[..., 0, :] # batch size, hidden_dim
hs = embeddings[..., 0, :] # batch size, hidden_dim

# as y return the entire embeddings if required (i.e. transformer mixer), otherwise only agents' hs embeddings
if self.return_embeddings:
Expand All @@ -160,16 +163,16 @@ def __call__(self, hs, x, return_all_hs=False):

ins, resets, avail_actions = x
embeddings = Embedder(
self.config['HIDDEN_DIM'],
self.config["HIDDEN_DIM"],
)(ins)

print('actor embeddings shape:', embeddings.shape)
print("actor embeddings shape:", embeddings.shape)

last_hs, hidden_states = ScannedTransformer(
hidden_dim=self.config['HIDDEN_DIM'],
transf_num_layers=self.config['AGENT_NUM_LAYERS'],
transf_num_heads=self.config['AGENT_NUM_HEADS'],
transf_dim_feedforward=self.config['AGENT_FF_DIM'],
hidden_dim=self.config["HIDDEN_DIM"],
transf_num_layers=self.config["NUM_LAYERS"],
transf_num_heads=self.config["NUM_HEADS"],
transf_dim_feedforward=self.config["FF_DIM"],
deterministic=True,
return_embeddings=False,
)(hs, (embeddings, None, resets))
Expand All @@ -196,17 +199,17 @@ def __call__(self, hs, x):

world_state, resets = x

embeddings= Embedder(
self.config['HIDDEN_DIM'],
embeddings = Embedder(
self.config["HIDDEN_DIM"],
)(world_state)

print('critic embeddings shape:', embeddings.shape)
print("critic embeddings shape:", embeddings.shape)

last_hs, hidden_states = ScannedTransformer(
hidden_dim=self.config['HIDDEN_DIM'],
transf_num_layers=self.config['CRITIC_NUM_LAYERS'],
transf_num_heads=self.config['CRITIC_NUM_HEADS'],
transf_dim_feedforward=self.config['CRITIC_FF_DIM'],
hidden_dim=self.config["HIDDEN_DIM"],
transf_num_layers=self.config["NUM_LAYERS"],
transf_num_heads=self.config["NUM_HEADS"],
transf_dim_feedforward=self.config["FF_DIM"],
deterministic=True,
return_embeddings=False,
)(hs, (embeddings, None, resets))
Expand Down Expand Up @@ -245,6 +248,7 @@ def batchify_transformer(x: dict, agent_list, num_actors):
x = x.reshape((num_actors, num_entities, num_feats))
return x


def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
x = x.reshape((num_actors, num_envs, -1))
return {a: x[i] for i, a in enumerate(agent_list)}
Expand Down Expand Up @@ -273,6 +277,11 @@ def make_train(config):

env = LogWrapper(env)

if config["LOAD_PATH"] is not None:
config["MODEL_PARAMS"] = load_params(config["LOAD_PATH"])
print("loaded model from", config["LOAD_PATH"])


def linear_schedule(count):
frac = (
1.0
Expand All @@ -284,36 +293,51 @@ def linear_schedule(count):
def train(rng):
original_seed = rng[0]
# INIT NETWORK
actor_network = TransformerAgent(env.action_space(env.agents[0]).n, config=config)
actor_network = TransformerAgent(
env.action_space(env.agents[0]).n, config=config
)
critic_network = TransformerCritic(config=config)
rng, _rng_actor, _rng_critic = jax.random.split(rng, 3)
ac_init_x = (
jnp.zeros(
(1, config["NUM_ENVS"], *env.observation_space(env.agents[0]).shape)
), # (time_step, batch_size, n_entities, obs_size)
jnp.zeros((1, config["NUM_ENVS"])), # (time_step, batch_size)
jnp.zeros((1, config["NUM_ENVS"], env.action_space(env.agents[0]).n)), # (time_step, batch_size, num_actions)
)
ac_init_hstate = ScannedTransformer.initialize_carry(
config["NUM_ENVS"], config["HIDDEN_DIM"], # (batch_size, hidden_dim)
)
actor_network_params = actor_network.init(_rng_actor, ac_init_hstate, ac_init_x)
cr_init_x = (
jnp.zeros(
(
1,
config["NUM_ENVS"],
*env.world_state_space.shape,
)
),
jnp.zeros((1, config["NUM_ENVS"])),
)
cr_init_hstate = ScannedTransformer.initialize_carry(
config["NUM_ENVS"], config["HIDDEN_DIM"], # (batch_size, hidden_dim)
)
critic_network_params = critic_network.init(
_rng_critic, cr_init_hstate, cr_init_x
)

if config["LOAD_PATH"] is not None:
actor_network_params = config["MODEL_PARAMS"]["actor"]
critic_network_params = config["MODEL_PARAMS"]["critic"]
else:
ac_init_x = (
jnp.zeros(
(1, config["NUM_ENVS"], *env.observation_space(env.agents[0]).shape)
), # (time_step, batch_size, n_entities, obs_size)
jnp.zeros((1, config["NUM_ENVS"])), # (time_step, batch_size)
jnp.zeros(
(1, config["NUM_ENVS"], env.action_space(env.agents[0]).n)
), # (time_step, batch_size, num_actions)
)
ac_init_hstate = ScannedTransformer.initialize_carry(
config["NUM_ENVS"],
config["HIDDEN_DIM"], # (batch_size, hidden_dim)
)
actor_network_params = actor_network.init(
_rng_actor, ac_init_hstate, ac_init_x
)

if config["LOAD_PATH"] is not None and not config["LOAD_CRITIC"]:
cr_init_x = (
jnp.zeros(
(
1,
config["NUM_ENVS"],
*env.world_state_space.shape,
)
),
jnp.zeros((1, config["NUM_ENVS"])),
)
cr_init_hstate = ScannedTransformer.initialize_carry(
config["NUM_ENVS"],
config["HIDDEN_DIM"], # (batch_size, hidden_dim)
)
critic_network_params = critic_network.init(
_rng_critic, cr_init_hstate, cr_init_x
)

if config["ANNEAL_LR"]:
actor_tx = optax.chain(
Expand Down Expand Up @@ -371,7 +395,9 @@ def _env_step(runner_state, unused):
avail_actions = jax.lax.stop_gradient(
batchify(avail_actions, env.agents, config["NUM_ACTORS"])
)
obs_batch = batchify_transformer(last_obs, env.agents, config["NUM_ACTORS"])
obs_batch = batchify_transformer(
last_obs, env.agents, config["NUM_ACTORS"]
)
print("obs shape:", obs_batch.shape)
ac_in = (
obs_batch[np.newaxis, :],
Expand Down Expand Up @@ -449,7 +475,11 @@ def _env_step(runner_state, unused):
last_world_state, env.num_agents, axis=0
) # repeat world_state for each agent
last_world_state = last_world_state.reshape(
(config["NUM_ACTORS"], last_world_state.shape[-2], last_world_state.shape[-1])
(
config["NUM_ACTORS"],
last_world_state.shape[-2],
last_world_state.shape[-1],
)
) # (num_actors, world_state_size)

cr_in = (
Expand Down Expand Up @@ -486,6 +516,8 @@ def _get_advantages(gae_and_next_value, transition):
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val)
# standardization should go here
# advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

# UPDATE NETWORK
def _update_epoch(update_state, unused):
Expand Down Expand Up @@ -525,7 +557,11 @@ def _actor_loss_fn(actor_params, init_hstate, traj_batch, gae):
approx_kl = ((ratio - 1) - logratio).mean()
clip_frac = jnp.mean(jnp.abs(ratio - 1) > config["CLIP_EPS"])

actor_loss = loss_actor - config["ENT_COEF"] * entropy
actor_loss = (
loss_actor
- config["ENT_COEF"] * entropy
+ config["KL_COEF"] * approx_kl
)

return actor_loss, (
loss_actor,
Expand Down Expand Up @@ -725,7 +761,9 @@ def get_complete_rollout(rng, params):
def step_agent(rng, hstate, obsv, last_done, env_state):

avail_actions = env.get_avail_actions(env_state.env_state)
obs_batch = batchify_transformer(obsv, env.agents, env.num_agents)
obs_batch = batchify_transformer(
obsv, env.agents, env.num_agents
)
avail_actions = batchify(
avail_actions, env.agents, env.num_agents
)
Expand Down
Loading

0 comments on commit 3ffc3ff

Please sign in to comment.