2025-07-30 19:41:09 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2025-11-03 19:19:09 +08:00
from functools import partial
import json
2025-07-30 19:41:09 +08:00
import os
import re
from abc import ABC
from agent . tools . base import ToolParamBase , ToolBase , ToolMeta
2025-11-05 08:01:39 +08:00
from common . constants import LLMType
2025-10-21 09:52:26 +08:00
from api . db . services . document_service import DocumentService
from api . db . services . dialog_service import meta_filter
2025-07-30 19:41:09 +08:00
from api . db . services . knowledgebase_service import KnowledgebaseService
from api . db . services . llm_service import LLMBundle
2025-11-06 09:36:38 +08:00
from common import settings
2025-11-04 11:51:12 +08:00
from common . connection_utils import timeout
2025-07-30 19:41:09 +08:00
from rag . app . tag import label_question
2025-10-21 09:52:26 +08:00
from rag . prompts . generator import cross_languages , kb_prompt , gen_meta_filter
2025-07-30 19:41:09 +08:00
class RetrievalParam ( ToolParamBase ) :
"""
Define the Retrieval component parameters .
"""
def __init__ ( self ) :
self . meta : ToolMeta = {
" name " : " search_my_dateset " ,
" description " : " This tool can be utilized for relevant content searching in the datasets. " ,
" parameters " : {
" query " : {
" type " : " string " ,
" description " : " The keywords to search the dataset. The keywords should be the most important words/terms(includes synonyms) from the original request. " ,
" default " : " " ,
" required " : True
}
}
}
super ( ) . __init__ ( )
self . function_name = " search_my_dateset "
self . description = " This tool can be utilized for relevant content searching in the datasets. "
self . similarity_threshold = 0.2
self . keywords_similarity_weight = 0.5
self . top_n = 8
self . top_k = 1024
self . kb_ids = [ ]
self . kb_vars = [ ]
self . rerank_id = " "
self . empty_response = " "
self . use_kg = False
self . cross_languages = [ ]
2025-10-10 17:07:55 +08:00
self . toc_enhance = False
2025-10-21 09:52:26 +08:00
self . meta_data_filter = { }
2025-07-30 19:41:09 +08:00
def check ( self ) :
self . check_decimal_float ( self . similarity_threshold , " [Retrieval] Similarity threshold " )
self . check_decimal_float ( self . keywords_similarity_weight , " [Retrieval] Keyword similarity weight " )
self . check_positive_number ( self . top_n , " [Retrieval] Top N " )
def get_input_form ( self ) - > dict [ str , dict ] :
return {
" query " : {
" name " : " Query " ,
" type " : " line "
}
}
class Retrieval ( ToolBase , ABC ) :
component_name = " Retrieval "
2025-09-25 14:11:09 +08:00
@timeout ( int ( os . environ . get ( " COMPONENT_EXEC_TIMEOUT " , 12 ) ) )
2025-07-30 19:41:09 +08:00
def _invoke ( self , * * kwargs ) :
if not kwargs . get ( " query " ) :
self . set_output ( " formalized_content " , self . _param . empty_response )
kb_ids : list [ str ] = [ ]
for id in self . _param . kb_ids :
if id . find ( " @ " ) < 0 :
kb_ids . append ( id )
continue
kb_nm = self . _canvas . get_variable_value ( id )
2025-08-19 09:42:39 +08:00
# if kb_nm is a list
kb_nm_list = kb_nm if isinstance ( kb_nm , list ) else [ kb_nm ]
for nm_or_id in kb_nm_list :
e , kb = KnowledgebaseService . get_by_name ( nm_or_id ,
self . _canvas . _tenant_id )
if not e :
e , kb = KnowledgebaseService . get_by_id ( nm_or_id )
if not e :
raise Exception ( f " Dataset( { nm_or_id } ) does not exist. " )
kb_ids . append ( kb . id )
2025-07-30 19:41:09 +08:00
filtered_kb_ids : list [ str ] = list ( set ( [ kb_id for kb_id in kb_ids if kb_id ] ) )
kbs = KnowledgebaseService . get_by_ids ( filtered_kb_ids )
if not kbs :
raise Exception ( " No dataset is selected. " )
embd_nms = list ( set ( [ kb . embd_id for kb in kbs ] ) )
assert len ( embd_nms ) == 1 , " Knowledge bases use different embedding models. "
embd_mdl = None
if embd_nms :
embd_mdl = LLMBundle ( self . _canvas . get_tenant_id ( ) , LLMType . EMBEDDING , embd_nms [ 0 ] )
rerank_mdl = None
if self . _param . rerank_id :
rerank_mdl = LLMBundle ( kbs [ 0 ] . tenant_id , LLMType . RERANK , self . _param . rerank_id )
2025-08-14 13:45:19 +08:00
vars = self . get_input_elements_from_text ( kwargs [ " query " ] )
vars = { k : o [ " value " ] for k , o in vars . items ( ) }
query = self . string_format ( kwargs [ " query " ] , vars )
2025-10-21 09:52:26 +08:00
doc_ids = [ ]
if self . _param . meta_data_filter != { } :
metas = DocumentService . get_meta_by_kbs ( kb_ids )
if self . _param . meta_data_filter . get ( " method " ) == " auto " :
chat_mdl = LLMBundle ( self . _canvas . get_tenant_id ( ) , LLMType . CHAT )
filters = gen_meta_filter ( chat_mdl , metas , query )
doc_ids . extend ( meta_filter ( metas , filters ) )
if not doc_ids :
doc_ids = None
elif self . _param . meta_data_filter . get ( " method " ) == " manual " :
2025-11-03 19:19:09 +08:00
filters = self . _param . meta_data_filter [ " manual " ]
for flt in filters :
pat = re . compile ( r " \ { * * \ { ([a-zA-Z:0-9]+@[A-Za-z:0-9_.-]+|sys \ .[a-z_]+) \ } * \ }* " )
s = flt [ " value " ]
out_parts = [ ]
last = 0
for m in pat . finditer ( s ) :
out_parts . append ( s [ last : m . start ( ) ] )
key = m . group ( 1 )
v = self . _canvas . get_variable_value ( key )
if v is None :
rep = " "
elif isinstance ( v , partial ) :
buf = [ ]
for chunk in v ( ) :
buf . append ( chunk )
rep = " " . join ( buf )
elif isinstance ( v , str ) :
rep = v
else :
rep = json . dumps ( v , ensure_ascii = False )
out_parts . append ( rep )
last = m . end ( )
out_parts . append ( s [ last : ] )
flt [ " value " ] = " " . join ( out_parts )
doc_ids . extend ( meta_filter ( metas , filters ) )
2025-10-21 09:52:26 +08:00
if not doc_ids :
doc_ids = None
2025-07-30 19:41:09 +08:00
if self . _param . cross_languages :
query = cross_languages ( kbs [ 0 ] . tenant_id , None , query , self . _param . cross_languages )
if kbs :
query = re . sub ( r " ^user[:: \ s]* " , " " , query , flags = re . IGNORECASE )
2025-11-06 09:36:38 +08:00
kbinfos = settings . retriever . retrieval (
2025-07-30 19:41:09 +08:00
query ,
embd_mdl ,
[ kb . tenant_id for kb in kbs ] ,
filtered_kb_ids ,
1 ,
self . _param . top_n ,
self . _param . similarity_threshold ,
1 - self . _param . keywords_similarity_weight ,
2025-10-21 09:52:26 +08:00
doc_ids = doc_ids ,
2025-07-30 19:41:09 +08:00
aggs = False ,
rerank_mdl = rerank_mdl ,
rank_feature = label_question ( query , kbs ) ,
)
2025-10-10 17:07:55 +08:00
if self . _param . toc_enhance :
chat_mdl = LLMBundle ( self . _canvas . _tenant_id , LLMType . CHAT )
2025-11-06 09:36:38 +08:00
cks = settings . retriever . retrieval_by_toc ( query , kbinfos [ " chunks " ] , [ kb . tenant_id for kb in kbs ] , chat_mdl , self . _param . top_n )
2025-10-10 17:07:55 +08:00
if cks :
kbinfos [ " chunks " ] = cks
2025-07-30 19:41:09 +08:00
if self . _param . use_kg :
2025-10-10 09:17:36 +08:00
ck = settings . kg_retriever . retrieval ( query ,
2025-07-30 19:41:09 +08:00
[ kb . tenant_id for kb in kbs ] ,
kb_ids ,
embd_mdl ,
LLMBundle ( self . _canvas . get_tenant_id ( ) , LLMType . CHAT ) )
if ck [ " content_with_weight " ] :
kbinfos [ " chunks " ] . insert ( 0 , ck )
else :
kbinfos = { " chunks " : [ ] , " doc_aggs " : [ ] }
if self . _param . use_kg and kbs :
2025-10-10 09:17:36 +08:00
ck = settings . kg_retriever . retrieval ( query , [ kb . tenant_id for kb in kbs ] , filtered_kb_ids , embd_mdl , LLMBundle ( kbs [ 0 ] . tenant_id , LLMType . CHAT ) )
2025-07-30 19:41:09 +08:00
if ck [ " content_with_weight " ] :
ck [ " content " ] = ck [ " content_with_weight " ]
del ck [ " content_with_weight " ]
kbinfos [ " chunks " ] . insert ( 0 , ck )
for ck in kbinfos [ " chunks " ] :
if " vector " in ck :
del ck [ " vector " ]
if " content_ltks " in ck :
del ck [ " content_ltks " ]
if not kbinfos [ " chunks " ] :
self . set_output ( " formalized_content " , self . _param . empty_response )
return
2025-09-22 17:28:29 +08:00
# Format the chunks for JSON output (similar to how other tools do it)
json_output = kbinfos [ " chunks " ] . copy ( )
2025-09-25 14:11:09 +08:00
2025-09-05 12:31:44 +08:00
self . _canvas . add_reference ( kbinfos [ " chunks " ] , kbinfos [ " doc_aggs " ] )
2025-07-30 19:41:09 +08:00
form_cnt = " \n " . join ( kb_prompt ( kbinfos , 200000 , True ) )
2025-09-25 14:11:09 +08:00
2025-09-22 17:28:29 +08:00
# Set both formalized content and JSON output
2025-07-30 19:41:09 +08:00
self . set_output ( " formalized_content " , form_cnt )
2025-09-22 17:28:29 +08:00
self . set_output ( " json " , json_output )
2025-09-25 14:11:09 +08:00
2025-07-30 19:41:09 +08:00
return form_cnt
2025-07-31 15:13:45 +08:00
def thoughts ( self ) - > str :
return """
2025-09-25 14:11:09 +08:00
Keywords : { }
2025-07-31 15:13:45 +08:00
Looking for the most relevant articles .
2025-09-25 14:11:09 +08:00
""" .format(self.get_input().get( " query " , " -_-! " ))