diff --git a/os_faults/drivers/devstack.py b/os_faults/drivers/devstack.py index 31d254b..21edeca 100644 --- a/os_faults/drivers/devstack.py +++ b/os_faults/drivers/devstack.py @@ -21,7 +21,7 @@ from os_faults.ansible import executor from os_faults.api import cloud_management from os_faults.api import node_collection from os_faults.api import service - +from os_faults import utils HostClass = namedtuple('HostClass', ['ip', 'mac']) @@ -80,13 +80,15 @@ class DevStackService(service.Service): def get_nodes(self): return self.cloud_management.get_nodes() + @utils.require_variables('RESTART_CMD', 'SERVICE_NAME') def restart(self, nodes=None): task = {'command': self.RESTART_CMD} exec_res = self.cloud_management.execute(task) - logging.info('Restart the service, result: %s', exec_res) + logging.info('Restart %s result: %s', self.SERVICE_NAME, exec_res) class KeystoneService(DevStackService): + SERVICE_NAME = 'keystone' RESTART_CMD = 'service apache2 restart' diff --git a/os_faults/drivers/fuel.py b/os_faults/drivers/fuel.py index 35a3b43..1ee4fe5 100644 --- a/os_faults/drivers/fuel.py +++ b/os_faults/drivers/fuel.py @@ -24,6 +24,7 @@ from os_faults.api import cloud_management from os_faults.api import error from os_faults.api import node_collection from os_faults.api import service +from os_faults import utils class FuelNodeCollection(node_collection.NodeCollection): @@ -136,14 +137,14 @@ class FuelService(service.Service): power_management=self.power_management, hosts=hosts) + @utils.require_variables('RESTART_CMD', 'SERVICE_NAME') def restart(self, nodes=None): - if not getattr(self, 'RESTART_CMD'): - raise NotImplementedError('RESTART_CMD is undefined') nodes = nodes if nodes is not None else self.get_nodes() logging.info("Restart '%s' service on nodes: %s", self.SERVICE_NAME, nodes.get_ips()) self._run_task({'command': self.RESTART_CMD}, nodes) + @utils.require_variables('GREP', 'SERVICE_NAME') def kill(self, nodes=None): nodes = nodes if nodes is not None else self.get_nodes() logging.info("Kill '%s' service on nodes: %s", self.SERVICE_NAME, @@ -151,6 +152,7 @@ class FuelService(service.Service): cmd = {'kill': {'grep': self.GREP, 'sig': signal.SIGKILL}} self._run_task(cmd, nodes) + @utils.require_variables('GREP', 'SERVICE_NAME') def freeze(self, nodes=None, sec=None): nodes = nodes if nodes is not None else self.get_nodes() if sec: @@ -161,6 +163,7 @@ class FuelService(service.Service): ('for %s sec ' % sec) if sec else '', nodes.get_ips()) self._run_task(cmd, nodes) + @utils.require_variables('GREP', 'SERVICE_NAME') def unfreeze(self, nodes=None): nodes = nodes if nodes is not None else self.get_nodes() logging.info("Unfreeze '%s' service on nodes: %s", self.SERVICE_NAME, @@ -168,6 +171,7 @@ class FuelService(service.Service): cmd = {'kill': {'grep': self.GREP, 'sig': signal.SIGCONT}} self._run_task(cmd, nodes) + @utils.require_variables('PORT', 'SERVICE_NAME') def plug(self, nodes=None): nodes = nodes if nodes is not None else self.get_nodes() logging.info("Open port %d for '%s' service on nodes: %s", @@ -177,6 +181,7 @@ class FuelService(service.Service): 'action': 'unblock', 'service': self.SERVICE_NAME}}, nodes) + @utils.require_variables('PORT', 'SERVICE_NAME') def unplug(self, nodes=None): nodes = nodes if nodes is not None else self.get_nodes() logging.info("Close port %d for '%s' service on nodes: %s", diff --git a/os_faults/tests/unit/test_utils.py b/os_faults/tests/unit/test_utils.py index 448e1a2..cdc8031 100644 --- a/os_faults/tests/unit/test_utils.py +++ b/os_faults/tests/unit/test_utils.py @@ -70,3 +70,29 @@ class UtilsTestCase(test.TestCase): thread_1.join.assert_called_once() thread_2.join.assert_called_once() + + +class MyClass(object): + FOO = 10 + + @utils.require_variables('FOO') + def method(self, a, b): + return self.FOO + a + b + + @utils.require_variables('BAR', 'BAZ') + def method_that_miss_variables(self): + return self.BAR, self.BAZ + + +class RequiredVariablesTestCase(test.TestCase): + + def test_require_variables(self): + inst = MyClass() + self.assertEqual(inst.method(1, b=2), 13) + + def test_require_variables_not_implemented(self): + inst = MyClass() + err = self.assertRaises(NotImplementedError, + inst.method_that_miss_variables) + msg = 'BAR, BAZ required for MyClass.method_that_miss_variables' + self.assertEqual(str(err), msg) diff --git a/os_faults/utils.py b/os_faults/utils.py index 2e1f224..356356f 100644 --- a/os_faults/utils.py +++ b/os_faults/utils.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import logging import threading import traceback @@ -51,3 +52,22 @@ class ThreadsWrapper(object): def join_threads(self): for thread in self.threads: thread.join() + + +def require_variables(*variables): + """Class method decorator to check that required variables are present""" + def decorator(fn): + @functools.wraps(fn) + def wrapper(self, *args, **kawrgs): + missing_vars = [] + for var in variables: + if not hasattr(self, var): + missing_vars.append(var) + if missing_vars: + missing_vars = ', '.join(missing_vars) + msg = '{} required for {}.{}'.format( + missing_vars, self.__class__.__name__, fn.__name__) + raise NotImplementedError(msg) + return fn(self, *args, **kawrgs) + return wrapper + return decorator