diff --git a/src/data/data_load.py b/src/data/data_load.py index be7a29e..2265593 100644 --- a/src/data/data_load.py +++ b/src/data/data_load.py @@ -38,13 +38,13 @@ def color_transform(images, brightness=0.1, contrast=0.05, saturation=0.1, hue=0 return train_transform_augment(images) -def data_loader_from_config(data_config, using_gpu=False): +def data_loader_from_config(data_config, image_dtype=torch.float32, using_gpu=False): data_dir = data_config['train_dir'] os_helper.is_valid_dir(data_dir, 'Invalid training data directory\nPath is an invalid directory: ' + data_dir) image_height, image_width = get_image_height_and_width(data_config) batch_size = int(data_config['batch_size']) n_workers = int(data_config['workers']) - return create_data_loader(data_dir, image_height, image_width, using_gpu=using_gpu, + return create_data_loader(data_dir, image_height, image_width, image_dtype=image_dtype, using_gpu=using_gpu, batch_size=batch_size, n_workers=n_workers) diff --git a/src/train_gan.py b/src/train_gan.py index ec5164e..3925fff 100644 --- a/src/train_gan.py +++ b/src/train_gan.py @@ -83,9 +83,9 @@ def train(config_file_path: str): data_config['image_height'] = str(int(data_config['base_height']) * (2 ** int(data_config['upsample_layers']))) data_config['image_width'] = str(int(data_config['base_width']) * (2 ** int(data_config['upsample_layers']))) - data_loader = data_loader_from_config(data_config, using_gpu=not running_on_cpu) + data_loader = data_loader_from_config(data_config, image_dtype=torch_dtype, using_gpu=not running_on_cpu) # Eval data loader is done to keep the same label distribution when evaluating - eval_data_loader = data_loader_from_config(data_config, using_gpu=not running_on_cpu) + eval_data_loader = data_loader_from_config(data_config, image_dtype=torch_dtype, using_gpu=not running_on_cpu) logging.info('Data size is ' + str(len(data_loader.dataset)) + ' images') data_loader, eval_data_loader = accelerator.prepare(data_loader, eval_data_loader)