Fix: search bot and verify model instance (#15588)

### What problem does this PR solve?

Fix:
- Verify provider with empty llm list in llm_factories.json
- Set search bot's chat_llm_name, use tenant default chat model as
default

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Lynn
2026-06-04 11:59:55 +08:00
committed by GitHub
parent bbacb31226
commit 597ac1e900
4 changed files with 71 additions and 35 deletions

View File

@@ -31,6 +31,7 @@ from api.db.services.dialog_service import DialogService, async_ask, gen_mindmap
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from common.metadata_utils import apply_meta_data_filter
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
@@ -262,10 +263,15 @@ async def ask_about_embedded(tenant_id=None):
if search_app := await thread_pool_exec(SearchService.get_detail, search_id):
search_config = search_app.get("search_config", {})
chat_llm_name = ""
if not search_config or not search_config.get("chat_id"):
_, tenant_info = TenantService.get_by_id(uid)
chat_llm_name = tenant_info.llm_id
async def stream():
nonlocal req, uid
try:
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
async for ans in async_ask(req["question"], req["kb_ids"], uid, chat_llm_name=chat_llm_name, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps(

View File

@@ -321,6 +321,12 @@ async def create_provider_instance(tenant_id: str = None, provider_name: str = N
api_key:
type: string
description: API key.
region:
type: string
description: Region.
model_info:
type: object
description: Model info.
responses:
200:
description: Instance created successfully.
@@ -335,9 +341,10 @@ async def create_provider_instance(tenant_id: str = None, provider_name: str = N
api_key = data["api_key"]
base_url = data.get("base_url", "")
region = data.get("region", "")
model_info = data.get("model_info", {})
try:
success, msg = await provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region)
success, msg = await provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region, model_info)
if success:
return get_result(message=msg)
else:
@@ -383,6 +390,12 @@ async def verify_provider_api_key(provider_name: str = None):
base_url:
type: string
description: Base URL.
region:
type: string
description: Region.
model_info:
type: object
description: Model info.
responses:
200:
description: Instance created successfully.
@@ -396,9 +409,10 @@ async def verify_provider_api_key(provider_name: str = None):
base_url = data.get("base_url", "")
api_key = data["api_key"]
region = data.get("region", "default")
model_info = data.get("model_info", {})
try:
success, msg = await provider_api_service.verify_api_key(provider_name, api_key, base_url, region)
success, msg = await provider_api_service.verify_api_key(provider_name, api_key, base_url, region, model_info)
if success:
return get_result(message=msg)
else:

View File

@@ -14,7 +14,6 @@
# limitations under the License.
#
import os
import aiohttp
import json
import logging
import asyncio
@@ -225,7 +224,7 @@ def show_provider_model(provider_name: str, model_name: str):
}
async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str):
async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str, model_info: dict=None):
"""
Create a provider instance.
@@ -238,6 +237,15 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
:param api_key: API key
:param base_url: base url
:param region: region
:param model_info: model info, {
"model_type": ["chat"], # support multiple
"model_name": "name"
"max_tokens": 4096,
"extra": {
"field1": "value1",
"field2": "'value2"
}
}
:return: (success, result_or_error_message)
"""
if not provider_name:
@@ -259,17 +267,20 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
same_key_instance = TenantModelInstanceService.get_by_provider_id_and_api_key(provider_obj.id, api_key)
if same_key_instance:
return False, f"Already exist instance: {same_key_instance.instance_name} with api_key {api_key}"
success, msg = await verify_api_key(provider_name, api_key, base_url, region)
success, msg = await verify_api_key(provider_name, api_key, base_url, region, model_info)
if not success:
return False, msg
import json
extra_fields = {}
if base_url:
extra_fields["base_url"] = base_url
if region:
extra_fields["region"] = region
TenantModelInstanceService.create_instance(provider_id=provider_obj.id,instance_name=instance_name,api_key=api_key, extra=json.dumps(extra_fields))
if model_info:
success, msg = add_model_to_instance(tenant_id, provider_name, instance_name, **model_info)
if not success:
return False, msg
return True, "success"
@@ -282,7 +293,6 @@ def list_provider_instances(tenant_id: str, provider_name: str):
:param provider_name: provider/factory name
:return: (success, result_or_error_message)
"""
import json
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
if not provider_obj:
return False, f"No provider found for provider '{provider_name}'"
@@ -309,7 +319,7 @@ def list_provider_instances(tenant_id: str, provider_name: str):
return True, active_instances + inactive_instances
async def verify_api_key(provider_name: str, api_key: str, base_url: str=None, region: str=None):
async def verify_api_key(provider_name: str, api_key: str, base_url: str=None, region: str=None, model_info: dict=None):
"""
Verify API key for a provider.
@@ -317,6 +327,15 @@ async def verify_api_key(provider_name: str, api_key: str, base_url: str=None, r
:param api_key: API key
:param base_url: base url
:param region: region
:param model_info: model info, {
"model_type": ["chat"], # support multiple
"model_name": "name"
"max_tokens": 4096,
"extra": {
"field1": "value1",
"field2": "'value2"
}
}
:return: (success, result_or_error_message)
"""
if not provider_name:
@@ -333,24 +352,12 @@ async def verify_api_key(provider_name: str, api_key: str, base_url: str=None, r
factory_llms = factory_info[0]["llm"]
if not factory_llms:
url = base_url or factory_info[0].get("url")
if not url:
if not model_info:
return False, f"No models found for provider '{provider_name}'"
v1_index = url.find("/v1")
if v1_index >= 0:
models_url = url[: v1_index + 3] + "/models"
else:
models_url = url.rstrip("/") + "/v1/models"
try:
async with aiohttp.ClientSession() as session:
async with session.get(models_url, headers={"Authorization": f"Bearer {api_key}"}) as resp:
if resp.status == 200:
return True, "success"
else:
return False, f"Fail to access {models_url} using this api key."
except Exception as e:
logging.error(f"Fail to access {models_url} using this api key.", exc_info=e)
return False, f"Fail to access {models_url} using this api key."
factory_llms = [{
"model_type": _type,
"llm_name": model_info.get("model_name", ""),
} for _type in model_info.get("model_type", [])]
# test if api key works
chat_passed, embd_passed, rerank_passed = False, False, False
@@ -430,7 +437,6 @@ def show_provider_instance(tenant_id: str, provider_name: str, instance_name: st
if not instance_obj:
return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'"
import json
extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {}
return True, {
"id": instance_obj.id,
@@ -549,7 +555,7 @@ def list_instance_models(tenant_id: str, provider_name: str, instance_name: str,
return True, active_models + inactive_models
def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str|list[str], max_tokens: int, extra: dict):
def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str|list[str], max_tokens: int=8192, extra: dict=None):
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
if not provider_obj:
return False, f"No provider found for provider '{provider_name}'"
@@ -566,14 +572,13 @@ def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str
if isinstance(model_type, str):
model_type = [model_type]
import json
for _type in model_type:
extra_fields = {"max_tokens": max_tokens}
target_model = [llm for llm in llms if llm["model_type"] == _type and llm["llm_name"] == model_name]
if target_model:
extra_fields.update({"is_tools": target_model[0].get("is_tools", False)})
extra_fields.update(extra)
if extra:
extra_fields.update(extra)
TenantModelService.insert(
model_name=model_name,
provider_id=provider_obj.id,

View File

@@ -57,10 +57,16 @@ def _load_document_api(
async def _make_response(payload):
return SimpleNamespace(payload=payload, headers={})
def _login_required(func=None, **_kw):
if func is not None:
return func
return lambda f: f
_stub(
monkeypatch, "api.apps",
AUTH_JWT="JWT", AUTH_API="API", AUTH_BETA="BETA",
current_user=SimpleNamespace(id="caller-tenant"),
login_required=lambda func: func,
login_required=_login_required,
)
_stub(monkeypatch, "api.constants", FILE_NAME_LEN_LIMIT=128, IMG_BASE64_PREFIX="data:image/")
_stub(
@@ -141,6 +147,11 @@ def _load_document_api(
apply_safe_file_response_headers=lambda *_a, **_k: None,
)
_stub(monkeypatch, "common.ssrf_guard", assert_url_is_safe=lambda *_a, **_k: None)
_stub(
monkeypatch, "rag.nlp",
search=SimpleNamespace(index_name=lambda *_a, **_k: "index"),
)
_stub(monkeypatch, "rag.nlp.search", index_name=lambda *_a, **_k: "index")
quart_stub = ModuleType("quart")
quart_stub.request = SimpleNamespace(method="GET", args={})
@@ -208,13 +219,13 @@ class TestDocumentPreviewAccessCheck:
def test_missing_doc_returns_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Missing-doc behaviour is unchanged: same 'Document not found!' shape."""
def _accessible_should_not_be_called(*_a, **_k):
raise AssertionError("accessible() must not be called for a missing doc")
def _accessible_returns_false(*_a, **_k):
return False
module = _load_document_api(
monkeypatch,
doc_get_by_id=(False, None),
accessible_fn=_accessible_should_not_be_called,
accessible_fn=_accessible_returns_false,
storage_get=lambda *_a, **_k: b"",
)