import re
import sqlalchemy
from migrate.changeset.databases.visitor import get_engine_visitor

__all__ = [
'create_column',
'drop_column',
'alter_column',
'rename_table',
'rename_index',
]


def create_column(column,table=None,*p,**k):
    if table is not None:
        return table.create_column(column,*p,**k)
    return column.create(*p,**k)

def drop_column(column,table=None,*p,**k):
    if table is not None:
        return table.drop_column(column,*p,**k)
    return column.drop(*p,**k)

def _to_table(table,engine=None):
    if isinstance(table,sqlalchemy.Table):
        return table
    # Given: table name, maybe an engine
    meta = sqlalchemy.MetaData()
    if engine is not None:
        meta.bind = engine #meta.connect(engine)
    return sqlalchemy.Table(table,meta)
def _to_index(index,table=None,engine=None):
    if isinstance(index,sqlalchemy.Index):
        return index
    # Given: index name; table name required
    table = _to_table(table,engine)
    ret = sqlalchemy.Index(index)
    ret.table = table
    return ret
    
def rename_table(table,name,engine=None):
    """Rename a table, given the table's current name and the new name."""
    table = _to_table(table,engine)
    table.rename(name)

def rename_index(index,name,table=None,engine=None):
    """Rename an index
    Takes an index name/object, a table name/object, and an engine. Engine and
    table aren't required if an index object is given.
    """
    index = _to_index(index,table,engine)
    index.rename(name)


def _engine_run_visitor(engine,visitorcallable,element,**kwargs):
    conn = engine.connect()
    try:
        element.accept_schema_visitor(visitorcallable(engine.dialect,connection=conn))
    finally:
        conn.close()

def alter_column(*p,**k):
    """Alter a column
    Parameters: column name, table name, an engine, and the
    properties of that column to change
    """
    if len(p) and isinstance(p[0],sqlalchemy.Column):
        col = p[0]
    else:
        col = None
    if 'table' not in k:
        k['table'] = col.table
    if 'engine' not in k:
        k['engine'] = k['table'].bind
    engine = k['engine']
    delta = _ColumnDelta(*p,**k)
    visitorcallable = get_engine_visitor(engine,'schemachanger')
    _engine_run_visitor(engine,visitorcallable,delta)

    # Update column
    if col is not None:
        # Special case: change column key on rename, if key not explicit
        # Used by SA : table.c.[key]
        #
        # This fails if the key was explit AND equal to the column name. 
        # (It changes the key name when it shouldn't.) 
        # Not much we can do about it.
        if 'name' in delta.keys():
            if (col.name == col.key):
                newname = delta['name']
                del col.table.c[col.key]
                setattr(col,'key',newname)
                col.table.c[col.key] = col
        # Change all other attrs
        for key,val in delta.iteritems():
            setattr(col,key,val)

def _normalize_table(column,table):
    if table is not None:
        if table is not column.table:
            # This is a bit of a hack: we end up with dupe PK columns here
            pk_names = map(lambda c: c.name, table.primary_key)
            if column.primary_key and pk_names.count(column.name):
                index = pk_names.index(column_name)
                del table.primary_key[index]
            table.append_column(column)
    return column.table


class _WrapRename(object):
    def __init__(self,item,name):
        self.item = item
        self.name = name

    def accept_schema_visitor(self,visitor):
        if isinstance(self.item,sqlalchemy.Table):
            suffix = 'table'
        elif isinstance(self.item,sqlalchemy.Column):
            suffix = 'column'
        elif isinstance(self.item,sqlalchemy.Index):
            suffix = 'index'
        funcname = 'visit_%s'%suffix
        func = getattr(visitor,funcname)
        param = self.item,self.name
        return func(param)

class _ColumnDelta(dict):
    """Extracts the differences between two columns/column-parameters"""
    def __init__(self,*p,**k):
        """Extract ALTER-able differences from two columns
        May receive parameters arranged in several different ways:
         * old_column_object,new_column_object,*parameters
            Identifies attributes that differ between the two columns. 
            Parameters specified outside of either column are always executed
            and override column differences.
         * column_object,[current_name,]*parameters
            Parameters specified are changed; table name is extracted from
            column object. 
            Name is changed to column_object.name from current_name, if
            current_name is specified. If not specified, name is unchanged.
         * current_name,table,*parameters
            'table' may be either an object or a name
        """
        # Things are initialized differently depending on how many column
        # parameters are given. Figure out how many and call the appropriate
        # method.
        if len(p) >= 1 and isinstance(p[0],sqlalchemy.Column):
            # At least one column specified
            if len(p) >= 2 and isinstance(p[1],sqlalchemy.Column):
                # Two columns specified
                func = self._init_2col
            else:
                # Exactly one column specified
                func = self._init_1col
        else:
            # Zero columns specified
            func = self._init_0col
        diffs = func(*p,**k)
        self._set_diffs(diffs)
    # Column attributes that can be altered
    diff_keys = ('name','type','nullable','default','primary_key','foreign_key')

    def _get_table_name(self):
        if isinstance(self._table,basestring):
            ret = self._table
        else:
            ret = self._table.name
        return ret
    table_name = property(_get_table_name)
    
    def _get_table(self):
        if isinstance(self._table,basestring):
            ret = None
        else:
            ret = self._table
        return ret
    table = property(_get_table)

    def _init_0col(self,current_name,*p,**k):
        p,k = self._init_normalize_params(p,k)
        table = k.pop('table')
        self.current_name = current_name
        self._table = table
        return k
    def _init_1col(self,col,*p,**k):
        p,k = self._init_normalize_params(p,k)
        self._table = k.pop('table',None) or col.table
        self.result_column = col.copy()
        if 'current_name' in k:
            # Renamed
            self.current_name = k.pop('current_name')
            k.setdefault('name',col.name)
        else:
            self.current_name = col.name
        return k
    def _init_2col(self,start_col,end_col,*p,**k):
        p,k = self._init_normalize_params(p,k)
        self.result_column = start_col.copy()
        self._table = k.pop('table',None) or start_col.table or end_col.table
        self.current_name = start_col.name
        for key in ('name','nullable','default','primary_key','foreign_key'):
            val = getattr(end_col,key,None)
            if getattr(start_col,key,None) != val:
                k.setdefault(key,val)
        if not self.column_types_eq(start_col.type,end_col.type):
            k.setdefault('type',end_col.type)
        return k
    def _init_normalize_params(self,p,k):
        p = list(p)
        if len(p):
            k.setdefault('name',p.pop(0))
        if len(p):
            k.setdefault('type',p.pop(0))
        # TODO: sequences? FKs?
        return p,k
    def _set_diffs(self,diffs):
        for key in self.diff_keys:
            if key in diffs:
                self[key] = diffs[key]
                if getattr(self,'result_column',None) is not None:
                    setattr(self.result_column,key,diffs[key])
    def column_types_eq(self,this,that):
        ret = isinstance(this,that.__class__)
        ret = ret or isinstance(that,this.__class__)
        # String length is a special case
        if ret and isinstance(that,sqlalchemy.types.String):
            ret = (getattr(this,'length',None) == getattr(that,'length',None))
        return ret

    def accept_schema_visitor(self,visitor):
        return visitor.visit_column(self)

class ChangesetTable(object):
    """Changeset extensions to SQLAlchemy tables."""
    def create_column(self,column):
        """Creates a column
        The column parameter may be a column definition or the name of a column
        in this table. 
        """
        if not isinstance(column,sqlalchemy.Column):
            # It's a column name
            column = getattr(self.c,str(column))
        column.create(table=self)
    
    def drop_column(self,column):
        """Drop a column, given its name or definition."""
        if not isinstance(column,sqlalchemy.Column):
            # It's a column name
            try:
                column = getattr(self.c,str(column),None)
            except AttributeError:
                # That column isn't part of the table. We don't need its entire
                # definition to drop the column, just its name, so create a dummy
                # column with the same name.
                column = sqlalchemy.Column(str(column))
        column.drop(table=self)

    def _meta_key(self):
        return sqlalchemy.schema._get_table_key(self.name,self.schema)
    def deregister(self):
        """Remove this table from its metadata"""
        key = self._meta_key()
        meta = self.metadata
        if key in meta.tables:
            del meta.tables[key]
    
    def rename(self,name,*args,**kwargs):
        """Rename this table
        This changes both the database name and the name of this Python object
        """
        engine = self.bind
        visitorcallable = get_engine_visitor(engine,'schemachanger')
        param = _WrapRename(self,name)
        #engine._run_visitor(visitorcallable,param,*args,**kwargs)
        _engine_run_visitor(engine,visitorcallable,param,*args,**kwargs)

        # Fix metadata registration
        meta = self.metadata
        self.deregister()
        self.name = name
        self._set_parent(meta)

    def _get_fullname(self):
        """Fullname should always be up to date"""
        # Copied from Table constructor
        if self.schema is not None:
            ret = "%s.%s"%(self.schema,self.name)
        else:
            ret = self.name
        return ret
    fullname = property(_get_fullname,(lambda self,val: None))

class ChangesetColumn(object):
    """Changeset extensions to SQLAlchemy columns"""
    def alter(self,*p,**k):
        """Alter a column's definition: ALTER TABLE ALTER COLUMN
        May supply a new column object, or a list of properties to change.

        For example; the following are equivalent:
            col.alter(Column('myint',Integer,nullable=False))
            col.alter('myint',Integer,nullable=False)
            col.alter(name='myint',type=Integer,nullable=False)

        Column name, type, default, and nullable may be changed here. Note that
        for column defaults, only PassiveDefaults are managed by the database -
        changing others doesn't make sense. 
        """
        if 'table' not in k:
            k['table'] = self.table
        if 'engine' not in k:
            k['engine'] = k['table'].bind
        return alter_column(self,*p,**k)

    def create(self,table=None,*args,**kwargs):
        """Create this column in the database. Assumes the given table exists.
        ALTER TABLE ADD COLUMN, for most databases. 
        """
        table = _normalize_table(self,table)
        engine = table.bind
        visitorcallable = get_engine_visitor(engine,'columngenerator')
        engine._run_visitor(visitorcallable,self,*args,**kwargs)
        return self
    
    def drop(self,table=None,*args,**kwargs):
        """Drop this column from the database, leaving its table intact.
        ALTER TABLE DROP COLUMN, for most databases. 
        """
        table = _normalize_table(self,table)
        engine = table.bind
        visitorcallable = get_engine_visitor(engine,'columndropper')
        #engine._run_visitor(visitorcallable,self,*args,**kwargs)
        engine._run_visitor(lambda dialect, conn: visitorcallable(conn), self, *args, **kwargs)
        ## Remove col from table object, too
        #del table._columns[self.key]
        #if self in table.primary_key:
        #    table.primary_key.remove(self)
        return self
    
class ChangesetIndex(object):
    """Changeset extensions to SQLAlchemy Indexes"""
    def rename(self,name,*args,**kwargs):
        """Change the name of an index.
        This changes both the Python object name and the database name.
        """
        engine = self.table.bind
        visitorcallable = get_engine_visitor(engine,'schemachanger')
        param = _WrapRename(self,name)
        #engine._run_visitor(visitorcallable,param,*args,**kwargs)
        _engine_run_visitor(engine,visitorcallable,param,*args,**kwargs)
        self.name = name


def _patch():
    """All the 'ugly' operations that patch SQLAlchemy's internals."""
    sqlalchemy.schema.Table.__bases__ += (ChangesetTable,)
    sqlalchemy.schema.Column.__bases__ += (ChangesetColumn,)
    sqlalchemy.schema.Index.__bases__ += (ChangesetIndex,)
_patch()