From b78a249b5207e3e0ab2df3b469c3af48acfe5ff7 Mon Sep 17 00:00:00 2001 From: Pete Keen Date: Fri, 10 Jun 2011 13:56:30 -0700 Subject: [PATCH] Allow descriptions in sql change script filenames --- migrate/tests/versioning/test_repository.py | 3 ++- migrate/versioning/api.py | 14 ++++++------ migrate/versioning/repository.py | 4 ++-- migrate/versioning/version.py | 25 ++++++++++++++++----- 4 files changed, 30 insertions(+), 16 deletions(-) diff --git a/migrate/tests/versioning/test_repository.py b/migrate/tests/versioning/test_repository.py index 9ef45c6..20aa433 100644 --- a/migrate/tests/versioning/test_repository.py +++ b/migrate/tests/versioning/test_repository.py @@ -120,13 +120,14 @@ class TestVersionedRepository(fixture.Pathed): # Load repository and commit script repo = Repository(self.path_repos) repo.create_script('') - repo.create_script_sql('postgres') + repo.create_script_sql('postgres', 'foo bar') # Source is valid: script must have an upgrade function # (not a very thorough test, but should be plenty) source = repo.version(1).script().source() self.assertTrue(source.find('def upgrade') >= 0) + import pprint; pprint.pprint(repo.version(2).sql) source = repo.version(2).script('postgres', 'upgrade').source() self.assertEqual(source.strip(), '') diff --git a/migrate/versioning/api.py b/migrate/versioning/api.py index c159a49..eee74be 100644 --- a/migrate/versioning/api.py +++ b/migrate/versioning/api.py @@ -110,19 +110,19 @@ def script(description, repository, **opts): @catch_known_errors -def script_sql(database, repository, **opts): - """%prog script_sql DATABASE REPOSITORY_PATH +def script_sql(database, description, repository, **opts): + """%prog script_sql DATABASE DESCRIPTION REPOSITORY_PATH Create empty change SQL scripts for given DATABASE, where DATABASE - is either specific ('postgres', 'mysql', 'oracle', 'sqlite', etc.) + is either specific ('postgresql', 'mysql', 'oracle', 'sqlite', etc.) or generic ('default'). - For instance, manage.py script_sql postgres creates: - repository/versions/001_postgres_upgrade.sql and - repository/versions/001_postgres_postgres.sql + For instance, manage.py script_sql postgresql description creates: + repository/versions/001_description_postgresql_upgrade.sql and + repository/versions/001_description_postgresql_postgres.sql """ repo = Repository(repository) - repo.create_script_sql(database, **opts) + repo.create_script_sql(database, description, **opts) def version(repository, **opts): diff --git a/migrate/versioning/repository.py b/migrate/versioning/repository.py index 5032cfa..9b1c08a 100644 --- a/migrate/versioning/repository.py +++ b/migrate/versioning/repository.py @@ -157,10 +157,10 @@ class Repository(pathed.Pathed): k['use_timestamp_numbering'] = self.use_timestamp_numbering self.versions.create_new_python_version(description, **k) - def create_script_sql(self, database, **k): + def create_script_sql(self, database, description, **k): """API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`""" k['use_timestamp_numbering'] = self.use_timestamp_numbering - self.versions.create_new_sql_version(database, **k) + self.versions.create_new_sql_version(database, description, **k) @property def latest(self): diff --git a/migrate/versioning/version.py b/migrate/versioning/version.py index fdb78a9..f41a71c 100644 --- a/migrate/versioning/version.py +++ b/migrate/versioning/version.py @@ -114,14 +114,22 @@ class Collection(pathed.Pathed): script.PythonScript.create(filepath, **k) self.versions[ver] = Version(ver, self.path, [filename]) - def create_new_sql_version(self, database, **k): + def create_new_sql_version(self, database, description, **k): """Create SQL files for new version""" ver = self._next_ver_num(k.pop('use_timestamp_numbering', False)) self.versions[ver] = Version(ver, self.path, []) + extra = str_to_filename(description) + + if extra: + if extra == '_': + extra = '' + elif not extra.startswith('_'): + extra = '_%s' % extra + # Create new files. for op in ('upgrade', 'downgrade'): - filename = '%03d_%s_%s.sql' % (ver, database, op) + filename = '%03d%s_%s_%s.sql' % (ver, extra, database, op) filepath = self._version_path(filename) script.SqlScript.create(filepath, **k) self.versions[ver].add_script(filepath) @@ -185,18 +193,23 @@ class Version(object): elif path.endswith(Extensions.sql): self._add_script_sql(path) - SQL_FILENAME = re.compile(r'^(\d+)_([^_]+)_([^_]+).sql') + SQL_FILENAME = re.compile(r'^.*\.sql') def _add_script_sql(self, path): basename = os.path.basename(path) match = self.SQL_FILENAME.match(basename) - + if match: - version, dbms, op = match.group(1), match.group(2), match.group(3) + basename = basename.replace('.sql', '') + parts = basename.split('_') + assert len(parts) >= 3 + version = parts[0] + op = parts[-1] + dbms = parts[-2] else: raise exceptions.ScriptError( "Invalid SQL script name %s " % basename + \ - "(needs to be ###_database_operation.sql)") + "(needs to be ###_description_database_operation.sql)") # File the script into a dictionary self.sql.setdefault(dbms, {})[op] = script.SqlScript(path)