From 6a7e3092ea0c491d358d16b1e0e27038c80b58d3 Mon Sep 17 00:00:00 2001 From: Matt23-star Date: Sat, 16 Aug 2025 22:37:48 +0800 Subject: [PATCH] feat: optimize node and edge queries in PostgreSQL. query tables Directly --- lightrag/kg/postgres_impl.py | 260 ++++++++++++++++++++--------------- 1 file changed, 150 insertions(+), 110 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 18957699..74034877 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -9,6 +9,7 @@ from typing import Any, Union, final import numpy as np import configparser import ssl +import itertools from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge @@ -3051,6 +3052,7 @@ class PGGraphStorage(BaseGraphStorage): query: str, readonly: bool = True, upsert: bool = False, + params: dict[str, Any] | None = None, ) -> list[dict[str, Any]]: """ Query the graph by taking a cypher query, converting it to an @@ -3066,6 +3068,7 @@ class PGGraphStorage(BaseGraphStorage): if readonly: data = await self.db.query( query, + params=self.params, multirows=True, with_age=True, graph_name=self.graph_name, @@ -3398,41 +3401,55 @@ class PGGraphStorage(BaseGraphStorage): if not node_ids: return {} - # Format node IDs for the query - formatted_ids = ", ".join( - ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] - ) - - query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - RETURN node_id, n - $$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids) - - results = await self._query(query) + seen = set() + unique_ids = [] + for nid in node_ids: + nid_norm = self._normalize_node_id(nid) + if nid_norm not in seen: + seen.add(nid_norm) + unique_ids.append(nid_norm) # Build result dictionary nodes_dict = {} - for result in results: - if result["node_id"] and result["n"]: - node_dict = result["n"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(node_dict, str): - try: - node_dict = json.loads(node_dict) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse node string in batch: {node_dict}" - ) + for i in range(0, len(unique_ids), batch_size): + batch = unique_ids[i : i + batch_size] - # Remove the 'base' label if present in a 'labels' property - # if "labels" in node_dict: - # node_dict["labels"] = [ - # label for label in node_dict["labels"] if label != "base" - # ] + query = f""" + WITH input(v, ord) AS ( + SELECT v, ord + FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord) + ), + ids(node_id, ord) AS ( + SELECT (to_json(v)::text)::agtype AS node_id, ord + FROM input + ) + SELECT i.node_id::text AS node_id, + b.properties + FROM {self.graph_name}.base AS b + JOIN ids i + ON ag_catalog.agtype_access_operator( + VARIADIC ARRAY[b.properties, '"entity_id"'::agtype] + ) = i.node_id + ORDER BY i.ord; + """ - nodes_dict[result["node_id"]] = node_dict + results = await self._query(query, params={"ids": batch}) + + for result in results: + if result["node_id"] and result["properties"]: + node_dict = result["properties"] + + # Process string result, parse it to JSON dictionary + if isinstance(node_dict, str): + try: + node_dict = json.loads(node_dict) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse node string in batch: {node_dict}" + ) + + nodes_dict[result["node_id"]] = node_dict return nodes_dict @@ -3453,44 +3470,66 @@ class PGGraphStorage(BaseGraphStorage): if not node_ids: return {} - # Format node IDs for the query - formatted_ids = ", ".join( - ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] - ) - - outgoing_query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - OPTIONAL MATCH (n)-[r]->(a) - RETURN node_id, count(a) AS out_degree - $$) AS (node_id text, out_degree bigint)""" % ( - self.graph_name, - formatted_ids, - ) - - incoming_query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - OPTIONAL MATCH (n)<-[r]-(b) - RETURN node_id, count(b) AS in_degree - $$) AS (node_id text, in_degree bigint)""" % ( - self.graph_name, - formatted_ids, - ) - - outgoing_results = await self._query(outgoing_query) - incoming_results = await self._query(incoming_query) + seen = set() + unique_ids: list[str] = [] + for nid in node_ids: + n = self._normalize_node_id(nid) + if n not in seen: + seen.add(n) + unique_ids.append(n) out_degrees = {} in_degrees = {} - for result in outgoing_results: - if result["node_id"] is not None: - out_degrees[result["node_id"]] = int(result["out_degree"]) + for i in range(0, len(unique_ids), batch_size): + batch = unique_ids[i:i + batch_size] - for result in incoming_results: - if result["node_id"] is not None: - in_degrees[result["node_id"]] = int(result["in_degree"]) + query = f""" + WITH input(v, ord) AS ( + SELECT v, ord + FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord) + ), + ids(node_id, ord) AS ( + SELECT (to_json(v)::text)::agtype AS node_id, ord + FROM input + ), + vids AS ( + SELECT b.id AS vid, i.node_id, i.ord + FROM {self.graph_name}.base AS b + JOIN ids i + ON ag_catalog.agtype_access_operator( + VARIADIC ARRAY[b.properties, '"entity_id"'::agtype] + ) = i.node_id + ), + deg_out AS ( + SELECT d.start_id AS vid, COUNT(*)::bigint AS out_degree + FROM {self.graph_name}."DIRECTED" AS d + JOIN vids v ON v.vid = d.start_id + GROUP BY d.start_id + ), + deg_in AS ( + SELECT d.end_id AS vid, COUNT(*)::bigint AS in_degree + FROM {self.graph_name}."DIRECTED" AS d + JOIN vids v ON v.vid = d.end_id + GROUP BY d.end_id + ) + SELECT v.node_id::text AS node_id, + COALESCE(o.out_degree, 0) AS out_degree, + COALESCE(n.in_degree, 0) AS in_degree + FROM vids v + LEFT JOIN deg_out o ON o.vid = v.vid + LEFT JOIN deg_in n ON n.vid = v.vid + ORDER BY v.ord; + """ + + combined_results = await self._query(query, params={"ids": batch}) + + for row in combined_results: + node_id = row["node_id"] + if not node_id: + continue + out_degrees[node_id] = int(row.get("out_degree", 0) or 0) + in_degrees[node_id] = int(row.get("in_degree", 0) or 0) degrees_dict = {} for node_id in node_ids: @@ -3625,22 +3664,10 @@ class PGGraphStorage(BaseGraphStorage): edges_dict[(result["source"], result["target"])] = edge_props - for result in backward_results: - if result["source"] and result["target"] and result["edge_properties"]: - edge_props = result["edge_properties"] for result in backward_results: if result["source"] and result["target"] and result["edge_properties"]: edge_props = result["edge_properties"] - # Process string result, parse it to JSON dictionary - if isinstance(edge_props, str): - try: - edge_props = json.loads(edge_props) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse edge properties string: {edge_props}" - ) - continue # Process string result, parse it to JSON dictionary if isinstance(edge_props, str): try: @@ -3671,49 +3698,62 @@ class PGGraphStorage(BaseGraphStorage): if not node_ids: return {} - # Format node IDs for the query - formatted_ids = ", ".join( - ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] - ) + seen = set() + unique_ids: list[str] = [] + for nid in node_ids: + n = self._normalize_node_id(nid) + if n and n not in seen: + seen.add(n) + unique_ids.append(n) - outgoing_query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - OPTIONAL MATCH (n:base)-[]->(connected:base) - RETURN node_id, connected.entity_id AS connected_id - $$) AS (node_id text, connected_id text)""" % ( - self.graph_name, - formatted_ids, - ) + edges_norm: dict[str, list[tuple[str, str]]] = {n: [] for n in unique_ids} - incoming_query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - OPTIONAL MATCH (n:base)<-[]-(connected:base) - RETURN node_id, connected.entity_id AS connected_id - $$) AS (node_id text, connected_id text)""" % ( - self.graph_name, - formatted_ids, - ) + for i in range(0, len(unique_ids), batch_size): + batch = unique_ids[i:i + batch_size] + # Format node IDs for the query + formatted_ids = ", ".join([f'"{n}"' for n in batch]) - outgoing_results = await self._query(outgoing_query) - incoming_results = await self._query(incoming_query) + outgoing_query = """SELECT * FROM cypher('%s', $$ + UNWIND [%s] AS node_id + MATCH (n:base {entity_id: node_id}) + OPTIONAL MATCH (n:base)-[]->(connected:base) + RETURN node_id, connected.entity_id AS connected_id + $$) AS (node_id text, connected_id text)""" % ( + self.graph_name, + formatted_ids, + ) - nodes_edges_dict = {node_id: [] for node_id in node_ids} + incoming_query = """SELECT * FROM cypher('%s', $$ + UNWIND [%s] AS node_id + MATCH (n:base {entity_id: node_id}) + OPTIONAL MATCH (n:base)<-[]-(connected:base) + RETURN node_id, connected.entity_id AS connected_id + $$) AS (node_id text, connected_id text)""" % ( + self.graph_name, + formatted_ids, + ) - for result in outgoing_results: - if result["node_id"] and result["connected_id"]: - nodes_edges_dict[result["node_id"]].append( - (result["node_id"], result["connected_id"]) - ) + outgoing_results = await self._query(outgoing_query) + incoming_results = await self._query(incoming_query) - for result in incoming_results: - if result["node_id"] and result["connected_id"]: - nodes_edges_dict[result["node_id"]].append( - (result["connected_id"], result["node_id"]) - ) + for result in outgoing_results: + if result["node_id"] and result["connected_id"]: + edges_norm[result["node_id"]].append( + (result["node_id"], result["connected_id"]) + ) - return nodes_edges_dict + for result in incoming_results: + if result["node_id"] and result["connected_id"]: + edges_norm[result["node_id"]].append( + (result["connected_id"], result["node_id"]) + ) + + out: dict[str, list[tuple[str, str]]] = {} + for orig in node_ids: + n = self._normalize_node_id(orig) + out[orig] = edges_norm.get(n, []) + + return out async def get_all_labels(self) -> list[str]: """