Skip to content

Commit 27bbdc3

Browse files
authored
Merge pull request #829 from akrpic77/ak_add_unix_transport
add support for "unix" transport where socket module contains AF_UNIX
2 parents 8503635 + a0554dd commit 27bbdc3

File tree

3 files changed

+66
-24
lines changed

3 files changed

+66
-24
lines changed

src/paho/mqtt/client.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,10 @@ class Client:
682682
683683
:param transport: use "websockets" to use WebSockets as the transport
684684
mechanism. Set to "tcp" to use raw TCP, which is the default.
685+
Use "unix" to use Unix sockets as the transport mechanism; note that
686+
this option is only available on platforms that support Unix sockets,
687+
and the "host" argument is interpreted as the path to the Unix socket
688+
file in this case.
685689
686690
:param bool manual_ack: normally, when a message is received, the library automatically
687691
acknowledges after on_message callback returns. manual_ack=True allows the application to
@@ -733,14 +737,16 @@ def __init__(
733737
clean_session: bool | None = None,
734738
userdata: Any = None,
735739
protocol: MQTTProtocolVersion = MQTTv311,
736-
transport: Literal["tcp", "websockets"] = "tcp",
740+
transport: Literal["tcp", "websockets", "unix"] = "tcp",
737741
reconnect_on_failure: bool = True,
738742
manual_ack: bool = False,
739743
) -> None:
740744
transport = transport.lower() # type: ignore
741-
if transport not in ("websockets", "tcp"):
745+
if transport == "unix" and not hasattr(socket, "AF_UNIX"):
746+
raise ValueError('"unix" transport not supported')
747+
elif transport not in ("websockets", "tcp", "unix"):
742748
raise ValueError(
743-
f'transport must be "websockets" or "tcp", not {transport}')
749+
f'transport must be "websockets", "tcp" or "unix", not {transport}')
744750

745751
self._manual_ack = manual_ack
746752
self._transport = transport
@@ -931,7 +937,7 @@ def keepalive(self, value: int) -> None:
931937
self._keepalive = value
932938

933939
@property
934-
def transport(self) -> Literal["tcp", "websockets"]:
940+
def transport(self) -> Literal["tcp", "websockets", "unix"]:
935941
"""
936942
Transport method used for the connection ("tcp" or "websockets").
937943
@@ -4597,7 +4603,11 @@ def _get_proxy(self) -> dict[str, Any] | None:
45974603
return None
45984604

45994605
def _create_socket(self) -> SocketLike:
4600-
sock = self._create_socket_connection()
4606+
if self._transport == "unix":
4607+
sock = self._create_unix_socket_connection()
4608+
else:
4609+
sock = self._create_socket_connection()
4610+
46014611
if self._ssl:
46024612
sock = self._ssl_wrap_socket(sock)
46034613

@@ -4614,6 +4624,11 @@ def _create_socket(self) -> SocketLike:
46144624

46154625
return sock
46164626

4627+
def _create_unix_socket_connection(self) -> _socket.socket:
4628+
unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
4629+
unix_socket.connect(self._host)
4630+
return unix_socket
4631+
46174632
def _create_socket_connection(self) -> _socket.socket:
46184633
proxy = self._get_proxy()
46194634
addr = (self._host, self._port)

tests/test_client.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def test_01_con_discon_success(self, proto_ver, callback_version, fake_broker):
3131
callback_version,
3232
"01-con-discon-success",
3333
protocol=proto_ver,
34+
transport=fake_broker.transport,
3435
)
3536

3637
def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None):
@@ -70,7 +71,8 @@ def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None):
7071

7172
def test_01_con_failure_rc(self, proto_ver, callback_version, fake_broker):
7273
mqttc = client.Client(
73-
callback_version, "01-con-failure-rc", protocol=proto_ver)
74+
callback_version, "01-con-failure-rc",
75+
protocol=proto_ver, transport=fake_broker.transport)
7476

7577
def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None):
7678
assert rc_or_reason_code > 0
@@ -107,7 +109,9 @@ def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None):
107109
mqttc.loop_stop()
108110

109111
def test_connection_properties(self, proto_ver, callback_version, fake_broker):
110-
mqttc = client.Client(CallbackAPIVersion.VERSION2, "client-id", protocol=proto_ver)
112+
mqttc = client.Client(
113+
CallbackAPIVersion.VERSION2, "client-id",
114+
protocol=proto_ver, transport=fake_broker.transport)
111115
mqttc.enable_logger()
112116

113117
is_connected = threading.Event()
@@ -131,7 +135,7 @@ def on_disconnect(*args):
131135
mqttc.keepalive = 7
132136
mqttc.max_inflight_messages = 7
133137
mqttc.max_queued_messages = 7
134-
mqttc.transport = "tcp"
138+
mqttc.transport = fake_broker.transport
135139
mqttc.username = "username"
136140
mqttc.password = "password"
137141

@@ -184,7 +188,7 @@ def on_disconnect(*args):
184188
mqttc.max_queued_messages = 7
185189

186190
with pytest.raises(RuntimeError):
187-
mqttc.transport = "tcp"
191+
mqttc.transport = fake_broker.transport
188192

189193
with pytest.raises(RuntimeError):
190194
mqttc.username = "username"
@@ -217,7 +221,9 @@ class Test_connect_v5:
217221
"""
218222

219223
def test_01_broker_no_support(self, fake_broker):
220-
mqttc = client.Client(CallbackAPIVersion.VERSION2, "01-broker-no-support", protocol=MQTTProtocolVersion.MQTTv5)
224+
mqttc = client.Client(
225+
CallbackAPIVersion.VERSION2, "01-broker-no-support",
226+
protocol=MQTTProtocolVersion.MQTTv5, transport=fake_broker.transport)
221227

222228
def on_connect(mqttc, obj, flags, reason, properties):
223229
assert reason == 132
@@ -261,6 +267,7 @@ def test_with_loop_start(self, fake_broker: FakeBroker):
261267
"test_with_loop_start",
262268
protocol=MQTTProtocolVersion.MQTTv311,
263269
reconnect_on_failure=False,
270+
transport=fake_broker.transport
264271
)
265272

266273
on_connect_reached = threading.Event()
@@ -311,6 +318,7 @@ def test_with_loop(self, fake_broker: FakeBroker):
311318
CallbackAPIVersion.VERSION1,
312319
"test_with_loop",
313320
clean_session=True,
321+
transport=fake_broker.transport,
314322
)
315323

316324
on_connect_reached = threading.Event()
@@ -367,6 +375,7 @@ def test_publish_before_connect(self, fake_broker: FakeBroker) -> None:
367375
mqttc = client.Client(
368376
CallbackAPIVersion.VERSION1,
369377
"test_publish_before_connect",
378+
transport=fake_broker.transport,
370379
)
371380

372381
def on_connect(mqttc, obj, flags, rc):
@@ -424,7 +433,7 @@ def on_connect(mqttc, obj, flags, rc):
424433
])
425434
class TestPublishBroker2Client:
426435
def test_invalid_utf8_topic(self, callback_version, fake_broker):
427-
mqttc = client.Client(callback_version, "client-id")
436+
mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport)
428437

429438
def on_message(client, userdata, msg):
430439
with pytest.raises(UnicodeDecodeError):
@@ -466,7 +475,7 @@ def on_message(client, userdata, msg):
466475
assert not packet_in # Check connection is closed
467476

468477
def test_valid_utf8_topic_recv(self, callback_version, fake_broker):
469-
mqttc = client.Client(callback_version, "client-id")
478+
mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport)
470479

471480
# It should be non-ascii multi-bytes character
472481
topic = unicodedata.lookup('SNOWMAN')
@@ -512,7 +521,7 @@ def on_message(client, userdata, msg):
512521
assert not packet_in # Check connection is closed
513522

514523
def test_valid_utf8_topic_publish(self, callback_version, fake_broker):
515-
mqttc = client.Client(callback_version, "client-id")
524+
mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport)
516525

517526
# It should be non-ascii multi-bytes character
518527
topic = unicodedata.lookup('SNOWMAN')
@@ -558,7 +567,7 @@ def test_valid_utf8_topic_publish(self, callback_version, fake_broker):
558567
assert not packet_in # Check connection is closed
559568

560569
def test_message_callback(self, callback_version, fake_broker):
561-
mqttc = client.Client(callback_version, "client-id")
570+
mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport)
562571
userdata = {
563572
'on_message': 0,
564573
'callback1': 0,
@@ -698,6 +707,7 @@ def test_callback_v1_mqtt3(self, fake_broker):
698707
CallbackAPIVersion.VERSION1,
699708
"client-id",
700709
userdata=callback_called,
710+
transport=fake_broker.transport,
701711
)
702712

703713
def on_connect(cl, userdata, flags, rc):
@@ -823,6 +833,7 @@ def test_callback_v2_mqtt3(self, fake_broker):
823833
CallbackAPIVersion.VERSION2,
824834
"client-id",
825835
userdata=callback_called,
836+
transport=fake_broker.transport,
826837
)
827838

828839
def on_connect(cl, userdata, flags, reason, properties):

tests/testsupport/broker.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import os
23
import socket
34
import socketserver
45
import threading
@@ -9,18 +10,27 @@
910

1011

1112
class FakeBroker:
12-
def __init__(self):
13-
# Bind to "localhost" for maximum performance, as described in:
14-
# http://docs.python.org/howto/sockets.html#ipc
15-
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
16-
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
13+
def __init__(self, transport):
14+
if transport == "tcp":
15+
# Bind to "localhost" for maximum performance, as described in:
16+
# http://docs.python.org/howto/sockets.html#ipc
17+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
18+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
19+
sock.bind(("localhost", 0))
20+
self.port = sock.getsockname()[1]
21+
elif transport == "unix":
22+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
23+
sock.bind("localhost")
24+
self.port = 1883
25+
else:
26+
raise ValueError(f"unsupported transport {transport}")
27+
1728
sock.settimeout(5)
18-
sock.bind(("localhost", 0))
19-
self.port = sock.getsockname()[1]
2029
sock.listen(1)
2130

2231
self._sock = sock
2332
self._conn = None
33+
self.transport = transport
2434

2535
def start(self):
2636
if self._sock is None:
@@ -39,6 +49,12 @@ def finish(self):
3949
self._sock.close()
4050
self._sock = None
4151

52+
if self.transport == 'unix':
53+
try:
54+
os.unlink('localhost')
55+
except OSError:
56+
pass
57+
4258
def receive_packet(self, num_bytes):
4359
if self._conn is None:
4460
raise ValueError('Connection is not open')
@@ -60,10 +76,10 @@ def expect_packet(self, name, packet):
6076
paho_test.expect_packet(self._conn, name, packet)
6177

6278

63-
@pytest.fixture
64-
def fake_broker():
79+
@pytest.fixture(params=["tcp"] + (["unix"] if hasattr(socket, 'AF_UNIX') else []))
80+
def fake_broker(request):
6581
# print('Setup broker')
66-
broker = FakeBroker()
82+
broker = FakeBroker(request.param)
6783

6884
yield broker
6985

0 commit comments

Comments
 (0)