diff --git a/src/main/java/dev/openfeature/sdk/FeatureProvider.java b/src/main/java/dev/openfeature/sdk/FeatureProvider.java index 22819ef10..694d5e14c 100644 --- a/src/main/java/dev/openfeature/sdk/FeatureProvider.java +++ b/src/main/java/dev/openfeature/sdk/FeatureProvider.java @@ -15,6 +15,23 @@ default List getProviderHooks() { return new ArrayList<>(); } + /** + * Returns all hooks that support the given flag value type. + * + * @param flagType the flag value type to support + * @return a list of hooks that support the given flag value type + */ + default List getProviderHooks(FlagValueType flagType) { + var allHooks = getProviderHooks(); + var filteredHooks = new ArrayList(allHooks.size()); + for (Hook hook : allHooks) { + if (hook.supportsFlagValueType(flagType)) { + filteredHooks.add(hook); + } + } + return filteredHooks; + } + ProviderEvaluation getBooleanEvaluation(String key, Boolean defaultValue, EvaluationContext ctx); ProviderEvaluation getStringEvaluation(String key, String defaultValue, EvaluationContext ctx); diff --git a/src/main/java/dev/openfeature/sdk/FlagEvaluationOptions.java b/src/main/java/dev/openfeature/sdk/FlagEvaluationOptions.java index f73bd9631..f17b85bb4 100644 --- a/src/main/java/dev/openfeature/sdk/FlagEvaluationOptions.java +++ b/src/main/java/dev/openfeature/sdk/FlagEvaluationOptions.java @@ -1,5 +1,7 @@ package dev.openfeature.sdk; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -19,4 +21,18 @@ public class FlagEvaluationOptions { @Builder.Default Map hookHints = new HashMap<>(); + + List getHooks(FlagValueType flagValueType) { + if (hooks == null || hooks.isEmpty()) { + return Collections.emptyList(); + } + + var result = new ArrayList(hooks.size()); + for (var hook : hooks) { + if (hook.supportsFlagValueType(flagValueType)) { + result.add(hook); + } + } + return result; + } } diff --git a/src/main/java/dev/openfeature/sdk/HookSupport.java b/src/main/java/dev/openfeature/sdk/HookSupport.java index 0254c07fd..2c80bf7ac 100644 --- a/src/main/java/dev/openfeature/sdk/HookSupport.java +++ b/src/main/java/dev/openfeature/sdk/HookSupport.java @@ -1,9 +1,12 @@ package dev.openfeature.sdk; import java.util.ArrayList; -import java.util.Collections; +import java.util.Collection; +import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentLinkedQueue; import lombok.extern.slf4j.Slf4j; /** @@ -17,20 +20,53 @@ class HookSupport { * Sets the {@link Hook}-{@link HookContext}-{@link Pair} list in the given data object with {@link HookContext} * set to null. Filters hooks by supported {@link FlagValueType}. * - * @param hookSupportData the data object to modify - * @param hooks the hooks to set - * @param type the flag value type to filter unsupported hooks + * @param hookSupportData the data object to modify + * @param providerHooks the hooks filtered for the proper flag value type from the respective layer + * @param flagOptionsHooks the hooks filtered for the proper flag value type from the respective layer + * @param clientHooks the hooks filtered for the proper flag value type from the respective layer + * @param apiHooks the hooks filtered for the proper flag value type from the respective layer */ - public void setHooks(HookSupportData hookSupportData, List hooks, FlagValueType type) { - List> hookContextPairs = new ArrayList<>(); - for (Hook hook : hooks) { - if (hook.supportsFlagValueType(type)) { - hookContextPairs.add(Pair.of(hook, null)); - } + public void setHooks( + HookSupportData hookSupportData, + List providerHooks, + List flagOptionsHooks, + ConcurrentLinkedQueue clientHooks, + ConcurrentLinkedQueue apiHooks) { + var lengthEstimate = 0; + + if (providerHooks != null) { + lengthEstimate += providerHooks.size(); + } + if (flagOptionsHooks != null) { + lengthEstimate += flagOptionsHooks.size(); + } + if (clientHooks != null) { + lengthEstimate += clientHooks.size(); } + if (apiHooks != null) { + lengthEstimate += apiHooks.size(); + } + + ArrayList> hookContextPairs = new ArrayList<>(lengthEstimate); + + addAll(hookContextPairs, providerHooks); + addAll(hookContextPairs, flagOptionsHooks); + addAll(hookContextPairs, clientHooks); + addAll(hookContextPairs, apiHooks); + hookSupportData.hooks = hookContextPairs; } + private void addAll(List> accumulator, Collection toAdd) { + if (toAdd == null || toAdd.isEmpty()) { + return; + } + + for (Hook hook : toAdd) { + accumulator.add(Pair.of(hook, null)); + } + } + /** * Creates & sets a {@link HookContext} for every {@link Hook}-{@link HookContext}-{@link Pair} * in the given data object with a new {@link HookData} instance. @@ -51,10 +87,9 @@ public void setHookContexts( public void executeBeforeHooks(HookSupportData data) { // These traverse backwards from normal. - List> reversedHooks = new ArrayList<>(data.getHooks()); - Collections.reverse(reversedHooks); - - for (Pair hookContextPair : reversedHooks) { + var hooks = data.getHooks(); + for (int i = hooks.size() - 1; i >= 0; i--) { + var hookContextPair = hooks.get(i); var hook = hookContextPair.getKey(); var hookContext = hookContextPair.getValue(); @@ -111,4 +146,26 @@ public void executeAfterAllHooks(HookSupportData data, FlagEvaluationDetails } } } + + static void addHooks(Map> hookMap, Hook... hooksToAdd) { + var types = FlagValueType.values(); + for (int i = 0; i < hooksToAdd.length; i++) { + var current = hooksToAdd[i]; + for (int j = 0; j < types.length; j++) { + var type = types[j]; + if (current.supportsFlagValueType(type)) { + hookMap.get(type).add(current); + } + } + } + } + + static ArrayList getAllUniqueHooks(Map> hookMap) { + // Hooks can be duplicated if they support multiple FlagValueTypes + var allHooks = new HashSet(); + for (var queue : hookMap.values()) { + allHooks.addAll(queue); + } + return new ArrayList<>(allHooks); + } } diff --git a/src/main/java/dev/openfeature/sdk/HookSupportData.java b/src/main/java/dev/openfeature/sdk/HookSupportData.java index 174702ea2..c9c8dfc7d 100644 --- a/src/main/java/dev/openfeature/sdk/HookSupportData.java +++ b/src/main/java/dev/openfeature/sdk/HookSupportData.java @@ -1,6 +1,6 @@ package dev.openfeature.sdk; -import java.util.List; +import java.util.ArrayList; import java.util.Map; import lombok.Getter; @@ -10,7 +10,7 @@ @Getter class HookSupportData { - List> hooks; + ArrayList> hooks; LayeredEvaluationContext evaluationContext; Map hints; diff --git a/src/main/java/dev/openfeature/sdk/LayeredEvaluationContext.java b/src/main/java/dev/openfeature/sdk/LayeredEvaluationContext.java index a58d82685..7a00d7b02 100644 --- a/src/main/java/dev/openfeature/sdk/LayeredEvaluationContext.java +++ b/src/main/java/dev/openfeature/sdk/LayeredEvaluationContext.java @@ -51,6 +51,10 @@ public LayeredEvaluationContext( } } + public static LayeredEvaluationContext empty() { + return new LayeredEvaluationContext(null, null, null, null); + } + @Override public String getTargetingKey() { return targetingKey; diff --git a/src/main/java/dev/openfeature/sdk/OpenFeatureAPI.java b/src/main/java/dev/openfeature/sdk/OpenFeatureAPI.java index 6d0d8feb4..14ec034be 100644 --- a/src/main/java/dev/openfeature/sdk/OpenFeatureAPI.java +++ b/src/main/java/dev/openfeature/sdk/OpenFeatureAPI.java @@ -3,12 +3,10 @@ import dev.openfeature.sdk.exceptions.OpenFeatureError; import dev.openfeature.sdk.internal.AutoCloseableLock; import dev.openfeature.sdk.internal.AutoCloseableReentrantReadWriteLock; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; import java.util.List; import java.util.Optional; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -24,14 +22,18 @@ public class OpenFeatureAPI implements EventBus { // package-private multi-read/single-write lock static AutoCloseableReentrantReadWriteLock lock = new AutoCloseableReentrantReadWriteLock(); - private final ConcurrentLinkedQueue apiHooks; + private final ConcurrentHashMap> apiHooks; private ProviderRepository providerRepository; private EventSupport eventSupport; private final AtomicReference evaluationContext = new AtomicReference<>(); private TransactionContextPropagator transactionContextPropagator; protected OpenFeatureAPI() { - apiHooks = new ConcurrentLinkedQueue<>(); + var values = FlagValueType.values(); + apiHooks = new ConcurrentHashMap<>(values.length); + for (FlagValueType value : values) { + apiHooks.put(value, new ConcurrentLinkedQueue<>()); + } providerRepository = new ProviderRepository(this); eventSupport = new EventSupport(); transactionContextPropagator = new NoOpTransactionContextPropagator(); @@ -304,7 +306,7 @@ public FeatureProvider getProvider(String domain) { * @param hooks The hook to add. */ public void addHooks(Hook... hooks) { - this.apiHooks.addAll(Arrays.asList(hooks)); + HookSupport.addHooks(apiHooks, hooks); } /** @@ -313,16 +315,16 @@ public void addHooks(Hook... hooks) { * @return A list of {@link Hook}s. */ public List getHooks() { - return new ArrayList<>(this.apiHooks); + return HookSupport.getAllUniqueHooks(apiHooks); } /** - * Returns a reference to the collection of {@link Hook}s. + * Fetch the hooks associated to this client, that support the given FlagValueType. * - * @return The collection of {@link Hook}s. + * @return A list of {@link Hook}s. */ - Collection getMutableHooks() { - return this.apiHooks; + ConcurrentLinkedQueue getHooks(FlagValueType type) { + return apiHooks.get(type); } /** diff --git a/src/main/java/dev/openfeature/sdk/OpenFeatureClient.java b/src/main/java/dev/openfeature/sdk/OpenFeatureClient.java index 0d5d0e643..06fb2b435 100644 --- a/src/main/java/dev/openfeature/sdk/OpenFeatureClient.java +++ b/src/main/java/dev/openfeature/sdk/OpenFeatureClient.java @@ -5,15 +5,13 @@ import dev.openfeature.sdk.exceptions.GeneralError; import dev.openfeature.sdk.exceptions.OpenFeatureError; import dev.openfeature.sdk.exceptions.ProviderNotReadyError; -import dev.openfeature.sdk.internal.ObjectUtils; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; -import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -47,7 +45,7 @@ public class OpenFeatureClient implements Client { @Getter private final String version; - private final ConcurrentLinkedQueue clientHooks; + private final ConcurrentHashMap> clientHooks; private final AtomicReference evaluationContext = new AtomicReference<>(); private final HookSupport hookSupport; @@ -69,7 +67,11 @@ public OpenFeatureClient(OpenFeatureAPI openFeatureAPI, String domain, String ve this.domain = domain; this.version = version; this.hookSupport = new HookSupport(); - this.clientHooks = new ConcurrentLinkedQueue<>(); + var values = FlagValueType.values(); + this.clientHooks = new ConcurrentHashMap<>(values.length); + for (FlagValueType value : values) { + this.clientHooks.put(value, new ConcurrentLinkedQueue<>()); + } } /** @@ -125,7 +127,7 @@ public void track(String trackingEventName, EvaluationContext context, TrackingE */ @Override public OpenFeatureClient addHooks(Hook... hooks) { - this.clientHooks.addAll(Arrays.asList(hooks)); + HookSupport.addHooks(clientHooks, hooks); return this; } @@ -134,7 +136,7 @@ public OpenFeatureClient addHooks(Hook... hooks) { */ @Override public List getHooks() { - return new ArrayList<>(this.clientHooks); + return HookSupport.getAllUniqueHooks(clientHooks); } /** @@ -185,9 +187,12 @@ private FlagEvaluationDetails evaluateFlag( final var state = stateManager.getState(); // Hooks are initialized as early as possible to enable the execution of error stages - var mergedHooks = ObjectUtils.merge( - provider.getProviderHooks(), flagOptions.getHooks(), clientHooks, openfeatureApi.getMutableHooks()); - hookSupport.setHooks(hookSupportData, mergedHooks, type); + hookSupport.setHooks( + hookSupportData, + provider.getProviderHooks(type), + flagOptions.getHooks(type), + clientHooks.get(type), + openfeatureApi.getHooks(type)); var sharedHookContext = new SharedHookContext(key, type, this.getMetadata(), provider.getMetadata(), defaultValue); diff --git a/src/test/java/dev/openfeature/sdk/DeveloperExperienceTest.java b/src/test/java/dev/openfeature/sdk/DeveloperExperienceTest.java index 19108bde5..03e851566 100644 --- a/src/test/java/dev/openfeature/sdk/DeveloperExperienceTest.java +++ b/src/test/java/dev/openfeature/sdk/DeveloperExperienceTest.java @@ -3,6 +3,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -14,6 +15,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; import lombok.SneakyThrows; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -101,9 +103,10 @@ void brokenProvider() { void providerLockedPerTransaction() { final String defaultValue = "string-value"; - final OpenFeatureAPI api = new OpenFeatureAPI(); - var provider1 = TestProvider.builder().initsToReady(); - var provider2 = TestProvider.builder().initsToReady(); + final OpenFeatureAPI testApi = new OpenFeatureAPI(); + final var provider1 = TestProvider.builder().initsToReady(); + final var provider2 = TestProvider.builder().initsToReady(); + final var wasHookCalled = new AtomicBoolean(false); class MutatingHook implements Hook { @@ -112,24 +115,27 @@ class MutatingHook implements Hook { // change the provider during a before hook - this should not impact the evaluation in progress public Optional before(HookContext ctx, Map hints) { - api.setProviderAndWait(provider2); - + testApi.setProviderAndWait(provider2); + wasHookCalled.set(true); return Optional.empty(); } } - final Client client = api.getClient(); - api.setProviderAndWait(provider1); - api.addHooks(new MutatingHook()); + final Client client = testApi.getClient(); + testApi.setProviderAndWait(provider1); + testApi.addHooks(new MutatingHook()); // if provider is changed during an evaluation transaction it should proceed with the original provider client.getStringValue("val", defaultValue); assertEquals(1, provider1.getFlagEvaluations().size()); + assertEquals(0, provider2.getFlagEvaluations().size()); + assertTrue(wasHookCalled.get()); - api.clearHooks(); + testApi.clearHooks(); // subsequent evaluations should now use new provider set by hook client.getStringValue("val", defaultValue); + assertEquals(1, provider1.getFlagEvaluations().size()); assertEquals(1, provider2.getFlagEvaluations().size()); } diff --git a/src/test/java/dev/openfeature/sdk/FlagEvaluationSpecTest.java b/src/test/java/dev/openfeature/sdk/FlagEvaluationSpecTest.java index 82aa4e3cc..67dded945 100644 --- a/src/test/java/dev/openfeature/sdk/FlagEvaluationSpecTest.java +++ b/src/test/java/dev/openfeature/sdk/FlagEvaluationSpecTest.java @@ -14,6 +14,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import dev.openfeature.sdk.e2e.Flag; import dev.openfeature.sdk.exceptions.GeneralError; @@ -143,6 +144,8 @@ void provider_metadata() { void hook_addition() { Hook h1 = mock(Hook.class); Hook h2 = mock(Hook.class); + when(h1.supportsFlagValueType(any())).thenReturn(true); + when(h2.supportsFlagValueType(any())).thenReturn(true); api.addHooks(h1); assertEquals(1, api.getHooks().size()); @@ -150,7 +153,7 @@ void hook_addition() { api.addHooks(h2); assertEquals(2, api.getHooks().size()); - assertEquals(h2, api.getHooks().get(1)); + assertTrue(api.getHooks().contains(h2)); } @Specification( @@ -175,6 +178,8 @@ void hookRegistration() { Client c = _client(); Hook m1 = mock(Hook.class); Hook m2 = mock(Hook.class); + when(m1.supportsFlagValueType(any())).thenReturn(true); + when(m2.supportsFlagValueType(any())).thenReturn(true); c.addHooks(m1); c.addHooks(m2); List hooks = c.getHooks(); diff --git a/src/test/java/dev/openfeature/sdk/HookSupportTest.java b/src/test/java/dev/openfeature/sdk/HookSupportTest.java index 3b21aff84..2b3dceab7 100644 --- a/src/test/java/dev/openfeature/sdk/HookSupportTest.java +++ b/src/test/java/dev/openfeature/sdk/HookSupportTest.java @@ -8,11 +8,12 @@ import static org.mockito.Mockito.when; import dev.openfeature.sdk.fixtures.HookFixtures; -import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentLinkedQueue; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -38,7 +39,12 @@ void shouldMergeEvaluationContextsOnBeforeHooksCorrectly() { var sharedContext = getBaseHookContextForType(FlagValueType.STRING); var hookSupportData = new HookSupportData(); hookSupportData.evaluationContext = layered; - hookSupport.setHooks(hookSupportData, Arrays.asList(hook1, hook2), FlagValueType.STRING); + hookSupport.setHooks( + hookSupportData, + List.of(hook1, hook2), + Collections.emptyList(), + new ConcurrentLinkedQueue<>(), + new ConcurrentLinkedQueue<>()); hookSupport.setHookContexts(hookSupportData, sharedContext, layered); hookSupport.executeBeforeHooks(hookSupportData); @@ -50,14 +56,18 @@ void shouldMergeEvaluationContextsOnBeforeHooksCorrectly() { assertThat(result.getValue("baseKey").asString()).isEqualTo("baseValue"); } - @ParameterizedTest - @EnumSource(value = FlagValueType.class) + @Test @DisplayName("should always call generic hook") - void shouldAlwaysCallGenericHook(FlagValueType flagValueType) { + void shouldAlwaysCallGenericHook() { Hook genericHook = mockGenericHook(); var hookSupportData = new HookSupportData(); - hookSupport.setHooks(hookSupportData, List.of(genericHook), flagValueType); + hookSupport.setHooks( + hookSupportData, + List.of(genericHook), + Collections.emptyList(), + new ConcurrentLinkedQueue<>(), + new ConcurrentLinkedQueue<>()); callAllHooks(hookSupportData); @@ -73,11 +83,14 @@ void shouldAlwaysCallGenericHook(FlagValueType flagValueType) { void shouldPassDataAcrossStages(FlagValueType flagValueType) { var testHook = new TestHookWithData(); var hookSupportData = new HookSupportData(); - hookSupport.setHooks(hookSupportData, List.of(testHook), flagValueType); - hookSupport.setHookContexts( + hookSupport.setHooks( hookSupportData, - getBaseHookContextForType(flagValueType), - new LayeredEvaluationContext(null, null, null, null)); + List.of(testHook), + Collections.emptyList(), + new ConcurrentLinkedQueue<>(), + new ConcurrentLinkedQueue<>()); + hookSupport.setHookContexts( + hookSupportData, getBaseHookContextForType(flagValueType), LayeredEvaluationContext.empty()); hookSupport.executeBeforeHooks(hookSupportData); assertHookData(testHook, "before"); @@ -102,11 +115,14 @@ void shouldIsolateDataBetweenHooks(FlagValueType flagValueType) { var testHook2 = new TestHookWithData(2); var hookSupportData = new HookSupportData(); - hookSupport.setHooks(hookSupportData, List.of(testHook1, testHook2), flagValueType); - hookSupport.setHookContexts( + hookSupport.setHooks( hookSupportData, - getBaseHookContextForType(flagValueType), - new LayeredEvaluationContext(null, null, null, null)); + List.of(testHook1, testHook2), + Collections.emptyList(), + new ConcurrentLinkedQueue<>(), + new ConcurrentLinkedQueue<>()); + hookSupport.setHookContexts( + hookSupportData, getBaseHookContextForType(flagValueType), LayeredEvaluationContext.empty()); callAllHooks(hookSupportData); @@ -132,7 +148,12 @@ public Optional before(HookContext ctx, Map hints) { var layeredEvaluationContext = new LayeredEvaluationContext(evaluationContextWithValue("key", "value"), null, null, null); hookSupportData.evaluationContext = layeredEvaluationContext; - hookSupport.setHooks(hookSupportData, List.of(recursiveHook, emptyHook), FlagValueType.STRING); + hookSupport.setHooks( + hookSupportData, + List.of(recursiveHook, emptyHook), + Collections.emptyList(), + new ConcurrentLinkedQueue<>(), + new ConcurrentLinkedQueue<>()); hookSupport.setHookContexts( hookSupportData, getBaseHookContextForType(FlagValueType.STRING), layeredEvaluationContext); diff --git a/src/test/java/dev/openfeature/sdk/LayeredEvaluationContextTest.java b/src/test/java/dev/openfeature/sdk/LayeredEvaluationContextTest.java index edbea81d5..b9ec1e939 100644 --- a/src/test/java/dev/openfeature/sdk/LayeredEvaluationContextTest.java +++ b/src/test/java/dev/openfeature/sdk/LayeredEvaluationContextTest.java @@ -26,7 +26,7 @@ class LayeredEvaluationContextTest { @Test void creatingLayeredContextWithNullsWorks() { - LayeredEvaluationContext layeredContext = new LayeredEvaluationContext(null, null, null, null); + LayeredEvaluationContext layeredContext = LayeredEvaluationContext.empty(); assertNotNull(layeredContext); assertNull(layeredContext.getTargetingKey()); assertEquals(Map.of(), layeredContext.asMap()); @@ -38,7 +38,7 @@ void creatingLayeredContextWithNullsWorks() { @Test void addingNullHookWorks() { - LayeredEvaluationContext layeredContext = new LayeredEvaluationContext(null, null, null, null); + LayeredEvaluationContext layeredContext = LayeredEvaluationContext.empty(); assertDoesNotThrow(() -> layeredContext.putHookContext(null)); } @@ -205,7 +205,7 @@ void mapIsGeneratedCorrectly() { @Test void emptyContextGeneratesEmptyMap() { - LayeredEvaluationContext layeredContext = new LayeredEvaluationContext(null, null, null, null); + LayeredEvaluationContext layeredContext = LayeredEvaluationContext.empty(); assertEquals(Map.of(), layeredContext.asMap()); assertEquals(Map.of(), layeredContext.asUnmodifiableMap()); assertEquals(Map.of(), layeredContext.asObjectMap()); @@ -251,7 +251,7 @@ void mapIsGeneratedCorrectly() { @Test void creatingMapWithCachedEmptyKeySetWorks() { - LayeredEvaluationContext layeredContext = new LayeredEvaluationContext(null, null, null, null); + LayeredEvaluationContext layeredContext = LayeredEvaluationContext.empty(); assertNotNull(layeredContext.keySet()); assertEquals(Map.of(), layeredContext.asObjectMap()); } @@ -337,7 +337,7 @@ void mutatingObjectMapHasNoSideEffects() { class IsEmpty { @Test void isEmptyWhenAllContextsAreNull() { - LayeredEvaluationContext layeredContext = new LayeredEvaluationContext(null, null, null, null); + LayeredEvaluationContext layeredContext = LayeredEvaluationContext.empty(); assertTrue(layeredContext.isEmpty()); } @@ -389,14 +389,14 @@ void isNotEmptyWhenInvocationAndClientAndTransactionAndApiContextIsSet() { @Test void isNotEmptyWhenHookContextIsSet() { - LayeredEvaluationContext layeredContext = new LayeredEvaluationContext(null, null, null, null); + LayeredEvaluationContext layeredContext = LayeredEvaluationContext.empty(); layeredContext.putHookContext(hookContext); assertFalse(layeredContext.isEmpty()); } @Test void isEmptyIfHookContextIsEmpty() { - LayeredEvaluationContext layeredContext = new LayeredEvaluationContext(null, null, null, null); + LayeredEvaluationContext layeredContext = LayeredEvaluationContext.empty(); layeredContext.putHookContext(new MutableContext()); assertTrue(layeredContext.isEmpty()); } diff --git a/src/test/java/dev/openfeature/sdk/OpenFeatureClientTest.java b/src/test/java/dev/openfeature/sdk/OpenFeatureClientTest.java index 31937ec2d..c8d0cced8 100644 --- a/src/test/java/dev/openfeature/sdk/OpenFeatureClientTest.java +++ b/src/test/java/dev/openfeature/sdk/OpenFeatureClientTest.java @@ -2,6 +2,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -12,6 +13,7 @@ import dev.openfeature.sdk.fixtures.HookFixtures; import dev.openfeature.sdk.testutils.testProvider.TestProvider; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; import org.junit.jupiter.api.AfterEach; @@ -19,6 +21,7 @@ import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mockito; import org.simplify4u.slf4jmock.LoggerMock; @@ -130,11 +133,158 @@ void shouldSupportUsageOfHookData(boolean isError) { assertThat(testHook.hookData.get("before")).isEqualTo("test-data"); assertThat(testHook.hookData.get("finallyAfter")).isEqualTo("test-data"); if (isError) { - assertThat(testHook.hookData.get("after")).isEqualTo(null); + assertThat(testHook.hookData.get("after")).isNull(); assertThat(testHook.hookData.get("error")).isEqualTo("test-data"); } else { assertThat(testHook.hookData.get("after")).isEqualTo("test-data"); - assertThat(testHook.hookData.get("error")).isEqualTo(null); + assertThat(testHook.hookData.get("error")).isNull(); + } + } + + @ParameterizedTest + @EnumSource(FlagValueType.class) + @DisplayName("Should call hooks that support the flag value type") + void shouldExecuteAppropriateHooks(FlagValueType flagValueType) { + var allTypes = FlagValueType.values(); + var apiHooks = new TypedTestHook[allTypes.length]; + var clientHooks = new TypedTestHook[allTypes.length]; + var providerHooks = new TypedTestHook[allTypes.length]; + var evaluationHooks = new TypedTestHook[allTypes.length]; + for (int i = 0; i < allTypes.length; i++) { + apiHooks[i] = new TypedTestHook(allTypes[i]); + clientHooks[i] = new TypedTestHook(allTypes[i]); + providerHooks[i] = new TypedTestHook(allTypes[i]); + evaluationHooks[i] = new TypedTestHook(allTypes[i]); + } + var allHooks = new TypedTestHook[][] {apiHooks, clientHooks, providerHooks, evaluationHooks}; + + OpenFeatureAPI api = new OpenFeatureAPI(); + var provider = TestProvider.builder() + .withHooks(providerHooks) + .allowUnknownFlags() + .initsToReady(); + api.setProviderAndWait(provider); + + Client client = api.getClient(); + + api.addHooks(apiHooks); + client.addHooks(clientHooks); + + var options = + FlagEvaluationOptions.builder().hooks(List.of(evaluationHooks)).build(); + + if (flagValueType == FlagValueType.BOOLEAN) { + client.getBooleanDetails("key", true, ImmutableContext.EMPTY, options); + } else if (flagValueType == FlagValueType.STRING) { + client.getStringDetails("key", "default", ImmutableContext.EMPTY, options); + } else if (flagValueType == FlagValueType.INTEGER) { + client.getIntegerDetails("key", 42, ImmutableContext.EMPTY, options); + } else if (flagValueType == FlagValueType.DOUBLE) { + client.getDoubleValue("key", 3.14, ImmutableContext.EMPTY, options); + } else if (flagValueType == FlagValueType.OBJECT) { + client.getObjectDetails("key", new Value(1), ImmutableContext.EMPTY, options); + } + + for (TypedTestHook[] level : allHooks) { + for (TypedTestHook hook : level) { + assertEquals( + flagValueType == hook.flagValueType, + hook.beforeCalled.get(), + () -> hook.flagValueType + + " hook called? " + + hook.beforeCalled.get() + + ", should have been called? " + + (flagValueType == hook.flagValueType)); + assertEquals( + flagValueType == hook.flagValueType, + hook.afterCalled.get(), + () -> hook.flagValueType + + " hook called? " + + hook.afterCalled.get() + + ", should have been called? " + + (flagValueType == hook.flagValueType)); + assertEquals( + flagValueType == hook.flagValueType, + hook.finallyAfterCalled.get(), + () -> hook.flagValueType + + " hook called? " + + hook.finallyAfterCalled.get() + + ", should have been called? " + + (flagValueType == hook.flagValueType)); + assertFalse(hook.errorCalled.get()); + } + } + } + + @ParameterizedTest + @EnumSource(FlagValueType.class) + @DisplayName("Should call hooks that support the flag value type in error scenarios") + void shouldExecuteAppropriateErrorHooks(FlagValueType flagValueType) { + var allTypes = FlagValueType.values(); + var apiHooks = new TypedTestHook[allTypes.length]; + var clientHooks = new TypedTestHook[allTypes.length]; + var providerHooks = new TypedTestHook[allTypes.length]; + var evaluationHooks = new TypedTestHook[allTypes.length]; + for (int i = 0; i < allTypes.length; i++) { + apiHooks[i] = new TypedTestHook(allTypes[i]); + clientHooks[i] = new TypedTestHook(allTypes[i]); + providerHooks[i] = new TypedTestHook(allTypes[i]); + evaluationHooks[i] = new TypedTestHook(allTypes[i]); + } + var allHooks = new TypedTestHook[][] {apiHooks, clientHooks, providerHooks, evaluationHooks}; + + OpenFeatureAPI api = new OpenFeatureAPI(); + var provider = TestProvider.builder().withHooks(providerHooks).initsToReady(); + api.setProviderAndWait(provider); + + Client client = api.getClient(); + + api.addHooks(apiHooks); + client.addHooks(clientHooks); + + var options = + FlagEvaluationOptions.builder().hooks(List.of(evaluationHooks)).build(); + + if (flagValueType == FlagValueType.BOOLEAN) { + client.getBooleanDetails("key", true, ImmutableContext.EMPTY, options); + } else if (flagValueType == FlagValueType.STRING) { + client.getStringDetails("key", "default", ImmutableContext.EMPTY, options); + } else if (flagValueType == FlagValueType.INTEGER) { + client.getIntegerDetails("key", 42, ImmutableContext.EMPTY, options); + } else if (flagValueType == FlagValueType.DOUBLE) { + client.getDoubleValue("key", 3.14, ImmutableContext.EMPTY, options); + } else if (flagValueType == FlagValueType.OBJECT) { + client.getObjectDetails("key", new Value(1), ImmutableContext.EMPTY, options); + } + + for (TypedTestHook[] level : allHooks) { + for (TypedTestHook hook : level) { + assertEquals( + flagValueType == hook.flagValueType, + hook.beforeCalled.get(), + () -> hook.flagValueType + + " hook called? " + + hook.beforeCalled.get() + + ", should have been called? " + + (flagValueType == hook.flagValueType)); + assertEquals( + flagValueType == hook.flagValueType, + hook.errorCalled.get(), + () -> hook.flagValueType + + " hook called? " + + hook.errorCalled.get() + + ", should have been called? " + + (flagValueType == hook.flagValueType)); + assertEquals( + flagValueType == hook.flagValueType, + hook.finallyAfterCalled.get(), + () -> hook.flagValueType + + " hook called? " + + hook.finallyAfterCalled.get() + + ", should have been called? " + + (flagValueType == hook.flagValueType)); + assertFalse(hook.afterCalled.get()); + } } } diff --git a/src/test/java/dev/openfeature/sdk/TypedTestHook.java b/src/test/java/dev/openfeature/sdk/TypedTestHook.java new file mode 100644 index 000000000..96ffed637 --- /dev/null +++ b/src/test/java/dev/openfeature/sdk/TypedTestHook.java @@ -0,0 +1,43 @@ +package dev.openfeature.sdk; + +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; + +public class TypedTestHook implements Hook { + public final FlagValueType flagValueType; + public final AtomicBoolean beforeCalled = new AtomicBoolean(false); + public final AtomicBoolean afterCalled = new AtomicBoolean(false); + public final AtomicBoolean errorCalled = new AtomicBoolean(false); + public final AtomicBoolean finallyAfterCalled = new AtomicBoolean(false); + + public TypedTestHook(FlagValueType flagValueType) { + this.flagValueType = flagValueType; + } + + @Override + public boolean supportsFlagValueType(FlagValueType flagValueType) { + return this.flagValueType == flagValueType; + } + + @Override + public Optional before(HookContext ctx, Map hints) { + beforeCalled.set(true); + return Optional.empty(); + } + + @Override + public void after(HookContext ctx, FlagEvaluationDetails details, Map hints) { + afterCalled.set(true); + } + + @Override + public void error(HookContext ctx, Exception error, Map hints) { + errorCalled.set(true); + } + + @Override + public void finallyAfter(HookContext ctx, FlagEvaluationDetails details, Map hints) { + finallyAfterCalled.set(true); + } +} diff --git a/src/test/java/dev/openfeature/sdk/vmlens/VmLensCT.java b/src/test/java/dev/openfeature/sdk/vmlens/VmLensCT.java index c09e254e6..bb9e070a5 100644 --- a/src/test/java/dev/openfeature/sdk/vmlens/VmLensCT.java +++ b/src/test/java/dev/openfeature/sdk/vmlens/VmLensCT.java @@ -6,6 +6,10 @@ import com.vmlens.api.AllInterleavings; import com.vmlens.api.Runner; +import dev.openfeature.sdk.Client; +import dev.openfeature.sdk.EvaluationContext; +import dev.openfeature.sdk.Hook; +import dev.openfeature.sdk.HookContext; import dev.openfeature.sdk.ImmutableContext; import dev.openfeature.sdk.OpenFeatureAPI; import dev.openfeature.sdk.OpenFeatureAPITestUtil; @@ -13,20 +17,26 @@ import dev.openfeature.sdk.providers.memory.Flag; import dev.openfeature.sdk.providers.memory.InMemoryProvider; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.Optional; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; class VmLensCT { - final OpenFeatureAPI api = OpenFeatureAPITestUtil.createAPI(); + private OpenFeatureAPI api; + private Client client; @BeforeEach void setUp() { + api = OpenFeatureAPITestUtil.createAPI(); var flags = new HashMap>(); flags.put("a", Flag.builder().variant("a", "def").defaultVariant("a").build()); flags.put("b", Flag.builder().variant("a", "as").defaultVariant("a").build()); api.setProviderAndWait(new InMemoryProvider(flags)); + client = api.getClient(); } @AfterEach @@ -48,7 +58,6 @@ void concurrentClientCreations() { @Test void concurrentFlagEvaluations() { - var client = api.getClient(); try (AllInterleavings allInterleavings = new AllInterleavings("Concurrent evaluations")) { while (allInterleavings.hasNext()) { Runner.runParallel( @@ -58,19 +67,76 @@ void concurrentFlagEvaluations() { } } - @Test - void concurrentContextSetting() { - var client = api.getClient(); - var contextA = new ImmutableContext(Map.of("a", new Value("b"))); - var contextB = new ImmutableContext(Map.of("c", new Value("d"))); - try (AllInterleavings allInterleavings = - new AllInterleavings("Concurrently setting the context and evaluating a flag")) { - while (allInterleavings.hasNext()) { - Runner.runParallel( - () -> assertEquals("def", client.getStringValue("a", "a")), - () -> client.setEvaluationContext(contextA), - () -> client.setEvaluationContext(contextB)); - assertThat(client.getEvaluationContext()).isIn(contextA, contextB); + @Nested + class ConcurrentContext { + private final ImmutableContext contextA = new ImmutableContext(Map.of("a", new Value("b"))); + private final ImmutableContext contextB = new ImmutableContext(Map.of("c", new Value("d"))); + + @Test + void concurrentContextSetting() { + try (AllInterleavings allInterleavings = + new AllInterleavings("Concurrently setting the context and evaluating a flag")) { + while (allInterleavings.hasNext()) { + Runner.runParallel( + () -> assertEquals("def", client.getStringValue("a", "a")), + () -> client.setEvaluationContext(contextA), + () -> client.setEvaluationContext(contextB)); + assertThat(client.getEvaluationContext()).isIn(contextA, contextB); + } + } + } + } + + @Nested + class ConcurrentHooks { + private final Hook hook0 = new Hook<>() {}; + private final Hook hook1 = new Hook<>() { + @Override + public Optional before(HookContext ctx, Map hints) { + return Optional.of(new ImmutableContext(Map.of("c", new Value("d")))); + } + }; + + @Test + void concurrentAdditionOfHooksToClient() { + try (AllInterleavings allInterleavings = + new AllInterleavings("Concurrently adding client hooks and evaluating a flag")) { + while (allInterleavings.hasNext()) { + Runner.runParallel( + () -> assertEquals("def", client.getStringValue("a", "a")), + () -> client.addHooks(hook0), + () -> client.addHooks(hook1)); + assertThat(client.getHooks()).containsAll(List.of(hook0, hook1)); + } + } + } + + @Test + void concurrentAdditionOfHooksToApi() { + try (AllInterleavings allInterleavings = + new AllInterleavings("Concurrently adding api hooks and evaluating a flag")) { + while (allInterleavings.hasNext()) { + Runner.runParallel( + () -> assertEquals("def", client.getStringValue("a", "a")), + () -> api.addHooks(hook0), + () -> api.addHooks(hook1)); + assertThat(api.getHooks()).containsAll(List.of(hook0, hook1)); + } + } + } + + @Test + void concurrentAdditionOfHooksToApiAndClient() { + try (AllInterleavings allInterleavings = + new AllInterleavings("Concurrently adding api and client hooks and evaluating a flag")) { + while (allInterleavings.hasNext()) { + Runner.runParallel( + () -> assertEquals("def", client.getStringValue("a", "a")), + () -> api.addHooks(hook0), + () -> client.addHooks(hook1)); + assertThat(api.getHooks()).containsAll(List.of(hook0)); + assertThat(client.getHooks()).containsAll(List.of(hook1)); + } } } }