Implement get_nodes_by_chunk_ids and get_edges_by_chunk_ids,

This commit is contained in:
Ken Chen 2025-06-25 22:17:17 +08:00
parent 81cff6e97f
commit a3865caaea
2 changed files with 164 additions and 30 deletions

View File

@ -17,6 +17,8 @@ from ..base import (
from ..namespace import NameSpace, is_namespace
from ..utils import logger, compute_mdhash_id
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
import pipmaster as pm
if not pm.is_installed("pymongo"):
@ -353,33 +355,33 @@ class MongoGraphStorage(BaseGraphStorage):
self.collection = None
self.edge_collection = None
#
# -------------------------------------------------------------------------
# HELPER: $graphLookup pipeline
# -------------------------------------------------------------------------
#
# Sample entity document
# "source_ids" is Array representation of "source_id" split by GRAPH_FIELD_SEP
# Sample entity_relation document
# {
# "_id" : "CompanyA",
# "created_at" : 1749904575,
# "description" : "A major technology company",
# "edges" : [
# {
# "target" : "ProductX",
# "relation": "Develops", // To distinguish multiple same-target relations
# "weight" : Double("1"),
# "description" : "CompanyA develops ProductX",
# "keywords" : "develop, produce",
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec",
# "file_path" : "custom_kg",
# "created_at" : 1749904575
# }
# ],
# "entity_id" : "CompanyA",
# "entity_type" : "Organization",
# "description" : "A major technology company",
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec",
# "source_ids": ["chunk-eeec0036b909839e8ec4fa150c939eec"],
# "file_path" : "custom_kg",
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec"
# "created_at" : 1749904575
# }
# Sample relation document
# {
# "_id" : ObjectId("6856ac6e7c6bad9b5470b678"), // MongoDB build-in ObjectId
# "description" : "CompanyA develops ProductX",
# "source_node_id" : "CompanyA",
# "target_node_id" : "ProductX",
# "relationship": "Develops", // To distinguish multiple same-target relations
# "weight" : Double("1"),
# "keywords" : "develop, produce",
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec",
# "source_ids": ["chunk-eeec0036b909839e8ec4fa150c939eec"],
# "file_path" : "custom_kg",
# "created_at" : 1749904575
# }
#
@ -567,6 +569,45 @@ class MongoGraphStorage(BaseGraphStorage):
return result
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
if not chunk_ids:
return []
cursor = self.collection.find({"source_ids": {"$in": chunk_ids}})
return [doc async for doc in cursor]
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
if not chunk_ids:
return []
cursor = self.edge_collection.find({"source_ids": {"$in": chunk_ids}})
edges = []
async for edge in cursor:
edge["source"] = edge["source_node_id"]
edge["target"] = edge["target_node_id"]
edges.append(edge)
return edges
#
# -------------------------------------------------------------------------
# UPSERTS
@ -578,6 +619,11 @@ class MongoGraphStorage(BaseGraphStorage):
Insert or update a node document.
"""
update_doc = {"$set": {**node_data}}
if node_data.get("source_id", ""):
update_doc["$set"]["source_ids"] = node_data["source_id"].split(
GRAPH_FIELD_SEP
)
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
async def upsert_edge(
@ -590,9 +636,15 @@ class MongoGraphStorage(BaseGraphStorage):
# Ensure source node exists
await self.upsert_node(source_node_id, {})
update_doc = {"$set": edge_data}
if edge_data.get("source_id", ""):
update_doc["$set"]["source_ids"] = edge_data["source_id"].split(
GRAPH_FIELD_SEP
)
await self.edge_collection.update_one(
{"source_node_id": source_node_id, "target_node_id": target_node_id},
{"$set": edge_data},
update_doc,
upsert=True,
)
@ -789,14 +841,16 @@ class MongoGraphStorage(BaseGraphStorage):
if not edges:
return
await self.edge_collection.delete_many(
{
"$or": [
{"source_node_id": source_id, "target_node_id": target_id}
for source_id, target_id in edges
]
}
)
all_edge_pairs = []
for source_id, target_id in edges:
all_edge_pairs.append(
{"source_node_id": source_id, "target_node_id": target_id}
)
all_edge_pairs.append(
{"source_node_id": target_id, "target_node_id": source_id}
)
await self.edge_collection.delete_many({"$or": all_edge_pairs})
logger.debug(f"Successfully deleted edges: {edges}")

View File

@ -30,6 +30,7 @@ from lightrag.kg import (
verify_storage_implementation,
)
from lightrag.kg.shared_storage import initialize_share_data
from lightrag.constants import GRAPH_FIELD_SEP
# 模拟的嵌入函数,返回随机向量
@ -437,6 +438,9 @@ async def test_graph_batch_operations(storage):
5. 使用 get_nodes_edges_batch 批量获取多个节点的所有边
"""
try:
chunk1_id = "1"
chunk2_id = "2"
chunk3_id = "3"
# 1. 插入测试数据
# 插入节点1: 人工智能
node1_id = "人工智能"
@ -445,6 +449,7 @@ async def test_graph_batch_operations(storage):
"description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
"keywords": "AI,机器学习,深度学习",
"entity_type": "技术领域",
"source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]),
}
print(f"插入节点1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
@ -456,6 +461,7 @@ async def test_graph_batch_operations(storage):
"description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
"keywords": "监督学习,无监督学习,强化学习",
"entity_type": "技术领域",
"source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]),
}
print(f"插入节点2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
@ -467,6 +473,7 @@ async def test_graph_batch_operations(storage):
"description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
"keywords": "神经网络,CNN,RNN",
"entity_type": "技术领域",
"source_id": GRAPH_FIELD_SEP.join([chunk3_id]),
}
print(f"插入节点3: {node3_id}")
await storage.upsert_node(node3_id, node3_data)
@ -498,6 +505,7 @@ async def test_graph_batch_operations(storage):
"relationship": "包含",
"weight": 1.0,
"description": "人工智能领域包含机器学习这个子领域",
"source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]),
}
print(f"插入边1: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data)
@ -507,6 +515,7 @@ async def test_graph_batch_operations(storage):
"relationship": "包含",
"weight": 1.0,
"description": "机器学习领域包含深度学习这个子领域",
"source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]),
}
print(f"插入边2: {node2_id} -> {node3_id}")
await storage.upsert_edge(node2_id, node3_id, edge2_data)
@ -516,6 +525,7 @@ async def test_graph_batch_operations(storage):
"relationship": "包含",
"weight": 1.0,
"description": "人工智能领域包含自然语言处理这个子领域",
"source_id": GRAPH_FIELD_SEP.join([chunk3_id]),
}
print(f"插入边3: {node1_id} -> {node4_id}")
await storage.upsert_edge(node1_id, node4_id, edge3_data)
@ -748,6 +758,76 @@ async def test_graph_batch_operations(storage):
print("无向图特性验证成功:批量获取的节点边包含所有相关的边(无论方向)")
# 7. 测试 get_nodes_by_chunk_ids - 批量根据 chunk_ids 获取多个节点
print("== 测试 get_nodes_by_chunk_ids")
print("== 测试单个 chunk_id匹配多个节点")
nodes = await storage.get_nodes_by_chunk_ids([chunk2_id])
assert len(nodes) == 2, f"{chunk1_id} 应有2个节点实际有 {len(nodes)}"
has_node1 = any(node["entity_id"] == node1_id for node in nodes)
has_node2 = any(node["entity_id"] == node2_id for node in nodes)
assert has_node1, f"节点 {node1_id} 应在返回结果中"
assert has_node2, f"节点 {node2_id} 应在返回结果中"
print("== 测试多个 chunk_id部分匹配多个节点")
nodes = await storage.get_nodes_by_chunk_ids([chunk2_id, chunk3_id])
assert (
len(nodes) == 3
), f"{chunk2_id}, {chunk3_id} 应有3个节点实际有 {len(nodes)}"
has_node1 = any(node["entity_id"] == node1_id for node in nodes)
has_node2 = any(node["entity_id"] == node2_id for node in nodes)
has_node3 = any(node["entity_id"] == node3_id for node in nodes)
assert has_node1, f"节点 {node1_id} 应在返回结果中"
assert has_node2, f"节点 {node2_id} 应在返回结果中"
assert has_node3, f"节点 {node3_id} 应在返回结果中"
# 8. 测试 get_edges_by_chunk_ids - 批量根据 chunk_ids 获取多条边
print("== 测试 get_edges_by_chunk_ids")
print("== 测试单个 chunk_id匹配多条边")
edges = await storage.get_edges_by_chunk_ids([chunk2_id])
assert len(edges) == 2, f"{chunk2_id} 应有2条边实际有 {len(edges)}"
has_edge_node1_node2 = any(
edge["source"] == node1_id and edge["target"] == node2_id for edge in edges
)
has_edge_node2_node3 = any(
edge["source"] == node2_id and edge["target"] == node3_id for edge in edges
)
assert has_edge_node1_node2, f"{chunk2_id} 应包含 {node1_id}{node2_id} 的边"
assert has_edge_node2_node3, f"{chunk2_id} 应包含 {node2_id}{node3_id} 的边"
print("== 测试多个 chunk_id部分匹配多条边")
edges = await storage.get_edges_by_chunk_ids([chunk2_id, chunk3_id])
assert (
len(edges) == 3
), f"{chunk2_id}, {chunk3_id} 应有3条边实际有 {len(edges)}"
has_edge_node1_node2 = any(
edge["source"] == node1_id and edge["target"] == node2_id for edge in edges
)
has_edge_node2_node3 = any(
edge["source"] == node2_id and edge["target"] == node3_id for edge in edges
)
has_edge_node1_node4 = any(
edge["source"] == node1_id and edge["target"] == node4_id for edge in edges
)
assert (
has_edge_node1_node2
), f"{chunk2_id}, {chunk3_id} 应包含 {node1_id}{node2_id} 的边"
assert (
has_edge_node2_node3
), f"{chunk2_id}, {chunk3_id} 应包含 {node2_id}{node3_id} 的边"
assert (
has_edge_node1_node4
), f"{chunk2_id}, {chunk3_id} 应包含 {node1_id}{node4_id} 的边"
print("\n批量操作测试完成")
return True