Files
ragflow/rag/llm/model_meta.py
buua436 6bf7056422 feat: add placeholder model metas (#15753)
### What problem does this PR solve?

add placeholder model metas

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-06-08 14:54:59 +08:00

218 lines
7.9 KiB
Python

#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import aiohttp
from abc import ABC
from common.constants import LLMType
class Base(ABC):
def __init__(self, api_key: str, base_url: str=None):
self.api_key = api_key
self.base_url = base_url
def _get_api_key(self):
return self.api_key
def _get_model_list_url(self):
if not self.base_url:
return None
if "/v1" in self.base_url:
return self.base_url.split("/v1")[0].rstrip("/") + "/v1/models"
return self.base_url.rstrip("/") + "/v1/models"
async def _get_raw_model_list(self):
url = self._get_model_list_url()
if not url:
return None
async with aiohttp.ClientSession() as session:
async with session.get(url, headers={"Authorization": f"Bearer {self._get_api_key()}"}) as resp:
if resp.status != 200:
return None
return await resp.json()
def _format_model_list(self, raw_model_list):
return raw_model_list
async def get_model_list(self):
raw_model_list = await self._get_raw_model_list()
if not raw_model_list:
return []
return self._format_model_list(raw_model_list)
class VolcEngine(Base):
_FACTORY_NAME = "VolcEngine"
def get_model_list(self):
# todo implement access token auth
raise NotImplementedError
class Ollama(Base):
_FACTORY_NAME = "Ollama"
def _get_model_tags_url(self):
return self.base_url.rstrip("/") + "/api/tags"
def _get_model_detail_url(self):
return self.base_url.rstrip("/") + "/api/show"
async def get_model_list(self):
if not self.base_url:
return []
headers = {}
if self.api_key:
headers.update({"Authorization": f"Bearer {self._get_api_key()}"})
async with aiohttp.ClientSession() as session:
async with session.get(self._get_model_tags_url(), headers=headers) as resp:
if resp.status != 200:
return []
tags = await resp.json()
models = tags.get("models", [])
if not models:
return []
res = []
capability_to_model_type_mapping = {
"completion": LLMType.CHAT.value,
"vision": LLMType.IMAGE2TEXT.value,
"embedding": LLMType.EMBEDDING.value
}
capability_to_feature_mapping = {
"thinking": "thinking",
"tools": "is_tools"
}
for model in models:
async with session.post(self._get_model_detail_url(), headers=headers, json={"model": model["name"]}) as resp:
if resp.status != 200:
continue
model_info = await resp.json()
max_tokens_key = "{}.context_length".format(model_info.get("details", {}).get("family", ""))
res.append({
"name": model["name"],
"model_types": [capability_to_model_type_mapping[c] for c in model_info.get("capabilities", []) if c in capability_to_model_type_mapping],
"features": [capability_to_feature_mapping[c] for c in model_info.get("capabilities", []) if c in capability_to_feature_mapping],
"max_tokens": model_info["model_info"].get(max_tokens_key, 8192)
})
return res
class FishAudio(Base):
_FACTORY_NAME = "Fish Audio"
def _get_access_token(self):
api_key = self._get_api_key()
if not api_key:
return ""
try:
payload = json.loads(api_key)
except Exception:
return api_key
if isinstance(payload, dict):
return payload.get("fish_audio_ak") or payload.get("access_token") or payload.get("api_key") or api_key
return api_key
def _get_model_list_url(self):
if not self.base_url:
return "https://api.fish.audio/model"
base_url = self.base_url.rstrip("/")
if "/v1/" in base_url:
return base_url.split("/v1")[0].rstrip("/") + "/model"
if base_url.endswith("/v1"):
return base_url[:-3] + "/model"
return base_url + "/model"
async def get_model_list(self):
url = self._get_model_list_url()
access_token = self._get_access_token()
if not url or not access_token:
return []
async with aiohttp.ClientSession() as session:
async with session.get(url, headers={"Authorization": f"Bearer {access_token}"}) as resp:
if resp.status != 200:
return []
raw_model_list = await resp.json()
if not isinstance(raw_model_list, dict):
return []
models = raw_model_list.get("items") or []
if not isinstance(models, list):
return []
model_list = []
for model in models:
if not isinstance(model, dict):
continue
model_name = model.get("title") or model.get("_id")
if not model_name:
continue
model_list.append({
"name": model_name,
"model_types": [LLMType.TTS.value],
"features": [],
"max_tokens": 8192,
})
return model_list
class MinerU(Base):
_FACTORY_NAME = "MinerU"
def _get_access_token(self):
api_key = self._get_api_key()
if not api_key:
return ""
try:
payload = json.loads(api_key)
except Exception:
return api_key
if isinstance(payload, dict):
return payload.get("access_token") or payload.get("api_key") or api_key
return api_key
async def get_model_list(self):
url = self._get_model_list_url()
access_token = self._get_access_token()
if not url or not access_token:
return []
async with aiohttp.ClientSession() as session:
async with session.get(url, headers={"Authorization": f"Bearer {access_token}"}) as resp:
if resp.status != 200:
return []
raw_model_list = await resp.json()
if isinstance(raw_model_list, dict):
raw_model_list = raw_model_list.get("data") or raw_model_list.get("models") or raw_model_list.get("items") or []
if not isinstance(raw_model_list, list):
return []
model_list = []
for model in raw_model_list:
if not isinstance(model, dict):
continue
model_name = model.get("title") or model.get("name") or model.get("id") or model.get("_id")
if not model_name:
continue
model_list.append({
"name": model_name,
"model_types": [LLMType.OCR.value],
"features": [],
"max_tokens": model.get("max_tokens", 8192),
})
return model_list