-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathexperiments.py
66 lines (57 loc) · 2.04 KB
/
experiments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import sys
import torch
import experiment_generalization
import finetuning_seq2seq
import finetuning_seq2seq_separated
import finetuning_windows
import training
import training_augmented
import training_augmented_noisy
import training_augmented_noisy_seq2seq
from choice_model import choice_model
from legacy import training_windowed, finetuning_seq2seq_fair
import training_noisy
from datetime import datetime
if __name__ == '__main__':
start = datetime.now()
# os.nice(-15)
if not os.path.exists('plots'):
os.mkdir('plots')
if not os.path.exists('models'):
os.mkdir('models')
alias, rep, architecture_type, architecture = sys.argv[1].split('#')
args = {
'alias': alias,
'rep': rep,
'architecture_type': architecture_type,
'architecture': architecture
}
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}")
model = choice_model(architecture, architecture_type)
if architecture_type == 'windowed':
tr = training_windowed
elif architecture_type == 'classification':
tr = training
elif architecture_type == 'classification_noisy':
tr = training_noisy
elif architecture_type == 'classification_augmented':
tr = training_augmented
elif architecture_type == 'classification_augmented_noisy':
tr = training_augmented_noisy
elif architecture_type == 'generalization':
tr = experiment_generalization
elif architecture_type == 'seq2seq_noisy':
tr = training_augmented_noisy_seq2seq
elif 'finetuning_windows' in architecture_type:
tr = finetuning_windows
elif 'finetuning_seq2seq' in architecture_type:
tr = finetuning_seq2seq
elif 'finetuning_separated' in architecture_type:
tr = finetuning_seq2seq_separated
data = tr.create_dataset(rep)
writer, logging = tr.training_loop(data, device, model, args)
tr.run_test(data, device, model, args, writer)
if logging:
writer.add_text("time cost", str(datetime.now()-start))