diff --git a/galaxy/web/security/__init__.py b/galaxy/web/security/__init__.py new file mode 100644 index 0000000..9689a51 --- /dev/null +++ b/galaxy/web/security/__init__.py @@ -0,0 +1,130 @@ +import collections +import os +import os.path +import logging + +import galaxy.exceptions + +from Crypto.Cipher import Blowfish +from Crypto.Util.randpool import RandomPool +from Crypto.Util import number + +log = logging.getLogger(__name__) + +if os.path.exists("/dev/urandom"): + # We have urandom, use it as the source of random data + random_fd = os.open("/dev/urandom", os.O_RDONLY) + + def get_random_bytes(nbytes): + value = os.read(random_fd, nbytes) + # Normally we should get as much as we need + if len(value) == nbytes: + return value.encode("hex") + # If we don't, keep reading (this is slow and should never happen) + while len(value) < nbytes: + value += os.read(random_fd, nbytes - len(value)) + return value.encode("hex") +else: + def get_random_bytes(nbytes): + nbits = nbytes * 8 + random_pool = RandomPool(1064) + while random_pool.entropy < nbits: + random_pool.add_event() + random_pool.stir() + return str(number.getRandomNumber(nbits, random_pool.get_bytes)) + + +class SecurityHelper(object): + + def __init__(self, **config): + self.id_secret = config['id_secret'] + self.id_cipher = Blowfish.new(self.id_secret) + + per_kind_id_secret_base = config.get('per_kind_id_secret_base', self.id_secret) + self.id_ciphers_for_kind = _cipher_cache(per_kind_id_secret_base) + + def encode_id(self, obj_id, kind=None): + if obj_id is None: + raise galaxy.exceptions.MalformedId("Attempted to encode None id") + id_cipher = self.__id_cipher(kind) + # Convert to string + s = str(obj_id) + # Pad to a multiple of 8 with leading "!" + s = ("!" * (8 - len(s) % 8)) + s + # Encrypt + return id_cipher.encrypt(s).encode('hex') + + def encode_dict_ids(self, a_dict, kind=None, skip_startswith=None): + """ + Encode all ids in dictionary. Ids are identified by (a) an 'id' key or + (b) a key that ends with '_id' + """ + for key, val in a_dict.items(): + if key == 'id' or key.endswith('_id') and (skip_startswith is None or not key.startswith(skip_startswith)): + a_dict[key] = self.encode_id(val, kind=kind) + + return a_dict + + def encode_all_ids(self, rval, recursive=False): + """ + Encodes all integer values in the dict rval whose keys are 'id' or end + with '_id' excluding `tool_id` which are consumed and produced as is + via the API. + """ + if not isinstance(rval, dict): + return rval + for k, v in rval.items(): + if (k == 'id' or k.endswith('_id')) and v is not None and k not in ['tool_id', 'external_id']: + try: + rval[k] = self.encode_id(v) + except Exception: + pass # probably already encoded + if (k.endswith("_ids") and isinstance(v, list)): + try: + o = [] + for i in v: + o.append(self.encode_id(i)) + rval[k] = o + except Exception: + pass + else: + if recursive and isinstance(v, dict): + rval[k] = self.encode_all_ids(v, recursive) + elif recursive and isinstance(v, list): + rval[k] = map(lambda el: self.encode_all_ids(el, True), v) + return rval + + def decode_id(self, obj_id, kind=None): + id_cipher = self.__id_cipher(kind) + return int(id_cipher.decrypt(obj_id.decode('hex')).lstrip("!")) + + def encode_guid(self, session_key): + # Session keys are strings + # Pad to a multiple of 8 with leading "!" + s = ("!" * (8 - len(session_key) % 8)) + session_key + # Encrypt + return self.id_cipher.encrypt(s).encode('hex') + + def decode_guid(self, session_key): + # Session keys are strings + return self.id_cipher.decrypt(session_key.decode('hex')).lstrip("!") + + def get_new_guid(self): + # Generate a unique, high entropy 128 bit random number + return get_random_bytes(16) + + def __id_cipher(self, kind): + if not kind: + id_cipher = self.id_cipher + else: + id_cipher = self.id_ciphers_for_kind[kind] + return id_cipher + + +class _cipher_cache(collections.defaultdict): + + def __init__(self, secret_base): + self.secret_base = secret_base + + def __missing__(self, key): + return Blowfish.new(self.secret_base + "__" + key) diff --git a/tests/test_security_helper.py b/tests/test_security_helper.py new file mode 100644 index 0000000..b58d8fe --- /dev/null +++ b/tests/test_security_helper.py @@ -0,0 +1,73 @@ +from galaxy.web import security + + +test_helper_1 = security.SecurityHelper(id_secret="sec1") +test_helper_2 = security.SecurityHelper(id_secret="sec2") + + +def test_encode_decode(): + # Different ids are encoded differently + assert test_helper_1.encode_id(1) != test_helper_1.encode_id(2) + # But decoding and encoded id brings back to original id + assert 1 == test_helper_1.decode_id(test_helper_1.encode_id(1)) + + +def test_nested_encoding(): + # Does nothing if not a dict + assert test_helper_1.encode_all_ids(1) == 1 + + # Encodes top-level things ending in _id + assert test_helper_1.encode_all_ids(dict(history_id=1))["history_id"] == test_helper_1.encode_id(1) + # ..except tool_id + assert test_helper_1.encode_all_ids(dict(tool_id=1))["tool_id"] == 1 + + # Encodes lists at top level is end in _ids + expected_ids = [test_helper_1.encode_id(1), test_helper_1.encode_id(2)] + assert test_helper_1.encode_all_ids(dict(history_ids=[1, 2]))["history_ids"] == expected_ids + + # Encodes nested stuff if and only if recursive set to true. + nested_dict = dict(objects=dict(history_ids=[1, 2])) + assert test_helper_1.encode_all_ids(nested_dict)["objects"]["history_ids"] == [1, 2] + assert test_helper_1.encode_all_ids(nested_dict, recursive=False)["objects"]["history_ids"] == [1, 2] + assert test_helper_1.encode_all_ids(nested_dict, recursive=True)["objects"]["history_ids"] == expected_ids + + +def test_per_kind_encode_deocde(): + # Different ids are encoded differently + assert test_helper_1.encode_id(1, kind="k1") != test_helper_1.encode_id(2, kind="k1") + # But decoding and encoded id brings back to original id + assert 1 == test_helper_1.decode_id(test_helper_1.encode_id(1, kind="k1"), kind="k1") + + +def test_different_secrets_encode_differently(): + assert test_helper_1.encode_id(1) != test_helper_2.encode_id(1) + + +def test_per_kind_encodes_id_differently(): + assert test_helper_1.encode_id(1) != test_helper_2.encode_id(1, kind="new_kind") + + +def test_encode_dict(): + test_dict = dict( + id=1, + other=2, + history_id=3, + ) + encoded_dict = test_helper_1.encode_dict_ids(test_dict) + assert encoded_dict["id"] == test_helper_1.encode_id(1) + assert encoded_dict["other"] == 2 + assert encoded_dict["history_id"] == test_helper_1.encode_id(3) + + +def test_guid_generation(): + guids = set() + for i in range(100): + guids.add(test_helper_1.get_new_guid()) + assert len(guids) == 100 # Not duplicate guids generated. + + +def test_encode_decode_guid(): + session_key = test_helper_1.get_new_guid() + encoded_key = test_helper_1.encode_guid(session_key) + decoded_key = test_helper_1.decode_guid(encoded_key).encode("utf-8") + assert session_key == decoded_key, "%s != %s" % (session_key, decoded_key) diff --git a/update_galaxy_utils.sh b/update_galaxy_utils.sh index 942450b..55c6338 100755 --- a/update_galaxy_utils.sh +++ b/update_galaxy_utils.sh @@ -39,8 +39,8 @@ GALAXY_LIB_DIR=$GALAXY_DIRECTORY/lib GALAXY_UNIT_TEST_DIR=$GALAXY_DIRECTORY/test/unit -UTIL_FILES=(__init__.py aliaspickler.py bunch.py checkers.py compression_utils.py dictifiable.py docutils_template.txt filelock.py expressions.py hash_util.py heartbeat.py heartbeat.py image_util.py inflection.py json.py lazy_process.py odict.py oset.py object_wrapper.py plugin_config.py properties.py simplegraph.py sleeper.py sockets.py specs.py sqlite.py submodules.py topsort.py topsort.py xml_macros.py) -GALAXY_LIB=(galaxy/objectstore galaxy/tools/deps galaxy/tools/parser galaxy/tools/verify galaxy/jobs/metrics galaxy/tools/locations galaxy/tools/linters galaxy/tools/fetcher.py galaxy/tools/loader_directory.py galaxy/tools/loader.py galaxy/tools/lint.py galaxy/tools/lint_util.py galaxy/tools/deps galaxy/tools/toolbox galaxy/exceptions galaxy/tools/cwl galaxy/web/stack) +UTIL_FILES=(__init__.py aliaspickler.py bunch.py checkers.py compression_utils.py dictifiable.py docutils_template.txt filelock.py expressions.py hash_util.py heartbeat.py heartbeat.py image_util.py inflection.py json.py lazy_process.py odict.py oset.py object_wrapper.py plugin_config.py properties.py simplegraph.py sleeper.py sockets.py specs.py sqlite.py submodules.py tool_version.py topsort.py topsort.py xml_macros.py) +GALAXY_LIB=(galaxy/objectstore galaxy/tools/deps galaxy/tools/parser galaxy/tools/verify galaxy/jobs/metrics galaxy/tools/locations galaxy/tools/linters galaxy/tools/fetcher.py galaxy/tools/loader_directory.py galaxy/tools/loader.py galaxy/tools/lint.py galaxy/tools/lint_util.py galaxy/tools/deps galaxy/tools/toolbox galaxy/exceptions galaxy/tools/cwl galaxy/web/stack galaxy/web/security) TEST_FILES=(tools/test_parsing.py tools/test_toolbox_filters.py tools/test_watcher.py test_sqlite_utils.py tools/test_tool_deps.py tools/test_tool_loader.py test_topsort.py test_sockets.py test_objectstore.py test_lazy_process.py)