diff --git a/src/imitation/data/serialize.py b/src/imitation/data/serialize.py index ee9d0d3f8..7db00c3c6 100644 --- a/src/imitation/data/serialize.py +++ b/src/imitation/data/serialize.py @@ -19,7 +19,7 @@ def save(path: AnyPath, trajectories: Sequence[Trajectory]) -> None: path: Trajectories are saved to this path. trajectories: The trajectories to save. """ - p = util.parse_path(path) + p = str(util.parse_path(path)) huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p) logging.info(f"Dumped demonstrations to {p}.") diff --git a/tests/data/conftest.py b/tests/data/conftest.py new file mode 100644 index 000000000..81903cee7 --- /dev/null +++ b/tests/data/conftest.py @@ -0,0 +1,62 @@ +import gymnasium as gym +import numpy as np +import pytest + +from imitation.data import types + +SPACES = [ + gym.spaces.Discrete(3), + gym.spaces.MultiDiscrete([3, 4]), + gym.spaces.Box(-1, 1, shape=(1,)), + gym.spaces.Box(-1, 1, shape=(2,)), + gym.spaces.Box(-np.inf, np.inf, shape=(2,)), +] +DICT_SPACE = gym.spaces.Dict( + {"a": gym.spaces.Discrete(3), "b": gym.spaces.Box(-1, 1, shape=(2,))}, +) +LENGTHS = [0, 1, 2, 10] + + +@pytest.fixture(params=SPACES) +def act_space(request): + return request.param + + +@pytest.fixture(params=SPACES + [DICT_SPACE]) +def obs_space(request): + return request.param + + +@pytest.fixture(params=LENGTHS) +def length(request): + return request.param + + +@pytest.fixture +def trajectory( + obs_space: gym.Space, + act_space: gym.Space, + length: int, +) -> types.Trajectory: + """Fixture to generate trajectory of length `length` iid sampled from spaces.""" + if length == 0: + pytest.skip() + + raw_obs = [obs_space.sample() for _ in range(length + 1)] + if isinstance(obs_space, gym.spaces.Dict): + obs: types.Observation = types.DictObs.from_obs_list(raw_obs) + else: + obs = np.array(raw_obs) + acts = np.array([act_space.sample() for _ in range(length)]) + infos = np.array([{f"key{i}": i} for i in range(length)]) + return types.Trajectory(obs=obs, acts=acts, infos=infos, terminal=True) + + +@pytest.fixture +def trajectory_rew(trajectory: types.Trajectory) -> types.TrajectoryWithRew: + """Like `trajectory` but with reward randomly sampled from a Gaussian.""" + rews = np.random.randn(len(trajectory)) + return types.TrajectoryWithRew( + **types.dataclass_quick_asdict(trajectory), + rews=rews, + ) diff --git a/tests/data/test_serialize.py b/tests/data/test_serialize.py new file mode 100644 index 000000000..1af6805c3 --- /dev/null +++ b/tests/data/test_serialize.py @@ -0,0 +1,63 @@ +"""Tests for `imitation.data.serialize`.""" + +import pathlib + +import gymnasium as gym +import numpy as np +import pytest + +from imitation.data import serialize, types +from imitation.data.types import DictObs + + +@pytest.fixture +def data_path(tmp_path): + return tmp_path / "data" + + +@pytest.mark.parametrize("path_type", [str, pathlib.Path]) +def test_save_trajectory(data_path, trajectory, path_type): + if isinstance(trajectory.obs, DictObs): + pytest.skip("serialize.save does not yet support DictObs") + + serialize.save(path_type(data_path), [trajectory]) + assert data_path.exists() + + +@pytest.mark.parametrize("path_type", [str, pathlib.Path]) +def test_save_trajectory_rew(data_path, trajectory_rew, path_type): + if isinstance(trajectory_rew.obs, DictObs): + pytest.skip("serialize.save does not yet support DictObs") + serialize.save(path_type(data_path), [trajectory_rew]) + assert data_path.exists() + + +def test_save_load_trajectory(data_path, trajectory): + if isinstance(trajectory.obs, DictObs): + pytest.skip("serialize.save does not yet support DictObs") + serialize.save(data_path, [trajectory]) + + reconstructed = list(serialize.load(data_path)) + reconstructedi = reconstructed[0] + + assert len(reconstructed) == 1 + assert np.allclose(reconstructedi.obs, trajectory.obs) + assert np.allclose(reconstructedi.acts, trajectory.acts) + assert np.allclose(reconstructedi.terminal, trajectory.terminal) + assert not hasattr(reconstructedi, "rews") + + +@pytest.mark.parametrize("load_fn", [serialize.load, serialize.load_with_rewards]) +def test_save_load_trajectory_rew(data_path, trajectory_rew, load_fn): + if isinstance(trajectory_rew.obs, DictObs): + pytest.skip("serialize.save does not yet support DictObs") + serialize.save(data_path, [trajectory_rew]) + + reconstructed = list(load_fn(data_path)) + reconstructedi = reconstructed[0] + + assert len(reconstructed) == 1 + assert np.allclose(reconstructedi.obs, trajectory_rew.obs) + assert np.allclose(reconstructedi.acts, trajectory_rew.acts) + assert np.allclose(reconstructedi.terminal, trajectory_rew.terminal) + assert np.allclose(reconstructedi.rews, trajectory_rew.rews) diff --git a/tests/data/test_types.py b/tests/data/test_types.py index 74c658c26..fcb77e143 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -15,21 +15,6 @@ from imitation.data import serialize, types from imitation.util import util -SPACES = [ - gym.spaces.Discrete(3), - gym.spaces.MultiDiscrete([3, 4]), - gym.spaces.Box(-1, 1, shape=(1,)), - gym.spaces.Box(-1, 1, shape=(2,)), - gym.spaces.Box(-np.inf, np.inf, shape=(2,)), -] -DICT_SPACE = gym.spaces.Dict( - {"a": gym.spaces.Discrete(3), "b": gym.spaces.Box(-1, 1, shape=(2,))}, -) - -OBS_SPACES = SPACES + [DICT_SPACE] -ACT_SPACES = SPACES -LENGTHS = [0, 1, 2, 10] - def _check_1d_shape(fn: Callable[[np.ndarray], Any], length: int, expected_msg: str): for shape in [(), (length, 1), (length, 2), (length - 1,), (length + 1,)]: @@ -37,36 +22,6 @@ def _check_1d_shape(fn: Callable[[np.ndarray], Any], length: int, expected_msg: fn(np.zeros(shape)) -@pytest.fixture -def trajectory( - obs_space: gym.Space, - act_space: gym.Space, - length: int, -) -> types.Trajectory: - """Fixture to generate trajectory of length `length` iid sampled from spaces.""" - if length == 0: - pytest.skip() - - raw_obs = [obs_space.sample() for _ in range(length + 1)] - if isinstance(obs_space, gym.spaces.Dict): - obs: types.Observation = types.DictObs.from_obs_list(raw_obs) - else: - obs = np.array(raw_obs) - acts = np.array([act_space.sample() for _ in range(length)]) - infos = np.array([{f"key{i}": i} for i in range(length)]) - return types.Trajectory(obs=obs, acts=acts, infos=infos, terminal=True) - - -@pytest.fixture -def trajectory_rew(trajectory: types.Trajectory) -> types.TrajectoryWithRew: - """Like `trajectory` but with reward randomly sampled from a Gaussian.""" - rews = np.random.randn(len(trajectory)) - return types.TrajectoryWithRew( - **types.dataclass_quick_asdict(trajectory), - rews=rews, - ) - - @pytest.fixture def transitions_min( obs_space: gym.Space, @@ -134,9 +89,6 @@ def pushd(dir_path): os.chdir(orig_dir) -@pytest.mark.parametrize("obs_space", OBS_SPACES) -@pytest.mark.parametrize("act_space", ACT_SPACES) -@pytest.mark.parametrize("length", LENGTHS) class TestData: """Tests of imitation.util.data.