From 94004a70eca95b510c3886940a8a29f909c6a055 Mon Sep 17 00:00:00 2001 From: Howuhh Date: Sun, 3 Dec 2023 18:56:48 +0300 Subject: [PATCH] tested single task on colab --- examples/train_single_standalone.ipynb | 40 +++++++++++++------------- examples/walkthrough.ipynb | 3 +- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/examples/train_single_standalone.ipynb b/examples/train_single_standalone.ipynb index d2ba03e..54f101b 100644 --- a/examples/train_single_standalone.ipynb +++ b/examples/train_single_standalone.ipynb @@ -38,7 +38,8 @@ "# jax is already installed on the colab, uncomment only if needed\n", "# !pip install --upgrade \"jax[cuda11_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", "\n", - "# !pip install 'xminigrid[baselines]'" + "# !pip install 'xminigrid[baselines]'\n", + "!pip install \"xminigrid[baselines] @ git+https://github.com/corl-team/xland-minigrid.git\"" ] }, { @@ -61,6 +62,7 @@ "import optax\n", "import imageio\n", "import wandb\n", + "import matplotlib.pyplot as plt\n", "\n", "from flax import struct\n", "from flax.linen.initializers import glorot_normal, orthogonal, zeros_init\n", @@ -623,7 +625,8 @@ "metadata": {}, "outputs": [], "source": [ - "config = TrainConfig(env_id=\"MiniGrid-FourRooms\")\n", + "# should take ~8min on T4 gpu\n", + "config = TrainConfig(env_id=\"MiniGrid-FourRooms\", total_timesteps=50_000_000)\n", "\n", "rng, env, env_params, init_hstate, train_state = make_states(config)\n", "# replicating args across devices\n", @@ -642,7 +645,7 @@ "t = time.time()\n", "train_info = jax.block_until_ready(train_fn(rng, train_state, init_hstate))\n", "elapsed_time = time.time() - t\n", - "print(f\"Done in {elapsed_time:.2f}s\")\n", + "print(f\"Done in {elapsed_time / 60:.2f}min\")\n", "\n", "# unreplicating from multiple devices\n", "train_info = unreplicate(train_info)\n", @@ -665,12 +668,12 @@ "metadata": {}, "outputs": [], "source": [ - "env, env_params = xminigrid.make(conifg.env_id)\n", + "env, env_params = xminigrid.make(config.env_id)\n", "env = GymAutoResetWrapper(env)\n", "\n", "# you can use train_state from the final runner_state also\n", "# we just demo here how to do it if you loaded params from the checkpoint\n", - "params = train_info[\"params\"]\n", + "params = train_info[\"runner_state\"][1].params\n", "model = ActorCriticRNN(\n", " num_actions=env.num_actions(env_params),\n", " action_emb_dim=config.action_emb_dim,\n", @@ -678,31 +681,31 @@ " rnn_num_layers=config.rnn_num_layers,\n", " head_hidden_dim=config.head_hidden_dim,\n", ")\n", + "\n", "# jitting all functions\n", "apply_fn, reset_fn, step_fn = jax.jit(model.apply), jax.jit(env.reset), jax.jit(env.step)\n", "\n", - "# initial inputs\n", - "hidden = model.initialize_carry(1)\n", - "prev_reward = jnp.asarray(0)\n", - "prev_action = jnp.asarray(0, dtype=jnp.uint32)\n", - "obs = jnp.zeros(env.observation_shape(env_params))\n", - "\n", "# for logging\n", - "total_reward, num_episodes = 0, 0\n", + "total_reward = 0\n", "rendered_imgs = []\n", "\n", - "rng = jax.random.PRNGKey(0)\n", + "rng = jax.random.PRNGKey(1)\n", "rng, _rng = jax.random.split(rng)\n", "\n", + "# initial inputs\n", + "hidden = model.initialize_carry(1)\n", + "prev_reward = jnp.asarray(0)\n", + "prev_action = jnp.asarray(0)\n", + "\n", "timestep = reset_fn(env_params, _rng)\n", "rendered_imgs.append(env.render(env_params, timestep))\n", "\n", - "while num_episodes < TOTAL_EPISODES:\n", + "while not timestep.last():\n", " rng, _rng = jax.random.split(rng)\n", - " dist, value, hidden = apply_fn(\n", + " dist, _, hidden = apply_fn(\n", " params,\n", " {\n", - " \"observation\": obs[None, None, ...],\n", + " \"observation\": timestep.observation[None, None, ...],\n", " \"prev_action\": prev_action[None, None, ...],\n", " \"prev_reward\": prev_reward[None, None, ...],\n", " },\n", @@ -713,14 +716,11 @@ " timestep = step_fn(env_params, timestep, action)\n", " prev_action = action\n", " prev_reward = timestep.reward\n", - " obs = timestep.observation\n", "\n", " total_reward += timestep.reward.item()\n", - " num_episodes += int(timestep.last().item())\n", - " # if not bool(timestep.last().item()):\n", " rendered_imgs.append(env.render(env_params, timestep))\n", "\n", - "print(\"Total reward:\", total_reward)\n", + "print(\"Reward:\", total_reward)\n", "imageio.mimsave(\"eval_rollout.mp4\", rendered_imgs, fps=16, format=\"mp4\")" ] }, diff --git a/examples/walkthrough.ipynb b/examples/walkthrough.ipynb index 6dc1df1..68e31f9 100644 --- a/examples/walkthrough.ipynb +++ b/examples/walkthrough.ipynb @@ -54,7 +54,8 @@ "# jax is already installed on the colab, uncomment only if needed\n", "# !pip install --upgrade \"jax[cuda11_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", "\n", - "!pip install xminigrid" + "# !pip install xminigrid\n", + "!pip install \"xminigrid[baselines] @ git+https://github.com/corl-team/xland-minigrid.git\"" ] }, {