Source code for lago.ssh

import array
import fcntl
import functools
import select
import socket
import sys
import termios
import time
import tty
import uuid
import logging

import paramiko

from . import (config, utils, log_utils, )

LOGGER = logging.getLogger(__name__)
LogTask = functools.partial(log_utils.LogTask, logger=LOGGER)
log_task = functools.partial(log_utils.log_task, logger=LOGGER)


[docs]def ssh( ip_addr, command, host_name=None, data=None, show_output=True, propagate_fail=True, tries=None, ssh_key=None, ): host_name = host_name or ip_addr client = get_ssh_client( ip_addr=ip_addr, host_name=host_name, propagate_fail=propagate_fail, ssh_tries=tries, ssh_key=ssh_key, ) transport = client.get_transport() channel = transport.open_session() joined_command = ' '.join(command) command_id = _gen_ssh_command_id() LOGGER.debug( 'Running %s on %s: %s%s', command_id, host_name, joined_command, data is not None and (' < "%s"' % data) or '', ) channel.exec_command(joined_command) if data is not None: channel.send(data) channel.shutdown_write() return_code, out, err = drain_ssh_channel( channel, **( show_output and {} or { 'stdout': None, 'stderr': None } ) ) channel.close() transport.close() client.close() LOGGER.debug( 'Command %s on %s returned with %d', command_id, host_name, return_code, ) if out: LOGGER.debug( 'Command %s on %s output:\n %s', command_id, host_name, out, ) if err: LOGGER.debug( 'Command %s on %s errors:\n %s', command_id, host_name, err, ) return utils.CommandStatus(return_code, out, err)
[docs]def wait_for_ssh(ip_addr, host_name=None, connect_retries=50, ssh_key=None): host_name = host_name or ip_addr while connect_retries: try: ret, _, _ = ssh( ip_addr=ip_addr, host_name=host_name, command=['true'], tries=1, propagate_fail=False, ssh_key=ssh_key, ) except Exception as err: ret = -1 sys.exc_clear() LOGGER.debug( 'Got exception while sshing to %s: %s', host_name, err, ) if ret == 0: break connect_retries -= 1 time.sleep(1) else: # Try one last time, using the ssh default timeout values, as we # already waited for boot_time_sec for sure ret, _, _ = ssh( ip_addr=ip_addr, host_name=host_name, command=['true'], ssh_key=ssh_key, ) if ret != 0: raise RuntimeError( 'Failed to connect remote shell to %s', host_name, ) LOGGER.debug('Wait succeeded for ssh to %s', host_name)
[docs]def ssh_script(ip_addr, path, host_name=None, show_output=True, ssh_key=None): host_name = host_name or ip_addr with open(path) as script_fd: return ssh( ip_addr=ip_addr, host_name=host_name, command=['bash', '-s'], data=script_fd.read(), show_output=show_output, ssh_key=ssh_key, )
[docs]def interactive_ssh(ip_addr, command=None, host_name=None, ssh_key=None): if command is None: command = ['bash'] client = get_ssh_client( ip_addr=ip_addr, host_name=host_name, ssh_key=ssh_key, ) transport = client.get_transport() channel = transport.open_session() try: return interactive_ssh_channel(channel, ' '.join(command)) finally: channel.close() transport.close() client.close()
[docs]def drain_ssh_channel(chan, stdin=None, stdout=sys.stdout, stderr=sys.stderr): chan.settimeout(0) out_queue = [] out_all = [] err_queue = [] err_all = [] try: stdout_is_tty = stdout.isatty() tty_w = tty_h = -1 except AttributeError: stdout_is_tty = False done = False while not done: if stdout_is_tty: arr = array.array('h', range(4)) if not fcntl.ioctl(stdout.fileno(), termios.TIOCGWINSZ, arr): if tty_h != arr[0] or tty_w != arr[1]: tty_h, tty_w = arr[:2] chan.resize_pty(width=tty_w, height=tty_h) read_streams = [] if not chan.closed: read_streams.append(chan) if stdin and not stdin.closed: read_streams.append(stdin) write_streams = [] if stdout and out_queue: write_streams.append(stdout) if stderr and err_queue: write_streams.append(stderr) read, write, _ = select.select(read_streams, write_streams, [], 0.1, ) if stdin in read: chunk = utils.read_nonblocking(stdin) if chunk: chan.send(chunk) else: chan.shutdown_write() try: if chan.recv_ready(): chunk = chan.recv(1024) if stdout: out_queue.append(chunk) out_all.append(chunk) if chan.recv_stderr_ready(): chunk = chan.recv_stderr(1024) if stderr: err_queue.append(chunk) err_all.append(chunk) except socket.error: pass if stdout in write: stdout.write(out_queue.pop(0)) stdout.flush() if stderr in write: stderr.write(err_queue.pop(0)) stderr.flush() if chan.closed and not out_queue and not err_queue: done = True return (chan.exit_status, ''.join(out_all), ''.join(err_all))
[docs]def interactive_ssh_channel(chan, command=None, stdin=sys.stdin): try: stdin_is_tty = stdin.isatty() except Exception: stdin_is_tty = False if stdin_is_tty: oldtty = termios.tcgetattr(stdin) chan.get_pty() if command is not None: chan.exec_command(command) try: if stdin_is_tty: tty.setraw(stdin.fileno()) tty.setcbreak(stdin.fileno()) return utils.CommandStatus(*drain_ssh_channel(chan, stdin)) finally: if stdin_is_tty: termios.tcsetattr(stdin, termios.TCSADRAIN, oldtty)
[docs]def _gen_ssh_command_id(): return uuid.uuid1().hex[:8]
[docs]def get_ssh_client( ip_addr, ssh_key=None, host_name=None, ssh_tries=None, propagate_fail=True, ): host_name = host_name or ip_addr with LogTask( 'Get ssh client for %s' % host_name, level='debug', propagate_fail=propagate_fail, ): ssh_timeout = int(config.get('ssh_timeout')) if ssh_tries is None: ssh_tries = int(config.get('ssh_tries')) start_time = time.time() while ssh_tries > 0: LOGGER.debug('Still got %d tries for %s', ssh_tries, host_name, ) client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy(), ) try: if ssh_key: client.connect( ip_addr, username='root', key_filename=ssh_key, timeout=ssh_timeout, ) else: client.connect( ip_addr, username='root', timeout=ssh_timeout, ) break except (socket.error, socket.timeout) as err: LOGGER.debug( 'Socket error connecting to %s: %s', host_name, err, ) except paramiko.ssh_exception.SSHException as err: LOGGER.debug( 'SSH error connecting to %s: %s', host_name, err, ) ssh_tries -= 1 time.sleep(1) else: end_time = time.time() raise RuntimeError( 'Timed out (in %d s) trying to ssh to %s' % (end_time - start_time, host_name) ) return client