From b05d5a5228bc3d289f96b9c3b96be526401f92ea Mon Sep 17 00:00:00 2001 From: Lynn Date: Mon, 8 Jun 2026 11:02:40 +0800 Subject: [PATCH] Feat: get model list from remote (#15711) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Feat: - Get model list from remote provider. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/restful_apis/provider_api.py | 10 +- api/apps/services/provider_api_service.py | 54 +++++++---- rag/llm/__init__.py | 4 +- rag/llm/model_meta.py | 112 ++++++++++++++++++++++ 4 files changed, 155 insertions(+), 25 deletions(-) create mode 100644 rag/llm/model_meta.py diff --git a/api/apps/restful_apis/provider_api.py b/api/apps/restful_apis/provider_api.py index 92cc14b243..ffc06a458d 100644 --- a/api/apps/restful_apis/provider_api.py +++ b/api/apps/restful_apis/provider_api.py @@ -199,7 +199,7 @@ def delete_provider(tenant_id: str = None, provider_name: str = None): @manager.route("/providers//models", methods=["GET"]) # noqa: F821 @login_required -def list_provider_models(provider_name: str): +async def list_provider_models(provider_name: str): """ List models for a provider. --- @@ -230,7 +230,9 @@ def list_provider_models(provider_name: str): type: object """ try: - success, result = provider_api_service.list_provider_models(provider_name) + api_key = request.args.get("api_key") + base_url = request.args.get("base_url") + success, result = await provider_api_service.list_provider_models(provider_name, api_key, base_url) if success: return get_result(data=result) else: @@ -341,7 +343,7 @@ async def create_provider_instance(tenant_id: str = None, provider_name: str = N api_key = data["api_key"] base_url = data.get("base_url", "") region = data.get("region", "") - model_info = data.get("model_info", {}) + model_info = data.get("model_info", []) try: success, msg = await provider_api_service.create_provider_instance(tenant_id, provider_name, instance_name, api_key, base_url, region, model_info) @@ -409,7 +411,7 @@ async def verify_provider_api_key(provider_name: str = None): base_url = data.get("base_url", "") api_key = data["api_key"] region = data.get("region", "default") - model_info = data.get("model_info", {}) + model_info = data.get("model_info", []) try: success, msg = await provider_api_service.verify_api_key(provider_name, api_key, base_url, region, model_info) diff --git a/api/apps/services/provider_api_service.py b/api/apps/services/provider_api_service.py index a8898f0915..8896558cb5 100644 --- a/api/apps/services/provider_api_service.py +++ b/api/apps/services/provider_api_service.py @@ -25,7 +25,7 @@ from api.db.joint_services.tenant_model_service import get_model_config_from_pro from api.db.services.tenant_model_provider_service import TenantModelProviderService from api.db.services.tenant_model_instance_service import TenantModelInstanceService from api.db.services.tenant_model_service import TenantModelService -from rag.llm import EmbeddingModel, ChatModel, RerankModel +from rag.llm import EmbeddingModel, ChatModel, RerankModel, ModelMeta def _to_int(v, default=500): @@ -168,28 +168,38 @@ def show_provider(provider_name: str): } -def list_provider_models(provider_name: str): +async def list_provider_models(provider_name: str, api_key: str = None, base_url: str = None): """ List all models for a provider from the LLM dictionary. :param provider_name: provider/factory name + :param api_key: api key + :param base_url: base url :return: (success, result_or_error_message) """ factory_info = [f for f in FACTORY_LLM_INFOS if f["name"]==provider_name] if not factory_info: return False, f"Provider '{provider_name}' not found" - llms = factory_info[0]["llm"] - if not llms: - return False, f"No models found for provider '{provider_name}'" - - models = [] - for llm in llms: - models.append({ + static_llms = [{ "name": llm["name"], "max_tokens": llm["max_tokens"], "model_types": [llm["model_type"]], "features": None - }) + } for llm in factory_info[0]["llm"]] + + model_base_url = base_url or factory_info[0].get("url", "") + remote_models = [] + if provider_name in ModelMeta: + remote_models = await ModelMeta[provider_name](api_key, model_base_url).get_model_list() + + if not static_llms and not remote_models: + return False, f"No models found for provider '{provider_name}'" + + # Merge static and remote models, preferring remote_models on name conflicts + merged = {m["name"]: m for m in static_llms} + merged.update({m["name"]: m for m in remote_models}) + models = list(merged.values()) + models.sort(key=lambda x: x["name"]) return True, models @@ -224,7 +234,7 @@ def show_provider_model(provider_name: str, model_name: str): } -async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str|dict, base_url: str, region: str, model_info: dict=None): +async def create_provider_instance(tenant_id: str, provider_name: str, instance_name: str, api_key: str|dict, base_url: str, region: str, model_info: list[dict]=None): """ Create a provider instance. @@ -237,7 +247,7 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_ :param api_key: API key :param base_url: base url :param region: region - :param model_info: model info, { + :param model_info: model info, [{ "model_type": ["chat"], # support multiple "model_name": "name", "max_tokens": 4096, @@ -245,7 +255,7 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_ "field1": "value1", "field2": "'value2" } - } + }] :return: (success, result_or_error_message) """ if not provider_name: @@ -280,8 +290,12 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_ extra_fields["region"] = region TenantModelInstanceService.create_instance(provider_id=provider_obj.id,instance_name=instance_name,api_key=api_key_str, extra=json.dumps(extra_fields)) if model_info: - success, msg = add_model_to_instance(tenant_id, provider_name, instance_name, **model_info) - if not success: + msg = "" + for model in model_info: + success, _msg = add_model_to_instance(tenant_id, provider_name, instance_name, **model) + if not success: + msg += _msg + if msg: return False, msg return True, "success" @@ -321,7 +335,7 @@ def list_provider_instances(tenant_id: str, provider_name: str): return True, active_instances + inactive_instances -async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=None, region: str=None, model_info: dict=None): +async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=None, region: str=None, model_info: list[dict]=None): """ Verify API key for a provider. @@ -329,7 +343,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No :param api_key: API key :param base_url: base url :param region: region - :param model_info: model info, { + :param model_info: model info, [{ "model_type": ["chat"], # support multiple "model_name": "name", "max_tokens": 4096, @@ -337,7 +351,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No "field1": "value1", "field2": "'value2" } - } + }] :return: (success, result_or_error_message) """ if not provider_name: @@ -358,8 +372,8 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No return False, f"No models found for provider '{provider_name}'" factory_llms = [{ "model_type": _type, - "llm_name": model_info.get("model_name", ""), - } for _type in model_info.get("model_type", [])] + "llm_name": model.get("model_name", ""), + } for model in model_info if model for _type in model.get("model_type", []) ] # test if api key works chat_passed, embd_passed, rerank_passed = False, False, False diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 4e30c9f91f..6ad7941275 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -145,6 +145,7 @@ RerankModel = globals().get("RerankModel", {}) Seq2txtModel = globals().get("Seq2txtModel", {}) TTSModel = globals().get("TTSModel", {}) OcrModel = globals().get("OcrModel", {}) +ModelMeta = globals().get("ModelMeta", {}) MODULE_MAPPING = { @@ -155,6 +156,7 @@ MODULE_MAPPING = { "sequence2txt_model": Seq2txtModel, "tts_model": TTSModel, "ocr_model": OcrModel, + "model_meta": ModelMeta, } package_name = __name__ @@ -188,7 +190,6 @@ for module_name, mapping_dict in MODULE_MAPPING.items(): else: mapping_dict[obj._FACTORY_NAME] = obj - __all__ = [ "ChatModel", "CvModel", @@ -197,4 +198,5 @@ __all__ = [ "Seq2txtModel", "TTSModel", "OcrModel", + "ModelMeta", ] diff --git a/rag/llm/model_meta.py b/rag/llm/model_meta.py new file mode 100644 index 0000000000..41399456cb --- /dev/null +++ b/rag/llm/model_meta.py @@ -0,0 +1,112 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import aiohttp +from abc import ABC + +from common.constants import LLMType + + +class Base(ABC): + + def __init__(self, api_key: str, base_url: str=None): + self.api_key = api_key + self.base_url = base_url + + def _get_api_key(self): + return self.api_key + + def _get_model_list_url(self): + if not self.base_url: + return None + if "/v1" in self.base_url: + return self.base_url.split("/v1")[0].rstrip("/") + "/v1/models" + return self.base_url.rstrip("/") + "/v1/models" + + async def _get_raw_model_list(self): + url = self._get_model_list_url() + if not url: + return None + async with aiohttp.ClientSession() as session: + async with session.get(url, headers={"Authorization": f"Bearer {self._get_api_key()}"}) as resp: + if resp.status != 200: + return None + return await resp.json() + + def _format_model_list(self, raw_model_list): + return raw_model_list + + async def get_model_list(self): + raw_model_list = await self._get_raw_model_list() + if not raw_model_list: + return [] + return self._format_model_list(raw_model_list) + + +class VolcEngine(Base): + _FACTORY_NAME = "VolcEngine" + + def get_model_list(self): + # todo implement access token auth + raise NotImplementedError + + +class Ollama(Base): + _FACTORY_NAME = "Ollama" + + def _get_model_tags_url(self): + return self.base_url.rstrip("/") + "/api/tags" + + def _get_model_detail_url(self): + return self.base_url.rstrip("/") + "/api/show" + + async def get_model_list(self): + if not self.base_url: + return [] + headers = {} + if self.api_key: + headers.update({"Authorization": f"Bearer {self._get_api_key()}"}) + async with aiohttp.ClientSession() as session: + async with session.get(self._get_model_tags_url(), headers=headers) as resp: + if resp.status != 200: + return [] + tags = await resp.json() + models = tags.get("models", []) + if not models: + return [] + res = [] + capability_to_model_type_mapping = { + "completion": LLMType.CHAT.value, + "vision": LLMType.IMAGE2TEXT.value, + "embedding": LLMType.EMBEDDING.value + } + capability_to_feature_mapping = { + "thinking": "thinking", + "tools": "is_tools" + } + + for model in models: + async with session.post(self._get_model_detail_url(), headers=headers, json={"model": model["name"]}) as resp: + if resp.status != 200: + continue + model_info = await resp.json() + max_tokens_key = "{}.context_length".format(model_info.get("details", {}).get("family", "")) + res.append({ + "name": model["name"], + "model_types": [capability_to_model_type_mapping[c] for c in model_info.get("capabilities", []) if c in capability_to_model_type_mapping], + "features": [capability_to_feature_mapping[c] for c in model_info.get("capabilities", []) if c in capability_to_feature_mapping], + "max_tokens": model_info["model_info"].get(max_tokens_key, 8192) + }) + return res