From 4e26efe121fe6dcde3de7c7178a3c63c61f04829 Mon Sep 17 00:00:00 2001 From: Howuhh Date: Mon, 4 Dec 2023 15:41:35 +0300 Subject: [PATCH] comment on training time --- examples/train_meta_standalone.ipynb | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/train_meta_standalone.ipynb b/examples/train_meta_standalone.ipynb index 1d14e05..fb84a1a 100644 --- a/examples/train_meta_standalone.ipynb +++ b/examples/train_meta_standalone.ipynb @@ -662,7 +662,8 @@ "metadata": {}, "outputs": [], "source": [ - "config = TrainConfig(env_id=\"XLand-MiniGrid-R4-9x9\", benchmark_id=\"Trivial\", total_timesteps=50_000_000)\n", + "# should take ~8min on colab gpu. It will suboptimal due to the model size and default hyperparams!\n", + "config = TrainConfig(env_id=\"XLand-MiniGrid-R4-9x9\", benchmark_id=\"Trivial\", total_timesteps=100_000_000)\n", "\n", "rng, env, env_params, benchmark, init_hstate, train_state = make_states(config)\n", "# replicating args across devices\n", @@ -687,7 +688,7 @@ "train_info = unreplicate(train_info)\n", "\n", "print(\"Final return: \", float(train_info[\"loss_info\"][\"eval/returns_mean\"][-1]))\n", - "plt.plot(jnp.arange(config.num_meta_updates), train_info[\"loss_info\"][\"eval/returns_mean\"])" + "plt.plot(jnp.arange(config.num_meta_updates), train_info[\"loss_info\"][\"eval/returns_mean\"]);" ] }, { @@ -705,6 +706,8 @@ "metadata": {}, "outputs": [], "source": [ + "from xminigrid.rendering.text_render import print_ruleset\n", + "\n", "META_EPISODES = 10\n", "\n", "env, env_params = xminigrid.make(config.env_id)\n", @@ -764,6 +767,8 @@ " rendered_imgs.append(env.render(env_params, timestep))\n", "\n", "print(\"Reward:\", total_reward)\n", + "print(\"Ruleset:\")\n", + "print_ruleset(ruleset)\n", "imageio.mimsave(\"eval_rollout.mp4\", rendered_imgs, fps=16, format=\"mp4\")" ] },