-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
73 lines (60 loc) · 3.22 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
from argparse import ArgumentParser
import utils
import warnings
import lightning_datamodule
import lightning_module
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
if __name__ == "__main__":
# Filter harmless warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", ".*Your `val_dataloader` has `shuffle=True`.*")
warnings.filterwarnings("ignore", ".*Checkpoint directory.*")
# Parse command line arguments
parser = ArgumentParser()
parser.add_argument("--accelerator", default="gpu", help="Type of accelerator: 'gpu', 'cpu', 'auto'")
parser.add_argument("--devices", default="auto", help="Number of devices (GPUs or CPU cores) to use: integer starting from 1 or 'auto'")
parser.add_argument("--workers", type=int, default=4, help="Number of CPU cores to use as as workers for the dataloarders: integer starting from 1 to maximum number of cores on this machine")
parser.add_argument("--epochs", type=int, default=60, help="Maximum number of epochs to run for")
parser.add_argument("--bs", type=int, default=256, help="Batch size")
parser.add_argument("--lr", type=float, default=0.1, help="Initial learning rate")
parser.add_argument("--pretrained", default="False", help="Whether to load pretrained ResNet18 and fine tune or not and train from scratch (True/False)")
args = parser.parse_args()
# Print summary of selected arguments and adjust them if needed
args = utils.args_interpreter(args)
# Instantiate the datamodule
ldm = lightning_datamodule.CIFAR100DataModule(batch_size=args.bs, num_workers=args.workers)
# Instantiate the logger
tensorboard_logger = TensorBoardLogger(save_dir="logs")
# Instantiate early stopping based on epoch validation loss
early_stopping = EarlyStopping("validation_loss", patience=40, verbose=True)
# Instantiate a learning rate monitor
lr_monitor = LearningRateMonitor(logging_interval='epoch')
# Instantiate a checkpoint callback
checkpoint = ModelCheckpoint(
dirpath=f"./checkpoints/",
filename="{epoch}-{validation_loss:.2f}",
verbose=True,
monitor="validation_loss",
save_last = False,
save_top_k=1,
mode="min",
save_weights_only=True
)
# Instantiate the trainer
trainer = Trainer(
accelerator=args.accelerator,
devices=args.devices,
max_epochs=args.epochs,
logger=tensorboard_logger,
log_every_n_steps = 1,
callbacks=[lr_monitor, early_stopping, checkpoint]
)
# Instantiate the pipeline
lm = lightning_module.CIFAR100ResNet(learning_rate=args.lr, batch_size=args.bs, pretrained=args.pretrained)
# Fit the trainer on the training set
trainer.fit(lm, ldm)
# Test on the test set
trainer.test(lm, ldm)