mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 16:25:44 +08:00
refactor:improve tts model's codes (#13137)
### What problem does this PR solve? improve tts model's codes ### Type of change - [x] Refactoring
This commit is contained in:
@@ -79,6 +79,68 @@ class Base(ABC):
|
||||
return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)
|
||||
|
||||
|
||||
class HTTPBasedTTS(Base):
|
||||
"""
|
||||
Base class for HTTP-based TTS services.
|
||||
Provides common HTTP request handling and response processing.
|
||||
"""
|
||||
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.api_key = key
|
||||
self.headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
if key and key != "x":
|
||||
self.headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
def _build_payload(self, text, voice, **kwargs):
|
||||
"""
|
||||
Build payload for TTS request.
|
||||
Subclasses should override this method if they need custom payload structure.
|
||||
"""
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"voice": voice,
|
||||
"input": text
|
||||
}
|
||||
|
||||
def _send_request(self, endpoint, payload, stream=True):
|
||||
"""
|
||||
Send HTTP request to TTS service.
|
||||
"""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
|
||||
return response
|
||||
|
||||
def _process_response(self, response):
|
||||
"""
|
||||
Process streaming response from TTS service.
|
||||
"""
|
||||
for chunk in response.iter_content():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
def tts(self, text, voice="alloy"):
|
||||
"""
|
||||
Generate speech from text.
|
||||
"""
|
||||
text = self.normalize_text(text)
|
||||
payload = self._build_payload(text, voice)
|
||||
response = self._send_request("/audio/speech", payload)
|
||||
return self._process_response(response)
|
||||
|
||||
|
||||
class FishAudioTTS(Base):
|
||||
_FACTORY_NAME = "Fish Audio"
|
||||
|
||||
@@ -178,28 +240,13 @@ class QwenTTS(Base):
|
||||
raise RuntimeError(f"**ERROR**: {e}")
|
||||
|
||||
|
||||
class OpenAITTS(Base):
|
||||
class OpenAITTS(HTTPBasedTTS):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
def tts(self, text, voice="alloy"):
|
||||
text = self.normalize_text(text)
|
||||
payload = {"model": self.model_name, "voice": voice, "input": text}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
for chunk in response.iter_content():
|
||||
if chunk:
|
||||
yield chunk
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class SparkTTS(Base):
|
||||
@@ -291,86 +338,74 @@ class SparkTTS(Base):
|
||||
yield audio_chunk
|
||||
|
||||
|
||||
class XinferenceTTS(Base):
|
||||
class XinferenceTTS(HTTPBasedTTS):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.model_name = model_name
|
||||
base_url = kwargs.get("base_url", None)
|
||||
super().__init__(key, model_name, base_url)
|
||||
# Override headers to remove Authorization
|
||||
self.headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
def tts(self, text, voice="中文女", stream=True):
|
||||
payload = {"model": self.model_name, "input": text, "voice": voice}
|
||||
|
||||
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
|
||||
def _process_response(self, response):
|
||||
# Use chunk_size=1024 for processing response
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
def tts(self, text, voice="中文女", stream=True):
|
||||
text = self.normalize_text(text)
|
||||
payload = self._build_payload(text, voice)
|
||||
response = self._send_request("/v1/audio/speech", payload, stream=stream)
|
||||
return self._process_response(response)
|
||||
|
||||
class OllamaTTS(Base):
|
||||
|
||||
class OllamaTTS(HTTPBasedTTS):
|
||||
def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.ollama.ai/v1"
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
if key and key != "x":
|
||||
self.headers["Authorization"] = f"Bearer {key}"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
def tts(self, text, voice="standard-voice"):
|
||||
payload = {"model": self.model_name, "voice": voice, "input": text}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
|
||||
for chunk in response.iter_content():
|
||||
if chunk:
|
||||
yield chunk
|
||||
text = self.normalize_text(text)
|
||||
payload = self._build_payload(text, voice)
|
||||
response = self._send_request("/audio/tts", payload)
|
||||
return self._process_response(response)
|
||||
|
||||
|
||||
class GPUStackTTS(Base):
|
||||
class GPUStackTTS(HTTPBasedTTS):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def tts(self, text, voice="Chinese Female", stream=True):
|
||||
payload = {"model": self.model_name, "input": text, "voice": voice}
|
||||
|
||||
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
base_url = kwargs.get("base_url", None)
|
||||
super().__init__(key, model_name, base_url)
|
||||
# Add accept header
|
||||
self.headers["accept"] = "application/json"
|
||||
|
||||
def _process_response(self, response):
|
||||
# Use chunk_size=1024 for processing response
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
def tts(self, text, voice="Chinese Female", stream=True):
|
||||
text = self.normalize_text(text)
|
||||
payload = self._build_payload(text, voice)
|
||||
response = self._send_request("/v1/audio/speech", payload, stream=stream)
|
||||
return self._process_response(response)
|
||||
|
||||
class SILICONFLOWTTS(Base):
|
||||
|
||||
class SILICONFLOWTTS(HTTPBasedTTS):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
def tts(self, text, voice="anna"):
|
||||
text = self.normalize_text(text)
|
||||
payload = {
|
||||
def _build_payload(self, text, voice, **kwargs):
|
||||
# Custom payload structure for SILICONFLOW
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"input": text,
|
||||
"voice": f"{self.model_name}:{voice}",
|
||||
@@ -381,13 +416,11 @@ class SILICONFLOWTTS(Base):
|
||||
"gain": 0,
|
||||
}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
for chunk in response.iter_content():
|
||||
if chunk:
|
||||
yield chunk
|
||||
def tts(self, text, voice="anna"):
|
||||
text = self.normalize_text(text)
|
||||
payload = self._build_payload(text, voice)
|
||||
response = self._send_request("/audio/speech", payload)
|
||||
return self._process_response(response)
|
||||
|
||||
|
||||
class DeepInfraTTS(OpenAITTS):
|
||||
|
||||
Reference in New Issue
Block a user