mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
952 lines
41 KiB
Python
952 lines
41 KiB
Python
#
|
|
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import os
|
|
import json
|
|
import logging
|
|
import asyncio
|
|
|
|
from common.constants import LLMType, ActiveStatusEnum
|
|
from common.misc_utils import get_uuid
|
|
from common.settings import FACTORY_LLM_INFOS
|
|
from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance, delete_models_by_instance_ids, delete_instances_by_provider_ids
|
|
from api.db.services.tenant_model_provider_service import TenantModelProviderService
|
|
from api.db.services.tenant_model_instance_service import TenantModelInstanceService
|
|
from api.db.services.tenant_model_service import TenantModelService
|
|
from rag.llm import ChatModel, EmbeddingModel, ModelMeta, OcrModel, RerankModel, TTSModel
|
|
|
|
|
|
def _to_int(v, default=500):
|
|
try:
|
|
return int(v)
|
|
except (TypeError, ValueError):
|
|
return default
|
|
|
|
|
|
def _factory_model_types(llm: dict) -> list[str]:
|
|
model_type = llm.get("model_type")
|
|
if isinstance(model_type, list):
|
|
return model_type
|
|
return [model_type] if model_type else []
|
|
|
|
|
|
def _normalize_provider_base_url(provider_name: str, base_url: str | None):
|
|
if provider_name != "VLLM" or not base_url:
|
|
return base_url
|
|
base_url = base_url.strip().rstrip("/")
|
|
if not base_url.endswith("/v1"):
|
|
base_url += "/v1"
|
|
return base_url
|
|
|
|
|
|
|
|
def _factory_llm_name(llm: dict) -> str:
|
|
return llm.get("name") or llm.get("llm_name", "")
|
|
|
|
|
|
def list_providers(tenant_id: str, all_available: bool = False):
|
|
"""
|
|
List providers for a tenant.
|
|
|
|
If available_only is True, list all system-wide providers (pool providers).
|
|
Otherwise, list providers that the tenant has configured.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param all_available: whether to list all available providers
|
|
:return: (success, result)
|
|
"""
|
|
if not FACTORY_LLM_INFOS:
|
|
return False, []
|
|
|
|
factory_rank_mapping = {factory["name"]: -_to_int(factory.get("rank", "500")) for factory in FACTORY_LLM_INFOS}
|
|
factory_info_map = {f["name"]: f for f in FACTORY_LLM_INFOS}
|
|
if all_available:
|
|
providers = []
|
|
for factory_info in FACTORY_LLM_INFOS:
|
|
if factory_info["name"] in ["Youdao", "FastEmbed", "BAAI", "Builtin", "siliconflow_intl"]:
|
|
continue
|
|
model_types = sorted(set(
|
|
model_type
|
|
for llm in factory_info.get("llm", [])
|
|
for model_type in _factory_model_types(llm)
|
|
)) if factory_info.get("llm", []) else []
|
|
if factory_info["name"] in ["MinerU", "PaddleOCR", "OpenDataLoader"]:
|
|
model_types.append("ocr")
|
|
provider = {
|
|
"model_types": model_types,
|
|
"name": factory_info["name"],
|
|
"url": {
|
|
"default": factory_info.get("url", "")
|
|
}
|
|
}
|
|
if factory_info["name"].lower() == "siliconflow":
|
|
provider["url"]["intl"] = factory_info_map.get("siliconflow_intl", {}).get("url", "https://api.siliconflow.com/v1")
|
|
elif factory_info["name"] == "Tongyi-Qianwen":
|
|
provider["url"]["intl"] = "https://dashscope-intl.aliyuncs.com/compatible-model/v1"
|
|
providers.append(provider)
|
|
providers.sort(key=lambda x: (factory_rank_mapping.get(x["name"]), x["name"]))
|
|
return True, providers
|
|
|
|
# List tenant-configured providers
|
|
factory_names = TenantModelProviderService.list_provider_names_by_tenant_id(tenant_id)
|
|
|
|
providers = []
|
|
factory_info_mapping = {f["name"]: f for f in FACTORY_LLM_INFOS}
|
|
for name in factory_names:
|
|
if name not in ["Youdao", "FastEmbed", "BAAI", "Builtin", "siliconflow_intl"] and factory_info_mapping.get(name):
|
|
factory_info = factory_info_mapping[name]
|
|
model_types = sorted(set(
|
|
model_type
|
|
for llm in factory_info.get("llm", [])
|
|
for model_type in _factory_model_types(llm)
|
|
)) if factory_info.get("llm", []) else []
|
|
if name in ["MinerU", "PaddleOCR", "OpenDataLoader"]:
|
|
model_types.append("ocr")
|
|
|
|
provider = {
|
|
"model_types": model_types,
|
|
"name": factory_info["name"],
|
|
"url": {
|
|
"default": factory_info.get("url", "")
|
|
}
|
|
}
|
|
if factory_info["name"].lower() == "siliconflow":
|
|
provider["url"]["intl"] = factory_info_map.get("siliconflow_intl", {}).get("url", "https://api.siliconflow.com/v1")
|
|
elif factory_info["name"] == "Tongyi-Qianwen":
|
|
provider["url"]["intl"] = "https://dashscope-intl.aliyuncs.com/compatible-model/v1"
|
|
providers.append(provider)
|
|
providers.sort(key=lambda x: (factory_rank_mapping.get(x["name"]), x["name"]))
|
|
return True, providers
|
|
|
|
|
|
def add_provider(tenant_id: str, provider_name: str):
|
|
"""
|
|
Add a provider (factory) for a tenant.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_name: provider/factory name
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
if not FACTORY_LLM_INFOS:
|
|
return False, "No providers found"
|
|
# Check if factory is allowed
|
|
allowed_factories = [f["name"] for f in FACTORY_LLM_INFOS]
|
|
if provider_name not in allowed_factories:
|
|
return False, f"Provider '{provider_name}' is not allowed"
|
|
|
|
existing = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
|
if existing:
|
|
return False, f"Provider {provider_name} already exists"
|
|
|
|
TenantModelProviderService.insert(
|
|
tenant_id=tenant_id,
|
|
provider_name=provider_name
|
|
)
|
|
return True, "success"
|
|
|
|
|
|
def delete_provider(tenant_id: str, provider_id_or_name: str):
|
|
"""
|
|
Delete all instances and models for a provider.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_id_or_name: provider ID or provider/factory name
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_id(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
return False, f"Provider {provider_id_or_name} not found"
|
|
instance_objs = TenantModelInstanceService.get_all_by_provider_id(provider_obj.id)
|
|
if instance_objs:
|
|
instance_ids = [instance_obj.id for instance_obj in instance_objs]
|
|
delete_models_by_instance_ids(instance_ids)
|
|
delete_instances_by_provider_ids([provider_obj.id])
|
|
TenantModelProviderService.delete_by_tenant_id_and_provider_name(tenant_id, provider_obj.provider_name)
|
|
return True, "success"
|
|
|
|
|
|
def show_provider(provider_id_or_name: str):
|
|
"""
|
|
Show provider details from LLMFactories.
|
|
|
|
:param provider_id_or_name: provider/factory ID or name
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
provider_obj = None
|
|
if provider_id_or_name:
|
|
_, provider_obj = TenantModelProviderService.get_by_id(provider_id_or_name)
|
|
provider_name = provider_obj.provider_name if provider_obj else provider_id_or_name
|
|
fac_list = [f for f in FACTORY_LLM_INFOS if f["name"]==provider_name]
|
|
if not fac_list:
|
|
return False, f"Provider '{provider_id_or_name}' not found"
|
|
factory_info = fac_list[0]
|
|
return True, {
|
|
"base_url": {
|
|
"default": factory_info.get("url", "")
|
|
},
|
|
"name": factory_info["name"],
|
|
"total_models": len(factory_info.get("llm", []))
|
|
}
|
|
|
|
|
|
async def list_provider_models(provider_id_or_name: str, api_key: str = None, base_url: str = None):
|
|
"""
|
|
List all models for a provider from the LLM dictionary.
|
|
|
|
:param provider_id_or_name: provider ID or provider/factory name
|
|
:param api_key: api key
|
|
:param base_url: base url
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
provider_obj = None
|
|
if provider_id_or_name:
|
|
_, provider_obj = TenantModelProviderService.get_by_id(provider_id_or_name)
|
|
provider_name = provider_obj.provider_name if provider_obj else provider_id_or_name
|
|
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
|
|
if not factory_info:
|
|
return False, f"Provider '{provider_id_or_name}' not found"
|
|
static_llms = [{
|
|
"name": _factory_llm_name(llm),
|
|
"max_tokens": llm["max_tokens"],
|
|
"model_types": _factory_model_types(llm),
|
|
"features": (
|
|
llm.get("features")
|
|
if llm.get("features") is not None
|
|
else (
|
|
(["is_tools"] if llm.get("is_tools") else [])
|
|
+ (["thinking"] if llm.get("thinking") else [])
|
|
)
|
|
)
|
|
} for llm in factory_info[0]["llm"]]
|
|
|
|
model_base_url = _normalize_provider_base_url(provider_name, base_url) or factory_info[0].get("url", "")
|
|
remote_models = []
|
|
if provider_name in ModelMeta:
|
|
remote_models = await ModelMeta[provider_name](api_key, model_base_url).get_model_list()
|
|
|
|
if not static_llms and not remote_models:
|
|
return False, f"No models found for provider '{provider_id_or_name}'"
|
|
|
|
# Merge static and remote models, preferring remote_models on name conflicts
|
|
merged = {m["name"]: m for m in static_llms}
|
|
merged.update({m["name"]: m for m in remote_models})
|
|
models = list(merged.values())
|
|
|
|
models.sort(key=lambda x: x["name"])
|
|
return True, models
|
|
|
|
|
|
def show_provider_model(provider_id_or_name: str, model_name: str):
|
|
"""
|
|
Show a specific model for a provider.
|
|
|
|
:param provider_id_or_name: provider/factory ID or name
|
|
:param model_name: model name
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
provider_obj = None
|
|
if provider_id_or_name:
|
|
_, provider_obj = TenantModelProviderService.get_by_id(provider_id_or_name)
|
|
provider_name = provider_obj.provider_name if provider_obj else provider_id_or_name
|
|
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
|
|
if not factory_info:
|
|
return False, f"Provider '{provider_id_or_name}' not found"
|
|
llms = factory_info[0]["llm"]
|
|
if not llms:
|
|
return False, f"No models found for provider '{provider_id_or_name}'"
|
|
target_llm = [llm for llm in llms if _factory_llm_name(llm) == model_name]
|
|
if not target_llm:
|
|
return False, f"Model '{model_name}' not found"
|
|
llm_info = target_llm[0]
|
|
|
|
return True, {
|
|
"name": _factory_llm_name(llm_info),
|
|
"max_tokens": llm_info["max_tokens"],
|
|
"model_types": _factory_model_types(llm_info),
|
|
"thinking": None,
|
|
"model_type_map": {model_type: True for model_type in _factory_model_types(llm_info)}
|
|
}
|
|
|
|
|
|
async def create_provider_instance(tenant_id: str, provider_id_or_name: str, instance_name: str, api_key: str|dict, base_url: str, region: str, model_info: list[dict]=None):
|
|
"""
|
|
Create a provider instance.
|
|
|
|
The instance_name parameter is accepted for API compatibility but in the old
|
|
model all records under a factory share the same API key configuration.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_id_or_name: provider/factory ID or name
|
|
:param instance_name: instance name (used as a logical identifier)
|
|
:param api_key: API key
|
|
:param base_url: base url
|
|
:param region: region
|
|
:param model_info: model info, [{
|
|
"model_type": ["chat"], # support multiple
|
|
"model_name": "name",
|
|
"max_tokens": 4096,
|
|
"extra": {
|
|
"field1": "value1",
|
|
"field2": "'value2"
|
|
}
|
|
}]
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
if not provider_id_or_name:
|
|
return False, "Provider ID or name is required"
|
|
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_id(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
return False, f"Provider '{provider_id_or_name}' does not exist"
|
|
|
|
provider_name = provider_obj.provider_name
|
|
|
|
base_url = _normalize_provider_base_url(provider_name, base_url)
|
|
|
|
if instance_name == "default":
|
|
return False, "Instance name cannot be 'default'"
|
|
|
|
# Check if provider exists in the system
|
|
allowed_factories = [f["name"] for f in FACTORY_LLM_INFOS]
|
|
if provider_name not in allowed_factories:
|
|
return False, f"Provider '{provider_name}' is not allowed"
|
|
|
|
api_key_str = ""
|
|
if api_key:
|
|
api_key_str = api_key if isinstance(api_key, str) else json.dumps(api_key)
|
|
success, msg = await verify_api_key(provider_name, api_key, base_url, region, model_info)
|
|
if not success:
|
|
return False, msg
|
|
|
|
extra_fields = {}
|
|
if base_url:
|
|
extra_fields["base_url"] = base_url
|
|
if region:
|
|
extra_fields["region"] = region
|
|
TenantModelInstanceService.create_instance(provider_id=provider_obj.id,instance_name=instance_name,api_key=api_key_str, extra=json.dumps(extra_fields))
|
|
if model_info:
|
|
msg = ""
|
|
for model in model_info:
|
|
success, _msg = add_model_to_instance(tenant_id, provider_name, instance_name, **model)
|
|
if not success:
|
|
msg += _msg
|
|
if msg:
|
|
return False, msg
|
|
|
|
return True, "success"
|
|
|
|
|
|
def list_provider_instances(tenant_id: str, provider_id_or_name: str):
|
|
"""
|
|
List provider instances for a tenant.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_id_or_name: provider/factory ID or name
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_id(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
return False, f"No provider found for provider '{provider_id_or_name}'"
|
|
provider_id = provider_obj.id
|
|
instance_objs = TenantModelInstanceService.get_all_by_provider_id(provider_id)
|
|
if not instance_objs:
|
|
return True, []
|
|
instances = []
|
|
for instance_obj in instance_objs:
|
|
extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {}
|
|
instances.append({
|
|
"id": instance_obj.id,
|
|
"instance_name": instance_obj.instance_name,
|
|
"provider_id": provider_id,
|
|
"region": extra_fields.get("region", ""),
|
|
"status": instance_obj.status,
|
|
})
|
|
|
|
return True, instances
|
|
|
|
|
|
async def verify_api_key(provider_id_or_name: str, api_key: str|dict, base_url: str=None, region: str=None, model_info: list[dict]=None):
|
|
"""
|
|
Verify API key for a provider.
|
|
|
|
:param provider_id_or_name: provider/factory ID or name
|
|
:param api_key: API key
|
|
:param base_url: base url
|
|
:param region: region
|
|
:param model_info: model info, [{
|
|
"model_type": ["chat"], # support multiple
|
|
"model_name": "name",
|
|
"max_tokens": 4096,
|
|
"extra": {
|
|
"field1": "value1",
|
|
"field2": "'value2"
|
|
}
|
|
}]
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
if not provider_id_or_name:
|
|
return False, "Provider ID or name is required"
|
|
|
|
provider_obj = None
|
|
if provider_id_or_name:
|
|
_, provider_obj = TenantModelProviderService.get_by_id(provider_id_or_name)
|
|
provider_name = provider_obj.provider_name if provider_obj else provider_id_or_name
|
|
|
|
base_url = _normalize_provider_base_url(provider_name, base_url)
|
|
|
|
if region and region == "intl" and provider_name.lower() == "siliconflow":
|
|
target_factory_name = "siliconflow_intl"
|
|
else:
|
|
target_factory_name = provider_name
|
|
|
|
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == target_factory_name]
|
|
if not factory_info:
|
|
return False, f"Provider '{provider_id_or_name}' not found"
|
|
|
|
factory_llms = factory_info[0]["llm"]
|
|
if not factory_llms:
|
|
if not model_info:
|
|
return False, f"No models found for provider '{provider_id_or_name}'"
|
|
factory_llms = [{
|
|
"model_type": _type,
|
|
"llm_name": model.get("model_name", ""),
|
|
} for model in model_info if model for _type in model.get("model_type", []) ]
|
|
if not factory_llms:
|
|
return False, f"No valid models found for provider '{provider_id_or_name}'"
|
|
|
|
# test if api key works
|
|
chat_passed, embd_passed, rerank_passed, ocr_passed, tts_passed = False, False, False, False, False
|
|
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
|
|
extra = {"provider": provider_name}
|
|
msg = ""
|
|
if provider_name == "BaiduYiyan":
|
|
if isinstance(api_key, str):
|
|
try:
|
|
json.loads(api_key)
|
|
except (json.JSONDecodeError, TypeError):
|
|
api_key = {"yiyan_ak": api_key, "yiyan_sk": ""}
|
|
api_key_str = api_key if isinstance(api_key, str) else json.dumps(api_key)
|
|
for llm in factory_llms:
|
|
model_types = _factory_model_types(llm)
|
|
if not embd_passed and LLMType.EMBEDDING.value in model_types:
|
|
assert provider_name in EmbeddingModel, f"Embedding model from {provider_name} is not supported yet."
|
|
mdl = EmbeddingModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url)
|
|
try:
|
|
arr, tc = await asyncio.wait_for(
|
|
asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
|
|
timeout=timeout_seconds,
|
|
)
|
|
if len(arr[0]) == 0:
|
|
raise Exception("Fail")
|
|
embd_passed = True
|
|
except Exception as e:
|
|
logging.exception(
|
|
"Fail to access embedding model for provider=%s model=%s",
|
|
provider_name,
|
|
llm["llm_name"],
|
|
)
|
|
msg += f"\nFail to access embedding model({llm['llm_name']}) using this api key." + str(e)
|
|
elif not chat_passed and LLMType.CHAT.value in model_types:
|
|
assert provider_name in ChatModel, f"Chat model from {provider_name} is not supported yet."
|
|
mdl = ChatModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url, **extra)
|
|
try:
|
|
async def check_streamly():
|
|
async for chunk in mdl.async_chat_streamly(
|
|
None,
|
|
[{"role": "user", "content": "Hi"}],
|
|
{"temperature": 0.9},
|
|
):
|
|
if chunk and isinstance(chunk, str) and chunk.find("**ERROR**") < 0:
|
|
return True
|
|
return False
|
|
|
|
result = await asyncio.wait_for(check_streamly(), timeout=timeout_seconds)
|
|
if result:
|
|
chat_passed = True
|
|
else:
|
|
raise Exception("No valid response received")
|
|
except Exception as e:
|
|
logging.exception(
|
|
"Fail to access chat model for provider=%s model=%s",
|
|
provider_name,
|
|
llm["llm_name"],
|
|
)
|
|
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
|
|
elif not rerank_passed and LLMType.RERANK.value in model_types:
|
|
if provider_name not in RerankModel:
|
|
unsupported_msg = f"Rerank model from {provider_name} is not supported yet."
|
|
logging.warning(unsupported_msg)
|
|
msg += f"\n{unsupported_msg}"
|
|
continue
|
|
mdl = RerankModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url)
|
|
try:
|
|
arr, tc = await asyncio.wait_for(
|
|
asyncio.to_thread(mdl.similarity, "What's the weather?", ["Is it sunny today?"]),
|
|
timeout=timeout_seconds,
|
|
)
|
|
if len(arr) == 0 or tc == 0:
|
|
raise Exception("Fail")
|
|
rerank_passed = True
|
|
logging.debug(f"passed model rerank {llm['llm_name']}")
|
|
except Exception as e:
|
|
logging.exception(
|
|
"Fail to access rerank model for provider=%s model=%s",
|
|
provider_name,
|
|
llm["llm_name"],
|
|
)
|
|
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
|
|
elif not ocr_passed and LLMType.OCR.value in model_types:
|
|
assert provider_name in OcrModel, f"OCR model from {provider_name} is not supported yet."
|
|
mdl = OcrModel[provider_name](key=api_key_str, model_name=llm["llm_name"], base_url=base_url)
|
|
try:
|
|
ok, reason = await asyncio.wait_for(
|
|
asyncio.to_thread(mdl.check_available),
|
|
timeout=timeout_seconds,
|
|
)
|
|
if not ok:
|
|
raise RuntimeError(reason or "Model not available")
|
|
ocr_passed = True
|
|
except Exception as e:
|
|
logging.exception(
|
|
"Fail to access OCR model for provider=%s model=%s",
|
|
provider_name,
|
|
llm["llm_name"],
|
|
)
|
|
msg += f"\nFail to access model({provider_name}/{llm['llm_name']})." + str(e)
|
|
elif not tts_passed and LLMType.TTS.value in model_types:
|
|
assert provider_name in TTSModel, f"TTS model from {provider_name} is not supported yet."
|
|
mdl = TTSModel[provider_name](key=api_key_str, model_name=llm["llm_name"], base_url=base_url)
|
|
try:
|
|
def drain_tts():
|
|
for _ in mdl.tts("Hello~ RAGFlower!"):
|
|
pass
|
|
|
|
await asyncio.wait_for(
|
|
asyncio.to_thread(drain_tts),
|
|
timeout=timeout_seconds,
|
|
)
|
|
tts_passed = True
|
|
except Exception as e:
|
|
logging.exception(
|
|
"Fail to access TTS model for provider=%s model=%s",
|
|
provider_name,
|
|
llm["llm_name"],
|
|
)
|
|
msg += f"\nFail to access model({provider_name}/{llm['llm_name']})." + str(e)
|
|
if any([embd_passed, chat_passed, rerank_passed, ocr_passed, tts_passed]):
|
|
msg = ""
|
|
break
|
|
|
|
success = any([embd_passed, chat_passed, rerank_passed, ocr_passed, tts_passed])
|
|
return success, "success" if success else msg
|
|
|
|
|
|
def show_provider_instance(tenant_id: str, provider_id_or_name: str, instance_id_or_name: str):
|
|
"""
|
|
Show a specific provider instance.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_id_or_name: provider/factory ID or name
|
|
:param instance_id_or_name: instance ID or name
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_id(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
return False, f"No provider found for provider '{provider_id_or_name}'"
|
|
provider_id = provider_obj.id
|
|
instance_obj = None
|
|
if instance_id_or_name:
|
|
_, instance_obj = TenantModelInstanceService.get_by_id(instance_id_or_name)
|
|
if instance_obj and instance_obj.provider_id != provider_id:
|
|
instance_obj = None
|
|
if not instance_obj:
|
|
instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_id, instance_id_or_name)
|
|
if not instance_obj:
|
|
return False, f"No instance found for provider '{provider_id_or_name}' and instance '{instance_id_or_name}'"
|
|
|
|
extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {}
|
|
return True, {
|
|
"id": instance_obj.id,
|
|
"instance_name": instance_obj.instance_name,
|
|
"provider_id": provider_id,
|
|
"region": extra_fields.get("region", ""),
|
|
"status": instance_obj.status
|
|
}
|
|
|
|
|
|
def drop_provider_instances(tenant_id: str, provider_id_or_name: str, instance_id_or_names: list):
|
|
"""
|
|
Drop provider instances.
|
|
for the specified models/instances.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_id_or_name: provider/factory ID or name
|
|
:param instance_id_or_names: list of instance IDs or names to drop
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_id(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
return False, f"No provider found for provider '{provider_id_or_name}'"
|
|
provider_id = provider_obj.id
|
|
not_exist_instances = []
|
|
instance_ids = []
|
|
for instance_id_or_name in instance_id_or_names:
|
|
instance_obj = None
|
|
if instance_id_or_name:
|
|
_, instance_obj = TenantModelInstanceService.get_by_id(instance_id_or_name)
|
|
if instance_obj and instance_obj.provider_id != provider_id:
|
|
instance_obj = None
|
|
if not instance_obj:
|
|
instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_id, instance_id_or_name)
|
|
if not instance_obj:
|
|
not_exist_instances.append(instance_id_or_name)
|
|
continue
|
|
instance_ids.append(instance_obj.id)
|
|
if not_exist_instances:
|
|
return False, f"No instance found for provider '{provider_id_or_name}' and instance '{not_exist_instances}'"
|
|
delete_models_by_instance_ids(instance_ids)
|
|
TenantModelInstanceService.delete_by_ids(instance_ids)
|
|
return True, None
|
|
|
|
|
|
def _hybrid_get_instance_models(provider_name: str, instance_id: str):
|
|
# List all models from the LLM dictionary for this provider
|
|
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
|
|
if not factory_info:
|
|
return False, f"Provider '{provider_name}' not found"
|
|
|
|
# Get model records for this instance from tenant_model table
|
|
model_records = TenantModelService.get_models_by_instance_id(instance_id)
|
|
# Build a map of model_name -> status, type
|
|
model_info_map: dict = {}
|
|
model_unsupported_type_map = {}
|
|
for model_record in model_records:
|
|
if model_record.status == ActiveStatusEnum.UNSUPPORTED.value:
|
|
if model_unsupported_type_map.get(model_record.model_name):
|
|
model_unsupported_type_map[model_record.model_name].append(model_record.model_type)
|
|
else:
|
|
model_unsupported_type_map[model_record.model_name] = [model_record.model_type]
|
|
continue
|
|
if model_info_map.get(model_record.model_name):
|
|
model_info_map[model_record.model_name]["model_type"].append(model_record.model_type)
|
|
else:
|
|
model_info_map[model_record.model_name] = {
|
|
"status": model_record.status,
|
|
"model_type": [model_record.model_type],
|
|
"extra": model_record.extra
|
|
}
|
|
|
|
llms = factory_info[0].get("llm", [])
|
|
models = []
|
|
for llm in llms:
|
|
models.append({
|
|
"name": llm["llm_name"],
|
|
"model_type": list(
|
|
set(_factory_model_types(llm) + model_info_map.get(llm["llm_name"], {}).get("model_type", [])) - set(model_unsupported_type_map.get(llm["llm_name"], []))
|
|
),
|
|
"max_tokens": llm.get("max_tokens"),
|
|
"status": model_info_map.get(llm["llm_name"], {}).get("status", "active"),
|
|
})
|
|
factory_models = [m["name"] for m in models]
|
|
for model_name, model_info_dict in model_info_map.items():
|
|
if model_name not in factory_models:
|
|
extra_fields = json.loads(model_info_dict["extra"]) if model_info_dict["extra"] else {}
|
|
models.append({
|
|
"name": model_name,
|
|
"model_type": set(model_info_dict["model_type"]) - set(model_unsupported_type_map.get(model_name, [])),
|
|
"max_tokens": extra_fields.get("max_tokens", 8192),
|
|
"status": model_info_dict["status"],
|
|
})
|
|
return True, models
|
|
|
|
|
|
def list_instance_models(tenant_id: str, provider_id_or_name: str, instance_id_or_name: str, supported_only: bool = False):
|
|
"""
|
|
List models for a provider instance.
|
|
|
|
Follows the Go version's logic:
|
|
- Reads tenant_model table to determine disabled models (records exist = disabled).
|
|
- Lists all models from the LLM dictionary for the provider.
|
|
- Models present in tenant_model table are marked "inactive", others "active".
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_id_or_name: provider/factory ID or name
|
|
:param instance_id_or_name: instance ID or name
|
|
:param supported_only: if True, only list supported models (from LLM dictionary)
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_id(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
return False, f"No provider found for provider '{provider_id_or_name}'"
|
|
|
|
if supported_only:
|
|
# List all models supported by this provider from the LLM dictionary
|
|
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_obj.provider_name]
|
|
if not factory_info:
|
|
return False, f"Provider '{provider_id_or_name}' not found"
|
|
llms = factory_info[0].get("llm", [])
|
|
models = [{"name": llm["llm_name"]} for llm in llms]
|
|
models.sort(key=lambda x: x["name"])
|
|
return True, models
|
|
|
|
# Get instance
|
|
instance_obj = None
|
|
if instance_id_or_name:
|
|
_, instance_obj = TenantModelInstanceService.get_by_id(instance_id_or_name)
|
|
if instance_obj and instance_obj.provider_id != provider_obj.id:
|
|
instance_obj = None
|
|
if not instance_obj:
|
|
instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_id_or_name)
|
|
if not instance_obj:
|
|
return False, f"No instance found for provider '{provider_id_or_name}' and instance '{instance_id_or_name}'"
|
|
|
|
return _hybrid_get_instance_models(provider_obj.provider_name, instance_obj.id)
|
|
|
|
|
|
def update_instance_models(tenant_id: str, provider_id_or_name: str, instance_id_or_name: str, model_names: list, model_types: list):
|
|
if not model_names or not model_types:
|
|
return False, "model_name and model_type are required"
|
|
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_id(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
return False, f"No provider found for provider '{provider_id_or_name}'"
|
|
instance_obj = None
|
|
if instance_id_or_name:
|
|
_, instance_obj = TenantModelInstanceService.get_by_id(instance_id_or_name)
|
|
if instance_obj and instance_obj.provider_id != provider_obj.id:
|
|
instance_obj = None
|
|
if not instance_obj:
|
|
instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_id_or_name)
|
|
if not instance_obj:
|
|
return False, f"No instance found for provider '{provider_id_or_name}' and instance '{instance_id_or_name}'"
|
|
|
|
found, models = _hybrid_get_instance_models(provider_obj.provider_name, instance_obj.id)
|
|
if not found:
|
|
return False, models
|
|
|
|
model_info_map = {model["name"]: model for model in models}
|
|
not_exist_models = set(model_names) - set(model_info_map.keys())
|
|
if not_exist_models:
|
|
return False, f"Models {not_exist_models} not found for provider '{provider_id_or_name}' and instance '{instance_id_or_name}'"
|
|
for model_name in model_names:
|
|
model_info = model_info_map.get(model_name, {})
|
|
TenantModelService.upsert_model_type(
|
|
provider_obj.id,
|
|
instance_obj.id,
|
|
model_name,
|
|
{
|
|
"add": list(set(model_types) - set(model_info["model_type"])),
|
|
"delete": list(set(model_info["model_type"]) - set(model_types))
|
|
}
|
|
)
|
|
return True, "success"
|
|
|
|
|
|
def add_model_to_instance(tenant_id: str, provider_id_or_name: str, instance_id_or_name: str, model_name: str, model_type: str|list[str], max_tokens: int=8192, extra: dict=None):
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_id(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
return False, f"No provider found for provider '{provider_id_or_name}'"
|
|
instance_obj = None
|
|
if instance_id_or_name:
|
|
_, instance_obj = TenantModelInstanceService.get_by_id(instance_id_or_name)
|
|
if instance_obj and instance_obj.provider_id != provider_obj.id:
|
|
instance_obj = None
|
|
if not instance_obj:
|
|
instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_id_or_name)
|
|
if not instance_obj:
|
|
return False, f"No instance found for provider '{provider_id_or_name}' and instance '{instance_id_or_name}'"
|
|
model_obj = TenantModelService.get_by_provider_id_and_instance_id_and_model_name(provider_obj.id, instance_obj.id, model_name)
|
|
if model_obj:
|
|
return False, f"Model '{model_name}' already exists for provider '{provider_id_or_name}' and instance '{instance_id_or_name}'"
|
|
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_obj.provider_name]
|
|
if not factory_info:
|
|
return False, f"Provider '{provider_id_or_name}' not found"
|
|
llms = factory_info[0].get("llm", [])
|
|
if isinstance(model_type, str):
|
|
model_type = [model_type]
|
|
|
|
for _type in model_type:
|
|
extra_fields = {"max_tokens": max_tokens}
|
|
target_model = [llm for llm in llms if _type in _factory_model_types(llm) and llm["llm_name"] == model_name]
|
|
if target_model:
|
|
extra_fields.update({"is_tools": target_model[0].get("is_tools", False)})
|
|
if extra:
|
|
extra_fields.update(extra)
|
|
TenantModelService.insert(
|
|
model_name=model_name,
|
|
provider_id=provider_obj.id,
|
|
instance_id=instance_obj.id,
|
|
model_type=_type,
|
|
extra=json.dumps(extra_fields)
|
|
)
|
|
|
|
return True, "success"
|
|
|
|
|
|
def update_model_status(tenant_id: str, provider_id_or_name: str, instance_id_or_name: str, model_name: str, status: str):
|
|
"""
|
|
Enable or disable a model for a provider instance.
|
|
|
|
- If the model record exists in tenant_model, update its status.
|
|
- If the model record does not exist:
|
|
- status="active": no need to add a record (default is active/enabled).
|
|
- status="inactive": create a record with status="inactive".
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_id_or_name: provider/factory ID or name
|
|
:param instance_id_or_name: instance ID or name
|
|
:param model_name: model name
|
|
:param status: "active" or "inactive" (ActiveStatusEnum values)
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
if status not in (ActiveStatusEnum.ACTIVE.value, ActiveStatusEnum.INACTIVE.value):
|
|
return False, f"status must be '{ActiveStatusEnum.ACTIVE.value}' or '{ActiveStatusEnum.INACTIVE.value}'"
|
|
|
|
# Check if provider exists for this tenant
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_id(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
return False, f"No provider found for provider '{provider_id_or_name}'"
|
|
|
|
# Check if instance exists
|
|
instance_obj = None
|
|
if instance_id_or_name:
|
|
_, instance_obj = TenantModelInstanceService.get_by_id(instance_id_or_name)
|
|
if instance_obj and instance_obj.provider_id != provider_obj.id:
|
|
instance_obj = None
|
|
if not instance_obj:
|
|
instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_id_or_name)
|
|
if not instance_obj:
|
|
return False, f"No instance found for provider '{provider_id_or_name}' and instance '{instance_id_or_name}'"
|
|
|
|
# Check if model record already exists in tenant_model table
|
|
model_obj_list = TenantModelService.get_by_provider_id_and_instance_id_and_model_name(
|
|
provider_obj.id, instance_obj.id, model_name
|
|
)
|
|
|
|
if model_obj_list:
|
|
# Model record exists — update its status
|
|
TenantModelService.batch_update_model_status([m.id for m in model_obj_list if m.status != ActiveStatusEnum.UNSUPPORTED.value], status)
|
|
else:
|
|
# Model record does not exist
|
|
if status == ActiveStatusEnum.ACTIVE.value:
|
|
# Default is active, no need to add a record
|
|
return True, None
|
|
# status is "inactive" — create a record with inactive status
|
|
# Look up model schema from FACTORY_LLM_INFOS
|
|
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_obj.provider_name]
|
|
if not factory_info:
|
|
return False, f"Provider '{provider_id_or_name}' not found"
|
|
llms = factory_info[0].get("llm", [])
|
|
target_llm = [llm for llm in llms if llm["llm_name"] == model_name]
|
|
if not target_llm:
|
|
return False, f"provider {provider_obj.provider_name} model {model_name} not found"
|
|
|
|
for model_type in _factory_model_types(target_llm[0]):
|
|
TenantModelService.insert(
|
|
id=get_uuid(),
|
|
model_name=model_name,
|
|
model_type=model_type,
|
|
provider_id=provider_obj.id,
|
|
instance_id=instance_obj.id,
|
|
status=status,
|
|
extra=json.dumps({"max_tokens": target_llm[0].get("max_tokens", 8192), "is_tools": target_llm[0].get("is_tools", False)})
|
|
)
|
|
|
|
return True, None
|
|
|
|
|
|
async def chat_to_model(tenant_id: str, provider_id_or_name: str, instance_id_or_name: str, model_name: str, message: str, stream: bool = False, thinking: bool = False):
|
|
"""
|
|
Chat to a model.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_id_or_name: provider/factory ID or name
|
|
:param instance_id_or_name: instance ID or name
|
|
:param model_name: model name
|
|
:param message: chat message
|
|
:param stream: whether to stream the response
|
|
:param thinking: whether to enable thinking/reasoning
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
from api.db.services.llm_service import LLMBundle
|
|
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_id(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_id_or_name)
|
|
if not provider_obj:
|
|
return False, f"No provider found for provider '{provider_id_or_name}'"
|
|
|
|
instance_obj = None
|
|
if instance_id_or_name:
|
|
_, instance_obj = TenantModelInstanceService.get_by_id(instance_id_or_name)
|
|
if instance_obj and instance_obj.provider_id != provider_obj.id:
|
|
instance_obj = None
|
|
if not instance_obj:
|
|
instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_id_or_name)
|
|
if not instance_obj:
|
|
return False, f"No instance found for provider '{provider_id_or_name}' and instance '{instance_id_or_name}'"
|
|
|
|
provider_name = provider_obj.provider_name
|
|
instance_name = instance_obj.instance_name
|
|
|
|
# Get model config
|
|
composite_name = f"{model_name}@{instance_name}@{provider_name}"
|
|
try:
|
|
model_config = get_model_config_from_provider_instance(tenant_id, LLMType.CHAT.value, composite_name)
|
|
except LookupError:
|
|
return False, f"Model '{composite_name}' not authorized"
|
|
|
|
if not model_config:
|
|
return False, f"Model '{composite_name}' not found"
|
|
|
|
llm = LLMBundle(tenant_id, model_config)
|
|
|
|
if stream:
|
|
return True, {"type": "stream", "llm": llm, "model_config": model_config}
|
|
|
|
# Non-streaming chat
|
|
try:
|
|
response = await llm.async_chat(
|
|
None,
|
|
[{"role": "user", "content": message}],
|
|
{"temperature": 0.9},
|
|
)
|
|
result = {
|
|
"answer": response,
|
|
"reasoning_content": "",
|
|
}
|
|
return True, result
|
|
except Exception as e:
|
|
logging.exception(f"Chat to model failed: {e}")
|
|
return False, str(e)
|