forked from songyingxin/Bert-TextClassification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_kaggle_cnn.py
73 lines (50 loc) · 2.37 KB
/
run_kaggle_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# coding=utf-8
from main import main
import pandas as pd
import os
if __name__ == "__main__":
model_name = 'BertCNN'
data_dir = '/data/hzy/Bert-Lab/Data/origin'
output_dir = ".Result/"
cache_dir = ".Cache/"
log_dir = ".Logs/"
train_val_rate='9:1'
# bert-base
bert_vocab_file = '/data/hzy/Bert-Lab/Bert_weight/bert-base-uncased-vocab.txt'
bert_model_dir = '/data/hzy/Bert-Lab/Bert_weight/bert-base-uncased.tar.gz'
# # bert-large
# bert_vocab_file = "/search/hadoop02/suanfa/songyingxin/pytorch_Bert/bert-large-uncased-vocab.txt"
# bert_model_dir = "/search/hadoop02/suanfa/songyingxin/pytorch_Bert/bert-large-uncased"
train_val_data = pd.read_csv(os.path.join(data_dir, "train.csv"))
test_data = pd.read_csv(os.path.join(data_dir, "test.csv"))
all_data = pd.concat([train_val_data, test_data])
max_length = all_data['Sentence'].apply(lambda x: len(x.split())).max()
label_list = train_val_data['Category'].drop_duplicates().tolist()
#print(f'max: {max_length}\ncls_list: {label_list}')
#train_val_nums = len(train_val_data)
#train_rate = int(train_val_rate.split(":")[0])/10
#train_data = train_val_data.iloc[:int(train_val_nums*train_rate)]
#val_data = train_val_data.iloc[int(train_val_nums*train_rate):]
#test_data['Label']=label_list[0]
#test_data.drop('Id',axis=1,inplace=True)
#train_data.to_csv(os.path.join(os.path.dirname(data_dir),'train.csv'),index=False)
#val_data.to_csv(os.path.join(os.path.dirname(data_dir),'dev.csv'),index=False)
#test_data.to_csv(os.path.join(os.path.dirname(data_dir),'test.csv'),index=False)
data_dir = os.path.dirname(data_dir)
if model_name == "BertOrigin":
from BertOrigin import args
elif model_name == "BertCNN":
from BertCNN import args
elif model_name == 'BertLSTM':
from BertLSTM import args
elif model_name == "BertATT":
from BertATT import args
elif model_name == "BertRCNN":
from BertRCNN import args
elif model_name == "BertCNNPlus":
from BertCNNPlus import args
elif model_name == "BertDPCNN":
from BertDPCNN import args
config = args.get_args(data_dir, output_dir, cache_dir,
bert_vocab_file, bert_model_dir, log_dir, max_seq_length=max_length)
main(config, config.save_name, label_list)