2025-07-30 19:41:09 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2025-12-03 14:19:53 +08:00
import asyncio
2025-07-30 19:41:09 +08:00
import json
import logging
import os
import re
2025-09-02 11:06:17 +08:00
from copy import deepcopy
2025-12-11 17:38:17 +08:00
from typing import Any , AsyncGenerator
2025-07-30 19:41:09 +08:00
import json_repair
from functools import partial
2025-11-05 08:01:39 +08:00
from common . constants import LLMType
2026-06-11 14:09:57 +08:00
from api . db . services . dialog_service import _stream_with_think_delta
2025-08-13 16:41:01 +08:00
from api . db . services . llm_service import LLMBundle
2026-05-29 17:39:41 +08:00
from api . db . joint_services . tenant_model_service import get_model_config_from_provider_instance , get_model_type_by_name
2025-07-30 19:41:09 +08:00
from agent . component . base import ComponentBase , ComponentParamBase
2025-11-04 11:51:12 +08:00
from common . connection_utils import timeout
2025-11-03 09:39:53 +08:00
from rag . prompts . generator import tool_call_summary , message_fit_in , citation_prompt , structured_output_prompt
2025-07-30 19:41:09 +08:00
class LLMParam ( ComponentParamBase ) :
"""
Define the LLM component parameters .
"""
def __init__ ( self ) :
super ( ) . __init__ ( )
self . llm_id = " "
self . sys_prompt = " "
self . prompts = [ { " role " : " user " , " content " : " {sys.query} " } ]
self . max_tokens = 0
self . temperature = 0
self . top_p = 0
self . presence_penalty = 0
self . frequency_penalty = 0
self . output_structure = None
self . cite = True
self . visual_files_var = None
def check ( self ) :
2025-08-01 21:49:39 +08:00
self . check_decimal_float ( float ( self . temperature ) , " [Agent] Temperature " )
self . check_decimal_float ( float ( self . presence_penalty ) , " [Agent] Presence penalty " )
self . check_decimal_float ( float ( self . frequency_penalty ) , " [Agent] Frequency penalty " )
self . check_nonnegative_number ( int ( self . max_tokens ) , " [Agent] Max tokens " )
self . check_decimal_float ( float ( self . top_p ) , " [Agent] Top P " )
2025-07-30 19:41:09 +08:00
self . check_empty ( self . llm_id , " [Agent] LLM " )
self . check_empty ( self . prompts , " [Agent] User prompt " )
def gen_conf ( self ) :
conf = { }
2025-08-01 21:49:39 +08:00
def get_attr ( nm ) :
try :
return getattr ( self , nm )
except Exception :
pass
if int ( self . max_tokens ) > 0 and get_attr ( " maxTokensEnabled " ) :
conf [ " max_tokens " ] = int ( self . max_tokens )
if float ( self . temperature ) > 0 and get_attr ( " temperatureEnabled " ) :
conf [ " temperature " ] = float ( self . temperature )
if float ( self . top_p ) > 0 and get_attr ( " topPEnabled " ) :
conf [ " top_p " ] = float ( self . top_p )
if float ( self . presence_penalty ) > 0 and get_attr ( " presencePenaltyEnabled " ) :
conf [ " presence_penalty " ] = float ( self . presence_penalty )
if float ( self . frequency_penalty ) > 0 and get_attr ( " frequencyPenaltyEnabled " ) :
conf [ " frequency_penalty " ] = float ( self . frequency_penalty )
2025-07-30 19:41:09 +08:00
return conf
class LLM ( ComponentBase ) :
component_name = " LLM "
2025-09-25 12:05:43 +08:00
def __init__ ( self , canvas , component_id , param : ComponentParamBase ) :
super ( ) . __init__ ( canvas , component_id , param )
2026-05-29 17:39:41 +08:00
model_types = get_model_type_by_name ( self . _canvas . get_tenant_id ( ) , self . _param . llm_id )
model_type = " chat " if " chat " in model_types else model_types [ 0 ]
chat_model_config = get_model_config_from_provider_instance ( self . _canvas . get_tenant_id ( ) , model_type , self . _param . llm_id )
2026-03-05 17:27:17 +08:00
self . chat_mdl = LLMBundle ( self . _canvas . get_tenant_id ( ) , chat_model_config ,
max_retries = self . _param . max_retries ,
retry_interval = self . _param . delay_after_error )
2025-07-30 19:41:09 +08:00
self . imgs = [ ]
def get_input_form ( self ) - > dict [ str , dict ] :
res = { }
for k , v in self . get_input_elements ( ) . items ( ) :
res [ k ] = {
" type " : " line " ,
" name " : v [ " name " ]
}
return res
def get_input_elements ( self ) - > dict [ str , Any ] :
res = self . get_input_elements_from_text ( self . _param . sys_prompt )
2025-10-09 12:36:19 +08:00
if isinstance ( self . _param . prompts , str ) :
self . _param . prompts = [ { " role " : " user " , " content " : self . _param . prompts } ]
2025-07-30 19:41:09 +08:00
for prompt in self . _param . prompts :
d = self . get_input_elements_from_text ( prompt [ " content " ] )
res . update ( d )
return res
def set_debug_inputs ( self , inputs : dict [ str , dict ] ) :
self . _param . debug_inputs = inputs
def add2system_prompt ( self , txt ) :
self . _param . sys_prompt + = txt
2025-10-09 12:36:19 +08:00
def _sys_prompt_and_msg ( self , msg , args ) :
if isinstance ( self . _param . prompts , str ) :
self . _param . prompts = [ { " role " : " user " , " content " : self . _param . prompts } ]
for p in self . _param . prompts :
if msg and msg [ - 1 ] [ " role " ] == p [ " role " ] :
continue
p = deepcopy ( p )
p [ " content " ] = self . string_format ( p [ " content " ] , args )
msg . append ( p )
return msg , self . string_format ( self . _param . sys_prompt , args )
2026-02-11 09:47:33 +08:00
@staticmethod
def _extract_data_images ( value ) - > list [ str ] :
imgs = [ ]
def walk ( v ) :
if v is None :
return
if isinstance ( v , str ) :
v = v . strip ( )
if v . startswith ( " data:image/ " ) :
imgs . append ( v )
return
if isinstance ( v , ( list , tuple , set ) ) :
for item in v :
walk ( item )
return
if isinstance ( v , dict ) :
if " content " in v :
walk ( v . get ( " content " ) )
else :
for item in v . values ( ) :
walk ( item )
walk ( value )
return imgs
@staticmethod
def _uniq_images ( images : list [ str ] ) - > list [ str ] :
seen = set ( )
uniq = [ ]
for img in images :
if not isinstance ( img , str ) :
continue
if not img . startswith ( " data:image/ " ) :
continue
if img in seen :
continue
seen . add ( img )
uniq . append ( img )
return uniq
@classmethod
def _remove_data_images ( cls , value ) :
if value is None :
return None
if isinstance ( value , str ) :
return None if value . strip ( ) . startswith ( " data:image/ " ) else value
if isinstance ( value , list ) :
cleaned = [ ]
for item in value :
v = cls . _remove_data_images ( item )
if v is None :
continue
if isinstance ( v , ( list , tuple , set , dict ) ) and not v :
continue
cleaned . append ( v )
return cleaned
if isinstance ( value , tuple ) :
cleaned = [ ]
for item in value :
v = cls . _remove_data_images ( item )
if v is None :
continue
if isinstance ( v , ( list , tuple , set , dict ) ) and not v :
continue
cleaned . append ( v )
return tuple ( cleaned )
if isinstance ( value , set ) :
cleaned = [ ]
for item in value :
v = cls . _remove_data_images ( item )
if v is None :
continue
if isinstance ( v , ( list , tuple , set , dict ) ) and not v :
continue
cleaned . append ( v )
return cleaned
if isinstance ( value , dict ) :
if value . get ( " type " ) in { " image_url " , " input_image " , " image " } and cls . _extract_data_images ( value ) :
return None
cleaned = { }
for k , item in value . items ( ) :
v = cls . _remove_data_images ( item )
if v is None :
continue
if isinstance ( v , ( list , tuple , set , dict ) ) and not v :
continue
cleaned [ k ] = v
return cleaned
return value
2025-07-30 19:41:09 +08:00
def _prepare_prompt_variables ( self ) :
2026-02-11 09:47:33 +08:00
self . imgs = [ ]
2025-07-30 19:41:09 +08:00
if self . _param . visual_files_var :
2026-04-14 15:03:46 +08:00
visual_val = self . _canvas . get_variable_value ( self . _param . visual_files_var )
self . imgs . extend ( self . _extract_data_images ( visual_val ) )
2025-07-30 19:41:09 +08:00
args = { }
vars = self . get_input_elements ( ) if not self . _param . debug_inputs else self . _param . debug_inputs
2026-02-11 09:47:33 +08:00
extracted_imgs = [ ]
2025-07-30 19:41:09 +08:00
for k , o in vars . items ( ) :
2026-02-11 09:47:33 +08:00
raw_value = o [ " value " ]
extracted_imgs . extend ( self . _extract_data_images ( raw_value ) )
args [ k ] = self . _remove_data_images ( raw_value )
if args [ k ] is None :
args [ k ] = " "
2025-07-30 19:41:09 +08:00
if not isinstance ( args [ k ] , str ) :
try :
args [ k ] = json . dumps ( args [ k ] , ensure_ascii = False )
except Exception :
args [ k ] = str ( args [ k ] )
self . set_input_value ( k , args [ k ] )
2026-02-11 09:47:33 +08:00
self . imgs = self . _uniq_images ( self . imgs + extracted_imgs )
2026-05-29 17:39:41 +08:00
model_types = get_model_type_by_name ( self . _canvas . get_tenant_id ( ) , self . _param . llm_id )
2026-06-10 19:09:18 +08:00
if self . imgs and LLMType . IMAGE2TEXT . value in model_types :
model_type = LLMType . IMAGE2TEXT . value
elif LLMType . CHAT . value in model_types :
model_type = LLMType . CHAT . value
else :
model_type = model_types [ 0 ]
2026-05-29 17:39:41 +08:00
model_config = get_model_config_from_provider_instance ( self . _canvas . get_tenant_id ( ) , model_type , self . _param . llm_id )
if self . imgs :
self . chat_mdl = LLMBundle ( self . _canvas . get_tenant_id ( ) , model_config , max_retries = self . _param . max_retries ,
2026-02-11 09:47:33 +08:00
retry_interval = self . _param . delay_after_error
)
2025-10-09 12:36:19 +08:00
msg , sys_prompt = self . _sys_prompt_and_msg ( self . _canvas . get_history ( self . _param . message_history_window_size ) [ : - 1 ] , args )
2025-09-08 14:05:01 +08:00
user_defined_prompt , sys_prompt = self . _extract_prompts ( sys_prompt )
2025-08-15 10:05:01 +08:00
if self . _param . cite and self . _canvas . get_reference ( ) [ " chunks " ] :
2025-09-08 14:05:01 +08:00
sys_prompt + = citation_prompt ( user_defined_prompt )
2025-07-30 19:41:09 +08:00
2025-09-08 14:05:01 +08:00
return sys_prompt , msg , user_defined_prompt
def _extract_prompts ( self , sys_prompt ) :
pts = { }
for tag in [ " TASK_ANALYSIS " , " PLAN_GENERATION " , " REFLECTION " , " CONTEXT_SUMMARY " , " CONTEXT_RANKING " , " CITATION_GUIDELINES " ] :
r = re . search ( rf " < { tag } >(.*?)</ { tag } > " , sys_prompt , flags = re . DOTALL | re . IGNORECASE )
if not r :
continue
pts [ tag . lower ( ) ] = r . group ( 1 )
2025-09-09 10:52:18 +08:00
sys_prompt = re . sub ( rf " < { tag } >(.*?)</ { tag } > " , " " , sys_prompt , flags = re . DOTALL | re . IGNORECASE )
2025-09-08 14:05:01 +08:00
return pts , sys_prompt
2025-07-30 19:41:09 +08:00
2025-12-03 14:19:53 +08:00
async def _generate_async ( self , msg : list [ dict ] , * * kwargs ) - > str :
2025-07-30 19:41:09 +08:00
if not self . imgs :
2025-12-11 17:38:17 +08:00
return await self . chat_mdl . async_chat ( msg [ 0 ] [ " content " ] , msg [ 1 : ] , self . _param . gen_conf ( ) , * * kwargs )
return await self . chat_mdl . async_chat ( msg [ 0 ] [ " content " ] , msg [ 1 : ] , self . _param . gen_conf ( ) , images = self . imgs , * * kwargs )
2025-07-30 19:41:09 +08:00
2025-12-11 17:38:17 +08:00
async def _generate_streamly ( self , msg : list [ dict ] , * * kwargs ) - > AsyncGenerator [ str , None ] :
2026-06-11 14:09:57 +08:00
stream_kwargs = { " images " : self . imgs } if self . imgs else { }
stream_kwargs . update ( kwargs )
stream = self . chat_mdl . async_chat_streamly_delta ( msg [ 0 ] [ " content " ] , msg [ 1 : ] , self . _param . gen_conf ( ) , * * stream_kwargs )
async for _ , value , _ in _stream_with_think_delta ( stream , min_tokens = 0 ) :
yield value
2025-12-03 14:19:53 +08:00
2025-12-01 14:24:06 +08:00
async def _stream_output_async ( self , prompt , msg ) :
_ , msg = message_fit_in ( [ { " role " : " system " , " content " : prompt } , * msg ] , int ( self . chat_mdl . max_length * 0.97 ) )
answer = " "
stream_kwargs = { " images " : self . imgs } if self . imgs else { }
2026-05-22 02:15:49 -05:00
extra_chat_kwargs = self . _get_chat_template_kwargs ( )
stream_kwargs . update ( extra_chat_kwargs )
2026-06-11 14:09:57 +08:00
stream = self . chat_mdl . async_chat_streamly_delta ( msg [ 0 ] [ " content " ] , msg [ 1 : ] , self . _param . gen_conf ( ) , * * stream_kwargs )
async for _ , ans , _ in _stream_with_think_delta ( stream , min_tokens = 0 ) :
2025-12-01 14:24:06 +08:00
if self . check_if_canceled ( " LLM streaming " ) :
return
if ans . find ( " **ERROR** " ) > = 0 :
if self . get_exception_default_value ( ) :
self . set_output ( " content " , self . get_exception_default_value ( ) )
yield self . get_exception_default_value ( )
else :
self . set_output ( " _ERROR " , ans )
return
2026-06-11 14:09:57 +08:00
answer + = ans
yield ans
2025-12-01 14:24:06 +08:00
self . set_output ( " content " , answer )
2025-09-26 14:55:19 +08:00
@timeout ( int ( os . environ . get ( " COMPONENT_EXEC_TIMEOUT " , 10 * 60 ) ) )
2025-12-04 14:15:05 +08:00
async def _invoke_async ( self , * * kwargs ) :
2025-11-11 17:36:48 +08:00
if self . check_if_canceled ( " LLM processing " ) :
return
2025-07-30 19:41:09 +08:00
def clean_formated_answer ( ans : str ) - > str :
ans = re . sub ( r " ^.*</think> " , " " , ans , flags = re . DOTALL )
ans = re . sub ( r " ^.*```json " , " " , ans , flags = re . DOTALL )
return re . sub ( r " ``` \ n*$ " , " " , ans , flags = re . DOTALL )
2025-09-08 14:05:01 +08:00
prompt , msg , _ = self . _prepare_prompt_variables ( )
2026-05-22 02:15:49 -05:00
extra_chat_kwargs = self . _get_chat_template_kwargs ( )
2025-09-25 12:05:43 +08:00
error : str = " "
2025-12-04 14:15:05 +08:00
output_structure = None
2025-11-03 09:39:53 +08:00
try :
2025-12-04 14:15:05 +08:00
output_structure = self . _param . outputs [ " structured " ]
2025-11-03 09:39:53 +08:00
except Exception :
pass
2025-11-27 16:00:56 +08:00
if output_structure and isinstance ( output_structure , dict ) and output_structure . get ( " properties " ) and len ( output_structure [ " properties " ] ) > 0 :
2025-12-04 14:15:05 +08:00
schema = json . dumps ( output_structure , ensure_ascii = False , indent = 2 )
prompt_with_schema = prompt + structured_output_prompt ( schema )
for _ in range ( self . _param . max_retries + 1 ) :
2025-11-11 17:36:48 +08:00
if self . check_if_canceled ( " LLM processing " ) :
return
2025-12-04 14:15:05 +08:00
_ , msg_fit = message_fit_in (
[ { " role " : " system " , " content " : prompt_with_schema } , * deepcopy ( msg ) ] ,
int ( self . chat_mdl . max_length * 0.97 ) ,
)
2025-07-30 19:41:09 +08:00
error = " "
2026-05-22 02:15:49 -05:00
ans = await self . _generate_async ( msg_fit , * * extra_chat_kwargs )
2025-12-04 14:15:05 +08:00
msg_fit . pop ( 0 )
2025-07-30 19:41:09 +08:00
if ans . find ( " **ERROR** " ) > = 0 :
logging . error ( f " LLM response error: { ans } " )
error = ans
continue
try :
2025-11-03 09:39:53 +08:00
self . set_output ( " structured " , json_repair . loads ( clean_formated_answer ( ans ) ) )
2025-07-30 19:41:09 +08:00
return
except Exception :
2025-12-04 14:15:05 +08:00
msg_fit . append ( { " role " : " user " , " content " : " The answer can ' t not be parsed as JSON " } )
2025-07-30 19:41:09 +08:00
error = " The answer can ' t not be parsed as JSON "
if error :
self . set_output ( " _ERROR " , error )
return
downstreams = self . _canvas . get_component ( self . _id ) [ " downstream " ] if self . _canvas . get_component ( self . _id ) else [ ]
2025-08-01 21:49:39 +08:00
ex = self . exception_handler ( )
2025-12-04 14:15:05 +08:00
if any ( [ self . _canvas . get_component_obj ( cid ) . component_name . lower ( ) == " message " for cid in downstreams ] ) and not (
ex and ex [ " goto " ]
) :
self . set_output ( " content " , partial ( self . _stream_output_async , prompt , deepcopy ( msg ) ) )
2025-07-30 19:41:09 +08:00
return
2025-12-04 14:15:05 +08:00
error = " "
for _ in range ( self . _param . max_retries + 1 ) :
2025-11-11 17:36:48 +08:00
if self . check_if_canceled ( " LLM processing " ) :
return
2025-12-04 14:15:05 +08:00
_ , msg_fit = message_fit_in (
[ { " role " : " system " , " content " : prompt } , * deepcopy ( msg ) ] , int ( self . chat_mdl . max_length * 0.97 )
)
2025-07-30 19:41:09 +08:00
error = " "
2026-05-22 02:15:49 -05:00
ans = await self . _generate_async ( msg_fit , * * extra_chat_kwargs )
2025-12-04 14:15:05 +08:00
msg_fit . pop ( 0 )
2025-07-30 19:41:09 +08:00
if ans . find ( " **ERROR** " ) > = 0 :
logging . error ( f " LLM response error: { ans } " )
error = ans
continue
self . set_output ( " content " , ans )
break
if error :
if self . get_exception_default_value ( ) :
self . set_output ( " content " , self . get_exception_default_value ( ) )
2025-08-01 21:49:39 +08:00
else :
self . set_output ( " _ERROR " , error )
2025-07-30 19:41:09 +08:00
2025-12-04 14:15:05 +08:00
@timeout ( int ( os . environ . get ( " COMPONENT_EXEC_TIMEOUT " , 10 * 60 ) ) )
def _invoke ( self , * * kwargs ) :
return asyncio . run ( self . _invoke_async ( * * kwargs ) )
2025-07-30 19:41:09 +08:00
2026-05-22 02:15:49 -05:00
def _get_chat_template_kwargs ( self ) - > dict [ str , Any ] :
chat_template_kwargs = self . _canvas . globals . get ( " sys.chat_template_kwargs " )
if chat_template_kwargs is None :
return { }
# The API should pass this as a JSON object, but accept a JSON string for compatibility.
if isinstance ( chat_template_kwargs , str ) :
try :
chat_template_kwargs = json_repair . loads ( chat_template_kwargs )
except Exception :
logging . warning ( " Ignore invalid sys.chat_template_kwargs: expected JSON object or JSON string object. " )
return { }
if not isinstance ( chat_template_kwargs , dict ) :
logging . warning ( " Ignore invalid sys.chat_template_kwargs type: %s " , type ( chat_template_kwargs ) . __name__ )
return { }
return { " chat_template_kwargs " : chat_template_kwargs }
2025-12-11 17:38:17 +08:00
async def add_memory ( self , user : str , assist : str , func_name : str , params : dict , results : str , user_defined_prompt : dict = { } ) :
summ = await tool_call_summary ( self . chat_mdl , func_name , params , results , user_defined_prompt )
2025-07-30 19:41:09 +08:00
logging . info ( f " [MEMORY]: { summ } " )
self . _canvas . add_memory ( user , assist , summ )
2025-07-31 15:13:45 +08:00
def thoughts ( self ) - > str :
2025-09-08 14:05:01 +08:00
_ , msg , _ = self . _prepare_prompt_variables ( )
2025-11-11 17:36:48 +08:00
return " ⌛Give me a moment—starting from: \n \n " + re . sub ( r " (User ' s query:|[ \\ ]+) " , ' ' , msg [ - 1 ] [ ' content ' ] , flags = re . DOTALL ) + " \n \n I’ ll figure out our best next move. "