mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### What problem does this PR solve? add placeholder model metas ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
218 lines
7.9 KiB
Python
218 lines
7.9 KiB
Python
#
|
|
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import json
|
|
import aiohttp
|
|
from abc import ABC
|
|
|
|
from common.constants import LLMType
|
|
|
|
|
|
class Base(ABC):
|
|
|
|
def __init__(self, api_key: str, base_url: str=None):
|
|
self.api_key = api_key
|
|
self.base_url = base_url
|
|
|
|
def _get_api_key(self):
|
|
return self.api_key
|
|
|
|
def _get_model_list_url(self):
|
|
if not self.base_url:
|
|
return None
|
|
if "/v1" in self.base_url:
|
|
return self.base_url.split("/v1")[0].rstrip("/") + "/v1/models"
|
|
return self.base_url.rstrip("/") + "/v1/models"
|
|
|
|
async def _get_raw_model_list(self):
|
|
url = self._get_model_list_url()
|
|
if not url:
|
|
return None
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(url, headers={"Authorization": f"Bearer {self._get_api_key()}"}) as resp:
|
|
if resp.status != 200:
|
|
return None
|
|
return await resp.json()
|
|
|
|
def _format_model_list(self, raw_model_list):
|
|
return raw_model_list
|
|
|
|
async def get_model_list(self):
|
|
raw_model_list = await self._get_raw_model_list()
|
|
if not raw_model_list:
|
|
return []
|
|
return self._format_model_list(raw_model_list)
|
|
|
|
|
|
class VolcEngine(Base):
|
|
_FACTORY_NAME = "VolcEngine"
|
|
|
|
def get_model_list(self):
|
|
# todo implement access token auth
|
|
raise NotImplementedError
|
|
|
|
|
|
class Ollama(Base):
|
|
_FACTORY_NAME = "Ollama"
|
|
|
|
def _get_model_tags_url(self):
|
|
return self.base_url.rstrip("/") + "/api/tags"
|
|
|
|
def _get_model_detail_url(self):
|
|
return self.base_url.rstrip("/") + "/api/show"
|
|
|
|
async def get_model_list(self):
|
|
if not self.base_url:
|
|
return []
|
|
headers = {}
|
|
if self.api_key:
|
|
headers.update({"Authorization": f"Bearer {self._get_api_key()}"})
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(self._get_model_tags_url(), headers=headers) as resp:
|
|
if resp.status != 200:
|
|
return []
|
|
tags = await resp.json()
|
|
models = tags.get("models", [])
|
|
if not models:
|
|
return []
|
|
res = []
|
|
capability_to_model_type_mapping = {
|
|
"completion": LLMType.CHAT.value,
|
|
"vision": LLMType.IMAGE2TEXT.value,
|
|
"embedding": LLMType.EMBEDDING.value
|
|
}
|
|
capability_to_feature_mapping = {
|
|
"thinking": "thinking",
|
|
"tools": "is_tools"
|
|
}
|
|
|
|
for model in models:
|
|
async with session.post(self._get_model_detail_url(), headers=headers, json={"model": model["name"]}) as resp:
|
|
if resp.status != 200:
|
|
continue
|
|
model_info = await resp.json()
|
|
max_tokens_key = "{}.context_length".format(model_info.get("details", {}).get("family", ""))
|
|
res.append({
|
|
"name": model["name"],
|
|
"model_types": [capability_to_model_type_mapping[c] for c in model_info.get("capabilities", []) if c in capability_to_model_type_mapping],
|
|
"features": [capability_to_feature_mapping[c] for c in model_info.get("capabilities", []) if c in capability_to_feature_mapping],
|
|
"max_tokens": model_info["model_info"].get(max_tokens_key, 8192)
|
|
})
|
|
return res
|
|
|
|
|
|
class FishAudio(Base):
|
|
_FACTORY_NAME = "Fish Audio"
|
|
|
|
def _get_access_token(self):
|
|
api_key = self._get_api_key()
|
|
if not api_key:
|
|
return ""
|
|
try:
|
|
payload = json.loads(api_key)
|
|
except Exception:
|
|
return api_key
|
|
if isinstance(payload, dict):
|
|
return payload.get("fish_audio_ak") or payload.get("access_token") or payload.get("api_key") or api_key
|
|
return api_key
|
|
|
|
def _get_model_list_url(self):
|
|
if not self.base_url:
|
|
return "https://api.fish.audio/model"
|
|
base_url = self.base_url.rstrip("/")
|
|
if "/v1/" in base_url:
|
|
return base_url.split("/v1")[0].rstrip("/") + "/model"
|
|
if base_url.endswith("/v1"):
|
|
return base_url[:-3] + "/model"
|
|
return base_url + "/model"
|
|
|
|
async def get_model_list(self):
|
|
url = self._get_model_list_url()
|
|
access_token = self._get_access_token()
|
|
if not url or not access_token:
|
|
return []
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(url, headers={"Authorization": f"Bearer {access_token}"}) as resp:
|
|
if resp.status != 200:
|
|
return []
|
|
raw_model_list = await resp.json()
|
|
if not isinstance(raw_model_list, dict):
|
|
return []
|
|
models = raw_model_list.get("items") or []
|
|
if not isinstance(models, list):
|
|
return []
|
|
|
|
model_list = []
|
|
for model in models:
|
|
if not isinstance(model, dict):
|
|
continue
|
|
model_name = model.get("title") or model.get("_id")
|
|
if not model_name:
|
|
continue
|
|
model_list.append({
|
|
"name": model_name,
|
|
"model_types": [LLMType.TTS.value],
|
|
"features": [],
|
|
"max_tokens": 8192,
|
|
})
|
|
return model_list
|
|
|
|
class MinerU(Base):
|
|
_FACTORY_NAME = "MinerU"
|
|
|
|
def _get_access_token(self):
|
|
api_key = self._get_api_key()
|
|
if not api_key:
|
|
return ""
|
|
try:
|
|
payload = json.loads(api_key)
|
|
except Exception:
|
|
return api_key
|
|
if isinstance(payload, dict):
|
|
return payload.get("access_token") or payload.get("api_key") or api_key
|
|
return api_key
|
|
|
|
async def get_model_list(self):
|
|
url = self._get_model_list_url()
|
|
access_token = self._get_access_token()
|
|
if not url or not access_token:
|
|
return []
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(url, headers={"Authorization": f"Bearer {access_token}"}) as resp:
|
|
if resp.status != 200:
|
|
return []
|
|
raw_model_list = await resp.json()
|
|
if isinstance(raw_model_list, dict):
|
|
raw_model_list = raw_model_list.get("data") or raw_model_list.get("models") or raw_model_list.get("items") or []
|
|
if not isinstance(raw_model_list, list):
|
|
return []
|
|
|
|
model_list = []
|
|
for model in raw_model_list:
|
|
if not isinstance(model, dict):
|
|
continue
|
|
model_name = model.get("title") or model.get("name") or model.get("id") or model.get("_id")
|
|
if not model_name:
|
|
continue
|
|
model_list.append({
|
|
"name": model_name,
|
|
"model_types": [LLMType.OCR.value],
|
|
"features": [],
|
|
"max_tokens": model.get("max_tokens", 8192),
|
|
})
|
|
return model_list
|