Skip to content

Commit

Permalink
Merge pull request #56 from alexunderch/fix/rnn_hidden_dim
Browse files Browse the repository at this point in the history
started a pr about rnn hidsize
  • Loading branch information
amacrutherford authored Jan 29, 2024
2 parents a784a4a + c70a335 commit 0f7ee87
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 207 deletions.
17 changes: 13 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
# install python
ARG DEBIAN_FRONTEND=noninteractive
ARG PYTHON_VERSION=3.10
#setting language and locale
ENV LANG="C.UTF-8" LC_ALL="C.UTF-8"

# ENV TZ=Europe/London

RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \
Expand Down Expand Up @@ -33,18 +34,26 @@ RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python get-pip.py

# default workdir
WORKDIR /home/workdir

# dev: install from source
COPY . .

#jaxmarl from source if needed, all the requirements
RUN pip install --ignore-installed -e '.[qlearning, dev]'

# install jax from to enable cuda
RUN pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

#disabling preallocation
ARG XLA_PYTHON_CLIENT_PREALLOCATE=false
RUN export XLA_PYTHON_CLIENT_PREALLOCATE=false
#safety measures
RUN export XLA_PYTHON_CLIENT_MEM_FRACTION=0.25
RUN export TF_FORCE_GPU_ALLOW_GROWTH=true

#for jupyter
EXPOSE 9999

#for secrets and debug
ENV WANDB_API_KEY=""
ENV WANDB_ENTITY=""
RUN git config --global --add safe.directory /home/workdir

CMD ["/bin/bash"]
3 changes: 2 additions & 1 deletion baselines/IPPO/config/ippo_rnn_smax.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"LR": 0.004
"NUM_ENVS": 128
"NUM_STEPS": 128
"GRU_HIDDEN_DIM": 256
"TOTAL_TIMESTEPS": 1e7
"UPDATE_EPOCHS": 4
"NUM_MINIBATCHES": 4
Expand All @@ -22,6 +23,6 @@
"ANNEAL_LR": True

# WandB Params
"ENTITY": ""
"ENTITY": ${oc.env:WANDB_ENTITY}
"PROJECT": "jaxmarl-smax"
"WANDB_MODE" : "disabled"
17 changes: 8 additions & 9 deletions baselines/IPPO/ippo_rnn_smax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __call__(self, carry, x):
ins, resets = x
rnn_state = jnp.where(
resets[:, np.newaxis],
self.initialize_carry(ins.shape[0], ins.shape[1]),
self.initialize_carry(*rnn_state.shape),
rnn_state,
)
new_rnn_state, y = nn.GRUCell(features=ins.shape[1])(rnn_state, ins)
Expand Down Expand Up @@ -65,7 +65,7 @@ def __call__(self, hidden, x):
rnn_in = (embedding, dones)
hidden, embedding = ScannedRNN()(hidden, rnn_in)

actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
actor_mean = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))(
embedding
)
actor_mean = nn.relu(actor_mean)
Expand Down Expand Up @@ -147,7 +147,7 @@ def train(rng):
jnp.zeros((1, config["NUM_ENVS"])),
jnp.zeros((1, config["NUM_ENVS"], env.action_space(env.agents[0]).n)),
)
init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], config["GRU_HIDDEN_DIM"])
network_params = network.init(_rng, init_hstate, init_x)
if config["ANNEAL_LR"]:
tx = optax.chain(
Expand All @@ -169,8 +169,7 @@ def train(rng):
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rng)
init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], 128)

init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], config["GRU_HIDDEN_DIM"])
# TRAIN LOOP
def _update_step(update_runner_state, unused):
# COLLECT TRAJECTORIES
Expand Down Expand Up @@ -275,7 +274,7 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
# RERUN NETWORK
_, pi, value = network.apply(
params,
init_hstate.transpose(),
init_hstate.squeeze(),
(traj_batch.obs, traj_batch.done, traj_batch.avail_actions),
)
log_prob = pi.log_prob(traj_batch.action)
Expand Down Expand Up @@ -330,8 +329,9 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
) = update_state
rng, _rng = jax.random.split(rng)

# adding an additional "fake" dimensionality to perform minibatching correctly
init_hstate = jnp.reshape(
init_hstate, (config["NUM_STEPS"], config["NUM_ACTORS"])
init_hstate, (1, config["NUM_ACTORS"], -1)
)
batch = (
init_hstate,
Expand Down Expand Up @@ -363,15 +363,14 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
)
update_state = (
train_state,
init_hstate,
init_hstate.squeeze(),
traj_batch,
advantages,
targets,
rng,
)
return update_state, total_loss

init_hstate = initial_hstate[None, :].squeeze().transpose()
update_state = (
train_state,
init_hstate,
Expand Down
389 changes: 196 additions & 193 deletions jaxmarl/tutorials/JaxMARL_Walkthrough.ipynb

Large diffs are not rendered by default.

0 comments on commit 0f7ee87

Please sign in to comment.