diff --git a/examples/train_single_standalone.ipynb b/examples/train_single_standalone.ipynb index 6a739f2..d2ba03e 100644 --- a/examples/train_single_standalone.ipynb +++ b/examples/train_single_standalone.ipynb @@ -10,10 +10,744 @@ "" ] }, + { + "cell_type": "markdown", + "id": "b8217958-2182-4a0d-8f95-4892c861ec1b", + "metadata": {}, + "source": [ + "# Single-task PPO " + ] + }, + { + "cell_type": "markdown", + "id": "2edb8481-2240-4eb1-a451-3769e85fc3ca", + "metadata": {}, + "source": [ + "> ⚠️ Ensure you select a GPU from `Runtime > Change runtime type`. ⚠️\n", + "\n", + "> 🔥 Instances with multiple T4 gpus are available on Kaggle for free! Multi-gpu can speed up training with `pmap`. 🔥" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "281e95ae-6e39-45fd-bb70-89bef0ca336e", + "metadata": {}, + "outputs": [], + "source": [ + "# jax is already installed on the colab, uncomment only if needed\n", + "# !pip install --upgrade \"jax[cuda11_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", + "\n", + "# !pip install 'xminigrid[baselines]'" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b012cfa6-acd6-4692-96d0-cd203aedd029", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import math\n", + "from typing import TypedDict\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.tree_util as jtu\n", + "import flax\n", + "import flax.linen as nn\n", + "import distrax\n", + "import optax\n", + "import imageio\n", + "import wandb\n", + "\n", + "from flax import struct\n", + "from flax.linen.initializers import glorot_normal, orthogonal, zeros_init\n", + "from flax.training.train_state import TrainState\n", + "from flax.jax_utils import replicate, unreplicate\n", + "from dataclasses import asdict, dataclass\n", + "from functools import partial\n", + "\n", + "import xminigrid\n", + "from xminigrid.environment import Environment, EnvParams\n", + "from xminigrid.wrappers import GymAutoResetWrapper" + ] + }, + { + "cell_type": "markdown", + "id": "96c673ae-78c5-4c4e-b2a9-ab77a042797a", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## Networks" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3d649792-f0b9-407b-ae3b-6c1214ba0881", + "metadata": {}, + "outputs": [], + "source": [ + "# Model adapted from minigrid baselines:\n", + "# https://github.com/lcswillems/rl-starter-files/blob/master/model.py\n", + "\n", + "# custom RNN cell, which is more convenient that default in flax\n", + "class GRU(nn.Module):\n", + " hidden_dim: int\n", + "\n", + " @nn.compact\n", + " def __call__(self, xs, init_state):\n", + " seq_len, input_dim = xs.shape\n", + " # this init might not be optimal, for example bias for reset gate should be -1 (for now ok)\n", + " Wi = self.param(\"Wi\", glorot_normal(in_axis=1, out_axis=0), (self.hidden_dim * 3, input_dim))\n", + " Wh = self.param(\"Wh\", orthogonal(column_axis=0), (self.hidden_dim * 3, self.hidden_dim))\n", + " bi = self.param(\"bi\", zeros_init(), (self.hidden_dim * 3,))\n", + " bn = self.param(\"bn\", zeros_init(), (self.hidden_dim,))\n", + "\n", + " def _step_fn(h, x):\n", + " igates = jnp.split(Wi @ x + bi, 3)\n", + " hgates = jnp.split(Wh @ h, 3)\n", + "\n", + " reset = nn.sigmoid(igates[0] + hgates[0])\n", + " update = nn.sigmoid(igates[1] + hgates[1])\n", + " new = nn.tanh(igates[2] + reset * (hgates[2] + bn))\n", + " next_h = (1 - update) * new + update * h\n", + "\n", + " return next_h, next_h\n", + "\n", + " last_state, all_states = jax.lax.scan(_step_fn, init=init_state, xs=xs)\n", + " return all_states, last_state\n", + "\n", + "class RNNModel(nn.Module):\n", + " hidden_dim: int\n", + " num_layers: int\n", + "\n", + " @nn.compact\n", + " def __call__(self, xs, init_state):\n", + " # xs: [seq_len, input_dim]\n", + " # init_state: [num_layers, hidden_dim]\n", + " outs, states = [], []\n", + " for layer in range(self.num_layers):\n", + " xs, state = GRU(hidden_dim=self.hidden_dim)(xs, init_state[layer])\n", + " outs.append(xs)\n", + " states.append(state)\n", + "\n", + " # sum outputs from all layers, kinda like in ResNet\n", + " return jnp.array(outs).sum(0), jnp.array(states)\n", + "\n", + "BatchedRNNModel = flax.linen.vmap(\n", + " RNNModel, variable_axes={\"params\": None}, split_rngs={\"params\": False}, axis_name=\"batch\"\n", + ")\n", + "\n", + "class MaxPool2d(nn.Module):\n", + " kernel_size: tuple[int, int]\n", + "\n", + " @nn.compact\n", + " def __call__(self, x):\n", + " return nn.max_pool(inputs=x, window_shape=self.kernel_size, strides=self.kernel_size, padding=\"VALID\")\n", + "\n", + "class ActorCriticInput(TypedDict):\n", + " observation: jax.Array\n", + " prev_action: jax.Array\n", + " prev_reward: jax.Array\n", + "\n", + "class ActorCriticRNN(nn.Module):\n", + " num_actions: int\n", + " action_emb_dim: int = 16\n", + " rnn_hidden_dim: int = 64\n", + " rnn_num_layers: int = 1\n", + " head_hidden_dim: int = 64\n", + "\n", + " @nn.compact\n", + " def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]:\n", + " B, S = inputs[\"observation\"].shape[:2]\n", + " # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py\n", + " img_encoder = nn.Sequential(\n", + " [\n", + " nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " MaxPool2d((2, 2)),\n", + " nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " ]\n", + " )\n", + " action_encoder = nn.Embed(self.num_actions, self.action_emb_dim)\n", + "\n", + " rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers)\n", + " actor = nn.Sequential(\n", + " [\n", + " nn.Dense(self.head_hidden_dim, kernel_init=orthogonal(2)),\n", + " nn.tanh,\n", + " nn.Dense(self.num_actions, kernel_init=orthogonal(0.01)),\n", + " ]\n", + " )\n", + " critic = nn.Sequential(\n", + " [\n", + " nn.Dense(self.head_hidden_dim, kernel_init=orthogonal(2)),\n", + " nn.tanh,\n", + " nn.Dense(1, kernel_init=orthogonal(1.0)),\n", + " ]\n", + " )\n", + "\n", + " # [batch_size, seq_len, ...]\n", + " obs_emb = img_encoder(inputs[\"observation\"]).reshape(B, S, -1)\n", + " act_emb = action_encoder(inputs[\"prev_action\"])\n", + " # [batch_size, seq_len, hidden_dim + act_emb_dim + 1]\n", + " out = jnp.concatenate([obs_emb, act_emb, inputs[\"prev_reward\"][..., None]], axis=-1)\n", + " # core networks\n", + " out, new_hidden = rnn_core(out, hidden)\n", + " dist = distrax.Categorical(logits=actor(out))\n", + " values = critic(out)\n", + "\n", + " return dist, jnp.squeeze(values, axis=-1), new_hidden\n", + "\n", + " def initialize_carry(self, batch_size):\n", + " return jnp.zeros((batch_size, self.rnn_num_layers, self.rnn_hidden_dim))" + ] + }, + { + "cell_type": "markdown", + "id": "7477aca3-cc54-40bb-9c4e-2e175d847718", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## Utils" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f7b37e8a-7054-4e99-b0d5-70a81cf8d6ff", + "metadata": {}, + "outputs": [], + "source": [ + "# Training stuff\n", + "class Transition(struct.PyTreeNode):\n", + " done: jax.Array\n", + " action: jax.Array\n", + " value: jax.Array\n", + " reward: jax.Array\n", + " log_prob: jax.Array\n", + " obs: jax.Array\n", + " # for rnn policy\n", + " prev_action: jax.Array\n", + " prev_reward: jax.Array\n", + "\n", + "\n", + "def calculate_gae(\n", + " transitions: Transition,\n", + " last_val: jax.Array,\n", + " gamma: float,\n", + " gae_lambda: float,\n", + ") -> tuple[jax.Array, jax.Array]:\n", + " # single iteration for the loop\n", + " def _get_advantages(gae_and_next_value, transition):\n", + " gae, next_value = gae_and_next_value\n", + " delta = transition.reward + gamma * next_value * (1 - transition.done) - transition.value\n", + " gae = delta + gamma * gae_lambda * (1 - transition.done) * gae\n", + " return (gae, transition.value), gae\n", + "\n", + " _, advantages = jax.lax.scan(\n", + " _get_advantages,\n", + " (jnp.zeros_like(last_val), last_val),\n", + " transitions,\n", + " reverse=True,\n", + " )\n", + " # advantages and values (Q)\n", + " return advantages, advantages + transitions.value\n", + "\n", + "\n", + "def ppo_update_networks(\n", + " train_state: TrainState,\n", + " transitions: Transition,\n", + " init_hstate: jax.Array,\n", + " advantages: jax.Array,\n", + " targets: jax.Array,\n", + " clip_eps: float,\n", + " vf_coef: float,\n", + " ent_coef: float,\n", + "):\n", + " # NORMALIZE ADVANTAGES\n", + " advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)\n", + "\n", + " def _loss_fn(params):\n", + " # RERUN NETWORK\n", + " dist, value, _ = train_state.apply_fn(\n", + " params,\n", + " {\n", + " # [batch_size, seq_len, ...]\n", + " \"observation\": transitions.obs,\n", + " \"prev_action\": transitions.prev_action,\n", + " \"prev_reward\": transitions.prev_reward,\n", + " },\n", + " init_hstate,\n", + " )\n", + " log_prob = dist.log_prob(transitions.action)\n", + "\n", + " # CALCULATE VALUE LOSS\n", + " value_pred_clipped = transitions.value + (value - transitions.value).clip(-clip_eps, clip_eps)\n", + " value_loss = jnp.square(value - targets)\n", + " value_loss_clipped = jnp.square(value_pred_clipped - targets)\n", + " value_loss = 0.5 * jnp.maximum(value_loss, value_loss_clipped).mean()\n", + " # TODO: ablate this!\n", + " # value_loss = jnp.square(value - targets).mean()\n", + "\n", + " # CALCULATE ACTOR LOSS\n", + " ratio = jnp.exp(log_prob - transitions.log_prob)\n", + " actor_loss1 = advantages * ratio\n", + " actor_loss2 = advantages * jnp.clip(ratio, 1.0 - clip_eps, 1.0 + clip_eps)\n", + " actor_loss = -jnp.minimum(actor_loss1, actor_loss2).mean()\n", + " entropy = dist.entropy().mean()\n", + "\n", + " total_loss = actor_loss + vf_coef * value_loss - ent_coef * entropy\n", + " return total_loss, (value_loss, actor_loss, entropy)\n", + "\n", + " (loss, (vloss, aloss, entropy)), grads = jax.value_and_grad(_loss_fn, has_aux=True)(train_state.params)\n", + " (loss, vloss, aloss, entropy, grads) = jax.lax.pmean((loss, vloss, aloss, entropy, grads), axis_name=\"devices\")\n", + " train_state = train_state.apply_gradients(grads=grads)\n", + " update_info = {\n", + " \"total_loss\": loss,\n", + " \"value_loss\": vloss,\n", + " \"actor_loss\": aloss,\n", + " \"entropy\": entropy,\n", + " }\n", + " return train_state, update_info\n", + "\n", + "\n", + "# for evaluation (evaluate for N consecutive episodes, sum rewards)\n", + "# N=1 single task, N>1 for meta-RL\n", + "class RolloutStats(struct.PyTreeNode):\n", + " reward: jax.Array = jnp.asarray(0.0)\n", + " length: jax.Array = jnp.asarray(0)\n", + " episodes: jax.Array = jnp.asarray(0)\n", + "\n", + "\n", + "def rollout(\n", + " rng: jax.Array,\n", + " env: Environment,\n", + " env_params: EnvParams,\n", + " train_state: TrainState,\n", + " init_hstate: jax.Array,\n", + " num_consecutive_episodes: int = 1,\n", + ") -> RolloutStats:\n", + " def _cond_fn(carry):\n", + " rng, stats, timestep, prev_action, prev_reward, hstate = carry\n", + " return jnp.less(stats.episodes, num_consecutive_episodes)\n", + "\n", + " def _body_fn(carry):\n", + " rng, stats, timestep, prev_action, prev_reward, hstate = carry\n", + "\n", + " rng, _rng = jax.random.split(rng)\n", + " dist, _, hstate = train_state.apply_fn(\n", + " train_state.params,\n", + " {\n", + " \"observation\": timestep.observation[None, None, ...],\n", + " \"prev_action\": prev_action[None, None, ...],\n", + " \"prev_reward\": prev_reward[None, None, ...],\n", + " },\n", + " hstate,\n", + " )\n", + " action = dist.sample(seed=_rng).squeeze()\n", + " timestep = env.step(env_params, timestep, action)\n", + "\n", + " stats = stats.replace(\n", + " reward=stats.reward + timestep.reward,\n", + " length=stats.length + 1,\n", + " episodes=stats.episodes + timestep.last(),\n", + " )\n", + " carry = (rng, stats, timestep, action, timestep.reward, hstate)\n", + " return carry\n", + "\n", + " timestep = env.reset(env_params, rng)\n", + " prev_action = jnp.asarray(0)\n", + " prev_reward = jnp.asarray(0)\n", + " init_carry = (rng, RolloutStats(), timestep, prev_action, prev_reward, init_hstate)\n", + "\n", + " final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)\n", + " return final_carry[1]" + ] + }, + { + "cell_type": "markdown", + "id": "283a98b7-8d8e-413e-90f0-6219338ded62", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "46ce2bb4-ae08-42aa-9baa-ff2c13b16fe6", + "metadata": {}, + "outputs": [], + "source": [ + "@dataclass\n", + "class TrainConfig:\n", + " env_id: str = \"MiniGrid-Empty-6x6\"\n", + " # agent\n", + " action_emb_dim: int = 16\n", + " rnn_hidden_dim: int = 1024\n", + " rnn_num_layers: int = 1\n", + " head_hidden_dim: int = 256\n", + " # training\n", + " num_envs: int = 8192\n", + " num_steps: int = 16\n", + " update_epochs: int = 1\n", + " num_minibatches: int = 16\n", + " total_timesteps: int = 5_000_000\n", + " lr: float = 0.001\n", + " clip_eps: float = 0.2\n", + " gamma: float = 0.99\n", + " gae_lambda: float = 0.95\n", + " ent_coef: float = 0.01\n", + " vf_coef: float = 0.5\n", + " max_grad_norm: float = 0.5\n", + " eval_episodes: int = 40\n", + " seed: int = 42\n", + "\n", + " def __post_init__(self):\n", + " num_devices = jax.local_device_count()\n", + " # splitting computation across all available devices\n", + " self.num_envs_per_device = self.num_envs // num_devices\n", + " self.total_timesteps_per_device = self.total_timesteps // num_devices\n", + " self.eval_episodes_per_device = self.eval_episodes // num_devices\n", + " assert self.num_envs % num_devices == 0\n", + " self.num_updates = self.total_timesteps_per_device // self.num_steps // self.num_envs_per_device\n", + " print(f\"Num devices: {num_devices}, Num updates: {self.num_updates}\")\n", + "\n", + "\n", + "def make_states(config: TrainConfig):\n", + " # for learning rate scheduling\n", + " def linear_schedule(count):\n", + " frac = 1.0 - (count // (config.num_minibatches * config.update_epochs)) / config.num_updates\n", + " return config.lr * frac\n", + "\n", + " # setup environment\n", + " if \"XLand-MiniGrid\" in config.env_id:\n", + " raise ValueError(\"Only single-task environments are supported.\")\n", + "\n", + " env, env_params = xminigrid.make(config.env_id)\n", + " env = GymAutoResetWrapper(env)\n", + "\n", + " # setup training state\n", + " rng = jax.random.PRNGKey(config.seed)\n", + " rng, _rng = jax.random.split(rng)\n", + "\n", + " network = ActorCriticRNN(\n", + " num_actions=env.num_actions(env_params),\n", + " action_emb_dim=config.action_emb_dim,\n", + " rnn_hidden_dim=config.rnn_hidden_dim,\n", + " rnn_num_layers=config.rnn_num_layers,\n", + " head_hidden_dim=config.head_hidden_dim,\n", + " )\n", + " # [batch_size, seq_len, ...]\n", + " init_obs = {\n", + " \"observation\": jnp.zeros((config.num_envs_per_device, 1, *env.observation_shape(env_params))),\n", + " \"prev_action\": jnp.zeros((config.num_envs_per_device, 1), dtype=jnp.int32),\n", + " \"prev_reward\": jnp.zeros((config.num_envs_per_device, 1)),\n", + " }\n", + " init_hstate = network.initialize_carry(batch_size=config.num_envs_per_device)\n", + "\n", + " network_params = network.init(_rng, init_obs, init_hstate)\n", + " tx = optax.chain(\n", + " optax.clip_by_global_norm(config.max_grad_norm),\n", + " optax.inject_hyperparams(optax.adam)(learning_rate=linear_schedule, eps=1e-8), # eps=1e-5\n", + " )\n", + " train_state = TrainState.create(apply_fn=network.apply, params=network_params, tx=tx)\n", + "\n", + " return rng, env, env_params, init_hstate, train_state\n", + "\n", + "\n", + "def make_train(\n", + " env: Environment,\n", + " env_params: EnvParams,\n", + " config: TrainConfig,\n", + "):\n", + " @partial(jax.pmap, axis_name=\"devices\")\n", + " def train(\n", + " rng: jax.Array,\n", + " train_state: TrainState,\n", + " init_hstate: jax.Array,\n", + " ):\n", + " # INIT ENV\n", + " rng, _rng = jax.random.split(rng)\n", + " reset_rng = jax.random.split(_rng, config.num_envs_per_device)\n", + "\n", + " timestep = jax.vmap(env.reset, in_axes=(None, 0))(env_params, reset_rng)\n", + " prev_action = jnp.zeros(config.num_envs_per_device, dtype=jnp.int32)\n", + " prev_reward = jnp.zeros(config.num_envs_per_device)\n", + "\n", + " # TRAIN LOOP\n", + " def _update_step(runner_state, _):\n", + " # COLLECT TRAJECTORIES\n", + " def _env_step(runner_state, _):\n", + " rng, train_state, prev_timestep, prev_action, prev_reward, prev_hstate = runner_state\n", + "\n", + " # SELECT ACTION\n", + " rng, _rng = jax.random.split(rng)\n", + " dist, value, hstate = train_state.apply_fn(\n", + " train_state.params,\n", + " {\n", + " # [batch_size, seq_len=1, ...]\n", + " \"observation\": prev_timestep.observation[:, None],\n", + " \"prev_action\": prev_action[:, None],\n", + " \"prev_reward\": prev_reward[:, None],\n", + " },\n", + " prev_hstate,\n", + " )\n", + " action, log_prob = dist.sample_and_log_prob(seed=_rng)\n", + " # squeeze seq_len where possible\n", + " action, value, log_prob = action.squeeze(1), value.squeeze(1), log_prob.squeeze(1)\n", + "\n", + " # STEP ENV\n", + " timestep = jax.vmap(env.step, in_axes=(None, 0, 0))(env_params, prev_timestep, action)\n", + " transition = Transition(\n", + " done=timestep.last(),\n", + " action=action,\n", + " value=value,\n", + " reward=timestep.reward,\n", + " log_prob=log_prob,\n", + " obs=prev_timestep.observation,\n", + " prev_action=prev_action,\n", + " prev_reward=prev_reward,\n", + " )\n", + " runner_state = (rng, train_state, timestep, action, timestep.reward, hstate)\n", + " return runner_state, transition\n", + "\n", + " initial_hstate = runner_state[-1]\n", + " # transitions: [seq_len, batch_size, ...]\n", + " runner_state, transitions = jax.lax.scan(_env_step, runner_state, None, config.num_steps)\n", + "\n", + " # CALCULATE ADVANTAGE\n", + " rng, train_state, timestep, prev_action, prev_reward, hstate = runner_state\n", + " # calculate value of the last step for bootstrapping\n", + " _, last_val, _ = train_state.apply_fn(\n", + " train_state.params,\n", + " {\n", + " \"observation\": timestep.observation[:, None],\n", + " \"prev_action\": prev_action[:, None],\n", + " \"prev_reward\": prev_reward[:, None],\n", + " },\n", + " hstate,\n", + " )\n", + " advantages, targets = calculate_gae(transitions, last_val.squeeze(1), config.gamma, config.gae_lambda)\n", + "\n", + " # UPDATE NETWORK\n", + " def _update_epoch(update_state, _):\n", + " def _update_minbatch(train_state, batch_info):\n", + " init_hstate, transitions, advantages, targets = batch_info\n", + " new_train_state, update_info = ppo_update_networks(\n", + " train_state=train_state,\n", + " transitions=transitions,\n", + " init_hstate=init_hstate.squeeze(1),\n", + " advantages=advantages,\n", + " targets=targets,\n", + " clip_eps=config.clip_eps,\n", + " vf_coef=config.vf_coef,\n", + " ent_coef=config.ent_coef,\n", + " )\n", + " return new_train_state, update_info\n", + "\n", + " rng, train_state, init_hstate, transitions, advantages, targets = update_state\n", + "\n", + " # MINIBATCHES PREPARATION\n", + " rng, _rng = jax.random.split(rng)\n", + " permutation = jax.random.permutation(_rng, config.num_envs_per_device)\n", + " # [seq_len, batch_size, ...]\n", + " batch = (init_hstate, transitions, advantages, targets)\n", + " # [batch_size, seq_len, ...], as our model assumes\n", + " batch = jtu.tree_map(lambda x: x.swapaxes(0, 1), batch)\n", + "\n", + " shuffled_batch = jtu.tree_map(lambda x: jnp.take(x, permutation, axis=0), batch)\n", + " # [num_minibatches, minibatch_size, ...]\n", + " minibatches = jtu.tree_map(\n", + " lambda x: jnp.reshape(x, (config.num_minibatches, -1) + x.shape[1:]), shuffled_batch\n", + " )\n", + " train_state, update_info = jax.lax.scan(_update_minbatch, train_state, minibatches)\n", + "\n", + " update_state = (rng, train_state, init_hstate, transitions, advantages, targets)\n", + " return update_state, update_info\n", + "\n", + " # [seq_len, batch_size, num_layers, hidden_dim]\n", + " init_hstate = initial_hstate[None, :]\n", + " update_state = (rng, train_state, init_hstate, transitions, advantages, targets)\n", + " update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config.update_epochs)\n", + "\n", + " # averaging over minibatches then over epochs\n", + " loss_info = jtu.tree_map(lambda x: x.mean(-1).mean(-1), loss_info)\n", + "\n", + " rng, train_state = update_state[:2]\n", + " # EVALUATE AGENT\n", + " rng, _rng = jax.random.split(rng)\n", + " eval_rng = jax.random.split(_rng, num=config.eval_episodes_per_device)\n", + "\n", + " # vmap only on rngs\n", + " eval_stats = jax.vmap(rollout, in_axes=(0, None, None, None, None, None))(\n", + " eval_rng,\n", + " env,\n", + " env_params,\n", + " train_state,\n", + " # TODO: make this as a static method mb?\n", + " jnp.zeros((1, config.rnn_num_layers, config.rnn_hidden_dim)),\n", + " 1,\n", + " )\n", + " eval_stats = jax.lax.pmean(eval_stats, axis_name=\"devices\")\n", + " loss_info.update(\n", + " {\n", + " \"eval/returns\": eval_stats.reward.mean(0),\n", + " \"eval/lengths\": eval_stats.length.mean(0),\n", + " \"lr\": train_state.opt_state[-1].hyperparams[\"learning_rate\"],\n", + " }\n", + " )\n", + " runner_state = (rng, train_state, timestep, prev_action, prev_reward, hstate)\n", + " return runner_state, loss_info\n", + "\n", + " runner_state = (rng, train_state, timestep, prev_action, prev_reward, init_hstate)\n", + " runner_state, loss_info = jax.lax.scan(_update_step, runner_state, None, config.num_updates)\n", + " return {\"params\": train_state.params, \"runner_state\": runner_state, \"loss_info\": loss_info}\n", + "\n", + " return train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "492347a2-374b-4377-a782-bff3778f5b3d", + "metadata": {}, + "outputs": [], + "source": [ + "config = TrainConfig(env_id=\"MiniGrid-FourRooms\")\n", + "\n", + "rng, env, env_params, init_hstate, train_state = make_states(config)\n", + "# replicating args across devices\n", + "rng = jax.random.split(rng, num=jax.local_device_count())\n", + "train_state = replicate(train_state, jax.local_devices())\n", + "init_hstate = replicate(init_hstate, jax.local_devices())\n", + "\n", + "print(\"Compiling...\")\n", + "t = time.time()\n", + "train_fn = make_train(env, env_params, config)\n", + "train_fn = train_fn.lower(rng, train_state, init_hstate).compile()\n", + "elapsed_time = time.time() - t\n", + "print(f\"Done in {elapsed_time:.2f}s.\")\n", + "\n", + "print(\"Training...\")\n", + "t = time.time()\n", + "train_info = jax.block_until_ready(train_fn(rng, train_state, init_hstate))\n", + "elapsed_time = time.time() - t\n", + "print(f\"Done in {elapsed_time:.2f}s\")\n", + "\n", + "# unreplicating from multiple devices\n", + "train_info = unreplicate(train_info)\n", + "\n", + "print(\"Final return: \", float(train_info[\"loss_info\"][\"eval/returns\"][-1]))" + ] + }, + { + "cell_type": "markdown", + "id": "84441082-a819-43c6-982f-6b68ab6e04cd", + "metadata": {}, + "source": [ + "## Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe55fded-f780-4197-afc3-b291045f1b42", + "metadata": {}, + "outputs": [], + "source": [ + "env, env_params = xminigrid.make(conifg.env_id)\n", + "env = GymAutoResetWrapper(env)\n", + "\n", + "# you can use train_state from the final runner_state also\n", + "# we just demo here how to do it if you loaded params from the checkpoint\n", + "params = train_info[\"params\"]\n", + "model = ActorCriticRNN(\n", + " num_actions=env.num_actions(env_params),\n", + " action_emb_dim=config.action_emb_dim,\n", + " rnn_hidden_dim=config.rnn_hidden_dim,\n", + " rnn_num_layers=config.rnn_num_layers,\n", + " head_hidden_dim=config.head_hidden_dim,\n", + ")\n", + "# jitting all functions\n", + "apply_fn, reset_fn, step_fn = jax.jit(model.apply), jax.jit(env.reset), jax.jit(env.step)\n", + "\n", + "# initial inputs\n", + "hidden = model.initialize_carry(1)\n", + "prev_reward = jnp.asarray(0)\n", + "prev_action = jnp.asarray(0, dtype=jnp.uint32)\n", + "obs = jnp.zeros(env.observation_shape(env_params))\n", + "\n", + "# for logging\n", + "total_reward, num_episodes = 0, 0\n", + "rendered_imgs = []\n", + "\n", + "rng = jax.random.PRNGKey(0)\n", + "rng, _rng = jax.random.split(rng)\n", + "\n", + "timestep = reset_fn(env_params, _rng)\n", + "rendered_imgs.append(env.render(env_params, timestep))\n", + "\n", + "while num_episodes < TOTAL_EPISODES:\n", + " rng, _rng = jax.random.split(rng)\n", + " dist, value, hidden = apply_fn(\n", + " params,\n", + " {\n", + " \"observation\": obs[None, None, ...],\n", + " \"prev_action\": prev_action[None, None, ...],\n", + " \"prev_reward\": prev_reward[None, None, ...],\n", + " },\n", + " hidden,\n", + " )\n", + " action = dist.sample(seed=_rng).squeeze()\n", + "\n", + " timestep = step_fn(env_params, timestep, action)\n", + " prev_action = action\n", + " prev_reward = timestep.reward\n", + " obs = timestep.observation\n", + "\n", + " total_reward += timestep.reward.item()\n", + " num_episodes += int(timestep.last().item())\n", + " # if not bool(timestep.last().item()):\n", + " rendered_imgs.append(env.render(env_params, timestep))\n", + "\n", + "print(\"Total reward:\", total_reward)\n", + "imageio.mimsave(\"eval_rollout.mp4\", rendered_imgs, fps=16, format=\"mp4\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71582fe6-5f36-43d2-8b93-eca16fe91173", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Video\n", + "\n", + "Video(\"eval_rollout.mp4\", embed=True)" + ] + }, + { + "cell_type": "markdown", + "id": "f59acf08-e0ce-4566-9fde-3e79b1f50b3e", + "metadata": {}, + "source": [ + "## Training" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "386b3d71-1325-4046-804b-fb8d510ff241", + "id": "aa37ef46-d3d3-4f38-b706-60fa5c08deba", "metadata": {}, "outputs": [], "source": [] diff --git a/examples/walkthrough.ipynb b/examples/walkthrough.ipynb index 13b2d41..6dc1df1 100644 --- a/examples/walkthrough.ipynb +++ b/examples/walkthrough.ipynb @@ -21,7 +21,11 @@ "\n", "Welcome to the walkthrough the XLand-MiniGrid library. This notebook will showcase our environments and bechmarks APIs, explaning the details and our motivations. It will also provide vectorization and multi-device parallelization examples. For full baselines training demo see our notebooks with standalone PPO implementations.\n", "\n", - "> ⚠️ Ensure you select a GPU from `Runtime > Change runtime type` ⚠️" + "> ⚠️ Ensure you select a GPU from `Runtime > Change runtime type`. ⚠️\n", + "\n", + "> ⚠️ Colab TPU runtime is not compatible with new JAX versions (> 0.4.0). Please, use kaggle notebooks if you want to use TPUs. There is no quick way to open notebook from github in kaggle (like colab badge above), so you will need to manually upload it. ⚠️\n", + "\n", + "> 🔥 Instances with multiple T4 gpus are available on Kaggle for free! Multi-gpu can speed up examples with `pmap`. 🔥" ] }, { @@ -53,32 +57,6 @@ "!pip install xminigrid" ] }, - { - "cell_type": "markdown", - "id": "4c2c028d-96fb-4413-80f5-f504285d8f42", - "metadata": {}, - "source": [ - "> ⚠️ Colab TPU runtime is not compatible with new JAX versions (> 0.4.0). Please, use kaggle notebooks if you want to use TPUs. There is no quick way to open notebook from github in kaggle (like colab badge above), so you will need to manually upload it. ⚠️\n", - "\n", - "If you have chosen the TPU runtime, some additional actions are needed. Please, uncomment these and setup the TPU through JAX. Check that devices are `TpuDevice`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6608c270-c2fa-44b9-8a9a-29b10928e6d1", - "metadata": {}, - "outputs": [], - "source": [ - "# import jax.tools.colab_tpu\n", - "# jax.tools.colab_tpu.setup_tpu()\n", - "\n", - "# import jax\n", - "# from jax.lib import xla_bridge\n", - "# print(xla_bridge.get_backend().platform)\n", - "# jax.devices()" - ] - }, { "cell_type": "code", "execution_count": null,