From a9e4d9ef61b3bff5ea952545d4967a3dbcc67391 Mon Sep 17 00:00:00 2001 From: Steve Gordon Date: Mon, 28 Oct 2024 14:33:46 +0000 Subject: [PATCH] Improve stream handling and disposal in transport layer --- .../Pipeline/DefaultResponseBuilder.cs | 93 +++++++++--------- .../TransportClient/HttpRequestInvoker.cs | 1 - .../TransportClient/HttpWebRequestInvoker.cs | 35 +++---- .../TransportClient/InMemoryRequestInvoker.cs | 20 +--- .../Responses/Special/StreamResponse.cs | 32 ++++++- .../Http/StreamResponseTests.cs | 87 ++++++++++++++++- .../ResponseBuilderDisposeTests.cs | 96 ++++++++++++++++--- 7 files changed, 266 insertions(+), 98 deletions(-) diff --git a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs index 414bbd5..86352e9 100644 --- a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs +++ b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs @@ -224,65 +224,70 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, details.ResponseBodyInBytes = bytes; } - if (SetSpecialTypes(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r; + var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); - if (details.HttpStatusCode.HasValue && - requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value)) - return null; + using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) + { + if (SetSpecialTypes(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r; - var serializer = requestData.ConnectionSettings.RequestResponseSerializer; + if (details.HttpStatusCode.HasValue && + requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value)) + return null; - TResponse response; - if (requestData.CustomResponseBuilder != null) - { - var beforeTicks = Stopwatch.GetTimestamp(); + var serializer = requestData.ConnectionSettings.RequestResponseSerializer; - if (isAsync) - response = await requestData.CustomResponseBuilder - .DeserializeResponseAsync(serializer, details, responseStream, cancellationToken) - .ConfigureAwait(false) as TResponse; - else - response = requestData.CustomResponseBuilder - .DeserializeResponse(serializer, details, responseStream) as TResponse; + TResponse response; + if (requestData.CustomResponseBuilder != null) + { + var beforeTicks = Stopwatch.GetTimestamp(); - var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); - if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) - Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + if (isAsync) + response = await requestData.CustomResponseBuilder + .DeserializeResponseAsync(serializer, details, responseStream, cancellationToken) + .ConfigureAwait(false) as TResponse; + else + response = requestData.CustomResponseBuilder + .DeserializeResponse(serializer, details, responseStream) as TResponse; - return response; - } + var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); + if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) + Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); - // TODO: Handle empty data in a nicer way as throwing exceptions has a cost we'd like to avoid! - // ie. check content-length (add to ApiCallDetails)? Content-length cannot be retrieved from a GZip content stream which is annoying. - try - { - if (requiresErrorDeserialization && TryGetError(details, requestData, responseStream, out var error) && error.HasError()) - { - response = new TResponse(); - SetErrorOnResponse(response, error); return response; } - if (!requestData.ValidateResponseContentType(mimeType)) - return default; + // TODO: Handle empty data in a nicer way as throwing exceptions has a cost we'd like to avoid! + // ie. check content-length (add to ApiCallDetails)? Content-length cannot be retrieved from a GZip content stream which is annoying. + try + { + if (requiresErrorDeserialization && TryGetError(details, requestData, responseStream, out var error) && error.HasError()) + { + response = new TResponse(); + SetErrorOnResponse(response, error); + return response; + } - var beforeTicks = Stopwatch.GetTimestamp(); + if (!requestData.ValidateResponseContentType(mimeType)) + return default; - if (isAsync) - response = await serializer.DeserializeAsync(responseStream, cancellationToken).ConfigureAwait(false); - else - response = serializer.Deserialize(responseStream); + var beforeTicks = Stopwatch.GetTimestamp(); - var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); + if (isAsync) + response = await serializer.DeserializeAsync(responseStream, cancellationToken).ConfigureAwait(false); + else + response = serializer.Deserialize(responseStream); - if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) - Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); - return response; - } - catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens")) - { - return default; + if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) + Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + + return response; + } + catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens")) + { + return default; + } } } diff --git a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs index a466b78..e58df63 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs @@ -157,7 +157,6 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); using (isStreamResponse ? DiagnosticSources.SingletonDisposable : receive) - using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) { TResponse response; diff --git a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs index fb138c9..4e5b859 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs @@ -162,31 +162,26 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req unregisterWaitHandle?.Invoke(); } - var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); + TResponse response; - using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) - { - TResponse response; - - if (isAsync) - response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync - (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) - .ConfigureAwait(false); - else - response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse - (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); + if (isAsync) + response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync + (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) + .ConfigureAwait(false); + else + response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse + (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); - if (OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners && (Activity.Current?.IsAllDataRequested ?? false)) + if (OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners && (Activity.Current?.IsAllDataRequested ?? false)) + { + var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails); + foreach (var attribute in attributes) { - var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails); - foreach (var attribute in attributes) - { - Activity.Current?.SetTag(attribute.Key, attribute.Value); - } + Activity.Current?.SetTag(attribute.Key, attribute.Value); } - - return response; } + + return response; } private static Dictionary> ParseHeaders(RequestData requestData, HttpWebResponse responseMessage, Dictionary> responseHeaders) diff --git a/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs index d3fa8fd..aead20a 100644 --- a/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs @@ -85,13 +85,8 @@ public TResponse BuildResponse(RequestData requestData, byte[] respon var sc = statusCode ?? _statusCode; Stream responseStream = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody); - var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); - - using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) - { - return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder - .ToResponse(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType ?? RequestData.DefaultMimeType, body?.Length ?? 0, null, null); - } + return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder + .ToResponse(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType ?? RequestData.DefaultMimeType, body?.Length ?? 0, null, null); } /// > @@ -122,13 +117,8 @@ public async Task BuildResponseAsync(RequestData requestDa Stream responseStream = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody); - var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); - - using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) - { - return await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder - .ToResponseAsync(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken) - .ConfigureAwait(false); - } + return await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder + .ToResponseAsync(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken) + .ConfigureAwait(false); } } diff --git a/src/Elastic.Transport/Responses/Special/StreamResponse.cs b/src/Elastic.Transport/Responses/Special/StreamResponse.cs index 53dd22a..d920a06 100644 --- a/src/Elastic.Transport/Responses/Special/StreamResponse.cs +++ b/src/Elastic.Transport/Responses/Special/StreamResponse.cs @@ -10,13 +10,15 @@ namespace Elastic.Transport; /// /// A response that exposes the response as . /// -/// Must be disposed after use. +/// MUST be disposed after use to ensure the HTTP connection is freed for reuse. /// /// -public sealed class StreamResponse : +public class StreamResponse : TransportResponse, IDisposable { + private bool _disposed; + internal Action? Finalizer { get; set; } /// @@ -38,10 +40,30 @@ public StreamResponse(Stream body, string? mimeType) MimeType = mimeType ?? string.Empty; } - /// + /// + /// Disposes the underlying stream. + /// + /// + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + Body.Dispose(); + Finalizer?.Invoke(); + } + + _disposed = true; + } + } + + /// + /// Disposes the underlying stream. + /// public void Dispose() { - Body.Dispose(); - Finalizer?.Invoke(); + Dispose(disposing: true); + GC.SuppressFinalize(this); } } diff --git a/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs b/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs index 57333bd..e006567 100644 --- a/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs +++ b/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs @@ -2,11 +2,14 @@ // Elasticsearch B.V licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information +using System.Collections.Generic; using System.IO; +using System.Linq; using System.Text.Json; using System.Threading.Tasks; using Elastic.Transport.IntegrationTests.Plumbing; using Elastic.Transport.Products.Elasticsearch; +using FluentAssertions; using Microsoft.AspNetCore.Mvc; using Xunit; @@ -25,8 +28,88 @@ public async Task StreamResponse_ShouldNotBeDisposed() var response = await transport.PostAsync(Path, PostData.String("{}")); - var sr = new StreamReader(response.Body); - var responseString = sr.ReadToEndAsync(); + // Ensure the stream is readable + using var sr = new StreamReader(response.Body); + _ = sr.ReadToEndAsync(); + } + + [Fact] + public async Task StreamResponse_MemoryStreamShouldNotBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var memoryStreamFactory = new TrackMemoryStreamFactory(); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) + .MemoryStreamFactory(memoryStreamFactory) + .DisableDirectStreaming(true); + + var transport = new DistributedTransport(config); + + _ = await transport.PostAsync(Path, PostData.String("{}")); + + var memoryStream = memoryStreamFactory.Created.Last(); + + memoryStream.IsDisposed.Should().BeFalse(); + } + + [Fact] + public async Task StringResponse_MemoryStreamShouldBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var memoryStreamFactory = new TrackMemoryStreamFactory(); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) + .MemoryStreamFactory(memoryStreamFactory) + .DisableDirectStreaming(true); + + var transport = new DistributedTransport(config); + + _ = await transport.PostAsync(Path, PostData.String("{}")); + + var memoryStream = memoryStreamFactory.Created.Last(); + + memoryStream.IsDisposed.Should().BeTrue(); + } + + private class TrackDisposeStream : MemoryStream + { + public TrackDisposeStream() { } + + public TrackDisposeStream(byte[] bytes) : base(bytes) { } + + public TrackDisposeStream(byte[] bytes, int index, int count) : base(bytes, index, count) { } + + public bool IsDisposed { get; private set; } + + protected override void Dispose(bool disposing) + { + IsDisposed = true; + base.Dispose(disposing); + } + } + + private class TrackMemoryStreamFactory : MemoryStreamFactory + { + public IList Created { get; } = []; + + public override MemoryStream Create() + { + var stream = new TrackDisposeStream(); + Created.Add(stream); + return stream; + } + + public override MemoryStream Create(byte[] bytes) + { + var stream = new TrackDisposeStream(bytes); + Created.Add(stream); + return stream; + } + + public override MemoryStream Create(byte[] bytes, int index, int count) + { + var stream = new TrackDisposeStream(bytes, index, count); + Created.Add(stream); + return stream; + } } } diff --git a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs index 4f0ca61..dade397 100644 --- a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs +++ b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information using System; +using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; @@ -15,42 +16,89 @@ namespace Elastic.Transport.Tests; public class ResponseBuilderDisposeTests { private readonly ITransportConfiguration _settings = InMemoryConnectionFactory.Create().DisableDirectStreaming(false); + private readonly ITransportConfiguration _settingsDisableDirectStream = InMemoryConnectionFactory.Create().DisableDirectStreaming(); [Fact] - public async Task ResponseWithPotentialBody_StreamIsNotDisposed() => await AssertResponse(expectedDisposed: false); + public async Task StreamResponseWithPotentialBody_StreamIsNotDisposed() => await AssertResponse(false, expectedDisposed: false); [Fact] - public async Task ResponseWith204StatusCode_StreamIsDisposed() => await AssertResponse(204); + public async Task StreamResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsNotDisposed() => await AssertResponse(true, expectedDisposed: false); [Fact] - public async Task ResponseForHeadRequest_StreamIsDisposed() => await AssertResponse(httpMethod: HttpMethod.HEAD); + public async Task StreamResponseWith204StatusCode_StreamIsDisposed() => await AssertResponse(false, 204); [Fact] - public async Task ResponseWithZeroContentLength_StreamIsDisposed() => await AssertResponse(contentLength: 0); + public async Task StreamResponseForHeadRequest_StreamIsDisposed() => await AssertResponse(false, httpMethod: HttpMethod.HEAD); - private async Task AssertResponse(int statusCode = 200, HttpMethod httpMethod = HttpMethod.GET, int contentLength = 10, bool expectedDisposed = true) + [Fact] + public async Task StreamResponseWithZeroContentLength_StreamIsDisposed() => await AssertResponse(false, contentLength: 0); + + [Fact] + public async Task ResponseWithPotentialBody_StreamIsDisposed() => await AssertResponse(false, expectedDisposed: true); + + [Fact] + public async Task ResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => await AssertResponse(true, expectedDisposed: true); + + [Fact] + public async Task ResponseWith204StatusCode_StreamIsDisposed() => await AssertResponse(false, 204); + + [Fact] + public async Task ResponseForHeadRequest_StreamIsDisposed() => await AssertResponse(false, httpMethod: HttpMethod.HEAD); + + [Fact] + public async Task ResponseWithZeroContentLength_StreamIsDisposed() => await AssertResponse(false, contentLength: 0); + + [Fact] + public async Task StringResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => await AssertResponse(true, expectedDisposed: true, memoryStreamCreateExpected: 1); + + private async Task AssertResponse(bool disableDirectStreaming, int statusCode = 200, HttpMethod httpMethod = HttpMethod.GET, int contentLength = 10, bool expectedDisposed = true, int memoryStreamCreateExpected = -1) + where T : TransportResponse, new() { - var settings = _settings; - var requestData = new RequestData(httpMethod, "/", null, settings, null, null, default) + var settings = disableDirectStreaming ? _settingsDisableDirectStream : _settings; + var memoryStreamFactory = new TrackMemoryStreamFactory(); + + var requestData = new RequestData(httpMethod, "/", null, settings, null, memoryStreamFactory, default) { Node = new Node(new Uri("http://localhost:9200")) }; var stream = new TrackDisposeStream(); - var response = _settings.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, null, contentLength, null, null); + var response = _settings.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, null, contentLength, null, null); response.Should().NotBeNull(); - stream.IsDisposed.Should().Be(expectedDisposed); + + memoryStreamFactory.Created.Count.Should().Be(memoryStreamCreateExpected > -1 ? memoryStreamCreateExpected : disableDirectStreaming ? 1 : 0); + if (disableDirectStreaming) + { + var memoryStream = memoryStreamFactory.Created[0]; + stream.IsDisposed.Should().BeTrue(); + memoryStream.IsDisposed.Should().Be(expectedDisposed); + } + else + { + stream.IsDisposed.Should().Be(expectedDisposed); + } stream = new TrackDisposeStream(); var ct = new CancellationToken(); - response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, contentLength, null, null, + response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, contentLength, null, null, cancellationToken: ct); response.Should().NotBeNull(); - stream.IsDisposed.Should().Be(expectedDisposed); + + memoryStreamFactory.Created.Count.Should().Be(memoryStreamCreateExpected > -1 ? memoryStreamCreateExpected + 1: disableDirectStreaming ? 2 : 0); + if (disableDirectStreaming) + { + var memoryStream = memoryStreamFactory.Created[0]; + stream.IsDisposed.Should().BeTrue(); + memoryStream.IsDisposed.Should().Be(expectedDisposed); + } + else + { + stream.IsDisposed.Should().Be(expectedDisposed); + } } private class TrackDisposeStream : MemoryStream @@ -69,4 +117,30 @@ protected override void Dispose(bool disposing) base.Dispose(disposing); } } + + private class TrackMemoryStreamFactory : MemoryStreamFactory + { + public IList Created { get; } = []; + + public override MemoryStream Create() + { + var stream = new TrackDisposeStream(); + Created.Add(stream); + return stream; + } + + public override MemoryStream Create(byte[] bytes) + { + var stream = new TrackDisposeStream(bytes); + Created.Add(stream); + return stream; + } + + public override MemoryStream Create(byte[] bytes, int index, int count) + { + var stream = new TrackDisposeStream(bytes, index, count); + Created.Add(stream); + return stream; + } + } }