Skip to content

Commit

Permalink
comment on training time
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Dec 4, 2023
1 parent 2494918 commit 4e26efe
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions examples/train_meta_standalone.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,8 @@
"metadata": {},
"outputs": [],
"source": [
"config = TrainConfig(env_id=\"XLand-MiniGrid-R4-9x9\", benchmark_id=\"Trivial\", total_timesteps=50_000_000)\n",
"# should take ~8min on colab gpu. It will suboptimal due to the model size and default hyperparams!\n",
"config = TrainConfig(env_id=\"XLand-MiniGrid-R4-9x9\", benchmark_id=\"Trivial\", total_timesteps=100_000_000)\n",
"\n",
"rng, env, env_params, benchmark, init_hstate, train_state = make_states(config)\n",
"# replicating args across devices\n",
Expand All @@ -687,7 +688,7 @@
"train_info = unreplicate(train_info)\n",
"\n",
"print(\"Final return: \", float(train_info[\"loss_info\"][\"eval/returns_mean\"][-1]))\n",
"plt.plot(jnp.arange(config.num_meta_updates), train_info[\"loss_info\"][\"eval/returns_mean\"])"
"plt.plot(jnp.arange(config.num_meta_updates), train_info[\"loss_info\"][\"eval/returns_mean\"]);"
]
},
{
Expand All @@ -705,6 +706,8 @@
"metadata": {},
"outputs": [],
"source": [
"from xminigrid.rendering.text_render import print_ruleset\n",
"\n",
"META_EPISODES = 10\n",
"\n",
"env, env_params = xminigrid.make(config.env_id)\n",
Expand Down Expand Up @@ -764,6 +767,8 @@
" rendered_imgs.append(env.render(env_params, timestep))\n",
"\n",
"print(\"Reward:\", total_reward)\n",
"print(\"Ruleset:\")\n",
"print_ruleset(ruleset)\n",
"imageio.mimsave(\"eval_rollout.mp4\", rendered_imgs, fps=16, format=\"mp4\")"
]
},
Expand Down

0 comments on commit 4e26efe

Please sign in to comment.