From 37eeab12277bead5e64604994346d0137c6e9e3e Mon Sep 17 00:00:00 2001 From: Quahu Date: Sat, 3 Aug 2024 19:51:23 +0200 Subject: [PATCH] Disconnect bot on voice connection exceptions and run cancellation --- examples/Voice/BasicVoice/AudioModule.cs | 14 ++++ .../VoiceExtension.cs | 65 ++++++++++++------- .../Default/DefaultVoiceConnection.cs | 3 + 3 files changed, 59 insertions(+), 23 deletions(-) diff --git a/examples/Voice/BasicVoice/AudioModule.cs b/examples/Voice/BasicVoice/AudioModule.cs index d6827274b..041af45db 100644 --- a/examples/Voice/BasicVoice/AudioModule.cs +++ b/examples/Voice/BasicVoice/AudioModule.cs @@ -116,4 +116,18 @@ public async Task Skip() return Response(new LocalInteractionMessageResponse().WithContent("Skipped.").WithIsEphemeral()); } + + [SlashCommand("stop")] + [Description("Stops the playback and disconnects the bot from the voice channel.")] + public async Task Stop() + { + var player = await _playerService.GetPlayerAsync(Context.GuildId); + if (player == null) + { + return Response("Not playing."); + } + + await _playerService.DisposePlayerAsync(Context.GuildId); + return Response("Disconnected."); + } } diff --git a/src/Disqord.Extensions.Voice/VoiceExtension.cs b/src/Disqord.Extensions.Voice/VoiceExtension.cs index a2e27c990..ac8c73ebb 100644 --- a/src/Disqord.Extensions.Voice/VoiceExtension.cs +++ b/src/Disqord.Extensions.Voice/VoiceExtension.cs @@ -1,5 +1,4 @@ -using System; -using System.Collections.Generic; +using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -17,6 +16,7 @@ public class VoiceExtension : DiscordClientExtension { private readonly IVoiceConnectionFactory _connectionFactory; + private readonly IThreadSafeDictionary _pendingConnections; private readonly IThreadSafeDictionary _connections; public VoiceExtension( @@ -26,6 +26,7 @@ public VoiceExtension( { _connectionFactory = connectionFactory; + _pendingConnections = ThreadSafeDictionary.Monitor.Create(); _connections = ThreadSafeDictionary.Monitor.Create(); } @@ -40,7 +41,7 @@ protected override ValueTask InitializeAsync(CancellationToken cancellationToken private Task VoiceServerUpdatedAsync(object? sender, VoiceServerUpdatedEventArgs e) { - GetConnection(e.GuildId)?.OnVoiceServerUpdate(e.Token, e.Endpoint); + _pendingConnections.GetValueOrDefault(e.GuildId)?.OnVoiceServerUpdate(e.Token, e.Endpoint); return Task.CompletedTask; } @@ -52,7 +53,7 @@ private Task VoiceStateUpdatedAsync(object? sender, VoiceStateUpdatedEventArgs e } var voiceState = e.NewVoiceState; - GetConnection(e.GuildId)?.OnVoiceStateUpdate(voiceState.ChannelId, voiceState.SessionId); + _pendingConnections.GetValueOrDefault(e.GuildId)?.OnVoiceStateUpdate(voiceState.ChannelId, voiceState.SessionId); return Task.CompletedTask; } @@ -106,20 +107,28 @@ public async ValueTask ConnectAsync(Snowflake guildId, Snowfla return new(shard.SetVoiceStateAsync(guildId, channelId, false, true, cancellationToken)); }); - var connectionInfo = new VoiceConnectionInfo(connection, Cts.Linked(Client.StoppingToken)); - _connections[guildId] = connectionInfo; - try + _pendingConnections[guildId] = connection; + using (var linkedReadyCts = Cts.Linked(cancellationToken, Client.StoppingToken)) { - var readyTask = connection.WaitUntilReadyAsync(cancellationToken); - _ = connection.RunAsync(connectionInfo.Cts.Token); + var readyTask = connection.WaitUntilReadyAsync(linkedReadyCts.Token); - await readyTask.ConfigureAwait(false); - } - catch - { - _connections.Remove(guildId); - await connectionInfo.DisposeAsync(); - throw; + var linkedRunCts = Cts.Linked(Client.StoppingToken); + Task runTask; + try + { + runTask = connection.RunAsync(linkedRunCts.Token); + await readyTask.ConfigureAwait(false); + } + catch + { + _pendingConnections.Remove(guildId); + linkedRunCts.Cancel(); + linkedRunCts.Dispose(); + + throw; + } + + _connections[guildId] = new VoiceConnectionInfo(connection, runTask, linkedRunCts); } return connection; @@ -133,33 +142,43 @@ public async ValueTask ConnectAsync(Snowflake guildId, Snowfla /// Use to obtain a new connection afterward. /// /// The ID of the guild. - public ValueTask DisconnectAsync(Snowflake guildId) + public async ValueTask DisconnectAsync(Snowflake guildId) { if (!_connections.TryRemove(guildId, out var connectionInfo)) { - return default; + return; } - return connectionInfo.DisposeAsync(); + await connectionInfo.StopAsync().ConfigureAwait(false); } - private readonly struct VoiceConnectionInfo : IAsyncDisposable + private readonly struct VoiceConnectionInfo { public IVoiceConnection Connection { get; } + public Task RunTask { get; } + public Cts Cts { get; } - public VoiceConnectionInfo(IVoiceConnection connection, Cts cts) + public VoiceConnectionInfo(IVoiceConnection connection, Task runTask, Cts cts) { Connection = connection; + RunTask = runTask; Cts = cts; } - public async ValueTask DisposeAsync() + public async ValueTask StopAsync() { Cts.Cancel(); + + try + { + await RunTask.ConfigureAwait(false); + } + catch { } + Cts.Dispose(); - await Connection.DisposeAsync(); + await Connection.DisposeAsync().ConfigureAwait(false); } } } diff --git a/src/Disqord.Voice/Default/DefaultVoiceConnection.cs b/src/Disqord.Voice/Default/DefaultVoiceConnection.cs index 5a0ff72c2..07c0d72a1 100644 --- a/src/Disqord.Voice/Default/DefaultVoiceConnection.cs +++ b/src/Disqord.Voice/Default/DefaultVoiceConnection.cs @@ -344,6 +344,7 @@ await Gateway.SendAsync(new VoiceGatewayPayloadJsonModel } catch (OperationCanceledException ex) when (ex.CancellationToken == linkedCancellationToken && stoppingToken.IsCancellationRequested) { + await _setVoiceStateDelegate(GuildId, null, default).ConfigureAwait(false); _readyTcs.Cancel(ex.CancellationToken); return; } @@ -389,6 +390,8 @@ await Gateway.SendAsync(new VoiceGatewayPayloadJsonModel } catch (Exception ex) { + await _setVoiceStateDelegate(GuildId, null, default).ConfigureAwait(false); + lock (_readyTcs) { _readyTcs.Throw(ex);