Skip to content

Commit

Permalink
Experimental N-Best results mocobeta#3
Browse files Browse the repository at this point in the history
  • Loading branch information
nakagami committed Oct 28, 2019
1 parent 0ed1105 commit 2d91edc
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 16 deletions.
64 changes: 57 additions & 7 deletions janome/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import sys, os
import heapq
PY3 = sys.version_info[0] == 3

class NodeType:
Expand All @@ -28,14 +29,15 @@ class Node(object):
Node class
"""
__slots__ = [
'pos', 'index', 'surface', 'left_id', 'right_id', 'cost',
'pos', 'epos', 'index', 'surface', 'left_id', 'right_id', 'cost',
'part_of_speech', 'infl_type', 'infl_form',
'base_form', 'reading', 'phonetic', 'node_type',
'min_cost', 'back_pos', 'back_index'
]

def __init__(self, dict_entry, node_type=NodeType.SYS_DICT):
self.pos = 0
self.epos = 0
self.index = 0
self.min_cost = 2147483647 # int(pow(2,31)-1)
self.back_pos = -1
Expand All @@ -59,10 +61,11 @@ class SurfaceNode(object):
"""
Node class with surface form only.
"""
__slots__ = ['pos', 'index', 'num', 'surface', 'left_id', 'right_id', 'cost', 'node_type', 'min_cost', 'back_pos', 'back_index']
__slots__ = ['pos', 'epos', 'index', 'num', 'surface', 'left_id', 'right_id', 'cost', 'node_type', 'min_cost', 'back_pos', 'back_index']

def __init__(self, dict_entry, node_type=NodeType.SYS_DICT):
self.pos = 0
self.epos = 0
self.index = 0
self.min_cost = 2147483647 # int(pow(2,31)-1)
self.back_pos = -1
Expand All @@ -81,6 +84,7 @@ class BOS(object):
"""
def __init__(self):
self.pos = 0
self.epos = 1
self.index = 0
self.right_id = 0
self.cost = 0
Expand All @@ -102,6 +106,7 @@ class EOS(object):
def __init__(self, pos):
self.min_cost = 2147483647 # int(pow(2,31)-1)
self.pos = pos
self.epos = pos + 1
self.cost = 0
self.left_id = 0

Expand All @@ -111,11 +116,31 @@ def __str__(self):
def node_label(self):
return 'EOS'


class BackwardPath(object):
def __init__(self, dic, node, right_path=None):
self.cost_from_bos = node.min_cost
if right_path is None:
assert isinstance(node, EOS)
self.cost_from_eos = 0
self.path = [node]
else:
neighbor_node = right_path.path[-1]
self.cost_from_eos = right_path.cost_from_eos + neighbor_node.cost + dic.get_trans_cost(node.right_id, neighbor_node.left_id)
self.path = right_path.path[:]
self.path.append(node)

def __lt__(self, other):
return self.cost_from_bos + self.cost_from_eos < other.cost_from_bos + other.cost_from_eos

def is_complete(self):
return isinstance(self.path[-1], BOS)


class Lattice:
def __init__(self, size, dic):
self.snodes = [[BOS()]] + [[] for i in range(0, size + 1)]
self.enodes = [[], [BOS()]] + [[] for i in range(0, size + 1)]
self.conn_costs = [[]]
self.p = 1
self.dic = dic

Expand All @@ -130,10 +155,10 @@ def add(self, node):
node.back_index = best_node.index
node.back_pos = best_node.pos
node.pos = self.p
node.epos = self.p + (len(node.surface) if hasattr(node, 'surface') else 1)
node.index = len(self.snodes[self.p])
self.snodes[self.p].append(node)
node_len = len(node.surface) if hasattr(node, 'surface') else 1
self.enodes[self.p + node_len].append(node)
self.snodes[node.pos].append(node)
self.enodes[node.epos].append(node)

def forward(self):
old_p = self.p
Expand All @@ -145,8 +170,9 @@ def forward(self):
def end(self):
eos = EOS(self.p)
self.add(eos)
# truncate snodes
# truncate snodes, enodes
self.snodes = self.snodes[:self.p+1]
self.enodes = self.enodes[:self.p+2]

def backward(self):
assert isinstance(self.snodes[len(self.snodes)-1][0], EOS)
Expand All @@ -161,6 +187,30 @@ def backward(self):
path.reverse()
return path

def backward_astar(self, n):
paths = []
epos = len(self.enodes) - 1
node = self.enodes[epos][0]
assert isinstance(node, EOS)
pq = []
bp = BackwardPath(self.dic, node)
heapq.heappush(pq, bp)

while pq and n:
bp = heapq.heappop(pq)
if bp.is_complete():
bp.path.reverse()
paths.append(bp.path)
n -= 1
else:
node = bp.path[-1]
node_len = len(node.surface) if hasattr(node, 'surface') else 1
epos = node.epos - node_len
for index in range(len(self.enodes[epos])):
node = self.enodes[epos][index]
heapq.heappush(pq, BackwardPath(self.dic, node, bp))
return paths

# generate Graphviz dot file
def generate_dotfile(self, filename='lattice.gv'):
def is_unknown(node):
Expand Down
36 changes: 27 additions & 9 deletions janome/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(self, udic='', udic_enc='utf8', udic_type='ipadic', max_unknown_len
self.user_dic = None
self.max_unknown_length = max_unknown_length

def tokenize(self, text, stream=False, wakati=False, baseform_unk=True, dotfile=''):
def tokenize(self, text, stream=False, wakati=False, baseform_unk=True, dotfile='', n_best=1):
u"""
Tokenize the input text.
Expand All @@ -196,6 +196,8 @@ def tokenize(self, text, stream=False, wakati=False, baseform_unk=True, dotfile=
"""
if self.wakati:
wakati = True
if n_best > 1 and len(text) < Tokenizer.MAX_CHUNK_SIZE:
return self.__tokenize_n_best(text, wakati, baseform_unk, n_best)
if stream:
return self.__tokenize_stream(text, wakati, baseform_unk, '')
elif dotfile and len(text) < Tokenizer.MAX_CHUNK_SIZE:
Expand All @@ -217,7 +219,23 @@ def __tokenize_stream(self, text, wakati, baseform_unk, dotfile):
def __tokenize_partial(self, text, wakati, baseform_unk, dotfile):
if self.wakati and not wakati:
raise WakatiModeOnlyException
lattice, next_pos = self.__build_lattice(text, baseform_unk)

min_cost_path = lattice.backward()
tokens = self.__nodes_to_tokens(min_cost_path, wakati)

if dotfile:
lattice.generate_dotfile(filename=dotfile)
return (tokens, next_pos)

def __tokenize_n_best(self, text, wakati, baseform_unk, n_best):
tokens_list = []
lattice, _ = self.__build_lattice(text, baseform_unk)
for path in lattice.backward_astar(n_best):
tokens_list.append(self.__nodes_to_tokens(path, wakati))
return tokens_list

def __build_lattice(self, text, baseform_unk):
chunk_size = min(len(text), Tokenizer.MAX_CHUNK_SIZE)
lattice = Lattice(chunk_size, self.sys_dic)
pos = 0
Expand Down Expand Up @@ -264,23 +282,23 @@ def __tokenize_partial(self, text, wakati, baseform_unk, dotfile):

pos += lattice.forward()
lattice.end()
min_cost_path = lattice.backward()
assert isinstance(min_cost_path[0], BOS)
assert isinstance(min_cost_path[-1], EOS)
return lattice, pos

def __nodes_to_tokens(self, nodes, wakati):
assert isinstance(nodes[0], BOS)
assert isinstance(nodes[-1], EOS)
if wakati:
tokens = [node.surface for node in min_cost_path[1:-1]]
tokens = [node.surface for node in nodes[1:-1]]
else:
tokens = []
for node in min_cost_path[1:-1]:
for node in nodes[1:-1]:
if type(node) == SurfaceNode and node.node_type == NodeType.SYS_DICT:
tokens.append(Token(node, self.sys_dic.lookup_extra(node.num)))
elif type(node) == SurfaceNode and node.node_type == NodeType.USER_DICT:
tokens.append(Token(node, self.user_dic.lookup_extra(node.num)))
else:
tokens.append(Token(node))
if dotfile:
lattice.generate_dotfile(filename=dotfile)
return (tokens, pos)
return tokens

def __should_split(self, text, pos):
return \
Expand Down
26 changes: 26 additions & 0 deletions tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,32 @@ def test_tokenize_dotfile_large_text(self):
tokens = Tokenizer().tokenize(text, dotfile=dotfile)
self.assertFalse(os.path.exists(dotfile))

def test_tokenize_n_best(self):
text = u'すもももももももものうち'
tokens_list = Tokenizer().tokenize(text, n_best=3)
self.assertEqual(3, len(tokens_list))
self._check_token(tokens_list[0][0], u'すもも', u'名詞,一般,*,*,*,*,すもも,スモモ,スモモ', NodeType.SYS_DICT)
self._check_token(tokens_list[0][1], u'も', u'助詞,係助詞,*,*,*,*,も,モ,モ', NodeType.SYS_DICT)
self._check_token(tokens_list[0][2], u'もも', u'名詞,一般,*,*,*,*,もも,モモ,モモ', NodeType.SYS_DICT)
self._check_token(tokens_list[0][3], u'も', u'助詞,係助詞,*,*,*,*,も,モ,モ', NodeType.SYS_DICT)
self._check_token(tokens_list[0][4], u'もも', u'名詞,一般,*,*,*,*,もも,モモ,モモ', NodeType.SYS_DICT)
self._check_token(tokens_list[0][5], u'の', u'助詞,連体化,*,*,*,*,の,ノ,ノ', NodeType.SYS_DICT)
self._check_token(tokens_list[0][6], u'うち', u'名詞,非自立,副詞可能,*,*,*,うち,ウチ,ウチ', NodeType.SYS_DICT)
self._check_token(tokens_list[1][0], u'すもも', u'名詞,一般,*,*,*,*,すもも,スモモ,スモモ', NodeType.SYS_DICT)
self._check_token(tokens_list[1][1], u'も', u'助詞,係助詞,*,*,*,*,も,モ,モ', NodeType.SYS_DICT)
self._check_token(tokens_list[1][2], u'もも', u'名詞,一般,*,*,*,*,もも,モモ,モモ', NodeType.SYS_DICT)
self._check_token(tokens_list[1][3], u'もも', u'名詞,一般,*,*,*,*,もも,モモ,モモ', NodeType.SYS_DICT)
self._check_token(tokens_list[1][4], u'も', u'助詞,係助詞,*,*,*,*,も,モ,モ', NodeType.SYS_DICT)
self._check_token(tokens_list[1][5], u'の', u'助詞,連体化,*,*,*,*,の,ノ,ノ', NodeType.SYS_DICT)
self._check_token(tokens_list[1][6], u'うち', u'名詞,非自立,副詞可能,*,*,*,うち,ウチ,ウチ', NodeType.SYS_DICT)
self._check_token(tokens_list[2][0], u'すもも', u'名詞,一般,*,*,*,*,すもも,スモモ,スモモ', NodeType.SYS_DICT)
self._check_token(tokens_list[2][1], u'もも', u'名詞,一般,*,*,*,*,もも,モモ,モモ', NodeType.SYS_DICT)
self._check_token(tokens_list[2][2], u'も', u'助詞,係助詞,*,*,*,*,も,モ,モ', NodeType.SYS_DICT)
self._check_token(tokens_list[2][3], u'もも', u'名詞,一般,*,*,*,*,もも,モモ,モモ', NodeType.SYS_DICT)
self._check_token(tokens_list[2][4], u'も', u'助詞,係助詞,*,*,*,*,も,モ,モ', NodeType.SYS_DICT)
self._check_token(tokens_list[2][5], u'の', u'助詞,連体化,*,*,*,*,の,ノ,ノ', NodeType.SYS_DICT)
self._check_token(tokens_list[2][6], u'うち', u'名詞,非自立,副詞可能,*,*,*,うち,ウチ,ウチ', NodeType.SYS_DICT)

def _check_token(self, token, surface, detail, node_type):
self.assertEqual(surface, token.surface)
self.assertEqual(detail, ','.join([token.part_of_speech,token.infl_type,token.infl_form,token.base_form,token.reading,token.phonetic]))
Expand Down

0 comments on commit 2d91edc

Please sign in to comment.