From 59d41406c65edd206c9547d7492b893688790986 Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Wed, 31 Jan 2024 15:12:33 -0500 Subject: [PATCH 01/41] updated train.py to reflect changes in cli.py train & trainlaunch --- deepliif/{ => scripts}/train.py | 102 +++++++---- train.py | 291 ++++++++++++++++++++++++++++---- 2 files changed, 328 insertions(+), 65 deletions(-) rename deepliif/{ => scripts}/train.py (78%) mode change 100755 => 100644 diff --git a/deepliif/train.py b/deepliif/scripts/train.py old mode 100755 new mode 100644 similarity index 78% rename from deepliif/train.py rename to deepliif/scripts/train.py index 39f38be..34eb5cd --- a/deepliif/train.py +++ b/deepliif/scripts/train.py @@ -1,4 +1,5 @@ """ +This script is used by cli.py trainlaunch command for DDP. Keep this train.py up-to-date with train() in cli.py! They are EXACTLY THE SAME. """ @@ -14,17 +15,13 @@ import numpy as np from PIL import Image -from deepliif.data import create_dataset, AlignedDataset, transform -from deepliif.models import inference, postprocess, compute_overlap, init_nets, DeepLIIFModel -from deepliif.util import allowed_file, Visualizer +from deepliif.data import create_dataset +from deepliif.models import create_model +from deepliif.util import Visualizer +from deepliif.options import Options, print_options import torch.distributed as dist -import os -import torch - -import numpy as np -import random -import torch + def set_seed(seed=0,rank=None): """ @@ -66,8 +63,9 @@ def set_seed(seed=0,rank=None): help='name of the experiment. It decides where to store samples and models') @click.option('--gpu-ids', type=int, multiple=True, help='gpu-ids 0 gpu-ids 1 or gpu-ids -1 for CPU') @click.option('--checkpoints-dir', default='./checkpoints', help='models are saved here') -@click.option('--targets-no', default=5, help='number of targets') +@click.option('--modalities-no', default=4, type=int, help='number of targets') # model parameters +@click.option('--model', default='DeepLIIF', help='name of model class') @click.option('--input-nc', default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') @click.option('--output-nc', default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') @click.option('--ngf', default=64, help='# of gen filters in the last conv layer') @@ -83,7 +81,6 @@ def set_seed(seed=0,rank=None): @click.option('--init-type', default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') @click.option('--init-gain', default=0.02, help='scaling factor for normal, xavier and orthogonal.') -@click.option('--padding-type', default='reflect', help='network padding type.') @click.option('--no-dropout', is_flag=True, help='no dropout for the generator') # dataset parameters @click.option('--direction', default='AtoB', help='AtoB or BtoA') @@ -126,6 +123,7 @@ def set_seed(seed=0,rank=None): help='learning rate policy. [linear | step | plateau | cosine]') @click.option('--lr-decay-iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') +@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)') # visdom and HTML visualization parameters @click.option('--display-freq', default=400, help='frequency of showing training results on screen') @click.option('--display-ncols', default=4, @@ -146,15 +144,30 @@ def set_seed(seed=0,rank=None): @click.option('--save-by-iter', is_flag=True, help='whether saves model by iteration') @click.option('--remote', type=bool, default=False, help='whether isolate visdom checkpoints or not; if False, you can run a separate visdom server anywhere that consumes the checkpoints') @click.option('--remote-transfer-cmd', type=str, default=None, help='module and function to be used to transfer remote files to target storage location, for example mymodule.myfunction') +@click.option('--dataset-mode', type=str, default='aligned', + help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') +@click.option('--padding', type=str, default='zero', + help='chooses the type of padding used by resnet generator. [reflect | zero]') +# DeepLIIFExt params +@click.option('--seg-gen', type=bool, default=True, help='True (Translation and Segmentation), False (Only Translation).') +@click.option('--net-ds', type=str, default='n_layers', + help='specify discriminator architecture for segmentation task [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') +@click.option('--net-gs', type=str, default='unet_512', + help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]') +@click.option('--gan-mode', type=str, default='vanilla', + help='the type of GAN objective for translation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') +@click.option('--gan-mode-s', type=str, default='lsgan', + help='the type of GAN objective for segmentation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') +# DDP related arguments @click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup') -@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)') -def train(dataroot, name, gpu_ids, checkpoints_dir, targets_no, input_nc, output_nc, ngf, ndf, net_d, net_g, - n_layers_d, norm, init_type, init_gain, padding_type, no_dropout, direction, serial_batches, num_threads, +def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, ndf, net_d, net_g, + n_layers_d, norm, init_type, init_gain, no_dropout, direction, serial_batches, num_threads, batch_size, load_size, crop_size, max_dataset_size, preprocess, no_flip, display_winsize, epoch, load_iter, verbose, lambda_l1, is_train, display_freq, display_ncols, display_id, display_server, display_env, display_port, update_html_freq, print_freq, no_html, save_latest_freq, save_epoch_freq, save_by_iter, continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, beta1, lr, lr_decay_iters, - remote, local_rank, remote_transfer_cmd, seed): + remote, remote_transfer_cmd, seed, dataset_mode, padding, model, + modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s, local_rank): """General-purpose training script for multi-task image-to-image translation. This script works for various models (with option '--model': e.g., DeepLIIF) and @@ -166,19 +179,36 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, targets_no, input_nc, output plot, and save models.The script supports continue/resume training. Use '--continue_train' to resume your previous training. """ + assert model in ['DeepLIIF','DeepLIIFExt','SDG'], f'model class {model} is not implemented' + if model == 'DeepLIIF': + seg_no = 1 + elif model == 'DeepLIIFExt': + if seg_gen: + seg_no = modalities_no + else: + seg_no = 0 + else: # SDG + seg_no = 0 + seg_gen = False + + d_params = locals() + + if gpu_ids and gpu_ids[0] == -1: + gpu_ids = [] + local_rank = os.getenv('LOCAL_RANK') # DDP single node training triggered by torchrun has LOCAL_RANK rank = os.getenv('RANK') # if using DDP with multiple nodes, please provide global rank in env var RANK if len(gpu_ids) > 0: if local_rank is not None: local_rank = int(local_rank) - torch.cuda.set_device(gpu_ids[local_rank]) - gpu_ids=[gpu_ids[local_rank]] + torch.cuda.set_device(local_rank)#gpu_ids[local_rank]) + gpu_ids = [local_rank] else: torch.cuda.set_device(gpu_ids[0]) if local_rank is not None: # LOCAL_RANK will be assigned a rank number if torchrun ddp is used - dist.init_process_group(backend='nccl') + dist.init_process_group(backend="nccl", rank=int(os.environ['RANK']), world_size=int(os.environ['WORLD_SIZE'])) print('local rank:',local_rank) flag_deterministic = set_seed(seed,local_rank) elif rank is not None: @@ -187,28 +217,42 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, targets_no, input_nc, output flag_deterministic = set_seed(seed) if flag_deterministic: - padding_type = 'zero' + d_params['padding'] = 'zero' print('padding type is forced to zero padding, because neither refection pad2d or replication pad2d has a deterministic implementation') + # infer number of input images + dir_data_train = dataroot + '/train' + fns = os.listdir(dir_data_train) + fns = [x for x in fns if x.endswith('.png')] + img = Image.open(f"{dir_data_train}/{fns[0]}") + + num_img = img.size[0] / img.size[1] + assert int(num_img) == num_img, f'img size {img.size[0]} / {img.size[1]} = {num_img} is not an integer' + num_img = int(num_img) + + input_no = num_img - modalities_no - seg_no + assert input_no > 0, f'inferred number of input images is {input_no}; should be greater than 0' + d_params['input_no'] = input_no + d_params['scale_size'] = img.size[1] + d_params['gpu_ids'] = gpu_ids + # create a dataset given dataset_mode and other options - dataset = AlignedDataset(dataroot, load_size, crop_size, input_nc, output_nc, direction, targets_no, preprocess, - no_flip, phase, max_dataset_size) + # dataset = AlignedDataset(opt) - dataset = create_dataset(dataset, batch_size, serial_batches, num_threads, max_dataset_size, gpu_ids) + opt = Options(d_params=d_params) + print_options(opt, save=True) + + dataset = create_dataset(opt) # get the number of images in the dataset. click.echo('The number of training images = %d' % len(dataset)) # create a model given model and other options - model = DeepLIIFModel(gpu_ids, is_train, checkpoints_dir, name, preprocess, targets_no, input_nc, output_nc, ngf, - net_g, norm, no_dropout, init_type, init_gain, padding_type, ndf, net_d, n_layers_d, lr, beta1, lambda_l1, - lr_policy, remote_transfer_cmd) + model = create_model(opt) # regular setup: load and print networks; create schedulers - model.setup(lr_policy, epoch_count, n_epochs, n_epochs_decay, lr_decay_iters, continue_train, load_iter, epoch, - verbose) + model.setup(opt) # create a visualizer that display/save images and plots - visualizer = Visualizer(display_id, is_train, no_html, display_winsize, name, display_port, display_ncols, - display_server, display_env, checkpoints_dir, remote, remote_transfer_cmd) + visualizer = Visualizer(opt) # the total number of training iterations total_iters = 0 diff --git a/train.py b/train.py index 205b6ac..d326a08 100644 --- a/train.py +++ b/train.py @@ -12,18 +12,23 @@ from deepliif.options.train_options import TrainOptions from deepliif.data import create_dataset from deepliif.models import create_model +from deepliif.options import read_model_params, Options, print_options from deepliif.util.visualizer import Visualizer +from PIL import Image import os import numpy as np import random import torch +import click def set_seed(seed=0,rank=None): """ seed: basic seed rank: rank of the current process, using which to mutate basic seed to have a unique seed per process + + output: a boolean flag indicating whether deterministic training is enabled (True) or not (False) """ os.environ['DEEPLIIF_SEED'] = str(seed) @@ -43,61 +48,275 @@ def set_seed(seed=0,rank=None): torch.backends.cudnn.deterministic = True torch.use_deterministic_algorithms(True) print(f'deterministic training, seed set to {seed_final}') + return True else: print(f'not using deterministic training') + return False -if __name__ == '__main__': +@click.command() +@click.option('--dataroot', required=True, type=str, + help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') +@click.option('--name', default='experiment_name', + help='name of the experiment. It decides where to store samples and models') +@click.option('--gpu_ids', type=int, multiple=True, help='gpu-ids 0 gpu-ids 1 or gpu-ids -1 for CPU') +@click.option('--checkpoints_dir', default='./checkpoints', help='models are saved here') +@click.option('--modalities_no', default=4, type=int, help='number of targets') +# model parameters +@click.option('--model', default='DeepLIIF', help='name of model class') +@click.option('--input_nc', default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') +@click.option('--output_nc', default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') +@click.option('--ngf', default=64, help='# of gen filters in the last conv layer') +@click.option('--ndf', default=64, help='# of discrim filters in the first conv layer') +@click.option('--net_d', default='n_layers', + help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 ' + 'PatchGAN. n_layers allows you to specify the layers in the discriminator') +@click.option('--net_g', default='resnet_9blocks', + help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]') +@click.option('--n_layers_d', default=4, help='only used if netD==n_layers') +@click.option('--norm', default='batch', + help='instance normalization or batch normalization [instance | batch | none]') +@click.option('--init_type', default='normal', + help='network initialization [normal | xavier | kaiming | orthogonal]') +@click.option('--init_gain', default=0.02, help='scaling factor for normal, xavier and orthogonal.') +@click.option('--no_dropout', is_flag=True, help='no dropout for the generator') +# dataset parameters +@click.option('--direction', default='AtoB', help='AtoB or BtoA') +@click.option('--serial_batches', is_flag=True, + help='if true, takes images in order to make batches, otherwise takes them randomly') +@click.option('--num_threads', default=4, help='# threads for loading data') +@click.option('--batch_size', default=1, help='input batch size') +@click.option('--load_size', default=512, help='scale images to this size') +@click.option('--crop_size', default=512, help='then crop to this size') +@click.option('--max_dataset_size', type=int, + help='Maximum number of samples allowed per dataset. If the dataset directory contains more than ' + 'max_dataset_size, only a subset is loaded.') +@click.option('--preprocess', type=str, + help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | ' + 'scale_width_and_crop | none]') +@click.option('--no_flip', is_flag=True, + help='if specified, do not flip the images for data augmentation') +@click.option('--display_winsize', default=512, help='display window size for both visdom and HTML') +# additional parameters +@click.option('--epoch', default='latest', + help='which epoch to load? set to latest to use latest cached model') +@click.option('--load_iter', default=0, + help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; ' + 'otherwise, the code will load models by [epoch]') +@click.option('--verbose', is_flag=True, help='if specified, print more debugging information') +@click.option('--lambda_L1', default=100.0, help='weight for L1 loss') +@click.option('--is_train', is_flag=True, default=True) +@click.option('--continue_train', is_flag=True, help='continue training: load the latest model') +@click.option('--epoch_count', type=int, default=0, + help='the starting epoch count, we save the model by , +') +@click.option('--phase', default='train', help='train, val, test, etc') +# training parameters +@click.option('--n_epochs', type=int, default=100, + help='number of epochs with the initial learning rate') +@click.option('--n_epochs_decay', type=int, default=100, + help='number of epochs to linearly decay learning rate to zero') +@click.option('--beta1', default=0.5, help='momentum term of adam') +@click.option('--lr', default=0.0002, help='initial learning rate for adam') +@click.option('--lr_policy', default='linear', + help='learning rate policy. [linear | step | plateau | cosine]') +@click.option('--lr_decay_iters', type=int, default=50, + help='multiply by a gamma every lr_decay_iters iterations') +@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)') +# visdom and HTML visualization parameters +@click.option('--display_freq', default=400, help='frequency of showing training results on screen') +@click.option('--display_ncols', default=4, + help='if positive, display all images in a single visdom web panel with certain number of images per row.') +@click.option('--display_id', default=1, help='window id of the web display') +@click.option('--display_server', default="http://localhost", help='visdom server of the web display') +@click.option('--display_env', default='main', + help='visdom display environment name (default is "main")') +@click.option('--display_port', default=8097, help='visdom port of the web display') +@click.option('--update_html_freq', default=1000, help='frequency of saving training results to html') +@click.option('--print_freq', default=100, help='frequency of showing training results on console') +@click.option('--no_html', is_flag=True, + help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') +# network saving and loading parameters +@click.option('--save_latest_freq', default=500, help='frequency of saving the latest results') +@click.option('--save_epoch_freq', default=100, + help='frequency of saving checkpoints at the end of epochs') +@click.option('--save_by_iter', is_flag=True, help='whether saves model by iteration') +@click.option('--remote', type=bool, default=False, help='whether isolate visdom checkpoints or not; if False, you can run a separate visdom server anywhere that consumes the checkpoints') +@click.option('--remote_transfer_cmd', type=str, default=None, help='module and function to be used to transfer remote files to target storage location, for example mymodule.myfunction') +@click.option('--dataset_mode', type=str, default='aligned', + help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') +@click.option('--padding', type=str, default='zero', + help='chooses the type of padding used by resnet generator. [reflect | zero]') +# DeepLIIFExt params +@click.option('--seg_gen', type=bool, default=True, help='True (Translation and Segmentation), False (Only Translation).') +@click.option('--net_ds', type=str, default='n_layers', + help='specify discriminator architecture for segmentation task [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') +@click.option('--net_gs', type=str, default='unet_512', + help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]') +@click.option('--gan_mode', type=str, default='vanilla', + help='the type of GAN objective for translation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') +@click.option('--gan_mode_s', type=str, default='lsgan', + help='the type of GAN objective for segmentation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') +# DDP related arguments +@click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup') +def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, ndf, net_d, net_g, + n_layers_d, norm, init_type, init_gain, no_dropout, direction, serial_batches, num_threads, + batch_size, load_size, crop_size, max_dataset_size, preprocess, no_flip, display_winsize, epoch, load_iter, + verbose, lambda_l1, is_train, display_freq, display_ncols, display_id, display_server, display_env, + display_port, update_html_freq, print_freq, no_html, save_latest_freq, save_epoch_freq, save_by_iter, + continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, beta1, lr, lr_decay_iters, + remote, remote_transfer_cmd, seed, dataset_mode, padding, model, + modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s, local_rank): + """General-purpose training script for multi-task image-to-image translation. + + This script works for various models (with option '--model': e.g., DeepLIIF) and + different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization). + You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model'). + + It first creates model, dataset, and visualizer given the option. + It then does standard network training. During the training, it also visualize/save the images, print/save the loss + plot, and save models.The script supports continue/resume training. + Use '--continue_train' to resume your previous training. + """ + assert model in ['DeepLIIF','DeepLIIFExt','SDG'], f'model class {model} is not implemented' + if model == 'DeepLIIF': + seg_no = 1 + elif model == 'DeepLIIFExt': + if seg_gen: + seg_no = modalities_no + else: + seg_no = 0 + else: # SDG + seg_no = 0 + seg_gen = False + + d_params = locals() + + if gpu_ids and gpu_ids[0] == -1: + gpu_ids = [] + + local_rank = os.getenv('LOCAL_RANK') # DDP single node training triggered by torchrun has LOCAL_RANK + rank = os.getenv('RANK') # if using DDP with multiple nodes, please provide global rank in env var RANK + + if len(gpu_ids) > 0: + if local_rank is not None: + local_rank = int(local_rank) + torch.cuda.set_device(local_rank)#gpu_ids[local_rank]) + gpu_ids = [local_rank] + else: + torch.cuda.set_device(gpu_ids[0]) + + if local_rank is not None: # LOCAL_RANK will be assigned a rank number if torchrun ddp is used + dist.init_process_group(backend="nccl", rank=int(os.environ['RANK']), world_size=int(os.environ['WORLD_SIZE'])) + print('local rank:',local_rank) + flag_deterministic = set_seed(seed,local_rank) + elif rank is not None: + flag_deterministic = set_seed(seed, rank) + else: + flag_deterministic = set_seed(seed) + + if flag_deterministic: + d_params['padding'] = 'zero' + print('padding type is forced to zero padding, because neither refection pad2d or replication pad2d has a deterministic implementation') - opt = TrainOptions().parse() # get training options - set_seed(opt.seed) - dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options - dataset_size = len(dataset) # get the number of images in the dataset. - print('The number of training images = %d' % dataset_size) - - model = create_model(opt) # create a model given opt.model and other options - model.setup(opt) # regular setup: load and print networks; create schedulers - visualizer = Visualizer(opt) # create a visualizer that display/save images and plots - total_iters = 0 # the total number of training iterations - - for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): # outer loop for different epochs; we save the model by , + - epoch_start_time = time.time() # timer for entire epoch - iter_data_time = time.time() # timer for data loading per iteration - epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch - visualizer.reset() # reset the visualizer: make sure it saves the results to HTML at least once every epoch - - for i, data in enumerate(dataset): # inner loop within one epoch - iter_start_time = time.time() # timer for computation per iteration - if total_iters % opt.print_freq == 0: + # infer number of input images + dir_data_train = dataroot + '/train' + fns = os.listdir(dir_data_train) + fns = [x for x in fns if x.endswith('.png')] + img = Image.open(f"{dir_data_train}/{fns[0]}") + + num_img = img.size[0] / img.size[1] + assert int(num_img) == num_img, f'img size {img.size[0]} / {img.size[1]} = {num_img} is not an integer' + num_img = int(num_img) + + input_no = num_img - modalities_no - seg_no + assert input_no > 0, f'inferred number of input images is {input_no}; should be greater than 0' + d_params['input_no'] = input_no + d_params['scale_size'] = img.size[1] + d_params['gpu_ids'] = gpu_ids + + # create a dataset given dataset_mode and other options + # dataset = AlignedDataset(opt) + + opt = Options(d_params=d_params) + print_options(opt, save=True) + + dataset = create_dataset(opt) + # get the number of images in the dataset. + click.echo('The number of training images = %d' % len(dataset)) + + # create a model given model and other options + model = create_model(opt) + # regular setup: load and print networks; create schedulers + model.setup(opt) + + # create a visualizer that display/save images and plots + visualizer = Visualizer(opt) + # the total number of training iterations + total_iters = 0 + + # outer loop for different epochs; we save the model by , + + for epoch in range(epoch_count, n_epochs + n_epochs_decay + 1): + # timer for entire epoch + epoch_start_time = time.time() + # timer for data loading per iteration + iter_data_time = time.time() + # the number of training iterations in current epoch, reset to 0 every epoch + epoch_iter = 0 + # reset the visualizer: make sure it saves the results to HTML at least once every epoch + visualizer.reset() + + # https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler + if local_rank is not None or os.getenv('RANK') is not None: # if DDP is used, either on one node or multi nodes + if not serial_batches: # if we want randome order in mini batches + dataset.sampler.set_epoch(epoch) + + # inner loop within one epoch + for i, data in enumerate(dataset): + # timer for computation per iteration + iter_start_time = time.time() + if total_iters % print_freq == 0: t_data = iter_start_time - iter_data_time - total_iters += opt.batch_size - epoch_iter += opt.batch_size - model.set_input(data) # unpack data from dataset and apply preprocessing - model.optimize_parameters() # calculate loss functions, get gradients, update network weights + total_iters += batch_size + epoch_iter += batch_size + # unpack data from dataset and apply preprocessing + model.set_input(data) + # calculate loss functions, get gradients, update network weights + model.optimize_parameters() - if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file - save_result = total_iters % opt.update_html_freq == 0 + # display images on visdom and save images to a HTML file + if total_iters % display_freq == 0: + save_result = total_iters % update_html_freq == 0 model.compute_visuals() visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) - if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk + # print training losses and save logging information to the disk + if total_iters % print_freq == 0: losses = model.get_current_losses() - t_comp = (time.time() - iter_start_time) / opt.batch_size + t_comp = (time.time() - iter_start_time) / batch_size visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) - if opt.display_id > 0: - visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) + if display_id > 0: + visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), losses) - if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations + # cache our latest model every iterations + if total_iters % save_latest_freq == 0: print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) - save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' + save_suffix = 'iter_%d' % total_iters if save_by_iter else 'latest' model.save_networks(save_suffix) iter_data_time = time.time() - if epoch % opt.save_epoch_freq == 0: # cache our model every epochs + + # cache our model every epochs + if epoch % save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) model.save_networks('latest') model.save_networks(epoch) - print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) - model.update_learning_rate() # update learning rates at the end of every epoch. + print('End of epoch %d / %d \t Time Taken: %d sec' % ( + epoch, n_epochs + n_epochs_decay, time.time() - epoch_start_time)) + # update learning rates at the end of every epoch. + model.update_learning_rate() + + +if __name__ == '__main__': + train() From 2f808f7b3a0a7260d5037be70ef8db92083d4305 Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Wed, 31 Jan 2024 15:13:03 -0500 Subject: [PATCH 02/41] updated trainlaunch for single node multi gpu --- cli.py | 57 ++++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/cli.py b/cli.py index 40e1204..604838c 100644 --- a/cli.py +++ b/cli.py @@ -134,6 +134,7 @@ def cli(): help='learning rate policy. [linear | step | plateau | cosine]') @click.option('--lr-decay-iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') +@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)') # visdom and HTML visualization parameters @click.option('--display-freq', default=400, help='frequency of showing training results on screen') @click.option('--display-ncols', default=4, @@ -158,8 +159,6 @@ def cli(): help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') @click.option('--padding', type=str, default='zero', help='chooses the type of padding used by resnet generator. [reflect | zero]') -@click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup') -@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)') # DeepLIIFExt params @click.option('--seg-gen', type=bool, default=True, help='True (Translation and Segmentation), False (Only Translation).') @click.option('--net-ds', type=str, default='n_layers', @@ -170,14 +169,16 @@ def cli(): help='the type of GAN objective for translation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') @click.option('--gan-mode-s', type=str, default='lsgan', help='the type of GAN objective for segmentation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') +# DDP related arguments +@click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup') def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, ndf, net_d, net_g, n_layers_d, norm, init_type, init_gain, no_dropout, direction, serial_batches, num_threads, batch_size, load_size, crop_size, max_dataset_size, preprocess, no_flip, display_winsize, epoch, load_iter, verbose, lambda_l1, is_train, display_freq, display_ncols, display_id, display_server, display_env, display_port, update_html_freq, print_freq, no_html, save_latest_freq, save_epoch_freq, save_by_iter, continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, beta1, lr, lr_decay_iters, - remote, local_rank, remote_transfer_cmd, seed, dataset_mode, padding, model, - modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s): + remote, remote_transfer_cmd, seed, dataset_mode, padding, model, + modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s, local_rank): """General-purpose training script for multi-task image-to-image translation. This script works for various models (with option '--model': e.g., DeepLIIF) and @@ -213,12 +214,12 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd if local_rank is not None: local_rank = int(local_rank) torch.cuda.set_device(gpu_ids[local_rank]) - gpu_ids=[gpu_ids[local_rank]] + gpu_ids=[local_rank] else: torch.cuda.set_device(gpu_ids[0]) if local_rank is not None: # LOCAL_RANK will be assigned a rank number if torchrun ddp is used - dist.init_process_group(backend='nccl') + dist.init_process_group(backend="nccl", rank=int(os.environ['RANK']), world_size=int(os.environ['WORLD_SIZE'])) print('local rank:',local_rank) flag_deterministic = set_seed(seed,local_rank) elif rank is not None: @@ -244,6 +245,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd assert input_no > 0, f'inferred number of input images is {input_no}; should be greater than 0' d_params['input_no'] = input_no d_params['scale_size'] = img.size[1] + d_params['gpu_ids'] = gpu_ids # create a dataset given dataset_mode and other options # dataset = AlignedDataset(opt) @@ -336,8 +338,9 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd help='name of the experiment. It decides where to store samples and models') @click.option('--gpu-ids', type=int, multiple=True, help='gpu-ids 0 gpu-ids 1 or gpu-ids -1 for CPU') @click.option('--checkpoints-dir', default='./checkpoints', help='models are saved here') -@click.option('--targets-no', default=5, help='number of targets') +@click.option('--modalities-no', default=4, type=int, help='number of targets') # model parameters +@click.option('--model', default='DeepLIIF', help='name of model class') @click.option('--input-nc', default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') @click.option('--output-nc', default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') @click.option('--ngf', default=64, help='# of gen filters in the last conv layer') @@ -353,7 +356,6 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd @click.option('--init-type', default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') @click.option('--init-gain', default=0.02, help='scaling factor for normal, xavier and orthogonal.') -@click.option('--padding-type', default='reflect', help='network padding type.') @click.option('--no-dropout', is_flag=True, help='no dropout for the generator') # dataset parameters @click.option('--direction', default='AtoB', help='AtoB or BtoA') @@ -396,6 +398,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd help='learning rate policy. [linear | step | plateau | cosine]') @click.option('--lr-decay-iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') +@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)') # visdom and HTML visualization parameters @click.option('--display-freq', default=400, help='frequency of showing training results on screen') @click.option('--display-ncols', default=4, @@ -416,8 +419,22 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd @click.option('--save-by-iter', is_flag=True, help='whether saves model by iteration') @click.option('--remote', type=bool, default=False, help='whether isolate visdom checkpoints or not; if False, you can run a separate visdom server anywhere that consumes the checkpoints') @click.option('--remote-transfer-cmd', type=str, default=None, help='module and function to be used to transfer remote files to target storage location, for example mymodule.myfunction') +@click.option('--dataset-mode', type=str, default='aligned', + help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') +@click.option('--padding', type=str, default='zero', + help='chooses the type of padding used by resnet generator. [reflect | zero]') +# DeepLIIFExt params +@click.option('--seg-gen', type=bool, default=True, help='True (Translation and Segmentation), False (Only Translation).') +@click.option('--net-ds', type=str, default='n_layers', + help='specify discriminator architecture for segmentation task [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') +@click.option('--net-gs', type=str, default='unet_512', + help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]') +@click.option('--gan-mode', type=str, default='vanilla', + help='the type of GAN objective for translation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') +@click.option('--gan-mode-s', type=str, default='lsgan', + help='the type of GAN objective for segmentation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') +# DDP related arguments @click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup') -@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)') @click.option('--use-torchrun', type=str, default=None, help='provide torchrun options, all in one string, for example "-t3 --log_dir ~/log/ --nproc_per_node 1"; if your pytorch version is older than 1.10, torch.distributed.launch will be called instead of torchrun') def trainlaunch(**kwargs): """ @@ -448,6 +465,7 @@ def trainlaunch(**kwargs): elif args[i-1] not in l_arg_skip and arg not in l_arg_skip: # if the previous element is not an option name to skip AND if the current element is not an option to remove args_final.append(arg) + ## add quotes back to the input arg that had quotes, e.g., experiment name args_final = [f'"{arg}"' if ' ' in arg else arg for arg in args_final] @@ -457,16 +475,29 @@ def trainlaunch(**kwargs): #### locate train.py import deepliif - path_train_py = deepliif.__path__[0]+'/train.py' + path_train_py = deepliif.__path__[0]+'/scripts/train.py' + + #### find out GPUs to use + gpu_ids = [args_final[i+1] for i,v in enumerate(args_final) if v=='--gpu-ids'] + if len(gpu_ids) > 0 and gpu_ids[0] == -1: + gpu_ids = [] + + if len(gpu_ids) > 0: + opt_env = f"CUDA_VISIBLE_DEVICES=\"{','.join(gpu_ids)}\"" + else: + opt_env = '' #### execute train.py if kwargs['use_torchrun']: if version.parse(torch.__version__) >= version.parse('1.10.0'): - subprocess.run(f'torchrun {kwargs["use_torchrun"]} {path_train_py} {options}',shell=True) + cmd = f'{opt_env} torchrun {kwargs["use_torchrun"]} {path_train_py} {options}' else: - subprocess.run(f'python -m torch.distributed.launch {kwargs["use_torchrun"]} {path_train_py} {options}',shell=True) + cmd = f'{opt_env} python -m torch.distributed.launch {kwargs["use_torchrun"]} {path_train_py} {options}' else: - subprocess.run(f'python {path_train_py} {options}',shell=True) + cmd = f'{opt_env} python {path_train_py} {options}' + + print('Executing command:',cmd) + subprocess.run(cmd,shell=True) From 04db7ce991ea90e9368c2fdaf8c91720a616d348 Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Tue, 27 Feb 2024 14:50:14 +0000 Subject: [PATCH 03/41] added auto-determination of tile size for data augementation --- Image_Processing/Augmentation.py | 18 ++++++++++-------- .../Image_Processing_Helper_Functions.py | 17 +++++++++++++---- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/Image_Processing/Augmentation.py b/Image_Processing/Augmentation.py index b215eb0..a3fa06a 100644 --- a/Image_Processing/Augmentation.py +++ b/Image_Processing/Augmentation.py @@ -6,12 +6,13 @@ class Augmentation: - def __init__(self, images): + def __init__(self, images, tile_size=512): self.images = images self.shape = self.images[list(self.images.keys())[0]].shape self.rotation_angle = np.random.choice([0, 90, 180, 270], 1)[0] # self.zoom_value = random.randint(0, 5) self.alpha_affine = 0.1 + self.tile_size = tile_size def pipeline(self): """ @@ -30,12 +31,13 @@ def zoom(self): :return: """ new_size = random.randint(int(self.shape[0] * 0.75), self.shape[0]) + assert self.shape[1] - new_size >= 0, f'self.shape[1] - new_size ({self.shape[1]} - {new_size})should not be negative' start_point = (random.randint(0, self.shape[0] - new_size), random.randint(0, self.shape[1] - new_size)) for key in self.images.keys(): try: - self.images[key] = cv2.resize(self.images[key][start_point[0]: start_point[0] + new_size, start_point[1]: start_point[1] + new_size], (512, 512)) - except: - print(key + ' not available') + self.images[key] = cv2.resize(self.images[key][start_point[0]: start_point[0] + new_size, start_point[1]: start_point[1] + new_size], (self.tile_size, self.tile_size)) + except Exception as e: + print(e) def rotate(self): """ @@ -47,8 +49,8 @@ def rotate(self): for key in self.images.keys(): try: self.images[key] = ndimage.rotate(self.images[key], self.rotation_angle, reshape=False) - except: - print(key + ' not available') + except Exception as e: + print(e) def elastic_transform(self, random_state=None): """ @@ -78,5 +80,5 @@ def elastic_transform(self, random_state=None): for key in self.images.keys(): try: self.images[key] = cv2.warpAffine(self.images[key], M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101) - except: - print(key + ' not available') \ No newline at end of file + except Exception as e: + print(e) diff --git a/Image_Processing/Image_Processing_Helper_Functions.py b/Image_Processing/Image_Processing_Helper_Functions.py index 8f2e6b2..b71e552 100644 --- a/Image_Processing/Image_Processing_Helper_Functions.py +++ b/Image_Processing/Image_Processing_Helper_Functions.py @@ -87,7 +87,7 @@ def create_training_testing_dataset_from_given_directory(input_dir, output_dir, cv2.imwrite(os.path.join(all_dirs[i], filename), all_images[filename]) -def augment_set(input_dir, output_dir, aug_no=9, modality_types=['hematoxylin', 'CD3', 'PanCK'], tile_size=512): +def augment_set(input_dir, output_dir, aug_no=9, modality_types=['hematoxylin', 'CD3', 'PanCK']): """ This function augments a co-aligned dataset. @@ -105,20 +105,29 @@ def augment_set(input_dir, output_dir, aug_no=9, modality_types=['hematoxylin', """ if not os.path.exists(output_dir): os.makedirs(output_dir) - images = os.listdir(input_dir) - for img in images: + images_original = os.listdir(input_dir) + print(f'{len(images_original)} images found') + + count = 0 + for i,img in enumerate(images_original): augmented = 0 while augmented < aug_no: images = {} image = cv2.imread(os.path.join(input_dir, img)) + if i == 0: + tile_size = image.shape[0] + assert image.shape[1] >= len(modality_types) * tile_size, f'image width ({image.shape[1]}) is not enough for {len(modality_types)} modalities with tile size {tile_size}' for i in range(0, len(modality_types)): images[modality_types[i]] = image[:, i * tile_size: (i + 1) * tile_size] new_images = images.copy() - aug = Augmentation(new_images) + aug = Augmentation(new_images, tile_size) aug.pipeline() cv2.imwrite(os.path.join(output_dir, img.replace('.png', '_' + str(augmented) + '.png')), np.concatenate(list(new_images.values()), 1)) augmented += 1 + count += 1 + if count % 10 == 0 or count == len(images_original): + print(f'Done {count}/{len(images_original)}') def augment_created_dataset(input_dir, output_dir, aug_no=9, modality_types=['hematoxylin', 'CD3', 'PanCK'], tile_size=512): From d0033f4268c2129c45dd7357441568998b7e19de Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Tue, 27 Feb 2024 14:55:38 +0000 Subject: [PATCH 04/41] allowed change of subfolder & batchsize when loading data (the former for val during training and the latter for batchsize change during val/test) --- deepliif/data/__init__.py | 14 +++++++------- deepliif/data/aligned_dataset.py | 5 ++--- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/deepliif/data/__init__.py b/deepliif/data/__init__.py index 6a4e6c4..2966985 100644 --- a/deepliif/data/__init__.py +++ b/deepliif/data/__init__.py @@ -55,28 +55,28 @@ def get_option_setter(dataset_name): return dataset_class.modify_commandline_options -def create_dataset(opt): +def create_dataset(opt, phase=None, batch_size=None): """Create a dataset given the option. This function wraps the class CustomDatasetDataLoader. This is the main interface between this package and 'train.py'/'test.py' """ - return CustomDatasetDataLoader(opt) + return CustomDatasetDataLoader(opt, phase=phase if phase else opt.phase, batch_size=batch_size if batch_size else opt.batch_size) class CustomDatasetDataLoader(object): """Wrapper class of Dataset class that performs multi-threaded data loading""" - def __init__(self, opt): + def __init__(self, opt, phase=None, batch_size=None): """Initialize this class Step 1: create a dataset instance given the name [dataset_mode] Step 2: create a multi-threaded data loader. """ - self.batch_size = opt.batch_size + self.batch_size = batch_size if batch_size else opt.batch_size self.max_dataset_size = opt.max_dataset_size dataset_class = find_dataset_using_name(opt.dataset_mode) - self.dataset = dataset_class(opt) + self.dataset = dataset_class(opt, phase=phase if phase else opt.phase) print("dataset [%s] was created" % type(self.dataset).__name__) sampler = None @@ -95,7 +95,7 @@ def seed_worker(worker_id): self.dataloader = torch.utils.data.DataLoader( self.dataset, sampler=sampler, - batch_size=opt.batch_size, + batch_size=batch_size, shuffle=not opt.serial_batches if sampler is None else False, num_workers=int(opt.num_threads) ) @@ -106,7 +106,7 @@ def seed_worker(worker_id): self.dataloader = torch.utils.data.DataLoader( self.dataset, sampler=sampler, - batch_size=opt.batch_size, + batch_size=batch_size, shuffle=not opt.serial_batches if sampler is None else False, num_workers=int(opt.num_threads), worker_init_fn=seed_worker, diff --git a/deepliif/data/aligned_dataset.py b/deepliif/data/aligned_dataset.py index aa5928b..b1e7501 100644 --- a/deepliif/data/aligned_dataset.py +++ b/deepliif/data/aligned_dataset.py @@ -11,7 +11,7 @@ class AlignedDataset(BaseDataset): During test time, you need to prepare a directory '/path/to/data/test'. """ - def __init__(self, opt): + def __init__(self, opt, phase='train'): """Initialize this dataset class. Parameters: @@ -19,7 +19,7 @@ def __init__(self, opt): """ BaseDataset.__init__(self, opt.dataroot) self.preprocess = opt.preprocess - self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory + self.dir_AB = os.path.join(opt.dataroot, phase) # get the image directory self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths assert(opt.load_size >= opt.crop_size) # crop_size should be smaller than the size of loaded image self.input_nc = opt.output_nc if opt.direction == 'BtoA' else opt.input_nc @@ -95,7 +95,6 @@ def __getitem__(self, index): A = AB.crop((w2 * i, 0, w2 * (i+1), h)) A = A_transform(A) A_Array.append(A) - for i in range(self.input_no, self.input_no + self.modalities_no + 1): B = AB.crop((w2 * i, 0, w2 * (i + 1), h)) B = B_transform(B) From 4aa5194b7462acd8153b079d138832f9202263b1 Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Tue, 27 Feb 2024 14:57:25 +0000 Subject: [PATCH 05/41] updated segmentation metrics calculation for deepliifext --- DeepLIIF_Statistics/ComputeStatistics.py | 16 ++++++++++------ DeepLIIF_Statistics/Segmentation_Metrics.py | 10 +++++++--- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/DeepLIIF_Statistics/ComputeStatistics.py b/DeepLIIF_Statistics/ComputeStatistics.py index 804945e..7b35192 100644 --- a/DeepLIIF_Statistics/ComputeStatistics.py +++ b/DeepLIIF_Statistics/ComputeStatistics.py @@ -21,7 +21,7 @@ parser.add_argument('--gt_path', type=str, required=True) parser.add_argument('--model_path', type=str, required=True) parser.add_argument('--output_path', type=str, required=True) -parser.add_argument('--model_name', type=str, required=False, default='DeepLIIF') +parser.add_argument('--model_name', type=str, required=False, default='') parser.add_argument('--mode', type=str, default='Segmentation', help='Mode of the statistics computation including Segmentation, ImageSynthesis, All') parser.add_argument('--raw_segmentation', action='store_true') @@ -30,7 +30,8 @@ help='Batch size to use') parser.add_argument('--num_workers', type=int, default=8, help='Number of processes to use for data loading') -parser.add_argument('--image_types', type=str, default='Hema,DAPI,Lap2,Marker') +parser.add_argument('--image_types', type=str, default='Hema,DAPI,Lap2,Marker', help='These are non-seg modalities to be evaluated.') +parser.add_argument('--seg_type', type=str, default='Seg', help='This is the seg modality to be evaluated.') class Statistics: @@ -45,6 +46,7 @@ def __init__(self, args): self.num_workers = args.num_workers self.device = args.device self.image_types = args.image_types.replace(' ', '').split(',') + self.seg_type = args.seg_type # Image Similarity Metrics self.inception_avg = collections.defaultdict(float) @@ -184,8 +186,9 @@ def compute_segmentation_metrics(self): thresh = 100 boundary_thresh = 100 noise_size = 50 - print(thresh, noise_size) - self.segmentation_info, self.segmentation_metrics = compute_segmentation_metrics(self.gt_path, self.model_path, self.model_name, image_size=512, thresh=thresh, boundary_thresh=boundary_thresh, small_object_size=noise_size, raw_segmentation=self.raw_segmentation) + + self.segmentation_info, self.segmentation_metrics = compute_segmentation_metrics(self.gt_path, self.model_path, self.model_name, image_size=512, thresh=thresh, boundary_thresh=boundary_thresh, small_object_size=noise_size, raw_segmentation=self.raw_segmentation, suffix_seg=self.seg_type) + assert len(self.segmentation_info) > 0, 'No segmentation results returned; one typical cause is a wrong --seg-type and/or --image-types that cannot be found in any image filename' self.write_list_to_csv(self.segmentation_info, self.segmentation_info[0].keys(), filename='segmentation_info_' + self.mode + '_' + self.model_name + '_' + str(thresh) + '_' + str(noise_size) + '.csv') for key in self.segmentation_metrics: @@ -202,14 +205,14 @@ def compute_statistics(self): self.create_all_info() def write_dict_to_csv(self, info_dict, csv_columns, filename='info.csv'): - print('Writing in csv') + print('Writing in csv',os.path.join(self.output_path, filename)) info_csv = open(os.path.join(self.output_path, filename), 'w') writer = csv.DictWriter(info_csv, fieldnames=csv_columns) writer.writeheader() writer.writerow(info_dict) def write_list_to_csv(self, info_dict, csv_columns, filename='info.csv'): - print('Writing in csv') + print('Writing in csv',os.path.join(self.output_path, filename)) info_csv = open(os.path.join(self.output_path, filename), 'w') writer = csv.DictWriter(info_csv, fieldnames=csv_columns) writer.writeheader() @@ -228,3 +231,4 @@ def write_list_to_csv(self, info_dict, csv_columns, filename='info.csv'): stat.compute_segmentation_metrics() elif stat.mode == 'ImageSynthesis': stat.compute_image_similarity_metrics() + stat.create_all_info() diff --git a/DeepLIIF_Statistics/Segmentation_Metrics.py b/DeepLIIF_Statistics/Segmentation_Metrics.py index 65a62f2..4e2c50a 100644 --- a/DeepLIIF_Statistics/Segmentation_Metrics.py +++ b/DeepLIIF_Statistics/Segmentation_Metrics.py @@ -102,12 +102,16 @@ def compute_aji(gt_image, mask_image): return aji -def compute_segmentation_metrics(gt_dir, model_dir, model_name, image_size=512, thresh=100, boundary_thresh=100, small_object_size=20, raw_segmentation=True): +def compute_segmentation_metrics(gt_dir, model_dir, model_name, image_size=512, thresh=100, boundary_thresh=100, small_object_size=20, raw_segmentation=True, suffix_seg=None): info_dict = [] metrics = collections.defaultdict(float) images = os.listdir(model_dir) + counter = 0 - postfix = '_Seg.png' if raw_segmentation else '_SegRefined.png' + if suffix_seg is not None: + postfix = f'_{suffix_seg}.png' + else: + postfix = '_Seg.png' if raw_segmentation else '_SegRefined.png' for mask_name in images: if postfix in mask_name: counter += 1 @@ -123,7 +127,7 @@ def compute_segmentation_metrics(gt_dir, model_dir, model_name, image_size=512, positive_mask[positive_mask > 0] = 1 negative_mask[negative_mask > 0] = 1 - gt_img = cv2.cvtColor(cv2.imread(os.path.join(gt_dir, mask_name.replace(postfix, '.png'))), cv2.COLOR_BGR2RGB) + gt_img = cv2.cvtColor(cv2.imread(os.path.join(gt_dir, mask_name)), cv2.COLOR_BGR2RGB) gt_img = cv2.resize(gt_img, (image_size, image_size)) positive_gt = gt_img[:, :, 0] From bcd4d926328fa608b78de90374e5988c79a6aa69 Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Tue, 5 Mar 2024 16:01:04 +0000 Subject: [PATCH 06/41] added resizeconv, currently inferred from model name / dir --- deepliif/models/SDG_model.py | 14 ++++++++------ deepliif/models/networks.py | 24 ++++++++++++++++++------ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/deepliif/models/SDG_model.py b/deepliif/models/SDG_model.py index ed0896d..5798eed 100644 --- a/deepliif/models/SDG_model.py +++ b/deepliif/models/SDG_model.py @@ -15,6 +15,8 @@ def __init__(self, opt): BaseModel.__init__(self, opt) self.mod_gen_no = self.opt.modalities_no + opt.resize_conv = 'resizeconv' in opt.checkpoints_dir or 'resizeconv' in opt.name + # weights of the modalities in generating segmentation mask self.seg_weights = [0, 0, 0] @@ -55,7 +57,7 @@ def __init__(self, opt): self.netG = [None for _ in range(self.mod_gen_no)] for i in range(self.mod_gen_no): self.netG[i] = networks.define_G(self.opt.input_nc * self.opt.input_no, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm, - not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.opt.gpu_ids, self.opt.padding) + not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.opt.gpu_ids, self.opt.padding, resize_conv=opt.resize_conv) print('***************************************') print(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm, not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.opt.gpu_ids, self.opt.padding) @@ -159,15 +161,15 @@ def backward_G(self): for i in range(self.mod_gen_no): self.loss_G_L1.append(self.criterionSmoothL1(self.fake_B[i], self.real_B[i]) * self.opt.lambda_L1) - #self.loss_G_VGG = [] - #for i in range(self.mod_gen_no): - # self.loss_G_VGG.append(self.criterionVGG(self.fake_B[i], self.real_B[i]) * self.opt.lambda_feat) + self.loss_G_VGG = [] + for i in range(self.mod_gen_no): + self.loss_G_VGG.append(self.criterionVGG(self.fake_B[i], self.real_B[i]) * self.opt.lambda_feat) # self.loss_G = (self.loss_G_GAN[0] + self.loss_G_L1[0]) * self.loss_G_weights[0] self.loss_G = torch.tensor(0., device=self.device) for i in range(0, self.mod_gen_no): - self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i]) * self.loss_G_weights[i] - # self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i]) * self.loss_G_weights[i] + #self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i]) * self.loss_G_weights[i] + self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i]) * self.loss_G_weights[i] self.loss_G.backward() def optimize_parameters(self): diff --git a/deepliif/models/networks.py b/deepliif/models/networks.py index ea99f6a..989639f 100644 --- a/deepliif/models/networks.py +++ b/deepliif/models/networks.py @@ -125,7 +125,7 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): return net -def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], padding_type='reflect'): +def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], padding_type='reflect', resize_conv=False): """Create a generator Parameters: @@ -156,7 +156,7 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in norm_layer = get_norm_layer(norm_type=norm) if netG == 'resnet_9blocks': - net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, padding_type=padding_type) + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, padding_type=padding_type, resize_conv=resize_conv) elif netG == 'resnet_6blocks': net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, padding_type=padding_type) elif netG == 'unet_128': @@ -332,9 +332,12 @@ class ResnetGenerator(nn.Module): """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + + Resize-conv: optional replacement of ConvTranspose2d + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/190#issuecomment-358546675 """ - def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='zero'): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='zero', resize_conv=False): """Construct a Resnet-based generator Parameters: @@ -346,6 +349,7 @@ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_d n_blocks (int) -- the number of ResNet blocks padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero """ + print('resize_conv',resize_conv) assert(n_blocks >= 0) super(ResnetGenerator, self).__init__() if type(norm_layer) == functools.partial: @@ -378,11 +382,19 @@ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_d for i in range(n_downsampling): # add upsampling layers mult = 2 ** (n_downsampling - i) - model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + if resize_conv: + upsample_layer = [#nn.Upsample(scale_factor = 2, mode='bilinear',align_corners=True), + nn.Upsample(scale_factor = 2, mode='nearest'), + nn.ReflectionPad2d(1), + nn.Conv2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=1, padding=0)] + else: + upsample_layer = [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, - bias=use_bias), - norm_layer(int(ngf * mult / 2)), + bias=use_bias)] + model += upsample_layer + model += [norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] if padding_type == 'reflect': From c03245b94dd030d7ef8a56af365630f2cefd9b9c Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Thu, 7 Mar 2024 20:56:10 +0000 Subject: [PATCH 07/41] fixed case when SDG has only 1 input modality --- deepliif/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepliif/models/__init__.py b/deepliif/models/__init__.py index 6c51259..61845d6 100644 --- a/deepliif/models/__init__.py +++ b/deepliif/models/__init__.py @@ -231,7 +231,7 @@ def run_dask(img, model_path, eager_mode=False, opt=None): model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path) nets = init_nets(model_dir, eager_mode, opt) - if opt.input_no > 1: + if opt.input_no > 1 or opt.model == 'SDG': l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img] ts = torch.cat(l_ts, dim=1) else: From a497cf46883b46511c34c05af231c176a822287a Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Thu, 7 Mar 2024 20:59:54 +0000 Subject: [PATCH 08/41] added an experimental submodule to facilitate model evaluation and metric calculation --- deepliif/stat/ComputeStatistics.py | 268 ++++++++ deepliif/stat/Create_Loss_Diagram.py | 81 +++ deepliif/stat/HelperFunctions.py | 667 +++++++++++++++++++ deepliif/stat/PostProcessSegmentationMask.py | 171 +++++ deepliif/stat/Segmentation_Metrics.py | 237 +++++++ deepliif/stat/__init__.py | 323 +++++++++ deepliif/stat/fid.py | 334 ++++++++++ deepliif/stat/fid_official_tf.py | 370 ++++++++++ deepliif/stat/inception_score.py | 100 +++ deepliif/stat/swd.py | 158 +++++ 10 files changed, 2709 insertions(+) create mode 100644 deepliif/stat/ComputeStatistics.py create mode 100644 deepliif/stat/Create_Loss_Diagram.py create mode 100644 deepliif/stat/HelperFunctions.py create mode 100644 deepliif/stat/PostProcessSegmentationMask.py create mode 100644 deepliif/stat/Segmentation_Metrics.py create mode 100644 deepliif/stat/__init__.py create mode 100644 deepliif/stat/fid.py create mode 100644 deepliif/stat/fid_official_tf.py create mode 100644 deepliif/stat/inception_score.py create mode 100644 deepliif/stat/swd.py diff --git a/deepliif/stat/ComputeStatistics.py b/deepliif/stat/ComputeStatistics.py new file mode 100644 index 0000000..ca471fb --- /dev/null +++ b/deepliif/stat/ComputeStatistics.py @@ -0,0 +1,268 @@ +import os +import cv2 +import numpy as np +import csv +from numba import cuda +import time + +from .Segmentation_Metrics import compute_segmentation_metrics +#from fid_official_tf import calculate_fid_given_paths +from .fid import calculate_fid_given_paths +from .inception_score import calculate_inception_score + +from skimage.metrics import structural_similarity as ssim +from skimage.metrics import peak_signal_noise_ratio as psnr +from skimage.metrics import mean_squared_error +from skimage import img_as_float, io, measure +from skimage.color import rgb2gray +import collections +from .swd import compute_swd + + +""" +params: + gt_path + model_path + output_path + model_name + mode: Mode of the statistics computation including Segmentation, ImageSynthesis, All + raw_segmentation + device: Device to use. Like cuda, cuda:0 or cpu + batch_size: Batch size to use + num_workers: Number of processes to use for data loading + image_types: These are non-seg modalities to be evaluated. + seg_type: This is the seg modality to be evaluated. +""" + +class Statistics: + def __init__(self, gt_path, model_path, output_path, model_name='', mode='Segmentation', + raw_segmentation=False, device='cuda', batch_size=50, num_workers=8, + image_types='Hema,DAPI,Lap2,Marker', seg_type='Seg'): + self.gt_path = gt_path + self.model_path = model_path + self.output_path = output_path + self.model_name = model_name + self.mode = mode + self.raw_segmentation = raw_segmentation + self.batch_size = batch_size + self.num_workers = num_workers + self.device = device + self.image_types = image_types.replace(' ', '').split(',') + self.seg_type = seg_type + + # Image Similarity Metrics + self.inception_avg = collections.defaultdict(float) + self.inception_std = collections.defaultdict(float) + + self.mse_avg = collections.defaultdict(float) + self.mse_std = collections.defaultdict(float) + + self.ssim_avg = collections.defaultdict(float) + self.ssim_std = collections.defaultdict(float) + + self.psnr_avg = collections.defaultdict(float) + self.psnr_std = collections.defaultdict(float) + + self.fid_value = collections.defaultdict(float) + self.swd_value = collections.defaultdict(float) + + self.all_info = {} + self.all_info['Model'] = self.model_name + + # Segmentation Metrics + self.segmentation_metrics = collections.defaultdict(float) + self.segmentation_info = None + + if not os.path.exists(self.output_path): + os.makedirs(self.output_path) + + def compute_mse_ssim_scores(self): + for img_type in self.image_types: + images = os.listdir(self.model_path) + mse_arr = [] + ssim_arr = [] + # mse_info = [] + for img_name in images: + if img_type in img_name: + orig_img = img_as_float(rgb2gray(io.imread(os.path.join(self.gt_path, img_name)))) + mask_img = img_as_float(rgb2gray(io.imread(os.path.join(self.model_path, img_name)))) + + mse_mask = mean_squared_error(orig_img, mask_img) + ssim_mask = ssim(orig_img, mask_img, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1) + + mse_arr.append(mse_mask) + ssim_arr.append(ssim_mask) + + # mse_info.append({'image_name': img_name, 'image_type':img_type, 'mse': mse_mask, 'ssim': ssim_mask}) + # self.write_list_to_csv(mse_info, mse_info[0].keys(), + # filename='inference_info_' + img_type + '_' + self.model_name + '.csv') + self.mse_avg[img_type], self.mse_std[img_type] = np.mean(mse_arr), np.std(mse_arr) + self.ssim_avg[img_type], self.ssim_std[img_type] = np.mean(ssim_arr), np.std(ssim_arr) + self.all_info['mse_avg'] = self.mse_avg + self.all_info['mse_std'] = self.mse_std + self.all_info['ssim_avg'] = self.ssim_avg + self.all_info['ssim_std'] = self.ssim_std + + def compute_psnr_scores(self): + """ + Peak signal-to-noise ratio + """ + for img_type in self.image_types: + images = os.listdir(self.model_path) + mse_arr = [] + score_arr = [] + # mse_info = [] + for img_name in images: + if img_type in img_name: + orig_img = img_as_float(rgb2gray(io.imread(os.path.join(self.gt_path, img_name)))) + mask_img = img_as_float(rgb2gray(io.imread(os.path.join(self.model_path, img_name)))) + #print(orig_img) + + score_arr.append(psnr(orig_img, mask_img, data_range=1)) + self.psnr_avg[img_type], self.psnr_std[img_type] = np.mean(score_arr), np.std(score_arr) + self.all_info['psnr_avg'] = self.psnr_avg + self.all_info['psnr_std'] = self.psnr_std + + def compute_inception_score(self): + for img_type in self.image_types: + images = os.listdir(self.model_path) + real_images_array = [] + for img in images: + if img_type in img: + image = cv2.imread(os.path.join(self.model_path, img)) + image = cv2.resize(image, (299, 299)) + real_images_array.append(image) + real_images_array = np.array(real_images_array) + self.inception_avg[img_type], self.inception_std[img_type] = calculate_inception_score(real_images_array) + + def compute_fid_score(self): + for img_type in self.image_types: + os.environ['CUDA_VISIBLE_DEVICES'] = self.gpu + self.fid_value[img_type] = calculate_fid_given_paths([self.gt_path, self.model_path], None, low_profile=False) + print("FID: ", self.fid_value[img_type]) + # self.fid_value[img_type] = calculate_fid_given_paths(paths=[self.gt_path, self.model_path], batch_size=self.batch_size, dims=self.fid_dims, num_workers=self.num_workers, mod_type='_' + img_type) + + device = cuda.get_current_device() + device.reset() + + def compute_swd(self): + for img_type in self.image_types: + orig_images = [] + mask_images = [] + images = os.listdir(self.model_path) + for img_name in images: + if img_type in img_name: + orig_img = cv2.cvtColor(cv2.imread(os.path.join(self.gt_path, img_name)), cv2.COLOR_BGR2RGB) + mask_img = cv2.cvtColor(cv2.imread(os.path.join(self.model_path, img_name)), cv2.COLOR_BGR2RGB) + orig_images.append(orig_img) + mask_images.append(mask_img) + + self.swd_value[img_type] = compute_swd(np.array(orig_images), np.array(mask_images), self.device) + + def compute_image_similarity_metrics(self): + self.compute_mse_ssim_scores() + print('SSIM Computed') + self.compute_inception_score() + print('inception Computed') + self.compute_fid_score() + print('fid Computed') + self.compute_swd() + print('swd Computed') + + for key in self.mse_avg: + self.all_info[key + '_' + 'MSE_avg'] = self.mse_avg[key] + self.all_info[key + '_' + 'MSE_std'] = self.mse_std[key] + self.all_info[key + '_' + 'ssim_avg'] = self.ssim_avg[key] + self.all_info[key + '_' + 'ssim_std'] = self.ssim_std[key] + self.all_info[key + '_' + 'inception_avg'] = self.inception_avg[key] + self.all_info[key + '_' + 'inception_std'] = self.inception_std[key] + self.all_info[key + '_' + 'fid_value'] = self.fid_value[key] + self.all_info[key + '_' + 'swd_value'] = self.swd_value[key] + + def compute_IHC_scoring(self): + images = os.listdir(self.gt_path) + IHC_info = [] + for img in images: + gt_image = cv2.cvtColor(cv2.imread(os.path.join(self.gt_path, img)), cv2.COLOR_BGR2RGB) + if 'DeepLIIF' in self.model_name: + mask_image = cv2.cvtColor(cv2.imread(os.path.join(self.model_path, img.replace('_Seg', '_Seg_Refined'))), cv2.COLOR_BGR2RGB) + else: + mask_image = cv2.cvtColor(cv2.imread(os.path.join(self.model_path, img)), cv2.COLOR_BGR2RGB) + gt_image[gt_image < 10] = 0 + label_image_red_gt = measure.label(gt_image[:, :, 0], background=0) + label_image_blue_gt = measure.label(gt_image[:, :, 2], background=0) + number_of_positive_cells_gt = (len(np.unique(label_image_red_gt)) - 1) + number_of_negative_cells_gt = (len(np.unique(label_image_blue_gt)) - 1) + number_of_all_cells_gt = number_of_positive_cells_gt + number_of_negative_cells_gt + gt_IHC_score = number_of_positive_cells_gt / number_of_all_cells_gt if number_of_all_cells_gt > 0 else 0 + + mask_image[mask_image < 10] = 0 + label_image_red_mask = measure.label(mask_image[:, :, 0], background=0) + label_image_blue_mask = measure.label(mask_image[:, :, 2], background=0) + number_of_positive_cells_mask = (len(np.unique(label_image_red_mask)) - 1) + number_of_negative_cells_mask = (len(np.unique(label_image_blue_mask)) - 1) + number_of_all_cells_mask = number_of_positive_cells_mask + number_of_negative_cells_mask + mask_IHC_score = number_of_positive_cells_mask / number_of_all_cells_mask if number_of_all_cells_mask > 0 else 0 + diff = abs(gt_IHC_score * 100 - mask_IHC_score * 100) + IHC_info.append({'Model': self.model_name, 'Sample': img, 'Diff_IHC_Score': diff}) + self.write_list_to_csv(IHC_info, IHC_info[0].keys(), + filename='IHC_Scoring_info_' + self.mode + '_' + self.model_name + '.csv') + + def compute_segmentation_metrics(self): + # max_dice = [0, 0, 0] + # max_AJI = [0, 0, 0] + # for thresh in range(60, 150, 10): + # for noise_size in range(10, 80, 20): + thresh = 100 + boundary_thresh = 100 + noise_size = 50 + + self.segmentation_info, self.segmentation_metrics = compute_segmentation_metrics(self.gt_path, self.model_path, self.model_name, image_size=512, thresh=thresh, boundary_thresh=boundary_thresh, small_object_size=noise_size, raw_segmentation=self.raw_segmentation, suffix_seg=self.seg_type) + assert len(self.segmentation_info) > 0, 'No segmentation results returned; one typical cause is a wrong --seg-type and/or --image-types that cannot be found in any image filename' + self.write_list_to_csv(self.segmentation_info, self.segmentation_info[0].keys(), + filename='segmentation_info_' + self.mode + '_' + self.model_name + '_' + str(thresh) + '_' + str(noise_size) + '.csv') + for key in self.segmentation_metrics: + self.all_info[key] = self.segmentation_metrics[key] + print(key, self.all_info[key]) + print('-------------------------------------------------------') + + def create_all_info(self): + self.write_dict_to_csv(self.all_info, list(self.all_info.keys()), filename='metrics_' + self.mode + '_' + self.model_name + '.csv') + + def compute_statistics(self): + self.compute_image_similarity_metrics() + self.compute_segmentation_metrics() + self.create_all_info() + + def write_dict_to_csv(self, info_dict, csv_columns, filename='info.csv'): + print('Writing in csv',os.path.join(self.output_path, filename)) + info_csv = open(os.path.join(self.output_path, filename), 'w') + writer = csv.DictWriter(info_csv, fieldnames=csv_columns) + writer.writeheader() + writer.writerow(info_dict) + + def write_list_to_csv(self, info_dict, csv_columns, filename='info.csv'): + print('Writing in csv',os.path.join(self.output_path, filename)) + info_csv = open(os.path.join(self.output_path, filename), 'w') + writer = csv.DictWriter(info_csv, fieldnames=csv_columns) + writer.writeheader() + for data in info_dict: + writer.writerow(data) + + def run(self,write_to_csv=False): + if self.mode == 'All': + self.compute_statistics() + self.compute_IHC_scoring() + elif self.mode == 'Segmentation': + self.compute_segmentation_metrics() + elif self.mode == 'ImageSynthesis': + self.compute_image_similarity_metrics() + elif self.mode == 'SSIM': + self.compute_mse_ssim_scores() + elif self.mode == 'Upscaling': + self.compute_mse_ssim_scores() + self.compute_psnr_scores() + if write_to_csv: + self.create_all_info() + return self.all_info + diff --git a/deepliif/stat/Create_Loss_Diagram.py b/deepliif/stat/Create_Loss_Diagram.py new file mode 100644 index 0000000..586489e --- /dev/null +++ b/deepliif/stat/Create_Loss_Diagram.py @@ -0,0 +1,81 @@ +import collections +import numpy as np +import matplotlib.pyplot as plt + +def isfloat(num): + try: + float(num) + return True + except ValueError: + return False + + +def read_losses(file_name): + losses = {} + with open(file_name) as f: + lines = f.readlines() + for line in lines: + if line.startswith('(epoch'): + line = line.replace(',', '').replace('(', '').replace(')', '').replace(':', '').strip() + values = line.split(' ') + current_losses = {} + epoch_number = -1 + for i in range(len(values)): + if values[i] == 'epoch': + epoch_number = int(values[i + 1]) + else: + if not isfloat(values[i]) and values[i] != 'time' and values[i] != 'iters' and values[i] != 'data': + print(values[i], values[i + 1]) + current_losses[values[i]] = float(values[i + 1]) + + losses[epoch_number] = current_losses + loss_values = collections.defaultdict(list) + for key in losses: + loss_values['epoch'].append(key) + for k in losses[key].keys(): + loss_values[k].append(losses[key][k]) + + return loss_values + + +def create_loss_diagram(file_name): + loss_values = read_losses(file_name) + x = loss_values['epoch'] + plt.figure(figsize=(12, 4)) + for i in range(1, 6): + plt.plot(x, loss_values['G_GAN_' + str(i)], label='G_GAN_' + str(i)) + # plt.legend() + # plt.show() + + # plt.figure() + for i in range(1, 6): + plt.plot(x, loss_values['G_L1_' + str(i)], label='G_L1_' + str(i)) + # plt.legend() + # plt.show() + + # plt.figure() + for i in range(1, 6): + plt.plot(x, loss_values['D_real_' + str(i)], label='D_real_' + str(i)) + # plt.legend() + # plt.show() + + # plt.figure() + for i in range(1, 6): + plt.plot(x, loss_values['D_fake_' + str(i)], label='D_fake_' + str(i)) + # plt.legend() + # plt.gca().legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=5) + plt.legend(ncol=5) + plt.show() + # plt.scatter(x, y) + # plt.plot(x, y) + # plt.title("Connected Scatterplot points with line") + # plt.xlabel("epoch") + # plt.ylabel("G_GAN_1") + # plt.show() + # figure.tight_layout() + # print(loss_values) + + +# create_loss_diagram('D://DeepLIIF//checkpoints//DeepLIIF_Empty_500_Model//loss_log.txt') +# create_loss_diagram('C://Users//localadmin//Desktop//loss_log_SpectralNorm_SYN.txt') +create_loss_diagram('C://Users//localadmin//Desktop//loss_log_DeepLIIF.txt') \ No newline at end of file diff --git a/deepliif/stat/HelperFunctions.py b/deepliif/stat/HelperFunctions.py new file mode 100644 index 0000000..11ed519 --- /dev/null +++ b/deepliif/stat/HelperFunctions.py @@ -0,0 +1,667 @@ +import csv +import json +import math +import os +import h5py +import numpy as np +# import staintools +from matplotlib.colors import LinearSegmentedColormap +from numba import jit +from scipy import ndimage +import cv2 +import random +from skimage.feature import peak_local_max +import skimage.segmentation +from skimage.morphology import watershed as ws, remove_small_objects +import matplotlib.pyplot as plt + +def remove_small_objects_from_image(red_channel, min_size=100): + red_channel_copy = red_channel.copy() + red_channel_copy[red_channel > 0] = 1 + red_channel_copy = red_channel_copy.astype(np.bool) + removed_red_channel = remove_small_objects(red_channel_copy, min_size=min_size).astype(np.uint8) + red_channel[removed_red_channel == 0] = 0 + return red_channel + + +def read_BC_detection_mask(img_name, data_type='test'): + img_type = img_name.split('.')[-1] + base_dir = '/home/parmida/Downloads/BCData' + annotations_dir = os.path.join(base_dir, 'annotations') + images_dir = os.path.join(base_dir, 'images') + negative_dir = os.path.join(annotations_dir, data_type, 'negative') + positive_dir = os.path.join(annotations_dir, data_type, 'positive') + images_dir = os.path.join(images_dir, data_type) + print(os.path.join(negative_dir, img_name.replace('.png', '.h5'))) + gt_file_negative = h5py.File(os.path.join(negative_dir, img_name.replace('.' + img_type, '.h5'))) + coordinates_negative = np.asarray(gt_file_negative['coordinates']) + gt_file_positive = h5py.File(os.path.join(positive_dir, img_name.replace('.' + img_type, '.h5'))) + coordinates_positive = np.asarray(gt_file_positive['coordinates']) + + positive_mask = np.zeros((640, 640), dtype=np.uint8) + negative_mask = np.zeros((640, 640), dtype=np.uint8) + for coord in coordinates_positive: + positive_mask[coord[1], coord[0]] = 255 + + for coord in coordinates_negative: + negative_mask[coord[1], coord[0]] = 255 + + return positive_mask, negative_mask + +def read_BC_detection_point(img_name, data_type='test'): + img_type = img_name.split('.')[-1] + base_dir = '/home/parmida/Downloads/BCData' + annotations_dir = os.path.join(base_dir, 'annotations') + images_dir = os.path.join(base_dir, 'images') + negative_dir = os.path.join(annotations_dir, data_type, 'negative') + positive_dir = os.path.join(annotations_dir, data_type, 'positive') + images_dir = os.path.join(images_dir, data_type) + print(os.path.join(negative_dir, img_name.replace('.png', '.h5'))) + gt_file_negative = h5py.File(os.path.join(negative_dir, img_name.replace('.' + img_type, '.h5'))) + coordinates_negative = np.asarray(gt_file_negative['coordinates']) + gt_file_positive = h5py.File(os.path.join(positive_dir, img_name.replace('.' + img_type, '.h5'))) + coordinates_positive = np.asarray(gt_file_positive['coordinates']) + + return coordinates_positive, coordinates_negative + + + +def compute_TP_FP_of_each_class(image, marked_class): + labeled, nr_objects = ndimage.label(image > 0) + TP = 0 + FP = 0 + for c in range(1, nr_objects): + component = np.zeros_like(image) + component[labeled == c] = image[labeled == c] + component = cv2.morphologyEx(component, cv2.MORPH_DILATE, kernel=np.ones((5, 5)), iterations=1) + TP, FP = compute_component_TP_FP(component, marked_class, TP, FP) + return TP, FP + + +@jit(nopython=True) +def compute_component_TP_FP(component, marked_class, TP, FP): + indices = np.nonzero(component) + cell_flag = False + for i in range(len(indices[0])): + if marked_class[indices[0][i], indices[1][i]] > 0: + TP += 1 + cell_flag = True + if not cell_flag: + FP += 1 + return TP, FP + + +def compute_precision_recall_f1(TP, FP, FN): + precision = TP / (TP + FP) if (TP + FP) > 0 else 1 + recall = TP / (TP + FN) if (TP + FN) > 0 else 1 + F1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 + return precision, recall, F1 + + +def mark_Shiraz_image_with_markers(image, immunopositive, immunonegative, immunoTIL): + marked_image = image.copy() + positive = cv2.morphologyEx(immunopositive, cv2.MORPH_DILATE, kernel=np.ones((5,5))) + negative = cv2.morphologyEx(immunonegative, cv2.MORPH_DILATE, kernel=np.ones((5,5))) + TIL = cv2.morphologyEx(immunoTIL, cv2.MORPH_DILATE, kernel=np.ones((5,5))) + marked_image[positive > 0] = (0,0,255) + marked_image[negative > 0] = (255,0,0) + marked_image[TIL > 0] = (0,255,0) + return marked_image + +def read_NuClick_mask(img_name, dir_type='Train'): + # image_dir = '/home/parmida/Pathology/IHC_Nuclick/images/Train' + mask_dir = '/home/parmida/Pathology/IHC_Nuclick/masks/' + dir_type + mask = np.load(os.path.join(mask_dir, img_name.replace('.png', '.npy'))) + labeled_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) + final_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) + labels_no = np.max(mask) + 1 + color_dict = {} + color_dict[0] = (0, 0, 0) + for i in range(1, labels_no): + color_dict[i] = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + for i in range(mask.shape[0]): + for j in range(mask.shape[1]): + labeled_mask[i, j] = color_dict[mask[i, j]] + final_mask[i, j] = (0, 0, 0) + boundaries = cv2.Canny(labeled_mask, 100, 200) + # boundaries = cv2.dilate(boundaries, kernel=np.ones((3, 3), np.uint8)) + labeled_mask_bw = cv2.cvtColor(labeled_mask, cv2.COLOR_RGB2GRAY) + final_mask[labeled_mask_bw > 0] = (0, 0, 255) + # final_mask[labeled_mask_bw == 0] = (0, 0, 255) + # final_mask[boundaries > 0] = (255, 255, 255) + + contours, hierarchy = cv2.findContours(boundaries, + cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + cv2.drawContours(final_mask, contours, -1, (255, 255, 255), 2) + # cv2.imshow('labeled_mask', labeled_mask) + # cv2.imshow('boundaries', boundaries) + # cv2.imshow('final', final_mask) + # cv2.waitKey(0) + boundaries[boundaries > 0] = 255 + return final_mask + + +def get_detection_points(seg_img): + seg_img = cv2.resize(seg_img, (640, 640)) + det_img = np.zeros_like(seg_img) + thresh = 50 + det_img[np.logical_and(seg_img[:, :, 2] > thresh, seg_img[:, :, 2] > seg_img[:, :, 0] + 50)] = (0, 0, 255) + det_img[np.logical_and(seg_img[:, :, 0] > thresh, seg_img[:, :, 0] >= seg_img[:, :, 2])] = (255, 0, 0) + det_img[seg_img[:, :, 1] > thresh] = 0 + det_img[:, :, 0] = remove_small_objects_from_image(det_img[:, :, 0], 80) + det_img[:, :, 2] = remove_small_objects_from_image(det_img[:, :, 2], 80) + det_img[:, :, 0] = ndimage.binary_fill_holes(det_img[:, :, 0]).astype(np.uint8) * 255 + det_img[:, :, 2] = ndimage.binary_fill_holes(det_img[:, :, 2]).astype(np.uint8) * 255 + # det_img[:, :, 0] = cv2.morphologyEx(det_img[:, :, 0], cv2.MORPH_ERODE, kernel=np.ones((3, 3)), iterations=2) + # det_img[:, :, 2] = cv2.morphologyEx(det_img[:, :, 2], cv2.MORPH_ERODE, kernel=np.ones((3, 3)), iterations=2) + # cv2.imshow('det_img', det_img) + det_img = np.squeeze(det_img).astype(np.uint8) + cells = watershed(det_img) + final_cells = [] + positive_points = [] + negative_points = [] + seen = np.zeros((seg_img.shape[0], seg_img.shape[1]), dtype=np.uint8) + for i in range(len(cells)): + p1 = cells[i] + x1, y1, c1 = int(p1[1]), int(p1[0]), int(p1[2]) + flag = False + seen[x1][y1] = 1 + for j in range(len(cells)): + p2 = cells[j] + x2, y2, c2 = int(p2[1]), int(p2[0]), int(p2[2]) + if seen[x2][y2] == 0: + if abs(x1 - x2) < 20 and abs(y1 - y2) < 20: + flag = True + # new_cell = int((x1 + x2) / 2), int((y1 + y2) / 2), int((c1 + c2)/2) + # final_cells.append(new_cell) + if not flag: + final_cells.append(p1) + if c1 == 2: + positive_points.append((x1, y1)) + elif c1 == 0: + negative_points.append((x1, y1)) + return final_cells, positive_points, negative_points + +def detect_circles(component, output): + gray_blurred = cv2.blur(component, (3, 3)) + + # Apply Hough transform on the blurred image. + detected_circles = cv2.HoughCircles(gray_blurred, + cv2.HOUGH_GRADIENT, 1, 20, param1=100, + param2=20, minRadius=1, maxRadius=40) + # circles = cv2.HoughCircles(component, cv2.HOUGH_GRADIENT, 1, 10) + # ensure at least some circles were found + if detected_circles is not None: + + # Convert the circle parameters a, b and r to integers. + detected_circles = np.uint16(np.around(detected_circles)) + + for pt in detected_circles[0, :]: + a, b, r = pt[0], pt[1], pt[2] + + # Draw the circumference of the circle. + cv2.circle(output, (a, b), r, (0, 255, 0), 2) + + # Draw a small circle (of radius 1) to show the center. + # cv2.circle(output, (a, b), 1, (0, 0, 255), 3) + # cv2.imshow("Detected Circle", output) + # cv2.waitKey(0) + # cv2.imshow('component', component) + +def watershed(pred): + cells=[] + for ch in range(3): + gray=pred[:,:,ch] + D = ndimage.distance_transform_edt(gray) + localMax = peak_local_max(D, indices=False, min_distance=10,exclude_border=False,labels=gray) + markers = ndimage.label(localMax, structure=np.ones((3, 3)))[0] + labels = ws(-D, markers, mask=gray) + for label in np.unique(labels): + if label == 0: + continue + mask = np.zeros(gray.shape, dtype="uint8") + mask[labels == label] = 255 + cnts = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)[-2] + c = max(cnts, key=cv2.contourArea) + ((x, y), _) = cv2.minEnclosingCircle(c) + cells.append([x,y,ch]) + return np.array(cells) + +def read_PathoNet_data(img_addr): + print(img_addr) + points = np.loadtxt(img_addr.replace('.jpg', '_points.txt')) + image = cv2.imread(img_addr) + # positive_mask = np.zeros((640, 640), dtype=np.uint8) + # negative_mask = np.zeros((640, 640), dtype=np.uint8) + positive_points = [] + negative_points = [] + for p in points: + if int(p[2]) == 1: + image[int(p[1]), int(p[0])] = (255, 0, 255) + negative_points.append((int(p[1]), int(p[0]))) + else: + image[int(p[1]), int(p[0])] = (0, 255, 255) + positive_points.append((int(p[1]), int(p[0]))) + # cv2.imshow('image', image) + # cv2.waitKey(0) + return positive_points, negative_points + + +def crop_modalities(input_dir, img_name, img_types, location, size, output_dir): + for img_type in img_types: + image = cv2.imread(os.path.join(input_dir, img_name + img_type + '.png')) + crop = image[location[0]:location[0] + size[0], location[1]: location[1] + size[1]] + cv2.imwrite(os.path.join(output_dir, 'MYC_' + img_type + '.png'), crop) + + +def read_mask_rcnn_segmentation_masks(input_dir, image_size): + images = os.listdir(input_dir) + masks = {} + for img in images: + if '.png' in img and len(img.split('_')) > 5: + print(img) + splitted = img.split('_') + image_name = '' + for i in range(0, len(splitted) - 3): + image_name += splitted[i] + '_' + image_name += splitted[-3] + cell_type = 'positive' if splitted[-2] == 1 else 'negative' + image = cv2.imread(os.path.join(input_dir, img)) + image = cv2.resize(image, (image_size, image_size)) + image_bw = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) + image_bw[image[:,:,0] > 250] = 1 + image_bw[image[:,:,1] > 250] = 1 + image_bw[image[:,:,2] > 250] = 1 + if image_name not in masks.keys(): + masks[image_name] = {'positive': np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8), 'negative': np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8), 'binary': np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)} + masks[image_name][cell_type][image_bw > 0] = 1 + masks[image_name]['binary'][image_bw > 0] = 1 + return masks + + +def read_mask_rcnn_detection_masks(input_dir, image_size): + images = os.listdir(input_dir) + masks = {} + for img in images: + if '_' in img and '.png' in img: + splitted = img.split('_') + image_name = '' + for i in range(0, len(splitted) - 3): + image_name += splitted[i] + '_' + image_name += splitted[-3] + cell_type = 'positive' if splitted[-2] == '1' else 'negative' + image = cv2.imread(os.path.join(input_dir, img)) + image = cv2.resize(image, (image_size, image_size)) + image_bw = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) + image_bw[image[:,:,0] > 250] = 1 + image_bw[image[:,:,1] > 250] = 1 + image_bw[image[:,:,2] > 250] = 1 + points = np.nonzero(image_bw) + x = points[0] + y = points[1] + bounding_box = [np.min(x), np.min(y), np.max(x), np.max(y)] + center = (int((bounding_box[0] + bounding_box[2]) / 2), int((bounding_box[1] + bounding_box[3]) / 2)) + if image_name not in masks.keys(): + masks[image_name] = {'positive': [], 'negative': [], 'binary': []} + masks[image_name][cell_type].append(center) + masks[image_name]['binary'].append(center) + return masks + + +def read_unetplusplus_unet(input_npy, npy_path): + # results = np.load(input_dir) + imgs = np.load(input_npy) + with open(npy_path + 'dict_name.txt', 'r') as f: + names = json.load(f) + res_path = os.path.dirname(input_npy) + for i in range(imgs.shape[0]): + img = imgs[i] + print(img.shape) + plt.imsave(os.path.join(res_path, names[str(i)] + '.png'), (img * 255).astype(np.uint8)) + + +def read_Unet_plusplus_detection_masks(input_dir, image_size): + images = os.listdir(input_dir) + masks = {} + for img in images: + if '.png' in img: + image = cv2.imread(os.path.join(input_dir, img)) + image = cv2.resize(image, (image_size, image_size)) + # cv2.imshow('image1', image) + img = img.replace('.png', '') + new_image = np.zeros_like(image) + new_image[image[:,:,0] > 150] = (255, 0, 0) + new_image[image[:,:,2] > 150] = (0, 0, 255) + # cv2.imshow('image2', new_image) + det_img = np.squeeze(new_image).astype(np.uint8) + cells = watershed(det_img) + final_cells = [] + positive_points = [] + negative_points = [] + seen = np.zeros((new_image.shape[0], new_image.shape[1]), dtype=np.uint8) + for i in range(len(cells)): + p1 = cells[i] + x1, y1, c1 = int(p1[1]), int(p1[0]), int(p1[2]) + flag = False + seen[x1][y1] = 1 + for j in range(len(cells)): + p2 = cells[j] + x2, y2, c2 = int(p2[1]), int(p2[0]), int(p2[2]) + if seen[x2][y2] == 0: + if abs(x1 - x2) < 20 and abs(y1 - y2) < 20: + flag = True + # new_cell = int((x1 + x2) / 2), int((y1 + y2) / 2), int((c1 + c2)/2) + # final_cells.append(new_cell) + if not flag: + final_cells.append(p1) + if c1 == 2: + positive_points.append((x1, y1)) + elif c1 == 0: + negative_points.append((x1, y1)) + + # for p in positive_points: + # new_image[p[0]-5:p[0]+5, p[1]-5:p[1]+5] = (255,0,255) + # for p in negative_points: + # new_image[p[0]-5:p[0]+5, p[1]-5:p[1]+5] = (0,255,255) + + # cv2.imshow('image3', new_image) + # cv2.waitKey(0) + masks[img] = {'positive': [], 'negative': []} + masks[img]['positive'] = positive_points + masks[img]['negative'] = negative_points + # masks[img]['binary'].append(center) + return masks + + +def read_Unet_plusplus_segmentation_masks(input_dir, image_size): + images = os.listdir(input_dir) + masks = {} + for img in images: + if '.png' in img: + image = cv2.imread(os.path.join(input_dir, img)) + image = cv2.resize(image, (image_size, image_size)) + img = img.replace('.png', '') + new_image = np.zeros_like(image) + new_image[image[:,:,0] > 0] = (255, 0, 0) + new_image[image[:,:,2] > 0] = (0, 0, 255) + masks[img] = {'positive': np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8), 'negative': np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8), 'binary': np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)} + masks[img]['positive'][image[:,:,0] > 0] = 1 + masks[img]['negative'][image[:,:,2] > 0] = 1 + masks[img]['binary'][image[:,:,0] > 0] = 1 + masks[img]['binary'][image[:,:,2] > 0] = 1 + return masks + +def read_DeepLIIF_segmentation_masks(input_dir, image_size, thresh=100): + images = os.listdir(input_dir) + masks = {} + for img in images: + if '_fake_B_5.png' in img: + image = cv2.imread(os.path.join(input_dir, img)) + image = cv2.resize(image, (image_size, image_size)) + img = img.replace('_fake_B_5.png', '') + new_image = np.zeros_like(image) + new_image[np.logical_and(image[:,:,0] > thresh, image[:,:,0] > image[:,:,2])] = (255, 0, 0) + new_image[np.logical_and(image[:,:,2] > thresh, image[:,:,2] >= image[:,:,0])] = (0, 0, 255) + # new_image[image[:,:,1] > thresh] = 0 + # new_image[image[:,:,0] > 100] = (255, 0, 0) + # new_image[image[:,:,2] > 100] = (0, 0, 255) + # cv2.imshow('image', image) + # cv2.imshow('new_image', new_image) + # cv2.waitKey(0) + + positive_mask = new_image[:, :, 0] + negative_mask = new_image[:, :, 2] + positive_mask = ndimage.binary_fill_holes(positive_mask, structure=np.ones((5, 5))).astype(np.uint8) + negative_mask = ndimage.binary_fill_holes(negative_mask, structure=np.ones((5, 5))).astype(np.uint8) + positive_mask = remove_small_objects_from_image(positive_mask, 50) + negative_mask = remove_small_objects_from_image(negative_mask, 50) + positive_mask[negative_mask > 0] = 0 + masks[img] = {'positive': np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8), 'negative': np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8), 'binary': np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)} + # masks[img]['positive'][image[:,:,0] > thresh] = 1 + # masks[img]['negative'][image[:,:,2] > thresh] = 1 + # masks[img]['binary'][image[:,:,0] > thresh] = 1 + # masks[img]['binary'][image[:,:,2] > thresh] = 1 + masks[img]['positive'][positive_mask > 0] = 1 + masks[img]['negative'][negative_mask > 0] = 1 + masks[img]['binary'][positive_mask > 0] = 1 + masks[img]['binary'][negative_mask > 0] = 1 + return masks + + +def read_ki67_detection_points(img): + input_dir = '/media/parmida/Work/DetectionDataset/test_Ki67' + image_size = 512 + image = cv2.imread(os.path.join(input_dir, img)) + image = image[:,5*512:] + image = cv2.resize(image, (image_size, image_size)) + cv2.imshow('mask', image) + # cv2.waitKey(0) + positive_image = np.zeros((image.shape[0], image.shape[1])) + positive_image[image[:, :, 2] > 0] = 1 + positive_mask = get_centers_of_objects(positive_image) + + negative_image = np.zeros((image.shape[0], image.shape[1])) + negative_image[image[:, :, 0] > 0] = 1 + negative_mask = get_centers_of_objects(negative_image) + + return positive_mask, negative_mask + + +def get_centers_of_objects(image): + mask = np.zeros((image.shape[0], image.shape[1])) + labeled, nr_objects = ndimage.label(image > 0) + for c in range(1, nr_objects): + component = np.zeros_like(image) + component[labeled == c] = image[labeled == c] + + points = np.nonzero(component) + x = points[0] + y = points[1] + bounding_box = [np.min(x), np.min(y), np.max(x), np.max(y)] + center = (int((bounding_box[0] + bounding_box[2]) / 2), int((bounding_box[1] + bounding_box[3]) / 2)) + mask[center[0], center[1]] = 255 + return mask + + +def read_Unet_plusplus_boundary_mask_image(img_addr, image_size): + print(img_addr) + image = cv2.cvtColor(cv2.imread(img_addr), cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (image_size, image_size)) + new_image = np.zeros_like(image) + res_image = np.zeros_like(image) + new_image[image[:,:,0] > 100] = (255, 0, 0) + new_image[image[:,:,2] > 100] = (0, 0, 255) + positive_mask = new_image[:,:,0] + negative_mask = new_image[:,:,2] + contours, hierarchy = cv2.findContours(positive_mask, + cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + cv2.drawContours(res_image, contours, -1, (255, 0, 0), 2) + contours, hierarchy = cv2.findContours(negative_mask, + cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + cv2.drawContours(res_image, contours, -1, (0, 0, 255), 2) + cv2.imwrite(img_addr.replace('.png', '_Seg.png'), cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB)) + return res_image, new_image + +def read_DeepLIIF_boundary_mask_image(img_addr, image_size, thresh=120): + image = cv2.cvtColor(cv2.imread(img_addr), cv2.COLOR_BGR2RGB) + # cv2.imshow('image', image) + image = cv2.resize(image, (image_size, image_size)) + new_image = np.zeros_like(image) + + new_image[np.logical_and(image[:, :, 0] > thresh, image[:, :, 0] > image[:, :, 2])] = (255, 0, 0) + new_image[np.logical_and(image[:, :, 2] > thresh, image[:, :, 2] >= image[:, :, 0])] = (0, 0, 255) + + new_image[image[:,:,1] > 80] = 0 + # new_image[image[:,:,0] > 100] = (255, 0, 0) + # new_image[image[:,:,2] > 100] = (0, 0, 255) + # cv2.imshow('image', image) + # cv2.imshow('new_image', new_image) + # cv2.waitKey(0) + + positive_mask = new_image[:, :, 0] + negative_mask = new_image[:, :, 2] + positive_mask = ndimage.binary_fill_holes(positive_mask, structure=np.ones((5, 5))).astype(np.uint8) + negative_mask = ndimage.binary_fill_holes(negative_mask, structure=np.ones((5, 5))).astype(np.uint8) + positive_mask = remove_small_objects_from_image(positive_mask, 50) + negative_mask = remove_small_objects_from_image(negative_mask, 50) + # positive_mask[negative_mask > 0] = 0 + negative_mask[positive_mask > 0] = 0 + + # new_image[np.logical_and(image[:,:,0] > thresh, image[:,:,0] > image[:,:,2])] = (255, 0, 0) + # new_image[np.logical_and(image[:,:,2] > thresh, image[:,:,2] >= image[:,:,0])] = (0, 0, 255) + # new_image[image[:,:,1] > thresh] = 0 + # + res_image = np.zeros_like(image) + # positive_mask = new_image[:,:,0] + # negative_mask = new_image[:,:,2] + # # positive_mask = cv2.morphologyEx(positive_mask, cv2.MORPH_DILATE, kernel=np.ones((3,3))) + # # negative_mask = cv2.morphologyEx(negative_mask, cv2.MORPH_DILATE, kernel=np.ones((3,3))) + # positive_mask = ndimage.binary_fill_holes(positive_mask, structure=np.ones((5,5))).astype(np.uint8) + # negative_mask = ndimage.binary_fill_holes(negative_mask, structure=np.ones((5,5))).astype(np.uint8) + # positive_mask = remove_small_objects_from_image(positive_mask, 50) + # negative_mask = remove_small_objects_from_image(negative_mask, 50) + # # positive_mask[negative_mask > 0] = 0 + # negative_mask[positive_mask > 0] = 0 + # negative_mask = mask.copy() + new_image = np.zeros_like(image) + new_image[positive_mask > 0] = (255,0,0) + new_image[negative_mask > 0] = (0,0,255) + # cv2.imshow('positive_mask', positive_mask*255) + # cv2.imshow('negative_mask', negative_mask*255) + # cv2.waitKey(0) + + contours, hierarchy = cv2.findContours(positive_mask, + cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + cv2.drawContours(res_image, contours, -1, (255, 0, 0), 2) + contours, hierarchy = cv2.findContours(negative_mask, + cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + cv2.drawContours(res_image, contours, -1, (0, 0, 255), 2) + return res_image, new_image + + + +def create_log_area_mask_cell_type(predicted_mask, gt_mask, index=0, colormap='bwr'): + smooth = 0.0001 + predicted = predicted_mask[:,:,index] + gt = gt_mask[:,:,index] + # gt[gt_mask[:,:,1] > 0]=0 + final_mask = np.zeros((predicted.shape[0], predicted.shape[1])) + + labeled, nr_objects = ndimage.label(predicted > 0) + labeled_gt, nr_objects_gt = ndimage.label(gt > 0) + for c in range(1, nr_objects): + component = np.zeros_like(predicted) + component[labeled == c] = predicted[labeled == c] + nonzeros = np.nonzero(component) + component_gt_size = 0 + # cv2.imshow('component', component) + for i in range(len(nonzeros[0])): + if gt[nonzeros[0][i], nonzeros[1][i]] > 0 and labeled_gt[nonzeros[0][i], nonzeros[1][i]] > 0: + label = labeled_gt[nonzeros[0][i], nonzeros[1][i]] + component_gt = np.zeros_like(gt) + component_gt[labeled_gt == label] = gt[labeled_gt == label] + nonzeros_gt = np.nonzero(component_gt) + component_gt_size = len(nonzeros_gt[0]) + # print(component_gt_size) + break + # if component_gt_size > 0: + component_size = len(nonzeros[0]) + # print('component_gt_size:', component_gt_size) + # print('component_size:', component_size) + if component_gt_size == 0: + value = 5 + # print('yes') + else: + value = np.log2((component_size+smooth)/(component_gt_size+smooth)) + value = min(value, 2) if value >= 0 else max(value, -2) + # print((len(nonzeros[0])+smooth)/(component_gt_size+smooth)) + # if value < 0: + # print(value) + # print(min(value, 2) if value >= 0 else max(value, -2)) + + final_mask[component > 0] = value + # cv2.waitKey(0) + # final_mask[predicted == 0] = 0 + # color_range = np.max(final_mask) + image_log = np.zeros_like(gt_mask) + for i in range(0, final_mask.shape[0]): + for j in range(0, final_mask.shape[1]): + value = final_mask[i][j] + if value == 5: + # print('yes!!!!!!!') + image_log[i, j] = (255, 255, 0) + elif -0.5 <= value <= 0.5: + image_log[i, j] = (255, 0, 0) if colormap == 'positive' else (0, 0, 255) + elif value > 0.5: + image_log[i, j] = (int(127.5/value), 0, 0) if colormap == 'positive' else (0, 0, int(127.5/value)) + elif value < -0.5: + image_log[i, j] = (255, 255 - int(127.5/abs(value)), 255 - int(127.5/abs(value))) if colormap == 'positive' else (255 - int(127.5/abs(value)), 255 - int(127.5/abs(value)), 255) + # print((255, 255 - int(127.5/abs(value)), 255 - int(127.5/abs(value)))) + # plt.tight_layout() + # plt.savefig('temp.png', bbox_inches='tight', pad_inches=0) + # + # image_log = plt.imread('temp.png') + # image_log = cv2.resize(image_log, (512, 512)) + # image_log[predicted_mask[:, :, index] == 0] = 0 + return image_log + + +def create_log_area_mask(predicted_mask, gt_mask): + log_positive = create_log_area_mask_cell_type(predicted_mask, gt_mask, index=0, colormap='positive') + log_negative = create_log_area_mask_cell_type(predicted_mask, gt_mask, index=2, colormap='negative') + log_positive = (log_positive[:,:,:3] * 255).astype(np.uint8) + log_negative = (log_negative[:,:,:3] * 255).astype(np.uint8) + log_image = np.zeros_like(predicted_mask) + log_image[predicted_mask[:,:,0] > 0] = log_positive[predicted_mask[:,:,0] > 0] + log_image[predicted_mask[:,:,2] > 0] = log_negative[predicted_mask[:,:,2] > 0] + return log_image + + + +def create_color_map_image(colormap): + image = np.zeros((400, 100)) + for i in range(image.shape[1]): + for j in range(image.shape[0]): + image[j, i] = (j - 200) / 100 + colormap_image = np.zeros((400, 100, 3), dtype=np.uint8) + for i in range(0, image.shape[0]): + for j in range(0, image.shape[1]): + value = image[i][j] + if -0.5 <= value <= 0.5: + colormap_image[i, j] = (255, 0, 0) if colormap == 'positive' else (0, 0, 255) + elif value > 0.5: + colormap_image[i, j] = (int(127.5/value), 0, 0) if colormap == 'positive' else (0, 0, int(127.5/value)) + elif value < -0.5: + colormap_image[i, j] = (255, 255 - int(127.5/abs(value)), 255 - int(127.5/abs(value))) if colormap == 'positive' else (255 - int(127.5/abs(value)), 255 - int(127.5/abs(value)), 255) + colormap_image = cv2.rotate(colormap_image, cv2.ROTATE_90_COUNTERCLOCKWISE) + return colormap_image + + +def read_image_write_crop_parts(input_dir, img_name, location, crop_size, output_dir): + types = ['', '_DAPI', '_DAPILap2', '_Hema', '_Ki67', '_Seg_Aligned_Bound'] + for image_type in types: + image = cv2.imread(os.path.join(input_dir, img_name + image_type + '.png')) + crop = image[location[0]: location[0] + crop_size[0], location[1]: location[1] + crop_size[1]] + cv2.imwrite(os.path.join(output_dir, img_name + image_type + '.png'), crop) + + +def overlay_ki67_on_DAPI(input_DAPI, input_ki67): + # overlaid_image = np.zeros((input_DAPI.shape[0], input_DAPI.shape[1], 3), dtype=np.uint8) + # overlaid_image[:,:,2] = input_ki67 + # overlaid_image[:,:,0] = input_DAPI + # overlaid_image[:,:,1] = np.floor((input_DAPI + input_ki67) / 2) + # overlaid_image[:,:,1] = input_DAPI + + # overlaid_image = cv2.addWeighted(input_DAPI, 0.6, input_Lap2, 0.4, 1) + # overlaid_image = cv2.addWeighted(input_DAPI, 0.9, input_ki67, 0.1, 1) + overlaid_image = input_DAPI.copy() + overlaid_image[input_ki67[:,:,2] >= 30] = input_ki67[input_ki67[:,:,2] >= 30] + # overlaid_image[input_ki67[:,:,2] < 30] = input_ki67[input_ki67[:,:,2] < 30] * 0.2 + overlaid_image[input_ki67[:,:,2]] * 0.8 + return overlaid_image + + +def count_cell_number(image, channel=1, thresh=0): + mask = image[:, :, channel] + labeled, nr_objects = ndimage.label(mask > thresh) + return nr_objects + diff --git a/deepliif/stat/PostProcessSegmentationMask.py b/deepliif/stat/PostProcessSegmentationMask.py new file mode 100644 index 0000000..07e1681 --- /dev/null +++ b/deepliif/stat/PostProcessSegmentationMask.py @@ -0,0 +1,171 @@ +import os.path +import cv2 +import numpy as np +from scipy import ndimage +from numba import jit +from skimage import measure, feature + + +def get_average_cell_size(image): + label_image = measure.label(image, background=0) + labels = np.unique(label_image) + average_cell_size = 0 + for _label in range(1, len(labels)): + indices = np.where(label_image == _label) + pixel_count = np.count_nonzero(image[indices]) + average_cell_size += pixel_count + average_cell_size /= len(labels) + return average_cell_size + + +@jit(nopython=True) +def get_average_cell_size_gpu(label_image, image_size, labels_no): + average_cell_size = 0 + for _label in labels_no: + if _label == 0: + continue + pixel_count = 0 + for index_x in range(image_size[0]): + for index_y in range(image_size[1]): + if label_image[index_x, index_y] == _label: + pixel_count += 1 + average_cell_size += pixel_count + average_cell_size /= len(labels_no) + return average_cell_size + + +@jit(nopython=True) +def compute_cell_mapping(new_mapping, image_size, small_object_size=20): + marked = [[False for _ in range(image_size[1])] for _ in range(image_size[0])] + for i in range(image_size[0]): + for j in range(image_size[1]): + if marked[i][j] is False and (new_mapping[i, j, 0] > 0 or new_mapping[i, j, 2] > 0): + cluster_red_no, cluster_blue_no = 0, 0 + pixels = [(i, j)] + cluster = [(i, j)] + marked[i][j] = True + while len(pixels) > 0: + pixel = pixels.pop() + if new_mapping[pixel[0], pixel[1], 0] > 0: + cluster_red_no += 1 + if new_mapping[pixel[0], pixel[1], 2] > 0: + cluster_blue_no += 1 + for neigh_i in range(-1, 2): + for neigh_j in range(-1, 2): + neigh_pixel = (pixel[0] + neigh_i, pixel[1] + neigh_j) + if 0 <= neigh_pixel[0] < image_size[0] and 0 <= neigh_pixel[1] < image_size[1] and marked[neigh_pixel[0]][neigh_pixel[1]] is False and (new_mapping[neigh_pixel[0], neigh_pixel[1], 0] > 0 or new_mapping[neigh_pixel[0], neigh_pixel[1], 2] > 0): + cluster.append(neigh_pixel) + pixels.append(neigh_pixel) + marked[neigh_pixel[0]][neigh_pixel[1]] = True + cluster_value = None + if cluster_red_no < cluster_blue_no: + cluster_value = (0, 0, 255) + else: + cluster_value = (255, 0, 0) + if len(cluster) < small_object_size: + cluster_value = (0, 0, 0) + if cluster_value is not None: + for node in cluster: + new_mapping[node[0], node[1]] = cluster_value + return new_mapping + + +@jit(nopython=True) +def remove_noises(channel, image_size, small_object_size=20): + marked = [[False for _ in range(image_size[1])] for _ in range(image_size[0])] + for i in range(image_size[0]): + for j in range(image_size[1]): + if marked[i][j] is False and channel[i, j] > 0: + pixels = [(i, j)] + cluster = [(i, j)] + marked[i][j] = True + while len(pixels) > 0: + pixel = pixels.pop() + for neigh_i in range(-1, 2): + for neigh_j in range(-1, 2): + neigh_pixel = (pixel[0] + neigh_i, pixel[1] + neigh_j) + if 0 <= neigh_pixel[0] < image_size[0] and 0 <= neigh_pixel[1] < image_size[1] and marked[neigh_pixel[0]][neigh_pixel[1]] is False and channel[neigh_pixel[0], neigh_pixel[1]] > 0: + cluster.append(neigh_pixel) + pixels.append(neigh_pixel) + marked[neigh_pixel[0]][neigh_pixel[1]] = True + + cluster_value = None + if len(cluster) < small_object_size: + cluster_value = 0 + if cluster_value is not None: + for node in cluster: + channel[node[0], node[1]] = cluster_value + return channel + + +def remove_noises_fill_empty_holes(label_img, size=200): + inverse_img = 255 - label_img + inverse_img_removed = remove_noises(inverse_img, inverse_img.shape, small_object_size=size) + label_img[inverse_img_removed == 0] = 255 + return label_img + + +def positive_negative_masks(mask, thresh=100, boundary_thresh=100, noise_objects_size=50): + positive_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.uint8) + negative_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.uint8) + + red = mask[:, :, 0] + blue = mask[:, :, 2] + boundary = mask[:, :, 1] + + # Filtering boundary pixels + boundary[boundary < boundary_thresh] = 0 + + positive_mask[red > thresh] = 255 + positive_mask[boundary > 0] = 0 + positive_mask[blue > red] = 0 + + negative_mask[blue > thresh] = 255 + negative_mask[boundary > 0] = 0 + negative_mask[red >= blue] = 0 + + cell_mapping = np.zeros_like(mask) + cell_mapping[:, :, 0] = positive_mask + cell_mapping[:, :, 2] = negative_mask + + compute_cell_mapping(cell_mapping, mask.shape, small_object_size=noise_objects_size) + cell_mapping[cell_mapping > 0] = 255 + + positive_mask = cell_mapping[:, :, 0] + negative_mask = cell_mapping[:, :, 2] + + # return remove_noises_fill_empty_holes(positive_mask, noise_objects_size), remove_noises_fill_empty_holes(negative_mask, noise_objects_size) + return positive_mask, negative_mask + + +def create_final_segmentation_mask_with_boundaries(positive_mask, negative_mask): + refined_mask = np.zeros((positive_mask.shape[0], positive_mask.shape[1], 3), dtype=np.uint8) + + refined_mask[positive_mask > 0] = (255, 0, 0) + refined_mask[negative_mask > 0] = (0, 0, 255) + + edges = feature.canny(refined_mask[:,:,0], sigma=3).astype(np.uint8) + contours, _ = cv2.findContours(edges, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + cv2.drawContours(refined_mask, contours, -1, (0, 255, 0), 2) + + edges = feature.canny(refined_mask[:,:,2], sigma=3).astype(np.uint8) + contours, _ = cv2.findContours(edges, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + cv2.drawContours(refined_mask, contours, -1, (0, 255, 0), 2) + + return refined_mask + + +def count_number_of_cells(input_dir): + images = os.listdir(input_dir) + total_red = 0 + total_blue = 0 + for img in images: + image = cv2.cvtColor(cv2.imread(os.path.join(input_dir, img)), cv2.COLOR_BGR2RGB) + image = image[:,5*512:] + red = image[:,:,0] + blue = image[:,:,2] + labeled_red, nr_objects_red = ndimage.label(red > 0) + labeled_blue, nr_objects_blue = ndimage.label(blue > 0) + total_red += nr_objects_red + total_blue += nr_objects_blue + return total_red, total_blue diff --git a/deepliif/stat/Segmentation_Metrics.py b/deepliif/stat/Segmentation_Metrics.py new file mode 100644 index 0000000..afb2aac --- /dev/null +++ b/deepliif/stat/Segmentation_Metrics.py @@ -0,0 +1,237 @@ +import collections + +import numpy as np +import cv2 +import os +from numba import jit +from skimage import measure +import time +from .PostProcessSegmentationMask import positive_negative_masks + + +@jit(nopython=True) +def compute_metrics_gpu(mask_img, gt_img, image_size): + TP, FP, FN, TN = 0, 0, 0, 0 + for i in range(image_size[0]): + for j in range(image_size[1]): + if mask_img[i, j] > 0 and gt_img[i, j] > 0: + TP += 1 + elif mask_img[i, j] > 0 and gt_img[i, j] == 0: + FP += 1 + elif mask_img[i, j] == 0 and gt_img[i, j] > 0: + FN += 1 + elif mask_img[i, j] == 0 and gt_img[i, j] == 0: + TN += 1 + + smooth = 0.0001 + if TP == 0: + if np.count_nonzero(gt_img) > 0 or FP > 0: + IOU, precision, recall, Dice, f1, pixAcc = 0, 0, 0, 0, 0, 0 + else: + IOU, precision, recall, Dice, f1, pixAcc = 1, 1, 1, 1, 1, 1 + else: + IOU = (TP) / (TP + FP + FN) + precision = (TP) / (TP + FP) + recall = (TP) / (TP + FN) + f1 = 2 * precision * recall / (precision + recall) + Dice = (2 * TP) / (2 * TP + FP + FN) + pixAcc = (TP + TN) / (TP + TN + FP + FN) + return IOU, precision, recall, f1, Dice, pixAcc + + +def compute_metrics(mask_img, gt_img): + smooth = 0.0001 + intesection_TP = np.logical_and(gt_img, mask_img) + intesection_FN = np.logical_and(gt_img, 1 - mask_img) + intesection_FP = np.logical_and(1 - gt_img, mask_img) + intesection_TN = np.logical_and(1 - gt_img, 1 - mask_img) + union = np.logical_or(gt_img, mask_img) + + iou_score = (np.sum(intesection_TP) + smooth) / (np.sum(union) + smooth) + precision_score = (np.sum(intesection_TP) + smooth) / (np.sum(intesection_TP) + np.sum(intesection_FP) + smooth) + recall_score = (np.sum(intesection_TP) + smooth) / (np.sum(intesection_TP) + np.sum(intesection_FN) + smooth) + f1_score = 2 * (precision_score * recall_score) / (precision_score + recall_score) + dice_score = (2 * np.sum(intesection_TP) + smooth) / (2 * np.sum(intesection_TP) + np.sum(intesection_FN) + np.sum(intesection_FP) + smooth) + pix_acc_score = (np.sum(intesection_TP) + np.sum(intesection_TN) + smooth) / (np.sum(intesection_TP) + np.sum(intesection_TN) + np.sum(intesection_FN) + np.sum(intesection_FP) + smooth) + return iou_score, precision_score, recall_score, f1_score, dice_score, pix_acc_score + + +def compute_jaccard_index(set_1, set_2): + n = len(set_1.intersection(set_2)) + return n / float(len(set_1) + len(set_2) - n) + + +def compute_aji(gt_image, mask_image): + label_image_gt = measure.label(gt_image, background=0) + label_image_mask = measure.label(mask_image, background=0) + gt_labels = np.unique(label_image_gt) + mask_labels = np.unique(label_image_mask) + mask_components = [] + mask_marked = [] + for mask_label in mask_labels: + if mask_label == 0: + continue + comp = np.zeros((gt_image.shape[0], gt_image.shape[1]), dtype=np.uint8) + comp[label_image_mask == mask_label] = 1 + mask_components.append(comp) + mask_marked.append(False) + + total_intersection = 0 + total_union = 0 + total_U = 0 + for gt_label in gt_labels: + if gt_label == 0: + continue + comp = np.zeros((gt_image.shape[0], gt_image.shape[1]), dtype=np.uint8) + comp[label_image_gt == gt_label] = 1 + intersection = [0, 0, 0] # index, intersection, union + for i in range(len(mask_components)): + if not mask_marked[i]: + comp_intersection = np.sum(np.logical_and(comp, mask_components[i])) + if comp_intersection > intersection[1]: + union = np.sum(np.logical_or(comp, mask_components[i])) + intersection = [i, comp_intersection, union] + if intersection[1] > 0: + mask_marked[intersection[0]] = True + total_intersection += intersection[1] + total_union += intersection[2] + for i in range(len(mask_marked)): + if not mask_marked[i]: + total_U += np.sum(mask_components[i]) + aji = total_intersection / (total_union + total_U) if (total_union + total_U) > 0 else 0 + return aji + + +def compute_segmentation_metrics(gt_dir, model_dir, model_name, image_size=512, thresh=100, boundary_thresh=100, small_object_size=20, raw_segmentation=True, suffix_seg=None): + info_dict = [] + metrics = collections.defaultdict(float) + images = os.listdir(model_dir) + + counter = 0 + if suffix_seg is not None: + postfix = f'_{suffix_seg}.png' + else: + postfix = '_Seg.png' if raw_segmentation else '_SegRefined.png' + for mask_name in images: + if postfix in mask_name: + counter += 1 + + mask_image = cv2.cvtColor(cv2.imread(os.path.join(model_dir, mask_name)), cv2.COLOR_BGR2RGB) + mask_image = cv2.resize(mask_image, (image_size, image_size)) + if not raw_segmentation: + positive_mask = mask_image[:, :, 0] + negative_mask = mask_image[:, :, 2] + else: + positive_mask, negative_mask = positive_negative_masks(mask_image, thresh, boundary_thresh, small_object_size) + + positive_mask[positive_mask > 0] = 1 + negative_mask[negative_mask > 0] = 1 + + gt_img = cv2.cvtColor(cv2.imread(os.path.join(gt_dir, mask_name)), cv2.COLOR_BGR2RGB) + gt_img = cv2.resize(gt_img, (image_size, image_size)) + + positive_gt = gt_img[:, :, 0] + negative_gt = gt_img[:, :, 2] + + positive_gt[positive_gt > 0] = 1 + negative_gt[negative_gt > 0] = 1 + + # AJI_positive = compute_aji(positive_gt, positive_mask) + # start = time.time() + IOU_positive, precision_positive, recall_positive, f1_positive, Dice_positive, pixAcc_positive = compute_metrics_gpu(positive_mask, positive_gt, gt_img.shape) + # end = time.time() + # print(end - start) + # print('GPU: ', IOU_positive, precision_positive, recall_positive, f1_positive, Dice_positive, pixAcc_positive) + # start = time.time() + # IOU_positive, precision_positive, recall_positive, f1_positive, Dice_positive, pixAcc_positive = compute_metrics(positive_mask, positive_gt) + # end = time.time() + # print(end - start) + # print('CPU: ', end - start, IOU_positive, precision_positive, recall_positive, f1_positive, Dice_positive, pixAcc_positive) + + # AJI_negative = compute_aji(negative_gt, negative_mask) + # start = time.time() + IOU_negative, precision_negative, recall_negative, f1_negative, Dice_negative, pixAcc_negative = compute_metrics_gpu(negative_mask, negative_gt, gt_img.shape) + # end = time.time() + # print('GPU: ', IOU_negative, precision_negative, recall_negative, f1_negative, Dice_negative, pixAcc_negative) + # start = time.time() + # IOU_negative, precision_negative, recall_negative, f1_negative, Dice_negative, pixAcc_negative = compute_metrics(negative_mask, negative_gt) + # end = time.time() + # print('CPU: ', end - start, IOU_negative, precision_negative, recall_negative, f1_negative, Dice_negative, pixAcc_negative) + + info_dict.append({'Model': model_name, + 'image_name': mask_name, + 'cell_type': 'Positive', + 'precision': precision_positive * 100, + 'recall': recall_positive * 100, + 'f1': f1_positive * 100, + 'Dice': Dice_positive * 100, + 'IOU': IOU_positive * 100, + 'PixAcc': pixAcc_positive * 100 + # 'AJI': AJI_positive * 100 + }) + + info_dict.append({'Model': model_name, + 'image_name': mask_name, + 'cell_type': 'Negative', + 'precision': precision_negative * 100, + 'recall': recall_negative * 100, + 'f1': f1_negative * 100, + 'Dice': Dice_negative * 100, + 'IOU': IOU_negative * 100, + 'PixAcc': pixAcc_negative * 100 + # 'AJI': AJI_negative * 100 + }) + + precision = (precision_positive * 100 + precision_negative * 100) / 2 + recall = (recall_positive * 100 + recall_negative * 100) / 2 + f1 = (f1_positive * 100 + f1_negative * 100) / 2 + Dice = (Dice_positive * 100 + Dice_negative * 100) / 2 + IOU = (IOU_positive * 100 + IOU_negative * 100) / 2 + pixAcc = (pixAcc_positive * 100 + pixAcc_negative * 100) / 2 + # AJI = (AJI_positive * 100 + AJI_negative * 100) / 2 + + info_dict.append({'Model': model_name, + 'image_name': mask_name, + 'cell_type': 'Mean', + 'precision': precision, + 'recall': recall, + 'f1': f1, + 'Dice': Dice, + 'IOU': IOU, + 'PixAcc': pixAcc, + # 'AJI': AJI + }) + + metrics['precision'] += precision + metrics['precision_positive'] += precision_positive + metrics['precision_negative'] += precision_negative + + metrics['recall'] += recall + metrics['recall_positive'] += recall_positive + metrics['recall_negative'] += recall_negative + + metrics['f1'] += f1 + metrics['f1_positive'] += f1_positive + metrics['f1_negative'] += f1_negative + + metrics['Dice'] += Dice + metrics['Dice_positive'] += Dice_positive + metrics['Dice_negative'] += Dice_negative + + metrics['IOU'] += IOU + metrics['IOU_positive'] += IOU_positive + metrics['IOU_negative'] += IOU_negative + + metrics['PixAcc'] += pixAcc + metrics['PixAcc_positive'] += pixAcc_positive + metrics['PixAcc_negative'] += pixAcc_negative + + # metrics['AJI'] += AJI + # metrics['AJI_positive'] += AJI_positive + # metrics['AJI_negative'] += AJI_negative + + for key in metrics: + metrics[key] /= counter + + return info_dict, metrics + diff --git a/deepliif/stat/__init__.py b/deepliif/stat/__init__.py new file mode 100644 index 0000000..7daac75 --- /dev/null +++ b/deepliif/stat/__init__.py @@ -0,0 +1,323 @@ + + +import os +import subprocess +import shutil +import time +from .ComputeStatistics import Statistics + +from ..options.test_options import TestOptions +from ..options import read_model_params, Options, print_options +from ..data import create_dataset +from ..models import create_model, init_nets, infer_modalities, infer_results_for_wsi +from ..util.visualizer import save_images +from ..util import html, allowed_file +import torch +import click + +from PIL import Image +import json + +def ensure_exists(d): + if not os.path.exists(d): + os.makedirs(d) + + +def format_file_structure(output_dir,source_folder='images',target_folder={'gt':'_real_B', 'pred':'_fake_B', 'input':'_real_A'}): + """ + target_folder: keys are new folder names, values are unique token with which to filter the needed images for this folder + 1) if segmentation metrics are needed and no seg images are generated from the main model, folder name "input" can be + specified when you want to run segmentation (using the seg model) on the original input as well + 2) if you want to also run metrics on the original input (e.g., SSIM/PSNR for upscaling task), then you should also specify + "input"" + otherwise only "gt" and "pred" are used + """ + dir_source = os.path.join(output_dir,source_folder) + + for folder_name, unique_token in target_folder.items(): + # create gt_dir and pred_dir + dir_folder = os.path.join(output_dir, folder_name) + if os.path.exists(dir_folder) and os.path.isdir(dir_folder): + shutil.rmtree(dir_folder) + os.makedirs(dir_folder) + + # separate the images + subprocess.run(f"cp {dir_source}/*{unique_token}* {dir_folder}", shell=True, check=True) + + # rename files + fns = os.listdir(dir_folder) + for fn in fns: + os.rename(f'{dir_folder}/{fn}',f"{dir_folder}/{fn.replace(f'{unique_token}','_')}") + + +def generate_predictions(dataroot, results_dir, checkpoints_dir, name='', num_test=10000, + phase='val', gpu_ids=(-1,), batch_size=None, epoch='latest'): + """ + a function version of cli.py test / test.py, to be used with evaluate() + params: + dataroot: reads images from here; expected to have a subfolder + results_dir: saves results here. + checkpoints_dir: load models from here. + name: name of the experiment, used as a subfolder under results_dir + num_test: only run test for num_test images + phase: this effectively refers to the subfolder name from where to load the images + gpu_ids: gpu-ids 0 gpu-ids 1 or gpu-ids -1 for CPU + batch_size: input batch size + """ + # retrieve options used in training setting, similar to cli.py test + model_dir = os.path.join(checkpoints_dir, name) + opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode='test') + + if gpu_ids and gpu_ids[0] == -1: + gpu_ids = [] + + # overwrite/supply unseen options using the values from the options provided in the command + setattr(opt,'checkpoints_dir',checkpoints_dir) + setattr(opt,'dataroot',dataroot) + setattr(opt,'name',name) + setattr(opt,'results_dir',results_dir) + setattr(opt,'num_test',num_test) + setattr(opt,'phase',phase) + setattr(opt,'gpu_ids',gpu_ids) + setattr(opt,'batch_size',batch_size) + setattr(opt,'epoch',str(epoch)) + + if not hasattr(opt,'seg_gen'): # old settings for DeepLIIF models + opt.seg_gen = True + + # hard-code some parameters for test.py + opt.aspect_ratio = 1.0 # from previous default setting + opt.display_winsize = 512 # from previous default setting + opt.use_dp = True # whether to initialize model in DataParallel setting (all models to one gpu, then pytorch controls the usage of specified set of GPUs for inference) + opt.num_threads = 0 # test code only supports num_threads = 1 + opt.batch_size = 1 # test code only supports batch_size = 1 + opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. + opt.no_flip = True # no flip; comment this line if results on flipped images are needed. + opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. + print_options(opt) + + batch_size = batch_size if batch_size else opt.batch_size + dataset = create_dataset(opt, phase=phase, batch_size=batch_size) # create a dataset given opt.dataset_mode and other options + model = create_model(opt) # create a model given opt.model and other options + model.setup(opt) # regular setup: load and print networks; create schedulers + torch.backends.cudnn.benchmark = False + # create a website + web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch)) # define the website directory + if opt.load_iter > 0: # load_iter is 0 by default + web_dir = '{:s}_iter{:d}'.format(web_dir, opt.load_iter) + print('creating web directory', web_dir) + webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) + # test with eval mode. This only affects layers like batchnorm and dropout. + # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. + # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. + model.eval() + # if opt.eval: + # model.eval() + + _start_time = time.time() + + for i, data in enumerate(dataset): + if i >= opt.num_test: # only apply our model to opt.num_test images. + break + + model.set_input(data) # unpack data from data loader + model.test() # run inference + visuals = model.get_current_visuals() # get image results + img_path = model.get_image_paths() # get image paths + if i % 5 == 0: # save images to an HTML file + print('processing (%04d)-th image... %s (batch size %s)' % (i*batch_size, img_path[-1], batch_size)) + save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) + + t_sec = round(time.time() - _start_time) + (t_min, t_sec) = divmod(t_sec, 60) + (t_hour, t_min) = divmod(t_min, 60) + print('Time passed: {}hour:{}min:{}sec'.format(t_hour, t_min, t_sec)) + webpage.save() # save the HTML + + +def generate_predictions_inference(input_dir, output_dir, tile_size=None, model_dir='./model-server/DeepLIIF_Latest_Model/', + gpu_ids=(-1,), region_size=20000, eager_mode=True, color_dapi=False, color_marker=False, + save_type_pattern='Seg'): + output_dir = output_dir or input_dir + ensure_exists(output_dir) + + image_files = sorted([fn for fn in os.listdir(input_dir) if allowed_file(fn)]) + print(len(image_files),'images found') + files = os.listdir(model_dir) + assert 'train_opt.txt' in files, f'file train_opt.txt is missing from model directory {model_dir}' + opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode='test') + opt.use_dp = False + + number_of_gpus_all = torch.cuda.device_count() + if number_of_gpus_all < len(gpu_ids) and -1 not in gpu_ids: + number_of_gpus = 0 + gpu_ids = [-1] + print(f'Specified to use GPU {opt.gpu_ids} for inference, but there are only {number_of_gpus_all} GPU devices. Switched to CPU inference.') + + if len(gpu_ids) > 0 and gpu_ids[0] == -1: + gpu_ids = [] + elif len(gpu_ids) == 0: + gpu_ids = list(range(number_of_gpus_all)) + + opt.gpu_ids = gpu_ids # overwrite gpu_ids; for test command, default gpu_ids at first is [] which will be translated to a list of all gpus + + # fix opt from old settings + if not hasattr(opt,'modalities_no') and hasattr(opt,'targets_no'): + opt.modalities_no = opt.targets_no - 1 + del opt.targets_no + print_options(opt) + + count = 0 + with click.progressbar( + image_files, + label=f'Processing {len(image_files)} images', + item_show_func=lambda fn: fn + ) as bar: + for filename in bar: + if '.svs' in filename: + start_time = time.time() + infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size) + print(time.time() - start_time) + else: + img = Image.open(os.path.join(input_dir, filename)).convert('RGB') + images, scoring = infer_modalities(img, tile_size, model_dir, eager_mode, color_dapi, color_marker, opt) + + for name, i in images.items(): + if save_type_pattern and save_type_pattern in name: + i.save(os.path.join( + output_dir, + filename.replace('.' + filename.split('.')[-1], f'_{name}.png') + )) + + with open(os.path.join( + output_dir, + filename.replace('.' + filename.split('.')[-1], f'.json') + ), 'w') as f: + json.dump(scoring, f, indent=2) + + count += 1 + if count % 100 == 0 or count == len(image_files): + print(f'Done {count}/{len(image_files)}') + + +def evaluate(model_dir, input_dir, output_dir, seg_model_dir=None, phase='val', subfolder=None, image_types='B_1,B_2,B_3,B_4', seg_type='B_5', + seg_gen=False, epoch='latest',mode='Segmentation', seg_on_input=False, include_input_metrics=False, batch_size=None, gpu_ids=(-1,),overwrite=False, verbose=False): + """ + params: + model_dir: directory of the trained model + input_dir: input directory of the whole dataset, which should contain multiple subfolders like train and test + output_dir: path to the trained model + seg_model_dir: only used when seg_gen=False; this is the model to generate segmentation mask for both gt and pred images + phase: this effectively is the subfolder name under model_dir in which is the input data + subfolder: this effectively is the subfolder name under output_dir in which is the folder for predictions "./images" locates + image_types: unique marker in filename for each modality; in model class DeepLIIFExt and SDG, it could be a string like B_1,B_2,B_3 + seg_type: unique marker in filename for segmentation, if exists; in model class DeepLIIFExt and SDG, it could be a string like B_4 + seg_gen: True (Translation and Segmentation), False (Only Translation). + epoch: the name of epoch to load, default to "latest" + mode: Mode of the statistics computation including Segmentation, ImageSynthesis, All, SSIM, Upscaling + seg_on_input: a flag to indicate whether to run the seg model on the input image; this applies only to the situation where + seg_gen=False, mode includes segmentation metrics, and seg_model_dir is provided; if True, 2 sets of segmentation metrics will + be returned (seg on fake vs seg on real, seg on input vs seg on real) + batch_size: batch size for inference; if not set, default is opt.batch_size + overwrite: overwrite results; otherwise, if output_dir already exists, prediction generation will be skipped + verbose: print more info if True + """ + if seg_gen == False and mode in ['Segmentation','All']: + if not seg_model_dir: + mode = 'SSIM' + print('seg_gen is False, cannot run segmentation metrics; mode is changed to SSIM') + + d_res = {'elapsed_time':[]} + if verbose: + params = locals() + for k,v in params: + print(k,v) + + if not subfolder: + subfolder = f'{phase}_{epoch}' + + print('Generating predictions...') + time_s = time.time() + if os.path.exists(os.path.join(output_dir, subfolder)): + print(f'Folder {os.path.join(output_dir, subfolder)} already exists') + print(f'overwrite: {overwrite}') + if not overwrite: + print('Skip prediction generation') + else: + print('Deleting folder and regenerating predictions...') + shutil.rmtree(output_dir) + generate_predictions(checkpoints_dir=model_dir, name='.', dataroot=input_dir, + results_dir=output_dir, gpu_ids=gpu_ids, phase=phase, + batch_size=batch_size, epoch=epoch) + else: + generate_predictions(checkpoints_dir=model_dir, name='.', dataroot=input_dir, + results_dir=output_dir, gpu_ids=gpu_ids, phase=phase, + batch_size=batch_size, epoch=epoch) + d_res['elapsed_time'].append(time.time() - time_s) + + if seg_gen == False and seg_model_dir: + print('Generating segmentation mask...') + time_s = time.time() + seg_output_dir = os.path.join(output_dir, subfolder, 'seg') + if os.path.exists(seg_output_dir): + print(f'Folder {os.path.join(output_dir, subfolder, "seg")} already exists') + print(f'overwrite: {overwrite}') + if not overwrite: + print('Skip segmentation mask generation') + else: + print('Deleting folder and regenerating segmentation mask...') + shutil.rmtree(os.path.join(output_dir, subfolder, 'seg')) + generate_predictions_inference(input_dir=os.path.join(output_dir,subfolder,'images'), + output_dir=seg_output_dir, tile_size=512, model_dir=seg_model_dir, + gpu_ids=gpu_ids, eager_mode=True) + else: + generate_predictions_inference(input_dir=os.path.join(output_dir,subfolder,'images'), + output_dir=seg_output_dir, tile_size=512, model_dir=seg_model_dir, + gpu_ids=gpu_ids, eager_mode=True) + d_res['elapsed_time'].append(time.time() - time_s) + + print('Preparing folder structure and formating filenames...') + time_s = time.time() + format_file_structure(os.path.join(output_dir,subfolder)) # the folder name val_latest + d_res['elapsed_time'].append(time.time() - time_s) + + if seg_gen == False and seg_model_dir: + print('Preparing folder structure and formating filenames...') + time_s = time.time() + format_file_structure(os.path.join(output_dir,subfolder),source_folder='seg',target_folder={'gt_seg':'_real_B_1_SegRefined','pred_seg':'_fake_B_1_SegRefined', 'input_seg':'_real_A_SegRefined'}) # the folder name val_latest + d_res['elapsed_time'].append(time.time() - time_s) + + print('Starting ComputeStatistics.py...') + + time_s = time.time() + if seg_gen == False and seg_model_dir and mode in ['Segmentation', 'All']: + gt_dir = os.path.join(output_dir, subfolder, 'gt_seg') + pred_dir = os.path.join(output_dir, subfolder, 'pred_seg') + stats = Statistics(gt_path=gt_dir, model_path=pred_dir, output_path=output_dir, + image_types=image_types, seg_type=seg_type, mode='Segmentation') + d_stat = stats.run() + + if mode == 'All': + gt_dir = os.path.join(output_dir, subfolder, 'gt') + pred_dir = os.path.join(output_dir, subfolder, 'pred') + stats = Statistics(gt_path=gt_dir, model_path=pred_dir, output_path=output_dir, + image_types=image_types, seg_type=seg_type, mode='ImageSynthesis') + d_stat_imagesynthesis = stats.run() + d_stat = {**d_stat, **d_stat_imagesynthesis} + else: + gt_dir = os.path.join(output_dir, subfolder, 'gt') + pred_dir = os.path.join(output_dir, subfolder, 'pred') + stats = Statistics(gt_path=gt_dir, model_path=pred_dir, output_path=output_dir, + image_types=image_types, seg_type=seg_type, mode=mode) + d_stat = stats.run() + + if include_input_metrics: + input_dir = os.path.join(output_dir, subfolder, 'input') + stats = Statistics(gt_path=gt_dir, model_path=input_dir, output_path=output_dir, + image_types=image_types, seg_type=seg_type, mode=mode) + d_res['elapsed_time'].append(time.time() - time_s) + + + return {**d_stat, **d_res} + + diff --git a/deepliif/stat/fid.py b/deepliif/stat/fid.py new file mode 100644 index 0000000..d57a50a --- /dev/null +++ b/deepliif/stat/fid.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +''' Calculates the Frechet Inception Distance (FID) to evalulate GANs. + +The FID metric calculates the distance between two distributions of images. +Typically, we have summary statistics (mean & covariance matrix) of one +of these distributions, while the 2nd distribution is given by a GAN. + +When run as a stand-alone program, it compares the distribution of +images that are stored as PNG/JPEG at a specified location with a +distribution given by summary statistics (in pickle format). + +The FID is calculated by assuming that X_1 and X_2 are the activations of +the pool_3 layer of the inception net for generated samples and real world +samples respectivly. + +See --help to see further details. +''' + +from __future__ import absolute_import, division, print_function +import numpy as np +import os +import gzip, pickle +import tensorflow as tf +from imageio import imread +from scipy import linalg +import pathlib +import urllib +import warnings + +class InvalidFIDException(Exception): + pass + + +def create_inception_graph(pth): + """Creates a graph from saved GraphDef file.""" + # Creates graph from saved graph_def.pb. + with tf.io.gfile.GFile( pth, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString( f.read()) + _ = tf.import_graph_def( graph_def, name='FID_Inception_Net') +#------------------------------------------------------------------------------- + + +# code for handling inception net derived from +# https://github.com/openai/improved-gan/blob/master/inception_score/model.py +def _get_inception_layer(sess): + """Prepares inception net for batched usage and returns pool_3 layer. """ + layername = 'FID_Inception_Net/pool_3:0' + pool3 = sess.graph.get_tensor_by_name(layername) + ops = pool3.graph.get_operations() + for op_idx, op in enumerate(ops): + for o in op.outputs: + shape = o.get_shape() + if shape._dims is not None: + #shape = [s.value for s in shape] TF 1.x + shape = [s for s in shape] #TF 2.x + new_shape = [] + for j, s in enumerate(shape): + if s == 1 and j == 0: + new_shape.append(None) + else: + new_shape.append(s) + o.__dict__['_shape_val'] = tf.TensorShape(new_shape) + return pool3 +#------------------------------------------------------------------------------- + + +def get_activations(images, sess, batch_size=50, verbose=False): + """Calculates the activations of the pool_3 layer for all images. + + Params: + -- images : Numpy array of dimension (n_images, hi, wi, 3). The values + must lie between 0 and 256. + -- sess : current session + -- batch_size : the images numpy array is split into batches with batch size + batch_size. A reasonable batch size depends on the disposable hardware. + -- verbose : If set to True and parameter out_step is given, the number of calculated + batches is reported. + Returns: + -- A numpy array of dimension (num images, 2048) that contains the + activations of the given tensor when feeding inception with the query tensor. + """ + inception_layer = _get_inception_layer(sess) + n_images = images.shape[0] + if batch_size > n_images: + print("warning: batch size is bigger than the data size. setting batch size to data size") + batch_size = n_images + n_batches = n_images//batch_size # drops the last batch if < batch_size + pred_arr = np.empty((n_batches * batch_size,2048)) + for i in range(n_batches): + if verbose: + print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True) + start = i*batch_size + + if start+batch_size < n_images: + end = start+batch_size + else: + end = n_images + + batch = images[start:end] + pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch}) + pred_arr[start:end] = pred.reshape(batch.shape[0],-1) + if verbose: + print(" done") + return pred_arr +#------------------------------------------------------------------------------- + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of the pool_3 layer of the + inception net ( like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted + on an representive data set. + -- sigma1: The covariance matrix over activations of the pool_3 layer for + generated samples. + -- sigma2: The covariance matrix over activations of the pool_3 layer, + precalcualted on an representive data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" + assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps + warnings.warn(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean +#------------------------------------------------------------------------------- + + +def calculate_activation_statistics(images, sess, batch_size=50, verbose=False): + """Calculation of the statistics used by the FID. + Params: + -- images : Numpy array of dimension (n_images, hi, wi, 3). The values + must lie between 0 and 255. + -- sess : current session + -- batch_size : the images numpy array is split into batches with batch size + batch_size. A reasonable batch size depends on the available hardware. + -- verbose : If set to True and parameter out_step is given, the number of calculated + batches is reported. + Returns: + -- mu : The mean over samples of the activations of the pool_3 layer of + the incption model. + -- sigma : The covariance matrix of the activations of the pool_3 layer of + the incption model. + """ + act = get_activations(images, sess, batch_size, verbose) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +#------------------ +# The following methods are implemented to obtain a batched version of the activations. +# This has the advantage to reduce memory requirements, at the cost of slightly reduced efficiency. +# - Pyrestone +#------------------ + + +def load_image_batch(files): + """Convenience method for batch-loading images + Params: + -- files : list of paths to image files. Images need to have same dimensions for all files. + Returns: + -- A numpy array of dimensions (num_images,hi, wi, 3) representing the image pixel values. + """ + return np.array([imread(str(fn)).astype(np.float32) for fn in files]) + +def get_activations_from_files(files, sess, batch_size=50, verbose=False): + """Calculates the activations of the pool_3 layer for all images. + + Params: + -- files : list of paths to image files. Images need to have same dimensions for all files. + -- sess : current session + -- batch_size : the images numpy array is split into batches with batch size + batch_size. A reasonable batch size depends on the disposable hardware. + -- verbose : If set to True and parameter out_step is given, the number of calculated + batches is reported. + Returns: + -- A numpy array of dimension (num images, 2048) that contains the + activations of the given tensor when feeding inception with the query tensor. + """ + inception_layer = _get_inception_layer(sess) + n_imgs = len(files) + if batch_size > n_imgs: + print("warning: batch size is bigger than the data size. setting batch size to data size") + batch_size = n_imgs + n_batches = n_imgs//batch_size + 1 + pred_arr = np.empty((n_imgs,2048)) + for i in range(n_batches): + if verbose: + print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True) + start = i*batch_size + if start+batch_size < n_imgs: + end = start+batch_size + else: + end = n_imgs + + batch = load_image_batch(files[start:end]) + pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch}) + pred_arr[start:end] = pred.reshape(batch_size,-1) + del batch #clean up memory + if verbose: + print(" done") + return pred_arr + +def calculate_activation_statistics_from_files(files, sess, batch_size=50, verbose=False): + """Calculation of the statistics used by the FID. + Params: + -- files : list of paths to image files. Images need to have same dimensions for all files. + -- sess : current session + -- batch_size : the images numpy array is split into batches with batch size + batch_size. A reasonable batch size depends on the available hardware. + -- verbose : If set to True and parameter out_step is given, the number of calculated + batches is reported. + Returns: + -- mu : The mean over samples of the activations of the pool_3 layer of + the incption model. + -- sigma : The covariance matrix of the activations of the pool_3 layer of + the incption model. + """ + act = get_activations_from_files(files, sess, batch_size, verbose) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + +#------------------------------------------------------------------------------- + + +#------------------------------------------------------------------------------- +# The following functions aren't needed for calculating the FID +# they're just here to make this module work as a stand-alone script +# for calculating FID scores +#------------------------------------------------------------------------------- +def check_or_download_inception(inception_path): + ''' Checks if the path to the inception file is valid, or downloads + the file if it is not present. ''' + INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' + if inception_path is None: + inception_path = '/tmp' + inception_path = pathlib.Path(inception_path) + model_file = inception_path / 'classify_image_graph_def.pb' + if not model_file.exists(): + print("Downloading Inception model") + from urllib import request + import tarfile + fn, _ = request.urlretrieve(INCEPTION_URL) + with tarfile.open(fn, mode='r') as f: + f.extract('classify_image_graph_def.pb', str(model_file.parent)) + return str(model_file) + + +def _handle_path(path, sess, low_profile=False): + if path.endswith('.npz'): + f = np.load(path) + m, s = f['mu'][:], f['sigma'][:] + f.close() + else: + path = pathlib.Path(path) + files = list(path.glob('*.jpg')) + list(path.glob('*.png')) + if low_profile: + m, s = calculate_activation_statistics_from_files(files, sess) + else: + x = np.array([imread(str(fn)).astype(np.float32) for fn in files]) + m, s = calculate_activation_statistics(x, sess) + del x #clean up memory + return m, s + + +def calculate_fid_given_paths(paths, inception_path, low_profile=False): + ''' Calculates the FID of two paths. ''' + inception_path = check_or_download_inception(inception_path) + + for p in paths: + if not os.path.exists(p): + raise RuntimeError("Invalid path: %s" % p) + + create_inception_graph(str(inception_path)) + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + m1, s1 = _handle_path(paths[0], sess, low_profile=low_profile) + m2, s2 = _handle_path(paths[1], sess, low_profile=low_profile) + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + return fid_value + + +if __name__ == "__main__": + from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument("path", type=str, nargs=2, + help='Path to the generated images or to .npz statistic files') + parser.add_argument("-i", "--inception", type=str, default=None, + help='Path to Inception model (will be downloaded if not provided)') + parser.add_argument("--gpu", default="", type=str, + help='GPU to use (leave blank for CPU only)') + parser.add_argument("--lowprofile", action="store_true", + help='Keep only one batch of images in memory at a time. This reduces memory footprint, but may decrease speed slightly.') + args = parser.parse_args() + os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + fid_value = calculate_fid_given_paths(args.path, args.inception, low_profile=args.lowprofile) + print("FID: ", fid_value) diff --git a/deepliif/stat/fid_official_tf.py b/deepliif/stat/fid_official_tf.py new file mode 100644 index 0000000..ab67ce5 --- /dev/null +++ b/deepliif/stat/fid_official_tf.py @@ -0,0 +1,370 @@ +""" +@Brief: + Tensorflow implementation of FID score, should be the same as the official one + modified from official inception score implementation + [bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR) +@Author: lzhbrian (https://lzhbrian.me) +@Date: 2019.4.7 +@Usage: + # CMD + # from 2 precalculated stats + python fid_official_tf.py res/stats_tf/fid_stats_imagenet_valid.npz res/stats_tf/fid_stats_imagenet_train.npz --gpu 0 + + # from 1 precalculated stats, 1 image foldername/ + python fid_official_tf.py res/stats_tf/fid_stats_imagenet_valid.npz /path/to/image/foldername/ --gpu 0 + + # from 2 image foldername/ + python fid_official_tf.py /path/to/image/foldername1/ /path/to/image/foldername2/ --gpu 0 + + # used in code + ``` + import tensorflow as tf + + # load from precalculated + f = np.load('res/stats_tf/fid_stats_imagenet_train.npz') + mu1, sigma1 = f['mu'][:], f['sigma'][:] + f.close() + + # calc from image ndarray + # images should be Numpy array of dimension (N, H, W, C). images should be in 0~255 + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + mu2, sigma2 = fid_official_tf.calculate_activation_statistics(images, sess, batch_size=100) + fid_score = calculate_frechet_distance(mu1, sigma1, mu2, sigma2) + ``` + +@Note: + Need to first download stats_tf of datasets in stats_tf/, see README.md + + also, the same as inception_score_official_tf.py, the inception model used + contains resize and normalization layers + so the input of our images should be 0~255, and arbitrary HxW size + + For calculating mu and sigma for foldername/, see precalc_stats_official_tf.py +""" + +import numpy as np +import os +import tensorflow as tf +import scipy.misc +from scipy.misc import imread +from scipy import linalg +import pathlib +import urllib +import warnings +from tqdm import tqdm + + +cur_dirname = os.path.dirname(os.path.abspath(__file__)) + +MODEL_DIR = '%s/res/' % cur_dirname + +class InvalidFIDException(Exception): + pass + + +def create_inception_graph(pth): + """Creates a graph from saved GraphDef file.""" + # Creates graph from saved graph_def.pb. + with tf.gfile.FastGFile(pth, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + _ = tf.import_graph_def(graph_def, name='FID_Inception_Net') + + +# ------------------------------------------------------------------------------- + + +# code for handling inception net derived from +# https://github.com/openai/improved-gan/blob/master/inception_score/model.py +def _get_inception_layer(sess): + """Prepares inception net for batched usage and returns pool_3 layer. """ + layername = 'FID_Inception_Net/pool_3:0' + pool3 = sess.graph.get_tensor_by_name(layername) + ops = pool3.graph.get_operations() + for op_idx, op in enumerate(ops): + for o in op.outputs: + shape = o.get_shape() + if shape._dims != []: + shape = [s.value for s in shape] + new_shape = [] + for j, s in enumerate(shape): + if s == 1 and j == 0: + new_shape.append(None) + else: + new_shape.append(s) + o.__dict__['_shape_val'] = tf.TensorShape(new_shape) + return pool3 + + +# ------------------------------------------------------------------------------- + + +def get_activations(images, sess, batch_size=50, verbose=False): + """Calculates the activations of the pool_3 layer for all images. + Params: + -- images : Numpy array of dimension (n_images, hi, wi, 3). The values + must lie between 0 and 256. + -- sess : current session + -- batch_size : the images numpy array is split into batches with batch size + batch_size. A reasonable batch size depends on the disposable hardware. + -- verbose : If set to True and parameter out_step is given, the number of calculated + batches is reported. + Returns: + -- A numpy array of dimension (num images, 2048) that contains the + activations of the given tensor when feeding inception with the query tensor. + """ + inception_layer = _get_inception_layer(sess) + d0 = images.shape[0] + if batch_size > d0: + print("warning: batch size is bigger than the data size. setting batch size to data size") + batch_size = d0 + n_batches = d0 // batch_size + n_used_imgs = n_batches * batch_size + pred_arr = np.empty((n_used_imgs, 2048)) + for i in tqdm(range(n_batches)): + if verbose: + print("\rPropagating batch %d/%d" % (i + 1, n_batches)) + start = i * batch_size + end = start + batch_size + batch = images[start:end] + pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch}) + pred_arr[start:end] = pred.reshape(batch_size, -1) + if verbose: + print(" done") + return pred_arr + + +# ------------------------------------------------------------------------------- + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of the pool_3 layer of the + inception net ( like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted + on an representive data set. + -- sigma1: The covariance matrix over activations of the pool_3 layer for + generated samples. + -- sigma2: The covariance matrix over activations of the pool_3 layer, + precalcualted on an representive data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" + assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps + warnings.warn(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +# ------------------------------------------------------------------------------- + + +def calculate_activation_statistics(images, sess, batch_size=50, verbose=False): + """Calculation of the statistics used by the FID. + Params: + -- images : Numpy array of dimension (n_images, hi, wi, 3). The values + must lie between 0 and 255. + -- sess : current session + -- batch_size : the images numpy array is split into batches with batch size + batch_size. A reasonable batch size depends on the available hardware. + -- verbose : If set to True and parameter out_step is given, the number of calculated + batches is reported. + Returns: + -- mu : The mean over samples of the activations of the pool_3 layer of + the inception model. + -- sigma : The covariance matrix of the activations of the pool_3 layer of + the inception model. + """ + act = get_activations(images, sess, batch_size, verbose) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +# ------------------ +# The following methods are implemented to obtain a batched version of the activations. +# This has the advantage to reduce memory requirements, at the cost of slightly reduced efficiency. +# - Pyrestone +# ------------------ + + +def load_image_batch(files): + """Convenience method for batch-loading images + Params: + -- files : list of paths to image files. Images need to have same dimensions for all files. + Returns: + -- A numpy array of dimensions (num_images,hi, wi, 3) representing the image pixel values. + """ + return np.array([imread(str(fn)).astype(np.float32) for fn in files]) + + +def get_activations_from_files(files, sess, batch_size=50, verbose=False): + """Calculates the activations of the pool_3 layer for all images. + Params: + -- files : list of paths to image files. Images need to have same dimensions for all files. + -- sess : current session + -- batch_size : the images numpy array is split into batches with batch size + batch_size. A reasonable batch size depends on the disposable hardware. + -- verbose : If set to True and parameter out_step is given, the number of calculated + batches is reported. + Returns: + -- A numpy array of dimension (num images, 2048) that contains the + activations of the given tensor when feeding inception with the query tensor. + """ + inception_layer = _get_inception_layer(sess) + d0 = len(files) + if batch_size > d0: + print("warning: batch size is bigger than the data size. setting batch size to data size") + batch_size = d0 + n_batches = d0 // batch_size + n_used_imgs = n_batches * batch_size + pred_arr = np.empty((n_used_imgs, 2048)) + for i in range(n_batches): + if verbose: + print("\rPropagating batch %d/%d" % (i + 1, n_batches)) + start = i * batch_size + end = start + batch_size + batch = load_image_batch(files[start:end]) + pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch}) + pred_arr[start:end] = pred.reshape(batch_size, -1) + del batch # clean up memory + if verbose: + print(" done") + return pred_arr + + +def calculate_activation_statistics_from_files(files, sess, batch_size=50, verbose=False): + """Calculation of the statistics used by the FID. + Params: + -- files : list of paths to image files. Images need to have same dimensions for all files. + -- sess : current session + -- batch_size : the images numpy array is split into batches with batch size + batch_size. A reasonable batch size depends on the available hardware. + -- verbose : If set to True and parameter out_step is given, the number of calculated + batches is reported. + Returns: + -- mu : The mean over samples of the activations of the pool_3 layer of + the inception model. + -- sigma : The covariance matrix of the activations of the pool_3 layer of + the inception model. + """ + act = get_activations_from_files(files, sess, batch_size, verbose) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +# ------------------------------------------------------------------------------- + + +# ------------------------------------------------------------------------------- +# The following functions aren't needed for calculating the FID +# they're just here to make this module work as a stand-alone script +# for calculating FID scores +# ------------------------------------------------------------------------------- +def check_or_download_inception(inception_path): + ''' Checks if the path to the inception file is valid, or downloads + the file if it is not present. ''' + INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' + if inception_path is None: + inception_path = MODEL_DIR + inception_path = pathlib.Path(inception_path) + model_file = inception_path / 'classify_image_graph_def.pb' + if not model_file.exists(): + print("Downloading Inception model") + from urllib import request + import tarfile + fn, _ = request.urlretrieve(INCEPTION_URL) + with tarfile.open(fn, mode='r') as f: + f.extract('classify_image_graph_def.pb', str(model_file.parent)) + return str(model_file) + + +def _handle_path(path, sess, low_profile=False): + if path.endswith('.npz'): + f = np.load(path) + m, s = f['mu'][:], f['sigma'][:] + f.close() + else: + path = pathlib.Path(path) + files = [] + for ext in ('*.png', '*.jpg', '*.jpeg', '.bmp'): + files.extend( list(path.glob(ext)) ) + + if low_profile: + m, s = calculate_activation_statistics_from_files(files, sess) + else: + # x = np.array([scipy.misc.imresize(imread(str(fn), mode='RGB'), (299, 299), interp='bilinear').astype(np.float32) for fn in files]) + x = np.array([imread(str(fn)).astype(np.float32) for fn in files]) + m, s = calculate_activation_statistics(x, sess) + del x # clean up memory + return m, s + + +def calculate_fid_given_paths(paths, inception_path, low_profile=False): + ''' Calculates the FID of two paths. ''' + inception_path = check_or_download_inception(inception_path) + + for p in paths: + if not os.path.exists(p): + raise RuntimeError("Invalid path: %s" % p) + + create_inception_graph(str(inception_path)) + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + m1, s1 = _handle_path(paths[0], sess, low_profile=low_profile) + m2, s2 = _handle_path(paths[1], sess, low_profile=low_profile) + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + return fid_value + + +if __name__ == "__main__": + from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter + + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument("path", type=str, nargs=2, + help='Path to the generated images or to .npz statistic files') + parser.add_argument("-i", "--inception", type=str, default=None, + help='Path to Inception model (will be downloaded if not provided)') + parser.add_argument("--gpu", default="", type=str, + help='GPU to use (leave blank for CPU only)') + parser.add_argument("--lowprofile", action="store_true", + help='Keep only one batch of images in memory at a time. This reduces memory footprint, but may decrease speed slightly.') + args = parser.parse_args() + os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + fid_value = calculate_fid_given_paths(args.path, args.inception, low_profile=args.lowprofile) + print("FID: ", fid_value) diff --git a/deepliif/stat/inception_score.py b/deepliif/stat/inception_score.py new file mode 100644 index 0000000..586b163 --- /dev/null +++ b/deepliif/stat/inception_score.py @@ -0,0 +1,100 @@ +# calculate inception score with Keras +import os +import sys +from math import floor + +import numpy as np +from numpy import ones +from numpy import expand_dims +from numpy import log +from numpy import mean +from numpy import std +from numpy import exp +try: + from tensorflow.keras.applications.inception_v3 import InceptionV3 + from tensorflow.keras.applications.inception_v3 import preprocess_input +except: + from keras.applications.inception_v3 import InceptionV3 + from keras.applications.inception_v3 import preprocess_input +import cv2 + + +# assumes images have the shape 299x299x3, pixels in [0,255] +def calculate_inception_score(images, n_split=10, eps=1E-16): + # load inception v3 model + model = InceptionV3() + # convert from uint8 to float32 + processed = images.astype('float32') + # pre-process raw images for inception v3 model + processed = preprocess_input(processed) + # predict class probabilities for images + yhat = model.predict(processed) + # enumerate splits of images/predictions + scores = list() + n_part = floor(images.shape[0] / n_split) + for i in range(n_split): + # retrieve p(y|x) + ix_start, ix_end = i * n_part, i * n_part + n_part + p_yx = yhat[ix_start:ix_end] + # calculate p(y) + p_y = expand_dims(p_yx.mean(axis=0), 0) + # calculate KL divergence using log probabilities + kl_d = p_yx * (log(p_yx + eps) - log(p_y + eps)) + # sum over classes + sum_kl_d = kl_d.sum(axis=1) + # average over images + avg_kl_d = mean(sum_kl_d) + # undo the log + is_score = exp(avg_kl_d) + # store + scores.append(is_score) + # average across images + is_avg, is_std = mean(scores), std(scores) + return is_avg, is_std + +# assumes images have the shape 299x299x3, pixels in [0,255] +def calculate_inception_score(images, n_split=10, eps=1E-16): + # load inception v3 model + model = InceptionV3() + # convert from uint8 to float32 + processed = images.astype('float32') + # pre-process raw images for inception v3 model + processed = preprocess_input(processed) + # predict class probabilities for images + yhat = model.predict(processed) + # enumerate splits of images/predictions + scores = list() + n_part = floor(images.shape[0] / n_split) + for i in range(n_split): + # retrieve p(y|x) + ix_start, ix_end = i * n_part, i * n_part + n_part + p_yx = yhat[ix_start:ix_end] + # calculate p(y) + p_y = expand_dims(p_yx.mean(axis=0), 0) + # calculate KL divergence using log probabilities + kl_d = p_yx * (log(p_yx + eps) - log(p_y + eps)) + # sum over classes + sum_kl_d = kl_d.sum(axis=1) + # average over images + avg_kl_d = mean(sum_kl_d) + # undo the log + is_score = exp(avg_kl_d) + # store + scores.append(is_score) + # average across images + is_avg, is_std = mean(scores), std(scores) + return is_avg, is_std + +if __name__ == '__main__': + # input_dir = 'D:/Pathology/DMIT/datasets/IHC2DAPI/testB' + input_dir = sys.argv[1] + print(input_dir) + images = os.listdir(input_dir) + real_images_array = [] + for img in images: + image = cv2.imread(os.path.join(input_dir, img)) + image = cv2.resize(image, (299, 299)) + real_images_array.append(image) + real_images_array = np.array(real_images_array) + is_avg, is_std = calculate_inception_score(real_images_array) + print('score', is_avg, is_std) \ No newline at end of file diff --git a/deepliif/stat/swd.py b/deepliif/stat/swd.py new file mode 100644 index 0000000..bd6a8b8 --- /dev/null +++ b/deepliif/stat/swd.py @@ -0,0 +1,158 @@ +from PIL import Image +import numpy as np +import torch +import torch.nn.functional as F +import torchvision + +# Gaussian blur kernel +def get_gaussian_kernel(device="cpu"): + kernel = np.array([ + [1, 4, 6, 4, 1], + [4, 16, 24, 16, 4], + [6, 24, 36, 24, 6], + [4, 16, 24, 16, 4], + [1, 4, 6, 4, 1]], np.float32) / 256.0 + gaussian_k = torch.as_tensor(kernel.reshape(1, 1, 5, 5)).to(device) + return gaussian_k + +def pyramid_down(image, device="cpu"): + gaussian_k = get_gaussian_kernel(device=device) + # channel-wise conv(important) + multiband = [F.conv2d(image[:, i:i + 1,:,:], gaussian_k, padding=2, stride=2) for i in range(3)] + down_image = torch.cat(multiband, dim=1) + return down_image + +def pyramid_up(image, device="cpu"): + gaussian_k = get_gaussian_kernel(device=device) + upsample = F.interpolate(image, scale_factor=2) + multiband = [F.conv2d(upsample[:, i:i + 1,:,:], gaussian_k, padding=2) for i in range(3)] + up_image = torch.cat(multiband, dim=1) + return up_image + +def gaussian_pyramid(original, n_pyramids, device="cpu"): + x = original + # pyramid down + pyramids = [original] + for i in range(n_pyramids): + x = pyramid_down(x, device=device) + pyramids.append(x) + return pyramids + +def laplacian_pyramid(original, n_pyramids, device="cpu"): + # create gaussian pyramid + pyramids = gaussian_pyramid(original, n_pyramids, device=device) + + # pyramid up - diff + laplacian = [] + for i in range(len(pyramids) - 1): + diff = pyramids[i] - pyramid_up(pyramids[i + 1], device=device) + laplacian.append(diff) + # Add last gaussian pyramid + laplacian.append(pyramids[len(pyramids) - 1]) + return laplacian + +def minibatch_laplacian_pyramid(image, n_pyramids, batch_size, device="cpu"): + n = image.size(0) // batch_size + np.sign(image.size(0) % batch_size) + pyramids = [] + for i in range(n): + x = image[i * batch_size:(i + 1) * batch_size] + p = laplacian_pyramid(x.to(device), n_pyramids, device=device) + p = [x.cpu() for x in p] + pyramids.append(p) + del x + result = [] + for i in range(n_pyramids + 1): + x = [] + for j in range(n): + x.append(pyramids[j][i]) + result.append(torch.cat(x, dim=0)) + return result + +def extract_patches(pyramid_layer, slice_indices, + slice_size=7, unfold_batch_size=128, device="cpu"): + assert pyramid_layer.ndim == 4 + n = pyramid_layer.size(0) // unfold_batch_size + np.sign(pyramid_layer.size(0) % unfold_batch_size) + # random slice 7x7 + p_slice = [] + for i in range(n): + # [unfold_batch_size, ch, n_slices, slice_size, slice_size] + ind_start = i * unfold_batch_size + ind_end = min((i + 1) * unfold_batch_size, pyramid_layer.size(0)) + x = pyramid_layer[ind_start:ind_end].unfold( + 2, slice_size, 1).unfold(3, slice_size, 1).reshape( + ind_end - ind_start, pyramid_layer.size(1), -1, slice_size, slice_size) + # [unfold_batch_size, ch, n_descriptors, slice_size, slice_size] + x = x[:,:, slice_indices,:,:] + # [unfold_batch_size, n_descriptors, ch, slice_size, slice_size] + p_slice.append(x.permute([0, 2, 1, 3, 4])) + # sliced tensor per layer [batch, n_descriptors, ch, slice_size, slice_size] + x = torch.cat(p_slice, dim=0) + # normalize along ch + std, mean = torch.std_mean(x, dim=(0, 1, 3, 4), keepdim=True) + x = (x - mean) / (std + 1e-8) + # reshape to 2rank + x = x.reshape(-1, 3 * slice_size * slice_size) + return x + +def swd(image1, image2, + n_pyramids=None, slice_size=7, n_descriptors=128, + n_repeat_projection=128, proj_per_repeat=4, device="cpu", return_by_resolution=False, + pyramid_batchsize=128): + # n_repeat_projectton * proj_per_repeat = 512 + # Please change these values according to memory usage. + # original = n_repeat_projection=4, proj_per_repeat=128 + assert image1.size() == image2.size() + assert image1.ndim == 4 and image2.ndim == 4 + + if n_pyramids is None: + n_pyramids = int(np.rint(np.log2(image1.size(2) // 16))) + with torch.no_grad(): + # minibatch laplacian pyramid for cuda memory reasons + pyramid1 = minibatch_laplacian_pyramid(image1, n_pyramids, pyramid_batchsize, device=device) + pyramid2 = minibatch_laplacian_pyramid(image2, n_pyramids, pyramid_batchsize, device=device) + result = [] + + for i_pyramid in range(n_pyramids + 1): + # indices + n = (pyramid1[i_pyramid].size(2) - 6) * (pyramid1[i_pyramid].size(3) - 6) + indices = torch.randperm(n)[:n_descriptors] + + # extract patches on CPU + # patch : 2rank (n_image*n_descriptors, slice_size**2*ch) + p1 = extract_patches(pyramid1[i_pyramid], indices, + slice_size=slice_size, device="cpu") + p2 = extract_patches(pyramid2[i_pyramid], indices, + slice_size=slice_size, device="cpu") + + p1, p2 = p1.to(device), p2.to(device) + + distances = [] + for j in range(n_repeat_projection): + # random + rand = torch.randn(p1.size(1), proj_per_repeat).to(device) # (slice_size**2*ch) + rand = rand / torch.std(rand, dim=0, keepdim=True) # noramlize + # projection + proj1 = torch.matmul(p1, rand) + proj2 = torch.matmul(p2, rand) + proj1, _ = torch.sort(proj1, dim=0) + proj2, _ = torch.sort(proj2, dim=0) + d = torch.abs(proj1 - proj2) + distances.append(torch.mean(d)) + + # swd + result.append(torch.mean(torch.stack(distances))) + + # average over resolution + result = torch.stack(result) * 1e3 + if return_by_resolution: + return result.cpu() + else: + return torch.mean(result).cpu() + + +def compute_swd(orig_images, mask_images, device): + orig_images = torch.as_tensor(orig_images.transpose([0, 3, 1, 2]).astype(np.float32) / 255.0) + mask_images = torch.as_tensor(mask_images.transpose([0, 3, 1, 2]).astype(np.float32) / 255.0) + + out = swd(orig_images, mask_images, device=device) # Fast estimation if device="cuda" + return out \ No newline at end of file From 20ee9a0285f7fe2e44fc6f7ac2c795ae4b45a500 Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Wed, 20 Mar 2024 18:20:24 +0000 Subject: [PATCH 09/41] added tv loss --- deepliif/models/SDG_model.py | 13 +++++++++++-- deepliif/models/networks.py | 12 ++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/deepliif/models/SDG_model.py b/deepliif/models/SDG_model.py index 5798eed..308f15c 100644 --- a/deepliif/models/SDG_model.py +++ b/deepliif/models/SDG_model.py @@ -16,6 +16,9 @@ def __init__(self, opt): self.mod_gen_no = self.opt.modalities_no opt.resize_conv = 'resizeconv' in opt.checkpoints_dir or 'resizeconv' in opt.name + print('opt.resize_conv', opt.resize_conv) + opt.tv_loss = 'tv' in opt.checkpoints_dir or 'tv' in opt.name + print('opt.tv_loss', opt.tv_loss) # weights of the modalities in generating segmentation mask @@ -38,7 +41,7 @@ def __init__(self, opt): self.visual_names = ['real_A'] # specify the training losses you want to print out. The training/test scripts will call for i in range(1, self.mod_gen_no + 1): - self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'D_real_' + str(i), 'D_fake_' + str(i)]) + self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'G_VGG_'+ str(i),'G_TV_'+ str(i),'D_real_' + str(i), 'D_fake_' + str(i)]) self.visual_names.extend(['fake_B_' + str(i), 'real_B_' + str(i)]) # specify the images you want to save/display. The training/test scripts will call @@ -78,6 +81,7 @@ def __init__(self, opt): self.criterionSmoothL1 = torch.nn.SmoothL1Loss() self.criterionVGG = networks.VGGLoss().to(self.device) + self.criterionTV = networks.TotalVariationLoss().to(self.device) # initialize optimizers; schedulers will be automatically created by function . params = [] @@ -164,12 +168,17 @@ def backward_G(self): self.loss_G_VGG = [] for i in range(self.mod_gen_no): self.loss_G_VGG.append(self.criterionVGG(self.fake_B[i], self.real_B[i]) * self.opt.lambda_feat) + + self.loss_G_TV = [] + for i in range(self.mod_gen_no): + self.loss_G_TV.append(self.criterionTV(self.fake_B[i]) * 0.02) # self.loss_G = (self.loss_G_GAN[0] + self.loss_G_L1[0]) * self.loss_G_weights[0] self.loss_G = torch.tensor(0., device=self.device) for i in range(0, self.mod_gen_no): #self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i]) * self.loss_G_weights[i] - self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i]) * self.loss_G_weights[i] + #self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i]) * self.loss_G_weights[i] + self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i] + self.loss_G_TV[i]) * self.loss_G_weights[i] self.loss_G.backward() def optimize_parameters(self): diff --git a/deepliif/models/networks.py b/deepliif/models/networks.py index 989639f..a5269f1 100644 --- a/deepliif/models/networks.py +++ b/deepliif/models/networks.py @@ -699,3 +699,15 @@ def forward(self, x, y): for i in range(len(x_vgg)): loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) return loss + + +class TotalVariationLoss(nn.Module): + """ + Absolute difference for neighbouring pixels (i,j to i+1,j, then i,j to i,j+1), averaged on pixel level + """ + def __init__(self): + super(TotalVariationLoss, self).__init__() + + def forward(self, x): + tv = torch.abs(x[:,:,1:,:]-x[:,:,:-1,:]).sum() + torch.abs(x[:,:,:,1:]-x[:,:,:,:-1]).sum() + return tv / torch.prod(torch.tensor(x.size())) From 507c70cda7e0a5938e176b23467089dd5cd1ee1c Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Thu, 4 Apr 2024 20:46:48 +0000 Subject: [PATCH 10/41] disabled TV loss --- deepliif/models/SDG_model.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/deepliif/models/SDG_model.py b/deepliif/models/SDG_model.py index 308f15c..f633a86 100644 --- a/deepliif/models/SDG_model.py +++ b/deepliif/models/SDG_model.py @@ -41,7 +41,9 @@ def __init__(self, opt): self.visual_names = ['real_A'] # specify the training losses you want to print out. The training/test scripts will call for i in range(1, self.mod_gen_no + 1): - self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'G_VGG_'+ str(i),'G_TV_'+ str(i),'D_real_' + str(i), 'D_fake_' + str(i)]) + self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'G_VGG_'+ str(i), + #'G_TV_'+ str(i), + 'D_real_' + str(i), 'D_fake_' + str(i)]) self.visual_names.extend(['fake_B_' + str(i), 'real_B_' + str(i)]) # specify the images you want to save/display. The training/test scripts will call @@ -65,6 +67,8 @@ def __init__(self, opt): print(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm, not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.opt.gpu_ids, self.opt.padding) print('***************************************') + if i==0: + print(self.netG[i]) if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc self.netD = [None for _ in range(self.mod_gen_no)] @@ -169,16 +173,16 @@ def backward_G(self): for i in range(self.mod_gen_no): self.loss_G_VGG.append(self.criterionVGG(self.fake_B[i], self.real_B[i]) * self.opt.lambda_feat) - self.loss_G_TV = [] - for i in range(self.mod_gen_no): - self.loss_G_TV.append(self.criterionTV(self.fake_B[i]) * 0.02) + # self.loss_G_TV = [] + # for i in range(self.mod_gen_no): + # self.loss_G_TV.append(self.criterionTV(self.fake_B[i]) * 0.02) # self.loss_G = (self.loss_G_GAN[0] + self.loss_G_L1[0]) * self.loss_G_weights[0] self.loss_G = torch.tensor(0., device=self.device) for i in range(0, self.mod_gen_no): #self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i]) * self.loss_G_weights[i] - #self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i]) * self.loss_G_weights[i] - self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i] + self.loss_G_TV[i]) * self.loss_G_weights[i] + self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i]) * self.loss_G_weights[i] + #self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i] + self.loss_G_TV[i]) * self.loss_G_weights[i] self.loss_G.backward() def optimize_parameters(self): From 08a767c099787c8d647bb617e00f81339903aff4 Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Thu, 4 Apr 2024 20:47:55 +0000 Subject: [PATCH 11/41] added experimental model: attention unet --- deepliif/models/att_unet.py | 197 ++++++++++++++++++++++++++++++++++++ deepliif/models/networks.py | 4 +- 2 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 deepliif/models/att_unet.py diff --git a/deepliif/models/att_unet.py b/deepliif/models/att_unet.py new file mode 100644 index 0000000..e1a46e3 --- /dev/null +++ b/deepliif/models/att_unet.py @@ -0,0 +1,197 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init + +def init_weights(net, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) + +class conv_block(nn.Module): + def __init__(self,ch_in,ch_out,innermost=False,outermost=False): + super(conv_block,self).__init__() + if outermost: + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=4,stride=2,padding=1,bias=True), + nn.LeakyReLU(0.2, True), + ) + elif innermost: + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=4,stride=2,padding=1,bias=True), + nn.ReLU(inplace=True), + ) + else: + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=4,stride=2,padding=1,bias=True), + nn.BatchNorm2d(ch_out), + nn.LeakyReLU(0.2, True), + ) + + def forward(self,x): + x = self.conv(x) + return x + +class up_conv(nn.Module): + def __init__(self,ch_in,ch_out,innermost=False,outermost=False): + super(up_conv,self).__init__() + use_bias=False + if outermost: + self.up = nn.Sequential( + nn.ConvTranspose2d(ch_in * 2, ch_out, + kernel_size=4, stride=2, + padding=1), + nn.Tanh()) + elif innermost: + self.up = nn.Sequential( + nn.ConvTranspose2d(ch_in, ch_out, + kernel_size=4, stride=2, + padding=1, bias=use_bias), + nn.BatchNorm2d(ch_out), + nn.ReLU(True)) + else: + self.up = nn.Sequential( + nn.ConvTranspose2d(ch_in * 2, ch_out, + kernel_size=4, stride=2, + padding=1, bias=use_bias), + nn.BatchNorm2d(ch_out), + nn.ReLU(True)) + + + def forward(self,x): + x = self.up(x) + return x + + +class Attention_block(nn.Module): + def __init__(self,F_g,F_l,F_int): + super(Attention_block,self).__init__() + self.W_g = nn.Sequential( + nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), + nn.BatchNorm2d(F_int) + ) + + self.W_x = nn.Sequential( + nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), + nn.BatchNorm2d(F_int) + ) + + self.psi = nn.Sequential( + nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), + nn.BatchNorm2d(1), + nn.Sigmoid() + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self,g,x): + g1 = self.W_g(g) + x1 = self.W_x(x) + psi = self.relu(g1+x1) + psi = self.psi(psi) + + return x*psi + + + +class AttU_Net(nn.Module): + def __init__(self,img_ch=3,output_ch=1): + super(AttU_Net,self).__init__() + + self.Conv1 = conv_block(ch_in=img_ch,ch_out=64,outermost=True) + self.Conv2 = conv_block(ch_in=64,ch_out=128) + self.Conv3 = conv_block(ch_in=128,ch_out=256) + self.Conv4 = conv_block(ch_in=256,ch_out=512) + self.Conv5 = conv_block(ch_in=512,ch_out=512) + self.Conv6 = conv_block(ch_in=512,ch_out=512) + self.Conv7 = conv_block(ch_in=512,ch_out=512) + self.Conv8 = conv_block(ch_in=512,ch_out=512,innermost=True) + #self.Conv9 = conv_block(ch_in=512,ch_out=512,innermost=True) + + self.Up8 = up_conv(ch_in=512,ch_out=512,innermost=True) + self.Att8 = Attention_block(F_g=512,F_l=512,F_int=512) + + self.Up7 = up_conv(ch_in=512,ch_out=512) + self.Att7 = Attention_block(F_g=512,F_l=512,F_int=512) + + self.Up6 = up_conv(ch_in=512,ch_out=512) + self.Att6 = Attention_block(F_g=512,F_l=512,F_int=512) + + self.Up5 = up_conv(ch_in=512,ch_out=512) + self.Att5 = Attention_block(F_g=512,F_l=512,F_int=512) + + self.Up4 = up_conv(ch_in=512,ch_out=256) + self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128) + + self.Up3 = up_conv(ch_in=256,ch_out=128) + self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64) + + self.Up2 = up_conv(ch_in=128,ch_out=64) + self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32) + + self.Up1 = up_conv(ch_in=64,ch_out=img_ch,outermost=True) + + def forward(self,x): + # encoding path + x1 = self.Conv1(x) + x2 = self.Conv2(x1) + x3 = self.Conv3(x2) + x4 = self.Conv4(x3) + x5 = self.Conv5(x4) + x6 = self.Conv6(x5) + x7 = self.Conv7(x6) + x8 = self.Conv8(x7) + #x9 = self.Conv9(x8) + + #d9 = self.Up + d8 = self.Up8(x8) + x7 = self.Att8(g=d8,x=x7) + d8 = torch.cat((x7,d8),dim=1) + + d7 = self.Up7(d8) + x6 = self.Att7(g=d7,x=x6) + d7 = torch.cat((x6,d7),dim=1) + + d6 = self.Up6(d7) + x5 = self.Att6(g=d6,x=x5) + d6 = torch.cat((x5,d6),dim=1) + + d5 = self.Up5(d6) + x4 = self.Att5(g=d5,x=x4) + d5 = torch.cat((x4,d5),dim=1) + + + d4 = self.Up4(d5) # x4: [2, 512, 4, 4], d4: [2, 256, 4, 4] + x3 = self.Att4(g=d4,x=x3) + d4 = torch.cat((x3,d4),dim=1) + + d3 = self.Up3(d4) + x2 = self.Att3(g=d3,x=x2) + d3 = torch.cat((x2,d3),dim=1) + + d2 = self.Up2(d3) + x1 = self.Att2(g=d2,x=x1) + d2 = torch.cat((x1,d2),dim=1) + + d1 = self.Up1(d2) + + return d1 + diff --git a/deepliif/models/networks.py b/deepliif/models/networks.py index a5269f1..0e3ad9d 100644 --- a/deepliif/models/networks.py +++ b/deepliif/models/networks.py @@ -6,7 +6,7 @@ import os from torchvision import models - +from .att_unet import AttU_Net ############################################################################### # Helper Functions ############################################################################### @@ -165,6 +165,8 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif netG == 'unet_512': net = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_512_attention': + net = AttU_Net(img_ch=3,output_ch=3) else: raise NotImplementedError('Generator model name [%s] is not recognized' % netG) return init_net(net, init_type, init_gain, gpu_ids) From aca2060f3ebb9b1badd6c2ae00f34fdf1b88cfd8 Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Wed, 15 May 2024 20:48:46 +0000 Subject: [PATCH 12/41] allowed to change net_gs --- deepliif/models/DeepLIIF_model.py | 12 +++++++----- deepliif/options/__init__.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/deepliif/models/DeepLIIF_model.py b/deepliif/models/DeepLIIF_model.py index aeb57f8..e450980 100644 --- a/deepliif/models/DeepLIIF_model.py +++ b/deepliif/models/DeepLIIF_model.py @@ -13,6 +13,8 @@ def __init__(self, opt): opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseModel.__init__(self, opt) + if not hasattr(opt,'net_gs'): + opt.net_gs = 'unet_512' # weights of the modalities in generating segmentation mask self.seg_weights = [0.25, 0.15, 0.25, 0.1, 0.25] @@ -60,15 +62,15 @@ def __init__(self, opt): self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding) - self.netG51 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'unet_512', opt.norm, + self.netG51 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'unet_512', opt.norm, + self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'unet_512', opt.norm, + self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'unet_512', opt.norm, + self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'unet_512', opt.norm, + self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) diff --git a/deepliif/options/__init__.py b/deepliif/options/__init__.py index 5e2ea6b..a350def 100644 --- a/deepliif/options/__init__.py +++ b/deepliif/options/__init__.py @@ -35,8 +35,8 @@ def __init__(self, d_params=None, path_file=None, mode='train'): if mode == 'train': self.is_train = True - self.netG = 'resnet_9blocks' - self.netD = 'n_layers' + self.netG = self.net_g #'resnet_9blocks' + self.netD = self.net_d #'n_layers' self.n_layers_D = 4 self.lambda_L1 = 100 self.lambda_feat = 100 From 7a9c30cf81bb496c5920869a0cea883e25c489cf Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Wed, 15 May 2024 21:37:24 +0000 Subject: [PATCH 13/41] added argument return_seg_intermediate --- deepliif/models/__init__.py | 44 +++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/deepliif/models/__init__.py b/deepliif/models/__init__.py index 61845d6..7649690 100644 --- a/deepliif/models/__init__.py +++ b/deepliif/models/__init__.py @@ -115,7 +115,7 @@ def load_torchscript_model(model_pt_path, device): -def load_eager_models(opt, devices): +def load_eager_models(opt, devices=None): # create a model given model and other options model = create_model(opt) # regular setup: load and print networks; create schedulers @@ -138,7 +138,8 @@ def load_eager_models(opt, devices): net = net.module nets[name] = net - nets[name].to(devices[name]) + if devices: + nets[name].to(devices[name]) return nets @@ -253,10 +254,19 @@ def forward(input, model): lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu')) segs = compute(lazy_segs)[0] - seg_weights = [0.25, 0.25, 0.25, 0, 0.25] - seg = torch.stack([torch.mul(n, w) for n, w in zip(segs.values(), seg_weights)]).sum(dim=0) + #seg_weights = [0.25, 0.25, 0.25, 0, 0.25] + #seg = torch.stack([torch.mul(n, w) for n, w in zip(segs.values(), seg_weights)]).sum(dim=0) + weights = { + 'G51': 0.25, # IHC + 'G52': 0.25, # Hema + 'G53': 0.25, # DAPI + 'G54': 0.00, # Lap2 + 'G55': 0.25, # Marker + } + seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0) res = {k: tensor_to_pil(v) for k, v in gens.items()} + res.update({k: tensor_to_pil(v) for k, v in segs.items()}) res['G5'] = tensor_to_pil(seg) return res @@ -280,12 +290,19 @@ def forward(input, model): -def is_empty(tile): +def is_empty_old(tile): # return True if np.mean(np.array(tile) - np.array(mean_background_val)) < 40 else False if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty return all([True if calculate_background_area(t) > 98 else False for t in tile]) else: return True if calculate_background_area(tile) > 98 else False + + +def is_empty(tile): + if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty + return all([True if np.max(image_variance_rgb(tile)) < 15 else False for t in tile]) + else: + return True if np.max(image_variance_rgb(tile)) < 15 else False def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None): @@ -358,7 +375,7 @@ def get_net_tiles(n): def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False, - color_dapi=False, color_marker=False, opt=None): + color_dapi=False, color_marker=False, opt=None, return_seg_intermediate=False): if not opt: opt = get_opt(model_path) #print_options(opt) @@ -373,7 +390,14 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, ea 'DAPI':'G2', 'Lap2':'G3', 'Marker':'G4', - 'Seg':'G5'} + 'Seg':'G5', + } + if return_seg_intermediate: + d_modality2net.update({'IHC_s':'G51', + 'Hema_s':'G52', + 'DAPI_s':'G53', + 'Lap2_s':'G54', + 'Marker_s':'G55',}) for k in d_modality2net.keys(): images[k] = create_image_for_stitching(tile_size, rows, cols) @@ -509,7 +533,8 @@ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='aut def infer_modalities(img, tile_size, model_dir, eager_mode=False, - color_dapi=False, color_marker=False, opt=None): + color_dapi=False, color_marker=False, opt=None, + return_seg_intermediate=False): """ This function is used to infer modalities for the given image using a trained model. :param img: The input image. @@ -539,7 +564,8 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False, eager_mode=eager_mode, color_dapi=color_dapi, color_marker=color_marker, - opt=opt + opt=opt, + return_seg_intermediate=return_seg_intermediate ) if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models From 9aebf568fcdee06276885f9bd2a807cd9d8c98d3 Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Wed, 26 Jun 2024 18:04:14 +0000 Subject: [PATCH 14/41] allowed cli.py inference and test.py to return raw segmentation predictions --- deepliif/models/DeepLIIF_model.py | 22 ++++++++++++---------- deepliif/models/__init__.py | 9 +++++++-- deepliif/models/base_model.py | 10 +++++++--- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/deepliif/models/DeepLIIF_model.py b/deepliif/models/DeepLIIF_model.py index e450980..f4af446 100644 --- a/deepliif/models/DeepLIIF_model.py +++ b/deepliif/models/DeepLIIF_model.py @@ -33,7 +33,7 @@ def __init__(self, opt): # specify the training losses you want to print out. The training/test scripts will call for i in range(1, self.opt.modalities_no + 1 + 1): self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'D_real_' + str(i), 'D_fake_' + str(i)]) - self.visual_names.extend(['fake_B_' + str(i), 'real_B_' + str(i)]) + self.visual_names.extend(['fake_B_' + str(i), 'fake_B_5' + str(i), 'real_B_' + str(i)]) # specify the images you want to save/display. The training/test scripts will call # specify the models you want to save to the disk. The training/test scripts will call and @@ -53,24 +53,26 @@ def __init__(self, opt): self.model_names.extend(['G5' + str(i)]) # define networks (both generator and discriminator) - self.netG1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + opt.netG = [opt.netG] * 4 + opt.net_gs = [opt.net_gs]*5 + self.netG1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[0], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding) - self.netG2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + self.netG2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[1], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding) - self.netG3 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + self.netG3 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[2], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding) - self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[3], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding) - self.netG51 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs, opt.norm, + self.netG51 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[0], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs, opt.norm, + self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[1], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs, opt.norm, + self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[2], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs, opt.norm, + self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[3], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs, opt.norm, + self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[4], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) diff --git a/deepliif/models/__init__.py b/deepliif/models/__init__.py index 7649690..1e4af1f 100644 --- a/deepliif/models/__init__.py +++ b/deepliif/models/__init__.py @@ -313,7 +313,12 @@ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None): 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)), 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)), - 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)) + 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), + 'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), + 'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), + 'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), + 'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), + 'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), } else: return run_fn(tile, model_path, eager_mode, opt) @@ -409,7 +414,7 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, ea for modality_name, net_name in d_modality2net.items(): stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j) - + for modality_name, output_img in images.items(): images[modality_name] = output_img.resize(img.size) diff --git a/deepliif/models/base_model.py b/deepliif/models/base_model.py index cad9295..e74ecd7 100644 --- a/deepliif/models/base_model.py +++ b/deepliif/models/base_model.py @@ -134,10 +134,14 @@ def get_current_visuals(self): for name in self.visual_names: if isinstance(name, str): if not hasattr(self, name): - if len(name.split('_')) == 2: - visual_ret[name] = getattr(self, name.split('_')[0])[int(name.split('_')[-1]) -1] + if len(name.split('_')) != 2: + if self.opt.model == 'DeepLIIF': + img_name = name[:-1] + '_' + name[-1] + visual_ret[name] = getattr(self, img_name) + else: + visual_ret[name] = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])[int(name.split('_')[-1]) - 1] else: - visual_ret[name] = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])[int(name.split('_')[-1]) - 1] + visual_ret[name] = getattr(self, name.split('_')[0])[int(name.split('_')[-1]) -1] else: visual_ret[name] = getattr(self, name) return visual_ret From 4cde80446e0f8e0168a325eae239cb7ab85ff697 Mon Sep 17 00:00:00 2001 From: Wendy Wang Date: Tue, 16 Jul 2024 20:08:53 +0000 Subject: [PATCH 15/41] allowed to specify modality-wise generator arch; added validation loss calculation --- cli.py | 106 +++++++++++++++++++++++-- deepliif/models/DeepLIIF_model.py | 19 +++-- deepliif/models/base_model.py | 11 +++ deepliif/util/visualizer.py | 127 +++++++++++++++++++++++++----- 4 files changed, 229 insertions(+), 34 deletions(-) diff --git a/cli.py b/cli.py index 604838c..e9e9fc4 100644 --- a/cli.py +++ b/cli.py @@ -8,9 +8,10 @@ import torch import numpy as np from PIL import Image +from torchvision.transforms import ToPILImage from deepliif.data import create_dataset, transform -from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi, create_model +from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi, create_model, postprocess from deepliif.util import allowed_file, Visualizer, get_information, test_diff_original_serialized, disable_batchnorm_tracking_stats from deepliif.util.util import mkdirs, check_multi_scale # from deepliif.util import infer_results_for_wsi @@ -85,7 +86,7 @@ def cli(): help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 ' 'PatchGAN. n_layers allows you to specify the layers in the discriminator') @click.option('--net-g', default='resnet_9blocks', - help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]') + help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention') @click.option('--n-layers-d', default=4, help='only used if netD==n_layers') @click.option('--norm', default='batch', help='instance normalization or batch normalization [instance | batch | none]') @@ -164,7 +165,7 @@ def cli(): @click.option('--net-ds', type=str, default='n_layers', help='specify discriminator architecture for segmentation task [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') @click.option('--net-gs', type=str, default='unet_512', - help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]') + help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention') @click.option('--gan-mode', type=str, default='vanilla', help='the type of GAN objective for translation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') @click.option('--gan-mode-s', type=str, default='lsgan', @@ -235,17 +236,34 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd dir_data_train = dataroot + '/train' fns = os.listdir(dir_data_train) fns = [x for x in fns if x.endswith('.png')] + print(f'{len(fns)} images found') img = Image.open(f"{dir_data_train}/{fns[0]}") + print(f'image shape:',img.size) num_img = img.size[0] / img.size[1] assert int(num_img) == num_img, f'img size {img.size[0]} / {img.size[1]} = {num_img} is not an integer' num_img = int(num_img) input_no = num_img - modalities_no - seg_no - assert input_no > 0, f'inferred number of input images is {input_no}; should be greater than 0' + assert input_no > 0, f'inferred number of input images is {input_no} (modalities_no {modalities_no}, seg_no {seg_no}); should be greater than 0' d_params['input_no'] = input_no d_params['scale_size'] = img.size[1] d_params['gpu_ids'] = gpu_ids + + # update generator arch + net_g = net_g.split(',') + assert len(net_g) in [1,modalities_no], f'net_g should contain either 1 architecture for all translation generators or the same number of architectures as the number of translation generators ({modalities_no})' + if len(net_g) == 1: + net_g = net_g*modalities_no + + + net_gs = net_gs.split(',') + assert len(net_gs) in [1,seg_no], f'net_gs should contain either 1 architecture for all segmentation generators or the same number of architectures as the number of segmentation generators ({seg_no})' + if len(net_gs) == 1: + net_gs = net_gs*seg_no + + d_params['net_g'] = net_g + d_params['net_gs'] = net_gs # create a dataset given dataset_mode and other options # dataset = AlignedDataset(opt) @@ -253,9 +271,16 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd opt = Options(d_params=d_params) print_options(opt, save=True) + # set dir for train and val dataset = create_dataset(opt) + dataset_val = create_dataset(opt,phase='val') + data_val = [batch for batch in dataset_val] + metrics_val = json.load(open(os.path.join(dataset_val.dataset.dir_AB,'metrics.json'))) + # get the number of images in the dataset. click.echo('The number of training images = %d' % len(dataset)) + click.echo('The number of validation images = %d' % len(dataset_val)) + click.echo('The number of validation images = %d' % len(data_val)) # create a model given model and other options model = create_model(opt) @@ -285,6 +310,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd # inner loop within one epoch for i, data in enumerate(dataset): + # timer for computation per iteration iter_start_time = time.time() if total_iters % print_freq == 0: @@ -301,15 +327,15 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd if total_iters % display_freq == 0: save_result = total_iters % update_html_freq == 0 model.compute_visuals() - visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) + visualizer.display_current_results({**model.get_current_visuals()}, epoch, save_result) # print training losses and save logging information to the disk if total_iters % print_freq == 0: - losses = model.get_current_losses() + losses = model.get_current_losses() # get training losses t_comp = (time.time() - iter_start_time) / batch_size - visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) + visualizer.print_current_losses(epoch, epoch_iter, {**losses}, t_comp, t_data) if display_id > 0: - visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), losses) + visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), {**losses}) # cache our latest model every iterations if total_iters % save_latest_freq == 0: @@ -325,6 +351,70 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd model.save_networks('latest') model.save_networks(epoch) + + + # validation loss and metrics calculation + losses = model.get_current_losses() # get training losses to print + + model.eval() + l_losses_val = [] + l_metrics_val = [] + + # for each val image, calculate validation loss and cell count metrics + for j, data_val_batch in enumerate(data_val): + # batch size is effectively 1 for validation + model.set_input(data_val_batch) + model.calculate_losses() # this does not optimize parameters + visuals = model.get_current_visuals() # get image results + + # val losses + losses_val_batch = model.get_current_losses() + l_losses_val += [(k,v) for k,v in losses_val_batch.items()] + + # calculate cell count metrics + l_seg_names = ['fake_B_5'] + assert l_seg_names[0] in visuals.keys(), f'Cannot find {l_seg_names[0]} in generated image names ({list(visuals.keys())})' + seg_mod_suffix = l_seg_names[0].split('_')[-1] + l_seg_names += [x for x in visuals.keys() if x.startswith('fake') and x.split('_')[-1].startswith(seg_mod_suffix) and x != l_seg_names[0]] + # print(f'Running postprocess for {len(l_seg_names)} generated images ({l_seg_names})') + + img_name_current = data_val_batch['A_paths'][0].split('/')[-1][:-4] # remove .png + metrics_gt = metrics_val[img_name_current] + + for seg_name in l_seg_names: + images = {'Seg':ToPILImage()((visuals[seg_name][0].cpu()+1)/2), + 'Marker':ToPILImage()((visuals['fake_B_4'][0].cpu()+1)/2)} + _, scoring = postprocess(ToPILImage()((data['A'][0]+1)/2), images, opt.scale_size, opt.model) + + for k,v in scoring.items(): + if k.startswith('num') or k.startswith('percent'): + # to calculate the rmse, here we calculate (x_pred - x_true) ** 2 + l_metrics_val.append((k+'_'+seg_name,(v - metrics_gt[k])**2)) + + d_losses_val = {k+'_val':0 for k in losses_val_batch.keys()} + for k,v in l_losses_val: + d_losses_val[k+'_val'] += v + for k in d_losses_val: + d_losses_val[k] = d_losses_val[k] / len(data_val) + + d_metrics_val = {} + for k,v in l_metrics_val: + try: + d_metrics_val[k] += v + except: + d_metrics_val[k] = v + for k in d_metrics_val: + # to calculate the rmse, this is the second part, where d_metrics_val[k] now represents sum((x_pred - x_true) ** 2) + d_metrics_val[k] = np.sqrt(d_metrics_val[k] / len(data_val)) + + + model.train() + t_comp = (time.time() - iter_start_time) / batch_size + visualizer.print_current_losses(epoch, epoch_iter, {**losses,**d_losses_val, **d_metrics_val}, t_comp, t_data) + if display_id > 0: + visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), {**losses,**d_losses_val,**d_metrics_val}) + + print('End of epoch %d / %d \t Time Taken: %d sec' % ( epoch, n_epochs + n_epochs_decay, time.time() - epoch_start_time)) # update learning rates at the end of every epoch. diff --git a/deepliif/models/DeepLIIF_model.py b/deepliif/models/DeepLIIF_model.py index f4af446..f851022 100644 --- a/deepliif/models/DeepLIIF_model.py +++ b/deepliif/models/DeepLIIF_model.py @@ -17,7 +17,7 @@ def __init__(self, opt): opt.net_gs = 'unet_512' # weights of the modalities in generating segmentation mask - self.seg_weights = [0.25, 0.15, 0.25, 0.1, 0.25] + self.seg_weights = [0.25, 0.25, 0.25, 0.0, 0.25] # loss weights in calculating the final loss self.loss_G_weights = [0.2, 0.2, 0.2, 0.2, 0.2] @@ -53,8 +53,12 @@ def __init__(self, opt): self.model_names.extend(['G5' + str(i)]) # define networks (both generator and discriminator) - opt.netG = [opt.netG] * 4 - opt.net_gs = [opt.net_gs]*5 + print(opt.netG) + if isinstance(opt.netG, str): + opt.netG = [opt.netG] * 4 + if isinstance(opt.net_gs, str): + opt.net_gs = [opt.net_gs]*5 + self.netG1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[0], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding) self.netG2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[1], opt.norm, @@ -64,15 +68,16 @@ def __init__(self, opt): self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[3], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding) + # DeepLIIF model currently uses one gs arch because there is only one explicit seg mod output self.netG51 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[0], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[1], opt.norm, + self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[0], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[2], opt.norm, + self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[0], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[3], opt.norm, + self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[0], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) - self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[4], opt.norm, + self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[0], opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) diff --git a/deepliif/models/base_model.py b/deepliif/models/base_model.py index e74ecd7..20fa424 100644 --- a/deepliif/models/base_model.py +++ b/deepliif/models/base_model.py @@ -88,6 +88,16 @@ def setup(self, opt): self.load_networks(load_suffix) self.print_networks(opt.verbose) + def train(self): + """Make models train mode """ + for name in self.model_names: + if isinstance(name, str): + if '_' in name: + net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1] + else: + net = getattr(self, 'net' + name) + net.train() + def eval(self): """Make models eval mode during test time""" for name in self.model_names: @@ -244,6 +254,7 @@ def load_networks(self, epoch): epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) """ for name in self.model_names: + if isinstance(name, str): load_filename = '%s_net_%s.pth' % (epoch, name) load_path = os.path.join(self.save_dir, load_filename) diff --git a/deepliif/util/visualizer.py b/deepliif/util/visualizer.py index 0856b51..f1c327c 100644 --- a/deepliif/util/visualizer.py +++ b/deepliif/util/visualizer.py @@ -165,6 +165,7 @@ def display_current_results(self, visuals, epoch, save_result, **kwargs): if ncols > 0: # show all the images in one visdom panel ncols = min(ncols, len(visuals)) h, w = next(iter(visuals.values())).shape[:2] + print(1) table_css = """