Merge branch 'add-Memgraph-graph-db' into memgraph

This commit is contained in:
yangdx 2025-07-09 03:38:07 +08:00
commit 14d51518dd
2 changed files with 191 additions and 83 deletions

View File

@ -134,13 +134,14 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
# LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage # LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage
### Graph Storage (Recommended for production deployment) ### Graph Storage (Recommended for production deployment)
# LIGHTRAG_GRAPH_STORAGE=Neo4JStorage # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
# LIGHTRAG_GRAPH_STORAGE=MemgraphStorage
#################################################################### ####################################################################
### Default workspace for all storage types ### Default workspace for all storage types
### For the purpose of isolation of data for each LightRAG instance ### For the purpose of isolation of data for each LightRAG instance
### Valid characters: a-z, A-Z, 0-9, and _ ### Valid characters: a-z, A-Z, 0-9, and _
#################################################################### ####################################################################
# WORKSPACE=doc— # WORKSPACE=space1
### PostgreSQL Configuration ### PostgreSQL Configuration
POSTGRES_HOST=localhost POSTGRES_HOST=localhost
@ -179,3 +180,10 @@ QDRANT_URL=http://localhost:6333
### Redis ### Redis
REDIS_URI=redis://localhost:6379 REDIS_URI=redis://localhost:6379
# REDIS_WORKSPACE=forced_workspace_name # REDIS_WORKSPACE=forced_workspace_name
### Memgraph Configuration
MEMGRAPH_URI=bolt://localhost:7687
MEMGRAPH_USERNAME=
MEMGRAPH_PASSWORD=
MEMGRAPH_DATABASE=memgraph
# MEMGRAPH_WORKSPACE=forced_workspace_name

View File

@ -31,14 +31,23 @@ config.read("config.ini", "utf-8")
@final @final
@dataclass @dataclass
class MemgraphStorage(BaseGraphStorage): class MemgraphStorage(BaseGraphStorage):
def __init__(self, namespace, global_config, embedding_func): def __init__(self, namespace, global_config, embedding_func, workspace=None):
memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE")
if memgraph_workspace and memgraph_workspace.strip():
workspace = memgraph_workspace
super().__init__( super().__init__(
namespace=namespace, namespace=namespace,
workspace=workspace or "",
global_config=global_config, global_config=global_config,
embedding_func=embedding_func, embedding_func=embedding_func,
) )
self._driver = None self._driver = None
def _get_workspace_label(self) -> str:
"""Get workspace label, return 'base' for compatibility when workspace is empty"""
workspace = getattr(self, "workspace", None)
return workspace if workspace else "base"
async def initialize(self): async def initialize(self):
URI = os.environ.get( URI = os.environ.get(
"MEMGRAPH_URI", "MEMGRAPH_URI",
@ -63,12 +72,17 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session(database=DATABASE) as session: async with self._driver.session(database=DATABASE) as session:
# Create index for base nodes on entity_id if it doesn't exist # Create index for base nodes on entity_id if it doesn't exist
try: try:
await session.run("""CREATE INDEX ON :base(entity_id)""") workspace_label = self._get_workspace_label()
logger.info("Created index on :base(entity_id) in Memgraph.") await session.run(
f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
)
logger.info(
f"Created index on :{workspace_label}(entity_id) in Memgraph."
)
except Exception as e: except Exception as e:
# Index may already exist, which is not an error # Index may already exist, which is not an error
logger.warning( logger.warning(
f"Index creation on :base(entity_id) may have failed or already exists: {e}" f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
) )
await session.run("RETURN 1") await session.run("RETURN 1")
logger.info(f"Connected to Memgraph at {URI}") logger.info(f"Connected to Memgraph at {URI}")
@ -101,15 +115,22 @@ class MemgraphStorage(BaseGraphStorage):
Raises: Raises:
Exception: If there is an error checking the node existence. Exception: If there is an error checking the node existence.
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
try: try:
query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
result = await session.run(query, entity_id=node_id) result = await session.run(query, entity_id=node_id)
single_result = await result.single() single_result = await result.single()
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
return single_result["node_exists"] return (
single_result["node_exists"] if single_result is not None else False
)
except Exception as e: except Exception as e:
logger.error(f"Error checking node existence for {node_id}: {str(e)}") logger.error(f"Error checking node existence for {node_id}: {str(e)}")
await result.consume() # Ensure the result is consumed even on error await result.consume() # Ensure the result is consumed even on error
@ -129,22 +150,29 @@ class MemgraphStorage(BaseGraphStorage):
Raises: Raises:
Exception: If there is an error checking the edge existence. Exception: If there is an error checking the edge existence.
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
try: try:
workspace_label = self._get_workspace_label()
query = ( query = (
"MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
"RETURN COUNT(r) > 0 AS edgeExists" "RETURN COUNT(r) > 0 AS edgeExists"
) )
result = await session.run( result = await session.run(
query, query,
source_entity_id=source_node_id, source_entity_id=source_node_id,
target_entity_id=target_node_id, target_entity_id=target_node_id,
) ) # type: ignore
single_result = await result.single() single_result = await result.single()
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
return single_result["edgeExists"] return (
single_result["edgeExists"] if single_result is not None else False
)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
@ -165,11 +193,18 @@ class MemgraphStorage(BaseGraphStorage):
Raises: Raises:
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
try: try:
query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" workspace_label = self._get_workspace_label()
query = (
f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
)
result = await session.run(query, entity_id=node_id) result = await session.run(query, entity_id=node_id)
try: try:
records = await result.fetch( records = await result.fetch(
@ -183,12 +218,12 @@ class MemgraphStorage(BaseGraphStorage):
if records: if records:
node = records[0]["n"] node = records[0]["n"]
node_dict = dict(node) node_dict = dict(node)
# Remove base label from labels list if it exists # Remove workspace label from labels list if it exists
if "labels" in node_dict: if "labels" in node_dict:
node_dict["labels"] = [ node_dict["labels"] = [
label label
for label in node_dict["labels"] for label in node_dict["labels"]
if label != "base" if label != workspace_label
] ]
return node_dict return node_dict
return None return None
@ -212,12 +247,17 @@ class MemgraphStorage(BaseGraphStorage):
Raises: Raises:
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
try: try:
query = """ workspace_label = self._get_workspace_label()
MATCH (n:base {entity_id: $entity_id}) query = f"""
MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
OPTIONAL MATCH (n)-[r]-() OPTIONAL MATCH (n)-[r]-()
RETURN COUNT(r) AS degree RETURN COUNT(r) AS degree
""" """
@ -246,12 +286,17 @@ class MemgraphStorage(BaseGraphStorage):
Raises: Raises:
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
try: try:
query = """ workspace_label = self._get_workspace_label()
MATCH (n:base) query = f"""
MATCH (n:`{workspace_label}`)
WHERE n.entity_id IS NOT NULL WHERE n.entity_id IS NOT NULL
RETURN DISTINCT n.entity_id AS label RETURN DISTINCT n.entity_id AS label
ORDER BY label ORDER BY label
@ -280,13 +325,18 @@ class MemgraphStorage(BaseGraphStorage):
Raises: Raises:
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
try: try:
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
try: try:
query = """MATCH (n:base {entity_id: $entity_id}) workspace_label = self._get_workspace_label()
OPTIONAL MATCH (n)-[r]-(connected:base) query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
WHERE connected.entity_id IS NOT NULL WHERE connected.entity_id IS NOT NULL
RETURN n, r, connected""" RETURN n, r, connected"""
results = await session.run(query, entity_id=source_node_id) results = await session.run(query, entity_id=source_node_id)
@ -341,12 +391,17 @@ class MemgraphStorage(BaseGraphStorage):
Raises: Raises:
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
try: try:
query = """ workspace_label = self._get_workspace_label()
MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) query = f"""
MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}})
RETURN properties(r) as edge_properties RETURN properties(r) as edge_properties
""" """
result = await session.run( result = await session.run(
@ -386,6 +441,10 @@ class MemgraphStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label) node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties node_data: Dictionary of node properties
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
properties = node_data properties = node_data
entity_type = properties["entity_type"] entity_type = properties["entity_type"]
if "entity_id" not in properties: if "entity_id" not in properties:
@ -393,16 +452,14 @@ class MemgraphStorage(BaseGraphStorage):
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
workspace_label = self._get_workspace_label()
async def execute_upsert(tx: AsyncManagedTransaction): async def execute_upsert(tx: AsyncManagedTransaction):
query = ( query = f"""
""" MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
MERGE (n:base {entity_id: $entity_id})
SET n += $properties SET n += $properties
SET n:`%s` SET n:`{entity_type}`
""" """
% entity_type
)
result = await tx.run( result = await tx.run(
query, entity_id=node_id, properties=properties query, entity_id=node_id, properties=properties
) )
@ -429,15 +486,20 @@ class MemgraphStorage(BaseGraphStorage):
Raises: Raises:
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
try: try:
edge_properties = edge_data edge_properties = edge_data
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction): async def execute_upsert(tx: AsyncManagedTransaction):
query = """ workspace_label = self._get_workspace_label()
MATCH (source:base {entity_id: $source_entity_id}) query = f"""
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
WITH source WITH source
MATCH (target:base {entity_id: $target_entity_id}) MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
MERGE (source)-[r:DIRECTED]-(target) MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties SET r += $properties
RETURN r, source, target RETURN r, source, target
@ -467,10 +529,15 @@ class MemgraphStorage(BaseGraphStorage):
Raises: Raises:
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async def _do_delete(tx: AsyncManagedTransaction): async def _do_delete(tx: AsyncManagedTransaction):
query = """ workspace_label = self._get_workspace_label()
MATCH (n:base {entity_id: $entity_id}) query = f"""
MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
DETACH DELETE n DETACH DELETE n
""" """
result = await tx.run(query, entity_id=node_id) result = await tx.run(query, entity_id=node_id)
@ -490,6 +557,10 @@ class MemgraphStorage(BaseGraphStorage):
Args: Args:
nodes: List of node labels to be deleted nodes: List of node labels to be deleted
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
for node in nodes: for node in nodes:
await self.delete_node(node) await self.delete_node(node)
@ -502,11 +573,16 @@ class MemgraphStorage(BaseGraphStorage):
Raises: Raises:
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
for source, target in edges: for source, target in edges:
async def _do_delete_edge(tx: AsyncManagedTransaction): async def _do_delete_edge(tx: AsyncManagedTransaction):
query = """ workspace_label = self._get_workspace_label()
MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) query = f"""
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}})
DELETE r DELETE r
""" """
result = await tx.run( result = await tx.run(
@ -523,9 +599,9 @@ class MemgraphStorage(BaseGraphStorage):
raise raise
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources """Drop all data from the current workspace and clean up resources
This method will delete all nodes and relationships in the Neo4j database. This method will delete all nodes and relationships in the Memgraph database.
Returns: Returns:
dict[str, str]: Operation status and message dict[str, str]: Operation status and message
@ -535,17 +611,24 @@ class MemgraphStorage(BaseGraphStorage):
Raises: Raises:
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
query = "MATCH (n) DETACH DELETE n" workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
result = await session.run(query) result = await session.run(query)
await result.consume() await result.consume()
logger.info( logger.info(
f"Process {os.getpid()} drop Memgraph database {self._DATABASE}" f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
) )
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "workspace data dropped"}
except Exception as e: except Exception as e:
logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}") logger.error(
f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
)
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@ -558,6 +641,10 @@ class MemgraphStorage(BaseGraphStorage):
Returns: Returns:
int: Sum of the degrees of both nodes int: Sum of the degrees of both nodes
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
src_degree = await self.node_degree(src_id) src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id) trg_degree = await self.node_degree(tgt_id)
@ -578,12 +665,17 @@ class MemgraphStorage(BaseGraphStorage):
list[dict]: A list of nodes, where each node is a dictionary of its properties. list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found. An empty list if no matching nodes are found.
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
query = """ query = f"""
UNWIND $chunk_ids AS chunk_id UNWIND $chunk_ids AS chunk_id
MATCH (n:base) MATCH (n:`{workspace_label}`)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
RETURN DISTINCT n RETURN DISTINCT n
""" """
@ -607,12 +699,17 @@ class MemgraphStorage(BaseGraphStorage):
list[dict]: A list of edges, where each edge is a dictionary of its properties. list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found. An empty list if no matching edges are found.
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
query = """ query = f"""
UNWIND $chunk_ids AS chunk_id UNWIND $chunk_ids AS chunk_id
MATCH (a:base)-[r]-(b:base) MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
// Ensure we only return each unique edge once by ordering the source and target // Ensure we only return each unique edge once by ordering the source and target
@ -652,9 +749,15 @@ class MemgraphStorage(BaseGraphStorage):
Raises: Raises:
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
result = KnowledgeGraph() result = KnowledgeGraph()
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
workspace_label = self._get_workspace_label()
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
@ -682,19 +785,17 @@ class MemgraphStorage(BaseGraphStorage):
await count_result.consume() await count_result.consume()
# Run the main query to get nodes with highest degree # Run the main query to get nodes with highest degree
main_query = """ main_query = f"""
MATCH (n) MATCH (n:`{workspace_label}`)
OPTIONAL MATCH (n)-[r]-() OPTIONAL MATCH (n)-[r]-()
WITH n, COALESCE(count(r), 0) AS degree WITH n, COALESCE(count(r), 0) AS degree
ORDER BY degree DESC ORDER BY degree DESC
LIMIT $max_nodes LIMIT $max_nodes
WITH collect({node: n}) AS filtered_nodes WITH collect(n) AS kept_nodes
UNWIND filtered_nodes AS node_info MATCH (a)-[r]-(b)
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
OPTIONAL MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info, RETURN [node IN kept_nodes | {{node: node}}] AS node_info,
collect(DISTINCT r) AS relationships collect(DISTINCT r) AS relationships
""" """
result_set = None result_set = None
try: try:
@ -710,31 +811,33 @@ class MemgraphStorage(BaseGraphStorage):
await result_set.consume() await result_set.consume()
else: else:
bfs_query = """ bfs_query = f"""
MATCH (start) WHERE start.entity_id = $entity_id MATCH (start:`{workspace_label}`)
WHERE start.entity_id = $entity_id
WITH start WITH start
CALL { CALL {{
WITH start WITH start
MATCH path = (start)-[*0..$max_depth]-(node) MATCH path = (start)-[*0..{max_depth}]-(node)
WITH nodes(path) AS path_nodes, relationships(path) AS path_rels WITH nodes(path) AS path_nodes, relationships(path) AS path_rels
UNWIND path_nodes AS n UNWIND path_nodes AS n
WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists
WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels
RETURN all_nodes, all_rels RETURN all_nodes, all_rels
} }}
WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
WITH
// Apply node limiting here CASE
WITH CASE WHEN total_nodes <= {max_nodes} THEN nodes
WHEN total_nodes <= $max_nodes THEN nodes ELSE nodes[0..{max_nodes}]
ELSE nodes[0..$max_nodes]
END AS limited_nodes, END AS limited_nodes,
relationships, relationships,
total_nodes, total_nodes,
total_nodes > $max_nodes AS is_truncated total_nodes > {max_nodes} AS is_truncated
UNWIND limited_nodes AS node RETURN
WITH collect({node: node}) AS node_info, relationships, total_nodes, is_truncated [node IN limited_nodes | {{node: node}}] AS node_info,
RETURN node_info, relationships, total_nodes, is_truncated relationships,
total_nodes,
is_truncated
""" """
result_set = None result_set = None
try: try:
@ -742,8 +845,6 @@ class MemgraphStorage(BaseGraphStorage):
bfs_query, bfs_query,
{ {
"entity_id": node_label, "entity_id": node_label,
"max_depth": max_depth,
"max_nodes": max_nodes,
}, },
) )
record = await result_set.single() record = await result_set.single()
@ -777,22 +878,21 @@ class MemgraphStorage(BaseGraphStorage):
) )
) )
if "relationships" in record and record["relationships"]: for rel in record["relationships"]:
for rel in record["relationships"]: edge_id = rel.id
edge_id = rel.id if edge_id not in seen_edges:
if edge_id not in seen_edges: seen_edges.add(edge_id)
seen_edges.add(edge_id) start = rel.start_node
start = rel.start_node end = rel.end_node
end = rel.end_node result.edges.append(
result.edges.append( KnowledgeGraphEdge(
KnowledgeGraphEdge( id=f"{edge_id}",
id=f"{edge_id}", type=rel.type,
type=rel.type, source=f"{start.id}",
source=f"{start.id}", target=f"{end.id}",
target=f"{end.id}", properties=dict(rel),
properties=dict(rel),
)
) )
)
logger.info( logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"