Skip to content

Commit

Permalink
Make config_file_path optional on DeepLabCutInterface (#1031)
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Sep 4, 2024
1 parent ee70ba3 commit 675ec48
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 47 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* The following classes and objects are now private `NWBMetaDataEncoder`, `NWBMetaDataEncoder`, `check_if_imaging_fits_into_memory`, `NoDatesSafeLoader` [PR #1050](https://github.com/catalystneuro/neuroconv/pull/1050)

### Features
* Make `config_file_path` optional in `DeepLabCutInterface`[PR #1031](https://github.com/catalystneuro/neuroconv/pull/1031)
* Added `get_stream_names` to `OpenEphysRecordingInterface`: [PR #1039](https://github.com/catalystneuro/neuroconv/pull/1039)

### Improvements
Expand Down
104 changes: 65 additions & 39 deletions src/neuroconv/datainterfaces/behavior/deeplabcut/_dlc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,31 @@
from ruamel.yaml import YAML


def _read_config(config_file_path):
def _read_config(config_file_path: FilePath) -> dict:
"""
Reads structured config file defining a project.
"""

ruamelFile = YAML()
path = Path(config_file_path)
if path.exists():
try:
with open(path, "r") as f:
cfg = ruamelFile.load(f)
curr_dir = config_file_path.parent
if cfg["project_path"] != curr_dir:
cfg["project_path"] = curr_dir
except Exception as err:
if len(err.args) > 2:
if err.args[2] == "could not determine a constructor for the tag '!!python/tuple'":
with open(path, "r") as ymlfile:
cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader)
else:
raise

else:
raise FileNotFoundError(
"Config file is not found. Please make sure that the file exists and/or that you passed the path of the config file correctly!"
)
if not path.exists():
raise FileNotFoundError(f"Config file {path} not found.")

try:
with open(path, "r") as f:
cfg = ruamelFile.load(f)
curr_dir = config_file_path.parent
if cfg["project_path"] != curr_dir:
cfg["project_path"] = curr_dir
except Exception as err:
if len(err.args) > 2:
if err.args[2] == "could not determine a constructor for the tag '!!python/tuple'":
with open(path, "r") as ymlfile:
cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader)
else:
raise

return cfg


Expand Down Expand Up @@ -154,12 +154,30 @@ def _infer_nan_timestamps(timestamps):
return timestamps


def _ensure_individuals_in_header(df, individual_name):
def _ensure_individuals_in_header(df, individual_name: str):
"""
Ensure that the 'individuals' column is present in the header of the given DataFrame.
Parameters:
df (pandas.DataFrame): The DataFrame to modify.
individual_name (str): The name of the individual to add to the header.
Returns:
pandas.DataFrame: The modified DataFrame with the 'individuals' column added to the header.
Notes:
- If the 'individuals' column is already present in the header, no modifications are made.
- If the 'individuals' column is not present, a new DataFrame is created with the 'individual_name'
as the column name, and the 'individuals' column is added to the header of the DataFrame.
- The order of the columns in the header is preserved.
"""
if "individuals" not in df.columns.names:
# Single animal project -> add individual row to
# the header of single animal projects.
temp = pd.concat({individual_name: df}, names=["individuals"], axis=1)
df = temp.reorder_levels(["scorer", "individuals", "bodyparts", "coords"], axis=1)

return df


Expand Down Expand Up @@ -220,7 +238,7 @@ def _get_video_info_from_config_file(config_file_path: Path, vidname: str):
break

if video is None:
warnings.warn(f"The corresponding video file could not be found...")
warnings.warn(f"The corresponding video file could not be found in the config file")
video = None, "0, 0, 0, 0"

# The video in the config_file looks like this:
Expand All @@ -240,9 +258,6 @@ def _get_pes_args(
):
h5file = Path(h5file)

if "DLC" not in h5file.name or not h5file.suffix == ".h5":
raise IOError("The file passed in is not a DeepLabCut h5 data file.")

_, scorer = h5file.stem.split("DLC")
scorer = "DLC" + scorer

Expand All @@ -256,7 +271,8 @@ def _write_pes_to_nwbfile(
animal,
df_animal,
scorer,
video, # Expects this to be a tuple; first index is string path, second is the image shape as "0, width, 0, height"
video_file_path,
image_shape,
paf_graph,
timestamps,
exclude_nans,
Expand Down Expand Up @@ -295,12 +311,13 @@ def _write_pes_to_nwbfile(
if is_deeplabcut_installed:
deeplabcut_version = importlib.metadata.version(distribution_name="deeplabcut")

# TODO, taken from the original implementation, improve it if the video is passed
dimensions = [list(map(int, image_shape.split(",")))[1::2]]
pose_estimation_default_kwargs = dict(
pose_estimation_series=pose_estimation_series,
description="2D keypoint coordinates estimated using DeepLabCut.",
original_videos=[video[0]],
# TODO check if this is a mandatory arg in ndx-pose (can skip if video is not found_
dimensions=[list(map(int, video[1].split(",")))[1::2]],
original_videos=[video_file_path],
dimensions=dimensions,
scorer=scorer,
source_software="DeepLabCut",
source_software_version=deeplabcut_version,
Expand All @@ -326,7 +343,7 @@ def add_subject_to_nwbfile(
nwbfile: NWBFile,
h5file: FilePath,
individual_name: str,
config_file: FilePath,
config_file: Optional[FilePath] = None,
timestamps: Optional[Union[list, np.ndarray]] = None,
pose_estimation_container_kwargs: Optional[dict] = None,
) -> NWBFile:
Expand All @@ -342,7 +359,7 @@ def add_subject_to_nwbfile(
individual_name : str
Name of the subject (whose pose is predicted) for single-animal DLC project.
For multi-animal projects, the names from the DLC project will be used directly.
config_file : str or path
config_file : str or path, optional
Path to a project config.yaml file
timestamps : list, np.ndarray or None, default: None
Alternative timestamps vector. If None, then use the inferred timestamps from DLC2NWB
Expand All @@ -356,26 +373,34 @@ def add_subject_to_nwbfile(
"""
h5file = Path(h5file)

scorer, df = _get_pes_args(
h5file=h5file,
individual_name=individual_name,
)
if "DLC" not in h5file.name or not h5file.suffix == ".h5":
raise IOError("The file passed in is not a DeepLabCut h5 data file.")

video_name, scorer = h5file.stem.split("DLC")
scorer = "DLC" + scorer

df = _ensure_individuals_in_header(pd.read_hdf(h5file), individual_name)

# Note the video here is a tuple of the video path and the image shape
vidname, scorer = h5file.stem.split("DLC")
video = _get_video_info_from_config_file(config_file_path=config_file, vidname=vidname)
if config_file is not None:
video_file_path, image_shape = _get_video_info_from_config_file(
config_file_path=config_file,
vidname=video_name,
)
else:
video_file_path = None
image_shape = "0, 0, 0, 0"

# find timestamps only if required:``
timestamps_available = timestamps is not None
video_file_path = video[0]
if not timestamps_available:
if video_file_path is None:
timestamps = df.index.tolist() # setting timestamps to dummy
else:
timestamps = _get_movie_timestamps(video_file_path, infer_timestamps=True)

# Fetch the corresponding metadata pickle file, we extract the edges graph from here
# TODO: This is the original implementation way to extract the file name but looks very brittle
# TODO: This is the original implementation way to extract the file name but looks very brittle. Improve it
filename = str(h5file.parent / h5file.stem)
for i, c in enumerate(filename[::-1]):
if c.isnumeric():
Expand All @@ -393,7 +418,8 @@ def add_subject_to_nwbfile(
individual_name,
df_animal,
scorer,
video,
video_file_path,
image_shape,
paf_graph,
timestamps,
exclude_nans=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_source_schema(cls) -> dict:
def __init__(
self,
file_path: FilePath,
config_file_path: FilePath,
config_file_path: Optional[FilePath] = None,
subject_name: str = "ind1",
verbose: bool = True,
):
Expand All @@ -40,7 +40,7 @@ def __init__(
----------
file_path : FilePath
path to the h5 file output by dlc.
config_file_path : FilePath
config_file_path : FilePath, optional
path to .yml config file
subject_name : str, default: "ind1"
the name of the subject for which the :py:class:`~pynwb.file.NWBFile` is to be created.
Expand All @@ -53,17 +53,22 @@ def __init__(
if "DLC" not in file_path.stem or ".h5" not in file_path.suffixes:
raise IOError("The file passed in is not a DeepLabCut h5 data file.")

self._config_file = _read_config(config_file_path=config_file_path)
self.config_dict = dict()
if config_file_path is not None:
self.config_dict = _read_config(config_file_path=config_file_path)
self.subject_name = subject_name
self.verbose = verbose
super().__init__(file_path=file_path, config_file_path=config_file_path)

def get_metadata(self):
metadata = super().get_metadata()
metadata["NWBFile"].update(
session_description=self._config_file["Task"],
experimenter=[self._config_file["scorer"]],
)

if self.config_dict:
metadata["NWBFile"].update(
session_description=self.config_dict["Task"],
experimenter=[self.config_dict["scorer"]],
)

return metadata

def get_original_timestamps(self) -> np.ndarray:
Expand Down Expand Up @@ -110,7 +115,7 @@ def add_to_nwbfile(
nwbfile=nwbfile,
h5file=str(self.source_data["file_path"]),
individual_name=self.subject_name,
config_file=str(self.source_data["config_file_path"]),
config_file=self.source_data["config_file_path"],
timestamps=self._timestamps,
pose_estimation_container_kwargs=dict(name=container_name),
)
26 changes: 26 additions & 0 deletions tests/test_on_data/behavior/test_behavior_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,32 @@ def check_read_nwb(self, nwbfile_path: str):
assert all(expected_pose_estimation_series_are_in_nwb_file)


class TestDeepLabCutInterfaceNoConfigFile(DataInterfaceTestMixin):
data_interface_cls = DeepLabCutInterface
interface_kwargs = dict(
file_path=str(BEHAVIOR_DATA_PATH / "DLC" / "m3v1mp4DLC_resnet50_openfieldAug20shuffle1_30000.h5"),
config_file_path=None,
subject_name="ind1",
)
save_directory = OUTPUT_PATH

def check_read_nwb(self, nwbfile_path: str):
with NWBHDF5IO(path=nwbfile_path, mode="r", load_namespaces=True) as io:
nwbfile = io.read()
assert "behavior" in nwbfile.processing
processing_module_interfaces = nwbfile.processing["behavior"].data_interfaces
assert "PoseEstimation" in processing_module_interfaces

pose_estimation_series_in_nwb = processing_module_interfaces["PoseEstimation"].pose_estimation_series
expected_pose_estimation_series = ["ind1_leftear", "ind1_rightear", "ind1_snout", "ind1_tailbase"]

expected_pose_estimation_series_are_in_nwb_file = [
pose_estimation in pose_estimation_series_in_nwb for pose_estimation in expected_pose_estimation_series
]

assert all(expected_pose_estimation_series_are_in_nwb_file)


class TestDeepLabCutInterfaceSetTimestamps(DeepLabCutInterfaceMixin):
data_interface_cls = DeepLabCutInterface
interface_kwargs = dict(
Expand Down

0 comments on commit 675ec48

Please sign in to comment.