diff --git a/Source/MQTTnet.Server/Internal/ISubscriptionChangedNotification.cs b/Source/MQTTnet.Server/Internal/ISubscriptionChangedNotification.cs index af9538df6..17ac338cf 100644 --- a/Source/MQTTnet.Server/Internal/ISubscriptionChangedNotification.cs +++ b/Source/MQTTnet.Server/Internal/ISubscriptionChangedNotification.cs @@ -2,7 +2,7 @@ namespace MQTTnet.Server.Internal { public interface ISubscriptionChangedNotification { - void OnSubscriptionsAdded(MqttSession clientSession, List subscriptionsTopics); + void OnSubscriptionsAdded(MqttSession clientSession, List subscriptionsTopics); void OnSubscriptionsRemoved(MqttSession clientSession, List subscriptionTopics); } diff --git a/Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs b/Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs index 2356b554b..dfef366b9 100644 --- a/Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs +++ b/Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs @@ -30,7 +30,8 @@ public sealed class MqttClientSessionsManager : ISubscriptionChangedNotification // The _sessions dictionary contains all session, the _subscriberSessions hash set contains subscriber sessions only. // See the MqttSubscription object for a detailed explanation. readonly MqttSessionsStorage _sessionsStorage = new(); - readonly HashSet _subscriberSessions = []; + readonly HashSet _subscriberSessionsWithWildcards = []; + readonly Dictionary> _simpleTopicToSessions = []; public MqttClientSessionsManager(MqttServerOptions options, MqttRetainedMessagesManager retainedMessagesManager, MqttServerEventContainer eventContainer, IMqttNetLogger logger) { @@ -77,7 +78,7 @@ public async Task DeleteSessionAsync(string clientId) { if (_sessionsStorage.TryRemoveSession(clientId, out session)) { - _subscriberSessions.Remove(session); + CleanupClientSessionUnsafe(session); } } finally @@ -161,11 +162,30 @@ public async Task DispatchApplicationMessage( await _retainedMessagesManager.UpdateMessage(senderId, applicationMessage).ConfigureAwait(false); } - List subscriberSessions; + HashSet subscriberSessions; _sessionsManagementLock.EnterReadLock(); try { - subscriberSessions = _subscriberSessions.ToList(); + if (_simpleTopicToSessions.TryGetValue(applicationMessage.Topic, out var matchedSimpleTopicSessions)) + { + // Create the initial subscriberSessions from whichever set is larger to take advantage + // of the internal ConstructFrom other HashSet optimizations + if (matchedSimpleTopicSessions.Count > _subscriberSessionsWithWildcards.Count) + { + subscriberSessions = new HashSet(matchedSimpleTopicSessions); + subscriberSessions.UnionWith(_subscriberSessionsWithWildcards); + } + else + { + subscriberSessions = new HashSet(_subscriberSessionsWithWildcards); + subscriberSessions.UnionWith(matchedSimpleTopicSessions); + } + } + else + { + // Always include the sessions with wildcards. They need to be properly matched against the topic filter. + subscriberSessions = new HashSet(_subscriberSessionsWithWildcards); + } } finally { @@ -446,20 +466,32 @@ public async Task HandleClientConnectionAsync(IMqttChannelAdapter channelAdapter } } - public void OnSubscriptionsAdded(MqttSession clientSession, List topics) + public void OnSubscriptionsAdded(MqttSession clientSession, List subscriptions) { _sessionsManagementLock.EnterWriteLock(); try { - if (!clientSession.HasSubscribedTopics) + foreach (var subscription in subscriptions) { - // first subscribed topic - _subscriberSessions.Add(clientSession); - } - - foreach (var topic in topics) - { - clientSession.AddSubscribedTopic(topic); + if (subscription.TopicHasWildcard) + { + if (!clientSession.HasSubscribedWildcardTopics) + { + _subscriberSessionsWithWildcards.Add(clientSession); + } + } + else + { + if (_simpleTopicToSessions.TryGetValue(subscription.Topic, out var simpleTopicSessions)) + { + simpleTopicSessions.Add(clientSession); + } + else + { + _simpleTopicToSessions[subscription.Topic] = [clientSession]; + } + } + clientSession.AddSubscribedTopic(subscription.Topic, subscription.TopicHasWildcard); } } finally @@ -475,13 +507,21 @@ public void OnSubscriptionsRemoved(MqttSession clientSession, List subsc { foreach (var subscriptionTopic in subscriptionTopics) { + if (_simpleTopicToSessions.TryGetValue(subscriptionTopic, out var simpleTopicSessions)) + { + simpleTopicSessions.Remove(clientSession); + if (simpleTopicSessions.Count == 0) + { + _simpleTopicToSessions.Remove(subscriptionTopic); + } + } clientSession.RemoveSubscribedTopic(subscriptionTopic); } - if (!clientSession.HasSubscribedTopics) + if (!clientSession.HasSubscribedWildcardTopics) { - // last subscription removed - _subscriberSessions.Remove(clientSession); + // Last wildcard subscription removed + _subscriberSessionsWithWildcards.Remove(clientSession); } } finally @@ -564,7 +604,7 @@ async Task CreateClientConnection( if (connectPacket.CleanSession) { _logger.Verbose("Deleting existing session of client '{0}' due to clean start", connectPacket.ClientId); - _subscriberSessions.Remove(oldSession); + CleanupClientSessionUnsafe(oldSession); session = CreateSession(connectPacket, validatingConnectionEventArgs); } else @@ -669,6 +709,23 @@ MqttSession GetClientSession(string clientId) } } + //* Must be called with the _sessionsManagementLock held. + void CleanupClientSessionUnsafe(MqttSession session) + { + _subscriberSessionsWithWildcards.Remove(session); + foreach (var simpleTopic in session.SubscribedSimpleTopics) + { + if (_simpleTopicToSessions.TryGetValue(simpleTopic, out var simpleTopicSessions)) + { + simpleTopicSessions.Remove(session); + if (simpleTopicSessions.Count == 0) + { + _simpleTopicToSessions.Remove(simpleTopic); + } + } + } + } + async Task ReceiveConnectPacket(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken) { try diff --git a/Source/MQTTnet.Server/Internal/MqttClientSubscriptionsManager.cs b/Source/MQTTnet.Server/Internal/MqttClientSubscriptionsManager.cs index 7c1023e52..aa6affe1a 100644 --- a/Source/MQTTnet.Server/Internal/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet.Server/Internal/MqttClientSubscriptionsManager.cs @@ -166,7 +166,7 @@ public async Task Subscribe(MqttSubscribePacket subscribePacket var retainedApplicationMessages = await _retainedMessagesManager.GetMessages().ConfigureAwait(false); var result = new SubscribeResult(subscribePacket.TopicFilters.Count); - var addedSubscriptions = new List(); + var addedSubscriptions = new List(); var finalTopicFilters = new List(); // The topic filters are order by its QoS so that the higher QoS will win over a @@ -195,7 +195,7 @@ public async Task Subscribe(MqttSubscribePacket subscribePacket var createSubscriptionResult = CreateSubscription(topicFilter, subscribePacket.SubscriptionIdentifier, interceptorEventArgs.Response.ReasonCode); - addedSubscriptions.Add(topicFilter.Topic); + addedSubscriptions.Add(createSubscriptionResult.Subscription); finalTopicFilters.Add(topicFilter); FilterRetainedApplicationMessages(retainedApplicationMessages, createSubscriptionResult, result); diff --git a/Source/MQTTnet.Server/Internal/MqttSession.cs b/Source/MQTTnet.Server/Internal/MqttSession.cs index 7e3ab1279..78a14ea9f 100644 --- a/Source/MQTTnet.Server/Internal/MqttSession.cs +++ b/Source/MQTTnet.Server/Internal/MqttSession.cs @@ -23,8 +23,8 @@ public sealed class MqttSession : IDisposable // Do not use a dictionary in order to keep the ordering of the messages. readonly List _unacknowledgedPublishPackets = new(); - // Bookkeeping to know if this is a subscribing client; lazy initialize later. - HashSet _subscribedTopics; + readonly HashSet _subscribedSimpleTopics = []; + readonly HashSet _subscribedWildcardTopics = []; public MqttSession( MqttConnectPacket connectPacket, @@ -50,7 +50,9 @@ public MqttSession( public uint ExpiryInterval => _connectPacket.SessionExpiryInterval; - public bool HasSubscribedTopics => _subscribedTopics != null && _subscribedTopics.Count > 0; + public bool HasSubscribedWildcardTopics => _subscribedWildcardTopics.Count > 0; + + public HashSet SubscribedSimpleTopics => _subscribedSimpleTopics; public string Id => _connectPacket.ClientId; @@ -79,14 +81,16 @@ public MqttPublishPacket AcknowledgePublishPacket(ushort packetIdentifier) return publishPacket; } - public void AddSubscribedTopic(string topic) + public void AddSubscribedTopic(string topic, bool isWildcardTopic) { - if (_subscribedTopics == null) + if (isWildcardTopic) { - _subscribedTopics = new HashSet(); + _subscribedWildcardTopics.Add(topic); + } + else + { + _subscribedSimpleTopics.Add(topic); } - - _subscribedTopics.Add(topic); } public Task DeleteAsync() @@ -208,7 +212,8 @@ public void Recover() public void RemoveSubscribedTopic(string topic) { - _subscribedTopics?.Remove(topic); + _subscribedSimpleTopics.Remove(topic); + _subscribedWildcardTopics.Remove(topic); } public Task Subscribe(MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) diff --git a/Source/MQTTnet.Tests/TopicFilterComparer_Tests.cs b/Source/MQTTnet.Tests/TopicFilterComparer_Tests.cs index 61f8987cd..7615cfeff 100644 --- a/Source/MQTTnet.Tests/TopicFilterComparer_Tests.cs +++ b/Source/MQTTnet.Tests/TopicFilterComparer_Tests.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using Microsoft.VisualStudio.TestTools.UnitTesting; -using MQTTnet.Server; using MQTTnet.Server.Internal; namespace MQTTnet.Tests