From f542b1e62fd78fbfeb9d32a852c77349d1d0e3e9 Mon Sep 17 00:00:00 2001 From: maximumG Date: Tue, 7 Aug 2018 10:16:41 +0200 Subject: [PATCH] Add support of manual aut_methods for SSH2 connection --- Exscript/protocols/protocol.py | 26 +++++++++++++++++++++++- Exscript/protocols/ssh2.py | 2 ++ tests/Exscript/protocols/ProtocolTest.py | 14 +++++++++++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/Exscript/protocols/protocol.py b/Exscript/protocols/protocol.py index 5fdaf7ec..d2687bac 100644 --- a/Exscript/protocols/protocol.py +++ b/Exscript/protocols/protocol.py @@ -224,7 +224,8 @@ def __init__(self, verify_fingerprint=True, account_factory=None, banner_timeout=20, - encoding='latin-1'): + encoding='latin-1', + auth_methods=[]): """ Constructor. The following events are provided: @@ -252,6 +253,9 @@ def __init__(self, :keyword banner_timeout: The time to wait for the banner. :type encoding: str :keyword encoding: The encoding of data received from the remote host. + :type auth_methods: list + :keyword auth_methods: The SSH authentication method to process (default to all supported + by the remote device) """ self.data_received_event = Event() self.otp_requested_event = Event() @@ -282,6 +286,8 @@ def __init__(self, self.banner_timeout = banner_timeout self.encoding = encoding self.send_data = None + self.auth_methods = auth_methods + if stdout is None: self.stdout = StringIO() else: @@ -611,6 +617,24 @@ def get_timeout(self): """ return self.timeout + def set_auth_methods(self, methods): + """ + Defines the SSH2 list of authentication methods allowed + + :type methods: list + :param methods: A list of authentication methods (check Exscript.protocols.ssh2.auth_type) + """ + self.auth_methods = methods + + def get_auth_methods(self): + """ + Returns the current SSH2 authentication methods allowed. + + :rtype: list + :return: A list of authentication SSH2 methods allowed. + """ + return self.auth_methods + def _connect_hook(self, host, port): """ Should be overwritten. diff --git a/Exscript/protocols/ssh2.py b/Exscript/protocols/ssh2.py index 85c00ac7..67570492 100644 --- a/Exscript/protocols/ssh2.py +++ b/Exscript/protocols/ssh2.py @@ -269,6 +269,8 @@ def _paramiko_auth_autokey(self, username, password): def _get_auth_methods(self, allowed_types): auth_methods = [] + if self.auth_methods: + allowed_types = self.auth_methods for method in allowed_types: for type_name in auth_types[method]: auth_methods.append(getattr(self, type_name)) diff --git a/tests/Exscript/protocols/ProtocolTest.py b/tests/Exscript/protocols/ProtocolTest.py index c20445d1..11d83a95 100644 --- a/tests/Exscript/protocols/ProtocolTest.py +++ b/tests/Exscript/protocols/ProtocolTest.py @@ -171,6 +171,20 @@ def testSetDriver(self): def testGetDriver(self): pass # Already tested in testSetDriver() + def testSetAuthMethods(self): + self.assertListEqual(self.protocol.get_auth_methods(), []) + + self.protocol.set_auth_methods([]) + self.assertTrue(self.protocol.get_auth_methods() is not None) + self.assertListEqual(self.protocol.get_auth_methods(), []) + + self.protocol.set_auth_methods(['password', 'publickey']) + self.assertTrue(self.protocol.get_auth_methods() is not None) + self.assertListEqual(self.protocol.get_auth_methods(), ['password', 'publickey']) + + def testGetAuthMethods(self): + pass + def testGetBanner(self): self.assertEqual(self.protocol.get_banner(), None) if self.protocol.__class__ == Protocol: