forked from Shinypuff/AdversarialAttacks_SMILES2024
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_classifier.py
93 lines (71 loc) · 2.62 KB
/
train_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import warnings
import hydra
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from src.data import MyDataset, load_data, transform_data
from src.training.train import Trainer
from src.utils import fix_seed, save_config, save_compiled_config
warnings.filterwarnings("ignore")
CONFIG_NAME = "train_classifier_config"
@hydra.main(config_path="config/my_configs", config_name=CONFIG_NAME, version_base=None)
def main(cfg: DictConfig):
if not cfg["test_run"]:
save_config(cfg["save_path"], CONFIG_NAME, CONFIG_NAME)
save_compiled_config(cfg,cfg["save_path"])
augmentator = (
[instantiate(trans) for trans in cfg["transform_data"]]
if cfg["transform_data"]
else None
)
# save_compiled_config(cfg)
# load data
X_train, y_train, X_test, y_test = load_data(cfg["dataset"]['name'])
if len(set(y_test)) > 2:
return None
X_train, X_test, y_train, y_test = transform_data(
X_train,
X_test,
y_train,
y_test,
slice_data=cfg["slice"],
)
train_loader = DataLoader(
MyDataset(X_train, y_train, augmentator),
batch_size=cfg["batch_size"],
shuffle=True,
)
test_loader = DataLoader(
MyDataset(X_test, y_test),
batch_size=cfg["batch_size"],
shuffle=False,
)
device = torch.device(cfg["cuda"] if torch.cuda.is_available() else "cpu")
for model_id in range(cfg["model_id_start"], cfg["model_id_finish"]):
print("trainig model", model_id)
fix_seed(model_id)
logger = SummaryWriter(cfg["save_path"] + "/tensorboard")
const_params = {
"logger": logger,
"print_every": cfg["print_every"],
"device": device,
"seed": model_id,
"train_self_supervised": cfg['train_self_supervised']
}
if cfg["enable_optimization"]:
trainer = Trainer.initialize_with_optimization(
train_loader, test_loader, cfg["optuna_optimizer"], const_params
)
else:
trainer_params = dict(cfg["training_params"])
trainer_params.update(const_params)
trainer = Trainer.initialize_with_params(**trainer_params)
trainer.train_model(train_loader, test_loader)
logger.close()
if not cfg["test_run"]:
model_save_name = f'model_{model_id}_{cfg["dataset"]["name"]}'
trainer.save_result(cfg["save_path"], model_save_name)
if __name__ == "__main__":
main()