-
Notifications
You must be signed in to change notification settings - Fork 0
/
element_expert.py
67 lines (49 loc) · 2.04 KB
/
element_expert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from utils import *
n_estimators = 10
max_samples = 0.8
n_epochs = 1000 #
load=0
models = []
optimizers = []
lr_schedulers=[]
for i in range(n_estimators):
model_realnum = Net_1().to(device)
if load:
model_realnum=torch.load(r'model/element_expert_{i}.pth'.format(i=i))
print('model loaded')
models.append(model_realnum)
optimizer = torch.optim.Adam(model_realnum.parameters(), lr=0.001)
optimizers.append(optimizer)
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.005, patience=500, verbose=True)
lr_schedulers.append(lr_scheduler)
for epoch in range(n_epochs):
for i in range(n_estimators):
indices = np.random.choice(len(train_dataset), int(len(train_dataset) * max_samples), replace=True)
sampler = SubsetRandomSampler(indices)
subset_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
all_loss = 0
preds = []
cs = []
for j, (element,element_real,physic, c) in enumerate( subset_dataloader):
model = models[i]
lr_scheduler = lr_schedulers[i]
optimizer = optimizers[i]
model.train()
optimizer.zero_grad()
element_real = element_real.to(device)
physic = physic.to(device)
c = c.to(device)
c=c.type(torch.float32)
c=c.unsqueeze(1)
output = model(element_real)
loss = criterion(output, c)
loss.backward()
optimizer.step()
all_loss+=loss.item()
preds += output.detach().cpu().numpy().tolist()
cs += c.detach().cpu().numpy().tolist()
train_r2 = r2_score(preds, cs)
torch.save(model, 'model/element_expert_{i}.pth'.format(i=i))
print('estimators',i,'Epoch: {}, Loss: {:.6f}, R2 Score: {:.6f}'.format(epoch+1, all_loss/len(subset_dataloader), train_r2))