From dcd7bb02fb69f53bfd63423dcd5c91ebfd96e015 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 9 Jan 2025 11:00:47 +0530 Subject: [PATCH 1/2] support precomputation. --- finetrainers/dataset.py | 11 ++++++++--- finetrainers/trainer.py | 4 +++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/finetrainers/dataset.py b/finetrainers/dataset.py index 6054e49b..19ebb698 100644 --- a/finetrainers/dataset.py +++ b/finetrainers/dataset.py @@ -353,13 +353,18 @@ def _find_nearest_resolution(self, height, width): class PrecomputedDataset(Dataset): - def __init__(self, data_root: str) -> None: + def __init__(self, data_root: str, model_name: str = None, cleaned_model_id: str = None) -> None: super().__init__() self.data_root = Path(data_root) - self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME - self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME + if model_name and cleaned_model_id: + precomputation_dir = self.data_root / f"{model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}" + self.latents_path = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME + self.conditions_path = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME + else: + self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME + self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME self.latent_conditions = sorted(os.listdir(self.latents_path)) self.text_conditions = sorted(os.listdir(self.conditions_path)) diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py index a05a8f0b..503cc161 100644 --- a/finetrainers/trainer.py +++ b/finetrainers/trainer.py @@ -225,7 +225,9 @@ def collate_fn(batch): if not should_precompute: logger.info("Precomputed conditions and latents found. Loading precomputed data.") self.dataloader = torch.utils.data.DataLoader( - PrecomputedDataset(self.args.data_root), + PrecomputedDataset( + data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id + ), batch_size=self.args.batch_size, shuffle=True, collate_fn=collate_fn, From 612d853922014e3360603c4ed94f2078532a713f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 9 Jan 2025 11:21:14 +0530 Subject: [PATCH 2/2] fixes --- finetrainers/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py index 503cc161..81c82647 100644 --- a/finetrainers/trainer.py +++ b/finetrainers/trainer.py @@ -355,7 +355,9 @@ def collate_fn(batch): # Update dataloader to use precomputed conditions and latents self.dataloader = torch.utils.data.DataLoader( - PrecomputedDataset(self.args.data_root), + PrecomputedDataset( + data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id + ), batch_size=self.args.batch_size, shuffle=True, collate_fn=collate_fn,