Skip to content

Commit

Permalink
Improve stream handling and disposal in transport layer
Browse files Browse the repository at this point in the history
  • Loading branch information
stevejgordon committed Oct 28, 2024
1 parent 2300050 commit a9e4d9e
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 98 deletions.
93 changes: 49 additions & 44 deletions src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -224,65 +224,70 @@ private async ValueTask<TResponse> SetBodyCoreAsync<TResponse>(bool isAsync,
details.ResponseBodyInBytes = bytes;
}

if (SetSpecialTypes<TResponse>(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r;
var isStreamResponse = typeof(TResponse) == typeof(StreamResponse);

if (details.HttpStatusCode.HasValue &&
requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value))
return null;
using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null)
{
if (SetSpecialTypes<TResponse>(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r;

var serializer = requestData.ConnectionSettings.RequestResponseSerializer;
if (details.HttpStatusCode.HasValue &&
requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value))
return null;

TResponse response;
if (requestData.CustomResponseBuilder != null)
{
var beforeTicks = Stopwatch.GetTimestamp();
var serializer = requestData.ConnectionSettings.RequestResponseSerializer;

if (isAsync)
response = await requestData.CustomResponseBuilder
.DeserializeResponseAsync(serializer, details, responseStream, cancellationToken)
.ConfigureAwait(false) as TResponse;
else
response = requestData.CustomResponseBuilder
.DeserializeResponse(serializer, details, responseStream) as TResponse;
TResponse response;
if (requestData.CustomResponseBuilder != null)
{
var beforeTicks = Stopwatch.GetTimestamp();

var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000);
if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested)
Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs);
if (isAsync)
response = await requestData.CustomResponseBuilder
.DeserializeResponseAsync(serializer, details, responseStream, cancellationToken)
.ConfigureAwait(false) as TResponse;
else
response = requestData.CustomResponseBuilder
.DeserializeResponse(serializer, details, responseStream) as TResponse;

return response;
}
var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000);
if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested)
Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs);

// TODO: Handle empty data in a nicer way as throwing exceptions has a cost we'd like to avoid!
// ie. check content-length (add to ApiCallDetails)? Content-length cannot be retrieved from a GZip content stream which is annoying.
try
{
if (requiresErrorDeserialization && TryGetError(details, requestData, responseStream, out var error) && error.HasError())
{
response = new TResponse();
SetErrorOnResponse(response, error);
return response;
}

if (!requestData.ValidateResponseContentType(mimeType))
return default;
// TODO: Handle empty data in a nicer way as throwing exceptions has a cost we'd like to avoid!
// ie. check content-length (add to ApiCallDetails)? Content-length cannot be retrieved from a GZip content stream which is annoying.
try
{
if (requiresErrorDeserialization && TryGetError(details, requestData, responseStream, out var error) && error.HasError())
{
response = new TResponse();
SetErrorOnResponse(response, error);
return response;
}

var beforeTicks = Stopwatch.GetTimestamp();
if (!requestData.ValidateResponseContentType(mimeType))
return default;

if (isAsync)
response = await serializer.DeserializeAsync<TResponse>(responseStream, cancellationToken).ConfigureAwait(false);
else
response = serializer.Deserialize<TResponse>(responseStream);
var beforeTicks = Stopwatch.GetTimestamp();

var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000);
if (isAsync)
response = await serializer.DeserializeAsync<TResponse>(responseStream, cancellationToken).ConfigureAwait(false);
else
response = serializer.Deserialize<TResponse>(responseStream);

if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested)
Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs);
var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000);

return response;
}
catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens"))
{
return default;
if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested)
Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs);

return response;
}
catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens"))
{
return default;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ private async ValueTask<TResponse> RequestCoreAsync<TResponse>(bool isAsync, Req
var isStreamResponse = typeof(TResponse) == typeof(StreamResponse);

using (isStreamResponse ? DiagnosticSources.SingletonDisposable : receive)
using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null)
{
TResponse response;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,31 +162,26 @@ private async ValueTask<TResponse> RequestCoreAsync<TResponse>(bool isAsync, Req
unregisterWaitHandle?.Invoke();
}

var isStreamResponse = typeof(TResponse) == typeof(StreamResponse);
TResponse response;

using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null)
{
TResponse response;

if (isAsync)
response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync<TResponse>
(requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken)
.ConfigureAwait(false);
else
response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse<TResponse>
(requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats);
if (isAsync)
response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync<TResponse>
(requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken)
.ConfigureAwait(false);
else
response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse<TResponse>
(requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats);

if (OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners && (Activity.Current?.IsAllDataRequested ?? false))
if (OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners && (Activity.Current?.IsAllDataRequested ?? false))
{
var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails);
foreach (var attribute in attributes)
{
var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails);
foreach (var attribute in attributes)
{
Activity.Current?.SetTag(attribute.Key, attribute.Value);
}
Activity.Current?.SetTag(attribute.Key, attribute.Value);
}

return response;
}

return response;
}

private static Dictionary<string, IEnumerable<string>> ParseHeaders(RequestData requestData, HttpWebResponse responseMessage, Dictionary<string, IEnumerable<string>> responseHeaders)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,8 @@ public TResponse BuildResponse<TResponse>(RequestData requestData, byte[] respon
var sc = statusCode ?? _statusCode;
Stream responseStream = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody);

var isStreamResponse = typeof(TResponse) == typeof(StreamResponse);

using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null)
{
return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder
.ToResponse<TResponse>(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType ?? RequestData.DefaultMimeType, body?.Length ?? 0, null, null);
}
return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder
.ToResponse<TResponse>(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType ?? RequestData.DefaultMimeType, body?.Length ?? 0, null, null);
}

/// <inheritdoc cref="BuildResponse{TResponse}"/>>
Expand Down Expand Up @@ -122,13 +117,8 @@ public async Task<TResponse> BuildResponseAsync<TResponse>(RequestData requestDa

Stream responseStream = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody);

var isStreamResponse = typeof(TResponse) == typeof(StreamResponse);

using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null)
{
return await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder
.ToResponseAsync<TResponse>(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken)
.ConfigureAwait(false);
}
return await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder
.ToResponseAsync<TResponse>(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken)
.ConfigureAwait(false);
}
}
32 changes: 27 additions & 5 deletions src/Elastic.Transport/Responses/Special/StreamResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ namespace Elastic.Transport;
/// <summary>
/// A response that exposes the response <see cref="TransportResponse{T}.Body"/> as <see cref="Stream"/>.
/// <para>
/// Must be disposed after use.
/// <strong>MUST</strong> be disposed after use to ensure the HTTP connection is freed for reuse.
/// </para>
/// </summary>
public sealed class StreamResponse :
public class StreamResponse :
TransportResponse<Stream>,
IDisposable
{
private bool _disposed;

internal Action? Finalizer { get; set; }

/// <summary>
Expand All @@ -38,10 +40,30 @@ public StreamResponse(Stream body, string? mimeType)
MimeType = mimeType ?? string.Empty;
}

/// <inheritdoc cref="IDisposable.Dispose"/>
/// <summary>
/// Disposes the underlying stream.
/// </summary>
/// <param name="disposing"></param>
protected virtual void Dispose(bool disposing)
{
if (!_disposed)
{
if (disposing)
{
Body.Dispose();
Finalizer?.Invoke();
}

_disposed = true;
}
}

/// <summary>
/// Disposes the underlying stream.
/// </summary>
public void Dispose()
{
Body.Dispose();
Finalizer?.Invoke();
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
// Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information

using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text.Json;
using System.Threading.Tasks;
using Elastic.Transport.IntegrationTests.Plumbing;
using Elastic.Transport.Products.Elasticsearch;
using FluentAssertions;
using Microsoft.AspNetCore.Mvc;
using Xunit;

Expand All @@ -25,8 +28,88 @@ public async Task StreamResponse_ShouldNotBeDisposed()

var response = await transport.PostAsync<StreamResponse>(Path, PostData.String("{}"));

var sr = new StreamReader(response.Body);
var responseString = sr.ReadToEndAsync();
// Ensure the stream is readable
using var sr = new StreamReader(response.Body);
_ = sr.ReadToEndAsync();
}

[Fact]
public async Task StreamResponse_MemoryStreamShouldNotBeDisposed()
{
var nodePool = new SingleNodePool(Server.Uri);
var memoryStreamFactory = new TrackMemoryStreamFactory();
var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient)))
.MemoryStreamFactory(memoryStreamFactory)
.DisableDirectStreaming(true);

var transport = new DistributedTransport(config);

_ = await transport.PostAsync<StreamResponse>(Path, PostData.String("{}"));

var memoryStream = memoryStreamFactory.Created.Last();

memoryStream.IsDisposed.Should().BeFalse();
}

[Fact]
public async Task StringResponse_MemoryStreamShouldBeDisposed()
{
var nodePool = new SingleNodePool(Server.Uri);
var memoryStreamFactory = new TrackMemoryStreamFactory();
var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient)))
.MemoryStreamFactory(memoryStreamFactory)
.DisableDirectStreaming(true);

var transport = new DistributedTransport(config);

_ = await transport.PostAsync<StringResponse>(Path, PostData.String("{}"));

var memoryStream = memoryStreamFactory.Created.Last();

memoryStream.IsDisposed.Should().BeTrue();
}

private class TrackDisposeStream : MemoryStream
{
public TrackDisposeStream() { }

public TrackDisposeStream(byte[] bytes) : base(bytes) { }

public TrackDisposeStream(byte[] bytes, int index, int count) : base(bytes, index, count) { }

public bool IsDisposed { get; private set; }

protected override void Dispose(bool disposing)
{
IsDisposed = true;
base.Dispose(disposing);
}
}

private class TrackMemoryStreamFactory : MemoryStreamFactory
{
public IList<TrackDisposeStream> Created { get; } = [];

public override MemoryStream Create()
{
var stream = new TrackDisposeStream();
Created.Add(stream);
return stream;
}

public override MemoryStream Create(byte[] bytes)
{
var stream = new TrackDisposeStream(bytes);
Created.Add(stream);
return stream;
}

public override MemoryStream Create(byte[] bytes, int index, int count)
{
var stream = new TrackDisposeStream(bytes, index, count);
Created.Add(stream);
return stream;
}
}
}

Expand Down
Loading

0 comments on commit a9e4d9e

Please sign in to comment.