mirror of
https://github.com/Cinnamon/kotaemon.git
synced 2025-06-26 23:19:56 +00:00
595 lines
21 KiB
Python
595 lines
21 KiB
Python
import logging
|
|
import threading
|
|
from typing import Generator
|
|
|
|
from ktem.embeddings.manager import embedding_models_manager as embeddings
|
|
from ktem.llms.manager import llms
|
|
from ktem.reasoning.prompt_optimization import (
|
|
DecomposeQuestionPipeline,
|
|
RewriteQuestionPipeline,
|
|
)
|
|
from ktem.utils.plantuml import PlantUML
|
|
from ktem.utils.render import Render
|
|
from ktem.utils.visualize_cited import CreateCitationVizPipeline
|
|
from plotly.io import to_json
|
|
|
|
from kotaemon.base import (
|
|
AIMessage,
|
|
BaseComponent,
|
|
Document,
|
|
HumanMessage,
|
|
Node,
|
|
RetrievedDocument,
|
|
SystemMessage,
|
|
)
|
|
from kotaemon.indices.qa.citation_qa import (
|
|
CONTEXT_RELEVANT_WARNING_SCORE,
|
|
DEFAULT_QA_TEXT_PROMPT,
|
|
AnswerWithContextPipeline,
|
|
)
|
|
from kotaemon.indices.qa.citation_qa_inline import AnswerWithInlineCitation
|
|
from kotaemon.indices.qa.format_context import PrepareEvidencePipeline
|
|
from kotaemon.llms import ChatLLM
|
|
|
|
from ..utils import SUPPORTED_LANGUAGE_MAP
|
|
from .base import BaseReasoning
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AddQueryContextPipeline(BaseComponent):
|
|
|
|
n_last_interactions: int = 5
|
|
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
|
|
|
|
def run(self, question: str, history: list) -> Document:
|
|
messages = [
|
|
SystemMessage(
|
|
content="Below is a history of the conversation so far, and a new "
|
|
"question asked by the user that needs to be answered by searching "
|
|
"in a knowledge base.\nYou have access to a Search index "
|
|
"with 100's of documents.\nGenerate a search query based on the "
|
|
"conversation and the new question.\nDo not include cited source "
|
|
"filenames and document names e.g info.txt or doc.pdf in the search "
|
|
"query terms.\nDo not include any text inside [] or <<>> in the "
|
|
"search query terms.\nDo not include any special characters like "
|
|
"'+'.\nIf the question is not in English, rewrite the query in "
|
|
"the language used in the question.\n If the question contains enough "
|
|
"information, return just the number 1\n If it's unnecessary to do "
|
|
"the searching, return just the number 0."
|
|
),
|
|
HumanMessage(content="How did crypto do last year?"),
|
|
AIMessage(
|
|
content="Summarize Cryptocurrency Market Dynamics from last year"
|
|
),
|
|
HumanMessage(content="What are my health plans?"),
|
|
AIMessage(content="Show available health plans"),
|
|
]
|
|
for human, ai in history[-self.n_last_interactions :]:
|
|
messages.append(HumanMessage(content=human))
|
|
messages.append(AIMessage(content=ai))
|
|
|
|
messages.append(HumanMessage(content=f"Generate search query for: {question}"))
|
|
|
|
resp = self.llm(messages).text
|
|
if resp == "0":
|
|
return Document(content="")
|
|
|
|
if resp == "1":
|
|
return Document(content=question)
|
|
|
|
return Document(content=resp)
|
|
|
|
|
|
class FullQAPipeline(BaseReasoning):
|
|
"""Question answering pipeline. Handle from question to answer"""
|
|
|
|
class Config:
|
|
allow_extra = True
|
|
|
|
# configuration parameters
|
|
trigger_context: int = 150
|
|
use_rewrite: bool = False
|
|
|
|
retrievers: list[BaseComponent]
|
|
|
|
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
|
|
answering_pipeline: AnswerWithContextPipeline
|
|
rewrite_pipeline: RewriteQuestionPipeline | None = None
|
|
create_citation_viz_pipeline: CreateCitationVizPipeline = Node(
|
|
default_callback=lambda _: CreateCitationVizPipeline(
|
|
embedding=embeddings.get_default()
|
|
)
|
|
)
|
|
add_query_context: AddQueryContextPipeline = AddQueryContextPipeline.withx()
|
|
|
|
def retrieve(
|
|
self, message: str, history: list
|
|
) -> tuple[list[RetrievedDocument], list[Document]]:
|
|
"""Retrieve the documents based on the message"""
|
|
# if len(message) < self.trigger_context:
|
|
# # prefer adding context for short user questions, avoid adding context for
|
|
# # long questions, as they are likely to contain enough information
|
|
# # plus, avoid the situation where the original message is already too long
|
|
# # for the model to handle
|
|
# query = self.add_query_context(message, history).content
|
|
# else:
|
|
# query = message
|
|
# print(f"Rewritten query: {query}")
|
|
query = None
|
|
if not query:
|
|
# TODO: previously return [], [] because we think this message as something
|
|
# like "Hello", "I need help"...
|
|
query = message
|
|
|
|
docs, doc_ids = [], []
|
|
plot_docs = []
|
|
|
|
for idx, retriever in enumerate(self.retrievers):
|
|
retriever_node = self._prepare_child(retriever, f"retriever_{idx}")
|
|
retriever_docs = retriever_node(text=query)
|
|
|
|
retriever_docs_text = []
|
|
retriever_docs_plot = []
|
|
|
|
for doc in retriever_docs:
|
|
if doc.metadata.get("type", "") == "plot":
|
|
retriever_docs_plot.append(doc)
|
|
else:
|
|
retriever_docs_text.append(doc)
|
|
|
|
for doc in retriever_docs_text:
|
|
if doc.doc_id not in doc_ids:
|
|
docs.append(doc)
|
|
doc_ids.append(doc.doc_id)
|
|
|
|
plot_docs.extend(retriever_docs_plot)
|
|
|
|
info = [
|
|
Document(
|
|
channel="info",
|
|
content=Render.collapsible_with_header(doc, open_collapsible=True),
|
|
)
|
|
for doc in docs
|
|
] + [
|
|
Document(
|
|
channel="plot",
|
|
content=doc.metadata.get("data", ""),
|
|
)
|
|
for doc in plot_docs
|
|
]
|
|
|
|
return docs, info
|
|
|
|
def prepare_mindmap(self, answer) -> Document | None:
|
|
mindmap = answer.metadata["mindmap"]
|
|
if mindmap:
|
|
mindmap_text = mindmap.text
|
|
uml_renderer = PlantUML()
|
|
|
|
try:
|
|
mindmap_svg = uml_renderer.process(mindmap_text)
|
|
except Exception as e:
|
|
print("Failed to process mindmap:", e)
|
|
mindmap_svg = "<svg></svg>"
|
|
|
|
# post-process the mindmap SVG
|
|
mindmap_svg = (
|
|
mindmap_svg.replace("sans-serif", "Quicksand, sans-serif")
|
|
.replace("#181818", "#cecece")
|
|
.replace("background:#FFFFF", "background:none")
|
|
.replace("stroke-width:1", "stroke-width:2")
|
|
)
|
|
|
|
mindmap_content = Document(
|
|
channel="info",
|
|
content=Render.collapsible(
|
|
header="""
|
|
<i>Mindmap</i>
|
|
<a href="#" id='mindmap-toggle'>
|
|
[Expand]</a>
|
|
<a href="#" id='mindmap-export'>
|
|
[Export]</a>""",
|
|
content=mindmap_svg,
|
|
open=True,
|
|
),
|
|
)
|
|
else:
|
|
mindmap_content = None
|
|
|
|
return mindmap_content
|
|
|
|
def prepare_citation_viz(self, answer, question, docs) -> Document | None:
|
|
doc_texts = [doc.text for doc in docs]
|
|
citation_plot = None
|
|
plot_content = None
|
|
|
|
if answer.metadata["citation_viz"] and len(docs) > 1:
|
|
try:
|
|
citation_plot = self.create_citation_viz_pipeline(doc_texts, question)
|
|
except Exception as e:
|
|
print("Failed to create citation plot:", e)
|
|
|
|
if citation_plot:
|
|
plot = to_json(citation_plot)
|
|
plot_content = Document(channel="plot", content=plot)
|
|
|
|
return plot_content
|
|
|
|
def show_citations_and_addons(self, answer, docs, question):
|
|
# show the evidence
|
|
with_citation, without_citation = self.answering_pipeline.prepare_citations(
|
|
answer, docs
|
|
)
|
|
mindmap_output = self.prepare_mindmap(answer)
|
|
citation_plot_output = self.prepare_citation_viz(answer, question, docs)
|
|
|
|
if not with_citation and not without_citation:
|
|
yield Document(channel="info", content="<h5><b>No evidence found.</b></h5>")
|
|
else:
|
|
# clear the Info panel
|
|
max_llm_rerank_score = max(
|
|
doc.metadata.get("llm_trulens_score", 0.0) for doc in docs
|
|
)
|
|
has_llm_score = any("llm_trulens_score" in doc.metadata for doc in docs)
|
|
# clear previous info
|
|
yield Document(channel="info", content=None)
|
|
|
|
# yield mindmap output
|
|
if mindmap_output:
|
|
yield mindmap_output
|
|
|
|
# yield citation plot output
|
|
if citation_plot_output:
|
|
yield citation_plot_output
|
|
|
|
# yield warning message
|
|
if has_llm_score and max_llm_rerank_score < CONTEXT_RELEVANT_WARNING_SCORE:
|
|
yield Document(
|
|
channel="info",
|
|
content=(
|
|
"<h5>WARNING! Context relevance score is low. "
|
|
"Double check the model answer for correctness.</h5>"
|
|
),
|
|
)
|
|
|
|
# show QA score
|
|
qa_score = (
|
|
round(answer.metadata["qa_score"], 2)
|
|
if answer.metadata.get("qa_score")
|
|
else None
|
|
)
|
|
if qa_score:
|
|
yield Document(
|
|
channel="info",
|
|
content=f"<h5>Answer confidence: {qa_score}</h5>",
|
|
)
|
|
|
|
yield from with_citation
|
|
if without_citation:
|
|
yield from without_citation
|
|
|
|
async def ainvoke( # type: ignore
|
|
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
|
) -> Document: # type: ignore
|
|
raise NotImplementedError
|
|
|
|
def stream( # type: ignore
|
|
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
|
) -> Generator[Document, None, Document]:
|
|
if self.use_rewrite and self.rewrite_pipeline:
|
|
print("Chosen rewrite pipeline", self.rewrite_pipeline)
|
|
message = self.rewrite_pipeline(question=message).text
|
|
print("Rewrite result", message)
|
|
|
|
print(f"Retrievers {self.retrievers}")
|
|
# should populate the context
|
|
docs, infos = self.retrieve(message, history)
|
|
print(f"Got {len(docs)} retrieved documents")
|
|
yield from infos
|
|
|
|
evidence_mode, evidence, images = self.evidence_pipeline(docs).content
|
|
|
|
def generate_relevant_scores():
|
|
nonlocal docs
|
|
docs = self.retrievers[0].generate_relevant_scores(message, docs)
|
|
|
|
# generate relevant score using
|
|
if evidence and self.retrievers:
|
|
scoring_thread = threading.Thread(target=generate_relevant_scores)
|
|
scoring_thread.start()
|
|
else:
|
|
scoring_thread = None
|
|
|
|
answer = yield from self.answering_pipeline.stream(
|
|
question=message,
|
|
history=history,
|
|
evidence=evidence,
|
|
evidence_mode=evidence_mode,
|
|
images=images,
|
|
conv_id=conv_id,
|
|
**kwargs,
|
|
)
|
|
|
|
# show the evidence
|
|
if scoring_thread:
|
|
scoring_thread.join()
|
|
|
|
yield from self.show_citations_and_addons(answer, docs, message)
|
|
|
|
return answer
|
|
|
|
@classmethod
|
|
def prepare_pipeline_instance(cls, settings, retrievers):
|
|
return cls(
|
|
retrievers=retrievers,
|
|
rewrite_pipeline=RewriteQuestionPipeline(),
|
|
)
|
|
|
|
@classmethod
|
|
def get_pipeline(cls, settings, states, retrievers):
|
|
"""Get the reasoning pipeline
|
|
|
|
Args:
|
|
settings: the settings for the pipeline
|
|
retrievers: the retrievers to use
|
|
"""
|
|
max_context_length_setting = settings.get("reasoning.max_context_length", 32000)
|
|
|
|
pipeline = cls.prepare_pipeline_instance(settings, retrievers)
|
|
|
|
prefix = f"reasoning.options.{cls.get_info()['id']}"
|
|
llm_name = settings.get(f"{prefix}.llm", None)
|
|
llm = llms.get(llm_name, llms.get_default())
|
|
|
|
# prepare evidence pipeline configuration
|
|
evidence_pipeline = pipeline.evidence_pipeline
|
|
evidence_pipeline.max_context_length = max_context_length_setting
|
|
|
|
# answering pipeline configuration
|
|
use_inline_citation = settings[f"{prefix}.highlight_citation"] == "inline"
|
|
|
|
if use_inline_citation:
|
|
answer_pipeline = pipeline.answering_pipeline = AnswerWithInlineCitation()
|
|
else:
|
|
answer_pipeline = pipeline.answering_pipeline = AnswerWithContextPipeline()
|
|
|
|
answer_pipeline.llm = llm
|
|
answer_pipeline.citation_pipeline.llm = llm
|
|
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
|
|
answer_pipeline.enable_citation = (
|
|
settings[f"{prefix}.highlight_citation"] != "off"
|
|
)
|
|
answer_pipeline.enable_mindmap = settings[f"{prefix}.create_mindmap"]
|
|
answer_pipeline.enable_citation_viz = settings[f"{prefix}.create_citation_viz"]
|
|
answer_pipeline.use_multimodal = settings[f"{prefix}.use_multimodal"]
|
|
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
|
|
answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"]
|
|
answer_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
|
|
settings["reasoning.lang"], "English"
|
|
)
|
|
|
|
pipeline.add_query_context.llm = llm
|
|
pipeline.add_query_context.n_last_interactions = settings[
|
|
f"{prefix}.n_last_interactions"
|
|
]
|
|
|
|
pipeline.trigger_context = settings[f"{prefix}.trigger_context"]
|
|
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
|
|
if pipeline.rewrite_pipeline:
|
|
pipeline.rewrite_pipeline.llm = llm
|
|
pipeline.rewrite_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
|
|
settings["reasoning.lang"], "English"
|
|
)
|
|
return pipeline
|
|
|
|
@classmethod
|
|
def get_user_settings(cls) -> dict:
|
|
from ktem.llms.manager import llms
|
|
|
|
llm = ""
|
|
choices = [("(default)", "")]
|
|
try:
|
|
choices += [(_, _) for _ in llms.options().keys()]
|
|
except Exception as e:
|
|
logger.exception(f"Failed to get LLM options: {e}")
|
|
|
|
return {
|
|
"llm": {
|
|
"name": "Language model",
|
|
"value": llm,
|
|
"component": "dropdown",
|
|
"choices": choices,
|
|
"special_type": "llm",
|
|
"info": (
|
|
"The language model to use for generating the answer. If None, "
|
|
"the application default language model will be used."
|
|
),
|
|
},
|
|
"highlight_citation": {
|
|
"name": "Citation style",
|
|
"value": "highlight",
|
|
"component": "radio",
|
|
"choices": [
|
|
("highlight (verbose)", "highlight"),
|
|
("inline (concise)", "inline"),
|
|
("no citation", "off"),
|
|
],
|
|
},
|
|
"create_mindmap": {
|
|
"name": "Create Mindmap",
|
|
"value": False,
|
|
"component": "checkbox",
|
|
},
|
|
"create_citation_viz": {
|
|
"name": "Create Embeddings Visualization",
|
|
"value": False,
|
|
"component": "checkbox",
|
|
},
|
|
"use_multimodal": {
|
|
"name": "Use Multimodal Input",
|
|
"value": False,
|
|
"component": "checkbox",
|
|
},
|
|
"system_prompt": {
|
|
"name": "System Prompt",
|
|
"value": "This is a question answering system",
|
|
},
|
|
"qa_prompt": {
|
|
"name": "QA Prompt (contains {context}, {question}, {lang})",
|
|
"value": DEFAULT_QA_TEXT_PROMPT,
|
|
},
|
|
"n_last_interactions": {
|
|
"name": "Number of interactions to include",
|
|
"value": 5,
|
|
"component": "number",
|
|
"info": "The maximum number of chat interactions to include in the LLM",
|
|
},
|
|
"trigger_context": {
|
|
"name": "Maximum message length for context rewriting",
|
|
"value": 150,
|
|
"component": "number",
|
|
"info": (
|
|
"The maximum length of the message to trigger context addition. "
|
|
"Exceeding this length, the message will be used as is."
|
|
),
|
|
},
|
|
}
|
|
|
|
@classmethod
|
|
def get_info(cls) -> dict:
|
|
return {
|
|
"id": "simple",
|
|
"name": "Simple QA",
|
|
"description": (
|
|
"Simple RAG-based question answering pipeline. This pipeline can "
|
|
"perform both keyword search and similarity search to retrieve the "
|
|
"context. After that it includes that context to generate the answer."
|
|
),
|
|
}
|
|
|
|
|
|
class FullDecomposeQAPipeline(FullQAPipeline):
|
|
def answer_sub_questions(
|
|
self, messages: list, conv_id: str, history: list, **kwargs
|
|
):
|
|
output_str = ""
|
|
for idx, message in enumerate(messages):
|
|
yield Document(
|
|
channel="chat",
|
|
content=f"<br><b>Sub-question {idx + 1}</b>"
|
|
f"<br>{message}<br><b>Answer</b><br>",
|
|
)
|
|
# should populate the context
|
|
docs, infos = self.retrieve(message, history)
|
|
print(f"Got {len(docs)} retrieved documents")
|
|
|
|
yield from infos
|
|
|
|
evidence_mode, evidence, images = self.evidence_pipeline(docs).content
|
|
answer = yield from self.answering_pipeline.stream(
|
|
question=message,
|
|
history=history,
|
|
evidence=evidence,
|
|
evidence_mode=evidence_mode,
|
|
images=images,
|
|
conv_id=conv_id,
|
|
**kwargs,
|
|
)
|
|
|
|
output_str += (
|
|
f"Sub-question {idx + 1}-th: '{message}'\nAnswer: '{answer.text}'\n\n"
|
|
)
|
|
|
|
return output_str
|
|
|
|
def stream( # type: ignore
|
|
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
|
) -> Generator[Document, None, Document]:
|
|
sub_question_answer_output = ""
|
|
if self.rewrite_pipeline:
|
|
print("Chosen rewrite pipeline", self.rewrite_pipeline)
|
|
result = self.rewrite_pipeline(question=message)
|
|
print("Rewrite result", result)
|
|
if isinstance(result, Document):
|
|
message = result.text
|
|
elif (
|
|
isinstance(result, list)
|
|
and len(result) > 0
|
|
and isinstance(result[0], Document)
|
|
):
|
|
yield Document(
|
|
channel="chat",
|
|
content="<h4>Sub questions and their answers</h4>",
|
|
)
|
|
sub_question_answer_output = yield from self.answer_sub_questions(
|
|
[r.text for r in result], conv_id, history, **kwargs
|
|
)
|
|
|
|
yield Document(
|
|
channel="chat",
|
|
content=f"<h4>Main question</h4>{message}<br><b>Answer</b><br>",
|
|
)
|
|
|
|
# should populate the context
|
|
docs, infos = self.retrieve(message, history)
|
|
print(f"Got {len(docs)} retrieved documents")
|
|
yield from infos
|
|
|
|
evidence_mode, evidence, images = self.evidence_pipeline(docs).content
|
|
answer = yield from self.answering_pipeline.stream(
|
|
question=message,
|
|
history=history,
|
|
evidence=evidence + "\n" + sub_question_answer_output,
|
|
evidence_mode=evidence_mode,
|
|
images=images,
|
|
conv_id=conv_id,
|
|
**kwargs,
|
|
)
|
|
|
|
# show the evidence
|
|
with_citation, without_citation = self.answering_pipeline.prepare_citations(
|
|
answer, docs
|
|
)
|
|
if not with_citation and not without_citation:
|
|
yield Document(channel="info", content="<h5><b>No evidence found.</b></h5>")
|
|
else:
|
|
yield Document(channel="info", content=None)
|
|
yield from with_citation
|
|
yield from without_citation
|
|
|
|
return answer
|
|
|
|
@classmethod
|
|
def get_user_settings(cls) -> dict:
|
|
user_settings = super().get_user_settings()
|
|
user_settings["decompose_prompt"] = {
|
|
"name": "Decompose Prompt",
|
|
"value": DecomposeQuestionPipeline.DECOMPOSE_SYSTEM_PROMPT_TEMPLATE,
|
|
}
|
|
return user_settings
|
|
|
|
@classmethod
|
|
def prepare_pipeline_instance(cls, settings, retrievers):
|
|
prefix = f"reasoning.options.{cls.get_info()['id']}"
|
|
pipeline = cls(
|
|
retrievers=retrievers,
|
|
rewrite_pipeline=DecomposeQuestionPipeline(
|
|
prompt_template=settings.get(f"{prefix}.decompose_prompt")
|
|
),
|
|
)
|
|
return pipeline
|
|
|
|
@classmethod
|
|
def get_info(cls) -> dict:
|
|
return {
|
|
"id": "complex",
|
|
"name": "Complex QA",
|
|
"description": (
|
|
"Use multi-step reasoning to decompose a complex question into "
|
|
"multiple sub-questions. This pipeline can "
|
|
"perform both keyword search and similarity search to retrieve the "
|
|
"context. After that it includes that context to generate the answer."
|
|
),
|
|
}
|