|
51 | 51 | "source": [
|
52 | 52 | "import time\n",
|
53 | 53 | "import math\n",
|
54 |
| - "from typing import TypedDict\n", |
| 54 | + "from typing import TypedDict, Optional\n", |
55 | 55 | "\n",
|
56 | 56 | "import jax\n",
|
57 | 57 | "import jax.numpy as jnp\n",
|
|
142 | 142 | " RNNModel, variable_axes={\"params\": None}, split_rngs={\"params\": False}, axis_name=\"batch\"\n",
|
143 | 143 | ")\n",
|
144 | 144 | "\n",
|
145 |
| - "class MaxPool2d(nn.Module):\n", |
146 |
| - " kernel_size: tuple[int, int]\n", |
147 |
| - "\n", |
148 |
| - " @nn.compact\n", |
149 |
| - " def __call__(self, x):\n", |
150 |
| - " return nn.max_pool(inputs=x, window_shape=self.kernel_size, strides=self.kernel_size, padding=\"VALID\")\n", |
151 | 145 | "\n",
|
152 | 146 | "class ActorCriticInput(TypedDict):\n",
|
153 | 147 | " observation: jax.Array\n",
|
154 | 148 | " prev_action: jax.Array\n",
|
155 | 149 | " prev_reward: jax.Array\n",
|
156 | 150 | "\n",
|
| 151 | + "\n", |
157 | 152 | "class ActorCriticRNN(nn.Module):\n",
|
158 | 153 | " num_actions: int\n",
|
159 | 154 | " action_emb_dim: int = 16\n",
|
160 | 155 | " rnn_hidden_dim: int = 64\n",
|
161 | 156 | " rnn_num_layers: int = 1\n",
|
162 | 157 | " head_hidden_dim: int = 64\n",
|
| 158 | + " img_obs: bool = False\n", |
163 | 159 | "\n",
|
164 | 160 | " @nn.compact\n",
|
165 | 161 | " def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]:\n",
|
166 | 162 | " B, S = inputs[\"observation\"].shape[:2]\n",
|
167 | 163 | " # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py\n",
|
168 |
| - " img_encoder = nn.Sequential(\n", |
169 |
| - " [\n", |
170 |
| - " nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", |
171 |
| - " nn.relu,\n", |
172 |
| - " MaxPool2d((2, 2)),\n", |
173 |
| - " nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", |
174 |
| - " nn.relu,\n", |
175 |
| - " nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", |
176 |
| - " nn.relu,\n", |
177 |
| - " ]\n", |
178 |
| - " )\n", |
| 164 | + " if self.img_obs:\n", |
| 165 | + " img_encoder = nn.Sequential(\n", |
| 166 | + " [\n", |
| 167 | + " nn.Conv(16, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", |
| 168 | + " nn.relu,\n", |
| 169 | + " nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", |
| 170 | + " nn.relu,\n", |
| 171 | + " nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", |
| 172 | + " nn.relu,\n", |
| 173 | + " nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", |
| 174 | + " ]\n", |
| 175 | + " )\n", |
| 176 | + " else:\n", |
| 177 | + " img_encoder = nn.Sequential(\n", |
| 178 | + " [\n", |
| 179 | + " nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", |
| 180 | + " nn.relu,\n", |
| 181 | + " nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", |
| 182 | + " nn.relu,\n", |
| 183 | + " nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", |
| 184 | + " nn.relu,\n", |
| 185 | + " ]\n", |
| 186 | + " )\n", |
179 | 187 | " action_encoder = nn.Embed(self.num_actions, self.action_emb_dim)\n",
|
180 | 188 | "\n",
|
181 | 189 | " rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers)\n",
|
|
389 | 397 | "@dataclass\n",
|
390 | 398 | "class TrainConfig:\n",
|
391 | 399 | " env_id: str = \"MiniGrid-Empty-6x6\"\n",
|
| 400 | + " benchmark_id: Optional[str] = None\n", |
| 401 | + " ruleset_id: Optional[int] = None\n", |
| 402 | + " img_obs: bool = False\n", |
392 | 403 | " # agent\n",
|
393 | 404 | " action_emb_dim: int = 16\n",
|
394 | 405 | " rnn_hidden_dim: int = 1024\n",
|
|
428 | 439 | " return config.lr * frac\n",
|
429 | 440 | "\n",
|
430 | 441 | " # setup environment\n",
|
431 |
| - " if \"XLand-MiniGrid\" in config.env_id:\n", |
432 |
| - " raise ValueError(\"Only single-task environments are supported.\")\n", |
433 |
| - "\n", |
434 | 442 | " env, env_params = xminigrid.make(config.env_id)\n",
|
435 | 443 | " env = GymAutoResetWrapper(env)\n",
|
436 | 444 | "\n",
|
| 445 | + " # for single-task XLand environments\n", |
| 446 | + " if config.benchmark_id is not None:\n", |
| 447 | + " assert \"XLand-MiniGrid\" in config.env_id, \"Benchmarks should be used only with XLand environments.\"\n", |
| 448 | + " assert config.ruleset_id is not None, \"Ruleset ID should be specified for benchmarks usage.\"\n", |
| 449 | + " benchmark = xminigrid.load_benchmark(config.benchmark_id)\n", |
| 450 | + " env_params = env_params.replace(ruleset=benchmark.get_ruleset(config.ruleset_id))\n", |
| 451 | + "\n", |
| 452 | + " # enabling image observations if needed\n", |
| 453 | + " if config.img_obs:\n", |
| 454 | + " from xminigrid.experimental.img_obs import RGBImgObservationWrapper\n", |
| 455 | + "\n", |
| 456 | + " env = RGBImgObservationWrapper(env)\n", |
| 457 | + "\n", |
437 | 458 | " # setup training state\n",
|
438 | 459 | " rng = jax.random.PRNGKey(config.seed)\n",
|
439 | 460 | " rng, _rng = jax.random.split(rng)\n",
|
|
444 | 465 | " rnn_hidden_dim=config.rnn_hidden_dim,\n",
|
445 | 466 | " rnn_num_layers=config.rnn_num_layers,\n",
|
446 | 467 | " head_hidden_dim=config.head_hidden_dim,\n",
|
| 468 | + " img_obs=config.img_obs,\n", |
447 | 469 | " )\n",
|
448 | 470 | " # [batch_size, seq_len, ...]\n",
|
449 | 471 | " init_obs = {\n",
|
|
0 commit comments