Feat: get model list from remote (#15711)

### What problem does this PR solve?

Feat:
- Get model list from remote provider. 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Lynn
2026-06-08 11:02:40 +08:00
committed by GitHub
parent b0a45809ff
commit b05d5a5228
4 changed files with 155 additions and 25 deletions

View File

@@ -199,7 +199,7 @@ def delete_provider(tenant_id: str = None, provider_name: str = None):
@manager.route("/providers/<provider_name>/models", methods=["GET"]) # noqa: F821
@login_required
def list_provider_models(provider_name: str):
async def list_provider_models(provider_name: str):
"""
List models for a provider.
---
@@ -230,7 +230,9 @@ def list_provider_models(provider_name: str):
type: object
"""
try:
success, result = provider_api_service.list_provider_models(provider_name)
api_key = request.args.get("api_key")
base_url = request.args.get("base_url")
success, result = await provider_api_service.list_provider_models(provider_name, api_key, base_url)
if success:
return get_result(data=result)
else:
@@ -341,7 +343,7 @@ async def create_provider_instance(tenant_id: str = None, provider_name: str = N
api_key = data["api_key"]
base_url = data.get("base_url", "")
region = data.get("region", "")
model_info = data.get("model_info", {})
model_info = data.get("model_info", [])
try:
success, msg = await provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region, model_info)
@@ -409,7 +411,7 @@ async def verify_provider_api_key(provider_name: str = None):
base_url = data.get("base_url", "")
api_key = data["api_key"]
region = data.get("region", "default")
model_info = data.get("model_info", {})
model_info = data.get("model_info", [])
try:
success, msg = await provider_api_service.verify_api_key(provider_name, api_key, base_url, region, model_info)

View File

@@ -25,7 +25,7 @@ from api.db.joint_services.tenant_model_service import get_model_config_from_pro
from api.db.services.tenant_model_provider_service import TenantModelProviderService
from api.db.services.tenant_model_instance_service import TenantModelInstanceService
from api.db.services.tenant_model_service import TenantModelService
from rag.llm import EmbeddingModel, ChatModel, RerankModel
from rag.llm import EmbeddingModel, ChatModel, RerankModel, ModelMeta
def _to_int(v, default=500):
@@ -168,28 +168,38 @@ def show_provider(provider_name: str):
}
def list_provider_models(provider_name: str):
async def list_provider_models(provider_name: str, api_key: str = None, base_url: str = None):
"""
List all models for a provider from the LLM dictionary.
:param provider_name: provider/factory name
:param api_key: api key
:param base_url: base url
:return: (success, result_or_error_message)
"""
factory_info = [f for f in FACTORY_LLM_INFOS if f["name"]==provider_name]
if not factory_info:
return False, f"Provider '{provider_name}' not found"
llms = factory_info[0]["llm"]
if not llms:
return False, f"No models found for provider '{provider_name}'"
models = []
for llm in llms:
models.append({
static_llms = [{
"name": llm["name"],
"max_tokens": llm["max_tokens"],
"model_types": [llm["model_type"]],
"features": None
})
} for llm in factory_info[0]["llm"]]
model_base_url = base_url or factory_info[0].get("url", "")
remote_models = []
if provider_name in ModelMeta:
remote_models = await ModelMeta[provider_name](api_key, model_base_url).get_model_list()
if not static_llms and not remote_models:
return False, f"No models found for provider '{provider_name}'"
# Merge static and remote models, preferring remote_models on name conflicts
merged = {m["name"]: m for m in static_llms}
merged.update({m["name"]: m for m in remote_models})
models = list(merged.values())
models.sort(key=lambda x: x["name"])
return True, models
@@ -224,7 +234,7 @@ def show_provider_model(provider_name: str, model_name: str):
}
async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str|dict, base_url: str, region: str, model_info: dict=None):
async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str|dict, base_url: str, region: str, model_info: list[dict]=None):
"""
Create a provider instance.
@@ -237,7 +247,7 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
:param api_key: API key
:param base_url: base url
:param region: region
:param model_info: model info, {
:param model_info: model info, [{
"model_type": ["chat"], # support multiple
"model_name": "name"
"max_tokens": 4096,
@@ -245,7 +255,7 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
"field1": "value1",
"field2": "'value2"
}
}
}]
:return: (success, result_or_error_message)
"""
if not provider_name:
@@ -280,8 +290,12 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
extra_fields["region"] = region
TenantModelInstanceService.create_instance(provider_id=provider_obj.id,instance_name=instance_name,api_key=api_key_str, extra=json.dumps(extra_fields))
if model_info:
success, msg = add_model_to_instance(tenant_id, provider_name, instance_name, **model_info)
if not success:
msg = ""
for model in model_info:
success, _msg = add_model_to_instance(tenant_id, provider_name, instance_name, **model)
if not success:
msg += _msg
if msg:
return False, msg
return True, "success"
@@ -321,7 +335,7 @@ def list_provider_instances(tenant_id: str, provider_name: str):
return True, active_instances + inactive_instances
async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=None, region: str=None, model_info: dict=None):
async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=None, region: str=None, model_info: list[dict]=None):
"""
Verify API key for a provider.
@@ -329,7 +343,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
:param api_key: API key
:param base_url: base url
:param region: region
:param model_info: model info, {
:param model_info: model info, [{
"model_type": ["chat"], # support multiple
"model_name": "name"
"max_tokens": 4096,
@@ -337,7 +351,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
"field1": "value1",
"field2": "'value2"
}
}
}]
:return: (success, result_or_error_message)
"""
if not provider_name:
@@ -358,8 +372,8 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
return False, f"No models found for provider '{provider_name}'"
factory_llms = [{
"model_type": _type,
"llm_name": model_info.get("model_name", ""),
} for _type in model_info.get("model_type", [])]
"llm_name": model.get("model_name", ""),
} for model in model_info if model for _type in model.get("model_type", []) ]
# test if api key works
chat_passed, embd_passed, rerank_passed = False, False, False

View File

@@ -145,6 +145,7 @@ RerankModel = globals().get("RerankModel", {})
Seq2txtModel = globals().get("Seq2txtModel", {})
TTSModel = globals().get("TTSModel", {})
OcrModel = globals().get("OcrModel", {})
ModelMeta = globals().get("ModelMeta", {})
MODULE_MAPPING = {
@@ -155,6 +156,7 @@ MODULE_MAPPING = {
"sequence2txt_model": Seq2txtModel,
"tts_model": TTSModel,
"ocr_model": OcrModel,
"model_meta": ModelMeta,
}
package_name = __name__
@@ -188,7 +190,6 @@ for module_name, mapping_dict in MODULE_MAPPING.items():
else:
mapping_dict[obj._FACTORY_NAME] = obj
__all__ = [
"ChatModel",
"CvModel",
@@ -197,4 +198,5 @@ __all__ = [
"Seq2txtModel",
"TTSModel",
"OcrModel",
"ModelMeta",
]

112
rag/llm/model_meta.py Normal file
View File

@@ -0,0 +1,112 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import aiohttp
from abc import ABC
from common.constants import LLMType
class Base(ABC):
def __init__(self, api_key: str, base_url: str=None):
self.api_key = api_key
self.base_url = base_url
def _get_api_key(self):
return self.api_key
def _get_model_list_url(self):
if not self.base_url:
return None
if "/v1" in self.base_url:
return self.base_url.split("/v1")[0].rstrip("/") + "/v1/models"
return self.base_url.rstrip("/") + "/v1/models"
async def _get_raw_model_list(self):
url = self._get_model_list_url()
if not url:
return None
async with aiohttp.ClientSession() as session:
async with session.get(url, headers={"Authorization": f"Bearer {self._get_api_key()}"}) as resp:
if resp.status != 200:
return None
return await resp.json()
def _format_model_list(self, raw_model_list):
return raw_model_list
async def get_model_list(self):
raw_model_list = await self._get_raw_model_list()
if not raw_model_list:
return []
return self._format_model_list(raw_model_list)
class VolcEngine(Base):
_FACTORY_NAME = "VolcEngine"
def get_model_list(self):
# todo implement access token auth
raise NotImplementedError
class Ollama(Base):
_FACTORY_NAME = "Ollama"
def _get_model_tags_url(self):
return self.base_url.rstrip("/") + "/api/tags"
def _get_model_detail_url(self):
return self.base_url.rstrip("/") + "/api/show"
async def get_model_list(self):
if not self.base_url:
return []
headers = {}
if self.api_key:
headers.update({"Authorization": f"Bearer {self._get_api_key()}"})
async with aiohttp.ClientSession() as session:
async with session.get(self._get_model_tags_url(), headers=headers) as resp:
if resp.status != 200:
return []
tags = await resp.json()
models = tags.get("models", [])
if not models:
return []
res = []
capability_to_model_type_mapping = {
"completion": LLMType.CHAT.value,
"vision": LLMType.IMAGE2TEXT.value,
"embedding": LLMType.EMBEDDING.value
}
capability_to_feature_mapping = {
"thinking": "thinking",
"tools": "is_tools"
}
for model in models:
async with session.post(self._get_model_detail_url(), headers=headers, json={"model": model["name"]}) as resp:
if resp.status != 200:
continue
model_info = await resp.json()
max_tokens_key = "{}.context_length".format(model_info.get("details", {}).get("family", ""))
res.append({
"name": model["name"],
"model_types": [capability_to_model_type_mapping[c] for c in model_info.get("capabilities", []) if c in capability_to_model_type_mapping],
"features": [capability_to_feature_mapping[c] for c in model_info.get("capabilities", []) if c in capability_to_feature_mapping],
"max_tokens": model_info["model_info"].get(max_tokens_key, 8192)
})
return res