make repository format flatter and get rid of commit command

This commit is contained in:
christian.simms 2008-05-30 18:21:14 +00:00
parent 02da42cba6
commit 95b666783e
11 changed files with 213 additions and 304 deletions

View File

@ -11,8 +11,8 @@ __all__=[
'help', 'help',
'create', 'create',
'script', 'script',
'script_sql',
'make_update_script_for_model', 'make_update_script_for_model',
'commit',
'version', 'version',
'source', 'source',
'version_control', 'version_control',
@ -61,57 +61,42 @@ def create(repository,name,**opts):
except exceptions.PathFoundError,e: except exceptions.PathFoundError,e:
raise exceptions.KnownError("The path %s already exists"%e.args[0]) raise exceptions.KnownError("The path %s already exists"%e.args[0])
def script(path,**opts): def script(description,repository=None,**opts):
"""%prog script PATH """%prog script [--repository=REPOSITORY_PATH] DESCRIPTION
Create an empty change script at the specified path. Create an empty change script using the next unused version number appended with the given description.
For instance, manage.py script "Add initial tables" creates: repository/versions/001_Add_initial_tables.py
""" """
try: try:
cls_script_python.create(path,**opts) if repository is None:
raise exceptions.UsageError("A repository must be specified")
repos = cls_repository(repository)
repos.create_script(description,**opts)
except exceptions.PathFoundError,e: except exceptions.PathFoundError,e:
raise exceptions.KnownError("The path %s already exists"%e.args[0]) raise exceptions.KnownError("The path %s already exists"%e.args[0])
def commit(script,repository,database=None,operation=None,version=None,**opts): def script_sql(database,repository=None,**opts):
"""%prog commit SCRIPT_PATH.py REPOSITORY_PATH [VERSION] """%prog script_sql [--repository=REPOSITORY_PATH] DATABASE
%prog commit SCRIPT_PATH.sql REPOSITORY_PATH DATABASE OPERATION [VERSION] Create empty change SQL scripts for given DATABASE, where DATABASE is either specific ('postgres', 'mysql',
'oracle', 'sqlite', etc.) or generic ('default').
Commit a script to this repository. The committed script is added to the For instance, manage.py script_sql postgres creates:
repository, and the file disappears. repository/versions/001_upgrade_postgres.py and repository/versions/001_downgrade_postgres.py
Once a script has been committed, you can use it to upgrade a database with
the 'upgrade' command.
If a version is given, that version will be replaced instead of creating a
new version.
Normally, when writing change scripts in Python, you'll use the first form
of this command (DATABASE and OPERATION aren't specified). If you write
change scripts as .sql files, you'll need to specify DATABASE ('postgres',
'mysql', 'oracle', 'sqlite'...) and OPERATION ('upgrade' or 'downgrade').
You may commit multiple .sql files under the same version to complete
functionality for a particular version::
%prog commit upgrade.postgres.sql /repository/path postgres upgrade 1
%prog commit downgrade.postgres.sql /repository/path postgres downgrade 1
%prog commit upgrade.sqlite.sql /repository/path sqlite upgrade 1
%prog commit downgrade.sqlite.sql /repository/path sqlite downgrade 1
[etc...]
""" """
if (database is not None) and (operation is None) and (version is None): try:
# Version was supplied as a positional if repository is None:
version = database raise exceptions.UsageError("A repository must be specified")
database = None
repos = cls_repository(repository) repos = cls_repository(repository)
repos.commit(script,version,database=database,operation=operation) repos.create_script_sql(database,**opts)
except exceptions.PathFoundError,e:
raise exceptions.KnownError("The path %s already exists"%e.args[0])
def test(script,repository,url=None,**opts): def test(repository,url=None,**opts):
"""%prog test SCRIPT_PATH REPOSITORY_PATH URL [VERSION] """%prog test REPOSITORY_PATH URL [VERSION]
""" """
engine=create_engine(url) engine=create_engine(url)
schema = cls_schema(engine,repository) repos=cls_repository(repository)
script = cls_script_python(script) script = repos.version(None).script()
# Upgrade # Upgrade
print "Upgrading...", print "Upgrading...",
try: try:

View File

@ -126,9 +126,6 @@ class ModelGenerator(object):
for modelTable in self.diff.tablesWithDiff: for modelTable in self.diff.tablesWithDiff:
modelTable = modelTable.tometadata(meta) modelTable = modelTable.tometadata(meta)
dbTable = self.diff.reflected_model.tables[modelTable.name] dbTable = self.diff.reflected_model.tables[modelTable.name]
#print 'TODO DEBUG.cols1', [x.name for x in dbTable.columns]
#dbTable = dbTable.tometadata(meta)
#print 'TODO DEBUG.cols2', [x.name for x in dbTable.columns]
tableName = modelTable.name tableName = modelTable.name
missingInDatabase, missingInModel, diffDecl = self.diff.colDiffs[tableName] missingInDatabase, missingInModel, diffDecl = self.diff.colDiffs[tableName]
if dbCanHandleThisChange(missingInDatabase, missingInModel, diffDecl): if dbCanHandleThisChange(missingInDatabase, missingInModel, diffDecl):

View File

@ -0,0 +1,35 @@
''' Script to migrate repository. This shouldn't use any other migrate modules, so that it can work in any version. '''
import os, sys
def usage():
print '''Usage: %(prog)s repository-to-migrate
Upgrade your repository to the new flat format.
NOTE: You should probably make a backup before running this.
''' % {'prog': sys.argv[0]}
raise SystemExit(1)
def migrate_repository(repos):
print 'Migrating repository at: %s to new format' % repos
versions = '%s/versions' % repos
dirs = os.listdir(versions)
numdirs = [ int(dir) for dir in dirs if dir.isdigit() ] # Only use int's in list.
numdirs.sort() # Sort list.
for dir in numdirs:
origdir = '%s/%s' % (versions, dir)
print ' Working on directory: %s' % origdir
files = os.listdir(origdir)
pass # finish TODO xxx
if __name__ == '__main__':
if len(sys.argv) != 2:
usage()
migrate_repository(sys.argv[1])

View File

@ -113,9 +113,11 @@ class Repository(pathed.Pathed):
log.error("There was an error creating your repository") log.error("There was an error creating your repository")
return cls(path) return cls(path)
def commit(self,*p,**k): def create_script(self,description,**k):
reqd = self.config.get('db_settings','required_dbs') self.versions.createNewVersion(description,**k)
return self.versions.commit(required=reqd,*p,**k)
def create_script_sql(self,database,**k):
self.versions.createNewSQLVersion(database,**k)
latest=property(lambda self: self.versions.latest) latest=property(lambda self: self.versions.latest)
version_table=property(lambda self: self.config.get('db_settings','version_table')) version_table=property(lambda self: self.config.get('db_settings','version_table'))

View File

@ -8,7 +8,6 @@ import inspect
alias = dict( alias = dict(
s=api.script, s=api.script,
ci=api.commit,
vc=api.version_control, vc=api.version_control,
dbv=api.db_version, dbv=api.db_version,
v=api.version, v=api.version,

View File

@ -1,5 +1,5 @@
from migrate.versioning import exceptions,pathed,script from migrate.versioning import exceptions,pathed,script
import os,shutil import os,re,shutil
@ -31,20 +31,37 @@ class VerNum(object):
return int(self)-int(value) return int(self)-int(value)
def strToFilename(s):
s = s.replace(' ', '_').replace('"', '_').replace("'", '_')
while '__' in s:
s = s.replace('__', '_')
return s
class Collection(pathed.Pathed): class Collection(pathed.Pathed):
"""A collection of versioning scripts in a repository""" """A collection of versioning scripts in a repository"""
FILENAME_WITH_VERSION = re.compile(r'^(\d+).*')
def __init__(self,path): def __init__(self,path):
super(Collection,self).__init__(path) super(Collection,self).__init__(path)
self.versions=dict()
ver=self.latest=VerNum(1) # Create temporary list of files, allowing skipped version numbers.
vers=os.listdir(path) files = os.listdir(path)
# This runs up to the latest *complete* version; stops when one's missing if '1' in files:
while str(ver) in vers: raise Exception('It looks like you have a repository in the old format (with directories for each version). Please convert repository before proceeding.')
verpath=self.version_path(ver) tempVersions = dict()
self.versions[ver]=Version(verpath) for filename in files:
ver+=1 match = self.FILENAME_WITH_VERSION.match(filename)
self.latest=ver-1 if match:
num = int(match.group(1))
tempVersions.setdefault(num, []).append(filename)
else:
pass # Must be a helper file or something, let's ignore it.
# Create the versions member where the keys are VerNum's and the values are Version's.
self.versions=dict()
for num, files in tempVersions.items():
self.versions[VerNum(num)] = Version(num, path, files)
self.latest = max([VerNum(0)] + self.versions.keys()) # calculate latest version
def version_path(self,ver): def version_path(self,ver):
return os.path.join(self.path,str(ver)) return os.path.join(self.path,str(ver))
@ -54,47 +71,52 @@ class Collection(pathed.Pathed):
vernum = self.latest vernum = self.latest
return self.versions[VerNum(vernum)] return self.versions[VerNum(vernum)]
def commit(self,path,ver=None,*p,**k): def getNewVersion(self):
"""Commit a script to this collection of scripts ver = self.latest+1
"""
maxver = self.latest+1
if ver is None:
ver = maxver
# Ver must be valid: can't upgrade past the next version
# No change scripts exist for 0 (even though it's a valid version) # No change scripts exist for 0 (even though it's a valid version)
if ver > maxver or ver == 0: if ver <= 0:
raise exceptions.InvalidVersionError() raise exceptions.InvalidVersionError()
verpath = self.version_path(ver)
tmpname = None
try:
# If replacing an old version, copy it in case it gets trashed
if os.path.exists(verpath):
tmpname = os.path.join(os.path.split(verpath)[0],"%s_tmp"%ver)
shutil.copytree(verpath,tmpname)
version = Version(verpath)
else:
# Create version folder
version = Version.create(verpath)
self.versions[ver] = version
# Commit the individual script
script = version.commit(path,*p,**k)
except:
# Rollback everything we did in the try before dying, and reraise
# Remove the created version folder
shutil.rmtree(verpath)
# Rollback if a version already existed above
if tmpname is not None:
shutil.move(tmpname,verpath)
raise
# Success: mark latest; delete old version
if tmpname is not None:
shutil.rmtree(tmpname)
self.latest = ver self.latest = ver
return ver
def createNewVersion(self, description, **k):
ver = self.getNewVersion()
extra = strToFilename(description)
if extra:
if extra == '_':
extra = ''
elif not extra.startswith('_'):
extra = '_%s' % extra
filename = '%03d%s.py' % (ver, extra)
filepath = self.version_path(filename)
if os.path.exists(filepath):
raise Exception('Script already exists: %s' % filepath)
else:
script.PythonScript.create(filepath)
self.versions[ver] = Version(ver, self.path, [filename])
def createNewSQLVersion(self, database, **k):
# Determine version number to use.
if (not self.versions) or self.versions[self.latest].python:
# First version or current version already contains python script, so create a new version.
ver = self.getNewVersion()
self.versions[ver] = Version(ver, self.path, [])
else:
ver = self.latest
# Create new files.
for op in ('upgrade', 'downgrade'):
filename = '%03d_%s_%s.sql' % (ver, database, op)
filepath = self.version_path(filename)
if os.path.exists(filepath):
raise Exception('Script already exists: %s' % filepath)
else:
open(filepath, "w").close()
self.versions[ver]._add_script(filepath)
@classmethod @classmethod
def clear(cls): def clear(cls):
super(Collection,cls).clear() super(Collection,cls).clear()
Version.clear()
class extensions: class extensions:
@ -103,28 +125,25 @@ class extensions:
sql='sql' sql='sql'
class Version(pathed.Pathed): class Version(object): # formerly inherit from: (pathed.Pathed):
"""A single version in a repository """A single version in a repository
""" """
def __init__(self,path): def __init__(self,vernum,path,filelist):
super(Version,self).__init__(path)
# Version must be numeric # Version must be numeric
try: try:
self.version=VerNum(os.path.basename(path)) self.version=VerNum(vernum)
except: except:
raise exceptions.InvalidVersionError(path) raise exceptions.InvalidVersionError(vernum)
# Collect scripts in this folder # Collect scripts in this folder
self.sql = dict() self.sql = dict()
self.python = None self.python = None
try:
for script in os.listdir(path): for script in filelist:
# skip __init__.py, because we assume that it's # skip __init__.py, because we assume that it's
# just there to mark the package # just there to mark the package
if script == '__init__.py': if script == '__init__.py':
continue continue
self._add_script(os.path.join(path,script)) self._add_script(os.path.join(path,script))
except:
raise exceptions.InvalidVersionError(path)
def script(self,database=None,operation=None): def script(self,database=None,operation=None):
#if database is None and operation is None: #if database is None and operation is None:
@ -163,10 +182,13 @@ class Version(pathed.Pathed):
self._add_script_py(path) self._add_script_py(path)
elif path.endswith(extensions.sql): elif path.endswith(extensions.sql):
self._add_script_sql(path) self._add_script_sql(path)
SQL_FILENAME = re.compile(r'^(\d+)_([^_]+)_([^_]+).sql')
def _add_script_sql(self,path): def _add_script_sql(self,path):
try: match = self.SQL_FILENAME.match(os.path.basename(path))
version,dbms,op,ext=os.path.basename(path).split('.',3) if match:
except: version, dbms, op = match.group(1), match.group(2), match.group(3)
else:
raise exceptions.ScriptError("Invalid sql script name %s"%path) raise exceptions.ScriptError("Invalid sql script name %s"%path)
# File the script into a dictionary # File the script into a dictionary
@ -176,6 +198,8 @@ class Version(pathed.Pathed):
ops = dbmses[dbms] ops = dbmses[dbms]
ops[op] = script.SqlScript(path) ops[op] = script.SqlScript(path)
def _add_script_py(self,path): def _add_script_py(self,path):
if self.python is not None:
raise Exception('You can only have one Python script per version, but you have: %s and %s' % (self.python, path))
self.python = script.PythonScript(path) self.python = script.PythonScript(path)
def _rm_ignore(self,path): def _rm_ignore(self,path):
@ -185,29 +209,3 @@ class Version(pathed.Pathed):
except OSError: except OSError:
pass pass
def commit(self,path,database=None,operation=None,required=None):
if (database is not None) and (operation is not None):
return self._commit_sql(path,database,operation)
return self._commit_py(path,required)
def _commit_sql(self,path,database,operation):
if not path.endswith(extensions.sql):
msg = "Bad file extension: should end with %s"%extensions.sql
raise exceptions.ScriptError(msg)
dest=os.path.join(self.path,'%s.%s.%s.%s'%(
str(self.version),str(database),str(operation),extensions.sql))
# Move the committed py script to this version's folder
shutil.move(path,dest)
self._add_script(dest)
def _commit_py(self,path_py,required=None):
if (not os.path.exists(path_py)) or (not os.path.isfile(path_py)):
raise exceptions.InvalidVersionError(path_py)
dest = os.path.join(self.path,'%s.%s'%(str(self.version),extensions.py))
# Move the committed py script to this version's folder
shutil.move(path_py,dest)
self._add_script(dest)
# Also delete the .pyc file, if it exists
path_pyc = path_py+'c'
if os.path.exists(path_pyc):
self._rm_ignore(path_pyc)

View File

@ -53,28 +53,11 @@ class TestVersionedRepository(fixture.Pathed):
def setUp(self): def setUp(self):
Repository.clear() Repository.clear()
self.path_repos=self.tmp_repos() self.path_repos=self.tmp_repos()
self.path_script=self.tmp_py()
# Create repository, script # Create repository, script
Repository.create(self.path_repos,'repository_name') Repository.create(self.path_repos,'repository_name')
def test_commit(self):
"""Commit scripts to a repository and detect repository version"""
# Load repository; commit script by pathname; script should go away
self.script_cls.create(self.path_script)
repos=Repository(self.path_repos)
self.assert_(os.path.exists(self.path_script))
repos.commit(self.path_script)
self.assert_(not os.path.exists(self.path_script))
# .pyc file from the committed script shouldn't exist either
self.assert_(not os.path.exists(self.path_script+'c'))
version = repos.versions.version()
self.assert_(os.path.exists(os.path.join(version.path,
"%s.py" % version.version)))
self.assert_(os.path.exists(os.path.join(version.path,
"__init__.py")))
def test_version(self): def test_version(self):
"""We should correctly detect the version of a repository""" """We should correctly detect the version of a repository"""
self.script_cls.create(self.path_script)
repos=Repository(self.path_repos) repos=Repository(self.path_repos)
# Get latest version, or detect if a specified version exists # Get latest version, or detect if a specified version exists
self.assertEquals(repos.latest,0) self.assertEquals(repos.latest,0)
@ -82,15 +65,14 @@ class TestVersionedRepository(fixture.Pathed):
# (so we can't just assume the following tests are correct) # (so we can't just assume the following tests are correct)
self.assert_(repos.latest>=0) self.assert_(repos.latest>=0)
self.assert_(repos.latest<1) self.assert_(repos.latest<1)
# Commit a script and test again # Create a script and test again
repos.commit(self.path_script) repos.create_script('')
self.assertEquals(repos.latest,1) self.assertEquals(repos.latest,1)
self.assert_(repos.latest>=0) self.assert_(repos.latest>=0)
self.assert_(repos.latest>=1) self.assert_(repos.latest>=1)
self.assert_(repos.latest<2) self.assert_(repos.latest<2)
# Commit a new script and test again # Create a new script and test again
self.script_cls.create(self.path_script) repos.create_script('')
repos.commit(self.path_script)
self.assertEquals(repos.latest,2) self.assertEquals(repos.latest,2)
self.assert_(repos.latest>=0) self.assert_(repos.latest>=0)
self.assert_(repos.latest>=1) self.assert_(repos.latest>=1)
@ -98,62 +80,27 @@ class TestVersionedRepository(fixture.Pathed):
self.assert_(repos.latest<3) self.assert_(repos.latest<3)
def test_source(self): def test_source(self):
"""Get a script object by version number and view its source""" """Get a script object by version number and view its source"""
self.script_cls.create(self.path_script)
# Load repository and commit script # Load repository and commit script
repos=Repository(self.path_repos) repos=Repository(self.path_repos)
repos.commit(self.path_script) repos.create_script('')
# Get script object # Get script object
source=repos.version(1).script().source() source=repos.version(1).script().source()
# Source is valid: script must have an upgrade function # Source is valid: script must have an upgrade function
# (not a very thorough test, but should be plenty) # (not a very thorough test, but should be plenty)
self.assert_(source.find('def upgrade')>=0) self.assert_(source.find('def upgrade')>=0)
def test_latestversion(self): def test_latestversion(self):
self.script_cls.create(self.path_script)
"""Repository.version() (no params) returns the latest version""" """Repository.version() (no params) returns the latest version"""
repos=Repository(self.path_repos) repos=Repository(self.path_repos)
repos.commit(self.path_script) repos.create_script('')
self.assert_(repos.version(repos.latest) is repos.version()) self.assert_(repos.version(repos.latest) is repos.version())
self.assert_(repos.version() is not None) self.assert_(repos.version() is not None)
def xtest_commit_fail(self):
"""Failed commits shouldn't corrupt the repository
Test disabled - logsql ran the script on commit; now that that's gone,
the content of the script is not checked before commit
"""
repos=Repository(self.path_repos)
path_script=self.tmp_py()
text_script = """
from sqlalchemy import *
from migrate import *
# Upgrade is not declared; commit should fail
#def upgrade():
# raise Exception()
def downgrade():
raise Exception()
""".replace("\n ","\n")
fd=open(path_script,'w')
fd.write(text_script)
fd.close()
# Record current state, and commit
ver_pre = os.listdir(repos.versions.path)
repos_pre = os.listdir(repos.path)
self.assertRaises(Exception,repos.commit,path_script)
# Version is unchanged
self.assertEquals(repos.latest,0)
# No new files created; committed script not moved
self.assert_(os.path.exists(path_script))
self.assertEquals(os.listdir(repos.versions.path),ver_pre)
self.assertEquals(os.listdir(repos.path),repos_pre)
def test_changeset(self): def test_changeset(self):
"""Repositories can create changesets properly""" """Repositories can create changesets properly"""
# Create a nonzero-version repository of empty scripts # Create a nonzero-version repository of empty scripts
repos=Repository(self.path_repos) repos=Repository(self.path_repos)
for i in range(10): for i in range(10):
self.script_cls.create(self.path_script) repos.create_script('')
repos.commit(self.path_script)
def check_changeset(params,length): def check_changeset(params,length):
"""Creates and verifies a changeset""" """Creates and verifies a changeset"""

View File

@ -8,7 +8,6 @@ class TestRunChangeset(fixture.Pathed,fixture.DB):
def setUp(self): def setUp(self):
Repository.clear() Repository.clear()
self.path_repos=self.tmp_repos() self.path_repos=self.tmp_repos()
self.path_script=self.tmp_py()
# Create repository, script # Create repository, script
Repository.create(self.path_repos,'repository_name') Repository.create(self.path_repos,'repository_name')
@ -17,8 +16,7 @@ class TestRunChangeset(fixture.Pathed,fixture.DB):
"""Running a changeset against a repository gives expected results""" """Running a changeset against a repository gives expected results"""
repos=Repository(self.path_repos) repos=Repository(self.path_repos)
for i in range(10): for i in range(10):
script.PythonScript.create(self.path_script) repos.create_script('')
repos.commit(self.path_script)
try: try:
ControlledSchema(self.engine,repos).drop() ControlledSchema(self.engine,repos).drop()
except: except:

View File

@ -58,11 +58,9 @@ class TestControlledSchema(fixture.Pathed,fixture.DB):
dbcontrol.drop() dbcontrol.drop()
# Now try it with a nonzero value # Now try it with a nonzero value
script_path = self.tmp_py()
version=10 version=10
for i in range(version): for i in range(version):
script.PythonScript.create(script_path) self.repos.create_script('')
self.repos.commit(script_path)
self.assertEquals(self.repos.latest,version) self.assertEquals(self.repos.latest,version)
# Test with some mid-range value # Test with some mid-range value

View File

@ -120,17 +120,24 @@ class TestShellCommands(Shell):
def test_script(self): def test_script(self):
"""We can create a migration script via the command line""" """We can create a migration script via the command line"""
script=self.tmp_py() repos=self.tmp_repos()
# Creating a file that doesn't exist should succeed self.assertSuccess(self.cmd('create',repos,'repository_name'))
self.assertSuccess(self.cmd('script',script)) self.assertSuccess(self.cmd('script', '--repository=%s' % repos, 'Desc'))
self.assert_(os.path.exists(script)) self.assert_(os.path.exists('%s/versions/001_Desc.py' % repos))
# 's' instead of 'script' should work too # 's' instead of 'script' should work too
os.remove(script) self.assertSuccess(self.cmd('script', '--repository=%s' % repos, 'More'))
self.assert_(not os.path.exists(script)) self.assert_(os.path.exists('%s/versions/002_More.py' % repos))
self.assertSuccess(self.cmd('s',script))
self.assert_(os.path.exists(script)) def test_script_sql(self):
"""We can create a migration sql script via the command line"""
repos=self.tmp_repos()
self.assertSuccess(self.cmd('create',repos,'repository_name'))
self.assertSuccess(self.cmd('script_sql', '--repository=%s' % repos, 'mydb'))
self.assert_(os.path.exists('%s/versions/001_mydb_upgrade.sql' % repos))
self.assert_(os.path.exists('%s/versions/001_mydb_downgrade.sql' % repos))
# Can't create it again: it already exists # Can't create it again: it already exists
self.assertFailure(self.cmd('script',script)) self.assertFailure(self.cmd('script_sql', '--repository=%s' % repos, 'mydb'))
def test_manage(self): def test_manage(self):
"""Create a project management script""" """Create a project management script"""
@ -145,20 +152,8 @@ class TestShellRepository(Shell):
def setUp(self): def setUp(self):
"""Create repository, python change script""" """Create repository, python change script"""
self.path_repos=repos=self.tmp_repos() self.path_repos=repos=self.tmp_repos()
self.path_script=script=self.tmp_py()
self.assertSuccess(self.cmd('create',repos,'repository_name')) self.assertSuccess(self.cmd('create',repos,'repository_name'))
self.assertSuccess(self.cmd('script',script))
def test_commit_1(self):
"""Commits should work correctly; script should vanish after commit"""
self.assert_(os.path.exists(self.path_script))
self.assertSuccess(self.cmd('commit',self.path_script,self.path_repos))
self.assert_(not os.path.exists(self.path_script))
def test_commit_2(self):
"""Commits should work correctly with repository as a keyword param"""
self.assert_(os.path.exists(self.path_script))
self.assertSuccess(self.cmd('commit',self.path_script,'--repository=%s'%self.path_repos))
self.assert_(not os.path.exists(self.path_script))
def test_version(self): def test_version(self):
"""Correctly detect repository version""" """Correctly detect repository version"""
# Version: 0 (no scripts yet); successful execution # Version: 0 (no scripts yet); successful execution
@ -169,17 +164,17 @@ class TestShellRepository(Shell):
fd=self.execute(self.cmd('version',self.path_repos)) fd=self.execute(self.cmd('version',self.path_repos))
self.assertEquals(fd.read().strip(),"0") self.assertEquals(fd.read().strip(),"0")
self.assertSuccess(fd) self.assertSuccess(fd)
# Commit a script and version should increment # Create a script and version should increment
self.assertSuccess(self.cmd('commit',self.path_script,'--repository=%s'%self.path_repos)) self.assertSuccess(self.cmd('script', '--repository=%s' % self.path_repos, 'Desc'))
fd=self.execute(self.cmd('version',self.path_repos)) fd=self.execute(self.cmd('version',self.path_repos))
self.assertEquals(fd.read().strip(),"1") self.assertEquals(fd.read().strip(),"1")
self.assertSuccess(fd) self.assertSuccess(fd)
def test_source(self): def test_source(self):
"""Correctly fetch a script's source""" """Correctly fetch a script's source"""
source=open(self.path_script).read() self.assertSuccess(self.cmd('script', '--repository=%s' % self.path_repos, 'Desc'))
filename='%s/versions/001_Desc.py' % self.path_repos
source=open(filename).read()
self.assert_(source.find('def upgrade')>=0) self.assert_(source.find('def upgrade')>=0)
self.assertSuccess(self.cmd('commit',self.path_script,'--repository=%s'%self.path_repos))
# Later, we'll want to make repos optional somehow
# Version is now 1 # Version is now 1
fd=self.execute(self.cmd('version',self.path_repos)) fd=self.execute(self.cmd('version',self.path_repos))
self.assert_(fd.read().strip()=="1") self.assert_(fd.read().strip()=="1")
@ -190,50 +185,11 @@ class TestShellRepository(Shell):
self.assertSuccess(fd) self.assertSuccess(fd)
self.assert_(result.strip()==source.strip()) self.assert_(result.strip()==source.strip())
# We can also send the source to a file... test that too # We can also send the source to a file... test that too
self.assertSuccess(self.cmd('source',1,self.path_script,'--repository=%s'%self.path_repos)) self.assertSuccess(self.cmd('source',1,filename,'--repository=%s'%self.path_repos))
self.assert_(os.path.exists(self.path_script)) self.assert_(os.path.exists(filename))
fd=open(self.path_script) fd=open(filename)
result=fd.read() result=fd.read()
self.assert_(result.strip()==source.strip()) self.assert_(result.strip()==source.strip())
def test_commit_replace(self):
"""Commit can replace a specified version"""
# Commit the default script
self.assertSuccess(self.cmd('commit',self.path_script,self.path_repos))
self.assertEquals(self.cmd_version(self.path_repos),1)
# Read the default script's text
fd=self.execute(self.cmd('source',1,'--repository=%s'%self.path_repos))
script_src_1 = fd.read()
self.assertSuccess(fd)
# Commit a new script
script_text="""
from sqlalchemy import *
from migrate import *
# Our test is just that the source is different; so we don't have to
# do anything useful in here.
def upgrade():
pass
def downgrade():
pass
""".replace('\n ','\n')
fd=open(self.path_script,'w')
fd.write(script_text)
fd.close()
self.assertSuccess(self.cmd('commit',self.path_script,self.path_repos,1))
# We specified a version above - it should replace that, not create new
self.assertEquals(self.cmd_version(self.path_repos),1)
# Source should change
fd=self.execute(self.cmd('source',1,'--repository=%s'%self.path_repos))
script_src_2 = fd.read()
self.assertSuccess(fd)
self.assertNotEquals(script_src_1,script_src_2)
# source should be reasonable
self.assertEquals(script_src_2.strip(),script_text.strip())
self.assert_(script_src_1.count('from migrate import'))
self.assert_(script_src_1.count('from sqlalchemy import'))
class TestShellDatabase(Shell,fixture.DB): class TestShellDatabase(Shell,fixture.DB):
"""Commands associated with a particular database""" """Commands associated with a particular database"""
@ -263,8 +219,7 @@ class TestShellDatabase(Shell,fixture.DB):
path_script = self.tmp_py() path_script = self.tmp_py()
version=1 version=1
for i in range(version): for i in range(version):
self.assertSuccess(self.cmd('script',path_script)) self.assertSuccess(self.cmd('script', '--repository=%s' % path_repos, 'Desc'))
self.assertSuccess(self.cmd('commit',path_script,path_repos))
# Repository version is correct # Repository version is correct
fd=self.execute(self.cmd('version',path_repos)) fd=self.execute(self.cmd('version',path_repos))
self.assertEquals(fd.read().strip(),str(version)) self.assertEquals(fd.read().strip(),str(version))
@ -284,7 +239,6 @@ class TestShellDatabase(Shell,fixture.DB):
# Create a repository # Create a repository
repos_name = 'repos_name' repos_name = 'repos_name'
repos_path = self.tmp() repos_path = self.tmp()
script_path = self.tmp_py()
self.assertSuccess(self.cmd('create',repos_path,repos_name)) self.assertSuccess(self.cmd('create',repos_path,repos_name))
self.assertEquals(self.cmd_version(repos_path),0) self.assertEquals(self.cmd_version(repos_path),0)
# Version the DB # Version the DB
@ -301,8 +255,7 @@ class TestShellDatabase(Shell,fixture.DB):
self.assertFailure(self.cmd('upgrade',self.url,repos_path,-1)) self.assertFailure(self.cmd('upgrade',self.url,repos_path,-1))
# Add a script to the repository; upgrade the db # Add a script to the repository; upgrade the db
self.assertSuccess(self.cmd('script',script_path)) self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc'))
self.assertSuccess(self.cmd('commit',script_path,repos_path))
self.assertEquals(self.cmd_version(repos_path),1) self.assertEquals(self.cmd_version(repos_path),1)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0) self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
@ -321,14 +274,6 @@ class TestShellDatabase(Shell,fixture.DB):
self.assertSuccess(self.cmd('drop_version_control',self.url,repos_path)) self.assertSuccess(self.cmd('drop_version_control',self.url,repos_path))
def _run_test_sqlfile(self,upgrade_script,downgrade_script): def _run_test_sqlfile(self,upgrade_script,downgrade_script):
upgrade_path = self.tmp_sql()
downgrade_path = self.tmp_sql()
upgrade = (upgrade_path,upgrade_script)
downgrade = (downgrade_path,downgrade_script)
for file_path,file_text in (upgrade,downgrade):
fd = open(file_path,'w')
fd.write(file_text)
fd.close()
repos_path = self.tmp() repos_path = self.tmp()
repos_name = 'repos' repos_name = 'repos'
@ -338,15 +283,12 @@ class TestShellDatabase(Shell,fixture.DB):
self.assertEquals(self.cmd_version(repos_path),0) self.assertEquals(self.cmd_version(repos_path),0)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0) self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
self.assertSuccess(self.cmd('commit',upgrade_path,repos_path,'postgres','upgrade')) beforeCount = len(os.listdir(os.path.join(repos_path,'versions'))) # hmm, this number changes sometimes based on running from svn
self.assertSuccess(self.cmd('script_sql', '--repository=%s' % repos_path, 'postgres'))
self.assertEquals(self.cmd_version(repos_path),1) self.assertEquals(self.cmd_version(repos_path),1)
self.assertEquals(len(os.listdir(os.path.join(repos_path,'versions','1'))),2) self.assertEquals(len(os.listdir(os.path.join(repos_path,'versions'))), beforeCount + 2)
open('%s/versions/001_postgres_upgrade.sql' % repos_path, 'a').write(upgrade_script)
# Add, not replace open('%s/versions/001_postgres_downgrade.sql' % repos_path, 'a').write(downgrade_script)
self.assertSuccess(self.cmd('commit',downgrade_path,repos_path,'postgres','downgrade','--version=1'))
self.assertEquals(len(os.listdir(os.path.join(repos_path,'versions','1'))),3)
self.assertEquals(self.cmd_version(repos_path),1)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0) self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
self.assertRaises(Exception,self.engine.text('select * from t_table').execute) self.assertRaises(Exception,self.engine.text('select * from t_table').execute)
@ -392,7 +334,6 @@ class TestShellDatabase(Shell,fixture.DB):
def test_test(self): def test_test(self):
repos_name = 'repos_name' repos_name = 'repos_name'
repos_path = self.tmp() repos_path = self.tmp()
script_path = self.tmp_py()
self.assertSuccess(self.cmd('create',repos_path,repos_name)) self.assertSuccess(self.cmd('create',repos_path,repos_name))
self.exitcode(self.cmd('drop_version_control',self.url,repos_path)) self.exitcode(self.cmd('drop_version_control',self.url,repos_path))
@ -401,9 +342,9 @@ class TestShellDatabase(Shell,fixture.DB):
self.assertEquals(self.cmd_db_version(self.url,repos_path),0) self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
# Empty script should succeed # Empty script should succeed
self.assertSuccess(self.cmd('script',script_path)) self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc'))
self.assertSuccess(self.cmd('test',script_path,repos_path,self.url)) self.assertSuccess(self.cmd('test',repos_path,self.url))
self.assertEquals(self.cmd_version(repos_path),0) self.assertEquals(self.cmd_version(repos_path),1)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0) self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
# Error script should fail # Error script should fail
@ -423,8 +364,8 @@ class TestShellDatabase(Shell,fixture.DB):
file=open(script_path,'w') file=open(script_path,'w')
file.write(script_text) file.write(script_text)
file.close() file.close()
self.assertFailure(self.cmd('test',script_path,repos_path,self.url)) self.assertFailure(self.cmd('test',repos_path,self.url,'blah blah'))
self.assertEquals(self.cmd_version(repos_path),0) self.assertEquals(self.cmd_version(repos_path),1)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0) self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
# Nonempty script using migrate_engine should succeed # Nonempty script using migrate_engine should succeed
@ -451,8 +392,8 @@ class TestShellDatabase(Shell,fixture.DB):
file=open(script_path,'w') file=open(script_path,'w')
file.write(script_text) file.write(script_text)
file.close() file.close()
self.assertSuccess(self.cmd('test',script_path,repos_path,self.url)) self.assertSuccess(self.cmd('test',repos_path,self.url))
self.assertEquals(self.cmd_version(repos_path),0) self.assertEquals(self.cmd_version(repos_path),1)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0) self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
@fixture.usedb() @fixture.usedb()
@ -507,7 +448,7 @@ class TestShellDatabase(Shell,fixture.DB):
output, exitcode = self.output_and_exitcode('python %s update_db_from_model' % script_path) output, exitcode = self.output_and_exitcode('python %s update_db_from_model' % script_path)
self.assertEquals(output, "") self.assertEquals(output, "")
self.assertEquals(self.cmd_version(repos_path),0) self.assertEquals(self.cmd_version(repos_path),0)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0) # version did not get bumped yet because new version not yet committed self.assertEquals(self.cmd_db_version(self.url,repos_path),0) # version did not get bumped yet because new version not yet created
output, exitcode = self.output_and_exitcode('python %s compare_model_to_db' % script_path) output, exitcode = self.output_and_exitcode('python %s compare_model_to_db' % script_path)
self.assertEquals(output, "No schema diffs") self.assertEquals(output, "No schema diffs")
output, exitcode = self.output_and_exitcode('python %s create_model' % script_path) output, exitcode = self.output_and_exitcode('python %s create_model' % script_path)
@ -545,13 +486,12 @@ class TestShellDatabase(Shell,fixture.DB):
tmp_account_rundiffs.drop() tmp_account_rundiffs.drop()
""") """)
# Commit the change. # Save the upgrade script.
upgrade_script_path = self.tmp_named('upgrade_script.py') self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc'))
upgrade_script_path = '%s/versions/001_Desc.py' % repos_path
open(upgrade_script_path, 'w').write(output) open(upgrade_script_path, 'w').write(output)
#output, exitcode = self.output_and_exitcode('python %s test %s' % (script_path, upgrade_script_path)) # no, we already upgraded the db above #output, exitcode = self.output_and_exitcode('python %s test %s' % (script_path, upgrade_script_path)) # no, we already upgraded the db above
#self.assertEquals(output, "") #self.assertEquals(output, "")
output, exitcode = self.output_and_exitcode('python %s commit %s' % (script_path, upgrade_script_path))
self.assertEquals(output, "")
output, exitcode = self.output_and_exitcode('python %s update_db_from_model' % script_path) # bump the db_version output, exitcode = self.output_and_exitcode('python %s update_db_from_model' % script_path) # bump the db_version
self.assertEquals(output, "") self.assertEquals(output, "")
self.assertEquals(self.cmd_version(repos_path),1) self.assertEquals(self.cmd_version(repos_path),1)

View File

@ -42,3 +42,13 @@ class TestVerNum(fixture.Base):
self.assert_(VerNum(1)>=1) self.assert_(VerNum(1)>=1)
self.assert_(not VerNum(1)>=2) self.assert_(not VerNum(1)>=2)
self.assert_(VerNum(2)>=1) self.assert_(VerNum(2)>=1)
class TestDescriptionNaming(fixture.Base):
def test_names(self):
self.assertEquals(strToFilename(''), '')
self.assertEquals(strToFilename('a'), 'a')
self.assertEquals(strToFilename('Abc Def'), 'Abc_Def')
self.assertEquals(strToFilename('Abc "D" Ef'), 'Abc_D_Ef')
self.assertEquals(strToFilename("Abc's Stuff"), 'Abc_s_Stuff')
self.assertEquals(strToFilename("a b"), 'a_b')