From 7d379334896961c53d4211f40dbb865d97cb9001 Mon Sep 17 00:00:00 2001 From: Reggie McLean Date: Fri, 30 Aug 2024 12:31:06 -0400 Subject: [PATCH] pre-commit --- metaworld/__init__.py | 34 +++++++++------------------------- metaworld/wrappers.py | 2 +- 2 files changed, 10 insertions(+), 26 deletions(-) diff --git a/metaworld/__init__.py b/metaworld/__init__.py index 19d5d5e1..156de0f1 100644 --- a/metaworld/__init__.py +++ b/metaworld/__init__.py @@ -337,9 +337,11 @@ def init_each_env( env_cls: type[SawyerXYZEnv], name: str, seed: int | None ) -> gym.Env: env = env_cls() + if seed: + env.seed(seed) env = gym.wrappers.TimeLimit(env, max_episode_steps or env.max_path_length) - if terminate_on_success: - env = AutoTerminateOnSuccessWrapper(env) + env = AutoTerminateOnSuccessWrapper(env) + env.toggle_terminate_on_success(terminate_on_success) env = gym.wrappers.RecordEpisodeStatistics(env) if use_one_hot: assert env_id is not None, "Need to pass env_id through constructor" @@ -349,29 +351,9 @@ def init_each_env( env = RandomTaskSelectWrapper(env, tasks, seed=seed) return env - if "MT1-" in name: - name = name.replace("MT1-", "") - benchmark = MT1(name, seed=seed) - return init_each_env( - env_cls=benchmark.train_classes[name], name=name, seed=seed - ) - elif "ML1-" in name: - benchmark = ML1( - name.replace("ML1-train-" if "train" in name else "ML1-test-", ""), - seed=seed, - ) # type: ignore - if "train" in name: - return init_each_env( - env_cls=benchmark.train_classes[name.replace("ML1-train-", "")], - name=name + "-train", - seed=seed, - ) # type: ignore - elif "test" in name: - return init_each_env( - env_cls=benchmark.test_classes[name.replace("ML1-test-", "")], - name=name + "-test", - seed=seed, - ) + name = name.replace("MT1-", "") + benchmark = MT1(name, seed=seed) + return init_each_env(env_cls=benchmark.train_classes[name], name=name, seed=seed) make_single_mt = partial(_make_single_env, terminate_on_success=False) @@ -405,6 +387,8 @@ def _make_single_ml( def make_env(env_cls: type[SawyerXYZEnv], tasks: list) -> gym.Env: env = env_cls() + if seed: + env.seed(seed) env = gym.wrappers.TimeLimit(env, max_episode_steps or env.max_path_length) env = AutoTerminateOnSuccessWrapper(env) env.toggle_terminate_on_success(terminate_on_success) diff --git a/metaworld/wrappers.py b/metaworld/wrappers.py index d4372937..e7425ccb 100644 --- a/metaworld/wrappers.py +++ b/metaworld/wrappers.py @@ -46,7 +46,7 @@ def _set_random_task(self): def __init__( self, env: Env, - tasks: list[Task], + tasks: List[Task], sample_tasks_on_reset: bool = True, seed: int | None = None, ):