diff --git a/env.example b/env.example index f759ea92..5e43cf5e 100644 --- a/env.example +++ b/env.example @@ -134,13 +134,14 @@ EMBEDDING_BINDING_HOST=http://localhost:11434 # LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage ### Graph Storage (Recommended for production deployment) # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage +# LIGHTRAG_GRAPH_STORAGE=MemgraphStorage #################################################################### ### Default workspace for all storage types ### For the purpose of isolation of data for each LightRAG instance ### Valid characters: a-z, A-Z, 0-9, and _ #################################################################### -# WORKSPACE=doc— +# WORKSPACE=space1 ### PostgreSQL Configuration POSTGRES_HOST=localhost @@ -179,3 +180,10 @@ QDRANT_URL=http://localhost:6333 ### Redis REDIS_URI=redis://localhost:6379 # REDIS_WORKSPACE=forced_workspace_name + +### Memgraph Configuration +MEMGRAPH_URI=bolt://localhost:7687 +MEMGRAPH_USERNAME= +MEMGRAPH_PASSWORD= +MEMGRAPH_DATABASE=memgraph +# MEMGRAPH_WORKSPACE=forced_workspace_name diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 397e5a99..8c6d6574 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -31,14 +31,23 @@ config.read("config.ini", "utf-8") @final @dataclass class MemgraphStorage(BaseGraphStorage): - def __init__(self, namespace, global_config, embedding_func): + def __init__(self, namespace, global_config, embedding_func, workspace=None): + memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE") + if memgraph_workspace and memgraph_workspace.strip(): + workspace = memgraph_workspace super().__init__( namespace=namespace, + workspace=workspace or "", global_config=global_config, embedding_func=embedding_func, ) self._driver = None + def _get_workspace_label(self) -> str: + """Get workspace label, return 'base' for compatibility when workspace is empty""" + workspace = getattr(self, "workspace", None) + return workspace if workspace else "base" + async def initialize(self): URI = os.environ.get( "MEMGRAPH_URI", @@ -63,12 +72,17 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session(database=DATABASE) as session: # Create index for base nodes on entity_id if it doesn't exist try: - await session.run("""CREATE INDEX ON :base(entity_id)""") - logger.info("Created index on :base(entity_id) in Memgraph.") + workspace_label = self._get_workspace_label() + await session.run( + f"""CREATE INDEX ON :{workspace_label}(entity_id)""" + ) + logger.info( + f"Created index on :{workspace_label}(entity_id) in Memgraph." + ) except Exception as e: # Index may already exist, which is not an error logger.warning( - f"Index creation on :base(entity_id) may have failed or already exists: {e}" + f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}" ) await session.run("RETURN 1") logger.info(f"Connected to Memgraph at {URI}") @@ -101,15 +115,22 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error checking the node existence. """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" result = await session.run(query, entity_id=node_id) single_result = await result.single() await result.consume() # Ensure result is fully consumed - return single_result["node_exists"] + return ( + single_result["node_exists"] if single_result is not None else False + ) except Exception as e: logger.error(f"Error checking node existence for {node_id}: {str(e)}") await result.consume() # Ensure the result is consumed even on error @@ -129,22 +150,29 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error checking the edge existence. """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: + workspace_label = self._get_workspace_label() query = ( - "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " + f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " "RETURN COUNT(r) > 0 AS edgeExists" ) result = await session.run( query, source_entity_id=source_node_id, target_entity_id=target_node_id, - ) + ) # type: ignore single_result = await result.single() await result.consume() # Ensure result is fully consumed - return single_result["edgeExists"] + return ( + single_result["edgeExists"] if single_result is not None else False + ) except Exception as e: logger.error( f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" @@ -165,11 +193,18 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + workspace_label = self._get_workspace_label() + query = ( + f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n" + ) result = await session.run(query, entity_id=node_id) try: records = await result.fetch( @@ -183,12 +218,12 @@ class MemgraphStorage(BaseGraphStorage): if records: node = records[0]["n"] node_dict = dict(node) - # Remove base label from labels list if it exists + # Remove workspace label from labels list if it exists if "labels" in node_dict: node_dict["labels"] = [ label for label in node_dict["labels"] - if label != "base" + if label != workspace_label ] return node_dict return None @@ -212,12 +247,17 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = """ - MATCH (n:base {entity_id: $entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) OPTIONAL MATCH (n)-[r]-() RETURN COUNT(r) AS degree """ @@ -246,12 +286,17 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = """ - MATCH (n:base) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}`) WHERE n.entity_id IS NOT NULL RETURN DISTINCT n.entity_id AS label ORDER BY label @@ -280,13 +325,18 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) try: async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = """MATCH (n:base {entity_id: $entity_id}) - OPTIONAL MATCH (n)-[r]-(connected:base) + workspace_label = self._get_workspace_label() + query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) + OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`) WHERE connected.entity_id IS NOT NULL RETURN n, r, connected""" results = await session.run(query, entity_id=source_node_id) @@ -341,12 +391,17 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = """ - MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}}) RETURN properties(r) as edge_properties """ result = await session.run( @@ -386,6 +441,10 @@ class MemgraphStorage(BaseGraphStorage): node_id: The unique identifier for the node (used as label) node_data: Dictionary of node properties """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) properties = node_data entity_type = properties["entity_type"] if "entity_id" not in properties: @@ -393,16 +452,14 @@ class MemgraphStorage(BaseGraphStorage): try: async with self._driver.session(database=self._DATABASE) as session: + workspace_label = self._get_workspace_label() async def execute_upsert(tx: AsyncManagedTransaction): - query = ( - """ - MERGE (n:base {entity_id: $entity_id}) + query = f""" + MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) SET n += $properties - SET n:`%s` + SET n:`{entity_type}` """ - % entity_type - ) result = await tx.run( query, entity_id=node_id, properties=properties ) @@ -429,15 +486,20 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) try: edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): - query = """ - MATCH (source:base {entity_id: $source_entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}}) WITH source - MATCH (target:base {entity_id: $target_entity_id}) + MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}}) MERGE (source)-[r:DIRECTED]-(target) SET r += $properties RETURN r, source, target @@ -467,10 +529,15 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async def _do_delete(tx: AsyncManagedTransaction): - query = """ - MATCH (n:base {entity_id: $entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) DETACH DELETE n """ result = await tx.run(query, entity_id=node_id) @@ -490,6 +557,10 @@ class MemgraphStorage(BaseGraphStorage): Args: nodes: List of node labels to be deleted """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) for node in nodes: await self.delete_node(node) @@ -502,11 +573,16 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) for source, target in edges: async def _do_delete_edge(tx: AsyncManagedTransaction): - query = """ - MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}}) DELETE r """ result = await tx.run( @@ -523,9 +599,9 @@ class MemgraphStorage(BaseGraphStorage): raise async def drop(self) -> dict[str, str]: - """Drop all data from storage and clean up resources + """Drop all data from the current workspace and clean up resources - This method will delete all nodes and relationships in the Neo4j database. + This method will delete all nodes and relationships in the Memgraph database. Returns: dict[str, str]: Operation status and message @@ -535,17 +611,24 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) try: async with self._driver.session(database=self._DATABASE) as session: - query = "MATCH (n) DETACH DELETE n" + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" result = await session.run(query) await result.consume() logger.info( - f"Process {os.getpid()} drop Memgraph database {self._DATABASE}" + f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" ) - return {"status": "success", "message": "data dropped"} + return {"status": "success", "message": "workspace data dropped"} except Exception as e: - logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}") + logger.error( + f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" + ) return {"status": "error", "message": str(e)} async def edge_degree(self, src_id: str, tgt_id: str) -> int: @@ -558,6 +641,10 @@ class MemgraphStorage(BaseGraphStorage): Returns: int: Sum of the degrees of both nodes """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) src_degree = await self.node_degree(src_id) trg_degree = await self.node_degree(tgt_id) @@ -578,12 +665,17 @@ class MemgraphStorage(BaseGraphStorage): list[dict]: A list of nodes, where each node is a dictionary of its properties. An empty list if no matching nodes are found. """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ + query = f""" UNWIND $chunk_ids AS chunk_id - MATCH (n:base) + MATCH (n:`{workspace_label}`) WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) RETURN DISTINCT n """ @@ -607,12 +699,17 @@ class MemgraphStorage(BaseGraphStorage): list[dict]: A list of edges, where each edge is a dictionary of its properties. An empty list if no matching edges are found. """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ + query = f""" UNWIND $chunk_ids AS chunk_id - MATCH (a:base)-[r]-(b:base) + MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id // Ensure we only return each unique edge once by ordering the source and target @@ -652,9 +749,15 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + result = KnowledgeGraph() seen_nodes = set() seen_edges = set() + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -682,19 +785,17 @@ class MemgraphStorage(BaseGraphStorage): await count_result.consume() # Run the main query to get nodes with highest degree - main_query = """ - MATCH (n) + main_query = f""" + MATCH (n:`{workspace_label}`) OPTIONAL MATCH (n)-[r]-() WITH n, COALESCE(count(r), 0) AS degree ORDER BY degree DESC LIMIT $max_nodes - WITH collect({node: n}) AS filtered_nodes - UNWIND filtered_nodes AS node_info - WITH collect(node_info.node) AS kept_nodes, filtered_nodes - OPTIONAL MATCH (a)-[r]-(b) + WITH collect(n) AS kept_nodes + MATCH (a)-[r]-(b) WHERE a IN kept_nodes AND b IN kept_nodes - RETURN filtered_nodes AS node_info, - collect(DISTINCT r) AS relationships + RETURN [node IN kept_nodes | {{node: node}}] AS node_info, + collect(DISTINCT r) AS relationships """ result_set = None try: @@ -710,31 +811,33 @@ class MemgraphStorage(BaseGraphStorage): await result_set.consume() else: - bfs_query = """ - MATCH (start) WHERE start.entity_id = $entity_id + bfs_query = f""" + MATCH (start:`{workspace_label}`) + WHERE start.entity_id = $entity_id WITH start - CALL { + CALL {{ WITH start - MATCH path = (start)-[*0..$max_depth]-(node) + MATCH path = (start)-[*0..{max_depth}]-(node) WITH nodes(path) AS path_nodes, relationships(path) AS path_rels UNWIND path_nodes AS n WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels RETURN all_nodes, all_rels - } + }} WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes - - // Apply node limiting here - WITH CASE - WHEN total_nodes <= $max_nodes THEN nodes - ELSE nodes[0..$max_nodes] + WITH + CASE + WHEN total_nodes <= {max_nodes} THEN nodes + ELSE nodes[0..{max_nodes}] END AS limited_nodes, relationships, total_nodes, - total_nodes > $max_nodes AS is_truncated - UNWIND limited_nodes AS node - WITH collect({node: node}) AS node_info, relationships, total_nodes, is_truncated - RETURN node_info, relationships, total_nodes, is_truncated + total_nodes > {max_nodes} AS is_truncated + RETURN + [node IN limited_nodes | {{node: node}}] AS node_info, + relationships, + total_nodes, + is_truncated """ result_set = None try: @@ -742,8 +845,6 @@ class MemgraphStorage(BaseGraphStorage): bfs_query, { "entity_id": node_label, - "max_depth": max_depth, - "max_nodes": max_nodes, }, ) record = await result_set.single() @@ -777,22 +878,21 @@ class MemgraphStorage(BaseGraphStorage): ) ) - if "relationships" in record and record["relationships"]: - for rel in record["relationships"]: - edge_id = rel.id - if edge_id not in seen_edges: - seen_edges.add(edge_id) - start = rel.start_node - end = rel.end_node - result.edges.append( - KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{start.id}", - target=f"{end.id}", - properties=dict(rel), - ) + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + seen_edges.add(edge_id) + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), ) + ) logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"