-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
91 lines (69 loc) · 2.31 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
This script is the entry point for training and testing the model.
It instantiates all necessary modules, trains the model and tests it.
"""
import os
from typing import List
import warnings
import hydra
import torch
from lightning import (
Callback,
LightningDataModule,
LightningModule,
seed_everything,
Trainer,
)
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
from torchmetrics import MetricCollection
@hydra.main(
config_path="configs",
config_name="run.yaml",
version_base="1.3",
)
def main(cfg: DictConfig):
"""
Instantiate all necessary modules, train and test the model.
Args:
cfg (DictConfig): Hydra configuration object, passed in by the @hydra.main decorator
"""
# Instantiate LightningDataModule
lightning_datamodule: LightningDataModule = hydra.utils.instantiate(
cfg.lightning_datamodule
)
# Instantiate LightningModule
metrics: MetricCollection = MetricCollection(
dict(hydra.utils.instantiate(cfg.metrics))
)
lightning_module: LightningModule = hydra.utils.instantiate(
cfg.lightning_module,
metrics=metrics,
)
# Instantiate Trainer
callbacks: List[Callback] = list(hydra.utils.instantiate(cfg.callbacks).values())
logger: Logger = hydra.utils.instantiate(cfg.logging.logger)
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)
# Train the model ⚡
trainer.fit(lightning_module, datamodule=lightning_datamodule)
# Test the model
trainer.test(ckpt_path='last', datamodule=lightning_datamodule)
def setup_environment():
"""
Setup environment for training
"""
warnings.filterwarnings("ignore")
# Set environment variables for full trace of errors
os.environ["HYDRA_FULL_ERROR"] = "1"
# Enable CUDNN backend
torch.backends.cudnn.enabled = True
# Enable CUDNN benchmarking to choose the best algorithm for every new input size
# e.g. for convolutional layers chose between Winograd, GEMM-based, or FFT algorithms
torch.backends.cudnn.benchmark = True
# Sets seeds for numpy, torch and python.random for reproducibility
seed_everything(42, workers=True)
if __name__ == "__main__":
setup_environment()
main()