-
Notifications
You must be signed in to change notification settings - Fork 89
/
train_ditto.py
92 lines (78 loc) · 3.28 KB
/
train_ditto.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import argparse
import json
import sys
import torch
import numpy as np
import random
sys.path.insert(0, "Snippext_public")
from ditto_light.dataset import DittoDataset
from ditto_light.summarize import Summarizer
from ditto_light.knowledge import *
from ditto_light.ditto import train
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Structured/Beer")
parser.add_argument("--run_id", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--max_len", type=int, default=256)
parser.add_argument("--lr", type=float, default=3e-5)
parser.add_argument("--n_epochs", type=int, default=20)
parser.add_argument("--finetuning", dest="finetuning", action="store_true")
parser.add_argument("--save_model", dest="save_model", action="store_true")
parser.add_argument("--logdir", type=str, default="checkpoints/")
parser.add_argument("--lm", type=str, default='distilbert')
parser.add_argument("--fp16", dest="fp16", action="store_true")
parser.add_argument("--da", type=str, default=None)
parser.add_argument("--alpha_aug", type=float, default=0.8)
parser.add_argument("--dk", type=str, default=None)
parser.add_argument("--summarize", dest="summarize", action="store_true")
parser.add_argument("--size", type=int, default=None)
hp = parser.parse_args()
# set seeds
seed = hp.run_id
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# only a single task for baseline
task = hp.task
# create the tag of the run
run_tag = '%s_lm=%s_da=%s_dk=%s_su=%s_size=%s_id=%d' % (task, hp.lm, hp.da,
hp.dk, hp.summarize, str(hp.size), hp.run_id)
run_tag = run_tag.replace('/', '_')
# load task configuration
configs = json.load(open('configs.json'))
configs = {conf['name'] : conf for conf in configs}
config = configs[task]
trainset = config['trainset']
validset = config['validset']
testset = config['testset']
# summarize the sequences up to the max sequence length
if hp.summarize:
summarizer = Summarizer(config, lm=hp.lm)
trainset = summarizer.transform_file(trainset, max_len=hp.max_len)
validset = summarizer.transform_file(validset, max_len=hp.max_len)
testset = summarizer.transform_file(testset, max_len=hp.max_len)
if hp.dk is not None:
if hp.dk == 'product':
injector = ProductDKInjector(config, hp.dk)
else:
injector = GeneralDKInjector(config, hp.dk)
trainset = injector.transform_file(trainset)
validset = injector.transform_file(validset)
testset = injector.transform_file(testset)
# load train/dev/test sets
train_dataset = DittoDataset(trainset,
lm=hp.lm,
max_len=hp.max_len,
size=hp.size,
da=hp.da)
valid_dataset = DittoDataset(validset, lm=hp.lm)
test_dataset = DittoDataset(testset, lm=hp.lm)
# train and evaluate the model
train(train_dataset,
valid_dataset,
test_dataset,
run_tag, hp)