feat: optimize node and edge queries in PostgreSQL. query tables Directly

This commit is contained in:
Matt23-star 2025-08-16 22:37:48 +08:00
parent a7da48e05c
commit 6a7e3092ea

View File

@ -9,6 +9,7 @@ from typing import Any, Union, final
import numpy as np import numpy as np
import configparser import configparser
import ssl import ssl
import itertools
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
@ -3051,6 +3052,7 @@ class PGGraphStorage(BaseGraphStorage):
query: str, query: str,
readonly: bool = True, readonly: bool = True,
upsert: bool = False, upsert: bool = False,
params: dict[str, Any] | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Query the graph by taking a cypher query, converting it to an Query the graph by taking a cypher query, converting it to an
@ -3066,6 +3068,7 @@ class PGGraphStorage(BaseGraphStorage):
if readonly: if readonly:
data = await self.db.query( data = await self.db.query(
query, query,
params=self.params,
multirows=True, multirows=True,
with_age=True, with_age=True,
graph_name=self.graph_name, graph_name=self.graph_name,
@ -3398,41 +3401,55 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids: if not node_ids:
return {} return {}
# Format node IDs for the query seen = set()
formatted_ids = ", ".join( unique_ids = []
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] for nid in node_ids:
) nid_norm = self._normalize_node_id(nid)
if nid_norm not in seen:
query = """SELECT * FROM cypher('%s', $$ seen.add(nid_norm)
UNWIND [%s] AS node_id unique_ids.append(nid_norm)
MATCH (n:base {entity_id: node_id})
RETURN node_id, n
$$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
results = await self._query(query)
# Build result dictionary # Build result dictionary
nodes_dict = {} nodes_dict = {}
for result in results:
if result["node_id"] and result["n"]:
node_dict = result["n"]["properties"]
# Process string result, parse it to JSON dictionary for i in range(0, len(unique_ids), batch_size):
if isinstance(node_dict, str): batch = unique_ids[i : i + batch_size]
try:
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
)
# Remove the 'base' label if present in a 'labels' property query = f"""
# if "labels" in node_dict: WITH input(v, ord) AS (
# node_dict["labels"] = [ SELECT v, ord
# label for label in node_dict["labels"] if label != "base" FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
# ] ),
ids(node_id, ord) AS (
SELECT (to_json(v)::text)::agtype AS node_id, ord
FROM input
)
SELECT i.node_id::text AS node_id,
b.properties
FROM {self.graph_name}.base AS b
JOIN ids i
ON ag_catalog.agtype_access_operator(
VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
) = i.node_id
ORDER BY i.ord;
"""
nodes_dict[result["node_id"]] = node_dict results = await self._query(query, params={"ids": batch})
for result in results:
if result["node_id"] and result["properties"]:
node_dict = result["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(node_dict, str):
try:
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse node string in batch: {node_dict}"
)
nodes_dict[result["node_id"]] = node_dict
return nodes_dict return nodes_dict
@ -3453,44 +3470,66 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids: if not node_ids:
return {} return {}
# Format node IDs for the query seen = set()
formatted_ids = ", ".join( unique_ids: list[str] = []
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] for nid in node_ids:
) n = self._normalize_node_id(nid)
if n not in seen:
outgoing_query = """SELECT * FROM cypher('%s', $$ seen.add(n)
UNWIND [%s] AS node_id unique_ids.append(n)
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n)-[r]->(a)
RETURN node_id, count(a) AS out_degree
$$) AS (node_id text, out_degree bigint)""" % (
self.graph_name,
formatted_ids,
)
incoming_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n)<-[r]-(b)
RETURN node_id, count(b) AS in_degree
$$) AS (node_id text, in_degree bigint)""" % (
self.graph_name,
formatted_ids,
)
outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query)
out_degrees = {} out_degrees = {}
in_degrees = {} in_degrees = {}
for result in outgoing_results: for i in range(0, len(unique_ids), batch_size):
if result["node_id"] is not None: batch = unique_ids[i:i + batch_size]
out_degrees[result["node_id"]] = int(result["out_degree"])
for result in incoming_results: query = f"""
if result["node_id"] is not None: WITH input(v, ord) AS (
in_degrees[result["node_id"]] = int(result["in_degree"]) SELECT v, ord
FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
),
ids(node_id, ord) AS (
SELECT (to_json(v)::text)::agtype AS node_id, ord
FROM input
),
vids AS (
SELECT b.id AS vid, i.node_id, i.ord
FROM {self.graph_name}.base AS b
JOIN ids i
ON ag_catalog.agtype_access_operator(
VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
) = i.node_id
),
deg_out AS (
SELECT d.start_id AS vid, COUNT(*)::bigint AS out_degree
FROM {self.graph_name}."DIRECTED" AS d
JOIN vids v ON v.vid = d.start_id
GROUP BY d.start_id
),
deg_in AS (
SELECT d.end_id AS vid, COUNT(*)::bigint AS in_degree
FROM {self.graph_name}."DIRECTED" AS d
JOIN vids v ON v.vid = d.end_id
GROUP BY d.end_id
)
SELECT v.node_id::text AS node_id,
COALESCE(o.out_degree, 0) AS out_degree,
COALESCE(n.in_degree, 0) AS in_degree
FROM vids v
LEFT JOIN deg_out o ON o.vid = v.vid
LEFT JOIN deg_in n ON n.vid = v.vid
ORDER BY v.ord;
"""
combined_results = await self._query(query, params={"ids": batch})
for row in combined_results:
node_id = row["node_id"]
if not node_id:
continue
out_degrees[node_id] = int(row.get("out_degree", 0) or 0)
in_degrees[node_id] = int(row.get("in_degree", 0) or 0)
degrees_dict = {} degrees_dict = {}
for node_id in node_ids: for node_id in node_ids:
@ -3625,22 +3664,10 @@ class PGGraphStorage(BaseGraphStorage):
edges_dict[(result["source"], result["target"])] = edge_props edges_dict[(result["source"], result["target"])] = edge_props
for result in backward_results:
if result["source"] and result["target"] and result["edge_properties"]:
edge_props = result["edge_properties"]
for result in backward_results: for result in backward_results:
if result["source"] and result["target"] and result["edge_properties"]: if result["source"] and result["target"] and result["edge_properties"]:
edge_props = result["edge_properties"] edge_props = result["edge_properties"]
# Process string result, parse it to JSON dictionary
if isinstance(edge_props, str):
try:
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
)
continue
# Process string result, parse it to JSON dictionary # Process string result, parse it to JSON dictionary
if isinstance(edge_props, str): if isinstance(edge_props, str):
try: try:
@ -3671,49 +3698,62 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids: if not node_ids:
return {} return {}
# Format node IDs for the query seen = set()
formatted_ids = ", ".join( unique_ids: list[str] = []
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] for nid in node_ids:
) n = self._normalize_node_id(nid)
if n and n not in seen:
seen.add(n)
unique_ids.append(n)
outgoing_query = """SELECT * FROM cypher('%s', $$ edges_norm: dict[str, list[tuple[str, str]]] = {n: [] for n in unique_ids}
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n:base)-[]->(connected:base)
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids,
)
incoming_query = """SELECT * FROM cypher('%s', $$ for i in range(0, len(unique_ids), batch_size):
UNWIND [%s] AS node_id batch = unique_ids[i:i + batch_size]
MATCH (n:base {entity_id: node_id}) # Format node IDs for the query
OPTIONAL MATCH (n:base)<-[]-(connected:base) formatted_ids = ", ".join([f'"{n}"' for n in batch])
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids,
)
outgoing_results = await self._query(outgoing_query) outgoing_query = """SELECT * FROM cypher('%s', $$
incoming_results = await self._query(incoming_query) UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n:base)-[]->(connected:base)
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids,
)
nodes_edges_dict = {node_id: [] for node_id in node_ids} incoming_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n:base)<-[]-(connected:base)
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids,
)
for result in outgoing_results: outgoing_results = await self._query(outgoing_query)
if result["node_id"] and result["connected_id"]: incoming_results = await self._query(incoming_query)
nodes_edges_dict[result["node_id"]].append(
(result["node_id"], result["connected_id"])
)
for result in incoming_results: for result in outgoing_results:
if result["node_id"] and result["connected_id"]: if result["node_id"] and result["connected_id"]:
nodes_edges_dict[result["node_id"]].append( edges_norm[result["node_id"]].append(
(result["connected_id"], result["node_id"]) (result["node_id"], result["connected_id"])
) )
return nodes_edges_dict for result in incoming_results:
if result["node_id"] and result["connected_id"]:
edges_norm[result["node_id"]].append(
(result["connected_id"], result["node_id"])
)
out: dict[str, list[tuple[str, str]]] = {}
for orig in node_ids:
n = self._normalize_node_id(orig)
out[orig] = edges_norm.get(n, [])
return out
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
""" """