# Code to generate a Python model from a database or differences between a model and database. # Some of this is borrowed heavily from the AutoCode project at: http://code.google.com/p/sqlautocode/ import sys import migrate, sqlalchemy HEADER = """ ## File autogenerated by genmodel.py from sqlalchemy import * meta = MetaData() """ DECLARATIVE_HEADER = """ ## File autogenerated by genmodel.py from sqlalchemy import * from sqlalchemy.ext import declarative Base = declarative.declarative_base() """ class ModelGenerator(object): def __init__(self, diff, declarative=False): self.diff = diff self.declarative = declarative dialectModule = sys.modules[self.diff.conn.dialect.__module__] # is there an easier way to get this? self.colTypeMappings = dict( (v,k) for k,v in dialectModule.colspecs.items() ) def column_repr(self, col): kwarg = [] if col.key != col.name: kwarg.append('key') if col.primary_key: col.primary_key = True # otherwise it dumps it as 1 kwarg.append('primary_key') if not col.nullable: kwarg.append('nullable') if col.onupdate: kwarg.append('onupdate') if col.default: if col.primary_key: # I found that Postgres automatically creates a default value for the sequence, but let's not show that. pass else: kwarg.append('default') ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg ) name = col.name.encode('utf8') # crs: not sure if this is good idea, but it gets rid of extra u'' type = self.colTypeMappings.get(col.type.__class__, None) if type: # Make the column type be an instance of this type. type = type() else: # We must already be a model type, no need to map from the database-specific types. type = col.type data = {'name' : name, 'type' : type, 'constraints' : ', '.join([repr(cn) for cn in col.constraints]), 'args' : ks and ks or '' } if data['constraints']: if data['args']: data['args'] = ',' + data['args'] if data['constraints'] or data['args']: data['maybeComma'] = ',' else: data['maybeComma'] = '' commonStuff = " %(maybeComma)s %(constraints)s %(args)s)""" % data commonStuff = commonStuff.strip() data['commonStuff'] = commonStuff if self.declarative: return """%(name)s = Column(%(type)r%(commonStuff)s""" % data else: return """Column(%(name)r, %(type)r%(commonStuff)s""" % data def getTableDefn(self, table): out = [] tableName = table.name if self.declarative: out.append("class %(table)s(Base):" % {'table': tableName}) out.append(" __tablename__ = '%(table)s'" % {'table': tableName}) for col in table.columns: out.append(" %s" % self.column_repr(col)) else: out.append("%(table)s = Table('%(table)s', meta," % {'table': tableName}) for col in table.columns: out.append(" %s," % self.column_repr(col)) out.append(")") return out def toPython(self): ''' Assume database is current and model is empty. ''' out = [] if self.declarative: out.append(DECLARATIVE_HEADER) else: out.append(HEADER) out.append("") for table in self.diff.tablesMissingInModel: out.extend(self.getTableDefn(table)) out.append("") return '\n'.join(out) def toUpgradeDowngradePython(self, indent=' '): ''' Assume model is most current and database is out-of-date. ''' decls = ['meta = MetaData(migrate_engine)'] for table in self.diff.tablesMissingInModel + self.diff.tablesMissingInDatabase: decls.extend(self.getTableDefn(table)) upgradeCommands, downgradeCommands = [], [] for table in self.diff.tablesMissingInModel: tableName = table.name upgradeCommands.append("%(table)s.drop()" % {'table': tableName}) downgradeCommands.append("%(table)s.create()" % {'table': tableName}) for table in self.diff.tablesMissingInDatabase: tableName = table.name upgradeCommands.append("%(table)s.create()" % {'table': tableName}) downgradeCommands.append("%(table)s.drop()" % {'table': tableName}) return ('\n'.join(decls), '\n'.join(['%s%s' % (indent, line) for line in upgradeCommands]), '\n'.join(['%s%s' % (indent, line) for line in downgradeCommands]) ) def applyModel(self): ''' Apply model to current database. ''' # Yuck! We have to import from changeset to apply the monkey-patch to allow column adding/dropping. from migrate.changeset import schema def dbCanHandleThisChange(missingInDatabase, missingInModel, diffDecl): if missingInDatabase and not missingInModel and not diffDecl: # Even sqlite can handle this. return True else: return not self.diff.conn.url.drivername.startswith('sqlite') meta = sqlalchemy.MetaData(self.diff.conn.engine) for table in self.diff.tablesMissingInModel: table = table.tometadata(meta) table.drop() for table in self.diff.tablesMissingInDatabase: table = table.tometadata(meta) table.create() for modelTable in self.diff.tablesWithDiff: modelTable = modelTable.tometadata(meta) dbTable = self.diff.reflected_model.tables[modelTable.name] tableName = modelTable.name missingInDatabase, missingInModel, diffDecl = self.diff.colDiffs[tableName] if dbCanHandleThisChange(missingInDatabase, missingInModel, diffDecl): for col in missingInDatabase: modelTable.columns[col.name].create() for col in missingInModel: dbTable.columns[col.name].drop() for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl: databaseCol.alter(modelCol) else: # Sqlite doesn't support drop column, so you have to do more: # create temp table, copy data to it, drop old table, create new table, copy data back. tempName = '_temp_%s' % modelTable.name # I wonder if this is guaranteed to be unique? def getCopyStatement(): preparer = self.diff.conn.engine.dialect.preparer commonCols = [] for modelCol in modelTable.columns: if dbTable.columns.has_key(modelCol.name): commonCols.append(modelCol.name) commonColsStr = ', '.join(commonCols) return 'INSERT INTO %s (%s) SELECT %s FROM %s' % (tableName, commonColsStr, commonColsStr, tempName) # Move the data in one transaction, so that we don't leave the database in a nasty state. connection = self.diff.conn.connect() trans = connection.begin() try: connection.execute('CREATE TEMPORARY TABLE %s as SELECT * from %s' % (tempName, modelTable.name)) modelTable.drop(bind=connection) # make sure the drop takes place inside our transaction with the bind parameter modelTable.create(bind=connection) connection.execute(getCopyStatement()) connection.execute('DROP TABLE %s' % tempName) trans.commit() except: trans.rollback() raise