From ec637f7946b454c6c659e7e34657ad5e50a8a09f Mon Sep 17 00:00:00 2001 From: Christian <6939810+chkr1011@users.noreply.github.com> Date: Sat, 4 Nov 2023 13:39:16 +0100 Subject: [PATCH] Handle unobserved tasks exceptions (#1871) * Access exception so that it will not treated as unhandled * Refactor code * Improve unit tests * Fix null reference exception * Update ReleaseNotes.md --- .github/workflows/ReleaseNotes.md | 3 +- MQTTnet.sln.DotSettings | 1 + .../ManagedMqttClient_Tests.cs | 192 +++++++++--------- .../MqttClient/MqttClient_Connection_Tests.cs | 32 +++ Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs | 4 +- .../MQTTnet.Tests/Mockups/TestEnvironment.cs | 126 ++++++++---- .../Server/HotSwapCerts_Tests.cs | 91 ++++----- Source/MQTTnet/Adapter/MqttChannelAdapter.cs | 30 ++- Source/MQTTnet/Client/MqttClient.cs | 15 +- .../LowLevelClient/LowLevelMqttClient.cs | 1 - 10 files changed, 289 insertions(+), 206 deletions(-) diff --git a/.github/workflows/ReleaseNotes.md b/.github/workflows/ReleaseNotes.md index 776fa1d01..4d23c9325 100644 --- a/.github/workflows/ReleaseNotes.md +++ b/.github/workflows/ReleaseNotes.md @@ -1,2 +1,3 @@ * [Server] Fixed not working _UpdateRetainedMessageAsync_ public api (#1858, thanks to @kimdiego2098). -* [Client] Added support for custom CA chain validation (#1851, thanks to @rido-min). \ No newline at end of file +* [Client] Added support for custom CA chain validation (#1851, thanks to @rido-min). +* [Client] Fixed handling of unobserved tasks exceptions (#1871). \ No newline at end of file diff --git a/MQTTnet.sln.DotSettings b/MQTTnet.sln.DotSettings index 49d8306ba..813b55202 100644 --- a/MQTTnet.sln.DotSettings +++ b/MQTTnet.sln.DotSettings @@ -240,6 +240,7 @@ See the LICENSE file in the project root for more information. True True True + True True True True diff --git a/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs b/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs index 376e800b6..5726b1484 100644 --- a/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs @@ -22,99 +22,6 @@ namespace MQTTnet.Tests.Clients.ManagedMqttClient [TestClass] public sealed class ManagedMqttClient_Tests : BaseTestClass { - [TestMethod] - public async Task Expose_Custom_Connection_Error() - { - using (var testEnvironment = CreateTestEnvironment()) - { - var server = await testEnvironment.StartServer(); - - server.ValidatingConnectionAsync += args => - { - args.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword; - return CompletedTask.Instance; - }; - - var managedClient = testEnvironment.Factory.CreateManagedMqttClient(); - - MqttClientDisconnectedEventArgs disconnectedArgs = null; - managedClient.DisconnectedAsync += args => - { - disconnectedArgs = args; - return CompletedTask.Instance; - }; - - var clientOptions = testEnvironment.Factory.CreateManagedMqttClientOptionsBuilder().WithClientOptions(testEnvironment.CreateDefaultClientOptions()).Build(); - await managedClient.StartAsync(clientOptions); - - await LongTestDelay(); - - Assert.IsNotNull(disconnectedArgs); - Assert.AreEqual(MqttClientConnectResultCode.BadUserNameOrPassword, disconnectedArgs.ConnectResult.ResultCode); - } - } - - [TestMethod] - public async Task Receive_While_Not_Cleanly_Disconnected() - { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) - { - await testEnvironment.StartServer(o => o.WithPersistentSessions()); - - var senderClient = await testEnvironment.ConnectClient(); - - // Prepare managed client. - var managedClient = testEnvironment.Factory.CreateManagedMqttClient(); - await managedClient.SubscribeAsync("#"); - var receivedMessages = testEnvironment.CreateApplicationMessageHandler(managedClient); - - var managedClientOptions = new ManagedMqttClientOptions - { - ClientOptions = testEnvironment.Factory.CreateClientOptionsBuilder() - .WithTcpServer("127.0.0.1", testEnvironment.ServerPort) - .WithClientId(nameof(Receive_While_Not_Cleanly_Disconnected) + "_managed") - .WithCleanSession(false) - .Build() - }; - - await managedClient.StartAsync(managedClientOptions); - await LongTestDelay(); - await LongTestDelay(); - - // Send test data. - await senderClient.PublishStringAsync("topic1"); - await LongTestDelay(); - await LongTestDelay(); - - receivedMessages.AssertReceivedCountEquals(1); - - // Stop the managed client but keep session at server (not clean disconnect required). - await managedClient.StopAsync(false); - await LongTestDelay(); - - // Send new messages in the meantime. - await senderClient.PublishStringAsync("topic2", qualityOfServiceLevel: MqttQualityOfServiceLevel.ExactlyOnce); - await LongTestDelay(); - - // Start the managed client, it should receive the new message. - await managedClient.StartAsync(managedClientOptions); - await LongTestDelay(); - - receivedMessages.AssertReceivedCountEquals(2); - - // Stop and start again, no new message should be received. - for (var i = 0; i < 3; i++) - { - await managedClient.StopAsync(false); - await LongTestDelay(); - await managedClient.StartAsync(managedClientOptions); - await LongTestDelay(); - } - - receivedMessages.AssertReceivedCountEquals(2); - } - } - [TestMethod] public async Task Connect_To_Invalid_Server() { @@ -181,6 +88,38 @@ public async Task Drop_New_Messages_On_Full_Queue() } } + [TestMethod] + public async Task Expose_Custom_Connection_Error() + { + using (var testEnvironment = CreateTestEnvironment()) + { + var server = await testEnvironment.StartServer(); + + server.ValidatingConnectionAsync += args => + { + args.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword; + return CompletedTask.Instance; + }; + + var managedClient = testEnvironment.Factory.CreateManagedMqttClient(); + + MqttClientDisconnectedEventArgs disconnectedArgs = null; + managedClient.DisconnectedAsync += args => + { + disconnectedArgs = args; + return CompletedTask.Instance; + }; + + var clientOptions = testEnvironment.Factory.CreateManagedMqttClientOptionsBuilder().WithClientOptions(testEnvironment.CreateDefaultClientOptions()).Build(); + await managedClient.StartAsync(clientOptions); + + await LongTestDelay(); + + Assert.IsNotNull(disconnectedArgs); + Assert.AreEqual(MqttClientConnectResultCode.BadUserNameOrPassword, disconnectedArgs.ConnectResult.ResultCode); + } + } + [TestMethod] public async Task ManagedClients_Will_Message_Send() { @@ -224,6 +163,67 @@ public async Task ManagedClients_Will_Message_Send() } } + [TestMethod] + public async Task Receive_While_Not_Cleanly_Disconnected() + { + using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + { + await testEnvironment.StartServer(o => o.WithPersistentSessions()); + + var senderClient = await testEnvironment.ConnectClient(); + + // Prepare managed client. + var managedClient = testEnvironment.Factory.CreateManagedMqttClient(); + await managedClient.SubscribeAsync("#"); + var receivedMessages = testEnvironment.CreateApplicationMessageHandler(managedClient); + + var managedClientOptions = new ManagedMqttClientOptions + { + ClientOptions = testEnvironment.Factory.CreateClientOptionsBuilder() + .WithTcpServer("127.0.0.1", testEnvironment.ServerPort) + .WithClientId(nameof(Receive_While_Not_Cleanly_Disconnected) + "_managed") + .WithCleanSession(false) + .Build() + }; + + await managedClient.StartAsync(managedClientOptions); + await LongTestDelay(); + await LongTestDelay(); + + // Send test data. + await senderClient.PublishStringAsync("topic1"); + await LongTestDelay(); + await LongTestDelay(); + + receivedMessages.AssertReceivedCountEquals(1); + + // Stop the managed client but keep session at server (not clean disconnect required). + await managedClient.StopAsync(false); + await LongTestDelay(); + + // Send new messages in the meantime. + await senderClient.PublishStringAsync("topic2", qualityOfServiceLevel: MqttQualityOfServiceLevel.ExactlyOnce); + await LongTestDelay(); + + // Start the managed client, it should receive the new message. + await managedClient.StartAsync(managedClientOptions); + await LongTestDelay(); + + receivedMessages.AssertReceivedCountEquals(2); + + // Stop and start again, no new message should be received. + for (var i = 0; i < 3; i++) + { + await managedClient.StopAsync(false); + await LongTestDelay(); + await managedClient.StartAsync(managedClientOptions); + await LongTestDelay(); + } + + receivedMessages.AssertReceivedCountEquals(2); + } + } + [TestMethod] public async Task Start_Stop() { @@ -375,7 +375,7 @@ public async Task Subscriptions_Are_Cleared_At_Logout() var clientOptions = new MqttClientOptionsBuilder().WithTcpServer("127.0.0.1", testEnvironment.ServerPort); var receivedManagedMessages = new List(); - + var managedClient = testEnvironment.Factory.CreateManagedMqttClient(testEnvironment.CreateClient()); managedClient.ApplicationMessageReceivedAsync += e => { @@ -403,7 +403,7 @@ public async Task Subscriptions_Are_Cleared_At_Logout() // Make sure that it gets received after subscribing again. await managedClient.SubscribeAsync("topic"); await LongTestDelay(); - + Assert.AreEqual(2, receivedManagedMessages.Count); } } @@ -421,7 +421,7 @@ public async Task Subscriptions_Are_Published_Immediately() var receivingClient = await CreateManagedClientAsync(testEnvironment, null, connectionCheckInterval); var sendingClient = await testEnvironment.ConnectClient(); - await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", PayloadSegment = new ArraySegment( new byte[] { 1 }), Retain = true }); + await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", PayloadSegment = new ArraySegment(new byte[] { 1 }), Retain = true }); var subscribeTime = DateTime.UtcNow; diff --git a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs index 996f71ebc..b49417ac5 100644 --- a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs @@ -167,6 +167,38 @@ public async Task Disconnect_Clean_With_User_Properties() } } + [TestMethod] + public async Task No_Unobserved_Exception() + { + using (var testEnvironment = CreateTestEnvironment()) + { + testEnvironment.IgnoreClientLogErrors = true; + + var client = testEnvironment.CreateClient(); + var options = new MqttClientOptionsBuilder().WithTcpServer("127.0.0.1").WithTimeout(TimeSpan.FromSeconds(2)).Build(); + + try + { + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(0.5))) + { + await client.ConnectAsync(options, timeout.Token); + } + } + catch (OperationCanceledException) + { + } + + client.Dispose(); + + // These delays and GC calls are required in order to make calling the finalizer reproducible. + GC.Collect(); + GC.WaitForPendingFinalizers(); + await LongTestDelay(); + await LongTestDelay(); + await LongTestDelay(); + } + } + [TestMethod] public async Task Return_Non_Success() { diff --git a/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs b/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs index 36bca5e4d..554669cb2 100644 --- a/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs +++ b/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs @@ -52,7 +52,7 @@ public async Task Execute_Success_Parameters_Propagated_Correctly() var paramValue = "123"; var parameters = new Dictionary { - { TestParametersTopicGenerationStrategy.ExpectedParamName, "123" }, + { TestParametersTopicGenerationStrategy.ExpectedParamName, "123" } }; using (var testEnvironment = CreateTestEnvironment()) @@ -164,7 +164,7 @@ public async Task Execute_Timeout_MQTT_V5_Mixed_Clients() using (var rpcClient = new MqttRpcClient(requestSender, new MqttRpcClientOptionsBuilder().Build())) { - var response = await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(2), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); + await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(2), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); } } } diff --git a/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs b/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs index 92f2543fa..c8d15185e 100644 --- a/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs +++ b/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.Linq; @@ -25,9 +24,9 @@ namespace MQTTnet.Tests.Mockups public sealed class TestEnvironment : IDisposable { readonly List _clientErrors = new List(); - readonly ConcurrentBag _clients = new ConcurrentBag(); + readonly List _clients = new List(); readonly List _exceptions = new List(); - readonly ConcurrentBag _lowLevelClients = new ConcurrentBag(); + readonly List _lowLevelClients = new List(); readonly MqttProtocolVersion _protocolVersion; readonly List _serverErrors = new List(); @@ -40,6 +39,8 @@ public TestEnvironment(TestContext testContext, MqttProtocolVersion protocolVers _protocolVersion = protocolVersion; TestContext = testContext; + TaskScheduler.UnobservedTaskException += TrackUnobservedTaskException; + ServerLogger.LogMessagePublished += (s, e) => { if (Debugger.IsAttached) @@ -216,8 +217,9 @@ public TestApplicationMessageReceivedHandler CreateApplicationMessageHandler(IMa public IMqttClient CreateClient() { var logger = EnableLogger ? (IMqttNetLogger)ClientLogger : MqttNetNullLogger.Instance; - + var client = Factory.CreateMqttClient(logger); + client.ConnectingAsync += e => { if (TestContext != null) @@ -232,9 +234,12 @@ public IMqttClient CreateClient() return CompletedTask.Instance; }; - - _clients.Add(client); - + + lock (_clients) + { + _clients.Add(client); + } + return client; } @@ -245,13 +250,20 @@ public MqttClientOptions CreateDefaultClientOptions() public MqttClientOptionsBuilder CreateDefaultClientOptionsBuilder() { - return Factory.CreateClientOptionsBuilder().WithProtocolVersion(_protocolVersion).WithTcpServer("127.0.0.1", ServerPort).WithClientId(TestContext.TestName + "_" + Guid.NewGuid()); + return Factory.CreateClientOptionsBuilder() + .WithProtocolVersion(_protocolVersion) + .WithTcpServer("127.0.0.1", ServerPort) + .WithClientId(TestContext.TestName + "_" + Guid.NewGuid()); } public ILowLevelMqttClient CreateLowLevelClient() { var client = Factory.CreateLowLevelMqttClient(ClientLogger); - _lowLevelClients.Add(client); + + lock (_lowLevelClients) + { + _lowLevelClients.Add(client); + } return client; } @@ -287,45 +299,68 @@ public MqttServer CreateServer(MqttServerOptions options) public void Dispose() { - foreach (var mqttClient in _clients) + try { + lock (_clients) + { + foreach (var mqttClient in _clients) + { + try + { + //mqttClient.DisconnectAsync().GetAwaiter().GetResult(); + } + catch + { + // This can happen when the test already disconnected the client. + } + finally + { + mqttClient?.Dispose(); + } + } + + _clients.Clear(); + } + + lock (_lowLevelClients) + { + foreach (var lowLevelMqttClient in _lowLevelClients) + { + lowLevelMqttClient.Dispose(); + } + + _lowLevelClients.Clear(); + } + try { - //mqttClient.DisconnectAsync().GetAwaiter().GetResult(); + Server?.StopAsync().GetAwaiter().GetResult(); } catch { - // This can happen when the test already disconnected the client. + // This can happen when the test already stopped the server. } finally { - mqttClient?.Dispose(); + Server?.Dispose(); } - } - foreach (var lowLevelMqttClient in _lowLevelClients) - { - lowLevelMqttClient.Dispose(); - } + Server = null; - try - { - Server?.StopAsync().GetAwaiter().GetResult(); - } - catch - { - // This can happen when the test already stopped the server. + ThrowIfLogErrors(); + + GC.Collect(); + GC.WaitForFullGCComplete(); + GC.WaitForPendingFinalizers(); + + if (_exceptions.Any()) + { + throw new Exception($"{_exceptions.Count} exceptions tracked.\r\n" + string.Join(Environment.NewLine, _exceptions)); + } } finally { - Server?.Dispose(); - } - - ThrowIfLogErrors(); - - if (_exceptions.Any()) - { - throw new Exception($"{_exceptions.Count} exceptions tracked.\r\n" + string.Join(Environment.NewLine, _exceptions)); + TaskScheduler.UnobservedTaskException -= TrackUnobservedTaskException; } } @@ -342,7 +377,7 @@ public async Task StartServer(MqttServerOptionsBuilder optionsBuilde var options = optionsBuilder.Build(); var server = CreateServer(options); - await server.StartAsync(); + await server.StartAsync().ConfigureAwait(false); // The OS has chosen the port to we have to properly expose it to the tests. ServerPort = options.DefaultEndpointOptions.Port; @@ -370,13 +405,16 @@ public async Task StartServer(Action confi public void ThrowIfLogErrors() { - lock (_serverErrors) + if (!IgnoreServerLogErrors) { - if (!IgnoreServerLogErrors && _serverErrors.Count > 0) + lock (_serverErrors) { - var message = $"Server had {_serverErrors.Count} errors (${string.Join(Environment.NewLine, _serverErrors)})."; - Console.WriteLine(message); - throw new Exception(message); + if (_serverErrors.Count > 0) + { + var message = $"Server had {_serverErrors.Count} errors (${string.Join(Environment.NewLine, _serverErrors)})."; + Console.WriteLine(message); + throw new Exception(message); + } } } @@ -396,10 +434,20 @@ public void ThrowIfLogErrors() public void TrackException(Exception exception) { + if (exception == null) + { + return; + } + lock (_exceptions) { _exceptions.Add(exception); } } + + void TrackUnobservedTaskException(object sender, UnobservedTaskExceptionEventArgs e) + { + TrackException(e.Exception); + } } } \ No newline at end of file diff --git a/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs b/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs index 9c8df7008..f6b3c42f7 100644 --- a/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs +++ b/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs @@ -1,4 +1,4 @@ -#if !(NET452 || NET461) +#if !(NET452 || NET461 || NET48) using System; using System.Collections.Concurrent; using System.Diagnostics; @@ -149,18 +149,14 @@ static X509Certificate2 CreateSelfSignedCertificate(string oid) } } - class ClientTestHarness : IDisposable + sealed class ClientTestHarness : IDisposable { readonly HotSwappableClientCertProvider _hotSwapClient = new HotSwappableClientCertProvider(); + IMqttClient _client; public string ClientId => _client.Options.ClientId; - public void ClearServerCerts() - { - _hotSwapClient.ClearServerCerts(); - } - public Task Connect() { return Run_Client_Connection(); @@ -169,6 +165,7 @@ public Task Connect() public void Dispose() { _client.Dispose(); + _hotSwapClient.Dispose(); } public X509Certificate2 GetCurrentClientCert() @@ -189,6 +186,8 @@ public void InstallNewServerCert(X509Certificate2 serverCert) public void WaitForConnectOrFail(TimeSpan timeout) { + Thread.Sleep(100); + if (!_client.IsConnected) { _client.ReconnectAsync().Wait(timeout); @@ -210,12 +209,12 @@ public void WaitForConnectToFail(TimeSpan timeout) Assert.IsFalse(_client.IsConnected, "Client connection success but test wanted fail"); } - public void WaitForDisconnect(TimeSpan timeout) + void WaitForDisconnect(TimeSpan timeout) { var timer = Stopwatch.StartNew(); while ((_client == null || _client.IsConnected) && timer.Elapsed < timeout) { - Thread.Sleep(5); + Thread.Sleep(100); } } @@ -252,22 +251,17 @@ void WaitForConnect(TimeSpan timeout) var timer = Stopwatch.StartNew(); while ((_client == null || !_client.IsConnected) && timer.Elapsed < timeout) { - Thread.Sleep(5); + Thread.Sleep(100); } } } - class ServerTestHarness : IDisposable + sealed class ServerTestHarness : IDisposable { readonly HotSwappableServerCertProvider _hotSwapServer = new HotSwappableServerCertProvider(); MqttServer _server; - public void ClearClientCerts() - { - _hotSwapServer.ClearClientCerts(); - } - public void Dispose() { if (_server != null) @@ -276,15 +270,12 @@ public void Dispose() _server.Dispose(); } - if (_hotSwapServer != null) - { - _hotSwapServer.Dispose(); - } + _hotSwapServer?.Dispose(); } - public async Task ForceDisconnectAsync(ClientTestHarness client) + public Task ForceDisconnectAsync(ClientTestHarness client) { - await _server.DisconnectClientAsync(client.ClientId, MqttDisconnectReasonCode.UnspecifiedError); + return _server.DisconnectClientAsync(client.ClientId, MqttDisconnectReasonCode.UnspecifiedError); } public X509Certificate2 GetCurrentServerCert() @@ -302,7 +293,7 @@ public void InstallNewClientCert(X509Certificate2 serverCert) _hotSwapServer.InstallNewClientCert(serverCert); } - public async Task StartServer() + public Task StartServer() { var mqttFactory = new MqttFactory(); @@ -312,24 +303,33 @@ public async Task StartServer() .Build(); mqttServerOptions.TlsEndpointOptions.ClientCertificateRequired = true; + _server = mqttFactory.CreateMqttServer(mqttServerOptions); - await _server.StartAsync(); + return _server.StartAsync(); } } - class HotSwappableClientCertProvider : IMqttClientCertificatesProvider + class HotSwappableClientCertProvider : IMqttClientCertificatesProvider, IDisposable { X509Certificate2Collection _certificates; - ConcurrentBag ServerCerts = new ConcurrentBag(); + ConcurrentBag _serverCerts = new ConcurrentBag(); public HotSwappableClientCertProvider() { _certificates = new X509Certificate2Collection(CreateSelfSignedCertificate("1.3.6.1.5.5.7.3.2")); } - - public void ClearServerCerts() + + public void Dispose() { - ServerCerts = new ConcurrentBag(); + if (_certificates != null) + { + foreach (var certs in _certificates) + { +#if !NET452 + certs.Dispose(); +#endif + } + } } public X509CertificateCollection GetCertificates() @@ -339,18 +339,17 @@ public X509CertificateCollection GetCertificates() public void HotSwapCert() { - var newCert = new X509Certificate2Collection(CreateSelfSignedCertificate("1.3.6.1.5.5.7.3.2")); - var oldCerts = Interlocked.Exchange(ref _certificates, newCert); + _certificates = new X509Certificate2Collection(CreateSelfSignedCertificate("1.3.6.1.5.5.7.3.2")); } public void InstallNewServerCert(X509Certificate2 serverCert) { - ServerCerts.Add(serverCert); + _serverCerts.Add(serverCert); } public bool OnCertificateValidation(MqttClientCertificateValidationEventArgs certContext) { - var serverCerts = ServerCerts.ToArray(); + var serverCerts = _serverCerts.ToArray(); var providedCert = certContext.Certificate.GetRawCertData(); for (int i = 0, n = serverCerts.Length; i < n; i++) @@ -365,36 +364,18 @@ public bool OnCertificateValidation(MqttClientCertificateValidationEventArgs cer return false; } - - void Dispose() - { - if (_certificates != null) - { - foreach (var certs in _certificates) - { -#if !NET452 - certs.Dispose(); -#endif - } - } - } } - class HotSwappableServerCertProvider : ICertificateProvider, IDisposable + sealed class HotSwappableServerCertProvider : ICertificateProvider, IDisposable { + readonly ConcurrentBag _clientCerts = new ConcurrentBag(); X509Certificate2 _certificate; - ConcurrentBag ClientCerts = new ConcurrentBag(); public HotSwappableServerCertProvider() { _certificate = CreateSelfSignedCertificate("1.3.6.1.5.5.7.3.1"); } - public void ClearClientCerts() - { - ClientCerts = new ConcurrentBag(); - } - public void Dispose() { #if !NET452 @@ -418,12 +399,12 @@ public void HotSwapCert() public void InstallNewClientCert(X509Certificate2 certificate) { - ClientCerts.Add(certificate); + _clientCerts.Add(certificate); } public bool RemoteCertificateValidationCallback(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) { - var serverCerts = ClientCerts.ToArray(); + var serverCerts = _clientCerts.ToArray(); var providedCert = certificate.GetRawCertData(); for (int i = 0, n = serverCerts.Length; i < n; i++) diff --git a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs index 92a2e012b..a5842c16b 100644 --- a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs @@ -26,9 +26,7 @@ public sealed class MqttChannelAdapter : Disposable, IMqttChannelAdapter readonly IMqttChannel _channel; readonly byte[] _fixedHeaderBuffer = new byte[2]; readonly MqttNetSourceLogger _logger; - readonly byte[] _singleByteBuffer = new byte[1]; - readonly AsyncLock _syncRoot = new AsyncLock(); Statistics _statistics; // mutable struct, don't make readonly! @@ -47,6 +45,8 @@ public MqttChannelAdapter(IMqttChannel channel, MqttPacketFormatterAdapter packe _logger = logger.WithSource(nameof(MqttChannelAdapter)); } + public bool AllowPacketFragmentation { get; set; } = true; + public long BytesReceived => Volatile.Read(ref _statistics._bytesReceived); public long BytesSent => Volatile.Read(ref _statistics._bytesSent); @@ -63,8 +63,6 @@ public MqttChannelAdapter(IMqttChannel channel, MqttPacketFormatterAdapter packe public MqttPacketInspector PacketInspector { get; set; } - public bool AllowPacketFragmentation { get; set; } = true; - public async Task ConnectAsync(CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -78,12 +76,30 @@ public async Task ConnectAsync(CancellationToken cancellationToken) * block forever. Even a cancellation token is not supported properly. */ - var connectTask = _channel.ConnectAsync(cancellationToken); - var timeout = new TaskCompletionSource(); using (cancellationToken.Register(() => timeout.TrySetResult(null))) { + var connectTask = Task.Run( + async () => + { + try + { + await _channel.ConnectAsync(cancellationToken).ConfigureAwait(false); + } + catch + { + // If the timeout is already reached the exception is no longer of interest and + // must be catched. Otherwise it will arrive at the TaskScheduler.UnobservedTaskException. + if (!timeout.Task.IsCompleted) + { + throw; + } + } + }, + CancellationToken.None); + await Task.WhenAny(connectTask, timeout.Task).ConfigureAwait(false); + if (timeout.Task.IsCompleted && !connectTask.IsCompleted) { throw new OperationCanceledException("MQTT connect canceled.", cancellationToken); @@ -204,7 +220,7 @@ public async Task SendPacketAsync(MqttPacket packet, CancellationToken cancellat PacketInspector?.BeginSendPacket(packetBuffer); _logger.Verbose("TX ({0} bytes) >>> {1}", packetBuffer.Length, packet); - + if (packetBuffer.Payload.Count == 0 || !AllowPacketFragmentation) { await _channel.WriteAsync(packetBuffer.Join(), true, cancellationToken).ConfigureAwait(false); diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index a62530643..dc65140b9 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -708,6 +708,9 @@ async Task ProcessReceivedPublishPackets(CancellationToken cancellationToken) await eventArgs.AcknowledgeAsync(cancellationToken).ConfigureAwait(false); } } + catch (ObjectDisposedException) + { + } catch (OperationCanceledException) { } @@ -787,7 +790,7 @@ async Task ReceivePacketsLoop(CancellationToken cancellationToken) { try { - _logger.Verbose("Start receiving packets."); + _logger.Verbose("Start receiving packets"); while (!cancellationToken.IsCancellationRequested) { @@ -825,20 +828,22 @@ async Task ReceivePacketsLoop(CancellationToken cancellationToken) } else if (exception is MqttCommunicationException) { - _logger.Warning(exception, "Communication error while receiving packets."); + _logger.Warning(exception, "Communication error while receiving packets"); } else { - _logger.Error(exception, "Error while receiving packets."); + _logger.Error(exception, "Error while receiving packets"); } - _packetDispatcher.FailAll(exception); + // The packet dispatcher is set to null when the client is being disposed so it may + // already being gone! + _packetDispatcher?.FailAll(exception); await DisconnectInternal(_packetReceiverTask, exception, null).ConfigureAwait(false); } finally { - _logger.Verbose("Stopped receiving packets."); + _logger.Verbose("Stopped receiving packets"); } } diff --git a/Source/MQTTnet/LowLevelClient/LowLevelMqttClient.cs b/Source/MQTTnet/LowLevelClient/LowLevelMqttClient.cs index 75074320e..dc2f9342e 100644 --- a/Source/MQTTnet/LowLevelClient/LowLevelMqttClient.cs +++ b/Source/MQTTnet/LowLevelClient/LowLevelMqttClient.cs @@ -19,7 +19,6 @@ public sealed class LowLevelMqttClient : ILowLevelMqttClient readonly IMqttClientAdapterFactory _clientAdapterFactory; readonly AsyncEvent _inspectPacketEvent = new AsyncEvent(); readonly MqttNetSourceLogger _logger; - readonly IMqttNetLogger _rootLogger; IMqttChannelAdapter _adapter;