From 1850e21d75cc26b2b9cec3696ed1a563f53d238f Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Thu, 17 Oct 2024 15:29:14 +0000 Subject: [PATCH 01/14] improve access edge performance --- plugins/aws/fix_plugin_aws/access_edges.py | 60 ++++++++++++++++++---- 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges.py index 5357fa4621..d47e5f1629 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges.py @@ -2,9 +2,9 @@ from attr import frozen, define import networkx from fix_plugin_aws.resource.base import AwsAccount, AwsResource, GraphBuilder - +from policy_sentry.querying.actions import get_actions_for_service from typing import Dict, List, Literal, Set, Optional, Tuple, Union, Pattern - +import fnmatch from networkx.algorithms.dag import is_directed_acyclic_graph from fixlib.baseresources import ( @@ -101,18 +101,60 @@ def find_all_allowed_actions(all_involved_policies: List[PolicyDocument], resour policy_actions.update(find_allowed_action(p, service_prefix)) return policy_actions.intersection(resource_actions) +@lru_cache(maxsize=1024) +def expand(action: str, service_prefix: str) -> list[str]: + if action == "*": + return get_actions_for_service(service_prefix=service_prefix) + elif "*" in action: + prefix = action.split(":", maxsplit=1)[0] + if prefix != service_prefix: + return [] + service_actions = get_actions_for_service(service_prefix=prefix) + expanded = [ + expanded_action + for expanded_action in service_actions + if fnmatch.fnmatchcase(expanded_action.lower(), action.lower()) + ] + + if not expanded: + return [action] + + return expanded + return [action] -def get_expanded_action(statement: StatementDetail, service_prefix: str) -> Set[str]: - actions = set() - expanded: List[str] = statement.expanded_actions or [] - for action in expanded: - if action.startswith(f"{service_prefix}:"): - actions.add(action) - return actions +def determine_actions_to_expand(action_list: list[str], service_prefix: str) -> list[str]: + new_action_list = [] + for action in action_list: + if "*" in action: + expanded_action = expand(action, service_prefix) + new_action_list.extend(expanded_action) + elif action.startswith(service_prefix): + new_action_list.append(action) + new_action_list.sort() + return new_action_list + + +@lru_cache(maxsize=4096) +def statement_expanded_actions(statement: StatementDetail, service_prefix: str) -> List[str]: + if statement.actions: + expanded: list[str] = determine_actions_to_expand(statement.actions, service_prefix) + return expanded + elif statement.not_action: + not_actions = statement.not_action_effective_actions or [] + return [na for na in not_actions if na.startswith(service_prefix)] + else: + log.warning("Statement has neither Actions nor NotActions") + return [] @lru_cache(maxsize=1024) +def get_expanded_action(statement: StatementDetail, service_prefix: str) -> List[str]: + expanded: List[str] = statement_expanded_actions(statement, service_prefix) + return expanded + + +@lru_cache(maxsize=10000) def make_resoruce_regex(aws_resorce_wildcard: str) -> Pattern[str]: # step 1: translate aws wildcard to python regex python_regex = aws_resorce_wildcard.replace("*", ".*").replace("?", ".") From c5cdb28e57a6352b7aa7b15f3806992f705ff65b Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Thu, 17 Oct 2024 15:34:50 +0000 Subject: [PATCH 02/14] linter fix --- plugins/aws/fix_plugin_aws/access_edges.py | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges.py index d47e5f1629..37cc49ca84 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges.py @@ -101,6 +101,7 @@ def find_all_allowed_actions(all_involved_policies: List[PolicyDocument], resour policy_actions.update(find_allowed_action(p, service_prefix)) return policy_actions.intersection(resource_actions) + @lru_cache(maxsize=1024) def expand(action: str, service_prefix: str) -> list[str]: if action == "*": From 3ab808154df4364387745c1cc72bde691e1fe77c Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Thu, 24 Oct 2024 15:36:16 +0000 Subject: [PATCH 03/14] optimize action wildcard matching --- plugins/aws/fix_plugin_aws/access_edges.py | 51 +++++++++++++++++++--- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges.py index 37cc49ca84..d3949d8f98 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges.py @@ -155,7 +155,7 @@ def get_expanded_action(statement: StatementDetail, service_prefix: str) -> List return expanded -@lru_cache(maxsize=10000) +@lru_cache(maxsize=1024) def make_resoruce_regex(aws_resorce_wildcard: str) -> Pattern[str]: # step 1: translate aws wildcard to python regex python_regex = aws_resorce_wildcard.replace("*", ".*").replace("?", ".") @@ -163,7 +163,7 @@ def make_resoruce_regex(aws_resorce_wildcard: str) -> Pattern[str]: return re.compile(f"^{python_regex}$", re.IGNORECASE) -def expand_wildcards_and_match(*, identifier: str, wildcard_string: str) -> bool: +def _expand_wildcards_and_match(*, identifier: str, wildcard_string: str) -> bool: """ helper function to expand wildcards and match the identifier @@ -175,6 +175,45 @@ def expand_wildcards_and_match(*, identifier: str, wildcard_string: str) -> bool return pattern.match(identifier) is not None +def expand_action_wildcards_and_match(action: str, wildcard_pattern: str) -> bool: + + if action == wildcard_pattern: + return True + + if wildcard_pattern == "*": + return True + + action = action.lower() + wildcard_pattern = wildcard_pattern.lower() + + splitted_action = action.split(":") + if len(splitted_action) < 2: + log.warning(f"Resource action {action} is not in the expected format") + return False + + action_service = splitted_action[0] + + splitted_pattern = wildcard_pattern.split(':') + if len(splitted_pattern) < 2: + log.warning(f"Wildcard action {wildcard_pattern} is not in the expected format") + return False + pattern_service = splitted_pattern[0] + pattern_name = splitted_pattern[1] + + if action_service != pattern_service: + return False + + if pattern_name == '*': + return True + + # all the other cases + return _expand_wildcards_and_match(identifier=action, wildcard_string=wildcard_pattern) + + +def expand_arn_wildcards_and_match(identifier: str, wildcard_string: str) -> bool: + return _expand_wildcards_and_match(identifier=identifier, wildcard_string=wildcard_string) + + def check_statement_match( statement: StatementDetail, effect: Optional[Literal["Allow", "Deny"]], @@ -241,14 +280,14 @@ def check_statement_match( action_match = False else: for a in statement.actions: - if expand_wildcards_and_match(identifier=action, wildcard_string=a): + if expand_action_wildcards_and_match(action=action, wildcard_pattern=a): action_match = True break else: # not_action action_match = True for na in statement.not_action: - if expand_wildcards_and_match(identifier=action, wildcard_string=na): + if expand_action_wildcards_and_match(action=action, wildcard_pattern=na): action_match = False break if not action_match: @@ -260,14 +299,14 @@ def check_statement_match( resource_matches = False if len(statement.resources) > 0: for resource_constraint in statement.resources: - if expand_wildcards_and_match(identifier=resource.arn, wildcard_string=resource_constraint): + if expand_arn_wildcards_and_match(identifier=resource.arn, wildcard_string=resource_constraint): matched_resource_constraints.append(resource_constraint) resource_matches = True break elif len(statement.not_resource) > 0: resource_matches = True for not_resource_constraint in statement.not_resource: - if expand_wildcards_and_match(identifier=resource.arn, wildcard_string=not_resource_constraint): + if expand_arn_wildcards_and_match(identifier=resource.arn, wildcard_string=not_resource_constraint): resource_matches = False break matched_resource_constraints.append("not " + not_resource_constraint) From 8f447dd09b224bd60bfbbd8f10ab1cbc0c0a448b Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Thu, 24 Oct 2024 16:37:28 +0000 Subject: [PATCH 04/14] more efficient action wildcard matching --- plugins/aws/fix_plugin_aws/access_edges.py | 73 +++++++++++++++------- 1 file changed, 52 insertions(+), 21 deletions(-) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges.py index d3949d8f98..99dce59c6c 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges.py @@ -175,39 +175,70 @@ def _expand_wildcards_and_match(*, identifier: str, wildcard_string: str) -> boo return pattern.match(identifier) is not None -def expand_action_wildcards_and_match(action: str, wildcard_pattern: str) -> bool: +@lru_cache(maxsize=1024) +def _compile_action_pattern(wildcard_pattern: str) -> tuple[str, str, re.Pattern[str] | None]: + """ + Compile and cache the action pattern components. + Returns (service, action_pattern, compiled_regex) + """ + wildcard_pattern = wildcard_pattern.lower() + parts = wildcard_pattern.split(':', 1) + if len(parts) != 2: + raise ValueError(f"Invalid action pattern format: {wildcard_pattern}") + + service, action_pattern = parts + + # Convert AWS wildcard pattern to regex pattern + if '*' in action_pattern: + pattern = '^' + re.escape(action_pattern).replace('\\*', '.*') + '$' + compiled = re.compile(pattern) + else: + compiled = None + + return service, action_pattern, compiled + +def expand_action_wildcards_and_match(action: str, wildcard_pattern: str) -> bool: + # Short circuit for exact matches if action == wildcard_pattern: return True + # Short circuit for global wildcard if wildcard_pattern == "*": return True - - action = action.lower() - wildcard_pattern = wildcard_pattern.lower() - splitted_action = action.split(":") - if len(splitted_action) < 2: - log.warning(f"Resource action {action} is not in the expected format") + # Normalize action + action = action.lower() + + # Split action + try: + action_service, action_name = action.split(':', 1) + except ValueError: return False - - action_service = splitted_action[0] - - splitted_pattern = wildcard_pattern.split(':') - if len(splitted_pattern) < 2: - log.warning(f"Wildcard action {wildcard_pattern} is not in the expected format") + + # Get cached pattern components + try: + pattern_service, pattern_action, compiled_regex = _compile_action_pattern(wildcard_pattern) + except ValueError: return False - pattern_service = splitted_pattern[0] - pattern_name = splitted_pattern[1] - + + # Check service match if action_service != pattern_service: return False - - if pattern_name == '*': + + # Handle full service wildcard + if pattern_action == '*': return True - - # all the other cases - return _expand_wildcards_and_match(identifier=action, wildcard_string=wildcard_pattern) + + # Handle exact action match + if pattern_action == action_name: + return True + + # Handle regex pattern match + if compiled_regex: + return bool(compiled_regex.match(action_name)) + + return False def expand_arn_wildcards_and_match(identifier: str, wildcard_string: str) -> bool: From ef20aae37f99d415cff278beabec145ed2a2b0f5 Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Fri, 25 Oct 2024 10:30:09 +0000 Subject: [PATCH 05/14] linter fix --- plugins/aws/fix_plugin_aws/access_edges.py | 30 +++++++++++----------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges.py index 99dce59c6c..84c8a95bdb 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges.py @@ -182,19 +182,19 @@ def _compile_action_pattern(wildcard_pattern: str) -> tuple[str, str, re.Pattern Returns (service, action_pattern, compiled_regex) """ wildcard_pattern = wildcard_pattern.lower() - parts = wildcard_pattern.split(':', 1) + parts = wildcard_pattern.split(":", 1) if len(parts) != 2: raise ValueError(f"Invalid action pattern format: {wildcard_pattern}") - + service, action_pattern = parts - + # Convert AWS wildcard pattern to regex pattern - if '*' in action_pattern: - pattern = '^' + re.escape(action_pattern).replace('\\*', '.*') + '$' + if "*" in action_pattern: + pattern = "^" + re.escape(action_pattern).replace("\\*", ".*") + "$" compiled = re.compile(pattern) else: compiled = None - + return service, action_pattern, compiled @@ -209,35 +209,35 @@ def expand_action_wildcards_and_match(action: str, wildcard_pattern: str) -> boo # Normalize action action = action.lower() - + # Split action try: - action_service, action_name = action.split(':', 1) + action_service, action_name = action.split(":", 1) except ValueError: return False - + # Get cached pattern components try: pattern_service, pattern_action, compiled_regex = _compile_action_pattern(wildcard_pattern) except ValueError: return False - + # Check service match if action_service != pattern_service: return False - + # Handle full service wildcard - if pattern_action == '*': + if pattern_action == "*": return True - + # Handle exact action match if pattern_action == action_name: return True - + # Handle regex pattern match if compiled_regex: return bool(compiled_regex.match(action_name)) - + return False From 033c52846e82badd6ab5e28a4b05dee633580eb6 Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Fri, 25 Oct 2024 16:17:58 +0000 Subject: [PATCH 06/14] moar perf --- plugins/aws/fix_plugin_aws/access_edges.py | 197 +++++++++++++-------- plugins/aws/test/acccess_edges_test.py | 103 ++++++----- 2 files changed, 179 insertions(+), 121 deletions(-) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges.py index 84c8a95bdb..f00d3fdf92 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges.py @@ -1,3 +1,4 @@ +from enum import Enum from functools import lru_cache from attr import frozen, define import networkx @@ -34,19 +35,70 @@ ALL_ACTIONS = get_all_actions() +class WildcardKind(Enum): + fixed = 1 + pattern = 2 + any = 3 + + +@frozen(slots=True) +class ActionWildcardPattern: + pattern: str + service: str + kind: WildcardKind + + +class FixStatementDetail(StatementDetail): + def __init__(self, statement: Json): + super().__init__(statement) + + def pattern_from_action(action: str) -> ActionWildcardPattern: + if action == "*": + return ActionWildcardPattern(pattern=action, service="*", kind=WildcardKind.any) + + action = action.lower() + service, action_name = action.split(":", 1) + if action_name == "*": + kind = WildcardKind.any + elif "*" in action_name: + kind = WildcardKind.pattern + else: + kind = WildcardKind.fixed + + return ActionWildcardPattern(pattern=action, service=service, kind=kind) + + self.actions_patterns = [pattern_from_action(action) for action in self.actions] + self.not_action_patterns = [pattern_from_action(action) for action in self.not_action] + + +class FixPolicyDocument(PolicyDocument): + def __init__(self, policy_document: Json): + super().__init__(policy_document) + + self.fix_statements = [FixStatementDetail(statement.json) for statement in self.statements] + + +@frozen(slots=True) +class ActionToCheck: + raw: str + raw_lower: str + service: str + action_name: str + + @define(slots=True) class IamRequestContext: principal: AwsResource - identity_policies: List[Tuple[PolicySource, PolicyDocument]] - permission_boundaries: List[PolicyDocument] # todo: use them too + identity_policies: List[Tuple[PolicySource, FixPolicyDocument]] + permission_boundaries: List[FixPolicyDocument] # todo: use them too # all service control policies applicable to the principal, # starting from the root, then all org units, then the account - service_control_policy_levels: List[List[PolicyDocument]] + service_control_policy_levels: List[List[FixPolicyDocument]] # technically we should also add a list of session policies here, but they don't exist in the collector context def all_policies( - self, resource_based_policies: Optional[List[Tuple[PolicySource, PolicyDocument]]] = None - ) -> List[PolicyDocument]: + self, resource_based_policies: Optional[List[Tuple[PolicySource, FixPolicyDocument]]] = None + ) -> List[FixPolicyDocument]: return ( [p[1] for p in self.identity_policies] + self.permission_boundaries @@ -81,7 +133,7 @@ def find_non_service_actions(resource_arn: str) -> Set[IamAction]: return set() -def find_all_allowed_actions(all_involved_policies: List[PolicyDocument], resource_arn: str) -> Set[IamAction]: +def find_all_allowed_actions(all_involved_policies: List[FixPolicyDocument], resource_arn: str) -> Set[IamAction]: resource_actions = set() try: resource_actions = set(get_actions_matching_arn(resource_arn)) @@ -176,7 +228,7 @@ def _expand_wildcards_and_match(*, identifier: str, wildcard_string: str) -> boo @lru_cache(maxsize=1024) -def _compile_action_pattern(wildcard_pattern: str) -> tuple[str, str, re.Pattern[str] | None]: +def _compile_action_pattern(wildcard_pattern: str) -> tuple[str, re.Pattern[str] | None]: """ Compile and cache the action pattern components. Returns (service, action_pattern, compiled_regex) @@ -186,7 +238,7 @@ def _compile_action_pattern(wildcard_pattern: str) -> tuple[str, str, re.Pattern if len(parts) != 2: raise ValueError(f"Invalid action pattern format: {wildcard_pattern}") - service, action_pattern = parts + _, action_pattern = parts # Convert AWS wildcard pattern to regex pattern if "*" in action_pattern: @@ -195,48 +247,33 @@ def _compile_action_pattern(wildcard_pattern: str) -> tuple[str, str, re.Pattern else: compiled = None - return service, action_pattern, compiled + return action_pattern, compiled -def expand_action_wildcards_and_match(action: str, wildcard_pattern: str) -> bool: - # Short circuit for exact matches - if action == wildcard_pattern: - return True +def expand_action_wildcards_and_match(action: ActionToCheck, wildcard_pattern: ActionWildcardPattern) -> bool: - # Short circuit for global wildcard - if wildcard_pattern == "*": + if wildcard_pattern.kind == WildcardKind.any: return True - # Normalize action - action = action.lower() + if wildcard_pattern.kind == WildcardKind.fixed: + return action.raw_lower == wildcard_pattern.pattern - # Split action - try: - action_service, action_name = action.split(":", 1) - except ValueError: + if action.service != wildcard_pattern.service: return False # Get cached pattern components try: - pattern_service, pattern_action, compiled_regex = _compile_action_pattern(wildcard_pattern) + pattern_action, compiled_regex = _compile_action_pattern(wildcard_pattern.pattern) except ValueError: return False - # Check service match - if action_service != pattern_service: - return False - - # Handle full service wildcard - if pattern_action == "*": - return True - # Handle exact action match - if pattern_action == action_name: + if pattern_action == action.action_name: return True # Handle regex pattern match if compiled_regex: - return bool(compiled_regex.match(action_name)) + return bool(compiled_regex.match(action.action_name)) return False @@ -246,9 +283,9 @@ def expand_arn_wildcards_and_match(identifier: str, wildcard_string: str) -> boo def check_statement_match( - statement: StatementDetail, + statement: FixStatementDetail, effect: Optional[Literal["Allow", "Deny"]], - action: str, + action: ActionToCheck, resource: AwsResource, principal: Optional[AwsResource], source_arn: Optional[str] = None, @@ -304,20 +341,20 @@ def check_statement_match( if statement.actions: # shortcuts for known AWS managed policies if source_arn == "arn:aws:iam::aws:policy/ReadOnlyAccess": - action_level = get_action_level(action) + action_level = get_action_level(action.raw) if action_level in [PermissionLevel.read or PermissionLevel.list]: action_match = True else: action_match = False else: - for a in statement.actions: + for a in statement.actions_patterns: if expand_action_wildcards_and_match(action=action, wildcard_pattern=a): action_match = True break else: # not_action action_match = True - for na in statement.not_action: + for na in statement.not_action_patterns: if expand_action_wildcards_and_match(action=action, wildcard_pattern=na): action_match = False break @@ -374,22 +411,22 @@ def check_principal_match(principal: AwsResource, aws_principal_list: List[str]) def collect_matching_statements( *, - policy: PolicyDocument, + policy: FixPolicyDocument, effect: Optional[Literal["Allow", "Deny"]], - action: str, + action: ActionToCheck, resource: AwsResource, principal: Optional[AwsResource], source_arn: Optional[str] = None, -) -> List[Tuple[StatementDetail, List[ResourceConstraint]]]: +) -> List[Tuple[FixStatementDetail, List[ResourceConstraint]]]: """ resoruce based policies contain principal field and need to be handled differently """ - results: List[Tuple[StatementDetail, List[ResourceConstraint]]] = [] + results: List[Tuple[FixStatementDetail, List[ResourceConstraint]]] = [] if resource.arn is None: raise ValueError("Resource ARN is missing, go and fix the filtering logic") - for statement in policy.statements: + for statement in policy.fix_statements: matches, maybe_resource_constraint = check_statement_match( statement, effect=effect, action=action, resource=resource, principal=principal, source_arn=source_arn @@ -403,8 +440,8 @@ def collect_matching_statements( def check_explicit_deny( request_context: IamRequestContext, resource: AwsResource, - action: str, - resource_based_policies: List[Tuple[PolicySource, PolicyDocument]], + action: ActionToCheck, + resource_based_policies: List[Tuple[PolicySource, FixPolicyDocument]], ) -> Union[Literal["Denied", "NextStep"], List[Json]]: denied_when_any_is_true: List[Json] = [] @@ -450,7 +487,7 @@ def check_explicit_deny( return "NextStep" -def scp_allowed(request_context: IamRequestContext, action: str, resource: AwsResource) -> bool: +def scp_allowed(request_context: IamRequestContext, action: ActionToCheck, resource: AwsResource) -> bool: # traverse the SCPs: root -> OU -> account levels for scp_level_policies in request_context.service_control_policy_levels: @@ -492,9 +529,9 @@ class Deny: # as a shortcut we return the first allow statement we find, or a first seen condition. def check_resource_based_policies( principal: AwsResource, - action: str, + action: ActionToCheck, resource: AwsResource, - resource_based_policies: List[Tuple[PolicySource, PolicyDocument]], + resource_based_policies: List[Tuple[PolicySource, FixPolicyDocument]], ) -> ResourceBasedPolicyResult: assert resource.arn @@ -552,7 +589,7 @@ def check_resource_based_policies( def check_identity_based_policies( - request_context: IamRequestContext, resource: AwsResource, action: str + request_context: IamRequestContext, resource: AwsResource, action: ActionToCheck ) -> List[PermissionScope]: scopes: List[PermissionScope] = [] @@ -571,7 +608,7 @@ def check_identity_based_policies( def check_permission_boundaries( - request_context: IamRequestContext, resource: AwsResource, action: str + request_context: IamRequestContext, resource: AwsResource, action: ActionToCheck ) -> Union[Literal["Denied", "NextStep"], List[Json]]: conditions: List[Json] = [] @@ -642,8 +679,8 @@ def get_action_level(action: str) -> PermissionLevel: def check_policies( request_context: IamRequestContext, resource: AwsResource, - action: str, - resource_based_policies: List[Tuple[PolicySource, PolicyDocument]], + action: ActionToCheck, + resource_based_policies: List[Tuple[PolicySource, FixPolicyDocument]], ) -> Optional[AccessPermission]: # when any of the conditions evaluate to true, the action is explicitly denied @@ -686,7 +723,9 @@ def check_policies( for scope in scopes: final_resource_scopes.add(scope.with_deny_conditions(deny_conditions)) - return AccessPermission(action=action, level=get_action_level(action), scopes=tuple(final_resource_scopes)) + return AccessPermission( + action=action.raw, level=get_action_level(action.raw), scopes=tuple(final_resource_scopes) + ) if isinstance(resource_result, Continue): scopes = resource_result.scopes allowed_scopes.extend(scopes) @@ -720,7 +759,7 @@ def check_policies( # we don't collect session principals and session policies, so this step is skipped # 7. if we reached here, the action is allowed - level = get_action_level(action) + level = get_action_level(action.raw) final_scopes: Set[PermissionScope] = set() for scope in allowed_scopes: @@ -740,7 +779,7 @@ def check_policies( # return the result return AccessPermission( - action=action, + action=action.raw, level=level, scopes=tuple(final_scopes), ) @@ -749,7 +788,7 @@ def check_policies( def compute_permissions( resource: AwsResource, iam_context: IamRequestContext, - resource_based_policies: List[Tuple[PolicySource, PolicyDocument]], + resource_based_policies: List[Tuple[PolicySource, FixPolicyDocument]], ) -> List[AccessPermission]: assert resource.arn @@ -760,7 +799,16 @@ def compute_permissions( # step 2: for every action, check if it is allowed for action in relevant_actions: - if p := check_policies(iam_context, resource, action, resource_based_policies): + try: + service, action_name = action.split(":", 1) + except ValueError: + log.error(f"Invalid action: {action}") + continue + + action_to_check = ActionToCheck( + service=service.lower(), action_name=action_name.lower(), raw_lower=action.lower(), raw=action + ) + if p := check_policies(iam_context, resource, action_to_check, resource_based_policies): all_permissions.append(p) return all_permissions @@ -776,11 +824,11 @@ def __init__(self, builder: GraphBuilder): def _init_principals(self) -> None: account_id = self.builder.account.id - service_control_policy_levels: List[List[PolicyDocument]] = [] + service_control_policy_levels: List[List[FixPolicyDocument]] = [] account = next(self.builder.nodes(clazz=AwsAccount, filter=lambda a: a.id == account_id), None) if account and account._service_control_policies: service_control_policy_levels = [ - [PolicyDocument(json) for json in level] for level in account._service_control_policies + [FixPolicyDocument(json) for json in level] for level in account._service_control_policies ] for node in self.builder.nodes(clazz=AwsResource): @@ -788,11 +836,12 @@ def _init_principals(self) -> None: identity_based_policies = self._get_user_based_policies(node) - permission_boundaries: List[PolicyDocument] = [] + permission_boundaries: List[FixPolicyDocument] = [] if (pb := node.user_permissions_boundary) and (pb_arn := pb.permissions_boundary_arn): for pb_policy in self.builder.nodes(clazz=AwsIamPolicy, filter=lambda p: p.arn == pb_arn): if pdj := pb_policy.policy_document_json(): - permission_boundaries.append(PolicyDocument(pdj)) + pd = FixPolicyDocument(pdj) + permission_boundaries.append(pd) request_context = IamRequestContext( principal=node, @@ -822,7 +871,7 @@ def _init_principals(self) -> None: if (pb := node.role_permissions_boundary) and (pb_arn := pb.permissions_boundary_arn): for pb_policy in self.builder.nodes(clazz=AwsIamPolicy, filter=lambda p: p.arn == pb_arn): if pdj := pb_policy.policy_document_json(): - permission_boundaries.append(PolicyDocument(pdj)) + permission_boundaries.append(FixPolicyDocument(pdj)) request_context = IamRequestContext( principal=node, @@ -833,11 +882,11 @@ def _init_principals(self) -> None: self.principals.append(request_context) - def _get_user_based_policies(self, principal: AwsIamUser) -> List[Tuple[PolicySource, PolicyDocument]]: + def _get_user_based_policies(self, principal: AwsIamUser) -> List[Tuple[PolicySource, FixPolicyDocument]]: inline_policies = [ ( PolicySource(kind=PolicySourceKind.principal, uri=principal.arn or ""), - PolicyDocument(policy.policy_document), + FixPolicyDocument(policy.policy_document), ) for policy in principal.user_policies if policy.policy_document @@ -850,7 +899,7 @@ def _get_user_based_policies(self, principal: AwsIamUser) -> List[Tuple[PolicySo attached_policies.append( ( PolicySource(kind=PolicySourceKind.principal, uri=to_node.arn or ""), - PolicyDocument(doc), + FixPolicyDocument(doc), ) ) @@ -862,7 +911,7 @@ def _get_user_based_policies(self, principal: AwsIamUser) -> List[Tuple[PolicySo group_policies.append( ( PolicySource(kind=PolicySourceKind.group, uri=group.arn or ""), - PolicyDocument(policy.policy_document), + FixPolicyDocument(policy.policy_document), ) ) # attached group policies @@ -872,18 +921,18 @@ def _get_user_based_policies(self, principal: AwsIamUser) -> List[Tuple[PolicySo group_policies.append( ( PolicySource(kind=PolicySourceKind.group, uri=group_successor.arn or ""), - PolicyDocument(doc), + FixPolicyDocument(doc), ) ) return inline_policies + attached_policies + group_policies - def _get_group_based_policies(self, principal: AwsIamGroup) -> List[Tuple[PolicySource, PolicyDocument]]: + def _get_group_based_policies(self, principal: AwsIamGroup) -> List[Tuple[PolicySource, FixPolicyDocument]]: # not really a principal, but could be useful to have access edges for groups inline_policies = [ ( PolicySource(kind=PolicySourceKind.group, uri=principal.arn or ""), - PolicyDocument(policy.policy_document), + FixPolicyDocument(policy.policy_document), ) for policy in principal.group_policies if policy.policy_document @@ -896,19 +945,19 @@ def _get_group_based_policies(self, principal: AwsIamGroup) -> List[Tuple[Policy attached_policies.append( ( PolicySource(kind=PolicySourceKind.group, uri=to_node.arn or ""), - PolicyDocument(doc), + FixPolicyDocument(doc), ) ) return inline_policies + attached_policies - def _get_role_based_policies(self, principal: AwsIamRole) -> List[Tuple[PolicySource, PolicyDocument]]: + def _get_role_based_policies(self, principal: AwsIamRole) -> List[Tuple[PolicySource, FixPolicyDocument]]: inline_policies = [] for doc in [p.policy_document for p in principal.role_policies if p.policy_document]: inline_policies.append( ( PolicySource(kind=PolicySourceKind.principal, uri=principal.arn or ""), - PolicyDocument(doc), + FixPolicyDocument(doc), ) ) @@ -919,7 +968,7 @@ def _get_role_based_policies(self, principal: AwsIamRole) -> List[Tuple[PolicySo attached_policies.append( ( PolicySource(kind=PolicySourceKind.principal, uri=to_node.arn or ""), - PolicyDocument(policy_doc), + FixPolicyDocument(policy_doc), ) ) @@ -934,10 +983,10 @@ def add_access_edges(self) -> None: # small graph cycles avoidance optimization continue - resource_policies: List[Tuple[PolicySource, PolicyDocument]] = [] + resource_policies: List[Tuple[PolicySource, FixPolicyDocument]] = [] if isinstance(node, HasResourcePolicy): for source, json_policy in node.resource_policy(self.builder): - resource_policies.append((source, PolicyDocument(json_policy))) + resource_policies.append((source, FixPolicyDocument(json_policy))) permissions = compute_permissions(node, context, resource_policies) diff --git a/plugins/aws/test/acccess_edges_test.py b/plugins/aws/test/acccess_edges_test.py index f3a2c5c375..380b041513 100644 --- a/plugins/aws/test/acccess_edges_test.py +++ b/plugins/aws/test/acccess_edges_test.py @@ -1,5 +1,4 @@ from cloudsplaining.scan.policy_document import PolicyDocument -from cloudsplaining.scan.statement_detail import StatementDetail from fix_plugin_aws.resource.base import AwsResource from fix_plugin_aws.resource.iam import AwsIamUser, AwsIamGroup, AwsIamRole @@ -14,6 +13,9 @@ IamRequestContext, check_explicit_deny, compute_permissions, + FixPolicyDocument, + FixStatementDetail, + ActionToCheck, ) from fixlib.baseresources import PolicySourceKind, PolicySource, PermissionLevel @@ -66,6 +68,13 @@ def test_make_resoruce_regex() -> None: assert not regex.match("arn:aws:s3:::my-bucket/abc") +def atc(action: str) -> ActionToCheck: + splitted = action.split(":") + return ActionToCheck( + raw=action, raw_lower=action.lower(), service=splitted[0].lower(), action_name=splitted[1].lower() + ) + + def test_check_statement_match1() -> None: allow_statement = { "Effect": "Allow", @@ -73,41 +82,41 @@ def test_check_statement_match1() -> None: "Resource": "arn:aws:s3:::example-bucket/*", "Principal": {"AWS": ["arn:aws:iam::123456789012:user/example-user"]}, } - statement = StatementDetail(allow_statement) + statement = FixStatementDetail(allow_statement) resource = AwsResource(id="bucket", arn="arn:aws:s3:::example-bucket/object.txt") principal = AwsResource(id="principal", arn="arn:aws:iam::123456789012:user/example-user") # Test matching statement - result, constraints = check_statement_match(statement, "Allow", "s3:GetObject", resource, principal) + result, constraints = check_statement_match(statement, "Allow", atc("s3:GetObject"), resource, principal) assert result is True assert constraints == ["arn:aws:s3:::example-bucket/*"] # Test wrong effect - result, constraints = check_statement_match(statement, "Deny", "s3:GetObject", resource, principal) + result, constraints = check_statement_match(statement, "Deny", atc("s3:GetObject"), resource, principal) assert result is False assert constraints == [] # wrong principal does not match - result, constraints = check_statement_match(statement, "Allow", "s3:GetObject", resource, resource) + result, constraints = check_statement_match(statement, "Allow", atc("s3:GetObject"), resource, resource) assert result is False # Test statement with condition allow_statement["Condition"] = {"StringEquals": {"s3:prefix": "private/"}} - statement = StatementDetail(allow_statement) - result, constraints = check_statement_match(statement, "Allow", "s3:GetObject", resource, principal) + statement = FixStatementDetail(allow_statement) + result, constraints = check_statement_match(statement, "Allow", atc("s3:GetObject"), resource, principal) assert result is True # not providing principaal works - result, constraints = check_statement_match(statement, "Allow", "s3:GetObject", resource, principal=None) + result, constraints = check_statement_match(statement, "Allow", atc("s3:GetObject"), resource, principal=None) assert result is True # not providing effect works result, constraints = check_statement_match( - statement, effect=None, action="s3:GetObject", resource=resource, principal=None + statement, effect=None, action=atc("s3:GetObject"), resource=resource, principal=None ) assert result is True - result, constraints = check_statement_match(statement, "Allow", "s3:GetObject", resource, principal) + result, constraints = check_statement_match(statement, "Allow", atc("s3:GetObject"), resource, principal) assert result is True assert constraints == ["arn:aws:s3:::example-bucket/*"] @@ -118,8 +127,8 @@ def test_check_statement_match1() -> None: "Principal": {"AWS": ["arn:aws:iam::123456789012:user/example-user"]}, } - statement = StatementDetail(deny_statement) - result, constraints = check_statement_match(statement, "Deny", "s3:GetObject", resource, principal) + statement = FixStatementDetail(deny_statement) + result, constraints = check_statement_match(statement, "Deny", atc("s3:GetObject"), resource, principal) assert result is True assert constraints == ["arn:aws:s3:::example-bucket/*"] @@ -127,8 +136,8 @@ def test_check_statement_match1() -> None: not_resource_statement = dict(allow_statement) del not_resource_statement["Resource"] not_resource_statement["NotResource"] = "arn:aws:s3:::example-bucket/private/*" - statement = StatementDetail(not_resource_statement) - result, constraints = check_statement_match(statement, "Allow", "s3:GetObject", resource, principal) + statement = FixStatementDetail(not_resource_statement) + result, constraints = check_statement_match(statement, "Allow", atc("s3:GetObject"), resource, principal) assert result is True assert constraints == ["not arn:aws:s3:::example-bucket/private/*"] @@ -162,7 +171,7 @@ def test_no_explicit_deny() -> None: ) resource = AwsResource(id="some-resource", arn="arn:aws:s3:::example-bucket") - action = "s3:GetObject" + action = atc("s3:GetObject") result = check_explicit_deny(request_context, resource, action, resource_based_policies=[]) assert result == "NextStep" @@ -177,10 +186,10 @@ def test_explicit_deny_in_identity_policy() -> None: "Version": "2012-10-17", "Statement": [{"Effect": "Deny", "Action": "s3:GetObject", "Resource": "arn:aws:s3:::example-bucket/*"}], } - policy_document = PolicyDocument(policy_json) + policy_document = FixPolicyDocument(policy_json) identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=principal.arn), policy_document)] - permission_boundaries: List[PolicyDocument] = [] - service_control_policy_levels: List[List[PolicyDocument]] = [] + permission_boundaries: List[FixPolicyDocument] = [] + service_control_policy_levels: List[List[FixPolicyDocument]] = [] request_context = IamRequestContext( principal=principal, @@ -190,7 +199,7 @@ def test_explicit_deny_in_identity_policy() -> None: ) resource = AwsResource(id="some-resource", arn="arn:aws:s3:::example-bucket/object.txt") - action = "s3:GetObject" + action = atc("s3:GetObject") result = check_explicit_deny(request_context, resource, action, resource_based_policies=[]) assert result == "Denied" @@ -212,7 +221,7 @@ def test_explicit_deny_with_condition_in_identity_policy() -> None: } ], } - policy_document = PolicyDocument(policy_json) + policy_document = FixPolicyDocument(policy_json) identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=principal.arn), policy_document)] request_context = IamRequestContext( @@ -223,7 +232,7 @@ def test_explicit_deny_with_condition_in_identity_policy() -> None: ) resource = AwsResource(id="some-resource", arn="arn:aws:s3:::example-bucket/object.txt") - action = "s3:GetObject" + action = atc("s3:GetObject") result = check_explicit_deny(request_context, resource, action, resource_based_policies=[]) expected_conditions = [policy_json["Statement"][0]["Condition"]] @@ -238,7 +247,7 @@ def test_explicit_deny_in_scp() -> None: "Version": "2012-10-17", "Statement": [{"Effect": "Deny", "Action": "s3:GetObject", "Resource": "*"}], } - scp_policy_document = PolicyDocument(scp_policy_json) + scp_policy_document = FixPolicyDocument(scp_policy_json) service_control_policy_levels = [[scp_policy_document]] request_context = IamRequestContext( @@ -249,7 +258,7 @@ def test_explicit_deny_in_scp() -> None: ) resource = AwsResource(id="some-resource", arn="arn:aws:s3:::example-bucket/object.txt") - action = "s3:GetObject" + action = atc("s3:GetObject") result = check_explicit_deny(request_context, resource, action, resource_based_policies=[]) assert result == "Denied" @@ -270,7 +279,7 @@ def test_explicit_deny_with_condition_in_scp() -> None: } ], } - scp_policy_document = PolicyDocument(scp_policy_json) + scp_policy_document = FixPolicyDocument(scp_policy_json) service_control_policy_levels = [ [ scp_policy_document, @@ -285,7 +294,7 @@ def test_explicit_deny_with_condition_in_scp() -> None: ) resource = AwsResource(id="some-resource", arn="arn:aws:s3:::example-bucket/object.txt") - action = "s3:GetObject" + action = atc("s3:GetObject") result = check_explicit_deny(request_context, resource, action, resource_based_policies=[]) expected_conditions = [scp_policy_json["Statement"][0]["Condition"]] @@ -314,13 +323,13 @@ def test_explicit_deny_in_resource_policy() -> None: } ], } - policy_document = PolicyDocument(policy_json) + policy_document = FixPolicyDocument(policy_json) resource_based_policies = [ (PolicySource(kind=PolicySourceKind.resource, uri="arn:aws:s3:::example-bucket"), policy_document) ] resource = AwsResource(id="some-resource", arn="arn:aws:s3:::example-bucket/object.txt") - action = "s3:GetObject" + action = atc("s3:GetObject") result = check_explicit_deny(request_context, resource, action, resource_based_policies) assert result == "Denied" @@ -349,13 +358,13 @@ def test_explicit_deny_with_condition_in_resource_policy() -> None: } ], } - policy_document = PolicyDocument(policy_json) + policy_document = FixPolicyDocument(policy_json) resource_based_policies = [ (PolicySource(kind=PolicySourceKind.resource, uri="arn:aws:s3:::example-bucket"), policy_document) ] resource = AwsResource(id="some-resource", arn="arn:aws:s3:::example-bucket/object.txt") - action = "s3:GetObject" + action = atc("s3:GetObject") result = check_explicit_deny(request_context, resource, action, resource_based_policies) expected_conditions = [policy_json["Statement"][0]["Condition"]] @@ -379,7 +388,7 @@ def test_compute_permissions_user_inline_policy_allow() -> None: } ], } - policy_document = PolicyDocument(policy_json) + policy_document = FixPolicyDocument(policy_json) identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), policy_document)] @@ -418,7 +427,7 @@ def test_compute_permissions_user_inline_policy_allow_with_conditions() -> None: } ], } - policy_document = PolicyDocument(policy_json) + policy_document = FixPolicyDocument(policy_json) identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), policy_document)] @@ -456,7 +465,7 @@ def test_compute_permissions_user_inline_policy_deny() -> None: } ], } - policy_document = PolicyDocument(policy_json) + policy_document = FixPolicyDocument(policy_json) identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), policy_document)] @@ -489,7 +498,7 @@ def test_compute_permissions_user_inline_policy_deny_with_condition() -> None: } ], } - policy_document = PolicyDocument(policy_json) + policy_document = FixPolicyDocument(policy_json) identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), policy_document)] @@ -520,7 +529,7 @@ def test_deny_overrides_allow() -> None: } ], } - deny_policy_document = PolicyDocument(deny_policy_json) + deny_policy_document = FixPolicyDocument(deny_policy_json) allow_policy_json = { "Version": "2012-10-17", @@ -533,7 +542,7 @@ def test_deny_overrides_allow() -> None: } ], } - allow_policy_document = PolicyDocument(allow_policy_json) + allow_policy_document = FixPolicyDocument(allow_policy_json) identity_policies = [ (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), deny_policy_document), @@ -566,7 +575,7 @@ def test_deny_different_action_does_not_override_allow() -> None: } ], } - deny_policy_document = PolicyDocument(deny_policy_json) + deny_policy_document = FixPolicyDocument(deny_policy_json) allow_policy_json = { "Version": "2012-10-17", @@ -579,7 +588,7 @@ def test_deny_different_action_does_not_override_allow() -> None: } ], } - allow_policy_document = PolicyDocument(allow_policy_json) + allow_policy_document = FixPolicyDocument(allow_policy_json) identity_policies = [ (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), deny_policy_document), @@ -615,7 +624,7 @@ def test_deny_overrides_allow_with_condition() -> None: } ], } - deny_policy_document = PolicyDocument(deny_policy_json) + deny_policy_document = FixPolicyDocument(deny_policy_json) allow_policy_json = { "Version": "2012-10-17", @@ -628,7 +637,7 @@ def test_deny_overrides_allow_with_condition() -> None: } ], } - allow_policy_document = PolicyDocument(allow_policy_json) + allow_policy_document = FixPolicyDocument(allow_policy_json) identity_policies = [ (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), deny_policy_document), @@ -672,7 +681,7 @@ def test_compute_permissions_resource_based_policy_allow() -> None: } ], } - policy_document = PolicyDocument(policy_json) + policy_document = FixPolicyDocument(policy_json) request_context = IamRequestContext( principal=user, identity_policies=[], permission_boundaries=[], service_control_policy_levels=[] @@ -718,7 +727,7 @@ def test_compute_permissions_permission_boundary_restrict() -> None: }, ], } - identity_policy_document = PolicyDocument(identity_policy_json) + identity_policy_document = FixPolicyDocument(identity_policy_json) permission_boundary_json = { "Version": "2012-10-17", @@ -726,7 +735,7 @@ def test_compute_permissions_permission_boundary_restrict() -> None: {"Sid": "Boundary", "Effect": "Allow", "Action": ["s3:ListBucket", "s3:PutObject"], "Resource": "*"} ], } - permission_boundary_document = PolicyDocument(permission_boundary_json) + permission_boundary_document = FixPolicyDocument(permission_boundary_json) identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), identity_policy_document)] @@ -769,7 +778,7 @@ def test_compute_permissions_scp_deny() -> None: } ], } - identity_policy_document = PolicyDocument(identity_policy_json) + identity_policy_document = FixPolicyDocument(identity_policy_json) scp_policy_json = { "Version": "2012-10-17", @@ -777,7 +786,7 @@ def test_compute_permissions_scp_deny() -> None: {"Sid": "DenyTerminateInstances", "Effect": "Deny", "Action": "ec2:TerminateInstances", "Resource": "*"} ], } - scp_policy_document = PolicyDocument(scp_policy_json) + scp_policy_document = FixPolicyDocument(scp_policy_json) identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), identity_policy_document)] @@ -808,7 +817,7 @@ def test_compute_permissions_user_with_group_policies() -> None: {"Sid": "AllowS3ListBucket", "Effect": "Allow", "Action": "s3:ListBucket", "Resource": bucket.arn} ], } - group_policy_document = PolicyDocument(group_policy_json) + group_policy_document = FixPolicyDocument(group_policy_json) identity_policies = [] @@ -862,7 +871,7 @@ def test_compute_permissions_group_inline_policy_allow() -> None: } ], } - policy_document = PolicyDocument(policy_json) + policy_document = FixPolicyDocument(policy_json) identity_policies = [(PolicySource(kind=PolicySourceKind.group, uri=group.arn), policy_document)] @@ -899,7 +908,7 @@ def test_compute_permissions_role_inline_policy_allow() -> None: } ], } - policy_document = PolicyDocument(policy_json) + policy_document = FixPolicyDocument(policy_json) identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=role.arn), policy_document)] From b175c7c0045f696314839a8b8cadc9e1690af0f3 Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Tue, 29 Oct 2024 12:19:20 +0000 Subject: [PATCH 07/14] pre-compute actions on the resource level --- plugins/aws/fix_plugin_aws/access_edges.py | 72 ++++++++++++++-- plugins/aws/test/acccess_edges_test.py | 97 ++++++++++++++++++---- 2 files changed, 147 insertions(+), 22 deletions(-) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges.py index f00d3fdf92..85703b3578 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges.py @@ -22,8 +22,10 @@ from cloudsplaining.scan.policy_document import PolicyDocument from cloudsplaining.scan.statement_detail import StatementDetail -from policy_sentry.querying.actions import get_action_data, get_actions_matching_arn +from policy_sentry.querying.actions import get_action_data from policy_sentry.querying.all import get_all_actions +from policy_sentry.querying.arns import get_matching_raw_arns, get_resource_type_name_with_raw_arn +from policy_sentry.shared.iam_data import get_service_prefix_data from policy_sentry.util.arns import ARN, get_service_from_arn from fixlib.graph import EdgeKey import re @@ -133,12 +135,46 @@ def find_non_service_actions(resource_arn: str) -> Set[IamAction]: return set() -def find_all_allowed_actions(all_involved_policies: List[FixPolicyDocument], resource_arn: str) -> Set[IamAction]: - resource_actions = set() +@lru_cache(maxsize=1024) +def get_actions_matching_raw_arn(raw_arn: str) -> set[str]: + results: set[str] = set() + resource_type_name = get_resource_type_name_with_raw_arn(raw_arn) + if resource_type_name is None: + return results + + service_prefix = get_service_from_arn(raw_arn) + service_prefix_data = get_service_prefix_data(service_prefix) + for action_name, action_data in service_prefix_data["privileges"].items(): + if resource_type_name.lower() in action_data["resource_types_lower_name"]: + results.add(f"{service_prefix}:{action_name}") + + return results + + +def get_actions_matching_arn(arn: str) -> set[str]: + """ + Given a user-supplied ARN, get a list of all actions that correspond to that ARN. + + Arguments: + arn: A user-supplied arn + Returns: + List: A list of all actions that can match it. + """ + results = set() try: - resource_actions = set(get_actions_matching_arn(resource_arn)) + raw_arns = get_matching_raw_arns(arn) + for raw_arn in raw_arns: + raw_arn_actions = get_actions_matching_raw_arn(raw_arn) + results.update(raw_arn_actions) except Exception as e: - log.debug(f"Error when trying to get actions matching ARN {resource_arn}: {e}") + log.debug(f"Error when trying to get actions for ARN {arn}: {e}") + + return results + + +def find_all_allowed_actions( + all_involved_policies: List[FixPolicyDocument], resource_arn: str, resource_actions: set[IamAction] +) -> Set[IamAction]: if additinal_actions := find_non_service_actions(resource_arn): resource_actions.update(additinal_actions) @@ -789,11 +825,16 @@ def compute_permissions( resource: AwsResource, iam_context: IamRequestContext, resource_based_policies: List[Tuple[PolicySource, FixPolicyDocument]], + resource_actions: set[IamAction], ) -> List[AccessPermission]: assert resource.arn # step 1: find the relevant action to check - relevant_actions = find_all_allowed_actions(iam_context.all_policies(resource_based_policies), resource.arn) + relevant_actions = find_all_allowed_actions( + iam_context.all_policies(resource_based_policies), + resource.arn, + resource_actions, + ) all_permissions: List[AccessPermission] = [] @@ -820,6 +861,7 @@ def __init__(self, builder: GraphBuilder): self.builder = builder self.principals: List[IamRequestContext] = [] self._init_principals() + self.actions_for_resource: Dict[str, set[IamAction]] = self._compute_actions_for_resource() def _init_principals(self) -> None: @@ -882,6 +924,18 @@ def _init_principals(self) -> None: self.principals.append(request_context) + def _compute_actions_for_resource(self) -> Dict[str, set[IamAction]]: + + actions_for_resource: Dict[str, set[IamAction]] = {} + + for node in self.builder.nodes(clazz=AwsResource, filter=lambda r: r.arn is not None): + if not node.arn: + continue + + actions_for_resource[node.arn] = get_actions_matching_arn(node.arn) + + return actions_for_resource + def _get_user_based_policies(self, principal: AwsIamUser) -> List[Tuple[PolicySource, FixPolicyDocument]]: inline_policies = [ ( @@ -977,7 +1031,7 @@ def _get_role_based_policies(self, principal: AwsIamRole) -> List[Tuple[PolicySo def add_access_edges(self) -> None: for node in self.builder.nodes(clazz=AwsResource, filter=lambda r: r.arn is not None): - + assert node.arn for context in self.principals: if context.principal.arn == node.arn: # small graph cycles avoidance optimization @@ -988,7 +1042,9 @@ def add_access_edges(self) -> None: for source, json_policy in node.resource_policy(self.builder): resource_policies.append((source, FixPolicyDocument(json_policy))) - permissions = compute_permissions(node, context, resource_policies) + permissions = compute_permissions( + node, context, resource_policies, self.actions_for_resource.get(node.arn, set()) + ) if not permissions: continue diff --git a/plugins/aws/test/acccess_edges_test.py b/plugins/aws/test/acccess_edges_test.py index 380b041513..684f489cf1 100644 --- a/plugins/aws/test/acccess_edges_test.py +++ b/plugins/aws/test/acccess_edges_test.py @@ -16,6 +16,7 @@ FixPolicyDocument, FixStatementDetail, ActionToCheck, + get_actions_matching_arn, ) from fixlib.baseresources import PolicySourceKind, PolicySource, PermissionLevel @@ -396,7 +397,12 @@ def test_compute_permissions_user_inline_policy_allow() -> None: principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] ) - permissions = compute_permissions(resource=bucket, iam_context=request_context, resource_based_policies=[]) + permissions = compute_permissions( + resource=bucket, + iam_context=request_context, + resource_based_policies=[], + resource_actions=get_actions_matching_arn(bucket.arn or ""), + ) assert len(permissions) == 1 assert permissions[0].action == "s3:ListBucket" assert permissions[0].level == PermissionLevel.list @@ -435,7 +441,12 @@ def test_compute_permissions_user_inline_policy_allow_with_conditions() -> None: principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] ) - permissions = compute_permissions(resource=bucket, iam_context=request_context, resource_based_policies=[]) + permissions = compute_permissions( + resource=bucket, + iam_context=request_context, + resource_based_policies=[], + resource_actions=get_actions_matching_arn(bucket.arn or ""), + ) assert len(permissions) == 1 assert permissions[0].action == "s3:ListBucket" assert permissions[0].level == PermissionLevel.list @@ -473,7 +484,12 @@ def test_compute_permissions_user_inline_policy_deny() -> None: principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] ) - permissions = compute_permissions(resource=bucket, iam_context=request_context, resource_based_policies=[]) + permissions = compute_permissions( + resource=bucket, + iam_context=request_context, + resource_based_policies=[], + resource_actions=get_actions_matching_arn(bucket.arn or ""), + ) assert len(permissions) == 0 @@ -506,7 +522,12 @@ def test_compute_permissions_user_inline_policy_deny_with_condition() -> None: principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] ) - permissions = compute_permissions(resource=bucket, iam_context=request_context, resource_based_policies=[]) + permissions = compute_permissions( + resource=bucket, + iam_context=request_context, + resource_based_policies=[], + resource_actions=get_actions_matching_arn(bucket.arn or ""), + ) # deny does not grant any permissions by itself, even if the condition is met assert len(permissions) == 0 @@ -553,7 +574,12 @@ def test_deny_overrides_allow() -> None: principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] ) - permissions = compute_permissions(resource=bucket, iam_context=request_context, resource_based_policies=[]) + permissions = compute_permissions( + resource=bucket, + iam_context=request_context, + resource_based_policies=[], + resource_actions=get_actions_matching_arn(bucket.arn or ""), + ) assert len(permissions) == 0 @@ -599,7 +625,12 @@ def test_deny_different_action_does_not_override_allow() -> None: principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] ) - permissions = compute_permissions(resource=bucket, iam_context=request_context, resource_based_policies=[]) + permissions = compute_permissions( + resource=bucket, + iam_context=request_context, + resource_based_policies=[], + resource_actions=get_actions_matching_arn(bucket.arn or ""), + ) assert len(permissions) == 1 @@ -648,7 +679,12 @@ def test_deny_overrides_allow_with_condition() -> None: principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] ) - permissions = compute_permissions(resource=bucket, iam_context=request_context, resource_based_policies=[]) + permissions = compute_permissions( + resource=bucket, + iam_context=request_context, + resource_based_policies=[], + resource_actions=get_actions_matching_arn(bucket.arn or ""), + ) assert len(permissions) == 1 p = permissions[0] @@ -690,7 +726,10 @@ def test_compute_permissions_resource_based_policy_allow() -> None: resource_based_policies = [(PolicySource(kind=PolicySourceKind.resource, uri=bucket.arn), policy_document)] permissions = compute_permissions( - resource=bucket, iam_context=request_context, resource_based_policies=resource_based_policies + resource=bucket, + iam_context=request_context, + resource_based_policies=resource_based_policies, + resource_actions=get_actions_matching_arn(bucket.arn or ""), ) assert len(permissions) == 1 @@ -748,7 +787,12 @@ def test_compute_permissions_permission_boundary_restrict() -> None: service_control_policy_levels=[], ) - permissions = compute_permissions(resource=bucket, iam_context=request_context, resource_based_policies=[]) + permissions = compute_permissions( + resource=bucket, + iam_context=request_context, + resource_based_policies=[], + resource_actions=get_actions_matching_arn(bucket.arn or ""), + ) assert len(permissions) == 1 p = permissions[0] @@ -799,7 +843,12 @@ def test_compute_permissions_scp_deny() -> None: service_control_policy_levels=service_control_policy_levels, ) - permissions = compute_permissions(resource=ec2_instance, iam_context=request_context, resource_based_policies=[]) + permissions = compute_permissions( + resource=ec2_instance, + iam_context=request_context, + resource_based_policies=[], + resource_actions=get_actions_matching_arn(ec2_instance.arn or ""), + ) assert len(permissions) == 0 @@ -827,7 +876,12 @@ def test_compute_permissions_user_with_group_policies() -> None: principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] ) - permissions = compute_permissions(resource=bucket, iam_context=request_context, resource_based_policies=[]) + permissions = compute_permissions( + resource=bucket, + iam_context=request_context, + resource_based_policies=[], + resource_actions=get_actions_matching_arn(bucket.arn or ""), + ) assert len(permissions) == 1 p = permissions[0] @@ -848,7 +902,12 @@ def test_compute_permissions_implicit_deny() -> None: principal=user, identity_policies=[], permission_boundaries=[], service_control_policy_levels=[] ) - permissions = compute_permissions(resource=table, iam_context=request_context, resource_based_policies=[]) + permissions = compute_permissions( + resource=table, + iam_context=request_context, + resource_based_policies=[], + resource_actions=get_actions_matching_arn(table.arn or ""), + ) # Assert that permissions do not include any actions (implicit deny) assert len(permissions) == 0 @@ -879,7 +938,12 @@ def test_compute_permissions_group_inline_policy_allow() -> None: principal=group, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] ) - permissions = compute_permissions(resource=bucket, iam_context=request_context, resource_based_policies=[]) + permissions = compute_permissions( + resource=bucket, + iam_context=request_context, + resource_based_policies=[], + resource_actions=get_actions_matching_arn(bucket.arn or ""), + ) assert len(permissions) == 1 assert permissions[0].action == "s3:ListBucket" @@ -916,7 +980,12 @@ def test_compute_permissions_role_inline_policy_allow() -> None: principal=role, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] ) - permissions = compute_permissions(resource=bucket, iam_context=request_context, resource_based_policies=[]) + permissions = compute_permissions( + resource=bucket, + iam_context=request_context, + resource_based_policies=[], + resource_actions=get_actions_matching_arn(bucket.arn or ""), + ) assert len(permissions) == 1 assert permissions[0].action == "s3:ListBucket" From 0af9023d9ee4dbc7a33b4fed70e1059c43d0f3da Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Tue, 5 Nov 2024 16:03:27 +0000 Subject: [PATCH 08/14] cache non-resource computations --- plugins/aws/fix_plugin_aws/access_edges.py | 290 ++++++++++--------- plugins/aws/test/acccess_edges_test.py | 314 +++++++++++---------- 2 files changed, 324 insertions(+), 280 deletions(-) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges.py index 85703b3578..5e60367687 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges.py @@ -1,10 +1,10 @@ from enum import Enum from functools import lru_cache -from attr import frozen, define +from attr import frozen import networkx from fix_plugin_aws.resource.base import AwsAccount, AwsResource, GraphBuilder from policy_sentry.querying.actions import get_actions_for_service -from typing import Dict, List, Literal, Set, Optional, Tuple, Union, Pattern +from typing import Callable, Dict, List, Literal, Set, Optional, Tuple, Union, Pattern import fnmatch from networkx.algorithms.dag import is_directed_acyclic_graph @@ -88,22 +88,22 @@ class ActionToCheck: action_name: str -@define(slots=True) +@frozen(slots=True) class IamRequestContext: principal: AwsResource - identity_policies: List[Tuple[PolicySource, FixPolicyDocument]] - permission_boundaries: List[FixPolicyDocument] # todo: use them too + identity_policies: Tuple[Tuple[PolicySource, FixPolicyDocument], ...] + permission_boundaries: Tuple[FixPolicyDocument, ...] # todo: use them too # all service control policies applicable to the principal, # starting from the root, then all org units, then the account - service_control_policy_levels: List[List[FixPolicyDocument]] + service_control_policy_levels: Tuple[Tuple[FixPolicyDocument, ...], ...] # technically we should also add a list of session policies here, but they don't exist in the collector context def all_policies( - self, resource_based_policies: Optional[List[Tuple[PolicySource, FixPolicyDocument]]] = None + self, resource_based_policies: Optional[Tuple[Tuple[PolicySource, FixPolicyDocument], ...]] = None ) -> List[FixPolicyDocument]: return ( [p[1] for p in self.identity_policies] - + self.permission_boundaries + + list(self.permission_boundaries) + [p for group in self.service_control_policy_levels for p in group] + ([p[1] for p in (resource_based_policies or [])]) ) @@ -121,12 +121,11 @@ def find_allowed_action(policy_document: PolicyDocument, service_prefix: str) -> return allowed_actions -def find_non_service_actions(resource_arn: str) -> Set[IamAction]: +def find_non_service_actions(resource_arn: ARN) -> Set[IamAction]: try: - splitted = resource_arn.split(":") - service_prefix = splitted[2] + service_prefix = resource_arn.service_prefix if service_prefix == "iam": - resource_type = splitted[5] + resource_type = resource_arn.resource_string resource = resource_type.split("/")[0] if resource == "role": return {"sts:AssumeRole"} @@ -173,7 +172,7 @@ def get_actions_matching_arn(arn: str) -> set[str]: def find_all_allowed_actions( - all_involved_policies: List[FixPolicyDocument], resource_arn: str, resource_actions: set[IamAction] + all_involved_policies: List[FixPolicyDocument], resource_arn: ARN, resource_actions: set[IamAction] ) -> Set[IamAction]: if additinal_actions := find_non_service_actions(resource_arn): @@ -181,7 +180,7 @@ def find_all_allowed_actions( service_prefix = "" try: - service_prefix = get_service_from_arn(resource_arn) + service_prefix = resource_arn.service_prefix except Exception as e: log.debug(f"Error when trying to get service prefix from ARN {resource_arn}: {e}") policy_actions: Set[IamAction] = set() @@ -318,21 +317,19 @@ def expand_arn_wildcards_and_match(identifier: str, wildcard_string: str) -> boo return _expand_wildcards_and_match(identifier=identifier, wildcard_string=wildcard_string) +@lru_cache(maxsize=4096) def check_statement_match( statement: FixStatementDetail, effect: Optional[Literal["Allow", "Deny"]], action: ActionToCheck, - resource: AwsResource, principal: Optional[AwsResource], source_arn: Optional[str] = None, -) -> Tuple[bool, List[ResourceConstraint]]: +) -> Union[None, Callable[[ARN], Optional[List[ResourceConstraint]]]]: """ - check if a statement matches the given effect, action, resource and principal, - returns boolean if there is a match and optional resource constraint (if there were any) + check if a statement matches the given effect, action, and principal, + returns None if there is no match no matter what the resource is, + or a callable that can be used to check if the resource matches """ - if resource.arn is None: - raise ValueError("Resource ARN is missing, go and fix the filtering logic") - # step 1: check the principal if provided if principal: principal_match = False @@ -364,13 +361,13 @@ def check_statement_match( if not principal_match: # principal does not match, we can shortcut here - return False, [] + return None # step 2: check if the effect matches if effect: if statement.effect != effect: # wrong effect, skip this statement - return False, [] + return None # step 3: check if the action matches action_match = False @@ -396,34 +393,37 @@ def check_statement_match( break if not action_match: # action does not match, skip this statement - return False, [] - - # step 4: check if the resource matches - matched_resource_constraints: List[ResourceConstraint] = [] - resource_matches = False - if len(statement.resources) > 0: - for resource_constraint in statement.resources: - if expand_arn_wildcards_and_match(identifier=resource.arn, wildcard_string=resource_constraint): - matched_resource_constraints.append(resource_constraint) - resource_matches = True - break - elif len(statement.not_resource) > 0: - resource_matches = True - for not_resource_constraint in statement.not_resource: - if expand_arn_wildcards_and_match(identifier=resource.arn, wildcard_string=not_resource_constraint): - resource_matches = False - break - matched_resource_constraints.append("not " + not_resource_constraint) - else: - # no Resource/NotResource specified, consider allowed - resource_matches = True - if not resource_matches: - # resource does not match, skip this statement - return False, [] + return None + + def check_resource_match(arn: ARN) -> Optional[List[ResourceConstraint]]: + # step 4: check if the resource matches + matched_resource_constraints: List[ResourceConstraint] = [] + resource_matches = False + if len(statement.resources) > 0: + for resource_constraint in statement.resources: + if expand_arn_wildcards_and_match(identifier=arn.arn, wildcard_string=resource_constraint): + matched_resource_constraints.append(resource_constraint) + resource_matches = True + break + elif len(statement.not_resource) > 0: + resource_matches = True + for not_resource_constraint in statement.not_resource: + if expand_arn_wildcards_and_match(identifier=arn.arn, wildcard_string=not_resource_constraint): + resource_matches = False + break + matched_resource_constraints.append("not " + not_resource_constraint) + else: + # no Resource/NotResource specified, consider allowed + resource_matches = True + if not resource_matches: + # resource does not match, skip this statement + return None - # step 5: (we're not doing this yet) check if the condition matches - # here we just return the statement and condition checking is the responsibility of the caller - return (True, matched_resource_constraints) + # step 5: (we're not doing this yet) check if the condition matches + # here we just return the statement and condition checking is the responsibility of the caller + return matched_resource_constraints + + return check_resource_match def check_principal_match(principal: AwsResource, aws_principal_list: List[str]) -> bool: @@ -445,93 +445,106 @@ def check_principal_match(principal: AwsResource, aws_principal_list: List[str]) return False +@lru_cache(maxsize=4096) def collect_matching_statements( *, policy: FixPolicyDocument, effect: Optional[Literal["Allow", "Deny"]], action: ActionToCheck, - resource: AwsResource, principal: Optional[AwsResource], source_arn: Optional[str] = None, -) -> List[Tuple[FixStatementDetail, List[ResourceConstraint]]]: +) -> Callable[[ARN], List[Tuple[FixStatementDetail, List[ResourceConstraint]]]]: """ resoruce based policies contain principal field and need to be handled differently """ - results: List[Tuple[FixStatementDetail, List[ResourceConstraint]]] = [] - - if resource.arn is None: - raise ValueError("Resource ARN is missing, go and fix the filtering logic") + matching_fns: List[Tuple[FixStatementDetail, Callable[[ARN], Optional[List[ResourceConstraint]]]]] = [] for statement in policy.fix_statements: - matches, maybe_resource_constraint = check_statement_match( - statement, effect=effect, action=action, resource=resource, principal=principal, source_arn=source_arn + match_fn = check_statement_match( + statement, effect=effect, action=action, principal=principal, source_arn=source_arn ) - if matches: - results.append((statement, maybe_resource_constraint)) + if not match_fn: + continue - return results + matching_fns.append((statement, match_fn)) + + def collect_matching_statements_closure(resource: ARN) -> List[Tuple[FixStatementDetail, List[ResourceConstraint]]]: + results: List[Tuple[FixStatementDetail, List[ResourceConstraint]]] = [] + for statement, match_fn in matching_fns: + if constraints := match_fn(resource): + results.append((statement, constraints)) + + return results + return collect_matching_statements_closure + +@lru_cache(maxsize=4096) def check_explicit_deny( request_context: IamRequestContext, - resource: AwsResource, action: ActionToCheck, - resource_based_policies: List[Tuple[PolicySource, FixPolicyDocument]], -) -> Union[Literal["Denied", "NextStep"], List[Json]]: + resource_based_policies: Tuple[Tuple[PolicySource, FixPolicyDocument], ...], +) -> Callable[[ARN], Union[Literal["Denied", "NextStep"], List[Json]]]: - denied_when_any_is_true: List[Json] = [] + matching_fns = [] # we should skip service control policies for service linked roles if not is_service_linked_role(request_context.principal): for scp_level in request_context.service_control_policy_levels: for policy in scp_level: - policy_statements = collect_matching_statements( - policy=policy, effect="Deny", action=action, resource=resource, principal=request_context.principal + matching_fn = collect_matching_statements( + policy=policy, effect="Deny", action=action, principal=request_context.principal ) - for statement, _ in policy_statements: - if statement.condition: - denied_when_any_is_true.append(statement.condition) - else: - return "Denied" + matching_fns.append(matching_fn) # check permission boundaries for policy in request_context.permission_boundaries: - policy_statements = collect_matching_statements( - policy=policy, effect="Deny", action=action, resource=resource, principal=request_context.principal + matching_fn = collect_matching_statements( + policy=policy, effect="Deny", action=action, principal=request_context.principal ) - for statement, _ in policy_statements: - if statement.condition: - denied_when_any_is_true.append(statement.condition) - else: - return "Denied" + matching_fns.append(matching_fn) # check the rest of the policies - for _, policy in request_context.identity_policies + resource_based_policies: - policy_statements = collect_matching_statements( - policy=policy, effect="Deny", action=action, resource=resource, principal=request_context.principal + for _, policy in request_context.identity_policies: + matching_fn = collect_matching_statements( + policy=policy, effect="Deny", action=action, principal=request_context.principal ) - for statement, _ in policy_statements: - if statement.condition: - denied_when_any_is_true.append(statement.condition) - else: - return "Denied" + matching_fns.append(matching_fn) + + for _, policy in resource_based_policies: + matching_fn = collect_matching_statements( + policy=policy, effect="Deny", action=action, principal=request_context.principal + ) + matching_fns.append(matching_fn) + + def check_explicit_deny_closure(arn: ARN) -> Union[Literal["Denied", "NextStep"], List[Json]]: + + denied_when_any_is_true: List[Json] = [] - if denied_when_any_is_true: - return denied_when_any_is_true + for matching_fn in matching_fns: + for statement, _ in matching_fn(arn): + if statement.condition: + denied_when_any_is_true.append(statement.condition) + else: + return "Denied" - return "NextStep" + if denied_when_any_is_true: + return denied_when_any_is_true + return "NextStep" -def scp_allowed(request_context: IamRequestContext, action: ActionToCheck, resource: AwsResource) -> bool: + return check_explicit_deny_closure + + +def scp_allowed(request_context: IamRequestContext, action: ActionToCheck, resource: ARN) -> bool: # traverse the SCPs: root -> OU -> account levels for scp_level_policies in request_context.service_control_policy_levels: level_allows = False for policy in scp_level_policies: - statements = collect_matching_statements( - policy=policy, effect="Allow", action=action, resource=resource, principal=None - ) + matching_fn = collect_matching_statements(policy=policy, effect="Allow", action=action, principal=None) + statements = matching_fn(resource) if statements: # 'Allow' statements in SCP can't have conditions, we do not check them level_allows = True @@ -566,27 +579,26 @@ class Deny: def check_resource_based_policies( principal: AwsResource, action: ActionToCheck, - resource: AwsResource, - resource_based_policies: List[Tuple[PolicySource, FixPolicyDocument]], + resource: ARN, + resource_based_policies: Tuple[Tuple[PolicySource, FixPolicyDocument], ...], ) -> ResourceBasedPolicyResult: - assert resource.arn scopes: List[PermissionScope] = [] - arn = ARN(resource.arn) + arn = resource explicit_allow_required = False if arn.service_prefix == "iam" or arn.service_prefix == "kms": explicit_allow_required = True for source, policy in resource_based_policies: - matching_statements = collect_matching_statements( + matching_fn = collect_matching_statements( policy=policy, effect="Allow", action=action, - resource=resource, principal=principal, ) + matching_statements = matching_fn(arn) if len(matching_statements) == 0: continue @@ -624,27 +636,38 @@ def check_resource_based_policies( return Continue(scopes) +@lru_cache(maxsize=4096) def check_identity_based_policies( - request_context: IamRequestContext, resource: AwsResource, action: ActionToCheck -) -> List[PermissionScope]: + request_context: IamRequestContext, action: ActionToCheck +) -> Callable[[ARN], List[PermissionScope]]: - scopes: List[PermissionScope] = [] + matching_fns: List[ + Tuple[PolicySource, Callable[[ARN], List[Tuple[FixStatementDetail, List[ResourceConstraint]]]]] + ] = [] for source, policy in request_context.identity_policies: - for statement, resource_constraints in collect_matching_statements( - policy=policy, effect="Allow", action=action, resource=resource, principal=None, source_arn=source.uri - ): - conditions = None - if statement.condition: - conditions = PermissionCondition(allow=(to_json_str(statement.condition),)) + matching_fn = collect_matching_statements( + policy=policy, effect="Allow", action=action, principal=None, source_arn=source.uri + ) + matching_fns.append((source, matching_fn)) + + def check_identity_policies_closure(resource: ARN) -> List[PermissionScope]: + scopes: List[PermissionScope] = [] + for source, matching_fn in matching_fns: + for statement, resource_constraints in matching_fn(resource): + conditions = None + if statement.condition: + conditions = PermissionCondition(allow=(to_json_str(statement.condition),)) - scopes.append(PermissionScope(source, tuple(resource_constraints), conditions=conditions)) + scopes.append(PermissionScope(source, tuple(resource_constraints), conditions=conditions)) - return scopes + return scopes + + return check_identity_policies_closure def check_permission_boundaries( - request_context: IamRequestContext, resource: AwsResource, action: ActionToCheck + request_context: IamRequestContext, resource: ARN, action: ActionToCheck ) -> Union[Literal["Denied", "NextStep"], List[Json]]: conditions: List[Json] = [] @@ -652,9 +675,8 @@ def check_permission_boundaries( # ignore policy sources and resource constraints because permission boundaries # can never allow access to a resource, only restrict it for policy in request_context.permission_boundaries: - for statement, _ in collect_matching_statements( - policy=policy, effect="Allow", action=action, resource=resource, principal=None - ): + matching_fn = collect_matching_statements(policy=policy, effect="Allow", action=action, principal=None) + for statement, _ in matching_fn(resource): if statement.condition: assert isinstance(statement.condition, dict) conditions.append(statement.condition) @@ -714,9 +736,9 @@ def get_action_level(action: str) -> PermissionLevel: # logic according to https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_evaluation-logic.html def check_policies( request_context: IamRequestContext, - resource: AwsResource, + resource: ARN, action: ActionToCheck, - resource_based_policies: List[Tuple[PolicySource, FixPolicyDocument]], + resource_based_policies: Tuple[Tuple[PolicySource, FixPolicyDocument], ...], ) -> Optional[AccessPermission]: # when any of the conditions evaluate to true, the action is explicitly denied @@ -732,7 +754,7 @@ def check_policies( allowed_scopes: List[PermissionScope] = [] # 1. check for explicit deny. If denied, we can abort immediately - result = check_explicit_deny(request_context, resource, action, resource_based_policies) + result = check_explicit_deny(request_context, action, resource_based_policies)(resource) if result == "Denied": return None elif result == "NextStep": @@ -786,7 +808,7 @@ def check_policies( return None # otherwise continue with the resource based policies else: - identity_based_allowed = check_identity_based_policies(request_context, resource, action) + identity_based_allowed = check_identity_based_policies(request_context, action)(resource) if not identity_based_allowed: return None allowed_scopes.extend(identity_based_allowed) @@ -822,17 +844,16 @@ def check_policies( def compute_permissions( - resource: AwsResource, + resource: ARN, iam_context: IamRequestContext, - resource_based_policies: List[Tuple[PolicySource, FixPolicyDocument]], + resource_based_policies: Tuple[Tuple[PolicySource, FixPolicyDocument], ...], resource_actions: set[IamAction], ) -> List[AccessPermission]: - assert resource.arn # step 1: find the relevant action to check relevant_actions = find_all_allowed_actions( iam_context.all_policies(resource_based_policies), - resource.arn, + resource, resource_actions, ) @@ -866,17 +887,17 @@ def __init__(self, builder: GraphBuilder): def _init_principals(self) -> None: account_id = self.builder.account.id - service_control_policy_levels: List[List[FixPolicyDocument]] = [] + service_control_policy_levels: tuple[tuple[FixPolicyDocument, ...], ...] = () account = next(self.builder.nodes(clazz=AwsAccount, filter=lambda a: a.id == account_id), None) if account and account._service_control_policies: - service_control_policy_levels = [ - [FixPolicyDocument(json) for json in level] for level in account._service_control_policies - ] + service_control_policy_levels = tuple( + [tuple([FixPolicyDocument(json) for json in level]) for level in account._service_control_policies] + ) for node in self.builder.nodes(clazz=AwsResource): if isinstance(node, AwsIamUser): - identity_based_policies = self._get_user_based_policies(node) + identity_based_policies = tuple(self._get_user_based_policies(node)) permission_boundaries: List[FixPolicyDocument] = [] if (pb := node.user_permissions_boundary) and (pb_arn := pb.permissions_boundary_arn): @@ -888,26 +909,26 @@ def _init_principals(self) -> None: request_context = IamRequestContext( principal=node, identity_policies=identity_based_policies, - permission_boundaries=permission_boundaries, + permission_boundaries=tuple(permission_boundaries), service_control_policy_levels=service_control_policy_levels, ) self.principals.append(request_context) if isinstance(node, AwsIamGroup): - identity_based_policies = self._get_group_based_policies(node) + identity_based_policies = tuple(self._get_group_based_policies(node)) request_context = IamRequestContext( principal=node, identity_policies=identity_based_policies, - permission_boundaries=[], # permission boundaries are not applicable to groups + permission_boundaries=(), # permission boundaries are not applicable to groups service_control_policy_levels=service_control_policy_levels, ) self.principals.append(request_context) if isinstance(node, AwsIamRole): - identity_based_policies = self._get_role_based_policies(node) + identity_based_policies = tuple(self._get_role_based_policies(node)) # todo: colect these resources permission_boundaries = [] if (pb := node.role_permissions_boundary) and (pb_arn := pb.permissions_boundary_arn): @@ -918,7 +939,7 @@ def _init_principals(self) -> None: request_context = IamRequestContext( principal=node, identity_policies=identity_based_policies, - permission_boundaries=permission_boundaries, + permission_boundaries=tuple(permission_boundaries), service_control_policy_levels=service_control_policy_levels, ) @@ -1042,8 +1063,9 @@ def add_access_edges(self) -> None: for source, json_policy in node.resource_policy(self.builder): resource_policies.append((source, FixPolicyDocument(json_policy))) + resource_arn = ARN(node.arn) permissions = compute_permissions( - node, context, resource_policies, self.actions_for_resource.get(node.arn, set()) + resource_arn, context, tuple(resource_policies), self.actions_for_resource.get(node.arn, set()) ) if not permissions: diff --git a/plugins/aws/test/acccess_edges_test.py b/plugins/aws/test/acccess_edges_test.py index 684f489cf1..6e79115493 100644 --- a/plugins/aws/test/acccess_edges_test.py +++ b/plugins/aws/test/acccess_edges_test.py @@ -2,7 +2,8 @@ from fix_plugin_aws.resource.base import AwsResource from fix_plugin_aws.resource.iam import AwsIamUser, AwsIamGroup, AwsIamRole -from typing import Any, Dict, List +from typing import Any, Dict +from policy_sentry.util.arns import ARN import re from fix_plugin_aws.access_edges import ( @@ -84,41 +85,47 @@ def test_check_statement_match1() -> None: "Principal": {"AWS": ["arn:aws:iam::123456789012:user/example-user"]}, } statement = FixStatementDetail(allow_statement) - resource = AwsResource(id="bucket", arn="arn:aws:s3:::example-bucket/object.txt") + resource_arn = ARN("arn:aws:s3:::example-bucket/object.txt") + resource = AwsResource(id="bucket", arn=resource_arn.arn) principal = AwsResource(id="principal", arn="arn:aws:iam::123456789012:user/example-user") # Test matching statement - result, constraints = check_statement_match(statement, "Allow", atc("s3:GetObject"), resource, principal) - assert result is True + match_fn = check_statement_match(statement, "Allow", atc("s3:GetObject"), principal) + assert match_fn is not None + constraints = match_fn(resource_arn) assert constraints == ["arn:aws:s3:::example-bucket/*"] # Test wrong effect - result, constraints = check_statement_match(statement, "Deny", atc("s3:GetObject"), resource, principal) - assert result is False - assert constraints == [] + match_fn = check_statement_match(statement, "Deny", atc("s3:GetObject"), principal) + assert match_fn is None # wrong principal does not match - result, constraints = check_statement_match(statement, "Allow", atc("s3:GetObject"), resource, resource) - assert result is False + match_fn = check_statement_match(statement, "Allow", atc("s3:GetObject"), resource) + assert match_fn is None # Test statement with condition allow_statement["Condition"] = {"StringEquals": {"s3:prefix": "private/"}} statement = FixStatementDetail(allow_statement) - result, constraints = check_statement_match(statement, "Allow", atc("s3:GetObject"), resource, principal) - assert result is True + match_fn = check_statement_match(statement, "Allow", atc("s3:GetObject"), principal) + assert match_fn is not None + result = match_fn(resource_arn) + assert result is not None - # not providing principaal works - result, constraints = check_statement_match(statement, "Allow", atc("s3:GetObject"), resource, principal=None) - assert result is True + # not providing principal works + match_fn = check_statement_match(statement, "Allow", atc("s3:GetObject"), principal=None) + assert match_fn is not None + result = match_fn(resource_arn) + assert result is not None # not providing effect works - result, constraints = check_statement_match( - statement, effect=None, action=atc("s3:GetObject"), resource=resource, principal=None - ) - assert result is True - - result, constraints = check_statement_match(statement, "Allow", atc("s3:GetObject"), resource, principal) - assert result is True + match_fn = check_statement_match(statement, effect=None, action=atc("s3:GetObject"), principal=None) + assert match_fn is not None + result = match_fn(resource_arn) + assert result is not None + + match_fn = check_statement_match(statement, "Allow", atc("s3:GetObject"), principal) + assert match_fn is not None + constraints = match_fn(resource_arn) assert constraints == ["arn:aws:s3:::example-bucket/*"] deny_statement = { @@ -129,8 +136,9 @@ def test_check_statement_match1() -> None: } statement = FixStatementDetail(deny_statement) - result, constraints = check_statement_match(statement, "Deny", atc("s3:GetObject"), resource, principal) - assert result is True + match_fn = check_statement_match(statement, "Deny", atc("s3:GetObject"), principal) + assert match_fn is not None + constraints = match_fn(resource_arn) assert constraints == ["arn:aws:s3:::example-bucket/*"] # test not resource @@ -138,8 +146,9 @@ def test_check_statement_match1() -> None: del not_resource_statement["Resource"] not_resource_statement["NotResource"] = "arn:aws:s3:::example-bucket/private/*" statement = FixStatementDetail(not_resource_statement) - result, constraints = check_statement_match(statement, "Allow", atc("s3:GetObject"), resource, principal) - assert result is True + match_fn = check_statement_match(statement, "Allow", atc("s3:GetObject"), principal) + assert match_fn is not None + constraints = match_fn(resource_arn) assert constraints == ["not arn:aws:s3:::example-bucket/private/*"] @@ -166,15 +175,15 @@ def test_no_explicit_deny() -> None: request_context = IamRequestContext( principal=principal, - identity_policies=[], - permission_boundaries=[], - service_control_policy_levels=[], + identity_policies=(), + permission_boundaries=(), + service_control_policy_levels=(), ) - resource = AwsResource(id="some-resource", arn="arn:aws:s3:::example-bucket") + resource_arn = ARN("arn:aws:s3:::example-bucket") action = atc("s3:GetObject") - result = check_explicit_deny(request_context, resource, action, resource_based_policies=[]) + result = check_explicit_deny(request_context, action, resource_based_policies=())(resource_arn) assert result == "NextStep" @@ -188,9 +197,9 @@ def test_explicit_deny_in_identity_policy() -> None: "Statement": [{"Effect": "Deny", "Action": "s3:GetObject", "Resource": "arn:aws:s3:::example-bucket/*"}], } policy_document = FixPolicyDocument(policy_json) - identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=principal.arn), policy_document)] - permission_boundaries: List[FixPolicyDocument] = [] - service_control_policy_levels: List[List[FixPolicyDocument]] = [] + identity_policies = tuple([(PolicySource(kind=PolicySourceKind.principal, uri=principal.arn), policy_document)]) + permission_boundaries: tuple[FixPolicyDocument, ...] = () + service_control_policy_levels: tuple[tuple[FixPolicyDocument, ...], ...] = () request_context = IamRequestContext( principal=principal, @@ -199,10 +208,10 @@ def test_explicit_deny_in_identity_policy() -> None: service_control_policy_levels=service_control_policy_levels, ) - resource = AwsResource(id="some-resource", arn="arn:aws:s3:::example-bucket/object.txt") + resource_arn = ARN("arn:aws:s3:::example-bucket/object.txt") action = atc("s3:GetObject") - result = check_explicit_deny(request_context, resource, action, resource_based_policies=[]) + result = check_explicit_deny(request_context, action, resource_based_policies=())(resource_arn) assert result == "Denied" @@ -223,19 +232,19 @@ def test_explicit_deny_with_condition_in_identity_policy() -> None: ], } policy_document = FixPolicyDocument(policy_json) - identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=principal.arn), policy_document)] + identity_policies = tuple([(PolicySource(kind=PolicySourceKind.principal, uri=principal.arn), policy_document)]) request_context = IamRequestContext( principal=principal, identity_policies=identity_policies, - permission_boundaries=[], - service_control_policy_levels=[], + permission_boundaries=(), + service_control_policy_levels=(), ) - resource = AwsResource(id="some-resource", arn="arn:aws:s3:::example-bucket/object.txt") + resource_arn = ARN("arn:aws:s3:::example-bucket/object.txt") action = atc("s3:GetObject") - result = check_explicit_deny(request_context, resource, action, resource_based_policies=[]) + result = check_explicit_deny(request_context, action, resource_based_policies=())(resource_arn) expected_conditions = [policy_json["Statement"][0]["Condition"]] assert result == expected_conditions @@ -249,19 +258,19 @@ def test_explicit_deny_in_scp() -> None: "Statement": [{"Effect": "Deny", "Action": "s3:GetObject", "Resource": "*"}], } scp_policy_document = FixPolicyDocument(scp_policy_json) - service_control_policy_levels = [[scp_policy_document]] + service_control_policy_levels = tuple([tuple([scp_policy_document])]) request_context = IamRequestContext( principal=principal, - identity_policies=[], - permission_boundaries=[], + identity_policies=(), + permission_boundaries=(), service_control_policy_levels=service_control_policy_levels, ) - resource = AwsResource(id="some-resource", arn="arn:aws:s3:::example-bucket/object.txt") + resource_arn = ARN("arn:aws:s3:::example-bucket/object.txt") action = atc("s3:GetObject") - result = check_explicit_deny(request_context, resource, action, resource_based_policies=[]) + result = check_explicit_deny(request_context, action, resource_based_policies=())(resource_arn) assert result == "Denied" @@ -281,23 +290,27 @@ def test_explicit_deny_with_condition_in_scp() -> None: ], } scp_policy_document = FixPolicyDocument(scp_policy_json) - service_control_policy_levels = [ + service_control_policy_levels = tuple( [ - scp_policy_document, + tuple( + [ + scp_policy_document, + ] + ) ] - ] + ) request_context = IamRequestContext( principal=principal, - identity_policies=[], - permission_boundaries=[], + identity_policies=(), + permission_boundaries=(), service_control_policy_levels=service_control_policy_levels, ) - resource = AwsResource(id="some-resource", arn="arn:aws:s3:::example-bucket/object.txt") + resource_arn = ARN("arn:aws:s3:::example-bucket/object.txt") action = atc("s3:GetObject") - result = check_explicit_deny(request_context, resource, action, resource_based_policies=[]) + result = check_explicit_deny(request_context, action, resource_based_policies=())(resource_arn) expected_conditions = [scp_policy_json["Statement"][0]["Condition"]] assert result == expected_conditions @@ -308,9 +321,9 @@ def test_explicit_deny_in_resource_policy() -> None: request_context = IamRequestContext( principal=principal, - identity_policies=[], - permission_boundaries=[], - service_control_policy_levels=[], + identity_policies=(), + permission_boundaries=(), + service_control_policy_levels=(), ) policy_json: Dict[str, Any] = { @@ -325,14 +338,14 @@ def test_explicit_deny_in_resource_policy() -> None: ], } policy_document = FixPolicyDocument(policy_json) - resource_based_policies = [ - (PolicySource(kind=PolicySourceKind.resource, uri="arn:aws:s3:::example-bucket"), policy_document) - ] + resource_based_policies = tuple( + [(PolicySource(kind=PolicySourceKind.resource, uri="arn:aws:s3:::example-bucket"), policy_document)] + ) - resource = AwsResource(id="some-resource", arn="arn:aws:s3:::example-bucket/object.txt") + resource_arn = ARN("arn:aws:s3:::example-bucket/object.txt") action = atc("s3:GetObject") - result = check_explicit_deny(request_context, resource, action, resource_based_policies) + result = check_explicit_deny(request_context, action, resource_based_policies)(resource_arn) assert result == "Denied" @@ -342,9 +355,9 @@ def test_explicit_deny_with_condition_in_resource_policy() -> None: request_context = IamRequestContext( principal=principal, - identity_policies=[], - permission_boundaries=[], - service_control_policy_levels=[], + identity_policies=(), + permission_boundaries=(), + service_control_policy_levels=(), ) policy_json: Dict[str, Any] = { @@ -360,14 +373,14 @@ def test_explicit_deny_with_condition_in_resource_policy() -> None: ], } policy_document = FixPolicyDocument(policy_json) - resource_based_policies = [ - (PolicySource(kind=PolicySourceKind.resource, uri="arn:aws:s3:::example-bucket"), policy_document) - ] + resource_based_policies = tuple( + [(PolicySource(kind=PolicySourceKind.resource, uri="arn:aws:s3:::example-bucket"), policy_document)] + ) - resource = AwsResource(id="some-resource", arn="arn:aws:s3:::example-bucket/object.txt") + resource_arn = ARN("arn:aws:s3:::example-bucket/object.txt") action = atc("s3:GetObject") - result = check_explicit_deny(request_context, resource, action, resource_based_policies) + result = check_explicit_deny(request_context, action, resource_based_policies)(resource_arn) expected_conditions = [policy_json["Statement"][0]["Condition"]] assert result == expected_conditions @@ -376,7 +389,7 @@ def test_compute_permissions_user_inline_policy_allow() -> None: user = AwsIamUser(id="user123", arn="arn:aws:iam::123456789012:user/test-user") assert user.arn - bucket = AwsResource(id="bucket123", arn="arn:aws:s3:::my-test-bucket") + bucket_arn = ARN("arn:aws:s3:::my-test-bucket") policy_json = { "Version": "2012-10-17", @@ -391,17 +404,17 @@ def test_compute_permissions_user_inline_policy_allow() -> None: } policy_document = FixPolicyDocument(policy_json) - identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), policy_document)] + identity_policies = tuple([(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), policy_document)]) request_context = IamRequestContext( - principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] + principal=user, identity_policies=identity_policies, permission_boundaries=(), service_control_policy_levels=() ) permissions = compute_permissions( - resource=bucket, + resource=bucket_arn, iam_context=request_context, - resource_based_policies=[], - resource_actions=get_actions_matching_arn(bucket.arn or ""), + resource_based_policies=(), + resource_actions=get_actions_matching_arn(bucket_arn.arn), ) assert len(permissions) == 1 assert permissions[0].action == "s3:ListBucket" @@ -417,7 +430,7 @@ def test_compute_permissions_user_inline_policy_allow_with_conditions() -> None: user = AwsIamUser(id="user123", arn="arn:aws:iam::123456789012:user/test-user") assert user.arn - bucket = AwsResource(id="bucket123", arn="arn:aws:s3:::my-test-bucket") + bucket = ARN("arn:aws:s3:::my-test-bucket") condition = {"IpAddress": {"aws:SourceIp": "1.1.1.1"}} @@ -435,17 +448,17 @@ def test_compute_permissions_user_inline_policy_allow_with_conditions() -> None: } policy_document = FixPolicyDocument(policy_json) - identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), policy_document)] + identity_policies = tuple([(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), policy_document)]) request_context = IamRequestContext( - principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] + principal=user, identity_policies=identity_policies, permission_boundaries=(), service_control_policy_levels=() ) permissions = compute_permissions( resource=bucket, iam_context=request_context, - resource_based_policies=[], - resource_actions=get_actions_matching_arn(bucket.arn or ""), + resource_based_policies=(), + resource_actions=get_actions_matching_arn(bucket.arn), ) assert len(permissions) == 1 assert permissions[0].action == "s3:ListBucket" @@ -463,7 +476,7 @@ def test_compute_permissions_user_inline_policy_deny() -> None: user = AwsIamUser(id="user123", arn="arn:aws:iam::123456789012:user/test-user") assert user.arn - bucket = AwsResource(id="bucket123", arn="arn:aws:s3:::my-test-bucket") + bucket = ARN("arn:aws:s3:::my-test-bucket") policy_json = { "Version": "2012-10-17", @@ -478,17 +491,17 @@ def test_compute_permissions_user_inline_policy_deny() -> None: } policy_document = FixPolicyDocument(policy_json) - identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), policy_document)] + identity_policies = tuple([(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), policy_document)]) request_context = IamRequestContext( - principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] + principal=user, identity_policies=identity_policies, permission_boundaries=(), service_control_policy_levels=() ) permissions = compute_permissions( resource=bucket, iam_context=request_context, - resource_based_policies=[], - resource_actions=get_actions_matching_arn(bucket.arn or ""), + resource_based_policies=(), + resource_actions=get_actions_matching_arn(bucket.arn), ) assert len(permissions) == 0 @@ -498,7 +511,7 @@ def test_compute_permissions_user_inline_policy_deny_with_condition() -> None: user = AwsIamUser(id="user123", arn="arn:aws:iam::123456789012:user/test-user") assert user.arn - bucket = AwsResource(id="bucket123", arn="arn:aws:s3:::my-test-bucket") + bucket = ARN("arn:aws:s3:::my-test-bucket") condition = {"IpAddress": {"aws:SourceIp": "1.1.1.1"}} @@ -516,17 +529,17 @@ def test_compute_permissions_user_inline_policy_deny_with_condition() -> None: } policy_document = FixPolicyDocument(policy_json) - identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), policy_document)] + identity_policies = tuple([(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), policy_document)]) request_context = IamRequestContext( - principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] + principal=user, identity_policies=identity_policies, permission_boundaries=(), service_control_policy_levels=() ) permissions = compute_permissions( resource=bucket, iam_context=request_context, - resource_based_policies=[], - resource_actions=get_actions_matching_arn(bucket.arn or ""), + resource_based_policies=(), + resource_actions=get_actions_matching_arn(bucket.arn), ) # deny does not grant any permissions by itself, even if the condition is met @@ -537,7 +550,7 @@ def test_deny_overrides_allow() -> None: user = AwsIamUser(id="user123", arn="arn:aws:iam::123456789012:user/test-user") assert user.arn - bucket = AwsResource(id="bucket123", arn="arn:aws:s3:::my-test-bucket") + bucket = ARN("arn:aws:s3:::my-test-bucket") deny_policy_json = { "Version": "2012-10-17", @@ -565,20 +578,22 @@ def test_deny_overrides_allow() -> None: } allow_policy_document = FixPolicyDocument(allow_policy_json) - identity_policies = [ - (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), deny_policy_document), - (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), allow_policy_document), - ] + identity_policies = tuple( + [ + (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), deny_policy_document), + (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), allow_policy_document), + ] + ) request_context = IamRequestContext( - principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] + principal=user, identity_policies=identity_policies, permission_boundaries=(), service_control_policy_levels=() ) permissions = compute_permissions( resource=bucket, iam_context=request_context, - resource_based_policies=[], - resource_actions=get_actions_matching_arn(bucket.arn or ""), + resource_based_policies=(), + resource_actions=get_actions_matching_arn(bucket.arn), ) assert len(permissions) == 0 @@ -588,7 +603,7 @@ def test_deny_different_action_does_not_override_allow() -> None: user = AwsIamUser(id="user123", arn="arn:aws:iam::123456789012:user/test-user") assert user.arn - bucket = AwsResource(id="bucket123", arn="arn:aws:s3:::my-test-bucket") + bucket = ARN("arn:aws:s3:::my-test-bucket") deny_policy_json = { "Version": "2012-10-17", @@ -616,20 +631,22 @@ def test_deny_different_action_does_not_override_allow() -> None: } allow_policy_document = FixPolicyDocument(allow_policy_json) - identity_policies = [ - (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), deny_policy_document), - (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), allow_policy_document), - ] + identity_policies = tuple( + [ + (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), deny_policy_document), + (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), allow_policy_document), + ] + ) request_context = IamRequestContext( - principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] + principal=user, identity_policies=identity_policies, permission_boundaries=(), service_control_policy_levels=() ) permissions = compute_permissions( resource=bucket, iam_context=request_context, - resource_based_policies=[], - resource_actions=get_actions_matching_arn(bucket.arn or ""), + resource_based_policies=(), + resource_actions=get_actions_matching_arn(bucket.arn), ) assert len(permissions) == 1 @@ -639,7 +656,7 @@ def test_deny_overrides_allow_with_condition() -> None: user = AwsIamUser(id="user123", arn="arn:aws:iam::123456789012:user/test-user") assert user.arn - bucket = AwsResource(id="bucket123", arn="arn:aws:s3:::my-test-bucket") + bucket = ARN("arn:aws:s3:::my-test-bucket") condition = {"IpAddress": {"aws:SourceIp": "1.1.1.1"}} @@ -670,20 +687,22 @@ def test_deny_overrides_allow_with_condition() -> None: } allow_policy_document = FixPolicyDocument(allow_policy_json) - identity_policies = [ - (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), deny_policy_document), - (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), allow_policy_document), - ] + identity_policies = tuple( + [ + (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), deny_policy_document), + (PolicySource(kind=PolicySourceKind.principal, uri=user.arn), allow_policy_document), + ] + ) request_context = IamRequestContext( - principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] + principal=user, identity_policies=identity_policies, permission_boundaries=(), service_control_policy_levels=() ) permissions = compute_permissions( resource=bucket, iam_context=request_context, - resource_based_policies=[], - resource_actions=get_actions_matching_arn(bucket.arn or ""), + resource_based_policies=(), + resource_actions=get_actions_matching_arn(bucket.arn), ) assert len(permissions) == 1 @@ -702,7 +721,7 @@ def test_deny_overrides_allow_with_condition() -> None: def test_compute_permissions_resource_based_policy_allow() -> None: user = AwsIamUser(id="user123", arn="arn:aws:iam::111122223333:user/test-user") - bucket = AwsResource(id="bucket123", arn="arn:aws:s3:::my-test-bucket") + bucket = ARN("arn:aws:s3:::my-test-bucket") assert bucket.arn policy_json = { @@ -720,16 +739,16 @@ def test_compute_permissions_resource_based_policy_allow() -> None: policy_document = FixPolicyDocument(policy_json) request_context = IamRequestContext( - principal=user, identity_policies=[], permission_boundaries=[], service_control_policy_levels=[] + principal=user, identity_policies=(), permission_boundaries=(), service_control_policy_levels=() ) - resource_based_policies = [(PolicySource(kind=PolicySourceKind.resource, uri=bucket.arn), policy_document)] + resource_based_policies = tuple([(PolicySource(kind=PolicySourceKind.resource, uri=bucket.arn), policy_document)]) permissions = compute_permissions( resource=bucket, iam_context=request_context, resource_based_policies=resource_based_policies, - resource_actions=get_actions_matching_arn(bucket.arn or ""), + resource_actions=get_actions_matching_arn(bucket.arn), ) assert len(permissions) == 1 @@ -747,7 +766,7 @@ def test_compute_permissions_permission_boundary_restrict() -> None: user = AwsIamUser(id="user123", arn="arn:aws:iam::123456789012:user/test-user") assert user.arn - bucket = AwsResource(id="bucket123", arn="arn:aws:s3:::my-test-bucket") + bucket = ARN("arn:aws:s3:::my-test-bucket") identity_policy_json = { "Version": "2012-10-17", @@ -776,22 +795,22 @@ def test_compute_permissions_permission_boundary_restrict() -> None: } permission_boundary_document = FixPolicyDocument(permission_boundary_json) - identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), identity_policy_document)] + identity_policies = tuple([(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), identity_policy_document)]) - permission_boundaries = [permission_boundary_document] + permission_boundaries = tuple([permission_boundary_document]) request_context = IamRequestContext( principal=user, identity_policies=identity_policies, permission_boundaries=permission_boundaries, - service_control_policy_levels=[], + service_control_policy_levels=(), ) permissions = compute_permissions( resource=bucket, iam_context=request_context, - resource_based_policies=[], - resource_actions=get_actions_matching_arn(bucket.arn or ""), + resource_based_policies=(), + resource_actions=get_actions_matching_arn(bucket.arn), ) assert len(permissions) == 1 @@ -809,7 +828,7 @@ def test_compute_permissions_scp_deny() -> None: user = AwsIamUser(id="user123", arn="arn:aws:iam::123456789012:user/test-user") assert user.arn - ec2_instance = AwsResource(id="instance123", arn="arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0") + ec2_instance = ARN("arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0") identity_policy_json = { "Version": "2012-10-17", @@ -832,22 +851,22 @@ def test_compute_permissions_scp_deny() -> None: } scp_policy_document = FixPolicyDocument(scp_policy_json) - identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), identity_policy_document)] + identity_policies = tuple([(PolicySource(kind=PolicySourceKind.principal, uri=user.arn), identity_policy_document)]) - service_control_policy_levels = [[scp_policy_document]] + service_control_policy_levels = ((scp_policy_document,),) request_context = IamRequestContext( principal=user, identity_policies=identity_policies, - permission_boundaries=[], + permission_boundaries=(), service_control_policy_levels=service_control_policy_levels, ) permissions = compute_permissions( resource=ec2_instance, iam_context=request_context, - resource_based_policies=[], - resource_actions=get_actions_matching_arn(ec2_instance.arn or ""), + resource_based_policies=(), + resource_actions=get_actions_matching_arn(ec2_instance.arn), ) assert len(permissions) == 0 @@ -855,7 +874,7 @@ def test_compute_permissions_scp_deny() -> None: def test_compute_permissions_user_with_group_policies() -> None: user = AwsIamUser(id="user123", arn="arn:aws:iam::123456789012:user/test-user") - bucket = AwsResource(id="bucket123", arn="arn:aws:s3:::my-test-bucket") + bucket = ARN("arn:aws:s3:::my-test-bucket") group = AwsResource(id="group123", arn="arn:aws:iam::123456789012:group/test-group") assert group.arn @@ -873,13 +892,16 @@ def test_compute_permissions_user_with_group_policies() -> None: identity_policies.append((PolicySource(kind=PolicySourceKind.group, uri=group.arn), group_policy_document)) request_context = IamRequestContext( - principal=user, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] + principal=user, + identity_policies=tuple(identity_policies), + permission_boundaries=(), + service_control_policy_levels=(), ) permissions = compute_permissions( resource=bucket, iam_context=request_context, - resource_based_policies=[], + resource_based_policies=(), resource_actions=get_actions_matching_arn(bucket.arn or ""), ) @@ -896,17 +918,17 @@ def test_compute_permissions_user_with_group_policies() -> None: def test_compute_permissions_implicit_deny() -> None: user = AwsIamUser(id="user123", arn="arn:aws:iam::123456789012:user/test-user") - table = AwsResource(id="table123", arn="arn:aws:dynamodb:us-east-1:123456789012:table/my-table") + table = ARN("arn:aws:dynamodb:us-east-1:123456789012:table/my-table") request_context = IamRequestContext( - principal=user, identity_policies=[], permission_boundaries=[], service_control_policy_levels=[] + principal=user, identity_policies=(), permission_boundaries=(), service_control_policy_levels=() ) permissions = compute_permissions( resource=table, iam_context=request_context, - resource_based_policies=[], - resource_actions=get_actions_matching_arn(table.arn or ""), + resource_based_policies=(), + resource_actions=get_actions_matching_arn(table.arn), ) # Assert that permissions do not include any actions (implicit deny) @@ -917,7 +939,7 @@ def test_compute_permissions_group_inline_policy_allow() -> None: group = AwsIamGroup(id="group123", arn="arn:aws:iam::123456789012:group/test-group") assert group.arn - bucket = AwsResource(id="bucket123", arn="arn:aws:s3:::my-test-bucket") + bucket = ARN("arn:aws:s3:::my-test-bucket") policy_json = { "Version": "2012-10-17", @@ -932,17 +954,17 @@ def test_compute_permissions_group_inline_policy_allow() -> None: } policy_document = FixPolicyDocument(policy_json) - identity_policies = [(PolicySource(kind=PolicySourceKind.group, uri=group.arn), policy_document)] + identity_policies = tuple([(PolicySource(kind=PolicySourceKind.group, uri=group.arn), policy_document)]) request_context = IamRequestContext( - principal=group, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] + principal=group, identity_policies=identity_policies, permission_boundaries=(), service_control_policy_levels=() ) permissions = compute_permissions( resource=bucket, iam_context=request_context, - resource_based_policies=[], - resource_actions=get_actions_matching_arn(bucket.arn or ""), + resource_based_policies=(), + resource_actions=get_actions_matching_arn(bucket.arn), ) assert len(permissions) == 1 @@ -959,7 +981,7 @@ def test_compute_permissions_role_inline_policy_allow() -> None: role = AwsIamRole(id="role123", arn="arn:aws:iam::123456789012:role/test-role") assert role.arn - bucket = AwsResource(id="bucket123", arn="arn:aws:s3:::my-test-bucket") + bucket = ARN("arn:aws:s3:::my-test-bucket") policy_json = { "Version": "2012-10-17", @@ -974,17 +996,17 @@ def test_compute_permissions_role_inline_policy_allow() -> None: } policy_document = FixPolicyDocument(policy_json) - identity_policies = [(PolicySource(kind=PolicySourceKind.principal, uri=role.arn), policy_document)] + identity_policies = tuple([(PolicySource(kind=PolicySourceKind.principal, uri=role.arn), policy_document)]) request_context = IamRequestContext( - principal=role, identity_policies=identity_policies, permission_boundaries=[], service_control_policy_levels=[] + principal=role, identity_policies=identity_policies, permission_boundaries=(), service_control_policy_levels=() ) permissions = compute_permissions( resource=bucket, iam_context=request_context, - resource_based_policies=[], - resource_actions=get_actions_matching_arn(bucket.arn or ""), + resource_based_policies=(), + resource_actions=get_actions_matching_arn(bucket.arn), ) assert len(permissions) == 1 From 352733c58a3e37e2c107c226f325cf1732e48c18 Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Thu, 14 Nov 2024 14:36:56 +0100 Subject: [PATCH 09/14] add a principal tree --- plugins/aws/fix_plugin_aws/access_edges.py | 212 ++++++++++++++++- plugins/aws/test/acccess_edges_test.py | 250 +++++++++++++++++++++ 2 files changed, 461 insertions(+), 1 deletion(-) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges.py index 7bb6565704..f7c4a36b23 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges.py @@ -1,4 +1,5 @@ from enum import Enum +import enum from functools import lru_cache from attr import frozen import networkx @@ -88,6 +89,215 @@ class ActionToCheck: action_name: str +class ArnResourceValueKind(enum.Enum): + Static = 1 # the segment is a fixed value, e.g. "s3", "vpc/vpc-0e9801d129EXAMPLE", + Pattern = 2 # the segment is a pattern, e.g. "my_corporate_bucket/*", + Any = 3 # the segment is missing, e.g. "::" or it is a wildcard, e.g. "*" + +@frozen(slots=True) +class ArnResource: + value: str + principal_arns: Set[str] + kind: ArnResourceValueKind + not_resource: bool + + def matches(self, segment: str) -> bool: + _match = False + match self.kind: + case ArnResourceValueKind.Any: + _match = True + case ArnResourceValueKind.Pattern: + _match = fnmatch.fnmatch(segment, self.value) + case ArnResourceValueKind.Static: + _match = segment == self.value + + + if self.not_resource: + _match = not _match + + + return _match + + + +@frozen(slots=True) +class ArnAccountId: + value: str + wildcard: bool # if the account is a wildcard, e.g. "*" or "::" + principal_arns: Set[str] + children: List[ArnResource] + + def matches(self, segment: str) -> bool: + return self.wildcard or self.value == segment + + +@frozen(slots=True) +class ArnRegion: + value: str + wildcard: bool # if the region is a wildcard, e.g. "*" or "::" + principal_arns: Set[str] + children: List[ArnAccountId] + + def matches(self, segment: str) -> bool: + return self.wildcard or self.value == segment + + +@frozen(slots=True) +class ArnService: + value: str + principal_arns: Set[str] + children: List[ArnRegion] + + def matches(self, segment: str) -> bool: + return self.value == segment + + +@frozen(slots=True) +class ArnPartition: + value: str + wildcard: bool # for the cases like "Allow": "*" on all resources + principal_arns: Set[str] + children: List[ArnService] + + def matches(self, segment: str) -> bool: + return self.wildcard or segment == self.value + + +def is_wildcard(segment: str) -> bool: + return segment == "*" or segment == "" + + +class PrincipalTree: + def __init__(self) -> None: + self.partitions: List[ArnPartition] = [] + + + def _add_allow_all_wildcard(self, principal_arn: str) -> None: + partition = next((p for p in self.partitions if p.value == "*"), None) + if not partition: + partition = ArnPartition(value="*", wildcard=True, principal_arns=set(), children=[]) + self.partitions.append(partition) + + partition.principal_arns.add(principal_arn) + + def _add_resource(self, resource_constraint: str, principal_arn: str, nr: bool = False) -> None: + """ + _add resource will add the principal arn at the resource level + """ + + + try: + arn = ARN(resource_constraint) + # Find existing or create partition + partition = next((p for p in self.partitions if p.value == arn.partition), None) + if not partition: + partition = ArnPartition(value=arn.partition, wildcard=False, principal_arns=set(), children=[]) + self.partitions.append(partition) + + # Find or create service + service = next((s for s in partition.children if s.value == arn.service_prefix), None) + if not service: + service = ArnService(value=arn.service_prefix, principal_arns=set(), children=[]) + partition.children.append(service) + + # Find or create region + region_wildcard = arn.region == "*" or not arn.region + region = next((r for r in service.children if r.value == (arn.region or "*")), None) + if not region: + region = ArnRegion(value=arn.region or "*", wildcard=region_wildcard, principal_arns=set(), children=[]) + service.children.append(region) + + # Find or create account + account_wildcard = arn.account == "*" or not arn.account + account = next((a for a in region.children if a.value == (arn.account or "*")), None) + if not account: + account = ArnAccountId(value=arn.account or "*", wildcard=account_wildcard, principal_arns=set(), children=[]) + region.children.append(account) + + # Add resource + resource = next((r for r in account.children if r.value == arn.resource_string and r.not_resource == nr), None) + if not resource: + if arn.resource_string == "*": + resource_kind = ArnResourceValueKind.Any + elif "*" in arn.resource_string: + resource_kind = ArnResourceValueKind.Pattern + else: + resource_kind = ArnResourceValueKind.Static + resource = ArnResource(value=arn.resource_string, principal_arns=set(), kind=resource_kind, not_resource=nr) + account.children.append(resource) + + resource.principal_arns.add(principal_arn) + + except Exception as e: + log.error(f"Error parsing ARN {principal_arn}: {e}") + pass + + + def _add_service(self, service_prefix: str, principal_arn: str) -> None: + # Find existing or create partition + partition = next((p for p in self.partitions if p.value == "*"), None) + if not partition: + partition = ArnPartition(value="*", wildcard=True, principal_arns=set(), children=[]) + self.partitions.append(partition) + + # Find or create service + service = next((s for s in partition.children if s.value == service_prefix), None) + if not service: + service = ArnService(value=service_prefix, principal_arns=set(), children=[]) + partition.children.append(service) + + service.principal_arns.add(principal_arn) + + + + def add_principal(self, principal_arn: str, policy_documents: List[FixPolicyDocument]) -> None: + """ + This method iterates over every policy statement and adds corresponding arns to principal tree. + """ + + for policy_doc in policy_documents: + for statement in policy_doc.fix_statements: + if statement.effect_allow: + has_wildcard_resource = False + for resource in statement.resources: + if resource == "*": + has_wildcard_resource = True + continue + self._add_resource(resource, principal_arn) + for not_resource in statement.not_resource: + self._add_resource(not_resource, principal_arn, nr=True) + + if has_wildcard_resource or (not statement.resources and not statement.not_resource): + for ap in statement.actions_patterns: + if ap.kind == WildcardKind.any: + self._add_allow_all_wildcard(principal_arn) + self._add_service(ap.service, principal_arn) + + + def list_principals(self, resource_arn: ARN) -> Set[str]: + """ + this will be called for every resource and it must be fast + """ + principals = set() + + matching_partitions = [p for p in self.partitions if p.value if p.matches(resource_arn.partition)] + + matching_services = [s for p in matching_partitions for s in p.children if s.matches(resource_arn.service_prefix)] + principals.update([arn for s in matching_services for arn in s.principal_arns]) + + matching_regions = [r for s in matching_services for r in s.children if r.matches(resource_arn.region)] + principals.update([arn for r in matching_regions for arn in r.principal_arns]) + + matching_account_ids = [a for r in matching_regions for a in r.children if r.matches(resource_arn.account)] + principals.update([arn for a in matching_account_ids for arn in a.principal_arns]) + + matching_resources = [r for a in matching_account_ids for r in a.children if r.matches(resource_arn.resource_string)] + principals.update([arn for r in matching_resources for arn in r.principal_arns]) + + return principals + + + @frozen(slots=True) class IamRequestContext: principal: AwsResource @@ -96,7 +306,7 @@ class IamRequestContext: # all service control policies applicable to the principal, # starting from the root, then all org units, then the account service_control_policy_levels: Tuple[Tuple[FixPolicyDocument, ...], ...] - # technically we should also add a list of session policies here, but they don't exist in the collector context + def all_policies( self, resource_based_policies: Optional[Tuple[Tuple[PolicySource, FixPolicyDocument], ...]] = None diff --git a/plugins/aws/test/acccess_edges_test.py b/plugins/aws/test/acccess_edges_test.py index 6e79115493..d59c897a2d 100644 --- a/plugins/aws/test/acccess_edges_test.py +++ b/plugins/aws/test/acccess_edges_test.py @@ -18,6 +18,8 @@ FixStatementDetail, ActionToCheck, get_actions_matching_arn, + PrincipalTree, + ArnResourceValueKind ) from fixlib.baseresources import PolicySourceKind, PolicySource, PermissionLevel @@ -1017,3 +1019,251 @@ def test_compute_permissions_role_inline_policy_allow() -> None: assert s.source.kind == PolicySourceKind.principal assert s.source.uri == role.arn assert s.constraints == ("arn:aws:s3:::my-test-bucket",) + + +def test_principal_tree_add_allow_all_wildcard() -> None: + """Test adding wildcard (*) permission to the principal tree.""" + tree = PrincipalTree() + principal_arn = "arn:aws:iam::123456789012:user/test-user" + + tree._add_allow_all_wildcard(principal_arn) + + # Verify the wildcard partition exists + assert len(tree.partitions) == 1 + partition = tree.partitions[0] + assert partition.value == "*" + assert partition.wildcard is True + assert principal_arn in partition.principal_arns + + +def test_principal_tree_add_resource() -> None: + """Test adding a resource ARN to the principal tree.""" + tree = PrincipalTree() + principal_arn = "arn:aws:iam::123456789012:user/test-user" + resource_arn = "arn:aws:s3:::my-bucket/my-object" + + tree._add_resource(resource_arn, principal_arn) + + # Verify the partition structure + assert len(tree.partitions) == 1 + partition = tree.partitions[0] + assert partition.value == "aws" + assert not partition.wildcard + + # Verify service level + assert len(partition.children) == 1 + service = partition.children[0] + assert service.value == "s3" + + # Verify region level + assert len(service.children) == 1 + region = service.children[0] + assert region.value == "*" + assert region.wildcard + + # Verify account level + assert len(region.children) == 1 + account = region.children[0] + assert account.value == "*" + assert account.wildcard + + # Verify resource level + assert len(account.children) == 1 + resource = account.children[0] + assert resource.value == "my-bucket/my-object" + assert resource.kind == ArnResourceValueKind.Static + assert principal_arn in resource.principal_arns + assert not resource.not_resource + + +def test_principal_tree_add_resource_with_wildcard() -> None: + """Test adding a resource ARN with wildcards to the principal tree.""" + tree = PrincipalTree() + principal_arn = "arn:aws:iam::123456789012:user/test-user" + resource_arn = "arn:aws:s3:::my-bucket/*" + + tree._add_resource(resource_arn, principal_arn) + + # Verify the resource level has correct wildcard pattern + partition = tree.partitions[0] + service = partition.children[0] + region = service.children[0] + account = region.children[0] + resource = account.children[0] + + assert resource.value == "my-bucket/*" + assert resource.kind == ArnResourceValueKind.Pattern + assert principal_arn in resource.principal_arns + + +def test_principal_tree_add_not_resource() -> None: + """Test adding a NotResource ARN to the principal tree.""" + tree = PrincipalTree() + principal_arn = "arn:aws:iam::123456789012:user/test-user" + resource_arn = "arn:aws:s3:::my-bucket/private/*" + + tree._add_resource(resource_arn, principal_arn, nr=True) + + # Verify the NotResource flag is set correctly through the tree + partition = tree.partitions[0] + service = partition.children[0] + region = service.children[0] + account = region.children[0] + resource = account.children[0] + assert resource.not_resource + + +def test_principal_tree_add_service() -> None: + """Test adding a service to the principal tree.""" + tree = PrincipalTree() + principal_arn = "arn:aws:iam::123456789012:user/test-user" + service_prefix = "s3" + + tree._add_service(service_prefix, principal_arn) + + # Verify service is added under wildcard partition + assert len(tree.partitions) == 1 + partition = tree.partitions[0] + assert partition.value == "*" + + assert len(partition.children) == 1 + service = partition.children[0] + assert service.value == "s3" + assert principal_arn in service.principal_arns + + +def test_principal_tree_add_principal_policy() -> None: + """Test adding a principal with policy documents to the principal tree.""" + tree = PrincipalTree() + principal_arn = "arn:aws:iam::123456789012:user/test-user" + + policy_json = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": ["s3:GetObject"], + "Resource": "arn:aws:s3:::my-bucket/*" + }, + { + "Effect": "Allow", + "Action": ["s3:ListAllMyBuckets"], + "Resource": "*" + } + ] + } + + policy_doc = FixPolicyDocument(policy_json) + tree.add_principal(principal_arn, [policy_doc]) + + # Verify both the specific resource and wildcard permissions are added + assert any( + p.value == "aws" and + any(s.value == "s3" and + any(r.value == "*" and + any(a.value == "*" and + any(res.value == "my-bucket/*" + for res in a.children) + for a in r.children) + for r in s.children) + for s in p.children) + for p in tree.partitions + ) + + +def test_principal_tree_list_principals() -> None: + """Test listing principals that have access to a given ARN.""" + tree = PrincipalTree() + principal1 = "arn:aws:iam::123456789012:user/test-user1" + principal2 = "arn:aws:iam::123456789012:user/test-user2" + + # Add different types of permissions + policy_doc1 = FixPolicyDocument({ + "Version": "2012-10-17", + "Statement": [{ + "Effect": "Allow", + "Action": ["s3:GetObject"], + "Resource": "arn:aws:s3:::my-bucket/*" + }] + }) + + policy_doc2 = FixPolicyDocument({ + "Version": "2012-10-17", + "Statement": [{ + "Effect": "Allow", + "Action": ["s3:ListAllMyBuckets"], + "Resource": "*" + }] + }) + + tree.add_principal(principal1, [policy_doc1]) + tree.add_principal(principal2, [policy_doc2]) + + # Test specific resource access + resource_arn = ARN("arn:aws:s3:::my-bucket/test.txt") + matching_principals = tree.list_principals(resource_arn) + + assert principal1 in matching_principals # Has specific access + assert principal2 in matching_principals # Has wildcard access + + +def test_principal_tree_add_multiple_statements() -> None: + """Test adding multiple statements for the same principal.""" + tree = PrincipalTree() + principal_arn = "arn:aws:iam::123456789012:user/test-user" + + policy_doc = FixPolicyDocument({ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": ["s3:GetObject"], + "Resource": "arn:aws:s3:::bucket1/*" + }, + { + "Effect": "Allow", + "Action": ["s3:PutObject"], + "Resource": "arn:aws:s3:::bucket2/*" + } + ] + }) + + tree.add_principal(principal_arn, [policy_doc]) + + # Test access to both buckets + bucket1_arn = ARN("arn:aws:s3:::bucket1/test.txt") + bucket2_arn = ARN("arn:aws:s3:::bucket2/test.txt") + + assert principal_arn in tree.list_principals(bucket1_arn) + assert principal_arn in tree.list_principals(bucket2_arn) + + +def test_principal_tree_not_resource() -> None: + """Test NotResource handling in the principal tree.""" + tree = PrincipalTree() + principal_arn = "arn:aws:iam::123456789012:user/test-user" + + policy_doc = FixPolicyDocument({ + "Version": "2012-10-17", + "Statement": [{ + "Effect": "Allow", + "Action": ["s3:GetObject"], + "NotResource": ["arn:aws:s3:::private-bucket/*"] + }] + }) + + tree.add_principal(principal_arn, [policy_doc]) + + # Test access is denied to private bucket + private_arn = ARN("arn:aws:s3:::private-bucket/secret.txt") + public_arn = ARN("arn:aws:s3:::public-bucket/public.txt") + ec2 = ARN("arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0") + + matching_principals = tree.list_principals(private_arn) + assert principal_arn not in matching_principals + + matching_principals = tree.list_principals(public_arn) + assert principal_arn in matching_principals + + matching_principals = tree.list_principals(ec2) + assert len(matching_principals) == 0 \ No newline at end of file From 9a7112ba69201e3778cf2e9312453c2ebbb5d3f1 Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Fri, 15 Nov 2024 15:30:02 +0100 Subject: [PATCH 10/14] enable principal tree filtering --- plugins/aws/fix_plugin_aws/access_edges.py | 87 +++++++++++++++++----- plugins/aws/fix_plugin_aws/collector.py | 17 ++--- 2 files changed, 77 insertions(+), 27 deletions(-) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges.py index f7c4a36b23..bad285e630 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges.py @@ -281,17 +281,29 @@ def list_principals(self, resource_arn: ARN) -> Set[str]: principals = set() matching_partitions = [p for p in self.partitions if p.value if p.matches(resource_arn.partition)] + if not matching_partitions: + return principals matching_services = [s for p in matching_partitions for s in p.children if s.matches(resource_arn.service_prefix)] + if not matching_services: + return principals principals.update([arn for s in matching_services for arn in s.principal_arns]) + matching_regions = [r for s in matching_services for r in s.children if r.matches(resource_arn.region)] + if not matching_regions: + return principals principals.update([arn for r in matching_regions for arn in r.principal_arns]) matching_account_ids = [a for r in matching_regions for a in r.children if r.matches(resource_arn.account)] + if not matching_account_ids: + return principals principals.update([arn for a in matching_account_ids for arn in a.principal_arns]) matching_resources = [r for a in matching_account_ids for r in a.children if r.matches(resource_arn.resource_string)] + if not matching_resources: + return principals + principals.update([arn for r in matching_resources for arn in r.principal_arns]) return principals @@ -1093,6 +1105,8 @@ def __init__(self, builder: GraphBuilder): self.principals: List[IamRequestContext] = [] self._init_principals() self.actions_for_resource: Dict[str, set[IamAction]] = self._compute_actions_for_resource() + self.principal_tree = self._build_principal_tree() + self.arn_to_context = {context.principal.arn: context for context in self.principals} def _init_principals(self) -> None: @@ -1155,6 +1169,21 @@ def _init_principals(self) -> None: self.principals.append(request_context) + def _build_principal_tree(self) -> PrincipalTree: + + tree = PrincipalTree() + + for context in self.principals: + principal_arn = context.principal.arn + if not principal_arn: + continue + + principal_policies = context.all_policies() + tree.add_principal(principal_arn, principal_policies) + + return tree + + def _compute_actions_for_resource(self) -> Dict[str, set[IamAction]]: actions_for_resource: Dict[str, set[IamAction]] = {} @@ -1263,32 +1292,54 @@ def add_access_edges(self) -> None: for node in self.builder.nodes(clazz=AwsResource, filter=lambda r: r.arn is not None): assert node.arn - for context in self.principals: - if context.principal.arn == node.arn: - # small graph cycles avoidance optimization - continue + resource_arn = ARN(node.arn) + - resource_policies: List[Tuple[PolicySource, FixPolicyDocument]] = [] - if isinstance(node, HasResourcePolicy): + if not isinstance(node, HasResourcePolicy): + # here we have identity-based policies only and can prune some principals + for arn in self.principal_tree.list_principals(resource_arn): + context = self.arn_to_context.get(arn) + if not context: + raise ValueError(f"Principal {arn} not found in the context") + + permissions = compute_permissions( + resource_arn, context, tuple(), self.actions_for_resource.get(node.arn, set()) + ) + + if not permissions: + continue + + access: Dict[PermissionLevel, bool] = {} + for permission in permissions: + access[permission.level] = True + reported = to_json({"permissions": permissions} | access, strip_nulls=True) + self.builder.add_edge(from_node=context.principal, edge_type=EdgeType.iam, reported=reported, node=node) + + else: + # here we have resource-based policies and must check all principals. + for context in self.principals: + if context.principal.arn == node.arn: + # small graph cycles avoidance optimization + continue + + resource_policies: List[Tuple[PolicySource, FixPolicyDocument]] = [] for source, json_policy in node.resource_policy(self.builder): resource_policies.append((source, FixPolicyDocument(json_policy))) - resource_arn = ARN(node.arn) - permissions = compute_permissions( - resource_arn, context, tuple(resource_policies), self.actions_for_resource.get(node.arn, set()) - ) - - if not permissions: - continue + permissions = compute_permissions( + resource_arn, context, tuple(resource_policies), self.actions_for_resource.get(node.arn, set()) + ) - access: Dict[PermissionLevel, bool] = {} + if not permissions: + continue - for permission in permissions: - access[permission.level] = True + access: Dict[PermissionLevel, bool] = {} + for permission in permissions: + access[permission.level] = True + reported = to_json({"permissions": permissions} | access, strip_nulls=True) + self.builder.add_edge(from_node=context.principal, edge_type=EdgeType.iam, reported=reported, node=node) - reported = to_json({"permissions": permissions} | access, strip_nulls=True) - self.builder.add_edge(from_node=context.principal, edge_type=EdgeType.iam, reported=reported, node=node) all_principal_arns = {p.principal.arn for p in self.principals if p.principal.arn} diff --git a/plugins/aws/fix_plugin_aws/collector.py b/plugins/aws/fix_plugin_aws/collector.py index 01c27acd92..c0fa052482 100644 --- a/plugins/aws/fix_plugin_aws/collector.py +++ b/plugins/aws/fix_plugin_aws/collector.py @@ -50,8 +50,6 @@ backup, bedrock, scp, - guardduty, - inspector, ) from fix_plugin_aws.resource.base import ( AwsAccount, @@ -71,6 +69,7 @@ from fixlib.threading import ExecutorQueue, GatherFutures from fixlib.types import Json from .utils import global_region_by_partition +from pyinstrument import Profiler log = logging.getLogger("fix.plugins.aws") @@ -106,7 +105,6 @@ + elb.resources + elbv2.resources + glacier.resources - + guardduty.resources + kinesis.resources + kms.resources + lambda_.resources @@ -121,7 +119,6 @@ + backup.resources + amazonq.resources + bedrock.resources - + inspector.resources ) all_resources: List[Type[AwsResource]] = global_resources + regional_resources @@ -248,10 +245,6 @@ def get_last_run() -> Optional[datetime]: ) shared_queue.wait_for_submitted_work() - # call all registered after collect hooks - for after_collect in global_builder.after_collect_actions: - after_collect() - # connect nodes log.info(f"[Aws:{self.account.id}] Connect resources and create edges.") for node, data in list(self.graph.nodes(data=True)): @@ -271,12 +264,18 @@ def get_last_run() -> Optional[datetime]: log.warning(f"Unexpected node type {node} in graph") raise Exception("Only AWS resources expected") - access_edge_collection_enabled = os.environ.get("ACCESS_EDGE_COLLECTION_ENABLED", "false").lower() == "true" + access_edge_collection_enabled = True if access_edge_collection_enabled and global_builder.config.collect_access_edges: # add access edges + profiler = Profiler() + profiler.start() log.info(f"[Aws:{self.account.id}] Create access edges.") access_edge_creator = AccessEdgeCreator(global_builder) access_edge_creator.add_access_edges() + profiler.stop() + html_output = profiler.output_html() + with open(f"profiler_{self.account.id}.html", "w") as f: + f.write(html_output) # final hook when the graph is complete for node, data in list(self.graph.nodes(data=True)): From 2a4c39c966c8129104ee3ad6a9a8fbd539ec63f6 Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Thu, 28 Nov 2024 10:55:45 +0100 Subject: [PATCH 11/14] add the pre-computing for non resource policies --- plugins/aws/fix_plugin_aws/access_edges.py | 155 ++++++++++++++++++--- 1 file changed, 136 insertions(+), 19 deletions(-) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges.py index bad285e630..5f1a08258b 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges.py @@ -333,8 +333,8 @@ def all_policies( IamAction = str - -def find_allowed_action(policy_document: PolicyDocument, service_prefix: str) -> Set[IamAction]: +@lru_cache(maxsize=4096) +def find_allowed_action(policy_document: FixPolicyDocument, service_prefix: str) -> Set[IamAction]: allowed_actions: Set[IamAction] = set() for statement in policy_document.statements: if statement.effect_allow: @@ -887,29 +887,37 @@ def check_identity_policies_closure(resource: ARN) -> List[PermissionScope]: return check_identity_policies_closure - +@lru_cache(maxsize=4096) def check_permission_boundaries( - request_context: IamRequestContext, resource: ARN, action: ActionToCheck -) -> Union[Literal["Denied", "NextStep"], List[Json]]: + request_context: IamRequestContext, action: ActionToCheck +) -> Callable[[ARN], Union[Literal["Denied", "NextStep"], List[Json]]]: - conditions: List[Json] = [] + + matching_fns = [] # ignore policy sources and resource constraints because permission boundaries # can never allow access to a resource, only restrict it for policy in request_context.permission_boundaries: matching_fn = collect_matching_statements(policy=policy, effect="Allow", action=action, principal=None) - for statement, _ in matching_fn(resource): - if statement.condition: - assert isinstance(statement.condition, dict) - conditions.append(statement.condition) - else: # if there is an allow statement without a condition, the action is allowed - return "NextStep" + matching_fns.append(matching_fn) + + def check_permission_boundaries_closure(resource: ARN) -> Union[Literal["Denied", "NextStep"], List[Json]]: + conditions: List[Json] = [] + for matching_fn in matching_fns: + for statement, _ in matching_fn(resource): + if statement.condition: + assert isinstance(statement.condition, dict) + conditions.append(statement.condition) + else: # if there is an allow statement without a condition, the action is allowed + return "NextStep" - if len(conditions) > 0: - return conditions + if len(conditions) > 0: + return conditions - # no matching permission boundaries that allow access - return "Denied" + # no matching permission boundaries that allow access + return "Denied" + + return check_permission_boundaries_closure def is_service_linked_role(principal: AwsResource) -> bool: @@ -1015,7 +1023,7 @@ def check_policies( # 4. to make it a bit simpler, we check the permission boundaries before checking identity based policies if len(request_context.permission_boundaries) > 0: - permission_boundary_result = check_permission_boundaries(request_context, resource, action) + permission_boundary_result = check_permission_boundaries(request_context, action)(resource) if permission_boundary_result == "Denied": return None elif permission_boundary_result == "NextStep": @@ -1065,6 +1073,110 @@ def check_policies( ) +# logic according to https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_evaluation-logic.html +@lru_cache(maxsize=4096) +def check_non_resource_policies( + request_context: IamRequestContext, + action: ActionToCheck, +) -> Callable[[ARN], Optional[AccessPermission]]: + + # step 1: calculate and cache the expensive function calls + explicit_deny_fn = check_explicit_deny(request_context, action, ()) + permission_boundary_fn = None + if len(request_context.permission_boundaries) > 0: + permission_boundary_fn = check_permission_boundaries(request_context, action) + identity_based_fn = check_identity_based_policies(request_context, action) + + # step 2: create the closure + def check_non_resource_policies_closure(resource: ARN) -> Optional[AccessPermission]: + + # shortcut: check if any identity based policies are present + if len(request_context.identity_policies) == 0: + return None + + # when any of the conditions evaluate to true, the action is explicitly denied + # comes from any explicit deny statements in all policies + deny_conditions: List[Json] = [] + + # when any of the conditions evaluate to false, the action is implicitly denied + # comes from the permission boundaries + restricting_conditions: List[Json] = [] + + # when any of the scopes evaluate to true, the action is allowed + # comes from the resource based policies and identity based policies + allowed_scopes: List[PermissionScope] = [] + + + # 1. check for explicit deny. If denied, we can abort immediately + result = explicit_deny_fn(resource) + if result == "Denied": + return None + elif result == "NextStep": + pass + else: + for c in result: + # satisfying any of the conditions above will deny the action + deny_conditions.append(c) + + + + # 2. check for organization SCPs # todo: move it outside the loop + if len(request_context.service_control_policy_levels) > 0 and not is_service_linked_role(request_context.principal): + org_scp_allowed = scp_allowed(request_context, action, resource) + if not org_scp_allowed: + return None + + # 3. skip resource based policies because the resource has none + + # 4. to make it a bit simpler, we check the permission boundaries before checking identity based policies + if permission_boundary_fn: + permission_boundary_result = permission_boundary_fn(resource) + if permission_boundary_result == "Denied": + return None + elif permission_boundary_result == "NextStep": + pass + else: + restricting_conditions.extend(permission_boundary_result) + + # 5. check identity based policies + identity_based_allowed = identity_based_fn(resource) + if not identity_based_allowed: + return None + allowed_scopes.extend(identity_based_allowed) + + # 6. check for session policies + # we don't collect session principals and session policies, so this step is skipped + + # 7. if we reached here, the action is allowed + level = get_action_level(action.raw) + + final_scopes: Set[PermissionScope] = set() + for scope in allowed_scopes: + if deny_conditions: + scope = scope.with_deny_conditions(deny_conditions) + final_scopes.add(scope) + + # if there is a scope with no conditions, we can ignore everything else + for scope in final_scopes: + if scope.has_no_condititons(): + final_scopes = {scope} + break + + log.debug( + f"Found access permission, {action} is allowed for {resource} by {request_context.principal}, level: {level}. Scopes: {len(final_scopes)}" + ) + + # return the result + return AccessPermission( + action=action.raw, + level=level, + scopes=tuple(final_scopes), + ) + + + return check_non_resource_policies_closure + + def compute_permissions( resource: ARN, iam_context: IamRequestContext, @@ -1092,8 +1204,13 @@ def compute_permissions( action_to_check = ActionToCheck( service=service.lower(), action_name=action_name.lower(), raw_lower=action.lower(), raw=action ) - if p := check_policies(iam_context, resource, action_to_check, resource_based_policies): - all_permissions.append(p) + + if resource_based_policies: + if p := check_policies(iam_context, resource, action_to_check, resource_based_policies): + all_permissions.append(p) + else: + if p := check_non_resource_policies(iam_context, action_to_check)(resource): + all_permissions.append(p) return all_permissions From 9b934889fba595e1c6e121344f60d2971bd1f85a Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Thu, 28 Nov 2024 12:44:12 +0100 Subject: [PATCH 12/14] use segmented arn matching for resources --- plugins/aws/fix_plugin_aws/access_edges.py | 123 +++++++++++++++++---- 1 file changed, 101 insertions(+), 22 deletions(-) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges.py index 5f1a08258b..d348779a3d 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges.py @@ -72,6 +72,8 @@ def pattern_from_action(action: str) -> ActionWildcardPattern: self.actions_patterns = [pattern_from_action(action) for action in self.actions] self.not_action_patterns = [pattern_from_action(action) for action in self.not_action] + self.resource_patterns = [ResourceWildcardPattern.from_str(resource) for resource in self.resources] + self.not_resource_patterns = [ResourceWildcardPattern.from_str(resource) for resource in self.not_resource] class FixPolicyDocument(PolicyDocument): @@ -94,6 +96,14 @@ class ArnResourceValueKind(enum.Enum): Pattern = 2 # the segment is a pattern, e.g. "my_corporate_bucket/*", Any = 3 # the segment is missing, e.g. "::" or it is a wildcard, e.g. "*" + @staticmethod + def from_str(value: str) -> "ArnResourceValueKind": + if value == "*": + return ArnResourceValueKind.Any + if "*" in value: + return ArnResourceValueKind.Pattern + return ArnResourceValueKind.Static + @frozen(slots=True) class ArnResource: value: str @@ -310,6 +320,55 @@ def list_principals(self, resource_arn: ARN) -> Set[str]: +@frozen(slots=True) +class ResourceWildcardPattern: + raw_value: str + partition: str | None # None in case the whole string is "*" + service: str + region: str + region_value_kind: ArnResourceValueKind + account: str + account_value_kind: ArnResourceValueKind + resource: str + resource_value_kind: ArnResourceValueKind + + + @staticmethod + def from_str(value: str) -> "ResourceWildcardPattern": + if value == "*": + return ResourceWildcardPattern( + raw_value=value, + partition=None, + service="*", + region="*", + region_value_kind=ArnResourceValueKind.Any, + account="*", + account_value_kind=ArnResourceValueKind.Any, + resource="*", + resource_value_kind=ArnResourceValueKind.Any + ) + + try: + splitted = value.split(":", 5) + if len(splitted) != 6: + raise ValueError(f"Invalid resource pattern: {value}") + _, partition, service, region, account, resource = splitted + + return ResourceWildcardPattern( + raw_value=value, + partition=partition, + service=service, + region=region, + region_value_kind=ArnResourceValueKind.from_str(region), + account=account, + account_value_kind=ArnResourceValueKind.from_str(account), + resource=resource, + resource_value_kind=ArnResourceValueKind.from_str(resource) + ) + except Exception as e: + log.error(f"Error parsing resource pattern {value}: {e}") + raise e + @frozen(slots=True) class IamRequestContext: principal: AwsResource @@ -472,18 +531,6 @@ def make_resoruce_regex(aws_resorce_wildcard: str) -> Pattern[str]: return re.compile(f"^{python_regex}$", re.IGNORECASE) -def _expand_wildcards_and_match(*, identifier: str, wildcard_string: str) -> bool: - """ - helper function to expand wildcards and match the identifier - - use case: - match the resource constraint (wildcard) with the ARN - match the wildcard action with the specific action - """ - pattern = make_resoruce_regex(wildcard_string) - return pattern.match(identifier) is not None - - @lru_cache(maxsize=1024) def _compile_action_pattern(wildcard_pattern: str) -> tuple[str, re.Pattern[str] | None]: """ @@ -535,8 +582,40 @@ def expand_action_wildcards_and_match(action: ActionToCheck, wildcard_pattern: A return False -def expand_arn_wildcards_and_match(identifier: str, wildcard_string: str) -> bool: - return _expand_wildcards_and_match(identifier=identifier, wildcard_string=wildcard_string) +def match_pattern(resource_segment: str, wildcard_segment: str, wildcard_segment_kind: ArnResourceValueKind) -> bool: + match wildcard_segment_kind: + case ArnResourceValueKind.Any: + return True + case ArnResourceValueKind.Pattern: + return fnmatch.fnmatch(resource_segment, wildcard_segment) + case ArnResourceValueKind.Static: + return resource_segment == wildcard_segment + + + +def expand_arn_wildcards_and_match(identifier: ARN, wildcard_string: ResourceWildcardPattern) -> bool: + + # if wildard is *, we can shortcut here + if wildcard_string.partition is None: + return True + + # go through the ARN segments and match them + if not wildcard_string.partition == identifier.partition: + return False + + if not wildcard_string.service == identifier.service_prefix: + return False + + if not match_pattern(identifier.region, wildcard_string.region, wildcard_string.region_value_kind): + return False + + if not match_pattern(identifier.account, wildcard_string.account, wildcard_string.account_value_kind): + return False + + if not match_pattern(identifier.resource_string, wildcard_string.resource, wildcard_string.resource_value_kind): + return False + + return True @lru_cache(maxsize=4096) @@ -621,19 +700,19 @@ def check_resource_match(arn: ARN) -> Optional[List[ResourceConstraint]]: # step 4: check if the resource matches matched_resource_constraints: List[ResourceConstraint] = [] resource_matches = False - if len(statement.resources) > 0: - for resource_constraint in statement.resources: - if expand_arn_wildcards_and_match(identifier=arn.arn, wildcard_string=resource_constraint): - matched_resource_constraints.append(resource_constraint) + if len(statement.resource_patterns) > 0: + for resource_constraint in statement.resource_patterns: + if expand_arn_wildcards_and_match(identifier=arn, wildcard_string=resource_constraint): + matched_resource_constraints.append(resource_constraint.raw_value) resource_matches = True break - elif len(statement.not_resource) > 0: + elif len(statement.not_resource_patterns) > 0: resource_matches = True - for not_resource_constraint in statement.not_resource: - if expand_arn_wildcards_and_match(identifier=arn.arn, wildcard_string=not_resource_constraint): + for not_resource_constraint in statement.not_resource_patterns: + if expand_arn_wildcards_and_match(identifier=arn, wildcard_string=not_resource_constraint): resource_matches = False break - matched_resource_constraints.append("not " + not_resource_constraint) + matched_resource_constraints.append("not " + not_resource_constraint.raw_value) else: # no Resource/NotResource specified, consider allowed resource_matches = True From c93c2ceaf5154ec48d8e58a75d38b3d39b389000 Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Thu, 28 Nov 2024 14:51:03 +0100 Subject: [PATCH 13/14] make the linter happy --- plugins/aws/fix_plugin_aws/access_edges.py | 119 +++++++------- plugins/aws/fix_plugin_aws/collector.py | 10 +- plugins/aws/test/acccess_edges_test.py | 171 ++++++++++----------- 3 files changed, 136 insertions(+), 164 deletions(-) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges.py index d348779a3d..32f6482b1e 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges.py @@ -92,9 +92,9 @@ class ActionToCheck: class ArnResourceValueKind(enum.Enum): - Static = 1 # the segment is a fixed value, e.g. "s3", "vpc/vpc-0e9801d129EXAMPLE", - Pattern = 2 # the segment is a pattern, e.g. "my_corporate_bucket/*", - Any = 3 # the segment is missing, e.g. "::" or it is a wildcard, e.g. "*" + Static = 1 # the segment is a fixed value, e.g. "s3", "vpc/vpc-0e9801d129EXAMPLE", + Pattern = 2 # the segment is a pattern, e.g. "my_corporate_bucket/*", + Any = 3 # the segment is missing, e.g. "::" or it is a wildcard, e.g. "*" @staticmethod def from_str(value: str) -> "ArnResourceValueKind": @@ -104,6 +104,7 @@ def from_str(value: str) -> "ArnResourceValueKind": return ArnResourceValueKind.Pattern return ArnResourceValueKind.Static + @frozen(slots=True) class ArnResource: value: str @@ -121,19 +122,16 @@ def matches(self, segment: str) -> bool: case ArnResourceValueKind.Static: _match = segment == self.value - if self.not_resource: _match = not _match - return _match - @frozen(slots=True) class ArnAccountId: value: str - wildcard: bool # if the account is a wildcard, e.g. "*" or "::" + wildcard: bool # if the account is a wildcard, e.g. "*" or "::" principal_arns: Set[str] children: List[ArnResource] @@ -144,7 +142,7 @@ def matches(self, segment: str) -> bool: @frozen(slots=True) class ArnRegion: value: str - wildcard: bool # if the region is a wildcard, e.g. "*" or "::" + wildcard: bool # if the region is a wildcard, e.g. "*" or "::" principal_arns: Set[str] children: List[ArnAccountId] @@ -160,12 +158,12 @@ class ArnService: def matches(self, segment: str) -> bool: return self.value == segment - + @frozen(slots=True) class ArnPartition: value: str - wildcard: bool # for the cases like "Allow": "*" on all resources + wildcard: bool # for the cases like "Allow": "*" on all resources principal_arns: Set[str] children: List[ArnService] @@ -181,7 +179,6 @@ class PrincipalTree: def __init__(self) -> None: self.partitions: List[ArnPartition] = [] - def _add_allow_all_wildcard(self, principal_arn: str) -> None: partition = next((p for p in self.partitions if p.value == "*"), None) if not partition: @@ -195,7 +192,6 @@ def _add_resource(self, resource_constraint: str, principal_arn: str, nr: bool = _add resource will add the principal arn at the resource level """ - try: arn = ARN(resource_constraint) # Find existing or create partition @@ -221,11 +217,15 @@ def _add_resource(self, resource_constraint: str, principal_arn: str, nr: bool = account_wildcard = arn.account == "*" or not arn.account account = next((a for a in region.children if a.value == (arn.account or "*")), None) if not account: - account = ArnAccountId(value=arn.account or "*", wildcard=account_wildcard, principal_arns=set(), children=[]) + account = ArnAccountId( + value=arn.account or "*", wildcard=account_wildcard, principal_arns=set(), children=[] + ) region.children.append(account) # Add resource - resource = next((r for r in account.children if r.value == arn.resource_string and r.not_resource == nr), None) + resource = next( + (r for r in account.children if r.value == arn.resource_string and r.not_resource == nr), None + ) if not resource: if arn.resource_string == "*": resource_kind = ArnResourceValueKind.Any @@ -233,7 +233,9 @@ def _add_resource(self, resource_constraint: str, principal_arn: str, nr: bool = resource_kind = ArnResourceValueKind.Pattern else: resource_kind = ArnResourceValueKind.Static - resource = ArnResource(value=arn.resource_string, principal_arns=set(), kind=resource_kind, not_resource=nr) + resource = ArnResource( + value=arn.resource_string, principal_arns=set(), kind=resource_kind, not_resource=nr + ) account.children.append(resource) resource.principal_arns.add(principal_arn) @@ -242,7 +244,6 @@ def _add_resource(self, resource_constraint: str, principal_arn: str, nr: bool = log.error(f"Error parsing ARN {principal_arn}: {e}") pass - def _add_service(self, service_prefix: str, principal_arn: str) -> None: # Find existing or create partition partition = next((p for p in self.partitions if p.value == "*"), None) @@ -258,11 +259,9 @@ def _add_service(self, service_prefix: str, principal_arn: str) -> None: service.principal_arns.add(principal_arn) - - def add_principal(self, principal_arn: str, policy_documents: List[FixPolicyDocument]) -> None: """ - This method iterates over every policy statement and adds corresponding arns to principal tree. + This method iterates over every policy statement and adds corresponding arns to principal tree. """ for policy_doc in policy_documents: @@ -276,30 +275,30 @@ def add_principal(self, principal_arn: str, policy_documents: List[FixPolicyDocu self._add_resource(resource, principal_arn) for not_resource in statement.not_resource: self._add_resource(not_resource, principal_arn, nr=True) - + if has_wildcard_resource or (not statement.resources and not statement.not_resource): for ap in statement.actions_patterns: if ap.kind == WildcardKind.any: self._add_allow_all_wildcard(principal_arn) self._add_service(ap.service, principal_arn) - def list_principals(self, resource_arn: ARN) -> Set[str]: """ this will be called for every resource and it must be fast """ - principals = set() + principals: Set[str] = set() matching_partitions = [p for p in self.partitions if p.value if p.matches(resource_arn.partition)] if not matching_partitions: return principals - matching_services = [s for p in matching_partitions for s in p.children if s.matches(resource_arn.service_prefix)] + matching_services = [ + s for p in matching_partitions for s in p.children if s.matches(resource_arn.service_prefix) + ] if not matching_services: return principals principals.update([arn for s in matching_services for arn in s.principal_arns]) - matching_regions = [r for s in matching_services for r in s.children if r.matches(resource_arn.region)] if not matching_regions: return principals @@ -310,20 +309,21 @@ def list_principals(self, resource_arn: ARN) -> Set[str]: return principals principals.update([arn for a in matching_account_ids for arn in a.principal_arns]) - matching_resources = [r for a in matching_account_ids for r in a.children if r.matches(resource_arn.resource_string)] + matching_resources = [ + r for a in matching_account_ids for r in a.children if r.matches(resource_arn.resource_string) + ] if not matching_resources: return principals - + principals.update([arn for r in matching_resources for arn in r.principal_arns]) return principals - @frozen(slots=True) class ResourceWildcardPattern: raw_value: str - partition: str | None # None in case the whole string is "*" + partition: str | None # None in case the whole string is "*" service: str region: str region_value_kind: ArnResourceValueKind @@ -332,22 +332,21 @@ class ResourceWildcardPattern: resource: str resource_value_kind: ArnResourceValueKind - @staticmethod def from_str(value: str) -> "ResourceWildcardPattern": if value == "*": return ResourceWildcardPattern( raw_value=value, - partition=None, - service="*", + partition=None, + service="*", region="*", - region_value_kind=ArnResourceValueKind.Any, + region_value_kind=ArnResourceValueKind.Any, account="*", - account_value_kind=ArnResourceValueKind.Any, - resource="*", - resource_value_kind=ArnResourceValueKind.Any + account_value_kind=ArnResourceValueKind.Any, + resource="*", + resource_value_kind=ArnResourceValueKind.Any, ) - + try: splitted = value.split(":", 5) if len(splitted) != 6: @@ -363,12 +362,13 @@ def from_str(value: str) -> "ResourceWildcardPattern": account=account, account_value_kind=ArnResourceValueKind.from_str(account), resource=resource, - resource_value_kind=ArnResourceValueKind.from_str(resource) + resource_value_kind=ArnResourceValueKind.from_str(resource), ) except Exception as e: log.error(f"Error parsing resource pattern {value}: {e}") raise e + @frozen(slots=True) class IamRequestContext: principal: AwsResource @@ -378,7 +378,6 @@ class IamRequestContext: # starting from the root, then all org units, then the account service_control_policy_levels: Tuple[Tuple[FixPolicyDocument, ...], ...] - def all_policies( self, resource_based_policies: Optional[Tuple[Tuple[PolicySource, FixPolicyDocument], ...]] = None ) -> List[FixPolicyDocument]: @@ -392,6 +391,7 @@ def all_policies( IamAction = str + @lru_cache(maxsize=4096) def find_allowed_action(policy_document: FixPolicyDocument, service_prefix: str) -> Set[IamAction]: allowed_actions: Set[IamAction] = set() @@ -592,29 +592,28 @@ def match_pattern(resource_segment: str, wildcard_segment: str, wildcard_segment return resource_segment == wildcard_segment - def expand_arn_wildcards_and_match(identifier: ARN, wildcard_string: ResourceWildcardPattern) -> bool: # if wildard is *, we can shortcut here if wildcard_string.partition is None: return True - + # go through the ARN segments and match them if not wildcard_string.partition == identifier.partition: return False - + if not wildcard_string.service == identifier.service_prefix: return False - + if not match_pattern(identifier.region, wildcard_string.region, wildcard_string.region_value_kind): return False - + if not match_pattern(identifier.account, wildcard_string.account, wildcard_string.account_value_kind): return False - + if not match_pattern(identifier.resource_string, wildcard_string.resource, wildcard_string.resource_value_kind): return False - + return True @@ -966,12 +965,12 @@ def check_identity_policies_closure(resource: ARN) -> List[PermissionScope]: return check_identity_policies_closure + @lru_cache(maxsize=4096) def check_permission_boundaries( request_context: IamRequestContext, action: ActionToCheck ) -> Callable[[ARN], Union[Literal["Denied", "NextStep"], List[Json]]]: - matching_fns = [] # ignore policy sources and resource constraints because permission boundaries @@ -995,7 +994,7 @@ def check_permission_boundaries_closure(resource: ARN) -> Union[Literal["Denied" # no matching permission boundaries that allow access return "Denied" - + return check_permission_boundaries_closure @@ -1185,7 +1184,6 @@ def check_non_resource_policies_closure(resource: ARN) -> Optional[AccessPermiss # comes from the resource based policies and identity based policies allowed_scopes: List[PermissionScope] = [] - # 1. check for explicit deny. If denied, we can abort immediately result = explicit_deny_fn(resource) if result == "Denied": @@ -1197,14 +1195,14 @@ def check_non_resource_policies_closure(resource: ARN) -> Optional[AccessPermiss # satisfying any of the conditions above will deny the action deny_conditions.append(c) - - # 2. check for organization SCPs # todo: move it outside the loop - if len(request_context.service_control_policy_levels) > 0 and not is_service_linked_role(request_context.principal): + if len(request_context.service_control_policy_levels) > 0 and not is_service_linked_role( + request_context.principal + ): org_scp_allowed = scp_allowed(request_context, action, resource) if not org_scp_allowed: return None - + # 3. skip resource based policies because the resource has none # 4. to make it a bit simpler, we check the permission boundaries before checking identity based policies @@ -1252,7 +1250,6 @@ def check_non_resource_policies_closure(resource: ARN) -> Optional[AccessPermiss scopes=tuple(final_scopes), ) - return check_non_resource_policies_closure @@ -1379,7 +1376,6 @@ def _build_principal_tree(self) -> PrincipalTree: return tree - def _compute_actions_for_resource(self) -> Dict[str, set[IamAction]]: actions_for_resource: Dict[str, set[IamAction]] = {} @@ -1490,14 +1486,13 @@ def add_access_edges(self) -> None: assert node.arn resource_arn = ARN(node.arn) - if not isinstance(node, HasResourcePolicy): # here we have identity-based policies only and can prune some principals for arn in self.principal_tree.list_principals(resource_arn): context = self.arn_to_context.get(arn) if not context: raise ValueError(f"Principal {arn} not found in the context") - + permissions = compute_permissions( resource_arn, context, tuple(), self.actions_for_resource.get(node.arn, set()) ) @@ -1509,7 +1504,9 @@ def add_access_edges(self) -> None: for permission in permissions: access[permission.level] = True reported = to_json({"permissions": permissions} | access, strip_nulls=True) - self.builder.add_edge(from_node=context.principal, edge_type=EdgeType.iam, reported=reported, node=node) + self.builder.add_edge( + from_node=context.principal, edge_type=EdgeType.iam, reported=reported, node=node + ) else: # here we have resource-based policies and must check all principals. @@ -1529,13 +1526,13 @@ def add_access_edges(self) -> None: if not permissions: continue - access: Dict[PermissionLevel, bool] = {} + access = {} for permission in permissions: access[permission.level] = True reported = to_json({"permissions": permissions} | access, strip_nulls=True) - self.builder.add_edge(from_node=context.principal, edge_type=EdgeType.iam, reported=reported, node=node) - - + self.builder.add_edge( + from_node=context.principal, edge_type=EdgeType.iam, reported=reported, node=node + ) all_principal_arns = {p.principal.arn for p in self.principals if p.principal.arn} diff --git a/plugins/aws/fix_plugin_aws/collector.py b/plugins/aws/fix_plugin_aws/collector.py index c0fa052482..03a08346c5 100644 --- a/plugins/aws/fix_plugin_aws/collector.py +++ b/plugins/aws/fix_plugin_aws/collector.py @@ -2,7 +2,6 @@ from collections import defaultdict from concurrent.futures import Future, ThreadPoolExecutor from datetime import datetime, timedelta, timezone -import os from typing import List, Type, Optional, Union, cast, Any from fix_plugin_aws.access_edges import AccessEdgeCreator @@ -69,7 +68,6 @@ from fixlib.threading import ExecutorQueue, GatherFutures from fixlib.types import Json from .utils import global_region_by_partition -from pyinstrument import Profiler log = logging.getLogger("fix.plugins.aws") @@ -264,18 +262,12 @@ def get_last_run() -> Optional[datetime]: log.warning(f"Unexpected node type {node} in graph") raise Exception("Only AWS resources expected") - access_edge_collection_enabled = True + access_edge_collection_enabled = False if access_edge_collection_enabled and global_builder.config.collect_access_edges: # add access edges - profiler = Profiler() - profiler.start() log.info(f"[Aws:{self.account.id}] Create access edges.") access_edge_creator = AccessEdgeCreator(global_builder) access_edge_creator.add_access_edges() - profiler.stop() - html_output = profiler.output_html() - with open(f"profiler_{self.account.id}.html", "w") as f: - f.write(html_output) # final hook when the graph is complete for node, data in list(self.graph.nodes(data=True)): diff --git a/plugins/aws/test/acccess_edges_test.py b/plugins/aws/test/acccess_edges_test.py index d59c897a2d..cd15c1342d 100644 --- a/plugins/aws/test/acccess_edges_test.py +++ b/plugins/aws/test/acccess_edges_test.py @@ -19,7 +19,7 @@ ActionToCheck, get_actions_matching_arn, PrincipalTree, - ArnResourceValueKind + ArnResourceValueKind, ) from fixlib.baseresources import PolicySourceKind, PolicySource, PermissionLevel @@ -1025,9 +1025,9 @@ def test_principal_tree_add_allow_all_wildcard() -> None: """Test adding wildcard (*) permission to the principal tree.""" tree = PrincipalTree() principal_arn = "arn:aws:iam::123456789012:user/test-user" - + tree._add_allow_all_wildcard(principal_arn) - + # Verify the wildcard partition exists assert len(tree.partitions) == 1 partition = tree.partitions[0] @@ -1041,32 +1041,32 @@ def test_principal_tree_add_resource() -> None: tree = PrincipalTree() principal_arn = "arn:aws:iam::123456789012:user/test-user" resource_arn = "arn:aws:s3:::my-bucket/my-object" - + tree._add_resource(resource_arn, principal_arn) - + # Verify the partition structure assert len(tree.partitions) == 1 partition = tree.partitions[0] assert partition.value == "aws" assert not partition.wildcard - + # Verify service level assert len(partition.children) == 1 service = partition.children[0] assert service.value == "s3" - + # Verify region level assert len(service.children) == 1 region = service.children[0] assert region.value == "*" assert region.wildcard - + # Verify account level assert len(region.children) == 1 account = region.children[0] assert account.value == "*" assert account.wildcard - + # Verify resource level assert len(account.children) == 1 resource = account.children[0] @@ -1081,16 +1081,16 @@ def test_principal_tree_add_resource_with_wildcard() -> None: tree = PrincipalTree() principal_arn = "arn:aws:iam::123456789012:user/test-user" resource_arn = "arn:aws:s3:::my-bucket/*" - + tree._add_resource(resource_arn, principal_arn) - + # Verify the resource level has correct wildcard pattern partition = tree.partitions[0] service = partition.children[0] region = service.children[0] account = region.children[0] resource = account.children[0] - + assert resource.value == "my-bucket/*" assert resource.kind == ArnResourceValueKind.Pattern assert principal_arn in resource.principal_arns @@ -1101,9 +1101,9 @@ def test_principal_tree_add_not_resource() -> None: tree = PrincipalTree() principal_arn = "arn:aws:iam::123456789012:user/test-user" resource_arn = "arn:aws:s3:::my-bucket/private/*" - + tree._add_resource(resource_arn, principal_arn, nr=True) - + # Verify the NotResource flag is set correctly through the tree partition = tree.partitions[0] service = partition.children[0] @@ -1118,14 +1118,14 @@ def test_principal_tree_add_service() -> None: tree = PrincipalTree() principal_arn = "arn:aws:iam::123456789012:user/test-user" service_prefix = "s3" - + tree._add_service(service_prefix, principal_arn) - + # Verify service is added under wildcard partition assert len(tree.partitions) == 1 partition = tree.partitions[0] assert partition.value == "*" - + assert len(partition.children) == 1 service = partition.children[0] assert service.value == "s3" @@ -1136,37 +1136,30 @@ def test_principal_tree_add_principal_policy() -> None: """Test adding a principal with policy documents to the principal tree.""" tree = PrincipalTree() principal_arn = "arn:aws:iam::123456789012:user/test-user" - + policy_json = { "Version": "2012-10-17", "Statement": [ - { - "Effect": "Allow", - "Action": ["s3:GetObject"], - "Resource": "arn:aws:s3:::my-bucket/*" - }, - { - "Effect": "Allow", - "Action": ["s3:ListAllMyBuckets"], - "Resource": "*" - } - ] + {"Effect": "Allow", "Action": ["s3:GetObject"], "Resource": "arn:aws:s3:::my-bucket/*"}, + {"Effect": "Allow", "Action": ["s3:ListAllMyBuckets"], "Resource": "*"}, + ], } - + policy_doc = FixPolicyDocument(policy_json) tree.add_principal(principal_arn, [policy_doc]) - + # Verify both the specific resource and wildcard permissions are added assert any( - p.value == "aws" and - any(s.value == "s3" and - any(r.value == "*" and - any(a.value == "*" and - any(res.value == "my-bucket/*" - for res in a.children) - for a in r.children) - for r in s.children) - for s in p.children) + p.value == "aws" + and any( + s.value == "s3" + and any( + r.value == "*" + and any(a.value == "*" and any(res.value == "my-bucket/*" for res in a.children) for a in r.children) + for r in s.children + ) + for s in p.children + ) for p in tree.partitions ) @@ -1176,33 +1169,29 @@ def test_principal_tree_list_principals() -> None: tree = PrincipalTree() principal1 = "arn:aws:iam::123456789012:user/test-user1" principal2 = "arn:aws:iam::123456789012:user/test-user2" - + # Add different types of permissions - policy_doc1 = FixPolicyDocument({ - "Version": "2012-10-17", - "Statement": [{ - "Effect": "Allow", - "Action": ["s3:GetObject"], - "Resource": "arn:aws:s3:::my-bucket/*" - }] - }) - - policy_doc2 = FixPolicyDocument({ - "Version": "2012-10-17", - "Statement": [{ - "Effect": "Allow", - "Action": ["s3:ListAllMyBuckets"], - "Resource": "*" - }] - }) - + policy_doc1 = FixPolicyDocument( + { + "Version": "2012-10-17", + "Statement": [{"Effect": "Allow", "Action": ["s3:GetObject"], "Resource": "arn:aws:s3:::my-bucket/*"}], + } + ) + + policy_doc2 = FixPolicyDocument( + { + "Version": "2012-10-17", + "Statement": [{"Effect": "Allow", "Action": ["s3:ListAllMyBuckets"], "Resource": "*"}], + } + ) + tree.add_principal(principal1, [policy_doc1]) tree.add_principal(principal2, [policy_doc2]) - + # Test specific resource access resource_arn = ARN("arn:aws:s3:::my-bucket/test.txt") matching_principals = tree.list_principals(resource_arn) - + assert principal1 in matching_principals # Has specific access assert principal2 in matching_principals # Has wildcard access @@ -1211,29 +1200,23 @@ def test_principal_tree_add_multiple_statements() -> None: """Test adding multiple statements for the same principal.""" tree = PrincipalTree() principal_arn = "arn:aws:iam::123456789012:user/test-user" - - policy_doc = FixPolicyDocument({ - "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": ["s3:GetObject"], - "Resource": "arn:aws:s3:::bucket1/*" - }, - { - "Effect": "Allow", - "Action": ["s3:PutObject"], - "Resource": "arn:aws:s3:::bucket2/*" - } - ] - }) - + + policy_doc = FixPolicyDocument( + { + "Version": "2012-10-17", + "Statement": [ + {"Effect": "Allow", "Action": ["s3:GetObject"], "Resource": "arn:aws:s3:::bucket1/*"}, + {"Effect": "Allow", "Action": ["s3:PutObject"], "Resource": "arn:aws:s3:::bucket2/*"}, + ], + } + ) + tree.add_principal(principal_arn, [policy_doc]) - + # Test access to both buckets bucket1_arn = ARN("arn:aws:s3:::bucket1/test.txt") bucket2_arn = ARN("arn:aws:s3:::bucket2/test.txt") - + assert principal_arn in tree.list_principals(bucket1_arn) assert principal_arn in tree.list_principals(bucket2_arn) @@ -1242,28 +1225,28 @@ def test_principal_tree_not_resource() -> None: """Test NotResource handling in the principal tree.""" tree = PrincipalTree() principal_arn = "arn:aws:iam::123456789012:user/test-user" - - policy_doc = FixPolicyDocument({ - "Version": "2012-10-17", - "Statement": [{ - "Effect": "Allow", - "Action": ["s3:GetObject"], - "NotResource": ["arn:aws:s3:::private-bucket/*"] - }] - }) - + + policy_doc = FixPolicyDocument( + { + "Version": "2012-10-17", + "Statement": [ + {"Effect": "Allow", "Action": ["s3:GetObject"], "NotResource": ["arn:aws:s3:::private-bucket/*"]} + ], + } + ) + tree.add_principal(principal_arn, [policy_doc]) - + # Test access is denied to private bucket private_arn = ARN("arn:aws:s3:::private-bucket/secret.txt") public_arn = ARN("arn:aws:s3:::public-bucket/public.txt") ec2 = ARN("arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0") - + matching_principals = tree.list_principals(private_arn) assert principal_arn not in matching_principals - + matching_principals = tree.list_principals(public_arn) assert principal_arn in matching_principals matching_principals = tree.list_principals(ec2) - assert len(matching_principals) == 0 \ No newline at end of file + assert len(matching_principals) == 0 From ca9fffc4f39bb04a8c3c52bf2b45c563b93b7bd7 Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Mon, 9 Dec 2024 19:35:13 +0000 Subject: [PATCH 14/14] refactor the arn tree --- .../fix_plugin_aws/access_edges/__init__.py | 0 .../fix_plugin_aws/access_edges/arn_tree.py | 231 +++++++++++ .../edge_builder.py} | 373 ++---------------- .../aws/fix_plugin_aws/access_edges/types.py | 117 ++++++ plugins/aws/fix_plugin_aws/collector.py | 2 +- plugins/aws/test/acccess_edges_test.py | 83 ++-- 6 files changed, 419 insertions(+), 387 deletions(-) create mode 100644 plugins/aws/fix_plugin_aws/access_edges/__init__.py create mode 100644 plugins/aws/fix_plugin_aws/access_edges/arn_tree.py rename plugins/aws/fix_plugin_aws/{access_edges.py => access_edges/edge_builder.py} (79%) create mode 100644 plugins/aws/fix_plugin_aws/access_edges/types.py diff --git a/plugins/aws/fix_plugin_aws/access_edges/__init__.py b/plugins/aws/fix_plugin_aws/access_edges/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/aws/fix_plugin_aws/access_edges/arn_tree.py b/plugins/aws/fix_plugin_aws/access_edges/arn_tree.py new file mode 100644 index 0000000000..581fea4f94 --- /dev/null +++ b/plugins/aws/fix_plugin_aws/access_edges/arn_tree.py @@ -0,0 +1,231 @@ +from typing import List, Set +from attrs import frozen +from fix_plugin_aws.access_edges.types import ArnResourceValueKind, FixPolicyDocument, WildcardKind +from policy_sentry.util.arns import ARN +import fnmatch +import logging + + +log = logging.getLogger("fix.plugins.aws") + + +@frozen(slots=True) +class ArnResource[T]: + key: str + values: Set[T] + kind: ArnResourceValueKind + not_resource: bool + + def matches(self, segment: str) -> bool: + _match = False + match self.kind: + case ArnResourceValueKind.Any: + _match = True + case ArnResourceValueKind.Pattern: + _match = fnmatch.fnmatch(segment, self.key) + case ArnResourceValueKind.Static: + _match = segment == self.key + + if self.not_resource: + _match = not _match + + return _match + + +@frozen(slots=True) +class ArnAccountId[T]: + key: str + wildcard: bool # if the account is a wildcard, e.g. "*" or "::" + values: Set[T] + children: List[ArnResource[T]] + + def matches(self, segment: str) -> bool: + return self.wildcard or self.key == segment + + +@frozen(slots=True) +class ArnRegion[T]: + key: str + wildcard: bool # if the region is a wildcard, e.g. "*" or "::" + values: Set[T] + children: List[ArnAccountId[T]] + + def matches(self, segment: str) -> bool: + return self.wildcard or self.key == segment + + +@frozen(slots=True) +class ArnService[T]: + key: str + values: Set[T] + children: List[ArnRegion[T]] + + def matches(self, segment: str) -> bool: + return self.key == segment + + +@frozen(slots=True) +class ArnPartition[T]: + key: str + wildcard: bool # for the cases like "Allow": "*" on all resources + values: Set[T] + children: List[ArnService[T]] + + def matches(self, segment: str) -> bool: + return self.wildcard or segment == self.key + + +class ArnTree[T]: + def __init__(self) -> None: + self.partitions: List[ArnPartition[T]] = [] + + def add_element(self, elem: T, policy_documents: List[FixPolicyDocument]) -> None: + """ + This method iterates over every policy statement and adds corresponding arns to principal tree. + """ + + for policy_doc in policy_documents: + for statement in policy_doc.fix_statements: + if statement.effect_allow: + has_wildcard_resource = False + for resource in statement.resources: + if resource == "*": + has_wildcard_resource = True + continue + self._add_resource(resource, elem) + for not_resource in statement.not_resource: + self._add_resource(not_resource, elem, nr=True) + + if has_wildcard_resource or (not statement.resources and not statement.not_resource): + for ap in statement.actions_patterns: + if ap.kind == WildcardKind.any: + self._add_allow_all_wildcard(elem) + self._add_service(ap.service, elem) + + def _add_allow_all_wildcard(self, elem: T) -> None: + partition = next((p for p in self.partitions if p.key == "*"), None) + if not partition: + partition = ArnPartition(key="*", wildcard=True, values=set(), children=[]) + self.partitions.append(partition) + + partition.values.add(elem) + + def _add_resource(self, resource_constraint: str, elem: T, nr: bool = False) -> None: + """ + _add resource will add the principal arn at the resource level + """ + + try: + arn = ARN(resource_constraint) + # Find existing or create partition + partition = next((p for p in self.partitions if p.key == arn.partition), None) + if not partition: + partition = ArnPartition[T](key=arn.partition, wildcard=False, values=set(), children=[]) + self.partitions.append(partition) + + # Find or create service + service = next((s for s in partition.children if s.key == arn.service_prefix), None) + if not service: + service = ArnService[T](key=arn.service_prefix, values=set(), children=[]) + partition.children.append(service) + + # Find or create region + region_wildcard = arn.region == "*" or not arn.region + region = next((r for r in service.children if r.key == (arn.region or "*")), None) + if not region: + region = ArnRegion[T](key=arn.region or "*", wildcard=region_wildcard, values=set(), children=[]) + service.children.append(region) + + # Find or create account + account_wildcard = arn.account == "*" or not arn.account + account = next((a for a in region.children if a.key == (arn.account or "*")), None) + if not account: + account = ArnAccountId[T](key=arn.account or "*", wildcard=account_wildcard, values=set(), children=[]) + region.children.append(account) + + # Add resource + resource = next( + (r for r in account.children if r.key == arn.resource_string and r.not_resource == nr), None + ) + if not resource: + if arn.resource_string == "*": + resource_kind = ArnResourceValueKind.Any + elif "*" in arn.resource_string: + resource_kind = ArnResourceValueKind.Pattern + else: + resource_kind = ArnResourceValueKind.Static + resource = ArnResource(key=arn.resource_string, values=set(), kind=resource_kind, not_resource=nr) + account.children.append(resource) + + resource.values.add(elem) + + except Exception as e: + log.error(f"Error parsing ARN {resource_constraint}: {e}") + pass + + def _add_service(self, service_prefix: str, elem: T) -> None: + # Find existing or create partition + partition = next((p for p in self.partitions if p.key == "*"), None) + if not partition: + partition = ArnPartition(key="*", wildcard=True, values=set(), children=[]) + self.partitions.append(partition) + + # Find or create service + service = next((s for s in partition.children if s.key == service_prefix), None) + if not service: + service = ArnService(key=service_prefix, values=set(), children=[]) + partition.children.append(service) + + service.values.add(elem) + + def find_matching_values(self, resource_arn: ARN) -> Set[T]: + """ + this will be called for every resource and it must be fast + """ + result: Set[T] = set() + + matching_partitions = [p for p in self.partitions if p.key if p.matches(resource_arn.partition)] + if not matching_partitions: + return result + + matching_services = [ + s for p in matching_partitions for s in p.children if s.matches(resource_arn.service_prefix) + ] + if not matching_services: + return result + result.update([arn for s in matching_services for arn in s.values]) + + matching_regions = [r for s in matching_services for r in s.children if r.matches(resource_arn.region)] + if not matching_regions: + return result + result.update([arn for r in matching_regions for arn in r.values]) + + matching_account_ids = [a for r in matching_regions for a in r.children if r.matches(resource_arn.account)] + if not matching_account_ids: + return result + result.update([arn for a in matching_account_ids for arn in a.values]) + + matching_resources = [ + r for a in matching_account_ids for r in a.children if r.matches(resource_arn.resource_string) + ] + if not matching_resources: + return result + + result.update([arn for r in matching_resources for arn in r.values]) + + return result + + +PrincipalArn = str + + +class PrincipalTree: + + def __init__(self) -> None: + self.arn_tree = ArnTree[PrincipalArn]() + + def add_principal(self, principal_arn: PrincipalArn, policy_documents: List[FixPolicyDocument]) -> None: + self.arn_tree.add_element(principal_arn, policy_documents) + + def list_principals(self, resource_arn: ARN) -> Set[str]: + return self.arn_tree.find_matching_values(resource_arn) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges/edge_builder.py similarity index 79% rename from plugins/aws/fix_plugin_aws/access_edges.py rename to plugins/aws/fix_plugin_aws/access_edges/edge_builder.py index 32f6482b1e..a8e0a6cd66 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges/edge_builder.py @@ -1,88 +1,51 @@ -from enum import Enum -import enum +import fnmatch +import logging +import re from functools import lru_cache -from attr import frozen +from typing import Callable, Dict, List, Literal, Optional, Pattern, Set, Tuple, Union + import networkx +from attr import frozen +from cloudsplaining.scan.statement_detail import StatementDetail +from fix_plugin_aws.access_edges.arn_tree import PrincipalTree +from fix_plugin_aws.access_edges.types import ( + ActionWildcardPattern, + ArnResourceValueKind, + FixPolicyDocument, + FixStatementDetail, + ResourceWildcardPattern, + WildcardKind, +) from fix_plugin_aws.resource.base import AwsAccount, AwsResource, GraphBuilder -from policy_sentry.querying.actions import get_actions_for_service -from typing import Callable, Dict, List, Literal, Set, Optional, Tuple, Union, Pattern -import fnmatch +from fix_plugin_aws.resource.iam import AwsIamGroup, AwsIamPolicy, AwsIamRole, AwsIamUser from networkx.algorithms.dag import is_directed_acyclic_graph +from policy_sentry.querying.actions import get_action_data, get_actions_for_service +from policy_sentry.querying.all import get_all_actions +from policy_sentry.querying.arns import get_matching_raw_arns, get_resource_type_name_with_raw_arn +from policy_sentry.shared.iam_data import get_service_prefix_data +from policy_sentry.util.arns import ARN, get_service_from_arn from fixlib.baseresources import ( + AccessPermission, + EdgeType, + HasResourcePolicy, PermissionCondition, - PolicySource, + PermissionLevel, PermissionScope, - AccessPermission, + PolicySource, + PolicySourceKind, ResourceConstraint, ) -from fix_plugin_aws.resource.iam import AwsIamGroup, AwsIamPolicy, AwsIamUser, AwsIamRole -from fixlib.baseresources import EdgeType, PolicySourceKind, HasResourcePolicy, PermissionLevel +from fixlib.graph import EdgeKey from fixlib.json import to_json, to_json_str from fixlib.types import Json -from cloudsplaining.scan.policy_document import PolicyDocument -from cloudsplaining.scan.statement_detail import StatementDetail -from policy_sentry.querying.actions import get_action_data -from policy_sentry.querying.all import get_all_actions -from policy_sentry.querying.arns import get_matching_raw_arns, get_resource_type_name_with_raw_arn -from policy_sentry.shared.iam_data import get_service_prefix_data -from policy_sentry.util.arns import ARN, get_service_from_arn -from fixlib.graph import EdgeKey -import re -import logging - log = logging.getLogger("fix.plugins.aws") ALL_ACTIONS = get_all_actions() -class WildcardKind(Enum): - fixed = 1 - pattern = 2 - any = 3 - - -@frozen(slots=True) -class ActionWildcardPattern: - pattern: str - service: str - kind: WildcardKind - - -class FixStatementDetail(StatementDetail): - def __init__(self, statement: Json): - super().__init__(statement) - - def pattern_from_action(action: str) -> ActionWildcardPattern: - if action == "*": - return ActionWildcardPattern(pattern=action, service="*", kind=WildcardKind.any) - - action = action.lower() - service, action_name = action.split(":", 1) - if action_name == "*": - kind = WildcardKind.any - elif "*" in action_name: - kind = WildcardKind.pattern - else: - kind = WildcardKind.fixed - - return ActionWildcardPattern(pattern=action, service=service, kind=kind) - - self.actions_patterns = [pattern_from_action(action) for action in self.actions] - self.not_action_patterns = [pattern_from_action(action) for action in self.not_action] - self.resource_patterns = [ResourceWildcardPattern.from_str(resource) for resource in self.resources] - self.not_resource_patterns = [ResourceWildcardPattern.from_str(resource) for resource in self.not_resource] - - -class FixPolicyDocument(PolicyDocument): - def __init__(self, policy_document: Json): - super().__init__(policy_document) - - self.fix_statements = [FixStatementDetail(statement.json) for statement in self.statements] - - @frozen(slots=True) class ActionToCheck: raw: str @@ -91,284 +54,6 @@ class ActionToCheck: action_name: str -class ArnResourceValueKind(enum.Enum): - Static = 1 # the segment is a fixed value, e.g. "s3", "vpc/vpc-0e9801d129EXAMPLE", - Pattern = 2 # the segment is a pattern, e.g. "my_corporate_bucket/*", - Any = 3 # the segment is missing, e.g. "::" or it is a wildcard, e.g. "*" - - @staticmethod - def from_str(value: str) -> "ArnResourceValueKind": - if value == "*": - return ArnResourceValueKind.Any - if "*" in value: - return ArnResourceValueKind.Pattern - return ArnResourceValueKind.Static - - -@frozen(slots=True) -class ArnResource: - value: str - principal_arns: Set[str] - kind: ArnResourceValueKind - not_resource: bool - - def matches(self, segment: str) -> bool: - _match = False - match self.kind: - case ArnResourceValueKind.Any: - _match = True - case ArnResourceValueKind.Pattern: - _match = fnmatch.fnmatch(segment, self.value) - case ArnResourceValueKind.Static: - _match = segment == self.value - - if self.not_resource: - _match = not _match - - return _match - - -@frozen(slots=True) -class ArnAccountId: - value: str - wildcard: bool # if the account is a wildcard, e.g. "*" or "::" - principal_arns: Set[str] - children: List[ArnResource] - - def matches(self, segment: str) -> bool: - return self.wildcard or self.value == segment - - -@frozen(slots=True) -class ArnRegion: - value: str - wildcard: bool # if the region is a wildcard, e.g. "*" or "::" - principal_arns: Set[str] - children: List[ArnAccountId] - - def matches(self, segment: str) -> bool: - return self.wildcard or self.value == segment - - -@frozen(slots=True) -class ArnService: - value: str - principal_arns: Set[str] - children: List[ArnRegion] - - def matches(self, segment: str) -> bool: - return self.value == segment - - -@frozen(slots=True) -class ArnPartition: - value: str - wildcard: bool # for the cases like "Allow": "*" on all resources - principal_arns: Set[str] - children: List[ArnService] - - def matches(self, segment: str) -> bool: - return self.wildcard or segment == self.value - - -def is_wildcard(segment: str) -> bool: - return segment == "*" or segment == "" - - -class PrincipalTree: - def __init__(self) -> None: - self.partitions: List[ArnPartition] = [] - - def _add_allow_all_wildcard(self, principal_arn: str) -> None: - partition = next((p for p in self.partitions if p.value == "*"), None) - if not partition: - partition = ArnPartition(value="*", wildcard=True, principal_arns=set(), children=[]) - self.partitions.append(partition) - - partition.principal_arns.add(principal_arn) - - def _add_resource(self, resource_constraint: str, principal_arn: str, nr: bool = False) -> None: - """ - _add resource will add the principal arn at the resource level - """ - - try: - arn = ARN(resource_constraint) - # Find existing or create partition - partition = next((p for p in self.partitions if p.value == arn.partition), None) - if not partition: - partition = ArnPartition(value=arn.partition, wildcard=False, principal_arns=set(), children=[]) - self.partitions.append(partition) - - # Find or create service - service = next((s for s in partition.children if s.value == arn.service_prefix), None) - if not service: - service = ArnService(value=arn.service_prefix, principal_arns=set(), children=[]) - partition.children.append(service) - - # Find or create region - region_wildcard = arn.region == "*" or not arn.region - region = next((r for r in service.children if r.value == (arn.region or "*")), None) - if not region: - region = ArnRegion(value=arn.region or "*", wildcard=region_wildcard, principal_arns=set(), children=[]) - service.children.append(region) - - # Find or create account - account_wildcard = arn.account == "*" or not arn.account - account = next((a for a in region.children if a.value == (arn.account or "*")), None) - if not account: - account = ArnAccountId( - value=arn.account or "*", wildcard=account_wildcard, principal_arns=set(), children=[] - ) - region.children.append(account) - - # Add resource - resource = next( - (r for r in account.children if r.value == arn.resource_string and r.not_resource == nr), None - ) - if not resource: - if arn.resource_string == "*": - resource_kind = ArnResourceValueKind.Any - elif "*" in arn.resource_string: - resource_kind = ArnResourceValueKind.Pattern - else: - resource_kind = ArnResourceValueKind.Static - resource = ArnResource( - value=arn.resource_string, principal_arns=set(), kind=resource_kind, not_resource=nr - ) - account.children.append(resource) - - resource.principal_arns.add(principal_arn) - - except Exception as e: - log.error(f"Error parsing ARN {principal_arn}: {e}") - pass - - def _add_service(self, service_prefix: str, principal_arn: str) -> None: - # Find existing or create partition - partition = next((p for p in self.partitions if p.value == "*"), None) - if not partition: - partition = ArnPartition(value="*", wildcard=True, principal_arns=set(), children=[]) - self.partitions.append(partition) - - # Find or create service - service = next((s for s in partition.children if s.value == service_prefix), None) - if not service: - service = ArnService(value=service_prefix, principal_arns=set(), children=[]) - partition.children.append(service) - - service.principal_arns.add(principal_arn) - - def add_principal(self, principal_arn: str, policy_documents: List[FixPolicyDocument]) -> None: - """ - This method iterates over every policy statement and adds corresponding arns to principal tree. - """ - - for policy_doc in policy_documents: - for statement in policy_doc.fix_statements: - if statement.effect_allow: - has_wildcard_resource = False - for resource in statement.resources: - if resource == "*": - has_wildcard_resource = True - continue - self._add_resource(resource, principal_arn) - for not_resource in statement.not_resource: - self._add_resource(not_resource, principal_arn, nr=True) - - if has_wildcard_resource or (not statement.resources and not statement.not_resource): - for ap in statement.actions_patterns: - if ap.kind == WildcardKind.any: - self._add_allow_all_wildcard(principal_arn) - self._add_service(ap.service, principal_arn) - - def list_principals(self, resource_arn: ARN) -> Set[str]: - """ - this will be called for every resource and it must be fast - """ - principals: Set[str] = set() - - matching_partitions = [p for p in self.partitions if p.value if p.matches(resource_arn.partition)] - if not matching_partitions: - return principals - - matching_services = [ - s for p in matching_partitions for s in p.children if s.matches(resource_arn.service_prefix) - ] - if not matching_services: - return principals - principals.update([arn for s in matching_services for arn in s.principal_arns]) - - matching_regions = [r for s in matching_services for r in s.children if r.matches(resource_arn.region)] - if not matching_regions: - return principals - principals.update([arn for r in matching_regions for arn in r.principal_arns]) - - matching_account_ids = [a for r in matching_regions for a in r.children if r.matches(resource_arn.account)] - if not matching_account_ids: - return principals - principals.update([arn for a in matching_account_ids for arn in a.principal_arns]) - - matching_resources = [ - r for a in matching_account_ids for r in a.children if r.matches(resource_arn.resource_string) - ] - if not matching_resources: - return principals - - principals.update([arn for r in matching_resources for arn in r.principal_arns]) - - return principals - - -@frozen(slots=True) -class ResourceWildcardPattern: - raw_value: str - partition: str | None # None in case the whole string is "*" - service: str - region: str - region_value_kind: ArnResourceValueKind - account: str - account_value_kind: ArnResourceValueKind - resource: str - resource_value_kind: ArnResourceValueKind - - @staticmethod - def from_str(value: str) -> "ResourceWildcardPattern": - if value == "*": - return ResourceWildcardPattern( - raw_value=value, - partition=None, - service="*", - region="*", - region_value_kind=ArnResourceValueKind.Any, - account="*", - account_value_kind=ArnResourceValueKind.Any, - resource="*", - resource_value_kind=ArnResourceValueKind.Any, - ) - - try: - splitted = value.split(":", 5) - if len(splitted) != 6: - raise ValueError(f"Invalid resource pattern: {value}") - _, partition, service, region, account, resource = splitted - - return ResourceWildcardPattern( - raw_value=value, - partition=partition, - service=service, - region=region, - region_value_kind=ArnResourceValueKind.from_str(region), - account=account, - account_value_kind=ArnResourceValueKind.from_str(account), - resource=resource, - resource_value_kind=ArnResourceValueKind.from_str(resource), - ) - except Exception as e: - log.error(f"Error parsing resource pattern {value}: {e}") - raise e - - @frozen(slots=True) class IamRequestContext: principal: AwsResource diff --git a/plugins/aws/fix_plugin_aws/access_edges/types.py b/plugins/aws/fix_plugin_aws/access_edges/types.py new file mode 100644 index 0000000000..ca9e83e2ca --- /dev/null +++ b/plugins/aws/fix_plugin_aws/access_edges/types.py @@ -0,0 +1,117 @@ +from enum import Enum +from attr import frozen +from cloudsplaining.scan.policy_document import PolicyDocument +from cloudsplaining.scan.statement_detail import StatementDetail +from fixlib.types import Json +import logging + + +log = logging.getLogger("fix.plugins.aws") + + +class WildcardKind(Enum): + fixed = 1 + pattern = 2 + any = 3 + + +@frozen(slots=True) +class ActionWildcardPattern: + pattern: str + service: str + kind: WildcardKind + + +class ArnResourceValueKind(Enum): + Static = 1 # the segment is a fixed value, e.g. "s3", "vpc/vpc-0e9801d129EXAMPLE", + Pattern = 2 # the segment is a pattern, e.g. "my_corporate_bucket/*", + Any = 3 # the segment is missing, e.g. "::" or it is a wildcard, e.g. "*" + + @staticmethod + def from_str(value: str) -> "ArnResourceValueKind": + if value == "*": + return ArnResourceValueKind.Any + if "*" in value: + return ArnResourceValueKind.Pattern + return ArnResourceValueKind.Static + + +@frozen(slots=True) +class ResourceWildcardPattern: + raw_value: str + partition: str | None # None in case the whole string is "*" + service: str + region: str + region_value_kind: ArnResourceValueKind + account: str + account_value_kind: ArnResourceValueKind + resource: str + resource_value_kind: ArnResourceValueKind + + @staticmethod + def from_str(value: str) -> "ResourceWildcardPattern": + if value == "*": + return ResourceWildcardPattern( + raw_value=value, + partition=None, + service="*", + region="*", + region_value_kind=ArnResourceValueKind.Any, + account="*", + account_value_kind=ArnResourceValueKind.Any, + resource="*", + resource_value_kind=ArnResourceValueKind.Any, + ) + + try: + splitted = value.split(":", 5) + if len(splitted) != 6: + raise ValueError(f"Invalid resource pattern: {value}") + _, partition, service, region, account, resource = splitted + + return ResourceWildcardPattern( + raw_value=value, + partition=partition, + service=service, + region=region, + region_value_kind=ArnResourceValueKind.from_str(region), + account=account, + account_value_kind=ArnResourceValueKind.from_str(account), + resource=resource, + resource_value_kind=ArnResourceValueKind.from_str(resource), + ) + except Exception as e: + log.error(f"Error parsing resource pattern {value}: {e}") + raise e + + +class FixStatementDetail(StatementDetail): + def __init__(self, statement: Json): + super().__init__(statement) + + def pattern_from_action(action: str) -> ActionWildcardPattern: + if action == "*": + return ActionWildcardPattern(pattern=action, service="*", kind=WildcardKind.any) + + action = action.lower() + service, action_name = action.split(":", 1) + if action_name == "*": + kind = WildcardKind.any + elif "*" in action_name: + kind = WildcardKind.pattern + else: + kind = WildcardKind.fixed + + return ActionWildcardPattern(pattern=action, service=service, kind=kind) + + self.actions_patterns = [pattern_from_action(action) for action in self.actions] + self.not_action_patterns = [pattern_from_action(action) for action in self.not_action] + self.resource_patterns = [ResourceWildcardPattern.from_str(resource) for resource in self.resources] + self.not_resource_patterns = [ResourceWildcardPattern.from_str(resource) for resource in self.not_resource] + + +class FixPolicyDocument(PolicyDocument): + def __init__(self, policy_document: Json): + super().__init__(policy_document) + + self.fix_statements = [FixStatementDetail(statement.json) for statement in self.statements] diff --git a/plugins/aws/fix_plugin_aws/collector.py b/plugins/aws/fix_plugin_aws/collector.py index 03a08346c5..0fb21eda0c 100644 --- a/plugins/aws/fix_plugin_aws/collector.py +++ b/plugins/aws/fix_plugin_aws/collector.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta, timezone from typing import List, Type, Optional, Union, cast, Any -from fix_plugin_aws.access_edges import AccessEdgeCreator +from fix_plugin_aws.access_edges.edge_builder import AccessEdgeCreator from fix_plugin_aws.aws_client import AwsClient from fix_plugin_aws.configuration import AwsConfig from fix_plugin_aws.resource import ( diff --git a/plugins/aws/test/acccess_edges_test.py b/plugins/aws/test/acccess_edges_test.py index cd15c1342d..4650a73bde 100644 --- a/plugins/aws/test/acccess_edges_test.py +++ b/plugins/aws/test/acccess_edges_test.py @@ -6,7 +6,7 @@ from policy_sentry.util.arns import ARN import re -from fix_plugin_aws.access_edges import ( +from fix_plugin_aws.access_edges.edge_builder import ( find_allowed_action, make_resoruce_regex, check_statement_match, @@ -14,13 +14,12 @@ IamRequestContext, check_explicit_deny, compute_permissions, - FixPolicyDocument, - FixStatementDetail, ActionToCheck, get_actions_matching_arn, - PrincipalTree, - ArnResourceValueKind, ) +from fix_plugin_aws.access_edges.types import FixPolicyDocument, FixStatementDetail, ArnResourceValueKind + +from fix_plugin_aws.access_edges.arn_tree import ArnTree from fixlib.baseresources import PolicySourceKind, PolicySource, PermissionLevel from fixlib.json import to_json_str @@ -1023,7 +1022,7 @@ def test_compute_permissions_role_inline_policy_allow() -> None: def test_principal_tree_add_allow_all_wildcard() -> None: """Test adding wildcard (*) permission to the principal tree.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" tree._add_allow_all_wildcard(principal_arn) @@ -1031,14 +1030,14 @@ def test_principal_tree_add_allow_all_wildcard() -> None: # Verify the wildcard partition exists assert len(tree.partitions) == 1 partition = tree.partitions[0] - assert partition.value == "*" + assert partition.key == "*" assert partition.wildcard is True - assert principal_arn in partition.principal_arns + assert principal_arn in partition.values def test_principal_tree_add_resource() -> None: """Test adding a resource ARN to the principal tree.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" resource_arn = "arn:aws:s3:::my-bucket/my-object" @@ -1047,38 +1046,38 @@ def test_principal_tree_add_resource() -> None: # Verify the partition structure assert len(tree.partitions) == 1 partition = tree.partitions[0] - assert partition.value == "aws" + assert partition.key == "aws" assert not partition.wildcard # Verify service level assert len(partition.children) == 1 service = partition.children[0] - assert service.value == "s3" + assert service.key == "s3" # Verify region level assert len(service.children) == 1 region = service.children[0] - assert region.value == "*" + assert region.key == "*" assert region.wildcard # Verify account level assert len(region.children) == 1 account = region.children[0] - assert account.value == "*" + assert account.key == "*" assert account.wildcard # Verify resource level assert len(account.children) == 1 resource = account.children[0] - assert resource.value == "my-bucket/my-object" + assert resource.key == "my-bucket/my-object" assert resource.kind == ArnResourceValueKind.Static - assert principal_arn in resource.principal_arns + assert principal_arn in resource.values assert not resource.not_resource def test_principal_tree_add_resource_with_wildcard() -> None: """Test adding a resource ARN with wildcards to the principal tree.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" resource_arn = "arn:aws:s3:::my-bucket/*" @@ -1091,14 +1090,14 @@ def test_principal_tree_add_resource_with_wildcard() -> None: account = region.children[0] resource = account.children[0] - assert resource.value == "my-bucket/*" + assert resource.key == "my-bucket/*" assert resource.kind == ArnResourceValueKind.Pattern - assert principal_arn in resource.principal_arns + assert principal_arn in resource.values def test_principal_tree_add_not_resource() -> None: """Test adding a NotResource ARN to the principal tree.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" resource_arn = "arn:aws:s3:::my-bucket/private/*" @@ -1115,7 +1114,7 @@ def test_principal_tree_add_not_resource() -> None: def test_principal_tree_add_service() -> None: """Test adding a service to the principal tree.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" service_prefix = "s3" @@ -1124,17 +1123,17 @@ def test_principal_tree_add_service() -> None: # Verify service is added under wildcard partition assert len(tree.partitions) == 1 partition = tree.partitions[0] - assert partition.value == "*" + assert partition.key == "*" assert len(partition.children) == 1 service = partition.children[0] - assert service.value == "s3" - assert principal_arn in service.principal_arns + assert service.key == "s3" + assert principal_arn in service.values def test_principal_tree_add_principal_policy() -> None: """Test adding a principal with policy documents to the principal tree.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" policy_json = { @@ -1146,16 +1145,16 @@ def test_principal_tree_add_principal_policy() -> None: } policy_doc = FixPolicyDocument(policy_json) - tree.add_principal(principal_arn, [policy_doc]) + tree.add_element(principal_arn, [policy_doc]) # Verify both the specific resource and wildcard permissions are added assert any( - p.value == "aws" + p.key == "aws" and any( - s.value == "s3" + s.key == "s3" and any( - r.value == "*" - and any(a.value == "*" and any(res.value == "my-bucket/*" for res in a.children) for a in r.children) + r.key == "*" + and any(a.key == "*" and any(res.key == "my-bucket/*" for res in a.children) for a in r.children) for r in s.children ) for s in p.children @@ -1166,7 +1165,7 @@ def test_principal_tree_add_principal_policy() -> None: def test_principal_tree_list_principals() -> None: """Test listing principals that have access to a given ARN.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal1 = "arn:aws:iam::123456789012:user/test-user1" principal2 = "arn:aws:iam::123456789012:user/test-user2" @@ -1185,12 +1184,12 @@ def test_principal_tree_list_principals() -> None: } ) - tree.add_principal(principal1, [policy_doc1]) - tree.add_principal(principal2, [policy_doc2]) + tree.add_element(principal1, [policy_doc1]) + tree.add_element(principal2, [policy_doc2]) # Test specific resource access resource_arn = ARN("arn:aws:s3:::my-bucket/test.txt") - matching_principals = tree.list_principals(resource_arn) + matching_principals = tree.find_matching_values(resource_arn) assert principal1 in matching_principals # Has specific access assert principal2 in matching_principals # Has wildcard access @@ -1198,7 +1197,7 @@ def test_principal_tree_list_principals() -> None: def test_principal_tree_add_multiple_statements() -> None: """Test adding multiple statements for the same principal.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" policy_doc = FixPolicyDocument( @@ -1211,19 +1210,19 @@ def test_principal_tree_add_multiple_statements() -> None: } ) - tree.add_principal(principal_arn, [policy_doc]) + tree.add_element(principal_arn, [policy_doc]) # Test access to both buckets bucket1_arn = ARN("arn:aws:s3:::bucket1/test.txt") bucket2_arn = ARN("arn:aws:s3:::bucket2/test.txt") - assert principal_arn in tree.list_principals(bucket1_arn) - assert principal_arn in tree.list_principals(bucket2_arn) + assert principal_arn in tree.find_matching_values(bucket1_arn) + assert principal_arn in tree.find_matching_values(bucket2_arn) def test_principal_tree_not_resource() -> None: """Test NotResource handling in the principal tree.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" policy_doc = FixPolicyDocument( @@ -1235,18 +1234,18 @@ def test_principal_tree_not_resource() -> None: } ) - tree.add_principal(principal_arn, [policy_doc]) + tree.add_element(principal_arn, [policy_doc]) # Test access is denied to private bucket private_arn = ARN("arn:aws:s3:::private-bucket/secret.txt") public_arn = ARN("arn:aws:s3:::public-bucket/public.txt") ec2 = ARN("arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0") - matching_principals = tree.list_principals(private_arn) + matching_principals = tree.find_matching_values(private_arn) assert principal_arn not in matching_principals - matching_principals = tree.list_principals(public_arn) + matching_principals = tree.find_matching_values(public_arn) assert principal_arn in matching_principals - matching_principals = tree.list_principals(ec2) + matching_principals = tree.find_matching_values(ec2) assert len(matching_principals) == 0