diff --git a/Dockerfile b/Dockerfile index 58d25d3ca0..2929c2dc26 100644 --- a/Dockerfile +++ b/Dockerfile @@ -189,6 +189,7 @@ COPY mcp mcp COPY common common COPY memory memory COPY bin bin +COPY tools/scripts tools/scripts COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template COPY docker/entrypoint.sh ./ diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 83c3e27e53..57dbaeaa65 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -27,10 +27,9 @@ import json_repair from agent.component.llm import LLM, LLMParam from agent.tools.base import LLMToolPluginCallSession, ToolBase, ToolMeta, ToolParamBase -from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name +from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance, get_model_type_by_name from api.db.services.llm_service import LLMBundle from api.db.services.mcp_server_service import MCPServerService -from api.db.services.tenant_llm_service import TenantLLMService from common.connection_utils import timeout from common.mcp_tool_call_conn import MCPToolBinding, MCPToolCallSession, mcp_tool_metadata_to_openai_tool from rag.prompts.generator import citation_plus, citation_prompt, full_question, kb_prompt, message_fit_in, structured_output_prompt @@ -81,7 +80,9 @@ class Agent(LLM, ToolBase): original_name = cpn.get_meta()["function"]["name"] indexed_name = f"{original_name}_{idx}" self.tools[indexed_name] = cpn - chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id) + model_types = get_model_type_by_name(self._canvas.get_tenant_id(), self._param.llm_id) + model_type = "chat" if "chat" in model_types else model_types[0] + chat_model_config = get_model_config_from_provider_instance(self._canvas.get_tenant_id(), model_type, self._param.llm_id) self.chat_mdl = LLMBundle( self._canvas.get_tenant_id(), chat_model_config, diff --git a/agent/component/browser.py b/agent/component/browser.py index 132cac283c..7734722a43 100644 --- a/agent/component/browser.py +++ b/agent/component/browser.py @@ -33,10 +33,9 @@ from urllib.request import Request, urlopen from agent.component.base import ComponentBase from agent.component.llm import LLMParam from api.db import FileType -from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name +from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance, get_model_type_by_name from api.db.services import duplicate_name from api.db.services.file_service import FileService -from api.db.services.tenant_llm_service import TenantLLMService from api.utils.file_utils import filename_type from common import settings from common.connection_utils import timeout @@ -394,9 +393,9 @@ class Browser(ComponentBase, ABC): def _build_browser_llm(self): from browser_use.llm import ChatBrowserUse, ChatOpenAI - chat_model_config = get_model_config_by_type_and_name( + chat_model_config = get_model_config_from_provider_instance( self._canvas.get_tenant_id(), - TenantLLMService.llm_id2llm_type(self._param.llm_id), + get_model_type_by_name(self._param.llm_id), self._param.llm_id, ) cfg = self._as_model_config_dict(chat_model_config) diff --git a/agent/component/categorize.py b/agent/component/categorize.py index 708ce142fe..3bc48e22db 100644 --- a/agent/component/categorize.py +++ b/agent/component/categorize.py @@ -21,7 +21,7 @@ from abc import ABC from common.constants import LLMType from api.db.services.llm_service import LLMBundle -from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name +from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance from agent.component.llm import LLMParam, LLM from common.connection_utils import timeout from rag.llm.chat_model import ERROR_PREFIX @@ -123,7 +123,7 @@ class Categorize(LLM, ABC): msg[-1]["content"] = query_value self.set_input_value(query_key, msg[-1]["content"]) self._param.update_prompt() - chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) + chat_model_config = get_model_config_from_provider_instance(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) chat_mdl = LLMBundle(self._canvas.get_tenant_id(), chat_model_config) user_prompt = """ diff --git a/agent/component/llm.py b/agent/component/llm.py index 126a5a0e8d..b4b2ee038b 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -24,8 +24,7 @@ import json_repair from functools import partial from common.constants import LLMType from api.db.services.llm_service import LLMBundle -from api.db.services.tenant_llm_service import TenantLLMService -from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name +from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance, get_model_type_by_name from agent.component.base import ComponentBase, ComponentParamBase from common.connection_utils import timeout from rag.prompts.generator import tool_call_summary, message_fit_in, citation_prompt, structured_output_prompt @@ -85,7 +84,9 @@ class LLM(ComponentBase): def __init__(self, canvas, component_id, param: ComponentParamBase): super().__init__(canvas, component_id, param) - chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id) + model_types = get_model_type_by_name(self._canvas.get_tenant_id(), self._param.llm_id) + model_type = "chat" if "chat" in model_types else model_types[0] + chat_model_config = get_model_config_from_provider_instance(self._canvas.get_tenant_id(), model_type, self._param.llm_id) self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), chat_model_config, max_retries=self._param.max_retries, retry_interval=self._param.delay_after_error) @@ -247,9 +248,11 @@ class LLM(ComponentBase): self.set_input_value(k, args[k]) self.imgs = self._uniq_images(self.imgs + extracted_imgs) - if self.imgs and TenantLLMService.llm_id2llm_type(self._param.llm_id) == LLMType.CHAT.value: - self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT.value, - self._param.llm_id, max_retries=self._param.max_retries, + model_types = get_model_type_by_name(self._canvas.get_tenant_id(), self._param.llm_id) + model_type = LLMType.CHAT.value if LLMType.CHAT.value in model_types else model_types[0] + model_config = get_model_config_from_provider_instance(self._canvas.get_tenant_id(), model_type, self._param.llm_id) + if self.imgs: + self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), model_config, max_retries=self._param.max_retries, retry_interval=self._param.delay_after_error ) diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 02cb3e2ce6..0d31490b52 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -27,7 +27,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.memory_service import MemoryService from api.db.joint_services import memory_message_service -from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_tenant_default_model_by_type +from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from common import settings from common.connection_utils import timeout from rag.app.tag import label_question @@ -121,12 +121,12 @@ class Retrieval(ToolBase, ABC): embd_mdl = None if embd_nms: tenant_id = self._canvas.get_tenant_id() - embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_nms[0]) + embd_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.EMBEDDING, embd_nms[0]) embd_mdl = LLMBundle(tenant_id, embd_model_config) rerank_mdl = None if self._param.rerank_id: - rerank_model_config = get_model_config_by_type_and_name(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) + rerank_model_config = get_model_config_from_provider_instance(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) rerank_mdl = LLMBundle(kbs[0].tenant_id, rerank_model_config) vars = self.get_input_elements_from_text(query_text) diff --git a/api/apps/restful_apis/bot_api.py b/api/apps/restful_apis/bot_api.py index 9e96a06931..c2d77a1060 100644 --- a/api/apps/restful_apis/bot_api.py +++ b/api/apps/restful_apis/bot_api.py @@ -35,8 +35,7 @@ from api.db.services.llm_service import LLMBundle from common.metadata_utils import apply_meta_data_filter from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService -from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, \ - get_model_config_by_type_and_name +from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from common.misc_utils import get_uuid, thread_pool_exec from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_json_result, \ get_result, get_request_json, server_error_response, token_required, validate_request @@ -435,7 +434,6 @@ async def retrieval_test_embedded(): return get_error_data_result("`top_k` must be greater than 0") langs = req.get("cross_languages", []) rerank_id = req.get("rerank_id", "") - tenant_rerank_id = req.get("tenant_rerank_id", "") tenant_id = objs[0].tenant_id if not tenant_id: return get_error_data_result(message="permission denined.") @@ -458,7 +456,7 @@ async def retrieval_test_embedded(): if meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_id = search_config.get("chat_id", "") if chat_id: - chat_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.CHAT, chat_id) + chat_model_config = await thread_pool_exec(get_model_config_from_provider_instance, tenant_id, LLMType.CHAT, chat_id) else: chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, chat_model_config) @@ -504,24 +502,12 @@ async def retrieval_test_embedded(): if langs: _question = await cross_languages(kb.tenant_id, None, _question, langs) - if kb.tenant_embd_id: - embd_model_config = await thread_pool_exec(get_model_config_by_id, kb.tenant_embd_id) - else: - embd_model_config = await thread_pool_exec(get_model_config_by_type_and_name, kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) + embd_model_config = await thread_pool_exec(get_model_config_from_provider_instance, kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) rerank_mdl = None - if tenant_rerank_id: - allowed_rerank_tenant_ids = {tenant_id, *tenant_ids} - rerank_model_config = await thread_pool_exec( - get_model_config_by_id, - tenant_rerank_id, - allowed_rerank_tenant_ids, - tenant_id, - ) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) - elif rerank_id: - rerank_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.RERANK, rerank_id) + if rerank_id: + rerank_model_config = await thread_pool_exec(get_model_config_from_provider_instance, tenant_id, LLMType.RERANK, rerank_id) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) if req.get("keyword", False): @@ -586,7 +572,7 @@ async def related_questions_embedded(): chat_id = search_config.get("chat_id", "") if chat_id: - chat_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.CHAT, chat_id) + chat_model_config = await thread_pool_exec(get_model_config_from_provider_instance, tenant_id, LLMType.CHAT, chat_id) else: chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, chat_model_config) diff --git a/api/apps/restful_apis/chat_api.py b/api/apps/restful_apis/chat_api.py index a46b93d423..ed1805be36 100644 --- a/api/apps/restful_apis/chat_api.py +++ b/api/apps/restful_apis/chat_api.py @@ -26,8 +26,7 @@ from quart import Response, request from api.apps import current_user, login_required from api.db.joint_services.tenant_model_service import ( - get_model_config_by_type_and_name, - get_tenant_default_model_by_type, + get_tenant_default_model_by_type, get_model_config_from_provider_instance, get_api_key, split_model_name ) from api.db.services.chunk_feedback_service import ChunkFeedbackService from api.db.services.conversation_service import ConversationService, structure_answer @@ -35,7 +34,6 @@ from api.db.services.dialog_service import DialogService, async_chat, gen_mindma from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService -from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_service import TenantService, UserTenantService from api.utils.api_utils import ( check_duplicate_ids, @@ -46,7 +44,6 @@ from api.utils.api_utils import ( validate_request, ) from api.utils.pagination_utils import validate_rest_api_page_size -from api.utils.tenant_utils import ensure_tenant_model_id_for_params from common.constants import LLMType, RetCode, StatusEnum from common import settings from common.misc_utils import get_uuid, thread_pool_exec @@ -241,37 +238,44 @@ async def _validate_llm_id(llm_id, tenant_id, llm_setting=None): if not llm_id: return None - llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(llm_id) - model_type = (llm_setting or {}).get("model_type") - if model_type not in {"chat", "image2text"}: + conf_model_type = (llm_setting or {}).get("model_type") + if isinstance(conf_model_type, str): + model_type = conf_model_type if conf_model_type in {"chat", "image2text"} else "chat" + elif isinstance(conf_model_type, list): + model_type = "image2text" if "image2text" in conf_model_type else "chat" + else: model_type = "chat" - - if not await thread_pool_exec( - TenantLLMService.query, - tenant_id=tenant_id, - llm_name=llm_name, - llm_factory=llm_factory, - model_type=model_type, - ): + try: + await thread_pool_exec( + get_model_config_from_provider_instance, + tenant_id=tenant_id, + model_name=llm_id, + model_type=model_type, + ) + except Exception as e: + logging.error(f"Fail to get model config for {llm_id}: {e}") return f"`llm_id` {llm_id} doesn't exist" - return None + return None async def _validate_rerank_id(rerank_id, tenant_id): if not rerank_id: return None - llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(rerank_id) + parts = rerank_id.split('@') + llm_name = parts[0] if llm_name in _DEFAULT_RERANK_MODELS: return None - if await thread_pool_exec( - TenantLLMService.query, - tenant_id=tenant_id, - llm_name=llm_name, - llm_factory=llm_factory, - model_type="rerank", - ): - return None - return f"`rerank_id` {rerank_id} doesn't exist" + try: + await thread_pool_exec( + get_model_config_from_provider_instance, + tenant_id=tenant_id, + model_name=rerank_id, + model_type="rerank", + ) + except Exception as e: + logging.error(f"Fail to get model config for {rerank_id}: {e}") + return f"`rerank_id` {rerank_id} doesn't exist" + return None # def _validate_prompt_config(prompt_config): @@ -302,7 +306,7 @@ async def _validate_dataset_ids(dataset_ids, tenant_id): return f"The dataset {dataset_id} doesn't own parsed file" kbs.append(kb) - embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] + embd_ids = [split_model_name(kb.embd_id)[0] for kb in kbs] if len(set(embd_ids)) > 1: return f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}' @@ -380,7 +384,6 @@ async def create(): # if err: # return get_data_error_result(message=err) - req = ensure_tenant_model_id_for_params(current_user.id, req) req = {field: value for field, value in req.items() if field in _PERSISTED_FIELDS} for field in _READONLY_FIELDS: req.pop(field, None) @@ -540,8 +543,6 @@ async def update_chat(chat_id): # kb_ids = req.get("kb_ids", current_chat.get("kb_ids", [])) # if not kb_ids and not prompt_config.get("tavily_api_key") and _has_knowledge_placeholder(prompt_config): # return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.") - - req = ensure_tenant_model_id_for_params(current_user.id, req) req = {field: value for field, value in req.items() if field in _PERSISTED_FIELDS} for field in _READONLY_FIELDS: req.pop(field, None) @@ -632,7 +633,6 @@ async def patch_chat(chat_id): # if not kb_ids and not prompt_config.get("tavily_api_key") and _has_knowledge_placeholder(prompt_config): # return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.") - req = ensure_tenant_model_id_for_params(current_user.id, req) req = {field: value for field, value in req.items() if field in _PERSISTED_FIELDS} for field in _READONLY_FIELDS: req.pop(field, None) @@ -1101,7 +1101,7 @@ async def recommendation(): chat_id = search_config.get("chat_id", "") if chat_id: - chat_model_config = get_model_config_by_type_and_name(current_user.id, LLMType.CHAT, chat_id) + chat_model_config = get_model_config_from_provider_instance(current_user.id, LLMType.CHAT, chat_id) else: chat_model_config = get_tenant_default_model_by_type(current_user.id, LLMType.CHAT) chat_mdl = LLMBundle(current_user.id, chat_model_config) @@ -1199,10 +1199,16 @@ async def session_completion(chat_id_in_arg=""): conv.reference.append({"chunks": [], "doc_aggs": []}) if chat_model_id: - if not await thread_pool_exec(TenantLLMService.get_api_key, tenant_id=dia.tenant_id, model_name=chat_model_id): + if not await thread_pool_exec(get_api_key, tenant_id=dia.tenant_id, model_name=chat_model_id): return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.") dia.llm_id = chat_model_id dia.llm_setting = chat_model_config + else: + logging.info("empty chat_model_id in req, use default chat model.") + _, tenant_info = TenantService.get_by_id(dia.tenant_id) + if not tenant_info or not tenant_info.llm_id: + raise LookupError("No default chat model for tenant.") + dia.llm_id = tenant_info.llm_id stream_mode = req.pop("stream", True) diff --git a/api/apps/restful_apis/chunk_api.py b/api/apps/restful_apis/chunk_api.py index 8bb677372a..8924d62ed6 100644 --- a/api/apps/restful_apis/chunk_api.py +++ b/api/apps/restful_apis/chunk_api.py @@ -24,8 +24,7 @@ from quart import request from api.apps import login_required from api.db.joint_services.tenant_model_service import ( - get_model_config_by_id, - get_model_config_by_type_and_name, + get_model_config_from_provider_instance, get_tenant_default_model_by_type, ) from api.db.db_models import Document, Task @@ -291,16 +290,12 @@ async def retrieval_test(tenant_id): e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) if not e: return get_error_data_result(message="Dataset not found!") - embd_model_config = get_model_config_by_id(kb.tenant_embd_id) if kb.tenant_embd_id else get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) + embd_model_config = get_model_config_from_provider_instance(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) rerank_mdl = None - if req.get("tenant_rerank_id"): - allowed_rerank_tenant_ids = {tenant_id, *[dataset.tenant_id for dataset in kbs]} - rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"], allowed_tenant_ids=allowed_rerank_tenant_ids, requester_tenant_id=tenant_id) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) - elif req.get("rerank_id"): - rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK, req["rerank_id"]) + if req.get("rerank_id"): + rerank_model_config = get_model_config_from_provider_instance(kb.tenant_id, LLMType.RERANK, req["rerank_id"]) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) if langs: @@ -517,12 +512,8 @@ async def add_chunk(tenant_id, dataset_id, document_id): d["img_id"] = f"{dataset_id}-{chunk_id}" d["doc_type_kwd"] = "image" - tenant_embd_id = DocumentService.get_tenant_embd_id(document_id) - if tenant_embd_id: - model_config = get_model_config_by_id(tenant_embd_id) - else: - embd_id = DocumentService.get_embd_id(document_id) - model_config = get_model_config_by_type_and_name(dataset_tenant_id, LLMType.EMBEDDING.value, embd_id) + embd_id = DocumentService.get_embd_id(document_id) + model_config = get_model_config_from_provider_instance(dataset_tenant_id, LLMType.EMBEDDING.value, embd_id) embd_mdl = TenantLLMService.model_instance(model_config) v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] @@ -661,12 +652,8 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id): d["img_id"] = f"{dataset_id}-{chunk_id}" d["doc_type_kwd"] = "image" - tenant_embd_id = DocumentService.get_tenant_embd_id(document_id) - if tenant_embd_id: - model_config = get_model_config_by_id(tenant_embd_id) - else: - embd_id = DocumentService.get_embd_id(document_id) - model_config = get_model_config_by_type_and_name(dataset_tenant_id, LLMType.EMBEDDING.value, embd_id) + embd_id = DocumentService.get_embd_id(document_id) + model_config = get_model_config_from_provider_instance(dataset_tenant_id, LLMType.EMBEDDING.value, embd_id) embd_mdl = TenantLLMService.model_instance(model_config) if doc.parser_id == ParserType.QA: arr = [t for t in re.split(r"[\n\t]", d["content_with_weight"]) if len(t) > 1] diff --git a/api/apps/restful_apis/dataset_api.py b/api/apps/restful_apis/dataset_api.py index f1c5797c34..70fbcc0777 100644 --- a/api/apps/restful_apis/dataset_api.py +++ b/api/apps/restful_apis/dataset_api.py @@ -149,6 +149,8 @@ async def create(tenant_id: str = None): return get_result(data=result) else: return get_error_data_result(message=result) + except LookupError as e: + return get_error_argument_result(str(e)) except ValueError as e: return get_error_argument_result(str(e)) except Exception as e: diff --git a/api/apps/restful_apis/dify_retrieval_api.py b/api/apps/restful_apis/dify_retrieval_api.py index ffe9f247f9..a2b19dea9d 100644 --- a/api/apps/restful_apis/dify_retrieval_api.py +++ b/api/apps/restful_apis/dify_retrieval_api.py @@ -27,7 +27,7 @@ from api.db.services.document_service import DocumentService from api.db.services.doc_metadata_service import DocMetadataService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle -from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type +from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from common.metadata_utils import meta_filter, convert_conditions from api.utils.api_utils import apikey_required, build_error_result, get_request_json, get_json_result from rag.app.tag import label_question @@ -260,10 +260,7 @@ async def retrieval(tenant_id): kb_id, ) return build_error_result(message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - if kb.tenant_embd_id: - model_config = get_model_config_by_id(kb.tenant_embd_id) - else: - model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) + model_config = get_model_config_from_provider_instance(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, model_config) if metadata_condition: doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))) diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py index 53319bb6dd..e080f82c45 100644 --- a/api/apps/restful_apis/memory_api.py +++ b/api/apps/restful_apis/memory_api.py @@ -18,12 +18,11 @@ import os import time from quart import request, g -from common.constants import LLMType, RetCode +from common.constants import RetCode from common.exceptions import ArgumentException, NotFoundException from api.apps import login_required, current_user from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result from api.apps.services import memory_api_service -from api.utils.tenant_utils import ensure_tenant_model_id_for_params from api.utils.pagination_utils import validate_rest_api_page_size @@ -36,18 +35,11 @@ async def create_memory(): req = await get_request_json() t_parsed = time.perf_counter() if timing_enabled else None try: - req = ensure_tenant_model_id_for_params(current_user.id, req) - if not req.get("tenant_llm_id"): - raise ArgumentException( - f"Tenant Model with name {req['llm_id']} and type {LLMType.CHAT.value} not found" - ) memory_info = { "name": req["name"], "memory_type": req["memory_type"], "embd_id": req["embd_id"], - "llm_id": req["llm_id"], - "tenant_embd_id": req["tenant_embd_id"], - "tenant_llm_id": req["tenant_llm_id"], + "llm_id": req["llm_id"] } success, res = await memory_api_service.create_memory(memory_info) if timing_enabled: diff --git a/api/apps/restful_apis/models_api.py b/api/apps/restful_apis/models_api.py new file mode 100644 index 0000000000..cac0a1cf9e --- /dev/null +++ b/api/apps/restful_apis/models_api.py @@ -0,0 +1,201 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging + +from quart import request + +from api.apps import login_required +from api.apps.services import models_api_service +from api.utils.api_utils import ( + add_tenant_id_to_kwargs, + get_error_argument_result, + get_error_data_result, + get_result, +) + + +@manager.route("/models", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def get_added_models(tenant_id: str): + """ + List tenant all added models. + --- + tags: + - Models + security: + - ApiKeyAuth: [] + parameters: + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: List of added models. + schema: + type: object + properties: + data: + type: object + properties: + models: + type: array + items: + type: object + properties: + model_provider: + type: string + model_instance: + type: string + model_name: + type: string + model_type: + type: string + enable: + type: boolean + """ + model_type_filter = request.args.get("type") + try: + success, result = models_api_service.list_tenant_added_models(tenant_id, model_type_filter) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/models/default", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def get_default_models(tenant_id: str): + """ + List tenant default models. + --- + tags: + - Models + security: + - ApiKeyAuth: [] + parameters: + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: List of default models. + schema: + type: object + properties: + data: + type: object + properties: + models: + type: array + items: + type: object + properties: + model_provider: + type: string + model_instance: + type: string + model_name: + type: string + model_type: + type: string + enable: + type: boolean + """ + try: + success, result = models_api_service.list_tenant_default_models(tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/models/default", methods=["PATCH"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def set_default_models(tenant_id: str): + """ + Set or clear a tenant default model. + --- + tags: + - Models + security: + - ApiKeyAuth: [] + parameters: + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Model configuration. + required: true + schema: + type: object + required: + - model_type + properties: + model_provider: + type: string + description: Provider name. Required when setting a model; omit to clear. + model_instance: + type: string + description: Instance name. Required when setting a model; omit to clear. + model_name: + type: string + description: Model name. Required when setting a model; omit to clear. + model_type: + type: string + description: "Model type: chat, embedding, rerank, asr, vision, tts, ocr" + responses: + 200: + description: Default model updated. + schema: + type: object + """ + data = await request.get_json() + if not data or "model_type" not in data: + return get_error_argument_result(message="model_type is required") + + model_provider = data.get("model_provider", "") + model_instance = data.get("model_instance", "") + model_name = data.get("model_name", "") + model_type = data["model_type"] + + try: + success, msg = models_api_service.set_tenant_default_models( + tenant_id, model_provider, model_instance, model_name, model_type + ) + if success: + logging.info(f"success: {success}, msg: {msg}") + return get_result(message=msg) + else: + return get_error_data_result(message=msg) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") diff --git a/api/apps/restful_apis/openai_api.py b/api/apps/restful_apis/openai_api.py index baa011f32a..8882e2c8bf 100644 --- a/api/apps/restful_apis/openai_api.py +++ b/api/apps/restful_apis/openai_api.py @@ -22,7 +22,7 @@ from quart import Response, jsonify from api.apps import current_user, login_required from api.db.services.dialog_service import DialogService, async_chat from api.db.services.doc_metadata_service import DocMetadataService -from api.db.services.tenant_llm_service import TenantLLMService +from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance, get_api_key from api.utils.api_utils import get_error_data_result, get_request_json, validate_request from common.constants import RetCode, StatusEnum from common.metadata_utils import convert_conditions, meta_filter @@ -33,17 +33,18 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None): if not llm_id: return None - llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(llm_id) model_type = (llm_setting or {}).get("model_type") if model_type not in {"chat", "image2text"}: model_type = "chat" - if not TenantLLMService.query( - tenant_id=tenant_id, - llm_name=llm_name, - llm_factory=llm_factory, - model_type=model_type, - ): + try: + get_model_config_from_provider_instance( + tenant_id=tenant_id, + model_name=llm_id, + model_type=model_type, + ) + except Exception as e: + logging.error(f"Fail to get model config for {llm_id}: {e}") return f"`llm_id` {llm_id} doesn't exist" return None @@ -133,7 +134,7 @@ async def openai_chat_completions(chat_id): if llm_id_error: return get_error_data_result(message=llm_id_error, code=RetCode.ARGUMENT_ERROR) dia.llm_id = requested_model - if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=requested_model): + if not get_api_key(tenant_id=dia.tenant_id, model_name=requested_model): return get_error_data_result(message=f"Cannot use specified model {requested_model}.") metadata_condition = extra_body.get("metadata_condition") or {} diff --git a/api/apps/restful_apis/provider_api.py b/api/apps/restful_apis/provider_api.py new file mode 100644 index 0000000000..6a9806b47b --- /dev/null +++ b/api/apps/restful_apis/provider_api.py @@ -0,0 +1,796 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging + +from quart import request + +from api.apps import login_required +from api.utils.api_utils import ( + add_tenant_id_to_kwargs, + get_error_argument_result, + get_error_data_result, + get_result, +) +from api.apps.services import provider_api_service + + +@manager.route("/providers", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def list_providers(tenant_id: str = None): + """ + List providers. + --- + parameters: + - in: query + name: available + type: string + required: false + description: "If 'true', list all available system providers; otherwise list tenant-configured providers." + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: List of providers. + schema: + type: object + properties: + data: + type: array + items: + type: object + """ + available_only = request.args.get("available", "").lower() == "true" + try: + success, result = provider_api_service.list_providers(tenant_id, available_only) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/providers", methods=["PUT"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def add_provider(tenant_id: str = None): + """ + Add a provider for the tenant. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Provider creation parameters. + required: true + schema: + type: object + required: + - provider_name + properties: + provider_name: + type: string + description: Provider/factory name. + responses: + 200: + description: Provider added successfully. + schema: + type: object + """ + data = await request.get_json() + if not data or "provider_name" not in data: + return get_error_argument_result(message="provider_name is required") + + provider_name = data["provider_name"] + + try: + success, msg = provider_api_service.add_provider(tenant_id, provider_name) + if success: + return get_result(message=msg) + else: + return get_error_data_result(message=msg) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/providers/", methods=["GET"]) # noqa: F821 +@login_required +def show_provider(provider_name: str): + """ + Show provider details. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: provider_name + type: string + required: true + description: Provider name. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Provider details. + schema: + type: object + """ + try: + success, result = provider_api_service.show_provider(provider_name) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/providers/", methods=["DELETE"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def delete_provider(tenant_id: str = None, provider_name: str = None): + """ + Delete a provider and all its models for the tenant. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: provider_name + type: string + required: true + description: Provider name. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Provider deleted successfully. + schema: + type: object + """ + try: + success, msg = provider_api_service.delete_provider(tenant_id, provider_name) + if success: + return get_result(message=msg) + else: + return get_error_data_result(message=msg) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/providers//models", methods=["GET"]) # noqa: F821 +@login_required +def list_provider_models(provider_name: str): + """ + List models for a provider. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: provider_name + type: string + required: true + description: Provider name. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: List of models for the provider. + schema: + type: object + properties: + data: + type: array + items: + type: object + """ + try: + success, result = provider_api_service.list_provider_models(provider_name) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/providers//models/", methods=["GET"]) # noqa: F821 +@login_required +def show_provider_model(provider_name: str, model_name: str): + """ + Show a specific model for a provider. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: provider_name + type: string + required: true + description: Provider name. + - in: path + name: model_name + type: string + required: true + description: Model name. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Model details. + schema: + type: object + """ + try: + success, result = provider_api_service.show_provider_model(provider_name, model_name) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/providers//instances", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def create_provider_instance(tenant_id: str = None, provider_name: str = None): + """ + Create a provider instance. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: provider_name + type: string + required: true + description: Provider name. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Instance creation parameters. + required: true + schema: + type: object + required: + - instance_name + - api_key + properties: + instance_name: + type: string + description: Instance name. + api_key: + type: string + description: API key. + responses: + 200: + description: Instance created successfully. + schema: + type: object + """ + data = await request.get_json() + if not data or "instance_name" not in data or "api_key" not in data: + return get_error_argument_result(message="instance_name and api_key are required") + + instance_name = data["instance_name"] + api_key = data["api_key"] + base_url = data.get("base_url", "") + region = data.get("region", "") + + try: + success, msg = provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region) + if success: + return get_result(message=msg) + else: + return get_error_data_result(message=msg) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/providers//instances", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def list_provider_instances(tenant_id: str = None, provider_name: str = None): + """ + List provider instances. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: provider_name + type: string + required: true + description: Provider name. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: List of provider instances. + schema: + type: object + properties: + data: + type: array + items: + type: object + """ + try: + success, result = provider_api_service.list_provider_instances(tenant_id, provider_name) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/providers//instances/", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def show_provider_instance(tenant_id: str = None, provider_name: str = None, instance_name: str = None): + """ + Show a provider instance. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: provider_name + type: string + required: true + description: Provider name. + - in: path + name: instance_name + type: string + required: true + description: Instance name. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Instance details. + schema: + type: object + """ + try: + success, result = provider_api_service.show_provider_instance(tenant_id, provider_name, instance_name) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/providers//instances", methods=["DELETE"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def drop_provider_instances(tenant_id: str = None, provider_name: str = None): + """ + Drop provider instances. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: provider_name + type: string + required: true + description: Provider name. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Instance deletion parameters. + required: true + schema: + type: object + required: + - instances + properties: + instances: + type: array + items: + type: string + description: List of instance names to drop. + responses: + 200: + description: Instances dropped successfully. + schema: + type: object + """ + data = await request.get_json() + if not data or "instances" not in data: + return get_error_argument_result(message="instances is required") + + instances = data["instances"] + if not instances: + return get_error_argument_result(message="instances is required") + + try: + success, msg = provider_api_service.drop_provider_instances(tenant_id, provider_name, instances) + if success: + return get_result(message=msg) + else: + return get_error_data_result(message=msg) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/providers//instances//models", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def list_instance_models(tenant_id: str = None, provider_name: str = None, instance_name: str = None): + """ + List models for a provider instance. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: provider_name + type: string + required: true + description: Provider name. + - in: path + name: instance_name + type: string + required: true + description: Instance name. + - in: query + name: supported + type: string + required: false + description: "If 'true', list only supported models." + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: List of models. + schema: + type: object + properties: + data: + type: array + items: + type: object + """ + supported_only = request.args.get("supported", "").lower() == "true" + try: + success, result = provider_api_service.list_instance_models( + tenant_id, provider_name, instance_name, supported_only + ) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/providers//instances//models", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str): + """ + Add a model to an instance. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: provider_name + type: string + required: true + description: Provider name. + - in: path + name: instance_name + type: string + required: true + description: Instance name. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Model details. + required: true + schema: + type: object + required: + - model_name + - model_type + properties: + model_name: + type: string + description: Model name. + model_type: + type: string + description: Model type. + max_tokens: + type: integer + description: Maximum number of tokens. + extra: + type: object + description: Extra model details. + responses: + 200: + description: Model added successfully. + """ + data = await request.get_json() + if not data or "model_name" not in data or "model_type" not in data: + return get_error_argument_result(message="model_name and model_type are required") + + model_name = data["model_name"] + model_type = data["model_type"] + max_tokens = data.get("max_tokens", 8192) + extra = data.get("extra", {}) + + try: + success, result = provider_api_service.add_model_to_instance( + tenant_id, provider_name, instance_name, model_name, model_type, max_tokens, extra + ) + if success: + return get_result(message=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/providers//instances//models/", methods=["PATCH"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def enable_or_disable_model(tenant_id: str = None, provider_name: str = None, instance_name: str = None, model_name: str = None): + """ + Enable or disable a model. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: provider_name + type: string + required: true + description: Provider name. + - in: path + name: instance_name + type: string + required: true + description: Instance name. + - in: path + name: model_name + type: string + required: true + description: Model name. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Model status update. + required: true + schema: + type: object + required: + - status + properties: + status: + type: string + enum: ["active", "inactive"] + description: Model status. + responses: + 200: + description: Model status updated. + schema: + type: object + """ + data = await request.get_json() + if not data or "status" not in data: + return get_error_argument_result(message="status is required") + + status = data["status"] + if status not in ("active", "inactive"): + return get_error_argument_result(message="status must be 'active' or 'inactive'") + + try: + success, msg = provider_api_service.update_model_status(tenant_id, provider_name, instance_name, model_name, status) + if success: + return get_result(message=msg) + else: + return get_error_data_result(message=msg) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/providers//instances//models/", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def chat_to_model(tenant_id: str = None, provider_name: str = None, instance_name: str = None, model_name: str = None): + """ + Chat to a model. + --- + tags: + - Providers + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: provider_name + type: string + required: true + description: Provider name. + - in: path + name: instance_name + type: string + required: true + description: Instance name. + - in: path + name: model_name + type: string + required: true + description: Model name. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Chat request. + required: true + schema: + type: object + required: + - message + properties: + message: + type: string + description: Chat message. + stream: + type: boolean + description: Whether to stream the response. + thinking: + type: boolean + description: Whether to enable thinking/reasoning. + responses: + 200: + description: Chat response. + schema: + type: object + """ + data = await request.get_json() + if not data or "message" not in data: + return get_error_argument_result(message="message is required") + + message = data["message"] + stream = data.get("stream", False) + thinking = data.get("thinking", False) + + try: + success, result = await provider_api_service.chat_to_model( + tenant_id, provider_name, instance_name, model_name, message, stream, thinking + ) + if not success: + return get_error_data_result(message=result) + + if stream and isinstance(result, dict) and result.get("type") == "stream": + # Streaming response using SSE + from quart import Response + llm = result["llm"] + + async def generate(): + async for chunk in llm.async_chat_streamly( + None, + [{"role": "user", "content": message}], + {"temperature": 0.9}, + ): + if chunk and isinstance(chunk, str) and chunk.find("**ERROR**") < 0: + yield f"data: [MESSAGE]{chunk}\n\n" + yield "data: [DONE]\n\n" + + return Response(generate(), mimetype="text/event-stream", headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }) + + # Non-streaming response + return get_result(data=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") diff --git a/api/apps/restful_apis/user_api.py b/api/apps/restful_apis/user_api.py index e33b51c6a9..8362e329dd 100644 --- a/api/apps/restful_apis/user_api.py +++ b/api/apps/restful_apis/user_api.py @@ -44,7 +44,6 @@ from api.utils.api_utils import ( validate_request, ) from api.utils.crypt import decrypt -from api.utils.tenant_utils import ensure_tenant_model_id_for_params from rag.utils.redis_conn import REDIS_CONN from api.apps import login_required, current_user, login_user, logout_user from api.utils.web_utils import ( @@ -631,8 +630,7 @@ async def set_tenant_info(): req = await get_request_json() try: tid = req.pop("tenant_id") - update_dict = ensure_tenant_model_id_for_params(tid, req) - TenantService.update_by_id(tid, update_dict) + TenantService.update_by_id(tid, req) return get_json_result(data=True) except Exception as e: return server_error_response(e) diff --git a/api/apps/services/dataset_api_service.py b/api/apps/services/dataset_api_service.py index 42c39364c5..3994316123 100644 --- a/api/apps/services/dataset_api_service.py +++ b/api/apps/services/dataset_api_service.py @@ -17,6 +17,8 @@ import logging import json import os import re + +from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance from common.constants import PAGERANK_FLD from common import settings from api.db.db_models import File @@ -912,11 +914,7 @@ async def search(dataset_id: str, tenant_id: str, req: dict): :param req: search request :return: (success, result) or (success, error_message) """ - from api.db.joint_services.tenant_model_service import ( - get_model_config_by_id, - get_model_config_by_type_and_name, - get_tenant_default_model_by_type, - ) + from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type from api.db.services.doc_metadata_service import DocMetadataService from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService @@ -985,7 +983,7 @@ async def search(dataset_id: str, tenant_id: str, req: dict): if meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_id = search_config.get("chat_id", "") if chat_id: - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, search_config["chat_id"]) + chat_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.CHAT, search_config["chat_id"]) else: chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, chat_model_config) @@ -1018,28 +1016,17 @@ async def search(dataset_id: str, tenant_id: str, req: dict): _question = question if langs: _question = await cross_languages(kb.tenant_id, None, _question, langs) - if kb.tenant_embd_id: - embd_model_config = get_model_config_by_id(kb.tenant_embd_id) - elif kb.embd_id: - embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) + if kb.embd_id: + embd_model_config = get_model_config_from_provider_instance(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) else: embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING) embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) rerank_mdl = None - if req.get("tenant_rerank_id"): - allowed_rerank_tenant_ids = {tenant_id, kb.tenant_id} - rerank_model_config = get_model_config_by_id( - req["tenant_rerank_id"], - allowed_tenant_ids=allowed_rerank_tenant_ids, - requester_tenant_id=tenant_id, - ) + rerank_id = search_config.get("rerank_id") or req.get("rerank_id") + if rerank_id: + rerank_model_config = get_model_config_from_provider_instance(kb.tenant_id, LLMType.RERANK.value, rerank_id) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) - else: - rerank_id = search_config.get("rerank_id") or req.get("rerank_id") - if rerank_id: - rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, rerank_id) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) if search_config.get("keyword", req.get("keyword", False)): default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) @@ -1099,9 +1086,6 @@ def check_embedding(dataset_id: str, tenant_id: str, req: dict): from common.doc_store.doc_store_base import OrderByExpr from rag.nlp import search - from api.db.joint_services.tenant_model_service import ( - get_model_config_by_type_and_name, - ) from api.db.services.llm_service import LLMBundle from common.constants import LLMType @@ -1214,7 +1198,7 @@ def check_embedding(dataset_id: str, tenant_id: str, req: dict): if not ok: return False, err - embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, embd_id) + embd_model_config = get_model_config_from_provider_instance(kb.tenant_id, LLMType.EMBEDDING, embd_id) emb_mdl = LLMBundle(kb.tenant_id, embd_model_config) n = int(req.get("check_num", 5)) @@ -1293,11 +1277,7 @@ async def search_datasets(tenant_id: str, req: dict): :param req: search request containing dataset_ids and other params :return: (success, result) or (success, error_message) """ - from api.db.joint_services.tenant_model_service import ( - get_model_config_by_id, - get_model_config_by_type_and_name, - get_tenant_default_model_by_type, - ) + from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type from api.db.services.doc_metadata_service import DocMetadataService from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService @@ -1373,7 +1353,7 @@ async def search_datasets(tenant_id: str, req: dict): if meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_id = search_config.get("chat_id", "") if chat_id: - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, search_config["chat_id"]) + chat_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.CHAT, search_config["chat_id"]) else: chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, chat_model_config) @@ -1408,28 +1388,17 @@ async def search_datasets(tenant_id: str, req: dict): _question = question if langs: _question = await cross_languages(kb.tenant_id, None, _question, langs) - if kb.tenant_embd_id: - embd_model_config = get_model_config_by_id(kb.tenant_embd_id) - elif kb.embd_id: - embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) + if kb.embd_id: + embd_model_config = get_model_config_from_provider_instance(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) else: embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING) embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) rerank_mdl = None - if req.get("tenant_rerank_id"): - allowed_rerank_tenant_ids = {tenant_id, *[dataset.tenant_id for dataset in kbs]} - rerank_model_config = get_model_config_by_id( - req["tenant_rerank_id"], - allowed_tenant_ids=allowed_rerank_tenant_ids, - requester_tenant_id=tenant_id, - ) + rerank_id = search_config.get("rerank_id") or req.get("rerank_id") + if rerank_id: + rerank_model_config = get_model_config_from_provider_instance(kb.tenant_id, LLMType.RERANK.value, rerank_id) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) - else: - rerank_id = search_config.get("rerank_id") or req.get("rerank_id") - if rerank_id: - rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, rerank_id) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) if search_config.get("keyword", req.get("keyword", False)): default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py index 53cff623ce..1555ce97dc 100644 --- a/api/apps/services/memory_api_service.py +++ b/api/apps/services/memory_api_service.py @@ -21,11 +21,10 @@ from api.db.services.canvas_service import UserCanvasService from api.db.services.task_service import TaskService from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default, queue_save_to_memory_task, query_message from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human -from api.utils.tenant_utils import ensure_tenant_model_id_for_params from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT from memory.services.messages import MessageService from memory.utils.prompt_util import PromptAssembler -from common.constants import MemoryType, ForgettingPolicy, LLMType +from common.constants import MemoryType, ForgettingPolicy from common.exceptions import ArgumentException, NotFoundException from common.time_utils import current_timestamp, timestamp_to_date @@ -104,9 +103,7 @@ async def create_memory(memory_info: dict): name=memory_name, memory_type=memory_type, embd_id=memory_info["embd_id"], - llm_id=memory_info["llm_id"], - tenant_llm_id=memory_info["tenant_llm_id"], - tenant_embd_id=memory_info["tenant_embd_id"] + llm_id=memory_info["llm_id"] ) if success: return True, format_ret_data_from_memory(res) @@ -133,7 +130,6 @@ async def update_memory(memory_id: str, new_memory_setting: dict): } """ current_memory = _require_memory_access(memory_id) - owner_tenant_id = current_memory.tenant_id update_dict = {} # check name length @@ -150,32 +146,15 @@ async def update_memory(memory_id: str, new_memory_setting: dict): if new_memory_setting["permissions"] not in [e.value for e in TenantPermission]: raise ArgumentException(f"Unknown permission '{new_memory_setting['permissions']}'.") update_dict["permissions"] = new_memory_setting["permissions"] - if ("tenant_llm_id" in new_memory_setting or "tenant_embd_id" in new_memory_setting) and not ( - new_memory_setting.get("llm_id") or new_memory_setting.get("embd_id") - ): - raise ArgumentException( - "Do not set tenant_llm_id or tenant_embd_id directly; update llm_id and/or embd_id instead." - ) if new_memory_setting.get("llm_id") or new_memory_setting.get("embd_id"): merged = { "llm_id": new_memory_setting.get("llm_id") or current_memory.llm_id, "embd_id": new_memory_setting.get("embd_id") or current_memory.embd_id, } - merged = ensure_tenant_model_id_for_params(owner_tenant_id, merged) - if not merged.get("tenant_llm_id"): - raise ArgumentException( - f"Tenant Model with name {merged['llm_id']} and type {LLMType.CHAT.value} not found" - ) - if new_memory_setting.get("embd_id") and not merged.get("tenant_embd_id"): - raise ArgumentException( - f"Tenant Model with name {merged['embd_id']} and type {LLMType.EMBEDDING.value} not found" - ) if new_memory_setting.get("llm_id"): update_dict["llm_id"] = merged["llm_id"] if new_memory_setting.get("embd_id"): update_dict["embd_id"] = merged["embd_id"] - update_dict["tenant_llm_id"] = merged["tenant_llm_id"] - update_dict["tenant_embd_id"] = merged.get("tenant_embd_id") if new_memory_setting.get("memory_type"): memory_type = set(new_memory_setting["memory_type"]) invalid_type = memory_type - {e.name.lower() for e in MemoryType} @@ -216,7 +195,7 @@ async def update_memory(memory_id: str, new_memory_setting: dict): return True, memory_dict # check memory empty when update embd_id, memory_type memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id) - not_allowed_update = [f for f in ["tenant_embd_id", "embd_id", "memory_type"] if f in to_update and memory_size > 0] + not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0] if not_allowed_update: raise ArgumentException(f"Can't update {not_allowed_update} when memory isn't empty.") if "memory_type" in to_update: diff --git a/api/apps/services/models_api_service.py b/api/apps/services/models_api_service.py new file mode 100644 index 0000000000..043957bd8d --- /dev/null +++ b/api/apps/services/models_api_service.py @@ -0,0 +1,360 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import logging + +from common.constants import ActiveStatusEnum, LLMType +from common.settings import FACTORY_LLM_INFOS +from api.db.services.tenant_model_provider_service import TenantModelProviderService +from api.db.services.tenant_model_instance_service import TenantModelInstanceService +from api.db.services.tenant_model_service import TenantModelService +from api.db.services.user_service import TenantService + +# Mapping from model_type string to Tenant model field name +MODEL_TYPE_TO_FIELD = { + "chat": "llm_id", + "embedding": "embd_id", + "rerank": "rerank_id", + "asr": "asr_id", + "vision": "img2txt_id", + "tts": "tts_id", + "ocr": "ocr_id", +} + +MODEL_TAG_TO_TYPE = { + "chat": "chat", + "embedding": "embedding", + "rerank": "rerank", + "asr": "speech2text", + "vision": "image2text", + "tts": "tts", + "ocr": "ocr", +} + + +def _get_model_info(tenant_id: str, default_model: str, model_type: str): + """ + Parse a composite model string (modelName@instanceName@providerName or modelName@providerName) + and validate that the provider, instance, and model exist. + + Returns a dict with model info or None on error. + """ + if not default_model: + return None + + parts = default_model.split("@") + if len(parts) == 3: + model_name, instance_name, provider_name = parts + elif len(parts) == 2: + model_name, provider_name = parts + instance_name = "default" + else: + logging.warning(f"Invalid model string: {default_model}") + return None + + model_type = MODEL_TAG_TO_TYPE.get(model_type, model_type) + # Special case: OCR with infiniflow@default@deepdoc is always enabled + if model_type == "ocr" and provider_name == "infiniflow" and instance_name == "default" and model_name == "deepdoc": + return { + "model_provider": provider_name, + "model_instance": instance_name, + "model_name": model_name, + "model_type": model_type, + "enable": True, + } + + # Check if the provider exists for the tenant + provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if not provider_obj: + logging.warning(f"Provider '{provider_name}' not found for tenant '{tenant_id}'") + return None + + # Check if the instance exists + instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_name) + if not instance_obj: + logging.warning(f"Instance '{instance_name}' not found for provider '{provider_name}'") + return None + + # Check if model is enabled (no TenantModel record or status != inactive means enabled) + model_entity = TenantModelService.get_by_provider_id_and_instance_id_and_model_type_and_model_name( + provider_obj.id, instance_obj.id, model_type, model_name + ) + enable = model_entity is None or model_entity.status != ActiveStatusEnum.INACTIVE.value + + if not enable: + return None + + if model_entity: + return { + "model_provider": provider_name, + "model_instance": instance_name, + "model_name": model_name, + "model_type": model_type, + "enable": enable, + } + + # Check if model is in the LLM factory info + factory_info = [f for f in (FACTORY_LLM_INFOS or []) if f["name"] == provider_name] + if not factory_info: + logging.warning(f"Provider '{provider_name}' not found in factory info") + return None + + llms = factory_info[0].get("llm", []) + target_llm = [llm for llm in llms if llm["llm_name"] == model_name] + if not target_llm: + logging.warning(f"Model '{model_name}' not found for provider '{provider_name}'") + return None + + # Check if the model_type matches + if target_llm[0].get("model_type") != model_type: + logging.warning(f"Model '{model_name}' isn't a {model_type} model") + return None + + return { + "model_provider": provider_name, + "model_instance": instance_name, + "model_name": model_name, + "model_type": model_type, + "enable": enable, + } + + +def _check_model_available(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str): + """ + Validate that a model is available for the tenant: + - Provider exists for the tenant + - Instance exists under the provider + - Model is in the LLM factory info for the provider + - Model type matches + - Model is not disabled in TenantModel table + + Returns (success, error_message). + """ + if provider_name == "infiniflow" and instance_name == "default" and model_name == "deepdoc": + return True, None + + if model_type == "ocr" and provider_name == "infiniflow" and instance_name == "default" and model_name == "deepdoc": + return True, None + + compose_profiles = os.getenv("COMPOSE_PROFILES", "") + is_tei_builtin_embedding = ( + model_type == LLMType.EMBEDDING.value + and "tei-" in compose_profiles + and model_name == os.getenv("TEI_MODEL", "") + and (provider_name == "Builtin" or provider_name is None) + ) + if is_tei_builtin_embedding: + return True, None + + # Check provider + provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if not provider_obj: + return False, f"Provider '{provider_name}' not found" + + # Check instance + instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_name) + if not instance_obj: + return False, f"Instance '{instance_name}' not found for provider '{provider_name}'" + + # Check model schema + factory_info = [f for f in (FACTORY_LLM_INFOS or []) if f["name"] == provider_name] + if not factory_info: + return False, f"Provider '{provider_name}' not found in factory info" + model_type = MODEL_TAG_TO_TYPE.get(model_type, model_type) + # Check if model is disabled + model_entity = TenantModelService.get_by_provider_id_and_instance_id_and_model_type_and_model_name( + provider_obj.id, instance_obj.id, model_type, model_name + ) + if model_entity: + if model_entity.status == "inactive": + return False, f"Model '{model_name}' isn't available" + return True, None + + llms = factory_info[0].get("llm", []) + target_llm = [llm for llm in llms if llm["llm_name"] == model_name] + if not target_llm and not model_entity: + return False, f"Model '{model_name}' not found for provider '{provider_name}'" + + if target_llm: + if target_llm[0].get("model_type") != model_type: + return False, f"Model '{model_name}' isn't a {model_type} model" + + return True, None + + +def list_tenant_default_models(tenant_id: str): + """ + List all default models for a tenant. + + For each model type (chat, embedding, rerank, asr, vision, tts, ocr), + reads the composite model ID string from the Tenant record and resolves + it into provider/instance/name components. + + :param tenant_id: tenant ID + :return: (success, result_or_error_message) + """ + e, tenant = TenantService.get_by_id(tenant_id) + if not e: + return False, "Tenant not found" + + models = [] + + for model_type, field_name in MODEL_TYPE_TO_FIELD.items(): + default_model = getattr(tenant, field_name, None) + if not default_model: + continue + model_info = _get_model_info(tenant_id, default_model, model_type) + if model_info: + models.append(model_info) + + return True, {"models": models} + + +def set_tenant_default_models(tenant_id: str, model_provider: str, model_instance: str, model_name: str, model_type: str): + """ + Set or clear a tenant default model. + + If model_provider, model_instance, and model_name are all provided, + validates the model and sets it as the default. + If all three are empty, clears the default for the given model type. + + :param tenant_id: tenant ID + :param model_provider: provider name + :param model_instance: instance name + :param model_name: model name + :param model_type: model type (chat, embedding, rerank, asr, vision, tts, ocr) + :return: (success, result_or_error_message) + """ + field_name = MODEL_TYPE_TO_FIELD.get(model_type) + if not field_name: + return False, f"model type '{model_type}' is invalid" + + e, tenant = TenantService.get_by_id(tenant_id) + if not e: + return False, "Tenant not found" + + if not model_provider and not model_instance and not model_name: + # Clear the default model + default_model = "" + elif model_provider and model_instance and model_name: + # Validate and set the default model + success, msg = _check_model_available(tenant_id, model_provider, model_instance, model_name, model_type) + if not success: + return False, msg + default_model = f"{model_name}@{model_instance}@{model_provider}" + else: + return False, "model_provider, model_instance and model_name must be specified together" + + TenantService.update_by_id(tenant_id, {field_name: default_model}) + return True, "success" + + +def list_tenant_added_models(tenant_id: str, model_type_filter: str=None): + """ + List all added models for a tenant. + + :param tenant_id: tenant ID + :param model_type_filter: model type filter (chat, embedding, rerank, asr, vision, tts, ocr) + :return: (success, result_or_error_message) + """ + e, tenant = TenantService.get_by_id(tenant_id) + if not e: + return False, "Tenant not found" + + if model_type_filter: + model_type_filter = model_type_filter.lower() + + providers = TenantModelProviderService.get_by_tenant_id(tenant_id) + if not providers: + return True, [] + + provider_ids = [provider.id for provider in providers] + instances = TenantModelInstanceService.get_by_provider_ids(provider_ids) + if not instances: + return True, [] + provider_instance_map: dict = {} + provider_info_map = {provider.id: provider for provider in providers} + for provider_instance_record in instances: + provider_name = provider_info_map[provider_instance_record.provider_id].provider_name if provider_info_map.get(provider_instance_record.provider_id) else "" + if provider_instance_map.get(provider_name): + provider_instance_map[provider_name].append(provider_instance_record) + else: + provider_instance_map[provider_name] = [provider_instance_record] + + model_records = TenantModelService.get_models_by_provider_ids_and_instance_ids(provider_ids, list({instance.id for instance in instances})) + target_type_records = [record for record in model_records if record.model_type == model_type_filter] if model_type_filter else model_records + model_record_map = {} + for model in target_type_records: + instance_model_key = f"{model.provider_id}_{model.instance_id}_{model.model_name}" + if model_record_map.get(instance_model_key): + model_record_map[instance_model_key].append(model) + else: + model_record_map[instance_model_key] = [model] + + added_models = [] + model_key_in_factory = [] + provider_names = [provider.provider_name for provider in providers] + for factory in FACTORY_LLM_INFOS: + if factory["name"] not in provider_names: + continue + factory_instances = provider_instance_map.get(factory["name"]) + if not factory_instances: + continue + for llm in factory["llm"]: + if model_type_filter and llm["model_type"] != model_type_filter: + continue + + for factory_instance in factory_instances: + model_record_key = f"{factory_instance.provider_id}_{factory_instance.id}_{llm['llm_name']}" + model_key_in_factory.append(model_record_key) + manual_modified_models = model_record_map.get(model_record_key, []) + active_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.ACTIVE.value] + inactive_model_types = [manual_model.model_type for manual_model in manual_modified_models if manual_model.status == ActiveStatusEnum.INACTIVE.value] + model_types = list(set([llm["model_type"]] + active_model_types) - set(inactive_model_types)) + if not model_types: + continue + + added_models.append({ + "model_type": model_types, + "name": llm["llm_name"], + "provider_id": factory_instance.provider_id, + "provider_name": provider_info_map[factory_instance.provider_id].provider_name if provider_info_map.get(factory_instance.provider_id) else "", + "instance_id": factory_instance.id, + "instance_name": factory_instance.instance_name + }) + + manual_added_model_record_keys = list(set(model_record_map.keys()) - set(model_key_in_factory)) + if manual_added_model_record_keys: + instance_info_map = {instance.id: instance for instance in instances} + for model_record_key in manual_added_model_record_keys: + model_records = model_record_map.get(model_record_key, []) + if not model_records: + continue + provider_id, instance_id, model_name = model_record_key.split("_") + model_types = [model.model_type for model in model_records if model.status == ActiveStatusEnum.ACTIVE.value] + if not model_types: + continue + + added_models.append({ + "model_type": model_types, + "name": model_name, + "provider_id": provider_id, + "provider_name": provider_info_map[provider_id].provider_name if provider_info_map.get(provider_id) else "", + "instance_id": instance_id, + "instance_name": instance_info_map[instance_id].instance_name if instance_info_map.get(instance_id) else "" + }) + + return True, added_models diff --git a/api/apps/services/provider_api_service.py b/api/apps/services/provider_api_service.py new file mode 100644 index 0000000000..c094674d4d --- /dev/null +++ b/api/apps/services/provider_api_service.py @@ -0,0 +1,558 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import logging + +from common.constants import LLMType, ActiveStatusEnum +from common.misc_utils import get_uuid +from common.settings import FACTORY_LLM_INFOS +from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance, delete_models_by_instance_ids, delete_instances_by_provider_ids +from api.db.services.tenant_model_provider_service import TenantModelProviderService +from api.db.services.tenant_model_instance_service import TenantModelInstanceService +from api.db.services.tenant_model_service import TenantModelService + + +def list_providers(tenant_id: str, all_available: bool = False): + """ + List providers for a tenant. + + If available_only is True, list all system-wide providers (pool providers). + Otherwise, list providers that the tenant has configured. + + :param tenant_id: tenant ID + :param all_available: whether to list all available providers + :return: (success, result) + """ + if not FACTORY_LLM_INFOS: + return False, [] + + if all_available: + providers = [] + for factory_info in FACTORY_LLM_INFOS: + model_types = sorted(set( + llm["model_type"] + for llm in factory_info.get("llm", []) + if llm.get("model_type") + )) + providers.append({ + "model_types": model_types, + "name": factory_info["name"], + "url": { + "default": factory_info.get("url", "") + } + }) + return True, providers + + # List tenant-configured providers + factory_names = TenantModelProviderService.list_provider_names_by_tenant_id(tenant_id) + + providers = [] + factory_info_mapping = {f["name"]: f for f in FACTORY_LLM_INFOS} + for name in factory_names: + if factory_info_mapping.get(name): + factory_info = factory_info_mapping[name] + model_types = sorted(set( + llm["model_type"] + for llm in factory_info.get("llm", []) + if llm.get("model_type") + )) + providers.append({ + "model_types": model_types, + "name": factory_info["name"], + "url": { + "default": factory_info.get("url", "") + } + }) + + return True, providers + + +def add_provider(tenant_id: str, provider_name: str): + """ + Add a provider (factory) for a tenant. + + :param tenant_id: tenant ID + :param provider_name: provider/factory name + :return: (success, result_or_error_message) + """ + if not FACTORY_LLM_INFOS: + return False, "No providers found" + # Check if factory is allowed + allowed_factories = [f["name"] for f in FACTORY_LLM_INFOS] + if provider_name not in allowed_factories: + return False, f"Provider '{provider_name}' is not allowed" + + existing = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if existing: + return False, f"Provider {provider_name} already exists" + + TenantModelProviderService.insert( + tenant_id=tenant_id, + provider_name=provider_name + ) + return True, "success" + + +def delete_provider(tenant_id: str, provider_name: str): + """ + Delete all instances and models for a provider. + + :param tenant_id: tenant ID + :param provider_name: provider/factory name + :return: (success, result_or_error_message) + """ + provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if not provider_obj: + return False, f"Provider {provider_name} not found" + instance_objs = TenantModelInstanceService.get_by_provider_id(provider_obj.id) + if not instance_objs: + return False, f"No instances found for provider {provider_name}" + instance_ids = [instance_obj.id for instance_obj in instance_objs] + delete_models_by_instance_ids(instance_ids) + delete_instances_by_provider_ids([provider_obj.id]) + TenantModelProviderService.delete_by_tenant_id_and_provider_name(tenant_id, provider_name) + return True, "success" + + +def show_provider(provider_name: str): + """ + Show provider details from LLMFactories. + + :param provider_name: provider/factory name + :return: (success, result_or_error_message) + """ + fac_list = [f for f in FACTORY_LLM_INFOS if f["name"]==provider_name] + if not fac_list: + return False, f"Provider '{provider_name}' not found" + factory_info = fac_list[0] + return True, { + "base_url": { + "default": factory_info.get("url", "") + }, + "name": factory_info["name"], + "total_models": len(factory_info.get("llm", [])) + } + + +def list_provider_models(provider_name: str): + """ + List all models for a provider from the LLM dictionary. + + :param provider_name: provider/factory name + :return: (success, result_or_error_message) + """ + factory_info = [f for f in FACTORY_LLM_INFOS if f["name"]==provider_name] + if not factory_info: + return False, f"Provider '{provider_name}' not found" + llms = factory_info[0]["llm"] + if not llms: + return False, f"No models found for provider '{provider_name}'" + + models = [] + for llm in llms: + models.append({ + "name": llm["name"], + "max_tokens": llm["max_tokens"], + "model_types": [llm["model_type"]], + "features": None + }) + return True, models + + +def show_provider_model(provider_name: str, model_name: str): + """ + Show a specific model for a provider. + + :param provider_name: provider/factory name + :param model_name: model name + :return: (success, result_or_error_message) + """ + factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name] + if not factory_info: + return False, f"Provider '{provider_name}' not found" + llms = factory_info[0]["llm"] + if not llms: + return False, f"No models found for provider '{provider_name}'" + target_llm = [llm for llm in llms if llm["name"] == model_name] + if not target_llm: + return False, f"Model '{model_name}' not found" + llm_info = target_llm[0] + + return True, { + "name": llm_info["name"], + "max_tokens": llm_info["max_tokens"], + "model_types": [llm_info["model_type"]], + "thinking": None, + "model_type_map": { + llm_info["model_type"]: True + } + } + + +def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str, base_url: str, region: str): + """ + Create a provider instance. + + The instance_name parameter is accepted for API compatibility but in the old + model all records under a factory share the same API key configuration. + + :param tenant_id: tenant ID + :param provider_name: provider/factory name + :param instance_name: instance name (used as a logical identifier) + :param api_key: API key + :param base_url: base url + :param region: region + :return: (success, result_or_error_message) + """ + if not provider_name: + return False, "Provider name is required" + + if instance_name == "default": + return False, "Instance name cannot be 'default'" + + # Check if provider exists in the system + allowed_factories = [f["name"] for f in FACTORY_LLM_INFOS] + if provider_name not in allowed_factories: + return False, f"Provider '{provider_name}' is not allowed" + + provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if not provider_obj: + return False, f"Provider '{provider_name}' does not exist" + + if api_key: + same_key_instance = TenantModelInstanceService.get_by_provider_id_and_api_key(provider_obj.id, api_key) + if same_key_instance: + return False, f"Already exist instance: {same_key_instance.instance_name} with api_key {api_key}" + + import json + extra_fields = {} + if base_url: + extra_fields["base_url"] = base_url + if region: + extra_fields["region"] = region + TenantModelInstanceService.create_instance(provider_id=provider_obj.id,instance_name=instance_name,api_key=api_key, extra=json.dumps(extra_fields)) + + return True, "success" + + +def list_provider_instances(tenant_id: str, provider_name: str): + """ + List provider instances for a tenant. + + :param tenant_id: tenant ID + :param provider_name: provider/factory name + :return: (success, result_or_error_message) + """ + import json + provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if not provider_obj: + return False, f"No provider found for provider '{provider_name}'" + provider_id = provider_obj.id + instance_objs = TenantModelInstanceService.get_all_by_provider_id(provider_id) + if not instance_objs: + return True, [] + instances = [] + for instance_obj in instance_objs: + extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {} + instances.append({ + "api_key": instance_obj.api_key, + "id": instance_obj.id, + "instance_name": instance_obj.instance_name, + "provider_id": provider_id, + "region": extra_fields.get("region", ""), + "status": instance_obj.status, + }) + + return True, instances + + +def show_provider_instance(tenant_id: str, provider_name: str, instance_name: str): + """ + Show a specific provider instance. + + :param tenant_id: tenant ID + :param provider_name: provider/factory name + :param instance_name: instance name + :return: (success, result_or_error_message) + """ + provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if not provider_obj: + return False, f"No provider found for provider '{provider_name}'" + provider_id = provider_obj.id + instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_id, instance_name) + if not instance_obj: + return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'" + + import json + extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {} + return True, { + "id": instance_obj.id, + "instance_name": instance_obj.instance_name, + "provider_id": provider_id, + "region": extra_fields.get("region", ""), + "status": instance_obj.status + } + + +def drop_provider_instances(tenant_id: str, provider_name: str, instance_names: list): + """ + Drop provider instances. + for the specified models/instances. + + :param tenant_id: tenant ID + :param provider_name: provider/factory name + :param instance_names: list of instance names to drop + :return: (success, result_or_error_message) + """ + provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if not provider_obj: + return False, f"No provider found for provider '{provider_name}'" + provider_id = provider_obj.id + not_exist_instances = [] + instance_ids = [] + for instance_name in instance_names: + instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_id, instance_name) + if not instance_obj: + not_exist_instances.append(instance_name) + continue + instance_ids.append(instance_obj.id) + if not_exist_instances: + return False, f"No instance found for provider '{provider_name}' and instance '{not_exist_instances}'" + delete_models_by_instance_ids(instance_ids) + TenantModelInstanceService.delete_by_ids(instance_ids) + return True, None + + +def list_instance_models(tenant_id: str, provider_name: str, instance_name: str, supported_only: bool = False): + """ + List models for a provider instance. + + Follows the Go version's logic: + - Reads tenant_model table to determine disabled models (records exist = disabled). + - Lists all models from the LLM dictionary for the provider. + - Models present in tenant_model table are marked "inactive", others "active". + + :param tenant_id: tenant ID + :param provider_name: provider/factory name + :param instance_name: instance name + :param supported_only: if True, only list supported models (from LLM dictionary) + :return: (success, result_or_error_message) + """ + provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if not provider_obj: + return False, f"No provider found for provider '{provider_name}'" + + if supported_only: + # List all models supported by this provider from the LLM dictionary + factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name] + if not factory_info: + return False, f"Provider '{provider_name}' not found" + llms = factory_info[0].get("llm", []) + models = [{"name": llm["llm_name"]} for llm in llms] + return True, models + + # Get instance + instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_name) + if not instance_obj: + return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'" + + # Get model records for this instance from tenant_model table + model_records = TenantModelService.get_models_by_instance_id(instance_obj.id) + # Build a map of model_name -> status, type + model_info_map: dict = {} + for model_record in model_records: + if model_info_map.get(model_record.model_name): + model_info_map[model_record.model_name]["model_type"].append(model_record.model_type) + else: + model_info_map[model_record.model_name] = { + "status": model_record.status, + "model_type": [model_record.model_type], + "extra": model_record.extra + } + + # List all models from the LLM dictionary for this provider + factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name] + if not factory_info: + return False, f"Provider '{provider_name}' not found" + + llms = factory_info[0].get("llm", []) + models = [] + for llm in llms: + models.append({ + "name": llm["llm_name"], + "model_type": [llm["model_type"]] + model_info_map.get(llm["llm_name"], {}).get("model_type", []), + "max_tokens": llm.get("max_tokens"), + "status": model_info_map.get(llm["llm_name"], {}).get("status", "active"), + }) + factory_models = [m["name"] for m in models] + for model_name, model_info_dict in model_info_map.items(): + if model_name not in factory_models: + extra_fields = json.loads(model_info_dict["extra"]) if model_info_dict["extra"] else {} + models.append({ + "name": model_name, + "model_type": model_info_dict["model_type"], + "max_tokens": extra_fields.get("max_tokens", 8192), + "status": model_info_dict["status"], + }) + + return True, models + + +def add_model_to_instance(tenant_id: str, provider_name: str, instance_name: str, model_name: str, model_type: str|list[str], max_tokens: int, extra: dict): + provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if not provider_obj: + return False, f"No provider found for provider '{provider_name}'" + instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_name) + if not instance_obj: + return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'" + model_obj = TenantModelService.get_by_provider_id_and_instance_id_and_model_name(provider_obj.id, instance_obj.id, model_name) + if model_obj: + return False, f"Model '{model_name}' already exists for provider '{provider_name}' and instance '{instance_name}'" + factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name] + if not factory_info: + return False, f"Provider '{provider_name}' not found" + llms = factory_info[0].get("llm", []) + if isinstance(model_type, str): + model_type = [model_type] + + import json + + for _type in model_type: + extra_fields = {"max_tokens": max_tokens} + target_model = [llm for llm in llms if llm["model_type"] == _type and llm["llm_name"] == model_name] + if target_model: + extra_fields.update({"is_tool": target_model[0].get("is_tool", False)}) + extra_fields.update(extra) + TenantModelService.insert( + model_name=model_name, + provider_id=provider_obj.id, + instance_id=instance_obj.id, + model_type=_type, + extra=json.dumps(extra_fields) + ) + + return True, "success" + + +def update_model_status(tenant_id: str, provider_name: str, instance_name: str, model_name: str, status: str): + """ + Enable or disable a model for a provider instance. + + - If the model record exists in tenant_model, update its status. + - If the model record does not exist: + - status="active": no need to add a record (default is active/enabled). + - status="inactive": create a record with status="inactive". + + :param tenant_id: tenant ID + :param provider_name: provider/factory name + :param instance_name: instance name + :param model_name: model name + :param status: "active" or "inactive" (ActiveStatusEnum values) + :return: (success, result_or_error_message) + """ + if status not in (ActiveStatusEnum.ACTIVE.value, ActiveStatusEnum.INACTIVE.value): + return False, f"status must be '{ActiveStatusEnum.ACTIVE.value}' or '{ActiveStatusEnum.INACTIVE.value}'" + + # Check if provider exists for this tenant + provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if not provider_obj: + return False, f"No provider found for provider '{provider_name}'" + + # Check if instance exists + instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_name) + if not instance_obj: + return False, f"No instance found for provider '{provider_name}' and instance '{instance_name}'" + + # Check if model record already exists in tenant_model table + model_obj_list = TenantModelService.get_by_provider_id_and_instance_id_and_model_name( + provider_obj.id, instance_obj.id, model_name + ) + + if model_obj_list: + # Model record exists — update its status + TenantModelService.batch_update_model_status([m.id for m in model_obj_list], status) + else: + # Model record does not exist + if status == ActiveStatusEnum.ACTIVE.value: + # Default is active, no need to add a record + return True, None + # status is "inactive" — create a record with inactive status + # Look up model schema from FACTORY_LLM_INFOS + factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name] + if not factory_info: + return False, f"Provider '{provider_name}' not found" + llms = factory_info[0].get("llm", []) + target_llm = [llm for llm in llms if llm["llm_name"] == model_name] + if not target_llm: + return False, f"provider {provider_name} model {model_name} not found" + + TenantModelService.insert( + id=get_uuid(), + model_name=model_name, + model_type=target_llm[0]["model_type"], + provider_id=provider_obj.id, + instance_id=instance_obj.id, + status=status, + ) + + return True, None + + +async def chat_to_model(tenant_id: str, provider_name: str, instance_name: str, model_name: str, message: str, stream: bool = False, thinking: bool = False): + """ + Chat to a model. + + :param tenant_id: tenant ID + :param provider_name: provider/factory name + :param instance_name: instance name + :param model_name: model name + :param message: chat message + :param stream: whether to stream the response + :param thinking: whether to enable thinking/reasoning + :return: (success, result_or_error_message) + """ + from api.db.services.llm_service import LLMBundle + + # Get model config + composite_name = f"{model_name}@{instance_name}@{provider_name}" + try: + model_config = get_model_config_from_provider_instance(tenant_id, LLMType.CHAT.value, composite_name) + except LookupError: + return False, f"Model '{composite_name}' not authorized" + + if not model_config: + return False, f"Model '{composite_name}' not found" + + llm = LLMBundle(tenant_id, model_config) + + if stream: + return True, {"type": "stream", "llm": llm, "model_config": model_config} + + # Non-streaming chat + try: + response = await llm.async_chat( + None, + [{"role": "user", "content": message}], + {"temperature": 0.9}, + ) + result = { + "answer": response, + "reasoning_content": "", + } + return True, result + except Exception as e: + logging.exception(f"Chat to model failed: {e}") + return False, str(e) diff --git a/api/db/db_models.py b/api/db/db_models.py index 70c2a0ee94..e676ae97cd 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -763,6 +763,7 @@ class Tenant(DataBaseModel): tenant_rerank_id = IntegerField(null=True, help_text="id in tenant_llm", index=True) tts_id = CharField(max_length=256, null=True, help_text="default tts model ID", index=True) tenant_tts_id = IntegerField(null=True, help_text="id in tenant_llm", index=True) + ocr_id = CharField(max_length=256, null=True, help_text="default OCR model ID", index=True) parser_ids = CharField(max_length=256, null=False, help_text="document processors", index=True) credit = IntegerField(default=512, index=True) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) @@ -1350,6 +1351,67 @@ class SystemSettings(DataBaseModel): class Meta: db_table = "system_settings" +class TenantModelProvider(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + provider_name = CharField(max_length=128, null=False, index=False, help_text="LLM provider name") + tenant_id = CharField(max_length=32, null=False, index=True) + + class Meta: + db_table = "tenant_model_provider" + indexes = ( + (("tenant_id", "provider_name"), True), + ) + +class TenantModelInstance(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + instance_name = CharField(max_length=128, null=False, index=False, help_text="Model instance name") + provider_id = CharField(max_length=32, null=False, index=False) + api_key = CharField(max_length=512, null=False, index=False, help_text="API key") + status = CharField(max_length=32, default="active", index=False) + extra = CharField(max_length=512, default="{}", index=False) + + class Meta: + db_table = "tenant_model_instance" + indexes = ( + (("api_key", "provider_id"), True), + ) + + +class TenantModel(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + model_name = CharField(max_length=128, null=True, index=False, help_text="Model name") + provider_id = CharField(max_length=32, null=False, index=False) + instance_id = CharField(max_length=32, null=False, index=True) + model_type = CharField(max_length=32, null=False, index=False, help_text="Model type") + status = CharField(max_length=32, default="active", index=False) + extra = CharField(max_length=1024, default="{}", index=False) + + class Meta: + db_table = "tenant_model" + + +class TenantModelGroup(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + group_type = CharField(max_length=32, null=False, index=False, help_text="Group type") + model_name = CharField(max_length=128, null=True, index=False, help_text="Model name") + strategy = CharField(max_length=32, default="weighted", index=False, help_text="Routing strategy") + + class Meta: + db_table = "tenant_model_group" + +class TenantModelGroupMapping(DataBaseModel): + group_id = CharField(max_length=32, null=False, index=True, help_text="Group ID") + provider_id = CharField(max_length=32, null=False, index=False) + instance_id = CharField(max_length=32, null=False, index=False) + model_id = CharField(max_length=32, null=False, index=True) + weight = IntegerField(default=100, index=False, help_text="Routing weight") + status = CharField(max_length=32, default="active", index=False) + + class Meta: + db_table = "tenant_model_group_mapping" + primary_key = CompositeKey("group_id", "provider_id", "instance_id", "model_id") + + def alter_db_add_column(migrator, table_name, column_name, column_type): try: migrate(migrator.add_column(table_name, column_name, column_type)) @@ -1668,6 +1730,7 @@ def migrate_db(): alter_db_add_column(migrator, "api_4_conversation", "version_title", CharField(max_length=255, null=True, help_text="canvas version title when session created", index=False)) alter_db_column_type(migrator, "document", "size", BigIntegerField(default=0, index=True)) alter_db_column_type(migrator, "file", "size", BigIntegerField(default=0, index=True)) + alter_db_add_column(migrator, "tenant", "ocr_id", CharField(max_length=128, null=True, help_text="default ocr model ID", index=True)) logging.disable(logging.NOTSET) # this is after re-enabling logging to allow logging changed user emails migrate_add_unique_email(migrator) diff --git a/api/db/init_data.py b/api/db/init_data.py index 5bd5225999..ec200a931d 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -24,17 +24,15 @@ from copy import deepcopy from peewee import IntegrityError from api.db import UserTenantRole -from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM, Knowledgebase, Dialog, Memory +from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM from api.db.services import UserService from api.db.services.canvas_service import CanvasTemplateService from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.memory_service import MemoryService from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm from api.db.services.user_service import TenantService, UserTenantService from api.db.services.system_settings_service import SystemSettingsService -from api.db.services.dialog_service import DialogService from api.db.template_utils import normalize_canvas_template_categories from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache, fix_missing_tokenized_memory from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type @@ -154,6 +152,8 @@ def init_llm_factory(): except Exception: pass break + +def update_document_number_in_init(): doc_count = DocumentService.get_all_kb_doc_count() for kb_id in KnowledgebaseService.get_all_ids(): KnowledgebaseService.update_document_number_in_init(kb_id=kb_id, doc_num=doc_count.get(kb_id, 0)) @@ -189,7 +189,8 @@ def init_web_data(): init_table() - init_llm_factory() + # init_llm_factory() + update_document_number_in_init() # if not UserService.get_all().count(): # init_superuser() @@ -197,7 +198,6 @@ def init_web_data(): init_message_id_sequence() init_memory_size_cache() fix_missing_tokenized_memory() - fix_empty_tenant_model_id() logging.info("init web data success:{}".format(time.time() - start_time)) def init_table(): @@ -226,105 +226,6 @@ def init_table(): raise e -def fix_empty_tenant_model_id(): - # knowledgebase - empty_tenant_embd_id_kbs = KnowledgebaseService.get_null_tenant_embd_id_row() - if empty_tenant_embd_id_kbs: - logging.info(f"Found {len(empty_tenant_embd_id_kbs)} empty tenant_embd_id knowledgebase.") - kb_groups: dict = {} - for obj in empty_tenant_embd_id_kbs: - if kb_groups.get((obj.tenant_id, obj.embd_id)): - kb_groups[(obj.tenant_id, obj.embd_id)].append(obj.id) - else: - kb_groups[(obj.tenant_id, obj.embd_id)] = [obj.id] - update_cnt = 0 - for k, v in kb_groups.items(): - tenant_llm = TenantLLMService.get_api_key(k[0], k[1]) - if tenant_llm: - update_cnt += KnowledgebaseService.filter_update([Knowledgebase.id.in_(v)], {"tenant_embd_id": tenant_llm.id}) - logging.info(f"Update {update_cnt} tenant_embd_id in table knowledgebase.") - # dialog - empty_tenant_llm_id_dialog = DialogService.get_null_tenant_llm_id_row() - if empty_tenant_llm_id_dialog: - logging.info(f"Found {len(empty_tenant_llm_id_dialog)} empty tenant_llm_id dialogs.") - dialog_groups: dict = {} - for obj in empty_tenant_llm_id_dialog: - if dialog_groups.get((obj.tenant_id, obj.llm_id)): - dialog_groups[(obj.tenant_id, obj.llm_id)].append(obj.id) - else: - dialog_groups[(obj.tenant_id, obj.llm_id)] = [obj.id] - update_cnt = 0 - for k, v in dialog_groups.items(): - tenant_llm = TenantLLMService.get_api_key(k[0], k[1]) - if tenant_llm: - update_cnt += DialogService.filter_update([Dialog.id.in_(v)], {"tenant_llm_id": tenant_llm.id}) - logging.info(f"Update {update_cnt} tenant_llm_id in table dialog.") - - empty_tenant_rerank_id_dialog = DialogService.get_null_tenant_rerank_id_row() - if empty_tenant_rerank_id_dialog: - logging.info(f"Found {len(empty_tenant_rerank_id_dialog)} empty tenant_rerank_id dialogs.") - dialog_groups: dict = {} - for obj in empty_tenant_rerank_id_dialog: - if dialog_groups.get((obj.tenant_id, obj.rerank_id)): - dialog_groups[(obj.tenant_id, obj.rerank_id)].append(obj.id) - else: - dialog_groups[(obj.tenant_id, obj.rerank_id)] = [obj.id] - update_cnt = 0 - for k, v in dialog_groups.items(): - tenant_llm = TenantLLMService.get_api_key(k[0], k[1]) - if tenant_llm: - update_cnt += DialogService.filter_update([Dialog.id.in_(v)], {"tenant_rerank_id": tenant_llm.id}) - logging.info(f"Update {update_cnt} tenant_rerank_id in table dialog.") - # memory - empty_tenant_embd_id_memories = MemoryService.get_null_tenant_embd_id_row() - if empty_tenant_embd_id_memories: - logging.info(f"Found {len(empty_tenant_embd_id_memories)} empty tenant_embd_id memories.") - memory_groups: dict = {} - for obj in empty_tenant_embd_id_memories: - if memory_groups.get((obj.tenant_id, obj.embd_id)): - memory_groups[(obj.tenant_id, obj.embd_id)].append(obj.id) - else: - memory_groups[(obj.tenant_id, obj.embd_id)] = [obj.id] - update_cnt = 0 - for k, v in memory_groups.items(): - tenant_llm = TenantLLMService.get_api_key(k[0], k[1]) - if tenant_llm: - update_cnt += MemoryService.filter_update([Memory.id.in_(v)], {"tenant_embd_id": tenant_llm.id}) - logging.info(f"Update {update_cnt} tenant_embd_id in table memory.") - - empty_tenant_llm_id_memories = MemoryService.get_null_tenant_llm_id_row() - if empty_tenant_llm_id_memories: - logging.info(f"Found {len(empty_tenant_llm_id_memories)} empty tenant_llm_id memories.") - memory_groups: dict = {} - for obj in empty_tenant_llm_id_memories: - if memory_groups.get((obj.tenant_id, obj.llm_id)): - memory_groups[(obj.tenant_id, obj.llm_id)].append(obj.id) - else: - memory_groups[(obj.tenant_id, obj.llm_id)] = [obj.id] - update_cnt = 0 - for k, v in memory_groups.items(): - tenant_llm = TenantLLMService.get_api_key(k[0], k[1]) - if tenant_llm: - update_cnt += MemoryService.filter_update([Memory.id.in_(v)], {"tenant_llm_id": tenant_llm.id}) - logging.info(f"Update {update_cnt} tenant_llm_id in table memory.") - # tenant - empty_tenant_model_id_tenants = TenantService.get_null_tenant_model_id_rows() - if empty_tenant_model_id_tenants: - logging.info(f"Found {len(empty_tenant_model_id_tenants)} empty tenant_model_id tenants.") - update_cnt = 0 - for obj in empty_tenant_model_id_tenants: - tenant_dict = obj.to_dict() - update_dict = {} - for key in ["llm_id", "embd_id", "asr_id", "img2txt_id", "rerank_id", "tts_id"]: - if tenant_dict.get(key) and not tenant_dict.get(f"tenant_{key}"): - tenant_model = TenantLLMService.get_api_key(tenant_dict["id"], tenant_dict[key]) - if tenant_model: - update_dict.update({f"tenant_{key}": tenant_model.id}) - if update_dict: - update_cnt += TenantService.update_by_id(tenant_dict["id"], update_dict) - logging.info(f"Update {update_cnt} tenant_model_id in table tenant.") - logging.info("Fix empty tenant_model_id done.") - if __name__ == '__main__': init_web_db() init_web_data() diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py index fd86342181..a25ed86a37 100644 --- a/api/db/joint_services/memory_message_service.py +++ b/api/db/joint_services/memory_message_service.py @@ -26,7 +26,7 @@ from api.db.db_models import Task from api.db.services.task_service import TaskService from api.db.services.memory_service import MemoryService from api.db.services.llm_service import LLMBundle -from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name +from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance from api.utils.memory_utils import get_memory_type_human from memory.services.messages import MessageService from memory.services.query import MsgTextQuery, get_vector @@ -153,14 +153,7 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict, user_prompts.append({"role": "user", "content": f"Conversation: {conversation_content}\nConversation Time: {conversation_time}\nCurrent Time: {conversation_time}"}) else: user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)}) - if tenant_llm_id: - llm_config = get_model_config_by_id( - tenant_llm_id, - allowed_tenant_ids=tenant_id, - requester_tenant_id=tenant_id, - ) - else: - llm_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id) + llm_config = get_model_config_from_provider_instance(tenant_id, LLMType.CHAT, llm_id) with LLMBundle(tenant_id, llm_config) as llm: if task_id: TaskService.update_progress(task_id, {"progress": 0.15, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared prompts and LLM."}) @@ -177,14 +170,7 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict, async def embed_and_save(memory, message_list: list[dict], task_id: str=None): - if memory.tenant_embd_id: - embd_model_config = get_model_config_by_id( - memory.tenant_embd_id, - allowed_tenant_ids=memory.tenant_id, - requester_tenant_id=memory.tenant_id, - ) - else: - embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) + embd_model_config = get_model_config_from_provider_instance(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) with LLMBundle(memory.tenant_id, embd_model_config) as embedding_model: if task_id: TaskService.update_progress(task_id, {"progress": 0.65, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared embedding model."}) @@ -255,14 +241,7 @@ def query_message(filter_dict: dict, params: dict): question = params["query"] question = question.strip() memory = memory_list[0] - if memory.tenant_embd_id: - embd_model_config = get_model_config_by_id( - memory.tenant_embd_id, - allowed_tenant_ids=memory.tenant_id, - requester_tenant_id=memory.tenant_id, - ) - else: - embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) + embd_model_config = get_model_config_from_provider_instance(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) embd_model = LLMBundle(memory.tenant_id, embd_model_config) match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"]) match_text, _ = MsgTextQuery().question(question, min_match=params["similarity_threshold"]) diff --git a/api/db/joint_services/tenant_model_service.py b/api/db/joint_services/tenant_model_service.py index 677bfcaaaf..b52a1377bf 100644 --- a/api/db/joint_services/tenant_model_service.py +++ b/api/db/joint_services/tenant_model_service.py @@ -17,131 +17,15 @@ import logging import os import enum from common import settings -from common.constants import LLMType -from api.db.services.llm_service import LLMService +from common.constants import LLMType, ActiveStatusEnum from api.db.services.tenant_llm_service import TenantLLMService, TenantService +from api.db.services.tenant_model_provider_service import TenantModelProviderService +from api.db.services.tenant_model_instance_service import TenantModelInstanceService +from api.db.services.tenant_model_service import TenantModelService logger = logging.getLogger(__name__) -def get_model_config_by_id( - tenant_model_id: int, - allowed_tenant_ids: str | list[str] | set[str] | tuple[str, ...] | None = None, - requester_tenant_id: str | None = None, -) -> dict: - found, model_config = TenantLLMService.get_by_id(tenant_model_id) - if not found: - raise LookupError(f"Tenant Model with id {tenant_model_id} not found") - if allowed_tenant_ids is not None: - if isinstance(allowed_tenant_ids, str): - allowed_tenant_ids = {allowed_tenant_ids} - else: - allowed_tenant_ids = {str(tenant_id) for tenant_id in allowed_tenant_ids if tenant_id} - if str(model_config.tenant_id) not in allowed_tenant_ids: - logger.warning( - "Denied tenant model access: tenant_model_id=%s model_tenant_id=%s " - "allowed_tenant_ids=%s requester_tenant_id=%s", - tenant_model_id, - model_config.tenant_id, - sorted(allowed_tenant_ids), - requester_tenant_id, - ) - raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") - config_dict = model_config.to_dict() - api_key, is_tools, api_key_payload = TenantLLMService._decode_api_key_config(config_dict.get("api_key", "")) - config_dict["api_key"] = api_key - if api_key_payload is not None: - config_dict["api_key_payload"] = api_key_payload - if is_tools is not None: - config_dict["is_tools"] = is_tools - llm = LLMService.query(llm_name=config_dict["llm_name"]) - if "is_tools" not in config_dict and llm: - config_dict["is_tools"] = llm[0].is_tools - return config_dict - - -def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str): - if not model_name: - raise Exception("Model Name is required") - model_type_val = model_type.value if hasattr(model_type, "value") else model_type - model_config = TenantLLMService.get_api_key(tenant_id, model_name, model_type_val) - if not model_config: - # model_name in format 'name@factory', split model_name and try again - pure_model_name, fid = TenantLLMService.split_model_name_and_factory(model_name) - compose_profiles = os.getenv("COMPOSE_PROFILES", "") - is_tei_builtin_embedding = ( - model_type_val == LLMType.EMBEDDING.value - and "tei-" in compose_profiles - and pure_model_name == os.getenv("TEI_MODEL", "") - and (fid == "Builtin" or fid is None) - ) - if is_tei_builtin_embedding: - # configured local embedding model - embedding_cfg = settings.EMBEDDING_CFG - config_dict = { - "llm_factory": "Builtin", - "api_key": embedding_cfg["api_key"], - "llm_name": pure_model_name, - "api_base": embedding_cfg["base_url"], - "model_type": LLMType.EMBEDDING.value, - } - elif model_type_val == LLMType.CHAT.value: - # Retry as CHAT with pure_model_name first; then fall back to a multimodal model registered under IMAGE2TEXT. - model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, LLMType.CHAT.value) - if not model_config: - model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, LLMType.IMAGE2TEXT.value) - if not model_config: - raise LookupError(f"Tenant Model with name {model_name} and type {model_type_val} not found") - config_dict = model_config.to_dict() - elif model_type_val == LLMType.IMAGE2TEXT.value: - model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, LLMType.IMAGE2TEXT.value) - if not model_config: - # Fall back to a chat model only if it has declared IMAGE2TEXT capability (tag check via llm table) - chat_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, LLMType.CHAT.value) - logger.debug("IMAGE2TEXT config not found for %s; chat_config found: %s", pure_model_name, chat_config is not None) - if chat_config: - llm_entry = LLMService.query(fid=chat_config.llm_factory, llm_name=chat_config.llm_name) - tags = [t.strip() for t in (llm_entry[0].tags or "").split(",")] if llm_entry else [] - logger.debug("LLM tags for %s/%s: %s", chat_config.llm_factory, chat_config.llm_name, tags) - if "IMAGE2TEXT" in tags: - logger.debug("Promoting chat config to IMAGE2TEXT for %s", pure_model_name) - model_config = chat_config - if not model_config: - raise LookupError(f"Tenant Model with name {model_name} and type {model_type_val} not found") - config_dict = model_config.to_dict() - config_dict["model_type"] = LLMType.IMAGE2TEXT.value - else: - model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, model_type_val) - if not model_config: - raise LookupError(f"Tenant Model with name {model_name} and type {model_type_val} not found") - config_dict = model_config.to_dict() - else: - # model_name without @factory - config_dict = model_config.to_dict() - api_key, is_tools, api_key_payload = TenantLLMService._decode_api_key_config(config_dict.get("api_key", "")) - config_dict["api_key"] = api_key - if api_key_payload is not None: - config_dict["api_key_payload"] = api_key_payload - if is_tools is not None: - config_dict["is_tools"] = is_tools - config_model_type = config_dict.get("model_type") - config_model_type = config_model_type.value if hasattr(config_model_type, "value") else config_model_type - if config_model_type != model_type_val and not ( - model_type_val == LLMType.CHAT.value - and config_model_type == LLMType.IMAGE2TEXT.value - ) and not ( - model_type_val == LLMType.IMAGE2TEXT.value - and config_model_type == LLMType.CHAT.value - ): - raise LookupError( - f"Tenant Model with name {model_name} has type {config_model_type}, expected {model_type_val}" - ) - llm = LLMService.query(llm_name=config_dict["llm_name"]) - if "is_tools" not in config_dict and llm: - config_dict["is_tools"] = llm[0].is_tools - return config_dict - - def get_tenant_default_model_by_type(tenant_id: str, model_type: str|enum.Enum): exist, tenant = TenantService.get_by_id(tenant_id) if not exist: @@ -167,4 +51,135 @@ def get_tenant_default_model_by_type(tenant_id: str, model_type: str|enum.Enum): raise Exception(f"Unknown model type {model_type}") if not model_name: raise Exception(f"No default {model_type} model is set.") - return get_model_config_by_type_and_name(tenant_id, model_type, model_name) + return get_model_config_from_provider_instance(tenant_id, model_type, model_name) + + +def split_model_name(model_name: str): + # Parse model_name: {model_name} or {model_name}@{factory_name} or {model_name}@{instance_name}@{factory_name} + parts = model_name.split("@") + if len(parts) == 1: + pure_model_name = parts[0] + provider_name = "" + instance_name = "" + elif len(parts) == 2: + pure_model_name = parts[0] + provider_name = parts[1] + instance_name = "default" + else: + pure_model_name = parts[0] + instance_name = parts[1] + provider_name = parts[2] + return pure_model_name, instance_name, provider_name + + +def get_model_config_from_provider_instance(tenant_id, model_type: str|enum.Enum, model_name: str): + pure_model_name, instance_name, provider_name = split_model_name(model_name) + model_type_val = model_type if isinstance(model_type, str) else model_type.value + # Builtin embedding model + compose_profiles = os.getenv("COMPOSE_PROFILES", "") + is_tei_builtin_embedding = ( + model_type_val == LLMType.EMBEDDING.value + and "tei-" in compose_profiles + and pure_model_name == os.getenv("TEI_MODEL", "") + and (provider_name == "Builtin" or provider_name is None) + ) + if is_tei_builtin_embedding: + # configured local embedding model + embedding_cfg = settings.EMBEDDING_CFG + return { + "llm_factory": "Builtin", + "api_key": embedding_cfg["api_key"], + "llm_name": pure_model_name, + "api_base": embedding_cfg["base_url"], + "model_type": LLMType.EMBEDDING.value, + } + + provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if not provider_obj: + raise LookupError(f"Provider {provider_name} not found for model {model_name}.") + instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_name) + if not instance_obj: + raise LookupError(f"Instance {instance_name} not found for model {model_name}.") + model_obj = TenantModelService.get_by_provider_id_and_instance_id_and_model_type_and_model_name(provider_obj.id, instance_obj.id, model_type_val, pure_model_name) + + import json + api_key, is_tool, api_key_payload = TenantLLMService._decode_api_key_config(instance_obj.api_key) + extra_fields = json.loads(instance_obj.extra) if instance_obj.extra else {} + + if model_obj: + if model_obj.status == ActiveStatusEnum.INACTIVE.value: + raise LookupError(f"Model {model_name} is disabled.") + + model_config = { + "llm_factory": provider_obj.provider_name, + "api_key": api_key, + "llm_name": model_obj.model_name, + "api_base": extra_fields.get("base_url", ""), + "model_type": model_obj.model_type, + "is_tool": extra_fields.get("is_tool", is_tool) + } + if api_key_payload is not None: + model_config["api_key_payload"] = api_key_payload + + return model_config + else: + fac_list = [f for f in settings.FACTORY_LLM_INFOS if f["name"] == provider_name] + if not fac_list: + raise LookupError(f"Model provider config not found: {provider_name}") + llm_list = [llm for llm in fac_list[0]["llm"] if llm["llm_name"] == pure_model_name] + if not llm_list: + raise LookupError(f"Model config not found: {model_name}") + llm_info = llm_list[0] + model_config = { + "llm_factory": provider_obj.provider_name, + "api_key": api_key, + "llm_name": llm_info["llm_name"], + "api_base": extra_fields.get("base_url", ""), + "model_type": llm_info["model_type"], + "is_tool": llm_info.get("is_tool", is_tool) + } + if api_key_payload is not None: + model_config["api_key_payload"] = api_key_payload + return model_config + + +def get_api_key(tenant_id: str, model_name: str): + _, instance_name, provider_name = split_model_name(model_name) + + if not provider_name: + raise LookupError("Provider name is required.") + provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if not provider_obj: + raise LookupError(f"Provider {provider_name} not found.") + instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_name) + if not instance_obj: + raise LookupError(f"Instance {instance_name} not found.") + return instance_obj.api_key + + +def get_model_type_by_name(tenant_id: str, model_name: str): + pure_model_name, instance_name, provider_name = split_model_name(model_name) + provider_obj = TenantModelProviderService.get_by_tenant_id_and_provider_name(tenant_id, provider_name) + if not provider_obj: + raise LookupError(f"Provider {provider_name} not found for model {model_name}.") + instance_obj = TenantModelInstanceService.get_by_provider_id_and_instance_name(provider_obj.id, instance_name) + if not instance_obj: + raise LookupError(f"Instance {instance_name} not found for model {model_name}.") + model_objs = TenantModelService.get_by_provider_id_and_instance_id_and_model_name(provider_obj.id, instance_obj.id, pure_model_name) + if not model_objs: + fac_list = [f for f in settings.FACTORY_LLM_INFOS if f["name"] == provider_name] + if not fac_list: + raise LookupError(f"Model provider config not found: {provider_name}") + llm_list = [llm for llm in fac_list[0]["llm"] if llm["llm_name"] == pure_model_name] + if not llm_list: + raise LookupError(f"Model {pure_model_name} not found for model {model_name}.") + return [llm_list[0]["model_type"]] + return [model_obj.model_type for model_obj in model_objs] + + +def delete_models_by_instance_ids(instance_ids: list[str]): + return TenantModelService.delete_by_instance_ids(instance_ids) + + +def delete_instances_by_provider_ids(provider_ids: list[str]): + return TenantModelInstanceService.delete_by_provider_ids(provider_ids) diff --git a/api/db/services/__init__.py b/api/db/services/__init__.py index ce937911fb..a5e83ea0e4 100644 --- a/api/db/services/__init__.py +++ b/api/db/services/__init__.py @@ -42,7 +42,7 @@ def _split_name_counter(filename: str) -> tuple[str, int | None]: return filename, None -def duplicate_name(query_func, **kwargs) -> str: +def duplicate_name(query_func, name_field: str="name", **kwargs) -> str: """ Generates a unique filename by appending/incrementing a counter when duplicates exist. @@ -54,6 +54,7 @@ def duplicate_name(query_func, **kwargs) -> str: query_func: Callable that accepts keyword arguments and returns: - True if name exists (should be modified) - False if name is available + name_field: the field name of name in db. default to 'name' **kwargs: Must contain 'name' key with original filename to check Returns: @@ -72,10 +73,10 @@ def duplicate_name(query_func, **kwargs) -> str: """ MAX_RETRIES = 1000 - if "name" not in kwargs: - raise KeyError("Arguments must contain 'name' key") + if name_field not in kwargs: + raise KeyError(f"Arguments must contain '{name_field}' key") - original_name = kwargs["name"] + original_name = kwargs[name_field] current_name = original_name retries = 0 @@ -92,7 +93,7 @@ def duplicate_name(query_func, **kwargs) -> str: new_name = f"{main_part}({counter}){suffix}" - kwargs["name"] = new_name + kwargs[name_field] = new_name current_name = new_name retries += 1 diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 4084fb80fe..cd6186e69a 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -39,8 +39,7 @@ from api.utils.reference_metadata_utils import ( enrich_chunks_with_document_metadata, resolve_reference_metadata_preferences, ) -from api.db.services.tenant_llm_service import TenantLLMService -from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type +from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance, get_model_type_by_name from common.time_utils import current_timestamp, datetime_format from common.text_utils import normalize_arabic_digits from rag.graphrag.general.mind_map_extractor import MindMapExtractor @@ -288,21 +287,19 @@ class DialogService(CommonService): async def async_chat_solo(dialog, messages, stream=True): - llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id) + llm_types = get_model_type_by_name(dialog.tenant_id, dialog.llm_id) attachments = "" image_attachments = [] image_files = [] if "files" in messages[-1]: - if llm_type == "chat": + if "chat" in llm_types: text_attachments, image_attachments = split_file_attachments(messages[-1]["files"]) else: text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True) attachments = "\n\n".join(text_attachments) if dialog.llm_id: - model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) - elif dialog.tenant_llm_id: - model_config = get_model_config_by_id(dialog.tenant_llm_id) + model_config = get_model_config_from_provider_instance(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) else: model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT) @@ -317,10 +314,10 @@ async def async_chat_solo(dialog, messages, stream=True): msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"] if attachments and msg: msg[-1]["content"] += attachments - if llm_type == "chat" and image_attachments: + if "chat" in llm_types and image_attachments: convert_last_user_msg_to_multimodal(msg, image_attachments, factory) if stream: - if llm_type == "chat": + if "chat" in llm_types: stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting) else: stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting, images=image_files) @@ -331,7 +328,7 @@ async def async_chat_solo(dialog, messages, stream=True): continue yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False} else: - if llm_type == "chat": + if "chat" in llm_types: answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting) else: answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting, images=image_files) @@ -349,22 +346,20 @@ def get_models(dialog): if embedding_list: embd_owner_tenant_id = kbs[0].tenant_id - embd_model_config = get_model_config_by_type_and_name(embd_owner_tenant_id, LLMType.EMBEDDING, embedding_list[0]) + embd_model_config = get_model_config_from_provider_instance(embd_owner_tenant_id, LLMType.EMBEDDING, embedding_list[0]) embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config) if not embd_mdl: raise LookupError("Embedding model(%s) not found" % embedding_list[0]) if dialog.llm_id: - chat_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) - elif dialog.tenant_llm_id: - chat_model_config = get_model_config_by_id(dialog.tenant_llm_id) + chat_model_config = get_model_config_from_provider_instance(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) else: chat_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config) if dialog.rerank_id: - rerank_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) + rerank_model_config = get_model_config_from_provider_instance(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) rerank_mdl = LLMBundle(dialog.tenant_id, rerank_model_config) if dialog.prompt_config.get("tts"): @@ -554,11 +549,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs): return chat_start_ts = timer() - llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id) - if llm_type == "image2text": - llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) + if dialog.llm_id: + llm_types = get_model_type_by_name(dialog.tenant_id, dialog.llm_id) + if "image2text" in llm_types: + llm_model_config = get_model_config_from_provider_instance(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) + else: + llm_model_config = get_model_config_from_provider_instance(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) else: - llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + llm_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT) factory = llm_model_config.get("llm_factory", "") if llm_model_config else "" max_tokens = llm_model_config.get("max_tokens", 8192) @@ -598,7 +596,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): if "doc_ids" in messages[-1]: attachments = [doc_id for doc_id in messages[-1]["doc_ids"] if doc_id] if "files" in messages[-1]: - if llm_type == "chat": + if llm_model_config["model_type"] == "chat": text_attachments, image_attachments = split_file_attachments(messages[-1]["files"]) else: text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True) @@ -769,7 +767,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): prompt4citation = citation_prompt() msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]) used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95)) - if llm_type == "chat" and image_attachments: + if llm_model_config["model_type"] == "chat" and image_attachments: convert_last_user_msg_to_multimodal(msg, image_attachments, factory) assert len(msg) >= 2, f"message_fit_in has bug: {msg}" prompt = msg[0]["content"] @@ -881,7 +879,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): langfuse_generation = None if stream: - if llm_type == "chat": + if llm_model_config["model_type"] == "chat": stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf) else: stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf, images=image_files) @@ -900,7 +898,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): final["audio_binary"] = None yield final else: - if llm_type == "chat": + if llm_model_config["model_type"] == "chat": answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf) else: answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf, images=image_files) @@ -1540,12 +1538,12 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever embd_owner_tenant_id = kbs[0].tenant_id - embd_model_config = get_model_config_by_type_and_name(embd_owner_tenant_id, LLMType.EMBEDDING, embedding_list[0]) + embd_model_config = get_model_config_from_provider_instance(embd_owner_tenant_id, LLMType.EMBEDDING, embedding_list[0]) embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config) - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_llm_name) + chat_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.CHAT, chat_llm_name) chat_mdl = LLMBundle(tenant_id, chat_model_config) if rerank_id: - rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id) + rerank_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.RERANK, rerank_id) rerank_mdl = LLMBundle(tenant_id, rerank_model_config) max_tokens = chat_mdl.max_length tenant_ids = list(set([kb.tenant_id for kb in kbs])) @@ -1649,23 +1647,18 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}): kbs = KnowledgebaseService.get_by_ids(kb_ids) if not kbs: return {"error": "No KB selected"} - tenant_embedding_list = list(set([kb.tenant_embd_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs])) - if tenant_embedding_list[0]: - embd_model_config = get_model_config_by_id(tenant_embedding_list[0]) - embd_owner_tenant_id = kbs[0].tenant_id - else: - embd_owner_tenant_id = kbs[0].tenant_id - embd_model_config = get_model_config_by_type_and_name(embd_owner_tenant_id, LLMType.EMBEDDING, kbs[0].embd_id) + embd_owner_tenant_id = kbs[0].tenant_id + embd_model_config = get_model_config_from_provider_instance(embd_owner_tenant_id, LLMType.EMBEDDING, kbs[0].embd_id) embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config) chat_id = search_config.get("chat_id", "") if chat_id: - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id) + chat_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.CHAT, chat_id) else: chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, chat_model_config) if rerank_id: - rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id) + rerank_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.RERANK, rerank_id) rerank_mdl = LLMBundle(tenant_id, rerank_model_config) if meta_data_filter: diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index acd35bfe6e..c808d4add1 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -24,7 +24,7 @@ from typing import Generator from api.db.db_models import LLM from api.db.services.common_service import CommonService -from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService +from api.db.services.tenant_llm_service import LLM4Tenant from common.constants import LLMType from common.token_utils import num_tokens_from_string @@ -137,9 +137,9 @@ class LLMBundle(LLM4Tenant): embeddings, used_tokens = self.mdl.encode(safe_texts) if self.model_config["llm_factory"] == "Builtin": - logging.debug("LLMBundle.encode_queries query: {}, emd len: {}, used_tokens: {}. Builtin model don't need to update token usage".format(texts, len(embeddings), used_tokens)) - elif not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens): - logging.error("LLMBundle.encode can't update token usage for /EMBEDDING used_tokens: {}".format(used_tokens)) + logging.debug("LLMBundle.encode query: {}, emd len: {}, used_tokens: {}. Builtin model don't need to update token usage".format(texts, len(embeddings), used_tokens)) + else: + logging.info("LLMBundle.encode used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"])) if self.langfuse: generation.update(usage_details={"total_tokens": used_tokens}) @@ -162,8 +162,8 @@ class LLMBundle(LLM4Tenant): emd, used_tokens = self.mdl.encode_queries(query) if self.model_config["llm_factory"] == "Builtin": logging.info("LLMBundle.encode_queries query: {}, emd len: {}, used_tokens: {}. Builtin model don't need to update token usage".format(query, len(emd), used_tokens)) - elif not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens): - logging.error("LLMBundle.encode_queries can't update token usage for /EMBEDDING used_tokens: {}".format(used_tokens)) + else: + logging.info("LLMBundle.encode_queries used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"])) if self.langfuse: generation.update(usage_details={"total_tokens": used_tokens}) @@ -176,8 +176,7 @@ class LLMBundle(LLM4Tenant): generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts}) sim, used_tokens = self.mdl.similarity(query, texts) - if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens): - logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens)) + logging.info("LLMBundle.similarity used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"])) if self.langfuse: generation.update(usage_details={"total_tokens": used_tokens}) @@ -190,8 +189,7 @@ class LLMBundle(LLM4Tenant): generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="describe", metadata={"model": self.model_config["llm_name"]}) txt, used_tokens = self.mdl.describe(image) - if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens): - logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens)) + logging.info("LLMBundle.describe used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"])) if self.langfuse: generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) @@ -204,8 +202,7 @@ class LLMBundle(LLM4Tenant): generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt}) txt, used_tokens = self.mdl.describe_with_prompt(image, prompt) - if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens): - logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens)) + logging.info("LLMBundle.describe_with_prompt used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"])) if self.langfuse: generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) @@ -218,8 +215,7 @@ class LLMBundle(LLM4Tenant): generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="transcription", metadata={"model": self.model_config["llm_name"]}) txt, used_tokens = self.mdl.transcription(audio) - if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens): - logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens)) + logging.info("LLMBundle.transcription used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"])) if self.langfuse: generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) @@ -254,7 +250,7 @@ class LLMBundle(LLM4Tenant): finally: if final_text: used_tokens = num_tokens_from_string(final_text) - TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens) + logging.info("LLMBundle.stream_transcription used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"])) if self.langfuse: generation.update( @@ -273,8 +269,7 @@ class LLMBundle(LLM4Tenant): ) full_text, used_tokens = mdl.transcription(audio) - if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens): - logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}") + logging.info("LLMBundle.stream_transcription used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"])) if self.langfuse: generation.update( @@ -295,8 +290,7 @@ class LLMBundle(LLM4Tenant): for chunk in self.mdl.tts(text): if isinstance(chunk, int): - if not TenantLLMService.increase_usage_by_id(self.model_config["id"], chunk): - logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id)) + logging.info("LLMBundle.tts used_tokens: {}, model_name: {}".format(chunk, self.model_config["llm_name"])) return yield chunk @@ -431,8 +425,8 @@ class LLMBundle(LLM4Tenant): if not self.verbose_tool_use: txt = re.sub(r".*?", "", txt, flags=re.DOTALL) - if used_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens): - logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], used_tokens)) + if used_tokens: + logging.info("LLMBundle.async_chat used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"])) if generation: generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) @@ -479,8 +473,8 @@ class LLMBundle(LLM4Tenant): generation.update(output={"error": str(e)}) generation.end() raise - if total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], total_tokens): - logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], total_tokens)) + if total_tokens: + logging.info("LLMBundle.async_chat_streamly used_tokens: {}, llm_name: {}".format(total_tokens, self.model_config["llm_name"])) if generation: generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens}) generation.end() @@ -522,8 +516,8 @@ class LLMBundle(LLM4Tenant): generation.update(output={"error": str(e)}) generation.end() raise - if total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], total_tokens): - logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], total_tokens)) + if total_tokens: + logging.info("LLMBundle.async_chat_streamly_delta used_tokens: {}, llm_name: {}".format(total_tokens, self.model_config["llm_name"])) if generation: generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens}) generation.end() diff --git a/api/db/services/memory_service.py b/api/db/services/memory_service.py index 530fc5ad9e..a565709cf3 100644 --- a/api/db/services/memory_service.py +++ b/api/db/services/memory_service.py @@ -112,7 +112,7 @@ class MemoryService(CommonService): @classmethod @DB.connection_context() - def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, tenant_embd_id: int, llm_id: str, tenant_llm_id: int): + def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, llm_id: str): # Deduplicate name within tenant memory_name = duplicate_name( cls.query, @@ -131,9 +131,7 @@ class MemoryService(CommonService): "memory_type": calculate_memory_type(memory_type), "tenant_id": tenant_id, "embd_id": embd_id, - "tenant_embd_id": tenant_embd_id, "llm_id": llm_id, - "tenant_llm_id": tenant_llm_id, "system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}), "create_time": timestamp, "create_date": format_time, diff --git a/api/db/services/tenant_llm_service.py b/api/db/services/tenant_llm_service.py index 7fae4571cc..c28dd72efb 100644 --- a/api/db/services/tenant_llm_service.py +++ b/api/db/services/tenant_llm_service.py @@ -188,30 +188,36 @@ class TenantLLMService(CommonService): api_key = model_config.get("api_key_payload", model_config["api_key"]) if model_config["model_type"] == LLMType.EMBEDDING.value: if model_config["llm_factory"] not in EmbeddingModel: + logging.error(f"Factory {model_config['llm_factory']} not in embedding model. Supported factories: {EmbeddingModel.keys()}") return None return EmbeddingModel[model_config["llm_factory"]](api_key, model_config["llm_name"], base_url=model_config["api_base"]) - elif model_config["model_type"] == LLMType.RERANK: + elif model_config["model_type"] == LLMType.RERANK.value: if model_config["llm_factory"] not in RerankModel: + logging.error(f"Factory {model_config['llm_factory']} not in rerank model. Supported factories: {RerankModel.keys()}") return None return RerankModel[model_config["llm_factory"]](api_key, model_config["llm_name"], base_url=model_config["api_base"]) elif model_config["model_type"] == LLMType.IMAGE2TEXT.value: if model_config["llm_factory"] not in CvModel: + logging.error(f"Factory {model_config['llm_factory']} not in cv model. Supported factories: {CvModel.keys()}") return None return CvModel[model_config["llm_factory"]](api_key, model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs) elif model_config["model_type"] == LLMType.CHAT.value: if model_config["llm_factory"] not in ChatModel: + logging.error(f"Factory {model_config['llm_factory']} not in chat model. Supported factories: {ChatModel.keys()}") return None return ChatModel[model_config["llm_factory"]](api_key, model_config["llm_name"], base_url=model_config["api_base"], **kwargs) - elif model_config["model_type"] == LLMType.SPEECH2TEXT: + elif model_config["model_type"] == LLMType.SPEECH2TEXT.value: if model_config["llm_factory"] not in Seq2txtModel: + logging.error(f"Factory {model_config['llm_factory']} not in speech2text model. Supported factories: {Seq2txtModel.keys()}") return None return Seq2txtModel[model_config["llm_factory"]](key=api_key, model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"]) - elif model_config["model_type"] == LLMType.TTS: + elif model_config["model_type"] == LLMType.TTS.value: if model_config["llm_factory"] not in TTSModel: + logging.error(f"Factory {model_config['llm_factory']} not in tts model. Supported factories: {TTSModel.keys()}") return None return TTSModel[model_config["llm_factory"]]( api_key, @@ -219,8 +225,9 @@ class TenantLLMService(CommonService): base_url=model_config["api_base"], ) - elif model_config["model_type"] == LLMType.OCR: + elif model_config["model_type"] == LLMType.OCR.value: if model_config["llm_factory"] not in OcrModel: + logging.error(f"Factory {model_config['llm_factory']} not in ocr model. Supported factories: {OcrModel.keys()}") return None return OcrModel[model_config["llm_factory"]]( key=api_key, diff --git a/api/db/services/tenant_model_group_mapping_service.py b/api/db/services/tenant_model_group_mapping_service.py new file mode 100644 index 0000000000..590c65129f --- /dev/null +++ b/api/db/services/tenant_model_group_mapping_service.py @@ -0,0 +1,31 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from api.db.db_models import DB, TenantModelGroupMapping +from api.db.services.common_service import CommonService + + +class TenantModelGroupMappingService(CommonService): + model = TenantModelGroupMapping + + @classmethod + @DB.connection_context() + def get_by_composite_id(cls, group_id, provider_id, instance_id, model_id): + return cls.model.get_or_none( + cls.model.group_id == group_id, + cls.model.provider_id == provider_id, + cls.model.instance_id == instance_id, + cls.model.model_id == model_id, + ) \ No newline at end of file diff --git a/api/db/services/tenant_model_group_service.py b/api/db/services/tenant_model_group_service.py new file mode 100644 index 0000000000..88781eb17e --- /dev/null +++ b/api/db/services/tenant_model_group_service.py @@ -0,0 +1,21 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from api.db.db_models import TenantModelGroup +from api.db.services.common_service import CommonService + + +class TenantModelGroupService(CommonService): + model = TenantModelGroup \ No newline at end of file diff --git a/api/db/services/tenant_model_instance_service.py b/api/db/services/tenant_model_instance_service.py new file mode 100644 index 0000000000..0f44de8962 --- /dev/null +++ b/api/db/services/tenant_model_instance_service.py @@ -0,0 +1,69 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from common.misc_utils import get_uuid +from api.db.db_models import DB, TenantModelInstance +from api.db.services.common_service import CommonService +from api.db.services import duplicate_name + +class TenantModelInstanceService(CommonService): + model = TenantModelInstance + + @classmethod + @DB.connection_context() + def create_instance(cls, provider_id: str, instance_name: str, api_key: str, extra: str): + unique_instance_name = duplicate_name(cls.query, name_field="instance_name", provider_id=provider_id, instance_name=instance_name) + return cls.insert(id=get_uuid(), provider_id=provider_id, instance_name=unique_instance_name, api_key=api_key, extra=extra) + + @classmethod + @DB.connection_context() + def get_all_by_provider_id(cls, provider_id): + return list(cls.model.select().where(cls.model.provider_id == provider_id)) + + @classmethod + @DB.connection_context() + def get_by_provider_ids(cls, provider_ids): + return list(cls.model.select().where(cls.model.provider_id.in_(provider_ids))) + + @classmethod + @DB.connection_context() + def get_by_provider_id_and_instance_name(cls, provider_id, instance_name): + return cls.model.get_or_none( + cls.model.provider_id == provider_id, + cls.model.instance_name == instance_name, + ) + + @classmethod + @DB.connection_context() + def get_by_provider_id_and_api_key(cls, provider_id, api_key): + return cls.model.get_or_none( + cls.model.provider_id == provider_id, + cls.model.api_key == api_key + ) + + @classmethod + @DB.connection_context() + def delete_by_provider_id_and_instance_name(cls, provider_id, instance_name): + return cls.model.delete().where( + cls.model.provider_id == provider_id, + cls.model.instance_name == instance_name, + ).execute() + + @classmethod + @DB.connection_context() + def delete_by_provider_ids(cls, provider_ids): + return cls.model.delete().where( + cls.model.provider_id.in_(provider_ids) + ).execute() diff --git a/api/db/services/tenant_model_provider_service.py b/api/db/services/tenant_model_provider_service.py new file mode 100644 index 0000000000..1472175947 --- /dev/null +++ b/api/db/services/tenant_model_provider_service.py @@ -0,0 +1,52 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from api.db.db_models import DB, TenantModelProvider +from api.db.services.common_service import CommonService + + +class TenantModelProviderService(CommonService): + model = TenantModelProvider + + @classmethod + @DB.connection_context() + def get_by_tenant_id_and_provider_name(cls, tenant_id, provider_name): + return cls.model.get_or_none( + cls.model.tenant_id == tenant_id, + cls.model.provider_name == provider_name, + ) + + @classmethod + @DB.connection_context() + def get_by_tenant_id(cls, tenant_id): + return list(cls.model.select().where(cls.model.tenant_id == tenant_id)) + + @classmethod + @DB.connection_context() + def delete_by_tenant_id(cls, tenant_id): + return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute() + + @classmethod + @DB.connection_context() + def delete_by_tenant_id_and_provider_name(cls, tenant_id, provider_name): + return cls.model.delete().where( + cls.model.tenant_id == tenant_id, + cls.model.provider_name == provider_name, + ).execute() + + @classmethod + @DB.connection_context() + def list_provider_names_by_tenant_id(cls, tenant_id): + return [row.provider_name for row in cls.model.select(cls.model.provider_name).where(cls.model.tenant_id == tenant_id)] \ No newline at end of file diff --git a/api/db/services/tenant_model_service.py b/api/db/services/tenant_model_service.py new file mode 100644 index 0000000000..00aa0e659b --- /dev/null +++ b/api/db/services/tenant_model_service.py @@ -0,0 +1,61 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from api.db.db_models import DB, TenantModel +from api.db.services.common_service import CommonService + + +class TenantModelService(CommonService): + model = TenantModel + + @classmethod + @DB.connection_context() + def get_by_provider_id_and_instance_id_and_model_name(cls, provider_id, instance_id, model_name): + return list(cls.model.select().where(cls.model.provider_id == provider_id, cls.model.instance_id == instance_id, cls.model.model_name == model_name)) + + @classmethod + @DB.connection_context() + def get_by_provider_id_and_instance_id_and_model_type_and_model_name(cls, provider_id, instance_id, model_type, model_name): + return cls.model.get_or_none( + cls.model.provider_id == provider_id, + cls.model.instance_id == instance_id, + cls.model.model_type == model_type, + cls.model.model_name == model_name + ) + + @classmethod + @DB.connection_context() + def get_models_by_instance_id(cls, instance_id): + return list(cls.model.select().where(cls.model.instance_id == instance_id)) + + @classmethod + @DB.connection_context() + def get_models_by_provider_ids_and_instance_ids(cls, provider_ids, instance_ids): + return list(cls.model.select().where(cls.model.provider_id.in_(provider_ids), cls.model.instance_id.in_(instance_ids))) + + @classmethod + @DB.connection_context() + def batch_update_model_status(cls, model_ids, status): + return cls.model.update(status=status).where(cls.model.id.in_(model_ids)).execute() + + @classmethod + @DB.connection_context() + def delete_by_id(cls, model_id): + return cls.model.delete().where(cls.model.id == model_id).execute() + + @classmethod + @DB.connection_context() + def delete_by_instance_ids(cls, instance_ids): + return cls.model.delete().where(cls.model.instance_id.in_(instance_ids)).execute() diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py index 6804dbd445..d6b985dd47 100644 --- a/api/db/services/user_service.py +++ b/api/db/services/user_service.py @@ -191,6 +191,7 @@ class TenantService(CommonService): cls.model.asr_id, cls.model.img2txt_id, cls.model.tts_id, + cls.model.ocr_id, cls.model.parser_ids, UserTenant.role] return list(cls.model.select(*fields) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 8c6abca7d5..4712d9504f 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -41,7 +41,7 @@ except ImportError: # pragma: no cover - optional dependency from peewee import OperationalError -from common.constants import ActiveEnum +from common.constants import ActiveEnum, LLMType from api.db.db_models import APIToken from api.utils.json_encode import CustomJSONEncoder from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions @@ -576,8 +576,7 @@ def check_duplicate_ids(ids, id_type="item"): def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, str | None]: - from api.db.services.llm_service import LLMService - from api.db.services.tenant_llm_service import TenantLLMService + from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance """ Verifies availability of an embedding model for a specific tenant. @@ -613,18 +612,9 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, s (False, {'code': 101, 'message': "Unsupported model: "}) """ try: - llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(embd_id) - in_llm_service = bool(LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding")) - - tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id) - is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms) - - is_builtin_model = llm_factory == "Builtin" - if not (is_builtin_model or is_tenant_model or in_llm_service): - return False, f"Unsupported model: <{embd_id}>" - - if not (is_builtin_model or is_tenant_model): - return False, f"Unauthorized model: <{embd_id}>" + get_model_config_from_provider_instance(tenant_id, LLMType.EMBEDDING, embd_id) + except LookupError as e: + return False, str(e) except OperationalError as e: logging.exception(e) return False, "Database operation failed" diff --git a/api/utils/tenant_utils.py b/api/utils/tenant_utils.py deleted file mode 100644 index 80f75b6fd6..0000000000 --- a/api/utils/tenant_utils.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from common.constants import LLMType -from common.exceptions import ArgumentException -from api.db.services.tenant_llm_service import TenantLLMService - -_KEY_TO_MODEL_TYPE = { - "llm_id": LLMType.CHAT, - "embd_id": LLMType.EMBEDDING, - "asr_id": LLMType.SPEECH2TEXT, - "img2txt_id": LLMType.IMAGE2TEXT, - "rerank_id": LLMType.RERANK, - "tts_id": LLMType.TTS, -} - -def ensure_tenant_model_id_for_params(tenant_id: str, param_dict: dict, *, strict: bool = False) -> dict: - for key in ["llm_id", "embd_id", "asr_id", "img2txt_id", "rerank_id", "tts_id"]: - if param_dict.get(key) and not param_dict.get(f"tenant_{key}"): - model_type = _KEY_TO_MODEL_TYPE.get(key) - tenant_model = TenantLLMService.get_api_key(tenant_id, param_dict[key], model_type) - if not tenant_model and model_type == LLMType.CHAT: - tenant_model = TenantLLMService.get_api_key(tenant_id, param_dict[key]) - if tenant_model: - param_dict.update({f"tenant_{key}": tenant_model.id}) - else: - if strict: - model_type_val = model_type.value if hasattr(model_type, "value") else model_type - raise ArgumentException( - f"Tenant Model with name {param_dict[key]} and type {model_type_val} not found" - ) - param_dict.update({f"tenant_{key}": 0}) - return param_dict diff --git a/common/constants.py b/common/constants.py index b222c4caf7..a766a21710 100644 --- a/common/constants.py +++ b/common/constants.py @@ -66,6 +66,11 @@ class StatusEnum(Enum): INVALID = "0" +class ActiveStatusEnum(Enum): + ACTIVE = "active" + INACTIVE = "inactive" + + class ActiveEnum(Enum): ACTIVE = "1" INACTIVE = "0" diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 4a98f2ccc5..b8da976234 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -4376,21 +4376,21 @@ "llm_name": "Qwen3-Reranker-8B", "tags": "TEXT RE-RANK,32K", "max_tokens": 32768, - "model_type": "reranker", + "model_type": "rerank", "is_tools": false }, { "llm_name": "Qwen3-Reranker-4B", "tags": "TEXT RE-RANK,32K", "max_tokens": 32768, - "model_type": "reranker", + "model_type": "rerank", "is_tools": false }, { "llm_name": "Qwen3-Reranker-0.6B", "tags": "TEXT RE-RANK,32K", "max_tokens": 32768, - "model_type": "reranker", + "model_type": "rerank", "is_tools": false }, { @@ -4432,7 +4432,7 @@ "llm_name": "jina-reranker-m0", "tags": "TEXT RE-RANK,10K", "max_tokens": 10240, - "model_type": "reranker", + "model_type": "rerank", "is_tools": false }, { @@ -4446,7 +4446,7 @@ "llm_name": "bce-reranker-base_v1", "tags": "TEXT RE-RANK", "max_tokens": 512, - "model_type": "reranker", + "model_type": "rerank", "is_tools": false }, { @@ -4460,7 +4460,7 @@ "llm_name": "bge-reranker-v2-m3", "tags": "TEXT RE-RANK", "max_tokens": 8192, - "model_type": "reranker", + "model_type": "rerank", "is_tools": false }, { @@ -6158,13 +6158,13 @@ "llm_name": "baai/bge-reranker-v2-m3", "tags": "RE-RANK,8K", "max_tokens": 8000, - "model_type": "reranker" + "model_type": "rerank" }, { "llm_name": "qwen/qwen3-reranker-8b", "tags": "RE-RANK,32K", "max_tokens": 32768, - "model_type": "reranker" + "model_type": "rerank" } ] }, diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 6eba5825d6..03066a87e3 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -28,6 +28,7 @@ services: # Example configuration to start Admin server: command: - --enable-adminserver + - --init-model-provider-tables ports: - ${SVR_WEB_HTTP_PORT}:80 - ${SVR_WEB_HTTPS_PORT}:443 diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 99ae05fb6b..4584b7159e 100755 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -16,6 +16,7 @@ function usage() { echo " --disable-datasync Disables synchronization of datasource workers." echo " --enable-mcpserver Enables the MCP server." echo " --enable-adminserver Enables the Admin server." + echo " --init-model-provider-tables Run model provider table migrations and exit." echo " --init-superuser Initializes the superuser." echo " --consumer-no-beg= Start range for consumers (if using range-based)." echo " --consumer-no-end= End range for consumers (if using range-based)." @@ -38,6 +39,7 @@ ENABLE_DATASYNC=1 ENABLE_MCP_SERVER=0 ENABLE_ADMIN_SERVER=0 # Default close admin server INIT_SUPERUSER_ARGS="" # Default to not initialize superuser +INIT_MODEL_PROVIDER_TABLES=0 CONSUMER_NO_BEG=0 CONSUMER_NO_END=0 WORKERS=1 @@ -89,6 +91,10 @@ for arg in "$@"; do ENABLE_ADMIN_SERVER=1 shift ;; + --init-model-provider-tables) + INIT_MODEL_PROVIDER_TABLES=1 + shift + ;; --init-superuser) INIT_SUPERUSER_ARGS="--init-superuser" shift @@ -266,6 +272,14 @@ function wait_for_server() { ensure_docling ensure_db_init +if [[ "${INIT_MODEL_PROVIDER_TABLES}" -eq 1 ]]; then + echo "Running model provider table migrations..." + "$PY" tools/scripts/mysql_migration.py --stages tenant_model_provider --config conf/service_conf.yaml --execute + "$PY" tools/scripts/mysql_migration.py --stages tenant_model_instance --config conf/service_conf.yaml --execute + "$PY" tools/scripts/mysql_migration.py --stages tenant_model --config conf/service_conf.yaml --execute + echo "Model provider table migrations completed." +fi + if [[ "${ENABLE_WEBSERVER}" -eq 1 ]]; then echo "Starting nginx..." /usr/sbin/nginx @@ -280,7 +294,7 @@ if [[ "${ENABLE_WEBSERVER}" -eq 1 ]]; then if [[ "${API_PROXY_SCHEME}" == "hybrid" ]]; then while true; do echo "Attempt to start RAGFlow go server..." - wait_for_server "http://127.0.0.1:9380/healthz" "ragflow_server" + wait_for_server "http://127.0.0.1:9380/api/v1/system/healthz" "ragflow_server" echo "Starting RAGFlow go server..." bin/server_main sleep 1; diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 004649c7c6..b78e4708da 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -263,7 +263,7 @@ func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName InstanceName: instanceName, ProviderID: provider.ID, APIKey: apiKey, - Status: "enable", + Status: "active", Extra: extraStr, } err = m.modelInstanceDAO.Create(tenantModelProvider) diff --git a/rag/app/naive.py b/rag/app/naive.py index 8b50464a5b..1dac71e107 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -31,7 +31,7 @@ from common.token_utils import num_tokens_from_string from common.constants import LLMType, MAXIMUM_PAGE_NUMBER from api.db.services.llm_service import LLMBundle -from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_tenant_default_model_by_type +from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from rag.utils.file_utils import extract_embed_file, extract_links_from_pdf, extract_links_from_docx, extract_html from deepdoc.parser import DocxParser, EpubParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser from deepdoc.parser.figure_parser import VisionFigureParser, vision_figure_parser_docx_wrapper_naive, vision_figure_parser_pdf_wrapper @@ -128,7 +128,7 @@ def by_mineru( if mineru_llm_name: try: - ocr_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.OCR, mineru_llm_name) + ocr_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.OCR, mineru_llm_name) ocr_model = LLMBundle(tenant_id=tenant_id, model_config=ocr_model_config, lang=lang) pdf_parser = ocr_model.mdl @@ -211,7 +211,7 @@ def by_opendataloader( if opendataloader_llm_name: try: - ocr_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.OCR, opendataloader_llm_name) + ocr_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.OCR, opendataloader_llm_name) ocr_model = LLMBundle(tenant_id=tenant_id, model_config=ocr_model_config, lang=lang) pdf_parser = ocr_model.mdl parse_options = {k: kwargs[k] for k in ("hybrid", "image_output", "sanitize") if k in kwargs} @@ -272,7 +272,7 @@ def by_paddleocr( if paddleocr_llm_name: try: - ocr_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.OCR, paddleocr_llm_name) + ocr_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.OCR, paddleocr_llm_name) ocr_model = LLMBundle(tenant_id=tenant_id, model_config=ocr_model_config, lang=lang) pdf_parser = ocr_model.mdl sections, tables = pdf_parser.parse_pdf( @@ -301,7 +301,7 @@ def by_plaintext(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER tenant_id = kwargs.get("tenant_id") if not tenant_id: raise ValueError("tenant_id is required when using vision layout recognizer") - vision_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.IMAGE2TEXT, layout_recognizer) + vision_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.IMAGE2TEXT, layout_recognizer) vision_model = LLMBundle( tenant_id, model_config=vision_model_config, diff --git a/rag/benchmark.py b/rag/benchmark.py index 0e3b256f87..9b069c1d3a 100644 --- a/rag/benchmark.py +++ b/rag/benchmark.py @@ -25,7 +25,7 @@ from common import settings from common.constants import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name +from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance from common.misc_utils import get_uuid from rag.nlp import tokenize, search from ranx import evaluate @@ -43,10 +43,7 @@ class Benchmark: e, self.kb = KnowledgebaseService.get_by_id(kb_id) self.similarity_threshold = self.kb.similarity_threshold self.vector_similarity_weight = self.kb.vector_similarity_weight - if self.kb.tenant_embd_id: - embd_model_config = get_model_config_by_id(self.kb.tenant_embd_id) - else: - embd_model_config = get_model_config_by_type_and_name(self.kb.tenant_id, LLMType.EMBEDDING, self.kb.embd_id) + embd_model_config = get_model_config_from_provider_instance(self.kb.tenant_id, LLMType.EMBEDDING, self.kb.embd_id) self.embd_mdl = LLMBundle(self.kb.tenant_id, embd_model_config, lang=self.kb.language) self.tenant_id = '' self.index_name = '' diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index e4f66a554f..8721f2b222 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -27,7 +27,7 @@ from PIL import Image from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.llm_service import LLMBundle -from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_tenant_default_model_by_type +from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from common import settings from common.constants import LLMType from common.misc_utils import get_uuid, thread_pool_exec @@ -381,7 +381,7 @@ class Parser(ProcessBase): raise RuntimeError("MinerU model not configured. Please add MinerU in Model Providers or set MINERU_* env.") tenant_id = self._canvas._tenant_id - ocr_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.OCR, parser_model_name) + ocr_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.OCR, parser_model_name) ocr_model = LLMBundle(tenant_id, ocr_model_config, lang=conf.get("lang", "Chinese")) pdf_parser = ocr_model.mdl @@ -458,7 +458,7 @@ class Parser(ProcessBase): raise RuntimeError("OpenDataLoader model not configured. Please add OpenDataLoader in Model Providers.") tenant_id = self._canvas._tenant_id - ocr_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.OCR, parser_model_name) + ocr_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.OCR, parser_model_name) ocr_model = LLMBundle(tenant_id, ocr_model_config) pdf_parser = ocr_model.mdl @@ -563,7 +563,7 @@ class Parser(ProcessBase): raise RuntimeError("PaddleOCR model not configured. Please add PaddleOCR in Model Providers or set PADDLEOCR_* env.") tenant_id = self._canvas._tenant_id - ocr_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.OCR, parser_model_name) + ocr_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.OCR, parser_model_name) ocr_model = LLMBundle(tenant_id, ocr_model_config) pdf_parser = ocr_model.mdl @@ -593,7 +593,7 @@ class Parser(ProcessBase): # Vision parser treats each page as a large image block. else: if conf.get("parse_method"): - vision_model_config = get_model_config_by_type_and_name(self._canvas._tenant_id, LLMType.IMAGE2TEXT, conf["parse_method"]) + vision_model_config = get_model_config_from_provider_instance(self._canvas._tenant_id, LLMType.IMAGE2TEXT, conf["parse_method"]) else: vision_model_config = get_tenant_default_model_by_type(self._canvas._tenant_id, LLMType.IMAGE2TEXT) vision_model = LLMBundle(self._canvas._tenant_id, vision_model_config, lang=self._param.setups["pdf"].get("lang")) @@ -1067,7 +1067,7 @@ class Parser(ProcessBase): else: lang = conf["lang"] # use VLM to describe the picture - cv_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, conf["parse_method"]) + cv_model_config = get_model_config_from_provider_instance(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, conf["parse_method"]) cv_model = LLMBundle(self._canvas.get_tenant_id(), cv_model_config, lang=lang) img_binary = io.BytesIO() img.save(img_binary, format="JPEG") @@ -1103,7 +1103,7 @@ class Parser(ProcessBase): tmpf.write(blob) tmpf.flush() tmp_path = os.path.abspath(tmpf.name) - seq2txt_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), LLMType.SPEECH2TEXT, vlm["llm_id"]) + seq2txt_model_config = get_model_config_from_provider_instance(self._canvas.get_tenant_id(), LLMType.SPEECH2TEXT, vlm["llm_id"]) seq2txt_mdl = LLMBundle(self._canvas.get_tenant_id(), seq2txt_model_config) txt = seq2txt_mdl.transcription(tmp_path) @@ -1116,7 +1116,7 @@ class Parser(ProcessBase): conf = self._param.setups["video"] vlm = conf.get("vlm") self.set_output("output_format", conf["output_format"]) - cv_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, vlm["llm_id"]) + cv_model_config = get_model_config_from_provider_instance(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, vlm["llm_id"]) cv_mdl = LLMBundle(self._canvas.get_tenant_id(), cv_model_config) video_prompt = str(conf.get("prompt", "") or "") txt = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=blob, filename=name, video_prompt=video_prompt)) diff --git a/rag/flow/parser/utils.py b/rag/flow/parser/utils.py index 93bc1652a8..5246acd2fd 100644 --- a/rag/flow/parser/utils.py +++ b/rag/flow/parser/utils.py @@ -20,8 +20,7 @@ from bs4 import BeautifulSoup from docx import Document from api.db.services.llm_service import LLMBundle from api.db.joint_services.tenant_model_service import ( - get_model_config_by_type_and_name, - get_tenant_default_model_by_type, + get_tenant_default_model_by_type, get_model_config_from_provider_instance, ) from common.constants import LLMType from deepdoc.parser.figure_parser import VisionFigureParser @@ -173,7 +172,7 @@ def enhance_media_sections_with_vision( try: try: - vision_model_config = get_model_config_by_type_and_name( + vision_model_config = get_model_config_from_provider_instance( tenant_id, LLMType.IMAGE2TEXT, vlm_conf["llm_id"] ) except Exception: diff --git a/rag/flow/tokenizer/tokenizer.py b/rag/flow/tokenizer/tokenizer.py index e4be09f005..cdf5ed0071 100644 --- a/rag/flow/tokenizer/tokenizer.py +++ b/rag/flow/tokenizer/tokenizer.py @@ -21,7 +21,7 @@ import numpy as np from common.constants import LLMType from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle -from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, get_model_config_by_type_and_name +from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from common.connection_utils import timeout from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.parser.pdf_chunk_metadata import finalize_pdf_chunk @@ -60,10 +60,7 @@ class Tokenizer(ProcessBase): token_count = 0 if self._canvas._kb_id: e, kb = KnowledgebaseService.get_by_id(self._canvas._kb_id) - if kb.tenant_embd_id: - embd_model_config = get_model_config_by_id(kb.tenant_embd_id) - else: - embd_model_config = get_model_config_by_type_and_name(self._canvas._tenant_id, LLMType.EMBEDDING, kb.embd_id) + embd_model_config = get_model_config_from_provider_instance(self._canvas._tenant_id, LLMType.EMBEDDING, kb.embd_id) else: embd_model_config = get_tenant_default_model_by_type(self._canvas._tenant_id, LLMType.EMBEDDING) embedding_model = LLMBundle(self._canvas._tenant_id, embd_model_config) diff --git a/rag/graphrag/general/smoke.py b/rag/graphrag/general/smoke.py index 02c1ab5cf6..7c8ee2de18 100644 --- a/rag/graphrag/general/smoke.py +++ b/rag/graphrag/general/smoke.py @@ -24,7 +24,7 @@ from common.constants import LLMType from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle -from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, get_model_config_by_type_and_name +from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from rag.graphrag.general.graph_extractor import GraphExtractor from rag.graphrag.general.index import update_graph, with_resolution, with_community from common import settings @@ -74,10 +74,7 @@ async def main(): llm_config = get_tenant_default_model_by_type(args.tenant_id, LLMType.CHAT) llm_bdl = LLMBundle(args.tenant_id, llm_config) _, kb = KnowledgebaseService.get_by_id(kb_id) - if kb.tenant_embd_id: - embd_model_config = get_model_config_by_id(kb.tenant_embd_id) - else: - embd_model_config = get_model_config_by_type_and_name(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) + embd_model_config = get_model_config_from_provider_instance(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) embed_bdl = LLMBundle(args.tenant_id, embd_model_config) graph, doc_ids = await update_graph( diff --git a/rag/graphrag/light/smoke.py b/rag/graphrag/light/smoke.py index 18af251518..4eeb254a6a 100644 --- a/rag/graphrag/light/smoke.py +++ b/rag/graphrag/light/smoke.py @@ -24,7 +24,7 @@ from common.constants import LLMType from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle -from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type +from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from rag.graphrag.general.index import update_graph from rag.graphrag.light.graph_extractor import GraphExtractor from common import settings @@ -75,10 +75,7 @@ async def main(): llm_config = get_tenant_default_model_by_type(args.tenant_id, LLMType.CHAT) llm_bdl = LLMBundle(args.tenant_id, llm_config) _, kb = KnowledgebaseService.get_by_id(kb_id) - if kb.tenant_embd_id: - embd_model_config = get_model_config_by_id(kb.tenant_embd_id) - else: - embd_model_config = get_model_config_by_type_and_name(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) + embd_model_config = get_model_config_from_provider_instance(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) embed_bdl = LLMBundle(args.tenant_id, embd_model_config) graph, doc_ids = await update_graph( diff --git a/rag/graphrag/search.py b/rag/graphrag/search.py index aec96ecf21..0c8a37a565 100644 --- a/rag/graphrag/search.py +++ b/rag/graphrag/search.py @@ -318,7 +318,7 @@ if __name__ == "__main__": from common.constants import LLMType from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle - from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, get_model_config_by_type_and_name + from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from rag.nlp import search settings.init_settings() @@ -332,10 +332,7 @@ if __name__ == "__main__": llm_config = get_tenant_default_model_by_type(args.tenant_id, LLMType.CHAT) llm_bdl = LLMBundle(args.tenant_id, llm_config) _, kb = KnowledgebaseService.get_by_id(kb_id) - if kb.tenant_embd_id: - embd_model_config = get_model_config_by_id(kb.tenant_embd_id) - else: - embd_model_config = get_model_config_by_type_and_name(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) + embd_model_config = get_model_config_from_provider_instance(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) embed_bdl = LLMBundle(args.tenant_id, embd_model_config) kg = KGSearch(settings.docStoreConn) diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index 83e277d4b8..83309cc76f 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -22,6 +22,7 @@ from copy import deepcopy from typing import Tuple from jinja2.sandbox import SandboxedEnvironment import json_repair + from common.misc_utils import hash_str2int from rag.nlp import rag_tokenizer from rag.prompts.template import load_prompt @@ -244,14 +245,14 @@ async def question_proposal(chat_mdl, content, topn=3): async def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None): from common.constants import LLMType from api.db.services.llm_service import LLMBundle - from api.db.services.tenant_llm_service import TenantLLMService - from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name + from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance, get_model_type_by_name if not chat_mdl: - if TenantLLMService.llm_id2llm_type(llm_id) == "image2text": - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.IMAGE2TEXT, llm_id) + model_types = get_model_type_by_name(tenant_id, llm_id) + if "image2text" in model_types: + chat_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.IMAGE2TEXT, llm_id) else: - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id) + chat_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.CHAT, llm_id) chat_mdl = LLMBundle(tenant_id, chat_model_config) conv = [] for m in messages: @@ -280,16 +281,15 @@ async def full_question(tenant_id=None, llm_id=None, messages=[], language=None, async def cross_languages(tenant_id, llm_id, query, languages=[]): from common.constants import LLMType from api.db.services.llm_service import LLMBundle - from api.db.services.tenant_llm_service import TenantLLMService - from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_tenant_default_model_by_type + from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance, get_tenant_default_model_by_type, get_model_type_by_name - if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text": - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.IMAGE2TEXT, llm_id) + if llm_id and "image2text" in get_model_type_by_name(tenant_id, llm_id) : + chat_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.IMAGE2TEXT, llm_id) else: if not llm_id: chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) else: - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id) + chat_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.CHAT, llm_id) chat_mdl = LLMBundle(tenant_id, chat_model_config) rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render() rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index ded3a0141a..3a4290d579 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -79,7 +79,7 @@ from api.db.services.doc_metadata_service import DocMetadataService from api.db.services.llm_service import LLMBundle from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID from api.db.services.file2document_service import File2DocumentService -from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_tenant_default_model_by_type +from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from common.versions import get_ragflow_version from api.db.db_models import close_connection from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \ @@ -423,7 +423,7 @@ async def build_chunks(task, progress_callback): if task["parser_config"].get("auto_keywords", 0): st = timer() progress_callback(msg="Start to generate keywords for every chunk ...") - chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"]) + chat_model_config = get_model_config_from_provider_instance(task["tenant_id"], LLMType.CHAT, task["llm_id"]) chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"]) async def doc_keyword_extraction(chat_mdl, d, topn): @@ -461,7 +461,7 @@ async def build_chunks(task, progress_callback): if task["parser_config"].get("auto_questions", 0): st = timer() progress_callback(msg="Start to generate questions for every chunk ...") - chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"]) + chat_model_config = get_model_config_from_provider_instance(task["tenant_id"], LLMType.CHAT, task["llm_id"]) chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"]) async def doc_question_proposal(chat_mdl, d, topn): @@ -498,7 +498,7 @@ async def build_chunks(task, progress_callback): if task["parser_config"].get("enable_metadata", False) and (task["parser_config"].get("metadata") or task["parser_config"].get("built_in_metadata")): st = timer() progress_callback(msg="Start to generate meta-data for every chunk ...") - chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"]) + chat_model_config = get_model_config_from_provider_instance(task["tenant_id"], LLMType.CHAT, task["llm_id"]) chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"]) async def gen_metadata_task(chat_mdl, d): @@ -575,7 +575,7 @@ async def build_chunks(task, progress_callback): set_tags_to_cache(kb_ids, all_tags) else: all_tags = json.loads(all_tags) - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, task["llm_id"]) + chat_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.CHAT, task["llm_id"]) chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"]) docs_to_tag = [] @@ -641,7 +641,7 @@ async def build_chunks(task, progress_callback): @timed_with_recording def build_TOC(task, docs, progress_callback): progress_callback(msg="Start to generate table of content ...") - chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"]) + chat_model_config = get_model_config_from_provider_instance(task["tenant_id"], LLMType.CHAT, task["llm_id"]) chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"]) docs = sorted(docs, key=lambda d: ( d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), @@ -817,7 +817,7 @@ async def run_dataflow(task: dict): set_progress(task_id, prog=0.82, msg="\n-------------------------------------\nStart to embedding...") e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) embedding_id = kb.embd_id - embd_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.EMBEDDING, embedding_id) + embd_model_config = get_model_config_from_provider_instance(task["tenant_id"], LLMType.EMBEDDING, embedding_id) embedding_model = LLMBundle(task["tenant_id"], embd_model_config) @timeout(60) @@ -1404,7 +1404,7 @@ async def do_handle_task(task): try: # bind embedding model if task_embedding_id: - embd_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.EMBEDDING, task_embedding_id) + embd_model_config = get_model_config_from_provider_instance(task_tenant_id, LLMType.EMBEDDING, task_embedding_id) else: embd_model_config = get_tenant_default_model_by_type(task_tenant_id, LLMType.EMBEDDING) embedding_model = LLMBundle(task_tenant_id, embd_model_config, lang=task_language) @@ -1452,7 +1452,7 @@ async def do_handle_task(task): return # bind LLM for raptor - chat_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.CHAT, kb_task_llm_id) + chat_model_config = get_model_config_from_provider_instance(task_tenant_id, LLMType.CHAT, kb_task_llm_id) chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language) # run RAPTOR async with kg_limiter: @@ -1511,7 +1511,7 @@ async def do_handle_task(task): graphrag_conf = kb_parser_config.get("graphrag", {}) start_ts = timer() - chat_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.CHAT, kb_task_llm_id) + chat_model_config = get_model_config_from_provider_instance(task_tenant_id, LLMType.CHAT, kb_task_llm_id) chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language) with_resolution = graphrag_conf.get("resolution", False) with_community = graphrag_conf.get("community", False) diff --git a/rag/svr/task_executor_refactor/chunk_post_processor.py b/rag/svr/task_executor_refactor/chunk_post_processor.py index fc12453357..5677c9bf53 100644 --- a/rag/svr/task_executor_refactor/chunk_post_processor.py +++ b/rag/svr/task_executor_refactor/chunk_post_processor.py @@ -39,7 +39,7 @@ from rag.svr.task_executor_refactor.task_context import TaskContext from api.db.services.doc_metadata_service import DocMetadataService from api.db.services.llm_service import LLMBundle -from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name +from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance from rag.prompts.generator import gen_metadata, keyword_extraction, question_proposal, content_tagging from rag.graphrag.utils import get_llm_cache, set_llm_cache @@ -55,7 +55,7 @@ async def extract_keywords(docs: List[Dict], ctx: TaskContext) -> None: st = timer() ctx.progress_cb(msg="Start to generate keywords for every chunk ...") - chat_model_config = get_model_config_by_type_and_name(ctx.tenant_id, LLMType.CHAT, ctx.llm_id) + chat_model_config = get_model_config_from_provider_instance(ctx.tenant_id, LLMType.CHAT, ctx.llm_id) with LLMBundle(ctx.tenant_id, chat_model_config, lang=ctx.language) as chat_model: async def doc_keyword_extraction(chat_mdl, d, topn): @@ -98,7 +98,7 @@ async def generate_questions(docs: List[Dict], ctx: TaskContext) -> None: st = timer() ctx.progress_cb(msg="Start to generate questions for every chunk ...") - chat_model_config = get_model_config_by_type_and_name(ctx.tenant_id, LLMType.CHAT, ctx.llm_id) + chat_model_config = get_model_config_from_provider_instance(ctx.tenant_id, LLMType.CHAT, ctx.llm_id) with LLMBundle(ctx.tenant_id, chat_model_config, lang=ctx.language) as chat_model: async def doc_question_proposal(chat_mdl, d, topn): @@ -179,7 +179,7 @@ async def generate_metadata(docs: List[Dict], ctx: TaskContext) -> None: st = timer() ctx.progress_cb(msg="Start to generate meta-data for every chunk ...") - chat_model_config = get_model_config_by_type_and_name(ctx.tenant_id, LLMType.CHAT, ctx.llm_id) + chat_model_config = get_model_config_from_provider_instance(ctx.tenant_id, LLMType.CHAT, ctx.llm_id) with LLMBundle(ctx.tenant_id, chat_model_config, lang=ctx.language) as chat_model: metadata_conf = build_metadata_config(ctx.parser_config) @@ -244,7 +244,7 @@ async def apply_tags(docs: List[Dict], ctx: TaskContext) -> None: st = timer() examples = [] all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S) - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, ctx.llm_id) + chat_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.CHAT, ctx.llm_id) with LLMBundle(ctx.tenant_id, chat_model_config, lang=ctx.language) as chat_model: docs_to_tag = [] diff --git a/rag/svr/task_executor_refactor/dataflow_service.py b/rag/svr/task_executor_refactor/dataflow_service.py index 006c6edff3..d67c6cb266 100644 --- a/rag/svr/task_executor_refactor/dataflow_service.py +++ b/rag/svr/task_executor_refactor/dataflow_service.py @@ -38,7 +38,7 @@ from api.db.services.canvas_service import UserCanvasService from api.db.services.document_service import DocumentService from api.db.services.doc_metadata_service import DocMetadataService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService -from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name +from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance from common.constants import LLMType, PipelineTaskType from common.metadata_utils import update_metadata_to from common.misc_utils import thread_pool_exec @@ -244,7 +244,7 @@ class DataflowService: self._progress(prog=0.82, msg="\n-------------------------------------\nStart to embedding...") e, kb = self._get_kb_by_id(ctx.kb_id) embedding_id = kb.embd_id - embd_model_config = get_model_config_by_type_and_name( + embd_model_config = get_model_config_from_provider_instance( ctx.tenant_id, LLMType.EMBEDDING, embedding_id ) from api.db.services.llm_service import LLMBundle diff --git a/rag/svr/task_executor_refactor/task_handler.py b/rag/svr/task_executor_refactor/task_handler.py index deee1b4b36..ebcd92c039 100644 --- a/rag/svr/task_executor_refactor/task_handler.py +++ b/rag/svr/task_executor_refactor/task_handler.py @@ -32,8 +32,8 @@ from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.joint_services.memory_message_service import handle_save_to_memory_task from api.db.joint_services.tenant_model_service import ( - get_model_config_by_type_and_name, get_tenant_default_model_by_type, + get_model_config_from_provider_instance ) from api.db.services.llm_service import LLMBundle from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID @@ -212,7 +212,7 @@ class TaskHandler: try: if task_embedding_id: - embd_model_config = get_model_config_by_type_and_name( + embd_model_config = get_model_config_from_provider_instance( task_tenant_id, LLMType.EMBEDDING, task_embedding_id ) else: @@ -269,7 +269,7 @@ class TaskHandler: return # Bind LLM for raptor - chat_model_config = get_model_config_by_type_and_name( + chat_model_config = get_model_config_from_provider_instance( task_tenant_id, LLMType.CHAT, kb_task_llm_id ) with LLMBundle(task_tenant_id, chat_model_config, lang=ctx.language) as chat_model: @@ -354,7 +354,7 @@ class TaskHandler: graphrag_conf = kb_parser_config.get("graphrag", {}) start_ts = timer() - chat_model_config = get_model_config_by_type_and_name( + chat_model_config = get_model_config_from_provider_instance( task_tenant_id, LLMType.CHAT, kb_task_llm_id ) with LLMBundle(task_tenant_id, chat_model_config, lang=task_language) as chat_model: @@ -512,7 +512,7 @@ class TaskHandler: def _build_toc(cls, ctx: TaskContext, docs: List[Dict], progress_cb: Callable) -> Optional[Dict]: """Build table of contents.""" progress_cb(msg="Start to generate table of content ...") - chat_model_config = get_model_config_by_type_and_name( + chat_model_config = get_model_config_from_provider_instance( ctx.tenant_id, LLMType.CHAT, ctx.llm_id ) with LLMBundle(ctx.tenant_id, chat_model_config, lang=ctx.language) as chat_mdl: diff --git a/sdk/python/test/conftest.py b/sdk/python/test/conftest.py index 682a715923..ac21be8d1b 100644 --- a/sdk/python/test/conftest.py +++ b/sdk/python/test/conftest.py @@ -92,6 +92,7 @@ def get_email(): def get_my_llms(auth, name): + # todo deprecated url = HOST_ADDRESS + "/v1/llm/my_llms" authorization = {"Authorization": auth} response = requests.get(url=url, headers=authorization) @@ -103,7 +104,20 @@ def get_my_llms(auth, name): return False +def get_added_models(auth, factory_name): + url = HOST_ADDRESS + "/api/v1/models" + authorization = {"Authorization": auth} + response = requests.get(url=url, headers=authorization) + res = response.json() + if res.get("code") != 0: + raise Exception(res.get("message")) + added_factory = {model["provider_name"] for model in res.get("data", [])} + if factory_name in added_factory: + return True + return False + def add_models(auth): + # todo deprecated url = HOST_ADDRESS + "/v1/llm/set_api_key" authorization = {"Authorization": auth} models_info = { @@ -118,7 +132,32 @@ def add_models(auth): pytest.exit(f"Critical error in add_models: {res.get('message')}") +def add_model_instance(auth): + add_provider_api = HOST_ADDRESS + "/api/v1/providers" + authorization = {"Authorization": auth} + add_provider_response = requests.put(url=add_provider_api, headers=authorization, json={"provider_name": "ZHIPU-AI"}) + add_provider_res = add_provider_response.json() + if add_provider_res.get("code") != 0: + pytest.exit(f"Critical error in add model provider: {add_provider_res.get('message')}") + + add_instance_api = HOST_ADDRESS + "/api/v1/providers/ZHIPU-AI/instances" + add_instance_response = requests.post(url=add_instance_api, headers=authorization, json={ + "instance_name": "CI", + "api_key": ZHIPU_AI_API_KEY, + "region": "default", + "base_url": "" + }) + add_instance_res = add_instance_response.json() + if add_instance_res.get("code") != 0: + pytest.exit(f"Critical error in add model instance: {add_instance_res.get('message')}") + + add_success = get_added_models(auth, "ZHIPU-AI") + if not add_success: + pytest.exit("Critical error in check added model: add model failed") + + def get_tenant_info(auth): + # todo deprecated url = HOST_ADDRESS + "/api/v1/users/me/models" authorization = {"Authorization": auth} response = requests.get(url=url, headers=authorization) @@ -131,22 +170,49 @@ def get_tenant_info(auth): @pytest.fixture(scope="session", autouse=True) def set_tenant_info(get_auth): auth = get_auth - try: - add_models(auth) - tenant_id = get_tenant_info(auth) - except Exception as e: - pytest.exit(f"Error in set_tenant_info: {str(e)}") - url = HOST_ADDRESS + "/api/v1/users/me/models" + if not get_added_models(auth, "ZHIPU-AI"): + try: + add_model_instance(auth) + except Exception as e: + pytest.exit(f"Error in set_tenant_info: {str(e)}") + url = HOST_ADDRESS + "/api/v1/models/default" authorization = {"Authorization": get_auth} - tenant_info = { - "tenant_id": tenant_id, - "llm_id": "glm-4-flash@ZHIPU-AI", - "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "img2txt_id": "glm-4v@ZHIPU-AI", - "asr_id": "", - "tts_id": None, - } - response = requests.patch(url=url, headers=authorization, json=tenant_info) - res = response.json() - if res.get("code") != 0: - raise Exception(res.get("message")) + # set chat model + set_default_llm_response = requests.patch( + url=url, + headers=authorization, + json={ + "model_provider": "ZHIPU-AI", + "model_instance": "CI", + "model_type": "chat", + "model_name": "glm-4-flash" + }) + llm_res = set_default_llm_response.json() + if llm_res.get("code") != 0: + raise Exception(llm_res.get("message")) + # set embedding model + set_default_embedding_response = requests.patch( + url=url, + headers=authorization, + json={ + "model_provider": "Builtin", + "model_instance": "Local", + "model_type": "embedding", + "model_name": "BAAI/bge-small-en-v1.5" + }) + embd_res = set_default_embedding_response.json() + if embd_res.get("code") != 0: + raise Exception(embd_res.get("message")) + # set image to text model + set_default_img2txt_response = requests.patch( + url=url, + headers=authorization, + json={ + "model_provider": "ZHIPU-AI", + "model_instance": "CI", + "model_type": "vision", + "model_name": "glm-4v" + }) + img2txt_res = set_default_img2txt_response.json() + if img2txt_res.get("code") != 0: + raise Exception(img2txt_res.get("message")) diff --git a/test/testcases/configs.py b/test/testcases/configs.py index 3f093bd814..d6428d5527 100644 --- a/test/testcases/configs.py +++ b/test/testcases/configs.py @@ -46,7 +46,7 @@ DEFAULT_PARSER_CONFIG = { "image_context_size": 0, "table_context_size": 0, "topn_tags": 3, - "llm_id": "glm-4-flash@ZHIPU-AI", + "llm_id": "glm-4-flash@CI@ZHIPU-AI", "raptor": { "use_raptor": True, "prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.", diff --git a/test/testcases/conftest.py b/test/testcases/conftest.py index a4de7aebc8..41e9fc62c3 100644 --- a/test/testcases/conftest.py +++ b/test/testcases/conftest.py @@ -171,6 +171,7 @@ def token(auth): def get_my_llms(auth, name): + # todo deprecated url = HOST_ADDRESS + f"/{VERSION}/llm/my_llms" authorization = {"Authorization": auth} response = requests.get(url=url, headers=authorization) @@ -182,7 +183,21 @@ def get_my_llms(auth, name): return False +def get_added_models(auth, factory_name): + url = HOST_ADDRESS + "/api/v1/models" + authorization = {"Authorization": auth} + response = requests.get(url=url, headers=authorization) + res = response.json() + if res.get("code") != 0: + raise Exception(res.get("message")) + added_factory = {model["provider_name"] for model in res.get("data", [])} + if factory_name in added_factory: + return True + return False + + def add_models(auth): + # todo deprecated url = HOST_ADDRESS + f"/{VERSION}/llm/set_api_key" authorization = {"Authorization": auth} models_info = { @@ -197,7 +212,32 @@ def add_models(auth): pytest.exit(f"Critical error in add_models: {res.get('message')}") +def add_model_instance(auth): + add_provider_api = HOST_ADDRESS + "/api/v1/providers" + authorization = {"Authorization": auth} + add_provider_response = requests.put(url=add_provider_api, headers=authorization, json={"provider_name": "ZHIPU-AI"}) + add_provider_res = add_provider_response.json() + if add_provider_res.get("code") != 0: + pytest.exit(f"Critical error in add model provider: {add_provider_res.get('message')}") + + add_instance_api = HOST_ADDRESS + "/api/v1/providers/ZHIPU-AI/instances" + add_instance_response = requests.post(url=add_instance_api, headers=authorization, json={ + "instance_name": "CI", + "api_key": ZHIPU_AI_API_KEY, + "region": "default", + "base_url": "" + }) + add_instance_res = add_instance_response.json() + if add_instance_res.get("code") != 0: + pytest.exit(f"Critical error in add model instance: {add_instance_res.get('message')}") + + add_success = get_added_models(auth, "ZHIPU-AI") + if not add_success: + pytest.exit("Critical error in check added model: add model failed") + + def get_tenant_info(auth): + # todo deprecated url = HOST_ADDRESS + f"/api/{VERSION}/users/me/models" authorization = {"Authorization": auth} response = requests.get(url=url, headers=authorization) @@ -209,23 +249,36 @@ def get_tenant_info(auth): @pytest.fixture(scope="session", autouse=True) def set_tenant_info(auth): - tenant_id = None - try: - add_models(auth) - tenant_id = get_tenant_info(auth) - except Exception as e: - pytest.exit(f"Error in set_tenant_info: {str(e)}") - url = HOST_ADDRESS + f"/api/{VERSION}/users/me/models" + if not get_added_models(auth, "ZHIPU-AI"): + try: + add_model_instance(auth) + except Exception as e: + pytest.exit(f"Error in set_tenant_info: {str(e)}") + url = HOST_ADDRESS + "/api/v1/models/default" authorization = {"Authorization": auth} - tenant_info = { - "tenant_id": tenant_id, - "llm_id": "glm-4-flash@ZHIPU-AI", - "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "img2txt_id": "", - "asr_id": "", - "tts_id": None, - } - response = requests.patch(url=url, headers=authorization, json=tenant_info) - res = response.json() - if res.get("code") != 0: - raise Exception(res.get("message")) + # set chat model + set_default_llm_response = requests.patch( + url=url, + headers=authorization, + json={ + "model_provider": "ZHIPU-AI", + "model_instance": "CI", + "model_type": "chat", + "model_name": "glm-4-flash" + }) + llm_res = set_default_llm_response.json() + if llm_res.get("code") != 0: + raise Exception(llm_res.get("message")) + # set embedding model + set_default_embedding_response = requests.patch( + url=url, + headers=authorization, + json={ + "model_provider": "Builtin", + "model_instance": "Local", + "model_type": "embedding", + "model_name": "BAAI/bge-small-en-v1.5" + }) + embd_res = set_default_embedding_response.json() + if embd_res.get("code") != 0: + raise Exception(embd_res.get("message")) diff --git a/test/testcases/restful_api/test_chats.py b/test/testcases/restful_api/test_chats.py index eaf94a13c9..a097c1c1fe 100644 --- a/test/testcases/restful_api/test_chats.py +++ b/test/testcases/restful_api/test_chats.py @@ -753,26 +753,6 @@ def _load_chat_routes_unit_module(monkeypatch): kb_service_mod.KnowledgebaseService = _StubKnowledgebaseService monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod) - tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") - - class _StubTenantLLMService: - @staticmethod - def split_model_name_and_factory(model_name): - if model_name and "@" in model_name: - return tuple(model_name.split("@", 1)) - return model_name, None - - @staticmethod - def query(**_kwargs): - return [] - - @staticmethod - def get_api_key(*_args, **_kwargs): - return SimpleNamespace(id=1) - - tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService - monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod) - llm_service_mod = ModuleType("api.db.services.llm_service") llm_service_mod.LLMBundle = lambda *_args, **_kwargs: None monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod) @@ -782,8 +762,10 @@ def _load_chat_routes_unit_module(monkeypatch): monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod) tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service") - tenant_model_service_mod.get_model_config_by_type_and_name = lambda *_args, **_kwargs: {} + tenant_model_service_mod.get_model_config_from_provider_instance = lambda *_args, **_kwargs: {} tenant_model_service_mod.get_tenant_default_model_by_type = lambda *_args, **_kwargs: {} + tenant_model_service_mod.get_api_key = lambda *_args, **_kwargs: SimpleNamespace(id=1) + tenant_model_service_mod.split_model_name = lambda model: (model.split("@")[0],"default", "factory") monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) user_service_mod = ModuleType("api.db.services.user_service") @@ -832,10 +814,6 @@ def _load_chat_routes_unit_module(monkeypatch): api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func) monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) - tenant_utils_mod = ModuleType("api.utils.tenant_utils") - tenant_utils_mod.ensure_tenant_model_id_for_params = lambda _tenant_id, req: req - monkeypatch.setitem(sys.modules, "api.utils.tenant_utils", tenant_utils_mod) - rag_pkg = ModuleType("rag") rag_pkg.__path__ = [str(repo_root / "rag")] monkeypatch.setitem(sys.modules, "rag", rag_pkg) @@ -1151,7 +1129,7 @@ def test_chat_create_accepts_provider_scoped_rerank_id_unit(monkeypatch): "name": "chat-a", "icon": "icon.png", "dataset_ids": ["kb-1"], - "llm_id": "glm-4@ZHIPU-AI", + "llm_id": "glm-4@@CI@ZHIPU-AI", "llm_setting": {"temperature": 0.8}, "prompt_config": { "system": "Answer with {knowledge}", @@ -1162,7 +1140,7 @@ def test_chat_create_accepts_provider_scoped_rerank_id_unit(monkeypatch): "vector_similarity_weight": 0.25, }, ) - monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4@ZHIPU-AI"))) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4@CI@ZHIPU-AI"))) monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()]) @@ -1170,30 +1148,18 @@ def test_chat_create_accepts_provider_scoped_rerank_id_unit(monkeypatch): def _split_model_name_and_factory(model_name): return { - "glm-4@ZHIPU-AI": ("glm-4", "ZHIPU-AI"), - "custom-reranker@OpenAI": ("custom-reranker", "OpenAI"), + "glm-4@ZHIPU-AI": ("glm-4", "default", "ZHIPU-AI"), + "glm-4@CI@ZHIPU-AI": ("glm-4", "CI", "ZHIPU-AI"), + "custom-reranker@OpenAI": ("custom-reranker", "default", "OpenAI") }.get(model_name, (model_name, None)) - def _query(**kwargs): - query_calls.append(kwargs) - if kwargs == { - "tenant_id": "tenant-1", - "llm_name": "glm-4", - "llm_factory": "ZHIPU-AI", - "model_type": "chat", - }: - return [SimpleNamespace(id="llm-1")] - if kwargs == { - "tenant_id": "tenant-1", - "llm_name": "custom-reranker", - "llm_factory": "OpenAI", - "model_type": "rerank", - }: - return [SimpleNamespace(id="rerank-1")] - return [] + monkeypatch.setattr(module, "split_model_name", _split_model_name_and_factory) - monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", _split_model_name_and_factory) - monkeypatch.setattr(module.TenantLLMService, "query", _query) + def _get_model_config_from_provider_instance(**kwargs): + query_calls.append(kwargs) + return {} + + monkeypatch.setattr(module, "get_model_config_from_provider_instance", _get_model_config_from_provider_instance) def _save(**kwargs): saved.update(kwargs) @@ -1207,8 +1173,7 @@ def test_chat_create_accepts_provider_scoped_rerank_id_unit(monkeypatch): assert saved["rerank_id"] == "custom-reranker@OpenAI" assert { "tenant_id": "tenant-1", - "llm_name": "custom-reranker", - "llm_factory": "OpenAI", + "model_name": "custom-reranker@OpenAI", "model_type": "rerank", } in query_calls @@ -1220,7 +1185,7 @@ def test_chat_create_allows_default_knowledge_placeholder_without_sources_unit(m _set_route_unit_request_json(monkeypatch, module, {"name": "chat-a"}) monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) - monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda *_args, **_kwargs: SimpleNamespace(id=1)) + monkeypatch.setattr(module, "get_api_key", lambda *_args, **_kwargs: SimpleNamespace(id=1)) def _save(**kwargs): saved.update(kwargs) @@ -1262,8 +1227,7 @@ def test_chat_create_uses_direct_chat_fields_unit(monkeypatch): monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()]) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) - monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda model: (model.split("@")[0], "factory")) - monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: [SimpleNamespace(id="llm-1")]) + monkeypatch.setattr(module, "split_model_name", lambda model: (model.split("@")[0],"default", "factory")) def _save(**kwargs): saved.update(kwargs) @@ -1407,8 +1371,8 @@ def test_patch_chat_drops_response_only_fields_before_update_unit(monkeypatch): monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()]) - monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda model: (model.split("@")[0], "factory")) - monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: [SimpleNamespace(id="llm-1")]) + monkeypatch.setattr(module, "split_model_name", lambda model: (model.split("@")[0],"default", "factory")) + monkeypatch.setattr(module, "get_api_key", lambda *args, **kwargs: SimpleNamespace(id=1)) def _update(_chat_id, req): updated.update(req) @@ -1477,8 +1441,7 @@ def test_update_chat_allows_knowledge_placeholder_without_sources_unit(monkeypat monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(existing))) monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) - monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda model: (model.split("@")[0], "factory")) - monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: [SimpleNamespace(id="llm-1")]) + monkeypatch.setattr(module, "split_model_name", lambda model: (model.split("@")[0], "default", "factory")) updated = {} def _update(_chat_id, payload): @@ -1534,35 +1497,35 @@ def test_chat_create_avatar_contract(rest_client, clear_chats, tmp_path): def test_chat_create_llm_contract(rest_client, clear_chats, ensure_parsed_document): dataset_id, _ = ensure_parsed_document() cases = [ - ("default llm", {}, 0, "", "glm-4-flash@ZHIPU-AI", {}), - ("explicit llm_id", {"llm_id": "glm-4"}, 0, "", "glm-4", {}), + ("default llm", {}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {}), + ("explicit llm_id", {"llm_id": "glm-4"}, 102, "`llm_id` glm-4 doesn't exist", None, None), ("unknown llm_id", {"llm_id": "unknown"}, 102, "`llm_id` unknown doesn't exist", None, None), - ("temperature zero", {"llm_setting": {"temperature": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": 0}), - ("temperature one", {"llm_setting": {"temperature": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": 1}), - ("temperature negative one", {"llm_setting": {"temperature": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": -1}), - ("temperature ten", {"llm_setting": {"temperature": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": 10}), - ("temperature string", {"llm_setting": {"temperature": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": "a"}), - ("top_p zero", {"llm_setting": {"top_p": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": 0}), - ("top_p one", {"llm_setting": {"top_p": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": 1}), - ("top_p negative one", {"llm_setting": {"top_p": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": -1}), - ("top_p ten", {"llm_setting": {"top_p": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": 10}), - ("top_p string", {"llm_setting": {"top_p": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": "a"}), - ("presence_penalty zero", {"llm_setting": {"presence_penalty": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": 0}), - ("presence_penalty one", {"llm_setting": {"presence_penalty": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": 1}), - ("presence_penalty negative one", {"llm_setting": {"presence_penalty": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": -1}), - ("presence_penalty ten", {"llm_setting": {"presence_penalty": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": 10}), - ("presence_penalty string", {"llm_setting": {"presence_penalty": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": "a"}), - ("frequency_penalty zero", {"llm_setting": {"frequency_penalty": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": 0}), - ("frequency_penalty one", {"llm_setting": {"frequency_penalty": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": 1}), - ("frequency_penalty negative one", {"llm_setting": {"frequency_penalty": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": -1}), - ("frequency_penalty ten", {"llm_setting": {"frequency_penalty": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": 10}), - ("frequency_penalty string", {"llm_setting": {"frequency_penalty": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": "a"}), - ("max_token zero", {"llm_setting": {"max_token": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": 0}), - ("max_token 1024", {"llm_setting": {"max_token": 1024}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": 1024}), - ("max_token negative one", {"llm_setting": {"max_token": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": -1}), - ("max_token ten", {"llm_setting": {"max_token": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": 10}), - ("max_token string", {"llm_setting": {"max_token": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": "a"}), - ("unknown llm setting key", {"llm_setting": {"unknown": "unknown"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"unknown": "unknown"}), + ("temperature zero", {"llm_setting": {"temperature": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": 0}), + ("temperature one", {"llm_setting": {"temperature": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": 1}), + ("temperature negative one", {"llm_setting": {"temperature": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": -1}), + ("temperature ten", {"llm_setting": {"temperature": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": 10}), + ("temperature string", {"llm_setting": {"temperature": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": "a"}), + ("top_p zero", {"llm_setting": {"top_p": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": 0}), + ("top_p one", {"llm_setting": {"top_p": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": 1}), + ("top_p negative one", {"llm_setting": {"top_p": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": -1}), + ("top_p ten", {"llm_setting": {"top_p": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": 10}), + ("top_p string", {"llm_setting": {"top_p": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": "a"}), + ("presence_penalty zero", {"llm_setting": {"presence_penalty": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": 0}), + ("presence_penalty one", {"llm_setting": {"presence_penalty": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": 1}), + ("presence_penalty negative one", {"llm_setting": {"presence_penalty": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": -1}), + ("presence_penalty ten", {"llm_setting": {"presence_penalty": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": 10}), + ("presence_penalty string", {"llm_setting": {"presence_penalty": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": "a"}), + ("frequency_penalty zero", {"llm_setting": {"frequency_penalty": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": 0}), + ("frequency_penalty one", {"llm_setting": {"frequency_penalty": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": 1}), + ("frequency_penalty negative one", {"llm_setting": {"frequency_penalty": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": -1}), + ("frequency_penalty ten", {"llm_setting": {"frequency_penalty": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": 10}), + ("frequency_penalty string", {"llm_setting": {"frequency_penalty": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": "a"}), + ("max_token zero", {"llm_setting": {"max_token": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": 0}), + ("max_token 1024", {"llm_setting": {"max_token": 1024}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": 1024}), + ("max_token negative one", {"llm_setting": {"max_token": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": -1}), + ("max_token ten", {"llm_setting": {"max_token": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": 10}), + ("max_token string", {"llm_setting": {"max_token": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": "a"}), + ("unknown llm setting key", {"llm_setting": {"unknown": "unknown"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"unknown": "unknown"}), ] for index, (scenario_name, extra_payload, expected_code, expected_message, expected_llm_id, expected_llm_setting) in enumerate(cases, start=1): @@ -1801,35 +1764,35 @@ def test_chat_update_avatar_contract(rest_client, clear_chats, ensure_parsed_doc def test_chat_update_llm_contract(rest_client, clear_chats, ensure_parsed_document): dataset_id, _ = ensure_parsed_document() cases = [ - ("default llm", {}, 0, "", "glm-4-flash@ZHIPU-AI", {}), - ("explicit llm_id", {"llm_id": "glm-4"}, 0, "", "glm-4", {}), + ("default llm", {}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {}), + ("explicit llm_id", {"llm_id": "glm-4"}, 102, "`llm_id` glm-4 doesn't exist", None, None), ("unknown llm_id", {"llm_id": "unknown"}, 102, "`llm_id` unknown doesn't exist", None, None), - ("temperature zero", {"llm_setting": {"temperature": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": 0}), - ("temperature one", {"llm_setting": {"temperature": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": 1}), - ("temperature negative one", {"llm_setting": {"temperature": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": -1}), - ("temperature ten", {"llm_setting": {"temperature": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": 10}), - ("temperature string", {"llm_setting": {"temperature": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": "a"}), - ("top_p zero", {"llm_setting": {"top_p": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": 0}), - ("top_p one", {"llm_setting": {"top_p": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": 1}), - ("top_p negative one", {"llm_setting": {"top_p": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": -1}), - ("top_p ten", {"llm_setting": {"top_p": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": 10}), - ("top_p string", {"llm_setting": {"top_p": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": "a"}), - ("presence_penalty zero", {"llm_setting": {"presence_penalty": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": 0}), - ("presence_penalty one", {"llm_setting": {"presence_penalty": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": 1}), - ("presence_penalty negative one", {"llm_setting": {"presence_penalty": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": -1}), - ("presence_penalty ten", {"llm_setting": {"presence_penalty": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": 10}), - ("presence_penalty string", {"llm_setting": {"presence_penalty": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": "a"}), - ("frequency_penalty zero", {"llm_setting": {"frequency_penalty": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": 0}), - ("frequency_penalty one", {"llm_setting": {"frequency_penalty": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": 1}), - ("frequency_penalty negative one", {"llm_setting": {"frequency_penalty": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": -1}), - ("frequency_penalty ten", {"llm_setting": {"frequency_penalty": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": 10}), - ("frequency_penalty string", {"llm_setting": {"frequency_penalty": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": "a"}), - ("max_token zero", {"llm_setting": {"max_token": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": 0}), - ("max_token 1024", {"llm_setting": {"max_token": 1024}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": 1024}), - ("max_token negative one", {"llm_setting": {"max_token": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": -1}), - ("max_token ten", {"llm_setting": {"max_token": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": 10}), - ("max_token string", {"llm_setting": {"max_token": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": "a"}), - ("unknown llm setting key", {"llm_setting": {"unknown": "unknown"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"unknown": "unknown"}), + ("temperature zero", {"llm_setting": {"temperature": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": 0}), + ("temperature one", {"llm_setting": {"temperature": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": 1}), + ("temperature negative one", {"llm_setting": {"temperature": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": -1}), + ("temperature ten", {"llm_setting": {"temperature": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": 10}), + ("temperature string", {"llm_setting": {"temperature": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": "a"}), + ("top_p zero", {"llm_setting": {"top_p": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": 0}), + ("top_p one", {"llm_setting": {"top_p": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": 1}), + ("top_p negative one", {"llm_setting": {"top_p": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": -1}), + ("top_p ten", {"llm_setting": {"top_p": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": 10}), + ("top_p string", {"llm_setting": {"top_p": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": "a"}), + ("presence_penalty zero", {"llm_setting": {"presence_penalty": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": 0}), + ("presence_penalty one", {"llm_setting": {"presence_penalty": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": 1}), + ("presence_penalty negative one", {"llm_setting": {"presence_penalty": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": -1}), + ("presence_penalty ten", {"llm_setting": {"presence_penalty": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": 10}), + ("presence_penalty string", {"llm_setting": {"presence_penalty": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": "a"}), + ("frequency_penalty zero", {"llm_setting": {"frequency_penalty": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": 0}), + ("frequency_penalty one", {"llm_setting": {"frequency_penalty": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": 1}), + ("frequency_penalty negative one", {"llm_setting": {"frequency_penalty": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": -1}), + ("frequency_penalty ten", {"llm_setting": {"frequency_penalty": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": 10}), + ("frequency_penalty string", {"llm_setting": {"frequency_penalty": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": "a"}), + ("max_token zero", {"llm_setting": {"max_token": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": 0}), + ("max_token 1024", {"llm_setting": {"max_token": 1024}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": 1024}), + ("max_token negative one", {"llm_setting": {"max_token": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": -1}), + ("max_token ten", {"llm_setting": {"max_token": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": 10}), + ("max_token string", {"llm_setting": {"max_token": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": "a"}), + ("unknown llm setting key", {"llm_setting": {"unknown": "unknown"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"unknown": "unknown"}), ] for index, (scenario_name, extra_payload, expected_code, expected_message, expected_llm_id, expected_llm_setting) in enumerate(cases, start=1): @@ -2033,13 +1996,13 @@ def test_chat_update_mapping_and_validation_branches_p2(rest_client, clear_chats def test_chat_update_rejects_unparsed_document(rest_client, clear_chats, create_document): dataset_id, _ = create_document() create_res = rest_client.post("/chats", json={"name": "restful_chat_update_unparsed_target", "dataset_ids": []}) - assert create_res.status_code == 200 + assert create_res.status_code == 200, create_res.text create_payload = create_res.json() assert create_payload["code"] == 0, create_payload chat_id = create_payload["data"]["id"] res = rest_client.patch(f"/chats/{chat_id}", json={"dataset_ids": [dataset_id]}) - assert res.status_code == 200 + assert res.status_code == 200, res.text payload = res.json() assert payload["code"] == 102, payload assert "doesn't own parsed file" in payload["message"], payload diff --git a/test/testcases/restful_api/test_datasets.py b/test/testcases/restful_api/test_datasets.py index fb4fe73c24..4050d9b347 100644 --- a/test/testcases/restful_api/test_datasets.py +++ b/test/testcases/restful_api/test_datasets.py @@ -459,7 +459,7 @@ def test_dataset_update_embedding_model_with_existing_chunks_contract(rest_clien assert dataset_payload["code"] == 0, dataset_payload current_embedding = dataset_payload["data"]["embedding_model"] - candidates = ["embedding-3@ZHIPU-AI", "BAAI/bge-small-en-v1.5@Builtin"] + candidates = ["embedding-3@CI@ZHIPU-AI", "BAAI/bge-small-en-v1.5@Local@Builtin"] last_payload = None for candidate in candidates: if candidate == current_embedding: @@ -810,10 +810,10 @@ def test_dataset_update_embedding_model_invalid_and_none_contract(rest_client, c dataset_id = create_payload["data"]["id"] invalid_cases = [ - ("unknown@ZHIPU-AI", "Unsupported model: "), - ("embedding-3@unknown", "Unsupported model: "), - ("text-embedding-v3@Tongyi-Qianwen", "Unauthorized model: "), - ("text-embedding-3-small@OpenAI", "Unauthorized model: "), + ("unknown@ZHIPU-AI", "Instance default not found for model unknown@ZHIPU-AI."), + ("embedding-3@unknown", "Provider unknown not found for model embedding-3@unknown."), + ("text-embedding-v3@Tongyi-Qianwen", "Provider Tongyi-Qianwen not found for model text-embedding-v3@Tongyi-Qianwen."), + ("text-embedding-3-small@OpenAI", "Provider OpenAI not found for model text-embedding-3-small@OpenAI."), ] for embedding_model, expected_message in invalid_cases: res = rest_client.put( @@ -834,7 +834,7 @@ def test_dataset_update_embedding_model_invalid_and_none_contract(rest_client, c assert list_res.status_code == 200 list_payload = list_res.json() assert list_payload["code"] == 0, list_payload - assert list_payload["data"][0]["embedding_model"] == "BAAI/bge-small-en-v1.5@Builtin", list_payload + assert list_payload["data"][0]["embedding_model"] == "BAAI/bge-small-en-v1.5@Local@Builtin", list_payload @pytest.mark.p2 @@ -1164,21 +1164,21 @@ def test_dataset_create_permission_contract(rest_client, clear_datasets, name, p @pytest.mark.parametrize( "name, embedding_model, expected_code, expected_embedding_model, expected_message, unauthorized_is_xfail", [ - ("builtin_baai", "BAAI/bge-small-en-v1.5@Builtin", 0, "BAAI/bge-small-en-v1.5@Builtin", None, False), - ("tenant_zhipu", "embedding-3@ZHIPU-AI", 0, "embedding-3@ZHIPU-AI", None, True), - ("embedding_model_unset", "__UNSET__", 0, "BAAI/bge-small-en-v1.5@Builtin", None, False), - ("embedding_model_none", None, 0, "BAAI/bge-small-en-v1.5@Builtin", None, False), - ("unknown_llm_name", "unknown@ZHIPU-AI", 102, None, "Unsupported model: ", False), - ("unknown_llm_factory", "embedding-3@unknown", 102, None, "Unsupported model: ", False), + ("builtin_baai", "BAAI/bge-small-en-v1.5@Local@Builtin", 0, "BAAI/bge-small-en-v1.5@Local@Builtin", None, False), + ("tenant_zhipu", "embedding-3@CI@ZHIPU-AI", 0, "embedding-3@CI@ZHIPU-AI", None, True), + ("embedding_model_unset", "__UNSET__", 0, "BAAI/bge-small-en-v1.5@Local@Builtin", None, False), + ("embedding_model_none", None, 0, "BAAI/bge-small-en-v1.5@Local@Builtin", None, False), + ("unknown_llm_name", "unknown@ZHIPU-AI", 102, None, "Instance default not found for model unknown@ZHIPU-AI.", False), + ("unknown_llm_factory", "embedding-3@unknown", 102, None, "Provider unknown not found for model embedding-3@unknown.", False), ( "tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen", 102, None, - "Unauthorized model: ", + "Provider Tongyi-Qianwen not found for model text-embedding-v3@Tongyi-Qianwen.", False, ), - ("tenant_no_auth", "text-embedding-3-small@OpenAI", 102, None, "Unauthorized model: ", False), + ("tenant_no_auth", "text-embedding-3-small@OpenAI", 102, None, "Provider OpenAI not found for model text-embedding-3-small@OpenAI.", False), ], ids=[ "builtin_baai", diff --git a/test/testcases/restful_api/test_dify_retrieval_routes_unit.py b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py index 1fad7ebae5..90ea5b781d 100644 --- a/test/testcases/restful_api/test_dify_retrieval_routes_unit.py +++ b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py @@ -103,8 +103,6 @@ def _load_dify_retrieval_module(monkeypatch): monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils) monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost")) - tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") - class _MockModelConfig: def __init__(self, tenant_id, model_name): self.tenant_id = tenant_id @@ -157,15 +155,6 @@ def _load_dify_retrieval_module(monkeypatch): return parts[0], parts[1] return model_name, None - tenant_llm_service_mod.TenantService = _StubTenantService - tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService - - class _StubLLMFactoriesService: - pass - - tenant_llm_service_mod.LLMFactoriesService = _StubLLMFactoriesService - monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod) - llm_service_mod = ModuleType("api.db.services.llm_service") class _StubLLM: @@ -233,12 +222,17 @@ def _load_dify_retrieval_module(monkeypatch): raise Exception("Model Name is required") return _MockModelConfig2(tenant_id, model_name).to_dict() + def _get_model_config_from_provider_instance(tenant_id: str, model_type: str, model_name: str): + if not model_name: + raise Exception("Model Name is required") + return _MockModelConfig2(tenant_id, model_name).to_dict() + def _get_tenant_default_model_by_type(tenant_id: str, model_type): return _MockModelConfig2(tenant_id, "chat-model").to_dict() tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id - tenant_model_service_mod.get_model_config_by_type_and_name = _get_model_config_by_type_and_name tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type + tenant_model_service_mod.get_model_config_from_provider_instance = _get_model_config_from_provider_instance monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) module_name = "test_dify_retrieval_routes_unit_module" diff --git a/test/testcases/restful_api/test_openai_compatible.py b/test/testcases/restful_api/test_openai_compatible.py index 49e2c55ca5..d753842e3c 100644 --- a/test/testcases/restful_api/test_openai_compatible.py +++ b/test/testcases/restful_api/test_openai_compatible.py @@ -124,7 +124,7 @@ def test_openai_compatible_nonstream_shape(rest_client, create_chat): assert res.status_code == 200 payload = res.json() - assert payload["object"] == "chat.completion", payload + assert payload.get("object") == "chat.completion", payload assert isinstance(payload["choices"], list) and payload["choices"], payload first_choice = payload["choices"][0] assert first_choice.get("finish_reason") == "stop", payload diff --git a/test/testcases/restful_api/test_user_tenant_routes_unit.py b/test/testcases/restful_api/test_user_tenant_routes_unit.py index 0bec8fe582..ec9af40790 100644 --- a/test/testcases/restful_api/test_user_tenant_routes_unit.py +++ b/test/testcases/restful_api/test_user_tenant_routes_unit.py @@ -620,10 +620,6 @@ def _load_user_app(monkeypatch): api_utils_mod.validate_request = _validate_request monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) - tenant_utils_mod = ModuleType("api.utils.tenant_utils") - tenant_utils_mod.ensure_tenant_model_id_for_params = lambda _tenant_id, params: params - monkeypatch.setitem(sys.modules, "api.utils.tenant_utils", tenant_utils_mod) - crypt_mod = ModuleType("api.utils.crypt") crypt_mod.decrypt = lambda value: value monkeypatch.setitem(sys.modules, "api.utils.crypt", crypt_mod) @@ -1505,13 +1501,20 @@ def _load_chat_routes_unit_module(monkeypatch): }) monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod) - tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") - tenant_llm_service_mod.TenantLLMService = type('TenantLLMService', (), { - 'split_model_name_and_factory': staticmethod(lambda model: (model.split('@', 1)[0], model.split('@', 1)[1] if '@' in model else None)), - 'query': staticmethod(lambda **_kwargs: [SimpleNamespace(id='llm-1')]), - 'get_api_key': staticmethod(lambda *_args, **_kwargs: SimpleNamespace(id=1)), - }) - monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod) + tenant_model_provider_mod = ModuleType("api.db.joint_services.tenant_model_service") + tenant_model_provider_mod.get_model_config_from_provider_instance = lambda *_args, **_kwargs: {} + tenant_model_provider_mod.get_tenant_default_model_by_type = lambda *_args, **_kwargs: {} + def _split_model_name(model_name): + parts = model_name.split("@") + if len(parts) == 1: + return parts[0], "", "" + elif len(parts) == 2: + return parts[0], "default", parts[1] + else: + return parts[0], parts[1], parts[2] + tenant_model_provider_mod.split_model_name = staticmethod(_split_model_name) + tenant_model_provider_mod.get_api_key = lambda *_args, **_kwargs: SimpleNamespace(id=1) + monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_provider_mod) llm_service_mod = ModuleType("api.db.services.llm_service") llm_service_mod.LLMBundle = lambda *_args, **_kwargs: None @@ -1521,11 +1524,6 @@ def _load_chat_routes_unit_module(monkeypatch): search_service_mod.SearchService = SimpleNamespace() monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod) - tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service") - tenant_model_service_mod.get_model_config_by_type_and_name = lambda *_args, **_kwargs: {} - tenant_model_service_mod.get_tenant_default_model_by_type = lambda *_args, **_kwargs: {} - monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) - user_service_mod = ModuleType("api.db.services.user_service") user_service_mod.UserService = type('UserService', (), {}) user_service_mod.TenantService = type('TenantService', (), { @@ -1548,10 +1546,6 @@ def _load_chat_routes_unit_module(monkeypatch): api_utils_mod.get_request_json = lambda: _AwaitableValue({}) monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) - tenant_utils_mod = ModuleType("api.utils.tenant_utils") - tenant_utils_mod.ensure_tenant_model_id_for_params = lambda _tenant_id, req: req - monkeypatch.setitem(sys.modules, "api.utils.tenant_utils", tenant_utils_mod) - rag_pkg = ModuleType("rag") rag_pkg.__path__ = [str(repo_root / 'rag')] monkeypatch.setitem(sys.modules, 'rag', rag_pkg) diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py index fa0894f142..18df0f9804 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py @@ -366,7 +366,7 @@ def _load_chat_module(monkeypatch): monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod) tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service") - tenant_model_service_mod.get_model_config_by_type_and_name = lambda *_args, **_kwargs: {} + tenant_model_service_mod.get_model_config_from_provider_instance = lambda *_args, **_kwargs: {} tenant_model_service_mod.get_tenant_default_model_by_type = lambda *_args, **_kwargs: {} monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) @@ -418,10 +418,6 @@ def _load_chat_module(monkeypatch): api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func) monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) - tenant_utils_mod = ModuleType("api.utils.tenant_utils") - tenant_utils_mod.ensure_tenant_model_id_for_params = lambda _tenant_id, req: req - monkeypatch.setitem(sys.modules, "api.utils.tenant_utils", tenant_utils_mod) - rag_pkg = ModuleType("rag") rag_pkg.__path__ = [str(repo_root / "rag")] monkeypatch.setitem(sys.modules, "rag", rag_pkg) diff --git a/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py b/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py index e73f18959c..c08dd482d0 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py +++ b/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py @@ -238,7 +238,7 @@ def _load_dify_retrieval_module(monkeypatch): raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") return _MockModelConfig2(mock_tenant_id, "model-1").to_dict() - def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str): + def _get_model_config_from_provider_instance(tenant_id: str, model_type: str, model_name: str): if not model_name: raise Exception("Model Name is required") return _MockModelConfig2(tenant_id, model_name).to_dict() @@ -248,7 +248,7 @@ def _load_dify_retrieval_module(monkeypatch): return _MockModelConfig2(tenant_id, "chat-model").to_dict() tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id - tenant_model_service_mod.get_model_config_by_type_and_name = _get_model_config_by_type_and_name + tenant_model_service_mod.get_model_config_from_provider_instance = _get_model_config_from_provider_instance tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py index 5b994ea525..3d0a5b52e6 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py @@ -478,7 +478,7 @@ def _load_doc_module(monkeypatch): raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") return _MockModelConfig2(mock_tenant_id, "model-1").to_dict() - def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str): + def _get_model_config_from_provider_instance(tenant_id: str, model_type: str, model_name: str): if not model_name: raise Exception("Model Name is required") return _MockModelConfig2(tenant_id, model_name).to_dict() @@ -488,7 +488,7 @@ def _load_doc_module(monkeypatch): return _MockModelConfig2(tenant_id, "chat-model").to_dict() tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id - tenant_model_service_mod.get_model_config_by_type_and_name = _get_model_config_by_type_and_name + tenant_model_service_mod.get_model_config_from_provider_instance = _get_model_config_from_provider_instance tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index 5cf30117b7..cdb2736e12 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -481,7 +481,7 @@ def _load_session_module(monkeypatch): raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") return _MockModelConfig2(mock_tenant_id, "model-1").to_dict() - def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str): + def _get_model_config_from_provider_instance(tenant_id: str, model_type: str, model_name: str): if not model_name: raise Exception("Model Name is required") return _MockModelConfig2(tenant_id, model_name, model_type).to_dict() @@ -525,7 +525,7 @@ def _load_session_module(monkeypatch): return _MockModelConfig2(tenant_id, model_name, model_type_val).to_dict() tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id - tenant_model_service_mod.get_model_config_by_type_and_name = _get_model_config_by_type_and_name + tenant_model_service_mod.get_model_config_from_provider_instance = _get_model_config_from_provider_instance tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) @@ -2095,7 +2095,6 @@ def _load_chat_api_module(monkeypatch): monkeypatch.setitem(sys.modules, "api.db.joint_services", joint_pkg) tenant_model_svc = ModuleType("api.db.joint_services.tenant_model_service") - tenant_model_svc.get_model_config_by_type_and_name = lambda *_a, **_k: {} tenant_model_svc.get_tenant_default_model_by_type = lambda *_a, **_k: {} monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_svc) @@ -2185,10 +2184,6 @@ def _load_chat_api_module(monkeypatch): api_utils_mod.validate_request = lambda *_a, **_k: (lambda func: func) monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) - tenant_utils_mod = ModuleType("api.utils.tenant_utils") - tenant_utils_mod.ensure_tenant_model_id_for_params = lambda _tenant_id, req: req - monkeypatch.setitem(sys.modules, "api.utils.tenant_utils", tenant_utils_mod) - rag_gen_mod = ModuleType("rag.prompts.generator") rag_gen_mod.chunks_format = lambda chunks: chunks monkeypatch.setitem(sys.modules, "rag.prompts.generator", rag_gen_mod) diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py index 3f32de77cc..8146b76cae 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py @@ -184,7 +184,7 @@ class TestDatasetCreate: "name, embedding_model", [ ("BAAI/bge-small-en-v1.5@Builtin", "BAAI/bge-small-en-v1.5@Builtin"), - ("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"), + ("embedding-3@ZHIPU-AI", "embedding-3@CI@ZHIPU-AI"), ], ids=["builtin_baai", "tenant_zhipu"], ) @@ -197,7 +197,7 @@ class TestDatasetCreate: @pytest.mark.parametrize( "name, embedding_model", [ - ("unknown_llm_name", "unknown@ZHIPU-AI"), + ("unknown_llm_name", "unknown@CI@ZHIPU-AI"), ("unknown_llm_factory", "embedding-3@unknown"), ("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"), ("tenant_no_auth", "text-embedding-3-small@OpenAI"), @@ -208,10 +208,7 @@ class TestDatasetCreate: payload = {"name": name, "embedding_model": embedding_model} with pytest.raises(Exception) as exception_info: client.create_dataset(**payload) - if "tenant_no_auth" in name: - assert str(exception_info.value) == f"Unauthorized model: <{embedding_model}>", str(exception_info.value) - else: - assert str(exception_info.value) == f"Unsupported model: <{embedding_model}>", str(exception_info.value) + assert "not found" in str(exception_info.value), str(exception_info.value) @pytest.mark.p2 @pytest.mark.parametrize( @@ -240,13 +237,13 @@ class TestDatasetCreate: def test_embedding_model_unset(self, client): payload = {"name": "embedding_model_unset"} dataset = client.create_dataset(**payload) - assert dataset.embedding_model == "BAAI/bge-small-en-v1.5@Builtin", str(dataset) + assert dataset.embedding_model == "BAAI/bge-small-en-v1.5@Local@Builtin", str(dataset) @pytest.mark.p2 def test_embedding_model_none(self, client): payload = {"name": "embedding_model_none", "embedding_model": None} dataset = client.create_dataset(**payload) - assert dataset.embedding_model == "BAAI/bge-small-en-v1.5@Builtin", str(dataset) + assert dataset.embedding_model == "BAAI/bge-small-en-v1.5@Local@Builtin", str(dataset) @pytest.mark.p2 @pytest.mark.parametrize( diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py index d32d8fd9b3..f29ce98892 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py @@ -170,7 +170,7 @@ class TestDatasetUpdate: "embedding_model", [ "BAAI/bge-small-en-v1.5@Builtin", - "embedding-3@ZHIPU-AI", + "embedding-3@CI@ZHIPU-AI", ], ids=["builtin_baai", "tenant_zhipu"], ) @@ -198,10 +198,7 @@ class TestDatasetUpdate: with pytest.raises(Exception) as exception_info: dataset.update({"name": name, "embedding_model": embedding_model}) error_msg = str(exception_info.value) - if "tenant_no_auth" in name: - assert error_msg == f"Unauthorized model: <{embedding_model}>", error_msg - else: - assert error_msg == f"Unsupported model: <{embedding_model}>", error_msg + assert "not found" in error_msg, error_msg @pytest.mark.p2 @pytest.mark.parametrize( @@ -231,10 +228,10 @@ class TestDatasetUpdate: def test_embedding_model_none(self, client, add_dataset_func): dataset = add_dataset_func dataset.update({"embedding_model": None}) - assert dataset.embedding_model == "BAAI/bge-small-en-v1.5@Builtin", str(dataset) + assert dataset.embedding_model == "BAAI/bge-small-en-v1.5@Local@Builtin", str(dataset) retrieved_dataset = client.get_dataset(name=dataset.name) - assert retrieved_dataset.embedding_model == "BAAI/bge-small-en-v1.5@Builtin", str(retrieved_dataset) + assert retrieved_dataset.embedding_model == "BAAI/bge-small-en-v1.5@Local@Builtin", str(retrieved_dataset) @pytest.mark.p2 @pytest.mark.parametrize( diff --git a/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py b/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py index 52c1ea5de6..221a56d142 100644 --- a/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py +++ b/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py @@ -309,7 +309,7 @@ def _load_chunk_module(monkeypatch): tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service") tenant_model_service_mod.get_model_config_by_id = lambda *_args, **_kwargs: {"llm_name": "embed", "model_type": "embedding"} - tenant_model_service_mod.get_model_config_by_type_and_name = lambda *_args, **_kwargs: {"llm_name": "embed", "model_type": "embedding"} + tenant_model_service_mod.get_model_config_from_provider_instance = lambda *_args, **_kwargs: {"llm_name": "embed", "model_type": "embedding"} tenant_model_service_mod.get_tenant_default_model_by_type = lambda *_args, **_kwargs: {"llm_name": "chat", "model_type": "chat"} monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) diff --git a/test/testcases/test_web_api/test_user_app/test_user_app_unit.py b/test/testcases/test_web_api/test_user_app/test_user_app_unit.py index fb576799e9..91f7f716bc 100644 --- a/test/testcases/test_web_api/test_user_app/test_user_app_unit.py +++ b/test/testcases/test_web_api/test_user_app/test_user_app_unit.py @@ -344,10 +344,6 @@ def _load_user_app(monkeypatch): api_utils_mod.validate_request = _validate_request monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) - tenant_utils_mod = ModuleType("api.utils.tenant_utils") - tenant_utils_mod.ensure_tenant_model_id_for_params = lambda _tenant_id, params: params - monkeypatch.setitem(sys.modules, "api.utils.tenant_utils", tenant_utils_mod) - crypt_mod = ModuleType("api.utils.crypt") crypt_mod.decrypt = lambda value: value monkeypatch.setitem(sys.modules, "api.utils.crypt", crypt_mod) diff --git a/test/unit_test/api/apps/sdk/test_dify_retrieval.py b/test/unit_test/api/apps/sdk/test_dify_retrieval.py index 72715d72f9..0b880308bd 100644 --- a/test/unit_test/api/apps/sdk/test_dify_retrieval.py +++ b/test/unit_test/api/apps/sdk/test_dify_retrieval.py @@ -120,9 +120,8 @@ def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, chunks=No _stub( monkeypatch, "api.db.joint_services.tenant_model_service", - get_model_config_by_id=lambda *_a, **_k: {}, - get_model_config_by_type_and_name=lambda *_a, **_k: {}, get_tenant_default_model_by_type=lambda *_a, **_k: {}, + get_model_config_from_provider_instance=lambda *_a, **_k: {}, ) _stub( diff --git a/test/unit_test/api/db/services/test_dialog_service_final_answer.py b/test/unit_test/api/db/services/test_dialog_service_final_answer.py index 30fb1e4c30..44b2bf570a 100644 --- a/test/unit_test/api/db/services/test_dialog_service_final_answer.py +++ b/test/unit_test/api/db/services/test_dialog_service_final_answer.py @@ -203,7 +203,7 @@ def test_async_ask_final_event_carries_decorated_answer(monkeypatch): dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB] ) monkeypatch.setattr( - dialog_service, "get_model_config_by_type_and_name", + dialog_service, "get_model_config_from_provider_instance", lambda _tid, _type, _name: _LLM_CONFIG, ) monkeypatch.setattr(dialog_service, "LLMBundle", lambda _tid, _cfg: chat_mdl) @@ -257,7 +257,7 @@ def test_async_ask_delta_events_carry_incremental_text_only(monkeypatch): dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB] ) monkeypatch.setattr( - dialog_service, "get_model_config_by_type_and_name", + dialog_service, "get_model_config_from_provider_instance", lambda _tid, _type, _name: _LLM_CONFIG, ) monkeypatch.setattr(dialog_service, "LLMBundle", lambda _tid, _cfg: chat_mdl) @@ -345,10 +345,11 @@ def test_async_chat_final_event_carries_decorated_answer(monkeypatch): # Stub out the heavy service/model calls monkeypatch.setattr( - dialog_service.TenantLLMService, "llm_id2llm_type", lambda _llm_id: "chat" + dialog_service, "get_model_type_by_name", + lambda _tid, _llm_id: ["chat"] ) monkeypatch.setattr( - dialog_service.TenantLLMService, "get_model_config", + dialog_service, "get_model_config_from_provider_instance", lambda _tid, _type, _llm_id: _LLM_CONFIG, ) monkeypatch.setattr( @@ -406,10 +407,11 @@ def test_async_chat_langfuse_uses_start_observation(monkeypatch): retriever = _StubRetriever() monkeypatch.setattr( - dialog_service.TenantLLMService, "llm_id2llm_type", lambda _llm_id: "chat" + dialog_service, "get_model_type_by_name", + lambda _tid, _llm_id: ["chat"] ) monkeypatch.setattr( - dialog_service.TenantLLMService, "get_model_config", + dialog_service, "get_model_config_from_provider_instance", lambda _tid, _type, _llm_id: _LLM_CONFIG, ) monkeypatch.setattr( @@ -475,10 +477,11 @@ def test_async_chat_continues_when_langfuse_observation_start_fails(monkeypatch) retriever = _StubRetriever() monkeypatch.setattr( - dialog_service.TenantLLMService, "llm_id2llm_type", lambda _llm_id: "chat" + dialog_service, "get_model_type_by_name", + lambda _tid, _llm_id: ["chat"] ) monkeypatch.setattr( - dialog_service.TenantLLMService, "get_model_config", + dialog_service, "get_model_config_from_provider_instance", lambda _tid, _type, _llm_id: _LLM_CONFIG, ) monkeypatch.setattr( diff --git a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py index 5910781be4..6109010b33 100644 --- a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py +++ b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py @@ -282,11 +282,14 @@ def test_async_chat_uses_all_docs_when_no_doc_ids_selected(monkeypatch): ) monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) - monkeypatch.setattr(dialog_service.TenantLLMService, "llm_id2llm_type", lambda _llm_id: "chat") monkeypatch.setattr( - dialog_service.TenantLLMService, - "get_model_config", - lambda *_args, **_kwargs: {"llm_factory": "unit", "max_tokens": 4096}, + dialog_service, "get_model_type_by_name", + lambda _tid, _llm_id: ["chat"] + ) + monkeypatch.setattr( + dialog_service, + "get_model_config_from_provider_instance", + lambda *_args, **_kwargs: {"llm_factory": "unit", "max_tokens": 4096, "model_type": "chat"}, ) monkeypatch.setattr(dialog_service.TenantLangfuseService, "filter_by_tenant", lambda **_kwargs: None) monkeypatch.setattr( diff --git a/test/unit_test/rag/svr/task_executor_refactor/conftest.py b/test/unit_test/rag/svr/task_executor_refactor/conftest.py index cb20ba9998..64da81a0db 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/conftest.py +++ b/test/unit_test/rag/svr/task_executor_refactor/conftest.py @@ -312,7 +312,7 @@ def create_patch_embedding_model(vectors=None, vector_size=128): mock_model.__exit__ = MagicMock(return_value=False) return patch( - "rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name", + "rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance", return_value=MagicMock(), ), patch( "rag.svr.task_executor_refactor.task_handler.LLMBundle", diff --git a/test/unit_test/rag/svr/task_executor_refactor/test_chunk_post_processor.py b/test/unit_test/rag/svr/task_executor_refactor/test_chunk_post_processor.py index 1d684fc0ef..015016c4a5 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/test_chunk_post_processor.py +++ b/test/unit_test/rag/svr/task_executor_refactor/test_chunk_post_processor.py @@ -56,7 +56,7 @@ class TestExtractKeywords: {"content_with_weight": "This is test content two"}, ] - with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_by_type_and_name") as mock_config: + with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_from_provider_instance") as mock_config: mock_config.return_value = MagicMock() with patch("rag.svr.task_executor_refactor.chunk_post_processor.LLMBundle") as mock_llm: @@ -84,7 +84,7 @@ class TestExtractKeywords: ctx.has_canceled_func = MagicMock(return_value=True) docs = [{"content_with_weight": "This is test content"}] - with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_by_type_and_name") as mock_config: + with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_from_provider_instance") as mock_config: mock_config.return_value = MagicMock() with patch("rag.svr.task_executor_refactor.chunk_post_processor.LLMBundle") as mock_llm: @@ -106,7 +106,7 @@ class TestExtractKeywords: ctx = self._create_mock_context() docs = [] - with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_by_type_and_name") as mock_config: + with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_from_provider_instance") as mock_config: mock_config.return_value = MagicMock() with patch("rag.svr.task_executor_refactor.chunk_post_processor.LLMBundle") as mock_llm: @@ -146,7 +146,7 @@ class TestGenerateQuestions: {"content_with_weight": "This is test content one"}, ] - with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_by_type_and_name") as mock_config: + with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_from_provider_instance") as mock_config: mock_config.return_value = MagicMock() with patch("rag.svr.task_executor_refactor.chunk_post_processor.LLMBundle") as mock_llm: @@ -174,7 +174,7 @@ class TestGenerateQuestions: ctx.has_canceled_func = MagicMock(return_value=True) docs = [{"content_with_weight": "This is test content"}] - with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_by_type_and_name") as mock_config: + with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_from_provider_instance") as mock_config: mock_config.return_value = MagicMock() with patch("rag.svr.task_executor_refactor.chunk_post_processor.LLMBundle") as mock_llm: @@ -223,7 +223,7 @@ class TestGenerateMetadata: {"content_with_weight": "This is test content", "metadata_obj": {"category": "test"}}, ] - with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_by_type_and_name") as mock_config: + with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_from_provider_instance") as mock_config: mock_config.return_value = MagicMock() with patch("rag.svr.task_executor_refactor.chunk_post_processor.LLMBundle") as mock_llm: @@ -256,7 +256,7 @@ class TestGenerateMetadata: {"content_with_weight": "This is test content", "metadata_obj": {"category": "test"}}, ] - with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_by_type_and_name") as mock_config: + with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_from_provider_instance") as mock_config: mock_config.return_value = MagicMock() with patch("rag.svr.task_executor_refactor.chunk_post_processor.LLMBundle") as mock_llm: @@ -307,7 +307,7 @@ class TestApplyTags: {"content_with_weight": "This is test content"}, ] - with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_by_type_and_name") as mock_config: + with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_from_provider_instance") as mock_config: mock_config.return_value = MagicMock() with patch("rag.svr.task_executor_refactor.chunk_post_processor.LLMBundle") as mock_llm: @@ -337,7 +337,7 @@ class TestApplyTags: {"content_with_weight": "This is test content"}, ] - with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_by_type_and_name") as mock_config: + with patch("rag.svr.task_executor_refactor.chunk_post_processor.get_model_config_from_provider_instance") as mock_config: mock_config.return_value = MagicMock() with patch("rag.svr.task_executor_refactor.chunk_post_processor.LLMBundle") as mock_llm: diff --git a/test/unit_test/rag/svr/task_executor_refactor/test_task_handler.py b/test/unit_test/rag/svr/task_executor_refactor/test_task_handler.py index d8149867a9..fed3eb06af 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/test_task_handler.py +++ b/test/unit_test/rag/svr/task_executor_refactor/test_task_handler.py @@ -247,7 +247,7 @@ class TestTaskHandlerBuildToc: coro.close() return [] - with patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_cfg: + with patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_cfg: mock_cfg.return_value = MagicMock() with patch("rag.svr.task_executor_refactor.task_handler.LLMBundle") as mock_bundle: mock_msg = MagicMock() @@ -271,7 +271,7 @@ class TestTaskHandlerBuildToc: coro.close() return toc_result - with patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_cfg: + with patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_cfg: mock_cfg.return_value = MagicMock() with patch("rag.svr.task_executor_refactor.task_handler.LLMBundle") as mock_bundle: mock_msg = MagicMock() diff --git a/test/unit_test/rag/svr/task_executor_refactor/test_task_handler_integration.py b/test/unit_test/rag/svr/task_executor_refactor/test_task_handler_integration.py index dbacf251ae..dc0fa0ace4 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/test_task_handler_integration.py +++ b/test/unit_test/rag/svr/task_executor_refactor/test_task_handler_integration.py @@ -132,7 +132,7 @@ class TestStandardChunkingPipelineIntegration: with patch_get_storage_binary(), \ patch_task_handler_settings(mock_settings), \ patch("rag.svr.task_executor_refactor.chunk_service.settings", mock_settings), \ - patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_get_config, \ + patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_get_config, \ patch("rag.svr.task_executor_refactor.task_handler.LLMBundle") as mock_bundle, \ patch("rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type") as mock_get_default, \ patch("rag.svr.task_executor_refactor.task_handler.File2DocumentService") as mock_file_service, \ @@ -177,7 +177,7 @@ class TestStandardChunkingPipelineIntegration: with patch_get_storage_binary(), \ patch_task_handler_settings(mock_settings), \ patch("rag.svr.task_executor_refactor.chunk_service.settings", mock_settings), \ - patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_get_config, \ + patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_get_config, \ patch("rag.svr.task_executor_refactor.task_handler.LLMBundle") as mock_bundle, \ patch("rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type") as mock_get_default, \ patch("rag.svr.task_executor_refactor.task_handler.File2DocumentService") as mock_file_service, \ @@ -223,7 +223,7 @@ class TestStandardChunkingPipelineIntegration: with patch_get_storage_binary(), \ patch_task_handler_settings(mock_settings), \ patch("rag.svr.task_executor_refactor.chunk_service.settings", mock_settings), \ - patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_get_config, \ + patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_get_config, \ patch("rag.svr.task_executor_refactor.task_handler.LLMBundle") as mock_bundle, \ patch("rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type") as mock_get_default, \ patch("rag.svr.task_executor_refactor.task_handler.File2DocumentService") as mock_file_service, \ @@ -271,7 +271,7 @@ class TestStandardChunkingPipelineIntegration: with patch_get_storage_binary(), \ patch_task_handler_settings(mock_settings), \ patch("rag.svr.task_executor_refactor.chunk_service.settings", mock_settings), \ - patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_get_config, \ + patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_get_config, \ patch("rag.svr.task_executor_refactor.task_handler.LLMBundle") as mock_bundle, \ patch("rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type") as mock_get_default, \ patch("rag.svr.task_executor_refactor.task_handler.File2DocumentService") as mock_file_service, \ @@ -320,7 +320,7 @@ class TestStandardChunkingPipelineIntegration: with patch_get_storage_binary(), \ patch_task_handler_settings(mock_settings), \ patch("rag.svr.task_executor_refactor.chunk_service.settings", mock_settings), \ - patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_get_config, \ + patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_get_config, \ patch("rag.svr.task_executor_refactor.task_handler.LLMBundle") as mock_bundle, \ patch("rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type") as mock_get_default, \ patch("rag.svr.task_executor_refactor.task_handler.File2DocumentService") as mock_file_service, \ @@ -457,7 +457,7 @@ class TestTaskCancellationCleanupIntegration: with patch_task_handler_settings(mock_settings), \ patch("rag.svr.task_executor_refactor.task_handler.search.index_name") as mock_index_name, \ patch("rag.svr.task_executor_refactor.task_handler.thread_pool_exec") as mock_thread_exec, \ - patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_get_config, \ + patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_get_config, \ patch("rag.svr.task_executor_refactor.task_handler.LLMBundle") as mock_bundle, \ patch("rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type") as mock_get_default: @@ -514,7 +514,7 @@ class TestRaptorPipelineIntegration: with patch_task_handler_settings(mock_settings), \ patch("rag.svr.task_executor_refactor.chunk_service.settings", mock_settings), \ - patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_get_config, \ + patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_get_config, \ patch("rag.svr.task_executor_refactor.task_handler.LLMBundle") as mock_bundle, \ patch("rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type") as mock_get_default, \ patch("rag.svr.task_executor_refactor.task_handler.search.index_name") as mock_index_name, \ @@ -559,7 +559,7 @@ class TestRaptorPipelineIntegration: with patch_task_handler_settings(mock_settings), \ patch("rag.svr.task_executor_refactor.chunk_service.settings", mock_settings), \ - patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_get_config, \ + patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_get_config, \ patch("rag.svr.task_executor_refactor.task_handler.LLMBundle") as mock_bundle, \ patch("rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type") as mock_get_default, \ patch("rag.svr.task_executor_refactor.task_handler.search.index_name") as mock_index_name, \ @@ -625,7 +625,7 @@ class TestEmbeddingModelBindingFailureIntegration: task_dict = self._create_standard_task_dict() ctx = create_task_context(task_dict) - with patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_get_config, \ + with patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_get_config, \ patch("rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type") as mock_get_default: mock_get_config.side_effect = Exception("Model not found") @@ -642,7 +642,7 @@ class TestEmbeddingModelBindingFailureIntegration: task_dict = self._create_standard_task_dict() ctx = create_task_context(task_dict) - with patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_get_config, \ + with patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_get_config, \ patch("rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type") as mock_get_default: mock_get_config.side_effect = Exception("Model not found") @@ -755,7 +755,7 @@ class TestTocAsyncFlowIntegration: with patch_get_storage_binary(), \ patch_task_handler_settings(mock_settings), \ patch("rag.svr.task_executor_refactor.chunk_service.settings", mock_settings), \ - patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_get_config, \ + patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_get_config, \ patch("rag.svr.task_executor_refactor.task_handler.LLMBundle") as mock_bundle, \ patch("rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type") as mock_get_default, \ patch("rag.svr.task_executor_refactor.task_handler.File2DocumentService") as mock_file_service, \ @@ -821,7 +821,7 @@ class TestTocAsyncFlowIntegration: with patch_get_storage_binary(), \ patch_task_handler_settings(mock_settings), \ patch("rag.svr.task_executor_refactor.chunk_service.settings", mock_settings), \ - patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_get_config, \ + patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_get_config, \ patch("rag.svr.task_executor_refactor.task_handler.LLMBundle") as mock_bundle, \ patch("rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type") as mock_get_default, \ patch("rag.svr.task_executor_refactor.task_handler.File2DocumentService") as mock_file_service, \ @@ -910,7 +910,7 @@ class TestRecordingContextDataFlowAssertions: with patch_get_storage_binary(), \ patch_task_handler_settings(mock_settings), \ patch("rag.svr.task_executor_refactor.chunk_service.settings", mock_settings), \ - patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_get_config, \ + patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_get_config, \ patch("rag.svr.task_executor_refactor.task_handler.LLMBundle") as mock_bundle, \ patch("rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type") as mock_get_default, \ patch("rag.svr.task_executor_refactor.task_handler.File2DocumentService") as mock_file_service, \ @@ -956,7 +956,7 @@ class TestRecordingContextDataFlowAssertions: with patch_get_storage_binary(), \ patch_task_handler_settings(mock_settings), \ patch("rag.svr.task_executor_refactor.chunk_service.settings", mock_settings), \ - patch("rag.svr.task_executor_refactor.task_handler.get_model_config_by_type_and_name") as mock_get_config, \ + patch("rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance") as mock_get_config, \ patch("rag.svr.task_executor_refactor.task_handler.LLMBundle") as mock_bundle, \ patch("rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type") as mock_get_default, \ patch("rag.svr.task_executor_refactor.task_handler.File2DocumentService") as mock_file_service, \ diff --git a/tools/scripts/mysql_migration.py b/tools/scripts/mysql_migration.py index 2c1232d86a..4fea8e676a 100644 --- a/tools/scripts/mysql_migration.py +++ b/tools/scripts/mysql_migration.py @@ -331,8 +331,8 @@ class TenantModelProviderStage(MigrationStage): for tenant_id, llm_factory in batch: record_id = self.generate_uuid() values.append(f"('{record_id}', '{llm_factory}', '{tenant_id}', " - f"{current_ts}, FROM_UNIXTIME({current_ts}), " - f"{current_ts}, FROM_UNIXTIME({current_ts}))") + f"{current_ts * 1000}, FROM_UNIXTIME({current_ts}), " + f"{current_ts * 1000}, FROM_UNIXTIME({current_ts}))") insert_sql = f""" INSERT INTO tenant_model_provider @@ -478,7 +478,7 @@ class TenantModelInstanceStage(MigrationStage): if self.dry_run: logger.info(f"[DRY RUN] Would insert {len(records)} records") for tenant_id, llm_factory, api_key, status, provider_id in records[:5]: - logger.info(f" instance_name={llm_factory}, provider_id={provider_id}, api_key=***") + logger.info(f" instance_name=default, provider_id={provider_id}, api_key=***") if len(records) > 5: logger.info(f" ... and {len(records) - 5} more records") return len(records), self.target_tables @@ -490,13 +490,13 @@ class TenantModelInstanceStage(MigrationStage): values = [] for tenant_id, llm_factory, api_key, status, provider_id in batch: record_id = self.generate_uuid() - instance_name = llm_factory.replace("'", "''") if llm_factory else "" + instance_name = "default" api_key_escaped = api_key.replace("'", "''") if api_key else "" - status_val = status if status else "active" + status_val = "active" if status in ["1", "active", "enable"] else "inactive" values.append(f"('{record_id}', '{instance_name}', '{provider_id}', " f"'{api_key_escaped}', '{status_val}', " - f"{current_ts}, FROM_UNIXTIME({current_ts}), " - f"{current_ts}, FROM_UNIXTIME({current_ts}))") + f"{current_ts * 1000}, FROM_UNIXTIME({current_ts}), " + f"{current_ts * 1000}, FROM_UNIXTIME({current_ts}))") insert_sql = f""" INSERT INTO tenant_model_instance @@ -518,6 +518,7 @@ class TenantModelInstanceStage(MigrationStage): provider_id VARCHAR(32) NOT NULL, api_key VARCHAR(512) NOT NULL, status VARCHAR(32) DEFAULT 'active', + extra VARCHAR(512) DEFAULT '{}', create_time BIGINT, create_date DATETIME, update_time BIGINT, @@ -676,11 +677,11 @@ class TenantModelStage(MigrationStage): record_id = self.generate_uuid() model_name_escaped = llm_name.replace("'", "''") if llm_name else "" model_type_escaped = model_type.replace("'", "''") if model_type else "" - status_val = status if status else "active" + status_val = "active" if status in ["1", "active", "enable"] else "inactive" values.append(f"('{record_id}', '{model_name_escaped}', '{provider_id}', " f"'{instance_id}', '{model_type_escaped}', '{status_val}', " - f"{current_ts}, FROM_UNIXTIME({current_ts}), " - f"{current_ts}, FROM_UNIXTIME({current_ts}))") + f"{current_ts * 1000}, FROM_UNIXTIME({current_ts}), " + f"{current_ts * 1000}, FROM_UNIXTIME({current_ts}))") insert_sql = f""" INSERT INTO tenant_model @@ -704,6 +705,7 @@ class TenantModelStage(MigrationStage): instance_id VARCHAR(32) NOT NULL, model_type VARCHAR(32) NOT NULL, status VARCHAR(32) DEFAULT 'active', + extra VARCHAR(1024) DEFAULT '{}', create_time BIGINT, create_date DATETIME, update_time BIGINT, diff --git a/web/CLAUDE.md b/web/CLAUDE.md index 822689cd09..9a78fdec54 100644 --- a/web/CLAUDE.md +++ b/web/CLAUDE.md @@ -38,9 +38,42 @@ For translation tasks, add keys **only to the explicitly requested language file When refactoring or extracting components, **verify layout behavior after each structural change** (especially `flex-1`, conditional rendering, or flex direction changes). Check that existing buttons, alignment, and responsive behavior remain intact. After extraction, verify: (1) all original props and behavior are preserved, (2) layout in parent contexts is identical, and (3) no syntax or type errors were introduced. ### State Management and Data Fetching + +#### Query Key Factory (Mandatory) +**Never write raw `queryKey` arrays inline.** Always use a query key factory object that returns `as const` tuples. Raw arrays duplicated across `useQuery` and `invalidateQueries` are brittle, unreadable, and cause stale-cache bugs when key structures drift. + +```ts +// ❌ Bad — raw array, hard to match with useQuery +queryClient.invalidateQueries({ + queryKey: [LLMApiAction.AddedProviders, params.provider_name, params.instance_name, 'models'], +}); + +// ✅ Good — factory reference, self-documenting +queryClient.invalidateQueries({ + queryKey: LlmKeys.instanceModels(params.provider_name, params.instance_name), +}); +``` + +- Place the factory in the same file as the hooks, named `{Domain}Keys` (e.g., `LlmKeys`, `DatasetKeys`). +- Every `useQuery` and every `invalidateQueries` must reference the same factory function. +- Use `as const` on each factory return value for type-safe readonly tuples. + +#### Cache Debugging For React Query / cache invalidation bugs, **carefully compare query keys across all consuming components and mutation hooks**. Mismatched keys (e.g., with/without `refreshCount`) are a common root cause of stale data or duplicate requests. - Systematically: (1) list every component/hook that calls `useQuery` for this data, (2) compare their query keys character-for-character, (3) check every mutation's `onSuccess` for cache invalidation, and (4) verify no parent re-renders are remounting the observer. +### Network Request Layering +HTTP requests are organized in three layers. **Never import `@/utils/request`, `@/utils/next-request`, or `@/utils/api` directly inside a hook**: +1. `src/hooks/use-xx-request.ts(x)` — React Query hooks; only call the service layer. +2. `src/services/xx-service.ts` — Register endpoints via `registerNextServer`, all going through `@/utils/next-request`. +3. `src/utils/next-request.ts` — The single axios instance; handles token, 401 redirects, and error notifications. + +Interface types are split between two folders: +- Response/data shape → `src/interfaces/database/xx.ts` +- Request params/body → `src/interfaces/request/xx.ts` + +Model-related endpoints (LLM provider / factory / my LLM, etc.) are consolidated in `src/services/llm-service.ts` rather than scattered across hooks. For GET endpoints, register with `method: 'get'` in the service, and on the call site pass `true` as the second argument to use the native axios config (e.g., `service.listProviders({ params: { available: true } }, true)`). + ### Shared UI Component Lock The folder `src/components/ui/` is the project's **shared UI library** — it contains both official shadcn/ui primitives and project-authored common components built on top of shadcn. Both kinds are intended to be reused across the app and **must not be modified casually**. @@ -55,3 +88,11 @@ The folder `src/components/ui/` is the project's **shared UI library** — it co - Extract complex logic into hooks or utils; keep components lean. - Use `PascalCase` for constants and component names. - Avoid duplicating component structures in JSX; favor render props or reusable components. + +### Utility Libraries and Reuse +- **Time/date handling**: Use `dayjs` for all date/time formatting, parsing, and manipulation. +- **Utility hooks**: Prefer `ahooks` for common reusable hooks (e.g., `useDebounce`, `useSetState`). +- **General utilities**: Lodash is available for utility functions when needed. +- **Project utilities first**: Before reaching for a third-party library, check if the project already has an existing utility or hook that covers the need. +- **Extract and share**: If repeated logic cannot be satisfied by an existing project utility or a third-party library, extract it into an appropriate shared hook (`src/hooks/`) or utility file (`src/utils/`). +- **Check for duplicate patterns before adding**: When asked to add logic (validation, existence checks, API calls, etc.), first search for existing hooks/functions that do the same or similar thing — especially in the same file or sibling hooks. If a hook already does X, call it instead of re-implementing X inline. This applies to mutations calling mutations, utility wrappers, and boilerplate around API calls. diff --git a/web/src/components/layout-recognize-form-field.tsx b/web/src/components/layout-recognize-form-field.tsx index 8ab9089173..01a5974478 100644 --- a/web/src/components/layout-recognize-form-field.tsx +++ b/web/src/components/layout-recognize-form-field.tsx @@ -1,13 +1,13 @@ -import { LlmModelType } from '@/constants/knowledge'; import { useTranslate } from '@/hooks/common-hooks'; -import { useSelectLlmOptionsByModelType } from '@/hooks/use-llm-request'; +import { useFetchAllAddedModels } from '@/hooks/use-llm-request'; import { cn } from '@/lib/utils'; import { camelCase } from 'lodash'; import { ReactNode, useMemo } from 'react'; import { useFormContext } from 'react-hook-form'; import { MinerUOptionsFormField } from './mineru-options-form-field'; -import { SelectWithSearch } from './originui/select-with-search'; +import { buildModelTree } from './model-tree-select'; import { PaddleOCROptionsFormField } from './paddleocr-options-form-field'; +import { TreeSelect, TreeSelectNode } from './tree-select'; import { FormControl, FormField, @@ -44,9 +44,9 @@ export function LayoutRecognizeFormField({ const form = useFormContext(); const { t } = useTranslate('knowledgeDetails'); - const allOptions = useSelectLlmOptionsByModelType(); + const { data: allAddedModels } = useFetchAllAddedModels(); - const options = useMemo(() => { + const treeData = useMemo(() => { const list = optionsWithoutLLM ? optionsWithoutLLM : [ @@ -60,28 +60,28 @@ export function LayoutRecognizeFormField({ value: x, })); - const image2TextList = [ - ...allOptions[LlmModelType.Image2text], - ...allOptions[LlmModelType.Ocr], - ].map((x) => { - return { - ...x, - options: x.options.map((y) => { - return { - ...y, - label: ( -
- {y.label} - Experimental -
- ), - }; - }), - }; - }); + const prependNodes: TreeSelectNode[] = list.map((x) => ({ + id: x.value, + title: x.label, + })); - return [...list, ...image2TextList]; - }, [allOptions, optionsWithoutLLM, t]); + const modelTree = buildModelTree( + allAddedModels, + ['image2text', 'ocr'], + (node) => ( +
+ + {node.label} + + + Experimental + +
+ ), + ); + + return [...prependNodes, ...modelTree]; + }, [allAddedModels, optionsWithoutLLM, t]); return (
- + showSearch + defaultExpandAll + renderSelected={(node) => { + if (!node) return null; + return node.label ?? node.title; + }} + />
diff --git a/web/src/components/llm-select/llm-label.tsx b/web/src/components/llm-select/llm-label.tsx index 97c59b7c35..8c02a7ee06 100644 --- a/web/src/components/llm-select/llm-label.tsx +++ b/web/src/components/llm-select/llm-label.tsx @@ -1,21 +1,33 @@ -import { getLLMIconName, getLlmNameAndFIdByLlmId } from '@/utils/llm-util'; +import { parseModelValue } from '@/utils/llm-util'; import { memo } from 'react'; import { LlmIcon } from '../svg-icon'; interface IProps { - id?: string; value?: string; - onChange?: (value: string) => void; - disabled?: boolean; } -const LLMLabel = ({ value }: IProps) => { - const { llmName, fId } = getLlmNameAndFIdByLlmId(value); +export const LLMLabel = ({ value }: IProps) => { + const parsed = value ? parseModelValue(value) : null; + const modelName = parsed?.model_name; + const instanceName = parsed?.model_instance; + const iconName = parsed ? parsed.model_provider : ''; + + if (!modelName) return null; return ( -
- - {llmName} +
+ + {modelName} + {instanceName && ( + + {instanceName} + + )}
); }; diff --git a/web/src/components/llm-select/next.tsx b/web/src/components/llm-select/next.tsx index e8eed362bb..9662048e2f 100644 --- a/web/src/components/llm-select/next.tsx +++ b/web/src/components/llm-select/next.tsx @@ -1,11 +1,11 @@ import { LlmModelType } from '@/constants/knowledge'; -import { useComposeLlmOptionsByModelTypes } from '@/hooks/use-llm-request'; import * as SelectPrimitive from '@radix-ui/react-select'; import { forwardRef, memo, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { LlmSettingFieldItems } from '../llm-setting-items/next'; import { Popover, PopoverContent, PopoverTrigger } from '../ui/popover'; import { Select, SelectTrigger, SelectValue } from '../ui/select'; +import LLMLabel from './llm-label'; export interface NextInnerLLMSelectProps { id?: string; @@ -51,8 +51,6 @@ const NextInnerLLMSelect = forwardRef< } }, [filter, ttsModel]); - const modelOptions = useComposeLlmOptionsByModelTypes(modelTypes); - return ( setSearchTerm(e.target.value)} + /> +
+ )} +
e.stopPropagation()} + > + {filteredData.length > 0 ? ( + renderTree(filteredData) + ) : ( +
+ {t('common.noData')} +
+ )} +
+ + + ); +} diff --git a/web/src/constants/llm.ts b/web/src/constants/llm.ts index 17fcc0620b..4157de1fbe 100644 --- a/web/src/constants/llm.ts +++ b/web/src/constants/llm.ts @@ -1,3 +1,8 @@ +export enum ModelStatus { + Active = 'active', + Inactive = 'inactive', +} + export enum LLMFactory { TongYiQianWen = 'Tongyi-Qianwen', Moonshot = 'Moonshot', @@ -140,6 +145,24 @@ export const IconMap = { [LLMFactory.Perplexity]: 'perplexity', }; +export const ModelTypeToField: Record = { + chat: 'llm_id', + embedding: 'embd_id', + image2text: 'img2txt_id', + speech2text: 'asr_id', + rerank: 'rerank_id', + tts: 'tts_id', +}; + +export const FieldToModelType: Record = { + llm_id: 'chat', + embd_id: 'embedding', + img2txt_id: 'vision', + asr_id: 'asr', + rerank_id: 'rerank', + tts_id: 'tts', +}; + export const APIMapUrl = { [LLMFactory.OpenAI]: 'https://platform.openai.com/api-keys', [LLMFactory.Anthropic]: 'https://console.anthropic.com/settings/keys', diff --git a/web/src/hooks/logic-hooks.ts b/web/src/hooks/logic-hooks.ts index bbd0220244..7a4e7509a1 100644 --- a/web/src/hooks/logic-hooks.ts +++ b/web/src/hooks/logic-hooks.ts @@ -29,7 +29,7 @@ import { import { v4 as uuid } from 'uuid'; import { useTranslate } from './common-hooks'; import { useSetPaginationParams } from './route-hook'; -import { useFetchTenantInfo, useSaveSetting } from './use-user-setting-request'; +import { useSaveSetting } from './use-user-setting-request'; export function usePrevious(value: T) { const ref = useRef(); @@ -748,12 +748,6 @@ export const useSelectItem = (defaultId?: string) => { return { selectedId, handleItemClick }; }; -export const useFetchModelId = () => { - const { data: tenantInfo } = useFetchTenantInfo(true); - - return tenantInfo?.llm_id ?? ''; -}; - const ChunkTokenNumMap = { naive: 128, knowledge_graph: 8192, diff --git a/web/src/hooks/logic-hooks/navigate-hooks.ts b/web/src/hooks/logic-hooks/navigate-hooks.ts index 55f9e1a299..b8bca526f9 100644 --- a/web/src/hooks/logic-hooks/navigate-hooks.ts +++ b/web/src/hooks/logic-hooks/navigate-hooks.ts @@ -197,6 +197,10 @@ export const useNavigatePage = () => { [navigate], ); + const navigateToModelSetting = useCallback(() => { + navigate(`${Routes.UserSetting}${Routes.Model}`); + }, [navigate]); + return { navigateToDatasetList, navigateToDataset, @@ -223,5 +227,6 @@ export const useNavigatePage = () => { navigateToDataSourceDetail, navigateToMemory, navigateToMemoryList, + navigateToModelSetting, }; }; diff --git a/web/src/hooks/use-llm-request.tsx b/web/src/hooks/use-llm-request.tsx index c28479b04f..b7038f697a 100644 --- a/web/src/hooks/use-llm-request.tsx +++ b/web/src/hooks/use-llm-request.tsx @@ -1,197 +1,68 @@ -import { LlmIcon } from '@/components/svg-icon'; import message from '@/components/ui/message'; -import { LlmModelType } from '@/constants/knowledge'; -import { DefaultOptionType } from '@/interfaces/antd-compat'; -import { ResponseGetType } from '@/interfaces/database/base'; +import { ModelTypeToField } from '@/constants/llm'; import { - IFactory, + IAddedModel, + IAvailableProvider, + IDefaultModel, + IInstanceModel, IMyLlmValue, - IThirdOAIModelCollection as IThirdAiModelCollection, - IThirdOAIModel, - IThirdOAIModelCollection, + IProviderInstance, } from '@/interfaces/database/llm'; import { - IAddLlmRequestBody, - IDeleteLlmRequestBody, + IAddInstanceModelRequestBody, + IAddProviderInstanceRequestBody, + IAddProviderRequestBody, + IDeleteProviderInstanceRequestBody, + IListAllModelsRequestParams, + IListProvidersRequestParams, + ISetDefaultModelRequestBody, + IUpdateModelStatusRequestBody, } from '@/interfaces/request/llm'; -import userService from '@/services/user-service'; -import { getLLMIconName, getRealModelName } from '@/utils/llm-util'; +import llmService from '@/services/llm-service'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; -import { useCallback, useMemo } from 'react'; +import { useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { buildLlmUuid } from '@/utils/llm-util'; +import { buildModelValue, parseModelValue } from '@/utils/llm-util'; +import { useWarnEmptyModel } from './use-warn-empty-model'; export const enum LLMApiAction { - LlmList = 'llmList', - MyLlmList = 'myLlmList', - MyLlmListDetailed = 'myLlmListDetailed', - FactoryList = 'factoryList', - SaveApiKey = 'saveApiKey', - SaveTenantInfo = 'saveTenantInfo', - AddLlm = 'addLlm', - DeleteLlm = 'deleteLlm', - EnableLlm = 'enableLlm', - DeleteFactory = 'deleteFactory', + AllModels = 'allModels', + AvailableProviders = 'availableProviders', + AddedProviders = 'addedProviders', + AddProvider = 'addProvider', + AddProviderInstance = 'addProviderInstance', + AddInstanceModel = 'addInstanceModel', + DeleteProviderInstance = 'deleteProviderInstance', + ListDefaultModels = 'listDefaultModels', + SetDefaultModel = 'setDefaultModel', } -export const useFetchLlmList = (modelType?: LlmModelType) => { - const { data } = useQuery({ - queryKey: [LLMApiAction.LlmList], - initialData: {}, - queryFn: async () => { - const { data } = await userService.llmList({ model_type: modelType }); - - return data?.data ?? {}; - }, - }); - - return data; +export const LlmKeys = { + availableProviders: () => [LLMApiAction.AvailableProviders] as const, + addedProviders: () => [LLMApiAction.AddedProviders] as const, + allModels: (modelType?: string) => + [LLMApiAction.AllModels, modelType] as const, + providerInstances: (providerName: string) => + [LLMApiAction.AddedProviders, providerName, 'instances'] as const, + instanceModels: (providerName: string, instanceName: string) => + [ + LLMApiAction.AddedProviders, + providerName, + instanceName, + 'models', + ] as const, + defaultModels: () => [LLMApiAction.ListDefaultModels] as const, }; -type IThirdOAIModelWithUuid = IThirdOAIModel & { uuid: string }; - -export function useSelectFlatLlmList(modelType?: LlmModelType) { - const llmList = useFetchLlmList(modelType); - - return Object.values(llmList).reduce((pre, cur) => { - pre.push(...cur.map((x) => ({ ...x, uuid: buildLlmUuid(x) }))); - - return pre; - }, []); -} - -export function useFindLlmByUuid(modelType?: LlmModelType) { - const flatList = useSelectFlatLlmList(modelType); - - return (uuid: string) => { - return flatList.find((x) => x.uuid === uuid); - }; -} - -function buildLlmOptionsWithIcon(x: IThirdOAIModel) { - return { - label: ( -
- - {getRealModelName(x.llm_name)} -
- ), - value: `${x.llm_name}@${x.fid}`, - disabled: !x.available, - is_tools: x.is_tools, - }; -} - -export const useSelectLlmOptionsByModelType = () => { - const llmInfo: IThirdOAIModelCollection = useFetchLlmList(); - - const groupImage2TextOptions = useCallback(() => { - const modelType = LlmModelType.Image2text; - const modelTag = modelType.toUpperCase(); - return Object.entries(llmInfo) - .map(([key, value]) => { - return { - label: key, - options: value - .filter( - (x) => - (x.model_type.includes(modelType) || - (x.tags && x.tags.includes(modelTag))) && - x.available && - x.status === '1', - ) - .map(buildLlmOptionsWithIcon), - }; - }) - .filter((x) => x.options.length > 0); - }, [llmInfo]); - - const groupOptionsByModelType = useCallback( - (modelType: LlmModelType) => { - return Object.entries(llmInfo) - .filter(([, value]) => - modelType - ? value.some((x) => x.model_type.includes(modelType)) - : true, - ) - .map(([key, value]) => { - return { - label: key, - options: value - .filter( - (x) => - (modelType ? x.model_type.includes(modelType) : true) && - x.available, - ) - .map(buildLlmOptionsWithIcon), - }; - }) - .filter((x) => x.options.length > 0); - }, - [llmInfo], - ); - - return { - [LlmModelType.Chat]: groupOptionsByModelType(LlmModelType.Chat), - [LlmModelType.Embedding]: groupOptionsByModelType(LlmModelType.Embedding), - [LlmModelType.Image2text]: groupImage2TextOptions(), - [LlmModelType.Speech2text]: groupOptionsByModelType( - LlmModelType.Speech2text, - ), - [LlmModelType.Rerank]: groupOptionsByModelType(LlmModelType.Rerank), - [LlmModelType.TTS]: groupOptionsByModelType(LlmModelType.TTS), - [LlmModelType.Ocr]: groupOptionsByModelType(LlmModelType.Ocr), - }; -}; - -// Merge different types of models from the same manufacturer under one manufacturer -export const useComposeLlmOptionsByModelTypes = ( - modelTypes: LlmModelType[], -) => { - const allOptions = useSelectLlmOptionsByModelType(); - return modelTypes.reduce< - (DefaultOptionType & { - options: { - label: JSX.Element; - value: string; - disabled: boolean; - is_tools: boolean; - }[]; - })[] - >((pre, cur) => { - const options = allOptions[cur]; - options.forEach((x) => { - const item = pre.find((y) => y.label === x.label); - if (item) { - x.options.forEach((y) => { - // A model that is both an image2text and speech2text model - if (!item.options.some((z) => z.value === y.value)) { - item.options.push(y); - } - }); - } else { - pre.push(x); - } - }); - - return pre; - }, []); -}; - -export const useFetchLlmFactoryList = (): ResponseGetType => { - const { data, isFetching: loading } = useQuery({ - queryKey: [LLMApiAction.FactoryList], +export const useFetchAvailableProviders = () => { + const { data, isFetching: loading } = useQuery({ + queryKey: LlmKeys.availableProviders(), initialData: [], gcTime: 0, queryFn: async () => { - const { data } = await userService.factoriesList(); + const params: IListProvidersRequestParams = { available: true }; + const { data } = await llmService.listProviders({ params }, true); return data?.data ?? []; }, @@ -200,190 +71,198 @@ export const useFetchLlmFactoryList = (): ResponseGetType => { return { data, loading }; }; +export const useFetchAddedProviders = () => { + const { data, isFetching: loading } = useQuery({ + queryKey: LlmKeys.addedProviders(), + initialData: [], + gcTime: 0, + queryFn: async () => { + const { data } = await llmService.listProviders({ params: {} }, true); + + return data?.data ?? []; + }, + }); + + return { data, loading }; +}; + +export const useFetchAllAddedModels = (modelType?: string) => { + const { data, isFetching: loading } = useQuery({ + queryKey: LlmKeys.allModels(modelType), + initialData: [], + gcTime: 0, + queryFn: async () => { + const params: IListAllModelsRequestParams = {}; + if (modelType) { + params.type = modelType; + } + const { data } = await llmService.listAllAddedModels({ params }, true); + + return data?.data ?? []; + }, + }); + + return { data, loading }; +}; + +export function useFindLlmByUuid() { + const { data: models } = useFetchAllAddedModels(); + + return (uuid: string) => { + const parsed = parseModelValue(uuid); + if (parsed) { + return models.find( + (m) => + m.name === parsed.model_name && + m.instance_name === parsed.model_instance && + m.provider_name === parsed.model_provider, + ); + } + return undefined; + }; +} + +export const useFetchProviderInstances = (providerName: string) => { + const { data, isFetching: loading } = useQuery({ + queryKey: LlmKeys.providerInstances(providerName), + initialData: [], + gcTime: 0, + enabled: !!providerName, + queryFn: async () => { + const { data } = await llmService.listProviderInstances( + { provider_name: providerName }, + true, + ); + return data?.data ?? []; + }, + }); + + return { data, loading }; +}; + +export const useFetchInstanceModels = ( + providerName: string, + instanceName: string, +) => { + const { data, isFetching: loading } = useQuery({ + queryKey: LlmKeys.instanceModels(providerName, instanceName), + initialData: [], + gcTime: 0, + enabled: !!providerName && !!instanceName, + queryFn: async () => { + const { data } = await llmService.listInstanceModels( + { provider_name: providerName, instance_name: instanceName }, + true, + ); + return data?.data ?? []; + }, + }); + + return { data, loading }; +}; + export type LlmItem = { name: string; logo: string } & IMyLlmValue; -export const useFetchMyLlmList = (): ResponseGetType< - Record -> => { - const { data, isFetching: loading } = useQuery({ - queryKey: [LLMApiAction.MyLlmList], - initialData: {}, - gcTime: 0, - queryFn: async () => { - const { data } = await userService.myLlm(); - - return data?.data ?? {}; - }, - }); - - return { data, loading }; -}; - -export const useFetchMyLlmListDetailed = (): ResponseGetType< - Record -> => { - const { data, isFetching: loading } = useQuery({ - queryKey: [LLMApiAction.MyLlmListDetailed], - initialData: {}, - gcTime: 0, - queryFn: async () => { - const { data } = await userService.myLlm({ include_details: true }); - - return data?.data ?? {}; - }, - }); - - return { data, loading }; -}; - -export const useSelectLlmList = () => { - const { data: myLlmList, loading: myLlmListLoading } = useFetchMyLlmList(); - const { data: factoryList, loading: factoryListLoading } = - useFetchLlmFactoryList(); - - const nextMyLlmList: Array = useMemo(() => { - return Object.entries(myLlmList).map(([key, value]) => ({ - name: key, - logo: factoryList.find((x) => x.name === key)?.logo ?? '', - ...value, - llm: value.llm?.map((x) => ({ ...x, name: x.name })), - })); - }, [myLlmList, factoryList]); - - const nextFactoryList = useMemo(() => { - const currentList = factoryList.filter((x) => - Object.keys(myLlmList).every((y) => y !== x.name), - ); - return currentList; - // return sortLLmFactoryListBySpecifiedOrder(currentList); - }, [factoryList, myLlmList]); - - return { - myLlmList: nextMyLlmList, - factoryList: nextFactoryList, - loading: myLlmListLoading || factoryListLoading, - }; -}; - -export interface IApiKeySavingParams { - llm_factory: string; - api_key: string; - llm_name?: string; - model_type?: string; - base_url?: string; - source_fid?: string; - verify?: boolean; -} - -export const useSaveApiKey = () => { - const queryClient = useQueryClient(); - // const { t } = useTranslation(); +export const useAddProvider = () => { const { data, isPending: loading, mutateAsync, } = useMutation({ - mutationKey: [LLMApiAction.SaveApiKey], - mutationFn: async (params: IApiKeySavingParams) => { - const { data } = await userService.setApiKey(params); - if (data.code === 0) { - // message.success(t('message.modified')); - queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] }); - queryClient.invalidateQueries({ - queryKey: [LLMApiAction.MyLlmListDetailed], - }); - queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] }); + mutationKey: [LLMApiAction.AddProvider], + mutationFn: async (params: IAddProviderRequestBody) => { + try { + const { data: listRes } = await llmService.listProviders( + { params: {} }, + true, + ); + const isProviderAdded = listRes?.data?.some( + (p: IAvailableProvider) => p.name === params.provider_name, + ); + if (isProviderAdded) { + return { code: 0, data: null }; + } + } catch { + // ignore list failure and proceed to add } + const { data } = await llmService.addProvider(params); return data; }, }); - return { data, loading, saveApiKey: mutateAsync }; + return { data, loading, addProvider: mutateAsync }; }; -export interface ISystemModelSettingSavingParams { - tenant_id: string; - name?: string; - asr_id: string; - embd_id: string; - img2txt_id: string; - llm_id: string; -} - -export const useSaveTenantInfo = () => { - const { t } = useTranslation(); - const { - data, - isPending: loading, - mutateAsync, - } = useMutation({ - mutationKey: [LLMApiAction.SaveTenantInfo], - mutationFn: async (params: ISystemModelSettingSavingParams) => { - const { data } = await userService.setTenantInfo(params); - if (data.code === 0) { - message.success(t('message.modified')); - } - return data.code; - }, - }); - - return { data, loading, saveTenantInfo: mutateAsync }; -}; - -export const useAddLlm = () => { +export const useAddProviderInstance = () => { + const { addProvider } = useAddProvider(); const queryClient = useQueryClient(); - // const { t } = useTranslation(); const { data, isPending: loading, mutateAsync, } = useMutation({ - mutationKey: [LLMApiAction.AddLlm], - mutationFn: async (params: IAddLlmRequestBody & { verify?: boolean }) => { - const { data } = await userService.addLlm(params); + mutationKey: [LLMApiAction.AddProviderInstance], + mutationFn: async ( + params: IAddProviderInstanceRequestBody & { verify?: boolean }, + ) => { + try { + await addProvider({ provider_name: params.llm_factory }); + + const { data: instancesRes } = await llmService.listProviderInstances( + { provider_name: params.llm_factory }, + true, + ); + const instanceExists = instancesRes?.data?.some( + (i: IProviderInstance) => i.instance_name === params.instance_name, + ); + if (instanceExists && !params.verify) { + return { code: 0, data: null }; + } + } catch { + // ignore list failure and proceed to add + } + + const { data } = await llmService.addProviderInstance(params); if (data.code === 0 && !params.verify) { - queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] }); queryClient.invalidateQueries({ - queryKey: [LLMApiAction.MyLlmListDetailed], + queryKey: LlmKeys.addedProviders(), }); - queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] }); - queryClient.invalidateQueries({ queryKey: [LLMApiAction.LlmList] }); - // message.success(t('message.modified')); } return data; }, }); - return { data, loading, addLlm: mutateAsync }; + return { data, loading, addProviderInstance: mutateAsync }; }; -export const useDeleteLlm = () => { +export const useAddInstanceModel = () => { const queryClient = useQueryClient(); - const { t } = useTranslation(); const { data, isPending: loading, mutateAsync, } = useMutation({ - mutationKey: [LLMApiAction.DeleteLlm], - mutationFn: async (params: IDeleteLlmRequestBody) => { - const { data } = await userService.deleteLlm(params); + mutationKey: [LLMApiAction.AddInstanceModel], + mutationFn: async ( + params: { + provider_name: string; + instance_name: string; + } & IAddInstanceModelRequestBody, + ) => { + const { data } = await llmService.addInstanceModel(params); if (data.code === 0) { - queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] }); queryClient.invalidateQueries({ - queryKey: [LLMApiAction.MyLlmListDetailed], + queryKey: LlmKeys.addedProviders(), }); - queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] }); - message.success(t('message.deleted')); } - return data.code; + return data; }, }); - return { data, loading, deleteLlm: mutateAsync }; + return { data, loading, addInstanceModel: mutateAsync }; }; -export const useEnableLlm = () => { +export const useDeleteProviderInstance = () => { const queryClient = useQueryClient(); const { t } = useTranslation(); const { @@ -391,52 +270,104 @@ export const useEnableLlm = () => { isPending: loading, mutateAsync, } = useMutation({ - mutationKey: [LLMApiAction.EnableLlm], - mutationFn: async (params: IDeleteLlmRequestBody & { enable: boolean }) => { - const reqParam: IDeleteLlmRequestBody & { - enable?: boolean; - status?: 1 | 0; - } = { ...params, status: params.enable ? 1 : 0 }; - delete reqParam.enable; - const { data } = await userService.enableLlm(reqParam); + mutationKey: [LLMApiAction.DeleteProviderInstance], + mutationFn: async (params: IDeleteProviderInstanceRequestBody) => { + const { data } = await llmService.deleteProviderInstance(params); if (data.code === 0) { - queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] }); queryClient.invalidateQueries({ - queryKey: [LLMApiAction.MyLlmListDetailed], + queryKey: LlmKeys.addedProviders(), + exact: true, }); - queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] }); + queryClient.invalidateQueries({ + queryKey: LlmKeys.providerInstances(params.provider_name), + }); + queryClient.invalidateQueries({ + queryKey: LlmKeys.defaultModels(), + }); + + message.success(t('message.deleted')); + } + return data; + }, + }); + + return { data, loading, deleteProviderInstance: mutateAsync }; +}; + +export const useUpdateModelStatus = () => { + const queryClient = useQueryClient(); + const { t } = useTranslation(); + const { isPending: loading, mutateAsync } = useMutation({ + mutationKey: [LLMApiAction.AddedProviders, 'updateModelStatus'], + mutationFn: async (params: IUpdateModelStatusRequestBody) => { + const { data } = await llmService.updateModelStatus(params); + if (data.code === 0) { message.success(t('message.modified')); + queryClient.invalidateQueries({ + queryKey: LlmKeys.defaultModels(), + }); + queryClient.invalidateQueries({ + queryKey: LlmKeys.instanceModels( + params.provider_name, + params.instance_name, + ), + }); } - return data.code; + return data; }, }); - return { data, loading, enableLlm: mutateAsync }; + return { loading, updateModelStatus: mutateAsync }; }; -export const useDeleteFactory = () => { +export const useFetchDefaultModels = () => { + const { data, isFetching: loading } = useQuery({ + queryKey: LlmKeys.defaultModels(), + initialData: [], + gcTime: 0, + queryFn: async () => { + const { data } = await llmService.listDefaultModels({}, true); + return data?.data?.models ?? []; + }, + }); + + return { data, loading }; +}; + +export const useFetchDefaultModelDictionary = (showEmptyModelWarn = false) => { + const { data: defaultModels } = useFetchDefaultModels(); + + const result = useMemo(() => { + const dict: Record = {}; + Object.entries(ModelTypeToField).forEach(([key, field]) => { + const model = defaultModels.find((m) => m.model_type === key); + dict[field] = model && model.enable ? buildModelValue(model) : ''; + }); + return dict; + }, [defaultModels]); + + useWarnEmptyModel(showEmptyModelWarn, result.embd_id, result.llm_id); + + return result; +}; + +export const useSetDefaultModel = () => { const queryClient = useQueryClient(); const { t } = useTranslation(); - const { - data, - isPending: loading, - mutateAsync, - } = useMutation({ - mutationKey: [LLMApiAction.DeleteFactory], - mutationFn: async (params: IDeleteLlmRequestBody) => { - const { data } = await userService.deleteFactory(params); + + const { isPending: loading, mutateAsync } = useMutation({ + mutationKey: [LLMApiAction.SetDefaultModel], + mutationFn: async (params: ISetDefaultModelRequestBody) => { + const { data } = await llmService.setDefaultModel(params); if (data.code === 0) { - queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] }); + message.success(t('message.modified')); queryClient.invalidateQueries({ - queryKey: [LLMApiAction.MyLlmListDetailed], + queryKey: LlmKeys.defaultModels(), }); - queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] }); - queryClient.invalidateQueries({ queryKey: [LLMApiAction.LlmList] }); - message.success(t('message.deleted')); } - return data.code; + return data; }, }); - return { data, loading, deleteFactory: mutateAsync }; + return { loading, setDefaultModel: mutateAsync }; }; diff --git a/web/src/hooks/use-user-setting-request.tsx b/web/src/hooks/use-user-setting-request.tsx index b7d6eb6bca..6bb272ebd4 100644 --- a/web/src/hooks/use-user-setting-request.tsx +++ b/web/src/hooks/use-user-setting-request.tsx @@ -1,5 +1,4 @@ import message from '@/components/ui/message'; -import { Modal } from '@/components/ui/modal/modal'; import { ResponseGetType } from '@/interfaces/database/base'; import { IToken } from '@/interfaces/database/chat'; import { ITenantInfo } from '@/interfaces/database/dataset'; @@ -11,7 +10,6 @@ import { } from '@/interfaces/database/user-setting'; import { ISetLangfuseConfigRequestBody } from '@/interfaces/request/system'; import { DEFAULT_LANGUAGE_CODE, supportedLanguages } from '@/locales/config'; -import { Routes } from '@/routes'; import userService, { addTenantUser, agreeTenant, @@ -20,11 +18,10 @@ import userService, { listTenantUser, } from '@/services/user-service'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; -import DOMPurify from 'dompurify'; -import { isEmpty } from 'lodash'; import { useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { useNavigate } from 'react-router'; + +import { useWarnEmptyModel } from './use-warn-empty-model'; export const enum UserSettingApiAction { UserInfo = 'userInfo', @@ -69,11 +66,10 @@ export const useFetchUserInfo = (): ResponseGetType => { return { data, loading }; }; +// Stop using this interface to retrieve the default model; instead, directly call `useFetchDefaultModelDictionary`. export const useFetchTenantInfo = ( showEmptyModelWarn = false, ): ResponseGetType => { - const { t } = useTranslation(); - const navigate = useNavigate(); const { data, isFetching: loading } = useQuery({ queryKey: [UserSettingApiAction.TenantInfo, showEmptyModelWarn], initialData: {}, @@ -84,27 +80,6 @@ export const useFetchTenantInfo = ( // llm_id is chat_id // asr_id is speech2txt const { data } = res; - if ( - showEmptyModelWarn && - (isEmpty(data.embd_id) || isEmpty(data.llm_id)) - ) { - Modal.warning({ - title: t('common.warn'), - content: ( -
- ), - closable: false, - showCancel: false, - onOk() { - // window.open('/user-setting/model', '_self'); - navigate(`${Routes.UserSetting}${Routes.Model}`); - }, - }); - } data.chat_id = data.llm_id; data.speech2text_id = data.asr_id; @@ -115,6 +90,8 @@ export const useFetchTenantInfo = ( }, }); + useWarnEmptyModel(showEmptyModelWarn, data?.embd_id, data?.llm_id); + return { data, loading }; }; diff --git a/web/src/hooks/use-warn-empty-model.tsx b/web/src/hooks/use-warn-empty-model.tsx new file mode 100644 index 0000000000..51cf606d39 --- /dev/null +++ b/web/src/hooks/use-warn-empty-model.tsx @@ -0,0 +1,43 @@ +import { Modal } from '@/components/ui/modal/modal'; +import DOMPurify from 'dompurify'; +import { isEmpty } from 'lodash'; +import { useEffect, useRef } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useNavigatePage } from './logic-hooks/navigate-hooks'; + +export const useWarnEmptyModel = ( + showEmptyModelWarn: boolean, + embdId?: string, + llmId?: string, +) => { + const { t } = useTranslation(); + const warnedRef = useRef(false); + const { navigateToModelSetting } = useNavigatePage(); + + useEffect(() => { + if ( + showEmptyModelWarn && + !warnedRef.current && + (isEmpty(embdId) || isEmpty(llmId)) && + typeof embdId === 'string' && + typeof llmId === 'string' + ) { + warnedRef.current = true; + Modal.warning({ + title: t('common.warn'), + content: ( +
+ ), + closable: false, + showCancel: false, + onOk() { + navigateToModelSetting(); + }, + }); + } + }, [showEmptyModelWarn, embdId, llmId, navigateToModelSetting, t]); +}; diff --git a/web/src/interfaces/database/llm.ts b/web/src/interfaces/database/llm.ts index e97cac609a..98be63eb2c 100644 --- a/web/src/interfaces/database/llm.ts +++ b/web/src/interfaces/database/llm.ts @@ -40,3 +40,41 @@ export interface Llm { status: '0' | '1'; used_token: number; } + +export interface IAvailableProvider { + name: string; + model_types: string[]; + url: { default?: string; [key: string]: string | undefined }; +} + +export interface IProviderInstance { + api_key: string; + id: string; + instance_name: string; + provider_id: string; + region: string; + status: string; +} +export interface IAddedModel { + model_type: string[]; + name: string; + provider_id: string; + provider_name: string; + instance_id: string; + instance_name: string; +} + +export interface IInstanceModel { + max_tokens: number; + model_type: string[]; + name: string; + status: string; +} + +export interface IDefaultModel { + enable: boolean; + model_instance: string; + model_name: string; + model_provider: string; + model_type: string; +} diff --git a/web/src/interfaces/request/llm.ts b/web/src/interfaces/request/llm.ts index f8690784e5..932a5b8ef2 100644 --- a/web/src/interfaces/request/llm.ts +++ b/web/src/interfaces/request/llm.ts @@ -1,7 +1,7 @@ export interface IAddLlmRequestBody { llm_factory: string; // Ollama llm_name: string; - model_type: string; + model_type: string | string[]; api_base?: string; // chat|embedding|speech2text|image2text api_key?: string | Record; max_tokens: number; @@ -12,3 +12,50 @@ export interface IDeleteLlmRequestBody { llm_factory: string; // Ollama llm_name?: string; } + +export interface IListProvidersRequestParams { + available?: boolean; +} + +export interface IAddProviderRequestBody { + provider_name: string; +} + +export type IAddProviderInstanceRequestBody = IAddLlmRequestBody & { + instance_name: string; +}; + +export interface IDeleteProviderInstanceRequestBody { + provider_name: string; + instances: string[]; +} + +export interface IShowProviderInstanceRequestParams { + provider_name: string; + instance_name: string; +} + +export interface IAddInstanceModelRequestBody { + model_name: string; + model_type: string[]; + max_tokens: number; + extra?: Record; +} + +export interface IListAllModelsRequestParams { + type?: string; +} + +export interface IUpdateModelStatusRequestBody { + provider_name: string; + instance_name: string; + model_name: string; + status: 'active' | 'inactive'; +} + +export interface ISetDefaultModelRequestBody { + model_provider: string; + model_instance: string; + model_type: string; + model_name: string; +} diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index d0efc988aa..82649c7309 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -1514,6 +1514,10 @@ Example: Virtual Hosted Style`, addLlmTitle: 'Add LLM', editLlmTitle: 'Edit {{name}} model', editModel: 'Edit model', + instanceName: 'Instance name', + instanceNameMessage: 'Please input the instance name!', + instanceNameTip: + 'A unique name to identify this provider instance under the same factory.', modelName: 'Model name', modelID: 'Model ID', modelUid: 'Model UID', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 473781af87..454bece5da 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -1099,7 +1099,8 @@ NER:使用 spaCy NER 和基于规则的关键词提取来抽取实体和关系 gmailTokenTip: '请上传由 Google Console 生成的 OAuth JSON。如果仅包含 client credentials,请通过浏览器授权一次以获取长期有效的刷新 Token。', dropboxDescription: '连接 Dropbox,同步指定账号下的文件与文件夹。', - teamsDescription: '通过 Microsoft Graph 连接 Microsoft Teams,同步频道帖子与回复。', + teamsDescription: + '通过 Microsoft Graph 连接 Microsoft Teams,同步频道帖子与回复。', teamsTenantIdTip: 'Azure AD 租户 ID。需要具备 Team.ReadBasic.All 与 ChannelMessage.Read.All 应用权限(管理员同意)的应用。', slackDescription: '连接你的 Slack 工作区,同步频道消息与讨论串。', @@ -1107,7 +1108,8 @@ NER:使用 spaCy NER 和基于规则的关键词提取来抽取实体和关系 'Slack 机器人用户 OAuth Token(以 xoxb- 开头)。应用需具备 channels:read、channels:history 和 users:read 权限。', slackChannelsTip: '可选:需要同步的频道名称(例如 general)。留空则同步所有可访问的频道。', - sharepointDescription: '通过 Microsoft Graph 连接 SharePoint 站点,同步其文档库。', + sharepointDescription: + '通过 Microsoft Graph 连接 SharePoint 站点,同步其文档库。', sharepointSiteUrlTip: '要索引的 SharePoint 站点完整 URL,例如 https://contoso.sharepoint.com/sites/MySite。需要具备 Sites.Read.All 与 Files.Read.All 应用权限(管理员同意)的 Azure AD 应用。', boxDescription: '连接你的 Box 云盘以同步文件和文件夹。', @@ -1238,6 +1240,9 @@ NER:使用 spaCy NER 和基于规则的关键词提取来抽取实体和关系 addLlmTitle: '添加 LLM', editLlmTitle: '编辑 {{name}} 模型', editModel: '编辑模型', + instanceName: '实例名称', + instanceNameMessage: '请输入实例名称!', + instanceNameTip: '用于在同一厂商下唯一标识该实例的名称。', modelName: '模型名称', modelID: '模型ID', modelUid: '模型UID', diff --git a/web/src/pages/agent/canvas/node/card.tsx b/web/src/pages/agent/canvas/node/card.tsx index 82e324d1a0..a11d33446b 100644 --- a/web/src/pages/agent/canvas/node/card.tsx +++ b/web/src/pages/agent/canvas/node/card.tsx @@ -17,8 +17,9 @@ import { SelectTrigger, SelectValue, } from '@/components/ui/select'; -import { useSelectFlatLlmList } from '@/hooks/use-llm-request'; +import { useFetchAllAddedModels } from '@/hooks/use-llm-request'; import { cn } from '@/lib/utils'; +import { parseModelValue } from '@/utils/llm-util'; import { PropsWithChildren, useMemo } from 'react'; export function CardWithForm() { @@ -80,12 +81,23 @@ export function LabelCard({ children, className, ...props }: LabelCardProps) { } export function LLMLabelCard({ llmId }: { llmId?: string }) { - const flatLlmList = useSelectFlatLlmList(); + const { data: allAddedModels } = useFetchAllAddedModels(); const isValidLlm = useMemo(() => { if (!llmId) return false; - return flatLlmList.some((llm) => llm.uuid === llmId); - }, [flatLlmList, llmId]); + + const parsed = parseModelValue(llmId); + if (parsed) { + return allAddedModels.some( + (m) => + m.name === parsed.model_name && + m.instance_name === parsed.model_instance && + m.provider_name === parsed.model_provider, + ); + } + + return false; + }, [allAddedModels, llmId]); return ( {isSubAgent && } - {findLlmByUuid(llmId)?.tags?.includes('IMAGE2TEXT') && ( + {findLlmByUuid(llmId)?.model_type?.includes( + LlmModelType.Image2text, + ) && ( ) { } export function useValues(node?: RAGFlowNodeType) { - const llmId = useFetchModelId(); + const defaultModelDictionary = useFetchDefaultModelDictionary(); const defaultValues = useMemo( () => ({ ...omitToolsAndMcp(initialAgentValues), - llm_id: llmId, + llm_id: defaultModelDictionary.llm_id, prompts: '', }), - [llmId], + [defaultModelDictionary], ); const values = useMemo(() => { diff --git a/web/src/pages/agent/form/parser-form/common-form-fields.tsx b/web/src/pages/agent/form/parser-form/common-form-fields.tsx index c857f4733f..c51cd4d837 100644 --- a/web/src/pages/agent/form/parser-form/common-form-fields.tsx +++ b/web/src/pages/agent/form/parser-form/common-form-fields.tsx @@ -1,9 +1,5 @@ import { crossLanguageOptions } from '@/components/cross-language-form-field'; import { LayoutRecognizeFormField } from '@/components/layout-recognize-form-field'; -import { - LLMFormField, - LLMFormFieldProps, -} from '@/components/llm-setting-items/llm-form-field'; import { SelectWithSearch, SelectWithSearchFlagOptionType, @@ -75,19 +71,6 @@ export function ParserMethodFormField({ ); } -export function LargeModelFormField({ - prefix, - options, -}: CommonProps & Pick) { - return ( - - ); -} - export function FlattenMediaToTextFormField({ prefix }: CommonProps) { const { t } = useTranslation(); return ( diff --git a/web/src/pages/agent/form/parser-form/pdf-form-fields.tsx b/web/src/pages/agent/form/parser-form/pdf-form-fields.tsx index f00930d24a..fd60be7cc0 100644 --- a/web/src/pages/agent/form/parser-form/pdf-form-fields.tsx +++ b/web/src/pages/agent/form/parser-form/pdf-form-fields.tsx @@ -1,11 +1,13 @@ import { ParseDocumentType } from '@/components/layout-recognize-form-field'; +import { + ModelTreeSelectFormField, + ModelTypeMap, +} from '@/components/model-tree-select'; import { SelectWithSearch, SelectWithSearchFlagOptionType, } from '@/components/originui/select-with-search'; import { RAGFlowFormItem } from '@/components/ragflow-form'; -import { LlmModelType } from '@/constants/knowledge'; -import { useComposeLlmOptionsByModelTypes } from '@/hooks/use-llm-request'; import { isEmpty } from 'lodash'; import { useEffect, useMemo } from 'react'; import { useFormContext, useWatch } from 'react-hook-form'; @@ -13,7 +15,6 @@ import { useTranslation } from 'react-i18next'; import { FlattenMediaToTextFormField, LanguageFormField, - LargeModelFormField, ParserMethodFormField, RemoveHeaderFooterFormField, RmdirFormField, @@ -38,9 +39,6 @@ export function PdfFormFields({ prefix }: CommonProps) { const form = useFormContext(); const parseMethodName = buildFieldNameWithPrefix('parse_method', prefix); - const modelOptions = useComposeLlmOptionsByModelTypes([ - LlmModelType.Image2text, - ]); const parseMethod = useWatch({ name: parseMethodName, }); @@ -109,10 +107,12 @@ export function PdfFormFields({ prefix }: CommonProps) { {!flattenMediaToText && ( - + )} {languageShown && } {tcadpOptionsShown && ( diff --git a/web/src/pages/agent/form/parser-form/spreadsheet-form-fields.tsx b/web/src/pages/agent/form/parser-form/spreadsheet-form-fields.tsx index d7566c3b05..f257a25f51 100644 --- a/web/src/pages/agent/form/parser-form/spreadsheet-form-fields.tsx +++ b/web/src/pages/agent/form/parser-form/spreadsheet-form-fields.tsx @@ -1,18 +1,19 @@ import { ParseDocumentType } from '@/components/layout-recognize-form-field'; +import { + ModelTreeSelectFormField, + ModelTypeMap, +} from '@/components/model-tree-select'; import { SelectWithSearch, SelectWithSearchFlagOptionType, } from '@/components/originui/select-with-search'; import { RAGFlowFormItem } from '@/components/ragflow-form'; -import { LlmModelType } from '@/constants/knowledge'; -import { useComposeLlmOptionsByModelTypes } from '@/hooks/use-llm-request'; import { isEmpty } from 'lodash'; import { useEffect, useMemo } from 'react'; import { useFormContext, useWatch } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { FlattenMediaToTextFormField, - LargeModelFormField, ParserMethodFormField, } from './common-form-fields'; import { CommonProps } from './interface'; @@ -31,9 +32,6 @@ const markdownImageResponseTypeOptions: SelectWithSearchFlagOptionType[] = [ export function SpreadsheetFormFields({ prefix }: CommonProps) { const { t } = useTranslation(); const form = useFormContext(); - const modelOptions = useComposeLlmOptionsByModelTypes([ - LlmModelType.Image2text, - ]); const parseMethodName = buildFieldNameWithPrefix('parse_method', prefix); @@ -103,10 +101,12 @@ export function SpreadsheetFormFields({ prefix }: CommonProps) { > {!flattenMediaToText && ( - + )} {tcadpOptionsShown && ( <> diff --git a/web/src/pages/agent/form/parser-form/text-html-form-fields.tsx b/web/src/pages/agent/form/parser-form/text-html-form-fields.tsx index 1bd500466f..76e477c74c 100644 --- a/web/src/pages/agent/form/parser-form/text-html-form-fields.tsx +++ b/web/src/pages/agent/form/parser-form/text-html-form-fields.tsx @@ -1,9 +1,11 @@ -import { LlmModelType } from '@/constants/knowledge'; -import { useComposeLlmOptionsByModelTypes } from '@/hooks/use-llm-request'; +import { + ModelTreeSelectFormField, + ModelTypeMap, +} from '@/components/model-tree-select'; import { useWatch } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; import { FlattenMediaToTextFormField, - LargeModelFormField, RemoveHeaderFooterFormField, RmdirFormField, } from './common-form-fields'; @@ -11,9 +13,7 @@ import { CommonProps } from './interface'; import { buildFieldNameWithPrefix } from './utils'; export function TextMarkdownFormFields({ prefix }: CommonProps) { - const modelOptions = useComposeLlmOptionsByModelTypes([ - LlmModelType.Image2text, - ]); + const { t } = useTranslation(); const flattenMediaToText = useWatch({ name: buildFieldNameWithPrefix('flatten_media_to_text', prefix), }); @@ -23,10 +23,12 @@ export function TextMarkdownFormFields({ prefix }: CommonProps) { {!flattenMediaToText && ( - + )} ); diff --git a/web/src/pages/agent/form/parser-form/video-form-fields.tsx b/web/src/pages/agent/form/parser-form/video-form-fields.tsx index 3f37a9864e..b24e703e7c 100644 --- a/web/src/pages/agent/form/parser-form/video-form-fields.tsx +++ b/web/src/pages/agent/form/parser-form/video-form-fields.tsx @@ -1,38 +1,39 @@ -import { LlmModelType } from '@/constants/knowledge'; -import { useComposeLlmOptionsByModelTypes } from '@/hooks/use-llm-request'; import { - LargeModelFormField, - OutputFormatFormFieldProps, -} from './common-form-fields'; + ModelTreeSelectFormField, + ModelTypeMap, +} from '@/components/model-tree-select'; +import { useTranslation } from 'react-i18next'; +import { OutputFormatFormFieldProps } from './common-form-fields'; +import { buildFieldNameWithPrefix } from './utils'; export function AudioFormFields({ prefix }: OutputFormatFormFieldProps) { - const modelOptions = useComposeLlmOptionsByModelTypes([ - LlmModelType.Speech2text, - ]); + const { t } = useTranslation(); return ( <> {/* Multimodal Model */} - + ); } export function VideoFormFields({ prefix }: OutputFormatFormFieldProps) { - const modelOptions = useComposeLlmOptionsByModelTypes([ - LlmModelType.Image2text, - ]); + const { t } = useTranslation(); return ( <> {/* Multimodal Model */} - + ); } diff --git a/web/src/pages/agent/form/parser-form/word-form-fields.tsx b/web/src/pages/agent/form/parser-form/word-form-fields.tsx index a6bd900305..41de221691 100644 --- a/web/src/pages/agent/form/parser-form/word-form-fields.tsx +++ b/web/src/pages/agent/form/parser-form/word-form-fields.tsx @@ -1,9 +1,11 @@ -import { LlmModelType } from '@/constants/knowledge'; -import { useComposeLlmOptionsByModelTypes } from '@/hooks/use-llm-request'; +import { + ModelTreeSelectFormField, + ModelTypeMap, +} from '@/components/model-tree-select'; import { useWatch } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; import { FlattenMediaToTextFormField, - LargeModelFormField, OutputFormatFormFieldProps, RemoveHeaderFooterFormField, RmdirFormField, @@ -11,9 +13,7 @@ import { import { buildFieldNameWithPrefix } from './utils'; export function WordFormFields({ prefix }: OutputFormatFormFieldProps) { - const modelOptions = useComposeLlmOptionsByModelTypes([ - LlmModelType.Image2text, - ]); + const { t } = useTranslation(); const flattenMediaToText = useWatch({ name: buildFieldNameWithPrefix('flatten_media_to_text', prefix), }); @@ -24,10 +24,12 @@ export function WordFormFields({ prefix }: OutputFormatFormFieldProps) { {!flattenMediaToText && ( - + )} ); diff --git a/web/src/pages/agent/form/rewrite-question-form/index.tsx b/web/src/pages/agent/form/rewrite-question-form/index.tsx index 4c94661764..bf55c9310b 100644 --- a/web/src/pages/agent/form/rewrite-question-form/index.tsx +++ b/web/src/pages/agent/form/rewrite-question-form/index.tsx @@ -1,5 +1,5 @@ -import { NextLLMSelect } from '@/components/llm-select/next'; import { MessageHistoryWindowSizeFormField } from '@/components/message-history-window-size-item'; +import { ModelTreeSelectFormField } from '@/components/model-tree-select'; import { Form, FormControl, @@ -24,20 +24,10 @@ const RewriteQuestionForm = ({ form }: INextOperatorForm) => { e.preventDefault(); }} > - ( - - - {t('chat.model')} - - - - - - - )} + label={t('chat.model')} + tooltip={t('chat.modelTip')} /> { - const llmId = useFetchModelId(); + const defaultModelDictionary = useFetchDefaultModelDictionary(); + const llmId = defaultModelDictionary.llm_id; const initialFormValuesMap = useMemo(() => { return { [Operator.Begin]: initialBeginValues, [Operator.Retrieval]: initialRetrievalValues, - [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId }, + [Operator.Categorize]: { + ...initialCategorizeValues, + llm_id: llmId, + }, [Operator.RewriteQuestion]: { ...initialRewriteQuestionValues, llm_id: llmId, diff --git a/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx b/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx index 1115a547db..f3409d4d16 100644 --- a/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx +++ b/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx @@ -3,10 +3,8 @@ import { FormFieldType, RenderField, } from '@/components/dynamic-form'; -import { - SelectWithSearch, - SelectWithSearchFlagOptionType, -} from '@/components/originui/select-with-search'; +import { ModelTreeSelect, ModelTypeMap } from '@/components/model-tree-select'; +import { SelectWithSearch } from '@/components/originui/select-with-search'; import { SliderInputFormField } from '@/components/slider-input-form-field'; import { Button } from '@/components/ui/button'; import { @@ -19,9 +17,8 @@ import { import { Radio } from '@/components/ui/radio'; import { Spin } from '@/components/ui/spin'; import { Switch } from '@/components/ui/switch'; -import { LlmModelType, ParseType } from '@/constants/knowledge'; +import { ParseType } from '@/constants/knowledge'; import { useTranslate } from '@/hooks/common-hooks'; -import { useComposeLlmOptionsByModelTypes } from '@/hooks/use-llm-request'; import { cn } from '@/lib/utils'; import { history } from '@/utils/simple-history-util'; import { t } from 'i18next'; @@ -51,7 +48,6 @@ import { useHandleKbEmbedding, useHasParsedDocument, useSelectChunkMethodList, - useSelectEmbeddingModelOptions, } from '../hooks'; interface IProps { line?: 1 | 2; @@ -117,13 +113,12 @@ export const EmbeddingSelect = ({ }) => { const { t } = useTranslate('knowledgeConfiguration'); const form = useFormContext(); - const embeddingModelOptions = useSelectEmbeddingModelOptions(); const { handleChange } = useHandleKbEmbedding(); const oldValue = useMemo(() => { const embdStr = form.getValues(name || 'embedding_model'); return embdStr || ''; - }, [form]); + }, [form, name]); const [loading, setLoading] = useState(false); return ( - { field.onChange(value); if (isEdit && disabled) { setLoading(true); const res = await handleChange({ embed_id: value, - // callback: field.onChange, }); if (res.code !== 0) { field.onChange(oldValue); @@ -149,7 +144,6 @@ export const EmbeddingSelect = ({ }} disabled={disabled && !isEdit} value={field.value} - options={embeddingModelOptions} placeholder={t('embeddingModelPlaceholder')} testId={testId} /> @@ -544,18 +538,14 @@ export const LLMSelect = ({ disabled?: boolean; }) => { const { t } = useTranslate('knowledgeConfiguration'); - const modelOptions = useComposeLlmOptionsByModelTypes([ - LlmModelType.Chat, - LlmModelType.Image2text, - ]); return ( - { + { field.onChange(value); }} disabled={disabled && !isEdit} value={field.value} - options={modelOptions as SelectWithSearchFlagOptionType[]} placeholder={t('embeddingModelPlaceholder')} /> ); diff --git a/web/src/pages/dataset/dataset-setting/hooks.ts b/web/src/pages/dataset/dataset-setting/hooks.ts index a276b6dceb..3389470d0f 100644 --- a/web/src/pages/dataset/dataset-setting/hooks.ts +++ b/web/src/pages/dataset/dataset-setting/hooks.ts @@ -1,8 +1,6 @@ -import { LlmModelType } from '@/constants/knowledge'; import { useSetModalState } from '@/hooks/common-hooks'; import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-request'; -import { useSelectLlmOptionsByModelType } from '@/hooks/use-llm-request'; import { useSelectParserList } from '@/hooks/use-user-setting-request'; import { checkEmbedding } from '@/services/knowledge-service'; import { useIsFetching } from '@tanstack/react-query'; @@ -22,11 +20,6 @@ export function useSelectChunkMethodList() { return parserList.filter((x) => !HiddenFields.some((y) => y === x.value)); } -export function useSelectEmbeddingModelOptions() { - const allOptions = useSelectLlmOptionsByModelType(); - return allOptions[LlmModelType.Embedding]; -} - export function useHasParsedDocument(isEdit?: boolean) { const { data: knowledgeDetails } = useFetchKnowledgeBaseConfiguration({ isEdit, diff --git a/web/src/pages/datasets/dataset-creating-dialog.tsx b/web/src/pages/datasets/dataset-creating-dialog.tsx index 9f3a83d6be..f460f757db 100644 --- a/web/src/pages/datasets/dataset-creating-dialog.tsx +++ b/web/src/pages/datasets/dataset-creating-dialog.tsx @@ -19,7 +19,7 @@ import { import { Input } from '@/components/ui/input'; import { FormLayout } from '@/constants/form'; import { ParseType } from '@/constants/knowledge'; -import { useFetchTenantInfo } from '@/hooks/use-user-setting-request'; +import { useFetchDefaultModelDictionary } from '@/hooks/use-llm-request'; import { IModalProps } from '@/interfaces/common'; import { zodResolver } from '@hookform/resolvers/zod'; import { omit } from 'lodash'; @@ -39,7 +39,7 @@ const ChunkMethodName = 'chunk_method'; export function InputForm({ onOk }: IModalProps) { const { t } = useTranslation(); - const { data: tenantInfo } = useFetchTenantInfo(); + const defaultModelDictionary = useFetchDefaultModelDictionary(); const FormSchema = z .object({ @@ -87,7 +87,7 @@ export function InputForm({ onOk }: IModalProps) { name: '', parseType: ParseType.BuiltIn, [ChunkMethodName]: '', - embedding_model: tenantInfo?.embd_id, + embedding_model: defaultModelDictionary?.embd_id, }, }); diff --git a/web/src/pages/memories/hooks.ts b/web/src/pages/memories/hooks.ts index f2bd4b64d5..5a4c86cc16 100644 --- a/web/src/pages/memories/hooks.ts +++ b/web/src/pages/memories/hooks.ts @@ -5,7 +5,7 @@ import { useHandleFilterSubmit } from '@/components/list-filter-bar/use-handle-f import message from '@/components/ui/message'; import { useSetModalState } from '@/hooks/common-hooks'; import { useHandleSearchChange } from '@/hooks/logic-hooks'; -import { useFetchTenantInfo } from '@/hooks/use-user-setting-request'; +import { useFetchDefaultModelDictionary } from '@/hooks/use-llm-request'; import memoryService, { updateMemoryById } from '@/services/memory-service'; import { buildOwnersFilter, @@ -243,13 +243,13 @@ export const useRenameMemory = () => { const { updateMemory } = useUpdateMemory(); const { createMemory } = useCreateMemory(); const [loading, setLoading] = useState(false); - const { data: tenantInfo } = useFetchTenantInfo(); + const defaultModelDictionary = useFetchDefaultModelDictionary(); const handleShowChatRenameModal = useCallback( (record?: IMemory) => { if (record) { - const embd_id = record.embd_id || tenantInfo?.embd_id; - const llm_id = record.llm_id || tenantInfo?.llm_id; + const embd_id = record.embd_id || defaultModelDictionary?.embd_id; + const llm_id = record.llm_id || defaultModelDictionary?.llm_id; setMemory({ ...record, embd_id, @@ -258,7 +258,7 @@ export const useRenameMemory = () => { } showChatRenameModal(); }, - [showChatRenameModal, tenantInfo], + [showChatRenameModal, defaultModelDictionary], ); const handleHideModal = useCallback(() => { diff --git a/web/src/pages/memory/memory-setting/memory-model-form.tsx b/web/src/pages/memory/memory-setting/memory-model-form.tsx index 70427504c1..455c4f8b34 100644 --- a/web/src/pages/memory/memory-setting/memory-model-form.tsx +++ b/web/src/pages/memory/memory-setting/memory-model-form.tsx @@ -1,5 +1,5 @@ import { FormFieldType, RenderField } from '@/components/dynamic-form'; -import { useModelOptions } from '@/components/llm-setting-items/llm-form-field'; +import { ModelTreeSelect } from '@/components/model-tree-select'; import { EmbeddingSelect } from '@/pages/dataset/dataset-setting/configuration/common-item'; import { MemoryOptions, MemoryType } from '@/pages/memories/constants'; import { TFunction } from 'i18next'; @@ -28,7 +28,6 @@ export const defaultMemoryModelForm = { memory_size: 0, }; export const MemoryModelForm = () => { - const { modelOptions } = useModelOptions(); const { t } = useTranslation(); const { data } = useFetchMemoryMessageList(); return ( @@ -40,7 +39,6 @@ export const MemoryModelForm = () => { placeholder: t('memories.selectModel'), required: true, horizontal: true, - // hideLabel: true, type: FormFieldType.Custom, disabled: true, render: (field) => ( @@ -58,12 +56,17 @@ export const MemoryModelForm = () => { field={{ name: 'llm_id', label: t('memories.llm'), - placeholder: t('memories.selectModel'), required: true, horizontal: true, - type: FormFieldType.Select, + type: FormFieldType.Custom, disabled: data?.messages?.total_count > 0, - options: modelOptions as { value: string; label: string }[], + render: (field) => ( + + ), tooltip: t('memories.llmTooltip'), }} /> @@ -93,7 +96,6 @@ export const MemoryModelForm = () => { type: FormFieldType.Number, horizontal: true, tooltip: t('memory.config.memorySizeTooltip'), - // placeholder: t('memory.config.memorySizePlaceholder'), required: false, }} /> diff --git a/web/src/pages/next-chats/chat/chat-box/next-multiple-chat-box.tsx b/web/src/pages/next-chats/chat/chat-box/next-multiple-chat-box.tsx index 666d4c4cf3..96480af0c9 100644 --- a/web/src/pages/next-chats/chat/chat-box/next-multiple-chat-box.tsx +++ b/web/src/pages/next-chats/chat/chat-box/next-multiple-chat-box.tsx @@ -157,7 +157,7 @@ const ChatCard = forwardRef(function ChatCard( }, }, }); - }, [currentDialog, dialogId, form, patchChat]); + }, [currentDialog, dialogId, form, patchChat, findLlmByUuid]); useImperativeHandle( ref, diff --git a/web/src/pages/next-chats/chat/chat-box/use-set-default-model.ts b/web/src/pages/next-chats/chat/chat-box/use-set-default-model.ts index 661f5a72de..6ea8807ae0 100644 --- a/web/src/pages/next-chats/chat/chat-box/use-set-default-model.ts +++ b/web/src/pages/next-chats/chat/chat-box/use-set-default-model.ts @@ -1,18 +1,26 @@ -import { LlmModelType } from '@/constants/knowledge'; -import { useComposeLlmOptionsByModelTypes } from '@/hooks/use-llm-request'; -import { useMount } from 'ahooks'; +import { ModelTypeMap } from '@/components/model-tree-select'; +import { useFetchAllAddedModels } from '@/hooks/use-llm-request'; +import { getRealModelName } from '@/utils/llm-util'; +import { useEffect, useRef } from 'react'; import { UseFormReturn } from 'react-hook-form'; export function useSetDefaultModel(form: UseFormReturn) { - const modelOptions = useComposeLlmOptionsByModelTypes([ - LlmModelType.Chat, - LlmModelType.Image2text, - ]); + const { data: allAddedModels } = useFetchAllAddedModels(); + const hasSet = useRef(false); - useMount(() => { - const firstModel = modelOptions.at(0)?.options.at(0)?.value; - if (firstModel) { - form.setValue('llm_id', firstModel); + useEffect(() => { + if (hasSet.current || !allAddedModels.length) return; + const chatModels = allAddedModels.filter((m) => + m.model_type?.some((t) => ModelTypeMap.llm_id.includes(t)), + ); + const first = chatModels[0]; + if (first) { + const modelName = getRealModelName(first.name); + form.setValue( + 'llm_id', + `${modelName}@${first.instance_name}@${first.provider_name}`, + ); + hasSet.current = true; } - }); + }, [allAddedModels, form]); } diff --git a/web/src/pages/next-chats/hooks/use-create-chat.ts b/web/src/pages/next-chats/hooks/use-create-chat.ts new file mode 100644 index 0000000000..b9ebe2f9bd --- /dev/null +++ b/web/src/pages/next-chats/hooks/use-create-chat.ts @@ -0,0 +1,65 @@ +import { useSetModalState } from '@/hooks/common-hooks'; +import { useCreateChat } from '@/hooks/use-chat-request'; +import { useFetchDefaultModelDictionary } from '@/hooks/use-llm-request'; +import { useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; + +export const useCreateChatDialog = () => { + const { + visible: createChatVisible, + hideModal: hideCreateChatModal, + showModal: showCreateChatModal, + } = useSetModalState(); + const { createChat, loading: createLoading } = useCreateChat(); + const { t } = useTranslation(); + const defaultModelDictionary = + useFetchDefaultModelDictionary(createChatVisible); + + const InitialData = useMemo( + () => ({ + name: '', + icon: '', + language: 'English', + description: '', + dataset_ids: [], + prompt_config: { + empty_response: '', + prologue: t('chat.setAnOpenerInitial'), + quote: true, + keyword: false, + tts: false, + system: t('chat.systemInitialValue'), + refine_multiturn: false, + use_kg: false, + reasoning: false, + parameters: [{ key: 'knowledge', optional: false }], + toc_enhance: false, + }, + llm_id: defaultModelDictionary?.llm_id, + llm_setting: {}, + similarity_threshold: 0.2, + vector_similarity_weight: 0.3, + top_n: 8, + top_k: 1024, + }), + [t, defaultModelDictionary?.llm_id], + ); + + const onCreateChatOk = useCallback( + async (name: string) => { + const ret = await createChat({ ...InitialData, name }); + if (ret === 0) { + hideCreateChatModal(); + } + }, + [InitialData, createChat, hideCreateChatModal], + ); + + return { + createChatLoading: createLoading, + onCreateChatOk, + createChatVisible, + hideCreateChatModal, + showCreateChatModal, + }; +}; diff --git a/web/src/pages/next-chats/hooks/use-rename-chat.ts b/web/src/pages/next-chats/hooks/use-rename-chat.ts index 65329e1938..7c46d6d534 100644 --- a/web/src/pages/next-chats/hooks/use-rename-chat.ts +++ b/web/src/pages/next-chats/hooks/use-rename-chat.ts @@ -1,11 +1,7 @@ import { useSetModalState } from '@/hooks/common-hooks'; -import { useCreateChat, usePatchChat } from '@/hooks/use-chat-request'; -import { useFindLlmByUuid } from '@/hooks/use-llm-request'; -import { useFetchTenantInfo } from '@/hooks/use-user-setting-request'; +import { usePatchChat } from '@/hooks/use-chat-request'; import { IDialog } from '@/interfaces/database/chat'; -import { isEmpty } from 'lodash'; -import { useCallback, useMemo, useState } from 'react'; -import { useTranslation } from 'react-i18next'; +import { useCallback, useState } from 'react'; export const useRenameChat = () => { const [chat, setChat] = useState({} as IDialog); @@ -14,70 +10,25 @@ export const useRenameChat = () => { hideModal: hideChatRenameModal, showModal: showChatRenameModal, } = useSetModalState(); - const { createChat, loading: createLoading } = useCreateChat(); const { patchChat, loading: patchLoading } = usePatchChat(); - const { t } = useTranslation(); - const tenantInfo = useFetchTenantInfo(); - const findLlmByUuid = useFindLlmByUuid(); - - const InitialData = useMemo( - () => ({ - name: '', - icon: '', - language: 'English', - description: '', - dataset_ids: [], - prompt_config: { - empty_response: '', - prologue: t('chat.setAnOpenerInitial'), - quote: true, - keyword: false, - tts: false, - system: t('chat.systemInitialValue'), - refine_multiturn: false, - use_kg: false, - reasoning: false, - parameters: [{ key: 'knowledge', optional: false }], - toc_enhance: false, - }, - llm_id: tenantInfo.data.llm_id, - llm_setting: { - model_type: findLlmByUuid(tenantInfo.data.llm_id)?.model_type || 'chat', - }, - similarity_threshold: 0.2, - vector_similarity_weight: 0.3, - top_n: 8, - top_k: 1024, - }), - [t, tenantInfo.data.llm_id, findLlmByUuid], - ); const onChatRenameOk = useCallback( async (name: string) => { - let ret: number | undefined; - if (isEmpty(chat)) { - ret = await createChat({ ...InitialData, name }); - } else { - ret = await patchChat({ - chatId: chat.id, - params: { name }, - }); - } + const ret = await patchChat({ + chatId: chat.id, + params: { name }, + }); if (ret === 0) { hideChatRenameModal(); } }, - [chat, InitialData, createChat, patchChat, hideChatRenameModal], + [chat.id, patchChat, hideChatRenameModal], ); const handleShowChatRenameModal = useCallback( - (record?: IDialog) => { - if (record) { - setChat(record); - } else { - setChat({} as IDialog); - } + (record: IDialog) => { + setChat(record); showChatRenameModal(); }, [showChatRenameModal], @@ -89,7 +40,7 @@ export const useRenameChat = () => { }, [hideChatRenameModal]); return { - chatRenameLoading: createLoading || patchLoading, + chatRenameLoading: patchLoading, initialChatName: chat?.name, onChatRenameOk, chatRenameVisible, diff --git a/web/src/pages/next-chats/index.tsx b/web/src/pages/next-chats/index.tsx index fb49f00249..ed10d2df93 100644 --- a/web/src/pages/next-chats/index.tsx +++ b/web/src/pages/next-chats/index.tsx @@ -8,10 +8,11 @@ import { RAGFlowPagination } from '@/components/ui/ragflow-pagination'; import { useFetchChatList } from '@/hooks/use-chat-request'; import { pick } from 'lodash'; import { Plus } from 'lucide-react'; -import { useCallback, useEffect } from 'react'; +import { useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useSearchParams } from 'react-router'; import { ChatCard } from './chat-card'; +import { useCreateChatDialog } from './hooks/use-create-chat'; import { useRenameChat } from './hooks/use-rename-chat'; export default function ChatList() { @@ -26,6 +27,13 @@ export default function ChatList() { onChatRenameOk, chatRenameLoading, } = useRenameChat(); + const { + createChatVisible, + showCreateChatModal, + hideCreateChatModal, + onCreateChatOk, + createChatLoading, + } = useCreateChatDialog(); const handlePageChange = useCallback( (page: number, pageSize?: number) => { @@ -35,8 +43,8 @@ export default function ChatList() { ); const handleShowCreateModal = useCallback(() => { - showChatRenameModal(); - }, [showChatRenameModal]); + showCreateChatModal(); + }, [showCreateChatModal]); const [searchParams, setSearchParams] = useSearchParams(); const isCreate = searchParams.get('isCreate') === 'true'; @@ -48,6 +56,39 @@ export default function ChatList() { } }, [isCreate, handleShowCreateModal, searchParams, setSearchParams]); + const renameDialogProps = useMemo(() => { + if (chatRenameVisible) { + return { + hideModal: hideChatRenameModal, + onOk: onChatRenameOk, + initialName: initialChatName, + loading: chatRenameLoading, + title: initialChatName, + }; + } + if (createChatVisible) { + return { + hideModal: hideCreateChatModal, + onOk: onCreateChatOk, + initialName: '', + loading: createChatLoading, + title: t('chat.createChat'), + }; + } + return null; + }, [ + chatRenameVisible, + createChatVisible, + hideChatRenameModal, + onChatRenameOk, + initialChatName, + chatRenameLoading, + hideCreateChatModal, + onCreateChatOk, + createChatLoading, + t, + ]); + return ( <> {data.chats?.length || searchString ? ( @@ -115,14 +156,8 @@ export default function ChatList() { )} - {chatRenameVisible && ( - + {renameDialogProps && ( + )} ); diff --git a/web/src/pages/next-search/search-setting.tsx b/web/src/pages/next-search/search-setting.tsx index 3c3dc6c600..29fd50ada9 100644 --- a/web/src/pages/next-search/search-setting.tsx +++ b/web/src/pages/next-search/search-setting.tsx @@ -10,6 +10,7 @@ import { MetadataFilter, MetadataFilterSchema, } from '@/components/metadata-filter'; +import { ModelTreeSelect } from '@/components/model-tree-select'; import { SimilaritySliderFormField } from '@/components/similarity-slider'; import { Button } from '@/components/ui/button'; import { SingleFormSlider } from '@/components/ui/dual-range-slider'; @@ -23,14 +24,9 @@ import { } from '@/components/ui/form'; import { Input } from '@/components/ui/input'; import { MultiSelect } from '@/components/ui/multi-select'; -import { RAGFlowSelect } from '@/components/ui/select'; import { Spin } from '@/components/ui/spin'; import { Switch } from '@/components/ui/switch'; import { useFetchKnowledgeMetadataKeys } from '@/hooks/use-knowledge-request'; -import { - useComposeLlmOptionsByModelTypes, - useSelectLlmOptionsByModelType, -} from '@/hooks/use-llm-request'; import { useFetchTenantInfo } from '@/hooks/use-user-setting-request'; import { cn } from '@/lib/utils'; import { zodResolver } from '@hookform/resolvers/zod'; @@ -39,7 +35,6 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { useForm, useWatch } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { z } from 'zod'; -import { LlmModelType } from '../dataset/dataset/constant'; import { ISearchAppDetailProps, IUpdateSearchProps, @@ -188,16 +183,6 @@ const SearchSetting: React.FC = ({ } }, [open]); - const allOptions = useSelectLlmOptionsByModelType(); - const rerankModelOptions = useMemo(() => { - return allOptions[LlmModelType.Rerank]; - }, [allOptions]); - - const aiSummeryModelOptions = useComposeLlmOptionsByModelTypes([ - LlmModelType.Chat, - LlmModelType.Image2text, - ]); - const rerankModelDisabled = useWatch({ control: formMethods.control, name: 'search_config.use_rerank', @@ -344,15 +329,9 @@ const SearchSetting: React.FC = ({ >
{ - console.log('Form submitted with data:', data); - onSubmit(data as unknown as IUpdateSearchProps); - }, - (errors) => { - console.log('Validation errors:', errors); - }, - )} + onSubmit={formMethods.handleSubmit((data) => { + onSubmit(data as unknown as IUpdateSearchProps); + })} className="space-y-6" > @@ -452,11 +431,9 @@ const SearchSetting: React.FC = ({ {t('chat.model')} - @@ -524,7 +501,6 @@ const SearchSetting: React.FC = ({ // > = ({ const [saving, setSaving] = useState(false); const [reindexing, setReindexing] = useState(false); - // Get embedding model options from user's configured LLMs - const llmOptions = useSelectLlmOptionsByModelType(); - const embeddingModelOptions = useMemo(() => { - return llmOptions[ - LlmModelType.Embedding - ] as SelectWithSearchFlagOptionType[]; - }, [llmOptions]); - useEffect(() => { if (open) { if (config) { @@ -166,12 +153,12 @@ export const SearchConfigModal: React.FC = ({ {/* Embedding Model */}
- setValue('embd_id', value, { shouldDirty: true }) } - options={embeddingModelOptions} placeholder={t('skillSearch.embeddingModelPlaceholder')} />
diff --git a/web/src/pages/user-setting/data-source/constant/index.tsx b/web/src/pages/user-setting/data-source/constant/index.tsx index a3480b35e3..6f55db8da1 100644 --- a/web/src/pages/user-setting/data-source/constant/index.tsx +++ b/web/src/pages/user-setting/data-source/constant/index.tsx @@ -1605,8 +1605,8 @@ export const DataSourceFormDefaultValues = { tenant_id: '', client_id: '', client_secret: '', - } - } + }, + }, }, [DataSourceKey.SLACK]: { name: '', diff --git a/web/src/pages/user-setting/interface.ts b/web/src/pages/user-setting/interface.ts index 1c9114ee80..f6fe4375ef 100644 --- a/web/src/pages/user-setting/interface.ts +++ b/web/src/pages/user-setting/interface.ts @@ -1,4 +1,5 @@ export interface ApiKeyPostBody { + instance_name: string; api_key: string; base_url: string; group_id?: string; diff --git a/web/src/pages/user-setting/setting-model/components/modal-card.tsx b/web/src/pages/user-setting/setting-model/components/modal-card.tsx deleted file mode 100644 index f688cc9f6c..0000000000 --- a/web/src/pages/user-setting/setting-model/components/modal-card.tsx +++ /dev/null @@ -1,212 +0,0 @@ -// src/components/ModelProviderCard.tsx -import { - ConfirmDeleteDialog, - ConfirmDeleteDialogNode, -} from '@/components/confirm-delete-dialog'; -import { LlmIcon } from '@/components/svg-icon'; -import { Button } from '@/components/ui/button'; -import { Switch } from '@/components/ui/switch'; -import { useSetModalState, useTranslate } from '@/hooks/common-hooks'; -import { LlmItem } from '@/hooks/use-llm-request'; -import { getRealModelName } from '@/utils/llm-util'; -import { EditOutlined, SettingOutlined } from '@ant-design/icons'; -import { ChevronsDown, ChevronsUp, Trash2 } from 'lucide-react'; -import { FC } from 'react'; -import { isLocalLlmFactory } from '../../utils'; -import { - useHandleDeleteFactory, - useHandleDeleteLlm, - useHandleEnableLlm, -} from '../hooks'; -import { mapModelKey } from './un-add-model'; - -interface IModelCardProps { - item: LlmItem; - clickApiKey: (llmFactory: string) => void; - handleEditModel: (model: any, factory: LlmItem) => void; -} - -type TagType = - | 'LLM' - | 'TEXT EMBEDDING' - | 'TEXT RE-RANK' - | 'TTS' - | 'SPEECH2TEXT' - | 'IMAGE2TEXT' - | 'MODERATION'; - -const sortTags = (tags: string) => { - const orderMap: Record = { - LLM: 1, - 'TEXT EMBEDDING': 2, - 'TEXT RE-RANK': 3, - TTS: 4, - SPEECH2TEXT: 5, - IMAGE2TEXT: 6, - MODERATION: 7, - }; - - return tags - .split(',') - .map((tag) => tag.trim()) - .sort( - (a, b) => - (orderMap[a as TagType] || 999) - (orderMap[b as TagType] || 999), - ); -}; - -export const ModelProviderCard: FC = ({ - item, - clickApiKey, - handleEditModel, -}) => { - const { visible, switchVisible } = useSetModalState(); - const { t } = useTranslate('setting'); - const { handleEnableLlm } = useHandleEnableLlm(item.name); - const { deleteFactory } = useHandleDeleteFactory(item.name); - const { handleDeleteLlm } = useHandleDeleteLlm(item.name); - - const handleApiKeyClick = () => { - clickApiKey(item.name); - }; - - const handleShowMoreClick = () => { - switchVisible(); - }; - - return ( -
- {/* Header */} -
-
- -
-
- {item.name} -
-
-
- -
- - - - - deleteFactory({ llm_factory: item.name })} - title={t('deleteModel')} - content={{ - node: ( - -
- - {item.name} -
-
- ), - }} - > - -
-
-
- - {/* Content */} - {visible && ( -
-
- {sortTags(item.tags).map((tag, index) => ( - - {mapModelKey[tag.trim() as keyof typeof mapModelKey] || - tag.trim()} - - ))} -
-
-
    - {item.llm.map((model) => ( -
  • -
    - - {getRealModelName(model.name)} - - - {model.type} - -
    - -
    - {isLocalLlmFactory(item.name) && ( - - )} - - { - handleEnableLlm(model.name, value); - }} - /> - - -
    -
  • - ))} -
-
-
- )} -
- ); -}; diff --git a/web/src/pages/user-setting/setting-model/components/system-setting.tsx b/web/src/pages/user-setting/setting-model/components/system-setting.tsx index f1fbdaaf51..99f130ef3e 100644 --- a/web/src/pages/user-setting/setting-model/components/system-setting.tsx +++ b/web/src/pages/user-setting/setting-model/components/system-setting.tsx @@ -1,176 +1,136 @@ -import { - SelectWithSearch, - SelectWithSearchFlagOptionType, -} from '@/components/originui/select-with-search'; +import { ModelTreeSelect, ModelTypeMap } from '@/components/model-tree-select'; import { Tooltip, TooltipContent, TooltipTrigger, } from '@/components/ui/tooltip'; -import { LlmModelType } from '@/constants/knowledge'; +import { FieldToModelType } from '@/constants/llm'; import { useTranslate } from '@/hooks/common-hooks'; import { - ISystemModelSettingSavingParams, - useComposeLlmOptionsByModelTypes, + useFetchDefaultModelDictionary, + useSetDefaultModel, } from '@/hooks/use-llm-request'; +import { parseModelValue } from '@/utils/llm-util'; import { CircleQuestionMark } from 'lucide-react'; -import { useCallback, useEffect, useMemo, useState } from 'react'; -import { useFetchSystemModelSettingOnMount } from '../hooks'; +import { useCallback, useMemo } from 'react'; -interface IProps { - loading: boolean; - onOk: ( - payload: Omit, - ) => void; +interface ModelFieldItemProps { + id: string; + label: string; + value: string; + tooltip?: string; + isRequired?: boolean; + onChange: (id: string, value: string) => void; } -const SystemSetting = ({ onOk, loading }: IProps) => { - const { systemSetting: initialValues, allOptions } = - useFetchSystemModelSettingOnMount(); +function ModelFieldItem({ + label, + value, + tooltip, + id, + isRequired, + onChange, +}: ModelFieldItemProps) { const { t } = useTranslate('setting'); - const [formData, setFormData] = useState({ - llm_id: '', - embd_id: '', - img2txt_id: '', - asr_id: '', - rerank_id: '', - tts_id: '', - }); + return ( +
+ +
+ onChange(id, val)} + placeholder={t('selectModelPlaceholder')} + showSearch + allowClear={id !== 'llm_id'} + /> +
+
+ ); +} + +function SystemSetting() { + const { t } = useTranslate('setting'); + const defaultModelDictionary = useFetchDefaultModelDictionary(); + const { setDefaultModel } = useSetDefaultModel(); const handleFieldChange = useCallback( - (field: string, value: string) => { - const updatedData = { ...formData, [field]: value || '' }; - setFormData(updatedData); - console.log('updatedData', updatedData); - onOk(updatedData); + async (field: string, value: string) => { + const modelType = FieldToModelType[field]; + if (!modelType) return; + + if (!value) { + await setDefaultModel({ + model_provider: '', + model_instance: '', + model_name: '', + model_type: modelType, + }); + } else { + const parsed = parseModelValue(value); + if (!parsed) return; + await setDefaultModel({ ...parsed, model_type: modelType }); + } }, - [formData, onOk], + [setDefaultModel], ); - useEffect(() => { - setFormData({ - llm_id: initialValues.llm_id ?? '', - embd_id: initialValues.embd_id ?? '', - img2txt_id: initialValues.img2txt_id ?? '', - asr_id: initialValues.asr_id ?? '', - rerank_id: initialValues.rerank_id ?? '', - tts_id: initialValues.tts_id ?? '', - }); - }, [initialValues]); - - const modelOptions = useComposeLlmOptionsByModelTypes([ - LlmModelType.Chat, - LlmModelType.Image2text, - ]); - const llmList = useMemo(() => { return [ { id: 'llm_id', label: t('chatModel'), isRequired: true, - value: formData.llm_id, - options: modelOptions as SelectWithSearchFlagOptionType[], + value: defaultModelDictionary.llm_id, tooltip: t('chatModelTip'), - testId: 'default-llm-combobox', }, { id: 'embd_id', label: t('embeddingModel'), - value: formData.embd_id, - options: allOptions[ - LlmModelType.Embedding - ] as SelectWithSearchFlagOptionType[], + value: defaultModelDictionary.embd_id, tooltip: t('embeddingModelTip'), - testId: 'default-embedding-combobox', }, { id: 'img2txt_id', label: t('img2txtModel'), - value: formData.img2txt_id, - options: allOptions[ - LlmModelType.Image2text - ] as SelectWithSearchFlagOptionType[], + value: defaultModelDictionary.img2txt_id, tooltip: t('img2txtModelTip'), }, { id: 'asr_id', label: t('sequence2txtModel'), - value: formData.asr_id, - options: allOptions[ - LlmModelType.Speech2text - ] as SelectWithSearchFlagOptionType[], + value: defaultModelDictionary.asr_id, tooltip: t('sequence2txtModelTip'), }, { id: 'rerank_id', label: t('rerankModel'), - value: formData.rerank_id, - options: allOptions[ - LlmModelType.Rerank - ] as SelectWithSearchFlagOptionType[], + value: defaultModelDictionary.rerank_id, tooltip: t('rerankModelTip'), }, { id: 'tts_id', label: t('ttsModel'), - value: formData.tts_id, - options: allOptions[ - LlmModelType.TTS - ] as SelectWithSearchFlagOptionType[], + value: defaultModelDictionary.tts_id, tooltip: t('ttsModelTip'), }, ]; - }, [formData, modelOptions, t, allOptions]); - - const Items = ({ - label, - value, - options, - tooltip, - id, - isRequired, - testId, - }: { - id: string; - label: string; - value: string; - options: SelectWithSearchFlagOptionType[]; - tooltip?: string; - isRequired?: boolean; - testId?: string; - }) => { - return ( -
- - handleFieldChange(id, value)} - placeholder={t('selectModelPlaceholder')} - emptyData={t('modelEmptyTip')} - testId={testId} - /> -
- ); - }; + }, [defaultModelDictionary, t]); return (
@@ -185,20 +145,15 @@ const SystemSetting = ({ onOk, loading }: IProps) => {
{llmList.map((item) => ( - + ))}
- {/*
- -
*/}
); -}; +} export default SystemSetting; diff --git a/web/src/pages/user-setting/setting-model/components/un-add-model.tsx b/web/src/pages/user-setting/setting-model/components/un-add-model.tsx index a78aa68628..41691b3a4e 100644 --- a/web/src/pages/user-setting/setting-model/components/un-add-model.tsx +++ b/web/src/pages/user-setting/setting-model/components/un-add-model.tsx @@ -3,49 +3,56 @@ import { LlmIcon } from '@/components/svg-icon'; import { Button } from '@/components/ui/button'; import { SearchInput } from '@/components/ui/input'; import { APIMapUrl } from '@/constants/llm'; -import { useTranslate } from '@/hooks/common-hooks'; -import { useSelectLlmList } from '@/hooks/use-llm-request'; +import { useFetchAvailableProviders } from '@/hooks/use-llm-request'; import { ArrowUpRight, Plus } from 'lucide-react'; import { FC, useMemo, useState } from 'react'; +import { useTranslation } from 'react-i18next'; + export const mapModelKey = { IMAGE2TEXT: 'VLM', 'TEXT EMBEDDING': 'Embedding', SPEECH2TEXT: 'ASR', 'TEXT RE-RANK': 'Rerank', + chat: 'LLM', + vision: 'VLM', + embedding: 'Embedding', + asr: 'ASR', + rerank: 'Rerank', + tts: 'TTS', + ocr: 'OCR', }; -const orderMap: Record = { - LLM: 1, - 'TEXT EMBEDDING': 2, - 'TEXT RE-RANK': 3, - TTS: 4, - SPEECH2TEXT: 5, - IMAGE2TEXT: 6, - MODERATION: 7, -}; -type TagType = - | 'LLM' - | 'TEXT EMBEDDING' - | 'TEXT RE-RANK' - | 'TTS' - | 'SPEECH2TEXT' - | 'IMAGE2TEXT' - | 'MODERATION'; -const sortTags = (tags: string) => { - return tags - .split(',') - .map((tag) => tag.trim()) - .sort( - (a, b) => - (orderMap[a as TagType] || 999) - (orderMap[b as TagType] || 999), - ); +const orderMap: Record = { + chat: 1, + embedding: 2, + rerank: 3, + tts: 4, + asr: 5, + vision: 6, + ocr: 7, +}; + +type ModelType = + | 'chat' + | 'embedding' + | 'rerank' + | 'tts' + | 'asr' + | 'vision' + | 'ocr'; + +const sortModelTypes = (modelTypes: string[]) => { + return [...modelTypes].sort( + (a, b) => + (orderMap[a as ModelType] || 999) - (orderMap[b as ModelType] || 999), + ); }; export const AvailableModels: FC<{ handleAddModel: (factory: string) => void; }> = ({ handleAddModel }) => { - const { t } = useTranslate('setting'); - const { factoryList } = useSelectLlmList(); + const { t } = useTranslation(); + const { data: factoryList } = useFetchAvailableProviders(); const [searchTerm, setSearchTerm] = useState(''); const [selectedTag, setSelectedTag] = useState(null); @@ -57,7 +64,7 @@ export const AvailableModels: FC<{ .includes(searchTerm.toLowerCase()); const matchesTag = selectedTag === null || - model.tags.split(',').some((tag) => tag.trim() === selectedTag); + model.model_types.some((type) => type === selectedTag); return matchesSearch && matchesTag; }); return models; @@ -66,11 +73,11 @@ export const AvailableModels: FC<{ const allTags = useMemo(() => { const tagsSet = new Set(); factoryList.forEach((model) => { - model.tags.split(',').forEach((tag) => tagsSet.add(tag.trim())); + model.model_types.forEach((type) => tagsSet.add(type)); }); return Array.from(tagsSet).sort( (a, b) => - (orderMap[a as TagType] || 999) - (orderMap[b as TagType] || 999), + (orderMap[a as ModelType] || 999) - (orderMap[b as ModelType] || 999), ); }, [factoryList]); @@ -84,14 +91,16 @@ export const AvailableModels: FC<{ data-testid="available-models-section" >
-

{t('availableModels')}

+

+ {t('setting.availableModels')} +

{/* Search Bar */}
{/*
*/} setSearchTerm(e.target.value)} className="w-full px-4 py-2 pl-10 bg-bg-input border border-border-default rounded-lg focus:outline-none focus:ring-1 focus:ring-border-button transition-colors" @@ -165,19 +174,17 @@ export const AvailableModels: FC<{ className="px-2 opacity-0 transition-all group-hover:opacity-100 group-focus-within:opacity-100" > - {t('addTheModel')} + {t('setting.addTheModel')}
- {sortTags(model.tags).map((tag, index) => ( + {sortModelTypes(model.model_types).map((type, index) => ( - {/* {tag} */} - {mapModelKey[tag.trim() as keyof typeof mapModelKey] || - tag.trim()} + {mapModelKey[type as keyof typeof mapModelKey] || type} ))}
diff --git a/web/src/pages/user-setting/setting-model/components/used-model.tsx b/web/src/pages/user-setting/setting-model/components/used-model.tsx index 38209569a3..0c8c83201d 100644 --- a/web/src/pages/user-setting/setting-model/components/used-model.tsx +++ b/web/src/pages/user-setting/setting-model/components/used-model.tsx @@ -1,15 +1,38 @@ -import { LlmItem, useSelectLlmList } from '@/hooks/use-llm-request'; -import { t } from 'i18next'; -import { ModelProviderCard } from './modal-card'; +import { ConfirmDeleteDialog } from '@/components/confirm-delete-dialog'; +import { LlmIcon } from '@/components/svg-icon'; +import { Button } from '@/components/ui/button'; +import { + Collapsible, + CollapsibleContent, + CollapsibleTrigger, +} from '@/components/ui/collapsible'; +import { Switch } from '@/components/ui/switch'; +import { ModelStatus } from '@/constants/llm'; +import { + useDeleteProviderInstance, + useFetchAddedProviders, + useFetchInstanceModels, + useFetchProviderInstances, + useUpdateModelStatus, +} from '@/hooks/use-llm-request'; +import { + IAvailableProvider, + IInstanceModel, + IProviderInstance, +} from '@/interfaces/database/llm'; +import { ChevronsDown, ChevronsUp, Trash2 } from 'lucide-react'; +import { useMemo, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { mapModelKey } from './un-add-model'; -export const UsedModel = ({ +export function UsedModel({ handleAddModel, - handleEditModel, }: { handleAddModel: (factory: string) => void; - handleEditModel: (model: any, factory: LlmItem) => void; -}) => { - const { myLlmList: llmList } = useSelectLlmList(); +}) { + const { t } = useTranslation(); + const { data: providerList } = useFetchAddedProviders(); + return (
{t('setting.addedModels')}
- {llmList.map((llm) => { - return ( - - ); - })} + {providerList.map((provider) => ( + + ))}
); -}; +} + +function ProviderCard({ + provider, + handleAddModel, +}: { + provider: IAvailableProvider; + handleAddModel: (factory: string) => void; +}) { + const { data: instances } = useFetchProviderInstances(provider.name); + + return ( +
+ {/* Provider header */} +
+
+ +
+ {provider.name} +
+
+
+ + {/* Instances */} + {instances.length > 0 && ( +
+ {instances.map((instance) => ( + + ))} +
+ )} +
+ ); +} + +function InstanceRow({ + instance, + providerName, + // handleAddModel, +}: { + instance: IProviderInstance; + providerName: string; + handleAddModel: (factory: string) => void; +}) { + const { t } = useTranslation(); + const [visible, setVisible] = useState(false); + const { deleteProviderInstance } = useDeleteProviderInstance(); + + const handleDelete = async () => { + await deleteProviderInstance({ + provider_name: providerName, + instances: [instance.instance_name], + }); + }; + + return ( + +
+ {/* Instance header */} +
+ + {instance.instance_name} + +
+ {/* */} + + + + + + +
+
+ + {/* Models */} + + + +
+
+ ); +} + +function InstanceModelList({ + providerName, + instanceName, +}: { + providerName: string; + instanceName: string; +}) { + const { data: models } = useFetchInstanceModels(providerName, instanceName); + + const modelTypes = useMemo(() => { + const types = new Set(); + models.forEach((m) => { + if (m.model_type) { + m.model_type.forEach((type) => types.add(type)); + } + }); + return Array.from(types); + }, [models]); + + return ( +
+ {/* Model type tags */} + {modelTypes.length > 0 && ( +
+ {modelTypes.map((type) => ( + + {mapModelKey[type.trim() as keyof typeof mapModelKey] || type} + + ))} +
+ )} + + {/* Model list */} +
+
    + {models.map((model) => ( + + ))} +
+
+
+ ); +} + +function ModelListItem({ + model, + providerName, + instanceName, +}: { + model: IInstanceModel; + providerName: string; + instanceName: string; +}) { + const { updateModelStatus } = useUpdateModelStatus(); + + const handleStatusChange = (checked: boolean) => { + updateModelStatus({ + provider_name: providerName, + instance_name: instanceName, + model_name: model.name, + status: checked ? ModelStatus.Active : ModelStatus.Inactive, + }); + }; + + return ( +
  • +
    + {model.name} + {model.model_type.map((modelType) => ( + + {modelType} + + ))} +
    + +
  • + ); +} diff --git a/web/src/pages/user-setting/setting-model/hooks.tsx b/web/src/pages/user-setting/setting-model/hooks.tsx index 1ddf3ac773..5a530cb462 100644 --- a/web/src/pages/user-setting/setting-model/hooks.tsx +++ b/web/src/pages/user-setting/setting-model/hooks.tsx @@ -1,55 +1,114 @@ import { LLMFactory } from '@/constants/llm'; -import { useSetModalState, useShowDeleteConfirm } from '@/hooks/common-hooks'; +import { useSetModalState } from '@/hooks/common-hooks'; import { - IApiKeySavingParams, - ISystemModelSettingSavingParams, - useAddLlm, - useDeleteFactory, - useDeleteLlm, - useEnableLlm, - useSaveApiKey, - useSaveTenantInfo, - useSelectLlmOptionsByModelType, + useAddInstanceModel, + useAddProviderInstance, + useFetchAddedProviders, + useFetchProviderInstances, } from '@/hooks/use-llm-request'; -import { useFetchTenantInfo } from '@/hooks/use-user-setting-request'; -import { IAddLlmRequestBody } from '@/interfaces/request/llm'; +import { IAddProviderInstanceRequestBody } from '@/interfaces/request/llm'; import { getRealModelName } from '@/utils/llm-util'; -import { useQueryClient } from '@tanstack/react-query'; -import { useCallback, useState } from 'react'; +import { useCallback, useMemo, useState } from 'react'; import { ApiKeyPostBody } from '../interface'; import { MinerUFormValues } from './modal/mineru-modal'; +import { splitProviderPayload } from './payload-utils'; -type SavingParamsState = Omit; +type SavingParamsState = { + llm_factory: string; + llm_name?: string; + model_type?: string; + instance_name?: string; + base_url?: string; +}; export type VerifyResult = { isValid: boolean | null; logs: string; }; + +const useSubmitProviderInstance = () => { + const { addProviderInstance } = useAddProviderInstance(); + const { addInstanceModel } = useAddInstanceModel(); + + return useCallback( + async (payload: IAddProviderInstanceRequestBody, isVerify = false) => { + if (isVerify) { + return addProviderInstance({ ...payload, verify: true }); + } + + const { instancePayload, modelPayload } = splitProviderPayload(payload); + const hasModelPayload = + !!modelPayload.model_name && !!modelPayload.model_type; + + const instanceRet = await addProviderInstance({ + ...instancePayload, + llm_factory: payload.llm_factory, + instance_name: payload.instance_name, + } as IAddProviderInstanceRequestBody); + if (instanceRet.code !== 0 || !hasModelPayload) { + return instanceRet; + } + + if (!hasModelPayload) { + return { code: 0, data: null } as any; + } + + return addInstanceModel({ + provider_name: payload.llm_factory, + instance_name: payload.instance_name, + ...modelPayload, + }); + }, + [addProviderInstance, addInstanceModel], + ); +}; + +export const useFetchInstanceNameSet = (providerName: string) => { + const { data: addedProviders } = useFetchAddedProviders(); + const providerExists = useMemo( + () => addedProviders.some((p) => p.name === providerName), + [addedProviders, providerName], + ); + const { data: instances } = useFetchProviderInstances( + providerExists ? providerName : '', + ); + const instanceNameSet = useMemo( + () => new Set(instances.map((i) => i.instance_name)), + [instances], + ); + return { instanceNameSet, providerExists }; +}; + +export const useHideWhenInstanceExists = (instanceNameSet: Set) => { + return useCallback( + (formValues: any) => { + const name = ((formValues?.instance_name as string) || '').trim(); + return !(name && instanceNameSet.has(name)); + }, + [instanceNameSet], + ); +}; export const useSubmitApiKey = () => { const [savingParams, setSavingParams] = useState( {} as SavingParamsState, ); const [editMode, setEditMode] = useState(false); - const { saveApiKey } = useSaveApiKey(); + const submitProviderInstance = useSubmitProviderInstance(); const [saveLoading, setSaveLoading] = useState(false); const { visible: apiKeyVisible, hideModal: hideApiKeyModal, showModal: showApiKeyModal, } = useSetModalState(); - const queryClient = useQueryClient(); const onApiKeySavingOk = useCallback( async (postBody: ApiKeyPostBody, isVerify = false) => { if (!isVerify) { setSaveLoading(true); } - const payload: IApiKeySavingParams = { - ...savingParams, - ...postBody, - verify: isVerify, - }; + let apiKey: string | Record = postBody.api_key || ''; + if (savingParams.llm_factory === LLMFactory.SILICONFLOW) { - let sourceFid = LLMFactory.SILICONFLOW; + let sourceFid: string = LLMFactory.SILICONFLOW; const baseUrl = postBody.base_url; if (baseUrl) { try { @@ -65,14 +124,24 @@ export const useSubmitApiKey = () => { // ignore invalid URL and keep default sourceFid } } - payload.source_fid = sourceFid; + apiKey = { api_key: postBody.api_key, source_fid: sourceFid }; } - const ret = await saveApiKey(payload); + const req: IAddProviderInstanceRequestBody = { + instance_name: + postBody.instance_name || savingParams.instance_name || '', + llm_factory: savingParams.llm_factory, + llm_name: savingParams.llm_name || '', + model_type: savingParams.model_type || '', + api_key: apiKey, + api_base: postBody.base_url || '', + max_tokens: 0, + }; + + const ret = await submitProviderInstance(req, isVerify); if (!isVerify) { setSaveLoading(false); if (ret.code === 0) { - queryClient.invalidateQueries({ queryKey: ['llmList'] }); hideApiKeyModal(); setEditMode(false); } @@ -93,7 +162,7 @@ export const useSubmitApiKey = () => { return res; } }, - [hideApiKeyModal, saveApiKey, savingParams, queryClient], + [hideApiKeyModal, submitProviderInstance, savingParams], ); const onShowApiKeyModal = useCallback( @@ -117,57 +186,14 @@ export const useSubmitApiKey = () => { }; }; -export const useSubmitSystemModelSetting = () => { - const { data: systemSetting } = useFetchTenantInfo(); - const { saveTenantInfo: saveSystemModelSetting, loading } = - useSaveTenantInfo(); - const { - visible: systemSettingVisible, - hideModal: hideSystemSettingModal, - showModal: showSystemSettingModal, - } = useSetModalState(); - - const onSystemSettingSavingOk = useCallback( - async ( - payload: Omit, - ) => { - const ret = await saveSystemModelSetting({ - tenant_id: systemSetting.tenant_id, - name: systemSetting.name, - ...payload, - }); - - if (ret === 0) { - hideSystemSettingModal(); - } - }, - [hideSystemSettingModal, saveSystemModelSetting, systemSetting], - ); - - return { - saveSystemModelSettingLoading: loading, - onSystemSettingSavingOk, - systemSettingVisible, - hideSystemSettingModal, - showSystemSettingModal, - }; -}; - -export const useFetchSystemModelSettingOnMount = () => { - const { data: systemSetting } = useFetchTenantInfo(); - const allOptions = useSelectLlmOptionsByModelType(); - - return { systemSetting, allOptions }; -}; - export const useSubmitOllama = () => { const [selectedLlmFactory, setSelectedLlmFactory] = useState(''); const [editMode, setEditMode] = useState(false); const [initialValues, setInitialValues] = useState< - Partial & { provider_order?: string } + Partial & { provider_order?: string } >(); const [saveLoading, setSaveLoading] = useState(false); - const { addLlm } = useAddLlm(); + const submitProviderInstance = useSubmitProviderInstance(); const { visible: llmAddingVisible, hideModal: hideLlmAddingModal, @@ -175,16 +201,20 @@ export const useSubmitOllama = () => { } = useSetModalState(); const onLlmAddingOk = useCallback( - async (payload: IAddLlmRequestBody, isVerify = false) => { + async (payload: IAddProviderInstanceRequestBody, isVerify = false) => { if (!isVerify) { setSaveLoading(true); } const cleanedPayload = { ...payload }; - if (!cleanedPayload.api_key || cleanedPayload.api_key.trim() === '') { - delete cleanedPayload.api_key; - } + // if ( + // !cleanedPayload.api_key || + // (typeof cleanedPayload.api_key === 'string' && + // cleanedPayload.api_key.trim() === '') + // ) { + // delete cleanedPayload.api_key; + // } - const ret = await addLlm({ ...cleanedPayload, verify: isVerify }); + const ret = await submitProviderInstance(cleanedPayload, isVerify); if (!isVerify) { setSaveLoading(false); if (ret.code === 0) { @@ -209,7 +239,7 @@ export const useSubmitOllama = () => { return res; } }, - [hideLlmAddingModal, addLlm, setSaveLoading], + [hideLlmAddingModal, submitProviderInstance, setSaveLoading], ); const handleShowLlmAddingModal = ( @@ -223,6 +253,8 @@ export const useSubmitOllama = () => { if (isEdit && detailedData) { const initialVals = { + instance_name: + detailedData.instance_name || getRealModelName(detailedData.name), llm_name: getRealModelName(detailedData.name), model_type: detailedData.type, api_base: detailedData.api_base || '', @@ -251,7 +283,7 @@ export const useSubmitOllama = () => { export const useSubmitVolcEngine = () => { const [saveLoading, setSaveLoading] = useState(false); - const { addLlm } = useAddLlm(); + const submitProviderInstance = useSubmitProviderInstance(); const { visible: volcAddingVisible, hideModal: hideVolcAddingModal, @@ -259,11 +291,11 @@ export const useSubmitVolcEngine = () => { } = useSetModalState(); const onVolcAddingOk = useCallback( - async (payload: IAddLlmRequestBody, isVerify = false) => { + async (payload: IAddProviderInstanceRequestBody, isVerify = false) => { if (!isVerify) { setSaveLoading(true); } - const ret = await addLlm({ ...payload, verify: isVerify }); + const ret = await submitProviderInstance(payload, isVerify); if (!isVerify) { setSaveLoading(false); if (ret.code === 0) { @@ -286,7 +318,7 @@ export const useSubmitVolcEngine = () => { return res; } }, - [hideVolcAddingModal, addLlm, setSaveLoading], + [hideVolcAddingModal, submitProviderInstance, setSaveLoading], ); return { @@ -300,7 +332,7 @@ export const useSubmitVolcEngine = () => { export const useSubmitTencentCloud = () => { const [saveLoading, setSaveLoading] = useState(false); - const { addLlm } = useAddLlm(); + const submitProviderInstance = useSubmitProviderInstance(); const { visible: TencentCloudAddingVisible, hideModal: hideTencentCloudAddingModal, @@ -308,11 +340,11 @@ export const useSubmitTencentCloud = () => { } = useSetModalState(); const onTencentCloudAddingOk = useCallback( - async (payload: IAddLlmRequestBody, isVerify = false) => { + async (payload: IAddProviderInstanceRequestBody, isVerify = false) => { if (!isVerify) { setSaveLoading(true); } - const ret = await addLlm({ ...payload, verify: isVerify }); + const ret = await submitProviderInstance(payload, isVerify); if (!isVerify) { setSaveLoading(false); if (ret.code === 0) { @@ -335,7 +367,7 @@ export const useSubmitTencentCloud = () => { return res; } }, - [hideTencentCloudAddingModal, addLlm, setSaveLoading], + [hideTencentCloudAddingModal, submitProviderInstance, setSaveLoading], ); return { @@ -349,7 +381,7 @@ export const useSubmitTencentCloud = () => { export const useSubmitSpark = () => { const [saveLoading, setSaveLoading] = useState(false); - const { addLlm } = useAddLlm(); + const submitProviderInstance = useSubmitProviderInstance(); const { visible: SparkAddingVisible, hideModal: hideSparkAddingModal, @@ -357,11 +389,11 @@ export const useSubmitSpark = () => { } = useSetModalState(); const onSparkAddingOk = useCallback( - async (payload: IAddLlmRequestBody, isVerify = false) => { + async (payload: IAddProviderInstanceRequestBody, isVerify = false) => { if (!isVerify) { setSaveLoading(true); } - const ret = await addLlm({ ...payload, verify: isVerify }); + const ret = await submitProviderInstance(payload, isVerify); if (!isVerify) { setSaveLoading(false); if (ret.code === 0) { @@ -384,7 +416,7 @@ export const useSubmitSpark = () => { return res; } }, - [hideSparkAddingModal, addLlm, setSaveLoading], + [hideSparkAddingModal, submitProviderInstance, setSaveLoading], ); return { @@ -398,7 +430,7 @@ export const useSubmitSpark = () => { export const useSubmityiyan = () => { const [saveLoading, setSaveLoading] = useState(false); - const { addLlm } = useAddLlm(); + const submitProviderInstance = useSubmitProviderInstance(); const { visible: yiyanAddingVisible, hideModal: hideyiyanAddingModal, @@ -406,11 +438,11 @@ export const useSubmityiyan = () => { } = useSetModalState(); const onyiyanAddingOk = useCallback( - async (payload: IAddLlmRequestBody, isVerify = false) => { + async (payload: IAddProviderInstanceRequestBody, isVerify = false) => { if (!isVerify) { setSaveLoading(true); } - const ret = await addLlm({ ...payload, verify: isVerify }); + const ret = await submitProviderInstance(payload, isVerify); if (!isVerify) { setSaveLoading(false); if (ret.code === 0) { @@ -433,7 +465,7 @@ export const useSubmityiyan = () => { return res; } }, - [hideyiyanAddingModal, addLlm, setSaveLoading], + [hideyiyanAddingModal, submitProviderInstance, setSaveLoading], ); return { @@ -447,7 +479,7 @@ export const useSubmityiyan = () => { export const useSubmitFishAudio = () => { const [saveLoading, setSaveLoading] = useState(false); - const { addLlm } = useAddLlm(); + const submitProviderInstance = useSubmitProviderInstance(); const { visible: FishAudioAddingVisible, hideModal: hideFishAudioAddingModal, @@ -455,11 +487,11 @@ export const useSubmitFishAudio = () => { } = useSetModalState(); const onFishAudioAddingOk = useCallback( - async (payload: IAddLlmRequestBody, isVerify = false) => { + async (payload: IAddProviderInstanceRequestBody, isVerify = false) => { if (!isVerify) { setSaveLoading(true); } - const ret = await addLlm({ ...payload, verify: isVerify }); + const ret = await submitProviderInstance(payload, isVerify); if (!isVerify) { setSaveLoading(false); if (ret.code === 0) { @@ -482,7 +514,7 @@ export const useSubmitFishAudio = () => { return res; } }, - [hideFishAudioAddingModal, addLlm, setSaveLoading], + [hideFishAudioAddingModal, submitProviderInstance, setSaveLoading], ); return { @@ -496,7 +528,7 @@ export const useSubmitFishAudio = () => { export const useSubmitGoogle = () => { const [saveLoading, setSaveLoading] = useState(false); - const { addLlm } = useAddLlm(); + const submitProviderInstance = useSubmitProviderInstance(); const { visible: GoogleAddingVisible, hideModal: hideGoogleAddingModal, @@ -504,11 +536,11 @@ export const useSubmitGoogle = () => { } = useSetModalState(); const onGoogleAddingOk = useCallback( - async (payload: IAddLlmRequestBody, isVerify = false) => { + async (payload: IAddProviderInstanceRequestBody, isVerify = false) => { if (!isVerify) { setSaveLoading(true); } - const ret = await addLlm({ ...payload, verify: isVerify }); + const ret = await submitProviderInstance(payload, isVerify); if (!isVerify) { setSaveLoading(false); if (ret.code === 0) { @@ -531,7 +563,7 @@ export const useSubmitGoogle = () => { return res; } }, - [hideGoogleAddingModal, addLlm, setSaveLoading], + [hideGoogleAddingModal, submitProviderInstance, setSaveLoading], ); return { @@ -545,7 +577,7 @@ export const useSubmitGoogle = () => { export const useSubmitBedrock = () => { const [saveLoading, setSaveLoading] = useState(false); - const { addLlm } = useAddLlm(); + const submitProviderInstance = useSubmitProviderInstance(); const { visible: bedrockAddingVisible, hideModal: hideBedrockAddingModal, @@ -553,11 +585,11 @@ export const useSubmitBedrock = () => { } = useSetModalState(); const onBedrockAddingOk = useCallback( - async (payload: IAddLlmRequestBody, isVerify = false) => { + async (payload: IAddProviderInstanceRequestBody, isVerify = false) => { if (!isVerify) { setSaveLoading(true); } - const ret = await addLlm({ ...payload, verify: isVerify }); + const ret = await submitProviderInstance(payload, isVerify); if (!isVerify) { setSaveLoading(false); if (ret.code === 0) { @@ -580,7 +612,7 @@ export const useSubmitBedrock = () => { return res; } }, - [hideBedrockAddingModal, addLlm, setSaveLoading], + [hideBedrockAddingModal, submitProviderInstance, setSaveLoading], ); return { @@ -594,7 +626,7 @@ export const useSubmitBedrock = () => { export const useSubmitAzure = () => { const [saveLoading, setSaveLoading] = useState(false); - const { addLlm } = useAddLlm(); + const submitProviderInstance = useSubmitProviderInstance(); const { visible: AzureAddingVisible, hideModal: hideAzureAddingModal, @@ -602,11 +634,11 @@ export const useSubmitAzure = () => { } = useSetModalState(); const onAzureAddingOk = useCallback( - async (payload: IAddLlmRequestBody, isVerify = false) => { + async (payload: IAddProviderInstanceRequestBody, isVerify = false) => { if (!isVerify) { setSaveLoading(true); } - const ret = await addLlm({ ...payload, verify: isVerify }); + const ret = await submitProviderInstance(payload, isVerify); if (!isVerify) { setSaveLoading(false); if (ret.code === 0) { @@ -629,7 +661,7 @@ export const useSubmitAzure = () => { return res; } }, - [hideAzureAddingModal, addLlm, setSaveLoading], + [hideAzureAddingModal, submitProviderInstance, setSaveLoading], ); return { @@ -641,49 +673,9 @@ export const useSubmitAzure = () => { }; }; -export const useHandleDeleteLlm = (llmFactory: string) => { - const { deleteLlm } = useDeleteLlm(); - const showDeleteConfirm = useShowDeleteConfirm(); - - const handleDeleteLlm = (name: string) => { - showDeleteConfirm({ - onOk: async () => { - deleteLlm({ llm_factory: llmFactory, llm_name: name }); - }, - }); - }; - - return { handleDeleteLlm }; -}; - -export const useHandleEnableLlm = (llmFactory: string) => { - const { enableLlm } = useEnableLlm(); - - const handleEnableLlm = (name: string, enable: boolean) => { - enableLlm({ llm_factory: llmFactory, llm_name: name, enable }); - }; - - return { handleEnableLlm }; -}; - -export const useHandleDeleteFactory = (llmFactory: string) => { - const { deleteFactory } = useDeleteFactory(); - const showDeleteConfirm = useShowDeleteConfirm(); - - const handleDeleteFactory = () => { - showDeleteConfirm({ - onOk: async () => { - deleteFactory({ llm_factory: llmFactory }); - }, - }); - }; - - return { handleDeleteFactory, deleteFactory }; -}; - export const useSubmitMinerU = () => { const [saveLoading, setSaveLoading] = useState(false); - const { addLlm } = useAddLlm(); + const submitProviderInstance = useSubmitProviderInstance(); const { visible: mineruVisible, hideModal: hideMineruModal, @@ -691,7 +683,10 @@ export const useSubmitMinerU = () => { } = useSetModalState(); const onMineruOk = useCallback( - async (payload: MinerUFormValues, isVerify = false) => { + async ( + payload: MinerUFormValues & { instance_name: string }, + isVerify = false, + ) => { if (!isVerify) { setSaveLoading(true); } @@ -700,10 +695,12 @@ export const useSubmitMinerU = () => { mineru_delete_output: (payload.mineru_delete_output ?? true) ? '1' : '0', }; + delete cfg.instance_name; if (payload.mineru_backend !== 'vlm-http-client') { delete cfg.mineru_server_url; } - const req: IAddLlmRequestBody = { + const req: IAddProviderInstanceRequestBody = { + instance_name: payload.instance_name, llm_factory: LLMFactory.MinerU, llm_name: payload.llm_name, model_type: 'ocr', @@ -711,7 +708,7 @@ export const useSubmitMinerU = () => { api_base: '', max_tokens: 0, }; - const ret = await addLlm({ ...req, verify: isVerify }); + const ret = await submitProviderInstance(req, isVerify); if (!isVerify) { setSaveLoading(false); if (ret.code === 0) { @@ -734,7 +731,7 @@ export const useSubmitMinerU = () => { return res; } }, - [addLlm, hideMineruModal, setSaveLoading], + [submitProviderInstance, hideMineruModal, setSaveLoading], ); return { @@ -748,7 +745,7 @@ export const useSubmitMinerU = () => { export const useSubmitPaddleOCR = () => { const [saveLoading, setSaveLoading] = useState(false); - const { addLlm } = useAddLlm(); + const submitProviderInstance = useSubmitProviderInstance(); const { visible: paddleocrVisible, hideModal: hidePaddleOCRModal, @@ -763,7 +760,9 @@ export const useSubmitPaddleOCR = () => { const cfg: any = { ...payload, }; - const req: IAddLlmRequestBody = { + delete cfg.instance_name; + const req: IAddProviderInstanceRequestBody = { + instance_name: payload.instance_name, llm_factory: LLMFactory.PaddleOCR, llm_name: payload.llm_name, model_type: 'ocr', @@ -771,7 +770,7 @@ export const useSubmitPaddleOCR = () => { api_base: '', max_tokens: 0, }; - const ret = await addLlm({ ...req, verify: isVerify }); + const ret = await submitProviderInstance(req, isVerify); if (!isVerify) { setSaveLoading(false); if (ret.code === 0) { @@ -796,7 +795,7 @@ export const useSubmitPaddleOCR = () => { } return false; }, - [addLlm, hidePaddleOCRModal, setSaveLoading], + [submitProviderInstance, hidePaddleOCRModal, setSaveLoading], ); return { @@ -810,7 +809,7 @@ export const useSubmitPaddleOCR = () => { export const useSubmitOpenDataLoader = () => { const [saveLoading, setSaveLoading] = useState(false); - const { addLlm } = useAddLlm(); + const submitProviderInstance = useSubmitProviderInstance(); const { visible: opendataloaderVisible, hideModal: hideOpenDataLoaderModal, @@ -822,15 +821,18 @@ export const useSubmitOpenDataLoader = () => { if (!isVerify) { setSaveLoading(true); } - const req: IAddLlmRequestBody = { + const cfg: any = { ...payload }; + delete cfg.instance_name; + const req: IAddProviderInstanceRequestBody = { + instance_name: payload.instance_name, llm_factory: LLMFactory.OpenDataLoader, llm_name: payload.llm_name, model_type: 'ocr', - api_key: { ...payload }, + api_key: cfg, api_base: '', max_tokens: 0, }; - const ret = await addLlm({ ...req, verify: isVerify }); + const ret = await submitProviderInstance(req, isVerify); if (!isVerify) { setSaveLoading(false); if (ret.code === 0) { @@ -846,7 +848,7 @@ export const useSubmitOpenDataLoader = () => { } return false; }, - [addLlm, hideOpenDataLoaderModal, setSaveLoading], + [submitProviderInstance, hideOpenDataLoaderModal, setSaveLoading], ); return { @@ -867,7 +869,7 @@ export const useVerifySettings = ({ isVerify?: boolean, ) => Promise) | (( - payload: IAddLlmRequestBody, + payload: IAddProviderInstanceRequestBody, isVerify?: boolean, ) => Promise) | (( diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx index 39f490febd..f32e8fd26c 100644 --- a/web/src/pages/user-setting/setting-model/index.tsx +++ b/web/src/pages/user-setting/setting-model/index.tsx @@ -1,6 +1,6 @@ import Spotlight from '@/components/spotlight'; import { LLMFactory } from '@/constants/llm'; -import { LlmItem, useFetchMyLlmListDetailed } from '@/hooks/use-llm-request'; +// import { LlmItem, useFetchMyLlmListDetailed } from '@/hooks/use-llm-request'; import { useCallback, useMemo } from 'react'; import { isLocalLlmFactory } from '../utils'; import SystemSetting from './components/system-setting'; @@ -17,7 +17,6 @@ import { useSubmitOpenDataLoader, useSubmitPaddleOCR, useSubmitSpark, - useSubmitSystemModelSetting, useSubmitTencentCloud, useSubmitVolcEngine, useSubmityiyan, @@ -37,9 +36,7 @@ import SparkModal from './modal/spark-modal'; import VolcEngineModal from './modal/volcengine-modal'; import YiyanModal from './modal/yiyan-modal'; const ModelProviders = () => { - const { saveSystemModelSettingLoading, onSystemSettingSavingOk } = - useSubmitSystemModelSetting(); - const { data: detailedLlmList } = useFetchMyLlmListDetailed(); + // const { data: detailedLlmList } = useFetchMyLlmListDetailed(); const { saveApiKeyLoading, initialApiKey, @@ -192,31 +189,31 @@ const ModelProviders = () => { [showApiKeyModal, showLlmAddingModal, ModalMap], ); - const handleEditModel = useCallback( - (model: any, factory: LlmItem) => { - if (factory) { - const detailedFactory = detailedLlmList[factory.name]; - const detailedModel = detailedFactory?.llm?.find( - (m: any) => m.name === model.name, - ); + // const handleEditModel = useCallback( + // (model: any, factory: LlmItem) => { + // if (factory) { + // const detailedFactory = detailedLlmList[factory.name]; + // const detailedModel = detailedFactory?.llm?.find( + // (m: any) => m.name === model.name, + // ); - const editData = { - llm_factory: factory.name, - llm_name: model.name, - model_type: model.type, - }; + // const editData = { + // llm_factory: factory.name, + // llm_name: model.name, + // model_type: model.type, + // }; - if (isLocalLlmFactory(factory.name)) { - showLlmAddingModal(factory.name, true, editData, detailedModel); - } else if (factory.name in ModalMap) { - ModalMap[factory.name as keyof typeof ModalMap](); - } else { - showApiKeyModal(editData, true); - } - } - }, - [showApiKeyModal, showLlmAddingModal, ModalMap, detailedLlmList], - ); + // if (isLocalLlmFactory(factory.name)) { + // showLlmAddingModal(factory.name, true, editData, detailedModel); + // } else if (factory.name in ModalMap) { + // ModalMap[factory.name as keyof typeof ModalMap](); + // } else { + // showApiKeyModal(editData, true); + // } + // } + // }, + // [showApiKeyModal, showLlmAddingModal, ModalMap, detailedLlmList], + // ); const handleOk = useMemo(() => { if (apiKeyVisible) { @@ -296,14 +293,8 @@ const ModelProviders = () => {
    - - + +
    diff --git a/web/src/pages/user-setting/setting-model/modal/api-key-modal/index.tsx b/web/src/pages/user-setting/setting-model/modal/api-key-modal/index.tsx index ff7a559e26..af92c74f9c 100644 --- a/web/src/pages/user-setting/setting-model/modal/api-key-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/modal/api-key-modal/index.tsx @@ -1,12 +1,6 @@ import { IModalManagerChildrenProps } from '@/components/modal-manager'; -import { - Form, - FormControl, - FormField, - FormItem, - FormLabel, - FormMessage, -} from '@/components/ui/form'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { Form } from '@/components/ui/form'; import { Input } from '@/components/ui/input'; import { Modal } from '@/components/ui/modal/modal'; import { LLMFactory } from '@/constants/llm'; @@ -31,6 +25,7 @@ interface IProps extends Omit { } type FieldType = { + instance_name?: string; api_key?: string; base_url?: string; group_id?: string; @@ -50,7 +45,7 @@ const ApiKeyModal = ({ llmFactory, loading, initialValue, - editMode = false, + // editMode = false, onOk, onVerify, }: IProps) => { @@ -92,108 +87,100 @@ const ApiKeyModal = ({ >
    - ( - - - {t('apiKey')} - - - - - - + + {(field) => ( + )} - /> + + + + {(field) => ( + + )} + {modelsWithBaseUrl.some((x) => x === llmFactory) && ( - ( - - - {t('baseUrl')} - - - - - - + label={t('baseUrl')} + tooltip={ + llmFactory === LLMFactory.MiniMax + ? t('minimaxBaseUrlTip') + : llmFactory === LLMFactory.TongYiQianWen + ? t('tongyiBaseUrlTip') + : llmFactory === LLMFactory.SILICONFLOW + ? t('siliconBaseUrlTip') + : t('baseUrlTip') + } + labelClassName="text-sm font-medium text-text-primary" + > + {(field) => ( + )} - /> + )} {llmFactory?.toLowerCase() === 'Anthropic'.toLowerCase() && ( - ( - - - {t('baseUrl')} - - - - - - + label={t('baseUrl')} + labelClassName="text-sm font-medium text-text-primary" + > + {(field) => ( + )} - /> + )} {llmFactory?.toLowerCase() === 'Minimax'.toLowerCase() && ( - ( - - - Group ID - - - - - - - )} - /> + label="Group ID" + labelClassName="text-sm font-medium text-text-primary" + > + {(field) => } + )} diff --git a/web/src/pages/user-setting/setting-model/modal/azure-openai-modal/index.tsx b/web/src/pages/user-setting/setting-model/modal/azure-openai-modal/index.tsx index b8442d29a9..3ed6e54373 100644 --- a/web/src/pages/user-setting/setting-model/modal/azure-openai-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/modal/azure-openai-modal/index.tsx @@ -8,9 +8,13 @@ import { Modal } from '@/components/ui/modal/modal'; import { useCommonTranslation, useTranslate } from '@/hooks/common-hooks'; import { useBuildModelTypeOptions } from '@/hooks/logic-hooks/use-build-options'; import { IModalProps } from '@/interfaces/common'; -import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { VerifyResult } from '@/pages/user-setting/setting-model/hooks'; -import { memo, useCallback, useRef } from 'react'; +import { IAddProviderInstanceRequestBody } from '@/interfaces/request/llm'; +import { + useFetchInstanceNameSet, + useHideWhenInstanceExists, + VerifyResult, +} from '@/pages/user-setting/setting-model/hooks'; +import { memo, useCallback, useMemo, useRef } from 'react'; import { FieldValues } from 'react-hook-form'; import { LLMHeader } from '../../components/llm-header'; import VerifyButton from '../../modal/verify-button'; @@ -22,7 +26,7 @@ const AzureOpenAIModal = ({ onVerify, loading, llmFactory, -}: IModalProps & { +}: IModalProps & { llmFactory: string; onVerify?: ( postBody: any, @@ -32,87 +36,107 @@ const AzureOpenAIModal = ({ const { t: tg } = useCommonTranslation(); const { buildModelTypeOptions } = useBuildModelTypeOptions(); const formRef = useRef(null); + const { instanceNameSet } = useFetchInstanceNameSet(llmFactory); - const fields: FormFieldConfig[] = [ - { - name: 'model_type', - label: t('modelType'), - type: FormFieldType.Select, - required: true, - options: buildModelTypeOptions(['chat', 'embedding', 'image2text']), - defaultValue: 'embedding', - validation: { - message: t('modelTypeMessage'), + const hideWhenInstanceExists = useHideWhenInstanceExists(instanceNameSet); + + const fields: FormFieldConfig[] = useMemo( + () => [ + { + name: 'instance_name', + label: t('instanceName'), + type: FormFieldType.Text, + required: true, + placeholder: t('instanceNameMessage'), + tooltip: t('instanceNameTip'), + validation: { + message: t('instanceNameMessage'), + }, }, - }, - { - name: 'api_base', - label: t('addLlmBaseUrl'), - type: FormFieldType.Text, - required: true, - placeholder: t('baseUrlNameMessage'), - validation: { - message: t('baseUrlNameMessage'), + { + name: 'model_type', + label: t('modelType'), + type: FormFieldType.MultiSelect, + required: true, + options: buildModelTypeOptions(['chat', 'embedding', 'image2text']), + defaultValue: ['embedding'], }, - }, - { - name: 'api_key', - label: t('apiKey'), - type: FormFieldType.Text, - required: false, - placeholder: t('apiKeyMessage'), - }, - { - name: 'llm_name', - label: t('modelName'), - type: FormFieldType.Text, - required: true, - placeholder: t('modelNameMessage'), - defaultValue: 'gpt-3.5-turbo', - validation: { - message: t('modelNameMessage'), + { + name: 'api_base', + label: t('addLlmBaseUrl'), + type: FormFieldType.Text, + required: true, + placeholder: t('baseUrlNameMessage'), + validation: { + message: t('baseUrlNameMessage'), + }, + shouldRender: hideWhenInstanceExists, }, - }, - { - name: 'api_version', - label: t('apiVersion'), - type: FormFieldType.Text, - required: false, - placeholder: t('apiVersionMessage'), - defaultValue: '2024-02-01', - }, - { - name: 'max_tokens', - label: t('maxTokens'), - type: FormFieldType.Number, - required: true, - placeholder: t('maxTokensTip'), - validation: { - min: 0, - message: t('maxTokensMessage'), + { + name: 'api_key', + label: t('apiKey'), + type: FormFieldType.Text, + required: false, + placeholder: t('apiKeyMessage'), + shouldRender: hideWhenInstanceExists, }, - }, - { - name: 'vision', - label: t('vision'), - type: FormFieldType.Switch, - defaultValue: false, - dependencies: ['model_type'], - shouldRender: (formValues: any) => { - return formValues?.model_type === 'chat'; + { + name: 'llm_name', + label: t('modelName'), + type: FormFieldType.Text, + required: true, + placeholder: t('modelNameMessage'), + defaultValue: 'gpt-3.5-turbo', + validation: { + message: t('modelNameMessage'), + }, }, - }, - ]; + { + name: 'api_version', + label: t('apiVersion'), + type: FormFieldType.Text, + required: false, + placeholder: t('apiVersionMessage'), + defaultValue: '2024-02-01', + }, + { + name: 'max_tokens', + label: t('maxTokens'), + type: FormFieldType.Number, + required: true, + placeholder: t('maxTokensTip'), + validation: { + min: 0, + message: t('maxTokensMessage'), + }, + }, + { + name: 'vision', + label: t('vision'), + type: FormFieldType.Switch, + defaultValue: false, + dependencies: ['model_type'], + shouldRender: (formValues: any) => { + const modelType = formValues?.model_type; + if (Array.isArray(modelType)) { + return modelType.includes('chat'); + } + return modelType === 'chat'; + }, + }, + ], + [t, buildModelTypeOptions, hideWhenInstanceExists], + ); const handleOk = async (values?: FieldValues) => { if (!values) return; - const modelType = - values.model_type === 'chat' && values.vision - ? 'image2text' - : values.model_type; + const modelType = values.model_type.map((t: string) => + t === 'chat' && values.vision ? 'image2text' : t, + ); - const data: IAddLlmRequestBody & { api_version?: string } = { + const data: IAddProviderInstanceRequestBody & { api_version?: string } = { + instance_name: values.instance_name as string, llm_factory: llmFactory, llm_name: values.llm_name as string, model_type: modelType, @@ -127,13 +151,11 @@ const AzureOpenAIModal = ({ const verifyParamsFunc = useCallback(() => { const values = formRef.current?.getValues(); - const modelType = - values.model_type === 'chat' && values.vision - ? 'image2text' - : values.model_type; return { llm_factory: llmFactory, - model_type: modelType, + model_type: values.model_type.map((t: string) => + t === 'chat' && values.vision ? 'image2text' : t, + ), }; }, [llmFactory]); @@ -162,10 +184,12 @@ const AzureOpenAIModal = ({ ref={formRef} defaultValues={ { - model_type: 'embedding', + instance_name: '', + model_type: ['embedding'], llm_name: 'gpt-3.5-turbo', api_version: '2024-02-01', vision: false, + max_tokens: 8192, } as FieldValues } labelClassName="font-normal" diff --git a/web/src/pages/user-setting/setting-model/modal/bedrock-modal/index.tsx b/web/src/pages/user-setting/setting-model/modal/bedrock-modal/index.tsx index f9d4f2b3fe..7103103878 100644 --- a/web/src/pages/user-setting/setting-model/modal/bedrock-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/modal/bedrock-modal/index.tsx @@ -4,26 +4,31 @@ import { ButtonLoading } from '@/components/ui/button'; import { Form } from '@/components/ui/form'; import { Input } from '@/components/ui/input'; import { Modal } from '@/components/ui/modal/modal'; +import { MultiSelect } from '@/components/ui/multi-select'; import { Segmented } from '@/components/ui/segmented'; import { useCommonTranslation, useTranslate } from '@/hooks/common-hooks'; import { useBuildModelTypeOptions } from '@/hooks/logic-hooks/use-build-options'; import { IModalProps } from '@/interfaces/common'; -import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { VerifyResult } from '@/pages/user-setting/setting-model/hooks'; +import { IAddProviderInstanceRequestBody } from '@/interfaces/request/llm'; +import { + useFetchInstanceNameSet, + VerifyResult, +} from '@/pages/user-setting/setting-model/hooks'; import { zodResolver } from '@hookform/resolvers/zod'; -import { memo, useCallback, useMemo } from 'react'; +import { memo, useCallback, useMemo, useRef } from 'react'; import { useForm, useWatch } from 'react-hook-form'; import { z } from 'zod'; import { LLMHeader } from '../../components/llm-header'; import { BedrockRegionList } from '../../constant'; import VerifyButton from '../../modal/verify-button'; -type FieldType = IAddLlmRequestBody & { +type FieldType = Omit & { auth_mode?: 'access_key_secret' | 'iam_role' | 'assume_role'; bedrock_ak: string; bedrock_sk: string; bedrock_region: string; aws_role_arn?: string; + model_type: string[]; }; const BedrockModal = ({ @@ -33,7 +38,7 @@ const BedrockModal = ({ onVerify, loading, llmFactory, -}: IModalProps & { +}: IModalProps & { llmFactory: string; onVerify?: ( postBody: any, @@ -42,61 +47,84 @@ const BedrockModal = ({ const { t } = useTranslate('setting'); const { t: ct } = useCommonTranslation(); const { buildModelTypeOptions } = useBuildModelTypeOptions(); + const { instanceNameSet } = useFetchInstanceNameSet(llmFactory); + const instanceExistsRef = useRef(false); - const FormSchema = z - .object({ - model_type: z.enum(['chat', 'embedding'], { - required_error: t('modelTypeMessage'), - }), - llm_name: z.string().min(1, { message: t('bedrockModelNameMessage') }), - bedrock_region: z.string().min(1, { message: t('bedrockRegionMessage') }), - max_tokens: z - .number({ - required_error: t('maxTokensMessage'), - invalid_type_error: t('maxTokensInvalidMessage'), + const FormSchema = useMemo( + () => + z + .object({ + instance_name: z + .string() + .min(1, { message: t('instanceNameMessage') }), + model_type: z + .array(z.enum(['chat', 'embedding'])) + .min(1, { message: t('modelTypeMessage') }), + llm_name: z + .string() + .min(1, { message: t('bedrockModelNameMessage') }), + bedrock_region: z.string().optional(), + max_tokens: z + .number({ + required_error: t('maxTokensMessage'), + invalid_type_error: t('maxTokensInvalidMessage'), + }) + .nonnegative({ message: t('maxTokensMinMessage') }), + auth_mode: z + .enum(['access_key_secret', 'iam_role', 'assume_role']) + .default('access_key_secret'), + bedrock_ak: z.string().optional(), + bedrock_sk: z.string().optional(), + aws_role_arn: z.string().optional(), }) - .nonnegative({ message: t('maxTokensMinMessage') }), - auth_mode: z - .enum(['access_key_secret', 'iam_role', 'assume_role']) - .default('access_key_secret'), - bedrock_ak: z.string().optional(), - bedrock_sk: z.string().optional(), - aws_role_arn: z.string().optional(), - }) - .superRefine((data, ctx) => { - if (data.auth_mode === 'access_key_secret') { - if (!data.bedrock_ak || data.bedrock_ak.trim() === '') { - ctx.addIssue({ - code: z.ZodIssueCode.custom, - message: t('bedrockAKMessage'), - path: ['bedrock_ak'], - }); - } - if (!data.bedrock_sk || data.bedrock_sk.trim() === '') { - ctx.addIssue({ - code: z.ZodIssueCode.custom, - message: t('bedrockSKMessage'), - path: ['bedrock_sk'], - }); - } - } + .superRefine((data, ctx) => { + if (instanceExistsRef.current) return; - if (data.auth_mode === 'iam_role') { - if (!data.aws_role_arn || data.aws_role_arn.trim() === '') { - ctx.addIssue({ - code: z.ZodIssueCode.custom, - message: t('awsRoleArnMessage'), - path: ['aws_role_arn'], - }); - } - } - }); + if (!data.bedrock_region || data.bedrock_region.trim() === '') { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: t('bedrockRegionMessage'), + path: ['bedrock_region'], + }); + } + + if (data.auth_mode === 'access_key_secret') { + if (!data.bedrock_ak || data.bedrock_ak.trim() === '') { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: t('bedrockAKMessage'), + path: ['bedrock_ak'], + }); + } + if (!data.bedrock_sk || data.bedrock_sk.trim() === '') { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: t('bedrockSKMessage'), + path: ['bedrock_sk'], + }); + } + } + + if (data.auth_mode === 'iam_role') { + if (!data.aws_role_arn || data.aws_role_arn.trim() === '') { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: t('awsRoleArnMessage'), + path: ['aws_role_arn'], + }); + } + } + }), + [t], + ); const form = useForm({ resolver: zodResolver(FormSchema), defaultValues: { - model_type: 'chat', + instance_name: '', + model_type: ['chat'], auth_mode: 'access_key_secret', + max_tokens: 8192, }, }); @@ -105,6 +133,17 @@ const BedrockModal = ({ name: 'auth_mode', }); + const instanceName = useWatch({ + control: form.control, + name: 'instance_name', + }); + + const instanceExists = useMemo(() => { + const trimmed = (instanceName || '').trim(); + return !!trimmed && instanceNameSet.has(trimmed); + }, [instanceName, instanceNameSet]); + instanceExistsRef.current = instanceExists; + const options = useMemo( () => BedrockRegionList.map((x) => ({ value: x, label: t(x) })), [t], @@ -133,9 +172,10 @@ const BedrockModal = ({ ...cleanedValues, llm_factory: llmFactory, max_tokens: values.max_tokens, + model_type: values.model_type, }; - onOk?.(data as unknown as IAddLlmRequestBody); + onOk?.(data as unknown as IAddProviderInstanceRequestBody); }; const verifyParamsFunc = useCallback(() => { @@ -160,6 +200,7 @@ const BedrockModal = ({ ...cleanedValues, llm_factory: llmFactory, max_tokens: values.max_tokens, + model_type: values.model_type, }; }, [llmFactory, authMode, form]); @@ -199,13 +240,24 @@ const BedrockModal = ({ className="space-y-6" id="bedrock-form" > + + + + {(field) => ( - )} @@ -214,86 +266,93 @@ const BedrockModal = ({ -
    - - {(field) => ( - { - // Clear non-active fields so they won't be validated/submitted by accident. - if (value !== 'access_key_secret') { - form.setValue('bedrock_ak', ''); - form.setValue('bedrock_sk', ''); - } - if (value !== 'iam_role') { - form.setValue('aws_role_arn', ''); - } - field.onChange(value); - }} - options={[ - { - label: t('awsAuthModeAccessKeySecret'), - value: 'access_key_secret', - }, - { label: t('awsAuthModeIamRole'), value: 'iam_role' }, - { label: t('awsAuthModeAssumeRole'), value: 'assume_role' }, - ]} - /> - )} - -
    - - {authMode === 'access_key_secret' && ( + {!instanceExists && ( <> +
    + + {(field) => ( + { + // Clear non-active fields so they won't be validated/submitted by accident. + if (value !== 'access_key_secret') { + form.setValue('bedrock_ak', ''); + form.setValue('bedrock_sk', ''); + } + if (value !== 'iam_role') { + form.setValue('aws_role_arn', ''); + } + field.onChange(value); + }} + options={[ + { + label: t('awsAuthModeAccessKeySecret'), + value: 'access_key_secret', + }, + { label: t('awsAuthModeIamRole'), value: 'iam_role' }, + { + label: t('awsAuthModeAssumeRole'), + value: 'assume_role', + }, + ]} + /> + )} + +
    + + {authMode === 'access_key_secret' && ( + <> + + + + + + + + )} + + {authMode === 'iam_role' && ( + + + + )} + + {authMode === 'assume_role' && ( +
    + {t('awsAssumeRoleTip')} +
    + )} + - - - - + {(field) => ( + + )} )} - {authMode === 'iam_role' && ( - - - - )} - - {authMode === 'assume_role' && ( -
    - {t('awsAssumeRoleTip')} -
    - )} - - - {(field) => ( - - )} - - {(field) => ( & { +}: IModalProps & { llmFactory: string; onVerify?: ( postBody: any, @@ -30,58 +34,78 @@ const FishAudioModal = ({ const { t } = useTranslate('setting'); const { t: tc } = useCommonTranslation(); const { buildModelTypeOptions } = useBuildModelTypeOptions(); + const { instanceNameSet } = useFetchInstanceNameSet(llmFactory); - const fields: FormFieldConfig[] = [ - { - name: 'model_type', - label: t('modelType'), - type: FormFieldType.Select, - required: true, - options: buildModelTypeOptions(['tts']), - defaultValue: 'tts', - validation: { message: t('modelTypeMessage') }, - }, - { - name: 'llm_name', - label: t('modelName'), - type: FormFieldType.Text, - required: true, - placeholder: t('FishAudioModelNameMessage'), - validation: { message: t('FishAudioModelNameMessage') }, - }, - { - name: 'fish_audio_ak', - label: t('addFishAudioAK'), - type: FormFieldType.Text, - required: true, - placeholder: t('FishAudioAKMessage'), - validation: { message: t('FishAudioAKMessage') }, - }, - { - name: 'fish_audio_refid', - label: t('addFishAudioRefID'), - type: FormFieldType.Text, - required: true, - placeholder: t('FishAudioRefIDMessage'), - validation: { message: t('FishAudioRefIDMessage') }, - }, - { - name: 'max_tokens', - label: t('maxTokens'), - type: FormFieldType.Number, - required: true, - placeholder: t('maxTokensTip'), - validation: { - min: 0, - message: t('maxTokensInvalidMessage'), + const hideWhenInstanceExists = useHideWhenInstanceExists(instanceNameSet); + + const fields: FormFieldConfig[] = useMemo( + () => [ + { + name: 'instance_name', + label: t('instanceName'), + type: FormFieldType.Text, + required: true, + placeholder: t('instanceNameMessage'), + tooltip: t('instanceNameTip'), + validation: { message: t('instanceNameMessage') }, }, - }, - ]; + { + name: 'model_type', + label: t('modelType'), + type: FormFieldType.MultiSelect, + required: true, + options: buildModelTypeOptions(['tts']), + defaultValue: ['tts'], + }, + { + name: 'llm_name', + label: t('modelName'), + type: FormFieldType.Text, + required: true, + placeholder: t('FishAudioModelNameMessage'), + validation: { message: t('FishAudioModelNameMessage') }, + }, + { + name: 'fish_audio_ak', + label: t('addFishAudioAK'), + type: FormFieldType.Text, + required: true, + placeholder: t('FishAudioAKMessage'), + validation: { message: t('FishAudioAKMessage') }, + shouldRender: hideWhenInstanceExists, + }, + { + name: 'fish_audio_refid', + label: t('addFishAudioRefID'), + type: FormFieldType.Text, + required: true, + placeholder: t('FishAudioRefIDMessage'), + validation: { message: t('FishAudioRefIDMessage') }, + shouldRender: hideWhenInstanceExists, + }, + { + name: 'max_tokens', + label: t('maxTokens'), + type: FormFieldType.Number, + required: true, + placeholder: t('maxTokensTip'), + validation: { + min: 0, + message: t('maxTokensInvalidMessage'), + }, + }, + ], + [t, buildModelTypeOptions, hideWhenInstanceExists], + ); const handleOk = async (values?: FieldValues) => { if (!values) return; - const data: Record = { + const data: IAddProviderInstanceRequestBody & { + fish_audio_ak: string; + fish_audio_refid: string; + } = { + instance_name: values.instance_name as string, llm_factory: llmFactory, llm_name: values.llm_name as string, model_type: values.model_type, @@ -90,8 +114,7 @@ const FishAudioModal = ({ max_tokens: values.max_tokens as number, }; - console.info(data); - await onOk?.(data as IAddLlmRequestBody); + await onOk?.(data); }; const handleVerify = useCallback( @@ -114,7 +137,11 @@ const FishAudioModal = ({ console.log(data)} - defaultValues={{ model_type: 'tts' }} + defaultValues={{ + instance_name: '', + model_type: ['tts'], + max_tokens: 8192, + }} labelClassName="font-normal" > {onVerify && ( diff --git a/web/src/pages/user-setting/setting-model/modal/google-modal/index.tsx b/web/src/pages/user-setting/setting-model/modal/google-modal/index.tsx index 322668e732..289d995f14 100644 --- a/web/src/pages/user-setting/setting-model/modal/google-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/modal/google-modal/index.tsx @@ -7,9 +7,13 @@ import { Modal } from '@/components/ui/modal/modal'; import { useCommonTranslation, useTranslate } from '@/hooks/common-hooks'; import { useBuildModelTypeOptions } from '@/hooks/logic-hooks/use-build-options'; import { IModalProps } from '@/interfaces/common'; -import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { VerifyResult } from '@/pages/user-setting/setting-model/hooks'; -import { memo, useCallback } from 'react'; +import { IAddProviderInstanceRequestBody } from '@/interfaces/request/llm'; +import { + useFetchInstanceNameSet, + useHideWhenInstanceExists, + VerifyResult, +} from '@/pages/user-setting/setting-model/hooks'; +import { memo, useCallback, useMemo } from 'react'; import { FieldValues } from 'react-hook-form'; import { LLMHeader } from '../../components/llm-header'; import VerifyButton from '../../modal/verify-button'; @@ -21,7 +25,7 @@ const GoogleModal = ({ onVerify, loading, llmFactory, -}: IModalProps & { +}: IModalProps & { llmFactory: string; onVerify?: ( postBody: any, @@ -30,85 +34,101 @@ const GoogleModal = ({ const { t } = useTranslate('setting'); const { t: tc } = useCommonTranslation(); const { buildModelTypeOptions } = useBuildModelTypeOptions(); + const { instanceNameSet } = useFetchInstanceNameSet(llmFactory); - const fields: FormFieldConfig[] = [ - { - name: 'model_type', - label: t('modelType'), - type: FormFieldType.Select, - required: true, - options: buildModelTypeOptions(['chat', 'image2text']), - defaultValue: 'chat', - validation: { - message: t('modelTypeMessage'), + const hideWhenInstanceExists = useHideWhenInstanceExists(instanceNameSet); + + const fields: FormFieldConfig[] = useMemo( + () => [ + { + name: 'instance_name', + label: t('instanceName'), + type: FormFieldType.Text, + required: true, + placeholder: t('instanceNameMessage'), + tooltip: t('instanceNameTip'), + validation: { message: t('instanceNameMessage') }, }, - }, - { - name: 'llm_name', - label: t('modelID'), - type: FormFieldType.Text, - required: true, - placeholder: t('GoogleModelIDMessage'), - validation: { - message: t('GoogleModelIDMessage'), + { + name: 'model_type', + label: t('modelType'), + type: FormFieldType.MultiSelect, + required: true, + options: buildModelTypeOptions(['chat', 'image2text']), + defaultValue: ['chat'], }, - }, - { - name: 'google_project_id', - label: t('addGoogleProjectID'), - type: FormFieldType.Text, - required: true, - placeholder: t('GoogleProjectIDMessage'), - validation: { - message: t('GoogleProjectIDMessage'), + { + name: 'llm_name', + label: t('modelID'), + type: FormFieldType.Text, + required: true, + placeholder: t('GoogleModelIDMessage'), + validation: { + message: t('GoogleModelIDMessage'), + }, }, - }, - { - name: 'google_region', - label: t('addGoogleRegion'), - type: FormFieldType.Text, - required: true, - placeholder: t('GoogleRegionMessage'), - validation: { - message: t('GoogleRegionMessage'), + { + name: 'google_project_id', + label: t('addGoogleProjectID'), + type: FormFieldType.Text, + required: true, + placeholder: t('GoogleProjectIDMessage'), + validation: { + message: t('GoogleProjectIDMessage'), + }, + shouldRender: hideWhenInstanceExists, }, - }, - { - name: 'google_service_account_key', - label: t('addGoogleServiceAccountKey'), - type: FormFieldType.Text, - required: true, - placeholder: t('GoogleServiceAccountKeyMessage'), - validation: { - message: t('GoogleServiceAccountKeyMessage'), + { + name: 'google_region', + label: t('addGoogleRegion'), + type: FormFieldType.Text, + required: true, + placeholder: t('GoogleRegionMessage'), + validation: { + message: t('GoogleRegionMessage'), + }, + shouldRender: hideWhenInstanceExists, }, - }, - { - name: 'max_tokens', - label: t('maxTokens'), - type: FormFieldType.Number, - required: true, - placeholder: t('maxTokensTip'), - validation: { - min: 0, - message: t('maxTokensMinMessage'), + { + name: 'google_service_account_key', + label: t('addGoogleServiceAccountKey'), + type: FormFieldType.Text, + required: true, + placeholder: t('GoogleServiceAccountKeyMessage'), + validation: { + message: t('GoogleServiceAccountKeyMessage'), + }, + shouldRender: hideWhenInstanceExists, }, - customValidate: (value: any) => { - if (value === undefined || value === null || value === '') { - return t('maxTokensMessage'); - } - if (value < 0) { - return t('maxTokensMinMessage'); - } - return true; + { + name: 'max_tokens', + label: t('maxTokens'), + type: FormFieldType.Number, + required: true, + placeholder: t('maxTokensTip'), + validation: { + min: 0, + message: t('maxTokensMinMessage'), + }, + customValidate: (value: any) => { + if (value === undefined || value === null || value === '') { + return t('maxTokensMessage'); + } + if (value < 0) { + return t('maxTokensMinMessage'); + } + return true; + }, }, - }, - ]; + ], + [t, buildModelTypeOptions, hideWhenInstanceExists], + ); const handleOk = async (values?: FieldValues) => { if (!values) return; const data = { + instance_name: values.instance_name as string, llm_factory: llmFactory, model_type: values.model_type, llm_name: values.llm_name, @@ -116,7 +136,7 @@ const GoogleModal = ({ google_region: values.google_region, google_service_account_key: values.google_service_account_key, max_tokens: values.max_tokens, - } as IAddLlmRequestBody; + } as IAddProviderInstanceRequestBody; await onOk?.(data); }; @@ -150,7 +170,9 @@ const GoogleModal = ({ }} defaultValues={ { - model_type: 'chat', + instance_name: '', + model_type: ['chat'], + max_tokens: 8192, } as FieldValues } labelClassName="font-normal" diff --git a/web/src/pages/user-setting/setting-model/modal/mineru-modal/index.tsx b/web/src/pages/user-setting/setting-model/modal/mineru-modal/index.tsx index 83b4ed7121..cc94eb40a6 100644 --- a/web/src/pages/user-setting/setting-model/modal/mineru-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/modal/mineru-modal/index.tsx @@ -25,6 +25,9 @@ import { LLMHeader } from '../../components/llm-header'; import VerifyButton from '../verify-button'; const FormSchema = z.object({ + instance_name: z.string().min(1, { + message: t('setting.instanceNameMessage'), + }), llm_name: z.string().min(1, { message: t('setting.mineru.modelNameRequired'), }), @@ -71,6 +74,7 @@ const MinerUModal = ({ const form = useForm({ resolver: zodResolver(FormSchema), defaultValues: { + instance_name: '', mineru_backend: 'pipeline', mineru_delete_output: true, }, @@ -102,6 +106,14 @@ const MinerUModal = ({ className="space-y-6" id="mineru-form" > + + + > & { +}: IModalProps> & { llmFactory: string; onVerify?: ( postBody: any, @@ -30,87 +34,100 @@ const TencentCloudModal = ({ const { t } = useTranslate('setting'); const { t: tc } = useCommonTranslation(); const { buildModelTypeOptions } = useBuildModelTypeOptions(); + const { instanceNameSet } = useFetchInstanceNameSet(llmFactory); - const fields: FormFieldConfig[] = [ - { - name: 'model_type', - label: t('modelType'), - type: FormFieldType.Select, - required: true, - options: buildModelTypeOptions(['speech2text']), - defaultValue: 'speech2text', - validation: { - message: t('modelTypeMessage'), + const hideWhenInstanceExists = useHideWhenInstanceExists(instanceNameSet); + + const fields: FormFieldConfig[] = useMemo( + () => [ + { + name: 'instance_name', + label: t('instanceName'), + type: FormFieldType.Text, + required: true, + placeholder: t('instanceNameMessage'), + tooltip: t('instanceNameTip'), + validation: { message: t('instanceNameMessage') }, }, - }, - { - name: 'llm_name', - label: t('modelName'), - type: FormFieldType.Select, - required: true, - options: [ - { label: '16k_zh', value: '16k_zh' }, - { label: '16k_zh_large', value: '16k_zh_large' }, - { label: '16k_multi_lang', value: '16k_multi_lang' }, - { label: '16k_zh_dialect', value: '16k_zh_dialect' }, - { label: '16k_en', value: '16k_en' }, - { label: '16k_yue', value: '16k_yue' }, - { label: '16k_zh-PY', value: '16k_zh-PY' }, - { label: '16k_ja', value: '16k_ja' }, - { label: '16k_ko', value: '16k_ko' }, - { label: '16k_vi', value: '16k_vi' }, - { label: '16k_ms', value: '16k_ms' }, - { label: '16k_id', value: '16k_id' }, - { label: '16k_fil', value: '16k_fil' }, - { label: '16k_th', value: '16k_th' }, - { label: '16k_pt', value: '16k_pt' }, - { label: '16k_tr', value: '16k_tr' }, - { label: '16k_ar', value: '16k_ar' }, - { label: '16k_es', value: '16k_es' }, - { label: '16k_hi', value: '16k_hi' }, - { label: '16k_fr', value: '16k_fr' }, - { label: '16k_zh_medical', value: '16k_zh_medical' }, - { label: '16k_de', value: '16k_de' }, - ], - defaultValue: '16k_zh', - validation: { - message: t('SparkModelNameMessage'), + { + name: 'model_type', + label: t('modelType'), + type: FormFieldType.MultiSelect, + required: true, + options: buildModelTypeOptions(['speech2text']), + defaultValue: ['speech2text'], }, - }, - { - name: 'TencentCloud_sid', - label: t('addTencentCloudSID'), - type: FormFieldType.Text, - required: true, - placeholder: t('TencentCloudSIDMessage'), - validation: { - message: t('TencentCloudSIDMessage'), + { + name: 'llm_name', + label: t('modelName'), + type: FormFieldType.Select, + required: true, + options: [ + { label: '16k_zh', value: '16k_zh' }, + { label: '16k_zh_large', value: '16k_zh_large' }, + { label: '16k_multi_lang', value: '16k_multi_lang' }, + { label: '16k_zh_dialect', value: '16k_zh_dialect' }, + { label: '16k_en', value: '16k_en' }, + { label: '16k_yue', value: '16k_yue' }, + { label: '16k_zh-PY', value: '16k_zh-PY' }, + { label: '16k_ja', value: '16k_ja' }, + { label: '16k_ko', value: '16k_ko' }, + { label: '16k_vi', value: '16k_vi' }, + { label: '16k_ms', value: '16k_ms' }, + { label: '16k_id', value: '16k_id' }, + { label: '16k_fil', value: '16k_fil' }, + { label: '16k_th', value: '16k_th' }, + { label: '16k_pt', value: '16k_pt' }, + { label: '16k_tr', value: '16k_tr' }, + { label: '16k_ar', value: '16k_ar' }, + { label: '16k_es', value: '16k_es' }, + { label: '16k_hi', value: '16k_hi' }, + { label: '16k_fr', value: '16k_fr' }, + { label: '16k_zh_medical', value: '16k_zh_medical' }, + { label: '16k_de', value: '16k_de' }, + ], + defaultValue: '16k_zh', + validation: { + message: t('SparkModelNameMessage'), + }, }, - }, - { - name: 'TencentCloud_sk', - label: t('addTencentCloudSK'), - type: FormFieldType.Text, - required: true, - placeholder: t('TencentCloudSKMessage'), - validation: { - message: t('TencentCloudSKMessage'), + { + name: 'TencentCloud_sid', + label: t('addTencentCloudSID'), + type: FormFieldType.Text, + required: true, + placeholder: t('TencentCloudSIDMessage'), + validation: { + message: t('TencentCloudSIDMessage'), + }, + shouldRender: hideWhenInstanceExists, }, - }, - ]; + { + name: 'TencentCloud_sk', + label: t('addTencentCloudSK'), + type: FormFieldType.Text, + required: true, + placeholder: t('TencentCloudSKMessage'), + validation: { + message: t('TencentCloudSKMessage'), + }, + shouldRender: hideWhenInstanceExists, + }, + ], + [t, buildModelTypeOptions, hideWhenInstanceExists], + ); const handleOk = async (values?: FieldValues) => { if (!values) return; - const modelType = values.model_type; - const data = { - model_type: modelType, + instance_name: values.instance_name as string, + model_type: values.model_type, llm_name: values.llm_name as string, TencentCloud_sid: values.TencentCloud_sid as string, TencentCloud_sk: values.TencentCloud_sk as string, llm_factory: llmFactory, - } as Omit; + } as Omit; await onOk?.(data); }; @@ -143,7 +160,8 @@ const TencentCloudModal = ({ onSubmit={() => {}} defaultValues={ { - model_type: 'speech2text', + instance_name: '', + model_type: ['speech2text'], llm_name: '16k_zh', } as FieldValues } diff --git a/web/src/pages/user-setting/setting-model/modal/ollama-modal/index.tsx b/web/src/pages/user-setting/setting-model/modal/ollama-modal/index.tsx index 1b59ecb422..b345039c96 100644 --- a/web/src/pages/user-setting/setting-model/modal/ollama-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/modal/ollama-modal/index.tsx @@ -9,8 +9,12 @@ import { LLMFactory } from '@/constants/llm'; import { useCommonTranslation, useTranslate } from '@/hooks/common-hooks'; import { useBuildModelTypeOptions } from '@/hooks/logic-hooks/use-build-options'; import { IModalProps } from '@/interfaces/common'; -import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { VerifyResult } from '@/pages/user-setting/setting-model/hooks'; +import { IAddProviderInstanceRequestBody } from '@/interfaces/request/llm'; +import { + useFetchInstanceNameSet, + useHideWhenInstanceExists, + VerifyResult, +} from '@/pages/user-setting/setting-model/hooks'; import { memo, useCallback, useMemo, useRef } from 'react'; import { FieldValues } from 'react-hook-form'; import { LLMHeader } from '../../components/llm-header'; @@ -47,7 +51,9 @@ const OllamaModal = ({ llmFactory, editMode = false, initialValues, -}: IModalProps & { provider_order?: string }> & { +}: IModalProps< + Partial & { provider_order?: string } +> & { llmFactory: string; editMode?: boolean; onVerify?: ( @@ -58,6 +64,9 @@ const OllamaModal = ({ const { t: tc } = useCommonTranslation(); const { buildModelTypeOptions } = useBuildModelTypeOptions(); const formRef = useRef(null); + const { instanceNameSet } = useFetchInstanceNameSet(llmFactory); + + const hideWhenInstanceExists = useHideWhenInstanceExists(instanceNameSet); const optionsMap: Partial< Record @@ -118,15 +127,23 @@ const OllamaModal = ({ const defaultToolCallEnabled = initialValues?.is_tools ?? false; const baseFields: FormFieldConfig[] = [ + { + name: 'instance_name', + label: t('instanceName'), + type: FormFieldType.Text, + required: true, + placeholder: t('instanceNameMessage'), + tooltip: t('instanceNameTip'), + validation: { + message: t('instanceNameMessage'), + }, + }, { name: 'model_type', label: t('modelType'), - type: FormFieldType.Select, + type: FormFieldType.MultiSelect, required: true, options: getOptions(llmFactory), - validation: { - message: t('modelTypeMessage'), - }, }, { name: 'llm_name', @@ -147,6 +164,7 @@ const OllamaModal = ({ validation: { message: t('baseUrlNameMessage'), }, + shouldRender: hideWhenInstanceExists, }, { name: 'api_key', @@ -154,6 +172,7 @@ const OllamaModal = ({ type: FormFieldType.Text, required: false, placeholder: t('apiKeyMessage'), + shouldRender: hideWhenInstanceExists, }, { name: 'max_tokens', @@ -186,6 +205,9 @@ const OllamaModal = ({ dependencies: ['model_type'], shouldRender: (formValues: any) => { const modelType = formValues?.model_type; + if (Array.isArray(modelType)) { + return modelType.includes('chat') || modelType.includes('image2text'); + } return modelType === 'chat' || modelType === 'image2text'; }, tooltip: t('enableToolCallTip'), @@ -212,18 +234,25 @@ const OllamaModal = ({ required: false, dependencies: ['model_type'], shouldRender: (formValues: any) => { - return formValues?.model_type === 'chat'; + const modelType = formValues?.model_type; + if (Array.isArray(modelType)) { + return modelType.includes('chat'); + } + return modelType === 'chat'; }, }); return baseFields; - }, [llmFactory, t]); + }, [llmFactory, t, hideWhenInstanceExists, initialValues?.is_tools]); const defaultValues: FieldValues = useMemo(() => { if (editMode && initialValues) { return { + instance_name: initialValues.instance_name || '', llm_name: initialValues.llm_name || '', - model_type: initialValues.model_type || 'chat', + model_type: initialValues.model_type + ? initialValues.model_type.split(',').filter(Boolean) + : ['chat'], api_base: initialValues.api_base || '', max_tokens: initialValues.max_tokens || 8192, api_key: '', @@ -233,34 +262,42 @@ const OllamaModal = ({ }; } return { - model_type: + instance_name: '', + model_type: [ llmFactory === LLMFactory.Ollama || llmFactory === LLMFactory.VLLM ? 'chat' : llmFactory in optionsMap ? optionsMap[llmFactory as LLMFactory]?.at(0)?.value : 'embedding', + ], vision: false, is_tools: false, + max_tokens: 8192, }; }, [editMode, initialValues, llmFactory]); const handleOk = async (values?: FieldValues) => { if (!values) return; - const modelType = - values.model_type === 'chat' && values.vision - ? 'image2text' - : values.model_type; - const supportsToolCall = modelType === 'chat' || modelType === 'image2text'; + const modelType = values.model_type.map((t: string) => + t === 'chat' && values.vision ? 'image2text' : t, + ); + const modelTypeArray: string[] = Array.isArray(values.model_type) + ? values.model_type + : [values.model_type]; + const supportsToolCall = + modelTypeArray.includes('chat') || modelTypeArray.includes('image2text'); - const data: IAddLlmRequestBody & { provider_order?: string } = { - llm_factory: llmFactory, - llm_name: values.llm_name as string, - model_type: modelType, - api_base: values.api_base as string, - api_key: values.api_key as string, - max_tokens: values.max_tokens as number, - }; + const data: IAddProviderInstanceRequestBody & { provider_order?: string } = + { + instance_name: values.instance_name as string, + llm_factory: llmFactory, + llm_name: values.llm_name as string, + model_type: modelType, + api_base: values.api_base as string, + api_key: values.api_key as string, + max_tokens: values.max_tokens as number, + }; if (supportsToolCall) { data.is_tools = Boolean(values.is_tools); } @@ -275,13 +312,11 @@ const OllamaModal = ({ const verifyParamsFunc = useCallback(() => { const values = formRef.current?.getValues(); - const modelType = - values.model_type === 'chat' && values.vision - ? 'image2text' - : values.model_type; return { llm_factory: llmFactory, - model_type: modelType, + model_type: values.model_type.map((t: string) => + t === 'chat' && values.vision ? 'image2text' : t, + ), }; }, [llmFactory]); diff --git a/web/src/pages/user-setting/setting-model/modal/opendataloader-modal/index.tsx b/web/src/pages/user-setting/setting-model/modal/opendataloader-modal/index.tsx index 8d94219176..7d3c6e7933 100644 --- a/web/src/pages/user-setting/setting-model/modal/opendataloader-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/modal/opendataloader-modal/index.tsx @@ -20,6 +20,7 @@ import { LLMHeader } from '../../components/llm-header'; import VerifyButton from '../verify-button'; export type OpenDataLoaderFormValues = { + instance_name: string; llm_name: string; opendataloader_apiserver: string; opendataloader_api_key?: string; @@ -47,6 +48,9 @@ const OpenDataLoaderModal = ({ const FormSchema = useMemo( () => z.object({ + instance_name: z.string().min(1, { + message: t('setting.instanceNameMessage'), + }), llm_name: z.string().min(1, { message: t('setting.modelNameMessage'), }), @@ -61,6 +65,7 @@ const OpenDataLoaderModal = ({ const form = useForm({ resolver: zodResolver(FormSchema), defaultValues: { + instance_name: '', opendataloader_apiserver: '', opendataloader_api_key: '', }, @@ -87,6 +92,14 @@ const OpenDataLoaderModal = ({ className="space-y-6" id="opendataloader-form" > + + + ({ resolver: zodResolver(FormSchema), defaultValues: { + instance_name: '', paddleocr_algorithm: 'PaddleOCR-VL', }, }); @@ -88,6 +92,14 @@ const PaddleOCRModal = ({ className="space-y-6" id="paddleocr-form" > + + + & { +}: IModalProps & { llmFactory: string; onVerify?: ( postBody: any, @@ -33,120 +36,140 @@ const SparkModal = ({ const { t: tc } = useCommonTranslation(); const { buildModelTypeOptions } = useBuildModelTypeOptions(); const formRef = useRef(null); - const fields: FormFieldConfig[] = [ - { - name: 'model_type', - label: t('modelType'), - type: FormFieldType.Select, - required: true, - options: buildModelTypeOptions(['chat', 'tts']), - defaultValue: 'chat', - validation: { - message: t('modelTypeMessage'), + const { instanceNameSet } = useFetchInstanceNameSet(llmFactory); + + const hideWhenInstanceExists = useHideWhenInstanceExists(instanceNameSet); + + const fields: FormFieldConfig[] = useMemo( + () => [ + { + name: 'instance_name', + label: t('instanceName'), + type: FormFieldType.Text, + required: true, + placeholder: t('instanceNameMessage'), + tooltip: t('instanceNameTip'), + validation: { message: t('instanceNameMessage') }, }, - }, - { - name: 'llm_name', - label: t('modelName'), - type: FormFieldType.Text, - required: true, - placeholder: t('modelNameMessage'), - validation: { - message: t('SparkModelNameMessage'), + { + name: 'model_type', + label: t('modelType'), + type: FormFieldType.MultiSelect, + required: true, + options: buildModelTypeOptions(['chat', 'tts']), + defaultValue: ['chat'], }, - }, - { - name: 'spark_api_password', - label: t('addSparkAPIPassword'), - type: FormFieldType.Text, - required: true, - placeholder: t('SparkAPIPasswordMessage'), - validation: { - message: t('SparkAPIPasswordMessage'), + { + name: 'llm_name', + label: t('modelName'), + type: FormFieldType.Text, + required: true, + placeholder: t('modelNameMessage'), + validation: { + message: t('SparkModelNameMessage'), + }, }, - }, - { - name: 'spark_app_id', - label: t('addSparkAPPID'), - type: FormFieldType.Text, - required: true, - placeholder: t('SparkAPPIDMessage'), - validation: { - message: t('SparkAPPIDMessage'), + { + name: 'spark_api_password', + label: t('addSparkAPIPassword'), + type: FormFieldType.Text, + required: true, + placeholder: t('SparkAPIPasswordMessage'), + validation: { + message: t('SparkAPIPasswordMessage'), + }, + shouldRender: hideWhenInstanceExists, }, - dependencies: ['model_type'], - shouldRender: (formValues: any) => { - return formValues?.model_type === 'tts'; + { + name: 'spark_app_id', + label: t('addSparkAPPID'), + type: FormFieldType.Text, + required: true, + placeholder: t('SparkAPPIDMessage'), + validation: { + message: t('SparkAPPIDMessage'), + }, + dependencies: ['model_type', 'instance_name'], + shouldRender: (formValues: any) => { + if (!hideWhenInstanceExists(formValues)) return false; + const modelType = formValues?.model_type; + if (Array.isArray(modelType)) { + return modelType.includes('tts'); + } + return modelType === 'tts'; + }, }, - }, - { - name: 'spark_api_secret', - label: t('addSparkAPISecret'), - type: FormFieldType.Text, - required: true, - placeholder: t('SparkAPISecretMessage'), - validation: { - message: t('SparkAPISecretMessage'), + { + name: 'spark_api_secret', + label: t('addSparkAPISecret'), + type: FormFieldType.Text, + required: true, + placeholder: t('SparkAPISecretMessage'), + validation: { + message: t('SparkAPISecretMessage'), + }, + dependencies: ['model_type', 'instance_name'], + shouldRender: (formValues: any) => { + if (!hideWhenInstanceExists(formValues)) return false; + const modelType = formValues?.model_type; + if (Array.isArray(modelType)) { + return modelType.includes('tts'); + } + return modelType === 'tts'; + }, }, - dependencies: ['model_type'], - shouldRender: (formValues: any) => { - return formValues?.model_type === 'tts'; + { + name: 'spark_api_key', + label: t('addSparkAPIKey'), + type: FormFieldType.Text, + required: true, + placeholder: t('SparkAPIKeyMessage'), + validation: { + message: t('SparkAPIKeyMessage'), + }, + dependencies: ['model_type', 'instance_name'], + shouldRender: (formValues: any) => { + if (!hideWhenInstanceExists(formValues)) return false; + const modelType = formValues?.model_type; + if (Array.isArray(modelType)) { + return modelType.includes('tts'); + } + return modelType === 'tts'; + }, }, - }, - { - name: 'spark_api_key', - label: t('addSparkAPIKey'), - type: FormFieldType.Text, - required: true, - placeholder: t('SparkAPIKeyMessage'), - validation: { - message: t('SparkAPIKeyMessage'), + { + name: 'max_tokens', + label: t('maxTokens'), + type: FormFieldType.Number, + required: true, + placeholder: t('maxTokensTip'), + validation: { + min: 0, + message: t('maxTokensInvalidMessage'), + }, }, - dependencies: ['model_type'], - shouldRender: (formValues: any) => { - return formValues?.model_type === 'tts'; - }, - }, - { - name: 'max_tokens', - label: t('maxTokens'), - type: FormFieldType.Number, - required: true, - placeholder: t('maxTokensTip'), - validation: { - min: 0, - message: t('maxTokensInvalidMessage'), - }, - }, - ]; + ], + [t, buildModelTypeOptions, hideWhenInstanceExists], + ); const handleOk = async (values?: FieldValues) => { if (!values) return; - const modelType = - values.model_type === 'chat' && values.vision - ? 'image2text' - : values.model_type; - const data = { - ...omit(values, ['vision']), - model_type: modelType, + instance_name: values.instance_name as string, + model_type: values.model_type, llm_factory: llmFactory, max_tokens: values.max_tokens, }; - await onOk?.(data as IAddLlmRequestBody); + await onOk?.(data as IAddProviderInstanceRequestBody); }; const verifyParamsFunc = useCallback(() => { const values = formRef.current?.getValues(); - const modelType = - values.model_type === 'chat' && values.vision - ? 'image2text' - : values.model_type; return { llm_factory: llmFactory, - model_type: modelType, + model_type: values.model_type, }; }, [llmFactory]); @@ -175,8 +198,9 @@ const SparkModal = ({ ref={formRef} defaultValues={ { - model_type: 'chat', - vision: false, + instance_name: '', + model_type: ['chat'], + max_tokens: 8192, } as FieldValues } labelClassName="font-normal" diff --git a/web/src/pages/user-setting/setting-model/modal/volcengine-modal/index.tsx b/web/src/pages/user-setting/setting-model/modal/volcengine-modal/index.tsx index 0332be298a..cf6f4426c6 100644 --- a/web/src/pages/user-setting/setting-model/modal/volcengine-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/modal/volcengine-modal/index.tsx @@ -8,14 +8,18 @@ import { Modal } from '@/components/ui/modal/modal'; import { useCommonTranslation, useTranslate } from '@/hooks/common-hooks'; import { useBuildModelTypeOptions } from '@/hooks/logic-hooks/use-build-options'; import { IModalProps } from '@/interfaces/common'; -import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { VerifyResult } from '@/pages/user-setting/setting-model/hooks'; -import { memo, useCallback, useRef } from 'react'; +import { IAddProviderInstanceRequestBody } from '@/interfaces/request/llm'; +import { + useFetchInstanceNameSet, + useHideWhenInstanceExists, + VerifyResult, +} from '@/pages/user-setting/setting-model/hooks'; +import { memo, useCallback, useMemo, useRef } from 'react'; import { FieldValues } from 'react-hook-form'; import { LLMHeader } from '../../components/llm-header'; import VerifyButton from '../../modal/verify-button'; -type VolcEngineLlmRequest = IAddLlmRequestBody & { +type VolcEngineLlmRequest = IAddProviderInstanceRequestBody & { endpoint_id: string; ark_api_key: string; }; @@ -27,7 +31,7 @@ const VolcEngineModal = ({ onVerify, loading, llmFactory, -}: IModalProps & { +}: IModalProps & { llmFactory: string; onVerify?: ( postBody: any, @@ -37,79 +41,87 @@ const VolcEngineModal = ({ const { t: tc } = useCommonTranslation(); const { buildModelTypeOptions } = useBuildModelTypeOptions(); const formRef = useRef(null); - const fields: FormFieldConfig[] = [ - { - name: 'model_type', - label: t('modelType'), - type: FormFieldType.Select, - required: true, - options: buildModelTypeOptions(['chat', 'embedding', 'image2text']), - defaultValue: 'chat', - }, - { - name: 'llm_name', - label: t('modelName'), - type: FormFieldType.Text, - required: true, - placeholder: t('volcModelNameMessage'), - }, - { - name: 'endpoint_id', - label: t('addEndpointID'), - type: FormFieldType.Text, - required: true, - placeholder: t('endpointIDMessage'), - }, - { - name: 'ark_api_key', - label: t('addArkApiKey'), - type: FormFieldType.Text, - required: true, - placeholder: t('ArkApiKeyMessage'), - }, - { - name: 'max_tokens', - label: t('maxTokens'), - type: FormFieldType.Number, - required: true, - placeholder: t('maxTokensTip'), - validation: { - min: 0, + const { instanceNameSet } = useFetchInstanceNameSet(llmFactory); + + const hideWhenInstanceExists = useHideWhenInstanceExists(instanceNameSet); + + const fields: FormFieldConfig[] = useMemo( + () => [ + { + name: 'instance_name', + label: t('instanceName'), + type: FormFieldType.Text, + required: true, + placeholder: t('instanceNameMessage'), + tooltip: t('instanceNameTip'), + validation: { message: t('instanceNameMessage') }, }, - }, - ]; + { + name: 'model_type', + label: t('modelType'), + type: FormFieldType.MultiSelect, + required: true, + options: buildModelTypeOptions(['chat', 'embedding', 'image2text']), + defaultValue: ['chat'], + }, + { + name: 'llm_name', + label: t('modelName'), + type: FormFieldType.Text, + required: true, + placeholder: t('volcModelNameMessage'), + }, + { + name: 'endpoint_id', + label: t('addEndpointID'), + type: FormFieldType.Text, + required: true, + placeholder: t('endpointIDMessage'), + shouldRender: hideWhenInstanceExists, + }, + { + name: 'ark_api_key', + label: t('addArkApiKey'), + type: FormFieldType.Text, + required: true, + placeholder: t('ArkApiKeyMessage'), + shouldRender: hideWhenInstanceExists, + }, + { + name: 'max_tokens', + label: t('maxTokens'), + type: FormFieldType.Number, + required: true, + placeholder: t('maxTokensTip'), + validation: { + min: 0, + }, + }, + ], + [t, buildModelTypeOptions, hideWhenInstanceExists], + ); const handleOk = async (values?: FieldValues) => { if (!values) return; - const modelType = - values.model_type === 'chat' && values.vision - ? 'image2text' - : values.model_type; - const data: VolcEngineLlmRequest = { + instance_name: values.instance_name as string, llm_factory: llmFactory, llm_name: values.llm_name as string, - model_type: modelType, + model_type: values.model_type, endpoint_id: values.endpoint_id as string, ark_api_key: values.ark_api_key as string, max_tokens: values.max_tokens as number, }; - console.info(data); - await onOk?.(data); }; const verifyParamsFunc = useCallback(() => { const values = formRef.current?.getValues(); - const modelType = - values.model_type === 'chat' && values.vision - ? 'image2text' - : values.model_type; return { llm_factory: llmFactory, - model_type: modelType, + model_type: values.model_type, }; }, [llmFactory]); @@ -138,8 +150,9 @@ const VolcEngineModal = ({ ref={formRef} defaultValues={ { - model_type: 'chat', - vision: false, + instance_name: '', + model_type: ['chat'], + max_tokens: 8192, } as FieldValues } labelClassName="font-normal" diff --git a/web/src/pages/user-setting/setting-model/modal/yiyan-modal/index.tsx b/web/src/pages/user-setting/setting-model/modal/yiyan-modal/index.tsx index c5ab31b74f..7babf637c5 100644 --- a/web/src/pages/user-setting/setting-model/modal/yiyan-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/modal/yiyan-modal/index.tsx @@ -8,8 +8,12 @@ import { Modal } from '@/components/ui/modal/modal'; import { useCommonTranslation, useTranslate } from '@/hooks/common-hooks'; import { useBuildModelTypeOptions } from '@/hooks/logic-hooks/use-build-options'; import { IModalProps } from '@/interfaces/common'; -import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { VerifyResult } from '@/pages/user-setting/setting-model/hooks'; +import { IAddProviderInstanceRequestBody } from '@/interfaces/request/llm'; +import { + useFetchInstanceNameSet, + useHideWhenInstanceExists, + VerifyResult, +} from '@/pages/user-setting/setting-model/hooks'; import { memo, useCallback, useMemo, useRef } from 'react'; import { FieldValues } from 'react-hook-form'; import { LLMHeader } from '../../components/llm-header'; @@ -22,7 +26,7 @@ const YiyanModal = ({ onVerify, loading, llmFactory, -}: IModalProps & { +}: IModalProps & { llmFactory: string; onVerify?: ( postBody: any, @@ -32,16 +36,28 @@ const YiyanModal = ({ const { t: tc } = useCommonTranslation(); const { buildModelTypeOptions } = useBuildModelTypeOptions(); const formRef = useRef(null); + const { instanceNameSet } = useFetchInstanceNameSet(llmFactory); + + const hideWhenInstanceExists = useHideWhenInstanceExists(instanceNameSet); const fields = useMemo( () => [ + { + name: 'instance_name', + label: t('instanceName'), + type: FormFieldType.Text, + required: true, + placeholder: t('instanceNameMessage'), + tooltip: t('instanceNameTip'), + validation: { message: t('instanceNameMessage') }, + }, { name: 'model_type', label: t('modelType'), - type: FormFieldType.Select, + type: FormFieldType.MultiSelect, required: true, options: buildModelTypeOptions(['chat', 'embedding', 'rerank']), - defaultValue: 'chat', + defaultValue: ['chat'], }, { name: 'llm_name', @@ -56,6 +72,7 @@ const YiyanModal = ({ type: FormFieldType.Text, required: true, placeholder: t('yiyanAKMessage'), + shouldRender: hideWhenInstanceExists, }, { name: 'yiyan_sk', @@ -63,6 +80,7 @@ const YiyanModal = ({ type: FormFieldType.Text, required: true, placeholder: t('yiyanSKMessage'), + shouldRender: hideWhenInstanceExists, }, { name: 'max_tokens', @@ -75,21 +93,17 @@ const YiyanModal = ({ }, }, ], - [t, buildModelTypeOptions], + [t, buildModelTypeOptions, hideWhenInstanceExists], ); const handleOk = async (values?: FieldValues) => { if (!values) return; - const modelType = - values.model_type === 'chat' && values.vision - ? 'image2text' - : values.model_type; - - const data: IAddLlmRequestBody = { + const data: IAddProviderInstanceRequestBody = { + instance_name: values.instance_name as string, llm_factory: llmFactory, llm_name: values.llm_name as string, - model_type: modelType, + model_type: values.model_type, api_key: { yiyan_ak: values.yiyan_ak, yiyan_sk: values.yiyan_sk, @@ -97,21 +111,15 @@ const YiyanModal = ({ max_tokens: values.max_tokens as number, }; - console.info(data); - await onOk?.(data); }; const verifyParamsFunc = useCallback(() => { const values = formRef.current?.getValues(); - const modelType = - values.model_type === 'chat' && values.vision - ? 'image2text' - : values.model_type; return { llm_factory: llmFactory, llm_name: values.llm_name as string, - model_type: modelType, + model_type: values.model_type, api_key: { yiyan_ak: values.yiyan_ak, yiyan_sk: values.yiyan_sk, @@ -148,8 +156,9 @@ const YiyanModal = ({ }} defaultValues={ { - model_type: 'chat', - vision: false, + instance_name: '', + model_type: ['chat'], + max_tokens: 8192, } as FieldValues } labelClassName="font-normal" diff --git a/web/src/pages/user-setting/setting-model/payload-utils.ts b/web/src/pages/user-setting/setting-model/payload-utils.ts new file mode 100644 index 0000000000..10c654921a --- /dev/null +++ b/web/src/pages/user-setting/setting-model/payload-utils.ts @@ -0,0 +1,109 @@ +import { + IAddInstanceModelRequestBody, + IAddProviderInstanceRequestBody, +} from '@/interfaces/request/llm'; + +const MODEL_RESERVED_KEYS = new Set([ + 'llm_name', + 'model_name', + 'model_type', + 'max_tokens', +]); + +const INSTANCE_RESERVED_KEYS = new Set([ + 'instance_name', + 'llm_factory', + 'provider_name', + 'api_base', + 'base_url', + 'region', + 'verify', +]); + +export const MODEL_EXTRA_KEYS = new Set([ + 'is_tools', + 'vision', + 'provider_order', + 'api_version', +]); + +export const MODEL_FIELD_NAMES = new Set([ + ...MODEL_RESERVED_KEYS, + ...MODEL_EXTRA_KEYS, +]); + +export const isModelField = (fieldName: string) => + MODEL_FIELD_NAMES.has(fieldName); + +type FlatPayload = Record; + +export type SplitResult = { + instancePayload: Omit< + IAddProviderInstanceRequestBody, + 'llm_name' | 'model_type' | 'max_tokens' + > & { + base_url?: string; + region?: string; + }; + modelPayload: IAddInstanceModelRequestBody; +}; + +const collectApiKeyExtras = (payload: FlatPayload) => { + const extras: Record = {}; + let apiKeyValue: any = undefined; + for (const [key, value] of Object.entries(payload)) { + if (value === undefined) continue; + if (key === 'api_key') { + apiKeyValue = value; + continue; + } + if (INSTANCE_RESERVED_KEYS.has(key)) continue; + if (MODEL_RESERVED_KEYS.has(key)) continue; + if (MODEL_EXTRA_KEYS.has(key)) continue; + extras[key] = value; + } + if (apiKeyValue && typeof apiKeyValue === 'object') { + return { ...apiKeyValue, ...extras }; + } + if (Object.keys(extras).length === 0) { + return apiKeyValue ?? ''; + } + if (apiKeyValue !== undefined && apiKeyValue !== '') { + return { api_key: apiKeyValue, ...extras }; + } + return extras; +}; + +const collectModelExtras = (payload: FlatPayload) => { + const extras: Record = {}; + for (const key of MODEL_EXTRA_KEYS) { + if (payload[key] !== undefined && payload[key] !== '') { + extras[key] = payload[key]; + } + } + return extras; +}; + +export const splitProviderPayload = (payload: FlatPayload): SplitResult => { + const instancePayload = { + instance_name: payload.instance_name as string, + llm_factory: payload.llm_factory as string, + api_key: collectApiKeyExtras(payload), + base_url: (payload.base_url ?? payload.api_base) as string | undefined, + region: (payload.region as string | undefined) || 'default', + }; + + const modelExtra = collectModelExtras(payload); + + const modelPayload = { + model_name: (payload.model_name ?? payload.llm_name) as string, + model_type: payload.model_type, + max_tokens: payload.max_tokens as number, + ...(Object.keys(modelExtra).length > 0 ? { extra: modelExtra } : {}), + }; + + return { + instancePayload: instancePayload as SplitResult['instancePayload'], + modelPayload, + }; +}; diff --git a/web/src/services/llm-service.ts b/web/src/services/llm-service.ts new file mode 100644 index 0000000000..e14a6920df --- /dev/null +++ b/web/src/services/llm-service.ts @@ -0,0 +1,71 @@ +import api from '@/utils/api'; +import { registerNextServer } from '@/utils/register-server'; + +const { + listAllAddedModels, + defaultModel, + listProviders, + addProvider, + addProviderInstance, + listProviderInstances, + listInstanceModels, + showProviderInstance, + addInstanceModel, + deleteProviderInstance, + updateModelStatus, +} = api; + +const methods = { + listAllAddedModels: { + url: listAllAddedModels, + method: 'get', + }, + listDefaultModels: { + url: defaultModel, + method: 'get', + }, + setDefaultModel: { + url: defaultModel, + method: 'patch', + }, + listProviders: { + url: listProviders, + method: 'get', + }, + addProvider: { + url: addProvider, + method: 'put', + }, + addProviderInstance: { + url: addProviderInstance, + method: 'post', + }, + listProviderInstances: { + url: listProviderInstances, + method: 'get', + }, + listInstanceModels: { + url: listInstanceModels, + method: 'get', + }, + showProviderInstance: { + url: showProviderInstance, + method: 'get', + }, + addInstanceModel: { + url: addInstanceModel, + method: 'post', + }, + deleteProviderInstance: { + url: deleteProviderInstance, + method: 'delete', + }, + updateModelStatus: { + url: updateModelStatus, + method: 'patch', + }, +} as const; + +const llmService = registerNextServer(methods); + +export default llmService; diff --git a/web/src/services/user-service.ts b/web/src/services/user-service.ts index 1637dcfe16..5b81967a51 100644 --- a/web/src/services/user-service.ts +++ b/web/src/services/user-service.ts @@ -9,15 +9,6 @@ const { setting, userInfo, tenantInfo, - factoriesList, - llmList, - myLlm, - setApiKey, - setTenantInfo, - addLlm, - deleteLlm, - enableLlm, - deleteFactory, getSystemVersion, getSystemTokenList, removeSystemToken, @@ -51,46 +42,10 @@ const methods = { url: tenantInfo, method: 'get', }, - setTenantInfo: { - url: setTenantInfo, - method: 'patch', - }, - factoriesList: { - url: factoriesList, - method: 'get', - }, - llmList: { - url: llmList, - method: 'get', - }, - myLlm: { - url: myLlm, - method: 'get', - }, - setApiKey: { - url: setApiKey, - method: 'post', - }, - addLlm: { - url: addLlm, - method: 'post', - }, - deleteLlm: { - url: deleteLlm, - method: 'post', - }, - enableLlm: { - url: enableLlm, - method: 'post', - }, getSystemVersion: { url: getSystemVersion, method: 'get', }, - deleteFactory: { - url: deleteFactory, - method: 'post', - }, listToken: { url: getSystemTokenList, method: 'get', diff --git a/web/src/utils/api.ts b/web/src/utils/api.ts index 7a25ea1495..b1376b55e1 100644 --- a/web/src/utils/api.ts +++ b/web/src/utils/api.ts @@ -11,7 +11,6 @@ export default { setting: `${restAPIv1}/users/me`, userInfo: `${restAPIv1}/users/me`, tenantInfo: `${restAPIv1}/users/me/models`, - setTenantInfo: `${restAPIv1}/users/me/models`, loginChannels: `${restAPIv1}/auth/login/channels`, loginChannel: (channel: string) => `${restAPIv1}/auth/login/${channel}`, @@ -25,14 +24,49 @@ export default { agreeTenant: (tenantId: string) => `${restAPIv1}/tenants/${tenantId}`, // llm model - factoriesList: `${webAPI}/llm/factories`, - llmList: `${webAPI}/llm/list`, - myLlm: `${webAPI}/llm/my_llms`, - setApiKey: `${webAPI}/llm/set_api_key`, - addLlm: `${webAPI}/llm/add_llm`, - deleteLlm: `${webAPI}/llm/delete_llm`, - enableLlm: `${webAPI}/llm/enable_llm`, - deleteFactory: `${webAPI}/llm/delete_factory`, + listAllAddedModels: `${restAPIv1}/models`, + defaultModel: `${restAPIv1}/models/default`, + listProviders: `${restAPIv1}/providers`, + addProvider: `${restAPIv1}/providers/`, + addProviderInstance: ({ llm_factory }: { llm_factory: string }) => + `${restAPIv1}/providers/${llm_factory}/instances`, + listProviderInstances: ({ provider_name }: { provider_name: string }) => + `${restAPIv1}/providers/${provider_name}/instances`, + listInstanceModels: ({ + provider_name, + instance_name, + }: { + provider_name: string; + instance_name: string; + }) => + `${restAPIv1}/providers/${provider_name}/instances/${instance_name}/models`, + showProviderInstance: ({ + provider_name, + instance_name, + }: { + provider_name: string; + instance_name: string; + }) => `${restAPIv1}/providers/${provider_name}/instances/${instance_name}`, + addInstanceModel: ({ + provider_name, + instance_name, + }: { + provider_name: string; + instance_name: string; + }) => + `${restAPIv1}/providers/${provider_name}/instances/${instance_name}/models`, + deleteProviderInstance: ({ provider_name }: { provider_name: string }) => + `${restAPIv1}/providers/${provider_name}/instances`, + updateModelStatus: ({ + provider_name, + instance_name, + model_name, + }: { + provider_name: string; + instance_name: string; + model_name: string; + }) => + `${restAPIv1}/providers/${provider_name}/instances/${instance_name}/models/${model_name}`, // data source dataSourceUpdate: (id: string) => `${restAPIv1}/connectors/${id}`, diff --git a/web/src/utils/llm-util.ts b/web/src/utils/llm-util.ts index daf9c0d586..ff53dd280e 100644 --- a/web/src/utils/llm-util.ts +++ b/web/src/utils/llm-util.ts @@ -1,29 +1,10 @@ -import { IThirdOAIModel } from '@/interfaces/database/llm'; import { getCachedLlmList } from './llm-cache'; -export const getLLMIconName = (fid: string, llm_name: string) => { - if (fid === 'FastEmbed') { - return llm_name.split('/').at(0) ?? ''; - } - - return fid; -}; - -export const getLlmNameAndFIdByLlmId = (llmId?: string) => { - const [llmName, fId] = llmId?.split('@') || []; - - return { fId, llmName }; -}; - // The names of the large models returned by the interface are similar to "deepseek-r1___OpenAI-API" export function getRealModelName(llmName: string) { return llmName.split('__').at(0) ?? ''; } -export function buildLlmUuid(llm: IThirdOAIModel) { - return `${llm.llm_name}@${llm.fid}`; -} - // Get tenant model ID from LLM list by model name and factory ID export function getTenantModelId( llmList: Record, @@ -53,12 +34,37 @@ export function getTenantModelId( return ''; } -// Extract model name and factory ID from a model UUID (e.g., "model_name@factory_id") +/** Build "modelName@instanceName@providerName" */ +export function buildModelValue(model: { + model_name: string; + model_instance: string; + model_provider: string; +}) { + return `${model.model_name}@${model.model_instance}@${model.model_provider}`; +} + +/** Parse "modelName@instanceName@providerName" */ +export function parseModelValue(val: string) { + if (!val) return null; + const firstAt = val.indexOf('@'); + const lastAt = val.lastIndexOf('@'); + if (firstAt === -1 || firstAt === lastAt) return null; + return { + model_name: val.substring(0, firstAt), + model_instance: val.substring(firstAt + 1, lastAt), + model_provider: val.substring(lastAt + 1), + }; +} + +// Extract model name and factory ID from a model UUID +// Supports both "model_name@factory_id" and "model_name@factory_id#instance_name" export function parseModelUuid(uuid: string): { modelName: string; factoryId: string; } { - const [modelName, factoryId] = uuid.split('@'); + const hashIndex = uuid.indexOf('#'); + const core = hashIndex === -1 ? uuid : uuid.slice(0, hashIndex); + const [modelName, factoryId] = core.split('@'); return { modelName, factoryId }; } diff --git a/web/src/utils/next-request.ts b/web/src/utils/next-request.ts index 63bc2fefe4..2b474e8366 100644 --- a/web/src/utils/next-request.ts +++ b/web/src/utils/next-request.ts @@ -148,8 +148,6 @@ request.interceptors.response.use( return response; }, function (error) { - console.log('🚀 ~ error:', error); - // Handle HTTP 401 (token expired / invalid) const status = error?.response?.status; if (status === 401) {