diff --git a/mistral/db/sqlalchemy/base.py b/mistral/db/sqlalchemy/base.py index 40ce09d90..daaeb92fe 100644 --- a/mistral/db/sqlalchemy/base.py +++ b/mistral/db/sqlalchemy/base.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import cachetools from oslo_config import cfg from oslo_db import options @@ -28,7 +29,8 @@ from mistral_lib import utils # Note(dzimine): sqlite only works for basic testing. options.set_defaults(cfg.CONF, connection="sqlite:///mistral.sqlite") -_DB_SESSION_THREAD_LOCAL_NAME = "db_sql_alchemy_session" +_DB_SESSION_THREAD_LOCAL_NAME = "__db_sql_alchemy_session__" +_TX_SCOPED_CACHE_THREAD_LOCAL_NAME = "__tx_scoped_cache__" _facade = None _sqlalchemy_create_engine_orig = sa.create_engine @@ -86,6 +88,10 @@ def _get_thread_local_session(): return utils.get_thread_local(_DB_SESSION_THREAD_LOCAL_NAME) +def get_tx_scoped_cache(): + return utils.get_thread_local(_TX_SCOPED_CACHE_THREAD_LOCAL_NAME) + + def _get_or_create_thread_local_session(): ses = _get_thread_local_session() @@ -101,6 +107,14 @@ def _get_or_create_thread_local_session(): def _set_thread_local_session(session): utils.set_thread_local(_DB_SESSION_THREAD_LOCAL_NAME, session) + if session is not None: + utils.set_thread_local( + _TX_SCOPED_CACHE_THREAD_LOCAL_NAME, + cachetools.LRUCache(maxsize=1000) + ) + else: + utils.set_thread_local(_TX_SCOPED_CACHE_THREAD_LOCAL_NAME, None) + def session_aware(param_name="session"): """Decorator for methods working within db session.""" diff --git a/mistral/db/utils.py b/mistral/db/utils.py index ca8750842..d6f7bdf2e 100644 --- a/mistral/db/utils.py +++ b/mistral/db/utils.py @@ -14,7 +14,11 @@ from __future__ import absolute_import +from cachetools import keys as cachetools_keys +import decorator import functools +import inspect +import six from sqlalchemy import exc as sqla_exc @@ -24,8 +28,10 @@ from oslo_log import log as logging import tenacity from mistral import context +from mistral.db.sqlalchemy import base as db_base from mistral import exceptions as exc from mistral.services import security +from mistral_lib import utils as ml_utils LOG = logging.getLogger(__name__) @@ -124,3 +130,87 @@ def check_db_obj_access(db_obj): "Can not modify a system %s resource, ID: %s" % (db_obj.__class__.__name__, db_obj.id) ) + + +def tx_cached(use_args=None, ignore_args=None): + """Decorates any function to cache its result within a DB transaction. + + Since a DB transaction is coupled with the current thread, the scope + of the underlying cache doesn't go beyond the thread. The decorator + is mainly useful for situations when we know we can safely cache a + result of some calculation if we know that it's not going to change + till the end of the current transaction. + + :param use_args: A tuple with argument names of the decorated function + used to build a cache key. + :param ignore_args: A tuple with argument names of the decorated function + that should be ignored when building a cache key. + :return: Decorated function. + """ + + if use_args and ignore_args: + raise ValueError( + "Only one of 'use_args' and 'ignore_args' can be used." + ) + + def _build_cache_key(func, *args, **kw): + # { arg name => arg value } + arg_dict = inspect.getcallargs(func, *args, **kw) + + if ignore_args: + if not isinstance(ignore_args, (six.string_types, tuple)): + raise ValueError( + "'ignore_args' must be either a tuple or a string," + " actual type: %s" % type(ignore_args) + ) + + ignore_args_tup = ( + ignore_args if isinstance(ignore_args, tuple) else + (ignore_args,) + ) + + for arg_name in ignore_args_tup: + arg_dict.pop(arg_name, None) + + if use_args: + if not isinstance(use_args, (six.string_types, tuple)): + raise ValueError( + "'use_args' must be either a tuple or a string," + " actual type: %s" % type(use_args) + ) + + use_args_tup = ( + use_args if isinstance(use_args, tuple) else (use_args,) + ) + + for arg_name in arg_dict.keys(): + if arg_name not in tuple(use_args_tup): + arg_dict.pop(arg_name, None) + + return cachetools_keys.hashkey(**arg_dict) + + @decorator.decorator + def _decorator(func, *args, **kw): + cache = db_base.get_tx_scoped_cache() + + # A DB transaction may not be necessarily open at the moment. + if not cache: + return func(*args, **kw) + + cache_key = _build_cache_key(func, *args, **kw) + + result = cache.get(cache_key, default=ml_utils.NotDefined) + + if result is not ml_utils.NotDefined: + return result + + # We don't do any exception handling here. In case of an exception + # nothing will be put into the cache and the exception will just + # bubble up as if there wasn't any wrapper. + result = func(*args, **kw) + + cache[cache_key] = result + + return result + + return _decorator diff --git a/mistral/expressions/std_functions.py b/mistral/expressions/std_functions.py index f67703910..9800bb90b 100644 --- a/mistral/expressions/std_functions.py +++ b/mistral/expressions/std_functions.py @@ -20,6 +20,7 @@ from oslo_log import log as logging from oslo_serialization import jsonutils import yaml +from mistral.db import utils as db_utils from mistral.db.v2 import api as db_api from mistral.utils import filter_utils from mistral_lib import utils @@ -35,48 +36,50 @@ def env_(context): return context['__env'] +@db_utils.tx_cached(ignore_args='context') def executions_(context, id=None, root_execution_id=None, state=None, from_time=None, to_time=None): - fltr = {} + filter_ = {} if id is not None: - fltr = filter_utils.create_or_update_filter('id', id, "eq", fltr) + filter_ = filter_utils.create_or_update_filter('id', id, "eq", filter_) if root_execution_id is not None: - fltr = filter_utils.create_or_update_filter( + filter_ = filter_utils.create_or_update_filter( 'root_execution_id', root_execution_id, 'eq', - fltr + filter_ ) if state is not None: - fltr = filter_utils.create_or_update_filter( + filter_ = filter_utils.create_or_update_filter( 'state', state, 'eq', - fltr + filter_ ) if from_time is not None: - fltr = filter_utils.create_or_update_filter( + filter_ = filter_utils.create_or_update_filter( 'created_at', from_time, 'gte', - fltr + filter_ ) if to_time is not None: - fltr = filter_utils.create_or_update_filter( + filter_ = filter_utils.create_or_update_filter( 'created_at', to_time, 'lt', - fltr + filter_ ) - return db_api.get_workflow_executions(**fltr) + return db_api.get_workflow_executions(**filter_) +@db_utils.tx_cached(ignore_args='context') def execution_(context): wf_ex = db_api.get_workflow_execution(context['__execution']['id']) @@ -113,6 +116,7 @@ def yaml_dump_(context, data): return yaml.safe_dump(data, default_flow_style=False) +@db_utils.tx_cached(ignore_args='context') def task_(context, task_name=None): # This section may not exist in a context if it's calculated not in # task scope. @@ -224,6 +228,7 @@ def _get_tasks_from_db(workflow_execution_id=None, recursive=False, state=None, return task_execs +@db_utils.tx_cached(ignore_args='context') def tasks_(context, workflow_execution_id=None, recursive=False, state=None, flat=False): task_execs = _get_tasks_from_db(