From efeb520917e048d5158c13c8123dfdf5b36e0564 Mon Sep 17 00:00:00 2001 From: Gianluca Rossi Date: Sat, 23 Mar 2024 16:49:07 -0400 Subject: [PATCH] refactor: raise explicit error when image paths are missing --- blazingai/vision/data.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/blazingai/vision/data.py b/blazingai/vision/data.py index 6496bf9..65d6cbe 100644 --- a/blazingai/vision/data.py +++ b/blazingai/vision/data.py @@ -175,6 +175,11 @@ def __init__( def setup(self, stage: Optional[str] = None) -> None: if stage == "fit": + if self.trn_img_paths is None: + raise ValueError("Missing trn_img_path") + if self.val_img_paths is None: + raise ValueError("Missing val_img_path") + self.trn_ds = get_dataset[self.task]( img_paths=self.trn_img_paths, trgt=self.trn_trgt, @@ -186,11 +191,12 @@ def setup(self, stage: Optional[str] = None) -> None: aug=self.val_aug, ) elif stage == "predict": - if self.tst_img_paths: - self.test_ds = get_dataset[self.task]( - img_paths=self.tst_img_paths, - aug=self.tst_aug, - ) + if self.tst_img_paths is None: + raise ValueError("Missing tst_img_paths") + self.test_ds = get_dataset[self.task]( + img_paths=self.tst_img_paths, + aug=self.tst_aug, + ) else: raise ValueError(f"stage `{stage}` currently not supported")