mirror of
https://github.com/OpenSPG/KAG.git
synced 2025-06-27 03:20:08 +00:00
add protect data
This commit is contained in:
parent
17a6cacfdd
commit
224620ad43
@ -232,8 +232,10 @@ class MemoryGraph:
|
||||
entity.prop = Prop.from_dict(attributes, None, None)
|
||||
entity.biz_id = attributes.get("id", biz_id)
|
||||
entity.name = attributes.get("name", "")
|
||||
entity.description = attributes.get("description", "")
|
||||
entity.description = attributes.get('content', attributes.get("description", attributes.get("desc", "")))
|
||||
entity.type = attributes.get("label", label)
|
||||
entity.name_vec = attributes.get("_name_vector")
|
||||
entity.content_vec = attributes.get("_content_vector")
|
||||
entity.type_zh = None
|
||||
entity.score = 1.0
|
||||
return entity
|
||||
@ -304,6 +306,9 @@ class MemoryGraph:
|
||||
node_unique_id = get_node_unique_id(node_biz_id, node_type)
|
||||
else:
|
||||
continue
|
||||
if node_unique_id not in self.name2id:
|
||||
logger.warning(f"{node_unique_id} not found")
|
||||
continue
|
||||
node_id = self.name2id[node_unique_id]
|
||||
reset_prob[node_id] = 1
|
||||
scores = self._backend_graph.personalized_pagerank(
|
||||
@ -328,7 +333,7 @@ class MemoryGraph:
|
||||
output = []
|
||||
for idx in top_indices[::-1]:
|
||||
node_attributes = self._backend_graph.vs[idx].attributes()
|
||||
node_attributes["__labels__"] = [node_attributes.pop("label")]
|
||||
node_attributes["__labels__"] = [node_attributes["label"]]
|
||||
output.append(
|
||||
{
|
||||
"score": ppr_scores[idx],
|
||||
@ -429,7 +434,7 @@ class MemoryGraph:
|
||||
for index, score in zip(top_indices, top_values):
|
||||
node = nodes[index.item()]
|
||||
node_attributes = node.attributes()
|
||||
node_attributes["__labels__"] = [node_attributes.pop("label")]
|
||||
node_attributes["__labels__"] = [node_attributes["label"]]
|
||||
items.append({"node": node_attributes, "score": score.item()})
|
||||
return items
|
||||
|
||||
|
@ -72,6 +72,7 @@ class KgFreeRetrieverWithOpenSPG(KagLogicalFormComponent):
|
||||
)
|
||||
)
|
||||
self.top_k = top_k
|
||||
self.disable_chunk = kwargs.get("disable_chunk", False)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
@ -92,12 +93,6 @@ class KgFreeRetrieverWithOpenSPG(KagLogicalFormComponent):
|
||||
name=self.name,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
ppr_sub_query = generate_step_query(
|
||||
logical_node=cur_task.logical_node,
|
||||
processed_logical_nodes=processed_logical_nodes,
|
||||
)
|
||||
|
||||
entities = []
|
||||
selected_rel = []
|
||||
if graph_data is not None:
|
||||
@ -114,6 +109,17 @@ class KgFreeRetrieverWithOpenSPG(KagLogicalFormComponent):
|
||||
selected_rel = graph_data.get_all_spo()
|
||||
entities = list(set(entities))
|
||||
|
||||
|
||||
ppr_sub_query = generate_step_query(
|
||||
logical_node=cur_task.logical_node,
|
||||
processed_logical_nodes=processed_logical_nodes,
|
||||
)
|
||||
|
||||
if self.disable_chunk:
|
||||
cur_task.logical_node.get_fl_node_result().spo = selected_rel
|
||||
cur_task.logical_node.get_fl_node_result().sub_question = ppr_sub_query
|
||||
return [graph_data]
|
||||
|
||||
ppr_queries = [query, ppr_sub_query]
|
||||
ppr_queries = list(set(ppr_queries))
|
||||
start_time = time.time()
|
||||
|
Loading…
x
Reference in New Issue
Block a user