diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 86958a1a..af26b961 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -997,3 +997,60 @@ class MemgraphStorage(BaseGraphStorage): logger.warning(f"Memgraph error during subgraph query: {str(e)}") return result + + async def get_all_nodes(self) -> list[dict]: + """Get all nodes in the graph. + + Returns: + A list of all nodes, where each node is a dictionary of its properties + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + workspace_label = self._get_workspace_label() + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = f""" + MATCH (n:`{workspace_label}`) + RETURN n + """ + result = await session.run(query) + nodes = [] + async for record in result: + node = record["n"] + node_dict = dict(node) + # Add node id (entity_id) to the dictionary for easier access + node_dict["id"] = node_dict.get("entity_id") + nodes.append(node_dict) + await result.consume() + return nodes + + async def get_all_edges(self) -> list[dict]: + """Get all edges in the graph. + + Returns: + A list of all edges, where each edge is a dictionary of its properties + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + workspace_label = self._get_workspace_label() + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = f""" + MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) + RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties + """ + result = await session.run(query) + edges = [] + async for record in result: + edge_properties = record["properties"] + edge_properties["source"] = record["source"] + edge_properties["target"] = record["target"] + edges.append(edge_properties) + await result.consume() + return edges diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 14bd6633..9e2847f2 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1508,6 +1508,36 @@ class MongoGraphStorage(BaseGraphStorage): logger.debug(f"Successfully deleted edges: {edges}") + async def get_all_nodes(self) -> list[dict]: + """Get all nodes in the graph. + + Returns: + A list of all nodes, where each node is a dictionary of its properties + """ + cursor = self.collection.find({}) + nodes = [] + async for node in cursor: + node_dict = dict(node) + # Add node id (entity_id) to the dictionary for easier access + node_dict["id"] = node_dict.get("_id") + nodes.append(node_dict) + return nodes + + async def get_all_edges(self) -> list[dict]: + """Get all edges in the graph. + + Returns: + A list of all edges, where each edge is a dictionary of its properties + """ + cursor = self.edge_collection.find({}) + edges = [] + async for edge in cursor: + edge_dict = dict(edge) + edge_dict["source"] = edge_dict.get("source_node_id") + edge_dict["target"] = edge_dict.get("target_node_id") + edges.append(edge_dict) + return edges + async def drop(self) -> dict[str, str]: """Drop the storage by removing all documents in the collection. diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index d68707b0..3ae22927 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -1400,6 +1400,55 @@ class Neo4JStorage(BaseGraphStorage): logger.error(f"Error during edge deletion: {str(e)}") raise + async def get_all_nodes(self) -> list[dict]: + """Get all nodes in the graph. + + Returns: + A list of all nodes, where each node is a dictionary of its properties + """ + workspace_label = self._get_workspace_label() + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = f""" + MATCH (n:`{workspace_label}`) + RETURN n + """ + result = await session.run(query) + nodes = [] + async for record in result: + node = record["n"] + node_dict = dict(node) + # Add node id (entity_id) to the dictionary for easier access + node_dict["id"] = node_dict.get("entity_id") + nodes.append(node_dict) + await result.consume() + return nodes + + async def get_all_edges(self) -> list[dict]: + """Get all edges in the graph. + + Returns: + A list of all edges, where each edge is a dictionary of its properties + """ + workspace_label = self._get_workspace_label() + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = f""" + MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) + RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties + """ + result = await session.run(query) + edges = [] + async for record in result: + edge_properties = record["properties"] + edge_properties["source"] = record["source"] + edge_properties["target"] = record["target"] + edges.append(edge_properties) + await result.consume() + return edges + async def drop(self) -> dict[str, str]: """Drop all data from current workspace storage and clean up resources diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 7edfdad1..fc21a50b 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -3669,6 +3669,66 @@ class PGGraphStorage(BaseGraphStorage): return kg + async def get_all_nodes(self) -> list[dict]: + """Get all nodes in the graph. + + Returns: + A list of all nodes, where each node is a dictionary of its properties + """ + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:base) + RETURN n + $$) AS (n agtype)""" + + results = await self._query(query) + nodes = [] + for result in results: + if result["n"]: + node_dict = result["n"]["properties"] + + # Process string result, parse it to JSON dictionary + if isinstance(node_dict, str): + try: + node_dict = json.loads(node_dict) + except json.JSONDecodeError: + logger.warning(f"Failed to parse node string: {node_dict}") + + # Add node id (entity_id) to the dictionary for easier access + node_dict["id"] = node_dict.get("entity_id") + nodes.append(node_dict) + return nodes + + async def get_all_edges(self) -> list[dict]: + """Get all edges in the graph. + + Returns: + A list of all edges, where each edge is a dictionary of its properties + """ + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (a:base)-[r]-(b:base) + RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties + $$) AS (source text, target text, properties agtype)""" + + results = await self._query(query) + edges = [] + for result in results: + edge_properties = result["properties"] + + # Process string result, parse it to JSON dictionary + if isinstance(edge_properties, str): + try: + edge_properties = json.loads(edge_properties) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse edge properties string: {edge_properties}" + ) + edge_properties = {} + + edge_properties["source"] = result["source"] + edge_properties["target"] = result["target"] + edges.append(edge_properties) + return edges + async def drop(self) -> dict[str, str]: """Drop the storage""" try: