-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
56 lines (44 loc) · 1.63 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.nn import DataParallel
from utils.utils import init_run, print_model_params
@hydra.main(config_path="configs", version_base=None)
def main(config: DictConfig) -> None:
device = init_run(config)
if config.general.use_multi_gpu:
model = DataParallel(instantiate(config=config.model)).to(device)
else:
model = instantiate(config=config.model).to(device)
print_model_params(model)
dataloader = instantiate(
config=config.dataloader,
dataset_type=config.dataset.type,
dataset_name=config.dataset.name,
train_dataset=config.dataset.train_data,
val_dataset=config.dataset.val_data,
test_dataset=config.dataset.test_data
)
train_dataloader, val_dataloader, test_dataloader = dataloader.create_dataloaders()
criterion = instantiate(config=config.loss)
optimizer = instantiate(config=config.optimizer, params=model.parameters())
trainer = instantiate(
config=config.trainer,
device=device,
model=model,
criterion=criterion,
optimizer=optimizer,
task=config.dataset.type,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=test_dataloader,
log_to_wandb=config.general.log_to_wandb,
save_dir=config.general.save_dir,
scheduler=None
)
for epoch in range(1, config.general.max_epochs + 1):
trainer.train(current_epoch_nr=epoch)
trainer.evaluate(current_epoch_nr=epoch)
trainer.test()
if __name__ == '__main__':
main()