diff --git a/tools/scripts/mysql_migration.py b/tools/scripts/mysql_migration.py index e0811cdcd0..99711402dd 100644 --- a/tools/scripts/mysql_migration.py +++ b/tools/scripts/mysql_migration.py @@ -612,7 +612,7 @@ class TenantModelInstanceStage(MigrationStage): return rows_inserted, self.target_tables @staticmethod - def _strip_is_tools_from_api_key(api_key: str, llm_factory: str) -> str: + def _strip_is_tools_from_api_key(api_key: str) -> str: """Strip is_tools from api_key for dedup comparison. Handles three api_key formats: @@ -683,7 +683,7 @@ class TenantModelInstanceStage(MigrationStage): seen = {} # canonical_key -> first record for rec in group: _, _, api_key, _, _ = rec - canonical = self._strip_is_tools_from_api_key(api_key, llm_factory) + canonical = self._strip_is_tools_from_api_key(api_key) if canonical not in seen: seen[canonical] = rec else: @@ -799,18 +799,16 @@ class TenantModelStage(MigrationStage): status_condition = self._build_status_condition() # Check if there's data to migrate + # We cannot JOIN tenant_model_instance on api_key directly because the instance + # stage deduped api_keys (stripping is_tools), so a plain SQL equality won't + # match records whose api_key was merged. Count at the provider level instead. cursor = self.db.execute_sql( f"SELECT COUNT(*) FROM (" f" SELECT tl.id " f" FROM tenant_llm tl " f" INNER JOIN tenant_model_provider tmp ON tmp.tenant_id = tl.tenant_id AND tmp.provider_name = tl.llm_factory " - f" INNER JOIN tenant_model_instance tmi ON tmi.provider_id = tmp.id AND tmi.api_key = tl.api_key " f" WHERE {status_condition} " - f" AND NOT EXISTS (" - f" SELECT 1 FROM tenant_model tm " - f" WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name AND tm.instance_id = tmi.id" - f" )" - f") AS distinct_records" + f") AS source_records" ) count = cursor.fetchone()[0] @@ -855,18 +853,23 @@ class TenantModelStage(MigrationStage): status_condition = self._build_status_condition() - # Get records from tenant_llm with provider_id and instance_id lookup + # Load all tenant_model_instance records into memory for Python-level matching. + # We cannot JOIN on api_key in SQL because the instance stage deduped api_keys + # (stripping is_tools), so a plain SQL equality won't match records whose + # api_key was merged during dedup. + instance_lookup = self._build_instance_lookup() + + # Get records from tenant_llm with provider_id lookup (no instance JOIN) # Migrate status='0' records, plus status='1' for empty-llm factories cursor = self.db.execute_sql( - f"SELECT tl.id, tl.llm_name, tmp.id as provider_id, tmi.id as instance_id, " + f"SELECT tl.id, tl.llm_name, tmp.id as provider_id, " f" tl.model_type, tl.status, tl.api_key " f"FROM tenant_llm tl " f"INNER JOIN tenant_model_provider tmp ON tmp.tenant_id = tl.tenant_id AND tmp.provider_name = tl.llm_factory " - f"INNER JOIN tenant_model_instance tmi ON tmi.provider_id = tmp.id AND tmi.api_key = tl.api_key " f"WHERE {status_condition} " f"AND NOT EXISTS (" f" SELECT 1 FROM tenant_model tm " - f" WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name AND tm.instance_id = tmi.id" + f" WHERE tm.provider_id = tmp.id AND tm.model_name = tl.llm_name" f")" ) @@ -876,21 +879,28 @@ class TenantModelStage(MigrationStage): logger.info("No records to migrate") return 0, [] - logger.info(f"Migrating {len(records)} tenant_model records...") + # Resolve instance_id for each record using Python-level canonical matching + resolved_records = self._resolve_instance_ids(records, instance_lookup) + + if not resolved_records: + logger.info("No records with matching instance_id to migrate") + return 0, [] + + logger.info(f"Migrating {len(resolved_records)} tenant_model records...") if self.dry_run: - logger.info(f"[DRY RUN] Would insert {len(records)} records") - for source_id, llm_name, provider_id, instance_id, model_type, status, api_key in records[:5]: + logger.info(f"[DRY RUN] Would insert {len(resolved_records)} records") + for source_id, llm_name, provider_id, instance_id, model_type, status, api_key in resolved_records[:5]: logger.info(f" model_name={llm_name}, provider_id={provider_id}, " f"instance_id={instance_id}, model_type={model_type}") - if len(records) > 5: - logger.info(f" ... and {len(records) - 5} more records") - return len(records), self.target_tables + if len(resolved_records) > 5: + logger.info(f" ... and {len(resolved_records) - 5} more records") + return len(resolved_records), self.target_tables # Insert records in batches batch_size = 100 - for i in range(0, len(records), batch_size): - batch = records[i:i + batch_size] + for i in range(0, len(resolved_records), batch_size): + batch = resolved_records[i:i + batch_size] values = [] for source_id, llm_name, provider_id, instance_id, model_type, status, api_key in batch: record_id = self.generate_uuid() @@ -918,6 +928,57 @@ class TenantModelStage(MigrationStage): return rows_inserted, self.target_tables + def _build_instance_lookup(self) -> dict: + """Load all tenant_model_instance records, indexed by (provider_id, canonical_api_key). + + The canonical_api_key is computed by stripping is_tools from the stored api_key, + matching the dedup logic used during the instance migration stage. + + Returns: + dict mapping (provider_id, canonical_api_key) -> instance_id + """ + cursor = self.db.execute_sql( + "SELECT id, provider_id, api_key FROM tenant_model_instance" + ) + lookup = {} + for instance_id, provider_id, api_key in cursor.fetchall(): + canonical = TenantModelInstanceStage._strip_is_tools_from_api_key(api_key) + lookup[(provider_id, canonical)] = instance_id + logger.info(f"Loaded {len(lookup)} instance records for lookup") + return lookup + + @staticmethod + def _resolve_instance_ids(records: list, instance_lookup: dict) -> list: + """Resolve instance_id for each tenant_llm record using canonical api_key matching. + + Args: + records: list of tuples (source_id, llm_name, provider_id, model_type, status, api_key) + instance_lookup: dict mapping (provider_id, canonical_api_key) -> instance_id + + Returns: + list of tuples (source_id, llm_name, provider_id, instance_id, model_type, status, api_key) + Only records with a matching instance_id are included. + """ + resolved = [] + skipped = 0 + for source_id, llm_name, provider_id, model_type, status, api_key in records: + canonical = TenantModelInstanceStage._strip_is_tools_from_api_key(api_key) + instance_id = instance_lookup.get((provider_id, canonical)) + if instance_id: + resolved.append((source_id, llm_name, provider_id, instance_id, model_type, status, api_key)) + else: + skipped += 1 + logger.warning( + f"No matching instance for tenant_llm id={source_id}, " + f"provider_id={provider_id}, llm_name={llm_name}, " + f"canonical_api_key={canonical[:30]}..." + ) + + if skipped > 0: + logger.warning(f"Skipped {skipped} records with no matching instance_id") + + return resolved + @staticmethod def _extract_extra_from_api_key(api_key: str) -> str: """Extract is_tools from api_key JSON and return an extra JSON string for tenant_model.