Skip to content

Commit

Permalink
Send buffer fix (#380)
Browse files Browse the repository at this point in the history
* Send buffer fix

* Send half message failure fixed

* Fixing tests for Linux

* Fixing tests for Linux

* Partial send error tweak

* Partial send error tweak

* format fix

* test fix

* reverted fail count check

* More partial send error tweak

* rework partial send buffer counter (#381)

Signed-off-by: Caleb Lloyd <caleb@synadia.com>

* revert a few benign changes

Signed-off-by: Caleb Lloyd <caleb@synadia.com>

---------

Signed-off-by: Caleb Lloyd <caleb@synadia.com>
Co-authored-by: Caleb Lloyd <2414837+caleblloyd@users.noreply.github.com>
Co-authored-by: Caleb Lloyd <caleb@synadia.com>
  • Loading branch information
3 people authored Feb 8, 2024
1 parent bac5173 commit dd1efd3
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 69 deletions.
79 changes: 67 additions & 12 deletions src/NATS.Client.Core/Commands/CommandWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ internal sealed class CommandWriter : IAsyncDisposable
private readonly HeaderWriter _headerWriter;
private readonly Channel<int> _channelLock;
private readonly Channel<int> _channelSize;
private readonly CancellationTimerPool _ctPool;
private readonly PipeReader _pipeReader;
private readonly PipeWriter _pipeWriter;
private readonly SemaphoreSlim _semLock = new(1);
private readonly PartialSendFailureCounter _partialSendFailureCounter = new();
private ISocketConnection? _socketConnection;
private Task? _flushTask;
private Task? _readerLoopTask;
Expand Down Expand Up @@ -76,13 +76,6 @@ public CommandWriter(NatsConnection connection, ObjectPool pool, NatsOpts opts,
useSynchronizationContext: false));
_pipeReader = pipe.Reader;
_pipeWriter = pipe.Writer;

// We need a new ObjectPool here because of the root token (_cts.Token).
// When the root token is cancelled as this object is disposed, cancellation
// objects in the pooled CancellationTimer should not be reused since the
// root token would already be cancelled which means CancellationTimer tokens
// would always be in a cancelled state.
_ctPool = new CancellationTimerPool(new ObjectPool(opts.ObjectPoolSize), _cts.Token);
}

public void Reset(ISocketConnection socketConnection)
Expand All @@ -100,6 +93,7 @@ await ReaderLoopAsync(
_pipeReader,
_channelSize,
_consolidateMem,
_partialSendFailureCounter,
_ctsReader.Token)
.ConfigureAwait(false);
});
Expand Down Expand Up @@ -437,6 +431,7 @@ private static async Task ReaderLoopAsync(
PipeReader pipeReader,
Channel<int> channelSize,
Memory<byte> consolidateMem,
PartialSendFailureCounter partialSendFailureCounter,
CancellationToken cancellationToken)
{
try
Expand Down Expand Up @@ -520,7 +515,11 @@ private static async Task ReaderLoopAsync(
// only mark bytes as consumed if a full command was sent
if (totalSize > 0)
{
// mark totalSize bytes as consumed
consumed = buffer.GetPosition(totalSize);

// reset the partialSendFailureCounter, since a full command was consumed
partialSendFailureCounter.Reset();
}

// mark sent bytes as examined
Expand All @@ -533,6 +532,35 @@ private static async Task ReaderLoopAsync(
// throw if there was a send failure
if (sendEx != null)
{
if (pending > 0)
{
// there was a partially sent command
// if this command is re-sent and fails again, it most likely means
// that the command is malformed and the nats-server is closing
// the connection with an error. we want to throw this command
// away if partialSendFailureCounter.Failed() returns true
if (partialSendFailureCounter.Failed())
{
// throw away the rest of the partially sent command if it's in the buffer
if (buffer.Length >= pending)
{
consumed = buffer.GetPosition(pending);
examined = buffer.GetPosition(pending);
partialSendFailureCounter.Reset();
while (!channelSize.Reader.TryRead(out _))
{
// should never happen; channel sizes are written before flush is called
await channelSize.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false);
}
}
}
else
{
// increment the counter
partialSendFailureCounter.Increment();
}
}

throw sendEx;
}
}
Expand All @@ -553,10 +581,6 @@ private static async Task ReaderLoopAsync(
{
// Expected during shutdown
}
catch (InvalidOperationException)
{
// We might still be using the previous pipe reader which might be completed already
}
catch (Exception e)
{
logger.LogError(NatsLogEvents.Buffer, e, "Unexpected error in send buffer reader loop");
Expand Down Expand Up @@ -820,4 +844,35 @@ private async ValueTask UnsubscribeStateMachineAsync(bool lockHeld, int sid, int
_semLock.Release();
}
}

private class PartialSendFailureCounter
{
private const int MaxRetry = 1;
private readonly object _gate = new();
private int _count;

public bool Failed()
{
lock (_gate)
{
return _count >= MaxRetry;
}
}

public void Increment()
{
lock (_gate)
{
_count++;
}
}

public void Reset()
{
lock (_gate)
{
_count = 0;
}
}
}
}
4 changes: 3 additions & 1 deletion tests/NATS.Client.Core.Tests/ProtocolTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,11 @@ public async Task Protocol_parser_under_load(int size)
for (var i = 0; i < 3; i++)
{
await Task.Delay(1_000, cts.Token);
var subjectCount = counts.Count;
await server.RestartAsync();
Interlocked.Increment(ref r);
await Task.Delay(1_000, cts.Token);

await Retry.Until("subject count goes up", () => counts.Count > subjectCount);
}

foreach (var log in logger.Logs.Where(x => x.EventId == NatsLogEvents.Protocol && x.LogLevel == LogLevel.Error))
Expand Down
110 changes: 108 additions & 2 deletions tests/NATS.Client.Core.Tests/SendBufferTest.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Diagnostics;
using System.Net.Sockets;
using Microsoft.Extensions.Logging;
using NATS.Client.TestUtilities;

namespace NATS.Client.Core.Tests;
Expand All @@ -12,7 +14,10 @@ public class SendBufferTest
[Fact]
public async Task Send_cancel()
{
void Log(string m) => TmpFileLogger.Log(m);
// void Log(string m) => TmpFileLogger.Log(m);
void Log(string m)
{
}

using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));

Expand All @@ -26,7 +31,7 @@ public async Task Send_cancel()
}
},
Log,
cts.Token);
cancellationToken: cts.Token);

Log("__________________________________");

Expand Down Expand Up @@ -75,4 +80,105 @@ public async Task Send_cancel()
await tasks[i];
}
}

[Fact]
public async Task Send_recover_half_sent()
{
// void Log(string m) => TmpFileLogger.Log(m);
void Log(string m)
{
}

using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));

List<string> pubs = new();
await using var server = new MockServer(
handler: (client, cmd) =>
{
if (cmd.Name == "PUB")
{
lock (pubs)
pubs.Add($"PUB {cmd.Subject}");
}
if (cmd is { Name: "PUB", Subject: "close" })
{
client.Close();
}
return Task.CompletedTask;
},
Log,
info: $"{{\"max_payload\":{1024 * 1024 * 8}}}",
cancellationToken: cts.Token);

Log("__________________________________");

var testLogger = new InMemoryTestLoggerFactory(LogLevel.Error, m =>
{
Log($"[NC] {m.Message}");
if (m.Exception is not SocketException)
_output.WriteLine($"ERROR: {m.Exception}");
});

await using var nats = new NatsConnection(new NatsOpts
{
Url = server.Url,
LoggerFactory = testLogger,
});

Log($"[C] connect {server.Url}");
await nats.ConnectAsync();

Log($"[C] ping");
var rtt = await nats.PingAsync(cts.Token);
Log($"[C] ping rtt={rtt}");

Log($"[C] publishing x1...");
await nats.PublishAsync("x1", "x", cancellationToken: cts.Token);

// we will close the connection in mock server when we receive subject "close"
Log($"[C] publishing close (8MB)...");
var pubTask = nats.PublishAsync("close", new byte[1024 * 1024 * 8], cancellationToken: cts.Token).AsTask();

await pubTask.WaitAsync(cts.Token);

for (var i = 1; i <= 10; i++)
{
try
{
await nats.PingAsync(cts.Token);
break;
}
catch (OperationCanceledException)
{
if (i == 10)
throw;
await Task.Delay(10 * i, cts.Token);
}
}

Log($"[C] publishing x2...");
await nats.PublishAsync("x2", "x", cancellationToken: cts.Token);

Log($"[C] flush...");
await nats.PingAsync(cts.Token);

Assert.Equal(2, testLogger.Logs.Count);
foreach (var log in testLogger.Logs)
{
Assert.True(log.Exception is SocketException, "Socket exception expected");
var socketErrorCode = (log.Exception as SocketException)!.SocketErrorCode;
Assert.True(socketErrorCode is SocketError.ConnectionReset or SocketError.Shutdown, "Socket error code");
}

lock (pubs)
{
Assert.Equal(4, pubs.Count);
Assert.Equal("PUB x1", pubs[0]);
Assert.Equal("PUB close", pubs[1]);
Assert.Equal("PUB close", pubs[2]);
Assert.Equal("PUB x2", pubs[3]);
}
}
}
8 changes: 6 additions & 2 deletions tests/NATS.Client.TestUtilities/InMemoryTestLoggerFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

namespace NATS.Client.TestUtilities;

public class InMemoryTestLoggerFactory(LogLevel level) : ILoggerFactory
public class InMemoryTestLoggerFactory(LogLevel level, Action<InMemoryTestLoggerFactory.LogMessage>? logger = null) : ILoggerFactory
{
private readonly List<LogMessage> _messages = new();

Expand All @@ -28,7 +28,11 @@ public void Dispose()
private void Log(string categoryName, LogLevel logLevel, EventId eventId, Exception? exception, string message)
{
lock (_messages)
_messages.Add(new LogMessage(categoryName, logLevel, eventId, exception, message));
{
var logMessage = new LogMessage(categoryName, logLevel, eventId, exception, message);
_messages.Add(logMessage);
logger?.Invoke(logMessage);
}
}

public record LogMessage(string Category, LogLevel LogLevel, EventId EventId, Exception? Exception, string Message);
Expand Down
Loading

0 comments on commit dd1efd3

Please sign in to comment.