mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Refa: refine mysql migration version workflow (#15549)
### What problem does this PR solve? refine mysql migration version workflow ### Type of change - [x] Refactoring
This commit is contained in:
@@ -31,6 +31,7 @@ import sys
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from peewee import (
|
||||
CharField,
|
||||
IntegerField,
|
||||
@@ -55,6 +56,9 @@ logging.basicConfig(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MIGRATION_DB_VERSION_MARKER = "mysql_migration.database.version"
|
||||
|
||||
|
||||
class MigrationConfig:
|
||||
"""Configuration for MySQL connection"""
|
||||
|
||||
@@ -217,6 +221,33 @@ class MigrationDatabase:
|
||||
),
|
||||
)
|
||||
|
||||
def get_database_version(self) -> str | None:
|
||||
return self.get_system_setting_value(MIGRATION_DB_VERSION_MARKER)
|
||||
|
||||
def set_database_version(self, version: str):
|
||||
self.upsert_system_setting(MIGRATION_DB_VERSION_MARKER, version)
|
||||
|
||||
|
||||
def parse_migration_version(version: str | None) -> Version | None:
|
||||
if not version:
|
||||
return None
|
||||
normalized = version.strip()
|
||||
if normalized.startswith(("v", "V")):
|
||||
normalized = normalized[1:]
|
||||
try:
|
||||
return Version(normalized)
|
||||
except InvalidVersion:
|
||||
logger.warning("Invalid migration version format: %s", version)
|
||||
return None
|
||||
|
||||
|
||||
def should_skip_migration(current_db_version: str | None, target_version: str) -> bool:
|
||||
current = parse_migration_version(current_db_version)
|
||||
target = parse_migration_version(target_version)
|
||||
if current is None or target is None:
|
||||
return False
|
||||
return current >= target
|
||||
|
||||
|
||||
# Define model classes for migration (not importing from api.db.db_models)
|
||||
class BaseModel(Model):
|
||||
@@ -266,9 +297,6 @@ class MigrationStage:
|
||||
description = "Base migration stage"
|
||||
source_tables = []
|
||||
target_tables = []
|
||||
migration_version = None
|
||||
migration_marker_prefix = "mysql_migration"
|
||||
|
||||
def __init__(self, db: MigrationDatabase, dry_run: bool = True, create_table_only: bool = False):
|
||||
self.db = db
|
||||
self.dry_run = dry_run
|
||||
@@ -287,46 +315,6 @@ class MigrationStage:
|
||||
"""Create target table (override in subclass if needed)"""
|
||||
pass
|
||||
|
||||
def migration_marker_name(self) -> str:
|
||||
return f"{self.migration_marker_prefix}.{self.name}.version"
|
||||
|
||||
def is_migration_version_applied(self) -> bool:
|
||||
if not self.migration_version:
|
||||
return False
|
||||
|
||||
marker_name = self.migration_marker_name()
|
||||
current_version = self.db.get_system_setting_value(marker_name)
|
||||
if current_version == self.migration_version:
|
||||
logger.info(
|
||||
"Stage '%s' already applied at version %s, skipping",
|
||||
self.name,
|
||||
self.migration_version,
|
||||
)
|
||||
return True
|
||||
|
||||
if current_version:
|
||||
logger.info(
|
||||
"Stage '%s' marker version is %s, target version is %s",
|
||||
self.name,
|
||||
current_version,
|
||||
self.migration_version,
|
||||
)
|
||||
return False
|
||||
|
||||
def mark_migration_version_applied(self):
|
||||
if not self.migration_version:
|
||||
return
|
||||
|
||||
self.db.upsert_system_setting(
|
||||
self.migration_marker_name(),
|
||||
self.migration_version,
|
||||
)
|
||||
logger.info(
|
||||
"Marked stage '%s' as applied at version %s",
|
||||
self.name,
|
||||
self.migration_version,
|
||||
)
|
||||
|
||||
def mark_noop_completes_migration(self):
|
||||
self._noop_completes_migration = True
|
||||
|
||||
@@ -341,7 +329,6 @@ class TenantModelProviderStage(MigrationStage):
|
||||
description = "Migrate tenant_llm.llm_factory to tenant_model_provider.provider_name"
|
||||
source_tables = ["tenant_llm"]
|
||||
target_tables = ["tenant_model_provider"]
|
||||
migration_version = "1"
|
||||
|
||||
def current_timestamp(self) -> int:
|
||||
return int(time.time())
|
||||
@@ -480,7 +467,6 @@ class TenantModelInstanceStage(MigrationStage):
|
||||
description = "Migrate tenant_llm to tenant_model_instance with provider_id lookup"
|
||||
source_tables = ["tenant_llm", "tenant_model_provider"]
|
||||
target_tables = ["tenant_model_instance"]
|
||||
migration_version = "1"
|
||||
|
||||
def current_timestamp(self) -> int:
|
||||
return int(time.time())
|
||||
@@ -647,7 +633,6 @@ class TenantModelStage(MigrationStage):
|
||||
description = "Migrate tenant_llm to tenant_model (status='0' records, plus status='1' for empty-llm factories)"
|
||||
source_tables = ["tenant_llm", "tenant_model_provider", "tenant_model_instance"]
|
||||
target_tables = ["tenant_model"]
|
||||
migration_version = "1"
|
||||
|
||||
@staticmethod
|
||||
def _get_empty_llm_factories() -> list[str]:
|
||||
@@ -861,7 +846,6 @@ class ModelIdConfigStage(MigrationStage):
|
||||
|
||||
name = "model_id_config"
|
||||
description = "Normalize stored model IDs in config columns to model@default@provider"
|
||||
migration_version = "1"
|
||||
source_tables = [
|
||||
"tenant",
|
||||
"knowledgebase",
|
||||
@@ -1117,8 +1101,14 @@ def list_available_stages():
|
||||
logger.info(f" Target tables: {stage_cls.target_tables}")
|
||||
|
||||
|
||||
def run_migration(config: MigrationConfig, stages: list, dry_run: bool = True,
|
||||
create_table_only: bool = False):
|
||||
def run_migration(
|
||||
config: MigrationConfig,
|
||||
stages: list,
|
||||
dry_run: bool = True,
|
||||
create_table_only: bool = False,
|
||||
database_version: str | None = None,
|
||||
mark_database_version_on_success: bool = False,
|
||||
):
|
||||
"""Run migration with specified stages"""
|
||||
stats = MigrationStats()
|
||||
stats.start()
|
||||
@@ -1127,8 +1117,31 @@ def run_migration(config: MigrationConfig, stages: list, dry_run: bool = True,
|
||||
|
||||
try:
|
||||
db.connect()
|
||||
|
||||
if database_version:
|
||||
current_db_version = db.get_database_version()
|
||||
if should_skip_migration(current_db_version, database_version):
|
||||
logger.info(
|
||||
"Database migration version is %s, target version is %s, skipping all stages",
|
||||
current_db_version,
|
||||
database_version,
|
||||
)
|
||||
return
|
||||
|
||||
if current_db_version:
|
||||
logger.info(
|
||||
"Current database migration version is %s, target version is %s",
|
||||
current_db_version,
|
||||
database_version,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Database migration version marker is not set, target version is %s",
|
||||
database_version,
|
||||
)
|
||||
|
||||
total_stages = len(stages)
|
||||
all_stages_completed = True
|
||||
|
||||
for idx, stage_name in enumerate(stages, 1):
|
||||
logger.info(f"{'=' * 60}")
|
||||
@@ -1138,16 +1151,13 @@ def run_migration(config: MigrationConfig, stages: list, dry_run: bool = True,
|
||||
if stage_name not in MIGRATION_STAGES:
|
||||
logger.error(f"Unknown stage: {stage_name}")
|
||||
stats.add_stage_stats(stage_name, [], 0, 0)
|
||||
all_stages_completed = False
|
||||
continue
|
||||
|
||||
stage_cls = MIGRATION_STAGES[stage_name]
|
||||
stage = stage_cls(db, dry_run=dry_run, create_table_only=create_table_only)
|
||||
|
||||
stage_start = time.time()
|
||||
|
||||
if not create_table_only and stage.is_migration_version_applied():
|
||||
stats.add_stage_stats(stage_name, [], 0, time.time() - stage_start)
|
||||
continue
|
||||
|
||||
# For create_table_only mode, skip check and directly execute
|
||||
if create_table_only:
|
||||
@@ -1157,27 +1167,72 @@ def run_migration(config: MigrationConfig, stages: list, dry_run: bool = True,
|
||||
# Check if migration is needed
|
||||
if not stage.check():
|
||||
logger.info(f"Stage '{stage_name}' check: no migration needed")
|
||||
if not dry_run and stage.noop_completes_migration():
|
||||
stage.mark_migration_version_applied()
|
||||
stats.add_stage_stats(stage_name, [], 0, time.time() - stage_start)
|
||||
continue
|
||||
|
||||
# Execute migration
|
||||
rows, tables = stage.execute()
|
||||
if not dry_run:
|
||||
stage.mark_migration_version_applied()
|
||||
|
||||
stage_duration = time.time() - stage_start
|
||||
|
||||
stats.add_stage_stats(stage_name, tables, rows, stage_duration)
|
||||
logger.info(f"Stage '{stage_name}' completed: {rows} rows in {stage_duration:.2f}s")
|
||||
|
||||
|
||||
if (
|
||||
mark_database_version_on_success
|
||||
and not dry_run
|
||||
and not create_table_only
|
||||
and database_version
|
||||
and all_stages_completed
|
||||
):
|
||||
db.set_database_version(database_version)
|
||||
logger.info("Marked database migration version as %s", database_version)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
stats.end()
|
||||
stats.print_summary()
|
||||
|
||||
|
||||
def check_database_version(config: MigrationConfig, target_version: str) -> int:
|
||||
db = MigrationDatabase(config)
|
||||
try:
|
||||
db.connect()
|
||||
current_db_version = db.get_database_version()
|
||||
if should_skip_migration(current_db_version, target_version):
|
||||
logger.info(
|
||||
"Database migration version is %s, target version is %s, migration is not needed",
|
||||
current_db_version,
|
||||
target_version,
|
||||
)
|
||||
return 0
|
||||
|
||||
if current_db_version:
|
||||
logger.info(
|
||||
"Database migration version is %s, target version is %s, migration is needed",
|
||||
current_db_version,
|
||||
target_version,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Database migration version marker is not set, target version is %s, migration is needed",
|
||||
target_version,
|
||||
)
|
||||
return 1
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def mark_database_version(config: MigrationConfig, version: str) -> None:
|
||||
db = MigrationDatabase(config)
|
||||
try:
|
||||
db.connect()
|
||||
db.set_database_version(version)
|
||||
logger.info("Marked database migration version as %s", version)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='MySQL Data Migration Tool',
|
||||
@@ -1186,6 +1241,12 @@ def main():
|
||||
Examples:
|
||||
# List available stages
|
||||
python mysql_migration.py --list-stages
|
||||
|
||||
# Check whether migration is needed for a target version
|
||||
python mysql_migration.py --check-database-version --database-version v0.26.0 --config /path/to/config.yaml
|
||||
|
||||
# Mark database version separately
|
||||
python mysql_migration.py --mark-database-version --database-version v0.26.0 --config /path/to/config.yaml
|
||||
|
||||
# Dry run (default - check only, no write) with config file
|
||||
python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml
|
||||
@@ -1198,6 +1259,12 @@ Examples:
|
||||
|
||||
# Execute full migration (create tables and migrate data)
|
||||
python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml --execute
|
||||
|
||||
# Execute migration only when database version is lower than v0.26.0
|
||||
python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml --execute --database-version v0.26.0
|
||||
|
||||
# Execute migration and mark the database version when all stages succeed
|
||||
python mysql_migration.py --stages tenant_model_provider,tenant_model_instance,tenant_model,model_id_config --config /path/to/config.yaml --execute --database-version v0.26.0 --mark-database-version-on-success
|
||||
|
||||
# Normalize legacy model IDs in stored configs
|
||||
python mysql_migration.py --stages model_id_config --config /path/to/config.yaml --execute
|
||||
@@ -1225,6 +1292,14 @@ Examples:
|
||||
# Migration options
|
||||
parser.add_argument('--stages', '-s', type=str, help='Comma-separated list of stages to run')
|
||||
parser.add_argument('--list-stages', '-l', action='store_true', help='List available stages')
|
||||
parser.add_argument('--check-database-version', action='store_true',
|
||||
help='Check whether migration is needed for the target database version')
|
||||
parser.add_argument('--mark-database-version', action='store_true',
|
||||
help='Write the database migration version marker and exit')
|
||||
parser.add_argument('--database-version', type=str, metavar='VERSION',
|
||||
help='Database migration version used by check/mark commands and as the migration threshold for --stages')
|
||||
parser.add_argument('--mark-database-version-on-success', action='store_true',
|
||||
help='When used with --stages and --execute, write --database-version after all stages succeed')
|
||||
parser.add_argument('--execute', '-e', action='store_true', default=False,
|
||||
help='Execute full migration: create tables and migrate data')
|
||||
parser.add_argument('--create-table-only', action='store_true', default=False,
|
||||
@@ -1237,13 +1312,6 @@ Examples:
|
||||
list_available_stages()
|
||||
return
|
||||
|
||||
# Parse stages
|
||||
if not args.stages:
|
||||
logger.error("No stages specified. Use --stages to specify stages or --list-stages to see available stages.")
|
||||
sys.exit(1)
|
||||
|
||||
stages = [s.strip() for s in args.stages.split(',')]
|
||||
|
||||
# Load configuration: command line args take precedence over config file
|
||||
if args.config:
|
||||
config = MigrationConfig.from_config_file(args.config)
|
||||
@@ -1270,12 +1338,38 @@ Examples:
|
||||
|
||||
logger.info(f"MySQL Configuration: host={config.host}, port={config.port}, "
|
||||
f"user={config.user}, database={config.database}")
|
||||
|
||||
if args.check_database_version and args.mark_database_version:
|
||||
logger.error("--check-database-version and --mark-database-version are mutually exclusive")
|
||||
sys.exit(1)
|
||||
|
||||
if args.check_database_version:
|
||||
if not args.database_version:
|
||||
logger.error("--check-database-version requires --database-version")
|
||||
sys.exit(1)
|
||||
sys.exit(check_database_version(config, args.database_version))
|
||||
|
||||
if args.mark_database_version:
|
||||
if not args.database_version:
|
||||
logger.error("--mark-database-version requires --database-version")
|
||||
sys.exit(1)
|
||||
mark_database_version(config, args.database_version)
|
||||
return
|
||||
|
||||
if args.mark_database_version_on_success and not args.database_version:
|
||||
logger.error("--mark-database-version-on-success requires --database-version")
|
||||
sys.exit(1)
|
||||
|
||||
# Three mutually exclusive modes: dry-run (default), create-table-only, execute
|
||||
if args.execute and args.create_table_only:
|
||||
logger.error("--execute and --create-table-only are mutually exclusive")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if not args.stages:
|
||||
logger.error("No stages specified. Use --stages to specify stages or --list-stages to see available stages.")
|
||||
sys.exit(1)
|
||||
|
||||
stages = [s.strip() for s in args.stages.split(',')]
|
||||
dry_run = True
|
||||
create_table_only = False
|
||||
|
||||
@@ -1294,7 +1388,9 @@ Examples:
|
||||
config=config,
|
||||
stages=stages,
|
||||
dry_run=dry_run,
|
||||
create_table_only=create_table_only
|
||||
create_table_only=create_table_only,
|
||||
database_version=args.database_version,
|
||||
mark_database_version_on_success=args.mark_database_version_on_success,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user