diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java index 417826437..1a3e1d352 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java @@ -20,6 +20,7 @@ public final class Config { static final int DEFAULT_MAX_CACHE_SIZE = 1000; static final int DEFAULT_OFFLINE_POLL_MS = 5000; static final long DEFAULT_KEEP_ALIVE = 0; + static final String DEFAULT_REINITIALIZE_ON_ERROR = "false"; static final String RESOLVER_ENV_VAR = "FLAGD_RESOLVER"; static final String HOST_ENV_VAR_NAME = "FLAGD_HOST"; @@ -51,6 +52,7 @@ public final class Config { static final String KEEP_ALIVE_MS_ENV_VAR_NAME = "FLAGD_KEEP_ALIVE_TIME_MS"; static final String TARGET_URI_ENV_VAR_NAME = "FLAGD_TARGET_URI"; static final String STREAM_RETRY_GRACE_PERIOD = "FLAGD_RETRY_GRACE_PERIOD"; + static final String REINITIALIZE_ON_ERROR_ENV_VAR_NAME = "FLAGD_REINITIALIZE_ON_ERROR"; static final String RESOLVER_RPC = "rpc"; static final String RESOLVER_IN_PROCESS = "in-process"; diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java index 17e86e6d1..4cda34df4 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java @@ -204,6 +204,16 @@ public class FlagdOptions { @Builder.Default private String defaultAuthority = fallBackToEnvOrDefault(Config.DEFAULT_AUTHORITY_ENV_VAR_NAME, null); + /** + * !EXPERIMENTAL! + * Whether to reinitialize the channel (TCP connection) after the grace period is exceeded. + * This can help recover from connection issues by creating fresh connections. + * Particularly useful for troubleshooting network issues related to proxies or service meshes. + */ + @Builder.Default + private boolean reinitializeOnError = Boolean.parseBoolean( + fallBackToEnvOrDefault(Config.REINITIALIZE_ON_ERROR_ENV_VAR_NAME, Config.DEFAULT_REINITIALIZE_ON_ERROR)); + /** * Builder overwrite in order to customize the "build" method. * diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java index e54c938cf..49e61e847 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java @@ -41,6 +41,7 @@ public class InProcessResolver implements Resolver { private final Consumer onConnectionEvent; private final Operator operator; private final String scope; + private final QueueSource queueSource; /** * Resolves flag values using @@ -52,7 +53,8 @@ public class InProcessResolver implements Resolver { * connection/stream */ public InProcessResolver(FlagdOptions options, Consumer onConnectionEvent) { - this.flagStore = new FlagStore(getConnector(options, onConnectionEvent)); + this.queueSource = getQueueSource(options, onConnectionEvent); + this.flagStore = new FlagStore(queueSource); this.onConnectionEvent = onConnectionEvent; this.operator = new Operator(); this.scope = options.getSelector(); @@ -94,6 +96,19 @@ public void init() throws Exception { stateWatcher.start(); } + /** + * Called when the provider enters error state after grace period. + * Attempts to reinitialize the sync connector if enabled. + */ + @Override + public void onError() { + if (queueSource instanceof SyncStreamQueueSource) { + SyncStreamQueueSource syncConnector = (SyncStreamQueueSource) queueSource; + // only reinitialize if option is enabled + syncConnector.reinitializeChannelComponents(); + } + } + /** * Shutdown in-process resolver. * @@ -147,7 +162,7 @@ public ProviderEvaluation objectEvaluation(String key, Value defaultValue .build(); } - static QueueSource getConnector(final FlagdOptions options, Consumer onConnectionEvent) { + static QueueSource getQueueSource(final FlagdOptions options, Consumer onConnectionEvent) { if (options.getCustomConnector() != null) { return options.getCustomConnector(); } diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSource.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSource.java index a3b01f913..8fa12c245 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSource.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSource.java @@ -28,12 +28,13 @@ import lombok.extern.slf4j.Slf4j; /** - * Implements the {@link QueueSource} contract and emit flags obtained from flagd sync gRPC contract. + * Implements the {@link QueueSource} contract and emit flags obtained from + * flagd sync gRPC contract. */ @Slf4j @SuppressFBWarnings( value = {"EI_EXPOSE_REP"}, - justification = "Random is used to generate a variation & flag configurations require exposing") + justification = "We need to expose the BlockingQueue to allow consumers to read from it") public class SyncStreamQueueSource implements QueueSource { private static final int QUEUE_SIZE = 5; @@ -45,13 +46,32 @@ public class SyncStreamQueueSource implements QueueSource { private final String selector; private final String providerId; private final boolean syncMetadataDisabled; - private final ChannelConnector channelConnector; + private final boolean reinitializeOnError; + private final FlagdOptions options; + private final Consumer onConnectionEvent; private final BlockingQueue outgoingQueue = new LinkedBlockingQueue<>(QUEUE_SIZE); - private final FlagSyncServiceStub flagSyncStub; - private final FlagSyncServiceBlockingStub metadataStub; + private volatile GrpcComponents grpcComponents; /** - * Creates a new SyncStreamQueueSource responsible for observing the event stream. + * Container for gRPC components to ensure atomicity during reinitialization. + * All three components are updated together to prevent consumers from seeing + * an inconsistent state where components are from different channel instances. + */ + private static class GrpcComponents { + final ChannelConnector channelConnector; + final FlagSyncServiceStub flagSyncStub; + final FlagSyncServiceBlockingStub metadataStub; + + GrpcComponents(ChannelConnector connector, FlagSyncServiceStub stub, FlagSyncServiceBlockingStub blockingStub) { + this.channelConnector = connector; + this.flagSyncStub = stub; + this.metadataStub = blockingStub; + } + } + + /** + * Creates a new SyncStreamQueueSource responsible for observing the event + * stream. */ public SyncStreamQueueSource(final FlagdOptions options, Consumer onConnectionEvent) { streamDeadline = options.getStreamDeadlineMs(); @@ -60,11 +80,10 @@ public SyncStreamQueueSource(final FlagdOptions options, Consumer {}; + this.grpcComponents = new GrpcComponents(connectorMock, stubMock, blockingStubMock); + } + + /** Initialize channel connector and stubs. */ + private synchronized void initializeChannelComponents() { + ChannelConnector newConnector = + new ChannelConnector(options, onConnectionEvent, ChannelBuilder.nettyChannel(options)); + FlagSyncServiceStub newFlagSyncStub = + FlagSyncServiceGrpc.newStub(newConnector.getChannel()).withWaitForReady(); + FlagSyncServiceBlockingStub newMetadataStub = + FlagSyncServiceGrpc.newBlockingStub(newConnector.getChannel()).withWaitForReady(); + + // atomic assignment of all components as a single unit + grpcComponents = new GrpcComponents(newConnector, newFlagSyncStub, newMetadataStub); + } + + /** Reinitialize channel connector and stubs on error. */ + public synchronized void reinitializeChannelComponents() { + if (!reinitializeOnError || shutdown.get()) { + return; + } + + log.info("Reinitializing channel gRPC components in attempt to restore stream."); + GrpcComponents oldComponents = grpcComponents; + + try { + // create new channel components first + initializeChannelComponents(); + } catch (Exception e) { + log.error("Failed to reinitialize channel components", e); + return; + } + + // shutdown old connector after successful reinitialization + if (oldComponents != null && oldComponents.channelConnector != null) { + try { + oldComponents.channelConnector.shutdown(); + } catch (Exception e) { + log.debug("Error shutting down old channel connector during reinitialization", e); + } + } } /** Initialize sync stream connector. */ public void init() throws Exception { - channelConnector.initialize(); + grpcComponents.channelConnector.initialize(); Thread listener = new Thread(this::observeSyncStream); listener.setDaemon(true); listener.start(); @@ -109,7 +169,7 @@ public void shutdown() throws InterruptedException { log.debug("Shutdown already in progress or completed"); return; } - this.channelConnector.shutdown(); + grpcComponents.channelConnector.shutdown(); } /** Contains blocking calls, to be used concurrently. */ @@ -159,13 +219,14 @@ private void observeSyncStream() { log.info("Shutdown invoked, exiting event stream listener"); } - // TODO: remove the metadata call entirely after https://github.com/open-feature/flagd/issues/1584 + // TODO: remove the metadata call entirely after + // https://github.com/open-feature/flagd/issues/1584 private Struct getMetadata() { if (syncMetadataDisabled) { return null; } - FlagSyncServiceBlockingStub localStub = metadataStub; + FlagSyncServiceBlockingStub localStub = grpcComponents.metadataStub; if (deadline > 0) { localStub = localStub.withDeadlineAfter(deadline, TimeUnit.MILLISECONDS); @@ -180,7 +241,8 @@ private Struct getMetadata() { return null; } catch (StatusRuntimeException e) { - // In newer versions of flagd, metadata is part of the sync stream. If the method is unimplemented, we + // In newer versions of flagd, metadata is part of the sync stream. If the + // method is unimplemented, we // can ignore the error if (e.getStatus() != null && Status.Code.UNIMPLEMENTED.equals(e.getStatus().getCode())) { @@ -192,7 +254,7 @@ private Struct getMetadata() { } private void syncFlags(SyncStreamObserver streamObserver) { - FlagSyncServiceStub localStub = flagSyncStub; // don't mutate the stub + FlagSyncServiceStub localStub = grpcComponents.flagSyncStub; // don't mutate the stub if (streamDeadline > 0) { localStub = localStub.withDeadlineAfter(streamDeadline, TimeUnit.MILLISECONDS); } diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolverTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolverTest.java index ffd898447..56ce4cc68 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolverTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolverTest.java @@ -17,6 +17,9 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import dev.openfeature.contrib.providers.flagd.Config; import dev.openfeature.contrib.providers.flagd.FlagdOptions; @@ -51,6 +54,26 @@ import org.junit.jupiter.api.Test; class InProcessResolverTest { + @Test + void onError_delegatesToQueueSource() throws Exception { + // given + FlagdOptions options = FlagdOptions.builder().build(); // option value doesn't matter here + SyncStreamQueueSource mockConnector = mock(SyncStreamQueueSource.class); + InProcessResolver resolver = new InProcessResolver(options, e -> {}); + + // Inject mock connector + java.lang.reflect.Field queueSourceField = InProcessResolver.class.getDeclaredField("queueSource"); + queueSourceField.setAccessible(true); + queueSourceField.set(resolver, mockConnector); + + // when + resolver.onError(); + + // then + // InProcessResolver should always delegate to the queue source. + // The decision to re-initialize or not is handled within SyncStreamQueueSource. + verify(mockConnector, times(1)).reinitializeChannelComponents(); + } @Test public void connectorSetup() { @@ -70,9 +93,9 @@ public void connectorSetup() { .build(); // then - assertInstanceOf(SyncStreamQueueSource.class, InProcessResolver.getConnector(forGrpcOptions, e -> {})); - assertInstanceOf(FileQueueSource.class, InProcessResolver.getConnector(forOfflineOptions, e -> {})); - assertInstanceOf(MockConnector.class, InProcessResolver.getConnector(forCustomConnectorOptions, e -> {})); + assertInstanceOf(SyncStreamQueueSource.class, InProcessResolver.getQueueSource(forGrpcOptions, e -> {})); + assertInstanceOf(FileQueueSource.class, InProcessResolver.getQueueSource(forOfflineOptions, e -> {})); + assertInstanceOf(MockConnector.class, InProcessResolver.getQueueSource(forCustomConnectorOptions, e -> {})); } @Test diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSourceTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSourceTest.java index 9116b8142..b3b999616 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSourceTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/storage/connector/sync/SyncStreamQueueSourceTest.java @@ -33,6 +33,75 @@ import org.mockito.stubbing.Answer; class SyncStreamQueueSourceTest { + @Test + void reinitializeChannelComponents_reinitializesWhenEnabled() throws InterruptedException { + FlagdOptions options = FlagdOptions.builder().reinitializeOnError(true).build(); + ChannelConnector initialConnector = mock(ChannelConnector.class); + FlagSyncServiceStub initialStub = mock(FlagSyncServiceStub.class); + FlagSyncServiceBlockingStub initialBlockingStub = mock(FlagSyncServiceBlockingStub.class); + SyncStreamQueueSource queueSource = + new SyncStreamQueueSource(options, initialConnector, initialStub, initialBlockingStub); + + try { + // save reference to old GrpcComponents + Object oldComponents = getPrivateField(queueSource, "grpcComponents"); + queueSource.reinitializeChannelComponents(); + Object newComponents = getPrivateField(queueSource, "grpcComponents"); + // should have replaced grpcComponents + assertNotNull(newComponents); + org.junit.jupiter.api.Assertions.assertNotSame(oldComponents, newComponents); + } finally { + queueSource.shutdown(); + } + } + + @Test + void reinitializeChannelComponents_doesNothingWhenDisabled() throws InterruptedException { + FlagdOptions options = FlagdOptions.builder().reinitializeOnError(false).build(); + ChannelConnector initialConnector = mock(ChannelConnector.class); + FlagSyncServiceStub initialStub = mock(FlagSyncServiceStub.class); + FlagSyncServiceBlockingStub initialBlockingStub = mock(FlagSyncServiceBlockingStub.class); + SyncStreamQueueSource queueSource = + new SyncStreamQueueSource(options, initialConnector, initialStub, initialBlockingStub); + + try { + Object oldComponents = getPrivateField(queueSource, "grpcComponents"); + queueSource.reinitializeChannelComponents(); + Object newComponents = getPrivateField(queueSource, "grpcComponents"); + // should NOT have replaced grpcComponents + org.junit.jupiter.api.Assertions.assertSame(oldComponents, newComponents); + } finally { + queueSource.shutdown(); + } + } + + @Test + void reinitializeChannelComponents_doesNothingWhenShutdown() throws InterruptedException { + FlagdOptions options = FlagdOptions.builder().reinitializeOnError(true).build(); + ChannelConnector initialConnector = mock(ChannelConnector.class); + FlagSyncServiceStub initialStub = mock(FlagSyncServiceStub.class); + FlagSyncServiceBlockingStub initialBlockingStub = mock(FlagSyncServiceBlockingStub.class); + SyncStreamQueueSource queueSource = + new SyncStreamQueueSource(options, initialConnector, initialStub, initialBlockingStub); + + queueSource.shutdown(); + Object oldComponents = getPrivateField(queueSource, "grpcComponents"); + queueSource.reinitializeChannelComponents(); + Object newComponents = getPrivateField(queueSource, "grpcComponents"); + // should NOT have replaced grpcComponents + org.junit.jupiter.api.Assertions.assertSame(oldComponents, newComponents); + } + // helper to access private fields via reflection + private static Object getPrivateField(Object instance, String fieldName) { + try { + java.lang.reflect.Field field = instance.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + return field.get(instance); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + private ChannelConnector mockConnector; private FlagSyncServiceBlockingStub blockingStub; private FlagSyncServiceStub stub; @@ -42,47 +111,54 @@ class SyncStreamQueueSourceTest { private CountDownLatch latch; // used to wait for observer to be initialized @BeforeEach + @SuppressWarnings("deprecation") public void setup() throws Exception { blockingStub = mock(FlagSyncServiceBlockingStub.class); when(blockingStub.withDeadlineAfter(anyLong(), any())).thenReturn(blockingStub); when(blockingStub.getMetadata(any())).thenReturn(GetMetadataResponse.getDefaultInstance()); mockConnector = mock(ChannelConnector.class); - doNothing().when(mockConnector).initialize(); // Mock the initialize method + doNothing().when(mockConnector).initialize(); // mock the initialize method stub = mock(FlagSyncServiceStub.class); when(stub.withDeadlineAfter(anyLong(), any())).thenReturn(stub); doAnswer((Answer) invocation -> { Object[] args = invocation.getArguments(); - observer = (StreamObserver) args[1]; + @SuppressWarnings("unchecked") + StreamObserver obs = (StreamObserver) args[1]; + observer = obs; latch.countDown(); return null; }) .when(stub) - .syncFlags(any(SyncFlagsRequest.class), any(StreamObserver.class)); // Mock the initialize + .syncFlags(any(SyncFlagsRequest.class), any()); // mock the initialize syncErrorStub = mock(FlagSyncServiceStub.class); when(syncErrorStub.withDeadlineAfter(anyLong(), any())).thenReturn(syncErrorStub); doAnswer((Answer) invocation -> { Object[] args = invocation.getArguments(); - observer = (StreamObserver) args[1]; + @SuppressWarnings("unchecked") + StreamObserver obs = (StreamObserver) args[1]; + observer = obs; latch.countDown(); throw new StatusRuntimeException(io.grpc.Status.NOT_FOUND); }) .when(syncErrorStub) - .syncFlags(any(SyncFlagsRequest.class), any(StreamObserver.class)); // Mock the initialize + .syncFlags(any(SyncFlagsRequest.class), any()); // mock the initialize asyncErrorStub = mock(FlagSyncServiceStub.class); when(asyncErrorStub.withDeadlineAfter(anyLong(), any())).thenReturn(asyncErrorStub); doAnswer((Answer) invocation -> { Object[] args = invocation.getArguments(); - observer = (StreamObserver) args[1]; + @SuppressWarnings("unchecked") + StreamObserver obs = (StreamObserver) args[1]; + observer = obs; latch.countDown(); - // Start a thread to call onError after a short delay + // start a thread to call onError after a short delay new Thread(() -> { try { - Thread.sleep(10); // Wait 100ms before calling onError + Thread.sleep(10); // wait 10ms before calling onError observer.onError(new StatusRuntimeException(io.grpc.Status.INTERNAL)); } catch (InterruptedException e) { Thread.currentThread().interrupt(); @@ -93,7 +169,7 @@ public void setup() throws Exception { return null; }) .when(asyncErrorStub) - .syncFlags(any(SyncFlagsRequest.class), any(StreamObserver.class)); // Mock the initialize + .syncFlags(any(SyncFlagsRequest.class), any()); // mock the initialize } @Test @@ -114,7 +190,7 @@ void syncInitError_DoesNotBusyWait() throws Exception { QueuePayload payload = streamQueue.poll(1000, TimeUnit.MILLISECONDS); assertNotNull(payload); assertEquals(QueuePayloadType.ERROR, payload.getType()); - Thread.sleep(maxBackoffMs + (maxBackoffMs / 2)); // wait 1.5x our delay for reties + Thread.sleep(maxBackoffMs + (maxBackoffMs / 2)); // wait 1.5x our delay for retries // should have retried the stream (2 calls); initial + 1 retry // it's very important that the retry count is low, to confirm no busy-loop @@ -139,7 +215,7 @@ void asyncInitError_DoesNotBusyWait() throws Exception { QueuePayload payload = streamQueue.poll(1000, TimeUnit.MILLISECONDS); assertNotNull(payload); assertEquals(QueuePayloadType.ERROR, payload.getType()); - Thread.sleep(maxBackoffMs + (maxBackoffMs / 2)); // wait 1.5x our delay for reties + Thread.sleep(maxBackoffMs + (maxBackoffMs / 2)); // wait 1.5x our delay for retries // should have retried the stream (2 calls); initial + 1 retry // it's very important that the retry count is low, to confirm no busy-loop @@ -168,6 +244,7 @@ void onNextEnqueuesDataPayload() throws Exception { } @Test + @SuppressWarnings("deprecation") void onNextEnqueuesDataPayloadMetadataDisabled() throws Exception { // disable GetMetadata call SyncStreamQueueSource queueSource = new SyncStreamQueueSource(