Skip to content

Commit cafd45e

Browse files
committed
Add benchmark script.
1 parent 3978d2d commit cafd45e

File tree

4 files changed

+265
-0
lines changed

4 files changed

+265
-0
lines changed

benchmark/README.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Benchmark jdepp-python
2+
3+
## Dataset
4+
5+
Wiki40b
6+
7+
## Requirements
8+
9+
* Python
10+
* Conda
11+
12+
## Install
13+
14+
```
15+
$ python -m pip install -r requirements.txt
16+
```
17+
18+
## Prepare data
19+
20+
We use huggingface datasets to download wiki40b, ja_sentence_splitter to split text into sentences(we use `ja_sentence_splitter` for speed. not all text are correctly splitted though), jagger for POS tagging.
21+
22+
Download and unpack jagger model file: https://github.com/lighttransport/jagger-python/releases/tag/v0.1.0
23+
24+
Then,
25+
26+
Run `prepare_dataset.py`
27+
28+
29+
## Benchmark in J.DepP
30+
31+
Download and extract jdepp model file. https://github.com/lighttransport/jdepp-python/releases/tag/v0.1.0
32+
(`2ndpoly` recommended)
33+
34+
Then,
35+
36+
```
37+
$ python run-jdepp.py
38+
```
39+
40+
41+
EoL.

benchmark/prepare_dataset.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import os
2+
import functools
3+
import signal
4+
import concurrent.futures
5+
6+
#
7+
import ja_sentence_segmenter
8+
import datasets
9+
import jagger
10+
from tqdm import tqdm
11+
12+
model_path = "model/kwdlc/patterns"
13+
tagger = jagger.Jagger()
14+
tagger.load_model(model_path)
15+
16+
17+
from ja_sentence_segmenter.common.pipeline import make_pipeline
18+
from ja_sentence_segmenter.concatenate.simple_concatenator import concatenate_matching
19+
from ja_sentence_segmenter.normalize.neologd_normalizer import normalize
20+
from ja_sentence_segmenter.split.simple_splitter import split_newline, split_punctuation
21+
22+
# Assume wikitext all uses '。' for punctuation(no period '.' for punctuation)
23+
split_punc = functools.partial(split_punctuation, punctuations=r"。 !?")
24+
concat_tail_no = functools.partial(concatenate_matching, former_matching_rule=r"^(?P<result>.+)(の)$", remove_former_matched=False)
25+
segmenter = make_pipeline(normalize, split_newline, concat_tail_no, split_punc)
26+
27+
28+
interrupted = False
29+
30+
def handler(signum, frame):
31+
# Gracefull shutfown
32+
print('Signal handler called with signal', signum)
33+
34+
global interrupted
35+
interrupted = True
36+
37+
38+
39+
dss = datasets.load_dataset("range3/wiki40b-ja")
40+
print(dss)
41+
42+
def senter(text):
43+
44+
result = list(segmenter(text))
45+
46+
outputs = ''
47+
for sent in result:
48+
49+
toks = tagger.tokenize(sent)
50+
51+
pos_tagged = ''
52+
for tok in toks:
53+
pos_tagged += tok.surface() + '\t' + tok.feature() + '\n'
54+
55+
pos_tagged += 'EOS\n'
56+
57+
# no newline-only line between sentence.
58+
outputs += pos_tagged
59+
60+
61+
return outputs
62+
63+
64+
65+
def singleprocess_proc(f):
66+
for example in tqdm(dss['train']):
67+
texts = example['text'].split()
68+
69+
# extract paragraph only.
70+
in_paragraph = False
71+
72+
txts_result = []
73+
for text in texts:
74+
if in_paragraph:
75+
text = text.replace("_NEWLINE_", '\n')
76+
text = senter(text)
77+
f.write(text)
78+
in_paragraph = False
79+
80+
if text == "_START_PARAGRAPH_":
81+
in_paragraph = True
82+
83+
def run_task(texts: [str]):
84+
out_texts = []
85+
86+
#global interrupted
87+
88+
for text in texts:
89+
#print(text)
90+
#if interrupted:
91+
# return None
92+
93+
lines = text.split()
94+
95+
# extract paragraph only.
96+
in_paragraph = False
97+
98+
txt_result = ''
99+
for line in lines:
100+
if in_paragraph:
101+
line = line.replace("_NEWLINE_", '\n')
102+
line = senter(line)
103+
104+
txt_result += line
105+
in_paragraph = False
106+
107+
if line == "_START_PARAGRAPH_":
108+
in_paragraph = True
109+
110+
out_texts.append(txt_result)
111+
112+
return {'text': out_texts}
113+
114+
def multiprocess_proc(f):
115+
116+
split_name = 'train'
117+
118+
nprocs = max(1, os.cpu_count() // 2)
119+
print("nprocs", nprocs)
120+
nexamples_per_batch = 10000 # small batch size results in slow(due to Python future object creation?). 10000 or more recommended for wiki40b/ja `train'
121+
122+
# datasets.map is a easy solution, but consumes lots of disk space.
123+
# so disabled atm.
124+
#
125+
# processed_ds = dss['train'].map(run_task, batched=True, batch_size=nexamples_per_batch, num_proc=nprocs)
126+
#for p in tqdm(processed_ds['text']):
127+
# f.write(p)
128+
129+
# ProcessPoolExecutor version
130+
131+
chunks = []
132+
for i in tqdm(range(0, len(dss[split_name]['text']), nexamples_per_batch), desc="[chunking input]"):
133+
chunks.append(dss[split_name]['text'][i:i+nexamples_per_batch])
134+
135+
signal.signal(signal.SIGINT, handler)
136+
total_ticks = len(chunks)
137+
with tqdm(total=total_ticks) as pbar:
138+
with concurrent.futures.ProcessPoolExecutor(max_workers=nprocs) as executor:
139+
futures = {executor.submit(run_task, chunks[i]): i for i in range(len(chunks))}
140+
141+
for future in concurrent.futures.as_completed(futures):
142+
arg = futures[future]
143+
result = future.result()
144+
# single IO
145+
for text in result['text']:
146+
f.write(text)
147+
148+
del result
149+
150+
pbar.update(1)
151+
152+
del future
153+
154+
if __name__ == '__main__':
155+
156+
f = open("output-wiki-postagged.txt", 'w')
157+
158+
# multiprocessing: finish in few mins, but consumes 20GB~40GB memory.
159+
# Use singleprocess_proc() if you face out-of-memory
160+
161+
# singleprocess_proc(f)
162+
multiprocess_proc(f)

benchmark/requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
datasets
2+
jagger
3+
jdepp
4+
tqdm
5+
zstandard

benchmark/run-jdepp.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import sys
2+
import tqdm
3+
import time
4+
import jdepp
5+
6+
parser = jdepp.Jdepp()
7+
8+
model_path = "model/knbc"
9+
parser.load_model(model_path)
10+
11+
input_filename = "output-wiki-postagged.txt"
12+
#input_filename = "test-posttagged.txt"
13+
14+
print("reading test data:", input_filename)
15+
lines = open(input_filename, 'r', encoding='utf8').readlines()
16+
17+
s = time.time()
18+
19+
nprocessed_sentences = 0
20+
21+
sents = []
22+
for line in tqdm.tqdm(lines):
23+
if line == '\n':
24+
continue
25+
26+
sents.append(line)
27+
28+
if line.startswith("EOS"):
29+
result = parser.parse_from_postagged(sents)
30+
print(result)
31+
32+
nprocessed_sentences += 1
33+
sents = []
34+
35+
e = time.time()
36+
proc_sec = e - s
37+
ms_per_sentence = 1000.0 * proc_sec / float(nprocessed_sentences)
38+
sys.stderr.write("J.DepP: Total {} secs({} sentences. {} ms per sentence))\n".format(proc_sec, nprocessed_sentences, ms_per_sentence))
39+
40+
#total_secs = 0
41+
#nlines_per_batch = 1024*128
42+
#for i in tqdm.tqdm(range(0, len(lines), nlines_per_batch)):
43+
# text = '\n'.join(lines[i:i+nlines_per_batch])
44+
#
45+
# print("run jagger for {} lines.".format(nlines_per_batch))
46+
# s = time.time()
47+
# toks_list = tokenizer.tokenize_batch(text)
48+
# e = time.time()
49+
# print("{} secs".format(e - s))
50+
#
51+
# total_secs += (e - s)
52+
#
53+
# # print result
54+
# #for toks in toks_list:
55+
# # for tok in toks:
56+
# # print(tok.surface(), tok.feature())
57+
# print("Total processing time: {} secs".format(total_secs))

0 commit comments

Comments
 (0)