forked from CytAI/SRLOOD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
102 lines (92 loc) · 3.5 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import random
import pickle as pkl
task_to_keys = {
"mnli": ("premise", "hypothesis"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
'20ng': ("text", None),
'trec': ("text", None),
'imdb': ("text", None),
'wmt16': ("en", None),
'multi30k': ("text", None),
}
task_to_keys = {"mnli": ("premise", "hypothesis"),"rte": ("sentence1", "sentence2"),"sst2": ("sentence", None),'20ng': ("text", None),'trec': ("text", None),'imdb': ("text", None),'wmt16': ("en", None),'multi30k': ("text", None),}
def load(task_name, tokenizer, max_seq_length=256, is_id=False):
sentence1_key, sentence2_key = task_to_keys[task_name]
print("Loading {}".format(task_name))
if task_name =='mnli':
with open('_DSs/_MNLI.pkl','rb') as f:
datasets = pkl.load(f)
elif task_name =='rte':
with open('_DSs/_RTE.pkl','rb') as f:
datasets = pkl.load(f)
elif task_name == 'sst2':
with open('_DSs/_SST2.pkl','rb') as f:
datasets = pkl.load(f)
elif task_name == '20ng':
with open('_DSs/_20NG.pkl','rb') as f:
datasets = pkl.load(f)
elif task_name == 'trec':
with open('_DSs/_TREC.pkl','rb') as f:
datasets = pkl.load(f)
elif task_name == 'imdb':
with open('_DSs/_IMDB.pkl','rb') as f:
datasets = pkl.load(f)
elif task_name == 'wmt16':
with open('_DSs/_WMT16.pkl','rb') as f:
datasets = pkl.load(f)
elif task_name == 'multi30k':
with open('_DSs/_MULTI30K.pkl','rb') as f:
datasets = pkl.load(f)
train_dataset =datasets['train']
train_dataset = check_srl(train_dataset)
dev_dataset =datasets['validation']
dev_dataset=check_srl(dev_dataset)
test_dataset = datasets['test']
test_dataset=check_srl(test_dataset)
if task_name == 'imdb':
IMDB_og = train_dataset+dev_dataset
IMDB_0 = []
IMDB_1=[]
inds = list(range(len(IMDB_og)-1))
new_val_ind = random.sample(inds,2500)
new_train_ind = list(set(inds)-set(new_val_ind))
new_val=[]
for j in new_val_ind:
new_val.append(IMDB_og[j])
new_train=[]
for i in new_train_ind:
new_train.append(IMDB_og[i])
train_dataset=new_train
dev_dataset=new_val
return train_dataset, dev_dataset, test_dataset
def load_ood(task_name, tokenizer, max_seq_length=256, is_id=False):
sentence1_key, sentence2_key = task_to_keys[task_name]
print("Loading {}".format(task_name))
if task_name =='mnli':
with open('_DSs/_MNLI.pkl','rb') as f:
datasets = pkl.load(f)
elif task_name =='rte':
with open('_DSs/_RTE.pkl','rb') as f:
datasets = pkl.load(f)
elif task_name == 'sst2':
with open('_DSs/_SST2.pkl','rb') as f:
datasets = pkl.load(f)
elif task_name == '20ng':
with open('_DSs/_20NG.pkl','rb') as f:
datasets = pkl.load(f)
elif task_name == 'trec':
with open('_DSs/_TREC.pkl','rb') as f:
datasets = pkl.load(f)
elif task_name == 'imdb':
with open('_DSs/_IMDB.pkl','rb') as f:
datasets = pkl.load(f)
elif task_name == 'wmt16':
with open('_DSs/_WMT16.pkl','rb') as f:
datasets = pkl.load(f)
elif task_name == 'multi30k':
with open('_DSs/_MULTI30K.pkl','rb') as f:
datasets = pkl.load(f)
test_dataset = datasets['test']
test_dataset = check_srl(test_dataset)
return test_dataset