Skip to content

Commit

Permalink
Fix Kestrel conn close on shutdown while response is writing (#40933)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK authored Mar 31, 2022
1 parent cef52ff commit 4c5a4db
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 72 deletions.
102 changes: 56 additions & 46 deletions src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,19 @@ public ValueTask<FlushResult> Write100ContinueAsync(int streamId)
| Padding (*) ...
+---------------------------------------------------------------+
*/
public void WriteResponseHeaders(int streamId, int statusCode, Http2HeadersFrameFlags headerFrameFlags, HttpResponseHeaders headers)
public void WriteResponseHeaders(Http2Stream stream, int statusCode, bool endStream, HttpResponseHeaders headers)
{
Http2HeadersFrameFlags headerFrameFlags;
if (endStream)
{
headerFrameFlags = Http2HeadersFrameFlags.END_STREAM;
stream.DecrementActiveClientStreamCount();
}
else
{
headerFrameFlags = Http2HeadersFrameFlags.NONE;
}

lock (_writeLock)
{
if (_completed)
Expand All @@ -184,24 +195,26 @@ public void WriteResponseHeaders(int streamId, int statusCode, Http2HeadersFrame
try
{
_headersEnumerator.Initialize(headers);
_outgoingFrame.PrepareHeaders(headerFrameFlags, streamId);
_outgoingFrame.PrepareHeaders(headerFrameFlags, stream.StreamId);
var buffer = _headerEncodingBuffer.AsSpan();
var done = HPackHeaderWriter.BeginEncodeHeaders(statusCode, _hpackEncoder, _headersEnumerator, buffer, out var payloadLength);
FinishWritingHeaders(streamId, payloadLength, done);
FinishWritingHeaders(stream.StreamId, payloadLength, done);
}
// Any exception from the HPack encoder can leave the dynamic table in a corrupt state.
// Since we allow custom header encoders we don't know what type of exceptions to expect.
catch (Exception ex)
{
_log.HPackEncodingError(_connectionId, streamId, ex);
_log.HPackEncodingError(_connectionId, stream.StreamId, ex);
_http2Connection.Abort(new ConnectionAbortedException(ex.Message, ex));
throw new InvalidOperationException(ex.Message, ex); // Report the error to the user if this was the first write.
}
}
}

public ValueTask<FlushResult> WriteResponseTrailersAsync(int streamId, HttpResponseTrailers headers)
public ValueTask<FlushResult> WriteResponseTrailersAsync(Http2Stream stream, HttpResponseTrailers headers)
{
stream.DecrementActiveClientStreamCount();

lock (_writeLock)
{
if (_completed)
Expand All @@ -212,16 +225,16 @@ public ValueTask<FlushResult> WriteResponseTrailersAsync(int streamId, HttpRespo
try
{
_headersEnumerator.Initialize(headers);
_outgoingFrame.PrepareHeaders(Http2HeadersFrameFlags.END_STREAM, streamId);
_outgoingFrame.PrepareHeaders(Http2HeadersFrameFlags.END_STREAM, stream.StreamId);
var buffer = _headerEncodingBuffer.AsSpan();
var done = HPackHeaderWriter.BeginEncodeHeaders(_hpackEncoder, _headersEnumerator, buffer, out var payloadLength);
FinishWritingHeaders(streamId, payloadLength, done);
FinishWritingHeaders(stream.StreamId, payloadLength, done);
}
// Any exception from the HPack encoder can leave the dynamic table in a corrupt state.
// Since we allow custom header encoders we don't know what type of exceptions to expect.
catch (Exception ex)
{
_log.HPackEncodingError(_connectionId, streamId, ex);
_log.HPackEncodingError(_connectionId, stream.StreamId, ex);
_http2Connection.Abort(new ConnectionAbortedException(ex.Message, ex));
}

Expand Down Expand Up @@ -258,7 +271,7 @@ private void FinishWritingHeaders(int streamId, int payloadLength, bool done)
}
}

public ValueTask<FlushResult> WriteDataAsync(int streamId, StreamOutputFlowControl flowControl, in ReadOnlySequence<byte> data, bool endStream, bool firstWrite, bool forceFlush)
public ValueTask<FlushResult> WriteDataAsync(Http2Stream stream, StreamOutputFlowControl flowControl, in ReadOnlySequence<byte> data, bool endStream, bool firstWrite, bool forceFlush)
{
// Logic in this method is replicated in WriteDataAndTrailersAsync.
// Changes here may need to be mirrored in WriteDataAndTrailersAsync.
Expand All @@ -277,12 +290,12 @@ public ValueTask<FlushResult> WriteDataAsync(int streamId, StreamOutputFlowContr
// https://httpwg.org/specs/rfc7540.html#rfc.section.6.9.1
if (dataLength != 0 && dataLength > flowControl.Available)
{
return WriteDataAsync(streamId, flowControl, data, dataLength, endStream, firstWrite);
return WriteDataAsync(stream, flowControl, data, dataLength, endStream, firstWrite);
}

// This cast is safe since if dataLength would overflow an int, it's guaranteed to be greater than the available flow control window.
flowControl.Advance((int)dataLength);
WriteDataUnsynchronized(streamId, data, dataLength, endStream);
WriteDataUnsynchronized(stream, data, dataLength, endStream);

if (forceFlush)
{
Expand All @@ -293,7 +306,7 @@ public ValueTask<FlushResult> WriteDataAsync(int streamId, StreamOutputFlowContr
}
}

public ValueTask<FlushResult> WriteDataAndTrailersAsync(int streamId, StreamOutputFlowControl flowControl, in ReadOnlySequence<byte> data, bool firstWrite, HttpResponseTrailers headers)
public ValueTask<FlushResult> WriteDataAndTrailersAsync(Http2Stream stream, StreamOutputFlowControl flowControl, in ReadOnlySequence<byte> data, bool firstWrite, HttpResponseTrailers headers)
{
// This method combines WriteDataAsync and WriteResponseTrailers.
// Changes here may need to be mirrored in WriteDataAsync.
Expand All @@ -312,21 +325,21 @@ public ValueTask<FlushResult> WriteDataAndTrailersAsync(int streamId, StreamOutp
// https://httpwg.org/specs/rfc7540.html#rfc.section.6.9.1
if (dataLength != 0 && dataLength > flowControl.Available)
{
return WriteDataAndTrailersAsyncCore(this, streamId, flowControl, data, dataLength, firstWrite, headers);
return WriteDataAndTrailersAsyncCore(this, stream, flowControl, data, dataLength, firstWrite, headers);
}

// This cast is safe since if dataLength would overflow an int, it's guaranteed to be greater than the available flow control window.
flowControl.Advance((int)dataLength);
WriteDataUnsynchronized(streamId, data, dataLength, endStream: false);
WriteDataUnsynchronized(stream, data, dataLength, endStream: false);

return WriteResponseTrailersAsync(streamId, headers);
return WriteResponseTrailersAsync(stream, headers);
}

static async ValueTask<FlushResult> WriteDataAndTrailersAsyncCore(Http2FrameWriter writer, int streamId, StreamOutputFlowControl flowControl, ReadOnlySequence<byte> data, long dataLength, bool firstWrite, HttpResponseTrailers headers)
static async ValueTask<FlushResult> WriteDataAndTrailersAsyncCore(Http2FrameWriter writer, Http2Stream stream, StreamOutputFlowControl flowControl, ReadOnlySequence<byte> data, long dataLength, bool firstWrite, HttpResponseTrailers headers)
{
await writer.WriteDataAsync(streamId, flowControl, data, dataLength, endStream: false, firstWrite);
await writer.WriteDataAsync(stream, flowControl, data, dataLength, endStream: false, firstWrite);

return await writer.WriteResponseTrailersAsync(streamId, headers);
return await writer.WriteResponseTrailersAsync(stream, headers);
}
}

Expand All @@ -339,29 +352,20 @@ static async ValueTask<FlushResult> WriteDataAndTrailersAsyncCore(Http2FrameWrit
| Padding (*) ...
+---------------------------------------------------------------+
*/
private void WriteDataUnsynchronized(int streamId, in ReadOnlySequence<byte> data, long dataLength, bool endStream)
private void WriteDataUnsynchronized(Http2Stream stream, in ReadOnlySequence<byte> data, long dataLength, bool endStream)
{
Debug.Assert(dataLength == data.Length);

// Note padding is not implemented
_outgoingFrame.PrepareData(streamId);
_outgoingFrame.PrepareData(stream.StreamId);

if (dataLength > _maxFrameSize) // Minus padding
{
TrimAndWriteDataUnsynchronized(in data, dataLength, endStream);
return;
}

if (endStream)
{
_outgoingFrame.DataFlags |= Http2DataFrameFlags.END_STREAM;
}

_outgoingFrame.PayloadLength = (int)dataLength; // Plus padding

WriteHeaderUnsynchronized();

data.CopyTo(_outputWriter);
WriteDataUnsynchronizedCore(stream, endStream, dataLength, data);

// Plus padding
return;
Expand All @@ -378,40 +382,46 @@ void TrimAndWriteDataUnsynchronized(in ReadOnlySequence<byte> data, long dataLen
do
{
var currentData = remainingData.Slice(0, dataPayloadLength);
_outgoingFrame.PayloadLength = dataPayloadLength; // Plus padding

WriteHeaderUnsynchronized();

foreach (var buffer in currentData)
{
_outputWriter.Write(buffer.Span);
}
WriteDataUnsynchronizedCore(stream, endStream: false, dataPayloadLength, currentData);

// Plus padding
dataLength -= dataPayloadLength;
remainingData = remainingData.Slice(dataPayloadLength);

} while (dataLength > dataPayloadLength);

WriteDataUnsynchronizedCore(stream, endStream, dataLength, remainingData);

// Plus padding
}

void WriteDataUnsynchronizedCore(Http2Stream stream, bool endStream, long dataLength, in ReadOnlySequence<byte> data)
{
Debug.Assert(dataLength == data.Length);

if (endStream)
{
_outgoingFrame.DataFlags |= Http2DataFrameFlags.END_STREAM;

// When writing data, must decrement active stream count after flow control availability is checked.
// If active stream count becomes zero while a graceful shutdown is in progress then the input side of connection is closed.
// This is a problem if a large amount of data is being written. The server must keep processing incoming WINDOW_UPDATE frames.
// No WINDOW_UPDATE frames means response write could hit flow control and hang.
// Decrement also has to happen before writing END_STREAM to client to avoid race over active stream count.
stream.DecrementActiveClientStreamCount();
}

// It can be expensive to get length from ROS. Use already available value.
_outgoingFrame.PayloadLength = (int)dataLength; // Plus padding

WriteHeaderUnsynchronized();

foreach (var buffer in remainingData)
{
_outputWriter.Write(buffer.Span);
}

// Plus padding
data.CopyTo(_outputWriter);
}
}

private async ValueTask<FlushResult> WriteDataAsync(int streamId, StreamOutputFlowControl flowControl, ReadOnlySequence<byte> data, long dataLength, bool endStream, bool firstWrite)
private async ValueTask<FlushResult> WriteDataAsync(Http2Stream stream, StreamOutputFlowControl flowControl, ReadOnlySequence<byte> data, long dataLength, bool endStream, bool firstWrite)
{
FlushResult flushResult = default;

Expand All @@ -436,13 +446,13 @@ private async ValueTask<FlushResult> WriteDataAsync(int streamId, StreamOutputFl
{
if (actual < dataLength)
{
WriteDataUnsynchronized(streamId, data.Slice(0, actual), actual, endStream: false);
WriteDataUnsynchronized(stream, data.Slice(0, actual), actual, endStream: false);
data = data.Slice(actual);
dataLength -= actual;
}
else
{
WriteDataUnsynchronized(streamId, data, actual, endStream);
WriteDataUnsynchronized(stream, data, actual, endStream);
dataLength = 0;
}

Expand Down
30 changes: 7 additions & 23 deletions src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -203,20 +203,11 @@ public void WriteResponseHeaders(int statusCode, string? reasonPhrase, HttpRespo
// The headers will be the final frame if:
// 1. There is no content
// 2. There is no trailing HEADERS frame.
Http2HeadersFrameFlags http2HeadersFrame;
_streamEnded = appCompleted
&& !_startedWritingDataFrames
&& (_stream.ResponseTrailers == null || _stream.ResponseTrailers.Count == 0);

if (appCompleted && !_startedWritingDataFrames && (_stream.ResponseTrailers == null || _stream.ResponseTrailers.Count == 0))
{
_streamEnded = true;
_stream.DecrementActiveClientStreamCount();
http2HeadersFrame = Http2HeadersFrameFlags.END_STREAM;
}
else
{
http2HeadersFrame = Http2HeadersFrameFlags.NONE;
}

_frameWriter.WriteResponseHeaders(StreamId, statusCode, http2HeadersFrame, responseHeaders);
_frameWriter.WriteResponseHeaders(_stream, statusCode, _streamEnded, responseHeaders);
}
}

Expand Down Expand Up @@ -429,16 +420,15 @@ private async Task ProcessDataWrites()
// Write any remaining content then write trailers

_stream.ResponseTrailers.SetReadOnly();
_stream.DecrementActiveClientStreamCount();

if (readResult.Buffer.Length > 0)
{
// It is faster to write data and trailers together. Locking once reduces lock contention.
flushResult = await _frameWriter.WriteDataAndTrailersAsync(StreamId, _flowControl, readResult.Buffer, firstWrite, _stream.ResponseTrailers);
flushResult = await _frameWriter.WriteDataAndTrailersAsync(_stream, _flowControl, readResult.Buffer, firstWrite, _stream.ResponseTrailers);
}
else
{
flushResult = await _frameWriter.WriteResponseTrailersAsync(StreamId, _stream.ResponseTrailers);
flushResult = await _frameWriter.WriteResponseTrailersAsync(_stream, _stream.ResponseTrailers);
}
}
else if (readResult.IsCompleted && _streamEnded)
Expand All @@ -454,13 +444,7 @@ private async Task ProcessDataWrites()
else
{
var endStream = readResult.IsCompleted;

if (endStream)
{
_stream.DecrementActiveClientStreamCount();
}

flushResult = await _frameWriter.WriteDataAsync(StreamId, _flowControl, readResult.Buffer, endStream, firstWrite, forceFlush: true);
flushResult = await _frameWriter.WriteDataAsync(_stream, _flowControl, readResult.Buffer, endStream, firstWrite, forceFlush: true);
}

firstWrite = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public class Http2FrameWriterBenchmark
private Pipe _pipe;
private Http2FrameWriter _frameWriter;
private HttpResponseHeaders _responseHeaders;
private Http2Stream _stream;

[GlobalSetup]
public void GlobalSetup()
Expand All @@ -45,6 +46,8 @@ public void GlobalSetup()
_memoryPool,
serviceContext);

_stream = new MockHttp2Stream(TestContextFactory.CreateHttp2StreamContext(streamId: 0));

_responseHeaders = new HttpResponseHeaders();
var headers = (IHeaderDictionary)_responseHeaders;
headers.ContentType = "application/json";
Expand All @@ -54,7 +57,7 @@ public void GlobalSetup()
[Benchmark]
public void WriteResponseHeaders()
{
_frameWriter.WriteResponseHeaders(0, 200, Http2HeadersFrameFlags.END_HEADERS, _responseHeaders);
_frameWriter.WriteResponseHeaders(_stream, 200, endStream: true, _responseHeaders);
}

[GlobalCleanup]
Expand All @@ -63,4 +66,16 @@ public void Dispose()
_pipe.Writer.Complete();
_memoryPool?.Dispose();
}

private class MockHttp2Stream : Http2Stream
{
public MockHttp2Stream(Http2StreamContext context)
{
Initialize(context);
}

public override void Execute()
{
}
}
}
13 changes: 12 additions & 1 deletion src/Servers/Kestrel/shared/test/TestContextFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ public static Http2StreamContext CreateHttp2StreamContext(
localEndPoint: localEndPoint,
remoteEndPoint: remoteEndPoint,
streamId: streamId ?? 0,
streamLifetimeHandler: streamLifetimeHandler,
streamLifetimeHandler: streamLifetimeHandler ?? new TestHttp2StreamLifetimeHandler(),
clientPeerSettings: clientPeerSettings ?? new Http2PeerSettings(),
serverPeerSettings: serverPeerSettings ?? new Http2PeerSettings(),
frameWriter: frameWriter,
Expand Down Expand Up @@ -201,6 +201,17 @@ public static Http3StreamContext CreateHttp3StreamContext(
return context;
}

private class TestHttp2StreamLifetimeHandler : IHttp2StreamLifetimeHandler
{
public void DecrementActiveClientStreamCount()
{
}

public void OnStreamCompleted(Http2Stream stream)
{
}
}

private class TestMultiplexedConnectionContext : MultiplexedConnectionContext
{
public override string ConnectionId { get; set; }
Expand Down
Loading

0 comments on commit 4c5a4db

Please sign in to comment.