diff --git a/migrate/changeset/ansisql.py b/migrate/changeset/ansisql.py index dbcd102..6e82aaa 100644 --- a/migrate/changeset/ansisql.py +++ b/migrate/changeset/ansisql.py @@ -29,6 +29,8 @@ class RawAlterTableVisitor(object): def start_alter_table(self,param): table = self._to_table(param) table_name = self._to_table_name(table) + print self + print self._do_quote_table_identifier(table_name) self.append('\nALTER TABLE %s ' % self._do_quote_table_identifier(table_name)) return table @@ -93,7 +95,7 @@ class ANSIColumnDropper(AlterTableVisitor): def visit_column(self,column): """Drop a column; #33""" table = self.start_alter_table(column) - self.append(' DROP COLUMN "%s"'%column.name) + self.append(' DROP COLUMN %s'%self._do_quote_column_identifier(column.name)) self.execute() #if column.primary_key: # cons = self._pk_constraint(table,column,False) diff --git a/migrate/changeset/databases/mysql.py b/migrate/changeset/databases/mysql.py index 1a87ff3..aa65166 100644 --- a/migrate/changeset/databases/mysql.py +++ b/migrate/changeset/databases/mysql.py @@ -5,9 +5,15 @@ from sqlalchemy.databases import mysql as sa_base MySQLSchemaGenerator = sa_base.MySQLSchemaGenerator class MySQLColumnGenerator(MySQLSchemaGenerator,ansisql.ANSIColumnGenerator): + def _do_quote_table_identifier(self, identifier): + return '%s'%identifier pass class MySQLColumnDropper(ansisql.ANSIColumnDropper): - pass + def _do_quote_table_identifier(self, identifier): + return '%s'%identifier + def _do_quote_column_identifier(self, identifier): + return '%s'%identifier + class MySQLSchemaChanger(MySQLSchemaGenerator,ansisql.ANSISchemaChanger): def visit_column(self,delta): keys = delta.keys() @@ -34,8 +40,13 @@ class MySQLSchemaChanger(MySQLSchemaGenerator,ansisql.ANSISchemaChanger): def visit_index(self,param): # If MySQL can do this, I can't find how raise exceptions.NotSupportedError("MySQL cannot rename indexes") + def _do_quote_table_identifier(self, identifier): + return '%s'%identifier + class MySQLConstraintGenerator(ansisql.ANSIConstraintGenerator): - pass + def _do_quote_table_identifier(self, identifier): + return '%s'%identifier + class MySQLConstraintDropper(ansisql.ANSIConstraintDropper): #def visit_constraint(self,constraint): # if isinstance(constraint,sqlalchemy.schema.PrimaryKeyConstraint): @@ -54,6 +65,9 @@ class MySQLConstraintDropper(ansisql.ANSIConstraintDropper): self.append(constraint.name) self.execute() + def _do_quote_table_identifier(self, identifier): + return '%s'%identifier + class MySQLDialect(ansisql.ANSIDialect): columngenerator = MySQLColumnGenerator columndropper = MySQLColumnDropper diff --git a/migrate/changeset/databases/postgres.py b/migrate/changeset/databases/postgres.py index acd6f6c..533a5f4 100644 --- a/migrate/changeset/databases/postgres.py +++ b/migrate/changeset/databases/postgres.py @@ -20,7 +20,8 @@ class PGSchemaChanger(ansisql.ANSISchemaChanger): class PGConstraintGenerator(ansisql.ANSIConstraintGenerator): - pass + def _do_quote_table_identifier(self, identifier): + return identifier class PGConstraintDropper(ansisql.ANSIConstraintDropper): pass diff --git a/migrate/changeset/databases/sqlite.py b/migrate/changeset/databases/sqlite.py index ebaa1de..a0b85d2 100644 --- a/migrate/changeset/databases/sqlite.py +++ b/migrate/changeset/databases/sqlite.py @@ -24,6 +24,9 @@ class SQLiteSchemaChanger(ansisql.ANSISchemaChanger): return self._not_supported('ALTER TABLE') def visit_index(self,param): self._not_supported('ALTER INDEX') + def _do_quote_column_identifier(self, identifier): + return '"%s"'%identifier + class SQLiteConstraintGenerator(ansisql.ANSIConstraintGenerator): def visit_migrate_primary_key_constraint(self,constraint): tmpl = "CREATE UNIQUE INDEX %s ON %s ( %s )" diff --git a/migrate/versioning/shell.py b/migrate/versioning/shell.py index 96c0840..229f3da 100644 --- a/migrate/versioning/shell.py +++ b/migrate/versioning/shell.py @@ -90,8 +90,8 @@ def parse_args(*args,**kwargs): try: cmdname = args.pop(0) if cmdname == 'downgrade': - if not args[0].startswith('--'): - kwargs['version'] = args.pop(0) + if not args[-1].startswith('--'): + kwargs['version'] = args[-1] except IndexError: # No command specified: no error message; just show usage diff --git a/test/changeset/test_changeset.py b/test/changeset/test_changeset.py index 75735c0..f32b789 100644 --- a/test/changeset/test_changeset.py +++ b/test/changeset/test_changeset.py @@ -16,26 +16,25 @@ class TestAddDropColumn(fixture.DB): table_name = 'tmp_adddropcol' table_int = 0 - def setUp(self): - fixture.DB.setUp(self) - self._connect(self.url) + def _setup(self, url): + super(TestAddDropColumn, self)._setup(url) self.meta.clear() self.table = Table(self.table_name,self.meta, Column('id',Integer,primary_key=True), ) - super(TestAddDropColumn,self).setUp() self.meta.bind = self.engine if self.engine.has_table(self.table.name): self.table.drop() self.table.create() - def tearDown(self): - super(TestAddDropColumn,self).tearDown() + + def _teardown(self): if self.engine.has_table(self.table.name): try: self.table.drop() except: pass self.meta.clear() + super(TestAddDropColumn,self)._teardown() def run_(self,create_column_func,drop_column_func,*col_p,**col_k): col_name = 'data' @@ -68,6 +67,7 @@ class TestAddDropColumn(fixture.DB): self.assertEquals(getattr(self.table.c,col_name),col) #drop_column(col,self.table) col = getattr(self.table.c,col_name) + print 'inside fxn', self.url # SQLite can't do drop column: stop here if self.url.startswith('sqlite://'): self.assertRaises(changeset.exceptions.NotSupportedError,drop_column_func,col) @@ -146,6 +146,8 @@ class TestAddDropColumn(fixture.DB): @fixture.usedb() def test_byname(self): """Add/drop columns via functions; by table object and column name""" + print 'vyname', self.url + print self def add_func(col): self.table.append_column(col) return create_column(col.name,self.table) @@ -189,9 +191,8 @@ class TestRename(fixture.DB): level=fixture.DB.CONNECT meta = MetaData() - def setUp(self): - fixture.DB.setUp(self) - self._connect(self.url) + def _setup(self, url): + super(TestRename, self)._setup(url) self.meta.bind = self.engine #self.meta.connect(self.engine) @fixture.usedb() @@ -278,10 +279,8 @@ class TestColumnChange(fixture.DB): level=fixture.DB.CONNECT table_name = 'tmp_colchange' - def setUp(self): - fixture.DB.setUp(self) - self._connect(self.url) - #self.engine.echo=True + def _setup(self, url): + super(TestColumnChange, self)._setup(url) self.meta = MetaData(self.engine) self.table = Table(self.table_name,self.meta, Column('id',Integer,primary_key=True), @@ -295,7 +294,7 @@ class TestColumnChange(fixture.DB): # SQLite: database schema has changed if not self.url.startswith('sqlite://'): raise - def tearDown(self): + def _teardown(self): if self.table.exists(): try: self.table.drop(self.engine) @@ -304,7 +303,7 @@ class TestColumnChange(fixture.DB): if not self.url.startswith('sqlite://'): raise #self.engine.echo=False - fixture.DB.tearDown(self) + super(TestColumnChange, self)._teardown() @fixture.usedb(supported='sqlite') def test_sqlite_not_supported(self): diff --git a/test/changeset/test_constraint.py b/test/changeset/test_constraint.py index 0c64bd3..4f8f706 100644 --- a/test/changeset/test_constraint.py +++ b/test/changeset/test_constraint.py @@ -5,13 +5,13 @@ from migrate.changeset import * class TestConstraint(fixture.DB): level=fixture.DB.CONNECT - def setUp(self): - fixture.DB.setUp(self) + def _setup(self, url): + super(TestConstraint, self)._setup(url) self._create_table() - def tearDown(self): + def _teardown(self): if hasattr(self,'table') and self.engine.has_table(self.table.name): self.table.drop() - fixture.DB.tearDown(self) + super(TestConstraint, self)._teardown() def _create_table(self): self._connect(self.url) @@ -105,8 +105,8 @@ class TestConstraint(fixture.DB): class TestAutoname(fixture.DB): level=fixture.DB.CONNECT - def setUp(self): - fixture.DB.setUp(self) + def _setup(self, url): + super(TestAutoname, self)._setup(url) self._connect(self.url) self.meta = MetaData(self.engine) self.table = Table('mytable',self.meta, @@ -116,10 +116,11 @@ class TestAutoname(fixture.DB): if self.engine.has_table(self.table.name): self.table.drop() self.table.create() - def tearDown(self): + + def _teardown(self): if hasattr(self,'table') and self.engine.has_table(self.table.name): self.table.drop() - fixture.DB.tearDown(self) + super(TestAutoname, self)._teardown() @fixture.usedb(not_supported='oracle') def test_autoname(self): diff --git a/test/fixture/base.py b/test/fixture/base.py index fb8ba19..2b675a2 100644 --- a/test/fixture/base.py +++ b/test/fixture/base.py @@ -1,6 +1,6 @@ #import unittest #from py.test import raises -from nose.tools import raises +from nose.tools import raises, eq_ class FakeTestCase(object): """Mimics unittest.testcase methods @@ -19,7 +19,8 @@ class FakeTestCase(object): def assert_(self,x,doc=None): assert x def assertEquals(self,x,y,doc=None): - assert x == y + eq_(x, y) + def assertNotEquals(self,x,y,doc=None): assert x != y @@ -37,9 +38,6 @@ class FakeTestCase(object): message = "%s() did not raise %s" % (func.__name__, valid) raise AssertionError(message) - #def assertRaises(self,error,func,*p,**k): - # assert raises(error,func,*p,**k) - def assertEqualsIgnoreWhitespace(self, v1, v2): def createLines(s): s = s.replace(' ', '') @@ -54,10 +52,9 @@ class FakeTestCase(object): class Base(FakeTestCase): """Base class for other test cases""" - def ignoreErrors(self,*p,**k): + def ignoreErrors(self, func, *p,**k): """Call a function, ignoring any exceptions""" - func=p[0] try: - func(*p[1:],**k) + func(*p,**k) except: pass diff --git a/test/fixture/database.py b/test/fixture/database.py index fd645d2..840a9a5 100644 --- a/test/fixture/database.py +++ b/test/fixture/database.py @@ -58,22 +58,13 @@ def usedb(supported=None,not_supported=None): urls = DB.urls urls = [url for url in urls if is_supported(url,supported,not_supported)] def dec(func): - def entangle(self): - for url in urls: - self._connect(url) - self.setup_method(func) + for url in urls: + def entangle(self): + self._setup(url) yield func, self - self._disconnect() - self.teardown_method(func) + self._teardown() - #[self.setup_method(func) - #try: - # r = func(self) - #finally: - # self.teardown_method(func) - # self._disconnect() - #yield r - entangle.__name__ = func.__name__ + entangle.__name__ = func.__name__ return entangle return dec @@ -101,7 +92,15 @@ class DB(Base): url=self.url return url + def _setup(self, url): + self._connect(url) + + def _teardown(self): + self._disconnect() + def _connect(self,url): + print 'connecting to', url + print self self.url = url self.engine = self.engines[url] if self.level < self.CONNECT: @@ -121,7 +120,7 @@ class DB(Base): #if hasattr(self,'conn'): # self.conn.close() - def run(self,*p,**k): + def ___run(self,*p,**k): """Run one test for each connection string""" for url in self.urls: self._run_one(url,*p,**k) diff --git a/test/versioning/test_runchangeset.py b/test/versioning/test_runchangeset.py index 192ecc0..fe5a58e 100644 --- a/test/versioning/test_runchangeset.py +++ b/test/versioning/test_runchangeset.py @@ -5,7 +5,8 @@ import os,shutil class TestRunChangeset(fixture.Pathed,fixture.DB): level=fixture.DB.CONNECT - def setUp(self): + def _setup(self, url): + super(TestRunChangeset, self)._setup(url) Repository.clear() self.path_repos=self.tmp_repos() # Create repository, script diff --git a/test/versioning/test_schema.py b/test/versioning/test_schema.py index c5e96ac..fc94329 100644 --- a/test/versioning/test_schema.py +++ b/test/versioning/test_schema.py @@ -6,7 +6,9 @@ import os,shutil class TestControlledSchema(fixture.Pathed,fixture.DB): # Transactions break postgres in this test; we'll clean up after ourselves level=fixture.DB.CONNECT - def setUp(self): + + def _setup(self, url): + super(TestControlledSchema, self)._setup(url) path_repos=self.tmp_repos() self.repos=Repository.create(path_repos,'repository_name') # drop existing version table if necessary diff --git a/test/versioning/test_schemadiff.py b/test/versioning/test_schemadiff.py index 0a7b322..7ecc8ea 100644 --- a/test/versioning/test_schemadiff.py +++ b/test/versioning/test_schemadiff.py @@ -3,15 +3,16 @@ import sqlalchemy from sqlalchemy import * from test import fixture from migrate.versioning import genmodel, schemadiff +from nose.tools import eq_ class TestSchemaDiff(fixture.DB): level=fixture.DB.CONNECT table_name = 'tmp_schemadiff' - def setUp(self): - fixture.DB.setUp(self) - self._connect(self.url) + def _setup(self, url): + + super(TestSchemaDiff, self)._setup(url) self.meta = MetaData(self.engine, reflect=True) self.meta.drop_all() # in case junk tables are lying around in the test database self.meta = MetaData(self.engine, reflect=True) # needed if we just deleted some tables @@ -24,12 +25,12 @@ class TestSchemaDiff(fixture.DB): if WANT_ENGINE_ECHO == 'T': self.engine.echo = True - def tearDown(self): + def _teardown(self): if self.table.exists(): #self.table.drop() # bummer, this doesn't work because the list of tables is out of date, but calling reflect didn't work self.meta = MetaData(self.engine, reflect=True) self.meta.drop_all() - fixture.DB.tearDown(self) + super(TestSchemaDiff, self)._teardown() def _applyLatestModel(self): diff = schemadiff.getDiffOfModelAgainstDatabase(self.meta, self.engine, excludeTables=['migrate_version']) @@ -43,8 +44,8 @@ class TestSchemaDiff(fixture.DB): def assertDiff(isDiff, tablesMissingInDatabase, tablesMissingInModel, tablesWithDiff): diff = schemadiff.getDiffOfModelAgainstDatabase(self.meta, self.engine, excludeTables=['migrate_version']) - self.assertEquals(bool(diff), isDiff) - self.assertEquals( ([t.name for t in diff.tablesMissingInDatabase], [t.name for t in diff.tablesMissingInModel], [t.name for t in diff.tablesWithDiff]), + eq_(bool(diff), isDiff) + eq_( ([t.name for t in diff.tablesMissingInDatabase], [t.name for t in diff.tablesMissingInModel], [t.name for t in diff.tablesWithDiff]), (tablesMissingInDatabase, tablesMissingInModel, tablesWithDiff) ) # Model is defined but database is empty. @@ -100,8 +101,8 @@ class TestSchemaDiff(fixture.DB): # Make sure data is still present. result = self.engine.execute(self.table.select(self.table.c.id==dataId)) rows = result.fetchall() - self.assertEquals(len(rows), 1) - self.assertEquals(rows[0].name, 'mydata') + eq_(len(rows), 1) + eq_(rows[0].name, 'mydata') # Add data, later we'll make sure it's still present. result = self.engine.execute(self.table.insert(), id=2, name=u'mydata2', data2=123) diff --git a/test/versioning/test_shell.py b/test/versioning/test_shell.py index eb566ad..60c3edf 100644 --- a/test/versioning/test_shell.py +++ b/test/versioning/test_shell.py @@ -82,7 +82,7 @@ class Shell(fixture.Shell): self.assertSuccess(fd) return ret -class TestShellCommands(Shell): +class _TestShellCommands(Shell): """Tests migrate.py commands""" def test_run(self): @@ -147,7 +147,7 @@ class TestShellCommands(Shell): self.assertSuccess(self.cmd('manage',script,'--repository=/path/to/repository')) self.assert_(os.path.exists(script)) -class TestShellRepository(Shell): +class _TestShellRepository(Shell): """Shell commands on an existing repository/python script""" def setUp(self): """Create repository, python change script""" @@ -198,7 +198,7 @@ class TestShellDatabase(Shell,fixture.DB): level=fixture.DB.CONNECT @fixture.usedb() - def test_version_control(self): + def _test_version_control(self): """Ensure we can set version control on a database""" path_repos=repos=self.tmp_repos() self.assertSuccess(self.cmd('create',path_repos,'repository_name')) @@ -210,7 +210,7 @@ class TestShellDatabase(Shell,fixture.DB): self.assertFailure(self.cmd('drop_version_control',self.url,path_repos)) @fixture.usedb() - def test_version_control_specified(self): + def _test_version_control_specified(self): """Ensure we can set version control to a particular version""" path_repos=self.tmp_repos() self.assertSuccess(self.cmd('create',path_repos,'repository_name')) @@ -261,13 +261,16 @@ class TestShellDatabase(Shell,fixture.DB): self.assertEquals(self.cmd_db_version(self.url,repos_path),0) self.assertSuccess(self.cmd('upgrade',self.url,repos_path)) self.assertEquals(self.cmd_db_version(self.url,repos_path),1) + # Downgrade must have a valid version specified - self.assertFailure(self.cmd('downgrade',self.url,repos_path)) - self.assertFailure(self.cmd('downgrade',self.url,repos_path,2)) - self.assertFailure(self.cmd('downgrade',self.url,repos_path,-1)) - self.assertEquals(self.cmd_db_version(self.url,repos_path),1) - self.assertSuccess(self.cmd('downgrade',self.url,repos_path,0)) + self.assertFailure(self.cmd('downgrade',self.url, repos_path)) + self.assertFailure(self.cmd('downgrade',self.url, repos_path, '0', 2)) + self.assertFailure(self.cmd('downgrade',self.url, repos_path, '0', -1)) + self.assertEquals(self.cmd_db_version(self.url, repos_path),1) + + self.assertSuccess(self.cmd('downgrade', self.url, repos_path, 0)) self.assertEquals(self.cmd_db_version(self.url,repos_path),0) + self.assertFailure(self.cmd('downgrade',self.url,repos_path,1)) self.assertEquals(self.cmd_db_version(self.url,repos_path),0)