2025-07-30 19:41:09 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os
import re
2025-09-02 11:06:17 +08:00
from copy import deepcopy
2025-08-12 18:05:52 +08:00
from typing import Any , Generator
2025-07-30 19:41:09 +08:00
import json_repair
from functools import partial
2025-08-01 21:49:39 +08:00
from api . db import LLMType
2025-08-13 16:41:01 +08:00
from api . db . services . llm_service import LLMBundle
from api . db . services . tenant_llm_service import TenantLLMService
2025-07-30 19:41:09 +08:00
from agent . component . base import ComponentBase , ComponentParamBase
from api . utils . api_utils import timeout
2025-09-23 10:19:25 +08:00
from rag . prompts . generator import tool_call_summary , message_fit_in , citation_prompt
2025-07-30 19:41:09 +08:00
class LLMParam ( ComponentParamBase ) :
"""
Define the LLM component parameters .
"""
def __init__ ( self ) :
super ( ) . __init__ ( )
self . llm_id = " "
self . sys_prompt = " "
self . prompts = [ { " role " : " user " , " content " : " {sys.query} " } ]
self . max_tokens = 0
self . temperature = 0
self . top_p = 0
self . presence_penalty = 0
self . frequency_penalty = 0
self . output_structure = None
self . cite = True
self . visual_files_var = None
def check ( self ) :
2025-08-01 21:49:39 +08:00
self . check_decimal_float ( float ( self . temperature ) , " [Agent] Temperature " )
self . check_decimal_float ( float ( self . presence_penalty ) , " [Agent] Presence penalty " )
self . check_decimal_float ( float ( self . frequency_penalty ) , " [Agent] Frequency penalty " )
self . check_nonnegative_number ( int ( self . max_tokens ) , " [Agent] Max tokens " )
self . check_decimal_float ( float ( self . top_p ) , " [Agent] Top P " )
2025-07-30 19:41:09 +08:00
self . check_empty ( self . llm_id , " [Agent] LLM " )
self . check_empty ( self . sys_prompt , " [Agent] System prompt " )
self . check_empty ( self . prompts , " [Agent] User prompt " )
def gen_conf ( self ) :
conf = { }
2025-08-01 21:49:39 +08:00
def get_attr ( nm ) :
try :
return getattr ( self , nm )
except Exception :
pass
if int ( self . max_tokens ) > 0 and get_attr ( " maxTokensEnabled " ) :
conf [ " max_tokens " ] = int ( self . max_tokens )
if float ( self . temperature ) > 0 and get_attr ( " temperatureEnabled " ) :
conf [ " temperature " ] = float ( self . temperature )
if float ( self . top_p ) > 0 and get_attr ( " topPEnabled " ) :
conf [ " top_p " ] = float ( self . top_p )
if float ( self . presence_penalty ) > 0 and get_attr ( " presencePenaltyEnabled " ) :
conf [ " presence_penalty " ] = float ( self . presence_penalty )
if float ( self . frequency_penalty ) > 0 and get_attr ( " frequencyPenaltyEnabled " ) :
conf [ " frequency_penalty " ] = float ( self . frequency_penalty )
2025-07-30 19:41:09 +08:00
return conf
class LLM ( ComponentBase ) :
component_name = " LLM "
def __init__ ( self , canvas , id , param : ComponentParamBase ) :
super ( ) . __init__ ( canvas , id , param )
self . chat_mdl = LLMBundle ( self . _canvas . get_tenant_id ( ) , TenantLLMService . llm_id2llm_type ( self . _param . llm_id ) ,
self . _param . llm_id , max_retries = self . _param . max_retries ,
retry_interval = self . _param . delay_after_error
)
self . imgs = [ ]
def get_input_form ( self ) - > dict [ str , dict ] :
res = { }
for k , v in self . get_input_elements ( ) . items ( ) :
res [ k ] = {
" type " : " line " ,
" name " : v [ " name " ]
}
return res
def get_input_elements ( self ) - > dict [ str , Any ] :
res = self . get_input_elements_from_text ( self . _param . sys_prompt )
for prompt in self . _param . prompts :
d = self . get_input_elements_from_text ( prompt [ " content " ] )
res . update ( d )
return res
def set_debug_inputs ( self , inputs : dict [ str , dict ] ) :
self . _param . debug_inputs = inputs
def add2system_prompt ( self , txt ) :
self . _param . sys_prompt + = txt
def _prepare_prompt_variables ( self ) :
if self . _param . visual_files_var :
self . imgs = self . _canvas . get_variable_value ( self . _param . visual_files_var )
if not self . imgs :
self . imgs = [ ]
self . imgs = [ img for img in self . imgs if img [ : len ( " data:image/ " ) ] == " data:image/ " ]
2025-08-01 21:49:39 +08:00
if self . imgs and TenantLLMService . llm_id2llm_type ( self . _param . llm_id ) == LLMType . CHAT . value :
self . chat_mdl = LLMBundle ( self . _canvas . get_tenant_id ( ) , LLMType . IMAGE2TEXT . value ,
self . _param . llm_id , max_retries = self . _param . max_retries ,
retry_interval = self . _param . delay_after_error
)
2025-07-30 19:41:09 +08:00
args = { }
vars = self . get_input_elements ( ) if not self . _param . debug_inputs else self . _param . debug_inputs
2025-08-21 12:14:43 +08:00
sys_prompt = self . _param . sys_prompt
2025-07-30 19:41:09 +08:00
for k , o in vars . items ( ) :
args [ k ] = o [ " value " ]
if not isinstance ( args [ k ] , str ) :
try :
args [ k ] = json . dumps ( args [ k ] , ensure_ascii = False )
except Exception :
args [ k ] = str ( args [ k ] )
self . set_input_value ( k , args [ k ] )
msg = self . _canvas . get_history ( self . _param . message_history_window_size ) [ : - 1 ]
2025-08-21 12:14:43 +08:00
for p in self . _param . prompts :
if msg and msg [ - 1 ] [ " role " ] == p [ " role " ] :
continue
2025-09-02 11:06:17 +08:00
msg . append ( deepcopy ( p ) )
2025-08-21 12:14:43 +08:00
sys_prompt = self . string_format ( sys_prompt , args )
2025-09-08 14:05:01 +08:00
user_defined_prompt , sys_prompt = self . _extract_prompts ( sys_prompt )
2025-07-30 19:41:09 +08:00
for m in msg :
m [ " content " ] = self . string_format ( m [ " content " ] , args )
2025-08-15 10:05:01 +08:00
if self . _param . cite and self . _canvas . get_reference ( ) [ " chunks " ] :
2025-09-08 14:05:01 +08:00
sys_prompt + = citation_prompt ( user_defined_prompt )
2025-07-30 19:41:09 +08:00
2025-09-08 14:05:01 +08:00
return sys_prompt , msg , user_defined_prompt
def _extract_prompts ( self , sys_prompt ) :
pts = { }
for tag in [ " TASK_ANALYSIS " , " PLAN_GENERATION " , " REFLECTION " , " CONTEXT_SUMMARY " , " CONTEXT_RANKING " , " CITATION_GUIDELINES " ] :
r = re . search ( rf " < { tag } >(.*?)</ { tag } > " , sys_prompt , flags = re . DOTALL | re . IGNORECASE )
if not r :
continue
pts [ tag . lower ( ) ] = r . group ( 1 )
2025-09-09 10:52:18 +08:00
sys_prompt = re . sub ( rf " < { tag } >(.*?)</ { tag } > " , " " , sys_prompt , flags = re . DOTALL | re . IGNORECASE )
2025-09-08 14:05:01 +08:00
return pts , sys_prompt
2025-07-30 19:41:09 +08:00
def _generate ( self , msg : list [ dict ] , * * kwargs ) - > str :
if not self . imgs :
return self . chat_mdl . chat ( msg [ 0 ] [ " content " ] , msg [ 1 : ] , self . _param . gen_conf ( ) , * * kwargs )
return self . chat_mdl . chat ( msg [ 0 ] [ " content " ] , msg [ 1 : ] , self . _param . gen_conf ( ) , images = self . imgs , * * kwargs )
2025-08-12 18:05:52 +08:00
def _generate_streamly ( self , msg : list [ dict ] , * * kwargs ) - > Generator [ str , None , None ] :
2025-07-30 19:41:09 +08:00
ans = " "
last_idx = 0
endswith_think = False
def delta ( txt ) :
nonlocal ans , last_idx , endswith_think
delta_ans = txt [ last_idx : ]
ans = txt
if delta_ans . find ( " <think> " ) == 0 :
last_idx + = len ( " <think> " )
return " <think> "
elif delta_ans . find ( " <think> " ) > 0 :
delta_ans = txt [ last_idx : last_idx + delta_ans . find ( " <think> " ) ]
last_idx + = delta_ans . find ( " <think> " )
return delta_ans
elif delta_ans . endswith ( " </think> " ) :
endswith_think = True
elif endswith_think :
endswith_think = False
return " </think> "
last_idx = len ( ans )
if ans . endswith ( " </think> " ) :
last_idx - = len ( " </think> " )
return re . sub ( r " (<think>|</think>) " , " " , delta_ans )
if not self . imgs :
for txt in self . chat_mdl . chat_streamly ( msg [ 0 ] [ " content " ] , msg [ 1 : ] , self . _param . gen_conf ( ) , * * kwargs ) :
yield delta ( txt )
else :
for txt in self . chat_mdl . chat_streamly ( msg [ 0 ] [ " content " ] , msg [ 1 : ] , self . _param . gen_conf ( ) , images = self . imgs , * * kwargs ) :
yield delta ( txt )
@timeout ( os . environ . get ( " COMPONENT_EXEC_TIMEOUT " , 10 * 60 ) )
def _invoke ( self , * * kwargs ) :
def clean_formated_answer ( ans : str ) - > str :
ans = re . sub ( r " ^.*</think> " , " " , ans , flags = re . DOTALL )
ans = re . sub ( r " ^.*```json " , " " , ans , flags = re . DOTALL )
return re . sub ( r " ``` \ n*$ " , " " , ans , flags = re . DOTALL )
2025-09-08 14:05:01 +08:00
prompt , msg , _ = self . _prepare_prompt_variables ( )
2025-07-30 19:41:09 +08:00
error = " "
if self . _param . output_structure :
prompt + = " \n The output MUST follow this JSON format: \n " + json . dumps ( self . _param . output_structure , ensure_ascii = False , indent = 2 )
prompt + = " \n Redundant information is FORBIDDEN. "
for _ in range ( self . _param . max_retries + 1 ) :
_ , msg = message_fit_in ( [ { " role " : " system " , " content " : prompt } , * msg ] , int ( self . chat_mdl . max_length * 0.97 ) )
error = " "
ans = self . _generate ( msg )
msg . pop ( 0 )
if ans . find ( " **ERROR** " ) > = 0 :
logging . error ( f " LLM response error: { ans } " )
error = ans
continue
try :
self . set_output ( " structured_content " , json_repair . loads ( clean_formated_answer ( ans ) ) )
return
except Exception :
msg . append ( { " role " : " user " , " content " : " The answer can ' t not be parsed as JSON " } )
error = " The answer can ' t not be parsed as JSON "
if error :
self . set_output ( " _ERROR " , error )
return
downstreams = self . _canvas . get_component ( self . _id ) [ " downstream " ] if self . _canvas . get_component ( self . _id ) else [ ]
2025-08-01 21:49:39 +08:00
ex = self . exception_handler ( )
if any ( [ self . _canvas . get_component_obj ( cid ) . component_name . lower ( ) == " message " for cid in downstreams ] ) and not self . _param . output_structure and not ( ex and ex [ " goto " ] ) :
2025-07-30 19:41:09 +08:00
self . set_output ( " content " , partial ( self . _stream_output , prompt , msg ) )
return
for _ in range ( self . _param . max_retries + 1 ) :
_ , msg = message_fit_in ( [ { " role " : " system " , " content " : prompt } , * msg ] , int ( self . chat_mdl . max_length * 0.97 ) )
error = " "
ans = self . _generate ( msg )
msg . pop ( 0 )
if ans . find ( " **ERROR** " ) > = 0 :
logging . error ( f " LLM response error: { ans } " )
error = ans
continue
self . set_output ( " content " , ans )
break
if error :
if self . get_exception_default_value ( ) :
self . set_output ( " content " , self . get_exception_default_value ( ) )
2025-08-01 21:49:39 +08:00
else :
self . set_output ( " _ERROR " , error )
2025-07-30 19:41:09 +08:00
def _stream_output ( self , prompt , msg ) :
_ , msg = message_fit_in ( [ { " role " : " system " , " content " : prompt } , * msg ] , int ( self . chat_mdl . max_length * 0.97 ) )
answer = " "
for ans in self . _generate_streamly ( msg ) :
2025-08-01 21:49:39 +08:00
if ans . find ( " **ERROR** " ) > = 0 :
if self . get_exception_default_value ( ) :
self . set_output ( " content " , self . get_exception_default_value ( ) )
yield self . get_exception_default_value ( )
else :
self . set_output ( " _ERROR " , ans )
return
2025-07-30 19:41:09 +08:00
yield ans
answer + = ans
self . set_output ( " content " , answer )
2025-09-08 14:05:01 +08:00
def add_memory ( self , user : str , assist : str , func_name : str , params : dict , results : str , user_defined_prompt : dict = { } ) :
summ = tool_call_summary ( self . chat_mdl , func_name , params , results , user_defined_prompt )
2025-07-30 19:41:09 +08:00
logging . info ( f " [MEMORY]: { summ } " )
self . _canvas . add_memory ( user , assist , summ )
2025-07-31 15:13:45 +08:00
def thoughts ( self ) - > str :
2025-09-08 14:05:01 +08:00
_ , msg , _ = self . _prepare_prompt_variables ( )
2025-08-01 21:49:39 +08:00
return " ⌛Give me a moment—starting from: \n \n " + re . sub ( r " (User ' s query:|[ \\ ]+) " , ' ' , msg [ - 1 ] [ ' content ' ] , flags = re . DOTALL ) + " \n \n I’ ll figure out our best next move. "