-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexperiment.py
138 lines (106 loc) · 6.33 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from base.experiment import GenericExperiment
from base.utils import load_pickle
from base.loss_function import compute_AU_F1, compute_AU_loss_BCE
from trainer import Trainer
from dataset import DataArranger, Dataset
from base.checkpointer import Checkpointer
from models.model import LFAN, CAN
from base.parameter_control import ResnetParamControl
import os
class Experiment(GenericExperiment):
def __init__(self, args):
super().__init__(args)
self.args = args
self.release_count = args.release_count
self.gradual_release = args.gradual_release
self.milestone = args.milestone
self.backbone_mode = "ir"
self.min_num_epochs = args.min_num_epochs
self.num_epochs = args.num_epochs
self.early_stopping = args.early_stopping
self.load_best_at_each_epoch = args.load_best_at_each_epoch
self.num_heads = args.num_heads
self.modal_dim = args.modal_dim
self.tcn_kernel_size = args.tcn_kernel_size
def prepare(self):
self.config = self.get_config()
self.feature_dimension = self.get_feature_dimension(self.config)
self.multiplier = self.get_multiplier(self.config)
self.time_delay = self.get_time_delay(self.config)
self.get_modality()
self.continuous_label_dim = 12 #self.get_selected_continuous_label_dim() #标签的维度,有问题
self.dataset_info = load_pickle(os.path.join(self.dataset_path, "dataset_info.pkl"))
self.data_arranger = self.init_data_arranger()
if self.calc_mean_std:
self.calc_mean_std_fn()
self.mean_std_dict = load_pickle(os.path.join(self.dataset_path, "mean_std_info.pkl"))
def init_data_arranger(self):
arranger = DataArranger(self.dataset_info, self.dataset_path, self.debug)
return arranger
def run(self):
criterion = compute_AU_loss_BCE()
for fold in iter(self.folds_to_run):
# save_path = os.path.join(self.save_path,
# self.experiment_name + "_" + self.model_name + "_" + self.stamp + "_fold" + str(
# fold) + "_" + str(self.patience) +str(self.learning_rate) + "_window"+str(self.window_length) +"_hop"+str(self.hop_length)+"_seed" +str(self.factor)+"_factor"+ str(self.seed))
save_path = '/home/data/zhangzr22/abaw/ABAW6/save/ABAW6_CAN_zzr_319_one_big_fold4_fold1_40.0001_window200_hop200_seed0.2_factor3407'
os.makedirs(save_path, exist_ok=True)
checkpoint_filename = os.path.join(save_path, "checkpoint.pkl")
model = self.init_model()
dataloaders = self.init_dataloader(fold)
trainer_kwards = {'device': self.device, 'emotion': self.emotion, 'model_name': self.model_name,
'models': model, 'save_path': save_path, 'fold': fold,
'min_epoch': self.min_num_epochs, 'max_epoch': self.num_epochs,
'early_stopping': self.early_stopping, 'scheduler': self.scheduler,
'learning_rate': self.learning_rate, 'min_learning_rate': self.min_learning_rate,
'patience': self.patience, 'batch_size': self.batch_size,
'criterion': criterion, 'factor': self.factor, 'verbose': True,
'milestone': self.milestone, 'metrics': self.config['metrics'],
'load_best_at_each_epoch': self.load_best_at_each_epoch,
'save_plot': self.config['save_plot']}
trainer = Trainer(**trainer_kwards)
parameter_controller = ResnetParamControl(trainer, gradual_release=self.gradual_release,
release_count=self.release_count,
backbone_mode=["visual", "audio"])
checkpoint_controller = Checkpointer(checkpoint_filename, trainer, parameter_controller, resume=self.resume)
if self.resume:
trainer, parameter_controller = checkpoint_controller.load_checkpoint()
else:
checkpoint_controller.init_csv_logger(self.args, self.config)
# if not trainer.fit_finished:
# trainer.fit(dataloaders, parameter_controller=parameter_controller,
# checkpoint_controller=checkpoint_controller)
test_kwargs = {'dataloader_dict': dataloaders, 'epoch': None, 'partition': 'extra'}
trainer.test(checkpoint_controller, predict_only=1, **test_kwargs)
def init_dataset(self, data, continuous_label_dim, mode, fold):
dataset = Dataset(data, continuous_label_dim, self.modality, self.multiplier,
self.feature_dimension, self.window_length,
mode, mean_std=self.mean_std_dict[fold][mode], time_delay=self.time_delay)
return dataset
def init_model(self):
self.init_randomness()
modality = [modal for modal in self.modality if "continuous_label" not in modal]
if self.model_name == "LFAN":
model = LFAN(backbone_settings=self.config['backbone_settings'],
modality=modality, example_length=self.window_length,
kernel_size=self.tcn_kernel_size,
tcn_channel=self.config['tcn']['channels'], modal_dim=self.modal_dim, num_heads=self.num_heads,
root_dir=self.load_path, device=self.device)
model.init()
elif self.model_name == "CAN":
model = CAN(root_dir=self.load_path, modalities=modality, tcn_settings=self.config['tcn_settings'], backbone_settings=self.config['backbone_settings'], output_dim=self.continuous_label_dim, device=self.device)
# model.init()
return model
def get_modality(self):
pass
def get_config(self):
from configs import config
return config
def get_selected_continuous_label_dim(self):
if self.emotion == "arousal":
dim = [1]
elif self.emotion == "valence":
dim = [0]
else:
raise ValueError("Unknown emotion!")
return dim