Refa: refine mysql migration version workflow (#15549)

### What problem does this PR solve?

refine mysql migration version workflow

### Type of change

- [x] Refactoring
This commit is contained in:
buua436
2026-06-03 11:51:42 +08:00
committed by GitHub
parent a75ea7ba7c
commit 76fc1d547f
3 changed files with 175 additions and 75 deletions

View File

@@ -31,6 +31,7 @@ import sys
import time
import uuid
from packaging.version import InvalidVersion, Version
from peewee import (
CharField,
IntegerField,
@@ -55,6 +56,9 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
MIGRATION_DB_VERSION_MARKER = "mysql_migration.database.version"
class MigrationConfig:
"""Configuration for MySQL connection"""
@@ -217,6 +221,33 @@ class MigrationDatabase:
),
)
def get_database_version(self) -> str | None:
return self.get_system_setting_value(MIGRATION_DB_VERSION_MARKER)
def set_database_version(self, version: str):
self.upsert_system_setting(MIGRATION_DB_VERSION_MARKER, version)
def parse_migration_version(version: str | None) -> Version | None:
if not version:
return None
normalized = version.strip()
if normalized.startswith(("v", "V")):
normalized = normalized[1:]
try:
return Version(normalized)
except InvalidVersion:
logger.warning("Invalid migration version format: %s", version)
return None
def should_skip_migration(current_db_version: str | None, target_version: str) -> bool:
current = parse_migration_version(current_db_version)
target = parse_migration_version(target_version)
if current is None or target is None:
return False
return current >= target
# Define model classes for migration (not importing from api.db.db_models)
class BaseModel(Model):
@@ -266,9 +297,6 @@ class MigrationStage:
description = "Base migration stage"
source_tables = []
target_tables = []
migration_version = None
migration_marker_prefix = "mysql_migration"
def __init__(self, db: MigrationDatabase, dry_run: bool = True, create_table_only: bool = False):
self.db = db
self.dry_run = dry_run
@@ -287,46 +315,6 @@ class MigrationStage:
"""Create target table (override in subclass if needed)"""
pass
def migration_marker_name(self) -> str:
return f"{self.migration_marker_prefix}.{self.name}.version"
def is_migration_version_applied(self) -> bool:
if not self.migration_version:
return False
marker_name = self.migration_marker_name()
current_version = self.db.get_system_setting_value(marker_name)
if current_version == self.migration_version:
logger.info(
"Stage '%s' already applied at version %s, skipping",
self.name,
self.migration_version,
)
return True
if current_version:
logger.info(
"Stage '%s' marker version is %s, target version is %s",
self.name,
current_version,
self.migration_version,
)
return False
def mark_migration_version_applied(self):
if not self.migration_version:
return
self.db.upsert_system_setting(
self.migration_marker_name(),
self.migration_version,
)
logger.info(
"Marked stage '%s' as applied at version %s",
self.name,
self.migration_version,
)
def mark_noop_completes_migration(self):
self._noop_completes_migration = True
@@ -341,7 +329,6 @@ class TenantModelProviderStage(MigrationStage):
description = "Migrate tenant_llm.llm_factory to tenant_model_provider.provider_name"
source_tables = ["tenant_llm"]
target_tables = ["tenant_model_provider"]
migration_version = "1"
def current_timestamp(self) -> int:
return int(time.time())
@@ -480,7 +467,6 @@ class TenantModelInstanceStage(MigrationStage):
description = "Migrate tenant_llm to tenant_model_instance with provider_id lookup"
source_tables = ["tenant_llm", "tenant_model_provider"]
target_tables = ["tenant_model_instance"]
migration_version = "1"
def current_timestamp(self) -> int:
return int(time.time())
@@ -647,7 +633,6 @@ class TenantModelStage(MigrationStage):
description = "Migrate tenant_llm to tenant_model (status='0' records, plus status='1' for empty-llm factories)"
source_tables = ["tenant_llm", "tenant_model_provider", "tenant_model_instance"]
target_tables = ["tenant_model"]
migration_version = "1"
@staticmethod
def _get_empty_llm_factories() -> list[str]:
@@ -861,7 +846,6 @@ class ModelIdConfigStage(MigrationStage):
name = "model_id_config"
description = "Normalize stored model IDs in config columns to model@default@provider"
migration_version = "1"
source_tables = [
"tenant",
"knowledgebase",
@@ -1117,8 +1101,14 @@ def list_available_stages():
logger.info(f" Target tables: {stage_cls.target_tables}")
def run_migration(config: MigrationConfig, stages: list, dry_run: bool = True,
create_table_only: bool = False):
def run_migration(
config: MigrationConfig,
stages: list,
dry_run: bool = True,
create_table_only: bool = False,
database_version: str | None = None,
mark_database_version_on_success: bool = False,
):
"""Run migration with specified stages"""
stats = MigrationStats()
stats.start()
@@ -1127,8 +1117,31 @@ def run_migration(config: MigrationConfig, stages: list, dry_run: bool = True,
try:
db.connect()
if database_version:
current_db_version = db.get_database_version()
if should_skip_migration(current_db_version, database_version):
logger.info(
"Database migration version is %s, target version is %s, skipping all stages",
current_db_version,
database_version,
)
return
if current_db_version:
logger.info(
"Current database migration version is %s, target version is %s",
current_db_version,
database_version,
)
else:
logger.info(
"Database migration version marker is not set, target version is %s",
database_version,
)
total_stages = len(stages)
all_stages_completed = True
for idx, stage_name in enumerate(stages, 1):
logger.info(f"{'=' * 60}")
@@ -1138,16 +1151,13 @@ def run_migration(config: MigrationConfig, stages: list, dry_run: bool = True,
if stage_name not in MIGRATION_STAGES:
logger.error(f"Unknown stage: {stage_name}")
stats.add_stage_stats(stage_name, [], 0, 0)
all_stages_completed = False
continue
stage_cls = MIGRATION_STAGES[stage_name]
stage = stage_cls(db, dry_run=dry_run, create_table_only=create_table_only)
stage_start = time.time()
if not create_table_only and stage.is_migration_version_applied():
stats.add_stage_stats(stage_name, [], 0, time.time() - stage_start)
continue
# For create_table_only mode, skip check and directly execute
if create_table_only:
@@ -1157,27 +1167,72 @@ def run_migration(config: MigrationConfig, stages: list, dry_run: bool = True,
# Check if migration is needed
if not stage.check():
logger.info(f"Stage '{stage_name}' check: no migration needed")
if not dry_run and stage.noop_completes_migration():
stage.mark_migration_version_applied()
stats.add_stage_stats(stage_name, [], 0, time.time() - stage_start)
continue
# Execute migration
rows, tables = stage.execute()
if not dry_run:
stage.mark_migration_version_applied()
stage_duration = time.time() - stage_start
stats.add_stage_stats(stage_name, tables, rows, stage_duration)
logger.info(f"Stage '{stage_name}' completed: {rows} rows in {stage_duration:.2f}s")
if (
mark_database_version_on_success
and not dry_run
and not create_table_only
and database_version
and all_stages_completed
):
db.set_database_version(database_version)
logger.info("Marked database migration version as %s", database_version)
finally:
db.close()
stats.end()
stats.print_summary()
def check_database_version(config: MigrationConfig, target_version: str) -> int:
db = MigrationDatabase(config)
try:
db.connect()
current_db_version = db.get_database_version()
if should_skip_migration(current_db_version, target_version):
logger.info(
"Database migration version is %s, target version is %s, migration is not needed",
current_db_version,
target_version,
)
return 0
if current_db_version:
logger.info(
"Database migration version is %s, target version is %s, migration is needed",
current_db_version,
target_version,
)
else:
logger.info(
"Database migration version marker is not set, target version is %s, migration is needed",
target_version,
)
return 1
finally:
db.close()
def mark_database_version(config: MigrationConfig, version: str) -> None:
db = MigrationDatabase(config)
try:
db.connect()
db.set_database_version(version)
logger.info("Marked database migration version as %s", version)
finally:
db.close()
def main():
parser = argparse.ArgumentParser(
description='MySQL Data Migration Tool',
@@ -1186,6 +1241,12 @@ def main():
Examples:
# List available stages
python mysql_migration.py --list-stages
# Check whether migration is needed for a target version
python mysql_migration.py --check-database-version --database-version v0.26.0 --config /path/to/config.yaml
# Mark database version separately
python mysql_migration.py --mark-database-version --database-version v0.26.0 --config /path/to/config.yaml
# Dry run (default - check only, no write) with config file
python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml
@@ -1198,6 +1259,12 @@ Examples:
# Execute full migration (create tables and migrate data)
python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml --execute
# Execute migration only when database version is lower than v0.26.0
python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml --execute --database-version v0.26.0
# Execute migration and mark the database version when all stages succeed
python mysql_migration.py --stages tenant_model_provider,tenant_model_instance,tenant_model,model_id_config --config /path/to/config.yaml --execute --database-version v0.26.0 --mark-database-version-on-success
# Normalize legacy model IDs in stored configs
python mysql_migration.py --stages model_id_config --config /path/to/config.yaml --execute
@@ -1225,6 +1292,14 @@ Examples:
# Migration options
parser.add_argument('--stages', '-s', type=str, help='Comma-separated list of stages to run')
parser.add_argument('--list-stages', '-l', action='store_true', help='List available stages')
parser.add_argument('--check-database-version', action='store_true',
help='Check whether migration is needed for the target database version')
parser.add_argument('--mark-database-version', action='store_true',
help='Write the database migration version marker and exit')
parser.add_argument('--database-version', type=str, metavar='VERSION',
help='Database migration version used by check/mark commands and as the migration threshold for --stages')
parser.add_argument('--mark-database-version-on-success', action='store_true',
help='When used with --stages and --execute, write --database-version after all stages succeed')
parser.add_argument('--execute', '-e', action='store_true', default=False,
help='Execute full migration: create tables and migrate data')
parser.add_argument('--create-table-only', action='store_true', default=False,
@@ -1237,13 +1312,6 @@ Examples:
list_available_stages()
return
# Parse stages
if not args.stages:
logger.error("No stages specified. Use --stages to specify stages or --list-stages to see available stages.")
sys.exit(1)
stages = [s.strip() for s in args.stages.split(',')]
# Load configuration: command line args take precedence over config file
if args.config:
config = MigrationConfig.from_config_file(args.config)
@@ -1270,12 +1338,38 @@ Examples:
logger.info(f"MySQL Configuration: host={config.host}, port={config.port}, "
f"user={config.user}, database={config.database}")
if args.check_database_version and args.mark_database_version:
logger.error("--check-database-version and --mark-database-version are mutually exclusive")
sys.exit(1)
if args.check_database_version:
if not args.database_version:
logger.error("--check-database-version requires --database-version")
sys.exit(1)
sys.exit(check_database_version(config, args.database_version))
if args.mark_database_version:
if not args.database_version:
logger.error("--mark-database-version requires --database-version")
sys.exit(1)
mark_database_version(config, args.database_version)
return
if args.mark_database_version_on_success and not args.database_version:
logger.error("--mark-database-version-on-success requires --database-version")
sys.exit(1)
# Three mutually exclusive modes: dry-run (default), create-table-only, execute
if args.execute and args.create_table_only:
logger.error("--execute and --create-table-only are mutually exclusive")
sys.exit(1)
if not args.stages:
logger.error("No stages specified. Use --stages to specify stages or --list-stages to see available stages.")
sys.exit(1)
stages = [s.strip() for s in args.stages.split(',')]
dry_run = True
create_table_only = False
@@ -1294,7 +1388,9 @@ Examples:
config=config,
stages=stages,
dry_run=dry_run,
create_table_only=create_table_only
create_table_only=create_table_only,
database_version=args.database_version,
mark_database_version_on_success=args.mark_database_version_on_success,
)