From 02a0547816ce87aede695ed156f297b8a4c0e8cb Mon Sep 17 00:00:00 2001 From: rubywood <48224649+rubywood@users.noreply.github.com> Date: Mon, 15 Sep 2025 14:53:05 +0100 Subject: [PATCH 1/8] EchoDataset --- mmai25_hackathon/dataset.py | 48 +++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/mmai25_hackathon/dataset.py b/mmai25_hackathon/dataset.py index 2121f48..69ad314 100644 --- a/mmai25_hackathon/dataset.py +++ b/mmai25_hackathon/dataset.py @@ -14,6 +14,7 @@ from torch.utils.data import Dataset, Sampler from torch_geometric.data import DataLoader +from load_data.echo import load_echo_dicom, load_mimic_iv_echo_record_list __all__ = ["BaseDataset", "BaseDataLoader", "BaseSampler"] @@ -110,6 +111,53 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) +class EchoDataset(BaseDataset): + """Example subclass for an ECHO dataset.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.records = load_mimic_iv_echo_record_list(args.data_path) + self.subject_ids = self.records['subject_id'].tolist() + + def __len__(self) -> int: + """Return the number of samples in the dataset.""" + return len(self.records) + + def __getitem__(self, idx: int): + """Return a single sample from the dataset.""" + record = self.records[idx] + # Load and return the ECHO data for the given record + #print(f"Loading first ECHO DICOM from: {records.iloc[0]['echo_path']}") + sample_path = record["echo_path"] + frames, meta = load_echo_dicom(sample_path) + #meta_filtered = { + # k: meta[k] for k in ("NumberOfFrames", "Rows", "Columns", "FrameTime", "CineRate") if k in meta + #} + return {'frames': frames, 'metadata': meta, 'subject_id': record['subject_id']} + + def extra_repr(self) -> str: + """Return any extra information about the dataset.""" + return f"sample_size={len(self)}, subjects={len(set(self.subject_ids))}" + + def __add__(self, other): + """ + Combine with another dataset. + + Override in subclasses to implement multimodal aggregation. + + Args: + other: Another dataset to combine with this one. + + Initial Idea: + Use `__add__` to align and merge heterogeneous modalities into a single + dataset, keeping shared IDs synchronized. + Note: This is not mandatory; treat it as a sketch you can refine or replace. + """ + self.records = self.records.merge(other.records, on='subject_id', suffixes=('', '_other'), how='outer') + self.subject_ids = self.records['subject_id'].tolist() + return self + + class BaseDataLoader(DataLoader): """ DataLoader for graph and non-graph data. From 9ef0fae597e91b2fb99caa570925563ef244ad2d Mon Sep 17 00:00:00 2001 From: vvw1n22 Date: Mon, 15 Sep 2025 15:37:05 +0100 Subject: [PATCH 2/8] ECG dataloader --- mmai25_hackathon/dataset.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/mmai25_hackathon/dataset.py b/mmai25_hackathon/dataset.py index 2121f48..0865c87 100644 --- a/mmai25_hackathon/dataset.py +++ b/mmai25_hackathon/dataset.py @@ -14,7 +14,7 @@ from torch.utils.data import Dataset, Sampler from torch_geometric.data import DataLoader - +from load_data.ecg import load_mimic_iv_ecg_record_list, load_ecg_record __all__ = ["BaseDataset", "BaseDataLoader", "BaseSampler"] @@ -108,7 +108,42 @@ class ECGDataset(BaseDataset): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # Loading the ECG data (records contains patient id, hea path) in df frame + self.records = load_mimic_iv_ecg_record_list(args.data_path) + self.subject_ids = self.records['subject_id'].tolist() + + + def __getitem__(self, idx: int): + """Return a single sample from the dataset.""" + record_idx = self.records[idx] + signals, fields = load_ecg_record(record_idx["hea_path"]) + return {'signals': signals, 'fields': fields, 'subject_id': record_idx['subject_id']} + + def __repr__(self) -> str: + """Return a string representation of the dataset.""" + return f"{self.__class__.__name__}({self.extra_repr()})" + + def extra_repr(self) -> str: + """Return any extra information about the dataset.""" + return f"sample_size={len(self.subject_ids)}" + + def __add__(self, other): + """ + Combine with another dataset. + Override in subclasses to implement multimodal aggregation. + + Args: + other: Another dataset to combine with this one. + + Initial Idea: + Use `__add__` to align and merge heterogeneous modalities into a single + dataset, keeping shared IDs synchronized. + Note: This is not mandatory; treat it as a sketch you can refine or replace. + """ + self.records = self.records.merge(other.records, on='subject_id', suffixes=('', '_other'), how='outer') + self.subject_ids = self.records['subject_id'].tolist() + return self class BaseDataLoader(DataLoader): """ From 0154cb3a1c8dec8b3d2f2bfa76099f5deb21be41 Mon Sep 17 00:00:00 2001 From: rubywood <48224649+rubywood@users.noreply.github.com> Date: Mon, 15 Sep 2025 15:50:57 +0100 Subject: [PATCH 3/8] initial implementation of Multimodal Dataset --- mmai25_hackathon/dataset.py | 67 ++++++++++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 12 deletions(-) diff --git a/mmai25_hackathon/dataset.py b/mmai25_hackathon/dataset.py index 69ad314..c1012e7 100644 --- a/mmai25_hackathon/dataset.py +++ b/mmai25_hackathon/dataset.py @@ -12,9 +12,9 @@ BaseSampler: Template for custom samplers, e.g., for multimodal sampling. """ +from load_data.echo import load_echo_dicom, load_mimic_iv_echo_record_list from torch.utils.data import Dataset, Sampler from torch_geometric.data import DataLoader -from load_data.echo import load_echo_dicom, load_mimic_iv_echo_record_list __all__ = ["BaseDataset", "BaseDataLoader", "BaseSampler"] @@ -113,28 +113,28 @@ def __init__(self, *args, **kwargs): class EchoDataset(BaseDataset): """Example subclass for an ECHO dataset.""" - + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.records = load_mimic_iv_echo_record_list(args.data_path) - self.subject_ids = self.records['subject_id'].tolist() + self.subject_ids = self.records["subject_id"].tolist() def __len__(self) -> int: """Return the number of samples in the dataset.""" return len(self.records) - + def __getitem__(self, idx: int): """Return a single sample from the dataset.""" - record = self.records[idx] + record = self.records.iloc[idx] # Load and return the ECHO data for the given record - #print(f"Loading first ECHO DICOM from: {records.iloc[0]['echo_path']}") + # print(f"Loading first ECHO DICOM from: {records.iloc[0]['echo_path']}") sample_path = record["echo_path"] frames, meta = load_echo_dicom(sample_path) - #meta_filtered = { + # meta_filtered = { # k: meta[k] for k in ("NumberOfFrames", "Rows", "Columns", "FrameTime", "CineRate") if k in meta - #} - return {'frames': frames, 'metadata': meta, 'subject_id': record['subject_id']} - + # } + return {"frames": frames, "metadata": meta, "subject_id": record["subject_id"]} + def extra_repr(self) -> str: """Return any extra information about the dataset.""" return f"sample_size={len(self)}, subjects={len(set(self.subject_ids))}" @@ -153,11 +153,40 @@ def __add__(self, other): dataset, keeping shared IDs synchronized. Note: This is not mandatory; treat it as a sketch you can refine or replace. """ - self.records = self.records.merge(other.records, on='subject_id', suffixes=('', '_other'), how='outer') - self.subject_ids = self.records['subject_id'].tolist() + self.records = self.records.merge(other.records, on="subject_id", suffixes=("", "_other"), how="outer") + self.subject_ids = self.records["subject_id"].tolist() return self +class MultimodalDataset(BaseDataset): + """Example subclass for a multimodal dataset.""" + + def __init__(self, datasets: list[BaseDataset], *args, **kwargs): + super().__init__(*args, **kwargs) + self.datasets = datasets + _dataset = datasets[0] + if not isinstance(_dataset, BaseDataset): + raise ValueError("All elements in datasets must be instances of BaseDataset.") + if len(datasets) > 1: + for ds in datasets[1:]: + if not isinstance(ds, BaseDataset): + raise ValueError("All elements in datasets must be instances of BaseDataset.") + _dataset.__add__(ds) + self.dataset = _dataset + + def __len__(self) -> int: + """Return the number of samples in the dataset.""" + return len(self.dataset) + + def __getitem__(self, idx: int): + """Return a single sample from the dataset.""" + return self.dataset.__getitem__(idx) + + def extra_repr(self) -> str: + """Return any extra information about the dataset.""" + return self.dataset.extra_repr() + + class BaseDataLoader(DataLoader): """ DataLoader for graph and non-graph data. @@ -180,6 +209,20 @@ class BaseDataLoader(DataLoader): Note: This is not a hard requirement. Consider it a future-facing idea you can evolve. """ + def __init__( + self, + dataset: BaseDataset, + batch_size: int = 1, + shuffle: bool = False, + follow_batch: list = None, + exclude_keys: list = None, + **kwargs, + ): + super().__init__(dataset, batch_size, shuffle, follow_batch, exclude_keys, **kwargs) + + # collate_fn=lambda data_list: Batch.from_data_list( + # data_list, follow_batch), + class MultimodalDataLoader(BaseDataLoader): """Example dataloader for handling multiple data modalities.""" From c1b2bc3715ba36eef652f5f6f920e9bf9d05697e Mon Sep 17 00:00:00 2001 From: rubywood <48224649+rubywood@users.noreply.github.com> Date: Mon, 15 Sep 2025 15:58:43 +0100 Subject: [PATCH 4/8] relative import --- mmai25_hackathon/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmai25_hackathon/dataset.py b/mmai25_hackathon/dataset.py index 1028161..5bc1b9c 100644 --- a/mmai25_hackathon/dataset.py +++ b/mmai25_hackathon/dataset.py @@ -12,7 +12,7 @@ BaseSampler: Template for custom samplers, e.g., for multimodal sampling. """ -from load_data.echo import load_echo_dicom, load_mimic_iv_echo_record_list +from .load_data.echo import load_echo_dicom, load_mimic_iv_echo_record_list from torch.utils.data import Dataset, Sampler from torch_geometric.data import DataLoader From 37d85787485d0c360c62d9f2c53b0dc60dbcd938 Mon Sep 17 00:00:00 2001 From: rubywood <48224649+rubywood@users.noreply.github.com> Date: Mon, 15 Sep 2025 16:31:32 +0100 Subject: [PATCH 5/8] Indexing by sample ID --- mmai25_hackathon/dataset.py | 107 +++++++++++++++++++++++++----------- 1 file changed, 76 insertions(+), 31 deletions(-) diff --git a/mmai25_hackathon/dataset.py b/mmai25_hackathon/dataset.py index 5bc1b9c..b44cea3 100644 --- a/mmai25_hackathon/dataset.py +++ b/mmai25_hackathon/dataset.py @@ -12,11 +12,12 @@ BaseSampler: Template for custom samplers, e.g., for multimodal sampling. """ -from .load_data.echo import load_echo_dicom, load_mimic_iv_echo_record_list from torch.utils.data import Dataset, Sampler from torch_geometric.data import DataLoader -from load_data.ecg import load_mimic_iv_ecg_record_list, load_ecg_record +from .load_data.ecg import load_ecg_record, load_mimic_iv_ecg_record_list +from .load_data.echo import load_echo_dicom, load_mimic_iv_echo_record_list + __all__ = ["BaseDataset", "BaseDataLoader", "BaseSampler"] @@ -112,13 +113,20 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Loading the ECG data (records contains patient id, hea path) in df frame self.records = load_mimic_iv_ecg_record_list(args.data_path) - self.subject_ids = self.records['subject_id'].tolist() + self.subject_ids = self.records["subject_id"].tolist() - def __getitem__(self, idx: int): - """Return a single sample from the dataset.""" - record_idx = self.records[idx] - signals, fields = load_ecg_record(record_idx["hea_path"]) - return {'signals': signals, 'fields': fields, 'subject_id': record_idx['subject_id']} + def __getitem__(self, sample_ID: int): + """Return samples for one sampleID from the dataset.""" + # record_idx = self.records[idx] + # signals, fields = load_ecg_record(record_idx["hea_path"]) + # return {"signals": signals, "fields": fields, "subject_id": record_idx["subject_id"]} + record_idx = self.records[self.records.subject_id == sample_ID] + samples = [] + for idx in record_idx: + signals, fields = load_ecg_record(idx["hea_path"]) + item = {"signals": signals, "fields": fields, "subject_id": record_idx["subject_id"]} + samples.append(item) + return samples def __repr__(self) -> str: """Return a string representation of the dataset.""" @@ -128,9 +136,13 @@ def extra_repr(self) -> str: """Return any extra information about the dataset.""" return f"sample_size={len(self.subject_ids)}" + def modality(self) -> str: + """Return the modality of the dataset.""" + return "ECG" + def __add__(self, other): """ - Combine with another dataset. + Combine with another dataset. Assume other is a single sample. Override in subclasses to implement multimodal aggregation. @@ -142,8 +154,12 @@ def __add__(self, other): dataset, keeping shared IDs synchronized. Note: This is not mandatory; treat it as a sketch you can refine or replace. """ - self.records = self.records.merge(other.records, on='subject_id', suffixes=('', '_other'), how='outer') - self.subject_ids = self.records['subject_id'].tolist() + self.records = self.records.merge(other.records, on="subject_id", suffixes=("", "_other"), how="outer") + self.subject_ids = self.records["subject_id"].tolist() + + # TODO: takes a single sample from other, find corresponding sample in this dataset? + # i.e. find idx where sample_id matches from other and call get_item on all of those indices? + return self @@ -159,22 +175,29 @@ def __len__(self) -> int: """Return the number of samples in the dataset.""" return len(self.records) - def __getitem__(self, idx: int): + def __getitem__(self, sample_ID: int): """Return a single sample from the dataset.""" - record = self.records.iloc[idx] - # Load and return the ECHO data for the given record - # print(f"Loading first ECHO DICOM from: {records.iloc[0]['echo_path']}") - sample_path = record["echo_path"] - frames, meta = load_echo_dicom(sample_path) - # meta_filtered = { - # k: meta[k] for k in ("NumberOfFrames", "Rows", "Columns", "FrameTime", "CineRate") if k in meta - # } - return {"frames": frames, "metadata": meta, "subject_id": record["subject_id"]} + # record = self.records.iloc[idx] + # sample_path = record["echo_path"] + # frames, meta = load_echo_dicom(sample_path) + # return {"frames": frames, "metadata": meta, "subject_id": record["subject_id"]} + record_idx = self.records[self.records.subject_id == sample_ID] + samples = [] + for idx in record_idx: + sample_path = idx["echo_path"] + frames, meta = load_echo_dicom(sample_path) + item = {"frames": frames, "metadata": meta, "subject_id": idx["subject_id"]} + samples.append(item) + return samples def extra_repr(self) -> str: """Return any extra information about the dataset.""" return f"sample_size={len(self)}, subjects={len(set(self.subject_ids))}" + def modality(self) -> str: + """Return the modality of the dataset.""" + return "echo" + def __add__(self, other): """ Combine with another dataset. @@ -200,15 +223,19 @@ class MultimodalDataset(BaseDataset): def __init__(self, datasets: list[BaseDataset], *args, **kwargs): super().__init__(*args, **kwargs) self.datasets = datasets - _dataset = datasets[0] - if not isinstance(_dataset, BaseDataset): - raise ValueError("All elements in datasets must be instances of BaseDataset.") - if len(datasets) > 1: - for ds in datasets[1:]: - if not isinstance(ds, BaseDataset): - raise ValueError("All elements in datasets must be instances of BaseDataset.") - _dataset.__add__(ds) - self.dataset = _dataset + # _dataset = datasets[0] + # if not isinstance(_dataset, BaseDataset): + # raise ValueError("All elements in datasets must be instances of BaseDataset.") + # if len(datasets) > 1: + # for ds in datasets[1:]: + # if not isinstance(ds, BaseDataset): + # raise ValueError("All elements in datasets must be instances of BaseDataset.") + # _dataset.__add__(ds) + # self.dataset = _dataset + + # get union of all subject IDs in each dataset + self.subject_ids = list(set().union(*(ds.subject_ids for ds in datasets))) + print(f"MultimodalDataset initialized with {len(self.subject_ids)} unique subjects.") def __len__(self) -> int: """Return the number of samples in the dataset.""" @@ -216,7 +243,25 @@ def __len__(self) -> int: def __getitem__(self, idx: int): """Return a single sample from the dataset.""" - return self.dataset.__getitem__(idx) + subject_ID = self.subject_ids[idx] + results = {} + for dataset in self.datasets: + items = dataset.__getitem__(subject_ID) + results[dataset.modality()] = items + return results + + # get dictionaries for each dataset + # results = {} + # for ds in self.datasets: + # dict_result = ds.__getitem__(idx) # TODO: assumes idx is same for each sample. replace idx with sample ID + # results[ds.modality()] = dict_result + + # or primary dataset and __add__ in others + # primary_ds = self.datasets[0] + # item = primary_ds.__getitem__(idx) + # for ds in self.datasets[1:]: + # items = ds.__add__(item) + # return results def extra_repr(self) -> str: """Return any extra information about the dataset.""" From 8192d736da7bb298677c79606d299006855cc154 Mon Sep 17 00:00:00 2001 From: rubywood <48224649+rubywood@users.noreply.github.com> Date: Mon, 15 Sep 2025 16:36:15 +0100 Subject: [PATCH 6/8] pre-commit update --- tests/test_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index b458d04..1c36aa7 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -77,8 +77,8 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> Data: # type: ignore[name-defined] return self._graphs[idx] - ds = GraphDataset() - loader = BaseDataLoader(ds, batch_size=2, shuffle=False) + ds = GraphDataset() # type: ignore[arg-type] + loader = BaseDataLoader(ds, batch_size=2, shuffle=False) # type: ignore[arg-type] total_graphs = 0 for batch in loader: From ba5bc89991ab456f01a3007cd0c8f84cf72c124d Mon Sep 17 00:00:00 2001 From: rubywood <48224649+rubywood@users.noreply.github.com> Date: Mon, 15 Sep 2025 16:43:38 +0100 Subject: [PATCH 7/8] CXR Dataset --- mmai25_hackathon/dataset.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/mmai25_hackathon/dataset.py b/mmai25_hackathon/dataset.py index b44cea3..e52eb4a 100644 --- a/mmai25_hackathon/dataset.py +++ b/mmai25_hackathon/dataset.py @@ -17,6 +17,7 @@ from .load_data.ecg import load_ecg_record, load_mimic_iv_ecg_record_list from .load_data.echo import load_echo_dicom, load_mimic_iv_echo_record_list +from .load_data.cxr import load_chest_xray_image, load_mimic_cxr_metadata __all__ = ["BaseDataset", "BaseDataLoader", "BaseSampler"] @@ -104,6 +105,25 @@ class CXRDataset(BaseDataset): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.records = load_mimic_cxr_metadata(args.data_path) + self.subject_ids = self.records["subject_id"].tolist() + + def __len__(self): + return len(self.records) + + def __getitem__(self, sample_ID: int): + record_idx = self.records[self.records.subject_id == sample_ID] + samples = [] + for idx in record_idx: + path = idx["cxr_path"] + image = load_chest_xray_image(path) + item = {"image": image, "subject_id": record_idx["subject_id"]} + samples.append(item) + return samples + + def modality(self) -> str: + """Return the modality of the dataset.""" + return "CXR" class ECGDataset(BaseDataset): @@ -115,6 +135,9 @@ def __init__(self, *args, **kwargs): self.records = load_mimic_iv_ecg_record_list(args.data_path) self.subject_ids = self.records["subject_id"].tolist() + def __len__(self): + return len(self.records) + def __getitem__(self, sample_ID: int): """Return samples for one sampleID from the dataset.""" # record_idx = self.records[idx] From 44b0f5e306663915074b501d1f2f414dd6a96cbf Mon Sep 17 00:00:00 2001 From: rubywood <48224649+rubywood@users.noreply.github.com> Date: Mon, 15 Sep 2025 16:44:01 +0100 Subject: [PATCH 8/8] pre-commit --- mmai25_hackathon/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmai25_hackathon/dataset.py b/mmai25_hackathon/dataset.py index e52eb4a..f4f774d 100644 --- a/mmai25_hackathon/dataset.py +++ b/mmai25_hackathon/dataset.py @@ -15,9 +15,9 @@ from torch.utils.data import Dataset, Sampler from torch_geometric.data import DataLoader +from .load_data.cxr import load_chest_xray_image, load_mimic_cxr_metadata from .load_data.ecg import load_ecg_record, load_mimic_iv_ecg_record_list from .load_data.echo import load_echo_dicom, load_mimic_iv_echo_record_list -from .load_data.cxr import load_chest_xray_image, load_mimic_cxr_metadata __all__ = ["BaseDataset", "BaseDataLoader", "BaseSampler"] @@ -120,7 +120,7 @@ def __getitem__(self, sample_ID: int): item = {"image": image, "subject_id": record_idx["subject_id"]} samples.append(item) return samples - + def modality(self) -> str: """Return the modality of the dataset.""" return "CXR"