Skip to content

Commit

Permalink
Cancellation fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mtmk committed Jan 31, 2024
1 parent 722150f commit a8a5d5d
Show file tree
Hide file tree
Showing 7 changed files with 310 additions and 61 deletions.
141 changes: 82 additions & 59 deletions src/NATS.Client.Core/Commands/CommandWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@ public CommandWriter(ObjectPool pool, NatsOpts opts, ConnectionStatsCounter coun
_cts = new CancellationTokenSource();
_readerLoopTask = Task.Run(ReaderLoopAsync);

// _ctPool = new CancellationTimerPool(_pool, _cts.Token);
_ctPool = new CancellationTimerPool(_pool, CancellationToken.None);
// We need a new ObjectPool here because of the root token (_cts.Token).
// When the root token is cancelled as this object is disposed, cancellation
// objects in the pooled CancellationTimer should not be reused since the
// root token would already be cancelled which means CancellationTimer tokens
// would always be in a cancelled state.
_ctPool = new CancellationTimerPool(new ObjectPool(opts.ObjectPoolSize), _cts.Token);
}

public void Reset(ISocketConnection? socketConnection)
Expand Down Expand Up @@ -102,21 +106,7 @@ public async ValueTask DisposeAsync()
public async ValueTask ConnectAsync(ClientOpts connectOpts, CancellationToken cancellationToken)
{
var cancellationTimer = _ctPool.Start(_defaultCommandTimeout, cancellationToken);
Interlocked.Increment(ref _counter.PendingMessages);

try
{
await _channelLock.Writer.WriteAsync(1, cancellationTimer.Token).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
return;
}
catch (ChannelClosedException)
{
return;
}

await LockAsync(cancellationTimer.Token).ConfigureAwait(false);
try
{
if (_disposed)
Expand All @@ -126,24 +116,23 @@ public async ValueTask ConnectAsync(ClientOpts connectOpts, CancellationToken ca

var bw = GetWriter();
_protocolWriter.WriteConnect(bw, connectOpts);
await bw.FlushAsync(cancellationTimer.Token).ConfigureAwait(false);
var result = await bw.FlushAsync(cancellationTimer.Token).ConfigureAwait(false);
if (result.IsCanceled)
{
throw new OperationCanceledException();
}
}
finally
{
while (!_channelLock.Reader.TryRead(out _))
{
await _channelLock.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false);
}

Interlocked.Decrement(ref _counter.PendingMessages);
await UnLockAsync().ConfigureAwait(false);
cancellationTimer.TryReturn();
}
}

public async ValueTask PingAsync(PingCommand pingCommand, CancellationToken cancellationToken)
{
await LockAsync(cancellationToken).ConfigureAwait(false);

var cancellationTimer = _ctPool.Start(_defaultCommandTimeout, cancellationToken);
await LockAsync(cancellationTimer.Token).ConfigureAwait(false);
try
{
if (_disposed)
Expand All @@ -155,18 +144,23 @@ public async ValueTask PingAsync(PingCommand pingCommand, CancellationToken canc

var bw = GetWriter();
_protocolWriter.WritePing(bw);
await bw.FlushAsync(cancellationToken).ConfigureAwait(false);
var result = await bw.FlushAsync(cancellationTimer.Token).ConfigureAwait(false);
if (result.IsCanceled)
{
throw new OperationCanceledException();
}
}
finally
{
await UnLockAsync(cancellationToken).ConfigureAwait(false);
await UnLockAsync().ConfigureAwait(false);
cancellationTimer.TryReturn();
}
}

public async ValueTask PongAsync(CancellationToken cancellationToken = default)
{
await LockAsync(cancellationToken).ConfigureAwait(false);

var cancellationTimer = _ctPool.Start(_defaultCommandTimeout, cancellationToken);
await LockAsync(cancellationTimer.Token).ConfigureAwait(false);
try
{
if (_disposed)
Expand All @@ -176,11 +170,16 @@ public async ValueTask PongAsync(CancellationToken cancellationToken = default)

var bw = GetWriter();
_protocolWriter.WritePong(bw);
await bw.FlushAsync(cancellationToken).ConfigureAwait(false);
var result = await bw.FlushAsync(cancellationTimer.Token).ConfigureAwait(false);
if (result.IsCanceled)
{
throw new OperationCanceledException();
}
}
finally
{
await UnLockAsync(cancellationToken).ConfigureAwait(false);
await UnLockAsync().ConfigureAwait(false);
cancellationTimer.TryReturn();
}
}

Expand All @@ -205,8 +204,8 @@ public ValueTask PublishAsync<T>(string subject, T? value, NatsHeaders? headers,

public async ValueTask SubscribeAsync(int sid, string subject, string? queueGroup, int? maxMsgs, CancellationToken cancellationToken)
{
await LockAsync(cancellationToken).ConfigureAwait(false);

var cancellationTimer = _ctPool.Start(_defaultCommandTimeout, cancellationToken);
await LockAsync(cancellationTimer.Token).ConfigureAwait(false);
try
{
if (_disposed)
Expand All @@ -216,18 +215,23 @@ public async ValueTask SubscribeAsync(int sid, string subject, string? queueGrou

var bw = GetWriter();
_protocolWriter.WriteSubscribe(bw, sid, subject, queueGroup, maxMsgs);
await bw.FlushAsync(cancellationToken).ConfigureAwait(false);
var result = await bw.FlushAsync(cancellationTimer.Token).ConfigureAwait(false);
if (result.IsCanceled)
{
throw new OperationCanceledException();
}
}
finally
{
await UnLockAsync(cancellationToken).ConfigureAwait(false);
await UnLockAsync().ConfigureAwait(false);
cancellationTimer.TryReturn();
}
}

public async ValueTask UnsubscribeAsync(int sid, int? maxMsgs, CancellationToken cancellationToken)
{
await LockAsync(cancellationToken).ConfigureAwait(false);

var cancellationTimer = _ctPool.Start(_defaultCommandTimeout, cancellationToken);
await LockAsync(cancellationTimer.Token).ConfigureAwait(false);
try
{
if (_disposed)
Expand All @@ -237,34 +241,27 @@ public async ValueTask UnsubscribeAsync(int sid, int? maxMsgs, CancellationToken

var bw = GetWriter();
_protocolWriter.WriteUnsubscribe(bw, sid, maxMsgs);
await bw.FlushAsync(cancellationToken).ConfigureAwait(false);
var result = await bw.FlushAsync(cancellationTimer.Token).ConfigureAwait(false);
if (result.IsCanceled)
{
throw new OperationCanceledException();
}
}
finally
{
await UnLockAsync(cancellationToken).ConfigureAwait(false);
await UnLockAsync().ConfigureAwait(false);
cancellationTimer.TryReturn();
}
}

[MethodImpl(MethodImplOptions.NoInlining)]
private static void ThrowOnDisconnected() => throw new NatsException("Connection hasn't been established yet.");

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private PipeWriter GetWriter()
{
lock (_lock)
{
if (_pipeWriter == null)
ThrowOnDisconnected();
return _pipeWriter!;
}
}

[AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))]
private async ValueTask PublishLockedAsync(string subject, string? replyTo, NatsPooledBufferWriter<byte> payloadBuffer, NatsPooledBufferWriter<byte>? headersBuffer, CancellationToken cancellationToken)
{
var cancellationTimer = _ctPool.Start(_defaultCommandTimeout, cancellationToken);
await LockAsync(cancellationTimer.Token).ConfigureAwait(false);

try
{
var payload = payloadBuffer.WrittenMemory;
Expand All @@ -287,27 +284,53 @@ private async ValueTask PublishLockedAsync(string subject, string? replyTo, Nat
_pool.Return(headersBuffer);
}

await bw.FlushAsync(cancellationTimer.Token).ConfigureAwait(false);
var result = await bw.FlushAsync(cancellationTimer.Token).ConfigureAwait(false);
if (result.IsCanceled)
{
throw new OperationCanceledException();
}
}
finally
{
await UnLockAsync(cancellationTimer.Token).ConfigureAwait(false);
await UnLockAsync().ConfigureAwait(false);
cancellationTimer.TryReturn();
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private ValueTask<int> UnLockAsync(CancellationToken cancellationToken)
private async ValueTask LockAsync(CancellationToken cancellationToken)
{
Interlocked.Increment(ref _counter.PendingMessages);
try
{
await _channelLock.Writer.WriteAsync(1, cancellationToken).ConfigureAwait(false);
}
catch (TaskCanceledException)
{
throw new OperationCanceledException();
}
catch (ChannelClosedException)
{
throw new OperationCanceledException();
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private ValueTask<int> UnLockAsync()
{
Interlocked.Decrement(ref _counter.PendingMessages);
return _channelLock.Reader.ReadAsync(cancellationToken);
return _channelLock.Reader.ReadAsync(_cts.Token);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private ValueTask LockAsync(CancellationToken cancellationToken)
private PipeWriter GetWriter()
{
Interlocked.Increment(ref _counter.PendingMessages);
return _channelLock.Writer.WriteAsync(1, cancellationToken);
lock (_lock)
{
if (_pipeWriter == null)
ThrowOnDisconnected();
return _pipeWriter!;
}
}

private async Task ReaderLoopAsync()
Expand Down
2 changes: 1 addition & 1 deletion src/NATS.Client.Core/Commands/PriorityCommandWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ internal sealed class PriorityCommandWriter : IAsyncDisposable

public PriorityCommandWriter(ObjectPool pool, ISocketConnection socketConnection, NatsOpts opts, ConnectionStatsCounter counter, Action<PingCommand> enqueuePing)
{
CommandWriter = new CommandWriter(pool, opts, counter, enqueuePing, overrideCommandTimeout: TimeSpan.MaxValue);
CommandWriter = new CommandWriter(pool, opts, counter, enqueuePing, overrideCommandTimeout: Timeout.InfiniteTimeSpan);
CommandWriter.Reset(socketConnection);
}

Expand Down
2 changes: 1 addition & 1 deletion src/NATS.Client.Core/Internal/CancellationTimer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public static CancellationTimer Start(ObjectPool pool, CancellationToken rootTok
}

self._timeout = timeout;
if (timeout != TimeSpan.MaxValue)
if (timeout != Timeout.InfiniteTimeSpan)
{
self._cancellationTokenSource.CancelAfter(timeout);
}
Expand Down
18 changes: 18 additions & 0 deletions tests/NATS.Client.Core.Tests/CancellationTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,22 @@ await Assert.ThrowsAsync<TaskCanceledException>(async () =>
}
});
}

[Fact]
public async Task Cancellation_timer()
{
var objectPool = new ObjectPool(10);
var cancellationTimerPool = new CancellationTimerPool(objectPool, CancellationToken.None);
var cancellationTimer = cancellationTimerPool.Start(TimeSpan.FromSeconds(2), CancellationToken.None);

try
{
await Task.Delay(TimeSpan.FromSeconds(4), cancellationTimer.Token);
_output.WriteLine($"delayed 4 seconds");
}
catch (Exception e)
{
_output.WriteLine($"Exception: {e.GetType().Name}");
}
}
}
78 changes: 78 additions & 0 deletions tests/NATS.Client.Core.Tests/SendBufferTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
using System.Diagnostics;
using NATS.Client.TestUtilities;

namespace NATS.Client.Core.Tests;

public class SendBufferTest
{
private readonly ITestOutputHelper _output;

public SendBufferTest(ITestOutputHelper output) => _output = output;

[Fact]
public async Task Send_cancel()
{
void Log(string m) => TmpFileLogger.Log(m);

using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));

await using var server = new MockServer(
async (s, cmd) =>
{
if (cmd.Name == "PUB" && cmd.Subject == "pause")
{
s.Log("[S] pause");
await Task.Delay(10_000, cts.Token);
}
},
Log,
cts.Token);

Log("__________________________________");

await using var nats = new NatsConnection(new NatsOpts { Url = server.Url });

Log($"[C] connect {server.Url}");
await nats.ConnectAsync();

Log($"[C] ping");
var rtt = await nats.PingAsync(cts.Token);
Log($"[C] ping rtt={rtt}");

server.Log($"[C] publishing pause...");
await nats.PublishAsync("pause", "x", cancellationToken: cts.Token);

server.Log($"[C] publishing 1M...");
var payload = new byte[1024 * 1024];
var tasks = new List<Task>();
for (var i = 0; i < 10; i++)
{
var i1 = i;
tasks.Add(Task.Run(async () =>
{
var stopwatch = Stopwatch.StartNew();
try
{
Log($"[C] ({i1}) publish...");
await nats.PublishAsync("x", payload, cancellationToken: cts.Token);
}
catch (Exception e)
{
stopwatch.Stop();
Log($"[C] ({i1}) publish cancelled after {stopwatch.Elapsed.TotalSeconds:n0} s (exception: {e.GetType()})");
return;
}
stopwatch.Stop();
Log($"[C] ({i1}) publish took {stopwatch.Elapsed.TotalSeconds:n3} s");
}));
}

for (var i = 0; i < 10; i++)
{
Log($"[C] await tasks {i}...");
await tasks[i];
}
}
}
Loading

0 comments on commit a8a5d5d

Please sign in to comment.