-
Notifications
You must be signed in to change notification settings - Fork 0
/
FK_2DSE.py
69 lines (61 loc) · 2.24 KB
/
FK_2DSE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch, os
from FK_DMFT import DMFT
import torch.multiprocessing as mp
from functools import partial
from tqdm import tqdm
from utils import myceil
import warnings
warnings.filterwarnings('ignore')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# parameters
L = 12
data = f'FK_{L}_QPT'
TYPE = 'train'
processors = 0
if processors == 0: bz = 100
# DMFT configs
count = 20
iota = 0.
momentum = 0.5
momDisor = 0.
maxEpoch = 2000
milestone = 30
f_filling = 0.5
d_filling = None
tol_sc = 1e-6
tol_bi = 1e-7
gap = 1.
double = True
device = torch.device("cuda")
def computeSE(i, bz, scf, path):
SE = scf(target[i * bz:(i + 1) * bz, 1], H0[i * bz:(i + 1) * bz], target[i * bz:(i + 1) * bz, 0])
torch.save(SE, f'{path}/SE_{i}.pt')
return SE # (bz, scf.count, L ** 2)
if __name__ == "__main__":
if processors > 0:
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
torch.set_num_threads(1)
path = 'datasets/{}/{}'.format(data, TYPE)
H0 = torch.load('{}/dataset.pt'.format(path)) # (amount, scf.count, L ** 2, L ** 2)
target = torch.load('{}/labels.pt'.format(path)) # (amount, 3)
scf = DMFT(count, iota, momentum, momDisor, maxEpoch, milestone, f_filling, d_filling, tol_sc, tol_bi, gap, device, double)
if processors > 0:
SE = []
mp.set_start_method('fork', force=True)
pool = mp.Pool(processes=processors)
res = pool.map(partial(computeSE, bz=myceil(len(H0) / processors), scf=scf, path=path), range(processors))
for se in res: SE.append(se)
pool.close()
pool.join()
torch.save(torch.cat(SE, dim=0), '{}/SE.pt'.format(path)) # (amount, scf.count, L ** 2)
else:
SE = torch.zeros((0, scf.count, L ** 2), dtype=scf.iomega0.dtype)
for i in tqdm(range(myceil(len(H0) / bz))):
SE = torch.cat((SE, scf(target[i * bz:(i + 1) * bz, 1], H0[i * bz:(i + 1) * bz],
target[i * bz:(i + 1) * bz, 0]).cpu()), dim=0)
torch.save(SE, '{}/SE.pt'.format(path)) # (amount, scf.count, L ** 2)