diff --git a/src/NATS.Client.Core/Commands/CommandWriter.cs b/src/NATS.Client.Core/Commands/CommandWriter.cs index 501a761ae..6860e684e 100644 --- a/src/NATS.Client.Core/Commands/CommandWriter.cs +++ b/src/NATS.Client.Core/Commands/CommandWriter.cs @@ -40,10 +40,10 @@ internal sealed class CommandWriter : IAsyncDisposable private readonly HeaderWriter _headerWriter; private readonly Channel _channelLock; private readonly Channel _channelSize; - private readonly CancellationTimerPool _ctPool; private readonly PipeReader _pipeReader; private readonly PipeWriter _pipeWriter; private readonly SemaphoreSlim _semLock = new(1); + private readonly PartialSendFailureCounter _partialSendFailureCounter = new(); private ISocketConnection? _socketConnection; private Task? _flushTask; private Task? _readerLoopTask; @@ -76,13 +76,6 @@ public CommandWriter(NatsConnection connection, ObjectPool pool, NatsOpts opts, useSynchronizationContext: false)); _pipeReader = pipe.Reader; _pipeWriter = pipe.Writer; - - // We need a new ObjectPool here because of the root token (_cts.Token). - // When the root token is cancelled as this object is disposed, cancellation - // objects in the pooled CancellationTimer should not be reused since the - // root token would already be cancelled which means CancellationTimer tokens - // would always be in a cancelled state. - _ctPool = new CancellationTimerPool(new ObjectPool(opts.ObjectPoolSize), _cts.Token); } public void Reset(ISocketConnection socketConnection) @@ -100,6 +93,7 @@ await ReaderLoopAsync( _pipeReader, _channelSize, _consolidateMem, + _partialSendFailureCounter, _ctsReader.Token) .ConfigureAwait(false); }); @@ -437,6 +431,7 @@ private static async Task ReaderLoopAsync( PipeReader pipeReader, Channel channelSize, Memory consolidateMem, + PartialSendFailureCounter partialSendFailureCounter, CancellationToken cancellationToken) { try @@ -520,7 +515,11 @@ private static async Task ReaderLoopAsync( // only mark bytes as consumed if a full command was sent if (totalSize > 0) { + // mark totalSize bytes as consumed consumed = buffer.GetPosition(totalSize); + + // reset the partialSendFailureCounter, since a full command was consumed + partialSendFailureCounter.Reset(); } // mark sent bytes as examined @@ -533,6 +532,35 @@ private static async Task ReaderLoopAsync( // throw if there was a send failure if (sendEx != null) { + if (pending > 0) + { + // there was a partially sent command + // if this command is re-sent and fails again, it most likely means + // that the command is malformed and the nats-server is closing + // the connection with an error. we want to throw this command + // away if partialSendFailureCounter.Failed() returns true + if (partialSendFailureCounter.Failed()) + { + // throw away the rest of the partially sent command if it's in the buffer + if (buffer.Length >= pending) + { + consumed = buffer.GetPosition(pending); + examined = buffer.GetPosition(pending); + partialSendFailureCounter.Reset(); + while (!channelSize.Reader.TryRead(out _)) + { + // should never happen; channel sizes are written before flush is called + await channelSize.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false); + } + } + } + else + { + // increment the counter + partialSendFailureCounter.Increment(); + } + } + throw sendEx; } } @@ -553,10 +581,6 @@ private static async Task ReaderLoopAsync( { // Expected during shutdown } - catch (InvalidOperationException) - { - // We might still be using the previous pipe reader which might be completed already - } catch (Exception e) { logger.LogError(NatsLogEvents.Buffer, e, "Unexpected error in send buffer reader loop"); @@ -820,4 +844,35 @@ private async ValueTask UnsubscribeStateMachineAsync(bool lockHeld, int sid, int _semLock.Release(); } } + + private class PartialSendFailureCounter + { + private const int MaxRetry = 1; + private readonly object _gate = new(); + private int _count; + + public bool Failed() + { + lock (_gate) + { + return _count >= MaxRetry; + } + } + + public void Increment() + { + lock (_gate) + { + _count++; + } + } + + public void Reset() + { + lock (_gate) + { + _count = 0; + } + } + } } diff --git a/tests/NATS.Client.Core.Tests/ProtocolTest.cs b/tests/NATS.Client.Core.Tests/ProtocolTest.cs index 27202246d..f6806d12d 100644 --- a/tests/NATS.Client.Core.Tests/ProtocolTest.cs +++ b/tests/NATS.Client.Core.Tests/ProtocolTest.cs @@ -397,9 +397,11 @@ public async Task Protocol_parser_under_load(int size) for (var i = 0; i < 3; i++) { await Task.Delay(1_000, cts.Token); + var subjectCount = counts.Count; await server.RestartAsync(); Interlocked.Increment(ref r); - await Task.Delay(1_000, cts.Token); + + await Retry.Until("subject count goes up", () => counts.Count > subjectCount); } foreach (var log in logger.Logs.Where(x => x.EventId == NatsLogEvents.Protocol && x.LogLevel == LogLevel.Error)) diff --git a/tests/NATS.Client.Core.Tests/SendBufferTest.cs b/tests/NATS.Client.Core.Tests/SendBufferTest.cs index d6d098804..2873d5dab 100644 --- a/tests/NATS.Client.Core.Tests/SendBufferTest.cs +++ b/tests/NATS.Client.Core.Tests/SendBufferTest.cs @@ -1,4 +1,6 @@ using System.Diagnostics; +using System.Net.Sockets; +using Microsoft.Extensions.Logging; using NATS.Client.TestUtilities; namespace NATS.Client.Core.Tests; @@ -12,7 +14,10 @@ public class SendBufferTest [Fact] public async Task Send_cancel() { - void Log(string m) => TmpFileLogger.Log(m); + // void Log(string m) => TmpFileLogger.Log(m); + void Log(string m) + { + } using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); @@ -26,7 +31,7 @@ public async Task Send_cancel() } }, Log, - cts.Token); + cancellationToken: cts.Token); Log("__________________________________"); @@ -75,4 +80,105 @@ public async Task Send_cancel() await tasks[i]; } } + + [Fact] + public async Task Send_recover_half_sent() + { + // void Log(string m) => TmpFileLogger.Log(m); + void Log(string m) + { + } + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + + List pubs = new(); + await using var server = new MockServer( + handler: (client, cmd) => + { + if (cmd.Name == "PUB") + { + lock (pubs) + pubs.Add($"PUB {cmd.Subject}"); + } + + if (cmd is { Name: "PUB", Subject: "close" }) + { + client.Close(); + } + + return Task.CompletedTask; + }, + Log, + info: $"{{\"max_payload\":{1024 * 1024 * 8}}}", + cancellationToken: cts.Token); + + Log("__________________________________"); + + var testLogger = new InMemoryTestLoggerFactory(LogLevel.Error, m => + { + Log($"[NC] {m.Message}"); + if (m.Exception is not SocketException) + _output.WriteLine($"ERROR: {m.Exception}"); + }); + + await using var nats = new NatsConnection(new NatsOpts + { + Url = server.Url, + LoggerFactory = testLogger, + }); + + Log($"[C] connect {server.Url}"); + await nats.ConnectAsync(); + + Log($"[C] ping"); + var rtt = await nats.PingAsync(cts.Token); + Log($"[C] ping rtt={rtt}"); + + Log($"[C] publishing x1..."); + await nats.PublishAsync("x1", "x", cancellationToken: cts.Token); + + // we will close the connection in mock server when we receive subject "close" + Log($"[C] publishing close (8MB)..."); + var pubTask = nats.PublishAsync("close", new byte[1024 * 1024 * 8], cancellationToken: cts.Token).AsTask(); + + await pubTask.WaitAsync(cts.Token); + + for (var i = 1; i <= 10; i++) + { + try + { + await nats.PingAsync(cts.Token); + break; + } + catch (OperationCanceledException) + { + if (i == 10) + throw; + await Task.Delay(10 * i, cts.Token); + } + } + + Log($"[C] publishing x2..."); + await nats.PublishAsync("x2", "x", cancellationToken: cts.Token); + + Log($"[C] flush..."); + await nats.PingAsync(cts.Token); + + Assert.Equal(2, testLogger.Logs.Count); + foreach (var log in testLogger.Logs) + { + Assert.True(log.Exception is SocketException, "Socket exception expected"); + var socketErrorCode = (log.Exception as SocketException)!.SocketErrorCode; + Assert.True(socketErrorCode is SocketError.ConnectionReset or SocketError.Shutdown, "Socket error code"); + } + + lock (pubs) + { + Assert.Equal(4, pubs.Count); + Assert.Equal("PUB x1", pubs[0]); + Assert.Equal("PUB close", pubs[1]); + Assert.Equal("PUB close", pubs[2]); + Assert.Equal("PUB x2", pubs[3]); + } + } } diff --git a/tests/NATS.Client.TestUtilities/InMemoryTestLoggerFactory.cs b/tests/NATS.Client.TestUtilities/InMemoryTestLoggerFactory.cs index fddde99e8..c04100c24 100644 --- a/tests/NATS.Client.TestUtilities/InMemoryTestLoggerFactory.cs +++ b/tests/NATS.Client.TestUtilities/InMemoryTestLoggerFactory.cs @@ -2,7 +2,7 @@ namespace NATS.Client.TestUtilities; -public class InMemoryTestLoggerFactory(LogLevel level) : ILoggerFactory +public class InMemoryTestLoggerFactory(LogLevel level, Action? logger = null) : ILoggerFactory { private readonly List _messages = new(); @@ -28,7 +28,11 @@ public void Dispose() private void Log(string categoryName, LogLevel logLevel, EventId eventId, Exception? exception, string message) { lock (_messages) - _messages.Add(new LogMessage(categoryName, logLevel, eventId, exception, message)); + { + var logMessage = new LogMessage(categoryName, logLevel, eventId, exception, message); + _messages.Add(logMessage); + logger?.Invoke(logMessage); + } } public record LogMessage(string Category, LogLevel LogLevel, EventId EventId, Exception? Exception, string Message); diff --git a/tests/NATS.Client.TestUtilities/MockServer.cs b/tests/NATS.Client.TestUtilities/MockServer.cs index 5ebfa186f..cecb73712 100644 --- a/tests/NATS.Client.TestUtilities/MockServer.cs +++ b/tests/NATS.Client.TestUtilities/MockServer.cs @@ -9,76 +9,90 @@ public class MockServer : IAsyncDisposable { private readonly Action _logger; private readonly TcpListener _server; + private readonly List _clients = new(); private readonly Task _accept; + private readonly CancellationTokenSource _cts; public MockServer( - Func handler, + Func handler, Action logger, - CancellationToken cancellationToken) + string info = "{\"max_payload\":1048576}", + CancellationToken cancellationToken = default) { _logger = logger; + _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cancellationToken = _cts.Token; _server = new TcpListener(IPAddress.Parse("127.0.0.1"), 0); - _server.Start(); + _server.Start(10); Port = ((IPEndPoint)_server.LocalEndpoint).Port; _accept = Task.Run( async () => { - var client = await _server.AcceptTcpClientAsync(); - - var stream = client.GetStream(); - - var sw = new StreamWriter(stream, Encoding.ASCII); - await sw.WriteAsync("INFO {\"max_payload\":1048576}\r\n"); - await sw.FlushAsync(); - - var sr = new StreamReader(stream, Encoding.ASCII); - + var n = 0; while (!cancellationToken.IsCancellationRequested) { - Log($"[S] >>> READ LINE"); - var line = (await sr.ReadLineAsync())!; + var tcpClient = await _server.AcceptTcpClientAsync(cancellationToken); + var client = new Client(this, tcpClient); + n++; + Log($"[S] [{n}] New client connected"); + var stream = tcpClient.GetStream(); - if (line.StartsWith("CONNECT")) - { - Log($"[S] RCV CONNECT"); - } - else if (line.StartsWith("PING")) - { - Log($"[S] RCV PING"); - await sw.WriteAsync("PONG\r\n"); - await sw.FlushAsync(); - Log($"[S] SND PONG"); - } - else if (line.StartsWith("SUB")) - { - var m = Regex.Match(line, @"^SUB\s+(?\S+)"); - var subject = m.Groups["subject"].Value; - Log($"[S] RCV SUB {subject}"); - await handler(this, new Cmd("SUB", subject, 0)); - } - else if (line.StartsWith("PUB") || line.StartsWith("HPUB")) + var sw = new StreamWriter(stream, Encoding.ASCII); + await sw.WriteAsync($"INFO {info}\r\n"); + await sw.FlushAsync(); + + var sr = new StreamReader(stream, Encoding.ASCII); + + _clients.Add(Task.Run(async () => { - var m = Regex.Match(line, @"^(H?PUB)\s+(?\S+).*?(?\d+)$"); - var size = int.Parse(m.Groups["size"].Value); - var subject = m.Groups["subject"].Value; - Log($"[S] RCV PUB {subject} {size}"); - var read = 0; - var buffer = new byte[size]; - while (read < size) + while (!cancellationToken.IsCancellationRequested) { - var received = await stream.ReadAsync(buffer, read, size - read); - read += received; - Log($"[S] RCV {received} bytes (size={size} read={read})"); - } + var line = (await sr.ReadLineAsync())!; - await handler(this, new Cmd("PUB", subject, size)); - await sr.ReadLineAsync(); - } - else - { - Log($"[S] RCV LINE: {line}"); - } + if (line.StartsWith("CONNECT")) + { + Log($"[S] [{n}] RCV CONNECT"); + } + else if (line.StartsWith("PING")) + { + Log($"[S] [{n}] RCV PING"); + await sw.WriteAsync("PONG\r\n"); + await sw.FlushAsync(); + Log($"[S] [{n}] SND PONG"); + } + else if (line.StartsWith("SUB")) + { + var m = Regex.Match(line, @"^SUB\s+(?\S+)"); + var subject = m.Groups["subject"].Value; + Log($"[S] [{n}] RCV SUB {subject}"); + await handler(client, new Cmd("SUB", subject, 0)); + } + else if (line.StartsWith("PUB") || line.StartsWith("HPUB")) + { + var m = Regex.Match(line, @"^(H?PUB)\s+(?\S+).*?(?\d+)$"); + var size = int.Parse(m.Groups["size"].Value); + var subject = m.Groups["subject"].Value; + Log($"[S] [{n}] RCV PUB {subject} {size}"); + await handler(client, new Cmd("PUB", subject, size)); + var read = 0; + var buffer = new char[size]; + while (read < size) + { + var received = await sr.ReadAsync(buffer, read, size - read); + read += received; + Log($"[S] [{n}] RCV {received} bytes (size={size} read={read})"); + } + + // Log($"[S] RCV PUB payload: {new string(buffer)}"); + await sr.ReadLineAsync(); + } + else + { + Log($"[S] [{n}] RCV LINE: {line}"); + } + } + })); } }, cancellationToken); @@ -90,7 +104,28 @@ public MockServer( public async ValueTask DisposeAsync() { + _cts.Cancel(); _server.Stop(); + foreach (var client in _clients) + { + try + { + await client; + } + catch (ObjectDisposedException) + { + } + catch (OperationCanceledException) + { + } + catch (SocketException) + { + } + catch (IOException) + { + } + } + try { await _accept; @@ -109,4 +144,20 @@ public async ValueTask DisposeAsync() public void Log(string m) => _logger(m); public record Cmd(string Name, string Subject, int Size); + + public class Client + { + private readonly MockServer _server; + private readonly TcpClient _tcpClient; + + public Client(MockServer server, TcpClient tcpClient) + { + _server = server; + _tcpClient = tcpClient; + } + + public void Log(string m) => _server.Log(m); + + public void Close() => _tcpClient.Close(); + } }