from sqlalchemy import Table,Column,MetaData,String,Text,Integer,create_engine from sqlalchemy.sql import and_ from sqlalchemy import exceptions as sa_exceptions from migrate.versioning.repository import Repository from migrate.versioning.util import loadModel from migrate.versioning.version import VerNum from migrate.versioning import exceptions, genmodel, schemadiff class ControlledSchema(object): """A database under version control""" #def __init__(self,engine,repository=None): def __init__(self,engine,repository): if type(repository) is str: repository=Repository(repository) self.engine = engine self.repository = repository self.meta=MetaData(engine) #if self.repository is None: # self._get_repository() self._load() def __eq__(self,other): return (self.repository is other.repository \ and self.version == other.version) def _load(self): """Load controlled schema version info from DB""" tname = self.repository.version_table self.meta=MetaData(self.engine) if not hasattr(self,'table') or self.table is None: try: self.table = Table(tname,self.meta,autoload=True) except (exceptions.NoSuchTableError): raise exceptions.DatabaseNotControlledError(tname) # TODO?: verify that the table is correct (# cols, etc.) result = self.engine.execute(self.table.select( self.table.c.repository_id == str(self.repository.id)) ) data = list(result)[0] # TODO?: exception if row count is bad # TODO: check repository id, exception if incorrect self.version = data['version'] def _get_repository(self): """Given a database engine, try to guess the repository""" # TODO: no guessing yet; for now, a repository must be supplied raise NotImplementedError() @classmethod def create(cls,engine,repository,version=None): """Declare a database to be under a repository's version control""" # Confirm that the version # is valid: positive, integer, exists in repos if type(repository) is str: repository=Repository(repository) version = cls._validate_version(repository,version) table=cls._create_table_version(engine,repository,version) # TODO: history table # Load repository information and return return cls(engine,repository) @classmethod def _validate_version(cls,repository,version): """Ensures this is a valid version number for this repository If invalid, raises cls.InvalidVersionError Returns a valid version number """ if version is None: version=0 try: version = VerNum(version) # raises valueerror if version < 0 or version > repository.latest: raise ValueError() except ValueError: raise exceptions.InvalidVersionError(version) return version @classmethod def _create_table_version(cls,engine,repository,version): """Creates the versioning table in a database""" # Create tables tname = repository.version_table meta = MetaData(engine) try: table = Table(tname,meta, #Column('repository_id',String,primary_key=True), # MySQL needs a length Column('repository_id',String(255),primary_key=True), Column('repository_path',Text), Column('version',Integer), ) table.create() except (sa_exceptions.ArgumentError,sa_exceptions.SQLError): # The table already exists, skip creation. pass #raise exceptions.DatabaseAlreadyControlledError() # Insert data try: engine.execute(table.insert(),repository_id=repository.id, repository_path=repository.path,version=int(version)) except sa_exceptions.IntegrityError: # An Entry for this repo already exists. raise exceptions.DatabaseAlreadyControlledError() return table @classmethod def compare_model_to_db(cls,engine,model,repository): """Compare the current model against the current database.""" if isinstance(repository, basestring): repository=Repository(repository) model = loadModel(model) diff = schemadiff.getDiffOfModelAgainstDatabase(model, engine, excludeTables=[repository.version_table]) return diff @classmethod def create_model(cls,engine,repository,declarative=False): """Dump the current database as a Python model.""" if isinstance(repository, basestring): repository=Repository(repository) diff = schemadiff.getDiffOfModelAgainstDatabase(MetaData(), engine, excludeTables=[repository.version_table]) return genmodel.ModelGenerator(diff, declarative).toPython() def update_db_from_model(self,model): """Modify the database to match the structure of the current Python model.""" if isinstance(self.repository, basestring): self.repository=Repository(self.repository) model = loadModel(model) diff = schemadiff.getDiffOfModelAgainstDatabase(model, self.engine, excludeTables=[self.repository.version_table]) genmodel.ModelGenerator(diff).applyModel() update = self.table.update(self.table.c.repository_id == str(self.repository.id)) self.engine.execute(update, version=int(self.repository.latest)) def drop(self): """Remove version control from a database""" try: self.table.drop() except (sa_exceptions.SQLError): raise exceptions.DatabaseNotControlledError(str(self.table)) def _engine_db(self,engine): """Returns the database name of an engine - 'postgres','sqlite'...""" # TODO: This is a bit of a hack... return str(engine.dialect.__module__).split('.')[-1] def changeset(self,version=None): database = self._engine_db(self.engine) start_ver = self.version changeset = self.repository.changeset(database,start_ver,version) return changeset def runchange(self,ver,change,step): startver = ver endver = ver + step # Current database version must be correct! Don't run if corrupt! if self.version != startver: raise exceptions.InvalidVersionError("%s is not %s"%(self.version,startver)) # Run the change change.run(self.engine,step) # Update/refresh database version update = self.table.update(and_(self.table.c.version == int(startver), self.table.c.repository_id == str(self.repository.id))) self.engine.execute(update, version=int(endver)) self._load() def upgrade(self,version=None): """Upgrade (or downgrade) to a specified version, or latest version""" changeset = self.changeset(version) for ver,change in changeset: self.runchange(ver,change,changeset.step)