Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
TheEimer committed Aug 10, 2023
1 parent 6cffcfc commit 8edacbe
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 11 deletions.
2 changes: 1 addition & 1 deletion carl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import datetime

name = "CARL"
package_name = "carl-gym"
package_name = "carl-bench"
author = __author__

author_email = "benjamins@tnt.uni-hannover.de"
Expand Down
11 changes: 6 additions & 5 deletions carl/envs/gymnasium/classic_control/carl_acrobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ def reset(
size=(2,),
)
self.env.unwrapped.state = np.concatenate([angles, velocities])
return (
np.array(
state = np.array(
[
np.cos(self.env.unwrapped.state[0]),
np.sin(self.env.unwrapped.state[0]),
Expand All @@ -96,6 +95,8 @@ def reset(
self.env.unwrapped.state[3],
],
dtype=np.float32,
),
{},
)
)
info = {}
state = self._add_context_to_state(state)
info["context_id"] = self.context_id
return state, info
6 changes: 5 additions & 1 deletion carl/envs/gymnasium/classic_control/carl_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,8 @@ def reset(
high=self.context["initial_state_upper"],
size=(4,),
)
return np.array(self.env.unwrapped.state, dtype=np.float32), {}
state = np.array(self.env.unwrapped.state, dtype=np.float32)
info = {}
state = self._add_context_to_state(state)
info["context_id"] = self.context_id
return state, info
6 changes: 5 additions & 1 deletion carl/envs/gymnasium/classic_control/carl_mountaincar.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,8 @@ def reset(
high=self.context["max_velocity_start"],
)
self.env.unwrapped.state = np.array([position, velocity])
return np.array(self.env.unwrapped.state, dtype=np.float32), {}
state = np.array(self.env.unwrapped.state, dtype=np.float32)
info = {}
state = self._add_context_to_state(state)
info["context_id"] = self.context_id
return state, info
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,8 @@ def reset(
high=self.context["max_velocity_start"],
)
self.env.unwrapped.state = np.array([position, velocity])
return np.array(self.env.unwrapped.state, dtype=np.float32), {}
state = np.array(self.env.unwrapped.state, dtype=np.float32)
info = {}
state = self._add_context_to_state(state)
info["context_id"] = self.context_id
return state, info
6 changes: 5 additions & 1 deletion carl/envs/gymnasium/classic_control/carl_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,8 @@ def reset(
theta = self.env.np_random.uniform(high=self.context["initial_angle_max"])
thetadot = self.env.np_random.uniform(high=self.context["initial_velocity_max"])
self.env.unwrapped.state = np.array([theta, thetadot], dtype=np.float32)
return np.array([np.cos(theta), np.sin(theta), thetadot], dtype=np.float32), {}
state = np.array([np.cos(theta), np.sin(theta), thetadot], dtype=np.float32)
info = {}
state = self._add_context_to_state(state)
info["context_id"] = self.context_id
return state, info
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@
" _states = []\n",
" for i in tqdm(range(n_initial_states)):\n",
" s, _ = env.reset()\n",
" _states.append(s)\n",
" _states.append(s[\"obs\"])\n",
" _renders.append(env.render())\n",
" renders[env_cls.__name__] = _renders\n",
" states[env_cls.__name__] = np.array(_states)\n",
Expand Down

0 comments on commit 8edacbe

Please sign in to comment.