Skip to content

Commit

Permalink
[MAINT] Update NumPy random generation (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Sep 17, 2024
1 parent 04a304c commit edb14eb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
8 changes: 4 additions & 4 deletions src/pyparrm/parrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _standardise_data(self) -> None:

def _optimise_period_estimate(self) -> None:
"""Optimise artefact period estimate."""
random_state = np.random.RandomState(self._random_seed)
random_state = np.random.default_rng(self._random_seed)

estimated_period = self._assumed_periods

Expand Down Expand Up @@ -329,7 +329,7 @@ def _get_centre_indices(
self,
use_n_samples: int,
ignore_portion: float,
random_state: np.random.RandomState,
random_state: np.random.Generator,
) -> np.ndarray:
"""Get indices for samples in the centre of the data segment.
Expand All @@ -341,7 +341,7 @@ def _get_centre_indices(
ignore_portion : float
Portion of the data segment to ignore when getting the indices.
random_state : numpy.random.RandomState
random_state : numpy.random.Generator
Random state object to use to generate numbers if the available number of
samples is less than that requested.
Expand All @@ -367,7 +367,7 @@ def _get_centre_indices(
)
return (
np.unique(
random_state.randint(
random_state.integers(
0, end_idx - start_idx, np.min((use_n_samples, end_idx - start_idx))
)
)
Expand Down
29 changes: 17 additions & 12 deletions tests/test_parrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from pyparrm.data import DATASETS
from pyparrm._utils._power import compute_psd


random = np.random.RandomState(44)
sampling_freq = 20 # Hz
artefact_freq = 10 # Hz

Expand All @@ -25,7 +23,8 @@
@pytest.mark.parametrize("n_jobs", [1, -1])
def test_parrm(n_chans: int, n_samples: int, verbose: bool, n_jobs: int):
"""Test that PARRM can run."""
data = random.rand(n_chans, n_samples)
random = np.random.default_rng(44)
data = random.standard_normal((n_chans, n_samples))

parrm = PARRM(
data=data,
Expand All @@ -43,7 +42,7 @@ def test_parrm(n_chans: int, n_samples: int, verbose: bool, n_jobs: int):
assert filtered_data.shape == data.shape
assert isinstance(filtered_data, np.ndarray)

other_data = random.rand(1, 50)
other_data = random.standard_normal((1, 50))
other_filtered_data = parrm.filter_data(other_data)
assert other_filtered_data.shape == other_data.shape

Expand All @@ -62,7 +61,8 @@ def test_parrm_attrs():
The returned attributes should simply be a copy of their private
counterparts.
"""
data = random.rand(1, 100)
random = np.random.default_rng(44)
data = random.standard_normal((1, 100))

parrm = PARRM(
data=data,
Expand Down Expand Up @@ -98,7 +98,8 @@ def test_parrm_attrs():

def test_parrm_wrong_type_inputs():
"""Test that inputs of wrong types to PARRM are caught."""
data = random.rand(1, 100)
random = np.random.default_rng(44)
data = random.standard_normal((1, 100))

# init object
with pytest.raises(TypeError, match="`data` must be a NumPy array."):
Expand Down Expand Up @@ -185,12 +186,13 @@ def test_parrm_wrong_type_inputs():

def test_parrm_wrong_value_inputs():
"""Test that inputs of wrong values to PARRM are caught."""
data = random.rand(1, 100)
random = np.random.default_rng(44)
data = random.standard_normal((1, 100))

# init object
with pytest.raises(ValueError, match="`data` must be a 2D array."):
PARRM(
data=random.rand(1, 1, 1),
data=random.standard_normal((1, 1, 1)),
sampling_freq=sampling_freq,
artefact_freq=artefact_freq,
)
Expand Down Expand Up @@ -281,13 +283,14 @@ def test_parrm_wrong_value_inputs():

# filter_data
with pytest.raises(ValueError, match="`data` must be a 2D array."):
parrm.filter_data(data=random.rand(100))
parrm.filter_data(data=random.standard_normal((100,)))


def test_parrm_premature_method_attribute_calls():
"""Test that errors raised for PARRM methods/attrs. called prematurely."""
random = np.random.default_rng(44)
parrm = PARRM(
data=random.rand(1, 100),
data=random.standard_normal((1, 100)),
sampling_freq=sampling_freq,
artefact_freq=artefact_freq,
verbose=False,
Expand Down Expand Up @@ -316,8 +319,9 @@ def test_parrm_premature_method_attribute_calls():

def test_parrm_missing_filter_inputs():
"""Test that PARRM can compute values for missing filter inputs."""
random = np.random.default_rng(44)
parrm = PARRM(
data=random.rand(1, 100),
data=random.standard_normal((1, 100)),
sampling_freq=sampling_freq,
artefact_freq=artefact_freq,
verbose=False,
Expand All @@ -333,7 +337,8 @@ def test_parrm_missing_filter_inputs():
@pytest.mark.parametrize("n_jobs", [1, 2])
def test_compute_psd(n_chans: int, n_jobs: int):
"""Test that PSD computation runs."""
data = random.rand(n_chans, 100)
random = np.random.default_rng(44)
data = random.standard_normal((n_chans, 100))

n_freqs = 5
freqs, psd = compute_psd(
Expand Down

0 comments on commit edb14eb

Please sign in to comment.