diff --git a/openfeature/api.py b/openfeature/api.py index c7830204..847f7daa 100644 --- a/openfeature/api.py +++ b/openfeature/api.py @@ -36,7 +36,9 @@ ] _evaluation_context = EvaluationContext() -_evaluation_transaction_context_propagator = NoOpTransactionContextPropagator() +_evaluation_transaction_context_propagator: TransactionContextPropagator = ( + NoOpTransactionContextPropagator() +) _hooks: typing.List[Hook] = [] diff --git a/openfeature/transaction_context/context_var_transaction_context_propagator.py b/openfeature/transaction_context/context_var_transaction_context_propagator.py index e57687cb..1abc04fa 100644 --- a/openfeature/transaction_context/context_var_transaction_context_propagator.py +++ b/openfeature/transaction_context/context_var_transaction_context_propagator.py @@ -5,14 +5,14 @@ TransactionContextPropagator, ) -_transaction_context_var: ContextVar[EvaluationContext] = ContextVar( - "transaction_context", default=EvaluationContext() -) - class ContextVarsTransactionContextPropagator(TransactionContextPropagator): + _transaction_context_var: ContextVar[EvaluationContext] = ContextVar( + "transaction_context", default=EvaluationContext() + ) + def get_transaction_context(self) -> EvaluationContext: - return _transaction_context_var.get() + return self._transaction_context_var.get() def set_transaction_context(self, transaction_context: EvaluationContext) -> None: - _transaction_context_var.set(transaction_context) + self._transaction_context_var.set(transaction_context) diff --git a/tests/test_client.py b/tests/test_client.py index b51c460c..caab2d36 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,12 +1,14 @@ import time import uuid from concurrent.futures import ThreadPoolExecutor -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import pytest +from openfeature import api from openfeature.api import add_hooks, clear_hooks, get_client, set_provider from openfeature.client import OpenFeatureClient +from openfeature.evaluation_context import EvaluationContext from openfeature.event import EventDetails, ProviderEvent, ProviderEventDetails from openfeature.exception import ErrorCode, OpenFeatureError from openfeature.flag_evaluation import FlagResolutionDetails, Reason @@ -14,6 +16,7 @@ from openfeature.provider import FeatureProvider, ProviderStatus from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider from openfeature.provider.no_op_provider import NoOpProvider +from openfeature.transaction_context import ContextVarsTransactionContextPropagator @pytest.mark.parametrize( @@ -384,3 +387,47 @@ def emit_events_task(): f2 = executor.submit(emit_events_task) f1.result() f2.result() + + +def test_client_should_merge_contexts(): + api.clear_hooks() + api.set_transaction_context_propagator(ContextVarsTransactionContextPropagator()) + + provider = NoOpProvider() + provider.resolve_boolean_details = MagicMock(wraps=provider.resolve_boolean_details) + api.set_provider(provider) + + # Global evaluation context + global_context = EvaluationContext( + targeting_key="global", attributes={"global_attr": "global_value"} + ) + api.set_evaluation_context(global_context) + + # Transaction context + transaction_context = EvaluationContext( + targeting_key="transaction", + attributes={"transaction_attr": "transaction_value"}, + ) + api.set_transaction_context(transaction_context) + + # Client-specific context + client_context = EvaluationContext( + targeting_key="client", attributes={"client_attr": "client_value"} + ) + client = OpenFeatureClient(domain=None, version=None, context=client_context) + + # Invocation-specific context + invocation_context = EvaluationContext( + targeting_key="invocation", attributes={"invocation_attr": "invocation_value"} + ) + client.get_boolean_details("flag", False, invocation_context) + + # Retrieve the call arguments + args, kwargs = provider.resolve_boolean_details.call_args + flag_key, default_value, context = args + + assert context.targeting_key == "invocation" # Last one in the merge chain + assert context.attributes["global_attr"] == "global_value" + assert context.attributes["transaction_attr"] == "transaction_value" + assert context.attributes["client_attr"] == "client_value" + assert context.attributes["invocation_attr"] == "invocation_value" diff --git a/tests/test_transaction_context.py b/tests/test_transaction_context.py index bef2825f..84d1d7b0 100644 --- a/tests/test_transaction_context.py +++ b/tests/test_transaction_context.py @@ -15,9 +15,6 @@ NoOpTransactionContextPropagator, TransactionContextPropagator, ) -from openfeature.transaction_context.context_var_transaction_context_propagator import ( - _transaction_context_var, -) # Test cases @@ -97,8 +94,12 @@ def test_should_propagate_event_when_context_set(): set_transaction_context(evaluation_context) # Then - assert _transaction_context_var.get().targeting_key == "custom_key" - assert _transaction_context_var.get().attributes == {"attr1": "val1"} + assert ( + custom_propagator._transaction_context_var.get().targeting_key == "custom_key" + ) + assert custom_propagator._transaction_context_var.get().attributes == { + "attr1": "val1" + } def test_context_vars_transaction_context_propagator_multiple_threads():