-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #52 from ryanharvey1/behavior-preprocessing
Behavior preprocessing
- Loading branch information
Showing
3 changed files
with
214 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import os | ||
from typing import Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from scipy.io import loadmat, savemat | ||
|
||
import neuro_py as npy | ||
|
||
|
||
def filter_tracker_jumps( | ||
beh_df: pd.DataFrame, max_speed: Union[int, float] = 100 | ||
) -> pd.DataFrame: | ||
""" | ||
Filter out tracker jumps (to NaN) in the behavior data. | ||
Parameters | ||
---------- | ||
beh_df : pd.DataFrame | ||
Behavior data with columns x, y, and ts. | ||
max_speed : Union[int,float], optional | ||
Maximum allowed speed in pixels per second. | ||
Returns | ||
------- | ||
pd.DataFrame | ||
Notes | ||
----- | ||
Will force dtypes of x and y to float64 | ||
""" | ||
|
||
# Calculate the Euclidean distance between consecutive points | ||
beh_df["dx"] = beh_df["x"].diff() | ||
beh_df["dy"] = beh_df["y"].diff() | ||
beh_df["distance"] = np.sqrt(beh_df["dx"] ** 2 + beh_df["dy"] ** 2) | ||
|
||
# Calculate the time difference between consecutive timestamps | ||
beh_df["dt"] = beh_df["ts"].diff() | ||
|
||
# Calculate the speed between consecutive points (distance / time) | ||
beh_df["speed"] = beh_df["distance"] / beh_df["dt"] | ||
|
||
# Identify the start of each jump | ||
# A jump starts when the speed exceeds the threshold, and the previous speed did not | ||
jump_starts = (beh_df["speed"] > max_speed) & ( | ||
beh_df["speed"].shift(1) <= max_speed | ||
) | ||
|
||
# Mark x and y as NaN only for the first frame of each jump | ||
beh_df.loc[jump_starts, ["x", "y"]] = np.nan | ||
|
||
beh_df = beh_df.drop(columns=["dx", "dy", "distance", "dt", "speed"]) | ||
|
||
return beh_df | ||
|
||
|
||
def filter_tracker_jumps_in_file( | ||
basepath: str, epoch_number=None, epoch_interval=None | ||
) -> None: | ||
""" | ||
Filter out tracker jumps in the behavior data (to NaN) and save the filtered data back to the file. | ||
Parameters | ||
---------- | ||
basepath : str | ||
Basepath to the behavior file. | ||
epoch_number : int, optional | ||
Epoch number to filter the behavior data to. | ||
epoch_interval : tuple, optional | ||
Epoch interval to filter the behavior data to. | ||
Returns | ||
------- | ||
None | ||
Examples | ||
-------- | ||
>>> basepath = "path/to/behavior/file" | ||
>>> filter_tracker_jumps_in_file(basepath, epoch_number=1) | ||
""" | ||
|
||
# Load the behavior data | ||
file = os.path.join(basepath, os.path.basename(basepath) + "animal.behavior.mat") | ||
|
||
behavior = loadmat(file, simplify_cells=True) | ||
|
||
# Filter the behavior data to remove tracker jumps | ||
if epoch_number is not None: | ||
epoch_df = npy.io.load_epoch(basepath) | ||
idx = ( | ||
behavior["behavior"]["timestamps"] > epoch_df.loc[epoch_number].startTime | ||
) & (behavior["behavior"]["timestamps"] < epoch_df.loc[epoch_number].stopTime) | ||
elif epoch_interval is not None: | ||
idx = (behavior["behavior"]["timestamps"] > epoch_interval[0]) & ( | ||
behavior["behavior"]["timestamps"] < epoch_interval[1] | ||
) | ||
else: | ||
# bool length of the same length as the number of timestamps | ||
idx = np.ones(len(behavior["behavior"]["timestamps"]), dtype=bool) | ||
|
||
# Filter the behavior data and add to dataframe | ||
x = behavior["behavior"]["position"]["x"][idx] | ||
y = behavior["behavior"]["position"]["y"][idx] | ||
ts = behavior["behavior"]["timestamps"][idx] | ||
beh_df = pd.DataFrame({"x": x, "y": y, "ts": ts}) | ||
|
||
# Filter out tracker jumps | ||
beh_df = filter_tracker_jumps(beh_df) | ||
|
||
# Save the filtered behavior data back to the file | ||
behavior["behavior"]["position"]["x"][idx] = beh_df.x.values | ||
behavior["behavior"]["position"]["y"][idx] = beh_df.y.values | ||
|
||
savemat(file, behavior, long_field_names=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import numpy as np | ||
import pandas as pd | ||
|
||
# Import the function to be tested | ||
from neuro_py.behavior.preprocessing import filter_tracker_jumps | ||
|
||
|
||
def test_filter_tracker_jumps(): | ||
""" | ||
Test the filter_tracker_jumps function. | ||
""" | ||
# Create a sample DataFrame with tracker jumps | ||
data = { | ||
"x": [0, 1, 2, 100, 3, 4], # Example x coordinates with a jump | ||
"y": [0, 1, 2, 100, 3, 4], # Example y coordinates with a jump | ||
"ts": [0, 1, 2, 3, 4, 5], # Example timestamps | ||
} | ||
beh_df = pd.DataFrame(data) | ||
|
||
# Expected output after filtering jumps | ||
expected_data = { | ||
"x": [0, 1, 2, np.nan, 3, 4], # Jump replaced with NaN | ||
"y": [0, 1, 2, np.nan, 3, 4], # Jump replaced with NaN | ||
"ts": [0, 1, 2, 3, 4, 5], # Timestamps remain unchanged | ||
} | ||
expected_df = pd.DataFrame(expected_data) | ||
|
||
# Call the function to filter jumps | ||
filtered_df = filter_tracker_jumps(beh_df, max_speed=100) | ||
|
||
# Check if the output matches the expected DataFrame | ||
pd.testing.assert_frame_equal(filtered_df, expected_df) | ||
|
||
def test_filter_tracker_jumps_multi_jumps(): | ||
""" | ||
Test the filter_tracker_jumps function. | ||
""" | ||
# Create a sample DataFrame with tracker jumps | ||
data = { | ||
"x": [0, 1, 2, 100, 3, 4, 100, 5], # Example x coordinates with a jump | ||
"y": [0, 1, 2, 100, 3, 4, 100, 5], # Example y coordinates with a jump | ||
"ts": [0, 1, 2, 3, 4, 5, 6, 7], # Example timestamps | ||
} | ||
beh_df = pd.DataFrame(data) | ||
|
||
# Expected output after filtering jumps | ||
expected_data = { | ||
"x": [0, 1, 2, np.nan, 3, 4, np.nan, 5], # Jump replaced with NaN | ||
"y": [0, 1, 2, np.nan, 3, 4, np.nan, 5], # Jump replaced with NaN | ||
"ts": [0, 1, 2, 3, 4, 5, 6, 7], # Timestamps remain unchanged | ||
} | ||
expected_df = pd.DataFrame(expected_data) | ||
|
||
# Call the function to filter jumps | ||
filtered_df = filter_tracker_jumps(beh_df, max_speed=100) | ||
|
||
# Check if the output matches the expected DataFrame | ||
pd.testing.assert_frame_equal(filtered_df, expected_df) | ||
|
||
def test_filter_tracker_jumps_no_jumps(): | ||
""" | ||
Test the filter_tracker_jumps function when there are no jumps. | ||
""" | ||
# Create a sample DataFrame without jumps | ||
data = { | ||
"x": [0, 1, 2, 3, 4], # Example x coordinates without jumps | ||
"y": [0, 1, 2, 3, 4], # Example y coordinates without jumps | ||
"ts": [0, 1, 2, 3, 4], # Example timestamps | ||
} | ||
beh_df = pd.DataFrame(data) | ||
|
||
# Expected output (no changes) | ||
expected_df = beh_df.copy() | ||
|
||
# Call the function to filter jumps | ||
filtered_df = filter_tracker_jumps(beh_df, max_speed=100) | ||
|
||
# Check if the output matches the expected DataFrame | ||
pd.testing.assert_frame_equal(filtered_df, expected_df, check_dtype=False) | ||
|
||
|
||
def test_filter_tracker_jumps_empty_input(): | ||
""" | ||
Test the filter_tracker_jumps function with an empty DataFrame. | ||
""" | ||
# Create an empty DataFrame | ||
beh_df = pd.DataFrame(columns=["x", "y", "ts"]) | ||
|
||
# Expected output (empty DataFrame) | ||
expected_df = beh_df.copy() | ||
|
||
# Call the function to filter jumps | ||
filtered_df = filter_tracker_jumps(beh_df, max_speed=100) | ||
|
||
# Check if the output matches the expected DataFrame | ||
pd.testing.assert_frame_equal(filtered_df, expected_df) |