mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-07-29 11:50:37 +00:00
fix the postgres get all labels and get knowledge graph
This commit is contained in:
parent
27ab894d00
commit
b7f67eda21
51
examples/test_postgres.py
Normal file
51
examples/test_postgres.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from lightrag.kg.postgres_impl import PGGraphStorage
|
||||||
|
from lightrag.llm.ollama import ollama_embedding
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
|
||||||
|
#########
|
||||||
|
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
|
||||||
|
# import nest_asyncio
|
||||||
|
# nest_asyncio.apply()
|
||||||
|
#########
|
||||||
|
|
||||||
|
WORKING_DIR = "./local_neo4jWorkDir"
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
# AGE
|
||||||
|
os.environ["AGE_GRAPH_NAME"] = "dickens"
|
||||||
|
|
||||||
|
os.environ["POSTGRES_HOST"] = "localhost"
|
||||||
|
os.environ["POSTGRES_PORT"] = "15432"
|
||||||
|
os.environ["POSTGRES_USER"] = "rag"
|
||||||
|
os.environ["POSTGRES_PASSWORD"] = "rag"
|
||||||
|
os.environ["POSTGRES_DATABASE"] = "rag"
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
graph_db = PGGraphStorage(
|
||||||
|
namespace="dickens",
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=1024,
|
||||||
|
max_token_size=8192,
|
||||||
|
func=lambda texts: ollama_embedding(
|
||||||
|
texts, embed_model="bge-m3", host="http://localhost:11434"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
global_config={},
|
||||||
|
)
|
||||||
|
await graph_db.initialize()
|
||||||
|
labels = await graph_db.get_all_labels()
|
||||||
|
print("all labels", labels)
|
||||||
|
|
||||||
|
res = await graph_db.get_knowledge_graph("FEZZIWIG")
|
||||||
|
print("knowledge graphs", res)
|
||||||
|
|
||||||
|
await graph_db.finalize()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
@ -810,42 +810,85 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
v = record[k]
|
v = record[k]
|
||||||
# agtype comes back '{key: value}::type' which must be parsed
|
# agtype comes back '{key: value}::type' which must be parsed
|
||||||
if isinstance(v, str) and "::" in v:
|
if isinstance(v, str) and "::" in v:
|
||||||
dtype = v.split("::")[-1]
|
if v.startswith("[") and v.endswith("]"):
|
||||||
v = v.split("::")[0]
|
if "::vertex" not in v:
|
||||||
if dtype == "vertex":
|
continue
|
||||||
vertex = json.loads(v)
|
v = v.replace("::vertex", "")
|
||||||
vertices[vertex["id"]] = vertex.get("properties")
|
vertexes = json.loads(v)
|
||||||
|
for vertex in vertexes:
|
||||||
|
vertices[vertex["id"]] = vertex.get("properties")
|
||||||
|
else:
|
||||||
|
dtype = v.split("::")[-1]
|
||||||
|
v = v.split("::")[0]
|
||||||
|
if dtype == "vertex":
|
||||||
|
vertex = json.loads(v)
|
||||||
|
vertices[vertex["id"]] = vertex.get("properties")
|
||||||
|
|
||||||
# iterate returned fields and parse appropriately
|
# iterate returned fields and parse appropriately
|
||||||
for k in record.keys():
|
for k in record.keys():
|
||||||
v = record[k]
|
v = record[k]
|
||||||
if isinstance(v, str) and "::" in v:
|
if isinstance(v, str) and "::" in v:
|
||||||
dtype = v.split("::")[-1]
|
if v.startswith("[") and v.endswith("]"):
|
||||||
v = v.split("::")[0]
|
if "::vertex" in v:
|
||||||
else:
|
v = v.replace("::vertex", "")
|
||||||
dtype = ""
|
vertexes = json.loads(v)
|
||||||
|
dl = []
|
||||||
|
for vertex in vertexes:
|
||||||
|
prop = vertex.get("properties")
|
||||||
|
if not prop:
|
||||||
|
prop = {}
|
||||||
|
prop["label"] = PGGraphStorage._decode_graph_label(
|
||||||
|
prop["node_id"]
|
||||||
|
)
|
||||||
|
dl.append(prop)
|
||||||
|
d[k] = dl
|
||||||
|
|
||||||
if dtype == "vertex":
|
elif "::edge" in v:
|
||||||
vertex = json.loads(v)
|
v = v.replace("::edge", "")
|
||||||
field = vertex.get("properties")
|
edges = json.loads(v)
|
||||||
if not field:
|
dl = []
|
||||||
field = {}
|
for edge in edges:
|
||||||
field["label"] = PGGraphStorage._decode_graph_label(field["node_id"])
|
dl.append(
|
||||||
d[k] = field
|
(
|
||||||
# convert edge from id-label->id by replacing id with node information
|
vertices[edge["start_id"]],
|
||||||
# we only do this if the vertex was also returned in the query
|
edge["label"],
|
||||||
# this is an attempt to be consistent with neo4j implementation
|
vertices[edge["end_id"]],
|
||||||
elif dtype == "edge":
|
)
|
||||||
edge = json.loads(v)
|
)
|
||||||
d[k] = (
|
d[k] = dl
|
||||||
vertices.get(edge["start_id"], {}),
|
else:
|
||||||
edge[
|
print("WARNING: unsupported type")
|
||||||
"label"
|
continue
|
||||||
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
|
|
||||||
vertices.get(edge["end_id"], {}),
|
else:
|
||||||
)
|
dtype = v.split("::")[-1]
|
||||||
|
v = v.split("::")[0]
|
||||||
|
if dtype == "vertex":
|
||||||
|
vertex = json.loads(v)
|
||||||
|
field = vertex.get("properties")
|
||||||
|
if not field:
|
||||||
|
field = {}
|
||||||
|
field["label"] = PGGraphStorage._decode_graph_label(
|
||||||
|
field["node_id"]
|
||||||
|
)
|
||||||
|
d[k] = field
|
||||||
|
# convert edge from id-label->id by replacing id with node information
|
||||||
|
# we only do this if the vertex was also returned in the query
|
||||||
|
# this is an attempt to be consistent with neo4j implementation
|
||||||
|
elif dtype == "edge":
|
||||||
|
edge = json.loads(v)
|
||||||
|
d[k] = (
|
||||||
|
vertices.get(edge["start_id"], {}),
|
||||||
|
edge[
|
||||||
|
"label"
|
||||||
|
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
|
||||||
|
vertices.get(edge["end_id"], {}),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
d[k] = json.loads(v) if isinstance(v, str) else v
|
if v is None or (v.count("{") < 1 and v.count("[") < 1):
|
||||||
|
d[k] = v
|
||||||
|
else:
|
||||||
|
d[k] = json.loads(v) if isinstance(v, str) else v
|
||||||
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@ -1319,7 +1362,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
OPTIONAL MATCH p = (n)-[*..%d]-(m)
|
OPTIONAL MATCH p = (n)-[*..%d]-(m)
|
||||||
RETURN nodes(p) AS nodes, relationships(p) AS relationships
|
RETURN nodes(p) AS nodes, relationships(p) AS relationships
|
||||||
LIMIT %d
|
LIMIT %d
|
||||||
$$) AS (nodes agtype[], relationships agtype[])""" % (
|
$$) AS (nodes agtype, relationships agtype)""" % (
|
||||||
self.graph_name,
|
self.graph_name,
|
||||||
encoded_node_label,
|
encoded_node_label,
|
||||||
max_depth,
|
max_depth,
|
||||||
@ -1328,17 +1371,23 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
results = await self._query(query)
|
results = await self._query(query)
|
||||||
|
|
||||||
nodes = set()
|
nodes = {}
|
||||||
edges = []
|
edges = []
|
||||||
|
unique_edge_ids = set()
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
if node_label == "*":
|
if node_label == "*":
|
||||||
if result["n"]:
|
if result["n"]:
|
||||||
node = result["n"]
|
node = result["n"]
|
||||||
nodes.add(self._decode_graph_label(node["node_id"]))
|
node_id = self._decode_graph_label(node["node_id"])
|
||||||
|
if node_id not in nodes:
|
||||||
|
nodes[node_id] = node
|
||||||
|
|
||||||
if result["m"]:
|
if result["m"]:
|
||||||
node = result["m"]
|
node = result["m"]
|
||||||
nodes.add(self._decode_graph_label(node["node_id"]))
|
node_id = self._decode_graph_label(node["node_id"])
|
||||||
|
if node_id not in nodes:
|
||||||
|
nodes[node_id] = node
|
||||||
if result["r"]:
|
if result["r"]:
|
||||||
edge = result["r"]
|
edge = result["r"]
|
||||||
src_id = self._decode_graph_label(edge["start_id"])
|
src_id = self._decode_graph_label(edge["start_id"])
|
||||||
@ -1347,16 +1396,36 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
else:
|
else:
|
||||||
if result["nodes"]:
|
if result["nodes"]:
|
||||||
for node in result["nodes"]:
|
for node in result["nodes"]:
|
||||||
nodes.add(self._decode_graph_label(node["node_id"]))
|
node_id = self._decode_graph_label(node["node_id"])
|
||||||
|
if node_id not in nodes:
|
||||||
|
nodes[node_id] = node
|
||||||
|
|
||||||
if result["relationships"]:
|
if result["relationships"]:
|
||||||
for edge in result["relationships"]:
|
for edge in result["relationships"]: # src --DIRECTED--> target
|
||||||
src_id = self._decode_graph_label(edge["start_id"])
|
src_id = self._decode_graph_label(edge[0]["node_id"])
|
||||||
tgt_id = self._decode_graph_label(edge["end_id"])
|
tgt_id = self._decode_graph_label(edge[2]["node_id"])
|
||||||
edges.append((src_id, tgt_id))
|
id = src_id + "," + tgt_id
|
||||||
|
if id in unique_edge_ids:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
unique_edge_ids.add(id)
|
||||||
|
edges.append(
|
||||||
|
(id, src_id, tgt_id, {"source": edge[0], "target": edge[2]})
|
||||||
|
)
|
||||||
|
|
||||||
kg = KnowledgeGraph(
|
kg = KnowledgeGraph(
|
||||||
nodes=[KnowledgeGraphNode(id=node_id) for node_id in nodes],
|
nodes=[
|
||||||
edges=[KnowledgeGraphEdge(source=src, target=tgt) for src, tgt in edges],
|
KnowledgeGraphNode(
|
||||||
|
id=node_id, labels=[node_id], properties=nodes[node_id]
|
||||||
|
)
|
||||||
|
for node_id in nodes
|
||||||
|
],
|
||||||
|
edges=[
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=id, type="DIRECTED", source=src, target=tgt, properties=props
|
||||||
|
)
|
||||||
|
for id, src, tgt, props in edges
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return kg
|
return kg
|
||||||
|
Loading…
x
Reference in New Issue
Block a user