Skip to content

Commit

Permalink
Reconnect NatsSub commands (#112)
Browse files Browse the repository at this point in the history
* Reconnect NatsSub commands

Collect all subscription and additional commands to be prioritized
when reconnecting. Additional commands will be important for
JetStream consumers to issue pull requests in order to reestablish
their state on the server.

* Fixed test and warnings
  • Loading branch information
mtmk authored Aug 14, 2023
1 parent 31797cf commit 55af9b1
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 11 deletions.
30 changes: 22 additions & 8 deletions src/NATS.Client.Core/Internal/SubscriptionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Concurrent;
using System.Runtime.CompilerServices;
using Microsoft.Extensions.Logging;
using NATS.Client.Core.Commands;

namespace NATS.Client.Core.Internal;

Expand Down Expand Up @@ -130,19 +131,32 @@ public ValueTask RemoveAsync(NatsSubBase sub)
return _connection.UnsubscribeAsync(subMetadata.Sid);
}

public async ValueTask ReconnectAsync(CancellationToken cancellationToken)
/// <summary>
/// Returns commands for all the live subscriptions to be used on reconnect so that they can rebuild their connection state on the server.
/// </summary>
/// <remarks>
/// Commands returned form all the subscriptions will be run as a priority right after reconnection is established.
/// </remarks>
/// <returns>Enumerable list of commands</returns>
public IEnumerable<ICommand> GetReconnectCommands()
{
foreach (var (sid, sidMetadata) in _bySid)
var subs = new List<(NatsSubBase, int)>();
lock (_gate)
{
if (sidMetadata.WeakReference.TryGetTarget(out var sub))
foreach (var (sid, sidMetadata) in _bySid)
{
// yield return (sid, sub.Subject, sub.QueueGroup, sub.PendingMsgs);
await _connection
.SubscribeCoreAsync(sid, sub.Subject, sub.QueueGroup, sub.PendingMsgs, cancellationToken)
.ConfigureAwait(false);
await sub.ReadyAsync().ConfigureAwait(false);
if (sidMetadata.WeakReference.TryGetTarget(out var sub))
{
subs.Add((sub, sid));
}
}
}

foreach (var (sub, sid) in subs)
{
foreach (var command in sub.GetReconnectCommands(sid))
yield return command;
}
}

public ISubscriptionManager GetManagerFor(string subject)
Expand Down
11 changes: 8 additions & 3 deletions src/NATS.Client.Core/NatsConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ public NatsConnection(NatsOptions options)

internal string InboxPrefix { get; }

internal ObjectPool ObjectPool => _pool;

/// <summary>
/// Connect socket and write CONNECT command to nats server.
/// </summary>
Expand Down Expand Up @@ -385,6 +387,12 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect)
_writerState.PriorityCommands.Add(connectCommand);
_writerState.PriorityCommands.Add(PingCommand.Create(_pool, GetCancellationTimer(CancellationToken.None)));

if (reconnect)
{
// Reestablish subscriptions and consumers
_writerState.PriorityCommands.AddRange(SubscriptionManager.GetReconnectCommands());
}

// create the socket writer
_socketWriter = new NatsPipeliningWriteProtocolProcessor(_socket!, _writerState, _pool, Counter);

Expand All @@ -393,9 +401,6 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect)

// receive COMMAND response (PONG or ERROR)
await waitForPongOrErrorSignal.Task.ConfigureAwait(false);

// Reestablish subscriptions and consumers
await SubscriptionManager.ReconnectAsync(_disposedCancellationTokenSource.Token).ConfigureAwait(false);
}
catch (Exception)
{
Expand Down
16 changes: 16 additions & 0 deletions src/NATS.Client.Core/NatsSubBase.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Buffers;
using System.Runtime.ExceptionServices;
using NATS.Client.Core.Commands;
using NATS.Client.Core.Internal;

namespace NATS.Client.Core;
Expand Down Expand Up @@ -187,6 +188,21 @@ public virtual async ValueTask ReceiveAsync(string subject, string? replyTo, Rea

internal void ClearException() => Interlocked.Exchange(ref _exception, null);

/// <summary>
/// Collect commands when reconnecting.
/// </summary>
/// <remarks>
/// By default this will yield the required subscription command.
/// When overriden base must be called to yield the re-subscription command.
/// Additional command (e.g. publishing pull requests in case of JetStream consumers) can be yielded as part of the reconnect routine.
/// </remarks>
/// <param name="sid">SID which might be required to create subscription commands</param>
/// <returns>IEnumerable list of commands</returns>
internal virtual IEnumerable<ICommand> GetReconnectCommands(int sid)
{
yield return AsyncSubscribeCommand.Create(Connection.ObjectPool, Connection.GetCancellationTimer(default), sid, Subject, QueueGroup, PendingMsgs);
}

/// <summary>
/// Invoked when a MSG or HMSG arrives for the subscription.
/// <remarks>
Expand Down
78 changes: 78 additions & 0 deletions tests/NATS.Client.Core.Tests/ProtocolTest.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using System.Buffers;
using System.Text;

namespace NATS.Client.Core.Tests;

public class ProtocolTest
Expand Down Expand Up @@ -289,4 +292,79 @@ await Retry.Until(
await reg;
}
}

[Fact]
public async Task Reconnect_with_sub_and_additional_commands()
{
await using var server = NatsServer.Start();
var (nats, proxy) = server.CreateProxiedClientConnection();

const string subject = "foo";

var sync = 0;
await using var sub = new NatsSubReconnectTest(nats, subject, i => Interlocked.Exchange(ref sync, i));
await nats.SubAsync(sub.Subject, opts: default, sub);

await Retry.Until(
"subscribed",
() => Volatile.Read(ref sync) == 1,
async () => await nats.PublishAsync(subject, 1));

var disconnected = new WaitSignal();
nats.ConnectionDisconnected += (_, _) => disconnected.Pulse();

proxy.Reset();

await disconnected;

await Retry.Until(
"re-subscribed",
() => Volatile.Read(ref sync) == 2,
async () => await nats.PublishAsync(subject, 2));

await Retry.Until(
"frames collected",
() => proxy.ClientFrames.Any(f => f.Message.StartsWith("PUB foo")));

var frames = proxy.ClientFrames.Select(f => f.Message).ToList();
Assert.StartsWith("SUB foo", frames[0]);
Assert.StartsWith("PUB bar1", frames[1]);
Assert.StartsWith("PUB bar2", frames[2]);
Assert.StartsWith("PUB bar3", frames[3]);
Assert.StartsWith("PUB foo", frames[4]);

await nats.DisposeAsync();
}

private sealed class NatsSubReconnectTest : NatsSubBase
{
private readonly Action<int> _callback;

internal NatsSubReconnectTest(NatsConnection connection, string subject, Action<int> callback)
: base(connection, connection.SubscriptionManager, subject, default) =>
_callback = callback;

internal override IEnumerable<ICommand> GetReconnectCommands(int sid)
{
// Yield re-subscription
foreach (var command in base.GetReconnectCommands(sid))
yield return command;

// Any additional commands to send on reconnect
yield return PublishBytesCommand.Create(Connection.ObjectPool, "bar1", default, default, default, default);
yield return PublishBytesCommand.Create(Connection.ObjectPool, "bar2", default, default, default, default);
yield return PublishBytesCommand.Create(Connection.ObjectPool, "bar3", default, default, default, default);
}

protected override ValueTask ReceiveInternalAsync(string subject, string? replyTo, ReadOnlySequence<byte>? headersBuffer, ReadOnlySequence<byte> payloadBuffer)
{
_callback(int.Parse(Encoding.UTF8.GetString(payloadBuffer)));
DecrementMaxMsgs();
return ValueTask.CompletedTask;
}

protected override void TryComplete()
{
}
}
}

0 comments on commit 55af9b1

Please sign in to comment.