From 7c04924144c362f11aa5d165b23a4cdddd71e51a Mon Sep 17 00:00:00 2001
From: Joachim Baumann <joachim.baumann@xinaris.de>
Date: Thu, 6 Jan 2022 18:59:27 +0100
Subject: [PATCH] Add optional timeout to subscribe()

---
 src/paho/mqtt/subscribe.py | 60 +++++++++++++++++++++++++++++---------
 1 file changed, 46 insertions(+), 14 deletions(-)

diff --git a/src/paho/mqtt/subscribe.py b/src/paho/mqtt/subscribe.py
index 643df9c1..a197507d 100644
--- a/src/paho/mqtt/subscribe.py
+++ b/src/paho/mqtt/subscribe.py
@@ -11,6 +11,7 @@
 #
 # Contributors:
 #    Roger Light - initial API and implementation
+#    Joachim Baumann - added timeout to subscribe.simple()
 
 """
 This module provides some helper functions to allow straightforward subscribing
@@ -22,6 +23,7 @@
 
 from .. import mqtt
 from . import client as paho
+from threading import Lock
 
 
 def _on_connect_v5(client, userdata, flags, rc, properties):
@@ -35,6 +37,7 @@ def _on_connect_v5(client, userdata, flags, rc, properties):
     else:
         client.subscribe(userdata['topics'], userdata['qos'])
 
+
 def _on_connect(client, userdata, flags, rc):
     """Internal v5 callback"""
     _on_connect_v5(client, userdata, flags, rc, None)
@@ -42,10 +45,11 @@ def _on_connect(client, userdata, flags, rc):
 
 def _on_message_callback(client, userdata, message):
     """Internal callback"""
-    userdata['callback'](client, userdata['userdata'], message)
+    userdata['callback'](client, userdata['userdata'],
+                         message, userdata['lock'])
 
 
-def _on_message_simple(client, userdata, message):
+def _on_message_simple(client, userdata, message, lock):
     """Internal callback"""
 
     if userdata['msg_count'] == 0:
@@ -60,22 +64,27 @@ def _on_message_simple(client, userdata, message):
     if userdata['messages'] is None and userdata['msg_count'] == 0:
         userdata['messages'] = message
         client.disconnect()
+        if lock:
+            lock.release()
         return
 
     userdata['messages'].append(message)
     if userdata['msg_count'] == 0:
         client.disconnect()
+        if lock:
+            lock.release()
 
 
 def callback(callback, topics, qos=0, userdata=None, hostname="localhost",
              port=1883, client_id="", keepalive=60, will=None, auth=None,
              tls=None, protocol=paho.MQTTv311, transport="tcp",
-             clean_session=True, proxy_args=None):
+             clean_session=True, proxy_args=None, timeout=None):
     """Subscribe to a list of topics and process them in a callback function.
 
     This function creates an MQTT client, connects to a broker and subscribes
     to a list of topics. Incoming messages are processed by the user provided
-    callback.  This is a blocking function and will never return.
+    callback.  This is a blocking function and will only return after the
+    timeout. If no timeout is given, the function will never return.
 
     callback : function of the form "on_message(client, userdata, message)" for
                processing the messages received.
@@ -132,16 +141,25 @@ def callback(callback, topics, qos=0, userdata=None, hostname="localhost",
                     Defaults to True.
 
     proxy_args: a dictionary that will be given to the client.
+
+    timeout: the timeout value after which the client disconnects from the
+             broker. If no timeout is given, the client disconnects only
+             after "msg_count" messages have been received.
     """
 
     if qos < 0 or qos > 2:
         raise ValueError('qos must be in the range 0-2')
 
+    lock = None
+    if timeout is not None:
+        lock = Lock()
+
     callback_userdata = {
-        'callback':callback,
-        'topics':topics,
-        'qos':qos,
-        'userdata':userdata}
+        'callback': callback,
+        'topics': topics,
+        'qos': qos,
+        'lock': lock,
+        'userdata': userdata}
 
     client = paho.Client(client_id=client_id, userdata=callback_userdata,
                          protocol=protocol, transport=transport,
@@ -180,18 +198,27 @@ def callback(callback, topics, qos=0, userdata=None, hostname="localhost",
             client.tls_set_context(tls)
 
     client.connect(hostname, port, keepalive)
-    client.loop_forever()
+
+    if timeout == None:
+        client.loop_forever()
+    else:
+        lock.acquire()
+        client.loop_start()
+        lock.acquire(timeout=timeout)
+        client.loop_stop()
+        client.disconnect()
 
 
 def simple(topics, qos=0, msg_count=1, retained=True, hostname="localhost",
            port=1883, client_id="", keepalive=60, will=None, auth=None,
            tls=None, protocol=paho.MQTTv311, transport="tcp",
-           clean_session=True, proxy_args=None):
+           clean_session=True, proxy_args=None, timeout=None):
     """Subscribe to a list of topics and return msg_count messages.
 
     This function creates an MQTT client, connects to a broker and subscribes
-    to a list of topics. Once "msg_count" messages have been received, it
-    disconnects cleanly from the broker and returns the messages.
+    to a list of topics. Once "msg_count" messages have been received or the
+    timeout has been reached, it disconnects cleanly from the broker and
+    returns the received messages.
 
     topics : either a string containing a single topic to subscribe to, or a
              list of topics to subscribe to.
@@ -253,6 +280,10 @@ def simple(topics, qos=0, msg_count=1, retained=True, hostname="localhost",
                     Defaults to True.
 
     proxy_args: a dictionary that will be given to the client.
+
+    timeout: the timeout value after which the client disconnects from the
+             broker. If no timeout is given, the client disconnects only
+             after "msg_count" messages have been received.
     """
 
     if msg_count < 1:
@@ -265,10 +296,11 @@ def simple(topics, qos=0, msg_count=1, retained=True, hostname="localhost",
     else:
         messages = []
 
-    userdata = {'retained':retained, 'msg_count':msg_count, 'messages':messages}
+    userdata = {'retained': retained,
+                'msg_count': msg_count, 'messages': messages}
 
     callback(_on_message_simple, topics, qos, userdata, hostname, port,
              client_id, keepalive, will, auth, tls, protocol, transport,
-             clean_session, proxy_args)
+             clean_session, proxy_args, timeout)
 
     return userdata['messages']