diff --git a/src/NATS.Client.Core/Internal/SubscriptionManager.cs b/src/NATS.Client.Core/Internal/SubscriptionManager.cs index 0ad353a24..c61ad8343 100644 --- a/src/NATS.Client.Core/Internal/SubscriptionManager.cs +++ b/src/NATS.Client.Core/Internal/SubscriptionManager.cs @@ -2,6 +2,7 @@ using System.Collections.Concurrent; using System.Runtime.CompilerServices; using Microsoft.Extensions.Logging; +using NATS.Client.Core.Commands; namespace NATS.Client.Core.Internal; @@ -130,19 +131,32 @@ public ValueTask RemoveAsync(NatsSubBase sub) return _connection.UnsubscribeAsync(subMetadata.Sid); } - public async ValueTask ReconnectAsync(CancellationToken cancellationToken) + /// + /// Returns commands for all the live subscriptions to be used on reconnect so that they can rebuild their connection state on the server. + /// + /// + /// Commands returned form all the subscriptions will be run as a priority right after reconnection is established. + /// + /// Enumerable list of commands + public IEnumerable GetReconnectCommands() { - foreach (var (sid, sidMetadata) in _bySid) + var subs = new List<(NatsSubBase, int)>(); + lock (_gate) { - if (sidMetadata.WeakReference.TryGetTarget(out var sub)) + foreach (var (sid, sidMetadata) in _bySid) { - // yield return (sid, sub.Subject, sub.QueueGroup, sub.PendingMsgs); - await _connection - .SubscribeCoreAsync(sid, sub.Subject, sub.QueueGroup, sub.PendingMsgs, cancellationToken) - .ConfigureAwait(false); - await sub.ReadyAsync().ConfigureAwait(false); + if (sidMetadata.WeakReference.TryGetTarget(out var sub)) + { + subs.Add((sub, sid)); + } } } + + foreach (var (sub, sid) in subs) + { + foreach (var command in sub.GetReconnectCommands(sid)) + yield return command; + } } public ISubscriptionManager GetManagerFor(string subject) diff --git a/src/NATS.Client.Core/NatsConnection.cs b/src/NATS.Client.Core/NatsConnection.cs index 81b71f8b7..d557565b7 100644 --- a/src/NATS.Client.Core/NatsConnection.cs +++ b/src/NATS.Client.Core/NatsConnection.cs @@ -94,6 +94,8 @@ public NatsConnection(NatsOptions options) internal string InboxPrefix { get; } + internal ObjectPool ObjectPool => _pool; + /// /// Connect socket and write CONNECT command to nats server. /// @@ -385,6 +387,12 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect) _writerState.PriorityCommands.Add(connectCommand); _writerState.PriorityCommands.Add(PingCommand.Create(_pool, GetCancellationTimer(CancellationToken.None))); + if (reconnect) + { + // Reestablish subscriptions and consumers + _writerState.PriorityCommands.AddRange(SubscriptionManager.GetReconnectCommands()); + } + // create the socket writer _socketWriter = new NatsPipeliningWriteProtocolProcessor(_socket!, _writerState, _pool, Counter); @@ -393,9 +401,6 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect) // receive COMMAND response (PONG or ERROR) await waitForPongOrErrorSignal.Task.ConfigureAwait(false); - - // Reestablish subscriptions and consumers - await SubscriptionManager.ReconnectAsync(_disposedCancellationTokenSource.Token).ConfigureAwait(false); } catch (Exception) { diff --git a/src/NATS.Client.Core/NatsSubBase.cs b/src/NATS.Client.Core/NatsSubBase.cs index 8124c3d3e..0cdcd319f 100644 --- a/src/NATS.Client.Core/NatsSubBase.cs +++ b/src/NATS.Client.Core/NatsSubBase.cs @@ -1,5 +1,6 @@ using System.Buffers; using System.Runtime.ExceptionServices; +using NATS.Client.Core.Commands; using NATS.Client.Core.Internal; namespace NATS.Client.Core; @@ -187,6 +188,21 @@ public virtual async ValueTask ReceiveAsync(string subject, string? replyTo, Rea internal void ClearException() => Interlocked.Exchange(ref _exception, null); + /// + /// Collect commands when reconnecting. + /// + /// + /// By default this will yield the required subscription command. + /// When overriden base must be called to yield the re-subscription command. + /// Additional command (e.g. publishing pull requests in case of JetStream consumers) can be yielded as part of the reconnect routine. + /// + /// SID which might be required to create subscription commands + /// IEnumerable list of commands + internal virtual IEnumerable GetReconnectCommands(int sid) + { + yield return AsyncSubscribeCommand.Create(Connection.ObjectPool, Connection.GetCancellationTimer(default), sid, Subject, QueueGroup, PendingMsgs); + } + /// /// Invoked when a MSG or HMSG arrives for the subscription. /// diff --git a/tests/NATS.Client.Core.Tests/ProtocolTest.cs b/tests/NATS.Client.Core.Tests/ProtocolTest.cs index 43edfaf78..dc194ca6a 100644 --- a/tests/NATS.Client.Core.Tests/ProtocolTest.cs +++ b/tests/NATS.Client.Core.Tests/ProtocolTest.cs @@ -1,3 +1,6 @@ +using System.Buffers; +using System.Text; + namespace NATS.Client.Core.Tests; public class ProtocolTest @@ -289,4 +292,79 @@ await Retry.Until( await reg; } } + + [Fact] + public async Task Reconnect_with_sub_and_additional_commands() + { + await using var server = NatsServer.Start(); + var (nats, proxy) = server.CreateProxiedClientConnection(); + + const string subject = "foo"; + + var sync = 0; + await using var sub = new NatsSubReconnectTest(nats, subject, i => Interlocked.Exchange(ref sync, i)); + await nats.SubAsync(sub.Subject, opts: default, sub); + + await Retry.Until( + "subscribed", + () => Volatile.Read(ref sync) == 1, + async () => await nats.PublishAsync(subject, 1)); + + var disconnected = new WaitSignal(); + nats.ConnectionDisconnected += (_, _) => disconnected.Pulse(); + + proxy.Reset(); + + await disconnected; + + await Retry.Until( + "re-subscribed", + () => Volatile.Read(ref sync) == 2, + async () => await nats.PublishAsync(subject, 2)); + + await Retry.Until( + "frames collected", + () => proxy.ClientFrames.Any(f => f.Message.StartsWith("PUB foo"))); + + var frames = proxy.ClientFrames.Select(f => f.Message).ToList(); + Assert.StartsWith("SUB foo", frames[0]); + Assert.StartsWith("PUB bar1", frames[1]); + Assert.StartsWith("PUB bar2", frames[2]); + Assert.StartsWith("PUB bar3", frames[3]); + Assert.StartsWith("PUB foo", frames[4]); + + await nats.DisposeAsync(); + } + + private sealed class NatsSubReconnectTest : NatsSubBase + { + private readonly Action _callback; + + internal NatsSubReconnectTest(NatsConnection connection, string subject, Action callback) + : base(connection, connection.SubscriptionManager, subject, default) => + _callback = callback; + + internal override IEnumerable GetReconnectCommands(int sid) + { + // Yield re-subscription + foreach (var command in base.GetReconnectCommands(sid)) + yield return command; + + // Any additional commands to send on reconnect + yield return PublishBytesCommand.Create(Connection.ObjectPool, "bar1", default, default, default, default); + yield return PublishBytesCommand.Create(Connection.ObjectPool, "bar2", default, default, default, default); + yield return PublishBytesCommand.Create(Connection.ObjectPool, "bar3", default, default, default, default); + } + + protected override ValueTask ReceiveInternalAsync(string subject, string? replyTo, ReadOnlySequence? headersBuffer, ReadOnlySequence payloadBuffer) + { + _callback(int.Parse(Encoding.UTF8.GetString(payloadBuffer))); + DecrementMaxMsgs(); + return ValueTask.CompletedTask; + } + + protected override void TryComplete() + { + } + } }