mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Fix: replace tenant_llm apis (#16131)
Replace tenant_llm apis with provider-instance apis.
This commit is contained in:
@@ -727,45 +727,133 @@ class RAGFlowClient:
|
||||
def create_model_provider(self, command):
|
||||
if self.server_type != "user":
|
||||
print("This command is only allowed in USER mode")
|
||||
llm_factory: str = command["provider_name"]
|
||||
return
|
||||
provider_name: str = command["provider_name"]
|
||||
api_key: str = command["provider_key"]
|
||||
payload = {"api_key": api_key, "llm_factory": llm_factory}
|
||||
response = self.http_client.request("POST", "/llm/set_api_key", json_body=payload, use_api_base=False,
|
||||
auth_kind="web")
|
||||
res_json = response.json()
|
||||
if response.status_code == 200 and res_json["code"] == 0:
|
||||
print(f"Success to add model provider {llm_factory}")
|
||||
|
||||
# Step 1: Add provider
|
||||
provider_payload = {"provider_name": provider_name}
|
||||
provider_response = self.http_client.request("PUT", "/providers", json_body=provider_payload,
|
||||
use_api_base=True, auth_kind="web")
|
||||
provider_res = provider_response.json()
|
||||
if provider_response.status_code == 200 and provider_res.get("code") == 0:
|
||||
print(f"Success to add provider {provider_name}")
|
||||
else:
|
||||
print(f"Fail to add model provider {llm_factory}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
msg = provider_res.get("message", "")
|
||||
if "duplicated" in msg.lower() or "already exist" in msg.lower():
|
||||
print(f"Note: provider {provider_name} already exists, continuing to add instance")
|
||||
else:
|
||||
print(f"Fail to add provider {provider_name}, code: {provider_res.get('code')}, message: {msg}")
|
||||
return
|
||||
|
||||
# Step 2: Add instance
|
||||
instance_payload = {
|
||||
"instance_name": "default",
|
||||
"api_key": api_key,
|
||||
"region": "default",
|
||||
"base_url": ""
|
||||
}
|
||||
instance_response = self.http_client.request("POST", f"/providers/{provider_name}/instances",
|
||||
json_body=instance_payload, use_api_base=True,
|
||||
auth_kind="web")
|
||||
instance_res = instance_response.json()
|
||||
if instance_response.status_code == 200 and instance_res.get("code") == 0:
|
||||
print(f"Success to add instance for provider {provider_name}")
|
||||
else:
|
||||
msg = instance_res.get("message", "")
|
||||
if "already exist" in msg.lower():
|
||||
print(f"Note: instance for provider {provider_name} already exists, skipping")
|
||||
else:
|
||||
print(f"Fail to add instance for provider {provider_name}, code: {instance_res.get('code')}, message: {msg}")
|
||||
|
||||
def drop_model_provider(self, command):
|
||||
if self.server_type != "user":
|
||||
print("This command is only allowed in USER mode")
|
||||
llm_factory: str = command["provider_name"]
|
||||
payload = {"llm_factory": llm_factory}
|
||||
response = self.http_client.request("POST", "/llm/delete_factory", json_body=payload, use_api_base=False,
|
||||
return
|
||||
provider_name: str = command["provider_name"]
|
||||
response = self.http_client.request("DELETE", f"/providers/{provider_name}", use_api_base=True,
|
||||
auth_kind="web")
|
||||
res_json = response.json()
|
||||
if response.status_code == 200 and res_json["code"] == 0:
|
||||
print(f"Success to drop model provider {llm_factory}")
|
||||
if response.status_code == 200 and res_json.get("code") == 0:
|
||||
print(f"Success to drop model provider {provider_name}")
|
||||
else:
|
||||
print(
|
||||
f"Fail to drop model provider {llm_factory}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
print(f"Fail to drop model provider {provider_name}, code: {res_json.get('code')}, message: {res_json.get('message')}")
|
||||
|
||||
# Mapping from legacy model_type keys to API model_type values
|
||||
_MODEL_TYPE_MAP = {
|
||||
"llm_id": "chat",
|
||||
"embd_id": "embedding",
|
||||
"img2txt_id": "vision",
|
||||
"reranker_id": "rerank",
|
||||
"asr_id": "asr",
|
||||
"tts_id": "tts",
|
||||
}
|
||||
|
||||
def set_default_model(self, command):
|
||||
if self.server_type != "user":
|
||||
print("This command is only allowed in USER mode")
|
||||
return
|
||||
|
||||
model_type: str = command["model_type"]
|
||||
model_type_key: str = command["model_type"]
|
||||
model_id: str = command["model_id"]
|
||||
self._set_default_models(model_type, model_id)
|
||||
|
||||
model_type = self._MODEL_TYPE_MAP.get(model_type_key)
|
||||
if model_type is None:
|
||||
print(f"Unknown model type: {model_type_key}")
|
||||
return
|
||||
|
||||
model_name, model_instance, model_provider = self._parse_model_id(model_id)
|
||||
|
||||
payload = {
|
||||
"model_provider": model_provider,
|
||||
"model_instance": model_instance,
|
||||
"model_type": model_type,
|
||||
"model_name": model_name,
|
||||
}
|
||||
response = self.http_client.request("PATCH", "/models/default", json_body=payload, use_api_base=True,
|
||||
auth_kind="web")
|
||||
res_json = response.json()
|
||||
if response.status_code == 200 and res_json.get("code") == 0:
|
||||
print(f"Success to set default {model_type} to {model_id}")
|
||||
else:
|
||||
print(f"Fail to set default {model_type}, code: {res_json.get('code')}, message: {res_json.get('message')}")
|
||||
|
||||
def reset_default_model(self, command):
|
||||
if self.server_type != "user":
|
||||
print("This command is only allowed in USER mode")
|
||||
return
|
||||
|
||||
model_type: str = command["model_type"]
|
||||
self._set_default_models(model_type, "")
|
||||
model_type_key: str = command["model_type"]
|
||||
model_type = self._MODEL_TYPE_MAP.get(model_type_key)
|
||||
if model_type is None:
|
||||
print(f"Unknown model type: {model_type_key}")
|
||||
return
|
||||
|
||||
payload = {"model_type": model_type}
|
||||
response = self.http_client.request("PATCH", "/models/default", json_body=payload, use_api_base=True,
|
||||
auth_kind="web")
|
||||
res_json = response.json()
|
||||
if response.status_code == 200 and res_json.get("code") == 0:
|
||||
print(f"Success to reset default {model_type}")
|
||||
else:
|
||||
print(f"Fail to reset default {model_type}, code: {res_json.get('code')}, message: {res_json.get('message')}")
|
||||
|
||||
@staticmethod
|
||||
def _parse_model_id(model_id: str):
|
||||
"""Parse model_id into (model_name, model_instance, model_provider).
|
||||
|
||||
Accepted formats:
|
||||
- model_name@instance@provider -> (model_name, instance, provider)
|
||||
- model_name@provider -> (model_name, "default", provider)
|
||||
- model_name -> (model_name, "default", "")
|
||||
"""
|
||||
parts = model_id.split("@")
|
||||
if len(parts) >= 3:
|
||||
return parts[0], parts[1], parts[-1]
|
||||
elif len(parts) == 2:
|
||||
return parts[0], "default", parts[1]
|
||||
else:
|
||||
return model_id, "default", ""
|
||||
|
||||
def list_user_datasets(self, command):
|
||||
if self.server_type != "user":
|
||||
@@ -1825,41 +1913,6 @@ class RAGFlowClient:
|
||||
print(f"Fail to list chats, code: {res_json['code']}, message: {res_json['message']}")
|
||||
return None
|
||||
|
||||
def _get_default_models(self):
|
||||
response = self.http_client.request("GET", "/users/me/models", use_api_base=True, auth_kind="web")
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
if res_json["code"] == 0:
|
||||
return res_json["data"]
|
||||
else:
|
||||
print(f"Fail to list user default models, code: {res_json['code']}, message: {res_json['message']}")
|
||||
return None
|
||||
else:
|
||||
print(f"Fail to list user default models, HTTP code: {response.status_code}, message: {res_json}")
|
||||
return None
|
||||
|
||||
def _set_default_models(self, model_type, model_id):
|
||||
current_payload = self._get_default_models()
|
||||
if current_payload is None:
|
||||
return
|
||||
else:
|
||||
current_payload.update({model_type: model_id})
|
||||
payload = {
|
||||
"tenant_id": current_payload["tenant_id"],
|
||||
"llm_id": current_payload["llm_id"],
|
||||
"embd_id": current_payload["embd_id"],
|
||||
"img2txt_id": current_payload["img2txt_id"],
|
||||
"asr_id": current_payload["asr_id"],
|
||||
"tts_id": current_payload["tts_id"],
|
||||
}
|
||||
response = self.http_client.request("PATCH", "/users/me/models", json_body=payload, use_api_base=True,
|
||||
auth_kind="web")
|
||||
res_json = response.json()
|
||||
if response.status_code == 200 and res_json["code"] == 0:
|
||||
print(f"Success to set default llm to {model_type}")
|
||||
else:
|
||||
print(f"Fail to set default llm to {model_type}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _format_service_detail_table(self, data):
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
|
||||
@@ -19,18 +19,16 @@ import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
|
||||
from peewee import IntegrityError
|
||||
|
||||
from api.db import UserTenantRole
|
||||
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
|
||||
from api.db.db_models import init_database_tables as init_web_db
|
||||
from api.db.services import UserService
|
||||
from api.db.services.canvas_service import CanvasTemplateService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.db.services.system_settings_service import SystemSettingsService
|
||||
from api.db.template_utils import normalize_canvas_template_categories
|
||||
@@ -76,8 +74,6 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_
|
||||
"role": role
|
||||
}
|
||||
|
||||
tenant_llm = get_init_tenant_llm(user_info["id"])
|
||||
|
||||
try:
|
||||
if not UserService.save(**user_info):
|
||||
logging.error("can't init admin.")
|
||||
@@ -87,7 +83,6 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_
|
||||
return
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
logging.info(
|
||||
f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.")
|
||||
|
||||
@@ -106,54 +101,6 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_
|
||||
logging.error("'{}' doesn't work!".format(tenant["embd_id"]))
|
||||
|
||||
|
||||
def init_llm_factory():
|
||||
# todo deprecated
|
||||
LLMFactoriesService.filter_delete([1 == 1])
|
||||
factory_llm_infos = settings.FACTORY_LLM_INFOS
|
||||
for factory_llm_info in factory_llm_infos:
|
||||
info = deepcopy(factory_llm_info)
|
||||
llm_infos = info.pop("llm")
|
||||
try:
|
||||
LLMFactoriesService.save(**info)
|
||||
except Exception:
|
||||
pass
|
||||
LLMService.filter_delete([LLM.fid == factory_llm_info["name"]])
|
||||
for llm_info in llm_infos:
|
||||
llm_info["fid"] = factory_llm_info["name"]
|
||||
try:
|
||||
LLMService.save(**llm_info)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
LLMFactoriesService.filter_delete([(LLMFactories.name == "Local") | (LLMFactories.name == "novita.ai")])
|
||||
LLMService.filter_delete([LLM.fid == "Local"])
|
||||
LLMService.filter_delete([LLM.llm_name == "qwen-vl-max"])
|
||||
LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
|
||||
TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])
|
||||
LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
|
||||
LLMService.filter_delete([LLMService.model.fid == "QAnything"])
|
||||
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
|
||||
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
|
||||
TenantService.filter_update([1 == 1], {
|
||||
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"})
|
||||
## insert openai two embedding models to the current openai user.
|
||||
# print("Start to insert 2 OpenAI embedding models...")
|
||||
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
|
||||
for tid in tenant_ids:
|
||||
for row in TenantLLMService.query(llm_factory="OpenAI", tenant_id=tid):
|
||||
row = row.to_dict()
|
||||
row["model_type"] = LLMType.EMBEDDING.value
|
||||
row["llm_name"] = "text-embedding-3-small"
|
||||
row["used_tokens"] = 0
|
||||
try:
|
||||
TenantLLMService.save(**row)
|
||||
row = deepcopy(row)
|
||||
row["llm_name"] = "text-embedding-3-large"
|
||||
TenantLLMService.save(**row)
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
def update_document_number_in_init():
|
||||
doc_count = DocumentService.get_all_kb_doc_count()
|
||||
for kb_id in KnowledgebaseService.get_all_ids():
|
||||
|
||||
@@ -27,7 +27,6 @@ from langfuse import propagate_attributes
|
||||
from api.db.db_models import LLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.tenant_llm_service import LLM4Tenant
|
||||
from common.constants import LLMType
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
@@ -35,55 +34,6 @@ class LLMService(CommonService):
|
||||
model = LLM
|
||||
|
||||
|
||||
def get_init_tenant_llm(user_id):
|
||||
from common import settings
|
||||
|
||||
tenant_llm = []
|
||||
|
||||
model_configs = {
|
||||
LLMType.CHAT: settings.CHAT_CFG,
|
||||
LLMType.EMBEDDING: settings.EMBEDDING_CFG,
|
||||
LLMType.SPEECH2TEXT: settings.ASR_CFG,
|
||||
LLMType.IMAGE2TEXT: settings.IMAGE2TEXT_CFG,
|
||||
LLMType.RERANK: settings.RERANK_CFG,
|
||||
}
|
||||
|
||||
seen = set()
|
||||
factory_configs = []
|
||||
for factory_config in [
|
||||
settings.CHAT_CFG,
|
||||
settings.EMBEDDING_CFG,
|
||||
settings.ASR_CFG,
|
||||
settings.IMAGE2TEXT_CFG,
|
||||
settings.RERANK_CFG,
|
||||
]:
|
||||
factory_name = factory_config["factory"]
|
||||
if factory_name not in seen:
|
||||
seen.add(factory_name)
|
||||
factory_configs.append(factory_config)
|
||||
|
||||
for factory_config in factory_configs:
|
||||
for llm in LLMService.query(fid=factory_config["factory"]):
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": factory_config["factory"],
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": model_configs.get(llm.model_type, {}).get("api_key", factory_config["api_key"]),
|
||||
"api_base": model_configs.get(llm.model_type, {}).get("base_url", factory_config["base_url"]),
|
||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||
}
|
||||
)
|
||||
|
||||
unique = {}
|
||||
for item in tenant_llm:
|
||||
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
||||
if key not in unique:
|
||||
unique[key] = item
|
||||
return list(unique.values())
|
||||
|
||||
|
||||
class LLMBundle(LLM4Tenant):
|
||||
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
|
||||
super().__init__(tenant_id, model_config, lang, **kwargs)
|
||||
|
||||
@@ -91,19 +91,6 @@ def get_email():
|
||||
return EMAIL
|
||||
|
||||
|
||||
def get_my_llms(auth, name):
|
||||
# todo deprecated
|
||||
url = HOST_ADDRESS + "/v1/llm/my_llms"
|
||||
authorization = {"Authorization": auth}
|
||||
response = requests.get(url=url, headers=authorization)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
if name in res.get("data"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_added_models(auth, factory_name):
|
||||
url = HOST_ADDRESS + "/api/v1/models"
|
||||
authorization = {"Authorization": auth}
|
||||
@@ -116,21 +103,6 @@ def get_added_models(auth, factory_name):
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_models(auth):
|
||||
# todo deprecated
|
||||
url = HOST_ADDRESS + "/v1/llm/set_api_key"
|
||||
authorization = {"Authorization": auth}
|
||||
models_info = {
|
||||
"ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY},
|
||||
}
|
||||
|
||||
for name, model_info in models_info.items():
|
||||
if not get_my_llms(auth, name):
|
||||
response = requests.post(url=url, headers=authorization, json=model_info)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
pytest.exit(f"Critical error in add_models: {res.get('message')}")
|
||||
|
||||
|
||||
def add_model_instance(auth):
|
||||
add_provider_api = HOST_ADDRESS + "/api/v1/providers"
|
||||
@@ -156,17 +128,6 @@ def add_model_instance(auth):
|
||||
pytest.exit("Critical error in check added model: add model failed")
|
||||
|
||||
|
||||
def get_tenant_info(auth):
|
||||
# todo deprecated
|
||||
url = HOST_ADDRESS + "/api/v1/users/me/models"
|
||||
authorization = {"Authorization": auth}
|
||||
response = requests.get(url=url, headers=authorization)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
return res["data"].get("tenant_id")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info(get_auth):
|
||||
auth = get_auth
|
||||
|
||||
@@ -59,7 +59,7 @@ Auth and bootstrap flags (used when --api-key is not provided)
|
||||
--token-name
|
||||
Optional API token name for /system/new_token.
|
||||
--bootstrap-llm
|
||||
Ensure LLM factory API key is configured via /llm/set_api_key.
|
||||
Ensure LLM factory API key is configured via /providers + /providers/{name}/instances.
|
||||
--llm-factory
|
||||
LLM factory name for bootstrap.
|
||||
Env: RAGFLOW_LLM_FACTORY
|
||||
@@ -70,7 +70,7 @@ Auth and bootstrap flags (used when --api-key is not provided)
|
||||
Optional LLM API base URL.
|
||||
Env: RAGFLOW_LLM_API_BASE
|
||||
--set-tenant-info
|
||||
Set tenant defaults via /users/me/models.
|
||||
Set tenant defaults via /models/default.
|
||||
--tenant-llm-id
|
||||
Tenant chat model ID.
|
||||
Env: RAGFLOW_TENANT_LLM_ID
|
||||
|
||||
@@ -55,10 +55,12 @@ def create_api_token(client: HttpClient, login_token: str, token_name: Optional[
|
||||
|
||||
|
||||
def get_my_llms(client: HttpClient) -> Dict[str, Any]:
|
||||
res = client.request_json("GET", "/llm/my_llms", use_api_base=False, auth_kind="login")
|
||||
"""List tenant-configured providers. Returns a dict keyed by provider name."""
|
||||
res = client.request_json("GET", "/providers", use_api_base=True, auth_kind="login")
|
||||
if res.get("code") != 0:
|
||||
raise AuthError(f"Failed to list LLMs: {res.get('message')}")
|
||||
return res.get("data", {})
|
||||
raise AuthError(f"Failed to list providers: {res.get('message')}")
|
||||
providers = res.get("data", [])
|
||||
return {p.get("name", ""): p for p in providers} if isinstance(providers, list) else {}
|
||||
|
||||
|
||||
def set_llm_api_key(
|
||||
@@ -67,22 +69,48 @@ def set_llm_api_key(
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
) -> None:
|
||||
payload = {"llm_factory": llm_factory, "api_key": api_key}
|
||||
if base_url:
|
||||
payload["base_url"] = base_url
|
||||
res = client.request_json("POST", "/llm/set_api_key", use_api_base=False, auth_kind="login", json_body=payload)
|
||||
if res.get("code") != 0:
|
||||
raise AuthError(f"Failed to set LLM API key: {res.get('message')}")
|
||||
"""Add a provider (PUT /providers) and create a default instance (POST /providers/{name}/instances)."""
|
||||
provider_payload = {"provider_name": llm_factory}
|
||||
provider_res = client.request_json("PUT", "/providers", use_api_base=True, auth_kind="login", json_body=provider_payload)
|
||||
provider_msg = provider_res.get("message", "")
|
||||
if provider_res.get("code") != 0 and "duplicated" not in provider_msg.lower() and "already exist" not in provider_msg.lower():
|
||||
raise AuthError(f"Failed to add provider: {provider_msg}")
|
||||
|
||||
instance_payload = {
|
||||
"instance_name": "default",
|
||||
"api_key": api_key,
|
||||
"region": "default",
|
||||
"base_url": base_url or "",
|
||||
}
|
||||
instance_res = client.request_json("POST", f"/providers/{llm_factory}/instances", use_api_base=True,
|
||||
auth_kind="login", json_body=instance_payload)
|
||||
instance_msg = instance_res.get("message", "")
|
||||
if instance_res.get("code") != 0 and "already exist" not in instance_msg.lower():
|
||||
raise AuthError(f"Failed to add instance: {instance_msg}")
|
||||
|
||||
|
||||
def get_tenant_info(client: HttpClient) -> Dict[str, Any]:
|
||||
res = client.request_json("GET", "/users/me/models", use_api_base=True, auth_kind="login")
|
||||
def get_default_models(client: HttpClient) -> Dict[str, Any]:
|
||||
"""List tenant default models."""
|
||||
res = client.request_json("GET", "/models/default", use_api_base=True, auth_kind="login")
|
||||
if res.get("code") != 0:
|
||||
raise AuthError(f"Failed to get tenant info: {res.get('message')}")
|
||||
raise AuthError(f"Failed to get default models: {res.get('message')}")
|
||||
return res.get("data", {})
|
||||
|
||||
|
||||
def set_tenant_info(client: HttpClient, payload: Dict[str, Any]) -> None:
|
||||
res = client.request_json("PATCH", "/users/me/models", use_api_base=True, auth_kind="login", json_body=payload)
|
||||
def set_default_model(
|
||||
client: HttpClient,
|
||||
model_provider: str,
|
||||
model_instance: str,
|
||||
model_name: str,
|
||||
model_type: str,
|
||||
) -> None:
|
||||
"""Set a tenant default model via PATCH /models/default."""
|
||||
payload = {
|
||||
"model_provider": model_provider,
|
||||
"model_instance": model_instance,
|
||||
"model_name": model_name,
|
||||
"model_type": model_type,
|
||||
}
|
||||
res = client.request_json("PATCH", "/models/default", use_api_base=True, auth_kind="login", json_body=payload)
|
||||
if res.get("code") != 0:
|
||||
raise AuthError(f"Failed to set tenant info: {res.get('message')}")
|
||||
raise AuthError(f"Failed to set default model: {res.get('message')}")
|
||||
|
||||
@@ -199,6 +199,50 @@ def _retrieval_worker(
|
||||
return run_retrieval_request(client, payload)
|
||||
|
||||
|
||||
def _parse_model_id(model_id: str):
|
||||
"""Parse model_id into (model_name, model_instance, model_provider).
|
||||
|
||||
Accepted formats:
|
||||
- model_name@instance@provider -> (model_name, instance, provider)
|
||||
- model_name@provider -> (model_name, "default", provider)
|
||||
- model_name -> (model_name, "default", "")
|
||||
"""
|
||||
parts = model_id.split("@")
|
||||
if len(parts) >= 3:
|
||||
return parts[0], parts[1], parts[-1]
|
||||
elif len(parts) == 2:
|
||||
return parts[0], "default", parts[1]
|
||||
else:
|
||||
return model_id, "default", ""
|
||||
|
||||
|
||||
# Mapping from legacy CLI arg keys to API model_type values
|
||||
_MODEL_TYPE_MAP = {
|
||||
"llm_id": "chat",
|
||||
"embd_id": "embedding",
|
||||
"img2txt_id": "vision",
|
||||
"asr_id": "asr",
|
||||
"tts_id": "tts",
|
||||
}
|
||||
|
||||
|
||||
def _set_default_models_from_args(client: HttpClient, args: argparse.Namespace) -> None:
|
||||
"""Set tenant default models from CLI arguments using the new PATCH /models/default API."""
|
||||
model_specs = [
|
||||
(args.tenant_llm_id, "llm_id"),
|
||||
(args.tenant_embd_id, "embd_id"),
|
||||
(args.tenant_img2txt_id, "img2txt_id"),
|
||||
(args.tenant_asr_id, "asr_id"),
|
||||
(args.tenant_tts_id, "tts_id"),
|
||||
]
|
||||
for model_id, type_key in model_specs:
|
||||
if not model_id:
|
||||
continue
|
||||
model_name, model_instance, model_provider = _parse_model_id(model_id)
|
||||
model_type = _MODEL_TYPE_MAP[type_key]
|
||||
auth.set_default_model(client, model_provider, model_instance, model_name, model_type)
|
||||
|
||||
|
||||
def _ensure_auth(client: HttpClient, args: argparse.Namespace) -> None:
|
||||
if args.api_key:
|
||||
client.api_key = args.api_key
|
||||
@@ -232,19 +276,7 @@ def _ensure_auth(client: HttpClient, args: argparse.Namespace) -> None:
|
||||
if args.set_tenant_info:
|
||||
if not args.tenant_llm_id or not args.tenant_embd_id:
|
||||
raise AuthError("Missing --tenant-llm-id or --tenant-embd-id for tenant setup")
|
||||
tenant = auth.get_tenant_info(client)
|
||||
tenant_id = tenant.get("tenant_id")
|
||||
if not tenant_id:
|
||||
raise AuthError("Tenant info missing tenant_id")
|
||||
payload = {
|
||||
"tenant_id": tenant_id,
|
||||
"llm_id": args.tenant_llm_id,
|
||||
"embd_id": args.tenant_embd_id,
|
||||
"img2txt_id": args.tenant_img2txt_id or "",
|
||||
"asr_id": args.tenant_asr_id or "",
|
||||
"tts_id": args.tenant_tts_id,
|
||||
}
|
||||
auth.set_tenant_info(client, payload)
|
||||
_set_default_models_from_args(client, args)
|
||||
|
||||
api_key = auth.create_api_token(client, login_token, args.token_name)
|
||||
client.api_key = api_key
|
||||
|
||||
@@ -170,19 +170,6 @@ def token(auth):
|
||||
return res["data"].get("token")
|
||||
|
||||
|
||||
def get_my_llms(auth, name):
|
||||
# todo deprecated
|
||||
url = HOST_ADDRESS + f"/{VERSION}/llm/my_llms"
|
||||
authorization = {"Authorization": auth}
|
||||
response = requests.get(url=url, headers=authorization)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
if name in res.get("data"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_added_models(auth, factory_name):
|
||||
url = HOST_ADDRESS + "/api/v1/models"
|
||||
authorization = {"Authorization": auth}
|
||||
@@ -202,49 +189,6 @@ def get_added_models(auth, factory_name):
|
||||
return False
|
||||
|
||||
|
||||
def get_tenant_llm_added(auth, factory_name, model_name, model_type="rerank"):
|
||||
"""
|
||||
Check whether a specific (factory, model_name, model_type) tenant_llm row exists.
|
||||
|
||||
Legacy /v1/llm/my_llms response shape:
|
||||
{
|
||||
"ZHIPU-AI": {"tags": ..., "llm": [{"name": ..., "type": ...}, ...]},
|
||||
"SILICONFLOW": {"tags": ..., "llm": [{"name": ..., "type": ...}, ...]},
|
||||
}
|
||||
so we navigate by factory key first, then look through its llm list.
|
||||
"""
|
||||
url = HOST_ADDRESS + f"/{VERSION}/llm/my_llms"
|
||||
authorization = {"Authorization": auth}
|
||||
response = requests.get(url=url, headers=authorization)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
return False
|
||||
data = res.get("data") or {}
|
||||
factory_data = data.get(factory_name) or {}
|
||||
for m in factory_data.get("llm", []) or []:
|
||||
if m.get("name") != model_name:
|
||||
continue
|
||||
if model_type is None or m.get("type") == model_type:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def add_models(auth):
|
||||
# todo deprecated
|
||||
url = HOST_ADDRESS + f"/{VERSION}/llm/set_api_key"
|
||||
authorization = {"Authorization": auth}
|
||||
models_info = {
|
||||
"ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY},
|
||||
}
|
||||
|
||||
for name, model_info in models_info.items():
|
||||
if not get_my_llms(auth, name):
|
||||
response = requests.post(url=url, headers=authorization, json=model_info)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
pytest.exit(f"Critical error in add_models: {res.get('message')}")
|
||||
|
||||
|
||||
def add_model_instance(auth):
|
||||
add_provider_api = HOST_ADDRESS + "/api/v1/providers"
|
||||
authorization = {"Authorization": auth}
|
||||
@@ -277,34 +221,34 @@ def add_model_instance(auth):
|
||||
else:
|
||||
pytest.exit(f"Critical error in add model provider: {msg}")
|
||||
|
||||
# Register both "CI" (used by glm-4-flash@CI@ZHIPU-AI in configs.py
|
||||
# and BAAI/bge-reranker-v2-m3@CI@SILICONFLOW) and "default".
|
||||
for instance_name in ("CI", "default"):
|
||||
add_instance_api = HOST_ADDRESS + f"/api/v1/providers/{provider_name}/instances"
|
||||
add_instance_response = requests.post(url=add_instance_api, headers=authorization, json={
|
||||
"instance_name": instance_name,
|
||||
"api_key": api_key,
|
||||
"region": "default",
|
||||
"base_url": ""
|
||||
})
|
||||
add_instance_res = add_instance_response.json()
|
||||
if add_instance_res.get("code") != 0:
|
||||
msg = add_instance_res.get("message", "")
|
||||
# Instance may already exist with a different API key from a
|
||||
# prior test run; that's fine — skip instead of failing.
|
||||
if "Already exist instance" in msg or "already exist" in msg.lower():
|
||||
print(f"Note: {provider_name}/{instance_name} already exists, skipping")
|
||||
continue
|
||||
# Python API blocks creating instances named "default".
|
||||
# The test_retrieval_parity test handles this by inserting
|
||||
# "default" directly into the DB for SILICONFLOW.
|
||||
if "cannot be 'default'" in msg:
|
||||
print(f"Note: {provider_name}/{instance_name} blocked by API (name reserved), skipping")
|
||||
continue
|
||||
pytest.exit(
|
||||
f"Critical error in add model instance {provider_name}/{instance_name}: "
|
||||
f"{msg}"
|
||||
)
|
||||
# Register "CI" (used by glm-4-flash@CI@ZHIPU-AI in configs.py
|
||||
# and BAAI/bge-reranker-v2-m3@CI@SILICONFLOW).
|
||||
instance_name = "CI"
|
||||
add_instance_api = HOST_ADDRESS + f"/api/v1/providers/{provider_name}/instances"
|
||||
add_instance_response = requests.post(url=add_instance_api, headers=authorization, json={
|
||||
"instance_name": instance_name,
|
||||
"api_key": api_key,
|
||||
"region": "default",
|
||||
"base_url": ""
|
||||
})
|
||||
add_instance_res = add_instance_response.json()
|
||||
if add_instance_res.get("code") != 0:
|
||||
msg = add_instance_res.get("message", "")
|
||||
# Instance may already exist with a different API key from a
|
||||
# prior test run; that's fine — skip instead of failing.
|
||||
if "Already exist instance" in msg or "already exist" in msg.lower():
|
||||
print(f"Note: {provider_name}/{instance_name} already exists, skipping")
|
||||
continue
|
||||
# Python API blocks creating instances named "default".
|
||||
# The test_retrieval_parity test handles this by inserting
|
||||
# "default" directly into the DB for SILICONFLOW.
|
||||
if "cannot be 'default'" in msg:
|
||||
print(f"Note: {provider_name}/{instance_name} blocked by API (name reserved), skipping")
|
||||
continue
|
||||
pytest.exit(
|
||||
f"Critical error in add model instance {provider_name}/{instance_name}: "
|
||||
f"{msg}"
|
||||
)
|
||||
|
||||
add_success = get_added_models(auth, provider_name)
|
||||
if not add_success:
|
||||
@@ -324,52 +268,6 @@ def add_model_instance(auth):
|
||||
pytest.exit(f"Critical error in check added model: {provider_name} add model failed")
|
||||
|
||||
|
||||
def add_siliconflow_rerank_llm(auth):
|
||||
"""
|
||||
Register the BAAI/bge-reranker-v2-m3 rerank model under factory=SILICONFLOW / instance=CI.
|
||||
|
||||
This is the model referenced as `BAAI/bge-reranker-v2-m3@CI@SILICONFLOW` in
|
||||
test_retrieval_parity.py. The /v1/llm/add_llm endpoint validates the key by
|
||||
issuing a real rerank request, so the call requires network access to SiliconFlow
|
||||
and a valid SILICONFLOW_API_KEY.
|
||||
"""
|
||||
factory = "SILICONFLOW"
|
||||
model_name = "BAAI/bge-reranker-v2-m3"
|
||||
if get_tenant_llm_added(auth, factory, model_name, "rerank"):
|
||||
return
|
||||
|
||||
url = HOST_ADDRESS + f"/{VERSION}/llm/add_llm"
|
||||
authorization = {"Authorization": auth}
|
||||
payload = {
|
||||
"llm_factory": factory,
|
||||
"llm_name": model_name,
|
||||
"model_type": "rerank",
|
||||
"api_key": SILICONFLOW_API_KEY,
|
||||
"api_base": "",
|
||||
}
|
||||
response = requests.post(url=url, headers=authorization, json=payload)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
pytest.exit(
|
||||
f"Critical error adding {factory} rerank model {model_name}: "
|
||||
f"code={res.get('code')} message={res.get('message')} data={res.get('data')}"
|
||||
)
|
||||
|
||||
if not get_tenant_llm_added(auth, factory, model_name, "rerank"):
|
||||
pytest.exit(f"Failed to confirm {factory}/{model_name} rerank row was added")
|
||||
|
||||
|
||||
def get_tenant_info(auth):
|
||||
# todo deprecated
|
||||
url = HOST_ADDRESS + f"/api/{VERSION}/users/me/models"
|
||||
authorization = {"Authorization": auth}
|
||||
response = requests.get(url=url, headers=authorization)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
return res["data"].get("tenant_id")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info(auth):
|
||||
if not get_added_models(auth, "ZHIPU-AI") or not get_added_models(auth, "SILICONFLOW"):
|
||||
@@ -416,24 +314,20 @@ def set_tenant_info(auth):
|
||||
f"WARNING: failed to set default embedding LLM via {url}: "
|
||||
f"{embd_res.get('message')!r}. Continuing."
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_siliconflow_rerank(auth):
|
||||
"""
|
||||
Ensure the SiliconFlow BAAI/bge-reranker-v2-m3 rerank model is registered
|
||||
for the test tenant. Used by test_retrieval_parity.py as
|
||||
`BAAI/bge-reranker-v2-m3@CI@SILICONFLOW`.
|
||||
|
||||
Runs after `set_tenant_info` so the SILICONFLOW provider+CI instance
|
||||
already exist when the /add_llm call is made.
|
||||
|
||||
If /add_llm is blocked (e.g. factory not in allowed list), the rerank
|
||||
model config is resolved from FACTORY_LLM_INFOS at search time, so the
|
||||
test can still proceed.
|
||||
"""
|
||||
try:
|
||||
add_siliconflow_rerank_llm(auth)
|
||||
except Exception as e:
|
||||
print(f"Note: Could not register SILICONFLOW rerank model via /add_llm: {e}")
|
||||
print("The model config will be resolved from FACTORY_LLM_INFOS at runtime.")
|
||||
# set rerank model
|
||||
set_default_rerank_response = requests.patch(
|
||||
url=url,
|
||||
headers=authorization,
|
||||
json={
|
||||
"model_provider": "SILICONFLOW",
|
||||
"model_instance": "CI",
|
||||
"model_type": "rerank",
|
||||
"model_name": "BAAI/bge-reranker-v2-m3"
|
||||
}
|
||||
)
|
||||
rerank_res = set_default_rerank_response.json()
|
||||
if rerank_res.get("code") != 0:
|
||||
print(
|
||||
f"WARNING: failed to set default rerank LLM via {url}: "
|
||||
f"{rerank_res.get('message')!r}. Continuing."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user