fix: replicate model provider (#15933)

### What problem does this PR solve?

FIx replicate model provider failing with valid api key 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Wang Qi <wangq8@outlook.com>
This commit is contained in:
Idriss Sbaaoui
2026-06-11 15:08:33 +08:00
committed by GitHub
parent 3f929e3904
commit 9871a7e0b6
7 changed files with 164 additions and 25 deletions

View File

@@ -129,7 +129,9 @@ async def set_api_key():
except Exception as e:
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
elif not rerank_passed and llm.model_type == LLMType.RERANK.value:
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
if factory not in RerankModel:
msg += f"\nRerank model from {factory} is not supported yet."
continue
mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=base_url)
try:
arr, tc = await asyncio.wait_for(
@@ -350,19 +352,21 @@ async def add_llm():
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.RERANK.value:
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
try:
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
arr, tc = await asyncio.wait_for(
asyncio.to_thread(mdl.similarity, "Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]),
timeout=timeout_seconds,
)
if len(arr) == 0:
raise Exception("Not known.")
except KeyError:
msg += f"{factory} does not support this model({factory}/{mdl_nm})"
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
if factory not in RerankModel:
msg += f"\nRerank model from {factory} is not supported yet."
else:
try:
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
arr, tc = await asyncio.wait_for(
asyncio.to_thread(mdl.similarity, "Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]),
timeout=timeout_seconds,
)
if len(arr) == 0:
raise Exception("Not known.")
except KeyError:
msg += f"{factory} does not support this model({factory}/{mdl_nm})"
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.IMAGE2TEXT.value:
from rag.utils.base64_image import test_image

View File

@@ -51,6 +51,11 @@ def _normalize_provider_base_url(provider_name: str, base_url: str | None):
return base_url
def _factory_llm_name(llm: dict) -> str:
return llm.get("name") or llm.get("llm_name", "")
def list_providers(tenant_id: str, all_available: bool = False):
"""
List providers for a tenant.
@@ -206,7 +211,7 @@ async def list_provider_models(provider_name: str, api_key: str = None, base_url
if not factory_info:
return False, f"Provider '{provider_name}' not found"
static_llms = [{
"name": llm["name"],
"name": _factory_llm_name(llm),
"max_tokens": llm["max_tokens"],
"model_types": _factory_model_types(llm),
"features": (
@@ -250,13 +255,13 @@ def show_provider_model(provider_name: str, model_name: str):
llms = factory_info[0]["llm"]
if not llms:
return False, f"No models found for provider '{provider_name}'"
target_llm = [llm for llm in llms if llm["name"] == model_name]
target_llm = [llm for llm in llms if _factory_llm_name(llm) == model_name]
if not target_llm:
return False, f"Model '{model_name}' not found"
llm_info = target_llm[0]
return True, {
"name": llm_info["name"],
"name": _factory_llm_name(llm_info),
"max_tokens": llm_info["max_tokens"],
"model_types": _factory_model_types(llm_info),
"thinking": None,
@@ -465,7 +470,11 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
)
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
elif not rerank_passed and LLMType.RERANK.value in model_types:
assert provider_name in RerankModel, f"Rerank model from {provider_name} is not supported yet."
if provider_name not in RerankModel:
unsupported_msg = f"Rerank model from {provider_name} is not supported yet."
logging.warning(unsupported_msg)
msg += f"\n{unsupported_msg}"
continue
mdl = RerankModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url)
try:
arr, tc = await asyncio.wait_for(