2025-07-30 19:41:09 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2025-12-03 14:19:53 +08:00
import asyncio
2025-11-27 16:00:56 +08:00
import json
2025-07-30 19:41:09 +08:00
import logging
import os
import re
from copy import deepcopy
from functools import partial
2026-03-20 20:32:00 +08:00
from timeit import default_timer as timer
2025-07-30 19:41:09 +08:00
from typing import Any
import json_repair
2026-03-20 20:32:00 +08:00
from agent . component . llm import LLM , LLMParam
from agent . tools . base import LLMToolPluginCallSession , ToolBase , ToolMeta , ToolParamBase
2026-05-29 17:39:41 +08:00
from api . db . joint_services . tenant_model_service import get_model_config_from_provider_instance , get_model_type_by_name
2025-08-13 16:41:01 +08:00
from api . db . services . llm_service import LLMBundle
2025-07-30 19:41:09 +08:00
from api . db . services . mcp_server_service import MCPServerService
2025-11-04 11:51:12 +08:00
from common . connection_utils import timeout
2026-05-14 15:28:39 +08:00
from common . mcp_tool_call_conn import MCPToolBinding , MCPToolCallSession , mcp_tool_metadata_to_openai_tool
2026-03-20 20:32:00 +08:00
from rag . prompts . generator import citation_plus , citation_prompt , full_question , kb_prompt , message_fit_in , structured_output_prompt
2025-07-30 19:41:09 +08:00
2026-06-29 09:41:16 +08:00
_logger = logging . getLogger ( __name__ )
2025-07-30 19:41:09 +08:00
class AgentParam ( LLMParam , ToolParamBase ) :
"""
Define the Agent component parameters .
"""
def __init__ ( self ) :
2026-03-20 20:32:00 +08:00
self . meta : ToolMeta = {
" name " : " agent " ,
" description " : " This is an agent for a specific task. " ,
" parameters " : {
" user_prompt " : { " type " : " string " , " description " : " This is the order you need to send to the agent. " , " default " : " " , " required " : True } ,
" reasoning " : {
" type " : " string " ,
" description " : ( " Supervisor ' s reasoning for choosing the this agent. Explain why this agent is being invoked and what is expected of it. " ) ,
" required " : True ,
} ,
" context " : {
" type " : " string " ,
" description " : (
" All relevant background information, prior facts, decisions, and state needed by the agent to solve the current query. Should be as detailed and self-contained as possible. "
) ,
" required " : True ,
} ,
} ,
}
2025-07-30 19:41:09 +08:00
super ( ) . __init__ ( )
self . function_name = " agent "
self . tools = [ ]
self . mcp = [ ]
self . max_rounds = 5
self . description = " "
2026-02-24 13:14:21 +08:00
self . custom_header = { }
2025-07-30 19:41:09 +08:00
class Agent ( LLM , ToolBase ) :
component_name = " Agent "
def __init__ ( self , canvas , id , param : LLMParam ) :
LLM . __init__ ( self , canvas , id , param )
self . tools = { }
2025-12-23 14:08:25 +08:00
for idx , cpn in enumerate ( self . _param . tools ) :
2025-07-30 19:41:09 +08:00
cpn = self . _load_tool_obj ( cpn )
2025-12-24 13:26:48 +08:00
original_name = cpn . get_meta ( ) [ " function " ] [ " name " ]
indexed_name = f " { original_name } _ { idx } "
self . tools [ indexed_name ] = cpn
2026-05-29 17:39:41 +08:00
model_types = get_model_type_by_name ( self . _canvas . get_tenant_id ( ) , self . _param . llm_id )
model_type = " chat " if " chat " in model_types else model_types [ 0 ]
chat_model_config = get_model_config_from_provider_instance ( self . _canvas . get_tenant_id ( ) , model_type , self . _param . llm_id )
2026-03-20 20:32:00 +08:00
self . chat_mdl = LLMBundle (
self . _canvas . get_tenant_id ( ) ,
chat_model_config ,
max_retries = self . _param . max_retries ,
retry_interval = self . _param . delay_after_error ,
max_rounds = self . _param . max_rounds ,
verbose_tool_use = False ,
)
2025-12-24 13:26:48 +08:00
self . tool_meta = [ ]
for indexed_name , tool_obj in self . tools . items ( ) :
original_meta = tool_obj . get_meta ( )
indexed_meta = deepcopy ( original_meta )
indexed_meta [ " function " ] [ " name " ] = indexed_name
self . tool_meta . append ( indexed_meta )
2025-07-30 19:41:09 +08:00
2026-05-14 15:28:39 +08:00
tool_idx = len ( self . tools )
2025-07-30 19:41:09 +08:00
for mcp in self . _param . mcp :
_ , mcp_server = MCPServerService . get_by_id ( mcp [ " mcp_id " ] )
2026-02-24 13:14:21 +08:00
custom_header = self . _param . custom_header
tool_call_session = MCPToolCallSession ( mcp_server , mcp_server . variables , custom_header )
2025-07-30 19:41:09 +08:00
for tnm , meta in mcp [ " tools " ] . items ( ) :
2026-05-14 15:28:39 +08:00
indexed_name = f " { tnm } _ { tool_idx } "
tool_idx + = 1
self . tool_meta . append ( mcp_tool_metadata_to_openai_tool ( meta , function_name = indexed_name ) )
self . tools [ indexed_name ] = MCPToolBinding ( tool_call_session , tnm )
2025-07-30 19:41:09 +08:00
self . callback = partial ( self . _canvas . tool_use_callback , id )
self . toolcall_session = LLMToolPluginCallSession ( self . tools , self . callback )
2026-03-20 20:32:00 +08:00
if self . tool_meta :
self . chat_mdl . bind_tools ( self . toolcall_session , self . tool_meta )
def _fit_messages ( self , prompt : str , msg : list [ dict ] ) - > list [ dict ] :
_ , fitted_messages = message_fit_in (
[ { " role " : " system " , " content " : prompt } , * msg ] ,
int ( self . chat_mdl . max_length * 0.97 ) ,
)
return fitted_messages
@staticmethod
def _append_system_prompt ( msg : list [ dict ] , extra_prompt : str ) - > None :
if extra_prompt and msg and msg [ 0 ] [ " role " ] == " system " :
msg [ 0 ] [ " content " ] + = " \n " + extra_prompt
@staticmethod
def _clean_formatted_answer ( ans : str ) - > str :
ans = re . sub ( r " ^.*</think> " , " " , ans , flags = re . DOTALL )
ans = re . sub ( r " ^.*```json " , " " , ans , flags = re . DOTALL )
return re . sub ( r " ``` \ n*$ " , " " , ans , flags = re . DOTALL )
2025-07-30 19:41:09 +08:00
def _load_tool_obj ( self , cpn : dict ) - > object :
from agent . component import component_class
2026-03-20 20:32:00 +08:00
2025-12-24 13:26:48 +08:00
tool_name = cpn [ " component_name " ]
param = component_class ( tool_name + " Param " ) ( )
2025-07-30 19:41:09 +08:00
param . update ( cpn [ " params " ] )
try :
param . check ( )
except Exception as e :
self . set_output ( " _ERROR " , cpn [ " component_name " ] + f " configuration error: { e } " )
raise
cpn_id = f " { self . _id } --> " + cpn . get ( " name " , " " ) . replace ( " " , " _ " )
return component_class ( cpn [ " component_name " ] ) ( self . _canvas , cpn_id , param )
def get_meta ( self ) - > dict [ str , Any ] :
2026-03-20 20:32:00 +08:00
self . _param . function_name = self . _id . split ( " --> " ) [ - 1 ]
2025-07-30 19:41:09 +08:00
m = super ( ) . get_meta ( )
if hasattr ( self . _param , " user_prompt " ) and self . _param . user_prompt :
2026-04-28 17:09:08 +08:00
# Keep the JSON schema valid; user_prompt is a string field, not a schema node.
m [ " function " ] [ " parameters " ] [ " properties " ] [ " user_prompt " ] [ " default " ] = self . _param . user_prompt
2025-07-30 19:41:09 +08:00
return m
def get_input_form ( self ) - > dict [ str , dict ] :
res = { }
for k , v in self . get_input_elements ( ) . items ( ) :
2026-03-20 20:32:00 +08:00
res [ k ] = { " type " : " line " , " name " : v [ " name " ] }
2025-07-30 19:41:09 +08:00
for cpn in self . _param . tools :
if not isinstance ( cpn , LLM ) :
continue
res . update ( cpn . get_input_form ( ) )
return res
2025-11-27 16:00:56 +08:00
def _get_output_schema ( self ) :
try :
cand = self . _param . outputs . get ( " structured " )
except Exception :
return None
if isinstance ( cand , dict ) :
if isinstance ( cand . get ( " properties " ) , dict ) and len ( cand [ " properties " ] ) > 0 :
return cand
for k in ( " schema " , " structured " ) :
if isinstance ( cand . get ( k ) , dict ) and isinstance ( cand [ k ] . get ( " properties " ) , dict ) and len ( cand [ k ] [ " properties " ] ) > 0 :
return cand [ k ]
return None
2025-12-04 14:15:05 +08:00
async def _force_format_to_schema_async ( self , text : str , schema_prompt : str ) - > str :
2025-11-27 16:00:56 +08:00
fmt_msgs = [
{ " role " : " system " , " content " : schema_prompt + " \n IMPORTANT: Output ONLY valid JSON. No markdown, no extra text. " } ,
{ " role " : " user " , " content " : text } ,
]
_ , fmt_msgs = message_fit_in ( fmt_msgs , int ( self . chat_mdl . max_length * 0.97 ) )
2025-12-04 14:15:05 +08:00
return await self . _generate_async ( fmt_msgs )
2025-11-27 16:00:56 +08:00
2025-07-30 19:41:09 +08:00
def _invoke ( self , * * kwargs ) :
2025-12-04 14:15:05 +08:00
return asyncio . run ( self . _invoke_async ( * * kwargs ) )
2025-07-30 19:41:09 +08:00
2026-03-20 20:32:00 +08:00
@timeout ( int ( os . environ . get ( " COMPONENT_EXEC_TIMEOUT " , 20 * 60 ) ) )
2025-12-03 14:19:53 +08:00
async def _invoke_async ( self , * * kwargs ) :
if self . check_if_canceled ( " Agent processing " ) :
return
2026-06-29 09:41:16 +08:00
_logger . debug (
" [Agent] _invoke_async called. Component: %s , Keys in kwargs: %s , user_prompt: %s , tools count: %d " ,
self . _id ,
list ( kwargs . keys ( ) ) ,
json . dumps ( kwargs . get ( " user_prompt " , " " ) , ensure_ascii = False , default = str ) [ : 300 ] ,
len ( self . tools ) if self . tools else 0 ,
)
2025-12-03 14:19:53 +08:00
if kwargs . get ( " user_prompt " ) :
usr_pmt = " "
if kwargs . get ( " reasoning " ) :
usr_pmt + = " \n REASONING: \n {} \n " . format ( kwargs [ " reasoning " ] )
if kwargs . get ( " context " ) :
usr_pmt + = " \n CONTEXT: \n {} \n " . format ( kwargs [ " context " ] )
if usr_pmt :
usr_pmt + = " \n QUERY: \n {} \n " . format ( str ( kwargs [ " user_prompt " ] ) )
else :
usr_pmt = str ( kwargs [ " user_prompt " ] )
self . _param . prompts = [ { " role " : " user " , " content " : usr_pmt } ]
2026-06-29 09:41:16 +08:00
_logger . debug ( " [Agent] Built user prompt with length= %d , reasoning= %s , context= %s " ,
len ( usr_pmt ) , bool ( kwargs . get ( " reasoning " ) ) , bool ( kwargs . get ( " context " ) ) )
2025-12-03 14:19:53 +08:00
if not self . tools :
if self . check_if_canceled ( " Agent processing " ) :
return
2026-06-29 09:41:16 +08:00
_logger . debug ( " [Agent] No tools configured. Delegating to LLM._invoke_async. prompt_count= %d " , len ( self . _param . prompts ) if self . _param . prompts else 0 )
2025-12-04 14:15:05 +08:00
return await LLM . _invoke_async ( self , * * kwargs )
2025-12-03 14:19:53 +08:00
prompt , msg , user_defined_prompt = self . _prepare_prompt_variables ( )
output_schema = self . _get_output_schema ( )
schema_prompt = " "
if output_schema :
schema = json . dumps ( output_schema , ensure_ascii = False , indent = 2 )
schema_prompt = structured_output_prompt ( schema )
2026-03-20 20:32:00 +08:00
component = self . _canvas . get_component ( self . _id )
downstreams = component [ " downstream " ] if component else [ ]
2025-12-03 14:19:53 +08:00
ex = self . exception_handler ( )
2026-03-20 20:32:00 +08:00
has_message_downstream = any ( self . _canvas . get_component_obj ( cid ) . component_name . lower ( ) == " message " for cid in downstreams )
if has_message_downstream and not ( ex and ex [ " goto " ] ) and not output_schema :
2026-06-29 09:41:16 +08:00
_logger . debug ( " [Agent] Entering streaming mode (has message downstream) " )
2025-12-04 14:15:05 +08:00
self . set_output ( " content " , partial ( self . stream_output_with_tools_async , prompt , deepcopy ( msg ) , user_defined_prompt ) )
2025-12-03 14:19:53 +08:00
return
2026-03-20 20:32:00 +08:00
msg = self . _fit_messages ( prompt , msg )
self . _append_system_prompt ( msg , schema_prompt )
2026-06-29 09:41:16 +08:00
_logger . debug ( " [Agent] Calling LLM with %d messages, has_schema= %s " , len ( msg ) , bool ( schema_prompt ) )
2026-03-20 20:32:00 +08:00
ans = await self . _generate_async ( msg )
2025-12-03 14:19:53 +08:00
if ans . find ( " **ERROR** " ) > = 0 :
logging . error ( f " Agent._chat got error. response: { ans } " )
if self . get_exception_default_value ( ) :
self . set_output ( " content " , self . get_exception_default_value ( ) )
else :
self . set_output ( " _ERROR " , ans )
return
if output_schema :
error = " "
for _ in range ( self . _param . max_retries + 1 ) :
try :
2026-03-20 20:32:00 +08:00
obj = json_repair . loads ( self . _clean_formatted_answer ( ans ) )
2025-12-03 14:19:53 +08:00
self . set_output ( " structured " , obj )
return obj
except Exception :
error = " The answer cannot be parsed as JSON "
2025-12-04 14:15:05 +08:00
ans = await self . _force_format_to_schema_async ( ans , schema_prompt )
2025-12-03 14:19:53 +08:00
if ans . find ( " **ERROR** " ) > = 0 :
continue
self . set_output ( " _ERROR " , error )
return
2026-03-20 20:32:00 +08:00
artifact_md = self . _collect_tool_artifact_markdown ( existing_text = ans )
if artifact_md :
ans + = " \n \n " + artifact_md
2026-06-29 09:41:16 +08:00
_logger . debug ( " [Agent] Final output. content_length= %d , has_artifact= %s " , len ( ans ) , bool ( artifact_md ) )
2025-12-03 14:19:53 +08:00
self . set_output ( " content " , ans )
return ans
async def stream_output_with_tools_async ( self , prompt , msg , user_defined_prompt = { } ) :
2026-03-20 20:32:00 +08:00
if len ( msg ) > 3 :
st = timer ( )
user_request = await full_question ( messages = msg , chat_mdl = self . chat_mdl )
self . callback ( " Multi-turn conversation optimization " , { } , user_request , elapsed_time = timer ( ) - st )
msg = [ * msg [ : - 1 ] , { " role " : " user " , " content " : user_request } ]
msg = self . _fit_messages ( prompt , msg )
need2cite = self . _param . cite and self . _canvas . get_reference ( ) [ " chunks " ] and self . _id . find ( " --> " ) < 0
cited = False
if need2cite and len ( msg ) < 7 :
self . _append_system_prompt ( msg , citation_prompt ( ) )
cited = True
answer = " "
async for delta in self . _generate_streamly ( msg ) :
2025-12-03 14:19:53 +08:00
if self . check_if_canceled ( " Agent streaming " ) :
return
2026-03-20 20:32:00 +08:00
if delta . find ( " **ERROR** " ) > = 0 :
2025-12-03 14:19:53 +08:00
if self . get_exception_default_value ( ) :
2026-05-07 15:54:57 +08:00
fallback = self . get_exception_default_value ( )
self . set_output ( " content " , fallback )
yield fallback
2025-12-03 14:19:53 +08:00
else :
2026-03-20 20:32:00 +08:00
self . set_output ( " _ERROR " , delta )
2026-05-07 15:54:57 +08:00
self . set_output ( " content " , delta )
yield delta
2025-07-30 19:41:09 +08:00
return
2026-03-20 20:32:00 +08:00
if not need2cite or cited :
yield delta
answer + = delta
if not need2cite or cited :
artifact_md = self . _collect_tool_artifact_markdown ( existing_text = answer )
if artifact_md :
yield " \n \n " + artifact_md
answer + = " \n \n " + artifact_md
self . set_output ( " content " , answer )
return
2025-07-30 19:41:09 +08:00
2025-08-19 10:27:24 +08:00
st = timer ( )
2026-03-20 20:32:00 +08:00
cited_answer = " "
async for delta in self . _gen_citations_async ( answer ) :
2025-11-11 17:36:48 +08:00
if self . check_if_canceled ( " Agent streaming " ) :
return
2026-03-20 20:32:00 +08:00
yield delta
cited_answer + = delta
artifact_md = self . _collect_tool_artifact_markdown ( existing_text = cited_answer )
if artifact_md :
yield " \n \n " + artifact_md
cited_answer + = " \n \n " + artifact_md
self . callback ( " gen_citations " , { } , cited_answer , elapsed_time = timer ( ) - st )
self . set_output ( " content " , cited_answer )
2025-12-23 09:36:08 +08:00
2025-12-04 14:15:05 +08:00
async def _gen_citations_async ( self , text ) :
retrievals = self . _canvas . get_reference ( )
retrievals = { " chunks " : list ( retrievals [ " chunks " ] . values ( ) ) , " doc_aggs " : list ( retrievals [ " doc_aggs " ] . values ( ) ) }
formated_refer = kb_prompt ( retrievals , self . chat_mdl . max_length , True )
2026-03-20 20:32:00 +08:00
async for delta_ans in self . _generate_streamly ( [ { " role " : " system " , " content " : citation_plus ( " \n \n " . join ( formated_refer ) ) } , { " role " : " user " , " content " : text } ] ) :
2025-12-04 14:15:05 +08:00
yield delta_ans
2025-07-31 15:13:45 +08:00
2026-03-20 20:32:00 +08:00
def _collect_tool_artifact_markdown ( self , existing_text : str = " " ) - > str :
md_parts = [ ]
for tool_obj in self . tools . values ( ) :
if not hasattr ( tool_obj , " _param " ) or not hasattr ( tool_obj . _param , " outputs " ) :
continue
artifacts_meta = tool_obj . _param . outputs . get ( " _ARTIFACTS " , { } )
artifacts = artifacts_meta . get ( " value " ) if isinstance ( artifacts_meta , dict ) else None
if not artifacts :
continue
for art in artifacts :
if not isinstance ( art , dict ) :
continue
url = art . get ( " url " , " " )
if url and ( f "  " in existing_text or f "  " in existing_text ) :
continue
if art . get ( " mime_type " , " " ) . startswith ( " image/ " ) :
md_parts . append ( f " ![ { art [ ' name ' ] } ]( { url } ) " )
else :
md_parts . append ( f " [Download { art [ ' name ' ] } ]( { url } ) " )
return " \n \n " . join ( md_parts )
2025-11-13 09:49:12 +08:00
def reset ( self , only_output = False ) :
2025-10-13 09:34:44 +08:00
"""
Reset all tools if they have a reset method . This avoids errors for tools like MCPToolCallSession .
"""
2025-11-13 09:49:12 +08:00
for k in self . _param . outputs . keys ( ) :
self . _param . outputs [ k ] [ " value " ] = None
2025-11-27 16:00:56 +08:00
2025-09-30 15:13:18 +08:00
for k , cpn in self . tools . items ( ) :
2025-10-13 09:34:44 +08:00
if hasattr ( cpn , " reset " ) and callable ( cpn . reset ) :
cpn . reset ( )
2025-11-13 09:49:12 +08:00
if only_output :
return
for k in self . _param . inputs . keys ( ) :
self . _param . inputs [ k ] [ " value " ] = None
self . _param . debug_inputs = { }