diff --git a/src/common/Elsa.Mediator/Extensions/HandlerExtensions.cs b/src/common/Elsa.Mediator/Extensions/HandlerExtensions.cs index 695268c58b..044970500e 100644 --- a/src/common/Elsa.Mediator/Extensions/HandlerExtensions.cs +++ b/src/common/Elsa.Mediator/Extensions/HandlerExtensions.cs @@ -1,5 +1,6 @@ using System.Diagnostics.CodeAnalysis; using System.Reflection; +using System.Runtime.ExceptionServices; using Elsa.Mediator.Contracts; using Elsa.Mediator.Middleware.Command; using Elsa.Mediator.Middleware.Notification; @@ -47,7 +48,7 @@ public static Task InvokeAsync(this INotificationHandler handler, MethodBase han { var notification = notificationContext.Notification; var cancellationToken = notificationContext.CancellationToken; - return (Task)handleMethod.Invoke(handler, [notification, cancellationToken])!; + return InvokeAndUnwrap(handleMethod, handler, [notification, cancellationToken]); } /// @@ -60,7 +61,22 @@ public static Task InvokeAsync(this ICommandHandler handler, M { var command = commandContext.Command; var cancellationToken = commandContext.CancellationToken; - var task = (Task)handleMethod.Invoke(handler, [command, cancellationToken])!; - return task; + return InvokeAndUnwrap>(handleMethod, handler, [command, cancellationToken]); + } + + /// + /// Invokes a method via reflection and unwraps any TargetInvocationException to preserve the original exception's stack trace. + /// + private static T InvokeAndUnwrap(MethodBase method, object target, object[] args) where T : Task + { + try + { + return (T)method.Invoke(target, args)!; + } + catch (TargetInvocationException ex) when (ex.InnerException is not null) + { + ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); + throw; // Unreachable, but required for compiler + } } } \ No newline at end of file diff --git a/src/common/Elsa.Mediator/Middleware/Command/Components/CommandHandlerInvokerMiddleware.cs b/src/common/Elsa.Mediator/Middleware/Command/Components/CommandHandlerInvokerMiddleware.cs index 647fe4b58e..96afdc9a72 100644 --- a/src/common/Elsa.Mediator/Middleware/Command/Components/CommandHandlerInvokerMiddleware.cs +++ b/src/common/Elsa.Mediator/Middleware/Command/Components/CommandHandlerInvokerMiddleware.cs @@ -39,16 +39,15 @@ public async ValueTask InvokeAsync(CommandContext context) var executeMethodWithReturnType = executeMethod.MakeGenericMethod(resultType); // Execute command. - var task = executeMethodWithReturnType.Invoke(strategy, [strategyContext]); + var task = (Task)executeMethodWithReturnType.Invoke(strategy, [strategyContext])!; + await task.ConfigureAwait(false); - // Await the task to get the result without blocking. + // Get result of task. var taskWithReturnType = typeof(Task<>).MakeGenericType(resultType); - var taskInstance = (Task)task!; - await taskInstance.ConfigureAwait(false); var resultProperty = taskWithReturnType.GetProperty(nameof(Task.Result))!; context.Result = resultProperty.GetValue(task); // Invoke next middleware. - await next(context); + await next(context).ConfigureAwait(false); } } \ No newline at end of file diff --git a/src/common/Elsa.Mediator/Middleware/Request/Components/RequestHandlerInvokerMiddleware.cs b/src/common/Elsa.Mediator/Middleware/Request/Components/RequestHandlerInvokerMiddleware.cs index 2055515f11..b3d47e17c4 100644 --- a/src/common/Elsa.Mediator/Middleware/Request/Components/RequestHandlerInvokerMiddleware.cs +++ b/src/common/Elsa.Mediator/Middleware/Request/Components/RequestHandlerInvokerMiddleware.cs @@ -31,7 +31,7 @@ public async ValueTask InvokeAsync(RequestContext context) var handleMethod = handlerType.GetMethod("HandleAsync")!; var cancellationToken = context.CancellationToken; var task = (Task)handleMethod.Invoke(handler, [request, cancellationToken])!; - await task; + await task.ConfigureAwait(false); // Get result of task. var taskWithReturnType = typeof(Task<>).MakeGenericType(responseType); @@ -39,6 +39,6 @@ public async ValueTask InvokeAsync(RequestContext context) context.Response = resultProperty.GetValue(task)!; // Invoke next middleware. - await next(context); + await next(context).ConfigureAwait(false); } } \ No newline at end of file diff --git a/test/unit/Elsa.Mediator.UnitTests/CommandCancellationBehaviorTests.cs b/test/unit/Elsa.Mediator.UnitTests/CommandCancellationBehaviorTests.cs new file mode 100644 index 0000000000..76e8fa047d --- /dev/null +++ b/test/unit/Elsa.Mediator.UnitTests/CommandCancellationBehaviorTests.cs @@ -0,0 +1,143 @@ +using Elsa.Mediator.Contracts; +using Elsa.Mediator.Models; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace Elsa.Mediator.UnitTests; + +public class CommandCancellationBehaviorTests +{ + [Fact] + public async Task SendAsync_WithSuccessfulCommand_ReturnsResult() + { + // Arrange + using var fixture = CreateCommandSender(); + + // Act + var result = await fixture.CommandSender.SendAsync(new EchoCommand("Hello")); + + // Assert + Assert.Equal("Hello", result); + } + + [Fact] + public async Task SendAsync_WithCancelledToken_ThrowsOperationCanceledException() + { + // Arrange + using var fixture = CreateCommandSender(); + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert + await Assert.ThrowsAnyAsync( + () => fixture.CommandSender.SendAsync(new SlowCommand(), cts.Token)); + } + + [Fact] + public async Task SendAsync_WithTimeout_ThrowsOperationCanceledException() + { + // Arrange + using var fixture = CreateCommandSender(); + using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(100)); + + // Act & Assert + await Assert.ThrowsAnyAsync( + () => fixture.CommandSender.SendAsync(new SlowCommand(), cts.Token)); + } + + [Fact] + public async Task SendAsync_WithSelfCancellingHandler_ThrowsTaskCanceledException() + { + // Arrange + using var fixture = CreateCommandSender(); + using var cts = new CancellationTokenSource(); + + // Act & Assert + await Assert.ThrowsAsync( + () => fixture.CommandSender.SendAsync(new SelfCancellingCommand(cts))); + } + + [Fact] + public async Task SendAsync_WithFailingHandler_ThrowsOriginalException() + { + // Arrange + using var fixture = CreateCommandSender(); + + // Act & Assert + var ex = await Assert.ThrowsAsync( + () => fixture.CommandSender.SendAsync(new FailingCommand("Test error"))); + + Assert.Equal("Test error", ex.Message); + } + + #region Helpers + + private static CommandSenderFixture CreateCommandSender() where THandler : class, ICommandHandler + { + var services = new ServiceCollection(); + services.AddLogging(b => b.SetMinimumLevel(LogLevel.Warning)); + services.AddMediator(); + services.AddCommandHandler(); + + var provider = services.BuildServiceProvider(); + var scope = provider.CreateScope(); + return new CommandSenderFixture(provider, scope); + } + + private sealed class CommandSenderFixture(ServiceProvider provider, IServiceScope scope) : IDisposable + { + public ICommandSender CommandSender => scope.ServiceProvider.GetRequiredService(); + + public void Dispose() + { + scope.Dispose(); + provider.Dispose(); + } + } + + #endregion + + #region Test Commands + + public record EchoCommand(string Message) : ICommand; + public record SlowCommand : ICommand; + public record SelfCancellingCommand(CancellationTokenSource Cts) : ICommand; + public record FailingCommand(string ErrorMessage) : ICommand; + + #endregion + + #region Test Handlers + + public class EchoCommandHandler : ICommandHandler + { + public Task HandleAsync(EchoCommand command, CancellationToken cancellationToken) + => Task.FromResult(command.Message); + } + + public class SlowCommandHandler : ICommandHandler + { + public async Task HandleAsync(SlowCommand command, CancellationToken cancellationToken) + { + await Task.Delay(TimeSpan.FromMilliseconds(500), cancellationToken); + return Unit.Instance; + } + } + + public class SelfCancellingCommandHandler : ICommandHandler + { + public async Task HandleAsync(SelfCancellingCommand command, CancellationToken cancellationToken) + { + await command.Cts.CancelAsync(); + await Task.Delay(1000, command.Cts.Token); + return Unit.Instance; + } + } + + public class FailingCommandHandler : ICommandHandler + { + public Task HandleAsync(FailingCommand command, CancellationToken cancellationToken) + => throw new InvalidOperationException(command.ErrorMessage); + } + + #endregion +} diff --git a/test/unit/Elsa.Mediator.UnitTests/Elsa.Mediator.UnitTests.csproj b/test/unit/Elsa.Mediator.UnitTests/Elsa.Mediator.UnitTests.csproj new file mode 100644 index 0000000000..431a924caa --- /dev/null +++ b/test/unit/Elsa.Mediator.UnitTests/Elsa.Mediator.UnitTests.csproj @@ -0,0 +1,12 @@ + + + + [Elsa.Mediator]* + 0 + + + + + + +