Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/common/Elsa.Mediator/Extensions/HandlerExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Runtime.ExceptionServices;
using Elsa.Mediator.Contracts;
using Elsa.Mediator.Middleware.Command;
using Elsa.Mediator.Middleware.Notification;
Expand Down Expand Up @@ -47,7 +48,7 @@ public static Task InvokeAsync(this INotificationHandler handler, MethodBase han
{
var notification = notificationContext.Notification;
var cancellationToken = notificationContext.CancellationToken;
return (Task)handleMethod.Invoke(handler, [notification, cancellationToken])!;
return InvokeAndUnwrap<Task>(handleMethod, handler, [notification, cancellationToken]);
}

/// <summary>
Expand All @@ -60,7 +61,22 @@ public static Task<TResult> InvokeAsync<TResult>(this ICommandHandler handler, M
{
var command = commandContext.Command;
var cancellationToken = commandContext.CancellationToken;
var task = (Task<TResult>)handleMethod.Invoke(handler, [command, cancellationToken])!;
return task;
return InvokeAndUnwrap<Task<TResult>>(handleMethod, handler, [command, cancellationToken]);
}

/// <summary>
/// Invokes a method via reflection and unwraps any TargetInvocationException to preserve the original exception's stack trace.
/// </summary>
private static T InvokeAndUnwrap<T>(MethodBase method, object target, object[] args) where T : Task
{
try
{
return (T)method.Invoke(target, args)!;
}
catch (TargetInvocationException ex) when (ex.InnerException is not null)
{
ExceptionDispatchInfo.Capture(ex.InnerException).Throw();
throw; // Unreachable, but required for compiler
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,15 @@ public async ValueTask InvokeAsync(CommandContext context)
var executeMethodWithReturnType = executeMethod.MakeGenericMethod(resultType);

// Execute command.
var task = executeMethodWithReturnType.Invoke(strategy, [strategyContext]);
var task = (Task)executeMethodWithReturnType.Invoke(strategy, [strategyContext])!;
await task.ConfigureAwait(false);

// Await the task to get the result without blocking.
// Get result of task.
var taskWithReturnType = typeof(Task<>).MakeGenericType(resultType);
var taskInstance = (Task)task!;
await taskInstance.ConfigureAwait(false);
var resultProperty = taskWithReturnType.GetProperty(nameof(Task<object>.Result))!;
context.Result = resultProperty.GetValue(task);

// Invoke next middleware.
await next(context);
await next(context).ConfigureAwait(false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ public async ValueTask InvokeAsync(RequestContext context)
var handleMethod = handlerType.GetMethod("HandleAsync")!;
var cancellationToken = context.CancellationToken;
var task = (Task)handleMethod.Invoke(handler, [request, cancellationToken])!;
await task;
await task.ConfigureAwait(false);

// Get result of task.
var taskWithReturnType = typeof(Task<>).MakeGenericType(responseType);
var resultProperty = taskWithReturnType.GetProperty(nameof(Task<object>.Result))!;
context.Response = resultProperty.GetValue(task)!;

// Invoke next middleware.
await next(context);
await next(context).ConfigureAwait(false);
}
}
143 changes: 143 additions & 0 deletions test/unit/Elsa.Mediator.UnitTests/CommandCancellationBehaviorTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
using Elsa.Mediator.Contracts;
using Elsa.Mediator.Models;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

namespace Elsa.Mediator.UnitTests;

public class CommandCancellationBehaviorTests
{
[Fact]
public async Task SendAsync_WithSuccessfulCommand_ReturnsResult()
{
// Arrange
using var fixture = CreateCommandSender<EchoCommandHandler>();

// Act
var result = await fixture.CommandSender.SendAsync(new EchoCommand("Hello"));

// Assert
Assert.Equal("Hello", result);
}

[Fact]
public async Task SendAsync_WithCancelledToken_ThrowsOperationCanceledException()
{
// Arrange
using var fixture = CreateCommandSender<SlowCommandHandler>();
using var cts = new CancellationTokenSource();
cts.Cancel();

// Act & Assert
await Assert.ThrowsAnyAsync<OperationCanceledException>(
() => fixture.CommandSender.SendAsync(new SlowCommand(), cts.Token));
}

[Fact]
public async Task SendAsync_WithTimeout_ThrowsOperationCanceledException()
{
// Arrange
using var fixture = CreateCommandSender<SlowCommandHandler>();
using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(100));

// Act & Assert
await Assert.ThrowsAnyAsync<OperationCanceledException>(
() => fixture.CommandSender.SendAsync(new SlowCommand(), cts.Token));
}

[Fact]
public async Task SendAsync_WithSelfCancellingHandler_ThrowsTaskCanceledException()
{
// Arrange
using var fixture = CreateCommandSender<SelfCancellingCommandHandler>();
using var cts = new CancellationTokenSource();

// Act & Assert
await Assert.ThrowsAsync<TaskCanceledException>(
() => fixture.CommandSender.SendAsync(new SelfCancellingCommand(cts)));
}

[Fact]
public async Task SendAsync_WithFailingHandler_ThrowsOriginalException()
{
// Arrange
using var fixture = CreateCommandSender<FailingCommandHandler>();

// Act & Assert
var ex = await Assert.ThrowsAsync<InvalidOperationException>(
() => fixture.CommandSender.SendAsync(new FailingCommand("Test error")));

Assert.Equal("Test error", ex.Message);
}

#region Helpers

private static CommandSenderFixture CreateCommandSender<THandler>() where THandler : class, ICommandHandler
{
var services = new ServiceCollection();
services.AddLogging(b => b.SetMinimumLevel(LogLevel.Warning));
services.AddMediator();
services.AddCommandHandler<THandler>();

var provider = services.BuildServiceProvider();
var scope = provider.CreateScope();
return new CommandSenderFixture(provider, scope);
}

private sealed class CommandSenderFixture(ServiceProvider provider, IServiceScope scope) : IDisposable
{
public ICommandSender CommandSender => scope.ServiceProvider.GetRequiredService<ICommandSender>();

public void Dispose()
{
scope.Dispose();
provider.Dispose();
}
}

#endregion

#region Test Commands

public record EchoCommand(string Message) : ICommand<string>;
public record SlowCommand : ICommand;
public record SelfCancellingCommand(CancellationTokenSource Cts) : ICommand;
public record FailingCommand(string ErrorMessage) : ICommand;

#endregion

#region Test Handlers

public class EchoCommandHandler : ICommandHandler<EchoCommand, string>
{
public Task<string> HandleAsync(EchoCommand command, CancellationToken cancellationToken)
=> Task.FromResult(command.Message);
}

public class SlowCommandHandler : ICommandHandler<SlowCommand, Unit>
{
public async Task<Unit> HandleAsync(SlowCommand command, CancellationToken cancellationToken)
{
await Task.Delay(TimeSpan.FromMilliseconds(500), cancellationToken);
return Unit.Instance;
}
}

public class SelfCancellingCommandHandler : ICommandHandler<SelfCancellingCommand, Unit>
{
public async Task<Unit> HandleAsync(SelfCancellingCommand command, CancellationToken cancellationToken)
{
await command.Cts.CancelAsync();
await Task.Delay(1000, command.Cts.Token);
return Unit.Instance;
}
}

public class FailingCommandHandler : ICommandHandler<FailingCommand, Unit>
{
public Task<Unit> HandleAsync(FailingCommand command, CancellationToken cancellationToken)
=> throw new InvalidOperationException(command.ErrorMessage);
}

#endregion
}
12 changes: 12 additions & 0 deletions test/unit/Elsa.Mediator.UnitTests/Elsa.Mediator.UnitTests.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<Include>[Elsa.Mediator]*</Include>
<Threshold>0</Threshold>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\src\common\Elsa.Mediator\Elsa.Mediator.csproj"/>
</ItemGroup>

</Project>
Loading