mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Feat: get model list from remote (#15711)
### What problem does this PR solve? Feat: - Get model list from remote provider. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -199,7 +199,7 @@ def delete_provider(tenant_id: str = None, provider_name: str = None):
|
||||
|
||||
@manager.route("/providers/<provider_name>/models", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_provider_models(provider_name: str):
|
||||
async def list_provider_models(provider_name: str):
|
||||
"""
|
||||
List models for a provider.
|
||||
---
|
||||
@@ -230,7 +230,9 @@ def list_provider_models(provider_name: str):
|
||||
type: object
|
||||
"""
|
||||
try:
|
||||
success, result = provider_api_service.list_provider_models(provider_name)
|
||||
api_key = request.args.get("api_key")
|
||||
base_url = request.args.get("base_url")
|
||||
success, result = await provider_api_service.list_provider_models(provider_name, api_key, base_url)
|
||||
if success:
|
||||
return get_result(data=result)
|
||||
else:
|
||||
@@ -341,7 +343,7 @@ async def create_provider_instance(tenant_id: str = None, provider_name: str = N
|
||||
api_key = data["api_key"]
|
||||
base_url = data.get("base_url", "")
|
||||
region = data.get("region", "")
|
||||
model_info = data.get("model_info", {})
|
||||
model_info = data.get("model_info", [])
|
||||
|
||||
try:
|
||||
success, msg = await provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region, model_info)
|
||||
@@ -409,7 +411,7 @@ async def verify_provider_api_key(provider_name: str = None):
|
||||
base_url = data.get("base_url", "")
|
||||
api_key = data["api_key"]
|
||||
region = data.get("region", "default")
|
||||
model_info = data.get("model_info", {})
|
||||
model_info = data.get("model_info", [])
|
||||
|
||||
try:
|
||||
success, msg = await provider_api_service.verify_api_key(provider_name, api_key, base_url, region, model_info)
|
||||
|
||||
@@ -25,7 +25,7 @@ from api.db.joint_services.tenant_model_service import get_model_config_from_pro
|
||||
from api.db.services.tenant_model_provider_service import TenantModelProviderService
|
||||
from api.db.services.tenant_model_instance_service import TenantModelInstanceService
|
||||
from api.db.services.tenant_model_service import TenantModelService
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, ModelMeta
|
||||
|
||||
|
||||
def _to_int(v, default=500):
|
||||
@@ -168,28 +168,38 @@ def show_provider(provider_name: str):
|
||||
}
|
||||
|
||||
|
||||
def list_provider_models(provider_name: str):
|
||||
async def list_provider_models(provider_name: str, api_key: str = None, base_url: str = None):
|
||||
"""
|
||||
List all models for a provider from the LLM dictionary.
|
||||
|
||||
:param provider_name: provider/factory name
|
||||
:param api_key: api key
|
||||
:param base_url: base url
|
||||
:return: (success, result_or_error_message)
|
||||
"""
|
||||
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"]==provider_name]
|
||||
if not factory_info:
|
||||
return False, f"Provider '{provider_name}' not found"
|
||||
llms = factory_info[0]["llm"]
|
||||
if not llms:
|
||||
return False, f"No models found for provider '{provider_name}'"
|
||||
|
||||
models = []
|
||||
for llm in llms:
|
||||
models.append({
|
||||
static_llms = [{
|
||||
"name": llm["name"],
|
||||
"max_tokens": llm["max_tokens"],
|
||||
"model_types": [llm["model_type"]],
|
||||
"features": None
|
||||
})
|
||||
} for llm in factory_info[0]["llm"]]
|
||||
|
||||
model_base_url = base_url or factory_info[0].get("url", "")
|
||||
remote_models = []
|
||||
if provider_name in ModelMeta:
|
||||
remote_models = await ModelMeta[provider_name](api_key, model_base_url).get_model_list()
|
||||
|
||||
if not static_llms and not remote_models:
|
||||
return False, f"No models found for provider '{provider_name}'"
|
||||
|
||||
# Merge static and remote models, preferring remote_models on name conflicts
|
||||
merged = {m["name"]: m for m in static_llms}
|
||||
merged.update({m["name"]: m for m in remote_models})
|
||||
models = list(merged.values())
|
||||
|
||||
models.sort(key=lambda x: x["name"])
|
||||
return True, models
|
||||
|
||||
@@ -224,7 +234,7 @@ def show_provider_model(provider_name: str, model_name: str):
|
||||
}
|
||||
|
||||
|
||||
async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str|dict, base_url: str, region: str, model_info: dict=None):
|
||||
async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str|dict, base_url: str, region: str, model_info: list[dict]=None):
|
||||
"""
|
||||
Create a provider instance.
|
||||
|
||||
@@ -237,7 +247,7 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
|
||||
:param api_key: API key
|
||||
:param base_url: base url
|
||||
:param region: region
|
||||
:param model_info: model info, {
|
||||
:param model_info: model info, [{
|
||||
"model_type": ["chat"], # support multiple
|
||||
"model_name": "name",
|
||||
"max_tokens": 4096,
|
||||
@@ -245,7 +255,7 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
|
||||
"field1": "value1",
|
||||
"field2": "'value2"
|
||||
}
|
||||
}
|
||||
}]
|
||||
:return: (success, result_or_error_message)
|
||||
"""
|
||||
if not provider_name:
|
||||
@@ -280,8 +290,12 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
|
||||
extra_fields["region"] = region
|
||||
TenantModelInstanceService.create_instance(provider_id=provider_obj.id,instance_name=instance_name,api_key=api_key_str, extra=json.dumps(extra_fields))
|
||||
if model_info:
|
||||
success, msg = add_model_to_instance(tenant_id, provider_name, instance_name, **model_info)
|
||||
if not success:
|
||||
msg = ""
|
||||
for model in model_info:
|
||||
success, _msg = add_model_to_instance(tenant_id, provider_name, instance_name, **model)
|
||||
if not success:
|
||||
msg += _msg
|
||||
if msg:
|
||||
return False, msg
|
||||
|
||||
return True, "success"
|
||||
@@ -321,7 +335,7 @@ def list_provider_instances(tenant_id: str, provider_name: str):
|
||||
return True, active_instances + inactive_instances
|
||||
|
||||
|
||||
async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=None, region: str=None, model_info: dict=None):
|
||||
async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=None, region: str=None, model_info: list[dict]=None):
|
||||
"""
|
||||
Verify API key for a provider.
|
||||
|
||||
@@ -329,7 +343,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
|
||||
:param api_key: API key
|
||||
:param base_url: base url
|
||||
:param region: region
|
||||
:param model_info: model info, {
|
||||
:param model_info: model info, [{
|
||||
"model_type": ["chat"], # support multiple
|
||||
"model_name": "name",
|
||||
"max_tokens": 4096,
|
||||
@@ -337,7 +351,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
|
||||
"field1": "value1",
|
||||
"field2": "'value2"
|
||||
}
|
||||
}
|
||||
}]
|
||||
:return: (success, result_or_error_message)
|
||||
"""
|
||||
if not provider_name:
|
||||
@@ -358,8 +372,8 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
|
||||
return False, f"No models found for provider '{provider_name}'"
|
||||
factory_llms = [{
|
||||
"model_type": _type,
|
||||
"llm_name": model_info.get("model_name", ""),
|
||||
} for _type in model_info.get("model_type", [])]
|
||||
"llm_name": model.get("model_name", ""),
|
||||
} for model in model_info if model for _type in model.get("model_type", []) ]
|
||||
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
|
||||
Reference in New Issue
Block a user