Skip to content

Commit

Permalink
add test_time_series
Browse files Browse the repository at this point in the history
  • Loading branch information
mtorabi59 committed Jul 17, 2024
1 parent 42dfe8f commit bfb675b
Showing 1 changed file with 256 additions and 0 deletions.
256 changes: 256 additions & 0 deletions tests/test_time_series.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import numpy as np
import pytest

from pydfc import TIME_SERIES


def test_create_time_series():
# create a random data with 100 regions and 1000 time points
data = np.random.rand(100, 1000)
# create a random locs with 100 regions and 3 coordinates
locs = np.random.rand(100, 3)
# create a random node_labels list with 100 regions
node_labels = [f"Region {i}" for i in range(100)]
time_series = TIME_SERIES(
data=data,
subj_id="sub-0001",
Fs=2,
locs=locs,
node_labels=node_labels,
)

assert time_series.data.shape == (100, 1000)
assert time_series.subj_id_lst == ["sub-0001"]
assert time_series.Fs == 2
assert time_series.n_time == 1000
assert time_series.n_regions == 100
assert np.all(time_series.nodes_lst == np.arange(0, time_series.n_regions, dtype=int))
assert np.all(time_series.time == 1 / 2 + np.arange(0, 1000 / 2, 1 / 2))
assert time_series.data_dict.keys() == {"sub-0001"}
assert np.all(time_series.data_dict["sub-0001"]["data"] == data)


def test_append_ts():
# create a random data with 100 regions and 1000 time points
data_1 = np.random.rand(100, 1000)
# create a random locs with 100 regions and 3 coordinates
locs = np.random.rand(100, 3)
# create a random node_labels list with 100 regions
node_labels = [f"Region {i}" for i in range(100)]
time_series = TIME_SERIES(
data=data_1,
subj_id="sub-0001",
Fs=2,
locs=locs,
node_labels=node_labels,
)

# create a random data with 100 regions and 1000 time points
data_2 = np.random.rand(100, 1000)
time_series.append_ts(
new_time_series=data_2,
subj_id="sub-0002",
)

assert time_series.data.shape == (100, 2000)
assert time_series.subj_id_lst == ["sub-0001", "sub-0002"]
assert time_series.Fs == 2
assert time_series.n_time is None
assert time_series.n_regions == 100
assert np.all(time_series.nodes_lst == np.arange(0, time_series.n_regions, dtype=int))
assert time_series.time is None
assert time_series.data_dict.keys() == {"sub-0001", "sub-0002"}
assert np.all(time_series.data_dict["sub-0001"]["data"] == data_1)
assert np.all(time_series.data_dict["sub-0002"]["data"] == data_2)
assert np.all(
time_series.data_dict["sub-0001"]["time_array"]
== 1 / 2 + np.arange(0, 1000 / 2, 1 / 2)
)
assert np.all(
time_series.data_dict["sub-0002"]["time_array"]
== 1 / 2 + np.arange(0, 1000 / 2, 1 / 2)
)

# check if visualization will raise warning correctly
with pytest.warns(
UserWarning, match="Multiple subjects are not supported in visualization."
):
time_series.visualize()


def test_concat_ts():
# create a random data with 100 regions and 1000 time points
data_1 = np.random.rand(100, 1000)
# create a random locs with 100 regions and 3 coordinates
locs = np.random.rand(100, 3)
# create a random node_labels list with 100 regions
node_labels = [f"Region {i}" for i in range(100)]
time_series = TIME_SERIES(
data=data_1,
subj_id="sub-0001",
Fs=2,
locs=locs,
node_labels=node_labels,
)

# create a random data with 100 regions and 1000 time points
data_2 = np.random.rand(100, 1000)
locs_2 = np.random.rand(100, 3)
node_labels_2 = [f"Parcel {i}" for i in range(100)]

time_series_2 = TIME_SERIES(
data=data_2,
subj_id="sub-0002",
Fs=1,
locs=locs,
node_labels=node_labels,
)

# check Fs mismatch assert error
with pytest.raises(AssertionError, match="Fs mismatch!"):
time_series.concat_ts(time_series_2)

time_series_2 = TIME_SERIES(
data=data_2,
subj_id="sub-0002",
Fs=2,
locs=locs_2,
node_labels=node_labels,
)
# check locs mismatch assert error
with pytest.raises(AssertionError, match="locs mismatch!"):
time_series.concat_ts(time_series_2)

time_series_2 = TIME_SERIES(
data=data_2,
subj_id="sub-0002",
Fs=2,
locs=locs,
node_labels=node_labels_2,
)
# check node_labels mismatch assert error
with pytest.raises(AssertionError, match="node_labels mismatch!"):
time_series.concat_ts(time_series_2)

time_series_2 = TIME_SERIES(
data=data_2,
subj_id="sub-0002",
Fs=2,
locs=locs,
node_labels=node_labels,
)
time_series.concat_ts(time_series_2)

assert time_series.data.shape == (100, 2000)
assert time_series.subj_id_lst == ["sub-0001", "sub-0002"]
assert time_series.Fs == 2
assert time_series.n_time is None
assert time_series.n_regions == 100
assert np.all(time_series.nodes_lst == np.arange(0, 100, dtype=int))
assert time_series.time is None
assert time_series.data_dict.keys() == {"sub-0001", "sub-0002"}
assert np.all(time_series.data_dict["sub-0001"]["data"] == data_1)
assert np.all(time_series.data_dict["sub-0002"]["data"] == data_2)
assert np.all(
time_series.data_dict["sub-0001"]["time_array"]
== 1 / 2 + np.arange(0, 1000 / 2, 1 / 2)
)
assert np.all(
time_series.data_dict["sub-0002"]["time_array"]
== 1 / 2 + np.arange(0, 1000 / 2, 1 / 2)
)


def test_get_subj_ts():
# create a random data with 100 regions and 1000 time points
data_1 = np.random.rand(100, 1000)
# create a random locs with 100 regions and 3 coordinates
locs = np.random.rand(100, 3)
# create a random node_labels list with 100 regions
node_labels = [f"Region {i}" for i in range(100)]
time_series = TIME_SERIES(
data=data_1,
subj_id="sub-0001",
Fs=2,
locs=locs,
node_labels=node_labels,
)
# create a random data with 100 regions and 1000 time points
data_2 = np.random.rand(100, 1000)
time_series_2 = TIME_SERIES(
data=data_2,
subj_id="sub-0002",
Fs=2,
locs=locs,
node_labels=node_labels,
)
time_series.concat_ts(time_series_2)

subj_ts = time_series.get_subj_ts(subjs_id="sub-0001")
assert subj_ts.data.shape == (100, 1000)
assert subj_ts.subj_id_lst == ["sub-0001"]
assert subj_ts.Fs == 2
assert subj_ts.n_time == 1000
assert subj_ts.n_regions == 100
assert np.all(subj_ts.nodes_lst == np.arange(0, 100, dtype=int))
assert np.all(subj_ts.time == 1 / 2 + np.arange(0, 1000 / 2, 1 / 2))
assert subj_ts.data_dict.keys() == {"sub-0001"}
assert np.all(subj_ts.data_dict["sub-0001"]["data"] == data_1)

subj_ts = time_series.get_subj_ts(subjs_id="sub-0002")
assert subj_ts.data.shape == (100, 1000)
assert subj_ts.subj_id_lst == ["sub-0002"]
assert subj_ts.Fs == 2
assert subj_ts.n_time == 1000
assert subj_ts.n_regions == 100
assert np.all(subj_ts.nodes_lst == np.arange(0, 100, dtype=int))
assert np.all(subj_ts.time == 1 / 2 + np.arange(0, 1000 / 2, 1 / 2))
assert subj_ts.data_dict.keys() == {"sub-0002"}
assert np.all(subj_ts.data_dict["sub-0002"]["data"] == data_2)

# check that the original time_series is not changed
assert time_series.data.shape == (100, 2000)
assert time_series.subj_id_lst == ["sub-0001", "sub-0002"]
assert time_series.Fs == 2
assert time_series.n_time is None
assert time_series.n_regions == 100
assert np.all(time_series.nodes_lst == np.arange(0, 100, dtype=int))
assert time_series.time is None
assert time_series.data_dict.keys() == {"sub-0001", "sub-0002"}
assert np.all(time_series.data_dict["sub-0001"]["data"] == data_1)
assert np.all(time_series.data_dict["sub-0002"]["data"] == data_2)
assert np.all(
time_series.data_dict["sub-0001"]["time_array"]
== 1 / 2 + np.arange(0, 1000 / 2, 1 / 2)
)
assert np.all(
time_series.data_dict["sub-0002"]["time_array"]
== 1 / 2 + np.arange(0, 1000 / 2, 1 / 2)
)


def test_select_nodes():
# create a random data with 100 regions and 1000 time points
data = np.random.rand(100, 1000)
# create a random locs with 100 regions and 3 coordinates
locs = np.random.rand(100, 3)
# create a random node_labels list with 100 regions
node_labels = [f"Region {i}" for i in range(100)]
time_series = TIME_SERIES(
data=data,
subj_id="sub-0001",
Fs=2,
locs=locs,
node_labels=node_labels,
)

# select 5 nodes
nodes_idx = np.arange(0, 100, 20, dtype=int)
time_series.select_nodes(nodes_idx=nodes_idx)
assert time_series.data.shape == (5, 1000)
assert np.all(time_series.data == data[nodes_idx, :])
assert time_series.n_regions == 5
assert np.all(time_series.nodes_lst == nodes_idx)
assert np.all(time_series.locs == locs[nodes_idx])
assert np.all(time_series.node_labels == [f"Region {i}" for i in nodes_idx])
assert np.all(time_series.nodes_selection_ == nodes_idx)

0 comments on commit bfb675b

Please sign in to comment.