mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Feat: model list (#15774)
### What problem does this PR solve? Support model list for VolcEngine. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -256,7 +256,7 @@ async def create_provider_instance(tenant_id: str, provider_name: str, instance_
|
||||
:param region: region
|
||||
:param model_info: model info, [{
|
||||
"model_type": ["chat"], # support multiple
|
||||
"model_name": "name",
|
||||
"model_name": "name",
|
||||
"max_tokens": 4096,
|
||||
"extra": {
|
||||
"field1": "value1",
|
||||
@@ -352,7 +352,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
|
||||
:param region: region
|
||||
:param model_info: model info, [{
|
||||
"model_type": ["chat"], # support multiple
|
||||
"model_name": "name",
|
||||
"model_name": "name",
|
||||
"max_tokens": 4096,
|
||||
"extra": {
|
||||
"field1": "value1",
|
||||
|
||||
@@ -1175,6 +1175,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM, TEXT EMBEDDING, IMAGE2TEXT",
|
||||
"status": "1",
|
||||
"url": "https://ark.cn-beijing.volces.com/api/v3",
|
||||
"llm": []
|
||||
},
|
||||
{
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
import json
|
||||
import aiohttp
|
||||
from abc import ABC
|
||||
from urllib.parse import urlparse
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
from common.constants import LLMType
|
||||
|
||||
@@ -58,9 +60,61 @@ class Base(ABC):
|
||||
class VolcEngine(Base):
|
||||
_FACTORY_NAME = "VolcEngine"
|
||||
|
||||
def get_model_list(self):
|
||||
# todo implement access token auth
|
||||
raise NotImplementedError
|
||||
def _get_api_key(self):
|
||||
try:
|
||||
api_key = json.loads(self.api_key).get("ark_api_key", "")
|
||||
except JSONDecodeError:
|
||||
api_key = self.api_key
|
||||
return api_key
|
||||
|
||||
def _get_model_list_url(self):
|
||||
if not self.base_url:
|
||||
self.base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
parsed = urlparse(self.base_url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}/api/v3/models"
|
||||
|
||||
def _format_model_list(self, raw_model_list):
|
||||
serving_model = [model for model in raw_model_list["data"] if model.get("status", "") != "Shutdown"]
|
||||
res = []
|
||||
for model in serving_model:
|
||||
|
||||
model_types = []
|
||||
|
||||
if model.get("domain", "") == "Embedding":
|
||||
model_types.append(LLMType.EMBEDDING.value)
|
||||
else:
|
||||
modalities = model.get("modalities", {})
|
||||
input_modalities = modalities.get("input_modalities", [])
|
||||
output_modalities = modalities.get("output_modalities", [])
|
||||
|
||||
if "text" in output_modalities:
|
||||
model_types.append(LLMType.CHAT.value)
|
||||
if "embeddings" in output_modalities:
|
||||
model_types.append(LLMType.EMBEDDING.value)
|
||||
if "image" in input_modalities and "text" in output_modalities:
|
||||
model_types.append(LLMType.IMAGE2TEXT.value)
|
||||
if "audio" in input_modalities and "text" in output_modalities:
|
||||
model_types.append(LLMType.SPEECH2TEXT.value)
|
||||
if "audio" in output_modalities:
|
||||
model_types.append(LLMType.TTS.value)
|
||||
|
||||
if not model_types:
|
||||
continue
|
||||
|
||||
features = []
|
||||
if model.get("features", {}).get("tools", {}).get("function_calling", False):
|
||||
features.append("is_tools")
|
||||
if model.get("token_limits", {}).get("max_reasoning_token_length", 0) > 0:
|
||||
features.append("thinking")
|
||||
|
||||
res.append({
|
||||
"name": model["id"],
|
||||
"model_types": model_types,
|
||||
"features": features,
|
||||
"max_tokens": model.get("token_limits", {}).get("max_input_token_length", 8192),
|
||||
"status": model.get("status")
|
||||
})
|
||||
return res
|
||||
|
||||
|
||||
class Ollama(Base):
|
||||
|
||||
Reference in New Issue
Block a user