mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-05 10:58:34 +08:00
@@ -25,7 +25,7 @@ from api.db.joint_services.tenant_model_service import get_model_config_from_pro
|
||||
from api.db.services.tenant_model_provider_service import TenantModelProviderService
|
||||
from api.db.services.tenant_model_instance_service import TenantModelInstanceService
|
||||
from api.db.services.tenant_model_service import TenantModelService
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, ModelMeta
|
||||
from rag.llm import ChatModel, EmbeddingModel, ModelMeta, OcrModel, RerankModel
|
||||
|
||||
|
||||
def _to_int(v, default=500):
|
||||
@@ -385,9 +385,11 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
|
||||
"model_type": _type,
|
||||
"llm_name": model.get("model_name", ""),
|
||||
} for model in model_info if model for _type in model.get("model_type", []) ]
|
||||
if not factory_llms:
|
||||
return False, f"No valid models found for provider '{provider_name}'"
|
||||
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
chat_passed, embd_passed, rerank_passed, ocr_passed = False, False, False, False
|
||||
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
|
||||
extra = {"provider": provider_name}
|
||||
msg = ""
|
||||
@@ -462,11 +464,29 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
|
||||
llm["llm_name"],
|
||||
)
|
||||
msg += f"\nFail to access model({provider_name}/{llm['llm_name']}) using this api key." + str(e)
|
||||
if any([embd_passed, chat_passed, rerank_passed]):
|
||||
elif not ocr_passed and llm["model_type"] == LLMType.OCR.value:
|
||||
assert provider_name in OcrModel, f"OCR model from {provider_name} is not supported yet."
|
||||
mdl = OcrModel[provider_name](key=api_key_str, model_name=llm["llm_name"], base_url=base_url)
|
||||
try:
|
||||
ok, reason = await asyncio.wait_for(
|
||||
asyncio.to_thread(mdl.check_available),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if not ok:
|
||||
raise RuntimeError(reason or "Model not available")
|
||||
ocr_passed = True
|
||||
except Exception as e:
|
||||
logging.exception(
|
||||
"Fail to access OCR model for provider=%s model=%s",
|
||||
provider_name,
|
||||
llm["llm_name"],
|
||||
)
|
||||
msg += f"\nFail to access model({provider_name}/{llm['llm_name']})." + str(e)
|
||||
if any([embd_passed, chat_passed, rerank_passed, ocr_passed]):
|
||||
msg = ""
|
||||
break
|
||||
|
||||
success = any([embd_passed, chat_passed, rerank_passed])
|
||||
success = any([embd_passed, chat_passed, rerank_passed, ocr_passed])
|
||||
return success, "success" if success else msg
|
||||
|
||||
|
||||
|
||||
@@ -871,7 +871,10 @@ export const ProviderConfigMap: Record<string, ProviderConfig> = {
|
||||
return {
|
||||
apiKey: cfg,
|
||||
baseUrl: values.mineru_apiserver,
|
||||
modelInfo: buildModelInfoFromValues(values),
|
||||
modelInfo: buildModelInfoFromValues({
|
||||
...values,
|
||||
model_type: ['ocr'],
|
||||
}),
|
||||
};
|
||||
},
|
||||
submitTransform: (values) => {
|
||||
@@ -887,7 +890,10 @@ export const ProviderConfigMap: Record<string, ProviderConfig> = {
|
||||
llm_factory: LLMFactory.MinerU,
|
||||
api_key: cfg,
|
||||
api_base: '',
|
||||
model_info: buildModelInfoFromValues(values),
|
||||
model_info: buildModelInfoFromValues({
|
||||
...values,
|
||||
model_type: ['ocr'],
|
||||
}),
|
||||
};
|
||||
},
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user