Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non root Docker, Unpin Jax req, small IPPO qual of life improvements, MaBrax ReadMe note #97

Merged
merged 20 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
FROM nvcr.io/nvidia/jax:23.10-py3

# Create user
ARG UID
ARG MYUSER
RUN useradd -u $UID --create-home ${MYUSER}
USER ${MYUSER}

# default workdir
WORKDIR /home/workdir
COPY . .
WORKDIR /home/${MYUSER}/
COPY --chown=${MYUSER} --chmod=765 . .

#jaxmarl from source if needed, all the requirements
USER root
RUN pip install -e .

# install tmux
RUN apt-get update && \
apt-get install -y tmux

USER ${MYUSER}

#disabling preallocation
RUN export XLA_PYTHON_CLIENT_PREALLOCATE=false
#safety measures
Expand All @@ -23,4 +32,4 @@ RUN export TF_FORCE_GPU_ALLOW_GROWTH=true
#for secrets and debug
ENV WANDB_API_KEY=""
ENV WANDB_ENTITY=""
RUN git config --global --add safe.directory /home/workdir
RUN git config --global --add safe.directory /home/${MYUSER}
8 changes: 5 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@ endif


# Set flag for docker run command
BASE_FLAGS=-it --rm -v ${PWD}:/home/workdir --shm-size 20G
MYUSER=myuser
BASE_FLAGS=-it --rm -v ${PWD}:/home/$(MYUSER) --shm-size 20G
RUN_FLAGS=$(GPUS) $(BASE_FLAGS)

DOCKER_IMAGE_NAME = jaxmarl
IMAGE = $(DOCKER_IMAGE_NAME):latest
DOCKER_RUN=docker run $(RUN_FLAGS) $(IMAGE)
USE_CUDA = $(if $(GPUS),true,false)
ID = $(shell id -u)

# make file commands
build:
DOCKER_BUILDKIT=1 docker build --build-arg USE_CUDA=$(USE_CUDA) --tag $(IMAGE) --progress=plain ${PWD}/.
DOCKER_BUILDKIT=1 docker build --build-arg USE_CUDA=$(USE_CUDA) --build-arg MYUSER=$(MYUSER) --build-arg UID=$(ID) --tag $(IMAGE) --progress=plain ${PWD}/.

run:
$(DOCKER_RUN) /bin/bash
Expand All @@ -28,4 +30,4 @@ test:

workflow-test:
# without -it flag
docker run --rm -v ${PWD}:/home/workdir --shm-size 20G $(IMAGE) /bin/bash -c "pytest ./tests/"
docker run --rm -v ${PWD}:/home/workdir --shm-size 20G $(IMAGE) /bin/bash -c "pytest ./tests/"
2 changes: 1 addition & 1 deletion baselines/IPPO/config/ippo_ff_mabrax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@

# WandB Params
"ENTITY": ""
"PROJECT": "jaxmarl-mabrax"
"PROJECT": ""
"WANDB_MODE": "disabled"
115 changes: 67 additions & 48 deletions baselines/IPPO/ippo_ff_mabrax.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,19 @@ class Transition(NamedTuple):
info: jnp.ndarray

def batchify(x: dict, agent_list, num_actors):
x = jnp.stack([x[a] for a in agent_list])
max_dim = max([x[a].shape[-1] for a in agent_list])
print('max_dim', max_dim)
def pad(z):
return jnp.concatenate([z, jnp.zeros(z.shape[:-1] + (max_dim - z.shape[-1],))], -1)

x = jnp.stack([x[a] if x[a].shape[-1] == max_dim else pad(x[a]) for a in agent_list])
return x.reshape((num_actors, -1))

def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
x = x.reshape((num_actors, num_envs, -1))
return {a: x[i] for i, a in enumerate(agent_list)}

def make_train(config):
def make_train(config, rng_init):
env = jaxmarl.make(config["ENV_NAME"], **config["ENV_KWARGS"])
config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"]
config["NUM_UPDATES"] = (
Expand All @@ -89,27 +94,26 @@ def linear_schedule(count):
frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
return config["LR"] * frac

def train(rng):
# INIT NETWORK
network = ActorCritic(env.action_space(env.agents[0]).shape[0], activation=config["ACTIVATION"])
max_dim = jnp.argmax(jnp.array([env.observation_space(a).shape[-1] for a in env.agents]))
init_x = jnp.zeros(env.observation_space(env.agents[max_dim]).shape)
network_params = network.init(rng_init, init_x)
if config["ANNEAL_LR"]:
tx = optax.chain(
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
optax.adam(learning_rate=linear_schedule, eps=1e-5),
)
else:
tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5))

# INIT NETWORK
# TODO doesn't work for non-homogenous agents
network = ActorCritic(env.action_space(env.agents[0]).shape[0], activation=config["ACTIVATION"])
rng, _rng = jax.random.split(rng)
init_x = jnp.zeros(env.observation_space(env.agents[0]).shape)
network_params = network.init(_rng, init_x)
if config["ANNEAL_LR"]:
tx = optax.chain(
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
optax.adam(learning_rate=linear_schedule, eps=1e-5),
)
else:
tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5))
train_state = TrainState.create(
apply_fn=network.apply,
params=network_params,
tx=tx,
)

train_state = TrainState.create(
apply_fn=network.apply,
params=network_params,
tx=tx,
)
def train(rng):

# INIT ENV
rng, _rng = jax.random.split(rng)
Expand All @@ -121,7 +125,7 @@ def train(rng):
def _update_step(runner_state, unused):
# COLLECT TRAJECTORIES
def _env_step(runner_state, unused):
train_state, env_state, last_obs, rng = runner_state
train_state, env_state, last_obs, update_count, rng = runner_state

obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
# SELECT ACTION
Expand Down Expand Up @@ -149,15 +153,15 @@ def _env_step(runner_state, unused):
obs_batch,
info,
)
runner_state = (train_state, env_state, obsv, rng)
runner_state = (train_state, env_state, obsv, update_count, rng)
return runner_state, transition

runner_state, traj_batch = jax.lax.scan(
_env_step, runner_state, None, config["NUM_STEPS"]
)

print('traj_batch', traj_batch)
# CALCULATE ADVANTAGE
train_state, env_state, last_obs, rng = runner_state
train_state, env_state, last_obs, update_count, rng = runner_state
last_obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
_, last_val = network.apply(train_state.params, last_obs_batch)

Expand Down Expand Up @@ -228,14 +232,23 @@ def _loss_fn(params, traj_batch, gae, targets):
+ config["VF_COEF"] * value_loss
- config["ENT_COEF"] * entropy
)
return total_loss, (value_loss, loss_actor, entropy)
return total_loss, (value_loss, loss_actor, entropy, ratio)

grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(
train_state.params, traj_batch, advantages, targets
)
train_state = train_state.apply_gradients(grads=grads)
return train_state, total_loss

loss_info = {
"total_loss": total_loss[0],
"actor_loss": total_loss[1][1],
"critic_loss": total_loss[1][0],
"entropy": total_loss[1][2],
"ratio": total_loss[1][3],
}

return train_state, loss_info

train_state, traj_batch, advantages, targets, rng = update_state
rng, _rng = jax.random.split(rng)
Expand All @@ -257,11 +270,17 @@ def _loss_fn(params, traj_batch, gae, targets):
),
shuffled_batch,
)
train_state, total_loss = jax.lax.scan(
train_state, loss_info = jax.lax.scan(
_update_minbatch, train_state, minibatches
)
update_state = (train_state, traj_batch, advantages, targets, rng)
return update_state, total_loss
return update_state, loss_info

def callback(metric):
wandb.log(
metric,
step=metric["update_step"],
)

update_state = (train_state, traj_batch, advantages, targets, rng)
update_state, loss_info = jax.lax.scan(
Expand All @@ -270,12 +289,20 @@ def _loss_fn(params, traj_batch, gae, targets):
train_state = update_state[0]
metric = traj_batch.info
rng = update_state[-1]

runner_state = (train_state, env_state, last_obs, rng)

update_count = update_count + 1
r0 = {"ratio0": loss_info["ratio"][0,0].mean()}
loss_info = jax.tree_map(lambda x: x.mean(), loss_info)
metric = jax.tree_map(lambda x: x.mean(), metric)
metric["update_step"] = update_count
metric["env_step"] = update_count * config["NUM_STEPS"] * config["NUM_ENVS"]
metric = {**metric, **loss_info, **r0}
jax.experimental.io_callback(callback, None, metric)
runner_state = (train_state, env_state, last_obs, update_count, rng)
return runner_state, metric

rng, _rng = jax.random.split(rng)
runner_state = (train_state, env_state, obsv, _rng)
runner_state = (train_state, env_state, obsv, 0, _rng)
runner_state, metric = jax.lax.scan(
_update_step, runner_state, None, config["NUM_UPDATES"]
)
Expand All @@ -298,25 +325,17 @@ def main(config):

rng = jax.random.PRNGKey(config["SEED"])
with jax.disable_jit(config["DISABLE_JIT"]):
train_jit = jax.jit(make_train(config), device=jax.devices()[config["DEVICE"]])
rng, _rng = jax.random.split(rng)
train_jit = jax.jit(make_train(config, _rng), device=jax.devices()[config["DEVICE"]])
out = train_jit(rng)

'''updates_x = jnp.arange(out["metrics"]["returned_episode_returns"].squeeze().shape[0])
print('updates x', updates_x.shape)
print('metrics shape', out["metrics"]["returned_episode_returns"].shape)
returns_table = jnp.stack([updates_x, out["metrics"]["returned_episode_returns"].mean(-1).squeeze()], axis=1)
returns_table = wandb.Table(data=returns_table.tolist(), columns=["updates", "returns"])
wandb.log({
"returns_plot": wandb.plot.line(returns_table, "updates", "returns", title="returns_vs_updates"),
"returns": out["metrics"]["returned_episode_returns"].mean()
})'''

mean_returns = out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1)
x = np.arange(len(mean_returns)) * config["NUM_ACTORS"]
plt.plot(x, mean_returns)
plt.xlabel("Timestep")
plt.ylabel("Return")
plt.savefig(f'mabrax_ippo_ret.png')
# mean_returns = out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1)
# x = np.arange(len(mean_returns)) * config["NUM_ACTORS"]
# plt.plot(x, mean_returns)
# plt.xlabel("Timestep")
# plt.ylabel("Return")
# plt.savefig(f'mabrax_ippo_ret.png')

# import pdb; pdb.set_trace()

Expand Down
2 changes: 1 addition & 1 deletion baselines/IPPO/ippo_ff_mpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def callback(metric):
rng = update_state[-1]

r0 = {"ratio0": loss_info["ratio"][0,0].mean()}
jax.debug.print('ratio0 {x}', x=r0["ratio0"])
# jax.debug.print('ratio0 {x}', x=r0["ratio0"])
loss_info = jax.tree_map(lambda x: x.mean(), loss_info)
metric = jax.tree_map(lambda x: x.mean(), metric)
metric = {**metric, **loss_info, **r0}
Expand Down
7 changes: 2 additions & 5 deletions baselines/IPPO/ippo_rnn_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,14 @@
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any, Dict
from flax.training.train_state import TrainState
import orbax.checkpoint
from flax.training import orbax_utils
import distrax
import jaxmarl
from jaxmarl.wrappers.baselines import LogWrapper
import wandb
import functools
import matplotlib.pyplot as plt
import hydra
from omegaconf import OmegaConf
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'

class ScannedRNN(nn.Module):
@functools.partial(
Expand Down Expand Up @@ -367,7 +363,8 @@ def callback(metric):
* config["NUM_ENVS"]
* config["NUM_STEPS"],
**metric["loss"],
}
},
step=metric["update_steps"],
)
metric["update_steps"] = update_steps
jax.experimental.io_callback(callback, None, metric)
Expand Down
6 changes: 6 additions & 0 deletions jaxmarl/environments/mabrax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ Each agent's observation vector is composed of the local state of the joints it
## Action Spaces
Each agent's action space is the input torques to the joints it controls. All environments have continuous actions in the range [-1.0, 1.0], except for `humanoid_9|8` where the range is [-0.4, 0.4].

Note: the two agents in `humanoid_9|8` have different action space sizes. To pad the action spaces to be the same size pass `"homogenisation_method":"max"` to the envrionment. If using our config files, this would done as:
```yaml
"ENV_NAME": "humanoid_9|8"
"ENV_KWARGS": {"homogenisation_method":"max"}
```


## Visualisation
To visualise a trajectory in a Jupyter notebook, given a list of states, you can use the following code snippet:
Expand Down
7 changes: 4 additions & 3 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# requirements are alligned with nvcr.io/nvidia/jax:23.10-py3 image
jax==0.4.17.*
jaxlib==0.4.17.*
jax>=0.4.16.0,<=0.4.25
jaxlib>=0.4.16.0,<=0.4.25
flax==0.7.4
chex==0.1.84
optax==0.1.7
dotmap==1.3.30
evosax==0.1.5
distrax==0.1.5
brax==0.10.0
brax==0.10.3
mujoco==3.1.3
gymnax==0.0.6
safetensors==0.4.2
flashbax==0.1.0
Expand Down
Loading