diff --git a/src/pyff/builtins.py b/src/pyff/builtins.py index b7950a55..0e592bd2 100644 --- a/src/pyff/builtins.py +++ b/src/pyff/builtins.py @@ -1,8 +1,9 @@ -from __future__ import print_function - """Package that contains the basic set of pipes - functions that can be used to put together a processing pipeling for pyFF. """ + +from __future__ import absolute_import, print_function + import base64 import hashlib import json @@ -19,13 +20,14 @@ from iso8601 import iso8601 from lxml.etree import DocumentInvalid -from pyff.constants import NS -from pyff.decorators import deprecated -from pyff.logs import log -from pyff.pipes import Plumbing, PipeException, PipelineCallback, pipe -from pyff.stats import set_metadata_info -from pyff.utils import total_seconds, dumptree, safe_write, root, duration2timedelta, xslt_transform, \ - iter_entities, validate_document +from .constants import NS +from .decorators import deprecated +from .logs import log +from .pipes import Plumbing, PipeException, PipelineCallback, pipe +from .stats import set_metadata_info +from .utils import total_seconds, dumptree, safe_write, root, duration2timedelta, xslt_transform, validate_document +from .samlmd import iter_entities, annotate_entity, set_entity_attributes +from .fetch import Resource try: from cStringIO import StringIO @@ -148,10 +150,10 @@ def fork(req, *opts): if req.t is not None: nt = deepcopy(req.t) - ip = Plumbing(pipeline=req.args, pid="{}.fork".format(req.plumbing.pid)) + ip = Plumbing(pipeline=req.args, pid="%s.fork" % req.plumbing.pid) # ip.process(req.md,t=nt) ireq = Plumbing.Request(ip, req.md, nt) - ip._process(ireq) + ip.iprocess(ireq) if req.t is not None and ireq.t is not None and len(root(ireq.t)) > 0: if 'merge' in opts: @@ -239,7 +241,7 @@ def _pipe(req, *opts): """ # req.process(Plumbing(pipeline=req.args, pid="%s.pipe" % req.plumbing.pid)) - ot = Plumbing(pipeline=req.args, pid="{}.pipe".format(req.plumbing.id))._process(req) + ot = Plumbing(pipeline=req.args, pid="%s.pipe" % req.plumbing.id).iprocess(req) req.done = False return ot @@ -252,7 +254,6 @@ def when(req, condition, *values): :param req: The request :param condition: The condition key :param values: The condition values -:param opts: More Options (unused) :return: None The inner pipeline is executed if the at least one of the condition values is present for the specified key in @@ -270,12 +271,9 @@ def when(req, condition, *values): The condition operates on the state: if 'foo' is present in the state (with any value), then the something branch is followed. If 'bar' is present in the state with the value 'bill' then the other branch is followed. """ - # log.debug("condition key: %s" % repr(condition)) c = req.state.get(condition, None) - # log.debug("condition %s" % repr(c)) - if c is not None: - if not values or _any(values, c): - return Plumbing(pipeline=req.args, pid="%s.when" % req.plumbing.id)._process(req) + if c is not None and (not values or _any(values, c)): + return Plumbing(pipeline=req.args, pid="%s.when" % req.plumbing.id).iprocess(req) return req.t @@ -357,7 +355,7 @@ def loadstats(req, *opts): :param opts: Options: (none) :return: None """ - from stats import metadata + from .stats import metadata _stats = None try: if 'json' in opts: @@ -467,36 +465,17 @@ def load(req, *opts): params.setdefault('as', url) - post = None + def _null(t): + return t + + post = _null if params['via'] is not None: post = PipelineCallback(params['via'], req) - if "://" in url: - log.debug("load {} verify {} as {} via {}".format(url, params['verify'], params['as'], params['via'])) - remotes.append((url, params['verify'], params['as'], post)) - elif os.path.exists(url): - if os.path.isdir(url): - log.debug("directory {} verify {} as {} via {}".format(url, params['verify'], params['as'], params['via'])) - req.md.load_dir(url, url=params['as'], validate=opts['validate'], post=post, - fail_on_error=opts['fail_on_error'], filter_invalid=opts['filter_invalid']) - elif os.path.isfile(url): - log.debug("file {} verify {} as {} via {}".format(url, params['verify'], params['as'], params['via'])) - remotes.append(("file://%s" % url, params['verify'], params['as'], post)) - else: - error = "Unknown file type for load: '{}'".format(url) - if opts['fail_on_error']: - raise PipeException(error) - log.error(error) - else: - error = "Don't know how to load '{}' as {} verify {} via {} (file does not exist?)".format(url, - params['as'], - params['verify'], - params['via']) - if opts['fail_on_error']: - raise PipeException(error) - log.error(error) + req.md.rm.add(Resource(url, post, **params)) - req.md.fetch_metadata(remotes, **opts) + log.debug("Refreshing all resources") + req.md.reload() def _select_args(req): @@ -915,6 +894,7 @@ def validate(req, *opts): return req.t + @pipe def prune(req, *opts): """ @@ -1018,7 +998,7 @@ def certreport(req, *opts): keysize = cdict['modulus'].bit_length() cert = cdict['cert'] if keysize < error_bits: - req.md.annotate(entity_elt, + annotate_entity(entity_elt, "certificate-error", "keysize too small", "%s has keysize of %s bits (less than %s)" % (cert.getSubject(), @@ -1026,7 +1006,7 @@ def certreport(req, *opts): error_bits)) log.error("%s has keysize of %s" % (eid, keysize)) elif keysize < warning_bits: - req.md.annotate(entity_elt, + annotate_entity(entity_elt, "certificate-warning", "keysize small", "%s has keysize of %s bits (less than %s)" % (cert.getSubject(), @@ -1036,7 +1016,7 @@ def certreport(req, *opts): notafter = cert.getNotAfter() if notafter is None: - req.md.annotate(entity_elt, + annotate_entity(entity_elt, "certificate-error", "certificate has no expiration time", "%s has no expiration time" % cert.getSubject()) @@ -1046,23 +1026,24 @@ def certreport(req, *opts): now = datetime.now() dt = et - now if total_seconds(dt) < error_seconds: - req.md.annotate(entity_elt, + annotate_entity(entity_elt, "certificate-error", "certificate has expired", "%s expired %s ago" % (cert.getSubject(), -dt)) log.error("%s expired %s ago" % (eid, -dt)) elif total_seconds(dt) < warning_seconds: - req.md.annotate(entity_elt, + annotate_entity(entity_elt, "certificate-warning", "certificate about to expire", "%s expires in %s" % (cert.getSubject(), dt)) log.warn("%s expires in %s" % (eid, dt)) except ValueError as ex: - req.md.annotate(entity_elt, + annotate_entity(entity_elt, "certificate-error", "certificate has unknown expiration time", "%s unknown expiration time %s" % (cert.getSubject(), notafter)) + req.md.store.update(entity_elt) except Exception as ex: log.error(ex) @@ -1133,7 +1114,7 @@ def signcerts(req, *opts): if req.t is None: raise PipeException("Your pipeline is missing a select statement.") - for fp, pem in xmlsec.crypto.CertDict(req.t).iteritems(): + for fp, pem in xmlsec.crypto.CertDict(req.t).items(): log.info("found signing cert with fingerprint %s" % fp) return req.t @@ -1191,12 +1172,12 @@ def finalize(req, *opts): mdid = req.args.get('ID', 'prefix _') if re.match('(\s)*prefix(\s)*', mdid): prefix = re.sub('^(\s)*prefix(\s)*', '', mdid) - ID = now.strftime(prefix + "%Y%m%dT%H%M%SZ") + _id = now.strftime(prefix + "%Y%m%dT%H%M%SZ") else: - ID = mdid + _id = mdid if not e.get('ID'): - e.set('ID', ID) + e.set('ID', _id) valid_until = str(req.args.get('validUntil', e.get('validUntil', None))) if valid_until is not None and len(valid_until) > 0: @@ -1210,7 +1191,7 @@ def finalize(req, *opts): dt = dt.replace(tzinfo=None) # make dt "naive" (tz-unaware) offset = dt - now e.set('validUntil', dt.strftime("%Y-%m-%dT%H:%M:%SZ")) - except ValueError, ex: + except ValueError as ex: log.error("Unable to parse validUntil: %s (%s)" % (valid_until, ex)) # set a reasonable default: 50% of the validity @@ -1315,6 +1296,7 @@ def _setattr(req, *opts): for e in iter_entities(req.t): # log.debug("setting %s on %s" % (req.args,e.get('entityID'))) - req.md.set_entity_attributes(e, req.args) + set_entity_attributes(e, req.args) + req.md.store.update(e) return req.t diff --git a/src/pyff/constants.py b/src/pyff/constants.py index 99b70a8b..4a2921d4 100644 --- a/src/pyff/constants.py +++ b/src/pyff/constants.py @@ -2,8 +2,6 @@ Useful constants for pyFF. Mostly XML namespace declarations. """ -import os -import sys import pyconfig import logging @@ -30,14 +28,9 @@ 'software': 'http://pyff.io/software', 'domain': 'http://pyff.io/domain'} -DIGESTS = ['sha1', 'md5', 'null'] - -EVENT_DROP_ENTITY = 'event.drop.entity' -EVENT_RETRY_URL = 'event.retry.url' -EVENT_IMPORTED_METADATA = 'event.imported.metadata' -EVENT_IMPORT_FAIL = 'event.import.failed' -EVENT_REPOSITORY_LIVE = 'event.repository.live' +PLACEHOLDER_ICON = 'data:image/gif;base64,R0lGODlhAQABAIABAP///wAAACH5BAEKAAEALAAAAAABAAEAAAICTAEAOw==' +DIGESTS = ['sha1', 'md5', 'null'] class Config(object): google_api_key = pyconfig.setting("pyff.google_api_key", "google+api+key+not+set") @@ -56,9 +49,16 @@ class Config(object): aliases = pyconfig.setting("pyff.aliases", ATTRS) base_dir = pyconfig.setting("pyff.base_dir", None) proxy = pyconfig.setting("pyff.proxy", False) - store = pyconfig.setting("pyff.store", None) allow_shutdown = pyconfig.setting("pyff.allow_shutdown", False) modules = pyconfig.setting("pyff.modules", []) - + cache_ttl = pyconfig.setting("pyff.cache.ttl", 300) + default_cache_duration = pyconfig.setting("pyff.default.cache_duration", "PT1H") + respect_cache_duration = pyconfig.setting("pyff.respect_cache_duration", True) + info_buffer_size = pyconfig.setting("pyff.info_buffer_size", 10) + worker_pool_size = pyconfig.setting("pyff.worker_pool_size", 10) + store_class = pyconfig.setting("pyff.store.class", "pyff.store:MemoryStore") + update_frequency = pyconfig.setting("pyff.update_frequency",600) + request_timeout = pyconfig.setting("pyff.request_timeout",10) + request_cache_time = pyconfig.setting("pyff.request_cache_time", 5) config = Config() diff --git a/src/pyff/decorators.py b/src/pyff/decorators.py index 241630e4..a442d243 100644 --- a/src/pyff/decorators.py +++ b/src/pyff/decorators.py @@ -1,5 +1,3 @@ -from __future__ import print_function - """ Various decorators used in pyFF. """ @@ -40,7 +38,7 @@ def f_retry(*args, **kwargs): try: return f(*args, **kwargs) except ex as e: - msg = "{}, Retrying in {:d} seconds...".format(str(e), mdelay) + msg = "%s, Retrying in %d seconds..." % (str(e), mdelay) if logger: logger.warn(msg) else: @@ -81,6 +79,7 @@ class _HashedSeq(list): __slots__ = 'hashvalue' def __init__(self, tup, thehash=hash): + super(_HashedSeq, self).__init__() self[:] = tup self.hashvalue = thehash(tup) @@ -95,8 +94,24 @@ def _make_key(args, kwds, typed, thetuple=tuple, thetype=type, thelen=len): - 'Make a cache key from optionally typed positional and keyword arguments' + """ + + :param args: + :param kwds: + :param typed: + :param kwd_mark: + :param fasttypes: + :param thesorted: + :param thetuple: + :param thetype: + :param thelen: + :return: + + Make a cache key from optionally typed positional and keyword arguments + + """ key = args + sorted_items = dict() if kwds: sorted_items = thesorted(kwds.items()) key += kwd_mark diff --git a/src/pyff/exceptions.py b/src/pyff/exceptions.py new file mode 100644 index 00000000..54316128 --- /dev/null +++ b/src/pyff/exceptions.py @@ -0,0 +1,14 @@ + +__author__ = 'leifj' + + +class MetadataException(Exception): + pass + + +class MetadataExpiredException(MetadataException): + pass + + +class PyffException(Exception): + pass diff --git a/src/pyff/fetch.py b/src/pyff/fetch.py new file mode 100644 index 00000000..b80379c6 --- /dev/null +++ b/src/pyff/fetch.py @@ -0,0 +1,182 @@ +""" + +An abstraction layer for metadata fetchers. Supports both syncronous and asyncronous fetchers with cache. + +""" + +from __future__ import absolute_import, unicode_literals +from .logs import log +import os +import requests +from requests_file import FileAdapter +from .constants import config +from datetime import datetime +from collections import deque +from UserDict import DictMixin +from concurrent import futures +from .parse import parse_resource +from itertools import chain +from requests_cache.core import CachedSession + +requests.packages.urllib3.disable_warnings() + +try: + from cStringIO import StringIO +except ImportError: # pragma: no cover + print(" *** install cStringIO for better performance") + from StringIO import StringIO + + +class ResourceException(Exception): + def __init__(self, msg, wrapped=None, data=None): + self._wraped = wrapped + self._data = data + super(self.__class__, self).__init__(msg) + + def raise_wraped(self): + raise self._wraped + + +class ResourceManager(DictMixin): + + def __init__(self): + self._resources = dict() + self.shutdown = False + + def __setitem__(self, key, value): + if not isinstance(value, Resource): + raise ValueError("I can only store Resources") + self._resources[key] = value + + def __getitem__(self, key): + return self._resources[key] + + def __delitem__(self, key): + if key in self: + del self._resources[key] + + def keys(self): + return self._resources.keys() + + def values(self): + return self._resources.values() + + def walk(self, url=None): + if url is not None: + return self[url].walk() + else: + i = [r.walk() for r in self.values()] + return chain(*i) + + def add(self, r): + if not isinstance(r, Resource): + raise ValueError("I can only store Resources") + self[r.name] = r + + def __contains__(self, item): + return item in self._resources + + def reload(self, url=None): + # type: (object, basestring) -> None + with futures.ThreadPoolExecutor(max_workers=config.worker_pool_size) as executor: + tasks = dict((executor.submit(r.fetch), r) for r in self.walk(url)) + i = 0 + for future in futures.as_completed(tasks): + r = tasks[future] + try: + res = future.result() + except Exception as ex: + from traceback import print_exc + print_exc() + + log.debug("finished...") + +class Resource(object): + def __init__(self, url, post, **kwargs): + self.url = url + self.post = post + self.opts = kwargs + self.t = None + self.type = "text/plain" + self.expire_time = None + self.last_seen = None + self._infos = deque(maxlen=config.info_buffer_size) + self.children = [] + + self.opts.setdefault('fail_on_error', False) + self.opts.setdefault('as', None) + self.opts.setdefault('verify', None) + self.opts.setdefault('filter_invalid', False) + self.opts.setdefault('validate', True) + + if "://" not in self.url: + if os.path.isdir(self.url) or os.path.isfile(self.url): + self.url = "file://{}".format(os.path.abspath(self.url)) + + def __str__(self): + return "Resource {} expires at {} using ".format(self.url, self.expire_time) + \ + ",".join(["{}={}".format(k, v) for k, v in self.opts.items()]) + + def walk(self): + yield self + for c in self.children: + for cn in c.walk(): + yield cn + + def is_expired(self): + now = datetime.now() + return self.expire_time is not None and self.expire_time < now + + def is_valid(self): + return self.t is not None and not self.is_expired() + + def add_info(self, info): + self._infos.append(info) + + def add_child(self, url): + self.children.append(Resource(url, self.post, **self.opts)) + + @property + def name(self): + if 'as' in self.opts: + return self.opts['as'] + else: + return self.url + + @property + def info(self): + return self._infos[0] + + def fetch(self): + s = None + if 'file://' in self.url: + s = requests.session() + s.mount('file://', FileAdapter()) + else: + s = CachedSession(cache_name="pyff_cache", expire_after=config.request_cache_time) + + r = s.get(self.url, verify=False, timeout=config.request_timeout) + info = dict() + info['Response Headers'] = r.headers + log.debug(r.encoding) + data = r.text + log.debug(type(data)) + + if r.ok and data: + info.update(parse_resource(self, data)) + if self.t: + self.last_seen = datetime.now() + if self.post is not None: + self.t = self.post(self.t) + + if self.is_expired(): + raise ResourceException("Resource at {} has expired".format(r.url)) + + for (eid, error) in info['Validation Errors'].items(): + log.error(error) + else: + log.error("Got no valid data from {}".format(r.url)) + + self.add_info(info) + else: + raise ResourceException("Got status={:d} while fetching {}".format(r.status_code, r.url)) \ No newline at end of file diff --git a/src/pyff/locks.py b/src/pyff/locks.py index d62e1bf2..4ed49508 100644 --- a/src/pyff/locks.py +++ b/src/pyff/locks.py @@ -137,6 +137,7 @@ def acquireWrite(self, timeout=None): * In case timeout is None, the call to acquireWrite blocks until the lock request can be serviced. * In case the timeout expires before the lock could be serviced, a RuntimeError is thrown.""" + endtime = None if timeout is not None: endtime = time() + timeout me, upgradewriter = currentThread(), False diff --git a/src/pyff/logs.py b/src/pyff/logs.py index 77f2b080..839c698b 100644 --- a/src/pyff/logs.py +++ b/src/pyff/logs.py @@ -5,6 +5,13 @@ import cherrypy +def printable(s): + if isinstance(s,unicode): + return s.encode('utf8', errors='ignore').decode('utf8') + elif isinstance(s,str): + return s.decode("utf8", errors="ignore").encode('utf8') + else: + return repr(s) class PyFFLogger(object): def __init__(self): @@ -17,9 +24,9 @@ def __init__(self): def _l(self, severity, msg): if '' in cherrypy.tree.apps: - cherrypy.tree.apps[''].log("%s" % msg, severity=severity) + cherrypy.tree.apps[''].log(printable(msg), severity=severity) elif severity in self._loggers: - self._loggers[severity]("%s" % msg) + self._loggers[severity](printable(msg)) else: raise ValueError("unknown severity %s" % severity) diff --git a/src/pyff/md.py b/src/pyff/md.py index 626c8155..dc02bb8b 100644 --- a/src/pyff/md.py +++ b/src/pyff/md.py @@ -1,5 +1,3 @@ -from __future__ import print_function - """ pyFF is the SAML metadata aggregator @@ -18,7 +16,6 @@ from . import __version__ from .mdrepo import MDRepository from .pipes import plumbing -from .store import MemoryStore from .constants import config @@ -36,9 +33,6 @@ def main(): print(__doc__) sys.exit(2) - if config.store is None: - config.store = MemoryStore() - if config.loglevel is None: config.loglevel = logging.WARN @@ -55,9 +49,6 @@ def main(): raise ValueError('Invalid log level: %s' % a) elif o in '--logfile': config.logfile = a - elif o in '-R': - from pyff.store import RedisStore - config.store = RedisStore() elif o in ('-m', '--module'): config.modules.append(a) elif o in '--version': @@ -74,7 +65,7 @@ def main(): importlib.import_module(mn) try: - md = MDRepository(store=config.store) + md = MDRepository() for p in args: plumbing(p).process(md, state={'batch': True, 'stats': {}}) sys.exit(0) diff --git a/src/pyff/mdrepo.py b/src/pyff/mdrepo.py index 832650ea..53967ab7 100644 --- a/src/pyff/mdrepo.py +++ b/src/pyff/mdrepo.py @@ -3,10 +3,10 @@ This is the implementation of the active repository of SAML metadata. The 'local' and 'remote' pipes operate on this. """ -import traceback -from pyff.stats import set_metadata_info, get_metadata_info -from pyff.store import entity_attribute_dict +from __future__ import absolute_import, unicode_literals + +from .stats import get_metadata_info try: from cStringIO import StringIO @@ -14,256 +14,45 @@ print(" *** install cStringIO for better performance") from StringIO import StringIO -from copy import deepcopy -from datetime import datetime -from UserDict import UserDict -import os import operator -from concurrent import futures from lxml import etree -from lxml.builder import ElementMaker -from lxml.etree import DocumentInvalid -import xmlsec import ipaddr -from .constants import ATTRS from . import merge_strategies from .logs import log -from .utils import schema, check_signature, filter_lang, root, duration2timedelta, \ - hash_id, MetadataException, find_merge_strategy, entities_list, url2host, subdomains, avg_domain_distance, \ - iter_entities, validate_document, load_url, iso2datetime, xml_error, find_entity -from .constants import NS, NF_URI, EVENT_DROP_ENTITY, EVENT_IMPORT_FAIL +from .samlmd import entitiesdescriptor, find_merge_strategy, find_entity, iter_entities, entity_simple_summary +from .utils import root, MetadataException, avg_domain_distance, load_callable +from .constants import NS, config +from .fetch import ResourceManager etree.set_default_parser(etree.XMLParser(resolve_entities=False)) -class Event(UserDict): - pass - - -class Observable(object): - def __init__(self): - self.callbacks = [] - - def subscribe(self, callback): - self.callbacks.append(callback) - - def fire(self, **attrs): - e = Event(attrs) - e['time'] = datetime.now() - for fn in self.callbacks: - fn(e) - - -def _trunc(x, l): - return (x[:l] + '..') if len(x) > l else x - - -class MDRepository(Observable): +class MDRepository(): """A class representing a set of SAML Metadata. Instances present as dict-like objects where the keys are URIs and values are EntitiesDescriptor elements containing sets of metadata. """ - def __init__(self, metadata_cache_enabled=False, min_cache_ttl="PT5M", store=None): - self.metadata_cache_enabled = metadata_cache_enabled - self.min_cache_ttl = min_cache_ttl - - if not isinstance(self.min_cache_ttl, int): - try: - self.min_cache_ttl = duration2timedelta(self.min_cache_ttl).total_seconds() - except Exception as ex: - log.error(ex) - self.min_cache_ttl = 300 - self.respect_cache_duration = True - self.default_cache_duration = "PT10M" - self.retry_limit = 5 + def __init__(self): + # if not isinstance(self.min_cache_ttl, int): + # try: + # self.min_cache_ttl = duration2timedelta(self.min_cache_ttl).total_seconds() + # except Exception as ex: + # log.error(ex) + # self.min_cache_ttl = 300 + self.store = None + self.rm = ResourceManager() + self.store_class = load_callable(config.store_class) self.store = None - if store is not None: - if hasattr(store, '__call__'): - self.store = store() - else: - self.store = store - else: - from .store import MemoryStore - - self.store = MemoryStore() - super(MDRepository, self).__init__() - - def clone(self): - return MDRepository(metadata_cache_enabled=self.metadata_cache_enabled, - min_cache_ttl=self.min_cache_ttl, - store=self.store.clone()) - - def sha1_id(self, e): - return hash_id(e, 'sha1') - - def is_idp(self, e): - return bool(e.find(".//{%s}IDPSSODescriptor" % NS['md']) is not None) - - def is_sp(self, e): - return bool(e.find(".//{%s}SPSSODescriptor" % NS['md']) is not None) - - def icon(self, e, langs=None): - for icon in filter_lang(e.iter("{%s}Logo" % NS['mdui']), langs=langs): - return dict(url=icon.text,width=icon.get('width'),height=icon.get('height')) - - def psu(self, entity, langs): - for url in filter_lang(entity.iter("{%s}PrivacyStatementURL" % NS['mdui']), langs=langs): - return url.text - - def geoloc(self, entity): - for loc in entity.iter("{%s}GeolocationHint" % NS['mdui']): - pos = loc.text[5:].split(",") - return dict(lat=pos[0],long=pos[1]) - - def domains(self, entity): - domains = [] - for d in entity.iter("{%s}DomainHint" % NS['mdui']): - if d.text == '.': - return [] - domains.append(d.text) - if not domains: - domains.append(url2host(entity.get('entityID'))) - return domains - - def ext_display(self, entity, langs=None): - """Utility-method for computing a displayable string for a given entity. - - :param entity: An EntityDescriptor element - """ - display = entity.get('entityID') - info = '' - - for organizationName in filter_lang(entity.iter("{%s}OrganizationName" % NS['md']), langs=langs): - info = display - display = organizationName.text - - for organizationDisplayName in filter_lang(entity.iter("{%s}OrganizationDisplayName" % NS['md']), langs=langs): - info = display - display = organizationDisplayName.text - - for serviceName in filter_lang(entity.iter("{%s}ServiceName" % NS['md']), langs=langs): - info = display - display = serviceName.text - - for displayName in filter_lang(entity.iter("{%s}DisplayName" % NS['mdui']), langs=langs): - info = display - display = displayName.text - - for organizationUrl in filter_lang(entity.iter("{%s}OrganizationURL" % NS['md']), langs=langs): - info = organizationUrl.text - - for description in filter_lang(entity.iter("{%s}Description" % NS['mdui']), langs=langs): - info = description.text - - if info == entity.get('entityID'): - info = '' - - return _trunc(display.strip(), 40), _trunc(info.strip(), 256) - - def display(self, entity, langs=None): - """Utility-method for computing a displayable string for a given entity. - - :param entity: An EntityDescriptor element - """ - for displayName in filter_lang(entity.iter("{%s}DisplayName" % NS['mdui']), langs=langs): - return displayName.text - - for serviceName in filter_lang(entity.iter("{%s}ServiceName" % NS['md']), langs=langs): - return serviceName.text - - for organizationDisplayName in filter_lang(entity.iter("{%s}OrganizationDisplayName" % NS['md']), langs=langs): - return organizationDisplayName.text - - for organizationName in filter_lang(entity.iter("{%s}OrganizationName" % NS['md']), langs=langs): - return organizationName.text - - return entity.get('entityID') - - def sub_domains(self, e): - lst = [] - domains = self.domains(e) - for d in domains: - for sub in subdomains(d): - if not sub in lst: - lst.append(sub) - return lst - - def scopes(self, e): - elt = e.findall(".//{%s}IDPSSODescriptor/{%s}Extensions/{%s}Scope" % (NS['md'], NS['md'], NS['shibmd'])) - if elt is None or len(elt) == 0: - return None - return [s.text for s in elt] - - def discojson(self, e, langs=None): - if e is None: - return dict() - - title, descr = self.ext_display(e) - entity_id = e.get('entityID') - - d = dict(title=title, - descr=descr, - auth='saml', - entityID=entity_id) - - eattr = entity_attribute_dict(e) - if 'idp' in eattr[ATTRS['role']]: - d['type'] = 'idp' - d['hidden'] = 'true' - if 'http://pyff.io/category/discoverable' in eattr[ATTRS['entity-category']]: - d['hidden'] = 'false' - elif 'sp' in eattr[ATTRS['role']]: - d['type'] = 'sp' - - icon_info = self.icon(e) - if icon_info is not None: - d['icon'] = icon_info.get('url', 'data:image/gif;base64,R0lGODlhAQABAIABAP///wAAACH5BAEKAAEALAAAAAABAAEAAAICTAEAOw==') - d['icon_height'] = icon_info.get('height', 64) - d['icon_width'] = icon_info.get('width', 64) - - scopes = self.scopes(e) - if scopes is not None and len(scopes) > 0: - d['scope'] = ",".join(scopes) - - keywords = filter_lang(e.iter("{%s}Keywords" % NS['mdui']), langs=langs) - if keywords is not None: - lst = [elt.text for elt in keywords] - if len(lst) > 0: - d['keywords'] = ",".join(lst) - psu = self.psu(e, langs) - if psu: - d['psu'] = psu - geo = self.geoloc(e) - if geo: - d['geo'] = geo - - return d - - def simple_summary(self, e): - if e is None: - return dict() - - title, descr = self.ext_display(e) - entity_id = e.get('entityID') - d = dict(title=title, - descr=descr, - entityID=entity_id, - domains=";".join(self.sub_domains(e)), - id=hash_id(e, 'sha1')) - icon_info = self.icon(e) - if icon_info is not None: - url = icon_info.get('url', 'data:image/gif;base64,R0lGODlhAQABAIABAP///wAAACH5BAEKAAEALAAAAAABAAEAAAICTAEAOw==') - d['icon_url'] = url - d['icon'] = url - - psu = self.psu(e, None) - if psu: - d['psu'] = psu - - return d + def reload(self): + self.rm.reload() + store = self.store_class() + for r in self.rm.walk(): + if r.t: + store.update(r.t, tid=r.name) + self.store = store def search(self, query=None, path=None, page=None, page_limit=10, entity_filter=None, related=None): """ @@ -280,9 +69,9 @@ def search(self, query=None, path=None, page=None, page_limit=10, entity_filter= The dict in the list contains three items: -:param title: A displayable string, useful as a UI label -:param value: The entityID of the EntityDescriptor -:param id: A sha1-ID of the entityID - on the form {sha1} +:title: A displayable string, useful as a UI label +:value: The entityID of the EntityDescriptor +:id: A sha1-ID of the entityID - on the form {sha1} """ match_query = bool(len(query) > 0) @@ -340,21 +129,19 @@ def _match(qq, elt): res = [] for e in self.lookup(mexpr): d = None - #log.debug("query: %s" % query) if match_query: m = _match(query, e) if m is not None: - d = self.simple_summary(e) + d = entity_simple_summary(e) ll = d['title'].lower() if m != ll and not query[0] in ll: d['title'] = "%s - %s" % (d['title'], m) else: - - d = self.simple_summary(e) + d = entity_simple_summary(e) if d is not None: if related is not None: - d['ddist'] = avg_domain_distance(related, d['domains']) + d['ddist'] = avg_domain_distance(related, d['entity_domains']) else: d['ddist'] = 0 @@ -379,423 +166,6 @@ def sane(self): """ return len(self.store.collections()) > 0 - def extensions(self, e): - """Return a list of the Extensions elements in the EntityDescriptor - -:param e: an EntityDescriptor -:return: a list - """ - ext = e.find("./{%s}Extensions" % NS['md']) - if ext is None: - ext = etree.Element("{%s}Extensions" % NS['md']) - e.insert(0, ext) - return ext - - def annotate(self, e, category, title, message, source=None): - """Add an ATOM annotation to an EntityDescriptor or an EntitiesDescriptor. This is a simple way to - add non-normative text annotations to metadata, eg for the purpuse of generating reports. - -:param e: An EntityDescriptor or an EntitiesDescriptor element -:param category: The ATOM category -:param title: The ATOM title -:param message: The ATOM content -:param source: An optional source URL. It is added as a element with @rel='saml-metadata-source' - """ - if e.tag != "{%s}EntityDescriptor" % NS['md'] and e.tag != "{%s}EntitiesDescriptor" % NS['md']: - raise MetadataException("I can only annotate EntityDescriptor or EntitiesDescriptor elements") - subject = e.get('Name', e.get('entityID', None)) - atom = ElementMaker(nsmap={'atom': 'http://www.w3.org/2005/Atom'}, namespace='http://www.w3.org/2005/Atom') - args = [atom.published("%s" % datetime.now().isoformat()), - atom.link(href=subject, rel="saml-metadata-subject")] - if source is not None: - args.append(atom.link(href=source, rel="saml-metadata-source")) - args.extend([atom.title(title), - atom.category(term=category), - atom.content(message, type="text/plain")]) - self.extensions(e).append(atom.entry(*args)) - self.store.update(e) - - def _entity_attributes(self, e): - ext = self.extensions(e) - ea = ext.find(".//{%s}EntityAttributes" % NS['mdattr']) - if ea is None: - ea = etree.Element("{%s}EntityAttributes" % NS['mdattr']) - ext.append(ea) - return ea - - def _eattribute(self, e, attr, nf): - ea = self._entity_attributes(e) - a = ea.xpath(".//saml:Attribute[@NameFormat='%s' and @Name='%s']" % (nf, attr), - namespaces=NS, - smart_strings=False) - if a is None or len(a) == 0: - a = etree.Element("{%s}Attribute" % NS['saml']) - a.set('NameFormat', nf) - a.set('Name', attr) - ea.append(a) - else: - a = a[0] - return a - - def set_entity_attributes(self, e, d, nf=NF_URI): - - """Set an entity attribute on an EntityDescriptor - -:param e: The EntityDescriptor element -:param d: A dict of attribute-value pairs that should be added as entity attributes -:param nf: The nameFormat (by default "urn:oasis:names:tc:SAML:2.0:attrname-format:uri") to use. -:raise: MetadataException unless e is an EntityDescriptor element - """ - if e.tag != "{%s}EntityDescriptor" % NS['md']: - raise MetadataException("I can only add EntityAttribute(s) to EntityDescriptor elements") - - for attr, value in d.iteritems(): - a = self._eattribute(e, attr, nf) - velt = etree.Element("{%s}AttributeValue" % NS['saml']) - velt.text = value - a.append(velt) - - self.store.update(e) - - def set_pubinfo(self, e, publisher=None, creation_instant=None): - if e.tag != "{%s}EntitiesDescriptor" % NS['md']: - raise MetadataException("I can only set RegistrationAuthority to EntitiesDescriptor elements") - if publisher is None: - raise MetadataException("At least publisher must be provided") - - if creation_instant is None: - now = datetime.utcnow() - creation_instant = now.strftime("%Y-%m-%dT%H:%M:%SZ") - - ext = self.extensions(e) - pi = ext.find(".//{%s}PublicationInfo" % NS['mdrpi']) - if pi is not None: - raise MetadataException("A PublicationInfo element is already present") - pi = etree.Element("{%s}PublicationInfo" % NS['mdrpi']) - pi.set('publisher', publisher) - if creation_instant: - pi.set('creationInstant', creation_instant) - ext.append(pi) - - def set_reginfo(self, e, policy=None, authority=None): - if e.tag != "{%s}EntityDescriptor" % NS['md']: - raise MetadataException("I can only set RegistrationAuthority to EntityDescriptor elements") - if authority is None: - raise MetadataException("At least authority must be provided") - if policy is None: - policy = dict() - - ext = self.extensions(e) - ri = ext.find(".//{%s}RegistrationInfo" % NS['mdrpi']) - if ri is not None: - raise MetadataException("A RegistrationInfo element is already present") - - ri = etree.Element("{%s}RegistrationInfo" % NS['mdrpi']) - ext.append(ri) - ri.set('registrationAuthority', authority) - for lang, policy_url in policy.iteritems(): - rp = etree.Element("{%s}RegistrationPolicy" % NS['mdrpi']) - rp.text = policy_url - rp.set('{%s}lang' % NS['xml'], lang) - ri.append(rp) - - def expiration(self, t): - relt = root(t) - if relt.tag in ('{%s}EntityDescriptor' % NS['md'], '{%s}EntitiesDescriptor' % NS['md']): - cache_duration = self.default_cache_duration - valid_until = relt.get('validUntil', None) - if valid_until is not None: - now = datetime.utcnow() - vu = iso2datetime(valid_until) - now = now.replace(microsecond=0) - vu = vu.replace(microsecond=0, tzinfo=None) - return vu - now - elif self.respect_cache_duration: - cache_duration = relt.get('cacheDuration', self.default_cache_duration) - return duration2timedelta(cache_duration) - - return None - - def fetch_metadata(self, resources, max_workers=5, timeout=120, max_tries=5, validate=False, fail_on_error=False, filter_invalid=True): - """Fetch a series of metadata URLs and optionally verify signatures. - -:param resources: A list of triples (url,cert-or-fingerprint,id, post-callback) -:param max_workers: The maximum number of parallell downloads to run -:param validate: Turn on or off schema validation - -The list of triples is processed by first downloading the URL. If a cert-or-fingerprint -is supplied it is used to validate the signature on the received XML. Two forms of XML -is supported: SAML Metadata and XRD. - -SAML metadata is (if valid and contains a valid signature) stored under the 'id' -identifier (which defaults to the URL unless provided in the triple. - -XRD elements are processed thus: for all elements that contain a ds;KeyInfo -elements with a X509Certificate and where the element contains the string -'urn:oasis:names:tc:SAML:2.0:metadata', the corresponding element is download -and verified. - """ - resources = [(url, verifier, tid, post, True) for url, verifier, tid, post in resources] - return self._fetch_metadata(resources, - max_workers=max_workers, - timeout=timeout, - max_tries=max_tries, - validate=validate, - fail_on_error=fail_on_error, - filter_invalid=filter_invalid) - - def _fetch_metadata(self, resources, max_workers=5, timeout=120, max_tries=5, validate=False, fail_on_error=False, filter_invalid=True): - tries = dict() - - def _process_url(rurl, verifier, tid, post, enable_cache=True): - tries.setdefault(rurl, 0) - - try: - resource = load_url(rurl, timeout=timeout, enable_cache=enable_cache) - except Exception as ex: - raise MetadataException(ex, "Exception fetching '%s': %s" % (rurl, str(ex)) ) - if not resource.result: - raise MetadataException("error fetching '%s'" % rurl) - xml = resource.result.strip() - retry_resources = [] - info = { - 'Time Spent': "%s seconds" % resource.time - } - - tries[rurl] += 1 - info['Tries'] = str(tries[rurl]) - - if resource.result is not None: - info['Bytes'] = str(len(resource.result)) - else: - raise MetadataException("empty response fetching '%s'" % resource.url) - - info['URL'] = str(rurl) - info['Cached'] = str(resource.cached) - info['Date'] = str(resource.date) - info['Last-Modified'] = str(resource.last_modified) - info['Validation Errors'] = dict() - info['Description'] = "Remote metadata" - info['Status'] = 'success' - if not validate: - info['Status'] = 'warning' - info['Description'] += " (un-validated)" - if not enable_cache: - info['Status'] = 'info' - - if resource.resp is not None: - info['HTTP Response'] = resource.resp - - t, offset = self.parse_metadata(StringIO(xml), - key=verifier, - base_url=rurl, - fail_on_error=fail_on_error, - filter_invalid=filter_invalid, - validate=validate, - validation_errors=info['Validation Errors'], - expiration=self.expiration, - post=post) - - if t is None: - self.fire(type=EVENT_IMPORT_FAIL, url=rurl) - raise MetadataException("no valid metadata found at '%s'" % rurl) - - relt = root(t) - - expired = False - if offset is not None: - expire_time = datetime.now() + offset - ttl = offset.total_seconds() - info['Expiration Time'] = str(expire_time) - info['Cache TTL'] = str(ttl) - if ttl < self.min_cache_ttl: - if tries[rurl] < max_tries: # try to get fresh md but we'll use what we have anyway - retry_resources.append((rurl, verifier, tid, post, False)) - else: - log.error("giving up on %s" % rurl) - if ttl < 0: - expired = True - - if not expired: - if relt.tag in ('{%s}XRD' % NS['xrd'], '{%s}XRDS' % NS['xrd']): - log.debug("%s looks like an xrd document" % rurl) - for xrd in t.iter("{%s}XRD" % NS['xrd']): - for link in xrd.findall(".//{%s}Link[@rel='%s']" % (NS['xrd'], NS['md'])): - link_href = link.get("href") - certs = xmlsec.crypto.CertDict(link) - fingerprints = certs.keys() - fp = None - if len(fingerprints) > 0: - fp = fingerprints[0] - log.debug("XRD: '%s' verified by '%s'" % (link_href, link)) - tries.setdefault(link_href, 0) - if tries[link_href] < max_tries: - retry_resources.append((link_href, fp, link_href, post, True)) - elif relt.tag in ('{%s}EntityDescriptor' % NS['md'], '{%s}EntitiesDescriptor' % NS['md']): - n = self.store.update(t, tid) - info['Size'] = str(n) - else: - raise MetadataException("unknown metadata type for '%s' (%s)" % (rurl, relt.tag)) - - set_metadata_info(tid, info) - log.debug(info) - - return retry_resources - - while resources: - with futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_url = dict((executor.submit(_process_url, url, verifier, tid, post, enable_cache), url) - for url, verifier, tid, post, enable_cache in resources) - - next_resources = [] - for future in futures.as_completed(future_to_url): - url = future_to_url[future] - if future.exception() is not None: - if fail_on_error: - log.error('fetching %r generated an exception' % url) - raise future.exception() - else: - log.error('fetching %s generated an exception: %s' % (url, future.exception())) - else: - next_resources.extend(future.result()) - resources = next_resources - log.debug("retrying %s" % resources) - - def filter_invalids(self, t, base_url, validation_errors): - xsd = schema() - for e in iter_entities(t): - if not xsd.validate(e): - error = xml_error(xsd.error_log, m=base_url) - entity_id = e.get("entityID") - log.warn("removing '%s': schema validation failed (%s)" % (entity_id, error)) - validation_errors[entity_id] = error - if e.getparent() is None: - return None - e.getparent().remove(e) - self.fire(type=EVENT_DROP_ENTITY, url=base_url, entityID=entity_id, error=error) - return t - - def parse_metadata(self, - source, - key=None, - base_url=None, - fail_on_error=False, - filter_invalid=True, - validate=True, - validation_errors=None, - expiration=None, - post=None): - """Parse a piece of XML and return an EntitiesDescriptor element after validation. - -:param source: a file-like object containing SAML metadata -:param key: a certificate (file) or a SHA1 fingerprint to use for signature verification -:param base_url: use this base url to resolve relative URLs for XInclude processing -:param fail_on_error: (default: False) -:param filter_invalid: (default True) remove invalid EntityDescriptor elements rather than raise an errror -:param validate: (default: True) set to False to turn off all XML schema validation -:param post: A callable that will be called to modify the parse-tree after schema validation -:param validation_errors: A dict that will be used to return validation errors to the caller -:param expiration: A callable that returns the valid_until datetime of the prsed metadata -(but after xinclude processing and signature validation) - """ - - if validation_errors is None: - validation_errors = dict() - - try: - parser = etree.XMLParser(resolve_entities=False) - t = etree.parse(source, base_url=base_url, parser=parser) - t.xinclude() - - valid_until = None - if expiration is not None: - valid_until = expiration(t) - - t = check_signature(t, key) - - # get rid of ID as early as possible - probably not unique - for e in iter_entities(t): - if e.get('ID') is not None: - del e.attrib['ID'] - - t = root(t) - - if validate: - if filter_invalid: - t = self.filter_invalids(t, base_url=base_url, validation_errors=validation_errors) - else: # all or nothing - try: - validate_document(t) - except DocumentInvalid as ex: - raise MetadataException("schema validation failed: [%s] '%s': %s" % - (base_url, source, xml_error(ex.error_log, m=base_url))) - - if t is not None: - if t.tag == "{%s}EntityDescriptor" % NS['md']: - t = self.entity_set([t], base_url, copy=False, nsmap=t.nsmap) - - if post is not None: - t = post(t) - - except Exception as ex: - if fail_on_error: - raise ex - traceback.print_exc(ex) - log.error(ex) - return None, None - - log.debug("returning %d valid entities" % len(list(iter_entities(t)))) - - return t, valid_until - - def load_dir(self, directory, ext=".xml", url=None, validate=False, post=None, description=None, fail_on_error=True, filter_invalid=True): - """ -:param directory: A directory to walk. -:param ext: Include files with this extension (default .xml) - -Traverse a directory tree looking for metadata. Files ending in the specified extension are included. Directories -starting with '.' are excluded. - """ - if url is None: - url = directory - - if description is None: - description = "All entities found in %s" % directory - - entities = [] - for top, dirs, files in os.walk(directory): - for dn in dirs: - if dn.startswith("."): - dirs.remove(dn) - for nm in files: - if nm.endswith(ext): - log.debug("parsing from file %s" % nm) - fn = os.path.join(top, nm) - try: - validation_errors = dict() - t, valid_until = self.parse_metadata(fn, - base_url=url, - fail_on_error=fail_on_error, - filter_invalid=filter_invalid, - validate=validate, - validation_errors=validation_errors, - post=post) - entities.extend(entities_list(t)) # local metadata is assumed to be ok - for (eid, error) in validation_errors.iteritems(): - log.error(error) - except Exception as ex: - if fail_on_error: - raise MetadataException('Error parsing "%s": %s' % (fn, str(ex))) - log.error(ex) - - if entities: - info = dict(Description=description) - n = self.store.update(self.entity_set(entities, url, validate=validate, copy=False), url) - info['Size'] = str(n) - set_metadata_info(url, info) - else: - log.info("no entities found in %s" % directory) - def find(self, t, member): relt = root(t) if type(member) is str or type(member) is unicode: @@ -870,60 +240,30 @@ def lookup(self, member, xp=None): log.debug("got %d entities after filtering" % len(l)) return l - def entity_set(self, entities, name, lookup_fn=None, cacheDuration=None, validUntil=None, validate=True, copy=True, nsmap=dict()): + def entity_set(self, + entities, + name, + lookup_fn=None, + cache_duration=None, + valid_until=None, + validate=True, + copy=True): """ :param entities: a set of entities specifiers (lookup is used to find entities from this set) :param name: the @Name attribute -:param cacheDuration: an XML timedelta expression, eg PT1H for 1hr -:param validUntil: a relative time eg 2w 4d 1h for 2 weeks, 4 days and 1hour from now. +:param cache_duration: an XML timedelta expression, eg PT1H for 1hr +:param valid_until: a relative time eg 2w 4d 1h for 2 weeks, 4 days and 1hour from now. +:param lookup_fn: a callable used to lookup entities by entityID Produce an EntityDescriptors set from a list of entities. Optional Name, cacheDuration and validUntil are affixed. """ - if lookup_fn is None: - lookup_fn = self.lookup - - def _resolve(member, l_fn): - if hasattr(member, 'tag'): - return [member] - else: - return l_fn(member) - - nsmap.update(NS) - resolved_entities = set() - for member in entities: - for entity in _resolve(member, lookup_fn): - resolved_entities.add(entity) - - if not resolved_entities: - return None - - for entity in resolved_entities: - nsmap.update(entity.nsmap) - - log.debug("selecting %d entities before validation" % len(resolved_entities)) - - attrs = dict(Name=name, nsmap=nsmap) - if cacheDuration is not None: - attrs['cacheDuration'] = cacheDuration - if validUntil is not None: - attrs['validUntil'] = validUntil - t = etree.Element("{%s}EntitiesDescriptor" % NS['md'], **attrs) - for entity in resolved_entities: - entity_id = entity.get('entityID', None) - if (entity is not None) and (entity_id is not None): - ent_insert = entity - if copy: - ent_insert = deepcopy(ent_insert) - t.append(ent_insert) - - if validate: - try: - validate_document(t) - except DocumentInvalid as ex: - log.debug(xml_error(ex.error_log)) - raise MetadataException("XML schema validation failed: %s" % name) - return t + return entitiesdescriptor(entities, name, + lookup_fn=self.lookup, + cache_duration=cache_duration, + valid_until=valid_until, + validate=validate, + copy=copy) def summary(self, uri): @@ -978,8 +318,8 @@ def merge(self, t, nt, strategy=merge_strategies.replace_existing, strategy_name first the strategy callable is called with the old and new EntityDescriptor elements as parameters. The strategy callable thus must implement the following pattern: -:param old_e: The EntityDescriptor from t -:param e: The EntityDescriptor from nt +:old_e: The EntityDescriptor from t +:e: The EntityDescriptor from nt :return: A merged EntityDescriptor element Before each call to strategy old_e is removed from the MDRepository index and after diff --git a/src/pyff/mdx.py b/src/pyff/mdx.py index f93603ef..f4a5a8c9 100644 --- a/src/pyff/mdx.py +++ b/src/pyff/mdx.py @@ -23,8 +23,6 @@ Listen on the specified port -H|--host= Listen on the specified interface - -R - Use redis-based store --frequency= Wake up every and run the update pipeline. By default the frequency is set to 600. @@ -44,6 +42,9 @@ One or more pipeline files """ + +from __future__ import absolute_import, print_function, unicode_literals + import importlib import pkg_resources @@ -67,29 +68,31 @@ from cherrypy.process.plugins import Monitor, SimplePlugin from cherrypy.lib import caching from simplejson import dumps -from pyff.constants import ATTRS, EVENT_REPOSITORY_LIVE, config -from pyff.locks import ReadWriteLock -from pyff.mdrepo import MDRepository -from pyff.pipes import plumbing -from pyff.utils import resource_string, xslt_transform, dumptree, duration2timedelta, debug_observer, render_template -from pyff.logs import log, SysLogLibHandler +from .constants import config +from .locks import ReadWriteLock +from .mdrepo import MDRepository +from .pipes import plumbing +from .utils import resource_string, xslt_transform, dumptree, duration2timedelta, \ + debug_observer, render_template, hash_id +from .logs import log, SysLogLibHandler +from .samlmd import entity_simple_summary, entity_display_name import logging -from pyff.stats import stats +from .stats import stats from lxml import html from datetime import datetime from lxml import etree -from pyff import __version__ as pyff_version -from pyff.store import MemoryStore, RedisStore +from . import __version__ as pyff_version from publicsuffix import PublicSuffixList -import i18n +from .i18n import language +from . import samlmd -_ = i18n.language.ugettext +_ = language.ugettext site_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "site") class MDUpdate(Monitor): - def __init__(self, bus, frequency=600, server=None): + def __init__(self, bus, frequency=int(config.update_frequency), server=None): self.lock = Lock() self.server = server self.bus = bus @@ -102,26 +105,11 @@ def run(self, server): try: self.lock.acquire() locked = True - md = self.server.md.clone() for p in server.plumbings: - state = {'update': True, 'stats': {}} - p.process(md, state) - stats.update(state.get('stats', {})) - - with server.lock.writelock: - log.debug("update produced new repository with %d entities" % server.md.store.size()) - server.md = md - server.md.fire(type=EVENT_REPOSITORY_LIVE, size=server.md.store.size()) - stats['Repository Update Time'] = datetime.now() - stats['Repository Size'] = server.md.store.size() + state = {'update': True} + p.process(self.server.md, state) - self.nruns += 1 - - stats['Updates Since Server Start'] = self.nruns - - if hasattr(self.server.md.store, 'periodic'): - self.server.md.store.periodic(stats) except Exception as ex: log.error(ex.message) finally: @@ -244,12 +232,12 @@ def webfinger(self, resource=None, rel=None): links = list() jrd['links'] = links - def _links(a): + def _links(url): links.append( dict(rel='urn:oasis:names:tc:SAML:2.0:metadata', role="provider", - href='%s/%s.xml' % (cherrypy.request.base, a))) - links.append(dict(rel='disco-json', href='%s/%s.json' % (cherrypy.request.base, a))) + href='%s/%s.xml' % (cherrypy.request.base, url))) + links.append(dict(rel='disco-json', href='%s/%s.json' % (cherrypy.request.base, url))) for a in self.server.md.store.collections(): if '://' not in a: @@ -476,7 +464,7 @@ def __init__(self, pipes=None, observers=None): self.refresh.subscribe() self.aliases = config.aliases self.psl = PublicSuffixList() - self.md = MDRepository(metadata_cache_enabled=config.caching_enabled, store=config.store) + self.md = MDRepository() if config.autoreload: for f in pipes: @@ -484,7 +472,7 @@ def __init__(self, pipes=None, observers=None): @property def ready(self): - return self.md.store.ready() + return self.md.store is not None def reload_pipeline(self): new_plumbings = [plumbing(v) for v in self._pipes] @@ -582,7 +570,7 @@ def _d(x, do_split=True): entity_id = kwargs.get('entityID', None) if entity_id is None: raise HTTPError(400, _("400 Bad Request - missing entityID")) - pdict['sp'] = self.md.sha1_id(entity_id) + pdict['sp'] = hash_id(entity_id, 'sha1') e = self.md.store.lookup(entity_id) if e is None or len(e) == 0: raise HTTPError(404) @@ -590,7 +578,7 @@ def _d(x, do_split=True): if len(e) > 1: raise HTTPError(400, _("400 Bad Request - multiple matches for") + " %s" % entity_id) - pdict['entity'] = self.md.simple_summary(e[0]) + pdict['entity'] = entity_simple_summary(e[0]) if not path: pdict['search'] = "/search/" pdict['list'] = "/role/idp.json" @@ -612,6 +600,7 @@ def _d(x, do_split=True): if query is None: log.debug("empty query - creating one") query = [cherrypy.request.remote.ip] + # XXX fix this - urlparse is not 3.x and also this way to handle extra info sucks referrer = cherrypy.request.headers.get('referrer', None) if referrer is not None: log.debug("including referrer: %s" % referrer) @@ -646,6 +635,7 @@ def _d(x, do_split=True): title = _("Metadata By Attributes") return render_template("index.html", md=self.md, + samlmd=samlmd, alias=alias, aliases=self.aliases, title=title) @@ -656,6 +646,7 @@ def _d(x, do_split=True): if len(entities) > 1: return render_template("metadata.html", md=self.md, + samlmd=samlmd, subheading=q, entities=entities) else: @@ -675,7 +666,7 @@ def _d(x, do_split=True): p.text = c_txt xml = dumptree(t, xml_declaration=False).decode('utf-8') return render_template("entity.html", - headline=self.md.display(entity).strip(), + headline=entity_display_name(entity), subheading=entity.get('entityID'), entity_id=entity.get('entityID'), content=xml) @@ -714,9 +705,6 @@ def main(): print(__doc__) sys.exit(2) - if config.store is None: - config.store = MemoryStore() - if config.loglevel is None: config.loglevel = logging.INFO @@ -748,8 +736,6 @@ def main(): config.port = int(a) elif o in ('--pidfile', '-p'): config.pid_file = a - elif o in '-R': - config.store = RedisStore() elif o in ('--no-caching', '-C'): config.caching_enabled = False elif o in ('--caching-delay', 'D'): @@ -774,7 +760,7 @@ def main(): elif o in ('-m', '--module'): config.modules.append(a) elif o in '--version': - print("pyffd version {} (cherrypy version {})".format(pyff_version, cherrypy.__version__)) + print("pyffd version %s (cherrypy version %s)" % (pyff_version, cherrypy.__version__)) sys.exit(0) else: raise ValueError("Unknown option '%s'" % o) @@ -836,6 +822,7 @@ def error_page(code, **kwargs): 'tools.caching.antistampede_timeout': 30, 'tools.caching.delay': 3600, # this is how long we keep static stuff 'tools.cpstats.on': True, + 'checker.on': False, 'tools.proxy.on': config.proxy, 'allow_shutdown': config.allow_shutdown, 'error_page.404': lambda **kwargs: error_page(404, _=_, **kwargs), diff --git a/src/pyff/merge_strategies.py b/src/pyff/merge_strategies.py index e63f6cb0..104bacce 100644 --- a/src/pyff/merge_strategies.py +++ b/src/pyff/merge_strategies.py @@ -13,4 +13,4 @@ def replace_existing(old, new): def remove(old, new): if old is not None: old.getparent().remove(old) - return None \ No newline at end of file + return None diff --git a/src/pyff/parse.py b/src/pyff/parse.py new file mode 100644 index 00000000..ec94fcea --- /dev/null +++ b/src/pyff/parse.py @@ -0,0 +1,59 @@ + +import os + +__author__ = 'leifj' + +class ParserException(Exception): + def __init__(self, msg, wrapped=None, data=None): + self._wraped = wrapped + self._data = data + super(self.__class__, self).__init__(msg) + + def raise_wraped(self): + raise self._wraped + + +class NoParser(): + def magic(self, content): + return True + + def parse(self, resource, content): + raise ParserException("No matching parser found for %s" % resource.url) + + +class DirectoryParser(): + + def __init__(self, ext): + self.ext = ext + + def magic(self, content): + return os.path.isdir(content) + + def _find_matching_files(self, dir): + for top, dirs, files in os.walk(dir): + for dn in dirs: + if dn.startswith("."): + dirs.remove(dn) + + for nm in files: + if nm.endswith(self.ext): + fn = os.path.join(top, nm) + yield fn + + def parse(self, resource, content): + resource.children = [] + for fn in self._find_matching_files(dir): + resource.add_child("file://"+fn) + + return dict() + + +_parsers = [DirectoryParser('.xml'), NoParser()] + +def add_parser(parser): + _parsers.insert(0,parser) + +def parse_resource(resource, content): + for parser in _parsers: + if parser.magic(content): + return parser.parse(resource, content) \ No newline at end of file diff --git a/src/pyff/pipes.py b/src/pyff/pipes.py index b76c3853..e6ef8a8d 100644 --- a/src/pyff/pipes.py +++ b/src/pyff/pipes.py @@ -112,7 +112,7 @@ def _n(_d): class PipelineCallback(object): """ -A delayed pipeline callback used as a post for parse_metadata +A delayed pipeline callback used as a post for parse_saml_metadata """ def __init__(self, entry_point, req): @@ -234,16 +234,16 @@ def process(self, md, state=None, t=None): if not state: state = dict() # req = Plumbing.Request(self, md, t, state=state) - # self._process(req) + # self.iprocess(req) # return req.t return Plumbing.Request(self, md, t, state=state).process(self) - def _process(self, req): + def iprocess(self, req): """The inner request pipeline processor. :param req: The request to run through the pipeline """ - log.debug('Processing \n%s' % self) + #log.debug('Processing \n%s' % self) for p in self.pipeline: try: pipe, opts, name, args = load_pipe(p) diff --git a/src/pyff/samlmd.py b/src/pyff/samlmd.py new file mode 100644 index 00000000..f02adbc4 --- /dev/null +++ b/src/pyff/samlmd.py @@ -0,0 +1,718 @@ +from __future__ import absolute_import, unicode_literals +from datetime import datetime +from .utils import parse_xml, check_signature, root, validate_document, xml_error, \ + schema, iso2datetime, duration2timedelta, filter_lang, url2host, trunc_str, subdomains, \ + has_tag, hash_id + +from .logs import log +from .constants import config, NS, ATTRS, NF_URI, PLACEHOLDER_ICON +from lxml import etree +from lxml.builder import ElementMaker +from lxml.etree import DocumentInvalid +from itertools import chain +from copy import deepcopy +from .exceptions import * +from StringIO import StringIO + + +class EntitySet(object): + def __init__(self, initial=None): + self._e = dict() + if initial is not None: + for e in initial: + self.add(e) + + def add(self, value): + self._e[value.get('entityID')] = value + + def discard(self, value): + entity_id = value.get('entityID') + if entity_id in self._e: + del self._e[entity_id] + + def __iter__(self): + for e in self._e.values(): + yield e + + def __len__(self): + return len(self._e.keys()) + + def __contains__(self, item): + return item.get('entityID') in self._e.keys() + + +def find_merge_strategy(strategy_name): + if '.' not in strategy_name: + strategy_name = "pyff.merge_strategies.%s" % strategy_name + (mn, sep, fn) = strategy_name.rpartition('.') + # log.debug("import %s from %s" % (fn,mn)) + module = None + if '.' in mn: + (pn, sep, modn) = mn.rpartition('.') + module = getattr(__import__(pn, globals(), locals(), [modn], -1), modn) + else: + module = __import__(mn, globals(), locals(), [], -1) + strategy = getattr(module, fn) # we might aswell let this fail early if the strategy is wrongly named + + if strategy is None: + raise MetadataException("Unable to find merge strategy %s" % strategy_name) + + return strategy + + +def parse_saml_metadata(source, + key=None, + base_url=None, + fail_on_error=False, + filter_invalid=True, + validate=True, + validation_errors=None): + """Parse a piece of XML and return an EntitiesDescriptor element after validation. + +:param source: a file-like object containing SAML metadata +:param key: a certificate (file) or a SHA1 fingerprint to use for signature verification +:param base_url: use this base url to resolve relative URLs for XInclude processing +:param fail_on_error: (default: False) +:param filter_invalid: (default True) remove invalid EntityDescriptor elements rather than raise an errror +:param validate: (default: True) set to False to turn off all XML schema validation +:param validation_errors: A dict that will be used to return validation errors to the caller +(but after xinclude processing and signature validation) + """ + + if validation_errors is None: + validation_errors = dict() + + try: + t = parse_xml(source, base_url=base_url) + t.xinclude() + + expire_time_offset = metadata_expiration(t) + + t = check_signature(t, key) + + # get rid of ID as early as possible - probably not unique + for e in iter_entities(t): + if e.get('ID') is not None: + del e.attrib['ID'] + + t = root(t) + + if validate: + if filter_invalid: + t = filter_invalids_from_document(t, base_url=base_url, validation_errors=validation_errors) + else: # all or nothing + try: + validate_document(t) + except DocumentInvalid as ex: + raise MetadataException("schema validation failed: [%s] '%s': %s" % + (base_url, source, xml_error(ex.error_log, m=base_url))) + + if t is not None: + if t.tag == "{%s}EntityDescriptor" % NS['md']: + t = entitiesdescriptor([t], base_url, copy=False) + + except Exception as ex: + if fail_on_error: + raise ex + #traceback.print_exc(ex) + log.error(ex) + return None, None + + log.debug("returning %d valid entities" % len(list(iter_entities(t)))) + + return t, expire_time_offset + + +class SAMLMetadataResourceParser(): + + def __init__(self): + pass + + def magic(self, content): + return "EntitiesDescriptor" in content or "EntityDescriptor" in content + + def parse(self, resource, content): + info = dict() + info['Validation Errors'] = dict() + t, expire_time_offset = parse_saml_metadata(StringIO(content.encode('utf8')), + key=resource.opts['verify'], + base_url=resource.url, + fail_on_error=resource.opts['fail_on_error'], + filter_invalid=resource.opts['filter_invalid'], + validate=resource.opts['validate'], + validation_errors=info['Validation Errors']) + + if expire_time_offset is not None: + expire_time = datetime.now() + expire_time_offset + resource.expire_time = expire_time + info['Expiration Time'] = str(expire_time) + + resource.t = t + resource.type = "application/samlmetadata+xml" + + return info + + +from .parse import add_parser +add_parser(SAMLMetadataResourceParser()) + + +def metadata_expiration(t): + relt = root(t) + if relt.tag in ('{%s}EntityDescriptor' % NS['md'], '{%s}EntitiesDescriptor' % NS['md']): + cache_duration = config.default_cache_duration + valid_until = relt.get('validUntil', None) + if valid_until is not None: + now = datetime.utcnow() + vu = iso2datetime(valid_until) + now = now.replace(microsecond=0) + vu = vu.replace(microsecond=0, tzinfo=None) + return vu - now + elif config.respect_cache_duration: + cache_duration = relt.get('cacheDuration', config.default_cache_duration) + if not cache_duration: + cache_duration = config.default_cache_duration + return duration2timedelta(cache_duration) + + return None + + +def filter_invalids_from_document(t, base_url, validation_errors): + xsd = schema() + for e in iter_entities(t): + if not xsd.validate(e): + error = xml_error(xsd.error_log, m=base_url) + entity_id = e.get("entityID") + log.warn('removing \'%s\': schema validation failed (%s)' % (entity_id, error)) + validation_errors[entity_id] = error + if e.getparent() is None: + return None + e.getparent().remove(e) + return t + + +def entitiesdescriptor(entities, name, lookup_fn=None, cache_duration=None, valid_until=None, validate=True, copy=True): + """ +:param lookup_fn: a function used to lookup entities by name +:param entities: a set of entities specifiers (lookup is used to find entities from this set) +:param name: the @Name attribute +:param cache_duration: an XML timedelta expression, eg PT1H for 1hr +:param valid_until: a relative time eg 2w 4d 1h for 2 weeks, 4 days and 1hour from now. +:param copy: set to False to avoid making a copy of all the entities in list. This may be dangerous. +:param validate: set to False to skip schema validation of the resulting EntitiesDesciptor element. This is dangerous! + +Produce an EntityDescriptors set from a list of entities. Optional Name, cacheDuration and validUntil are affixed. + """ + + def _insert(ent): + entity_id = ent.get('entityID', None) + # log.debug("adding %s to set" % entity_id) + if (ent is not None) and (entity_id is not None) and (entity_id not in seen): + ent_insert = ent + if copy: + ent_insert = deepcopy(ent_insert) + t.append(ent_insert) + # log.debug("really adding %s to set" % entity_id) + seen[entity_id] = True + + attrs = dict(Name=name, nsmap=NS) + if cache_duration is not None: + attrs['cacheDuration'] = cache_duration + if valid_until is not None: + attrs['validUntil'] = valid_until + t = etree.Element("{%s}EntitiesDescriptor" % NS['md'], **attrs) + nent = 0 + seen = {} # TODO make better de-duplication + for member in entities: + if hasattr(member, 'tag'): + _insert(member) + nent += 1 + else: + for entity in lookup_fn(member): + _insert(entity) + nent += 1 + + log.debug("selecting %d entities before validation" % nent) + + if not nent: + return None + + if validate: + try: + validate_document(t) + except DocumentInvalid as ex: + log.debug(xml_error(ex.error_log)) + raise MetadataException("XML schema validation failed: %s" % name) + return t + + +def entities_list(t=None): + """ + :param t: An EntitiesDescriptor or EntityDescriptor element + + Returns the list of contained EntityDescriptor elements + """ + if t is None: + return [] + elif root(t).tag == "{%s}EntityDescriptor" % NS['md']: + return [root(t)] + else: + return iter_entities(t) + + +def iter_entities(t): + if t is None: + return [] + return t.iter('{%s}EntityDescriptor' % NS['md']) + + +def find_entity(t, e_id, attr='entityID'): + for e in iter_entities(t): + if e.get(attr) == e_id: + return e + return None + + +# semantics copied from https://github.com/lordal/md-summary/blob/master/md-summary +# many thanks to Anders Lordahl & Scotty Logan for the idea +def guess_entity_software(e): + for elt in chain(e.findall(".//{%s}SingleSignOnService" % NS['md']), + e.findall(".//{%s}AssertionConsumerService" % NS['md'])): + location = elt.get('Location') + if location: + if 'Shibboleth.sso' in location \ + or 'profile/SAML2/POST/SSO' in location \ + or 'profile/SAML2/Redirect/SSO' in location \ + or 'profile/Shibboleth/SSO' in location: + return 'Shibboleth' + if location.endswith('saml2/idp/SSOService.php') or 'saml/sp/saml2-acs.php' in location: + return 'SimpleSAMLphp' + if location.endswith('user/authenticate'): + return 'KalturaSSP' + if location.endswith('adfs/ls') or location.endswith('adfs/ls/'): + return 'ADFS' + if '/oala/' in location or 'login.openathens.net' in location: + return 'OpenAthens' + if '/idp/SSO.saml2' in location or '/sp/ACS.saml2' in location \ + or 'sso.connect.pingidentity.com' in location: + return 'PingFederate' + if 'idp/saml2/sso' in location: + return 'Authentic2' + if 'nidp/saml2/sso' in location: + return 'Novell Access Manager' + if 'affwebservices/public/saml2sso' in location: + return 'CASiteMinder' + if 'FIM/sps' in location: + return 'IBMTivoliFIM' + if 'sso/post' in location \ + or 'sso/redirect' in location \ + or 'saml2/sp/acs' in location \ + or 'saml2/ls' in location \ + or 'saml2/acs' in location \ + or 'acs/redirect' in location \ + or 'acs/post' in location \ + or 'saml2/sp/ls/' in location: + return 'PySAML' + if 'engine.surfconext.nl' in location: + return 'SURFConext' + if 'opensso' in location: + return 'OpenSSO' + if 'my.salesforce.com' in location: + return 'Salesforce' + + entity_id = e.get('entityID') + if '/shibboleth' in entity_id: + return 'Shibboleth' + if entity_id.endswith('/metadata.php'): + return 'SimpleSAMLphp' + if '/openathens' in entity_id: + return 'OpenAthens' + + return 'other' + + +def is_idp(entity): + return has_tag(entity, "{%s}IDPSSODescriptor" % NS['md']) + + +def is_sp(entity): + return has_tag(entity, "{%s}SPSSODescriptor" % NS['md']) + + +def is_aa(entity): + return has_tag(entity, "{%s}AttributeAuthorityDescriptor" % NS['md']) + + +def _domains(entity): + domains = [url2host(entity.get('entityID'))] + for d in entity.iter("{%s}DomainHint" % NS['mdui']): + if d.text not in domains: + domains.append(d.text) + return domains + + +def with_entity_attributes(entity, cb): + def _stext(e): + if e.text is not None: + return e.text.strip() + + for ea in entity.iter("{%s}EntityAttributes" % NS['mdattr']): + for a in ea.iter("{%s}Attribute" % NS['saml']): + an = a.get('Name', None) + if a is not None: + values = filter(lambda x: x is not None, [_stext(v) for v in a.iter("{%s}AttributeValue" % NS['saml'])]) + cb(an, values) + + +def _all_domains_and_subdomains(entity): + dlist = [] + try: + for dn in _domains(entity): + for sub in subdomains(dn): + dlist.append(sub) + except ValueError: + pass + return dlist + + +def entity_attribute_dict(entity): + d = {} + + def _u(an, values): + d[an] = values + + with_entity_attributes(entity, _u) + + d[ATTRS['domain']] = _all_domains_and_subdomains(entity) + + roles = d.setdefault(ATTRS['role'], []) + if is_idp(entity): + roles.append('idp') + eca = ATTRS['entity-category'] + ec = d.setdefault(eca, []) + if 'http://refeds.org/category/hide-from-discovery' not in ec: + ec.append('http://pyff.io/category/discoverable') + if is_sp(entity): + roles.append('sp') + if is_aa(entity): + roles.append('aa') + + if ATTRS['software'] not in d: + d[ATTRS['software']] = [guess_entity_software(entity)] + + return d + + +def entity_icon(e, langs=None): + for ico in filter_lang(e.iter("{%s}Logo" % NS['mdui']), langs=langs): + return dict(url=ico.text, width=ico.get('width'), height=ico.get('height')) + + +def privacy_statement_url(entity, langs): + for url in filter_lang(entity.iter("{%s}PrivacyStatementURL" % NS['mdui']), langs=langs): + return url.text + + +def entity_geoloc(entity): + for loc in entity.iter("{%s}GeolocationHint" % NS['mdui']): + pos = loc.text[5:].split(",") + return dict(lat=pos[0], long=pos[1]) + + +def entity_domains(entity): + domains = [] + for d in entity.iter("{%s}DomainHint" % NS['mdui']): + if d.text == '.': + return [] + domains.append(d.text) + if not domains: + domains.append(url2host(entity.get('entityID'))) + return domains + + +def entity_extended_display(entity, langs=None): + """Utility-method for computing a displayable string for a given entity. + + :param entity: An EntityDescriptor element + :param langs: The list of languages to search in priority order + """ + display = entity.get('entityID') + info = '' + + for organizationName in filter_lang(entity.iter("{%s}OrganizationName" % NS['md']), langs=langs): + info = display + display = organizationName.text + + for organizationDisplayName in filter_lang(entity.iter("{%s}OrganizationDisplayName" % NS['md']), langs=langs): + info = display + display = organizationDisplayName.text + + for serviceName in filter_lang(entity.iter("{%s}ServiceName" % NS['md']), langs=langs): + info = display + display = serviceName.text + + for displayName in filter_lang(entity.iter("{%s}DisplayName" % NS['mdui']), langs=langs): + info = display + display = displayName.text + + for organizationUrl in filter_lang(entity.iter("{%s}OrganizationURL" % NS['md']), langs=langs): + info = organizationUrl.text + + for description in filter_lang(entity.iter("{%s}Description" % NS['mdui']), langs=langs): + info = description.text + + if info == entity.get('entityID'): + info = '' + + return trunc_str(display.strip(), 40), trunc_str(info.strip(), 256) + + +def entity_display_name(entity, langs=None): + """Utility-method for computing a displayable string for a given entity. + + :param entity: An EntityDescriptor element + :param langs: The list of languages to search in priority order + """ + for displayName in filter_lang(entity.iter("{%s}DisplayName" % NS['mdui']), langs=langs): + return displayName.text.strip() + + for serviceName in filter_lang(entity.iter("{%s}ServiceName" % NS['md']), langs=langs): + return serviceName.text.strip() + + for organizationDisplayName in filter_lang(entity.iter("{%s}OrganizationDisplayName" % NS['md']), langs=langs): + return organizationDisplayName.text.strip() + + for organizationName in filter_lang(entity.iter("{%s}OrganizationName" % NS['md']), langs=langs): + return organizationName.text.strip() + + return entity.get('entityID').strip() + + +def sub_domains(e): + lst = [] + domains = entity_domains(e) + for d in domains: + for sub in subdomains(d): + if sub not in lst: + lst.append(sub) + return lst + + +def entity_scopes(e): + elt = e.findall('.//{%s}IDPSSODescriptor/{%s}Extensions/{%s}Scope' % (NS['md'], NS['md'], NS['shibmd'])) + if elt is None or len(elt) == 0: + return None + return [s.text for s in elt] + + +def discojson(e, langs=None): + if e is None: + return dict() + + title, descr = entity_extended_display(e) + entity_id = e.get('entityID') + + d = dict(title=title, + descr=descr, + auth='saml', + entityID=entity_id) + + eattr = entity_attribute_dict(e) + if 'idp' in eattr[ATTRS['role']]: + d['type'] = 'idp' + d['hidden'] = 'true' + if 'http://pyff.io/category/discoverable' in eattr[ATTRS['entity-category']]: + d['hidden'] = 'false' + elif 'sp' in eattr[ATTRS['role']]: + d['type'] = 'sp' + + icon_info = entity_icon(e) + if icon_info is not None: + d['entity_icon'] = icon_info.get('url', PLACEHOLDER_ICON) + d['icon_height'] = icon_info.get('height', 64) + d['icon_width'] = icon_info.get('width', 64) + + scopes = entity_scopes(e) + if scopes is not None and len(scopes) > 0: + d['scope'] = ",".join(scopes) + + keywords = filter_lang(e.iter("{%s}Keywords" % NS['mdui']), langs=langs) + if keywords is not None: + lst = [elt.text for elt in keywords] + if len(lst) > 0: + d['keywords'] = ",".join(lst) + psu = privacy_statement_url(e, langs) + if psu: + d['privacy_statement_url'] = psu + geo = entity_geoloc(e) + if geo: + d['geo'] = geo + + return d + +def sha1_id(e): + return hash_id(e, 'sha1') + +def entity_simple_summary(e): + if e is None: + return dict() + + title, descr = entity_extended_display(e) + entity_id = e.get('entityID') + d = dict(title=title, + descr=descr, + entityID=entity_id, + domains=";".join(sub_domains(e)), + id=hash_id(e, 'sha1')) + icon_info = entity_icon(e) + if icon_info is not None: + url = icon_info.get('url', 'data:image/gif;base64,R0lGODlhAQABAIABAP///wAAACH5BAEKAAEALAAAAAABAAEAAAICTAEAOw==') + d['icon_url'] = url + d['entity_icon'] = url + + psu = privacy_statement_url(e, None) + if psu: + d['privacy_statement_url'] = psu + + return d + + +def entity_extensions(e): + """Return a list of the Extensions elements in the EntityDescriptor + +:param e: an EntityDescriptor +:return: a list + """ + ext = e.find("./{%s}Extensions" % NS['md']) + if ext is None: + ext = etree.Element("{%s}Extensions" % NS['md']) + e.insert(0, ext) + return ext + + +def annotate_entity(e, category, title, message, source=None): + """Add an ATOM annotation to an EntityDescriptor or an EntitiesDescriptor. This is a simple way to + add non-normative text annotations to metadata, eg for the purpuse of generating reports. + +:param e: An EntityDescriptor or an EntitiesDescriptor element +:param category: The ATOM category +:param title: The ATOM title +:param message: The ATOM content +:param source: An optional source URL. It is added as a element with @rel='saml-metadata-source' + """ + if e.tag != "{%s}EntityDescriptor" % NS['md'] and e.tag != "{%s}EntitiesDescriptor" % NS['md']: + raise MetadataException('I can only annotate EntityDescriptor or EntitiesDescriptor elements') + subject = e.get('Name', e.get('entityID', None)) + atom = ElementMaker(nsmap={'atom': 'http://www.w3.org/2005/Atom'}, namespace='http://www.w3.org/2005/Atom') + args = [atom.published("%s" % datetime.now().isoformat()), + atom.link(href=subject, rel="saml-metadata-subject")] + if source is not None: + args.append(atom.link(href=source, rel="saml-metadata-source")) + args.extend([atom.title(title), + atom.category(term=category), + atom.content(message, type="text/plain")]) + entity_extensions(e).append(atom.entry(*args)) + + +def _entity_attributes(e): + ext = entity_extensions(e) + ea = ext.find(".//{%s}EntityAttributes" % NS['mdattr']) + if ea is None: + ea = etree.Element("{%s}EntityAttributes" % NS['mdattr']) + ext.append(ea) + return ea + + +def _eattribute(e, attr, nf): + ea = _entity_attributes(e) + a = ea.xpath(".//saml:Attribute[@NameFormat='%s' and @Name='%s']" % (nf, attr), + namespaces=NS, + smart_strings=False) + if a is None or len(a) == 0: + a = etree.Element("{%s}Attribute" % NS['saml']) + a.set('NameFormat', nf) + a.set('Name', attr) + ea.append(a) + else: + a = a[0] + return a + + +def set_entity_attributes(e, d, nf=NF_URI): + """Set an entity attribute on an EntityDescriptor + +:param e: The EntityDescriptor element +:param d: A dict of attribute-value pairs that should be added as entity attributes +:param nf: The nameFormat (by default "urn:oasis:names:tc:SAML:2.0:attrname-format:uri") to use. +:raise: MetadataException unless e is an EntityDescriptor element + """ + if e.tag != "{%s}EntityDescriptor" % NS['md']: + raise MetadataException("I can only add EntityAttribute(s) to EntityDescriptor elements") + + for attr, value in d.iteritems(): + a = _eattribute(e, attr, nf) + velt = etree.Element("{%s}AttributeValue" % NS['saml']) + velt.text = value + a.append(velt) + + +def set_pubinfo(e, publisher=None, creation_instant=None): + if e.tag != "{%s}EntitiesDescriptor" % NS['md']: + raise MetadataException("I can only set RegistrationAuthority to EntitiesDescriptor elements") + if publisher is None: + raise MetadataException("At least publisher must be provided") + + if creation_instant is None: + now = datetime.utcnow() + creation_instant = now.strftime("%Y-%m-%dT%H:%M:%SZ") + + ext = entity_extensions(e) + pi = ext.find(".//{%s}PublicationInfo" % NS['mdrpi']) + if pi is not None: + raise MetadataException("A PublicationInfo element is already present") + pi = etree.Element("{%s}PublicationInfo" % NS['mdrpi']) + pi.set('publisher', publisher) + if creation_instant: + pi.set('creationInstant', creation_instant) + ext.append(pi) + + +def set_reginfo(e, policy=None, authority=None): + if e.tag != "{%s}EntityDescriptor" % NS['md']: + raise MetadataException("I can only set RegistrationAuthority to EntityDescriptor elements") + if authority is None: + raise MetadataException("At least authority must be provided") + if policy is None: + policy = dict() + + ext = entity_extensions(e) + ri = ext.find(".//{%s}RegistrationInfo" % NS['mdrpi']) + if ri is not None: + raise MetadataException("A RegistrationInfo element is already present") + + ri = etree.Element("{%s}RegistrationInfo" % NS['mdrpi']) + ext.append(ri) + ri.set('registrationAuthority', authority) + for lang, policy_url in policy.iteritems(): + rp = etree.Element("{%s}RegistrationPolicy" % NS['mdrpi']) + rp.text = policy_url + rp.set('{%s}lang' % NS['xml'], lang) + ri.append(rp) + + +def expiration(t): + relt = root(t) + if relt.tag in ('{%s}EntityDescriptor' % NS['md'], '{%s}EntitiesDescriptor' % NS['md']): + cache_duration = config.default_cache_duration + valid_until = relt.get('validUntil', None) + if valid_until is not None: + now = datetime.utcnow() + vu = iso2datetime(valid_until) + now = now.replace(microsecond=0) + vu = vu.replace(microsecond=0, tzinfo=None) + return vu - now + elif config.respect_cache_duration: + cache_duration = relt.get('cacheDuration', config.default_cache_duration) + return duration2timedelta(cache_duration) + + return None diff --git a/src/pyff/store.py b/src/pyff/store.py index 9e89bf78..ac8ba026 100644 --- a/src/pyff/store.py +++ b/src/pyff/store.py @@ -10,83 +10,11 @@ import re from redis import Redis -from pyff.constants import NS, ATTRS -from pyff.decorators import cached -from pyff.logs import log -from pyff.utils import root, dumptree, parse_xml, hex_digest, hash_id, EntitySet, \ - url2host, subdomains, has_tag, iter_entities, valid_until_ts, guess_entity_software - - -def is_idp(entity): - return has_tag(entity, "{%s}IDPSSODescriptor" % NS['md']) - - -def is_sp(entity): - return has_tag(entity, "{%s}SPSSODescriptor" % NS['md']) - - -def is_aa(entity): - return has_tag(entity, "{%s}AttributeAuthorityDescriptor" % NS['md']) - - -def _domains(entity): - domains = [url2host(entity.get('entityID'))] - for d in entity.iter("{%s}DomainHint" % NS['mdui']): - if d.text not in domains: - domains.append(d.text) - return domains - - -def with_entity_attributes(entity, cb): - - def _stext(e): - if e.text is not None: - return e.text.strip() - - for ea in entity.iter("{%s}EntityAttributes" % NS['mdattr']): - for a in ea.iter("{%s}Attribute" % NS['saml']): - an = a.get('Name', None) - if a is not None: - values = filter(lambda x: x is not None, [_stext(v) for v in a.iter("{%s}AttributeValue" % NS['saml'])]) - cb(an, values) - - -def _all_domains_and_subdomains(entity): - dlist = [] - try: - for dn in _domains(entity): - for sub in subdomains(dn): - dlist.append(sub) - except ValueError: - pass - return dlist - - -def entity_attribute_dict(entity): - d = {} - - def _u(an, values): - d[an] = values - with_entity_attributes(entity, _u) - - d[ATTRS['domain']] = _all_domains_and_subdomains(entity) - - roles = d.setdefault(ATTRS['role'], []) - if is_idp(entity): - roles.append('idp') - eca = ATTRS['entity-category'] - ec = d.setdefault(eca, []) - if 'http://refeds.org/category/hide-from-discovery' not in ec: - ec.append('http://pyff.io/category/discoverable') - if is_sp(entity): - roles.append('sp') - if is_aa(entity): - roles.append('aa') - - if ATTRS['software'] not in d: - d[ATTRS['software']] = [guess_entity_software(entity)] - - return d +from .constants import NS, ATTRS +from .decorators import cached +from .logs import log +from .utils import root, dumptree, parse_xml, hex_digest, hash_id, valid_until_ts +from .samlmd import EntitySet, iter_entities, entity_attribute_dict, is_sp, is_idp def _now(): @@ -95,14 +23,10 @@ def _now(): DINDEX = ('sha1', 'sha256', 'null') - class StoreBase(object): def lookup(self, key): raise NotImplementedError() - def ready(self): - raise NotImplementedError() - def clone(self): return self @@ -132,7 +56,6 @@ def __init__(self): self.md = dict() self.index = dict() self.entities = dict() - self._ready = False for hn in DINDEX: self.index.setdefault(hn, {}) @@ -158,12 +81,6 @@ def attributes(self): def attribute(self, a): return self.index.setdefault('attr', {}).setdefault(a, {}).keys() - def ready(self): - return self._ready - - def periodic(self, stats): - self._ready = True - def _modify(self, entity, modifier): def _m(idx, vv): @@ -268,22 +185,20 @@ def _lookup(self, key): m = re.match("^(.+)=(.+)$", key) if m: - return self._lookup("{%s}%s" % (m.group(1), m.group(2).rstrip("/"))) + return self._lookup("{%s}%s" % (m.group(1), str(m.group(2)).rstrip("/"))) m = re.match("^{(.+)}(.+)$", key) if m: res = set() - for v in m.group(2).rstrip("/").split(';'): + for v in str(m.group(2)).rstrip("/").split(';'): # log.debug("... adding %s=%s" % (m.group(1),v)) res.update(self._get_index(m.group(1), v)) return list(res) - # log.debug("trying null index lookup %s" % key) l = self._get_index("null", key) if l: return list(l) - # log.debug("trying main index lookup %s: " % key) if key in self.md: # log.debug("entities list %s: %s" % (key, self.md[key])) lst = [] @@ -305,9 +220,6 @@ def _expiration(self, relt): if self.respect_validity: return valid_until_ts(relt, ts) - def ready(self): - return True - def reset(self): self.rc.flushdb() @@ -432,7 +344,7 @@ def lookup(self, key): hk = hex_digest(key) if not self.rc.exists("%s#members" % hk): self.rc.zunionstore("%s#members" % hk, - ["{%s}%s#members" % (m.group(1), v) for v in m.group(2).split(';')], 'min') + ["{%s}%s#members" % (m.group(1), v) for v in str(m.group(2)).split(';')], 'min') self.rc.expire("%s#members" % hk, 30) # XXX bad juju - only to keep clients from hammering return self.lookup(hk) elif self.rc.exists("%s#alias" % key): diff --git a/src/pyff/templates/metadata.html b/src/pyff/templates/metadata.html index cdd89bc7..00ca8e95 100644 --- a/src/pyff/templates/metadata.html +++ b/src/pyff/templates/metadata.html @@ -7,15 +7,15 @@
    {% for entity in entities %}
  • - - {% if md.is_idp(entity) %} + + {% if samlmd.is_idp(entity) %} - {% elif md.is_sp(entity) %} + {% elif samlmd.is_sp(entity) %} {% else %} {% endif %} - {{ md.display(entity) }} + {{ samlmd.entity_display_name(entity) }}
  • {% endfor %} diff --git a/src/pyff/test/test_repo.py b/src/pyff/test/test_repo.py index ef384445..507e80c0 100644 --- a/src/pyff/test/test_repo.py +++ b/src/pyff/test/test_repo.py @@ -99,9 +99,9 @@ def test_utils(self): assert (summary['title'] == 'Example University') assert (summary['descr'] == 'Identity Provider for Example University') assert (summary['entityID'] == entity_id) - assert ('icon' in summary) - assert ('icon_url' in summary and summary['icon'] == summary['icon_url']) - assert ('domains' in summary) + assert ('entity_icon' in summary) + assert ('icon_url' in summary and summary['entity_icon'] == summary['icon_url']) + assert ('entity_domains' in summary) assert ('id' in summary) empty = self.md.simple_summary(None) diff --git a/src/pyff/utils.py b/src/pyff/utils.py index 47f7654e..1635eb53 100644 --- a/src/pyff/utils.py +++ b/src/pyff/utils.py @@ -1,4 +1,6 @@ # coding=utf-8 +from __future__ import print_function, unicode_literals, absolute_import + """ This module contains various utilities. @@ -7,43 +9,32 @@ import hashlib import io import tempfile -from collections import namedtuple from datetime import timedelta, datetime from email.utils import parsedate from threading import local -from time import gmtime, strftime, clock -from traceback import print_exc +from time import gmtime, strftime from urlparse import urlparse from itertools import chain -import urllib -from markupsafe import Markup + import xmlsec import cherrypy -import httplib2 import iso8601 import os import pkg_resources import re from jinja2 import Environment, PackageLoader from lxml import etree - -from .constants import NS -from .constants import config -from .decorators import retry +from .constants import config, NS from .logs import log +from .exceptions import * +from .i18n import language __author__ = 'leifj' -import i18n - sentinel = object() thread_data = local() -class PyffException(Exception): - pass - - def xml_error(error_log, m=None): def _f(x): if ":WARNING:" in x: @@ -59,6 +50,9 @@ def debug_observer(e): log.error(repr(e)) +def trunc_str(x, l): + return (x[:l] + '..') if len(x) > l else x + def resource_string(name, pfx=None): """ Attempt to load and return the contents (as a string) of the resource named by @@ -163,7 +157,7 @@ def resolve(self, system_url, public_id, context): """ Resolves URIs using the resource API """ - log.debug("resolve SYSTEM URL' %s' for '%s'" % (system_url, public_id)) + #log.debug("resolve SYSTEM URL' %s' for '%s'" % (system_url, public_id)) path = system_url.split("/") fn = path[len(path) - 1] if pkg_resources.resource_exists(__name__, fn): @@ -240,7 +234,11 @@ def safe_write(fn, data): site_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "site") env = Environment(loader=PackageLoader(__package__, 'templates'), extensions=['jinja2.ext.i18n']) -env.install_gettext_callables(i18n.language.gettext, i18n.language.ngettext, newstyle=True) +getattr(env, 'install_gettext_callables')(language.gettext, language.ngettext, newstyle=True) + +import urllib +from markupsafe import Markup + def urlencode_filter(s): if type(s) == 'Markup': @@ -273,68 +271,12 @@ def render_template(name, **kwargs): return template(name).render(**kwargs) -_Resource = namedtuple("Resource", ["result", "cached", "date", "last_modified", "resp", "time"]) - - def parse_date(s): if s is None: return datetime.now() return datetime(*parsedate(s)[:6]) -@retry((IOError, httplib2.HttpLib2Error)) -def load_url(url, enable_cache=True, timeout=60): - start_time = clock() - cache = httplib2.FileCache(".cache") - headers = {'Accept': 'application/samlmetadata+xml,text/xml,application/xml'} - if not enable_cache: - headers['cache-control'] = 'no-cache' - - log.debug("fetching (caching: %s) '%s'" % (enable_cache, url)) - - if url.startswith('file://'): - path = url[7:] - if not os.path.exists(path): - log.error("file not found: %s" % path) - return _Resource(result=None, - cached=False, - date=None, - resp=None, - time=None, - last_modified=None) - - with io.open(path, 'rb') as fd: - return _Resource(result=fd.read(), - cached=False, - date=datetime.now(), - resp=None, - time=clock() - start_time, - last_modified=datetime.fromtimestamp(os.stat(path).st_mtime)) - else: - h = httplib2.Http(cache=cache, - timeout=timeout, - disable_ssl_certificate_validation=True) # trust is done using signatures over here - log.debug("about to request %s" % url) - try: - resp, content = h.request(url, headers=headers) - except Exception as ex: - print_exc(ex) - raise ex - log.debug("got status: %d" % resp.status) - if resp.status != 200: - log.debug("got resp code %d (%d bytes)" % (resp.status, len(content))) - raise IOError(resp.reason) - log.debug("last-modified header: %s" % resp.get('last-modified')) - log.debug("date header: %s" % resp.get('date')) - log.debug("last modified: %s" % resp.get('date', resp.get('last-modified', None))) - return _Resource(result=content, - cached=resp.fromcache, - date=parse_date(resp['date']), - resp=resp, - time=clock() - start_time, - last_modified=parse_date(resp.get('date', resp.get('last-modified', None)))) - - def root(t): if hasattr(t, 'getroot') and hasattr(t.getroot, '__call__'): return t.getroot() @@ -458,88 +400,8 @@ def hex_digest(data, hn='sha1'): return m.hexdigest() -def parse_xml(source, base_url=None): - return etree.parse(source, base_url=base_url, parser=etree.XMLParser(resolve_entities=False)) - - -class EntitySet(object): - def __init__(self, initial=None): - self._e = dict() - if initial is not None: - for e in initial: - self.add(e) - - def add(self, value): - self._e[value.get('entityID')] = value - - def discard(self, value): - entity_id = value.get('entityID') - if entity_id in self._e: - del self._e[entity_id] - - def __iter__(self): - for e in self._e.values(): - yield e - - def __len__(self): - return len(self._e.keys()) - - def __contains__(self, item): - return item.get('entityID') in self._e.keys() - - -class MetadataException(Exception): - pass - - -class MetadataExpiredException(MetadataException): - pass - - -def find_merge_strategy(strategy_name): - if '.' not in strategy_name: - strategy_name = "pyff.merge_strategies.%s" % strategy_name - (mn, sep, fn) = strategy_name.rpartition('.') - # log.debug("import %s from %s" % (fn,mn)) - module = None - if '.' in mn: - (pn, sep, modn) = mn.rpartition('.') - module = getattr(__import__(pn, globals(), locals(), [modn], -1), modn) - else: - module = __import__(mn, globals(), locals(), [], -1) - strategy = getattr(module, fn) # we might aswell let this fail early if the strategy is wrongly named - - if strategy is None: - raise MetadataException("Unable to find merge strategy %s" % strategy_name) - - return strategy - - -def entities_list(t=None): - """ - :param t: An EntitiesDescriptor or EntityDescriptor element - - Returns the list of contained EntityDescriptor elements - """ - if t is None: - return [] - elif root(t).tag == "{%s}EntityDescriptor" % NS['md']: - return [root(t)] - else: - return iter_entities(t) - - -def iter_entities(t): - if t is None: - return [] - return t.iter('{%s}EntityDescriptor' % NS['md']) - - -def find_entity(t, e_id, attr='entityID'): - for e in iter_entities(t): - if e.get(attr) == e_id: - return e - return None +def parse_xml(io, base_url=None): + return etree.parse(io, base_url=base_url, parser=etree.XMLParser(resolve_entities=False)) def has_tag(t, tag): @@ -600,6 +462,12 @@ def sync_nsmap(nsmap, elt): pass +def load_callable( name ): + from importlib import import_module + p, m = name.rsplit(':', 1) + mod = import_module(p) + return getattr(mod, m) + # semantics copied from https://github.com/lordal/md-summary/blob/master/md-summary # many thanks to Anders Lordahl & Scotty Logan for the idea def guess_entity_software(e): @@ -655,3 +523,10 @@ def guess_entity_software(e): return 'OpenAthens' return 'other' + + +def printable(s): + if type(s) is unicode: + return s.encode('ascii', errors='ignore').decode() + else: + return s.decode("ascii", errors="ignore").encode() \ No newline at end of file