Patch based training, but volume statistics #869
-
Hi all, I am training a super resolution network on brain MRI data, using TorchIO and Pytorch Lightning. So does anyone know how to do this properly? I can see some kind of dirty approach by using an inference gridsampler for every subject in my validation set, but I don't really want to go down that path... Below a code snippet of the relevant stuff of my training (based on Pytorch Lightning) def training_step(self, batch, batch_idx, optimizer_idx):
imgs_lr, imgs_hr = self.prepare_batch(batch)
imgs_sr = self(imgs_lr)
train_batches_done = batch_idx + self.current_epoch * self.train_len
# ---------------------
# Train Generator
# ---------------------
if optimizer_idx == 0:
loss_pixel = self.criterion_pixel(imgs_sr, imgs_hr)
if train_batches_done < self.args.warmup_batches:
# Warm-up (pixel loss only)
g_loss = loss_pixel
return g_loss
loss_edge = self.criterion_edge(imgs_sr, imgs_hr)
gen_features = self.netF(torch.repeat_interleave(imgs_sr, 3, 1))
real_features = self.netF(torch.repeat_interleave(imgs_hr, 3, 1)).detach()
loss_content = self.criterion_content(gen_features, real_features)
# Extract validity predictions from discriminator
pred_real = self.netD(imgs_hr).detach()
pred_fake = self.netD(imgs_sr)
if self.ragan:
pred_fake -= pred_real.mean(0, keepdim=True)
# Adversarial loss
loss_adv = self.criterion_GAN(pred_fake, True)
g_loss = 0.3 * loss_edge + 0.7 * loss_pixel + self.alpha_adv * loss_adv + loss_content
return g_loss
# ---------------------
# Train Discriminator
# ---------------------
if optimizer_idx == 1:
# Extract validity predictions from discriminator
pred_real = self.netD(imgs_hr)
pred_fake = self.netD(imgs_sr.detach())
if self.ragan:
pred_real, pred_fake = pred_real - pred_fake.mean(0, keepdim=True), \
pred_fake - pred_real.mean(0, keepdim=True)
# Adversarial loss
loss_real = self.criterion_GAN(pred_real, True)
loss_fake = self.criterion_GAN(pred_fake, False)
d_loss = (loss_real + loss_fake) / 2
self.log('Epoch loss/discriminator', {"Train": d_loss},
on_step=False,
on_epoch=True,
sync_dist=True,
batch_size=self.batch_size)
return d_loss
def validation_step(self, batch, batch_idx):
with torch.no_grad():
# ---------------------
# Validation Generator
# ---------------------
imgs_lr, imgs_hr = self.prepare_batch(batch)
imgs_sr = self(imgs_lr)
loss_pixel = self.criterion_pixel(imgs_sr, imgs_hr)
loss_edge = self.criterion_edge(imgs_sr, imgs_hr)
gen_features = self.netF(imgs_sr.repeat(1, 3, 1 ,1))
real_features = self.netF(imgs_hr.repeat(1, 3, 1 ,1))
loss_content = self.criterion_content(gen_features, real_features)
# Extract validity predictions from discriminator
pred_real = self.netD(imgs_hr)
pred_fake = self.netD(imgs_sr)
# Relativistic average GAN
if self.ragan:
pred_real, pred_fake = pred_real - pred_fake.mean(0, keepdim=True), \
pred_fake - pred_real.mean(0, keepdim=True)
# Adversarial loss
loss_adv = self.criterion_GAN(pred_fake, True) # Gradient Penalty cannot be calculated during validation
g_loss = 0.3 * loss_edge + 0.7 * loss_pixel + self.alpha_adv * loss_adv + loss_content
# ---------------------
# Validation Discriminator
# ---------------------
# Adversarial loss for real and fake images
loss_real = self.criterion_GAN(pred_real, True)
loss_fake = self.criterion_GAN(pred_fake, False)
d_loss = (loss_real + loss_fake) / 2
return g_loss, d_loss
def setup(self, stage='fit'):
args = self.args
data_path = os.path.join(args.root_dir, 'data')
train_subjects = data_split('training',
patients_frac=self.patients_frac,
root_dir=data_path,
datasource=self.datasource,
numslices=50)
val_subjects = data_split('validation',
patients_frac=self.patients_frac,
root_dir=data_path,
datasource=self.datasource,
numslices=50)
training_transform = tio.Compose([
Normalize(std=args.std),
tio.RandomFlip(axes=(0, 1), flip_probability=0.5),
tio.RandomAffine(degrees=(0, 0, 0, 0, 0, 360),
default_pad_value=0,
scales=0,
translation=0,
isotropic=True
)
])
self.training_set = tio.SubjectsDataset(
train_subjects, transform=training_transform)
self.val_set = tio.SubjectsDataset(
val_subjects, transform=training_transform)
overlap, nr_patches = calculate_overlap(train_subjects[0]['LR'],
(self.patch_size, self.patch_size),
(self.patch_overlap, self.patch_overlap)
)
self.samples_per_volume = nr_patches
probabilities = {0: 0, 1: 1}
self.sampler = tio.data.LabelSampler(
patch_size=(self.patch_size, self.patch_size, 1),
label_name='HR_msk_bin',
label_probabilities=probabilities,
)
def train_dataloader(self):
training_queue = tio.Queue(
subjects_dataset=self.training_set,
max_length=self.samples_per_volume * 10,
samples_per_volume=self.samples_per_volume,
sampler=self.sampler,
num_workers=self.args.num_workers,
shuffle_subjects=True,
shuffle_patches=True,
)
training_loader = torch.utils.data.DataLoader(
training_queue,
batch_size=self.batch_size,
num_workers=0,
)
self.train_len = len(training_loader)
return training_loader
def val_dataloader(self):
val_queue = tio.Queue(
subjects_dataset=self.val_set,
max_length=self.samples_per_volume * 5,
samples_per_volume=self.samples_per_volume,
sampler=self.sampler,
num_workers=self.args.num_workers,
shuffle_subjects=False,
shuffle_patches=False,
)
val_loader = torch.utils.data.DataLoader(
val_queue,
batch_size=self.batch_size,
num_workers=0,
)
self.val_len = len(val_loader)
return val_loader |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Hi, @rienboonstoppel. I don't think it's that clunky to instantiate a sampler and aggregator for each validation subject, nor why you can't do early stopping following that approach. I've trained like that before without issues. |
Beta Was this translation helpful? Give feedback.
Hi, @rienboonstoppel. I don't think it's that clunky to instantiate a sampler and aggregator for each validation subject, nor why you can't do early stopping following that approach. I've trained like that before without issues.