From 47bd9dd049575645ee74c92e7238a73b0e945618 Mon Sep 17 00:00:00 2001 From: Lynn Date: Thu, 18 Jun 2026 16:38:32 +0800 Subject: [PATCH] Fix: replace tenant_llm apis (#16131) Replace tenant_llm apis with provider-instance apis. --- admin/client/ragflow_client.py | 161 ++++++++++++++++++--------- api/db/init_data.py | 57 +--------- api/db/services/llm_service.py | 50 --------- sdk/python/test/conftest.py | 39 ------- test/benchmark/README.md | 4 +- test/benchmark/auth.py | 58 +++++++--- test/benchmark/cli.py | 58 +++++++--- test/testcases/conftest.py | 196 ++++++++------------------------- 8 files changed, 244 insertions(+), 379 deletions(-) diff --git a/admin/client/ragflow_client.py b/admin/client/ragflow_client.py index fc939c3c51..b8102520ad 100644 --- a/admin/client/ragflow_client.py +++ b/admin/client/ragflow_client.py @@ -727,45 +727,133 @@ class RAGFlowClient: def create_model_provider(self, command): if self.server_type != "user": print("This command is only allowed in USER mode") - llm_factory: str = command["provider_name"] + return + provider_name: str = command["provider_name"] api_key: str = command["provider_key"] - payload = {"api_key": api_key, "llm_factory": llm_factory} - response = self.http_client.request("POST", "/llm/set_api_key", json_body=payload, use_api_base=False, - auth_kind="web") - res_json = response.json() - if response.status_code == 200 and res_json["code"] == 0: - print(f"Success to add model provider {llm_factory}") + + # Step 1: Add provider + provider_payload = {"provider_name": provider_name} + provider_response = self.http_client.request("PUT", "/providers", json_body=provider_payload, + use_api_base=True, auth_kind="web") + provider_res = provider_response.json() + if provider_response.status_code == 200 and provider_res.get("code") == 0: + print(f"Success to add provider {provider_name}") else: - print(f"Fail to add model provider {llm_factory}, code: {res_json['code']}, message: {res_json['message']}") + msg = provider_res.get("message", "") + if "duplicated" in msg.lower() or "already exist" in msg.lower(): + print(f"Note: provider {provider_name} already exists, continuing to add instance") + else: + print(f"Fail to add provider {provider_name}, code: {provider_res.get('code')}, message: {msg}") + return + + # Step 2: Add instance + instance_payload = { + "instance_name": "default", + "api_key": api_key, + "region": "default", + "base_url": "" + } + instance_response = self.http_client.request("POST", f"/providers/{provider_name}/instances", + json_body=instance_payload, use_api_base=True, + auth_kind="web") + instance_res = instance_response.json() + if instance_response.status_code == 200 and instance_res.get("code") == 0: + print(f"Success to add instance for provider {provider_name}") + else: + msg = instance_res.get("message", "") + if "already exist" in msg.lower(): + print(f"Note: instance for provider {provider_name} already exists, skipping") + else: + print(f"Fail to add instance for provider {provider_name}, code: {instance_res.get('code')}, message: {msg}") def drop_model_provider(self, command): if self.server_type != "user": print("This command is only allowed in USER mode") - llm_factory: str = command["provider_name"] - payload = {"llm_factory": llm_factory} - response = self.http_client.request("POST", "/llm/delete_factory", json_body=payload, use_api_base=False, + return + provider_name: str = command["provider_name"] + response = self.http_client.request("DELETE", f"/providers/{provider_name}", use_api_base=True, auth_kind="web") res_json = response.json() - if response.status_code == 200 and res_json["code"] == 0: - print(f"Success to drop model provider {llm_factory}") + if response.status_code == 200 and res_json.get("code") == 0: + print(f"Success to drop model provider {provider_name}") else: - print( - f"Fail to drop model provider {llm_factory}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to drop model provider {provider_name}, code: {res_json.get('code')}, message: {res_json.get('message')}") + + # Mapping from legacy model_type keys to API model_type values + _MODEL_TYPE_MAP = { + "llm_id": "chat", + "embd_id": "embedding", + "img2txt_id": "vision", + "reranker_id": "rerank", + "asr_id": "asr", + "tts_id": "tts", + } def set_default_model(self, command): if self.server_type != "user": print("This command is only allowed in USER mode") + return - model_type: str = command["model_type"] + model_type_key: str = command["model_type"] model_id: str = command["model_id"] - self._set_default_models(model_type, model_id) + + model_type = self._MODEL_TYPE_MAP.get(model_type_key) + if model_type is None: + print(f"Unknown model type: {model_type_key}") + return + + model_name, model_instance, model_provider = self._parse_model_id(model_id) + + payload = { + "model_provider": model_provider, + "model_instance": model_instance, + "model_type": model_type, + "model_name": model_name, + } + response = self.http_client.request("PATCH", "/models/default", json_body=payload, use_api_base=True, + auth_kind="web") + res_json = response.json() + if response.status_code == 200 and res_json.get("code") == 0: + print(f"Success to set default {model_type} to {model_id}") + else: + print(f"Fail to set default {model_type}, code: {res_json.get('code')}, message: {res_json.get('message')}") def reset_default_model(self, command): if self.server_type != "user": print("This command is only allowed in USER mode") + return - model_type: str = command["model_type"] - self._set_default_models(model_type, "") + model_type_key: str = command["model_type"] + model_type = self._MODEL_TYPE_MAP.get(model_type_key) + if model_type is None: + print(f"Unknown model type: {model_type_key}") + return + + payload = {"model_type": model_type} + response = self.http_client.request("PATCH", "/models/default", json_body=payload, use_api_base=True, + auth_kind="web") + res_json = response.json() + if response.status_code == 200 and res_json.get("code") == 0: + print(f"Success to reset default {model_type}") + else: + print(f"Fail to reset default {model_type}, code: {res_json.get('code')}, message: {res_json.get('message')}") + + @staticmethod + def _parse_model_id(model_id: str): + """Parse model_id into (model_name, model_instance, model_provider). + + Accepted formats: + - model_name@instance@provider -> (model_name, instance, provider) + - model_name@provider -> (model_name, "default", provider) + - model_name -> (model_name, "default", "") + """ + parts = model_id.split("@") + if len(parts) >= 3: + return parts[0], parts[1], parts[-1] + elif len(parts) == 2: + return parts[0], "default", parts[1] + else: + return model_id, "default", "" def list_user_datasets(self, command): if self.server_type != "user": @@ -1825,41 +1913,6 @@ class RAGFlowClient: print(f"Fail to list chats, code: {res_json['code']}, message: {res_json['message']}") return None - def _get_default_models(self): - response = self.http_client.request("GET", "/users/me/models", use_api_base=True, auth_kind="web") - res_json = response.json() - if response.status_code == 200: - if res_json["code"] == 0: - return res_json["data"] - else: - print(f"Fail to list user default models, code: {res_json['code']}, message: {res_json['message']}") - return None - else: - print(f"Fail to list user default models, HTTP code: {response.status_code}, message: {res_json}") - return None - - def _set_default_models(self, model_type, model_id): - current_payload = self._get_default_models() - if current_payload is None: - return - else: - current_payload.update({model_type: model_id}) - payload = { - "tenant_id": current_payload["tenant_id"], - "llm_id": current_payload["llm_id"], - "embd_id": current_payload["embd_id"], - "img2txt_id": current_payload["img2txt_id"], - "asr_id": current_payload["asr_id"], - "tts_id": current_payload["tts_id"], - } - response = self.http_client.request("PATCH", "/users/me/models", json_body=payload, use_api_base=True, - auth_kind="web") - res_json = response.json() - if response.status_code == 200 and res_json["code"] == 0: - print(f"Success to set default llm to {model_type}") - else: - print(f"Fail to set default llm to {model_type}, code: {res_json['code']}, message: {res_json['message']}") - def _format_service_detail_table(self, data): if isinstance(data, list): return data diff --git a/api/db/init_data.py b/api/db/init_data.py index 93c92cc64f..c4caf17e99 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -19,18 +19,16 @@ import json import os import time import uuid -from copy import deepcopy from peewee import IntegrityError from api.db import UserTenantRole -from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM +from api.db.db_models import init_database_tables as init_web_db from api.db.services import UserService from api.db.services.canvas_service import CanvasTemplateService from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService -from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm +from api.db.services.llm_service import LLMBundle from api.db.services.user_service import TenantService, UserTenantService from api.db.services.system_settings_service import SystemSettingsService from api.db.template_utils import normalize_canvas_template_categories @@ -76,8 +74,6 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_ "role": role } - tenant_llm = get_init_tenant_llm(user_info["id"]) - try: if not UserService.save(**user_info): logging.error("can't init admin.") @@ -87,7 +83,6 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_ return TenantService.insert(**tenant) UserTenantService.insert(**usr_tenant) - TenantLLMService.insert_many(tenant_llm) logging.info( f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.") @@ -106,54 +101,6 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_ logging.error("'{}' doesn't work!".format(tenant["embd_id"])) -def init_llm_factory(): - # todo deprecated - LLMFactoriesService.filter_delete([1 == 1]) - factory_llm_infos = settings.FACTORY_LLM_INFOS - for factory_llm_info in factory_llm_infos: - info = deepcopy(factory_llm_info) - llm_infos = info.pop("llm") - try: - LLMFactoriesService.save(**info) - except Exception: - pass - LLMService.filter_delete([LLM.fid == factory_llm_info["name"]]) - for llm_info in llm_infos: - llm_info["fid"] = factory_llm_info["name"] - try: - LLMService.save(**llm_info) - except Exception: - pass - - LLMFactoriesService.filter_delete([(LLMFactories.name == "Local") | (LLMFactories.name == "novita.ai")]) - LLMService.filter_delete([LLM.fid == "Local"]) - LLMService.filter_delete([LLM.llm_name == "qwen-vl-max"]) - LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"]) - TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"]) - LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"]) - LLMService.filter_delete([LLMService.model.fid == "QAnything"]) - TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"}) - TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"}) - TenantService.filter_update([1 == 1], { - "parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"}) - ## insert openai two embedding models to the current openai user. - # print("Start to insert 2 OpenAI embedding models...") - tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()]) - for tid in tenant_ids: - for row in TenantLLMService.query(llm_factory="OpenAI", tenant_id=tid): - row = row.to_dict() - row["model_type"] = LLMType.EMBEDDING.value - row["llm_name"] = "text-embedding-3-small" - row["used_tokens"] = 0 - try: - TenantLLMService.save(**row) - row = deepcopy(row) - row["llm_name"] = "text-embedding-3-large" - TenantLLMService.save(**row) - except Exception: - pass - break - def update_document_number_in_init(): doc_count = DocumentService.get_all_kb_doc_count() for kb_id in KnowledgebaseService.get_all_ids(): diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index f651729806..6aeb94a4c8 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -27,7 +27,6 @@ from langfuse import propagate_attributes from api.db.db_models import LLM from api.db.services.common_service import CommonService from api.db.services.tenant_llm_service import LLM4Tenant -from common.constants import LLMType from common.token_utils import num_tokens_from_string @@ -35,55 +34,6 @@ class LLMService(CommonService): model = LLM -def get_init_tenant_llm(user_id): - from common import settings - - tenant_llm = [] - - model_configs = { - LLMType.CHAT: settings.CHAT_CFG, - LLMType.EMBEDDING: settings.EMBEDDING_CFG, - LLMType.SPEECH2TEXT: settings.ASR_CFG, - LLMType.IMAGE2TEXT: settings.IMAGE2TEXT_CFG, - LLMType.RERANK: settings.RERANK_CFG, - } - - seen = set() - factory_configs = [] - for factory_config in [ - settings.CHAT_CFG, - settings.EMBEDDING_CFG, - settings.ASR_CFG, - settings.IMAGE2TEXT_CFG, - settings.RERANK_CFG, - ]: - factory_name = factory_config["factory"] - if factory_name not in seen: - seen.add(factory_name) - factory_configs.append(factory_config) - - for factory_config in factory_configs: - for llm in LLMService.query(fid=factory_config["factory"]): - tenant_llm.append( - { - "tenant_id": user_id, - "llm_factory": factory_config["factory"], - "llm_name": llm.llm_name, - "model_type": llm.model_type, - "api_key": model_configs.get(llm.model_type, {}).get("api_key", factory_config["api_key"]), - "api_base": model_configs.get(llm.model_type, {}).get("base_url", factory_config["base_url"]), - "max_tokens": llm.max_tokens if llm.max_tokens else 8192, - } - ) - - unique = {} - for item in tenant_llm: - key = (item["tenant_id"], item["llm_factory"], item["llm_name"]) - if key not in unique: - unique[key] = item - return list(unique.values()) - - class LLMBundle(LLM4Tenant): def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs): super().__init__(tenant_id, model_config, lang, **kwargs) diff --git a/sdk/python/test/conftest.py b/sdk/python/test/conftest.py index ac21be8d1b..ef6dd44338 100644 --- a/sdk/python/test/conftest.py +++ b/sdk/python/test/conftest.py @@ -91,19 +91,6 @@ def get_email(): return EMAIL -def get_my_llms(auth, name): - # todo deprecated - url = HOST_ADDRESS + "/v1/llm/my_llms" - authorization = {"Authorization": auth} - response = requests.get(url=url, headers=authorization) - res = response.json() - if res.get("code") != 0: - raise Exception(res.get("message")) - if name in res.get("data"): - return True - return False - - def get_added_models(auth, factory_name): url = HOST_ADDRESS + "/api/v1/models" authorization = {"Authorization": auth} @@ -116,21 +103,6 @@ def get_added_models(auth, factory_name): return True return False -def add_models(auth): - # todo deprecated - url = HOST_ADDRESS + "/v1/llm/set_api_key" - authorization = {"Authorization": auth} - models_info = { - "ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY}, - } - - for name, model_info in models_info.items(): - if not get_my_llms(auth, name): - response = requests.post(url=url, headers=authorization, json=model_info) - res = response.json() - if res.get("code") != 0: - pytest.exit(f"Critical error in add_models: {res.get('message')}") - def add_model_instance(auth): add_provider_api = HOST_ADDRESS + "/api/v1/providers" @@ -156,17 +128,6 @@ def add_model_instance(auth): pytest.exit("Critical error in check added model: add model failed") -def get_tenant_info(auth): - # todo deprecated - url = HOST_ADDRESS + "/api/v1/users/me/models" - authorization = {"Authorization": auth} - response = requests.get(url=url, headers=authorization) - res = response.json() - if res.get("code") != 0: - raise Exception(res.get("message")) - return res["data"].get("tenant_id") - - @pytest.fixture(scope="session", autouse=True) def set_tenant_info(get_auth): auth = get_auth diff --git a/test/benchmark/README.md b/test/benchmark/README.md index 6f1ca539de..c21bce7bba 100644 --- a/test/benchmark/README.md +++ b/test/benchmark/README.md @@ -59,7 +59,7 @@ Auth and bootstrap flags (used when --api-key is not provided) --token-name Optional API token name for /system/new_token. --bootstrap-llm - Ensure LLM factory API key is configured via /llm/set_api_key. + Ensure LLM factory API key is configured via /providers + /providers/{name}/instances. --llm-factory LLM factory name for bootstrap. Env: RAGFLOW_LLM_FACTORY @@ -70,7 +70,7 @@ Auth and bootstrap flags (used when --api-key is not provided) Optional LLM API base URL. Env: RAGFLOW_LLM_API_BASE --set-tenant-info - Set tenant defaults via /users/me/models. + Set tenant defaults via /models/default. --tenant-llm-id Tenant chat model ID. Env: RAGFLOW_TENANT_LLM_ID diff --git a/test/benchmark/auth.py b/test/benchmark/auth.py index 49c6efdcfb..407ef59b9e 100644 --- a/test/benchmark/auth.py +++ b/test/benchmark/auth.py @@ -55,10 +55,12 @@ def create_api_token(client: HttpClient, login_token: str, token_name: Optional[ def get_my_llms(client: HttpClient) -> Dict[str, Any]: - res = client.request_json("GET", "/llm/my_llms", use_api_base=False, auth_kind="login") + """List tenant-configured providers. Returns a dict keyed by provider name.""" + res = client.request_json("GET", "/providers", use_api_base=True, auth_kind="login") if res.get("code") != 0: - raise AuthError(f"Failed to list LLMs: {res.get('message')}") - return res.get("data", {}) + raise AuthError(f"Failed to list providers: {res.get('message')}") + providers = res.get("data", []) + return {p.get("name", ""): p for p in providers} if isinstance(providers, list) else {} def set_llm_api_key( @@ -67,22 +69,48 @@ def set_llm_api_key( api_key: str, base_url: Optional[str] = None, ) -> None: - payload = {"llm_factory": llm_factory, "api_key": api_key} - if base_url: - payload["base_url"] = base_url - res = client.request_json("POST", "/llm/set_api_key", use_api_base=False, auth_kind="login", json_body=payload) - if res.get("code") != 0: - raise AuthError(f"Failed to set LLM API key: {res.get('message')}") + """Add a provider (PUT /providers) and create a default instance (POST /providers/{name}/instances).""" + provider_payload = {"provider_name": llm_factory} + provider_res = client.request_json("PUT", "/providers", use_api_base=True, auth_kind="login", json_body=provider_payload) + provider_msg = provider_res.get("message", "") + if provider_res.get("code") != 0 and "duplicated" not in provider_msg.lower() and "already exist" not in provider_msg.lower(): + raise AuthError(f"Failed to add provider: {provider_msg}") + + instance_payload = { + "instance_name": "default", + "api_key": api_key, + "region": "default", + "base_url": base_url or "", + } + instance_res = client.request_json("POST", f"/providers/{llm_factory}/instances", use_api_base=True, + auth_kind="login", json_body=instance_payload) + instance_msg = instance_res.get("message", "") + if instance_res.get("code") != 0 and "already exist" not in instance_msg.lower(): + raise AuthError(f"Failed to add instance: {instance_msg}") -def get_tenant_info(client: HttpClient) -> Dict[str, Any]: - res = client.request_json("GET", "/users/me/models", use_api_base=True, auth_kind="login") +def get_default_models(client: HttpClient) -> Dict[str, Any]: + """List tenant default models.""" + res = client.request_json("GET", "/models/default", use_api_base=True, auth_kind="login") if res.get("code") != 0: - raise AuthError(f"Failed to get tenant info: {res.get('message')}") + raise AuthError(f"Failed to get default models: {res.get('message')}") return res.get("data", {}) -def set_tenant_info(client: HttpClient, payload: Dict[str, Any]) -> None: - res = client.request_json("PATCH", "/users/me/models", use_api_base=True, auth_kind="login", json_body=payload) +def set_default_model( + client: HttpClient, + model_provider: str, + model_instance: str, + model_name: str, + model_type: str, +) -> None: + """Set a tenant default model via PATCH /models/default.""" + payload = { + "model_provider": model_provider, + "model_instance": model_instance, + "model_name": model_name, + "model_type": model_type, + } + res = client.request_json("PATCH", "/models/default", use_api_base=True, auth_kind="login", json_body=payload) if res.get("code") != 0: - raise AuthError(f"Failed to set tenant info: {res.get('message')}") + raise AuthError(f"Failed to set default model: {res.get('message')}") diff --git a/test/benchmark/cli.py b/test/benchmark/cli.py index 971540aab3..e0e9352451 100644 --- a/test/benchmark/cli.py +++ b/test/benchmark/cli.py @@ -199,6 +199,50 @@ def _retrieval_worker( return run_retrieval_request(client, payload) +def _parse_model_id(model_id: str): + """Parse model_id into (model_name, model_instance, model_provider). + + Accepted formats: + - model_name@instance@provider -> (model_name, instance, provider) + - model_name@provider -> (model_name, "default", provider) + - model_name -> (model_name, "default", "") + """ + parts = model_id.split("@") + if len(parts) >= 3: + return parts[0], parts[1], parts[-1] + elif len(parts) == 2: + return parts[0], "default", parts[1] + else: + return model_id, "default", "" + + +# Mapping from legacy CLI arg keys to API model_type values +_MODEL_TYPE_MAP = { + "llm_id": "chat", + "embd_id": "embedding", + "img2txt_id": "vision", + "asr_id": "asr", + "tts_id": "tts", +} + + +def _set_default_models_from_args(client: HttpClient, args: argparse.Namespace) -> None: + """Set tenant default models from CLI arguments using the new PATCH /models/default API.""" + model_specs = [ + (args.tenant_llm_id, "llm_id"), + (args.tenant_embd_id, "embd_id"), + (args.tenant_img2txt_id, "img2txt_id"), + (args.tenant_asr_id, "asr_id"), + (args.tenant_tts_id, "tts_id"), + ] + for model_id, type_key in model_specs: + if not model_id: + continue + model_name, model_instance, model_provider = _parse_model_id(model_id) + model_type = _MODEL_TYPE_MAP[type_key] + auth.set_default_model(client, model_provider, model_instance, model_name, model_type) + + def _ensure_auth(client: HttpClient, args: argparse.Namespace) -> None: if args.api_key: client.api_key = args.api_key @@ -232,19 +276,7 @@ def _ensure_auth(client: HttpClient, args: argparse.Namespace) -> None: if args.set_tenant_info: if not args.tenant_llm_id or not args.tenant_embd_id: raise AuthError("Missing --tenant-llm-id or --tenant-embd-id for tenant setup") - tenant = auth.get_tenant_info(client) - tenant_id = tenant.get("tenant_id") - if not tenant_id: - raise AuthError("Tenant info missing tenant_id") - payload = { - "tenant_id": tenant_id, - "llm_id": args.tenant_llm_id, - "embd_id": args.tenant_embd_id, - "img2txt_id": args.tenant_img2txt_id or "", - "asr_id": args.tenant_asr_id or "", - "tts_id": args.tenant_tts_id, - } - auth.set_tenant_info(client, payload) + _set_default_models_from_args(client, args) api_key = auth.create_api_token(client, login_token, args.token_name) client.api_key = api_key diff --git a/test/testcases/conftest.py b/test/testcases/conftest.py index f50f7fcc3f..4735a389a3 100644 --- a/test/testcases/conftest.py +++ b/test/testcases/conftest.py @@ -170,19 +170,6 @@ def token(auth): return res["data"].get("token") -def get_my_llms(auth, name): - # todo deprecated - url = HOST_ADDRESS + f"/{VERSION}/llm/my_llms" - authorization = {"Authorization": auth} - response = requests.get(url=url, headers=authorization) - res = response.json() - if res.get("code") != 0: - raise Exception(res.get("message")) - if name in res.get("data"): - return True - return False - - def get_added_models(auth, factory_name): url = HOST_ADDRESS + "/api/v1/models" authorization = {"Authorization": auth} @@ -202,49 +189,6 @@ def get_added_models(auth, factory_name): return False -def get_tenant_llm_added(auth, factory_name, model_name, model_type="rerank"): - """ - Check whether a specific (factory, model_name, model_type) tenant_llm row exists. - - Legacy /v1/llm/my_llms response shape: - { - "ZHIPU-AI": {"tags": ..., "llm": [{"name": ..., "type": ...}, ...]}, - "SILICONFLOW": {"tags": ..., "llm": [{"name": ..., "type": ...}, ...]}, - } - so we navigate by factory key first, then look through its llm list. - """ - url = HOST_ADDRESS + f"/{VERSION}/llm/my_llms" - authorization = {"Authorization": auth} - response = requests.get(url=url, headers=authorization) - res = response.json() - if res.get("code") != 0: - return False - data = res.get("data") or {} - factory_data = data.get(factory_name) or {} - for m in factory_data.get("llm", []) or []: - if m.get("name") != model_name: - continue - if model_type is None or m.get("type") == model_type: - return True - return False - - -def add_models(auth): - # todo deprecated - url = HOST_ADDRESS + f"/{VERSION}/llm/set_api_key" - authorization = {"Authorization": auth} - models_info = { - "ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY}, - } - - for name, model_info in models_info.items(): - if not get_my_llms(auth, name): - response = requests.post(url=url, headers=authorization, json=model_info) - res = response.json() - if res.get("code") != 0: - pytest.exit(f"Critical error in add_models: {res.get('message')}") - - def add_model_instance(auth): add_provider_api = HOST_ADDRESS + "/api/v1/providers" authorization = {"Authorization": auth} @@ -277,34 +221,34 @@ def add_model_instance(auth): else: pytest.exit(f"Critical error in add model provider: {msg}") - # Register both "CI" (used by glm-4-flash@CI@ZHIPU-AI in configs.py - # and BAAI/bge-reranker-v2-m3@CI@SILICONFLOW) and "default". - for instance_name in ("CI", "default"): - add_instance_api = HOST_ADDRESS + f"/api/v1/providers/{provider_name}/instances" - add_instance_response = requests.post(url=add_instance_api, headers=authorization, json={ - "instance_name": instance_name, - "api_key": api_key, - "region": "default", - "base_url": "" - }) - add_instance_res = add_instance_response.json() - if add_instance_res.get("code") != 0: - msg = add_instance_res.get("message", "") - # Instance may already exist with a different API key from a - # prior test run; that's fine — skip instead of failing. - if "Already exist instance" in msg or "already exist" in msg.lower(): - print(f"Note: {provider_name}/{instance_name} already exists, skipping") - continue - # Python API blocks creating instances named "default". - # The test_retrieval_parity test handles this by inserting - # "default" directly into the DB for SILICONFLOW. - if "cannot be 'default'" in msg: - print(f"Note: {provider_name}/{instance_name} blocked by API (name reserved), skipping") - continue - pytest.exit( - f"Critical error in add model instance {provider_name}/{instance_name}: " - f"{msg}" - ) + # Register "CI" (used by glm-4-flash@CI@ZHIPU-AI in configs.py + # and BAAI/bge-reranker-v2-m3@CI@SILICONFLOW). + instance_name = "CI" + add_instance_api = HOST_ADDRESS + f"/api/v1/providers/{provider_name}/instances" + add_instance_response = requests.post(url=add_instance_api, headers=authorization, json={ + "instance_name": instance_name, + "api_key": api_key, + "region": "default", + "base_url": "" + }) + add_instance_res = add_instance_response.json() + if add_instance_res.get("code") != 0: + msg = add_instance_res.get("message", "") + # Instance may already exist with a different API key from a + # prior test run; that's fine — skip instead of failing. + if "Already exist instance" in msg or "already exist" in msg.lower(): + print(f"Note: {provider_name}/{instance_name} already exists, skipping") + continue + # Python API blocks creating instances named "default". + # The test_retrieval_parity test handles this by inserting + # "default" directly into the DB for SILICONFLOW. + if "cannot be 'default'" in msg: + print(f"Note: {provider_name}/{instance_name} blocked by API (name reserved), skipping") + continue + pytest.exit( + f"Critical error in add model instance {provider_name}/{instance_name}: " + f"{msg}" + ) add_success = get_added_models(auth, provider_name) if not add_success: @@ -324,52 +268,6 @@ def add_model_instance(auth): pytest.exit(f"Critical error in check added model: {provider_name} add model failed") -def add_siliconflow_rerank_llm(auth): - """ - Register the BAAI/bge-reranker-v2-m3 rerank model under factory=SILICONFLOW / instance=CI. - - This is the model referenced as `BAAI/bge-reranker-v2-m3@CI@SILICONFLOW` in - test_retrieval_parity.py. The /v1/llm/add_llm endpoint validates the key by - issuing a real rerank request, so the call requires network access to SiliconFlow - and a valid SILICONFLOW_API_KEY. - """ - factory = "SILICONFLOW" - model_name = "BAAI/bge-reranker-v2-m3" - if get_tenant_llm_added(auth, factory, model_name, "rerank"): - return - - url = HOST_ADDRESS + f"/{VERSION}/llm/add_llm" - authorization = {"Authorization": auth} - payload = { - "llm_factory": factory, - "llm_name": model_name, - "model_type": "rerank", - "api_key": SILICONFLOW_API_KEY, - "api_base": "", - } - response = requests.post(url=url, headers=authorization, json=payload) - res = response.json() - if res.get("code") != 0: - pytest.exit( - f"Critical error adding {factory} rerank model {model_name}: " - f"code={res.get('code')} message={res.get('message')} data={res.get('data')}" - ) - - if not get_tenant_llm_added(auth, factory, model_name, "rerank"): - pytest.exit(f"Failed to confirm {factory}/{model_name} rerank row was added") - - -def get_tenant_info(auth): - # todo deprecated - url = HOST_ADDRESS + f"/api/{VERSION}/users/me/models" - authorization = {"Authorization": auth} - response = requests.get(url=url, headers=authorization) - res = response.json() - if res.get("code") != 0: - raise Exception(res.get("message")) - return res["data"].get("tenant_id") - - @pytest.fixture(scope="session", autouse=True) def set_tenant_info(auth): if not get_added_models(auth, "ZHIPU-AI") or not get_added_models(auth, "SILICONFLOW"): @@ -416,24 +314,20 @@ def set_tenant_info(auth): f"WARNING: failed to set default embedding LLM via {url}: " f"{embd_res.get('message')!r}. Continuing." ) - - -@pytest.fixture(scope="session", autouse=True) -def set_tenant_siliconflow_rerank(auth): - """ - Ensure the SiliconFlow BAAI/bge-reranker-v2-m3 rerank model is registered - for the test tenant. Used by test_retrieval_parity.py as - `BAAI/bge-reranker-v2-m3@CI@SILICONFLOW`. - - Runs after `set_tenant_info` so the SILICONFLOW provider+CI instance - already exist when the /add_llm call is made. - - If /add_llm is blocked (e.g. factory not in allowed list), the rerank - model config is resolved from FACTORY_LLM_INFOS at search time, so the - test can still proceed. - """ - try: - add_siliconflow_rerank_llm(auth) - except Exception as e: - print(f"Note: Could not register SILICONFLOW rerank model via /add_llm: {e}") - print("The model config will be resolved from FACTORY_LLM_INFOS at runtime.") + # set rerank model + set_default_rerank_response = requests.patch( + url=url, + headers=authorization, + json={ + "model_provider": "SILICONFLOW", + "model_instance": "CI", + "model_type": "rerank", + "model_name": "BAAI/bge-reranker-v2-m3" + } + ) + rerank_res = set_default_rerank_response.json() + if rerank_res.get("code") != 0: + print( + f"WARNING: failed to set default rerank LLM via {url}: " + f"{rerank_res.get('message')!r}. Continuing." + )