144 lines
6.7 KiB
Python

import sqlalchemy
def getDiffOfModelAgainstDatabase(model, conn, excludeTables=None):
''' Return differences of model against database.
Returned object will evaluate to True if there are differences else False.
'''
return SchemaDiff(model, conn, excludeTables)
def getDiffOfModelAgainstModel(oldmodel, model, conn, excludeTables=None):
''' Return differences of model against database.
Returned object will evaluate to True if there are differences else False.
'''
return SchemaDiff(model, conn, excludeTables, oldmodel=oldmodel)
class SchemaDiff(object):
''' Differences of model against database. '''
def __init__(self, model, conn, excludeTables=None, oldmodel=None):
''' Parameter model is your Python model's metadata and conn is an active database connection. '''
self.model = model
self.conn = conn
if not excludeTables: excludeTables = [] # [] can't be default value in Python parameter
self.excludeTables = excludeTables
if oldmodel:
self.reflected_model = oldmodel
else:
self.reflected_model = sqlalchemy.MetaData(conn, reflect=True)
self.tablesMissingInDatabase, self.tablesMissingInModel, self.tablesWithDiff = [], [], []
self.colDiffs = {}
self.compareModelToDatabase()
def compareModelToDatabase(self):
''' Do actual comparison. '''
# Setup common variables.
cc = self.conn.contextual_connect()
schemagenerator = self.conn.dialect.schemagenerator(self.conn.dialect, cc)
# For each in model, find missing in database.
for modelName, modelTable in self.model.tables.items():
if modelName in self.excludeTables:
continue
reflectedTable = self.reflected_model.tables.get(modelName, None)
if reflectedTable:
# Table exists.
pass
else:
self.tablesMissingInDatabase.append(modelTable)
# For each in database, find missing in model.
for reflectedName, reflectedTable in self.reflected_model.tables.items():
if reflectedName in self.excludeTables:
continue
modelTable = self.model.tables.get(reflectedName, None)
if modelTable:
# Table exists.
# Find missing columns in database.
for modelCol in modelTable.columns:
databaseCol = reflectedTable.columns.get(modelCol.name, None)
if databaseCol:
pass
else:
self.storeColumnMissingInDatabase(modelTable, modelCol)
# Find missing columns in model.
for databaseCol in reflectedTable.columns:
modelCol = modelTable.columns.get(databaseCol.name, None)
if modelCol:
# Compare attributes of column.
modelDecl = schemagenerator.get_column_specification(modelCol)
databaseDecl = schemagenerator.get_column_specification(databaseCol)
if modelDecl != databaseDecl:
# Unfortunately, sometimes the database decl won't quite match the model, even though they're the same.
mc, dc = modelCol.type.__class__, databaseCol.type.__class__
if (issubclass(mc, dc) or issubclass(dc, mc)) and modelCol.nullable == databaseCol.nullable:
# Types and nullable are the same.
pass
else:
self.storeColumnDiff(modelTable, modelCol, databaseCol, modelDecl, databaseDecl)
else:
self.storeColumnMissingInModel(modelTable, databaseCol)
else:
self.tablesMissingInModel.append(reflectedTable)
def __str__(self):
''' Summarize differences. '''
def colDiffDetails():
colout = []
for table in self.tablesWithDiff:
tableName = table.name
missingInDatabase, missingInModel, diffDecl = self.colDiffs[tableName]
if missingInDatabase:
colout.append(' %s missing columns in database: %s' % (tableName, ', '.join([col.name for col in missingInDatabase])))
if missingInModel:
colout.append(' %s missing columns in model: %s' % (tableName, ', '.join([col.name for col in missingInModel])))
if diffDecl:
colout.append(' %s with different declaration of columns in database: %s' % (tableName, str(diffDecl)))
return colout
out = []
if self.tablesMissingInDatabase:
out.append(' tables missing in database: %s' % ', '.join([table.name for table in self.tablesMissingInDatabase]))
if self.tablesMissingInModel:
out.append(' tables missing in model: %s' % ', '.join([table.name for table in self.tablesMissingInModel]))
if self.tablesWithDiff:
out.append(' tables with differences: %s' % ', '.join([table.name for table in self.tablesWithDiff]))
if out:
out.insert(0, 'Schema diffs:')
out.extend(colDiffDetails())
return '\n'.join(out)
else:
return 'No schema diffs'
#__repr__ = __str__
def __len__(self):
''' Used in bool evaluation, return of 0 means no diffs. '''
return len(self.tablesMissingInDatabase) + len(self.tablesMissingInModel) + len(self.tablesWithDiff)
def storeColumnMissingInDatabase(self, table, col):
if table not in self.tablesWithDiff:
self.tablesWithDiff.append(table)
missingInDatabase, missingInModel, diffDecl = self.colDiffs.setdefault(table.name, ([], [], []))
missingInDatabase.append(col)
def storeColumnMissingInModel(self, table, col):
if table not in self.tablesWithDiff:
self.tablesWithDiff.append(table)
missingInDatabase, missingInModel, diffDecl = self.colDiffs.setdefault(table.name, ([], [], []))
missingInModel.append(col)
def storeColumnDiff(self, table, modelCol, databaseCol, modelDecl, databaseDecl):
if table not in self.tablesWithDiff:
self.tablesWithDiff.append(table)
missingInDatabase, missingInModel, diffDecl = self.colDiffs.setdefault(table.name, ([], [], []))
diffDecl.append( (modelCol, databaseCol, modelDecl, databaseDecl) )