mirror of
				https://github.com/HKUDS/LightRAG.git
				synced 2025-11-03 19:29:38 +00:00 
			
		
		
		
	Remove unused node embedding functionality from graph storage
- Deleted embed_nodes() method implementations
This commit is contained in:
		
							parent
							
								
									c084358dc9
								
							
						
					
					
						commit
						83353ab9a6
					
				@ -6,7 +6,6 @@ import sys
 | 
			
		||||
from contextlib import asynccontextmanager
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, Dict, List, NamedTuple, Optional, Union, final
 | 
			
		||||
import numpy as np
 | 
			
		||||
import pipmaster as pm
 | 
			
		||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
 | 
			
		||||
 | 
			
		||||
@ -668,21 +667,6 @@ class AGEStorage(BaseGraphStorage):
 | 
			
		||||
                logger.error(f"Error during edge deletion: {str(e)}")
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
    async def embed_nodes(
 | 
			
		||||
        self, algorithm: str
 | 
			
		||||
    ) -> tuple[np.ndarray[Any, Any], list[str]]:
 | 
			
		||||
        """Embed nodes using the specified algorithm
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            algorithm: Name of the embedding algorithm
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            tuple: (embedding matrix, list of node identifiers)
 | 
			
		||||
        """
 | 
			
		||||
        if algorithm not in self._node_embed_algorithms:
 | 
			
		||||
            raise ValueError(f"Node embedding algorithm {algorithm} not supported")
 | 
			
		||||
        return await self._node_embed_algorithms[algorithm]()
 | 
			
		||||
 | 
			
		||||
    async def get_all_labels(self) -> list[str]:
 | 
			
		||||
        """Get all node labels in the database
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6,9 +6,6 @@ import pipmaster as pm
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, Dict, List, final
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from tenacity import (
 | 
			
		||||
    retry,
 | 
			
		||||
    retry_if_exception_type,
 | 
			
		||||
@ -419,27 +416,6 @@ class GremlinStorage(BaseGraphStorage):
 | 
			
		||||
            logger.error(f"Error during node deletion: {str(e)}")
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    async def embed_nodes(
 | 
			
		||||
        self, algorithm: str
 | 
			
		||||
    ) -> tuple[np.ndarray[Any, Any], list[str]]:
 | 
			
		||||
        """
 | 
			
		||||
        Embed nodes using the specified algorithm.
 | 
			
		||||
        Currently, only node2vec is supported but never called.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            algorithm: The name of the embedding algorithm to use
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A tuple of (embeddings, node_ids)
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            NotImplementedError: If the specified algorithm is not supported
 | 
			
		||||
            ValueError: If the algorithm is not supported
 | 
			
		||||
        """
 | 
			
		||||
        if algorithm not in self._node_embed_algorithms:
 | 
			
		||||
            raise ValueError(f"Node embedding algorithm {algorithm} not supported")
 | 
			
		||||
        return await self._node_embed_algorithms[algorithm]()
 | 
			
		||||
 | 
			
		||||
    async def get_all_labels(self) -> list[str]:
 | 
			
		||||
        """
 | 
			
		||||
        Get all node entity_names in the graph
 | 
			
		||||
 | 
			
		||||
@ -663,20 +663,6 @@ class MongoGraphStorage(BaseGraphStorage):
 | 
			
		||||
        # Remove the node doc
 | 
			
		||||
        await self.collection.delete_one({"_id": node_id})
 | 
			
		||||
 | 
			
		||||
    #
 | 
			
		||||
    # -------------------------------------------------------------------------
 | 
			
		||||
    # EMBEDDINGS (NOT IMPLEMENTED)
 | 
			
		||||
    # -------------------------------------------------------------------------
 | 
			
		||||
    #
 | 
			
		||||
 | 
			
		||||
    async def embed_nodes(
 | 
			
		||||
        self, algorithm: str
 | 
			
		||||
    ) -> tuple[np.ndarray[Any, Any], list[str]]:
 | 
			
		||||
        """
 | 
			
		||||
        Placeholder for demonstration, raises NotImplementedError.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError("Node embedding is not used in lightrag.")
 | 
			
		||||
 | 
			
		||||
    #
 | 
			
		||||
    # -------------------------------------------------------------------------
 | 
			
		||||
    # QUERY
 | 
			
		||||
 | 
			
		||||
@ -2,8 +2,7 @@ import inspect
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, final
 | 
			
		||||
import numpy as np
 | 
			
		||||
from typing import final
 | 
			
		||||
import configparser
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1126,11 +1125,6 @@ class Neo4JStorage(BaseGraphStorage):
 | 
			
		||||
                logger.error(f"Error during edge deletion: {str(e)}")
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
    async def embed_nodes(
 | 
			
		||||
        self, algorithm: str
 | 
			
		||||
    ) -> tuple[np.ndarray[Any, Any], list[str]]:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    async def drop(self) -> dict[str, str]:
 | 
			
		||||
        """Drop all data from storage and clean up resources
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
import os
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, final
 | 
			
		||||
import numpy as np
 | 
			
		||||
from typing import final
 | 
			
		||||
 | 
			
		||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
 | 
			
		||||
from lightrag.utils import logger
 | 
			
		||||
@ -16,7 +15,6 @@ if not pm.is_installed("graspologic"):
 | 
			
		||||
    pm.install("graspologic")
 | 
			
		||||
 | 
			
		||||
import networkx as nx
 | 
			
		||||
from graspologic import embed
 | 
			
		||||
from .shared_storage import (
 | 
			
		||||
    get_storage_lock,
 | 
			
		||||
    get_update_flag,
 | 
			
		||||
@ -42,40 +40,6 @@ class NetworkXStorage(BaseGraphStorage):
 | 
			
		||||
        )
 | 
			
		||||
        nx.write_graphml(graph, file_name)
 | 
			
		||||
 | 
			
		||||
    # TODO:deprecated, remove later
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
 | 
			
		||||
        """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
 | 
			
		||||
        Ensure an undirected graph with the same relationships will always be read the same way.
 | 
			
		||||
        """
 | 
			
		||||
        fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
 | 
			
		||||
 | 
			
		||||
        sorted_nodes = graph.nodes(data=True)
 | 
			
		||||
        sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
 | 
			
		||||
 | 
			
		||||
        fixed_graph.add_nodes_from(sorted_nodes)
 | 
			
		||||
        edges = list(graph.edges(data=True))
 | 
			
		||||
 | 
			
		||||
        if not graph.is_directed():
 | 
			
		||||
 | 
			
		||||
            def _sort_source_target(edge):
 | 
			
		||||
                source, target, edge_data = edge
 | 
			
		||||
                if source > target:
 | 
			
		||||
                    temp = source
 | 
			
		||||
                    source = target
 | 
			
		||||
                    target = temp
 | 
			
		||||
                return source, target, edge_data
 | 
			
		||||
 | 
			
		||||
            edges = [_sort_source_target(edge) for edge in edges]
 | 
			
		||||
 | 
			
		||||
        def _get_edge_key(source: Any, target: Any) -> str:
 | 
			
		||||
            return f"{source} -> {target}"
 | 
			
		||||
 | 
			
		||||
        edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
 | 
			
		||||
 | 
			
		||||
        fixed_graph.add_edges_from(edges)
 | 
			
		||||
        return fixed_graph
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        self._graphml_xml_file = os.path.join(
 | 
			
		||||
            self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
 | 
			
		||||
@ -191,24 +155,6 @@ class NetworkXStorage(BaseGraphStorage):
 | 
			
		||||
        else:
 | 
			
		||||
            logger.warning(f"Node {node_id} not found in the graph for deletion.")
 | 
			
		||||
 | 
			
		||||
    # TODO: NOT USED
 | 
			
		||||
    async def embed_nodes(
 | 
			
		||||
        self, algorithm: str
 | 
			
		||||
    ) -> tuple[np.ndarray[Any, Any], list[str]]:
 | 
			
		||||
        if algorithm not in self._node_embed_algorithms:
 | 
			
		||||
            raise ValueError(f"Node embedding algorithm {algorithm} not supported")
 | 
			
		||||
        return await self._node_embed_algorithms[algorithm]()
 | 
			
		||||
 | 
			
		||||
    # TODO: NOT USED
 | 
			
		||||
    async def _node2vec_embed(self):
 | 
			
		||||
        graph = await self._get_graph()
 | 
			
		||||
        embeddings, nodes = embed.node2vec_embed(
 | 
			
		||||
            graph,
 | 
			
		||||
            **self.global_config["node2vec_params"],
 | 
			
		||||
        )
 | 
			
		||||
        nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
 | 
			
		||||
        return embeddings, nodes_ids
 | 
			
		||||
 | 
			
		||||
    async def remove_nodes(self, nodes: list[str]):
 | 
			
		||||
        """Delete multiple nodes
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1485,24 +1485,6 @@ class PGGraphStorage(BaseGraphStorage):
 | 
			
		||||
        labels = [result["label"] for result in results]
 | 
			
		||||
        return labels
 | 
			
		||||
 | 
			
		||||
    async def embed_nodes(
 | 
			
		||||
        self, algorithm: str
 | 
			
		||||
    ) -> tuple[np.ndarray[Any, Any], list[str]]:
 | 
			
		||||
        """
 | 
			
		||||
        Generate node embeddings using the specified algorithm.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            algorithm (str): The name of the embedding algorithm to use.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            tuple[np.ndarray[Any, Any], list[str]]: A tuple containing the embeddings and the corresponding node IDs.
 | 
			
		||||
        """
 | 
			
		||||
        if algorithm not in self._node_embed_algorithms:
 | 
			
		||||
            raise ValueError(f"Unsupported embedding algorithm: {algorithm}")
 | 
			
		||||
 | 
			
		||||
        embed_func = self._node_embed_algorithms[algorithm]
 | 
			
		||||
        return await embed_func()
 | 
			
		||||
 | 
			
		||||
    async def get_knowledge_graph(
 | 
			
		||||
        self,
 | 
			
		||||
        node_label: str,
 | 
			
		||||
 | 
			
		||||
@ -800,13 +800,6 @@ class TiDBGraphStorage(BaseGraphStorage):
 | 
			
		||||
        }
 | 
			
		||||
        await self.db.execute(merge_sql, data)
 | 
			
		||||
 | 
			
		||||
    async def embed_nodes(
 | 
			
		||||
        self, algorithm: str
 | 
			
		||||
    ) -> tuple[np.ndarray[Any, Any], list[str]]:
 | 
			
		||||
        if algorithm not in self._node_embed_algorithms:
 | 
			
		||||
            raise ValueError(f"Node embedding algorithm {algorithm} not supported")
 | 
			
		||||
        return await self._node_embed_algorithms[algorithm]()
 | 
			
		||||
 | 
			
		||||
    # Query
 | 
			
		||||
 | 
			
		||||
    async def has_node(self, node_id: str) -> bool:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user