-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_histogram.py
186 lines (141 loc) · 8.71 KB
/
train_histogram.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torchnet as tnt
import json, os
import random
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from scipy.stats import pearsonr
from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error, r2_score
import torch.optim as optim
from data_preparation.dataset_histogram import YDataset_Hist
import torch
import pickle as pkl
from data_preparation.utils import evalMetrics
from data_preparation.utils_deeplearning import *
from torch.nn import MSELoss
from models.models_1d import LSTM
from models.models_1d import Hist1D
from models.cnn_3d_you import Hist_3D_You
from torch.utils.tensorboard import SummaryWriter
import hydra
from omegaconf import DictConfig, OmegaConf
@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig):
# create results dir and save config
cfg.model_config.res_dir = cfg.model_config.res_dir + '_{}'.format(cfg.training.seed)
# print(OmegaConf.to_yaml(cfg))
os.makedirs(cfg.model_config.res_dir, exist_ok=True)
OmegaConf.save(cfg, os.path.join(cfg.model_config.res_dir, 'config.yaml'))
# set seed for reproducability
np.random.seed(cfg.training.seed)
random.seed(cfg.training.seed)
torch.manual_seed(cfg.training.seed)
device = torch.device(cfg.training.device)
# Initialize YieldDataset with various parameters
train_dataset = YDataset_Hist(cfg.model_config.npy_path, cfg.dataset.label_path,
norm_path= cfg.dataset.norm_path, lookup=cfg.dataset.train_years,
mode='train', seed=cfg.training.seed, start_doy_idx=cfg.dataset.start_doy_idx,
end_doy_idx=cfg.dataset.end_doy_idx, feature_idx =cfg.dataset.feature_idx)
# Initialize YieldDataset with various parameters
val_dataset = YDataset_Hist(cfg.model_config.npy_path, cfg.dataset.label_path,
norm_path= cfg.dataset.norm_path, lookup=cfg.dataset.train_years,
mode='validation', seed=cfg.training.seed, start_doy_idx=cfg.dataset.start_doy_idx,
end_doy_idx=cfg.dataset.end_doy_idx, feature_idx =cfg.dataset.feature_idx)
# Initialize YieldDataset with various parameters
test_dataset_d = YDataset_Hist(cfg.model_config.npy_path, cfg.dataset.label_path,
norm_path= cfg.dataset.norm_path, lookup=cfg.dataset.test_years_d,
mode=None, seed=cfg.training.seed, start_doy_idx=cfg.dataset.start_doy_idx,
end_doy_idx=cfg.dataset.end_doy_idx, feature_idx =cfg.dataset.feature_idx)
test_dataset_nd = YDataset_Hist(cfg.model_config.npy_path, cfg.dataset.label_path,
norm_path= cfg.dataset.norm_path, lookup=cfg.dataset.test_years_nd,
mode=None, seed=cfg.training.seed, start_doy_idx=cfg.dataset.start_doy_idx,
end_doy_idx=cfg.dataset.end_doy_idx, feature_idx =cfg.dataset.feature_idx)
train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=cfg.training.num_workers, \
batch_size=cfg.model_config.batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, num_workers=cfg.training.num_workers, \
batch_size=cfg.model_config.batch_size, shuffle=True )
test_loader_d = torch.utils.data.DataLoader(test_dataset_d, num_workers=cfg.training.num_workers, \
batch_size=cfg.model_config.batch_size, shuffle=False )
test_loader_nd = torch.utils.data.DataLoader(test_dataset_nd, num_workers=cfg.training.num_workers, \
batch_size=cfg.model_config.batch_size, shuffle=False )
print('Train {}, Val {}, Test {}, Test {}'.format(len(train_loader), len(val_loader), len(test_loader_d), len(test_loader_nd)))
# Initialize the model based on the name
if cfg.model_config.name == "histo_1d":
model = Hist1D(input_dim=cfg.dataset.input_dim, seq_length=cfg.dataset.seq_length,
kernel_size=cfg.model_config.kernel_size, hidden_dims=cfg.model_config.hidden_dim,
num_layers = cfg.model_config.num_layers,
dropout=cfg.model_config.dropout)
elif cfg.model_config.name == "histo_3d":
model = Hist_3D_You(input_dim=cfg.dataset.input_dim, dropout=cfg.model_config.dropout,
seq_length=cfg.dataset.seq_length, dense_features = None)
else:
raise ValueError(f"Unknown model name: {cfg.model_config.name}")
model = model.to(cfg.training.device)
# Create optimizer from Hydra config
optimizer_class = getattr(optim, cfg.training.optimizer.capitalize())
optimizer = optimizer_class(model.parameters(), lr=cfg.model_config.lr)
criterion = MSELoss()
# Initialize TensorBoard SummaryWriter
writer_train = SummaryWriter(log_dir=cfg.model_config.res_dir)
writer_val = SummaryWriter(log_dir=cfg.model_config.res_dir)
# holder for logging training performance
trainlog = {}
best_RMSE = np.inf
epochs_no_improve = 0
for epoch in range(1, cfg.model_config.epochs + 1):
print('EPOCH {}/{}'.format(epoch, cfg.model_config.epochs))
model.train()
train_metrics = train_epoch(model, optimizer, criterion, train_loader,
device=device, display_step=cfg.training.display_step)
# print(train_metrics)
# Log training metrics to TensorBoard
writer_train.add_scalar('Loss/train', train_metrics['train_loss'], epoch)
writer_train.add_scalar('R2/train', train_metrics['train_R2'], epoch)
# print('Validation . . . ')
model.eval()
val_metrics = evaluation(model, criterion, val_loader, device=device, mode='val')
print('Loss {:.4f}, RMSE {:.4f}, R2 {:.4f}'.format(val_metrics['val_loss'],
val_metrics['val_rmse'],
val_metrics['val_R2']))
# Log validation metrics to TensorBoard
writer_val.add_scalar('Loss/val', val_metrics['val_loss'], epoch)
# writer_val.add_scalar('RMSE/val', val_metrics['val_rmse'], epoch)
writer_val.add_scalar('R2/val', val_metrics['val_R2'], epoch)
trainlog[epoch] = {**train_metrics, **val_metrics}
checkpoint(trainlog, cfg.model_config.res_dir)
# Early stopping
if val_metrics['val_rmse'] < best_RMSE:
best_epoch = epoch
best_RMSE = val_metrics['val_rmse']
epochs_no_improve = 0 # Reset the counter if validation loss improves
torch.save({'best epoch': best_epoch, 'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()},
os.path.join(cfg.model_config.res_dir, 'model.pth.tar'))
else:
epochs_no_improve += 1 # Increment the counter if validation loss does not improve
if epochs_no_improve >= cfg.model_config.patience:
print(f'Early stopping at epoch {epoch + 1}')
break # Stop training if no improvement for specified number of epochs
# load best model
model.load_state_dict(torch.load(os.path.join(cfg.model_config.res_dir, 'model.pth.tar'))['state_dict'])
# evaluate on test data
model.eval()
test_metrics, y_true, y_pred = evaluation(model, criterion, test_loader_d, device=device, mode='test')
print('========== Test Metrics ===========')
print('Loss {:.4f}, RMSE {:.4f}, R2 {:.4f}'.format(test_metrics['test_loss'],
test_metrics['test_rmse'],
test_metrics['test_R2']))
save_results(test_metrics, cfg.model_config.res_dir, y_true, y_pred, test_dataset_d.geoid ,cfg.dataset.test_years_d[0])
# evaluate on test data
model.eval()
test_metrics, y_true, y_pred = evaluation(model, criterion, test_loader_nd, device=device, mode='test')
print('========== Test Metrics ===========')
print('Loss {:.4f}, RMSE {:.4f}, R2 {:.4f}'.format(test_metrics['test_loss'],
test_metrics['test_rmse'],
test_metrics['test_R2']))
save_results(test_metrics, cfg.model_config.res_dir, y_true, y_pred, test_dataset_nd.geoid, cfg.dataset.test_years_nd[0])
# close the TensorBoard writer
writer_train.close()
writer_val.close()
if __name__ == "__main__":
main()