mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Fix: agent session log message (#14991)
### What problem does this PR solve? agent session log message ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -108,9 +108,23 @@ def _build_sse_response(body):
|
||||
return resp
|
||||
|
||||
|
||||
def _normalize_agent_reference_entry(reference):
|
||||
if not isinstance(reference, dict):
|
||||
return {"chunks": [], "doc_aggs": []}
|
||||
if "chunks" in reference or "doc_aggs" in reference:
|
||||
return {
|
||||
"chunks": reference.get("chunks", []),
|
||||
"doc_aggs": reference.get("doc_aggs", []),
|
||||
}
|
||||
return {
|
||||
"chunks": reference.get("reference", reference.get("chunks", [])) or [],
|
||||
"doc_aggs": reference.get("doc_aggs", []) or [],
|
||||
}
|
||||
|
||||
|
||||
def _normalize_agent_session(conv):
|
||||
conv["messages"] = conv.pop("message")
|
||||
for info in conv["messages"]:
|
||||
conv["message"] = conv.get("message", [])
|
||||
for info in conv["message"]:
|
||||
if "prompt" in info:
|
||||
info.pop("prompt")
|
||||
conv["agent_id"] = conv.pop("dialog_id")
|
||||
@@ -119,11 +133,15 @@ def _normalize_agent_session(conv):
|
||||
conv["reference"] = [conv["reference"]]
|
||||
else:
|
||||
conv["reference"] = [value for _, value in sorted(conv["reference"].items(), key=lambda item: int(item[0]))]
|
||||
elif isinstance(conv["reference"], list):
|
||||
conv["reference"] = [_normalize_agent_reference_entry(reference) for reference in conv["reference"]]
|
||||
else:
|
||||
conv["reference"] = []
|
||||
|
||||
if conv["reference"]:
|
||||
messages = [message for i, message in enumerate(conv["messages"]) if i != 0 and message["role"] != "user"]
|
||||
messages = [message for i, message in enumerate(conv["message"]) if i != 0 and message["role"] != "user"]
|
||||
for message, reference in zip(messages, conv["reference"]):
|
||||
chunks = reference["chunks"]
|
||||
chunks = reference.get("chunks", [])
|
||||
message["reference"] = [
|
||||
{
|
||||
"id": chunk.get("chunk_id", chunk.get("id")),
|
||||
@@ -144,6 +162,171 @@ def _agent_session_list_result(data, total):
|
||||
return jsonify({"code": RetCode.SUCCESS, "message": "success", "data": data, "total": total})
|
||||
|
||||
|
||||
async def _run_workflow_session(
|
||||
tenant_id,
|
||||
agent_id,
|
||||
workflow_conv,
|
||||
canvas,
|
||||
query,
|
||||
files,
|
||||
inputs,
|
||||
user_id,
|
||||
session_id,
|
||||
custom_header,
|
||||
canvas_title,
|
||||
canvas_category,
|
||||
return_trace,
|
||||
stream,
|
||||
):
|
||||
async def commit_runtime_replica():
|
||||
commit_ok = CanvasReplicaService.commit_after_run(
|
||||
canvas_id=agent_id,
|
||||
tenant_id=str(tenant_id),
|
||||
runtime_user_id=user_id,
|
||||
dsl=json.loads(str(canvas)),
|
||||
canvas_category=canvas_category,
|
||||
title=canvas_title,
|
||||
)
|
||||
if not commit_ok:
|
||||
logging.error(
|
||||
"Canvas runtime replica commit failed: canvas_id=%s tenant_id=%s runtime_user_id=%s",
|
||||
agent_id,
|
||||
tenant_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
workflow_conv.setdefault("message", [])
|
||||
if isinstance(workflow_conv.get("reference"), dict):
|
||||
if "chunks" in workflow_conv["reference"]:
|
||||
workflow_conv["reference"] = [workflow_conv["reference"]]
|
||||
else:
|
||||
workflow_conv["reference"] = [
|
||||
value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0]))
|
||||
]
|
||||
elif not isinstance(workflow_conv.get("reference"), list):
|
||||
workflow_conv["reference"] = []
|
||||
workflow_conv["reference"] = [_normalize_agent_reference_entry(reference) for reference in workflow_conv["reference"]]
|
||||
|
||||
turn_id = workflow_conv["message"][-1].get("id") if workflow_conv["message"] else get_uuid()
|
||||
full_content = ""
|
||||
reference = {}
|
||||
final_ans = {}
|
||||
trace_items = []
|
||||
structured_output = {}
|
||||
|
||||
async def persist_workflow_session():
|
||||
if not final_ans:
|
||||
return
|
||||
workflow_conv["message"].append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content,
|
||||
"created_at": time.time(),
|
||||
"id": turn_id,
|
||||
}
|
||||
)
|
||||
workflow_conv["reference"].append(_normalize_agent_reference_entry(reference))
|
||||
workflow_conv["dsl"] = json.loads(str(canvas))
|
||||
workflow_conv["source"] = workflow_conv.get("source") or "workflow"
|
||||
await thread_pool_exec(API4ConversationService.append_message, session_id, workflow_conv)
|
||||
await commit_runtime_replica()
|
||||
|
||||
if stream:
|
||||
|
||||
async def sse():
|
||||
nonlocal full_content, reference, final_ans, trace_items, structured_output
|
||||
done_sent = False
|
||||
try:
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
ans["session_id"] = session_id
|
||||
if ans.get("event") == "message":
|
||||
full_content += ans.get("data", {}).get("content", "")
|
||||
if ans.get("data", {}).get("reference", None):
|
||||
reference.update(ans["data"]["reference"])
|
||||
if ans.get("event") == "node_finished":
|
||||
data = ans.get("data", {})
|
||||
node_out = data.get("outputs", {})
|
||||
component_id = data.get("component_id")
|
||||
if component_id is not None and "structured" in node_out:
|
||||
structured_output[component_id] = copy.deepcopy(node_out["structured"])
|
||||
if return_trace:
|
||||
trace_items.append(
|
||||
{
|
||||
"component_id": data.get("component_id"),
|
||||
"trace": [copy.deepcopy(data)],
|
||||
}
|
||||
)
|
||||
final_ans = ans
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
if final_ans:
|
||||
if "data" not in final_ans or not isinstance(final_ans["data"], dict):
|
||||
final_ans["data"] = {}
|
||||
final_ans["data"]["content"] = full_content
|
||||
final_ans["data"]["reference"] = reference
|
||||
if structured_output:
|
||||
final_ans["data"]["structured"] = structured_output
|
||||
if trace_items:
|
||||
final_ans["data"]["trace"] = trace_items
|
||||
await persist_workflow_session()
|
||||
except Exception as exc:
|
||||
logging.exception(exc)
|
||||
canvas.cancel_task()
|
||||
yield (
|
||||
"data:"
|
||||
+ json.dumps({"code": 500, "message": str(exc), "data": False}, ensure_ascii=False)
|
||||
+ "\n\n"
|
||||
)
|
||||
finally:
|
||||
if not done_sent:
|
||||
done_sent = True
|
||||
yield "data:[DONE]\n\n"
|
||||
|
||||
return _build_sse_response(sse())
|
||||
|
||||
try:
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
ans["session_id"] = session_id
|
||||
if ans.get("event") == "message":
|
||||
full_content += ans.get("data", {}).get("content", "")
|
||||
if ans.get("data", {}).get("reference", None):
|
||||
reference.update(ans["data"]["reference"])
|
||||
if ans.get("event") == "node_finished":
|
||||
data = ans.get("data", {})
|
||||
node_out = data.get("outputs", {})
|
||||
component_id = data.get("component_id")
|
||||
if component_id is not None and "structured" in node_out:
|
||||
structured_output[component_id] = copy.deepcopy(node_out["structured"])
|
||||
if return_trace:
|
||||
trace_items.append(
|
||||
{
|
||||
"component_id": data.get("component_id"),
|
||||
"trace": [copy.deepcopy(data)],
|
||||
}
|
||||
)
|
||||
final_ans = ans
|
||||
except Exception as exc:
|
||||
logging.exception(exc)
|
||||
canvas.cancel_task()
|
||||
return get_result(data=f"**ERROR**: {str(exc)}")
|
||||
|
||||
if not final_ans:
|
||||
await commit_runtime_replica()
|
||||
return get_result(data={})
|
||||
|
||||
if "data" not in final_ans or not isinstance(final_ans["data"], dict):
|
||||
final_ans["data"] = {}
|
||||
final_ans["data"]["content"] = full_content
|
||||
final_ans["data"]["reference"] = reference
|
||||
if structured_output:
|
||||
final_ans["data"]["structured"] = structured_output
|
||||
if trace_items:
|
||||
final_ans["data"]["trace"] = trace_items
|
||||
|
||||
await persist_workflow_session()
|
||||
return get_result(data=final_ans)
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
@@ -957,6 +1140,8 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
req.pop("agent_id", None)
|
||||
req.pop("openai-compatible", None)
|
||||
session_id = req.get("session_id")
|
||||
workflow_session = False
|
||||
workflow_conv = None
|
||||
if session_id:
|
||||
exists, conv = API4ConversationService.get_by_id(session_id)
|
||||
if not exists:
|
||||
@@ -973,6 +1158,9 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
message="Only authorized users can access this agent session.",
|
||||
code=RetCode.OPERATING_ERROR,
|
||||
)
|
||||
workflow_session = getattr(conv, "source", "") == "workflow"
|
||||
if workflow_session:
|
||||
workflow_conv = conv.to_dict()
|
||||
|
||||
if openai_compatible:
|
||||
# OpenAI-compatible mode uses a different wire format, keep it separate from regular agent events.
|
||||
@@ -1005,8 +1193,7 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
return jsonify(response)
|
||||
return None
|
||||
|
||||
if not session_id:
|
||||
# Without session state, run against the runtime replica that tracks draft edits.
|
||||
if workflow_session:
|
||||
query = req.get("query", "") or req.get("question", "")
|
||||
files = req.get("files", [])
|
||||
inputs = req.get("inputs", {})
|
||||
@@ -1014,6 +1201,64 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
user_id = str(runtime_user_id)
|
||||
custom_header = req.get("custom_header", "")
|
||||
|
||||
_, cvs = await thread_pool_exec(UserCanvasService.get_by_id, agent_id)
|
||||
if not cvs:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
if not isinstance(workflow_conv.get("message"), list):
|
||||
workflow_conv["message"] = []
|
||||
if isinstance(workflow_conv.get("reference"), dict):
|
||||
if "chunks" in workflow_conv["reference"]:
|
||||
workflow_conv["reference"] = [workflow_conv["reference"]]
|
||||
else:
|
||||
workflow_conv["reference"] = [
|
||||
value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0]))
|
||||
]
|
||||
elif not isinstance(workflow_conv.get("reference"), list):
|
||||
workflow_conv["reference"] = []
|
||||
workflow_conv["reference"] = [_normalize_agent_reference_entry(reference) for reference in workflow_conv["reference"]]
|
||||
turn_id = get_uuid()
|
||||
workflow_conv["message"].append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": query,
|
||||
"id": turn_id,
|
||||
"files": files,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
)
|
||||
await thread_pool_exec(API4ConversationService.update_by_id, session_id, workflow_conv)
|
||||
|
||||
try:
|
||||
from agent.canvas import Canvas
|
||||
|
||||
workflow_dsl = workflow_conv.get("dsl", {})
|
||||
if isinstance(workflow_dsl, str):
|
||||
dsl_str = workflow_dsl
|
||||
else:
|
||||
dsl_str = json.dumps(workflow_dsl, ensure_ascii=False)
|
||||
canvas = Canvas(dsl_str, str(tenant_id), canvas_id=agent_id, custom_header=custom_header)
|
||||
except Exception as exc:
|
||||
return server_error_response(exc)
|
||||
|
||||
return await _run_workflow_session(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
workflow_conv=workflow_conv,
|
||||
canvas=canvas,
|
||||
query=query,
|
||||
files=files,
|
||||
inputs=inputs,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
custom_header=custom_header,
|
||||
canvas_title=getattr(cvs, "title", ""),
|
||||
canvas_category=getattr(cvs, "canvas_category", CanvasCategory.Agent),
|
||||
return_trace=bool(req.get("return_trace", False)),
|
||||
stream=req.get("stream", True),
|
||||
)
|
||||
|
||||
if not session_id:
|
||||
if not UserCanvasService.accessible(agent_id, tenant_id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@@ -1021,6 +1266,16 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
code=RetCode.OPERATING_ERROR,
|
||||
)
|
||||
|
||||
# Keep the original workflow execution path, but assign a session_id so the
|
||||
# response shape stays closer to the older agent completion contract.
|
||||
query = req.get("query", "") or req.get("question", "")
|
||||
files = req.get("files", [])
|
||||
inputs = req.get("inputs", {})
|
||||
runtime_user_id = req.get("user_id") or tenant_id
|
||||
user_id = str(runtime_user_id)
|
||||
custom_header = req.get("custom_header", "")
|
||||
session_id = get_uuid()
|
||||
|
||||
_, cvs = await thread_pool_exec(UserCanvasService.get_by_id, agent_id)
|
||||
if not cvs:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
@@ -1054,6 +1309,31 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
from rag.flow.pipeline import Pipeline
|
||||
|
||||
task_id = get_uuid()
|
||||
workflow_conv = {
|
||||
"id": session_id,
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": user_id,
|
||||
"exp_user_id": user_id,
|
||||
"name": req.get("name", ""),
|
||||
"message": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": query,
|
||||
"id": task_id,
|
||||
"files": files,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
],
|
||||
"reference": [],
|
||||
"source": "workflow",
|
||||
"dsl": replica_dsl,
|
||||
"version_title": await thread_pool_exec(
|
||||
UserCanvasVersionService.get_latest_version_title,
|
||||
cvs.id,
|
||||
release_mode=False,
|
||||
),
|
||||
}
|
||||
await thread_pool_exec(API4ConversationService.save, **workflow_conv)
|
||||
Pipeline(
|
||||
dsl_str,
|
||||
tenant_id=str(tenant_id),
|
||||
@@ -1072,7 +1352,7 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
)
|
||||
if not ok:
|
||||
return get_data_error_result(message=error_message)
|
||||
return get_json_result(data={"message_id": task_id})
|
||||
return get_json_result(data={"message_id": task_id, "session_id": session_id})
|
||||
|
||||
try:
|
||||
from agent.canvas import Canvas
|
||||
@@ -1080,88 +1360,49 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
canvas = Canvas(dsl_str, str(tenant_id), canvas_id=agent_id, custom_header=custom_header)
|
||||
except Exception as exc:
|
||||
return server_error_response(exc)
|
||||
|
||||
async def commit_runtime_replica():
|
||||
commit_ok = CanvasReplicaService.commit_after_run(
|
||||
canvas_id=agent_id,
|
||||
tenant_id=str(tenant_id),
|
||||
runtime_user_id=user_id,
|
||||
dsl=json.loads(str(canvas)),
|
||||
canvas_category=canvas_category,
|
||||
title=canvas_title,
|
||||
)
|
||||
if not commit_ok:
|
||||
logging.error(
|
||||
"Canvas runtime replica commit failed: canvas_id=%s tenant_id=%s runtime_user_id=%s",
|
||||
agent_id,
|
||||
tenant_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
if req.get("stream", True):
|
||||
async def sse():
|
||||
nonlocal canvas
|
||||
try:
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
await commit_runtime_replica()
|
||||
except Exception as exc:
|
||||
logging.exception(exc)
|
||||
canvas.cancel_task()
|
||||
yield (
|
||||
"data:"
|
||||
+ json.dumps({"code": 500, "message": str(exc), "data": False}, ensure_ascii=False)
|
||||
+ "\n\n"
|
||||
)
|
||||
|
||||
return _build_sse_response(sse())
|
||||
|
||||
full_content = ""
|
||||
reference = {}
|
||||
final_ans = {}
|
||||
trace_items = []
|
||||
structured_output = {}
|
||||
try:
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
if ans.get("event") == "message":
|
||||
full_content += ans.get("data", {}).get("content", "")
|
||||
if ans.get("data", {}).get("reference", None):
|
||||
reference.update(ans["data"]["reference"])
|
||||
if ans.get("event") == "node_finished":
|
||||
data = ans.get("data", {})
|
||||
node_out = data.get("outputs", {})
|
||||
component_id = data.get("component_id")
|
||||
if component_id is not None and "structured" in node_out:
|
||||
structured_output[component_id] = copy.deepcopy(node_out["structured"])
|
||||
if req.get("return_trace", False):
|
||||
trace_items.append(
|
||||
{
|
||||
"component_id": data.get("component_id"),
|
||||
"trace": [copy.deepcopy(data)],
|
||||
}
|
||||
)
|
||||
final_ans = ans
|
||||
except Exception as exc:
|
||||
logging.exception(exc)
|
||||
canvas.cancel_task()
|
||||
return get_result(data=f"**ERROR**: {str(exc)}")
|
||||
|
||||
if not final_ans:
|
||||
await commit_runtime_replica()
|
||||
return get_result(data={})
|
||||
|
||||
if "data" not in final_ans or not isinstance(final_ans["data"], dict):
|
||||
final_ans["data"] = {}
|
||||
final_ans["data"]["content"] = full_content
|
||||
final_ans["data"]["reference"] = reference
|
||||
if structured_output:
|
||||
final_ans["data"]["structured"] = structured_output
|
||||
if trace_items:
|
||||
final_ans["data"]["trace"] = trace_items
|
||||
|
||||
await commit_runtime_replica()
|
||||
return get_result(data=final_ans)
|
||||
turn_id = get_uuid()
|
||||
workflow_conv = {
|
||||
"id": session_id,
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": user_id,
|
||||
"exp_user_id": user_id,
|
||||
"name": req.get("name", ""),
|
||||
"message": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": query,
|
||||
"id": turn_id,
|
||||
"files": files,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
],
|
||||
"reference": [],
|
||||
"source": "workflow",
|
||||
"dsl": replica_dsl,
|
||||
"version_title": await thread_pool_exec(
|
||||
UserCanvasVersionService.get_latest_version_title,
|
||||
cvs.id,
|
||||
release_mode=False,
|
||||
),
|
||||
}
|
||||
workflow_conv["reference"] = [_normalize_agent_reference_entry(reference) for reference in workflow_conv["reference"]]
|
||||
await thread_pool_exec(API4ConversationService.save, **workflow_conv)
|
||||
return await _run_workflow_session(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
workflow_conv=workflow_conv,
|
||||
canvas=canvas,
|
||||
query=query,
|
||||
files=files,
|
||||
inputs=inputs,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
custom_header=custom_header,
|
||||
canvas_title=canvas_title,
|
||||
canvas_category=canvas_category,
|
||||
return_trace=bool(req.get("return_trace", False)),
|
||||
stream=req.get("stream", True),
|
||||
)
|
||||
|
||||
return_trace = bool(req.get("return_trace", False))
|
||||
if req.get("stream", True):
|
||||
|
||||
Reference in New Issue
Block a user