Query is better abstracted
This commit is contained in:
parent
124328b628
commit
23d9804321
@ -25,6 +25,59 @@ db_api = utils.import_object(config.Config.get("db_api_implementation",
|
||||
"melange.db.sqlalchemy.api"))
|
||||
|
||||
|
||||
class Query(object):
|
||||
"""Mimics sqlalchemy query object.
|
||||
|
||||
This class allows us to store query conditions and use them with
|
||||
bulk updates and deletes just like sqlalchemy query object.
|
||||
Using this class makes the models independent of sqlalchemy
|
||||
|
||||
"""
|
||||
def __init__(self, model, query_func, **conditions):
|
||||
self._query_func = query_func
|
||||
self._model = model
|
||||
self._conditions = conditions
|
||||
|
||||
def all(self):
|
||||
return db_api.list(self._query_func, self._model, **self._conditions)
|
||||
|
||||
def count(self):
|
||||
return db_api.count(self._query_func, self._model, **self._conditions)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.all())
|
||||
|
||||
def update(self, **values):
|
||||
db_api.update_all(self._query_func, self._model, self._conditions,
|
||||
values)
|
||||
|
||||
def delete(self):
|
||||
db_api.delete_all(self._query_func, self._model, **self._conditions)
|
||||
|
||||
def limit(self, limit=200, marker=None, marker_column=None):
|
||||
return db_api.find_all_by_limit(self._query_func,
|
||||
self._model,
|
||||
self._conditions,
|
||||
limit=limit,
|
||||
marker=marker,
|
||||
marker_column=marker_column)
|
||||
|
||||
def paginated_collection(self, limit=200, marker=None, marker_column=None):
|
||||
collection = self.limit(int(limit) + 1, marker, marker_column)
|
||||
if len(collection) > int(limit):
|
||||
return (collection[0:-1], collection[-2]['id'])
|
||||
return (collection, None)
|
||||
|
||||
|
||||
class Queryable(object):
|
||||
|
||||
def __getattr__(self, item):
|
||||
return lambda model, **conditions: Query(
|
||||
model, query_func=getattr(db_api, item), **conditions)
|
||||
|
||||
db_query = Queryable()
|
||||
|
||||
|
||||
def add_options(parser):
|
||||
"""Adds any configuration options that the db layer might have.
|
||||
|
||||
|
@ -19,7 +19,6 @@ import sqlalchemy.exc
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from melange import ipam
|
||||
from melange.common import exception
|
||||
@ -29,21 +28,22 @@ from melange.db.sqlalchemy import mappers
|
||||
from melange.db.sqlalchemy import session
|
||||
|
||||
|
||||
def list(query):
|
||||
return query.all()
|
||||
def list(query_func, *args, **kwargs):
|
||||
return query_func(*args, **kwargs).all()
|
||||
|
||||
|
||||
def count(query):
|
||||
return query.count()
|
||||
def count(query, *args, **kwargs):
|
||||
return query(*args, **kwargs).count()
|
||||
|
||||
|
||||
def find_all_by(model, **conditions):
|
||||
def find_all(model, **conditions):
|
||||
return _query_by(model, **conditions)
|
||||
|
||||
|
||||
def find_all_by_limit(query_func, model, conditions, limit, marker=None,
|
||||
marker_column=None):
|
||||
return _limits(query_func, model, conditions, limit, marker, marker_column)
|
||||
return _limits(query_func, model, conditions, limit, marker,
|
||||
marker_column).all()
|
||||
|
||||
|
||||
def find_by(model, **kwargs):
|
||||
|
@ -27,56 +27,12 @@ from melange.common import config
|
||||
from melange.common import exception
|
||||
from melange.common import utils
|
||||
from melange.db import db_api
|
||||
from melange.db import db_query
|
||||
|
||||
|
||||
LOG = logging.getLogger('melange.ipam.models')
|
||||
|
||||
|
||||
class Query(object):
|
||||
"""Mimics sqlalchemy query object.
|
||||
|
||||
This class allows us to store query conditions and use them with
|
||||
bulk updates and deletes just like sqlalchemy query object.
|
||||
Using this class makes the models independent of sqlalchemy
|
||||
|
||||
"""
|
||||
def __init__(self, model, query_func=None, **conditions):
|
||||
self._query_func = query_func or db_api.find_all_by
|
||||
self._model = model
|
||||
self._conditions = conditions
|
||||
|
||||
def all(self):
|
||||
return db_api.list(self._query_func(self._model, **self._conditions))
|
||||
|
||||
def count(self):
|
||||
return db_api.count(self._query_func(self._model, **self._conditions))
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.all())
|
||||
|
||||
def update(self, **values):
|
||||
db_api.update_all(self._query_func, self._model, self._conditions,
|
||||
values)
|
||||
|
||||
def delete(self):
|
||||
db_api.delete_all(self._query_func, self._model, **self._conditions)
|
||||
|
||||
def limit(self, limit=200, marker=None, marker_column=None):
|
||||
limit_query = db_api.find_all_by_limit(self._query_func,
|
||||
self._model,
|
||||
self._conditions,
|
||||
limit=limit,
|
||||
marker=marker,
|
||||
marker_column=marker_column)
|
||||
return db_api.list(limit_query)
|
||||
|
||||
def paginated_collection(self, limit=200, marker=None, marker_column=None):
|
||||
collection = self.limit(int(limit) + 1, marker, marker_column)
|
||||
if len(collection) > int(limit):
|
||||
return (collection[0:-1], collection[-2]['id'])
|
||||
return (collection, None)
|
||||
|
||||
|
||||
class Converter(object):
|
||||
|
||||
data_type_converters = {
|
||||
@ -193,11 +149,11 @@ class ModelBase(object):
|
||||
|
||||
@classmethod
|
||||
def find_all(cls, **kwargs):
|
||||
return Query(cls, **cls._process_conditions(kwargs))
|
||||
return db_query.find_all(cls, **cls._process_conditions(kwargs))
|
||||
|
||||
@classmethod
|
||||
def count(cls, **conditions):
|
||||
return Query(cls, **conditions).count()
|
||||
return cls.find_all(**conditions).count()
|
||||
|
||||
def merge_attributes(self, values):
|
||||
"""dict.update() behaviour."""
|
||||
@ -571,16 +527,13 @@ class IpAddress(ModelBase):
|
||||
|
||||
@classmethod
|
||||
def find_all_by_network(cls, network_id, **conditions):
|
||||
return Query(cls,
|
||||
query_func=db_api.find_all_ips_in_network,
|
||||
network_id=network_id,
|
||||
**conditions)
|
||||
return db_query.find_all_ips_in_network(cls,
|
||||
network_id=network_id,
|
||||
**conditions)
|
||||
|
||||
@classmethod
|
||||
def find_all_allocated_ips(cls, **conditions):
|
||||
return Query(cls,
|
||||
query_func=db_api.find_all_allocated_ips,
|
||||
**conditions)
|
||||
return db_query.find_all_allocated_ips(cls, **conditions)
|
||||
|
||||
def delete(self):
|
||||
if self._explicitly_allowed_on_interfaces():
|
||||
@ -593,9 +546,8 @@ class IpAddress(ModelBase):
|
||||
super(IpAddress, self).delete()
|
||||
|
||||
def _explicitly_allowed_on_interfaces(self):
|
||||
return Query(IpAddress,
|
||||
query_func=db_api.find_allowed_ips,
|
||||
ip_address_id=self.id).count() > 0
|
||||
return db_query.find_allowed_ips(IpAddress,
|
||||
ip_address_id=self.id).count() > 0
|
||||
|
||||
def _before_save(self):
|
||||
self.address = self._formatted(self.address)
|
||||
@ -619,10 +571,9 @@ class IpAddress(ModelBase):
|
||||
self.update(marked_for_deallocation=False, deallocated_at=None)
|
||||
|
||||
def inside_globals(self, **kwargs):
|
||||
return Query(IpAddress,
|
||||
query_func=db_api.find_inside_globals,
|
||||
local_address_id=self.id,
|
||||
**kwargs)
|
||||
return db_query.find_inside_globals(IpAddress,
|
||||
local_address_id=self.id,
|
||||
**kwargs)
|
||||
|
||||
def add_inside_globals(self, ip_addresses):
|
||||
db_api.save_nat_relationships([
|
||||
@ -633,10 +584,9 @@ class IpAddress(ModelBase):
|
||||
for global_address in ip_addresses])
|
||||
|
||||
def inside_locals(self, **kwargs):
|
||||
return Query(IpAddress,
|
||||
query_func=db_api.find_inside_locals,
|
||||
global_address_id=self.id,
|
||||
**kwargs)
|
||||
return db_query.find_inside_locals(IpAddress,
|
||||
global_address_id=self.id,
|
||||
**kwargs)
|
||||
|
||||
def remove_inside_globals(self, inside_global_address=None):
|
||||
return db_api.remove_inside_globals(self.id, inside_global_address)
|
||||
@ -825,9 +775,8 @@ class Interface(ModelBase):
|
||||
db_api.remove_allowed_ip(interface_id=self.id, ip_address_id=ip.id)
|
||||
|
||||
def ips_allowed(self):
|
||||
explicitly_allowed = Query(IpAddress,
|
||||
query_func=db_api.find_allowed_ips,
|
||||
allowed_on_interface_id=self.id)
|
||||
explicitly_allowed = db_query.find_allowed_ips(
|
||||
IpAddress, allowed_on_interface_id=self.id)
|
||||
allocated_ips = IpAddress.find_all_allocated_ips(interface_id=self.id)
|
||||
return list(set(allocated_ips) | set(explicitly_allowed))
|
||||
|
||||
|
@ -22,6 +22,7 @@ import netaddr
|
||||
from melange import tests
|
||||
from melange.common import exception
|
||||
from melange.common import utils
|
||||
from melange.db import db_query
|
||||
from melange.ipam import models
|
||||
from melange.tests import unit
|
||||
from melange.tests.factories import models as factory_models
|
||||
@ -91,7 +92,7 @@ class TestQuery(tests.BaseTest):
|
||||
block2 = factory_models.IpBlockFactory(network_id="1")
|
||||
noise_block = factory_models.IpBlockFactory(network_id="999")
|
||||
|
||||
blocks = models.Query(models.IpBlock, network_id="1").all()
|
||||
blocks = db_query.find_all(models.IpBlock, network_id="1").all()
|
||||
|
||||
self.assertModelsEqual(blocks, [block1, block2])
|
||||
|
||||
@ -100,7 +101,7 @@ class TestQuery(tests.BaseTest):
|
||||
factory_models.IpBlockFactory(network_id="1")
|
||||
noise_block = factory_models.IpBlockFactory(network_id="999")
|
||||
|
||||
count = models.Query(models.IpBlock, network_id="1").count()
|
||||
count = db_query.find_all(models.IpBlock, network_id="1").count()
|
||||
|
||||
self.assertEqual(count, 2)
|
||||
|
||||
@ -109,7 +110,7 @@ class TestQuery(tests.BaseTest):
|
||||
block2 = factory_models.IpBlockFactory(network_id="1")
|
||||
noise_block = factory_models.IpBlockFactory(network_id="999")
|
||||
|
||||
query = models.Query(models.IpBlock, network_id="1")
|
||||
query = db_query.find_all(models.IpBlock, network_id="1")
|
||||
blocks = [block for block in query]
|
||||
|
||||
self.assertModelsEqual(blocks, [block1, block2])
|
||||
@ -123,8 +124,9 @@ class TestQuery(tests.BaseTest):
|
||||
])
|
||||
|
||||
marker_block = blocks[1]
|
||||
paginated_blocks = models.Query(models.IpBlock).limit(limit=2,
|
||||
marker=marker_block.id)
|
||||
all_blocks_query = db_query.find_all(models.IpBlock)
|
||||
paginated_blocks = all_blocks_query.limit(limit=2,
|
||||
marker=marker_block.id)
|
||||
|
||||
self.assertEqual(len(paginated_blocks), 2)
|
||||
self.assertEqual(paginated_blocks, [blocks[2], blocks[3]])
|
||||
@ -134,7 +136,8 @@ class TestQuery(tests.BaseTest):
|
||||
block2 = factory_models.IpBlockFactory(network_id="1")
|
||||
noise_block = factory_models.IpBlockFactory(network_id="999")
|
||||
|
||||
models.Query(models.IpBlock, network_id="1").update(network_id="2")
|
||||
db_query.find_all(models.IpBlock,
|
||||
network_id="1").update(network_id="2")
|
||||
|
||||
self.assertEqual(models.IpBlock.find(block1.id).network_id, "2")
|
||||
self.assertEqual(models.IpBlock.find(block2.id).network_id, "2")
|
||||
@ -146,7 +149,7 @@ class TestQuery(tests.BaseTest):
|
||||
block2 = factory_models.IpBlockFactory(network_id="1")
|
||||
noise_block = factory_models.IpBlockFactory(network_id="999")
|
||||
|
||||
models.Query(models.IpBlock, network_id="1").delete()
|
||||
db_query.find_all(models.IpBlock, network_id="1").delete()
|
||||
|
||||
self.assertIsNone(models.IpBlock.get(block1.id))
|
||||
self.assertIsNone(models.IpBlock.get(block2.id))
|
||||
|
Loading…
x
Reference in New Issue
Block a user