-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathattr_tree.py
92 lines (80 loc) · 2.83 KB
/
attr_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# import json
import logging
from collections import defaultdict
import pyparsing
from pyparsing import nestedExpr
ALL_STRINGS = set()
class AttrTree(list):
@staticmethod
def from_string(raw_string):
ALL_STRINGS.clear()
t_string = raw_string.replace("*", "").replace(",", " ").replace("''", "")
if not t_string:
return AttrTree.from_list([])
try:
t_list = nestedExpr("(", ")").parseString(t_string).asList()
except pyparsing.ParseException:
print("pyparsing exception caused by:", raw_string, t_string)
print("returning empty")
return AttrTree.from_list([])
return AttrTree.from_list(t_list)
@staticmethod
def from_list(t_list):
# print('construncting from this list:', t_list)
if isinstance(t_list, str):
# print(f'ALL STRINGS: {ALL_STRINGS}')
if t_list in ALL_STRINGS:
s = t_list + "_"
else:
s = t_list
ALL_STRINGS.add(s)
return s
else:
return AttrTree(AttrTree.from_list(elem) for elem in t_list)
def gen_attr_dfs(self):
for elem in self:
if isinstance(elem, str):
yield elem
else:
yield from elem.gen_attr_dfs()
def count_attr_dists(self):
if self.d:
return
strings, strees = [], []
self_dist = {}
for selem in self:
if isinstance(selem, str):
strings.append(selem)
else:
strees.append(selem)
selem.count_attr_dists()
self.d.update(selem.d)
self_dist.update(
{
attr: min(self_dist.get(attr, float("inf")), val + 1)
for attr, val in selem.d["self"].items()
}
)
self.d["self"] = self_dist
for s in strings:
for attr, dist in self.d["self"].items():
self.d[attr][s] = dist + 1
self.d[s][attr] = dist + 1
self.d["self"][s] = 1
for s2 in strings:
if s != s2:
self.d[s][s2] = 1
self.d[s2][s] = 1
for attr, dist in self.d["self"].items():
for attr2, dist2 in self.d["self"].items():
# print('checking:', attr, attr2)
if attr == attr2 or attr2 in self.d[attr]:
continue
logging.debug("new_path:", attr, attr2)
self.d[attr][attr2] = dist + dist2
self.d[attr2][attr] = dist + dist2
# print('counted, d:')
# print(json.dumps(self.d))
def __init__(self, seq):
super(AttrTree, self).__init__(seq)
self.d = defaultdict(dict)