Skip to content

Commit

Permalink
Add KeepAliveMode option
Browse files Browse the repository at this point in the history
  • Loading branch information
Shane32 committed Oct 7, 2024
1 parent a84c553 commit 3b7645b
Show file tree
Hide file tree
Showing 11 changed files with 348 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/Transports.AspNetCore/GraphQLHttpMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ protected virtual Task WriteJsonResponseAsync<TResult>(HttpContext context, Http
/// <summary>
/// Gets a list of WebSocket sub-protocols supported.
/// </summary>
protected virtual IEnumerable<string> SupportedWebSocketSubProtocols => _supportedSubProtocols;
protected virtual IEnumerable<string> SupportedWebSocketSubProtocols => _options.WebSockets.SupportedWebSocketSubProtocols;

/// <summary>
/// Creates an <see cref="IWebSocketConnection"/>, a WebSocket message pump.
Expand Down
72 changes: 66 additions & 6 deletions src/Transports.AspNetCore/WebSockets/BaseSubscriptionServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,33 @@ protected virtual Task OnNotAuthorizedPolicyAsync(OperationMessage message, Auth
/// <br/><br/>
/// Otherwise, the connection is acknowledged via <see cref="OnConnectionAcknowledgeAsync(OperationMessage)"/>,
/// <see cref="TryInitialize"/> is called to indicate that this WebSocket connection is ready to accept requests,
/// and keep-alive messages are sent via <see cref="OnSendKeepAliveAsync"/> if configured to do so.
/// Keep-alive messages are only sent if no messages have been sent over the WebSockets connection for the
/// length of time configured in <see cref="GraphQLWebSocketOptions.KeepAliveTimeout"/>.
/// and <see cref="OnSendKeepAliveAsync"/> is called to start sending keep-alive messages if configured to do so.
/// </summary>
protected virtual async Task OnConnectionInitAsync(OperationMessage message)
{
if (!await AuthorizeAsync(message))
{
return;
}
await OnConnectionAcknowledgeAsync(message);
if (TryInitialize() == false)

Check notice

Code scanning / CodeQL

Unnecessarily complex Boolean expression Note

The expression 'A == false' can be simplified to '!A'.
return;

_ = OnKeepAliveLoopAsync();
}

/// <summary>
/// Executes when the client is attempting to initialize the connection.
/// <br/><br/>
/// By default, this first checks <see cref="AuthorizeAsync(OperationMessage)"/> to validate that the
/// request has passed authentication. If validation fails, the connection is closed with an Access
/// Denied message.
/// <br/><br/>
/// Otherwise, the connection is acknowledged via <see cref="OnConnectionAcknowledgeAsync(OperationMessage)"/>,
/// <see cref="TryInitialize"/> is called to indicate that this WebSocket connection is ready to accept requests,
/// and <see cref="OnSendKeepAliveAsync"/> is called to start sending keep-alive messages if configured to do so.
/// </summary>
[Obsolete($"Please use the {nameof(OnConnectionInitAsync)}(message) and {nameof(OnKeepAliveLoopAsync)} methods instead. This method will be removed in a future version of this library.")]
protected virtual async Task OnConnectionInitAsync(OperationMessage message, bool smartKeepAlive)
{
if (!await AuthorizeAsync(message))
Expand All @@ -277,12 +300,49 @@ protected virtual async Task OnConnectionInitAsync(OperationMessage message, boo
if (keepAliveTimeout > TimeSpan.Zero)
{
if (smartKeepAlive)
_ = StartSmartKeepAliveLoopAsync();
_ = OnKeepAliveLoopAsync(keepAliveTimeout, KeepAliveMode.Timeout);
else
_ = StartKeepAliveLoopAsync();
_ = OnKeepAliveLoopAsync(keepAliveTimeout, KeepAliveMode.Interval);
}
}

/// <summary>
/// Starts sending keep-alive messages if configured to do so. Inspects the configured
/// <see cref="GraphQLWebSocketOptions"/> and passes control to <see cref="OnKeepAliveLoopAsync(TimeSpan, KeepAliveMode)"/>
/// if keep-alive messages are enabled.
/// </summary>
protected virtual Task OnKeepAliveLoopAsync()
{
return OnKeepAliveLoopAsync(
_options.KeepAliveTimeout ?? DefaultKeepAliveTimeout,
_options.KeepAliveMode);
}

/// <summary>
/// Sends keep-alive messages according to the specified timeout period and method.
/// See <see cref="KeepAliveMode"/> for implementation details for each supported mode.
/// </summary>
protected virtual async Task OnKeepAliveLoopAsync(TimeSpan keepAliveTimeout, KeepAliveMode keepAliveMode)
{
if (keepAliveTimeout <= TimeSpan.Zero)
return;

switch (keepAliveMode)
{
case KeepAliveMode.Default:
case KeepAliveMode.Timeout:
await StartSmartKeepAliveLoopAsync();
break;
case KeepAliveMode.Interval:
await StartDumbKeepAliveLoopAsync();
break;
case KeepAliveMode.TimeoutWithPayload:
throw new NotImplementedException($"{nameof(KeepAliveMode.TimeoutWithPayload)} is not implemented within the {nameof(BaseSubscriptionServer)} class.");
default:
throw new ArgumentOutOfRangeException(nameof(keepAliveMode));
}

async Task StartKeepAliveLoopAsync()
async Task StartDumbKeepAliveLoopAsync()
{
while (!CancellationToken.IsCancellationRequested)
{
Expand Down
19 changes: 19 additions & 0 deletions src/Transports.AspNetCore/WebSockets/GraphQLWebSocketOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ public class GraphQLWebSocketOptions
/// </summary>
public TimeSpan? KeepAliveTimeout { get; set; }

/// <summary>
/// Gets or sets the keep-alive mode used for websocket subscriptions.
/// This property is only applicable when using the GraphQLWs protocol.
/// </summary>
public KeepAliveMode KeepAliveMode { get; set; } = KeepAliveMode.Default;

/// <summary>
/// The amount of time to wait to attempt a graceful teardown of the WebSockets protocol.
/// A value of <see langword="null"/> indicates the default value defined by the implementation.
Expand All @@ -42,4 +48,17 @@ public class GraphQLWebSocketOptions
/// Disconnects a subscription from the client in the event of any GraphQL errors during a subscription. The default value is <see langword="false"/>.
/// </summary>
public bool DisconnectAfterAnyError { get; set; }

/// <summary>
/// The list of supported WebSocket sub-protocols.
/// Defaults to <see cref="GraphQLWs.SubscriptionServer.SubProtocol"/> and <see cref="SubscriptionsTransportWs.SubscriptionServer.SubProtocol"/>.
/// Adding other sub-protocols require the <see cref="GraphQLHttpMiddleware.CreateMessageProcessor(IWebSocketConnection, string)"/> method
/// to be overridden to handle the new sub-protocol.
/// </summary>
/// <remarks>
/// When the <see cref="KeepAliveMode"/> is set to <see cref="KeepAliveMode.TimeoutWithPayload"/>, you may wish to remove
/// <see cref="SubscriptionsTransportWs.SubscriptionServer.SubProtocol"/> from this list to prevent clients from using
/// protocols which do not support the <see cref="KeepAliveMode.TimeoutWithPayload"/> keep-alive mode.
/// </remarks>
public List<string> SupportedWebSocketSubProtocols { get; set; } = [GraphQLWs.SubscriptionServer.SubProtocol, SubscriptionsTransportWs.SubscriptionServer.SubProtocol];
}
12 changes: 12 additions & 0 deletions src/Transports.AspNetCore/WebSockets/GraphQLWs/PingPayload.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace GraphQL.Server.Transports.AspNetCore.WebSockets.GraphQLWs;

/// <summary>
/// The payload of the ping message.
/// </summary>
public class PingPayload
{
/// <summary>
/// The unique identifier of the ping message.
/// </summary>
public string? id { get; set; }
}
111 changes: 109 additions & 2 deletions src/Transports.AspNetCore/WebSockets/GraphQLWs/SubscriptionServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ namespace GraphQL.Server.Transports.AspNetCore.WebSockets.GraphQLWs;
public class SubscriptionServer : BaseSubscriptionServer
{
private readonly IWebSocketAuthenticationService? _authenticationService;
private readonly IGraphQLSerializer _serializer;
private readonly GraphQLWebSocketOptions _options;
private DateTime _lastPongReceivedUtc;
private string? _lastPingId;
private readonly object _lastPingLock = new();

/// <summary>
/// The WebSocket sub-protocol used for this protocol.
Expand Down Expand Up @@ -67,6 +72,8 @@ public SubscriptionServer(
UserContextBuilder = userContextBuilder ?? throw new ArgumentNullException(nameof(userContextBuilder));
Serializer = serializer ?? throw new ArgumentNullException(nameof(serializer));
_authenticationService = authenticationService;
_serializer = serializer;
_options = options;
}

/// <inheritdoc/>
Expand All @@ -90,7 +97,9 @@ public override async Task OnMessageReceivedAsync(OperationMessage message)
}
else
{
#pragma warning disable CS0618 // Type or member is obsolete
await OnConnectionInitAsync(message, true);

Check warning

Code scanning / CodeQL

Call to obsolete method Warning

Call to obsolete method
OnConnectionInitAsync
.
#pragma warning restore CS0618 // Type or member is obsolete
}
return;
}
Expand All @@ -113,6 +122,69 @@ public override async Task OnMessageReceivedAsync(OperationMessage message)
}
}

/// <inheritdoc/>
[Obsolete($"Please use the {nameof(OnConnectionInitAsync)} and {nameof(OnKeepAliveLoopAsync)} methods instead. This method will be removed in a future version of this library.")]
protected override Task OnConnectionInitAsync(OperationMessage message, bool smartKeepAlive)
{
if (smartKeepAlive)
return base.OnConnectionInitAsync(message);
else
return base.OnConnectionInitAsync(message, smartKeepAlive);

Check notice

Code scanning / CodeQL

Missed ternary opportunity Note

Both branches of this 'if' statement return - consider using '?' to express intent better.
}

/// <inheritdoc/>
protected override Task OnKeepAliveLoopAsync(TimeSpan keepAliveTimeout, KeepAliveMode keepAliveMode)
{
if (keepAliveMode == KeepAliveMode.TimeoutWithPayload)
{
if (keepAliveTimeout <= TimeSpan.Zero)
return Task.CompletedTask;
return SecureKeepAliveLoopAsync(keepAliveTimeout, keepAliveTimeout);
}
return base.OnKeepAliveLoopAsync(keepAliveTimeout, keepAliveMode);

// pingInterval is the time since the last pong was received before sending a new ping
// pongInterval is the time to wait for a pong after a ping was sent before forcibly closing the connection
async Task SecureKeepAliveLoopAsync(TimeSpan pingInterval, TimeSpan pongInterval)
{
lock (_lastPingLock)
_lastPongReceivedUtc = DateTime.UtcNow;
while (!CancellationToken.IsCancellationRequested)
{
// Wait for the next ping interval
TimeSpan interval;
var now = DateTime.UtcNow;
DateTime lastPongReceivedUtc;
lock (_lastPingLock)
{
lastPongReceivedUtc = _lastPongReceivedUtc;
}
var nextPing = _lastPongReceivedUtc.Add(pingInterval);
interval = nextPing.Subtract(now);
if (interval > TimeSpan.Zero) // could easily be zero or less, if pongInterval is equal or greater than pingInterval
await Task.Delay(interval, CancellationToken);

// Send a new ping message
await OnSendKeepAliveAsync();

// Wait for the pong response
await Task.Delay(pongInterval, CancellationToken);
bool abort;
lock (_lastPingLock)
{
abort = _lastPongReceivedUtc == lastPongReceivedUtc;
}
if (abort)
{
// Forcibly close the connection if the client has not responded to the keep-alive message.
// Do not send a close message to the client or wait for a response.
Connection.HttpContext.Abort();
return;
}
}
}
}

/// <summary>
/// Pong is a required response to a ping, and also a unidirectional keep-alive packet,
/// whereas ping is a bidirectional keep-alive packet.
Expand All @@ -129,11 +201,46 @@ protected virtual Task OnPingAsync(OperationMessage message)
/// Executes when a pong message is received.
/// </summary>
protected virtual Task OnPongAsync(OperationMessage message)
=> Task.CompletedTask;
{
if (_options.KeepAliveMode == KeepAliveMode.TimeoutWithPayload)
{
try
{
var pingId = _serializer.ReadNode<PingPayload>(message.Payload)?.id;
lock (_lastPingLock)
{
if (_lastPingId == pingId)
_lastPongReceivedUtc = DateTime.UtcNow;
}
}
catch { } // ignore deserialization errors in case the pong message does not match the expected format

Check notice

Code scanning / CodeQL

Poor error handling: empty catch block Note

Poor error handling: empty catch block.

Check notice

Code scanning / CodeQL

Generic catch clause Note

Generic catch clause.
}
return Task.CompletedTask;
}

/// <inheritdoc/>
protected override Task OnSendKeepAliveAsync()
=> Connection.SendMessageAsync(_pongMessage);
{
if (_options.KeepAliveMode == KeepAliveMode.TimeoutWithPayload)
{
var lastPingId = Guid.NewGuid().ToString("N");
lock (_lastPingLock)
{
_lastPingId = lastPingId;
}
return Connection.SendMessageAsync(
new()
{
Type = MessageType.Ping,
Payload = new PingPayload { id = lastPingId }
}
);
}
else
{
return Connection.SendMessageAsync(_pongMessage);
}
}

private static readonly OperationMessage _connectionAckMessage = new() { Type = MessageType.ConnectionAck };
/// <inheritdoc/>
Expand Down
36 changes: 36 additions & 0 deletions src/Transports.AspNetCore/WebSockets/KeepAliveMode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
namespace GraphQL.Server.Transports.AspNetCore.WebSockets;

/// <summary>
/// Specifies the mode of keep-alive behavior.
/// </summary>
public enum KeepAliveMode
{
/// <summary>
/// Same as <see cref="Timeout"/>: Sends a unidirectional keep-alive message when no message has been received within the specified timeout period.
/// </summary>
Default = 0,

/// <summary>
/// Sends a unidirectional keep-alive message when no message has been received within the specified timeout period.
/// </summary>
Timeout = 1,

/// <summary>
/// Sends a unidirectional keep-alive message at a fixed interval, regardless of message activity.
/// </summary>
Interval = 2,

/// <summary>
/// Sends a Ping message with a payload after the specified timeout from the last received Pong,
/// and waits for a corresponding Pong response. Requires that the client reflects the payload
/// in the response. Forcibly disconnects the client if the client does not respond with a Pong
/// message within the specified timeout. This means that a dead connection will be closed after
/// a maximum of double the <see cref="GraphQLWebSocketOptions.KeepAliveTimeout"/> period.
/// </summary>
/// <remarks>
/// This mode is particularly useful when backpressure causes subscription messages to be delayed
/// due to a slow or unresponsive client connection. The server can detect that the client is not
/// processing messages in a timely manner and disconnect the client to free up resources.
/// </remarks>
TimeoutWithPayload = 3,
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ public override async Task OnMessageReceivedAsync(OperationMessage message)
}
else
{
#pragma warning disable CS0618 // Type or member is obsolete
await OnConnectionInitAsync(message, false);

Check warning

Code scanning / CodeQL

Call to obsolete method Warning

Call to obsolete method
OnConnectionInitAsync
.
#pragma warning restore CS0618 // Type or member is obsolete
}
return;
}
Expand All @@ -108,6 +110,26 @@ public override async Task OnMessageReceivedAsync(OperationMessage message)
}
}

/// <inheritdoc/>
[Obsolete($"Please use the {nameof(OnConnectionInitAsync)} and {nameof(OnKeepAliveLoopAsync)} methods instead. This method will be removed in a future version of this library.")]
protected override Task OnConnectionInitAsync(OperationMessage message, bool smartKeepAlive)
{
if (!smartKeepAlive)
return base.OnConnectionInitAsync(message);
else
return base.OnConnectionInitAsync(message, smartKeepAlive);

Check notice

Code scanning / CodeQL

Missed ternary opportunity Note

Both branches of this 'if' statement return - consider using '?' to express intent better.
}

/// <inheritdoc/>
/// <remarks>
/// This implementation overrides <see cref="GraphQLWebSocketOptions.KeepAliveMode"/> to <see cref="KeepAliveMode.Interval"/>
/// as this protocol does not support the other modes. Override this method to support your own implementation.
/// </remarks>
protected override Task OnKeepAliveLoopAsync(TimeSpan keepAliveTimeout, KeepAliveMode keepAliveMode)
=> base.OnKeepAliveLoopAsync(
keepAliveTimeout,
KeepAliveMode.Interval);

private static readonly OperationMessage _keepAliveMessage = new() { Type = MessageType.GQL_CONNECTION_KEEP_ALIVE };
/// <inheritdoc/>
protected override Task OnSendKeepAliveAsync()
Expand Down
Loading

0 comments on commit 3b7645b

Please sign in to comment.