Skip to content

Commit

Permalink
added test for define instance
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 10, 2024
1 parent 8339e40 commit d76c80e
Showing 1 changed file with 60 additions and 0 deletions.
60 changes: 60 additions & 0 deletions tests/test_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2207,3 +2207,63 @@ def test_get_slice_public(start, end, expected_slice, expected_array, ts):
out_array = ts.t[out_slice]
assert out_slice == expected_slice
assert np.all(out_array == expected_array)


@pytest.mark.parametrize(
"kwargs",
[
{},
{"columns": [1, 2]},
{"metadata": {"banana": [3, 4]}},
{"load_array": False},
{
"columns": ["a", "b"],
"metadata": {"banana": [3, 4]},
"load_array": False
},
]
)
@pytest.mark.parametrize(
"tsd",
[
nap.Ts(t=np.arange(10), time_support=nap.IntervalSet(0,15)),
nap.Ts(t=np.arange(10), time_support=nap.IntervalSet(0,15)),
nap.Tsd(t=np.arange(10), d=np.arange(10), time_support=nap.IntervalSet(0,15)),
nap.Tsd(t=np.arange(10), d=np.arange(10), time_support=nap.IntervalSet(0,15)),
nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 2)), time_support=nap.IntervalSet(0,15), columns=["a", "b"]),
nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 2)), time_support=nap.IntervalSet(0,15), metadata={"pineapple": [1, 2]}),
nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 2)), time_support=nap.IntervalSet(0,15), load_array=True),
nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 2)), time_support=nap.IntervalSet(0,15),load_array=True, columns=["a", "b"], metadata={"pineapple": [1, 2]}),
nap.TsdTensor(t=np.arange(10), d=np.zeros((10, 2, 3)), time_support=nap.IntervalSet(0,15)),
]
)
def test_define_instance(tsd, kwargs):
t = tsd.t
d = getattr(tsd, "d", None)
iset = tsd.time_support
cols = kwargs.get("columns", None)

# metadata index must be cols if provided.
# clear metadata if cols are provided to avoid errors
if (cols is not None) and ("metadata" not in kwargs):
kwargs["metadata"] = {}

out = tsd._define_instance(t, iset, data=d, **kwargs)

# check data
np.testing.assert_array_equal(out.t, t)
np.testing.assert_array_equal(out.time_support, iset)
if hasattr(tsd, "d"):
np.testing.assert_array_equal(out.d, d)

# if TsdFrame check kwargs
if isinstance(tsd, nap.TsdFrame):
for key in ["columns", "load_array"]:
val = kwargs.get(key, getattr(tsd, key))
assert np.all(val == getattr(out, key))
# get expected metadata
meta = kwargs.get("metadata", getattr(tsd, "metadata"))
for key, val, in meta.items():
assert np.all(out.metadata[key] == val)


0 comments on commit d76c80e

Please sign in to comment.