-
Notifications
You must be signed in to change notification settings - Fork 0
/
linalg.py
368 lines (322 loc) · 11.6 KB
/
linalg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
from .models.model_base import ClassificationModelBase
from .models.utils import get_current_gradients
from .models.utils import get_trainable_parameters, \
get_trainable_parameters_in_layer, \
set_trainable_parameters, \
set_trainable_parameters_in_layer, \
flip_parameters_to_tensors
from typing import Union, Tuple
from torch.utils.data import Dataset, DataLoader
from hessian_eigenthings import compute_hessian_eigenthings
from copy import deepcopy
import numpy as np
import torch
import math
import torch.nn.functional as F
def cos_vec_vec(
vec_1: torch.Tensor,
vec_2: torch.Tensor) -> float:
""" Computes cosine similarity between vec_1 and vec_2.
"""
vec_1 = vec_1.cpu()
vec_2 = vec_2.cpu()
if vec_1.norm() > 0 and vec_2.norm() > 0:
cos = np.dot(vec_1,vec_2)/(vec_1.norm()*vec_2.norm())
return cos.item()
else:
return float('inf')
def get_random_ortho_matrix(
D: int,
d: int,
device: torch.device) -> torch.Tensor:
""" Computes a random (D, d) orthonormal matrix.
Adapted from https://github.com/jeffiar/cs229-final-project
"""
M = torch.zeros(D, d, device=device)
for i in range(d):
col = torch.zeros(D)
prob = 1 / math.sqrt(D)
col[torch.rand(D) < prob] = 1
col[torch.rand(D) < 0.5] *= -1
col /= col.norm()
M[:,i] = col
return M
def sparse_vector(
D: int,
n: int) -> torch.Tensor:
""" Computes a D dimensional sparse vector with n non-zero entries.
Adapted from https://github.com/jeffiar/cs229-final-project
"""
vec = torch.zeros(D).float()
idxs = np.random.choice(range(D), size=n, replace=False)
signs = np.random.choice([-1.,1.],size=n)
cnsts = torch.from_numpy(signs*math.sqrt(n)*np.ones(n)).float()
vec[idxs] = cnsts
return vec
def goldilocks(
model: ClassificationModelBase,
dataset: Dataset,
dim: int,
device: torch.device,
layer: Union[int, None] = None) -> Tuple[float, float]:
if layer is None:
w = get_trainable_parameters(model).to(device)
else:
w = get_trainable_parameters_in_layer(model, layer).to(device)
w = w.squeeze()
ns = np.random.choice(range(w.numel()), size=dim)
M = torch.vstack([sparse_vector(w.numel(), n) for n in ns])
R = get_random_ortho_matrix(len(w), dim, device).to(model.dtype)
L,V = eigenvvs(
model=model,
dataset=dataset,
top_k=-1,
is_subspace=True,
R=R,
layer=layer)
L = L.to(model.dtype)
frac_pos = sum(L>0).item()/L.numel()
trace = torch.sum(L).item()
norm = torch.linalg.norm(L).item()
if norm == 0:
trace_norm = float('inf')
else:
trace_norm = trace / norm
return frac_pos, trace_norm
def eigenvvs(
model: ClassificationModelBase,
dataset: Dataset,
top_k: int,
is_subspace: bool = False,
R: Union[torch.Tensor, None] = None,
layer: Union[int, None] = None) -> Tuple[torch.Tensor, torch.Tensor]:
""" Computes top_k eigenvalues and eigenvectors of the model
on train_dataset (top_k = -1 means all).
- is_subspace: use a subspace specified by R
- R (P, d): column vectors spanning the subsapce
- layer: compute the Hessian for a specified tensor layer.
"""
H = hessian(model, dataset, is_subspace, R, layer)
H = H.squeeze()
L, V = torch.linalg.eig(H)
V = V.T
L = torch.real(L)
V = torch.real(V)
o = sorted(range(len(L)), key=lambda i: L[i], reverse=True)
if top_k>0:
V_k = V[o][:top_k]
L_k = L[o][:top_k]
else:
V_k = V[o]
L_k = L[o]
return L_k, V_k
def hessian(
model: ClassificationModelBase,
dataset: Dataset,
is_subspace: bool = False,
R: Union[torch.Tensor, None] = None,
layer: Union[int, None] = None) -> torch.Tensor:
""" Computes the full Hessian of the model on the dataset.
- is_subspace: use a subspace specified by R
- R (P, d): column vectors spanning the subspace
- layer: compute the Hessian for a specified tensor layer.
Warning: this function invokes flip_parameters_to_tensors,
which removes all parameters from the model; hence, it can
be called only once. After invoking this function, it is
impossible to retrieve gradients with respect to model's
parameters as they are no longer leaf tensors. Likewise,
no training of the model can be arranged thereafter.
"""
model.zero_grad()
model.eval()
if not is_subspace:
if layer is None:
w = get_trainable_parameters(model)
else:
w = get_trainable_parameters_in_layer(model, layer)
w.requires_grad_(True)
R = torch.eye(w.numel()).to(model.device)
d = torch.zeros_like(w).to(model.device)
else:
if layer is None:
d = get_trainable_parameters(model)
else:
d = get_trainable_parameters_in_layer(model, layer)
w = torch.zeros(R.shape[1]).requires_grad_(True)
w = w.reshape(-1,1).to(model.dtype).to(model.device)
model_copy = deepcopy(model)
flip_parameters_to_tensors(model_copy)
dataloader = DataLoader(dataset, batch_size=128)
def func(w):
W = torch.mm(R, w)
W = W + d.reshape(W.shape)
if layer is None:
set_trainable_parameters(model_copy, W)
else:
set_trainable_parameters_in_layer(model_copy, W, layer)
y_preds = []
y_trues = []
model_copy.eval()
for X,y in dataloader:
X = X.to(model_copy.device)
y = y.to(model_copy.device)
y_preds.append(model_copy(X))
y_trues.append(y)
y_preds = torch.vstack(y_preds)
y_trues = torch.cat(y_trues)
return F.cross_entropy(y_preds, y_trues)
H = torch.autograd.functional.hessian(func, w).squeeze()
del model_copy
return H
def eigenthings(
model: ClassificationModelBase,
loss: callable,
dataset: Dataset,
num_things: int) -> Tuple[torch.Tensor, torch.Tensor]:
""" Computes num_things eigenvalues and eigenvectors.
"""
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128)
use_gpu = True if model.device.type == 'cuda' else False
vals, vects = compute_hessian_eigenthings(
model,
dataloader,
loss,
num_things,
use_gpu=use_gpu)
return vals, vects
def _get_G_term(
J: torch.Tensor,
p: torch.Tensor) -> torch.Tensor:
""" Computes the G-term from the Gauss-Newton decomposition.
- J (K, d): Jacobian
- p (K): softmax output
"""
num_p = J.shape[1]
A = p.reshape(1,-1) @ J
Jp = J*p.repeat(num_p, 1).T
G = (J.T @ Jp) - (A.T @ A)
return G
def get_G_term(
J: torch.Tensor,
p: torch.Tensor) -> torch.Tensor:
""" Computes the G-term from the Gauss-Newton decomposition.
- J (S, K, d): Jacobian
- p (S, K): softmax output
"""
S = J.shape[0]
Gs = torch.vstack([_get_G_term(J[s], p[s]).unsqueeze(0) for s in range(S)])
return Gs.mean(dim=0)
def get_Jacobian(
model: ClassificationModelBase,
dataset: Dataset,
K: int,
R: Union[torch.Tensor, None] = None) -> torch.Tensor:
""" Computes Jacobian of the model on dataset.
- K: number of classes;
- R (P, d): subspace (if any).
Returns:
- J (S, K, d): Jacobian with respect to samples in dataset.
"""
w = get_trainable_parameters(model)
d = w.numel() if R is None else R.shape[1]
J = torch.zeros(len(dataset), K, d).to(model.dtype)
for s,(X,_) in enumerate(dataset):
for k in range(K):
model.zero_grad()
if X.shape[0]>1:
logit = model(X.unsqueeze(0)).squeeze()[k]
else:
logit = model(X).squeeze()[k]
logit.backward()
g = get_current_gradients(model).detach().squeeze()
if R is not None:
g = g.reshape(1,-1)@R
J[s][k] = g.reshape(-1)
model.zero_grad()
return J
def Gamma(
P: torch.Tensor) -> float:
""" Computes Gamma(P) for a matrix of softmax outputs.
- P (S, K): matrix of softmax outputs.
"""
M = sum((torch.diag(p.reshape(-1))-p.reshape(-1,1)@p.reshape(1,-1)) for p in P)
M = M/P.shape[0]
trace = torch.trace(M)
norm = torch.linalg.norm(M, 'fro')
if trace == 0 or norm == 0:
return 0
else:
return trace/norm
def EG_curvature(
var_E: float,
var_C: float,
d: int,
P: torch.Tensor):
""" Computes positive curvature of the expected G_term per Eq. 9.
- var_E: estimated variance of logit gradients
- var_C: estimated variance of logit gradient means
- d: dimension of the subspace
- P (S, K): matrix of softmax outputs.
"""
gamma = Gamma(P)
if gamma >= 0 and gamma < 1:
return 0
numer = np.sqrt(d)*(var_E+var_C)
denom = np.sqrt(var_E**2 + 2*var_E*var_C + d*(var_C**2)/(gamma**2))
res = numer / denom
return res
def hessian_vector_product(
model: ClassificationModelBase,
loss: torch.Tensor,
vec: torch.Tensor) -> torch.Tensor:
""" Computes hessian-vector product H @ vec wrt to
model parameters given a scalar loss tensor.
"""
grad = torch.autograd.grad(
loss,
model.parameters(),
create_graph=True,
allow_unused=True)
grad = torch.cat([g.reshape(-1) for g in grad if g is not None])
g = grad @ vec
hvp = torch.autograd.grad(g, model.parameters(), allow_unused=True)
hvp = torch.cat([g.reshape(-1) for g in hvp if g is not None])
return hvp.detach()
def hutch_tr_H(
model: ClassificationModelBase,
train_X: torch.Tensor,
train_y: torch.LongTensor,
maxiter: int = 100) -> float:
""" Computes trace of the model Hessian computed on data
(train_X, train_y) using Hutchinson's stochastic
approximation method using maxiter iterations.
"""
w = get_trainable_parameters(model)
traces = []
for _ in range(maxiter):
model.zero_grad()
vec = torch.randint_like(w, high=2).squeeze()
vec[vec == 0] = -1
loss = F.cross_entropy(model(train_X), train_y)
t = vec@hessian_vector_product(model, loss, vec)
traces.append(t.item())
return np.mean(traces).item()
def hutch_fr_H(
model: ClassificationModelBase,
train_X: torch.Tensor,
train_y: torch.LongTensor,
maxiter: int = 100) -> float:
""" Computes Frobenius norm of the model Hessian computed
on data (train_X, train_y) using Hutchinson's stochastic
approximation method using maxiter iterations.
"""
w = get_trainable_parameters(model)
frobs = []
for _ in range(maxiter):
model.zero_grad()
vec = torch.randint_like(w, high=2).squeeze()
vec[vec == 0] = -1
loss = F.cross_entropy(model(train_X), train_y)
hvp = hessian_vector_product(model, loss, vec).squeeze()
frobs.append(hvp@hvp)
return np.sqrt(np.mean(frobs)).item()