Skip to content

Commit

Permalink
Fix gRPC retry calls not unregistering from cancellation token (#1398)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK authored Sep 3, 2021
1 parent 15158af commit c804021
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/Grpc.Net.Client/Internal/IGrpcCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,7 @@ internal interface IGrpcCall<TRequest, TResponse> : IDisposable
void StartDuplexStreaming();

Task WriteClientStreamAsync<TState>(Func<GrpcCall<TRequest, TResponse>, Stream, CallOptions, TState, ValueTask> writeFunc, TState state);

bool Disposed { get; }
}
}
7 changes: 7 additions & 0 deletions src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,13 @@ private async Task StartCall(Action<GrpcCall<TRequest, TResponse>> startCallFunc
}
}
}

if (CommitedCallTask.IsCompletedSuccessfully() && CommitedCallTask.Result == call)
{
// Wait until the commited call is finished and then clean up hedging call.
await call.CallTask.ConfigureAwait(false);
Cleanup();
}
}

protected override void OnCommitCall(IGrpcCall<TRequest, TResponse> call)
Expand Down
10 changes: 10 additions & 0 deletions src/Grpc.Net.Client/Internal/Retry/RetryCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,16 @@ private async Task StartRetry(Action<GrpcCall<TRequest, TResponse>> startCallFun
}
finally
{
if (CommitedCallTask.IsCompletedSuccessfully())
{
if (CommitedCallTask.Result is GrpcCall<TRequest, TResponse> call)
{
// Wait until the commited call is finished and then clean up retry call.
await call.CallTask.ConfigureAwait(false);
Cleanup();
}
}

Log.StoppingRetryWorker(Logger);
}
}
Expand Down
30 changes: 24 additions & 6 deletions src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ internal abstract partial class RetryCallBase<TRequest, TResponse> : IGrpcCall<T
private readonly TaskCompletionSource<IGrpcCall<TRequest, TResponse>> _commitedCallTcs;
private RetryCallBaseClientStreamReader<TRequest, TResponse>? _retryBaseClientStreamReader;
private RetryCallBaseClientStreamWriter<TRequest, TResponse>? _retryBaseClientStreamWriter;
private CancellationTokenRegistration? _ctsRegistration;

// Internal for unit testing.
internal CancellationTokenRegistration? _ctsRegistration;

protected object Lock { get; } = new object();
protected ILogger Logger { get; }
Expand All @@ -52,14 +54,14 @@ internal abstract partial class RetryCallBase<TRequest, TResponse> : IGrpcCall<T
protected int MaxRetryAttempts { get; }
protected CancellationTokenSource CancellationTokenSource { get; }
protected TaskCompletionSource<IGrpcCall<TRequest, TResponse>?>? NewActiveCallTcs { get; set; }
protected bool Disposed { get; private set; }

public GrpcChannel Channel { get; }
public Task<IGrpcCall<TRequest, TResponse>> CommitedCallTask => _commitedCallTcs.Task;
public IAsyncStreamReader<TResponse>? ClientStreamReader => _retryBaseClientStreamReader ??= new RetryCallBaseClientStreamReader<TRequest, TResponse>(this);
public IClientStreamWriter<TRequest>? ClientStreamWriter => _retryBaseClientStreamWriter ??= new RetryCallBaseClientStreamWriter<TRequest, TResponse>(this);
public WriteOptions? ClientStreamWriteOptions { get; internal set; }
public bool ClientStreamComplete { get; set; }
public bool Disposed { get; private set; }

protected int AttemptCount { get; private set; }
protected List<ReadOnlyMemory<byte>> BufferedMessages { get; }
Expand Down Expand Up @@ -345,6 +347,16 @@ protected void CommitCall(IGrpcCall<TRequest, TResponse> call, CommitReason comm

NewActiveCallTcs?.SetResult(null);
_commitedCallTcs.SetResult(call);

// If the commited call has finished and cleaned up then it is safe for
// the wrapping retry call to clean up. This is required to unregister
// from the cancellation token and avoid a memory leak.
//
// A commited call that has already cleaned up is likely a StatusGrpcCall.
if (call.Disposed)
{
Cleanup();
}
}
}
}
Expand Down Expand Up @@ -406,18 +418,24 @@ protected virtual void Dispose(bool disposing)

if (disposing)
{
_ctsRegistration?.Dispose();
CancellationTokenSource.Cancel();

if (CommitedCallTask.IsCompletedSuccessfully())
{
CommitedCallTask.Result.Dispose();
}

ClearRetryBuffer();
Cleanup();
}
}

protected void Cleanup()
{
_ctsRegistration?.Dispose();
_ctsRegistration = null;
CancellationTokenSource.Cancel();

ClearRetryBuffer();
}

internal bool TryAddToRetryBuffer(ReadOnlyMemory<byte> message)
{
lock (Lock)
Expand Down
1 change: 1 addition & 0 deletions src/Grpc.Net.Client/Internal/Retry/StatusGrpcCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ internal sealed class StatusGrpcCall<TRequest, TResponse> : IGrpcCall<TRequest,

public IClientStreamWriter<TRequest>? ClientStreamWriter => _clientStreamWriter ??= new StatusClientStreamWriter(_status);
public IAsyncStreamReader<TResponse>? ClientStreamReader => _clientStreamReader ??= new StatusStreamReader(_status);
public bool Disposed => true;

public StatusGrpcCall(Status status)
{
Expand Down
58 changes: 58 additions & 0 deletions test/FunctionalTests/Client/RetryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Google.Protobuf;
Expand Down Expand Up @@ -356,6 +357,63 @@ Task<DataMessage> UnaryFailure(DataMessage request, ServerCallContext context)
tcs.SetResult(new DataMessage());
}

[Test]
public async Task ServerStreaming_CancellatonTokenSpecified_TokenUnregisteredAndResourcesReleased()
{
Task FakeServerStreamCall(DataMessage request, IServerStreamWriter<DataMessage> responseStream, ServerCallContext context)
{
return Task.CompletedTask;
}

// Arrange
var method = Fixture.DynamicGrpc.AddServerStreamingMethod<DataMessage, DataMessage>(FakeServerStreamCall);

var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(retryableStatusCodes: new List<StatusCode> { StatusCode.DeadlineExceeded });
var channel = CreateChannel(serviceConfig: serviceConfig);

var references = new List<WeakReference>();

// Checking that token register calls don't build up on CTS and create a memory leak.
var cts = new CancellationTokenSource();

// Act
// Send calls in a different method so there is no chance that a stack reference
// to a gRPC call is still alive after calls are complete.
await MakeCallsAsync(channel, method, references, cts.Token).DefaultTimeout();

// Assert
// There is a race when cleaning up cancellation token registry.
// Retry a few times to ensure GC is run after unregister.
await TestHelpers.AssertIsTrueRetryAsync(() =>
{
GC.Collect();
GC.WaitForPendingFinalizers();
for (var i = 0; i < references.Count; i++)
{
if (references[i].IsAlive)
{
return false;
}
}
return true;
}, "Assert that retry call resources are released.");
}

[MethodImpl(MethodImplOptions.NoInlining)]
private static async Task MakeCallsAsync(GrpcChannel channel, Method<DataMessage, DataMessage> method, List<WeakReference> references, CancellationToken cancellationToken)
{
var client = TestClientFactory.Create(channel, method);
for (int i = 0; i < 10; i++)
{
var call = client.ServerStreamingCall(new DataMessage(), new CallOptions(cancellationToken: cancellationToken));
references.Add(new WeakReference(call.ResponseStream));

Assert.IsFalse(await call.ResponseStream.MoveNext());
}
}

[TestCase(1)]
[TestCase(20)]
public async Task Unary_AttemptsGreaterThanDefaultClientLimit_LimitedAttemptsMade(int hedgingDelay)
Expand Down
32 changes: 32 additions & 0 deletions test/Grpc.Net.Client.Tests/Retry/HedgingCallTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ public async Task AsyncUnaryCall_CancellationDuringBackoff_CanceledStatus()

// Act
hedgingCall.StartUnary(new HelloRequest());
Assert.IsNotNull(hedgingCall._ctsRegistration);

// Assert
await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._activeCalls.Count == 0, "Wait for all calls to fail.").DefaultTimeout();
Expand All @@ -340,6 +341,37 @@ public async Task AsyncUnaryCall_CancellationDuringBackoff_CanceledStatus()
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => hedgingCall.GetResponseAsync()).DefaultTimeout();
Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode);
Assert.AreEqual("Call canceled by the client.", ex.Status.Detail);
Assert.IsNull(hedgingCall._ctsRegistration);
}

[Test]
public async Task AsyncUnaryCall_CancellationTokenSuccess_CleanedUp()
{
// Arrange
var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
var httpClient = ClientTestHelpers.CreateTestClient(async request =>
{
await tcs.Task;
var reply = new HelloReply { Message = "Hello world" };
var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout();
return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent);
});
var cts = new CancellationTokenSource();
var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromSeconds(10));
var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig);
var hedgingCall = new HedgingCall<HelloRequest, HelloReply>(CreateHedgingPolicy(serviceConfig.MethodConfigs[0].HedgingPolicy), invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions(cancellationToken: cts.Token));

// Act
hedgingCall.StartUnary(new HelloRequest());
Assert.IsNotNull(hedgingCall._ctsRegistration);
tcs.SetResult(null);

// Assert
await hedgingCall.GetResponseAsync().DefaultTimeout();

// There is a race between unregistering and GetResponseAsync returning.
await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._ctsRegistration == null, "Hedge call CTS unregistered.");
}

[Test]
Expand Down

0 comments on commit c804021

Please sign in to comment.