Skip to content

Commit

Permalink
Replace RequestData with BoundConfiguration (#141)
Browse files Browse the repository at this point in the history
This PR involves a comprehensive refactor to replace the `RequestData`
class with a new class named `BoundConfiguration` across the codebase.
The `BoundConfiguration` class encapsulates configuration details and
implements the `IRequestConfiguration` interface, which avoids
introducing new overloads of the `Request` \ `RequestAsync` methods. We
type check in the transport, and when the provided
`IRequestConfiguration` is a `BoundConfiguration`, we use that without
rebinding.

Closes #138
  • Loading branch information
stevejgordon authored Nov 20, 2024
1 parent dbd958b commit 86a1ab5
Show file tree
Hide file tree
Showing 47 changed files with 550 additions and 514 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ public ExposingPipelineFactory(TConfiguration configuration)
private TConfiguration Configuration { get; }
public ITransport<TConfiguration> Transport { get; }

public override RequestPipeline Create(RequestData requestData) =>
new RequestPipeline(requestData);
public override RequestPipeline Create(BoundConfiguration boundConfiguration) => new(boundConfiguration);
}
#nullable restore
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ private void UpdateCluster(VirtualCluster cluster)
private bool IsPingRequest(Endpoint endpoint) => _productRegistration.IsPingRequest(endpoint);

/// <inheritdoc cref="IRequestInvoker.RequestAsync{TResponse}"/>>
public Task<TResponse> RequestAsync<TResponse>(Endpoint endpoint, RequestData requestData, PostData? postData, CancellationToken cancellationToken)
public Task<TResponse> RequestAsync<TResponse>(Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData, CancellationToken cancellationToken)
where TResponse : TransportResponse, new() =>
Task.FromResult(Request<TResponse>(endpoint, requestData, postData));
Task.FromResult(Request<TResponse>(endpoint, boundConfiguration, postData));

/// <inheritdoc cref="IRequestInvoker.Request{TResponse}"/>>
public TResponse Request<TResponse>(Endpoint endpoint, RequestData requestData, PostData? postData)
public TResponse Request<TResponse>(Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData)
where TResponse : TransportResponse, new()
{
if (!_calls.ContainsKey(endpoint.Uri.Port))
Expand All @@ -138,11 +138,11 @@ public TResponse Request<TResponse>(Endpoint endpoint, RequestData requestData,
_ = Interlocked.Increment(ref state.Sniffed);
return HandleRules<TResponse, ISniffRule>(
endpoint,
requestData,
boundConfiguration,
postData,
nameof(VirtualCluster.Sniff),
_cluster.SniffingRules,
requestData.RequestTimeout,
boundConfiguration.RequestTimeout,
(r) => UpdateCluster(r.NewClusterState),
(r) => _productRegistration.CreateSniffResponseBytes(_cluster.Nodes, _cluster.ElasticsearchVersion, _cluster.PublishAddressOverride, _cluster.SniffShouldReturnFqnd)
);
Expand All @@ -152,36 +152,36 @@ public TResponse Request<TResponse>(Endpoint endpoint, RequestData requestData,
_ = Interlocked.Increment(ref state.Pinged);
return HandleRules<TResponse, IRule>(
endpoint,
requestData,
boundConfiguration,
postData,
nameof(VirtualCluster.Ping),
_cluster.PingingRules,
requestData.PingTimeout,
boundConfiguration.PingTimeout,
(r) => { },
(r) => null //HEAD request
);
}
_ = Interlocked.Increment(ref state.Called);
return HandleRules<TResponse, IClientCallRule>(
endpoint,
requestData,
boundConfiguration,
postData,
nameof(VirtualCluster.ClientCalls),
_cluster.ClientCallRules,
requestData.RequestTimeout,
boundConfiguration.RequestTimeout,
(r) => { },
CallResponse
);
}
catch (TheException e)
{
return ResponseFactory.Create<TResponse>(endpoint, requestData, postData, e, null, null, Stream.Null, null, -1, null, null);
return ResponseFactory.Create<TResponse>(endpoint, boundConfiguration, postData, e, null, null, Stream.Null, null, -1, null, null);
}
}

private TResponse HandleRules<TResponse, TRule>(
Endpoint endpoint,
RequestData requestData,
BoundConfiguration boundConfiguration,
PostData? postData,
string origin,
IList<TRule> rules,
Expand All @@ -203,28 +203,28 @@ private TResponse HandleRules<TResponse, TRule>(
if (rule.OnPort == null || rule.OnPort.Value != endpoint.Uri.Port) continue;

if (always)
return Always<TResponse, TRule>(endpoint, requestData, postData, timeout, beforeReturn, successResponse, rule);
return Always<TResponse, TRule>(endpoint, boundConfiguration, postData, timeout, beforeReturn, successResponse, rule);

if (rule.ExecuteCount > times) continue;

return Sometimes<TResponse, TRule>(endpoint, requestData, postData, timeout, beforeReturn, successResponse, rule);
return Sometimes<TResponse, TRule>(endpoint, boundConfiguration, postData, timeout, beforeReturn, successResponse, rule);
}
foreach (var rule in rules.Where(s => !s.OnPort.HasValue))
{
var always = rule.Times.Match(t => true, t => false);
var times = rule.Times.Match(t => -1, t => t);
if (always)
return Always<TResponse, TRule>(endpoint, requestData, postData, timeout, beforeReturn, successResponse, rule);
return Always<TResponse, TRule>(endpoint, boundConfiguration, postData, timeout, beforeReturn, successResponse, rule);

if (rule.ExecuteCount > times) continue;

return Sometimes<TResponse, TRule>(endpoint, requestData, postData, timeout, beforeReturn, successResponse, rule);
return Sometimes<TResponse, TRule>(endpoint, boundConfiguration, postData, timeout, beforeReturn, successResponse, rule);
}
var count = _calls.Select(kv => kv.Value.Called).Sum();
throw new Exception($@"No global or port specific {origin} rule ({endpoint.Uri.Port}) matches any longer after {count} calls in to the cluster");
}

private TResponse Always<TResponse, TRule>(Endpoint endpoint, RequestData requestData, PostData? postData, TimeSpan timeout, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse, TRule rule
private TResponse Always<TResponse, TRule>(Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData, TimeSpan timeout, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse, TRule rule
)
where TResponse : TransportResponse, new()
where TRule : IRule
Expand All @@ -233,20 +233,20 @@ private TResponse Always<TResponse, TRule>(Endpoint endpoint, RequestData reques
{
var time = timeout < rule.Takes.Value ? timeout : rule.Takes.Value;
_dateTimeProvider.ChangeTime(d => d.Add(time));
if (rule.Takes.Value > requestData.RequestTimeout)
if (rule.Takes.Value > boundConfiguration.RequestTimeout)
{
throw new TheException(
$"Request timed out after {time} : call configured to take {rule.Takes.Value} while requestTimeout was: {timeout}");
}
}

return rule.Succeeds
? Success<TResponse, TRule>(endpoint, requestData, postData, beforeReturn, successResponse, rule)
: Fail<TResponse, TRule>(endpoint, requestData, postData, rule);
? Success<TResponse, TRule>(endpoint, boundConfiguration, postData, beforeReturn, successResponse, rule)
: Fail<TResponse, TRule>(endpoint, boundConfiguration, postData, rule);
}

private TResponse Sometimes<TResponse, TRule>(
Endpoint endpoint, RequestData requestData, PostData? postData, TimeSpan timeout, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse, TRule rule
Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData, TimeSpan timeout, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse, TRule rule
)
where TResponse : TransportResponse, new()
where TRule : IRule
Expand All @@ -255,20 +255,20 @@ private TResponse Sometimes<TResponse, TRule>(
{
var time = timeout < rule.Takes.Value ? timeout : rule.Takes.Value;
_dateTimeProvider.ChangeTime(d => d.Add(time));
if (rule.Takes.Value > requestData.RequestTimeout)
if (rule.Takes.Value > boundConfiguration.RequestTimeout)
{
throw new TheException(
$"Request timed out after {time} : call configured to take {rule.Takes.Value} while requestTimeout was: {timeout}");
}
}

if (rule.Succeeds)
return Success<TResponse, TRule>(endpoint, requestData, postData, beforeReturn, successResponse, rule);
return Success<TResponse, TRule>(endpoint, boundConfiguration, postData, beforeReturn, successResponse, rule);

return Fail<TResponse, TRule>(endpoint, requestData, postData, rule);
return Fail<TResponse, TRule>(endpoint, boundConfiguration, postData, rule);
}

private TResponse Fail<TResponse, TRule>(Endpoint endpoint, RequestData requestData, PostData? postData, TRule rule, RuleOption<Exception, int>? returnOverride = null)
private TResponse Fail<TResponse, TRule>(Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData, TRule rule, RuleOption<Exception, int>? returnOverride = null)
where TResponse : TransportResponse, new()
where TRule : IRule
{
Expand All @@ -282,13 +282,13 @@ private TResponse Fail<TResponse, TRule>(Endpoint endpoint, RequestData requestD

return ret.Match(
e => throw e,
statusCode => _inMemoryRequestInvoker.BuildResponse<TResponse>(endpoint, requestData, postData, CallResponse(rule),
statusCode => _inMemoryRequestInvoker.BuildResponse<TResponse>(endpoint, boundConfiguration, postData, CallResponse(rule),
//make sure we never return a valid status code in Fail responses because of a bad rule.
statusCode >= 200 && statusCode < 300 ? 502 : statusCode, rule.ReturnContentType)
);
}

private TResponse Success<TResponse, TRule>(Endpoint endpoint, RequestData requestData, PostData? postData, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse,
private TResponse Success<TResponse, TRule>(Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse,
TRule rule
)
where TResponse : TransportResponse, new()
Expand All @@ -299,7 +299,7 @@ TRule rule
rule.RecordExecuted();

beforeReturn?.Invoke(rule);
return _inMemoryRequestInvoker.BuildResponse<TResponse>(endpoint, requestData, postData, successResponse(rule), contentType: rule.ReturnContentType);
return _inMemoryRequestInvoker.BuildResponse<TResponse>(endpoint, boundConfiguration, postData, successResponse(rule), contentType: rule.ReturnContentType);
}

private static byte[] CallResponse<TRule>(TRule rule)
Expand Down
4 changes: 2 additions & 2 deletions src/Elastic.Transport.VirtualizedCluster/Rules/RuleBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ public TRule ReturnResponse<T>(T response)
r = ms.ToArray();
}
Self.ReturnResponse = r;
Self.ReturnContentType = RequestData.DefaultContentType;
Self.ReturnContentType = BoundConfiguration.DefaultContentType;
return (TRule)this;
}

public TRule ReturnByteResponse(byte[] response, string responseContentType = RequestData.DefaultContentType)
public TRule ReturnByteResponse(byte[] response, string responseContentType = BoundConfiguration.DefaultContentType)
{
Self.ReturnResponse = response;
Self.ReturnContentType = responseContentType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@
namespace Elastic.Transport;

/// <summary>
/// Where and how <see cref="IRequestInvoker.Request{TResponse}" /> should connect to.
/// <para>
/// Represents the cumulative configuration from <see cref="ITransportConfiguration" />
/// and <see cref="IRequestConfiguration" />.
/// </para>
/// </summary>
public sealed record RequestData
public sealed record BoundConfiguration : IRequestConfiguration
{
private const string OpaqueIdHeader = "X-Opaque-Id";

Expand All @@ -27,8 +24,8 @@ public sealed record RequestData
/// The security header used to run requests as a different user.
public const string RunAsSecurityHeader = "es-security-runas-user";

/// <inheritdoc cref="RequestData"/>
public RequestData(ITransportConfiguration global, IRequestConfiguration? local = null)
/// <inheritdoc cref="BoundConfiguration"/>
public BoundConfiguration(ITransportConfiguration global, IRequestConfiguration? local = null)
{
ConnectionSettings = global;
MemoryStreamFactory = global.MemoryStreamFactory;
Expand All @@ -55,7 +52,7 @@ public RequestData(ITransportConfiguration global, IRequestConfiguration? local
Accept = local?.Accept ?? global.Accept ?? DefaultContentType;
ThrowExceptions = local?.ThrowExceptions ?? global.ThrowExceptions ?? false;
RequestTimeout = local?.RequestTimeout ?? global.RequestTimeout ?? RequestConfiguration.DefaultRequestTimeout;
RequestMetaData = local?.RequestMetaData?.Items ?? EmptyReadOnly<string, string>.Dictionary;
RequestMetaData = local?.RequestMetaData;
AuthenticationHeader = local?.Authentication ?? global.Authentication;
AllowedStatusCodes = local?.AllowedStatusCodes ?? EmptyReadOnly<int>.Collection;
ClientCertificates = local?.ClientCertificates ?? global.ClientCertificates;
Expand All @@ -81,6 +78,7 @@ public RequestData(ITransportConfiguration global, IRequestConfiguration? local
Headers[key] = local.Headers[key];
}

OpaqueId = local?.OpaqueId;
if (!string.IsNullOrEmpty(local?.OpaqueId))
{
Headers ??= [];
Expand Down Expand Up @@ -115,6 +113,7 @@ public RequestData(ITransportConfiguration global, IRequestConfiguration? local
}

ProductResponseBuilders = global.ProductRegistration.ResponseBuilders;
DisableAuditTrail = local?.DisableAuditTrail ?? global.DisableAuditTrail ?? false;
}

/// <inheritdoc cref="ITransportConfiguration.MemoryStreamFactory"/>
Expand All @@ -140,7 +139,7 @@ public RequestData(ITransportConfiguration global, IRequestConfiguration? local
/// <inheritdoc cref="ITransportConfiguration.DnsRefreshTimeout"/>
public TimeSpan DnsRefreshTimeout { get; }
/// <inheritdoc cref="IRequestConfiguration.RequestMetaData"/>
public IReadOnlyDictionary<string, string> RequestMetaData { get; }
public RequestMetaData? RequestMetaData { get; }
/// <inheritdoc cref="IRequestConfiguration.Accept"/>
public string Accept { get; }
/// <inheritdoc cref="IRequestConfiguration.AllowedStatusCodes"/>
Expand Down Expand Up @@ -191,4 +190,45 @@ public RequestData(ITransportConfiguration global, IRequestConfiguration? local
public IReadOnlyCollection<IResponseBuilder> ProductResponseBuilders { get; }
/// <inheritdoc cref="IRequestConfiguration.ResponseBuilders"/>
public IReadOnlyCollection<IResponseBuilder> ResponseBuilders { get; }
/// <inheritdoc cref="IRequestConfiguration.DisableAuditTrail"/>
public bool DisableAuditTrail { get; }
/// <inheritdoc cref="IRequestConfiguration.OpaqueId"/>
public string? OpaqueId { get; }

string? IRequestConfiguration.Accept => Accept;
IReadOnlyCollection<int>? IRequestConfiguration.AllowedStatusCodes => AllowedStatusCodes;
AuthorizationHeader? IRequestConfiguration.Authentication => AuthenticationHeader;
X509CertificateCollection? IRequestConfiguration.ClientCertificates => ClientCertificates;
string? IRequestConfiguration.ContentType => ContentType;
bool? IRequestConfiguration.DisableDirectStreaming => DisableDirectStreaming;
bool? IRequestConfiguration.DisableAuditTrail => DisableAuditTrail;
bool? IRequestConfiguration.DisablePings => DisablePings;
bool? IRequestConfiguration.DisableSniff => DisableSniff;
bool? IRequestConfiguration.HttpPipeliningEnabled => HttpPipeliningEnabled;
bool? IRequestConfiguration.EnableHttpCompression => HttpCompression;
Uri? IRequestConfiguration.ForceNode => ForceNode;
int? IRequestConfiguration.MaxRetries => MaxRetries;
TimeSpan? IRequestConfiguration.MaxRetryTimeout => RequestTimeout;
string? IRequestConfiguration.OpaqueId => OpaqueId;
bool? IRequestConfiguration.ParseAllHeaders => ParseAllHeaders;
TimeSpan? IRequestConfiguration.PingTimeout => PingTimeout;
TimeSpan? IRequestConfiguration.RequestTimeout => RequestTimeout;
IReadOnlyCollection<IResponseBuilder> IRequestConfiguration.ResponseBuilders => ResponseBuilders;
HeadersList? IRequestConfiguration.ResponseHeadersToParse => ResponseHeadersToParse;
string? IRequestConfiguration.RunAs => RunAs;
bool? IRequestConfiguration.ThrowExceptions => ThrowExceptions;
bool? IRequestConfiguration.TransferEncodingChunked => TransferEncodingChunked;
NameValueCollection? IRequestConfiguration.Headers => Headers;
bool? IRequestConfiguration.EnableTcpStats => EnableTcpStats;
bool? IRequestConfiguration.EnableThreadPoolStats => EnableThreadPoolStats;
RequestMetaData? IRequestConfiguration.RequestMetaData => RequestMetaData;

/// <summary>
/// Create a cachable instance of <see cref="BoundConfiguration"/> for use in high-performance scenarios.
/// </summary>
/// <param name="transport">An existing <see cref="ITransport{TConfiguration}"/> from which to bind transport configuration.</param>
/// <param name="requestConfiguration">A request specific <see cref="IRequestConfiguration"/>.</param>
/// <returns></returns>
public static BoundConfiguration Create(ITransport<ITransportConfiguration> transport, IRequestConfiguration requestConfiguration) =>
new(transport.Configuration, requestConfiguration);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,26 @@ internal sealed class DefaultResponseBuilder : IResponseBuilder
bool IResponseBuilder.CanBuild<TResponse>() => true;

/// <inheritdoc/>
public TResponse Build<TResponse>(ApiCallDetails apiCallDetails, RequestData requestData,
public TResponse Build<TResponse>(ApiCallDetails apiCallDetails, BoundConfiguration boundConfiguration,
Stream responseStream, string contentType, long contentLength)
where TResponse : TransportResponse, new() =>
SetBodyCoreAsync<TResponse>(false, apiCallDetails, requestData, responseStream).EnsureCompleted();
SetBodyCoreAsync<TResponse>(false, apiCallDetails, boundConfiguration, responseStream).EnsureCompleted();

/// <inheritdoc/>
public Task<TResponse> BuildAsync<TResponse>(
ApiCallDetails apiCallDetails, RequestData requestData, Stream responseStream, string contentType, long contentLength,
ApiCallDetails apiCallDetails, BoundConfiguration boundConfiguration, Stream responseStream, string contentType, long contentLength,
CancellationToken cancellationToken) where TResponse : TransportResponse, new() =>
SetBodyCoreAsync<TResponse>(true, apiCallDetails, requestData, responseStream, cancellationToken).AsTask();
SetBodyCoreAsync<TResponse>(true, apiCallDetails, boundConfiguration, responseStream, cancellationToken).AsTask();

private static async ValueTask<TResponse> SetBodyCoreAsync<TResponse>(bool isAsync,
ApiCallDetails details, RequestData requestData, Stream responseStream,
ApiCallDetails details, BoundConfiguration boundConfiguration, Stream responseStream,
CancellationToken cancellationToken = default)
where TResponse : TransportResponse, new()
{
TResponse response = null;

if (details.HttpStatusCode.HasValue &&
requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value))
boundConfiguration.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value))
{
return response;
}
Expand All @@ -51,9 +51,9 @@ private static async ValueTask<TResponse> SetBodyCoreAsync<TResponse>(bool isAsy
var beforeTicks = Stopwatch.GetTimestamp();

if (isAsync)
response = await requestData.ConnectionSettings.RequestResponseSerializer.DeserializeAsync<TResponse>(responseStream, cancellationToken).ConfigureAwait(false);
response = await boundConfiguration.ConnectionSettings.RequestResponseSerializer.DeserializeAsync<TResponse>(responseStream, cancellationToken).ConfigureAwait(false);
else
response = requestData.ConnectionSettings.RequestResponseSerializer.Deserialize<TResponse>(responseStream);
response = boundConfiguration.ConnectionSettings.RequestResponseSerializer.Deserialize<TResponse>(responseStream);

var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000);

Expand Down
Loading

0 comments on commit 86a1ab5

Please sign in to comment.