From e26465170eee749a19a8b2fc63d6f2804becfb8d Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 28 May 2024 16:42:47 +0100 Subject: [PATCH 01/20] updated docker and make files2 --- Dockerfile | 15 ++++++++++++--- Makefile | 4 +++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index cbc02a27..7ca0eee8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,16 +1,25 @@ FROM nvcr.io/nvidia/jax:23.10-py3 +# Create user +ARG UID +ARG MYUSER +RUN useradd -u $UID --create-home ${MYUSER} +USER ${MYUSER} + # default workdir -WORKDIR /home/workdir -COPY . . +WORKDIR /home/${MYUSER}/ +COPY --chown=${MYUSER} --chmod=765 . . #jaxmarl from source if needed, all the requirements +USER root RUN pip install -e . # install tmux RUN apt-get update && \ apt-get install -y tmux +USER ${MYUSER} + #disabling preallocation RUN export XLA_PYTHON_CLIENT_PREALLOCATE=false #safety measures @@ -23,4 +32,4 @@ RUN export TF_FORCE_GPU_ALLOW_GROWTH=true #for secrets and debug ENV WANDB_API_KEY="" ENV WANDB_ENTITY="" -RUN git config --global --add safe.directory /home/workdir +RUN git config --global --add safe.directory /home/${MYUSER} diff --git a/Makefile b/Makefile index c54b1f98..a56106a8 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ endif # Set flag for docker run command +MYUSER=alexr BASE_FLAGS=-it --rm -v ${PWD}:/home/workdir --shm-size 20G RUN_FLAGS=$(GPUS) $(BASE_FLAGS) @@ -15,10 +16,11 @@ DOCKER_IMAGE_NAME = jaxmarl IMAGE = $(DOCKER_IMAGE_NAME):latest DOCKER_RUN=docker run $(RUN_FLAGS) $(IMAGE) USE_CUDA = $(if $(GPUS),true,false) +ID = $(shell id -u) # make file commands build: - DOCKER_BUILDKIT=1 docker build --build-arg USE_CUDA=$(USE_CUDA) --tag $(IMAGE) --progress=plain ${PWD}/. + DOCKER_BUILDKIT=1 docker build --build-arg USE_CUDA=$(USE_CUDA) --build-arg MYUSER=$(MYUSER) --build-arg UID=$(ID) --tag $(IMAGE) --progress=plain ${PWD}/. run: $(DOCKER_RUN) /bin/bash From c42e838f893b2e16afc87a22ecdb9c0110285dcf Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 28 May 2024 17:55:26 +0100 Subject: [PATCH 02/20] mabrax sweep --- baselines/IPPO/config/ippo_ff_mabrax.yaml | 6 +-- baselines/IPPO/ippo_ff_mabrax.py | 50 ++++++++++++++--------- sweep_mabrax.yaml | 24 +++++++++++ 3 files changed, 57 insertions(+), 23 deletions(-) create mode 100644 sweep_mabrax.yaml diff --git a/baselines/IPPO/config/ippo_ff_mabrax.yaml b/baselines/IPPO/config/ippo_ff_mabrax.yaml index f4f41d64..be194dea 100644 --- a/baselines/IPPO/config/ippo_ff_mabrax.yaml +++ b/baselines/IPPO/config/ippo_ff_mabrax.yaml @@ -1,7 +1,7 @@ "LR": 1e-3 "NUM_ENVS": 64 "NUM_STEPS": 300 -"TOTAL_TIMESTEPS": 1e7 +"TOTAL_TIMESTEPS": 5e7 "UPDATE_EPOCHS": 4 "NUM_MINIBATCHES": 4 "GAMMA": 0.99 @@ -19,6 +19,6 @@ "DISABLE_JIT": False # WandB Params -"ENTITY": "" +"ENTITY": "alex-plus" "PROJECT": "jaxmarl-mabrax" -"WANDB_MODE": "disabled" \ No newline at end of file +"WANDB_MODE": "online" \ No newline at end of file diff --git a/baselines/IPPO/ippo_ff_mabrax.py b/baselines/IPPO/ippo_ff_mabrax.py index c057128e..41e415ab 100644 --- a/baselines/IPPO/ippo_ff_mabrax.py +++ b/baselines/IPPO/ippo_ff_mabrax.py @@ -228,14 +228,23 @@ def _loss_fn(params, traj_batch, gae, targets): + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy ) - return total_loss, (value_loss, loss_actor, entropy) + return total_loss, (value_loss, loss_actor, entropy, ratio) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) total_loss, grads = grad_fn( train_state.params, traj_batch, advantages, targets ) train_state = train_state.apply_gradients(grads=grads) - return train_state, total_loss + + loss_info = { + "total_loss": total_loss[0], + "actor_loss": total_loss[1][1], + "critic_loss": total_loss[1][0], + "entropy": total_loss[1][2], + "ratio": total_loss[1][3], + } + + return train_state, loss_info train_state, traj_batch, advantages, targets, rng = update_state rng, _rng = jax.random.split(rng) @@ -257,11 +266,16 @@ def _loss_fn(params, traj_batch, gae, targets): ), shuffled_batch, ) - train_state, total_loss = jax.lax.scan( + train_state, loss_info = jax.lax.scan( _update_minbatch, train_state, minibatches ) update_state = (train_state, traj_batch, advantages, targets, rng) - return update_state, total_loss + return update_state, loss_info + + def callback(metric): + wandb.log( + metric + ) update_state = (train_state, traj_batch, advantages, targets, rng) update_state, loss_info = jax.lax.scan( @@ -270,7 +284,12 @@ def _loss_fn(params, traj_batch, gae, targets): train_state = update_state[0] metric = traj_batch.info rng = update_state[-1] - + + r0 = {"ratio0": loss_info["ratio"][0,0].mean()} + loss_info = jax.tree_map(lambda x: x.mean(), loss_info) + metric = jax.tree_map(lambda x: x.mean(), metric) + metric = {**metric, **loss_info, **r0} + jax.experimental.io_callback(callback, None, metric) runner_state = (train_state, env_state, last_obs, rng) return runner_state, metric @@ -301,22 +320,13 @@ def main(config): train_jit = jax.jit(make_train(config), device=jax.devices()[config["DEVICE"]]) out = train_jit(rng) - '''updates_x = jnp.arange(out["metrics"]["returned_episode_returns"].squeeze().shape[0]) - print('updates x', updates_x.shape) - print('metrics shape', out["metrics"]["returned_episode_returns"].shape) - returns_table = jnp.stack([updates_x, out["metrics"]["returned_episode_returns"].mean(-1).squeeze()], axis=1) - returns_table = wandb.Table(data=returns_table.tolist(), columns=["updates", "returns"]) - wandb.log({ - "returns_plot": wandb.plot.line(returns_table, "updates", "returns", title="returns_vs_updates"), - "returns": out["metrics"]["returned_episode_returns"].mean() - })''' - mean_returns = out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1) - x = np.arange(len(mean_returns)) * config["NUM_ACTORS"] - plt.plot(x, mean_returns) - plt.xlabel("Timestep") - plt.ylabel("Return") - plt.savefig(f'mabrax_ippo_ret.png') + # mean_returns = out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1) + # x = np.arange(len(mean_returns)) * config["NUM_ACTORS"] + # plt.plot(x, mean_returns) + # plt.xlabel("Timestep") + # plt.ylabel("Return") + # plt.savefig(f'mabrax_ippo_ret.png') # import pdb; pdb.set_trace() diff --git a/sweep_mabrax.yaml b/sweep_mabrax.yaml new file mode 100644 index 00000000..818eb84c --- /dev/null +++ b/sweep_mabrax.yaml @@ -0,0 +1,24 @@ +command: + - python3 + - ${program} + - ${args_no_hyphens} +entity: amacrutherford +method: grid +parameters: + SEED: + values: + - 42 + - 43 + - 44 + - 45 + - 46 + - 47 + - 48 + - 49 + - 50 + - 51 + WANDB_MODE: + values: + - online +program: baselines/IPPO/ippo_ff_mabrax.py +project: jaxmarl-hanabi \ No newline at end of file From 6bca5d0ae89b2a25d4cb1bbc3befe8427f14fde7 Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 28 May 2024 19:26:17 +0100 Subject: [PATCH 03/20] non hetrogenous agent support --- baselines/IPPO/ippo_ff_mabrax.py | 68 +++++++++++++++++++------------- 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/baselines/IPPO/ippo_ff_mabrax.py b/baselines/IPPO/ippo_ff_mabrax.py index 41e415ab..c0641b0a 100644 --- a/baselines/IPPO/ippo_ff_mabrax.py +++ b/baselines/IPPO/ippo_ff_mabrax.py @@ -65,15 +65,24 @@ class Transition(NamedTuple): obs: jnp.ndarray info: jnp.ndarray +# def batchify(x: dict, agent_list, num_actors): +# x = jnp.stack([x[a] for a in agent_list]) +# return x.reshape((num_actors, -1)) + def batchify(x: dict, agent_list, num_actors): - x = jnp.stack([x[a] for a in agent_list]) + max_dim = max([x[a].shape[-1] for a in agent_list]) + print('max_dim', max_dim) + def pad(z): + return jnp.concatenate([z, jnp.zeros(z.shape[:-1] + (max_dim - z.shape[-1],))], -1) + + x = jnp.stack([x[a] if x[a].shape[-1] == max_dim else pad(x[a]) for a in agent_list]) return x.reshape((num_actors, -1)) def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors): x = x.reshape((num_actors, num_envs, -1)) return {a: x[i] for i, a in enumerate(agent_list)} -def make_train(config): +def make_train(config, rng_init): env = jaxmarl.make(config["ENV_NAME"], **config["ENV_KWARGS"]) config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"] config["NUM_UPDATES"] = ( @@ -89,27 +98,27 @@ def linear_schedule(count): frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"] return config["LR"] * frac - def train(rng): + # INIT NETWORK + network = ActorCritic(env.action_space(env.agents[0]).shape[0], activation=config["ACTIVATION"]) + # rng, _rng = jax.random.split(rng_init) + max_dim = jnp.argmax(jnp.array([env.observation_space(a).shape[-1] for a in env.agents])) + init_x = jnp.zeros(env.observation_space(env.agents[max_dim]).shape) + network_params = network.init(rng_init, init_x) + if config["ANNEAL_LR"]: + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(learning_rate=linear_schedule, eps=1e-5), + ) + else: + tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5)) - # INIT NETWORK - # TODO doesn't work for non-homogenous agents - network = ActorCritic(env.action_space(env.agents[0]).shape[0], activation=config["ACTIVATION"]) - rng, _rng = jax.random.split(rng) - init_x = jnp.zeros(env.observation_space(env.agents[0]).shape) - network_params = network.init(_rng, init_x) - if config["ANNEAL_LR"]: - tx = optax.chain( - optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), - optax.adam(learning_rate=linear_schedule, eps=1e-5), - ) - else: - tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5)) + train_state = TrainState.create( + apply_fn=network.apply, + params=network_params, + tx=tx, + ) - train_state = TrainState.create( - apply_fn=network.apply, - params=network_params, - tx=tx, - ) + def train(rng): # INIT ENV rng, _rng = jax.random.split(rng) @@ -121,7 +130,7 @@ def train(rng): def _update_step(runner_state, unused): # COLLECT TRAJECTORIES def _env_step(runner_state, unused): - train_state, env_state, last_obs, rng = runner_state + train_state, env_state, last_obs, update_count, rng = runner_state obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"]) # SELECT ACTION @@ -149,15 +158,15 @@ def _env_step(runner_state, unused): obs_batch, info, ) - runner_state = (train_state, env_state, obsv, rng) + runner_state = (train_state, env_state, obsv, update_count, rng) return runner_state, transition runner_state, traj_batch = jax.lax.scan( _env_step, runner_state, None, config["NUM_STEPS"] ) - + print('traj_batch', traj_batch) # CALCULATE ADVANTAGE - train_state, env_state, last_obs, rng = runner_state + train_state, env_state, last_obs, update_count, rng = runner_state last_obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"]) _, last_val = network.apply(train_state.params, last_obs_batch) @@ -285,16 +294,18 @@ def callback(metric): metric = traj_batch.info rng = update_state[-1] + update_count = update_count + 1 r0 = {"ratio0": loss_info["ratio"][0,0].mean()} loss_info = jax.tree_map(lambda x: x.mean(), loss_info) metric = jax.tree_map(lambda x: x.mean(), metric) + metric["env_step"] = update_count * config["NUM_STEPS"] * config["NUM_ENVS"] metric = {**metric, **loss_info, **r0} jax.experimental.io_callback(callback, None, metric) - runner_state = (train_state, env_state, last_obs, rng) + runner_state = (train_state, env_state, last_obs, update_count, rng) return runner_state, metric rng, _rng = jax.random.split(rng) - runner_state = (train_state, env_state, obsv, _rng) + runner_state = (train_state, env_state, obsv, 0, _rng) runner_state, metric = jax.lax.scan( _update_step, runner_state, None, config["NUM_UPDATES"] ) @@ -317,7 +328,8 @@ def main(config): rng = jax.random.PRNGKey(config["SEED"]) with jax.disable_jit(config["DISABLE_JIT"]): - train_jit = jax.jit(make_train(config), device=jax.devices()[config["DEVICE"]]) + rng, _rng = jax.random.split(rng) + train_jit = jax.jit(make_train(config, _rng), device=jax.devices()[config["DEVICE"]]) out = train_jit(rng) From fb27ff4cbc8e3b1b579ac26769458249cfe365fe Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 30 May 2024 13:16:22 +0100 Subject: [PATCH 04/20] log with update step --- baselines/IPPO/ippo_ff_mabrax.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/baselines/IPPO/ippo_ff_mabrax.py b/baselines/IPPO/ippo_ff_mabrax.py index c0641b0a..b45b5adb 100644 --- a/baselines/IPPO/ippo_ff_mabrax.py +++ b/baselines/IPPO/ippo_ff_mabrax.py @@ -283,7 +283,8 @@ def _loss_fn(params, traj_batch, gae, targets): def callback(metric): wandb.log( - metric + metric, + step=metric["update_step"], ) update_state = (train_state, traj_batch, advantages, targets, rng) @@ -298,6 +299,7 @@ def callback(metric): r0 = {"ratio0": loss_info["ratio"][0,0].mean()} loss_info = jax.tree_map(lambda x: x.mean(), loss_info) metric = jax.tree_map(lambda x: x.mean(), metric) + metric["update_step"] = update_count metric["env_step"] = update_count * config["NUM_STEPS"] * config["NUM_ENVS"] metric = {**metric, **loss_info, **r0} jax.experimental.io_callback(callback, None, metric) From 24328dc5b003a3f6978c614bdc902240da7b5954 Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 30 May 2024 13:26:24 +0100 Subject: [PATCH 05/20] update brax req --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 19427146..6773d453 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -7,7 +7,7 @@ optax==0.1.7 dotmap==1.3.30 evosax==0.1.5 distrax==0.1.5 -brax==0.10.0 +brax==0.10.3 gymnax==0.0.6 safetensors==0.4.2 flashbax==0.1.0 From a63d8ea2ac72b5c50f5285b48ee3b402a5e53b31 Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 30 May 2024 13:26:47 +0100 Subject: [PATCH 06/20] update proj --- sweep_mabrax.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sweep_mabrax.yaml b/sweep_mabrax.yaml index 818eb84c..a82aff13 100644 --- a/sweep_mabrax.yaml +++ b/sweep_mabrax.yaml @@ -21,4 +21,4 @@ parameters: values: - online program: baselines/IPPO/ippo_ff_mabrax.py -project: jaxmarl-hanabi \ No newline at end of file +project: jaxmarl-mabrax \ No newline at end of file From c8ee4627dc5419de7adce985980bf3dc6cdd2e16 Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 30 May 2024 13:31:40 +0100 Subject: [PATCH 07/20] pin mujoco --- requirements/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 6773d453..a9ce3284 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,6 +8,7 @@ dotmap==1.3.30 evosax==0.1.5 distrax==0.1.5 brax==0.10.3 +mujoco==3.1.3 gymnax==0.0.6 safetensors==0.4.2 flashbax==0.1.0 From 44b4def6c75328bb480557413108bd71ae61ccfb Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 30 May 2024 13:38:55 +0100 Subject: [PATCH 08/20] correct -v --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index a56106a8..2fbca96d 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ endif # Set flag for docker run command MYUSER=alexr -BASE_FLAGS=-it --rm -v ${PWD}:/home/workdir --shm-size 20G +BASE_FLAGS=-it --rm -v ${PWD}:/home/$(MYUSER) --shm-size 20G RUN_FLAGS=$(GPUS) $(BASE_FLAGS) DOCKER_IMAGE_NAME = jaxmarl From 7c310d3ee482eb70856b76e5087eed025f04e349 Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 30 May 2024 13:44:41 +0100 Subject: [PATCH 09/20] hanabi ippo --- baselines/IPPO/config/ippo_rnn_hanabi.yaml | 6 +++--- baselines/IPPO/ippo_rnn_hanabi.py | 7 ++----- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/baselines/IPPO/config/ippo_rnn_hanabi.yaml b/baselines/IPPO/config/ippo_rnn_hanabi.yaml index 121ebeda..44dea301 100644 --- a/baselines/IPPO/config/ippo_rnn_hanabi.yaml +++ b/baselines/IPPO/config/ippo_rnn_hanabi.yaml @@ -19,6 +19,6 @@ "SEED": 30 # WandB Params -"WANDB_MODE": "disabled" -"ENTITY": "" -"PROJECT": "" +"WANDB_MODE": "online" +"ENTITY": "alex-plus" +"PROJECT": "jaxmarl-hanabi" diff --git a/baselines/IPPO/ippo_rnn_hanabi.py b/baselines/IPPO/ippo_rnn_hanabi.py index 43191093..0d5f382c 100644 --- a/baselines/IPPO/ippo_rnn_hanabi.py +++ b/baselines/IPPO/ippo_rnn_hanabi.py @@ -14,18 +14,14 @@ from flax.linen.initializers import constant, orthogonal from typing import Sequence, NamedTuple, Any, Dict from flax.training.train_state import TrainState -import orbax.checkpoint -from flax.training import orbax_utils import distrax import jaxmarl from jaxmarl.wrappers.baselines import LogWrapper import wandb import functools -import matplotlib.pyplot as plt import hydra from omegaconf import OmegaConf import os -os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false' class ScannedRNN(nn.Module): @functools.partial( @@ -367,7 +363,8 @@ def callback(metric): * config["NUM_ENVS"] * config["NUM_STEPS"], **metric["loss"], - } + }, + step=metric["update_steps"], ) metric["update_steps"] = update_steps jax.experimental.io_callback(callback, None, metric) From ebcae02b1914d2d7c6da4b1d5bb1e1973c317513 Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 30 May 2024 13:49:12 +0100 Subject: [PATCH 10/20] hanabi sweep --- sweep_hanabi.yaml | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 sweep_hanabi.yaml diff --git a/sweep_hanabi.yaml b/sweep_hanabi.yaml new file mode 100644 index 00000000..24b40a36 --- /dev/null +++ b/sweep_hanabi.yaml @@ -0,0 +1,24 @@ +command: + - python3 + - ${program} + - ${args_no_hyphens} +entity: alex-plus +method: grid +parameters: + SEED: + values: + - 42 + - 43 + - 44 + - 45 + - 46 + - 47 + - 48 + - 49 + - 50 + - 51 + WANDB_MODE: + values: + - online +program: baselines/MAPPO/ippo_rnn_hanabi.py +project: jaxmarl-hanabi \ No newline at end of file From aa31f2f7a7edf7558c70905096ef7b7493f862a6 Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 30 May 2024 13:52:24 +0100 Subject: [PATCH 11/20] crrectio --- sweep_hanabi.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sweep_hanabi.yaml b/sweep_hanabi.yaml index 24b40a36..7c7c49a4 100644 --- a/sweep_hanabi.yaml +++ b/sweep_hanabi.yaml @@ -20,5 +20,5 @@ parameters: WANDB_MODE: values: - online -program: baselines/MAPPO/ippo_rnn_hanabi.py +program: baselines/IPPO/ippo_rnn_hanabi.py project: jaxmarl-hanabi \ No newline at end of file From 1938b03eb29a6562b8cab4d1a1127014914a340b Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 11 Jun 2024 15:55:25 +0100 Subject: [PATCH 12/20] tidying --- baselines/IPPO/ippo_ff_mabrax.py | 5 ----- baselines/IPPO/ippo_ff_mpe.py | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/baselines/IPPO/ippo_ff_mabrax.py b/baselines/IPPO/ippo_ff_mabrax.py index b45b5adb..f0b05f06 100644 --- a/baselines/IPPO/ippo_ff_mabrax.py +++ b/baselines/IPPO/ippo_ff_mabrax.py @@ -65,10 +65,6 @@ class Transition(NamedTuple): obs: jnp.ndarray info: jnp.ndarray -# def batchify(x: dict, agent_list, num_actors): -# x = jnp.stack([x[a] for a in agent_list]) -# return x.reshape((num_actors, -1)) - def batchify(x: dict, agent_list, num_actors): max_dim = max([x[a].shape[-1] for a in agent_list]) print('max_dim', max_dim) @@ -100,7 +96,6 @@ def linear_schedule(count): # INIT NETWORK network = ActorCritic(env.action_space(env.agents[0]).shape[0], activation=config["ACTIVATION"]) - # rng, _rng = jax.random.split(rng_init) max_dim = jnp.argmax(jnp.array([env.observation_space(a).shape[-1] for a in env.agents])) init_x = jnp.zeros(env.observation_space(env.agents[max_dim]).shape) network_params = network.init(rng_init, init_x) diff --git a/baselines/IPPO/ippo_ff_mpe.py b/baselines/IPPO/ippo_ff_mpe.py index c9ef4678..fd86072e 100644 --- a/baselines/IPPO/ippo_ff_mpe.py +++ b/baselines/IPPO/ippo_ff_mpe.py @@ -287,7 +287,7 @@ def callback(metric): rng = update_state[-1] r0 = {"ratio0": loss_info["ratio"][0,0].mean()} - jax.debug.print('ratio0 {x}', x=r0["ratio0"]) + # jax.debug.print('ratio0 {x}', x=r0["ratio0"]) loss_info = jax.tree_map(lambda x: x.mean(), loss_info) metric = jax.tree_map(lambda x: x.mean(), metric) metric = {**metric, **loss_info, **r0} From 6998b83d8b0e9eb9350e2d6cb1c913144c7952e8 Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 11 Jun 2024 15:59:25 +0100 Subject: [PATCH 13/20] tidy --- baselines/IPPO/config/ippo_ff_mabrax.yaml | 6 +++--- baselines/IPPO/config/ippo_rnn_hanabi.yaml | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/baselines/IPPO/config/ippo_ff_mabrax.yaml b/baselines/IPPO/config/ippo_ff_mabrax.yaml index be194dea..15853339 100644 --- a/baselines/IPPO/config/ippo_ff_mabrax.yaml +++ b/baselines/IPPO/config/ippo_ff_mabrax.yaml @@ -19,6 +19,6 @@ "DISABLE_JIT": False # WandB Params -"ENTITY": "alex-plus" -"PROJECT": "jaxmarl-mabrax" -"WANDB_MODE": "online" \ No newline at end of file +"ENTITY": "" +"PROJECT": "" +"WANDB_MODE": "disabled" \ No newline at end of file diff --git a/baselines/IPPO/config/ippo_rnn_hanabi.yaml b/baselines/IPPO/config/ippo_rnn_hanabi.yaml index 44dea301..121ebeda 100644 --- a/baselines/IPPO/config/ippo_rnn_hanabi.yaml +++ b/baselines/IPPO/config/ippo_rnn_hanabi.yaml @@ -19,6 +19,6 @@ "SEED": 30 # WandB Params -"WANDB_MODE": "online" -"ENTITY": "alex-plus" -"PROJECT": "jaxmarl-hanabi" +"WANDB_MODE": "disabled" +"ENTITY": "" +"PROJECT": "" From ea0dcad7f2f36abab126b4ec81a64711d2627e5f Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 11 Jun 2024 15:59:35 +0100 Subject: [PATCH 14/20] action padding note --- jaxmarl/environments/mabrax/README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jaxmarl/environments/mabrax/README.md b/jaxmarl/environments/mabrax/README.md index 5edec87f..ae3448de 100644 --- a/jaxmarl/environments/mabrax/README.md +++ b/jaxmarl/environments/mabrax/README.md @@ -21,6 +21,12 @@ Each agent's observation vector is composed of the local state of the joints it ## Action Spaces Each agent's action space is the input torques to the joints it controls. All environments have continuous actions in the range [-1.0, 1.0], except for `humanoid_9|8` where the range is [-0.4, 0.4]. +Note: the two agents in `humanoid_9|8` have different action space sizes. To pad the action spaces to be the same size pass `"homogenisation_method":"max"` to the envrionment. If using our config files, this would done as: +```yaml +"ENV_NAME": "humanoid_9|8" +"ENV_KWARGS": {"homogenisation_method":"max"} +``` + ## Visualisation To visualise a trajectory in a Jupyter notebook, given a list of states, you can use the following code snippet: From aee081e7003e5e137fe917fbc95f11d701a4562a Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 11 Jun 2024 16:03:00 +0100 Subject: [PATCH 15/20] unpin jax --- requirements/requirements.txt | 4 ++-- sweep_hanabi.yaml | 24 ------------------------ sweep_mabrax.yaml | 24 ------------------------ 3 files changed, 2 insertions(+), 50 deletions(-) delete mode 100644 sweep_hanabi.yaml delete mode 100644 sweep_mabrax.yaml diff --git a/requirements/requirements.txt b/requirements/requirements.txt index a9ce3284..9a40a966 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,6 @@ # requirements are alligned with nvcr.io/nvidia/jax:23.10-py3 image -jax==0.4.17.* -jaxlib==0.4.17.* +jax>=0.4.17.* +jaxlib>=0.4.17.* flax==0.7.4 chex==0.1.84 optax==0.1.7 diff --git a/sweep_hanabi.yaml b/sweep_hanabi.yaml deleted file mode 100644 index 7c7c49a4..00000000 --- a/sweep_hanabi.yaml +++ /dev/null @@ -1,24 +0,0 @@ -command: - - python3 - - ${program} - - ${args_no_hyphens} -entity: alex-plus -method: grid -parameters: - SEED: - values: - - 42 - - 43 - - 44 - - 45 - - 46 - - 47 - - 48 - - 49 - - 50 - - 51 - WANDB_MODE: - values: - - online -program: baselines/IPPO/ippo_rnn_hanabi.py -project: jaxmarl-hanabi \ No newline at end of file diff --git a/sweep_mabrax.yaml b/sweep_mabrax.yaml deleted file mode 100644 index a82aff13..00000000 --- a/sweep_mabrax.yaml +++ /dev/null @@ -1,24 +0,0 @@ -command: - - python3 - - ${program} - - ${args_no_hyphens} -entity: amacrutherford -method: grid -parameters: - SEED: - values: - - 42 - - 43 - - 44 - - 45 - - 46 - - 47 - - 48 - - 49 - - 50 - - 51 - WANDB_MODE: - values: - - online -program: baselines/IPPO/ippo_ff_mabrax.py -project: jaxmarl-mabrax \ No newline at end of file From 30412ebb5bbd7580d46e7f901a2c400cedd9dbb4 Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 11 Jun 2024 16:03:29 +0100 Subject: [PATCH 16/20] ready for PR --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 2fbca96d..2190472a 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ endif # Set flag for docker run command -MYUSER=alexr +MYUSER=myuser BASE_FLAGS=-it --rm -v ${PWD}:/home/$(MYUSER) --shm-size 20G RUN_FLAGS=$(GPUS) $(BASE_FLAGS) From 42fa7001a42c0dab8373b06e750a4908c55c4e4f Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 11 Jun 2024 16:05:27 +0100 Subject: [PATCH 17/20] correct timetsep count --- baselines/IPPO/config/ippo_ff_mabrax.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baselines/IPPO/config/ippo_ff_mabrax.yaml b/baselines/IPPO/config/ippo_ff_mabrax.yaml index 15853339..98e0f945 100644 --- a/baselines/IPPO/config/ippo_ff_mabrax.yaml +++ b/baselines/IPPO/config/ippo_ff_mabrax.yaml @@ -1,7 +1,7 @@ "LR": 1e-3 "NUM_ENVS": 64 "NUM_STEPS": 300 -"TOTAL_TIMESTEPS": 5e7 +"TOTAL_TIMESTEPS": 1e7 "UPDATE_EPOCHS": 4 "NUM_MINIBATCHES": 4 "GAMMA": 0.99 From 257de3ceff7c225c11ae941b7c6762c2d407ff65 Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 11 Jun 2024 16:12:38 +0100 Subject: [PATCH 18/20] bug fix --- Makefile | 4 ---- requirements/requirements.txt | 4 ++-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 2190472a..c8991dc1 100644 --- a/Makefile +++ b/Makefile @@ -27,7 +27,3 @@ run: test: $(DOCKER_RUN) /bin/bash -c "pytest ./tests/" - -workflow-test: - # without -it flag - docker run --rm -v ${PWD}:/home/workdir --shm-size 20G $(IMAGE) /bin/bash -c "pytest ./tests/" diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 9a40a966..0f769bf5 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,6 @@ # requirements are alligned with nvcr.io/nvidia/jax:23.10-py3 image -jax>=0.4.17.* -jaxlib>=0.4.17.* +jax>=0.4.16.0 +jaxlib>=0.4.16.0 flax==0.7.4 chex==0.1.84 optax==0.1.7 From 78014815a9f64a8f25c8df94ba860fbdba74d4e5 Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 11 Jun 2024 17:46:57 +0100 Subject: [PATCH 19/20] workflow test back in --- Makefile | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Makefile b/Makefile index c8991dc1..d9d89082 100644 --- a/Makefile +++ b/Makefile @@ -27,3 +27,7 @@ run: test: $(DOCKER_RUN) /bin/bash -c "pytest ./tests/" + +workflow-test: + # without -it flag + docker run --rm -v ${PWD}:/home/workdir --shm-size 20G $(IMAGE) /bin/bash -c "pytest ./tests/" \ No newline at end of file From 105ad8267e2c1f51e2945344fc03021ddf3ce1b4 Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 11 Jun 2024 19:56:15 +0100 Subject: [PATCH 20/20] upper bound jax --- requirements/requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 0f769bf5..9f9c12ae 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,6 @@ # requirements are alligned with nvcr.io/nvidia/jax:23.10-py3 image -jax>=0.4.16.0 -jaxlib>=0.4.16.0 +jax>=0.4.16.0,<=0.4.25 +jaxlib>=0.4.16.0,<=0.4.25 flax==0.7.4 chex==0.1.84 optax==0.1.7