Skip to content

Commit

Permalink
seperate ssh port from port check
Browse files Browse the repository at this point in the history
  • Loading branch information
Deleh committed Jul 31, 2024
1 parent 4ccc173 commit 9f2fc28
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions robmuxinator/robmuxinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,16 @@ def format(self, record):
DEFAULT_USER = "robot"
DEFAULT_HOST = socket.gethostname()
DEFAULT_PORT = None # default port None disables port check
DEFAULT_SSH_PORT = 22


class SSHClient:
"""Handle commands over ssh tunnel"""

def __init__(self, user, hostname, port=DEFAULT_PORT):
def __init__(self, user, hostname, port=DEFAULT_SSH_PORT):
self._user = user
self._hostname = hostname

if port is not None:
self._port = port
else:
self._port = 22
self._port = port

# check if user has sudo privileges
self._sudo_user = True if os.getuid() == 0 else False
Expand Down Expand Up @@ -284,9 +281,6 @@ def __init__(self, hostname, user, port=DEFAULT_PORT):
def get_hostname(self):
return self._hostname

def get_port(self):
return self._port

def shutdown(self, timeout=30):
pass

Expand Down Expand Up @@ -338,11 +332,15 @@ def wait_for_host(self, timeout=60):
class LinuxHost(Host):
"""Handle linux hosts"""

def __init__(self, hostname, user, port=DEFAULT_PORT, check_nfs=True):
def __init__(self, hostname, user, port=DEFAULT_PORT, ssh_port=DEFAULT_SSH_PORT, check_nfs=True):
super().__init__(hostname, user, port)
self._ssh_client = SSHClient(user, hostname, port)
self._ssh_port = ssh_port
self._ssh_client = SSHClient(user, hostname, ssh_port)
self._check_nfs = check_nfs

def get_ssh_port(self):
return self._ssh_port

def shutdown(self, timeout=60):
logger.info(" shutting down {}...".format(self._hostname))
cmd = "nohup sh -c '( ( sudo shutdown now -P 0 > /dev/null 2>&1 ) & )'"
Expand Down Expand Up @@ -810,6 +808,11 @@ def main():
else:
port = DEFAULT_PORT

if "ssh_port" in yaml_hosts[key]:
ssh_port = yaml_hosts[key]["ssh_port"]
else:
ssh_port = DEFAULT_SSH_PORT

if "check_nfs" in yaml_hosts[key]:
check_nfs = yaml_hosts[key]["check_nfs"]
else:
Expand All @@ -829,7 +832,7 @@ def main():

if yaml_hosts[key]["os"].lower().strip() == "linux":
hosts[key] = LinuxHost(
hostname, user, port, check_nfs
hostname, user, port, ssh_port, check_nfs
)
elif yaml_hosts[key]["os"].lower().strip() == "windows":
hosts[key] = WindowsHost(hostname, user, port)
Expand Down Expand Up @@ -896,7 +899,7 @@ def main():
if key in args.sessions:
sessions.append(
Session(
SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_port()),
SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_ssh_port()),
key,
yaml_sessions[key],
envs
Expand All @@ -905,7 +908,7 @@ def main():
else:
sessions.append(
Session(
SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_port()),
SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_ssh_port()),
key,
yaml_sessions[key],
envs
Expand Down

0 comments on commit 9f2fc28

Please sign in to comment.