Decorator for required variables

Added decorator to check that required class variables are present.

Change-Id: I6b90a7fccd4dca5675322ea2b9457c0fd0dcc658
This commit is contained in:
Anton Studenov 2016-10-13 17:30:54 +03:00
parent dc64ad38c2
commit 80a82024ea
4 changed files with 57 additions and 4 deletions

View File

@ -21,7 +21,7 @@ from os_faults.ansible import executor
from os_faults.api import cloud_management
from os_faults.api import node_collection
from os_faults.api import service
from os_faults import utils
HostClass = namedtuple('HostClass', ['ip', 'mac'])
@ -80,13 +80,15 @@ class DevStackService(service.Service):
def get_nodes(self):
return self.cloud_management.get_nodes()
@utils.require_variables('RESTART_CMD', 'SERVICE_NAME')
def restart(self, nodes=None):
task = {'command': self.RESTART_CMD}
exec_res = self.cloud_management.execute(task)
logging.info('Restart the service, result: %s', exec_res)
logging.info('Restart %s result: %s', self.SERVICE_NAME, exec_res)
class KeystoneService(DevStackService):
SERVICE_NAME = 'keystone'
RESTART_CMD = 'service apache2 restart'

View File

@ -24,6 +24,7 @@ from os_faults.api import cloud_management
from os_faults.api import error
from os_faults.api import node_collection
from os_faults.api import service
from os_faults import utils
class FuelNodeCollection(node_collection.NodeCollection):
@ -136,14 +137,14 @@ class FuelService(service.Service):
power_management=self.power_management,
hosts=hosts)
@utils.require_variables('RESTART_CMD', 'SERVICE_NAME')
def restart(self, nodes=None):
if not getattr(self, 'RESTART_CMD'):
raise NotImplementedError('RESTART_CMD is undefined')
nodes = nodes if nodes is not None else self.get_nodes()
logging.info("Restart '%s' service on nodes: %s", self.SERVICE_NAME,
nodes.get_ips())
self._run_task({'command': self.RESTART_CMD}, nodes)
@utils.require_variables('GREP', 'SERVICE_NAME')
def kill(self, nodes=None):
nodes = nodes if nodes is not None else self.get_nodes()
logging.info("Kill '%s' service on nodes: %s", self.SERVICE_NAME,
@ -151,6 +152,7 @@ class FuelService(service.Service):
cmd = {'kill': {'grep': self.GREP, 'sig': signal.SIGKILL}}
self._run_task(cmd, nodes)
@utils.require_variables('GREP', 'SERVICE_NAME')
def freeze(self, nodes=None, sec=None):
nodes = nodes if nodes is not None else self.get_nodes()
if sec:
@ -161,6 +163,7 @@ class FuelService(service.Service):
('for %s sec ' % sec) if sec else '', nodes.get_ips())
self._run_task(cmd, nodes)
@utils.require_variables('GREP', 'SERVICE_NAME')
def unfreeze(self, nodes=None):
nodes = nodes if nodes is not None else self.get_nodes()
logging.info("Unfreeze '%s' service on nodes: %s", self.SERVICE_NAME,
@ -168,6 +171,7 @@ class FuelService(service.Service):
cmd = {'kill': {'grep': self.GREP, 'sig': signal.SIGCONT}}
self._run_task(cmd, nodes)
@utils.require_variables('PORT', 'SERVICE_NAME')
def plug(self, nodes=None):
nodes = nodes if nodes is not None else self.get_nodes()
logging.info("Open port %d for '%s' service on nodes: %s",
@ -177,6 +181,7 @@ class FuelService(service.Service):
'action': 'unblock',
'service': self.SERVICE_NAME}}, nodes)
@utils.require_variables('PORT', 'SERVICE_NAME')
def unplug(self, nodes=None):
nodes = nodes if nodes is not None else self.get_nodes()
logging.info("Close port %d for '%s' service on nodes: %s",

View File

@ -70,3 +70,29 @@ class UtilsTestCase(test.TestCase):
thread_1.join.assert_called_once()
thread_2.join.assert_called_once()
class MyClass(object):
FOO = 10
@utils.require_variables('FOO')
def method(self, a, b):
return self.FOO + a + b
@utils.require_variables('BAR', 'BAZ')
def method_that_miss_variables(self):
return self.BAR, self.BAZ
class RequiredVariablesTestCase(test.TestCase):
def test_require_variables(self):
inst = MyClass()
self.assertEqual(inst.method(1, b=2), 13)
def test_require_variables_not_implemented(self):
inst = MyClass()
err = self.assertRaises(NotImplementedError,
inst.method_that_miss_variables)
msg = 'BAR, BAZ required for MyClass.method_that_miss_variables'
self.assertEqual(str(err), msg)

View File

@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
import threading
import traceback
@ -51,3 +52,22 @@ class ThreadsWrapper(object):
def join_threads(self):
for thread in self.threads:
thread.join()
def require_variables(*variables):
"""Class method decorator to check that required variables are present"""
def decorator(fn):
@functools.wraps(fn)
def wrapper(self, *args, **kawrgs):
missing_vars = []
for var in variables:
if not hasattr(self, var):
missing_vars.append(var)
if missing_vars:
missing_vars = ', '.join(missing_vars)
msg = '{} required for {}.{}'.format(
missing_vars, self.__class__.__name__, fn.__name__)
raise NotImplementedError(msg)
return fn(self, *args, **kawrgs)
return wrapper
return decorator