diff --git a/src/mh_structlog/aws.py b/src/mh_structlog/aws.py index cbaa8be..9336b0f 100644 --- a/src/mh_structlog/aws.py +++ b/src/mh_structlog/aws.py @@ -1,18 +1,39 @@ +import os + import structlog from aws_lambda_powertools.utilities.typing import LambdaContext +is_cold_start = True + + +def _reset_cold_start_flag() -> None: + """Reset the cold start flag to True. This is primarily intended for testing purposes.""" + global is_cold_start # noqa: PLW0603 + is_cold_start = True + + def bind_lambda_context(lambda_context: LambdaContext) -> None: """Bind AWS Lambda context information to the structlog context variables, so log entries contain Lambda function metadata. Args: lambda_context (LambdaContext): The AWS Lambda context object. """ + global is_cold_start # noqa: PLW0603 + if lambda_context: structlog.contextvars.clear_contextvars() + + if os.getenv('AWS_LAMBDA_INITIALIZATION_TYPE', '') == "provisioned-concurrency": + is_cold_start = False + structlog.contextvars.bind_contextvars( function_name=lambda_context.function_name, function_memory_size=lambda_context.memory_limit_in_mb, function_arn=lambda_context.invoked_function_arn, function_request_id=lambda_context.aws_request_id, + cold_start=is_cold_start, ) + + # After the first invocation of an environment, set cold_start to False for further invocations + is_cold_start = False diff --git a/src/mh_structlog/config.py b/src/mh_structlog/config.py index c856266..f5e459b 100644 --- a/src/mh_structlog/config.py +++ b/src/mh_structlog/config.py @@ -22,7 +22,7 @@ class StructlogLoggingConfigExceptionError(Exception): def setup( # noqa: PLR0912, PLR0915, C901 - log_format: Literal["console", "json", "gcp_json"] | None = None, + log_format: Literal["console", "json", "gcp_json", "aws_json"] | None = None, logging_configs: list[dict] | None = None, include_source_location: bool = False, # noqa: FBT001, FBT002 global_filter_level: int | None = None, @@ -64,7 +64,14 @@ def setup( # noqa: PLR0912, PLR0915, C901 # Configure stdout formatter if log_format is None: log_format = "console" if sys.stdout.isatty() else "json" - if log_format not in {"console", "json", "gcp_json"}: + + # Determine a more specific log format based on environment if possible. + if log_format == "json": + if os.environ.get("GCP_PROJECT"): + log_format = "gcp_json" + elif os.environ.get("AWS_REGION"): + log_format = "aws_json" + if log_format not in {"console", "json", "gcp_json", "aws_json"}: raise StructlogLoggingConfigExceptionError("Unknown logging format requested.") SELECTED_LOG_FORMAT = log_format @@ -90,7 +97,7 @@ def setup( # noqa: PLR0912, PLR0915, C901 if log_format == "console": selected_formatter = "mh_structlog_colored" - elif log_format in {"json", "gcp_json"}: + elif log_format in {"json", "gcp_json", "aws_json"}: shared_processors.extend( [structlog.processors.dict_tracebacks, processors.CapExceptionFrames(max_frames=2 * max_frames)] ) @@ -199,6 +206,7 @@ def setup( # noqa: PLR0912, PLR0915, C901 processors.FieldRenamer( log_format == 'gcp_json', 'level', 'severity' ), # rename the level field for GCP + processors.FieldTransformer(log_format == 'aws_json', 'level', lambda v: v.upper()), processors.render_orjson, ], "foreign_pre_chain": shared_processors, diff --git a/src/mh_structlog/processors.py b/src/mh_structlog/processors.py index d28e9a6..c5154c4 100644 --- a/src/mh_structlog/processors.py +++ b/src/mh_structlog/processors.py @@ -1,4 +1,5 @@ import logging +from collections.abc import Callable import orjson import structlog @@ -81,6 +82,20 @@ def __call__(self, logger: logging.Logger, name: str, event_dict: EventDict) -> return event_dict +class FieldTransformer: + """Transform a field in the event dict using a provided function.""" + + def __init__(self, enable: bool, field_name: str, transform_function: Callable): # noqa: D107 + self.enable = enable + self.field_name = field_name + self.transform_function = transform_function + + def __call__(self, logger: logging.Logger, name: str, event_dict: EventDict) -> EventDict: # noqa: D102,ARG001,ARG002 + if self.enable and self.field_name in event_dict: + event_dict[self.field_name] = self.transform_function(event_dict[self.field_name]) + return event_dict + + class CapExceptionFrames: """Limit the number of frames in the exception traceback. diff --git a/tests/test_aws.py b/tests/test_aws.py index 3835469..666080f 100644 --- a/tests/test_aws.py +++ b/tests/test_aws.py @@ -3,10 +3,12 @@ from aws_lambda_powertools.utilities.typing import LambdaContext from structlog.contextvars import clear_contextvars, get_contextvars -from mh_structlog.aws import bind_lambda_context +from mh_structlog.aws import _reset_cold_start_flag, bind_lambda_context def test_bind_lambda_context_non_empty(): + _reset_cold_start_flag() + clear_contextvars() mock_lambda_context = Mock(spec=LambdaContext) @@ -24,6 +26,7 @@ def test_bind_lambda_context_non_empty(): 'function_memory_size': 128, 'function_name': 'test_function', 'function_request_id': '1234-5678', + 'cold_start': True, } @@ -35,3 +38,23 @@ def test_bind_lambda_context_empty(): bind_lambda_context(None) assert get_contextvars() == {} + + +def test_bind_lambda_context_cold_start(): + _reset_cold_start_flag() + + clear_contextvars() + + mock_lambda_context = Mock(spec=LambdaContext) + mock_lambda_context.function_name = "test_function" + mock_lambda_context.memory_limit_in_mb = 128 + mock_lambda_context.invoked_function_arn = "arn:aws:lambda:region:account-id:function:test_function" + mock_lambda_context.aws_request_id = "1234-5678" + + assert get_contextvars() == {} + + bind_lambda_context(mock_lambda_context) + assert get_contextvars()['cold_start'] + + bind_lambda_context(mock_lambda_context) + assert not get_contextvars()['cold_start'] diff --git a/tests/test_processors.py b/tests/test_processors.py index 70f71ab..2fe9a45 100644 --- a/tests/test_processors.py +++ b/tests/test_processors.py @@ -2,6 +2,7 @@ FieldDropper, FieldRenamer, FieldsAdder, + FieldTransformer, add_flattened_extra, cap_timestamp_to_ms_precision, ) @@ -59,3 +60,10 @@ def test_field_adder(): event_dict = {"event": "startup"} result = adder(None, None, event_dict) assert result == {"event": "startup", "service": "my-service", "env": "production"} + + +def test_field_transformer_enabled(): + transformer = FieldTransformer(enable=True, field_name="level", transform_function=lambda v: v.upper()) + event_dict = {"event": "system alert", "level": "warning"} + result = transformer(None, None, event_dict) + assert result == {"event": "system alert", "level": "WARNING"}