mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Fix: code supports matplotlib (#13724)
### What problem does this PR solve? Code as "final" node:  Code as "mid" node:  ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -20,20 +20,20 @@ import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from timeit import default_timer as timer
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
from timeit import default_timer as timer
|
||||
from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
|
||||
from agent.component.llm import LLM, LLMParam
|
||||
from agent.tools.base import LLMToolPluginCallSession, ToolBase, ToolMeta, ToolParamBase
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from common.connection_utils import timeout
|
||||
from rag.prompts.generator import next_step_async, COMPLETE_TASK, \
|
||||
citation_prompt, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
from rag.prompts.generator import citation_plus, citation_prompt, full_question, kb_prompt, message_fit_in, structured_output_prompt
|
||||
|
||||
|
||||
class AgentParam(LLMParam, ToolParamBase):
|
||||
@@ -42,35 +42,25 @@ class AgentParam(LLMParam, ToolParamBase):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "agent",
|
||||
"description": "This is an agent for a specific task.",
|
||||
"parameters": {
|
||||
"user_prompt": {
|
||||
"type": "string",
|
||||
"description": "This is the order you need to send to the agent.",
|
||||
"default": "",
|
||||
"required": True
|
||||
},
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Supervisor's reasoning for choosing the this agent. "
|
||||
"Explain why this agent is being invoked and what is expected of it."
|
||||
),
|
||||
"required": True
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"All relevant background information, prior facts, decisions, "
|
||||
"and state needed by the agent to solve the current query. "
|
||||
"Should be as detailed and self-contained as possible."
|
||||
),
|
||||
"required": True
|
||||
},
|
||||
}
|
||||
}
|
||||
self.meta: ToolMeta = {
|
||||
"name": "agent",
|
||||
"description": "This is an agent for a specific task.",
|
||||
"parameters": {
|
||||
"user_prompt": {"type": "string", "description": "This is the order you need to send to the agent.", "default": "", "required": True},
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": ("Supervisor's reasoning for choosing the this agent. Explain why this agent is being invoked and what is expected of it."),
|
||||
"required": True,
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"All relevant background information, prior facts, decisions, and state needed by the agent to solve the current query. Should be as detailed and self-contained as possible."
|
||||
),
|
||||
"required": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
super().__init__()
|
||||
self.function_name = "agent"
|
||||
self.tools = []
|
||||
@@ -92,12 +82,14 @@ class Agent(LLM, ToolBase):
|
||||
indexed_name = f"{original_name}_{idx}"
|
||||
self.tools[indexed_name] = cpn
|
||||
chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id)
|
||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), chat_model_config,
|
||||
max_retries=self._param.max_retries,
|
||||
retry_interval=self._param.delay_after_error,
|
||||
max_rounds=self._param.max_rounds,
|
||||
verbose_tool_use=True
|
||||
)
|
||||
self.chat_mdl = LLMBundle(
|
||||
self._canvas.get_tenant_id(),
|
||||
chat_model_config,
|
||||
max_retries=self._param.max_retries,
|
||||
retry_interval=self._param.delay_after_error,
|
||||
max_rounds=self._param.max_rounds,
|
||||
verbose_tool_use=False,
|
||||
)
|
||||
self.tool_meta = []
|
||||
for indexed_name, tool_obj in self.tools.items():
|
||||
original_meta = tool_obj.get_meta()
|
||||
@@ -114,10 +106,30 @@ class Agent(LLM, ToolBase):
|
||||
self.tools[tnm] = tool_call_session
|
||||
self.callback = partial(self._canvas.tool_use_callback, id)
|
||||
self.toolcall_session = LLMToolPluginCallSession(self.tools, self.callback)
|
||||
#self.chat_mdl.bind_tools(self.toolcall_session, self.tool_metas)
|
||||
if self.tool_meta:
|
||||
self.chat_mdl.bind_tools(self.toolcall_session, self.tool_meta)
|
||||
|
||||
def _fit_messages(self, prompt: str, msg: list[dict]) -> list[dict]:
|
||||
_, fitted_messages = message_fit_in(
|
||||
[{"role": "system", "content": prompt}, *msg],
|
||||
int(self.chat_mdl.max_length * 0.97),
|
||||
)
|
||||
return fitted_messages
|
||||
|
||||
@staticmethod
|
||||
def _append_system_prompt(msg: list[dict], extra_prompt: str) -> None:
|
||||
if extra_prompt and msg and msg[0]["role"] == "system":
|
||||
msg[0]["content"] += "\n" + extra_prompt
|
||||
|
||||
@staticmethod
|
||||
def _clean_formatted_answer(ans: str) -> str:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
|
||||
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||
|
||||
def _load_tool_obj(self, cpn: dict) -> object:
|
||||
from agent.component import component_class
|
||||
|
||||
tool_name = cpn["component_name"]
|
||||
param = component_class(tool_name + "Param")()
|
||||
param.update(cpn["params"])
|
||||
@@ -130,7 +142,7 @@ class Agent(LLM, ToolBase):
|
||||
return component_class(cpn["component_name"])(self._canvas, cpn_id, param)
|
||||
|
||||
def get_meta(self) -> dict[str, Any]:
|
||||
self._param.function_name= self._id.split("-->")[-1]
|
||||
self._param.function_name = self._id.split("-->")[-1]
|
||||
m = super().get_meta()
|
||||
if hasattr(self._param, "user_prompt") and self._param.user_prompt:
|
||||
m["function"]["parameters"]["properties"]["user_prompt"] = self._param.user_prompt
|
||||
@@ -139,10 +151,7 @@ class Agent(LLM, ToolBase):
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
res = {}
|
||||
for k, v in self.get_input_elements().items():
|
||||
res[k] = {
|
||||
"type": "line",
|
||||
"name": v["name"]
|
||||
}
|
||||
res[k] = {"type": "line", "name": v["name"]}
|
||||
for cpn in self._param.tools:
|
||||
if not isinstance(cpn, LLM):
|
||||
continue
|
||||
@@ -175,7 +184,7 @@ class Agent(LLM, ToolBase):
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20 * 60)))
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
@@ -204,19 +213,17 @@ class Agent(LLM, ToolBase):
|
||||
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
|
||||
schema_prompt = structured_output_prompt(schema)
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
component = self._canvas.get_component(self._id)
|
||||
downstreams = component["downstream"] if component else []
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
|
||||
has_message_downstream = any(self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams)
|
||||
if has_message_downstream and not (ex and ex["goto"]) and not output_schema:
|
||||
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt))
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
use_tools = []
|
||||
ans = ""
|
||||
async for delta_ans, _tk in self._react_with_tools_streamly_async_simple(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
ans += delta_ans
|
||||
msg = self._fit_messages(prompt, msg)
|
||||
self._append_system_prompt(msg, schema_prompt)
|
||||
ans = await self._generate_async(msg)
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"Agent._chat got error. response: {ans}")
|
||||
@@ -230,14 +237,8 @@ class Agent(LLM, ToolBase):
|
||||
error = ""
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
try:
|
||||
def clean_formated_answer(ans: str) -> str:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
|
||||
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||
obj = json_repair.loads(clean_formated_answer(ans))
|
||||
obj = json_repair.loads(self._clean_formatted_answer(ans))
|
||||
self.set_output("structured", obj)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
return obj
|
||||
except Exception:
|
||||
error = "The answer cannot be parsed as JSON"
|
||||
@@ -248,333 +249,118 @@ class Agent(LLM, ToolBase):
|
||||
self.set_output("_ERROR", error)
|
||||
return
|
||||
|
||||
attachment_content = self._collect_tool_attachment_content(existing_text=ans)
|
||||
if attachment_content:
|
||||
ans += "\n\n" + attachment_content
|
||||
artifact_md = self._collect_tool_artifact_markdown(existing_text=ans)
|
||||
if artifact_md:
|
||||
ans += "\n\n" + artifact_md
|
||||
self.set_output("content", ans)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
return ans
|
||||
|
||||
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer_without_toolcall = ""
|
||||
use_tools = []
|
||||
async for delta_ans, _ in self._react_with_tools_streamly_async_simple(prompt, msg, use_tools, user_defined_prompt):
|
||||
if len(msg) > 3:
|
||||
st = timer()
|
||||
user_request = await full_question(messages=msg, chat_mdl=self.chat_mdl)
|
||||
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer() - st)
|
||||
msg = [*msg[:-1], {"role": "user", "content": user_request}]
|
||||
|
||||
msg = self._fit_messages(prompt, msg)
|
||||
|
||||
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
|
||||
cited = False
|
||||
if need2cite and len(msg) < 7:
|
||||
self._append_system_prompt(msg, citation_prompt())
|
||||
cited = True
|
||||
|
||||
answer = ""
|
||||
async for delta in self._generate_streamly(msg):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
|
||||
if delta_ans.find("**ERROR**") >= 0:
|
||||
if delta.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", delta_ans)
|
||||
return
|
||||
answer_without_toolcall += delta_ans
|
||||
yield delta_ans
|
||||
|
||||
self.set_output("content", answer_without_toolcall)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
|
||||
async def _react_with_tools_streamly_async_simple(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||
token_count = 0
|
||||
tool_metas = self.tool_meta
|
||||
hist = deepcopy(history)
|
||||
last_calling = ""
|
||||
if len(hist) > 3:
|
||||
st = timer()
|
||||
user_request = await full_question(messages=history, chat_mdl=self.chat_mdl)
|
||||
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
||||
else:
|
||||
user_request = history[-1]["content"]
|
||||
|
||||
def build_task_desc(prompt: str, user_request: str, user_defined_prompt: dict | None = None) -> str:
|
||||
"""Build a minimal task_desc by concatenating prompt, query, and tool schemas."""
|
||||
user_defined_prompt = user_defined_prompt or {}
|
||||
|
||||
task_desc = (
|
||||
"### Agent Prompt\n"
|
||||
f"{prompt}\n\n"
|
||||
"### User Request\n"
|
||||
f"{user_request}\n\n"
|
||||
)
|
||||
|
||||
if user_defined_prompt:
|
||||
udp_json = json.dumps(user_defined_prompt, ensure_ascii=False, indent=2)
|
||||
task_desc += "\n### User Defined Prompts\n" + udp_json + "\n"
|
||||
|
||||
return task_desc
|
||||
|
||||
|
||||
async def use_tool_async(name, args):
|
||||
nonlocal hist, use_tools, last_calling
|
||||
logging.info(f"{last_calling=} == {name=}")
|
||||
last_calling = name
|
||||
tool_response = await self.toolcall_session.tool_call_async(name, args)
|
||||
use_tools.append({
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
"results": tool_response
|
||||
})
|
||||
return name, tool_response
|
||||
|
||||
async def complete():
|
||||
nonlocal hist
|
||||
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
|
||||
if schema_prompt:
|
||||
need2cite = False
|
||||
cited = False
|
||||
if hist and hist[0]["role"] == "system":
|
||||
if schema_prompt:
|
||||
hist[0]["content"] += "\n" + schema_prompt
|
||||
if need2cite and len(hist) < 7:
|
||||
hist[0]["content"] += citation_prompt()
|
||||
cited = True
|
||||
yield "", token_count
|
||||
|
||||
_hist = hist
|
||||
if len(hist) > 12:
|
||||
_hist = [hist[0], hist[1], *hist[-10:]]
|
||||
entire_txt = ""
|
||||
async for delta_ans in self._generate_streamly(_hist):
|
||||
if not need2cite or cited:
|
||||
yield delta_ans, 0
|
||||
entire_txt += delta_ans
|
||||
if not need2cite or cited:
|
||||
self.set_output("_ERROR", delta)
|
||||
return
|
||||
if not need2cite or cited:
|
||||
yield delta
|
||||
answer += delta
|
||||
|
||||
st = timer()
|
||||
txt = ""
|
||||
async for delta_ans in self._gen_citations_async(entire_txt):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
yield delta_ans, 0
|
||||
txt += delta_ans
|
||||
|
||||
self.callback("gen_citations", {}, txt, elapsed_time=timer()-st)
|
||||
|
||||
def build_observation(tool_call_res: list[tuple]) -> str:
|
||||
"""
|
||||
Build a Observation from tool call results.
|
||||
No LLM involved.
|
||||
"""
|
||||
if not tool_call_res:
|
||||
return ""
|
||||
|
||||
lines = ["Observation:"]
|
||||
for name, result in tool_call_res:
|
||||
lines.append(f"[{name} result]")
|
||||
lines.append(str(result))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def append_user_content(hist, content):
|
||||
if hist[-1]["role"] == "user":
|
||||
hist[-1]["content"] += content
|
||||
else:
|
||||
hist.append({"role": "user", "content": content})
|
||||
if not need2cite or cited:
|
||||
attachment_content = self._collect_tool_attachment_content(existing_text=answer)
|
||||
if attachment_content:
|
||||
yield "\n\n" + attachment_content
|
||||
answer += "\n\n" + attachment_content
|
||||
artifact_md = self._collect_tool_artifact_markdown(existing_text=answer)
|
||||
if artifact_md:
|
||||
yield "\n\n" + artifact_md
|
||||
answer += "\n\n" + artifact_md
|
||||
self.set_output("content", answer)
|
||||
return
|
||||
|
||||
st = timer()
|
||||
task_desc = build_task_desc(prompt, user_request, user_defined_prompt)
|
||||
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||
for _ in range(self._param.max_rounds + 1):
|
||||
cited_answer = ""
|
||||
async for delta in self._gen_citations_async(answer):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||
# self.callback("next_step", {}, str(response)[:256]+"...")
|
||||
token_count += tk or 0
|
||||
hist.append({"role": "assistant", "content": response})
|
||||
try:
|
||||
# Remove markdown code fences properly
|
||||
cleaned_response = re.sub(r"^.*```json\s*", "", response, flags=re.DOTALL)
|
||||
cleaned_response = re.sub(r"```\s*$", "", cleaned_response, flags=re.DOTALL)
|
||||
functions = json_repair.loads(cleaned_response)
|
||||
if not isinstance(functions, list):
|
||||
raise TypeError(f"List should be returned, but `{functions}`")
|
||||
for f in functions:
|
||||
if not isinstance(f, dict):
|
||||
raise TypeError(f"An object type should be returned, but `{f}`")
|
||||
|
||||
tool_tasks = []
|
||||
for func in functions:
|
||||
name = func["name"]
|
||||
args = func["arguments"]
|
||||
if name == COMPLETE_TASK:
|
||||
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
||||
async for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
return
|
||||
|
||||
tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
|
||||
|
||||
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
|
||||
st = timer()
|
||||
reflection = build_observation(results)
|
||||
append_user_content(hist, reflection)
|
||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
|
||||
e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}"
|
||||
append_user_content(hist, str(e))
|
||||
|
||||
logging.warning( f"Exceed max rounds: {self._param.max_rounds}")
|
||||
final_instruction = f"""
|
||||
{user_request}
|
||||
IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request.
|
||||
Instructions:
|
||||
1. SYNTHESIZE all information collected during this conversation
|
||||
2. Provide a COMPLETE response using existing data - do not suggest additional research
|
||||
3. Structure your response as a FINAL DELIVERABLE, not a plan
|
||||
4. If information is incomplete, state what you found and provide the best analysis possible with available data
|
||||
5. DO NOT mention conversation limits or suggest further steps
|
||||
6. Focus on delivering VALUE with the information already gathered
|
||||
Respond immediately with your final comprehensive answer.
|
||||
"""
|
||||
if self.check_if_canceled("Agent final instruction"):
|
||||
return
|
||||
append_user_content(hist, final_instruction)
|
||||
|
||||
async for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
|
||||
# async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||
# token_count = 0
|
||||
# tool_metas = self.tool_meta
|
||||
# hist = deepcopy(history)
|
||||
# last_calling = ""
|
||||
# if len(hist) > 3:
|
||||
# st = timer()
|
||||
# user_request = await full_question(messages=history, chat_mdl=self.chat_mdl)
|
||||
# self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
||||
# else:
|
||||
# user_request = history[-1]["content"]
|
||||
|
||||
# async def use_tool_async(name, args):
|
||||
# nonlocal hist, use_tools, last_calling
|
||||
# logging.info(f"{last_calling=} == {name=}")
|
||||
# last_calling = name
|
||||
# tool_response = await self.toolcall_session.tool_call_async(name, args)
|
||||
# use_tools.append({
|
||||
# "name": name,
|
||||
# "arguments": args,
|
||||
# "results": tool_response
|
||||
# })
|
||||
# # self.callback("add_memory", {}, "...")
|
||||
# #self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt)
|
||||
|
||||
# return name, tool_response
|
||||
|
||||
# async def complete():
|
||||
# nonlocal hist
|
||||
# need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
|
||||
# if schema_prompt:
|
||||
# need2cite = False
|
||||
# cited = False
|
||||
# if hist and hist[0]["role"] == "system":
|
||||
# if schema_prompt:
|
||||
# hist[0]["content"] += "\n" + schema_prompt
|
||||
# if need2cite and len(hist) < 7:
|
||||
# hist[0]["content"] += citation_prompt()
|
||||
# cited = True
|
||||
# yield "", token_count
|
||||
|
||||
# _hist = hist
|
||||
# if len(hist) > 12:
|
||||
# _hist = [hist[0], hist[1], *hist[-10:]]
|
||||
# entire_txt = ""
|
||||
# async for delta_ans in self._generate_streamly(_hist):
|
||||
# if not need2cite or cited:
|
||||
# yield delta_ans, 0
|
||||
# entire_txt += delta_ans
|
||||
# if not need2cite or cited:
|
||||
# return
|
||||
|
||||
# st = timer()
|
||||
# txt = ""
|
||||
# async for delta_ans in self._gen_citations_async(entire_txt):
|
||||
# if self.check_if_canceled("Agent streaming"):
|
||||
# return
|
||||
# yield delta_ans, 0
|
||||
# txt += delta_ans
|
||||
|
||||
# self.callback("gen_citations", {}, txt, elapsed_time=timer()-st)
|
||||
|
||||
# def append_user_content(hist, content):
|
||||
# if hist[-1]["role"] == "user":
|
||||
# hist[-1]["content"] += content
|
||||
# else:
|
||||
# hist.append({"role": "user", "content": content})
|
||||
|
||||
# st = timer()
|
||||
# task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||
# self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||
# for _ in range(self._param.max_rounds + 1):
|
||||
# if self.check_if_canceled("Agent streaming"):
|
||||
# return
|
||||
# response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||
# # self.callback("next_step", {}, str(response)[:256]+"...")
|
||||
# token_count += tk or 0
|
||||
# hist.append({"role": "assistant", "content": response})
|
||||
# try:
|
||||
# functions = json_repair.loads(re.sub(r"```.*", "", response))
|
||||
# if not isinstance(functions, list):
|
||||
# raise TypeError(f"List should be returned, but `{functions}`")
|
||||
# for f in functions:
|
||||
# if not isinstance(f, dict):
|
||||
# raise TypeError(f"An object type should be returned, but `{f}`")
|
||||
|
||||
# tool_tasks = []
|
||||
# for func in functions:
|
||||
# name = func["name"]
|
||||
# args = func["arguments"]
|
||||
# if name == COMPLETE_TASK:
|
||||
# append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
||||
# async for txt, tkcnt in complete():
|
||||
# yield txt, tkcnt
|
||||
# return
|
||||
|
||||
# tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
|
||||
|
||||
# results = await asyncio.gather(*tool_tasks) if tool_tasks else []
|
||||
# st = timer()
|
||||
# reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
|
||||
# append_user_content(hist, reflection)
|
||||
# self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||
|
||||
# except Exception as e:
|
||||
# logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
|
||||
# e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}"
|
||||
# append_user_content(hist, str(e))
|
||||
|
||||
# logging.warning( f"Exceed max rounds: {self._param.max_rounds}")
|
||||
# final_instruction = f"""
|
||||
# {user_request}
|
||||
# IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request.
|
||||
# Instructions:
|
||||
# 1. SYNTHESIZE all information collected during this conversation
|
||||
# 2. Provide a COMPLETE response using existing data - do not suggest additional research
|
||||
# 3. Structure your response as a FINAL DELIVERABLE, not a plan
|
||||
# 4. If information is incomplete, state what you found and provide the best analysis possible with available data
|
||||
# 5. DO NOT mention conversation limits or suggest further steps
|
||||
# 6. Focus on delivering VALUE with the information already gathered
|
||||
# Respond immediately with your final comprehensive answer.
|
||||
# """
|
||||
# if self.check_if_canceled("Agent final instruction"):
|
||||
# return
|
||||
# append_user_content(hist, final_instruction)
|
||||
|
||||
# async for txt, tkcnt in complete():
|
||||
# yield txt, tkcnt
|
||||
yield delta
|
||||
cited_answer += delta
|
||||
attachment_content = self._collect_tool_attachment_content(existing_text=cited_answer)
|
||||
if attachment_content:
|
||||
yield "\n\n" + attachment_content
|
||||
cited_answer += "\n\n" + attachment_content
|
||||
artifact_md = self._collect_tool_artifact_markdown(existing_text=cited_answer)
|
||||
if artifact_md:
|
||||
yield "\n\n" + artifact_md
|
||||
cited_answer += "\n\n" + artifact_md
|
||||
self.callback("gen_citations", {}, cited_answer, elapsed_time=timer() - st)
|
||||
self.set_output("content", cited_answer)
|
||||
|
||||
async def _gen_citations_async(self, text):
|
||||
retrievals = self._canvas.get_reference()
|
||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
||||
async for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
||||
{"role": "user", "content": text}
|
||||
]):
|
||||
async for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))}, {"role": "user", "content": text}]):
|
||||
yield delta_ans
|
||||
|
||||
def _collect_tool_artifact_markdown(self, existing_text: str = "") -> str:
|
||||
md_parts = []
|
||||
for tool_obj in self.tools.values():
|
||||
if not hasattr(tool_obj, "_param") or not hasattr(tool_obj._param, "outputs"):
|
||||
continue
|
||||
artifacts_meta = tool_obj._param.outputs.get("_ARTIFACTS", {})
|
||||
artifacts = artifacts_meta.get("value") if isinstance(artifacts_meta, dict) else None
|
||||
if not artifacts:
|
||||
continue
|
||||
for art in artifacts:
|
||||
if not isinstance(art, dict):
|
||||
continue
|
||||
url = art.get("url", "")
|
||||
if url and (f"" in existing_text or f"" in existing_text):
|
||||
continue
|
||||
if art.get("mime_type", "").startswith("image/"):
|
||||
md_parts.append(f"![{art['name']}]({url})")
|
||||
else:
|
||||
md_parts.append(f"[Download {art['name']}]({url})")
|
||||
return "\n\n".join(md_parts)
|
||||
|
||||
def _collect_tool_attachment_content(self, existing_text: str = "") -> str:
|
||||
text_parts = []
|
||||
for tool_obj in self.tools.values():
|
||||
if not hasattr(tool_obj, "_param") or not hasattr(tool_obj._param, "outputs"):
|
||||
continue
|
||||
content_meta = tool_obj._param.outputs.get("_ATTACHMENT_CONTENT", {})
|
||||
content = content_meta.get("value") if isinstance(content_meta, dict) else None
|
||||
if not content or not isinstance(content, str):
|
||||
continue
|
||||
content = content.strip()
|
||||
if not content or content in existing_text:
|
||||
continue
|
||||
text_parts.append(content)
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
def reset(self, only_output=False):
|
||||
"""
|
||||
Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession.
|
||||
|
||||
Reference in New Issue
Block a user