mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Fix Agent chat Minimax content in thinking (#15937)
Fix Agent chat Minimax content in thinking
This commit is contained in:
@@ -23,6 +23,7 @@ from typing import Any, AsyncGenerator
|
||||
import json_repair
|
||||
from functools import partial
|
||||
from common.constants import LLMType
|
||||
from api.db.services.dialog_service import _stream_with_think_delta
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance, get_model_type_by_name
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
@@ -284,84 +285,23 @@ class LLM(ComponentBase):
|
||||
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||
|
||||
async def _generate_streamly(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
|
||||
async def delta_wrapper(txt_iter):
|
||||
ans = ""
|
||||
last_idx = 0
|
||||
endswith_think = False
|
||||
|
||||
def delta(txt):
|
||||
nonlocal ans, last_idx, endswith_think
|
||||
delta_ans = txt[last_idx:]
|
||||
ans = txt
|
||||
|
||||
if delta_ans.find("<think>") == 0:
|
||||
last_idx += len("<think>")
|
||||
return "<think>"
|
||||
elif delta_ans.find("<think>") > 0:
|
||||
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
|
||||
last_idx += delta_ans.find("<think>")
|
||||
return delta_ans
|
||||
elif delta_ans.endswith("</think>"):
|
||||
endswith_think = True
|
||||
elif endswith_think:
|
||||
endswith_think = False
|
||||
return "</think>"
|
||||
|
||||
last_idx = len(ans)
|
||||
if ans.endswith("</think>"):
|
||||
last_idx -= len("</think>")
|
||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
async for t in txt_iter:
|
||||
yield delta(t)
|
||||
|
||||
if not self.imgs:
|
||||
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
|
||||
yield t
|
||||
return
|
||||
|
||||
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
|
||||
yield t
|
||||
stream_kwargs = {"images": self.imgs} if self.imgs else {}
|
||||
stream_kwargs.update(kwargs)
|
||||
stream = self.chat_mdl.async_chat_streamly_delta(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs)
|
||||
async for _, value, _ in _stream_with_think_delta(stream, min_tokens=0):
|
||||
yield value
|
||||
|
||||
async def _stream_output_async(self, prompt, msg):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer = ""
|
||||
last_idx = 0
|
||||
endswith_think = False
|
||||
|
||||
def delta(txt):
|
||||
nonlocal answer, last_idx, endswith_think
|
||||
delta_ans = txt[last_idx:]
|
||||
answer = txt
|
||||
|
||||
if delta_ans.find("<think>") == 0:
|
||||
last_idx += len("<think>")
|
||||
return "<think>"
|
||||
elif delta_ans.find("<think>") > 0:
|
||||
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
|
||||
last_idx += delta_ans.find("<think>")
|
||||
return delta_ans
|
||||
elif delta_ans.endswith("</think>"):
|
||||
endswith_think = True
|
||||
elif endswith_think:
|
||||
endswith_think = False
|
||||
return "</think>"
|
||||
|
||||
last_idx = len(answer)
|
||||
if answer.endswith("</think>"):
|
||||
last_idx -= len("</think>")
|
||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
stream_kwargs = {"images": self.imgs} if self.imgs else {}
|
||||
extra_chat_kwargs = self._get_chat_template_kwargs()
|
||||
stream_kwargs.update(extra_chat_kwargs)
|
||||
async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs):
|
||||
stream = self.chat_mdl.async_chat_streamly_delta(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs)
|
||||
async for _, ans, _ in _stream_with_think_delta(stream, min_tokens=0):
|
||||
if self.check_if_canceled("LLM streaming"):
|
||||
return
|
||||
|
||||
if isinstance(ans, int):
|
||||
continue
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
@@ -370,7 +310,8 @@ class LLM(ComponentBase):
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
|
||||
yield delta(ans)
|
||||
answer += ans
|
||||
yield ans
|
||||
|
||||
self.set_output("content", answer)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user