Skip to content

Commit 7831e70

Browse files
authored
Merge pull request #4 from gcervantes8/hf-accelerate
Added Accelerate support from HuggingFace - MultiGPU & distributed support!
2 parents 5756b6e + 68325fd commit 7831e70

19 files changed

+464
-381
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
19+
python-version: ["3.8", "3.9", "3.10", "3.11"]
2020

2121
steps:
2222
- uses: actions/checkout@v3

README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
![Fast Image Gans with a picture of a fig to the left of it](logo/FigsName.png)
2-
[![Python](https://img.shields.io/badge/Python-3.7--3.11-blue)](https://www.python.org/downloads/) [![License](https://img.shields.io/badge/License-GPL--3.0-yellow)](https://github.com/gcervantes8/Game-Image-Generator/blob/master/LICENSE) [![Python package](https://github.com/gcervantes8/Game-Image-Generator/actions/workflows/python-package.yml/badge.svg)](https://github.com/gcervantes8/Game-Image-Generator/actions/workflows/python-package.yml)
2+
3+
[![Python](https://img.shields.io/badge/Python-3.8--3.11-blue)](https://www.python.org/downloads/) [![License](https://img.shields.io/badge/License-GPL--3.0-yellow)](https://github.com/gcervantes8/Game-Image-Generator/blob/master/LICENSE) [![Python package](https://github.com/gcervantes8/Game-Image-Generator/actions/workflows/python-package.yml/badge.svg)](https://github.com/gcervantes8/Game-Image-Generator/actions/workflows/python-package.yml)
34

45

56
With this project, you can train Generative Adversarial Networks (GANs). While you can train with any type of image,
67
this repository focuses on generating images from games.
78

89
## Features
910

10-
- PyTorch 2.0 Compile
11-
- Mixed Precision training
11+
- PyTorch 2 Compile
12+
- Mixed Precision training (fp16 or bf16)
13+
- Gradient Accumulation
1214
- Inception Score and FID evaluation
15+
- HF🤗 Accelerate - Adds Multi-GPU, and TPU support
1316
- Easy to start training
1417
- Testing
1518

@@ -26,8 +29,8 @@ Provided in the code is a sample of the coil-100 dataset, which is used for test
2629

2730
## Requirements
2831
The following are the Python packages needed.
29-
- [Pytorch](https://pytorch.org/get-started/locally/), 1.9+
30-
- [torchvision](https://pypi.org/project/torchvision/) 0.9+
32+
- [Pytorch](https://pytorch.org/get-started/locally/), 2.0+
33+
- [torchvision](https://pypi.org/project/torchvision/) 1.5+
3134
- [SciPy](https://scipy.org/install/) 1.7+
3235
- [TorchMetrics](https://torchmetrics.readthedocs.io/en/stable/)
3336
- [torchinfo](https://github.com/TylerYep/torchinfo)

requirements.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
torch >= 1.9
2-
torchvision >= 0.9
1+
torch >= 2.0
2+
torchvision >= 0.15
33
torchinfo
44
torch-ema
55
Pillow
66
torchmetrics
77
scipy >= 1.7
88
tqdm
99
tensorboard
10-
six
10+
six
11+
bitsandbytes
12+
accelerate

src/data_load.py renamed to src/data/data_load.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import torch
1111
import PIL
1212
import torchvision.datasets as torch_data_set
13-
import torchvision.transforms as transforms
14-
from src import os_helper
13+
import torchvision.transforms.v2 as transforms
14+
from src.utils import os_helper
1515

1616

1717
def normalize(images, norm_mean=torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32),
@@ -38,13 +38,13 @@ def color_transform(images, brightness=0.1, contrast=0.05, saturation=0.1, hue=0
3838
return train_transform_augment(images)
3939

4040

41-
def data_loader_from_config(data_config, data_dtype=torch.float32, using_gpu=False):
41+
def data_loader_from_config(data_config, using_gpu=False):
4242
data_dir = data_config['train_dir']
4343
os_helper.is_valid_dir(data_dir, 'Invalid training data directory\nPath is an invalid directory: ' + data_dir)
4444
image_height, image_width = get_image_height_and_width(data_config)
4545
batch_size = int(data_config['batch_size'])
4646
n_workers = int(data_config['workers'])
47-
return create_data_loader(data_dir, image_height, image_width, dtype=data_dtype, using_gpu=using_gpu,
47+
return create_data_loader(data_dir, image_height, image_width, using_gpu=using_gpu,
4848
batch_size=batch_size, n_workers=n_workers)
4949

5050

@@ -69,28 +69,34 @@ def get_num_classes(data_config):
6969
data_loader = data_loader_from_config(data_config)
7070
return len(data_loader.dataset.classes)
7171

72+
def to_int16(label):
73+
return torch.tensor(label, dtype=torch.int16)
7274

73-
def create_data_loader(data_dir: str, image_height: int, image_width: int, dtype=torch.float32, using_gpu=False,
75+
def create_data_loader(data_dir: str, image_height: int, image_width: int, image_dtype=torch.float16, using_gpu=False,
7476
batch_size=1, n_workers=1):
7577

7678
data_transform = transforms.Compose([transforms.Resize((image_height, image_width)),
77-
transforms.ToTensor(),
78-
transforms.ConvertImageDtype(dtype)
79+
transforms.ToImage(),
80+
transforms.ToDtype(image_dtype, scale=True), # Float16 is tiny bit faster, and bit more VRAM. Strange.
81+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
7982
])
83+
label_transform = to_int16
8084
try:
81-
data_set = torch_data_set.ImageFolder(root=data_dir, transform=data_transform)
85+
data_set = torch_data_set.ImageFolder(root=data_dir, transform=data_transform, target_transform=label_transform)
8286
except FileNotFoundError:
8387
raise FileNotFoundError('Data directory provided should contain directories that have images in them, '
8488
'directory provided: ' + data_dir)
8589

8690
# Create the data-loader
87-
torch_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size,
88-
shuffle=True, num_workers=n_workers, pin_memory=using_gpu)
91+
torch_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, shuffle=True,
92+
num_workers=n_workers, pin_memory=using_gpu, drop_last=True)
8993
return torch_loader
9094

9195

9296
# Returns images of size: (batch_size, num_channels, height, width)
93-
def get_data_batch(data_loader, device):
97+
def get_data_batch(data_loader, device, unnormalize_batch=False):
98+
if unnormalize_batch:
99+
return unnormalize(next(iter(data_loader))[0]).to(device)
94100
return next(iter(data_loader))[0].to(device)
95101

96102

src/gan_inference.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212
from torchvision import transforms
1313
import math
1414
import PIL
15-
from src import saver_and_loader, os_helper
15+
from utils import saver_and_loader
1616
from src.configs import ini_parser
1717
# from src.metrics import score_metrics
18-
from src.data_load import unnormalize, get_num_classes, create_latent_vector
19-
from src import create_model
18+
from data.data_load import unnormalize, get_num_classes, create_latent_vector
19+
from models import create_model
2020

2121
import os
2222
import logging
2323
import argparse
2424

25+
from utils import os_helper
26+
2527

2628
def generate_batch_image(ini_config, gan_model, num_images: int):
2729
model_arch_config, data_config = ini_config['MODEL ARCHITECTURE'], ini_config['DATA']

src/gan_model.py

Lines changed: 0 additions & 222 deletions
This file was deleted.

src/losses/loss_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def supported_losses():
2727
# Given a loss function returns a 3-tuple
2828
# The loss function, the fake label, and the real label
2929
# Returns 3-tuple of None if the loss function is not supported
30-
def supported_loss_functions(loss_name: str, device=None):
30+
def supported_loss_functions(loss_name: str):
3131
loss_functions = _losses()
3232
if loss_name in loss_functions:
3333
loss_fn, fake_label, real_label = loss_functions[loss_name.lower()]

0 commit comments

Comments
 (0)