Skip to content

Commit

Permalink
Merge pull request #52 from ryanharvey1/behavior-preprocessing
Browse files Browse the repository at this point in the history
Behavior preprocessing
  • Loading branch information
ryanharvey1 authored Feb 3, 2025
2 parents e06bca3 + fe8686c commit 94e6b0d
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 0 deletions.
3 changes: 3 additions & 0 deletions neuro_py/behavior/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ __all__ = [
"find_last_non_center_well",
"get_correct_inbound_outbound",
"score_inbound_outbound",
"filter_tracker_jumps",
"filter_tracker_jumps_in_file",
]

from .cheeseboard import plot_grid_with_circle_and_random_dots
Expand Down Expand Up @@ -52,3 +54,4 @@ from .well_traversal_classification import (
segment_path,
shift_well_enters,
)
from .preprocessing import filter_tracker_jumps, filter_tracker_jumps_in_file
115 changes: 115 additions & 0 deletions neuro_py/behavior/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os
from typing import Union

import numpy as np
import pandas as pd
from scipy.io import loadmat, savemat

import neuro_py as npy


def filter_tracker_jumps(
beh_df: pd.DataFrame, max_speed: Union[int, float] = 100
) -> pd.DataFrame:
"""
Filter out tracker jumps (to NaN) in the behavior data.
Parameters
----------
beh_df : pd.DataFrame
Behavior data with columns x, y, and ts.
max_speed : Union[int,float], optional
Maximum allowed speed in pixels per second.
Returns
-------
pd.DataFrame
Notes
-----
Will force dtypes of x and y to float64
"""

# Calculate the Euclidean distance between consecutive points
beh_df["dx"] = beh_df["x"].diff()
beh_df["dy"] = beh_df["y"].diff()
beh_df["distance"] = np.sqrt(beh_df["dx"] ** 2 + beh_df["dy"] ** 2)

# Calculate the time difference between consecutive timestamps
beh_df["dt"] = beh_df["ts"].diff()

# Calculate the speed between consecutive points (distance / time)
beh_df["speed"] = beh_df["distance"] / beh_df["dt"]

# Identify the start of each jump
# A jump starts when the speed exceeds the threshold, and the previous speed did not
jump_starts = (beh_df["speed"] > max_speed) & (
beh_df["speed"].shift(1) <= max_speed
)

# Mark x and y as NaN only for the first frame of each jump
beh_df.loc[jump_starts, ["x", "y"]] = np.nan

beh_df = beh_df.drop(columns=["dx", "dy", "distance", "dt", "speed"])

return beh_df


def filter_tracker_jumps_in_file(
basepath: str, epoch_number=None, epoch_interval=None
) -> None:
"""
Filter out tracker jumps in the behavior data (to NaN) and save the filtered data back to the file.
Parameters
----------
basepath : str
Basepath to the behavior file.
epoch_number : int, optional
Epoch number to filter the behavior data to.
epoch_interval : tuple, optional
Epoch interval to filter the behavior data to.
Returns
-------
None
Examples
--------
>>> basepath = "path/to/behavior/file"
>>> filter_tracker_jumps_in_file(basepath, epoch_number=1)
"""

# Load the behavior data
file = os.path.join(basepath, os.path.basename(basepath) + "animal.behavior.mat")

behavior = loadmat(file, simplify_cells=True)

# Filter the behavior data to remove tracker jumps
if epoch_number is not None:
epoch_df = npy.io.load_epoch(basepath)
idx = (
behavior["behavior"]["timestamps"] > epoch_df.loc[epoch_number].startTime
) & (behavior["behavior"]["timestamps"] < epoch_df.loc[epoch_number].stopTime)
elif epoch_interval is not None:
idx = (behavior["behavior"]["timestamps"] > epoch_interval[0]) & (
behavior["behavior"]["timestamps"] < epoch_interval[1]
)
else:
# bool length of the same length as the number of timestamps
idx = np.ones(len(behavior["behavior"]["timestamps"]), dtype=bool)

# Filter the behavior data and add to dataframe
x = behavior["behavior"]["position"]["x"][idx]
y = behavior["behavior"]["position"]["y"][idx]
ts = behavior["behavior"]["timestamps"][idx]
beh_df = pd.DataFrame({"x": x, "y": y, "ts": ts})

# Filter out tracker jumps
beh_df = filter_tracker_jumps(beh_df)

# Save the filtered behavior data back to the file
behavior["behavior"]["position"]["x"][idx] = beh_df.x.values
behavior["behavior"]["position"]["y"][idx] = beh_df.y.values

savemat(file, behavior, long_field_names=True)
96 changes: 96 additions & 0 deletions tests/test_behavior_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import numpy as np
import pandas as pd

# Import the function to be tested
from neuro_py.behavior.preprocessing import filter_tracker_jumps


def test_filter_tracker_jumps():
"""
Test the filter_tracker_jumps function.
"""
# Create a sample DataFrame with tracker jumps
data = {
"x": [0, 1, 2, 100, 3, 4], # Example x coordinates with a jump
"y": [0, 1, 2, 100, 3, 4], # Example y coordinates with a jump
"ts": [0, 1, 2, 3, 4, 5], # Example timestamps
}
beh_df = pd.DataFrame(data)

# Expected output after filtering jumps
expected_data = {
"x": [0, 1, 2, np.nan, 3, 4], # Jump replaced with NaN
"y": [0, 1, 2, np.nan, 3, 4], # Jump replaced with NaN
"ts": [0, 1, 2, 3, 4, 5], # Timestamps remain unchanged
}
expected_df = pd.DataFrame(expected_data)

# Call the function to filter jumps
filtered_df = filter_tracker_jumps(beh_df, max_speed=100)

# Check if the output matches the expected DataFrame
pd.testing.assert_frame_equal(filtered_df, expected_df)

def test_filter_tracker_jumps_multi_jumps():
"""
Test the filter_tracker_jumps function.
"""
# Create a sample DataFrame with tracker jumps
data = {
"x": [0, 1, 2, 100, 3, 4, 100, 5], # Example x coordinates with a jump
"y": [0, 1, 2, 100, 3, 4, 100, 5], # Example y coordinates with a jump
"ts": [0, 1, 2, 3, 4, 5, 6, 7], # Example timestamps
}
beh_df = pd.DataFrame(data)

# Expected output after filtering jumps
expected_data = {
"x": [0, 1, 2, np.nan, 3, 4, np.nan, 5], # Jump replaced with NaN
"y": [0, 1, 2, np.nan, 3, 4, np.nan, 5], # Jump replaced with NaN
"ts": [0, 1, 2, 3, 4, 5, 6, 7], # Timestamps remain unchanged
}
expected_df = pd.DataFrame(expected_data)

# Call the function to filter jumps
filtered_df = filter_tracker_jumps(beh_df, max_speed=100)

# Check if the output matches the expected DataFrame
pd.testing.assert_frame_equal(filtered_df, expected_df)

def test_filter_tracker_jumps_no_jumps():
"""
Test the filter_tracker_jumps function when there are no jumps.
"""
# Create a sample DataFrame without jumps
data = {
"x": [0, 1, 2, 3, 4], # Example x coordinates without jumps
"y": [0, 1, 2, 3, 4], # Example y coordinates without jumps
"ts": [0, 1, 2, 3, 4], # Example timestamps
}
beh_df = pd.DataFrame(data)

# Expected output (no changes)
expected_df = beh_df.copy()

# Call the function to filter jumps
filtered_df = filter_tracker_jumps(beh_df, max_speed=100)

# Check if the output matches the expected DataFrame
pd.testing.assert_frame_equal(filtered_df, expected_df, check_dtype=False)


def test_filter_tracker_jumps_empty_input():
"""
Test the filter_tracker_jumps function with an empty DataFrame.
"""
# Create an empty DataFrame
beh_df = pd.DataFrame(columns=["x", "y", "ts"])

# Expected output (empty DataFrame)
expected_df = beh_df.copy()

# Call the function to filter jumps
filtered_df = filter_tracker_jumps(beh_df, max_speed=100)

# Check if the output matches the expected DataFrame
pd.testing.assert_frame_equal(filtered_df, expected_df)

0 comments on commit 94e6b0d

Please sign in to comment.