-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreverseLayerTree.py
More file actions
27 lines (21 loc) · 858 Bytes
/
reverseLayerTree.py
File metadata and controls
27 lines (21 loc) · 858 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
def reverseLayerTree(psi_fin, shape_mid, wlist, wiso):
for p in range(0, len(wlist)):
k = wlist[p]
d1 = int(np.prod(shape_mid[:k]))
d2 = shape_mid[k]
d3 = int(np.prod(shape_mid[k+1:]))
# Reshape and permute psi_fin tensor
psi_fin = np.reshape(psi_fin, (d1, d2, d3))
psi_fin = np.transpose(psi_fin, (1, 0, 2))
psi_fin = np.reshape(psi_fin, (d2, d1*d3))
# Multiply with wiso tensor
psi_temp = wiso[p] @ psi_fin
chi0 = wiso[p].shape[0]
shape_mid[k] = chi0
# Reshape and permute back
psi_temp = np.reshape(psi_temp, (chi0, d1, d3))
psi_temp = np.transpose(psi_temp, (1, 0, 2))
psi_temp = np.reshape(psi_temp, shape_mid)
psi_fin = psi_temp
return psi_fin