2024-08-15 09:17:36 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2024-11-14 17:13:48 +08:00
import logging
2024-09-03 19:49:14 +08:00
import binascii
2025-02-20 17:41:01 +08:00
import time
2025-02-26 15:40:52 +08:00
from functools import partial
2024-08-15 09:17:36 +08:00
import re
from copy import deepcopy
2024-09-09 12:08:50 +08:00
from timeit import default_timer as timer
2025-02-26 15:40:52 +08:00
from agentic_reasoning import DeepResearcher
2024-12-19 18:13:33 +08:00
from api . db import LLMType , ParserType , StatusEnum
2024-12-09 12:38:04 +08:00
from api . db . db_models import Dialog , DB
2024-08-15 09:17:36 +08:00
from api . db . services . common_service import CommonService
from api . db . services . knowledgebase_service import KnowledgebaseService
2025-02-18 13:42:22 +08:00
from api . db . services . llm_service import TenantLLMService , LLMBundle
2024-11-15 17:30:56 +08:00
from api import settings
2024-08-15 09:17:36 +08:00
from rag . app . resume import forbidden_select_fields4resume
2025-02-26 15:40:52 +08:00
from rag . app . tag import label_question
2024-08-15 09:17:36 +08:00
from rag . nlp . search import index_name
2025-02-26 19:45:22 +08:00
from rag . prompts import kb_prompt , message_fit_in , llm_id2llm_type , keyword_extraction , full_question , chunks_format
2025-02-26 15:40:52 +08:00
from rag . utils import rmSpace , num_tokens_from_string
2025-02-26 10:21:04 +08:00
from rag . utils . tavily_conn import Tavily
2024-08-15 09:17:36 +08:00
class DialogService ( CommonService ) :
model = Dialog
2024-10-12 13:48:43 +08:00
@classmethod
@DB.connection_context ( )
def get_list ( cls , tenant_id ,
2024-12-19 18:13:33 +08:00
page_number , items_per_page , orderby , desc , id , name ) :
2024-10-12 13:48:43 +08:00
chats = cls . model . select ( )
if id :
chats = chats . where ( cls . model . id == id )
if name :
chats = chats . where ( cls . model . name == name )
chats = chats . where (
2024-12-19 18:13:33 +08:00
( cls . model . tenant_id == tenant_id )
2024-10-12 13:48:43 +08:00
& ( cls . model . status == StatusEnum . VALID . value )
)
if desc :
chats = chats . order_by ( cls . model . getter_by ( orderby ) . desc ( ) )
else :
chats = chats . order_by ( cls . model . getter_by ( orderby ) . asc ( ) )
chats = chats . paginate ( page_number , items_per_page )
return list ( chats . dicts ( ) )
2024-08-15 09:17:36 +08:00
2025-02-21 12:24:02 +08:00
def chat_solo ( dialog , messages , stream = True ) :
if llm_id2llm_type ( dialog . llm_id ) == " image2text " :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . IMAGE2TEXT , dialog . llm_id )
else :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . CHAT , dialog . llm_id )
prompt_config = dialog . prompt_config
tts_mdl = None
if prompt_config . get ( " tts " ) :
tts_mdl = LLMBundle ( dialog . tenant_id , LLMType . TTS )
msg = [ { " role " : m [ " role " ] , " content " : re . sub ( r " ## \ d+ \ $ \ $ " , " " , m [ " content " ] ) }
for m in messages if m [ " role " ] != " system " ]
if stream :
last_ans = " "
for ans in chat_mdl . chat_streamly ( prompt_config . get ( " system " , " " ) , msg , dialog . llm_setting ) :
answer = ans
delta_ans = ans [ len ( last_ans ) : ]
if num_tokens_from_string ( delta_ans ) < 16 :
continue
last_ans = answer
yield { " answer " : answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) , " prompt " : " " , " created_at " : time . time ( ) }
else :
answer = chat_mdl . chat ( prompt_config . get ( " system " , " " ) , msg , dialog . llm_setting )
user_content = msg [ - 1 ] . get ( " content " , " [content not available] " )
logging . debug ( " User: {} |Assistant: {} " . format ( user_content , answer ) )
yield { " answer " : answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , answer ) , " prompt " : " " , " created_at " : time . time ( ) }
2024-08-15 09:17:36 +08:00
def chat ( dialog , messages , stream = True , * * kwargs ) :
assert messages [ - 1 ] [ " role " ] == " user " , " The last content of this conversation is not from user. "
2025-02-21 12:24:02 +08:00
if not dialog . kb_ids :
for ans in chat_solo ( dialog , messages , stream ) :
yield ans
return
2024-12-19 18:13:33 +08:00
chat_start_ts = timer ( )
2025-02-18 13:42:22 +08:00
if llm_id2llm_type ( dialog . llm_id ) == " image2text " :
llm_model_config = TenantLLMService . get_model_config ( dialog . tenant_id , LLMType . IMAGE2TEXT , dialog . llm_id )
2024-08-15 09:17:36 +08:00
else :
2025-02-18 13:42:22 +08:00
llm_model_config = TenantLLMService . get_model_config ( dialog . tenant_id , LLMType . CHAT , dialog . llm_id )
max_tokens = llm_model_config . get ( " max_tokens " , 8192 )
2024-12-19 18:13:33 +08:00
check_llm_ts = timer ( )
2024-08-15 09:17:36 +08:00
kbs = KnowledgebaseService . get_by_ids ( dialog . kb_ids )
2024-12-19 18:13:33 +08:00
embedding_list = list ( set ( [ kb . embd_id for kb in kbs ] ) )
if len ( embedding_list ) != 1 :
2024-08-15 09:17:36 +08:00
yield { " answer " : " **ERROR**: Knowledge bases use different embedding models. " , " reference " : [ ] }
return { " answer " : " **ERROR**: Knowledge bases use different embedding models. " , " reference " : [ ] }
2024-12-19 18:13:33 +08:00
embedding_model_name = embedding_list [ 0 ]
2025-01-22 19:43:14 +08:00
retriever = settings . retrievaler
2024-08-15 09:17:36 +08:00
questions = [ m [ " content " ] for m in messages if m [ " role " ] == " user " ] [ - 3 : ]
attachments = kwargs [ " doc_ids " ] . split ( " , " ) if " doc_ids " in kwargs else None
if " doc_ids " in messages [ - 1 ] :
attachments = messages [ - 1 ] [ " doc_ids " ]
2024-12-19 18:13:33 +08:00
create_retriever_ts = timer ( )
embd_mdl = LLMBundle ( dialog . tenant_id , LLMType . EMBEDDING , embedding_model_name )
2024-11-05 09:29:01 +08:00
if not embd_mdl :
2024-12-19 18:13:33 +08:00
raise LookupError ( " Embedding model( %s ) not found " % embedding_model_name )
bind_embedding_ts = timer ( )
2024-11-05 10:04:31 +08:00
2024-08-15 09:17:36 +08:00
if llm_id2llm_type ( dialog . llm_id ) == " image2text " :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . IMAGE2TEXT , dialog . llm_id )
else :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . CHAT , dialog . llm_id )
2024-12-19 18:13:33 +08:00
bind_llm_ts = timer ( )
2024-08-15 09:17:36 +08:00
prompt_config = dialog . prompt_config
field_map = KnowledgebaseService . get_field_map ( dialog . kb_ids )
2024-09-03 19:49:14 +08:00
tts_mdl = None
if prompt_config . get ( " tts " ) :
tts_mdl = LLMBundle ( dialog . tenant_id , LLMType . TTS )
2024-08-15 09:17:36 +08:00
# try to use sql if field mapping is good to go
if field_map :
2024-11-14 17:13:48 +08:00
logging . debug ( " Use SQL to retrieval: {} " . format ( questions [ - 1 ] ) )
2024-08-15 09:17:36 +08:00
ans = use_sql ( questions [ - 1 ] , field_map , dialog . tenant_id , chat_mdl , prompt_config . get ( " quote " , True ) )
if ans :
yield ans
return
for p in prompt_config [ " parameters " ] :
if p [ " key " ] == " knowledge " :
continue
if p [ " key " ] not in kwargs and not p [ " optional " ] :
raise KeyError ( " Miss parameter: " + p [ " key " ] )
if p [ " key " ] not in kwargs :
prompt_config [ " system " ] = prompt_config [ " system " ] . replace (
" { %s } " % p [ " key " ] , " " )
2024-09-20 17:25:55 +08:00
if len ( questions ) > 1 and prompt_config . get ( " refine_multiturn " ) :
questions = [ full_question ( dialog . tenant_id , dialog . llm_id , messages ) ]
else :
questions = questions [ - 1 : ]
2024-12-19 18:13:33 +08:00
refine_question_ts = timer ( )
2024-09-20 17:25:55 +08:00
2024-08-15 09:17:36 +08:00
rerank_mdl = None
if dialog . rerank_id :
rerank_mdl = LLMBundle ( dialog . tenant_id , LLMType . RERANK , dialog . rerank_id )
2024-12-19 18:13:33 +08:00
bind_reranker_ts = timer ( )
generate_keyword_ts = bind_reranker_ts
2025-02-20 17:41:01 +08:00
thought = " "
kbinfos = { " total " : 0 , " chunks " : [ ] , " doc_aggs " : [ ] }
2024-12-19 18:13:33 +08:00
2024-08-15 09:17:36 +08:00
if " knowledge " not in [ p [ " key " ] for p in prompt_config [ " parameters " ] ] :
2025-02-20 17:41:01 +08:00
knowledges = [ ]
2024-08-15 09:17:36 +08:00
else :
if prompt_config . get ( " keyword " , False ) :
questions [ - 1 ] + = keyword_extraction ( chat_mdl , questions [ - 1 ] )
2024-12-19 18:13:33 +08:00
generate_keyword_ts = timer ( )
2024-10-29 13:19:01 +08:00
tenant_ids = list ( set ( [ kb . tenant_id for kb in kbs ] ) )
2025-01-09 17:07:21 +08:00
2025-02-20 17:41:01 +08:00
knowledges = [ ]
if prompt_config . get ( " reasoning " , False ) :
2025-02-26 15:40:52 +08:00
reasoner = DeepResearcher ( chat_mdl ,
prompt_config ,
partial ( retriever . retrieval , embd_mdl = embd_mdl , tenant_ids = tenant_ids , kb_ids = dialog . kb_ids , page = 1 , page_size = dialog . top_n , similarity_threshold = 0.2 , vector_similarity_weight = 0.3 ) )
for think in reasoner . thinking ( kbinfos , " " . join ( questions ) ) :
2025-02-20 17:41:01 +08:00
if isinstance ( think , str ) :
thought = think
knowledges = [ t for t in think . split ( " \n " ) if t ]
else :
yield think
else :
kbinfos = retriever . retrieval ( " " . join ( questions ) , embd_mdl , tenant_ids , dialog . kb_ids , 1 , dialog . top_n ,
dialog . similarity_threshold ,
dialog . vector_similarity_weight ,
doc_ids = attachments ,
top = dialog . top_k , aggs = False , rerank_mdl = rerank_mdl ,
rank_feature = label_question ( " " . join ( questions ) , kbs )
)
2025-02-26 10:21:04 +08:00
if prompt_config . get ( " tavily_api_key " ) :
tav = Tavily ( prompt_config [ " tavily_api_key " ] )
tav_res = tav . retrieve_chunks ( " " . join ( questions ) )
kbinfos [ " chunks " ] . extend ( tav_res [ " chunks " ] )
kbinfos [ " doc_aggs " ] . extend ( tav_res [ " doc_aggs " ] )
2025-02-20 17:41:01 +08:00
if prompt_config . get ( " use_kg " ) :
ck = settings . kg_retrievaler . retrieval ( " " . join ( questions ) ,
tenant_ids ,
dialog . kb_ids ,
embd_mdl ,
LLMBundle ( dialog . tenant_id , LLMType . CHAT ) )
if ck [ " content_with_weight " ] :
kbinfos [ " chunks " ] . insert ( 0 , ck )
knowledges = kb_prompt ( kbinfos , max_tokens )
2024-12-19 18:13:33 +08:00
2024-11-14 17:13:48 +08:00
logging . debug (
2024-08-15 09:17:36 +08:00
" {} -> {} " . format ( " " . join ( questions ) , " \n -> " . join ( knowledges ) ) )
2025-02-20 17:41:01 +08:00
retrieval_ts = timer ( )
2024-08-15 09:17:36 +08:00
if not knowledges and prompt_config . get ( " empty_response " ) :
2024-09-03 19:49:14 +08:00
empty_res = prompt_config [ " empty_response " ]
yield { " answer " : empty_res , " reference " : kbinfos , " audio_binary " : tts ( tts_mdl , empty_res ) }
2024-08-15 09:17:36 +08:00
return { " answer " : prompt_config [ " empty_response " ] , " reference " : kbinfos }
2025-01-13 14:35:24 +08:00
kwargs [ " knowledge " ] = " \n ------ \n " + " \n \n ------ \n \n " . join ( knowledges )
2024-08-15 09:17:36 +08:00
gen_conf = dialog . llm_setting
msg = [ { " role " : " system " , " content " : prompt_config [ " system " ] . format ( * * kwargs ) } ]
msg . extend ( [ { " role " : m [ " role " ] , " content " : re . sub ( r " ## \ d+ \ $ \ $ " , " " , m [ " content " ] ) }
for m in messages if m [ " role " ] != " system " ] )
used_token_count , msg = message_fit_in ( msg , int ( max_tokens * 0.97 ) )
assert len ( msg ) > = 2 , f " message_fit_in has bug: { msg } "
2024-08-26 16:14:15 +08:00
prompt = msg [ 0 ] [ " content " ]
2024-10-08 12:53:04 +08:00
prompt + = " \n \n ### Query: \n %s " % " " . join ( questions )
2024-08-15 09:17:36 +08:00
if " max_tokens " in gen_conf :
gen_conf [ " max_tokens " ] = min (
gen_conf [ " max_tokens " ] ,
max_tokens - used_token_count )
def decorate_answer ( answer ) :
2024-12-19 18:13:33 +08:00
nonlocal prompt_config , knowledges , kwargs , kbinfos , prompt , retrieval_ts
2024-08-15 09:17:36 +08:00
refs = [ ]
2025-02-20 17:41:01 +08:00
ans = answer . split ( " </think> " )
think = " "
if len ( ans ) == 2 :
think = ans [ 0 ] + " </think> "
answer = ans [ 1 ]
2024-08-15 09:17:36 +08:00
if knowledges and ( prompt_config . get ( " quote " , True ) and kwargs . get ( " quote " , True ) ) :
2024-12-19 18:13:33 +08:00
answer , idx = retriever . insert_citations ( answer ,
[ ck [ " content_ltks " ]
for ck in kbinfos [ " chunks " ] ] ,
[ ck [ " vector " ]
for ck in kbinfos [ " chunks " ] ] ,
embd_mdl ,
tkweight = 1 - dialog . vector_similarity_weight ,
vtweight = dialog . vector_similarity_weight )
2024-08-15 09:17:36 +08:00
idx = set ( [ kbinfos [ " chunks " ] [ int ( i ) ] [ " doc_id " ] for i in idx ] )
recall_docs = [
d for d in kbinfos [ " doc_aggs " ] if d [ " doc_id " ] in idx ]
2024-12-08 14:21:12 +08:00
if not recall_docs :
recall_docs = kbinfos [ " doc_aggs " ]
2024-08-15 09:17:36 +08:00
kbinfos [ " doc_aggs " ] = recall_docs
refs = deepcopy ( kbinfos )
for c in refs [ " chunks " ] :
if c . get ( " vector " ) :
del c [ " vector " ]
if answer . lower ( ) . find ( " invalid key " ) > = 0 or answer . lower ( ) . find ( " invalid api " ) > = 0 :
2024-12-07 11:04:36 +08:00
answer + = " Please set LLM API-Key in ' User Setting -> Model providers -> API-Key ' "
2024-12-19 18:13:33 +08:00
finish_chat_ts = timer ( )
total_time_cost = ( finish_chat_ts - chat_start_ts ) * 1000
check_llm_time_cost = ( check_llm_ts - chat_start_ts ) * 1000
create_retriever_time_cost = ( create_retriever_ts - check_llm_ts ) * 1000
bind_embedding_time_cost = ( bind_embedding_ts - create_retriever_ts ) * 1000
bind_llm_time_cost = ( bind_llm_ts - bind_embedding_ts ) * 1000
refine_question_time_cost = ( refine_question_ts - bind_llm_ts ) * 1000
bind_reranker_time_cost = ( bind_reranker_ts - refine_question_ts ) * 1000
generate_keyword_time_cost = ( generate_keyword_ts - bind_reranker_ts ) * 1000
retrieval_time_cost = ( retrieval_ts - generate_keyword_ts ) * 1000
generate_result_time_cost = ( finish_chat_ts - retrieval_ts ) * 1000
2024-12-23 14:50:12 +08:00
prompt = f " { prompt } \n \n - Total: { total_time_cost : .1f } ms \n - Check LLM: { check_llm_time_cost : .1f } ms \n - Create retriever: { create_retriever_time_cost : .1f } ms \n - Bind embedding: { bind_embedding_time_cost : .1f } ms \n - Bind LLM: { bind_llm_time_cost : .1f } ms \n - Tune question: { refine_question_time_cost : .1f } ms \n - Bind reranker: { bind_reranker_time_cost : .1f } ms \n - Generate keyword: { generate_keyword_time_cost : .1f } ms \n - Retrieval: { retrieval_time_cost : .1f } ms \n - Generate answer: { generate_result_time_cost : .1f } ms "
2025-02-20 17:41:01 +08:00
return { " answer " : think + answer , " reference " : refs , " prompt " : re . sub ( r " \ n " , " \n " , prompt ) , " created_at " : time . time ( ) }
2024-08-15 09:17:36 +08:00
if stream :
2024-09-03 19:49:14 +08:00
last_ans = " "
2024-08-15 09:17:36 +08:00
answer = " "
2024-08-26 16:14:15 +08:00
for ans in chat_mdl . chat_streamly ( prompt , msg [ 1 : ] , gen_conf ) :
2025-02-20 17:41:01 +08:00
if thought :
ans = re . sub ( r " <think>.*</think> " , " " , ans , flags = re . DOTALL )
2024-08-15 09:17:36 +08:00
answer = ans
2024-09-03 19:49:14 +08:00
delta_ans = ans [ len ( last_ans ) : ]
2024-09-12 17:51:20 +08:00
if num_tokens_from_string ( delta_ans ) < 16 :
2024-09-03 19:49:14 +08:00
continue
last_ans = answer
2025-02-20 17:41:01 +08:00
yield { " answer " : thought + answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) }
2024-09-03 19:49:14 +08:00
delta_ans = answer [ len ( last_ans ) : ]
if delta_ans :
2025-02-20 17:41:01 +08:00
yield { " answer " : thought + answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) }
yield decorate_answer ( thought + answer )
2024-08-15 09:17:36 +08:00
else :
2024-08-26 16:14:15 +08:00
answer = chat_mdl . chat ( prompt , msg [ 1 : ] , gen_conf )
2025-02-13 23:27:01 -03:00
user_content = msg [ - 1 ] . get ( " content " , " [content not available] " )
logging . debug ( " User: {} |Assistant: {} " . format ( user_content , answer ) )
2024-09-03 19:49:14 +08:00
res = decorate_answer ( answer )
res [ " audio_binary " ] = tts ( tts_mdl , answer )
yield res
2024-08-15 09:17:36 +08:00
def use_sql ( question , field_map , tenant_id , chat_mdl , quota = True ) :
2024-12-19 18:13:33 +08:00
sys_prompt = " You are a Database Administrator. You need to check the fields of the following tables based on the user ' s list of questions and write the SQL corresponding to the last question. "
user_prompt = """
Table name : { } ;
Table of database fields are as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Question are as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Please write the SQL , only SQL , without any other explanations or text .
2024-08-15 09:17:36 +08:00
""" .format(
index_name ( tenant_id ) ,
" \n " . join ( [ f " { k } : { v } " for k , v in field_map . items ( ) ] ) ,
question
)
tried_times = 0
def get_table ( ) :
2024-12-19 18:13:33 +08:00
nonlocal sys_prompt , user_prompt , question , tried_times
sql = chat_mdl . chat ( sys_prompt , [ { " role " : " user " , " content " : user_prompt } ] , {
2024-08-15 09:17:36 +08:00
" temperature " : 0.06 } )
2024-12-19 18:13:33 +08:00
logging . debug ( f " { question } ==> { user_prompt } get SQL: { sql } " )
2024-08-15 09:17:36 +08:00
sql = re . sub ( r " [ \ r \ n]+ " , " " , sql . lower ( ) )
sql = re . sub ( r " .*select " , " select " , sql . lower ( ) )
sql = re . sub ( r " + " , " " , sql )
sql = re . sub ( r " ([;; ]|```).* " , " " , sql )
if sql [ : len ( " select " ) ] != " select " :
return None , None
if not re . search ( r " ((sum|avg|max|min) \ (|group by ) " , sql . lower ( ) ) :
if sql [ : len ( " select * " ) ] != " select * " :
sql = " select doc_id,docnm_kwd, " + sql [ 6 : ]
else :
flds = [ ]
for k in field_map . keys ( ) :
if k in forbidden_select_fields4resume :
continue
if len ( flds ) > 11 :
break
flds . append ( k )
sql = " select doc_id,docnm_kwd, " + " , " . join ( flds ) + sql [ 8 : ]
2024-11-14 17:13:48 +08:00
logging . debug ( f " { question } get SQL(refined): { sql } " )
2024-08-15 09:17:36 +08:00
tried_times + = 1
2024-11-15 17:30:56 +08:00
return settings . retrievaler . sql_retrieval ( sql , format = " json " ) , sql
2024-08-15 09:17:36 +08:00
tbl , sql = get_table ( )
if tbl is None :
return None
if tbl . get ( " error " ) and tried_times < = 2 :
2024-12-19 18:13:33 +08:00
user_prompt = """
Table name : { } ;
Table of database fields are as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Question are as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Please write the SQL , only SQL , without any other explanations or text .
2024-08-15 09:17:36 +08:00
2024-12-19 18:13:33 +08:00
The SQL error you provided last time is as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Error issued by database as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Please correct the error and write SQL again , only SQL , without any other explanations or text .
2024-08-15 09:17:36 +08:00
""" .format(
index_name ( tenant_id ) ,
" \n " . join ( [ f " { k } : { v } " for k , v in field_map . items ( ) ] ) ,
question , sql , tbl [ " error " ]
)
tbl , sql = get_table ( )
2024-11-14 17:13:48 +08:00
logging . debug ( " TRY it again: {} " . format ( sql ) )
2024-08-15 09:17:36 +08:00
2024-11-14 17:13:48 +08:00
logging . debug ( " GET table: {} " . format ( tbl ) )
2024-08-15 09:17:36 +08:00
if tbl . get ( " error " ) or len ( tbl [ " rows " ] ) == 0 :
return None
docid_idx = set ( [ ii for ii , c in enumerate (
tbl [ " columns " ] ) if c [ " name " ] == " doc_id " ] )
2024-12-19 18:13:33 +08:00
doc_name_idx = set ( [ ii for ii , c in enumerate (
2024-08-15 09:17:36 +08:00
tbl [ " columns " ] ) if c [ " name " ] == " docnm_kwd " ] )
2024-12-19 18:13:33 +08:00
column_idx = [ ii for ii in range (
len ( tbl [ " columns " ] ) ) if ii not in ( docid_idx | doc_name_idx ) ]
2024-08-15 09:17:36 +08:00
2024-12-19 18:13:33 +08:00
# compose Markdown table
columns = " | " + " | " . join ( [ re . sub ( r " (/.*|( [^( ) ]+) ) " , " " , field_map . get ( tbl [ " columns " ] [ i ] [ " name " ] ,
tbl [ " columns " ] [ i ] [ " name " ] ) ) for i in
column_idx ] ) + ( " |Source| " if docid_idx and docid_idx else " | " )
2024-08-15 09:17:36 +08:00
2024-12-19 18:13:33 +08:00
line = " | " + " | " . join ( [ " ------ " for _ in range ( len ( column_idx ) ) ] ) + \
2024-08-15 09:17:36 +08:00
( " |------| " if docid_idx and docid_idx else " " )
rows = [ " | " +
2024-12-19 18:13:33 +08:00
" | " . join ( [ rmSpace ( str ( r [ i ] ) ) for i in column_idx ] ) . replace ( " None " , " " ) +
2024-08-15 09:17:36 +08:00
" | " for r in tbl [ " rows " ] ]
2024-11-06 18:47:53 +08:00
rows = [ r for r in rows if re . sub ( r " [ |]+ " , " " , r ) ]
2024-08-15 09:17:36 +08:00
if quota :
rows = " \n " . join ( [ r + f " ## { ii } $$ | " for ii , r in enumerate ( rows ) ] )
else :
rows = " \n " . join ( [ r + f " ## { ii } $$ | " for ii , r in enumerate ( rows ) ] )
rows = re . sub ( r " T[0-9] {2} :[0-9] {2} :[0-9] {2} ( \ .[0-9]+Z)? \ | " , " | " , rows )
2024-12-19 18:13:33 +08:00
if not docid_idx or not doc_name_idx :
2024-11-14 17:13:48 +08:00
logging . warning ( " SQL missing field: " + sql )
2024-08-15 09:17:36 +08:00
return {
2024-12-19 18:13:33 +08:00
" answer " : " \n " . join ( [ columns , line , rows ] ) ,
2024-08-26 16:14:15 +08:00
" reference " : { " chunks " : [ ] , " doc_aggs " : [ ] } ,
" prompt " : sys_prompt
2024-08-15 09:17:36 +08:00
}
docid_idx = list ( docid_idx ) [ 0 ]
2024-12-19 18:13:33 +08:00
doc_name_idx = list ( doc_name_idx ) [ 0 ]
2024-08-15 09:17:36 +08:00
doc_aggs = { }
for r in tbl [ " rows " ] :
if r [ docid_idx ] not in doc_aggs :
2024-12-19 18:13:33 +08:00
doc_aggs [ r [ docid_idx ] ] = { " doc_name " : r [ doc_name_idx ] , " count " : 0 }
2024-08-15 09:17:36 +08:00
doc_aggs [ r [ docid_idx ] ] [ " count " ] + = 1
return {
2024-12-19 18:13:33 +08:00
" answer " : " \n " . join ( [ columns , line , rows ] ) ,
" reference " : { " chunks " : [ { " doc_id " : r [ docid_idx ] , " docnm_kwd " : r [ doc_name_idx ] } for r in tbl [ " rows " ] ] ,
2024-08-15 09:17:36 +08:00
" doc_aggs " : [ { " doc_id " : did , " doc_name " : d [ " doc_name " ] , " count " : d [ " count " ] } for did , d in
2024-08-26 16:14:15 +08:00
doc_aggs . items ( ) ] } ,
" prompt " : sys_prompt
2024-08-15 09:17:36 +08:00
}
2024-09-03 19:49:14 +08:00
def tts ( tts_mdl , text ) :
2024-12-08 14:21:12 +08:00
if not tts_mdl or not text :
return
2024-09-03 19:49:14 +08:00
bin = b " "
for chunk in tts_mdl . tts ( text ) :
bin + = chunk
2024-09-09 12:08:50 +08:00
return binascii . hexlify ( bin ) . decode ( " utf-8 " )
def ask ( question , kb_ids , tenant_id ) :
kbs = KnowledgebaseService . get_by_ids ( kb_ids )
2024-12-19 18:13:33 +08:00
embedding_list = list ( set ( [ kb . embd_id for kb in kbs ] ) )
2024-09-09 12:08:50 +08:00
2024-12-19 18:13:33 +08:00
is_knowledge_graph = all ( [ kb . parser_id == ParserType . KG for kb in kbs ] )
retriever = settings . retrievaler if not is_knowledge_graph else settings . kg_retrievaler
2024-09-09 12:08:50 +08:00
2024-12-19 18:13:33 +08:00
embd_mdl = LLMBundle ( tenant_id , LLMType . EMBEDDING , embedding_list [ 0 ] )
2024-09-09 12:08:50 +08:00
chat_mdl = LLMBundle ( tenant_id , LLMType . CHAT )
max_tokens = chat_mdl . max_length
2024-12-10 17:03:24 +08:00
tenant_ids = list ( set ( [ kb . tenant_id for kb in kbs ] ) )
2025-01-09 17:07:21 +08:00
kbinfos = retriever . retrieval ( question , embd_mdl , tenant_ids , kb_ids ,
1 , 12 , 0.1 , 0.3 , aggs = False ,
rank_feature = label_question ( question , kbs )
)
2024-12-10 17:03:24 +08:00
knowledges = kb_prompt ( kbinfos , max_tokens )
2024-09-09 12:08:50 +08:00
prompt = """
Role : You ' re a smart assistant. Your name is Miss R.
Task : Summarize the information from knowledge bases and answer user ' s question.
Requirements and restriction :
- DO NOT make things up , especially for numbers .
- If the information from knowledge is irrelevant with user ' s question, JUST SAY: Sorry, no relevant information provided.
- Answer with markdown format text .
- Answer in language of user ' s question.
- DO NOT make things up , especially for numbers .
2024-12-10 17:03:24 +08:00
2024-09-09 12:08:50 +08:00
### Information from knowledge bases
% s
2024-12-10 17:03:24 +08:00
2024-09-09 12:08:50 +08:00
The above is information from knowledge bases .
2024-12-10 17:03:24 +08:00
""" % " \n " .join(knowledges)
2024-09-09 12:08:50 +08:00
msg = [ { " role " : " user " , " content " : question } ]
def decorate_answer ( answer ) :
nonlocal knowledges , kbinfos , prompt
2024-12-19 18:13:33 +08:00
answer , idx = retriever . insert_citations ( answer ,
[ ck [ " content_ltks " ]
for ck in kbinfos [ " chunks " ] ] ,
[ ck [ " vector " ]
for ck in kbinfos [ " chunks " ] ] ,
embd_mdl ,
tkweight = 0.7 ,
vtweight = 0.3 )
2024-09-09 12:08:50 +08:00
idx = set ( [ kbinfos [ " chunks " ] [ int ( i ) ] [ " doc_id " ] for i in idx ] )
recall_docs = [
d for d in kbinfos [ " doc_aggs " ] if d [ " doc_id " ] in idx ]
2024-12-08 14:21:12 +08:00
if not recall_docs :
recall_docs = kbinfos [ " doc_aggs " ]
2024-09-09 12:08:50 +08:00
kbinfos [ " doc_aggs " ] = recall_docs
refs = deepcopy ( kbinfos )
for c in refs [ " chunks " ] :
if c . get ( " vector " ) :
del c [ " vector " ]
if answer . lower ( ) . find ( " invalid key " ) > = 0 or answer . lower ( ) . find ( " invalid api " ) > = 0 :
2024-12-10 17:03:24 +08:00
answer + = " Please set LLM API-Key in ' User Setting -> Model Providers -> API-Key ' "
2025-02-26 19:45:22 +08:00
return { " answer " : answer , " reference " : chunks_format ( refs ) }
2024-09-09 12:08:50 +08:00
answer = " "
for ans in chat_mdl . chat_streamly ( prompt , msg , { " temperature " : 0.1 } ) :
answer = ans
yield { " answer " : answer , " reference " : { } }
yield decorate_answer ( answer )
2025-01-09 17:07:21 +08:00