-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #29 from ryanharvey1/raw-module
Raw module
- Loading branch information
Showing
6 changed files
with
136 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import lazy_loader as _lazy | ||
|
||
(__getattr__, __dir__, __all__) = _lazy.attach_stub(__name__, __file__) | ||
del _lazy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
__all__ = ["zero_intervals_in_file"] | ||
|
||
from .preprocessing import zero_intervals_in_file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import os | ||
import warnings | ||
from typing import List, Tuple | ||
|
||
import numpy as np | ||
|
||
|
||
def zero_intervals_in_file( | ||
filepath: str, | ||
n_channels: int, | ||
zero_intervals: List[Tuple[int, int]], | ||
precision: str = "int16", | ||
) -> None: | ||
""" | ||
Zero out specified intervals in a binary file. | ||
Parameters | ||
---------- | ||
filepath : str | ||
Path to the binary file. | ||
n_channels : int | ||
Number of channels in the file. | ||
zero_intervals : List[Tuple[int, int]] | ||
List of intervals (start, end) in sample indices to zero out. | ||
precision : str, optional | ||
Data precision, by default "int16". | ||
Returns | ||
------- | ||
None | ||
Examples | ||
-------- | ||
>>> fs = 20_000 | ||
>>> zero_intervals_in_file( | ||
>>> "U:\data\hpc_ctx_project\HP13\HP13_day12_20241112\HP13_day12_20241112.dat", | ||
>>> n_channels=128, | ||
>>> zero_intervals = (bad_intervals.data * fs).astype(int) | ||
>>> ) | ||
""" | ||
# Check if file exists | ||
if not os.path.exists(filepath): | ||
warnings.warn("File does not exist.") | ||
return | ||
|
||
# Open the file in memory-mapped mode for read/write | ||
bytes_size = np.dtype(precision).itemsize | ||
with open(filepath, "rb") as f: | ||
startoffile = f.seek(0, 0) | ||
endoffile = f.seek(0, 2) | ||
n_samples = int((endoffile - startoffile) / n_channels / bytes_size) | ||
|
||
# Map the file to memory in read-write mode | ||
data = np.memmap( | ||
filepath, dtype=precision, mode="r+", shape=(n_samples, n_channels) | ||
) | ||
|
||
# Zero out the specified intervals | ||
zero_value = np.zeros((1, n_channels), dtype=precision) | ||
for start, end in zero_intervals: | ||
if 0 <= start < n_samples and 0 < end <= n_samples: | ||
data[start:end, :] = zero_value | ||
else: | ||
warnings.warn( | ||
f"Interval ({start}, {end}) is out of bounds and was skipped." | ||
) | ||
|
||
# Ensure changes are written to disk | ||
data.flush() | ||
|
||
# save log file with intervals zeroed out | ||
log_file = filepath.replace(".dat", "_zeroed_intervals.log") | ||
with open(log_file, "w") as f: | ||
for start, end in zero_intervals: | ||
f.write(f"{start} {end}\n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import numpy as np | ||
import os | ||
import tempfile | ||
from neuro_py.raw.preprocessing import zero_intervals_in_file | ||
|
||
def test_zero_intervals_in_file(): | ||
# Set up test parameters | ||
n_channels = 4 | ||
n_samples = 100 | ||
precision = "int16" | ||
zero_intervals = [(10, 20), (40, 50), (90, 100)] | ||
|
||
# Create a temporary file to act as our binary data file | ||
with tempfile.NamedTemporaryFile(delete=False, suffix=".dat") as tmpfile: | ||
filepath = tmpfile.name | ||
|
||
try: | ||
# Generate some sample data and write it to the temporary file | ||
original_data = (np.arange(n_samples * n_channels, dtype=precision) | ||
.reshape(n_samples, n_channels)) | ||
original_data.tofile(filepath) | ||
|
||
# Run the function | ||
zero_intervals_in_file(filepath, n_channels, zero_intervals, precision) | ||
|
||
# Load the file and check intervals are zeroed out | ||
data = np.fromfile(filepath, dtype=precision).reshape(n_samples, n_channels) | ||
|
||
for start, end in zero_intervals: | ||
# Check that the specified intervals are zeroed | ||
assert np.all(data[start:end, :] == 0), f"Interval ({start}, {end}) was not zeroed." | ||
|
||
# Check that the other intervals are unchanged | ||
for i in range(n_samples): | ||
if not any(start <= i < end for start, end in zero_intervals): | ||
expected_values = original_data[i, :] | ||
assert np.array_equal(data[i, :], expected_values), f"Data outside intervals was altered at index {i}." | ||
|
||
# Check if the log file was created and contains the correct intervals | ||
log_filepath = filepath.replace(".dat", "_zeroed_intervals.log") | ||
assert os.path.exists(log_filepath), "Log file was not created." | ||
|
||
with open(log_filepath, "r") as log_file: | ||
log_content = log_file.readlines() | ||
for i, (start, end) in enumerate(zero_intervals): | ||
assert log_content[i].strip() == f"{start} {end}", "Log file content is incorrect." | ||
|
||
finally: | ||
# Clean up the temporary file | ||
os.remove(filepath) | ||
os.remove(log_filepath) |