-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen_envs.py
76 lines (59 loc) · 2.34 KB
/
gen_envs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import logging
import numpy as np
from pathlib import Path
import torch
from torch import nn
from archs.autoencoder import Autoencoder
class AutoencoderEnv(nn.Module):
def __init__(self, ae_block_type, ae_conv_type, max_layer, ckpt_path: Path = None):
super().__init__()
self.autoencoder = Autoencoder(ae_block_type, ae_conv_type, max_layer)
if ckpt_path is not None:
logging.debug(f'Loading autoencoder checkpoint from {ckpt_path}...')
state_dict = torch.load(ckpt_path, map_location='cpu')
self.autoencoder.load_state_dict(state_dict)
self.opt = None
self.gen = None
def init_opt(self, opt_ckpt_path: Path = None):
self.opt = torch.optim.Adam(self.autoencoder.parameters())
if opt_ckpt_path is not None:
logging.debug(f'Loading autoencoder optimizer checkpoint from {opt_ckpt_path}...')
opt_state_dict = torch.load(opt_ckpt_path, map_location='cpu')
self.opt.load_state_dict(opt_state_dict)
def reset(self):
def reset_parameters(m):
if isinstance(m, (nn.Conv2d, nn.BatchNorm2d)):
m.reset_parameters()
self.autoencoder.apply(reset_parameters)
def step(self, grad, max_grad_norm=None):
assert self.gen is not None
self.opt.zero_grad()
self.gen.backward(grad)
if max_grad_norm is not None:
grad_norm = nn.utils.clip_grad_norm_(self.autoencoder.parameters(), max_grad_norm)
else:
grad_norm = None
self.opt.step()
self.gen = None
return grad_norm
def forward(self, gt):
self.gen = self.autoencoder(gt)
return self.gen.detach()
def state_dicts(self, *args, **kwargs):
return self.autoencoder.state_dict(*args, **kwargs), self.opt.state_dict()
class MixupEnv(nn.Module):
def __init__(self, beta):
super().__init__()
self.beta = beta
def forward(self, gt):
mixup_coefs = torch.tensor(
np.random.beta(self.beta, self.beta, size=(gt.shape[0], 1, 1, 1)),
dtype=torch.float32,
device=gt.device,
)
p = np.random.permutation(gt.shape[0])
gen = mixup_coefs * gt + (1 - mixup_coefs) * gt[p, :, :, :]
return gen
class MixupZeroEnv(nn.Module):
def forward(self, gt):
return gt