feat: optimize node and edge queries in PostgreSQL. query tables Directly

This commit is contained in:
Matt23-star 2025-08-16 22:37:48 +08:00
parent a7da48e05c
commit 6a7e3092ea

View File

@ -9,6 +9,7 @@ from typing import Any, Union, final
import numpy as np import numpy as np
import configparser import configparser
import ssl import ssl
import itertools
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
@ -3051,6 +3052,7 @@ class PGGraphStorage(BaseGraphStorage):
query: str, query: str,
readonly: bool = True, readonly: bool = True,
upsert: bool = False, upsert: bool = False,
params: dict[str, Any] | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Query the graph by taking a cypher query, converting it to an Query the graph by taking a cypher query, converting it to an
@ -3066,6 +3068,7 @@ class PGGraphStorage(BaseGraphStorage):
if readonly: if readonly:
data = await self.db.query( data = await self.db.query(
query, query,
params=self.params,
multirows=True, multirows=True,
with_age=True, with_age=True,
graph_name=self.graph_name, graph_name=self.graph_name,
@ -3398,24 +3401,44 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids: if not node_ids:
return {} return {}
# Format node IDs for the query seen = set()
formatted_ids = ", ".join( unique_ids = []
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] for nid in node_ids:
) nid_norm = self._normalize_node_id(nid)
if nid_norm not in seen:
query = """SELECT * FROM cypher('%s', $$ seen.add(nid_norm)
UNWIND [%s] AS node_id unique_ids.append(nid_norm)
MATCH (n:base {entity_id: node_id})
RETURN node_id, n
$$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
results = await self._query(query)
# Build result dictionary # Build result dictionary
nodes_dict = {} nodes_dict = {}
for i in range(0, len(unique_ids), batch_size):
batch = unique_ids[i : i + batch_size]
query = f"""
WITH input(v, ord) AS (
SELECT v, ord
FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
),
ids(node_id, ord) AS (
SELECT (to_json(v)::text)::agtype AS node_id, ord
FROM input
)
SELECT i.node_id::text AS node_id,
b.properties
FROM {self.graph_name}.base AS b
JOIN ids i
ON ag_catalog.agtype_access_operator(
VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
) = i.node_id
ORDER BY i.ord;
"""
results = await self._query(query, params={"ids": batch})
for result in results: for result in results:
if result["node_id"] and result["n"]: if result["node_id"] and result["properties"]:
node_dict = result["n"]["properties"] node_dict = result["properties"]
# Process string result, parse it to JSON dictionary # Process string result, parse it to JSON dictionary
if isinstance(node_dict, str): if isinstance(node_dict, str):
@ -3423,15 +3446,9 @@ class PGGraphStorage(BaseGraphStorage):
node_dict = json.loads(node_dict) node_dict = json.loads(node_dict)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning( logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}" f"Failed to parse node string in batch: {node_dict}"
) )
# Remove the 'base' label if present in a 'labels' property
# if "labels" in node_dict:
# node_dict["labels"] = [
# label for label in node_dict["labels"] if label != "base"
# ]
nodes_dict[result["node_id"]] = node_dict nodes_dict[result["node_id"]] = node_dict
return nodes_dict return nodes_dict
@ -3453,44 +3470,66 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids: if not node_ids:
return {} return {}
# Format node IDs for the query seen = set()
formatted_ids = ", ".join( unique_ids: list[str] = []
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] for nid in node_ids:
) n = self._normalize_node_id(nid)
if n not in seen:
outgoing_query = """SELECT * FROM cypher('%s', $$ seen.add(n)
UNWIND [%s] AS node_id unique_ids.append(n)
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n)-[r]->(a)
RETURN node_id, count(a) AS out_degree
$$) AS (node_id text, out_degree bigint)""" % (
self.graph_name,
formatted_ids,
)
incoming_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n)<-[r]-(b)
RETURN node_id, count(b) AS in_degree
$$) AS (node_id text, in_degree bigint)""" % (
self.graph_name,
formatted_ids,
)
outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query)
out_degrees = {} out_degrees = {}
in_degrees = {} in_degrees = {}
for result in outgoing_results: for i in range(0, len(unique_ids), batch_size):
if result["node_id"] is not None: batch = unique_ids[i:i + batch_size]
out_degrees[result["node_id"]] = int(result["out_degree"])
for result in incoming_results: query = f"""
if result["node_id"] is not None: WITH input(v, ord) AS (
in_degrees[result["node_id"]] = int(result["in_degree"]) SELECT v, ord
FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
),
ids(node_id, ord) AS (
SELECT (to_json(v)::text)::agtype AS node_id, ord
FROM input
),
vids AS (
SELECT b.id AS vid, i.node_id, i.ord
FROM {self.graph_name}.base AS b
JOIN ids i
ON ag_catalog.agtype_access_operator(
VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
) = i.node_id
),
deg_out AS (
SELECT d.start_id AS vid, COUNT(*)::bigint AS out_degree
FROM {self.graph_name}."DIRECTED" AS d
JOIN vids v ON v.vid = d.start_id
GROUP BY d.start_id
),
deg_in AS (
SELECT d.end_id AS vid, COUNT(*)::bigint AS in_degree
FROM {self.graph_name}."DIRECTED" AS d
JOIN vids v ON v.vid = d.end_id
GROUP BY d.end_id
)
SELECT v.node_id::text AS node_id,
COALESCE(o.out_degree, 0) AS out_degree,
COALESCE(n.in_degree, 0) AS in_degree
FROM vids v
LEFT JOIN deg_out o ON o.vid = v.vid
LEFT JOIN deg_in n ON n.vid = v.vid
ORDER BY v.ord;
"""
combined_results = await self._query(query, params={"ids": batch})
for row in combined_results:
node_id = row["node_id"]
if not node_id:
continue
out_degrees[node_id] = int(row.get("out_degree", 0) or 0)
in_degrees[node_id] = int(row.get("in_degree", 0) or 0)
degrees_dict = {} degrees_dict = {}
for node_id in node_ids: for node_id in node_ids:
@ -3625,23 +3664,11 @@ class PGGraphStorage(BaseGraphStorage):
edges_dict[(result["source"], result["target"])] = edge_props edges_dict[(result["source"], result["target"])] = edge_props
for result in backward_results:
if result["source"] and result["target"] and result["edge_properties"]:
edge_props = result["edge_properties"]
for result in backward_results: for result in backward_results:
if result["source"] and result["target"] and result["edge_properties"]: if result["source"] and result["target"] and result["edge_properties"]:
edge_props = result["edge_properties"] edge_props = result["edge_properties"]
# Process string result, parse it to JSON dictionary # Process string result, parse it to JSON dictionary
if isinstance(edge_props, str):
try:
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
)
continue
# Process string result, parse it to JSON dictionary
if isinstance(edge_props, str): if isinstance(edge_props, str):
try: try:
edge_props = json.loads(edge_props) edge_props = json.loads(edge_props)
@ -3671,10 +3698,20 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids: if not node_ids:
return {} return {}
seen = set()
unique_ids: list[str] = []
for nid in node_ids:
n = self._normalize_node_id(nid)
if n and n not in seen:
seen.add(n)
unique_ids.append(n)
edges_norm: dict[str, list[tuple[str, str]]] = {n: [] for n in unique_ids}
for i in range(0, len(unique_ids), batch_size):
batch = unique_ids[i:i + batch_size]
# Format node IDs for the query # Format node IDs for the query
formatted_ids = ", ".join( formatted_ids = ", ".join([f'"{n}"' for n in batch])
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
)
outgoing_query = """SELECT * FROM cypher('%s', $$ outgoing_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id UNWIND [%s] AS node_id
@ -3699,21 +3736,24 @@ class PGGraphStorage(BaseGraphStorage):
outgoing_results = await self._query(outgoing_query) outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query) incoming_results = await self._query(incoming_query)
nodes_edges_dict = {node_id: [] for node_id in node_ids}
for result in outgoing_results: for result in outgoing_results:
if result["node_id"] and result["connected_id"]: if result["node_id"] and result["connected_id"]:
nodes_edges_dict[result["node_id"]].append( edges_norm[result["node_id"]].append(
(result["node_id"], result["connected_id"]) (result["node_id"], result["connected_id"])
) )
for result in incoming_results: for result in incoming_results:
if result["node_id"] and result["connected_id"]: if result["node_id"] and result["connected_id"]:
nodes_edges_dict[result["node_id"]].append( edges_norm[result["node_id"]].append(
(result["connected_id"], result["node_id"]) (result["connected_id"], result["node_id"])
) )
return nodes_edges_dict out: dict[str, list[tuple[str, str]]] = {}
for orig in node_ids:
n = self._normalize_node_id(orig)
out[orig] = edges_norm.get(n, [])
return out
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
""" """