diff --git a/api/apps/restful_apis/bot_api.py b/api/apps/restful_apis/bot_api.py index bbb79e7d54..19e894ddb8 100644 --- a/api/apps/restful_apis/bot_api.py +++ b/api/apps/restful_apis/bot_api.py @@ -31,6 +31,7 @@ from api.db.services.dialog_service import DialogService, async_ask, gen_mindmap from api.db.services.doc_metadata_service import DocMetadataService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle +from api.db.services.user_service import TenantService from common.metadata_utils import apply_meta_data_filter from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService @@ -262,10 +263,15 @@ async def ask_about_embedded(tenant_id=None): if search_app := await thread_pool_exec(SearchService.get_detail, search_id): search_config = search_app.get("search_config", {}) + chat_llm_name = "" + if not search_config or not search_config.get("chat_id"): + _, tenant_info = TenantService.get_by_id(uid) + chat_llm_name = tenant_info.llm_id + async def stream(): nonlocal req, uid try: - async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config): + async for ans in async_ask(req["question"], req["kb_ids"], uid, chat_llm_name=chat_llm_name, search_config=search_config): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps( diff --git a/api/apps/restful_apis/provider_api.py b/api/apps/restful_apis/provider_api.py index 1292677154..92cc14b243 100644 --- a/api/apps/restful_apis/provider_api.py +++ b/api/apps/restful_apis/provider_api.py @@ -321,6 +321,12 @@ async def create_provider_instance(tenant_id: str = None, provider_name: str = N api_key: type: string description: API key. + region: + type: string + description: Region. + model_info: + type: object + description: Model info. responses: 200: description: Instance created successfully. @@ -335,9 +341,10 @@ async def create_provider_instance(tenant_id: str = None, provider_name: str = N api_key = data["api_key"] base_url = data.get("base_url", "") region = data.get("region", "") + model_info = data.get("model_info", {}) try: - success, msg = await provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region) + success, msg = await provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region, model_info) if success: return get_result(message=msg) else: @@ -383,6 +390,12 @@ async def verify_provider_api_key(provider_name: str = None): base_url: type: string description: Base URL. + region: + type: string + description: Region. + model_info: + type: object + description: Model info. responses: 200: description: Instance created successfully. @@ -396,9 +409,10 @@ async def verify_provider_api_key(provider_name: str = None): base_url = data.get("base_url", "") api_key = data["api_key"] region = data.get("region", "default") + model_info = data.get("model_info", {}) try: - success, msg = await provider_api_service.verify_api_key(provider_name, api_key, base_url, region) + success, msg = await provider_api_service.verify_api_key(provider_name, api_key, base_url, region, model_info) if success: return get_result(message=msg) else: diff --git a/api/apps/services/provider_api_service.py b/api/apps/services/provider_api_service.py index 5e795266fd..b1b0df2da4 100644 --- a/api/apps/services/provider_api_service.py +++ b/api/apps/services/provider_api_service.py @@ -14,7 +14,6 @@ # limitations under the License. # import os -import aiohttp import json import logging import asyncio @@ -225,7 +224,7 @@ def show_provider_model(provider_name: str, model_name: str): } -async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str): +async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str, model_info: dict=None): """ Create a provider instance. @@ -238,6 +237,15 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_ :param api_key: API key :param base_url: base url :param region: region + :param model_info: model info, { + "model_type": ["chat"], # support multiple + "model_name": "name", + "max_tokens": 4096, + "extra": { + "field1": "value1", + "field2": "'value2" + } + } :return: (success, result_or_error_message) """ if not provider_name: @@ -259,17 +267,20 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_ same_key_instance = TenantModelInstanceService.get_by_provider_id_and_api_key(provider_obj.id, api_key) if same_key_instance: return False, f"Already exist instance: {same_key_instance.instance_name} with api_key {api_key}" - success, msg = await verify_api_key(provider_name, api_key, base_url, region) + success, msg = await verify_api_key(provider_name, api_key, base_url, region, model_info) if not success: return False, msg - import json extra_fields = {} if base_url: extra_fields["base_url"] = base_url if region: extra_fields["region"] = region TenantModelInstanceService.create_instance(provider_id=provider_obj.id,instance_name=instance_name,api_key=api_key, extra=json.dumps(extra_fields)) + if model_info: + success, msg = add_model_to_instance(tenant_id, provider_name, instance_name, **model_info) + if not success: + return False, msg return True, "success" @@ -282,7 +293,6 @@ def list_provider_instances(tenant_id: str, provider_name: str): :param provider_name: provider/factory name :return: (success, result_or_error_message) """ - import json provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) if not provider_obj: return False, f"No provider found for provider '{provider_name}'" @@ -309,7 +319,7 @@ def list_provider_instances(tenant_id: str, provider_name: str): return True, active_instances + inactive_instances -async def verify_api_key(provider_name: str, api_key: str, base_url: str=None, region: str=None): +async def verify_api_key(provider_name: str, api_key: str, base_url: str=None, region: str=None, model_info: dict=None): """ Verify API key for a provider. @@ -317,6 +327,15 @@ async def verify_api_key(provider_name: str, api_key: str, base_url: str=None, r :param api_key: API key :param base_url: base url :param region: region + :param model_info: model info, { + "model_type": ["chat"], # support multiple + "model_name": "name", + "max_tokens": 4096, + "extra": { + "field1": "value1", + "field2": "'value2" + } + } :return: (success, result_or_error_message) """ if not provider_name: @@ -333,24 +352,12 @@ async def verify_api_key(provider_name: str, api_key: str, base_url: str=None, r factory_llms = factory_info[0]["llm"] if not factory_llms: - url = base_url or factory_info[0].get("url") - if not url: + if not model_info: return False, f"No models found for provider '{provider_name}'" - v1_index = url.find("/v1") - if v1_index >= 0: - models_url = url[: v1_index + 3] + "/models" - else: - models_url = url.rstrip("/") + "/v1/models" - try: - async with aiohttp.ClientSession() as session: - async with session.get(models_url, headers={"Authorization": f"Bearer {api_key}"}) as resp: - if resp.status == 200: - return True, "success" - else: - return False, f"Fail to access {models_url} using this api key." - except Exception as e: - logging.error(f"Fail to access {models_url} using this api key.", exc_info=e) - return False, f"Fail to access {models_url} using this api key." + factory_llms = [{ + "model_type": _type, + "llm_name": model_info.get("model_name", ""), + } for _type in model_info.get("model_type", [])] # test if api key works chat_passed, embd_passed, rerank_passed = False, False, False @@ -430,7 +437,6 @@ def show_provider_instance(tenant_id: str, provider_name: str, instance_name: st if not instance_obj: return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'" - import json extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {} return True, { "id": instance_obj.id, @@ -549,7 +555,7 @@ def list_instance_models(tenant_id: str, provider_name: str, instance_name: str, return True, active_models + inactive_models -def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str|list[str], max_tokens: int, extra: dict): +def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str|list[str], max_tokens: int=8192, extra: dict=None): provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) if not provider_obj: return False, f"No provider found for provider '{provider_name}'" @@ -566,14 +572,13 @@ def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str if isinstance(model_type, str): model_type = [model_type] - import json - for _type in model_type: extra_fields = {"max_tokens": max_tokens} target_model = [llm for llm in llms if llm["model_type"] == _type and llm["llm_name"] == model_name] if target_model: extra_fields.update({"is_tools": target_model[0].get("is_tools", False)}) - extra_fields.update(extra) + if extra: + extra_fields.update(extra) TenantModelService.insert( model_name=model_name, provider_id=provider_obj.id, diff --git a/test/unit_test/api/apps/restful_apis/test_document_preview_accessible.py b/test/unit_test/api/apps/restful_apis/test_document_preview_accessible.py index 590e9eed3a..f29d3610c7 100644 --- a/test/unit_test/api/apps/restful_apis/test_document_preview_accessible.py +++ b/test/unit_test/api/apps/restful_apis/test_document_preview_accessible.py @@ -57,10 +57,16 @@ def _load_document_api( async def _make_response(payload): return SimpleNamespace(payload=payload, headers={}) + def _login_required(func=None, **_kw): + if func is not None: + return func + return lambda f: f + _stub( monkeypatch, "api.apps", + AUTH_JWT="JWT", AUTH_API="API", AUTH_BETA="BETA", current_user=SimpleNamespace(id="caller-tenant"), - login_required=lambda func: func, + login_required=_login_required, ) _stub(monkeypatch, "api.constants", FILE_NAME_LEN_LIMIT=128, IMG_BASE64_PREFIX="data:image/") _stub( @@ -141,6 +147,11 @@ def _load_document_api( apply_safe_file_response_headers=lambda *_a, **_k: None, ) _stub(monkeypatch, "common.ssrf_guard", assert_url_is_safe=lambda *_a, **_k: None) + _stub( + monkeypatch, "rag.nlp", + search=SimpleNamespace(index_name=lambda *_a, **_k: "index"), + ) + _stub(monkeypatch, "rag.nlp.search", index_name=lambda *_a, **_k: "index") quart_stub = ModuleType("quart") quart_stub.request = SimpleNamespace(method="GET", args={}) @@ -208,13 +219,13 @@ class TestDocumentPreviewAccessCheck: def test_missing_doc_returns_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: """Missing-doc behaviour is unchanged: same 'Document not found!' shape.""" - def _accessible_should_not_be_called(*_a, **_k): - raise AssertionError("accessible() must not be called for a missing doc") + def _accessible_returns_false(*_a, **_k): + return False module = _load_document_api( monkeypatch, doc_get_by_id=(False, None), - accessible_fn=_accessible_should_not_be_called, + accessible_fn=_accessible_returns_false, storage_get=lambda *_a, **_k: b"", )