From 0fa7d4103e849b01f52df9321206c8be3637e960 Mon Sep 17 00:00:00 2001
From: "jan.dittberner" <unknown>
Date: Sun, 25 Jan 2009 21:09:05 +0000
Subject: [PATCH] migrate.versioning.schema and schemadiff PEP-8 clean, added
 to api.rst

---
 docs/api.rst                     |  14 +++
 migrate/versioning/schema.py     | 167 ++++++++++++++++++-------------
 migrate/versioning/schemadiff.py | 164 ++++++++++++++++++++----------
 3 files changed, 220 insertions(+), 125 deletions(-)

diff --git a/docs/api.rst b/docs/api.rst
index 88ee989..2120eb9 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -118,3 +118,17 @@ Module :mod:`repository <migrate.versioning.repository>`
 .. automodule:: migrate.versioning.repository
    :members:
    :synopsis: SQLAlchemy migrate repository management
+
+Module :mod:`schema <migrate.versioning.schema>`
+------------------------------------------------
+
+.. automodule:: migrate.versioning.schema
+   :members:
+   :synopsis: Database schema management
+
+Module :mod:`schemadiff <migrate.versioning.schemadiff>`
+--------------------------------------------------------
+
+.. automodule:: migrate.versioning.schemadiff
+   :members:
+   :synopsis: Database schema and model differencing
diff --git a/migrate/versioning/schema.py b/migrate/versioning/schema.py
index 6b9afdd..ac4a57d 100644
--- a/migrate/versioning/schema.py
+++ b/migrate/versioning/schema.py
@@ -1,4 +1,8 @@
-from sqlalchemy import Table,Column,MetaData,String,Text,Integer,create_engine
+"""
+   Database schema version management.
+"""
+from sqlalchemy import Table, Column, MetaData, String, Text, Integer, \
+    create_engine
 from sqlalchemy.sql import and_
 from sqlalchemy import exceptions as sa_exceptions
 from migrate.versioning.repository import Repository
@@ -6,63 +10,70 @@ from migrate.versioning.util import loadModel
 from migrate.versioning.version import VerNum
 from migrate.versioning import exceptions, genmodel, schemadiff
 
+
 class ControlledSchema(object):
     """A database under version control"""
-    #def __init__(self,engine,repository=None):
-    def __init__(self,engine,repository):
+
+    def __init__(self, engine, repository):
         if type(repository) is str:
             repository=Repository(repository)
         self.engine = engine
         self.repository = repository
         self.meta=MetaData(engine)
-        #if self.repository is None:
-        #   self._get_repository()
         self._load()
-    
-    def __eq__(self,other):
+
+    def __eq__(self, other):
         return (self.repository is other.repository \
             and self.version == other.version)
-    
+
     def _load(self):
         """Load controlled schema version info from DB"""
         tname = self.repository.version_table
         self.meta=MetaData(self.engine)
-        if not hasattr(self,'table') or self.table is None:
+        if not hasattr(self, 'table') or self.table is None:
             try:
-                self.table = Table(tname,self.meta,autoload=True)
+                self.table = Table(tname, self.meta, autoload=True)
             except (exceptions.NoSuchTableError):
                 raise exceptions.DatabaseNotControlledError(tname)
         # TODO?: verify that the table is correct (# cols, etc.)
         result = self.engine.execute(self.table.select(
-                    self.table.c.repository_id == str(self.repository.id))
-                )
+                    self.table.c.repository_id == str(self.repository.id)))
         data = list(result)[0]
         # TODO?: exception if row count is bad
         # TODO: check repository id, exception if incorrect
         self.version = data['version']
-    
+
     def _get_repository(self):
-        """Given a database engine, try to guess the repository"""
+        """
+        Given a database engine, try to guess the repository.
+
+        :raise: :exc:`NotImplementedError`
+        """
         # TODO: no guessing yet; for now, a repository must be supplied
         raise NotImplementedError()
-    
+
     @classmethod
-    def create(cls,engine,repository,version=None):
-        """Declare a database to be under a repository's version control"""
-        # Confirm that the version # is valid: positive, integer, exists in repos
+    def create(cls, engine, repository, version=None):
+        """
+        Declare a database to be under a repository's version control.
+        """
+        # Confirm that the version # is valid: positive, integer,
+        # exists in repos
         if type(repository) is str:
             repository=Repository(repository)
-        version = cls._validate_version(repository,version)
-        table=cls._create_table_version(engine,repository,version)
+        version = cls._validate_version(repository, version)
+        table=cls._create_table_version(engine, repository, version)
         # TODO: history table
         # Load repository information and return
-        return cls(engine,repository)
-    
+        return cls(engine, repository)
+
     @classmethod
-    def _validate_version(cls,repository,version):
-        """Ensures this is a valid version number for this repository
-        If invalid, raises cls.InvalidVersionError
-        Returns a valid version number
+    def _validate_version(cls, repository, version):
+        """
+        Ensures this is a valid version number for this repository.
+
+        :raises: :exc:`cls.InvalidVersionError` if invalid
+        :return: valid version number
         """
         if version is None:
             version=0
@@ -73,98 +84,114 @@ class ControlledSchema(object):
         except ValueError:
             raise exceptions.InvalidVersionError(version)
         return version
-    
+
     @classmethod
-    def _create_table_version(cls,engine,repository,version):
-        """Creates the versioning table in a database"""
+    def _create_table_version(cls, engine, repository, version):
+        """
+        Creates the versioning table in a database.
+        """
         # Create tables
         tname = repository.version_table
         meta = MetaData(engine)
         try:
-            table = Table(tname,meta,
-                #Column('repository_id',String,primary_key=True), # MySQL needs a length
-                Column('repository_id',String(255),primary_key=True),
-                Column('repository_path',Text),
-                Column('version',Integer),
-            )
+            table = Table(
+                tname, meta,
+                Column('repository_id', String(255), primary_key=True),
+                Column('repository_path', Text),
+                Column('version', Integer), )
             table.create()
-        except (sa_exceptions.ArgumentError,sa_exceptions.SQLError):
+        except (sa_exceptions.ArgumentError, sa_exceptions.SQLError):
             # The table already exists, skip creation.
             pass
-            #raise exceptions.DatabaseAlreadyControlledError()
         # Insert data
         try:
-            engine.execute(table.insert(),repository_id=repository.id,
-                repository_path=repository.path,version=int(version))
+            engine.execute(table.insert(), repository_id=repository.id,
+                           repository_path=repository.path,
+                           version=int(version))
         except sa_exceptions.IntegrityError:
             # An Entry for this repo already exists.
             raise exceptions.DatabaseAlreadyControlledError()
         return table
-    
-    @classmethod
-    def compare_model_to_db(cls,engine,model,repository):
-        """Compare the current model against the current database."""
 
+    @classmethod
+    def compare_model_to_db(cls, engine, model, repository):
+        """
+        Compare the current model against the current database.
+        """
         if isinstance(repository, basestring):
             repository=Repository(repository)
         model = loadModel(model)
-        diff = schemadiff.getDiffOfModelAgainstDatabase(model, engine, excludeTables=[repository.version_table])
+        diff = schemadiff.getDiffOfModelAgainstDatabase(
+            model, engine, excludeTables=[repository.version_table])
         return diff
 
     @classmethod
-    def create_model(cls,engine,repository,declarative=False):
-        """Dump the current database as a Python model."""
-
+    def create_model(cls, engine, repository, declarative=False):
+        """
+        Dump the current database as a Python model.
+        """
         if isinstance(repository, basestring):
             repository=Repository(repository)
-        diff = schemadiff.getDiffOfModelAgainstDatabase(MetaData(), engine, excludeTables=[repository.version_table])
+        diff = schemadiff.getDiffOfModelAgainstDatabase(
+            MetaData(), engine, excludeTables=[repository.version_table])
         return genmodel.ModelGenerator(diff, declarative).toPython()
-    
-    def update_db_from_model(self,model):
-        """Modify the database to match the structure of the current Python model."""
 
+    def update_db_from_model(self, model):
+        """
+        Modify the database to match the structure of the current Python model.
+        """
         if isinstance(self.repository, basestring):
             self.repository=Repository(self.repository)
         model = loadModel(model)
-        diff = schemadiff.getDiffOfModelAgainstDatabase(model, self.engine, excludeTables=[self.repository.version_table])
+        diff = schemadiff.getDiffOfModelAgainstDatabase(
+            model, self.engine, excludeTables=[self.repository.version_table])
         genmodel.ModelGenerator(diff).applyModel()
-        update = self.table.update(self.table.c.repository_id == str(self.repository.id))
+        update = self.table.update(
+            self.table.c.repository_id == str(self.repository.id))
         self.engine.execute(update, version=int(self.repository.latest))
 
     def drop(self):
-        """Remove version control from a database"""
+        """
+        Remove version control from a database.
+        """
         try:
             self.table.drop()
         except (sa_exceptions.SQLError):
             raise exceptions.DatabaseNotControlledError(str(self.table))
-    
-    def _engine_db(self,engine):
-        """Returns the database name of an engine - 'postgres','sqlite'..."""
+
+    def _engine_db(self, engine):
+        """
+        Returns the database name of an engine - ``postgres``, ``sqlite`` ...
+        """
         # TODO: This is a bit of a hack...
         return str(engine.dialect.__module__).split('.')[-1]
 
-    def changeset(self,version=None):
+    def changeset(self, version=None):
         database = self._engine_db(self.engine)
         start_ver = self.version
-        changeset = self.repository.changeset(database,start_ver,version)
+        changeset = self.repository.changeset(database, start_ver, version)
         return changeset
-    
-    def runchange(self,ver,change,step):
+
+    def runchange(self, ver, change, step):
         startver = ver
         endver = ver + step
         # Current database version must be correct! Don't run if corrupt!
         if self.version != startver:
-            raise exceptions.InvalidVersionError("%s is not %s"%(self.version,startver))
+            raise exceptions.InvalidVersionError("%s is not %s" % \
+                                                     (self.version, startver))
         # Run the change
-        change.run(self.engine,step)
+        change.run(self.engine, step)
         # Update/refresh database version
-        update = self.table.update(and_(self.table.c.version == int(startver),
-                                   self.table.c.repository_id == str(self.repository.id)))
+        update = self.table.update(
+            and_(self.table.c.version == int(startver),
+                 self.table.c.repository_id == str(self.repository.id)))
         self.engine.execute(update, version=int(endver))
         self._load()
-        
-    def upgrade(self,version=None):
-        """Upgrade (or downgrade) to a specified version, or latest version"""
+
+    def upgrade(self, version=None):
+        """
+        Upgrade (or downgrade) to a specified version, or latest version.
+        """
         changeset = self.changeset(version)
-        for ver,change in changeset:
-            self.runchange(ver,change,changeset.step)
+        for ver, change in changeset:
+            self.runchange(ver, change, changeset.step)
diff --git a/migrate/versioning/schemadiff.py b/migrate/versioning/schemadiff.py
index 78ff4dc..6f300b3 100644
--- a/migrate/versioning/schemadiff.py
+++ b/migrate/versioning/schemadiff.py
@@ -1,44 +1,63 @@
-
+"""
+   Schema differencing support.
+"""
 import sqlalchemy
 
 
 def getDiffOfModelAgainstDatabase(model, conn, excludeTables=None):
-    ''' Return differences of model against database.
-        Returned object will evaluate to True if there are differences else False.
-    '''
+    """
+    Return differences of model against database.
+
+    :return: object which will evaluate to :keyword:`True` if there \
+      are differences else :keyword:`False`.
+    """
     return SchemaDiff(model, conn, excludeTables)
-    
+
+
 def getDiffOfModelAgainstModel(oldmodel, model, conn, excludeTables=None):
-    ''' Return differences of model against database.
-        Returned object will evaluate to True if there are differences else False.
-    '''
+    """
+    Return differences of model against another model.
+
+    :return: object which will evaluate to :keyword:`True` if there \
+      are differences else :keyword:`False`.
+    """
     return SchemaDiff(model, conn, excludeTables, oldmodel=oldmodel)
-    
+
 
 class SchemaDiff(object):
-    ''' Differences of model against database. '''
-    
+    """
+    Differences of model against database.
+    """
+
     def __init__(self, model, conn, excludeTables=None, oldmodel=None):
-        ''' Parameter model is your Python model's metadata and conn is an active database connection. '''
+        """
+        :param model: Python model's metadata
+        :param conn: active database connection.
+        """
         self.model = model
         self.conn = conn
-        if not excludeTables: excludeTables = []  # [] can't be default value in Python parameter
+        if not excludeTables:
+            # [] can't be default value in Python parameter
+            excludeTables = []
         self.excludeTables = excludeTables
         if oldmodel:
             self.reflected_model = oldmodel
         else:
             self.reflected_model = sqlalchemy.MetaData(conn, reflect=True)
-        self.tablesMissingInDatabase, self.tablesMissingInModel, self.tablesWithDiff = [], [], []
+        self.tablesMissingInDatabase, self.tablesMissingInModel, \
+            self.tablesWithDiff = [], [], []
         self.colDiffs = {}
         self.compareModelToDatabase()
-        
+
     def compareModelToDatabase(self):
-        ''' Do actual comparison. '''
-       
+        """
+        Do actual comparison.
+        """
         # Setup common variables.
         cc = self.conn.contextual_connect()
-        schemagenerator = self.conn.dialect.schemagenerator(self.conn.dialect, cc)
-        
+        schemagenerator = self.conn.dialect.schemagenerator(
+            self.conn.dialect, cc)
+
         # For each in model, find missing in database.
         for modelName, modelTable in self.model.tables.items():
             if modelName in self.excludeTables:
@@ -49,95 +68,130 @@ class SchemaDiff(object):
                 pass
             else:
                 self.tablesMissingInDatabase.append(modelTable)
-        
+
         # For each in database, find missing in model.
-        for reflectedName, reflectedTable in self.reflected_model.tables.items():
+        for reflectedName, reflectedTable in \
+                self.reflected_model.tables.items():
             if reflectedName in self.excludeTables:
                 continue
             modelTable = self.model.tables.get(reflectedName, None)
             if modelTable:
                 # Table exists.
-                
+
                 # Find missing columns in database.
                 for modelCol in modelTable.columns:
-                    databaseCol = reflectedTable.columns.get(modelCol.name, None)
+                    databaseCol = reflectedTable.columns.get(modelCol.name,
+                                                             None)
                     if databaseCol:
                         pass
                     else:
                         self.storeColumnMissingInDatabase(modelTable, modelCol)
-                
+
                 # Find missing columns in model.
                 for databaseCol in reflectedTable.columns:
                     modelCol = modelTable.columns.get(databaseCol.name, None)
                     if modelCol:
                         # Compare attributes of column.
-                        modelDecl = schemagenerator.get_column_specification(modelCol)
-                        databaseDecl = schemagenerator.get_column_specification(databaseCol)
+                        modelDecl = \
+                            schemagenerator.get_column_specification(
+                            modelCol)
+                        databaseDecl = \
+                            schemagenerator.get_column_specification(
+                            databaseCol)
                         if modelDecl != databaseDecl:
-                            # Unfortunately, sometimes the database decl won't quite match the model, even though they're the same.
-                            mc, dc = modelCol.type.__class__, databaseCol.type.__class__
-                            if (issubclass(mc, dc) or issubclass(dc, mc)) and modelCol.nullable == databaseCol.nullable:
+                            # Unfortunately, sometimes the database
+                            # decl won't quite match the model, even
+                            # though they're the same.
+                            mc, dc = modelCol.type.__class__, \
+                                databaseCol.type.__class__
+                            if (issubclass(mc, dc) \
+                                    or issubclass(dc, mc)) \
+                                    and modelCol.nullable == \
+                                    databaseCol.nullable:
                                 # Types and nullable are the same.
                                 pass
                             else:
-                                self.storeColumnDiff(modelTable, modelCol, databaseCol, modelDecl, databaseDecl)
+                                self.storeColumnDiff(
+                                    modelTable, modelCol, databaseCol,
+                                    modelDecl, databaseDecl)
                     else:
                         self.storeColumnMissingInModel(modelTable, databaseCol)
             else:
                 self.tablesMissingInModel.append(reflectedTable)
-        
+
     def __str__(self):
         ''' Summarize differences. '''
-        
+
         def colDiffDetails():
             colout = []
             for table in self.tablesWithDiff:
                 tableName = table.name
-                missingInDatabase, missingInModel, diffDecl = self.colDiffs[tableName]
+                missingInDatabase, missingInModel, diffDecl = \
+                    self.colDiffs[tableName]
                 if missingInDatabase:
-                    colout.append('    %s missing columns in database: %s' % (tableName, ', '.join([col.name for col in missingInDatabase])))
+                    colout.append(
+                        '    %s missing columns in database: %s' % \
+                            (tableName, ', '.join(
+                                [col.name for col in missingInDatabase])))
                 if missingInModel:
-                    colout.append('    %s missing columns in model: %s' % (tableName, ', '.join([col.name for col in missingInModel])))
+                    colout.append(
+                        '    %s missing columns in model: %s' % \
+                            (tableName, ', '.join(
+                                [col.name for col in missingInModel])))
                 if diffDecl:
-                    colout.append('    %s with different declaration of columns in database: %s' % (tableName, str(diffDecl)))
+                    colout.append(
+                        '    %s with different declaration of columns\
+ in database: %s' % (tableName, str(diffDecl)))
             return colout
-            
+
         out = []
         if self.tablesMissingInDatabase:
-            out.append('  tables missing in database: %s' % ', '.join([table.name for table in self.tablesMissingInDatabase]))
+            out.append(
+                '  tables missing in database: %s' % \
+                    ', '.join(
+                    [table.name for table in self.tablesMissingInDatabase]))
         if self.tablesMissingInModel:
-            out.append('  tables missing in model: %s' % ', '.join([table.name for table in self.tablesMissingInModel]))
+            out.append(
+                '  tables missing in model: %s' % \
+                    ', '.join(
+                    [table.name for table in self.tablesMissingInModel]))
         if self.tablesWithDiff:
-            out.append('  tables with differences: %s' % ', '.join([table.name for table in self.tablesWithDiff]))
-            
+            out.append(
+                '  tables with differences: %s' % \
+                    ', '.join([table.name for table in self.tablesWithDiff]))
+
         if out:
             out.insert(0, 'Schema diffs:')
             out.extend(colDiffDetails())
             return '\n'.join(out)
         else:
             return 'No schema diffs'
-    
-    #__repr__ = __str__
-    
+
     def __len__(self):
-        ''' Used in bool evaluation, return of 0 means no diffs. '''
-        return len(self.tablesMissingInDatabase) + len(self.tablesMissingInModel) + len(self.tablesWithDiff)
-    
+        """
+        Used in bool evaluation, return of 0 means no diffs.
+        """
+        return len(self.tablesMissingInDatabase) + \
+            len(self.tablesMissingInModel) + len(self.tablesWithDiff)
+
     def storeColumnMissingInDatabase(self, table, col):
         if table not in self.tablesWithDiff:
             self.tablesWithDiff.append(table)
-        missingInDatabase, missingInModel, diffDecl = self.colDiffs.setdefault(table.name, ([], [], []))
+        missingInDatabase, missingInModel, diffDecl = \
+            self.colDiffs.setdefault(table.name, ([], [], []))
         missingInDatabase.append(col)
-   
+
     def storeColumnMissingInModel(self, table, col):
         if table not in self.tablesWithDiff:
             self.tablesWithDiff.append(table)
-        missingInDatabase, missingInModel, diffDecl = self.colDiffs.setdefault(table.name, ([], [], []))
+        missingInDatabase, missingInModel, diffDecl = \
+            self.colDiffs.setdefault(table.name, ([], [], []))
         missingInModel.append(col)
-   
-    def storeColumnDiff(self, table, modelCol, databaseCol, modelDecl, databaseDecl):
+
+    def storeColumnDiff(self, table, modelCol, databaseCol, modelDecl,
+                        databaseDecl):
         if table not in self.tablesWithDiff:
             self.tablesWithDiff.append(table)
-        missingInDatabase, missingInModel, diffDecl = self.colDiffs.setdefault(table.name, ([], [], []))
-        diffDecl.append( (modelCol, databaseCol, modelDecl, databaseDecl) )
-   
+        missingInDatabase, missingInModel, diffDecl = \
+            self.colDiffs.setdefault(table.name, ([], [], []))
+        diffDecl.append((modelCol, databaseCol, modelDecl, databaseDecl))