2024-08-15 09:17:36 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2024-09-03 19:49:14 +08:00
import binascii
2025-03-24 13:18:47 +08:00
import logging
2024-08-15 09:17:36 +08:00
import re
2025-03-24 13:18:47 +08:00
import time
2024-08-15 09:17:36 +08:00
from copy import deepcopy
2025-05-19 19:34:05 +08:00
from datetime import datetime
2025-03-24 13:18:47 +08:00
from functools import partial
2024-09-09 12:08:50 +08:00
from timeit import default_timer as timer
2025-08-19 17:25:44 +08:00
import trio
2025-03-24 13:18:47 +08:00
from langfuse import Langfuse
2025-08-06 10:33:52 +08:00
from peewee import fn
2025-02-26 15:40:52 +08:00
from agentic_reasoning import DeepResearcher
2025-11-05 08:01:39 +08:00
from common . constants import LLMType , ParserType , StatusEnum
2025-03-24 13:18:47 +08:00
from api . db . db_models import DB , Dialog
2024-08-15 09:17:36 +08:00
from api . db . services . common_service import CommonService
2025-08-12 14:12:56 +08:00
from api . db . services . document_service import DocumentService
2024-08-15 09:17:36 +08:00
from api . db . services . knowledgebase_service import KnowledgebaseService
2025-03-24 13:18:47 +08:00
from api . db . services . langfuse_service import TenantLangfuseService
2025-08-13 16:41:01 +08:00
from api . db . services . llm_service import LLMBundle
from api . db . services . tenant_llm_service import TenantLLMService
2025-10-28 19:09:14 +08:00
from common . time_utils import current_timestamp , datetime_format
2025-08-19 17:25:44 +08:00
from graphrag . general . mind_map_extractor import MindMapExtractor
2024-08-15 09:17:36 +08:00
from rag . app . resume import forbidden_select_fields4resume
2025-02-26 15:40:52 +08:00
from rag . app . tag import label_question
2024-08-15 09:17:36 +08:00
from rag . nlp . search import index_name
2025-09-23 10:19:25 +08:00
from rag . prompts . generator import chunks_format , citation_prompt , cross_languages , full_question , kb_prompt , keyword_extraction , message_fit_in , \
gen_meta_filter , PROMPT_JINJA_ENV , ASK_SUMMARY
2025-11-03 08:50:05 +08:00
from common . token_utils import num_tokens_from_string
2025-02-26 10:21:04 +08:00
from rag . utils . tavily_conn import Tavily
2025-10-28 09:46:32 +08:00
from common . string_utils import remove_redundant_spaces
2025-11-06 09:36:38 +08:00
from common import settings
2024-08-15 09:17:36 +08:00
class DialogService ( CommonService ) :
model = Dialog
2025-04-15 10:20:33 +08:00
@classmethod
def save ( cls , * * kwargs ) :
""" Save a new record to database.
This method creates a new record in the database with the provided field values ,
forcing an insert operation rather than an update .
Args :
* * kwargs : Record field values as keyword arguments .
Returns :
Model instance : The created record object .
"""
sample_obj = cls . model ( * * kwargs ) . save ( force_insert = True )
return sample_obj
@classmethod
def update_many_by_id ( cls , data_list ) :
""" Update multiple records by their IDs.
This method updates multiple records in the database , identified by their IDs .
It automatically updates the update_time and update_date fields for each record .
Args :
data_list ( list ) : List of dictionaries containing record data to update .
Each dictionary must include an ' id ' field .
"""
with DB . atomic ( ) :
for data in data_list :
data [ " update_time " ] = current_timestamp ( )
data [ " update_date " ] = datetime_format ( datetime . now ( ) )
cls . model . update ( data ) . where ( cls . model . id == data [ " id " ] ) . execute ( )
2024-10-12 13:48:43 +08:00
@classmethod
@DB.connection_context ( )
2025-03-24 13:18:47 +08:00
def get_list ( cls , tenant_id , page_number , items_per_page , orderby , desc , id , name ) :
2024-10-12 13:48:43 +08:00
chats = cls . model . select ( )
if id :
chats = chats . where ( cls . model . id == id )
if name :
chats = chats . where ( cls . model . name == name )
2025-03-24 13:18:47 +08:00
chats = chats . where ( ( cls . model . tenant_id == tenant_id ) & ( cls . model . status == StatusEnum . VALID . value ) )
2024-10-12 13:48:43 +08:00
if desc :
chats = chats . order_by ( cls . model . getter_by ( orderby ) . desc ( ) )
else :
chats = chats . order_by ( cls . model . getter_by ( orderby ) . asc ( ) )
chats = chats . paginate ( page_number , items_per_page )
return list ( chats . dicts ( ) )
2025-08-06 10:33:52 +08:00
@classmethod
@DB.connection_context ( )
def get_by_tenant_ids ( cls , joined_tenant_ids , user_id , page_number , items_per_page , orderby , desc , keywords , parser_id = None ) :
from api . db . db_models import User
fields = [
cls . model . id ,
cls . model . tenant_id ,
cls . model . name ,
cls . model . description ,
cls . model . language ,
cls . model . llm_id ,
cls . model . llm_setting ,
cls . model . prompt_type ,
cls . model . prompt_config ,
cls . model . similarity_threshold ,
cls . model . vector_similarity_weight ,
cls . model . top_n ,
cls . model . top_k ,
cls . model . do_refer ,
cls . model . rerank_id ,
cls . model . kb_ids ,
2025-08-13 10:26:26 +08:00
cls . model . icon ,
2025-08-06 10:33:52 +08:00
cls . model . status ,
User . nickname ,
User . avatar . alias ( " tenant_avatar " ) ,
cls . model . update_time ,
cls . model . create_time ,
]
if keywords :
dialogs = (
cls . model . select ( * fields )
. join ( User , on = ( cls . model . tenant_id == User . id ) )
. where (
( cls . model . tenant_id . in_ ( joined_tenant_ids ) | ( cls . model . tenant_id == user_id ) ) & ( cls . model . status == StatusEnum . VALID . value ) ,
( fn . LOWER ( cls . model . name ) . contains ( keywords . lower ( ) ) ) ,
)
)
else :
dialogs = (
cls . model . select ( * fields )
. join ( User , on = ( cls . model . tenant_id == User . id ) )
. where (
( cls . model . tenant_id . in_ ( joined_tenant_ids ) | ( cls . model . tenant_id == user_id ) ) & ( cls . model . status == StatusEnum . VALID . value ) ,
)
)
if parser_id :
dialogs = dialogs . where ( cls . model . parser_id == parser_id )
if desc :
dialogs = dialogs . order_by ( cls . model . getter_by ( orderby ) . desc ( ) )
else :
dialogs = dialogs . order_by ( cls . model . getter_by ( orderby ) . asc ( ) )
count = dialogs . count ( )
if page_number and items_per_page :
dialogs = dialogs . paginate ( page_number , items_per_page )
return list ( dialogs . dicts ( ) ) , count
2025-09-29 10:16:13 +08:00
@classmethod
@DB.connection_context ( )
def get_all_dialogs_by_tenant_id ( cls , tenant_id ) :
fields = [ cls . model . id ]
dialogs = cls . model . select ( * fields ) . where ( cls . model . tenant_id == tenant_id )
dialogs . order_by ( cls . model . create_time . asc ( ) )
offset , limit = 0 , 100
res = [ ]
while True :
d_batch = dialogs . offset ( offset ) . limit ( limit )
_temp = list ( d_batch . dicts ( ) )
if not _temp :
break
res . extend ( _temp )
offset + = limit
return res
2025-08-06 10:33:52 +08:00
2025-02-21 12:24:02 +08:00
def chat_solo ( dialog , messages , stream = True ) :
2025-07-16 18:06:03 +08:00
if TenantLLMService . llm_id2llm_type ( dialog . llm_id ) == " image2text " :
2025-02-21 12:24:02 +08:00
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . IMAGE2TEXT , dialog . llm_id )
else :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . CHAT , dialog . llm_id )
prompt_config = dialog . prompt_config
tts_mdl = None
if prompt_config . get ( " tts " ) :
tts_mdl = LLMBundle ( dialog . tenant_id , LLMType . TTS )
2025-03-24 13:18:47 +08:00
msg = [ { " role " : m [ " role " ] , " content " : re . sub ( r " ## \ d+ \ $ \ $ " , " " , m [ " content " ] ) } for m in messages if m [ " role " ] != " system " ]
2025-02-21 12:24:02 +08:00
if stream :
last_ans = " "
2025-05-06 19:30:00 +08:00
delta_ans = " "
2025-02-21 12:24:02 +08:00
for ans in chat_mdl . chat_streamly ( prompt_config . get ( " system " , " " ) , msg , dialog . llm_setting ) :
answer = ans
2025-09-23 10:19:25 +08:00
delta_ans = ans [ len ( last_ans ) : ]
2025-02-21 12:24:02 +08:00
if num_tokens_from_string ( delta_ans ) < 16 :
continue
last_ans = answer
2025-03-04 11:58:10 +08:00
yield { " answer " : answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) , " prompt " : " " , " created_at " : time . time ( ) }
2025-05-06 19:30:00 +08:00
delta_ans = " "
2025-03-04 11:58:10 +08:00
if delta_ans :
yield { " answer " : answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) , " prompt " : " " , " created_at " : time . time ( ) }
2025-02-21 12:24:02 +08:00
else :
answer = chat_mdl . chat ( prompt_config . get ( " system " , " " ) , msg , dialog . llm_setting )
user_content = msg [ - 1 ] . get ( " content " , " [content not available] " )
logging . debug ( " User: {} |Assistant: {} " . format ( user_content , answer ) )
yield { " answer " : answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , answer ) , " prompt " : " " , " created_at " : time . time ( ) }
2025-06-05 13:00:43 +08:00
def get_models ( dialog ) :
embd_mdl , chat_mdl , rerank_mdl , tts_mdl = None , None , None , None
kbs = KnowledgebaseService . get_by_ids ( dialog . kb_ids )
embedding_list = list ( set ( [ kb . embd_id for kb in kbs ] ) )
if len ( embedding_list ) > 1 :
raise Exception ( " **ERROR**: Knowledge bases use different embedding models. " )
if embedding_list :
embd_mdl = LLMBundle ( dialog . tenant_id , LLMType . EMBEDDING , embedding_list [ 0 ] )
if not embd_mdl :
raise LookupError ( " Embedding model( %s ) not found " % embedding_list [ 0 ] )
2025-07-16 18:06:03 +08:00
if TenantLLMService . llm_id2llm_type ( dialog . llm_id ) == " image2text " :
2025-06-05 13:00:43 +08:00
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . IMAGE2TEXT , dialog . llm_id )
else :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . CHAT , dialog . llm_id )
if dialog . rerank_id :
rerank_mdl = LLMBundle ( dialog . tenant_id , LLMType . RERANK , dialog . rerank_id )
if dialog . prompt_config . get ( " tts " ) :
tts_mdl = LLMBundle ( dialog . tenant_id , LLMType . TTS )
return kbs , embd_mdl , rerank_mdl , chat_mdl , tts_mdl
2025-05-29 10:03:51 +08:00
BAD_CITATION_PATTERNS = [
re . compile ( r " \ ( \ s*ID \ s*[: ]* \ s*( \ d+) \ s* \ ) " ) , # (ID: 12)
re . compile ( r " \ [ \ s*ID \ s*[: ]* \ s*( \ d+) \ s* \ ] " ) , # [ID: 12]
re . compile ( r " 【 \ s*ID \ s*[: ]* \ s*( \ d+) \ s*】 " ) , # 【ID: 12】
re . compile ( r " ref \ s*( \ d+) " , flags = re . IGNORECASE ) , # ref12、REF 12
]
2025-06-18 16:45:42 +08:00
2025-06-05 13:00:43 +08:00
def repair_bad_citation_formats ( answer : str , kbinfos : dict , idx : set ) :
max_index = len ( kbinfos [ " chunks " ] )
def safe_add ( i ) :
if 0 < = i < max_index :
idx . add ( i )
return True
return False
def find_and_replace ( pattern , group_index = 1 , repl = lambda i : f " ID: { i } " , flags = 0 ) :
nonlocal answer
def replacement ( match ) :
try :
i = int ( match . group ( group_index ) )
if safe_add ( i ) :
return f " [ { repl ( i ) } ] "
except Exception :
pass
return match . group ( 0 )
answer = re . sub ( pattern , replacement , answer , flags = flags )
for pattern in BAD_CITATION_PATTERNS :
find_and_replace ( pattern )
return answer , idx
2025-05-29 10:03:51 +08:00
2025-09-05 19:26:15 +08:00
def convert_conditions ( metadata_condition ) :
if metadata_condition is None :
metadata_condition = { }
op_mapping = {
" is " : " = " ,
" not is " : " ≠ "
}
return [
2025-09-23 10:19:25 +08:00
{
" op " : op_mapping . get ( cond [ " comparison_operator " ] , cond [ " comparison_operator " ] ) ,
" key " : cond [ " name " ] ,
" value " : cond [ " value " ]
}
for cond in metadata_condition . get ( " conditions " , [ ] )
]
2025-09-05 19:26:15 +08:00
2025-08-12 14:12:56 +08:00
def meta_filter ( metas : dict , filters : list [ dict ] ) :
2025-08-25 18:29:24 +08:00
doc_ids = set ( [ ] )
2025-08-15 17:44:58 +08:00
2025-08-12 14:12:56 +08:00
def filter_out ( v2docs , operator , value ) :
2025-08-25 18:29:24 +08:00
ids = [ ]
2025-08-15 17:44:58 +08:00
for input , docids in v2docs . items ( ) :
2025-11-05 15:14:30 +08:00
if operator in [ " = " , " ≠ " , " > " , " < " , " ≥ " , " ≤ " ] :
try :
input = float ( input )
value = float ( value )
except Exception :
input = str ( input )
value = str ( value )
2025-08-12 14:12:56 +08:00
for conds in [
2025-09-23 10:19:25 +08:00
( operator == " contains " , str ( value ) . lower ( ) in str ( input ) . lower ( ) ) ,
( operator == " not contains " , str ( value ) . lower ( ) not in str ( input ) . lower ( ) ) ,
( operator == " start with " , str ( input ) . lower ( ) . startswith ( str ( value ) . lower ( ) ) ) ,
( operator == " end with " , str ( input ) . lower ( ) . endswith ( str ( value ) . lower ( ) ) ) ,
( operator == " empty " , not input ) ,
( operator == " not empty " , input ) ,
( operator == " = " , input == value ) ,
( operator == " ≠ " , input != value ) ,
( operator == " > " , input > value ) ,
( operator == " < " , input < value ) ,
( operator == " ≥ " , input > = value ) ,
( operator == " ≤ " , input < = value ) ,
] :
2025-08-12 14:12:56 +08:00
try :
if all ( conds ) :
2025-08-25 18:29:24 +08:00
ids . extend ( docids )
break
2025-08-12 14:12:56 +08:00
except Exception :
pass
2025-08-25 18:29:24 +08:00
return ids
2025-08-12 14:12:56 +08:00
for k , v2docs in metas . items ( ) :
for f in filters :
if k != f [ " key " ] :
continue
2025-08-25 18:29:24 +08:00
ids = filter_out ( v2docs , f [ " op " ] , f [ " value " ] )
if not doc_ids :
doc_ids = set ( ids )
else :
doc_ids = doc_ids & set ( ids )
if not doc_ids :
return [ ]
return list ( doc_ids )
2025-08-12 14:12:56 +08:00
2024-08-15 09:17:36 +08:00
def chat ( dialog , messages , stream = True , * * kwargs ) :
assert messages [ - 1 ] [ " role " ] == " user " , " The last content of this conversation is not from user. "
2025-06-05 13:00:43 +08:00
if not dialog . kb_ids and not dialog . prompt_config . get ( " tavily_api_key " ) :
2025-02-21 12:24:02 +08:00
for ans in chat_solo ( dialog , messages , stream ) :
yield ans
2025-11-16 19:29:20 +08:00
return None
2024-12-19 18:13:33 +08:00
chat_start_ts = timer ( )
2025-07-16 18:06:03 +08:00
if TenantLLMService . llm_id2llm_type ( dialog . llm_id ) == " image2text " :
2025-02-18 13:42:22 +08:00
llm_model_config = TenantLLMService . get_model_config ( dialog . tenant_id , LLMType . IMAGE2TEXT , dialog . llm_id )
2024-08-15 09:17:36 +08:00
else :
2025-02-18 13:42:22 +08:00
llm_model_config = TenantLLMService . get_model_config ( dialog . tenant_id , LLMType . CHAT , dialog . llm_id )
max_tokens = llm_model_config . get ( " max_tokens " , 8192 )
2024-12-19 18:13:33 +08:00
check_llm_ts = timer ( )
2025-03-24 13:18:47 +08:00
langfuse_tracer = None
2025-08-04 14:45:43 +08:00
trace_context = { }
2025-03-24 13:18:47 +08:00
langfuse_keys = TenantLangfuseService . filter_by_tenant ( tenant_id = dialog . tenant_id )
if langfuse_keys :
langfuse = Langfuse ( public_key = langfuse_keys . public_key , secret_key = langfuse_keys . secret_key , host = langfuse_keys . host )
if langfuse . auth_check ( ) :
langfuse_tracer = langfuse
2025-08-04 14:45:43 +08:00
trace_id = langfuse_tracer . create_trace_id ( )
trace_context = { " trace_id " : trace_id }
2025-03-24 13:18:47 +08:00
check_langfuse_tracer_ts = timer ( )
2025-06-05 13:00:43 +08:00
kbs , embd_mdl , rerank_mdl , chat_mdl , tts_mdl = get_models ( dialog )
toolcall_session , tools = kwargs . get ( " toolcall_session " ) , kwargs . get ( " tools " )
if toolcall_session and tools :
chat_mdl . bind_tools ( toolcall_session , tools )
bind_models_ts = timer ( )
2024-12-19 18:13:33 +08:00
2025-11-06 09:36:38 +08:00
retriever = settings . retriever
2024-08-15 09:17:36 +08:00
questions = [ m [ " content " ] for m in messages if m [ " role " ] == " user " ] [ - 3 : ]
2025-08-12 14:12:56 +08:00
attachments = kwargs [ " doc_ids " ] . split ( " , " ) if " doc_ids " in kwargs else [ ]
2024-08-15 09:17:36 +08:00
if " doc_ids " in messages [ - 1 ] :
attachments = messages [ - 1 ] [ " doc_ids " ]
2025-08-12 14:12:56 +08:00
2024-08-15 09:17:36 +08:00
prompt_config = dialog . prompt_config
field_map = KnowledgebaseService . get_field_map ( dialog . kb_ids )
# try to use sql if field mapping is good to go
if field_map :
2024-11-14 17:13:48 +08:00
logging . debug ( " Use SQL to retrieval: {} " . format ( questions [ - 1 ] ) )
2025-09-04 16:51:13 +08:00
ans = use_sql ( questions [ - 1 ] , field_map , dialog . tenant_id , chat_mdl , prompt_config . get ( " quote " , True ) , dialog . kb_ids )
2024-08-15 09:17:36 +08:00
if ans :
yield ans
2025-11-16 19:29:20 +08:00
return None
2024-08-15 09:17:36 +08:00
for p in prompt_config [ " parameters " ] :
if p [ " key " ] == " knowledge " :
continue
if p [ " key " ] not in kwargs and not p [ " optional " ] :
raise KeyError ( " Miss parameter: " + p [ " key " ] )
if p [ " key " ] not in kwargs :
2025-03-24 13:18:47 +08:00
prompt_config [ " system " ] = prompt_config [ " system " ] . replace ( " { %s } " % p [ " key " ] , " " )
2024-08-15 09:17:36 +08:00
2024-09-20 17:25:55 +08:00
if len ( questions ) > 1 and prompt_config . get ( " refine_multiturn " ) :
questions = [ full_question ( dialog . tenant_id , dialog . llm_id , messages ) ]
else :
questions = questions [ - 1 : ]
2024-12-19 18:13:33 +08:00
2025-05-09 15:32:02 +08:00
if prompt_config . get ( " cross_languages " ) :
questions = [ cross_languages ( dialog . tenant_id , dialog . llm_id , questions [ 0 ] , prompt_config [ " cross_languages " ] ) ]
2025-08-12 14:12:56 +08:00
if dialog . meta_data_filter :
metas = DocumentService . get_meta_by_kbs ( dialog . kb_ids )
if dialog . meta_data_filter . get ( " method " ) == " auto " :
filters = gen_meta_filter ( chat_mdl , metas , questions [ - 1 ] )
attachments . extend ( meta_filter ( metas , filters ) )
2025-08-13 12:43:31 +08:00
if not attachments :
attachments = None
2025-08-12 14:12:56 +08:00
elif dialog . meta_data_filter . get ( " method " ) == " manual " :
attachments . extend ( meta_filter ( metas , dialog . meta_data_filter [ " manual " ] ) )
2025-08-13 12:43:31 +08:00
if not attachments :
attachments = None
2025-08-12 14:12:56 +08:00
2025-06-05 13:00:43 +08:00
if prompt_config . get ( " keyword " , False ) :
questions [ - 1 ] + = keyword_extraction ( chat_mdl , questions [ - 1 ] )
2024-09-20 17:25:55 +08:00
2025-06-05 13:00:43 +08:00
refine_question_ts = timer ( )
2024-08-15 09:17:36 +08:00
2025-02-20 17:41:01 +08:00
thought = " "
kbinfos = { " total " : 0 , " chunks " : [ ] , " doc_aggs " : [ ] }
2025-08-13 12:43:31 +08:00
knowledges = [ ]
2024-12-19 18:13:33 +08:00
2025-08-13 12:43:31 +08:00
if attachments is not None and " knowledge " in [ p [ " key " ] for p in prompt_config [ " parameters " ] ] :
2024-10-29 13:19:01 +08:00
tenant_ids = list ( set ( [ kb . tenant_id for kb in kbs ] ) )
2025-02-20 17:41:01 +08:00
knowledges = [ ]
if prompt_config . get ( " reasoning " , False ) :
2025-03-24 13:18:47 +08:00
reasoner = DeepResearcher (
chat_mdl ,
prompt_config ,
2025-08-15 17:44:58 +08:00
partial (
retriever . retrieval ,
embd_mdl = embd_mdl ,
tenant_ids = tenant_ids ,
kb_ids = dialog . kb_ids ,
page = 1 ,
page_size = dialog . top_n ,
similarity_threshold = 0.2 ,
vector_similarity_weight = 0.3 ,
doc_ids = attachments ,
) ,
2025-03-24 13:18:47 +08:00
)
2025-02-26 15:40:52 +08:00
for think in reasoner . thinking ( kbinfos , " " . join ( questions ) ) :
2025-02-20 17:41:01 +08:00
if isinstance ( think , str ) :
thought = think
knowledges = [ t for t in think . split ( " \n " ) if t ]
2025-02-28 10:05:18 +08:00
elif stream :
2025-02-20 17:41:01 +08:00
yield think
else :
2025-06-05 13:00:43 +08:00
if embd_mdl :
kbinfos = retriever . retrieval (
" " . join ( questions ) ,
embd_mdl ,
tenant_ids ,
dialog . kb_ids ,
1 ,
dialog . top_n ,
dialog . similarity_threshold ,
dialog . vector_similarity_weight ,
doc_ids = attachments ,
top = dialog . top_k ,
aggs = False ,
rerank_mdl = rerank_mdl ,
rank_feature = label_question ( " " . join ( questions ) , kbs ) ,
)
2025-10-10 17:07:55 +08:00
if prompt_config . get ( " toc_enhance " ) :
cks = retriever . retrieval_by_toc ( " " . join ( questions ) , kbinfos [ " chunks " ] , tenant_ids , chat_mdl , dialog . top_n )
if cks :
kbinfos [ " chunks " ] = cks
2025-02-26 10:21:04 +08:00
if prompt_config . get ( " tavily_api_key " ) :
tav = Tavily ( prompt_config [ " tavily_api_key " ] )
tav_res = tav . retrieve_chunks ( " " . join ( questions ) )
kbinfos [ " chunks " ] . extend ( tav_res [ " chunks " ] )
kbinfos [ " doc_aggs " ] . extend ( tav_res [ " doc_aggs " ] )
2025-02-20 17:41:01 +08:00
if prompt_config . get ( " use_kg " ) :
2025-10-10 09:17:36 +08:00
ck = settings . kg_retriever . retrieval ( " " . join ( questions ) , tenant_ids , dialog . kb_ids , embd_mdl ,
2025-09-23 10:19:25 +08:00
LLMBundle ( dialog . tenant_id , LLMType . CHAT ) )
2025-02-20 17:41:01 +08:00
if ck [ " content_with_weight " ] :
kbinfos [ " chunks " ] . insert ( 0 , ck )
knowledges = kb_prompt ( kbinfos , max_tokens )
2024-12-19 18:13:33 +08:00
2025-03-24 13:18:47 +08:00
logging . debug ( " {} -> {} " . format ( " " . join ( questions ) , " \n -> " . join ( knowledges ) ) )
2024-08-15 09:17:36 +08:00
2025-02-20 17:41:01 +08:00
retrieval_ts = timer ( )
2024-08-15 09:17:36 +08:00
if not knowledges and prompt_config . get ( " empty_response " ) :
2024-09-03 19:49:14 +08:00
empty_res = prompt_config [ " empty_response " ]
2025-09-23 10:19:25 +08:00
yield { " answer " : empty_res , " reference " : kbinfos , " prompt " : " \n \n ### Query: \n %s " % " " . join ( questions ) ,
" audio_binary " : tts ( tts_mdl , empty_res ) }
2024-08-15 09:17:36 +08:00
return { " answer " : prompt_config [ " empty_response " ] , " reference " : kbinfos }
2025-01-13 14:35:24 +08:00
kwargs [ " knowledge " ] = " \n ------ \n " + " \n \n ------ \n \n " . join ( knowledges )
2024-08-15 09:17:36 +08:00
gen_conf = dialog . llm_setting
msg = [ { " role " : " system " , " content " : prompt_config [ " system " ] . format ( * * kwargs ) } ]
2025-03-11 19:56:21 +08:00
prompt4citation = " "
if knowledges and ( prompt_config . get ( " quote " , True ) and kwargs . get ( " quote " , True ) ) :
prompt4citation = citation_prompt ( )
2025-03-24 13:18:47 +08:00
msg . extend ( [ { " role " : m [ " role " ] , " content " : re . sub ( r " ## \ d+ \ $ \ $ " , " " , m [ " content " ] ) } for m in messages if m [ " role " ] != " system " ] )
2025-03-11 19:56:21 +08:00
used_token_count , msg = message_fit_in ( msg , int ( max_tokens * 0.95 ) )
2024-08-15 09:17:36 +08:00
assert len ( msg ) > = 2 , f " message_fit_in has bug: { msg } "
2024-08-26 16:14:15 +08:00
prompt = msg [ 0 ] [ " content " ]
2024-08-15 09:17:36 +08:00
if " max_tokens " in gen_conf :
2025-03-24 13:18:47 +08:00
gen_conf [ " max_tokens " ] = min ( gen_conf [ " max_tokens " ] , max_tokens - used_token_count )
2024-08-15 09:17:36 +08:00
def decorate_answer ( answer ) :
2025-06-05 13:00:43 +08:00
nonlocal embd_mdl , prompt_config , knowledges , kwargs , kbinfos , prompt , retrieval_ts , questions , langfuse_tracer
2024-12-19 18:13:33 +08:00
2024-08-15 09:17:36 +08:00
refs = [ ]
2025-02-20 17:41:01 +08:00
ans = answer . split ( " </think> " )
think = " "
if len ( ans ) == 2 :
think = ans [ 0 ] + " </think> "
answer = ans [ 1 ]
2025-04-15 09:33:53 +08:00
2024-08-15 09:17:36 +08:00
if knowledges and ( prompt_config . get ( " quote " , True ) and kwargs . get ( " quote " , True ) ) :
2025-04-15 09:33:53 +08:00
idx = set ( [ ] )
2025-06-05 13:00:43 +08:00
if embd_mdl and not re . search ( r " \ [ID:([0-9]+) \ ] " , answer ) :
2025-03-24 13:18:47 +08:00
answer , idx = retriever . insert_citations (
answer ,
[ ck [ " content_ltks " ] for ck in kbinfos [ " chunks " ] ] ,
[ ck [ " vector " ] for ck in kbinfos [ " chunks " ] ] ,
embd_mdl ,
tkweight = 1 - dialog . vector_similarity_weight ,
vtweight = dialog . vector_similarity_weight ,
)
2025-03-11 19:56:21 +08:00
else :
2025-05-29 10:03:51 +08:00
for match in re . finditer ( r " \ [ID:([0-9]+) \ ] " , answer ) :
2025-04-15 09:33:53 +08:00
i = int ( match . group ( 1 ) )
if i < len ( kbinfos [ " chunks " ] ) :
idx . add ( i )
2025-05-19 19:34:05 +08:00
answer , idx = repair_bad_citation_formats ( answer , kbinfos , idx )
2025-03-11 19:56:21 +08:00
2024-08-15 09:17:36 +08:00
idx = set ( [ kbinfos [ " chunks " ] [ int ( i ) ] [ " doc_id " ] for i in idx ] )
2025-03-24 13:18:47 +08:00
recall_docs = [ d for d in kbinfos [ " doc_aggs " ] if d [ " doc_id " ] in idx ]
2024-12-08 14:21:12 +08:00
if not recall_docs :
recall_docs = kbinfos [ " doc_aggs " ]
2024-08-15 09:17:36 +08:00
kbinfos [ " doc_aggs " ] = recall_docs
refs = deepcopy ( kbinfos )
for c in refs [ " chunks " ] :
if c . get ( " vector " ) :
del c [ " vector " ]
if answer . lower ( ) . find ( " invalid key " ) > = 0 or answer . lower ( ) . find ( " invalid api " ) > = 0 :
2024-12-07 11:04:36 +08:00
answer + = " Please set LLM API-Key in ' User Setting -> Model providers -> API-Key ' "
2024-12-19 18:13:33 +08:00
finish_chat_ts = timer ( )
total_time_cost = ( finish_chat_ts - chat_start_ts ) * 1000
check_llm_time_cost = ( check_llm_ts - chat_start_ts ) * 1000
2025-03-24 13:18:47 +08:00
check_langfuse_tracer_cost = ( check_langfuse_tracer_ts - check_llm_ts ) * 1000
2025-06-05 13:00:43 +08:00
bind_embedding_time_cost = ( bind_models_ts - check_langfuse_tracer_ts ) * 1000
refine_question_time_cost = ( refine_question_ts - bind_models_ts ) * 1000
retrieval_time_cost = ( retrieval_ts - refine_question_ts ) * 1000
2024-12-19 18:13:33 +08:00
generate_result_time_cost = ( finish_chat_ts - retrieval_ts ) * 1000
2025-03-24 13:18:47 +08:00
tk_num = num_tokens_from_string ( think + answer )
2025-03-03 13:12:38 +08:00
prompt + = " \n \n ### Query: \n %s " % " " . join ( questions )
2025-03-14 13:37:31 +08:00
prompt = (
2025-03-24 13:18:47 +08:00
f " { prompt } \n \n "
" ## Time elapsed: \n "
f " - Total: { total_time_cost : .1f } ms \n "
f " - Check LLM: { check_llm_time_cost : .1f } ms \n "
f " - Check Langfuse tracer: { check_langfuse_tracer_cost : .1f } ms \n "
2025-06-05 13:00:43 +08:00
f " - Bind models: { bind_embedding_time_cost : .1f } ms \n "
f " - Query refinement(LLM): { refine_question_time_cost : .1f } ms \n "
2025-03-24 13:18:47 +08:00
f " - Retrieval: { retrieval_time_cost : .1f } ms \n "
f " - Generate answer: { generate_result_time_cost : .1f } ms \n \n "
" ## Token usage: \n "
f " - Generated tokens(approximately): { tk_num } \n "
f " - Token speed: { int ( tk_num / ( generate_result_time_cost / 1000.0 ) ) } /s "
2025-03-14 13:37:31 +08:00
)
2025-03-24 13:18:47 +08:00
2025-03-24 15:14:36 +08:00
# Add a condition check to call the end method only if langfuse_tracer exists
2025-04-08 16:09:03 +08:00
if langfuse_tracer and " langfuse_generation " in locals ( ) :
2025-08-04 14:45:43 +08:00
langfuse_output = " \n " + re . sub ( r " ^.*?(### Query:.*) " , r " \ 1 " , prompt , flags = re . DOTALL )
langfuse_output = { " time_elapsed: " : re . sub ( r " \ n " , " \n " , langfuse_output ) , " created_at " : time . time ( ) }
langfuse_generation . update ( output = langfuse_output )
langfuse_generation . end ( )
2025-03-24 13:18:47 +08:00
return { " answer " : think + answer , " reference " : refs , " prompt " : re . sub ( r " \ n " , " \n " , prompt ) , " created_at " : time . time ( ) }
if langfuse_tracer :
2025-08-04 14:45:43 +08:00
langfuse_generation = langfuse_tracer . start_generation (
2025-09-23 10:19:25 +08:00
trace_context = trace_context , name = " chat " , model = llm_model_config [ " llm_name " ] ,
input = { " prompt " : prompt , " prompt4citation " : prompt4citation , " messages " : msg }
2025-08-04 14:45:43 +08:00
)
2024-08-15 09:17:36 +08:00
if stream :
2024-09-03 19:49:14 +08:00
last_ans = " "
2024-08-15 09:17:36 +08:00
answer = " "
2025-03-24 13:18:47 +08:00
for ans in chat_mdl . chat_streamly ( prompt + prompt4citation , msg [ 1 : ] , gen_conf ) :
2025-02-20 17:41:01 +08:00
if thought :
2025-04-24 11:44:10 +08:00
ans = re . sub ( r " ^.*</think> " , " " , ans , flags = re . DOTALL )
2024-08-15 09:17:36 +08:00
answer = ans
2025-09-23 10:19:25 +08:00
delta_ans = ans [ len ( last_ans ) : ]
2024-09-12 17:51:20 +08:00
if num_tokens_from_string ( delta_ans ) < 16 :
2024-09-03 19:49:14 +08:00
continue
last_ans = answer
2025-03-24 13:18:47 +08:00
yield { " answer " : thought + answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) }
2025-09-23 10:19:25 +08:00
delta_ans = answer [ len ( last_ans ) : ]
2024-09-03 19:49:14 +08:00
if delta_ans :
2025-03-24 13:18:47 +08:00
yield { " answer " : thought + answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) }
yield decorate_answer ( thought + answer )
2024-08-15 09:17:36 +08:00
else :
2025-03-24 13:18:47 +08:00
answer = chat_mdl . chat ( prompt + prompt4citation , msg [ 1 : ] , gen_conf )
2025-02-13 23:27:01 -03:00
user_content = msg [ - 1 ] . get ( " content " , " [content not available] " )
logging . debug ( " User: {} |Assistant: {} " . format ( user_content , answer ) )
2024-09-03 19:49:14 +08:00
res = decorate_answer ( answer )
res [ " audio_binary " ] = tts ( tts_mdl , answer )
yield res
2024-08-15 09:17:36 +08:00
2025-11-16 19:29:20 +08:00
return None
2024-08-15 09:17:36 +08:00
2025-09-04 16:51:13 +08:00
def use_sql ( question , field_map , tenant_id , chat_mdl , quota = True , kb_ids = None ) :
2025-11-10 19:02:07 +08:00
sys_prompt = """
You are a Database Administrator . You need to check the fields of the following tables based on the user ' s list of questions and write the SQL corresponding to the last question.
Ensure that :
1. Field names should not start with a digit . If any field name starts with a digit , use double quotes around it .
2. Write only the SQL , no explanations or additional text .
"""
2024-12-19 18:13:33 +08:00
user_prompt = """
Table name : { } ;
Table of database fields are as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Question are as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Please write the SQL , only SQL , without any other explanations or text .
2025-03-24 13:18:47 +08:00
""" .format(index_name(tenant_id), " \n " .join([f " {k} : {v} " for k, v in field_map.items()]), question)
2024-08-15 09:17:36 +08:00
tried_times = 0
def get_table ( ) :
2024-12-19 18:13:33 +08:00
nonlocal sys_prompt , user_prompt , question , tried_times
2025-03-24 13:18:47 +08:00
sql = chat_mdl . chat ( sys_prompt , [ { " role " : " user " , " content " : user_prompt } ] , { " temperature " : 0.06 } )
2025-04-24 11:44:10 +08:00
sql = re . sub ( r " ^.*</think> " , " " , sql , flags = re . DOTALL )
2024-12-19 18:13:33 +08:00
logging . debug ( f " { question } ==> { user_prompt } get SQL: { sql } " )
2024-08-15 09:17:36 +08:00
sql = re . sub ( r " [ \ r \ n]+ " , " " , sql . lower ( ) )
sql = re . sub ( r " .*select " , " select " , sql . lower ( ) )
sql = re . sub ( r " + " , " " , sql )
sql = re . sub ( r " ([;; ]|```).* " , " " , sql )
2025-11-10 19:02:07 +08:00
sql = re . sub ( r " & " , " and " , sql )
2025-03-24 13:18:47 +08:00
if sql [ : len ( " select " ) ] != " select " :
2024-08-15 09:17:36 +08:00
return None , None
if not re . search ( r " ((sum|avg|max|min) \ (|group by ) " , sql . lower ( ) ) :
2025-03-24 13:18:47 +08:00
if sql [ : len ( " select * " ) ] != " select * " :
2024-08-15 09:17:36 +08:00
sql = " select doc_id,docnm_kwd, " + sql [ 6 : ]
else :
flds = [ ]
for k in field_map . keys ( ) :
if k in forbidden_select_fields4resume :
continue
if len ( flds ) > 11 :
break
flds . append ( k )
sql = " select doc_id,docnm_kwd, " + " , " . join ( flds ) + sql [ 8 : ]
2025-09-04 16:51:13 +08:00
if kb_ids :
kb_filter = " ( " + " OR " . join ( [ f " kb_id = ' { kb_id } ' " for kb_id in kb_ids ] ) + " ) "
if " where " not in sql . lower ( ) :
sql + = f " WHERE { kb_filter } "
else :
sql + = f " AND { kb_filter } "
2024-11-14 17:13:48 +08:00
logging . debug ( f " { question } get SQL(refined): { sql } " )
2024-08-15 09:17:36 +08:00
tried_times + = 1
2025-11-06 09:36:38 +08:00
return settings . retriever . sql_retrieval ( sql , format = " json " ) , sql
2024-08-15 09:17:36 +08:00
tbl , sql = get_table ( )
if tbl is None :
return None
if tbl . get ( " error " ) and tried_times < = 2 :
2024-12-19 18:13:33 +08:00
user_prompt = """
Table name : { } ;
Table of database fields are as follows :
2024-08-15 09:17:36 +08:00
{ }
2025-04-15 10:20:33 +08:00
2024-12-19 18:13:33 +08:00
Question are as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Please write the SQL , only SQL , without any other explanations or text .
2025-04-15 10:20:33 +08:00
2024-08-15 09:17:36 +08:00
2024-12-19 18:13:33 +08:00
The SQL error you provided last time is as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Error issued by database as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Please correct the error and write SQL again , only SQL , without any other explanations or text .
2025-03-24 13:18:47 +08:00
""" .format(index_name(tenant_id), " \n " .join([f " {k} : {v} " for k, v in field_map.items()]), question, sql, tbl[ " error " ])
2024-08-15 09:17:36 +08:00
tbl , sql = get_table ( )
2024-11-14 17:13:48 +08:00
logging . debug ( " TRY it again: {} " . format ( sql ) )
2024-08-15 09:17:36 +08:00
2024-11-14 17:13:48 +08:00
logging . debug ( " GET table: {} " . format ( tbl ) )
2024-08-15 09:17:36 +08:00
if tbl . get ( " error " ) or len ( tbl [ " rows " ] ) == 0 :
return None
2025-03-24 13:18:47 +08:00
docid_idx = set ( [ ii for ii , c in enumerate ( tbl [ " columns " ] ) if c [ " name " ] == " doc_id " ] )
doc_name_idx = set ( [ ii for ii , c in enumerate ( tbl [ " columns " ] ) if c [ " name " ] == " docnm_kwd " ] )
column_idx = [ ii for ii in range ( len ( tbl [ " columns " ] ) ) if ii not in ( docid_idx | doc_name_idx ) ]
2024-08-15 09:17:36 +08:00
2024-12-19 18:13:33 +08:00
# compose Markdown table
2025-03-24 13:18:47 +08:00
columns = (
2025-09-23 10:19:25 +08:00
" | " + " | " . join (
[ re . sub ( r " (/.*|( [^( ) ]+) ) " , " " , field_map . get ( tbl [ " columns " ] [ i ] [ " name " ] , tbl [ " columns " ] [ i ] [ " name " ] ) ) for i in column_idx ] ) + (
" |Source| " if docid_idx and docid_idx else " | " )
2025-03-24 13:18:47 +08:00
)
2024-08-15 09:17:36 +08:00
2025-03-24 13:18:47 +08:00
line = " | " + " | " . join ( [ " ------ " for _ in range ( len ( column_idx ) ) ] ) + ( " |------| " if docid_idx and docid_idx else " " )
2024-08-15 09:17:36 +08:00
2025-10-28 09:46:32 +08:00
rows = [ " | " + " | " . join ( [ remove_redundant_spaces ( str ( r [ i ] ) ) for i in column_idx ] ) . replace ( " None " , " " ) + " | " for r in tbl [ " rows " ] ]
2024-11-06 18:47:53 +08:00
rows = [ r for r in rows if re . sub ( r " [ |]+ " , " " , r ) ]
2024-08-15 09:17:36 +08:00
if quota :
rows = " \n " . join ( [ r + f " ## { ii } $$ | " for ii , r in enumerate ( rows ) ] )
else :
rows = " \n " . join ( [ r + f " ## { ii } $$ | " for ii , r in enumerate ( rows ) ] )
rows = re . sub ( r " T[0-9] {2} :[0-9] {2} :[0-9] {2} ( \ .[0-9]+Z)? \ | " , " | " , rows )
2024-12-19 18:13:33 +08:00
if not docid_idx or not doc_name_idx :
2024-11-14 17:13:48 +08:00
logging . warning ( " SQL missing field: " + sql )
2025-03-24 13:18:47 +08:00
return { " answer " : " \n " . join ( [ columns , line , rows ] ) , " reference " : { " chunks " : [ ] , " doc_aggs " : [ ] } , " prompt " : sys_prompt }
2024-08-15 09:17:36 +08:00
docid_idx = list ( docid_idx ) [ 0 ]
2024-12-19 18:13:33 +08:00
doc_name_idx = list ( doc_name_idx ) [ 0 ]
2024-08-15 09:17:36 +08:00
doc_aggs = { }
for r in tbl [ " rows " ] :
if r [ docid_idx ] not in doc_aggs :
2024-12-19 18:13:33 +08:00
doc_aggs [ r [ docid_idx ] ] = { " doc_name " : r [ doc_name_idx ] , " count " : 0 }
2024-08-15 09:17:36 +08:00
doc_aggs [ r [ docid_idx ] ] [ " count " ] + = 1
return {
2024-12-19 18:13:33 +08:00
" answer " : " \n " . join ( [ columns , line , rows ] ) ,
2025-03-24 13:18:47 +08:00
" reference " : {
" chunks " : [ { " doc_id " : r [ docid_idx ] , " docnm_kwd " : r [ doc_name_idx ] } for r in tbl [ " rows " ] ] ,
" doc_aggs " : [ { " doc_id " : did , " doc_name " : d [ " doc_name " ] , " count " : d [ " count " ] } for did , d in doc_aggs . items ( ) ] ,
} ,
" prompt " : sys_prompt ,
2024-08-15 09:17:36 +08:00
}
2024-09-03 19:49:14 +08:00
def tts ( tts_mdl , text ) :
2024-12-08 14:21:12 +08:00
if not tts_mdl or not text :
2025-11-16 19:29:20 +08:00
return None
2024-09-03 19:49:14 +08:00
bin = b " "
for chunk in tts_mdl . tts ( text ) :
bin + = chunk
2024-09-09 12:08:50 +08:00
return binascii . hexlify ( bin ) . decode ( " utf-8 " )
2025-08-19 09:33:33 +08:00
def ask ( question , kb_ids , tenant_id , chat_llm_name = None , search_config = { } ) :
2025-08-19 17:25:44 +08:00
doc_ids = search_config . get ( " doc_ids " , [ ] )
2025-08-19 09:33:33 +08:00
rerank_mdl = None
2025-08-19 17:25:44 +08:00
kb_ids = search_config . get ( " kb_ids " , kb_ids )
chat_llm_name = search_config . get ( " chat_id " , chat_llm_name )
rerank_id = search_config . get ( " rerank_id " , " " )
meta_data_filter = search_config . get ( " meta_data_filter " )
2025-08-19 09:33:33 +08:00
2024-09-09 12:08:50 +08:00
kbs = KnowledgebaseService . get_by_ids ( kb_ids )
2024-12-19 18:13:33 +08:00
embedding_list = list ( set ( [ kb . embd_id for kb in kbs ] ) )
2024-09-09 12:08:50 +08:00
2024-12-19 18:13:33 +08:00
is_knowledge_graph = all ( [ kb . parser_id == ParserType . KG for kb in kbs ] )
2025-11-06 09:36:38 +08:00
retriever = settings . retriever if not is_knowledge_graph else settings . kg_retriever
2024-09-09 12:08:50 +08:00
2024-12-19 18:13:33 +08:00
embd_mdl = LLMBundle ( tenant_id , LLMType . EMBEDDING , embedding_list [ 0 ] )
2025-06-18 16:45:42 +08:00
chat_mdl = LLMBundle ( tenant_id , LLMType . CHAT , chat_llm_name )
2025-08-19 09:33:33 +08:00
if rerank_id :
rerank_mdl = LLMBundle ( tenant_id , LLMType . RERANK , rerank_id )
2024-09-09 12:08:50 +08:00
max_tokens = chat_mdl . max_length
2024-12-10 17:03:24 +08:00
tenant_ids = list ( set ( [ kb . tenant_id for kb in kbs ] ) )
2025-08-19 10:27:24 +08:00
2025-08-19 17:25:44 +08:00
if meta_data_filter :
metas = DocumentService . get_meta_by_kbs ( kb_ids )
if meta_data_filter . get ( " method " ) == " auto " :
filters = gen_meta_filter ( chat_mdl , metas , question )
doc_ids . extend ( meta_filter ( metas , filters ) )
if not doc_ids :
doc_ids = None
elif meta_data_filter . get ( " method " ) == " manual " :
doc_ids . extend ( meta_filter ( metas , meta_data_filter [ " manual " ] ) )
if not doc_ids :
doc_ids = None
2025-08-19 09:33:33 +08:00
kbinfos = retriever . retrieval (
2025-09-23 10:19:25 +08:00
question = question ,
2025-08-19 09:33:33 +08:00
embd_mdl = embd_mdl ,
tenant_ids = tenant_ids ,
kb_ids = kb_ids ,
page = 1 ,
page_size = 12 ,
2025-08-19 17:25:44 +08:00
similarity_threshold = search_config . get ( " similarity_threshold " , 0.1 ) ,
vector_similarity_weight = search_config . get ( " vector_similarity_weight " , 0.3 ) ,
top = search_config . get ( " top_k " , 1024 ) ,
2025-08-19 09:33:33 +08:00
doc_ids = doc_ids ,
aggs = False ,
rerank_mdl = rerank_mdl ,
rank_feature = label_question ( question , kbs )
)
2024-12-10 17:03:24 +08:00
knowledges = kb_prompt ( kbinfos , max_tokens )
2025-08-19 10:27:24 +08:00
sys_prompt = PROMPT_JINJA_ENV . from_string ( ASK_SUMMARY ) . render ( knowledge = " \n " . join ( knowledges ) )
2024-09-09 12:08:50 +08:00
msg = [ { " role " : " user " , " content " : question } ]
def decorate_answer ( answer ) :
2025-08-19 10:27:24 +08:00
nonlocal knowledges , kbinfos , sys_prompt
2025-09-23 10:19:25 +08:00
answer , idx = retriever . insert_citations ( answer , [ ck [ " content_ltks " ] for ck in kbinfos [ " chunks " ] ] , [ ck [ " vector " ] for ck in kbinfos [ " chunks " ] ] ,
embd_mdl , tkweight = 0.7 , vtweight = 0.3 )
2024-09-09 12:08:50 +08:00
idx = set ( [ kbinfos [ " chunks " ] [ int ( i ) ] [ " doc_id " ] for i in idx ] )
2025-03-24 13:18:47 +08:00
recall_docs = [ d for d in kbinfos [ " doc_aggs " ] if d [ " doc_id " ] in idx ]
2024-12-08 14:21:12 +08:00
if not recall_docs :
recall_docs = kbinfos [ " doc_aggs " ]
2024-09-09 12:08:50 +08:00
kbinfos [ " doc_aggs " ] = recall_docs
refs = deepcopy ( kbinfos )
for c in refs [ " chunks " ] :
if c . get ( " vector " ) :
del c [ " vector " ]
if answer . lower ( ) . find ( " invalid key " ) > = 0 or answer . lower ( ) . find ( " invalid api " ) > = 0 :
2024-12-10 17:03:24 +08:00
answer + = " Please set LLM API-Key in ' User Setting -> Model Providers -> API-Key ' "
2025-03-05 17:25:47 +08:00
refs [ " chunks " ] = chunks_format ( refs )
return { " answer " : answer , " reference " : refs }
2024-09-09 12:08:50 +08:00
answer = " "
2025-08-19 10:27:24 +08:00
for ans in chat_mdl . chat_streamly ( sys_prompt , msg , { " temperature " : 0.1 } ) :
2024-09-09 12:08:50 +08:00
answer = ans
yield { " answer " : answer , " reference " : { } }
2025-05-19 19:34:05 +08:00
yield decorate_answer ( answer )
2025-08-19 17:25:44 +08:00
def gen_mindmap ( question , kb_ids , tenant_id , search_config = { } ) :
meta_data_filter = search_config . get ( " meta_data_filter " , { } )
doc_ids = search_config . get ( " doc_ids " , [ ] )
rerank_id = search_config . get ( " rerank_id " , " " )
rerank_mdl = None
kbs = KnowledgebaseService . get_by_ids ( kb_ids )
2025-08-19 18:57:35 +08:00
if not kbs :
return { " error " : " No KB selected " }
2025-08-19 17:25:44 +08:00
embedding_list = list ( set ( [ kb . embd_id for kb in kbs ] ) )
tenant_ids = list ( set ( [ kb . tenant_id for kb in kbs ] ) )
embd_mdl = LLMBundle ( tenant_id , LLMType . EMBEDDING , llm_name = embedding_list [ 0 ] )
chat_mdl = LLMBundle ( tenant_id , LLMType . CHAT , llm_name = search_config . get ( " chat_id " , " " ) )
if rerank_id :
rerank_mdl = LLMBundle ( tenant_id , LLMType . RERANK , rerank_id )
if meta_data_filter :
metas = DocumentService . get_meta_by_kbs ( kb_ids )
if meta_data_filter . get ( " method " ) == " auto " :
filters = gen_meta_filter ( chat_mdl , metas , question )
doc_ids . extend ( meta_filter ( metas , filters ) )
if not doc_ids :
doc_ids = None
elif meta_data_filter . get ( " method " ) == " manual " :
doc_ids . extend ( meta_filter ( metas , meta_data_filter [ " manual " ] ) )
if not doc_ids :
doc_ids = None
2025-11-06 09:36:38 +08:00
ranks = settings . retriever . retrieval (
2025-08-19 17:25:44 +08:00
question = question ,
embd_mdl = embd_mdl ,
tenant_ids = tenant_ids ,
kb_ids = kb_ids ,
page = 1 ,
page_size = 12 ,
similarity_threshold = search_config . get ( " similarity_threshold " , 0.2 ) ,
vector_similarity_weight = search_config . get ( " vector_similarity_weight " , 0.3 ) ,
top = search_config . get ( " top_k " , 1024 ) ,
doc_ids = doc_ids ,
aggs = False ,
rerank_mdl = rerank_mdl ,
rank_feature = label_question ( question , kbs ) ,
)
mindmap = MindMapExtractor ( chat_mdl )
mind_map = trio . run ( mindmap , [ c [ " content_with_weight " ] for c in ranks [ " chunks " ] ] )
2025-09-04 16:51:13 +08:00
return mind_map . output