"""
   Schema differencing support.
"""
import logging

import sqlalchemy
from migrate.changeset import SQLA_06


log = logging.getLogger(__name__)

def getDiffOfModelAgainstDatabase(model, conn, excludeTables=None):
    """
    Return differences of model against database.

    :return: object which will evaluate to :keyword:`True` if there \
      are differences else :keyword:`False`.
    """
    return SchemaDiff(model, conn, excludeTables)


def getDiffOfModelAgainstModel(oldmodel, model, conn, excludeTables=None):
    """
    Return differences of model against another model.

    :return: object which will evaluate to :keyword:`True` if there \
      are differences else :keyword:`False`.
    """
    return SchemaDiff(model, conn, excludeTables, oldmodel=oldmodel)


class SchemaDiff(object):
    """
    Differences of model against database.
    """

    def __init__(self, model, conn, excludeTables=None, oldmodel=None):
        """
        :param model: Python model's metadata
        :param conn: active database connection.
        """
        self.model = model
        self.conn = conn
        if not excludeTables:
            # [] can't be default value in Python parameter
            excludeTables = []
        self.excludeTables = excludeTables
        if oldmodel:
            self.reflected_model = oldmodel
        else:
            self.reflected_model = sqlalchemy.MetaData(conn, reflect=True)
        self.tablesMissingInDatabase, self.tablesMissingInModel, \
            self.tablesWithDiff = [], [], []
        self.colDiffs = {}
        self.compareModelToDatabase()

    def compareModelToDatabase(self):
        """
        Do actual comparison.
        """
        # Setup common variables.
        cc = self.conn.contextual_connect()
        if SQLA_06:
            from sqlalchemy.ext import compiler
            from sqlalchemy.schema import DDLElement
            class DefineColumn(DDLElement):
                def __init__(self, col):
                    self.col = col
            
            @compiler.compiles(DefineColumn)
            def compile(elem, compiler, **kw):
                return compiler.get_column_specification(elem.col)
            
            def get_column_specification(col):
                return str(DefineColumn(col).compile(dialect=self.conn.dialect))
        else:
            schemagenerator = self.conn.dialect.schemagenerator(
                self.conn.dialect, cc)
            def get_column_specification(col):
                return schemagenerator.get_column_specification(col)
                
        # For each in model, find missing in database.
        for modelName, modelTable in self.model.tables.items():
            if modelName in self.excludeTables:
                continue
            reflectedTable = self.reflected_model.tables.get(modelName, None)
            if reflectedTable is not None:
                # Table exists.
                pass
            else:
                self.tablesMissingInDatabase.append(modelTable)

        # For each in database, find missing in model.
        for reflectedName, reflectedTable in \
                self.reflected_model.tables.items():
            if reflectedName in self.excludeTables:
                continue
            modelTable = self.model.tables.get(reflectedName, None)
            if modelTable is not None:
                # Table exists.

                # Find missing columns in database.
                for modelCol in modelTable.columns:
                    databaseCol = reflectedTable.columns.get(modelCol.name,
                                                             None)
                    if databaseCol is not None:
                        pass
                    else:
                        self.storeColumnMissingInDatabase(modelTable, modelCol)

                # Find missing columns in model.
                for databaseCol in reflectedTable.columns:
                    
                    # TODO: no test coverage here?   (mrb)
                    
                    modelCol = modelTable.columns.get(databaseCol.name, None)
                    if modelCol is not None:
                        # Compare attributes of column.
                        modelDecl = \
                            get_column_specification(modelCol)
                        databaseDecl = \
                            get_column_specification(databaseCol)
                        if modelDecl != databaseDecl:
                            # Unfortunately, sometimes the database
                            # decl won't quite match the model, even
                            # though they're the same.
                            mc, dc = modelCol.type.__class__, \
                                databaseCol.type.__class__
                            if (issubclass(mc, dc) \
                                    or issubclass(dc, mc)) \
                                    and modelCol.nullable == \
                                    databaseCol.nullable:
                                # Types and nullable are the same.
                                pass
                            else:
                                self.storeColumnDiff(
                                    modelTable, modelCol, databaseCol,
                                    modelDecl, databaseDecl)
                    else:
                        self.storeColumnMissingInModel(modelTable, databaseCol)
            else:
                self.tablesMissingInModel.append(reflectedTable)

    def __str__(self):
        ''' Summarize differences. '''

        def colDiffDetails():
            colout = []
            for table in self.tablesWithDiff:
                tableName = table.name
                missingInDatabase, missingInModel, diffDecl = \
                    self.colDiffs[tableName]
                if missingInDatabase:
                    colout.append(
                        '    %s missing columns in database: %s' % \
                            (tableName, ', '.join(
                                [col.name for col in missingInDatabase])))
                if missingInModel:
                    colout.append(
                        '    %s missing columns in model: %s' % \
                            (tableName, ', '.join(
                                [col.name for col in missingInModel])))
                if diffDecl:
                    colout.append(
                        '    %s with different declaration of columns\
 in database: %s' % (tableName, str(diffDecl)))
            return colout

        out = []
        if self.tablesMissingInDatabase:
            out.append(
                '  tables missing in database: %s' % \
                    ', '.join(
                    [table.name for table in self.tablesMissingInDatabase]))
        if self.tablesMissingInModel:
            out.append(
                '  tables missing in model: %s' % \
                    ', '.join(
                    [table.name for table in self.tablesMissingInModel]))
        if self.tablesWithDiff:
            out.append(
                '  tables with differences: %s' % \
                    ', '.join([table.name for table in self.tablesWithDiff]))

        if out:
            out.insert(0, 'Schema diffs:')
            out.extend(colDiffDetails())
            return '\n'.join(out)
        else:
            return 'No schema diffs'

    def __len__(self):
        """
        Used in bool evaluation, return of 0 means no diffs.
        """
        return len(self.tablesMissingInDatabase) + \
            len(self.tablesMissingInModel) + len(self.tablesWithDiff)

    def storeColumnMissingInDatabase(self, table, col):
        if table not in self.tablesWithDiff:
            self.tablesWithDiff.append(table)
        missingInDatabase, missingInModel, diffDecl = \
            self.colDiffs.setdefault(table.name, ([], [], []))
        missingInDatabase.append(col)

    def storeColumnMissingInModel(self, table, col):
        if table not in self.tablesWithDiff:
            self.tablesWithDiff.append(table)
        missingInDatabase, missingInModel, diffDecl = \
            self.colDiffs.setdefault(table.name, ([], [], []))
        missingInModel.append(col)

    def storeColumnDiff(self, table, modelCol, databaseCol, modelDecl,
                        databaseDecl):
        if table not in self.tablesWithDiff:
            self.tablesWithDiff.append(table)
        missingInDatabase, missingInModel, diffDecl = \
            self.colDiffs.setdefault(table.name, ([], [], []))
        diffDecl.append((modelCol, databaseCol, modelDecl, databaseDecl))