diff --git a/.github/workflows/codestyle.yml b/.github/workflows/codestyle.yml index cb6d575..e216776 100644 --- a/.github/workflows/codestyle.yml +++ b/.github/workflows/codestyle.yml @@ -24,7 +24,7 @@ jobs: pip install -e ".[dev]" - name: check codestyle run: | - ruff --config pyproject.toml --diff . + ruff check --config pyproject.toml --diff . - name: check type hints run: | pyright --project=pyproject.toml src/xminigrid \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 057a42b..c916567 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: # pyright checking - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.350 + rev: v1.1.371 hooks: - id: pyright args: [--project=pyproject.toml] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 99dceb9..5d2396e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,11 +85,10 @@ ignore = [ [tool.ruff.format] skip-magic-trailing-comma = false -[tool.ruff.isort] +[tool.ruff.lint.isort] # see https://github.com/astral-sh/ruff/issues/8571 known-third-party = ["wandb"] - [tool.pyright] include = ["src/xminigrid"] exclude = [ diff --git a/src/xminigrid/environment.py b/src/xminigrid/environment.py index 59b7441..616207a 100644 --- a/src/xminigrid/environment.py +++ b/src/xminigrid/environment.py @@ -41,7 +41,7 @@ def default_params(self, **kwargs: Any) -> EnvParamsT: def num_actions(self, params: EnvParamsT) -> int: return int(NUM_ACTIONS) - def observation_shape(self, params: EnvParamsT) -> tuple[int, int, int]: + def observation_shape(self, params: EnvParamsT) -> tuple[int, int, int] | dict[str, Any]: return params.view_size, params.view_size, NUM_LAYERS @abc.abstractmethod diff --git a/src/xminigrid/experimental/img_obs.py b/src/xminigrid/experimental/img_obs.py index cc854b4..0f7ca85 100644 --- a/src/xminigrid/experimental/img_obs.py +++ b/src/xminigrid/experimental/img_obs.py @@ -82,14 +82,33 @@ def _render_obs(obs: jax.Array) -> jax.Array: class RGBImgObservationWrapper(Wrapper): def observation_shape(self, params): - return params.view_size * TILE_SIZE, params.view_size * TILE_SIZE, 3 + new_shape = (params.view_size * TILE_SIZE, params.view_size * TILE_SIZE, 3) + + base_shape = self._env.observation_shape(params) + if isinstance(base_shape, dict): + assert "img" in base_shape + obs_shape = {**base_shape, **{"img": new_shape}} + else: + obs_shape = new_shape + + return obs_shape + + def __convert_obs(self, timestep): + if isinstance(timestep.observation, dict): + assert "img" in timestep.observation + rendered_obs = {**timestep.observation, **{"img": _render_obs(timestep.observation["img"])}} + else: + rendered_obs = _render_obs(timestep.observation) + + timestep = timestep.replace(observation=rendered_obs) + return timestep def reset(self, params, key): timestep = self._env.reset(params, key) - timestep = timestep.replace(observation=_render_obs(timestep.observation)) + timestep = self.__convert_obs(timestep) return timestep def step(self, params, timestep, action): timestep = self._env.step(params, timestep, action) - timestep = timestep.replace(observation=_render_obs(timestep.observation)) + timestep = self.__convert_obs(timestep) return timestep diff --git a/src/xminigrid/types.py b/src/xminigrid/types.py index bcfd27a..f4547d1 100644 --- a/src/xminigrid/types.py +++ b/src/xminigrid/types.py @@ -56,7 +56,7 @@ class TimeStep(struct.PyTreeNode, Generic[EnvCarryT]): step_type: StepType reward: jax.Array discount: jax.Array - observation: jax.Array + observation: jax.Array | dict[str, jax.Array] def first(self): return self.step_type == StepType.FIRST diff --git a/src/xminigrid/wrappers.py b/src/xminigrid/wrappers.py index e75557d..a5bf790 100644 --- a/src/xminigrid/wrappers.py +++ b/src/xminigrid/wrappers.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + import jax from .environment import Environment, EnvParamsT @@ -19,7 +21,7 @@ def default_params(self, **kwargs) -> EnvParamsT: def num_actions(self, params: EnvParamsT) -> int: return self._env.num_actions(params) - def observation_shape(self, params: EnvParamsT) -> tuple[int, int, int]: + def observation_shape(self, params: EnvParamsT) -> tuple[int, int, int] | dict[str, Any]: return self._env.observation_shape(params) def _generate_problem(self, params: EnvParamsT, key: jax.Array) -> State[EnvCarryT]: @@ -67,3 +69,102 @@ def step(self, params, timestep, action): lambda: self._env.step(params, timestep, action), ) return timestep + + +# Yes, these are a bit stupid, but a tmp workaround to not write an actual system for spaces. +# May be, in the future, I will port the entire API to some existing one, like functional Gymnasium. +# For now, faster to do this stuff with dicts instead... +# NB: if you do not want to use this (due to the dicts as obs), +# just get needed parts from the original TimeStep and State dataclasses +class DirectionObservationWrapper(Wrapper): + def observation_shape(self, params): + base_shape = self._env.observation_shape(params) + if isinstance(base_shape, dict): + assert "img" in base_shape + obs_shape = {**base_shape, **{"direction": 4}} + else: + obs_shape = { + "img": self._env.observation_shape(params), + "direction": 4, + } + return obs_shape + + def __extend_obs(self, timestep): + direction = jax.nn.one_hot(timestep.state.agent.direction, num_classes=4) + if isinstance(timestep.observation, dict): + assert "img" in timestep.observation + extended_obs = { + **timestep.observation, + **{"direction": direction}, + } + else: + extended_obs = { + "img": timestep.observation, + "direction": direction, + } + + timestep = timestep.replace(observation=extended_obs) + return timestep + + def reset(self, params, key): + timestep = self._env.reset(params, key) + timestep = self.__extend_obs(timestep) + return timestep + + def step(self, params, timestep, action): + timestep = self._env.step(params, timestep, action) + timestep = self.__extend_obs(timestep) + return timestep + + +class RulesAndGoalsObservationWrapper(Wrapper): + def observation_shape(self, params): + base_shape = self._env.observation_shape(params) + if isinstance(base_shape, dict): + assert "img" in base_shape + obs_shape = { + **base_shape, + **{ + "goal_encoding": params.ruleset.goal.shape, + "rule_encoding": params.ruleset.rules.shape, + }, + } + else: + obs_shape = { + "img": self._env.observation_shape(params), + "goal_encoding": params.ruleset.goal.shape, + "rule_encoding": params.ruleset.rules.shape, + } + return obs_shape + + def __extend_obs(self, timestep): + goal_encoding = timestep.state.goal_encoding + rule_encoding = timestep.state.rule_encoding + if isinstance(timestep.observation, dict): + assert "img" in timestep.observation + extended_obs = { + **timestep.observation, + **{ + "goal_encoding": goal_encoding, + "rule_encoding": rule_encoding, + }, + } + else: + extended_obs = { + "img": timestep.observation, + "goal_encoding": goal_encoding, + "rule_encoding": rule_encoding, + } + + timestep = timestep.replace(observation=extended_obs) + return timestep + + def reset(self, params, key): + timestep = self._env.reset(params, key) + timestep = self.__extend_obs(timestep) + return timestep + + def step(self, params, timestep, action): + timestep = self._env.step(params, timestep, action) + timestep = self.__extend_obs(timestep) + return timestep