diff --git a/packages/opal-common/opal_common/sources/api_policy_source.py b/packages/opal-common/opal_common/sources/api_policy_source.py index 7adc9ad70..22d2c06da 100644 --- a/packages/opal-common/opal_common/sources/api_policy_source.py +++ b/packages/opal-common/opal_common/sources/api_policy_source.py @@ -2,8 +2,10 @@ from pathlib import Path from typing import Optional, Tuple from urllib.parse import urlparse +from xml.etree import ElementTree import aiohttp +import aiofiles from fastapi import status from fastapi.exceptions import HTTPException from opal_common.git_utils.tar_file_to_local_git_extractor import ( @@ -17,6 +19,7 @@ hash_file, throw_if_bad_status_code, tuple_to_dict, + async_time_cache, ) from opal_server.config import PolicyBundleServerType from tenacity import AsyncRetrying @@ -43,6 +46,9 @@ class ApiPolicySource(BasePolicySource): token (str, optional): auth token to include in connections to bundle server. Defaults to POLICY_BUNDLE_SERVER_TOKEN. token_id (str, optional): auth token ID to include in connections to bundle server. Defaults to POLICY_BUNDLE_SERVER_TOKEN_ID. bundle_server_type (PolicyBundleServerType, optional): the type of bundle server + region (str, optional): the aws region of s3 bucket containing the bundle + aws_role_arn (str, optional): the aws iam role to assume when accessing the s3 bucket. Only required when using temporary sts credentials. + aws_web_id_token_file (str, optional): the file containing a web id token for the target aws iam role. Only required when using temporary sts credentials. """ def __init__( @@ -53,6 +59,8 @@ def __init__( token: Optional[str] = None, token_id: Optional[str] = None, region: Optional[str] = None, + aws_role_arn: Optional[str] = None, + aws_web_id_token_file: Optional[str] = None, bundle_server_type: Optional[PolicyBundleServerType] = None, policy_bundle_path=".", policy_bundle_git_add_pattern="*", @@ -66,6 +74,8 @@ def __init__( self.token_id = token_id self.server_type = bundle_server_type self.region = region + self.aws_role_arn = aws_role_arn + self.aws_web_id_token_file = aws_web_id_token_file self.bundle_hash = None self.etag = None self.tmp_bundle_path = Path(policy_bundle_path) @@ -126,7 +136,84 @@ async def api_update_policy(self) -> Tuple[bool, str, str]: ) raise - def build_auth_headers(self, token=None, path=None): + @async_time_cache(ttl=3000) + async def get_temporary_sts_credentials(self) -> tuple[str, str, str]: + """ + This function will fetch a set of temporary credentials for a IAM role + from Amazon STS. It requires an aws region, the arn for the target role + and the file containing the web token. + + This function will return the id and secret key required for login. + When using temporary credentials, AWS also requires a session token + which this function also provides. + + This result of this funciton is cached to avoid being rate limited by + STS. + """ + assert self.aws_web_id_token_file + assert self.aws_role_arn + assert self.region + + async with aiofiles.open(self.aws_web_id_token_file) as token_file: + token = await token_file.read() + + sts_url = f"sts.{self.region}.amazonaws.com" + params: dict[str, str] = { + "Action": "AssumeRoleWithWebIdentity", + "DurationSeconds": "3600", + "RoleSessionName": "Opal", + "RoleArn": self.aws_role_arn, + "WebIdentityToken": token, + "Version": "2011-06-15", + } + + async with aiohttp.ClientSession() as session: + try: + async with session.get( + f"https://{sts_url}", + params=params, + headers={"Content-Type": "application/xml"}, + ) as response: + if response.status == status.HTTP_404_NOT_FOUND: + logger.warning( + "requested url not found: {sts_url}", + sts_url=sts_url, + ) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"requested url not found: {sts_url}", + ) + + body = await response.read() + + # the default aws xml namespace + ns = {"": "https://sts.amazonaws.com/doc/2011-06-15/"} + + et = ElementTree.fromstring(body) + credentials = et.find( + "AssumeRoleWithWebIdentityResult/Credentials", ns + ) + assert credentials + + id = credentials.findtext("AccessKeyId", namespaces=ns) + key = credentials.findtext("SecretAccessKey", namespaces=ns) + session_token = credentials.findtext("SessionToken", namespaces=ns) + + assert id + assert key + assert session_token + + except (aiohttp.ClientError, HTTPException) as e: + logger.warning("server connection error: {err}", err=repr(e)) + raise + except Exception as e: + logger.error("unexpected server connection error: {err}", err=repr(e)) + raise + + logger.info("Successfully generated temporary AWS credentials") + return id, key, session_token + + async def build_auth_headers(self, token=None, path=None): # if it's a simple HTTP server with a bearer token if self.server_type == PolicyBundleServerType.HTTP and token is not None: return tuple_to_dict(get_authorization_header(token)) @@ -136,6 +223,8 @@ def build_auth_headers(self, token=None, path=None): and token is not None and self.token_id is not None ): + logger.info("Using provided token to log in to AWS_S3") + split_url = urlparse(self.remote_source_url) host = split_url.netloc path = split_url.path + "/" + path @@ -143,7 +232,25 @@ def build_auth_headers(self, token=None, path=None): return build_aws_rest_auth_headers( self.token_id, token, host, path, self.region ) + elif ( + self.server_type == PolicyBundleServerType.AWS_S3 + and self.aws_role_arn is not None + and self.aws_web_id_token_file is not None + and self.region is not None + ): + logger.info("Using IAM Web auth to log in to AWS_S3") + + split_url = urlparse(self.remote_source_url) + host = split_url.netloc + path = split_url.path + "/" + path + + id, key, session_token = await self.get_temporary_sts_credentials() + + return build_aws_rest_auth_headers( + id, key, host, path, self.region, session_token + ) else: + logger.info("Not authenticating on bundle endpoint") return {} async def fetch_policy_bundle_from_api_source( @@ -166,7 +273,7 @@ async def fetch_policy_bundle_from_api_source( """ path = "bundle.tar.gz" - auth_headers = self.build_auth_headers(token=token, path=path) + auth_headers = await self.build_auth_headers(token=token, path=path) etag_headers = ( {"ETag": self.etag, "If-None-Match": self.etag} if self.etag else {} ) @@ -278,4 +385,4 @@ async def check_for_changes(self): prev_head=prev, new_head=latest, ) - await self._on_new_policy(old=prev_commit, new=new_commit) + await self._on_new_policy(old=prev_commit, new=new_commit) \ No newline at end of file diff --git a/packages/opal-common/opal_common/utils.py b/packages/opal-common/opal_common/utils.py index 3897c058f..f0e86fbed 100644 --- a/packages/opal-common/opal_common/utils.py +++ b/packages/opal-common/opal_common/utils.py @@ -8,7 +8,9 @@ import threading from datetime import datetime from hashlib import sha1 -from typing import Coroutine, Dict, List, Tuple +from typing import Callable, Coroutine, Dict, List, Tuple +import functools +import time import aiohttp @@ -57,7 +59,12 @@ def get_authorization_header(token: str) -> Tuple[str, str]: def build_aws_rest_auth_headers( - key_id: str, secret_key: str, host: str, path: str, region: str + key_id: str, + secret_key: str, + host: str, + path: str, + region: str, + token: str | None, ): """Use the AWS signature algorithm (https://docs.aws.amazon.com/AmazonS3/la test/userguide/RESTAuthentication.html) to generate the hTTP headers. @@ -67,6 +74,7 @@ def build_aws_rest_auth_headers( secret_key (str): Secret key (aka password) of an account in the S3 service. host (str): S3 storage host path (str): path to bundle file in s3 storage (including bucket) + token (str | None): Optional session token when using temporary credential. Returns: http headers """ @@ -91,6 +99,10 @@ def getSignatureKey(key, dateStamp, regionName, serviceName): canonical_headers = "host:" + host + "\n" + "x-amz-date:" + amzdate + "\n" signed_headers = "host;x-amz-date" + if token: + canonical_headers += f"x-amz-security-token:{token}\n" + signed_headers += ";x-amz-security-token" + payload_hash = hashlib.sha256("".encode("utf-8")).hexdigest() canonical_request = ( @@ -138,8 +150,13 @@ def getSignatureKey(key, dateStamp, regionName, serviceName): + signature ) + token_header: dict[str, str] = {} + if token: + token_header["x-amz-security-token"] = token + return { "x-amz-date": amzdate, + **token_header, "x-amz-content-sha256": SHA256_EMPTY, "Authorization": authorization_header, } @@ -275,3 +292,28 @@ def run_coro(self, coro: Coroutine): run_coro() is thread-safe. """ return asyncio.run_coroutine_threadsafe(coro, loop=self.loop).result() + + +def async_time_cache(ttl: float): + """ + This decorator is a wrapper around lru_cache that makes it time sensitive. + + ttl is in seconds + """ + + def decorator(func: Callable): + # instead of directly caching the function, a time "hash" is + # also passed in as a param that will invalidate the cache + # after at most ttl seconds + @functools.lru_cache + def wrapped(*args, __ttl_hash=None, **kwargs): + coro = func(*args, **kwargs) + return asyncio.ensure_future(coro) + + def ret(*args, **kwargs): + ttl_hash = round(time.time() / ttl) + return wrapped(*args, **kwargs, __ttl_hash=ttl_hash) + + return ret + + return decorator diff --git a/packages/opal-server/opal_server/config.py b/packages/opal-server/opal_server/config.py index b272915ad..d692fa90f 100644 --- a/packages/opal-server/opal_server/config.py +++ b/packages/opal-server/opal_server/config.py @@ -51,7 +51,8 @@ class OpalServerConfig(Confi): AUTH_PRIVATE_KEY_PASSPHRASE = confi.str("AUTH_PRIVATE_KEY_PASSPHRASE", None) AUTH_PRIVATE_KEY = confi.delay( - lambda AUTH_PRIVATE_KEY_FORMAT=None, AUTH_PRIVATE_KEY_PASSPHRASE="": confi.private_key( + lambda AUTH_PRIVATE_KEY_FORMAT=None, + AUTH_PRIVATE_KEY_PASSPHRASE="": confi.private_key( "AUTH_PRIVATE_KEY", default=None, key_format=AUTH_PRIVATE_KEY_FORMAT, @@ -133,6 +134,18 @@ class OpalServerConfig(Confi): "us-east-1", description="The AWS region of the S3 bucket", ) + POLICY_BUNDLE_AWS_ROLE_ARN = confi.str( + "AWS_ROLE_ARN", + # default to the env var injected by aws + os.getenv("AWS_ROLE_ARN"), + description="The IAM role to be used when accessing the bundle server. This is set by AWS automatically in EKS, but can be overridden if required.", + ) + POLICY_BUNDLE_AWS_WEB_IDENTITY_TOKEN_FILE = confi.str( + "AWS_WEB_IDENTITY_TOKEN_FILE", + # default to the env var injected by aws + os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE"), + description="The oidc token for the IAM role to be used when accessing the bundle server. This is set by AWS automatically in EKS, but can be overridden if required.", + ) POLICY_BUNDLE_TMP_PATH = confi.str( "POLICY_BUNDLE_TMP_PATH", "/tmp/bundle.tar.gz", diff --git a/packages/opal-server/opal_server/policy/watcher/factory.py b/packages/opal-server/opal_server/policy/watcher/factory.py index 6d94d6fc4..e08d7a997 100644 --- a/packages/opal-server/opal_server/policy/watcher/factory.py +++ b/packages/opal-server/opal_server/policy/watcher/factory.py @@ -1,5 +1,6 @@ from functools import partial from typing import Any, List, Optional +import os from fastapi_websocket_pubsub.pub_sub_server import PubSubEndpoint from opal_common.confi.confi import load_conf_if_none @@ -129,6 +130,8 @@ def setup_watcher_task( policy_bundle_path=opal_server_config.POLICY_BUNDLE_TMP_PATH, policy_bundle_git_add_pattern=opal_server_config.POLICY_BUNDLE_GIT_ADD_PATTERN, region=policy_bundle_aws_region, + aws_role_arn=opal_server_config.POLICY_BUNDLE_AWS_ROLE_ARN, + aws_web_id_token_file=opal_server_config.POLICY_BUNDLE_AWS_WEB_IDENTITY_TOKEN_FILE, ) else: raise ValueError("Unknown value for OPAL_POLICY_SOURCE_TYPE")