diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 8d282baa77..33f21a717a 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -1436,12 +1436,12 @@ class _ThinkStreamState: def __init__(self) -> None: self.full_text = "" self.last_idx = 0 - self.endswith_think = False - self.last_full = "" self.last_model_full = "" self.in_think = False - self.buffer = "" - self.post_think_text = "" + self.close_pending = False + self.pending_after_close = "" + self.think_buffer = "" + self.answer_buffer = "" def _extract_visible_answer(text: str) -> str: @@ -1457,38 +1457,40 @@ def _extract_visible_answer(text: str) -> str: return f"{thought}{answer}" -def _next_think_delta(state: _ThinkStreamState) -> str: - full_text = state.full_text - if full_text == state.last_full: - return "" - state.last_full = full_text - delta_ans = full_text[state.last_idx :] - - if delta_ans.find("") == 0: - state.last_idx += len("") - return "" - if delta_ans.find("") > 0: - delta_text = full_text[state.last_idx : state.last_idx + delta_ans.find("")] - state.last_idx += delta_ans.find("") - return delta_text - if delta_ans.endswith(""): - state.endswith_think = True - elif state.endswith_think: - state.endswith_think = False - remainder = delta_ans[len("") :] - if remainder: - state.post_think_text = remainder - state.last_idx = len(full_text) - return "" - - state.last_idx = len(full_text) - if full_text.endswith(""): - state.last_idx -= len("") - return re.sub(r"(|)", "", delta_ans) - - async def _stream_with_think_delta(stream_iter, min_tokens: int = 16): state = _ThinkStreamState() + + def _emit_text(section: str, text: str): + if not text: + return None + if section == "think": + state.think_buffer += text + if num_tokens_from_string(state.think_buffer) >= min_tokens: + out = state.think_buffer + state.think_buffer = "" + return out + return None + state.answer_buffer += text + if num_tokens_from_string(state.answer_buffer) >= min_tokens: + out = state.answer_buffer + state.answer_buffer = "" + return out + return None + + def _flush_think_buffer(): + if not state.think_buffer: + return None + out = state.think_buffer + state.think_buffer = "" + return out + + def _flush_answer_buffer(): + if not state.answer_buffer: + return None + out = state.answer_buffer + state.answer_buffer = "" + return out + async for chunk in stream_iter: if not chunk: continue @@ -1501,40 +1503,96 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16): if not new_part: continue state.full_text += new_part - delta = _next_think_delta(state) - if not delta: - continue - if delta in ("", ""): - if delta == "" and state.in_think: - continue - if delta == "" and not state.in_think: - continue - if state.buffer: - yield ("text", state.buffer, state) - state.buffer = "" - state.in_think = delta == "" - yield ("marker", delta, state) - if delta == "" and state.post_think_text: - state.buffer += state.post_think_text - state.post_think_text = "" - if num_tokens_from_string(state.buffer) >= min_tokens: - yield ("text", state.buffer, state) - state.buffer = "" - continue - state.buffer += delta - if num_tokens_from_string(state.buffer) < min_tokens: - continue - yield ("text", state.buffer, state) - state.buffer = "" + pending = new_part - if state.buffer: - yield ("text", state.buffer, state) - state.buffer = "" - if state.post_think_text: - yield ("text", state.post_think_text, state) - state.post_think_text = "" - if state.endswith_think: + if state.close_pending and "" not in pending: + state.close_pending = False + think_piece = _flush_think_buffer() + if think_piece is not None: + yield ("text", think_piece, state) + state.in_think = False + yield ("marker", "", state) + if state.pending_after_close: + answer_piece = state.pending_after_close + state.pending_after_close = "" + out = _emit_text("answer", answer_piece) + if out is not None: + yield ("text", out, state) + answer_piece = re.sub(r"", "", pending or "") + if answer_piece: + out = _emit_text("answer", answer_piece) + if out is not None: + yield ("text", out, state) + continue + + while pending: + open_idx = pending.find("") + close_idx = pending.find("") + + if open_idx == -1 and close_idx == -1: + piece = re.sub(r"", "", pending or "") + if piece: + section = "think" if state.in_think else "answer" + out = _emit_text(section, piece) + if out is not None: + yield ("text", out, state) + break + + if open_idx != -1 and (close_idx == -1 or open_idx < close_idx): + before = pending[:open_idx] + if before: + piece = re.sub(r"", "", before or "") + section = "think" if state.in_think else "answer" + out = _emit_text(section, piece) + if out is not None: + yield ("text", out, state) + pending = pending[open_idx + len("") :] + if not state.in_think: + answer_piece = _flush_answer_buffer() + if answer_piece is not None: + yield ("text", answer_piece, state) + think_piece = _flush_think_buffer() + if think_piece is not None: + yield ("text", think_piece, state) + state.in_think = True + yield ("marker", "", state) + continue + + before = pending[:close_idx] + after = pending[close_idx + len("") :] + if before: + piece = re.sub(r"", "", before or "") + section = "think" if state.in_think else "answer" + out = _emit_text(section, piece) + if out is not None: + yield ("text", out, state) + after_visible = re.sub(r"", "", after or "") + if after_visible.strip(): + think_piece = _flush_think_buffer() + if think_piece is not None: + yield ("text", think_piece, state) + state.in_think = False + yield ("marker", "", state) + pending = after_visible + continue + state.close_pending = True + if after_visible: + state.pending_after_close += after_visible + pending = "" + break + + if state.think_buffer: + yield ("text", state.think_buffer, state) + state.think_buffer = "" + if state.close_pending: + state.in_think = False yield ("marker", "", state) + if state.answer_buffer: + yield ("text", state.answer_buffer, state) + state.answer_buffer = "" + if state.pending_after_close: + yield ("text", state.pending_after_close, state) + state.pending_after_close = "" async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}, search_id=None): diff --git a/test/testcases/unit/test_think_stream_parser.py b/test/testcases/unit/test_think_stream_parser.py new file mode 100644 index 0000000000..e98b57006e --- /dev/null +++ b/test/testcases/unit/test_think_stream_parser.py @@ -0,0 +1,353 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import io +import sys +import unittest +from contextlib import redirect_stdout +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[3] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from api.db.services.dialog_service import _stream_with_think_delta + + +CASES = [ + ( + "minimax", + { + "min_tokens": 16, + "chunks": [ + 'The user has sent a simple greeting "hello". I should respond in a friendly and helpful manner.Hello!', + "\n\n How can I help", + " you today?", + ], + "expected": { + "think": 'The user has sent a simple greeting "hello". I should respond in a friendly and helpful manner.Hello!', + "answer": "\n\n How can I help you today?", + }, + }, + ), + ( + "deepseek", + { + "min_tokens": 16, + "chunks": [ + "We", + " need", + " to", + " respond", + " to", + " the", + " user", + "'s", + " greeting", + ' "', + "hello", + '".', + " The", + " assistant", + " should", + " be", + " friendly", + " and", + " helpful", + ".", + " A", + " simple", + " greeting", + " back", + " is", + " appropriate", + ",", + " perhaps", + " with", + " an", + " offer", + " of", + " assistance", + ".", + "Hello", + "!", + " How", + " can", + " I", + " assist", + " you", + " today", + "?", + ], + "expected": { + "think": 'We need to respond to the user\'s greeting "hello". The assistant should be friendly and helpful. A simple greeting back is appropriate, perhaps with an offer of assistance.', + "answer": "Hello! How can I assist you today?", + }, + }, + ), + ( + "deepseek_repeat", + { + "min_tokens": 16, + "chunks": [ + "We", + " need", + " to", + " respond", + " to", + " the", + " user", + "'s", + ' "', + "hello", + '"', + " again", + ".", + " The", + " user", + " just", + " said", + ' "', + "hello", + '"', + " after", + " I", + " already", + " responded", + ".", + " Possibly", + " they", + "'re", + " testing", + " or", + " just", + " greeting", + " again", + ".", + " I", + "'ll", + " respond", + " in", + " a", + " friendly", + " manner", + ",", + " perhaps", + " acknowledging", + " the", + " repeated", + " greeting", + " and", + " inviting", + " them", + " to", + " ask", + " something", + ".", + "Hello", + " again", + "!", + " How", + " can", + " I", + " help", + " you", + " today", + "?", + ], + "expected": { + "think": 'We need to respond to the user\'s "hello" again. The user just said "hello" after I already responded. Possibly they\'re testing or just greeting again. I\'ll respond in a friendly manner, perhaps acknowledging the repeated greeting and inviting them to ask something.', + "answer": "Hello again! How can I help you today?", + }, + }, + ), + ( + "answer_then_think", + { + "min_tokens": 16, + "chunks": [ + "前言", + " ", + "内部推理一", + "最终回答", + "。", + ], + "expected": { + "think": "内部推理一", + "answer": "前言 最终回答。", + "markers": ["", ""], + }, + }, + ), + ( + "close_pending_eof", + { + "min_tokens": 16, + "chunks": [ + "先思考完毕答案在这里", + ], + "expected": { + "think": "先思考完毕", + "answer": "答案在这里", + "markers": ["", ""], + }, + }, + ), + ( + "mixed_boundary", + { + "min_tokens": 16, + "chunks": [ + "前缀", + "理由A答案A", + " 后缀", + ], + "expected": { + "think": "理由A", + "answer": "前缀答案A 后缀", + "markers": ["", ""], + }, + }, + ), + ( + "think_only_eof", + { + "min_tokens": 16, + "chunks": [ + "只输出思考,不输出最终答案", + ",并且流在这里结束", + ], + "expected": { + "think": "只输出思考,不输出最终答案,并且流在这里结束", + "answer": "", + "markers": [""], + }, + }, + ), + ( + "double_think_blocks", + { + "min_tokens": 16, + "chunks": [ + "第一段推理答案A", + " 第二段推理答案B", + ], + "expected": { + "think": "第一段推理第二段推理", + "answer": "答案A 答案B", + "markers": ["", "", "", ""], + }, + }, + ), + ( + "nested_or_malformed_tags", + { + "min_tokens": 16, + "chunks": [ + "重复开始", + "答案", + "", + "尾巴", + ], + "expected": { + "think": "重复开始", + "answer": "答案尾巴", + "markers": ["", "", ""], + }, + }, + ), + ( + "tiny_think_chunks", + { + "min_tokens": 16, + "chunks": [ + "", + "A", + "B", + "C", + "D", + "E", + "", + "答", + "案", + "输", + "出", + ], + "expected": { + "think": "ABCDE", + "answer": "答案输出", + "markers": ["", ""], + }, + }, + ), + ( + "think_then_answer_then_think", + { + "min_tokens": 16, + "chunks": [ + "第一轮推理第一轮答案", + " 第二轮推理第二轮答案", + ], + "expected": { + "think": "第一轮推理第二轮推理", + "answer": "第一轮答案 第二轮答案", + "markers": ["", "", "", ""], + }, + }, + ), +] + + +async def _iter_chunks(chunks): + for chunk in chunks: + yield chunk + + +async def _collect_case(chunks, min_tokens): + think_parts = [] + answer_parts = [] + markers = [] + section = "answer" + + async for kind, value, _state in _stream_with_think_delta(_iter_chunks(chunks), min_tokens=min_tokens): + if kind == "marker": + markers.append(value) + section = "think" if value == "" else "answer" + continue + if section == "think": + think_parts.append(value) + else: + answer_parts.append(value) + + return "".join(think_parts), "".join(answer_parts), markers + + +class TestThinkStreamParser(unittest.TestCase): + def test_think_stream_parser_cases(self): + for case_name, case in CASES: + with self.subTest(case=case_name): + buf = io.StringIO() + with redirect_stdout(buf): + think_text, answer_text, markers = asyncio.run( + _collect_case(case["chunks"], case["min_tokens"]) + ) + + expected = case["expected"] + self.assertEqual(think_text, expected["think"], case_name) + self.assertEqual(answer_text, expected["answer"], case_name) + if "markers" in expected: + self.assertEqual(markers, expected["markers"], case_name)