Merge "Load drivers dynamically"

This commit is contained in:
Jenkins 2016-10-10 10:14:25 +00:00 committed by Gerrit Code Review
commit 1019286b71
11 changed files with 172 additions and 18 deletions

@ -18,10 +18,7 @@ import yaml
from os_faults.api import error
from os_faults.api import human
from os_faults.drivers import devstack
from os_faults.drivers import fuel
from os_faults.drivers import ipmi
from os_faults.drivers import libvirt_driver
from os_faults import registry
__version__ = pbr.version.VersionInfo('os_faults').version_string()
@ -54,26 +51,31 @@ def _read_config(config_filename):
raise error.OSFError(msg)
def _init_driver(params):
all_drivers = registry.get_drivers()
name = params.get('driver')
if not name:
return None
if name not in all_drivers:
raise error.OSFError('Driver %s is not found' % name)
return all_drivers[name](params)
def connect(cloud_config=None, config_filename=None):
if not cloud_config:
cloud_config = _read_config(config_filename)
cloud_management = None
cloud_management_params = cloud_config.get('cloud_management') or {}
cloud_management = _init_driver(cloud_management_params)
if cloud_management_params.get('driver') == 'fuel':
cloud_management = fuel.FuelManagement(cloud_management_params)
elif cloud_management_params.get('driver') == 'devstack':
cloud_management = devstack.DevStackManagement(cloud_management_params)
if not cloud_management:
raise error.OSFError('Cloud management driver name is not specified')
power_management = None
power_management_params = cloud_config.get('power_management') or {}
if power_management_params.get('driver') == 'libvirt':
power_management = libvirt_driver.LibvirtDriver(
power_management_params)
elif power_management_params.get('driver') == 'ipmi':
power_management = ipmi.IPMIDriver(power_management_params)
power_management = _init_driver(power_management_params)
cloud_management.set_power_management(power_management)

@ -0,0 +1,20 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class BaseDriver(object):
NAME = 'base'
@classmethod
def get_driver_name(cls):
return cls.NAME

@ -15,9 +15,11 @@ import abc
import six
from os_faults.api import base_driver
@six.add_metaclass(abc.ABCMeta)
class CloudManagement(object):
class CloudManagement(base_driver.BaseDriver):
def __init__(self):
self.power_management = None

@ -15,9 +15,11 @@ import abc
import six
from os_faults.api import base_driver
@six.add_metaclass(abc.ABCMeta)
class PowerManagement(object):
class PowerManagement(base_driver.BaseDriver):
@abc.abstractmethod
def poweroff(self, hosts):

@ -96,6 +96,8 @@ SERVICE_NAME_TO_CLASS = {
class DevStackManagement(cloud_management.CloudManagement):
NAME = 'devstack'
def __init__(self, cloud_management_params):
super(DevStackManagement, self).__init__()

@ -442,6 +442,8 @@ SERVICE_NAME_TO_CLASS = {
class FuelManagement(cloud_management.CloudManagement):
NAME = 'fuel'
def __init__(self, cloud_management_params):
super(FuelManagement, self).__init__()

@ -22,6 +22,8 @@ from os_faults import utils
class IPMIDriver(power_management.PowerManagement):
NAME = 'ipmi'
def __init__(self, params):
self.mac_to_bmc = params['mac_to_bmc']

@ -21,6 +21,8 @@ from os_faults import utils
class LibvirtDriver(power_management.PowerManagement):
NAME = 'libvirt'
def __init__(self, params):
self.connection_uri = params['connection_uri']
self._cached_conn = None

69
os_faults/registry.py Normal file

@ -0,0 +1,69 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import sys
from oslo_utils import importutils
from os_faults.api import base_driver
from os_faults import drivers
DRIVERS = {}
def _import_modules_from_package():
drivers_folder = os.path.dirname(drivers.__file__)
library_root = os.path.normpath(os.path.join(
os.path.join(drivers_folder, os.pardir), os.pardir))
for root, dirs, files in os.walk(drivers_folder):
for filename in files:
if filename.startswith('__') or not filename.endswith('.py'):
continue
relative_path = os.path.relpath(os.path.join(root, filename),
library_root)
name = os.path.splitext(relative_path)[0] # remove extension
module_name = '.'.join(name.split(os.sep)) # convert / to .
if module_name not in sys.modules:
module = importutils.import_module(module_name)
sys.modules[module_name] = module
else:
module = sys.modules[module_name]
yield module
def _list_drivers():
modules = _import_modules_from_package()
for module in modules:
class_info_list = inspect.getmembers(module, inspect.isclass)
for class_info in class_info_list:
klazz = class_info[1]
if issubclass(klazz, base_driver.BaseDriver):
yield class_info[1]
def get_drivers():
global DRIVERS
if not DRIVERS:
DRIVERS = dict((k.get_driver_name(), k) for k in _list_drivers())
return DRIVERS

@ -79,6 +79,18 @@ class OSFaultsTestCase(test.TestCase):
self.assertIsInstance(destructor, fuel.FuelManagement)
self.assertIsInstance(destructor.power_management, ipmi.IPMIDriver)
def test_connect_driver_not_found(self):
cloud_config = {
'cloud_management': {
'driver': 'non-existing',
}
}
self.assertRaises(error.OSFError, os_faults.connect, cloud_config)
def test_connect_driver_not_specified(self):
cloud_config = {}
self.assertRaises(error.OSFError, os_faults.connect, cloud_config)
@mock.patch('os.path.exists', return_value=True)
def test_connect_with_config_file(self, mock_os_path_exists):
mock_os_faults_open = mock.mock_open(

@ -0,0 +1,39 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import sys
import mock
from os_faults.api import base_driver
from os_faults import drivers
from os_faults import registry
from os_faults.tests.unit import test
class TestDriver(base_driver.BaseDriver):
NAME = 'test'
class RegistryTestCase(test.TestCase):
@mock.patch('oslo_utils.importutils.import_module')
@mock.patch('os.walk')
def test_get_drivers(self, mock_os_walk, mock_import_module):
drivers_folder = os.path.dirname(drivers.__file__)
mock_os_walk.return_value = [(drivers_folder, [], ['test_driver.py'])]
mock_import_module.return_value = sys.modules[__name__]
registry.DRIVERS.clear() # reset global drivers list
self.assertEqual({'test': TestDriver}, registry.get_drivers())