-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_segmentator.py
More file actions
26 lines (21 loc) · 951 Bytes
/
train_segmentator.py
File metadata and controls
26 lines (21 loc) · 951 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from shutil import rmtree
import os
# Reduce VRAM usage by reducing fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
#os.environ["TORCHDYNAMO_VERBOSE"] = "1"
from clearml import Task
from pytorch_lightning.cli import LightningCLI
from datasets.grazpedwri_dataset import SegGrazPedWriDataModule
from datasets.jsrt_dataset import JSRTDataModule
from models.poolformer_segmentor import MetaFormerSegmentator
from models.unet import UNetSegmentator, UNetOnPatchEmbedding
task = Task.init(project_name="FlexConv/Segmentation", auto_resource_monitoring=False, reuse_last_task_id=False,
auto_connect_frameworks=False)
# training routine
cli = LightningCLI()
# housekeeping
trainer = cli.trainer
Task.current_task().upload_artifact("best.ckpt", trainer.checkpoint_callback.best_model_path, wait_on_upload=True)
Task.current_task().close()
if trainer.logger is not None:
rmtree(trainer.logger.log_dir)