-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
117 lines (94 loc) · 3.82 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
class DiffusionModule(nn.Module):
def __init__(self, network, var_scheduler, **kwargs):
super().__init__()
self.network = network
self.var_scheduler = var_scheduler
def get_loss(self, x0, class_label=None, noise=None):
######## TODO ########
# DO NOT change the code outside this part.
# compute noise matching loss.
B = x0.shape[0]
timestep = self.var_scheduler.uniform_sample_t(B, self.device)
xt, esp = self.var_scheduler.add_noise(x0, timestep, noise)
eps_theta = self.network(xt, timestep, class_label)
loss = F.mse_loss(eps_theta, esp)
######################
return loss
@property
def device(self):
return next(self.network.parameters()).device
@property
def image_resolution(self):
return self.network.image_resolution
@torch.no_grad()
def sample(
self,
batch_size,
return_traj=False,
class_label: Optional[torch.Tensor] = None,
guidance_scale: Optional[float] = 0.0,
):
x_T = torch.randn([batch_size, 3, self.image_resolution, self.image_resolution]).to(self.device)
do_classifier_free_guidance = guidance_scale > 0.0
if do_classifier_free_guidance:
######## TODO ########
# Assignment 2-3. Implement the classifier-free guidance.
# Specifically, given a tensor of shape (batch_size,) containing class labels,
# create a tensor of shape (2*batch_size,) where the first half is filled with zeros (i.e., null condition).
assert class_label is not None
assert len(class_label) == batch_size, f"len(class_label) != batch_size. {len(class_label)} != {batch_size}"
class_label = class_label.to(self.device)
class_label = torch.cat([torch.zeros_like(class_label), class_label])
#######################
traj = [x_T]
for t in tqdm(self.var_scheduler.timesteps):
x_t = traj[-1]
if do_classifier_free_guidance:
######## TODO ########
# Assignment 2. Implement the classifier-free guidance.
null_c_noise_pred = self.network(
x_t,
timestep=t.to(self.device),
class_label=class_label[:batch_size],
)
c_noise_pred = self.network(
x_t,
timestep=t.to(self.device),
class_label=class_label[batch_size:],
)
noise_pred = (1 + guidance_scale) * c_noise_pred - guidance_scale * null_c_noise_pred
#######################
else:
noise_pred = self.network(
x_t,
timestep=t.to(self.device),
class_label=class_label,
)
x_t_prev = self.var_scheduler.step(x_t, t, noise_pred)
traj[-1] = traj[-1].cpu()
traj.append(x_t_prev.detach())
if return_traj:
return traj
else:
return traj[-1]
def save(self, file_path):
hparams = {
"network": self.network,
"var_scheduler": self.var_scheduler,
}
state_dict = self.state_dict()
dic = {"hparams": hparams, "state_dict": state_dict}
torch.save(dic, file_path)
def load(self, file_path):
dic = torch.load(file_path, map_location="cpu")
hparams = dic["hparams"]
state_dict = dic["state_dict"]
self.network = hparams["network"]
self.var_scheduler = hparams["var_scheduler"]
self.load_state_dict(state_dict)