-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmain.py
More file actions
54 lines (42 loc) · 1.44 KB
/
main.py
File metadata and controls
54 lines (42 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from config.config import ProjectConfig
from tridi.core.evaluator import Evaluator
from tridi.core.sampler import Sampler
from tridi.core.trainer import Trainer
from tridi.model import get_model
from tridi.utils import training as training_utils
from tridi.utils.exp import init_exp, init_wandb, init_logging, parse_arguments
def main():
torch.multiprocessing.set_sharing_strategy('file_system')
torch.set_float32_matmul_precision('high')
# Parse arguments
arguments = parse_arguments()
# Initialzie run
cfg: ProjectConfig = init_exp(arguments)
# Logging
init_logging(cfg)
if cfg.logging.wandb:
init_wandb(cfg)
# Set random seed
training_utils.set_seed(cfg.run.seed)
if cfg.run.job in ['train', 'sample']:
# Model
model = get_model(cfg)
if cfg.run.job == 'train':
trainer = Trainer(cfg, model)
trainer.train()
elif cfg.run.job == 'sample':
sampler = Sampler(cfg, model)
if cfg.sample.target == 'meshes':
sampler.sample()
elif cfg.sample.target == 'hdf5':
sampler.sample_to_hdf5()
else:
raise ValueError(f"Invalid target {cfg.sample.target}")
elif cfg.run.job == 'eval':
evaluator = Evaluator(cfg)
evaluator.evaluate()
else:
raise ValueError(f"Invalid job type {cfg.run.job}")
if __name__ == '__main__':
main()