From f169ab4b3980c78a26e8da8b8c6ffda258d35b04 Mon Sep 17 00:00:00 2001 From: plind <59729252+plind-junior@users.noreply.github.com> Date: Mon, 18 May 2026 23:20:40 -0700 Subject: [PATCH] feat(tts): cache synthesized speech in Redis to avoid redundant calls (#14851) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What problem does this PR solve? Closes #12017. TTS output is deterministic for a given `(model, text)` pair, so re-running the same text through the same TTS model produces the same bytes — yet `Canvas.tts` and `dialog_service.tts` re-synthesized on every request. That's slow and wastes provider quota whenever the same assistant response is replayed, shared across users, or repeated within a session. ### Change New helper `rag/utils/tts_cache.py` with `synthesize_with_cache(tts_mdl, cleaned_text)`: - **Key:** `tts:cache:{model_id}:{sha256(text)}` — separate namespace per model, identical cleaned text reuses a single entry across both call sites. - **Value:** the hex-encoded audio blob both call sites already returned. No format change for downstream consumers. - **TTL:** 7 days by default, configurable via `RAGFLOW_TTS_CACHE_TTL_SECONDS`. - **Failure modes:** a Redis hiccup falls back to direct synthesis; a failed synthesis still returns `None` (existing contract preserved). [`Canvas.tts`](https://github.com/infiniflow/ragflow/blob/main/agent/canvas.py#L683-L724) and [`dialog_service.tts`](https://github.com/infiniflow/ragflow/blob/main/api/db/services/dialog_service.py#L1367-L1380) now route through the helper; the per-file bytes-accumulation/hex-encode loop has been removed in favor of one shared implementation. ## Type of change - [x] New Feature (non-breaking change which adds functionality) ## Test plan - [ ] **Cache hit, chat path:** Configure a dialog with TTS enabled, ask the same question twice with `stream=false`. Verify the second response returns the same `audio_binary` and that the second invocation doesn't hit the TTS provider (e.g., observe provider-side logs / usage counters; check no `LLMBundle.tts can't update token usage` log line on the second run). - [ ] **Cache hit, agent path:** Same exercise via a Conversational Agent that includes a Message component playing back the answer. - [ ] **Cache isolation per model:** Switch tenant's `tts_id` between two models, run the same text against each — confirm the second model's first synthesis still happens (no cross-model hits). - [ ] **TTL override:** Set `RAGFLOW_TTS_CACHE_TTL_SECONDS=120`, confirm the entry expires after 2 minutes. - [ ] **Redis unavailable:** Stop Redis (or break the connection). Verify the TTS endpoint still works — synthesis falls back to direct calls, with a `TTS cache lookup failed` / `TTS cache store failed` warning logged. - [ ] **Failure path:** Configure a TTS model with an invalid API key, ensure the response still returns successfully with `audio_binary=None` (no regression vs. current behavior). --- agent/canvas.py | 11 +-- api/db/services/dialog_service.py | 11 +-- rag/utils/tts_cache.py | 120 ++++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 18 deletions(-) create mode 100644 rag/utils/tts_cache.py diff --git a/agent/canvas.py b/agent/canvas.py index bbd06facbb..3421d207ed 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -17,7 +17,6 @@ import asyncio import base64 import datetime import inspect -import binascii import json import logging import re @@ -39,6 +38,7 @@ from common.misc_utils import get_uuid, hash_str2int from common.exceptions import TaskCanceledException from rag.prompts.generator import chunks_format from rag.utils.redis_conn import REDIS_CONN +from rag.utils.tts_cache import synthesize_with_cache class Graph: """ @@ -714,14 +714,7 @@ class Canvas(Graph): text = clean_tts_text(text) if not text: return None - bin = b"" - try: - for chunk in tts_mdl.tts(text): - bin += chunk - except Exception as e: - logging.error(f"TTS failed: {e}, text={text!r}") - return None - return binascii.hexlify(bin).decode("utf-8") + return synthesize_with_cache(tts_mdl, text) def get_history(self, window_size): convs = [] diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index f7a5befc3f..aa6d255097 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -14,7 +14,6 @@ # limitations under the License. # import asyncio -import binascii import logging import re import time @@ -51,6 +50,7 @@ from rag.nlp.search import index_name from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, PROMPT_JINJA_ENV, ASK_SUMMARY from common.token_utils import num_tokens_from_string from rag.utils.tavily_conn import Tavily +from rag.utils.tts_cache import synthesize_with_cache from common.string_utils import remove_redundant_spaces from common import settings @@ -1427,14 +1427,7 @@ def tts(tts_mdl, text): text = clean_tts_text(text) if not text: return None - bin = b"" - try: - for chunk in tts_mdl.tts(text): - bin += chunk - except Exception as e: - logging.error(f"TTS failed: {e}, text={text!r}") - return None - return binascii.hexlify(bin).decode("utf-8") + return synthesize_with_cache(tts_mdl, text) class _ThinkStreamState: diff --git a/rag/utils/tts_cache.py b/rag/utils/tts_cache.py new file mode 100644 index 0000000000..a96f192528 --- /dev/null +++ b/rag/utils/tts_cache.py @@ -0,0 +1,120 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import binascii +import hashlib +import logging +import os +from typing import Any, Optional + +from rag.utils.redis_conn import REDIS_CONN + +_DEFAULT_TTL_SECONDS = 7 * 24 * 60 * 60 +_KEY_PREFIX = "tts:cache:" + + +def _ttl_seconds() -> int: + raw = os.environ.get("RAGFLOW_TTS_CACHE_TTL_SECONDS") + if not raw: + return _DEFAULT_TTL_SECONDS + try: + v = int(raw) + return v if v > 0 else 0 + except ValueError: + logging.warning("Invalid RAGFLOW_TTS_CACHE_TTL_SECONDS=%r, using default", raw) + return _DEFAULT_TTL_SECONDS + + +def _model_id(tts_mdl: Any) -> Optional[str]: + cfg = getattr(tts_mdl, "model_config", None) + if isinstance(cfg, dict): + mid = cfg.get("id") + if mid is not None: + return str(mid) + name = cfg.get("llm_name") or cfg.get("model_name") + if name: + return str(name) + return None + + +def _build_key(tts_mdl: Any, text: str) -> Optional[str]: + mid = _model_id(tts_mdl) + if not mid: + return None + digest = hashlib.sha256(text.encode("utf-8", "ignore")).hexdigest() + return f"{_KEY_PREFIX}{mid}:{digest}" + + +def _to_hex_string(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, bytes): + try: + return value.decode("utf-8") + except Exception: + return None + if isinstance(value, str): + return value + return None + + +def synthesize_with_cache(tts_mdl: Any, cleaned_text: str) -> Optional[str]: + """ + Synthesize ``cleaned_text`` through ``tts_mdl`` and return a hex-encoded + audio blob, reusing a Redis-cached result when available. + + The cache key is derived from the TTS model identifier and a SHA-256 of the + text, so different models keep separate caches and the same text on the + same model resolves to the same key regardless of call site. Returns + ``None`` on synthesis failure; callers should treat that as a no-op the + same way they do today. + """ + if not tts_mdl or not cleaned_text: + return None + + key = _build_key(tts_mdl, cleaned_text) + + if key: + try: + cached = REDIS_CONN.get(key) + except Exception as e: + logging.warning("TTS cache lookup failed: %s", e) + cached = None + hex_cached = _to_hex_string(cached) + if hex_cached: + return hex_cached + + buf = b"" + try: + for chunk in tts_mdl.tts(cleaned_text): + if isinstance(chunk, (bytes, bytearray)): + buf += bytes(chunk) + except Exception as e: + logging.error("TTS failed: %s (text length=%d)", e, len(cleaned_text)) + return None + + if not buf: + return None + + hex_value = binascii.hexlify(buf).decode("utf-8") + + ttl = _ttl_seconds() + if key and ttl > 0: + try: + REDIS_CONN.set(key, hex_value, exp=ttl) + except Exception as e: + logging.warning("TTS cache store failed: %s", e) + + return hex_value