diff --git a/src/spikegadgets_to_nwb/convert_intervals.py b/src/spikegadgets_to_nwb/convert_intervals.py index e69de29..441910d 100644 --- a/src/spikegadgets_to_nwb/convert_intervals.py +++ b/src/spikegadgets_to_nwb/convert_intervals.py @@ -0,0 +1,28 @@ +import pandas as pd +from pynwb import NWBFile +from spikegadgets_to_nwb.convert_rec_header import read_header + + +def add_epochs(nwbfile: NWBFile, file_info: pd.DataFrame, date: int, animal: str): + session_info = file_info[(file_info.date == date) & (file_info.animal == animal)] + for epoch in set(session_info.epoch): + rec_file_list = session_info[ + (session_info.epoch == epoch) & (session_info.file_extension == ".rec") + ] + start_time = None + end_time = None + for rec_path in rec_file_list.full_path: + file_start_time = get_file_start_time(rec_path) + if start_time is None or file_start_time < start_time: + start_time = file_start_time + end_time = 0.0 + + tag = f"{epoch:02d}_{rec_file_list.tag.iloc[0]}" + nwbfile.add_epoch(start_time, end_time, tag) + return + + +def get_file_start_time(rec_file: str) -> float: + header = read_header(rec_file) + gconf = header.find("GlobalConfiguration") + return float(gconf.attrib["systemTimeAtCreation"].strip()) / 1000.0 diff --git a/src/spikegadgets_to_nwb/tests/test_convert_intervals.py b/src/spikegadgets_to_nwb/tests/test_convert_intervals.py index e69de29..5674552 100644 --- a/src/spikegadgets_to_nwb/tests/test_convert_intervals.py +++ b/src/spikegadgets_to_nwb/tests/test_convert_intervals.py @@ -0,0 +1,29 @@ +from pathlib import Path +import os + +from spikegadgets_to_nwb.data_scanner import get_file_info +from spikegadgets_to_nwb.convert_intervals import add_epochs +from spikegadgets_to_nwb.convert_yaml import initialize_nwb, load_metadata +from spikegadgets_to_nwb.tests.test_convert_rec_header import default_test_xml_tree + +path = os.path.dirname(os.path.abspath(__file__)) + + +def test_add_epochs(): + metadata_path = path + "/test_data/20230622_sample_metadata.yml" + metadata, _ = load_metadata(metadata_path, []) + nwbfile = initialize_nwb(metadata, default_test_xml_tree()) + try: + # running on github + file_info = get_file_info(Path(os.environ.get("DOWNLOAD_DIR"))) + except (TypeError, FileNotFoundError): + # running locally + file_info = get_file_info(Path(path)) + add_epochs(nwbfile, file_info, 20230622, "sample") + epochs_df = nwbfile.epochs.to_dataframe() + + assert len(epochs_df) == 2 + assert list(epochs_df.index) == [0, 1] + assert list(epochs_df.tags) == [["01_a1"], ["02_a1"]] + assert list(epochs_df.start_time) == [1687474797.888, 1687474821.109] + assert list(epochs_df.stop_time) == [0.0, 0.0]