forked from IdentityPython/pyFF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstore.py
356 lines (292 loc) · 11.6 KB
/
store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
from six import StringIO
import time
from copy import deepcopy
import re
from redis import Redis
from .constants import NS, ATTRS
from .decorators import cached
from .logs import log
from .utils import root, dumptree, parse_xml, hex_digest, hash_id, valid_until_ts
from .samlmd import EntitySet, iter_entities, entity_attribute_dict, is_sp, is_idp
def _now():
return int(time.time())
DINDEX = ('sha1', 'sha256', 'null')
class StoreBase(object):
def lookup(self, key):
raise NotImplementedError()
def clone(self):
return self
def __iter__(self):
for e in self.lookup("entities"):
log.debug("**** yield entityID=%s" % e.get('entityID'))
yield e
def periodic(self, stats):
pass
def size(self, a=None, v=None):
raise NotImplementedError()
def collections(self):
raise NotImplementedError()
def update(self, t, tid=None, ts=None, merge_strategy=None):
raise NotImplementedError()
def reset(self):
raise NotImplementedError()
def entity_ids(self):
return set(e.get('entityID') for e in self.lookup('entities'))
class MemoryStore(StoreBase):
def __init__(self):
self.md = dict()
self.index = dict()
self.entities = dict()
for hn in DINDEX:
self.index.setdefault(hn, {})
self.index.setdefault('attr', {})
def __str__(self):
return repr(self.index)
def clone(self):
return deepcopy(self)
def size(self, a=None, v=None):
if a is None:
return len(self.entities)
elif a is not None and v is None:
return len(self.index.setdefault('attr', {}).setdefault(a, {}).keys())
else:
return len(self.index.setdefault('attr', {}).setdefault(a, {}).get(v, []))
def attributes(self):
return self.index.setdefault('attr', {}).keys()
def attribute(self, a):
return self.index.setdefault('attr', {}).setdefault(a, {}).keys()
def _modify(self, entity, modifier):
def _m(idx, vv):
getattr(idx.setdefault(vv, EntitySet()), modifier)(entity)
for hn in DINDEX:
_m(self.index[hn], hash_id(entity, hn, False))
attr_idx = self.index.setdefault('attr', {})
for attr, values in entity_attribute_dict(entity).items():
vidx = attr_idx.setdefault(attr, {})
for v in values:
_m(vidx, v)
vidx = attr_idx.setdefault(ATTRS['role'], {})
if is_idp(entity):
_m(vidx, "idp")
if is_sp(entity):
_m(vidx, "sp")
def _index(self, entity):
return self._modify(entity, "add")
def _unindex(self, entity):
return self._modify(entity, "discard")
def _get_index(self, a, v):
if a in DINDEX:
return self.index[a].get(v, [])
else:
idx = self.index['attr'].setdefault(a, {})
entities = idx.get(v, None)
if entities is not None:
return entities
else:
m = re.compile(v)
entities = []
for value, ents in idx.items():
if m.match(value):
entities.extend(ents)
return entities
def reset(self):
self.__init__()
def collections(self):
return self.md.keys()
def update(self, t, tid=None, ts=None, merge_strategy=None):
# log.debug("memory store update: %s: %s" % (repr(t), tid))
relt = root(t)
assert (relt is not None)
ne = 0
if relt.tag == "{%s}EntityDescriptor" % NS['md']:
# log.debug("memory store setting entity descriptor")
self._unindex(relt)
self._index(relt)
self.entities[relt.get('entityID')] = relt # TODO: merge?
if tid is not None:
self.md[tid] = [relt.get('entityID')]
ne += 1
# log.debug("keys %s" % self.md.keys())
elif relt.tag == "{%s}EntitiesDescriptor" % NS['md']:
if tid is None:
tid = relt.get('Name')
lst = []
for e in iter_entities(t):
self.update(e)
lst.append(e.get('entityID'))
ne += 1
self.md[tid] = lst
return ne
def lookup(self, key):
# log.debug("memory store lookup: %s" % key)
return self._lookup(key)
def _lookup(self, key):
if key == 'entities' or key is None:
return self.entities.values()
if '+' in key:
key = key.strip('+')
# log.debug("lookup intersection of '%s'" % ' and '.join(key.split('+')))
hits = None
for f in key.split("+"):
f = f.strip()
if hits is None:
hits = set(self._lookup(f))
else:
other = self._lookup(f)
hits.intersection_update(other)
if not hits:
log.debug("empty intersection")
return []
if hits is not None and hits:
return list(hits)
else:
return []
m = re.match("^(.+)=(.+)$", key)
if m:
return self._lookup("{%s}%s" % (m.group(1), str(m.group(2)).rstrip("/")))
m = re.match("^{(.+)}(.+)$", key)
if m:
res = set()
for v in str(m.group(2)).rstrip("/").split(';'):
# log.debug("... adding %s=%s" % (m.group(1),v))
res.update(self._get_index(m.group(1), v))
return list(res)
l = self._get_index("null", key)
if l:
return list(l)
if key in self.md:
# log.debug("entities list %s: %s" % (key, self.md[key]))
lst = []
for entityID in self.md[key]:
lst.extend(self.lookup(entityID))
return lst
return []
class RedisStore(StoreBase):
def __init__(self, version=_now(), default_ttl=3600 * 24 * 4, respect_validity=True):
self.rc = Redis()
self.default_ttl = default_ttl
self.respect_validity = respect_validity
def _expiration(self, relt):
ts = _now() + self.default_ttl
if self.respect_validity:
return valid_until_ts(relt, ts)
def reset(self):
self.rc.flushdb()
def _drop_empty_av(self, attr, tag, ts):
an = "#%s" % attr
for c in self.rc.smembers(an):
tn = "%s#members" % c
self.rc.zremrangebyscore(tn, "-inf", ts)
if not self.rc.zcard(tn) > 0:
log.debug("dropping empty %s %s" % (attr, c))
self.rc.srem(an, c)
def periodic(self, stats):
now = _now()
stats['Last Periodic Maintenance'] = now
log.debug("periodic maintentance...")
self.rc.zremrangebyscore("members", "-inf", now)
self._drop_empty_av("collections", "members", now)
self._drop_empty_av("attributes", "values", now)
def update_entity(self, relt, t, tid, ts, p=None):
if p is None:
p = self.rc
p.set("%s#metadata" % tid, dumptree(t))
self._get_metadata.invalidate(tid) # invalidate the parse-cache entry
if ts is not None:
p.expireat("%s#metadata" % tid, ts)
nfo = dict(expires=ts)
nfo.update(**relt.attrib)
p.hmset(tid, nfo)
if ts is not None:
p.expireat(tid, ts)
def membership(self, gid, mid, ts, p=None):
if p is None:
p = self.rc
p.zadd("%s#members" % gid, mid, ts)
# p.zadd("%s#groups", mid, gid, ts)
p.sadd("#collections", gid)
def attributes(self):
return self.rc.smembers("#attributes")
def attribute(self, an):
return self.rc.zrangebyscore("%s#values" % an, _now(), "+inf")
def collections(self):
return self.rc.smembers("#collections")
def update(self, t, tid=None, ts=None, merge_strategy=None): # TODO: merge ?
log.debug("redis store update: %s: %s" % (t, tid))
relt = root(t)
ne = 0
if ts is None:
ts = int(_now() + 3600 * 24 * 4) # 4 days is the arbitrary default expiration
if relt.tag == "{%s}EntityDescriptor" % NS['md']:
if tid is None:
tid = relt.get('entityID')
with self.rc.pipeline() as p:
self.update_entity(relt, t, tid, ts, p)
entity_id = relt.get("entityID")
if entity_id is not None:
self.membership("entities", entity_id, ts, p)
for ea, eav in entity_attribute_dict(relt).items():
for v in eav:
# log.debug("%s=%s" % (ea, v))
self.membership("{%s}%s" % (ea, v), tid, ts, p)
p.zadd("%s#values" % ea, v, ts)
p.sadd("#attributes", ea)
for hn in ('sha1', 'sha256', 'md5'):
tid_hash = hex_digest(tid, hn)
p.set("{%s}%s#alias" % (hn, tid_hash), tid)
if ts is not None:
p.expireat(tid_hash, ts)
p.execute()
ne += 1
elif relt.tag == "{%s}EntitiesDescriptor" % NS['md']:
if tid is None:
tid = relt.get('Name')
ts = self._expiration(relt)
with self.rc.pipeline() as p:
self.update_entity(relt, t, tid, ts, p)
for e in iter_entities(t):
ne += self.update(e, ts=ts)
entity_id = e.get("entityID")
if entity_id is not None:
self.membership(tid, entity_id, ts, p)
self.membership("entities", entity_id, ts, p)
p.execute()
else:
raise ValueError("Bad metadata top-level element: '%s'" % root(t).tag)
return ne
def _members(self, k):
mem = []
if self.rc.exists("%s#members" % k):
for entity_id in self.rc.zrangebyscore("%s#members" % k, _now(), "+inf"):
mem.extend(self.lookup(entity_id))
return mem
@cached(ttl=30)
def _get_metadata(self, key):
return root(parse_xml(StringIO(self.rc.get("%s#metadata" % key))))
def lookup(self, key):
log.debug("redis store lookup: %s" % key)
if '+' in key:
hk = hex_digest(key)
if not self.rc.exists("%s#members" % hk):
self.rc.zinterstore("%s#members" % hk, ["%s#members" % k for k in key.split('+')], 'min')
self.rc.expire("%s#members" % hk, 30) # XXX bad juju - only to keep clients from hammering
return self.lookup(hk)
m = re.match("^(.+)=(.+)$", key)
if m:
return self.lookup("{%s}%s" % (m.group(1), m.group(2)))
m = re.match("^{(.+)}(.+)$", key)
if m and ';' in m.group(2):
hk = hex_digest(key)
if not self.rc.exists("%s#members" % hk):
self.rc.zunionstore("%s#members" % hk,
["{%s}%s#members" % (m.group(1), v) for v in str(m.group(2)).split(';')], 'min')
self.rc.expire("%s#members" % hk, 30) # XXX bad juju - only to keep clients from hammering
return self.lookup(hk)
elif self.rc.exists("%s#alias" % key):
return self.lookup(self.rc.get("%s#alias" % key))
elif self.rc.exists("%s#metadata" % key):
return [self._get_metadata(key)]
else:
return self._members(key)
def size(self):
return self.rc.zcount("entities#members", _now(), "+inf")