mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
fix: replicate model provider (#15933)
### What problem does this PR solve? FIx replicate model provider failing with valid api key ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Wang Qi <wangq8@outlook.com>
This commit is contained in:
@@ -129,7 +129,9 @@ async def set_api_key():
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
|
||||
elif not rerank_passed and llm.model_type == LLMType.RERANK.value:
|
||||
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
|
||||
if factory not in RerankModel:
|
||||
msg += f"\nRerank model from {factory} is not supported yet."
|
||||
continue
|
||||
mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=base_url)
|
||||
try:
|
||||
arr, tc = await asyncio.wait_for(
|
||||
@@ -350,19 +352,21 @@ async def add_llm():
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
|
||||
case LLMType.RERANK.value:
|
||||
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
arr, tc = await asyncio.wait_for(
|
||||
asyncio.to_thread(mdl.similarity, "Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if len(arr) == 0:
|
||||
raise Exception("Not known.")
|
||||
except KeyError:
|
||||
msg += f"{factory} does not support this model({factory}/{mdl_nm})"
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
if factory not in RerankModel:
|
||||
msg += f"\nRerank model from {factory} is not supported yet."
|
||||
else:
|
||||
try:
|
||||
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
arr, tc = await asyncio.wait_for(
|
||||
asyncio.to_thread(mdl.similarity, "Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if len(arr) == 0:
|
||||
raise Exception("Not known.")
|
||||
except KeyError:
|
||||
msg += f"{factory} does not support this model({factory}/{mdl_nm})"
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
|
||||
case LLMType.IMAGE2TEXT.value:
|
||||
from rag.utils.base64_image import test_image
|
||||
|
||||
@@ -51,6 +51,11 @@ def _normalize_provider_base_url(provider_name: str, base_url: str | None):
|
||||
return base_url
|
||||
|
||||
|
||||
|
||||
def _factory_llm_name(llm: dict) -> str:
|
||||
return llm.get("name") or llm.get("llm_name", "")
|
||||
|
||||
|
||||
def list_providers(tenant_id: str, all_available: bool = False):
|
||||
"""
|
||||
List providers for a tenant.
|
||||
@@ -206,7 +211,7 @@ async def list_provider_models(provider_name: str, api_key: str = None, base_url
|
||||
if not factory_info:
|
||||
return False, f"Provider '{provider_name}' not found"
|
||||
static_llms = [{
|
||||
"name": llm["name"],
|
||||
"name": _factory_llm_name(llm),
|
||||
"max_tokens": llm["max_tokens"],
|
||||
"model_types": _factory_model_types(llm),
|
||||
"features": (
|
||||
@@ -250,13 +255,13 @@ def show_provider_model(provider_name: str, model_name: str):
|
||||
llms = factory_info[0]["llm"]
|
||||
if not llms:
|
||||
return False, f"No models found for provider '{provider_name}'"
|
||||
target_llm = [llm for llm in llms if llm["name"] == model_name]
|
||||
target_llm = [llm for llm in llms if _factory_llm_name(llm) == model_name]
|
||||
if not target_llm:
|
||||
return False, f"Model '{model_name}' not found"
|
||||
llm_info = target_llm[0]
|
||||
|
||||
return True, {
|
||||
"name": llm_info["name"],
|
||||
"name": _factory_llm_name(llm_info),
|
||||
"max_tokens": llm_info["max_tokens"],
|
||||
"model_types": _factory_model_types(llm_info),
|
||||
"thinking": None,
|
||||
@@ -465,7 +470,11 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
|
||||
)
|
||||
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
|
||||
elif not rerank_passed and LLMType.RERANK.value in model_types:
|
||||
assert provider_name in RerankModel, f"Rerank model from {provider_name} is not supported yet."
|
||||
if provider_name not in RerankModel:
|
||||
unsupported_msg = f"Rerank model from {provider_name} is not supported yet."
|
||||
logging.warning(unsupported_msg)
|
||||
msg += f"\n{unsupported_msg}"
|
||||
continue
|
||||
mdl = RerankModel[provider_name](api_key_str, llm["llm_name"], base_url=base_url)
|
||||
try:
|
||||
arr, tc = await asyncio.wait_for(
|
||||
|
||||
Reference in New Issue
Block a user