156 lines
5.3 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import shutil
from StringIO import StringIO
import migrate
from migrate.versioning import exceptions, genmodel, schemadiff
from migrate.versioning.config import operations
from migrate.versioning.template import Template
from migrate.versioning.script import base
from migrate.versioning.util import import_path, load_model, construct_engine
class PythonScript(base.BaseScript):
"""Base for Python scripts"""
@classmethod
def create(cls, path, **opts):
"""Create an empty migration script at specified path
:returns: :class:`PythonScript instance <migrate.versioning.script.py.PythonScript>`"""
cls.require_notfound(path)
src = Template(opts.pop('templates_path', None)).get_script(theme=opts.pop('templates_theme', None))
shutil.copy(src, path)
return cls(path)
@classmethod
def make_update_script_for_model(cls, engine, oldmodel,
model, repository, **opts):
"""Create a migration script based on difference between two SA models.
:param repository: path to migrate repository
:param oldmodel: dotted.module.name:SAClass or SAClass object
:param model: dotted.module.name:SAClass or SAClass object
:param engine: SQLAlchemy engine
:type repository: string or :class:`Repository instance <migrate.versioning.repository.Repository>`
:type oldmodel: string or Class
:type model: string or Class
:type engine: Engine instance
:returns: Upgrade / Downgrade script
:rtype: string
"""
if isinstance(repository, basestring):
# oh dear, an import cycle!
from migrate.versioning.repository import Repository
repository = Repository(repository)
oldmodel = load_model(oldmodel)
model = load_model(model)
# Compute differences.
diff = schemadiff.getDiffOfModelAgainstModel(
oldmodel,
model,
engine,
excludeTables=[repository.version_table])
# TODO: diff can be False (there is no difference?)
decls, upgradeCommands, downgradeCommands = \
genmodel.ModelGenerator(diff).toUpgradeDowngradePython()
# Store differences into file.
src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None))
f = open(src)
contents = f.read()
f.close()
# generate source
search = 'def upgrade(migrate_engine):'
contents = contents.replace(search, '\n\n'.join((decls, search)), 1)
if upgradeCommands:
contents = contents.replace(' pass', upgradeCommands, 1)
if downgradeCommands:
contents = contents.replace(' pass', downgradeCommands, 1)
return contents
@classmethod
def verify_module(cls, path):
"""Ensure path is a valid script
:param path: Script location
:type path: string
:raises: :exc:`InvalidScriptError <migrate.versioning.exceptions.InvalidScriptError>`
:returns: Python module
"""
# Try to import and get the upgrade() func
try:
module = import_path(path)
except:
# If the script itself has errors, that's not our problem
raise
try:
assert callable(module.upgrade)
except Exception, e:
raise exceptions.InvalidScriptError(path + ': %s' % str(e))
return module
def preview_sql(self, url, step, **args):
"""Mocks SQLAlchemy Engine to store all executed calls in a string
and runs :meth:`PythonScript.run <migrate.versioning.script.py.PythonScript.run>`
:returns: SQL file
"""
buf = StringIO()
args['engine_arg_strategy'] = 'mock'
args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p)
engine = construct_engine(url, **args)
self.run(engine, step)
return buf.getvalue()
def run(self, engine, step):
"""Core method of Script file.
Exectues :func:`update` or :func:`downgrade` functions
:param engine: SQLAlchemy Engine
:param step: Operation to run
:type engine: string
:type step: int
"""
if step > 0:
op = 'upgrade'
elif step < 0:
op = 'downgrade'
else:
raise exceptions.ScriptError("%d is not a valid step" % step)
funcname = base.operations[op]
func = self._func(funcname)
try:
func(engine)
except TypeError:
print "upgrade/downgrade functions must accept engine parameter (since ver 0.5.5)"
raise
@property
def module(self):
"""Calls :meth:`migrate.versioning.script.py.verify_module`
and returns it.
"""
if not hasattr(self, '_module'):
self._module = self.verify_module(self.path)
return self._module
def _func(self, funcname):
try:
return getattr(self.module, funcname)
except AttributeError:
msg = "The function %s is not defined in this script"
raise exceptions.ScriptError(msg % funcname)