Skip to content

Commit

Permalink
Merge pull request #29 from ryanharvey1/raw-module
Browse files Browse the repository at this point in the history
Raw module
  • Loading branch information
ryanharvey1 authored Nov 14, 2024
2 parents 9fea41e + f313902 commit 3b38a94
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 3 deletions.
2 changes: 2 additions & 0 deletions neuro_py/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ __all__ = [
"lfp",
"plotting",
"process",
"raw",
"session",
"spikes",
"stats",
Expand All @@ -22,6 +23,7 @@ from . import (
lfp,
plotting,
process,
raw,
session,
spikes,
stats,
Expand Down
4 changes: 1 addition & 3 deletions neuro_py/lfp/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import warnings
from typing import List, Tuple, Union
from typing import Tuple, Union

import nelpy as nel
import numpy as np
Expand Down
4 changes: 4 additions & 0 deletions neuro_py/raw/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import lazy_loader as _lazy

(__getattr__, __dir__, __all__) = _lazy.attach_stub(__name__, __file__)
del _lazy
3 changes: 3 additions & 0 deletions neuro_py/raw/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__all__ = ["zero_intervals_in_file"]

from .preprocessing import zero_intervals_in_file
75 changes: 75 additions & 0 deletions neuro_py/raw/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os
import warnings
from typing import List, Tuple

import numpy as np


def zero_intervals_in_file(
filepath: str,
n_channels: int,
zero_intervals: List[Tuple[int, int]],
precision: str = "int16",
) -> None:
"""
Zero out specified intervals in a binary file.
Parameters
----------
filepath : str
Path to the binary file.
n_channels : int
Number of channels in the file.
zero_intervals : List[Tuple[int, int]]
List of intervals (start, end) in sample indices to zero out.
precision : str, optional
Data precision, by default "int16".
Returns
-------
None
Examples
--------
>>> fs = 20_000
>>> zero_intervals_in_file(
>>> "U:\data\hpc_ctx_project\HP13\HP13_day12_20241112\HP13_day12_20241112.dat",
>>> n_channels=128,
>>> zero_intervals = (bad_intervals.data * fs).astype(int)
>>> )
"""
# Check if file exists
if not os.path.exists(filepath):
warnings.warn("File does not exist.")
return

# Open the file in memory-mapped mode for read/write
bytes_size = np.dtype(precision).itemsize
with open(filepath, "rb") as f:
startoffile = f.seek(0, 0)
endoffile = f.seek(0, 2)
n_samples = int((endoffile - startoffile) / n_channels / bytes_size)

# Map the file to memory in read-write mode
data = np.memmap(
filepath, dtype=precision, mode="r+", shape=(n_samples, n_channels)
)

# Zero out the specified intervals
zero_value = np.zeros((1, n_channels), dtype=precision)
for start, end in zero_intervals:
if 0 <= start < n_samples and 0 < end <= n_samples:
data[start:end, :] = zero_value
else:
warnings.warn(
f"Interval ({start}, {end}) is out of bounds and was skipped."
)

# Ensure changes are written to disk
data.flush()

# save log file with intervals zeroed out
log_file = filepath.replace(".dat", "_zeroed_intervals.log")
with open(log_file, "w") as f:
for start, end in zero_intervals:
f.write(f"{start} {end}\n")
51 changes: 51 additions & 0 deletions tests/test_zero_intervals_in_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
import os
import tempfile
from neuro_py.raw.preprocessing import zero_intervals_in_file

def test_zero_intervals_in_file():
# Set up test parameters
n_channels = 4
n_samples = 100
precision = "int16"
zero_intervals = [(10, 20), (40, 50), (90, 100)]

# Create a temporary file to act as our binary data file
with tempfile.NamedTemporaryFile(delete=False, suffix=".dat") as tmpfile:
filepath = tmpfile.name

try:
# Generate some sample data and write it to the temporary file
original_data = (np.arange(n_samples * n_channels, dtype=precision)
.reshape(n_samples, n_channels))
original_data.tofile(filepath)

# Run the function
zero_intervals_in_file(filepath, n_channels, zero_intervals, precision)

# Load the file and check intervals are zeroed out
data = np.fromfile(filepath, dtype=precision).reshape(n_samples, n_channels)

for start, end in zero_intervals:
# Check that the specified intervals are zeroed
assert np.all(data[start:end, :] == 0), f"Interval ({start}, {end}) was not zeroed."

# Check that the other intervals are unchanged
for i in range(n_samples):
if not any(start <= i < end for start, end in zero_intervals):
expected_values = original_data[i, :]
assert np.array_equal(data[i, :], expected_values), f"Data outside intervals was altered at index {i}."

# Check if the log file was created and contains the correct intervals
log_filepath = filepath.replace(".dat", "_zeroed_intervals.log")
assert os.path.exists(log_filepath), "Log file was not created."

with open(log_filepath, "r") as log_file:
log_content = log_file.readlines()
for i, (start, end) in enumerate(zero_intervals):
assert log_content[i].strip() == f"{start} {end}", "Log file content is incorrect."

finally:
# Clean up the temporary file
os.remove(filepath)
os.remove(log_filepath)

0 comments on commit 3b38a94

Please sign in to comment.