diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 3678c18e6..a1da4156b 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -140,7 +140,8 @@ def completion(): if not conv.reference: conv.reference.append(ans["reference"]) else: conv.reference[-1] = ans["reference"] - conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id} + conv.message[-1] = {"role": "assistant", "content": ans["answer"], + "id": message_id, "prompt": ans.get("prompt", "")} def stream(): nonlocal dia, msg, req, conv diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 417d6106f..241493898 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -179,6 +179,7 @@ def chat(dialog, messages, stream=True, **kwargs): for m in messages if m["role"] != "system"]) used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97)) assert len(msg) >= 2, f"message_fit_in has bug: {msg}" + prompt = msg[0]["content"] if "max_tokens" in gen_conf: gen_conf["max_tokens"] = min( @@ -186,7 +187,7 @@ def chat(dialog, messages, stream=True, **kwargs): max_tokens - used_token_count) def decorate_answer(answer): - nonlocal prompt_config, knowledges, kwargs, kbinfos + nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt refs = [] if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): answer, idx = retr.insert_citations(answer, @@ -210,17 +211,16 @@ def chat(dialog, messages, stream=True, **kwargs): if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" - return {"answer": answer, "reference": refs} + return {"answer": answer, "reference": refs, "prompt": prompt} if stream: answer = "" - for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], gen_conf): + for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf): answer = ans - yield {"answer": answer, "reference": {}} + yield {"answer": answer, "reference": {}, "prompt": prompt} yield decorate_answer(answer) else: - answer = chat_mdl.chat( - msg[0]["content"], msg[1:], gen_conf) + answer = chat_mdl.chat(prompt, msg[1:], gen_conf) chat_logger.info("User: {}|Assistant: {}".format( msg[-1]["content"], answer)) yield decorate_answer(answer) @@ -334,7 +334,8 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): chat_logger.warning("SQL missing field: " + sql) return { "answer": "\n".join([clmns, line, rows]), - "reference": {"chunks": [], "doc_aggs": []} + "reference": {"chunks": [], "doc_aggs": []}, + "prompt": sys_prompt } docid_idx = list(docid_idx)[0] @@ -348,7 +349,8 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): "answer": "\n".join([clmns, line, rows]), "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]], "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in - doc_aggs.items()]} + doc_aggs.items()]}, + "prompt": sys_prompt }