diff --git a/.gitignore b/.gitignore index c2bcccb..0da1b4f 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,4 @@ g # will remove later scripts/*testing* configs/wip -scripts/eval_example.py \ No newline at end of file +scripts/eval_example.py diff --git a/configs/meta.yaml b/configs/meta.yaml index 12795d8..0c2122b 100644 --- a/configs/meta.yaml +++ b/configs/meta.yaml @@ -6,7 +6,7 @@ num_envs: 16384 num_steps_per_env: 12800 num_steps_per_update: 256 update_epochs: 1 -num_minibatches: 32 +num_minibatches: 16 eval_num_envs: 16384 eval_num_episodes: 25 train_seed: 5 \ No newline at end of file diff --git a/configs/single.yaml b/configs/single.yaml index 11d45fa..e106d04 100644 --- a/configs/single.yaml +++ b/configs/single.yaml @@ -6,6 +6,6 @@ total_timesteps: 1_000_000_000 num_envs: 16384 num_steps: 256 update_epochs: 1 -num_minibatches: 8 +num_minibatches: 16 eval_episodes: 512 train_seed: 5 \ No newline at end of file diff --git a/src/xminigrid/__init__.py b/src/xminigrid/__init__.py index f57ecbe..02a3601 100644 --- a/src/xminigrid/__init__.py +++ b/src/xminigrid/__init__.py @@ -2,11 +2,11 @@ from .registration import make, register, registered_environments # TODO: add __all__ -__version__ = "0.8.0" +__version__ = "0.9.0" # ---------- XLand-MiniGrid environments ---------- -# WARN: TMP, only for FPS measurements +# WARN: TMP, only for FPS measurements, will remove later # register( # id="MiniGrid-1Rules", # entry_point="xminigrid.envs.xland_tmp:XLandMiniGrid", @@ -79,9 +79,6 @@ # width=64, # ) - -# TODO: reconsider grid sizes and time limits after the benchmarks are generated. -# Should be enough space for initial tiles even in the hardest setting register( id="XLand-MiniGrid-R1-9x9", entry_point="xminigrid.envs.xland:XLandMiniGrid", diff --git a/training/eval.py b/training/eval.py deleted file mode 100644 index 450bdb5..0000000 --- a/training/eval.py +++ /dev/null @@ -1,81 +0,0 @@ -# example on how to restore checkpoints and use them for inference -# TODO: add argparse arguments to load generic checkpoints and envs -# TODO: move this to examples/ -import imageio -import jax -import jax.numpy as jnp -import orbax.checkpoint -import xminigrid -from nn import ActorCriticRNN -from xminigrid.rendering.text_render import print_ruleset -from xminigrid.wrappers import GymAutoResetWrapper - -TOTAL_EPISODES = 10 - - -def main(): - orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() - checkpoint = orbax_checkpointer.restore("../xland-minigrid-data/checkpoints") - config = checkpoint["config"] - params = checkpoint["params"] - - env, env_params = xminigrid.make("XLand-MiniGrid-R1-9x9") - env = GymAutoResetWrapper(env) - - ruleset = xminigrid.load_benchmark("trivial-1m").get_ruleset(3) - env_params = env_params.replace(ruleset=ruleset) - - model = ActorCriticRNN( - num_actions=env.num_actions(env_params), - action_emb_dim=config["action_emb_dim"], - rnn_hidden_dim=config["rnn_hidden_dim"], - rnn_num_layers=config["rnn_num_layers"], - head_hidden_dim=config["head_hidden_dim"], - ) - # jitting all functions - apply_fn, reset_fn, step_fn = jax.jit(model.apply), jax.jit(env.reset), jax.jit(env.step) - - # initial inputs - prev_reward = jnp.asarray(0) - prev_action = jnp.asarray(0) - hidden = model.initialize_carry(1) - - # for logging - total_reward, num_episodes = 0, 0 - rendered_imgs = [] - - rng = jax.random.key(0) - rng, _rng = jax.random.split(rng) - - timestep = reset_fn(env_params, _rng) - rendered_imgs.append(env.render(env_params, timestep)) - while num_episodes < TOTAL_EPISODES: - rng, _rng = jax.random.split(rng) - dist, value, hidden = apply_fn( - params, - { - "observation": timestep.observation[None, None, ...], - "prev_action": prev_action[None, None, ...], - "prev_reward": prev_reward[None, None, ...], - }, - hidden, - ) - action = dist.sample(seed=_rng).squeeze() - - timestep = step_fn(env_params, timestep, action) - prev_action = action - prev_reward = timestep.reward - - total_reward += timestep.reward.item() - num_episodes += int(timestep.last().item()) - - rendered_imgs.append(env.render(env_params, timestep)) - - print("Total reward:", total_reward) - print_ruleset(ruleset) - imageio.mimsave("rollout.mp4", rendered_imgs, fps=8, format="mp4") - # imageio.mimsave("rollout.gif", rendered_imgs, duration=1000 * 1 / 8, format="gif") - - -if __name__ == "__main__": - main() diff --git a/training/nn.py b/training/nn.py index 759c579..14144d2 100644 --- a/training/nn.py +++ b/training/nn.py @@ -1,29 +1,32 @@ # Model adapted from minigrid baselines: # https://github.com/lcswillems/rl-starter-files/blob/master/model.py import math -from typing import TypedDict +from typing import Optional, TypedDict import distrax import flax import flax.linen as nn import jax import jax.numpy as jnp +from flax.linen.dtypes import promote_dtype from flax.linen.initializers import glorot_normal, orthogonal, zeros_init - -# from xminigrid.core.constants import NUM_COLORS, NUM_TILES +from flax.typing import Dtype +from xminigrid.core.constants import NUM_COLORS, NUM_TILES class GRU(nn.Module): hidden_dim: int + dtype: Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 @nn.compact def __call__(self, xs, init_state): seq_len, input_dim = xs.shape # this init might not be optimal, for example bias for reset gate should be -1 (for now ok) - Wi = self.param("Wi", glorot_normal(in_axis=1, out_axis=0), (self.hidden_dim * 3, input_dim)) - Wh = self.param("Wh", orthogonal(column_axis=0), (self.hidden_dim * 3, self.hidden_dim)) - bi = self.param("bi", zeros_init(), (self.hidden_dim * 3,)) - bn = self.param("bn", zeros_init(), (self.hidden_dim,)) + Wi = self.param("Wi", glorot_normal(in_axis=1, out_axis=0), (self.hidden_dim * 3, input_dim), self.param_dtype) + Wh = self.param("Wh", orthogonal(column_axis=0), (self.hidden_dim * 3, self.hidden_dim), self.param_dtype) + bi = self.param("bi", zeros_init(), (self.hidden_dim * 3,), self.param_dtype) + bn = self.param("bn", zeros_init(), (self.hidden_dim,), self.param_dtype) def _step_fn(h, x): igates = jnp.split(Wi @ x + bi, 3) @@ -36,6 +39,9 @@ def _step_fn(h, x): return next_h, next_h + # cast to the computation dtype + xs, init_state, Wi, Wh, bi, bn = promote_dtype(xs, init_state, Wi, Wh, bi, bn, dtype=self.dtype) + last_state, all_states = jax.lax.scan(_step_fn, init=init_state, xs=xs) return all_states, last_state @@ -43,6 +49,8 @@ def _step_fn(h, x): class RNNModel(nn.Module): hidden_dim: int num_layers: int + dtype: Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 @nn.compact def __call__(self, xs, init_state): @@ -50,7 +58,7 @@ def __call__(self, xs, init_state): # init_state: [num_layers, hidden_dim] outs, states = [], [] for layer in range(self.num_layers): - xs, state = GRU(hidden_dim=self.hidden_dim)(xs, init_state[layer]) + xs, state = GRU(self.hidden_dim, self.dtype, self.param_dtype)(xs, init_state[layer]) outs.append(xs) states.append(state) @@ -63,103 +71,174 @@ def __call__(self, xs, init_state): ) -class MaxPool2d(nn.Module): - kernel_size: tuple[int, int] +class EmbeddingEncoder(nn.Module): + emb_dim: int = 16 + dtype: Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 @nn.compact - def __call__(self, x): - return nn.max_pool(inputs=x, window_shape=self.kernel_size, strides=self.kernel_size, padding="VALID") - - -# not used currently -# class EmbeddingEncoder(nn.Module): -# emb_dim: int = 2 -# -# @nn.compact -# def __call__(self, img): -# entity_emb = nn.Embed(NUM_TILES, self.emb_dim) -# color_emb = nn.Embed(NUM_COLORS, self.emb_dim) -# -# # [..., channels] -# img_emb = jnp.concatenate([ -# entity_emb(img[..., 0]), -# color_emb(img[..., 1]), -# ], axis=-1) -# return img_emb + def __call__(self, img): + entity_emb = nn.Embed(NUM_TILES, self.emb_dim, self.dtype, self.param_dtype) + color_emb = nn.Embed(NUM_COLORS, self.emb_dim, self.dtype, self.param_dtype) + + # [..., channels] + img_emb = jnp.concatenate( + [ + entity_emb(img[..., 0]), + color_emb(img[..., 1]), + ], + axis=-1, + ) + return img_emb class ActorCriticInput(TypedDict): - observation: jax.Array + obs_img: jax.Array + obs_dir: jax.Array prev_action: jax.Array prev_reward: jax.Array class ActorCriticRNN(nn.Module): num_actions: int + obs_emb_dim: int = 16 action_emb_dim: int = 16 rnn_hidden_dim: int = 64 rnn_num_layers: int = 1 head_hidden_dim: int = 64 img_obs: bool = False + dtype: Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 @nn.compact def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]: - B, S = inputs["observation"].shape[:2] + B, S = inputs["obs_img"].shape[:2] + # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py if self.img_obs: img_encoder = nn.Sequential( [ - nn.Conv(16, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.Conv( + 16, + (3, 3), + strides=2, + padding="VALID", + kernel_init=orthogonal(math.sqrt(2)), + dtype=self.dtype, + param_dtype=self.param_dtype, + ), nn.relu, - nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.Conv( + 32, + (3, 3), + strides=2, + padding="VALID", + kernel_init=orthogonal(math.sqrt(2)), + dtype=self.dtype, + param_dtype=self.param_dtype, + ), nn.relu, - nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.Conv( + 32, + (3, 3), + strides=2, + padding="VALID", + kernel_init=orthogonal(math.sqrt(2)), + dtype=self.dtype, + param_dtype=self.param_dtype, + ), nn.relu, - nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.Conv( + 32, + (3, 3), + strides=2, + padding="VALID", + kernel_init=orthogonal(math.sqrt(2)), + dtype=self.dtype, + param_dtype=self.param_dtype, + ), ] ) else: img_encoder = nn.Sequential( [ - nn.Conv(16, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + # For small dims nn.Embed is extremely slow in bf16, so we leave everything in default dtypes + EmbeddingEncoder(emb_dim=self.obs_emb_dim), + nn.Conv( + 16, + (2, 2), + padding="VALID", + kernel_init=orthogonal(math.sqrt(2)), + dtype=self.dtype, + param_dtype=self.param_dtype, + ), nn.relu, - # use this only for image sizes >= 7 - # MaxPool2d((2, 2)), - nn.Conv(32, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.Conv( + 32, + (2, 2), + padding="VALID", + kernel_init=orthogonal(math.sqrt(2)), + dtype=self.dtype, + param_dtype=self.param_dtype, + ), nn.relu, - nn.Conv(64, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.Conv( + 64, + (2, 2), + padding="VALID", + kernel_init=orthogonal(math.sqrt(2)), + dtype=self.dtype, + param_dtype=self.param_dtype, + ), nn.relu, ] ) action_encoder = nn.Embed(self.num_actions, self.action_emb_dim) + direction_encoder = nn.Dense(self.action_emb_dim, dtype=self.dtype, param_dtype=self.param_dtype) - rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers) + rnn_core = BatchedRNNModel( + self.rnn_hidden_dim, self.rnn_num_layers, dtype=self.dtype, param_dtype=self.param_dtype + ) actor = nn.Sequential( [ - nn.Dense(self.head_hidden_dim, kernel_init=orthogonal(2)), + nn.Dense( + self.head_hidden_dim, kernel_init=orthogonal(2), dtype=self.dtype, param_dtype=self.param_dtype + ), nn.tanh, - nn.Dense(self.num_actions, kernel_init=orthogonal(0.01)), + nn.Dense( + self.num_actions, kernel_init=orthogonal(0.01), dtype=self.dtype, param_dtype=self.param_dtype + ), ] ) critic = nn.Sequential( [ - nn.Dense(self.head_hidden_dim, kernel_init=orthogonal(2)), + nn.Dense( + self.head_hidden_dim, kernel_init=orthogonal(2), dtype=self.dtype, param_dtype=self.param_dtype + ), nn.tanh, - nn.Dense(1, kernel_init=orthogonal(1.0)), + nn.Dense(1, kernel_init=orthogonal(1.0), dtype=self.dtype, param_dtype=self.param_dtype), ] ) # [batch_size, seq_len, ...] - obs_emb = img_encoder(inputs["observation"]).reshape(B, S, -1) + obs_emb = img_encoder(inputs["obs_img"].astype(jnp.int32)).reshape(B, S, -1) + dir_emb = direction_encoder(inputs["obs_dir"]) act_emb = action_encoder(inputs["prev_action"]) - # [batch_size, seq_len, hidden_dim + act_emb_dim + 1] - out = jnp.concatenate([obs_emb, act_emb, inputs["prev_reward"][..., None]], axis=-1) + + # [batch_size, seq_len, hidden_dim + 2 * act_emb_dim + 1] + out = jnp.concatenate([obs_emb, dir_emb, act_emb, inputs["prev_reward"][..., None]], axis=-1) + # core networks out, new_hidden = rnn_core(out, hidden) - dist = distrax.Categorical(logits=actor(out)) + + # casting to full precision for the loss, as softmax/log_softmax + # (inside Categorical) is not stable in bf16 + logits = actor(out).astype(jnp.float32) + + dist = distrax.Categorical(logits=logits) values = critic(out) return dist, jnp.squeeze(values, axis=-1), new_hidden def initialize_carry(self, batch_size): - return jnp.zeros((batch_size, self.rnn_num_layers, self.rnn_hidden_dim)) + return jnp.zeros((batch_size, self.rnn_num_layers, self.rnn_hidden_dim), dtype=self.dtype) diff --git a/training/train_meta_task.py b/training/train_meta_task.py index 1ee16fd..8525f2c 100644 --- a/training/train_meta_task.py +++ b/training/train_meta_task.py @@ -23,7 +23,7 @@ from utils import Transition, calculate_gae, ppo_update_networks, rollout from xminigrid.benchmarks import Benchmark from xminigrid.environment import Environment, EnvParams -from xminigrid.wrappers import GymAutoResetWrapper +from xminigrid.wrappers import DirectionObservationWrapper, GymAutoResetWrapper # this will be default in new jax versions anyway jax.config.update("jax_threefry_partitionable", True) @@ -38,11 +38,13 @@ class TrainConfig: benchmark_id: str = "trivial-1m" img_obs: bool = False # agent + obs_emb_dim: int = 16 action_emb_dim: int = 16 rnn_hidden_dim: int = 1024 rnn_num_layers: int = 1 head_hidden_dim: int = 256 # training + enable_bf16: bool = False num_envs: int = 8192 num_steps_per_env: int = 4096 num_steps_per_update: int = 32 @@ -90,6 +92,7 @@ def linear_schedule(count): env, env_params = xminigrid.make(config.env_id) env = GymAutoResetWrapper(env) + env = DirectionObservationWrapper(env) # enabling image observations if needed if config.img_obs: @@ -106,15 +109,20 @@ def linear_schedule(count): network = ActorCriticRNN( num_actions=env.num_actions(env_params), + obs_emb_dim=config.obs_emb_dim, action_emb_dim=config.action_emb_dim, rnn_hidden_dim=config.rnn_hidden_dim, rnn_num_layers=config.rnn_num_layers, head_hidden_dim=config.head_hidden_dim, img_obs=config.img_obs, + dtype=jnp.bfloat16 if config.enable_bf16 else None, ) # [batch_size, seq_len, ...] + shapes = env.observation_shape(env_params) + init_obs = { - "observation": jnp.zeros((config.num_envs_per_device, 1, *env.observation_shape(env_params))), + "obs_img": jnp.zeros((config.num_envs_per_device, 1, *shapes["img"])), + "obs_dir": jnp.zeros((config.num_envs_per_device, 1, shapes["direction"])), "prev_action": jnp.zeros((config.num_envs_per_device, 1), dtype=jnp.int32), "prev_reward": jnp.zeros((config.num_envs_per_device, 1)), } @@ -142,6 +150,8 @@ def train( train_state: TrainState, init_hstate: jax.Array, ): + eval_hstate = init_hstate[0][None] + # META TRAIN LOOP def _meta_step(meta_state, _): rng, train_state = meta_state @@ -171,7 +181,8 @@ def _env_step(runner_state, _): train_state.params, { # [batch_size, seq_len=1, ...] - "observation": prev_timestep.observation[:, None], + "obs_img": prev_timestep.observation["img"][:, None], + "obs_dir": prev_timestep.observation["direction"][:, None], "prev_action": prev_action[:, None], "prev_reward": prev_reward[:, None], }, @@ -190,7 +201,8 @@ def _env_step(runner_state, _): value=value, reward=timestep.reward, log_prob=log_prob, - obs=prev_timestep.observation, + obs=prev_timestep.observation["img"], + dir=prev_timestep.observation["direction"], prev_action=prev_action, prev_reward=prev_reward, ) @@ -207,7 +219,8 @@ def _env_step(runner_state, _): _, last_val, _ = train_state.apply_fn( train_state.params, { - "observation": timestep.observation[:, None], + "obs_img": timestep.observation["img"][:, None], + "obs_dir": timestep.observation["direction"][:, None], "prev_action": prev_action[:, None], "prev_reward": prev_reward[:, None], }, @@ -281,8 +294,7 @@ def _update_minbatch(train_state, batch_info): env, eval_env_params, train_state, - # TODO: make this a static method? - jnp.zeros((1, config.rnn_num_layers, config.rnn_hidden_dim)), + eval_hstate, config.eval_num_episodes, ) eval_stats = jax.lax.pmean(eval_stats, axis_name="devices") diff --git a/training/train_single_task.py b/training/train_single_task.py index 65252fa..54041ec 100644 --- a/training/train_single_task.py +++ b/training/train_single_task.py @@ -18,7 +18,7 @@ from nn import ActorCriticRNN from utils import Transition, calculate_gae, ppo_update_networks, rollout from xminigrid.environment import Environment, EnvParams -from xminigrid.wrappers import GymAutoResetWrapper +from xminigrid.wrappers import DirectionObservationWrapper, GymAutoResetWrapper # this will be default in new jax versions anyway jax.config.update("jax_threefry_partitionable", True) @@ -34,11 +34,13 @@ class TrainConfig: ruleset_id: Optional[int] = None img_obs: bool = False # agent + obs_emb_dim: int = 16 action_emb_dim: int = 16 rnn_hidden_dim: int = 1024 rnn_num_layers: int = 1 head_hidden_dim: int = 256 # training + enable_bf16: bool = False num_envs: int = 8192 num_steps: int = 16 update_epochs: int = 1 @@ -74,6 +76,7 @@ def linear_schedule(count): # setup environment env, env_params = xminigrid.make(config.env_id) env = GymAutoResetWrapper(env) + env = DirectionObservationWrapper(env) # for single-task XLand environments if config.benchmark_id is not None: @@ -94,15 +97,21 @@ def linear_schedule(count): network = ActorCriticRNN( num_actions=env.num_actions(env_params), + obs_emb_dim=config.obs_emb_dim, action_emb_dim=config.action_emb_dim, rnn_hidden_dim=config.rnn_hidden_dim, rnn_num_layers=config.rnn_num_layers, head_hidden_dim=config.head_hidden_dim, img_obs=config.img_obs, + dtype=jnp.bfloat16 if config.enable_bf16 else None, ) + # [batch_size, seq_len, ...] + shapes = env.observation_shape(env_params) + init_obs = { - "observation": jnp.zeros((config.num_envs_per_device, 1, *env.observation_shape(env_params))), + "obs_img": jnp.zeros((config.num_envs_per_device, 1, *shapes["img"])), + "obs_dir": jnp.zeros((config.num_envs_per_device, 1, shapes["direction"])), "prev_action": jnp.zeros((config.num_envs_per_device, 1), dtype=jnp.int32), "prev_reward": jnp.zeros((config.num_envs_per_device, 1)), } @@ -149,7 +158,8 @@ def _env_step(runner_state, _): train_state.params, { # [batch_size, seq_len=1, ...] - "observation": prev_timestep.observation[:, None], + "obs_img": prev_timestep.observation["img"][:, None], + "obs_dir": prev_timestep.observation["direction"][:, None], "prev_action": prev_action[:, None], "prev_reward": prev_reward[:, None], }, @@ -167,7 +177,8 @@ def _env_step(runner_state, _): value=value, reward=timestep.reward, log_prob=log_prob, - obs=prev_timestep.observation, + obs=prev_timestep.observation["img"], + dir=prev_timestep.observation["direction"], prev_action=prev_action, prev_reward=prev_reward, ) @@ -184,7 +195,8 @@ def _env_step(runner_state, _): _, last_val, _ = train_state.apply_fn( train_state.params, { - "observation": timestep.observation[:, None], + "obs_img": timestep.observation["img"][:, None], + "obs_dir": timestep.observation["direction"][:, None], "prev_action": prev_action[:, None], "prev_reward": prev_reward[:, None], }, diff --git a/training/utils.py b/training/utils.py index 9644054..4b59fb4 100644 --- a/training/utils.py +++ b/training/utils.py @@ -13,7 +13,9 @@ class Transition(struct.PyTreeNode): value: jax.Array reward: jax.Array log_prob: jax.Array + # for obs obs: jax.Array + dir: jax.Array # for rnn policy prev_action: jax.Array prev_reward: jax.Array @@ -61,7 +63,8 @@ def _loss_fn(params): params, { # [batch_size, seq_len, ...] - "observation": transitions.obs, + "obs_img": transitions.obs, + "obs_dir": transitions.dir, "prev_action": transitions.prev_action, "prev_reward": transitions.prev_reward, }, @@ -74,6 +77,7 @@ def _loss_fn(params): value_loss = jnp.square(value - targets) value_loss_clipped = jnp.square(value_pred_clipped - targets) value_loss = 0.5 * jnp.maximum(value_loss, value_loss_clipped).mean() + # TODO: ablate this! # value_loss = jnp.square(value - targets).mean() @@ -126,7 +130,8 @@ def _body_fn(carry): dist, _, hstate = train_state.apply_fn( train_state.params, { - "observation": timestep.observation[None, None, ...], + "obs_img": timestep.observation["img"][None, None, ...], + "obs_dir": timestep.observation["direction"][None, None, ...], "prev_action": prev_action[None, None, ...], "prev_reward": prev_reward[None, None, ...], },