diff --git a/CHANGELOG.md b/CHANGELOG.md index fe0c6d1e..ec1aaeb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## v24.36 ### Pre-releases +- v24.36-alpha5 - v24.36-alpha4 - v24.36-alpha3 - ~~v24.36-alpha2~~ @@ -18,9 +19,13 @@ - Hotfix: Session expiration in userinfo must match access token expiration (#414, `v24.29-alpha7`) ### Features +- Improve the default configuration of LDAP credentials provider (#422, `v24.36-alpha5`) - Duplicating roles (#416, `v24.36-alpha1`) - Run Batman with warning when there is no ElasticSearch URL (#413, `v24.29-alpha6`) +### Refactoring +- Refactor LDAP credentials provider (#422, `v24.36-alpha5`) + --- diff --git a/seacatauth/credentials/providers/ldap.py b/seacatauth/credentials/providers/ldap.py index 7cadbfa6..1d476273 100644 --- a/seacatauth/credentials/providers/ldap.py +++ b/seacatauth/credentials/providers/ldap.py @@ -2,6 +2,7 @@ import base64 import datetime import contextlib +import typing from typing import Optional @@ -12,6 +13,7 @@ import asab import asab.proactor +import asab.config from .abc import CredentialsProviderABC @@ -30,10 +32,6 @@ } -class LDAPObject(ldap.ldapobject.LDAPObject, ldap.resiter.ResultProcessor): - pass - - class LDAPCredentialsService(asab.Service): def __init__(self, app, service_name="seacatauth.credentials.ldap"): @@ -56,8 +54,8 @@ class LDAPCredentialsProvider(CredentialsProviderABC): "username": "cn=admin,dc=example,dc=org", "password": "admin", "base": "dc=example,dc=org", - "filter": "(&(objectClass=inetOrgPerson)(cn=*))", # should filter valid users only - "attributes": "mail mobile", + "filter": "|(objectClass=organizationalPerson)(objectClass=inetOrgPerson)", + "attributes": "mail mobile userAccountControl displayName", # Path to CA file in PEM format "tls_cafile": "", @@ -74,7 +72,7 @@ class LDAPCredentialsProvider(CredentialsProviderABC): "tls_protocol_max": "", "tls_cipher_suite": "", - "attrusername": "cn", # LDAP attribute that should be used as a username, e.g. `uid` or `sAMAccountName` + "attrusername": "sAMAccountName", # LDAP attribute that should be used as a username, e.g. `uid` or `sAMAccountName` } @@ -85,21 +83,60 @@ def __init__(self, provider_id, config_section_name, proactor_svc): # synchronous library (python-ldap) to be used from asynchronous code self.ProactorService = proactor_svc - attr = set(self.Config["attributes"].split(" ")) - attr.add("createTimestamp") - attr.add("modifyTimestamp") - attr.add("cn") - attr.add(self.Config["attrusername"]) - self.AttrList = list(attr) + self.LdapUri = self.Config["uri"] + self.Base = self.Config["base"] + self.Filter: str = self.Config["filter"] + if not (self.Filter.startswith("(") and self.Filter.endswith(")")): + self.Filter = "({})".format(self.Filter) + self.AttrList = _prepare_attributes(self.Config) # Fields to filter by when locating a user - self._locate_filter_fields = ["cn", "mail", "mobile"] + self.IdentFields = ["mail", "mobile"] # If attrusername field is not empty, locate by it too if len(self.Config["attrusername"]) > 0: - self._locate_filter_fields.append(self.Config["attrusername"]) + self.IdentFields.append(self.Config["attrusername"]) - async def get_login_descriptors(self, credentials_id): + async def get(self, credentials_id: str, include: typing.Optional[typing.Iterable[str]] = None) -> Optional[dict]: + if not credentials_id.startswith(self.Prefix): + raise KeyError("Credentials {!r} not found".format(credentials_id)) + cn = base64.urlsafe_b64decode(credentials_id[len(self.Prefix):]).decode("utf-8") + try: + return await self.ProactorService.execute(self._get_worker, cn) + except KeyError as e: + raise KeyError("Credentials not found: {!r}".format(credentials_id)) from e + + + async def search(self, filter: dict = None, sort: dict = None, page: int = 0, limit: int = 0, **kwargs) -> list: + # TODO: Implement pagination + filterstr = self._build_search_filter(filter) + return await self.ProactorService.execute(self._search_worker, filterstr) + + + async def count(self, filtr=None) -> int: + filterstr = self._build_search_filter(filtr) + return await self.ProactorService.execute(self._count_worker, filterstr) + + + async def iterate(self, offset: int = 0, limit: int = -1, filtr: str = None): + filterstr = self._build_search_filter(filtr) + results = await self.ProactorService.execute(self._search_worker, filterstr) + for i in results[offset:(None if limit == -1 else limit + offset)]: + yield i + + + async def locate(self, ident: str, ident_fields: dict = None, login_dict: dict = None) -> str: + return await self.ProactorService.execute(self._locate_worker, ident, ident_fields) + + + async def authenticate(self, credentials_id: str, credentials: dict) -> bool: + dn = base64.urlsafe_b64decode(credentials_id[len(self.Prefix):]).decode("utf-8") + password = credentials.get("password") + return await self.ProactorService.execute(self._authenticate_worker, dn, password) + + + async def get_login_descriptors(self, credentials_id: str) -> typing.List[typing.Dict]: + # Only login with password is supported return [{ "id": "default", "label": "Use recommended login.", @@ -112,334 +149,222 @@ async def get_login_descriptors(self, credentials_id): @contextlib.contextmanager def _ldap_client(self): - ldap_client = LDAPObject(self.Config["uri"]) + ldap_client = _LDAPObject(self.LdapUri) ldap_client.protocol_version = ldap.VERSION3 ldap_client.set_option(ldap.OPT_REFERRALS, 0) - network_timeout = int(self.Config.get("network_timeout")) + network_timeout = self.Config.getint("network_timeout") ldap_client.set_option(ldap.OPT_NETWORK_TIMEOUT, network_timeout) - # Enable TLS - if self.Config["uri"].startswith("ldaps"): - self._enable_tls(ldap_client) + if self.LdapUri.startswith("ldaps"): + _enable_tls(ldap_client, self.Config) ldap_client.simple_bind_s(self.Config["username"], self.Config["password"]) - try: yield ldap_client - finally: ldap_client.unbind_s() - def _enable_tls(self, ldap_client): - tls_cafile = self.Config["tls_cafile"] - - # Add certificate authority - if len(tls_cafile) > 0: - ldap_client.set_option(ldap.OPT_X_TLS_CACERTFILE, tls_cafile) - - # Set cert policy - if self.Config["tls_require_cert"] == "never": - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER) - elif self.Config["tls_require_cert"] == "demand": - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) - elif self.Config["tls_require_cert"] == "allow": - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_ALLOW) - elif self.Config["tls_require_cert"] == "hard": - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_HARD) - else: - L.error("Invalid 'tls_require_cert' value: {!r}. Defaulting to 'demand'.".format( - self.Config["tls_require_cert"] - )) - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) - - # Misc TLS options - tls_protocol_min = self.Config["tls_protocol_min"] - if tls_protocol_min != "": - if tls_protocol_min not in _TLS_VERSION: - raise ValueError("'tls_protocol_min' must be one of {} or empty.".format(list(_TLS_VERSION))) - ldap_client.set_option(ldap.OPT_X_TLS_PROTOCOL_MIN, _TLS_VERSION[tls_protocol_min]) - - tls_protocol_max = self.Config["tls_protocol_max"] - if tls_protocol_max != "": - if tls_protocol_max not in _TLS_VERSION: - raise ValueError("'tls_protocol_max' must be one of {} or empty.".format(list(_TLS_VERSION))) - ldap_client.set_option(ldap.OPT_X_TLS_PROTOCOL_MAX, _TLS_VERSION[tls_protocol_max]) - - if self.Config["tls_cipher_suite"] != "": - ldap_client.set_option(ldap.OPT_X_TLS_CIPHER_SUITE, self.Config["tls_cipher_suite"]) - # NEWCTX needs to be the last option, because it applies all the prepared options to the new context - ldap_client.set_option(ldap.OPT_X_TLS_NEWCTX, 0) - - def _get_worker(self, prefix, credentials_id, include=None) -> Optional[dict]: - - # TODO: Validate credetials_id with regex - - cn = base64.urlsafe_b64decode(credentials_id[len(prefix):]).decode("utf-8") - with self._ldap_client() as lc: + def _get_worker(self, cn: str) -> Optional[typing.Dict]: + with self._ldap_client() as ldap_client: try: - sr = lc.search_s( + results = ldap_client.search_s( cn, ldap.SCOPE_BASE, - filterstr=self.Config["filter"], + filterstr=self.Filter, attrlist=self.AttrList, ) except ldap.NO_SUCH_OBJECT as e: - L.error(e) - sr = [] - - if len(sr) == 0: - raise KeyError("Credentials {!r} not found".format(credentials_id)) + raise KeyError("CN matched no LDAP objects.") from e - assert len(sr) == 1 - dn, entry = sr[0] - - return _normalize_entry( - prefix, - self.Type, - self.ProviderID, - dn, - entry, - self.Config["attrusername"] - ) - - - async def get(self, credentials_id, include=None) -> Optional[dict]: - prefix = "{}:{}:".format(self.Type, self.ProviderID) - if not credentials_id.startswith(prefix): - raise KeyError("Credentials {!r} not found".format(credentials_id)) + if len(results) > 1: + L.error("CN matched multiple LDAP objects.", struct_data={"CN": cn}) + raise KeyError("CN matched multiple LDAP objects.") - return await self.ProactorService.execute(self._get_worker, prefix, credentials_id, include) + dn, entry = results[0] + return self._normalize_credentials(dn, entry) - def _count_worker(self, filterstr): - count = 0 - with self._ldap_client() as lc: - msgid = lc.search( - self.Config["base"], + def _search_worker(self, filterstr: str) -> typing.List[typing.Dict]: + # TODO: Implement sorting (Note that not all LDAP servers support server-side sorting) + results = [] + with self._ldap_client() as ldap_client: + msgid = ldap_client.search( + self.Base, ldap.SCOPE_SUBTREE, filterstr=filterstr, - attrsonly=1, # If attrsonly is non-zero - attrlist=["cn", "mail", "mobile"], # For counting, we need only absolutely minimum set of attributes + attrlist=self.AttrList, ) - result_iter = lc.allresults(msgid) + result_iter = ldap_client.allresults(msgid) for res_type, res_data, res_msgid, res_controls in result_iter: for dn, entry in res_data: - if dn is None: - continue - else: - count += 1 - - return count - - - async def count(self, filtr=None) -> int: - filterstr = self._build_search_filter(filtr) - return await self.ProactorService.execute(self._count_worker, filterstr) - + if dn is not None: + results.append(self._normalize_credentials(dn, entry)) - def _search_worker(self, filterstr): + return results - # TODO: sorting - prefix = "{}:{}:".format(self.Type, self.ProviderID) - result = [] - with self._ldap_client() as lc: - msgid = lc.search( - self.Config["base"], + def _count_worker(self, filterstr: str) -> int: + count = 0 + with self._ldap_client() as ldap_client: + msgid = ldap_client.search( + self.Base, ldap.SCOPE_SUBTREE, filterstr=filterstr, - attrlist=self.AttrList, + attrsonly=1, # If attrsonly is non-zero + attrlist=["cn", "mail", "mobile"], # For counting, we need only absolutely minimum set of attributes ) - result_iter = lc.allresults(msgid) + result_iter = ldap_client.allresults(msgid) for res_type, res_data, res_msgid, res_controls in result_iter: for dn, entry in res_data: if dn is not None: - result.append(_normalize_entry( - prefix, - self.Type, - self.ProviderID, - dn, - entry, - self.Config["attrusername"] - )) - - return result - - - async def search(self, filter: dict = None, **kwargs) -> list: - # TODO: Implement filtering and pagination - if filter is not None: - return [] - filterstr = self.Config["filter"] - return await self.ProactorService.execute(self._search_worker, filterstr) - - - async def iterate(self, offset: int = 0, limit: int = -1, filtr: str = None): - filterstr = self._build_search_filter(filtr) - arr = await self.ProactorService.execute(self._search_worker, filterstr) - for i in arr[offset:None if limit == -1 else limit + offset]: - yield i - - def _build_search_filter(self, filtr=None): - if not filtr: - filterstr = self.Config["filter"] - else: - # The query filter is the intersection of the filter from config - # and the filter defined by the search request - # The username must START WITH the given filter string - filter_template = "(&{}({}=*%s*))".format(self.Config["filter"], self.Config["attrusername"]) - assertion_values = ["{}".format(filtr.lower())] - filterstr = ldap.filter.filter_format( - filter_template=filter_template, - assertion_values=assertion_values - ) - return filterstr + count += 1 - def _locate_worker(self, ident: str): - with self._ldap_client() as lc: + return count - # Build the filter template - # Example: (|(cn=%s)(mail=%s)(mobile=%s)(sAMAccountName=%s)) - filter_template = "(|{})".format( - "".join("({}=%s)".format(field) for field in self._locate_filter_fields) - ) - assertion_values = tuple( - ident for _ in self._locate_filter_fields - ) - msgid = lc.search( - self.Config["base"], + def _locate_worker( + self, + ident: str, + ident_fields: typing.Optional[typing.Mapping[str, str]] = None + ) -> typing.Optional[str]: + # TODO: Implement configurable ident_fields support + with self._ldap_client() as ldap_client: + msgid = ldap_client.search( + self.Base, ldap.SCOPE_SUBTREE, filterstr=ldap.filter.filter_format( - filter_template=filter_template, - assertion_values=assertion_values + # Build the filter template + # Example: (|(cn=%s)(mail=%s)(mobile=%s)(sAMAccountName=%s)) + filter_template="(|{})".format( + "".join("({}=%s)".format(field) for field in self.IdentFields)), + assertion_values=tuple(ident for _ in self.IdentFields) ), - attrlist=["cn"] + attrlist=["cn"], ) - result_iter = lc.allresults(msgid) + result_iter = ldap_client.allresults(msgid) for res_type, res_data, res_msgid, res_controls in result_iter: for dn, entry in res_data: if dn is not None: - return "{}:{}:{}".format( - self.Type, - self.ProviderID, - base64.urlsafe_b64encode(dn.encode("utf-8")).decode("ascii"), - ) + return self._format_credentials_id(dn) return None - async def locate(self, ident: str, ident_fields: dict = None, login_dict: dict = None) -> str: - # TODO: Implement ident_fields support - """ - Locate search for the exact match of provided ident and the username in the htpasswd file - """ - return await self.ProactorService.execute(self._locate_worker, ident) - - - def _authenticate_worker(self, credentials_id: str, credentials: dict) -> bool: - prefix = "{}:{}:".format(self.Type, self.ProviderID) - - password = credentials.get("password") - dn = base64.urlsafe_b64decode(credentials_id[len(prefix):]).decode("utf-8") - - lc = LDAPObject(self.Config["uri"]) - lc.protocol_version = ldap.VERSION3 - lc.set_option(ldap.OPT_REFERRALS, 0) + def _authenticate_worker(self, dn: str, password: str) -> bool: + ldap_client = _LDAPObject(self.LdapUri) + ldap_client.protocol_version = ldap.VERSION3 + ldap_client.set_option(ldap.OPT_REFERRALS, 0) - # Enable TLS - if self.Config["uri"].startswith("ldaps"): - self._enable_tls(lc) + if self.LdapUri.startswith("ldaps"): + _enable_tls(ldap_client, self.Config) try: - lc.simple_bind_s(dn, password) + ldap_client.simple_bind_s(dn, password) except ldap.INVALID_CREDENTIALS: - L.log(asab.LOG_NOTICE, "Authentication failed: Invalid LDAP credentials.", struct_data={ - "cid": credentials_id, "dn": dn}) + L.log(asab.LOG_NOTICE, "Authentication failed: Invalid LDAP credentials.", struct_data={"dn": dn}) return False - lc.unbind_s() + ldap_client.unbind_s() return True - async def authenticate(self, credentials_id: str, credentials: dict) -> bool: - return await self.ProactorService.execute(self._authenticate_worker, credentials_id, credentials) + def _normalize_credentials(self, dn: str, search_record: typing.Mapping) -> typing.Dict: + ret = { + "_id": self._format_credentials_id(dn), + "_type": self.Type, + "_provider_id": self.ProviderID, + } + + decoded_record = {"dn": dn} + for k, v in search_record.items(): + if k == "userPassword": + continue + if isinstance(v, list): + if len(v) == 0: + continue + elif len(v) == 1: + decoded_record[k] = v[0].decode("utf-8") + else: + decoded_record[k] = [i.decode("utf-8") for i in v] + + v = decoded_record.pop(self.Config["attrusername"], None) + if v is not None: + ret["username"] = v + else: + # This is fallback, since we need a username on various places + ret["username"] = dn + v = decoded_record.pop("cn", None) + if v is not None: + ret["full_name"] = v -def _normalize_entry(prefix, ptype, provider_id, dn, entry, attrusername: str = "cn"): - ret = { - "_id": prefix + base64.urlsafe_b64encode(dn.encode("utf-8")).decode("ascii"), - "_type": ptype, - "_provider_id": provider_id, - } + v = decoded_record.pop("mail", None) + if v is not None: + ret["email"] = v - ldap_obj = { - "dn": dn, - } - for k, v in entry.items(): - if k in frozenset(["userPassword"]): - continue - if isinstance(v, list): - if len(v) == 1: - v = v[0].decode("utf-8") - else: - v = [i.decode("utf-8") for i in v] - ldap_obj[k] = v - - v = ldap_obj.pop(attrusername, None) - if v is not None: - ret["username"] = v - else: - # This is fallback, since we need a username on various places - ret["username"] = dn - - v = ldap_obj.pop("cn", None) - if v is not None: - ret["full_name"] = v - - v = ldap_obj.pop("mail", None) - if v is not None: - ret["email"] = v - - v = ldap_obj.pop("mobile", None) - if v is not None: - ret["phone"] = v - - v = ldap_obj.pop("userAccountControl", None) - if v is not None: - # userAccountControl is an array of binary flags returned as a decimal integer - # byte #1 is ACCOUNTDISABLE which corresponds to "suspended" status - # https://learn.microsoft.com/en-us/troubleshoot/windows-server/identity/useraccountcontrol-manipulate-account-properties - try: - ret["suspended"] = int(v) & 2 == 2 - except ValueError: - pass + v = decoded_record.pop("mobile", None) + if v is not None: + ret["phone"] = v - v = ldap_obj.pop("createTimestamp", None) - if v is not None: - ret["_c"] = _parse_timestamp(v) - else: - v = ldap_obj.pop("createTimeStamp", None) + v = decoded_record.pop("userAccountControl", None) + if v is not None: + # userAccountControl is an array of binary flags returned as a decimal integer + # byte #1 is ACCOUNTDISABLE which corresponds to "suspended" status + # https://learn.microsoft.com/en-us/troubleshoot/windows-server/identity/useraccountcontrol-manipulate-account-properties + try: + ret["suspended"] = int(v) & 2 == 2 + except ValueError: + pass + + v = decoded_record.pop("createTimestamp", None) if v is not None: ret["_c"] = _parse_timestamp(v) + else: + v = decoded_record.pop("createTimeStamp", None) + if v is not None: + ret["_c"] = _parse_timestamp(v) - v = ldap_obj.pop("modifyTimestamp", None) - if v is not None: - ret["_m"] = _parse_timestamp(v) - else: - v = ldap_obj.pop("modifyTimeStamp", None) + v = decoded_record.pop("modifyTimestamp", None) if v is not None: ret["_m"] = _parse_timestamp(v) + else: + v = decoded_record.pop("modifyTimeStamp", None) + if v is not None: + ret["_m"] = _parse_timestamp(v) + + if len(decoded_record) > 0: + ret["data"] = {k: v for k, v in decoded_record.items() if k in self.AttrList} - if len(ldap_obj) > 0: - ret["_ldap"] = ldap_obj + return ret - return ret + + def _format_credentials_id(self, dn: str) -> str: + return self.Prefix + base64.urlsafe_b64encode(dn.encode("utf-8")).decode("ascii") + + + def _build_search_filter(self, filtr: typing.Optional[str] = None) -> str: + if not filtr: + filterstr = self.Filter + else: + # The query filter is the intersection of the filter from config + # and the filter defined by the search request + # The username must START WITH the given filter string + filter_template = "(&{}({}=%s*))".format(self.Filter, self.Config["attrusername"]) + assertion_values = ["{}".format(filtr.lower())] + filterstr = ldap.filter.filter_format( + filter_template=filter_template, + assertion_values=assertion_values + ) + return filterstr + + +class _LDAPObject(ldap.ldapobject.LDAPObject, ldap.resiter.ResultProcessor): + pass def _parse_timestamp(ts: str) -> datetime.datetime: @@ -449,3 +374,54 @@ def _parse_timestamp(ts: str) -> datetime.datetime: pass return datetime.datetime.strptime(ts, r"%Y%m%d%H%M%S.%fZ") + + +def _prepare_attributes(config: typing.Mapping) -> list: + attr = set(config["attributes"].split(" ")) + attr.add("createTimestamp") + attr.add("modifyTimestamp") + attr.add("cn") + attr.add(config["attrusername"]) + return list(attr) + + +def _enable_tls(ldap_client, config: typing.Mapping): + tls_cafile = config["tls_cafile"] + + # Add certificate authority + if len(tls_cafile) > 0: + ldap_client.set_option(ldap.OPT_X_TLS_CACERTFILE, tls_cafile) + + # Set cert policy + if config["tls_require_cert"] == "never": + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER) + elif config["tls_require_cert"] == "demand": + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) + elif config["tls_require_cert"] == "allow": + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_ALLOW) + elif config["tls_require_cert"] == "hard": + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_HARD) + else: + L.error("Invalid 'tls_require_cert' value: {!r}. Defaulting to 'demand'.".format( + config["tls_require_cert"] + )) + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) + + # Misc TLS options + tls_protocol_min = config["tls_protocol_min"] + if tls_protocol_min != "": + if tls_protocol_min not in _TLS_VERSION: + raise ValueError("'tls_protocol_min' must be one of {} or empty.".format(list(_TLS_VERSION))) + ldap_client.set_option(ldap.OPT_X_TLS_PROTOCOL_MIN, _TLS_VERSION[tls_protocol_min]) + + tls_protocol_max = config["tls_protocol_max"] + if tls_protocol_max != "": + if tls_protocol_max not in _TLS_VERSION: + raise ValueError("'tls_protocol_max' must be one of {} or empty.".format(list(_TLS_VERSION))) + ldap_client.set_option(ldap.OPT_X_TLS_PROTOCOL_MAX, _TLS_VERSION[tls_protocol_max]) + + if config["tls_cipher_suite"] != "": + ldap_client.set_option(ldap.OPT_X_TLS_CIPHER_SUITE, config["tls_cipher_suite"]) + + # NEWCTX needs to be the last option, because it applies all the prepared options to the new context + ldap_client.set_option(ldap.OPT_X_TLS_NEWCTX, 0)