diff --git a/config/inference/base.yaml b/config/inference/base.yaml index 1c8cae1..f3b450b 100644 --- a/config/inference/base.yaml +++ b/config/inference/base.yaml @@ -17,7 +17,11 @@ inference: align_motif: True symmetric_self_cond: True final_step: 1 + + # If deterministic == False, seed is ignored. deterministic: False + seed: 0 + trb_save_ckpt_path: null schedule_directory_path: null model_directory_path: null diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 2a3bf36..60795c0 100755 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -39,7 +39,7 @@ def make_deterministic(seed=0): def main(conf: HydraConfig) -> None: log = logging.getLogger(__name__) if conf.inference.deterministic: - make_deterministic() + make_deterministic(conf.inference.seed) # Check for available GPU and print result of check if torch.cuda.is_available(): @@ -70,7 +70,7 @@ def main(conf: HydraConfig) -> None: for i_des in range(design_startnum, design_startnum + sampler.inf_conf.num_designs): if conf.inference.deterministic: - make_deterministic(i_des) + make_deterministic(conf.inference.seed + i_des) start_time = time.time() out_prefix = f"{sampler.inf_conf.output_prefix}_{i_des}"