-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* accept reformatting * accept reformatting * seperate validators * accept reformatting
- Loading branch information
1 parent
f9ec761
commit 87113bd
Showing
8 changed files
with
328 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from abc import ABC, abstractmethod, abstractproperty | ||
from enum import Enum | ||
|
||
import requests | ||
from pydantic import BaseModel, ValidationError, validator | ||
|
||
from openeo_fastapi.client.exceptions import ( | ||
InvalidIssuerConfig, | ||
TokenCantBeValidated, | ||
TokenInvalid, | ||
) | ||
|
||
OIDC_WELLKNOWN_CONFIG_PATH = "/.well-known/openid-configuration" | ||
OIDC_USERINFO = "userinfo_endpoint" | ||
|
||
|
||
class Authenticator(ABC): | ||
# Authenticator validate method needs to know what decisions to make based on user info response from the issuer handler. | ||
# This will be different for different backends, so just put it as ABC for now. We might be able to define this if we want | ||
# to specify an auth config when initialising the backend. | ||
@abstractmethod | ||
def validate(self): | ||
pass | ||
|
||
|
||
class AuthMethod(Enum): | ||
"""Enum defining known auth methods.""" | ||
|
||
BASIC = "basic" | ||
OIDC = "oidc" | ||
|
||
|
||
# Breaks the OpenEO token format down into it's components. This makes it possible to use the token against the issuer. | ||
class AuthToken(BaseModel): | ||
""" """ | ||
|
||
bearer: bool | ||
method: AuthMethod | ||
provider: str | ||
token: str | ||
|
||
@validator("bearer", pre=True) | ||
def passwords_match(cls, v, values, **kwargs): | ||
if v != "Bearer ": | ||
return ValueError("Token not formatted correctly") | ||
return True | ||
|
||
@validator("provider", pre=True) | ||
def check_provider(cls, v, values, **kwargs): | ||
if v == "": | ||
raise ValidationError("Empty provider string.") | ||
return v | ||
|
||
@validator("token", pre=True) | ||
def check_token(cls, v, values, **kwargs): | ||
if v == "": | ||
raise ValidationError("Empty token string.") | ||
return v | ||
|
||
@classmethod | ||
def from_token(cls, token: str): | ||
"""Takes the openeo format token, splits it into the component parts, and returns an Auth token.""" | ||
return cls( | ||
**dict(zip(["bearer", "method", "provider", "token"], token.split("/"))) | ||
) | ||
|
||
|
||
# TODO Remove? Would be good to generate the user info model for each issuer that is provided. | ||
class UserInfo(BaseModel): | ||
""" """ | ||
|
||
info: dict | ||
|
||
|
||
class IssuerHandler(BaseModel): | ||
"""General token handler for querying provided tokens against issuers.""" | ||
|
||
issuer_url: str | ||
organisation: str | ||
# TODO Roles will need to be used by the Authenticator class to be checked against the user info. | ||
roles: list | ||
|
||
@validator("issuer_url", pre=True) | ||
def remove_trailing_slash(cls, v, values, **kwargs): | ||
if v.endswith("/"): | ||
return v.removesuffix("/") | ||
return v | ||
|
||
def _get_issuer_config(self): | ||
""" """ | ||
return requests.get(self.issuer_url + OIDC_WELLKNOWN_CONFIG_PATH) | ||
|
||
def _get_user_info(self, info_endpoint, token): | ||
""" """ | ||
return requests.get( | ||
info_endpoint, | ||
headers={ | ||
"Content-Type": "application/json", | ||
"Authorization": f"Bearer {token}", | ||
}, | ||
) | ||
|
||
def _validate_oidc_token(self, token: str) -> UserInfo: | ||
""" """ | ||
|
||
issuer_oidc_config = self._get_issuer_config() | ||
|
||
if issuer_oidc_config.status_code != 200: | ||
raise InvalidIssuerConfig() | ||
|
||
userinfo_url = issuer_oidc_config.json()[OIDC_USERINFO] | ||
resp = self._get_user_info(userinfo_url, token) | ||
|
||
if resp.status_code != 200: | ||
raise TokenInvalid() | ||
|
||
return UserInfo(info=resp.json()) | ||
|
||
def validate_token(self, token: str) -> UserInfo: | ||
"""Try to validate the token against the give OIDC provider.""" | ||
# TODO Handle validation exceptions | ||
parsed_token = AuthToken.from_token(token) | ||
|
||
if parsed_token.method.value == AuthMethod.OIDC.value: | ||
return self._validate_oidc_token(parsed_token.token) | ||
raise TokenCantBeValidated() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
class TokenInvalid(Exception): | ||
""" """ | ||
|
||
pass | ||
|
||
|
||
class TokenCantBeValidated(Exception): | ||
""" """ | ||
|
||
pass | ||
|
||
|
||
class InvalidIssuerConfig(Exception): | ||
""" """ | ||
|
||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[virtualenvs] | ||
create = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
from pydantic import ValidationError | ||
|
||
from openeo_fastapi.client import auth, exceptions | ||
|
||
BASIC_TOKEN_EXAMPLE = "Bearer /basic/openeo/rubbish.not.a.token" | ||
OIDC_TOKEN_EXAMPLE = "Bearer /oidc/issuer/rubbish.not.a.token" | ||
|
||
INVALID_TOKEN_EXAMPLE_1 = "bearer /basic/openeo/rubbish.not.a.token" | ||
INVALID_TOKEN_EXAMPLE_2 = "Bearer /basicopeneorubbish.not.a.token" | ||
INVALID_TOKEN_EXAMPLE_3 = "Bearer //openeo/rubbish.not.a.token" | ||
INVALID_TOKEN_EXAMPLE_4 = "Bearer /basic//rubbish.not.a.token" | ||
INVALID_TOKEN_EXAMPLE_5 = "Bearer /basic/openeo/" | ||
|
||
|
||
def test_auth_method(): | ||
BASIC_VALUE = "basic" | ||
OIDC_VALUE = "oidc" | ||
|
||
basic = auth.AuthMethod(BASIC_VALUE) | ||
oidc = auth.AuthMethod(OIDC_VALUE) | ||
|
||
assert basic.value == BASIC_VALUE | ||
assert oidc.value == OIDC_VALUE | ||
|
||
with pytest.raises(ValueError): | ||
auth.AuthMethod("wrong") | ||
|
||
|
||
def test_auth_token(): | ||
def token_checks(token: auth.AuthToken, method: str, provider: str): | ||
assert token.bearer | ||
assert token.method.value == method | ||
assert token.provider == provider | ||
|
||
basic_token = auth.AuthToken.from_token(BASIC_TOKEN_EXAMPLE) | ||
token_checks(basic_token, "basic", "openeo") | ||
|
||
oidc_token = auth.AuthToken.from_token(OIDC_TOKEN_EXAMPLE) | ||
token_checks(oidc_token, "oidc", "issuer") | ||
|
||
# Check cases of invalid format raise a validation error. | ||
with pytest.raises(ValidationError): | ||
auth.AuthToken.from_token(INVALID_TOKEN_EXAMPLE_1) | ||
|
||
with pytest.raises(ValidationError): | ||
auth.AuthToken.from_token(INVALID_TOKEN_EXAMPLE_2) | ||
|
||
with pytest.raises(ValidationError): | ||
auth.AuthToken.from_token(INVALID_TOKEN_EXAMPLE_3) | ||
|
||
with pytest.raises(ValidationError): | ||
auth.AuthToken.from_token(INVALID_TOKEN_EXAMPLE_4) | ||
|
||
with pytest.raises(ValidationError): | ||
auth.AuthToken.from_token(INVALID_TOKEN_EXAMPLE_5) | ||
|
||
|
||
def test_issuer_handler_init(): | ||
test_issuer = auth.IssuerHandler( | ||
issuer_url="http://issuer.mycloud/", | ||
organisation="mycloud", | ||
roles=["admin", "user"], | ||
) | ||
|
||
# Check trailing slash removal | ||
assert not test_issuer.issuer_url.endswith("/") | ||
assert test_issuer.organisation == "mycloud" | ||
|
||
|
||
def test_issuer_handler__validate_oidc_token( | ||
mocked_oidc_config, mocked_oidc_userinfo, mocked_issuer | ||
): | ||
info = mocked_issuer._validate_oidc_token(token=OIDC_TOKEN_EXAMPLE) | ||
assert isinstance(info, auth.UserInfo) | ||
|
||
|
||
def test_issuer_handler__validate_oidc_token_bad_config( | ||
mocked_bad_oidc_config, mocked_oidc_userinfo, mocked_issuer | ||
): | ||
with pytest.raises(exceptions.InvalidIssuerConfig): | ||
mocked_issuer._validate_oidc_token(token=OIDC_TOKEN_EXAMPLE) | ||
|
||
|
||
def test_issuer_handler__validate_oidc_token_bad_userinfo( | ||
mocked_oidc_config, mocked_bad_oidc_userinfo, mocked_issuer | ||
): | ||
with pytest.raises(exceptions.TokenInvalid): | ||
mocked_issuer._validate_oidc_token(token=OIDC_TOKEN_EXAMPLE) | ||
|
||
|
||
def test_issuer_handler_validate_oidc_token( | ||
mocked_oidc_config, mocked_oidc_userinfo, mocked_issuer | ||
): | ||
info = mocked_issuer.validate_token(token=OIDC_TOKEN_EXAMPLE) | ||
assert isinstance(info, auth.UserInfo) | ||
|
||
|
||
def test_issuer_handler_validate_basic_token( | ||
mocked_oidc_config, mocked_oidc_userinfo, mocked_issuer | ||
): | ||
with pytest.raises(exceptions.TokenCantBeValidated): | ||
mocked_issuer.validate_token(token=BASIC_TOKEN_EXAMPLE) | ||
|
||
|
||
def test_issuer_handler_validate_broken_token( | ||
mocked_oidc_config, mocked_oidc_userinfo, mocked_issuer | ||
): | ||
with pytest.raises(ValidationError): | ||
mocked_issuer.validate_token(token=INVALID_TOKEN_EXAMPLE_1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters