From 08eb68b8ed0f774843fd682e3de3cc4749b5b6e8 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Tue, 8 Jul 2025 20:21:20 +0200 Subject: [PATCH] run pre-commit --- lightrag/kg/memgraph_impl.py | 113 +++++++++++++++++++++++++---------- 1 file changed, 82 insertions(+), 31 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 4c16b843..8c6d6574 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -73,8 +73,12 @@ class MemgraphStorage(BaseGraphStorage): # Create index for base nodes on entity_id if it doesn't exist try: workspace_label = self._get_workspace_label() - await session.run(f"""CREATE INDEX ON :{workspace_label}(entity_id)""") - logger.info(f"Created index on :{workspace_label}(entity_id) in Memgraph.") + await session.run( + f"""CREATE INDEX ON :{workspace_label}(entity_id)""" + ) + logger.info( + f"Created index on :{workspace_label}(entity_id) in Memgraph." + ) except Exception as e: # Index may already exist, which is not an error logger.warning( @@ -112,7 +116,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error checking the node existence. """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -122,7 +128,9 @@ class MemgraphStorage(BaseGraphStorage): result = await session.run(query, entity_id=node_id) single_result = await result.single() await result.consume() # Ensure result is fully consumed - return single_result["node_exists"] if single_result is not None else False + return ( + single_result["node_exists"] if single_result is not None else False + ) except Exception as e: logger.error(f"Error checking node existence for {node_id}: {str(e)}") await result.consume() # Ensure the result is consumed even on error @@ -143,7 +151,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error checking the edge existence. """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -153,10 +163,16 @@ class MemgraphStorage(BaseGraphStorage): f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " "RETURN COUNT(r) > 0 AS edgeExists" ) - result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id) # type: ignore + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) # type: ignore single_result = await result.single() await result.consume() # Ensure result is fully consumed - return single_result["edgeExists"] if single_result is not None else False + return ( + single_result["edgeExists"] if single_result is not None else False + ) except Exception as e: logger.error( f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" @@ -178,13 +194,17 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: workspace_label = self._get_workspace_label() - query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n" + query = ( + f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n" + ) result = await session.run(query, entity_id=node_id) try: records = await result.fetch( @@ -228,7 +248,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -265,7 +287,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -302,7 +326,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) try: async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -366,7 +392,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -414,7 +442,9 @@ class MemgraphStorage(BaseGraphStorage): node_data: Dictionary of node properties """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) properties = node_data entity_type = properties["entity_type"] if "entity_id" not in properties: @@ -423,14 +453,13 @@ class MemgraphStorage(BaseGraphStorage): try: async with self._driver.session(database=self._DATABASE) as session: workspace_label = self._get_workspace_label() + async def execute_upsert(tx: AsyncManagedTransaction): - query = ( - f""" + query = f""" MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) SET n += $properties SET n:`{entity_type}` """ - ) result = await tx.run( query, entity_id=node_id, properties=properties ) @@ -458,7 +487,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) try: edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: @@ -499,7 +530,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async def _do_delete(tx: AsyncManagedTransaction): workspace_label = self._get_workspace_label() @@ -525,7 +558,9 @@ class MemgraphStorage(BaseGraphStorage): nodes: List of node labels to be deleted """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) for node in nodes: await self.delete_node(node) @@ -539,7 +574,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) for source, target in edges: async def _do_delete_edge(tx: AsyncManagedTransaction): @@ -575,17 +612,23 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) try: async with self._driver.session(database=self._DATABASE) as session: workspace_label = self._get_workspace_label() query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" result = await session.run(query) await result.consume() - logger.info(f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}") + logger.info( + f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" + ) return {"status": "success", "message": "workspace data dropped"} except Exception as e: - logger.error(f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}") + logger.error( + f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" + ) return {"status": "error", "message": str(e)} async def edge_degree(self, src_id: str, tgt_id: str) -> int: @@ -599,7 +642,9 @@ class MemgraphStorage(BaseGraphStorage): int: Sum of the degrees of both nodes """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) src_degree = await self.node_degree(src_id) trg_degree = await self.node_degree(tgt_id) @@ -621,7 +666,9 @@ class MemgraphStorage(BaseGraphStorage): An empty list if no matching nodes are found. """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -653,7 +700,9 @@ class MemgraphStorage(BaseGraphStorage): An empty list if no matching edges are found. """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -701,7 +750,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) result = KnowledgeGraph() seen_nodes = set() @@ -761,7 +812,7 @@ class MemgraphStorage(BaseGraphStorage): else: bfs_query = f""" - MATCH (start:`{workspace_label}`) + MATCH (start:`{workspace_label}`) WHERE start.entity_id = $entity_id WITH start CALL {{ @@ -774,7 +825,7 @@ class MemgraphStorage(BaseGraphStorage): RETURN all_nodes, all_rels }} WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes - WITH + WITH CASE WHEN total_nodes <= {max_nodes} THEN nodes ELSE nodes[0..{max_nodes}] @@ -782,7 +833,7 @@ class MemgraphStorage(BaseGraphStorage): relationships, total_nodes, total_nodes > {max_nodes} AS is_truncated - RETURN + RETURN [node IN limited_nodes | {{node: node}}] AS node_info, relationships, total_nodes,