From cb624f15db39de58d9e317126403469e4230d8a9 Mon Sep 17 00:00:00 2001
From: percious17 <unknown>
Date: Mon, 19 Jan 2009 23:39:53 +0000
Subject: [PATCH] support for SA 0.5.1.

Only postgres is working fully.

MySQL has 2 broken tests.

sqlite has about 4 broken tests.
---
 migrate/changeset/ansisql.py     | 41 +++++++++++++++++++++++++-------
 migrate/changeset/schema.py      |  9 ++++++-
 test/changeset/test_changeset.py |  5 +++-
 test/fixture/database.py         |  2 +-
 4 files changed, 46 insertions(+), 11 deletions(-)

diff --git a/migrate/changeset/ansisql.py b/migrate/changeset/ansisql.py
index e188100..4c721e8 100644
--- a/migrate/changeset/ansisql.py
+++ b/migrate/changeset/ansisql.py
@@ -22,7 +22,7 @@ class RawAlterTableVisitor(object):
         if isinstance(ret,sa.Table):
             ret = ret.fullname
         return ret
-
+    
     def _do_quote_table_identifier(self, identifier):
         return '"%s"'%identifier
     
@@ -81,12 +81,16 @@ class ANSIColumnGenerator(AlterTableVisitor,SchemaGenerator):
         pks = table.primary_key
         colspec = self.get_column_specification(column)
         self.append(colspec)
-        if column.foreign_keys:
-            for fk in column.foreign_keys:
-                self.append(";\n\t ")
-                self.add_foreignkey(fk.constraint)
-        else:
-            self.execute()
+
+#        if column.foreign_keys:
+#            self.append(" ")
+#            for fk in column.foreign_keys:
+#                self.add_foreignkey(fk.constraint)
+#                continue
+#                self.append(";\n\t ")
+#                self.define_foreign_key(fk.constraint)
+#        else:
+        self.execute()
 
 
     def visit_table(self,table):
@@ -283,14 +287,35 @@ class ANSIConstraintDropper(ANSIConstraintCommon):
 
     def visit_migrate_foreign_key_constraint(self,*p,**k):
         return self._visit_constraint(*p,**k)
-
+    
     def visit_migrate_check_constraint(self,*p,**k):
         return self._visit_constraint(*p,**k)
 
+class ANSIFKGenerator(AlterTableVisitor,SchemaGenerator):
+    """Extends ansisql generator for column creation (alter table add col)"""
+    def __init__(self, *args, **kwargs):
+        self.fk = kwargs.get('fk', None)
+        if self.fk:
+            del kwargs['fk']
+        super(ANSIFKGenerator, self).__init__(*args, **kwargs)
+
+    def visit_column(self,column):
+        """Create foreign keys for a column (table already exists); #32"""
+
+        if self.fk:
+            self.add_foreignkey(self.fk.constraint)
+
+        if self.buffer.getvalue() !='':
+            self.execute()
+
+    def visit_table(self,table):
+        pass
+
 class ANSIDialect(object):
     columngenerator = ANSIColumnGenerator
     columndropper = ANSIColumnDropper
     schemachanger = ANSISchemaChanger
+    columnfkgenerator = ANSIFKGenerator
 
     @classmethod
     def visitor(self,name):
diff --git a/migrate/changeset/schema.py b/migrate/changeset/schema.py
index d47de1b..bb78d9e 100644
--- a/migrate/changeset/schema.py
+++ b/migrate/changeset/schema.py
@@ -318,7 +318,14 @@ class ChangesetColumn(object):
         table = _normalize_table(self,table)
         engine = table.bind
         visitorcallable = get_engine_visitor(engine,'columngenerator')
-        engine._run_visitor(visitorcallable,self,*args,**kwargs)
+        engine._run_visitor(visitorcallable, self, *args,**kwargs)
+
+        #add in foreign keys
+        if self.foreign_keys:
+            for fk in self.foreign_keys:
+                visitorcallable = get_engine_visitor(engine,'columnfkgenerator')
+                engine._run_visitor(visitorcallable, self, fk=fk)
+    
         return self
     
     def drop(self,table=None,*args,**kwargs):
diff --git a/test/changeset/test_changeset.py b/test/changeset/test_changeset.py
index c3d199e..116b7b5 100644
--- a/test/changeset/test_changeset.py
+++ b/test/changeset/test_changeset.py
@@ -40,6 +40,7 @@ class TestAddDropColumn(fixture.DB):
 
         def _assert_numcols(expected,type_):
             result = len(self.table.c)
+            
             self.assertEquals(result,expected,
                 "# %s cols incorrect: %s != %s"%(type_,result,expected))
             if not col_k.get('primary_key',None):
@@ -53,9 +54,11 @@ class TestAddDropColumn(fixture.DB):
             # Changed: create/drop shouldn't mess with the objects
             #_assert_numcols(expected,'object')
             # Detect # database cols via autoload
-            self.meta.clear()
+            #self.meta.clear()
+            del self.meta.tables[self.table_name]
             self.table=Table(self.table_name,self.meta,autoload=True)
             _assert_numcols(expected,'database')
+            
         assert_numcols(1)
         if len(col_p) == 0:
             col_p = [String(40)]
diff --git a/test/fixture/database.py b/test/fixture/database.py
index 97d18b9..96d2219 100644
--- a/test/fixture/database.py
+++ b/test/fixture/database.py
@@ -42,7 +42,7 @@ def is_supported(url,supported,not_supported):
 
 #we make the engines global, which should make the tests run a bit faster
 urls = readurls()
-engines=dict([(url,create_engine(url)) for url in urls])
+engines=dict([(url,create_engine(url, echo=True)) for url in urls])
 
 
 def usedb(supported=None,not_supported=None):