Skip to content

Commit

Permalink
Type hints, docstrings & refactor (#475)
Browse files Browse the repository at this point in the history
  • Loading branch information
rainx0r authored May 1, 2024
1 parent 6846120 commit 83ac03c
Show file tree
Hide file tree
Showing 207 changed files with 4,182 additions and 5,223 deletions.
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ repos:
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.6.1"
hooks:
- id: mypy
exclude: docs/
args: [--ignore-missing-imports]
additional_dependencies: [numpy==1.26.1]
# - repo: https://github.com/pycqa/pydocstyle
# rev: 6.3.0
# hooks:
Expand Down
235 changes: 149 additions & 86 deletions metaworld/__init__.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,37 @@
"""Proposal for a simple, understandable MetaWorld API."""
"""The public-facing Metaworld API."""

from __future__ import annotations

import abc
import pickle
from collections import OrderedDict
from typing import List, NamedTuple, Type
from typing import Any

import numpy as np
import numpy.typing as npt

import metaworld.envs.mujoco.env_dict as _env_dict

EnvName = str


class Task(NamedTuple):
"""All data necessary to describe a single MDP.
Should be passed into a MetaWorldEnv's set_task method.
"""

env_name: EnvName
data: bytes # Contains env parameters like random_init and *a* goal
from metaworld.types import Task


class MetaWorldEnv:
class MetaWorldEnv(abc.ABC):
"""Environment that requires a task before use.
Takes no arguments to its constructor, and raises an exception if used
before `set_task` is called.
"""

@abc.abstractmethod
def set_task(self, task: Task) -> None:
"""Set the task.
"""Sets the task.
Raises:
ValueError: If task.env_name is different from the current task.
Args:
task: The task to set.
Raises:
ValueError: If `task.env_name` is different from the current task.
"""
raise NotImplementedError


class Benchmark(abc.ABC):
Expand All @@ -43,132 +40,224 @@ class Benchmark(abc.ABC):
When used to evaluate an algorithm, only a single instance should be used.
"""

_train_classes: _env_dict.EnvDict
_test_classes: _env_dict.EnvDict
_train_tasks: list[Task]
_test_tasks: list[Task]

@abc.abstractmethod
def __init__(self):
pass

@property
def train_classes(self) -> "OrderedDict[EnvName, Type]":
"""Get all of the environment classes used for training."""
def train_classes(self) -> _env_dict.EnvDict:
"""Returns all of the environment classes used for training."""
return self._train_classes

@property
def test_classes(self) -> "OrderedDict[EnvName, Type]":
"""Get all of the environment classes used for testing."""
def test_classes(self) -> _env_dict.EnvDict:
"""Returns all of the environment classes used for testing."""
return self._test_classes

@property
def train_tasks(self) -> List[Task]:
"""Get all of the training tasks for this benchmark."""
def train_tasks(self) -> list[Task]:
"""Returns all of the training tasks for this benchmark."""
return self._train_tasks

@property
def test_tasks(self) -> List[Task]:
"""Get all of the test tasks for this benchmark."""
def test_tasks(self) -> list[Task]:
"""Returns all of the test tasks for this benchmark."""
return self._test_tasks


_ML_OVERRIDE = dict(partially_observable=True)
"""The overrides for the Meta-Learning benchmarks. Disables the inclusion of the goal position in the observation."""

_MT_OVERRIDE = dict(partially_observable=False)
"""The overrides for the Multi-Task benchmarks. Enables the inclusion of the goal position in the observation."""

_N_GOALS = 50
"""The number of goals to generate for each environment."""


def _encode_task(env_name, data) -> Task:
"""Instantiates a new `Task` object after pickling the data.
def _encode_task(env_name, data):
Args:
env_name: The name of the environment.
data: The task data (will be pickled).
Returns:
A `Task` object.
"""
return Task(env_name=env_name, data=pickle.dumps(data))


def _make_tasks(classes, args_kwargs, kwargs_override, seed=None):
def _make_tasks(
classes: _env_dict.EnvDict,
args_kwargs: _env_dict.EnvArgsKwargsDict,
kwargs_override: dict,
seed: int | None = None,
) -> list[Task]:
"""Initialises goals for a given set of environments.
Args:
classes: The environment classes as an `EnvDict`.
args_kwargs: The environment arguments and keyword arguments.
kwargs_override: Any kwarg overrides.
seed: The random seed to use.
Returns:
A flat list of `Task` objects, `_N_GOALS` for each environment in `classes`.
"""
# Cache existing random state
if seed is not None:
st0 = np.random.get_state()
np.random.seed(seed)

tasks = []
for env_name, args in args_kwargs.items():
kwargs = args["kwargs"].copy()
assert isinstance(kwargs, dict)
assert len(args["args"]) == 0

# Init env
env = classes[env_name]()
env._freeze_rand_vec = False
env._set_task_called = True
rand_vecs = []
kwargs = args["kwargs"].copy()
rand_vecs: list[npt.NDArray[Any]] = []

# Set task
del kwargs["task_id"]
env._set_task_inner(**kwargs)
for _ in range(_N_GOALS):

for _ in range(_N_GOALS): # Generate random goals
env.reset()
assert env._last_rand_vec is not None
rand_vecs.append(env._last_rand_vec)

unique_task_rand_vecs = np.unique(np.array(rand_vecs), axis=0)
assert unique_task_rand_vecs.shape[0] == _N_GOALS, unique_task_rand_vecs.shape[
0
]
assert (
unique_task_rand_vecs.shape[0] == _N_GOALS
), f"Only generated {unique_task_rand_vecs.shape[0]} unique goals, not {_N_GOALS}"
env.close()

# Create a task for each random goal
for rand_vec in rand_vecs:
kwargs = args["kwargs"].copy()
assert isinstance(kwargs, dict)
del kwargs["task_id"]

kwargs.update(dict(rand_vec=rand_vec, env_cls=classes[env_name]))
kwargs.update(kwargs_override)

tasks.append(_encode_task(env_name, kwargs))

del env

# Restore random state
if seed is not None:
np.random.set_state(st0)

return tasks


def _ml1_env_names():
tasks = list(_env_dict.ML1_V2["train"])
assert len(tasks) == 50
return tasks
# MT Benchmarks


class ML1(Benchmark):
ENV_NAMES = _ml1_env_names()
class MT1(Benchmark):
"""The MT1 benchmark. A goal-conditioned RL environment for a single Metaworld task."""

ENV_NAMES = list(_env_dict.ALL_V2_ENVIRONMENTS.keys())

def __init__(self, env_name, seed=None):
super().__init__()
if env_name not in _env_dict.ALL_V2_ENVIRONMENTS:
raise ValueError(f"{env_name} is not a V2 environment")
cls = _env_dict.ALL_V2_ENVIRONMENTS[env_name]
self._train_classes = OrderedDict([(env_name, cls)])
self._test_classes = self._train_classes
self._train_ = OrderedDict([(env_name, cls)])
self._test_classes = OrderedDict([(env_name, cls)])
args_kwargs = _env_dict.ML1_args_kwargs[env_name]

self._train_tasks = _make_tasks(
self._train_classes, {env_name: args_kwargs}, _ML_OVERRIDE, seed=seed
self._train_classes, {env_name: args_kwargs}, _MT_OVERRIDE, seed=seed
)
self._test_tasks = _make_tasks(
self._test_classes,
{env_name: args_kwargs},
_ML_OVERRIDE,
seed=(seed + 1 if seed is not None else seed),

self._test_tasks = []


class MT10(Benchmark):
"""The MT10 benchmark. Contains 10 tasks in its train set. Has an empty test set."""

def __init__(self, seed=None):
super().__init__()
self._train_classes = _env_dict.MT10_V2
self._test_classes = OrderedDict()
train_kwargs = _env_dict.MT10_V2_ARGS_KWARGS
self._train_tasks = _make_tasks(
self._train_classes, train_kwargs, _MT_OVERRIDE, seed=seed
)

self._test_tasks = []
self._test_classes = []


class MT50(Benchmark):
"""The MT50 benchmark. Contains all (50) tasks in its train set. Has an empty test set."""

def __init__(self, seed=None):
super().__init__()
self._train_classes = _env_dict.MT50_V2
self._test_classes = OrderedDict()
train_kwargs = _env_dict.MT50_V2_ARGS_KWARGS
self._train_tasks = _make_tasks(
self._train_classes, train_kwargs, _MT_OVERRIDE, seed=seed
)

self._test_tasks = []
self._test_classes = []


# ML Benchmarks

class MT1(Benchmark):
ENV_NAMES = _ml1_env_names()

class ML1(Benchmark):
"""The ML1 benchmark. A meta-RL environment for a single Metaworld task. The train and test set contain different goal positions.
The goal position is not part of the observation."""

ENV_NAMES = list(_env_dict.ALL_V2_ENVIRONMENTS.keys())

def __init__(self, env_name, seed=None):
super().__init__()
if env_name not in _env_dict.ALL_V2_ENVIRONMENTS:
raise ValueError(f"{env_name} is not a V2 environment")

cls = _env_dict.ALL_V2_ENVIRONMENTS[env_name]
self._train_classes = OrderedDict([(env_name, cls)])
self._test_classes = OrderedDict([(env_name, cls)])
self._test_classes = self._train_classes
args_kwargs = _env_dict.ML1_args_kwargs[env_name]

self._train_tasks = _make_tasks(
self._train_classes, {env_name: args_kwargs}, _MT_OVERRIDE, seed=seed
self._train_classes, {env_name: args_kwargs}, _ML_OVERRIDE, seed=seed
)
self._test_tasks = _make_tasks(
self._test_classes,
{env_name: args_kwargs},
_ML_OVERRIDE,
seed=(seed + 1 if seed is not None else seed),
)

self._test_tasks = []


class ML10(Benchmark):
"""The ML10 benchmark. Contains 10 tasks in its train set and 5 tasks in its test set. The goal position is not part of the observation."""

def __init__(self, seed=None):
super().__init__()
self._train_classes = _env_dict.ML10_V2["train"]
self._test_classes = _env_dict.ML10_V2["test"]
train_kwargs = _env_dict.ml10_train_args_kwargs
train_kwargs = _env_dict.ML10_ARGS_KWARGS["train"]

test_kwargs = _env_dict.ml10_test_args_kwargs
test_kwargs = _env_dict.ML10_ARGS_KWARGS["test"]
self._train_tasks = _make_tasks(
self._train_classes, train_kwargs, _ML_OVERRIDE, seed=seed
)
Expand All @@ -179,12 +268,14 @@ def __init__(self, seed=None):


class ML45(Benchmark):
"""The ML45 benchmark. Contains 45 tasks in its train set and 5 tasks in its test set (50 in total). The goal position is not part of the observation."""

def __init__(self, seed=None):
super().__init__()
self._train_classes = _env_dict.ML45_V2["train"]
self._test_classes = _env_dict.ML45_V2["test"]
train_kwargs = _env_dict.ml45_train_args_kwargs
test_kwargs = _env_dict.ml45_test_args_kwargs
train_kwargs = _env_dict.ML45_ARGS_KWARGS["train"]
test_kwargs = _env_dict.ML45_ARGS_KWARGS["test"]

self._train_tasks = _make_tasks(
self._train_classes, train_kwargs, _ML_OVERRIDE, seed=seed
Expand All @@ -194,32 +285,4 @@ def __init__(self, seed=None):
)


class MT10(Benchmark):
def __init__(self, seed=None):
super().__init__()
self._train_classes = _env_dict.MT10_V2
self._test_classes = OrderedDict()
train_kwargs = _env_dict.MT10_V2_ARGS_KWARGS
self._train_tasks = _make_tasks(
self._train_classes, train_kwargs, _MT_OVERRIDE, seed=seed
)

self._test_tasks = []
self._test_classes = []


class MT50(Benchmark):
def __init__(self, seed=None):
super().__init__()
self._train_classes = _env_dict.MT50_V2
self._test_classes = OrderedDict()
train_kwargs = _env_dict.MT50_V2_ARGS_KWARGS

self._train_tasks = _make_tasks(
self._train_classes, train_kwargs, _MT_OVERRIDE, seed=seed
)

self._test_tasks = []


__all__ = ["ML1", "MT1", "ML10", "MT10", "ML45", "MT50"]
Loading

0 comments on commit 83ac03c

Please sign in to comment.