Source code for

# Copyright 2014-2017 Red Hat, Inc.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301 USA
# Refer to the README and COPYING files for full details of the license

from __future__ import absolute_import

from future.builtins import super
from collections import defaultdict
import functools
import logging
import time
from copy import deepcopy

import six

from lxml import etree as ET
import lago.providers.libvirt.utils as libvirt_utils
from lago import brctl, log_utils, utils
import libvirt

LOGGER = logging.getLogger(__name__)
LogTask = functools.partial(log_utils.LogTask, logger=LOGGER)

[docs]class Network(object): def __init__(self, env, spec, compat): self._env = env self.libvirt_con = libvirt_utils.get_libvirt_connection( name=env.uuid, ) self._spec = spec self.compat = compat def __del__(self): if self.libvirt_con is not None: self.libvirt_con.close()
[docs] def name(self): return self._spec['name']
[docs] def gw(self): return self._spec.get('gw')
[docs] def mtu(self): if self.libvirt_con.getLibVersion() > 3001001: return self._spec.get('mtu', '1500') else: return '1500'
[docs] def is_management(self): return self._spec.get('management', False)
[docs] def add_mappings(self, mappings): for name, ip, mac in mappings: self.add_mapping(name, ip, save=False)
[docs] def add_mapping(self, name, ip, save=True): self._spec['mapping'][name] = ip if save:
[docs] def resolve(self, name): return self._spec['mapping'][name]
[docs] def mapping(self): return self._spec['mapping']
[docs] def _libvirt_name(self): return self._env.prefixed_name(, max_length=15)
[docs] def _libvirt_xml(self): raise NotImplementedError( 'should be implemented by the specific network class' )
[docs] def alive(self): flags = libvirt.VIR_CONNECT_LIST_NETWORKS_TRANSIENT \ | libvirt.VIR_CONNECT_LIST_NETWORKS_ACTIVE net_names = [ for net in self.libvirt_con.listAllNetworks(flags) ] return self._libvirt_name() in net_names
[docs] def start(self, attempts=5, timeout=2): """ Start the network, will check if the network is active ``attempts`` times, waiting ``timeout`` between each attempt. Args: attempts (int): number of attempts to check the network is active timeout (int): timeout for each attempt Returns: None Raises: RuntimeError: if network creation failed, or failed to verify it is active. """ if not self.alive(): with LogTask('Create network %s' % net = self.libvirt_con.networkCreateXML(self._libvirt_xml()) if net is None: raise RuntimeError( 'failed to create network, XML: %s' % (self._libvirt_xml()) ) for _ in range(attempts): if net.isActive(): return LOGGER.debug( 'waiting for network %s to become active', ) time.sleep(timeout) raise RuntimeError( 'failed to verify network %s is active' % )
[docs] def stop(self): if self.alive(): with LogTask('Destroy network %s' % self.libvirt_con.networkLookupByName( self._libvirt_name(), ).destroy()
[docs] def save(self): with open(self._env.virt_path('net-%s' %, 'w') as f: utils.json_dump(self._spec, f)
@property def spec(self): return deepcopy(self._spec)
[docs]class NATNetwork(Network):
[docs] def _generate_dns_forward(self, forward_ip): dns = ET.Element('dns', forwardPlainNames='yes') dns.append(ET.Element('forwarder', addr=forward_ip)) return dns
[docs] def _generate_dns_disable(self): dns = ET.Element('dns', enable='no') return dns
[docs] def _generate_main_dns(self, records, subnet, forward_plain='no'): dns = ET.Element('dns', forwardPlainNames=forward_plain) reverse_records = defaultdict(list) ipv6_prefix = self._ipv6_prefix(subnet=subnet) for hostname, ip in six.iteritems(records): reverse_records[ip] = reverse_records[ip] + [hostname] for ip, hostnames in six.iteritems(reverse_records): record_ipv4 = ET.Element('host', ip=ip) record_ipv6 = ET.Element('host', ip=ipv6_prefix + ip) for hostname in sorted(hostnames): host = ET.Element('hostname') host.text = hostname record_ipv4.append(host) record_ipv6.append(deepcopy(host)) dns.append(record_ipv4) dns.append(record_ipv6) return dns
[docs] def _ipv6_prefix(self, subnet, const='fd8f:1391:3a82:'): return '{0}{1}::'.format(const, subnet)
[docs] def _libvirt_xml(self): net_raw_xml = libvirt_utils.get_template('net_nat_template.xml') subnet ='.')[2] ipv6_prefix = self._ipv6_prefix(subnet=subnet) mtu = self.mtu() replacements = { '@NAME@': self._libvirt_name(), '@BR_NAME@': ('%s-nic' % self._libvirt_name())[:12], '@GW_ADDR@':, '@SUBNET@': subnet } for k, v in replacements.items(): net_raw_xml = net_raw_xml.replace(k, v, 1) parser = ET.XMLParser(remove_blank_text=True) net_xml = ET.fromstring(net_raw_xml, parser) if mtu != '1500': net_xml.append(ET.Element( 'mtu', size=str(mtu), )) if 'dhcp' in self._spec: ipv4 = net_xml.xpath('/network/ip')[0] ipv6 = net_xml.xpath('/network/ip')[1] def make_ipv4(last): return '.'.join('.')[:-1] + [str(last)]) dhcp = ET.Element('dhcp') dhcpv6 = ET.Element('dhcp') ipv4.append(dhcp) ipv6.append(dhcpv6) dhcp.append( ET.Element( 'range', start=make_ipv4(self._spec['dhcp']['start']), end=make_ipv4(self._spec['dhcp']['end']), ) ) dhcpv6.append( ET.Element( 'range', start=ipv6_prefix + make_ipv4(self._spec['dhcp']['start']), end=ipv6_prefix + make_ipv4(self._spec['dhcp']['end']), ) ) ipv4s = [] for hostname in sorted(six.iterkeys(self._spec['mapping'])): ip4 = self._spec['mapping'][hostname] if ip4 in ipv4s: continue ipv4s.append(ip4) dhcp.append( ET.Element( 'host', mac=utils.ipv4_to_mac(ip4), ip=ip4, name=hostname ) ) dhcpv6.append( ET.Element( 'host', id='0:3:0:1:' + utils.ipv4_to_mac(ip4), ip=ipv6_prefix + ip4, name=hostname ) ) if utils.ver_cmp(self.compat, '0.36.11') >= 0: if self.is_management(): domain_xml = ET.Element( 'domain', name=self._spec['dns_domain_name'], localOnly='yes' ) net_xml.append(domain_xml) net_xml.append( self._generate_main_dns(self._spec['dns_records'], subnet) ) else: if self.libvirt_con.getLibVersion() < 2002000: net_xml.append( self._generate_dns_forward(self._spec['dns_forward']) ) else: net_xml.append(self._generate_dns_disable()) else: LOGGER.debug( 'Generating network XML with compatibility prior to %s', self.compat ) # Prior to v0.37, DNS records were only the mappings of the # management network. if self.is_management(): if 'dns_domain_name' in self._spec: domain_xml = ET.Element( 'domain', name=self._spec['dns_domain_name'], localOnly='yes' ) net_xml.append(domain_xml) net_xml.append( self._generate_main_dns( self._spec['mapping'], subnet, forward_plain='yes' ) ) else: dns = ET.Element('dns', forwardPlainNames='yes', enable='yes') net_xml.append(dns) LOGGER.debug( 'Generated Network XML\n {0}'.format( ET.tostring(net_xml, pretty_print=True) ) ) return ET.tostring(net_xml)
[docs]class BridgeNetwork(Network):
[docs] def _libvirt_xml(self): net_raw_xml = libvirt_utils.get_template('net_br_template.xml') replacements = { '@NAME@': self._libvirt_name(), '@BR_NAME@': self._libvirt_name(), } for k, v in replacements.items(): net_raw_xml = net_raw_xml.replace(k, v, 1) return net_raw_xml
[docs] def start(self): if brctl.exists(self._libvirt_name()): return brctl.create(self._libvirt_name()) try: super().start() except: brctl.destroy(self._libvirt_name())
[docs] def stop(self): super().stop() if brctl.exists(self._libvirt_name()): brctl.destroy(self._libvirt_name())