""" NetworkX Storage Module ======================= This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. Author: lightrag team Created: 2024-01-25 License: MIT Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. Version: 1.0.0 Dependencies: - NetworkX - NumPy - LightRAG - graspologic Features: - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) - Query graph nodes and edges - Calculate node and edge degrees - Embed nodes using various algorithms (e.g., Node2Vec) - Remove nodes and edges from the graph Usage: from lightrag.storage.networkx_storage import NetworkXStorage """ import html import os from dataclasses import dataclass from typing import Any, Union, cast import networkx as nx import numpy as np from lightrag.utils import ( logger, ) from lightrag.base import ( BaseGraphStorage, ) @dataclass class NetworkXStorage(BaseGraphStorage): @staticmethod def load_nx_graph(file_name) -> nx.Graph: if os.path.exists(file_name): return nx.read_graphml(file_name) return None @staticmethod def write_nx_graph(graph: nx.Graph, file_name): logger.info( f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges" ) nx.write_graphml(graph, file_name) @staticmethod def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py Return the largest connected component of the graph, with nodes and edges sorted in a stable way. """ from graspologic.utils import largest_connected_component graph = graph.copy() graph = cast(nx.Graph, largest_connected_component(graph)) node_mapping = { node: html.unescape(node.upper().strip()) for node in graph.nodes() } # type: ignore graph = nx.relabel_nodes(graph, node_mapping) return NetworkXStorage._stabilize_graph(graph) @staticmethod def _stabilize_graph(graph: nx.Graph) -> nx.Graph: """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py Ensure an undirected graph with the same relationships will always be read the same way. """ fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() sorted_nodes = graph.nodes(data=True) sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) fixed_graph.add_nodes_from(sorted_nodes) edges = list(graph.edges(data=True)) if not graph.is_directed(): def _sort_source_target(edge): source, target, edge_data = edge if source > target: temp = source source = target target = temp return source, target, edge_data edges = [_sort_source_target(edge) for edge in edges] def _get_edge_key(source: Any, target: Any) -> str: return f"{source} -> {target}" edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) fixed_graph.add_edges_from(edges) return fixed_graph def __post_init__(self): self._graphml_xml_file = os.path.join( self.global_config["working_dir"], f"graph_{self.namespace}.graphml" ) preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) if preloaded_graph is not None: logger.info( f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" ) self._graph = preloaded_graph or nx.Graph() self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } async def index_done_callback(self): NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) async def has_node(self, node_id: str) -> bool: return self._graph.has_node(node_id) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: return self._graph.has_edge(source_node_id, target_node_id) async def get_node(self, node_id: str) -> Union[dict, None]: return self._graph.nodes.get(node_id) async def node_degree(self, node_id: str) -> int: return self._graph.degree(node_id) async def edge_degree(self, src_id: str, tgt_id: str) -> int: return self._graph.degree(src_id) + self._graph.degree(tgt_id) async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: return self._graph.edges.get((source_node_id, target_node_id)) async def get_node_edges(self, source_node_id: str): if self._graph.has_node(source_node_id): return list(self._graph.edges(source_node_id)) return None async def upsert_node(self, node_id: str, node_data: dict[str, str]): self._graph.add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): self._graph.add_edge(source_node_id, target_node_id, **edge_data) async def delete_node(self, node_id: str): """ Delete a node from the graph based on the specified node_id. :param node_id: The node_id to delete """ if self._graph.has_node(node_id): self._graph.remove_node(node_id) logger.info(f"Node {node_id} deleted from the graph.") else: logger.warning(f"Node {node_id} not found in the graph for deletion.") async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() # @TODO: NOT USED async def _node2vec_embed(self): from graspologic import embed embeddings, nodes = embed.node2vec_embed( self._graph, **self.global_config["node2vec_params"], ) nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids def remove_nodes(self, nodes: list[str]): """Delete multiple nodes Args: nodes: List of node IDs to be deleted """ for node in nodes: if self._graph.has_node(node): self._graph.remove_node(node) def remove_edges(self, edges: list[tuple[str, str]]): """Delete multiple edges Args: edges: List of edges to be deleted, each edge is a (source, target) tuple """ for source, target in edges: if self._graph.has_edge(source, target): self._graph.remove_edge(source, target)