From 7da6e9db7a30cd2173e99c28e39e221bfc9adf3d Mon Sep 17 00:00:00 2001 From: leohoare Date: Sun, 12 Jan 2025 19:42:53 +1100 Subject: [PATCH 1/3] refactor, switch to single client with common code and fallback Signed-off-by: leohoare --- openfeature/client.py | 439 ++++++++++++++++++++++++++++--- openfeature/provider/__init__.py | 75 ++++++ tests/test_client.py | 34 ++- 3 files changed, 508 insertions(+), 40 deletions(-) diff --git a/openfeature/client.py b/openfeature/client.py index 1edfca63..7560915d 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -20,7 +20,7 @@ FlagType, Reason, ) -from openfeature.hook import Hook, HookContext +from openfeature.hook import Hook, HookContext, HookHints from openfeature.hook._hook_support import ( after_all_hooks, after_hooks, @@ -55,6 +55,28 @@ FlagResolutionDetails[typing.Union[dict, list]], ], ] +GetDetailCallableAsync = typing.Union[ + typing.Callable[ + [str, bool, typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[bool]], + ], + typing.Callable[ + [str, int, typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[int]], + ], + typing.Callable[ + [str, float, typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[float]], + ], + typing.Callable[ + [str, str, typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[str]], + ], + typing.Callable[ + [str, typing.Union[dict, list], typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[typing.Union[dict, list]]], + ], +] TypeMap = typing.Dict[ FlagType, typing.Union[ @@ -113,6 +135,21 @@ def get_boolean_value( flag_evaluation_options, ).value + async def get_boolean_value_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> bool: + details = await self.get_boolean_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_boolean_details( self, flag_key: str, @@ -128,6 +165,21 @@ def get_boolean_details( flag_evaluation_options, ) + async def get_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[bool]: + return await self.evaluate_flag_details_async( + FlagType.BOOLEAN, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def get_string_value( self, flag_key: str, @@ -142,6 +194,21 @@ def get_string_value( flag_evaluation_options, ).value + async def get_string_value_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> str: + details = await self.get_string_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_string_details( self, flag_key: str, @@ -157,6 +224,21 @@ def get_string_details( flag_evaluation_options, ) + async def get_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[str]: + return await self.evaluate_flag_details_async( + FlagType.STRING, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def get_integer_value( self, flag_key: str, @@ -171,6 +253,21 @@ def get_integer_value( flag_evaluation_options, ).value + async def get_integer_value_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> int: + details = await self.get_integer_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_integer_details( self, flag_key: str, @@ -186,6 +283,21 @@ def get_integer_details( flag_evaluation_options, ) + async def get_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[int]: + return await self.evaluate_flag_details_async( + FlagType.INTEGER, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def get_float_value( self, flag_key: str, @@ -200,6 +312,21 @@ def get_float_value( flag_evaluation_options, ).value + async def get_float_value_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> float: + details = await self.get_float_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_float_details( self, flag_key: str, @@ -215,6 +342,21 @@ def get_float_details( flag_evaluation_options, ) + async def get_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[float]: + return await self.evaluate_flag_details_async( + FlagType.FLOAT, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def get_object_value( self, flag_key: str, @@ -229,6 +371,21 @@ def get_object_value( flag_evaluation_options, ).value + async def get_object_value_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> typing.Union[dict, list]: + details = await self.get_object_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_object_details( self, flag_key: str, @@ -244,26 +401,35 @@ def get_object_details( flag_evaluation_options, ) - def evaluate_flag_details( # noqa: PLR0915 + async def get_object_details_async( self, - flag_type: FlagType, flag_key: str, - default_value: typing.Any, + default_value: typing.Union[dict, list], evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> FlagEvaluationDetails[typing.Any]: - """ - Evaluate the flag requested by the user from the clients provider. - - :param flag_type: the type of the flag being returned - :param flag_key: the string key of the selected flag - :param default_value: backup value returned if no result found by the provider - :param evaluation_context: Information for the purposes of flag evaluation - :param flag_evaluation_options: Additional flag evaluation information - :return: a FlagEvaluationDetails object with the fully evaluated flag from a - provider - """ + ) -> FlagEvaluationDetails[typing.Union[dict, list]]: + return await self.evaluate_flag_details_async( + FlagType.OBJECT, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def _establish_hooks_and_provider( + self, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext], + flag_evaluation_options: typing.Optional[FlagEvaluationOptions], + ) -> typing.Tuple[ + FeatureProvider, + HookContext, + HookHints, + typing.List[Hook], + typing.List[Hook], + ]: if evaluation_context is None: evaluation_context = EvaluationContext() @@ -295,7 +461,17 @@ def evaluate_flag_details( # noqa: PLR0915 reversed_merged_hooks = merged_hooks[:] reversed_merged_hooks.reverse() + return provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks + + def _assert_provider_status( + self, + flag_type: FlagType, + hook_context: HookContext, + reversed_merged_hooks: typing.List[Hook], + hook_hints: HookHints, + ) -> typing.Union[None, ErrorCode]: status = self.get_provider_status() + if status == ProviderStatus.NOT_READY: error_hooks( flag_type, @@ -304,43 +480,194 @@ def evaluate_flag_details( # noqa: PLR0915 reversed_merged_hooks, hook_hints, ) + return ErrorCode.PROVIDER_NOT_READY + if status == ProviderStatus.FATAL: + error_hooks( + flag_type, + hook_context, + ProviderFatalError(), + reversed_merged_hooks, + hook_hints, + ) + return ErrorCode.PROVIDER_FATAL + return None + + def _before_hooks_and_merge_context( + self, + flag_type: FlagType, + hook_context: HookContext, + merged_hooks: typing.List[Hook], + hook_hints: HookHints, + evaluation_context: typing.Optional[EvaluationContext], + ) -> EvaluationContext: + # https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md + # Any resulting evaluation context from a before hook will overwrite + # duplicate fields defined globally, on the client, or in the invocation. + # Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context + invocation_context = before_hooks( + flag_type, hook_context, merged_hooks, hook_hints + ) + if evaluation_context: + invocation_context = invocation_context.merge(ctx2=evaluation_context) + + # Requirement 3.2.2 merge: API.context->transaction.context->client.context->invocation.context + merged_context = ( + api.get_evaluation_context() + .merge(api.get_transaction_context()) + .merge(self.context) + .merge(invocation_context) + ) + return merged_context + + async def evaluate_flag_details_async( + self, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[typing.Any]: + """ + Evaluate the flag requested by the user from the clients provider. + + :param flag_type: the type of the flag being returned + :param flag_key: the string key of the selected flag + :param default_value: backup value returned if no result found by the provider + :param evaluation_context: Information for the purposes of flag evaluation + :param flag_evaluation_options: Additional flag evaluation information + :return: a typing.Awaitable[FlagEvaluationDetails] object with the fully evaluated flag from a + provider + """ + provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = ( + self._establish_hooks_and_provider( + flag_type, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + ) + error_code = self._assert_provider_status( + flag_type, + hook_context, + reversed_merged_hooks, + hook_hints, + ) + if error_code: return FlagEvaluationDetails( flag_key=flag_key, value=default_value, reason=Reason.ERROR, - error_code=ErrorCode.PROVIDER_NOT_READY, + error_code=error_code, ) - if status == ProviderStatus.FATAL: - error_hooks( + + try: + merged_context = self._before_hooks_and_merge_context( flag_type, hook_context, - ProviderFatalError(), + merged_hooks, + hook_hints, + evaluation_context, + ) + + flag_evaluation = await self._create_provider_evaluation_async( + provider, + flag_type, + flag_key, + default_value, + merged_context, + ) + + after_hooks( + flag_type, + hook_context, + flag_evaluation, reversed_merged_hooks, hook_hints, ) + + return flag_evaluation + + except OpenFeatureError as err: + error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + return FlagEvaluationDetails( flag_key=flag_key, value=default_value, reason=Reason.ERROR, - error_code=ErrorCode.PROVIDER_FATAL, + error_code=err.error_code, + error_message=err.error_message, + ) + # Catch any type of exception here since the user can provide any exception + # in the error hooks + except Exception as err: # pragma: no cover + logger.exception( + "Unable to correctly evaluate flag with key: '%s'", flag_key ) - try: - # https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md - # Any resulting evaluation context from a before hook will overwrite - # duplicate fields defined globally, on the client, or in the invocation. - # Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context - invocation_context = before_hooks( - flag_type, hook_context, merged_hooks, hook_hints + error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + + error_message = getattr(err, "error_message", str(err)) + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=error_message, + ) + + finally: + after_all_hooks(flag_type, hook_context, reversed_merged_hooks, hook_hints) + + def evaluate_flag_details( + self, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[typing.Any]: + """ + Evaluate the flag requested by the user from the clients provider. + + :param flag_type: the type of the flag being returned + :param flag_key: the string key of the selected flag + :param default_value: backup value returned if no result found by the provider + :param evaluation_context: Information for the purposes of flag evaluation + :param flag_evaluation_options: Additional flag evaluation information + :return: a FlagEvaluationDetails object with the fully evaluated flag from a + provider + """ + provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = ( + self._establish_hooks_and_provider( + flag_type, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + ) + error_code = self._assert_provider_status( + flag_type, + hook_context, + reversed_merged_hooks, + hook_hints, + ) + if error_code: + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=error_code, ) - invocation_context = invocation_context.merge(ctx2=evaluation_context) - # Requirement 3.2.2 merge: API.context->transaction.context->client.context->invocation.context - merged_context = ( - api.get_evaluation_context() - .merge(api.get_transaction_context()) - .merge(self.context) - .merge(invocation_context) + try: + merged_context = self._before_hooks_and_merge_context( + flag_type, + hook_context, + merged_hooks, + hook_hints, + evaluation_context, ) flag_evaluation = self._create_provider_evaluation( @@ -392,6 +719,48 @@ def evaluate_flag_details( # noqa: PLR0915 finally: after_all_hooks(flag_type, hook_context, reversed_merged_hooks, hook_hints) + async def _create_provider_evaluation_async( + self, + provider: FeatureProvider, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagEvaluationDetails[typing.Any]: + args = ( + flag_key, + default_value, + evaluation_context, + ) + get_details_callables_async: typing.Mapping[ + FlagType, GetDetailCallableAsync + ] = { + FlagType.BOOLEAN: provider.resolve_boolean_details_async, + FlagType.INTEGER: provider.resolve_integer_details_async, + FlagType.FLOAT: provider.resolve_float_details_async, + FlagType.OBJECT: provider.resolve_object_details_async, + FlagType.STRING: provider.resolve_string_details_async, + } + get_details_callable = get_details_callables_async.get(flag_type) + if not get_details_callable: + raise GeneralError(error_message="Unknown flag type") + + resolution = await get_details_callable(*args) + resolution.raise_for_error() + + # we need to check the get_args to be compatible with union types. + _typecheck_flag_value(resolution.value, flag_type) + + return FlagEvaluationDetails( + flag_key=flag_key, + value=resolution.value, + variant=resolution.variant, + flag_metadata=resolution.flag_metadata or {}, + reason=resolution.reason, + error_code=resolution.error_code, + error_message=resolution.error_message, + ) + def _create_provider_evaluation( self, provider: FeatureProvider, diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index b390f928..6a782635 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -47,6 +47,13 @@ def resolve_boolean_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[bool]: ... + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: ... + def resolve_string_details( self, flag_key: str, @@ -54,6 +61,13 @@ def resolve_string_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[str]: ... + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: ... + def resolve_integer_details( self, flag_key: str, @@ -61,6 +75,13 @@ def resolve_integer_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[int]: ... + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: ... + def resolve_float_details( self, flag_key: str, @@ -68,6 +89,13 @@ def resolve_float_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[float]: ... + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: ... + def resolve_object_details( self, flag_key: str, @@ -75,6 +103,13 @@ def resolve_object_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[typing.Union[dict, list]]: ... + async def resolve_object_details_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: ... + class AbstractProvider(FeatureProvider): def attach( @@ -111,6 +146,14 @@ def resolve_boolean_details( ) -> FlagResolutionDetails[bool]: pass + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return self.resolve_boolean_details(flag_key, default_value, evaluation_context) + @abstractmethod def resolve_string_details( self, @@ -120,6 +163,14 @@ def resolve_string_details( ) -> FlagResolutionDetails[str]: pass + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return self.resolve_string_details(flag_key, default_value, evaluation_context) + @abstractmethod def resolve_integer_details( self, @@ -129,6 +180,14 @@ def resolve_integer_details( ) -> FlagResolutionDetails[int]: pass + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return self.resolve_integer_details(flag_key, default_value, evaluation_context) + @abstractmethod def resolve_float_details( self, @@ -138,6 +197,14 @@ def resolve_float_details( ) -> FlagResolutionDetails[float]: pass + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return self.resolve_float_details(flag_key, default_value, evaluation_context) + @abstractmethod def resolve_object_details( self, @@ -147,6 +214,14 @@ def resolve_object_details( ) -> FlagResolutionDetails[typing.Union[dict, list]]: pass + async def resolve_object_details_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + return self.resolve_object_details(flag_key, default_value, evaluation_context) + def emit_provider_ready(self, details: ProviderEventDetails) -> None: self.emit(ProviderEvent.PROVIDER_READY, details) diff --git a/tests/test_client.py b/tests/test_client.py index 7f0ca461..4c00a95a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,3 +1,4 @@ +import asyncio import time import uuid from concurrent.futures import ThreadPoolExecutor @@ -60,9 +61,13 @@ def test_should_get_flag_value_based_on_method_type( "flag_type, default_value, get_method", ( (bool, True, "get_boolean_details"), + (bool, True, "get_boolean_details_async"), (str, "String", "get_string_details"), + (str, "String", "get_string_details_async"), (int, 100, "get_integer_details"), + (int, 100, "get_integer_details_async"), (float, 10.23, "get_float_details"), + (float, 10.23, "get_float_details_async"), ( dict, { @@ -72,28 +77,47 @@ def test_should_get_flag_value_based_on_method_type( }, "get_object_details", ), + ( + dict, + { + "String": "string", + "Number": 2, + "Boolean": True, + }, + "get_object_details_async", + ), ( list, ["string1", "string2"], "get_object_details", ), + ( + list, + ["string1", "string2"], + "get_object_details_async", + ), ), ) -def test_should_get_flag_detail_based_on_method_type( +@pytest.mark.asyncio +async def test_should_get_flag_detail_based_on_method_type( flag_type, default_value, get_method, no_op_provider_client ): # Given # When - flag = getattr(no_op_provider_client, get_method)( - flag_key="Key", default_value=default_value - ) + method = getattr(no_op_provider_client, get_method) + if asyncio.iscoroutinefunction(method): + flag = await method(flag_key="Key", default_value=default_value) + else: + flag = method(flag_key="Key", default_value=default_value) # Then assert flag is not None assert flag.value == default_value assert isinstance(flag.value, flag_type) -def test_should_raise_exception_when_invalid_flag_type_provided(no_op_provider_client): +def test_should_raise_exception_when_invalid_flag_type_provided( + no_op_provider_client, +): # Given # When flag = no_op_provider_client.evaluate_flag_details( From a89bd74395963921bbb1f113a1dca505d021cc6e Mon Sep 17 00:00:00 2001 From: leohoare Date: Mon, 13 Jan 2025 22:16:01 +1100 Subject: [PATCH 2/3] add readme, async to in memory provider and a few more tests Signed-off-by: leohoare --- README.md | 51 +++++++++ openfeature/provider/in_memory_provider.py | 47 ++++++++ tests/provider/test_in_memory_provider.py | 127 ++++++++++++++------- 3 files changed, 182 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 8c078fab..40d3602c 100644 --- a/README.md +++ b/README.md @@ -390,6 +390,57 @@ class MyProvider(AbstractProvider): ... ``` +Providers can also be extended to support async functionality. +To support add asynchronous calls to a provider: +* Implement the `AbstractProvider` as shown above. +* Define asynchronous calls for each data type. + +```python +class MyProvider(AbstractProvider): + ... + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + ... + + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + ... + + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + ... + + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + ... + + async def resolve_object_details_async( + self, + flag_key: str, + default_value: Union[dict, list], + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[Union[dict, list]]: + ... + +``` + + > Built a new provider? [Let us know](https://github.com/open-feature/openfeature.dev/issues/new?assignees=&labels=provider&projects=&template=document-provider.yaml&title=%5BProvider%5D%3A+) so we can add it to the docs! ### Develop a hook diff --git a/openfeature/provider/in_memory_provider.py b/openfeature/provider/in_memory_provider.py index 322f4ed6..d64a7735 100644 --- a/openfeature/provider/in_memory_provider.py +++ b/openfeature/provider/in_memory_provider.py @@ -76,6 +76,14 @@ def resolve_boolean_details( ) -> FlagResolutionDetails[bool]: return self._resolve(flag_key, evaluation_context) + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return await self._resolve_async(flag_key, evaluation_context) + def resolve_string_details( self, flag_key: str, @@ -84,6 +92,14 @@ def resolve_string_details( ) -> FlagResolutionDetails[str]: return self._resolve(flag_key, evaluation_context) + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return await self._resolve_async(flag_key, evaluation_context) + def resolve_integer_details( self, flag_key: str, @@ -92,6 +108,14 @@ def resolve_integer_details( ) -> FlagResolutionDetails[int]: return self._resolve(flag_key, evaluation_context) + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return await self._resolve_async(flag_key, evaluation_context) + def resolve_float_details( self, flag_key: str, @@ -100,6 +124,14 @@ def resolve_float_details( ) -> FlagResolutionDetails[float]: return self._resolve(flag_key, evaluation_context) + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return await self._resolve_async(flag_key, evaluation_context) + def resolve_object_details( self, flag_key: str, @@ -108,6 +140,14 @@ def resolve_object_details( ) -> FlagResolutionDetails[typing.Union[dict, list]]: return self._resolve(flag_key, evaluation_context) + async def resolve_object_details_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + return await self._resolve_async(flag_key, evaluation_context) + def _resolve( self, flag_key: str, @@ -117,3 +157,10 @@ def _resolve( if flag is None: raise FlagNotFoundError(f"Flag '{flag_key}' not found") return flag.resolve(evaluation_context) + + async def _resolve_async( + self, + flag_key: str, + evaluation_context: typing.Optional[EvaluationContext], + ) -> FlagResolutionDetails[V]: + return self._resolve(flag_key, evaluation_context) diff --git a/tests/provider/test_in_memory_provider.py b/tests/provider/test_in_memory_provider.py index 66d5239e..f3559363 100644 --- a/tests/provider/test_in_memory_provider.py +++ b/tests/provider/test_in_memory_provider.py @@ -17,16 +17,20 @@ def test_should_return_in_memory_provider_metadata(): assert metadata.name == "In-Memory Provider" -def test_should_handle_unknown_flags_correctly(): +@pytest.mark.asyncio +async def test_should_handle_unknown_flags_correctly(): # Given provider = InMemoryProvider({}) # When with pytest.raises(FlagNotFoundError): provider.resolve_boolean_details(flag_key="Key", default_value=True) + with pytest.raises(FlagNotFoundError): + await provider.resolve_integer_details_async(flag_key="Key", default_value=1) # Then -def test_calls_context_evaluator_if_present(): +@pytest.mark.asyncio +async def test_calls_context_evaluator_if_present(): # Given def context_evaluator(flag: InMemoryFlag, evaluation_context: dict): return FlagResolutionDetails( @@ -44,57 +48,81 @@ def context_evaluator(flag: InMemoryFlag, evaluation_context: dict): } ) # When - flag = provider.resolve_boolean_details(flag_key="Key", default_value=False) + flag_sync = provider.resolve_boolean_details(flag_key="Key", default_value=False) + flag_async = await provider.resolve_boolean_details_async( + flag_key="Key", default_value=False + ) # Then - assert flag is not None - assert flag.value is False - assert isinstance(flag.value, bool) - assert flag.reason == Reason.TARGETING_MATCH + assert flag_sync == flag_async + for flag in [flag_sync, flag_async]: + assert flag is not None + assert flag.value is False + assert isinstance(flag.value, bool) + assert flag.reason == Reason.TARGETING_MATCH -def test_should_resolve_boolean_flag_from_in_memory(): +@pytest.mark.asyncio +async def test_should_resolve_boolean_flag_from_in_memory(): # Given provider = InMemoryProvider( {"Key": InMemoryFlag("true", {"true": True, "false": False})} ) # When - flag = provider.resolve_boolean_details(flag_key="Key", default_value=False) + flag_sync = provider.resolve_boolean_details(flag_key="Key", default_value=False) + flag_async = await provider.resolve_boolean_details_async( + flag_key="Key", default_value=False + ) # Then - assert flag is not None - assert flag.value is True - assert isinstance(flag.value, bool) - assert flag.variant == "true" + assert flag_sync == flag_async + for flag in [flag_sync, flag_async]: + assert flag is not None + assert flag.value is True + assert isinstance(flag.value, bool) + assert flag.variant == "true" -def test_should_resolve_integer_flag_from_in_memory(): +@pytest.mark.asyncio +async def test_should_resolve_integer_flag_from_in_memory(): # Given provider = InMemoryProvider( {"Key": InMemoryFlag("hundred", {"zero": 0, "hundred": 100})} ) # When - flag = provider.resolve_integer_details(flag_key="Key", default_value=0) + flag_sync = provider.resolve_integer_details(flag_key="Key", default_value=0) + flag_async = await provider.resolve_integer_details_async( + flag_key="Key", default_value=0 + ) # Then - assert flag is not None - assert flag.value == 100 - assert isinstance(flag.value, Number) - assert flag.variant == "hundred" + assert flag_sync == flag_async + for flag in [flag_sync, flag_async]: + assert flag is not None + assert flag.value == 100 + assert isinstance(flag.value, Number) + assert flag.variant == "hundred" -def test_should_resolve_float_flag_from_in_memory(): +@pytest.mark.asyncio +async def test_should_resolve_float_flag_from_in_memory(): # Given provider = InMemoryProvider( {"Key": InMemoryFlag("ten", {"zero": 0.0, "ten": 10.23})} ) # When - flag = provider.resolve_float_details(flag_key="Key", default_value=0.0) + flag_sync = provider.resolve_float_details(flag_key="Key", default_value=0.0) + flag_async = await provider.resolve_float_details_async( + flag_key="Key", default_value=0.0 + ) # Then - assert flag is not None - assert flag.value == 10.23 - assert isinstance(flag.value, Number) - assert flag.variant == "ten" + assert flag_sync == flag_async + for flag in [flag_sync, flag_async]: + assert flag is not None + assert flag.value == 10.23 + assert isinstance(flag.value, Number) + assert flag.variant == "ten" -def test_should_resolve_string_flag_from_in_memory(): +@pytest.mark.asyncio +async def test_should_resolve_string_flag_from_in_memory(): # Given provider = InMemoryProvider( { @@ -105,29 +133,39 @@ def test_should_resolve_string_flag_from_in_memory(): } ) # When - flag = provider.resolve_string_details(flag_key="Key", default_value="Default") + flag_sync = provider.resolve_string_details(flag_key="Key", default_value="Default") + flag_async = await provider.resolve_string_details_async( + flag_key="Key", default_value="Default" + ) # Then - assert flag is not None - assert flag.value == "String" - assert isinstance(flag.value, str) - assert flag.variant == "stringVariant" + assert flag_sync == flag_async + for flag in [flag_sync, flag_async]: + assert flag is not None + assert flag.value == "String" + assert isinstance(flag.value, str) + assert flag.variant == "stringVariant" -def test_should_resolve_list_flag_from_in_memory(): +@pytest.mark.asyncio +async def test_should_resolve_list_flag_from_in_memory(): # Given provider = InMemoryProvider( {"Key": InMemoryFlag("twoItems", {"empty": [], "twoItems": ["item1", "item2"]})} ) # When - flag = provider.resolve_object_details(flag_key="Key", default_value=[]) + flag_sync = provider.resolve_object_details(flag_key="Key", default_value=[]) + flag_async = provider.resolve_object_details(flag_key="Key", default_value=[]) # Then - assert flag is not None - assert flag.value == ["item1", "item2"] - assert isinstance(flag.value, list) - assert flag.variant == "twoItems" + assert flag_sync == flag_async + for flag in [flag_sync, flag_async]: + assert flag is not None + assert flag.value == ["item1", "item2"] + assert isinstance(flag.value, list) + assert flag.variant == "twoItems" -def test_should_resolve_object_flag_from_in_memory(): +@pytest.mark.asyncio +async def test_should_resolve_object_flag_from_in_memory(): # Given return_value = { "String": "string", @@ -138,9 +176,12 @@ def test_should_resolve_object_flag_from_in_memory(): {"Key": InMemoryFlag("obj", {"obj": return_value, "empty": {}})} ) # When - flag = provider.resolve_object_details(flag_key="Key", default_value={}) + flag_sync = provider.resolve_object_details(flag_key="Key", default_value={}) + flag_async = provider.resolve_object_details(flag_key="Key", default_value={}) # Then - assert flag is not None - assert flag.value == return_value - assert isinstance(flag.value, dict) - assert flag.variant == "obj" + assert flag_sync == flag_async + for flag in [flag_sync, flag_async]: + assert flag is not None + assert flag.value == return_value + assert isinstance(flag.value, dict) + assert flag.variant == "obj" From bb9a4e6aa23f60562a608f52124793dafb4b1d0c Mon Sep 17 00:00:00 2001 From: leohoare Date: Wed, 22 Jan 2025 20:56:16 +1100 Subject: [PATCH 3/3] add test coverage, async providers calling sync calls, async only client Signed-off-by: leohoare --- tests/provider/test_in_memory_provider.py | 4 +- tests/provider/test_provider_compatibility.py | 196 ++++++++++++++++++ tests/test_client.py | 140 ++++++++++--- 3 files changed, 309 insertions(+), 31 deletions(-) create mode 100644 tests/provider/test_provider_compatibility.py diff --git a/tests/provider/test_in_memory_provider.py b/tests/provider/test_in_memory_provider.py index f3559363..cdcea7bf 100644 --- a/tests/provider/test_in_memory_provider.py +++ b/tests/provider/test_in_memory_provider.py @@ -154,7 +154,9 @@ async def test_should_resolve_list_flag_from_in_memory(): ) # When flag_sync = provider.resolve_object_details(flag_key="Key", default_value=[]) - flag_async = provider.resolve_object_details(flag_key="Key", default_value=[]) + flag_async = await provider.resolve_object_details_async( + flag_key="Key", default_value=[] + ) # Then assert flag_sync == flag_async for flag in [flag_sync, flag_async]: diff --git a/tests/provider/test_provider_compatibility.py b/tests/provider/test_provider_compatibility.py new file mode 100644 index 00000000..d90859c4 --- /dev/null +++ b/tests/provider/test_provider_compatibility.py @@ -0,0 +1,196 @@ +import asyncio +from typing import Optional, Union + +import pytest + +from openfeature.api import OpenFeatureClient, get_client, set_provider +from openfeature.evaluation_context import EvaluationContext +from openfeature.flag_evaluation import FlagResolutionDetails +from openfeature.provider import AbstractProvider, Metadata + + +class SynchronousProvider(AbstractProvider): + def get_metadata(self): + return Metadata(name="SynchronousProvider") + + def get_provider_hooks(self): + return [] + + def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails(value=True) + + def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return FlagResolutionDetails(value="string") + + def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return FlagResolutionDetails(value=1) + + def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return FlagResolutionDetails(value=10.0) + + def resolve_object_details( + self, + flag_key: str, + default_value: Union[dict, list], + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[Union[dict, list]]: + return FlagResolutionDetails(value={"key": "value"}) + + +@pytest.mark.parametrize( + "flag_type, default_value, get_method", + ( + (bool, True, "get_boolean_value_async"), + (str, "string", "get_string_value_async"), + (int, 1, "get_integer_value_async"), + (float, 10.0, "get_float_value_async"), + ( + dict, + {"key": "value"}, + "get_object_value_async", + ), + ), +) +@pytest.mark.asyncio +async def test_sync_provider_can_be_called_async(flag_type, default_value, get_method): + # Given + set_provider(SynchronousProvider(), "SynchronousProvider") + client = get_client("SynchronousProvider") + # When + async_callable = getattr(client, get_method) + flag = await async_callable(flag_key="Key", default_value=default_value) + # Then + assert flag is not None + assert flag == default_value + assert isinstance(flag, flag_type) + + +@pytest.mark.asyncio +async def test_sync_provider_can_be_extended_async(): + # Given + class ExtendedAsyncProvider(SynchronousProvider): + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails(value=False) + + set_provider(ExtendedAsyncProvider(), "ExtendedAsyncProvider") + client = get_client("ExtendedAsyncProvider") + # When + flag = await client.get_boolean_value_async(flag_key="Key", default_value=True) + # Then + assert flag is not None + assert flag is False + + +# We're not allowing providers to only have async methods +def test_sync_methods_enforced_for_async_providers(): + # Given + class AsyncProvider(AbstractProvider): + def get_metadata(self): + return Metadata(name="AsyncProvider") + + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails(value=True) + + # When + with pytest.raises(TypeError) as exception: + set_provider(AsyncProvider(), "AsyncProvider") + + # Then + # assert + assert str(exception.value).startswith( + "Can't instantiate abstract class AsyncProvider with abstract methods resolve_boolean_details" + ) + + +@pytest.mark.asyncio +async def test_async_provider_not_implemented_exception_workaround(): + # Given + class SyncNotImplementedProvider(AbstractProvider): + def get_metadata(self): + return Metadata(name="AsyncProvider") + + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails(value=True) + + def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + raise NotImplementedError("Use the async method") + + def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + raise NotImplementedError("Use the async method") + + def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + raise NotImplementedError("Use the async method") + + def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + raise NotImplementedError("Use the async method") + + def resolve_object_details( + self, + flag_key: str, + default_value: Union[dict, list], + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[Union[dict, list]]: + raise NotImplementedError("Use the async method") + + # When + set_provider(SyncNotImplementedProvider(), "SyncNotImplementedProvider") + client = get_client("SyncNotImplementedProvider") + flag = await client.get_boolean_value_async(flag_key="Key", default_value=False) + # Then + assert flag is not None + assert flag is True diff --git a/tests/test_client.py b/tests/test_client.py index 4c00a95a..695a92c1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,7 +8,7 @@ from openfeature import api from openfeature.api import add_hooks, clear_hooks, get_client, set_provider -from openfeature.client import OpenFeatureClient +from openfeature.client import GeneralError, OpenFeatureClient, _typecheck_flag_value from openfeature.evaluation_context import EvaluationContext from openfeature.event import EventDetails, ProviderEvent, ProviderEventDetails from openfeature.exception import ErrorCode, OpenFeatureError @@ -24,9 +24,13 @@ "flag_type, default_value, get_method", ( (bool, True, "get_boolean_value"), + (bool, True, "get_boolean_value_async"), (str, "String", "get_string_value"), + (str, "String", "get_string_value_async"), (int, 100, "get_integer_value"), + (int, 100, "get_integer_value_async"), (float, 10.23, "get_float_value"), + (float, 10.23, "get_float_value_async"), ( dict, { @@ -36,21 +40,38 @@ }, "get_object_value", ), + ( + dict, + { + "String": "string", + "Number": 2, + "Boolean": True, + }, + "get_object_value_async", + ), ( list, ["string1", "string2"], "get_object_value", ), + ( + list, + ["string1", "string2"], + "get_object_value_async", + ), ), ) -def test_should_get_flag_value_based_on_method_type( +@pytest.mark.asyncio +async def test_should_get_flag_value_based_on_method_type( flag_type, default_value, get_method, no_op_provider_client ): # Given # When - flag = getattr(no_op_provider_client, get_method)( - flag_key="Key", default_value=default_value - ) + method = getattr(no_op_provider_client, get_method) + if asyncio.iscoroutinefunction(method): + flag = await method(flag_key="Key", default_value=default_value) + else: + flag = method(flag_key="Key", default_value=default_value) # Then assert flag is not None assert flag == default_value @@ -115,19 +136,24 @@ async def test_should_get_flag_detail_based_on_method_type( assert isinstance(flag.value, flag_type) -def test_should_raise_exception_when_invalid_flag_type_provided( +@pytest.mark.asyncio +async def test_should_raise_exception_when_invalid_flag_type_provided( no_op_provider_client, ): # Given # When - flag = no_op_provider_client.evaluate_flag_details( + flag_sync = no_op_provider_client.evaluate_flag_details( + flag_type=None, flag_key="Key", default_value=True + ) + flag_async = await no_op_provider_client.evaluate_flag_details_async( flag_type=None, flag_key="Key", default_value=True ) # Then - assert flag.value - assert flag.error_message == "Unknown flag type" - assert flag.error_code == ErrorCode.GENERAL - assert flag.reason == Reason.ERROR + for flag in [flag_sync, flag_async]: + assert flag.value + assert flag.error_message == "Unknown flag type" + assert flag.error_code == ErrorCode.GENERAL + assert flag.reason == Reason.ERROR def test_should_pass_flag_metadata_from_resolution_to_evaluation_details(): @@ -226,7 +252,8 @@ def test_should_define_a_provider_status_accessor(no_op_provider_client): # Requirement 1.7.6 -def test_should_shortcircuit_if_provider_is_not_ready( +@pytest.mark.asyncio +async def test_should_shortcircuit_if_provider_is_not_ready( no_op_provider_client, monkeypatch ): # Given @@ -236,19 +263,26 @@ def test_should_shortcircuit_if_provider_is_not_ready( spy_hook = MagicMock(spec=Hook) no_op_provider_client.add_hooks([spy_hook]) # When - flag_details = no_op_provider_client.get_boolean_details( + flag_details_sync = no_op_provider_client.get_boolean_details( + flag_key="Key", default_value=True + ) + spy_hook.error.assert_called_once() + spy_hook.reset_mock() + flag_details_async = await no_op_provider_client.get_boolean_details_async( flag_key="Key", default_value=True ) # Then - assert flag_details is not None - assert flag_details.value - assert flag_details.reason == Reason.ERROR - assert flag_details.error_code == ErrorCode.PROVIDER_NOT_READY + for flag_details in [flag_details_sync, flag_details_async]: + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.PROVIDER_NOT_READY spy_hook.error.assert_called_once() # Requirement 1.7.7 -def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( +@pytest.mark.asyncio +async def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( no_op_provider_client, monkeypatch ): # Given @@ -258,41 +292,87 @@ def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( spy_hook = MagicMock(spec=Hook) no_op_provider_client.add_hooks([spy_hook]) # When - flag_details = no_op_provider_client.get_boolean_details( + flag_details_sync = no_op_provider_client.get_boolean_details( + flag_key="Key", default_value=True + ) + spy_hook.error.assert_called_once() + spy_hook.reset_mock() + flag_details_async = await no_op_provider_client.get_boolean_details_async( flag_key="Key", default_value=True ) # Then - assert flag_details is not None - assert flag_details.value - assert flag_details.reason == Reason.ERROR - assert flag_details.error_code == ErrorCode.PROVIDER_FATAL + for flag_details in [flag_details_sync, flag_details_async]: + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.PROVIDER_FATAL spy_hook.error.assert_called_once() -def test_should_run_error_hooks_if_provider_returns_resolution_with_error_code(): +@pytest.mark.asyncio +async def test_should_run_error_hooks_if_provider_returns_resolution_with_error_code(): # Given spy_hook = MagicMock(spec=Hook) provider = MagicMock(spec=FeatureProvider) provider.get_provider_hooks.return_value = [] - provider.resolve_boolean_details.return_value = FlagResolutionDetails( + mock_resolution = FlagResolutionDetails( value=True, reason=Reason.ERROR, error_code=ErrorCode.PROVIDER_FATAL, error_message="This is an error message", ) + provider.resolve_boolean_details.return_value = mock_resolution + provider.resolve_boolean_details_async.return_value = mock_resolution set_provider(provider) client = get_client() client.add_hooks([spy_hook]) # When - flag_details = client.get_boolean_details(flag_key="Key", default_value=True) + flag_details_sync = client.get_boolean_details(flag_key="Key", default_value=True) + spy_hook.error.assert_called_once() + spy_hook.reset_mock() + flag_details_async = await client.get_boolean_details_async( + flag_key="Key", default_value=True + ) # Then - assert flag_details is not None - assert flag_details.value - assert flag_details.reason == Reason.ERROR - assert flag_details.error_code == ErrorCode.PROVIDER_FATAL + for flag_details in [flag_details_sync, flag_details_async]: + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.PROVIDER_FATAL spy_hook.error.assert_called_once() +@pytest.mark.asyncio +async def test_client_type_mismatch_exceptions(): + # Given + client = get_client() + # When + flag_details_sync = client.get_boolean_details( + flag_key="Key", default_value="type mismatch" + ) + flag_details_async = await client.get_boolean_details_async( + flag_key="Key", default_value="type mismatch" + ) + # Then + for flag_details in [flag_details_sync, flag_details_async]: + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.TYPE_MISMATCH + + +@pytest.mark.asyncio +async def test_client_general_exception(): + # Given + flag_value = "A" + flag_type = None + # When + with pytest.raises(GeneralError) as e: + flag_type = _typecheck_flag_value(flag_value, flag_type) + # Then + assert e.value.error_message == "Unknown flag type" + + def test_provider_events(): # Given provider = NoOpProvider()