From 75494ae2269294f91745718f109ebcce71ae3eb7 Mon Sep 17 00:00:00 2001
From: iElectric <unknown>
Date: Tue, 9 Jun 2009 13:31:15 +0000
Subject: [PATCH] update tests for schema, refactor a bit

---
 TODO                               |   4 +
 migrate/versioning/exceptions.py   |  12 --
 migrate/versioning/schema.py       | 162 +++++++++++++------------
 test/fixture/pathed.py             |   5 +-
 test/versioning/test_genmodel.py   |  13 ++
 test/versioning/test_repository.py |   5 +
 test/versioning/test_schema.py     | 186 +++++++++++++++++++++++------
 test/versioning/test_shell.py      |   4 +-
 test/versioning/test_template.py   |   4 +-
 9 files changed, 261 insertions(+), 134 deletions(-)
 create mode 100644 test/versioning/test_genmodel.py

diff --git a/TODO b/TODO
index bc3453d..922778b 100644
--- a/TODO
+++ b/TODO
@@ -8,3 +8,7 @@ make_update_script_for_model:
 - calculated differences between models are actually differences between metas
 - columns are not compared?
 - even if two "models" are equal, it doesn't yield so
+
+
+- refactor test_shell to test_api and use TestScript for cmd line testing
+- controlledschema.drop() drops whole migrate table, maybe there are some other repositories bound to it!
diff --git a/migrate/versioning/exceptions.py b/migrate/versioning/exceptions.py
index 0b523e9..00b5dd6 100644
--- a/migrate/versioning/exceptions.py
+++ b/migrate/versioning/exceptions.py
@@ -5,12 +5,10 @@
 
 class Error(Exception):
     """Error base class."""
-    pass
 
 
 class ApiError(Error):
     """Base class for API errors."""
-    pass
 
 
 class KnownError(ApiError):
@@ -23,7 +21,6 @@ class UsageError(ApiError):
 
 class ControlledSchemaError(Error):
     """Base class for controlled schema errors."""
-    pass
 
 
 class InvalidVersionError(ControlledSchemaError):
@@ -44,44 +41,35 @@ class WrongRepositoryError(ControlledSchemaError):
 
 class NoSuchTableError(ControlledSchemaError):
     """The table does not exist."""
-    pass
 
 
 class PathError(Error):
     """Base class for path errors."""
-    pass
 
 
 class PathNotFoundError(PathError):
     """A path with no file was required; found a file."""
-    pass
 
 
 class PathFoundError(PathError):
     """A path with a file was required; found no file."""
-    pass
 
 
 class RepositoryError(Error):
     """Base class for repository errors."""
-    pass
 
 
 class InvalidRepositoryError(RepositoryError):
     """Invalid repository error."""
-    pass
 
 
 class ScriptError(Error):
     """Base class for script errors."""
-    pass
 
 
 class InvalidScriptError(ScriptError):
     """Invalid script error."""
-    pass
 
 
 class InvalidVersionError(Error):
     """Invalid version error."""
-    pass
diff --git a/migrate/versioning/schema.py b/migrate/versioning/schema.py
index ae6d6e7..9b71ba7 100644
--- a/migrate/versioning/schema.py
+++ b/migrate/versioning/schema.py
@@ -5,6 +5,7 @@ from sqlalchemy import (Table, Column, MetaData, String, Text, Integer,
     create_engine)
 from sqlalchemy.sql import and_
 from sqlalchemy import exceptions as sa_exceptions
+from sqlalchemy.sql import bindparam
 
 from migrate.versioning import exceptions, genmodel, schemadiff
 from migrate.versioning.repository import Repository
@@ -34,36 +35,92 @@ class ControlledSchema(object):
         if not hasattr(self, 'table') or self.table is None:
             try:
                 self.table = Table(tname, self.meta, autoload=True)
-            except (exceptions.NoSuchTableError):
+            except sa_exceptions.NoSuchTableError:
                 raise exceptions.DatabaseNotControlledError(tname)
 
         # TODO?: verify that the table is correct (# cols, etc.)
         result = self.engine.execute(self.table.select(
                     self.table.c.repository_id == str(self.repository.id)))
-        data = list(result)[0]
-        # TODO?: exception if row count is bad
-        # TODO: check repository id, exception if incorrect
+
+        try:
+            data = list(result)[0]
+        except IndexError:
+            raise exceptions.DatabaseNotControlledError(tname)
+
         self.version = data['version']
+        return data
 
-    def _get_repository(self):
+    def drop(self):
         """
-        Given a database engine, try to guess the repository.
+        Remove version control from a database.
+        """
+        try:
+            self.table.drop()
+        except (sa_exceptions.SQLError):
+            raise exceptions.DatabaseNotControlledError(str(self.table))
 
-        :raise: :exc:`NotImplementedError`
+    def changeset(self, version=None):
+        """API to Changeset creation.
+        
+        Uses self.version for start version and engine.name to get database name."""
+        database = self.engine.name
+        start_ver = self.version
+        changeset = self.repository.changeset(database, start_ver, version)
+        return changeset
+
+    def runchange(self, ver, change, step):
+        startver = ver
+        endver = ver + step
+        # Current database version must be correct! Don't run if corrupt!
+        if self.version != startver:
+            raise exceptions.InvalidVersionError("%s is not %s" % \
+                                                     (self.version, startver))
+        # Run the change
+        change.run(self.engine, step)
+
+        # Update/refresh database version
+        self.update_repository_table(startver, endver)
+        self.load()
+
+    def update_repository_table(self, startver, endver):
+        """Update version_table with new information"""
+        update = self.table.update(and_(self.table.c.version == int(startver),
+             self.table.c.repository_id == str(self.repository.id)))
+        self.engine.execute(update, version=int(endver))
+
+    def upgrade(self, version=None):
         """
-        # TODO: no guessing yet; for now, a repository must be supplied
-        raise NotImplementedError()
+        Upgrade (or downgrade) to a specified version, or latest version.
+        """
+        changeset = self.changeset(version)
+        for ver, change in changeset:
+            self.runchange(ver, change, changeset.step)
+
+    def update_db_from_model(self, model):
+        """
+        Modify the database to match the structure of the current Python model.
+        """
+        model = load_model(model)
+
+        diff = schemadiff.getDiffOfModelAgainstDatabase(
+            model, self.engine, excludeTables=[self.repository.version_table])
+        genmodel.ModelGenerator(diff).applyModel()
+
+        self.update_repository_table(self.version, int(self.repository.latest))
+
+        self.load()
 
     @classmethod
     def create(cls, engine, repository, version=None):
         """
         Declare a database to be under a repository's version control.
 
+        :raises: :exc:`DatabaseAlreadyControlledError`
         :returns: :class:`ControlledSchema`
         """
         # Confirm that the version # is valid: positive, integer,
         # exists in repos
-        if isinstance(repository, str):
+        if isinstance(repository, basestring):
             repository = Repository(repository)
         version = cls._validate_version(repository, version)
         table = cls._create_table_version(engine, repository, version)
@@ -76,7 +133,7 @@ class ControlledSchema(object):
         """
         Ensures this is a valid version number for this repository.
 
-        :raises: :exc:`ControlledSchema.InvalidVersionError` if invalid
+        :raises: :exc:`InvalidVersionError` if invalid
         :return: valid version number
         """
         if version is None:
@@ -93,6 +150,8 @@ class ControlledSchema(object):
     def _create_table_version(cls, engine, repository, version):
         """
         Creates the versioning table in a database.
+
+        :raises: :exc:`DatabaseAlreadyControlledError`
         """
         # Create tables
         tname = repository.version_table
@@ -104,17 +163,21 @@ class ControlledSchema(object):
             Column('repository_path', Text),
             Column('version', Integer), )
 
+        # there can be multiple repositories/schemas in the same db
         if not table.exists():
             table.create()
 
+        # test for existing repository_id
+        s = table.select(table.c.repository_id == bindparam("repository_id"))
+        result = engine.execute(s, repository_id=repository.id)
+        if result.fetchone():
+            raise exceptions.DatabaseAlreadyControlledError
+
         # Insert data
-        try:
-            engine.execute(table.insert(), repository_id=repository.id,
+        engine.execute(table.insert().values(
+                           repository_id=repository.id,
                            repository_path=repository.path,
-                           version=int(version))
-        except sa_exceptions.IntegrityError:
-            # An Entry for this repo already exists.
-            raise exceptions.DatabaseAlreadyControlledError()
+                           version=int(version)))
         return table
 
     @classmethod
@@ -123,8 +186,9 @@ class ControlledSchema(object):
         Compare the current model against the current database.
         """
         if isinstance(repository, basestring):
-            repository=Repository(repository)
+            repository = Repository(repository)
         model = load_model(model)
+
         diff = schemadiff.getDiffOfModelAgainstDatabase(
             model, engine, excludeTables=[repository.version_table])
         return diff
@@ -135,66 +199,8 @@ class ControlledSchema(object):
         Dump the current database as a Python model.
         """
         if isinstance(repository, basestring):
-            repository=Repository(repository)
+            repository = Repository(repository)
+
         diff = schemadiff.getDiffOfModelAgainstDatabase(
             MetaData(), engine, excludeTables=[repository.version_table])
         return genmodel.ModelGenerator(diff, declarative).toPython()
-
-    def update_db_from_model(self, model):
-        """
-        Modify the database to match the structure of the current Python model.
-        """
-        if isinstance(self.repository, basestring):
-            self.repository=Repository(self.repository)
-        model = load_model(model)
-        diff = schemadiff.getDiffOfModelAgainstDatabase(
-            model, self.engine, excludeTables=[self.repository.version_table])
-        genmodel.ModelGenerator(diff).applyModel()
-        update = self.table.update(
-            self.table.c.repository_id == str(self.repository.id))
-        self.engine.execute(update, version=int(self.repository.latest))
-
-    def drop(self):
-        """
-        Remove version control from a database.
-        """
-        try:
-            self.table.drop()
-        except (sa_exceptions.SQLError):
-            raise exceptions.DatabaseNotControlledError(str(self.table))
-
-    def _engine_db(self, engine):
-        """
-        Returns the database name of an engine - ``postgres``, ``sqlite`` ...
-        """
-        return engine.name
-
-    def changeset(self, version=None):
-        database = self._engine_db(self.engine)
-        start_ver = self.version
-        changeset = self.repository.changeset(database, start_ver, version)
-        return changeset
-
-    def runchange(self, ver, change, step):
-        startver = ver
-        endver = ver + step
-        # Current database version must be correct! Don't run if corrupt!
-        if self.version != startver:
-            raise exceptions.InvalidVersionError("%s is not %s" % \
-                                                     (self.version, startver))
-        # Run the change
-        change.run(self.engine, step)
-        # Update/refresh database version
-        update = self.table.update(
-            and_(self.table.c.version == int(startver),
-                 self.table.c.repository_id == str(self.repository.id)))
-        self.engine.execute(update, version=int(endver))
-        self.load()
-
-    def upgrade(self, version=None):
-        """
-        Upgrade (or downgrade) to a specified version, or latest version.
-        """
-        changeset = self.changeset(version)
-        for ver, change in changeset:
-            self.runchange(ver, change, changeset.step)
diff --git a/test/fixture/pathed.py b/test/fixture/pathed.py
index e8c6ebc..a5c81ec 100644
--- a/test/fixture/pathed.py
+++ b/test/fixture/pathed.py
@@ -21,7 +21,10 @@ class Pathed(base.Base):
 
     def tearDown(self):
         super(Pathed, self).tearDown()
-        sys.path.remove(self.temp_usable_dir)
+        try:
+            sys.path.remove(self.temp_usable_dir)
+        except:
+            pass # w00t?
         Pathed.purge(self.temp_usable_dir)
 
     @classmethod
diff --git a/test/versioning/test_genmodel.py b/test/versioning/test_genmodel.py
new file mode 100644
index 0000000..8d6a6ee
--- /dev/null
+++ b/test/versioning/test_genmodel.py
@@ -0,0 +1,13 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+
+from migrate.versioning.genmodel import *
+from migrate.versioning.exceptions import *
+
+from test import fixture
+
+
+class TestModelGenerator(fixture.Pathed, fixture.DB):
+    level = fixture.DB.TXN
diff --git a/test/versioning/test_repository.py b/test/versioning/test_repository.py
index 3d42006..88fd2dd 100644
--- a/test/versioning/test_repository.py
+++ b/test/versioning/test_repository.py
@@ -6,6 +6,7 @@ import shutil
 
 from migrate.versioning import exceptions
 from migrate.versioning.repository import *
+from migrate.versioning.script import *
 from nose.tools import raises
 
 from test import fixture
@@ -164,6 +165,9 @@ class TestVersionedRepository(fixture.Pathed):
         check_changeset((9,), 1)
         check_changeset((10,), 0)
 
+        # run changes
+        cs.run('postgres', 'upgrade')
+
         # Can't request a changeset of higher/lower version than this repository
         self.assertRaises(Exception, repos.changeset, 'postgres', 11)
         self.assertRaises(Exception, repos.changeset, 'postgres', -1)
@@ -186,6 +190,7 @@ class TestVersionedRepository(fixture.Pathed):
         # since we normally create 3 digit ones, let's see if we blow up
         self.assert_(os.path.exists('%s/versions/1000.py' % self.path_repos))
         self.assert_(os.path.exists('%s/versions/1001.py' % self.path_repos))
+
         
 # TODO: test manage file
 # TODO: test changeset
diff --git a/test/versioning/test_schema.py b/test/versioning/test_schema.py
index fc94329..81fbc69 100644
--- a/test/versioning/test_schema.py
+++ b/test/versioning/test_schema.py
@@ -1,90 +1,200 @@
-from test import fixture
-from migrate.versioning.schema import *
-from migrate.versioning import script,exceptions
-import os,shutil
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
 
-class TestControlledSchema(fixture.Pathed,fixture.DB):
+import os
+import shutil
+
+from migrate.versioning.schema import *
+from migrate.versioning import script, exceptions, schemadiff
+
+from sqlalchemy import *
+
+from test import fixture
+
+
+class TestControlledSchema(fixture.Pathed, fixture.DB):
     # Transactions break postgres in this test; we'll clean up after ourselves
-    level=fixture.DB.CONNECT
+    level = fixture.DB.CONNECT
+
     
+    def setUp(self):
+        super(TestControlledSchema, self).setUp()
+        path_repos = self.temp_usable_dir + '/repo/'
+        self.repos = Repository.create(path_repos, 'repo_name')
+
     def _setup(self, url):
+        self.setUp()
         super(TestControlledSchema, self)._setup(url)
-        path_repos=self.tmp_repos()
-        self.repos=Repository.create(path_repos,'repository_name')
+        self.cleanup()
+
+    def _teardown(self):
+        super(TestControlledSchema, self)._teardown()
+        self.cleanup()
+        self.tearDown()
+
+    def cleanup(self):
         # drop existing version table if necessary
         try:
-            ControlledSchema(self.engine,self.repos).drop()
+            ControlledSchema(self.engine, self.repos).drop()
         except:
             # No table to drop; that's fine, be silent
             pass
 
+    def tearDown(self):
+        self.cleanup()
+        super(TestControlledSchema, self).tearDown()
+
     @fixture.usedb()
     def test_version_control(self):
         """Establish version control on a particular database"""
         # Establish version control on this database
-        dbcontrol=ControlledSchema.create(self.engine,self.repos)
-        
-        # We can load a controlled DB this way, too
-        dbcontrol0=ControlledSchema(self.engine,self.repos)
-        self.assertEquals(dbcontrol,dbcontrol0)
-        # We can also use a repository path, instead of a repository
-        dbcontrol0=ControlledSchema(self.engine,self.repos.path)
-        self.assertEquals(dbcontrol,dbcontrol0)
-        # We don't have to use the same connection
-        engine=create_engine(self.url)
-        dbcontrol0=ControlledSchema(self.engine,self.repos.path)
-        self.assertEquals(dbcontrol,dbcontrol0)
+        dbcontrol = ControlledSchema.create(self.engine, self.repos)
 
         # Trying to create another DB this way fails: table exists
-        self.assertRaises(exceptions.ControlledSchemaError,
-            ControlledSchema.create,self.engine,self.repos)
+        self.assertRaises(exceptions.DatabaseAlreadyControlledError,
+            ControlledSchema.create, self.engine, self.repos)
+        
+        # We can load a controlled DB this way, too
+        dbcontrol0 = ControlledSchema(self.engine, self.repos)
+        self.assertEquals(dbcontrol, dbcontrol0)
+
+        # We can also use a repository path, instead of a repository
+        dbcontrol0 = ControlledSchema(self.engine, self.repos.path)
+        self.assertEquals(dbcontrol, dbcontrol0)
+
+        # We don't have to use the same connection
+        engine = create_engine(self.url)
+        dbcontrol0 = ControlledSchema(engine, self.repos.path)
+        self.assertEquals(dbcontrol, dbcontrol0)
 
         # Clean up: 
-        # un-establish version control
         dbcontrol.drop()
+
         # Attempting to drop vc from a db without it should fail
-        self.assertRaises(exceptions.DatabaseNotControlledError,dbcontrol.drop)
+        self.assertRaises(exceptions.DatabaseNotControlledError, dbcontrol.drop)
+
+        # No table defined should raise error
+        self.assertRaises(exceptions.DatabaseNotControlledError,
+            ControlledSchema, self.engine, self.repos)
 
     @fixture.usedb()
     def test_version_control_specified(self):
         """Establish version control with a specified version"""
         # Establish version control on this database
-        version=0
-        dbcontrol=ControlledSchema.create(self.engine,self.repos,version)
-        self.assertEquals(dbcontrol.version,version)
+        version = 0
+        dbcontrol = ControlledSchema.create(self.engine, self.repos, version)
+        self.assertEquals(dbcontrol.version, version)
         
         # Correct when we load it, too
-        dbcontrol=ControlledSchema(self.engine,self.repos)
-        self.assertEquals(dbcontrol.version,version)
+        dbcontrol = ControlledSchema(self.engine, self.repos)
+        self.assertEquals(dbcontrol.version, version)
 
         dbcontrol.drop()
 
         # Now try it with a nonzero value
-        version=10
+        version = 10
         for i in range(version):
             self.repos.create_script('')
-        self.assertEquals(self.repos.latest,version)
+        self.assertEquals(self.repos.latest, version)
 
         # Test with some mid-range value
-        dbcontrol=ControlledSchema.create(self.engine,self.repos,5)
-        self.assertEquals(dbcontrol.version,5)
+        dbcontrol = ControlledSchema.create(self.engine,self.repos, 5)
+        self.assertEquals(dbcontrol.version, 5)
         dbcontrol.drop()
 
         # Test with max value
-        dbcontrol=ControlledSchema.create(self.engine,self.repos,version)
-        self.assertEquals(dbcontrol.version,version)
+        dbcontrol = ControlledSchema.create(self.engine, self.repos, version)
+        self.assertEquals(dbcontrol.version, version)
         dbcontrol.drop()
 
     @fixture.usedb()
     def test_version_control_invalid(self):
         """Try to establish version control with an invalid version"""
-        versions=('Thirteen','-1',-1,'',13)
+        versions = ('Thirteen', '-1', -1, '' , 13)
         # A fresh repository doesn't go up to version 13 yet
         for version in versions:
             #self.assertRaises(ControlledSchema.InvalidVersionError,
             # Can't have custom errors with assertRaises...
             try:
                 ControlledSchema.create(self.engine,self.repos,version)
-                self.assert_(False,repr(version))
+                self.assert_(False, repr(version))
             except exceptions.InvalidVersionError:
                 pass
+
+    @fixture.usedb()
+    def test_changeset(self):
+        """Create changeset from controlled schema"""
+        dbschema = ControlledSchema.create(self.engine, self.repos)
+        
+        # empty schema doesn't have changesets
+        cs = dbschema.changeset()
+        self.assertEqual(cs, {})
+
+        for i in range(5):
+            self.repos.create_script('')
+        self.assertEquals(self.repos.latest, 5)
+
+        cs = dbschema.changeset(5)
+        self.assertEqual(len(cs), 5)
+
+        # cleanup
+        dbschema.drop()
+
+    @fixture.usedb()
+    def test_upgrade_runchange(self):
+        dbschema = ControlledSchema.create(self.engine, self.repos)
+        
+        for i in range(10):
+            self.repos.create_script('')
+
+        self.assertEquals(self.repos.latest, 10)
+
+        dbschema.upgrade(10)
+
+        # TODO: test for table version in db
+
+        # cleanup
+        dbschema.drop()
+
+    @fixture.usedb()
+    def test_create_model(self):
+        """Test workflow to generate create_model"""
+        model = ControlledSchema.create_model(self.engine, self.repos, declarative=False)
+        self.assertTrue(isinstance(model, basestring))
+
+        model = ControlledSchema.create_model(self.engine, self.repos.path, declarative=True)
+        self.assertTrue(isinstance(model, basestring))
+
+    @fixture.usedb()
+    def test_compare_model_to_db(self):
+        meta = self.construct_model()
+
+        diff = ControlledSchema.compare_model_to_db(self.engine, meta, self.repos)
+        self.assertTrue(isinstance(diff, schemadiff.SchemaDiff))
+
+        diff = ControlledSchema.compare_model_to_db(self.engine, meta, self.repos.path)
+        self.assertTrue(isinstance(diff, schemadiff.SchemaDiff))
+        meta.drop_all(self.engine)
+
+    @fixture.usedb()
+    def test_update_db_from_model(self):
+        dbschema = ControlledSchema.create(self.engine, self.repos)
+
+        meta = self.construct_model()
+    
+        dbschema.update_db_from_model(meta)
+
+        # TODO: test for table version in db
+
+        # cleanup
+        dbschema.drop()
+        meta.drop_all(self.engine)
+
+    def construct_model(self):
+        meta = MetaData()
+
+        user = Table('temp_model_schema', meta, Column('id', Integer), Column('user', String))
+
+        return meta
+
+    # TODO: test how are tables populated in db
diff --git a/test/versioning/test_shell.py b/test/versioning/test_shell.py
index 1758cd8..951fec3 100644
--- a/test/versioning/test_shell.py
+++ b/test/versioning/test_shell.py
@@ -16,8 +16,6 @@ from migrate.versioning.exceptions import *
 from test import fixture
 
 
-python_version = sys.version[:3]
-
 class Shell(fixture.Shell):
 
     _cmd = os.path.join('python migrate', 'versioning', 'shell.py')
@@ -400,7 +398,7 @@ class TestShellDatabase(Shell, fixture.DB):
         self._run_test_sqlfile(upgrade_script,downgrade_script)
 
     @fixture.usedb()
-    def test_test(self):
+    def test_command_test(self):
         repos_name = 'repos_name'
         repos_path = self.tmp()
 
diff --git a/test/versioning/test_template.py b/test/versioning/test_template.py
index 72217ac..e92cb95 100644
--- a/test/versioning/test_template.py
+++ b/test/versioning/test_template.py
@@ -9,9 +9,9 @@ class TestPathed(fixture.Base):
         self.assert_(os.path.exists(path))
     def test_repository(self):
         """We can find the path to the default repository"""
-        path = template.get_repository() 
+        path = template.get_repository()
         self.assert_(os.path.exists(path))
     def test_script(self):
         """We can find the path to the default migration script"""
-        path = template.get_script() 
+        path = template.get_script()
         self.assert_(os.path.exists(path))