diff --git a/src/pyff/builtins.py b/src/pyff/builtins.py index 0767dc5e..55881371 100644 --- a/src/pyff/builtins.py +++ b/src/pyff/builtins.py @@ -25,7 +25,7 @@ from .pipes import Plumbing, PipeException, PipelineCallback, pipe from .stats import set_metadata_info from .utils import total_seconds, dumptree, safe_write, root, with_tree, duration2timedelta, xslt_transform, validate_document -from .samlmd import iter_entities, annotate_entity, set_entity_attributes, discojson, set_pubinfo, set_reginfo +from .samlmd import sort_entities, iter_entities, annotate_entity, set_entity_attributes, discojson, set_pubinfo, set_reginfo from .fetch import Resource from six import StringIO from six.moves.urllib_parse import urlparse @@ -290,6 +290,37 @@ def info(req, *opts): return req.t +@pipe +def sort(req, *opts): + """ +Sorts the working entities by the value returned by the given xpath. +By default, entities are sorted by 'entityID' when the 'order_by [xpath]' option is omitted and +otherwise as second criteria. +Entities where no value exists for a given xpath are sorted last. + +:param req: The request +:param opts: Options: (see bellow) +:return: None + +Options are put directly after "sort". E.g: + +.. code-block:: yaml + + - sort order_by [xpath] + +**Options** +- order_by [xpath] : xpath expression selecting to the value used for sorting the entities. + """ + if req.t is None: + raise PipeException("Unable to sort empty document.") + + opts = dict(zip(opts[0:1], [" ".join(opts[1:])])) + opts.setdefault('order_by', None) + sort_entities(req.t, opts['order_by']) + + return req.t + + @pipe def publish(req, *opts): """ diff --git a/src/pyff/samlmd.py b/src/pyff/samlmd.py index b28e18ec..04690762 100644 --- a/src/pyff/samlmd.py +++ b/src/pyff/samlmd.py @@ -841,3 +841,37 @@ def expiration(t): return duration2timedelta(cache_duration) return None + + +def sort_entities(t, sxp=None): + """ +Sorts the working entities 't' by the value returned by the xpath 'sxp' +By default, entities are sorted by 'entityID' when this method is called without 'sxp', and otherwise as +second criteria. +Entities where no value exists for the given 'sxp' are sorted last. + +:param t: An element tree containing the entities to sort +:param sxp: xpath expression selecting the value used for sorting the entities +""" + def get_key(e): + eid = e.attrib.get('entityID') + sv = None + try: + sxp_values = e.xpath(sxp, namespaces=NS, smart_strings=False) + try: + sv = sxp_values[0] + try: + sv = sv.text + except AttributeError: + pass + except IndexError: + log.warn("Sort pipe: unable to sort entity by '%s'. " + "Entity '%s' has no such value" % (sxp, eid)) + except TypeError: + pass + + log.debug("Generated sort key for entityID='%s' and %s='%s'" % (eid, sxp, sv)) + return sv is None, sv, eid + + container = root(t) + container[:] = sorted(container, key=lambda e: get_key(e)) diff --git a/src/pyff/test/test_pipeline.py b/src/pyff/test/test_pipeline.py index 525bc75b..e6e9a92c 100644 --- a/src/pyff/test/test_pipeline.py +++ b/src/pyff/test/test_pipeline.py @@ -310,6 +310,115 @@ def test_no_fail_on_error_invalid_dir(self): print(sys.stderr.captured) +class SortTest(PipeLineTest): + EID1 = "https://idp.aco.net/idp/shibboleth" + EID2 = "https://idp.example.com/saml2/idp/metadata.php" + EID3 = "https://sharav.abes.fr/idp/shibboleth" + + @staticmethod + def _run_sort_test(expected_order, sxp, res, l): + if sxp is not None: + # Verify expected warnings for missing sort values + for e in expected_order: + try: + if not isinstance(e[1], bool): + raise TypeError + if not e[1]: + keygen_fail_str = ("Sort pipe: unable to sort entity by '%s'. " + "Entity '%s' has no such value" % (sxp, e[0])) + try: + assert (keygen_fail_str in unicode(l)) + except AssertionError: + print("Test failed on expecting missing sort value from: '%s'.\nCould not find string " + "on the output: '%s'.\nOutput was:\n %s" % (e[0], keygen_fail_str,unicode(l))) + raise + except (IndexError, TypeError): + print("Test failed for: '%s' due to 'order_by' xpath supplied without proper expectation tuple." % + "".join(e)) + raise + + # Verify order + for i, me in enumerate(expected_order): + try: + assert res[i].attrib.get("entityID") == me[0] + except AssertionError: + print(("Test failed on verifying sort position %i.\nExpected: %s; Found: %s " % + (i, me[0], res[i].attrib.get("entityID")))) + raise + + # Test sort by entityID only + def test_sort(self): + sxp = None + self.output = tempfile.NamedTemporaryFile('w').name + with patch.multiple("sys", exit=self.sys_exit, stdout=StreamCapturing(sys.stdout), + stderr=StreamCapturing(sys.stderr)): + from testfixtures import LogCapture + with LogCapture() as l: + res, md = self.exec_pipeline(""" + - load: + - %s/metadata + - %s/simple-pipeline/idp.aco.net.xml + - select: + - "!//md:EntityDescriptor[md:IDPSSODescriptor]" + - sort + - stats + """ % (self.datadir, self.datadir)) + print(sys.stdout.captured) + print(sys.stderr.captured) + + # tuple format (entityID, has value for 'order_by' xpath) + expected_order = [(self.EID1, ), (self.EID2, ), (self.EID3, )] + self._run_sort_test(expected_order, sxp, res, l) + + # Test sort entries first by registrationAuthority + def test_sort_by_ra(self): + sxp = ".//md:Extensions/mdrpi:RegistrationInfo/@registrationAuthority" + self.output = tempfile.NamedTemporaryFile('w').name + with patch.multiple("sys", exit=self.sys_exit, stdout=StreamCapturing(sys.stdout), + stderr=StreamCapturing(sys.stderr)): + from testfixtures import LogCapture + with LogCapture() as l: + res, md = self.exec_pipeline(""" + - load: + - %s/metadata + - %s/simple-pipeline/idp.aco.net.xml + - select: + - "!//md:EntityDescriptor[md:IDPSSODescriptor]" + - sort order_by %s + - stats + """ % (self.datadir, self.datadir, sxp)) + print(sys.stdout.captured) + print(sys.stderr.captured) + + # tuple format (entityID, has value for 'order_by' xpath) + expected_order = [(self.EID3, True), (self.EID1, False), (self.EID2, False)] + self._run_sort_test(expected_order, sxp, res, l) + + # Test group entries by specific NameIDFormat support + def test_sort_group(self): + sxp = ".//md:IDPSSODescriptor/md:NameIDFormat[./text()='urn:mace:shibboleth:1.0:nameIdentifier']" + self.output = tempfile.NamedTemporaryFile('w').name + with patch.multiple("sys", exit=self.sys_exit, stdout=StreamCapturing(sys.stdout), + stderr=StreamCapturing(sys.stderr)): + from testfixtures import LogCapture + with LogCapture() as l: + res, md = self.exec_pipeline(""" + - load: + - %s/metadata + - %s/simple-pipeline/idp.aco.net.xml + - select: + - "!//md:EntityDescriptor[md:IDPSSODescriptor]" + - sort order_by %s + - stats + """ % (self.datadir, self.datadir, sxp)) + print(sys.stdout.captured) + print(sys.stderr.captured) + + # tuple format (entityID, has value for 'order_by' xpath) + expected_order = [(self.EID1, True), (self.EID3, True), (self.EID2, False)] + self._run_sort_test(expected_order, sxp, res, l) + + # noinspection PyUnresolvedReferences class SigningTest(PipeLineTest):