diff --git a/.travis.yml b/.travis.yml index 6671e25..38fcbfd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ language: python python: - "2.7" - - "3.5" - "3.6" - "3.7" - "3.8" diff --git a/setup.py b/setup.py index 8aadc17..dcf6274 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ license = f.read() setup(name='python-slimta', - version='4.1.1', + version='4.2.0', author='Ian Good', author_email='icgood@gmail.com', description='Lightweight, asynchronous SMTP libraries.', @@ -56,7 +56,6 @@ 'License :: OSI Approved :: MIT License', 'Programming Language :: Python', 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8']) diff --git a/slimta/edge/smtp.py b/slimta/edge/smtp.py index 612d6d0..fa04d03 100644 --- a/slimta/edge/smtp.py +++ b/slimta/edge/smtp.py @@ -66,6 +66,8 @@ class SmtpValidators(object): - ``handle_rset(reply)``: Called before replying to an RSET command. - ``handle_tls()``: Called after a successful TLS handshake. This may be at the beginning of the session or after a `STARTTLS` command. + - ``handle_tls2(ssl_socket)``: Identical to ``handle_tls()`` except the new + :class:`~ssl.SSLSocket` is passed in as an argument. :param session: When sub-classes are instantiated, instances are passed this object, stored and described in :attr:`session` below, @@ -137,8 +139,9 @@ def HELO(self, reply, helo_as): self.ehlo_as = helo_as self.envelope = None - def TLSHANDSHAKE(self): + def TLSHANDSHAKE2(self, ssl_socket): self._call_validator('tls') + self._call_validator('tls2', ssl_socket) self.security = 'TLS' def AUTH(self, reply, creds): diff --git a/slimta/smtp/__init__.py b/slimta/smtp/__init__.py index b98534f..2ceae5f 100644 --- a/slimta/smtp/__init__.py +++ b/slimta/smtp/__init__.py @@ -64,7 +64,12 @@ class BadReply(SmtpError): """ def __init__(self, data): - super(BadReply, self).__init__('Bad SMTP reply from server.') + if data: + data_str = data.decode('utf-8', 'replace') + msg = 'Bad SMTP reply from server:\r\n' + data_str + else: + msg = 'Bad SMTP reply from server.' + super(BadReply, self).__init__(msg) self.data = data diff --git a/slimta/smtp/server.py b/slimta/smtp/server.py index 83eeed6..60140b4 100644 --- a/slimta/smtp/server.py +++ b/slimta/smtp/server.py @@ -162,6 +162,7 @@ def _encrypt_session(self): if not self.io.encrypt_socket_server(self.context): return False self._call_custom_handler('TLSHANDSHAKE') + self._call_custom_handler('TLSHANDSHAKE2', self.io.socket) return True def _check_close_code(self, reply): diff --git a/slimta/util/dns.py b/slimta/util/dns.py index c393ced..85bf0f2 100644 --- a/slimta/util/dns.py +++ b/slimta/util/dns.py @@ -25,6 +25,7 @@ from __future__ import absolute_import +from collections import OrderedDict from functools import partial import pycares @@ -108,24 +109,59 @@ def _result_cb(cls, result, answer, errno): else: result.set(answer) + @classmethod + def _distinct(cls, read_fds, write_fds): + seen = set() + for fd in read_fds: + if fd not in seen: + yield fd + seen.add(fd) + for fd in write_fds: + if fd not in seen: + yield fd + seen.add(fd) + + @classmethod + def _register_fds(cls, poll, prev_fds_map): + # we must mimic the behavior of pycares sock_state_cb to maintain + # compatibility with custom DNSResolver.channel objects. + fds_map = OrderedDict() + _read_fds, _write_fds = cls._channel.getsock() + read_fds = set(_read_fds) + write_fds = set(_write_fds) + for fd in cls._distinct(_read_fds, _write_fds): + event = 0 + if fd in read_fds: + event |= select.POLLIN + if fd in write_fds: + event |= select.POLLOUT + fds_map[fd] = event + prev_event = prev_fds_map.pop(fd, 0) + if event != prev_event: + poll.register(fd, event) + for fd in prev_fds_map: + poll.unregister(fd) + return fds_map + @classmethod def _wait_channel(cls): + poll = select.poll() + fds_map = OrderedDict() try: while True: - read_fds, write_fds = cls._channel.getsock() - if not read_fds and not write_fds: + fds_map = cls._register_fds(poll, fds_map) + if not fds_map: break timeout = cls._channel.timeout() if not timeout: cls._channel.process_fd(pycares.ARES_SOCKET_BAD, pycares.ARES_SOCKET_BAD) continue - rlist, wlist, xlist = select.select( - read_fds, write_fds, [], timeout) - for fd in rlist: - cls._channel.process_fd(fd, pycares.ARES_SOCKET_BAD) - for fd in wlist: - cls._channel.process_fd(pycares.ARES_SOCKET_BAD, fd) + for fd, event in poll.poll(timeout): + if event & (select.POLLIN | select.POLLPRI): + cls._channel.process_fd(fd, pycares.ARES_SOCKET_BAD) + if event & select.POLLOUT: + cls._channel.process_fd(pycares.ARES_SOCKET_BAD, fd) except Exception: logging.log_exception(__name__) cls._channel.cancel() diff --git a/slimta/util/dnsbl.py b/slimta/util/dnsbl.py index fd5c65b..26beec1 100644 --- a/slimta/util/dnsbl.py +++ b/slimta/util/dnsbl.py @@ -91,7 +91,6 @@ def get(self, ip, timeout=None, strict=False): if exc.errno == ARES_ENOTFOUND: return False logging.log_exception(__name__, query=query) - return not strict else: return True return strict diff --git a/test/test_slimta_edge_smtp.py b/test/test_slimta_edge_smtp.py index f68b42f..e9da688 100644 --- a/test/test_slimta_edge_smtp.py +++ b/test/test_slimta_edge_smtp.py @@ -2,6 +2,7 @@ from mox3.mox import MoxTestBase, IsA, IgnoreArg import gevent from gevent.socket import create_connection +from gevent.ssl import SSLSocket from slimta.edge.smtp import SmtpEdge, SmtpSession from slimta.envelope import Envelope @@ -47,17 +48,19 @@ def test_extended_handshake(self): creds = self.mox.CreateMockAnything() creds.authcid = 'testuser' creds.authzid = 'testzid' + ssl_sock = self.mox.CreateMock(SSLSocket) mock = self.mox.CreateMockAnything() mock.__call__(IsA(SmtpSession)).AndReturn(mock) mock.handle_banner(IsA(Reply), ('127.0.0.1', 0)) mock.handle_ehlo(IsA(Reply), 'there') mock.handle_tls() + mock.handle_tls2(IsA(SSLSocket)) mock.handle_auth(IsA(Reply), creds) self.mox.ReplayAll() h = SmtpSession(('127.0.0.1', 0), mock, None) h.BANNER_(Reply('220')) h.EHLO(Reply('250'), 'there') - h.TLSHANDSHAKE() + h.TLSHANDSHAKE2(ssl_sock) h.AUTH(Reply('235'), creds) self.assertEqual('there', h.ehlo_as) self.assertTrue(h.extended_smtp) diff --git a/test/test_slimta_util_dns.py b/test/test_slimta_util_dns.py index 78a6c69..76ba2b6 100644 --- a/test/test_slimta_util_dns.py +++ b/test/test_slimta_util_dns.py @@ -42,28 +42,42 @@ def test_query(self): def test_wait_channel(self): DNSResolver._channel = channel = self.mox.CreateMockAnything() - self.mox.StubOutWithMock(select, 'select') - channel.getsock().AndReturn(('read', 'write')) + poll = self.mox.CreateMockAnything() + self.mox.StubOutWithMock(select, 'poll') + select.poll().AndReturn(poll) + channel.getsock().AndReturn(([1, 2], [2, 3])) channel.timeout().AndReturn(1.0) - select.select('read', 'write', [], 1.0).AndReturn( - ([1, 2, 3], [4, 5, 6], None)) - for fd in [1, 2, 3]: - channel.process_fd(fd, pycares.ARES_SOCKET_BAD) - for fd in [4, 5, 6]: - channel.process_fd(pycares.ARES_SOCKET_BAD, fd) - channel.getsock().AndReturn(('read', 'write')) + poll.register(1, select.POLLIN) + poll.register(2, select.POLLIN | select.POLLOUT) + poll.register(3, select.POLLOUT) + poll.poll(1.0).AndReturn([(1, select.POLLIN), (3, select.POLLOUT)]) + channel.process_fd(1, pycares.ARES_SOCKET_BAD) + channel.process_fd(pycares.ARES_SOCKET_BAD, 3) + channel.getsock().AndReturn(([1, 3], [4])) + channel.timeout().AndReturn(1.0) + poll.register(3, select.POLLIN) + poll.register(4, select.POLLOUT) + poll.unregister(2) + poll.poll(1.0).AndReturn([]) + channel.getsock().AndReturn(([1, 3], [4])) channel.timeout().AndReturn(None) channel.process_fd(pycares.ARES_SOCKET_BAD, pycares.ARES_SOCKET_BAD) - channel.getsock().AndReturn((None, None)) + channel.getsock().AndReturn(([], [])) + poll.unregister(1) + poll.unregister(3) + poll.unregister(4) self.mox.ReplayAll() DNSResolver._wait_channel() def test_wait_channel_error(self): DNSResolver._channel = channel = self.mox.CreateMockAnything() - self.mox.StubOutWithMock(select, 'select') - channel.getsock().AndReturn(('read', 'write')) + poll = self.mox.CreateMockAnything() + self.mox.StubOutWithMock(select, 'poll') + select.poll().AndReturn(poll) + channel.getsock().AndReturn(([1], [])) channel.timeout().AndReturn(1.0) - select.select('read', 'write', [], 1.0).AndRaise(ValueError(13)) + poll.register(1, select.POLLIN).AndReturn(None) + poll.poll(1.0).AndRaise(ValueError(13)) channel.cancel() self.mox.ReplayAll() with self.assertRaises(ValueError):