apply PEP8 to version.py, fixed notification of missing test_db.cfg

This commit is contained in:
iElectric 2009-06-01 22:23:50 +00:00
parent 1b927fa427
commit 7cd2c3233b
3 changed files with 98 additions and 71 deletions

View File

@ -1,37 +1,54 @@
from migrate.versioning import exceptions,pathed,script
import os,re,shutil
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import re
import shutil
from migrate.versioning import exceptions, pathed, script
class VerNum(object):
"""A version number"""
_instances = dict()
def __new__(cls, value):
val = str(value)
if val not in cls._instances:
cls._instances[val] = super(VerNum, cls).__new__(cls)
ret = cls._instances[val]
return ret
def __init__(self,value):
self.value = str(int(value))
if self < 0:
raise ValueError("Version number cannot be negative")
def __repr__(self):
return str(self.value)
def __str__(self):
return str(self.value)
def __int__(self):
return int(self.value)
def __add__(self, value):
ret = int(self) + int(value)
return VerNum(ret)
def __sub__(self, value):
return self + (int(value) * -1)
def __cmp__(self, value):
return int(self) - int(value)
def __repr__(self):
return str(self.value)
def __str__(self):
return str(self.value)
def __int__(self):
return int(self.value)
def str_to_filename(s):
"""Replaces spaces, (double and single) quotes
and double underscores to underscores
"""
def strToFilename(s):
s = s.replace(' ', '_').replace('"', '_').replace("'", '_')
while '__' in s:
s = s.replace('__', '_')
@ -40,15 +57,19 @@ def strToFilename(s):
class Collection(pathed.Pathed):
"""A collection of versioning scripts in a repository"""
FILENAME_WITH_VERSION = re.compile(r'^(\d+).*')
def __init__(self, path):
super(Collection, self).__init__(path)
# Create temporary list of files, allowing skipped version numbers.
files = os.listdir(path)
if '1' in files:
raise Exception('It looks like you have a repository in the old format (with directories for each version). Please convert repository before proceeding.')
tempVersions = dict()
if '1' in files:
raise Exception('It looks like you have a repository in the old '
'format (with directories for each version). '
'Please convert repository before proceeding.')
for filename in files:
match = self.FILENAME_WITH_VERSION.match(filename)
if match:
@ -57,11 +78,13 @@ class Collection(pathed.Pathed):
else:
pass # Must be a helper file or something, let's ignore it.
# Create the versions member where the keys are VerNum's and the values are Version's.
# Create the versions member where the keys
# are VerNum's and the values are Version's.
self.versions = dict()
for num, files in tempVersions.items():
self.versions[VerNum(num)] = Version(num, path, files)
self.latest = max([VerNum(0)] + self.versions.keys()) # calculate latest version
# calculate latest version
self.latest = max([VerNum(0)] + self.versions.keys())
def version_path(self, ver):
return os.path.join(self.path, str(ver))
@ -81,18 +104,22 @@ class Collection(pathed.Pathed):
def createNewVersion(self, description, **k):
ver = self.getNewVersion()
extra = strToFilename(description)
extra = str_to_filename(description)
if extra:
if extra == '_':
extra = ''
elif not extra.startswith('_'):
extra = '_%s' % extra
filename = '%03d%s.py' % (ver, extra)
filepath = self.version_path(filename)
if os.path.exists(filepath):
raise Exception('Script already exists: %s' % filepath)
else:
script.PythonScript.create(filepath)
self.versions[ver] = Version(ver, self.path, [filename])
def createNewSQLVersion(self, database, **k):
@ -116,21 +143,22 @@ class Collection(pathed.Pathed):
super(Collection, cls).clear()
class extensions:
class Extensions:
"""A namespace for file extensions"""
py = 'py'
sql = 'sql'
class Version(object): # formerly inherit from: (pathed.Pathed):
"""A single version in a repository
"""
"""A single version in a repository """
def __init__(self, vernum, path, filelist):
# Version must be numeric
try:
self.version = VerNum(vernum)
except:
raise exceptions.InvalidVersionError(vernum)
# Collect scripts in this folder
self.sql = dict()
self.python = None
@ -143,18 +171,14 @@ class Version(object): # formerly inherit from: (pathed.Pathed):
self._add_script(os.path.join(path, script))
def script(self, database=None, operation=None):
#if database is None and operation is None:
# return self._script_py()
#print database,operation,self.sql
try:
# Try to return a .sql script first
try:
return self._script_sql(database, operation)
except KeyError:
pass # No .sql script exists
try:
# Try to return the default .sql script
try:
return self._script_sql('default', operation)
except KeyError:
pass # No .sql script exists
@ -163,8 +187,10 @@ class Version(object): # formerly inherit from: (pathed.Pathed):
assert ret is not None
return ret
def _script_py(self):
return self.python
def _script_sql(self, database, operation):
return self.sql[database][operation]
@ -184,14 +210,16 @@ class Version(object): # formerly inherit from: (pathed.Pathed):
return ret
def _add_script(self, path):
if path.endswith(extensions.py):
if path.endswith(Extensions.py):
self._add_script_py(path)
elif path.endswith(extensions.sql):
elif path.endswith(Extensions.sql):
self._add_script_sql(path)
SQL_FILENAME = re.compile(r'^(\d+)_([^_]+)_([^_]+).sql')
def _add_script_sql(self, path):
match = self.SQL_FILENAME.match(os.path.basename(path))
if match:
version, dbms, op = match.group(1), match.group(2), match.group(3)
else:
@ -203,9 +231,11 @@ class Version(object): # formerly inherit from: (pathed.Pathed):
dbmses[dbms] = dict()
ops = dbmses[dbms]
ops[op] = script.SqlScript(path)
def _add_script_py(self, path):
if self.python is not None:
raise Exception('You can only have one Python script per version, but you have: %s and %s' % (self.python, path))
raise Exception('You can only have one Python script per version,'
' but you have: %s and %s' % (self.python, path))
self.python = script.PythonScript(path)
def _rm_ignore(self, path):
@ -214,4 +244,3 @@ class Version(object): # formerly inherit from: (pathed.Pathed):
os.remove(path)
except OSError:
pass

View File

@ -13,10 +13,8 @@ def readurls():
try:
fd=open(fullpath)
except IOError:
print "You must specify the databases to use for testing!"
tmplfile = "%s.tmpl"%filename
print "Copy %s.tmpl to %s and edit your database URLs."%(tmplfile,filename)
raise
raise IOError("""You must specify the databases to use for testing!
Copy %(filename)s.tmpl to %(filename)s and edit your database URLs.""" % locals())
#fd = resource_stream('__main__',filename)
for line in fd:
if line.startswith('#'):

View File

@ -45,10 +45,10 @@ class TestVerNum(fixture.Base):
class TestDescriptionNaming(fixture.Base):
def test_names(self):
self.assertEquals(strToFilename(''), '')
self.assertEquals(strToFilename('a'), 'a')
self.assertEquals(strToFilename('Abc Def'), 'Abc_Def')
self.assertEquals(strToFilename('Abc "D" Ef'), 'Abc_D_Ef')
self.assertEquals(strToFilename("Abc's Stuff"), 'Abc_s_Stuff')
self.assertEquals(strToFilename("a b"), 'a_b')
self.assertEquals(str_to_filename(''), '')
self.assertEquals(str_to_filename('a'), 'a')
self.assertEquals(str_to_filename('Abc Def'), 'Abc_Def')
self.assertEquals(str_to_filename('Abc "D" Ef'), 'Abc_D_Ef')
self.assertEquals(str_to_filename("Abc's Stuff"), 'Abc_s_Stuff')
self.assertEquals(str_to_filename("a b"), 'a_b')