diff --git a/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs b/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs index 376e800b6..5726b1484 100644 --- a/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs @@ -22,99 +22,6 @@ namespace MQTTnet.Tests.Clients.ManagedMqttClient [TestClass] public sealed class ManagedMqttClient_Tests : BaseTestClass { - [TestMethod] - public async Task Expose_Custom_Connection_Error() - { - using (var testEnvironment = CreateTestEnvironment()) - { - var server = await testEnvironment.StartServer(); - - server.ValidatingConnectionAsync += args => - { - args.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword; - return CompletedTask.Instance; - }; - - var managedClient = testEnvironment.Factory.CreateManagedMqttClient(); - - MqttClientDisconnectedEventArgs disconnectedArgs = null; - managedClient.DisconnectedAsync += args => - { - disconnectedArgs = args; - return CompletedTask.Instance; - }; - - var clientOptions = testEnvironment.Factory.CreateManagedMqttClientOptionsBuilder().WithClientOptions(testEnvironment.CreateDefaultClientOptions()).Build(); - await managedClient.StartAsync(clientOptions); - - await LongTestDelay(); - - Assert.IsNotNull(disconnectedArgs); - Assert.AreEqual(MqttClientConnectResultCode.BadUserNameOrPassword, disconnectedArgs.ConnectResult.ResultCode); - } - } - - [TestMethod] - public async Task Receive_While_Not_Cleanly_Disconnected() - { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) - { - await testEnvironment.StartServer(o => o.WithPersistentSessions()); - - var senderClient = await testEnvironment.ConnectClient(); - - // Prepare managed client. - var managedClient = testEnvironment.Factory.CreateManagedMqttClient(); - await managedClient.SubscribeAsync("#"); - var receivedMessages = testEnvironment.CreateApplicationMessageHandler(managedClient); - - var managedClientOptions = new ManagedMqttClientOptions - { - ClientOptions = testEnvironment.Factory.CreateClientOptionsBuilder() - .WithTcpServer("127.0.0.1", testEnvironment.ServerPort) - .WithClientId(nameof(Receive_While_Not_Cleanly_Disconnected) + "_managed") - .WithCleanSession(false) - .Build() - }; - - await managedClient.StartAsync(managedClientOptions); - await LongTestDelay(); - await LongTestDelay(); - - // Send test data. - await senderClient.PublishStringAsync("topic1"); - await LongTestDelay(); - await LongTestDelay(); - - receivedMessages.AssertReceivedCountEquals(1); - - // Stop the managed client but keep session at server (not clean disconnect required). - await managedClient.StopAsync(false); - await LongTestDelay(); - - // Send new messages in the meantime. - await senderClient.PublishStringAsync("topic2", qualityOfServiceLevel: MqttQualityOfServiceLevel.ExactlyOnce); - await LongTestDelay(); - - // Start the managed client, it should receive the new message. - await managedClient.StartAsync(managedClientOptions); - await LongTestDelay(); - - receivedMessages.AssertReceivedCountEquals(2); - - // Stop and start again, no new message should be received. - for (var i = 0; i < 3; i++) - { - await managedClient.StopAsync(false); - await LongTestDelay(); - await managedClient.StartAsync(managedClientOptions); - await LongTestDelay(); - } - - receivedMessages.AssertReceivedCountEquals(2); - } - } - [TestMethod] public async Task Connect_To_Invalid_Server() { @@ -181,6 +88,38 @@ public async Task Drop_New_Messages_On_Full_Queue() } } + [TestMethod] + public async Task Expose_Custom_Connection_Error() + { + using (var testEnvironment = CreateTestEnvironment()) + { + var server = await testEnvironment.StartServer(); + + server.ValidatingConnectionAsync += args => + { + args.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword; + return CompletedTask.Instance; + }; + + var managedClient = testEnvironment.Factory.CreateManagedMqttClient(); + + MqttClientDisconnectedEventArgs disconnectedArgs = null; + managedClient.DisconnectedAsync += args => + { + disconnectedArgs = args; + return CompletedTask.Instance; + }; + + var clientOptions = testEnvironment.Factory.CreateManagedMqttClientOptionsBuilder().WithClientOptions(testEnvironment.CreateDefaultClientOptions()).Build(); + await managedClient.StartAsync(clientOptions); + + await LongTestDelay(); + + Assert.IsNotNull(disconnectedArgs); + Assert.AreEqual(MqttClientConnectResultCode.BadUserNameOrPassword, disconnectedArgs.ConnectResult.ResultCode); + } + } + [TestMethod] public async Task ManagedClients_Will_Message_Send() { @@ -224,6 +163,67 @@ public async Task ManagedClients_Will_Message_Send() } } + [TestMethod] + public async Task Receive_While_Not_Cleanly_Disconnected() + { + using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + { + await testEnvironment.StartServer(o => o.WithPersistentSessions()); + + var senderClient = await testEnvironment.ConnectClient(); + + // Prepare managed client. + var managedClient = testEnvironment.Factory.CreateManagedMqttClient(); + await managedClient.SubscribeAsync("#"); + var receivedMessages = testEnvironment.CreateApplicationMessageHandler(managedClient); + + var managedClientOptions = new ManagedMqttClientOptions + { + ClientOptions = testEnvironment.Factory.CreateClientOptionsBuilder() + .WithTcpServer("127.0.0.1", testEnvironment.ServerPort) + .WithClientId(nameof(Receive_While_Not_Cleanly_Disconnected) + "_managed") + .WithCleanSession(false) + .Build() + }; + + await managedClient.StartAsync(managedClientOptions); + await LongTestDelay(); + await LongTestDelay(); + + // Send test data. + await senderClient.PublishStringAsync("topic1"); + await LongTestDelay(); + await LongTestDelay(); + + receivedMessages.AssertReceivedCountEquals(1); + + // Stop the managed client but keep session at server (not clean disconnect required). + await managedClient.StopAsync(false); + await LongTestDelay(); + + // Send new messages in the meantime. + await senderClient.PublishStringAsync("topic2", qualityOfServiceLevel: MqttQualityOfServiceLevel.ExactlyOnce); + await LongTestDelay(); + + // Start the managed client, it should receive the new message. + await managedClient.StartAsync(managedClientOptions); + await LongTestDelay(); + + receivedMessages.AssertReceivedCountEquals(2); + + // Stop and start again, no new message should be received. + for (var i = 0; i < 3; i++) + { + await managedClient.StopAsync(false); + await LongTestDelay(); + await managedClient.StartAsync(managedClientOptions); + await LongTestDelay(); + } + + receivedMessages.AssertReceivedCountEquals(2); + } + } + [TestMethod] public async Task Start_Stop() { @@ -375,7 +375,7 @@ public async Task Subscriptions_Are_Cleared_At_Logout() var clientOptions = new MqttClientOptionsBuilder().WithTcpServer("127.0.0.1", testEnvironment.ServerPort); var receivedManagedMessages = new List(); - + var managedClient = testEnvironment.Factory.CreateManagedMqttClient(testEnvironment.CreateClient()); managedClient.ApplicationMessageReceivedAsync += e => { @@ -403,7 +403,7 @@ public async Task Subscriptions_Are_Cleared_At_Logout() // Make sure that it gets received after subscribing again. await managedClient.SubscribeAsync("topic"); await LongTestDelay(); - + Assert.AreEqual(2, receivedManagedMessages.Count); } } @@ -421,7 +421,7 @@ public async Task Subscriptions_Are_Published_Immediately() var receivingClient = await CreateManagedClientAsync(testEnvironment, null, connectionCheckInterval); var sendingClient = await testEnvironment.ConnectClient(); - await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", PayloadSegment = new ArraySegment( new byte[] { 1 }), Retain = true }); + await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", PayloadSegment = new ArraySegment(new byte[] { 1 }), Retain = true }); var subscribeTime = DateTime.UtcNow; diff --git a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs index 996f71ebc..b49417ac5 100644 --- a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs @@ -167,6 +167,38 @@ public async Task Disconnect_Clean_With_User_Properties() } } + [TestMethod] + public async Task No_Unobserved_Exception() + { + using (var testEnvironment = CreateTestEnvironment()) + { + testEnvironment.IgnoreClientLogErrors = true; + + var client = testEnvironment.CreateClient(); + var options = new MqttClientOptionsBuilder().WithTcpServer("127.0.0.1").WithTimeout(TimeSpan.FromSeconds(2)).Build(); + + try + { + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(0.5))) + { + await client.ConnectAsync(options, timeout.Token); + } + } + catch (OperationCanceledException) + { + } + + client.Dispose(); + + // These delays and GC calls are required in order to make calling the finalizer reproducible. + GC.Collect(); + GC.WaitForPendingFinalizers(); + await LongTestDelay(); + await LongTestDelay(); + await LongTestDelay(); + } + } + [TestMethod] public async Task Return_Non_Success() { diff --git a/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs b/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs index 36bca5e4d..554669cb2 100644 --- a/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs +++ b/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs @@ -52,7 +52,7 @@ public async Task Execute_Success_Parameters_Propagated_Correctly() var paramValue = "123"; var parameters = new Dictionary { - { TestParametersTopicGenerationStrategy.ExpectedParamName, "123" }, + { TestParametersTopicGenerationStrategy.ExpectedParamName, "123" } }; using (var testEnvironment = CreateTestEnvironment()) @@ -164,7 +164,7 @@ public async Task Execute_Timeout_MQTT_V5_Mixed_Clients() using (var rpcClient = new MqttRpcClient(requestSender, new MqttRpcClientOptionsBuilder().Build())) { - var response = await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(2), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); + await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(2), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); } } } diff --git a/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs b/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs index 92f2543fa..c8d15185e 100644 --- a/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs +++ b/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.Linq; @@ -25,9 +24,9 @@ namespace MQTTnet.Tests.Mockups public sealed class TestEnvironment : IDisposable { readonly List _clientErrors = new List(); - readonly ConcurrentBag _clients = new ConcurrentBag(); + readonly List _clients = new List(); readonly List _exceptions = new List(); - readonly ConcurrentBag _lowLevelClients = new ConcurrentBag(); + readonly List _lowLevelClients = new List(); readonly MqttProtocolVersion _protocolVersion; readonly List _serverErrors = new List(); @@ -40,6 +39,8 @@ public TestEnvironment(TestContext testContext, MqttProtocolVersion protocolVers _protocolVersion = protocolVersion; TestContext = testContext; + TaskScheduler.UnobservedTaskException += TrackUnobservedTaskException; + ServerLogger.LogMessagePublished += (s, e) => { if (Debugger.IsAttached) @@ -216,8 +217,9 @@ public TestApplicationMessageReceivedHandler CreateApplicationMessageHandler(IMa public IMqttClient CreateClient() { var logger = EnableLogger ? (IMqttNetLogger)ClientLogger : MqttNetNullLogger.Instance; - + var client = Factory.CreateMqttClient(logger); + client.ConnectingAsync += e => { if (TestContext != null) @@ -232,9 +234,12 @@ public IMqttClient CreateClient() return CompletedTask.Instance; }; - - _clients.Add(client); - + + lock (_clients) + { + _clients.Add(client); + } + return client; } @@ -245,13 +250,20 @@ public MqttClientOptions CreateDefaultClientOptions() public MqttClientOptionsBuilder CreateDefaultClientOptionsBuilder() { - return Factory.CreateClientOptionsBuilder().WithProtocolVersion(_protocolVersion).WithTcpServer("127.0.0.1", ServerPort).WithClientId(TestContext.TestName + "_" + Guid.NewGuid()); + return Factory.CreateClientOptionsBuilder() + .WithProtocolVersion(_protocolVersion) + .WithTcpServer("127.0.0.1", ServerPort) + .WithClientId(TestContext.TestName + "_" + Guid.NewGuid()); } public ILowLevelMqttClient CreateLowLevelClient() { var client = Factory.CreateLowLevelMqttClient(ClientLogger); - _lowLevelClients.Add(client); + + lock (_lowLevelClients) + { + _lowLevelClients.Add(client); + } return client; } @@ -287,45 +299,68 @@ public MqttServer CreateServer(MqttServerOptions options) public void Dispose() { - foreach (var mqttClient in _clients) + try { + lock (_clients) + { + foreach (var mqttClient in _clients) + { + try + { + //mqttClient.DisconnectAsync().GetAwaiter().GetResult(); + } + catch + { + // This can happen when the test already disconnected the client. + } + finally + { + mqttClient?.Dispose(); + } + } + + _clients.Clear(); + } + + lock (_lowLevelClients) + { + foreach (var lowLevelMqttClient in _lowLevelClients) + { + lowLevelMqttClient.Dispose(); + } + + _lowLevelClients.Clear(); + } + try { - //mqttClient.DisconnectAsync().GetAwaiter().GetResult(); + Server?.StopAsync().GetAwaiter().GetResult(); } catch { - // This can happen when the test already disconnected the client. + // This can happen when the test already stopped the server. } finally { - mqttClient?.Dispose(); + Server?.Dispose(); } - } - foreach (var lowLevelMqttClient in _lowLevelClients) - { - lowLevelMqttClient.Dispose(); - } + Server = null; - try - { - Server?.StopAsync().GetAwaiter().GetResult(); - } - catch - { - // This can happen when the test already stopped the server. + ThrowIfLogErrors(); + + GC.Collect(); + GC.WaitForFullGCComplete(); + GC.WaitForPendingFinalizers(); + + if (_exceptions.Any()) + { + throw new Exception($"{_exceptions.Count} exceptions tracked.\r\n" + string.Join(Environment.NewLine, _exceptions)); + } } finally { - Server?.Dispose(); - } - - ThrowIfLogErrors(); - - if (_exceptions.Any()) - { - throw new Exception($"{_exceptions.Count} exceptions tracked.\r\n" + string.Join(Environment.NewLine, _exceptions)); + TaskScheduler.UnobservedTaskException -= TrackUnobservedTaskException; } } @@ -342,7 +377,7 @@ public async Task StartServer(MqttServerOptionsBuilder optionsBuilde var options = optionsBuilder.Build(); var server = CreateServer(options); - await server.StartAsync(); + await server.StartAsync().ConfigureAwait(false); // The OS has chosen the port to we have to properly expose it to the tests. ServerPort = options.DefaultEndpointOptions.Port; @@ -370,13 +405,16 @@ public async Task StartServer(Action confi public void ThrowIfLogErrors() { - lock (_serverErrors) + if (!IgnoreServerLogErrors) { - if (!IgnoreServerLogErrors && _serverErrors.Count > 0) + lock (_serverErrors) { - var message = $"Server had {_serverErrors.Count} errors (${string.Join(Environment.NewLine, _serverErrors)})."; - Console.WriteLine(message); - throw new Exception(message); + if (_serverErrors.Count > 0) + { + var message = $"Server had {_serverErrors.Count} errors (${string.Join(Environment.NewLine, _serverErrors)})."; + Console.WriteLine(message); + throw new Exception(message); + } } } @@ -396,10 +434,20 @@ public void ThrowIfLogErrors() public void TrackException(Exception exception) { + if (exception == null) + { + return; + } + lock (_exceptions) { _exceptions.Add(exception); } } + + void TrackUnobservedTaskException(object sender, UnobservedTaskExceptionEventArgs e) + { + TrackException(e.Exception); + } } } \ No newline at end of file diff --git a/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs b/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs index 9c8df7008..f6b3c42f7 100644 --- a/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs +++ b/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs @@ -1,4 +1,4 @@ -#if !(NET452 || NET461) +#if !(NET452 || NET461 || NET48) using System; using System.Collections.Concurrent; using System.Diagnostics; @@ -149,18 +149,14 @@ static X509Certificate2 CreateSelfSignedCertificate(string oid) } } - class ClientTestHarness : IDisposable + sealed class ClientTestHarness : IDisposable { readonly HotSwappableClientCertProvider _hotSwapClient = new HotSwappableClientCertProvider(); + IMqttClient _client; public string ClientId => _client.Options.ClientId; - public void ClearServerCerts() - { - _hotSwapClient.ClearServerCerts(); - } - public Task Connect() { return Run_Client_Connection(); @@ -169,6 +165,7 @@ public Task Connect() public void Dispose() { _client.Dispose(); + _hotSwapClient.Dispose(); } public X509Certificate2 GetCurrentClientCert() @@ -189,6 +186,8 @@ public void InstallNewServerCert(X509Certificate2 serverCert) public void WaitForConnectOrFail(TimeSpan timeout) { + Thread.Sleep(100); + if (!_client.IsConnected) { _client.ReconnectAsync().Wait(timeout); @@ -210,12 +209,12 @@ public void WaitForConnectToFail(TimeSpan timeout) Assert.IsFalse(_client.IsConnected, "Client connection success but test wanted fail"); } - public void WaitForDisconnect(TimeSpan timeout) + void WaitForDisconnect(TimeSpan timeout) { var timer = Stopwatch.StartNew(); while ((_client == null || _client.IsConnected) && timer.Elapsed < timeout) { - Thread.Sleep(5); + Thread.Sleep(100); } } @@ -252,22 +251,17 @@ void WaitForConnect(TimeSpan timeout) var timer = Stopwatch.StartNew(); while ((_client == null || !_client.IsConnected) && timer.Elapsed < timeout) { - Thread.Sleep(5); + Thread.Sleep(100); } } } - class ServerTestHarness : IDisposable + sealed class ServerTestHarness : IDisposable { readonly HotSwappableServerCertProvider _hotSwapServer = new HotSwappableServerCertProvider(); MqttServer _server; - public void ClearClientCerts() - { - _hotSwapServer.ClearClientCerts(); - } - public void Dispose() { if (_server != null) @@ -276,15 +270,12 @@ public void Dispose() _server.Dispose(); } - if (_hotSwapServer != null) - { - _hotSwapServer.Dispose(); - } + _hotSwapServer?.Dispose(); } - public async Task ForceDisconnectAsync(ClientTestHarness client) + public Task ForceDisconnectAsync(ClientTestHarness client) { - await _server.DisconnectClientAsync(client.ClientId, MqttDisconnectReasonCode.UnspecifiedError); + return _server.DisconnectClientAsync(client.ClientId, MqttDisconnectReasonCode.UnspecifiedError); } public X509Certificate2 GetCurrentServerCert() @@ -302,7 +293,7 @@ public void InstallNewClientCert(X509Certificate2 serverCert) _hotSwapServer.InstallNewClientCert(serverCert); } - public async Task StartServer() + public Task StartServer() { var mqttFactory = new MqttFactory(); @@ -312,24 +303,33 @@ public async Task StartServer() .Build(); mqttServerOptions.TlsEndpointOptions.ClientCertificateRequired = true; + _server = mqttFactory.CreateMqttServer(mqttServerOptions); - await _server.StartAsync(); + return _server.StartAsync(); } } - class HotSwappableClientCertProvider : IMqttClientCertificatesProvider + class HotSwappableClientCertProvider : IMqttClientCertificatesProvider, IDisposable { X509Certificate2Collection _certificates; - ConcurrentBag ServerCerts = new ConcurrentBag(); + ConcurrentBag _serverCerts = new ConcurrentBag(); public HotSwappableClientCertProvider() { _certificates = new X509Certificate2Collection(CreateSelfSignedCertificate("1.3.6.1.5.5.7.3.2")); } - - public void ClearServerCerts() + + public void Dispose() { - ServerCerts = new ConcurrentBag(); + if (_certificates != null) + { + foreach (var certs in _certificates) + { +#if !NET452 + certs.Dispose(); +#endif + } + } } public X509CertificateCollection GetCertificates() @@ -339,18 +339,17 @@ public X509CertificateCollection GetCertificates() public void HotSwapCert() { - var newCert = new X509Certificate2Collection(CreateSelfSignedCertificate("1.3.6.1.5.5.7.3.2")); - var oldCerts = Interlocked.Exchange(ref _certificates, newCert); + _certificates = new X509Certificate2Collection(CreateSelfSignedCertificate("1.3.6.1.5.5.7.3.2")); } public void InstallNewServerCert(X509Certificate2 serverCert) { - ServerCerts.Add(serverCert); + _serverCerts.Add(serverCert); } public bool OnCertificateValidation(MqttClientCertificateValidationEventArgs certContext) { - var serverCerts = ServerCerts.ToArray(); + var serverCerts = _serverCerts.ToArray(); var providedCert = certContext.Certificate.GetRawCertData(); for (int i = 0, n = serverCerts.Length; i < n; i++) @@ -365,36 +364,18 @@ public bool OnCertificateValidation(MqttClientCertificateValidationEventArgs cer return false; } - - void Dispose() - { - if (_certificates != null) - { - foreach (var certs in _certificates) - { -#if !NET452 - certs.Dispose(); -#endif - } - } - } } - class HotSwappableServerCertProvider : ICertificateProvider, IDisposable + sealed class HotSwappableServerCertProvider : ICertificateProvider, IDisposable { + readonly ConcurrentBag _clientCerts = new ConcurrentBag(); X509Certificate2 _certificate; - ConcurrentBag ClientCerts = new ConcurrentBag(); public HotSwappableServerCertProvider() { _certificate = CreateSelfSignedCertificate("1.3.6.1.5.5.7.3.1"); } - public void ClearClientCerts() - { - ClientCerts = new ConcurrentBag(); - } - public void Dispose() { #if !NET452 @@ -418,12 +399,12 @@ public void HotSwapCert() public void InstallNewClientCert(X509Certificate2 certificate) { - ClientCerts.Add(certificate); + _clientCerts.Add(certificate); } public bool RemoteCertificateValidationCallback(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) { - var serverCerts = ClientCerts.ToArray(); + var serverCerts = _clientCerts.ToArray(); var providedCert = certificate.GetRawCertData(); for (int i = 0, n = serverCerts.Length; i < n; i++)