mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Fix: search bot and verify model instance (#15588)
### What problem does this PR solve? Fix: - Verify provider with empty llm list in llm_factories.json - Set search bot's chat_llm_name, use tenant default chat model as default ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -31,6 +31,7 @@ from api.db.services.dialog_service import DialogService, async_ask, gen_mindmap
|
||||
from api.db.services.doc_metadata_service import DocMetadataService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
@@ -262,10 +263,15 @@ async def ask_about_embedded(tenant_id=None):
|
||||
if search_app := await thread_pool_exec(SearchService.get_detail, search_id):
|
||||
search_config = search_app.get("search_config", {})
|
||||
|
||||
chat_llm_name = ""
|
||||
if not search_config or not search_config.get("chat_id"):
|
||||
_, tenant_info = TenantService.get_by_id(uid)
|
||||
chat_llm_name = tenant_info.llm_id
|
||||
|
||||
async def stream():
|
||||
nonlocal req, uid
|
||||
try:
|
||||
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||
async for ans in async_ask(req["question"], req["kb_ids"], uid, chat_llm_name=chat_llm_name, search_config=search_config):
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps(
|
||||
|
||||
@@ -321,6 +321,12 @@ async def create_provider_instance(tenant_id: str = None, provider_name: str = N
|
||||
api_key:
|
||||
type: string
|
||||
description: API key.
|
||||
region:
|
||||
type: string
|
||||
description: Region.
|
||||
model_info:
|
||||
type: object
|
||||
description: Model info.
|
||||
responses:
|
||||
200:
|
||||
description: Instance created successfully.
|
||||
@@ -335,9 +341,10 @@ async def create_provider_instance(tenant_id: str = None, provider_name: str = N
|
||||
api_key = data["api_key"]
|
||||
base_url = data.get("base_url", "")
|
||||
region = data.get("region", "")
|
||||
model_info = data.get("model_info", {})
|
||||
|
||||
try:
|
||||
success, msg = await provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region)
|
||||
success, msg = await provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region, model_info)
|
||||
if success:
|
||||
return get_result(message=msg)
|
||||
else:
|
||||
@@ -383,6 +390,12 @@ async def verify_provider_api_key(provider_name: str = None):
|
||||
base_url:
|
||||
type: string
|
||||
description: Base URL.
|
||||
region:
|
||||
type: string
|
||||
description: Region.
|
||||
model_info:
|
||||
type: object
|
||||
description: Model info.
|
||||
responses:
|
||||
200:
|
||||
description: Instance created successfully.
|
||||
@@ -396,9 +409,10 @@ async def verify_provider_api_key(provider_name: str = None):
|
||||
base_url = data.get("base_url", "")
|
||||
api_key = data["api_key"]
|
||||
region = data.get("region", "default")
|
||||
model_info = data.get("model_info", {})
|
||||
|
||||
try:
|
||||
success, msg = await provider_api_service.verify_api_key(provider_name, api_key, base_url, region)
|
||||
success, msg = await provider_api_service.verify_api_key(provider_name, api_key, base_url, region, model_info)
|
||||
if success:
|
||||
return get_result(message=msg)
|
||||
else:
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import aiohttp
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
@@ -225,7 +224,7 @@ def show_provider_model(provider_name: str, model_name: str):
|
||||
}
|
||||
|
||||
|
||||
async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str):
|
||||
async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str, model_info: dict=None):
|
||||
"""
|
||||
Create a provider instance.
|
||||
|
||||
@@ -238,6 +237,15 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
|
||||
:param api_key: API key
|
||||
:param base_url: base url
|
||||
:param region: region
|
||||
:param model_info: model info, {
|
||||
"model_type": ["chat"], # support multiple
|
||||
"model_name": "name",
|
||||
"max_tokens": 4096,
|
||||
"extra": {
|
||||
"field1": "value1",
|
||||
"field2": "'value2"
|
||||
}
|
||||
}
|
||||
:return: (success, result_or_error_message)
|
||||
"""
|
||||
if not provider_name:
|
||||
@@ -259,17 +267,20 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
|
||||
same_key_instance = TenantModelInstanceService.get_by_provider_id_and_api_key(provider_obj.id, api_key)
|
||||
if same_key_instance:
|
||||
return False, f"Already exist instance: {same_key_instance.instance_name} with api_key {api_key}"
|
||||
success, msg = await verify_api_key(provider_name, api_key, base_url, region)
|
||||
success, msg = await verify_api_key(provider_name, api_key, base_url, region, model_info)
|
||||
if not success:
|
||||
return False, msg
|
||||
|
||||
import json
|
||||
extra_fields = {}
|
||||
if base_url:
|
||||
extra_fields["base_url"] = base_url
|
||||
if region:
|
||||
extra_fields["region"] = region
|
||||
TenantModelInstanceService.create_instance(provider_id=provider_obj.id,instance_name=instance_name,api_key=api_key, extra=json.dumps(extra_fields))
|
||||
if model_info:
|
||||
success, msg = add_model_to_instance(tenant_id, provider_name, instance_name, **model_info)
|
||||
if not success:
|
||||
return False, msg
|
||||
|
||||
return True, "success"
|
||||
|
||||
@@ -282,7 +293,6 @@ def list_provider_instances(tenant_id: str, provider_name: str):
|
||||
:param provider_name: provider/factory name
|
||||
:return: (success, result_or_error_message)
|
||||
"""
|
||||
import json
|
||||
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
||||
if not provider_obj:
|
||||
return False, f"No provider found for provider '{provider_name}'"
|
||||
@@ -309,7 +319,7 @@ def list_provider_instances(tenant_id: str, provider_name: str):
|
||||
return True, active_instances + inactive_instances
|
||||
|
||||
|
||||
async def verify_api_key(provider_name: str, api_key: str, base_url: str=None, region: str=None):
|
||||
async def verify_api_key(provider_name: str, api_key: str, base_url: str=None, region: str=None, model_info: dict=None):
|
||||
"""
|
||||
Verify API key for a provider.
|
||||
|
||||
@@ -317,6 +327,15 @@ async def verify_api_key(provider_name: str, api_key: str, base_url: str=None, r
|
||||
:param api_key: API key
|
||||
:param base_url: base url
|
||||
:param region: region
|
||||
:param model_info: model info, {
|
||||
"model_type": ["chat"], # support multiple
|
||||
"model_name": "name",
|
||||
"max_tokens": 4096,
|
||||
"extra": {
|
||||
"field1": "value1",
|
||||
"field2": "'value2"
|
||||
}
|
||||
}
|
||||
:return: (success, result_or_error_message)
|
||||
"""
|
||||
if not provider_name:
|
||||
@@ -333,24 +352,12 @@ async def verify_api_key(provider_name: str, api_key: str, base_url: str=None, r
|
||||
|
||||
factory_llms = factory_info[0]["llm"]
|
||||
if not factory_llms:
|
||||
url = base_url or factory_info[0].get("url")
|
||||
if not url:
|
||||
if not model_info:
|
||||
return False, f"No models found for provider '{provider_name}'"
|
||||
v1_index = url.find("/v1")
|
||||
if v1_index >= 0:
|
||||
models_url = url[: v1_index + 3] + "/models"
|
||||
else:
|
||||
models_url = url.rstrip("/") + "/v1/models"
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(models_url, headers={"Authorization": f"Bearer {api_key}"}) as resp:
|
||||
if resp.status == 200:
|
||||
return True, "success"
|
||||
else:
|
||||
return False, f"Fail to access {models_url} using this api key."
|
||||
except Exception as e:
|
||||
logging.error(f"Fail to access {models_url} using this api key.", exc_info=e)
|
||||
return False, f"Fail to access {models_url} using this api key."
|
||||
factory_llms = [{
|
||||
"model_type": _type,
|
||||
"llm_name": model_info.get("model_name", ""),
|
||||
} for _type in model_info.get("model_type", [])]
|
||||
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
@@ -430,7 +437,6 @@ def show_provider_instance(tenant_id: str, provider_name: str, instance_name: st
|
||||
if not instance_obj:
|
||||
return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'"
|
||||
|
||||
import json
|
||||
extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {}
|
||||
return True, {
|
||||
"id": instance_obj.id,
|
||||
@@ -549,7 +555,7 @@ def list_instance_models(tenant_id: str, provider_name: str, instance_name: str,
|
||||
return True, active_models + inactive_models
|
||||
|
||||
|
||||
def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str|list[str], max_tokens: int, extra: dict):
|
||||
def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str|list[str], max_tokens: int=8192, extra: dict=None):
|
||||
provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name)
|
||||
if not provider_obj:
|
||||
return False, f"No provider found for provider '{provider_name}'"
|
||||
@@ -566,14 +572,13 @@ def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str
|
||||
if isinstance(model_type, str):
|
||||
model_type = [model_type]
|
||||
|
||||
import json
|
||||
|
||||
for _type in model_type:
|
||||
extra_fields = {"max_tokens": max_tokens}
|
||||
target_model = [llm for llm in llms if llm["model_type"] == _type and llm["llm_name"] == model_name]
|
||||
if target_model:
|
||||
extra_fields.update({"is_tools": target_model[0].get("is_tools", False)})
|
||||
extra_fields.update(extra)
|
||||
if extra:
|
||||
extra_fields.update(extra)
|
||||
TenantModelService.insert(
|
||||
model_name=model_name,
|
||||
provider_id=provider_obj.id,
|
||||
|
||||
@@ -57,10 +57,16 @@ def _load_document_api(
|
||||
async def _make_response(payload):
|
||||
return SimpleNamespace(payload=payload, headers={})
|
||||
|
||||
def _login_required(func=None, **_kw):
|
||||
if func is not None:
|
||||
return func
|
||||
return lambda f: f
|
||||
|
||||
_stub(
|
||||
monkeypatch, "api.apps",
|
||||
AUTH_JWT="JWT", AUTH_API="API", AUTH_BETA="BETA",
|
||||
current_user=SimpleNamespace(id="caller-tenant"),
|
||||
login_required=lambda func: func,
|
||||
login_required=_login_required,
|
||||
)
|
||||
_stub(monkeypatch, "api.constants", FILE_NAME_LEN_LIMIT=128, IMG_BASE64_PREFIX="data:image/")
|
||||
_stub(
|
||||
@@ -141,6 +147,11 @@ def _load_document_api(
|
||||
apply_safe_file_response_headers=lambda *_a, **_k: None,
|
||||
)
|
||||
_stub(monkeypatch, "common.ssrf_guard", assert_url_is_safe=lambda *_a, **_k: None)
|
||||
_stub(
|
||||
monkeypatch, "rag.nlp",
|
||||
search=SimpleNamespace(index_name=lambda *_a, **_k: "index"),
|
||||
)
|
||||
_stub(monkeypatch, "rag.nlp.search", index_name=lambda *_a, **_k: "index")
|
||||
|
||||
quart_stub = ModuleType("quart")
|
||||
quart_stub.request = SimpleNamespace(method="GET", args={})
|
||||
@@ -208,13 +219,13 @@ class TestDocumentPreviewAccessCheck:
|
||||
def test_missing_doc_returns_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Missing-doc behaviour is unchanged: same 'Document not found!' shape."""
|
||||
|
||||
def _accessible_should_not_be_called(*_a, **_k):
|
||||
raise AssertionError("accessible() must not be called for a missing doc")
|
||||
def _accessible_returns_false(*_a, **_k):
|
||||
return False
|
||||
|
||||
module = _load_document_api(
|
||||
monkeypatch,
|
||||
doc_get_by_id=(False, None),
|
||||
accessible_fn=_accessible_should_not_be_called,
|
||||
accessible_fn=_accessible_returns_false,
|
||||
storage_get=lambda *_a, **_k: b"",
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user