diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index c51e7aae5a36..42e26d2fcf21 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -164,7 +164,10 @@ public static RequestDelegateResult Create(Delegate handler, RequestDelegateFact /// /// Creates a implementation for . /// - /// A request handler with any number of custom parameters that often produces a response with its return value. + /// + /// A request handler with any number of custom parameters that often produces a response with its return value. + /// If delegate points to instance method, but is set to , target will be fetched from . + /// /// The used to configure the behavior of the handler. /// /// The result returned from if that was used to inferring metadata before creating the final RequestDelegate. @@ -178,23 +181,37 @@ public static RequestDelegateResult Create(Delegate handler, RequestDelegateFact { ArgumentNullException.ThrowIfNull(handler); - var targetExpression = handler.Target switch + UnaryExpression? targetExpression = null; + Func? targetFactory = null; + Expression>? targetFactoryExpression = null; + + switch (handler.Target) { - object => Expression.Convert(TargetExpr, handler.Target.GetType()), - null => null, - }; + case object: + targetExpression = Expression.Convert(TargetExpr, handler.Target.GetType()); + targetFactory = (httpContext) => handler.Target; + targetFactoryExpression = (httpContext) => handler.Target; + + break; + + case null when !handler.Method.IsStatic: + targetExpression = Expression.Convert(TargetExpr, handler.Method.ReflectedType!); + targetFactory = (httpContext) => httpContext.RequestServices.GetRequiredService(handler.Method.ReflectedType!); + targetFactoryExpression = (httpContext) => httpContext.RequestServices.GetRequiredService(handler.Method.ReflectedType!); + + break; + } var factoryContext = CreateFactoryContext(options, metadataResult, handler); - Expression> targetFactory = (httpContext) => handler.Target; - var targetableRequestDelegate = CreateTargetableRequestDelegate(handler.Method, targetExpression, factoryContext, targetFactory); + var targetableRequestDelegate = CreateTargetableRequestDelegate(handler.Method, targetExpression, factoryContext, targetFactoryExpression); RequestDelegate finalRequestDelegate = targetableRequestDelegate switch { // handler is a RequestDelegate that has not been modified by a filter. Short-circuit and return the original RequestDelegate back. // It's possible a filter factory has still modified the endpoint metadata though. null => (RequestDelegate)handler, - _ => httpContext => targetableRequestDelegate(handler.Target, httpContext), + _ => httpContext => targetableRequestDelegate(targetFactory?.Invoke(httpContext), httpContext), }; return CreateRequestDelegateResult(finalRequestDelegate, factoryContext.EndpointBuilder); @@ -369,8 +386,9 @@ private static IReadOnlyList AsReadOnlyList(IList metadata) } } - // return null for plain RequestDelegates that have not been modified by filters so we can just pass back the original RequestDelegate. - if (filterPipeline is null && factoryContext.Handler is RequestDelegate) + // return null for plain RequestDelegates that have not been modified by filters so we can just pass back the original RequestDelegate + // but only when target is not injected + if (filterPipeline is null && factoryContext.Handler is RequestDelegate && (factoryContext.Handler.Method.IsStatic || factoryContext.Handler.Target is not null)) { return null; } diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 61e7d2a46965..a872cb3a8624 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -1217,6 +1217,56 @@ public async Task RequestDelegatePopulatesParametersFromServiceWithAndWithoutAtt Assert.Same(myOriginalService, httpContext.Items["service"]); } + [Fact] + public async Task RequestDelegateInjectingHandlerForUnboundCustomDelegate() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + serviceCollection.AddScoped(); + + var services = serviceCollection.BuildServiceProvider(); + + using var requestScoped = services.CreateScope(); + + var httpContext = CreateHttpContext(); + httpContext.RequestServices = requestScoped.ServiceProvider; + + var requestMethod = typeof(HttpHandler).GetMethod(nameof(HttpHandler.Handle))!; + var requestMethodDelegate = requestMethod.CreateDelegate>(); + + var factoryResult = RequestDelegateFactory.Create(requestMethodDelegate, options: new() { ServiceProvider = services }); + var requestDelegate = factoryResult.RequestDelegate; + + await requestDelegate(httpContext); + + Assert.Equal(1, httpContext.Items["calls"]); + } + + [Fact] + public async Task RequestDelegateInjectingHandlerForUnboundRequestDelegate() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + serviceCollection.AddScoped(); + + var services = serviceCollection.BuildServiceProvider(); + + using var requestScoped = services.CreateScope(); + + var httpContext = CreateHttpContext(); + httpContext.RequestServices = requestScoped.ServiceProvider; + + var requestMethod = typeof(HttpHandler).GetMethod(nameof(HttpHandler.Handle))!; + var requestMethodDelegate = requestMethod.CreateDelegate(null); + + var factoryResult = RequestDelegateFactory.Create(requestMethodDelegate, options: new() { ServiceProvider = services }); + var requestDelegate = factoryResult.RequestDelegate; + + await requestDelegate(httpContext); + + Assert.Equal(1, httpContext.Items["calls"]); + } + [Fact] public async Task RequestDelegatePopulatesHttpContextParameterWithoutAttribute() { @@ -3659,14 +3709,19 @@ private class FromServiceAttribute : Attribute, IFromServiceMetadata { } - class HttpHandler + private class HttpHandler { private int _calls; - public void Handle(HttpContext httpContext) + /// + /// Method in form of . + /// + public Task Handle(HttpContext httpContext) { _calls++; httpContext.Items["calls"] = _calls; + + return Task.CompletedTask; } }