Skip to content

Commit a61adac

Browse files
committed
updated standalone ppo
1 parent 8127fe9 commit a61adac

File tree

2 files changed

+76
-40
lines changed

2 files changed

+76
-40
lines changed

examples/train_meta_standalone.ipynb

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -143,40 +143,48 @@
143143
" RNNModel, variable_axes={\"params\": None}, split_rngs={\"params\": False}, axis_name=\"batch\"\n",
144144
")\n",
145145
"\n",
146-
"class MaxPool2d(nn.Module):\n",
147-
" kernel_size: tuple[int, int]\n",
148-
"\n",
149-
" @nn.compact\n",
150-
" def __call__(self, x):\n",
151-
" return nn.max_pool(inputs=x, window_shape=self.kernel_size, strides=self.kernel_size, padding=\"VALID\")\n",
152146
"\n",
153147
"class ActorCriticInput(TypedDict):\n",
154148
" observation: jax.Array\n",
155149
" prev_action: jax.Array\n",
156150
" prev_reward: jax.Array\n",
157151
"\n",
152+
"\n",
158153
"class ActorCriticRNN(nn.Module):\n",
159154
" num_actions: int\n",
160155
" action_emb_dim: int = 16\n",
161156
" rnn_hidden_dim: int = 64\n",
162157
" rnn_num_layers: int = 1\n",
163158
" head_hidden_dim: int = 64\n",
159+
" img_obs: bool = False\n",
164160
"\n",
165161
" @nn.compact\n",
166162
" def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]:\n",
167163
" B, S = inputs[\"observation\"].shape[:2]\n",
168164
" # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py\n",
169-
" img_encoder = nn.Sequential(\n",
170-
" [\n",
171-
" nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
172-
" nn.relu,\n",
173-
" MaxPool2d((2, 2)),\n",
174-
" nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
175-
" nn.relu,\n",
176-
" nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
177-
" nn.relu,\n",
178-
" ]\n",
179-
" )\n",
165+
" if self.img_obs:\n",
166+
" img_encoder = nn.Sequential(\n",
167+
" [\n",
168+
" nn.Conv(16, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
169+
" nn.relu,\n",
170+
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
171+
" nn.relu,\n",
172+
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
173+
" nn.relu,\n",
174+
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
175+
" ]\n",
176+
" )\n",
177+
" else:\n",
178+
" img_encoder = nn.Sequential(\n",
179+
" [\n",
180+
" nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
181+
" nn.relu,\n",
182+
" nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
183+
" nn.relu,\n",
184+
" nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
185+
" nn.relu,\n",
186+
" ]\n",
187+
" )\n",
180188
" action_encoder = nn.Embed(self.num_actions, self.action_emb_dim)\n",
181189
"\n",
182190
" rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers)\n",
@@ -294,8 +302,6 @@
294302
" value_loss = jnp.square(value - targets)\n",
295303
" value_loss_clipped = jnp.square(value_pred_clipped - targets)\n",
296304
" value_loss = 0.5 * jnp.maximum(value_loss, value_loss_clipped).mean()\n",
297-
" # TODO: ablate this!\n",
298-
" # value_loss = jnp.square(value - targets).mean()\n",
299305
"\n",
300306
" # CALCULATE ACTOR LOSS\n",
301307
" ratio = jnp.exp(log_prob - transitions.log_prob)\n",
@@ -391,6 +397,7 @@
391397
"class TrainConfig:\n",
392398
" env_id: str = \"XLand-MiniGrid-R1-8x8\"\n",
393399
" benchmark_id: str = \"trivial-1m\"\n",
400+
" img_obs: bool = False\n",
394401
" # agent\n",
395402
" action_emb_dim: int = 16\n",
396403
" rnn_hidden_dim: int = 64\n",
@@ -444,6 +451,12 @@
444451
" env, env_params = xminigrid.make(config.env_id)\n",
445452
" env = GymAutoResetWrapper(env)\n",
446453
"\n",
454+
" # enabling image observations if needed\n",
455+
" if config.img_obs:\n",
456+
" from xminigrid.experimental.img_obs import RGBImgObservationWrapper\n",
457+
"\n",
458+
" env = RGBImgObservationWrapper(env)\n",
459+
" \n",
447460
" # loading benchmark\n",
448461
" benchmark = xminigrid.load_benchmark(config.benchmark_id)\n",
449462
"\n",
@@ -457,6 +470,7 @@
457470
" rnn_hidden_dim=config.rnn_hidden_dim,\n",
458471
" rnn_num_layers=config.rnn_num_layers,\n",
459472
" head_hidden_dim=config.head_hidden_dim,\n",
473+
" img_obs=config.img_obs,\n",
460474
" )\n",
461475
" # [batch_size, seq_len, ...]\n",
462476
" init_obs = {\n",

examples/train_single_standalone.ipynb

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
"source": [
5252
"import time\n",
5353
"import math\n",
54-
"from typing import TypedDict\n",
54+
"from typing import TypedDict, Optional\n",
5555
"\n",
5656
"import jax\n",
5757
"import jax.numpy as jnp\n",
@@ -142,40 +142,48 @@
142142
" RNNModel, variable_axes={\"params\": None}, split_rngs={\"params\": False}, axis_name=\"batch\"\n",
143143
")\n",
144144
"\n",
145-
"class MaxPool2d(nn.Module):\n",
146-
" kernel_size: tuple[int, int]\n",
147-
"\n",
148-
" @nn.compact\n",
149-
" def __call__(self, x):\n",
150-
" return nn.max_pool(inputs=x, window_shape=self.kernel_size, strides=self.kernel_size, padding=\"VALID\")\n",
151145
"\n",
152146
"class ActorCriticInput(TypedDict):\n",
153147
" observation: jax.Array\n",
154148
" prev_action: jax.Array\n",
155149
" prev_reward: jax.Array\n",
156150
"\n",
151+
"\n",
157152
"class ActorCriticRNN(nn.Module):\n",
158153
" num_actions: int\n",
159154
" action_emb_dim: int = 16\n",
160155
" rnn_hidden_dim: int = 64\n",
161156
" rnn_num_layers: int = 1\n",
162157
" head_hidden_dim: int = 64\n",
158+
" img_obs: bool = False\n",
163159
"\n",
164160
" @nn.compact\n",
165161
" def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]:\n",
166162
" B, S = inputs[\"observation\"].shape[:2]\n",
167163
" # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py\n",
168-
" img_encoder = nn.Sequential(\n",
169-
" [\n",
170-
" nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
171-
" nn.relu,\n",
172-
" MaxPool2d((2, 2)),\n",
173-
" nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
174-
" nn.relu,\n",
175-
" nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
176-
" nn.relu,\n",
177-
" ]\n",
178-
" )\n",
164+
" if self.img_obs:\n",
165+
" img_encoder = nn.Sequential(\n",
166+
" [\n",
167+
" nn.Conv(16, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
168+
" nn.relu,\n",
169+
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
170+
" nn.relu,\n",
171+
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
172+
" nn.relu,\n",
173+
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
174+
" ]\n",
175+
" )\n",
176+
" else:\n",
177+
" img_encoder = nn.Sequential(\n",
178+
" [\n",
179+
" nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
180+
" nn.relu,\n",
181+
" nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
182+
" nn.relu,\n",
183+
" nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
184+
" nn.relu,\n",
185+
" ]\n",
186+
" )\n",
179187
" action_encoder = nn.Embed(self.num_actions, self.action_emb_dim)\n",
180188
"\n",
181189
" rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers)\n",
@@ -389,6 +397,9 @@
389397
"@dataclass\n",
390398
"class TrainConfig:\n",
391399
" env_id: str = \"MiniGrid-Empty-6x6\"\n",
400+
" benchmark_id: Optional[str] = None\n",
401+
" ruleset_id: Optional[int] = None\n",
402+
" img_obs: bool = False\n",
392403
" # agent\n",
393404
" action_emb_dim: int = 16\n",
394405
" rnn_hidden_dim: int = 1024\n",
@@ -428,12 +439,22 @@
428439
" return config.lr * frac\n",
429440
"\n",
430441
" # setup environment\n",
431-
" if \"XLand-MiniGrid\" in config.env_id:\n",
432-
" raise ValueError(\"Only single-task environments are supported.\")\n",
433-
"\n",
434442
" env, env_params = xminigrid.make(config.env_id)\n",
435443
" env = GymAutoResetWrapper(env)\n",
436444
"\n",
445+
" # for single-task XLand environments\n",
446+
" if config.benchmark_id is not None:\n",
447+
" assert \"XLand-MiniGrid\" in config.env_id, \"Benchmarks should be used only with XLand environments.\"\n",
448+
" assert config.ruleset_id is not None, \"Ruleset ID should be specified for benchmarks usage.\"\n",
449+
" benchmark = xminigrid.load_benchmark(config.benchmark_id)\n",
450+
" env_params = env_params.replace(ruleset=benchmark.get_ruleset(config.ruleset_id))\n",
451+
"\n",
452+
" # enabling image observations if needed\n",
453+
" if config.img_obs:\n",
454+
" from xminigrid.experimental.img_obs import RGBImgObservationWrapper\n",
455+
"\n",
456+
" env = RGBImgObservationWrapper(env)\n",
457+
"\n",
437458
" # setup training state\n",
438459
" rng = jax.random.PRNGKey(config.seed)\n",
439460
" rng, _rng = jax.random.split(rng)\n",
@@ -444,6 +465,7 @@
444465
" rnn_hidden_dim=config.rnn_hidden_dim,\n",
445466
" rnn_num_layers=config.rnn_num_layers,\n",
446467
" head_hidden_dim=config.head_hidden_dim,\n",
468+
" img_obs=config.img_obs,\n",
447469
" )\n",
448470
" # [batch_size, seq_len, ...]\n",
449471
" init_obs = {\n",

0 commit comments

Comments
 (0)