mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Feat: tenant llm provider (#14595)
### What problem does this PR solve? Python implementation of the Go-based model_provider API suite. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: bill <yibie_jingnian@163.com>
This commit is contained in:
@@ -27,10 +27,9 @@ import json_repair
|
||||
|
||||
from agent.component.llm import LLM, LLMParam
|
||||
from agent.tools.base import LLMToolPluginCallSession, ToolBase, ToolMeta, ToolParamBase
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance, get_model_type_by_name
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from common.connection_utils import timeout
|
||||
from common.mcp_tool_call_conn import MCPToolBinding, MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||
from rag.prompts.generator import citation_plus, citation_prompt, full_question, kb_prompt, message_fit_in, structured_output_prompt
|
||||
@@ -81,7 +80,9 @@ class Agent(LLM, ToolBase):
|
||||
original_name = cpn.get_meta()["function"]["name"]
|
||||
indexed_name = f"{original_name}_{idx}"
|
||||
self.tools[indexed_name] = cpn
|
||||
chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id)
|
||||
model_types = get_model_type_by_name(self._canvas.get_tenant_id(), self._param.llm_id)
|
||||
model_type = "chat" if "chat" in model_types else model_types[0]
|
||||
chat_model_config = get_model_config_from_provider_instance(self._canvas.get_tenant_id(), model_type, self._param.llm_id)
|
||||
self.chat_mdl = LLMBundle(
|
||||
self._canvas.get_tenant_id(),
|
||||
chat_model_config,
|
||||
|
||||
@@ -33,10 +33,9 @@ from urllib.request import Request, urlopen
|
||||
from agent.component.base import ComponentBase
|
||||
from agent.component.llm import LLMParam
|
||||
from api.db import FileType
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance, get_model_type_by_name
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.utils.file_utils import filename_type
|
||||
from common import settings
|
||||
from common.connection_utils import timeout
|
||||
@@ -394,9 +393,9 @@ class Browser(ComponentBase, ABC):
|
||||
def _build_browser_llm(self):
|
||||
from browser_use.llm import ChatBrowserUse, ChatOpenAI
|
||||
|
||||
chat_model_config = get_model_config_by_type_and_name(
|
||||
chat_model_config = get_model_config_from_provider_instance(
|
||||
self._canvas.get_tenant_id(),
|
||||
TenantLLMService.llm_id2llm_type(self._param.llm_id),
|
||||
get_model_type_by_name(self._param.llm_id),
|
||||
self._param.llm_id,
|
||||
)
|
||||
cfg = self._as_model_config_dict(chat_model_config)
|
||||
|
||||
@@ -21,7 +21,7 @@ from abc import ABC
|
||||
|
||||
from common.constants import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
from common.connection_utils import timeout
|
||||
from rag.llm.chat_model import ERROR_PREFIX
|
||||
@@ -123,7 +123,7 @@ class Categorize(LLM, ABC):
|
||||
msg[-1]["content"] = query_value
|
||||
self.set_input_value(query_key, msg[-1]["content"])
|
||||
self._param.update_prompt()
|
||||
chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
||||
chat_model_config = get_model_config_from_provider_instance(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), chat_model_config)
|
||||
|
||||
user_prompt = """
|
||||
|
||||
@@ -24,8 +24,7 @@ import json_repair
|
||||
from functools import partial
|
||||
from common.constants import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance, get_model_type_by_name
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from common.connection_utils import timeout
|
||||
from rag.prompts.generator import tool_call_summary, message_fit_in, citation_prompt, structured_output_prompt
|
||||
@@ -85,7 +84,9 @@ class LLM(ComponentBase):
|
||||
|
||||
def __init__(self, canvas, component_id, param: ComponentParamBase):
|
||||
super().__init__(canvas, component_id, param)
|
||||
chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id)
|
||||
model_types = get_model_type_by_name(self._canvas.get_tenant_id(), self._param.llm_id)
|
||||
model_type = "chat" if "chat" in model_types else model_types[0]
|
||||
chat_model_config = get_model_config_from_provider_instance(self._canvas.get_tenant_id(), model_type, self._param.llm_id)
|
||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), chat_model_config,
|
||||
max_retries=self._param.max_retries,
|
||||
retry_interval=self._param.delay_after_error)
|
||||
@@ -247,9 +248,11 @@ class LLM(ComponentBase):
|
||||
self.set_input_value(k, args[k])
|
||||
|
||||
self.imgs = self._uniq_images(self.imgs + extracted_imgs)
|
||||
if self.imgs and TenantLLMService.llm_id2llm_type(self._param.llm_id) == LLMType.CHAT.value:
|
||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT.value,
|
||||
self._param.llm_id, max_retries=self._param.max_retries,
|
||||
model_types = get_model_type_by_name(self._canvas.get_tenant_id(), self._param.llm_id)
|
||||
model_type = LLMType.CHAT.value if LLMType.CHAT.value in model_types else model_types[0]
|
||||
model_config = get_model_config_from_provider_instance(self._canvas.get_tenant_id(), model_type, self._param.llm_id)
|
||||
if self.imgs:
|
||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), model_config, max_retries=self._param.max_retries,
|
||||
retry_interval=self._param.delay_after_error
|
||||
)
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.joint_services import memory_message_service
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_tenant_default_model_by_type
|
||||
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance
|
||||
from common import settings
|
||||
from common.connection_utils import timeout
|
||||
from rag.app.tag import label_question
|
||||
@@ -121,12 +121,12 @@ class Retrieval(ToolBase, ABC):
|
||||
embd_mdl = None
|
||||
if embd_nms:
|
||||
tenant_id = self._canvas.get_tenant_id()
|
||||
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
||||
embd_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
||||
embd_mdl = LLMBundle(tenant_id, embd_model_config)
|
||||
|
||||
rerank_mdl = None
|
||||
if self._param.rerank_id:
|
||||
rerank_model_config = get_model_config_by_type_and_name(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
||||
rerank_model_config = get_model_config_from_provider_instance(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, rerank_model_config)
|
||||
|
||||
vars = self.get_input_elements_from_text(query_text)
|
||||
|
||||
Reference in New Issue
Block a user