Add a plugin interface for drivers
This change adds a plugin interface so that driver can be loaded dynamically. Instead of importing each driver in the launcher, provider_manager and config, the Drivers class discovers and loads driver from the driver directory. This change also adds a reset() method to the driver Config interface to reset the os_client_config reference when reloading the OpenStack driver. Change-Id: Ia347aa2501de0e05b2a7dd014c4daf1b0a4e0fb5
This commit is contained in:
parent
35defe52f3
commit
d0a67878a3
@ -21,8 +21,7 @@ import yaml
|
||||
|
||||
from nodepool import zk
|
||||
from nodepool.driver import ConfigValue
|
||||
from nodepool.driver.fake.config import FakeProviderConfig
|
||||
from nodepool.driver.openstack.config import OpenStackProviderConfig
|
||||
from nodepool.driver import Drivers
|
||||
|
||||
|
||||
class Config(ConfigValue):
|
||||
@ -59,10 +58,8 @@ def get_provider_config(provider):
|
||||
# Ensure legacy configuration still works when using fake cloud
|
||||
if provider.get('name', '').startswith('fake'):
|
||||
provider['driver'] = 'fake'
|
||||
if provider['driver'] == 'fake':
|
||||
return FakeProviderConfig(provider)
|
||||
elif provider['driver'] == 'openstack':
|
||||
return OpenStackProviderConfig(provider)
|
||||
driver = Drivers.get(provider['driver'])
|
||||
return driver['config'](provider)
|
||||
|
||||
|
||||
def openConfig(path):
|
||||
@ -90,8 +87,9 @@ def openConfig(path):
|
||||
def loadConfig(config_path):
|
||||
config = openConfig(config_path)
|
||||
|
||||
# Reset the shared os_client_config instance
|
||||
OpenStackProviderConfig.os_client_config = None
|
||||
# Call driver config reset now to clean global hooks like os_client_config
|
||||
for driver in Drivers.drivers.values():
|
||||
driver["config"].reset()
|
||||
|
||||
newconfig = Config()
|
||||
newconfig.db = None
|
||||
|
@ -16,6 +16,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import inspect
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
import six
|
||||
|
||||
@ -23,6 +27,79 @@ from nodepool import zk
|
||||
from nodepool import exceptions
|
||||
|
||||
|
||||
class Drivers:
|
||||
"""The Drivers plugin interface"""
|
||||
|
||||
log = logging.getLogger("nodepool.driver.Drivers")
|
||||
drivers = {}
|
||||
drivers_paths = None
|
||||
|
||||
@staticmethod
|
||||
def _load_class(driver_name, path, parent_class):
|
||||
"""Return a driver class that implements the parent_class"""
|
||||
spec = importlib.util.spec_from_file_location(driver_name, path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
obj = inspect.getmembers(
|
||||
module,
|
||||
lambda x: inspect.isclass(x) and issubclass(x, parent_class) and
|
||||
x.__module__ == driver_name)
|
||||
error = None
|
||||
if len(obj) > 1:
|
||||
error = "multiple %s implementation" % parent_class
|
||||
if not obj:
|
||||
error = "no %s implementation found" % parent_class
|
||||
if error:
|
||||
Drivers.log.error("%s: %s" % (path, error))
|
||||
return False
|
||||
return obj[0][1]
|
||||
|
||||
@staticmethod
|
||||
def load(drivers_paths=[]):
|
||||
"""Load drivers"""
|
||||
if drivers_paths == Drivers.drivers_paths:
|
||||
# Already loaded
|
||||
return
|
||||
Drivers.drivers.clear()
|
||||
for drivers_path in drivers_paths + [os.path.dirname(__file__)]:
|
||||
drivers = os.listdir(drivers_path)
|
||||
for driver in drivers:
|
||||
driver_path = os.path.join(drivers_path, driver)
|
||||
if driver in Drivers.drivers:
|
||||
Drivers.log.warning("%s: duplicate driver" % driver_path)
|
||||
continue
|
||||
if not os.path.isdir(driver_path) or \
|
||||
"__init__.py" not in os.listdir(driver_path):
|
||||
continue
|
||||
Drivers.log.debug("%s: loading driver" % driver_path)
|
||||
driver_obj = {}
|
||||
for name, parent_class in (
|
||||
("config", ProviderConfig),
|
||||
("handler", NodeRequestHandler),
|
||||
("provider", Provider),
|
||||
):
|
||||
driver_obj[name] = Drivers._load_class(
|
||||
driver, os.path.join(driver_path, "%s.py" % name),
|
||||
parent_class)
|
||||
if not driver_obj[name]:
|
||||
break
|
||||
if not driver_obj[name]:
|
||||
Drivers.log.error("%s: skipping incorrect driver" %
|
||||
driver_path)
|
||||
continue
|
||||
Drivers.drivers[driver] = driver_obj
|
||||
Drivers.drivers_paths = drivers_paths
|
||||
|
||||
@staticmethod
|
||||
def get(name):
|
||||
if not Drivers.drivers:
|
||||
Drivers.load()
|
||||
try:
|
||||
return Drivers.drivers[name]
|
||||
except KeyError:
|
||||
raise RuntimeError("%s: unknown driver" % name)
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Provider(object):
|
||||
"""The Provider interface
|
||||
@ -353,6 +430,10 @@ class ProviderConfig(ConfigValue):
|
||||
def __eq__(self, other):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset():
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self, newconfig):
|
||||
pass
|
||||
|
@ -66,13 +66,13 @@ class Dummy(object):
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
def get_fake_quota():
|
||||
return 100, 20, 1000000
|
||||
|
||||
|
||||
class FakeOpenStackCloud(object):
|
||||
log = logging.getLogger("nodepool.FakeOpenStackCloud")
|
||||
|
||||
@staticmethod
|
||||
def _get_quota():
|
||||
return 100, 20, 1000000
|
||||
|
||||
def __init__(self, images=None, networks=None):
|
||||
self.pause_creates = False
|
||||
self._image_list = images
|
||||
@ -100,7 +100,8 @@ class FakeOpenStackCloud(object):
|
||||
vcpus=4),
|
||||
]
|
||||
self._server_list = []
|
||||
self.max_cores, self.max_instances, self.max_ram = get_fake_quota()
|
||||
self.max_cores, self.max_instances, self.max_ram = FakeOpenStackCloud.\
|
||||
_get_quota()
|
||||
|
||||
def _get(self, name_or_id, instance_list):
|
||||
self.log.debug("Get %s in %s" % (name_or_id, repr(instance_list)))
|
||||
@ -287,9 +288,11 @@ class FakeUploadFailCloud(FakeOpenStackCloud):
|
||||
|
||||
|
||||
class FakeProvider(OpenStackProvider):
|
||||
fake_cloud = FakeOpenStackCloud
|
||||
|
||||
def __init__(self, provider, use_taskmanager):
|
||||
self.createServer_fails = 0
|
||||
self.__client = FakeOpenStackCloud()
|
||||
self.__client = FakeProvider.fake_cloud()
|
||||
super(FakeProvider, self).__init__(provider, use_taskmanager)
|
||||
|
||||
def _getClient(self):
|
||||
|
@ -96,6 +96,10 @@ class OpenStackProviderConfig(ProviderConfig):
|
||||
cloud_kwargs[arg] = self.provider[arg]
|
||||
return cloud_kwargs
|
||||
|
||||
@staticmethod
|
||||
def reset():
|
||||
OpenStackProviderConfig.os_client_config = None
|
||||
|
||||
def load(self, config):
|
||||
if OpenStackProviderConfig.os_client_config is None:
|
||||
OpenStackProviderConfig.os_client_config = \
|
||||
|
@ -28,8 +28,7 @@ from nodepool import provider_manager
|
||||
from nodepool import stats
|
||||
from nodepool import config as nodepool_config
|
||||
from nodepool import zk
|
||||
from nodepool.driver.fake.handler import FakeNodeRequestHandler
|
||||
from nodepool.driver.openstack.handler import OpenStackNodeRequestHandler
|
||||
from nodepool.driver import Drivers
|
||||
|
||||
|
||||
MINS = 60
|
||||
@ -146,12 +145,8 @@ class PoolWorker(threading.Thread):
|
||||
# ---------------------------------------------------------------
|
||||
|
||||
def _get_node_request_handler(self, provider, request):
|
||||
if provider.driver.name == 'fake':
|
||||
return FakeNodeRequestHandler(self, request)
|
||||
elif provider.driver.name == 'openstack':
|
||||
return OpenStackNodeRequestHandler(self, request)
|
||||
else:
|
||||
raise RuntimeError("Unknown provider driver %s" % provider.driver)
|
||||
driver = Drivers.get(provider.driver.name)
|
||||
return driver['handler'](self, request)
|
||||
|
||||
def _assignHandlers(self):
|
||||
'''
|
||||
|
@ -18,17 +18,12 @@
|
||||
|
||||
import logging
|
||||
|
||||
from nodepool.driver.fake.provider import FakeProvider
|
||||
from nodepool.driver.openstack.provider import OpenStackProvider
|
||||
from nodepool.driver import Drivers
|
||||
|
||||
|
||||
def get_provider(provider, use_taskmanager):
|
||||
if provider.driver.name == 'fake':
|
||||
return FakeProvider(provider, use_taskmanager)
|
||||
elif provider.driver.name == 'openstack':
|
||||
return OpenStackProvider(provider, use_taskmanager)
|
||||
else:
|
||||
raise RuntimeError("Unknown provider driver %s" % provider.driver)
|
||||
driver = Drivers.get(provider.driver.name)
|
||||
return driver['provider'](provider, use_taskmanager)
|
||||
|
||||
|
||||
class ProviderManager(object):
|
||||
|
0
nodepool/tests/fixtures/drivers/test/__init__.py
vendored
Normal file
0
nodepool/tests/fixtures/drivers/test/__init__.py
vendored
Normal file
46
nodepool/tests/fixtures/drivers/test/config.py
vendored
Normal file
46
nodepool/tests/fixtures/drivers/test/config.py
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright 2017 Red Hat
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import voluptuous as v
|
||||
|
||||
from nodepool.driver import ConfigValue
|
||||
from nodepool.driver import ProviderConfig
|
||||
|
||||
|
||||
class TestPool(ConfigValue):
|
||||
pass
|
||||
|
||||
|
||||
class TestConfig(ProviderConfig):
|
||||
def __eq__(self, other):
|
||||
return self.name == other.name
|
||||
|
||||
@staticmethod
|
||||
def reset():
|
||||
pass
|
||||
|
||||
def load(self, newconfig):
|
||||
self.pools = {}
|
||||
for pool in self.provider.get('pools', []):
|
||||
testpool = TestPool()
|
||||
testpool.name = pool['name']
|
||||
testpool.provider = self
|
||||
for label in pool['labels']:
|
||||
newconfig.labels[label].pools.append(testpool)
|
||||
self.pools[pool['name']] = testpool
|
||||
|
||||
def get_schema(self):
|
||||
pool = {'name': str,
|
||||
'labels': [str]}
|
||||
return v.Schema({'pools': [pool]})
|
34
nodepool/tests/fixtures/drivers/test/handler.py
vendored
Normal file
34
nodepool/tests/fixtures/drivers/test/handler.py
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2017 Red Hat
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from nodepool import zk
|
||||
from nodepool.driver import NodeRequestHandler
|
||||
|
||||
|
||||
class TestHandler(NodeRequestHandler):
|
||||
log = logging.getLogger("nodepool.driver.test.TestHandler")
|
||||
|
||||
def run_handler(self):
|
||||
self._setFromPoolWorker()
|
||||
node = zk.Node()
|
||||
node.state = zk.READY
|
||||
node.external_id = "test-%s" % self.request.id
|
||||
node.provider = self.provider.name
|
||||
node.launcher = self.launcher_id
|
||||
node.allocated_to = self.request.id
|
||||
node.type = self.request.node_types[0]
|
||||
self.nodeset.append(node)
|
||||
self.zk.storeNode(node)
|
46
nodepool/tests/fixtures/drivers/test/provider.py
vendored
Normal file
46
nodepool/tests/fixtures/drivers/test/provider.py
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright (C) 2011-2013 OpenStack Foundation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
#
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from nodepool.driver import Provider
|
||||
|
||||
|
||||
class TestProvider(Provider):
|
||||
def __init__(self, provider, *args):
|
||||
self.provider = provider
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def join(self):
|
||||
pass
|
||||
|
||||
def labelReady(self, name):
|
||||
return True
|
||||
|
||||
def cleanupNode(self, node_id):
|
||||
pass
|
||||
|
||||
def waitForNodeCleanup(self, node_id):
|
||||
pass
|
||||
|
||||
def cleanupLeakedResources(self):
|
||||
pass
|
||||
|
||||
def listNodes(self):
|
||||
return []
|
16
nodepool/tests/fixtures/external_driver.yaml
vendored
Normal file
16
nodepool/tests/fixtures/external_driver.yaml
vendored
Normal file
@ -0,0 +1,16 @@
|
||||
zookeeper-servers:
|
||||
- host: {zookeeper_host}
|
||||
port: {zookeeper_port}
|
||||
chroot: {zookeeper_chroot}
|
||||
|
||||
labels:
|
||||
- name: test-label
|
||||
min-ready: 1
|
||||
|
||||
providers:
|
||||
- name: test-provider
|
||||
driver: test
|
||||
pools:
|
||||
- name: test-pool
|
||||
labels:
|
||||
- test-label
|
@ -18,6 +18,7 @@ import uuid
|
||||
import fixtures
|
||||
|
||||
from nodepool import builder, exceptions, tests
|
||||
from nodepool.driver import Drivers
|
||||
from nodepool.driver.fake import provider as fakeprovider
|
||||
from nodepool import zk
|
||||
|
||||
@ -119,8 +120,8 @@ class TestNodePoolBuilder(tests.DBTestCase):
|
||||
def get_fake_client(*args, **kwargs):
|
||||
return fake_client
|
||||
|
||||
self.useFixture(fixtures.MonkeyPatch(
|
||||
'nodepool.driver.fake.provider.FakeProvider._getClient',
|
||||
self.useFixture(fixtures.MockPatchObject(
|
||||
Drivers.get('fake')['provider'], '_getClient',
|
||||
get_fake_client))
|
||||
|
||||
configfile = self.setup_config('node.yaml')
|
||||
|
39
nodepool/tests/test_drivers.py
Normal file
39
nodepool/tests/test_drivers.py
Normal file
@ -0,0 +1,39 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from nodepool import config as nodepool_config
|
||||
from nodepool import tests
|
||||
from nodepool.driver import Drivers
|
||||
|
||||
|
||||
class TestDrivers(tests.DBTestCase):
|
||||
def setup_config(self, filename):
|
||||
drivers_dir = os.path.join(
|
||||
os.path.dirname(__file__), 'fixtures', 'drivers')
|
||||
Drivers.load([drivers_dir])
|
||||
return super().setup_config(filename)
|
||||
|
||||
def test_external_driver_config(self):
|
||||
configfile = self.setup_config('external_driver.yaml')
|
||||
nodepool_config.loadConfig(configfile)
|
||||
self.assertIn("config", Drivers.get("test"))
|
||||
|
||||
def test_external_driver_handler(self):
|
||||
configfile = self.setup_config('external_driver.yaml')
|
||||
pool = self.useNodepool(configfile, watermark_sleep=1)
|
||||
pool.start()
|
||||
nodes = self.waitForNodes('test-label')
|
||||
self.assertEqual(len(nodes), 1)
|
@ -17,10 +17,10 @@ import logging
|
||||
import math
|
||||
import time
|
||||
import fixtures
|
||||
import unittest.mock as mock
|
||||
|
||||
from nodepool import tests
|
||||
from nodepool import zk
|
||||
from nodepool.driver import Drivers
|
||||
import nodepool.launcher
|
||||
|
||||
|
||||
@ -111,8 +111,7 @@ class TestLauncher(tests.DBTestCase):
|
||||
self.assertEqual(nodes[2].type, 'fake-label4')
|
||||
self.assertEqual(nodes[3].type, 'fake-label2')
|
||||
|
||||
@mock.patch('nodepool.driver.fake.provider.get_fake_quota')
|
||||
def _test_node_assignment_at_quota(self, mock_quota,
|
||||
def _test_node_assignment_at_quota(self,
|
||||
config='node_quota.yaml',
|
||||
max_cores=100,
|
||||
max_instances=20,
|
||||
@ -124,7 +123,12 @@ class TestLauncher(tests.DBTestCase):
|
||||
'''
|
||||
|
||||
# patch the cloud with requested quota
|
||||
mock_quota.return_value = (max_cores, max_instances, max_ram)
|
||||
def fake_get_quota():
|
||||
return (max_cores, max_instances, max_ram)
|
||||
self.useFixture(fixtures.MockPatchObject(
|
||||
Drivers.get('fake')['provider'].fake_cloud, '_get_quota',
|
||||
fake_get_quota
|
||||
))
|
||||
|
||||
configfile = self.setup_config(config)
|
||||
self.useBuilder(configfile)
|
||||
@ -258,9 +262,7 @@ class TestLauncher(tests.DBTestCase):
|
||||
max_instances=math.inf,
|
||||
max_ram=2 * 8192)
|
||||
|
||||
@mock.patch('nodepool.driver.fake.provider.get_fake_quota')
|
||||
def test_over_quota(self, mock_quota,
|
||||
config='node_quota_cloud.yaml'):
|
||||
def test_over_quota(self, config='node_quota_cloud.yaml'):
|
||||
'''
|
||||
This tests what happens when a cloud unexpectedly returns an
|
||||
over-quota error.
|
||||
@ -272,7 +274,12 @@ class TestLauncher(tests.DBTestCase):
|
||||
max_ram = math.inf
|
||||
|
||||
# patch the cloud with requested quota
|
||||
mock_quota.return_value = (max_cores, max_instances, max_ram)
|
||||
def fake_get_quota():
|
||||
return (max_cores, max_instances, max_ram)
|
||||
self.useFixture(fixtures.MockPatchObject(
|
||||
Drivers.get('fake')['provider'].fake_cloud, '_get_quota',
|
||||
fake_get_quota
|
||||
))
|
||||
|
||||
configfile = self.setup_config(config)
|
||||
self.useBuilder(configfile)
|
||||
@ -589,8 +596,8 @@ class TestLauncher(tests.DBTestCase):
|
||||
def fail_delete(self, name):
|
||||
raise RuntimeError('Fake Error')
|
||||
|
||||
fake_delete = 'nodepool.driver.fake.provider.FakeProvider.deleteServer'
|
||||
self.useFixture(fixtures.MonkeyPatch(fake_delete, fail_delete))
|
||||
self.useFixture(fixtures.MockPatchObject(
|
||||
Drivers.get('fake')['provider'], 'deleteServer', fail_delete))
|
||||
|
||||
configfile = self.setup_config('node.yaml')
|
||||
pool = self.useNodepool(configfile, watermark_sleep=1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user