diff --git a/Simple.HttpClientFactory.Tests/ExceptionTranslatorTests.cs b/Simple.HttpClientFactory.Tests/ExceptionTranslatorTests.cs new file mode 100644 index 0000000..00d3e6e --- /dev/null +++ b/Simple.HttpClientFactory.Tests/ExceptionTranslatorTests.cs @@ -0,0 +1,110 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading.Tasks; +using Polly; +using Polly.Timeout; +using WireMock.RequestBuilders; +using WireMock.ResponseBuilders; +using WireMock.Server; +using Xunit; + +namespace Simple.HttpClientFactory.Tests +{ + public class ExceptionTranslatorTests + { + private readonly WireMockServer _server; + private readonly List _visitedMiddleware = new List(); + + public ExceptionTranslatorTests() + { + _server = WireMockServer.Start(); + _server.Given(Request.Create().WithPath("/hello/world").UsingAnyMethod()) + .RespondWith( + Response.Create() + .WithStatusCode(200) + .WithHeader("Content-Type", "text/plain") + .WithBody("Hello world!")); + + _server + .Given(Request.Create() + .WithPath("/timeout") + .UsingGet()) + .RespondWith(Response.Create() + .WithStatusCode(408)); + } + + public class TestException : Exception + { + public TestException(string message) : base(message) + { + } + } + + [Fact] + public async Task Exception_translator_can_translate_exception_types() + { + var clientWithRetry = HttpClientFactory.Create() + .WithMessageExceptionHandler(ex => true, ex => new TestException(ex.Message)) + .WithPolicy( + Policy + .Handle() + .OrResult(result => (int)result.StatusCode >= 500 || result.StatusCode == HttpStatusCode.RequestTimeout) + .WaitAndRetryAsync(3, retryAttempt => TimeSpan.FromSeconds(1))) + .WithPolicy(Policy.TimeoutAsync(TimeSpan.FromSeconds(4), TimeoutStrategy.Optimistic)) + .Build(); + + await Assert.ThrowsAsync(() => clientWithRetry.GetAsync(_server.Urls[0] + "/timeout")); + Assert.Equal(4, _server.LogEntries.Count()); + + } + + + [Fact] + public async Task Exception_translator_should_not_change_unhandled_exceptions() + { + var clientWithRetry = HttpClientFactory.Create() + .WithMessageExceptionHandler(ex => true, ex => ex) + .WithPolicy( + Policy + .Handle() + .OrResult(result => (int)result.StatusCode >= 500 || result.StatusCode == HttpStatusCode.RequestTimeout) + .WaitAndRetryAsync(3, retryAttempt => TimeSpan.FromSeconds(1))) + .WithPolicy(Policy.TimeoutAsync(TimeSpan.FromSeconds(4), TimeoutStrategy.Optimistic)) + .Build(); + + await Assert.ThrowsAsync(() => clientWithRetry.GetAsync(_server.Urls[0] + "/timeout")); + Assert.Equal(4, _server.LogEntries.Count()); + + } + + [Fact] + public async Task Exception_translator_without_errors_should_not_affect_anything() + { + var trafficRecorderMessageHandler = new TrafficRecorderMessageHandler(_visitedMiddleware); + var eventMessageHandler = new EventMessageHandler(_visitedMiddleware); + + var client = HttpClientFactory.Create() + .WithMessageExceptionHandler(ex => true, ex => ex) + .WithMessageHandler(eventMessageHandler) + .WithMessageHandler(trafficRecorderMessageHandler) + .Build(); + + var raisedEvent = await Assert.RaisesAsync( + h => eventMessageHandler.Request += h, + h => eventMessageHandler.Request -= h, + () => client.GetAsync(_server.Urls[0] + "/hello/world")); + + Assert.True(raisedEvent.Arguments.Request.Headers.Contains("foobar")); + Assert.Equal("foobar",raisedEvent.Arguments.Request.Headers.GetValues("foobar").FirstOrDefault()); + Assert.Single(trafficRecorderMessageHandler.Traffic); + + Assert.Equal(HttpStatusCode.OK, trafficRecorderMessageHandler.Traffic[0].Item2.StatusCode); + Assert.Equal(new [] { nameof(TrafficRecorderMessageHandler), nameof(EventMessageHandler) }, _visitedMiddleware); + } + + } +} diff --git a/Simple.HttpClientFactory.Tests/MiddlewareDelegateTests.cs b/Simple.HttpClientFactory.Tests/MiddlewareDelegateTests.cs index 1469955..c5682eb 100644 --- a/Simple.HttpClientFactory.Tests/MiddlewareDelegateTests.cs +++ b/Simple.HttpClientFactory.Tests/MiddlewareDelegateTests.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Net; +using System.Net.Http; using System.Threading.Tasks; using WireMock.RequestBuilders; using WireMock.ResponseBuilders; @@ -26,6 +27,7 @@ public MiddlewareDelegateTests() .WithBody("Hello world!")); } + [Fact] public async Task Single_middleware_handler_should_work() { diff --git a/Simple.HttpClientFactory/HttpClientBuilder.cs b/Simple.HttpClientFactory/HttpClientBuilder.cs index 45de2e2..5aa0bd8 100644 --- a/Simple.HttpClientFactory/HttpClientBuilder.cs +++ b/Simple.HttpClientFactory/HttpClientBuilder.cs @@ -71,7 +71,7 @@ public IHttpClientBuilder WithTimeout(in TimeSpan timeout) public IHttpClientBuilder WithMessageExceptionHandler( Func exceptionHandlingPredicate, Func exceptionHandler) => - WithMessageHandler(new ExceptionTranslatorRequestMiddleware(exceptionHandlingPredicate, exceptionHandler, null)); + WithMessageHandler(new ExceptionTranslatorRequestMiddleware(exceptionHandlingPredicate, exceptionHandler)); /// is public IHttpClientBuilder WithMessageHandler(DelegatingHandler handler) diff --git a/Simple.HttpClientFactory/MessageHandlers/ExceptionTranslatorRequestMiddleware.cs b/Simple.HttpClientFactory/MessageHandlers/ExceptionTranslatorRequestMiddleware.cs index 0907a2e..4182da0 100644 --- a/Simple.HttpClientFactory/MessageHandlers/ExceptionTranslatorRequestMiddleware.cs +++ b/Simple.HttpClientFactory/MessageHandlers/ExceptionTranslatorRequestMiddleware.cs @@ -13,6 +13,15 @@ public class ExceptionTranslatorRequestMiddleware : DelegatingHandler public event EventHandler RequestException; public event EventHandler TransformedRequestException; + public ExceptionTranslatorRequestMiddleware( + Func exceptionHandlingPredicate, + Func exceptionHandler) + { + _exceptionHandlingPredicate = exceptionHandlingPredicate ?? throw new ArgumentNullException(nameof(exceptionHandlingPredicate)); + _exceptionHandler = exceptionHandler ?? throw new ArgumentNullException(nameof(exceptionHandler)); + } + + public ExceptionTranslatorRequestMiddleware( Func exceptionHandlingPredicate, Func exceptionHandler, DelegatingHandler handler) : base(handler) diff --git a/Simple.HttpClientFactory/MessageHandlers/PollyHttpMessageHandler.cs b/Simple.HttpClientFactory/MessageHandlers/PollyHttpMessageHandler.cs index 15a0c9d..0ab93c6 100644 --- a/Simple.HttpClientFactory/MessageHandlers/PollyHttpMessageHandler.cs +++ b/Simple.HttpClientFactory/MessageHandlers/PollyHttpMessageHandler.cs @@ -45,22 +45,29 @@ private Task SendAsyncInternal(HttpRequestMessage request, var cleanUpContext = false; var context = GetOrCreatePolicyExecutionContext(request, ref cleanUpContext); - return _policy.ExecuteAsync( - async (c, ct) => await base.SendAsync(request, cancellationToken), context, cancellationToken) - .ContinueWith(t => - { - if(cleanUpContext) - request.SetPolicyExecutionContext(null); - return t.Result; - }, cancellationToken); + //do not await for the task so the async state machine won't grow big + var responseTask = + _policy.ExecuteAsync( + async (c, ct) => + await base.SendAsync(request, cancellationToken), context, cancellationToken) + .ContinueWith(t => + { + if(cleanUpContext) + request.SetPolicyExecutionContext(null); + return t.Result; + }, cancellationToken); + + responseTask.ConfigureAwait(false); + + return responseTask; - Context GetOrCreatePolicyExecutionContext(HttpRequestMessage httpRequestMessage, ref bool b) + Context GetOrCreatePolicyExecutionContext(HttpRequestMessage httpRequestMessage, ref bool shouldCleanupContext) { if (!httpRequestMessage.TryGetPolicyExecutionContext(out var fetchedContext)) { fetchedContext = new Context(); httpRequestMessage.SetPolicyExecutionContext(fetchedContext); - b = true; + shouldCleanupContext = true; } return fetchedContext;