Skip to content

Commit

Permalink
fix encoding_length array
Browse files Browse the repository at this point in the history
  • Loading branch information
Jordy Thielen committed Dec 6, 2024
1 parent d4a1212 commit 2a40c61
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 18 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
### Changed

### Fixed
- Fixed fit exception in `DistributionStopping` in `stopping`
- Fixed array `encoding_length` of `rCCA` in `classifiers`

## Version 1.8.0 (08-11-2024)

Expand Down
8 changes: 4 additions & 4 deletions pyntbci/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,9 +862,9 @@ def __init__(
else:
self.decoding_stride = decoding_stride
if encoding_length is None:
self.encoding_length = 1 / fs
self.encoding_length = np.atleast_1d(1 / fs)
else:
self.encoding_length = encoding_length
self.encoding_length = np.atleast_1d(encoding_length)
if encoding_stride is None:
self.encoding_stride = 1 / fs
else:
Expand Down Expand Up @@ -915,8 +915,8 @@ def _get_M(

# Get encoding matrices
E, self.events_ = event_matrix(stimulus, self.event, self.onset_event)
M = encoding_matrix(E, int(self.encoding_length * self.fs), int(self.encoding_stride * self.fs), amplitudes,
int(self.tmin * self.fs))
M = encoding_matrix(E, (self.encoding_length * self.fs).astype("uint"), int(self.encoding_stride * self.fs),
amplitudes, int(self.tmin * self.fs))
M = M[:, :, :n_samples]

return M
Expand Down
20 changes: 20 additions & 0 deletions pyntbci/tests/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,26 @@ def test_rcca_shape(self):
yh = rcca.predict(X)
self.assertEqual((X.shape[0], ), yh.shape)

def test_rcca_encoding_length(self):
fs = 200
encoding_length = [0.3, 0.1]
X = np.random.rand(111, 64, 2 * fs)
y = np.random.choice(5, 111)
V = np.random.rand(5, fs) > 0.5

rcca = pyntbci.classifiers.rCCA(stimulus=V, fs=fs, event="refe", encoding_length=encoding_length)
rcca.fit(X, y)
self.assertEqual((X.shape[1], 1), rcca.w_.shape)
self.assertEqual((int(sum(encoding_length) * fs), 1), rcca.r_.shape)
self.assertEqual((V.shape[0], 1, V.shape[1]), rcca.Ts_.shape)
self.assertEqual((V.shape[0], 1, V.shape[1]), rcca.Tw_.shape)

z = rcca.decision_function(X)
self.assertEqual((X.shape[0], V.shape[0]), z.shape)

yh = rcca.predict(X)
self.assertEqual((X.shape[0], ), yh.shape)

def test_rcca_score_metric(self):
fs = 200
encoding_length = 0.3
Expand Down
53 changes: 47 additions & 6 deletions pyntbci/tests/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,56 @@ def test_encoding_matrix_stride(self):
self.assertEqual(M.shape[1], int(encoding_length / encoding_stride) * E.shape[1]) # response length(s)
self.assertEqual(M.shape[2], S.shape[1]) # samples

def test_encoding_matrix_encoding_length_list(self):
encoding_length = [31, 41]
S = np.random.rand(17, 123) > 0.5
E = pyntbci.utilities.event_matrix(stimulus=S, event="refe")[0]
M = pyntbci.utilities.encoding_matrix(stimulus=E, length=encoding_length)
def test_encoding_matrix_encoding_length(self):
S = np.tile(np.array([0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0]), reps=10).reshape([1, -1]) # 3 events
E, events = pyntbci.utilities.event_matrix(stimulus=S, event="dur")

M = pyntbci.utilities.encoding_matrix(stimulus=E, length=31)
self.assertEqual(M.shape[0], S.shape[0]) # classes
self.assertEqual(M.shape[1], 3 * 31) # response length(s)
self.assertEqual(M.shape[2], S.shape[1]) # samples

M = pyntbci.utilities.encoding_matrix(stimulus=E, length=[31])
self.assertEqual(M.shape[0], S.shape[0]) # classes
self.assertEqual(M.shape[1], 3 * 31) # response length(s)
self.assertEqual(M.shape[2], S.shape[1]) # samples

M = pyntbci.utilities.encoding_matrix(stimulus=E, length=(31,))
self.assertEqual(M.shape[0], S.shape[0]) # classes
self.assertEqual(M.shape[1], sum(encoding_length)) # response length(s)
self.assertEqual(M.shape[1], 3 * 31) # response length(s)
self.assertEqual(M.shape[2], S.shape[1]) # samples

M = pyntbci.utilities.encoding_matrix(stimulus=E, length=np.array([31]))
self.assertEqual(M.shape[0], S.shape[0]) # classes
self.assertEqual(M.shape[1], 3 * 31) # response length(s)
self.assertEqual(M.shape[2], S.shape[1]) # samples

M = pyntbci.utilities.encoding_matrix(stimulus=E, length=[31, 41, 51])
self.assertEqual(M.shape[0], S.shape[0]) # classes
self.assertEqual(M.shape[1], sum([31, 41, 51])) # response length(s)
self.assertEqual(M.shape[2], S.shape[1]) # samples

M = pyntbci.utilities.encoding_matrix(stimulus=E, length=(31, 41, 51))
self.assertEqual(M.shape[0], S.shape[0]) # classes
self.assertEqual(M.shape[1], sum([31, 41, 51])) # response length(s)
self.assertEqual(M.shape[2], S.shape[1]) # samples

M = pyntbci.utilities.encoding_matrix(stimulus=E, length=np.array([31, 41, 51]))
self.assertEqual(M.shape[0], S.shape[0]) # classes
self.assertEqual(M.shape[1], np.array([31, 41, 51]).sum()) # response length(s)
self.assertEqual(M.shape[2], S.shape[1]) # samples

self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, 31.0) # float
self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, [31, 41]) # too few
self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, [31, 41, 51, 61]) # too many
self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, [31., 41., 51.]) # float
self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, (31, 41)) # too few
self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, (31, 41, 51, 61)) # too many
self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, (31., 41., 51.)) # float
self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, np.array([31, 41])) # too few
self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, np.array([31, 41, 51, 61])) # too many
self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, np.array([31., 41., 51.])) # float


class TestEventMatrix(unittest.TestCase):

Expand Down
23 changes: 16 additions & 7 deletions pyntbci/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def decoding_matrix(

def encoding_matrix(
stimulus: np.array,
length: Union[int, list[int]],
length: Union[int, list[int], tuple[int,...], NDArray],
stride: int = 1,
amplitude: NDArray = None,
tmin: float = 0,
Expand All @@ -198,9 +198,10 @@ def encoding_matrix(
----------
stimulus: NDArray
Stimulus matrix of shape (n_classes, n_events, n_samples).
length: int | list[int]
The length in samples of the temporal filter, i.e., the number of phase-shifted stimulus per event. If a list is
provided, it denotes the length per event.
length: int | list[int] | tuple[int] | NDArray
The length in samples of the temporal filter, i.e., the number of phase-shifted stimulus per event. If an array
is provided, it denotes the length per event. If one value is provided, it is assumed all event responses are of
the same length.
stride: int (default: 1)
The step size in samples over the length of the temporal filter.
amplitude: NDArray (default: None)
Expand All @@ -214,12 +215,20 @@ def encoding_matrix(
"""
n_classes, n_events, n_samples = stimulus.shape

assert (isinstance(length, int) or isinstance(length, list) or isinstance(length, tuple) or
isinstance(length, np.ndarray)), "encoding_length must be int, list[int], tuple[int], or np.ndarray()."
if isinstance(length, int):
length = n_events * [length]
elif isinstance(length, list) or isinstance(length, tuple):
assert len(length) == n_events, "len(encoding_length) does not match S.shape[1]."
else:
raise Exception("encoding_length should be int, list[int], or tuple[int].")
if len(length) == 1:
length *= n_events
assert len(length) == n_events, "the number of events in encoding_length must match those in stimulus."
assert all([isinstance(value, int) for value in length]), "encoding_length must contain integer values."
elif isinstance(length, np.ndarray):
if length.size == 1:
length = np.repeat(length, n_events)
assert length.size == n_events, "the number of events in encoding_length must match those in stimulus."
assert np.issubdtype(length.dtype, np.integer), "encoding_length must contain integer values."

# Create encoding window per event
ematrix = []
Expand Down

0 comments on commit 2a40c61

Please sign in to comment.