Add a plugin interface for drivers

This change adds a plugin interface so that driver can be loaded dynamically.
Instead of importing each driver in the launcher, provider_manager and config,
the Drivers class discovers and loads driver from the driver directory.

This change also adds a reset() method to the driver Config interface to
reset the os_client_config reference when reloading the OpenStack driver.

Change-Id: Ia347aa2501de0e05b2a7dd014c4daf1b0a4e0fb5
This commit is contained in:
Tristan Cacqueray 2017-12-01 07:43:58 +00:00
parent 35defe52f3
commit d0a67878a3
14 changed files with 307 additions and 42 deletions

View File

@ -21,8 +21,7 @@ import yaml
from nodepool import zk from nodepool import zk
from nodepool.driver import ConfigValue from nodepool.driver import ConfigValue
from nodepool.driver.fake.config import FakeProviderConfig from nodepool.driver import Drivers
from nodepool.driver.openstack.config import OpenStackProviderConfig
class Config(ConfigValue): class Config(ConfigValue):
@ -59,10 +58,8 @@ def get_provider_config(provider):
# Ensure legacy configuration still works when using fake cloud # Ensure legacy configuration still works when using fake cloud
if provider.get('name', '').startswith('fake'): if provider.get('name', '').startswith('fake'):
provider['driver'] = 'fake' provider['driver'] = 'fake'
if provider['driver'] == 'fake': driver = Drivers.get(provider['driver'])
return FakeProviderConfig(provider) return driver['config'](provider)
elif provider['driver'] == 'openstack':
return OpenStackProviderConfig(provider)
def openConfig(path): def openConfig(path):
@ -90,8 +87,9 @@ def openConfig(path):
def loadConfig(config_path): def loadConfig(config_path):
config = openConfig(config_path) config = openConfig(config_path)
# Reset the shared os_client_config instance # Call driver config reset now to clean global hooks like os_client_config
OpenStackProviderConfig.os_client_config = None for driver in Drivers.drivers.values():
driver["config"].reset()
newconfig = Config() newconfig = Config()
newconfig.db = None newconfig.db = None

View File

@ -16,6 +16,10 @@
# limitations under the License. # limitations under the License.
import abc import abc
import inspect
import importlib
import logging
import os
import six import six
@ -23,6 +27,79 @@ from nodepool import zk
from nodepool import exceptions from nodepool import exceptions
class Drivers:
"""The Drivers plugin interface"""
log = logging.getLogger("nodepool.driver.Drivers")
drivers = {}
drivers_paths = None
@staticmethod
def _load_class(driver_name, path, parent_class):
"""Return a driver class that implements the parent_class"""
spec = importlib.util.spec_from_file_location(driver_name, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
obj = inspect.getmembers(
module,
lambda x: inspect.isclass(x) and issubclass(x, parent_class) and
x.__module__ == driver_name)
error = None
if len(obj) > 1:
error = "multiple %s implementation" % parent_class
if not obj:
error = "no %s implementation found" % parent_class
if error:
Drivers.log.error("%s: %s" % (path, error))
return False
return obj[0][1]
@staticmethod
def load(drivers_paths=[]):
"""Load drivers"""
if drivers_paths == Drivers.drivers_paths:
# Already loaded
return
Drivers.drivers.clear()
for drivers_path in drivers_paths + [os.path.dirname(__file__)]:
drivers = os.listdir(drivers_path)
for driver in drivers:
driver_path = os.path.join(drivers_path, driver)
if driver in Drivers.drivers:
Drivers.log.warning("%s: duplicate driver" % driver_path)
continue
if not os.path.isdir(driver_path) or \
"__init__.py" not in os.listdir(driver_path):
continue
Drivers.log.debug("%s: loading driver" % driver_path)
driver_obj = {}
for name, parent_class in (
("config", ProviderConfig),
("handler", NodeRequestHandler),
("provider", Provider),
):
driver_obj[name] = Drivers._load_class(
driver, os.path.join(driver_path, "%s.py" % name),
parent_class)
if not driver_obj[name]:
break
if not driver_obj[name]:
Drivers.log.error("%s: skipping incorrect driver" %
driver_path)
continue
Drivers.drivers[driver] = driver_obj
Drivers.drivers_paths = drivers_paths
@staticmethod
def get(name):
if not Drivers.drivers:
Drivers.load()
try:
return Drivers.drivers[name]
except KeyError:
raise RuntimeError("%s: unknown driver" % name)
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
class Provider(object): class Provider(object):
"""The Provider interface """The Provider interface
@ -353,6 +430,10 @@ class ProviderConfig(ConfigValue):
def __eq__(self, other): def __eq__(self, other):
pass pass
@abc.abstractmethod
def reset():
pass
@abc.abstractmethod @abc.abstractmethod
def load(self, newconfig): def load(self, newconfig):
pass pass

View File

@ -66,13 +66,13 @@ class Dummy(object):
setattr(self, key, value) setattr(self, key, value)
def get_fake_quota():
return 100, 20, 1000000
class FakeOpenStackCloud(object): class FakeOpenStackCloud(object):
log = logging.getLogger("nodepool.FakeOpenStackCloud") log = logging.getLogger("nodepool.FakeOpenStackCloud")
@staticmethod
def _get_quota():
return 100, 20, 1000000
def __init__(self, images=None, networks=None): def __init__(self, images=None, networks=None):
self.pause_creates = False self.pause_creates = False
self._image_list = images self._image_list = images
@ -100,7 +100,8 @@ class FakeOpenStackCloud(object):
vcpus=4), vcpus=4),
] ]
self._server_list = [] self._server_list = []
self.max_cores, self.max_instances, self.max_ram = get_fake_quota() self.max_cores, self.max_instances, self.max_ram = FakeOpenStackCloud.\
_get_quota()
def _get(self, name_or_id, instance_list): def _get(self, name_or_id, instance_list):
self.log.debug("Get %s in %s" % (name_or_id, repr(instance_list))) self.log.debug("Get %s in %s" % (name_or_id, repr(instance_list)))
@ -287,9 +288,11 @@ class FakeUploadFailCloud(FakeOpenStackCloud):
class FakeProvider(OpenStackProvider): class FakeProvider(OpenStackProvider):
fake_cloud = FakeOpenStackCloud
def __init__(self, provider, use_taskmanager): def __init__(self, provider, use_taskmanager):
self.createServer_fails = 0 self.createServer_fails = 0
self.__client = FakeOpenStackCloud() self.__client = FakeProvider.fake_cloud()
super(FakeProvider, self).__init__(provider, use_taskmanager) super(FakeProvider, self).__init__(provider, use_taskmanager)
def _getClient(self): def _getClient(self):

View File

@ -96,6 +96,10 @@ class OpenStackProviderConfig(ProviderConfig):
cloud_kwargs[arg] = self.provider[arg] cloud_kwargs[arg] = self.provider[arg]
return cloud_kwargs return cloud_kwargs
@staticmethod
def reset():
OpenStackProviderConfig.os_client_config = None
def load(self, config): def load(self, config):
if OpenStackProviderConfig.os_client_config is None: if OpenStackProviderConfig.os_client_config is None:
OpenStackProviderConfig.os_client_config = \ OpenStackProviderConfig.os_client_config = \

View File

@ -28,8 +28,7 @@ from nodepool import provider_manager
from nodepool import stats from nodepool import stats
from nodepool import config as nodepool_config from nodepool import config as nodepool_config
from nodepool import zk from nodepool import zk
from nodepool.driver.fake.handler import FakeNodeRequestHandler from nodepool.driver import Drivers
from nodepool.driver.openstack.handler import OpenStackNodeRequestHandler
MINS = 60 MINS = 60
@ -146,12 +145,8 @@ class PoolWorker(threading.Thread):
# --------------------------------------------------------------- # ---------------------------------------------------------------
def _get_node_request_handler(self, provider, request): def _get_node_request_handler(self, provider, request):
if provider.driver.name == 'fake': driver = Drivers.get(provider.driver.name)
return FakeNodeRequestHandler(self, request) return driver['handler'](self, request)
elif provider.driver.name == 'openstack':
return OpenStackNodeRequestHandler(self, request)
else:
raise RuntimeError("Unknown provider driver %s" % provider.driver)
def _assignHandlers(self): def _assignHandlers(self):
''' '''

View File

@ -18,17 +18,12 @@
import logging import logging
from nodepool.driver.fake.provider import FakeProvider from nodepool.driver import Drivers
from nodepool.driver.openstack.provider import OpenStackProvider
def get_provider(provider, use_taskmanager): def get_provider(provider, use_taskmanager):
if provider.driver.name == 'fake': driver = Drivers.get(provider.driver.name)
return FakeProvider(provider, use_taskmanager) return driver['provider'](provider, use_taskmanager)
elif provider.driver.name == 'openstack':
return OpenStackProvider(provider, use_taskmanager)
else:
raise RuntimeError("Unknown provider driver %s" % provider.driver)
class ProviderManager(object): class ProviderManager(object):

View File

View File

@ -0,0 +1,46 @@
# Copyright 2017 Red Hat
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import voluptuous as v
from nodepool.driver import ConfigValue
from nodepool.driver import ProviderConfig
class TestPool(ConfigValue):
pass
class TestConfig(ProviderConfig):
def __eq__(self, other):
return self.name == other.name
@staticmethod
def reset():
pass
def load(self, newconfig):
self.pools = {}
for pool in self.provider.get('pools', []):
testpool = TestPool()
testpool.name = pool['name']
testpool.provider = self
for label in pool['labels']:
newconfig.labels[label].pools.append(testpool)
self.pools[pool['name']] = testpool
def get_schema(self):
pool = {'name': str,
'labels': [str]}
return v.Schema({'pools': [pool]})

View File

@ -0,0 +1,34 @@
# Copyright 2017 Red Hat
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import logging
from nodepool import zk
from nodepool.driver import NodeRequestHandler
class TestHandler(NodeRequestHandler):
log = logging.getLogger("nodepool.driver.test.TestHandler")
def run_handler(self):
self._setFromPoolWorker()
node = zk.Node()
node.state = zk.READY
node.external_id = "test-%s" % self.request.id
node.provider = self.provider.name
node.launcher = self.launcher_id
node.allocated_to = self.request.id
node.type = self.request.node_types[0]
self.nodeset.append(node)
self.zk.storeNode(node)

View File

@ -0,0 +1,46 @@
# Copyright (C) 2011-2013 OpenStack Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
#
# See the License for the specific language governing permissions and
# limitations under the License.
from nodepool.driver import Provider
class TestProvider(Provider):
def __init__(self, provider, *args):
self.provider = provider
def start(self):
pass
def stop(self):
pass
def join(self):
pass
def labelReady(self, name):
return True
def cleanupNode(self, node_id):
pass
def waitForNodeCleanup(self, node_id):
pass
def cleanupLeakedResources(self):
pass
def listNodes(self):
return []

View File

@ -0,0 +1,16 @@
zookeeper-servers:
- host: {zookeeper_host}
port: {zookeeper_port}
chroot: {zookeeper_chroot}
labels:
- name: test-label
min-ready: 1
providers:
- name: test-provider
driver: test
pools:
- name: test-pool
labels:
- test-label

View File

@ -18,6 +18,7 @@ import uuid
import fixtures import fixtures
from nodepool import builder, exceptions, tests from nodepool import builder, exceptions, tests
from nodepool.driver import Drivers
from nodepool.driver.fake import provider as fakeprovider from nodepool.driver.fake import provider as fakeprovider
from nodepool import zk from nodepool import zk
@ -119,8 +120,8 @@ class TestNodePoolBuilder(tests.DBTestCase):
def get_fake_client(*args, **kwargs): def get_fake_client(*args, **kwargs):
return fake_client return fake_client
self.useFixture(fixtures.MonkeyPatch( self.useFixture(fixtures.MockPatchObject(
'nodepool.driver.fake.provider.FakeProvider._getClient', Drivers.get('fake')['provider'], '_getClient',
get_fake_client)) get_fake_client))
configfile = self.setup_config('node.yaml') configfile = self.setup_config('node.yaml')

View File

@ -0,0 +1,39 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from nodepool import config as nodepool_config
from nodepool import tests
from nodepool.driver import Drivers
class TestDrivers(tests.DBTestCase):
def setup_config(self, filename):
drivers_dir = os.path.join(
os.path.dirname(__file__), 'fixtures', 'drivers')
Drivers.load([drivers_dir])
return super().setup_config(filename)
def test_external_driver_config(self):
configfile = self.setup_config('external_driver.yaml')
nodepool_config.loadConfig(configfile)
self.assertIn("config", Drivers.get("test"))
def test_external_driver_handler(self):
configfile = self.setup_config('external_driver.yaml')
pool = self.useNodepool(configfile, watermark_sleep=1)
pool.start()
nodes = self.waitForNodes('test-label')
self.assertEqual(len(nodes), 1)

View File

@ -17,10 +17,10 @@ import logging
import math import math
import time import time
import fixtures import fixtures
import unittest.mock as mock
from nodepool import tests from nodepool import tests
from nodepool import zk from nodepool import zk
from nodepool.driver import Drivers
import nodepool.launcher import nodepool.launcher
@ -111,8 +111,7 @@ class TestLauncher(tests.DBTestCase):
self.assertEqual(nodes[2].type, 'fake-label4') self.assertEqual(nodes[2].type, 'fake-label4')
self.assertEqual(nodes[3].type, 'fake-label2') self.assertEqual(nodes[3].type, 'fake-label2')
@mock.patch('nodepool.driver.fake.provider.get_fake_quota') def _test_node_assignment_at_quota(self,
def _test_node_assignment_at_quota(self, mock_quota,
config='node_quota.yaml', config='node_quota.yaml',
max_cores=100, max_cores=100,
max_instances=20, max_instances=20,
@ -124,7 +123,12 @@ class TestLauncher(tests.DBTestCase):
''' '''
# patch the cloud with requested quota # patch the cloud with requested quota
mock_quota.return_value = (max_cores, max_instances, max_ram) def fake_get_quota():
return (max_cores, max_instances, max_ram)
self.useFixture(fixtures.MockPatchObject(
Drivers.get('fake')['provider'].fake_cloud, '_get_quota',
fake_get_quota
))
configfile = self.setup_config(config) configfile = self.setup_config(config)
self.useBuilder(configfile) self.useBuilder(configfile)
@ -258,9 +262,7 @@ class TestLauncher(tests.DBTestCase):
max_instances=math.inf, max_instances=math.inf,
max_ram=2 * 8192) max_ram=2 * 8192)
@mock.patch('nodepool.driver.fake.provider.get_fake_quota') def test_over_quota(self, config='node_quota_cloud.yaml'):
def test_over_quota(self, mock_quota,
config='node_quota_cloud.yaml'):
''' '''
This tests what happens when a cloud unexpectedly returns an This tests what happens when a cloud unexpectedly returns an
over-quota error. over-quota error.
@ -272,7 +274,12 @@ class TestLauncher(tests.DBTestCase):
max_ram = math.inf max_ram = math.inf
# patch the cloud with requested quota # patch the cloud with requested quota
mock_quota.return_value = (max_cores, max_instances, max_ram) def fake_get_quota():
return (max_cores, max_instances, max_ram)
self.useFixture(fixtures.MockPatchObject(
Drivers.get('fake')['provider'].fake_cloud, '_get_quota',
fake_get_quota
))
configfile = self.setup_config(config) configfile = self.setup_config(config)
self.useBuilder(configfile) self.useBuilder(configfile)
@ -589,8 +596,8 @@ class TestLauncher(tests.DBTestCase):
def fail_delete(self, name): def fail_delete(self, name):
raise RuntimeError('Fake Error') raise RuntimeError('Fake Error')
fake_delete = 'nodepool.driver.fake.provider.FakeProvider.deleteServer' self.useFixture(fixtures.MockPatchObject(
self.useFixture(fixtures.MonkeyPatch(fake_delete, fail_delete)) Drivers.get('fake')['provider'], 'deleteServer', fail_delete))
configfile = self.setup_config('node.yaml') configfile = self.setup_config('node.yaml')
pool = self.useNodepool(configfile, watermark_sleep=1) pool = self.useNodepool(configfile, watermark_sleep=1)