
- code cleanup in the remaining migrate.changeset files - add migrate.changeset.exceptions and migrate.changeset.schema to api.rst - add more information to CHANGELOG
410 lines
13 KiB
Python
410 lines
13 KiB
Python
"""
|
|
Schema module providing common schema operations.
|
|
"""
|
|
import re
|
|
import sqlalchemy
|
|
from migrate.changeset.databases.visitor import get_engine_visitor
|
|
|
|
__all__ = [
|
|
'create_column',
|
|
'drop_column',
|
|
'alter_column',
|
|
'rename_table',
|
|
'rename_index',
|
|
]
|
|
|
|
|
|
def create_column(column, table=None, *p, **k):
|
|
if table is not None:
|
|
return table.create_column(column, *p, **k)
|
|
return column.create(*p, **k)
|
|
|
|
|
|
def drop_column(column, table=None, *p, **k):
|
|
if table is not None:
|
|
return table.drop_column(column, *p, **k)
|
|
return column.drop(*p, **k)
|
|
|
|
|
|
def _to_table(table, engine=None):
|
|
if isinstance(table, sqlalchemy.Table):
|
|
return table
|
|
# Given: table name, maybe an engine
|
|
meta = sqlalchemy.MetaData()
|
|
if engine is not None:
|
|
meta.bind = engine
|
|
return sqlalchemy.Table(table, meta)
|
|
|
|
|
|
def _to_index(index, table=None, engine=None):
|
|
if isinstance(index, sqlalchemy.Index):
|
|
return index
|
|
# Given: index name; table name required
|
|
table = _to_table(table, engine)
|
|
ret = sqlalchemy.Index(index)
|
|
ret.table = table
|
|
return ret
|
|
|
|
|
|
def rename_table(table, name, engine=None):
|
|
"""Rename a table, given the table's current name and the new
|
|
name."""
|
|
table = _to_table(table, engine)
|
|
table.rename(name)
|
|
|
|
|
|
def rename_index(index, name, table=None, engine=None):
|
|
"""Rename an index.
|
|
|
|
Takes an index name/object, a table name/object, and an
|
|
engine. Engine and table aren't required if an index object is
|
|
given.
|
|
"""
|
|
index = _to_index(index, table, engine)
|
|
index.rename(name)
|
|
|
|
|
|
def _engine_run_visitor(engine, visitorcallable, element, **kwargs):
|
|
conn = engine.connect()
|
|
try:
|
|
element.accept_schema_visitor(visitorcallable(engine.dialect,
|
|
connection=conn))
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def alter_column(*p, **k):
|
|
"""Alter a column.
|
|
|
|
Parameters: column name, table name, an engine, and the properties
|
|
of that column to change
|
|
"""
|
|
if len(p) and isinstance(p[0], sqlalchemy.Column):
|
|
col = p[0]
|
|
else:
|
|
col = None
|
|
if 'table' not in k:
|
|
k['table'] = col.table
|
|
if 'engine' not in k:
|
|
k['engine'] = k['table'].bind
|
|
engine = k['engine']
|
|
delta = _ColumnDelta(*p, **k)
|
|
visitorcallable = get_engine_visitor(engine, 'schemachanger')
|
|
_engine_run_visitor(engine, visitorcallable, delta)
|
|
|
|
# Update column
|
|
if col is not None:
|
|
# Special case: change column key on rename, if key not
|
|
# explicit
|
|
#
|
|
# Used by SA : table.c.[key]
|
|
#
|
|
# This fails if the key was explit AND equal to the column
|
|
# name. (It changes the key name when it shouldn't.)
|
|
#
|
|
# Not much we can do about it.
|
|
if 'name' in delta.keys():
|
|
if (col.name == col.key):
|
|
newname = delta['name']
|
|
del col.table.c[col.key]
|
|
setattr(col, 'key', newname)
|
|
col.table.c[col.key] = col
|
|
# Change all other attrs
|
|
for key, val in delta.iteritems():
|
|
setattr(col, key, val)
|
|
|
|
|
|
def _normalize_table(column, table):
|
|
if table is not None:
|
|
if table is not column.table:
|
|
# This is a bit of a hack: we end up with dupe PK columns here
|
|
pk_names = map(lambda c: c.name, table.primary_key)
|
|
if column.primary_key and pk_names.count(column.name):
|
|
index = pk_names.index(column_name)
|
|
del table.primary_key[index]
|
|
table.append_column(column)
|
|
return column.table
|
|
|
|
|
|
class _WrapRename(object):
|
|
|
|
def __init__(self, item, name):
|
|
self.item = item
|
|
self.name = name
|
|
|
|
def accept_schema_visitor(self, visitor):
|
|
if isinstance(self.item, sqlalchemy.Table):
|
|
suffix = 'table'
|
|
elif isinstance(self.item, sqlalchemy.Column):
|
|
suffix = 'column'
|
|
elif isinstance(self.item, sqlalchemy.Index):
|
|
suffix = 'index'
|
|
funcname = 'visit_%s' % suffix
|
|
func = getattr(visitor, funcname)
|
|
param = self.item, self.name
|
|
return func(param)
|
|
|
|
|
|
class _ColumnDelta(dict):
|
|
"""Extracts the differences between two columns/column-parameters"""
|
|
|
|
def __init__(self, *p, **k):
|
|
"""Extract ALTER-able differences from two columns.
|
|
|
|
May receive parameters arranged in several different ways:
|
|
* old_column_object,new_column_object,*parameters Identifies
|
|
attributes that differ between the two columns.
|
|
Parameters specified outside of either column are always
|
|
executed and override column differences.
|
|
* column_object,[current_name,]*parameters Parameters
|
|
specified are changed; table name is extracted from column
|
|
object. Name is changed to column_object.name from
|
|
current_name, if current_name is specified. If not
|
|
specified, name is unchanged.
|
|
* current_name,table,*parameters 'table' may be either an
|
|
object or a name
|
|
"""
|
|
# Things are initialized differently depending on how many column
|
|
# parameters are given. Figure out how many and call the appropriate
|
|
# method.
|
|
if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column):
|
|
# At least one column specified
|
|
if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
|
|
# Two columns specified
|
|
func = self._init_2col
|
|
else:
|
|
# Exactly one column specified
|
|
func = self._init_1col
|
|
else:
|
|
# Zero columns specified
|
|
func = self._init_0col
|
|
diffs = func(*p, **k)
|
|
self._set_diffs(diffs)
|
|
# Column attributes that can be altered
|
|
diff_keys = ('name', 'type', 'nullable', 'default', 'server_default',
|
|
'primary_key', 'foreign_key')
|
|
|
|
def _get_table_name(self):
|
|
if isinstance(self._table, basestring):
|
|
ret = self._table
|
|
else:
|
|
ret = self._table.name
|
|
return ret
|
|
table_name = property(_get_table_name)
|
|
|
|
def _get_table(self):
|
|
if isinstance(self._table, basestring):
|
|
ret = None
|
|
else:
|
|
ret = self._table
|
|
return ret
|
|
table = property(_get_table)
|
|
|
|
def _init_0col(self, current_name, *p, **k):
|
|
p, k = self._init_normalize_params(p, k)
|
|
table = k.pop('table')
|
|
self.current_name = current_name
|
|
self._table = table
|
|
return k
|
|
|
|
def _init_1col(self, col, *p, **k):
|
|
p, k = self._init_normalize_params(p, k)
|
|
self._table = k.pop('table', None) or col.table
|
|
self.result_column = col.copy()
|
|
if 'current_name' in k:
|
|
# Renamed
|
|
self.current_name = k.pop('current_name')
|
|
k.setdefault('name', col.name)
|
|
else:
|
|
self.current_name = col.name
|
|
return k
|
|
|
|
def _init_2col(self, start_col, end_col, *p, **k):
|
|
p, k = self._init_normalize_params(p, k)
|
|
self.result_column = start_col.copy()
|
|
self._table = k.pop('table', None) or start_col.table \
|
|
or end_col.table
|
|
self.current_name = start_col.name
|
|
for key in ('name', 'nullable', 'default', 'server_default',
|
|
'primary_key', 'foreign_key'):
|
|
val = getattr(end_col, key, None)
|
|
if getattr(start_col, key, None) != val:
|
|
k.setdefault(key, val)
|
|
if not self.column_types_eq(start_col.type, end_col.type):
|
|
k.setdefault('type', end_col.type)
|
|
return k
|
|
|
|
def _init_normalize_params(self, p, k):
|
|
p = list(p)
|
|
if len(p):
|
|
k.setdefault('name', p.pop(0))
|
|
if len(p):
|
|
k.setdefault('type', p.pop(0))
|
|
# TODO: sequences? FKs?
|
|
return p, k
|
|
|
|
def _set_diffs(self, diffs):
|
|
for key in self.diff_keys:
|
|
if key in diffs:
|
|
self[key] = diffs[key]
|
|
if getattr(self, 'result_column', None) is not None:
|
|
setattr(self.result_column, key, diffs[key])
|
|
|
|
def column_types_eq(self, this, that):
|
|
ret = isinstance(this, that.__class__)
|
|
ret = ret or isinstance(that, this.__class__)
|
|
# String length is a special case
|
|
if ret and isinstance(that, sqlalchemy.types.String):
|
|
ret = (getattr(this, 'length', None) == \
|
|
getattr(that, 'length', None))
|
|
return ret
|
|
|
|
def accept_schema_visitor(self, visitor):
|
|
return visitor.visit_column(self)
|
|
|
|
|
|
class ChangesetTable(object):
|
|
"""Changeset extensions to SQLAlchemy tables."""
|
|
|
|
def create_column(self, column):
|
|
"""Creates a column.
|
|
|
|
The column parameter may be a column definition or the name of
|
|
a column in this table.
|
|
"""
|
|
if not isinstance(column, sqlalchemy.Column):
|
|
# It's a column name
|
|
column = getattr(self.c, str(column))
|
|
column.create(table=self)
|
|
|
|
def drop_column(self, column):
|
|
"""Drop a column, given its name or definition."""
|
|
if not isinstance(column, sqlalchemy.Column):
|
|
# It's a column name
|
|
try:
|
|
column = getattr(self.c, str(column), None)
|
|
except AttributeError:
|
|
# That column isn't part of the table. We don't need
|
|
# its entire definition to drop the column, just its
|
|
# name, so create a dummy column with the same name.
|
|
column = sqlalchemy.Column(str(column))
|
|
column.drop(table=self)
|
|
|
|
def _meta_key(self):
|
|
return sqlalchemy.schema._get_table_key(self.name, self.schema)
|
|
|
|
def deregister(self):
|
|
"""Remove this table from its metadata"""
|
|
key = self._meta_key()
|
|
meta = self.metadata
|
|
if key in meta.tables:
|
|
del meta.tables[key]
|
|
|
|
def rename(self, name, *args, **kwargs):
|
|
"""Rename this table.
|
|
|
|
This changes both the database name and the name of this
|
|
Python object
|
|
"""
|
|
engine = self.bind
|
|
visitorcallable = get_engine_visitor(engine, 'schemachanger')
|
|
param = _WrapRename(self, name)
|
|
_engine_run_visitor(engine, visitorcallable, param, *args, **kwargs)
|
|
|
|
# Fix metadata registration
|
|
meta = self.metadata
|
|
self.deregister()
|
|
self.name = name
|
|
self._set_parent(meta)
|
|
|
|
def _get_fullname(self):
|
|
"""Fullname should always be up to date"""
|
|
# Copied from Table constructor
|
|
if self.schema is not None:
|
|
ret = "%s.%s"%(self.schema, self.name)
|
|
else:
|
|
ret = self.name
|
|
return ret
|
|
|
|
fullname = property(_get_fullname, (lambda self, val: None))
|
|
|
|
|
|
class ChangesetColumn(object):
|
|
"""Changeset extensions to SQLAlchemy columns"""
|
|
|
|
def alter(self, *p, **k):
|
|
"""Alter a column's definition: ``ALTER TABLE ALTER COLUMN``.
|
|
|
|
May supply a new column object, or a list of properties to
|
|
change.
|
|
|
|
For example; the following are equivalent:
|
|
col.alter(Column('myint', Integer, nullable=False))
|
|
col.alter('myint', Integer, nullable=False)
|
|
col.alter(name='myint', type=Integer, nullable=False)
|
|
|
|
Column name, type, default, and nullable may be changed
|
|
here. Note that for column defaults, only PassiveDefaults are
|
|
managed by the database - changing others doesn't make sense.
|
|
"""
|
|
if 'table' not in k:
|
|
k['table'] = self.table
|
|
if 'engine' not in k:
|
|
k['engine'] = k['table'].bind
|
|
return alter_column(self, *p, **k)
|
|
|
|
def create(self, table=None, *args, **kwargs):
|
|
"""Create this column in the database.
|
|
|
|
Assumes the given table exists. ``ALTER TABLE ADD COLUMN``,
|
|
for most databases.
|
|
"""
|
|
table = _normalize_table(self, table)
|
|
engine = table.bind
|
|
visitorcallable = get_engine_visitor(engine, 'columngenerator')
|
|
engine._run_visitor(visitorcallable, self, *args, **kwargs)
|
|
|
|
#add in foreign keys
|
|
if self.foreign_keys:
|
|
for fk in self.foreign_keys:
|
|
visitorcallable = get_engine_visitor(engine,
|
|
'columnfkgenerator')
|
|
engine._run_visitor(visitorcallable, self, fk=fk)
|
|
return self
|
|
|
|
def drop(self, table=None, *args, **kwargs):
|
|
"""Drop this column from the database, leaving its table intact.
|
|
|
|
``ALTER TABLE DROP COLUMN``, for most databases.
|
|
"""
|
|
table = _normalize_table(self, table)
|
|
engine = table.bind
|
|
visitorcallable = get_engine_visitor(engine, 'columndropper')
|
|
engine._run_visitor(lambda dialect, conn: visitorcallable(conn),
|
|
self, *args, **kwargs)
|
|
return self
|
|
|
|
|
|
class ChangesetIndex(object):
|
|
"""Changeset extensions to SQLAlchemy Indexes."""
|
|
|
|
def rename(self, name, *args, **kwargs):
|
|
"""Change the name of an index.
|
|
|
|
This changes both the Python object name and the database
|
|
name.
|
|
"""
|
|
engine = self.table.bind
|
|
visitorcallable = get_engine_visitor(engine, 'schemachanger')
|
|
param = _WrapRename(self, name)
|
|
_engine_run_visitor(engine, visitorcallable, param, *args, **kwargs)
|
|
self.name = name
|
|
|
|
|
|
def _patch():
|
|
"""All the 'ugly' operations that patch SQLAlchemy's internals."""
|
|
sqlalchemy.schema.Table.__bases__ += (ChangesetTable, )
|
|
sqlalchemy.schema.Column.__bases__ += (ChangesetColumn, )
|
|
sqlalchemy.schema.Index.__bases__ += (ChangesetIndex, )
|
|
_patch()
|