mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-11-18 02:53:46 +00:00
feat: optimize node and edge queries in PostgreSQL. query tables Directly
This commit is contained in:
parent
a7da48e05c
commit
6a7e3092ea
@ -9,6 +9,7 @@ from typing import Any, Union, final
|
||||
import numpy as np
|
||||
import configparser
|
||||
import ssl
|
||||
import itertools
|
||||
|
||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
|
||||
@ -3051,6 +3052,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
query: str,
|
||||
readonly: bool = True,
|
||||
upsert: bool = False,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Query the graph by taking a cypher query, converting it to an
|
||||
@ -3066,6 +3068,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
if readonly:
|
||||
data = await self.db.query(
|
||||
query,
|
||||
params=self.params,
|
||||
multirows=True,
|
||||
with_age=True,
|
||||
graph_name=self.graph_name,
|
||||
@ -3398,41 +3401,55 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
if not node_ids:
|
||||
return {}
|
||||
|
||||
# Format node IDs for the query
|
||||
formatted_ids = ", ".join(
|
||||
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
|
||||
)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
RETURN node_id, n
|
||||
$$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
|
||||
|
||||
results = await self._query(query)
|
||||
seen = set()
|
||||
unique_ids = []
|
||||
for nid in node_ids:
|
||||
nid_norm = self._normalize_node_id(nid)
|
||||
if nid_norm not in seen:
|
||||
seen.add(nid_norm)
|
||||
unique_ids.append(nid_norm)
|
||||
|
||||
# Build result dictionary
|
||||
nodes_dict = {}
|
||||
for result in results:
|
||||
if result["node_id"] and result["n"]:
|
||||
node_dict = result["n"]["properties"]
|
||||
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(node_dict, str):
|
||||
try:
|
||||
node_dict = json.loads(node_dict)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
|
||||
)
|
||||
for i in range(0, len(unique_ids), batch_size):
|
||||
batch = unique_ids[i : i + batch_size]
|
||||
|
||||
# Remove the 'base' label if present in a 'labels' property
|
||||
# if "labels" in node_dict:
|
||||
# node_dict["labels"] = [
|
||||
# label for label in node_dict["labels"] if label != "base"
|
||||
# ]
|
||||
query = f"""
|
||||
WITH input(v, ord) AS (
|
||||
SELECT v, ord
|
||||
FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
|
||||
),
|
||||
ids(node_id, ord) AS (
|
||||
SELECT (to_json(v)::text)::agtype AS node_id, ord
|
||||
FROM input
|
||||
)
|
||||
SELECT i.node_id::text AS node_id,
|
||||
b.properties
|
||||
FROM {self.graph_name}.base AS b
|
||||
JOIN ids i
|
||||
ON ag_catalog.agtype_access_operator(
|
||||
VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
|
||||
) = i.node_id
|
||||
ORDER BY i.ord;
|
||||
"""
|
||||
|
||||
nodes_dict[result["node_id"]] = node_dict
|
||||
results = await self._query(query, params={"ids": batch})
|
||||
|
||||
for result in results:
|
||||
if result["node_id"] and result["properties"]:
|
||||
node_dict = result["properties"]
|
||||
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(node_dict, str):
|
||||
try:
|
||||
node_dict = json.loads(node_dict)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse node string in batch: {node_dict}"
|
||||
)
|
||||
|
||||
nodes_dict[result["node_id"]] = node_dict
|
||||
|
||||
return nodes_dict
|
||||
|
||||
@ -3453,44 +3470,66 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
if not node_ids:
|
||||
return {}
|
||||
|
||||
# Format node IDs for the query
|
||||
formatted_ids = ", ".join(
|
||||
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
|
||||
)
|
||||
|
||||
outgoing_query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
OPTIONAL MATCH (n)-[r]->(a)
|
||||
RETURN node_id, count(a) AS out_degree
|
||||
$$) AS (node_id text, out_degree bigint)""" % (
|
||||
self.graph_name,
|
||||
formatted_ids,
|
||||
)
|
||||
|
||||
incoming_query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
OPTIONAL MATCH (n)<-[r]-(b)
|
||||
RETURN node_id, count(b) AS in_degree
|
||||
$$) AS (node_id text, in_degree bigint)""" % (
|
||||
self.graph_name,
|
||||
formatted_ids,
|
||||
)
|
||||
|
||||
outgoing_results = await self._query(outgoing_query)
|
||||
incoming_results = await self._query(incoming_query)
|
||||
seen = set()
|
||||
unique_ids: list[str] = []
|
||||
for nid in node_ids:
|
||||
n = self._normalize_node_id(nid)
|
||||
if n not in seen:
|
||||
seen.add(n)
|
||||
unique_ids.append(n)
|
||||
|
||||
out_degrees = {}
|
||||
in_degrees = {}
|
||||
|
||||
for result in outgoing_results:
|
||||
if result["node_id"] is not None:
|
||||
out_degrees[result["node_id"]] = int(result["out_degree"])
|
||||
for i in range(0, len(unique_ids), batch_size):
|
||||
batch = unique_ids[i:i + batch_size]
|
||||
|
||||
for result in incoming_results:
|
||||
if result["node_id"] is not None:
|
||||
in_degrees[result["node_id"]] = int(result["in_degree"])
|
||||
query = f"""
|
||||
WITH input(v, ord) AS (
|
||||
SELECT v, ord
|
||||
FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
|
||||
),
|
||||
ids(node_id, ord) AS (
|
||||
SELECT (to_json(v)::text)::agtype AS node_id, ord
|
||||
FROM input
|
||||
),
|
||||
vids AS (
|
||||
SELECT b.id AS vid, i.node_id, i.ord
|
||||
FROM {self.graph_name}.base AS b
|
||||
JOIN ids i
|
||||
ON ag_catalog.agtype_access_operator(
|
||||
VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
|
||||
) = i.node_id
|
||||
),
|
||||
deg_out AS (
|
||||
SELECT d.start_id AS vid, COUNT(*)::bigint AS out_degree
|
||||
FROM {self.graph_name}."DIRECTED" AS d
|
||||
JOIN vids v ON v.vid = d.start_id
|
||||
GROUP BY d.start_id
|
||||
),
|
||||
deg_in AS (
|
||||
SELECT d.end_id AS vid, COUNT(*)::bigint AS in_degree
|
||||
FROM {self.graph_name}."DIRECTED" AS d
|
||||
JOIN vids v ON v.vid = d.end_id
|
||||
GROUP BY d.end_id
|
||||
)
|
||||
SELECT v.node_id::text AS node_id,
|
||||
COALESCE(o.out_degree, 0) AS out_degree,
|
||||
COALESCE(n.in_degree, 0) AS in_degree
|
||||
FROM vids v
|
||||
LEFT JOIN deg_out o ON o.vid = v.vid
|
||||
LEFT JOIN deg_in n ON n.vid = v.vid
|
||||
ORDER BY v.ord;
|
||||
"""
|
||||
|
||||
combined_results = await self._query(query, params={"ids": batch})
|
||||
|
||||
for row in combined_results:
|
||||
node_id = row["node_id"]
|
||||
if not node_id:
|
||||
continue
|
||||
out_degrees[node_id] = int(row.get("out_degree", 0) or 0)
|
||||
in_degrees[node_id] = int(row.get("in_degree", 0) or 0)
|
||||
|
||||
degrees_dict = {}
|
||||
for node_id in node_ids:
|
||||
@ -3625,22 +3664,10 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
edges_dict[(result["source"], result["target"])] = edge_props
|
||||
|
||||
for result in backward_results:
|
||||
if result["source"] and result["target"] and result["edge_properties"]:
|
||||
edge_props = result["edge_properties"]
|
||||
for result in backward_results:
|
||||
if result["source"] and result["target"] and result["edge_properties"]:
|
||||
edge_props = result["edge_properties"]
|
||||
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(edge_props, str):
|
||||
try:
|
||||
edge_props = json.loads(edge_props)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
|
||||
)
|
||||
continue
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(edge_props, str):
|
||||
try:
|
||||
@ -3671,49 +3698,62 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
if not node_ids:
|
||||
return {}
|
||||
|
||||
# Format node IDs for the query
|
||||
formatted_ids = ", ".join(
|
||||
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
|
||||
)
|
||||
seen = set()
|
||||
unique_ids: list[str] = []
|
||||
for nid in node_ids:
|
||||
n = self._normalize_node_id(nid)
|
||||
if n and n not in seen:
|
||||
seen.add(n)
|
||||
unique_ids.append(n)
|
||||
|
||||
outgoing_query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
OPTIONAL MATCH (n:base)-[]->(connected:base)
|
||||
RETURN node_id, connected.entity_id AS connected_id
|
||||
$$) AS (node_id text, connected_id text)""" % (
|
||||
self.graph_name,
|
||||
formatted_ids,
|
||||
)
|
||||
edges_norm: dict[str, list[tuple[str, str]]] = {n: [] for n in unique_ids}
|
||||
|
||||
incoming_query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
OPTIONAL MATCH (n:base)<-[]-(connected:base)
|
||||
RETURN node_id, connected.entity_id AS connected_id
|
||||
$$) AS (node_id text, connected_id text)""" % (
|
||||
self.graph_name,
|
||||
formatted_ids,
|
||||
)
|
||||
for i in range(0, len(unique_ids), batch_size):
|
||||
batch = unique_ids[i:i + batch_size]
|
||||
# Format node IDs for the query
|
||||
formatted_ids = ", ".join([f'"{n}"' for n in batch])
|
||||
|
||||
outgoing_results = await self._query(outgoing_query)
|
||||
incoming_results = await self._query(incoming_query)
|
||||
outgoing_query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
OPTIONAL MATCH (n:base)-[]->(connected:base)
|
||||
RETURN node_id, connected.entity_id AS connected_id
|
||||
$$) AS (node_id text, connected_id text)""" % (
|
||||
self.graph_name,
|
||||
formatted_ids,
|
||||
)
|
||||
|
||||
nodes_edges_dict = {node_id: [] for node_id in node_ids}
|
||||
incoming_query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
OPTIONAL MATCH (n:base)<-[]-(connected:base)
|
||||
RETURN node_id, connected.entity_id AS connected_id
|
||||
$$) AS (node_id text, connected_id text)""" % (
|
||||
self.graph_name,
|
||||
formatted_ids,
|
||||
)
|
||||
|
||||
for result in outgoing_results:
|
||||
if result["node_id"] and result["connected_id"]:
|
||||
nodes_edges_dict[result["node_id"]].append(
|
||||
(result["node_id"], result["connected_id"])
|
||||
)
|
||||
outgoing_results = await self._query(outgoing_query)
|
||||
incoming_results = await self._query(incoming_query)
|
||||
|
||||
for result in incoming_results:
|
||||
if result["node_id"] and result["connected_id"]:
|
||||
nodes_edges_dict[result["node_id"]].append(
|
||||
(result["connected_id"], result["node_id"])
|
||||
)
|
||||
for result in outgoing_results:
|
||||
if result["node_id"] and result["connected_id"]:
|
||||
edges_norm[result["node_id"]].append(
|
||||
(result["node_id"], result["connected_id"])
|
||||
)
|
||||
|
||||
return nodes_edges_dict
|
||||
for result in incoming_results:
|
||||
if result["node_id"] and result["connected_id"]:
|
||||
edges_norm[result["node_id"]].append(
|
||||
(result["connected_id"], result["node_id"])
|
||||
)
|
||||
|
||||
out: dict[str, list[tuple[str, str]]] = {}
|
||||
for orig in node_ids:
|
||||
n = self._normalize_node_id(orig)
|
||||
out[orig] = edges_norm.get(n, [])
|
||||
|
||||
return out
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user