mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### What problem does this PR solve?
When setting the API key for the BaiduYiyan provider, all model
validations fail with the error "Fail to access model using this api
key. No valid response received".
**Root cause:**
1. `BaiduYiyanChat` in `rag/llm/chat_model.py` does not override
`async_chat_streamly()`. The `verify_api_key()` function uses
`mdl.async_chat_streamly()` to validate, but `BaiduYiyanChat` inherits
`Base.async_chat_streamly()` which uses the OpenAI client, not the Baidu
Qianfan SDK (qianfan). Since BaiduYiyan has no OpenAI-compatible
base_url, validation always fails.
2. `verify_api_key()` in `provider_api_service.py` does not format the
raw API key string into the JSON format (`{"yiyan_ak": "...",
"yiyan_sk": "..."}`) that `BaiduYiyanChat.__init__()` expects via
`json.loads(key)`.
**Fix:**
1. Add `async_chat_streamly()` method to `BaiduYiyanChat` using the
qianfan SDK, consistent with the existing `chat_streamly()` method.
2. Add BaiduYiyan API key formatting in `provider_api_service.py`
`verify_api_key()` to match the format expected by
`BaiduYiyanChat.__init__()`.
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- [ ] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):
753 lines
31 KiB
Python
753 lines
31 KiB
Python
#
|
|
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import os
|
|
import json
|
|
import logging
|
|
import asyncio
|
|
|
|
from common.constants import LLMType, ActiveStatusEnum
|
|
from common.misc_utils import get_uuid
|
|
from common.settings import FACTORY_LLM_INFOS
|
|
from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance, delete_models_by_instance_ids, delete_instances_by_provider_ids
|
|
from api.db.services.tenant_model_provider_service import TenantModelProviderService
|
|
from api.db.services.tenant_model_instance_service import TenantModelInstanceService
|
|
from api.db.services.tenant_model_service import TenantModelService
|
|
from rag.llm import EmbeddingModel, ChatModel, RerankModel, ModelMeta
|
|
|
|
|
|
def _to_int(v, default=500):
|
|
try:
|
|
return int(v)
|
|
except (TypeError, ValueError):
|
|
return default
|
|
|
|
|
|
def list_providers(tenant_id: str, all_available: bool = False):
|
|
"""
|
|
List providers for a tenant.
|
|
|
|
If available_only is True, list all system-wide providers (pool providers).
|
|
Otherwise, list providers that the tenant has configured.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param all_available: whether to list all available providers
|
|
:return: (success, result)
|
|
"""
|
|
if not FACTORY_LLM_INFOS:
|
|
return False, []
|
|
|
|
factory_rank_mapping = {factory["name"]: -_to_int(factory.get("rank", "500")) for factory in FACTORY_LLM_INFOS}
|
|
factory_info_map = {f["name"]: f for f in FACTORY_LLM_INFOS}
|
|
if all_available:
|
|
providers = []
|
|
for factory_info in FACTORY_LLM_INFOS:
|
|
if factory_info["name"] in ["Youdao", "FastEmbed", "BAAI", "Builtin", "siliconflow_intl"]:
|
|
continue
|
|
model_types = sorted(set(
|
|
llm["model_type"]
|
|
for llm in factory_info.get("llm", [])
|
|
if llm.get("model_type")
|
|
))
|
|
provider = {
|
|
"model_types": model_types,
|
|
"name": factory_info["name"],
|
|
"url": {
|
|
"default": factory_info.get("url", "")
|
|
}
|
|
}
|
|
if factory_info["name"].lower() == "siliconflow":
|
|
provider["url"]["intl"] = factory_info_map.get("siliconflow_intl", {}).get("url", "https://api.siliconflow.com/v1")
|
|
elif factory_info["name"] == "Tongyi-Qianwen":
|
|
provider["url"]["intl"] = "https://dashscope-intl.aliyuncs.com/compatible-model/v1"
|
|
providers.append(provider)
|
|
providers.sort(key=lambda x: (factory_rank_mapping.get(x["name"]), x["name"]))
|
|
return True, providers
|
|
|
|
# List tenant-configured providers
|
|
factory_names = TenantModelProviderService.list_provider_names_by_tenant_id(tenant_id)
|
|
|
|
providers = []
|
|
factory_info_mapping = {f["name"]: f for f in FACTORY_LLM_INFOS}
|
|
for name in factory_names:
|
|
if name not in ["Youdao", "FastEmbed", "BAAI", "Builtin", "siliconflow_intl"] and factory_info_mapping.get(name):
|
|
factory_info = factory_info_mapping[name]
|
|
model_types = sorted(set(
|
|
llm["model_type"]
|
|
for llm in factory_info.get("llm", [])
|
|
if llm.get("model_type")
|
|
))
|
|
provider = {
|
|
"model_types": model_types,
|
|
"name": factory_info["name"],
|
|
"url": {
|
|
"default": factory_info.get("url", "")
|
|
}
|
|
}
|
|
if factory_info["name"].lower() == "siliconflow":
|
|
provider["url"]["intl"] = factory_info_map.get("siliconflow_intl", {}).get("url", "https://api.siliconflow.com/v1")
|
|
elif factory_info["name"] == "Tongyi-Qianwen":
|
|
provider["url"]["intl"] = "https://dashscope-intl.aliyuncs.com/compatible-model/v1"
|
|
providers.append(provider)
|
|
providers.sort(key=lambda x: (factory_rank_mapping.get(x["name"]), x["name"]))
|
|
return True, providers
|
|
|
|
|
|
def add_provider(tenant_id: str, provider_name: str):
|
|
"""
|
|
Add a provider (factory) for a tenant.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_name: provider/factory name
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
if not FACTORY_LLM_INFOS:
|
|
return False, "No providers found"
|
|
# Check if factory is allowed
|
|
allowed_factories = [f["name"] for f in FACTORY_LLM_INFOS]
|
|
if provider_name not in allowed_factories:
|
|
return False, f"Provider '{provider_name}' is not allowed"
|
|
|
|
existing = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
|
if existing:
|
|
return False, f"Provider {provider_name} already exists"
|
|
|
|
TenantModelProviderService.insert(
|
|
tenant_id=tenant_id,
|
|
provider_name=provider_name
|
|
)
|
|
return True, "success"
|
|
|
|
|
|
def delete_provider(tenant_id: str, provider_name: str):
|
|
"""
|
|
Delete all instances and models for a provider.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_name: provider/factory name
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
|
if not provider_obj:
|
|
return False, f"Provider {provider_name} not found"
|
|
instance_objs = TenantModelInstanceService.get_all_by_provider_id(provider_obj.id)
|
|
if not instance_objs:
|
|
return False, f"No instances found for provider {provider_name}"
|
|
instance_ids = [instance_obj.id for instance_obj in instance_objs]
|
|
delete_models_by_instance_ids(instance_ids)
|
|
delete_instances_by_provider_ids([provider_obj.id])
|
|
TenantModelProviderService.delete_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
|
return True, "success"
|
|
|
|
|
|
def show_provider(provider_name: str):
|
|
"""
|
|
Show provider details from LLMFactories.
|
|
|
|
:param provider_name: provider/factory name
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
fac_list = [f for f in FACTORY_LLM_INFOS if f["name"]==provider_name]
|
|
if not fac_list:
|
|
return False, f"Provider '{provider_name}' not found"
|
|
factory_info = fac_list[0]
|
|
return True, {
|
|
"base_url": {
|
|
"default": factory_info.get("url", "")
|
|
},
|
|
"name": factory_info["name"],
|
|
"total_models": len(factory_info.get("llm", []))
|
|
}
|
|
|
|
|
|
async def list_provider_models(provider_name: str, api_key: str = None, base_url: str = None):
|
|
"""
|
|
List all models for a provider from the LLM dictionary.
|
|
|
|
:param provider_name: provider/factory name
|
|
:param api_key: api key
|
|
:param base_url: base url
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"]==provider_name]
|
|
if not factory_info:
|
|
return False, f"Provider '{provider_name}' not found"
|
|
static_llms = [{
|
|
"name": llm["name"],
|
|
"max_tokens": llm["max_tokens"],
|
|
"model_types": [llm["model_type"]],
|
|
"features": (
|
|
llm.get("features")
|
|
if llm.get("features") is not None
|
|
else (
|
|
(["is_tools"] if llm.get("is_tools") else [])
|
|
+ (["thinking"] if llm.get("thinking") else [])
|
|
)
|
|
)
|
|
} for llm in factory_info[0]["llm"]]
|
|
|
|
model_base_url = base_url or factory_info[0].get("url", "")
|
|
remote_models = []
|
|
if provider_name in ModelMeta:
|
|
remote_models = await ModelMeta[provider_name](api_key, model_base_url).get_model_list()
|
|
|
|
if not static_llms and not remote_models:
|
|
return False, f"No models found for provider '{provider_name}'"
|
|
|
|
# Merge static and remote models, preferring remote_models on name conflicts
|
|
merged = {m["name"]: m for m in static_llms}
|
|
merged.update({m["name"]: m for m in remote_models})
|
|
models = list(merged.values())
|
|
|
|
models.sort(key=lambda x: x["name"])
|
|
return True, models
|
|
|
|
|
|
def show_provider_model(provider_name: str, model_name: str):
|
|
"""
|
|
Show a specific model for a provider.
|
|
|
|
:param provider_name: provider/factory name
|
|
:param model_name: model name
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
|
|
if not factory_info:
|
|
return False, f"Provider '{provider_name}' not found"
|
|
llms = factory_info[0]["llm"]
|
|
if not llms:
|
|
return False, f"No models found for provider '{provider_name}'"
|
|
target_llm = [llm for llm in llms if llm["name"] == model_name]
|
|
if not target_llm:
|
|
return False, f"Model '{model_name}' not found"
|
|
llm_info = target_llm[0]
|
|
|
|
return True, {
|
|
"name": llm_info["name"],
|
|
"max_tokens": llm_info["max_tokens"],
|
|
"model_types": [llm_info["model_type"]],
|
|
"thinking": None,
|
|
"model_type_map": {
|
|
llm_info["model_type"]: True
|
|
}
|
|
}
|
|
|
|
|
|
async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str|dict, base_url: str, region: str, model_info: list[dict]=None):
|
|
"""
|
|
Create a provider instance.
|
|
|
|
The instance_name parameter is accepted for API compatibility but in the old
|
|
model all records under a factory share the same API key configuration.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_name: provider/factory name
|
|
:param instance_name: instance name (used as a logical identifier)
|
|
:param api_key: API key
|
|
:param base_url: base url
|
|
:param region: region
|
|
:param model_info: model info, [{
|
|
"model_type": ["chat"], # support multiple
|
|
"model_name": "name",
|
|
"max_tokens": 4096,
|
|
"extra": {
|
|
"field1": "value1",
|
|
"field2": "'value2"
|
|
}
|
|
}]
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
if not provider_name:
|
|
return False, "Provider name is required"
|
|
|
|
if instance_name == "default":
|
|
return False, "Instance name cannot be 'default'"
|
|
|
|
# Check if provider exists in the system
|
|
allowed_factories = [f["name"] for f in FACTORY_LLM_INFOS]
|
|
if provider_name not in allowed_factories:
|
|
return False, f"Provider '{provider_name}' is not allowed"
|
|
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
|
if not provider_obj:
|
|
return False, f"Provider '{provider_name}' does not exist"
|
|
|
|
api_key_str = ""
|
|
if api_key:
|
|
api_key_str = api_key if isinstance(api_key, str) else json.dumps(api_key)
|
|
same_key_instance = TenantModelInstanceService.get_by_provider_id_and_api_key(provider_obj.id, api_key_str)
|
|
if same_key_instance:
|
|
return False, f"Already exist instance: {same_key_instance.instance_name} with api_key {api_key}"
|
|
success, msg = await verify_api_key(provider_name, api_key, base_url, region, model_info)
|
|
if not success:
|
|
return False, msg
|
|
|
|
extra_fields = {}
|
|
if base_url:
|
|
extra_fields["base_url"] = base_url
|
|
if region:
|
|
extra_fields["region"] = region
|
|
TenantModelInstanceService.create_instance(provider_id=provider_obj.id,instance_name=instance_name,api_key=api_key_str, extra=json.dumps(extra_fields))
|
|
if model_info:
|
|
msg = ""
|
|
for model in model_info:
|
|
success, _msg = add_model_to_instance(tenant_id, provider_name, instance_name, **model)
|
|
if not success:
|
|
msg += _msg
|
|
if msg:
|
|
return False, msg
|
|
|
|
return True, "success"
|
|
|
|
|
|
def list_provider_instances(tenant_id: str, provider_name: str):
|
|
"""
|
|
List provider instances for a tenant.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_name: provider/factory name
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
|
if not provider_obj:
|
|
return False, f"No provider found for provider '{provider_name}'"
|
|
provider_id = provider_obj.id
|
|
instance_objs = TenantModelInstanceService.get_all_by_provider_id(provider_id)
|
|
if not instance_objs:
|
|
return True, []
|
|
instances = []
|
|
for instance_obj in instance_objs:
|
|
extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {}
|
|
instances.append({
|
|
"id": instance_obj.id,
|
|
"instance_name": instance_obj.instance_name,
|
|
"provider_id": provider_id,
|
|
"region": extra_fields.get("region", ""),
|
|
"status": instance_obj.status,
|
|
})
|
|
|
|
active_instances = [instance for instance in instances if instance["status"] == ActiveStatusEnum.ACTIVE.value]
|
|
inactive_instances = [instance for instance in instances if instance["status"] == ActiveStatusEnum.INACTIVE.value]
|
|
active_instances.sort(key=lambda x: x["instance_name"])
|
|
inactive_instances.sort(key=lambda x: x["instance_name"])
|
|
|
|
return True, active_instances + inactive_instances
|
|
|
|
|
|
async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=None, region: str=None, model_info: list[dict]=None):
|
|
"""
|
|
Verify API key for a provider.
|
|
|
|
:param provider_name: provider/factory name
|
|
:param api_key: API key
|
|
:param base_url: base url
|
|
:param region: region
|
|
:param model_info: model info, [{
|
|
"model_type": ["chat"], # support multiple
|
|
"model_name": "name",
|
|
"max_tokens": 4096,
|
|
"extra": {
|
|
"field1": "value1",
|
|
"field2": "'value2"
|
|
}
|
|
}]
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
if not provider_name:
|
|
return False, "Provider name is required"
|
|
|
|
if region and region == "intl" and provider_name.lower() == "siliconflow":
|
|
target_factory_name = "siliconflow_intl"
|
|
else:
|
|
target_factory_name = provider_name
|
|
|
|
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == target_factory_name]
|
|
if not factory_info:
|
|
return False, f"Provider '{provider_name}' not found"
|
|
|
|
factory_llms = factory_info[0]["llm"]
|
|
if not factory_llms:
|
|
if not model_info:
|
|
return False, f"No models found for provider '{provider_name}'"
|
|
factory_llms = [{
|
|
"model_type": _type,
|
|
"llm_name": model.get("model_name", ""),
|
|
} for model in model_info if model for _type in model.get("model_type", []) ]
|
|
|
|
# test if api key works
|
|
chat_passed, embd_passed, rerank_passed = False, False, False
|
|
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
|
|
extra = {"provider": provider_name}
|
|
msg = ""
|
|
if provider_name == "BaiduYiyan":
|
|
if isinstance(api_key, str):
|
|
try:
|
|
json.loads(api_key)
|
|
except (json.JSONDecodeError, TypeError):
|
|
api_key = {"yiyan_ak": api_key, "yiyan_sk": ""}
|
|
api_key_str = api_key if isinstance(api_key, str) else json.dumps(api_key)
|
|
for llm in factory_llms:
|
|
if not embd_passed and llm["model_type"] == LLMType.EMBEDDING.value:
|
|
assert provider_name in EmbeddingModel, f"Embedding model from {provider_name} is not supported yet."
|
|
mdl = EmbeddingModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url)
|
|
try:
|
|
arr, tc = await asyncio.wait_for(
|
|
asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
|
|
timeout=timeout_seconds,
|
|
)
|
|
if len(arr[0]) == 0:
|
|
raise Exception("Fail")
|
|
embd_passed = True
|
|
except Exception as e:
|
|
logging.exception(
|
|
"Fail to access embedding model for provider=%s model=%s",
|
|
provider_name,
|
|
llm["llm_name"],
|
|
)
|
|
msg += f"\nFail to access embedding model({llm['llm_name']}) using this api key." + str(e)
|
|
elif not chat_passed and llm["model_type"] == LLMType.CHAT.value:
|
|
assert provider_name in ChatModel, f"Chat model from {provider_name} is not supported yet."
|
|
mdl = ChatModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url, **extra)
|
|
try:
|
|
async def check_streamly():
|
|
async for chunk in mdl.async_chat_streamly(
|
|
None,
|
|
[{"role": "user", "content": "Hi"}],
|
|
{"temperature": 0.9},
|
|
):
|
|
if chunk and isinstance(chunk, str) and chunk.find("**ERROR**") < 0:
|
|
return True
|
|
return False
|
|
|
|
result = await asyncio.wait_for(check_streamly(), timeout=timeout_seconds)
|
|
if result:
|
|
chat_passed = True
|
|
else:
|
|
raise Exception("No valid response received")
|
|
except Exception as e:
|
|
logging.exception(
|
|
"Fail to access chat model for provider=%s model=%s",
|
|
provider_name,
|
|
llm["llm_name"],
|
|
)
|
|
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
|
|
elif not rerank_passed and llm["model_type"] == LLMType.RERANK.value:
|
|
assert provider_name in RerankModel, f"Rerank model from {provider_name} is not supported yet."
|
|
mdl = RerankModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url)
|
|
try:
|
|
arr, tc = await asyncio.wait_for(
|
|
asyncio.to_thread(mdl.similarity, "What's the weather?", ["Is it sunny today?"]),
|
|
timeout=timeout_seconds,
|
|
)
|
|
if len(arr) == 0 or tc == 0:
|
|
raise Exception("Fail")
|
|
rerank_passed = True
|
|
logging.debug(f"passed model rerank {llm['llm_name']}")
|
|
except Exception as e:
|
|
logging.exception(
|
|
"Fail to access rerank model for provider=%s model=%s",
|
|
provider_name,
|
|
llm["llm_name"],
|
|
)
|
|
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
|
|
if any([embd_passed, chat_passed, rerank_passed]):
|
|
msg = ""
|
|
break
|
|
|
|
success = any([embd_passed, chat_passed, rerank_passed])
|
|
return success, "success" if success else msg
|
|
|
|
|
|
def show_provider_instance(tenant_id: str, provider_name: str, instance_name: str):
|
|
"""
|
|
Show a specific provider instance.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_name: provider/factory name
|
|
:param instance_name: instance name
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
|
if not provider_obj:
|
|
return False, f"No provider found for provider '{provider_name}'"
|
|
provider_id = provider_obj.id
|
|
instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_id, instance_name)
|
|
if not instance_obj:
|
|
return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'"
|
|
|
|
extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {}
|
|
return True, {
|
|
"id": instance_obj.id,
|
|
"instance_name": instance_obj.instance_name,
|
|
"provider_id": provider_id,
|
|
"region": extra_fields.get("region", ""),
|
|
"status": instance_obj.status
|
|
}
|
|
|
|
|
|
def drop_provider_instances(tenant_id: str, provider_name: str, instance_names: list):
|
|
"""
|
|
Drop provider instances.
|
|
for the specified models/instances.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_name: provider/factory name
|
|
:param instance_names: list of instance names to drop
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
|
if not provider_obj:
|
|
return False, f"No provider found for provider '{provider_name}'"
|
|
provider_id = provider_obj.id
|
|
not_exist_instances = []
|
|
instance_ids = []
|
|
for instance_name in instance_names:
|
|
instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_id, instance_name)
|
|
if not instance_obj:
|
|
not_exist_instances.append(instance_name)
|
|
continue
|
|
instance_ids.append(instance_obj.id)
|
|
if not_exist_instances:
|
|
return False, f"No instance found for provider '{provider_name}' and instance '{not_exist_instances}'"
|
|
delete_models_by_instance_ids(instance_ids)
|
|
TenantModelInstanceService.delete_by_ids(instance_ids)
|
|
return True, None
|
|
|
|
|
|
def list_instance_models(tenant_id: str, provider_name: str, instance_name: str, supported_only: bool = False):
|
|
"""
|
|
List models for a provider instance.
|
|
|
|
Follows the Go version's logic:
|
|
- Reads tenant_model table to determine disabled models (records exist = disabled).
|
|
- Lists all models from the LLM dictionary for the provider.
|
|
- Models present in tenant_model table are marked "inactive", others "active".
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_name: provider/factory name
|
|
:param instance_name: instance name
|
|
:param supported_only: if True, only list supported models (from LLM dictionary)
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
|
if not provider_obj:
|
|
return False, f"No provider found for provider '{provider_name}'"
|
|
|
|
if supported_only:
|
|
# List all models supported by this provider from the LLM dictionary
|
|
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
|
|
if not factory_info:
|
|
return False, f"Provider '{provider_name}' not found"
|
|
llms = factory_info[0].get("llm", [])
|
|
models = [{"name": llm["llm_name"]} for llm in llms]
|
|
models.sort(key=lambda x: x["name"])
|
|
return True, models
|
|
|
|
# Get instance
|
|
instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_name)
|
|
if not instance_obj:
|
|
return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'"
|
|
|
|
# Get model records for this instance from tenant_model table
|
|
model_records = TenantModelService.get_models_by_instance_id(instance_obj.id)
|
|
# Build a map of model_name -> status, type
|
|
model_info_map: dict = {}
|
|
for model_record in model_records:
|
|
if model_info_map.get(model_record.model_name):
|
|
model_info_map[model_record.model_name]["model_type"].append(model_record.model_type)
|
|
else:
|
|
model_info_map[model_record.model_name] = {
|
|
"status": model_record.status,
|
|
"model_type": [model_record.model_type],
|
|
"extra": model_record.extra
|
|
}
|
|
|
|
# List all models from the LLM dictionary for this provider
|
|
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
|
|
if not factory_info:
|
|
return False, f"Provider '{provider_name}' not found"
|
|
|
|
llms = factory_info[0].get("llm", [])
|
|
models = []
|
|
for llm in llms:
|
|
models.append({
|
|
"name": llm["llm_name"],
|
|
"model_type": [llm["model_type"]] + model_info_map.get(llm["llm_name"], {}).get("model_type", []),
|
|
"max_tokens": llm.get("max_tokens"),
|
|
"status": model_info_map.get(llm["llm_name"], {}).get("status", "active"),
|
|
})
|
|
factory_models = [m["name"] for m in models]
|
|
for model_name, model_info_dict in model_info_map.items():
|
|
if model_name not in factory_models:
|
|
extra_fields = json.loads(model_info_dict["extra"]) if model_info_dict["extra"] else {}
|
|
models.append({
|
|
"name": model_name,
|
|
"model_type": model_info_dict["model_type"],
|
|
"max_tokens": extra_fields.get("max_tokens", 8192),
|
|
"status": model_info_dict["status"],
|
|
})
|
|
active_models = [model for model in models if model["status"] == ActiveStatusEnum.ACTIVE.value]
|
|
inactive_models = [model for model in models if model["status"] == ActiveStatusEnum.INACTIVE.value]
|
|
active_models.sort(key=lambda x: x["name"])
|
|
inactive_models.sort(key=lambda x: x["name"])
|
|
return True, active_models + inactive_models
|
|
|
|
|
|
def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str|list[str], max_tokens: int=8192, extra: dict=None):
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
|
if not provider_obj:
|
|
return False, f"No provider found for provider '{provider_name}'"
|
|
instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_name)
|
|
if not instance_obj:
|
|
return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'"
|
|
model_obj = TenantModelService.get_by_provider_id_and_instance_id_and_model_name(provider_obj.id, instance_obj.id, model_name)
|
|
if model_obj:
|
|
return False, f"Model '{model_name}' already exists for provider '{provider_name}' and instance '{instance_name}'"
|
|
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
|
|
if not factory_info:
|
|
return False, f"Provider '{provider_name}' not found"
|
|
llms = factory_info[0].get("llm", [])
|
|
if isinstance(model_type, str):
|
|
model_type = [model_type]
|
|
|
|
for _type in model_type:
|
|
extra_fields = {"max_tokens": max_tokens}
|
|
target_model = [llm for llm in llms if llm["model_type"] == _type and llm["llm_name"] == model_name]
|
|
if target_model:
|
|
extra_fields.update({"is_tools": target_model[0].get("is_tools", False)})
|
|
if extra:
|
|
extra_fields.update(extra)
|
|
TenantModelService.insert(
|
|
model_name=model_name,
|
|
provider_id=provider_obj.id,
|
|
instance_id=instance_obj.id,
|
|
model_type=_type,
|
|
extra=json.dumps(extra_fields)
|
|
)
|
|
|
|
return True, "success"
|
|
|
|
|
|
def update_model_status(tenant_id: str, provider_name: str, instance_name: str, model_name: str, status: str):
|
|
"""
|
|
Enable or disable a model for a provider instance.
|
|
|
|
- If the model record exists in tenant_model, update its status.
|
|
- If the model record does not exist:
|
|
- status="active": no need to add a record (default is active/enabled).
|
|
- status="inactive": create a record with status="inactive".
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_name: provider/factory name
|
|
:param instance_name: instance name
|
|
:param model_name: model name
|
|
:param status: "active" or "inactive" (ActiveStatusEnum values)
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
if status not in (ActiveStatusEnum.ACTIVE.value, ActiveStatusEnum.INACTIVE.value):
|
|
return False, f"status must be '{ActiveStatusEnum.ACTIVE.value}' or '{ActiveStatusEnum.INACTIVE.value}'"
|
|
|
|
# Check if provider exists for this tenant
|
|
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
|
if not provider_obj:
|
|
return False, f"No provider found for provider '{provider_name}'"
|
|
|
|
# Check if instance exists
|
|
instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_name)
|
|
if not instance_obj:
|
|
return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'"
|
|
|
|
# Check if model record already exists in tenant_model table
|
|
model_obj_list = TenantModelService.get_by_provider_id_and_instance_id_and_model_name(
|
|
provider_obj.id, instance_obj.id, model_name
|
|
)
|
|
|
|
if model_obj_list:
|
|
# Model record exists — update its status
|
|
TenantModelService.batch_update_model_status([m.id for m in model_obj_list], status)
|
|
else:
|
|
# Model record does not exist
|
|
if status == ActiveStatusEnum.ACTIVE.value:
|
|
# Default is active, no need to add a record
|
|
return True, None
|
|
# status is "inactive" — create a record with inactive status
|
|
# Look up model schema from FACTORY_LLM_INFOS
|
|
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
|
|
if not factory_info:
|
|
return False, f"Provider '{provider_name}' not found"
|
|
llms = factory_info[0].get("llm", [])
|
|
target_llm = [llm for llm in llms if llm["llm_name"] == model_name]
|
|
if not target_llm:
|
|
return False, f"provider {provider_name} model {model_name} not found"
|
|
|
|
TenantModelService.insert(
|
|
id=get_uuid(),
|
|
model_name=model_name,
|
|
model_type=target_llm[0]["model_type"],
|
|
provider_id=provider_obj.id,
|
|
instance_id=instance_obj.id,
|
|
status=status,
|
|
extra=json.dumps({"max_tokens": target_llm[0].get("max_tokens", 8192), "is_tools": target_llm[0].get("is_tools", False)})
|
|
)
|
|
|
|
return True, None
|
|
|
|
|
|
async def chat_to_model(tenant_id: str, provider_name: str, instance_name: str, model_name: str, message: str, stream: bool = False, thinking: bool = False):
|
|
"""
|
|
Chat to a model.
|
|
|
|
:param tenant_id: tenant ID
|
|
:param provider_name: provider/factory name
|
|
:param instance_name: instance name
|
|
:param model_name: model name
|
|
:param message: chat message
|
|
:param stream: whether to stream the response
|
|
:param thinking: whether to enable thinking/reasoning
|
|
:return: (success, result_or_error_message)
|
|
"""
|
|
from api.db.services.llm_service import LLMBundle
|
|
|
|
# Get model config
|
|
composite_name = f"{model_name}@{instance_name}@{provider_name}"
|
|
try:
|
|
model_config = get_model_config_from_provider_instance(tenant_id, LLMType.CHAT.value, composite_name)
|
|
except LookupError:
|
|
return False, f"Model '{composite_name}' not authorized"
|
|
|
|
if not model_config:
|
|
return False, f"Model '{composite_name}' not found"
|
|
|
|
llm = LLMBundle(tenant_id, model_config)
|
|
|
|
if stream:
|
|
return True, {"type": "stream", "llm": llm, "model_config": model_config}
|
|
|
|
# Non-streaming chat
|
|
try:
|
|
response = await llm.async_chat(
|
|
None,
|
|
[{"role": "user", "content": message}],
|
|
{"temperature": 0.9},
|
|
)
|
|
result = {
|
|
"answer": response,
|
|
"reasoning_content": "",
|
|
}
|
|
return True, result
|
|
except Exception as e:
|
|
logging.exception(f"Chat to model failed: {e}")
|
|
return False, str(e)
|