Skip to content

Commit

Permalink
Refactor DLC (#1047)
Browse files Browse the repository at this point in the history
Co-authored-by: Cody Baker <51133164+CodyCBakerPhD@users.noreply.github.com>
  • Loading branch information
h-mayorquin and CodyCBakerPhD authored Sep 2, 2024
1 parent e73745a commit 77d2252
Showing 1 changed file with 93 additions and 49 deletions.
142 changes: 93 additions & 49 deletions src/neuroconv/datainterfaces/behavior/deeplabcut/_dlc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,40 +163,23 @@ def _ensure_individuals_in_header(df, individual_name):
return df


def _get_pes_args(
*,
config_file: Path,
h5file: Path,
individual_name: str,
timestamps_available: bool = False,
infer_timestamps: bool = True,
):
config_file = Path(config_file)
h5file = Path(h5file)

if "DLC" not in h5file.name or not h5file.suffix == ".h5":
raise IOError("The file passed in is not a DeepLabCut h5 data file.")

cfg = _read_config(config_file)

vidname, scorer = h5file.stem.split("DLC")
scorer = "DLC" + scorer
video = None
def _get_graph_edges(metadata_file_path: Path):
"""
Extracts the part affinity field graph from the metadata pickle file.
df = _ensure_individuals_in_header(pd.read_hdf(h5file), individual_name)
Parameters
----------
metadata_file_path : Path
The path to the metadata pickle file.
# Fetch the corresponding metadata pickle file
Returns
-------
list
The part affinity field graph, which defines the edges between the keypoints in the pose estimation.
"""
paf_graph = []
filename = str(h5file.parent / h5file.stem)
for i, c in enumerate(filename[::-1]):
if c.isnumeric():
break
if i > 0:
filename = filename[:-i]
metadata_file = Path(filename + "_meta.pickle")

if metadata_file.exists():
with open(metadata_file, "rb") as file:
if metadata_file_path.exists():
with open(metadata_file_path, "rb") as file:
metadata = pickle.load(file)

test_cfg = metadata["data"]["DLC-model-config file"]
Expand All @@ -208,25 +191,64 @@ def _get_pes_args(
else:
warnings.warn("Metadata not found...")

return paf_graph


def _get_video_info_from_config_file(config_file_path: Path, vidname: str):
"""
Get the video information from the project config file.
Parameters
----------
config_file_path : Path
The path to the project config file.
vidname : str
The name of the video.
Returns
-------
tuple
A tuple containing the video file path and the image shape.
"""
config_file_path = Path(config_file_path)
cfg = _read_config(config_file_path)

video = None
for video_path, params in cfg["video_sets"].items():
if vidname in video_path:
video = video_path, params["crop"]
break

# find timestamps only if required:
if timestamps_available:
timestamps = None
else:
if video is None:
timestamps = df.index.tolist() # setting timestamps to dummy TODO: extract timestamps in DLC?
else:
timestamps = _get_movie_timestamps(video[0], infer_timestamps=infer_timestamps)

if video is None:
warnings.warn(f"The video file corresponding to {h5file} could not be found...")
video = "fake_path", "0, 0, 0, 0"
warnings.warn(f"The corresponding video file could not be found...")
video = None, "0, 0, 0, 0"

# The video in the config_file looks like this:
# video_sets:
# /Data/openfield-Pranav-2018-08-20/videos/m1s1.mp4:
# crop: 0, 640, 0, 480

video_file_path, image_shape = video

return video_file_path, image_shape


def _get_pes_args(
*,
h5file: Path,
individual_name: str,
):
h5file = Path(h5file)

if "DLC" not in h5file.name or not h5file.suffix == ".h5":
raise IOError("The file passed in is not a DeepLabCut h5 data file.")

_, scorer = h5file.stem.split("DLC")
scorer = "DLC" + scorer

df = _ensure_individuals_in_header(pd.read_hdf(h5file), individual_name)

return scorer, df, video, paf_graph, timestamps, cfg
return scorer, df


def _write_pes_to_nwbfile(
Expand Down Expand Up @@ -332,15 +354,37 @@ def add_subject_to_nwbfile(
nwbfile : pynwb.NWBFile
nwbfile with pes written in the behavior module
"""
timestamps_available = timestamps is not None
scorer, df, video, paf_graph, dlc_timestamps, _ = _get_pes_args(
config_file=config_file,
h5file = Path(h5file)

scorer, df = _get_pes_args(
h5file=h5file,
individual_name=individual_name,
timestamps_available=timestamps_available,
)
if timestamps is None:
timestamps = dlc_timestamps

# Note the video here is a tuple of the video path and the image shape
vidname, scorer = h5file.stem.split("DLC")
video = _get_video_info_from_config_file(config_file_path=config_file, vidname=vidname)

# find timestamps only if required:``
timestamps_available = timestamps is not None
video_file_path = video[0]
if not timestamps_available:
if video_file_path is None:
timestamps = df.index.tolist() # setting timestamps to dummy
else:
timestamps = _get_movie_timestamps(video_file_path, infer_timestamps=True)

# Fetch the corresponding metadata pickle file, we extract the edges graph from here
# TODO: This is the original implementation way to extract the file name but looks very brittle
filename = str(h5file.parent / h5file.stem)
for i, c in enumerate(filename[::-1]):
if c.isnumeric():
break
if i > 0:
filename = filename[:-i]

metadata_file_path = Path(filename + "_meta.pickle")
paf_graph = _get_graph_edges(metadata_file_path=metadata_file_path)

df_animal = df.xs(individual_name, level="individuals", axis=1)

Expand Down

0 comments on commit 77d2252

Please sign in to comment.