Skip to content
This repository was archived by the owner on Dec 14, 2020. It is now read-only.

Commit e5536d8

Browse files
OfirArvivOfir Arviv
and
Ofir Arviv
authored
Mrp 2020 (#106)
* check for cycles * add support for 2020 mrp format * merge changes * fix * CR fixes * remove print Co-authored-by: Ofir Arviv <t-ofarvi@microsoft.com>
1 parent 3c195eb commit e5536d8

File tree

7 files changed

+118
-21
lines changed

7 files changed

+118
-21
lines changed

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ dynet==2.1
66
logbook>=1.5.2
77
word2number>=1.1
88
git+https://github.com/cfmrp/mtool.git#egg=mtool
9+
networkx==2.4
10+
matplotlib == 3.2.1

tupa/config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
COMPOUND = "compound"
2323

2424
# Required number of edge labels per framework
25-
NODE_LABELS_NUM = {"amr": 1000, "dm": 1000, "psd": 1000, "eds": 1000, "ucca": 0}
26-
NODE_PROPERTY_NUM = {"amr": 1000, "dm": 510, "psd": 1000, "eds": 1000, "ucca": 0}
27-
EDGE_LABELS_NUM = {"amr": 141, "dm": 59, "psd": 90, "eds": 10, "ucca": 15}
25+
NODE_LABELS_NUM = {"amr": 1000, "dm": 1000, "psd": 1000, "eds": 1000, "ucca": 0, "ptg": 1000}
26+
NODE_PROPERTY_NUM = {"amr": 1000, "dm": 510, "psd": 1000, "eds": 1000, "ucca": 0, "ptg": 1000}
27+
EDGE_LABELS_NUM = {"amr": 141, "dm": 59, "psd": 90, "eds": 10, "ucca": 15, "ptg": 150}
2828
EDGE_ATTRIBUTE_NUM = {"amr": 0, "dm": 0, "psd": 0, "eds": 0, "ucca": 2}
2929
NN_ARG_NAMES = set()
3030
DYNET_ARG_NAMES = set()
@@ -537,20 +537,20 @@ def __str__(self):
537537

538538

539539
def requires_node_labels(framework):
540-
return framework != "ucca"
540+
return framework not in ("ucca", "drg", "ptg")
541541

542542

543543
def requires_node_properties(framework):
544-
return framework != "ucca"
544+
return framework not in ("ucca", "drg")
545545

546546

547547
def requires_edge_attributes(framework):
548548
return framework == "ucca"
549549

550550

551551
def requires_anchors(framework):
552-
return framework != "amr"
552+
return framework not in ("amr", "drg")
553553

554554

555555
def requires_tops(framework):
556-
return framework in ("ucca", "amr")
556+
return framework in ("ucca", "amr", "drg", "ptg")

tupa/constraints/ptg.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .validation import Constraints
2+
3+
4+
class PtgConstraints(Constraints):
5+
def __init__(self, **kwargs):
6+
super().__init__(multigraph=True, **kwargs)

tupa/constraints/validation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,18 @@ def eds_constraints(**kwargs):
155155
return EdsConstraints(**kwargs)
156156

157157

158+
def ptg_constraints(**kwargs):
159+
from .ptg import PtgConstraints
160+
return PtgConstraints(**kwargs)
161+
162+
158163
CONSTRAINTS = {
159164
"ucca": ucca_constraints,
160165
"amr": amr_constraints,
161166
"dm": sdp_constraints,
162167
"psd": sdp_constraints,
163168
"eds": eds_constraints,
169+
"ptg": ptg_constraints,
164170
}
165171

166172

tupa/oracle.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,12 @@ def get_node_label(self, state, node):
193193
return true_label, raw_true_label
194194

195195
def get_node_property_value(self, state, node):
196-
true_property_value = next((k, v) for k, v in node.ref_node.properties.items()
197-
if k not in (node.properties or ()))
196+
try:
197+
true_property_value = next((k, v) for k, v in (node.ref_node.properties.items()
198+
if node.ref_node.properties else [])
199+
if k not in (node.properties or ()))
200+
except StopIteration:
201+
return None
198202
if self.args.validate_oracle:
199203
try:
200204
state.check_valid_property_value(true_property_value, message=True)

tupa/parse.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def tokens_per_second(self):
5757

5858
class GraphParser(AbstractParser):
5959
""" Parser for a single graph, has a state and optionally an oracle """
60+
6061
def __init__(self, graph, *args, target=None, **kwargs):
6162
"""
6263
:param graph: gold Graph to get the correct nodes and edges from (in training), or just to get id from (in test)
@@ -74,20 +75,54 @@ def __init__(self, graph, *args, target=None, **kwargs):
7475
assert self.lang, "Attribute 'lang' is required per passage when using multilingual BERT"
7576
self.state_hash_history = set()
7677
self.state = self.oracle = None
77-
if self.framework == "amr" and self.alignment: # Copy alignments to anchors, updating graph
78+
if self.framework in ("amr", "drg", "ptg") and self.alignment: # Copy alignments to anchors, updating graph
7879
for alignment_node in self.alignment.nodes:
7980
node = self.graph.find_node(alignment_node.id)
8081
if node is None:
8182
self.config.log("graph %s: invalid alignment node %s" % (self.graph.id, alignment_node.id))
8283
continue
8384
if node.anchors is None:
8485
node.anchors = []
85-
for conllu_node_id in (alignment_node.label or []) + list(chain(*alignment_node.values or [])):
86-
conllu_node = self.conllu.find_node(conllu_node_id)
87-
if conllu_node is None:
88-
raise ValueError("Alignments incompatible with tokenization: token %s "
89-
"not found in graph %s" % (conllu_node_id, self.graph.id))
90-
node.anchors += conllu_node.anchors
86+
87+
conllu_node_id_list = None
88+
alignment_node_anchor_char_range_list = None
89+
if self.alignment.framework == "alignment":
90+
conllu_node_id_list = (alignment_node.label or []) + list(chain(*alignment_node.values or []))
91+
elif self.alignment.framework == "anchoring" and self.framework in ("amr", "ptg"):
92+
conllu_node_id_list = set([alignment_dict["#"] for alignment_dict in
93+
(alignment_node.anchors or [])
94+
+ ([anchor for anchor_list in (alignment_node.anchorings or []) for anchor in anchor_list])])
95+
elif self.alignment.framework == "anchoring" and self.framework == "drg":
96+
alignment_node_anchor_char_range_list = [(int(alignment_dict["from"]),(int(alignment_dict["to"]))) for alignment_dict in
97+
(alignment_node.anchors or [])
98+
+ ([anchor for anchor_list in (alignment_node.anchorings or []) for anchor in anchor_list])]
99+
assert all([len(conllu_node.anchors) == 1 for conllu_node in self.conllu.nodes])
100+
anchors_to_conllu_node = {(int(conllu_node.anchors[0]["from"]), int(conllu_node.anchors[0]["to"])):
101+
conllu_node
102+
for conllu_node in self.conllu.nodes}
103+
else:
104+
raise ValueError(f'Unknown alignments framework: {alignment_node.framework}')
105+
106+
if conllu_node_id_list is not None:
107+
assert self.framework in ("amr", "ptg")
108+
for conllu_node_id in conllu_node_id_list:
109+
conllu_node = self.conllu.find_node(conllu_node_id + 1)
110+
111+
if conllu_node is None:
112+
raise ValueError("Alignments incompatible with tokenization: token %s "
113+
"not found in graph %s" % (conllu_node_id, self.graph.id))
114+
node.anchors += conllu_node.anchors
115+
116+
elif alignment_node_anchor_char_range_list is not None:
117+
for alignment_node_char_range in alignment_node_anchor_char_range_list:
118+
for conllu_anchor_range in anchors_to_conllu_node:
119+
if alignment_node_char_range[0] <= conllu_anchor_range[0] \
120+
and alignment_node_char_range[1] >= conllu_anchor_range[1]:
121+
conllu_node = anchors_to_conllu_node[conllu_anchor_range]
122+
if conllu_node is None:
123+
raise ValueError("Alignments incompatible with tokenization: token %s "
124+
"not found in graph %s" % (conllu_anchor_range, self.graph.id))
125+
node.anchors += conllu_node.anchors
91126

92127
def init(self):
93128
self.config.set_framework(self.framework)
@@ -320,6 +355,7 @@ def num_tokens(self, _):
320355

321356
class BatchParser(AbstractParser):
322357
""" Parser for a single training iteration or single pass over dev/test graphs """
358+
323359
def __init__(self, *args, **kwargs):
324360
super().__init__(*args, **kwargs)
325361
self.seen_per_framework = defaultdict(int)
@@ -335,14 +371,11 @@ def parse(self, graphs, display=True, write=False, accuracies=None):
335371
if conllu is None:
336372
self.config.print("skipped '%s', no companion conllu data found" % graph.id)
337373
continue
338-
alignment = self.alignment.get(graph.id)
374+
alignment = self.alignment.get(graph.id) if self.alignment else None
339375
for target in graph.targets() or [graph.framework]:
340376
if not self.training and target not in self.model.classifier.labels:
341377
self.config.print("skipped target '%s' for '%s': did not train on it" % (target, graph.id), level=1)
342378
continue
343-
if target == "amr" and alignment is None:
344-
self.config.print("skipped target 'amr' for '%s': no companion alignment found" % graph.id, level=1)
345-
continue
346379
parser = GraphParser(
347380
graph, self.config, self.model, self.training, conllu=conllu, alignment=alignment, target=target)
348381
if self.config.args.verbose and display:
@@ -403,6 +436,7 @@ def time_per_graph(self):
403436

404437
class Parser(AbstractParser):
405438
""" Main class to implement transition-based meaning representation parser """
439+
406440
def __init__(self, model_file=None, config=None, training=None, conllu=None, alignment=None):
407441
super().__init__(config=config or Config(), model=Model(model_file or config.args.model),
408442
training=config.args.train if training is None else training,
@@ -646,7 +680,7 @@ def read_graphs_with_progress_bar(file_handle_or_graphs):
646680
if isinstance(file_handle_or_graphs, IOBase):
647681
graphs, _ = read_graphs(
648682
tqdm(file_handle_or_graphs, desc="Reading " + getattr(file_handle_or_graphs, "name", "input"),
649-
unit=" graphs"), format="mrp")
683+
unit=" graphs"), format="mrp", robust=True)
650684
return graphs
651685
return file_handle_or_graphs
652686

tupa/states/ref_graph.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
import sys
2+
from typing import List, Tuple
3+
14
from .anchors import expand_anchors
25
from .edge import StateEdge
36
from .node import StateNode
47
from ..constraints.amr import NAME
58
from ..constraints.validation import ROOT_ID, ROOT_LAB, ANCHOR_LAB
69
from ..recategorization import resolve, compress_name
710

11+
import networkx as nx
812

913
class RefGraph:
1014
def __init__(self, graph, conllu, framework):
@@ -27,6 +31,7 @@ def __init__(self, graph, conllu, framework):
2731
offset = len(conllu.nodes) + 1
2832
self.non_virtual_nodes = []
2933
self.edges = []
34+
have_anchors = False
3035
for graph_node in graph.nodes:
3136
node_id = graph_node.id + offset
3237
id2node[node_id] = node = \
@@ -43,7 +48,21 @@ def __init__(self, graph, conllu, framework):
4348
anchor_terminals = [min(self.terminals, key=lambda terminal: min(
4449
x - y for x in terminal.anchors for y in node.anchors))] # Must have anchors, get closest one
4550
for terminal in anchor_terminals:
51+
have_anchors = True
4652
self.edges.append(StateEdge(node, terminal, ANCHOR_LAB).add())
53+
54+
if not have_anchors:
55+
print(f'framework {graph.framework} graph id {graph.id} have no anchors', file=sys.stderr)
56+
57+
cycle = find_cycle(graph)
58+
while len(cycle) > 0:
59+
edge_list = list(graph.edges)
60+
first_edge_idx = \
61+
[i for i, edge in enumerate(graph.edges) if edge.src == cycle[0][0] and edge.tgt == cycle[0][1]][0]
62+
del edge_list[first_edge_idx]
63+
graph.edges = set(edge_list)
64+
cycle = find_cycle(graph)
65+
4766
for edge in graph.edges:
4867
if edge.src != edge.tgt: # Drop self-loops as the parser currently does not support them
4968
self.edges.append(StateEdge(id2node[edge.src + offset],
@@ -55,4 +74,30 @@ def __init__(self, graph, conllu, framework):
5574
node.properties = compress_name(node.properties)
5675
node.properties = {prop: resolve(node, value, introduce_placeholders=True)
5776
for prop, value in node.properties.items()}
77+
5878
node.label = resolve(node, node.label, introduce_placeholders=True) # Must be after properties in case NAME
79+
80+
81+
def find_cycle(graph, plot_graph=False) -> List[Tuple[int, int]]:
82+
edges_tuple = [(e.src, e.tgt) for e in graph.edges]
83+
nx_graph = nx.DiGraph()
84+
nx_graph.add_edges_from(edges_tuple)
85+
try:
86+
cycle = nx.find_cycle(nx_graph)
87+
except nx.exception.NetworkXNoCycle as e:
88+
cycle = []
89+
90+
if plot_graph:
91+
import matplotlib.pyplot as plt
92+
nx.draw(nx_graph, with_labels=True, font_weight='bold')
93+
plt.show()
94+
95+
return cycle
96+
97+
98+
def is_directed_acyclic_graph(graph) -> bool:
99+
edges_tuple = list(map(lambda x: (x.src, x.tgt), graph.edges))
100+
nx_graph = nx.DiGraph()
101+
nx_graph.add_edges_from(edges_tuple)
102+
103+
assert nx.is_directed_acyclic_graph(nx_graph)

0 commit comments

Comments
 (0)