diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 430b7d8dc3..690d54b954 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -129,7 +129,9 @@ async def set_api_key(): except Exception as e: msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e) elif not rerank_passed and llm.model_type == LLMType.RERANK.value: - assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet." + if factory not in RerankModel: + msg += f"\nRerank model from {factory} is not supported yet." + continue mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=base_url) try: arr, tc = await asyncio.wait_for( @@ -350,19 +352,21 @@ async def add_llm(): msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) case LLMType.RERANK.value: - assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet." - try: - mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) - arr, tc = await asyncio.wait_for( - asyncio.to_thread(mdl.similarity, "Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]), - timeout=timeout_seconds, - ) - if len(arr) == 0: - raise Exception("Not known.") - except KeyError: - msg += f"{factory} does not support this model({factory}/{mdl_nm})" - except Exception as e: - msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) + if factory not in RerankModel: + msg += f"\nRerank model from {factory} is not supported yet." + else: + try: + mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) + arr, tc = await asyncio.wait_for( + asyncio.to_thread(mdl.similarity, "Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]), + timeout=timeout_seconds, + ) + if len(arr) == 0: + raise Exception("Not known.") + except KeyError: + msg += f"{factory} does not support this model({factory}/{mdl_nm})" + except Exception as e: + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) case LLMType.IMAGE2TEXT.value: from rag.utils.base64_image import test_image diff --git a/api/apps/services/provider_api_service.py b/api/apps/services/provider_api_service.py index d5290fcace..23497da605 100644 --- a/api/apps/services/provider_api_service.py +++ b/api/apps/services/provider_api_service.py @@ -51,6 +51,11 @@ def _normalize_provider_base_url(provider_name: str, base_url: str | None): return base_url + +def _factory_llm_name(llm: dict) -> str: + return llm.get("name") or llm.get("llm_name", "") + + def list_providers(tenant_id: str, all_available: bool = False): """ List providers for a tenant. @@ -206,7 +211,7 @@ async def list_provider_models(provider_name: str, api_key: str = None, base_url if not factory_info: return False, f"Provider '{provider_name}' not found" static_llms = [{ - "name": llm["name"], + "name": _factory_llm_name(llm), "max_tokens": llm["max_tokens"], "model_types": _factory_model_types(llm), "features": ( @@ -250,13 +255,13 @@ def show_provider_model(provider_name: str, model_name: str): llms = factory_info[0]["llm"] if not llms: return False, f"No models found for provider '{provider_name}'" - target_llm = [llm for llm in llms if llm["name"] == model_name] + target_llm = [llm for llm in llms if _factory_llm_name(llm) == model_name] if not target_llm: return False, f"Model '{model_name}' not found" llm_info = target_llm[0] return True, { - "name": llm_info["name"], + "name": _factory_llm_name(llm_info), "max_tokens": llm_info["max_tokens"], "model_types": _factory_model_types(llm_info), "thinking": None, @@ -465,7 +470,11 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No ) msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e) elif not rerank_passed and LLMType.RERANK.value in model_types: - assert provider_name in RerankModel, f"Rerank model from {provider_name} is not supported yet." + if provider_name not in RerankModel: + unsupported_msg = f"Rerank model from {provider_name} is not supported yet." + logging.warning(unsupported_msg) + msg += f"\n{unsupported_msg}" + continue mdl = RerankModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url) try: arr, tc = await asyncio.wait_for( diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 76da621b04..bd83fce3c1 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -4097,7 +4097,46 @@ "logo": "", "tags": "LLM,TEXT EMBEDDING", "status": "1", - "llm": [] + "llm": [ + { + "llm_name": "meta/llama-4-maverick-instruct", + "tags": "LLM,CHAT,8k", + "max_tokens": 8192, + "model_type": "chat" + }, + { + "llm_name": "meta/llama-4-scout-instruct", + "tags": "LLM,CHAT,8k", + "max_tokens": 8192, + "model_type": "chat" + }, + { + "llm_name": "meta/meta-llama-3-70b-instruct", + "tags": "LLM,CHAT,8k", + "max_tokens": 8192, + "model_type": "chat" + }, + { + "llm_name": "meta/meta-llama-3-8b-instruct", + "tags": "LLM,CHAT,8k", + "max_tokens": 8192, + "model_type": "chat" + }, + { + "llm_name": "replicate/all-mpnet-base-v2:b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", + "tags": "TEXT EMBEDDING", + "max_tokens": 384, + "model_type": "embedding" + }, + { + "llm_name": "ibm-granite/granite-embedding-278m-multilingual:1f76d42a05f120e12272746d5a2d86b525c13420773f795a4cbef9117d8685f1", + "tags": "TEXT EMBEDDING", + "max_tokens": 512, + "model_type": "embedding" + } + ], + "rank": "987", + "url": "https://api.replicate.com" }, { "name": "Tencent Hunyuan", diff --git a/conf/models/replicate.json b/conf/models/replicate.json index 42a8255dc7..84cf3f2173 100644 --- a/conf/models/replicate.json +++ b/conf/models/replicate.json @@ -9,6 +9,20 @@ }, "class": "replicate", "models": [ + { + "name": "meta/llama-4-maverick-instruct", + "max_tokens": 8192, + "model_types": [ + "chat" + ] + }, + { + "name": "meta/llama-4-scout-instruct", + "max_tokens": 8192, + "model_types": [ + "chat" + ] + }, { "name": "meta/meta-llama-3-70b-instruct", "max_tokens": 8192, @@ -31,10 +45,10 @@ ] }, { - "name": "yxzwayne/bge-reranker-v2-m3:7f7c6e9d18336e2cbf07d88e9362d881d2fe4d6a9854ec1260f115cabc106a8c", - "max_tokens": 8192, + "name": "ibm-granite/granite-embedding-278m-multilingual:1f76d42a05f120e12272746d5a2d86b525c13420773f795a4cbef9117d8685f1", + "max_tokens": 512, "model_types": [ - "rerank" + "embedding" ] } ] diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 0396d4a267..da9e430174 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -34,6 +34,7 @@ from enum import StrEnum from common.misc_utils import thread_pool_exec from common.token_utils import num_tokens_from_string, total_token_count_from_response from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider +from rag.llm.key_utils import _normalize_replicate_key from rag.llm.tool_decorator import FunctionToolSession, is_tool from rag.nlp import is_chinese, is_english @@ -938,7 +939,7 @@ class ReplicateChat(Base): from replicate.client import Client self.model_name = model_name - self.client = Client(api_token=key) + self.client = Client(api_token=_normalize_replicate_key(key)) def _chat(self, history, gen_conf=None, **kwargs): gen_conf = dict(gen_conf or {}) @@ -971,6 +972,43 @@ class ReplicateChat(Base): yield num_tokens_from_string(ans) + async def async_chat_streamly(self, system, history, gen_conf: dict | None = None, **kwargs): + gen_conf = dict(gen_conf or {}) + if "max_tokens" in gen_conf: + del gen_conf["max_tokens"] + + def _do_chat(): + msgs = list(history or []) + if system and msgs and msgs[0].get("role") != "system": + msgs.insert(0, {"role": "system", "content": system}) + elif system and not msgs: + msgs = [{"role": "system", "content": system}] + + system_msg = msgs[0]["content"] if msgs and msgs[0].get("role") == "system" else "" + prompt = "\n".join( + [item["role"] + ":" + item["content"] for item in msgs[-5:] if item.get("role") != "system"] + ) + try: + response = self.client.run( + self.model_name, + input={"system_prompt": system_msg, "prompt": prompt, **gen_conf}, + ) + chunks = [] + for resp in response: + chunks.append(resp if isinstance(resp, str) else str(resp)) + answer = "".join(chunks) + return chunks or ([answer] if answer else []), num_tokens_from_string(answer), None + except Exception as e: + return [], 0, e + + chunks, total_tokens, error = await asyncio.to_thread(_do_chat) + if error: + yield f"**ERROR**: {error}" + else: + for chunk in chunks: + yield chunk + yield total_tokens + class SparkChat(Base): _FACTORY_NAME = "XunFei Spark" diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 516f3dad5a..46c43d4752 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -28,10 +28,11 @@ from ollama import Client from openai import OpenAI from zhipuai import ZhipuAI +from common import settings from common.exceptions import ModelException from common.log_utils import log_exception from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response -from common import settings +from rag.llm.key_utils import _normalize_replicate_key import logging import base64 @@ -971,7 +972,7 @@ class ReplicateEmbed(Base): from replicate.client import Client self.model_name = model_name - self.client = Client(api_token=key) + self.client = Client(api_token=_normalize_replicate_key(key)) def encode(self, texts: list): batch_size = 16 diff --git a/rag/llm/key_utils.py b/rag/llm/key_utils.py new file mode 100644 index 0000000000..05255d79fb --- /dev/null +++ b/rag/llm/key_utils.py @@ -0,0 +1,34 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json + + +def _normalize_replicate_key(key): + if isinstance(key, dict): + if "api_key" in key: + return key.get("api_key") + return json.dumps(key) + if isinstance(key, str): + try: + payload = json.loads(key) + if isinstance(payload, dict) and "api_key" in payload: + return payload.get("api_key") + except (json.JSONDecodeError, TypeError): + pass + return key + + +__all__ = ["_normalize_replicate_key"]