2024-08-15 09:17:36 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2024-11-14 17:13:48 +08:00
import logging
2024-09-03 19:49:14 +08:00
import binascii
2024-08-15 09:17:36 +08:00
import os
import json
import re
2024-12-10 17:03:24 +08:00
from collections import defaultdict
2024-08-15 09:17:36 +08:00
from copy import deepcopy
2024-09-09 12:08:50 +08:00
from timeit import default_timer as timer
2024-11-13 13:49:18 +08:00
import datetime
from datetime import timedelta
2024-12-19 18:13:33 +08:00
from api . db import LLMType , ParserType , StatusEnum
2024-12-09 12:38:04 +08:00
from api . db . db_models import Dialog , DB
2024-08-15 09:17:36 +08:00
from api . db . services . common_service import CommonService
from api . db . services . knowledgebase_service import KnowledgebaseService
from api . db . services . llm_service import LLMService , TenantLLMService , LLMBundle
2024-11-15 17:30:56 +08:00
from api import settings
2024-08-15 09:17:36 +08:00
from rag . app . resume import forbidden_select_fields4resume
from rag . nlp . search import index_name
from rag . utils import rmSpace , num_tokens_from_string , encoder
from api . utils . file_utils import get_project_base_directory
class DialogService ( CommonService ) :
model = Dialog
2024-10-12 13:48:43 +08:00
@classmethod
@DB.connection_context ( )
def get_list ( cls , tenant_id ,
2024-12-19 18:13:33 +08:00
page_number , items_per_page , orderby , desc , id , name ) :
2024-10-12 13:48:43 +08:00
chats = cls . model . select ( )
if id :
chats = chats . where ( cls . model . id == id )
if name :
chats = chats . where ( cls . model . name == name )
chats = chats . where (
2024-12-19 18:13:33 +08:00
( cls . model . tenant_id == tenant_id )
2024-10-12 13:48:43 +08:00
& ( cls . model . status == StatusEnum . VALID . value )
)
if desc :
chats = chats . order_by ( cls . model . getter_by ( orderby ) . desc ( ) )
else :
chats = chats . order_by ( cls . model . getter_by ( orderby ) . asc ( ) )
chats = chats . paginate ( page_number , items_per_page )
return list ( chats . dicts ( ) )
2024-08-15 09:17:36 +08:00
def message_fit_in ( msg , max_length = 4000 ) :
def count ( ) :
nonlocal msg
tks_cnts = [ ]
for m in msg :
tks_cnts . append (
{ " role " : m [ " role " ] , " count " : num_tokens_from_string ( m [ " content " ] ) } )
total = 0
for m in tks_cnts :
total + = m [ " count " ]
return total
c = count ( )
if c < max_length :
return c , msg
msg_ = [ m for m in msg [ : - 1 ] if m [ " role " ] == " system " ]
2024-11-19 18:41:48 +08:00
if len ( msg ) > 1 :
msg_ . append ( msg [ - 1 ] )
2024-08-15 09:17:36 +08:00
msg = msg_
c = count ( )
if c < max_length :
return c , msg
ll = num_tokens_from_string ( msg_ [ 0 ] [ " content " ] )
2024-12-08 14:21:12 +08:00
ll2 = num_tokens_from_string ( msg_ [ - 1 ] [ " content " ] )
if ll / ( ll + ll2 ) > 0.8 :
2024-08-15 09:17:36 +08:00
m = msg_ [ 0 ] [ " content " ]
2024-12-08 14:21:12 +08:00
m = encoder . decode ( encoder . encode ( m ) [ : max_length - ll2 ] )
2024-08-15 09:17:36 +08:00
msg [ 0 ] [ " content " ] = m
return max_length , msg
m = msg_ [ 1 ] [ " content " ]
2024-12-08 14:21:12 +08:00
m = encoder . decode ( encoder . encode ( m ) [ : max_length - ll2 ] )
2024-08-15 09:17:36 +08:00
msg [ 1 ] [ " content " ] = m
return max_length , msg
def llm_id2llm_type ( llm_id ) :
2024-12-03 12:41:39 +08:00
llm_id , _ = TenantLLMService . split_model_name_and_factory ( llm_id )
2024-08-15 09:17:36 +08:00
fnm = os . path . join ( get_project_base_directory ( ) , " conf " )
llm_factories = json . load ( open ( os . path . join ( fnm , " llm_factories.json " ) , " r " ) )
for llm_factory in llm_factories [ " factory_llm_infos " ] :
for llm in llm_factory [ " llm " ] :
if llm_id == llm [ " llm_name " ] :
return llm [ " model_type " ] . strip ( " , " ) [ - 1 ]
2024-10-12 13:48:43 +08:00
2024-08-15 09:17:36 +08:00
2024-12-10 17:03:24 +08:00
def kb_prompt ( kbinfos , max_tokens ) :
knowledges = [ ck [ " content_with_weight " ] for ck in kbinfos [ " chunks " ] ]
used_token_count = 0
chunks_num = 0
for i , c in enumerate ( knowledges ) :
used_token_count + = num_tokens_from_string ( c )
chunks_num + = 1
if max_tokens * 0.97 < used_token_count :
knowledges = knowledges [ : i ]
break
doc2chunks = defaultdict ( list )
for i , ck in enumerate ( kbinfos [ " chunks " ] ) :
if i > = chunks_num :
break
2024-12-23 03:22:57 +01:00
doc2chunks [ ck [ " docnm_kwd " ] ] . append ( ck [ " content_with_weight " ] )
2024-12-10 17:03:24 +08:00
knowledges = [ ]
for nm , chunks in doc2chunks . items ( ) :
txt = f " Document: { nm } \n Contains the following relevant fragments: \n "
for i , chunk in enumerate ( chunks , 1 ) :
txt + = f " { i } . { chunk } \n "
knowledges . append ( txt )
return knowledges
2024-08-15 09:17:36 +08:00
def chat ( dialog , messages , stream = True , * * kwargs ) :
assert messages [ - 1 ] [ " role " ] == " user " , " The last content of this conversation is not from user. "
2024-12-19 18:13:33 +08:00
chat_start_ts = timer ( )
# Get llm model name and model provider name
llm_id , model_provider = TenantLLMService . split_model_name_and_factory ( dialog . llm_id )
# Get llm model instance by model and provide name
llm = LLMService . query ( llm_name = llm_id ) if not model_provider else LLMService . query ( llm_name = llm_id , fid = model_provider )
2024-08-15 09:17:36 +08:00
if not llm :
2024-12-19 18:13:33 +08:00
# Model name is provided by tenant, but not system built-in
llm = TenantLLMService . query ( tenant_id = dialog . tenant_id , llm_name = llm_id ) if not model_provider else \
TenantLLMService . query ( tenant_id = dialog . tenant_id , llm_name = llm_id , llm_factory = model_provider )
2024-08-15 09:17:36 +08:00
if not llm :
raise LookupError ( " LLM( %s ) not found " % dialog . llm_id )
max_tokens = 8192
else :
max_tokens = llm [ 0 ] . max_tokens
2024-12-19 18:13:33 +08:00
check_llm_ts = timer ( )
2024-08-15 09:17:36 +08:00
kbs = KnowledgebaseService . get_by_ids ( dialog . kb_ids )
2024-12-19 18:13:33 +08:00
embedding_list = list ( set ( [ kb . embd_id for kb in kbs ] ) )
if len ( embedding_list ) != 1 :
2024-08-15 09:17:36 +08:00
yield { " answer " : " **ERROR**: Knowledge bases use different embedding models. " , " reference " : [ ] }
return { " answer " : " **ERROR**: Knowledge bases use different embedding models. " , " reference " : [ ] }
2024-12-19 18:13:33 +08:00
embedding_model_name = embedding_list [ 0 ]
is_knowledge_graph = all ( [ kb . parser_id == ParserType . KG for kb in kbs ] )
retriever = settings . retrievaler if not is_knowledge_graph else settings . kg_retrievaler
2024-08-15 09:17:36 +08:00
questions = [ m [ " content " ] for m in messages if m [ " role " ] == " user " ] [ - 3 : ]
attachments = kwargs [ " doc_ids " ] . split ( " , " ) if " doc_ids " in kwargs else None
if " doc_ids " in messages [ - 1 ] :
attachments = messages [ - 1 ] [ " doc_ids " ]
for m in messages [ : - 1 ] :
if " doc_ids " in m :
attachments . extend ( m [ " doc_ids " ] )
2024-12-19 18:13:33 +08:00
create_retriever_ts = timer ( )
embd_mdl = LLMBundle ( dialog . tenant_id , LLMType . EMBEDDING , embedding_model_name )
2024-11-05 09:29:01 +08:00
if not embd_mdl :
2024-12-19 18:13:33 +08:00
raise LookupError ( " Embedding model( %s ) not found " % embedding_model_name )
bind_embedding_ts = timer ( )
2024-11-05 10:04:31 +08:00
2024-08-15 09:17:36 +08:00
if llm_id2llm_type ( dialog . llm_id ) == " image2text " :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . IMAGE2TEXT , dialog . llm_id )
else :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . CHAT , dialog . llm_id )
2024-12-19 18:13:33 +08:00
bind_llm_ts = timer ( )
2024-08-15 09:17:36 +08:00
prompt_config = dialog . prompt_config
field_map = KnowledgebaseService . get_field_map ( dialog . kb_ids )
2024-09-03 19:49:14 +08:00
tts_mdl = None
if prompt_config . get ( " tts " ) :
tts_mdl = LLMBundle ( dialog . tenant_id , LLMType . TTS )
2024-08-15 09:17:36 +08:00
# try to use sql if field mapping is good to go
if field_map :
2024-11-14 17:13:48 +08:00
logging . debug ( " Use SQL to retrieval: {} " . format ( questions [ - 1 ] ) )
2024-08-15 09:17:36 +08:00
ans = use_sql ( questions [ - 1 ] , field_map , dialog . tenant_id , chat_mdl , prompt_config . get ( " quote " , True ) )
if ans :
yield ans
return
for p in prompt_config [ " parameters " ] :
if p [ " key " ] == " knowledge " :
continue
if p [ " key " ] not in kwargs and not p [ " optional " ] :
raise KeyError ( " Miss parameter: " + p [ " key " ] )
if p [ " key " ] not in kwargs :
prompt_config [ " system " ] = prompt_config [ " system " ] . replace (
" { %s } " % p [ " key " ] , " " )
2024-09-20 17:25:55 +08:00
if len ( questions ) > 1 and prompt_config . get ( " refine_multiturn " ) :
questions = [ full_question ( dialog . tenant_id , dialog . llm_id , messages ) ]
else :
questions = questions [ - 1 : ]
2024-12-19 18:13:33 +08:00
refine_question_ts = timer ( )
2024-09-20 17:25:55 +08:00
2024-08-15 09:17:36 +08:00
rerank_mdl = None
if dialog . rerank_id :
rerank_mdl = LLMBundle ( dialog . tenant_id , LLMType . RERANK , dialog . rerank_id )
2024-12-19 18:13:33 +08:00
bind_reranker_ts = timer ( )
generate_keyword_ts = bind_reranker_ts
2024-08-15 09:17:36 +08:00
if " knowledge " not in [ p [ " key " ] for p in prompt_config [ " parameters " ] ] :
kbinfos = { " total " : 0 , " chunks " : [ ] , " doc_aggs " : [ ] }
else :
if prompt_config . get ( " keyword " , False ) :
questions [ - 1 ] + = keyword_extraction ( chat_mdl , questions [ - 1 ] )
2024-12-19 18:13:33 +08:00
generate_keyword_ts = timer ( )
2024-10-29 13:19:01 +08:00
tenant_ids = list ( set ( [ kb . tenant_id for kb in kbs ] ) )
2024-12-19 18:13:33 +08:00
kbinfos = retriever . retrieval ( " " . join ( questions ) , embd_mdl , tenant_ids , dialog . kb_ids , 1 , dialog . top_n ,
dialog . similarity_threshold ,
dialog . vector_similarity_weight ,
doc_ids = attachments ,
top = dialog . top_k , aggs = False , rerank_mdl = rerank_mdl )
retrieval_ts = timer ( )
2024-12-10 17:03:24 +08:00
knowledges = kb_prompt ( kbinfos , max_tokens )
2024-11-14 17:13:48 +08:00
logging . debug (
2024-08-15 09:17:36 +08:00
" {} -> {} " . format ( " " . join ( questions ) , " \n -> " . join ( knowledges ) ) )
if not knowledges and prompt_config . get ( " empty_response " ) :
2024-09-03 19:49:14 +08:00
empty_res = prompt_config [ " empty_response " ]
yield { " answer " : empty_res , " reference " : kbinfos , " audio_binary " : tts ( tts_mdl , empty_res ) }
2024-08-15 09:17:36 +08:00
return { " answer " : prompt_config [ " empty_response " ] , " reference " : kbinfos }
2024-09-24 12:04:16 +08:00
kwargs [ " knowledge " ] = " \n \n ------ \n \n " . join ( knowledges )
2024-08-15 09:17:36 +08:00
gen_conf = dialog . llm_setting
msg = [ { " role " : " system " , " content " : prompt_config [ " system " ] . format ( * * kwargs ) } ]
msg . extend ( [ { " role " : m [ " role " ] , " content " : re . sub ( r " ## \ d+ \ $ \ $ " , " " , m [ " content " ] ) }
for m in messages if m [ " role " ] != " system " ] )
used_token_count , msg = message_fit_in ( msg , int ( max_tokens * 0.97 ) )
assert len ( msg ) > = 2 , f " message_fit_in has bug: { msg } "
2024-08-26 16:14:15 +08:00
prompt = msg [ 0 ] [ " content " ]
2024-10-08 12:53:04 +08:00
prompt + = " \n \n ### Query: \n %s " % " " . join ( questions )
2024-08-15 09:17:36 +08:00
if " max_tokens " in gen_conf :
gen_conf [ " max_tokens " ] = min (
gen_conf [ " max_tokens " ] ,
max_tokens - used_token_count )
def decorate_answer ( answer ) :
2024-12-19 18:13:33 +08:00
nonlocal prompt_config , knowledges , kwargs , kbinfos , prompt , retrieval_ts
finish_chat_ts = timer ( )
2024-08-15 09:17:36 +08:00
refs = [ ]
if knowledges and ( prompt_config . get ( " quote " , True ) and kwargs . get ( " quote " , True ) ) :
2024-12-19 18:13:33 +08:00
answer , idx = retriever . insert_citations ( answer ,
[ ck [ " content_ltks " ]
for ck in kbinfos [ " chunks " ] ] ,
[ ck [ " vector " ]
for ck in kbinfos [ " chunks " ] ] ,
embd_mdl ,
tkweight = 1 - dialog . vector_similarity_weight ,
vtweight = dialog . vector_similarity_weight )
2024-08-15 09:17:36 +08:00
idx = set ( [ kbinfos [ " chunks " ] [ int ( i ) ] [ " doc_id " ] for i in idx ] )
recall_docs = [
d for d in kbinfos [ " doc_aggs " ] if d [ " doc_id " ] in idx ]
2024-12-08 14:21:12 +08:00
if not recall_docs :
recall_docs = kbinfos [ " doc_aggs " ]
2024-08-15 09:17:36 +08:00
kbinfos [ " doc_aggs " ] = recall_docs
refs = deepcopy ( kbinfos )
for c in refs [ " chunks " ] :
if c . get ( " vector " ) :
del c [ " vector " ]
if answer . lower ( ) . find ( " invalid key " ) > = 0 or answer . lower ( ) . find ( " invalid api " ) > = 0 :
2024-12-07 11:04:36 +08:00
answer + = " Please set LLM API-Key in ' User Setting -> Model providers -> API-Key ' "
2024-12-19 18:13:33 +08:00
finish_chat_ts = timer ( )
total_time_cost = ( finish_chat_ts - chat_start_ts ) * 1000
check_llm_time_cost = ( check_llm_ts - chat_start_ts ) * 1000
create_retriever_time_cost = ( create_retriever_ts - check_llm_ts ) * 1000
bind_embedding_time_cost = ( bind_embedding_ts - create_retriever_ts ) * 1000
bind_llm_time_cost = ( bind_llm_ts - bind_embedding_ts ) * 1000
refine_question_time_cost = ( refine_question_ts - bind_llm_ts ) * 1000
bind_reranker_time_cost = ( bind_reranker_ts - refine_question_ts ) * 1000
generate_keyword_time_cost = ( generate_keyword_ts - bind_reranker_ts ) * 1000
retrieval_time_cost = ( retrieval_ts - generate_keyword_ts ) * 1000
generate_result_time_cost = ( finish_chat_ts - retrieval_ts ) * 1000
2024-12-23 14:50:12 +08:00
prompt = f " { prompt } \n \n - Total: { total_time_cost : .1f } ms \n - Check LLM: { check_llm_time_cost : .1f } ms \n - Create retriever: { create_retriever_time_cost : .1f } ms \n - Bind embedding: { bind_embedding_time_cost : .1f } ms \n - Bind LLM: { bind_llm_time_cost : .1f } ms \n - Tune question: { refine_question_time_cost : .1f } ms \n - Bind reranker: { bind_reranker_time_cost : .1f } ms \n - Generate keyword: { generate_keyword_time_cost : .1f } ms \n - Retrieval: { retrieval_time_cost : .1f } ms \n - Generate answer: { generate_result_time_cost : .1f } ms "
2024-09-11 19:49:18 +08:00
return { " answer " : answer , " reference " : refs , " prompt " : prompt }
2024-08-15 09:17:36 +08:00
if stream :
2024-09-03 19:49:14 +08:00
last_ans = " "
2024-08-15 09:17:36 +08:00
answer = " "
2024-08-26 16:14:15 +08:00
for ans in chat_mdl . chat_streamly ( prompt , msg [ 1 : ] , gen_conf ) :
2024-08-15 09:17:36 +08:00
answer = ans
2024-09-03 19:49:14 +08:00
delta_ans = ans [ len ( last_ans ) : ]
2024-09-12 17:51:20 +08:00
if num_tokens_from_string ( delta_ans ) < 16 :
2024-09-03 19:49:14 +08:00
continue
last_ans = answer
yield { " answer " : answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) }
delta_ans = answer [ len ( last_ans ) : ]
if delta_ans :
yield { " answer " : answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) }
2024-08-15 09:17:36 +08:00
yield decorate_answer ( answer )
else :
2024-08-26 16:14:15 +08:00
answer = chat_mdl . chat ( prompt , msg [ 1 : ] , gen_conf )
2024-11-14 17:13:48 +08:00
logging . debug ( " User: {} |Assistant: {} " . format (
2024-08-15 09:17:36 +08:00
msg [ - 1 ] [ " content " ] , answer ) )
2024-09-03 19:49:14 +08:00
res = decorate_answer ( answer )
res [ " audio_binary " ] = tts ( tts_mdl , answer )
yield res
2024-08-15 09:17:36 +08:00
def use_sql ( question , field_map , tenant_id , chat_mdl , quota = True ) :
2024-12-19 18:13:33 +08:00
sys_prompt = " You are a Database Administrator. You need to check the fields of the following tables based on the user ' s list of questions and write the SQL corresponding to the last question. "
user_prompt = """
Table name : { } ;
Table of database fields are as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Question are as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Please write the SQL , only SQL , without any other explanations or text .
2024-08-15 09:17:36 +08:00
""" .format(
index_name ( tenant_id ) ,
" \n " . join ( [ f " { k } : { v } " for k , v in field_map . items ( ) ] ) ,
question
)
tried_times = 0
def get_table ( ) :
2024-12-19 18:13:33 +08:00
nonlocal sys_prompt , user_prompt , question , tried_times
sql = chat_mdl . chat ( sys_prompt , [ { " role " : " user " , " content " : user_prompt } ] , {
2024-08-15 09:17:36 +08:00
" temperature " : 0.06 } )
2024-12-19 18:13:33 +08:00
logging . debug ( f " { question } ==> { user_prompt } get SQL: { sql } " )
2024-08-15 09:17:36 +08:00
sql = re . sub ( r " [ \ r \ n]+ " , " " , sql . lower ( ) )
sql = re . sub ( r " .*select " , " select " , sql . lower ( ) )
sql = re . sub ( r " + " , " " , sql )
sql = re . sub ( r " ([;; ]|```).* " , " " , sql )
if sql [ : len ( " select " ) ] != " select " :
return None , None
if not re . search ( r " ((sum|avg|max|min) \ (|group by ) " , sql . lower ( ) ) :
if sql [ : len ( " select * " ) ] != " select * " :
sql = " select doc_id,docnm_kwd, " + sql [ 6 : ]
else :
flds = [ ]
for k in field_map . keys ( ) :
if k in forbidden_select_fields4resume :
continue
if len ( flds ) > 11 :
break
flds . append ( k )
sql = " select doc_id,docnm_kwd, " + " , " . join ( flds ) + sql [ 8 : ]
2024-11-14 17:13:48 +08:00
logging . debug ( f " { question } get SQL(refined): { sql } " )
2024-08-15 09:17:36 +08:00
tried_times + = 1
2024-11-15 17:30:56 +08:00
return settings . retrievaler . sql_retrieval ( sql , format = " json " ) , sql
2024-08-15 09:17:36 +08:00
tbl , sql = get_table ( )
if tbl is None :
return None
if tbl . get ( " error " ) and tried_times < = 2 :
2024-12-19 18:13:33 +08:00
user_prompt = """
Table name : { } ;
Table of database fields are as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Question are as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Please write the SQL , only SQL , without any other explanations or text .
2024-08-15 09:17:36 +08:00
2024-12-19 18:13:33 +08:00
The SQL error you provided last time is as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Error issued by database as follows :
2024-08-15 09:17:36 +08:00
{ }
2024-12-19 18:13:33 +08:00
Please correct the error and write SQL again , only SQL , without any other explanations or text .
2024-08-15 09:17:36 +08:00
""" .format(
index_name ( tenant_id ) ,
" \n " . join ( [ f " { k } : { v } " for k , v in field_map . items ( ) ] ) ,
question , sql , tbl [ " error " ]
)
tbl , sql = get_table ( )
2024-11-14 17:13:48 +08:00
logging . debug ( " TRY it again: {} " . format ( sql ) )
2024-08-15 09:17:36 +08:00
2024-11-14 17:13:48 +08:00
logging . debug ( " GET table: {} " . format ( tbl ) )
2024-08-15 09:17:36 +08:00
if tbl . get ( " error " ) or len ( tbl [ " rows " ] ) == 0 :
return None
docid_idx = set ( [ ii for ii , c in enumerate (
tbl [ " columns " ] ) if c [ " name " ] == " doc_id " ] )
2024-12-19 18:13:33 +08:00
doc_name_idx = set ( [ ii for ii , c in enumerate (
2024-08-15 09:17:36 +08:00
tbl [ " columns " ] ) if c [ " name " ] == " docnm_kwd " ] )
2024-12-19 18:13:33 +08:00
column_idx = [ ii for ii in range (
len ( tbl [ " columns " ] ) ) if ii not in ( docid_idx | doc_name_idx ) ]
2024-08-15 09:17:36 +08:00
2024-12-19 18:13:33 +08:00
# compose Markdown table
columns = " | " + " | " . join ( [ re . sub ( r " (/.*|( [^( ) ]+) ) " , " " , field_map . get ( tbl [ " columns " ] [ i ] [ " name " ] ,
tbl [ " columns " ] [ i ] [ " name " ] ) ) for i in
column_idx ] ) + ( " |Source| " if docid_idx and docid_idx else " | " )
2024-08-15 09:17:36 +08:00
2024-12-19 18:13:33 +08:00
line = " | " + " | " . join ( [ " ------ " for _ in range ( len ( column_idx ) ) ] ) + \
2024-08-15 09:17:36 +08:00
( " |------| " if docid_idx and docid_idx else " " )
rows = [ " | " +
2024-12-19 18:13:33 +08:00
" | " . join ( [ rmSpace ( str ( r [ i ] ) ) for i in column_idx ] ) . replace ( " None " , " " ) +
2024-08-15 09:17:36 +08:00
" | " for r in tbl [ " rows " ] ]
2024-11-06 18:47:53 +08:00
rows = [ r for r in rows if re . sub ( r " [ |]+ " , " " , r ) ]
2024-08-15 09:17:36 +08:00
if quota :
rows = " \n " . join ( [ r + f " ## { ii } $$ | " for ii , r in enumerate ( rows ) ] )
else :
rows = " \n " . join ( [ r + f " ## { ii } $$ | " for ii , r in enumerate ( rows ) ] )
rows = re . sub ( r " T[0-9] {2} :[0-9] {2} :[0-9] {2} ( \ .[0-9]+Z)? \ | " , " | " , rows )
2024-12-19 18:13:33 +08:00
if not docid_idx or not doc_name_idx :
2024-11-14 17:13:48 +08:00
logging . warning ( " SQL missing field: " + sql )
2024-08-15 09:17:36 +08:00
return {
2024-12-19 18:13:33 +08:00
" answer " : " \n " . join ( [ columns , line , rows ] ) ,
2024-08-26 16:14:15 +08:00
" reference " : { " chunks " : [ ] , " doc_aggs " : [ ] } ,
" prompt " : sys_prompt
2024-08-15 09:17:36 +08:00
}
docid_idx = list ( docid_idx ) [ 0 ]
2024-12-19 18:13:33 +08:00
doc_name_idx = list ( doc_name_idx ) [ 0 ]
2024-08-15 09:17:36 +08:00
doc_aggs = { }
for r in tbl [ " rows " ] :
if r [ docid_idx ] not in doc_aggs :
2024-12-19 18:13:33 +08:00
doc_aggs [ r [ docid_idx ] ] = { " doc_name " : r [ doc_name_idx ] , " count " : 0 }
2024-08-15 09:17:36 +08:00
doc_aggs [ r [ docid_idx ] ] [ " count " ] + = 1
return {
2024-12-19 18:13:33 +08:00
" answer " : " \n " . join ( [ columns , line , rows ] ) ,
" reference " : { " chunks " : [ { " doc_id " : r [ docid_idx ] , " docnm_kwd " : r [ doc_name_idx ] } for r in tbl [ " rows " ] ] ,
2024-08-15 09:17:36 +08:00
" doc_aggs " : [ { " doc_id " : did , " doc_name " : d [ " doc_name " ] , " count " : d [ " count " ] } for did , d in
2024-08-26 16:14:15 +08:00
doc_aggs . items ( ) ] } ,
" prompt " : sys_prompt
2024-08-15 09:17:36 +08:00
}
def relevant ( tenant_id , llm_id , question , contents : list ) :
if llm_id2llm_type ( llm_id ) == " image2text " :
chat_mdl = LLMBundle ( tenant_id , LLMType . IMAGE2TEXT , llm_id )
else :
chat_mdl = LLMBundle ( tenant_id , LLMType . CHAT , llm_id )
prompt = """
You are a grader assessing relevance of a retrieved document to a user question .
It does not need to be a stringent test . The goal is to filter out erroneous retrievals .
If the document contains keyword ( s ) or semantic meaning related to the user question , grade it as relevant .
Give a binary score ' yes ' or ' no ' score to indicate whether the document is relevant to the question .
No other words needed except ' yes ' or ' no ' .
"""
2024-12-08 14:21:12 +08:00
if not contents :
return False
2024-08-15 09:17:36 +08:00
contents = " Documents: \n " + " - " . join ( contents )
contents = f " Question: { question } \n " + contents
if num_tokens_from_string ( contents ) > = chat_mdl . max_length - 4 :
contents = encoder . decode ( encoder . encode ( contents ) [ : chat_mdl . max_length - 4 ] )
ans = chat_mdl . chat ( prompt , [ { " role " : " user " , " content " : contents } ] , { " temperature " : 0.01 } )
2024-12-08 14:21:12 +08:00
if ans . lower ( ) . find ( " yes " ) > = 0 :
return True
2024-08-15 09:17:36 +08:00
return False
def rewrite ( tenant_id , llm_id , question ) :
if llm_id2llm_type ( llm_id ) == " image2text " :
chat_mdl = LLMBundle ( tenant_id , LLMType . IMAGE2TEXT , llm_id )
else :
chat_mdl = LLMBundle ( tenant_id , LLMType . CHAT , llm_id )
prompt = """
You are an expert at query expansion to generate a paraphrasing of a question .
I can ' t retrieval relevant information from the knowledge base by using user ' s question directly .
You need to expand or paraphrase user ' s question by multiple ways such as using synonyms words/phrase,
writing the abbreviation in its entirety , adding some extra descriptions or explanations ,
changing the way of expression , translating the original question into another language ( English / Chinese ) , etc .
And return 5 versions of question and one is from translation .
Just list the question . No other words are needed .
"""
ans = chat_mdl . chat ( prompt , [ { " role " : " user " , " content " : question } ] , { " temperature " : 0.8 } )
return ans
2024-09-03 19:49:14 +08:00
2024-10-22 13:12:49 +08:00
def keyword_extraction ( chat_mdl , content , topn = 3 ) :
prompt = f """
Role : You ' re a text analyzer.
Task : extract the most important keywords / phrases of a given piece of text content .
Requirements :
- Summarize the text content , and give top { topn } important keywords / phrases .
- The keywords MUST be in language of the given piece of text content .
- The keywords are delimited by ENGLISH COMMA .
- Keywords ONLY in output .
### Text Content
{ content }
"""
msg = [
{ " role " : " system " , " content " : prompt } ,
{ " role " : " user " , " content " : " Output: " }
]
_ , msg = message_fit_in ( msg , chat_mdl . max_length )
kwd = chat_mdl . chat ( prompt , msg [ 1 : ] , { " temperature " : 0.2 } )
2024-12-08 14:21:12 +08:00
if isinstance ( kwd , tuple ) :
kwd = kwd [ 0 ]
2024-12-19 18:13:33 +08:00
if kwd . find ( " **ERROR** " ) > = 0 :
2024-12-08 14:21:12 +08:00
return " "
2024-10-22 13:12:49 +08:00
return kwd
def question_proposal ( chat_mdl , content , topn = 3 ) :
prompt = f """
Role : You ' re a text analyzer.
Task : propose { topn } questions about a given piece of text content .
Requirements :
- Understand and summarize the text content , and propose top { topn } important questions .
- The questions SHOULD NOT have overlapping meanings .
- The questions SHOULD cover the main content of the text as much as possible .
- The questions MUST be in language of the given piece of text content .
- One question per line .
- Question ONLY in output .
### Text Content
{ content }
"""
msg = [
{ " role " : " system " , " content " : prompt } ,
{ " role " : " user " , " content " : " Output: " }
]
_ , msg = message_fit_in ( msg , chat_mdl . max_length )
kwd = chat_mdl . chat ( prompt , msg [ 1 : ] , { " temperature " : 0.2 } )
2024-12-08 14:21:12 +08:00
if isinstance ( kwd , tuple ) :
kwd = kwd [ 0 ]
if kwd . find ( " **ERROR** " ) > = 0 :
return " "
2024-10-22 13:12:49 +08:00
return kwd
2024-09-20 17:25:55 +08:00
def full_question ( tenant_id , llm_id , messages ) :
if llm_id2llm_type ( llm_id ) == " image2text " :
chat_mdl = LLMBundle ( tenant_id , LLMType . IMAGE2TEXT , llm_id )
else :
chat_mdl = LLMBundle ( tenant_id , LLMType . CHAT , llm_id )
conv = [ ]
for m in messages :
2024-12-08 14:21:12 +08:00
if m [ " role " ] not in [ " user " , " assistant " ] :
continue
2024-09-20 17:25:55 +08:00
conv . append ( " {} : {} " . format ( m [ " role " ] . upper ( ) , m [ " content " ] ) )
conv = " \n " . join ( conv )
2024-11-13 13:49:18 +08:00
today = datetime . date . today ( ) . isoformat ( )
yesterday = ( datetime . date . today ( ) - timedelta ( days = 1 ) ) . isoformat ( )
tomorrow = ( datetime . date . today ( ) + timedelta ( days = 1 ) ) . isoformat ( )
2024-09-20 17:25:55 +08:00
prompt = f """
Role : A helpful assistant
2024-11-13 13:49:18 +08:00
Task and steps :
1. Generate a full user question that would follow the conversation .
2. If the user ' s question involves relative date, you need to convert it into absolute date based on the current date, which is {today} . For example: ' yesterday ' would be converted to {yesterday} .
2024-09-20 17:25:55 +08:00
Requirements & Restrictions :
- Text generated MUST be in the same language of the original user ' s question.
- If the user ' s latest question is completely, don ' t do anything , just return the original question .
- DON ' T generate anything except a refined question.
######################
- Examples -
######################
# Example 1
## Conversation
USER : What is the name of Donald Trump ' s father?
ASSISTANT : Fred Trump .
USER : And his mother ?
###############
Output : What ' s the name of Donald Trump ' s mother ?
- - - - - - - - - - - -
# Example 2
## Conversation
USER : What is the name of Donald Trump ' s father?
ASSISTANT : Fred Trump .
USER : And his mother ?
ASSISTANT : Mary Trump .
User : What ' s her full name?
###############
Output : What ' s the full name of Donald Trump ' s mother Mary Trump ?
2024-11-13 13:49:18 +08:00
- - - - - - - - - - - -
# Example 3
## Conversation
USER : What ' s the weather today in London?
ASSISTANT : Cloudy .
USER : What ' s about tomorrow in Rochester?
###############
Output : What ' s the weather in Rochester on {tomorrow} ?
2024-09-20 17:25:55 +08:00
######################
# Real Data
## Conversation
{ conv }
###############
"""
ans = chat_mdl . chat ( prompt , [ { " role " : " user " , " content " : " Output: " } ] , { " temperature " : 0.2 } )
return ans if ans . find ( " **ERROR** " ) < 0 else messages [ - 1 ] [ " content " ]
2024-09-03 19:49:14 +08:00
def tts ( tts_mdl , text ) :
2024-12-08 14:21:12 +08:00
if not tts_mdl or not text :
return
2024-09-03 19:49:14 +08:00
bin = b " "
for chunk in tts_mdl . tts ( text ) :
bin + = chunk
2024-09-09 12:08:50 +08:00
return binascii . hexlify ( bin ) . decode ( " utf-8 " )
def ask ( question , kb_ids , tenant_id ) :
kbs = KnowledgebaseService . get_by_ids ( kb_ids )
2024-12-19 18:13:33 +08:00
embedding_list = list ( set ( [ kb . embd_id for kb in kbs ] ) )
2024-09-09 12:08:50 +08:00
2024-12-19 18:13:33 +08:00
is_knowledge_graph = all ( [ kb . parser_id == ParserType . KG for kb in kbs ] )
retriever = settings . retrievaler if not is_knowledge_graph else settings . kg_retrievaler
2024-09-09 12:08:50 +08:00
2024-12-19 18:13:33 +08:00
embd_mdl = LLMBundle ( tenant_id , LLMType . EMBEDDING , embedding_list [ 0 ] )
2024-09-09 12:08:50 +08:00
chat_mdl = LLMBundle ( tenant_id , LLMType . CHAT )
max_tokens = chat_mdl . max_length
2024-12-10 17:03:24 +08:00
tenant_ids = list ( set ( [ kb . tenant_id for kb in kbs ] ) )
2024-12-19 18:13:33 +08:00
kbinfos = retriever . retrieval ( question , embd_mdl , tenant_ids , kb_ids , 1 , 12 , 0.1 , 0.3 , aggs = False )
2024-12-10 17:03:24 +08:00
knowledges = kb_prompt ( kbinfos , max_tokens )
2024-09-09 12:08:50 +08:00
prompt = """
Role : You ' re a smart assistant. Your name is Miss R.
Task : Summarize the information from knowledge bases and answer user ' s question.
Requirements and restriction :
- DO NOT make things up , especially for numbers .
- If the information from knowledge is irrelevant with user ' s question, JUST SAY: Sorry, no relevant information provided.
- Answer with markdown format text .
- Answer in language of user ' s question.
- DO NOT make things up , especially for numbers .
2024-12-10 17:03:24 +08:00
2024-09-09 12:08:50 +08:00
### Information from knowledge bases
% s
2024-12-10 17:03:24 +08:00
2024-09-09 12:08:50 +08:00
The above is information from knowledge bases .
2024-12-10 17:03:24 +08:00
""" % " \n " .join(knowledges)
2024-09-09 12:08:50 +08:00
msg = [ { " role " : " user " , " content " : question } ]
def decorate_answer ( answer ) :
nonlocal knowledges , kbinfos , prompt
2024-12-19 18:13:33 +08:00
answer , idx = retriever . insert_citations ( answer ,
[ ck [ " content_ltks " ]
for ck in kbinfos [ " chunks " ] ] ,
[ ck [ " vector " ]
for ck in kbinfos [ " chunks " ] ] ,
embd_mdl ,
tkweight = 0.7 ,
vtweight = 0.3 )
2024-09-09 12:08:50 +08:00
idx = set ( [ kbinfos [ " chunks " ] [ int ( i ) ] [ " doc_id " ] for i in idx ] )
recall_docs = [
d for d in kbinfos [ " doc_aggs " ] if d [ " doc_id " ] in idx ]
2024-12-08 14:21:12 +08:00
if not recall_docs :
recall_docs = kbinfos [ " doc_aggs " ]
2024-09-09 12:08:50 +08:00
kbinfos [ " doc_aggs " ] = recall_docs
refs = deepcopy ( kbinfos )
for c in refs [ " chunks " ] :
if c . get ( " vector " ) :
del c [ " vector " ]
if answer . lower ( ) . find ( " invalid key " ) > = 0 or answer . lower ( ) . find ( " invalid api " ) > = 0 :
2024-12-10 17:03:24 +08:00
answer + = " Please set LLM API-Key in ' User Setting -> Model Providers -> API-Key ' "
2024-09-09 12:08:50 +08:00
return { " answer " : answer , " reference " : refs }
answer = " "
for ans in chat_mdl . chat_streamly ( prompt , msg , { " temperature " : 0.1 } ) :
answer = ans
yield { " answer " : answer , " reference " : { } }
yield decorate_answer ( answer )