@@ -57,6 +57,7 @@ def tokens_per_second(self):
57
57
58
58
class GraphParser (AbstractParser ):
59
59
""" Parser for a single graph, has a state and optionally an oracle """
60
+
60
61
def __init__ (self , graph , * args , target = None , ** kwargs ):
61
62
"""
62
63
:param graph: gold Graph to get the correct nodes and edges from (in training), or just to get id from (in test)
@@ -74,20 +75,54 @@ def __init__(self, graph, *args, target=None, **kwargs):
74
75
assert self .lang , "Attribute 'lang' is required per passage when using multilingual BERT"
75
76
self .state_hash_history = set ()
76
77
self .state = self .oracle = None
77
- if self .framework == "amr" and self .alignment : # Copy alignments to anchors, updating graph
78
+ if self .framework in ( "amr" , "drg" , "ptg" ) and self .alignment : # Copy alignments to anchors, updating graph
78
79
for alignment_node in self .alignment .nodes :
79
80
node = self .graph .find_node (alignment_node .id )
80
81
if node is None :
81
82
self .config .log ("graph %s: invalid alignment node %s" % (self .graph .id , alignment_node .id ))
82
83
continue
83
84
if node .anchors is None :
84
85
node .anchors = []
85
- for conllu_node_id in (alignment_node .label or []) + list (chain (* alignment_node .values or [])):
86
- conllu_node = self .conllu .find_node (conllu_node_id )
87
- if conllu_node is None :
88
- raise ValueError ("Alignments incompatible with tokenization: token %s "
89
- "not found in graph %s" % (conllu_node_id , self .graph .id ))
90
- node .anchors += conllu_node .anchors
86
+
87
+ conllu_node_id_list = None
88
+ alignment_node_anchor_char_range_list = None
89
+ if self .alignment .framework == "alignment" :
90
+ conllu_node_id_list = (alignment_node .label or []) + list (chain (* alignment_node .values or []))
91
+ elif self .alignment .framework == "anchoring" and self .framework in ("amr" , "ptg" ):
92
+ conllu_node_id_list = set ([alignment_dict ["#" ] for alignment_dict in
93
+ (alignment_node .anchors or [])
94
+ + ([anchor for anchor_list in (alignment_node .anchorings or []) for anchor in anchor_list ])])
95
+ elif self .alignment .framework == "anchoring" and self .framework == "drg" :
96
+ alignment_node_anchor_char_range_list = [(int (alignment_dict ["from" ]),(int (alignment_dict ["to" ]))) for alignment_dict in
97
+ (alignment_node .anchors or [])
98
+ + ([anchor for anchor_list in (alignment_node .anchorings or []) for anchor in anchor_list ])]
99
+ assert all ([len (conllu_node .anchors ) == 1 for conllu_node in self .conllu .nodes ])
100
+ anchors_to_conllu_node = {(int (conllu_node .anchors [0 ]["from" ]), int (conllu_node .anchors [0 ]["to" ])):
101
+ conllu_node
102
+ for conllu_node in self .conllu .nodes }
103
+ else :
104
+ raise ValueError (f'Unknown alignments framework: { alignment_node .framework } ' )
105
+
106
+ if conllu_node_id_list is not None :
107
+ assert self .framework in ("amr" , "ptg" )
108
+ for conllu_node_id in conllu_node_id_list :
109
+ conllu_node = self .conllu .find_node (conllu_node_id + 1 )
110
+
111
+ if conllu_node is None :
112
+ raise ValueError ("Alignments incompatible with tokenization: token %s "
113
+ "not found in graph %s" % (conllu_node_id , self .graph .id ))
114
+ node .anchors += conllu_node .anchors
115
+
116
+ elif alignment_node_anchor_char_range_list is not None :
117
+ for alignment_node_char_range in alignment_node_anchor_char_range_list :
118
+ for conllu_anchor_range in anchors_to_conllu_node :
119
+ if alignment_node_char_range [0 ] <= conllu_anchor_range [0 ] \
120
+ and alignment_node_char_range [1 ] >= conllu_anchor_range [1 ]:
121
+ conllu_node = anchors_to_conllu_node [conllu_anchor_range ]
122
+ if conllu_node is None :
123
+ raise ValueError ("Alignments incompatible with tokenization: token %s "
124
+ "not found in graph %s" % (conllu_anchor_range , self .graph .id ))
125
+ node .anchors += conllu_node .anchors
91
126
92
127
def init (self ):
93
128
self .config .set_framework (self .framework )
@@ -320,6 +355,7 @@ def num_tokens(self, _):
320
355
321
356
class BatchParser (AbstractParser ):
322
357
""" Parser for a single training iteration or single pass over dev/test graphs """
358
+
323
359
def __init__ (self , * args , ** kwargs ):
324
360
super ().__init__ (* args , ** kwargs )
325
361
self .seen_per_framework = defaultdict (int )
@@ -335,14 +371,11 @@ def parse(self, graphs, display=True, write=False, accuracies=None):
335
371
if conllu is None :
336
372
self .config .print ("skipped '%s', no companion conllu data found" % graph .id )
337
373
continue
338
- alignment = self .alignment .get (graph .id )
374
+ alignment = self .alignment .get (graph .id ) if self . alignment else None
339
375
for target in graph .targets () or [graph .framework ]:
340
376
if not self .training and target not in self .model .classifier .labels :
341
377
self .config .print ("skipped target '%s' for '%s': did not train on it" % (target , graph .id ), level = 1 )
342
378
continue
343
- if target == "amr" and alignment is None :
344
- self .config .print ("skipped target 'amr' for '%s': no companion alignment found" % graph .id , level = 1 )
345
- continue
346
379
parser = GraphParser (
347
380
graph , self .config , self .model , self .training , conllu = conllu , alignment = alignment , target = target )
348
381
if self .config .args .verbose and display :
@@ -403,6 +436,7 @@ def time_per_graph(self):
403
436
404
437
class Parser (AbstractParser ):
405
438
""" Main class to implement transition-based meaning representation parser """
439
+
406
440
def __init__ (self , model_file = None , config = None , training = None , conllu = None , alignment = None ):
407
441
super ().__init__ (config = config or Config (), model = Model (model_file or config .args .model ),
408
442
training = config .args .train if training is None else training ,
@@ -646,7 +680,7 @@ def read_graphs_with_progress_bar(file_handle_or_graphs):
646
680
if isinstance (file_handle_or_graphs , IOBase ):
647
681
graphs , _ = read_graphs (
648
682
tqdm (file_handle_or_graphs , desc = "Reading " + getattr (file_handle_or_graphs , "name" , "input" ),
649
- unit = " graphs" ), format = "mrp" )
683
+ unit = " graphs" ), format = "mrp" , robust = True )
650
684
return graphs
651
685
return file_handle_or_graphs
652
686
0 commit comments