-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
65 lines (51 loc) · 1.94 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
from tqdm import trange
import torchvision
from torchvision import transforms, datasets, models
import torch.optim as optim
from model import Encoder, Decoder
import matplotlib.pyplot as plt
EPOCH = 50
input_nc = 1
output_nc = 1
batch_size= 64
transform = transforms.Compose([
#transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5],
std=[0.5])
])
dataset = datasets.MNIST(root='./data',transform=transform, download=True)
dataset_loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, shuffle=True,)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
enc = Encoder(input_nc, output_nc).to(device)
dec = Decoder(input_nc, output_nc).to(device)
enc.load_state_dict(torch.load("pretrained/enc.pth"))
dec.load_state_dict(torch.load("pretrained/dec.pth"))
dec_optimize = optim.Adam(dec.parameters())
loss_functionD = nn.MSELoss()
losses = []
def train():
for epoch in (i := trange(0, EPOCH)):
for id, (images,_) in enumerate(dataset_loader):
x = images.to(device)
##### training the Encoder #####
E2E_output = enc(x)
#D2E_output = dec(E2E_output)
#loss_E = loss_functionE(D2E_output,x)
#enc_optimize.zero_grad()
#loss_E.backward(retain_graph=True)
#enc_optimize.step()
##### training the Decoder #####
D2D_output = dec(E2E_output)
loss_D = loss_functionD(D2D_output, x)
dec_optimize.zero_grad()
loss_D.backward()
dec_optimize.step()
i.set_description(f'epoch [{epoch + 1}/{EPOCH}], loss:{loss_D.item():.4f}')
torch.save(dec.state_dict(), 'pretrained/dec.pth')
torch.save(enc.state_dict(), 'pretrained/enc.pth')
train()