Skip to content

Commit e6d7e31

Browse files
author
Yh Tian
committed
implement the predicting function
1 parent 5eee7ce commit e6d7e31

File tree

8 files changed

+176
-101
lines changed

8 files changed

+176
-101
lines changed

README.md

+14-1
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,22 @@ Here are some important parameters:
7878
* `--feature_flag`: use `pos`, `chunk`, or `dep` knowledge
7979
* `--model_name`: the name of model to save
8080

81+
## Predicting
82+
83+
`run_sample.sh` contains the command line to segment and tag the sentences in an input file ([./sample_data/sentence.txt](./sample_data/sentence.txt)).
84+
85+
Here are some important parameters:
86+
87+
* `--do_predict`: segment and tag the sentences using a pre-trained TwASP model.
88+
* `--input_file`: the file contains sentences to be segmented and tagged. Each line contains one sentence; you can refer to [a sample input file](./sample_data/sentence.txt) for the input format.
89+
* `--output_file`: the path of the output file. Words are segmented by a space; POS labels are attached to the resulting words by an underline ("_").
90+
* `--eval_model`: the pre-trained WMSeg model to be used to segment the sentences in the input file.
91+
92+
To run a pre-trained TwASP model, you need to install SCT and BNP to obtain the auto-analyzed syntactic knowledge. See [data_processing](./data_preprocessing) for more information to download the two toolkits.
93+
8194
## To-do List
8295

83-
* Implement `predict` function in `twasp_main.py`
96+
* Regular maintenance
8497

8598
You can leave comments in the `Issues` section, if you want us to implement any functions.
8699

run_sample.sh

+2
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ python twasp_main.py --do_train --train_data_path=./sample_data/train.tsv --eval
66
# test
77
python twasp_main.py --do_test --eval_data_path=./sample_data/test.tsv --eval_model=./models/model_name/model.pt
88

9+
# predict
10+
python twasp_main.py --do_predict --input_file=./sample_data/sentence.txt --output_file=./sample_data/sentece.txt.out --eval_model=./models/model_name/model.pt

sample_data/sentence.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
共同创造美好的新世纪——二○○一年新年贺词
2+
(二○○○年十二月三十一日)(附图片1张)
3+
女士们,先生们,同志们,朋友们:
4+
2001年新年钟声即将敲响。人类社会前进的航船就要驶入21世纪的新航程。中国人民进入了向现代化建设第三步战略目标迈进的新征程。
5+
在这个激动人心的时刻,我很高兴通过中国国际广播电台、中央人民广播电台和中央电视台,向全国各族人民,向香港特别行政区同胞、澳门特别行政区同胞和台湾同胞、海外侨胞,向世界各国的朋友们,致以新世纪第一个新年的祝贺!

twasp_eval.py

+28-18
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,49 @@
1-
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
1+
from seqeval.metrics import f1_score, precision_score, recall_score
2+
23

34
def eval_sentence(y_pred, y, sentence, word2id):
45
words = sentence.split(' ')
5-
seg_true = []
6+
7+
if y is not None:
8+
seg_true = []
9+
word_true = ''
10+
y_word = []
11+
y_pos = []
12+
for y_label in y:
13+
y_word.append(y_label[0])
14+
y_pos.append(y_label[2:])
15+
16+
for i in range(len(y_word)):
17+
word_true += words[i]
18+
if y_word[i] in ['S', 'E']:
19+
pos_tag_true = y_pos[i]
20+
word_pos_true = word_true + '_' + pos_tag_true
21+
if word_true not in word2id:
22+
word_pos_true = '*' + word_pos_true + '*'
23+
seg_true.append(word_pos_true)
24+
word_true = ''
25+
26+
seg_true_str = ' '.join(seg_true)
27+
else:
28+
seg_true_str = None
29+
630
seg_pred = []
7-
word_true = ''
831
word_pred = ''
932

10-
y_word = []
11-
y_pos = []
1233
y_pred_word = []
1334
y_pred_pos = []
14-
for y_label, y_pred_label in zip(y, y_pred):
15-
y_word.append(y_label[0])
16-
y_pos.append(y_label[2:])
35+
for y_pred_label in y_pred:
1736
y_pred_word.append(y_pred_label[0])
1837
y_pred_pos.append(y_pred_label[2:])
1938

20-
for i in range(len(y_word)):
21-
word_true += words[i]
39+
for i in range(len(y_pred_word)):
2240
word_pred += words[i]
23-
if y_word[i] in ['S', 'E']:
24-
pos_tag_true = y_pos[i]
25-
word_pos_true = word_true + '_' + pos_tag_true
26-
if word_true not in word2id:
27-
word_pos_true = '*' + word_pos_true + '*'
28-
seg_true.append(word_pos_true)
29-
word_true = ''
3041
if y_pred_word[i] in ['S', 'E']:
3142
pos_tag_pred = y_pred_pos[i]
3243
word_pos_pred = word_pred + '_' + pos_tag_pred
3344
seg_pred.append(word_pos_pred)
3445
word_pred = ''
3546

36-
seg_true_str = ' '.join(seg_true)
3747
seg_pred_str = ' '.join(seg_pred)
3848
return seg_true_str, seg_pred_str
3949

twasp_helper.py

+20-78
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,19 @@ def read_tsv(file_path):
4242
return sentence_list, label_list
4343

4444

45+
def read_sentence(file_path):
46+
sentence = []
47+
with open(file_path, 'r', encoding='utf8') as f:
48+
lines = f.readlines()
49+
for line in lines:
50+
line = line.strip()
51+
if line == '':
52+
continue
53+
sentence.append([char for char in line])
54+
55+
return sentence, None
56+
57+
4558
def get_word2id(train_path):
4659
word2id = {'<PAD>': 0}
4760
word = ''
@@ -97,7 +110,7 @@ def merge_results(results):
97110
return merged
98111

99112

100-
def request_features_from_stanford(data_path):
113+
def request_features_from_stanford(data_path, do_predict=False):
101114
data_dir = data_path[:data_path.rfind('/')]
102115
flag = data_path[data_path.rfind('/') + 1: data_path.rfind('.')]
103116

@@ -107,7 +120,10 @@ def request_features_from_stanford(data_path):
107120

108121
print('Requesting Stanford results for %s' % str(data_path))
109122

110-
all_sentences, _ = read_tsv(data_path)
123+
if do_predict:
124+
all_sentences, _ = read_sentence(data_path)
125+
else:
126+
all_sentences, _ = read_tsv(data_path)
111127
sentences_str = []
112128
for sentence in all_sentences:
113129
sentences_str.append(''.join(sentence))
@@ -126,13 +142,13 @@ def request_features_from_stanford(data_path):
126142
f.write('\n')
127143

128144

129-
def request_features_from_berkeley(data_path):
145+
def request_features_from_berkeley(data_path, do_predict=False):
130146
data_dir = data_path[:data_path.rfind('/')]
131147
flag = data_path[data_path.rfind('/') + 1: data_path.rfind('.')]
132148

133149
if not os.path.exists(path.join(data_dir, flag + '.stanford.json')):
134150
print('Do not find the Stanford data file\nRequesting Stanford segmentation results for %s' % str(data_path))
135-
request_features_from_stanford(data_path, flag)
151+
request_features_from_stanford(data_path, do_predict=do_predict)
136152
else:
137153
print('The Stanford data file for %s already exists!' % str(data_path))
138154
if os.path.exists(path.join(data_dir, flag + '.berkeley.json')):
@@ -164,14 +180,7 @@ def request_features_from_berkeley(data_path):
164180
pos_tags = parse_tree.pos()
165181

166182
for i, (bt, (w, pos)) in enumerate(zip(berkeley_data['tokens'], pos_tags)):
167-
# w = w_pos[0]
168-
# pos = w_pos[1]
169-
# try:
170183
assert bt['word'] == w
171-
# except AssertionError:
172-
# print('error in sentence: %s' % ''.join(word_list))
173-
# print('word error: excepted %s, get %s' % (bt['word'], w))
174-
# else:
175184
berkeley_data['tokens'][i]['pos'] = pos
176185
berkeley_all_data.append(berkeley_data)
177186

@@ -455,70 +464,3 @@ def renew_ngram_by_freq(all_sentences, ngram2count, min_feq, ngram_len=10):
455464
new_ngram2count[n_gram] += 1
456465
new_ngram_dict = {gram: c for gram, c in new_ngram2count.items() if c > min_feq}
457466
return new_ngram_dict
458-
459-
460-
if __name__ == "__main__":
461-
parser = argparse.ArgumentParser()
462-
463-
parser.add_argument("--dataset",
464-
default=None,
465-
type=str,
466-
required=True,
467-
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
468-
469-
args = parser.parse_args()
470-
base_min_freq = 1
471-
av_threshold = 2
472-
473-
min_freq = base_min_freq
474-
475-
print('min freq: %d' % min_freq)
476-
477-
data_dir = path.join(DATA_DIR, args.dataset)
478-
479-
print(data_dir)
480-
481-
# getlabels(data_dir)
482-
483-
# get_word2id(data_dir)
484-
485-
# be(data_dir, 0, 10)
486-
487-
# oov_stat(data_dir, 'train')
488-
# oov_stat(data_dir, 'dev')
489-
# oov_stat(data_dir, 'test')
490-
# request_features_from_stanford(data_dir, 'train')
491-
# request_features_from_stanford(data_dir, 'dev')
492-
# request_features_from_stanford(data_dir, 'test')
493-
494-
# request_features_from_stanford(data_dir, 'bc')
495-
# request_features_from_stanford(data_dir, 'bn')
496-
# request_features_from_stanford(data_dir, 'cs')
497-
# request_features_from_stanford(data_dir, 'df')
498-
# request_features_from_stanford(data_dir, 'mz')
499-
# request_features_from_stanford(data_dir, 'nw')
500-
# request_features_from_stanford(data_dir, 'sc')
501-
# request_features_from_stanford(data_dir, 'wb')
502-
503-
# request_features_from_stanford('./data/POS/demo', 'demo')
504-
505-
# sfp = stanford_feature_processor(data_dir)
506-
# sfp._pre_processing()
507-
# sfp.read_features('train')
508-
# sfp.read_features('test')
509-
# sfp.feature_stat()
510-
511-
# bek = berkeley_feature_processor(data_dir)
512-
# bek.request_knoledge('train')
513-
# bek.request_knoledge('dev')
514-
# bek.request_knoledge('test')
515-
# bek.request_knoledge('demo')
516-
# bek._pre_processing()
517-
# bek.feature_stat()
518-
519-
# attentionn_gram_stat(data_dir, 0, 10)
520-
521-
print('')
522-
523-
# exit()
524-

twasp_main.py

+87-2
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,93 @@ def test(args):
480480

481481

482482
def predict(args):
483-
# In progressing
484-
return None
483+
484+
if args.local_rank == -1 or args.no_cuda:
485+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
486+
n_gpu = torch.cuda.device_count()
487+
else:
488+
torch.cuda.set_device(args.local_rank)
489+
device = torch.device("cuda", args.local_rank)
490+
n_gpu = 1
491+
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
492+
torch.distributed.init_process_group(backend='nccl')
493+
print("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
494+
device, n_gpu, bool(args.local_rank != -1), args.fp16))
495+
496+
joint_model_checkpoint = torch.load(args.eval_model)
497+
joint_model = TwASP.from_spec(joint_model_checkpoint['spec'], joint_model_checkpoint['state_dict'], args)
498+
499+
if joint_model.use_attention:
500+
if joint_model.source == 'stanford':
501+
request_features_from_stanford(args.input_file, do_predict=True)
502+
elif joint_model.source == 'berkeley':
503+
request_features_from_berkeley(args.input_file, do_predict=True)
504+
else:
505+
raise ValueError('Invalid source $s. '
506+
'Source must be one of \'stanford\' or \'berkeley\' if attentions are used.'
507+
% joint_model.source)
508+
509+
eval_examples = joint_model.load_data(args.input_file, do_predict=True)
510+
convert_examples_to_features = joint_model.convert_examples_to_features
511+
feature2input = joint_model.feature2input
512+
num_labels = joint_model.num_labels
513+
word2id = joint_model.word2id
514+
label_map = {v: k for k, v in joint_model.labelmap.items()}
515+
label_map[0] = 'O'
516+
517+
if args.fp16:
518+
joint_model.half()
519+
joint_model.to(device)
520+
if args.local_rank != -1:
521+
try:
522+
from apex.parallel import DistributedDataParallel as DDP
523+
except ImportError:
524+
raise ImportError(
525+
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
526+
527+
joint_model = DDP(joint_model)
528+
elif n_gpu > 1:
529+
joint_model = torch.nn.DataParallel(joint_model)
530+
531+
joint_model.to(device)
532+
533+
joint_model.eval()
534+
y_pred = []
535+
536+
for start_index in tqdm(range(0, len(eval_examples), args.eval_batch_size)):
537+
eval_batch_examples = eval_examples[start_index: min(start_index + args.eval_batch_size,
538+
len(eval_examples))]
539+
eval_features = convert_examples_to_features(eval_batch_examples)
540+
541+
feature_ids, input_ids, input_mask, l_mask, label_ids, ngram_ids, ngram_positions, \
542+
segment_ids, valid_ids, word_ids, word_matching_matrix = feature2input(device, eval_features)
543+
544+
with torch.no_grad():
545+
_, tag_seq = joint_model(input_ids, segment_ids, input_mask, label_ids, valid_ids, l_mask,
546+
word_ids, feature_ids, word_matching_matrix, word_matching_matrix,
547+
ngram_ids, ngram_positions)
548+
549+
logits = tag_seq.to('cpu').numpy()
550+
label_ids = label_ids.to('cpu').numpy()
551+
552+
for i, label in enumerate(label_ids):
553+
temp = []
554+
for j, m in enumerate(label):
555+
if j == 0:
556+
continue
557+
elif label_ids[i][j] == num_labels - 1:
558+
y_pred.append(temp)
559+
break
560+
else:
561+
temp.append(label_map[logits[i][j]])
562+
563+
print('write results to %s' % str(args.output_file))
564+
with open(args.output_file, 'w') as writer:
565+
for i in range(len(y_pred)):
566+
sentence = eval_examples[i].text_a
567+
_, seg_pred_str = eval_sentence(y_pred[i], None, sentence, word2id)
568+
writer.write('%s\n' % seg_pred_str)
569+
485570

486571

487572
def main():

twasp_model.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,12 @@ def from_spec(cls, spec, model, args):
189189
res.load_state_dict(model)
190190
return res
191191

192-
def load_data(self, data_path):
193-
lines = readfile(data_path)
192+
def load_data(self, data_path, do_predict=False):
193+
194+
if do_predict:
195+
lines = read_sentence(data_path)
196+
else:
197+
lines = readfile(data_path)
194198

195199
flag = data_path[data_path.rfind('/')+1: data_path.rfind('.')]
196200

@@ -654,3 +658,16 @@ def readfile(filename):
654658
label = []
655659
return data
656660

661+
662+
def read_sentence(filename):
663+
data = []
664+
with open(filename, 'r', encoding='utf8') as f:
665+
lines = f.readlines()
666+
for line in lines:
667+
line = line.strip()
668+
if line == '':
669+
continue
670+
sentence = [char for char in line]
671+
label = ['<UNK>' for _ in sentence]
672+
data.append((sentence, label))
673+
return data

updates.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# Important Updates
22

3+
* July 14, 2020: Implement the `predict` function in `twasp_main.py`. You can use that function to segment and tag the sentences in an input file with a pre-trained TwASP model. See [run_sample.sh](./run_sample.sh) for the usage, and [./sample_data/sentences.txt](./sample_data/sentence.txt) for the input format. If you run pre-trained TwASP models using Stanford CoreNLP Toolkit v3.9.2 or Berkeley Neural Parser, you need to download these toolkits before running. See [data_preprocessing](./data_preprocessing) for more information to install the toolkits.
34
* July 7, 2020: the release of [pre-trained TwASP models](./models).

0 commit comments

Comments
 (0)