-
Notifications
You must be signed in to change notification settings - Fork 4
/
ddpm_eval.py
52 lines (38 loc) · 1.3 KB
/
ddpm_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import random
import sys
import numpy as np
import torch
import yaml
from evaluation.eval_cdm import run_inference as run_inference_only_cdm
from evaluation.evaluate_lidc_sampling_speed import eval_lidc_sampling_speed
from evaluation.evaluate_lidc_uncertainty import eval_lidc_uncertainty
def set_seeds(seed: int):
"""Function that sets all relevant seeds (by Claudio)
:param seed: Seed to use
"""
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed % 2**32)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def main(argv):
set_seeds(0)
params_file = "params_eval.yml"
if len(argv) == 2 and "params_" in argv[1]:
params_file = argv[1]
print(f"Overriding params file with {params_file}...")
with open(params_file, 'r') as f:
params = yaml.safe_load(f)
if 'lidc_sampling_speed' in params['dataset_file']:
params['dataset_file'] = "datasets.lidc"
eval_lidc_sampling_speed(params)
elif 'lidc' in params['dataset_file']:
eval_lidc_uncertainty(params)
elif 'cityscapes' in params['dataset_file']:
run_inference_only_cdm(params)
else:
raise ValueError("Unknown dataset")
if __name__ == "__main__":
main(sys.argv)