From cc58c267d9e018d5473bcda19c50f538dc3a349a Mon Sep 17 00:00:00 2001 From: Aaron Roney Date: Fri, 4 Oct 2019 20:33:51 -0700 Subject: [PATCH 1/7] Remove extraneous references. --- src/Core/AspNetCore.Proxy.csproj | 1 - src/Test/AspNetCore.Proxy.Tests.csproj | 6 ++++-- src/Test/UnitTests.cs | 5 ----- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/Core/AspNetCore.Proxy.csproj b/src/Core/AspNetCore.Proxy.csproj index efce8e7..83c7447 100644 --- a/src/Core/AspNetCore.Proxy.csproj +++ b/src/Core/AspNetCore.Proxy.csproj @@ -15,7 +15,6 @@ - diff --git a/src/Test/AspNetCore.Proxy.Tests.csproj b/src/Test/AspNetCore.Proxy.Tests.csproj index 726b8d4..737e0db 100644 --- a/src/Test/AspNetCore.Proxy.Tests.csproj +++ b/src/Test/AspNetCore.Proxy.Tests.csproj @@ -1,17 +1,19 @@ + netcoreapp3.0 false + - - + + \ No newline at end of file diff --git a/src/Test/UnitTests.cs b/src/Test/UnitTests.cs index 64e2770..f716acf 100644 --- a/src/Test/UnitTests.cs +++ b/src/Test/UnitTests.cs @@ -1,4 +1,3 @@ -using System; using System.Collections.Generic; using System.Linq; using System.Net; @@ -6,12 +5,8 @@ using System.Net.Http.Headers; using System.Text; using System.Threading.Tasks; -using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.TestHost; -using Microsoft.Extensions.DependencyInjection; using Newtonsoft.Json.Linq; using Xunit; From 1e80b28ab431bb11f301248ad12fd1524df3b0f7 Mon Sep 17 00:00:00 2001 From: Aaron Roney Date: Tue, 8 Oct 2019 21:55:47 -0700 Subject: [PATCH 2/7] Add WebSocket support. --- README.md | 20 ++- TODO.md | 54 +++++++ src/Core/Extensions.cs | 26 +++ src/Core/Helpers.cs | 151 +---------------- src/Core/HttpExtensions.cs | 152 ++++++++++++++++++ src/Core/ProxyOptions.cs | 20 +++ src/Core/ProxyRouteExtensions.cs | 62 ++++++- src/Core/WsExtensions.cs | 86 ++++++++++ src/Test/Extensions.cs | 63 ++++++++ src/Test/{Helpers.cs => Http/HttpHelpers.cs} | 41 +++-- .../{UnitTests.cs => Http/HttpUnitTests.cs} | 21 ++- src/Test/Mix/MixHelpers.cs | 58 +++++++ src/Test/Mix/MixUnitTests.cs | 74 +++++++++ src/Test/Ws/WsHelpers.cs | 58 +++++++ src/Test/Ws/WsUnitTests.cs | 71 ++++++++ 15 files changed, 781 insertions(+), 176 deletions(-) create mode 100644 TODO.md create mode 100644 src/Core/Extensions.cs create mode 100644 src/Core/HttpExtensions.cs create mode 100644 src/Core/WsExtensions.cs create mode 100644 src/Test/Extensions.cs rename src/Test/{Helpers.cs => Http/HttpHelpers.cs} (86%) rename src/Test/{UnitTests.cs => Http/HttpUnitTests.cs} (91%) create mode 100644 src/Test/Mix/MixHelpers.cs create mode 100644 src/Test/Mix/MixUnitTests.cs create mode 100644 src/Test/Ws/WsHelpers.cs create mode 100644 src/Test/Ws/WsUnitTests.cs diff --git a/README.md b/README.md index d9330f7..9f6095e 100644 --- a/README.md +++ b/README.md @@ -67,22 +67,38 @@ public class MyController : Controller { var options = ProxyOptions.Instance .WithShouldAddForwardedHeaders(false) + .WithHttpClientName("MyCustomClient") + .WithIntercept(async context => + { + if(c.Connection.RemotePort == 7777) + { + c.Response.StatusCode = 300; + await c.Response.WriteAsync("I don't like this port, so I am not proxying this request!"); + return true; + } + + return false; + }) .WithBeforeSend((c, hrm) => { // Set something that is needed for the downstream endpoint. hrm.Headers.Authorization = new AuthenticationHeaderValue("Basic"); + + return Task.CompletedTask; }) .WithAfterReceive((c, hrm) => { // Alter the content in some way before sending back to client. var newContent = new StringContent("It's all greek...er, Latin...to me!"); hrm.Content = newContent; + + return Task.CompletedTask; }) - .WithHandleFailure((c, e) => + .WithHandleFailure(async (c, e) => { // Return a custom error response. c.Response.StatusCode = 403; - c.Response.WriteAsync("Things borked."); + await c.Response.WriteAsync("Things borked."); }); return this.ProxyAsync($"https://jsonplaceholder.typicode.com/posts/{postId}"); diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..6c10c94 --- /dev/null +++ b/TODO.md @@ -0,0 +1,54 @@ +TODO: + * Should UseProxy require the user to set all of the proxies at once? YES, in 4.0.0...`UseProxies` with builders. + * Remove the [ProxyRoute] attribute? Maybe, in 4.0.0. If we keep it, change it to `UseStaticProxies`, and somehow return options? + * Round robin helper, and protocol helper for `RunProxy`? Maybe in 4.0.0. + * Add options for WebSocket calls. + * Make options handlers called `Async`? + * Allow the user to set options via a lambda for builder purposes? + +Some ideas of how `UseProxies` should work in 4.0.0. + +```csharp + +// Custom top-level extension method. +app.UseProxies(proxies => +{ + proxies.Map("/route/thingy") + .ToHttp("http://mysite.com/") // OR To(http, ws) + .WithOption1(); + + // OR + + proxies.Map("/route/thingy", proxy => + { + // Make sure the proxy builder has HttpContext on it. + proxy.ToHttp("http://mysite.com") + .WithOption1(...); + + proxy.ToWs(...); + }); +}); + +// OR? + +// Piggy-back on the ASP.NET Core 3 endpoints pattern. +app.UseEndpoints(endpoints => +{ + endpoints.Map("/my/path", context => + { + return context.ProxyAsync("http://mysite.com", options => + { + options.WithOption1(); + }); + + // OR? + + return context.HttpProxyTo("http://mysite.com", options => + { + options.WithOption1(); + }); + + // OR, maybe there is an `HttpProxyTo` and `WsProxyTo`, and a `ProxyTo` that does its best to decide. + }); +}) +``` \ No newline at end of file diff --git a/src/Core/Extensions.cs b/src/Core/Extensions.cs new file mode 100644 index 0000000..557ce87 --- /dev/null +++ b/src/Core/Extensions.cs @@ -0,0 +1,26 @@ +using Microsoft.AspNetCore.Http; +using System; +using System.Threading.Tasks; + +namespace AspNetCore.Proxy +{ + internal static class Extensions + { + internal static Task ExecuteProxyOperationAsync(this HttpContext context, string uri, ProxyOptions options = null) + { + if (context.WebSockets.IsWebSocketRequest) + { + if(!uri.StartsWith("ws", System.StringComparison.OrdinalIgnoreCase)) + throw new InvalidOperationException("A WebSocket request must forward to a WebSocket (ws[s]) endpoint."); + + return context.ExecuteWsProxyOperationAsync(uri, options); + } + + // Assume HTTP if not WebSocket. + if(!uri.StartsWith("http", System.StringComparison.OrdinalIgnoreCase)) + throw new InvalidOperationException("An HTTP request must forward to an HTTP (http[s]) endpoint."); + + return context.ExecuteHttpProxyOperationAsync(uri, options); + } + } +} diff --git a/src/Core/Helpers.cs b/src/Core/Helpers.cs index d5a1e60..8c8b2c4 100644 --- a/src/Core/Helpers.cs +++ b/src/Core/Helpers.cs @@ -1,20 +1,14 @@ -using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.DependencyModel; using System; using System.Collections.Generic; -using System.Linq; -using System.Net.Http; -using System.Net.Sockets; using System.Reflection; -using System.Text; -using System.Threading.Tasks; +using Microsoft.Extensions.DependencyModel; namespace AspNetCore.Proxy { internal static class Helpers { - internal static readonly string ProxyClientName = "AspNetCore.Proxy.ProxyClient"; + internal static readonly string HttpProxyClientName = "AspNetCore.Proxy.HttpProxyClient"; + internal static readonly string[] WebSocketNotForwardedHeaders = new[] { "Connection", "Host", "Upgrade", "Sec-WebSocket-Accept", "Sec-WebSocket-Protocol", "Sec-WebSocket-Key", "Sec-WebSocket-Version", "Sec-WebSocket-Extensions" }; internal static IEnumerable GetReferencingAssemblies() { @@ -31,142 +25,5 @@ internal static IEnumerable GetReferencingAssemblies() } return assemblies; } - - internal static async Task ExecuteProxyOperation(HttpContext context, string uri, ProxyOptions options = null) - { - try - { - var proxiedRequest = context.CreateProxiedHttpRequest(uri, options?.ShouldAddForwardedHeaders ?? true); - - if(options?.BeforeSend != null) - await options.BeforeSend(context, proxiedRequest).ConfigureAwait(false); - var proxiedResponse = await context - .SendProxiedHttpRequest(proxiedRequest, options?.HttpClientName ?? Helpers.ProxyClientName) - .ConfigureAwait(false); - - if(options?.AfterReceive != null) - await options.AfterReceive(context, proxiedResponse).ConfigureAwait(false); - await context.WriteProxiedHttpResponse(proxiedResponse).ConfigureAwait(false); - } - catch (Exception e) - { - if (options?.HandleFailure == null) - { - // If the failures are not caught, then write a generic response. - context.Response.StatusCode = 502; - await context.Response.WriteAsync($"Request could not be proxied.\n\n{e.Message}\n\n{e.StackTrace}.").ConfigureAwait(false); - return; - } - - await options.HandleFailure(context, e).ConfigureAwait(false); - } - } - } - - internal static class Extensions - { - internal static HttpRequestMessage CreateProxiedHttpRequest(this HttpContext context, string uriString, bool shouldAddForwardedHeaders) - { - var uri = new Uri(uriString); - var request = context.Request; - - var requestMessage = new HttpRequestMessage(); - var requestMethod = request.Method; - - // Write to request content, when necessary. - if (!HttpMethods.IsGet(requestMethod) && - !HttpMethods.IsHead(requestMethod) && - !HttpMethods.IsDelete(requestMethod) && - !HttpMethods.IsTrace(requestMethod)) - { - var streamContent = new StreamContent(request.Body); - requestMessage.Content = streamContent; - } - - // Copy the request headers. - foreach (var header in context.Request.Headers) - if (!requestMessage.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray())) - requestMessage.Content?.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray()); - - // Add forwarded headers. - if(shouldAddForwardedHeaders) - AddForwardedHeadersToRequest(context, requestMessage); - - // Set destination and method. - requestMessage.Headers.Host = uri.Authority; - requestMessage.RequestUri = uri; - requestMessage.Method = new HttpMethod(request.Method); - - return requestMessage; - } - - internal static Task SendProxiedHttpRequest(this HttpContext context, HttpRequestMessage message, string httpClientName) - { - return context.RequestServices - .GetService() - .CreateClient(httpClientName) - .SendAsync(message, HttpCompletionOption.ResponseHeadersRead, context.RequestAborted); - } - - internal static Task WriteProxiedHttpResponse(this HttpContext context, HttpResponseMessage responseMessage) - { - var response = context.Response; - - response.StatusCode = (int)responseMessage.StatusCode; - foreach (var header in responseMessage.Headers) - { - response.Headers[header.Key] = header.Value.ToArray(); - } - - foreach (var header in responseMessage.Content.Headers) - { - response.Headers[header.Key] = header.Value.ToArray(); - } - - response.Headers.Remove("transfer-encoding"); - - return responseMessage.Content.CopyToAsync(response.Body); - } - - private static void AddForwardedHeadersToRequest(HttpContext context, HttpRequestMessage requestMessage) - { - var request = context.Request; - var connection = context.Connection; - - var host = request.Host.ToString(); - var protocol = request.Scheme; - - var localIp = connection.LocalIpAddress?.ToString(); - var isLocalIpV6 = connection.LocalIpAddress?.AddressFamily == AddressFamily.InterNetworkV6; - - var remoteIp = context.Connection.RemoteIpAddress?.ToString(); - var isRemoteIpV6 = connection.RemoteIpAddress?.AddressFamily == AddressFamily.InterNetworkV6; - - if(remoteIp != null) - requestMessage.Headers.TryAddWithoutValidation("X-Forwarded-For", remoteIp); - requestMessage.Headers.TryAddWithoutValidation("X-Forwarded-Proto", protocol); - requestMessage.Headers.TryAddWithoutValidation("X-Forwarded-Host", host); - - // Fix IPv6 IPs for the `Forwarded` header. - var forwardedHeader = new StringBuilder($"proto={protocol};host={host};"); - - if(localIp != null) - { - if(isLocalIpV6) - localIp = $"\"[{localIp}]\""; - - forwardedHeader.Append($"by={localIp};"); - } - - if(remoteIp != null) - { - if(isRemoteIpV6) - remoteIp = $"\"[{remoteIp}]\""; - - forwardedHeader.Append($"for={remoteIp};"); - } - - requestMessage.Headers.TryAddWithoutValidation("Forwarded", forwardedHeader.ToString()); - } } -} +} \ No newline at end of file diff --git a/src/Core/HttpExtensions.cs b/src/Core/HttpExtensions.cs new file mode 100644 index 0000000..3728d89 --- /dev/null +++ b/src/Core/HttpExtensions.cs @@ -0,0 +1,152 @@ +using System; +using System.Linq; +using System.Net.Http; +using System.Net.Sockets; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; + +namespace AspNetCore.Proxy +{ + internal static class HttpExtensions + { + internal static async Task ExecuteHttpProxyOperationAsync(this HttpContext context, string uri, ProxyOptions options = null) + { + try + { + // If `true`, this proxy call has been intercepted. + if(options?.Intercept != null && await options.Intercept(context)) + return; + + var proxiedRequest = context.CreateProxiedHttpRequest(uri, options?.ShouldAddForwardedHeaders ?? true); + + if(options?.BeforeSend != null) + await options.BeforeSend(context, proxiedRequest).ConfigureAwait(false); + var proxiedResponse = await context + .SendProxiedHttpRequestAsync(proxiedRequest, options?.HttpClientName ?? Helpers.HttpProxyClientName) + .ConfigureAwait(false); + + if(options?.AfterReceive != null) + await options.AfterReceive(context, proxiedResponse).ConfigureAwait(false); + await context.WriteProxiedHttpResponseAsync(proxiedResponse).ConfigureAwait(false); + } + catch (Exception e) + { + if (options?.HandleFailure == null) + { + // If the failures are not caught, then write a generic response. + context.Response.StatusCode = 502; + await context.Response.WriteAsync($"Request could not be proxied.\n\n{e.Message}\n\n{e.StackTrace}.").ConfigureAwait(false); + return; + } + + await options.HandleFailure(context, e).ConfigureAwait(false); + } + } + + private static HttpRequestMessage CreateProxiedHttpRequest(this HttpContext context, string uriString, bool shouldAddForwardedHeaders) + { + var uri = new Uri(uriString); + var request = context.Request; + + var requestMessage = new HttpRequestMessage(); + var requestMethod = request.Method; + + // Write to request content, when necessary. + if (!HttpMethods.IsGet(requestMethod) && + !HttpMethods.IsHead(requestMethod) && + !HttpMethods.IsDelete(requestMethod) && + !HttpMethods.IsTrace(requestMethod)) + { + var streamContent = new StreamContent(request.Body); + requestMessage.Content = streamContent; + } + + // Copy the request headers. + foreach (var header in context.Request.Headers) + if (!requestMessage.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray())) + requestMessage.Content?.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray()); + + // Add forwarded headers. + if(shouldAddForwardedHeaders) + AddForwardedHeadersToHttpRequest(context, requestMessage); + + // Set destination and method. + requestMessage.Headers.Host = uri.Authority; + requestMessage.RequestUri = uri; + requestMessage.Method = new HttpMethod(request.Method); + + return requestMessage; + } + + private static Task SendProxiedHttpRequestAsync(this HttpContext context, HttpRequestMessage message, string httpClientName) + { + return context.RequestServices + .GetService() + .CreateClient(httpClientName) + .SendAsync(message, HttpCompletionOption.ResponseHeadersRead, context.RequestAborted); + } + + private static Task WriteProxiedHttpResponseAsync(this HttpContext context, HttpResponseMessage responseMessage) + { + var response = context.Response; + + response.StatusCode = (int)responseMessage.StatusCode; + foreach (var header in responseMessage.Headers) + { + response.Headers[header.Key] = header.Value.ToArray(); + } + + foreach (var header in responseMessage.Content.Headers) + { + response.Headers[header.Key] = header.Value.ToArray(); + } + + response.Headers.Remove("transfer-encoding"); + + return responseMessage.Content.CopyToAsync(response.Body); + } + + private static void AddForwardedHeadersToHttpRequest(HttpContext context, HttpRequestMessage requestMessage) + { + var request = context.Request; + var connection = context.Connection; + + var host = request.Host.ToString(); + var protocol = request.Scheme; + + var localIp = connection.LocalIpAddress?.ToString(); + var isLocalIpV6 = connection.LocalIpAddress?.AddressFamily == AddressFamily.InterNetworkV6; + + var remoteIp = context.Connection.RemoteIpAddress?.ToString(); + var isRemoteIpV6 = connection.RemoteIpAddress?.AddressFamily == AddressFamily.InterNetworkV6; + + if(remoteIp != null) + requestMessage.Headers.TryAddWithoutValidation("X-Forwarded-For", remoteIp); + requestMessage.Headers.TryAddWithoutValidation("X-Forwarded-Proto", protocol); + requestMessage.Headers.TryAddWithoutValidation("X-Forwarded-Host", host); + + // Fix IPv6 IPs for the `Forwarded` header. + var forwardedHeader = new StringBuilder($"proto={protocol};host={host};"); + + if(localIp != null) + { + if(isLocalIpV6) + localIp = $"\"[{localIp}]\""; + + forwardedHeader.Append($"by={localIp};"); + } + + if(remoteIp != null) + { + if(isRemoteIpV6) + remoteIp = $"\"[{remoteIp}]\""; + + forwardedHeader.Append($"for={remoteIp};"); + } + + requestMessage.Headers.TryAddWithoutValidation("Forwarded", forwardedHeader.ToString()); + } + } +} \ No newline at end of file diff --git a/src/Core/ProxyOptions.cs b/src/Core/ProxyOptions.cs index b199e74..f818580 100644 --- a/src/Core/ProxyOptions.cs +++ b/src/Core/ProxyOptions.cs @@ -34,6 +34,15 @@ public class ProxyOptions /// A that is invoked once if the proxy operation fails. public Func HandleFailure { get; set; } + /// + /// Intercept property. + /// + /// + /// A that is invoked upon a call. + /// The result should be `true` if the call is intercepted and **not** meant to be forwarded. + /// + public Func> Intercept { get; set; } + /// /// BeforeSend property. /// @@ -61,12 +70,14 @@ private ProxyOptions( bool shouldAddForwardedHeaders, string httpClientName, Func handleFailure, + Func> intercept, Func beforeSend, Func afterReceive) { ShouldAddForwardedHeaders = shouldAddForwardedHeaders; HttpClientName = httpClientName; HandleFailure = handleFailure; + Intercept = intercept; BeforeSend = beforeSend; AfterReceive = afterReceive; } @@ -76,6 +87,7 @@ private static ProxyOptions CreateFrom( bool? shouldAddForwardedHeaders = null, string httpClientName = null, Func handleFailure = null, + Func> intercept = null, Func beforeSend = null, Func afterReceive = null) { @@ -83,6 +95,7 @@ private static ProxyOptions CreateFrom( shouldAddForwardedHeaders ?? old.ShouldAddForwardedHeaders, httpClientName ?? old.HttpClientName, handleFailure ?? old.HandleFailure, + intercept ?? old.Intercept, beforeSend ?? old.BeforeSend, afterReceive ?? old.AfterReceive); } @@ -114,6 +127,13 @@ private static ProxyOptions CreateFrom( /// A new instance of with the new value for the property. public ProxyOptions WithHandleFailure(Func handleFailure) => CreateFrom(this, handleFailure: handleFailure); + /// + /// Sets the property to a cloned instance of this . + /// + /// + /// A new instance of with the new value for the property. + public ProxyOptions WithIntercept(Func> intercept) => CreateFrom(this, intercept: intercept); + /// /// Sets the property to a cloned instance of this . /// diff --git a/src/Core/ProxyRouteExtensions.cs b/src/Core/ProxyRouteExtensions.cs index e945a5e..ba43bfa 100644 --- a/src/Core/ProxyRouteExtensions.cs +++ b/src/Core/ProxyRouteExtensions.cs @@ -30,7 +30,7 @@ public static class ProxyExtensions /// public static Task ProxyAsync(this ControllerBase controller, string uri, ProxyOptions options = null) { - return Helpers.ExecuteProxyOperation(controller.HttpContext, uri, options); + return controller.HttpContext.ExecuteProxyOperationAsync(uri, options); } /// @@ -42,9 +42,9 @@ public static Task ProxyAsync(this ControllerBase controller, string uri, ProxyO public static IServiceCollection AddProxies(this IServiceCollection services, Action configureProxyClient = null) { if(configureProxyClient != null) - services.AddHttpClient(Helpers.ProxyClientName, configureProxyClient); + services.AddHttpClient(Helpers.HttpProxyClientName, configureProxyClient); else - services.AddHttpClient(Helpers.ProxyClientName); + services.AddHttpClient(Helpers.HttpProxyClientName); return services; } @@ -94,6 +94,56 @@ public static void UseProxies(this IApplicationBuilder app) } } + #region RunProxy Overloads + + /// + /// Terminating middleware which creates a proxy over a specified endpoint. + /// + /// The ASP.NET . + /// The proxied address. + /// Extra options to apply during proxying. + public static void RunProxy(this IApplicationBuilder app, string proxiedAddress, ProxyOptions options = null) + { + app.Run(context => + { + return context.ExecuteProxyOperationAsync($"{proxiedAddress}{context.Request.Path}", options); + }); + } + + /// + /// Terminating middleware which creates a proxy over a specified endpoint. + /// + /// The ASP.NET . + /// A lambda { (context) => } which returns the address to which the request is proxied. + /// Extra options to apply during proxying. + public static void RunProxy(this IApplicationBuilder app, Func getProxiedAddress, ProxyOptions options = null) + { + app.Run(context => + { + return context.ExecuteProxyOperationAsync($"{getProxiedAddress(context)}{context.Request.Path}", options); + }); + } + + #endregion + + #region UseProxy Overloads + + /// + /// Middleware which creates an ad hoc proxy over a specified endpoint. + /// + /// The ASP.NET . + /// The local route endpoint. + /// The proxied address. + /// Extra options to apply during proxying. + public static void UseProxy(this IApplicationBuilder app, string endpoint, string proxiedAddress, ProxyOptions options = null) + { + UseProxy_GpaSync( + app, + endpoint, + (context, args) => proxiedAddress, + options); + } + /// /// Middleware which creates an ad hoc proxy over a specified endpoint. /// @@ -110,8 +160,6 @@ public static void UseProxy(this IApplicationBuilder app, string endpoint, Func< options); } - #region UseProxy Overloads - /// /// Middleware which creates an ad hoc proxy over a specified endpoint. /// @@ -199,7 +247,7 @@ private static void UseProxy_GpaAsync(this IApplicationBuilder app, string endpo builder.MapMiddlewareRoute(endpoint, proxyApp => { proxyApp.Run(async context => { var uri = await getProxiedAddress(context, context.GetRouteData().Values.ToDictionary(v => v.Key, v => v.Value)).ConfigureAwait(false); - await Helpers.ExecuteProxyOperation(context, uri, options); + await context.ExecuteProxyOperationAsync(uri, options); }); }); }); @@ -212,7 +260,7 @@ private static void UseProxy_GpaSync(this IApplicationBuilder app, string endpoi builder.MapMiddlewareRoute(endpoint, proxyApp => { proxyApp.Run(async context => { var uri = getProxiedAddress(context, context.GetRouteData().Values.ToDictionary(v => v.Key, v => v.Value)); - await Helpers.ExecuteProxyOperation(context, uri, options); + await context.ExecuteProxyOperationAsync(uri, options); }); }); }); diff --git a/src/Core/WsExtensions.cs b/src/Core/WsExtensions.cs new file mode 100644 index 0000000..b6a08f9 --- /dev/null +++ b/src/Core/WsExtensions.cs @@ -0,0 +1,86 @@ +using System; +using System.Linq; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; + +namespace AspNetCore.Proxy +{ + internal static class WsExtensions + { + internal static async Task ExecuteWsProxyOperationAsync(this HttpContext context, string uri, ProxyOptions options = null) + { + using (var socketToEndpoint = new ClientWebSocket()) + { + foreach (var protocol in context.WebSockets.WebSocketRequestedProtocols) + { + socketToEndpoint.Options.AddSubProtocol(protocol); + } + + foreach (var headerEntry in context.Request.Headers) + if (!Helpers.WebSocketNotForwardedHeaders.Contains(headerEntry.Key, StringComparer.OrdinalIgnoreCase)) + socketToEndpoint.Options.SetRequestHeader(headerEntry.Key, headerEntry.Value); + + // TODO: Add a proxy options for keep alive and set it here. + //client.Options.KeepAliveInterval = proxyService.Options.WebSocketKeepAliveInterval.Value; + + // TODO make a proxy option action to edit the web socket options. + + try + { + await socketToEndpoint.ConnectAsync(new Uri(uri), context.RequestAborted); + } + catch (Exception e) + { + if (options?.HandleFailure == null) + { + // If the failures are not caught, then write a generic response. + context.Response.StatusCode = 502; + await context.Response.WriteAsync($"Request could not be proxied.\n\n{e.Message}\n\n{e.StackTrace}.").ConfigureAwait(false); + return; + } + + await options.HandleFailure(context, e).ConfigureAwait(false); + } + + using (var socketToClient = await context.WebSockets.AcceptWebSocketAsync(socketToEndpoint.SubProtocol)) + { + // TODO: Add a buffer size option and set it here. + var bufferSize = 4096; + await Task.WhenAll(PumpWebSocket(socketToEndpoint, socketToClient, bufferSize, context.RequestAborted), PumpWebSocket(socketToClient, socketToEndpoint, bufferSize, context.RequestAborted)); + } + } + } + + private static async Task PumpWebSocket(WebSocket source, WebSocket destination, int bufferSize, CancellationToken cancellationToken) + { + var buffer = new byte[bufferSize]; + + while (true) + { + WebSocketReceiveResult result; + + try + { + result = await source.ReceiveAsync(new ArraySegment(buffer), cancellationToken); + } + catch (OperationCanceledException) + { + await destination.CloseOutputAsync(WebSocketCloseStatus.EndpointUnavailable, null, cancellationToken); + return; + } + + if (result.MessageType == WebSocketMessageType.Close) + { + await destination.CloseOutputAsync(source.CloseStatus.Value, source.CloseStatusDescription, cancellationToken); + return; + } + + // TODO: Add handlers here to allow the developer to edit message before forwarding, and vice versa? + + await destination.SendAsync(new ArraySegment(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, cancellationToken); + } + } + } +} \ No newline at end of file diff --git a/src/Test/Extensions.cs b/src/Test/Extensions.cs new file mode 100644 index 0000000..6414473 --- /dev/null +++ b/src/Test/Extensions.cs @@ -0,0 +1,63 @@ +using System; +using System.Diagnostics; +using System.IO; +using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; + +namespace AspNetCore.Proxy.Tests +{ + internal static class Extensions + { + const int BUFFER_SIZE = 4096; + internal static readonly string SupportedProtocol = "MyProtocol1"; + internal static readonly string CloseMessage = "PLEASE_CLOSE"; + internal static readonly string CloseDescription = "ARBITRARY"; + + internal static Task SendShortMessageAsync(this WebSocket socket, string message) + { + if(message.Length > BUFFER_SIZE / 8) + throw new InvalidOperationException($"Must send a short message (less than {BUFFER_SIZE / 8} characters)."); + + return socket.SendAsync(new ArraySegment(Encoding.UTF8.GetBytes(message)), WebSocketMessageType.Text, true, CancellationToken.None); + } + + internal static async Task ReceiveShortMessageAsync(this WebSocket socket) + { + var buffer = new byte[BUFFER_SIZE]; + var result = await socket.ReceiveAsync(buffer, CancellationToken.None); + + if(!result.EndOfMessage) + throw new InvalidOperationException($"Must send a short message (less than {BUFFER_SIZE / 8} characters)."); + + return Encoding.UTF8.GetString(buffer, 0, result.Count); + } + + internal static async Task SocketBoomerang(this HttpContext context) + { + var socket = await context.WebSockets.AcceptWebSocketAsync(SupportedProtocol); + + while(true) + { + var message = await socket.ReceiveShortMessageAsync(); + + if(message == CloseMessage) + { + await socket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, CloseDescription, context.RequestAborted); + break; + } + + // Basically, this server just always sends back a message that is the message it received wrapped with "[]". + await socket.SendShortMessageAsync($"[{message}]"); + } + } + + internal static async Task HttpBoomerang(this HttpContext context) + { + var message = await new StreamReader(context.Request.Body).ReadToEndAsync(); + await context.Response.WriteAsync($"[{message}]"); + } + } +} \ No newline at end of file diff --git a/src/Test/Helpers.cs b/src/Test/Http/HttpHelpers.cs similarity index 86% rename from src/Test/Helpers.cs rename to src/Test/Http/HttpHelpers.cs index 66d79fa..ebeb064 100644 --- a/src/Test/Helpers.cs +++ b/src/Test/Http/HttpHelpers.cs @@ -10,7 +10,7 @@ namespace AspNetCore.Proxy.Tests { - public class Startup + internal class Startup { public void ConfigureServices(IServiceCollection services) { @@ -28,8 +28,8 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env) { app.UseMiddleware(); app.UseRouting(); - app.UseEndpoints(endpoints => endpoints.MapControllers()); app.UseProxies(); + app.UseEndpoints(endpoints => endpoints.MapControllers()); app.UseProxy("echo/post", (context, args) => { return Task.FromResult($"https://postman-echo.com/post"); @@ -105,28 +105,43 @@ public static string ProxyToString(int postId) { return $"https://jsonplaceholder.typicode.com/posts/{postId}"; } + } - [ProxyRoute("api/posts")] - public static string ProxyPostRequest() + public class MvcController : ControllerBase + { + [Route("api/posts")] + public Task ProxyPostRequest() { - return $"https://jsonplaceholder.typicode.com/posts"; + return this.ProxyAsync($"https://jsonplaceholder.typicode.com/posts"); } - [ProxyRoute("api/catchall/{*rest}")] - public static string ProxyCatchAll(string rest) + [Route("api/catchall/{**rest}")] + public Task ProxyCatchAll(string rest) { - return $"https://jsonplaceholder.typicode.com/{rest}"; + return this.ProxyAsync($"https://jsonplaceholder.typicode.com/{rest}"); } - } - public class MvcController : ControllerBase - { [Route("api/controller/posts/{postId}")] public Task GetPosts(int postId) { return this.ProxyAsync($"https://jsonplaceholder.typicode.com/posts/{postId}"); } + [Route("api/controller/intercept/{postId}")] + public Task GetWithIntercept(int postId) + { + var options = ProxyOptions.Instance + .WithIntercept(async c => + { + c.Response.StatusCode = 200; + await c.Response.WriteAsync("This was intercepted and not proxied!"); + + return true; + }); + + return this.ProxyAsync($"https://jsonplaceholder.typicode.com/posts/{postId}", options); + } + [Route("api/controller/customrequest/{postId}")] public Task GetWithCustomRequest(int postId) { @@ -135,8 +150,8 @@ public Task GetWithCustomRequest(int postId) { hrm.RequestUri = new Uri("https://jsonplaceholder.typicode.com/posts/2"); return Task.CompletedTask; - }) - .WithShouldAddForwardedHeaders(false); + }) + .WithShouldAddForwardedHeaders(false); return this.ProxyAsync($"https://jsonplaceholder.typicode.com/posts/{postId}", options); } diff --git a/src/Test/UnitTests.cs b/src/Test/Http/HttpUnitTests.cs similarity index 91% rename from src/Test/UnitTests.cs rename to src/Test/Http/HttpUnitTests.cs index f716acf..8c9120f 100644 --- a/src/Test/UnitTests.cs +++ b/src/Test/Http/HttpUnitTests.cs @@ -12,12 +12,12 @@ namespace AspNetCore.Proxy.Tests { - public class UnitTests + public class HttpUnitTests { private readonly TestServer _server; private readonly HttpClient _client; - public UnitTests() + public HttpUnitTests() { _server = new TestServer(new WebHostBuilder().UseStartup()); _client = _server.CreateClient(); @@ -44,7 +44,7 @@ public async Task CanProxyAttributeToString() } [Fact] - public async Task CanProxyAttributePostRequest() + public async Task CanProxyControllerPostRequest() { var content = new StringContent("{\"title\": \"foo\", \"body\": \"bar\", \"userId\": 1}", Encoding.UTF8, "application/json"); var response = await _client.PostAsync("api/posts", content); @@ -54,9 +54,8 @@ public async Task CanProxyAttributePostRequest() Assert.Contains("101", JObject.Parse(responseString).Value("id")); } - [Fact] - public async Task CanProxyContentHeadersPostRequest() + public async Task CanProxyControllerContentHeadersPostRequest() { var content = "hello world"; var contentType = "application/xcustom"; @@ -75,7 +74,7 @@ public async Task CanProxyContentHeadersPostRequest() [Fact] - public async Task CanProxyAttributePostWithFormRequest() + public async Task CanProxyControllerPostWithFormRequest() { var content = new FormUrlEncodedContent(new Dictionary { { "xyz", "123" }, { "abc", "321" } }); var response = await _client.PostAsync("api/posts", content); @@ -90,7 +89,7 @@ public async Task CanProxyAttributePostWithFormRequest() } [Fact] - public async Task CanProxyAttributeCatchAll() + public async Task CanProxyControllerCatchAll() { var response = await _client.GetAsync("api/catchall/posts/1"); response.EnsureSuccessStatusCode(); @@ -212,6 +211,14 @@ public async Task CanGetCustomFailure() Assert.Equal("Things borked.", await response.Content.ReadAsStringAsync()); } + [Fact] + public async Task CanGetIntercept() + { + var response = await _client.GetAsync("api/controller/intercept/1"); + response.EnsureSuccessStatusCode(); + Assert.Equal("This was intercepted and not proxied!", await response.Content.ReadAsStringAsync()); + } + [Fact] public async Task CanProxyConcurrentCalls() { diff --git a/src/Test/Mix/MixHelpers.cs b/src/Test/Mix/MixHelpers.cs new file mode 100644 index 0000000..39a0d45 --- /dev/null +++ b/src/Test/Mix/MixHelpers.cs @@ -0,0 +1,58 @@ +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace AspNetCore.Proxy.Tests +{ + internal static class MixHelpers + { + internal static Task RunMixServers(CancellationToken token) + { + var proxiedServerTask = WebHost.CreateDefaultBuilder() + .SuppressStatusMessages(true) + .ConfigureLogging(logging => logging.ClearProviders()) + .ConfigureKestrel(options => + { + options.ListenLocalhost(5004); + }) + .Configure(app => app.UseWebSockets().Run(context => + { + if(context.WebSockets.IsWebSocketRequest) + return context.SocketBoomerang(); + + return context.HttpBoomerang(); + })) + .Build().RunAsync(token); + + var proxyServerTask = WebHost.CreateDefaultBuilder() + .SuppressStatusMessages(true) + .ConfigureLogging(logging => logging.ClearProviders()) + .ConfigureKestrel(options => + { + options.ListenLocalhost(5003); + }) + .ConfigureServices(services => services.AddProxies().AddRouting().AddControllers()) + .Configure(app => + { + app.UseWebSockets(); + app.UseRouting(); + app.UseEndpoints(end => end.MapControllers()); + + app.RunProxy(context => + { + if(context.WebSockets.IsWebSocketRequest) + return $"ws://localhost:5004"; + + return $"http://localhost:5004"; + }); + }) + .Build().RunAsync(token); + + return Task.WhenAll(proxiedServerTask, proxyServerTask); + } + } +} \ No newline at end of file diff --git a/src/Test/Mix/MixUnitTests.cs b/src/Test/Mix/MixUnitTests.cs new file mode 100644 index 0000000..e1e7cef --- /dev/null +++ b/src/Test/Mix/MixUnitTests.cs @@ -0,0 +1,74 @@ +using System; +using System.Net.Http; +using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace AspNetCore.Proxy.Tests +{ + public class MixServerFixture : IDisposable + { + private CancellationTokenSource _source; + + public MixServerFixture() + { + _source = new CancellationTokenSource(); + MixHelpers.RunMixServers(_source.Token); + } + + public void Dispose() + { + _source.Cancel(); + } + } + + public class MixUnitTests : IClassFixture + { + public readonly ClientWebSocket _client; + + public MixUnitTests(MixServerFixture fixture) + { + _client = new ClientWebSocket(); + _client.Options.AddSubProtocol(Extensions.SupportedProtocol); + } + + [Fact] + public async Task CanDoWebSockets() + { + var send1 = "TEST1"; + var expected1 = $"[{send1}]"; + + var send2 = "TEST2"; + var expected2 = $"[{send2}]"; + + await _client.ConnectAsync(new Uri("ws://localhost:5003/to/random/path"), CancellationToken.None); + Assert.Equal(Extensions.SupportedProtocol, _client.SubProtocol); + + // Send a message. + await _client.SendShortMessageAsync(send1); + await _client.SendShortMessageAsync(send2); + await _client.SendShortMessageAsync(Extensions.CloseMessage); + + // Receive responses. + var response1 = await _client.ReceiveShortMessageAsync(); + Assert.Equal(expected1, response1); + var response2 = await _client.ReceiveShortMessageAsync(); + Assert.Equal(expected2, response2); + + // Receive close. + var result = await _client.ReceiveAsync(new ArraySegment(new byte[4096]), CancellationToken.None); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(WebSocketCloseStatus.NormalClosure, result.CloseStatus); + Assert.Equal(Extensions.CloseDescription, result.CloseStatusDescription); + + await _client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, Extensions.CloseDescription, CancellationToken.None); + + // HTTP test. + HttpClient client = new HttpClient(); + var response = await client.PostAsync("http://localhost:5003/at/some/path", new StringContent(send1)); + Assert.Equal(expected1, await response.Content.ReadAsStringAsync()); + } + } +} \ No newline at end of file diff --git a/src/Test/Ws/WsHelpers.cs b/src/Test/Ws/WsHelpers.cs new file mode 100644 index 0000000..2fc88ea --- /dev/null +++ b/src/Test/Ws/WsHelpers.cs @@ -0,0 +1,58 @@ +using System; +using System.Diagnostics; +using System.Linq; +using System.Net; +using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace AspNetCore.Proxy.Tests +{ + public class WsController : ControllerBase + { + [Route("api/ws")] + public Task ProxyWsController() + { + return this.ProxyAsync("ws://localhost:5002/"); + } + } + + internal static class WsHelpers + { + internal static Task RunWsServers(CancellationToken token) + { + var proxiedServerTask = WebHost.CreateDefaultBuilder() + .SuppressStatusMessages(true) + .ConfigureLogging(logging => logging.ClearProviders()) + .ConfigureKestrel(options => + { + options.ListenLocalhost(5002); + }) + .Configure(app => app.UseWebSockets().Run(context => + { + return context.SocketBoomerang(); + })) + .Build().RunAsync(token); + + var proxyServerTask = WebHost.CreateDefaultBuilder() + .SuppressStatusMessages(true) + .ConfigureLogging(logging => logging.ClearProviders()) + .ConfigureKestrel(options => + { + options.ListenLocalhost(5001); + }) + .ConfigureServices(services => services.AddProxies().AddRouting().AddControllers()) + .Configure(app => app.UseWebSockets().UseRouting().UseEndpoints(end => end.MapControllers()).UseProxy("/ws", "ws://localhost:5002/")) + .Build().RunAsync(token); + + return Task.WhenAll(proxiedServerTask, proxyServerTask); + } + } +} \ No newline at end of file diff --git a/src/Test/Ws/WsUnitTests.cs b/src/Test/Ws/WsUnitTests.cs new file mode 100644 index 0000000..8c76bce --- /dev/null +++ b/src/Test/Ws/WsUnitTests.cs @@ -0,0 +1,71 @@ +using System; +using System.Net.Http; +using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace AspNetCore.Proxy.Tests +{ + public class WsServerFixture : IDisposable + { + private CancellationTokenSource _source; + + public WsServerFixture() + { + _source = new CancellationTokenSource(); + WsHelpers.RunWsServers(_source.Token); + } + + public void Dispose() + { + _source.Cancel(); + } + } + + public class WsUnitTests : IClassFixture + { + public readonly ClientWebSocket _client; + + public WsUnitTests(WsServerFixture fixture) + { + _client = new ClientWebSocket(); + _client.Options.AddSubProtocol(Extensions.SupportedProtocol); + } + + [Theory] + [InlineData("ws://localhost:5001/ws")] + [InlineData("ws://localhost:5001/api/ws")] + public async Task CanDoWebSockets(string server) + { + var send1 = "TEST1"; + var expected1 = $"[{send1}]"; + + var send2 = "TEST2"; + var expected2 = $"[{send2}]"; + + await _client.ConnectAsync(new Uri(server), CancellationToken.None); + Assert.Equal(Extensions.SupportedProtocol, _client.SubProtocol); + + // Send a message. + await _client.SendShortMessageAsync(send1); + await _client.SendShortMessageAsync(send2); + await _client.SendShortMessageAsync(Extensions.CloseMessage); + + // Receive responses. + var response1 = await _client.ReceiveShortMessageAsync(); + Assert.Equal(expected1, response1); + var response2 = await _client.ReceiveShortMessageAsync(); + Assert.Equal(expected2, response2); + + // Receive close. + var result = await _client.ReceiveAsync(new ArraySegment(new byte[4096]), CancellationToken.None); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(WebSocketCloseStatus.NormalClosure, result.CloseStatus); + Assert.Equal(Extensions.CloseDescription, result.CloseStatusDescription); + + await _client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, Extensions.CloseDescription, CancellationToken.None); + } + } +} \ No newline at end of file From f0ebfb1ef85799093606b4253aa710ce929231f1 Mon Sep 17 00:00:00 2001 From: Aaron Roney Date: Tue, 8 Oct 2019 22:03:35 -0700 Subject: [PATCH 3/7] Fix readme. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9f6095e..c7c6975 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ public class MyController : Controller await c.Response.WriteAsync("Things borked."); }); - return this.ProxyAsync($"https://jsonplaceholder.typicode.com/posts/{postId}"); + return this.ProxyAsync($"https://jsonplaceholder.typicode.com/posts/{postId}", options); } } ``` From b11beb6f30e98e9c4e2820fa0e8b7d42866dee15 Mon Sep 17 00:00:00 2001 From: Aaron Roney Date: Tue, 8 Oct 2019 23:11:56 -0700 Subject: [PATCH 4/7] Improve error handling and tests. --- src/Core/Extensions.cs | 39 ++++++++++++++++++++++-------- src/Core/HttpExtensions.cs | 39 +++++++++--------------------- src/Core/WsExtensions.cs | 29 +++++++++------------- src/Test/Extensions.cs | 6 +++++ src/Test/Mix/MixHelpers.cs | 6 +++++ src/Test/Mix/MixUnitTests.cs | 47 +++++++++++++++++++++++++----------- src/Test/Ws/WsUnitTests.cs | 23 ++++++++++++++++++ 7 files changed, 120 insertions(+), 69 deletions(-) diff --git a/src/Core/Extensions.cs b/src/Core/Extensions.cs index 557ce87..5fc38c3 100644 --- a/src/Core/Extensions.cs +++ b/src/Core/Extensions.cs @@ -6,21 +6,40 @@ namespace AspNetCore.Proxy { internal static class Extensions { - internal static Task ExecuteProxyOperationAsync(this HttpContext context, string uri, ProxyOptions options = null) + internal static async Task ExecuteProxyOperationAsync(this HttpContext context, string uri, ProxyOptions options = null) { - if (context.WebSockets.IsWebSocketRequest) + try { - if(!uri.StartsWith("ws", System.StringComparison.OrdinalIgnoreCase)) - throw new InvalidOperationException("A WebSocket request must forward to a WebSocket (ws[s]) endpoint."); + if (context.WebSockets.IsWebSocketRequest) + { + if(!uri.StartsWith("ws", System.StringComparison.OrdinalIgnoreCase)) + throw new InvalidOperationException("A WebSocket request must forward to a WebSocket (ws[s]) endpoint."); - return context.ExecuteWsProxyOperationAsync(uri, options); - } + await context.ExecuteWsProxyOperationAsync(uri, options).ConfigureAwait(false); + return; + } - // Assume HTTP if not WebSocket. - if(!uri.StartsWith("http", System.StringComparison.OrdinalIgnoreCase)) - throw new InvalidOperationException("An HTTP request must forward to an HTTP (http[s]) endpoint."); + // Assume HTTP if not WebSocket. + if(!uri.StartsWith("http", System.StringComparison.OrdinalIgnoreCase)) + throw new InvalidOperationException("An HTTP request must forward to an HTTP (http[s]) endpoint."); - return context.ExecuteHttpProxyOperationAsync(uri, options); + await context.ExecuteHttpProxyOperationAsync(uri, options).ConfigureAwait(false); + } + catch (Exception e) + { + if(!context.Response.HasStarted) + { + if (options?.HandleFailure == null) + { + // If the failures are not caught, then write a generic response. + context.Response.StatusCode = 502; + await context.Response.WriteAsync($"Request could not be proxied.\n\n{e.Message}\n\n{e.StackTrace}").ConfigureAwait(false); + return; + } + + await options.HandleFailure(context, e).ConfigureAwait(false); + } + } } } } diff --git a/src/Core/HttpExtensions.cs b/src/Core/HttpExtensions.cs index 3728d89..d1022a9 100644 --- a/src/Core/HttpExtensions.cs +++ b/src/Core/HttpExtensions.cs @@ -13,36 +13,21 @@ internal static class HttpExtensions { internal static async Task ExecuteHttpProxyOperationAsync(this HttpContext context, string uri, ProxyOptions options = null) { - try - { - // If `true`, this proxy call has been intercepted. - if(options?.Intercept != null && await options.Intercept(context)) - return; + // If `true`, this proxy call has been intercepted. + if(options?.Intercept != null && await options.Intercept(context)) + return; - var proxiedRequest = context.CreateProxiedHttpRequest(uri, options?.ShouldAddForwardedHeaders ?? true); + var proxiedRequest = context.CreateProxiedHttpRequest(uri, options?.ShouldAddForwardedHeaders ?? true); - if(options?.BeforeSend != null) - await options.BeforeSend(context, proxiedRequest).ConfigureAwait(false); - var proxiedResponse = await context - .SendProxiedHttpRequestAsync(proxiedRequest, options?.HttpClientName ?? Helpers.HttpProxyClientName) - .ConfigureAwait(false); + if(options?.BeforeSend != null) + await options.BeforeSend(context, proxiedRequest).ConfigureAwait(false); + var proxiedResponse = await context + .SendProxiedHttpRequestAsync(proxiedRequest, options?.HttpClientName ?? Helpers.HttpProxyClientName) + .ConfigureAwait(false); - if(options?.AfterReceive != null) - await options.AfterReceive(context, proxiedResponse).ConfigureAwait(false); - await context.WriteProxiedHttpResponseAsync(proxiedResponse).ConfigureAwait(false); - } - catch (Exception e) - { - if (options?.HandleFailure == null) - { - // If the failures are not caught, then write a generic response. - context.Response.StatusCode = 502; - await context.Response.WriteAsync($"Request could not be proxied.\n\n{e.Message}\n\n{e.StackTrace}.").ConfigureAwait(false); - return; - } - - await options.HandleFailure(context, e).ConfigureAwait(false); - } + if(options?.AfterReceive != null) + await options.AfterReceive(context, proxiedResponse).ConfigureAwait(false); + await context.WriteProxiedHttpResponseAsync(proxiedResponse).ConfigureAwait(false); } private static HttpRequestMessage CreateProxiedHttpRequest(this HttpContext context, string uriString, bool shouldAddForwardedHeaders) diff --git a/src/Core/WsExtensions.cs b/src/Core/WsExtensions.cs index b6a08f9..2581db4 100644 --- a/src/Core/WsExtensions.cs +++ b/src/Core/WsExtensions.cs @@ -1,6 +1,7 @@ using System; using System.Linq; using System.Net.WebSockets; +using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; @@ -9,6 +10,8 @@ namespace AspNetCore.Proxy { internal static class WsExtensions { + internal static readonly int CloseMessageMaxSize = 123; + internal static async Task ExecuteWsProxyOperationAsync(this HttpContext context, string uri, ProxyOptions options = null) { using (var socketToEndpoint = new ClientWebSocket()) @@ -27,22 +30,7 @@ internal static async Task ExecuteWsProxyOperationAsync(this HttpContext context // TODO make a proxy option action to edit the web socket options. - try - { - await socketToEndpoint.ConnectAsync(new Uri(uri), context.RequestAborted); - } - catch (Exception e) - { - if (options?.HandleFailure == null) - { - // If the failures are not caught, then write a generic response. - context.Response.StatusCode = 502; - await context.Response.WriteAsync($"Request could not be proxied.\n\n{e.Message}\n\n{e.StackTrace}.").ConfigureAwait(false); - return; - } - - await options.HandleFailure(context, e).ConfigureAwait(false); - } + await socketToEndpoint.ConnectAsync(new Uri(uri), context.RequestAborted); using (var socketToClient = await context.WebSockets.AcceptWebSocketAsync(socketToEndpoint.SubProtocol)) { @@ -65,12 +53,17 @@ private static async Task PumpWebSocket(WebSocket source, WebSocket destination, { result = await source.ReceiveAsync(new ArraySegment(buffer), cancellationToken); } - catch (OperationCanceledException) + catch (Exception e) { - await destination.CloseOutputAsync(WebSocketCloseStatus.EndpointUnavailable, null, cancellationToken); + var closeMessageBytes = Encoding.UTF8.GetBytes($"WebSocket failure.\n\n{e.Message}\n\n{e.StackTrace}"); + var closeMessage = Encoding.UTF8.GetString(closeMessageBytes, 0, Math.Min(closeMessageBytes.Length, CloseMessageMaxSize)); + await destination.CloseOutputAsync(WebSocketCloseStatus.EndpointUnavailable, closeMessage, cancellationToken); return; } + if(destination.State != WebSocketState.Open && destination.State != WebSocketState.CloseReceived) + return; + if (result.MessageType == WebSocketMessageType.Close) { await destination.CloseOutputAsync(source.CloseStatus.Value, source.CloseStatusDescription, cancellationToken); diff --git a/src/Test/Extensions.cs b/src/Test/Extensions.cs index 6414473..606b86d 100644 --- a/src/Test/Extensions.cs +++ b/src/Test/Extensions.cs @@ -14,6 +14,7 @@ internal static class Extensions const int BUFFER_SIZE = 4096; internal static readonly string SupportedProtocol = "MyProtocol1"; internal static readonly string CloseMessage = "PLEASE_CLOSE"; + internal static readonly string KillMessage = "PLEASE_KILL"; internal static readonly string CloseDescription = "ARBITRARY"; internal static Task SendShortMessageAsync(this WebSocket socket, string message) @@ -48,6 +49,11 @@ internal static async Task SocketBoomerang(this HttpContext context) await socket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, CloseDescription, context.RequestAborted); break; } + + if(message == KillMessage) + { + throw new Exception(); + } // Basically, this server just always sends back a message that is the message it received wrapped with "[]". await socket.SendShortMessageAsync($"[{message}]"); diff --git a/src/Test/Mix/MixHelpers.cs b/src/Test/Mix/MixHelpers.cs index 39a0d45..db031b7 100644 --- a/src/Test/Mix/MixHelpers.cs +++ b/src/Test/Mix/MixHelpers.cs @@ -44,6 +44,12 @@ internal static Task RunMixServers(CancellationToken token) app.RunProxy(context => { + if(context.Request.Path.StartsWithSegments("/should/forward/to/ws")) + return $"ws://localhost:5004"; + + if(context.Request.Path.StartsWithSegments("/should/forward/to/http")) + return $"http://localhost:5004"; + if(context.WebSockets.IsWebSocketRequest) return $"ws://localhost:5004"; diff --git a/src/Test/Mix/MixUnitTests.cs b/src/Test/Mix/MixUnitTests.cs index e1e7cef..070b490 100644 --- a/src/Test/Mix/MixUnitTests.cs +++ b/src/Test/Mix/MixUnitTests.cs @@ -1,4 +1,5 @@ using System; +using System.Net; using System.Net.Http; using System.Net.WebSockets; using System.Text; @@ -26,12 +27,15 @@ public void Dispose() public class MixUnitTests : IClassFixture { - public readonly ClientWebSocket _client; + private readonly ClientWebSocket _wsClient; + private readonly HttpClient _httpClient; public MixUnitTests(MixServerFixture fixture) { - _client = new ClientWebSocket(); - _client.Options.AddSubProtocol(Extensions.SupportedProtocol); + _wsClient = new ClientWebSocket(); + _wsClient.Options.AddSubProtocol(Extensions.SupportedProtocol); + + _httpClient = new HttpClient(); } [Fact] @@ -43,32 +47,47 @@ public async Task CanDoWebSockets() var send2 = "TEST2"; var expected2 = $"[{send2}]"; - await _client.ConnectAsync(new Uri("ws://localhost:5003/to/random/path"), CancellationToken.None); - Assert.Equal(Extensions.SupportedProtocol, _client.SubProtocol); + await _wsClient.ConnectAsync(new Uri("ws://localhost:5003/to/random/path"), CancellationToken.None); + Assert.Equal(Extensions.SupportedProtocol, _wsClient.SubProtocol); // Send a message. - await _client.SendShortMessageAsync(send1); - await _client.SendShortMessageAsync(send2); - await _client.SendShortMessageAsync(Extensions.CloseMessage); + await _wsClient.SendShortMessageAsync(send1); + await _wsClient.SendShortMessageAsync(send2); + await _wsClient.SendShortMessageAsync(Extensions.CloseMessage); // Receive responses. - var response1 = await _client.ReceiveShortMessageAsync(); + var response1 = await _wsClient.ReceiveShortMessageAsync(); Assert.Equal(expected1, response1); - var response2 = await _client.ReceiveShortMessageAsync(); + var response2 = await _wsClient.ReceiveShortMessageAsync(); Assert.Equal(expected2, response2); // Receive close. - var result = await _client.ReceiveAsync(new ArraySegment(new byte[4096]), CancellationToken.None); + var result = await _wsClient.ReceiveAsync(new ArraySegment(new byte[4096]), CancellationToken.None); Assert.Equal(WebSocketMessageType.Close, result.MessageType); Assert.Equal(WebSocketCloseStatus.NormalClosure, result.CloseStatus); Assert.Equal(Extensions.CloseDescription, result.CloseStatusDescription); - await _client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, Extensions.CloseDescription, CancellationToken.None); + await _wsClient.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, Extensions.CloseDescription, CancellationToken.None); // HTTP test. - HttpClient client = new HttpClient(); - var response = await client.PostAsync("http://localhost:5003/at/some/path", new StringContent(send1)); + var response = await _httpClient.PostAsync("http://localhost:5003/at/some/path", new StringContent(send1)); Assert.Equal(expected1, await response.Content.ReadAsStringAsync()); } + + [Fact] + public async Task CanFailOnIncorrectForwardToWs() + { + var response = await _httpClient.GetAsync("http://localhost:5003/should/forward/to/ws"); + Assert.Equal(HttpStatusCode.BadGateway, response.StatusCode); + } + + [Fact] + public async Task CanFailOnIncorrectForwardToHttp() + { + await Assert.ThrowsAnyAsync(async () => + { + await _wsClient.ConnectAsync(new Uri("ws://localhost:5003/should/forward/to/http"), CancellationToken.None); + }); + } } } \ No newline at end of file diff --git a/src/Test/Ws/WsUnitTests.cs b/src/Test/Ws/WsUnitTests.cs index 8c76bce..2533531 100644 --- a/src/Test/Ws/WsUnitTests.cs +++ b/src/Test/Ws/WsUnitTests.cs @@ -31,6 +31,7 @@ public class WsUnitTests : IClassFixture public WsUnitTests(WsServerFixture fixture) { _client = new ClientWebSocket(); + _client.Options.SetRequestHeader("SomeHeader", "SomeValue"); _client.Options.AddSubProtocol(Extensions.SupportedProtocol); } @@ -67,5 +68,27 @@ public async Task CanDoWebSockets(string server) await _client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, Extensions.CloseDescription, CancellationToken.None); } + + [Fact] + public async Task CanCatchAbruptClose() + { + var send1 = "PLEASE_KILL"; + var expected1 = $"[{send1}]"; + + var send2 = "TEST2"; + var expected2 = $"[{send2}]"; + + await _client.ConnectAsync(new Uri("ws://localhost:5001/ws"), CancellationToken.None); + + // Send a message. + await _client.SendShortMessageAsync(send1); + + // Receive failed close. + var result = await _client.ReceiveAsync(new ArraySegment(new byte[4096]), CancellationToken.None); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(WebSocketCloseStatus.EndpointUnavailable, result.CloseStatus); + + await _client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, Extensions.CloseDescription, CancellationToken.None); + } } } \ No newline at end of file From 22ada5b1e0c7ce845675460cae1edffe727a3320 Mon Sep 17 00:00:00 2001 From: Aaron Roney Date: Tue, 8 Oct 2019 23:22:40 -0700 Subject: [PATCH 5/7] Code coverage fix. --- src/Test/Mix/MixHelpers.cs | 29 +++++++++++++++++++++-------- src/Test/Mix/MixUnitTests.cs | 10 ++++++++++ 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/Test/Mix/MixHelpers.cs b/src/Test/Mix/MixHelpers.cs index db031b7..e73006d 100644 --- a/src/Test/Mix/MixHelpers.cs +++ b/src/Test/Mix/MixHelpers.cs @@ -35,30 +35,43 @@ internal static Task RunMixServers(CancellationToken token) { options.ListenLocalhost(5003); }) - .ConfigureServices(services => services.AddProxies().AddRouting().AddControllers()) + .ConfigureServices(services => services.AddProxies()) .Configure(app => { app.UseWebSockets(); - app.UseRouting(); - app.UseEndpoints(end => end.MapControllers()); app.RunProxy(context => { if(context.Request.Path.StartsWithSegments("/should/forward/to/ws")) - return $"ws://localhost:5004"; + return "ws://localhost:5004"; if(context.Request.Path.StartsWithSegments("/should/forward/to/http")) - return $"http://localhost:5004"; + return "http://localhost:5004"; if(context.WebSockets.IsWebSocketRequest) - return $"ws://localhost:5004"; + return "ws://localhost:5004"; - return $"http://localhost:5004"; + return "http://localhost:5004"; }); }) .Build().RunAsync(token); - return Task.WhenAll(proxiedServerTask, proxyServerTask); + var proxyServerTask2 = WebHost.CreateDefaultBuilder() + .SuppressStatusMessages(true) + .ConfigureLogging(logging => logging.ClearProviders()) + .ConfigureKestrel(options => + { + options.ListenLocalhost(5007); + }) + .ConfigureServices(services => services.AddProxies()) + .Configure(app => + { + app.UseWebSockets(); + app.RunProxy("http://localhost:5004"); + }) + .Build().RunAsync(token); + + return Task.WhenAll(proxiedServerTask, proxyServerTask, proxyServerTask2); } } } \ No newline at end of file diff --git a/src/Test/Mix/MixUnitTests.cs b/src/Test/Mix/MixUnitTests.cs index 070b490..8fae147 100644 --- a/src/Test/Mix/MixUnitTests.cs +++ b/src/Test/Mix/MixUnitTests.cs @@ -74,6 +74,16 @@ public async Task CanDoWebSockets() Assert.Equal(expected1, await response.Content.ReadAsStringAsync()); } + [Fact] + public async Task CanDoSimpleServer() + { + var send1 = "TEST1"; + var expected1 = $"[{send1}]"; + + var response = await _httpClient.PostAsync("http://localhost:5007/at/some/other/path", new StringContent(send1)); + Assert.Equal(expected1, await response.Content.ReadAsStringAsync()); + } + [Fact] public async Task CanFailOnIncorrectForwardToWs() { From 58e2524f2f4796b6eef16de789583f9ed605d299 Mon Sep 17 00:00:00 2001 From: Aaron Roney Date: Tue, 8 Oct 2019 23:31:15 -0700 Subject: [PATCH 6/7] Moar coverage. --- src/Test/Mix/MixHelpers.cs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/Test/Mix/MixHelpers.cs b/src/Test/Mix/MixHelpers.cs index e73006d..159b22b 100644 --- a/src/Test/Mix/MixHelpers.cs +++ b/src/Test/Mix/MixHelpers.cs @@ -35,7 +35,13 @@ internal static Task RunMixServers(CancellationToken token) { options.ListenLocalhost(5003); }) - .ConfigureServices(services => services.AddProxies()) + .ConfigureServices(services => + { + services.AddProxies(client => + { + // This doesn't do anything, but it covers more code paths. :) + }); + }) .Configure(app => { app.UseWebSockets(); From ef261c6659a5b891b11cd1d35f9aa57d7a5ce4cf Mon Sep 17 00:00:00 2001 From: Aaron Roney Date: Tue, 8 Oct 2019 23:49:04 -0700 Subject: [PATCH 7/7] Update readme. --- README.md | 37 +++++++++++++++++++++++++++++++++++-- TODO.md | 1 + 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c7c6975..82c96a4 100644 --- a/README.md +++ b/README.md @@ -42,9 +42,29 @@ public void ConfigureServices(IServiceCollection services) } ``` +#### Run a Proxy + +You can run a proxy over all endpoints. + +```csharp +app.RunProxy("https://google.com"); +``` + +In addition, you can route this proxy depending on the context. + +```csharp +app.RunProxy(context => +{ + if(context.WebSockets.IsWebSocketRequest) + return "wss://mysite.com/ws"; + + return "https://mysite.com"; +}); +``` + #### Existing Controller -You can use the proxy functionality on an existing `Controller` by leveraging the `Proxy` extension method. +You can define a proxy over a specific endpoint on an existing `Controller` by leveraging the `Proxy` extension method. ```csharp public class MyController : Controller @@ -57,6 +77,19 @@ public class MyController : Controller } ``` +In addition, you can proxy to WebSocket endpoints. + +```csharp +public class MyController : Controller +{ + [Route("ws")] + public Task OpenWs() + { + return this.ProxyAsync($"wss://myendpoint.com/ws"); + } +} +``` + You can also pass special options that apply when the proxy operation occurs. ```csharp @@ -108,7 +141,7 @@ public class MyController : Controller #### Application Builder -You can define a proxy in `Configure(IApplicationBuilder app, IHostingEnvironment env)`. The arguments are passed to the underlying lambda as a `Dictionary`. +You can define a proxy over a specific endpoint in `Configure(IApplicationBuilder app, IHostingEnvironment env)`. The arguments are passed to the underlying lambda as a `Dictionary`. ```csharp app.UseProxy("api/{arg1}/{arg2}", async (args) => { diff --git a/TODO.md b/TODO.md index 6c10c94..17d3155 100644 --- a/TODO.md +++ b/TODO.md @@ -5,6 +5,7 @@ TODO: * Add options for WebSocket calls. * Make options handlers called `Async`? * Allow the user to set options via a lambda for builder purposes? + * Add a `RunProxy` that takes a `getProxiedAddress` as a `Task`. Some ideas of how `UseProxies` should work in 4.0.0.