Skip to content

Commit

Permalink
Fixed dataloader dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
gcervantes8 committed Dec 20, 2023
1 parent 7831e70 commit 6a23658
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/data/data_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def color_transform(images, brightness=0.1, contrast=0.05, saturation=0.1, hue=0
return train_transform_augment(images)


def data_loader_from_config(data_config, using_gpu=False):
def data_loader_from_config(data_config, image_dtype=torch.float32, using_gpu=False):
data_dir = data_config['train_dir']
os_helper.is_valid_dir(data_dir, 'Invalid training data directory\nPath is an invalid directory: ' + data_dir)
image_height, image_width = get_image_height_and_width(data_config)
batch_size = int(data_config['batch_size'])
n_workers = int(data_config['workers'])
return create_data_loader(data_dir, image_height, image_width, using_gpu=using_gpu,
return create_data_loader(data_dir, image_height, image_width, image_dtype=image_dtype, using_gpu=using_gpu,
batch_size=batch_size, n_workers=n_workers)


Expand Down
4 changes: 2 additions & 2 deletions src/train_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def train(config_file_path: str):
data_config['image_height'] = str(int(data_config['base_height']) * (2 ** int(data_config['upsample_layers'])))
data_config['image_width'] = str(int(data_config['base_width']) * (2 ** int(data_config['upsample_layers'])))

data_loader = data_loader_from_config(data_config, using_gpu=not running_on_cpu)
data_loader = data_loader_from_config(data_config, image_dtype=torch_dtype, using_gpu=not running_on_cpu)
# Eval data loader is done to keep the same label distribution when evaluating
eval_data_loader = data_loader_from_config(data_config, using_gpu=not running_on_cpu)
eval_data_loader = data_loader_from_config(data_config, image_dtype=torch_dtype, using_gpu=not running_on_cpu)
logging.info('Data size is ' + str(len(data_loader.dataset)) + ' images')

data_loader, eval_data_loader = accelerator.prepare(data_loader, eval_data_loader)
Expand Down

0 comments on commit 6a23658

Please sign in to comment.