diff --git a/janome/lattice.py b/janome/lattice.py index ed280c3..e0099e7 100644 --- a/janome/lattice.py +++ b/janome/lattice.py @@ -15,6 +15,7 @@ # limitations under the License. import sys, os +import heapq PY3 = sys.version_info[0] == 3 class NodeType: @@ -28,7 +29,7 @@ class Node(object): Node class """ __slots__ = [ - 'pos', 'index', 'surface', 'left_id', 'right_id', 'cost', + 'pos', 'epos', 'index', 'surface', 'left_id', 'right_id', 'cost', 'part_of_speech', 'infl_type', 'infl_form', 'base_form', 'reading', 'phonetic', 'node_type', 'min_cost', 'back_pos', 'back_index' @@ -36,6 +37,7 @@ class Node(object): def __init__(self, dict_entry, node_type=NodeType.SYS_DICT): self.pos = 0 + self.epos = 0 self.index = 0 self.min_cost = 2147483647 # int(pow(2,31)-1) self.back_pos = -1 @@ -59,10 +61,11 @@ class SurfaceNode(object): """ Node class with surface form only. """ - __slots__ = ['pos', 'index', 'num', 'surface', 'left_id', 'right_id', 'cost', 'node_type', 'min_cost', 'back_pos', 'back_index'] + __slots__ = ['pos', 'epos', 'index', 'num', 'surface', 'left_id', 'right_id', 'cost', 'node_type', 'min_cost', 'back_pos', 'back_index'] def __init__(self, dict_entry, node_type=NodeType.SYS_DICT): self.pos = 0 + self.epos = 0 self.index = 0 self.min_cost = 2147483647 # int(pow(2,31)-1) self.back_pos = -1 @@ -81,6 +84,7 @@ class BOS(object): """ def __init__(self): self.pos = 0 + self.epos = 1 self.index = 0 self.right_id = 0 self.cost = 0 @@ -102,6 +106,7 @@ class EOS(object): def __init__(self, pos): self.min_cost = 2147483647 # int(pow(2,31)-1) self.pos = pos + self.epos = pos + 1 self.cost = 0 self.left_id = 0 @@ -111,11 +116,31 @@ def __str__(self): def node_label(self): return 'EOS' + +class BackwardPath(object): + def __init__(self, dic, node, right_path=None): + self.cost_from_bos = node.min_cost + if right_path is None: + assert isinstance(node, EOS) + self.cost_from_eos = 0 + self.path = [node] + else: + neighbor_node = right_path.path[-1] + self.cost_from_eos = right_path.cost_from_eos + neighbor_node.cost + dic.get_trans_cost(node.right_id, neighbor_node.left_id) + self.path = right_path.path[:] + self.path.append(node) + + def __lt__(self, other): + return self.cost_from_bos + self.cost_from_eos < other.cost_from_bos + other.cost_from_eos + + def is_complete(self): + return isinstance(self.path[-1], BOS) + + class Lattice: def __init__(self, size, dic): self.snodes = [[BOS()]] + [[] for i in range(0, size + 1)] self.enodes = [[], [BOS()]] + [[] for i in range(0, size + 1)] - self.conn_costs = [[]] self.p = 1 self.dic = dic @@ -130,10 +155,10 @@ def add(self, node): node.back_index = best_node.index node.back_pos = best_node.pos node.pos = self.p + node.epos = self.p + (len(node.surface) if hasattr(node, 'surface') else 1) node.index = len(self.snodes[self.p]) - self.snodes[self.p].append(node) - node_len = len(node.surface) if hasattr(node, 'surface') else 1 - self.enodes[self.p + node_len].append(node) + self.snodes[node.pos].append(node) + self.enodes[node.epos].append(node) def forward(self): old_p = self.p @@ -145,8 +170,9 @@ def forward(self): def end(self): eos = EOS(self.p) self.add(eos) - # truncate snodes + # truncate snodes, enodes self.snodes = self.snodes[:self.p+1] + self.enodes = self.enodes[:self.p+2] def backward(self): assert isinstance(self.snodes[len(self.snodes)-1][0], EOS) @@ -161,6 +187,30 @@ def backward(self): path.reverse() return path + def backward_astar(self, n): + paths = [] + epos = len(self.enodes) - 1 + node = self.enodes[epos][0] + assert isinstance(node, EOS) + pq = [] + bp = BackwardPath(self.dic, node) + heapq.heappush(pq, bp) + + while pq and n: + bp = heapq.heappop(pq) + if bp.is_complete(): + bp.path.reverse() + paths.append(bp.path) + n -= 1 + else: + node = bp.path[-1] + node_len = len(node.surface) if hasattr(node, 'surface') else 1 + epos = node.epos - node_len + for index in range(len(self.enodes[epos])): + node = self.enodes[epos][index] + heapq.heappush(pq, BackwardPath(self.dic, node, bp)) + return paths + # generate Graphviz dot file def generate_dotfile(self, filename='lattice.gv'): def is_unknown(node): diff --git a/janome/tokenizer.py b/janome/tokenizer.py index c3d34b4..5c7a5f4 100644 --- a/janome/tokenizer.py +++ b/janome/tokenizer.py @@ -182,7 +182,7 @@ def __init__(self, udic='', udic_enc='utf8', udic_type='ipadic', max_unknown_len self.user_dic = None self.max_unknown_length = max_unknown_length - def tokenize(self, text, stream=False, wakati=False, baseform_unk=True, dotfile=''): + def tokenize(self, text, stream=False, wakati=False, baseform_unk=True, dotfile='', n_best=1): u""" Tokenize the input text. @@ -196,6 +196,8 @@ def tokenize(self, text, stream=False, wakati=False, baseform_unk=True, dotfile= """ if self.wakati: wakati = True + if n_best > 1 and len(text) < Tokenizer.MAX_CHUNK_SIZE: + return self.__tokenize_n_best(text, wakati, baseform_unk, n_best) if stream: return self.__tokenize_stream(text, wakati, baseform_unk, '') elif dotfile and len(text) < Tokenizer.MAX_CHUNK_SIZE: @@ -217,7 +219,23 @@ def __tokenize_stream(self, text, wakati, baseform_unk, dotfile): def __tokenize_partial(self, text, wakati, baseform_unk, dotfile): if self.wakati and not wakati: raise WakatiModeOnlyException + lattice, next_pos = self.__build_lattice(text, baseform_unk) + min_cost_path = lattice.backward() + tokens = self.__nodes_to_tokens(min_cost_path, wakati) + + if dotfile: + lattice.generate_dotfile(filename=dotfile) + return (tokens, next_pos) + + def __tokenize_n_best(self, text, wakati, baseform_unk, n_best): + tokens_list = [] + lattice, _ = self.__build_lattice(text, baseform_unk) + for path in lattice.backward_astar(n_best): + tokens_list.append(self.__nodes_to_tokens(path, wakati)) + return tokens_list + + def __build_lattice(self, text, baseform_unk): chunk_size = min(len(text), Tokenizer.MAX_CHUNK_SIZE) lattice = Lattice(chunk_size, self.sys_dic) pos = 0 @@ -264,23 +282,23 @@ def __tokenize_partial(self, text, wakati, baseform_unk, dotfile): pos += lattice.forward() lattice.end() - min_cost_path = lattice.backward() - assert isinstance(min_cost_path[0], BOS) - assert isinstance(min_cost_path[-1], EOS) + return lattice, pos + + def __nodes_to_tokens(self, nodes, wakati): + assert isinstance(nodes[0], BOS) + assert isinstance(nodes[-1], EOS) if wakati: - tokens = [node.surface for node in min_cost_path[1:-1]] + tokens = [node.surface for node in nodes[1:-1]] else: tokens = [] - for node in min_cost_path[1:-1]: + for node in nodes[1:-1]: if type(node) == SurfaceNode and node.node_type == NodeType.SYS_DICT: tokens.append(Token(node, self.sys_dic.lookup_extra(node.num))) elif type(node) == SurfaceNode and node.node_type == NodeType.USER_DICT: tokens.append(Token(node, self.user_dic.lookup_extra(node.num))) else: tokens.append(Token(node)) - if dotfile: - lattice.generate_dotfile(filename=dotfile) - return (tokens, pos) + return tokens def __should_split(self, text, pos): return \ diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 86f2941..6f92c94 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -283,6 +283,32 @@ def test_tokenize_dotfile_large_text(self): tokens = Tokenizer().tokenize(text, dotfile=dotfile) self.assertFalse(os.path.exists(dotfile)) + def test_tokenize_n_best(self): + text = u'すもももももももものうち' + tokens_list = Tokenizer().tokenize(text, n_best=3) + self.assertEqual(3, len(tokens_list)) + self._check_token(tokens_list[0][0], u'すもも', u'名詞,一般,*,*,*,*,すもも,スモモ,スモモ', NodeType.SYS_DICT) + self._check_token(tokens_list[0][1], u'も', u'助詞,係助詞,*,*,*,*,も,モ,モ', NodeType.SYS_DICT) + self._check_token(tokens_list[0][2], u'もも', u'名詞,一般,*,*,*,*,もも,モモ,モモ', NodeType.SYS_DICT) + self._check_token(tokens_list[0][3], u'も', u'助詞,係助詞,*,*,*,*,も,モ,モ', NodeType.SYS_DICT) + self._check_token(tokens_list[0][4], u'もも', u'名詞,一般,*,*,*,*,もも,モモ,モモ', NodeType.SYS_DICT) + self._check_token(tokens_list[0][5], u'の', u'助詞,連体化,*,*,*,*,の,ノ,ノ', NodeType.SYS_DICT) + self._check_token(tokens_list[0][6], u'うち', u'名詞,非自立,副詞可能,*,*,*,うち,ウチ,ウチ', NodeType.SYS_DICT) + self._check_token(tokens_list[1][0], u'すもも', u'名詞,一般,*,*,*,*,すもも,スモモ,スモモ', NodeType.SYS_DICT) + self._check_token(tokens_list[1][1], u'も', u'助詞,係助詞,*,*,*,*,も,モ,モ', NodeType.SYS_DICT) + self._check_token(tokens_list[1][2], u'もも', u'名詞,一般,*,*,*,*,もも,モモ,モモ', NodeType.SYS_DICT) + self._check_token(tokens_list[1][3], u'もも', u'名詞,一般,*,*,*,*,もも,モモ,モモ', NodeType.SYS_DICT) + self._check_token(tokens_list[1][4], u'も', u'助詞,係助詞,*,*,*,*,も,モ,モ', NodeType.SYS_DICT) + self._check_token(tokens_list[1][5], u'の', u'助詞,連体化,*,*,*,*,の,ノ,ノ', NodeType.SYS_DICT) + self._check_token(tokens_list[1][6], u'うち', u'名詞,非自立,副詞可能,*,*,*,うち,ウチ,ウチ', NodeType.SYS_DICT) + self._check_token(tokens_list[2][0], u'すもも', u'名詞,一般,*,*,*,*,すもも,スモモ,スモモ', NodeType.SYS_DICT) + self._check_token(tokens_list[2][1], u'もも', u'名詞,一般,*,*,*,*,もも,モモ,モモ', NodeType.SYS_DICT) + self._check_token(tokens_list[2][2], u'も', u'助詞,係助詞,*,*,*,*,も,モ,モ', NodeType.SYS_DICT) + self._check_token(tokens_list[2][3], u'もも', u'名詞,一般,*,*,*,*,もも,モモ,モモ', NodeType.SYS_DICT) + self._check_token(tokens_list[2][4], u'も', u'助詞,係助詞,*,*,*,*,も,モ,モ', NodeType.SYS_DICT) + self._check_token(tokens_list[2][5], u'の', u'助詞,連体化,*,*,*,*,の,ノ,ノ', NodeType.SYS_DICT) + self._check_token(tokens_list[2][6], u'うち', u'名詞,非自立,副詞可能,*,*,*,うち,ウチ,ウチ', NodeType.SYS_DICT) + def _check_token(self, token, surface, detail, node_type): self.assertEqual(surface, token.surface) self.assertEqual(detail, ','.join([token.part_of_speech,token.infl_type,token.infl_form,token.base_form,token.reading,token.phonetic]))