Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MAINT] Update NumPy random generation #11

Merged
merged 3 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/pyparrm/parrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _standardise_data(self) -> None:

def _optimise_period_estimate(self) -> None:
"""Optimise artefact period estimate."""
random_state = np.random.RandomState(self._random_seed)
random_state = np.random.default_rng(self._random_seed)

estimated_period = self._assumed_periods

Expand Down Expand Up @@ -329,7 +329,7 @@ def _get_centre_indices(
self,
use_n_samples: int,
ignore_portion: float,
random_state: np.random.RandomState,
random_state: np.random.Generator,
) -> np.ndarray:
"""Get indices for samples in the centre of the data segment.

Expand All @@ -341,7 +341,7 @@ def _get_centre_indices(
ignore_portion : float
Portion of the data segment to ignore when getting the indices.

random_state : numpy.random.RandomState
random_state : numpy.random.Generator
Random state object to use to generate numbers if the available number of
samples is less than that requested.

Expand All @@ -367,7 +367,7 @@ def _get_centre_indices(
)
return (
np.unique(
random_state.randint(
random_state.integers(
0, end_idx - start_idx, np.min((use_n_samples, end_idx - start_idx))
)
)
Expand Down
29 changes: 17 additions & 12 deletions tests/test_parrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from pyparrm.data import DATASETS
from pyparrm._utils._power import compute_psd


random = np.random.RandomState(44)
sampling_freq = 20 # Hz
artefact_freq = 10 # Hz

Expand All @@ -25,7 +23,8 @@
@pytest.mark.parametrize("n_jobs", [1, -1])
def test_parrm(n_chans: int, n_samples: int, verbose: bool, n_jobs: int):
"""Test that PARRM can run."""
data = random.rand(n_chans, n_samples)
random = np.random.default_rng(44)
data = random.standard_normal((n_chans, n_samples))

parrm = PARRM(
data=data,
Expand All @@ -43,7 +42,7 @@ def test_parrm(n_chans: int, n_samples: int, verbose: bool, n_jobs: int):
assert filtered_data.shape == data.shape
assert isinstance(filtered_data, np.ndarray)

other_data = random.rand(1, 50)
other_data = random.standard_normal((1, 50))
other_filtered_data = parrm.filter_data(other_data)
assert other_filtered_data.shape == other_data.shape

Expand All @@ -62,7 +61,8 @@ def test_parrm_attrs():
The returned attributes should simply be a copy of their private
counterparts.
"""
data = random.rand(1, 100)
random = np.random.default_rng(44)
data = random.standard_normal((1, 100))

parrm = PARRM(
data=data,
Expand Down Expand Up @@ -98,7 +98,8 @@ def test_parrm_attrs():

def test_parrm_wrong_type_inputs():
"""Test that inputs of wrong types to PARRM are caught."""
data = random.rand(1, 100)
random = np.random.default_rng(44)
data = random.standard_normal((1, 100))

# init object
with pytest.raises(TypeError, match="`data` must be a NumPy array."):
Expand Down Expand Up @@ -185,12 +186,13 @@ def test_parrm_wrong_type_inputs():

def test_parrm_wrong_value_inputs():
"""Test that inputs of wrong values to PARRM are caught."""
data = random.rand(1, 100)
random = np.random.default_rng(44)
data = random.standard_normal((1, 100))

# init object
with pytest.raises(ValueError, match="`data` must be a 2D array."):
PARRM(
data=random.rand(1, 1, 1),
data=random.standard_normal((1, 1, 1)),
sampling_freq=sampling_freq,
artefact_freq=artefact_freq,
)
Expand Down Expand Up @@ -281,13 +283,14 @@ def test_parrm_wrong_value_inputs():

# filter_data
with pytest.raises(ValueError, match="`data` must be a 2D array."):
parrm.filter_data(data=random.rand(100))
parrm.filter_data(data=random.standard_normal((100,)))


def test_parrm_premature_method_attribute_calls():
"""Test that errors raised for PARRM methods/attrs. called prematurely."""
random = np.random.default_rng(44)
parrm = PARRM(
data=random.rand(1, 100),
data=random.standard_normal((1, 100)),
sampling_freq=sampling_freq,
artefact_freq=artefact_freq,
verbose=False,
Expand Down Expand Up @@ -316,8 +319,9 @@ def test_parrm_premature_method_attribute_calls():

def test_parrm_missing_filter_inputs():
"""Test that PARRM can compute values for missing filter inputs."""
random = np.random.default_rng(44)
parrm = PARRM(
data=random.rand(1, 100),
data=random.standard_normal((1, 100)),
sampling_freq=sampling_freq,
artefact_freq=artefact_freq,
verbose=False,
Expand All @@ -333,7 +337,8 @@ def test_parrm_missing_filter_inputs():
@pytest.mark.parametrize("n_jobs", [1, 2])
def test_compute_psd(n_chans: int, n_jobs: int):
"""Test that PSD computation runs."""
data = random.rand(n_chans, 100)
random = np.random.default_rng(44)
data = random.standard_normal((n_chans, 100))

n_freqs = 5
freqs, psd = compute_psd(
Expand Down