diff --git a/panther_analysis_tool/log_schemas/user_defined.py b/panther_analysis_tool/log_schemas/user_defined.py index 33477de3..a84eb3da 100644 --- a/panther_analysis_tool/log_schemas/user_defined.py +++ b/panther_analysis_tool/log_schemas/user_defined.py @@ -24,12 +24,13 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, cast -import boto3 from botocore import client from ruamel.yaml import YAML from ruamel.yaml.parser import ParserError from ruamel.yaml.scanner import ScannerError +from panther_analysis_tool.util import get_client + logger = logging.getLogger(__file__) @@ -38,13 +39,14 @@ class Client: _LIST_SCHEMAS_ENDPOINT = "ListSchemas" _PUT_SCHEMA_ENDPOINT = "PutUserSchema" - def __init__(self) -> None: + def __init__(self, aws_profile: str) -> None: self._lambda_client = None + self._aws_profile = aws_profile @property def lambda_client(self) -> client.BaseClient: if self._lambda_client is None: - self._lambda_client = boto3.client("lambda") + self._lambda_client = get_client(self._aws_profile, "lambda") return self._lambda_client def list_schemas(self) -> Tuple[bool, dict]: @@ -143,16 +145,17 @@ class Uploader: _SCHEMA_NAME_PREFIX = "Custom." _SCHEMA_FILE_GLOB_PATTERNS = ("*.yml", "*.yaml") - def __init__(self, path: str): + def __init__(self, path: str, aws_profile: str): self._path = path self._files: Optional[List[str]] = None self._api_client: Optional[Client] = None self._existing_schemas: Optional[List[Dict[str, Any]]] = None + self._aws_profile = aws_profile @property def api_client(self) -> Client: if self._api_client is None: - self._api_client = Client() + self._api_client = Client(self._aws_profile) return self._api_client @property diff --git a/panther_analysis_tool/main.py b/panther_analysis_tool/main.py index 8685f75d..a4782f43 100644 --- a/panther_analysis_tool/main.py +++ b/panther_analysis_tool/main.py @@ -40,7 +40,6 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 -import boto3 import botocore import requests import semver @@ -75,6 +74,7 @@ SCHEDULED_QUERY_SCHEMA, TYPE_SCHEMA, ) +from panther_analysis_tool.util import get_client DATA_MODEL_LOCATION = "./data_models" HELPERS_LOCATION = "./global_helpers" @@ -335,7 +335,7 @@ def upload_analysis(args: argparse.Namespace) -> Tuple[int, str]: if return_code == 1: return return_code, "" - client = get_client(args, "lambda") + client = get_client(args.aws_profile, "lambda") with open(archive, "rb") as analysis_zip: zip_bytes = analysis_zip.read() @@ -387,7 +387,7 @@ def update_schemas(args: argparse.Namespace) -> Tuple[int, str]: A tuple of return code and the archive filename. """ - client = get_client(args, "lambda") + client = get_client(args.aws_profile, "lambda") logging.info("Fetching updates") response = client.invoke( @@ -463,15 +463,11 @@ def update_custom_schemas(args: argparse.Namespace) -> Tuple[int, str]: Returns: A tuple of return code and a placeholder string. """ - if args.aws_profile is not None: - logging.info("Using AWS profile: %s", args.aws_profile) - set_env("AWS_PROFILE", args.aws_profile) - normalized_path = user_defined.normalize_path(args.path) if not normalized_path: return 1, f"path not found: {args.path}" - uploader = user_defined.Uploader(normalized_path) + uploader = user_defined.Uploader(normalized_path, args.aws_profile) results = uploader.process() has_errors = False for failed, summary in user_defined.report_summary(normalized_path, results): @@ -498,13 +494,8 @@ def generate_release_assets(args: argparse.Namespace) -> Tuple[int, str]: if args.kms_key: # Then generate the sha512 sum of the zip file archive_hash = generate_hash(release_file) - # optionally set env variable for profile passed as argument - # this must be called prior to setting up the client - if args.aws_profile is not None: - logging.info("Using AWS profile: %s", args.aws_profile) - set_env("AWS_PROFILE", args.aws_profile) - client = get_client(args, "kms") + client = get_client(args.aws_profile, "kms") try: response = client.sign( KeyId=args.kms_key, @@ -706,7 +697,7 @@ def test_analysis(args: argparse.Namespace) -> Tuple[int, list]: f"No analysis in {args.path} matched filters {args.filter} - {args.filter_inverted}" ] - available_destinations = [] + available_destinations: List[str] = [] if args.available_destination: available_destinations.extend(args.available_destination) @@ -1310,7 +1301,7 @@ def setup_parser() -> argparse.ArgumentParser: + "managing Panther policies and rules.", prog="panther_analysis_tool", ) - parser.add_argument("--version", action="version", version="panther_analysis_tool 0.8.1") + parser.add_argument("--version", action="version", version="panther_analysis_tool 0.8.2") parser.add_argument("--debug", action="store_true", dest="debug") subparsers = parser.add_subparsers() @@ -1545,21 +1536,6 @@ def parse_filter(filters: List[str]) -> Tuple[Dict[str, Any], Dict[str, Any]]: return parsed_filters, parsed_filters_inverted -def get_client(args: argparse.Namespace, service: str) -> boto3.client: - client = boto3.client(service) - # optionally set env variable for profile passed as argument - if args.aws_profile is not None: - logging.info("Using AWS profile: %s", args.aws_profile) - set_env("AWS_PROFILE", args.aws_profile) - session = boto3.Session(profile_name=args.aws_profile) - client = session.client(service) - return client - - -def set_env(key: str, value: str) -> None: - os.environ[key] = value - - def run() -> None: parser = setup_parser() # if no args are passed, print the help output diff --git a/panther_analysis_tool/rule.py b/panther_analysis_tool/rule.py index 3ca70e81..bce26fcb 100644 --- a/panther_analysis_tool/rule.py +++ b/panther_analysis_tool/rule.py @@ -28,7 +28,7 @@ from dataclasses import dataclass from pathlib import Path from types import ModuleType -from typing import Any, Callable, List, Optional, Dict +from typing import Any, Callable, Dict, List, Optional from panther_analysis_tool.enriched_event import PantherEvent from panther_analysis_tool.exceptions import ( @@ -310,13 +310,13 @@ def __init__(self, config: Mapping): self.rule_dedup_period_mins = config["dedupPeriodMinutes"] if not ("tags" in config) or not isinstance(config["tags"], list): - self.rule_tags = list() + self.rule_tags: List[str] = list() else: config["tags"].sort() self.rule_tags = config["tags"] if "reports" not in config: - self.rule_reports = dict() + self.rule_reports: Dict[str, List[str]] = dict() else: # Reports are Dict[str, List[str]] # Sorting the List before setting it @@ -634,7 +634,7 @@ def _get_destinations( # pylint: disable=too-many-return-statements,too-many-ar # Check for (in)valid destinations invalid_destinations = [] - standardized_destinations = [] + standardized_destinations: List[str] = [] # Standardize the destinations for each_destination in destinations: diff --git a/panther_analysis_tool/util.py b/panther_analysis_tool/util.py index 1c520f2c..733299b5 100644 --- a/panther_analysis_tool/util.py +++ b/panther_analysis_tool/util.py @@ -17,11 +17,14 @@ along with this program. If not, see . """ +import logging import os from importlib import util as import_util from pathlib import Path from typing import Any +import boto3 + def allowed_char(char: str) -> bool: """Return true if the character is part of a valid ID.""" @@ -53,3 +56,19 @@ def store_modules(path: str, body: str) -> None: Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True) with open(path, "w") as py_file: py_file.write(body) + + +def get_client(aws_profile: str, service: str) -> boto3.client: + # optionally set env variable for profile passed as argument + if aws_profile is not None: + logging.info("Using AWS profile: %s", aws_profile) + set_env("AWS_PROFILE", aws_profile) + sess = boto3.Session(profile_name=aws_profile) + client = sess.client(service) + else: + client = boto3.client(service) + return client + + +def set_env(key: str, value: str) -> None: + os.environ[key] = value diff --git a/setup.py b/setup.py index 3f7e6b0e..62259257 100644 --- a/setup.py +++ b/setup.py @@ -21,14 +21,14 @@ name='panther_analysis_tool', packages=['panther_analysis_tool', 'panther_analysis_tool/log_schemas'], package_dir={'log_schemas': 'panther_analysis_tool/log_schemas'}, - version='0.8.1', + version='0.8.2', license='AGPL-3.0', description= 'Panther command line interface for writing, testing, and packaging policies/rules.', author='Panther Labs Inc', author_email='pypi@runpanther.io', url='https://github.com/panther-labs/panther_analysis_tool', - download_url = 'https://github.com/panther-labs/panther_analysis_tool/archive/v0.8.1.tar.gz', + download_url = 'https://github.com/panther-labs/panther_analysis_tool/archive/v0.8.2.tar.gz', keywords=['Security', 'CLI'], scripts=['bin/panther_analysis_tool'], install_requires=install_requires, diff --git a/tests/unit/panther_analysis_tool/log_schemas/test_user_defined.py b/tests/unit/panther_analysis_tool/log_schemas/test_user_defined.py index b0526db6..d1ae0f9a 100644 --- a/tests/unit/panther_analysis_tool/log_schemas/test_user_defined.py +++ b/tests/unit/panther_analysis_tool/log_schemas/test_user_defined.py @@ -98,20 +98,20 @@ def test_existing_schemas(self): mock_uploader_client.list_schemas = mock.MagicMock( return_value=(True, self.list_schemas_response) ) - uploader = user_defined.Uploader(self.valid_schema_path) + uploader = user_defined.Uploader(self.valid_schema_path, None) self.assertListEqual(uploader.existing_schemas, self.list_schemas_response['results']) mock_uploader_client.list_schemas.assert_called_once() def test_find_schema(self): with mock.patch('panther_analysis_tool.log_schemas.user_defined.Uploader.existing_schemas', self.list_schemas_response['results']): - uploader = user_defined.Uploader(self.valid_schema_path) + uploader = user_defined.Uploader(self.valid_schema_path, None) self.assertDictEqual(uploader.find_schema('Custom.SampleSchema2'), self.list_schemas_response['results'][1]) self.assertIsNone(uploader.find_schema('unknown-schema')) def test_files(self): - uploader = user_defined.Uploader(self.valid_schema_path) + uploader = user_defined.Uploader(self.valid_schema_path, None) self.assertListEqual( uploader.files, [os.path.join(self.valid_schema_path, 'schema-1.yml'), @@ -138,7 +138,7 @@ def test_process(self): mock_uploader_client.put_schema = mock.MagicMock( side_effect=put_schema_responses ) - uploader = user_defined.Uploader(self.valid_schema_path) + uploader = user_defined.Uploader(self.valid_schema_path, None) results = uploader.process() self.assertEqual(len(results), 2) self.assertListEqual([r.name for r in results], diff --git a/tests/unit/panther_analysis_tool/test_main.py b/tests/unit/panther_analysis_tool/test_main.py index e1d9d14f..cf23f98b 100644 --- a/tests/unit/panther_analysis_tool/test_main.py +++ b/tests/unit/panther_analysis_tool/test_main.py @@ -28,6 +28,7 @@ from nose.tools import assert_equal, assert_is_instance, assert_true from panther_analysis_tool import main as pat +from panther_analysis_tool import util from panther_analysis_tool.data_model import _DATAMODEL_FOLDER FIXTURES_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../', 'fixtures')) @@ -184,7 +185,7 @@ def test_aws_profiles(self): aws_profile = 'AWS_PROFILE' args = pat.setup_parser().parse_args( f'upload --path {DETECTIONS_FIXTURES_PATH}/valid_analysis --aws-profile myprofile'.split()) - pat.set_env(aws_profile, args.aws_profile) + util.set_env(aws_profile, args.aws_profile) assert_equal('myprofile', args.aws_profile) assert_equal(args.aws_profile, os.environ.get(aws_profile)) @@ -341,7 +342,7 @@ def test_update_custom_schemas(self): with mock.patch('panther_analysis_tool.log_schemas.user_defined.Uploader') as mock_uploader: _, _ = pat.update_custom_schemas(args) - mock_uploader.assert_called_once_with(f'{FIXTURES_PATH}/custom-schemas/valid') + mock_uploader.assert_called_once_with(f'{FIXTURES_PATH}/custom-schemas/valid', None) with open(os.path.join(schema_path, 'schema-1.yml')) as f: schema1 = f.read()