mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Fix: migrate legacy model id configs (#15495)
### What problem does this PR solve? migrate legacy model id configs ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -24,6 +24,7 @@ This script provides a flexible MySQL data migration tool that supports:
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -167,6 +168,14 @@ class MigrationDatabase:
|
||||
)
|
||||
return cursor.fetchone()[0] > 0
|
||||
|
||||
def column_exists(self, table_name: str, column_name: str) -> bool:
|
||||
cursor = self.execute_sql(
|
||||
"SELECT COUNT(*) FROM information_schema.columns "
|
||||
"WHERE table_schema = %s AND table_name = %s AND column_name = %s",
|
||||
(self.config.database, table_name, column_name)
|
||||
)
|
||||
return cursor.fetchone()[0] > 0
|
||||
|
||||
|
||||
# Define model classes for migration (not importing from api.db.db_models)
|
||||
class BaseModel(Model):
|
||||
@@ -725,11 +734,253 @@ class TenantModelStage(MigrationStage):
|
||||
logger.info("Created tenant_model table")
|
||||
|
||||
|
||||
class ModelIdConfigStage(MigrationStage):
|
||||
"""Normalize stored model IDs from model@provider to model@default@provider."""
|
||||
|
||||
name = "model_id_config"
|
||||
description = "Normalize stored model IDs in config columns to model@default@provider"
|
||||
source_tables = [
|
||||
"tenant",
|
||||
"knowledgebase",
|
||||
"document",
|
||||
"dialog",
|
||||
"memory",
|
||||
"search",
|
||||
"user_canvas",
|
||||
"canvas_template",
|
||||
"user_canvas_version",
|
||||
"api_4_conversation",
|
||||
"pipeline_operation_log",
|
||||
"connector",
|
||||
"evaluation_runs",
|
||||
]
|
||||
target_tables = source_tables
|
||||
|
||||
model_id_fields = {
|
||||
"llm_id",
|
||||
"embd_id",
|
||||
"embedding_model",
|
||||
"rerank_id",
|
||||
"asr_id",
|
||||
"img2txt_id",
|
||||
"tts_id",
|
||||
"ocr_id",
|
||||
}
|
||||
search_config_model_id_fields = {"chat_id"}
|
||||
scan_batch_size = 500
|
||||
string_columns = {
|
||||
"tenant": ("llm_id", "embd_id", "asr_id", "img2txt_id", "rerank_id", "tts_id", "ocr_id"),
|
||||
"knowledgebase": ("embd_id",),
|
||||
"dialog": ("llm_id", "rerank_id"),
|
||||
"memory": ("embd_id", "llm_id"),
|
||||
}
|
||||
json_columns = {
|
||||
"knowledgebase": ("parser_config",),
|
||||
"document": ("parser_config",),
|
||||
"search": ("search_config",),
|
||||
"user_canvas": ("dsl",),
|
||||
"canvas_template": ("dsl",),
|
||||
"user_canvas_version": ("dsl",),
|
||||
"api_4_conversation": ("dsl",),
|
||||
"pipeline_operation_log": ("dsl",),
|
||||
"connector": ("config",),
|
||||
"evaluation_runs": ("config_snapshot",),
|
||||
}
|
||||
|
||||
def normalize_model_id(self, value):
|
||||
if not isinstance(value, str) or not value:
|
||||
return value, False
|
||||
|
||||
parts = value.split("@")
|
||||
if len(parts) != 2:
|
||||
return value, False
|
||||
|
||||
model_name, provider_name = parts
|
||||
if not model_name or not provider_name:
|
||||
return value, False
|
||||
|
||||
return f"{model_name}@default@{provider_name}", True
|
||||
|
||||
def normalize_config(self, value, path=None):
|
||||
path = path or ()
|
||||
|
||||
if isinstance(value, dict):
|
||||
changed = False
|
||||
normalized = {}
|
||||
for key, item in value.items():
|
||||
key_path = path + (str(key),)
|
||||
should_normalize = key in self.model_id_fields or (
|
||||
key in self.search_config_model_id_fields and "search_config" in path
|
||||
)
|
||||
if should_normalize:
|
||||
normalized_item, item_changed = self.normalize_model_id(item)
|
||||
else:
|
||||
normalized_item, item_changed = self.normalize_config(item, key_path)
|
||||
normalized[key] = normalized_item
|
||||
changed = changed or item_changed
|
||||
return normalized, changed
|
||||
|
||||
if isinstance(value, list):
|
||||
changed = False
|
||||
normalized = []
|
||||
for index, item in enumerate(value):
|
||||
normalized_item, item_changed = self.normalize_config(item, path + (str(index),))
|
||||
normalized.append(normalized_item)
|
||||
changed = changed or item_changed
|
||||
return normalized, changed
|
||||
|
||||
return value, False
|
||||
|
||||
def existing_columns(self, table_columns):
|
||||
for table_name, columns in table_columns.items():
|
||||
if not self.db.table_exists(table_name):
|
||||
logger.info("Table '%s' does not exist, skipping", table_name)
|
||||
continue
|
||||
for column_name in columns:
|
||||
if not self.db.column_exists(table_name, column_name):
|
||||
logger.info("Column '%s.%s' does not exist, skipping", table_name, column_name)
|
||||
continue
|
||||
yield table_name, column_name
|
||||
|
||||
def load_json_value(self, raw_value, table_name, column_name, row_id):
|
||||
if raw_value in (None, ""):
|
||||
return None, False
|
||||
if isinstance(raw_value, (dict, list)):
|
||||
return raw_value, True
|
||||
try:
|
||||
return json.loads(raw_value), True
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
logger.warning(
|
||||
"Failed to parse JSON in %s.%s id=%s, skipping",
|
||||
table_name,
|
||||
column_name,
|
||||
row_id,
|
||||
)
|
||||
return None, False
|
||||
|
||||
def iter_string_changes(self):
|
||||
for table_name, column_name in self.existing_columns(self.string_columns):
|
||||
cursor = self.db.execute_sql(
|
||||
f"SELECT id, `{column_name}` FROM `{table_name}` "
|
||||
f"WHERE `{column_name}` IS NOT NULL AND `{column_name}` != '' AND `{column_name}` LIKE %s",
|
||||
("%@%",),
|
||||
)
|
||||
while True:
|
||||
rows = cursor.fetchmany(self.scan_batch_size)
|
||||
if not rows:
|
||||
break
|
||||
for row_id, value in rows:
|
||||
normalized, changed = self.normalize_model_id(value)
|
||||
if changed:
|
||||
yield table_name, column_name, row_id, normalized
|
||||
|
||||
def iter_json_changes(self):
|
||||
for table_name, column_name in self.existing_columns(self.json_columns):
|
||||
cursor = self.db.execute_sql(
|
||||
f"SELECT id, `{column_name}` FROM `{table_name}` "
|
||||
f"WHERE `{column_name}` IS NOT NULL AND `{column_name}` != '' AND `{column_name}` LIKE %s",
|
||||
("%@%",),
|
||||
)
|
||||
while True:
|
||||
rows = cursor.fetchmany(self.scan_batch_size)
|
||||
if not rows:
|
||||
break
|
||||
for row_id, raw_value in rows:
|
||||
config, loaded = self.load_json_value(raw_value, table_name, column_name, row_id)
|
||||
if not loaded:
|
||||
continue
|
||||
normalized, changed = self.normalize_config(config, (column_name,))
|
||||
if changed:
|
||||
normalized_json = json.dumps(
|
||||
normalized,
|
||||
ensure_ascii=False,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
yield table_name, column_name, row_id, normalized_json
|
||||
|
||||
def count_changes(self) -> tuple[int, set]:
|
||||
rows = 0
|
||||
tables = set()
|
||||
for table_name, _, _, _ in self.iter_string_changes():
|
||||
rows += 1
|
||||
tables.add(table_name)
|
||||
for table_name, _, _, _ in self.iter_json_changes():
|
||||
rows += 1
|
||||
tables.add(table_name)
|
||||
return rows, tables
|
||||
|
||||
def check(self) -> bool:
|
||||
rows, tables = self.count_changes()
|
||||
if rows == 0:
|
||||
logger.info("No stored model IDs need normalization")
|
||||
return False
|
||||
logger.info(
|
||||
"Found %s rows to normalize across tables: %s",
|
||||
rows,
|
||||
", ".join(sorted(tables)),
|
||||
)
|
||||
return True
|
||||
|
||||
def execute(self) -> tuple[int, list]:
|
||||
if self.create_table_only:
|
||||
logger.info("[CREATE TABLE ONLY] No tables are created for this data migration")
|
||||
return 0, []
|
||||
|
||||
rows_updated = 0
|
||||
tables_operated = set()
|
||||
|
||||
for table_name, column_name, row_id, normalized in self.iter_string_changes():
|
||||
tables_operated.add(table_name)
|
||||
rows_updated += 1
|
||||
if rows_updated <= 10:
|
||||
logger.info(
|
||||
"%s %s.%s id=%s -> %s",
|
||||
"[DRY RUN] Would update" if self.dry_run else "Updating",
|
||||
table_name,
|
||||
column_name,
|
||||
row_id,
|
||||
normalized,
|
||||
)
|
||||
if not self.dry_run:
|
||||
self.db.execute_sql(
|
||||
f"UPDATE `{table_name}` SET `{column_name}` = %s WHERE id = %s",
|
||||
(normalized, row_id),
|
||||
)
|
||||
|
||||
for table_name, column_name, row_id, normalized_json in self.iter_json_changes():
|
||||
tables_operated.add(table_name)
|
||||
rows_updated += 1
|
||||
if rows_updated <= 10:
|
||||
logger.info(
|
||||
"%s %s.%s id=%s",
|
||||
"[DRY RUN] Would update" if self.dry_run else "Updating",
|
||||
table_name,
|
||||
column_name,
|
||||
row_id,
|
||||
)
|
||||
if not self.dry_run:
|
||||
self.db.execute_sql(
|
||||
f"UPDATE `{table_name}` SET `{column_name}` = %s WHERE id = %s",
|
||||
(normalized_json, row_id),
|
||||
)
|
||||
|
||||
if rows_updated > 10:
|
||||
logger.info("... and %s more row updates", rows_updated - 10)
|
||||
|
||||
if self.dry_run:
|
||||
logger.info("[DRY RUN] Would update %s rows", rows_updated)
|
||||
else:
|
||||
logger.info("Updated %s rows", rows_updated)
|
||||
|
||||
return rows_updated, sorted(tables_operated)
|
||||
|
||||
|
||||
# Registry of available migration stages
|
||||
MIGRATION_STAGES = {
|
||||
'tenant_model_provider': TenantModelProviderStage,
|
||||
'tenant_model_instance': TenantModelInstanceStage,
|
||||
'tenant_model': TenantModelStage,
|
||||
'model_id_config': ModelIdConfigStage,
|
||||
}
|
||||
|
||||
|
||||
@@ -816,6 +1067,9 @@ Examples:
|
||||
# Execute full migration (create tables and migrate data)
|
||||
python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml --execute
|
||||
|
||||
# Normalize legacy model IDs in stored configs
|
||||
python mysql_migration.py --stages model_id_config --config /path/to/config.yaml --execute
|
||||
|
||||
# Run multiple stages
|
||||
python mysql_migration.py --stages stage1,stage2,stage3 --config /path/to/config.yaml --execute
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user