Fix: migrate legacy model id configs (#15495)

### What problem does this PR solve?

migrate legacy model id configs

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
buua436
2026-06-02 13:08:58 +08:00
committed by GitHub
parent 33ef724b5f
commit 2e02bf7ba4
2 changed files with 255 additions and 0 deletions

View File

@@ -24,6 +24,7 @@ This script provides a flexible MySQL data migration tool that supports:
"""
import argparse
import json
import logging
import os
import sys
@@ -167,6 +168,14 @@ class MigrationDatabase:
)
return cursor.fetchone()[0] > 0
def column_exists(self, table_name: str, column_name: str) -> bool:
cursor = self.execute_sql(
"SELECT COUNT(*) FROM information_schema.columns "
"WHERE table_schema = %s AND table_name = %s AND column_name = %s",
(self.config.database, table_name, column_name)
)
return cursor.fetchone()[0] > 0
# Define model classes for migration (not importing from api.db.db_models)
class BaseModel(Model):
@@ -725,11 +734,253 @@ class TenantModelStage(MigrationStage):
logger.info("Created tenant_model table")
class ModelIdConfigStage(MigrationStage):
"""Normalize stored model IDs from model@provider to model@default@provider."""
name = "model_id_config"
description = "Normalize stored model IDs in config columns to model@default@provider"
source_tables = [
"tenant",
"knowledgebase",
"document",
"dialog",
"memory",
"search",
"user_canvas",
"canvas_template",
"user_canvas_version",
"api_4_conversation",
"pipeline_operation_log",
"connector",
"evaluation_runs",
]
target_tables = source_tables
model_id_fields = {
"llm_id",
"embd_id",
"embedding_model",
"rerank_id",
"asr_id",
"img2txt_id",
"tts_id",
"ocr_id",
}
search_config_model_id_fields = {"chat_id"}
scan_batch_size = 500
string_columns = {
"tenant": ("llm_id", "embd_id", "asr_id", "img2txt_id", "rerank_id", "tts_id", "ocr_id"),
"knowledgebase": ("embd_id",),
"dialog": ("llm_id", "rerank_id"),
"memory": ("embd_id", "llm_id"),
}
json_columns = {
"knowledgebase": ("parser_config",),
"document": ("parser_config",),
"search": ("search_config",),
"user_canvas": ("dsl",),
"canvas_template": ("dsl",),
"user_canvas_version": ("dsl",),
"api_4_conversation": ("dsl",),
"pipeline_operation_log": ("dsl",),
"connector": ("config",),
"evaluation_runs": ("config_snapshot",),
}
def normalize_model_id(self, value):
if not isinstance(value, str) or not value:
return value, False
parts = value.split("@")
if len(parts) != 2:
return value, False
model_name, provider_name = parts
if not model_name or not provider_name:
return value, False
return f"{model_name}@default@{provider_name}", True
def normalize_config(self, value, path=None):
path = path or ()
if isinstance(value, dict):
changed = False
normalized = {}
for key, item in value.items():
key_path = path + (str(key),)
should_normalize = key in self.model_id_fields or (
key in self.search_config_model_id_fields and "search_config" in path
)
if should_normalize:
normalized_item, item_changed = self.normalize_model_id(item)
else:
normalized_item, item_changed = self.normalize_config(item, key_path)
normalized[key] = normalized_item
changed = changed or item_changed
return normalized, changed
if isinstance(value, list):
changed = False
normalized = []
for index, item in enumerate(value):
normalized_item, item_changed = self.normalize_config(item, path + (str(index),))
normalized.append(normalized_item)
changed = changed or item_changed
return normalized, changed
return value, False
def existing_columns(self, table_columns):
for table_name, columns in table_columns.items():
if not self.db.table_exists(table_name):
logger.info("Table '%s' does not exist, skipping", table_name)
continue
for column_name in columns:
if not self.db.column_exists(table_name, column_name):
logger.info("Column '%s.%s' does not exist, skipping", table_name, column_name)
continue
yield table_name, column_name
def load_json_value(self, raw_value, table_name, column_name, row_id):
if raw_value in (None, ""):
return None, False
if isinstance(raw_value, (dict, list)):
return raw_value, True
try:
return json.loads(raw_value), True
except (TypeError, json.JSONDecodeError):
logger.warning(
"Failed to parse JSON in %s.%s id=%s, skipping",
table_name,
column_name,
row_id,
)
return None, False
def iter_string_changes(self):
for table_name, column_name in self.existing_columns(self.string_columns):
cursor = self.db.execute_sql(
f"SELECT id, `{column_name}` FROM `{table_name}` "
f"WHERE `{column_name}` IS NOT NULL AND `{column_name}` != '' AND `{column_name}` LIKE %s",
("%@%",),
)
while True:
rows = cursor.fetchmany(self.scan_batch_size)
if not rows:
break
for row_id, value in rows:
normalized, changed = self.normalize_model_id(value)
if changed:
yield table_name, column_name, row_id, normalized
def iter_json_changes(self):
for table_name, column_name in self.existing_columns(self.json_columns):
cursor = self.db.execute_sql(
f"SELECT id, `{column_name}` FROM `{table_name}` "
f"WHERE `{column_name}` IS NOT NULL AND `{column_name}` != '' AND `{column_name}` LIKE %s",
("%@%",),
)
while True:
rows = cursor.fetchmany(self.scan_batch_size)
if not rows:
break
for row_id, raw_value in rows:
config, loaded = self.load_json_value(raw_value, table_name, column_name, row_id)
if not loaded:
continue
normalized, changed = self.normalize_config(config, (column_name,))
if changed:
normalized_json = json.dumps(
normalized,
ensure_ascii=False,
separators=(",", ":"),
)
yield table_name, column_name, row_id, normalized_json
def count_changes(self) -> tuple[int, set]:
rows = 0
tables = set()
for table_name, _, _, _ in self.iter_string_changes():
rows += 1
tables.add(table_name)
for table_name, _, _, _ in self.iter_json_changes():
rows += 1
tables.add(table_name)
return rows, tables
def check(self) -> bool:
rows, tables = self.count_changes()
if rows == 0:
logger.info("No stored model IDs need normalization")
return False
logger.info(
"Found %s rows to normalize across tables: %s",
rows,
", ".join(sorted(tables)),
)
return True
def execute(self) -> tuple[int, list]:
if self.create_table_only:
logger.info("[CREATE TABLE ONLY] No tables are created for this data migration")
return 0, []
rows_updated = 0
tables_operated = set()
for table_name, column_name, row_id, normalized in self.iter_string_changes():
tables_operated.add(table_name)
rows_updated += 1
if rows_updated <= 10:
logger.info(
"%s %s.%s id=%s -> %s",
"[DRY RUN] Would update" if self.dry_run else "Updating",
table_name,
column_name,
row_id,
normalized,
)
if not self.dry_run:
self.db.execute_sql(
f"UPDATE `{table_name}` SET `{column_name}` = %s WHERE id = %s",
(normalized, row_id),
)
for table_name, column_name, row_id, normalized_json in self.iter_json_changes():
tables_operated.add(table_name)
rows_updated += 1
if rows_updated <= 10:
logger.info(
"%s %s.%s id=%s",
"[DRY RUN] Would update" if self.dry_run else "Updating",
table_name,
column_name,
row_id,
)
if not self.dry_run:
self.db.execute_sql(
f"UPDATE `{table_name}` SET `{column_name}` = %s WHERE id = %s",
(normalized_json, row_id),
)
if rows_updated > 10:
logger.info("... and %s more row updates", rows_updated - 10)
if self.dry_run:
logger.info("[DRY RUN] Would update %s rows", rows_updated)
else:
logger.info("Updated %s rows", rows_updated)
return rows_updated, sorted(tables_operated)
# Registry of available migration stages
MIGRATION_STAGES = {
'tenant_model_provider': TenantModelProviderStage,
'tenant_model_instance': TenantModelInstanceStage,
'tenant_model': TenantModelStage,
'model_id_config': ModelIdConfigStage,
}
@@ -816,6 +1067,9 @@ Examples:
# Execute full migration (create tables and migrate data)
python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml --execute
# Normalize legacy model IDs in stored configs
python mysql_migration.py --stages model_id_config --config /path/to/config.yaml --execute
# Run multiple stages
python mysql_migration.py --stages stage1,stage2,stage3 --config /path/to/config.yaml --execute
"""