From aec2ef42326087d17afa13c33ad057ccb4d7dd52 Mon Sep 17 00:00:00 2001 From: Stephen Hu Date: Sat, 28 Feb 2026 10:18:00 +0800 Subject: [PATCH] refactor:improve tts model's codes (#13137) ### What problem does this PR solve? improve tts model's codes ### Type of change - [x] Refactoring --- rag/llm/tts_model.py | 177 +++++++++++++++++++++++++------------------ 1 file changed, 105 insertions(+), 72 deletions(-) diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index 035d8412b4..602ea165a1 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -79,6 +79,68 @@ class Base(ABC): return re.sub(r"(\*\*|##\d+\$\$|#)", "", text) +class HTTPBasedTTS(Base): + """ + Base class for HTTP-based TTS services. + Provides common HTTP request handling and response processing. + """ + + def __init__(self, key, model_name, base_url, **kwargs): + self.model_name = model_name + self.base_url = base_url + self.api_key = key + self.headers = { + "Content-Type": "application/json" + } + if key and key != "x": + self.headers["Authorization"] = f"Bearer {self.api_key}" + + def _build_payload(self, text, voice, **kwargs): + """ + Build payload for TTS request. + Subclasses should override this method if they need custom payload structure. + """ + return { + "model": self.model_name, + "voice": voice, + "input": text + } + + def _send_request(self, endpoint, payload, stream=True): + """ + Send HTTP request to TTS service. + """ + url = f"{self.base_url}{endpoint}" + response = requests.post( + url, + headers=self.headers, + json=payload, + stream=stream + ) + + if response.status_code != 200: + raise Exception(f"**Error**: {response.status_code}, {response.text}") + + return response + + def _process_response(self, response): + """ + Process streaming response from TTS service. + """ + for chunk in response.iter_content(): + if chunk: + yield chunk + + def tts(self, text, voice="alloy"): + """ + Generate speech from text. + """ + text = self.normalize_text(text) + payload = self._build_payload(text, voice) + response = self._send_request("/audio/speech", payload) + return self._process_response(response) + + class FishAudioTTS(Base): _FACTORY_NAME = "Fish Audio" @@ -178,28 +240,13 @@ class QwenTTS(Base): raise RuntimeError(f"**ERROR**: {e}") -class OpenAITTS(Base): +class OpenAITTS(HTTPBasedTTS): _FACTORY_NAME = "OpenAI" def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"): if not base_url: base_url = "https://api.openai.com/v1" - self.api_key = key - self.model_name = model_name - self.base_url = base_url - self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} - - def tts(self, text, voice="alloy"): - text = self.normalize_text(text) - payload = {"model": self.model_name, "voice": voice, "input": text} - - response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True) - - if response.status_code != 200: - raise Exception(f"**Error**: {response.status_code}, {response.text}") - for chunk in response.iter_content(): - if chunk: - yield chunk + super().__init__(key, model_name, base_url) class SparkTTS(Base): @@ -291,86 +338,74 @@ class SparkTTS(Base): yield audio_chunk -class XinferenceTTS(Base): +class XinferenceTTS(HTTPBasedTTS): _FACTORY_NAME = "Xinference" def __init__(self, key, model_name, **kwargs): - self.base_url = kwargs.get("base_url", None) - self.model_name = model_name + base_url = kwargs.get("base_url", None) + super().__init__(key, model_name, base_url) + # Override headers to remove Authorization self.headers = {"accept": "application/json", "Content-Type": "application/json"} - def tts(self, text, voice="中文女", stream=True): - payload = {"model": self.model_name, "input": text, "voice": voice} - - response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream) - - if response.status_code != 200: - raise Exception(f"**Error**: {response.status_code}, {response.text}") - + def _process_response(self, response): + # Use chunk_size=1024 for processing response for chunk in response.iter_content(chunk_size=1024): if chunk: yield chunk + def tts(self, text, voice="中文女", stream=True): + text = self.normalize_text(text) + payload = self._build_payload(text, voice) + response = self._send_request("/v1/audio/speech", payload, stream=stream) + return self._process_response(response) -class OllamaTTS(Base): + +class OllamaTTS(HTTPBasedTTS): def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"): if not base_url: base_url = "https://api.ollama.ai/v1" - self.model_name = model_name - self.base_url = base_url - self.headers = {"Content-Type": "application/json"} - if key and key != "x": - self.headers["Authorization"] = f"Bearer {key}" + super().__init__(key, model_name, base_url) def tts(self, text, voice="standard-voice"): - payload = {"model": self.model_name, "voice": voice, "input": text} - - response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True) - - if response.status_code != 200: - raise Exception(f"**Error**: {response.status_code}, {response.text}") - - for chunk in response.iter_content(): - if chunk: - yield chunk + text = self.normalize_text(text) + payload = self._build_payload(text, voice) + response = self._send_request("/audio/tts", payload) + return self._process_response(response) -class GPUStackTTS(Base): +class GPUStackTTS(HTTPBasedTTS): _FACTORY_NAME = "GPUStack" def __init__(self, key, model_name, **kwargs): - self.base_url = kwargs.get("base_url", None) - self.api_key = key - self.model_name = model_name - self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} - - def tts(self, text, voice="Chinese Female", stream=True): - payload = {"model": self.model_name, "input": text, "voice": voice} - - response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream) - - if response.status_code != 200: - raise Exception(f"**Error**: {response.status_code}, {response.text}") + base_url = kwargs.get("base_url", None) + super().__init__(key, model_name, base_url) + # Add accept header + self.headers["accept"] = "application/json" + def _process_response(self, response): + # Use chunk_size=1024 for processing response for chunk in response.iter_content(chunk_size=1024): if chunk: yield chunk + def tts(self, text, voice="Chinese Female", stream=True): + text = self.normalize_text(text) + payload = self._build_payload(text, voice) + response = self._send_request("/v1/audio/speech", payload, stream=stream) + return self._process_response(response) -class SILICONFLOWTTS(Base): + +class SILICONFLOWTTS(HTTPBasedTTS): _FACTORY_NAME = "SILICONFLOW" def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"): if not base_url: base_url = "https://api.siliconflow.cn/v1" - self.api_key = key - self.model_name = model_name - self.base_url = base_url - self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + super().__init__(key, model_name, base_url) - def tts(self, text, voice="anna"): - text = self.normalize_text(text) - payload = { + def _build_payload(self, text, voice, **kwargs): + # Custom payload structure for SILICONFLOW + return { "model": self.model_name, "input": text, "voice": f"{self.model_name}:{voice}", @@ -381,13 +416,11 @@ class SILICONFLOWTTS(Base): "gain": 0, } - response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload) - - if response.status_code != 200: - raise Exception(f"**Error**: {response.status_code}, {response.text}") - for chunk in response.iter_content(): - if chunk: - yield chunk + def tts(self, text, voice="anna"): + text = self.normalize_text(text) + payload = self._build_payload(text, voice) + response = self._send_request("/audio/speech", payload) + return self._process_response(response) class DeepInfraTTS(OpenAITTS):