diff --git a/watcher/db/api.py b/watcher/db/api.py index f257ec485..059e0bf49 100644 --- a/watcher/db/api.py +++ b/watcher/db/api.py @@ -36,7 +36,7 @@ class BaseConnection(object): @abc.abstractmethod def get_goal_list(self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): + marker=None, sort_key=None, sort_dir=None, eager=False): """Get specific columns for matching goals. Return a list of the specified columns for all goals that @@ -50,6 +50,7 @@ class BaseConnection(object): :param sort_key: Attribute by which results should be sorted. :param sort_dir: direction in which results should be sorted. (asc, desc) + :param eager: If True, also loads One-to-X data (Default: False) :returns: A list of tuples of the specified columns. """ @@ -72,31 +73,34 @@ class BaseConnection(object): """ @abc.abstractmethod - def get_goal_by_id(self, context, goal_id): + def get_goal_by_id(self, context, goal_id, eager=False): """Return a goal given its ID. :param context: The security context :param goal_id: The ID of a goal + :param eager: If True, also loads One-to-X data (Default: False) :returns: A goal :raises: :py:class:`~.GoalNotFound` """ @abc.abstractmethod - def get_goal_by_uuid(self, context, goal_uuid): + def get_goal_by_uuid(self, context, goal_uuid, eager=False): """Return a goal given its UUID. :param context: The security context :param goal_uuid: The UUID of a goal + :param eager: If True, also loads One-to-X data (Default: False) :returns: A goal :raises: :py:class:`~.GoalNotFound` """ @abc.abstractmethod - def get_goal_by_name(self, context, goal_name): + def get_goal_by_name(self, context, goal_name, eager=False): """Return a goal given its name. :param context: The security context :param goal_name: The name of a goal + :param eager: If True, also loads One-to-X data (Default: False) :returns: A goal :raises: :py:class:`~.GoalNotFound` """ @@ -129,9 +133,17 @@ class BaseConnection(object): :raises: :py:class:`~.Invalid` """ + def soft_delete_goal(self, goal_id): + """Soft delete a goal. + + :param goal_id: The id or uuid of a goal. + :raises: :py:class:`~.GoalNotFound` + """ + @abc.abstractmethod def get_strategy_list(self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): + marker=None, sort_key=None, sort_dir=None, + eager=True): """Get specific columns for matching strategies. Return a list of the specified columns for all strategies that @@ -146,6 +158,7 @@ class BaseConnection(object): :param sort_key: Attribute by which results should be sorted. :param sort_dir: Direction in which results should be sorted. (asc, desc) + :param eager: If True, also loads One-to-X data (Default: False) :returns: A list of tuples of the specified columns. """ @@ -170,31 +183,34 @@ class BaseConnection(object): """ @abc.abstractmethod - def get_strategy_by_id(self, context, strategy_id): + def get_strategy_by_id(self, context, strategy_id, eager=False): """Return a strategy given its ID. :param context: The security context :param strategy_id: The ID of a strategy + :param eager: If True, also loads One-to-X data (Default: False) :returns: A strategy :raises: :py:class:`~.StrategyNotFound` """ @abc.abstractmethod - def get_strategy_by_uuid(self, context, strategy_uuid): + def get_strategy_by_uuid(self, context, strategy_uuid, eager=False): """Return a strategy given its UUID. :param context: The security context :param strategy_uuid: The UUID of a strategy + :param eager: If True, also loads One-to-X data (Default: False) :returns: A strategy :raises: :py:class:`~.StrategyNotFound` """ @abc.abstractmethod - def get_strategy_by_name(self, context, strategy_name): + def get_strategy_by_name(self, context, strategy_name, eager=False): """Return a strategy given its name. :param context: The security context :param strategy_name: The name of a strategy + :param eager: If True, also loads One-to-X data (Default: False) :returns: A strategy :raises: :py:class:`~.StrategyNotFound` """ @@ -217,10 +233,17 @@ class BaseConnection(object): :raises: :py:class:`~.Invalid` """ + def soft_delete_strategy(self, strategy_id): + """Soft delete a strategy. + + :param strategy_id: The id or uuid of a strategy. + :raises: :py:class:`~.StrategyNotFound` + """ + @abc.abstractmethod def get_audit_template_list(self, context, filters=None, limit=None, marker=None, sort_key=None, - sort_dir=None): + sort_dir=None, eager=False): """Get specific columns for matching audit templates. Return a list of the specified columns for all audit templates that @@ -234,6 +257,7 @@ class BaseConnection(object): :param sort_key: Attribute by which results should be sorted. :param sort_dir: direction in which results should be sorted. (asc, desc) + :param eager: If True, also loads One-to-X data (Default: False) :returns: A list of tuples of the specified columns. """ @@ -258,37 +282,43 @@ class BaseConnection(object): """ @abc.abstractmethod - def get_audit_template_by_id(self, context, audit_template_id): + def get_audit_template_by_id(self, context, audit_template_id, + eager=False): """Return an audit template. :param context: The security context :param audit_template_id: The id of an audit template. + :param eager: If True, also loads One-to-X data (Default: False) :returns: An audit template. :raises: :py:class:`~.AuditTemplateNotFound` """ @abc.abstractmethod - def get_audit_template_by_uuid(self, context, audit_template_uuid): + def get_audit_template_by_uuid(self, context, audit_template_uuid, + eager=False): """Return an audit template. :param context: The security context :param audit_template_uuid: The uuid of an audit template. + :param eager: If True, also loads One-to-X data (Default: False) :returns: An audit template. :raises: :py:class:`~.AuditTemplateNotFound` """ - def get_audit_template_by_name(self, context, audit_template_name): + def get_audit_template_by_name(self, context, audit_template_name, + eager=False): """Return an audit template. :param context: The security context :param audit_template_name: The name of an audit template. + :param eager: If True, also loads One-to-X data (Default: False) :returns: An audit template. :raises: :py:class:`~.AuditTemplateNotFound` """ @abc.abstractmethod def destroy_audit_template(self, audit_template_id): - """Destroy an audit_template. + """Destroy an audit template. :param audit_template_id: The id or uuid of an audit template. :raises: :py:class:`~.AuditTemplateNotFound` @@ -306,7 +336,7 @@ class BaseConnection(object): @abc.abstractmethod def soft_delete_audit_template(self, audit_template_id): - """Soft delete an audit_template. + """Soft delete an audit template. :param audit_template_id: The id or uuid of an audit template. :raises: :py:class:`~.AuditTemplateNotFound` @@ -314,7 +344,7 @@ class BaseConnection(object): @abc.abstractmethod def get_audit_list(self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): + marker=None, sort_key=None, sort_dir=None, eager=False): """Get specific columns for matching audits. Return a list of the specified columns for all audits that match the @@ -328,6 +358,7 @@ class BaseConnection(object): :param sort_key: Attribute by which results should be sorted. :param sort_dir: direction in which results should be sorted. (asc, desc) + :param eager: If True, also loads One-to-X data (Default: False) :returns: A list of tuples of the specified columns. """ @@ -351,21 +382,23 @@ class BaseConnection(object): """ @abc.abstractmethod - def get_audit_by_id(self, context, audit_id): + def get_audit_by_id(self, context, audit_id, eager=False): """Return an audit. :param context: The security context :param audit_id: The id of an audit. + :param eager: If True, also loads One-to-X data (Default: False) :returns: An audit. :raises: :py:class:`~.AuditNotFound` """ @abc.abstractmethod - def get_audit_by_uuid(self, context, audit_uuid): + def get_audit_by_uuid(self, context, audit_uuid, eager=False): """Return an audit. :param context: The security context :param audit_uuid: The uuid of an audit. + :param eager: If True, also loads One-to-X data (Default: False) :returns: An audit. :raises: :py:class:`~.AuditNotFound` """ @@ -392,13 +425,13 @@ class BaseConnection(object): """Soft delete an audit and all associated action plans. :param audit_id: The id or uuid of an audit. - :returns: An audit. :raises: :py:class:`~.AuditNotFound` """ @abc.abstractmethod def get_action_list(self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): + marker=None, sort_key=None, sort_dir=None, + eager=False): """Get specific columns for matching actions. Return a list of the specified columns for all actions that match the @@ -412,6 +445,7 @@ class BaseConnection(object): :param sort_key: Attribute by which results should be sorted. :param sort_dir: direction in which results should be sorted. (asc, desc) + :param eager: If True, also loads One-to-X data (Default: False) :returns: A list of tuples of the specified columns. """ @@ -436,21 +470,23 @@ class BaseConnection(object): """ @abc.abstractmethod - def get_action_by_id(self, context, action_id): + def get_action_by_id(self, context, action_id, eager=False): """Return a action. :param context: The security context :param action_id: The id of a action. + :param eager: If True, also loads One-to-X data (Default: False) :returns: A action. :raises: :py:class:`~.ActionNotFound` """ @abc.abstractmethod - def get_action_by_uuid(self, context, action_uuid): + def get_action_by_uuid(self, context, action_uuid, eager=False): """Return a action. :param context: The security context :param action_uuid: The uuid of a action. + :param eager: If True, also loads One-to-X data (Default: False) :returns: A action. :raises: :py:class:`~.ActionNotFound` """ @@ -475,10 +511,17 @@ class BaseConnection(object): :raises: :py:class:`~.Invalid` """ + def soft_delete_action(self, action_id): + """Soft delete an action. + + :param action_id: The id or uuid of an action. + :raises: :py:class:`~.ActionNotFound` + """ + @abc.abstractmethod def get_action_plan_list( self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): + marker=None, sort_key=None, sort_dir=None, eager=False): """Get specific columns for matching action plans. Return a list of the specified columns for all action plans that @@ -492,6 +535,7 @@ class BaseConnection(object): :param sort_key: Attribute by which results should be sorted. :param sort_dir: direction in which results should be sorted. (asc, desc) + :param eager: If True, also loads One-to-X data (Default: False) :returns: A list of tuples of the specified columns. """ @@ -506,21 +550,23 @@ class BaseConnection(object): """ @abc.abstractmethod - def get_action_plan_by_id(self, context, action_plan_id): + def get_action_plan_by_id(self, context, action_plan_id, eager=False): """Return an action plan. :param context: The security context :param action_plan_id: The id of an action plan. + :param eager: If True, also loads One-to-X data (Default: False) :returns: An action plan. :raises: :py:class:`~.ActionPlanNotFound` """ @abc.abstractmethod - def get_action_plan_by_uuid(self, context, action_plan__uuid): + def get_action_plan_by_uuid(self, context, action_plan__uuid, eager=False): """Return a action plan. :param context: The security context :param action_plan__uuid: The uuid of an action plan. + :param eager: If True, also loads One-to-X data (Default: False) :returns: An action plan. :raises: :py:class:`~.ActionPlanNotFound` """ @@ -545,9 +591,17 @@ class BaseConnection(object): :raises: :py:class:`~.Invalid` """ + def soft_delete_action_plan(self, action_plan_id): + """Soft delete an action plan. + + :param action_plan_id: The id or uuid of an action plan. + :raises: :py:class:`~.ActionPlanNotFound` + """ + @abc.abstractmethod def get_efficacy_indicator_list(self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): + marker=None, sort_key=None, sort_dir=None, + eager=False): """Get specific columns for matching efficacy indicators. Return a list of the specified columns for all efficacy indicators that @@ -564,6 +618,7 @@ class BaseConnection(object): :param sort_key: Attribute by which results should be sorted. :param sort_dir: Direction in which results should be sorted. (asc, desc) + :param eager: If True, also loads One-to-X data (Default: False) :returns: A list of tuples of the specified columns. """ @@ -588,31 +643,37 @@ class BaseConnection(object): """ @abc.abstractmethod - def get_efficacy_indicator_by_id(self, context, efficacy_indicator_id): + def get_efficacy_indicator_by_id(self, context, efficacy_indicator_id, + eager=False): """Return an efficacy indicator given its ID. :param context: The security context :param efficacy_indicator_id: The ID of an efficacy indicator + :param eager: If True, also loads One-to-X data (Default: False) :returns: An efficacy indicator :raises: :py:class:`~.EfficacyIndicatorNotFound` """ @abc.abstractmethod - def get_efficacy_indicator_by_uuid(self, context, efficacy_indicator_uuid): + def get_efficacy_indicator_by_uuid(self, context, efficacy_indicator_uuid, + eager=False): """Return an efficacy indicator given its UUID. :param context: The security context :param efficacy_indicator_uuid: The UUID of an efficacy indicator + :param eager: If True, also loads One-to-X data (Default: False) :returns: An efficacy indicator :raises: :py:class:`~.EfficacyIndicatorNotFound` """ @abc.abstractmethod - def get_efficacy_indicator_by_name(self, context, efficacy_indicator_name): + def get_efficacy_indicator_by_name(self, context, efficacy_indicator_name, + eager=False): """Return an efficacy indicator given its name. :param context: The security context :param efficacy_indicator_name: The name of an efficacy indicator + :param eager: If True, also loads One-to-X data (Default: False) :returns: An efficacy indicator :raises: :py:class:`~.EfficacyIndicatorNotFound` """ @@ -626,7 +687,7 @@ class BaseConnection(object): """ @abc.abstractmethod - def update_efficacy_indicator(self, efficacy_indicator_uuid, values): + def update_efficacy_indicator(self, efficacy_indicator_id, values): """Update properties of an efficacy indicator. :param efficacy_indicator_uuid: The UUID of an efficacy indicator @@ -638,7 +699,7 @@ class BaseConnection(object): @abc.abstractmethod def get_scoring_engine_list( self, context, columns=None, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): + marker=None, sort_key=None, sort_dir=None, eager=False): """Get specific columns for matching scoring engines. Return a list of the specified columns for all scoring engines that @@ -654,6 +715,7 @@ class BaseConnection(object): :param sort_key: Attribute by which results should be sorted. :param sort_dir: direction in which results should be sorted. (asc, desc) + :param eager: If True, also loads One-to-X data (Default: False) :returns: A list of tuples of the specified columns. """ @@ -668,31 +730,37 @@ class BaseConnection(object): """ @abc.abstractmethod - def get_scoring_engine_by_id(self, context, scoring_engine_id): + def get_scoring_engine_by_id(self, context, scoring_engine_id, + eager=False): """Return a scoring engine by its id. :param context: The security context :param scoring_engine_id: The id of a scoring engine. + :param eager: If True, also loads One-to-X data (Default: False) :returns: A scoring engine. :raises: :py:class:`~.ScoringEngineNotFound` """ @abc.abstractmethod - def get_scoring_engine_by_uuid(self, context, scoring_engine_uuid): + def get_scoring_engine_by_uuid(self, context, scoring_engine_uuid, + eager=False): """Return a scoring engine by its uuid. :param context: The security context :param scoring_engine_uuid: The uuid of a scoring engine. + :param eager: If True, also loads One-to-X data (Default: False) :returns: A scoring engine. :raises: :py:class:`~.ScoringEngineNotFound` """ @abc.abstractmethod - def get_scoring_engine_by_name(self, context, scoring_engine_name): + def get_scoring_engine_by_name(self, context, scoring_engine_name, + eager=False): """Return a scoring engine by its name. :param context: The security context :param scoring_engine_name: The name of a scoring engine. + :param eager: If True, also loads One-to-X data (Default: False) :returns: A scoring engine. :raises: :py:class:`~.ScoringEngineNotFound` """ @@ -716,8 +784,8 @@ class BaseConnection(object): """ @abc.abstractmethod - def get_service_list(self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): + def get_service_list(self, context, filters=None, limit=None, marker=None, + sort_key=None, sort_dir=None, eager=False): """Get specific columns for matching services. Return a list of the specified columns for all services that @@ -732,6 +800,7 @@ class BaseConnection(object): :param sort_key: Attribute by which results should be sorted. :param sort_dir: Direction in which results should be sorted. (asc, desc) + :param eager: If True, also loads One-to-X data (Default: False) :returns: A list of tuples of the specified columns. """ @@ -755,21 +824,23 @@ class BaseConnection(object): """ @abc.abstractmethod - def get_service_by_id(self, context, service_id): + def get_service_by_id(self, context, service_id, eager=False): """Return a service given its ID. :param context: The security context :param service_id: The ID of a service + :param eager: If True, also loads One-to-X data (Default: False) :returns: A service :raises: :py:class:`~.ServiceNotFound` """ @abc.abstractmethod - def get_service_by_name(self, context, service_name): + def get_service_by_name(self, context, service_name, eager=False): """Return a service given its name. :param context: The security context :param service_name: The name of a service + :param eager: If True, also loads One-to-X data (Default: False) :returns: A service :raises: :py:class:`~.ServiceNotFound` """ diff --git a/watcher/db/sqlalchemy/api.py b/watcher/db/sqlalchemy/api.py index 8be7bb326..cf5e0a973 100644 --- a/watcher/db/sqlalchemy/api.py +++ b/watcher/db/sqlalchemy/api.py @@ -24,7 +24,10 @@ from oslo_config import cfg from oslo_db import exception as db_exc from oslo_db.sqlalchemy import session as db_session from oslo_db.sqlalchemy import utils as db_utils +from oslo_utils import timeutils +from sqlalchemy.inspection import inspect from sqlalchemy.orm import exc +from sqlalchemy.orm import joinedload from watcher._i18n import _ from watcher.common import exception @@ -34,7 +37,6 @@ from watcher.db.sqlalchemy import models from watcher.objects import action as action_objects from watcher.objects import action_plan as ap_objects from watcher.objects import audit as audit_objects -from watcher.objects import utils as objutils CONF = cfg.CONF @@ -133,8 +135,9 @@ class Connection(api.BaseConnection): def __add_simple_filter(self, query, model, fieldname, value, operator_): field = getattr(model, fieldname) - if field.type.python_type is datetime.datetime: - value = objutils.datetime_or_str_or_none(value) + if field.type.python_type is datetime.datetime and value: + if not isinstance(value, datetime.datetime): + value = timeutils.parse_isotime(value) return query.filter(self.valid_operators[operator_](field, value)) @@ -233,8 +236,20 @@ class Connection(api.BaseConnection): return query - def _get(self, context, model, fieldname, value): + @staticmethod + def _set_eager_options(model, query): + relationships = inspect(model).relationships + for relationship in relationships: + if not relationship.uselist: + # We have a One-to-X relationship + query = query.options(joinedload(relationship.key)) + return query + + def _get(self, context, model, fieldname, value, eager): query = model_query(model) + if eager: + query = self._set_eager_options(model, query) + query = query.filter(getattr(model, fieldname) == value) if not context.show_deleted: query = query.filter(model.deleted_at.is_(None)) @@ -246,7 +261,8 @@ class Connection(api.BaseConnection): return obj - def _update(self, model, id_, values): + @staticmethod + def _update(model, id_, values): session = get_session() with session.begin(): query = model_query(model, session=session) @@ -259,7 +275,8 @@ class Connection(api.BaseConnection): ref.update(values) return ref - def _soft_delete(self, model, id_): + @staticmethod + def _soft_delete(model, id_): session = get_session() with session.begin(): query = model_query(model, session=session) @@ -271,7 +288,8 @@ class Connection(api.BaseConnection): query.soft_delete() - def _destroy(self, model, id_): + @staticmethod + def _destroy(model, id_): session = get_session() with session.begin(): query = model_query(model, session=session) @@ -398,10 +416,11 @@ class Connection(api.BaseConnection): # ### GOALS ### # - def get_goal_list(self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): - + def get_goal_list(self, context, filters=None, limit=None, marker=None, + sort_key=None, sort_dir=None, eager=False): query = model_query(models.Goal) + if eager: + query = self._set_eager_options(models.Goal, query) query = self._add_goals_filters(query, filters) if not context.show_deleted: query = query.filter_by(deleted_at=None) @@ -422,21 +441,24 @@ class Connection(api.BaseConnection): raise exception.GoalAlreadyExists(uuid=values['uuid']) return goal - def _get_goal(self, context, fieldname, value): + def _get_goal(self, context, fieldname, value, eager): try: return self._get(context, model=models.Goal, - fieldname=fieldname, value=value) + fieldname=fieldname, value=value, eager=eager) except exception.ResourceNotFound: raise exception.GoalNotFound(goal=value) - def get_goal_by_id(self, context, goal_id): - return self._get_goal(context, fieldname="id", value=goal_id) + def get_goal_by_id(self, context, goal_id, eager=False): + return self._get_goal( + context, fieldname="id", value=goal_id, eager=eager) - def get_goal_by_uuid(self, context, goal_uuid): - return self._get_goal(context, fieldname="uuid", value=goal_uuid) + def get_goal_by_uuid(self, context, goal_uuid, eager=False): + return self._get_goal( + context, fieldname="uuid", value=goal_uuid, eager=eager) - def get_goal_by_name(self, context, goal_name): - return self._get_goal(context, fieldname="name", value=goal_name) + def get_goal_by_name(self, context, goal_name, eager=False): + return self._get_goal( + context, fieldname="name", value=goal_name, eager=eager) def destroy_goal(self, goal_id): try: @@ -463,9 +485,11 @@ class Connection(api.BaseConnection): # ### STRATEGIES ### # def get_strategy_list(self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): - + marker=None, sort_key=None, sort_dir=None, + eager=True): query = model_query(models.Strategy) + if eager: + query = self._set_eager_options(models.Strategy, query) query = self._add_strategies_filters(query, filters) if not context.show_deleted: query = query.filter_by(deleted_at=None) @@ -486,23 +510,24 @@ class Connection(api.BaseConnection): raise exception.StrategyAlreadyExists(uuid=values['uuid']) return strategy - def _get_strategy(self, context, fieldname, value): + def _get_strategy(self, context, fieldname, value, eager): try: return self._get(context, model=models.Strategy, - fieldname=fieldname, value=value) + fieldname=fieldname, value=value, eager=eager) except exception.ResourceNotFound: raise exception.StrategyNotFound(strategy=value) - def get_strategy_by_id(self, context, strategy_id): - return self._get_strategy(context, fieldname="id", value=strategy_id) - - def get_strategy_by_uuid(self, context, strategy_uuid): + def get_strategy_by_id(self, context, strategy_id, eager=False): return self._get_strategy( - context, fieldname="uuid", value=strategy_uuid) + context, fieldname="id", value=strategy_id, eager=eager) - def get_strategy_by_name(self, context, strategy_name): + def get_strategy_by_uuid(self, context, strategy_uuid, eager=False): return self._get_strategy( - context, fieldname="name", value=strategy_name) + context, fieldname="uuid", value=strategy_uuid, eager=eager) + + def get_strategy_by_name(self, context, strategy_name, eager=False): + return self._get_strategy( + context, fieldname="name", value=strategy_name, eager=eager) def destroy_strategy(self, strategy_id): try: @@ -529,9 +554,12 @@ class Connection(api.BaseConnection): # ### AUDIT TEMPLATES ### # def get_audit_template_list(self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): + marker=None, sort_key=None, sort_dir=None, + eager=False): query = model_query(models.AuditTemplate) + if eager: + query = self._set_eager_options(models.AuditTemplate, query) query = self._add_audit_templates_filters(query, filters) if not context.show_deleted: query = query.filter_by(deleted_at=None) @@ -561,24 +589,27 @@ class Connection(api.BaseConnection): audit_template=values['name']) return audit_template - def _get_audit_template(self, context, fieldname, value): + def _get_audit_template(self, context, fieldname, value, eager): try: return self._get(context, model=models.AuditTemplate, - fieldname=fieldname, value=value) + fieldname=fieldname, value=value, eager=eager) except exception.ResourceNotFound: raise exception.AuditTemplateNotFound(audit_template=value) - def get_audit_template_by_id(self, context, audit_template_id): + def get_audit_template_by_id(self, context, audit_template_id, + eager=False): return self._get_audit_template( - context, fieldname="id", value=audit_template_id) + context, fieldname="id", value=audit_template_id, eager=eager) - def get_audit_template_by_uuid(self, context, audit_template_uuid): + def get_audit_template_by_uuid(self, context, audit_template_uuid, + eager=False): return self._get_audit_template( - context, fieldname="uuid", value=audit_template_uuid) + context, fieldname="uuid", value=audit_template_uuid, eager=eager) - def get_audit_template_by_name(self, context, audit_template_name): + def get_audit_template_by_name(self, context, audit_template_name, + eager=False): return self._get_audit_template( - context, fieldname="name", value=audit_template_name) + context, fieldname="name", value=audit_template_name, eager=eager) def destroy_audit_template(self, audit_template_id): try: @@ -609,8 +640,10 @@ class Connection(api.BaseConnection): # ### AUDITS ### # def get_audit_list(self, context, filters=None, limit=None, marker=None, - sort_key=None, sort_dir=None): + sort_key=None, sort_dir=None, eager=False): query = model_query(models.Audit) + if eager: + query = self._set_eager_options(models.Audit, query) query = self._add_audits_filters(query, filters) if not context.show_deleted: query = query.filter( @@ -636,30 +669,20 @@ class Connection(api.BaseConnection): raise exception.AuditAlreadyExists(uuid=values['uuid']) return audit - def get_audit_by_id(self, context, audit_id): - query = model_query(models.Audit) - query = query.filter_by(id=audit_id) + def _get_audit(self, context, fieldname, value, eager): try: - audit = query.one() - if not context.show_deleted: - if audit.state == audit_objects.State.DELETED: - raise exception.AuditNotFound(audit=audit_id) - return audit - except exc.NoResultFound: - raise exception.AuditNotFound(audit=audit_id) + return self._get(context, model=models.Audit, + fieldname=fieldname, value=value, eager=eager) + except exception.ResourceNotFound: + raise exception.AuditNotFound(audit=value) - def get_audit_by_uuid(self, context, audit_uuid): - query = model_query(models.Audit) - query = query.filter_by(uuid=audit_uuid) + def get_audit_by_id(self, context, audit_id, eager=False): + return self._get_audit( + context, fieldname="id", value=audit_id, eager=eager) - try: - audit = query.one() - if not context.show_deleted: - if audit.state == audit_objects.State.DELETED: - raise exception.AuditNotFound(audit=audit_uuid) - return audit - except exc.NoResultFound: - raise exception.AuditNotFound(audit=audit_uuid) + def get_audit_by_uuid(self, context, audit_uuid, eager=False): + return self._get_audit( + context, fieldname="uuid", value=audit_uuid, eager=eager) def destroy_audit(self, audit_id): def is_audit_referenced(session, audit_id): @@ -704,8 +727,10 @@ class Connection(api.BaseConnection): # ### ACTIONS ### # def get_action_list(self, context, filters=None, limit=None, marker=None, - sort_key=None, sort_dir=None): + sort_key=None, sort_dir=None, eager=False): query = model_query(models.Action) + if eager: + query = self._set_eager_options(models.Action, query) query = self._add_actions_filters(query, filters) if not context.show_deleted: query = query.filter( @@ -726,31 +751,20 @@ class Connection(api.BaseConnection): raise exception.ActionAlreadyExists(uuid=values['uuid']) return action - def get_action_by_id(self, context, action_id): - query = model_query(models.Action) - query = query.filter_by(id=action_id) + def _get_action(self, context, fieldname, value, eager): try: - action = query.one() - if not context.show_deleted: - if action.state == action_objects.State.DELETED: - raise exception.ActionNotFound( - action=action_id) - return action - except exc.NoResultFound: - raise exception.ActionNotFound(action=action_id) + return self._get(context, model=models.Action, + fieldname=fieldname, value=value, eager=eager) + except exception.ResourceNotFound: + raise exception.ActionNotFound(action=value) - def get_action_by_uuid(self, context, action_uuid): - query = model_query(models.Action) - query = query.filter_by(uuid=action_uuid) - try: - action = query.one() - if not context.show_deleted: - if action.state == action_objects.State.DELETED: - raise exception.ActionNotFound( - action=action_uuid) - return action - except exc.NoResultFound: - raise exception.ActionNotFound(action=action_uuid) + def get_action_by_id(self, context, action_id, eager=False): + return self._get_action( + context, fieldname="id", value=action_id, eager=eager) + + def get_action_by_uuid(self, context, action_uuid, eager=False): + return self._get_action( + context, fieldname="uuid", value=action_uuid, eager=eager) def destroy_action(self, action_id): session = get_session() @@ -765,12 +779,12 @@ class Connection(api.BaseConnection): # NOTE(dtantsur): this can lead to very strange errors if 'uuid' in values: raise exception.Invalid( - message=_("Cannot overwrite UUID for an existing " - "Action.")) + message=_("Cannot overwrite UUID for an existing Action.")) return self._do_update_action(action_id, values) - def _do_update_action(self, action_id, values): + @staticmethod + def _do_update_action(action_id, values): session = get_session() with session.begin(): query = model_query(models.Action, session=session) @@ -799,9 +813,11 @@ class Connection(api.BaseConnection): # ### ACTION PLANS ### # def get_action_plan_list( - self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): + self, context, filters=None, limit=None, marker=None, + sort_key=None, sort_dir=None, eager=False): query = model_query(models.ActionPlan) + if eager: + query = self._set_eager_options(models.ActionPlan, query) query = self._add_action_plans_filters(query, filters) if not context.show_deleted: query = query.filter( @@ -824,32 +840,20 @@ class Connection(api.BaseConnection): raise exception.ActionPlanAlreadyExists(uuid=values['uuid']) return action_plan - def get_action_plan_by_id(self, context, action_plan_id): - query = model_query(models.ActionPlan) - query = query.filter_by(id=action_plan_id) + def _get_action_plan(self, context, fieldname, value, eager): try: - action_plan = query.one() - if not context.show_deleted: - if action_plan.state == ap_objects.State.DELETED: - raise exception.ActionPlanNotFound( - action_plan=action_plan_id) - return action_plan - except exc.NoResultFound: - raise exception.ActionPlanNotFound(action_plan=action_plan_id) + return self._get(context, model=models.ActionPlan, + fieldname=fieldname, value=value, eager=eager) + except exception.ResourceNotFound: + raise exception.ActionPlanNotFound(action_plan=value) - def get_action_plan_by_uuid(self, context, action_plan__uuid): - query = model_query(models.ActionPlan) - query = query.filter_by(uuid=action_plan__uuid) + def get_action_plan_by_id(self, context, action_plan_id, eager=False): + return self._get_action_plan( + context, fieldname="id", value=action_plan_id, eager=eager) - try: - action_plan = query.one() - if not context.show_deleted: - if action_plan.state == ap_objects.State.DELETED: - raise exception.ActionPlanNotFound( - action_plan=action_plan__uuid) - return action_plan - except exc.NoResultFound: - raise exception.ActionPlanNotFound(action_plan=action_plan__uuid) + def get_action_plan_by_uuid(self, context, action_plan_uuid, eager=False): + return self._get_action_plan( + context, fieldname="uuid", value=action_plan_uuid, eager=eager) def destroy_action_plan(self, action_plan_id): def is_action_plan_referenced(session, action_plan_id): @@ -883,7 +887,8 @@ class Connection(api.BaseConnection): return self._do_update_action_plan(action_plan_id, values) - def _do_update_action_plan(self, action_plan_id, values): + @staticmethod + def _do_update_action_plan(action_plan_id, values): session = get_session() with session.begin(): query = model_query(models.ActionPlan, session=session) @@ -912,9 +917,12 @@ class Connection(api.BaseConnection): # ### EFFICACY INDICATORS ### # def get_efficacy_indicator_list(self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): + marker=None, sort_key=None, sort_dir=None, + eager=False): query = model_query(models.EfficacyIndicator) + if eager: + query = self._set_eager_options(models.EfficacyIndicator, query) query = self._add_efficacy_indicators_filters(query, filters) if not context.show_deleted: query = query.filter_by(deleted_at=None) @@ -935,24 +943,30 @@ class Connection(api.BaseConnection): raise exception.EfficacyIndicatorAlreadyExists(uuid=values['uuid']) return efficacy_indicator - def _get_efficacy_indicator(self, context, fieldname, value): + def _get_efficacy_indicator(self, context, fieldname, value, eager): try: return self._get(context, model=models.EfficacyIndicator, - fieldname=fieldname, value=value) + fieldname=fieldname, value=value, eager=eager) except exception.ResourceNotFound: raise exception.EfficacyIndicatorNotFound(efficacy_indicator=value) - def get_efficacy_indicator_by_id(self, context, efficacy_indicator_id): + def get_efficacy_indicator_by_id(self, context, efficacy_indicator_id, + eager=False): return self._get_efficacy_indicator( - context, fieldname="id", value=efficacy_indicator_id) + context, fieldname="id", + value=efficacy_indicator_id, eager=eager) - def get_efficacy_indicator_by_uuid(self, context, efficacy_indicator_uuid): + def get_efficacy_indicator_by_uuid(self, context, efficacy_indicator_uuid, + eager=False): return self._get_efficacy_indicator( - context, fieldname="uuid", value=efficacy_indicator_uuid) + context, fieldname="uuid", + value=efficacy_indicator_uuid, eager=eager) - def get_efficacy_indicator_by_name(self, context, efficacy_indicator_name): + def get_efficacy_indicator_by_name(self, context, efficacy_indicator_name, + eager=False): return self._get_efficacy_indicator( - context, fieldname="name", value=efficacy_indicator_name) + context, fieldname="name", + value=efficacy_indicator_name, eager=eager) def update_efficacy_indicator(self, efficacy_indicator_id, values): if 'uuid' in values: @@ -995,9 +1009,11 @@ class Connection(api.BaseConnection): plain_fields=plain_fields) def get_scoring_engine_list( - self, context, columns=None, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): + self, context, columns=None, filters=None, limit=None, + marker=None, sort_key=None, sort_dir=None, eager=False): query = model_query(models.ScoringEngine) + if eager: + query = self._set_eager_options(models.ScoringEngine, query) query = self._add_scoring_engine_filters(query, filters) if not context.show_deleted: query = query.filter_by(deleted_at=None) @@ -1019,24 +1035,27 @@ class Connection(api.BaseConnection): raise exception.ScoringEngineAlreadyExists(uuid=values['uuid']) return scoring_engine - def _get_scoring_engine(self, context, fieldname, value): + def _get_scoring_engine(self, context, fieldname, value, eager): try: return self._get(context, model=models.ScoringEngine, - fieldname=fieldname, value=value) + fieldname=fieldname, value=value, eager=eager) except exception.ResourceNotFound: raise exception.ScoringEngineNotFound(scoring_engine=value) - def get_scoring_engine_by_id(self, context, scoring_engine_id): + def get_scoring_engine_by_id(self, context, scoring_engine_id, + eager=False): return self._get_scoring_engine( - context, fieldname="id", value=scoring_engine_id) + context, fieldname="id", value=scoring_engine_id, eager=eager) - def get_scoring_engine_by_uuid(self, context, scoring_engine_uuid): + def get_scoring_engine_by_uuid(self, context, scoring_engine_uuid, + eager=False): return self._get_scoring_engine( - context, fieldname="uuid", value=scoring_engine_uuid) + context, fieldname="uuid", value=scoring_engine_uuid, eager=eager) - def get_scoring_engine_by_name(self, context, scoring_engine_name): + def get_scoring_engine_by_name(self, context, scoring_engine_name, + eager=False): return self._get_scoring_engine( - context, fieldname="name", value=scoring_engine_name) + context, fieldname="name", value=scoring_engine_name, eager=eager) def destroy_scoring_engine(self, scoring_engine_id): try: @@ -1046,9 +1065,9 @@ class Connection(api.BaseConnection): scoring_engine=scoring_engine_id) def update_scoring_engine(self, scoring_engine_id, values): - if 'id' in values: + if 'uuid' in values: raise exception.Invalid( - message=_("Cannot overwrite ID for an existing " + message=_("Cannot overwrite UUID for an existing " "Scoring Engine.")) try: @@ -1077,9 +1096,11 @@ class Connection(api.BaseConnection): query=query, model=models.Service, filters=filters, plain_fields=plain_fields) - def get_service_list(self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): + def get_service_list(self, context, filters=None, limit=None, marker=None, + sort_key=None, sort_dir=None, eager=False): query = model_query(models.Service) + if eager: + query = self._set_eager_options(models.Service, query) query = self._add_services_filters(query, filters) if not context.show_deleted: query = query.filter_by(deleted_at=None) @@ -1096,18 +1117,20 @@ class Connection(api.BaseConnection): host=values['host']) return service - def _get_service(self, context, fieldname, value): + def _get_service(self, context, fieldname, value, eager): try: return self._get(context, model=models.Service, - fieldname=fieldname, value=value) + fieldname=fieldname, value=value, eager=eager) except exception.ResourceNotFound: raise exception.ServiceNotFound(service=value) - def get_service_by_id(self, context, service_id): - return self._get_service(context, fieldname="id", value=service_id) + def get_service_by_id(self, context, service_id, eager=False): + return self._get_service( + context, fieldname="id", value=service_id, eager=eager) - def get_service_by_name(self, context, service_name): - return self._get_service(context, fieldname="name", value=service_name) + def get_service_by_name(self, context, service_name, eager=False): + return self._get_service( + context, fieldname="name", value=service_name, eager=eager) def destroy_service(self, service_id): try: diff --git a/watcher/db/sqlalchemy/models.py b/watcher/db/sqlalchemy/models.py index 90157973e..97351d654 100644 --- a/watcher/db/sqlalchemy/models.py +++ b/watcher/db/sqlalchemy/models.py @@ -27,6 +27,7 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import Numeric +from sqlalchemy import orm from sqlalchemy import String from sqlalchemy import Text from sqlalchemy.types import TypeDecorator, TEXT @@ -57,6 +58,7 @@ def table_args(): class JsonEncodedType(TypeDecorator): """Abstract base type serialized as json-encoded string in db.""" + type = None impl = TEXT @@ -81,11 +83,13 @@ class JsonEncodedType(TypeDecorator): class JSONEncodedDict(JsonEncodedType): """Represents dict serialized as json-encoded string in db.""" + type = dict class JSONEncodedList(JsonEncodedType): """Represents list serialized as json-encoded string in db.""" + type = list @@ -111,23 +115,6 @@ class WatcherBase(models.SoftDeleteMixin, Base = declarative_base(cls=WatcherBase) -class Strategy(Base): - """Represents a strategy.""" - - __tablename__ = 'strategies' - __table_args__ = ( - UniqueConstraint('uuid', name='uniq_strategies0uuid'), - UniqueConstraint('name', 'deleted', name='uniq_strategies0name'), - table_args() - ) - id = Column(Integer, primary_key=True) - uuid = Column(String(36)) - name = Column(String(63), nullable=False) - display_name = Column(String(63), nullable=False) - goal_id = Column(Integer, ForeignKey('goals.id'), nullable=False) - parameters_spec = Column(JSONEncodedDict, nullable=True) - - class Goal(Base): """Represents a goal.""" @@ -137,13 +124,32 @@ class Goal(Base): UniqueConstraint('name', 'deleted', name='uniq_goals0name'), table_args(), ) - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, autoincrement=True) uuid = Column(String(36)) name = Column(String(63), nullable=False) display_name = Column(String(63), nullable=False) efficacy_specification = Column(JSONEncodedList, nullable=False) +class Strategy(Base): + """Represents a strategy.""" + + __tablename__ = 'strategies' + __table_args__ = ( + UniqueConstraint('uuid', name='uniq_strategies0uuid'), + UniqueConstraint('name', 'deleted', name='uniq_strategies0name'), + table_args() + ) + id = Column(Integer, primary_key=True, autoincrement=True) + uuid = Column(String(36)) + name = Column(String(63), nullable=False) + display_name = Column(String(63), nullable=False) + goal_id = Column(Integer, ForeignKey('goals.id'), nullable=False) + parameters_spec = Column(JSONEncodedDict, nullable=True) + + goal = orm.relationship(Goal, foreign_keys=goal_id, lazy=None) + + class AuditTemplate(Base): """Represents an audit template.""" @@ -163,6 +169,9 @@ class AuditTemplate(Base): version = Column(String(15), nullable=True) scope = Column(JSONEncodedList) + goal = orm.relationship(Goal, foreign_keys=goal_id, lazy=None) + strategy = orm.relationship(Strategy, foreign_keys=strategy_id, lazy=None) + class Audit(Base): """Represents an audit.""" @@ -172,7 +181,7 @@ class Audit(Base): UniqueConstraint('uuid', name='uniq_audits0uuid'), table_args() ) - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, autoincrement=True) uuid = Column(String(36)) audit_type = Column(String(20)) state = Column(String(20), nullable=True) @@ -183,24 +192,8 @@ class Audit(Base): strategy_id = Column(Integer, ForeignKey('strategies.id'), nullable=True) scope = Column(JSONEncodedList, nullable=True) - -class Action(Base): - """Represents an action.""" - - __tablename__ = 'actions' - __table_args__ = ( - UniqueConstraint('uuid', name='uniq_actions0uuid'), - table_args() - ) - id = Column(Integer, primary_key=True) - uuid = Column(String(36), nullable=False) - action_plan_id = Column(Integer, ForeignKey('action_plans.id'), - nullable=False) - # only for the first version - action_type = Column(String(255), nullable=False) - input_parameters = Column(JSONEncodedDict, nullable=True) - state = Column(String(20), nullable=True) - next = Column(String(36), nullable=True) + goal = orm.relationship(Goal, foreign_keys=goal_id, lazy=None) + strategy = orm.relationship(Strategy, foreign_keys=strategy_id, lazy=None) class ActionPlan(Base): @@ -211,7 +204,7 @@ class ActionPlan(Base): UniqueConstraint('uuid', name='uniq_action_plans0uuid'), table_args() ) - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, autoincrement=True) uuid = Column(String(36)) first_action_id = Column(Integer) audit_id = Column(Integer, ForeignKey('audits.id'), nullable=False) @@ -219,6 +212,31 @@ class ActionPlan(Base): state = Column(String(20), nullable=True) global_efficacy = Column(JSONEncodedDict, nullable=True) + audit = orm.relationship(Audit, foreign_keys=audit_id, lazy=None) + strategy = orm.relationship(Strategy, foreign_keys=strategy_id, lazy=None) + + +class Action(Base): + """Represents an action.""" + + __tablename__ = 'actions' + __table_args__ = ( + UniqueConstraint('uuid', name='uniq_actions0uuid'), + table_args() + ) + id = Column(Integer, primary_key=True, autoincrement=True) + uuid = Column(String(36), nullable=False) + action_plan_id = Column(Integer, ForeignKey('action_plans.id'), + nullable=False) + # only for the first version + action_type = Column(String(255), nullable=False) + input_parameters = Column(JSONEncodedDict, nullable=True) + state = Column(String(20), nullable=True) + next = Column(String(36), nullable=True) + + action_plan = orm.relationship( + ActionPlan, foreign_keys=action_plan_id, lazy=None) + class EfficacyIndicator(Base): """Represents an efficacy indicator.""" @@ -228,7 +246,7 @@ class EfficacyIndicator(Base): UniqueConstraint('uuid', name='uniq_efficacy_indicators0uuid'), table_args() ) - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, autoincrement=True) uuid = Column(String(36)) name = Column(String(63)) description = Column(String(255), nullable=True) @@ -237,6 +255,9 @@ class EfficacyIndicator(Base): action_plan_id = Column(Integer, ForeignKey('action_plans.id'), nullable=False) + action_plan = orm.relationship( + ActionPlan, foreign_keys=action_plan_id, lazy=None) + class ScoringEngine(Base): """Represents a scoring engine.""" @@ -247,7 +268,7 @@ class ScoringEngine(Base): UniqueConstraint('name', 'deleted', name='uniq_scoring_engines0name'), table_args() ) - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, autoincrement=True) uuid = Column(String(36), nullable=False) name = Column(String(63), nullable=False) description = Column(String(255), nullable=True) diff --git a/watcher/tests/api/v1/test_actions_plans.py b/watcher/tests/api/v1/test_actions_plans.py index fb2668bf3..404a29d5e 100644 --- a/watcher/tests/api/v1/test_actions_plans.py +++ b/watcher/tests/api/v1/test_actions_plans.py @@ -504,24 +504,6 @@ class TestPatchStateTransitionDenied(api_base.FunctionalTest): self.assertTrue(response.json['error_message']) -class TestPatchStateDeletedNotFound(api_base.FunctionalTest): - - @mock.patch.object( - db_api.BaseConnection, 'update_action_plan', - mock.Mock(side_effect=lambda ap: ap.save() or ap)) - def test_replace_state_pending_not_found(self): - action_plan = obj_utils.create_test_action_plan( - self.context, state=objects.action_plan.State.DELETED) - - response = self.get_json( - '/action_plans/%s' % action_plan.uuid, - expect_errors=True - ) - self.assertEqual(404, response.status_code) - self.assertEqual('application/json', response.content_type) - self.assertTrue(response.json['error_message']) - - class TestPatchStateTransitionOk(api_base.FunctionalTest): scenarios = [ diff --git a/watcher/tests/db/test_action.py b/watcher/tests/db/test_action.py index 106fd4201..406ed639a 100644 --- a/watcher/tests/db/test_action.py +++ b/watcher/tests/db/test_action.py @@ -245,12 +245,32 @@ class DbActionTestCase(base.DbTestCase): def test_get_action_list(self): uuids = [] - for _ in range(1, 6): + for _ in range(1, 4): action = utils.create_test_action(uuid=w_utils.generate_uuid()) uuids.append(six.text_type(action['uuid'])) - res = self.dbapi.get_action_list(self.context) - res_uuids = [r.uuid for r in res] - self.assertEqual(uuids.sort(), res_uuids.sort()) + actions = self.dbapi.get_action_list(self.context) + action_uuids = [a.uuid for a in actions] + self.assertEqual(3, len(action_uuids)) + self.assertEqual(sorted(uuids), sorted(action_uuids)) + for action in actions: + self.assertIsNone(action.action_plan) + + def test_get_action_list_eager(self): + _action_plan = utils.get_test_action_plan() + action_plan = self.dbapi.create_action_plan(_action_plan) + + uuids = [] + for i in range(1, 4): + action = utils.create_test_action( + id=i, uuid=w_utils.generate_uuid(), + action_plan_id=action_plan.id) + uuids.append(six.text_type(action['uuid'])) + actions = self.dbapi.get_action_list(self.context, eager=True) + action_map = {a.uuid: a for a in actions} + self.assertEqual(sorted(uuids), sorted(action_map.keys())) + eager_action = action_map[action.uuid] + self.assertEqual( + action_plan.as_dict(), eager_action.action_plan.as_dict()) def test_get_action_list_with_filters(self): audit = utils.create_test_audit(uuid=w_utils.generate_uuid()) @@ -299,8 +319,8 @@ class DbActionTestCase(base.DbTestCase): self.context, filters={'action_plan_uuid': action_plan['uuid']}) self.assertEqual( - [action1['id'], action3['id']].sort(), - [r.id for r in res].sort()) + sorted([action1['id'], action3['id']]), + sorted([r.id for r in res])) res = self.dbapi.get_action_list( self.context, diff --git a/watcher/tests/db/test_action_plan.py b/watcher/tests/db/test_action_plan.py index aaa04e61f..31cd5593f 100644 --- a/watcher/tests/db/test_action_plan.py +++ b/watcher/tests/db/test_action_plan.py @@ -242,16 +242,40 @@ class DbActionPlanTestCase(base.DbTestCase): def test_get_action_plan_list(self): uuids = [] - for i in range(1, 6): - audit = utils.create_test_action_plan(uuid=w_utils.generate_uuid()) - uuids.append(six.text_type(audit['uuid'])) - res = self.dbapi.get_action_plan_list(self.context) - res_uuids = [r.uuid for r in res] - self.assertEqual(uuids.sort(), res_uuids.sort()) + for _ in range(1, 4): + action_plan = utils.create_test_action_plan( + uuid=w_utils.generate_uuid()) + uuids.append(six.text_type(action_plan['uuid'])) + action_plans = self.dbapi.get_action_plan_list(self.context) + action_plan_uuids = [ap.uuid for ap in action_plans] + self.assertEqual(sorted(uuids), sorted(action_plan_uuids)) + for action_plan in action_plans: + self.assertIsNone(action_plan.audit) + self.assertIsNone(action_plan.strategy) + + def test_get_action_plan_list_eager(self): + _strategy = utils.get_test_strategy() + strategy = self.dbapi.create_strategy(_strategy) + _audit = utils.get_test_audit() + audit = self.dbapi.create_audit(_audit) + + uuids = [] + for _ in range(1, 4): + action_plan = utils.create_test_action_plan( + uuid=w_utils.generate_uuid()) + uuids.append(six.text_type(action_plan['uuid'])) + action_plans = self.dbapi.get_action_plan_list( + self.context, eager=True) + action_plan_map = {a.uuid: a for a in action_plans} + self.assertEqual(sorted(uuids), sorted(action_plan_map.keys())) + eager_action_plan = action_plan_map[action_plan.uuid] + self.assertEqual( + strategy.as_dict(), eager_action_plan.strategy.as_dict()) + self.assertEqual(audit.as_dict(), eager_action_plan.audit.as_dict()) def test_get_action_plan_list_with_filters(self): audit = self._create_test_audit( - id=1, + id=2, audit_type='ONESHOT', uuid=w_utils.generate_uuid(), deadline=None, diff --git a/watcher/tests/db/test_audit.py b/watcher/tests/db/test_audit.py index e5afd1a28..1f76cfc0b 100644 --- a/watcher/tests/db/test_audit.py +++ b/watcher/tests/db/test_audit.py @@ -20,7 +20,7 @@ import six from watcher.common import exception from watcher.common import utils as w_utils -from watcher.objects import audit as audit_objects +from watcher import objects from watcher.tests.db import base from watcher.tests.db import utils @@ -48,11 +48,11 @@ class TestDbAuditFilters(base.DbTestCase): with freezegun.freeze_time(self.FAKE_OLD_DATE): self.audit2 = utils.create_test_audit( audit_template_id=self.audit_template.id, id=2, uuid=None, - state=audit_objects.State.FAILED) + state=objects.audit.State.FAILED) with freezegun.freeze_time(self.FAKE_OLDER_DATE): self.audit3 = utils.create_test_audit( audit_template_id=self.audit_template.id, id=3, uuid=None, - state=audit_objects.State.CANCELLED) + state=objects.audit.State.CANCELLED) def _soft_delete_audits(self): with freezegun.freeze_time(self.FAKE_TODAY): @@ -66,15 +66,15 @@ class TestDbAuditFilters(base.DbTestCase): with freezegun.freeze_time(self.FAKE_TODAY): self.dbapi.update_audit( self.audit1.uuid, - values={"state": audit_objects.State.SUCCEEDED}) + values={"state": objects.audit.State.SUCCEEDED}) with freezegun.freeze_time(self.FAKE_OLD_DATE): self.dbapi.update_audit( self.audit2.uuid, - values={"state": audit_objects.State.SUCCEEDED}) + values={"state": objects.audit.State.SUCCEEDED}) with freezegun.freeze_time(self.FAKE_OLDER_DATE): self.dbapi.update_audit( self.audit3.uuid, - values={"state": audit_objects.State.SUCCEEDED}) + values={"state": objects.audit.State.SUCCEEDED}) def test_get_audit_list_filter_deleted_true(self): with freezegun.freeze_time(self.FAKE_TODAY): @@ -230,8 +230,8 @@ class TestDbAuditFilters(base.DbTestCase): def test_get_audit_list_filter_state_in(self): res = self.dbapi.get_audit_list( self.context, - filters={'state__in': (audit_objects.State.FAILED, - audit_objects.State.CANCELLED)}) + filters={'state__in': (objects.audit.State.FAILED, + objects.audit.State.CANCELLED)}) self.assertEqual( [self.audit2['id'], self.audit3['id']], @@ -240,8 +240,8 @@ class TestDbAuditFilters(base.DbTestCase): def test_get_audit_list_filter_state_notin(self): res = self.dbapi.get_audit_list( self.context, - filters={'state__notin': (audit_objects.State.FAILED, - audit_objects.State.CANCELLED)}) + filters={'state__notin': (objects.audit.State.FAILED, + objects.audit.State.CANCELLED)}) self.assertEqual( [self.audit1['id']], @@ -257,29 +257,52 @@ class DbAuditTestCase(base.DbTestCase): def test_get_audit_list(self): uuids = [] - for _ in range(1, 6): + for _ in range(1, 4): audit = utils.create_test_audit(uuid=w_utils.generate_uuid()) uuids.append(six.text_type(audit['uuid'])) - res = self.dbapi.get_audit_list(self.context) - res_uuids = [r.uuid for r in res] - self.assertEqual(uuids.sort(), res_uuids.sort()) + audits = self.dbapi.get_audit_list(self.context) + audit_uuids = [a.uuid for a in audits] + self.assertEqual(sorted(uuids), sorted(audit_uuids)) + for audit in audits: + self.assertIsNone(audit.goal) + self.assertIsNone(audit.strategy) + + def test_get_audit_list_eager(self): + _goal = utils.get_test_goal() + goal = self.dbapi.create_goal(_goal) + _strategy = utils.get_test_strategy() + strategy = self.dbapi.create_strategy(_strategy) + + uuids = [] + for i in range(1, 4): + audit = utils.create_test_audit( + id=i, uuid=w_utils.generate_uuid(), + goal_id=goal.id, strategy_id=strategy.id) + uuids.append(six.text_type(audit['uuid'])) + audits = self.dbapi.get_audit_list(self.context, eager=True) + audit_map = {a.uuid: a for a in audits} + self.assertEqual(sorted(uuids), sorted(audit_map.keys())) + eager_audit = audit_map[audit.uuid] + self.assertEqual(goal.as_dict(), eager_audit.goal.as_dict()) + self.assertEqual(strategy.as_dict(), eager_audit.strategy.as_dict()) def test_get_audit_list_with_filters(self): audit1 = self._create_test_audit( id=1, - audit_type='ONESHOT', + audit_type=objects.audit.AuditType.ONESHOT.value, uuid=w_utils.generate_uuid(), deadline=None, - state=audit_objects.State.ONGOING) + state=objects.audit.State.ONGOING) audit2 = self._create_test_audit( id=2, audit_type='CONTINUOUS', uuid=w_utils.generate_uuid(), deadline=None, - state=audit_objects.State.PENDING) + state=objects.audit.State.PENDING) - res = self.dbapi.get_audit_list(self.context, - filters={'audit_type': 'ONESHOT'}) + res = self.dbapi.get_audit_list( + self.context, + filters={'audit_type': objects.audit.AuditType.ONESHOT.value}) self.assertEqual([audit1['id']], [r.id for r in res]) res = self.dbapi.get_audit_list(self.context, @@ -288,12 +311,12 @@ class DbAuditTestCase(base.DbTestCase): res = self.dbapi.get_audit_list( self.context, - filters={'state': audit_objects.State.ONGOING}) + filters={'state': objects.audit.State.ONGOING}) self.assertEqual([audit1['id']], [r.id for r in res]) res = self.dbapi.get_audit_list( self.context, - filters={'state': audit_objects.State.PENDING}) + filters={'state': objects.audit.State.PENDING}) self.assertEqual([audit2['id']], [r.id for r in res]) def test_get_audit_list_with_filter_by_uuid(self): diff --git a/watcher/tests/db/test_audit_template.py b/watcher/tests/db/test_audit_template.py index b206fec5a..110c77c55 100644 --- a/watcher/tests/db/test_audit_template.py +++ b/watcher/tests/db/test_audit_template.py @@ -232,14 +232,40 @@ class DbAuditTemplateTestCase(base.DbTestCase): def test_get_audit_template_list(self): uuids = [] - for i in range(1, 6): + for i in range(1, 4): audit_template = utils.create_test_audit_template( + id=i, uuid=w_utils.generate_uuid(), name='My Audit Template {0}'.format(i)) uuids.append(six.text_type(audit_template['uuid'])) - res = self.dbapi.get_audit_template_list(self.context) - res_uuids = [r.uuid for r in res] - self.assertEqual(uuids.sort(), res_uuids.sort()) + audit_templates = self.dbapi.get_audit_template_list(self.context) + audit_template_uuids = [at.uuid for at in audit_templates] + self.assertEqual(sorted(uuids), sorted(audit_template_uuids)) + for audit_template in audit_templates: + self.assertIsNone(audit_template.goal) + self.assertIsNone(audit_template.strategy) + + def test_get_audit_template_list_eager(self): + _goal = utils.get_test_goal() + goal = self.dbapi.create_goal(_goal) + _strategy = utils.get_test_strategy() + strategy = self.dbapi.create_strategy(_strategy) + + uuids = [] + for i in range(1, 4): + audit_template = utils.create_test_audit_template( + id=i, uuid=w_utils.generate_uuid(), + name='My Audit Template {0}'.format(i), + goal_id=goal.id, strategy_id=strategy.id) + uuids.append(six.text_type(audit_template['uuid'])) + audit_templates = self.dbapi.get_audit_template_list( + self.context, eager=True) + audit_template_map = {a.uuid: a for a in audit_templates} + self.assertEqual(sorted(uuids), sorted(audit_template_map.keys())) + eager_audit_template = audit_template_map[audit_template.uuid] + self.assertEqual(goal.as_dict(), eager_audit_template.goal.as_dict()) + self.assertEqual( + strategy.as_dict(), eager_audit_template.strategy.as_dict()) def test_get_audit_template_list_with_filters(self): audit_template1 = self._create_test_audit_template( diff --git a/watcher/tests/db/test_efficacy_indicator.py b/watcher/tests/db/test_efficacy_indicator.py index 55cda9df3..e09835517 100644 --- a/watcher/tests/db/test_efficacy_indicator.py +++ b/watcher/tests/db/test_efficacy_indicator.py @@ -256,14 +256,37 @@ class DbEfficacyIndicatorTestCase(base.DbTestCase): def test_get_efficacy_indicator_list(self): uuids = [] action_plan = self._create_test_action_plan() - for id_ in range(1, 6): + for id_ in range(1, 4): efficacy_indicator = utils.create_test_efficacy_indicator( action_plan_id=action_plan.id, id=id_, uuid=None, name="efficacy_indicator", description="Test Indicator ") uuids.append(six.text_type(efficacy_indicator['uuid'])) - res = self.dbapi.get_efficacy_indicator_list(self.context) - res_uuids = [r.uuid for r in res] - self.assertEqual(uuids.sort(), res_uuids.sort()) + efficacy_indicators = self.dbapi.get_efficacy_indicator_list( + self.context) + efficacy_indicator_uuids = [ei.uuid for ei in efficacy_indicators] + self.assertEqual(sorted(uuids), sorted(efficacy_indicator_uuids)) + for efficacy_indicator in efficacy_indicators: + self.assertIsNone(efficacy_indicator.action_plan) + + def test_get_efficacy_indicator_list_eager(self): + _action_plan = utils.get_test_action_plan() + action_plan = self.dbapi.create_action_plan(_action_plan) + + uuids = [] + for i in range(1, 4): + efficacy_indicator = utils.create_test_efficacy_indicator( + id=i, uuid=w_utils.generate_uuid(), + action_plan_id=action_plan.id) + uuids.append(six.text_type(efficacy_indicator['uuid'])) + efficacy_indicators = self.dbapi.get_efficacy_indicator_list( + self.context, eager=True) + efficacy_indicator_map = {a.uuid: a for a in efficacy_indicators} + self.assertEqual(sorted(uuids), sorted(efficacy_indicator_map.keys())) + eager_efficacy_indicator = efficacy_indicator_map[ + efficacy_indicator.uuid] + self.assertEqual( + action_plan.as_dict(), + eager_efficacy_indicator.action_plan.as_dict()) def test_get_efficacy_indicator_list_with_filters(self): audit = utils.create_test_audit(uuid=w_utils.generate_uuid()) @@ -311,8 +334,8 @@ class DbEfficacyIndicatorTestCase(base.DbTestCase): self.context, filters={'action_plan_uuid': action_plan['uuid']}) self.assertEqual( - [efficacy_indicator1['id'], efficacy_indicator3['id']].sort(), - [r.id for r in res].sort()) + sorted([efficacy_indicator1['id'], efficacy_indicator3['id']]), + sorted([r.id for r in res])) def test_get_efficacy_indicator_list_with_filter_by_uuid(self): efficacy_indicator = self._create_test_efficacy_indicator() diff --git a/watcher/tests/db/test_goal.py b/watcher/tests/db/test_goal.py index 703a4bac2..cae9449ea 100644 --- a/watcher/tests/db/test_goal.py +++ b/watcher/tests/db/test_goal.py @@ -230,16 +230,16 @@ class DbGoalTestCase(base.DbTestCase): def test_get_goal_list(self): uuids = [] - for i in range(1, 6): + for i in range(1, 4): goal = utils.create_test_goal( id=i, uuid=w_utils.generate_uuid(), name="GOAL_%s" % i, display_name='My Goal %s' % i) uuids.append(six.text_type(goal['uuid'])) - res = self.dbapi.get_goal_list(self.context) - res_uuids = [r.uuid for r in res] - self.assertEqual(uuids.sort(), res_uuids.sort()) + goals = self.dbapi.get_goal_list(self.context) + goal_uuids = [g.uuid for g in goals] + self.assertEqual(sorted(uuids), sorted(goal_uuids)) def test_get_goal_list_with_filters(self): goal1 = self._create_test_goal( diff --git a/watcher/tests/db/test_scoring_engine.py b/watcher/tests/db/test_scoring_engine.py index 5322d1135..02da05e8c 100644 --- a/watcher/tests/db/test_scoring_engine.py +++ b/watcher/tests/db/test_scoring_engine.py @@ -235,7 +235,7 @@ class DbScoringEngineTestCase(base.DbTestCase): def test_get_scoring_engine_list(self): names = [] - for i in range(1, 6): + for i in range(1, 4): scoring_engine = utils.create_test_scoring_engine( id=i, uuid=w_utils.generate_uuid(), @@ -243,9 +243,9 @@ class DbScoringEngineTestCase(base.DbTestCase): description='My ScoringEngine {0}'.format(i), metainfo='a{0}=b{0}'.format(i)) names.append(six.text_type(scoring_engine['name'])) - res = self.dbapi.get_scoring_engine_list(self.context) - res_names = [r.name for r in res] - self.assertEqual(names.sort(), res_names.sort()) + scoring_engines = self.dbapi.get_scoring_engine_list(self.context) + scoring_engines_names = [se.name for se in scoring_engines] + self.assertEqual(sorted(names), sorted(scoring_engines_names)) def test_get_scoring_engine_list_with_filters(self): scoring_engine1 = self._create_test_scoring_engine( @@ -310,7 +310,7 @@ class DbScoringEngineTestCase(base.DbTestCase): self.assertRaises(exception.Invalid, self.dbapi.update_scoring_engine, scoring_engine['id'], - {'id': 5}) + {'uuid': w_utils.generate_uuid()}) def test_update_scoring_engine_that_does_not_exist(self): self.assertRaises(exception.ScoringEngineNotFound, diff --git a/watcher/tests/db/test_service.py b/watcher/tests/db/test_service.py index 3803a13cb..cda5470f3 100644 --- a/watcher/tests/db/test_service.py +++ b/watcher/tests/db/test_service.py @@ -18,7 +18,6 @@ """Tests for manipulating Service via the DB API""" import freezegun -import six from oslo_utils import timeutils @@ -237,15 +236,15 @@ class DbServiceTestCase(base.DbTestCase): def test_get_service_list(self): ids = [] - for i in range(1, 6): + for i in range(1, 4): service = utils.create_test_service( id=i, name="SERVICE_ID_%s" % i, host="controller_{0}".format(i)) - ids.append(six.text_type(service['id'])) - res = self.dbapi.get_service_list(self.context) - res_ids = [r.id for r in res] - self.assertEqual(ids.sort(), res_ids.sort()) + ids.append(service['id']) + services = self.dbapi.get_service_list(self.context) + service_ids = [s.id for s in services] + self.assertEqual(sorted(ids), sorted(service_ids)) def test_get_service_list_with_filters(self): service1 = self._create_test_service( diff --git a/watcher/tests/db/test_strategy.py b/watcher/tests/db/test_strategy.py index 6841c046f..081fa7973 100644 --- a/watcher/tests/db/test_strategy.py +++ b/watcher/tests/db/test_strategy.py @@ -245,17 +245,37 @@ class DbStrategyTestCase(base.DbTestCase): return strategy def test_get_strategy_list(self): - ids = [] - for i in range(1, 6): + uuids = [] + for i in range(1, 4): strategy = utils.create_test_strategy( id=i, uuid=w_utils.generate_uuid(), name="STRATEGY_ID_%s" % i, display_name='My Strategy {0}'.format(i)) - ids.append(six.text_type(strategy['uuid'])) - res = self.dbapi.get_strategy_list(self.context) - res_ids = [r.display_name for r in res] - self.assertEqual(ids.sort(), res_ids.sort()) + uuids.append(six.text_type(strategy['uuid'])) + strategies = self.dbapi.get_strategy_list(self.context) + strategy_uuids = [s.uuid for s in strategies] + self.assertEqual(sorted(uuids), sorted(strategy_uuids)) + for strategy in strategies: + self.assertIsNone(strategy.goal) + + def test_get_strategy_list_eager(self): + _goal = utils.get_test_goal() + goal = self.dbapi.create_goal(_goal) + uuids = [] + for i in range(1, 4): + strategy = utils.create_test_strategy( + id=i, + uuid=w_utils.generate_uuid(), + name="STRATEGY_ID_%s" % i, + display_name='My Strategy {0}'.format(i), + goal_id=goal.id) + uuids.append(six.text_type(strategy['uuid'])) + strategys = self.dbapi.get_strategy_list(self.context, eager=True) + strategy_map = {a.uuid: a for a in strategys} + self.assertEqual(sorted(uuids), sorted(strategy_map.keys())) + eager_strategy = strategy_map[strategy.uuid] + self.assertEqual(goal.as_dict(), eager_strategy.goal.as_dict()) def test_get_strategy_list_with_filters(self): strategy1 = self._create_test_strategy(