Skip to content

Commit 786e97a

Browse files
author
Yh Tian
committed
update the way to save model
1 parent e3ee0c0 commit 786e97a

File tree

6 files changed

+160
-120
lines changed

6 files changed

+160
-120
lines changed

README.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ We will keep updating this repository these days.
99
If you use or extend our work, please cite our paper at ACL2020.
1010

1111
```
12-
@inproceedings{tian-etal-2020-improving,
13-
title = "Improving {C}hinese Word Segmentation with Wordhood Memory Networks",
14-
author = "Tian, Yuanhe and Song, Yan and Xia, Fei and Zhang, Tong and Wang, Yonggang",
12+
@inproceedings{tian-etal-2020-joint,
13+
title = "Joint Chinese Word Segmentation and Part-of-speech Tagging via Two-way Attentions of Auto-analyzed Knowledge",
14+
author = "Tian, Yuanhe and Song, Yan and Ao, Xiang and Xia, Fei and Quan, Xiaojun and Zhang, Tong and Wang, Yonggang",
1515
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
1616
month = jul,
1717
year = "2020",
1818
address = "Online",
19-
pages = "8274--8285",
19+
pages = "8286--8296",
2020
}
2121
```
2222

@@ -42,7 +42,7 @@ Run `run_sample.sh` to train a model on the small sample data under the `sample_
4242

4343
We use [CTB5](https://catalog.ldc.upenn.edu/LDC2005T01), [CTB6](https://catalog.ldc.upenn.edu/LDC2007T36), [CTB7](https://catalog.ldc.upenn.edu/LDC2010T07), [CTB9](https://catalog.ldc.upenn.edu/LDC2016T13), and [Universal Dependencies 2.4](https://lindat.mff.cuni.cz/repository/xmlui/handle/11234/1-2988) (UD) in our paper.
4444

45-
To obtain and pre-process the data, you can go to `data_preprocessing` directory and run `getdata.sh`. This script will download and process the official data from UD. For CTB5 (LDC05T01), CTB6 (LDC07T36), CTB7 (LDC10T07), and CTB9 (LDC2016T13), you need to obtain the official data yourself, and then put the raw data directory under the `data_preprocessing` directory.
45+
To obtain and pre-process the data, you can go to `data_preprocessing` directory and run `getdata.sh`. This script will download and process the official data from UD. For CTB5 (LDC05T01), CTB6 (LDC07T36), CTB7 (LDC10T07), and CTB9 (LDC2016T13), you need to obtain the official data yourself, and then put the raw data folder under the `data_preprocessing` directory.
4646

4747
The script will also download the [Stanford CoreNLP Toolkit v3.9.2](https://stanfordnlp.github.io/CoreNLP/history.html) (SCT) and [Berkeley Neural Parser](https://github.com/nikitakit/self-attentive-parser) (BNP) to obtain the auto-analyzed syntactic knowledge. You can refer to their website for more information.
4848

get_syninfo.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@
4242
if os.path.exists(out_file) and not args.overwrite:
4343
print('File already exists: %s' % str(out_file))
4444
continue
45-
request_features_from_stanford(input_file, flag)
45+
request_features_from_stanford(input_file)
4646

4747
elif args.toolkit == 'BNP':
4848
out_file = os.path.join(input_dir, flag + '.berkeley.json')
4949
if os.path.exists(out_file) and not args.overwrite:
5050
print('File already exists: %s' % str(out_file))
5151
continue
52-
request_features_from_berkeley(input_file, flag)
52+
request_features_from_berkeley(input_file)
5353
else:
5454
raise ValueError('Invalid type of toolkit name: %s. Should be one of \'SCT\' and \'BNP\'.' % args.toolkit)
5555

run_sample.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
mkdir logs
22

33
# train
4-
python twasp_main.py --do_train --train_data_path=./sample_data/train.tsv --eval_data_path=./sample_data/test.tsv --use_bert --bert_model=/path/to/bert/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=2 --eval_batch_size=2 --num_train_epochs=3 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=sample_model
4+
python twasp_main.py --do_train --train_data_path=./sample_data/train.tsv --eval_data_path=./sample_data/dev.tsv --use_bert --bert_model=/path/to/bert/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=2 --eval_batch_size=2 --num_train_epochs=3 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=sample_model
55

66
# test
77
python twasp_main.py --do_test --eval_data_path=./sample_data/test.tsv --eval_model=./models/model_name/model.pt

twasp_helper.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,10 @@ def merge_results(results):
9797
return merged
9898

9999

100-
def request_features_from_stanford(data_path, flag):
100+
def request_features_from_stanford(data_path):
101101
data_dir = data_path[:data_path.rfind('/')]
102+
flag = data_path[data_path.rfind('/') + 1: data_path.rfind('.')]
103+
102104
if os.path.exists(path.join(data_dir, flag + '.stanford.json')):
103105
print('The Stanford data file for %s already exists!' % str(data_path))
104106
return None
@@ -124,8 +126,9 @@ def request_features_from_stanford(data_path, flag):
124126
f.write('\n')
125127

126128

127-
def request_features_from_berkeley(data_path, flag):
129+
def request_features_from_berkeley(data_path):
128130
data_dir = data_path[:data_path.rfind('/')]
131+
flag = data_path[data_path.rfind('/') + 1: data_path.rfind('.')]
129132

130133
if not os.path.exists(path.join(data_dir, flag + '.stanford.json')):
131134
print('Do not find the Stanford data file\nRequesting Stanford segmentation results for %s' % str(data_path))

twasp_main.py

+25-65
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def train(args):
2626
if args.use_bert and args.use_zen:
2727
raise ValueError('We cannot use both BERT and ZEN')
2828

29-
if not os.path.exists('./logs/'):
30-
os.mkdir('./logs')
29+
if not os.path.exists('./logs'):
30+
os.mkdir('logs')
3131

3232
now_time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
3333
log_file_name = './logs/log-' + now_time
@@ -88,12 +88,12 @@ def train(args):
8888

8989
if args.use_attention:
9090
if args.source == 'stanford':
91-
request_features_from_stanford(args.train_data_path, flag='train')
92-
request_features_from_stanford(args.eval_data_path, flag='test')
91+
request_features_from_stanford(args.train_data_path)
92+
request_features_from_stanford(args.eval_data_path)
9393
processor = stanford_feature_processor()
9494
elif args.source == 'berkeley':
95-
request_features_from_berkeley(args.train_data_path, flag='train')
96-
request_features_from_berkeley(args.eval_data_path, flag='test')
95+
request_features_from_berkeley(args.train_data_path)
96+
request_features_from_berkeley(args.eval_data_path)
9797
processor = berkeley_feature_processor()
9898
else:
9999
raise ValueError('Source must be one of \'stanford\' or \'berkeley\' if attentions are used.')
@@ -103,12 +103,14 @@ def train(args):
103103
gram2id = None
104104
feature2id = None
105105

106-
joint_model = TwASP(word2id, gram2id, feature2id, label_map, processor, args)
106+
hpara = TwASP.init_hyper_parameters(args)
107+
joint_model = TwASP(word2id, gram2id, feature2id, label_map, processor, hpara, args)
107108

108-
train_examples = joint_model.load_data(args.train_data_path, flag='train')
109-
eval_examples = joint_model.load_data(args.eval_data_path, flag='test')
109+
train_examples = joint_model.load_data(args.train_data_path)
110+
eval_examples = joint_model.load_data(args.eval_data_path)
110111
num_labels = joint_model.num_labels
111112
convert_examples_to_features = joint_model.convert_examples_to_features
113+
feature2input = joint_model.feature2input
112114

113115
total_params = sum(p.numel() for p in joint_model.parameters() if p.requires_grad)
114116
logger.info('# of trainable parameters: %d' % total_params)
@@ -194,7 +196,7 @@ def train(args):
194196
continue
195197
train_features = convert_examples_to_features(batch_examples)
196198
feature_ids, input_ids, input_mask, l_mask, label_ids, ngram_ids, ngram_positions, \
197-
segment_ids, valid_ids, word_ids, word_matching_matrix = feature2input(args, device, train_features)
199+
segment_ids, valid_ids, word_ids, word_matching_matrix = feature2input(device, train_features)
198200

199201
loss, _ = joint_model(input_ids, segment_ids, input_mask, label_ids, valid_ids, l_mask, word_ids,
200202
feature_ids, word_matching_matrix, word_matching_matrix, ngram_ids, ngram_positions)
@@ -237,7 +239,7 @@ def train(args):
237239
eval_features = convert_examples_to_features(eval_batch_examples)
238240

239241
feature_ids, input_ids, input_mask, l_mask, label_ids, ngram_ids, ngram_positions, \
240-
segment_ids, valid_ids, word_ids, word_matching_matrix = feature2input(args, device, eval_features)
242+
segment_ids, valid_ids, word_ids, word_matching_matrix = feature2input(device, eval_features)
241243

242244
with torch.no_grad():
243245
_, tag_seq = joint_model(input_ids, segment_ids, input_mask, label_ids, valid_ids, l_mask,
@@ -365,48 +367,6 @@ def train(args):
365367
f.write('\n')
366368

367369

368-
def feature2input(args, device, feature):
369-
all_input_ids = torch.tensor([f.input_ids for f in feature], dtype=torch.long)
370-
all_input_mask = torch.tensor([f.input_mask for f in feature], dtype=torch.long)
371-
all_segment_ids = torch.tensor([f.segment_ids for f in feature], dtype=torch.long)
372-
all_label_ids = torch.tensor([f.label_id for f in feature], dtype=torch.long)
373-
all_valid_ids = torch.tensor([f.valid_ids for f in feature], dtype=torch.long)
374-
all_lmask_ids = torch.tensor([f.label_mask for f in feature], dtype=torch.long)
375-
376-
input_ids = all_input_ids.to(device)
377-
input_mask = all_input_mask.to(device)
378-
segment_ids = all_segment_ids.to(device)
379-
label_ids = all_label_ids.to(device)
380-
valid_ids = all_valid_ids.to(device)
381-
l_mask = all_lmask_ids.to(device)
382-
if args.use_attention:
383-
all_word_ids = torch.tensor([f.word_ids for f in feature], dtype=torch.long)
384-
all_feature_ids = torch.tensor([f.syn_feature_ids for f in feature], dtype=torch.long)
385-
all_word_matching_matrix = torch.tensor([f.word_matching_matrix for f in feature],
386-
dtype=torch.float)
387-
388-
word_ids = all_word_ids.to(device)
389-
feature_ids = all_feature_ids.to(device)
390-
word_matching_matrix = all_word_matching_matrix.to(device)
391-
else:
392-
word_ids = None
393-
feature_ids = None
394-
word_matching_matrix = None
395-
if args.use_zen:
396-
all_ngram_ids = torch.tensor([f.ngram_ids for f in feature], dtype=torch.long)
397-
all_ngram_positions = torch.tensor([f.ngram_positions for f in feature], dtype=torch.long)
398-
# all_ngram_lengths = torch.tensor([f.ngram_lengths for f in train_features], dtype=torch.long)
399-
# all_ngram_seg_ids = torch.tensor([f.ngram_seg_ids for f in train_features], dtype=torch.long)
400-
# all_ngram_masks = torch.tensor([f.ngram_masks for f in train_features], dtype=torch.long)
401-
402-
ngram_ids = all_ngram_ids.to(device)
403-
ngram_positions = all_ngram_positions.to(device)
404-
else:
405-
ngram_ids = None
406-
ngram_positions = None
407-
return feature_ids, input_ids, input_mask, l_mask, label_ids, ngram_ids, ngram_positions, segment_ids, valid_ids, word_ids, word_matching_matrix
408-
409-
410370
def test(args):
411371

412372
if args.local_rank == -1 or args.no_cuda:
@@ -422,21 +382,23 @@ def test(args):
422382
device, n_gpu, bool(args.local_rank != -1), args.fp16))
423383

424384
joint_model_checkpoint = torch.load(args.eval_model)
425-
joint_model = TwASP.from_spec(joint_model_checkpoint['spec'], joint_model_checkpoint['state_dict'])
385+
joint_model = TwASP.from_spec(joint_model_checkpoint['spec'], joint_model_checkpoint['state_dict'], args)
426386

427387
if joint_model.use_attention:
428-
if joint_model.spec['args'].source == 'stanford':
429-
request_features_from_stanford(args.eval_data_path, flag='test')
430-
elif joint_model.spec['args'].source == 'berkeley':
431-
request_features_from_berkeley(args.eval_data_path, flag='test')
388+
if joint_model.source == 'stanford':
389+
request_features_from_stanford(args.eval_data_path)
390+
elif joint_model.source == 'berkeley':
391+
request_features_from_berkeley(args.eval_data_path)
432392
else:
433-
raise ValueError('Source must be one of \'stanford\' or \'berkeley\' if attentions are used.')
393+
raise ValueError('Invalid source $s. '
394+
'Source must be one of \'stanford\' or \'berkeley\' if attentions are used.'
395+
% joint_model.source)
434396

435-
eval_examples = joint_model.load_data(args.eval_data_path, flag='test')
397+
eval_examples = joint_model.load_data(args.eval_data_path)
436398
convert_examples_to_features = joint_model.convert_examples_to_features
399+
feature2input = joint_model.feature2input
437400
num_labels = joint_model.num_labels
438401
word2id = joint_model.word2id
439-
model_args = joint_model.spec['args']
440402
label_map = {v: k for k, v in joint_model.labelmap.items()}
441403
label_map[0] = 'O'
442404

@@ -457,8 +419,6 @@ def test(args):
457419
joint_model.to(device)
458420

459421
joint_model.eval()
460-
eval_loss, eval_accuracy = 0, 0
461-
nb_eval_steps, nb_eval_examples = 0, 0
462422
y_true = []
463423
y_pred = []
464424

@@ -468,7 +428,7 @@ def test(args):
468428
eval_features = convert_examples_to_features(eval_batch_examples)
469429

470430
feature_ids, input_ids, input_mask, l_mask, label_ids, ngram_ids, ngram_positions, \
471-
segment_ids, valid_ids, word_ids, word_matching_matrix = feature2input(model_args, device, eval_features)
431+
segment_ids, valid_ids, word_ids, word_matching_matrix = feature2input(device, eval_features)
472432

473433
with torch.no_grad():
474434
_, tag_seq = joint_model(input_ids, segment_ids, input_mask, label_ids, valid_ids, l_mask,
@@ -520,7 +480,7 @@ def test(args):
520480

521481

522482
def predict(args):
523-
483+
# In progressing
524484
return None
525485

526486

0 commit comments

Comments
 (0)