fix the postgres get all labels and get knowledge graph

This commit is contained in:
Samuel Chan 2025-03-08 11:45:59 +08:00
parent 27ab894d00
commit b7f67eda21
2 changed files with 160 additions and 40 deletions

51
examples/test_postgres.py Normal file
View File

@ -0,0 +1,51 @@
import os
import asyncio
from lightrag.kg.postgres_impl import PGGraphStorage
from lightrag.llm.ollama import ollama_embedding
from lightrag.utils import EmbeddingFunc
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./local_neo4jWorkDir"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# AGE
os.environ["AGE_GRAPH_NAME"] = "dickens"
os.environ["POSTGRES_HOST"] = "localhost"
os.environ["POSTGRES_PORT"] = "15432"
os.environ["POSTGRES_USER"] = "rag"
os.environ["POSTGRES_PASSWORD"] = "rag"
os.environ["POSTGRES_DATABASE"] = "rag"
async def main():
graph_db = PGGraphStorage(
namespace="dickens",
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=lambda texts: ollama_embedding(
texts, embed_model="bge-m3", host="http://localhost:11434"
),
),
global_config={},
)
await graph_db.initialize()
labels = await graph_db.get_all_labels()
print("all labels", labels)
res = await graph_db.get_knowledge_graph("FEZZIWIG")
print("knowledge graphs", res)
await graph_db.finalize()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -810,42 +810,85 @@ class PGGraphStorage(BaseGraphStorage):
v = record[k] v = record[k]
# agtype comes back '{key: value}::type' which must be parsed # agtype comes back '{key: value}::type' which must be parsed
if isinstance(v, str) and "::" in v: if isinstance(v, str) and "::" in v:
dtype = v.split("::")[-1] if v.startswith("[") and v.endswith("]"):
v = v.split("::")[0] if "::vertex" not in v:
if dtype == "vertex": continue
vertex = json.loads(v) v = v.replace("::vertex", "")
vertices[vertex["id"]] = vertex.get("properties") vertexes = json.loads(v)
for vertex in vertexes:
vertices[vertex["id"]] = vertex.get("properties")
else:
dtype = v.split("::")[-1]
v = v.split("::")[0]
if dtype == "vertex":
vertex = json.loads(v)
vertices[vertex["id"]] = vertex.get("properties")
# iterate returned fields and parse appropriately # iterate returned fields and parse appropriately
for k in record.keys(): for k in record.keys():
v = record[k] v = record[k]
if isinstance(v, str) and "::" in v: if isinstance(v, str) and "::" in v:
dtype = v.split("::")[-1] if v.startswith("[") and v.endswith("]"):
v = v.split("::")[0] if "::vertex" in v:
else: v = v.replace("::vertex", "")
dtype = "" vertexes = json.loads(v)
dl = []
for vertex in vertexes:
prop = vertex.get("properties")
if not prop:
prop = {}
prop["label"] = PGGraphStorage._decode_graph_label(
prop["node_id"]
)
dl.append(prop)
d[k] = dl
if dtype == "vertex": elif "::edge" in v:
vertex = json.loads(v) v = v.replace("::edge", "")
field = vertex.get("properties") edges = json.loads(v)
if not field: dl = []
field = {} for edge in edges:
field["label"] = PGGraphStorage._decode_graph_label(field["node_id"]) dl.append(
d[k] = field (
# convert edge from id-label->id by replacing id with node information vertices[edge["start_id"]],
# we only do this if the vertex was also returned in the query edge["label"],
# this is an attempt to be consistent with neo4j implementation vertices[edge["end_id"]],
elif dtype == "edge": )
edge = json.loads(v) )
d[k] = ( d[k] = dl
vertices.get(edge["start_id"], {}), else:
edge[ print("WARNING: unsupported type")
"label" continue
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
vertices.get(edge["end_id"], {}), else:
) dtype = v.split("::")[-1]
v = v.split("::")[0]
if dtype == "vertex":
vertex = json.loads(v)
field = vertex.get("properties")
if not field:
field = {}
field["label"] = PGGraphStorage._decode_graph_label(
field["node_id"]
)
d[k] = field
# convert edge from id-label->id by replacing id with node information
# we only do this if the vertex was also returned in the query
# this is an attempt to be consistent with neo4j implementation
elif dtype == "edge":
edge = json.loads(v)
d[k] = (
vertices.get(edge["start_id"], {}),
edge[
"label"
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
vertices.get(edge["end_id"], {}),
)
else: else:
d[k] = json.loads(v) if isinstance(v, str) else v if v is None or (v.count("{") < 1 and v.count("[") < 1):
d[k] = v
else:
d[k] = json.loads(v) if isinstance(v, str) else v
return d return d
@ -1319,7 +1362,7 @@ class PGGraphStorage(BaseGraphStorage):
OPTIONAL MATCH p = (n)-[*..%d]-(m) OPTIONAL MATCH p = (n)-[*..%d]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT %d LIMIT %d
$$) AS (nodes agtype[], relationships agtype[])""" % ( $$) AS (nodes agtype, relationships agtype)""" % (
self.graph_name, self.graph_name,
encoded_node_label, encoded_node_label,
max_depth, max_depth,
@ -1328,17 +1371,23 @@ class PGGraphStorage(BaseGraphStorage):
results = await self._query(query) results = await self._query(query)
nodes = set() nodes = {}
edges = [] edges = []
unique_edge_ids = set()
for result in results: for result in results:
if node_label == "*": if node_label == "*":
if result["n"]: if result["n"]:
node = result["n"] node = result["n"]
nodes.add(self._decode_graph_label(node["node_id"])) node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["m"]: if result["m"]:
node = result["m"] node = result["m"]
nodes.add(self._decode_graph_label(node["node_id"])) node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["r"]: if result["r"]:
edge = result["r"] edge = result["r"]
src_id = self._decode_graph_label(edge["start_id"]) src_id = self._decode_graph_label(edge["start_id"])
@ -1347,16 +1396,36 @@ class PGGraphStorage(BaseGraphStorage):
else: else:
if result["nodes"]: if result["nodes"]:
for node in result["nodes"]: for node in result["nodes"]:
nodes.add(self._decode_graph_label(node["node_id"])) node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["relationships"]: if result["relationships"]:
for edge in result["relationships"]: for edge in result["relationships"]: # src --DIRECTED--> target
src_id = self._decode_graph_label(edge["start_id"]) src_id = self._decode_graph_label(edge[0]["node_id"])
tgt_id = self._decode_graph_label(edge["end_id"]) tgt_id = self._decode_graph_label(edge[2]["node_id"])
edges.append((src_id, tgt_id)) id = src_id + "," + tgt_id
if id in unique_edge_ids:
continue
else:
unique_edge_ids.add(id)
edges.append(
(id, src_id, tgt_id, {"source": edge[0], "target": edge[2]})
)
kg = KnowledgeGraph( kg = KnowledgeGraph(
nodes=[KnowledgeGraphNode(id=node_id) for node_id in nodes], nodes=[
edges=[KnowledgeGraphEdge(source=src, target=tgt) for src, tgt in edges], KnowledgeGraphNode(
id=node_id, labels=[node_id], properties=nodes[node_id]
)
for node_id in nodes
],
edges=[
KnowledgeGraphEdge(
id=id, type="DIRECTED", source=src, target=tgt, properties=props
)
for id, src, tgt, props in edges
],
) )
return kg return kg