diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 503a913d75..74a0d83022 100755 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -274,10 +274,12 @@ ensure_db_init if [[ "${INIT_MODEL_PROVIDER_TABLES}" -eq 1 ]]; then echo "Running model provider table migrations..." - "$PY" tools/scripts/mysql_migration.py --stages tenant_model_provider --config conf/service_conf.yaml --execute - "$PY" tools/scripts/mysql_migration.py --stages tenant_model_instance --config conf/service_conf.yaml --execute - "$PY" tools/scripts/mysql_migration.py --stages tenant_model --config conf/service_conf.yaml --execute - "$PY" tools/scripts/mysql_migration.py --stages model_id_config --config conf/service_conf.yaml --execute + "$PY" tools/scripts/mysql_migration.py \ + --stages tenant_model_provider,tenant_model_instance,tenant_model,model_id_config \ + --config conf/service_conf.yaml \ + --execute \ + --database-version "v0.26.0" \ + --mark-database-version-on-success echo "Model provider table migrations completed." fi diff --git a/docker/launch_backend_service.sh b/docker/launch_backend_service.sh index b20ddc973c..ac2ec5e4ab 100755 --- a/docker/launch_backend_service.sh +++ b/docker/launch_backend_service.sh @@ -122,10 +122,12 @@ ensure_db_init() { run_mysql_migrations() { echo "Running model provider table migrations..." - "$PY" tools/scripts/mysql_migration.py --stages tenant_model_provider --config conf/service_conf.yaml --execute - "$PY" tools/scripts/mysql_migration.py --stages tenant_model_instance --config conf/service_conf.yaml --execute - "$PY" tools/scripts/mysql_migration.py --stages tenant_model --config conf/service_conf.yaml --execute - "$PY" tools/scripts/mysql_migration.py --stages model_id_config --config conf/service_conf.yaml --execute + "$PY" tools/scripts/mysql_migration.py \ + --stages tenant_model_provider,tenant_model_instance,tenant_model,model_id_config \ + --config conf/service_conf.yaml \ + --execute \ + --database-version "v0.26.0" \ + --mark-database-version-on-success echo "Model provider table migrations completed." } diff --git a/tools/scripts/mysql_migration.py b/tools/scripts/mysql_migration.py index 6a997ec396..b84d9662a5 100644 --- a/tools/scripts/mysql_migration.py +++ b/tools/scripts/mysql_migration.py @@ -31,6 +31,7 @@ import sys import time import uuid +from packaging.version import InvalidVersion, Version from peewee import ( CharField, IntegerField, @@ -55,6 +56,9 @@ logging.basicConfig( logger = logging.getLogger(__name__) +MIGRATION_DB_VERSION_MARKER = "mysql_migration.database.version" + + class MigrationConfig: """Configuration for MySQL connection""" @@ -217,6 +221,33 @@ class MigrationDatabase: ), ) + def get_database_version(self) -> str | None: + return self.get_system_setting_value(MIGRATION_DB_VERSION_MARKER) + + def set_database_version(self, version: str): + self.upsert_system_setting(MIGRATION_DB_VERSION_MARKER, version) + + +def parse_migration_version(version: str | None) -> Version | None: + if not version: + return None + normalized = version.strip() + if normalized.startswith(("v", "V")): + normalized = normalized[1:] + try: + return Version(normalized) + except InvalidVersion: + logger.warning("Invalid migration version format: %s", version) + return None + + +def should_skip_migration(current_db_version: str | None, target_version: str) -> bool: + current = parse_migration_version(current_db_version) + target = parse_migration_version(target_version) + if current is None or target is None: + return False + return current >= target + # Define model classes for migration (not importing from api.db.db_models) class BaseModel(Model): @@ -266,9 +297,6 @@ class MigrationStage: description = "Base migration stage" source_tables = [] target_tables = [] - migration_version = None - migration_marker_prefix = "mysql_migration" - def __init__(self, db: MigrationDatabase, dry_run: bool = True, create_table_only: bool = False): self.db = db self.dry_run = dry_run @@ -287,46 +315,6 @@ class MigrationStage: """Create target table (override in subclass if needed)""" pass - def migration_marker_name(self) -> str: - return f"{self.migration_marker_prefix}.{self.name}.version" - - def is_migration_version_applied(self) -> bool: - if not self.migration_version: - return False - - marker_name = self.migration_marker_name() - current_version = self.db.get_system_setting_value(marker_name) - if current_version == self.migration_version: - logger.info( - "Stage '%s' already applied at version %s, skipping", - self.name, - self.migration_version, - ) - return True - - if current_version: - logger.info( - "Stage '%s' marker version is %s, target version is %s", - self.name, - current_version, - self.migration_version, - ) - return False - - def mark_migration_version_applied(self): - if not self.migration_version: - return - - self.db.upsert_system_setting( - self.migration_marker_name(), - self.migration_version, - ) - logger.info( - "Marked stage '%s' as applied at version %s", - self.name, - self.migration_version, - ) - def mark_noop_completes_migration(self): self._noop_completes_migration = True @@ -341,7 +329,6 @@ class TenantModelProviderStage(MigrationStage): description = "Migrate tenant_llm.llm_factory to tenant_model_provider.provider_name" source_tables = ["tenant_llm"] target_tables = ["tenant_model_provider"] - migration_version = "1" def current_timestamp(self) -> int: return int(time.time()) @@ -480,7 +467,6 @@ class TenantModelInstanceStage(MigrationStage): description = "Migrate tenant_llm to tenant_model_instance with provider_id lookup" source_tables = ["tenant_llm", "tenant_model_provider"] target_tables = ["tenant_model_instance"] - migration_version = "1" def current_timestamp(self) -> int: return int(time.time()) @@ -647,7 +633,6 @@ class TenantModelStage(MigrationStage): description = "Migrate tenant_llm to tenant_model (status='0' records, plus status='1' for empty-llm factories)" source_tables = ["tenant_llm", "tenant_model_provider", "tenant_model_instance"] target_tables = ["tenant_model"] - migration_version = "1" @staticmethod def _get_empty_llm_factories() -> list[str]: @@ -861,7 +846,6 @@ class ModelIdConfigStage(MigrationStage): name = "model_id_config" description = "Normalize stored model IDs in config columns to model@default@provider" - migration_version = "1" source_tables = [ "tenant", "knowledgebase", @@ -1117,8 +1101,14 @@ def list_available_stages(): logger.info(f" Target tables: {stage_cls.target_tables}") -def run_migration(config: MigrationConfig, stages: list, dry_run: bool = True, - create_table_only: bool = False): +def run_migration( + config: MigrationConfig, + stages: list, + dry_run: bool = True, + create_table_only: bool = False, + database_version: str | None = None, + mark_database_version_on_success: bool = False, +): """Run migration with specified stages""" stats = MigrationStats() stats.start() @@ -1127,8 +1117,31 @@ def run_migration(config: MigrationConfig, stages: list, dry_run: bool = True, try: db.connect() + + if database_version: + current_db_version = db.get_database_version() + if should_skip_migration(current_db_version, database_version): + logger.info( + "Database migration version is %s, target version is %s, skipping all stages", + current_db_version, + database_version, + ) + return + + if current_db_version: + logger.info( + "Current database migration version is %s, target version is %s", + current_db_version, + database_version, + ) + else: + logger.info( + "Database migration version marker is not set, target version is %s", + database_version, + ) total_stages = len(stages) + all_stages_completed = True for idx, stage_name in enumerate(stages, 1): logger.info(f"{'=' * 60}") @@ -1138,16 +1151,13 @@ def run_migration(config: MigrationConfig, stages: list, dry_run: bool = True, if stage_name not in MIGRATION_STAGES: logger.error(f"Unknown stage: {stage_name}") stats.add_stage_stats(stage_name, [], 0, 0) + all_stages_completed = False continue stage_cls = MIGRATION_STAGES[stage_name] stage = stage_cls(db, dry_run=dry_run, create_table_only=create_table_only) stage_start = time.time() - - if not create_table_only and stage.is_migration_version_applied(): - stats.add_stage_stats(stage_name, [], 0, time.time() - stage_start) - continue # For create_table_only mode, skip check and directly execute if create_table_only: @@ -1157,27 +1167,72 @@ def run_migration(config: MigrationConfig, stages: list, dry_run: bool = True, # Check if migration is needed if not stage.check(): logger.info(f"Stage '{stage_name}' check: no migration needed") - if not dry_run and stage.noop_completes_migration(): - stage.mark_migration_version_applied() stats.add_stage_stats(stage_name, [], 0, time.time() - stage_start) continue # Execute migration rows, tables = stage.execute() - if not dry_run: - stage.mark_migration_version_applied() stage_duration = time.time() - stage_start stats.add_stage_stats(stage_name, tables, rows, stage_duration) logger.info(f"Stage '{stage_name}' completed: {rows} rows in {stage_duration:.2f}s") - + + if ( + mark_database_version_on_success + and not dry_run + and not create_table_only + and database_version + and all_stages_completed + ): + db.set_database_version(database_version) + logger.info("Marked database migration version as %s", database_version) + finally: db.close() stats.end() stats.print_summary() +def check_database_version(config: MigrationConfig, target_version: str) -> int: + db = MigrationDatabase(config) + try: + db.connect() + current_db_version = db.get_database_version() + if should_skip_migration(current_db_version, target_version): + logger.info( + "Database migration version is %s, target version is %s, migration is not needed", + current_db_version, + target_version, + ) + return 0 + + if current_db_version: + logger.info( + "Database migration version is %s, target version is %s, migration is needed", + current_db_version, + target_version, + ) + else: + logger.info( + "Database migration version marker is not set, target version is %s, migration is needed", + target_version, + ) + return 1 + finally: + db.close() + + +def mark_database_version(config: MigrationConfig, version: str) -> None: + db = MigrationDatabase(config) + try: + db.connect() + db.set_database_version(version) + logger.info("Marked database migration version as %s", version) + finally: + db.close() + + def main(): parser = argparse.ArgumentParser( description='MySQL Data Migration Tool', @@ -1186,6 +1241,12 @@ def main(): Examples: # List available stages python mysql_migration.py --list-stages + + # Check whether migration is needed for a target version + python mysql_migration.py --check-database-version --database-version v0.26.0 --config /path/to/config.yaml + + # Mark database version separately + python mysql_migration.py --mark-database-version --database-version v0.26.0 --config /path/to/config.yaml # Dry run (default - check only, no write) with config file python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml @@ -1198,6 +1259,12 @@ Examples: # Execute full migration (create tables and migrate data) python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml --execute + + # Execute migration only when database version is lower than v0.26.0 + python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml --execute --database-version v0.26.0 + + # Execute migration and mark the database version when all stages succeed + python mysql_migration.py --stages tenant_model_provider,tenant_model_instance,tenant_model,model_id_config --config /path/to/config.yaml --execute --database-version v0.26.0 --mark-database-version-on-success # Normalize legacy model IDs in stored configs python mysql_migration.py --stages model_id_config --config /path/to/config.yaml --execute @@ -1225,6 +1292,14 @@ Examples: # Migration options parser.add_argument('--stages', '-s', type=str, help='Comma-separated list of stages to run') parser.add_argument('--list-stages', '-l', action='store_true', help='List available stages') + parser.add_argument('--check-database-version', action='store_true', + help='Check whether migration is needed for the target database version') + parser.add_argument('--mark-database-version', action='store_true', + help='Write the database migration version marker and exit') + parser.add_argument('--database-version', type=str, metavar='VERSION', + help='Database migration version used by check/mark commands and as the migration threshold for --stages') + parser.add_argument('--mark-database-version-on-success', action='store_true', + help='When used with --stages and --execute, write --database-version after all stages succeed') parser.add_argument('--execute', '-e', action='store_true', default=False, help='Execute full migration: create tables and migrate data') parser.add_argument('--create-table-only', action='store_true', default=False, @@ -1237,13 +1312,6 @@ Examples: list_available_stages() return - # Parse stages - if not args.stages: - logger.error("No stages specified. Use --stages to specify stages or --list-stages to see available stages.") - sys.exit(1) - - stages = [s.strip() for s in args.stages.split(',')] - # Load configuration: command line args take precedence over config file if args.config: config = MigrationConfig.from_config_file(args.config) @@ -1270,12 +1338,38 @@ Examples: logger.info(f"MySQL Configuration: host={config.host}, port={config.port}, " f"user={config.user}, database={config.database}") + + if args.check_database_version and args.mark_database_version: + logger.error("--check-database-version and --mark-database-version are mutually exclusive") + sys.exit(1) + + if args.check_database_version: + if not args.database_version: + logger.error("--check-database-version requires --database-version") + sys.exit(1) + sys.exit(check_database_version(config, args.database_version)) + + if args.mark_database_version: + if not args.database_version: + logger.error("--mark-database-version requires --database-version") + sys.exit(1) + mark_database_version(config, args.database_version) + return + + if args.mark_database_version_on_success and not args.database_version: + logger.error("--mark-database-version-on-success requires --database-version") + sys.exit(1) # Three mutually exclusive modes: dry-run (default), create-table-only, execute if args.execute and args.create_table_only: logger.error("--execute and --create-table-only are mutually exclusive") sys.exit(1) - + + if not args.stages: + logger.error("No stages specified. Use --stages to specify stages or --list-stages to see available stages.") + sys.exit(1) + + stages = [s.strip() for s in args.stages.split(',')] dry_run = True create_table_only = False @@ -1294,7 +1388,9 @@ Examples: config=config, stages=stages, dry_run=dry_run, - create_table_only=create_table_only + create_table_only=create_table_only, + database_version=args.database_version, + mark_database_version_on_success=args.mark_database_version_on_success, )