From 2a40c6154d8f1aa2801e6d0ca19add94c3a5f80f Mon Sep 17 00:00:00 2001 From: Jordy Thielen Date: Fri, 6 Dec 2024 15:01:42 +0100 Subject: [PATCH] fix encoding_length array --- CHANGELOG.md | 2 +- pyntbci/classifiers.py | 8 +++--- pyntbci/tests/classifiers.py | 20 ++++++++++++++ pyntbci/tests/utilities.py | 53 ++++++++++++++++++++++++++++++++---- pyntbci/utilities.py | 23 +++++++++++----- 5 files changed, 88 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d2d4724..10d9d88 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ ### Changed ### Fixed -- Fixed fit exception in `DistributionStopping` in `stopping` +- Fixed array `encoding_length` of `rCCA` in `classifiers` ## Version 1.8.0 (08-11-2024) diff --git a/pyntbci/classifiers.py b/pyntbci/classifiers.py index 7e15dac..bfe7cfe 100644 --- a/pyntbci/classifiers.py +++ b/pyntbci/classifiers.py @@ -862,9 +862,9 @@ def __init__( else: self.decoding_stride = decoding_stride if encoding_length is None: - self.encoding_length = 1 / fs + self.encoding_length = np.atleast_1d(1 / fs) else: - self.encoding_length = encoding_length + self.encoding_length = np.atleast_1d(encoding_length) if encoding_stride is None: self.encoding_stride = 1 / fs else: @@ -915,8 +915,8 @@ def _get_M( # Get encoding matrices E, self.events_ = event_matrix(stimulus, self.event, self.onset_event) - M = encoding_matrix(E, int(self.encoding_length * self.fs), int(self.encoding_stride * self.fs), amplitudes, - int(self.tmin * self.fs)) + M = encoding_matrix(E, (self.encoding_length * self.fs).astype("uint"), int(self.encoding_stride * self.fs), + amplitudes, int(self.tmin * self.fs)) M = M[:, :, :n_samples] return M diff --git a/pyntbci/tests/classifiers.py b/pyntbci/tests/classifiers.py index 94db2cc..35b3128 100644 --- a/pyntbci/tests/classifiers.py +++ b/pyntbci/tests/classifiers.py @@ -298,6 +298,26 @@ def test_rcca_shape(self): yh = rcca.predict(X) self.assertEqual((X.shape[0], ), yh.shape) + def test_rcca_encoding_length(self): + fs = 200 + encoding_length = [0.3, 0.1] + X = np.random.rand(111, 64, 2 * fs) + y = np.random.choice(5, 111) + V = np.random.rand(5, fs) > 0.5 + + rcca = pyntbci.classifiers.rCCA(stimulus=V, fs=fs, event="refe", encoding_length=encoding_length) + rcca.fit(X, y) + self.assertEqual((X.shape[1], 1), rcca.w_.shape) + self.assertEqual((int(sum(encoding_length) * fs), 1), rcca.r_.shape) + self.assertEqual((V.shape[0], 1, V.shape[1]), rcca.Ts_.shape) + self.assertEqual((V.shape[0], 1, V.shape[1]), rcca.Tw_.shape) + + z = rcca.decision_function(X) + self.assertEqual((X.shape[0], V.shape[0]), z.shape) + + yh = rcca.predict(X) + self.assertEqual((X.shape[0], ), yh.shape) + def test_rcca_score_metric(self): fs = 200 encoding_length = 0.3 diff --git a/pyntbci/tests/utilities.py b/pyntbci/tests/utilities.py index 69ade8b..a473816 100644 --- a/pyntbci/tests/utilities.py +++ b/pyntbci/tests/utilities.py @@ -135,15 +135,56 @@ def test_encoding_matrix_stride(self): self.assertEqual(M.shape[1], int(encoding_length / encoding_stride) * E.shape[1]) # response length(s) self.assertEqual(M.shape[2], S.shape[1]) # samples - def test_encoding_matrix_encoding_length_list(self): - encoding_length = [31, 41] - S = np.random.rand(17, 123) > 0.5 - E = pyntbci.utilities.event_matrix(stimulus=S, event="refe")[0] - M = pyntbci.utilities.encoding_matrix(stimulus=E, length=encoding_length) + def test_encoding_matrix_encoding_length(self): + S = np.tile(np.array([0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0]), reps=10).reshape([1, -1]) # 3 events + E, events = pyntbci.utilities.event_matrix(stimulus=S, event="dur") + + M = pyntbci.utilities.encoding_matrix(stimulus=E, length=31) + self.assertEqual(M.shape[0], S.shape[0]) # classes + self.assertEqual(M.shape[1], 3 * 31) # response length(s) + self.assertEqual(M.shape[2], S.shape[1]) # samples + + M = pyntbci.utilities.encoding_matrix(stimulus=E, length=[31]) + self.assertEqual(M.shape[0], S.shape[0]) # classes + self.assertEqual(M.shape[1], 3 * 31) # response length(s) + self.assertEqual(M.shape[2], S.shape[1]) # samples + + M = pyntbci.utilities.encoding_matrix(stimulus=E, length=(31,)) self.assertEqual(M.shape[0], S.shape[0]) # classes - self.assertEqual(M.shape[1], sum(encoding_length)) # response length(s) + self.assertEqual(M.shape[1], 3 * 31) # response length(s) self.assertEqual(M.shape[2], S.shape[1]) # samples + M = pyntbci.utilities.encoding_matrix(stimulus=E, length=np.array([31])) + self.assertEqual(M.shape[0], S.shape[0]) # classes + self.assertEqual(M.shape[1], 3 * 31) # response length(s) + self.assertEqual(M.shape[2], S.shape[1]) # samples + + M = pyntbci.utilities.encoding_matrix(stimulus=E, length=[31, 41, 51]) + self.assertEqual(M.shape[0], S.shape[0]) # classes + self.assertEqual(M.shape[1], sum([31, 41, 51])) # response length(s) + self.assertEqual(M.shape[2], S.shape[1]) # samples + + M = pyntbci.utilities.encoding_matrix(stimulus=E, length=(31, 41, 51)) + self.assertEqual(M.shape[0], S.shape[0]) # classes + self.assertEqual(M.shape[1], sum([31, 41, 51])) # response length(s) + self.assertEqual(M.shape[2], S.shape[1]) # samples + + M = pyntbci.utilities.encoding_matrix(stimulus=E, length=np.array([31, 41, 51])) + self.assertEqual(M.shape[0], S.shape[0]) # classes + self.assertEqual(M.shape[1], np.array([31, 41, 51]).sum()) # response length(s) + self.assertEqual(M.shape[2], S.shape[1]) # samples + + self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, 31.0) # float + self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, [31, 41]) # too few + self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, [31, 41, 51, 61]) # too many + self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, [31., 41., 51.]) # float + self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, (31, 41)) # too few + self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, (31, 41, 51, 61)) # too many + self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, (31., 41., 51.)) # float + self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, np.array([31, 41])) # too few + self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, np.array([31, 41, 51, 61])) # too many + self.assertRaises(AssertionError, pyntbci.utilities.encoding_matrix, E, np.array([31., 41., 51.])) # float + class TestEventMatrix(unittest.TestCase): diff --git a/pyntbci/utilities.py b/pyntbci/utilities.py index 06bd485..2db4675 100644 --- a/pyntbci/utilities.py +++ b/pyntbci/utilities.py @@ -186,7 +186,7 @@ def decoding_matrix( def encoding_matrix( stimulus: np.array, - length: Union[int, list[int]], + length: Union[int, list[int], tuple[int,...], NDArray], stride: int = 1, amplitude: NDArray = None, tmin: float = 0, @@ -198,9 +198,10 @@ def encoding_matrix( ---------- stimulus: NDArray Stimulus matrix of shape (n_classes, n_events, n_samples). - length: int | list[int] - The length in samples of the temporal filter, i.e., the number of phase-shifted stimulus per event. If a list is - provided, it denotes the length per event. + length: int | list[int] | tuple[int] | NDArray + The length in samples of the temporal filter, i.e., the number of phase-shifted stimulus per event. If an array + is provided, it denotes the length per event. If one value is provided, it is assumed all event responses are of + the same length. stride: int (default: 1) The step size in samples over the length of the temporal filter. amplitude: NDArray (default: None) @@ -214,12 +215,20 @@ def encoding_matrix( """ n_classes, n_events, n_samples = stimulus.shape + assert (isinstance(length, int) or isinstance(length, list) or isinstance(length, tuple) or + isinstance(length, np.ndarray)), "encoding_length must be int, list[int], tuple[int], or np.ndarray()." if isinstance(length, int): length = n_events * [length] elif isinstance(length, list) or isinstance(length, tuple): - assert len(length) == n_events, "len(encoding_length) does not match S.shape[1]." - else: - raise Exception("encoding_length should be int, list[int], or tuple[int].") + if len(length) == 1: + length *= n_events + assert len(length) == n_events, "the number of events in encoding_length must match those in stimulus." + assert all([isinstance(value, int) for value in length]), "encoding_length must contain integer values." + elif isinstance(length, np.ndarray): + if length.size == 1: + length = np.repeat(length, n_events) + assert length.size == n_events, "the number of events in encoding_length must match those in stimulus." + assert np.issubdtype(length.dtype, np.integer), "encoding_length must contain integer values." # Create encoding window per event ematrix = []