Skip to content

Commit

Permalink
feat: handle deprecation of auto_lr_find and auto_scale_batch_size ar…
Browse files Browse the repository at this point in the history
…guments from Trainer
  • Loading branch information
IamGianluca committed Mar 21, 2024
1 parent 4b31719 commit b24b17e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 19 deletions.
54 changes: 35 additions & 19 deletions blazingai/recipes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path
from types import ModuleType
from typing import Tuple
from typing import List, Tuple

import lightning as pl
import numpy as np
Expand All @@ -16,6 +16,8 @@

from lightning.pytorch import callbacks
from lightning.pytorch.callbacks import RichProgressBar

from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig, OmegaConf
from timm.data import transforms_factory
Expand Down Expand Up @@ -97,44 +99,58 @@ def image_classification_recipe(
)

model = learner.ImageClassifier(cfg=cfg)

cbacks: List[Callback] = list()
checkpoint_callback = callbacks.ModelCheckpoint(
monitor="val_metric",
mode=cfg.metric_mode,
dirpath=const.ckpt_path,
filename=f"model_{cfg.name}_fold{cfg.fold}",
save_weights_only=True,
)
cbacks.append(checkpoint_callback)
lr_callback = callbacks.LearningRateMonitor(
logging_interval="step", log_momentum=True
)
cbacks.append(lr_callback)

if cfg.auto_lr_find:
from lightning.pytorch.callbacks import LearningRateFinder

lr_finder_callback = LearningRateFinder()
cbacks.append(lr_finder_callback)

if cfg.auto_scale_batch_size:
from lightning.pytorch.callbacks import BatchSizeFinder

bs_finder_callback = BatchSizeFinder()
cbacks.append(bs_finder_callback)

progress_bar_callback = RichProgressBar()
cbacks.append(progress_bar_callback)

trainer = pl.Trainer(
accelerator="gpu",
max_epochs=cfg.epochs,
devices=[0],
precision=cfg.precision,
overfit_batches=cfg.overfit_batches,
auto_lr_find=cfg.auto_lr,
accumulate_grad_batches=cfg.accumulate_grad_batches,
auto_scale_batch_size=cfg.auto_batch_size,
logger=logger,
callbacks=[checkpoint_callback, lr_callback, RichProgressBar()],
callbacks=cbacks,
)
trainer.fit(model, datamodule=data)
print_mtrc(cfg.metric, model.best_train_metric, model.best_val_metric) # type: ignore

trgt = df_val.loc[:, const.trgt_cols].values.tolist()
pred = trainer.predict(model, datamodule=data, ckpt_path="best")
pred = [p[0] * 100 for b in pred for p in b] # type: ignore
return (
model.best_train_metric.detach().cpu().numpy(), # type: ignore
model.best_val_metric.detach().cpu().numpy(), # type: ignore
trgt,
pred,
)
if cfg.auto_lr or cfg.auto_batch_size:
trainer.tune(model, data)
else:
trainer.fit(model, datamodule=data)
print_mtrc(cfg.metric, model.best_train_metric, model.best_val_metric) # type: ignore

trgt = df_val.loc[:, const.trgt_cols].values.tolist()
pred = trainer.predict(model, datamodule=data, ckpt_path="best")
pred = [p[0] * 100 for b in pred for p in b] # type: ignore
return (
model.best_train_metric.detach().cpu().numpy(), # type: ignore
model.best_val_metric.detach().cpu().numpy(), # type: ignore
trgt,
pred,
)


def log_mtrc(logger: Logger, metrics: CrossValMetrics) -> None:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dev = [
"ipdb",
"mypy",
"usort",
"pynvim",
]
medical = [
"pydicom>2.3.1",
Expand Down
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ fsspec==2024.2.0
# lightning
# pytorch-lightning
# torch
greenlet==3.0.3
# via pynvim
huggingface-hub==0.21.4
# via
# datasets
Expand Down Expand Up @@ -95,6 +97,8 @@ moreorless==0.4.0
# via usort
mpmath==1.3.0
# via sympy
msgpack==1.0.8
# via pynvim
multidict==6.0.5
# via
# aiohttp
Expand Down Expand Up @@ -193,6 +197,7 @@ pygments==2.17.2
pylibjpeg==2.0.0
pylibjpeg-libjpeg==2.0.2
pylibjpeg-openjpeg==2.1.1
pynvim==0.5.0
pytest==8.1.1
python-dateutil==2.9.0.post0
# via pandas
Expand Down

0 comments on commit b24b17e

Please sign in to comment.