mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-11-22 21:15:52 +00:00
feat: optimize node and edge queries in PostgreSQL. query tables Directly
This commit is contained in:
parent
a7da48e05c
commit
6a7e3092ea
@ -9,6 +9,7 @@ from typing import Any, Union, final
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
import ssl
|
import ssl
|
||||||
|
import itertools
|
||||||
|
|
||||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
|
||||||
@ -3051,6 +3052,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
query: str,
|
query: str,
|
||||||
readonly: bool = True,
|
readonly: bool = True,
|
||||||
upsert: bool = False,
|
upsert: bool = False,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Query the graph by taking a cypher query, converting it to an
|
Query the graph by taking a cypher query, converting it to an
|
||||||
@ -3066,6 +3068,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
if readonly:
|
if readonly:
|
||||||
data = await self.db.query(
|
data = await self.db.query(
|
||||||
query,
|
query,
|
||||||
|
params=self.params,
|
||||||
multirows=True,
|
multirows=True,
|
||||||
with_age=True,
|
with_age=True,
|
||||||
graph_name=self.graph_name,
|
graph_name=self.graph_name,
|
||||||
@ -3398,24 +3401,44 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
if not node_ids:
|
if not node_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Format node IDs for the query
|
seen = set()
|
||||||
formatted_ids = ", ".join(
|
unique_ids = []
|
||||||
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
|
for nid in node_ids:
|
||||||
)
|
nid_norm = self._normalize_node_id(nid)
|
||||||
|
if nid_norm not in seen:
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
seen.add(nid_norm)
|
||||||
UNWIND [%s] AS node_id
|
unique_ids.append(nid_norm)
|
||||||
MATCH (n:base {entity_id: node_id})
|
|
||||||
RETURN node_id, n
|
|
||||||
$$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
|
|
||||||
|
|
||||||
results = await self._query(query)
|
|
||||||
|
|
||||||
# Build result dictionary
|
# Build result dictionary
|
||||||
nodes_dict = {}
|
nodes_dict = {}
|
||||||
|
|
||||||
|
for i in range(0, len(unique_ids), batch_size):
|
||||||
|
batch = unique_ids[i : i + batch_size]
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
WITH input(v, ord) AS (
|
||||||
|
SELECT v, ord
|
||||||
|
FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
|
||||||
|
),
|
||||||
|
ids(node_id, ord) AS (
|
||||||
|
SELECT (to_json(v)::text)::agtype AS node_id, ord
|
||||||
|
FROM input
|
||||||
|
)
|
||||||
|
SELECT i.node_id::text AS node_id,
|
||||||
|
b.properties
|
||||||
|
FROM {self.graph_name}.base AS b
|
||||||
|
JOIN ids i
|
||||||
|
ON ag_catalog.agtype_access_operator(
|
||||||
|
VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
|
||||||
|
) = i.node_id
|
||||||
|
ORDER BY i.ord;
|
||||||
|
"""
|
||||||
|
|
||||||
|
results = await self._query(query, params={"ids": batch})
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
if result["node_id"] and result["n"]:
|
if result["node_id"] and result["properties"]:
|
||||||
node_dict = result["n"]["properties"]
|
node_dict = result["properties"]
|
||||||
|
|
||||||
# Process string result, parse it to JSON dictionary
|
# Process string result, parse it to JSON dictionary
|
||||||
if isinstance(node_dict, str):
|
if isinstance(node_dict, str):
|
||||||
@ -3423,15 +3446,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
node_dict = json.loads(node_dict)
|
node_dict = json.loads(node_dict)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
|
f"Failed to parse node string in batch: {node_dict}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove the 'base' label if present in a 'labels' property
|
|
||||||
# if "labels" in node_dict:
|
|
||||||
# node_dict["labels"] = [
|
|
||||||
# label for label in node_dict["labels"] if label != "base"
|
|
||||||
# ]
|
|
||||||
|
|
||||||
nodes_dict[result["node_id"]] = node_dict
|
nodes_dict[result["node_id"]] = node_dict
|
||||||
|
|
||||||
return nodes_dict
|
return nodes_dict
|
||||||
@ -3453,44 +3470,66 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
if not node_ids:
|
if not node_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Format node IDs for the query
|
seen = set()
|
||||||
formatted_ids = ", ".join(
|
unique_ids: list[str] = []
|
||||||
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
|
for nid in node_ids:
|
||||||
)
|
n = self._normalize_node_id(nid)
|
||||||
|
if n not in seen:
|
||||||
outgoing_query = """SELECT * FROM cypher('%s', $$
|
seen.add(n)
|
||||||
UNWIND [%s] AS node_id
|
unique_ids.append(n)
|
||||||
MATCH (n:base {entity_id: node_id})
|
|
||||||
OPTIONAL MATCH (n)-[r]->(a)
|
|
||||||
RETURN node_id, count(a) AS out_degree
|
|
||||||
$$) AS (node_id text, out_degree bigint)""" % (
|
|
||||||
self.graph_name,
|
|
||||||
formatted_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
incoming_query = """SELECT * FROM cypher('%s', $$
|
|
||||||
UNWIND [%s] AS node_id
|
|
||||||
MATCH (n:base {entity_id: node_id})
|
|
||||||
OPTIONAL MATCH (n)<-[r]-(b)
|
|
||||||
RETURN node_id, count(b) AS in_degree
|
|
||||||
$$) AS (node_id text, in_degree bigint)""" % (
|
|
||||||
self.graph_name,
|
|
||||||
formatted_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
outgoing_results = await self._query(outgoing_query)
|
|
||||||
incoming_results = await self._query(incoming_query)
|
|
||||||
|
|
||||||
out_degrees = {}
|
out_degrees = {}
|
||||||
in_degrees = {}
|
in_degrees = {}
|
||||||
|
|
||||||
for result in outgoing_results:
|
for i in range(0, len(unique_ids), batch_size):
|
||||||
if result["node_id"] is not None:
|
batch = unique_ids[i:i + batch_size]
|
||||||
out_degrees[result["node_id"]] = int(result["out_degree"])
|
|
||||||
|
|
||||||
for result in incoming_results:
|
query = f"""
|
||||||
if result["node_id"] is not None:
|
WITH input(v, ord) AS (
|
||||||
in_degrees[result["node_id"]] = int(result["in_degree"])
|
SELECT v, ord
|
||||||
|
FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
|
||||||
|
),
|
||||||
|
ids(node_id, ord) AS (
|
||||||
|
SELECT (to_json(v)::text)::agtype AS node_id, ord
|
||||||
|
FROM input
|
||||||
|
),
|
||||||
|
vids AS (
|
||||||
|
SELECT b.id AS vid, i.node_id, i.ord
|
||||||
|
FROM {self.graph_name}.base AS b
|
||||||
|
JOIN ids i
|
||||||
|
ON ag_catalog.agtype_access_operator(
|
||||||
|
VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
|
||||||
|
) = i.node_id
|
||||||
|
),
|
||||||
|
deg_out AS (
|
||||||
|
SELECT d.start_id AS vid, COUNT(*)::bigint AS out_degree
|
||||||
|
FROM {self.graph_name}."DIRECTED" AS d
|
||||||
|
JOIN vids v ON v.vid = d.start_id
|
||||||
|
GROUP BY d.start_id
|
||||||
|
),
|
||||||
|
deg_in AS (
|
||||||
|
SELECT d.end_id AS vid, COUNT(*)::bigint AS in_degree
|
||||||
|
FROM {self.graph_name}."DIRECTED" AS d
|
||||||
|
JOIN vids v ON v.vid = d.end_id
|
||||||
|
GROUP BY d.end_id
|
||||||
|
)
|
||||||
|
SELECT v.node_id::text AS node_id,
|
||||||
|
COALESCE(o.out_degree, 0) AS out_degree,
|
||||||
|
COALESCE(n.in_degree, 0) AS in_degree
|
||||||
|
FROM vids v
|
||||||
|
LEFT JOIN deg_out o ON o.vid = v.vid
|
||||||
|
LEFT JOIN deg_in n ON n.vid = v.vid
|
||||||
|
ORDER BY v.ord;
|
||||||
|
"""
|
||||||
|
|
||||||
|
combined_results = await self._query(query, params={"ids": batch})
|
||||||
|
|
||||||
|
for row in combined_results:
|
||||||
|
node_id = row["node_id"]
|
||||||
|
if not node_id:
|
||||||
|
continue
|
||||||
|
out_degrees[node_id] = int(row.get("out_degree", 0) or 0)
|
||||||
|
in_degrees[node_id] = int(row.get("in_degree", 0) or 0)
|
||||||
|
|
||||||
degrees_dict = {}
|
degrees_dict = {}
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
@ -3625,23 +3664,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
edges_dict[(result["source"], result["target"])] = edge_props
|
edges_dict[(result["source"], result["target"])] = edge_props
|
||||||
|
|
||||||
for result in backward_results:
|
|
||||||
if result["source"] and result["target"] and result["edge_properties"]:
|
|
||||||
edge_props = result["edge_properties"]
|
|
||||||
for result in backward_results:
|
for result in backward_results:
|
||||||
if result["source"] and result["target"] and result["edge_properties"]:
|
if result["source"] and result["target"] and result["edge_properties"]:
|
||||||
edge_props = result["edge_properties"]
|
edge_props = result["edge_properties"]
|
||||||
|
|
||||||
# Process string result, parse it to JSON dictionary
|
# Process string result, parse it to JSON dictionary
|
||||||
if isinstance(edge_props, str):
|
|
||||||
try:
|
|
||||||
edge_props = json.loads(edge_props)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning(
|
|
||||||
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
# Process string result, parse it to JSON dictionary
|
|
||||||
if isinstance(edge_props, str):
|
if isinstance(edge_props, str):
|
||||||
try:
|
try:
|
||||||
edge_props = json.loads(edge_props)
|
edge_props = json.loads(edge_props)
|
||||||
@ -3671,10 +3698,20 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
if not node_ids:
|
if not node_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
seen = set()
|
||||||
|
unique_ids: list[str] = []
|
||||||
|
for nid in node_ids:
|
||||||
|
n = self._normalize_node_id(nid)
|
||||||
|
if n and n not in seen:
|
||||||
|
seen.add(n)
|
||||||
|
unique_ids.append(n)
|
||||||
|
|
||||||
|
edges_norm: dict[str, list[tuple[str, str]]] = {n: [] for n in unique_ids}
|
||||||
|
|
||||||
|
for i in range(0, len(unique_ids), batch_size):
|
||||||
|
batch = unique_ids[i:i + batch_size]
|
||||||
# Format node IDs for the query
|
# Format node IDs for the query
|
||||||
formatted_ids = ", ".join(
|
formatted_ids = ", ".join([f'"{n}"' for n in batch])
|
||||||
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
|
|
||||||
)
|
|
||||||
|
|
||||||
outgoing_query = """SELECT * FROM cypher('%s', $$
|
outgoing_query = """SELECT * FROM cypher('%s', $$
|
||||||
UNWIND [%s] AS node_id
|
UNWIND [%s] AS node_id
|
||||||
@ -3699,21 +3736,24 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
outgoing_results = await self._query(outgoing_query)
|
outgoing_results = await self._query(outgoing_query)
|
||||||
incoming_results = await self._query(incoming_query)
|
incoming_results = await self._query(incoming_query)
|
||||||
|
|
||||||
nodes_edges_dict = {node_id: [] for node_id in node_ids}
|
|
||||||
|
|
||||||
for result in outgoing_results:
|
for result in outgoing_results:
|
||||||
if result["node_id"] and result["connected_id"]:
|
if result["node_id"] and result["connected_id"]:
|
||||||
nodes_edges_dict[result["node_id"]].append(
|
edges_norm[result["node_id"]].append(
|
||||||
(result["node_id"], result["connected_id"])
|
(result["node_id"], result["connected_id"])
|
||||||
)
|
)
|
||||||
|
|
||||||
for result in incoming_results:
|
for result in incoming_results:
|
||||||
if result["node_id"] and result["connected_id"]:
|
if result["node_id"] and result["connected_id"]:
|
||||||
nodes_edges_dict[result["node_id"]].append(
|
edges_norm[result["node_id"]].append(
|
||||||
(result["connected_id"], result["node_id"])
|
(result["connected_id"], result["node_id"])
|
||||||
)
|
)
|
||||||
|
|
||||||
return nodes_edges_dict
|
out: dict[str, list[tuple[str, str]]] = {}
|
||||||
|
for orig in node_ids:
|
||||||
|
n = self._normalize_node_id(orig)
|
||||||
|
out[orig] = edges_norm.get(n, [])
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
async def get_all_labels(self) -> list[str]:
|
async def get_all_labels(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user