2023-12-25 19:05:59 +08:00
|
|
|
|
#-*- coding:utf-8 -*-
|
|
|
|
|
import sys, os, re,inspect,json,traceback,logging,argparse, copy
|
|
|
|
|
sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../")
|
|
|
|
|
from tornado.web import RequestHandler,Application
|
|
|
|
|
from tornado.ioloop import IOLoop
|
|
|
|
|
from tornado.httpserver import HTTPServer
|
|
|
|
|
from tornado.options import define,options
|
|
|
|
|
from util import es_conn, setup_logging
|
|
|
|
|
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
|
|
|
|
from nlp import huqie
|
|
|
|
|
from nlp import query as Query
|
2023-12-26 19:32:06 +08:00
|
|
|
|
from nlp import search
|
2023-12-25 19:05:59 +08:00
|
|
|
|
from llm import HuEmbedding, GptTurbo
|
|
|
|
|
import numpy as np
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
from util import config
|
|
|
|
|
from timeit import default_timer as timer
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
SE = None
|
|
|
|
|
CFIELD="content_ltks"
|
|
|
|
|
EMBEDDING = HuEmbedding()
|
|
|
|
|
LLM = GptTurbo()
|
|
|
|
|
|
|
|
|
|
def get_QA_pairs(hists):
|
|
|
|
|
pa = []
|
|
|
|
|
for h in hists:
|
|
|
|
|
for k in ["user", "assistant"]:
|
|
|
|
|
if h.get(k):
|
|
|
|
|
pa.append({
|
|
|
|
|
"content": h[k],
|
|
|
|
|
"role": k,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
for p in pa[:-1]: assert len(p) == 2, p
|
|
|
|
|
return pa
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2023-12-26 19:32:06 +08:00
|
|
|
|
def get_instruction(sres, top_i, max_len=8096, fld="content_ltks"):
|
2023-12-25 19:05:59 +08:00
|
|
|
|
max_len //= len(top_i)
|
|
|
|
|
# add instruction to prompt
|
|
|
|
|
instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
|
|
|
|
|
if len(instructions)>2:
|
|
|
|
|
# Said that LLM is sensitive to the first and the last one, so
|
|
|
|
|
# rearrange the order of references
|
|
|
|
|
instructions.append(copy.deepcopy(instructions[1]))
|
|
|
|
|
instructions.pop(1)
|
|
|
|
|
|
|
|
|
|
def token_num(txt):
|
|
|
|
|
c = 0
|
|
|
|
|
for tk in re.split(r"[,。/?‘’”“:;:;!!]", txt):
|
|
|
|
|
if re.match(r"[a-zA-Z-]+$", tk):
|
|
|
|
|
c += 1
|
|
|
|
|
continue
|
|
|
|
|
c += len(tk)
|
|
|
|
|
return c
|
|
|
|
|
|
|
|
|
|
_inst = ""
|
|
|
|
|
for ins in instructions:
|
|
|
|
|
if token_num(_inst) > 4096:
|
|
|
|
|
_inst += "\n知识库:" + instructions[-1][:max_len]
|
|
|
|
|
break
|
|
|
|
|
_inst += "\n知识库:" + ins[:max_len]
|
|
|
|
|
return _inst
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prompt_and_answer(history, inst):
|
|
|
|
|
hist = get_QA_pairs(history)
|
|
|
|
|
chks = []
|
|
|
|
|
for s in re.split(r"[::;;。\n\r]+", inst):
|
|
|
|
|
if s: chks.append(s)
|
|
|
|
|
chks = len(set(chks))/(0.1+len(chks))
|
|
|
|
|
print("Duplication portion:", chks)
|
|
|
|
|
|
|
|
|
|
system = """
|
|
|
|
|
你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。
|
|
|
|
|
以下是知识库:
|
|
|
|
|
%s
|
|
|
|
|
以上是知识库。
|
|
|
|
|
"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst)
|
|
|
|
|
|
|
|
|
|
print("【PROMPT】:", system)
|
|
|
|
|
start = timer()
|
|
|
|
|
response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512})
|
|
|
|
|
print("GENERATE: ", timer()-start)
|
|
|
|
|
print("===>>", response)
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Handler(RequestHandler):
|
|
|
|
|
def post(self):
|
|
|
|
|
global SE,MUST_TK_NUM
|
|
|
|
|
param = json.loads(self.request.body.decode('utf-8'))
|
|
|
|
|
try:
|
|
|
|
|
question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
|
|
|
|
|
res = SE.search({
|
2023-12-26 19:32:06 +08:00
|
|
|
|
"question": question,
|
|
|
|
|
"kb_ids": param.get("kb_ids", []),
|
|
|
|
|
"size": param.get("topn", 15)},
|
|
|
|
|
search.index_name(param["uid"])
|
|
|
|
|
)
|
2023-12-25 19:05:59 +08:00
|
|
|
|
|
|
|
|
|
sim = SE.rerank(res, question)
|
|
|
|
|
rk_idx = np.argsort(sim*-1)
|
|
|
|
|
topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)]
|
|
|
|
|
inst = get_instruction(res, topidx)
|
|
|
|
|
|
|
|
|
|
ans, topidx = prompt_and_answer(param["history"], inst)
|
|
|
|
|
ans = SE.insert_citations(ans, topidx, res)
|
|
|
|
|
|
|
|
|
|
refer = OrderedDict()
|
|
|
|
|
docnms = {}
|
|
|
|
|
for i in rk_idx:
|
2023-12-26 19:32:06 +08:00
|
|
|
|
did = res.field[res.ids[i]]["doc_id"]
|
|
|
|
|
if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"]
|
2023-12-25 19:05:59 +08:00
|
|
|
|
if did not in refer: refer[did] = []
|
|
|
|
|
refer[did].append({
|
|
|
|
|
"chunk_id": res.ids[i],
|
2023-12-26 19:32:06 +08:00
|
|
|
|
"content": res.field[res.ids[i]]["content_ltks"],
|
2023-12-25 19:05:59 +08:00
|
|
|
|
"image": ""
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
print("::::::::::::::", ans)
|
|
|
|
|
self.write(json.dumps({
|
|
|
|
|
"code":0,
|
|
|
|
|
"msg":"success",
|
|
|
|
|
"data":{
|
|
|
|
|
"uid": param["uid"],
|
|
|
|
|
"dialog_id": param["dialog_id"],
|
2023-12-26 19:32:06 +08:00
|
|
|
|
"assistant": ans,
|
2023-12-25 19:05:59 +08:00
|
|
|
|
"refer": [{
|
|
|
|
|
"did": did,
|
|
|
|
|
"doc_name": docnms[did],
|
|
|
|
|
"chunks": chunks
|
|
|
|
|
} for did, chunks in refer.items()]
|
|
|
|
|
}
|
|
|
|
|
}))
|
|
|
|
|
logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False))
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logging.error("Request 500: "+str(e))
|
|
|
|
|
self.write(json.dumps({
|
|
|
|
|
"code":500,
|
|
|
|
|
"msg":str(e),
|
|
|
|
|
"data":{}
|
|
|
|
|
}))
|
|
|
|
|
print(traceback.format_exc())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument("--port", default=4455, type=int, help="Port used for service")
|
|
|
|
|
ARGS = parser.parse_args()
|
|
|
|
|
|
2023-12-26 19:32:06 +08:00
|
|
|
|
SE = search.Dealer(es_conn.HuEs("infiniflow"), EMBEDDING)
|
2023-12-25 19:05:59 +08:00
|
|
|
|
|
|
|
|
|
app = Application([(r'/v1/chat/completions', Handler)],debug=False)
|
|
|
|
|
http_server = HTTPServer(app)
|
|
|
|
|
http_server.bind(ARGS.port)
|
|
|
|
|
http_server.start(3)
|
|
|
|
|
|
|
|
|
|
IOLoop.current().start()
|
|
|
|
|
|