Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sva slova #4

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
12 changes: 5 additions & 7 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import pandas as pd
from torchvision import transforms

df = pd.read_csv('english.csv')
df.head()

class Data(Dataset):
def __init__(self, path_to_file):
def __init__(self, path_to_file, image_dim):
self.data = pd.read_csv(path_to_file)
self.transform = transforms.Compose([
#transforms.to_grayscale(),
transforms.Resize((image_dim, image_dim)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
transforms.Normalize(0.5, 0.5),
])

self.images = self.data.iloc[:, 0]
Expand All @@ -26,8 +25,7 @@ def __getitem__(self, idx):
label = self.captions[idx]

image = Image.open(image_path).convert('L')
if self.transform is not None:
image = self.transform(image)
image = self.transform(image)

# pretvoriti u torch tensor
return image, label
108 changes: 108 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from matplotlib.pylab import f
import numpy as np
import pandas as pd
import matplotlib as plib
import torch
from torch import nn

class Generator(nn.Module):
def __init__(self, in_channels, out_dim, device):
self.device = device
self.kernel_size = 5
self.in_channels = in_channels
super().__init__()
ngf = 32
self.embedding = nn.Embedding(62, in_channels * 50)
#self.kernel = nn.Parametar(torch.randn(broj_slova, z_dim))
self.gen = nn.Sequential(

nn.ConvTranspose2d(in_channels * 50, ngf * 8, self.kernel_size, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),

nn.ConvTranspose2d(ngf * 8, ngf * 4, self.kernel_size, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),

nn.ConvTranspose2d( ngf * 4, ngf * 2, self.kernel_size, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),

nn.ConvTranspose2d( ngf * 2, ngf, self.kernel_size, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),

nn.ConvTranspose2d( ngf, out_dim, self.kernel_size, 2, 1, bias=False),
nn.Tanh()
)

def forward(self, noise, label):

emb_label = self.embedding(torch.tensor(label))
emb_label = emb_label.view(-1, self.in_channels * 50, 1, 1)

return self.gen(emb_label + noise)

class Discriminator(nn.Module):
def __init__(self, out_channels, device):
self.device = device
self.kernel_size = 5
super().__init__()
ndf = 32
self.disc = nn.Sequential(

nn.Conv2d(1, ndf, self.kernel_size, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(ndf, ndf * 2, self.kernel_size, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(ndf * 2, ndf * 4, self.kernel_size, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(ndf * 4, ndf * 8, self.kernel_size, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(ndf * 8, out_channels, self.kernel_size, 1, 0, bias=False),
nn.Sigmoid()

)

def forward(self, img):
return self.disc(img)


class GAN(nn.Module):
def __init__(self, in_out_channels, out_dim, glr, dlr, device):
super().__init__()
self.gen = Generator(in_out_channels, out_dim, device).to(device)
self.disc = Discriminator(in_out_channels, device).to(device)
self.gen_learn_rate = glr
self.disc_learn_rate = dlr
self.device = device

self.gen_opt = torch.optim.Adam(self.gen.parameters(), betas=(0.5, 0.999), lr=self.gen_learn_rate)
self.disc_opt = torch.optim.Adam(self.disc.parameters(), betas=(0.5, 0.999), lr=self.disc_learn_rate)
self.criterion = nn.BCELoss()

def scale(self, tensor, homothety_coeff, translation_coeff):
return tensor * homothety_coeff + translation_coeff

def compress(self, labels):
clabels = []
for label in labels:
if(ord(label) >= 97 and ord(label) <= 122):
label = ord(label) - 97
clabels.append(label)
elif(ord(label) >= 65 and ord(label) <= 90):
label = ord(label) - 39
clabels.append(label)
else:
label = ord(label) + 4
clabels.append(label)

return clabels

46 changes: 0 additions & 46 deletions network.py

This file was deleted.

107 changes: 74 additions & 33 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,93 @@
from calendar import c
from locale import currency
from PIL import Image
import os
from network import *
from model import *
from dataset import *
from torch.utils.tensorboard import SummaryWriter

#print(torch.cuda.is_available())
device = torch.device('cpu')

device = 'cpu'
learn_rate = 3e-4
z_dim = 32
img_dim = 28*28
batch_size = 32
num_epochs = 1

df = pd.read_csv('english.csv')
train_dataset = Data('english.csv')
train_dataloader = DataLoader(train_dataset, batch_size, shuffle = True)
gen_learn_rate = 3e-4
disc_learn_rate = 3e-4
z_dim = 1
z_depth = 50
img_dim = 64
input_channels = 1
output_channels = 1
batch_size = 32
num_epochs = 200
epoch_offset = 0

print(len(df))
train_dataset = Data('english.csv', img_dim)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
gan = GAN(input_channels, output_channels, gen_learn_rate, disc_learn_rate, device).to(device)
writer = SummaryWriter()

gen = Generator(img_dim, z_dim).to(device)
disc = Discriminator(z_dim).to(device)
z = torch.randn(batch_size, z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=learn_rate)
disc_opt = torch.optim.Adam(disc.parameters(), lr=learn_rate)
criterion = nn.BCELoss()
#gan.load_state_dict(torch.load('models/model_epoch_499.pt'))

for epoch in range(num_epochs):
batch_idx = 0
trainer = 0

for image, label in train_dataloader:
print(image.shape, label)

trainer += 1
image = image.to(device)

#print(label)
label = gan.compress(label)
#print(label)
label = torch.tensor(label).to(device)

curr_batch_size = len(image)
zeros = torch.zeros(curr_batch_size, dtype=torch.int32).to(device)
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
disc_opt.zero_grad()
fake = gen(torch.randn(batch_size, z_dim).to(device))
disc_real = disc(image).view(1, -1)
lossD_real = criterion(disc_real, torch.ones_like(disc_real))
disc_fake = disc(fake.detach()).view(1, -1)
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
gan.disc_opt.zero_grad()
gan.gen_opt.zero_grad()
noise = torch.randn(curr_batch_size, input_channels * z_depth, z_dim, z_dim)
noise = gan.scale(noise, 1, 0.5).to(device)

fake = gan.gen(noise, label).to(device)

disc_real = gan.disc(image.view(curr_batch_size, 1, img_dim, img_dim))
disc_fake = gan.disc(fake)

lossD_real = gan.criterion(disc_real, torch.ones_like(disc_real))
lossD_fake = gan.criterion(disc_fake, torch.zeros_like(disc_fake))

### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))

lossD = (lossD_real + lossD_fake) / 2
lossD.backward(retain_graph=True)
disc_opt.step()
gan.disc_opt.step()

### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
gen_opt.zero_grad()
output = disc(fake).view(-1)
lossG = criterion(output, torch.ones_like(output))
fake = gan.gen(noise, label)
#print(fake.shape)
disc_fake = gan.disc(fake)

lossG = gan.criterion(disc_fake, torch.ones_like(disc_fake))
lossG.backward()
gen_opt.step()

gan.gen_opt.step()

writer.add_scalar('Loss/Generator', lossG.item(), epoch)
writer.add_scalar('Loss/Discriminator', lossD.item(), epoch)

if batch_idx % 100 == 0:
#print(fake.shape)
for i in range(curr_batch_size):
img = fake[i]
img = img.view(output_channels, img_dim, img_dim)
writer.add_image(f'label: {label[i]}', img, global_step=epoch+epoch_offset)

batch_idx += 1
if batch_idx % 10 == 0:
print(
f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(df)} \
f"Epoch [{epoch+1+epoch_offset}/{num_epochs+epoch_offset}] Batch {batch_idx * curr_batch_size}/{len(train_dataset)} \
Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
)

torch.save(gan.state_dict(), f'models/model_epoch_{epoch+epoch_offset}.pt')

print(f"Epoch [{epoch+1+epoch_offset}] completed. \t\t\t\t Loss D: {lossD:.4f}, loss G: {lossG:.4f}")
Loading