Query is better abstracted

This commit is contained in:
Rajaram Mallya 2011-11-17 12:21:09 +05:30
parent 124328b628
commit 23d9804321
4 changed files with 87 additions and 82 deletions

View File

@ -25,6 +25,59 @@ db_api = utils.import_object(config.Config.get("db_api_implementation",
"melange.db.sqlalchemy.api")) "melange.db.sqlalchemy.api"))
class Query(object):
"""Mimics sqlalchemy query object.
This class allows us to store query conditions and use them with
bulk updates and deletes just like sqlalchemy query object.
Using this class makes the models independent of sqlalchemy
"""
def __init__(self, model, query_func, **conditions):
self._query_func = query_func
self._model = model
self._conditions = conditions
def all(self):
return db_api.list(self._query_func, self._model, **self._conditions)
def count(self):
return db_api.count(self._query_func, self._model, **self._conditions)
def __iter__(self):
return iter(self.all())
def update(self, **values):
db_api.update_all(self._query_func, self._model, self._conditions,
values)
def delete(self):
db_api.delete_all(self._query_func, self._model, **self._conditions)
def limit(self, limit=200, marker=None, marker_column=None):
return db_api.find_all_by_limit(self._query_func,
self._model,
self._conditions,
limit=limit,
marker=marker,
marker_column=marker_column)
def paginated_collection(self, limit=200, marker=None, marker_column=None):
collection = self.limit(int(limit) + 1, marker, marker_column)
if len(collection) > int(limit):
return (collection[0:-1], collection[-2]['id'])
return (collection, None)
class Queryable(object):
def __getattr__(self, item):
return lambda model, **conditions: Query(
model, query_func=getattr(db_api, item), **conditions)
db_query = Queryable()
def add_options(parser): def add_options(parser):
"""Adds any configuration options that the db layer might have. """Adds any configuration options that the db layer might have.

View File

@ -19,7 +19,6 @@ import sqlalchemy.exc
from sqlalchemy import and_ from sqlalchemy import and_
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy.orm import aliased from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from melange import ipam from melange import ipam
from melange.common import exception from melange.common import exception
@ -29,21 +28,22 @@ from melange.db.sqlalchemy import mappers
from melange.db.sqlalchemy import session from melange.db.sqlalchemy import session
def list(query): def list(query_func, *args, **kwargs):
return query.all() return query_func(*args, **kwargs).all()
def count(query): def count(query, *args, **kwargs):
return query.count() return query(*args, **kwargs).count()
def find_all_by(model, **conditions): def find_all(model, **conditions):
return _query_by(model, **conditions) return _query_by(model, **conditions)
def find_all_by_limit(query_func, model, conditions, limit, marker=None, def find_all_by_limit(query_func, model, conditions, limit, marker=None,
marker_column=None): marker_column=None):
return _limits(query_func, model, conditions, limit, marker, marker_column) return _limits(query_func, model, conditions, limit, marker,
marker_column).all()
def find_by(model, **kwargs): def find_by(model, **kwargs):

View File

@ -27,56 +27,12 @@ from melange.common import config
from melange.common import exception from melange.common import exception
from melange.common import utils from melange.common import utils
from melange.db import db_api from melange.db import db_api
from melange.db import db_query
LOG = logging.getLogger('melange.ipam.models') LOG = logging.getLogger('melange.ipam.models')
class Query(object):
"""Mimics sqlalchemy query object.
This class allows us to store query conditions and use them with
bulk updates and deletes just like sqlalchemy query object.
Using this class makes the models independent of sqlalchemy
"""
def __init__(self, model, query_func=None, **conditions):
self._query_func = query_func or db_api.find_all_by
self._model = model
self._conditions = conditions
def all(self):
return db_api.list(self._query_func(self._model, **self._conditions))
def count(self):
return db_api.count(self._query_func(self._model, **self._conditions))
def __iter__(self):
return iter(self.all())
def update(self, **values):
db_api.update_all(self._query_func, self._model, self._conditions,
values)
def delete(self):
db_api.delete_all(self._query_func, self._model, **self._conditions)
def limit(self, limit=200, marker=None, marker_column=None):
limit_query = db_api.find_all_by_limit(self._query_func,
self._model,
self._conditions,
limit=limit,
marker=marker,
marker_column=marker_column)
return db_api.list(limit_query)
def paginated_collection(self, limit=200, marker=None, marker_column=None):
collection = self.limit(int(limit) + 1, marker, marker_column)
if len(collection) > int(limit):
return (collection[0:-1], collection[-2]['id'])
return (collection, None)
class Converter(object): class Converter(object):
data_type_converters = { data_type_converters = {
@ -193,11 +149,11 @@ class ModelBase(object):
@classmethod @classmethod
def find_all(cls, **kwargs): def find_all(cls, **kwargs):
return Query(cls, **cls._process_conditions(kwargs)) return db_query.find_all(cls, **cls._process_conditions(kwargs))
@classmethod @classmethod
def count(cls, **conditions): def count(cls, **conditions):
return Query(cls, **conditions).count() return cls.find_all(**conditions).count()
def merge_attributes(self, values): def merge_attributes(self, values):
"""dict.update() behaviour.""" """dict.update() behaviour."""
@ -571,16 +527,13 @@ class IpAddress(ModelBase):
@classmethod @classmethod
def find_all_by_network(cls, network_id, **conditions): def find_all_by_network(cls, network_id, **conditions):
return Query(cls, return db_query.find_all_ips_in_network(cls,
query_func=db_api.find_all_ips_in_network, network_id=network_id,
network_id=network_id, **conditions)
**conditions)
@classmethod @classmethod
def find_all_allocated_ips(cls, **conditions): def find_all_allocated_ips(cls, **conditions):
return Query(cls, return db_query.find_all_allocated_ips(cls, **conditions)
query_func=db_api.find_all_allocated_ips,
**conditions)
def delete(self): def delete(self):
if self._explicitly_allowed_on_interfaces(): if self._explicitly_allowed_on_interfaces():
@ -593,9 +546,8 @@ class IpAddress(ModelBase):
super(IpAddress, self).delete() super(IpAddress, self).delete()
def _explicitly_allowed_on_interfaces(self): def _explicitly_allowed_on_interfaces(self):
return Query(IpAddress, return db_query.find_allowed_ips(IpAddress,
query_func=db_api.find_allowed_ips, ip_address_id=self.id).count() > 0
ip_address_id=self.id).count() > 0
def _before_save(self): def _before_save(self):
self.address = self._formatted(self.address) self.address = self._formatted(self.address)
@ -619,10 +571,9 @@ class IpAddress(ModelBase):
self.update(marked_for_deallocation=False, deallocated_at=None) self.update(marked_for_deallocation=False, deallocated_at=None)
def inside_globals(self, **kwargs): def inside_globals(self, **kwargs):
return Query(IpAddress, return db_query.find_inside_globals(IpAddress,
query_func=db_api.find_inside_globals, local_address_id=self.id,
local_address_id=self.id, **kwargs)
**kwargs)
def add_inside_globals(self, ip_addresses): def add_inside_globals(self, ip_addresses):
db_api.save_nat_relationships([ db_api.save_nat_relationships([
@ -633,10 +584,9 @@ class IpAddress(ModelBase):
for global_address in ip_addresses]) for global_address in ip_addresses])
def inside_locals(self, **kwargs): def inside_locals(self, **kwargs):
return Query(IpAddress, return db_query.find_inside_locals(IpAddress,
query_func=db_api.find_inside_locals, global_address_id=self.id,
global_address_id=self.id, **kwargs)
**kwargs)
def remove_inside_globals(self, inside_global_address=None): def remove_inside_globals(self, inside_global_address=None):
return db_api.remove_inside_globals(self.id, inside_global_address) return db_api.remove_inside_globals(self.id, inside_global_address)
@ -825,9 +775,8 @@ class Interface(ModelBase):
db_api.remove_allowed_ip(interface_id=self.id, ip_address_id=ip.id) db_api.remove_allowed_ip(interface_id=self.id, ip_address_id=ip.id)
def ips_allowed(self): def ips_allowed(self):
explicitly_allowed = Query(IpAddress, explicitly_allowed = db_query.find_allowed_ips(
query_func=db_api.find_allowed_ips, IpAddress, allowed_on_interface_id=self.id)
allowed_on_interface_id=self.id)
allocated_ips = IpAddress.find_all_allocated_ips(interface_id=self.id) allocated_ips = IpAddress.find_all_allocated_ips(interface_id=self.id)
return list(set(allocated_ips) | set(explicitly_allowed)) return list(set(allocated_ips) | set(explicitly_allowed))

View File

@ -22,6 +22,7 @@ import netaddr
from melange import tests from melange import tests
from melange.common import exception from melange.common import exception
from melange.common import utils from melange.common import utils
from melange.db import db_query
from melange.ipam import models from melange.ipam import models
from melange.tests import unit from melange.tests import unit
from melange.tests.factories import models as factory_models from melange.tests.factories import models as factory_models
@ -91,7 +92,7 @@ class TestQuery(tests.BaseTest):
block2 = factory_models.IpBlockFactory(network_id="1") block2 = factory_models.IpBlockFactory(network_id="1")
noise_block = factory_models.IpBlockFactory(network_id="999") noise_block = factory_models.IpBlockFactory(network_id="999")
blocks = models.Query(models.IpBlock, network_id="1").all() blocks = db_query.find_all(models.IpBlock, network_id="1").all()
self.assertModelsEqual(blocks, [block1, block2]) self.assertModelsEqual(blocks, [block1, block2])
@ -100,7 +101,7 @@ class TestQuery(tests.BaseTest):
factory_models.IpBlockFactory(network_id="1") factory_models.IpBlockFactory(network_id="1")
noise_block = factory_models.IpBlockFactory(network_id="999") noise_block = factory_models.IpBlockFactory(network_id="999")
count = models.Query(models.IpBlock, network_id="1").count() count = db_query.find_all(models.IpBlock, network_id="1").count()
self.assertEqual(count, 2) self.assertEqual(count, 2)
@ -109,7 +110,7 @@ class TestQuery(tests.BaseTest):
block2 = factory_models.IpBlockFactory(network_id="1") block2 = factory_models.IpBlockFactory(network_id="1")
noise_block = factory_models.IpBlockFactory(network_id="999") noise_block = factory_models.IpBlockFactory(network_id="999")
query = models.Query(models.IpBlock, network_id="1") query = db_query.find_all(models.IpBlock, network_id="1")
blocks = [block for block in query] blocks = [block for block in query]
self.assertModelsEqual(blocks, [block1, block2]) self.assertModelsEqual(blocks, [block1, block2])
@ -123,8 +124,9 @@ class TestQuery(tests.BaseTest):
]) ])
marker_block = blocks[1] marker_block = blocks[1]
paginated_blocks = models.Query(models.IpBlock).limit(limit=2, all_blocks_query = db_query.find_all(models.IpBlock)
marker=marker_block.id) paginated_blocks = all_blocks_query.limit(limit=2,
marker=marker_block.id)
self.assertEqual(len(paginated_blocks), 2) self.assertEqual(len(paginated_blocks), 2)
self.assertEqual(paginated_blocks, [blocks[2], blocks[3]]) self.assertEqual(paginated_blocks, [blocks[2], blocks[3]])
@ -134,7 +136,8 @@ class TestQuery(tests.BaseTest):
block2 = factory_models.IpBlockFactory(network_id="1") block2 = factory_models.IpBlockFactory(network_id="1")
noise_block = factory_models.IpBlockFactory(network_id="999") noise_block = factory_models.IpBlockFactory(network_id="999")
models.Query(models.IpBlock, network_id="1").update(network_id="2") db_query.find_all(models.IpBlock,
network_id="1").update(network_id="2")
self.assertEqual(models.IpBlock.find(block1.id).network_id, "2") self.assertEqual(models.IpBlock.find(block1.id).network_id, "2")
self.assertEqual(models.IpBlock.find(block2.id).network_id, "2") self.assertEqual(models.IpBlock.find(block2.id).network_id, "2")
@ -146,7 +149,7 @@ class TestQuery(tests.BaseTest):
block2 = factory_models.IpBlockFactory(network_id="1") block2 = factory_models.IpBlockFactory(network_id="1")
noise_block = factory_models.IpBlockFactory(network_id="999") noise_block = factory_models.IpBlockFactory(network_id="999")
models.Query(models.IpBlock, network_id="1").delete() db_query.find_all(models.IpBlock, network_id="1").delete()
self.assertIsNone(models.IpBlock.get(block1.id)) self.assertIsNone(models.IpBlock.get(block1.id))
self.assertIsNone(models.IpBlock.get(block2.id)) self.assertIsNone(models.IpBlock.get(block2.id))