diff --git a/Datasets/Sample_Dataset/val/Bladder1.json b/Datasets/Sample_Dataset/val/Bladder1.json new file mode 100644 index 0000000..f10b59f --- /dev/null +++ b/Datasets/Sample_Dataset/val/Bladder1.json @@ -0,0 +1,10 @@ +{ + "num_total": 183, + "num_pos": 15, + "num_neg": 168, + "percent_pos": 8.2, + "prob_thresh": 150, + "size_thresh": 30, + "size_thresh_upper": null, + "marker_thresh": null +} \ No newline at end of file diff --git a/Datasets/Sample_Dataset/val/Bladder1.png b/Datasets/Sample_Dataset/val/Bladder1.png new file mode 100644 index 0000000..dcbcc4d Binary files /dev/null and b/Datasets/Sample_Dataset/val/Bladder1.png differ diff --git a/Datasets/Sample_Dataset/val/Lung1.json b/Datasets/Sample_Dataset/val/Lung1.json new file mode 100644 index 0000000..90fb847 --- /dev/null +++ b/Datasets/Sample_Dataset/val/Lung1.json @@ -0,0 +1,10 @@ +{ + "num_total": 60, + "num_pos": 11, + "num_neg": 49, + "percent_pos": 18.3, + "prob_thresh": 150, + "size_thresh": 78, + "size_thresh_upper": null, + "marker_thresh": null +} \ No newline at end of file diff --git a/Datasets/Sample_Dataset/val/Lung1.png b/Datasets/Sample_Dataset/val/Lung1.png new file mode 100644 index 0000000..7223ef4 Binary files /dev/null and b/Datasets/Sample_Dataset/val/Lung1.png differ diff --git a/Datasets/Sample_Dataset/val/metrics.json b/Datasets/Sample_Dataset/val/metrics.json new file mode 100644 index 0000000..cdc25dd --- /dev/null +++ b/Datasets/Sample_Dataset/val/metrics.json @@ -0,0 +1,22 @@ +{ + "Lung1": { + "num_total": 60, + "num_pos": 11, + "num_neg": 49, + "percent_pos": 18.3, + "prob_thresh": 150, + "size_thresh": 78, + "size_thresh_upper": null, + "marker_thresh": null + }, + "Bladder1": { + "num_total": 183, + "num_pos": 15, + "num_neg": 168, + "percent_pos": 8.2, + "prob_thresh": 150, + "size_thresh": 30, + "size_thresh_upper": null, + "marker_thresh": null + } +} \ No newline at end of file diff --git a/Image_Processing/Augmentation.py b/Image_Processing/Augmentation.py index b215eb0..a3fa06a 100644 --- a/Image_Processing/Augmentation.py +++ b/Image_Processing/Augmentation.py @@ -6,12 +6,13 @@ class Augmentation: - def __init__(self, images): + def __init__(self, images, tile_size=512): self.images = images self.shape = self.images[list(self.images.keys())[0]].shape self.rotation_angle = np.random.choice([0, 90, 180, 270], 1)[0] # self.zoom_value = random.randint(0, 5) self.alpha_affine = 0.1 + self.tile_size = tile_size def pipeline(self): """ @@ -30,12 +31,13 @@ def zoom(self): :return: """ new_size = random.randint(int(self.shape[0] * 0.75), self.shape[0]) + assert self.shape[1] - new_size >= 0, f'self.shape[1] - new_size ({self.shape[1]} - {new_size})should not be negative' start_point = (random.randint(0, self.shape[0] - new_size), random.randint(0, self.shape[1] - new_size)) for key in self.images.keys(): try: - self.images[key] = cv2.resize(self.images[key][start_point[0]: start_point[0] + new_size, start_point[1]: start_point[1] + new_size], (512, 512)) - except: - print(key + ' not available') + self.images[key] = cv2.resize(self.images[key][start_point[0]: start_point[0] + new_size, start_point[1]: start_point[1] + new_size], (self.tile_size, self.tile_size)) + except Exception as e: + print(e) def rotate(self): """ @@ -47,8 +49,8 @@ def rotate(self): for key in self.images.keys(): try: self.images[key] = ndimage.rotate(self.images[key], self.rotation_angle, reshape=False) - except: - print(key + ' not available') + except Exception as e: + print(e) def elastic_transform(self, random_state=None): """ @@ -78,5 +80,5 @@ def elastic_transform(self, random_state=None): for key in self.images.keys(): try: self.images[key] = cv2.warpAffine(self.images[key], M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101) - except: - print(key + ' not available') \ No newline at end of file + except Exception as e: + print(e) diff --git a/Image_Processing/Image_Processing_Helper_Functions.py b/Image_Processing/Image_Processing_Helper_Functions.py index 8f2e6b2..b71e552 100644 --- a/Image_Processing/Image_Processing_Helper_Functions.py +++ b/Image_Processing/Image_Processing_Helper_Functions.py @@ -87,7 +87,7 @@ def create_training_testing_dataset_from_given_directory(input_dir, output_dir, cv2.imwrite(os.path.join(all_dirs[i], filename), all_images[filename]) -def augment_set(input_dir, output_dir, aug_no=9, modality_types=['hematoxylin', 'CD3', 'PanCK'], tile_size=512): +def augment_set(input_dir, output_dir, aug_no=9, modality_types=['hematoxylin', 'CD3', 'PanCK']): """ This function augments a co-aligned dataset. @@ -105,20 +105,29 @@ def augment_set(input_dir, output_dir, aug_no=9, modality_types=['hematoxylin', """ if not os.path.exists(output_dir): os.makedirs(output_dir) - images = os.listdir(input_dir) - for img in images: + images_original = os.listdir(input_dir) + print(f'{len(images_original)} images found') + + count = 0 + for i,img in enumerate(images_original): augmented = 0 while augmented < aug_no: images = {} image = cv2.imread(os.path.join(input_dir, img)) + if i == 0: + tile_size = image.shape[0] + assert image.shape[1] >= len(modality_types) * tile_size, f'image width ({image.shape[1]}) is not enough for {len(modality_types)} modalities with tile size {tile_size}' for i in range(0, len(modality_types)): images[modality_types[i]] = image[:, i * tile_size: (i + 1) * tile_size] new_images = images.copy() - aug = Augmentation(new_images) + aug = Augmentation(new_images, tile_size) aug.pipeline() cv2.imwrite(os.path.join(output_dir, img.replace('.png', '_' + str(augmented) + '.png')), np.concatenate(list(new_images.values()), 1)) augmented += 1 + count += 1 + if count % 10 == 0 or count == len(images_original): + print(f'Done {count}/{len(images_original)}') def augment_created_dataset(input_dir, output_dir, aug_no=9, modality_types=['hematoxylin', 'CD3', 'PanCK'], tile_size=512): diff --git a/cli.py b/cli.py index b089c15..bb3c58f 100644 --- a/cli.py +++ b/cli.py @@ -8,9 +8,10 @@ import torch import numpy as np from PIL import Image +from torchvision.transforms import ToPILImage from deepliif.data import create_dataset, transform -from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi, create_model +from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi, create_model, postprocess from deepliif.util import allowed_file, Visualizer, get_information, test_diff_original_serialized, disable_batchnorm_tracking_stats from deepliif.util.util import mkdirs # from deepliif.util import infer_results_for_wsi @@ -85,7 +86,7 @@ def cli(): help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 ' 'PatchGAN. n_layers allows you to specify the layers in the discriminator') @click.option('--net-g', default='resnet_9blocks', - help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]') + help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention') @click.option('--n-layers-d', default=4, help='only used if netD==n_layers') @click.option('--norm', default='batch', help='instance normalization or batch normalization [instance | batch | none]') @@ -128,12 +129,15 @@ def cli(): help='number of epochs with the initial learning rate') @click.option('--n-epochs-decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero') +@click.option('--optimizer', type=str, default='adam', + help='optimizer from torch.optim to use, applied to both generators and discriminators [adam | sgd | adamw | ...]; the current parameters however are set up for adam, so other optimziers may encounter issue') @click.option('--beta1', default=0.5, help='momentum term of adam') @click.option('--lr', default=0.0002, help='initial learning rate for adam') @click.option('--lr-policy', default='linear', help='learning rate policy. [linear | step | plateau | cosine]') @click.option('--lr-decay-iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') +@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)') # visdom and HTML visualization parameters @click.option('--display-freq', default=400, help='frequency of showing training results on screen') @click.option('--display-ncols', default=4, @@ -158,26 +162,32 @@ def cli(): help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') @click.option('--padding', type=str, default='zero', help='chooses the type of padding used by resnet generator. [reflect | zero]') -@click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup') -@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)') # DeepLIIFExt params @click.option('--seg-gen', type=bool, default=True, help='True (Translation and Segmentation), False (Only Translation).') @click.option('--net-ds', type=str, default='n_layers', help='specify discriminator architecture for segmentation task [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') @click.option('--net-gs', type=str, default='unet_512', - help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]') + help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention') @click.option('--gan-mode', type=str, default='vanilla', help='the type of GAN objective for translation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') @click.option('--gan-mode-s', type=str, default='lsgan', help='the type of GAN objective for segmentation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') +# DDP related arguments +@click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup') +# Others +@click.option('--with-val', is_flag=True, + help='use validation set to evaluate model performance at the end of each epoch') +@click.option('--debug', is_flag=True, + help='debug mode, limits the number of data points per epoch to a small value') +@click.option('--debug-data-size', default=10, type=int, help='data size per epoch used in debug mode; due to batch size, the epoch will be passed once the completed no. data points is greater than this value (e.g., for batch size 3, debug data size 10, the effective size used in training will be 12)') def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, ndf, net_d, net_g, n_layers_d, norm, init_type, init_gain, no_dropout, direction, serial_batches, num_threads, batch_size, load_size, crop_size, max_dataset_size, preprocess, no_flip, display_winsize, epoch, load_iter, verbose, lambda_l1, is_train, display_freq, display_ncols, display_id, display_server, display_env, display_port, update_html_freq, print_freq, no_html, save_latest_freq, save_epoch_freq, save_by_iter, - continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, beta1, lr, lr_decay_iters, - remote, local_rank, remote_transfer_cmd, seed, dataset_mode, padding, model, - modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s): + continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, optimizer, beta1, lr, lr_decay_iters, + remote, remote_transfer_cmd, seed, dataset_mode, padding, model, + modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s, local_rank, with_val, debug, debug_data_size): """General-purpose training script for multi-task image-to-image translation. This script works for various models (with option '--model': e.g., DeepLIIF) and @@ -201,6 +211,9 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd seg_no = 0 seg_gen = False + if optimizer != 'adam': + print(f'Optimizer torch.optim.{optimizer} is not tested. Be careful about the parameters of the optimizer.') + d_params = locals() if gpu_ids and gpu_ids[0] == -1: @@ -213,12 +226,12 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd if local_rank is not None: local_rank = int(local_rank) torch.cuda.set_device(gpu_ids[local_rank]) - gpu_ids=[gpu_ids[local_rank]] + gpu_ids=[local_rank] else: torch.cuda.set_device(gpu_ids[0]) if local_rank is not None: # LOCAL_RANK will be assigned a rank number if torchrun ddp is used - dist.init_process_group(backend='nccl') + dist.init_process_group(backend="nccl", rank=int(os.environ['RANK']), world_size=int(os.environ['WORLD_SIZE'])) print('local rank:',local_rank) flag_deterministic = set_seed(seed,local_rank) elif rank is not None: @@ -234,16 +247,35 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd dir_data_train = dataroot + '/train' fns = os.listdir(dir_data_train) fns = [x for x in fns if x.endswith('.png')] + print(f'{len(fns)} images found') img = Image.open(f"{dir_data_train}/{fns[0]}") + print(f'image shape:',img.size) num_img = img.size[0] / img.size[1] assert int(num_img) == num_img, f'img size {img.size[0]} / {img.size[1]} = {num_img} is not an integer' num_img = int(num_img) input_no = num_img - modalities_no - seg_no - assert input_no > 0, f'inferred number of input images is {input_no}; should be greater than 0' + assert input_no > 0, f'inferred number of input images is {input_no} (modalities_no {modalities_no}, seg_no {seg_no}); should be greater than 0' d_params['input_no'] = input_no d_params['scale_size'] = img.size[1] + d_params['gpu_ids'] = gpu_ids + + # update generator arch + net_g = net_g.split(',') + assert len(net_g) in [1,modalities_no], f'net_g should contain either 1 architecture for all translation generators or the same number of architectures as the number of translation generators ({modalities_no})' + if len(net_g) == 1: + net_g = net_g*modalities_no + + net_gs = net_gs.split(',') + assert len(net_gs) in [1,seg_no], f'net_gs should contain either 1 architecture for all segmentation generators or the same number of architectures as the number of segmentation generators ({seg_no})' + if len(net_gs) == 1 and model == 'DeepLIIF': + net_gs = net_gs*(modalities_no + seg_no) + elif len(net_gs) == 1: + net_gs = net_gs*seg_no + + d_params['net_g'] = net_g + d_params['net_gs'] = net_gs # create a dataset given dataset_mode and other options # dataset = AlignedDataset(opt) @@ -251,9 +283,19 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd opt = Options(d_params=d_params) print_options(opt, save=True) + # set dir for train and val dataset = create_dataset(opt) + # get the number of images in the dataset. click.echo('The number of training images = %d' % len(dataset)) + + if with_val: + dataset_val = create_dataset(opt,phase='val') + data_val = [batch for batch in dataset_val] + click.echo('The number of validation images = %d' % len(dataset_val)) + + if model in ['DeepLIIF']: + metrics_val = json.load(open(os.path.join(dataset_val.dataset.dir_AB,'metrics.json'))) # create a model given model and other options model = create_model(opt) @@ -283,6 +325,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd # inner loop within one epoch for i, data in enumerate(dataset): + # timer for computation per iteration iter_start_time = time.time() if total_iters % print_freq == 0: @@ -299,15 +342,15 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd if total_iters % display_freq == 0: save_result = total_iters % update_html_freq == 0 model.compute_visuals() - visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) + visualizer.display_current_results({**model.get_current_visuals()}, epoch, save_result) # print training losses and save logging information to the disk if total_iters % print_freq == 0: - losses = model.get_current_losses() + losses = model.get_current_losses() # get training losses t_comp = (time.time() - iter_start_time) / batch_size - visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) + visualizer.print_current_losses(epoch, epoch_iter, {**losses}, t_comp, t_data) if display_id > 0: - visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), losses) + visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), {**losses}) # cache our latest model every iterations if total_iters % save_latest_freq == 0: @@ -315,7 +358,11 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd save_suffix = 'iter_%d' % total_iters if save_by_iter else 'latest' model.save_networks(save_suffix) + iter_data_time = time.time() + if debug and epoch_iter >= debug_data_size: + print(f'debug mode, epoch {epoch} stopped at epoch iter {epoch_iter} (>= {debug_data_size})') + break # cache our model every epochs if epoch % save_epoch_freq == 0: @@ -323,6 +370,77 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd model.save_networks('latest') model.save_networks(epoch) + + + # validation loss and metrics calculation + if with_val: + losses = model.get_current_losses() # get training losses to print + + model.eval() + l_losses_val = [] + l_metrics_val = [] + + # for each val image, calculate validation loss and cell count metrics + for j, data_val_batch in enumerate(data_val): + # batch size is effectively 1 for validation + model.set_input(data_val_batch) + model.calculate_losses() # this does not optimize parameters + visuals = model.get_current_visuals() # get image results + + # val losses + losses_val_batch = model.get_current_losses() + l_losses_val += [(k,v) for k,v in losses_val_batch.items()] + + # calculate cell count metrics + if type(model).__name__ == 'DeepLIIFModel': + l_seg_names = ['fake_B_5'] + assert l_seg_names[0] in visuals.keys(), f'Cannot find {l_seg_names[0]} in generated image names ({list(visuals.keys())})' + seg_mod_suffix = l_seg_names[0].split('_')[-1] + l_seg_names += [x for x in visuals.keys() if x.startswith('fake') and x.split('_')[-1].startswith(seg_mod_suffix) and x != l_seg_names[0]] + # print(f'Running postprocess for {len(l_seg_names)} generated images ({l_seg_names})') + + img_name_current = data_val_batch['A_paths'][0].split('/')[-1][:-4] # remove .png + metrics_gt = metrics_val[img_name_current] + + for seg_name in l_seg_names: + images = {'Seg':ToPILImage()((visuals[seg_name][0].cpu()+1)/2), + #'Marker':ToPILImage()((visuals['fake_B_4'][0].cpu()+1)/2) + } + _, scoring = postprocess(ToPILImage()((data['A'][0]+1)/2), images, opt.scale_size, opt.model) + + for k,v in scoring.items(): + if k.startswith('num') or k.startswith('percent'): + # to calculate the rmse, here we calculate (x_pred - x_true) ** 2 + l_metrics_val.append((k+'_'+seg_name,(v - metrics_gt[k])**2)) + + if debug and epoch_iter >= debug_data_size: + print(f'debug mode, epoch {epoch} stopped at epoch iter {epoch_iter} (>= {debug_data_size})') + break + + d_losses_val = {k+'_val':0 for k in losses_val_batch.keys()} + for k,v in l_losses_val: + d_losses_val[k+'_val'] += v + for k in d_losses_val: + d_losses_val[k] = d_losses_val[k] / len(data_val) + + d_metrics_val = {} + for k,v in l_metrics_val: + try: + d_metrics_val[k] += v + except: + d_metrics_val[k] = v + for k in d_metrics_val: + # to calculate the rmse, this is the second part, where d_metrics_val[k] now represents sum((x_pred - x_true) ** 2) + d_metrics_val[k] = np.sqrt(d_metrics_val[k] / len(data_val)) + + + model.train() + t_comp = (time.time() - iter_start_time) / batch_size + visualizer.print_current_losses(epoch, epoch_iter, {**losses,**d_losses_val, **d_metrics_val}, t_comp, t_data) + if display_id > 0: + visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), {**losses,**d_losses_val,**d_metrics_val}) + + print('End of epoch %d / %d \t Time Taken: %d sec' % ( epoch, n_epochs + n_epochs_decay, time.time() - epoch_start_time)) # update learning rates at the end of every epoch. @@ -336,8 +454,9 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd help='name of the experiment. It decides where to store samples and models') @click.option('--gpu-ids', type=int, multiple=True, help='gpu-ids 0 gpu-ids 1 or gpu-ids -1 for CPU') @click.option('--checkpoints-dir', default='./checkpoints', help='models are saved here') -@click.option('--targets-no', default=5, help='number of targets') +@click.option('--modalities-no', default=4, type=int, help='number of targets') # model parameters +@click.option('--model', default='DeepLIIF', help='name of model class') @click.option('--input-nc', default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') @click.option('--output-nc', default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') @click.option('--ngf', default=64, help='# of gen filters in the last conv layer') @@ -346,14 +465,13 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 ' 'PatchGAN. n_layers allows you to specify the layers in the discriminator') @click.option('--net-g', default='resnet_9blocks', - help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]') + help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention') @click.option('--n-layers-d', default=4, help='only used if netD==n_layers') @click.option('--norm', default='batch', help='instance normalization or batch normalization [instance | batch | none]') @click.option('--init-type', default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') @click.option('--init-gain', default=0.02, help='scaling factor for normal, xavier and orthogonal.') -@click.option('--padding-type', default='reflect', help='network padding type.') @click.option('--no-dropout', is_flag=True, help='no dropout for the generator') # dataset parameters @click.option('--direction', default='AtoB', help='AtoB or BtoA') @@ -390,12 +508,15 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd help='number of epochs with the initial learning rate') @click.option('--n-epochs-decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero') +@click.option('--optimizer', type=str, default='adam', + help='optimizer from torch.optim to use, applied to both generators and discriminators [adam | sgd | adamw | ...]; the current parameters however are set up for adam, so other optimziers may encounter issue') @click.option('--beta1', default=0.5, help='momentum term of adam') @click.option('--lr', default=0.0002, help='initial learning rate for adam') @click.option('--lr-policy', default='linear', help='learning rate policy. [linear | step | plateau | cosine]') @click.option('--lr-decay-iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') +@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)') # visdom and HTML visualization parameters @click.option('--display-freq', default=400, help='frequency of showing training results on screen') @click.option('--display-ncols', default=4, @@ -416,8 +537,29 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd @click.option('--save-by-iter', is_flag=True, help='whether saves model by iteration') @click.option('--remote', type=bool, default=False, help='whether isolate visdom checkpoints or not; if False, you can run a separate visdom server anywhere that consumes the checkpoints') @click.option('--remote-transfer-cmd', type=str, default=None, help='module and function to be used to transfer remote files to target storage location, for example mymodule.myfunction') +@click.option('--dataset-mode', type=str, default='aligned', + help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') +@click.option('--padding', type=str, default='zero', + help='chooses the type of padding used by resnet generator. [reflect | zero]') +# DeepLIIFExt params +@click.option('--seg-gen', type=bool, default=True, help='True (Translation and Segmentation), False (Only Translation).') +@click.option('--net-ds', type=str, default='n_layers', + help='specify discriminator architecture for segmentation task [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') +@click.option('--net-gs', type=str, default='unet_512', + help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention') +@click.option('--gan-mode', type=str, default='vanilla', + help='the type of GAN objective for translation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') +@click.option('--gan-mode-s', type=str, default='lsgan', + help='the type of GAN objective for segmentation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') +# DDP related arguments @click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup') -@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)') +# Others +@click.option('--with-val', is_flag=True, + help='use validation set to evaluate model performance at the end of each epoch') +@click.option('--debug', is_flag=True, + help='debug mode, limits the number of data points per epoch to a small value') +@click.option('--debug-data-size', default=10, type=int, help='data size per epoch used in debug mode; due to batch size, the epoch will be passed once the completed no. data points is greater than this value (e.g., for batch size 3, debug data size 10, the effective size used in training will be 12)') +# trainlaunch DDP related arguments @click.option('--use-torchrun', type=str, default=None, help='provide torchrun options, all in one string, for example "-t3 --log_dir ~/log/ --nproc_per_node 1"; if your pytorch version is older than 1.10, torch.distributed.launch will be called instead of torchrun') def trainlaunch(**kwargs): """ @@ -448,6 +590,7 @@ def trainlaunch(**kwargs): elif args[i-1] not in l_arg_skip and arg not in l_arg_skip: # if the previous element is not an option name to skip AND if the current element is not an option to remove args_final.append(arg) + ## add quotes back to the input arg that had quotes, e.g., experiment name args_final = [f'"{arg}"' if ' ' in arg else arg for arg in args_final] @@ -457,16 +600,29 @@ def trainlaunch(**kwargs): #### locate train.py import deepliif - path_train_py = deepliif.__path__[0]+'/train.py' + path_train_py = deepliif.__path__[0]+'/scripts/train.py' + + #### find out GPUs to use + gpu_ids = [args_final[i+1] for i,v in enumerate(args_final) if v=='--gpu-ids'] + if len(gpu_ids) > 0 and gpu_ids[0] == -1: + gpu_ids = [] + + if len(gpu_ids) > 0: + opt_env = f"CUDA_VISIBLE_DEVICES=\"{','.join(gpu_ids)}\"" + else: + opt_env = '' #### execute train.py if kwargs['use_torchrun']: if version.parse(torch.__version__) >= version.parse('1.10.0'): - subprocess.run(f'torchrun {kwargs["use_torchrun"]} {path_train_py} {options}',shell=True) + cmd = f'{opt_env} torchrun {kwargs["use_torchrun"]} {path_train_py} {options}' else: - subprocess.run(f'python -m torch.distributed.launch {kwargs["use_torchrun"]} {path_train_py} {options}',shell=True) + cmd = f'{opt_env} python -m torch.distributed.launch {kwargs["use_torchrun"]} {path_train_py} {options}' else: - subprocess.run(f'python {path_train_py} {options}',shell=True) + cmd = f'{opt_env} python {path_train_py} {options}' + + print('Executing command:',cmd) + subprocess.run(cmd,shell=True) @@ -475,9 +631,10 @@ def trainlaunch(**kwargs): @click.option('--model-dir', default='./model-server/DeepLIIF_Latest_Model', help='reads models from here') @click.option('--output-dir', help='saves results here.') #@click.option('--tile-size', type=int, default=None, help='tile size') -@click.option('--device', default='cpu', type=str, help='device to load model for the similarity test, either cpu or gpu') +@click.option('--device', default='cpu', type=str, help='device to run serialization as well as load model for the similarity test, either cpu or gpu') +@click.option('--epoch', default='latest', type=str, help='epoch to load and serialize') @click.option('--verbose', default=0, type=int,help='saves results here.') -def serialize(model_dir, output_dir, device, verbose): +def serialize(model_dir, output_dir, device, epoch, verbose): """Serialize DeepLIIF models using Torchscript """ #if tile_size is None: @@ -490,12 +647,20 @@ def serialize(model_dir, output_dir, device, verbose): if model_dir != output_dir: shutil.copy(f'{model_dir}/train_opt.txt',f'{output_dir}/train_opt.txt') + # load and update opt for serialization opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode='test') + opt.epoch = epoch + if device == 'gpu': + opt.gpu_ids = [0] # use gpu 0, in case training was done on larger machines + else: + opt.gpu_ids = [] # use cpu + + print_options(opt) sample = transform(Image.new('RGB', (opt.scale_size, opt.scale_size))) sample = torch.cat([sample]*opt.input_no, 1) with click.progressbar( - init_nets(model_dir, eager_mode=True, phase='test').items(), + init_nets(model_dir, eager_mode=True, opt=opt, phase='test').items(), label='Tracing nets', item_show_func=lambda n: n[0] if n else n ) as bar: @@ -514,8 +679,9 @@ def serialize(model_dir, output_dir, device, verbose): # test: whether the original and the serialized model produces highly similar predictions print('testing similarity between prediction from original vs serialized models...') - models_original = init_nets(model_dir,eager_mode=True,phase='test') - models_serialized = init_nets(output_dir,eager_mode=False,phase='test') + models_original = init_nets(model_dir,eager_mode=True,opt=opt,phase='test') + models_serialized = init_nets(output_dir,eager_mode=False,opt=opt,phase='test') + if device == 'gpu': sample = sample.cuda() else: @@ -523,7 +689,7 @@ def serialize(model_dir, output_dir, device, verbose): for name in models_serialized.keys(): print(name,':') model_original = models_original[name].cuda().eval() if device=='gpu' else models_original[name].cpu().eval() - model_serialized = models_serialized[name].cuda() if device=='gpu' else models_serialized[name].cpu().eval() + model_serialized = models_serialized[name].cuda().eval() if device=='gpu' else models_serialized[name].cpu().eval() if name.startswith('GS'): test_diff_original_serialized(model_original,model_serialized,torch.cat([sample, sample, sample], 1),verbose) else: diff --git a/deepliif/data/__init__.py b/deepliif/data/__init__.py index 6a4e6c4..2966985 100644 --- a/deepliif/data/__init__.py +++ b/deepliif/data/__init__.py @@ -55,28 +55,28 @@ def get_option_setter(dataset_name): return dataset_class.modify_commandline_options -def create_dataset(opt): +def create_dataset(opt, phase=None, batch_size=None): """Create a dataset given the option. This function wraps the class CustomDatasetDataLoader. This is the main interface between this package and 'train.py'/'test.py' """ - return CustomDatasetDataLoader(opt) + return CustomDatasetDataLoader(opt, phase=phase if phase else opt.phase, batch_size=batch_size if batch_size else opt.batch_size) class CustomDatasetDataLoader(object): """Wrapper class of Dataset class that performs multi-threaded data loading""" - def __init__(self, opt): + def __init__(self, opt, phase=None, batch_size=None): """Initialize this class Step 1: create a dataset instance given the name [dataset_mode] Step 2: create a multi-threaded data loader. """ - self.batch_size = opt.batch_size + self.batch_size = batch_size if batch_size else opt.batch_size self.max_dataset_size = opt.max_dataset_size dataset_class = find_dataset_using_name(opt.dataset_mode) - self.dataset = dataset_class(opt) + self.dataset = dataset_class(opt, phase=phase if phase else opt.phase) print("dataset [%s] was created" % type(self.dataset).__name__) sampler = None @@ -95,7 +95,7 @@ def seed_worker(worker_id): self.dataloader = torch.utils.data.DataLoader( self.dataset, sampler=sampler, - batch_size=opt.batch_size, + batch_size=batch_size, shuffle=not opt.serial_batches if sampler is None else False, num_workers=int(opt.num_threads) ) @@ -106,7 +106,7 @@ def seed_worker(worker_id): self.dataloader = torch.utils.data.DataLoader( self.dataset, sampler=sampler, - batch_size=opt.batch_size, + batch_size=batch_size, shuffle=not opt.serial_batches if sampler is None else False, num_workers=int(opt.num_threads), worker_init_fn=seed_worker, diff --git a/deepliif/data/aligned_dataset.py b/deepliif/data/aligned_dataset.py index aa5928b..b1e7501 100644 --- a/deepliif/data/aligned_dataset.py +++ b/deepliif/data/aligned_dataset.py @@ -11,7 +11,7 @@ class AlignedDataset(BaseDataset): During test time, you need to prepare a directory '/path/to/data/test'. """ - def __init__(self, opt): + def __init__(self, opt, phase='train'): """Initialize this dataset class. Parameters: @@ -19,7 +19,7 @@ def __init__(self, opt): """ BaseDataset.__init__(self, opt.dataroot) self.preprocess = opt.preprocess - self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory + self.dir_AB = os.path.join(opt.dataroot, phase) # get the image directory self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths assert(opt.load_size >= opt.crop_size) # crop_size should be smaller than the size of loaded image self.input_nc = opt.output_nc if opt.direction == 'BtoA' else opt.input_nc @@ -95,7 +95,6 @@ def __getitem__(self, index): A = AB.crop((w2 * i, 0, w2 * (i+1), h)) A = A_transform(A) A_Array.append(A) - for i in range(self.input_no, self.input_no + self.modalities_no + 1): B = AB.crop((w2 * i, 0, w2 * (i + 1), h)) B = B_transform(B) diff --git a/deepliif/models/DeepLIIFExt_model.py b/deepliif/models/DeepLIIFExt_model.py index de7b0e6..0522f74 100644 --- a/deepliif/models/DeepLIIFExt_model.py +++ b/deepliif/models/DeepLIIFExt_model.py @@ -1,6 +1,7 @@ import torch from .base_model import BaseModel from . import networks +from .networks import get_optimizer class DeepLIIFExtModel(BaseModel): @@ -72,22 +73,19 @@ def __init__(self, opt): self.model_names.extend(['GS_' + str(i)]) # define networks (both generator and discriminator) + if isinstance(opt.net_g, str): + self.opt.net_g = [self.opt.net_g] * self.mod_gen_no + if isinstance(opt.net_gs, str): + self.opt.net_gs = [self.opt.net_gs]*self.mod_gen_no self.netG = [None for _ in range(self.mod_gen_no)] self.netGS = [None for _ in range(self.mod_gen_no)] for i in range(self.mod_gen_no): - self.netG[i] = networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm, + self.netG[i] = networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g[i], self.opt.norm, not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding) - print('***************************************') - print(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm, - not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding) - print('***************************************') + for i in range(self.mod_gen_no): if self.opt.seg_gen: - # if i == 0: - # self.netGS[i] = networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_gs, self.opt.norm, - # not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids) - # else: - self.netGS[i] = networks.define_G(self.opt.input_nc * 3, self.opt.output_nc, self.opt.ngf, self.opt.net_gs, self.opt.norm, + self.netGS[i] = networks.define_G(self.opt.input_nc * 3, self.opt.output_nc, self.opt.ngf, self.opt.net_gs[i], self.opt.norm, not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids) if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc @@ -99,11 +97,6 @@ def __init__(self, opt): self.gpu_ids) for i in range(self.mod_gen_no): if self.opt.seg_gen: - # if i == 0: - # self.netDS[i] = networks.define_D(self.opt.input_nc + self.opt.output_nc, self.opt.ndf, self.opt.net_ds, - # self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain, - # self.gpu_ids) - # else: self.netDS[i] = networks.define_D(self.opt.input_nc * 3 + self.opt.output_nc, self.opt.ndf, self.opt.net_ds, self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain, self.gpu_ids) @@ -113,9 +106,7 @@ def __init__(self, opt): # define loss functions self.criterionGAN_mod = networks.GANLoss(self.opt.gan_mode).to(self.device) self.criterionGAN_seg = networks.GANLoss(self.opt.gan_mode_s).to(self.device) - self.criterionSmoothL1 = torch.nn.SmoothL1Loss() - self.criterionVGG = networks.VGGLoss().to(self.device) # initialize optimizers; schedulers will be automatically created by function . @@ -125,7 +116,11 @@ def __init__(self, opt): for i in range(len(self.netGS)): if self.netGS[i]: params += list(self.netGS[i].parameters()) - self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + try: + self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + except: + print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators') + self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr) params = [] for i in range(len(self.netD)): @@ -133,7 +128,11 @@ def __init__(self, opt): for i in range(len(self.netDS)): if self.netDS[i]: params += list(self.netDS[i].parameters()) - self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + try: + self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + except: + print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in discriminators') + self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) @@ -295,3 +294,29 @@ def optimize_parameters(self): self.optimizer_G.zero_grad() # set G's gradients to zero self.backward_G() # calculate graidents for G self.optimizer_G.step() # udpate G's weights + + def calculate_losses(self): + """ + Calculate losses but do not optimize parameters. Used in validation loss calculation during training. + """ + self.forward() # compute fake images: G(A) + # update D + for i in range(self.mod_gen_no): + self.set_requires_grad(self.netD[i], True) # enable backprop for D1 + for i in range(self.mod_gen_no): + if self.netDS[i]: + self.set_requires_grad(self.netDS[i], True) + + self.optimizer_D.zero_grad() # set D's gradients to zero + self.backward_D() # calculate gradients for D + + # update G + for i in range(self.mod_gen_no): + self.set_requires_grad(self.netD[i], False) + for i in range(self.mod_gen_no): + if self.netDS[i]: + self.set_requires_grad(self.netDS[i], False) + + self.optimizer_G.zero_grad() # set G's gradients to zero + self.backward_G() # calculate graidents for G + diff --git a/deepliif/models/DeepLIIF_model.py b/deepliif/models/DeepLIIF_model.py index aeb57f8..8279e61 100644 --- a/deepliif/models/DeepLIIF_model.py +++ b/deepliif/models/DeepLIIF_model.py @@ -1,6 +1,7 @@ import torch from .base_model import BaseModel from . import networks +from .networks import get_optimizer class DeepLIIFModel(BaseModel): @@ -13,9 +14,11 @@ def __init__(self, opt): opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseModel.__init__(self, opt) + if not hasattr(opt,'net_gs'): + opt.net_gs = 'unet_512' # weights of the modalities in generating segmentation mask - self.seg_weights = [0.25, 0.15, 0.25, 0.1, 0.25] + self.seg_weights = [0.25, 0.25, 0.25, 0.0, 0.25] # loss weights in calculating the final loss self.loss_G_weights = [0.2, 0.2, 0.2, 0.2, 0.2] @@ -31,7 +34,7 @@ def __init__(self, opt): # specify the training losses you want to print out. The training/test scripts will call for i in range(1, self.opt.modalities_no + 1 + 1): self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'D_real_' + str(i), 'D_fake_' + str(i)]) - self.visual_names.extend(['fake_B_' + str(i), 'real_B_' + str(i)]) + self.visual_names.extend(['fake_B_' + str(i), 'fake_B_5' + str(i), 'real_B_' + str(i)]) # specify the images you want to save/display. The training/test scripts will call # specify the models you want to save to the disk. The training/test scripts will call and @@ -51,24 +54,31 @@ def __init__(self, opt): self.model_names.extend(['G5' + str(i)]) # define networks (both generator and discriminator) - self.netG1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + if isinstance(opt.netG, str): + opt.netG = [opt.netG] * 4 + if isinstance(opt.net_gs, str): + opt.net_gs = [opt.net_gs]*5 + + + self.netG1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[0], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding) - self.netG2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + self.netG2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[1], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding) - self.netG3 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + self.netG3 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[2], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding) - self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[3], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding) - self.netG51 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'unet_512', opt.norm, + # DeepLIIF model currently uses one gs arch because there is only one explicit seg mod output + self.netG51 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[0], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'unet_512', opt.norm, + self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[1], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'unet_512', opt.norm, + self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[2], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'unet_512', opt.norm, + self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[3], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'unet_512', opt.norm, + self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[4], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) @@ -101,10 +111,18 @@ def __init__(self, opt): # initialize optimizers; schedulers will be automatically created by function . params = list(self.netG1.parameters()) + list(self.netG2.parameters()) + list(self.netG3.parameters()) + list(self.netG4.parameters()) + list(self.netG51.parameters()) + list(self.netG52.parameters()) + list(self.netG53.parameters()) + list(self.netG54.parameters()) + list(self.netG55.parameters()) - self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + try: + self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + except: + print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators') + self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr) params = list(self.netD1.parameters()) + list(self.netD2.parameters()) + list(self.netD3.parameters()) + list(self.netD4.parameters()) + list(self.netD51.parameters()) + list(self.netD52.parameters()) + list(self.netD53.parameters()) + list(self.netD54.parameters()) + list(self.netD55.parameters()) - self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + try: + self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + except: + print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in discriminators') + self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) @@ -322,3 +340,38 @@ def optimize_parameters(self): self.optimizer_G.zero_grad() # set G's gradients to zero self.backward_G() # calculate graidents for G self.optimizer_G.step() # udpate G's weights + + def calculate_losses(self): + """ + Calculate losses but do not optimize parameters. Used in validation loss calculation during training. + """ + + self.forward() # compute fake images: G(A) + # update D + self.set_requires_grad(self.netD1, True) # enable backprop for D1 + self.set_requires_grad(self.netD2, True) # enable backprop for D2 + self.set_requires_grad(self.netD3, True) # enable backprop for D3 + self.set_requires_grad(self.netD4, True) # enable backprop for D4 + self.set_requires_grad(self.netD51, True) # enable backprop for D51 + self.set_requires_grad(self.netD52, True) # enable backprop for D52 + self.set_requires_grad(self.netD53, True) # enable backprop for D53 + self.set_requires_grad(self.netD54, True) # enable backprop for D54 + self.set_requires_grad(self.netD55, True) # enable backprop for D54 + + self.optimizer_D.zero_grad() # set D's gradients to zero + self.backward_D() # calculate gradients for D + + # update G + self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1 + self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2 + self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3 + self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4 + self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51 + self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52 + self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53 + self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54 + self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54 + + self.optimizer_G.zero_grad() # set G's gradients to zero + self.backward_G() # calculate graidents for G + diff --git a/deepliif/models/SDG_model.py b/deepliif/models/SDG_model.py index ed0896d..0db367f 100644 --- a/deepliif/models/SDG_model.py +++ b/deepliif/models/SDG_model.py @@ -1,6 +1,7 @@ import torch from .base_model import BaseModel from . import networks +from .networks import get_optimizer class SDGModel(BaseModel): @@ -15,6 +16,7 @@ def __init__(self, opt): BaseModel.__init__(self, opt) self.mod_gen_no = self.opt.modalities_no + # weights of the modalities in generating segmentation mask self.seg_weights = [0, 0, 0] @@ -24,7 +26,7 @@ def __init__(self, opt): # self.seg_weights = opt.seg_weights # assert len(self.seg_weights) == self.seg_gen_no, 'The number of the segmentation weights (seg_weights) is not equal to the number of target images (modalities_no)!' - # print(self.seg_weights) + # loss weights in calculating the final loss self.loss_G_weights = [1 / self.mod_gen_no] * self.mod_gen_no self.loss_GS_weights = [1 / self.mod_gen_no] * self.mod_gen_no @@ -32,11 +34,19 @@ def __init__(self, opt): self.loss_D_weights = [1 / self.mod_gen_no] * self.mod_gen_no self.loss_DS_weights = [1 / self.mod_gen_no] * self.mod_gen_no + # self.gpu_ids is a possibly modifed one for model initialization + # self.opt.gpu_ids is the original one received in the command + if not opt.is_train: + self.gpu_ids = [] # avoid the models being loaded as DP + else: + self.gpu_ids = opt.gpu_ids + self.loss_names = [] self.visual_names = ['real_A'] # specify the training losses you want to print out. The training/test scripts will call for i in range(1, self.mod_gen_no + 1): - self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'D_real_' + str(i), 'D_fake_' + str(i)]) + self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'G_VGG_'+ str(i), + 'D_real_' + str(i), 'D_fake_' + str(i)]) self.visual_names.extend(['fake_B_' + str(i), 'real_B_' + str(i)]) # specify the images you want to save/display. The training/test scripts will call @@ -52,41 +62,48 @@ def __init__(self, opt): self.model_names.extend(['G_' + str(i)]) # define networks (both generator and discriminator) + if isinstance(opt.net_g, str): + self.opt.net_g = [self.opt.net_g] * self.mod_gen_no + if isinstance(opt.net_gs, str): + self.opt.net_gs = [self.opt.net_gs]*self.mod_gen_no self.netG = [None for _ in range(self.mod_gen_no)] + self.netGS = [None for _ in range(self.mod_gen_no)] for i in range(self.mod_gen_no): - self.netG[i] = networks.define_G(self.opt.input_nc * self.opt.input_no, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm, - not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.opt.gpu_ids, self.opt.padding) - print('***************************************') - print(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm, - not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.opt.gpu_ids, self.opt.padding) - print('***************************************') + self.netG[i] = networks.define_G(self.opt.input_nc * self.opt.input_no, self.opt.output_nc, self.opt.ngf, self.opt.net_g[i], self.opt.norm, + not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding) if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc self.netD = [None for _ in range(self.mod_gen_no)] for i in range(self.mod_gen_no): self.netD[i] = networks.define_D(self.opt.input_nc * self.opt.input_no + self.opt.output_nc, self.opt.ndf, self.opt.net_d, self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain, - self.opt.gpu_ids) + self.gpu_ids) if self.is_train: # define loss functions self.criterionGAN_mod = networks.GANLoss(self.opt.gan_mode).to(self.device) self.criterionGAN_seg = networks.GANLoss(self.opt.gan_mode_s).to(self.device) - self.criterionSmoothL1 = torch.nn.SmoothL1Loss() - self.criterionVGG = networks.VGGLoss().to(self.device) # initialize optimizers; schedulers will be automatically created by function . params = [] for i in range(len(self.netG)): params += list(self.netG[i].parameters()) - self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + try: + self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + except: + print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators') + self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr) params = [] for i in range(len(self.netD)): params += list(self.netD[i].parameters()) - self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + try: + self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + except: + print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in discriminators') + self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) @@ -136,7 +153,6 @@ def backward_D(self): self.loss_D_real.append(self.criterionGAN_mod(pred_real[i], True)) # combine losses and calculate gradients - # self.loss_D = (self.loss_D_fake[0] + self.loss_D_real[0]) * 0.5 * self.loss_D_weights[0] self.loss_D = torch.tensor(0., device=self.device) for i in range(0, self.mod_gen_no): self.loss_D += (self.loss_D_fake[i] + self.loss_D_real[i]) * 0.5 * self.loss_D_weights[i] @@ -159,15 +175,14 @@ def backward_G(self): for i in range(self.mod_gen_no): self.loss_G_L1.append(self.criterionSmoothL1(self.fake_B[i], self.real_B[i]) * self.opt.lambda_L1) - #self.loss_G_VGG = [] - #for i in range(self.mod_gen_no): - # self.loss_G_VGG.append(self.criterionVGG(self.fake_B[i], self.real_B[i]) * self.opt.lambda_feat) + self.loss_G_VGG = [] + for i in range(self.mod_gen_no): + self.loss_G_VGG.append(self.criterionVGG(self.fake_B[i], self.real_B[i]) * self.opt.lambda_feat) + - # self.loss_G = (self.loss_G_GAN[0] + self.loss_G_L1[0]) * self.loss_G_weights[0] self.loss_G = torch.tensor(0., device=self.device) for i in range(0, self.mod_gen_no): - self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i]) * self.loss_G_weights[i] - # self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i]) * self.loss_G_weights[i] + self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i]) * self.loss_G_weights[i] self.loss_G.backward() def optimize_parameters(self): @@ -187,3 +202,22 @@ def optimize_parameters(self): self.optimizer_G.zero_grad() # set G's gradients to zero self.backward_G() # calculate graidents for G self.optimizer_G.step() # udpate G's weights + + def calculate_losses(self): + """ + Calculate losses but do not optimize parameters. Used in validation loss calculation during training. + """ + self.forward() # compute fake images: G(A) + # update D + for i in range(self.mod_gen_no): + self.set_requires_grad(self.netD[i], True) # enable backprop for D1 + + self.optimizer_D.zero_grad() # set D's gradients to zero + self.backward_D() # calculate gradients for D + + # update G + for i in range(self.mod_gen_no): + self.set_requires_grad(self.netD[i], False) + + self.optimizer_G.zero_grad() # set G's gradients to zero + self.backward_G() # calculate graidents for G diff --git a/deepliif/models/__init__.py b/deepliif/models/__init__.py index ce027a1..3d3049d 100644 --- a/deepliif/models/__init__.py +++ b/deepliif/models/__init__.py @@ -115,7 +115,7 @@ def load_torchscript_model(model_pt_path, device): -def load_eager_models(opt, devices): +def load_eager_models(opt, devices=None): # create a model given model and other options model = create_model(opt) # regular setup: load and print networks; create schedulers @@ -138,7 +138,8 @@ def load_eager_models(opt, devices): net = net.module nets[name] = net - nets[name].to(devices[name]) + if devices: + nets[name].to(devices[name]) return nets @@ -154,8 +155,7 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'): """ if opt is None: opt = get_opt(model_dir, mode=phase) - opt.use_dp = False - #print_options(opt) + opt.use_dp = False if opt.model == 'DeepLIIF': net_groups = [ @@ -174,8 +174,7 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'): raise Exception(f'init_nets() not implemented for model {opt.model}') number_of_gpus_all = torch.cuda.device_count() - number_of_gpus = len(opt.gpu_ids) - #print(number_of_gpus) + number_of_gpus = min(len(opt.gpu_ids),number_of_gpus_all) if number_of_gpus > 0: mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)} @@ -263,6 +262,7 @@ def forward(input, model): seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0) res = {k: tensor_to_pil(v) for k, v in gens.items()} + res.update({k: tensor_to_pil(v) for k, v in segs.items()}) res['G5'] = tensor_to_pil(seg) return res @@ -290,7 +290,7 @@ def is_empty_old(tile): return all([True if calculate_background_area(t) > 98 else False for t in tile]) else: return True if calculate_background_area(tile) > 98 else False - + def is_empty(tile): thresh = 15 @@ -308,7 +308,12 @@ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None): 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)), 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)), - 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)) + 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), + 'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), + 'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), + 'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), + 'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), + 'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), } else: return run_fn(tile, model_path, eager_mode, opt) @@ -385,7 +390,14 @@ def inference_old2(img, tile_size, overlap_size, model_path, use_torchserve=Fals 'DAPI':'G2', 'Lap2':'G3', 'Marker':'G4', - 'Seg':'G5'} + 'Seg':'G5', + } + if return_seg_intermediate: + d_modality2net.update({'IHC_s':'G51', + 'Hema_s':'G52', + 'DAPI_s':'G53', + 'Lap2_s':'G54', + 'Marker_s':'G55',}) for k in d_modality2net.keys(): images[k] = create_image_for_stitching(tile_size, rows, cols) @@ -397,7 +409,7 @@ def inference_old2(img, tile_size, overlap_size, model_path, use_torchserve=Fals for modality_name, net_name in d_modality2net.items(): stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j) - + for modality_name, output_img in images.items(): images[modality_name] = output_img.resize(img.size) @@ -490,7 +502,8 @@ def get_net_tiles(n): def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, - eager_mode=False, color_dapi=False, color_marker=False, opt=None): + eager_mode=False, color_dapi=False, color_marker=False, opt=None, + return_seg_intermediate=False): if not opt: opt = get_opt(model_path) #print_options(opt) @@ -519,6 +532,14 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, 'Marker': results['G4'], 'Seg': results['G5'], } + + if return_seg_intermediate: + images.update({'IHC_s':results['G51'], + 'Hema_s':results['G52'], + 'DAPI_s':results['G53'], + 'Lap2_s':results['G54'], + 'Marker_s':results['G55'],}) + if color_dapi: matrix = ( 0, 0, 0, 0, 299/1000, 587/1000, 114/1000, 0, @@ -557,7 +578,7 @@ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='aut processed_images['SegRefined'] = Image.fromarray(refined) return processed_images, scoring - elif model == 'DeepLIIFExt': + elif model in ['DeepLIIFExt','SDG']: resolution = '40x' if tile_size > 768 else ('20x' if tile_size > 384 else '10x') processed_images = {} scoring = {} @@ -578,7 +599,8 @@ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='aut def infer_modalities(img, tile_size, model_dir, eager_mode=False, - color_dapi=False, color_marker=False, opt=None): + color_dapi=False, color_marker=False, opt=None, + return_seg_intermediate=False): """ This function is used to infer modalities for the given image using a trained model. :param img: The input image. @@ -604,7 +626,8 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False, eager_mode=eager_mode, color_dapi=color_dapi, color_marker=color_marker, - opt=opt + opt=opt, + return_seg_intermediate=return_seg_intermediate ) if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models diff --git a/deepliif/models/att_unet.py b/deepliif/models/att_unet.py new file mode 100644 index 0000000..a7920ba --- /dev/null +++ b/deepliif/models/att_unet.py @@ -0,0 +1,199 @@ +# adapted from https://github.com/LeeJunHyun/Image_Segmentation/blob/master/network.py + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init + +def init_weights(net, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) + +class conv_block(nn.Module): + def __init__(self,ch_in,ch_out,innermost=False,outermost=False): + super(conv_block,self).__init__() + if outermost: + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=4,stride=2,padding=1,bias=True), + nn.LeakyReLU(0.2, True), + ) + elif innermost: + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=4,stride=2,padding=1,bias=True), + nn.ReLU(inplace=True), + ) + else: + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=4,stride=2,padding=1,bias=True), + nn.BatchNorm2d(ch_out), + nn.LeakyReLU(0.2, True), + ) + + def forward(self,x): + x = self.conv(x) + return x + +class up_conv(nn.Module): + def __init__(self,ch_in,ch_out,innermost=False,outermost=False): + super(up_conv,self).__init__() + use_bias=False + if outermost: + self.up = nn.Sequential( + nn.ConvTranspose2d(ch_in * 2, ch_out, + kernel_size=4, stride=2, + padding=1), + nn.Tanh()) + elif innermost: + self.up = nn.Sequential( + nn.ConvTranspose2d(ch_in, ch_out, + kernel_size=4, stride=2, + padding=1, bias=use_bias), + nn.BatchNorm2d(ch_out), + nn.ReLU(True)) + else: + self.up = nn.Sequential( + nn.ConvTranspose2d(ch_in * 2, ch_out, + kernel_size=4, stride=2, + padding=1, bias=use_bias), + nn.BatchNorm2d(ch_out), + nn.ReLU(True)) + + + def forward(self,x): + x = self.up(x) + return x + + +class Attention_block(nn.Module): + def __init__(self,F_g,F_l,F_int): + super(Attention_block,self).__init__() + self.W_g = nn.Sequential( + nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), + nn.BatchNorm2d(F_int) + ) + + self.W_x = nn.Sequential( + nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), + nn.BatchNorm2d(F_int) + ) + + self.psi = nn.Sequential( + nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), + nn.BatchNorm2d(1), + nn.Sigmoid() + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self,g,x): + g1 = self.W_g(g) + x1 = self.W_x(x) + psi = self.relu(g1+x1) + psi = self.psi(psi) + + return x*psi + + + +class AttU_Net(nn.Module): + def __init__(self,img_ch=3,output_ch=1): + super(AttU_Net,self).__init__() + + self.Conv1 = conv_block(ch_in=img_ch,ch_out=64,outermost=True) + self.Conv2 = conv_block(ch_in=64,ch_out=128) + self.Conv3 = conv_block(ch_in=128,ch_out=256) + self.Conv4 = conv_block(ch_in=256,ch_out=512) + self.Conv5 = conv_block(ch_in=512,ch_out=512) + self.Conv6 = conv_block(ch_in=512,ch_out=512) + self.Conv7 = conv_block(ch_in=512,ch_out=512) + self.Conv8 = conv_block(ch_in=512,ch_out=512,innermost=True) + #self.Conv9 = conv_block(ch_in=512,ch_out=512,innermost=True) + + self.Up8 = up_conv(ch_in=512,ch_out=512,innermost=True) + self.Att8 = Attention_block(F_g=512,F_l=512,F_int=512) + + self.Up7 = up_conv(ch_in=512,ch_out=512) + self.Att7 = Attention_block(F_g=512,F_l=512,F_int=512) + + self.Up6 = up_conv(ch_in=512,ch_out=512) + self.Att6 = Attention_block(F_g=512,F_l=512,F_int=512) + + self.Up5 = up_conv(ch_in=512,ch_out=512) + self.Att5 = Attention_block(F_g=512,F_l=512,F_int=512) + + self.Up4 = up_conv(ch_in=512,ch_out=256) + self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128) + + self.Up3 = up_conv(ch_in=256,ch_out=128) + self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64) + + self.Up2 = up_conv(ch_in=128,ch_out=64) + self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32) + + self.Up1 = up_conv(ch_in=64,ch_out=output_ch,outermost=True) + + def forward(self,x): + # encoding path + x1 = self.Conv1(x) + x2 = self.Conv2(x1) + x3 = self.Conv3(x2) + x4 = self.Conv4(x3) + x5 = self.Conv5(x4) + x6 = self.Conv6(x5) + x7 = self.Conv7(x6) + x8 = self.Conv8(x7) + #x9 = self.Conv9(x8) + + #d9 = self.Up + d8 = self.Up8(x8) + x7 = self.Att8(g=d8,x=x7) + d8 = torch.cat((x7,d8),dim=1) + + d7 = self.Up7(d8) + x6 = self.Att7(g=d7,x=x6) + d7 = torch.cat((x6,d7),dim=1) + + d6 = self.Up6(d7) + x5 = self.Att6(g=d6,x=x5) + d6 = torch.cat((x5,d6),dim=1) + + d5 = self.Up5(d6) + x4 = self.Att5(g=d5,x=x4) + d5 = torch.cat((x4,d5),dim=1) + + + d4 = self.Up4(d5) # x4: [2, 512, 4, 4], d4: [2, 256, 4, 4] + x3 = self.Att4(g=d4,x=x3) + d4 = torch.cat((x3,d4),dim=1) + + d3 = self.Up3(d4) + x2 = self.Att3(g=d3,x=x2) + d3 = torch.cat((x2,d3),dim=1) + + d2 = self.Up2(d3) + x1 = self.Att2(g=d2,x=x1) + d2 = torch.cat((x1,d2),dim=1) + + d1 = self.Up1(d2) + + return d1 + diff --git a/deepliif/models/base_model.py b/deepliif/models/base_model.py index cad9295..76fd9d2 100644 --- a/deepliif/models/base_model.py +++ b/deepliif/models/base_model.py @@ -3,7 +3,7 @@ from collections import OrderedDict from abc import ABC, abstractmethod from . import networks -from ..util import disable_batchnorm_tracking_stats +from ..util import disable_batchnorm_tracking_stats, enable_batchnorm_tracking_stats from deepliif.util import * import itertools @@ -88,6 +88,17 @@ def setup(self, opt): self.load_networks(load_suffix) self.print_networks(opt.verbose) + def train(self): + """Make models train mode """ + for name in self.model_names: + if isinstance(name, str): + if '_' in name: + net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1] + else: + net = getattr(self, 'net' + name) + net.train() + net = enable_batchnorm_tracking_stats(net) + def eval(self): """Make models eval mode during test time""" for name in self.model_names: @@ -134,10 +145,14 @@ def get_current_visuals(self): for name in self.visual_names: if isinstance(name, str): if not hasattr(self, name): - if len(name.split('_')) == 2: - visual_ret[name] = getattr(self, name.split('_')[0])[int(name.split('_')[-1]) -1] + if len(name.split('_')) != 2: + if self.opt.model == 'DeepLIIF': + img_name = name[:-1] + '_' + name[-1] + visual_ret[name] = getattr(self, img_name) + else: + visual_ret[name] = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])[int(name.split('_')[-1]) - 1] else: - visual_ret[name] = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])[int(name.split('_')[-1]) - 1] + visual_ret[name] = getattr(self, name.split('_')[0])[int(name.split('_')[-1]) -1] else: visual_ret[name] = getattr(self, name) return visual_ret @@ -240,6 +255,7 @@ def load_networks(self, epoch): epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) """ for name in self.model_names: + if isinstance(name, str): load_filename = '%s_net_%s.pth' % (epoch, name) load_path = os.path.join(self.save_dir, load_filename) diff --git a/deepliif/models/networks.py b/deepliif/models/networks.py index ea99f6a..6aef419 100644 --- a/deepliif/models/networks.py +++ b/deepliif/models/networks.py @@ -3,15 +3,19 @@ from torch.nn import init import functools from torch.optim import lr_scheduler + import os from torchvision import models - +from .att_unet import AttU_Net ############################################################################### # Helper Functions ############################################################################### from deepliif.util import util +# as of pytorch 2.4, all optimizers start with an uppercase letter +OPTIMIZER_MAPPING = {optimizer_name.lower():optimizer_name for optimizer_name in dir(torch.optim) if optimizer_name[0].isupper()} + class Identity(nn.Module): def forward(self, x): @@ -37,6 +41,14 @@ def norm_layer(x): return Identity() raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer +def get_optimizer(optimizer_name): + try: + return getattr(torch.optim, optimizer_name) + except: + try: + return getattr(torch.optim, OPTIMIZER_MAPPING[optimizer_name]) + except: + raise NotImplementedError('optimizer [%s] is not found' % optimizer_name) def get_scheduler(optimizer, opt): """Return a learning rate scheduler @@ -125,7 +137,7 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): return net -def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], padding_type='reflect'): +def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], padding_type='reflect', resize_conv=False): """Create a generator Parameters: @@ -156,7 +168,7 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in norm_layer = get_norm_layer(norm_type=norm) if netG == 'resnet_9blocks': - net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, padding_type=padding_type) + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, padding_type=padding_type, resize_conv=resize_conv) elif netG == 'resnet_6blocks': net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, padding_type=padding_type) elif netG == 'unet_128': @@ -165,6 +177,8 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif netG == 'unet_512': net = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_512_attention': + net = AttU_Net(img_ch=input_nc,output_ch=output_nc) else: raise NotImplementedError('Generator model name [%s] is not recognized' % netG) return init_net(net, init_type, init_gain, gpu_ids) @@ -332,9 +346,12 @@ class ResnetGenerator(nn.Module): """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + + Resize-conv: optional replacement of ConvTranspose2d + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/190#issuecomment-358546675 """ - def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='zero'): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='zero', resize_conv=False): """Construct a Resnet-based generator Parameters: @@ -346,6 +363,7 @@ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_d n_blocks (int) -- the number of ResNet blocks padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero """ + print('resize_conv',resize_conv) assert(n_blocks >= 0) super(ResnetGenerator, self).__init__() if type(norm_layer) == functools.partial: @@ -378,11 +396,19 @@ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_d for i in range(n_downsampling): # add upsampling layers mult = 2 ** (n_downsampling - i) - model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + if resize_conv: + upsample_layer = [#nn.Upsample(scale_factor = 2, mode='bilinear',align_corners=True), + nn.Upsample(scale_factor = 2, mode='nearest'), + nn.ReflectionPad2d(1), + nn.Conv2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=1, padding=0)] + else: + upsample_layer = [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, - bias=use_bias), - norm_layer(int(ngf * mult / 2)), + bias=use_bias)] + model += upsample_layer + model += [norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] if padding_type == 'reflect': @@ -687,3 +713,15 @@ def forward(self, x, y): for i in range(len(x_vgg)): loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) return loss + + +class TotalVariationLoss(nn.Module): + """ + Absolute difference for neighbouring pixels (i,j to i+1,j, then i,j to i,j+1), averaged on pixel level + """ + def __init__(self): + super(TotalVariationLoss, self).__init__() + + def forward(self, x): + tv = torch.abs(x[:,:,1:,:]-x[:,:,:-1,:]).sum() + torch.abs(x[:,:,:,1:]-x[:,:,:,:-1]).sum() + return tv / torch.prod(torch.tensor(x.size())) diff --git a/deepliif/options/__init__.py b/deepliif/options/__init__.py index 5e2ea6b..5a734c0 100644 --- a/deepliif/options/__init__.py +++ b/deepliif/options/__init__.py @@ -3,6 +3,7 @@ from pathlib import Path import os from ..util.util import mkdirs +import re def read_model_params(file_addr): with open(file_addr) as f: @@ -11,8 +12,27 @@ def read_model_params(file_addr): for line in lines: if ':' in line: key = line.split(':')[0].strip() - val = line.split(':')[1].split('[')[0].strip() - param_dict[key] = val + val = ':'.join(line.split(':')[1:]) + + # drop default value + str_default = [x for x in re.findall(r"\[.+?\]", val) if x.startswith('[default')] + if len(str_default) > 1: + raise Exception('train_opt.txt should not contain multiple possible default keys in one line:',str_default) + elif len(str_default) == 1: + str_default = str_default[0] + val = val.replace(str_default,'') + val = val.strip() + + # val = line.split(':')[1].split('[')[0].strip() + try: + param_dict[key] = eval(val) + #print(f'value of {key} is converted to {type(param_dict[key]).__name__}') + except: + param_dict[key] = val + + if isinstance(param_dict[key],list): + param_dict[key] = param_dict[key][0] + return param_dict class Options: @@ -33,13 +53,16 @@ def __init__(self, d_params=None, path_file=None, mode='train'): except: setattr(self,k,v) + self.optimizer = 'adam' if not hasattr(self,'optimizer') else self.optimizer + if mode == 'train': self.is_train = True - self.netG = 'resnet_9blocks' - self.netD = 'n_layers' + self.netG = self.net_g #'resnet_9blocks' + self.netD = self.net_d #'n_layers' self.n_layers_D = 4 self.lambda_L1 = 100 self.lambda_feat = 100 + else: self.phase = 'test' self.is_train = False diff --git a/deepliif/train.py b/deepliif/scripts/train.py old mode 100755 new mode 100644 similarity index 54% rename from deepliif/train.py rename to deepliif/scripts/train.py index 39f38be..4824f38 --- a/deepliif/train.py +++ b/deepliif/scripts/train.py @@ -1,30 +1,26 @@ """ +This script is used by cli.py trainlaunch command for DDP. Keep this train.py up-to-date with train() in cli.py! They are EXACTLY THE SAME. """ -import os -import json import time -import random - -import click -import cv2 -import torch -import numpy as np +from deepliif.options.train_options import TrainOptions +from deepliif.data import create_dataset +from deepliif.models import create_model, postprocess +from deepliif.options import read_model_params, Options, print_options +from deepliif.util.visualizer import Visualizer from PIL import Image - -from deepliif.data import create_dataset, AlignedDataset, transform -from deepliif.models import inference, postprocess, compute_overlap, init_nets, DeepLIIFModel -from deepliif.util import allowed_file, Visualizer - -import torch.distributed as dist import os -import torch import numpy as np import random +import json import torch +import torch.distributed as dist +from torchvision.transforms import ToPILImage + +import click def set_seed(seed=0,rank=None): """ @@ -57,8 +53,6 @@ def set_seed(seed=0,rank=None): return False - - @click.command() @click.option('--dataroot', required=True, type=str, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') @@ -66,8 +60,9 @@ def set_seed(seed=0,rank=None): help='name of the experiment. It decides where to store samples and models') @click.option('--gpu-ids', type=int, multiple=True, help='gpu-ids 0 gpu-ids 1 or gpu-ids -1 for CPU') @click.option('--checkpoints-dir', default='./checkpoints', help='models are saved here') -@click.option('--targets-no', default=5, help='number of targets') +@click.option('--modalities-no', default=4, type=int, help='number of targets') # model parameters +@click.option('--model', default='DeepLIIF', help='name of model class') @click.option('--input-nc', default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') @click.option('--output-nc', default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') @click.option('--ngf', default=64, help='# of gen filters in the last conv layer') @@ -76,14 +71,13 @@ def set_seed(seed=0,rank=None): help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 ' 'PatchGAN. n_layers allows you to specify the layers in the discriminator') @click.option('--net-g', default='resnet_9blocks', - help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]') + help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention') @click.option('--n-layers-d', default=4, help='only used if netD==n_layers') @click.option('--norm', default='batch', help='instance normalization or batch normalization [instance | batch | none]') @click.option('--init-type', default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') @click.option('--init-gain', default=0.02, help='scaling factor for normal, xavier and orthogonal.') -@click.option('--padding-type', default='reflect', help='network padding type.') @click.option('--no-dropout', is_flag=True, help='no dropout for the generator') # dataset parameters @click.option('--direction', default='AtoB', help='AtoB or BtoA') @@ -120,12 +114,15 @@ def set_seed(seed=0,rank=None): help='number of epochs with the initial learning rate') @click.option('--n-epochs-decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero') +@click.option('--optimizer', type=str, default='adam', + help='optimizer from torch.optim to use, applied to both generators and discriminators [adam | sgd | adamw | ...]; the current parameters however are set up for adam, so other optimziers may encounter issue') @click.option('--beta1', default=0.5, help='momentum term of adam') @click.option('--lr', default=0.0002, help='initial learning rate for adam') @click.option('--lr-policy', default='linear', help='learning rate policy. [linear | step | plateau | cosine]') @click.option('--lr-decay-iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') +@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)') # visdom and HTML visualization parameters @click.option('--display-freq', default=400, help='frequency of showing training results on screen') @click.option('--display-ncols', default=4, @@ -146,15 +143,36 @@ def set_seed(seed=0,rank=None): @click.option('--save-by-iter', is_flag=True, help='whether saves model by iteration') @click.option('--remote', type=bool, default=False, help='whether isolate visdom checkpoints or not; if False, you can run a separate visdom server anywhere that consumes the checkpoints') @click.option('--remote-transfer-cmd', type=str, default=None, help='module and function to be used to transfer remote files to target storage location, for example mymodule.myfunction') +@click.option('--dataset-mode', type=str, default='aligned', + help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') +@click.option('--padding', type=str, default='zero', + help='chooses the type of padding used by resnet generator. [reflect | zero]') +# DeepLIIFExt params +@click.option('--seg-gen', type=bool, default=True, help='True (Translation and Segmentation), False (Only Translation).') +@click.option('--net-ds', type=str, default='n_layers', + help='specify discriminator architecture for segmentation task [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') +@click.option('--net-gs', type=str, default='unet_512', + help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention') +@click.option('--gan-mode', type=str, default='vanilla', + help='the type of GAN objective for translation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') +@click.option('--gan-mode-s', type=str, default='lsgan', + help='the type of GAN objective for segmentation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') +# DDP related arguments @click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup') -@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)') -def train(dataroot, name, gpu_ids, checkpoints_dir, targets_no, input_nc, output_nc, ngf, ndf, net_d, net_g, - n_layers_d, norm, init_type, init_gain, padding_type, no_dropout, direction, serial_batches, num_threads, +# Others +@click.option('--with-val', is_flag=True, + help='use validation set to evaluate model performance at the end of each epoch') +@click.option('--debug', is_flag=True, + help='debug mode, limits the number of data points per epoch to a small value') +@click.option('--debug-data-size', default=10, type=int, help='data size per epoch used in debug mode; due to batch size, the epoch will be passed once the completed no. data points is greater than this value (e.g., for batch size 3, debug data size 10, the effective size used in training will be 12)') +def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, ndf, net_d, net_g, + n_layers_d, norm, init_type, init_gain, no_dropout, direction, serial_batches, num_threads, batch_size, load_size, crop_size, max_dataset_size, preprocess, no_flip, display_winsize, epoch, load_iter, verbose, lambda_l1, is_train, display_freq, display_ncols, display_id, display_server, display_env, display_port, update_html_freq, print_freq, no_html, save_latest_freq, save_epoch_freq, save_by_iter, - continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, beta1, lr, lr_decay_iters, - remote, local_rank, remote_transfer_cmd, seed): + continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, optimizer, beta1, lr, lr_decay_iters, + remote, remote_transfer_cmd, seed, dataset_mode, padding, model, + modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s, local_rank, with_val, debug, debug_data_size): """General-purpose training script for multi-task image-to-image translation. This script works for various models (with option '--model': e.g., DeepLIIF) and @@ -166,6 +184,26 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, targets_no, input_nc, output plot, and save models.The script supports continue/resume training. Use '--continue_train' to resume your previous training. """ + assert model in ['DeepLIIF','DeepLIIFExt','SDG'], f'model class {model} is not implemented' + if model == 'DeepLIIF': + seg_no = 1 + elif model == 'DeepLIIFExt': + if seg_gen: + seg_no = modalities_no + else: + seg_no = 0 + else: # SDG + seg_no = 0 + seg_gen = False + + if optimizer != 'adam': + print(f'Optimizer torch.optim.{optimizer} is not tested. Be careful about the parameters of the optimizer.') + + d_params = locals() + + if gpu_ids and gpu_ids[0] == -1: + gpu_ids = [] + local_rank = os.getenv('LOCAL_RANK') # DDP single node training triggered by torchrun has LOCAL_RANK rank = os.getenv('RANK') # if using DDP with multiple nodes, please provide global rank in env var RANK @@ -173,12 +211,12 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, targets_no, input_nc, output if local_rank is not None: local_rank = int(local_rank) torch.cuda.set_device(gpu_ids[local_rank]) - gpu_ids=[gpu_ids[local_rank]] + gpu_ids=[local_rank] else: torch.cuda.set_device(gpu_ids[0]) if local_rank is not None: # LOCAL_RANK will be assigned a rank number if torchrun ddp is used - dist.init_process_group(backend='nccl') + dist.init_process_group(backend="nccl", rank=int(os.environ['RANK']), world_size=int(os.environ['WORLD_SIZE'])) print('local rank:',local_rank) flag_deterministic = set_seed(seed,local_rank) elif rank is not None: @@ -187,28 +225,71 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, targets_no, input_nc, output flag_deterministic = set_seed(seed) if flag_deterministic: - padding_type = 'zero' + d_params['padding'] = 'zero' print('padding type is forced to zero padding, because neither refection pad2d or replication pad2d has a deterministic implementation') + # infer number of input images + dir_data_train = dataroot + '/train' + fns = os.listdir(dir_data_train) + fns = [x for x in fns if x.endswith('.png')] + print(f'{len(fns)} images found') + img = Image.open(f"{dir_data_train}/{fns[0]}") + print(f'image shape:',img.size) + + num_img = img.size[0] / img.size[1] + assert int(num_img) == num_img, f'img size {img.size[0]} / {img.size[1]} = {num_img} is not an integer' + num_img = int(num_img) + + input_no = num_img - modalities_no - seg_no + assert input_no > 0, f'inferred number of input images is {input_no} (modalities_no {modalities_no}, seg_no {seg_no}); should be greater than 0' + d_params['input_no'] = input_no + d_params['scale_size'] = img.size[1] + d_params['gpu_ids'] = gpu_ids + + # update generator arch + net_g = net_g.split(',') + assert len(net_g) in [1,modalities_no], f'net_g should contain either 1 architecture for all translation generators or the same number of architectures as the number of translation generators ({modalities_no})' + if len(net_g) == 1: + net_g = net_g*modalities_no + + + net_gs = net_gs.split(',') + assert len(net_gs) in [1,seg_no], f'net_gs should contain either 1 architecture for all segmentation generators or the same number of architectures as the number of segmentation generators ({seg_no})' + if len(net_gs) == 1 and model == 'DeepLIIF': + net_gs = net_gs*(modalities_no + seg_no) + elif len(net_gs) == 1: + net_gs = net_gs*seg_no + + d_params['net_g'] = net_g + d_params['net_gs'] = net_gs + # create a dataset given dataset_mode and other options - dataset = AlignedDataset(dataroot, load_size, crop_size, input_nc, output_nc, direction, targets_no, preprocess, - no_flip, phase, max_dataset_size) + # dataset = AlignedDataset(opt) + + opt = Options(d_params=d_params) + print_options(opt, save=True) + + # set dir for train and val + dataset = create_dataset(opt) - dataset = create_dataset(dataset, batch_size, serial_batches, num_threads, max_dataset_size, gpu_ids) # get the number of images in the dataset. click.echo('The number of training images = %d' % len(dataset)) + + if with_val: + dataset_val = create_dataset(opt,phase='val') + data_val = [batch for batch in dataset_val] + click.echo('The number of validation images = %d' % len(dataset_val)) + + if model in ['DeepLIIF']: + metrics_val = json.load(open(os.path.join(dataset_val.dataset.dir_AB,'metrics.json'))) # create a model given model and other options - model = DeepLIIFModel(gpu_ids, is_train, checkpoints_dir, name, preprocess, targets_no, input_nc, output_nc, ngf, - net_g, norm, no_dropout, init_type, init_gain, padding_type, ndf, net_d, n_layers_d, lr, beta1, lambda_l1, - lr_policy, remote_transfer_cmd) + model = create_model(opt) # regular setup: load and print networks; create schedulers - model.setup(lr_policy, epoch_count, n_epochs, n_epochs_decay, lr_decay_iters, continue_train, load_iter, epoch, - verbose) + model.setup(opt) # create a visualizer that display/save images and plots - visualizer = Visualizer(display_id, is_train, no_html, display_winsize, name, display_port, display_ncols, - display_server, display_env, checkpoints_dir, remote, remote_transfer_cmd) + visualizer = Visualizer(opt) # the total number of training iterations total_iters = 0 @@ -230,6 +311,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, targets_no, input_nc, output # inner loop within one epoch for i, data in enumerate(dataset): + # timer for computation per iteration iter_start_time = time.time() if total_iters % print_freq == 0: @@ -246,15 +328,15 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, targets_no, input_nc, output if total_iters % display_freq == 0: save_result = total_iters % update_html_freq == 0 model.compute_visuals() - visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) + visualizer.display_current_results({**model.get_current_visuals()}, epoch, save_result) # print training losses and save logging information to the disk if total_iters % print_freq == 0: - losses = model.get_current_losses() + losses = model.get_current_losses() # get training losses t_comp = (time.time() - iter_start_time) / batch_size - visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) + visualizer.print_current_losses(epoch, epoch_iter, {**losses}, t_comp, t_data) if display_id > 0: - visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), losses) + visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), {**losses}) # cache our latest model every iterations if total_iters % save_latest_freq == 0: @@ -262,7 +344,11 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, targets_no, input_nc, output save_suffix = 'iter_%d' % total_iters if save_by_iter else 'latest' model.save_networks(save_suffix) + iter_data_time = time.time() + if debug and epoch_iter >= debug_data_size: + print(f'debug mode, epoch {epoch} stopped at epoch iter {epoch_iter} (>= {debug_data_size})') + break # cache our model every epochs if epoch % save_epoch_freq == 0: @@ -270,6 +356,77 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, targets_no, input_nc, output model.save_networks('latest') model.save_networks(epoch) + + + # validation loss and metrics calculation + if with_val: + losses = model.get_current_losses() # get training losses to print + + model.eval() + l_losses_val = [] + l_metrics_val = [] + + # for each val image, calculate validation loss and cell count metrics + for j, data_val_batch in enumerate(data_val): + # batch size is effectively 1 for validation + model.set_input(data_val_batch) + model.calculate_losses() # this does not optimize parameters + visuals = model.get_current_visuals() # get image results + + # val losses + losses_val_batch = model.get_current_losses() + l_losses_val += [(k,v) for k,v in losses_val_batch.items()] + + # calculate cell count metrics + if type(model).__name__ == 'DeepLIIFModel': + l_seg_names = ['fake_B_5'] + assert l_seg_names[0] in visuals.keys(), f'Cannot find {l_seg_names[0]} in generated image names ({list(visuals.keys())})' + seg_mod_suffix = l_seg_names[0].split('_')[-1] + l_seg_names += [x for x in visuals.keys() if x.startswith('fake') and x.split('_')[-1].startswith(seg_mod_suffix) and x != l_seg_names[0]] + # print(f'Running postprocess for {len(l_seg_names)} generated images ({l_seg_names})') + + img_name_current = data_val_batch['A_paths'][0].split('/')[-1][:-4] # remove .png + metrics_gt = metrics_val[img_name_current] + + for seg_name in l_seg_names: + images = {'Seg':ToPILImage()((visuals[seg_name][0].cpu()+1)/2), + #'Marker':ToPILImage()((visuals['fake_B_4'][0].cpu()+1)/2) + } + _, scoring = postprocess(ToPILImage()((data['A'][0]+1)/2), images, opt.scale_size, opt.model) + + for k,v in scoring.items(): + if k.startswith('num') or k.startswith('percent'): + # to calculate the rmse, here we calculate (x_pred - x_true) ** 2 + l_metrics_val.append((k+'_'+seg_name,(v - metrics_gt[k])**2)) + + if debug and epoch_iter >= debug_data_size: + print(f'debug mode, epoch {epoch} stopped at epoch iter {epoch_iter} (>= {debug_data_size})') + break + + d_losses_val = {k+'_val':0 for k in losses_val_batch.keys()} + for k,v in l_losses_val: + d_losses_val[k+'_val'] += v + for k in d_losses_val: + d_losses_val[k] = d_losses_val[k] / len(data_val) + + d_metrics_val = {} + for k,v in l_metrics_val: + try: + d_metrics_val[k] += v + except: + d_metrics_val[k] = v + for k in d_metrics_val: + # to calculate the rmse, this is the second part, where d_metrics_val[k] now represents sum((x_pred - x_true) ** 2) + d_metrics_val[k] = np.sqrt(d_metrics_val[k] / len(data_val)) + + + model.train() + t_comp = (time.time() - iter_start_time) / batch_size + visualizer.print_current_losses(epoch, epoch_iter, {**losses,**d_losses_val, **d_metrics_val}, t_comp, t_data) + if display_id > 0: + visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), {**losses,**d_losses_val,**d_metrics_val}) + + print('End of epoch %d / %d \t Time Taken: %d sec' % ( epoch, n_epochs + n_epochs_decay, time.time() - epoch_start_time)) # update learning rates at the end of every epoch. @@ -277,4 +434,4 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, targets_no, input_nc, output if __name__ == '__main__': - train() \ No newline at end of file + train() diff --git a/deepliif/stat/__init__.py b/deepliif/stat/__init__.py new file mode 100644 index 0000000..559db4f --- /dev/null +++ b/deepliif/stat/__init__.py @@ -0,0 +1,83 @@ +import os +import json +from PIL import Image +from ..models import postprocess +import re + +def get_cell_count_metrics(dir_img, dir_save=None, model = None, + tile_size=512, single_tile=False, + use_marker = False, + save_individual = False): + """ + Obtain cell count metrics through postprocess functions. + Currenlty implemented only for ground truth tiles. + + dir_img: directory to load images for calculation + dir_save: directory to save the results out + model: model type (DeepLIIF, DeepLIIFExt, SDG) + tile_size: tile size used for postprocess calculation + single_tile: True if the images are single-tile images; use False if the + images contain a row of multiple tiles like those used in + training or validation + use_marker: whether to use the marker image (if Truem, assumes the marker + image is the second last tile (single_tile=False) or has a + suffix of "_4.png" (single_tile=True)) + """ + dir_save = dir_save if dir_img is None else dir_save + + if model is None: + if 'sdg' in dir_img: + model = 'SDG' + elif 'ext' in dir_img: + model = 'DeepLIIFExt' + else: + model = 'DeepLIIF' + + if single_tile: + fns = [x for x in os.listdir(dir_img) if x.endswith('_5.png') or x.endswith('_4.png')] + fns = list(set([x[:-6] for x in fns])) # fns do not have extention + else: + fns = [x for x in os.listdir(dir_img) if x.endswith('.png')] # fns have extension + + d_metrics = {} + count = 0 + for fn in fns: + if single_tile: + img_gt = Image.open(os.path.join(dir_img,fn+'_5.png')) + img_marker = Image.open(os.path.join(dir_img,fn+'_4.png')) + img_input = Image.open(os.path.join(dir_img.replace('/gt','/input'),fn+'.png')) + k = fn + else: + img = Image.open(os.path.join(dir_img,fn)) + w, h = img.size + + # assume in the row of tiles, the first is the input and the last is the ground truth + img_input = img.crop((0,0,h,h)) + img_gt = img.crop((w-h,0,w,h)) + img_marker = img.crop((w-h*2,0,w-h,h)) # the second last is marker, if marker is included + k = re.sub('\..*?$','',fn) # remove extension + + images = {'Seg':img_gt} + if use_marker: + images['Marker'] = img_marker + + post_images, scoring = postprocess(img_input, images, tile_size, model) + d_metrics[k] = scoring + + if save_individual: + with open(os.path.join( + dir_save, + k+'.json' + ), 'w') as f: + json.dump(scoring, f, indent=2) + + count += 1 + + if count % 100 == 0 or count == len(fns): + print(count,'/',len(fns)) + + with open(os.path.join( + dir_save, + 'metrics.json' + ), 'w') as f: + json.dump(d_metrics, f, indent=2) diff --git a/deepliif/util/__init__.py b/deepliif/util/__init__.py index 4911580..b8c8263 100644 --- a/deepliif/util/__init__.py +++ b/deepliif/util/__init__.py @@ -14,11 +14,6 @@ from ..postprocessing import imadjust import cv2 -import bioformats -import javabridge -import bioformats.omexml as ome -import tifffile as tf - import pickle import sys @@ -440,6 +435,115 @@ def get_information(filename): return size_x, size_y, size_z, size_c, size_t, pixel_type + + +def write_results_to_pickle_file(output_addr, results): + """ + This function writes data into the pickle file. + :param output_addr: The address of the pickle file to write data into. + :param results: The data to be written into the pickle file. + :return: + """ + pickle_obj = open(output_addr, "wb") + pickle.dump(results, pickle_obj) + pickle_obj.close() + + +def read_results_from_pickle_file(input_addr): + """ + This function reads data from a pickle file and returns it. + :param input_addr: The address to the pickle file. + :return: The data inside pickle file. + """ + pickle_obj = open(input_addr, "rb") + results = pickle.load(pickle_obj) + pickle_obj.close() + return results + +def test_diff_original_serialized(model_original,model_serialized,example,verbose=0): + threshold = 10 + + orig_res = model_original(example) + if verbose > 0: + print('Original:') + print(orig_res.shape) + print(orig_res[0, 0:10]) + print('min abs value:{}'.format(torch.min(torch.abs(orig_res)))) + + ts_res = model_serialized(example) + if verbose > 0: + print('Torchscript:') + print(ts_res.shape) + print(ts_res[0, 0:10]) + print('min abs value:{}'.format(torch.min(torch.abs(ts_res)))) + + abs_diff = torch.abs(orig_res-ts_res) + if verbose > 0: + print('Dif sum:') + print(torch.sum(abs_diff)) + print('max dif:{}'.format(torch.max(abs_diff))) + + assert torch.sum(abs_diff) <= threshold, f"Sum of difference in predicted values {torch.sum(abs_diff)} is larger than threshold {threshold}" + +def disable_batchnorm_tracking_stats(model): + # https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/16 + # https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/67 + # https://github.com/pytorch/pytorch/blob/ca39c5b04e30a67512589cafbd9d063cc17168a5/torch/nn/modules/batchnorm.py#L158 + for m in model.modules(): + for child in m.children(): + if type(child) == torch.nn.BatchNorm2d: + child.track_running_stats = False + child.running_mean_backup = child.running_mean + child.running_mean = None + child.running_var_backup = child.running_var + child.running_var = None + return model + +def enable_batchnorm_tracking_stats(model): + """ + This is needed during training when val set loss/metrics calculation is enabled. + In this case, we need to switch to eval mode for inference, which triggers + disable_batchnorm_tracking_stats(). After the evaluation, the model should be + set back to train mode, where running stats are restored for batchnorm layers. + """ + for m in model.modules(): + for child in m.children(): + if type(child) == torch.nn.BatchNorm2d: + child.track_running_stats = True + assert hasattr(child, 'running_mean_backup') and hasattr(child, 'running_var_backup'), 'enable_batchnorm_tracking_stats() is supposed to be executed after disable_batchnorm_tracking_stats() is applied' + child.running_mean = child.running_mean_backup + child.running_var = child.running_var_backup + return model + + +def image_variance_gray(img): + px = np.asarray(img) if img.mode == 'L' else np.asarray(img.convert('L')) + idx = np.logical_and(px != 255, px != 0) + val = px[idx] + if val.shape[0] == 0: + return 0 + var = np.var(val) + return var + + +def image_variance_rgb(img): + px = np.asarray(img) if img.mode == 'RGB' else np.asarray(img.convert('RGB')) + nonwhite = np.any(px != [255, 255, 255], axis=-1) + nonblack = np.any(px != [0, 0, 0], axis=-1) + idx = np.logical_and(nonwhite, nonblack) + val = px[idx] + if val.shape[0] == 0: + return [0, 0, 0] + var = np.var(val, axis=0) + return var + + + +import bioformats +import javabridge +import bioformats.omexml as ome +import tifffile as tf + def write_big_tiff_file(output_addr, img, tile_size): """ This function write the image into a big tiff file using the tiling and compression. @@ -581,64 +685,3 @@ def write_ome_tiff_file_array(results_array, output_addr, size_t, size_z, size_c output_addr, SizeT=size_t, SizeZ=size_z, SizeC=len(channel_names), SizeX=size_x, SizeY=size_y, channel_names=channel_names) - - -def write_results_to_pickle_file(output_addr, results): - """ - This function writes data into the pickle file. - :param output_addr: The address of the pickle file to write data into. - :param results: The data to be written into the pickle file. - :return: - """ - pickle_obj = open(output_addr, "wb") - pickle.dump(results, pickle_obj) - pickle_obj.close() - - -def read_results_from_pickle_file(input_addr): - """ - This function reads data from a pickle file and returns it. - :param input_addr: The address to the pickle file. - :return: The data inside pickle file. - """ - pickle_obj = open(input_addr, "rb") - results = pickle.load(pickle_obj) - pickle_obj.close() - return results - -def test_diff_original_serialized(model_original,model_serialized,example,verbose=0): - threshold = 10 - - orig_res = model_original(example) - if verbose > 0: - print('Original:') - print(orig_res.shape) - print(orig_res[0, 0:10]) - print('min abs value:{}'.format(torch.min(torch.abs(orig_res)))) - - ts_res = model_serialized(example) - if verbose > 0: - print('Torchscript:') - print(ts_res.shape) - print(ts_res[0, 0:10]) - print('min abs value:{}'.format(torch.min(torch.abs(ts_res)))) - - abs_diff = torch.abs(orig_res-ts_res) - if verbose > 0: - print('Dif sum:') - print(torch.sum(abs_diff)) - print('max dif:{}'.format(torch.max(abs_diff))) - - assert torch.sum(abs_diff) <= threshold, f"Sum of difference in predicted values {torch.sum(abs_diff)} is larger than threshold {threshold}" - -def disable_batchnorm_tracking_stats(model): - # https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/16 - # https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/67 - # https://github.com/pytorch/pytorch/blob/ca39c5b04e30a67512589cafbd9d063cc17168a5/torch/nn/modules/batchnorm.py#L158 - for m in model.modules(): - for child in m.children(): - if type(child) == torch.nn.BatchNorm2d: - child.track_running_stats = False - child.running_mean = None - child.running_var = None - return model diff --git a/deepliif/util/visualizer.py b/deepliif/util/visualizer.py index 0856b51..0bbf7cb 100644 --- a/deepliif/util/visualizer.py +++ b/deepliif/util/visualizer.py @@ -165,6 +165,7 @@ def display_current_results(self, visuals, epoch, save_result, **kwargs): if ncols > 0: # show all the images in one visdom panel ncols = min(ncols, len(visuals)) h, w = next(iter(visuals.values())).shape[:2] + table_css = """