#
# Copyright 2014 Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#
# Refer to the README and COPYING files for full details of the license
#
import collections
import contextlib
import functools
import json
import logging
import os
import pwd
import socket
import time
import uuid
import guestfs
import libvirt
import lxml.etree
import paramiko
import config
import brctl
import utils
import sysprep
[docs]def _gen_ssh_command_id():
return uuid.uuid1().hex[:8]
[docs]def _ip_to_mac(ip):
# Mac addrs of domains are 54:52:xx:xx:xx:xx where the last 4 octets are
# the hex repr of the IP address)
mac_addr_pieces = [0x54, 0x52] + [int(y) for y in ip.split('.')]
return ':'.join([('%02x' % x) for x in mac_addr_pieces])
[docs]def _guestfs_copy_path(g, guest_path, host_path):
if g.is_file(guest_path):
with open(host_path, 'w') as f:
f.write(g.read_file(guest_path))
elif g.is_dir(guest_path):
os.mkdir(host_path)
for path in g.ls(guest_path):
_guestfs_copy_path(
g,
os.path.join(
guest_path,
path,
),
os.path.join(
host_path,
os.path.basename(path)
),
)
[docs]def _path_to_xml(basename):
return os.path.join(
os.path.dirname(__file__),
basename,
)
[docs]class VirtEnv(object):
'''Env properties:
* prefix
* vms
* net
* libvirt_con
'''
def __init__(self, prefix, vm_specs, net_specs):
self.prefix = prefix
with open(self.prefix.paths.uuid(), 'r') as f:
self._uuid = f.read().strip()
self._nets = {}
for name, spec in net_specs.items():
self._nets[name] = self._create_net(spec)
self._vms = {}
for name, spec in vm_specs.items():
self._vms[name] = self._create_vm(spec)
self._libvirt_con = None
[docs] def _create_net(self, net_spec):
if net_spec['type'] == 'nat':
cls = NATNetwork
elif net_spec['type'] == 'bridge':
cls = BridgeNetwork
return cls(self, net_spec)
[docs] def _create_vm(self, vm_spec):
return VM(self, vm_spec)
[docs] def prefixed_name(self, unprefixed_name):
return '%s-%s' % (self._uuid[:8], unprefixed_name)
[docs] def virt_path(self, *args):
return self.prefix.paths.virt(*args)
[docs] def bootstrap(self):
utils.invoke_in_parallel(lambda vm: vm.bootstrap(), self._vms.values())
@property
def libvirt_con(self):
if self._libvirt_con is None:
self._libvirt_con = libvirt.open('qemu:///system')
return self._libvirt_con
[docs] def start(self):
with utils.RollbackContext() as rollback:
for net in self._nets.values():
net.start()
rollback.prependDefer(net.stop)
for vm in self._vms.values():
vm.start()
rollback.prependDefer(vm.stop)
rollback.clear()
[docs] def stop(self):
logging.info("Stopping prefix")
for vm in self._vms.values():
vm.stop()
for net in self._nets.values():
net.stop()
[docs] def get_nets(self):
return self._nets.copy()
[docs] def get_net(self, name=None):
if name:
return self.get_nets().get(name)
else:
return [
net
for net in self.get_nets().values()
if net.is_management()
].pop()
[docs] def get_vms(self):
return self._vms.copy()
[docs] def get_vm(self, name):
return self._vms[name]
@classmethod
[docs] def from_prefix(cls, prefix):
virt_path = lambda name: \
os.path.join(prefix.paths.prefix(), 'virt', name)
with open(virt_path('env'), 'r') as f:
env_dom = json.load(f)
net_specs = {}
for name in env_dom['nets']:
with open(virt_path('net-%s' % name), 'r') as f:
net_specs[name] = json.load(f)
vm_specs = {}
for name in env_dom['vms']:
with open(virt_path('vm-%s' % name), 'r') as f:
vm_specs[name] = json.load(f)
return cls(prefix, vm_specs, net_specs)
[docs] def save(self):
for net in self._nets.values():
net.save()
for vm in self._vms.values():
vm.save()
spec = {
'nets': self._nets.keys(),
'vms': self._vms.keys(),
}
with open(self.virt_path('env'), 'w') as f:
utils.json_dump(spec, f)
[docs] def create_snapshots(self, name):
utils.invoke_in_parallel(
lambda vm: vm.create_snapshot(name),
self._vms.values(),
)
[docs] def revert_snapshots(self, name):
utils.invoke_in_parallel(
lambda vm: vm.revert_snapshot(name),
self._vms.values(),
)
[docs]class Network(object):
def __init__(self, env, spec):
self._env = env
self._spec = spec
[docs] def name(self):
return self._spec['name']
[docs] def gw(self):
return self._spec.get('gw')
[docs] def is_management(self):
return self._spec.get('management', False)
[docs] def add_mappings(self, mappings):
for name, ip, mac in mappings:
self.add_mapping(name, ip, save=False)
self.save()
[docs] def add_mapping(self, name, ip, save=True):
self._spec['mapping'][name] = ip
if save:
self.save()
[docs] def resolve(self, name):
return self._spec['mapping'][name]
[docs] def _libvirt_name(self):
return self._env.prefixed_name(self.name())
[docs] def alive(self):
net_names = [
net.name()
for net in self._env.libvirt_con.listAllNetworks()
]
return self._libvirt_name() in net_names
[docs] def start(self):
if not self.alive():
logging.info('Creating network %s', self.name())
self._env.libvirt_con.networkCreateXML(self._libvirt_xml())
[docs] def stop(self):
if self.alive():
logging.info('Destroying network %s', self.name())
self._env.libvirt_con.networkLookupByName(
self._libvirt_name(),
).destroy()
[docs] def save(self):
with open(self._env.virt_path('net-%s' % self.name()), 'w') as f:
utils.json_dump(self._spec, f)
[docs]class NATNetwork(Network):
[docs] def _libvirt_xml(self):
with open(_path_to_xml('net_nat_template.xml')) as f:
net_raw_xml = f.read()
replacements = {
'@NAME@': self._libvirt_name(),
'@BR_NAME@': ('%s-nic' % self._libvirt_name())[:12],
'@GW_ADDR@': self.gw(),
}
for k, v in replacements.items():
net_raw_xml = net_raw_xml.replace(k, v, 1)
net_xml = lxml.etree.fromstring(net_raw_xml)
if 'dhcp' in self._spec:
ip = net_xml.xpath('/network/ip')[0]
def make_ip(last):
return '.'.join(
self.gw().split('.')[:-1] + [str(last)]
)
dhcp = lxml.etree.Element('dhcp')
ip.append(dhcp)
dhcp.append(
lxml.etree.Element(
'range',
start=make_ip(self._spec['dhcp']['start']),
end=make_ip(self._spec['dhcp']['end']),
)
)
if self.is_management():
for hostname, ip in self._spec['mapping'].items():
dhcp.append(
lxml.etree.Element(
'host',
mac=_ip_to_mac(ip),
ip=ip,
name=hostname
)
)
return lxml.etree.tostring(net_xml)
[docs]class BridgeNetwork(Network):
[docs] def _libvirt_xml(self):
with open(_path_to_xml('net_br_template.xml')) as f:
net_raw_xml = f.read()
replacements = {
'@NAME@': self._libvirt_name(),
'@BR_NAME@': self._libvirt_name(),
}
for k, v in replacements.items():
net_raw_xml = net_raw_xml.replace(k, v, 1)
return net_raw_xml
[docs] def start(self):
brctl.create(self._libvirt_name())
try:
super(BridgeNetwork, self).start()
except:
brctl.destroy(self._libvirt_name())
[docs] def stop(self):
super(BridgeNetwork, self).stop()
if brctl.exists(self._libvirt_name()):
brctl.destroy(self._libvirt_name())
[docs]class ServiceState:
MISSING = 0
INACTIVE = 1
ACTIVE = 2
[docs]class _Service:
def __init__(self, vm, name):
self._vm = vm
self._name = name
[docs] def exists(self):
return self.state() != ServiceState.MISSING
[docs] def alive(self):
return self.state() == ServiceState.ACTIVE
[docs] def start(self):
state = self.state()
if state == ServiceState.MISSING:
raise RuntimeError('Service %s not present' % self._name)
elif state == ServiceState.ACTIVE:
return
if self._request_start():
raise RuntimeError('Failed to start service')
[docs] def stop(self):
state = self.state()
if state == ServiceState.MISSING:
raise RuntimeError('Service %s not present' % self._name)
elif state == ServiceState.INACTIVE:
return
if self._request_stop():
raise RuntimeError('Failed to stop service')
@classmethod
[docs] def is_supported(cls, vm):
return vm.ssh(['test', '-e', cls.BIN_PATH]).code == 0
[docs]class _SystemdService(_Service):
BIN_PATH = '/usr/bin/systemctl'
[docs] def _request_start(self):
return self._vm.ssh([self.BIN_PATH, 'start', self._name])
[docs] def _request_stop(self):
return self._vm.ssh([self.BIN_PATH, 'stop', self._name])
[docs] def state(self):
ret = self._vm.ssh([self.BIN_PATH, 'status', self._name])
if not ret:
return ServiceState.ACTIVE
lines = [l.strip() for l in ret.out.split('\n')]
loaded = [l for l in lines if l.startswith('Loaded:')].pop()
if loaded.split()[1] == 'loaded':
return ServiceState.INACTIVE
return ServiceState.MISSING
[docs]class _SysVInitService(_Service):
BIN_PATH = '/sbin/service'
[docs] def _request_start(self):
return self._vm.ssh([self.BIN_PATH, self._name, 'start'])
[docs] def _request_stop(self):
return self._vm.ssh([self.BIN_PATH, self._name, 'stop'])
[docs] def state(self):
ret = self._vm.ssh([self.BIN_PATH, self._name, 'status'])
if ret.code == 0:
return ServiceState.ACTIVE
if ret.out.strip().endswith('is stopped'):
return ServiceState.INACTIVE
return ServiceState.MISSING
[docs]class _SystemdContainerService(_Service):
BIN_PATH = '/usr/bin/docker'
HOST_BIN_PATH = '/usr/bin/systemctl'
[docs] def _request_start(self):
ret = self._vm.ssh(
[self.BIN_PATH, 'exec vdsmc systemctl start', self._name]
)
if ret.code == 0:
return ret
return self._vm.ssh([self.HOST_BIN_PATH, 'start', self._name])
[docs] def _request_stop(self):
ret = self._vm.ssh(
[self.BIN_PATH, 'exec vdsmc systemctl stop', self._name]
)
if ret.code == 0:
return ret
return self._vm.ssh([self.HOST_BIN_PATH, 'stop', self._name])
[docs] def state(self):
ret = self._vm.ssh(
[self.BIN_PATH, 'exec vdsmc systemctl status', self._name])
if ret.code == 0:
return ServiceState.ACTIVE
lines = [l.strip() for l in ret.out.split('\n')]
loaded = [l for l in lines if l.startswith('Loaded:')].pop()
if loaded.split()[1] == 'loaded':
return ServiceState.INACTIVE
ret = self._vm.ssh([self.HOST_BIN_PATH, 'status', self._name])
if ret.code == 0:
return ServiceState.ACTIVE
lines = [l.strip() for l in ret.out.split('\n')]
loaded = [l for l in lines if l.startswith('Loaded:')].pop()
if loaded.split()[1] == 'loaded':
return ServiceState.INACTIVE
return ServiceState.MISSING
_SERVICE_WRAPPERS = collections.OrderedDict()
_SERVICE_WRAPPERS['systemd_container'] = _SystemdContainerService
_SERVICE_WRAPPERS['systemd'] = _SystemdService
_SERVICE_WRAPPERS['sysvinit'] = _SysVInitService
[docs]class VM(object):
'''VM properties:
* name
* cpus
* memory
* disks
* metadata
* network/mac addr
'''
def __init__(self, env, spec):
self._env = env
self._spec = self._normalize_spec(spec.copy())
self._service_class = _SERVICE_WRAPPERS.get(
self._spec.get('service_class', None),
None,
)
self._ssh_client = None
[docs] def virt_env(self):
return self._env
@classmethod
[docs] def _normalize_spec(cls, spec):
spec['snapshots'] = spec.get('snapshots', {})
spec['metadata'] = spec.get('metadata', {})
if 'root-password' not in spec:
spec['root-password'] = config.get('default_root_password')
return spec
[docs] def _open_ssh_client(self):
while self._ssh_client is None:
try:
client = paramiko.SSHClient()
client.set_missing_host_key_policy(
paramiko.AutoAddPolicy(),
)
client.connect(
self.ip(),
username='root',
key_filename=self._env.prefix.paths.ssh_id_rsa(),
timeout=1,
)
return client
except socket.error:
pass
except socket.timeout:
pass
[docs] def _check_alive(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if not self.alive():
raise RuntimeError('VM is not running')
return func(self, *args, **kwargs)
return wrapper
@_check_alive
[docs] def _get_ssh_client(self):
while True:
try:
client = paramiko.SSHClient()
client.set_missing_host_key_policy(
paramiko.AutoAddPolicy(),
)
client.connect(
self.ip(),
username='root',
key_filename=self._env.prefix.paths.ssh_id_rsa(),
timeout=1,
)
return client
except socket.error:
pass
except socket.timeout:
pass
[docs] def ssh(self, command, data=None, show_output=True):
if not self.alive():
raise RuntimeError('Attempt to ssh into offline host')
client = self._get_ssh_client()
transport = client.get_transport()
channel = transport.open_session()
joined_command = ' '.join(command)
command_id = _gen_ssh_command_id()
logging.debug(
'Running %s on %s: %s%s',
command_id,
self.name(),
joined_command,
data is not None and (' < "%s"' % data) or '',
)
channel.exec_command(joined_command)
if data is not None:
channel.send(data)
channel.shutdown_write()
rc, out, err = utils.drain_ssh_channel(
channel,
**(show_output and {} or {'stdout': None, 'stderr': None})
)
channel.close()
transport.close()
client.close()
logging.debug(
'Command %s on %s returned with %d',
command_id,
self.name(),
rc,
)
if out:
logging.debug(
'Command %s on %s output:\n %s',
command_id,
self.name(),
out,
)
if err:
logging.debug(
'Command %s on %s errors:\n %s',
command_id,
self.name(),
err,
)
return utils.CommandStatus(rc, out, err)
[docs] def wait_for_ssh(self, connect_retries=50):
while connect_retries:
ret, _, _ = self.ssh(['true'])
if ret == 0:
return
connect_retries -= 1
time.sleep(1)
raise RuntimeError('Failed to connect to remote shell')
[docs] def ssh_script(self, path, show_output=True):
with open(path) as f:
return self.ssh(
['bash', '-s'],
data=f.read(),
show_output=show_output
)
@contextlib.contextmanager
[docs] def _sftp(self):
client = self._get_ssh_client()
sftp = client.open_sftp()
try:
yield sftp
finally:
sftp.close()
client.close()
[docs] def copy_to(self, local_path, remote_path):
with self._sftp() as sftp:
sftp.put(local_path, remote_path)
[docs] def copy_from(self, remote_path, local_path):
with self._sftp() as sftp:
sftp.get(remote_path, local_path)
@property
def metadata(self):
return self._spec['metadata'].copy()
[docs] def name(self):
return str(self._spec['name'])
[docs] def iscsi_name(self):
return 'iqn.2014-07.org.lago:%s' % self.name()
[docs] def ip(self):
return str(self._env.get_net().resolve(self.name()))
[docs] def _libvirt_name(self):
return self._env.prefixed_name(self.name())
[docs] def _libvirt_xml(self):
with open(_path_to_xml('dom_template.xml')) as f:
dom_raw_xml = f.read()
qemu_kvm_path = [
path
for path in [
'/usr/libexec/qemu-kvm',
'/usr/bin/qemu-kvm',
] if os.path.exists(path)
].pop()
replacements = {
'@NAME@': self._libvirt_name(),
'@VCPU@': self._spec.get('vcpu', 4),
'@CPU@': self._spec.get('cpu', 4),
'@MEM_SIZE@': self._spec.get('memory', 16 * 1024),
'@QEMU_KVM@': qemu_kvm_path,
}
for k, v in replacements.items():
dom_raw_xml = dom_raw_xml.replace(k, str(v), 1)
dom_xml = lxml.etree.fromstring(dom_raw_xml)
devices = dom_xml.xpath('/domain/devices')[0]
disk = devices.xpath('disk')[0]
devices.remove(disk)
for dev_spec in self._spec['disks']:
disk = lxml.etree.Element(
'disk',
type='file',
device='disk',
)
disk.append(
lxml.etree.Element(
'driver',
name='qemu',
type=dev_spec['format'],
),
)
disk.append(
lxml.etree.Element(
'source',
file=dev_spec['path'],
),
)
disk.append(
lxml.etree.Element(
'target',
dev=dev_spec['dev'],
bus='virtio',
),
)
devices.append(disk)
for dev_spec in self._spec['nics']:
interface = lxml.etree.Element(
'interface',
type='network',
)
interface.append(
lxml.etree.Element(
'source',
network=self._env.prefixed_name(dev_spec['net']),
),
)
interface.append(
lxml.etree.Element(
'model',
type='virtio',
),
)
if 'ip' in dev_spec:
interface.append(
lxml.etree.Element(
'mac',
address=_ip_to_mac(dev_spec['ip'])
),
)
devices.append(interface)
return lxml.etree.tostring(dom_xml)
[docs] def start(self):
if not self.alive():
logging.info('Starting VM %s', self.name())
self._env.libvirt_con.createXML(self._libvirt_xml())
[docs] def stop(self):
if self.alive():
self._ssh_client = None
logging.info('Destroying VM %s', self.name())
self._env.libvirt_con.lookupByName(
self._libvirt_name(),
).destroy()
[docs] def alive(self):
dom_names = [
dom.name()
for dom in self._env.libvirt_con.listAllDomains()
]
return self._libvirt_name() in dom_names
[docs] def create_snapshot(self, name):
if self.alive():
self._create_live_snapshot(name)
else:
self._create_dead_snapshot(name)
self.save()
[docs] def _create_dead_snapshot(self, name):
raise RuntimeError('Dead snapshots are not implemented yet')
[docs] def _create_live_snapshot(self, name):
logging.info(
'Creating live snapshot named %s for %s',
name,
self.name(),
)
self.wait_for_ssh()
self.guest_agent().start()
self.ssh('sync'.split(' '))
dom = self._env.libvirt_con.lookupByName(self._libvirt_name())
dom_xml = lxml.etree.fromstring(dom.XMLDesc())
disks = dom_xml.xpath('devices/disk')
with open(_path_to_xml('snapshot_template.xml')) as f:
snapshot_xml = lxml.etree.fromstring(f.read())
snapshot_disks = snapshot_xml.xpath('disks')[0]
for disk in disks:
target_dev = disk.xpath('target')[0].attrib['dev']
snapshot_disks.append(lxml.etree.Element('disk', name=target_dev))
try:
dom.snapshotCreateXML(
lxml.etree.tostring(snapshot_xml),
libvirt.VIR_DOMAIN_SNAPSHOT_CREATE_DISK_ONLY |
libvirt.VIR_DOMAIN_SNAPSHOT_CREATE_QUIESCE,
)
except libvirt.libvirtError:
logging.exception(
'Failed to create snapshot %s for %s', name, self.name(),
)
raise
snap_info = []
new_disks = lxml.etree.fromstring(dom.XMLDesc()).xpath('devices/disk')
for disk, xml_node in zip(self._spec['disks'], new_disks):
disk['path'] = xml_node.xpath('source')[0].attrib['file']
disk['format'] = 'qcow2'
snap_disk = disk.copy()
snap_disk['path'] = xml_node.xpath(
'backingStore',
)[0].xpath(
'source',
)[0].attrib['file']
snap_info.append(snap_disk)
self._reclaim_disks()
self._spec['snapshots'][name] = snap_info
[docs] def revert_snapshot(self, name):
try:
snap_info = self._spec['snapshots'][name]
except KeyError:
raise RuntimeError('No snapshot %s for %s' % (name, self.name()))
logging.info('Reverting %s to snapshot %s', self.name(), name)
was_alive = self.alive()
if was_alive:
self.stop()
for disk, disk_template in zip(self._spec['disks'], snap_info):
os.unlink(disk['path'])
ret, _, _ = utils.run_command(
[
'qemu-img',
'create',
'-f', 'qcow2',
'-b', disk_template['path'],
disk['path'],
],
cwd=os.path.dirname(disk['path']),
)
if ret != 0:
raise RuntimeError('Failed to revert disk')
self._reclaim_disks()
if was_alive:
self.start()
[docs] def save(self, path=None):
if path is None:
path = self._env.virt_path('vm-%s' % self.name())
with open(path, 'w') as f:
utils.json_dump(self._spec, f)
[docs] def bootstrap(self):
logging.debug('Bootstrapping %s begin', self.name())
sysprep.sysprep(
self._spec['disks'][0]['path'],
[
sysprep.set_hostname(self.name()),
sysprep.set_root_password(self.root_password()),
sysprep.add_ssh_key(
self._env.prefix.paths.ssh_id_rsa_pub(),
with_restorecon_fix=(self.distro() == 'fc23'),
),
sysprep.set_iscsi_initiator_name(self.iscsi_name()),
sysprep.set_selinux_mode('enforcing'),
] + [
sysprep.config_net_interface_dhcp(
'eth%d' % i,
_ip_to_mac(nic['ip']),
) for i, nic in enumerate(self._spec['nics']) if 'ip' in nic
],
)
logging.debug('Bootstrapping %s end', self.name())
[docs] def _reclaim_disk(self, path):
if pwd.getpwuid(os.stat(path).st_uid).pw_name == 'qemu':
utils.run_command(['sudo', '-u', 'qemu', 'chmod', 'a+rw', path])
else:
os.chmod(path, 0666)
[docs] def _reclaim_disks(self):
for disk in self._spec['disks']:
self._reclaim_disk(disk['path'])
@_check_alive
[docs] def vnc_port(self):
dom = self._env.libvirt_con.lookupByName(self._libvirt_name())
dom_xml = lxml.etree.fromstring(dom.XMLDesc())
return dom_xml.xpath('devices/graphics').pop().attrib['port']
[docs] def _detect_service_manager(self):
logging.debug('Detecting service manager for %s', self.name())
for manager_name, service_class in _SERVICE_WRAPPERS.items():
if service_class.is_supported(self):
logging.debug(
'Setting %s as service manager for %s',
manager_name,
self.name(),
)
self._service_class = service_class
self._spec['service_class'] = manager_name
self.save()
break
@_check_alive
[docs] def service(self, name):
if self._service_class is None:
self._detect_service_manager()
return self._service_class(self, name)
[docs] def guest_agent(self):
if 'guest-agent' not in self._spec:
for possible_name in ('qemu-ga', 'qemu-guest-agent'):
if self.service(possible_name).exists():
self._spec['guest-agent'] = possible_name
self.save()
break
else:
raise RuntimeError('Could not find guest agent service')
return self.service(self._spec['guest-agent'])
@_check_alive
[docs] def interactive_ssh(self, command):
client = self._get_ssh_client()
transport = client.get_transport()
channel = transport.open_session()
try:
return utils.interactive_ssh_channel(channel, ' '.join(command))
finally:
channel.close()
transport.close()
client.close()
[docs] def nics(self):
return self._spec['nics'][:]
[docs] def distro(self):
return self._template_metadata().get('distro', None)
[docs] def root_password(self):
return self._spec['root-password']