Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Broken SSH Sessions #12

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ sessions:
- `os: string` {linux, windows, online} (mandatory): Operating system of the host. Hosts of type `online` will only be checked for network availability.
- `user: string` (optional, default: robot): User on the host machine used for sending SSH commands.
- `port: int` (optional, default: none): The port that is checked to determine if a service on the host is already up.
- `ssh_port: int` (optional, default: `22`): The port that is used for SSH connections to the host.
- `hostname: string` (optional, default: `<key>` of `hosts` section): The hostname of the host PC.
- `check_nfs: bool` (optional, default: true): Whether the host should be checked for NFS status. Only supported on Linux.

Expand Down
42 changes: 28 additions & 14 deletions robmuxinator/robmuxinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,16 @@ def format(self, record):
DEFAULT_USER = "robot"
DEFAULT_HOST = socket.gethostname()
DEFAULT_PORT = None # default port None disables port check
DEFAULT_SSH_PORT = 22


class SSHClient:
"""Handle commands over ssh tunnel"""

def __init__(self, user, hostname, port=DEFAULT_PORT):
def __init__(self, user, hostname, port=DEFAULT_SSH_PORT):
self._user = user
self._hostname = hostname

if port is not None:
self._port = port
else:
self._port = 22
self._port = port

# check if user has sudo privileges
self._sudo_user = True if os.getuid() == 0 else False
Expand Down Expand Up @@ -207,6 +204,7 @@ def send_cmd(self, cmd, wait_for_exit_status=True, get_pty=False):
return returncode, stdout, stderr
except Exception as e:
logger.error("{}".format(e))
self.ssh_cli = None
return 1, None, None

def send_keys(self, session_name, keys):
Expand Down Expand Up @@ -279,12 +277,13 @@ def __init__(self, hostname, user, port=DEFAULT_PORT):
self._hostname = hostname
self._user = user
self._port = port
self._ssh_port = DEFAULT_SSH_PORT

def get_hostname(self):
return self._hostname

def get_port(self):
return self._port
def get_ssh_port(self):
return self._ssh_port

def shutdown(self, timeout=30):
pass
Expand Down Expand Up @@ -337,9 +336,10 @@ def wait_for_host(self, timeout=60):
class LinuxHost(Host):
"""Handle linux hosts"""

def __init__(self, hostname, user, port=DEFAULT_PORT, check_nfs=True):
def __init__(self, hostname, user, port=DEFAULT_PORT, ssh_port=DEFAULT_SSH_PORT, check_nfs=True):
super().__init__(hostname, user, port)
self._ssh_client = SSHClient(user, hostname, port)
self._ssh_port = ssh_port
self._ssh_client = SSHClient(user, hostname, ssh_port)
self._check_nfs = check_nfs

def shutdown(self, timeout=60):
Expand Down Expand Up @@ -391,7 +391,16 @@ def wait_for_host(self, timeout=60):
)
return False

logger.info(" {} nfs is up".format(self._hostname))
# Send an initial 'echo' command to verify if sending commands works
logger.info(" {} sending initial command".format(self._hostname))
ret = 1
while ret != 0:
ret, _, _ = self._ssh_client.send_cmd("echo", get_pty=True)
if ret != 0:
logger.error(" {} sending initial command failed".format(self._hostname))
time.sleep(0.25)
logger.info(" {} sending initial command succeeded".format(self._hostname))

return True


Expand Down Expand Up @@ -800,6 +809,11 @@ def main():
else:
port = DEFAULT_PORT

if "ssh_port" in yaml_hosts[key]:
Deleh marked this conversation as resolved.
Show resolved Hide resolved
ssh_port = yaml_hosts[key]["ssh_port"]
else:
ssh_port = DEFAULT_SSH_PORT

if "check_nfs" in yaml_hosts[key]:
check_nfs = yaml_hosts[key]["check_nfs"]
else:
Expand All @@ -819,7 +833,7 @@ def main():

if yaml_hosts[key]["os"].lower().strip() == "linux":
hosts[key] = LinuxHost(
hostname, user, port, check_nfs
hostname, user, port, ssh_port, check_nfs
)
elif yaml_hosts[key]["os"].lower().strip() == "windows":
hosts[key] = WindowsHost(hostname, user, port)
Expand Down Expand Up @@ -886,7 +900,7 @@ def main():
if key in args.sessions:
sessions.append(
Session(
SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_port()),
SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_ssh_port()),
key,
yaml_sessions[key],
envs
Expand All @@ -895,7 +909,7 @@ def main():
else:
sessions.append(
Session(
SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_port()),
SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_ssh_port()),
key,
yaml_sessions[key],
envs
Expand Down