From 52a4986b592885d752f5196dbbdc69bdeb1fd7b9 Mon Sep 17 00:00:00 2001 From: Damien Daspit Date: Thu, 19 Oct 2023 10:02:14 -0500 Subject: [PATCH] Only test a best checkpoint if a model has one --- silnlp/nmt/config.py | 5 +++++ silnlp/nmt/hugging_face_config.py | 11 +++++++++++ silnlp/nmt/open_nmt_config.py | 14 ++++++++++++++ silnlp/nmt/test.py | 2 +- 4 files changed, 31 insertions(+), 1 deletion(-) diff --git a/silnlp/nmt/config.py b/silnlp/nmt/config.py index 21af58ac..bb96b428 100644 --- a/silnlp/nmt/config.py +++ b/silnlp/nmt/config.py @@ -459,6 +459,11 @@ def has_val_split(self) -> bool: return any( pair.is_val and (pair.size if pair.val_size is None else pair.val_size) > 0 for pair in self.corpus_pairs ) + + @property + @abstractmethod + def has_best_checkpoint(self) -> bool: + ... def set_seed(self) -> None: seed = self.data["seed"] diff --git a/silnlp/nmt/hugging_face_config.py b/silnlp/nmt/hugging_face_config.py index f1e4ce55..9ab9d785 100644 --- a/silnlp/nmt/hugging_face_config.py +++ b/silnlp/nmt/hugging_face_config.py @@ -129,6 +129,13 @@ def get_best_checkpoint(model_dir: Path) -> Path: return model_dir / Path(trainer_state["best_model_checkpoint"]).name +def has_best_checkpoint(model_dir: Path) -> bool: + trainer_state_path = model_dir / "trainer_state.json" + with trainer_state_path.open("r", encoding="utf-8") as f: + trainer_state = json.load(f) + return "best_model_checkpoint" in trainer_state and trainer_state["best_model_checkpoint"] is not None + + OPTIMIZER_STATE_FILES = {"optimizer.pt", "rng_state.pth", "scaler.pt", "scheduler.pt"} @@ -269,6 +276,10 @@ def test_trg_lang(self) -> str: lang_codes: Dict[str, str] = self.data["lang_codes"] return lang_codes.get(self.default_test_trg_iso, self.default_test_trg_iso) + @property + def has_best_checkpoint(self) -> bool: + return has_best_checkpoint(self.model_dir) + def create_model(self, mixed_precision: bool = False, num_devices: int = 1) -> NMTModel: return HuggingFaceNMTModel(self, mixed_precision, num_devices) diff --git a/silnlp/nmt/open_nmt_config.py b/silnlp/nmt/open_nmt_config.py index e2a22144..6168bea4 100644 --- a/silnlp/nmt/open_nmt_config.py +++ b/silnlp/nmt/open_nmt_config.py @@ -161,6 +161,16 @@ def get_last_checkpoint(model_dir: Path) -> Tuple[Path, int]: return checkpoint_path, step +def has_best_checkpoint(model_dir: Path) -> bool: + export_path = model_dir / "export" + models = list(d.name for d in export_path.iterdir()) + for model in sorted(models, key=lambda m: int(m), reverse=True): + path = export_path / model + if path.is_dir(): + return True + return False + + @register_scorer(name="bleu_sp") class BLEUSentencepieceScorer(Scorer): def __init__(self): @@ -428,6 +438,10 @@ def __init__(self, exp_dir: Path, config: dict) -> None: def model_dir(self) -> Path: return Path(self.root["model_dir"]) + @property + def has_best_checkpoint(self) -> bool: + return has_best_checkpoint(self.model_dir) + def create_model(self, mixed_precision: bool = False, num_devices: int = 1) -> NMTModel: return OpenNMTModel(self, mixed_precision, num_devices) diff --git a/silnlp/nmt/test.py b/silnlp/nmt/test.py index 68e7ecfc..9cdba62f 100644 --- a/silnlp/nmt/test.py +++ b/silnlp/nmt/test.py @@ -523,7 +523,7 @@ def test( LOGGER.warn("No average checkpoint available.") best_step = 0 - if best and config.has_val_split: + if best and config.has_best_checkpoint: _, best_step = model.get_checkpoint_path(CheckpointType.BEST) step = best_step if step not in results: