Add RPC object for traits

Adds a new traits object to expose traits DB operations to the API. It
also adds a new traits field into the node object, with the appropriate
version compatibility logic.

get_node_by_port_addresses is modified to ensure we correctly join in
the tags and traits in that DB call, this avoids a orphaned db object
lazy load style failure.

_set_from_db_object in the base object is modified such that the new
traits object doesn't have to include the dictionary style compatibility
mix-in.

Change-Id: I69403b9875a020fab7a7975810b57bf646417953
Partial-Bug: #1722194
Co-Authored-By: Mark Goddard <mark@stackhpc.com>
This commit is contained in:
John Garbutt 2018-01-09 16:48:57 +00:00 committed by Mark Goddard
parent 248296a4f5
commit 2cd7232f14
14 changed files with 453 additions and 22 deletions

View File

@ -124,11 +124,13 @@ RELEASE_MAPPING = {
'api': '1.36',
'rpc': '1.43',
'objects': {
'Node': ['1.22'],
'Node': ['1.23'],
'Conductor': ['1.2'],
'Chassis': ['1.3'],
'Port': ['1.7'],
'Portgroup': ['1.3'],
'Trait': ['1.0'],
'TraitList': ['1.0'],
'VolumeConnector': ['1.0'],
'VolumeTarget': ['1.0'],
}

View File

@ -341,6 +341,9 @@ class Connection(api.Connection):
node = models.Node()
node.update(values)
# Set tags & traits to [] for new created node
node['tags'] = []
node['traits'] = []
with _session_for_write() as session:
try:
session.add(node)
@ -353,9 +356,6 @@ class Connection(api.Connection):
instance_uuid=values['instance_uuid'],
node=values['uuid'])
raise exception.NodeAlreadyExists(uuid=values['uuid'])
# Set tags & traits to [] for new created node
node['tags'] = []
node['traits'] = []
return node
def get_node_by_id(self, node_id):
@ -463,7 +463,7 @@ class Connection(api.Connection):
@oslo_db_api.retry_on_deadlock
def _do_update_node(self, node_id, values):
with _session_for_write():
query = model_query(models.Node)
query = _get_node_query_with_all()
query = add_identity_filter(query, node_id)
try:
ref = query.with_lockmode('update').one()
@ -978,7 +978,8 @@ class Connection(api.Connection):
return model_query(q.exists()).scalar()
def get_node_by_port_addresses(self, addresses):
q = model_query(models.Node).distinct().join(models.Port)
q = _get_node_query_with_all()
q = q.distinct().join(models.Port)
q = q.filter(models.Port.address.in_(addresses))
try:

View File

@ -29,5 +29,6 @@ def register_all():
__import__('ironic.objects.node')
__import__('ironic.objects.port')
__import__('ironic.objects.portgroup')
__import__('ironic.objects.trait')
__import__('ironic.objects.volume_connector')
__import__('ironic.objects.volume_target')

View File

@ -220,7 +220,7 @@ class IronicObject(object_base.VersionedObject):
"""
fields = fields or self.fields
for field in fields:
self[field] = db_object[field]
setattr(self, field, db_object[field])
@staticmethod
def _from_db_object(context, obj, db_object, fields=None):

View File

@ -82,6 +82,10 @@ class ObjectField(object_fields.ObjectField):
pass
class ListOfObjectsField(object_fields.ListOfObjectsField):
pass
class FlexibleDict(object_fields.FieldType):
@staticmethod
def coerce(obj, attr, value):

View File

@ -21,6 +21,7 @@ from oslo_versionedobjects import base as object_base
from ironic.common import exception
from ironic.common.i18n import _
from ironic.db import api as db_api
from ironic import objects
from ironic.objects import base
from ironic.objects import fields as object_fields
from ironic.objects import notification
@ -57,7 +58,8 @@ class Node(base.IronicObject, object_base.VersionedObjectDictCompat):
# Version 1.20: Type of network_interface changed to just nullable string
# Version 1.21: Add storage_interface field
# Version 1.22: Add rescue_interface field
VERSION = '1.22'
# Version 1.23: Add traits field
VERSION = '1.23'
dbapi = db_api.get_instance()
@ -128,6 +130,8 @@ class Node(base.IronicObject, object_base.VersionedObjectDictCompat):
'rescue_interface': object_fields.StringField(nullable=True),
'storage_interface': object_fields.StringField(nullable=True),
'vendor_interface': object_fields.StringField(nullable=True),
'traits': object_fields.ObjectField('TraitList', nullable=True),
}
def _validate_property_values(self, properties):
@ -157,6 +161,14 @@ class Node(base.IronicObject, object_base.VersionedObjectDictCompat):
{'node': self.uuid, 'msgs': ', '.join(invalid_msgs_list)})
raise exception.InvalidParameterValue(msg)
def _set_from_db_object(self, context, db_object, fields=None):
fields = set(fields or self.fields) - {'traits'}
super(Node, self)._set_from_db_object(context, db_object, fields)
self.traits = object_base.obj_make_list(
context, objects.TraitList(context),
objects.Trait, db_object['traits'])
self.traits.obj_reset_changes()
# NOTE(xek): We don't want to enable RPC on this call just yet. Remotable
# methods can be used in the future to replace current explicit RPC calls.
# Implications of calling new remote procedures should be thought through.
@ -329,6 +341,7 @@ class Node(base.IronicObject, object_base.VersionedObjectDictCompat):
"""
values = self.do_version_changes_for_db()
self._validate_property_values(values.get('properties'))
self._validate_and_remove_traits(values)
db_node = self.dbapi.create_node(values)
self._from_db_object(self._context, self, db_node)
@ -375,9 +388,30 @@ class Node(base.IronicObject, object_base.VersionedObjectDictCompat):
# Clean driver_internal_info when changes driver
self.driver_internal_info = {}
updates = self.do_version_changes_for_db()
self._validate_and_remove_traits(updates)
db_node = self.dbapi.update_node(self.uuid, updates)
self._from_db_object(self._context, self, db_node)
@staticmethod
def _validate_and_remove_traits(fields):
"""Validate traits in fields for create or update, remove if present.
:param fields: a dict of Node fields for create or update.
:raises: BadRequest if fields contains a traits that are not None.
"""
if 'traits' in fields:
# NOTE(mgoddard): Traits should be updated via the the node
# object's traits field, which is itself an object. We shouldn't
# get here with changes to traits, as this should be handled by the
# API. When services are pinned to Pike, we can get here with
# traits set to None in updates, due to changes made to the object
# in _convert_to_version.
if fields['traits']:
# NOTE(mgoddard): We shouldn't get here as this should be
# handled by the API.
raise exception.BadRequest()
fields.pop('traits')
# NOTE(xek): We don't want to enable RPC on this call just yet. Remotable
# methods can be used in the future to replace current explicit RPC calls.
# Implications of calling new remote procedures should be thought through.
@ -429,6 +463,9 @@ class Node(base.IronicObject, object_base.VersionedObjectDictCompat):
Version 1.22: rescue_interface field was added. Its default value is
None. For versions prior to this, it should be set to None (or
removed).
Version 1.23: traits field was added. Its default value is
None. For versions prior to this, it should be set to None (or
removed).
:param target_version: the desired version of the object
:param remove_unavailable_fields: True to remove fields that are
@ -453,6 +490,17 @@ class Node(base.IronicObject, object_base.VersionedObjectDictCompat):
# DB: set unavailable field to the default of None.
self.rescue_interface = None
traits_is_set = self.obj_attr_is_set('traits')
if target_version >= (1, 23):
# Target version supports traits.
if not traits_is_set:
self.traits = None
elif traits_is_set:
if remove_unavailable_fields:
delattr(self, 'traits')
elif self.traits is not None:
self.traits = None
@base.IronicObjectRegistry.register
class NodePayload(notification.NotificationPayloadBase):
@ -504,6 +552,8 @@ class NodePayload(notification.NotificationPayloadBase):
# field to payload and increment the object versions for all objects
# that inherit the NodePayload object.
# TODO(mgoddard): Add a traits field to the NodePayload object.
# Version 1.0: Initial version, based off of Node version 1.18.
# Version 1.1: Type of network_interface changed to just nullable string
# similar to version 1.20 of Node.

175
ironic/objects/trait.py Normal file
View File

@ -0,0 +1,175 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_versionedobjects import base as object_base
from ironic.db import api as db_api
from ironic.objects import base
from ironic.objects import fields as object_fields
@base.IronicObjectRegistry.register
class Trait(base.IronicObject):
# Version 1.0: Initial version
VERSION = '1.0'
dbapi = db_api.get_instance()
fields = {
'node_id': object_fields.StringField(),
'trait': object_fields.StringField(),
}
# NOTE(mgoddard): We don't want to enable RPC on this call just yet.
# Remotable methods can be used in the future to replace current explicit
# RPC calls. Implications of calling new remote procedures should be
# thought through.
# @object_base.remotable
def create(self, context=None):
"""Create a Trait record in the DB.
:param context: security context. NOTE: This should only
be used internally by the indirection_api.
Unfortunately, RPC requires context as the first
argument, even though we don't use it.
A context should be set when instantiating the
object, e.g.: Trait(context).
:raises: InvalidParameterValue if adding the trait would exceed the
per-node traits limit.
:raises: NodeNotFound if the node no longer appears in the database.
"""
values = self.do_version_changes_for_db()
db_trait = self.dbapi.add_node_trait(
values['node_id'], values['trait'], values['version'])
self._from_db_object(self._context, self, db_trait)
# NOTE(mgoddard): We don't want to enable RPC on this call just yet.
# Remotable methods can be used in the future to replace current explicit
# RPC calls. Implications of calling new remote procedures should be
# thought through.
# @object_base.remotable_classmethod
@classmethod
def destroy(cls, context, node_id, trait):
"""Delete the Trait from the DB.
:param context: security context. NOTE: This should only
be used internally by the indirection_api.
Unfortunately, RPC requires context as the first
argument, even though we don't use it.
A context should be set when instantiating the
object, e.g.: Trait(context).
:param node_id: The id of a node.
:param trait: A trait string.
:raises: NodeNotFound if the node no longer appears in the database.
:raises: NodeTraitNotFound if the trait is not found.
"""
cls.dbapi.delete_node_trait(node_id, trait)
# NOTE(mgoddard): We don't want to enable RPC on this call just yet.
# Remotable methods can be used in the future to replace current explicit
# RPC calls. Implications of calling new remote procedures should be
# thought through.
# @object_base.remotable_classmethod
@classmethod
def exists(cls, context, node_id, trait):
"""Check whether a Trait exists in the DB.
:param context: security context. NOTE: This should only
be used internally by the indirection_api.
Unfortunately, RPC requires context as the first
argument, even though we don't use it.
A context should be set when instantiating the
object, e.g.: Trait(context).
:param node_id: The id of a node.
:param trait: A trait string.
:returns: True if the trait exists otherwise False.
:raises: NodeNotFound if the node no longer appears in the database.
"""
return cls.dbapi.node_trait_exists(node_id, trait)
@base.IronicObjectRegistry.register
class TraitList(object_base.ObjectListBase, base.IronicObject):
# Version 1.0: Initial version
VERSION = '1.0'
dbapi = db_api.get_instance()
fields = {
'objects': object_fields.ListOfObjectsField('Trait'),
}
# NOTE(mgoddard): We don't want to enable RPC on this call just yet.
# Remotable methods can be used in the future to replace current explicit
# RPC calls. Implications of calling new remote procedures should be
# thought through.
# @object_base.remotable_classmethod
@classmethod
def get_by_node_id(cls, context, node_id):
"""Return all traits for the specified node.
:param context: security context. NOTE: This should only
be used internally by the indirection_api.
Unfortunately, RPC requires context as the first
argument, even though we don't use it.
A context should be set when instantiating the
object, e.g.: Trait(context).
:param node_id: The id of a node.
:raises: NodeNotFound if the node no longer appears in the database.
"""
db_traits = cls.dbapi.get_node_traits_by_node_id(node_id)
return object_base.obj_make_list(context, cls(), Trait, db_traits)
# NOTE(mgoddard): We don't want to enable RPC on this call just yet.
# Remotable methods can be used in the future to replace current explicit
# RPC calls. Implications of calling new remote procedures should be
# thought through.
# @object_base.remotable_classmethod
@classmethod
def create(cls, context, node_id, traits):
"""Replace all existing traits with the specified list.
:param context: security context. NOTE: This should only
be used internally by the indirection_api.
Unfortunately, RPC requires context as the first
argument, even though we don't use it.
A context should be set when instantiating the
object, e.g.: Trait(context).
:param node_id: The id of a node.
:param traits: List of Strings; traits to set.
:raises: InvalidParameterValue if adding the trait would exceed the
per-node traits limit.
:raises: NodeNotFound if the node no longer appears in the database.
"""
version = Trait.get_target_version()
db_traits = cls.dbapi.set_node_traits(node_id, traits, version)
return object_base.obj_make_list(context, cls(), Trait, db_traits)
# NOTE(mgoddard): We don't want to enable RPC on this call just yet.
# Remotable methods can be used in the future to replace current explicit
# RPC calls. Implications of calling new remote procedures should be
# thought through.
# @object_base.remotable_classmethod
@classmethod
def destroy(cls, context, node_id):
"""Delete all traits for the specified node.
:param context: security context. NOTE: This should only
be used internally by the indirection_api.
Unfortunately, RPC requires context as the first
argument, even though we don't use it.
A context should be set when instantiating the
object, e.g.: Trait(context).
:param node_id: The id of a node.
:raises: NodeNotFound if the node no longer appears in the database.
"""
cls.dbapi.unset_node_traits(node_id)

View File

@ -97,6 +97,8 @@ class ReleaseMappingsTestCase(base.TestCase):
# releases or are sent through RPC should have their counterpart
# versioned objects.
model_names -= exceptions
# NodeTrait maps to two objects
model_names |= set(['Trait', 'TraitList'])
object_names = set(
release_mappings.RELEASE_MAPPING['master']['objects'])
self.assertEqual(model_names, object_names)

View File

@ -669,6 +669,7 @@ class DbNodeTestCase(base.DbTestCase):
res = self.dbapi.get_node_by_port_addresses(addresses)
self.assertEqual(node.uuid, res.uuid)
self.assertEqual([], res.traits)
def test_get_node_by_port_addresses_not_found(self):
node = utils.create_test_node(

View File

@ -25,6 +25,7 @@ from ironic.objects import conductor
from ironic.objects import node
from ironic.objects import port
from ironic.objects import portgroup
from ironic.objects import trait
from ironic.objects import volume_connector
from ironic.objects import volume_target
@ -508,9 +509,7 @@ def get_test_xclarity_driver_info():
def get_test_node_trait(**kw):
return {
# TODO(mgoddard): Replace None below with the NodeTrait RPC object
# VERSION when the RPC object is added.
'version': kw.get('version', None),
'version': kw.get('version', trait.Trait.VERSION),
"trait": kw.get("trait", "trait1"),
"node_id": kw.get("node_id", "123"),
'created_at': kw.get('created_at'),

View File

@ -131,6 +131,20 @@ class TestNodeObject(db_base.DbTestCase, obj_utils.SchemasTestMixIn):
res_updated_at = n.updated_at.replace(tzinfo=None)
self.assertEqual(test_time, res_updated_at)
def test_save_with_traits(self):
uuid = self.fake_node['uuid']
with mock.patch.object(self.dbapi, 'get_node_by_uuid',
autospec=True) as mock_get_node:
mock_get_node.return_value = self.fake_node
with mock.patch.object(self.dbapi, 'update_node',
autospec=True) as mock_update_node:
n = objects.Node.get(self.context, uuid)
trait = objects.Trait(self.context, node_id=n.id,
trait='CUSTOM_1')
n.traits = objects.TraitList(self.context, objects=[trait])
self.assertRaises(exception.BadRequest, n.save)
self.assertFalse(mock_update_node.called)
def test_refresh(self):
uuid = self.fake_node['uuid']
returns = [dict(self.fake_node, properties={"fake": "first"}),
@ -215,7 +229,7 @@ class TestNodeObject(db_base.DbTestCase, obj_utils.SchemasTestMixIn):
mock_touch.assert_called_once_with(node.id)
def test_create(self):
node = objects.Node(self.context, **self.fake_node)
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
with mock.patch.object(self.dbapi, 'create_node',
autospec=True) as mock_create_node:
mock_create_node.return_value = db_utils.get_test_node()
@ -224,12 +238,19 @@ class TestNodeObject(db_base.DbTestCase, obj_utils.SchemasTestMixIn):
args, _kwargs = mock_create_node.call_args
self.assertEqual(objects.Node.VERSION, args[0]['version'])
self.assertEqual(1, mock_create_node.call_count)
def test_create_with_invalid_properties(self):
node = objects.Node(self.context, **self.fake_node)
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
node.properties = {"local_gb": "5G"}
self.assertRaises(exception.InvalidParameterValue, node.create)
def test_create_with_traits(self):
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
trait = objects.Trait(self.context, node_id=node.id, trait='CUSTOM_1')
node.traits = objects.TraitList(self.context, objects=[trait])
self.assertRaises(exception.BadRequest, node.create)
def test_update_with_invalid_properties(self):
uuid = self.fake_node['uuid']
with mock.patch.object(self.dbapi, 'get_node_by_uuid',
@ -271,7 +292,7 @@ class TestConvertToVersion(db_base.DbTestCase):
def test_rescue_supported_missing(self):
# rescue_interface not set, should be set to default.
node = objects.Node(self.context, **self.fake_node)
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
delattr(node, 'rescue_interface')
node.obj_reset_changes()
@ -283,7 +304,7 @@ class TestConvertToVersion(db_base.DbTestCase):
def test_rescue_supported_set(self):
# rescue_interface set, no change required.
node = objects.Node(self.context, **self.fake_node)
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
node.rescue_interface = 'fake'
node.obj_reset_changes()
@ -293,7 +314,7 @@ class TestConvertToVersion(db_base.DbTestCase):
def test_rescue_unsupported_missing(self):
# rescue_interface not set, no change required.
node = objects.Node(self.context, **self.fake_node)
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
delattr(node, 'rescue_interface')
node.obj_reset_changes()
@ -303,7 +324,7 @@ class TestConvertToVersion(db_base.DbTestCase):
def test_rescue_unsupported_set_remove(self):
# rescue_interface set, should be removed.
node = objects.Node(self.context, **self.fake_node)
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
node.rescue_interface = 'fake'
node.obj_reset_changes()
@ -313,20 +334,94 @@ class TestConvertToVersion(db_base.DbTestCase):
def test_rescue_unsupported_set_no_remove_non_default(self):
# rescue_interface set, should be set to default.
node = objects.Node(self.context, **self.fake_node)
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
node.rescue_interface = 'fake'
node.obj_reset_changes()
node._convert_to_version("1.21", False)
self.assertIsNone(node.rescue_interface)
self.assertEqual({'rescue_interface': None}, node.obj_get_changes())
self.assertEqual({'rescue_interface': None, 'traits': None},
node.obj_get_changes())
def test_rescue_unsupported_set_no_remove_default(self):
# rescue_interface set, no change required.
node = objects.Node(self.context, **self.fake_node)
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
node.rescue_interface = None
node.traits = None
node.obj_reset_changes()
node._convert_to_version("1.21", False)
self.assertIsNone(node.rescue_interface)
self.assertEqual({}, node.obj_get_changes())
def test_traits_supported_missing(self):
# traits not set, should be set to default.
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
delattr(node, 'traits')
node.obj_reset_changes()
node._convert_to_version("1.23")
self.assertIsNone(node.traits)
self.assertEqual({'traits': None},
node.obj_get_changes())
def test_traits_supported_set(self):
# traits set, no change required.
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
traits = objects.TraitList(
objects=[objects.Trait('CUSTOM_TRAIT')])
traits.obj_reset_changes()
node.traits = traits
node.obj_reset_changes()
node._convert_to_version("1.23")
self.assertEqual(traits, node.traits)
self.assertEqual({}, node.obj_get_changes())
def test_traits_unsupported_missing_remove(self):
# traits not set, no change required.
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
delattr(node, 'traits')
node.obj_reset_changes()
node._convert_to_version("1.22")
self.assertNotIn('traits', node)
self.assertEqual({}, node.obj_get_changes())
def test_traits_unsupported_missing(self):
# traits not set, should be set to default.
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
delattr(node, 'traits')
node.obj_reset_changes()
node._convert_to_version("1.22", False)
self.assertNotIn('traits', node)
self.assertEqual({}, node.obj_get_changes())
def test_trait_unsupported_set_no_remove_non_default(self):
# traits set, should be set to default.
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
node.traits = objects.TraitList(self.ctxt)
node.traits.obj_reset_changes()
node.obj_reset_changes()
node._convert_to_version("1.22", False)
self.assertIsNone(node.traits)
self.assertEqual({'traits': None},
node.obj_get_changes())
def test_trait_unsupported_set_no_remove_default(self):
# traits set, no change required.
node = obj_utils.get_test_node(self.ctxt, **self.fake_node)
node.traits = None
node.obj_reset_changes()
node._convert_to_version("1.22", False)
self.assertIsNone(node.traits)
self.assertEqual({}, node.obj_get_changes())

View File

@ -684,7 +684,7 @@ class TestObject(_LocalTest, _TestObject):
# version bump. It is an MD5 hash of the object fields and remotable methods.
# The fingerprint values should only be changed if there is a version bump.
expected_object_fingerprints = {
'Node': '1.22-f2c453dd0b42aec8d4833a69a9ac79f3',
'Node': '1.23-6bebf8dbcd2ce15407c946bd091f80b4',
'MyObj': '1.5-9459d30d6954bffc7a9afd347a807ca6',
'Chassis': '1.3-d656e039fd8ae9f34efc232ab3980905',
'Port': '1.7-898a47921f4a1f53fcdddd4eeb179e0b',
@ -717,6 +717,8 @@ expected_object_fingerprints = {
'VolumeConnectorCRUDPayload': '1.0-5e8dbb41e05b6149d8f7bfd4daff9339',
'VolumeTargetCRUDNotification': '1.0-59acc533c11d306f149846f922739c15',
'VolumeTargetCRUDPayload': '1.0-30dcc4735512c104a3a36a2ae1e2aeb2',
'Trait': '1.0-3f26cb70c8a10a3807d64c219453e347',
'TraitList': '1.0-33a2e1bb91ad4082f9f63429b77c1244',
}

View File

@ -0,0 +1,89 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from ironic.common import context
from ironic.db import api as dbapi
from ironic import objects
from ironic.tests.unit.db import base as db_base
from ironic.tests.unit.db import utils as db_utils
from ironic.tests.unit.objects import utils as obj_utils
class TestTraitObject(db_base.DbTestCase, obj_utils.SchemasTestMixIn):
def setUp(self):
super(TestTraitObject, self).setUp()
self.ctxt = context.get_admin_context()
self.fake_trait = db_utils.get_test_node_trait()
self.node_id = self.fake_trait['node_id']
@mock.patch.object(dbapi.IMPL, 'get_node_traits_by_node_id', autospec=True)
def test_get_by_id(self, mock_get_traits):
mock_get_traits.return_value = [self.fake_trait]
traits = objects.TraitList.get_by_node_id(self.context, self.node_id)
mock_get_traits.assert_called_once_with(self.node_id)
self.assertEqual(self.context, traits._context)
self.assertEqual(1, len(traits))
self.assertEqual(self.fake_trait['trait'], traits[0].trait)
self.assertEqual(self.fake_trait['node_id'], traits[0].node_id)
@mock.patch.object(dbapi.IMPL, 'set_node_traits', autospec=True)
def test_create_list(self, mock_set_traits):
traits = [self.fake_trait['trait']]
mock_set_traits.return_value = [self.fake_trait, self.fake_trait]
result = objects.TraitList.create(self.context, self.node_id, traits)
mock_set_traits.assert_called_once_with(self.node_id, traits, '1.0')
self.assertEqual(self.context, result._context)
self.assertEqual(2, len(result))
self.assertEqual(self.fake_trait['node_id'], result[0].node_id)
@mock.patch.object(dbapi.IMPL, 'unset_node_traits', autospec=True)
def test_destroy_list(self, mock_unset_traits):
objects.TraitList.destroy(self.context, self.node_id)
mock_unset_traits.assert_called_once_with(self.node_id)
@mock.patch.object(dbapi.IMPL, 'add_node_trait', autospec=True)
def test_create(self, mock_add_trait):
trait = objects.Trait(context=self.context, node_id=self.node_id,
trait="fake")
mock_add_trait.return_value = self.fake_trait
trait.create()
mock_add_trait.assert_called_once_with(self.node_id, 'fake', '1.0')
self.assertEqual(self.fake_trait['trait'], trait.trait)
self.assertEqual(self.fake_trait['node_id'], trait.node_id)
@mock.patch.object(dbapi.IMPL, 'delete_node_trait', autospec=True)
def test_destroy(self, mock_delete_trait):
objects.Trait.destroy(self.context, self.node_id, "trait")
mock_delete_trait.assert_called_once_with(self.node_id, "trait")
@mock.patch.object(dbapi.IMPL, 'node_trait_exists', autospec=True)
def test_exists(self, mock_trait_exists):
mock_trait_exists.return_value = True
result = objects.Trait.exists(self.context, self.node_id, "trait")
self.assertTrue(result)
mock_trait_exists.assert_called_once_with(self.node_id, "trait")

View File

@ -58,7 +58,17 @@ def get_test_node(ctxt, **kw):
del db_node['id']
node = objects.Node(ctxt)
for key in db_node:
setattr(node, key, db_node[key])
if key == 'traits':
# convert list of strings to object
raw_traits = db_node['traits']
trait_list = []
for raw_trait in raw_traits:
trait = objects.Trait(ctxt, trait=raw_trait)
trait_list.append(trait)
node.traits = objects.TraitList(ctxt, objects=trait_list)
node.traits.obj_reset_changes()
else:
setattr(node, key, db_node[key])
return node