Skip to content

Commit

Permalink
Add add_epochs function and test
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbray32 committed Aug 14, 2023
1 parent c6fecee commit d332fdd
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/spikegadgets_to_nwb/convert_intervals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pandas as pd
from pynwb import NWBFile
from spikegadgets_to_nwb.convert_rec_header import read_header


def add_epochs(nwbfile: NWBFile, file_info: pd.DataFrame, date: int, animal: str):
session_info = file_info[(file_info.date == date) & (file_info.animal == animal)]
for epoch in set(session_info.epoch):
rec_file_list = session_info[
(session_info.epoch == epoch) & (session_info.file_extension == ".rec")
]
start_time = None
end_time = None
for rec_path in rec_file_list.full_path:
file_start_time = get_file_start_time(rec_path)
if start_time is None or file_start_time < start_time:
start_time = file_start_time
end_time = 0.0

tag = f"{epoch:02d}_{rec_file_list.tag.iloc[0]}"
nwbfile.add_epoch(start_time, end_time, tag)
return


def get_file_start_time(rec_file: str) -> float:
header = read_header(rec_file)
gconf = header.find("GlobalConfiguration")
return float(gconf.attrib["systemTimeAtCreation"].strip()) / 1000.0
29 changes: 29 additions & 0 deletions src/spikegadgets_to_nwb/tests/test_convert_intervals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from pathlib import Path
import os

from spikegadgets_to_nwb.data_scanner import get_file_info
from spikegadgets_to_nwb.convert_intervals import add_epochs
from spikegadgets_to_nwb.convert_yaml import initialize_nwb, load_metadata
from spikegadgets_to_nwb.tests.test_convert_rec_header import default_test_xml_tree

path = os.path.dirname(os.path.abspath(__file__))


def test_add_epochs():
metadata_path = path + "/test_data/20230622_sample_metadata.yml"
metadata, _ = load_metadata(metadata_path, [])
nwbfile = initialize_nwb(metadata, default_test_xml_tree())
try:
# running on github
file_info = get_file_info(Path(os.environ.get("DOWNLOAD_DIR")))
except (TypeError, FileNotFoundError):
# running locally
file_info = get_file_info(Path(path))
add_epochs(nwbfile, file_info, 20230622, "sample")
epochs_df = nwbfile.epochs.to_dataframe()

assert len(epochs_df) == 2
assert list(epochs_df.index) == [0, 1]
assert list(epochs_df.tags) == [["01_a1"], ["02_a1"]]
assert list(epochs_df.start_time) == [1687474797.888, 1687474821.109]
assert list(epochs_df.stop_time) == [0.0, 0.0]

0 comments on commit d332fdd

Please sign in to comment.