Skip to content

Commit

Permalink
tested single task on colab
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Dec 3, 2023
1 parent e29a50b commit 94004a7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
40 changes: 20 additions & 20 deletions examples/train_single_standalone.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
"# jax is already installed on the colab, uncomment only if needed\n",
"# !pip install --upgrade \"jax[cuda11_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"\n",
"# !pip install 'xminigrid[baselines]'"
"# !pip install 'xminigrid[baselines]'\n",
"!pip install \"xminigrid[baselines] @ git+https://github.com/corl-team/xland-minigrid.git\""
]
},
{
Expand All @@ -61,6 +62,7 @@
"import optax\n",
"import imageio\n",
"import wandb\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from flax import struct\n",
"from flax.linen.initializers import glorot_normal, orthogonal, zeros_init\n",
Expand Down Expand Up @@ -623,7 +625,8 @@
"metadata": {},
"outputs": [],
"source": [
"config = TrainConfig(env_id=\"MiniGrid-FourRooms\")\n",
"# should take ~8min on T4 gpu\n",
"config = TrainConfig(env_id=\"MiniGrid-FourRooms\", total_timesteps=50_000_000)\n",
"\n",
"rng, env, env_params, init_hstate, train_state = make_states(config)\n",
"# replicating args across devices\n",
Expand All @@ -642,7 +645,7 @@
"t = time.time()\n",
"train_info = jax.block_until_ready(train_fn(rng, train_state, init_hstate))\n",
"elapsed_time = time.time() - t\n",
"print(f\"Done in {elapsed_time:.2f}s\")\n",
"print(f\"Done in {elapsed_time / 60:.2f}min\")\n",
"\n",
"# unreplicating from multiple devices\n",
"train_info = unreplicate(train_info)\n",
Expand All @@ -665,44 +668,44 @@
"metadata": {},
"outputs": [],
"source": [
"env, env_params = xminigrid.make(conifg.env_id)\n",
"env, env_params = xminigrid.make(config.env_id)\n",
"env = GymAutoResetWrapper(env)\n",
"\n",
"# you can use train_state from the final runner_state also\n",
"# we just demo here how to do it if you loaded params from the checkpoint\n",
"params = train_info[\"params\"]\n",
"params = train_info[\"runner_state\"][1].params\n",
"model = ActorCriticRNN(\n",
" num_actions=env.num_actions(env_params),\n",
" action_emb_dim=config.action_emb_dim,\n",
" rnn_hidden_dim=config.rnn_hidden_dim,\n",
" rnn_num_layers=config.rnn_num_layers,\n",
" head_hidden_dim=config.head_hidden_dim,\n",
")\n",
"\n",
"# jitting all functions\n",
"apply_fn, reset_fn, step_fn = jax.jit(model.apply), jax.jit(env.reset), jax.jit(env.step)\n",
"\n",
"# initial inputs\n",
"hidden = model.initialize_carry(1)\n",
"prev_reward = jnp.asarray(0)\n",
"prev_action = jnp.asarray(0, dtype=jnp.uint32)\n",
"obs = jnp.zeros(env.observation_shape(env_params))\n",
"\n",
"# for logging\n",
"total_reward, num_episodes = 0, 0\n",
"total_reward = 0\n",
"rendered_imgs = []\n",
"\n",
"rng = jax.random.PRNGKey(0)\n",
"rng = jax.random.PRNGKey(1)\n",
"rng, _rng = jax.random.split(rng)\n",
"\n",
"# initial inputs\n",
"hidden = model.initialize_carry(1)\n",
"prev_reward = jnp.asarray(0)\n",
"prev_action = jnp.asarray(0)\n",
"\n",
"timestep = reset_fn(env_params, _rng)\n",
"rendered_imgs.append(env.render(env_params, timestep))\n",
"\n",
"while num_episodes < TOTAL_EPISODES:\n",
"while not timestep.last():\n",
" rng, _rng = jax.random.split(rng)\n",
" dist, value, hidden = apply_fn(\n",
" dist, _, hidden = apply_fn(\n",
" params,\n",
" {\n",
" \"observation\": obs[None, None, ...],\n",
" \"observation\": timestep.observation[None, None, ...],\n",
" \"prev_action\": prev_action[None, None, ...],\n",
" \"prev_reward\": prev_reward[None, None, ...],\n",
" },\n",
Expand All @@ -713,14 +716,11 @@
" timestep = step_fn(env_params, timestep, action)\n",
" prev_action = action\n",
" prev_reward = timestep.reward\n",
" obs = timestep.observation\n",
"\n",
" total_reward += timestep.reward.item()\n",
" num_episodes += int(timestep.last().item())\n",
" # if not bool(timestep.last().item()):\n",
" rendered_imgs.append(env.render(env_params, timestep))\n",
"\n",
"print(\"Total reward:\", total_reward)\n",
"print(\"Reward:\", total_reward)\n",
"imageio.mimsave(\"eval_rollout.mp4\", rendered_imgs, fps=16, format=\"mp4\")"
]
},
Expand Down
3 changes: 2 additions & 1 deletion examples/walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
"# jax is already installed on the colab, uncomment only if needed\n",
"# !pip install --upgrade \"jax[cuda11_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"\n",
"!pip install xminigrid"
"# !pip install xminigrid\n",
"!pip install \"xminigrid[baselines] @ git+https://github.com/corl-team/xland-minigrid.git\""
]
},
{
Expand Down

0 comments on commit 94004a7

Please sign in to comment.