diff --git a/robmuxinator/robmuxinator.py b/robmuxinator/robmuxinator.py index 0fa5ee8..9facec6 100755 --- a/robmuxinator/robmuxinator.py +++ b/robmuxinator/robmuxinator.py @@ -115,19 +115,16 @@ def format(self, record): DEFAULT_USER = "robot" DEFAULT_HOST = socket.gethostname() DEFAULT_PORT = None # default port None disables port check +DEFAULT_SSH_PORT = 22 class SSHClient: """Handle commands over ssh tunnel""" - def __init__(self, user, hostname, port=DEFAULT_PORT): + def __init__(self, user, hostname, port=DEFAULT_SSH_PORT): self._user = user self._hostname = hostname - - if port is not None: - self._port = port - else: - self._port = 22 + self._port = port # check if user has sudo privileges self._sudo_user = True if os.getuid() == 0 else False @@ -280,12 +277,13 @@ def __init__(self, hostname, user, port=DEFAULT_PORT): self._hostname = hostname self._user = user self._port = port + self._ssh_port = DEFAULT_SSH_PORT def get_hostname(self): return self._hostname - def get_port(self): - return self._port + def get_ssh_port(self): + return self._ssh_port def shutdown(self, timeout=30): pass @@ -338,9 +336,10 @@ def wait_for_host(self, timeout=60): class LinuxHost(Host): """Handle linux hosts""" - def __init__(self, hostname, user, port=DEFAULT_PORT, check_nfs=True): + def __init__(self, hostname, user, port=DEFAULT_PORT, ssh_port=DEFAULT_SSH_PORT, check_nfs=True): super().__init__(hostname, user, port) - self._ssh_client = SSHClient(user, hostname, port) + self._ssh_port = ssh_port + self._ssh_client = SSHClient(user, hostname, ssh_port) self._check_nfs = check_nfs def shutdown(self, timeout=60): @@ -810,6 +809,11 @@ def main(): else: port = DEFAULT_PORT + if "ssh_port" in yaml_hosts[key]: + ssh_port = yaml_hosts[key]["ssh_port"] + else: + ssh_port = DEFAULT_SSH_PORT + if "check_nfs" in yaml_hosts[key]: check_nfs = yaml_hosts[key]["check_nfs"] else: @@ -829,7 +833,7 @@ def main(): if yaml_hosts[key]["os"].lower().strip() == "linux": hosts[key] = LinuxHost( - hostname, user, port, check_nfs + hostname, user, port, ssh_port, check_nfs ) elif yaml_hosts[key]["os"].lower().strip() == "windows": hosts[key] = WindowsHost(hostname, user, port) @@ -896,7 +900,7 @@ def main(): if key in args.sessions: sessions.append( Session( - SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_port()), + SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_ssh_port()), key, yaml_sessions[key], envs @@ -905,7 +909,7 @@ def main(): else: sessions.append( Session( - SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_port()), + SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_ssh_port()), key, yaml_sessions[key], envs