-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmisc.py
66 lines (50 loc) · 2.32 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
import torch
def cheap_stack(tensors, dim):
if len(tensors) == 1:
return tensors[0].unsqueeze(dim)
else:
return torch.stack(tensors, dim=dim)
def tridiagonal_solve(b, A_upper, A_diagonal, A_lower):
"""Solves a tridiagonal system Ax = b.
The arguments A_upper, A_digonal, A_lower correspond to the three diagonals of A. Letting U = A_upper, D=A_digonal
and L = A_lower, and assuming for simplicity that there are no batch dimensions, then the matrix A is assumed to be
of size (k, k), with entries:
D[0] U[0]
L[0] D[1] U[1]
L[1] D[2] U[2] 0
L[2] D[3] U[3]
. . .
. . .
. . .
L[k - 3] D[k - 2] U[k - 2]
0 L[k - 2] D[k - 1] U[k - 1]
L[k - 1] D[k]
Arguments:
b: A tensor of shape (..., k), where '...' is zero or more batch dimensions
A_upper: A tensor of shape (..., k - 1).
A_diagonal: A tensor of shape (..., k).
A_lower: A tensor of shape (..., k - 1).
Returns:
A tensor of shape (..., k), corresponding to the x solving Ax = b
Warning:
This implementation isn't super fast. You probably want to cache the result, if possible.
"""
# This implementation is very much written for clarity rather than speed.
A_upper, _ = torch.broadcast_tensors(A_upper, b[..., :-1])
A_lower, _ = torch.broadcast_tensors(A_lower, b[..., :-1])
A_diagonal, b = torch.broadcast_tensors(A_diagonal, b)
channels = b.size(-1)
new_b = np.empty(channels, dtype=object)
new_A_diagonal = np.empty(channels, dtype=object)
outs = np.empty(channels, dtype=object)
new_b[0] = b[..., 0]
new_A_diagonal[0] = A_diagonal[..., 0]
for i in range(1, channels):
w = A_lower[..., i - 1] / new_A_diagonal[i - 1]
new_A_diagonal[i] = A_diagonal[..., i] - w * A_upper[..., i - 1]
new_b[i] = b[..., i] - w * new_b[i - 1]
outs[channels - 1] = new_b[channels - 1] / new_A_diagonal[channels - 1]
for i in range(channels - 2, -1, -1):
outs[i] = (new_b[i] - A_upper[..., i] * outs[i + 1]) / new_A_diagonal[i]
return torch.stack(outs.tolist(), dim=-1)