import array
import fcntl
import functools
import select
import socket
import sys
import termios
import time
import tty
import uuid
import logging
import paramiko
from . import (
utils,
log_utils,
)
from .config import config
LOGGER = logging.getLogger(__name__)
LogTask = functools.partial(log_utils.LogTask, logger=LOGGER)
log_task = functools.partial(log_utils.log_task, logger=LOGGER)
[docs]def ssh(
ip_addr,
command,
host_name=None,
data=None,
show_output=True,
propagate_fail=True,
tries=None,
ssh_key=None,
username='root',
password='123456',
):
host_name = host_name or ip_addr
client = get_ssh_client(
ip_addr=ip_addr,
host_name=host_name,
propagate_fail=propagate_fail,
ssh_tries=tries,
ssh_key=ssh_key,
username=username,
password=password,
)
transport = client.get_transport()
channel = transport.open_session()
joined_command = ' '.join(command)
command_id = _gen_ssh_command_id()
LOGGER.debug(
'Running %s on %s: %s%s',
command_id,
host_name,
joined_command,
data is not None and (' < "%s"' % data) or '',
)
channel.exec_command(joined_command)
if data is not None:
channel.send(data)
channel.shutdown_write()
return_code, out, err = drain_ssh_channel(
channel, **(show_output and {} or {
'stdout': None,
'stderr': None
})
)
channel.close()
transport.close()
client.close()
LOGGER.debug(
'Command %s on %s returned with %d',
command_id,
host_name,
return_code,
)
if out:
LOGGER.debug(
'Command %s on %s output:\n %s',
command_id,
host_name,
out,
)
if err:
LOGGER.debug(
'Command %s on %s errors:\n %s',
command_id,
host_name,
err,
)
return utils.CommandStatus(return_code, out, err)
[docs]def wait_for_ssh(
ip_addr,
host_name=None,
connect_timeout=600, # 10 minutes
ssh_key=None,
username='root',
password='123456',
):
host_name = host_name or ip_addr
start_time = time.time()
while (time.time() - start_time) < connect_timeout:
try:
ret, _, _ = ssh(
ip_addr=ip_addr,
host_name=host_name,
command=['true'],
tries=1,
propagate_fail=False,
ssh_key=ssh_key,
username=username,
password=password,
)
except Exception as err:
ret = -1
sys.exc_clear()
LOGGER.debug(
'Got exception while sshing to %s: %s',
host_name,
err,
)
if ret == 0:
break
time.sleep(1)
else:
# Try one last time, using the ssh default timeout values, as we
# already waited for boot_time_sec for sure
ret, _, _ = ssh(
ip_addr=ip_addr,
host_name=host_name,
command=['true'],
ssh_key=ssh_key,
username=username,
password=password,
)
if ret != 0:
raise RuntimeError(
'Failed to connect remote shell to %s',
host_name,
)
LOGGER.debug('Wait succeeded for ssh to %s', host_name)
[docs]def ssh_script(
ip_addr,
path,
host_name=None,
show_output=True,
ssh_key=None,
username='root',
password='123456',
):
host_name = host_name or ip_addr
with open(path) as script_fd:
return ssh(
ip_addr=ip_addr,
host_name=host_name,
command=['bash', '-s'],
data=script_fd.read(),
show_output=show_output,
ssh_key=ssh_key,
username=username,
password=password,
)
[docs]def interactive_ssh(
ip_addr,
command=None,
host_name=None,
ssh_key=None,
username='root',
password='123456',
):
if command is None:
command = ['bash']
client = get_ssh_client(
ip_addr=ip_addr,
host_name=host_name,
ssh_key=ssh_key,
username=username,
password=password,
)
transport = client.get_transport()
channel = transport.open_session()
try:
return interactive_ssh_channel(channel, ' '.join(command))
finally:
channel.close()
transport.close()
client.close()
[docs]def drain_ssh_channel(chan, stdin=None, stdout=sys.stdout, stderr=sys.stderr):
chan.settimeout(0)
out_queue = []
out_all = []
err_queue = []
err_all = []
try:
stdout_is_tty = stdout.isatty()
tty_w = tty_h = -1
except AttributeError:
stdout_is_tty = False
done = False
while not done:
if stdout_is_tty:
arr = array.array('h', range(4))
if not fcntl.ioctl(stdout.fileno(), termios.TIOCGWINSZ, arr):
if tty_h != arr[0] or tty_w != arr[1]:
tty_h, tty_w = arr[:2]
chan.resize_pty(width=tty_w, height=tty_h)
read_streams = []
if not chan.closed:
read_streams.append(chan)
if stdin and not stdin.closed:
read_streams.append(stdin)
write_streams = []
if stdout and out_queue:
write_streams.append(stdout)
if stderr and err_queue:
write_streams.append(stderr)
read, write, _ = select.select(
read_streams,
write_streams,
[],
0.1,
)
if stdin in read:
chunk = utils.read_nonblocking(stdin)
if chunk:
chan.send(chunk)
else:
chan.shutdown_write()
try:
if chan.recv_ready():
chunk = chan.recv(1024)
if stdout:
out_queue.append(chunk)
out_all.append(chunk)
if chan.recv_stderr_ready():
chunk = chan.recv_stderr(1024)
if stderr:
err_queue.append(chunk)
err_all.append(chunk)
except socket.error:
pass
if stdout in write:
stdout.write(out_queue.pop(0))
stdout.flush()
if stderr in write:
stderr.write(err_queue.pop(0))
stderr.flush()
if chan.closed and not out_queue and not err_queue:
done = True
return (chan.exit_status, ''.join(out_all), ''.join(err_all))
[docs]def interactive_ssh_channel(chan, command=None, stdin=sys.stdin):
try:
stdin_is_tty = stdin.isatty()
except Exception:
stdin_is_tty = False
if stdin_is_tty:
oldtty = termios.tcgetattr(stdin)
chan.get_pty()
if command is not None:
chan.exec_command(command)
try:
if stdin_is_tty:
tty.setraw(stdin.fileno())
tty.setcbreak(stdin.fileno())
return utils.CommandStatus(*drain_ssh_channel(chan, stdin))
finally:
if stdin_is_tty:
termios.tcsetattr(stdin, termios.TCSADRAIN, oldtty)
[docs]def _gen_ssh_command_id():
return uuid.uuid1().hex[:8]
[docs]def get_ssh_client(
ip_addr,
ssh_key=None,
host_name=None,
ssh_tries=None,
propagate_fail=True,
username='root',
password='123456',
):
host_name = host_name or ip_addr
with LogTask(
'Get ssh client for %s' % host_name,
level='debug',
propagate_fail=propagate_fail,
):
ssh_timeout = int(config.get('ssh_timeout'))
if ssh_tries is None:
ssh_tries = int(config.get('ssh_tries'))
start_time = time.time()
while ssh_tries > 0:
LOGGER.debug(
'Still got %d tries for %s',
ssh_tries,
host_name,
)
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy(), )
try:
if ssh_key:
client.connect(
ip_addr,
username=username,
password=password,
key_filename=ssh_key,
timeout=ssh_timeout,
)
else:
client.connect(
ip_addr,
username='root',
timeout=ssh_timeout,
)
break
except (socket.error, socket.timeout) as err:
LOGGER.debug(
'Socket error connecting to %s: %s',
host_name,
err,
)
except paramiko.ssh_exception.SSHException as err:
LOGGER.debug(
'SSH error connecting to %s: %s',
host_name,
err,
)
ssh_tries -= 1
time.sleep(1)
else:
end_time = time.time()
raise RuntimeError(
'Timed out (in %d s) trying to ssh to %s' %
(end_time - start_time, host_name)
)
return client