Add a plugin interface for drivers

This change adds a plugin interface so that driver can be loaded dynamically.
Instead of importing each driver in the launcher, provider_manager and config,
the Drivers class discovers and loads driver from the driver directory.

This change also adds a reset() method to the driver Config interface to
reset the os_client_config reference when reloading the OpenStack driver.

Change-Id: Ia347aa2501de0e05b2a7dd014c4daf1b0a4e0fb5
This commit is contained in:
Tristan Cacqueray 2017-12-01 07:43:58 +00:00
parent 35defe52f3
commit d0a67878a3
14 changed files with 307 additions and 42 deletions

View File

@ -21,8 +21,7 @@ import yaml
from nodepool import zk
from nodepool.driver import ConfigValue
from nodepool.driver.fake.config import FakeProviderConfig
from nodepool.driver.openstack.config import OpenStackProviderConfig
from nodepool.driver import Drivers
class Config(ConfigValue):
@ -59,10 +58,8 @@ def get_provider_config(provider):
# Ensure legacy configuration still works when using fake cloud
if provider.get('name', '').startswith('fake'):
provider['driver'] = 'fake'
if provider['driver'] == 'fake':
return FakeProviderConfig(provider)
elif provider['driver'] == 'openstack':
return OpenStackProviderConfig(provider)
driver = Drivers.get(provider['driver'])
return driver['config'](provider)
def openConfig(path):
@ -90,8 +87,9 @@ def openConfig(path):
def loadConfig(config_path):
config = openConfig(config_path)
# Reset the shared os_client_config instance
OpenStackProviderConfig.os_client_config = None
# Call driver config reset now to clean global hooks like os_client_config
for driver in Drivers.drivers.values():
driver["config"].reset()
newconfig = Config()
newconfig.db = None

View File

@ -16,6 +16,10 @@
# limitations under the License.
import abc
import inspect
import importlib
import logging
import os
import six
@ -23,6 +27,79 @@ from nodepool import zk
from nodepool import exceptions
class Drivers:
"""The Drivers plugin interface"""
log = logging.getLogger("nodepool.driver.Drivers")
drivers = {}
drivers_paths = None
@staticmethod
def _load_class(driver_name, path, parent_class):
"""Return a driver class that implements the parent_class"""
spec = importlib.util.spec_from_file_location(driver_name, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
obj = inspect.getmembers(
module,
lambda x: inspect.isclass(x) and issubclass(x, parent_class) and
x.__module__ == driver_name)
error = None
if len(obj) > 1:
error = "multiple %s implementation" % parent_class
if not obj:
error = "no %s implementation found" % parent_class
if error:
Drivers.log.error("%s: %s" % (path, error))
return False
return obj[0][1]
@staticmethod
def load(drivers_paths=[]):
"""Load drivers"""
if drivers_paths == Drivers.drivers_paths:
# Already loaded
return
Drivers.drivers.clear()
for drivers_path in drivers_paths + [os.path.dirname(__file__)]:
drivers = os.listdir(drivers_path)
for driver in drivers:
driver_path = os.path.join(drivers_path, driver)
if driver in Drivers.drivers:
Drivers.log.warning("%s: duplicate driver" % driver_path)
continue
if not os.path.isdir(driver_path) or \
"__init__.py" not in os.listdir(driver_path):
continue
Drivers.log.debug("%s: loading driver" % driver_path)
driver_obj = {}
for name, parent_class in (
("config", ProviderConfig),
("handler", NodeRequestHandler),
("provider", Provider),
):
driver_obj[name] = Drivers._load_class(
driver, os.path.join(driver_path, "%s.py" % name),
parent_class)
if not driver_obj[name]:
break
if not driver_obj[name]:
Drivers.log.error("%s: skipping incorrect driver" %
driver_path)
continue
Drivers.drivers[driver] = driver_obj
Drivers.drivers_paths = drivers_paths
@staticmethod
def get(name):
if not Drivers.drivers:
Drivers.load()
try:
return Drivers.drivers[name]
except KeyError:
raise RuntimeError("%s: unknown driver" % name)
@six.add_metaclass(abc.ABCMeta)
class Provider(object):
"""The Provider interface
@ -353,6 +430,10 @@ class ProviderConfig(ConfigValue):
def __eq__(self, other):
pass
@abc.abstractmethod
def reset():
pass
@abc.abstractmethod
def load(self, newconfig):
pass

View File

@ -66,13 +66,13 @@ class Dummy(object):
setattr(self, key, value)
def get_fake_quota():
return 100, 20, 1000000
class FakeOpenStackCloud(object):
log = logging.getLogger("nodepool.FakeOpenStackCloud")
@staticmethod
def _get_quota():
return 100, 20, 1000000
def __init__(self, images=None, networks=None):
self.pause_creates = False
self._image_list = images
@ -100,7 +100,8 @@ class FakeOpenStackCloud(object):
vcpus=4),
]
self._server_list = []
self.max_cores, self.max_instances, self.max_ram = get_fake_quota()
self.max_cores, self.max_instances, self.max_ram = FakeOpenStackCloud.\
_get_quota()
def _get(self, name_or_id, instance_list):
self.log.debug("Get %s in %s" % (name_or_id, repr(instance_list)))
@ -287,9 +288,11 @@ class FakeUploadFailCloud(FakeOpenStackCloud):
class FakeProvider(OpenStackProvider):
fake_cloud = FakeOpenStackCloud
def __init__(self, provider, use_taskmanager):
self.createServer_fails = 0
self.__client = FakeOpenStackCloud()
self.__client = FakeProvider.fake_cloud()
super(FakeProvider, self).__init__(provider, use_taskmanager)
def _getClient(self):

View File

@ -96,6 +96,10 @@ class OpenStackProviderConfig(ProviderConfig):
cloud_kwargs[arg] = self.provider[arg]
return cloud_kwargs
@staticmethod
def reset():
OpenStackProviderConfig.os_client_config = None
def load(self, config):
if OpenStackProviderConfig.os_client_config is None:
OpenStackProviderConfig.os_client_config = \

View File

@ -28,8 +28,7 @@ from nodepool import provider_manager
from nodepool import stats
from nodepool import config as nodepool_config
from nodepool import zk
from nodepool.driver.fake.handler import FakeNodeRequestHandler
from nodepool.driver.openstack.handler import OpenStackNodeRequestHandler
from nodepool.driver import Drivers
MINS = 60
@ -146,12 +145,8 @@ class PoolWorker(threading.Thread):
# ---------------------------------------------------------------
def _get_node_request_handler(self, provider, request):
if provider.driver.name == 'fake':
return FakeNodeRequestHandler(self, request)
elif provider.driver.name == 'openstack':
return OpenStackNodeRequestHandler(self, request)
else:
raise RuntimeError("Unknown provider driver %s" % provider.driver)
driver = Drivers.get(provider.driver.name)
return driver['handler'](self, request)
def _assignHandlers(self):
'''

View File

@ -18,17 +18,12 @@
import logging
from nodepool.driver.fake.provider import FakeProvider
from nodepool.driver.openstack.provider import OpenStackProvider
from nodepool.driver import Drivers
def get_provider(provider, use_taskmanager):
if provider.driver.name == 'fake':
return FakeProvider(provider, use_taskmanager)
elif provider.driver.name == 'openstack':
return OpenStackProvider(provider, use_taskmanager)
else:
raise RuntimeError("Unknown provider driver %s" % provider.driver)
driver = Drivers.get(provider.driver.name)
return driver['provider'](provider, use_taskmanager)
class ProviderManager(object):

View File

View File

@ -0,0 +1,46 @@
# Copyright 2017 Red Hat
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import voluptuous as v
from nodepool.driver import ConfigValue
from nodepool.driver import ProviderConfig
class TestPool(ConfigValue):
pass
class TestConfig(ProviderConfig):
def __eq__(self, other):
return self.name == other.name
@staticmethod
def reset():
pass
def load(self, newconfig):
self.pools = {}
for pool in self.provider.get('pools', []):
testpool = TestPool()
testpool.name = pool['name']
testpool.provider = self
for label in pool['labels']:
newconfig.labels[label].pools.append(testpool)
self.pools[pool['name']] = testpool
def get_schema(self):
pool = {'name': str,
'labels': [str]}
return v.Schema({'pools': [pool]})

View File

@ -0,0 +1,34 @@
# Copyright 2017 Red Hat
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import logging
from nodepool import zk
from nodepool.driver import NodeRequestHandler
class TestHandler(NodeRequestHandler):
log = logging.getLogger("nodepool.driver.test.TestHandler")
def run_handler(self):
self._setFromPoolWorker()
node = zk.Node()
node.state = zk.READY
node.external_id = "test-%s" % self.request.id
node.provider = self.provider.name
node.launcher = self.launcher_id
node.allocated_to = self.request.id
node.type = self.request.node_types[0]
self.nodeset.append(node)
self.zk.storeNode(node)

View File

@ -0,0 +1,46 @@
# Copyright (C) 2011-2013 OpenStack Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
#
# See the License for the specific language governing permissions and
# limitations under the License.
from nodepool.driver import Provider
class TestProvider(Provider):
def __init__(self, provider, *args):
self.provider = provider
def start(self):
pass
def stop(self):
pass
def join(self):
pass
def labelReady(self, name):
return True
def cleanupNode(self, node_id):
pass
def waitForNodeCleanup(self, node_id):
pass
def cleanupLeakedResources(self):
pass
def listNodes(self):
return []

View File

@ -0,0 +1,16 @@
zookeeper-servers:
- host: {zookeeper_host}
port: {zookeeper_port}
chroot: {zookeeper_chroot}
labels:
- name: test-label
min-ready: 1
providers:
- name: test-provider
driver: test
pools:
- name: test-pool
labels:
- test-label

View File

@ -18,6 +18,7 @@ import uuid
import fixtures
from nodepool import builder, exceptions, tests
from nodepool.driver import Drivers
from nodepool.driver.fake import provider as fakeprovider
from nodepool import zk
@ -119,8 +120,8 @@ class TestNodePoolBuilder(tests.DBTestCase):
def get_fake_client(*args, **kwargs):
return fake_client
self.useFixture(fixtures.MonkeyPatch(
'nodepool.driver.fake.provider.FakeProvider._getClient',
self.useFixture(fixtures.MockPatchObject(
Drivers.get('fake')['provider'], '_getClient',
get_fake_client))
configfile = self.setup_config('node.yaml')

View File

@ -0,0 +1,39 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from nodepool import config as nodepool_config
from nodepool import tests
from nodepool.driver import Drivers
class TestDrivers(tests.DBTestCase):
def setup_config(self, filename):
drivers_dir = os.path.join(
os.path.dirname(__file__), 'fixtures', 'drivers')
Drivers.load([drivers_dir])
return super().setup_config(filename)
def test_external_driver_config(self):
configfile = self.setup_config('external_driver.yaml')
nodepool_config.loadConfig(configfile)
self.assertIn("config", Drivers.get("test"))
def test_external_driver_handler(self):
configfile = self.setup_config('external_driver.yaml')
pool = self.useNodepool(configfile, watermark_sleep=1)
pool.start()
nodes = self.waitForNodes('test-label')
self.assertEqual(len(nodes), 1)

View File

@ -17,10 +17,10 @@ import logging
import math
import time
import fixtures
import unittest.mock as mock
from nodepool import tests
from nodepool import zk
from nodepool.driver import Drivers
import nodepool.launcher
@ -111,8 +111,7 @@ class TestLauncher(tests.DBTestCase):
self.assertEqual(nodes[2].type, 'fake-label4')
self.assertEqual(nodes[3].type, 'fake-label2')
@mock.patch('nodepool.driver.fake.provider.get_fake_quota')
def _test_node_assignment_at_quota(self, mock_quota,
def _test_node_assignment_at_quota(self,
config='node_quota.yaml',
max_cores=100,
max_instances=20,
@ -124,7 +123,12 @@ class TestLauncher(tests.DBTestCase):
'''
# patch the cloud with requested quota
mock_quota.return_value = (max_cores, max_instances, max_ram)
def fake_get_quota():
return (max_cores, max_instances, max_ram)
self.useFixture(fixtures.MockPatchObject(
Drivers.get('fake')['provider'].fake_cloud, '_get_quota',
fake_get_quota
))
configfile = self.setup_config(config)
self.useBuilder(configfile)
@ -258,9 +262,7 @@ class TestLauncher(tests.DBTestCase):
max_instances=math.inf,
max_ram=2 * 8192)
@mock.patch('nodepool.driver.fake.provider.get_fake_quota')
def test_over_quota(self, mock_quota,
config='node_quota_cloud.yaml'):
def test_over_quota(self, config='node_quota_cloud.yaml'):
'''
This tests what happens when a cloud unexpectedly returns an
over-quota error.
@ -272,7 +274,12 @@ class TestLauncher(tests.DBTestCase):
max_ram = math.inf
# patch the cloud with requested quota
mock_quota.return_value = (max_cores, max_instances, max_ram)
def fake_get_quota():
return (max_cores, max_instances, max_ram)
self.useFixture(fixtures.MockPatchObject(
Drivers.get('fake')['provider'].fake_cloud, '_get_quota',
fake_get_quota
))
configfile = self.setup_config(config)
self.useBuilder(configfile)
@ -589,8 +596,8 @@ class TestLauncher(tests.DBTestCase):
def fail_delete(self, name):
raise RuntimeError('Fake Error')
fake_delete = 'nodepool.driver.fake.provider.FakeProvider.deleteServer'
self.useFixture(fixtures.MonkeyPatch(fake_delete, fail_delete))
self.useFixture(fixtures.MockPatchObject(
Drivers.get('fake')['provider'], 'deleteServer', fail_delete))
configfile = self.setup_config('node.yaml')
pool = self.useNodepool(configfile, watermark_sleep=1)