mirror of
https://github.com/OpenSPG/KAG.git
synced 2025-06-27 03:20:08 +00:00
fix(solver): add component name (#480)
* add ai search example * bugfix reporter name * update version * fix ci
This commit is contained in:
parent
d842662e5c
commit
a7fd51d138
@ -1 +1 @@
|
|||||||
0.7.0b2
|
0.7.0b2
|
@ -351,6 +351,7 @@ def get_recall_node_label(label_set):
|
|||||||
for l in label_set:
|
for l in label_set:
|
||||||
if l != "Entity":
|
if l != "Entity":
|
||||||
return l
|
return l
|
||||||
|
return "Entity"
|
||||||
|
|
||||||
|
|
||||||
def node_2_doc(node: dict):
|
def node_2_doc(node: dict):
|
||||||
|
@ -76,9 +76,7 @@ class KagMerger(FlowComponent):
|
|||||||
top_k,
|
top_k,
|
||||||
llm_module: LLMClient = None,
|
llm_module: LLMClient = None,
|
||||||
summary_prompt: PromptABC = None,
|
summary_prompt: PromptABC = None,
|
||||||
vector_chunk_retriever: VectorChunkRetriever = None,
|
|
||||||
vectorize_model: VectorizeModelABC = None,
|
vectorize_model: VectorizeModelABC = None,
|
||||||
search_api: SearchApiABC = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -96,22 +94,6 @@ class KagMerger(FlowComponent):
|
|||||||
)
|
)
|
||||||
self.text_similarity = TextSimilarity(vectorize_model)
|
self.text_similarity = TextSimilarity(vectorize_model)
|
||||||
|
|
||||||
self.search_api = search_api or SearchApiABC.from_config(
|
|
||||||
{"type": "openspg_search_api"}
|
|
||||||
)
|
|
||||||
self.vector_chunk_retriever = vector_chunk_retriever or VectorChunkRetriever(
|
|
||||||
vectorize_model=self.vectorize_model, search_api=self.search_api
|
|
||||||
)
|
|
||||||
|
|
||||||
def recall_query(self, query):
|
|
||||||
sim_scores_start_time = time.time()
|
|
||||||
"""Process a single query for similarity scores in parallel."""
|
|
||||||
query_sim_scores = self.vector_chunk_retriever.invoke(query, self.top_k * 20)
|
|
||||||
logger.info(
|
|
||||||
f"`{query}` Similarity scores calculation completed in {time.time() - sim_scores_start_time:.2f} seconds."
|
|
||||||
)
|
|
||||||
return query_sim_scores
|
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
cur_task: FlowComponentTask,
|
cur_task: FlowComponentTask,
|
||||||
@ -157,7 +139,11 @@ class KagMerger(FlowComponent):
|
|||||||
"FINISH",
|
"FINISH",
|
||||||
component_name=self.name,
|
component_name=self.name,
|
||||||
chunk_num=len(limited_merged_chunks),
|
chunk_num=len(limited_merged_chunks),
|
||||||
desc="kag_merger_digest",
|
desc=(
|
||||||
|
"kag_merger_digest"
|
||||||
|
if len(limited_merged_chunks) > 0
|
||||||
|
else "kag_merger_digest_failed"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
# summary
|
# summary
|
||||||
formatted_docs = []
|
formatted_docs = []
|
||||||
|
@ -29,7 +29,7 @@ class KgConstrainRetrieverWithOpenSPG(KagLogicalFormComponent):
|
|||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.name = "kg_cs"
|
self.name = kwargs.get("name", "kg_cs")
|
||||||
self.llm = llm or LLMClient.from_config(get_default_chat_llm_config())
|
self.llm = llm or LLMClient.from_config(get_default_chat_llm_config())
|
||||||
self.path_select = path_select or PathSelect.from_config(
|
self.path_select = path_select or PathSelect.from_config(
|
||||||
{"type": "exact_one_hop_select"}
|
{"type": "exact_one_hop_select"}
|
||||||
|
@ -2,7 +2,7 @@ import logging
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from kag.common.config import get_default_chat_llm_config
|
from kag.common.config import get_default_chat_llm_config
|
||||||
from kag.interface import LLMClient, Task
|
from kag.interface import LLMClient, Task, ToolABC
|
||||||
from kag.interface.solver.base_model import LogicNode
|
from kag.interface.solver.base_model import LogicNode
|
||||||
from kag.interface.solver.model.one_hop_graph import RetrievedData
|
from kag.interface.solver.model.one_hop_graph import RetrievedData
|
||||||
from kag.interface.solver.reporter_abc import ReporterABC
|
from kag.interface.solver.reporter_abc import ReporterABC
|
||||||
@ -37,12 +37,12 @@ class KgFreeRetrieverWithOpenSPG(KagLogicalFormComponent):
|
|||||||
path_select: PathSelect = None,
|
path_select: PathSelect = None,
|
||||||
entity_linking=None,
|
entity_linking=None,
|
||||||
llm: LLMClient = None,
|
llm: LLMClient = None,
|
||||||
ppr_chunk_retriever_tool: PprChunkRetriever = None,
|
ppr_chunk_retriever_tool: ToolABC = None,
|
||||||
top_k=10,
|
top_k=10,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.name = "kg_fr"
|
self.name = kwargs.get("name", "kg_fr")
|
||||||
self.llm = llm or LLMClient.from_config(get_default_chat_llm_config())
|
self.llm = llm or LLMClient.from_config(get_default_chat_llm_config())
|
||||||
self.path_select = path_select or PathSelect.from_config(
|
self.path_select = path_select or PathSelect.from_config(
|
||||||
{"type": "fuzzy_one_hop_select"}
|
{"type": "fuzzy_one_hop_select"}
|
||||||
@ -122,7 +122,7 @@ class KgFreeRetrieverWithOpenSPG(KagLogicalFormComponent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"`{query}` Retrieved chunks num: {len(chunks)}")
|
logger.info(f"`{query}` Retrieved chunks num: {len(chunks)}")
|
||||||
cur_task.logical_node.get_fl_node_result().spo = match_spo
|
cur_task.logical_node.get_fl_node_result().spo = match_spo + selected_rel
|
||||||
cur_task.logical_node.get_fl_node_result().chunks = chunks
|
cur_task.logical_node.get_fl_node_result().chunks = chunks
|
||||||
cur_task.logical_node.get_fl_node_result().sub_question = ppr_sub_query
|
cur_task.logical_node.get_fl_node_result().sub_question = ppr_sub_query
|
||||||
if reporter:
|
if reporter:
|
||||||
|
@ -45,7 +45,7 @@ class RCRetrieverOnOpenSPG(KagLogicalFormComponent):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.name = "kg_rc"
|
self.name = kwargs.get("name", "kg_rc")
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.vectorize_model = vectorize_model or VectorizeModelABC.from_config(
|
self.vectorize_model = vectorize_model or VectorizeModelABC.from_config(
|
||||||
KAG_CONFIG.all_config["vectorize_model"]
|
KAG_CONFIG.all_config["vectorize_model"]
|
||||||
|
@ -143,6 +143,10 @@ class OpenSPGReporter(ReporterABC):
|
|||||||
self.report_sub_segment = {}
|
self.report_sub_segment = {}
|
||||||
self.thinking_enabled = kwargs.get("thinking_enabled", True)
|
self.thinking_enabled = kwargs.get("thinking_enabled", True)
|
||||||
self.word_mapping = {
|
self.word_mapping = {
|
||||||
|
"kag_merger_digest_failed": {
|
||||||
|
"zh": "未检索到相关信息。",
|
||||||
|
"en": "No relevant information was found.",
|
||||||
|
},
|
||||||
"kag_merger_digest": {
|
"kag_merger_digest": {
|
||||||
"zh": "排序文档后,输出{chunk_num}篇文档, 检索信息已足够回答问题。",
|
"zh": "排序文档后,输出{chunk_num}篇文档, 检索信息已足够回答问题。",
|
||||||
"en": "{chunk_num} documents were output, sufficient information retrieved to answer the question.",
|
"en": "{chunk_num} documents were output, sufficient information retrieved to answer the question.",
|
||||||
|
@ -369,7 +369,7 @@ class PprChunkRetriever(ToolABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return matched_docs, self._convert_relation_datas(
|
return matched_docs, self._convert_relation_datas(
|
||||||
chunk_docs=matched_docs, matched_entities=matched_entities[:top_k]
|
chunk_docs=matched_docs[:top_k], matched_entities=matched_entities
|
||||||
)
|
)
|
||||||
|
|
||||||
def _convert_relation_datas(self, chunk_docs, matched_entities):
|
def _convert_relation_datas(self, chunk_docs, matched_entities):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user