Skip to content

Commit

Permalink
adapted training scripts for img obs
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Mar 5, 2024
1 parent 845eac9 commit 27c5c72
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 12 deletions.
39 changes: 27 additions & 12 deletions training/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,38 @@ class ActorCriticRNN(nn.Module):
rnn_hidden_dim: int = 64
rnn_num_layers: int = 1
head_hidden_dim: int = 64
img_obs: bool = False

@nn.compact
def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]:
B, S = inputs["observation"].shape[:2]
# encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py
img_encoder = nn.Sequential(
[
nn.Conv(16, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.relu,
# use this only for image sizes >= 7
# MaxPool2d((2, 2)),
nn.Conv(32, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.relu,
nn.Conv(64, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.relu,
]
)
if self.img_obs:
# slight modification of NatureDQN CNN
img_encoder = nn.Sequential(
[
nn.Conv(32, (8, 8), strides=4, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.relu,
nn.Conv(64, (4, 4), strides=3, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.relu,
nn.Conv(64, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.relu,
nn.Conv(64, (2, 2), strides=1, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
]
)
else:
img_encoder = nn.Sequential(
[
nn.Conv(16, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.relu,
# use this only for image sizes >= 7
# MaxPool2d((2, 2)),
nn.Conv(32, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.relu,
nn.Conv(64, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.relu,
]
)
action_encoder = nn.Embed(self.num_actions, self.action_emb_dim)

rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers)
Expand Down
8 changes: 8 additions & 0 deletions training/train_meta_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class TrainConfig:
name: str = "meta-task-ppo"
env_id: str = "XLand-MiniGrid-R1-9x9"
benchmark_id: str = "trivial-1m"
img_obs: bool = False
# agent
action_emb_dim: int = 16
rnn_hidden_dim: int = 1024
Expand Down Expand Up @@ -90,6 +91,12 @@ def linear_schedule(count):
env, env_params = xminigrid.make(config.env_id)
env = GymAutoResetWrapper(env)

# enabling image observations if needed
if config.img_obs:
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

env = RGBImgObservationWrapper(env)

# loading benchmark
benchmark = xminigrid.load_benchmark(config.benchmark_id)

Expand All @@ -103,6 +110,7 @@ def linear_schedule(count):
rnn_hidden_dim=config.rnn_hidden_dim,
rnn_num_layers=config.rnn_num_layers,
head_hidden_dim=config.head_hidden_dim,
img_obs=config.img_obs,
)
# [batch_size, seq_len, ...]
init_obs = {
Expand Down
8 changes: 8 additions & 0 deletions training/train_single_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TrainConfig:
group: str = "default"
name: str = "single-task-ppo"
env_id: str = "MiniGrid-Empty-6x6"
img_obs: bool = False
# agent
action_emb_dim: int = 16
rnn_hidden_dim: int = 1024
Expand Down Expand Up @@ -74,6 +75,12 @@ def linear_schedule(count):
env, env_params = xminigrid.make(config.env_id)
env = GymAutoResetWrapper(env)

# enabling image observations if needed
if config.img_obs:
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

env = RGBImgObservationWrapper(env)

# setup training state
rng = jax.random.PRNGKey(config.seed)
rng, _rng = jax.random.split(rng)
Expand All @@ -84,6 +91,7 @@ def linear_schedule(count):
rnn_hidden_dim=config.rnn_hidden_dim,
rnn_num_layers=config.rnn_num_layers,
head_hidden_dim=config.head_hidden_dim,
img_obs=config.img_obs,
)
# [batch_size, seq_len, ...]
init_obs = {
Expand Down

0 comments on commit 27c5c72

Please sign in to comment.