diff --git a/migrate/versioning/api.py b/migrate/versioning/api.py index 3e96d56..e710c34 100644 --- a/migrate/versioning/api.py +++ b/migrate/versioning/api.py @@ -284,7 +284,8 @@ def create_model(url,repository,**opts): NOTE: This is EXPERIMENTAL. """ # TODO: get rid of EXPERIMENTAL label engine=create_engine(url) - print cls_schema.create_model(engine,repository) + declarative = opts.get('declarative', False) + print cls_schema.create_model(engine,repository,declarative) def make_update_script_for_model(url,oldmodel,model,repository,**opts): """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH diff --git a/migrate/versioning/genmodel.py b/migrate/versioning/genmodel.py index 39714b6..384025a 100644 --- a/migrate/versioning/genmodel.py +++ b/migrate/versioning/genmodel.py @@ -13,10 +13,20 @@ from sqlalchemy import * meta = MetaData() """ +DECLARATIVE_HEADER = """ +## File autogenerated by genmodel.py + +from sqlalchemy import * +from sqlalchemy.ext import declarative + +Base = declarative.declarative_base() +""" + class ModelGenerator(object): - def __init__(self, diff): + def __init__(self, diff, declarative=False): self.diff = diff + self.declarative = declarative dialectModule = sys.modules[self.diff.conn.dialect.__module__] # is there an easier way to get this? self.colTypeMappings = dict( (v,k) for k,v in dialectModule.colspecs.items() ) @@ -59,21 +69,36 @@ class ModelGenerator(object): else: data['maybeComma'] = '' - return """Column(%(name)r, %(type)r %(maybeComma)s %(constraints)s %(args)s)""" % data + commonStuff = " %(maybeComma)s %(constraints)s %(args)s)""" % data + commonStuff = commonStuff.strip() + data['commonStuff'] = commonStuff + if self.declarative: + return """%(name)s = Column(%(type)r%(commonStuff)s""" % data + else: + return """Column(%(name)r, %(type)r%(commonStuff)s""" % data def getTableDefn(self, table): out = [] tableName = table.name - out.append("%(table)s = Table('%(table)s', meta," % {'table': tableName}) - for col in table.columns: - out.append(" %s," % self.column_repr(col)) - out.append(")") + if self.declarative: + out.append("class %(table)s(Base):" % {'table': tableName}) + out.append(" __tablename__ = '%(table)s'" % {'table': tableName}) + for col in table.columns: + out.append(" %s" % self.column_repr(col)) + else: + out.append("%(table)s = Table('%(table)s', meta," % {'table': tableName}) + for col in table.columns: + out.append(" %s," % self.column_repr(col)) + out.append(")") return out def toPython(self): ''' Assume database is current and model is empty. ''' out = [] - out.append(HEADER) + if self.declarative: + out.append(DECLARATIVE_HEADER) + else: + out.append(HEADER) out.append("") for table in self.diff.tablesMissingInModel: out.extend(self.getTableDefn(table)) diff --git a/migrate/versioning/schema.py b/migrate/versioning/schema.py index ed77719..6b9afdd 100644 --- a/migrate/versioning/schema.py +++ b/migrate/versioning/schema.py @@ -112,13 +112,13 @@ class ControlledSchema(object): return diff @classmethod - def create_model(cls,engine,repository): + def create_model(cls,engine,repository,declarative=False): """Dump the current database as a Python model.""" if isinstance(repository, basestring): repository=Repository(repository) diff = schemadiff.getDiffOfModelAgainstDatabase(MetaData(), engine, excludeTables=[repository.version_table]) - return genmodel.ModelGenerator(diff).toPython() + return genmodel.ModelGenerator(diff, declarative).toPython() def update_db_from_model(self,model): """Modify the database to match the structure of the current Python model."""