Skip to content

Commit

Permalink
change intervals data type
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanharvey1 committed Nov 15, 2024
1 parent 824cb83 commit df50197
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions tests/test_remove_artifacts.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import numpy as np
import os
import tempfile

import numpy as np
from neuro_py.raw.preprocessing import remove_artifacts


def test_remove_artifacts():
# Parameters
n_channels = 4
n_samples = 100
precision = "int16"
zero_intervals = [(20, 30), (50, 60)]
zero_intervals = np.array([[20, 30], [50, 60]])

# Create a temporary binary file
with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
Expand All @@ -25,23 +26,33 @@ def test_remove_artifacts():

# Test mode "zeros"
remove_artifacts(filepath, n_channels, zero_intervals, precision, mode="zeros")
data = np.memmap(filepath, dtype=precision, mode="r", shape=(n_samples, n_channels))
data = np.memmap(
filepath, dtype=precision, mode="r", shape=(n_samples, n_channels)
)
for start, end in zero_intervals:
assert np.all(data[start:end, :] == 0)
del data # Close memmap

# Test mode "linear"
remove_artifacts(filepath, n_channels, zero_intervals, precision, mode="linear")
data = np.memmap(filepath, dtype=precision, mode="r", shape=(n_samples, n_channels))
data = np.memmap(
filepath, dtype=precision, mode="r", shape=(n_samples, n_channels)
)
for start, end in zero_intervals:
for ch in range(n_channels):
expected = np.linspace(original_data[start, ch], original_data[end, ch], end - start)
expected = np.linspace(
original_data[start, ch], original_data[end, ch], end - start
)
np.testing.assert_allclose(data[start:end, ch], expected, rtol=1e-5)
del data # Close memmap

# Test mode "gaussian"
remove_artifacts(filepath, n_channels, zero_intervals, precision, mode="gaussian")
data = np.memmap(filepath, dtype=precision, mode="r", shape=(n_samples, n_channels))
remove_artifacts(
filepath, n_channels, zero_intervals, precision, mode="gaussian"
)
data = np.memmap(
filepath, dtype=precision, mode="r", shape=(n_samples, n_channels)
)
for start, end in zero_intervals:
for ch in range(n_channels):
segment = data[start:end, ch]
Expand All @@ -53,4 +64,5 @@ def test_remove_artifacts():
del data # Close memmap
finally:
# Clean up temporary file
os.remove(filepath)

os.remove(filepath)

0 comments on commit df50197

Please sign in to comment.