mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-07-04 07:26:22 +00:00
165 lines
5.5 KiB
Python
165 lines
5.5 KiB
Python
![]() |
#-*- coding:utf-8 -*-
|
|||
|
import sys, os, re,inspect,json,traceback,logging,argparse, copy
|
|||
|
sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../")
|
|||
|
from tornado.web import RequestHandler,Application
|
|||
|
from tornado.ioloop import IOLoop
|
|||
|
from tornado.httpserver import HTTPServer
|
|||
|
from tornado.options import define,options
|
|||
|
from util import es_conn, setup_logging
|
|||
|
from svr import sec_search as search
|
|||
|
from svr.rpc_proxy import RPCProxy
|
|||
|
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
|||
|
from nlp import huqie
|
|||
|
from nlp import query as Query
|
|||
|
from llm import HuEmbedding, GptTurbo
|
|||
|
import numpy as np
|
|||
|
from io import BytesIO
|
|||
|
from util import config
|
|||
|
from timeit import default_timer as timer
|
|||
|
from collections import OrderedDict
|
|||
|
|
|||
|
SE = None
|
|||
|
CFIELD="content_ltks"
|
|||
|
EMBEDDING = HuEmbedding()
|
|||
|
LLM = GptTurbo()
|
|||
|
|
|||
|
def get_QA_pairs(hists):
|
|||
|
pa = []
|
|||
|
for h in hists:
|
|||
|
for k in ["user", "assistant"]:
|
|||
|
if h.get(k):
|
|||
|
pa.append({
|
|||
|
"content": h[k],
|
|||
|
"role": k,
|
|||
|
})
|
|||
|
|
|||
|
for p in pa[:-1]: assert len(p) == 2, p
|
|||
|
return pa
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def get_instruction(sres, top_i, max_len=8096 fld="content_ltks"):
|
|||
|
max_len //= len(top_i)
|
|||
|
# add instruction to prompt
|
|||
|
instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
|
|||
|
if len(instructions)>2:
|
|||
|
# Said that LLM is sensitive to the first and the last one, so
|
|||
|
# rearrange the order of references
|
|||
|
instructions.append(copy.deepcopy(instructions[1]))
|
|||
|
instructions.pop(1)
|
|||
|
|
|||
|
def token_num(txt):
|
|||
|
c = 0
|
|||
|
for tk in re.split(r"[,。/?‘’”“:;:;!!]", txt):
|
|||
|
if re.match(r"[a-zA-Z-]+$", tk):
|
|||
|
c += 1
|
|||
|
continue
|
|||
|
c += len(tk)
|
|||
|
return c
|
|||
|
|
|||
|
_inst = ""
|
|||
|
for ins in instructions:
|
|||
|
if token_num(_inst) > 4096:
|
|||
|
_inst += "\n知识库:" + instructions[-1][:max_len]
|
|||
|
break
|
|||
|
_inst += "\n知识库:" + ins[:max_len]
|
|||
|
return _inst
|
|||
|
|
|||
|
|
|||
|
def prompt_and_answer(history, inst):
|
|||
|
hist = get_QA_pairs(history)
|
|||
|
chks = []
|
|||
|
for s in re.split(r"[::;;。\n\r]+", inst):
|
|||
|
if s: chks.append(s)
|
|||
|
chks = len(set(chks))/(0.1+len(chks))
|
|||
|
print("Duplication portion:", chks)
|
|||
|
|
|||
|
system = """
|
|||
|
你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。
|
|||
|
以下是知识库:
|
|||
|
%s
|
|||
|
以上是知识库。
|
|||
|
"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst)
|
|||
|
|
|||
|
print("【PROMPT】:", system)
|
|||
|
start = timer()
|
|||
|
response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512})
|
|||
|
print("GENERATE: ", timer()-start)
|
|||
|
print("===>>", response)
|
|||
|
return response
|
|||
|
|
|||
|
|
|||
|
class Handler(RequestHandler):
|
|||
|
def post(self):
|
|||
|
global SE,MUST_TK_NUM
|
|||
|
param = json.loads(self.request.body.decode('utf-8'))
|
|||
|
try:
|
|||
|
question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
|
|||
|
res = SE.search({
|
|||
|
"question": question,
|
|||
|
"kb_ids": param.get("kb_ids", []),
|
|||
|
"size": param.get("topn", 15)
|
|||
|
})
|
|||
|
|
|||
|
sim = SE.rerank(res, question)
|
|||
|
rk_idx = np.argsort(sim*-1)
|
|||
|
topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)]
|
|||
|
inst = get_instruction(res, topidx)
|
|||
|
|
|||
|
ans, topidx = prompt_and_answer(param["history"], inst)
|
|||
|
ans = SE.insert_citations(ans, topidx, res)
|
|||
|
|
|||
|
refer = OrderedDict()
|
|||
|
docnms = {}
|
|||
|
for i in rk_idx:
|
|||
|
did = res.field[res.ids[i]]["doc_id"])
|
|||
|
if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"])
|
|||
|
if did not in refer: refer[did] = []
|
|||
|
refer[did].append({
|
|||
|
"chunk_id": res.ids[i],
|
|||
|
"content": res.field[res.ids[i]]["content_ltks"]),
|
|||
|
"image": ""
|
|||
|
})
|
|||
|
|
|||
|
print("::::::::::::::", ans)
|
|||
|
self.write(json.dumps({
|
|||
|
"code":0,
|
|||
|
"msg":"success",
|
|||
|
"data":{
|
|||
|
"uid": param["uid"],
|
|||
|
"dialog_id": param["dialog_id"],
|
|||
|
"assistant": ans
|
|||
|
"refer": [{
|
|||
|
"did": did,
|
|||
|
"doc_name": docnms[did],
|
|||
|
"chunks": chunks
|
|||
|
} for did, chunks in refer.items()]
|
|||
|
}
|
|||
|
}))
|
|||
|
logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False))
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logging.error("Request 500: "+str(e))
|
|||
|
self.write(json.dumps({
|
|||
|
"code":500,
|
|||
|
"msg":str(e),
|
|||
|
"data":{}
|
|||
|
}))
|
|||
|
print(traceback.format_exc())
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
parser = argparse.ArgumentParser()
|
|||
|
parser.add_argument("--port", default=4455, type=int, help="Port used for service")
|
|||
|
ARGS = parser.parse_args()
|
|||
|
|
|||
|
SE = search.ResearchReportSearch(es_conn.HuEs("infiniflow"), EMBEDDING)
|
|||
|
|
|||
|
app = Application([(r'/v1/chat/completions', Handler)],debug=False)
|
|||
|
http_server = HTTPServer(app)
|
|||
|
http_server.bind(ARGS.port)
|
|||
|
http_server.start(3)
|
|||
|
|
|||
|
IOLoop.current().start()
|
|||
|
|