Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move contents of common.yaml to train and inference.
Browse files Browse the repository at this point in the history
Virginia committed Jan 28, 2025
1 parent 9d36b3b commit dd92a8e
Showing 4 changed files with 417 additions and 254 deletions.
61 changes: 61 additions & 0 deletions models/mednist_ddpm/configs/inference.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,65 @@
# This defines an inference script for generating a random image to a Pytorch file
imports:
- $import os
- $import datetime
- $import torch
- $import scripts
- $import monai
- $import torch.distributed as dist
- $import operator

# Common elements to all yaml files
-
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED

is_dist: '$dist.is_initialized()'
rank: '$dist.get_rank() if @is_dist else 0'
is_not_rank0: '$@rank > 0'
device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")'

network_def:
_target_: monai.networks.nets.DiffusionModelUNet
spatial_dims: 2
in_channels: 1
out_channels: 1
channels: [64, 128, 128]
attention_levels: [false, true, true]
num_res_blocks: 1
num_head_channels: 128

network: $@network_def.to(@device)
bundle_root: .
ckpt_path: $@bundle_root + '/models/model.pt'
use_amp: true
image_dim: 64
image_size: [1, '@image_dim', '@image_dim']
num_train_timesteps: 1000

base_transforms:
- _target_: LoadImaged
keys: '@image'
image_only: true
- _target_: EnsureChannelFirstd
keys: '@image'
- _target_: ScaleIntensityRanged
keys: '@image'
a_min: 0.0
a_max: 255.0
b_min: 0.0
b_max: 1.0
clip: true

scheduler:
_target_: monai.networks.schedulers.DDPMScheduler
num_train_timesteps: '@num_train_timesteps'

inferer:
_target_: monai.inferers.DiffusionInferer
scheduler: '@scheduler'

# Inference-specific

batch_size: 1
num_workers: 0
61 changes: 61 additions & 0 deletions models/mednist_ddpm/configs/train.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,65 @@
# This defines the training script for the network
imports:
- $import os
- $import datetime
- $import torch
- $import scripts
- $import monai
- $import torch.distributed as dist
- $import operator

# Common elements to all training files
-
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED

is_dist: '$dist.is_initialized()'
rank: '$dist.get_rank() if @is_dist else 0'
is_not_rank0: '$@rank > 0'
device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")'

network_def:
_target_: monai.networks.nets.DiffusionModelUNet
spatial_dims: 2
in_channels: 1
out_channels: 1
channels: [64, 128, 128]
attention_levels: [false, true, true]
num_res_blocks: 1
num_head_channels: 128

network: $@network_def.to(@device)
bundle_root: .
ckpt_path: $@bundle_root + '/models/model.pt'
use_amp: true
image_dim: 64
image_size: [1, '@image_dim', '@image_dim']
num_train_timesteps: 1000

base_transforms:
- _target_: LoadImaged
keys: '@image'
image_only: true
- _target_: EnsureChannelFirstd
keys: '@image'
- _target_: ScaleIntensityRanged
keys: '@image'
a_min: 0.0
a_max: 255.0
b_min: 0.0
b_max: 1.0
clip: true

scheduler:
_target_: monai.networks.schedulers.DDPMScheduler
num_train_timesteps: '@num_train_timesteps'

inferer:
_target_: monai.inferers.DiffusionInferer
scheduler: '@scheduler'

# Training-specific

# choose a new directory for every run
output_dir: $datetime.datetime.now().strftime('./results/output_%y%m%d_%H%M%S')
53 changes: 53 additions & 0 deletions models/mednist_ddpm/configs/train_multigpu.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,57 @@
# This can be mixed in with the training script to enable multi-GPU training
imports:
- $import os
- $import datetime
- $import torch
- $import scripts
- $import monai
- $import torch.distributed as dist
- $import operator

# Common elements to all training files
-
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED

is_dist: '$dist.is_initialized()'
rank: '$dist.get_rank() if @is_dist else 0'
is_not_rank0: '$@rank > 0'
device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")'

network_def:
_target_: monai.networks.nets.DiffusionModelUNet
spatial_dims: 2
in_channels: 1
out_channels: 1
channels: [64, 128, 128]
attention_levels: [false, true, true]
num_res_blocks: 1
num_head_channels: 128

base_transforms:
- _target_: LoadImaged
keys: '@image'
image_only: true
- _target_: EnsureChannelFirstd
keys: '@image'
- _target_: ScaleIntensityRanged
keys: '@image'
a_min: 0.0
a_max: 255.0
b_min: 0.0
b_max: 1.0
clip: true

scheduler:
_target_: monai.networks.schedulers.DDPMScheduler
num_train_timesteps: '@num_train_timesteps'

inferer:
_target_: monai.inferers.DiffusionInferer
scheduler: '@scheduler'

# Training specific

network:
_target_: torch.nn.parallel.DistributedDataParallel
496 changes: 242 additions & 254 deletions models/mednist_ddpm/docs/2d_ddpm_bundle_tutorial.ipynb

Large diffs are not rendered by default.

0 comments on commit dd92a8e

Please sign in to comment.