Query is better abstracted

This commit is contained in:
Rajaram Mallya 2011-11-17 12:21:09 +05:30
parent 124328b628
commit 23d9804321
4 changed files with 87 additions and 82 deletions

View File

@ -25,6 +25,59 @@ db_api = utils.import_object(config.Config.get("db_api_implementation",
"melange.db.sqlalchemy.api"))
class Query(object):
"""Mimics sqlalchemy query object.
This class allows us to store query conditions and use them with
bulk updates and deletes just like sqlalchemy query object.
Using this class makes the models independent of sqlalchemy
"""
def __init__(self, model, query_func, **conditions):
self._query_func = query_func
self._model = model
self._conditions = conditions
def all(self):
return db_api.list(self._query_func, self._model, **self._conditions)
def count(self):
return db_api.count(self._query_func, self._model, **self._conditions)
def __iter__(self):
return iter(self.all())
def update(self, **values):
db_api.update_all(self._query_func, self._model, self._conditions,
values)
def delete(self):
db_api.delete_all(self._query_func, self._model, **self._conditions)
def limit(self, limit=200, marker=None, marker_column=None):
return db_api.find_all_by_limit(self._query_func,
self._model,
self._conditions,
limit=limit,
marker=marker,
marker_column=marker_column)
def paginated_collection(self, limit=200, marker=None, marker_column=None):
collection = self.limit(int(limit) + 1, marker, marker_column)
if len(collection) > int(limit):
return (collection[0:-1], collection[-2]['id'])
return (collection, None)
class Queryable(object):
def __getattr__(self, item):
return lambda model, **conditions: Query(
model, query_func=getattr(db_api, item), **conditions)
db_query = Queryable()
def add_options(parser):
"""Adds any configuration options that the db layer might have.

View File

@ -19,7 +19,6 @@ import sqlalchemy.exc
from sqlalchemy import and_
from sqlalchemy import or_
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from melange import ipam
from melange.common import exception
@ -29,21 +28,22 @@ from melange.db.sqlalchemy import mappers
from melange.db.sqlalchemy import session
def list(query):
return query.all()
def list(query_func, *args, **kwargs):
return query_func(*args, **kwargs).all()
def count(query):
return query.count()
def count(query, *args, **kwargs):
return query(*args, **kwargs).count()
def find_all_by(model, **conditions):
def find_all(model, **conditions):
return _query_by(model, **conditions)
def find_all_by_limit(query_func, model, conditions, limit, marker=None,
marker_column=None):
return _limits(query_func, model, conditions, limit, marker, marker_column)
return _limits(query_func, model, conditions, limit, marker,
marker_column).all()
def find_by(model, **kwargs):

View File

@ -27,56 +27,12 @@ from melange.common import config
from melange.common import exception
from melange.common import utils
from melange.db import db_api
from melange.db import db_query
LOG = logging.getLogger('melange.ipam.models')
class Query(object):
"""Mimics sqlalchemy query object.
This class allows us to store query conditions and use them with
bulk updates and deletes just like sqlalchemy query object.
Using this class makes the models independent of sqlalchemy
"""
def __init__(self, model, query_func=None, **conditions):
self._query_func = query_func or db_api.find_all_by
self._model = model
self._conditions = conditions
def all(self):
return db_api.list(self._query_func(self._model, **self._conditions))
def count(self):
return db_api.count(self._query_func(self._model, **self._conditions))
def __iter__(self):
return iter(self.all())
def update(self, **values):
db_api.update_all(self._query_func, self._model, self._conditions,
values)
def delete(self):
db_api.delete_all(self._query_func, self._model, **self._conditions)
def limit(self, limit=200, marker=None, marker_column=None):
limit_query = db_api.find_all_by_limit(self._query_func,
self._model,
self._conditions,
limit=limit,
marker=marker,
marker_column=marker_column)
return db_api.list(limit_query)
def paginated_collection(self, limit=200, marker=None, marker_column=None):
collection = self.limit(int(limit) + 1, marker, marker_column)
if len(collection) > int(limit):
return (collection[0:-1], collection[-2]['id'])
return (collection, None)
class Converter(object):
data_type_converters = {
@ -193,11 +149,11 @@ class ModelBase(object):
@classmethod
def find_all(cls, **kwargs):
return Query(cls, **cls._process_conditions(kwargs))
return db_query.find_all(cls, **cls._process_conditions(kwargs))
@classmethod
def count(cls, **conditions):
return Query(cls, **conditions).count()
return cls.find_all(**conditions).count()
def merge_attributes(self, values):
"""dict.update() behaviour."""
@ -571,16 +527,13 @@ class IpAddress(ModelBase):
@classmethod
def find_all_by_network(cls, network_id, **conditions):
return Query(cls,
query_func=db_api.find_all_ips_in_network,
network_id=network_id,
**conditions)
return db_query.find_all_ips_in_network(cls,
network_id=network_id,
**conditions)
@classmethod
def find_all_allocated_ips(cls, **conditions):
return Query(cls,
query_func=db_api.find_all_allocated_ips,
**conditions)
return db_query.find_all_allocated_ips(cls, **conditions)
def delete(self):
if self._explicitly_allowed_on_interfaces():
@ -593,9 +546,8 @@ class IpAddress(ModelBase):
super(IpAddress, self).delete()
def _explicitly_allowed_on_interfaces(self):
return Query(IpAddress,
query_func=db_api.find_allowed_ips,
ip_address_id=self.id).count() > 0
return db_query.find_allowed_ips(IpAddress,
ip_address_id=self.id).count() > 0
def _before_save(self):
self.address = self._formatted(self.address)
@ -619,10 +571,9 @@ class IpAddress(ModelBase):
self.update(marked_for_deallocation=False, deallocated_at=None)
def inside_globals(self, **kwargs):
return Query(IpAddress,
query_func=db_api.find_inside_globals,
local_address_id=self.id,
**kwargs)
return db_query.find_inside_globals(IpAddress,
local_address_id=self.id,
**kwargs)
def add_inside_globals(self, ip_addresses):
db_api.save_nat_relationships([
@ -633,10 +584,9 @@ class IpAddress(ModelBase):
for global_address in ip_addresses])
def inside_locals(self, **kwargs):
return Query(IpAddress,
query_func=db_api.find_inside_locals,
global_address_id=self.id,
**kwargs)
return db_query.find_inside_locals(IpAddress,
global_address_id=self.id,
**kwargs)
def remove_inside_globals(self, inside_global_address=None):
return db_api.remove_inside_globals(self.id, inside_global_address)
@ -825,9 +775,8 @@ class Interface(ModelBase):
db_api.remove_allowed_ip(interface_id=self.id, ip_address_id=ip.id)
def ips_allowed(self):
explicitly_allowed = Query(IpAddress,
query_func=db_api.find_allowed_ips,
allowed_on_interface_id=self.id)
explicitly_allowed = db_query.find_allowed_ips(
IpAddress, allowed_on_interface_id=self.id)
allocated_ips = IpAddress.find_all_allocated_ips(interface_id=self.id)
return list(set(allocated_ips) | set(explicitly_allowed))

View File

@ -22,6 +22,7 @@ import netaddr
from melange import tests
from melange.common import exception
from melange.common import utils
from melange.db import db_query
from melange.ipam import models
from melange.tests import unit
from melange.tests.factories import models as factory_models
@ -91,7 +92,7 @@ class TestQuery(tests.BaseTest):
block2 = factory_models.IpBlockFactory(network_id="1")
noise_block = factory_models.IpBlockFactory(network_id="999")
blocks = models.Query(models.IpBlock, network_id="1").all()
blocks = db_query.find_all(models.IpBlock, network_id="1").all()
self.assertModelsEqual(blocks, [block1, block2])
@ -100,7 +101,7 @@ class TestQuery(tests.BaseTest):
factory_models.IpBlockFactory(network_id="1")
noise_block = factory_models.IpBlockFactory(network_id="999")
count = models.Query(models.IpBlock, network_id="1").count()
count = db_query.find_all(models.IpBlock, network_id="1").count()
self.assertEqual(count, 2)
@ -109,7 +110,7 @@ class TestQuery(tests.BaseTest):
block2 = factory_models.IpBlockFactory(network_id="1")
noise_block = factory_models.IpBlockFactory(network_id="999")
query = models.Query(models.IpBlock, network_id="1")
query = db_query.find_all(models.IpBlock, network_id="1")
blocks = [block for block in query]
self.assertModelsEqual(blocks, [block1, block2])
@ -123,8 +124,9 @@ class TestQuery(tests.BaseTest):
])
marker_block = blocks[1]
paginated_blocks = models.Query(models.IpBlock).limit(limit=2,
marker=marker_block.id)
all_blocks_query = db_query.find_all(models.IpBlock)
paginated_blocks = all_blocks_query.limit(limit=2,
marker=marker_block.id)
self.assertEqual(len(paginated_blocks), 2)
self.assertEqual(paginated_blocks, [blocks[2], blocks[3]])
@ -134,7 +136,8 @@ class TestQuery(tests.BaseTest):
block2 = factory_models.IpBlockFactory(network_id="1")
noise_block = factory_models.IpBlockFactory(network_id="999")
models.Query(models.IpBlock, network_id="1").update(network_id="2")
db_query.find_all(models.IpBlock,
network_id="1").update(network_id="2")
self.assertEqual(models.IpBlock.find(block1.id).network_id, "2")
self.assertEqual(models.IpBlock.find(block2.id).network_id, "2")
@ -146,7 +149,7 @@ class TestQuery(tests.BaseTest):
block2 = factory_models.IpBlockFactory(network_id="1")
noise_block = factory_models.IpBlockFactory(network_id="999")
models.Query(models.IpBlock, network_id="1").delete()
db_query.find_all(models.IpBlock, network_id="1").delete()
self.assertIsNone(models.IpBlock.get(block1.id))
self.assertIsNone(models.IpBlock.get(block2.id))