Files
ragflow/rag/llm/tts_model.py
Tim Wang ca96d61e73 Feat: Add New API model provider for OpenAI-compatible gateways (#15991)
## Summary

Add support for **"New API"** as a model provider, enabling connection
to [New API](https://github.com/QuantumNous/new-api) /
[one-api](https://github.com/songquanpeng/one-api) compatible gateways
that aggregate multiple LLM backends behind a unified OpenAI-compatible
`/v1` endpoint.

### Features

- **All model types**: Chat, Embedding, Rerank, Image2Text, TTS,
Speech2Text
- **List Models discovery**: `NewAPI(OpenAIAPICompatible)` class in
`model_meta.py` queries the gateway's `/v1/models` to auto-discover
available models via the native `GET /api/v1/providers/<name>/models`
endpoint
- **Model parameter editing**: Pencil icon on each discovered model row
to edit `model_type`, `max_tokens`, and `features` (e.g. tool call
support) before submitting
- **Custom model addition**: "Add Custom Model" button at the bottom of
the List Models dropdown for models not returned by the API
- **Gear icon settings**: Enabled the Settings gear button on provider
instances to manage models on existing instances (viewMode)
- **viewMode credential passthrough**: Fixed List Models in viewMode —
merges `initialValues` credentials when `api_key`/`base_url` fields are
hidden by `hideWhenInstanceExists`

### Changes

**Backend** (8 files):
- `rag/llm/chat_model.py` — `NewAPIChat(Base)` class
- `rag/llm/embedding_model.py` — `NewAPIEmbed(OpenAIEmbed)` class (no
auto `/v1` append)
- `rag/llm/rerank_model.py` — `NewAPIRerank(Base)` class (uses `/rerank`
endpoint)
- `rag/llm/cv_model.py` — `NewAPICv(GptV4)` class
- `rag/llm/tts_model.py` — `NewAPITTS(OpenAITTS)` class
- `rag/llm/sequence2txt_model.py` — `NewAPISeq2txt(GPTSeq2txt)` class
- `rag/llm/model_meta.py` — `NewAPI(OpenAIAPICompatible)` class for List
Models discovery
- `conf/llm_factories.json` — New API factory entry with all model type
tags

**Frontend** (8 files + 1 new SVG):
- `web/src/assets/svg/llm/new-api.svg` — New API logo icon
- `web/src/constants/llm.ts` — `LLMFactory.NewAPI` enum + `IconMap`
entry
- `web/src/components/svg-icon.tsx` — `NewAPI` added to `svgIcons`
-
`web/src/pages/user-setting/setting-model/modal/provider-modal/field-config/local-llm-configs.ts`
— New API `buildLocalConfig`
-
`web/src/pages/user-setting/setting-model/modal/provider-modal/constants.ts`
— `LIST_MODEL_PROVIDERS` includes NewAPI
- `web/src/pages/user-setting/setting-model/components/used-model.tsx` —
Enable Settings gear button
-
`web/src/pages/user-setting/setting-model/modal/provider-modal/hooks/use-list-models-picker.ts`
— viewMode credential merge + model editing state/handlers
-
`web/src/pages/user-setting/setting-model/modal/provider-modal/hooks/use-list-models-options.tsx`
— Pencil edit icon per model row
-
`web/src/pages/user-setting/setting-model/modal/provider-modal/index.tsx`
— `AddCustomModelDialog` import + edit dialog rendering

**Note on Go implementation**: A Go model driver (`NewAPIModel`
delegating to `OpenAIModel`) has been prepared but is deferred until the
Go runtime is enabled in a future release (current v0.26.0 images use
`API_PROXY_SCHEME=python` and do not compile Go binaries). Will submit
as a follow-up PR.

## Related

- Depends on: #15996 (provider instance API improvements — server-side
credential lookup, idempotent `add_model`, security fixes — required for
viewMode gear icon and batch model submission)

## Test plan

- [ ] Add New API provider with api_key and base_url pointing to an
OpenAI-compatible gateway
- [ ] Click "List Models" — should discover and display available models
from `/v1/models`
- [ ] Click pencil icon on a model — should open edit dialog to change
model_type, max_tokens, features
- [ ] Select multiple models and click OK — should add all selected
models
- [ ] Click gear icon on the added instance — should open viewMode with
List Models working
- [ ] In viewMode, select new models including pre-existing ones, click
OK — should succeed (requires #15996)
- [ ] Verify all model types work: create a Chat assistant, Embedding
KB, Rerank setting

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Tim Wang <wanghualoong@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-06-26 18:47:20 +08:00

556 lines
18 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import _thread as thread
import base64
import hashlib
import hmac
import json
import os
import queue
import re
import ssl
import time
from abc import ABC
from datetime import datetime
from time import mktime
from typing import Annotated, Literal
import logging
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import httpx
import ormsgpack
import requests
import websocket
from pydantic import BaseModel, conint
from common.http_client import sync_request
from common.token_utils import num_tokens_from_string
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str
class ServeTTSRequest(BaseModel):
text: str
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3"] = "mp3"
mp3_bitrate: Literal[64, 128, 192] = 128
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
# Reference id
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
# Balance mode will reduce latency to 300ms, but may decrease stability
latency: Literal["normal", "balanced"] = "normal"
class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs):
"""
Abstract base class constructor.
Parameters are not stored; subclasses should handle their own initialization.
"""
pass
def tts(self, audio):
pass
def normalize_text(self, text):
return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)
class HTTPBasedTTS(Base):
"""
Base class for HTTP-based TTS services.
Provides common HTTP request handling and response processing.
"""
def __init__(self, key, model_name, base_url, **kwargs):
self.model_name = model_name
self.base_url = base_url
self.api_key = key
self.headers = {
"Content-Type": "application/json"
}
if key and key != "x":
self.headers["Authorization"] = f"Bearer {self.api_key}"
def _build_payload(self, text, voice, **kwargs):
"""
Build payload for TTS request.
Subclasses should override this method if they need custom payload structure.
"""
return {
"model": self.model_name,
"voice": voice,
"input": text
}
def _send_request(self, endpoint, payload, stream=True):
"""
Send HTTP request to TTS service.
"""
url = f"{self.base_url}{endpoint}"
response = requests.post(
url,
headers=self.headers,
json=payload,
stream=stream,
timeout=60,
)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
return response
def _process_response(self, response):
"""
Process streaming response from TTS service.
"""
for chunk in response.iter_content():
if chunk:
yield chunk
def tts(self, text, voice="alloy"):
"""
Generate speech from text.
"""
text = self.normalize_text(text)
payload = self._build_payload(text, voice)
response = self._send_request("/audio/speech", payload)
return self._process_response(response)
class FishAudioTTS(Base):
_FACTORY_NAME = "Fish Audio"
def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
if not base_url:
base_url = "https://api.fish.audio/v1/tts"
key = json.loads(key)
self.headers = {
"api-key": key.get("fish_audio_ak"),
"content-type": "application/msgpack",
}
self.ref_id = key.get("fish_audio_refid")
self.base_url = base_url
def tts(self, text):
from http import HTTPStatus
text = self.normalize_text(text)
request = ServeTTSRequest(text=text, reference_id=self.ref_id)
with httpx.Client() as client:
try:
with client.stream(
method="POST",
url=self.base_url,
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
headers=self.headers,
timeout=None,
) as response:
if response.status_code == HTTPStatus.OK:
for chunk in response.iter_bytes():
yield chunk
else:
response.raise_for_status()
yield num_tokens_from_string(text)
except httpx.HTTPStatusError as e:
raise RuntimeError(f"**ERROR**: {e}")
class QwenTTS(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name, base_url=""):
import dashscope
self.model_name = model_name
dashscope.api_key = key
def tts(self, text):
from collections import deque
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer
class Callback(ResultCallback):
def __init__(self) -> None:
self.dque = deque()
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
pass
def on_complete(self):
self.dque.append(None)
def on_error(self, response: SpeechSynthesisResponse):
raise RuntimeError(str(response))
def on_close(self):
pass
def on_event(self, result: SpeechSynthesisResult):
if result.get_audio_frame() is not None:
self.dque.append(result.get_audio_frame())
text = self.normalize_text(text)
callback = Callback()
SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3")
try:
for data in callback._run():
yield data
yield num_tokens_from_string(text)
except Exception as e:
raise RuntimeError(f"**ERROR**: {e}")
class OpenAITTS(HTTPBasedTTS):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
super().__init__(key, model_name, base_url)
class SparkTTS(Base):
_FACTORY_NAME = "XunFei Spark"
STATUS_FIRST_FRAME = 0
STATUS_CONTINUE_FRAME = 1
STATUS_LAST_FRAME = 2
def __init__(self, key, model_name, base_url=""):
key = json.loads(key)
self.APPID = key.get("spark_app_id", "xxxxxxx")
self.APISecret = key.get("spark_api_secret", "xxxxxxx")
self.APIKey = key.get("spark_api_key", "xxxxxx")
self.model_name = model_name
self.CommonArgs = {"app_id": self.APPID}
self.audio_queue = queue.Queue()
# 用来存储音频数据
# 生成url
def create_url(self):
url = "wss://tts-api.xfyun.cn/v2/tts"
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
url = url + "?" + urlencode(v)
return url
def tts(self, text):
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")}
CommonArgs = {"app_id": self.APPID}
audio_queue = self.audio_queue
model_name = self.model_name
class Callback:
def __init__(self):
self.audio_queue = audio_queue
def on_message(self, ws, message):
message = json.loads(message)
code = message["code"]
sid = message["sid"]
audio = message["data"]["audio"]
audio = base64.b64decode(audio)
status = message["data"]["status"]
if status == 2:
ws.close()
if code != 0:
errMsg = message["message"]
raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
else:
self.audio_queue.put(audio)
def on_error(self, ws, error):
raise Exception(error)
def on_close(self, ws, close_status_code, close_msg):
self.audio_queue.put(None) # None is terminator
def on_open(self, ws):
def run(*args):
d = {"common": CommonArgs, "business": BusinessArgs, "data": Data}
ws.send(json.dumps(d))
thread.start_new_thread(run, ())
wsUrl = self.create_url()
websocket.enableTrace(False)
a = Callback()
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message)
status_code = 0
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
while True:
audio_chunk = self.audio_queue.get()
if audio_chunk is None:
if status_code == 0:
raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
else:
break
status_code = 1
yield audio_chunk
class XinferenceTTS(HTTPBasedTTS):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name, **kwargs):
base_url = kwargs.get("base_url", None)
super().__init__(key, model_name, base_url)
# Override headers to remove Authorization
self.headers = {"accept": "application/json", "Content-Type": "application/json"}
def _process_response(self, response):
# Use chunk_size=1024 for processing response
for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk
def tts(self, text, voice="中文女", stream=True):
text = self.normalize_text(text)
payload = self._build_payload(text, voice)
response = self._send_request("/v1/audio/speech", payload, stream=stream)
return self._process_response(response)
class OllamaTTS(HTTPBasedTTS):
def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
if not base_url:
base_url = "https://api.ollama.ai/v1"
super().__init__(key, model_name, base_url)
def tts(self, text, voice="standard-voice"):
text = self.normalize_text(text)
payload = self._build_payload(text, voice)
response = self._send_request("/audio/tts", payload)
return self._process_response(response)
class GPUStackTTS(HTTPBasedTTS):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, **kwargs):
base_url = kwargs.get("base_url", None)
super().__init__(key, model_name, base_url)
# Add accept header
self.headers["accept"] = "application/json"
def _process_response(self, response):
# Use chunk_size=1024 for processing response
for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk
def tts(self, text, voice="Chinese Female", stream=True):
text = self.normalize_text(text)
payload = self._build_payload(text, voice)
response = self._send_request("/v1/audio/speech", payload, stream=stream)
return self._process_response(response)
class SILICONFLOWTTS(HTTPBasedTTS):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
super().__init__(key, model_name, base_url)
def _build_payload(self, text, voice, **kwargs):
# Custom payload structure for SILICONFLOW
return {
"model": self.model_name,
"input": text,
"voice": f"{self.model_name}:{voice}",
"response_format": "mp3",
"sample_rate": 123,
"stream": True,
"speed": 1,
"gain": 0,
}
def tts(self, text, voice="anna"):
text = self.normalize_text(text)
payload = self._build_payload(text, voice)
response = self._send_request("/audio/speech", payload)
return self._process_response(response)
class DeepInfraTTS(OpenAITTS):
_FACTORY_NAME = "DeepInfra"
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai", **kwargs):
if not base_url:
base_url = "https://api.deepinfra.com/v1/openai"
super().__init__(key, model_name, base_url, **kwargs)
class CometAPITTS(OpenAITTS):
_FACTORY_NAME = "CometAPI"
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1", **kwargs):
if not base_url:
base_url = "https://api.cometapi.com/v1"
super().__init__(key, model_name, base_url, **kwargs)
class DeerAPITTS(OpenAITTS):
_FACTORY_NAME = "DeerAPI"
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs):
if not base_url:
base_url = "https://api.deerapi.com/v1"
super().__init__(key, model_name, base_url, **kwargs)
class FuturMixTTS(OpenAITTS):
_FACTORY_NAME = "FuturMix"
def __init__(self, key, model_name, base_url="https://futurmix.ai/v1", **kwargs):
if not base_url:
base_url = "https://futurmix.ai/v1"
super().__init__(key, model_name, base_url, **kwargs)
logging.info("[FuturMix] TTS initialized with model %s", model_name)
class StepFunTTS(OpenAITTS):
_FACTORY_NAME = "StepFun"
_SUPPORTED_RESPONSE_FORMATS = {"wav", "mp3", "flac", "opus", "pcm"}
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs):
if not base_url:
base_url = "https://api.stepfun.com/v1"
self.default_voice = os.environ.get("STEPFUN_TTS_VOICE") or "cixingnansheng"
super().__init__(key, model_name, base_url, **kwargs)
def tts(self, text, voice=None, response_format: Literal["wav", "mp3", "flac", "opus", "pcm"] = "mp3"):
text = self.normalize_text(text)
if response_format not in self._SUPPORTED_RESPONSE_FORMATS:
raise ValueError(f"Unsupported response_format={response_format!r}. Supported: {sorted(self._SUPPORTED_RESPONSE_FORMATS)}")
payload = {
"model": self.model_name,
"voice": voice or self.default_voice,
"input": text,
"response_format": response_format,
}
response = sync_request("POST", f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_bytes():
if chunk:
yield chunk
yield num_tokens_from_string(text)
class RAGconTTS(Base):
"""
RAGcon TTS Provider - routes through LiteLLM proxy
Text-to-speech models routed through LiteLLM.
Default Base URL: https://connect.ragcon.ai/v1
"""
_FACTORY_NAME = "RAGcon"
def __init__(self, key, model_name, base_url=None, **kwargs):
if not base_url:
base_url = "https://connect.ragcon.com/v1"
self.base_url = base_url
self.api_key = key
self.model_name = model_name
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
def tts(self, text, voice="English Female", stream=True):
"""
Uses LiteLLM's /v1/audio/speech endpoint
"""
payload = {
"model": self.model_name,
"input": text,
"voice": voice
}
response = requests.post(
f"{self.base_url}/audio/speech",
headers=self.headers,
json=payload,
stream=stream,
timeout=60,
)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk
class NewAPITTS(OpenAITTS):
_FACTORY_NAME = "New API"
def __init__(self, key, model_name, base_url="", **kwargs):
if not base_url:
raise ValueError("url cannot be None")
model_name = model_name.split("___")[0]
super().__init__(key, model_name, base_url, **kwargs)