diff --git a/melange/db/__init__.py b/melange/db/__init__.py index b1f68b7b..28596d86 100644 --- a/melange/db/__init__.py +++ b/melange/db/__init__.py @@ -25,6 +25,59 @@ db_api = utils.import_object(config.Config.get("db_api_implementation", "melange.db.sqlalchemy.api")) +class Query(object): + """Mimics sqlalchemy query object. + + This class allows us to store query conditions and use them with + bulk updates and deletes just like sqlalchemy query object. + Using this class makes the models independent of sqlalchemy + + """ + def __init__(self, model, query_func, **conditions): + self._query_func = query_func + self._model = model + self._conditions = conditions + + def all(self): + return db_api.list(self._query_func, self._model, **self._conditions) + + def count(self): + return db_api.count(self._query_func, self._model, **self._conditions) + + def __iter__(self): + return iter(self.all()) + + def update(self, **values): + db_api.update_all(self._query_func, self._model, self._conditions, + values) + + def delete(self): + db_api.delete_all(self._query_func, self._model, **self._conditions) + + def limit(self, limit=200, marker=None, marker_column=None): + return db_api.find_all_by_limit(self._query_func, + self._model, + self._conditions, + limit=limit, + marker=marker, + marker_column=marker_column) + + def paginated_collection(self, limit=200, marker=None, marker_column=None): + collection = self.limit(int(limit) + 1, marker, marker_column) + if len(collection) > int(limit): + return (collection[0:-1], collection[-2]['id']) + return (collection, None) + + +class Queryable(object): + + def __getattr__(self, item): + return lambda model, **conditions: Query( + model, query_func=getattr(db_api, item), **conditions) + +db_query = Queryable() + + def add_options(parser): """Adds any configuration options that the db layer might have. diff --git a/melange/db/sqlalchemy/api.py b/melange/db/sqlalchemy/api.py index 199204f9..fe723606 100644 --- a/melange/db/sqlalchemy/api.py +++ b/melange/db/sqlalchemy/api.py @@ -19,7 +19,6 @@ import sqlalchemy.exc from sqlalchemy import and_ from sqlalchemy import or_ from sqlalchemy.orm import aliased -from sqlalchemy.orm import joinedload from melange import ipam from melange.common import exception @@ -29,21 +28,22 @@ from melange.db.sqlalchemy import mappers from melange.db.sqlalchemy import session -def list(query): - return query.all() +def list(query_func, *args, **kwargs): + return query_func(*args, **kwargs).all() -def count(query): - return query.count() +def count(query, *args, **kwargs): + return query(*args, **kwargs).count() -def find_all_by(model, **conditions): +def find_all(model, **conditions): return _query_by(model, **conditions) def find_all_by_limit(query_func, model, conditions, limit, marker=None, marker_column=None): - return _limits(query_func, model, conditions, limit, marker, marker_column) + return _limits(query_func, model, conditions, limit, marker, + marker_column).all() def find_by(model, **kwargs): diff --git a/melange/ipam/models.py b/melange/ipam/models.py index 2bc2539e..33e5abd3 100644 --- a/melange/ipam/models.py +++ b/melange/ipam/models.py @@ -27,56 +27,12 @@ from melange.common import config from melange.common import exception from melange.common import utils from melange.db import db_api +from melange.db import db_query LOG = logging.getLogger('melange.ipam.models') -class Query(object): - """Mimics sqlalchemy query object. - - This class allows us to store query conditions and use them with - bulk updates and deletes just like sqlalchemy query object. - Using this class makes the models independent of sqlalchemy - - """ - def __init__(self, model, query_func=None, **conditions): - self._query_func = query_func or db_api.find_all_by - self._model = model - self._conditions = conditions - - def all(self): - return db_api.list(self._query_func(self._model, **self._conditions)) - - def count(self): - return db_api.count(self._query_func(self._model, **self._conditions)) - - def __iter__(self): - return iter(self.all()) - - def update(self, **values): - db_api.update_all(self._query_func, self._model, self._conditions, - values) - - def delete(self): - db_api.delete_all(self._query_func, self._model, **self._conditions) - - def limit(self, limit=200, marker=None, marker_column=None): - limit_query = db_api.find_all_by_limit(self._query_func, - self._model, - self._conditions, - limit=limit, - marker=marker, - marker_column=marker_column) - return db_api.list(limit_query) - - def paginated_collection(self, limit=200, marker=None, marker_column=None): - collection = self.limit(int(limit) + 1, marker, marker_column) - if len(collection) > int(limit): - return (collection[0:-1], collection[-2]['id']) - return (collection, None) - - class Converter(object): data_type_converters = { @@ -193,11 +149,11 @@ class ModelBase(object): @classmethod def find_all(cls, **kwargs): - return Query(cls, **cls._process_conditions(kwargs)) + return db_query.find_all(cls, **cls._process_conditions(kwargs)) @classmethod def count(cls, **conditions): - return Query(cls, **conditions).count() + return cls.find_all(**conditions).count() def merge_attributes(self, values): """dict.update() behaviour.""" @@ -571,16 +527,13 @@ class IpAddress(ModelBase): @classmethod def find_all_by_network(cls, network_id, **conditions): - return Query(cls, - query_func=db_api.find_all_ips_in_network, - network_id=network_id, - **conditions) + return db_query.find_all_ips_in_network(cls, + network_id=network_id, + **conditions) @classmethod def find_all_allocated_ips(cls, **conditions): - return Query(cls, - query_func=db_api.find_all_allocated_ips, - **conditions) + return db_query.find_all_allocated_ips(cls, **conditions) def delete(self): if self._explicitly_allowed_on_interfaces(): @@ -593,9 +546,8 @@ class IpAddress(ModelBase): super(IpAddress, self).delete() def _explicitly_allowed_on_interfaces(self): - return Query(IpAddress, - query_func=db_api.find_allowed_ips, - ip_address_id=self.id).count() > 0 + return db_query.find_allowed_ips(IpAddress, + ip_address_id=self.id).count() > 0 def _before_save(self): self.address = self._formatted(self.address) @@ -619,10 +571,9 @@ class IpAddress(ModelBase): self.update(marked_for_deallocation=False, deallocated_at=None) def inside_globals(self, **kwargs): - return Query(IpAddress, - query_func=db_api.find_inside_globals, - local_address_id=self.id, - **kwargs) + return db_query.find_inside_globals(IpAddress, + local_address_id=self.id, + **kwargs) def add_inside_globals(self, ip_addresses): db_api.save_nat_relationships([ @@ -633,10 +584,9 @@ class IpAddress(ModelBase): for global_address in ip_addresses]) def inside_locals(self, **kwargs): - return Query(IpAddress, - query_func=db_api.find_inside_locals, - global_address_id=self.id, - **kwargs) + return db_query.find_inside_locals(IpAddress, + global_address_id=self.id, + **kwargs) def remove_inside_globals(self, inside_global_address=None): return db_api.remove_inside_globals(self.id, inside_global_address) @@ -825,9 +775,8 @@ class Interface(ModelBase): db_api.remove_allowed_ip(interface_id=self.id, ip_address_id=ip.id) def ips_allowed(self): - explicitly_allowed = Query(IpAddress, - query_func=db_api.find_allowed_ips, - allowed_on_interface_id=self.id) + explicitly_allowed = db_query.find_allowed_ips( + IpAddress, allowed_on_interface_id=self.id) allocated_ips = IpAddress.find_all_allocated_ips(interface_id=self.id) return list(set(allocated_ips) | set(explicitly_allowed)) diff --git a/melange/tests/unit/test_ipam_models.py b/melange/tests/unit/test_ipam_models.py index 97ac4ee5..3c1e69a5 100644 --- a/melange/tests/unit/test_ipam_models.py +++ b/melange/tests/unit/test_ipam_models.py @@ -22,6 +22,7 @@ import netaddr from melange import tests from melange.common import exception from melange.common import utils +from melange.db import db_query from melange.ipam import models from melange.tests import unit from melange.tests.factories import models as factory_models @@ -91,7 +92,7 @@ class TestQuery(tests.BaseTest): block2 = factory_models.IpBlockFactory(network_id="1") noise_block = factory_models.IpBlockFactory(network_id="999") - blocks = models.Query(models.IpBlock, network_id="1").all() + blocks = db_query.find_all(models.IpBlock, network_id="1").all() self.assertModelsEqual(blocks, [block1, block2]) @@ -100,7 +101,7 @@ class TestQuery(tests.BaseTest): factory_models.IpBlockFactory(network_id="1") noise_block = factory_models.IpBlockFactory(network_id="999") - count = models.Query(models.IpBlock, network_id="1").count() + count = db_query.find_all(models.IpBlock, network_id="1").count() self.assertEqual(count, 2) @@ -109,7 +110,7 @@ class TestQuery(tests.BaseTest): block2 = factory_models.IpBlockFactory(network_id="1") noise_block = factory_models.IpBlockFactory(network_id="999") - query = models.Query(models.IpBlock, network_id="1") + query = db_query.find_all(models.IpBlock, network_id="1") blocks = [block for block in query] self.assertModelsEqual(blocks, [block1, block2]) @@ -123,8 +124,9 @@ class TestQuery(tests.BaseTest): ]) marker_block = blocks[1] - paginated_blocks = models.Query(models.IpBlock).limit(limit=2, - marker=marker_block.id) + all_blocks_query = db_query.find_all(models.IpBlock) + paginated_blocks = all_blocks_query.limit(limit=2, + marker=marker_block.id) self.assertEqual(len(paginated_blocks), 2) self.assertEqual(paginated_blocks, [blocks[2], blocks[3]]) @@ -134,7 +136,8 @@ class TestQuery(tests.BaseTest): block2 = factory_models.IpBlockFactory(network_id="1") noise_block = factory_models.IpBlockFactory(network_id="999") - models.Query(models.IpBlock, network_id="1").update(network_id="2") + db_query.find_all(models.IpBlock, + network_id="1").update(network_id="2") self.assertEqual(models.IpBlock.find(block1.id).network_id, "2") self.assertEqual(models.IpBlock.find(block2.id).network_id, "2") @@ -146,7 +149,7 @@ class TestQuery(tests.BaseTest): block2 = factory_models.IpBlockFactory(network_id="1") noise_block = factory_models.IpBlockFactory(network_id="999") - models.Query(models.IpBlock, network_id="1").delete() + db_query.find_all(models.IpBlock, network_id="1").delete() self.assertIsNone(models.IpBlock.get(block1.id)) self.assertIsNone(models.IpBlock.get(block2.id))