Source code for lago.virt

#
# Copyright 2014 Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301 USA
#
# Refer to the README and COPYING files for full details of the license
#
import functools
import hashlib
import json
import logging
import os
import uuid

import lxml.etree

from . import (config, brctl, utils, log_utils, plugins, libvirt_utils, )

LOGGER = logging.getLogger(__name__)
LogTask = functools.partial(log_utils.LogTask, logger=LOGGER)
log_task = functools.partial(log_utils.log_task, logger=LOGGER)


[docs]def _gen_ssh_command_id(): return uuid.uuid1().hex[:8]
[docs]def _guestfs_copy_path(g, guest_path, host_path): if g.is_file(guest_path): with open(host_path, 'w') as f: f.write(g.read_file(guest_path)) elif g.is_dir(guest_path): os.mkdir(host_path) for path in g.ls(guest_path): _guestfs_copy_path( g, os.path.join( guest_path, path, ), os.path.join( host_path, os.path.basename(path) ), )
[docs]def _path_to_xml(basename): return os.path.join(os.path.dirname(__file__), basename, )
[docs]class VirtEnv(object): '''Env properties: * prefix * vms * net * libvirt_con ''' def __init__(self, prefix, vm_specs, net_specs): self.vm_types = plugins.load_plugins( plugins.PLUGIN_ENTRY_POINTS['vm'], instantiate=False, ) self.prefix = prefix with open(self.prefix.paths.uuid(), 'r') as uuid_fd: self.uuid = uuid_fd.read().strip() self._nets = {} for name, spec in net_specs.items(): self._nets[name] = self._create_net(spec) self._vms = {} for name, spec in vm_specs.items(): self._vms[name] = self._create_vm(spec) libvirt_url = config.get('libvirt_url') self.libvirt_con = libvirt_utils.get_libvirt_connection( name=self.uuid + libvirt_url, libvirt_url=libvirt_url, )
[docs] def get_cpu_model(self): cap_tree = lxml.etree.fromstring(self.libvirt_con.getCapabilities()) cpu_model = cap_tree.xpath('/capabilities/host/cpu/model')[0].text return cpu_model
[docs] def _create_net(self, net_spec): if net_spec['type'] == 'nat': cls = NATNetwork elif net_spec['type'] == 'bridge': cls = BridgeNetwork return cls(self, net_spec)
[docs] def _create_vm(self, vm_spec): default_vm_type = config.get('default_vm_type') vm_type_name = vm_spec.get('vm-type', default_vm_type) vm_type = self.vm_types.get(vm_type_name) vm_spec['vm-type'] = vm_type_name return vm_type(self, vm_spec)
[docs] def prefixed_name(self, unprefixed_name, max_length=0): """ Returns a uuid pefixed identifier Args: unprefixed_name(str): Name to add a prefix to max_length(int): maximum length of the resultant prefixed name, will adapt the given name and the length of the uuid ot fit it Returns: str: prefixed identifier for the given unprefixed name """ if max_length == 0: prefixed_name = '%s-%s' % (self.uuid[:8], unprefixed_name) else: if max_length < 6: raise RuntimeError( "Can't prefix with less than 6 chars (%s)" % unprefixed_name ) if max_length < 16: _uuid = self.uuid[:4] else: _uuid = self.uuid[:8] name_max_length = max_length - len(_uuid) - 1 if name_max_length < len(unprefixed_name): hashed_name = hashlib.sha1(unprefixed_name).hexdigest() unprefixed_name = hashed_name[:name_max_length] prefixed_name = '%s-%s' % (_uuid, unprefixed_name) return prefixed_name
[docs] def virt_path(self, *args): return self.prefix.paths.virt(*args)
[docs] def bootstrap(self): utils.invoke_in_parallel(lambda vm: vm.bootstrap(), self._vms.values())
[docs] def start(self, vm_names=None): if not vm_names: log_msg = 'Start Prefix' vms = self._vms.values() nets = self._nets.values() else: log_msg = 'Start specified VMs' vms = [self._vms[vm_name] for vm_name in vm_names] nets = set() for vm in vms: nets = nets.union( set( self._nets[net_name] for net_name in vm.nets() ) ) with LogTask(log_msg), utils.RollbackContext() as rollback: with LogTask('Start nets'): for net in nets: net.start() rollback.prependDefer(net.stop) with LogTask('Start vms'): for vm in vms: vm.start() rollback.prependDefer(vm.stop) rollback.clear()
[docs] def stop(self, vm_names=None): if not vm_names: log_msg = 'Stop prefix' vms = self._vms.values() nets = self._nets.values() else: log_msg = 'Stop specified VMs' vms = [self._vms[vm_name] for vm_name in vm_names] stoppable_nets = set() for vm in vms: stoppable_nets = stoppable_nets.union(vm.nets()) for vm in self._vms.values(): if not vm.defined() or vm.name() in vm_names: continue for net in vm.nets(): stoppable_nets.discard(net) nets = [self._nets[net] for net in stoppable_nets] with LogTask(log_msg): with LogTask('Stop vms'): for vm in vms: vm.stop() with LogTask('Stop nets'): for net in nets: net.stop()
[docs] def get_nets(self): return self._nets.copy()
[docs] def get_net(self, name=None): if name: return self.get_nets().get(name) else: try: return [ net for net in self.get_nets().values() if net.is_management() ].pop() except IndexError: return self.get_nets().values().pop()
[docs] def get_vms(self): return self._vms.copy()
[docs] def get_vm(self, name): return self._vms[name]
@classmethod
[docs] def from_prefix(cls, prefix): virt_path = functools.partial(prefix.paths.prefixed, 'virt') with open(virt_path('env'), 'r') as f: env_dom = json.load(f) net_specs = {} for name in env_dom['nets']: with open(virt_path('net-%s' % name), 'r') as f: net_specs[name] = json.load(f) vm_specs = {} for name in env_dom['vms']: with open(virt_path('vm-%s' % name), 'r') as f: vm_specs[name] = json.load(f) return cls(prefix, vm_specs, net_specs)
@log_task('Save prefix')
[docs] def save(self): with LogTask('Save nets'): for net in self._nets.values(): net.save() with LogTask('Save VMs'): for vm in self._vms.values(): vm.save() spec = {'nets': self._nets.keys(), 'vms': self._vms.keys(), } with LogTask('Save env'): with open(self.virt_path('env'), 'w') as f: utils.json_dump(spec, f)
@log_task('Create VMs snapshots')
[docs] def create_snapshots(self, name): utils.invoke_in_parallel( lambda vm: vm.create_snapshot(name), self._vms.values(), )
@log_task('Revert VMs snapshots')
[docs] def revert_snapshots(self, name): utils.invoke_in_parallel( lambda vm: vm.revert_snapshot(name), self._vms.values(), )
[docs] def get_snapshots(self, domains=None): """ Get the list of snapshots for each domain Args: domanins(list of str): list of the domains to get the snapshots for, all will be returned if none or empty list passed Returns: dict of str -> list(str): with the domain names and the list of snapshots for each """ snapshots = {} for vm_name, vm in self.get_vms().items(): if domains and vm_name not in domains: continue snapshots[vm_name] = vm._spec['snapshots'] return snapshots
[docs]class Network(object): def __init__(self, env, spec): self._env = env libvirt_url = config.get('libvirt_url') self.libvirt_con = libvirt_utils.get_libvirt_connection( name=env.uuid + libvirt_url, libvirt_url=libvirt_url, ) self._spec = spec
[docs] def name(self): return self._spec['name']
[docs] def gw(self): return self._spec.get('gw')
[docs] def is_management(self): return self._spec.get('management', False)
[docs] def add_mappings(self, mappings): for name, ip, mac in mappings: self.add_mapping(name, ip, save=False) self.save()
[docs] def add_mapping(self, name, ip, save=True): self._spec['mapping'][name] = ip if save: self.save()
[docs] def resolve(self, name): return self._spec['mapping'][name]
[docs] def _libvirt_name(self): return self._env.prefixed_name(self.name(), max_length=15)
[docs] def alive(self): net_names = [net.name() for net in self.libvirt_con.listAllNetworks()] return self._libvirt_name() in net_names
[docs] def start(self): if not self.alive(): with LogTask('Create network %s' % self.name()): self.libvirt_con.networkCreateXML(self._libvirt_xml())
[docs] def stop(self): if self.alive(): with LogTask('Destroy network %s' % self.name()): self.libvirt_con.networkLookupByName( self._libvirt_name(), ).destroy()
[docs] def save(self): with open(self._env.virt_path('net-%s' % self.name()), 'w') as f: utils.json_dump(self._spec, f)
[docs]class NATNetwork(Network):
[docs] def _libvirt_xml(self): with open(_path_to_xml('net_nat_template.xml')) as f: net_raw_xml = f.read() replacements = { '@NAME@': self._libvirt_name(), '@BR_NAME@': ('%s-nic' % self._libvirt_name())[:12], '@GW_ADDR@': self.gw(), } for k, v in replacements.items(): net_raw_xml = net_raw_xml.replace(k, v, 1) net_xml = lxml.etree.fromstring(net_raw_xml) dns_domain_name = self._spec.get('dns_domain_name', None) if dns_domain_name is not None: domain_xml = lxml.etree.Element( 'domain', name=dns_domain_name, localOnly='yes', ) net_xml.append(domain_xml) if 'dhcp' in self._spec: IPV6_PREFIX = 'fd8f:1391:3a82:5e0d::' ipv4 = net_xml.xpath('/network/ip')[0] ipv6 = net_xml.xpath('/network/ip')[1] dns = net_xml.xpath('/network/dns')[0] def make_ipv4(last): return '.'.join(self.gw().split('.')[:-1] + [str(last)]) dhcp = lxml.etree.Element('dhcp') dhcpv6 = lxml.etree.Element('dhcp') ipv4.append(dhcp) ipv6.append(dhcpv6) dhcp.append( lxml.etree.Element( 'range', start=make_ipv4(self._spec['dhcp']['start']), end=make_ipv4(self._spec['dhcp']['end']), ) ) dhcpv6.append( lxml.etree.Element( 'range', start=IPV6_PREFIX + make_ipv4(self._spec['dhcp']['start']), end=IPV6_PREFIX + make_ipv4(self._spec['dhcp']['end']), ) ) if self.is_management(): for hostname, ip4 in self._spec['mapping'].items(): dhcp.append( lxml.etree.Element( 'host', mac=utils.ipv4_to_mac(ip4), ip=ip4, name=hostname ) ) dhcpv6.append( lxml.etree.Element( 'host', id='0:3:0:1:' + utils.ipv4_to_mac(ip4), ip=IPV6_PREFIX + ip4, name=hostname ) ) dns_host = lxml.etree.SubElement(dns, 'host', ip=ip4) dns_name = lxml.etree.SubElement(dns_host, 'hostname') dns_name.text = hostname dns6_host = lxml.etree.SubElement( dns, 'host', ip=IPV6_PREFIX + ip4 ) dns6_name = lxml.etree.SubElement(dns6_host, 'hostname') dns6_name.text = hostname dns.append(dns_host) dns.append(dns6_host) return lxml.etree.tostring(net_xml)
[docs]class BridgeNetwork(Network):
[docs] def _libvirt_xml(self): with open(_path_to_xml('net_br_template.xml')) as f: net_raw_xml = f.read() replacements = { '@NAME@': self._libvirt_name(), '@BR_NAME@': self._libvirt_name(), } for k, v in replacements.items(): net_raw_xml = net_raw_xml.replace(k, v, 1) return net_raw_xml
[docs] def start(self): if brctl.exists(self._libvirt_name()): return brctl.create(self._libvirt_name()) try: super(BridgeNetwork, self).start() except: brctl.destroy(self._libvirt_name())
[docs] def stop(self): super(BridgeNetwork, self).stop() if brctl.exists(self._libvirt_name()): brctl.destroy(self._libvirt_name())