From 4c5a4db8937e083363c1ae7987842f00342a4838 Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Thu, 31 Mar 2022 11:51:05 +0800 Subject: [PATCH] Fix Kestrel conn close on shutdown while response is writing (#40933) --- .../src/Internal/Http2/Http2FrameWriter.cs | 102 ++++++++++-------- .../src/Internal/Http2/Http2OutputProducer.cs | 30 ++---- .../Http2/Http2FrameWriterBenchmark.cs | 17 ++- .../Kestrel/shared/test/TestContextFactory.cs | 13 ++- .../Http2/Http2RequestTests.cs | 93 ++++++++++++++++ .../Interop.FunctionalTests/HttpHelpers.cs | 2 +- 6 files changed, 185 insertions(+), 72 deletions(-) diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs index d04e6b74c60d..cec7e6d53566 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs @@ -172,8 +172,19 @@ public ValueTask Write100ContinueAsync(int streamId) | Padding (*) ... +---------------------------------------------------------------+ */ - public void WriteResponseHeaders(int streamId, int statusCode, Http2HeadersFrameFlags headerFrameFlags, HttpResponseHeaders headers) + public void WriteResponseHeaders(Http2Stream stream, int statusCode, bool endStream, HttpResponseHeaders headers) { + Http2HeadersFrameFlags headerFrameFlags; + if (endStream) + { + headerFrameFlags = Http2HeadersFrameFlags.END_STREAM; + stream.DecrementActiveClientStreamCount(); + } + else + { + headerFrameFlags = Http2HeadersFrameFlags.NONE; + } + lock (_writeLock) { if (_completed) @@ -184,24 +195,26 @@ public void WriteResponseHeaders(int streamId, int statusCode, Http2HeadersFrame try { _headersEnumerator.Initialize(headers); - _outgoingFrame.PrepareHeaders(headerFrameFlags, streamId); + _outgoingFrame.PrepareHeaders(headerFrameFlags, stream.StreamId); var buffer = _headerEncodingBuffer.AsSpan(); var done = HPackHeaderWriter.BeginEncodeHeaders(statusCode, _hpackEncoder, _headersEnumerator, buffer, out var payloadLength); - FinishWritingHeaders(streamId, payloadLength, done); + FinishWritingHeaders(stream.StreamId, payloadLength, done); } // Any exception from the HPack encoder can leave the dynamic table in a corrupt state. // Since we allow custom header encoders we don't know what type of exceptions to expect. catch (Exception ex) { - _log.HPackEncodingError(_connectionId, streamId, ex); + _log.HPackEncodingError(_connectionId, stream.StreamId, ex); _http2Connection.Abort(new ConnectionAbortedException(ex.Message, ex)); throw new InvalidOperationException(ex.Message, ex); // Report the error to the user if this was the first write. } } } - public ValueTask WriteResponseTrailersAsync(int streamId, HttpResponseTrailers headers) + public ValueTask WriteResponseTrailersAsync(Http2Stream stream, HttpResponseTrailers headers) { + stream.DecrementActiveClientStreamCount(); + lock (_writeLock) { if (_completed) @@ -212,16 +225,16 @@ public ValueTask WriteResponseTrailersAsync(int streamId, HttpRespo try { _headersEnumerator.Initialize(headers); - _outgoingFrame.PrepareHeaders(Http2HeadersFrameFlags.END_STREAM, streamId); + _outgoingFrame.PrepareHeaders(Http2HeadersFrameFlags.END_STREAM, stream.StreamId); var buffer = _headerEncodingBuffer.AsSpan(); var done = HPackHeaderWriter.BeginEncodeHeaders(_hpackEncoder, _headersEnumerator, buffer, out var payloadLength); - FinishWritingHeaders(streamId, payloadLength, done); + FinishWritingHeaders(stream.StreamId, payloadLength, done); } // Any exception from the HPack encoder can leave the dynamic table in a corrupt state. // Since we allow custom header encoders we don't know what type of exceptions to expect. catch (Exception ex) { - _log.HPackEncodingError(_connectionId, streamId, ex); + _log.HPackEncodingError(_connectionId, stream.StreamId, ex); _http2Connection.Abort(new ConnectionAbortedException(ex.Message, ex)); } @@ -258,7 +271,7 @@ private void FinishWritingHeaders(int streamId, int payloadLength, bool done) } } - public ValueTask WriteDataAsync(int streamId, StreamOutputFlowControl flowControl, in ReadOnlySequence data, bool endStream, bool firstWrite, bool forceFlush) + public ValueTask WriteDataAsync(Http2Stream stream, StreamOutputFlowControl flowControl, in ReadOnlySequence data, bool endStream, bool firstWrite, bool forceFlush) { // Logic in this method is replicated in WriteDataAndTrailersAsync. // Changes here may need to be mirrored in WriteDataAndTrailersAsync. @@ -277,12 +290,12 @@ public ValueTask WriteDataAsync(int streamId, StreamOutputFlowContr // https://httpwg.org/specs/rfc7540.html#rfc.section.6.9.1 if (dataLength != 0 && dataLength > flowControl.Available) { - return WriteDataAsync(streamId, flowControl, data, dataLength, endStream, firstWrite); + return WriteDataAsync(stream, flowControl, data, dataLength, endStream, firstWrite); } // This cast is safe since if dataLength would overflow an int, it's guaranteed to be greater than the available flow control window. flowControl.Advance((int)dataLength); - WriteDataUnsynchronized(streamId, data, dataLength, endStream); + WriteDataUnsynchronized(stream, data, dataLength, endStream); if (forceFlush) { @@ -293,7 +306,7 @@ public ValueTask WriteDataAsync(int streamId, StreamOutputFlowContr } } - public ValueTask WriteDataAndTrailersAsync(int streamId, StreamOutputFlowControl flowControl, in ReadOnlySequence data, bool firstWrite, HttpResponseTrailers headers) + public ValueTask WriteDataAndTrailersAsync(Http2Stream stream, StreamOutputFlowControl flowControl, in ReadOnlySequence data, bool firstWrite, HttpResponseTrailers headers) { // This method combines WriteDataAsync and WriteResponseTrailers. // Changes here may need to be mirrored in WriteDataAsync. @@ -312,21 +325,21 @@ public ValueTask WriteDataAndTrailersAsync(int streamId, StreamOutp // https://httpwg.org/specs/rfc7540.html#rfc.section.6.9.1 if (dataLength != 0 && dataLength > flowControl.Available) { - return WriteDataAndTrailersAsyncCore(this, streamId, flowControl, data, dataLength, firstWrite, headers); + return WriteDataAndTrailersAsyncCore(this, stream, flowControl, data, dataLength, firstWrite, headers); } // This cast is safe since if dataLength would overflow an int, it's guaranteed to be greater than the available flow control window. flowControl.Advance((int)dataLength); - WriteDataUnsynchronized(streamId, data, dataLength, endStream: false); + WriteDataUnsynchronized(stream, data, dataLength, endStream: false); - return WriteResponseTrailersAsync(streamId, headers); + return WriteResponseTrailersAsync(stream, headers); } - static async ValueTask WriteDataAndTrailersAsyncCore(Http2FrameWriter writer, int streamId, StreamOutputFlowControl flowControl, ReadOnlySequence data, long dataLength, bool firstWrite, HttpResponseTrailers headers) + static async ValueTask WriteDataAndTrailersAsyncCore(Http2FrameWriter writer, Http2Stream stream, StreamOutputFlowControl flowControl, ReadOnlySequence data, long dataLength, bool firstWrite, HttpResponseTrailers headers) { - await writer.WriteDataAsync(streamId, flowControl, data, dataLength, endStream: false, firstWrite); + await writer.WriteDataAsync(stream, flowControl, data, dataLength, endStream: false, firstWrite); - return await writer.WriteResponseTrailersAsync(streamId, headers); + return await writer.WriteResponseTrailersAsync(stream, headers); } } @@ -339,12 +352,12 @@ static async ValueTask WriteDataAndTrailersAsyncCore(Http2FrameWrit | Padding (*) ... +---------------------------------------------------------------+ */ - private void WriteDataUnsynchronized(int streamId, in ReadOnlySequence data, long dataLength, bool endStream) + private void WriteDataUnsynchronized(Http2Stream stream, in ReadOnlySequence data, long dataLength, bool endStream) { Debug.Assert(dataLength == data.Length); // Note padding is not implemented - _outgoingFrame.PrepareData(streamId); + _outgoingFrame.PrepareData(stream.StreamId); if (dataLength > _maxFrameSize) // Minus padding { @@ -352,16 +365,7 @@ private void WriteDataUnsynchronized(int streamId, in ReadOnlySequence dat return; } - if (endStream) - { - _outgoingFrame.DataFlags |= Http2DataFrameFlags.END_STREAM; - } - - _outgoingFrame.PayloadLength = (int)dataLength; // Plus padding - - WriteHeaderUnsynchronized(); - - data.CopyTo(_outputWriter); + WriteDataUnsynchronizedCore(stream, endStream, dataLength, data); // Plus padding return; @@ -378,14 +382,8 @@ void TrimAndWriteDataUnsynchronized(in ReadOnlySequence data, long dataLen do { var currentData = remainingData.Slice(0, dataPayloadLength); - _outgoingFrame.PayloadLength = dataPayloadLength; // Plus padding - WriteHeaderUnsynchronized(); - - foreach (var buffer in currentData) - { - _outputWriter.Write(buffer.Span); - } + WriteDataUnsynchronizedCore(stream, endStream: false, dataPayloadLength, currentData); // Plus padding dataLength -= dataPayloadLength; @@ -393,25 +391,37 @@ void TrimAndWriteDataUnsynchronized(in ReadOnlySequence data, long dataLen } while (dataLength > dataPayloadLength); + WriteDataUnsynchronizedCore(stream, endStream, dataLength, remainingData); + + // Plus padding + } + + void WriteDataUnsynchronizedCore(Http2Stream stream, bool endStream, long dataLength, in ReadOnlySequence data) + { + Debug.Assert(dataLength == data.Length); + if (endStream) { _outgoingFrame.DataFlags |= Http2DataFrameFlags.END_STREAM; + + // When writing data, must decrement active stream count after flow control availability is checked. + // If active stream count becomes zero while a graceful shutdown is in progress then the input side of connection is closed. + // This is a problem if a large amount of data is being written. The server must keep processing incoming WINDOW_UPDATE frames. + // No WINDOW_UPDATE frames means response write could hit flow control and hang. + // Decrement also has to happen before writing END_STREAM to client to avoid race over active stream count. + stream.DecrementActiveClientStreamCount(); } + // It can be expensive to get length from ROS. Use already available value. _outgoingFrame.PayloadLength = (int)dataLength; // Plus padding WriteHeaderUnsynchronized(); - foreach (var buffer in remainingData) - { - _outputWriter.Write(buffer.Span); - } - - // Plus padding + data.CopyTo(_outputWriter); } } - private async ValueTask WriteDataAsync(int streamId, StreamOutputFlowControl flowControl, ReadOnlySequence data, long dataLength, bool endStream, bool firstWrite) + private async ValueTask WriteDataAsync(Http2Stream stream, StreamOutputFlowControl flowControl, ReadOnlySequence data, long dataLength, bool endStream, bool firstWrite) { FlushResult flushResult = default; @@ -436,13 +446,13 @@ private async ValueTask WriteDataAsync(int streamId, StreamOutputFl { if (actual < dataLength) { - WriteDataUnsynchronized(streamId, data.Slice(0, actual), actual, endStream: false); + WriteDataUnsynchronized(stream, data.Slice(0, actual), actual, endStream: false); data = data.Slice(actual); dataLength -= actual; } else { - WriteDataUnsynchronized(streamId, data, actual, endStream); + WriteDataUnsynchronized(stream, data, actual, endStream); dataLength = 0; } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs index fb507c02e563..c138befdf181 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs @@ -203,20 +203,11 @@ public void WriteResponseHeaders(int statusCode, string? reasonPhrase, HttpRespo // The headers will be the final frame if: // 1. There is no content // 2. There is no trailing HEADERS frame. - Http2HeadersFrameFlags http2HeadersFrame; + _streamEnded = appCompleted + && !_startedWritingDataFrames + && (_stream.ResponseTrailers == null || _stream.ResponseTrailers.Count == 0); - if (appCompleted && !_startedWritingDataFrames && (_stream.ResponseTrailers == null || _stream.ResponseTrailers.Count == 0)) - { - _streamEnded = true; - _stream.DecrementActiveClientStreamCount(); - http2HeadersFrame = Http2HeadersFrameFlags.END_STREAM; - } - else - { - http2HeadersFrame = Http2HeadersFrameFlags.NONE; - } - - _frameWriter.WriteResponseHeaders(StreamId, statusCode, http2HeadersFrame, responseHeaders); + _frameWriter.WriteResponseHeaders(_stream, statusCode, _streamEnded, responseHeaders); } } @@ -429,16 +420,15 @@ private async Task ProcessDataWrites() // Write any remaining content then write trailers _stream.ResponseTrailers.SetReadOnly(); - _stream.DecrementActiveClientStreamCount(); if (readResult.Buffer.Length > 0) { // It is faster to write data and trailers together. Locking once reduces lock contention. - flushResult = await _frameWriter.WriteDataAndTrailersAsync(StreamId, _flowControl, readResult.Buffer, firstWrite, _stream.ResponseTrailers); + flushResult = await _frameWriter.WriteDataAndTrailersAsync(_stream, _flowControl, readResult.Buffer, firstWrite, _stream.ResponseTrailers); } else { - flushResult = await _frameWriter.WriteResponseTrailersAsync(StreamId, _stream.ResponseTrailers); + flushResult = await _frameWriter.WriteResponseTrailersAsync(_stream, _stream.ResponseTrailers); } } else if (readResult.IsCompleted && _streamEnded) @@ -454,13 +444,7 @@ private async Task ProcessDataWrites() else { var endStream = readResult.IsCompleted; - - if (endStream) - { - _stream.DecrementActiveClientStreamCount(); - } - - flushResult = await _frameWriter.WriteDataAsync(StreamId, _flowControl, readResult.Buffer, endStream, firstWrite, forceFlush: true); + flushResult = await _frameWriter.WriteDataAsync(_stream, _flowControl, readResult.Buffer, endStream, firstWrite, forceFlush: true); } firstWrite = false; diff --git a/src/Servers/Kestrel/perf/Microbenchmarks/Http2/Http2FrameWriterBenchmark.cs b/src/Servers/Kestrel/perf/Microbenchmarks/Http2/Http2FrameWriterBenchmark.cs index 3a2844125c2f..bb5d2d7d8b1e 100644 --- a/src/Servers/Kestrel/perf/Microbenchmarks/Http2/Http2FrameWriterBenchmark.cs +++ b/src/Servers/Kestrel/perf/Microbenchmarks/Http2/Http2FrameWriterBenchmark.cs @@ -20,6 +20,7 @@ public class Http2FrameWriterBenchmark private Pipe _pipe; private Http2FrameWriter _frameWriter; private HttpResponseHeaders _responseHeaders; + private Http2Stream _stream; [GlobalSetup] public void GlobalSetup() @@ -45,6 +46,8 @@ public void GlobalSetup() _memoryPool, serviceContext); + _stream = new MockHttp2Stream(TestContextFactory.CreateHttp2StreamContext(streamId: 0)); + _responseHeaders = new HttpResponseHeaders(); var headers = (IHeaderDictionary)_responseHeaders; headers.ContentType = "application/json"; @@ -54,7 +57,7 @@ public void GlobalSetup() [Benchmark] public void WriteResponseHeaders() { - _frameWriter.WriteResponseHeaders(0, 200, Http2HeadersFrameFlags.END_HEADERS, _responseHeaders); + _frameWriter.WriteResponseHeaders(_stream, 200, endStream: true, _responseHeaders); } [GlobalCleanup] @@ -63,4 +66,16 @@ public void Dispose() _pipe.Writer.Complete(); _memoryPool?.Dispose(); } + + private class MockHttp2Stream : Http2Stream + { + public MockHttp2Stream(Http2StreamContext context) + { + Initialize(context); + } + + public override void Execute() + { + } + } } diff --git a/src/Servers/Kestrel/shared/test/TestContextFactory.cs b/src/Servers/Kestrel/shared/test/TestContextFactory.cs index 0f1fdba1f7ca..14df1153cbd8 100644 --- a/src/Servers/Kestrel/shared/test/TestContextFactory.cs +++ b/src/Servers/Kestrel/shared/test/TestContextFactory.cs @@ -155,7 +155,7 @@ public static Http2StreamContext CreateHttp2StreamContext( localEndPoint: localEndPoint, remoteEndPoint: remoteEndPoint, streamId: streamId ?? 0, - streamLifetimeHandler: streamLifetimeHandler, + streamLifetimeHandler: streamLifetimeHandler ?? new TestHttp2StreamLifetimeHandler(), clientPeerSettings: clientPeerSettings ?? new Http2PeerSettings(), serverPeerSettings: serverPeerSettings ?? new Http2PeerSettings(), frameWriter: frameWriter, @@ -201,6 +201,17 @@ public static Http3StreamContext CreateHttp3StreamContext( return context; } + private class TestHttp2StreamLifetimeHandler : IHttp2StreamLifetimeHandler + { + public void DecrementActiveClientStreamCount() + { + } + + public void OnStreamCompleted(Http2Stream stream) + { + } + } + private class TestMultiplexedConnectionContext : MultiplexedConnectionContext { public override string ConnectionId { get; set; } diff --git a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http2/Http2RequestTests.cs b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http2/Http2RequestTests.cs index 142fb5577033..59834b29eff0 100644 --- a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http2/Http2RequestTests.cs +++ b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http2/Http2RequestTests.cs @@ -3,11 +3,16 @@ using System.Net; using System.Net.Http; +using System.Net.Http.Headers; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Headers; +using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; namespace Interop.FunctionalTests.Http2; @@ -37,6 +42,94 @@ public async Task GET_NoTLS_Http11RequestToHttp2Endpoint_400Result() } } + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task GET_RequestReturnsLargeData_GracefulShutdownDuringRequest_RequestGracefullyCompletes(bool hasTrailers) + { + // Arrange + const int DataLength = 500_000; + var randomBytes = Enumerable.Range(1, DataLength).Select(i => (byte)((i % 10) + 48)).ToArray(); + + var syncPoint = new SyncPoint(); + + ILogger logger = null; + var builder = CreateHostBuilder( + async c => + { + await syncPoint.WaitToContinue(); + + var memory = c.Response.BodyWriter.GetMemory(randomBytes.Length); + + logger.LogInformation($"Server writing {randomBytes.Length} bytes response"); + randomBytes.CopyTo(memory); + + // It's important for this test that the large write is the last data written to + // the response and it's not awaited by the request delegate. + logger.LogInformation($"Server advancing {randomBytes.Length} bytes response"); + c.Response.BodyWriter.Advance(randomBytes.Length); + + if (hasTrailers) + { + c.Response.AppendTrailer("test-trailer", "value!"); + } + }, + protocol: HttpProtocols.Http2, + plaintext: true); + + using var host = builder.Build(); + logger = host.Services.GetRequiredService().CreateLogger("Test"); + + var client = HttpHelpers.CreateClient(); + + // Act + await host.StartAsync().DefaultTimeout(); + + var longRunningTask = StartLongRunningRequestAsync(logger, host, client); + + logger.LogInformation("Waiting for request on server"); + await syncPoint.WaitForSyncPoint().DefaultTimeout(); + + logger.LogInformation("Stopping server"); + var stopTask = host.StopAsync(); + + syncPoint.Continue(); + + var (readData, trailers) = await longRunningTask.DefaultTimeout(); + await stopTask.DefaultTimeout(); + + // Assert + Assert.Equal(randomBytes, readData); + if (hasTrailers) + { + Assert.Equal("value!", trailers.GetValues("test-trailer").Single()); + } + } + + private static async Task<(byte[], HttpResponseHeaders)> StartLongRunningRequestAsync(ILogger logger, IHost host, HttpMessageInvoker client) + { + var request = new HttpRequestMessage(HttpMethod.Get, $"http://127.0.0.1:{host.GetPort()}/"); + request.Version = HttpVersion.Version20; + request.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + + var responseMessage = await client.SendAsync(request, CancellationToken.None).DefaultTimeout(); + responseMessage.EnsureSuccessStatusCode(); + + var responseStream = await responseMessage.Content.ReadAsStreamAsync(); + + var data = new List(); + var buffer = new byte[1024 * 128]; + int readCount; + while ((readCount = await responseStream.ReadAsync(buffer)) != 0) + { + data.AddRange(buffer.AsMemory(0, readCount).ToArray()); + logger.LogInformation($"Received {readCount} bytes. Total {data.Count} bytes."); + } + logger.LogInformation($"Finished reading response content"); + + return (data.ToArray(), responseMessage.TrailingHeaders); + } + private IHostBuilder CreateHostBuilder(RequestDelegate requestDelegate, HttpProtocols? protocol = null, Action configureKestrel = null, bool? plaintext = null) { return HttpHelpers.CreateHostBuilder(AddTestLogging, requestDelegate, protocol, configureKestrel, plaintext); diff --git a/src/Servers/Kestrel/test/Interop.FunctionalTests/HttpHelpers.cs b/src/Servers/Kestrel/test/Interop.FunctionalTests/HttpHelpers.cs index dcb49dd6ce45..6dec97958421 100644 --- a/src/Servers/Kestrel/test/Interop.FunctionalTests/HttpHelpers.cs +++ b/src/Servers/Kestrel/test/Interop.FunctionalTests/HttpHelpers.cs @@ -81,7 +81,7 @@ public static IHostBuilder CreateHostBuilder(Action configur } else { - o.ShutdownTimeout = TimeSpan.FromSeconds(1); + o.ShutdownTimeout = TimeSpan.FromSeconds(5); } }); }