diff --git a/src/NATS.Client.Core/Internal/SubscriptionManager.cs b/src/NATS.Client.Core/Internal/SubscriptionManager.cs
index 0ad353a24..c61ad8343 100644
--- a/src/NATS.Client.Core/Internal/SubscriptionManager.cs
+++ b/src/NATS.Client.Core/Internal/SubscriptionManager.cs
@@ -2,6 +2,7 @@
using System.Collections.Concurrent;
using System.Runtime.CompilerServices;
using Microsoft.Extensions.Logging;
+using NATS.Client.Core.Commands;
namespace NATS.Client.Core.Internal;
@@ -130,19 +131,32 @@ public ValueTask RemoveAsync(NatsSubBase sub)
return _connection.UnsubscribeAsync(subMetadata.Sid);
}
- public async ValueTask ReconnectAsync(CancellationToken cancellationToken)
+ ///
+ /// Returns commands for all the live subscriptions to be used on reconnect so that they can rebuild their connection state on the server.
+ ///
+ ///
+ /// Commands returned form all the subscriptions will be run as a priority right after reconnection is established.
+ ///
+ /// Enumerable list of commands
+ public IEnumerable GetReconnectCommands()
{
- foreach (var (sid, sidMetadata) in _bySid)
+ var subs = new List<(NatsSubBase, int)>();
+ lock (_gate)
{
- if (sidMetadata.WeakReference.TryGetTarget(out var sub))
+ foreach (var (sid, sidMetadata) in _bySid)
{
- // yield return (sid, sub.Subject, sub.QueueGroup, sub.PendingMsgs);
- await _connection
- .SubscribeCoreAsync(sid, sub.Subject, sub.QueueGroup, sub.PendingMsgs, cancellationToken)
- .ConfigureAwait(false);
- await sub.ReadyAsync().ConfigureAwait(false);
+ if (sidMetadata.WeakReference.TryGetTarget(out var sub))
+ {
+ subs.Add((sub, sid));
+ }
}
}
+
+ foreach (var (sub, sid) in subs)
+ {
+ foreach (var command in sub.GetReconnectCommands(sid))
+ yield return command;
+ }
}
public ISubscriptionManager GetManagerFor(string subject)
diff --git a/src/NATS.Client.Core/NatsConnection.cs b/src/NATS.Client.Core/NatsConnection.cs
index 81b71f8b7..d557565b7 100644
--- a/src/NATS.Client.Core/NatsConnection.cs
+++ b/src/NATS.Client.Core/NatsConnection.cs
@@ -94,6 +94,8 @@ public NatsConnection(NatsOptions options)
internal string InboxPrefix { get; }
+ internal ObjectPool ObjectPool => _pool;
+
///
/// Connect socket and write CONNECT command to nats server.
///
@@ -385,6 +387,12 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect)
_writerState.PriorityCommands.Add(connectCommand);
_writerState.PriorityCommands.Add(PingCommand.Create(_pool, GetCancellationTimer(CancellationToken.None)));
+ if (reconnect)
+ {
+ // Reestablish subscriptions and consumers
+ _writerState.PriorityCommands.AddRange(SubscriptionManager.GetReconnectCommands());
+ }
+
// create the socket writer
_socketWriter = new NatsPipeliningWriteProtocolProcessor(_socket!, _writerState, _pool, Counter);
@@ -393,9 +401,6 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect)
// receive COMMAND response (PONG or ERROR)
await waitForPongOrErrorSignal.Task.ConfigureAwait(false);
-
- // Reestablish subscriptions and consumers
- await SubscriptionManager.ReconnectAsync(_disposedCancellationTokenSource.Token).ConfigureAwait(false);
}
catch (Exception)
{
diff --git a/src/NATS.Client.Core/NatsSubBase.cs b/src/NATS.Client.Core/NatsSubBase.cs
index 8124c3d3e..0cdcd319f 100644
--- a/src/NATS.Client.Core/NatsSubBase.cs
+++ b/src/NATS.Client.Core/NatsSubBase.cs
@@ -1,5 +1,6 @@
using System.Buffers;
using System.Runtime.ExceptionServices;
+using NATS.Client.Core.Commands;
using NATS.Client.Core.Internal;
namespace NATS.Client.Core;
@@ -187,6 +188,21 @@ public virtual async ValueTask ReceiveAsync(string subject, string? replyTo, Rea
internal void ClearException() => Interlocked.Exchange(ref _exception, null);
+ ///
+ /// Collect commands when reconnecting.
+ ///
+ ///
+ /// By default this will yield the required subscription command.
+ /// When overriden base must be called to yield the re-subscription command.
+ /// Additional command (e.g. publishing pull requests in case of JetStream consumers) can be yielded as part of the reconnect routine.
+ ///
+ /// SID which might be required to create subscription commands
+ /// IEnumerable list of commands
+ internal virtual IEnumerable GetReconnectCommands(int sid)
+ {
+ yield return AsyncSubscribeCommand.Create(Connection.ObjectPool, Connection.GetCancellationTimer(default), sid, Subject, QueueGroup, PendingMsgs);
+ }
+
///
/// Invoked when a MSG or HMSG arrives for the subscription.
///
diff --git a/tests/NATS.Client.Core.Tests/ProtocolTest.cs b/tests/NATS.Client.Core.Tests/ProtocolTest.cs
index 43edfaf78..dc194ca6a 100644
--- a/tests/NATS.Client.Core.Tests/ProtocolTest.cs
+++ b/tests/NATS.Client.Core.Tests/ProtocolTest.cs
@@ -1,3 +1,6 @@
+using System.Buffers;
+using System.Text;
+
namespace NATS.Client.Core.Tests;
public class ProtocolTest
@@ -289,4 +292,79 @@ await Retry.Until(
await reg;
}
}
+
+ [Fact]
+ public async Task Reconnect_with_sub_and_additional_commands()
+ {
+ await using var server = NatsServer.Start();
+ var (nats, proxy) = server.CreateProxiedClientConnection();
+
+ const string subject = "foo";
+
+ var sync = 0;
+ await using var sub = new NatsSubReconnectTest(nats, subject, i => Interlocked.Exchange(ref sync, i));
+ await nats.SubAsync(sub.Subject, opts: default, sub);
+
+ await Retry.Until(
+ "subscribed",
+ () => Volatile.Read(ref sync) == 1,
+ async () => await nats.PublishAsync(subject, 1));
+
+ var disconnected = new WaitSignal();
+ nats.ConnectionDisconnected += (_, _) => disconnected.Pulse();
+
+ proxy.Reset();
+
+ await disconnected;
+
+ await Retry.Until(
+ "re-subscribed",
+ () => Volatile.Read(ref sync) == 2,
+ async () => await nats.PublishAsync(subject, 2));
+
+ await Retry.Until(
+ "frames collected",
+ () => proxy.ClientFrames.Any(f => f.Message.StartsWith("PUB foo")));
+
+ var frames = proxy.ClientFrames.Select(f => f.Message).ToList();
+ Assert.StartsWith("SUB foo", frames[0]);
+ Assert.StartsWith("PUB bar1", frames[1]);
+ Assert.StartsWith("PUB bar2", frames[2]);
+ Assert.StartsWith("PUB bar3", frames[3]);
+ Assert.StartsWith("PUB foo", frames[4]);
+
+ await nats.DisposeAsync();
+ }
+
+ private sealed class NatsSubReconnectTest : NatsSubBase
+ {
+ private readonly Action _callback;
+
+ internal NatsSubReconnectTest(NatsConnection connection, string subject, Action callback)
+ : base(connection, connection.SubscriptionManager, subject, default) =>
+ _callback = callback;
+
+ internal override IEnumerable GetReconnectCommands(int sid)
+ {
+ // Yield re-subscription
+ foreach (var command in base.GetReconnectCommands(sid))
+ yield return command;
+
+ // Any additional commands to send on reconnect
+ yield return PublishBytesCommand.Create(Connection.ObjectPool, "bar1", default, default, default, default);
+ yield return PublishBytesCommand.Create(Connection.ObjectPool, "bar2", default, default, default, default);
+ yield return PublishBytesCommand.Create(Connection.ObjectPool, "bar3", default, default, default, default);
+ }
+
+ protected override ValueTask ReceiveInternalAsync(string subject, string? replyTo, ReadOnlySequence? headersBuffer, ReadOnlySequence payloadBuffer)
+ {
+ _callback(int.Parse(Encoding.UTF8.GetString(payloadBuffer)));
+ DecrementMaxMsgs();
+ return ValueTask.CompletedTask;
+ }
+
+ protected override void TryComplete()
+ {
+ }
+ }
}