Skip to content

Commit

Permalink
Merge pull request #151 from TeskaLabs/feature/extra-locate-keys
Browse files Browse the repository at this point in the history
Custom login URI and custom login key
  • Loading branch information
byewokko authored Feb 24, 2023
2 parents d98e2a3 + 54deda2 commit c453b3a
Show file tree
Hide file tree
Showing 15 changed files with 118 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
python -m pip install --upgrade pip
pip install bson
pip install pymongo
pip install git+https://github.com/TeskaLabs/asab#egg=asab[storage_encryption]
pip install git+https://github.com/TeskaLabs/asab#egg=asab[encryption]
- name: Test with unittest
run: |
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
- Allow unsetting some client features (#148, PLUM Sprint 230113)
- OAuth 2.0 PKCE challenge (RFC7636) (#152, PLUM Sprint 230127)
- Session tracking ID introduced (#135, PLUM Sprint 230210)
- Clients can register a custom login_uri and login_key (#151, PLUM Sprint 230210)
- Authorize request adds client_id to login URL query (#151, PLUM Sprint 230210)

### Refactoring
- Regex validation of cookie_domain client attribute (#144, PLUM Sprint 230113)
Expand Down
22 changes: 19 additions & 3 deletions seacatauth/authn/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ async def login_prologue(self, request):
# Get arguments specified in login URL query
expiration = None
login_preferences = None
login_key = None
query_string = key.get("qs")
if query_string is not None:
query_dict = urllib.parse.parse_qs(query_string)

# Get requested session expiration
# TODO: This option should be moved to client config or removed completely
expiration = query_dict.get("expiration")
if expiration is not None:
try:
Expand All @@ -72,10 +74,16 @@ async def login_prologue(self, request):
L.warning("Error when parsing expiration: {}".format(e))

# Get preferred login descriptor IDs
# TODO: This option should be moved to client config or removed completely
login_preferences = query_dict.get("ldid")

# Get login key by client ID
client_id = query_dict.get("client_id")
if client_id is not None:
login_key = await self._get_client_login_key(client_id[0])

# Locate credentials
credentials_id = await self.CredentialsService.locate(ident, stop_at_first=True)
credentials_id = await self.CredentialsService.locate(ident, stop_at_first=True, key=login_key)
if credentials_id is None or credentials_id == []:
L.warning("Cannot locate credentials.", struct_data={"ident": ident})
# Empty credentials is used for creating a fake login session
Expand Down Expand Up @@ -128,7 +136,6 @@ async def login_prologue(self, request):
}
return asab.web.rest.json_response(request, response)


async def login(self, request):
lsid = request.match_info["lsid"]

Expand Down Expand Up @@ -271,7 +278,6 @@ async def smslogin(self, request):
body = {"result": "OK" if success is True else "FAILED"}
return aiohttp.web.Response(body=login_session.encrypt(body))


async def webauthn_login(self, request):
# Decode JSON request
lsid = request.match_info["lsid"]
Expand Down Expand Up @@ -303,3 +309,13 @@ async def webauthn_login(self, request):
await self.AuthenticationService.update_login_session(lsid, data=login_data)

return aiohttp.web.Response(body=login_session.encrypt(authentication_options))


async def _get_client_login_key(self, client_id):
client_service = self.AuthenticationService.App.get_service("seacatauth.ClientService")
try:
client = await client_service.get(client_id)
login_key = client.get("login_key")
except KeyError:
login_key = None
return login_key
12 changes: 11 additions & 1 deletion seacatauth/client/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@
"login_uri": { # NON-CANONICAL
"type": "string",
"description": "URL of preferred login page."},
"login_key": { # NON-CANONICAL
"type": "object",
"description": "Additional data used for locating the credentials at login.",
"patternProperties": {
"^[a-zA-Z][a-zA-Z0-9_-]{0,126}[a-zA-Z0-9]$": {"anyOf": [
{"type": "string"},
{"type": "number"},
{"type": "boolean"},
{"type": "null"}]}}},
"template": { # NON-CANONICAL
"type": "string",
"description": "Client template.",
Expand Down Expand Up @@ -339,7 +348,8 @@ async def register(

# Optional client metadata
for k in frozenset([
"client_name", "client_uri", "logout_uri", "cookie_domain", "custom_data", "login_uri", "template"]):
"client_name", "client_uri", "logout_uri", "cookie_domain", "custom_data", "login_uri", "login_key",
"template"]):
v = kwargs.get(k)
if v is not None and len(v) > 0:
upsertor.set(k, v)
Expand Down
2 changes: 1 addition & 1 deletion seacatauth/credentials/providers/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_info(self) -> dict:
}


async def locate(self, ident: str, ident_fields: dict = None) -> str:
async def locate(self, ident: str, ident_fields: dict = None, key: dict = None) -> str:
'''
Locate credentials based on the vague 'ident', which could be the username, password, phone number etc.
Return credentials_id or return None if not found.
Expand Down
2 changes: 1 addition & 1 deletion seacatauth/credentials/providers/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def delete(self, credentials_id) -> Optional[str]:
self.Dictionary.pop(credentials_id[len(prefix):])
return "OK"

async def locate(self, ident: str, ident_fields: dict = None) -> str:
async def locate(self, ident: str, ident_fields: dict = None, key: dict = None) -> str:
# TODO: Implement ident_fields support
# Fast match based on the username
credentials_id = hashlib.sha224(ident.encode('utf-8')).hexdigest()
Expand Down
2 changes: 1 addition & 1 deletion seacatauth/credentials/providers/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _nomalize_credentials(self, username, user):
return obj


async def locate(self, ident: str, ident_fields: dict = None) -> str:
async def locate(self, ident: str, ident_fields: dict = None, key: dict = None) -> str:
query = {"query": {"bool": {
"filter": [
{"match_phrase": {"type": "user"}},
Expand Down
2 changes: 1 addition & 1 deletion seacatauth/credentials/providers/htpasswd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, provider_id, config_section_name):
self.HT = HtpasswdFile(self.Config['path'])


async def locate(self, ident: str, ident_fields: dict = None) -> str:
async def locate(self, ident: str, ident_fields: dict = None, key: dict = None) -> str:
# TODO: Implement ident_fields support
'''
Locate search for the exact match of provided ident and the username in the htpasswd file
Expand Down
2 changes: 1 addition & 1 deletion seacatauth/credentials/providers/ldap.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _locate_worker(self, ident: str):
return None


async def locate(self, ident: str, ident_fields: dict = None) -> str:
async def locate(self, ident: str, ident_fields: dict = None, key: dict = None) -> str:
# TODO: Implement ident_fields support
'''
Locate search for the exact match of provided ident and the username in the htpasswd file
Expand Down
2 changes: 1 addition & 1 deletion seacatauth/credentials/providers/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async def delete(self, credentials_id) -> Optional[str]:
return "OK"


async def locate(self, ident: str, ident_fields: dict = None) -> Optional[str]:
async def locate(self, ident: str, ident_fields: dict = None, key: dict = None) -> Optional[str]:
"""
Locate credentials by matching ident string against configured ident fields.
"""
Expand Down
7 changes: 5 additions & 2 deletions seacatauth/credentials/providers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,13 @@ def __init__(self, app, provider_id, config_section_name):
self.DataFields = None


async def locate(self, ident: str, ident_fields: dict = None) -> Optional[str]:
async def locate(self, ident: str, ident_fields: dict = None, key: dict = None) -> Optional[str]:
kwargs = {"ident": ident}
if key is not None:
kwargs.update(key)
async with aiomysql.connect(**self.ConnectionParams) as connection:
async with connection.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(self.LocateQuery, {"ident": ident})
await cursor.execute(self.LocateQuery, kwargs)
result = await cursor.fetchone()
if result is None:
return None
Expand Down
7 changes: 5 additions & 2 deletions seacatauth/credentials/providers/xmongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,11 @@ def _prepare_query(self, query: str, query_args: dict):
return bson.json_util.loads(bound_query)


async def locate(self, ident: str, ident_fields: dict = None) -> Optional[str]:
query = self._prepare_query(self.LocateQuery, {"ident": ident})
async def locate(self, ident: str, ident_fields: dict = None, key: dict = None) -> Optional[str]:
kwargs = {"ident": ident}
if key is not None:
kwargs.update(key)
query = self._prepare_query(self.LocateQuery, kwargs)
cursor = self.Collection.aggregate(query)
result = None
async for obj in cursor:
Expand Down
4 changes: 2 additions & 2 deletions seacatauth/credentials/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,13 @@ def register_provider(self, credentials_provider):
self.CredentialProviders[credentials_provider.ProviderID] = credentials_provider


async def locate(self, ident: str, stop_at_first: bool = False):
async def locate(self, ident: str, stop_at_first: bool = False, key: dict = None):
'''
Locate credentials based on the vague 'ident', which could be the username, password, phone number etc.
'''
ident = ident.strip()
credentials_ids = []
pending = [provider.locate(ident, self.IdentFields) for provider in self.CredentialProviders.values()]
pending = [provider.locate(ident, self.IdentFields, key) for provider in self.CredentialProviders.values()]
while len(pending) > 0:
done, pending = await asyncio.wait(pending)
for task in done:
Expand Down
35 changes: 34 additions & 1 deletion seacatauth/generic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
import logging

import urllib.parse
import aiohttp.web
import asab

Expand Down Expand Up @@ -137,6 +137,39 @@ async def nginx_introspection(
return response


def urlparse(url: str):
"""
Parse the URL into a dictionary.
Convenience wrapper around urllib.parse.urlparse().
"""
return urllib.parse.urlparse(url)._asdict()


def urlunparse(
*,
scheme: str = "",
netloc: str = "",
path: str = "",
params: str = "",
query: str = "",
fragment: str = ""
):
"""
Build URL from individual components.
Convenience wrapper around urllib.parse.urlunparse().
Example usage:
```python
parsed = parse_url("http://local.test?option=true")
parsed["path"] = "/some/subpath"
url = unparse_url(**parsed)
```
"""
return urllib.parse.urlunparse((scheme, netloc, path, params, query, fragment))


def generate_ergonomic_token(length: int):
'''
This function generates random string that is "ergonomic".
Expand Down
40 changes: 33 additions & 7 deletions seacatauth/openidconnect/handler/authorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ... import exceptions
from ..utils import AuthErrorResponseCode
from ..pkce import InvalidCodeChallengeMethodError
from ...generic import urlparse, urlunparse

#

Expand Down Expand Up @@ -479,12 +480,9 @@ async def reply_with_redirect_to_login(
)

login_query_params.append(("redirect_uri", authorize_redirect_uri))
login_query_params.append(("client_id", client_id))

login_url = "{}{}?{}".format(
self.AuthWebuiBaseUrl,
self.LoginPath,
urllib.parse.urlencode(login_query_params)
)
login_url = await self._build_login_uri(client_id, login_query_params)
response = aiohttp.web.HTTPNotFound(
headers={
"Location": login_url,
Expand All @@ -496,7 +494,6 @@ async def reply_with_redirect_to_login(
delete_cookie(self.App, response)
return response


async def reply_with_factor_setup_redirect(
self, session, missing_factors: list,
response_type: str, scope: list, client_id: str, redirect_uri: str,
Expand Down Expand Up @@ -560,7 +557,6 @@ async def reply_with_factor_setup_redirect(

return response


def reply_with_authentication_error(
self, error: str, redirect_uri: str,
error_description: str = None,
Expand Down Expand Up @@ -663,3 +659,33 @@ async def authorize_tenants_by_scope(self, scope, session, client_id):
raise exceptions.AccessDeniedError(subject=session.Credentials.Id)

return tenants


async def _build_login_uri(self, client_id, login_query_params):
"""
Check if the client has a registered login URI. If not, use the default.
Extend the URI with query parameters.
"""
try:
client_dict = await self.OpenIdConnectService.ClientService.get(client_id)
client_login_uri = client_dict.get("login_uri")
except KeyError:
client_login_uri = None
if client_login_uri is not None:
parsed = urlparse(client_login_uri)
query = urllib.parse.parse_qs(parsed["query"])
# WARNING: If the client's login URI includes query parameters with the same names
# as those used by Seacat Auth, they will be overwritten
query.update(login_query_params)
parsed["query"] = urllib.parse.urlencode(query)
login_url = urlunparse(**parsed)
else:
# Seacat Auth login expects the parameters to be at the end of the URL (in the fragment (hash) part)
# TODO: Consider using regular query parameters instead (UI refactoring needed)
# so that Seacat Auth UI does not need a special approach here
login_url = "{}{}?{}".format(
self.AuthWebuiBaseUrl,
self.LoginPath,
urllib.parse.urlencode(login_query_params)
)
return login_url

0 comments on commit c453b3a

Please sign in to comment.