Skip to content

Commit

Permalink
add flag to store all checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
SobhanMP committed Feb 9, 2024
1 parent d232bcb commit 1bf8724
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/gflownet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class Config:
The number of training steps after which to validate the model
checkpoint_every : Optional[int]
The number of training steps after which to checkpoint the model
store_all_checkpoints : bool
Whether to store all checkpoints or only the last one
print_every : int
The number of training steps after which to print the training loss
start_at_step : int
Expand All @@ -85,6 +87,7 @@ class Config:
seed: int = 0
validate_every: int = 1000
checkpoint_every: Optional[int] = None
store_all_checkpoints: bool = False
print_every: int = 100
start_at_step: int = 0
num_final_gen_steps: Optional[int] = None
Expand Down
3 changes: 2 additions & 1 deletion src/gflownet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,8 @@ def _save_state(self, it):
state,
fd,
)
shutil.copy(fn, pathlib.Path(self.cfg.log_dir) / f"model_state_{it}.pt")
if self.cfg.store_all_checkpoints:
shutil.copy(fn, pathlib.Path(self.cfg.log_dir) / f"model_state_{it}.pt")

def log(self, info, index, key):
if not hasattr(self, "_summary_writer"):
Expand Down

0 comments on commit 1bf8724

Please sign in to comment.