ragflow/rag/prompts.py
Kevin Hu fbd115773b
Perf: set timeout of some steps in KG. (#8873)
### What problem does this PR solve?

### Type of change


- [x] Performance Improvement
2025-07-16 18:06:03 +08:00

291 lines
11 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import json
import logging
import re
from collections import defaultdict
import jinja2
import json_repair
from rag.prompt_template import load_prompt
from rag.settings import TAG_FLD
from rag.utils import encoder, num_tokens_from_string
def chunks_format(reference):
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))
return [
{
"id": get_value(chunk, "chunk_id", "id"),
"content": get_value(chunk, "content", "content_with_weight"),
"document_id": get_value(chunk, "doc_id", "document_id"),
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"image_id": get_value(chunk, "image_id", "img_id"),
"positions": get_value(chunk, "positions", "position_int"),
"url": chunk.get("url"),
"similarity": chunk.get("similarity"),
"vector_similarity": chunk.get("vector_similarity"),
"term_similarity": chunk.get("term_similarity"),
"doc_type": chunk.get("doc_type_kwd"),
}
for chunk in reference.get("chunks", [])
]
def message_fit_in(msg, max_length=4000):
def count():
nonlocal msg
tks_cnts = []
for m in msg:
tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
total = 0
for m in tks_cnts:
total += m["count"]
return total
c = count()
if c < max_length:
return c, msg
msg_ = [m for m in msg if m["role"] == "system"]
if len(msg) > 1:
msg_.append(msg[-1])
msg = msg_
c = count()
if c < max_length:
return c, msg
ll = num_tokens_from_string(msg_[0]["content"])
ll2 = num_tokens_from_string(msg_[-1]["content"])
if ll / (ll + ll2) > 0.8:
m = msg_[0]["content"]
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
msg[0]["content"] = m
return max_length, msg
m = msg_[-1]["content"]
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
msg[-1]["content"] = m
return max_length, msg
def kb_prompt(kbinfos, max_tokens):
from api.db.services.document_service import DocumentService
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
kwlg_len = len(knowledges)
used_token_count = 0
chunks_num = 0
for i, c in enumerate(knowledges):
used_token_count += num_tokens_from_string(c)
chunks_num += 1
if max_tokens * 0.97 < used_token_count:
knowledges = knowledges[:i]
logging.warning(f"Not all the retrieval into prompt: {len(knowledges)}/{kwlg_len}")
break
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
docs = {d.id: d.meta_fields for d in docs}
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
cnt += re.sub(r"( style=\"[^\"]+\"|</?(html|body|head|title)>|<!DOCTYPE html>)", " ", ck["content_with_weight"], flags=re.DOTALL | re.IGNORECASE)
doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
knowledges = []
for nm, cks_meta in doc2chunks.items():
txt = f"\nDocument: {nm} \n"
for k, v in cks_meta["meta"].items():
txt += f"{k}: {v}\n"
txt += "Relevant fragments as following:\n"
for i, chunk in enumerate(cks_meta["chunks"], 1):
txt += f"{chunk}\n"
knowledges.append(txt)
return knowledges
CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE = load_prompt("cross_languages_sys_prompt")
CROSS_LANGUAGES_USER_PROMPT_TEMPLATE = load_prompt("cross_languages_user_prompt")
FULL_QUESTION_PROMPT_TEMPLATE = load_prompt("full_question_prompt")
KEYWORD_PROMPT_TEMPLATE = load_prompt("keyword_prompt")
QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt")
VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt")
VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt")
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
def citation_prompt() -> str:
template = PROMPT_JINJA_ENV.from_string(CITATION_PROMPT_TEMPLATE)
return template.render()
def keyword_extraction(chat_mdl, content, topn=3):
template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
def question_proposal(chat_mdl, content, topn=3):
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
def full_question(tenant_id, llm_id, messages, language=None):
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.llm_service import TenantLLMService
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
conv = []
for m in messages:
if m["role"] not in ["user", "assistant"]:
continue
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
conversation = "\n".join(conv)
today = datetime.date.today().isoformat()
yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat()
tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat()
template = PROMPT_JINJA_ENV.from_string(FULL_QUESTION_PROMPT_TEMPLATE)
rendered_prompt = template.render(
today=today,
yesterday=yesterday,
tomorrow=tomorrow,
conversation=conversation,
language=language,
)
ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2})
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
def cross_languages(tenant_id, llm_id, query, languages=[]):
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.llm_service import TenantLLMService
if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if ans.find("**ERROR**") >= 0:
return query
return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE)
for ex in examples:
ex["tags_json"] = json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False)
rendered_prompt = template.render(
topn=topn,
all_tags=all_tags,
examples=examples,
content=content,
)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
raise Exception(kwd)
try:
obj = json_repair.loads(kwd)
except json_repair.JSONDecodeError:
try:
result = kwd.replace(rendered_prompt[:-1], "").replace("user", "").replace("model", "").strip()
result = "{" + result.split("{")[1].split("}")[0] + "}"
obj = json_repair.loads(result)
except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}")
raise e
res = {}
for k, v in obj.items():
try:
if int(v) > 0:
res[str(k)] = int(v)
except Exception:
pass
return res
def vision_llm_describe_prompt(page=None) -> str:
template = PROMPT_JINJA_ENV.from_string(VISION_LLM_DESCRIBE_PROMPT)
return template.render(page=page)
def vision_llm_figure_describe_prompt() -> str:
template = PROMPT_JINJA_ENV.from_string(VISION_LLM_FIGURE_DESCRIBE_PROMPT)
return template.render()
if __name__ == "__main__":
print(CITATION_PROMPT_TEMPLATE)
print(CONTENT_TAGGING_PROMPT_TEMPLATE)
print(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE)
print(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE)
print(FULL_QUESTION_PROMPT_TEMPLATE)
print(KEYWORD_PROMPT_TEMPLATE)
print(QUESTION_PROMPT_TEMPLATE)
print(VISION_LLM_DESCRIBE_PROMPT)
print(VISION_LLM_FIGURE_DESCRIBE_PROMPT)