diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 4584b7159e..503a913d75 100755 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -277,6 +277,7 @@ if [[ "${INIT_MODEL_PROVIDER_TABLES}" -eq 1 ]]; then "$PY" tools/scripts/mysql_migration.py --stages tenant_model_provider --config conf/service_conf.yaml --execute "$PY" tools/scripts/mysql_migration.py --stages tenant_model_instance --config conf/service_conf.yaml --execute "$PY" tools/scripts/mysql_migration.py --stages tenant_model --config conf/service_conf.yaml --execute + "$PY" tools/scripts/mysql_migration.py --stages model_id_config --config conf/service_conf.yaml --execute echo "Model provider table migrations completed." fi diff --git a/tools/scripts/mysql_migration.py b/tools/scripts/mysql_migration.py index 709b5d6e9c..5079592523 100644 --- a/tools/scripts/mysql_migration.py +++ b/tools/scripts/mysql_migration.py @@ -24,6 +24,7 @@ This script provides a flexible MySQL data migration tool that supports: """ import argparse +import json import logging import os import sys @@ -167,6 +168,14 @@ class MigrationDatabase: ) return cursor.fetchone()[0] > 0 + def column_exists(self, table_name: str, column_name: str) -> bool: + cursor = self.execute_sql( + "SELECT COUNT(*) FROM information_schema.columns " + "WHERE table_schema = %s AND table_name = %s AND column_name = %s", + (self.config.database, table_name, column_name) + ) + return cursor.fetchone()[0] > 0 + # Define model classes for migration (not importing from api.db.db_models) class BaseModel(Model): @@ -725,11 +734,253 @@ class TenantModelStage(MigrationStage): logger.info("Created tenant_model table") +class ModelIdConfigStage(MigrationStage): + """Normalize stored model IDs from model@provider to model@default@provider.""" + + name = "model_id_config" + description = "Normalize stored model IDs in config columns to model@default@provider" + source_tables = [ + "tenant", + "knowledgebase", + "document", + "dialog", + "memory", + "search", + "user_canvas", + "canvas_template", + "user_canvas_version", + "api_4_conversation", + "pipeline_operation_log", + "connector", + "evaluation_runs", + ] + target_tables = source_tables + + model_id_fields = { + "llm_id", + "embd_id", + "embedding_model", + "rerank_id", + "asr_id", + "img2txt_id", + "tts_id", + "ocr_id", + } + search_config_model_id_fields = {"chat_id"} + scan_batch_size = 500 + string_columns = { + "tenant": ("llm_id", "embd_id", "asr_id", "img2txt_id", "rerank_id", "tts_id", "ocr_id"), + "knowledgebase": ("embd_id",), + "dialog": ("llm_id", "rerank_id"), + "memory": ("embd_id", "llm_id"), + } + json_columns = { + "knowledgebase": ("parser_config",), + "document": ("parser_config",), + "search": ("search_config",), + "user_canvas": ("dsl",), + "canvas_template": ("dsl",), + "user_canvas_version": ("dsl",), + "api_4_conversation": ("dsl",), + "pipeline_operation_log": ("dsl",), + "connector": ("config",), + "evaluation_runs": ("config_snapshot",), + } + + def normalize_model_id(self, value): + if not isinstance(value, str) or not value: + return value, False + + parts = value.split("@") + if len(parts) != 2: + return value, False + + model_name, provider_name = parts + if not model_name or not provider_name: + return value, False + + return f"{model_name}@default@{provider_name}", True + + def normalize_config(self, value, path=None): + path = path or () + + if isinstance(value, dict): + changed = False + normalized = {} + for key, item in value.items(): + key_path = path + (str(key),) + should_normalize = key in self.model_id_fields or ( + key in self.search_config_model_id_fields and "search_config" in path + ) + if should_normalize: + normalized_item, item_changed = self.normalize_model_id(item) + else: + normalized_item, item_changed = self.normalize_config(item, key_path) + normalized[key] = normalized_item + changed = changed or item_changed + return normalized, changed + + if isinstance(value, list): + changed = False + normalized = [] + for index, item in enumerate(value): + normalized_item, item_changed = self.normalize_config(item, path + (str(index),)) + normalized.append(normalized_item) + changed = changed or item_changed + return normalized, changed + + return value, False + + def existing_columns(self, table_columns): + for table_name, columns in table_columns.items(): + if not self.db.table_exists(table_name): + logger.info("Table '%s' does not exist, skipping", table_name) + continue + for column_name in columns: + if not self.db.column_exists(table_name, column_name): + logger.info("Column '%s.%s' does not exist, skipping", table_name, column_name) + continue + yield table_name, column_name + + def load_json_value(self, raw_value, table_name, column_name, row_id): + if raw_value in (None, ""): + return None, False + if isinstance(raw_value, (dict, list)): + return raw_value, True + try: + return json.loads(raw_value), True + except (TypeError, json.JSONDecodeError): + logger.warning( + "Failed to parse JSON in %s.%s id=%s, skipping", + table_name, + column_name, + row_id, + ) + return None, False + + def iter_string_changes(self): + for table_name, column_name in self.existing_columns(self.string_columns): + cursor = self.db.execute_sql( + f"SELECT id, `{column_name}` FROM `{table_name}` " + f"WHERE `{column_name}` IS NOT NULL AND `{column_name}` != '' AND `{column_name}` LIKE %s", + ("%@%",), + ) + while True: + rows = cursor.fetchmany(self.scan_batch_size) + if not rows: + break + for row_id, value in rows: + normalized, changed = self.normalize_model_id(value) + if changed: + yield table_name, column_name, row_id, normalized + + def iter_json_changes(self): + for table_name, column_name in self.existing_columns(self.json_columns): + cursor = self.db.execute_sql( + f"SELECT id, `{column_name}` FROM `{table_name}` " + f"WHERE `{column_name}` IS NOT NULL AND `{column_name}` != '' AND `{column_name}` LIKE %s", + ("%@%",), + ) + while True: + rows = cursor.fetchmany(self.scan_batch_size) + if not rows: + break + for row_id, raw_value in rows: + config, loaded = self.load_json_value(raw_value, table_name, column_name, row_id) + if not loaded: + continue + normalized, changed = self.normalize_config(config, (column_name,)) + if changed: + normalized_json = json.dumps( + normalized, + ensure_ascii=False, + separators=(",", ":"), + ) + yield table_name, column_name, row_id, normalized_json + + def count_changes(self) -> tuple[int, set]: + rows = 0 + tables = set() + for table_name, _, _, _ in self.iter_string_changes(): + rows += 1 + tables.add(table_name) + for table_name, _, _, _ in self.iter_json_changes(): + rows += 1 + tables.add(table_name) + return rows, tables + + def check(self) -> bool: + rows, tables = self.count_changes() + if rows == 0: + logger.info("No stored model IDs need normalization") + return False + logger.info( + "Found %s rows to normalize across tables: %s", + rows, + ", ".join(sorted(tables)), + ) + return True + + def execute(self) -> tuple[int, list]: + if self.create_table_only: + logger.info("[CREATE TABLE ONLY] No tables are created for this data migration") + return 0, [] + + rows_updated = 0 + tables_operated = set() + + for table_name, column_name, row_id, normalized in self.iter_string_changes(): + tables_operated.add(table_name) + rows_updated += 1 + if rows_updated <= 10: + logger.info( + "%s %s.%s id=%s -> %s", + "[DRY RUN] Would update" if self.dry_run else "Updating", + table_name, + column_name, + row_id, + normalized, + ) + if not self.dry_run: + self.db.execute_sql( + f"UPDATE `{table_name}` SET `{column_name}` = %s WHERE id = %s", + (normalized, row_id), + ) + + for table_name, column_name, row_id, normalized_json in self.iter_json_changes(): + tables_operated.add(table_name) + rows_updated += 1 + if rows_updated <= 10: + logger.info( + "%s %s.%s id=%s", + "[DRY RUN] Would update" if self.dry_run else "Updating", + table_name, + column_name, + row_id, + ) + if not self.dry_run: + self.db.execute_sql( + f"UPDATE `{table_name}` SET `{column_name}` = %s WHERE id = %s", + (normalized_json, row_id), + ) + + if rows_updated > 10: + logger.info("... and %s more row updates", rows_updated - 10) + + if self.dry_run: + logger.info("[DRY RUN] Would update %s rows", rows_updated) + else: + logger.info("Updated %s rows", rows_updated) + + return rows_updated, sorted(tables_operated) + + # Registry of available migration stages MIGRATION_STAGES = { 'tenant_model_provider': TenantModelProviderStage, 'tenant_model_instance': TenantModelInstanceStage, 'tenant_model': TenantModelStage, + 'model_id_config': ModelIdConfigStage, } @@ -816,6 +1067,9 @@ Examples: # Execute full migration (create tables and migrate data) python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml --execute + # Normalize legacy model IDs in stored configs + python mysql_migration.py --stages model_id_config --config /path/to/config.yaml --execute + # Run multiple stages python mysql_migration.py --stages stage1,stage2,stage3 --config /path/to/config.yaml --execute """