diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..42ee230 --- /dev/null +++ b/.gitignore @@ -0,0 +1,134 @@ +# vscode +.vscode/ +wandb/ +.DS_Store + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/examples/dueling_navix_empty.py b/examples/dueling_navix_empty.py index 1808f13..ca8a411 100644 --- a/examples/dueling_navix_empty.py +++ b/examples/dueling_navix_empty.py @@ -30,7 +30,7 @@ def main(argv): # environment env = nx.environments.Room(5, 5, 100, observation_fn=nx.observations.categorical) - env = helx.environment.to_helx(env) # type: ignore + env = helx.envs.interop.to_helx(env) # optimiser optimiser = optax.rmsprop( diff --git a/helx/agents/dqn.py b/helx/agents/dqn.py index 9bff670..8b5d3c2 100644 --- a/helx/agents/dqn.py +++ b/helx/agents/dqn.py @@ -29,7 +29,6 @@ from helx.base.mdp import TERMINATION, Timestep from helx.base.memory import ReplayBuffer from helx.base.spaces import Discrete -from helx.base import losses from .agent import Agent, HParams, Log, AgentState diff --git a/helx/agents/dueling_dqn.py b/helx/agents/dueling_dqn.py index c22a4ef..2db387d 100644 --- a/helx/agents/dueling_dqn.py +++ b/helx/agents/dueling_dqn.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations -from functools import partial import jax.numpy as jnp import optax @@ -36,6 +35,9 @@ class DuelingDQNState(DQNState): class DuelingDQN(DQN): + """Dueling DQN agent as described in https://arxiv.org/abs/1511.06581 + Uses the average operator version to combine the advantage and value functions.""" + hparams: DuelingDQNHParams = struct.field(pytree_node=True) optimiser: optax.GradientTransformation = struct.field(pytree_node=True) critic: nn.Module = struct.field(pytree_node=True) @@ -52,7 +54,10 @@ def create( backbone, Split(2), Parallel((nn.Dense(1), nn.Dense(hparams.action_space.maximum))), # v, A - Merge(partial(jnp.sum, axis=-1)) # q = v + A + Merge( + lambda inputs: inputs[0] + + (inputs[1] - jnp.mean(inputs[1], axis=-1)) + ), # q = v + (A - mean(A)) ] ) return DuelingDQN( diff --git a/helx/base/modules.py b/helx/base/modules.py index aaf4a67..b338c4f 100644 --- a/helx/base/modules.py +++ b/helx/base/modules.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations -from functools import partial +from functools import partial from typing import Callable, Sequence, Tuple -import flax.linen as nn +from jax import Array import jax.numpy as jnp import jax.tree_util as jtu -from jax import Array +import flax.linen as nn class Split(nn.Module): diff --git a/helx/envs/__init__.py b/helx/envs/__init__.py index 0abeeb0..74454f9 100644 --- a/helx/envs/__init__.py +++ b/helx/envs/__init__.py @@ -24,5 +24,5 @@ gymnasium, gymnax, interop, - # navix + navix ) diff --git a/requirements.txt b/requirements.txt index 04bba32..78f535b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,4 +21,5 @@ minigrid procgen gymnax brax +navix wandb \ No newline at end of file