From 0071e273720d65aca00b49fbde99c3b3ed1d0385 Mon Sep 17 00:00:00 2001 From: Edward Hope-Morley Date: Tue, 12 Mar 2024 22:27:07 +0000 Subject: [PATCH] Support checking service ports with ssl connection By default netcat is used to check if a service is listening on a port. This is generally ok except for services expecting SSL connections which need to be properly closed and netcat can't do that. So here we add support for optionally using the python ssl library to create an ssl connection to the port and close it properly once finished. Related-Bug: #1920770 --- charmhelpers/contrib/network/ip.py | 48 +++++++++++++++++-- charmhelpers/contrib/openstack/utils.py | 24 +++++++--- tests/contrib/network/test_ip.py | 34 +++++++++++++ .../contrib/openstack/test_openstack_utils.py | 2 +- 4 files changed, 95 insertions(+), 13 deletions(-) diff --git a/charmhelpers/contrib/network/ip.py b/charmhelpers/contrib/network/ip.py index cf9926b95..f3b4864f9 100644 --- a/charmhelpers/contrib/network/ip.py +++ b/charmhelpers/contrib/network/ip.py @@ -16,6 +16,7 @@ import re import subprocess import socket +import ssl from functools import partial @@ -527,19 +528,56 @@ def get_hostname(address, fqdn=True): return result.split('.')[0] -def port_has_listener(address, port): +class SSLPortCheckInfo(object): + + def __init__(self, key, cert, ca_cert, check_hostname=False): + self.key = key + self.cert = cert + self.ca_cert = ca_cert + # NOTE: by default we do not check hostname since the port check is + # typically performed using 0.0.0.0 which will not match the + # certificate. Hence the default for this is False. + self.check_hostname = check_hostname + + @property + def ssl_context(self): + context = ssl.create_default_context() + context.check_hostname = self.check_hostname + context.load_cert_chain(self.cert, self.key) + context.load_verify_locations(self.ca_cert) + return context + + +def port_has_listener(address, port, sslinfo=None): """ Returns True if the address:port is open and being listened to, - else False. + else False. By default uses netcat to check ports but if sslinfo is + provided will use an SSL connection instead. @param address: an IP address or hostname @param port: integer port + @param sslinfo: optional SSLPortCheckInfo object. + If provided, the check is performed using an ssl + connection. Note calls 'zc' via a subprocess shell """ - cmd = ['nc', '-z', address, str(port)] - result = subprocess.call(cmd) - return not (bool(result)) + if not sslinfo: + cmd = ['nc', '-z', address, str(port)] + result = subprocess.call(cmd) + return not (bool(result)) + + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) as sock: + ssock = sslinfo.ssl_context.wrap_socket(sock, + server_hostname=address) + ssock.connect((address, port)) + # this bit is crucial to ensure tls close_notify is sent + ssock.unwrap() + + return True + except ConnectionRefusedError: + return False def assert_charm_supports_ipv6(): diff --git a/charmhelpers/contrib/openstack/utils.py b/charmhelpers/contrib/openstack/utils.py index da711c65d..82c28d8ea 100644 --- a/charmhelpers/contrib/openstack/utils.py +++ b/charmhelpers/contrib/openstack/utils.py @@ -1207,12 +1207,14 @@ def _ows_check_services_running(services, ports): return ows_check_services_running(services, ports) -def ows_check_services_running(services, ports): +def ows_check_services_running(services, ports, ssl_check_info=None): """Check that the services that should be running are actually running and that any ports specified are being listened to. @param services: list of strings OR dictionary specifying services/ports @param ports: list of ports + @param ssl_check_info: SSLPortCheckInfo object. If provided, port checks + will be done using an SSL connection. @returns state, message: strings or None, None """ messages = [] @@ -1228,7 +1230,7 @@ def ows_check_services_running(services, ports): # also verify that the ports that should be open are open # NB, that ServiceManager objects only OPTIONALLY have ports map_not_open, ports_open = ( - _check_listening_on_services_ports(services)) + _check_listening_on_services_ports(services, ssl_check_info)) if not all(ports_open): # find which service has missing ports. They are in service # order which makes it a bit easier. @@ -1243,7 +1245,8 @@ def ows_check_services_running(services, ports): if ports is not None: # and we can also check ports which we don't know the service for - ports_open, ports_open_bools = _check_listening_on_ports_list(ports) + ports_open, ports_open_bools = \ + _check_listening_on_ports_list(ports, ssl_check_info) if not all(ports_open_bools): messages.append( "Ports which should be open, but are not: {}" @@ -1302,7 +1305,8 @@ def _check_running_services(services): return list(zip(services, services_running)), services_running -def _check_listening_on_services_ports(services, test=False): +def _check_listening_on_services_ports(services, test=False, + ssl_check_info=None): """Check that the unit is actually listening (has the port open) on the ports that the service specifies are open. If test is True then the function returns the services with ports that are open rather than @@ -1312,11 +1316,14 @@ def _check_listening_on_services_ports(services, test=False): @param services: OrderedDict(service: [port, ...], ...) @param test: default=False, if False, test for closed, otherwise open. + @param ssl_check_info: SSLPortCheckInfo object. If provided, port checks + will be done using an SSL connection. @returns OrderedDict(service: [port-not-open, ...]...), [boolean] """ test = not (not (test)) # ensure test is True or False all_ports = list(itertools.chain(*services.values())) - ports_states = [port_has_listener('0.0.0.0', p) for p in all_ports] + ports_states = [port_has_listener('0.0.0.0', p, ssl_check_info) + for p in all_ports] map_ports = OrderedDict() matched_ports = [p for p, opened in zip(all_ports, ports_states) if opened == test] # essentially opened xor test @@ -1327,16 +1334,19 @@ def _check_listening_on_services_ports(services, test=False): return map_ports, ports_states -def _check_listening_on_ports_list(ports): +def _check_listening_on_ports_list(ports, ssl_check_info=None): """Check that the ports list given are being listened to Returns a list of ports being listened to and a list of the booleans. + @param ssl_check_info: SSLPortCheckInfo object. If provided, port checks + will be done using an SSL connection. @param ports: LIST of port numbers. @returns [(port_num, boolean), ...], [boolean] """ - ports_open = [port_has_listener('0.0.0.0', p) for p in ports] + ports_open = [port_has_listener('0.0.0.0', p, ssl_check_info) + for p in ports] return zip(ports, ports_open), ports_open diff --git a/tests/contrib/network/test_ip.py b/tests/contrib/network/test_ip.py index 606fc8a39..3938a63d9 100644 --- a/tests/contrib/network/test_ip.py +++ b/tests/contrib/network/test_ip.py @@ -1,5 +1,6 @@ import subprocess import unittest +from contextlib import contextmanager import mock import netifaces @@ -784,6 +785,39 @@ def test_port_has_listener(self, subprocess_call): self.assertEqual(net_ip.port_has_listener('ip-address', 70), True) subprocess_call.assert_called_with(['nc', '-z', 'ip-address', '70']) + @patch('charmhelpers.contrib.network.ip.socket') + @patch('charmhelpers.contrib.network.ip.ssl') + def test_port_has_listener_ssl(self, mock_ssl, mock_socket): + ctxt = mock.MagicMock() + mock_ssl.create_default_context.return_value = ctxt + + @contextmanager + def fake_socket(*args, **kwargs): + for x in [1]: + yield x + + mock_socket.socket.side_effect = fake_socket + sslinfo = net_ip.SSLPortCheckInfo('/etc/ssl/key', '/etc/ssl/cert', + '/etc/ssl/ca_cert') + self.assertEqual(net_ip.port_has_listener('10.0.0.1', 50, sslinfo), + True) + + @patch('charmhelpers.contrib.network.ip.socket') + @patch('charmhelpers.contrib.network.ip.ssl') + def test_port_has_listener_ssl_false(self, mock_ssl, mock_socket): + ctxt = mock.MagicMock() + mock_ssl.create_default_context.return_value = ctxt + + @contextmanager + def fake_socket(*args, **kwargs): + raise ConnectionRefusedError + + mock_socket.socket.side_effect = fake_socket + sslinfo = net_ip.SSLPortCheckInfo('/etc/ssl/key', '/etc/ssl/cert', + '/etc/ssl/ca_cert') + self.assertEqual(net_ip.port_has_listener('10.0.0.1', 50, sslinfo), + False) + @patch.object(net_ip, 'log', lambda *args, **kwargs: None) @patch.object(net_ip, 'config') @patch.object(net_ip, 'network_get_primary_address') diff --git a/tests/contrib/openstack/test_openstack_utils.py b/tests/contrib/openstack/test_openstack_utils.py index 00976049c..19fecbb4e 100644 --- a/tests/contrib/openstack/test_openstack_utils.py +++ b/tests/contrib/openstack/test_openstack_utils.py @@ -1658,7 +1658,7 @@ def test_pause_unit_retry_port_check_retries( port_has_listener.side_effect = [True, False] wait_for_ports_func = openstack.make_wait_for_ports_barrier([77]) openstack.pause_unit(None, services=['service1'], ports=[77], charm_func=wait_for_ports_func) - port_has_listener.assert_has_calls([call('0.0.0.0', 77), call('0.0.0.0', 77)]) + port_has_listener.assert_has_calls([call('0.0.0.0', 77, None), call('0.0.0.0', 77, None)]) @patch('charmhelpers.contrib.openstack.utils.service_resume') @patch('charmhelpers.contrib.openstack.utils.clear_unit_paused')