feat: implement get_all_nodes and get_all_edges methods for graph storage backends

Add get_all_nodes() and get_all_edges() methods to Neo4JStorage, PGGraphStorage, MongoGraphStorage, and MemgraphStorage classes. These methods return all nodes and edges in the graph with consistent formatting matching NetworkXStorage for compatibility across different storage backends.
This commit is contained in:
yangdx 2025-08-03 11:02:37 +08:00
parent bfe6657b31
commit d2dd137f83
4 changed files with 196 additions and 0 deletions

View File

@ -997,3 +997,60 @@ class MemgraphStorage(BaseGraphStorage):
logger.warning(f"Memgraph error during subgraph query: {str(e)}")
return result
async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph.
Returns:
A list of all nodes, where each node is a dictionary of its properties
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
MATCH (n:`{workspace_label}`)
RETURN n
"""
result = await session.run(query)
nodes = []
async for record in result:
node = record["n"]
node_dict = dict(node)
# Add node id (entity_id) to the dictionary for easier access
node_dict["id"] = node_dict.get("entity_id")
nodes.append(node_dict)
await result.consume()
return nodes
async def get_all_edges(self) -> list[dict]:
"""Get all edges in the graph.
Returns:
A list of all edges, where each edge is a dictionary of its properties
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
"""
result = await session.run(query)
edges = []
async for record in result:
edge_properties = record["properties"]
edge_properties["source"] = record["source"]
edge_properties["target"] = record["target"]
edges.append(edge_properties)
await result.consume()
return edges

View File

@ -1508,6 +1508,36 @@ class MongoGraphStorage(BaseGraphStorage):
logger.debug(f"Successfully deleted edges: {edges}")
async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph.
Returns:
A list of all nodes, where each node is a dictionary of its properties
"""
cursor = self.collection.find({})
nodes = []
async for node in cursor:
node_dict = dict(node)
# Add node id (entity_id) to the dictionary for easier access
node_dict["id"] = node_dict.get("_id")
nodes.append(node_dict)
return nodes
async def get_all_edges(self) -> list[dict]:
"""Get all edges in the graph.
Returns:
A list of all edges, where each edge is a dictionary of its properties
"""
cursor = self.edge_collection.find({})
edges = []
async for edge in cursor:
edge_dict = dict(edge)
edge_dict["source"] = edge_dict.get("source_node_id")
edge_dict["target"] = edge_dict.get("target_node_id")
edges.append(edge_dict)
return edges
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all documents in the collection.

View File

@ -1400,6 +1400,55 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error during edge deletion: {str(e)}")
raise
async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph.
Returns:
A list of all nodes, where each node is a dictionary of its properties
"""
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
MATCH (n:`{workspace_label}`)
RETURN n
"""
result = await session.run(query)
nodes = []
async for record in result:
node = record["n"]
node_dict = dict(node)
# Add node id (entity_id) to the dictionary for easier access
node_dict["id"] = node_dict.get("entity_id")
nodes.append(node_dict)
await result.consume()
return nodes
async def get_all_edges(self) -> list[dict]:
"""Get all edges in the graph.
Returns:
A list of all edges, where each edge is a dictionary of its properties
"""
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
"""
result = await session.run(query)
edges = []
async for record in result:
edge_properties = record["properties"]
edge_properties["source"] = record["source"]
edge_properties["target"] = record["target"]
edges.append(edge_properties)
await result.consume()
return edges
async def drop(self) -> dict[str, str]:
"""Drop all data from current workspace storage and clean up resources

View File

@ -3669,6 +3669,66 @@ class PGGraphStorage(BaseGraphStorage):
return kg
async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph.
Returns:
A list of all nodes, where each node is a dictionary of its properties
"""
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base)
RETURN n
$$) AS (n agtype)"""
results = await self._query(query)
nodes = []
for result in results:
if result["n"]:
node_dict = result["n"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(node_dict, str):
try:
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(f"Failed to parse node string: {node_dict}")
# Add node id (entity_id) to the dictionary for easier access
node_dict["id"] = node_dict.get("entity_id")
nodes.append(node_dict)
return nodes
async def get_all_edges(self) -> list[dict]:
"""Get all edges in the graph.
Returns:
A list of all edges, where each edge is a dictionary of its properties
"""
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (a:base)-[r]-(b:base)
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
$$) AS (source text, target text, properties agtype)"""
results = await self._query(query)
edges = []
for result in results:
edge_properties = result["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(edge_properties, str):
try:
edge_properties = json.loads(edge_properties)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge properties string: {edge_properties}"
)
edge_properties = {}
edge_properties["source"] = result["source"]
edge_properties["target"] = result["target"]
edges.append(edge_properties)
return edges
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try: