Fix: replace tenant_llm apis (#16131)

Replace tenant_llm apis with provider-instance apis.
This commit is contained in:
Lynn
2026-06-18 16:38:32 +08:00
committed by GitHub
parent 72db9044e2
commit 47bd9dd049
8 changed files with 244 additions and 379 deletions

View File

@@ -727,45 +727,133 @@ class RAGFlowClient:
def create_model_provider(self, command):
if self.server_type != "user":
print("This command is only allowed in USER mode")
llm_factory: str = command["provider_name"]
return
provider_name: str = command["provider_name"]
api_key: str = command["provider_key"]
payload = {"api_key": api_key, "llm_factory": llm_factory}
response = self.http_client.request("POST", "/llm/set_api_key", json_body=payload, use_api_base=False,
auth_kind="web")
res_json = response.json()
if response.status_code == 200 and res_json["code"] == 0:
print(f"Success to add model provider {llm_factory}")
# Step 1: Add provider
provider_payload = {"provider_name": provider_name}
provider_response = self.http_client.request("PUT", "/providers", json_body=provider_payload,
use_api_base=True, auth_kind="web")
provider_res = provider_response.json()
if provider_response.status_code == 200 and provider_res.get("code") == 0:
print(f"Success to add provider {provider_name}")
else:
print(f"Fail to add model provider {llm_factory}, code: {res_json['code']}, message: {res_json['message']}")
msg = provider_res.get("message", "")
if "duplicated" in msg.lower() or "already exist" in msg.lower():
print(f"Note: provider {provider_name} already exists, continuing to add instance")
else:
print(f"Fail to add provider {provider_name}, code: {provider_res.get('code')}, message: {msg}")
return
# Step 2: Add instance
instance_payload = {
"instance_name": "default",
"api_key": api_key,
"region": "default",
"base_url": ""
}
instance_response = self.http_client.request("POST", f"/providers/{provider_name}/instances",
json_body=instance_payload, use_api_base=True,
auth_kind="web")
instance_res = instance_response.json()
if instance_response.status_code == 200 and instance_res.get("code") == 0:
print(f"Success to add instance for provider {provider_name}")
else:
msg = instance_res.get("message", "")
if "already exist" in msg.lower():
print(f"Note: instance for provider {provider_name} already exists, skipping")
else:
print(f"Fail to add instance for provider {provider_name}, code: {instance_res.get('code')}, message: {msg}")
def drop_model_provider(self, command):
if self.server_type != "user":
print("This command is only allowed in USER mode")
llm_factory: str = command["provider_name"]
payload = {"llm_factory": llm_factory}
response = self.http_client.request("POST", "/llm/delete_factory", json_body=payload, use_api_base=False,
return
provider_name: str = command["provider_name"]
response = self.http_client.request("DELETE", f"/providers/{provider_name}", use_api_base=True,
auth_kind="web")
res_json = response.json()
if response.status_code == 200 and res_json["code"] == 0:
print(f"Success to drop model provider {llm_factory}")
if response.status_code == 200 and res_json.get("code") == 0:
print(f"Success to drop model provider {provider_name}")
else:
print(
f"Fail to drop model provider {llm_factory}, code: {res_json['code']}, message: {res_json['message']}")
print(f"Fail to drop model provider {provider_name}, code: {res_json.get('code')}, message: {res_json.get('message')}")
# Mapping from legacy model_type keys to API model_type values
_MODEL_TYPE_MAP = {
"llm_id": "chat",
"embd_id": "embedding",
"img2txt_id": "vision",
"reranker_id": "rerank",
"asr_id": "asr",
"tts_id": "tts",
}
def set_default_model(self, command):
if self.server_type != "user":
print("This command is only allowed in USER mode")
return
model_type: str = command["model_type"]
model_type_key: str = command["model_type"]
model_id: str = command["model_id"]
self._set_default_models(model_type, model_id)
model_type = self._MODEL_TYPE_MAP.get(model_type_key)
if model_type is None:
print(f"Unknown model type: {model_type_key}")
return
model_name, model_instance, model_provider = self._parse_model_id(model_id)
payload = {
"model_provider": model_provider,
"model_instance": model_instance,
"model_type": model_type,
"model_name": model_name,
}
response = self.http_client.request("PATCH", "/models/default", json_body=payload, use_api_base=True,
auth_kind="web")
res_json = response.json()
if response.status_code == 200 and res_json.get("code") == 0:
print(f"Success to set default {model_type} to {model_id}")
else:
print(f"Fail to set default {model_type}, code: {res_json.get('code')}, message: {res_json.get('message')}")
def reset_default_model(self, command):
if self.server_type != "user":
print("This command is only allowed in USER mode")
return
model_type: str = command["model_type"]
self._set_default_models(model_type, "")
model_type_key: str = command["model_type"]
model_type = self._MODEL_TYPE_MAP.get(model_type_key)
if model_type is None:
print(f"Unknown model type: {model_type_key}")
return
payload = {"model_type": model_type}
response = self.http_client.request("PATCH", "/models/default", json_body=payload, use_api_base=True,
auth_kind="web")
res_json = response.json()
if response.status_code == 200 and res_json.get("code") == 0:
print(f"Success to reset default {model_type}")
else:
print(f"Fail to reset default {model_type}, code: {res_json.get('code')}, message: {res_json.get('message')}")
@staticmethod
def _parse_model_id(model_id: str):
"""Parse model_id into (model_name, model_instance, model_provider).
Accepted formats:
- model_name@instance@provider -> (model_name, instance, provider)
- model_name@provider -> (model_name, "default", provider)
- model_name -> (model_name, "default", "")
"""
parts = model_id.split("@")
if len(parts) >= 3:
return parts[0], parts[1], parts[-1]
elif len(parts) == 2:
return parts[0], "default", parts[1]
else:
return model_id, "default", ""
def list_user_datasets(self, command):
if self.server_type != "user":
@@ -1825,41 +1913,6 @@ class RAGFlowClient:
print(f"Fail to list chats, code: {res_json['code']}, message: {res_json['message']}")
return None
def _get_default_models(self):
response = self.http_client.request("GET", "/users/me/models", use_api_base=True, auth_kind="web")
res_json = response.json()
if response.status_code == 200:
if res_json["code"] == 0:
return res_json["data"]
else:
print(f"Fail to list user default models, code: {res_json['code']}, message: {res_json['message']}")
return None
else:
print(f"Fail to list user default models, HTTP code: {response.status_code}, message: {res_json}")
return None
def _set_default_models(self, model_type, model_id):
current_payload = self._get_default_models()
if current_payload is None:
return
else:
current_payload.update({model_type: model_id})
payload = {
"tenant_id": current_payload["tenant_id"],
"llm_id": current_payload["llm_id"],
"embd_id": current_payload["embd_id"],
"img2txt_id": current_payload["img2txt_id"],
"asr_id": current_payload["asr_id"],
"tts_id": current_payload["tts_id"],
}
response = self.http_client.request("PATCH", "/users/me/models", json_body=payload, use_api_base=True,
auth_kind="web")
res_json = response.json()
if response.status_code == 200 and res_json["code"] == 0:
print(f"Success to set default llm to {model_type}")
else:
print(f"Fail to set default llm to {model_type}, code: {res_json['code']}, message: {res_json['message']}")
def _format_service_detail_table(self, data):
if isinstance(data, list):
return data

View File

@@ -19,18 +19,16 @@ import json
import os
import time
import uuid
from copy import deepcopy
from peewee import IntegrityError
from api.db import UserTenantRole
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
from api.db.db_models import init_database_tables as init_web_db
from api.db.services import UserService
from api.db.services.canvas_service import CanvasTemplateService
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService, UserTenantService
from api.db.services.system_settings_service import SystemSettingsService
from api.db.template_utils import normalize_canvas_template_categories
@@ -76,8 +74,6 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_
"role": role
}
tenant_llm = get_init_tenant_llm(user_info["id"])
try:
if not UserService.save(**user_info):
logging.error("can't init admin.")
@@ -87,7 +83,6 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_
return
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
logging.info(
f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.")
@@ -106,54 +101,6 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_
logging.error("'{}' doesn't work!".format(tenant["embd_id"]))
def init_llm_factory():
# todo deprecated
LLMFactoriesService.filter_delete([1 == 1])
factory_llm_infos = settings.FACTORY_LLM_INFOS
for factory_llm_info in factory_llm_infos:
info = deepcopy(factory_llm_info)
llm_infos = info.pop("llm")
try:
LLMFactoriesService.save(**info)
except Exception:
pass
LLMService.filter_delete([LLM.fid == factory_llm_info["name"]])
for llm_info in llm_infos:
llm_info["fid"] = factory_llm_info["name"]
try:
LLMService.save(**llm_info)
except Exception:
pass
LLMFactoriesService.filter_delete([(LLMFactories.name == "Local") | (LLMFactories.name == "novita.ai")])
LLMService.filter_delete([LLM.fid == "Local"])
LLMService.filter_delete([LLM.llm_name == "qwen-vl-max"])
LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])
LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
LLMService.filter_delete([LLMService.model.fid == "QAnything"])
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
TenantService.filter_update([1 == 1], {
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"})
## insert openai two embedding models to the current openai user.
# print("Start to insert 2 OpenAI embedding models...")
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
for tid in tenant_ids:
for row in TenantLLMService.query(llm_factory="OpenAI", tenant_id=tid):
row = row.to_dict()
row["model_type"] = LLMType.EMBEDDING.value
row["llm_name"] = "text-embedding-3-small"
row["used_tokens"] = 0
try:
TenantLLMService.save(**row)
row = deepcopy(row)
row["llm_name"] = "text-embedding-3-large"
TenantLLMService.save(**row)
except Exception:
pass
break
def update_document_number_in_init():
doc_count = DocumentService.get_all_kb_doc_count()
for kb_id in KnowledgebaseService.get_all_ids():

View File

@@ -27,7 +27,6 @@ from langfuse import propagate_attributes
from api.db.db_models import LLM
from api.db.services.common_service import CommonService
from api.db.services.tenant_llm_service import LLM4Tenant
from common.constants import LLMType
from common.token_utils import num_tokens_from_string
@@ -35,55 +34,6 @@ class LLMService(CommonService):
model = LLM
def get_init_tenant_llm(user_id):
from common import settings
tenant_llm = []
model_configs = {
LLMType.CHAT: settings.CHAT_CFG,
LLMType.EMBEDDING: settings.EMBEDDING_CFG,
LLMType.SPEECH2TEXT: settings.ASR_CFG,
LLMType.IMAGE2TEXT: settings.IMAGE2TEXT_CFG,
LLMType.RERANK: settings.RERANK_CFG,
}
seen = set()
factory_configs = []
for factory_config in [
settings.CHAT_CFG,
settings.EMBEDDING_CFG,
settings.ASR_CFG,
settings.IMAGE2TEXT_CFG,
settings.RERANK_CFG,
]:
factory_name = factory_config["factory"]
if factory_name not in seen:
seen.add(factory_name)
factory_configs.append(factory_config)
for factory_config in factory_configs:
for llm in LLMService.query(fid=factory_config["factory"]):
tenant_llm.append(
{
"tenant_id": user_id,
"llm_factory": factory_config["factory"],
"llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": model_configs.get(llm.model_type, {}).get("api_key", factory_config["api_key"]),
"api_base": model_configs.get(llm.model_type, {}).get("base_url", factory_config["base_url"]),
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
}
)
unique = {}
for item in tenant_llm:
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
if key not in unique:
unique[key] = item
return list(unique.values())
class LLMBundle(LLM4Tenant):
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
super().__init__(tenant_id, model_config, lang, **kwargs)

View File

@@ -91,19 +91,6 @@ def get_email():
return EMAIL
def get_my_llms(auth, name):
# todo deprecated
url = HOST_ADDRESS + "/v1/llm/my_llms"
authorization = {"Authorization": auth}
response = requests.get(url=url, headers=authorization)
res = response.json()
if res.get("code") != 0:
raise Exception(res.get("message"))
if name in res.get("data"):
return True
return False
def get_added_models(auth, factory_name):
url = HOST_ADDRESS + "/api/v1/models"
authorization = {"Authorization": auth}
@@ -116,21 +103,6 @@ def get_added_models(auth, factory_name):
return True
return False
def add_models(auth):
# todo deprecated
url = HOST_ADDRESS + "/v1/llm/set_api_key"
authorization = {"Authorization": auth}
models_info = {
"ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY},
}
for name, model_info in models_info.items():
if not get_my_llms(auth, name):
response = requests.post(url=url, headers=authorization, json=model_info)
res = response.json()
if res.get("code") != 0:
pytest.exit(f"Critical error in add_models: {res.get('message')}")
def add_model_instance(auth):
add_provider_api = HOST_ADDRESS + "/api/v1/providers"
@@ -156,17 +128,6 @@ def add_model_instance(auth):
pytest.exit("Critical error in check added model: add model failed")
def get_tenant_info(auth):
# todo deprecated
url = HOST_ADDRESS + "/api/v1/users/me/models"
authorization = {"Authorization": auth}
response = requests.get(url=url, headers=authorization)
res = response.json()
if res.get("code") != 0:
raise Exception(res.get("message"))
return res["data"].get("tenant_id")
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info(get_auth):
auth = get_auth

View File

@@ -59,7 +59,7 @@ Auth and bootstrap flags (used when --api-key is not provided)
--token-name
Optional API token name for /system/new_token.
--bootstrap-llm
Ensure LLM factory API key is configured via /llm/set_api_key.
Ensure LLM factory API key is configured via /providers + /providers/{name}/instances.
--llm-factory
LLM factory name for bootstrap.
Env: RAGFLOW_LLM_FACTORY
@@ -70,7 +70,7 @@ Auth and bootstrap flags (used when --api-key is not provided)
Optional LLM API base URL.
Env: RAGFLOW_LLM_API_BASE
--set-tenant-info
Set tenant defaults via /users/me/models.
Set tenant defaults via /models/default.
--tenant-llm-id
Tenant chat model ID.
Env: RAGFLOW_TENANT_LLM_ID

View File

@@ -55,10 +55,12 @@ def create_api_token(client: HttpClient, login_token: str, token_name: Optional[
def get_my_llms(client: HttpClient) -> Dict[str, Any]:
res = client.request_json("GET", "/llm/my_llms", use_api_base=False, auth_kind="login")
"""List tenant-configured providers. Returns a dict keyed by provider name."""
res = client.request_json("GET", "/providers", use_api_base=True, auth_kind="login")
if res.get("code") != 0:
raise AuthError(f"Failed to list LLMs: {res.get('message')}")
return res.get("data", {})
raise AuthError(f"Failed to list providers: {res.get('message')}")
providers = res.get("data", [])
return {p.get("name", ""): p for p in providers} if isinstance(providers, list) else {}
def set_llm_api_key(
@@ -67,22 +69,48 @@ def set_llm_api_key(
api_key: str,
base_url: Optional[str] = None,
) -> None:
payload = {"llm_factory": llm_factory, "api_key": api_key}
if base_url:
payload["base_url"] = base_url
res = client.request_json("POST", "/llm/set_api_key", use_api_base=False, auth_kind="login", json_body=payload)
if res.get("code") != 0:
raise AuthError(f"Failed to set LLM API key: {res.get('message')}")
"""Add a provider (PUT /providers) and create a default instance (POST /providers/{name}/instances)."""
provider_payload = {"provider_name": llm_factory}
provider_res = client.request_json("PUT", "/providers", use_api_base=True, auth_kind="login", json_body=provider_payload)
provider_msg = provider_res.get("message", "")
if provider_res.get("code") != 0 and "duplicated" not in provider_msg.lower() and "already exist" not in provider_msg.lower():
raise AuthError(f"Failed to add provider: {provider_msg}")
instance_payload = {
"instance_name": "default",
"api_key": api_key,
"region": "default",
"base_url": base_url or "",
}
instance_res = client.request_json("POST", f"/providers/{llm_factory}/instances", use_api_base=True,
auth_kind="login", json_body=instance_payload)
instance_msg = instance_res.get("message", "")
if instance_res.get("code") != 0 and "already exist" not in instance_msg.lower():
raise AuthError(f"Failed to add instance: {instance_msg}")
def get_tenant_info(client: HttpClient) -> Dict[str, Any]:
res = client.request_json("GET", "/users/me/models", use_api_base=True, auth_kind="login")
def get_default_models(client: HttpClient) -> Dict[str, Any]:
"""List tenant default models."""
res = client.request_json("GET", "/models/default", use_api_base=True, auth_kind="login")
if res.get("code") != 0:
raise AuthError(f"Failed to get tenant info: {res.get('message')}")
raise AuthError(f"Failed to get default models: {res.get('message')}")
return res.get("data", {})
def set_tenant_info(client: HttpClient, payload: Dict[str, Any]) -> None:
res = client.request_json("PATCH", "/users/me/models", use_api_base=True, auth_kind="login", json_body=payload)
def set_default_model(
client: HttpClient,
model_provider: str,
model_instance: str,
model_name: str,
model_type: str,
) -> None:
"""Set a tenant default model via PATCH /models/default."""
payload = {
"model_provider": model_provider,
"model_instance": model_instance,
"model_name": model_name,
"model_type": model_type,
}
res = client.request_json("PATCH", "/models/default", use_api_base=True, auth_kind="login", json_body=payload)
if res.get("code") != 0:
raise AuthError(f"Failed to set tenant info: {res.get('message')}")
raise AuthError(f"Failed to set default model: {res.get('message')}")

View File

@@ -199,6 +199,50 @@ def _retrieval_worker(
return run_retrieval_request(client, payload)
def _parse_model_id(model_id: str):
"""Parse model_id into (model_name, model_instance, model_provider).
Accepted formats:
- model_name@instance@provider -> (model_name, instance, provider)
- model_name@provider -> (model_name, "default", provider)
- model_name -> (model_name, "default", "")
"""
parts = model_id.split("@")
if len(parts) >= 3:
return parts[0], parts[1], parts[-1]
elif len(parts) == 2:
return parts[0], "default", parts[1]
else:
return model_id, "default", ""
# Mapping from legacy CLI arg keys to API model_type values
_MODEL_TYPE_MAP = {
"llm_id": "chat",
"embd_id": "embedding",
"img2txt_id": "vision",
"asr_id": "asr",
"tts_id": "tts",
}
def _set_default_models_from_args(client: HttpClient, args: argparse.Namespace) -> None:
"""Set tenant default models from CLI arguments using the new PATCH /models/default API."""
model_specs = [
(args.tenant_llm_id, "llm_id"),
(args.tenant_embd_id, "embd_id"),
(args.tenant_img2txt_id, "img2txt_id"),
(args.tenant_asr_id, "asr_id"),
(args.tenant_tts_id, "tts_id"),
]
for model_id, type_key in model_specs:
if not model_id:
continue
model_name, model_instance, model_provider = _parse_model_id(model_id)
model_type = _MODEL_TYPE_MAP[type_key]
auth.set_default_model(client, model_provider, model_instance, model_name, model_type)
def _ensure_auth(client: HttpClient, args: argparse.Namespace) -> None:
if args.api_key:
client.api_key = args.api_key
@@ -232,19 +276,7 @@ def _ensure_auth(client: HttpClient, args: argparse.Namespace) -> None:
if args.set_tenant_info:
if not args.tenant_llm_id or not args.tenant_embd_id:
raise AuthError("Missing --tenant-llm-id or --tenant-embd-id for tenant setup")
tenant = auth.get_tenant_info(client)
tenant_id = tenant.get("tenant_id")
if not tenant_id:
raise AuthError("Tenant info missing tenant_id")
payload = {
"tenant_id": tenant_id,
"llm_id": args.tenant_llm_id,
"embd_id": args.tenant_embd_id,
"img2txt_id": args.tenant_img2txt_id or "",
"asr_id": args.tenant_asr_id or "",
"tts_id": args.tenant_tts_id,
}
auth.set_tenant_info(client, payload)
_set_default_models_from_args(client, args)
api_key = auth.create_api_token(client, login_token, args.token_name)
client.api_key = api_key

View File

@@ -170,19 +170,6 @@ def token(auth):
return res["data"].get("token")
def get_my_llms(auth, name):
# todo deprecated
url = HOST_ADDRESS + f"/{VERSION}/llm/my_llms"
authorization = {"Authorization": auth}
response = requests.get(url=url, headers=authorization)
res = response.json()
if res.get("code") != 0:
raise Exception(res.get("message"))
if name in res.get("data"):
return True
return False
def get_added_models(auth, factory_name):
url = HOST_ADDRESS + "/api/v1/models"
authorization = {"Authorization": auth}
@@ -202,49 +189,6 @@ def get_added_models(auth, factory_name):
return False
def get_tenant_llm_added(auth, factory_name, model_name, model_type="rerank"):
"""
Check whether a specific (factory, model_name, model_type) tenant_llm row exists.
Legacy /v1/llm/my_llms response shape:
{
"ZHIPU-AI": {"tags": ..., "llm": [{"name": ..., "type": ...}, ...]},
"SILICONFLOW": {"tags": ..., "llm": [{"name": ..., "type": ...}, ...]},
}
so we navigate by factory key first, then look through its llm list.
"""
url = HOST_ADDRESS + f"/{VERSION}/llm/my_llms"
authorization = {"Authorization": auth}
response = requests.get(url=url, headers=authorization)
res = response.json()
if res.get("code") != 0:
return False
data = res.get("data") or {}
factory_data = data.get(factory_name) or {}
for m in factory_data.get("llm", []) or []:
if m.get("name") != model_name:
continue
if model_type is None or m.get("type") == model_type:
return True
return False
def add_models(auth):
# todo deprecated
url = HOST_ADDRESS + f"/{VERSION}/llm/set_api_key"
authorization = {"Authorization": auth}
models_info = {
"ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY},
}
for name, model_info in models_info.items():
if not get_my_llms(auth, name):
response = requests.post(url=url, headers=authorization, json=model_info)
res = response.json()
if res.get("code") != 0:
pytest.exit(f"Critical error in add_models: {res.get('message')}")
def add_model_instance(auth):
add_provider_api = HOST_ADDRESS + "/api/v1/providers"
authorization = {"Authorization": auth}
@@ -277,34 +221,34 @@ def add_model_instance(auth):
else:
pytest.exit(f"Critical error in add model provider: {msg}")
# Register both "CI" (used by glm-4-flash@CI@ZHIPU-AI in configs.py
# and BAAI/bge-reranker-v2-m3@CI@SILICONFLOW) and "default".
for instance_name in ("CI", "default"):
add_instance_api = HOST_ADDRESS + f"/api/v1/providers/{provider_name}/instances"
add_instance_response = requests.post(url=add_instance_api, headers=authorization, json={
"instance_name": instance_name,
"api_key": api_key,
"region": "default",
"base_url": ""
})
add_instance_res = add_instance_response.json()
if add_instance_res.get("code") != 0:
msg = add_instance_res.get("message", "")
# Instance may already exist with a different API key from a
# prior test run; that's fine — skip instead of failing.
if "Already exist instance" in msg or "already exist" in msg.lower():
print(f"Note: {provider_name}/{instance_name} already exists, skipping")
continue
# Python API blocks creating instances named "default".
# The test_retrieval_parity test handles this by inserting
# "default" directly into the DB for SILICONFLOW.
if "cannot be 'default'" in msg:
print(f"Note: {provider_name}/{instance_name} blocked by API (name reserved), skipping")
continue
pytest.exit(
f"Critical error in add model instance {provider_name}/{instance_name}: "
f"{msg}"
)
# Register "CI" (used by glm-4-flash@CI@ZHIPU-AI in configs.py
# and BAAI/bge-reranker-v2-m3@CI@SILICONFLOW).
instance_name = "CI"
add_instance_api = HOST_ADDRESS + f"/api/v1/providers/{provider_name}/instances"
add_instance_response = requests.post(url=add_instance_api, headers=authorization, json={
"instance_name": instance_name,
"api_key": api_key,
"region": "default",
"base_url": ""
})
add_instance_res = add_instance_response.json()
if add_instance_res.get("code") != 0:
msg = add_instance_res.get("message", "")
# Instance may already exist with a different API key from a
# prior test run; that's fine — skip instead of failing.
if "Already exist instance" in msg or "already exist" in msg.lower():
print(f"Note: {provider_name}/{instance_name} already exists, skipping")
continue
# Python API blocks creating instances named "default".
# The test_retrieval_parity test handles this by inserting
# "default" directly into the DB for SILICONFLOW.
if "cannot be 'default'" in msg:
print(f"Note: {provider_name}/{instance_name} blocked by API (name reserved), skipping")
continue
pytest.exit(
f"Critical error in add model instance {provider_name}/{instance_name}: "
f"{msg}"
)
add_success = get_added_models(auth, provider_name)
if not add_success:
@@ -324,52 +268,6 @@ def add_model_instance(auth):
pytest.exit(f"Critical error in check added model: {provider_name} add model failed")
def add_siliconflow_rerank_llm(auth):
"""
Register the BAAI/bge-reranker-v2-m3 rerank model under factory=SILICONFLOW / instance=CI.
This is the model referenced as `BAAI/bge-reranker-v2-m3@CI@SILICONFLOW` in
test_retrieval_parity.py. The /v1/llm/add_llm endpoint validates the key by
issuing a real rerank request, so the call requires network access to SiliconFlow
and a valid SILICONFLOW_API_KEY.
"""
factory = "SILICONFLOW"
model_name = "BAAI/bge-reranker-v2-m3"
if get_tenant_llm_added(auth, factory, model_name, "rerank"):
return
url = HOST_ADDRESS + f"/{VERSION}/llm/add_llm"
authorization = {"Authorization": auth}
payload = {
"llm_factory": factory,
"llm_name": model_name,
"model_type": "rerank",
"api_key": SILICONFLOW_API_KEY,
"api_base": "",
}
response = requests.post(url=url, headers=authorization, json=payload)
res = response.json()
if res.get("code") != 0:
pytest.exit(
f"Critical error adding {factory} rerank model {model_name}: "
f"code={res.get('code')} message={res.get('message')} data={res.get('data')}"
)
if not get_tenant_llm_added(auth, factory, model_name, "rerank"):
pytest.exit(f"Failed to confirm {factory}/{model_name} rerank row was added")
def get_tenant_info(auth):
# todo deprecated
url = HOST_ADDRESS + f"/api/{VERSION}/users/me/models"
authorization = {"Authorization": auth}
response = requests.get(url=url, headers=authorization)
res = response.json()
if res.get("code") != 0:
raise Exception(res.get("message"))
return res["data"].get("tenant_id")
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info(auth):
if not get_added_models(auth, "ZHIPU-AI") or not get_added_models(auth, "SILICONFLOW"):
@@ -416,24 +314,20 @@ def set_tenant_info(auth):
f"WARNING: failed to set default embedding LLM via {url}: "
f"{embd_res.get('message')!r}. Continuing."
)
@pytest.fixture(scope="session", autouse=True)
def set_tenant_siliconflow_rerank(auth):
"""
Ensure the SiliconFlow BAAI/bge-reranker-v2-m3 rerank model is registered
for the test tenant. Used by test_retrieval_parity.py as
`BAAI/bge-reranker-v2-m3@CI@SILICONFLOW`.
Runs after `set_tenant_info` so the SILICONFLOW provider+CI instance
already exist when the /add_llm call is made.
If /add_llm is blocked (e.g. factory not in allowed list), the rerank
model config is resolved from FACTORY_LLM_INFOS at search time, so the
test can still proceed.
"""
try:
add_siliconflow_rerank_llm(auth)
except Exception as e:
print(f"Note: Could not register SILICONFLOW rerank model via /add_llm: {e}")
print("The model config will be resolved from FACTORY_LLM_INFOS at runtime.")
# set rerank model
set_default_rerank_response = requests.patch(
url=url,
headers=authorization,
json={
"model_provider": "SILICONFLOW",
"model_instance": "CI",
"model_type": "rerank",
"model_name": "BAAI/bge-reranker-v2-m3"
}
)
rerank_res = set_default_rerank_response.json()
if rerank_res.get("code") != 0:
print(
f"WARNING: failed to set default rerank LLM via {url}: "
f"{rerank_res.get('message')!r}. Continuing."
)