diff --git a/src/spikegadgets_to_nwb/convert.py b/src/spikegadgets_to_nwb/convert.py index 1911fa6..980de6a 100644 --- a/src/spikegadgets_to_nwb/convert.py +++ b/src/spikegadgets_to_nwb/convert.py @@ -2,6 +2,7 @@ import pandas as pd +from spikegadgets_to_nwb.convert_dios import add_dios from spikegadgets_to_nwb.convert_ephys import add_raw_ephys from spikegadgets_to_nwb.convert_position import add_position from spikegadgets_to_nwb.convert_rec_header import ( @@ -14,7 +15,6 @@ add_acquisition_devices, add_associated_files, add_cameras, - add_dios, add_electrode_groups, add_subject, add_tasks, diff --git a/src/spikegadgets_to_nwb/convert_analog.py b/src/spikegadgets_to_nwb/convert_analog.py index bf6a4ad..30e843d 100644 --- a/src/spikegadgets_to_nwb/convert_analog.py +++ b/src/spikegadgets_to_nwb/convert_analog.py @@ -38,10 +38,17 @@ def add_analog_data(nwbfile: NWBFile, rec_file_path: list[str], **kwargs) -> Non analog_channel_ids.append(channel.attrib["id"]) # make the data chunk iterator + # TODO use the stream name instead of the stream index to be more robust rec_dci = RecFileDataChunkIterator( - rec_file_path, nwb_hw_channel_order=analog_channel_ids, stream_index=0 + rec_file_path, + nwb_hw_channel_order=analog_channel_ids, + stream_index=2, + is_analog=True, ) + # add headstage channel IDs to the list of analog channel IDs + analog_channel_ids.extend(rec_dci.neo_io[0].multiplexed_channel_xml.keys()) + # (16384, 32) chunks of dtype int16 (2 bytes) is 1 MB, which is recommended # by studies by the NWB team. # could also add compression here. zstd/blosc-zstd are recommended by the NWB team, but diff --git a/src/spikegadgets_to_nwb/convert_dios.py b/src/spikegadgets_to_nwb/convert_dios.py index e69de29..20eb9f8 100644 --- a/src/spikegadgets_to_nwb/convert_dios.py +++ b/src/spikegadgets_to_nwb/convert_dios.py @@ -0,0 +1,83 @@ +import numpy as np +from pynwb import NWBFile, TimeSeries +from pynwb.behavior import BehavioralEvents + +from .spike_gadgets_raw_io import SpikeGadgetsRawIO + + +def _get_channel_name_map(metadata: dict) -> dict[str, str]: + """Parses behavioral events metadata from the yaml file + + Parameters + ---------- + metadata : dict + metadata from the yaml generator + + Returns + ------- + channel_name_map : dict + Parsed behavioral events metadata mapping hardware event name to human-readable name + """ + dio_metadata = metadata["behavioral_events"] + channel_name_map = {} + for dio_event in dio_metadata: + channel_name_map[dio_event["description"]] = dio_event["name"] + return channel_name_map + + +def add_dios(nwbfile: NWBFile, recfile: list[str], metadata: dict) -> None: + """Adds DIO event information and data to nwb file + + Parameters + ---------- + nwbfile : NWBFile + nwb file being assembled + recfile : list[str] + list of paths to rec files + metadata : dict + metadata from the yaml generator + """ + + # TODO remove redundancy with convert_ephys.py + neo_io = [ + SpikeGadgetsRawIO(filename=file) for file in recfile + ] # get all streams for all files + [neo_io.parse_header() for neo_io in neo_io] + + # Make a processing module for behavior and add to the nwbfile + if not "behavior" in nwbfile.processing: + nwbfile.create_processing_module( + name="behavior", description="Contains all behavior-related data" + ) + + # Make BehavioralEvents object to hold DIO data + beh_events = BehavioralEvents(name="behavioral_events") + + # Map hardware event name (encoded in `description` in metadata YAML) + # to a human-readable name (encoded in `name`) + channel_name_map = _get_channel_name_map(metadata) + + # Loop through the channels from the metadata YAML and add a TimeSeries for each one + stream_name = "ECU_digital" + for channel_name in channel_name_map: + # merge streams from multiple files + all_timestamps = np.array([], dtype=np.float64) + all_state_changes = np.array([], dtype=np.uint8) + for io in neo_io: + timestamps, state_changes = io.get_digitalsignal( + stream_name, "ECU_" + channel_name + ) + all_timestamps = np.concatenate((all_timestamps, timestamps)) + all_state_changes = np.concatenate((all_state_changes, state_changes)) + + ts = TimeSeries( + name=channel_name_map[channel_name], + description=channel_name, + data=all_state_changes, + unit="-1", # TODO change to "N/A", + timestamps=all_timestamps, # TODO adjust timestamps + ) + beh_events.add_timeseries(ts) + + # Add the BehavioralEvents object to the file + nwbfile.processing["behavior"].add(beh_events) diff --git a/src/spikegadgets_to_nwb/convert_ephys.py b/src/spikegadgets_to_nwb/convert_ephys.py index 1582ae2..2aa9f0a 100644 --- a/src/spikegadgets_to_nwb/convert_ephys.py +++ b/src/spikegadgets_to_nwb/convert_ephys.py @@ -21,20 +21,26 @@ def __init__( rec_file_path: list[str], nwb_hw_channel_order=[], conversion: float = 1.0, - stream_index: int = 1, + stream_index: int = 3, # TODO use the stream name instead of the index + is_analog: bool = False, **kwargs, ): self.conversion = conversion + self.is_analog = is_analog self.neo_io = [ SpikeGadgetsRawIO(filename=file) for file in rec_file_path ] # get all streams for all files [neo_io.parse_header() for neo_io in self.neo_io] # TODO see what else spikeinterface does and whether it is necessary - # for now, make sure that there is only one block, one segment, and two streams + # for now, make sure that there is only one block, one segment, and four streams: + # Controller_DIO_digital + # ECU_digital + # ECU_analog + # trodes assert all([neo_io.block_count() == 1 for neo_io in self.neo_io]) assert all([neo_io.segment_count(0) == 1 for neo_io in self.neo_io]) - assert all([neo_io.signal_streams_count() == 2 for neo_io in self.neo_io]) + assert all([neo_io.signal_streams_count() == 4 for neo_io in self.neo_io]) self.block_index = 0 self.seg_index = 0 @@ -64,6 +70,9 @@ def __init__( self.n_channel = self.neo_io[0].signal_channels_count( stream_index=self.stream_index ) + self.n_multiplexed_channel = 0 + if self.is_analog: + self.n_multiplexed_channel += len(self.neo_io[0].multiplexed_channel_xml) # order that the hw channels are in within the nwb table if len(nwb_hw_channel_order) == 0: # TODO: raise error instead? @@ -93,11 +102,18 @@ def _get_data(self, selection: Tuple[slice]) -> np.ndarray: # DCI will want channels 0 to X first to put into the array in that order # those are stored in the file as channel IDs # make into list form passed to neo_io - channel_ids = [str(x) for x in self.nwb_hw_channel_order[selection[1]]] + selection_list = list(selection) + if self.is_analog: + selection_list[1] = slice( + selection[1].start, + min(selection[1].stop, self.n_channel), + selection[1].step, + ) + channel_ids = [str(x) for x in self.nwb_hw_channel_order[selection_list[1]]] # what global index each file starts at file_start_ind = np.append(np.zeros(1), np.cumsum(self.n_time)) # the time indexes we want - time_index = np.arange(self._get_maxshape()[0])[selection[0]] + time_index = np.arange(self._get_maxshape()[0])[selection_list[0]] data = [] i = time_index[0] while i < min(time_index[-1], self._get_maxshape()[0]): @@ -134,12 +150,13 @@ def _get_data(self, selection: Tuple[slice]) -> np.ndarray: time_index[-1] - i, # if finished in this stream ) data = (np.array(data) * self.conversion).astype("int16") + return data def _get_maxshape(self) -> Tuple[int, int]: return ( np.sum(self.n_time), - self.n_channel, + self.n_channel + self.n_multiplexed_channel, ) # TODO: Is this right for maxshape @rly def _get_dtype(self) -> np.dtype: diff --git a/src/spikegadgets_to_nwb/convert_yaml.py b/src/spikegadgets_to_nwb/convert_yaml.py index 35a1069..76efdf1 100644 --- a/src/spikegadgets_to_nwb/convert_yaml.py +++ b/src/spikegadgets_to_nwb/convert_yaml.py @@ -365,43 +365,6 @@ def add_tasks(nwbfile: NWBFile, metadata: dict) -> None: nwbfile.processing["tasks"].add(task) -def add_dios(nwbfile: NWBFile, metadata: dict) -> None: - """Adds DIO event information and data to nwb file - - Parameters - ---------- - nwbfile : NWBFile - nwb file being assembled - metadata : dict - metadata from the yaml generator - """ - # TODO: pass the dio data and include in this - # Make a processing module for behavior and add to the nwbfile - if not "behavior" in nwbfile.processing: - nwbfile.create_processing_module( - name="behavior", description="Contains all behavior-related data" - ) - # Make Behavioral events object to hold DIO data - events = BehavioralEvents(name="behavioral_events") - # Loop through and add timeseries for each one - dio_metadata = metadata["behavioral_events"] - for dio_event in dio_metadata: - events.add_timeseries( - TimeSeries( - name=dio_event["name"], - description=dio_event["description"], - data=np.array( - [] - ), # TODO: from rec file // self.data[dio_event['description']], - unit="N/A", - timestamps=np.array([]), - # TODO: data, timestamps, - ) - ) - # add it to your file - nwbfile.processing["behavior"].add(events) - - def add_associated_files(nwbfile: NWBFile, metadata: dict) -> None: """Adds associated files processing module. Reads in file referenced in metadata and stores in processing diff --git a/src/spikegadgets_to_nwb/spike_gadgets_raw_io.py b/src/spikegadgets_to_nwb/spike_gadgets_raw_io.py index 498e22b..c1e0fab 100644 --- a/src/spikegadgets_to_nwb/spike_gadgets_raw_io.py +++ b/src/spikegadgets_to_nwb/spike_gadgets_raw_io.py @@ -81,7 +81,7 @@ def _parse_header(self): # explore xml header root = ElementTree.fromstring(header_txt) - gconf = sr = root.find("GlobalConfiguration") + gconf = root.find("GlobalConfiguration") hconf = root.find("HardwareConfiguration") sconf = root.find("SpikeConfiguration") @@ -98,11 +98,11 @@ def _parse_header(self): # explore sub stream and count packet size # first bytes is 0x55 packet_size = 1 - stream_bytes = {} + device_bytes = {} for device in hconf: - stream_id = device.attrib["name"] + device_name = device.attrib["name"] num_bytes = int(device.attrib["numBytes"]) - stream_bytes[stream_id] = packet_size + device_bytes[device_name] = packet_size packet_size += num_bytes # timestamps 4 uint32 @@ -127,12 +127,46 @@ def _parse_header(self): signal_streams = [] signal_channels = [] - # walk in xml device and keep only "analog" one + self._mask_channels_ids = {} self._mask_channels_bytes = {} self._mask_channels_bits = {} # for digital data + + self.multiplexed_channel_xml = {} # dictionary from id to channel xml + self._multiplexed_byte_start = device_bytes["Multiplexed"] + + # walk through xml devices for device in hconf: - stream_id = device.attrib["name"] + device_name = device.attrib["name"] for channel in device: + if ( + device.attrib["name"] == "Multiplexed" + and channel.attrib["dataType"] == "analog" + ): + # the multiplexed analog device has interleaved data from multiple sources + # that are sampled at a lower rate. + # for each packet, + # the interleavedDataIDByte and the interleavedDataIDBit indicate which + # channel has an updated value. + # the startByte contains the int16 updated value. + # if there was no update, use the last value received. + # thus, there is a value at every timestamp, but usually it will be the same + # as the previous value. + # it is assumed that for a given startByte, only one of the + # interleavedDataIDByte and interleavedDataIDBit combinations that + # use that startByte is active at any given timestamp, + # i.e. there should be at most one 1 in the interleavedDataIDByte value + # at each timestamp. + + # the typical mask approach will not work, so store the channel specs + # and use them to read the analog data on demand. + self.multiplexed_channel_xml[channel.attrib["id"]] = channel + continue + + # one device can have streams with different data types, + # so create a stream_id that differentiates them. + # users need to be aware of this when using the API + stream_id = device_name + "_" + channel.attrib["dataType"] + if "interleavedDataIDByte" in channel.attrib: # TODO LATER: deal with "headstageSensor" which have interleaved continue @@ -142,6 +176,7 @@ def _parse_header(self): stream_ids.append(stream_id) stream_name = stream_id signal_streams.append((stream_name, stream_id)) + self._mask_channels_ids[stream_id] = [] self._mask_channels_bytes[stream_id] = [] self._mask_channels_bits[stream_id] = [] @@ -165,7 +200,9 @@ def _parse_header(self): ) ) - num_bytes = stream_bytes[stream_id] + int( + self._mask_channels_ids[stream_id].append(channel.attrib["id"]) + + num_bytes = device_bytes[device_name] + int( channel.attrib["startByte"] ) chan_mask_bytes = np.zeros(packet_size, dtype="bool") @@ -175,31 +212,54 @@ def _parse_header(self): chan_mask_bits = np.zeros(packet_size * 8, dtype="bool") # TODO self._mask_channels_bits[stream_id].append(chan_mask_bits) - # elif channel.attrib['dataType'] == 'digital' and channel.attrib['input'] == '0': # handle DIO - - # if stream_id not in stream_ids: - # stream_ids.append(stream_id) - # stream_name = stream_id - # signal_streams.append((stream_name, stream_id)) - # self._mask_channels_bytes[stream_id] = [] - # self._mask_channels_bits[stream_id] = [] - - # name = channel.attrib['id'] - # chan_id = channel.attrib['id'] - # dtype = 'bool' - # # TODO LATER : handle gain correctly according the file version - # units = '' - # gain = 1. - # offset = 0. - # signal_channels.append((name, chan_id, self._sampling_rate, dtype, - # units, gain, offset, stream_id)) - - # start_bit = stream_bytes[stream_id] + int(channel.attrib['startByte']) + int(channel.attrib['bit']) - # chan_mask_bytes = np.zeros(packet_size, dtype='bool') # TODO - # self._mask_channels_bytes[stream_id].append(chan_mask_bytes) - # chan_mask_bits = np.zeros(packet_size * 8, dtype='bool') - # chan_mask_bits[start_bit] = True - # self._mask_channels_bits[stream_id].append(chan_mask_bits) + elif channel.attrib["dataType"] == "digital": # handle DIO + if stream_id not in stream_ids: + stream_ids.append(stream_id) + stream_name = stream_id + signal_streams.append((stream_name, stream_id)) + self._mask_channels_ids[stream_id] = [] + self._mask_channels_bytes[stream_id] = [] + self._mask_channels_bits[stream_id] = [] + + # NOTE store data in signal_channels to make neo happy + name = channel.attrib["id"] + chan_id = channel.attrib["id"] + dtype = "int8" + units = "" + gain = 1.0 + offset = 0.0 + + signal_channels.append( + ( + name, + chan_id, + self._sampling_rate, + dtype, + units, + gain, + offset, + stream_id, + ) + ) + + self._mask_channels_ids[stream_id].append(channel.attrib["id"]) + + # to handle digital data, need to split the data by bits + num_bytes = device_bytes[device_name] + int( + channel.attrib["startByte"] + ) + chan_byte_mask = np.zeros(packet_size, dtype="bool") + chan_byte_mask[num_bytes] = True + self._mask_channels_bytes[stream_id].append(chan_byte_mask) + + # within the concatenated, masked bytes, mask the bit (flipped order) + chan_bit_mask = np.zeros(8 * 1, dtype="bool") + chan_bit_mask[int(channel.attrib["bit"])] = True + chan_bit_mask = np.flip(chan_bit_mask) + self._mask_channels_bits[stream_id].append(chan_bit_mask) + + # NOTE: _mask_channels_ids, _mask_channels_bytes, and + # _mask_channels_bits are parallel lists if num_ephy_channels > 0: stream_id = "trodes" @@ -347,6 +407,12 @@ def _get_analogsignal_chunk( if re_order is not None: raw_unit16 = raw_unit16[:, re_order] + if stream_id == "ECU_analog": + # automatically include the interleaved analog signals: + analog_multiplexed_data = self.get_analogsignal_multiplexed() + analog_multiplexed_data = analog_multiplexed_data[i_start:i_stop, :] + raw_unit16 = np.concatenate((raw_unit16, analog_multiplexed_data), axis=1) + return raw_unit16 def get_analogsignal_timestamps(self, i_start, i_stop): @@ -355,3 +421,118 @@ def get_analogsignal_timestamps(self, i_start, i_stop): ] raw_uint32 = raw_uint8.flatten().view("uint32") return raw_uint32 + + def get_analogsignal_multiplexed(self, channel_names=None) -> dict[str, np.ndarray]: + if channel_names is None: + # read all multiplexed channels + channel_names = list(self.multiplexed_channel_xml.keys()) + else: + for ch_name in channel_names: + if ch_name not in self.multiplexed_channel_xml: + raise ValueError(f"Channel name '{ch_name}' not found in file.") + + # because of the encoding scheme, it is easiest to read all the data in sequence + # one packet at a time + num_packet = self._raw_memmap.shape[0] + analog_multiplexed_data = np.empty( + (num_packet, len(channel_names)), dtype=np.int16 + ) + for i, packet in enumerate(self._raw_memmap): + for j, ch_name in enumerate(channel_names): + ch_xml = self.multiplexed_channel_xml[ch_name] + interleaved_data_id_byte = int(ch_xml.attrib["interleavedDataIDByte"]) + interleaved_data_id_bit = int(ch_xml.attrib["interleavedDataIDBit"]) + relative_start_byte = int(ch_xml.attrib["startByte"]) + start_byte = self._multiplexed_byte_start + relative_start_byte + interleaved_data_id_byte_value = packet[ + self._multiplexed_byte_start + interleaved_data_id_byte + ] + interleaved_data_id_bit_value = np.flip( + np.unpackbits(interleaved_data_id_byte_value) + )[interleaved_data_id_bit] + if ( + i == 0 or interleaved_data_id_bit_value == 1 + ): # initialize the stream or record new value + data = packet[start_byte : start_byte + 2].view("int16")[ + 0 + ] # int16 = two bytes + else: # copy last value + data = analog_multiplexed_data[i - 1, j] + analog_multiplexed_data[i, j] = data + + return analog_multiplexed_data + + def get_digitalsignal(self, stream_id, channel_id): + # stream_id = self.header["signal_streams"][stream_index]["id"] + + # for now, allow only reading the entire dataset + i_start = 0 + i_stop = self._raw_memmap.shape[0] + raw_packets = self._raw_memmap[i_start:i_stop] + + channel_index = -1 + for i, chan_id in enumerate(self._mask_channels_ids[stream_id]): + if chan_id == channel_id: + channel_index = i + break + assert ( + channel_index >= 0 + ), f"channel_id {channel_id} not found in stream {stream_id}" + + # num_chan = len(self._mask_channels_bytes[stream_id]) + # re_order = None + # if channel_indexes is None: + # # no loop : entire stream mask + # stream_mask = self._mask_streams[stream_id] + # else: + # # accumulate mask + # if isinstance(channel_indexes, slice): + # chan_inds = np.arange(num_chan)[channel_indexes] + # else: + # chan_inds = channel_indexes + + # if np.any(np.diff(channel_indexes) < 0): + # # handle channel are not ordered + # sorted_channel_indexes = np.sort(channel_indexes) + # re_order = np.array( + # [ + # list(sorted_channel_indexes).index(ch) + # for ch in channel_indexes + # ] + # ) + + # stream_mask = np.zeros(raw_packets.shape[1], dtype="bool") + # for chan_ind in chan_inds: + # chan_mask = self._mask_channels_bytes[stream_id][chan_ind] + # stream_mask |= chan_mask + + # this copies the data from the memmap into memory + byte_mask = self._mask_channels_bytes[stream_id][channel_index] + raw_packets_masked = raw_packets[:, byte_mask] + + bit_mask = self._mask_channels_bits[stream_id][channel_index] + continuous_dio = np.unpackbits(raw_packets_masked, axis=1)[ + :, bit_mask + ].flatten() + change_dir = np.diff(continuous_dio).astype( + np.int8 + ) # possible values: [-1, 0, 1] + change_dir_trim = change_dir[change_dir != 0] # keeps -1 and 1 + change_dir_trim[change_dir_trim == -1] = 0 # change -1 to 0 + # resulting array has 1 when there is a change from 0 to 1, + # 0 when there is change from 1 to 0 + + # track the timestamps when there is a change from 0 to 1 or 1 to 0 + timestamps = self.get_analogsignal_timestamps(i_start, i_stop) + dio_change_times = timestamps[np.where(change_dir)[0] + 1] + + # insert the first timestamp with the first value + dio_change_times = np.insert(dio_change_times, 0, timestamps[0]) + change_dir_trim = np.insert(change_dir_trim, 0, continuous_dio[0]) + + change_dir_trim = change_dir_trim.astype(np.uint8) + + # if re_order is not None: + # raw_unit16 = raw_unit16[:, re_order] + + return dio_change_times, change_dir_trim diff --git a/src/spikegadgets_to_nwb/tests/test_convert_analog.py b/src/spikegadgets_to_nwb/tests/test_convert_analog.py index d5befbc..e2a060b 100644 --- a/src/spikegadgets_to_nwb/tests/test_convert_analog.py +++ b/src/spikegadgets_to_nwb/tests/test_convert_analog.py @@ -53,6 +53,7 @@ def test_add_analog_data(): "analog" ].description.split(" ")[:-1] index_order = [old_id_order.index(id) for id in id_order] + # TODO check that all the same channels are present # compare data assert ( diff --git a/src/spikegadgets_to_nwb/tests/test_convert_dios.py b/src/spikegadgets_to_nwb/tests/test_convert_dios.py index e69de29..cf75e4f 100644 --- a/src/spikegadgets_to_nwb/tests/test_convert_dios.py +++ b/src/spikegadgets_to_nwb/tests/test_convert_dios.py @@ -0,0 +1,124 @@ +import numpy as np +import os +import pynwb +from spikegadgets_to_nwb.convert_dios import add_dios +from spikegadgets_to_nwb import convert_yaml, convert_rec_header +from spikegadgets_to_nwb.tests.test_convert_rec_header import default_test_xml_tree + +path = os.path.dirname(os.path.abspath(__file__)) + + +def test_add_dios_single_rec(): + # load metadata yml and make nwb file + metadata_path = path + "/test_data/20230622_sample_metadata.yml" + probe_metadata = [ + path + "/test_data/tetrode_12.5.yml", + ] + metadata, _ = convert_yaml.load_metadata(metadata_path, probe_metadata) + nwbfile = convert_yaml.initialize_nwb(metadata, default_test_xml_tree()) + + try: + # running on github + recfile = [os.environ.get("DOWNLOAD_DIR") + "/20230622_sample_01_a1.rec"] + rec_to_nwb_file = os.environ.get("DOWNLOAD_DIR") + "/20230622_155936.nwb" + except (TypeError, FileNotFoundError): + # running locally + recfile = [path + "/test_data/20230622_sample_01_a1.rec"] + rec_to_nwb_file = path + "/test_data/20230622_155936.nwb" + + add_dios(nwbfile, recfile, metadata) + + filename = "test_add_dios_single_rec.nwb" + with pynwb.NWBHDF5IO(filename, "w") as io: + io.write(nwbfile) + + with pynwb.NWBHDF5IO(filename, "r", load_namespaces=True) as io: + read_nwbfile = io.read() + assert "behavior" in read_nwbfile.processing + assert ( + "behavioral_events" in read_nwbfile.processing["behavior"].data_interfaces + ) + expected_dios = ["Poke_1", "Light_1", "Light_2"] + for name in expected_dios: + assert ( + name + in read_nwbfile.processing["behavior"]["behavioral_events"].time_series + ) + + with pynwb.NWBHDF5IO(rec_to_nwb_file, "r", load_namespaces=True) as io2: + old_nwbfile = io2.read() + for old_dio in old_nwbfile.processing["behavior"][ + "behavioral_events" + ].time_series.values(): + current_dio = read_nwbfile.processing["behavior"]["behavioral_events"][ + old_dio.name + ] + # check that timeseries match + np.testing.assert_array_equal(current_dio.data, old_dio.data) + np.testing.assert_array_equal( + current_dio.timestamps, old_dio.timestamps + ) + assert current_dio.unit == old_dio.unit + assert current_dio.description == old_dio.description + + os.remove(filename) + + +def test_add_dios_two_epoch(): + # load metadata yml and make nwb file + metadata_path = path + "/test_data/20230622_sample_metadata.yml" + probe_metadata = [ + path + "/test_data/tetrode_12.5.yml", + ] + metadata, _ = convert_yaml.load_metadata(metadata_path, probe_metadata) + nwbfile = convert_yaml.initialize_nwb(metadata, default_test_xml_tree()) + + try: + # running on github + recfile = [ + os.environ.get("DOWNLOAD_DIR") + "/20230622_sample_01_a1.rec", + os.environ.get("DOWNLOAD_DIR") + "/20230622_sample_02_a1.rec", + ] + rec_to_nwb_file = os.environ.get("DOWNLOAD_DIR") + "/minirec20230622_.nwb" + except (TypeError, FileNotFoundError): + # running locally + recfile = [ + path + "/test_data/20230622_sample_01_a1.rec", + path + "/test_data/20230622_sample_02_a1.rec", + ] + rec_to_nwb_file = path + "/test_data/minirec20230622_.nwb" + + add_dios(nwbfile, recfile, metadata) + + filename = "test_add_dios_two_epoch.nwb" + with pynwb.NWBHDF5IO(filename, "w") as io: + io.write(nwbfile) + + with pynwb.NWBHDF5IO(filename, "r", load_namespaces=True) as io: + read_nwbfile = io.read() + assert "behavior" in read_nwbfile.processing + assert ( + "behavioral_events" in read_nwbfile.processing["behavior"].data_interfaces + ) + expected_dios = ["Poke_1", "Light_1", "Light_2"] + for name in expected_dios: + assert ( + name + in read_nwbfile.processing["behavior"]["behavioral_events"].time_series + ) + + with pynwb.NWBHDF5IO(rec_to_nwb_file, "r", load_namespaces=True) as io2: + old_nwbfile = io2.read() + for old_dio in old_nwbfile.processing["behavior"][ + "behavioral_events" + ].time_series.values(): + current_dio = read_nwbfile.processing["behavior"]["behavioral_events"][ + old_dio.name + ] + # check that timeseries match + np.testing.assert_array_equal(current_dio.data, old_dio.data) + # np.testing.assert_array_equal(current_dio.timestamps, old_dio.timestamps) # TODO uncomment + assert current_dio.unit == old_dio.unit + assert current_dio.description == old_dio.description + + os.remove(filename) diff --git a/src/spikegadgets_to_nwb/tests/test_data/20230622_sample_metadata.yml b/src/spikegadgets_to_nwb/tests/test_data/20230622_sample_metadata.yml index f56fa4f..9531e6b 100644 --- a/src/spikegadgets_to_nwb/tests/test_data/20230622_sample_metadata.yml +++ b/src/spikegadgets_to_nwb/tests/test_data/20230622_sample_metadata.yml @@ -83,18 +83,13 @@ units: times_period_multiplier: 1 raw_data_to_volts: 1.95e-07 default_header_file_path: /stelmo/sam/test_data -behavioral_events: - - description: Din1 - name: Light_1 - - description: Din1 - name: Light_2 device: name: - device1 behavioral_events: - description: Din1 name: Light_1 - - description: Din1 + - description: Din2 name: Light_2 - description: Dout2 name: Poke_1 diff --git a/src/spikegadgets_to_nwb/tests/test_data/20230622_sample_metadataProbeReconfig.yml b/src/spikegadgets_to_nwb/tests/test_data/20230622_sample_metadataProbeReconfig.yml index 9f731d5..20bf195 100644 --- a/src/spikegadgets_to_nwb/tests/test_data/20230622_sample_metadataProbeReconfig.yml +++ b/src/spikegadgets_to_nwb/tests/test_data/20230622_sample_metadataProbeReconfig.yml @@ -83,18 +83,13 @@ units: times_period_multiplier: 1 raw_data_to_volts: 1.95e-07 default_header_file_path: default/header/path/ -behavioral_events: - - description: Din1 - name: Light_1 - - description: Din1 - name: Light_2 device: name: - device1 behavioral_events: - description: Din1 name: Light_1 - - description: Din1 + - description: Din2 name: Light_2 - description: Dout2 name: Poke_1