Skip to content

Commit

Permalink
feat: add multi threading when mapping imports (#11)
Browse files Browse the repository at this point in the history
* add threading when getting imports

* fix formatting

* fix formatting

* fix func name
  • Loading branch information
tcm5343 authored Feb 10, 2025
1 parent 7bdfd46 commit 6fd9e02
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 8 deletions.
65 changes: 58 additions & 7 deletions src/cycl/cycl.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,87 @@
from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from pathlib import Path
from queue import Queue
from typing import TYPE_CHECKING

import boto3
import networkx as nx
from botocore.config import Config

from cycl.utils.cdk import get_cdk_out_imports
from cycl.utils.cfn import get_all_exports, get_all_imports, parse_name_from_id

if TYPE_CHECKING:
from botocore.client import BaseClient

log = getLogger(__name__)


def build_dependency_graph(cdk_out_path: Path | None = None) -> nx.MultiDiGraph:
dep_graph = nx.MultiDiGraph()
def __get_imports_mapper(export: tuple[str, dict], cfn_client: BaseClient) -> dict:
export[1]['ExportingStackName'] = parse_name_from_id(export[1]['ExportingStackId'])
export[1].setdefault('ImportingStackNames', []).extend(
get_all_imports(export_name=export[1]['Name'], cfn_client=cfn_client)
)
return {export[0]: export[1]}


def __map_existing_exports_to_imports(exports: dict) -> dict:
res = {}
q = Queue()

max_workers = 10
boto_config = Config(
retries={
'max_attempts': 10,
'mode': 'adaptive',
},
max_pool_connections=max_workers,
)
cfn_client = boto3.client('cloudformation', config=boto_config)

for item in exports.items():
q.put(item)

futures = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
while not q.empty():
futures.append(executor.submit(__get_imports_mapper, export=q.get(), cfn_client=cfn_client))

for future in futures:
err = future.exception()
if err:
log.error(err)
else:
res.update(future.result())
return res


def __get_graph_data(cdk_out_path: Path | None = None) -> dict:
"""TODO: make this public and test"""
cdk_out_imports = {}
if cdk_out_path:
cdk_out_imports = get_cdk_out_imports(Path(cdk_out_path))

exports = get_all_exports()
for export_name in cdk_out_imports:
for export_name, importing_stack_names in cdk_out_imports.items():
if export_name in exports:
exports[export_name].setdefault('ImportingStackNames', []).extend(importing_stack_names)
if export_name not in exports:
log.warning(
'found an export (%s) which has not been deployed yet about to be imported stack(s): (%s)',
export_name,
cdk_out_imports[export_name],
)

for export in exports.values():
export['ExportingStackName'] = parse_name_from_id(export['ExportingStackId'])
export['ImportingStackNames'] = get_all_imports(export_name=export['Name'])
export.setdefault('ImportingStackNames', []).extend(cdk_out_imports.get(export['Name'], []))
return __map_existing_exports_to_imports(exports=exports)


def build_dependency_graph(cdk_out_path: Path | None = None) -> nx.MultiDiGraph:
dep_graph = nx.MultiDiGraph()
mapped_exports = __get_graph_data(cdk_out_path)
for export in mapped_exports.values():
edges = [
(export['ExportingStackName'], importing_stack_name) for importing_stack_name in export['ImportingStackNames']
]
Expand Down
27 changes: 26 additions & 1 deletion tests/cycl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
from cycl.cycl import build_dependency_graph


@pytest.fixture(autouse=True)
def mock_boto3():
with patch.object(cycl_module, 'boto3') as mock:
yield mock


@pytest.fixture(autouse=True)
def mock_config():
with patch.object(cycl_module, 'Config') as mock:
yield mock


@pytest.fixture(autouse=True)
def mock_parse_name_from_id():
with patch.object(cycl_module, 'parse_name_from_id') as mock:
Expand Down Expand Up @@ -123,7 +135,7 @@ def test_build_dependency_graph_returns_graph_with_multiple_exports(mock_get_all
},
}

def mock_get_all_imports_side_effect_func(export_name):
def mock_get_all_imports_side_effect_func(export_name, *_args, **_kwargs):
if export_name == 'some-name-1':
return [
'some-importing-stack-name-1',
Expand Down Expand Up @@ -290,3 +302,16 @@ def test_build_dependency_graph_returns_graph_with_cdk_out_path_and_no_existing_

assert nx.is_directed_acyclic_graph(actual_graph)
assert next(nx.simple_cycles(actual_graph), []) == []


def test_config_defined_as_expected(mock_config, mock_boto3):
build_dependency_graph()

mock_config.assert_called_once_with(
retries={
'max_attempts': 10,
'mode': 'adaptive',
},
max_pool_connections=10,
)
mock_boto3.client.assert_called_once_with('cloudformation', config=mock_config.return_value)

0 comments on commit 6fd9e02

Please sign in to comment.