-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtent.py
137 lines (110 loc) · 4.79 KB
/
tent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# %%
from copy import deepcopy
import torch
import torch.jit
import torch
from torch.nn import BatchNorm2d, SyncBatchNorm, Module
class Tent(Module):
"""Tent adapts a model by entropy minimization during testing.
Once tented, a model adapts itself by updating on every forward.
"""
def __init__(self, model, optimizer, steps=1, episodic=False):
super().__init__()
self.model = model
self.optimizer = optimizer
self.steps = steps
assert steps > 0, "tent requires >= 1 step(s) to forward and update"
self.episodic = episodic
# note: if the model is never reset, like for continual adaptation,
# then skipping the state copy would save memory
# self.model_state, self.optimizer_state = \
# copy_model_and_optimizer(self.model, self.optimizer)
def forward(self, x):
if self.episodic:
self.reset()
for _ in range(self.steps):
outputs = forward_and_adapt(x, self.model, self.optimizer)
return outputs
def reset(self):
if self.model_state is None or self.optimizer_state is None:
raise Exception("cannot reset without saved model/optimizer state")
load_model_and_optimizer(self.model, self.optimizer,
self.model_state, self.optimizer_state)
@staticmethod
def collect_params(model):
"""Collect the affine scale + shift parameters from batch norms.
Walk the model's modules and collect all batch normalization parameters.
Return the parameters.
Note: other choices of parameterization are possible!
"""
params = []
for _, m in model.named_modules():
if isinstance(m, (BatchNorm2d, SyncBatchNorm)):
for np, p in m.named_parameters():
if np in ['weight', 'bias']: # weight is scale, bias is shift
params.append(p)
return params
def __deepcopy__(self, memo):
deepcopy_method = self.__deepcopy__
self.__deepcopy__ = None
cp = deepcopy(self, memo)
params = self.collect_params(cp.model)
cp.optimizer = type(self.optimizer)(params, lr=self.optimizer.defaults['lr'])
cp.optimizer.load_state_dict(self.optimizer.state_dict())
self.__deepcopy__ = deepcopy_method
cp.__deepcopy__ = deepcopy_method
return cp
@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
"""Entropy of softmax distribution from logits."""
return -(x.softmax(1) * x.log_softmax(1)).sum(1)
@torch.enable_grad() # ensure grads in possible no grad context for testing
def forward_and_adapt(x, model, optimizer):
"""Forward and adapt model on batch of data.
Measure entropy of the model prediction, take gradients, and update params.
"""
# forward
outputs = model(x)
# adapt
loss = softmax_entropy(outputs).mean(0)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return outputs
def copy_model_and_optimizer(model, optimizer):
"""Copy the model and optimizer states for resetting after adaptation."""
model_state = deepcopy(model.state_dict())
optimizer_state = deepcopy(optimizer.state_dict())
return model_state, optimizer_state
def load_model_and_optimizer(model, optimizer, model_state, optimizer_state):
"""Restore the model and optimizer states from copies."""
model.load_state_dict(model_state, strict=True)
optimizer.load_state_dict(optimizer_state)
def configure_model(model):
"""Configure model for use with tent."""
# train mode, because tent optimizes the model to minimize entropy
model.train()
# disable grad, to (re-)enable only what tent updates
model.requires_grad_(False)
# configure norm for tent updates: enable grad + force batch statisics
for m in model.modules():
if isinstance(m, (BatchNorm2d, SyncBatchNorm)):
m.requires_grad_(True)
# force use of batch stats in train and eval modes
m.track_running_stats = False
m.running_mean = None
m.running_var = None
return model
def check_model(model):
"""Check model for compatability with tent."""
is_training = model.training
assert is_training, "tent needs train mode: call model.train()"
param_grads = [p.requires_grad for p in model.parameters()]
has_any_params = any(param_grads)
has_all_params = all(param_grads)
assert has_any_params, "tent needs params to update: " \
"check which require grad"
assert not has_all_params, "tent should not update all params: " \
"check which require grad"
has_bn = any([isinstance(m, (BatchNorm2d, SyncBatchNorm)) for m in model.modules()])
assert has_bn, "tent needs normalization for its optimization"