Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/mh_structlog/aws.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,39 @@
import os

import structlog
from aws_lambda_powertools.utilities.typing import LambdaContext


is_cold_start = True


def _reset_cold_start_flag() -> None:
"""Reset the cold start flag to True. This is primarily intended for testing purposes."""
global is_cold_start # noqa: PLW0603
is_cold_start = True


def bind_lambda_context(lambda_context: LambdaContext) -> None:
"""Bind AWS Lambda context information to the structlog context variables, so log entries contain Lambda function metadata.

Args:
lambda_context (LambdaContext): The AWS Lambda context object.
"""
global is_cold_start # noqa: PLW0603

if lambda_context:
structlog.contextvars.clear_contextvars()

if os.getenv('AWS_LAMBDA_INITIALIZATION_TYPE', '') == "provisioned-concurrency":
is_cold_start = False

structlog.contextvars.bind_contextvars(
function_name=lambda_context.function_name,
function_memory_size=lambda_context.memory_limit_in_mb,
function_arn=lambda_context.invoked_function_arn,
function_request_id=lambda_context.aws_request_id,
cold_start=is_cold_start,
)

# After the first invocation of an environment, set cold_start to False for further invocations
is_cold_start = False
14 changes: 11 additions & 3 deletions src/mh_structlog/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class StructlogLoggingConfigExceptionError(Exception):


def setup( # noqa: PLR0912, PLR0915, C901
log_format: Literal["console", "json", "gcp_json"] | None = None,
log_format: Literal["console", "json", "gcp_json", "aws_json"] | None = None,
logging_configs: list[dict] | None = None,
include_source_location: bool = False, # noqa: FBT001, FBT002
global_filter_level: int | None = None,
Expand Down Expand Up @@ -64,7 +64,14 @@ def setup( # noqa: PLR0912, PLR0915, C901
# Configure stdout formatter
if log_format is None:
log_format = "console" if sys.stdout.isatty() else "json"
if log_format not in {"console", "json", "gcp_json"}:

# Determine a more specific log format based on environment if possible.
if log_format == "json":
if os.environ.get("GCP_PROJECT"):
log_format = "gcp_json"
elif os.environ.get("AWS_REGION"):
log_format = "aws_json"
if log_format not in {"console", "json", "gcp_json", "aws_json"}:
raise StructlogLoggingConfigExceptionError("Unknown logging format requested.")

SELECTED_LOG_FORMAT = log_format
Expand All @@ -90,7 +97,7 @@ def setup( # noqa: PLR0912, PLR0915, C901

if log_format == "console":
selected_formatter = "mh_structlog_colored"
elif log_format in {"json", "gcp_json"}:
elif log_format in {"json", "gcp_json", "aws_json"}:
shared_processors.extend(
[structlog.processors.dict_tracebacks, processors.CapExceptionFrames(max_frames=2 * max_frames)]
)
Expand Down Expand Up @@ -199,6 +206,7 @@ def setup( # noqa: PLR0912, PLR0915, C901
processors.FieldRenamer(
log_format == 'gcp_json', 'level', 'severity'
), # rename the level field for GCP
processors.FieldTransformer(log_format == 'aws_json', 'level', lambda v: v.upper()),
processors.render_orjson,
],
"foreign_pre_chain": shared_processors,
Expand Down
15 changes: 15 additions & 0 deletions src/mh_structlog/processors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections.abc import Callable

import orjson
import structlog
Expand Down Expand Up @@ -81,6 +82,20 @@ def __call__(self, logger: logging.Logger, name: str, event_dict: EventDict) ->
return event_dict


class FieldTransformer:
"""Transform a field in the event dict using a provided function."""

def __init__(self, enable: bool, field_name: str, transform_function: Callable): # noqa: D107
self.enable = enable
self.field_name = field_name
self.transform_function = transform_function

def __call__(self, logger: logging.Logger, name: str, event_dict: EventDict) -> EventDict: # noqa: D102,ARG001,ARG002
if self.enable and self.field_name in event_dict:
event_dict[self.field_name] = self.transform_function(event_dict[self.field_name])
return event_dict


class CapExceptionFrames:
"""Limit the number of frames in the exception traceback.

Expand Down
25 changes: 24 additions & 1 deletion tests/test_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from aws_lambda_powertools.utilities.typing import LambdaContext
from structlog.contextvars import clear_contextvars, get_contextvars

from mh_structlog.aws import bind_lambda_context
from mh_structlog.aws import _reset_cold_start_flag, bind_lambda_context


def test_bind_lambda_context_non_empty():
_reset_cold_start_flag()

clear_contextvars()

mock_lambda_context = Mock(spec=LambdaContext)
Expand All @@ -24,6 +26,7 @@ def test_bind_lambda_context_non_empty():
'function_memory_size': 128,
'function_name': 'test_function',
'function_request_id': '1234-5678',
'cold_start': True,
}


Expand All @@ -35,3 +38,23 @@ def test_bind_lambda_context_empty():
bind_lambda_context(None)

assert get_contextvars() == {}


def test_bind_lambda_context_cold_start():
_reset_cold_start_flag()

clear_contextvars()

mock_lambda_context = Mock(spec=LambdaContext)
mock_lambda_context.function_name = "test_function"
mock_lambda_context.memory_limit_in_mb = 128
mock_lambda_context.invoked_function_arn = "arn:aws:lambda:region:account-id:function:test_function"
mock_lambda_context.aws_request_id = "1234-5678"

assert get_contextvars() == {}

bind_lambda_context(mock_lambda_context)
assert get_contextvars()['cold_start']

bind_lambda_context(mock_lambda_context)
assert not get_contextvars()['cold_start']
8 changes: 8 additions & 0 deletions tests/test_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
FieldDropper,
FieldRenamer,
FieldsAdder,
FieldTransformer,
add_flattened_extra,
cap_timestamp_to_ms_precision,
)
Expand Down Expand Up @@ -59,3 +60,10 @@ def test_field_adder():
event_dict = {"event": "startup"}
result = adder(None, None, event_dict)
assert result == {"event": "startup", "service": "my-service", "env": "production"}


def test_field_transformer_enabled():
transformer = FieldTransformer(enable=True, field_name="level", transform_function=lambda v: v.upper())
event_dict = {"event": "system alert", "level": "warning"}
result = transformer(None, None, event_dict)
assert result == {"event": "system alert", "level": "WARNING"}