Skip to content

Commit

Permalink
Only test a best checkpoint if a model has one
Browse files Browse the repository at this point in the history
  • Loading branch information
ddaspit committed Oct 19, 2023
1 parent e5de92c commit 52a4986
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 1 deletion.
5 changes: 5 additions & 0 deletions silnlp/nmt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,11 @@ def has_val_split(self) -> bool:
return any(
pair.is_val and (pair.size if pair.val_size is None else pair.val_size) > 0 for pair in self.corpus_pairs
)

@property
@abstractmethod
def has_best_checkpoint(self) -> bool:
...

def set_seed(self) -> None:
seed = self.data["seed"]
Expand Down
11 changes: 11 additions & 0 deletions silnlp/nmt/hugging_face_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ def get_best_checkpoint(model_dir: Path) -> Path:
return model_dir / Path(trainer_state["best_model_checkpoint"]).name


def has_best_checkpoint(model_dir: Path) -> bool:
trainer_state_path = model_dir / "trainer_state.json"
with trainer_state_path.open("r", encoding="utf-8") as f:
trainer_state = json.load(f)
return "best_model_checkpoint" in trainer_state and trainer_state["best_model_checkpoint"] is not None


OPTIMIZER_STATE_FILES = {"optimizer.pt", "rng_state.pth", "scaler.pt", "scheduler.pt"}


Expand Down Expand Up @@ -269,6 +276,10 @@ def test_trg_lang(self) -> str:
lang_codes: Dict[str, str] = self.data["lang_codes"]
return lang_codes.get(self.default_test_trg_iso, self.default_test_trg_iso)

@property
def has_best_checkpoint(self) -> bool:
return has_best_checkpoint(self.model_dir)

def create_model(self, mixed_precision: bool = False, num_devices: int = 1) -> NMTModel:
return HuggingFaceNMTModel(self, mixed_precision, num_devices)

Expand Down
14 changes: 14 additions & 0 deletions silnlp/nmt/open_nmt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ def get_last_checkpoint(model_dir: Path) -> Tuple[Path, int]:
return checkpoint_path, step


def has_best_checkpoint(model_dir: Path) -> bool:
export_path = model_dir / "export"
models = list(d.name for d in export_path.iterdir())
for model in sorted(models, key=lambda m: int(m), reverse=True):
path = export_path / model
if path.is_dir():
return True
return False


@register_scorer(name="bleu_sp")
class BLEUSentencepieceScorer(Scorer):
def __init__(self):
Expand Down Expand Up @@ -428,6 +438,10 @@ def __init__(self, exp_dir: Path, config: dict) -> None:
def model_dir(self) -> Path:
return Path(self.root["model_dir"])

@property
def has_best_checkpoint(self) -> bool:
return has_best_checkpoint(self.model_dir)

def create_model(self, mixed_precision: bool = False, num_devices: int = 1) -> NMTModel:
return OpenNMTModel(self, mixed_precision, num_devices)

Expand Down
2 changes: 1 addition & 1 deletion silnlp/nmt/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def test(
LOGGER.warn("No average checkpoint available.")

best_step = 0
if best and config.has_val_split:
if best and config.has_best_checkpoint:
_, best_step = model.get_checkpoint_path(CheckpointType.BEST)
step = best_step
if step not in results:
Expand Down

0 comments on commit 52a4986

Please sign in to comment.