From edb14eb1154edb97c8bbb1035b18521fa7cf02ea Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Tue, 17 Sep 2024 18:14:11 +0200 Subject: [PATCH] [MAINT] Update NumPy random generation (#11) --- src/pyparrm/parrm.py | 8 ++++---- tests/test_parrm.py | 29 +++++++++++++++++------------ 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/pyparrm/parrm.py b/src/pyparrm/parrm.py index ff80e6f..fdeba21 100644 --- a/src/pyparrm/parrm.py +++ b/src/pyparrm/parrm.py @@ -282,7 +282,7 @@ def _standardise_data(self) -> None: def _optimise_period_estimate(self) -> None: """Optimise artefact period estimate.""" - random_state = np.random.RandomState(self._random_seed) + random_state = np.random.default_rng(self._random_seed) estimated_period = self._assumed_periods @@ -329,7 +329,7 @@ def _get_centre_indices( self, use_n_samples: int, ignore_portion: float, - random_state: np.random.RandomState, + random_state: np.random.Generator, ) -> np.ndarray: """Get indices for samples in the centre of the data segment. @@ -341,7 +341,7 @@ def _get_centre_indices( ignore_portion : float Portion of the data segment to ignore when getting the indices. - random_state : numpy.random.RandomState + random_state : numpy.random.Generator Random state object to use to generate numbers if the available number of samples is less than that requested. @@ -367,7 +367,7 @@ def _get_centre_indices( ) return ( np.unique( - random_state.randint( + random_state.integers( 0, end_idx - start_idx, np.min((use_n_samples, end_idx - start_idx)) ) ) diff --git a/tests/test_parrm.py b/tests/test_parrm.py index b97f7ab..473d84f 100644 --- a/tests/test_parrm.py +++ b/tests/test_parrm.py @@ -13,8 +13,6 @@ from pyparrm.data import DATASETS from pyparrm._utils._power import compute_psd - -random = np.random.RandomState(44) sampling_freq = 20 # Hz artefact_freq = 10 # Hz @@ -25,7 +23,8 @@ @pytest.mark.parametrize("n_jobs", [1, -1]) def test_parrm(n_chans: int, n_samples: int, verbose: bool, n_jobs: int): """Test that PARRM can run.""" - data = random.rand(n_chans, n_samples) + random = np.random.default_rng(44) + data = random.standard_normal((n_chans, n_samples)) parrm = PARRM( data=data, @@ -43,7 +42,7 @@ def test_parrm(n_chans: int, n_samples: int, verbose: bool, n_jobs: int): assert filtered_data.shape == data.shape assert isinstance(filtered_data, np.ndarray) - other_data = random.rand(1, 50) + other_data = random.standard_normal((1, 50)) other_filtered_data = parrm.filter_data(other_data) assert other_filtered_data.shape == other_data.shape @@ -62,7 +61,8 @@ def test_parrm_attrs(): The returned attributes should simply be a copy of their private counterparts. """ - data = random.rand(1, 100) + random = np.random.default_rng(44) + data = random.standard_normal((1, 100)) parrm = PARRM( data=data, @@ -98,7 +98,8 @@ def test_parrm_attrs(): def test_parrm_wrong_type_inputs(): """Test that inputs of wrong types to PARRM are caught.""" - data = random.rand(1, 100) + random = np.random.default_rng(44) + data = random.standard_normal((1, 100)) # init object with pytest.raises(TypeError, match="`data` must be a NumPy array."): @@ -185,12 +186,13 @@ def test_parrm_wrong_type_inputs(): def test_parrm_wrong_value_inputs(): """Test that inputs of wrong values to PARRM are caught.""" - data = random.rand(1, 100) + random = np.random.default_rng(44) + data = random.standard_normal((1, 100)) # init object with pytest.raises(ValueError, match="`data` must be a 2D array."): PARRM( - data=random.rand(1, 1, 1), + data=random.standard_normal((1, 1, 1)), sampling_freq=sampling_freq, artefact_freq=artefact_freq, ) @@ -281,13 +283,14 @@ def test_parrm_wrong_value_inputs(): # filter_data with pytest.raises(ValueError, match="`data` must be a 2D array."): - parrm.filter_data(data=random.rand(100)) + parrm.filter_data(data=random.standard_normal((100,))) def test_parrm_premature_method_attribute_calls(): """Test that errors raised for PARRM methods/attrs. called prematurely.""" + random = np.random.default_rng(44) parrm = PARRM( - data=random.rand(1, 100), + data=random.standard_normal((1, 100)), sampling_freq=sampling_freq, artefact_freq=artefact_freq, verbose=False, @@ -316,8 +319,9 @@ def test_parrm_premature_method_attribute_calls(): def test_parrm_missing_filter_inputs(): """Test that PARRM can compute values for missing filter inputs.""" + random = np.random.default_rng(44) parrm = PARRM( - data=random.rand(1, 100), + data=random.standard_normal((1, 100)), sampling_freq=sampling_freq, artefact_freq=artefact_freq, verbose=False, @@ -333,7 +337,8 @@ def test_parrm_missing_filter_inputs(): @pytest.mark.parametrize("n_jobs", [1, 2]) def test_compute_psd(n_chans: int, n_jobs: int): """Test that PSD computation runs.""" - data = random.rand(n_chans, 100) + random = np.random.default_rng(44) + data = random.standard_normal((n_chans, 100)) n_freqs = 5 freqs, psd = compute_psd(