-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_classifier.py
More file actions
28 lines (23 loc) · 1.13 KB
/
train_classifier.py
File metadata and controls
28 lines (23 loc) · 1.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from shutil import rmtree
import os
# Reduce VRAM usage by reducing fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
#os.environ["TORCHDYNAMO_VERBOSE"] = "1"
from clearml import Task
from pytorch_lightning.cli import LightningCLI
from models.cnn_classifier import CNNClassifier, ConvFormerClassifier
from models.poolformer_classifier import PoolFormerClassifier, FlexFormerClassifier, MetaFormerClassifier
from models.pretrained_metaformer_classifier import PretrainedMetaformer
from models.metaformer_classifier import AdaptiveMetaformerClassifier
from datasets.med_mnist_dataset import MedMNISTDataModule
from datasets.imagewoof_dataset import ImageWoofDataModule
task = Task.init(project_name="FlexConv/Classification", auto_resource_monitoring=False, reuse_last_task_id=False,
auto_connect_frameworks=False)
# training routine
cli = LightningCLI()
# housekeeping
trainer = cli.trainer
Task.current_task().upload_artifact("best.ckpt", trainer.checkpoint_callback.best_model_path, wait_on_upload=True)
Task.current_task().close()
if trainer.logger is not None:
rmtree(trainer.logger.log_dir)