diff --git a/docs/api.rst b/docs/api.rst index 88ee989..2120eb9 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -118,3 +118,17 @@ Module :mod:`repository ` .. automodule:: migrate.versioning.repository :members: :synopsis: SQLAlchemy migrate repository management + +Module :mod:`schema ` +------------------------------------------------ + +.. automodule:: migrate.versioning.schema + :members: + :synopsis: Database schema management + +Module :mod:`schemadiff ` +-------------------------------------------------------- + +.. automodule:: migrate.versioning.schemadiff + :members: + :synopsis: Database schema and model differencing diff --git a/migrate/versioning/schema.py b/migrate/versioning/schema.py index 6b9afdd..ac4a57d 100644 --- a/migrate/versioning/schema.py +++ b/migrate/versioning/schema.py @@ -1,4 +1,8 @@ -from sqlalchemy import Table,Column,MetaData,String,Text,Integer,create_engine +""" + Database schema version management. +""" +from sqlalchemy import Table, Column, MetaData, String, Text, Integer, \ + create_engine from sqlalchemy.sql import and_ from sqlalchemy import exceptions as sa_exceptions from migrate.versioning.repository import Repository @@ -6,63 +10,70 @@ from migrate.versioning.util import loadModel from migrate.versioning.version import VerNum from migrate.versioning import exceptions, genmodel, schemadiff + class ControlledSchema(object): """A database under version control""" - #def __init__(self,engine,repository=None): - def __init__(self,engine,repository): + + def __init__(self, engine, repository): if type(repository) is str: repository=Repository(repository) self.engine = engine self.repository = repository self.meta=MetaData(engine) - #if self.repository is None: - # self._get_repository() self._load() - - def __eq__(self,other): + + def __eq__(self, other): return (self.repository is other.repository \ and self.version == other.version) - + def _load(self): """Load controlled schema version info from DB""" tname = self.repository.version_table self.meta=MetaData(self.engine) - if not hasattr(self,'table') or self.table is None: + if not hasattr(self, 'table') or self.table is None: try: - self.table = Table(tname,self.meta,autoload=True) + self.table = Table(tname, self.meta, autoload=True) except (exceptions.NoSuchTableError): raise exceptions.DatabaseNotControlledError(tname) # TODO?: verify that the table is correct (# cols, etc.) result = self.engine.execute(self.table.select( - self.table.c.repository_id == str(self.repository.id)) - ) + self.table.c.repository_id == str(self.repository.id))) data = list(result)[0] # TODO?: exception if row count is bad # TODO: check repository id, exception if incorrect self.version = data['version'] - + def _get_repository(self): - """Given a database engine, try to guess the repository""" + """ + Given a database engine, try to guess the repository. + + :raise: :exc:`NotImplementedError` + """ # TODO: no guessing yet; for now, a repository must be supplied raise NotImplementedError() - + @classmethod - def create(cls,engine,repository,version=None): - """Declare a database to be under a repository's version control""" - # Confirm that the version # is valid: positive, integer, exists in repos + def create(cls, engine, repository, version=None): + """ + Declare a database to be under a repository's version control. + """ + # Confirm that the version # is valid: positive, integer, + # exists in repos if type(repository) is str: repository=Repository(repository) - version = cls._validate_version(repository,version) - table=cls._create_table_version(engine,repository,version) + version = cls._validate_version(repository, version) + table=cls._create_table_version(engine, repository, version) # TODO: history table # Load repository information and return - return cls(engine,repository) - + return cls(engine, repository) + @classmethod - def _validate_version(cls,repository,version): - """Ensures this is a valid version number for this repository - If invalid, raises cls.InvalidVersionError - Returns a valid version number + def _validate_version(cls, repository, version): + """ + Ensures this is a valid version number for this repository. + + :raises: :exc:`cls.InvalidVersionError` if invalid + :return: valid version number """ if version is None: version=0 @@ -73,98 +84,114 @@ class ControlledSchema(object): except ValueError: raise exceptions.InvalidVersionError(version) return version - + @classmethod - def _create_table_version(cls,engine,repository,version): - """Creates the versioning table in a database""" + def _create_table_version(cls, engine, repository, version): + """ + Creates the versioning table in a database. + """ # Create tables tname = repository.version_table meta = MetaData(engine) try: - table = Table(tname,meta, - #Column('repository_id',String,primary_key=True), # MySQL needs a length - Column('repository_id',String(255),primary_key=True), - Column('repository_path',Text), - Column('version',Integer), - ) + table = Table( + tname, meta, + Column('repository_id', String(255), primary_key=True), + Column('repository_path', Text), + Column('version', Integer), ) table.create() - except (sa_exceptions.ArgumentError,sa_exceptions.SQLError): + except (sa_exceptions.ArgumentError, sa_exceptions.SQLError): # The table already exists, skip creation. pass - #raise exceptions.DatabaseAlreadyControlledError() # Insert data try: - engine.execute(table.insert(),repository_id=repository.id, - repository_path=repository.path,version=int(version)) + engine.execute(table.insert(), repository_id=repository.id, + repository_path=repository.path, + version=int(version)) except sa_exceptions.IntegrityError: # An Entry for this repo already exists. raise exceptions.DatabaseAlreadyControlledError() return table - - @classmethod - def compare_model_to_db(cls,engine,model,repository): - """Compare the current model against the current database.""" + @classmethod + def compare_model_to_db(cls, engine, model, repository): + """ + Compare the current model against the current database. + """ if isinstance(repository, basestring): repository=Repository(repository) model = loadModel(model) - diff = schemadiff.getDiffOfModelAgainstDatabase(model, engine, excludeTables=[repository.version_table]) + diff = schemadiff.getDiffOfModelAgainstDatabase( + model, engine, excludeTables=[repository.version_table]) return diff @classmethod - def create_model(cls,engine,repository,declarative=False): - """Dump the current database as a Python model.""" - + def create_model(cls, engine, repository, declarative=False): + """ + Dump the current database as a Python model. + """ if isinstance(repository, basestring): repository=Repository(repository) - diff = schemadiff.getDiffOfModelAgainstDatabase(MetaData(), engine, excludeTables=[repository.version_table]) + diff = schemadiff.getDiffOfModelAgainstDatabase( + MetaData(), engine, excludeTables=[repository.version_table]) return genmodel.ModelGenerator(diff, declarative).toPython() - - def update_db_from_model(self,model): - """Modify the database to match the structure of the current Python model.""" + def update_db_from_model(self, model): + """ + Modify the database to match the structure of the current Python model. + """ if isinstance(self.repository, basestring): self.repository=Repository(self.repository) model = loadModel(model) - diff = schemadiff.getDiffOfModelAgainstDatabase(model, self.engine, excludeTables=[self.repository.version_table]) + diff = schemadiff.getDiffOfModelAgainstDatabase( + model, self.engine, excludeTables=[self.repository.version_table]) genmodel.ModelGenerator(diff).applyModel() - update = self.table.update(self.table.c.repository_id == str(self.repository.id)) + update = self.table.update( + self.table.c.repository_id == str(self.repository.id)) self.engine.execute(update, version=int(self.repository.latest)) def drop(self): - """Remove version control from a database""" + """ + Remove version control from a database. + """ try: self.table.drop() except (sa_exceptions.SQLError): raise exceptions.DatabaseNotControlledError(str(self.table)) - - def _engine_db(self,engine): - """Returns the database name of an engine - 'postgres','sqlite'...""" + + def _engine_db(self, engine): + """ + Returns the database name of an engine - ``postgres``, ``sqlite`` ... + """ # TODO: This is a bit of a hack... return str(engine.dialect.__module__).split('.')[-1] - def changeset(self,version=None): + def changeset(self, version=None): database = self._engine_db(self.engine) start_ver = self.version - changeset = self.repository.changeset(database,start_ver,version) + changeset = self.repository.changeset(database, start_ver, version) return changeset - - def runchange(self,ver,change,step): + + def runchange(self, ver, change, step): startver = ver endver = ver + step # Current database version must be correct! Don't run if corrupt! if self.version != startver: - raise exceptions.InvalidVersionError("%s is not %s"%(self.version,startver)) + raise exceptions.InvalidVersionError("%s is not %s" % \ + (self.version, startver)) # Run the change - change.run(self.engine,step) + change.run(self.engine, step) # Update/refresh database version - update = self.table.update(and_(self.table.c.version == int(startver), - self.table.c.repository_id == str(self.repository.id))) + update = self.table.update( + and_(self.table.c.version == int(startver), + self.table.c.repository_id == str(self.repository.id))) self.engine.execute(update, version=int(endver)) self._load() - - def upgrade(self,version=None): - """Upgrade (or downgrade) to a specified version, or latest version""" + + def upgrade(self, version=None): + """ + Upgrade (or downgrade) to a specified version, or latest version. + """ changeset = self.changeset(version) - for ver,change in changeset: - self.runchange(ver,change,changeset.step) + for ver, change in changeset: + self.runchange(ver, change, changeset.step) diff --git a/migrate/versioning/schemadiff.py b/migrate/versioning/schemadiff.py index 78ff4dc..6f300b3 100644 --- a/migrate/versioning/schemadiff.py +++ b/migrate/versioning/schemadiff.py @@ -1,44 +1,63 @@ - +""" + Schema differencing support. +""" import sqlalchemy def getDiffOfModelAgainstDatabase(model, conn, excludeTables=None): - ''' Return differences of model against database. - Returned object will evaluate to True if there are differences else False. - ''' + """ + Return differences of model against database. + + :return: object which will evaluate to :keyword:`True` if there \ + are differences else :keyword:`False`. + """ return SchemaDiff(model, conn, excludeTables) - + + def getDiffOfModelAgainstModel(oldmodel, model, conn, excludeTables=None): - ''' Return differences of model against database. - Returned object will evaluate to True if there are differences else False. - ''' + """ + Return differences of model against another model. + + :return: object which will evaluate to :keyword:`True` if there \ + are differences else :keyword:`False`. + """ return SchemaDiff(model, conn, excludeTables, oldmodel=oldmodel) - + class SchemaDiff(object): - ''' Differences of model against database. ''' - + """ + Differences of model against database. + """ + def __init__(self, model, conn, excludeTables=None, oldmodel=None): - ''' Parameter model is your Python model's metadata and conn is an active database connection. ''' + """ + :param model: Python model's metadata + :param conn: active database connection. + """ self.model = model self.conn = conn - if not excludeTables: excludeTables = [] # [] can't be default value in Python parameter + if not excludeTables: + # [] can't be default value in Python parameter + excludeTables = [] self.excludeTables = excludeTables if oldmodel: self.reflected_model = oldmodel else: self.reflected_model = sqlalchemy.MetaData(conn, reflect=True) - self.tablesMissingInDatabase, self.tablesMissingInModel, self.tablesWithDiff = [], [], [] + self.tablesMissingInDatabase, self.tablesMissingInModel, \ + self.tablesWithDiff = [], [], [] self.colDiffs = {} self.compareModelToDatabase() - + def compareModelToDatabase(self): - ''' Do actual comparison. ''' - + """ + Do actual comparison. + """ # Setup common variables. cc = self.conn.contextual_connect() - schemagenerator = self.conn.dialect.schemagenerator(self.conn.dialect, cc) - + schemagenerator = self.conn.dialect.schemagenerator( + self.conn.dialect, cc) + # For each in model, find missing in database. for modelName, modelTable in self.model.tables.items(): if modelName in self.excludeTables: @@ -49,95 +68,130 @@ class SchemaDiff(object): pass else: self.tablesMissingInDatabase.append(modelTable) - + # For each in database, find missing in model. - for reflectedName, reflectedTable in self.reflected_model.tables.items(): + for reflectedName, reflectedTable in \ + self.reflected_model.tables.items(): if reflectedName in self.excludeTables: continue modelTable = self.model.tables.get(reflectedName, None) if modelTable: # Table exists. - + # Find missing columns in database. for modelCol in modelTable.columns: - databaseCol = reflectedTable.columns.get(modelCol.name, None) + databaseCol = reflectedTable.columns.get(modelCol.name, + None) if databaseCol: pass else: self.storeColumnMissingInDatabase(modelTable, modelCol) - + # Find missing columns in model. for databaseCol in reflectedTable.columns: modelCol = modelTable.columns.get(databaseCol.name, None) if modelCol: # Compare attributes of column. - modelDecl = schemagenerator.get_column_specification(modelCol) - databaseDecl = schemagenerator.get_column_specification(databaseCol) + modelDecl = \ + schemagenerator.get_column_specification( + modelCol) + databaseDecl = \ + schemagenerator.get_column_specification( + databaseCol) if modelDecl != databaseDecl: - # Unfortunately, sometimes the database decl won't quite match the model, even though they're the same. - mc, dc = modelCol.type.__class__, databaseCol.type.__class__ - if (issubclass(mc, dc) or issubclass(dc, mc)) and modelCol.nullable == databaseCol.nullable: + # Unfortunately, sometimes the database + # decl won't quite match the model, even + # though they're the same. + mc, dc = modelCol.type.__class__, \ + databaseCol.type.__class__ + if (issubclass(mc, dc) \ + or issubclass(dc, mc)) \ + and modelCol.nullable == \ + databaseCol.nullable: # Types and nullable are the same. pass else: - self.storeColumnDiff(modelTable, modelCol, databaseCol, modelDecl, databaseDecl) + self.storeColumnDiff( + modelTable, modelCol, databaseCol, + modelDecl, databaseDecl) else: self.storeColumnMissingInModel(modelTable, databaseCol) else: self.tablesMissingInModel.append(reflectedTable) - + def __str__(self): ''' Summarize differences. ''' - + def colDiffDetails(): colout = [] for table in self.tablesWithDiff: tableName = table.name - missingInDatabase, missingInModel, diffDecl = self.colDiffs[tableName] + missingInDatabase, missingInModel, diffDecl = \ + self.colDiffs[tableName] if missingInDatabase: - colout.append(' %s missing columns in database: %s' % (tableName, ', '.join([col.name for col in missingInDatabase]))) + colout.append( + ' %s missing columns in database: %s' % \ + (tableName, ', '.join( + [col.name for col in missingInDatabase]))) if missingInModel: - colout.append(' %s missing columns in model: %s' % (tableName, ', '.join([col.name for col in missingInModel]))) + colout.append( + ' %s missing columns in model: %s' % \ + (tableName, ', '.join( + [col.name for col in missingInModel]))) if diffDecl: - colout.append(' %s with different declaration of columns in database: %s' % (tableName, str(diffDecl))) + colout.append( + ' %s with different declaration of columns\ + in database: %s' % (tableName, str(diffDecl))) return colout - + out = [] if self.tablesMissingInDatabase: - out.append(' tables missing in database: %s' % ', '.join([table.name for table in self.tablesMissingInDatabase])) + out.append( + ' tables missing in database: %s' % \ + ', '.join( + [table.name for table in self.tablesMissingInDatabase])) if self.tablesMissingInModel: - out.append(' tables missing in model: %s' % ', '.join([table.name for table in self.tablesMissingInModel])) + out.append( + ' tables missing in model: %s' % \ + ', '.join( + [table.name for table in self.tablesMissingInModel])) if self.tablesWithDiff: - out.append(' tables with differences: %s' % ', '.join([table.name for table in self.tablesWithDiff])) - + out.append( + ' tables with differences: %s' % \ + ', '.join([table.name for table in self.tablesWithDiff])) + if out: out.insert(0, 'Schema diffs:') out.extend(colDiffDetails()) return '\n'.join(out) else: return 'No schema diffs' - - #__repr__ = __str__ - + def __len__(self): - ''' Used in bool evaluation, return of 0 means no diffs. ''' - return len(self.tablesMissingInDatabase) + len(self.tablesMissingInModel) + len(self.tablesWithDiff) - + """ + Used in bool evaluation, return of 0 means no diffs. + """ + return len(self.tablesMissingInDatabase) + \ + len(self.tablesMissingInModel) + len(self.tablesWithDiff) + def storeColumnMissingInDatabase(self, table, col): if table not in self.tablesWithDiff: self.tablesWithDiff.append(table) - missingInDatabase, missingInModel, diffDecl = self.colDiffs.setdefault(table.name, ([], [], [])) + missingInDatabase, missingInModel, diffDecl = \ + self.colDiffs.setdefault(table.name, ([], [], [])) missingInDatabase.append(col) - + def storeColumnMissingInModel(self, table, col): if table not in self.tablesWithDiff: self.tablesWithDiff.append(table) - missingInDatabase, missingInModel, diffDecl = self.colDiffs.setdefault(table.name, ([], [], [])) + missingInDatabase, missingInModel, diffDecl = \ + self.colDiffs.setdefault(table.name, ([], [], [])) missingInModel.append(col) - - def storeColumnDiff(self, table, modelCol, databaseCol, modelDecl, databaseDecl): + + def storeColumnDiff(self, table, modelCol, databaseCol, modelDecl, + databaseDecl): if table not in self.tablesWithDiff: self.tablesWithDiff.append(table) - missingInDatabase, missingInModel, diffDecl = self.colDiffs.setdefault(table.name, ([], [], [])) - diffDecl.append( (modelCol, databaseCol, modelDecl, databaseDecl) ) - + missingInDatabase, missingInModel, diffDecl = \ + self.colDiffs.setdefault(table.name, ([], [], [])) + diffDecl.append((modelCol, databaseCol, modelDecl, databaseDecl))