mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Fix: model provider (#15517)
### What problem does this PR solve? Fix: - Handle siliconflow and siliconflow_intl api_key ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -395,7 +395,7 @@ class Browser(ComponentBase, ABC):
|
||||
|
||||
chat_model_config = get_model_config_from_provider_instance(
|
||||
self._canvas.get_tenant_id(),
|
||||
get_model_type_by_name(self._param.llm_id),
|
||||
get_model_type_by_name(self._canvas.get_tenant_id(), self._param.llm_id),
|
||||
self._param.llm_id,
|
||||
)
|
||||
cfg = self._as_model_config_dict(chat_model_config)
|
||||
|
||||
@@ -335,10 +335,9 @@ async def create_provider_instance(tenant_id: str = None, provider_name: str = N
|
||||
api_key = data["api_key"]
|
||||
base_url = data.get("base_url", "")
|
||||
region = data.get("region", "")
|
||||
verify = data.get("verify", False)
|
||||
|
||||
try:
|
||||
success, msg = await provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region, verify)
|
||||
success, msg = await provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region)
|
||||
if success:
|
||||
return get_result(message=msg)
|
||||
else:
|
||||
@@ -396,8 +395,10 @@ async def verify_provider_api_key(provider_name: str = None):
|
||||
|
||||
base_url = data.get("base_url", "")
|
||||
api_key = data["api_key"]
|
||||
region = data.get("region", "default")
|
||||
|
||||
try:
|
||||
success, msg = await provider_api_service.verify_api_key(provider_name, api_key, base_url)
|
||||
success, msg = await provider_api_service.verify_api_key(provider_name, api_key, base_url, region)
|
||||
if success:
|
||||
return get_result(message=msg)
|
||||
else:
|
||||
|
||||
@@ -45,6 +45,8 @@ def list_providers(tenant_id: str, all_available: bool = False):
|
||||
if all_available:
|
||||
providers = []
|
||||
for factory_info in FACTORY_LLM_INFOS:
|
||||
if factory_info["name"] in ["Youdao", "FastEmbed", "BAAI", "Builtin", "siliconflow_intl"]:
|
||||
continue
|
||||
model_types = sorted(set(
|
||||
llm["model_type"]
|
||||
for llm in factory_info.get("llm", [])
|
||||
@@ -66,7 +68,7 @@ def list_providers(tenant_id: str, all_available: bool = False):
|
||||
providers = []
|
||||
factory_info_mapping = {f["name"]: f for f in FACTORY_LLM_INFOS}
|
||||
for name in factory_names:
|
||||
if factory_info_mapping.get(name):
|
||||
if name not in ["Youdao", "FastEmbed", "BAAI", "Builtin", "siliconflow_intl"] and factory_info_mapping.get(name):
|
||||
factory_info = factory_info_mapping[name]
|
||||
model_types = sorted(set(
|
||||
llm["model_type"]
|
||||
@@ -207,7 +209,7 @@ def show_provider_model(provider_name: str, model_name: str):
|
||||
}
|
||||
|
||||
|
||||
async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str, verify: bool):
|
||||
async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str):
|
||||
"""
|
||||
Create a provider instance.
|
||||
|
||||
@@ -220,7 +222,6 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
|
||||
:param api_key: API key
|
||||
:param base_url: base url
|
||||
:param region: region
|
||||
:param verify: verify
|
||||
:return: (success, result_or_error_message)
|
||||
"""
|
||||
if not provider_name:
|
||||
@@ -242,10 +243,9 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
|
||||
same_key_instance = TenantModelInstanceService.get_by_provider_id_and_api_key(provider_obj.id, api_key)
|
||||
if same_key_instance:
|
||||
return False, f"Already exist instance: {same_key_instance.instance_name} with api_key {api_key}"
|
||||
if verify:
|
||||
success, msg = await verify_api_key(provider_name, api_key, base_url)
|
||||
if not success:
|
||||
return False, msg
|
||||
success, msg = await verify_api_key(provider_name, api_key, base_url, region)
|
||||
if not success:
|
||||
return False, msg
|
||||
|
||||
import json
|
||||
extra_fields = {}
|
||||
@@ -293,19 +293,25 @@ def list_provider_instances(tenant_id: str, provider_name: str):
|
||||
return True, active_instances + inactive_instances
|
||||
|
||||
|
||||
async def verify_api_key(provider_name: str, api_key: str, base_url: str=None):
|
||||
async def verify_api_key(provider_name: str, api_key: str, base_url: str=None, region: str=None):
|
||||
"""
|
||||
Verify API key for a provider.
|
||||
|
||||
:param provider_name: provider/factory name
|
||||
:param api_key: API key
|
||||
:param base_url: base url
|
||||
:param region: region
|
||||
:return: (success, result_or_error_message)
|
||||
"""
|
||||
if not provider_name:
|
||||
return False, "Provider name is required"
|
||||
|
||||
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name]
|
||||
if region and region == "intl" and provider_name.lower() == "siliconflow":
|
||||
target_factory_name = "siliconflow_intl"
|
||||
else:
|
||||
target_factory_name = provider_name
|
||||
|
||||
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == target_factory_name]
|
||||
if not factory_info:
|
||||
return False, f"Provider '{provider_name}' not found"
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import logging
|
||||
import os
|
||||
import enum
|
||||
import json
|
||||
from common import settings
|
||||
from common.constants import LLMType, ActiveStatusEnum
|
||||
from api.db.services.tenant_llm_service import TenantLLMService, TenantService
|
||||
@@ -102,7 +103,6 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str|enum.Enum
|
||||
raise LookupError(f"Instance {instance_name} not found for model {model_name}.")
|
||||
model_obj = TenantModelService.get_by_provider_id_and_instance_id_and_model_type_and_model_name(provider_obj.id, instance_obj.id, model_type_val, pure_model_name)
|
||||
|
||||
import json
|
||||
api_key, is_tool, api_key_payload = TenantLLMService._decode_api_key_config(instance_obj.api_key)
|
||||
extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {}
|
||||
|
||||
@@ -123,7 +123,12 @@ def get_model_config_from_provider_instance(tenant_id, model_type: str|enum.Enum
|
||||
|
||||
return model_config
|
||||
else:
|
||||
fac_list = [f for f in settings.FACTORY_LLM_INFOS if f["name"] == provider_name]
|
||||
region = extra_fields.get("region", "default")
|
||||
if region == "intl" and provider_name.lower() == "siliconflow":
|
||||
target_factory_name = "siliconflow_intl"
|
||||
else:
|
||||
target_factory_name = provider_name
|
||||
fac_list = [f for f in settings.FACTORY_LLM_INFOS if f["name"] == target_factory_name]
|
||||
if not fac_list:
|
||||
raise LookupError(f"Model provider config not found: {provider_name}")
|
||||
llm_list = [llm for llm in fac_list[0]["llm"] if llm["llm_name"] == pure_model_name]
|
||||
@@ -167,7 +172,13 @@ def get_model_type_by_name(tenant_id: str, model_name: str):
|
||||
raise LookupError(f"Instance {instance_name} not found for model {model_name}.")
|
||||
model_objs = TenantModelService.get_by_provider_id_and_instance_id_and_model_name(provider_obj.id, instance_obj.id, pure_model_name)
|
||||
if not model_objs:
|
||||
fac_list = [f for f in settings.FACTORY_LLM_INFOS if f["name"] == provider_name]
|
||||
extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {}
|
||||
region = extra_fields.get("region", "default")
|
||||
if region == "intl" and provider_name.lower() == "siliconflow":
|
||||
target_factory_name = "siliconflow_intl"
|
||||
else:
|
||||
target_factory_name = provider_name
|
||||
fac_list = [f for f in settings.FACTORY_LLM_INFOS if f["name"] == target_factory_name]
|
||||
if not fac_list:
|
||||
raise LookupError(f"Model provider config not found: {provider_name}")
|
||||
llm_list = [llm for llm in fac_list[0]["llm"] if llm["llm_name"] == pure_model_name]
|
||||
|
||||
@@ -6,6 +6,7 @@ export interface IAddLlmRequestBody {
|
||||
api_key?: string | Record<string, any>;
|
||||
max_tokens: number;
|
||||
is_tools?: boolean;
|
||||
region?: string;
|
||||
}
|
||||
|
||||
export interface IDeleteLlmRequestBody {
|
||||
|
||||
@@ -105,10 +105,10 @@ export const useSubmitApiKey = () => {
|
||||
if (!isVerify) {
|
||||
setSaveLoading(true);
|
||||
}
|
||||
let apiKey: string | Record<string, any> = postBody.api_key || '';
|
||||
const apiKey: string = postBody.api_key || '';
|
||||
|
||||
let region: string | undefined;
|
||||
if (savingParams.llm_factory === LLMFactory.SILICONFLOW) {
|
||||
let sourceFid: string = LLMFactory.SILICONFLOW;
|
||||
const baseUrl = postBody.base_url;
|
||||
if (baseUrl) {
|
||||
try {
|
||||
@@ -118,13 +118,12 @@ export const useSubmitApiKey = () => {
|
||||
host === 'api.siliconflow.com' ||
|
||||
host.endsWith('.api.siliconflow.com')
|
||||
) {
|
||||
sourceFid = 'siliconflow_intl';
|
||||
region = 'intl';
|
||||
}
|
||||
} catch {
|
||||
// ignore invalid URL and keep default sourceFid
|
||||
// ignore invalid URL
|
||||
}
|
||||
}
|
||||
apiKey = { api_key: postBody.api_key, source_fid: sourceFid };
|
||||
}
|
||||
|
||||
const req: IAddProviderInstanceRequestBody = {
|
||||
@@ -136,6 +135,7 @@ export const useSubmitApiKey = () => {
|
||||
api_key: apiKey,
|
||||
api_base: postBody.base_url || '',
|
||||
max_tokens: 0,
|
||||
...(region ? { region } : {}),
|
||||
};
|
||||
|
||||
const ret = await submitProviderInstance(req, isVerify);
|
||||
|
||||
Reference in New Issue
Block a user