diff --git a/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs index 37f8eb0..d3fa8fd 100644 --- a/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs @@ -66,24 +66,32 @@ public TResponse BuildResponse(RequestData requestData, byte[] respon { var body = responseBody ?? _responseBody; var data = requestData.PostData; - if (data != null) + + if (data is not null) { - using (var stream = requestData.MemoryStreamFactory.Create()) + using var stream = requestData.MemoryStreamFactory.Create(); + if (requestData.HttpCompression) + { + using var zipStream = new GZipStream(stream, CompressionMode.Compress); + data.Write(zipStream, requestData.ConnectionSettings); + } + else { - if (requestData.HttpCompression) - { - using var zipStream = new GZipStream(stream, CompressionMode.Compress); - data.Write(zipStream, requestData.ConnectionSettings); - } - else - data.Write(stream, requestData.ConnectionSettings); + data.Write(stream, requestData.ConnectionSettings); } } requestData.MadeItToResponse = true; var sc = statusCode ?? _statusCode; - Stream s = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody); - return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse(requestData, _exception, sc, _headers, s, contentType ?? _contentType ?? RequestData.DefaultMimeType, body?.Length ?? 0, null, null); + Stream responseStream = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody); + + var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); + + using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) + { + return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder + .ToResponse(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType ?? RequestData.DefaultMimeType, body?.Length ?? 0, null, null); + } } /// > @@ -93,17 +101,19 @@ public async Task BuildResponseAsync(RequestData requestDa { var body = responseBody ?? _responseBody; var data = requestData.PostData; - if (data != null) + + if (data is not null) { - using (var stream = requestData.MemoryStreamFactory.Create()) + using var stream = requestData.MemoryStreamFactory.Create(); + + if (requestData.HttpCompression) + { + using var zipStream = new GZipStream(stream, CompressionMode.Compress); + await data.WriteAsync(zipStream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false); + } + else { - if (requestData.HttpCompression) - { - using var zipStream = new GZipStream(stream, CompressionMode.Compress); - await data.WriteAsync(zipStream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false); - } - else - await data.WriteAsync(stream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false); + await data.WriteAsync(stream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false); } } requestData.MadeItToResponse = true; @@ -117,8 +127,8 @@ public async Task BuildResponseAsync(RequestData requestDa using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) { return await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder - .ToResponseAsync(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken) - .ConfigureAwait(false); + .ToResponseAsync(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken) + .ConfigureAwait(false); } } } diff --git a/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs b/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs new file mode 100644 index 0000000..57333bd --- /dev/null +++ b/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs @@ -0,0 +1,38 @@ +// Licensed to Elasticsearch B.V under one or more agreements. +// Elasticsearch B.V licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information + +using System.IO; +using System.Text.Json; +using System.Threading.Tasks; +using Elastic.Transport.IntegrationTests.Plumbing; +using Elastic.Transport.Products.Elasticsearch; +using Microsoft.AspNetCore.Mvc; +using Xunit; + +namespace Elastic.Transport.IntegrationTests.Http; + +public class StreamResponseTests(TransportTestServer instance) : AssemblyServerTestsBase(instance) +{ + private const string Path = "/streamresponse"; + + [Fact] + public async Task StreamResponse_ShouldNotBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))); + var transport = new DistributedTransport(config); + + var response = await transport.PostAsync(Path, PostData.String("{}")); + + var sr = new StreamReader(response.Body); + var responseString = sr.ReadToEndAsync(); + } +} + +[ApiController, Route("[controller]")] +public class StreamResponseController : ControllerBase +{ + [HttpPost] + public Task Post([FromBody] JsonElement body) => Task.FromResult(body); +} diff --git a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs index a64fdfe..4f0ca61 100644 --- a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs +++ b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs @@ -3,146 +3,70 @@ // See the LICENSE file in the project root for more information using System; -using System.Collections.Generic; using System.IO; -using System.Linq; using System.Threading; using System.Threading.Tasks; using Elastic.Transport.Tests.Plumbing; using FluentAssertions; using Xunit; -namespace Elastic.Transport.Tests +namespace Elastic.Transport.Tests; + +public class ResponseBuilderDisposeTests { - public class ResponseBuilderDisposeTests - { - private readonly ITransportConfiguration _settings = InMemoryConnectionFactory.Create().DisableDirectStreaming(false); - private readonly ITransportConfiguration _settingsDisableDirectStream = InMemoryConnectionFactory.Create().DisableDirectStreaming(); + private readonly ITransportConfiguration _settings = InMemoryConnectionFactory.Create().DisableDirectStreaming(false); - [Fact] public async Task ResponseWithHttpStatusCode() => await AssertRegularResponse(false, 1); + [Fact] + public async Task ResponseWithPotentialBody_StreamIsNotDisposed() => await AssertResponse(expectedDisposed: false); - [Fact] public async Task ResponseBuilderWithNoHttpStatusCode() => await AssertRegularResponse(false); + [Fact] + public async Task ResponseWith204StatusCode_StreamIsDisposed() => await AssertResponse(204); - [Fact] public async Task ResponseWithHttpStatusCodeDisableDirectStreaming() => - await AssertRegularResponse(true, 1); + [Fact] + public async Task ResponseForHeadRequest_StreamIsDisposed() => await AssertResponse(httpMethod: HttpMethod.HEAD); - [Fact] public async Task ResponseBuilderWithNoHttpStatusCodeDisableDirectStreaming() => - await AssertRegularResponse(true); + [Fact] + public async Task ResponseWithZeroContentLength_StreamIsDisposed() => await AssertResponse(contentLength: 0); - private async Task AssertRegularResponse(bool disableDirectStreaming, int? statusCode = null) + private async Task AssertResponse(int statusCode = 200, HttpMethod httpMethod = HttpMethod.GET, int contentLength = 10, bool expectedDisposed = true) + { + var settings = _settings; + var requestData = new RequestData(httpMethod, "/", null, settings, null, null, default) { - var settings = disableDirectStreaming ? _settingsDisableDirectStream : _settings; - var memoryStreamFactory = new TrackMemoryStreamFactory(); - var requestData = new RequestData(HttpMethod.GET, "/", null, settings, null, memoryStreamFactory, default) - { - Node = new Node(new Uri("http://localhost:9200")) - }; - - var stream = new TrackDisposeStream(); - var response = _settings.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, null, -1, null, null); - response.Should().NotBeNull(); - - memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 1 : 0); - if (disableDirectStreaming) - { - var memoryStream = memoryStreamFactory.Created[0]; - memoryStream.IsDisposed.Should().BeTrue(); - } - stream.IsDisposed.Should().BeTrue(); - - - stream = new TrackDisposeStream(); - var ct = new CancellationToken(); - response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, -1, null, null, - cancellationToken: ct); - response.Should().NotBeNull(); - memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 2 : 0); - if (disableDirectStreaming) - { - var memoryStream = memoryStreamFactory.Created[1]; - memoryStream.IsDisposed.Should().BeTrue(); - } - stream.IsDisposed.Should().BeTrue(); - } + Node = new Node(new Uri("http://localhost:9200")) + }; - [Fact] public async Task StreamResponseWithHttpStatusCode() => await AssertStreamResponse(false, 200); + var stream = new TrackDisposeStream(); - [Fact] public async Task StreamResponseBuilderWithNoHttpStatusCode() => await AssertStreamResponse(false); + var response = _settings.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, null, contentLength, null, null); - [Fact] public async Task StreamResponseWithHttpStatusCodeDisableDirectStreaming() => - await AssertStreamResponse(true, 1); + response.Should().NotBeNull(); + stream.IsDisposed.Should().Be(expectedDisposed); - [Fact] public async Task StreamResponseBuilderWithNoHttpStatusCodeDisableDirectStreaming() => - await AssertStreamResponse(true); + stream = new TrackDisposeStream(); + var ct = new CancellationToken(); - private async Task AssertStreamResponse(bool disableDirectStreaming, int? statusCode = null) - { - var settings = disableDirectStreaming ? _settingsDisableDirectStream : _settings; - var memoryStreamFactory = new TrackMemoryStreamFactory(); - - var requestData = new RequestData(HttpMethod.GET, "/", null, settings, null, memoryStreamFactory, default) - { - Node = new Node(new Uri("http://localhost:9200")) - }; - - var stream = new TrackDisposeStream(); - var response = _settings.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, null, -1, null, null); - response.Should().NotBeNull(); - - memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 1 : 0); - stream.IsDisposed.Should().Be(true); - - stream = new TrackDisposeStream(); - var ct = new CancellationToken(); - response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, -1, null, null, - cancellationToken: ct); - response.Should().NotBeNull(); - memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 2 : 0); - stream.IsDisposed.Should().Be(true); - } + response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, contentLength, null, null, + cancellationToken: ct); + response.Should().NotBeNull(); + stream.IsDisposed.Should().Be(expectedDisposed); + } - private class TrackDisposeStream : MemoryStream - { - public TrackDisposeStream() { } - - public TrackDisposeStream(byte[] bytes) : base(bytes) { } + private class TrackDisposeStream : MemoryStream + { + public TrackDisposeStream() { } - public TrackDisposeStream(byte[] bytes, int index, int count) : base(bytes, index, count) { } + public TrackDisposeStream(byte[] bytes) : base(bytes) { } - public bool IsDisposed { get; private set; } + public TrackDisposeStream(byte[] bytes, int index, int count) : base(bytes, index, count) { } - protected override void Dispose(bool disposing) - { - IsDisposed = true; - base.Dispose(disposing); - } - } + public bool IsDisposed { get; private set; } - private class TrackMemoryStreamFactory : MemoryStreamFactory + protected override void Dispose(bool disposing) { - public IList Created { get; } = new List(); - - public override MemoryStream Create() - { - var stream = new TrackDisposeStream(); - Created.Add(stream); - return stream; - } - - public override MemoryStream Create(byte[] bytes) - { - var stream = new TrackDisposeStream(bytes); - Created.Add(stream); - return stream; - } - - public override MemoryStream Create(byte[] bytes, int index, int count) - { - var stream = new TrackDisposeStream(bytes, index, count); - Created.Add(stream); - return stream; - } + IsDisposed = true; + base.Dispose(disposing); } } }