Skip to content

Commit

Permalink
FIX: Consolidate get_client function & fix to have user-defined schem…
Browse files Browse the repository at this point in the history
…as respect --aws-profile (#122)

* Moved the get_client function to the utils and removed the duplicated/redunant set_env calls.
Changed the user_defined Client and Uploader to pass the argparse.Namespace up the stack to ensure that the lambda client used respects the aws-profile flag

* Fixed linting issues - added type annotations

* PR feedback

* remove unused imports

* version bump

* remove debug print statement

Co-authored-by: lindsey-w <lindsey.whitehurst@runpanther.io>
  • Loading branch information
wey-chiang and lindsey-w authored Aug 12, 2021
1 parent eaba5e1 commit d146ac2
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 48 deletions.
13 changes: 8 additions & 5 deletions panther_analysis_tool/log_schemas/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, cast

import boto3
from botocore import client
from ruamel.yaml import YAML
from ruamel.yaml.parser import ParserError
from ruamel.yaml.scanner import ScannerError

from panther_analysis_tool.util import get_client

logger = logging.getLogger(__file__)


Expand All @@ -38,13 +39,14 @@ class Client:
_LIST_SCHEMAS_ENDPOINT = "ListSchemas"
_PUT_SCHEMA_ENDPOINT = "PutUserSchema"

def __init__(self) -> None:
def __init__(self, aws_profile: str) -> None:
self._lambda_client = None
self._aws_profile = aws_profile

@property
def lambda_client(self) -> client.BaseClient:
if self._lambda_client is None:
self._lambda_client = boto3.client("lambda")
self._lambda_client = get_client(self._aws_profile, "lambda")
return self._lambda_client

def list_schemas(self) -> Tuple[bool, dict]:
Expand Down Expand Up @@ -143,16 +145,17 @@ class Uploader:
_SCHEMA_NAME_PREFIX = "Custom."
_SCHEMA_FILE_GLOB_PATTERNS = ("*.yml", "*.yaml")

def __init__(self, path: str):
def __init__(self, path: str, aws_profile: str):
self._path = path
self._files: Optional[List[str]] = None
self._api_client: Optional[Client] = None
self._existing_schemas: Optional[List[Dict[str, Any]]] = None
self._aws_profile = aws_profile

@property
def api_client(self) -> Client:
if self._api_client is None:
self._api_client = Client()
self._api_client = Client(self._aws_profile)
return self._api_client

@property
Expand Down
38 changes: 7 additions & 31 deletions panther_analysis_tool/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from unittest.mock import MagicMock, patch
from uuid import uuid4

import boto3
import botocore
import requests
import semver
Expand Down Expand Up @@ -75,6 +74,7 @@
SCHEDULED_QUERY_SCHEMA,
TYPE_SCHEMA,
)
from panther_analysis_tool.util import get_client

DATA_MODEL_LOCATION = "./data_models"
HELPERS_LOCATION = "./global_helpers"
Expand Down Expand Up @@ -335,7 +335,7 @@ def upload_analysis(args: argparse.Namespace) -> Tuple[int, str]:
if return_code == 1:
return return_code, ""

client = get_client(args, "lambda")
client = get_client(args.aws_profile, "lambda")

with open(archive, "rb") as analysis_zip:
zip_bytes = analysis_zip.read()
Expand Down Expand Up @@ -387,7 +387,7 @@ def update_schemas(args: argparse.Namespace) -> Tuple[int, str]:
A tuple of return code and the archive filename.
"""

client = get_client(args, "lambda")
client = get_client(args.aws_profile, "lambda")

logging.info("Fetching updates")
response = client.invoke(
Expand Down Expand Up @@ -463,15 +463,11 @@ def update_custom_schemas(args: argparse.Namespace) -> Tuple[int, str]:
Returns:
A tuple of return code and a placeholder string.
"""
if args.aws_profile is not None:
logging.info("Using AWS profile: %s", args.aws_profile)
set_env("AWS_PROFILE", args.aws_profile)

normalized_path = user_defined.normalize_path(args.path)
if not normalized_path:
return 1, f"path not found: {args.path}"

uploader = user_defined.Uploader(normalized_path)
uploader = user_defined.Uploader(normalized_path, args.aws_profile)
results = uploader.process()
has_errors = False
for failed, summary in user_defined.report_summary(normalized_path, results):
Expand All @@ -498,13 +494,8 @@ def generate_release_assets(args: argparse.Namespace) -> Tuple[int, str]:
if args.kms_key:
# Then generate the sha512 sum of the zip file
archive_hash = generate_hash(release_file)
# optionally set env variable for profile passed as argument
# this must be called prior to setting up the client
if args.aws_profile is not None:
logging.info("Using AWS profile: %s", args.aws_profile)
set_env("AWS_PROFILE", args.aws_profile)

client = get_client(args, "kms")
client = get_client(args.aws_profile, "kms")
try:
response = client.sign(
KeyId=args.kms_key,
Expand Down Expand Up @@ -706,7 +697,7 @@ def test_analysis(args: argparse.Namespace) -> Tuple[int, list]:
f"No analysis in {args.path} matched filters {args.filter} - {args.filter_inverted}"
]

available_destinations = []
available_destinations: List[str] = []
if args.available_destination:
available_destinations.extend(args.available_destination)

Expand Down Expand Up @@ -1310,7 +1301,7 @@ def setup_parser() -> argparse.ArgumentParser:
+ "managing Panther policies and rules.",
prog="panther_analysis_tool",
)
parser.add_argument("--version", action="version", version="panther_analysis_tool 0.8.1")
parser.add_argument("--version", action="version", version="panther_analysis_tool 0.8.2")
parser.add_argument("--debug", action="store_true", dest="debug")
subparsers = parser.add_subparsers()

Expand Down Expand Up @@ -1545,21 +1536,6 @@ def parse_filter(filters: List[str]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
return parsed_filters, parsed_filters_inverted


def get_client(args: argparse.Namespace, service: str) -> boto3.client:
client = boto3.client(service)
# optionally set env variable for profile passed as argument
if args.aws_profile is not None:
logging.info("Using AWS profile: %s", args.aws_profile)
set_env("AWS_PROFILE", args.aws_profile)
session = boto3.Session(profile_name=args.aws_profile)
client = session.client(service)
return client


def set_env(key: str, value: str) -> None:
os.environ[key] = value


def run() -> None:
parser = setup_parser()
# if no args are passed, print the help output
Expand Down
8 changes: 4 additions & 4 deletions panther_analysis_tool/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from dataclasses import dataclass
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, List, Optional, Dict
from typing import Any, Callable, Dict, List, Optional

from panther_analysis_tool.enriched_event import PantherEvent
from panther_analysis_tool.exceptions import (
Expand Down Expand Up @@ -310,13 +310,13 @@ def __init__(self, config: Mapping):
self.rule_dedup_period_mins = config["dedupPeriodMinutes"]

if not ("tags" in config) or not isinstance(config["tags"], list):
self.rule_tags = list()
self.rule_tags: List[str] = list()
else:
config["tags"].sort()
self.rule_tags = config["tags"]

if "reports" not in config:
self.rule_reports = dict()
self.rule_reports: Dict[str, List[str]] = dict()
else:
# Reports are Dict[str, List[str]]
# Sorting the List before setting it
Expand Down Expand Up @@ -634,7 +634,7 @@ def _get_destinations( # pylint: disable=too-many-return-statements,too-many-ar

# Check for (in)valid destinations
invalid_destinations = []
standardized_destinations = []
standardized_destinations: List[str] = []

# Standardize the destinations
for each_destination in destinations:
Expand Down
19 changes: 19 additions & 0 deletions panther_analysis_tool/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""

import logging
import os
from importlib import util as import_util
from pathlib import Path
from typing import Any

import boto3


def allowed_char(char: str) -> bool:
"""Return true if the character is part of a valid ID."""
Expand Down Expand Up @@ -53,3 +56,19 @@ def store_modules(path: str, body: str) -> None:
Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True)
with open(path, "w") as py_file:
py_file.write(body)


def get_client(aws_profile: str, service: str) -> boto3.client:
# optionally set env variable for profile passed as argument
if aws_profile is not None:
logging.info("Using AWS profile: %s", aws_profile)
set_env("AWS_PROFILE", aws_profile)
sess = boto3.Session(profile_name=aws_profile)
client = sess.client(service)
else:
client = boto3.client(service)
return client


def set_env(key: str, value: str) -> None:
os.environ[key] = value
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
name='panther_analysis_tool',
packages=['panther_analysis_tool', 'panther_analysis_tool/log_schemas'],
package_dir={'log_schemas': 'panther_analysis_tool/log_schemas'},
version='0.8.1',
version='0.8.2',
license='AGPL-3.0',
description=
'Panther command line interface for writing, testing, and packaging policies/rules.',
author='Panther Labs Inc',
author_email='pypi@runpanther.io',
url='https://github.com/panther-labs/panther_analysis_tool',
download_url = 'https://github.com/panther-labs/panther_analysis_tool/archive/v0.8.1.tar.gz',
download_url = 'https://github.com/panther-labs/panther_analysis_tool/archive/v0.8.2.tar.gz',
keywords=['Security', 'CLI'],
scripts=['bin/panther_analysis_tool'],
install_requires=install_requires,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,20 @@ def test_existing_schemas(self):
mock_uploader_client.list_schemas = mock.MagicMock(
return_value=(True, self.list_schemas_response)
)
uploader = user_defined.Uploader(self.valid_schema_path)
uploader = user_defined.Uploader(self.valid_schema_path, None)
self.assertListEqual(uploader.existing_schemas, self.list_schemas_response['results'])
mock_uploader_client.list_schemas.assert_called_once()

def test_find_schema(self):
with mock.patch('panther_analysis_tool.log_schemas.user_defined.Uploader.existing_schemas',
self.list_schemas_response['results']):
uploader = user_defined.Uploader(self.valid_schema_path)
uploader = user_defined.Uploader(self.valid_schema_path, None)
self.assertDictEqual(uploader.find_schema('Custom.SampleSchema2'),
self.list_schemas_response['results'][1])
self.assertIsNone(uploader.find_schema('unknown-schema'))

def test_files(self):
uploader = user_defined.Uploader(self.valid_schema_path)
uploader = user_defined.Uploader(self.valid_schema_path, None)
self.assertListEqual(
uploader.files,
[os.path.join(self.valid_schema_path, 'schema-1.yml'),
Expand All @@ -138,7 +138,7 @@ def test_process(self):
mock_uploader_client.put_schema = mock.MagicMock(
side_effect=put_schema_responses
)
uploader = user_defined.Uploader(self.valid_schema_path)
uploader = user_defined.Uploader(self.valid_schema_path, None)
results = uploader.process()
self.assertEqual(len(results), 2)
self.assertListEqual([r.name for r in results],
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/panther_analysis_tool/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from nose.tools import assert_equal, assert_is_instance, assert_true

from panther_analysis_tool import main as pat
from panther_analysis_tool import util
from panther_analysis_tool.data_model import _DATAMODEL_FOLDER

FIXTURES_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../', 'fixtures'))
Expand Down Expand Up @@ -184,7 +185,7 @@ def test_aws_profiles(self):
aws_profile = 'AWS_PROFILE'
args = pat.setup_parser().parse_args(
f'upload --path {DETECTIONS_FIXTURES_PATH}/valid_analysis --aws-profile myprofile'.split())
pat.set_env(aws_profile, args.aws_profile)
util.set_env(aws_profile, args.aws_profile)
assert_equal('myprofile', args.aws_profile)
assert_equal(args.aws_profile, os.environ.get(aws_profile))

Expand Down Expand Up @@ -341,7 +342,7 @@ def test_update_custom_schemas(self):

with mock.patch('panther_analysis_tool.log_schemas.user_defined.Uploader') as mock_uploader:
_, _ = pat.update_custom_schemas(args)
mock_uploader.assert_called_once_with(f'{FIXTURES_PATH}/custom-schemas/valid')
mock_uploader.assert_called_once_with(f'{FIXTURES_PATH}/custom-schemas/valid', None)

with open(os.path.join(schema_path, 'schema-1.yml')) as f:
schema1 = f.read()
Expand Down

0 comments on commit d146ac2

Please sign in to comment.