diff --git a/examples/test_postgres.py b/examples/test_postgres.py new file mode 100644 index 00000000..e1f796c6 --- /dev/null +++ b/examples/test_postgres.py @@ -0,0 +1,51 @@ +import os +import asyncio +from lightrag.kg.postgres_impl import PGGraphStorage +from lightrag.llm.ollama import ollama_embedding +from lightrag.utils import EmbeddingFunc + +######### +# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() +# import nest_asyncio +# nest_asyncio.apply() +######### + +WORKING_DIR = "./local_neo4jWorkDir" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +# AGE +os.environ["AGE_GRAPH_NAME"] = "dickens" + +os.environ["POSTGRES_HOST"] = "localhost" +os.environ["POSTGRES_PORT"] = "15432" +os.environ["POSTGRES_USER"] = "rag" +os.environ["POSTGRES_PASSWORD"] = "rag" +os.environ["POSTGRES_DATABASE"] = "rag" + + +async def main(): + graph_db = PGGraphStorage( + namespace="dickens", + embedding_func=EmbeddingFunc( + embedding_dim=1024, + max_token_size=8192, + func=lambda texts: ollama_embedding( + texts, embed_model="bge-m3", host="http://localhost:11434" + ), + ), + global_config={}, + ) + await graph_db.initialize() + labels = await graph_db.get_all_labels() + print("all labels", labels) + + res = await graph_db.get_knowledge_graph("FEZZIWIG") + print("knowledge graphs", res) + + await graph_db.finalize() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 644c47cd..3a636e6a 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -810,42 +810,85 @@ class PGGraphStorage(BaseGraphStorage): v = record[k] # agtype comes back '{key: value}::type' which must be parsed if isinstance(v, str) and "::" in v: - dtype = v.split("::")[-1] - v = v.split("::")[0] - if dtype == "vertex": - vertex = json.loads(v) - vertices[vertex["id"]] = vertex.get("properties") + if v.startswith("[") and v.endswith("]"): + if "::vertex" not in v: + continue + v = v.replace("::vertex", "") + vertexes = json.loads(v) + for vertex in vertexes: + vertices[vertex["id"]] = vertex.get("properties") + else: + dtype = v.split("::")[-1] + v = v.split("::")[0] + if dtype == "vertex": + vertex = json.loads(v) + vertices[vertex["id"]] = vertex.get("properties") # iterate returned fields and parse appropriately for k in record.keys(): v = record[k] if isinstance(v, str) and "::" in v: - dtype = v.split("::")[-1] - v = v.split("::")[0] - else: - dtype = "" + if v.startswith("[") and v.endswith("]"): + if "::vertex" in v: + v = v.replace("::vertex", "") + vertexes = json.loads(v) + dl = [] + for vertex in vertexes: + prop = vertex.get("properties") + if not prop: + prop = {} + prop["label"] = PGGraphStorage._decode_graph_label( + prop["node_id"] + ) + dl.append(prop) + d[k] = dl - if dtype == "vertex": - vertex = json.loads(v) - field = vertex.get("properties") - if not field: - field = {} - field["label"] = PGGraphStorage._decode_graph_label(field["node_id"]) - d[k] = field - # convert edge from id-label->id by replacing id with node information - # we only do this if the vertex was also returned in the query - # this is an attempt to be consistent with neo4j implementation - elif dtype == "edge": - edge = json.loads(v) - d[k] = ( - vertices.get(edge["start_id"], {}), - edge[ - "label" - ], # we don't use decode_graph_label(), since edge label is always "DIRECTED" - vertices.get(edge["end_id"], {}), - ) + elif "::edge" in v: + v = v.replace("::edge", "") + edges = json.loads(v) + dl = [] + for edge in edges: + dl.append( + ( + vertices[edge["start_id"]], + edge["label"], + vertices[edge["end_id"]], + ) + ) + d[k] = dl + else: + print("WARNING: unsupported type") + continue + + else: + dtype = v.split("::")[-1] + v = v.split("::")[0] + if dtype == "vertex": + vertex = json.loads(v) + field = vertex.get("properties") + if not field: + field = {} + field["label"] = PGGraphStorage._decode_graph_label( + field["node_id"] + ) + d[k] = field + # convert edge from id-label->id by replacing id with node information + # we only do this if the vertex was also returned in the query + # this is an attempt to be consistent with neo4j implementation + elif dtype == "edge": + edge = json.loads(v) + d[k] = ( + vertices.get(edge["start_id"], {}), + edge[ + "label" + ], # we don't use decode_graph_label(), since edge label is always "DIRECTED" + vertices.get(edge["end_id"], {}), + ) else: - d[k] = json.loads(v) if isinstance(v, str) else v + if v is None or (v.count("{") < 1 and v.count("[") < 1): + d[k] = v + else: + d[k] = json.loads(v) if isinstance(v, str) else v return d @@ -1319,7 +1362,7 @@ class PGGraphStorage(BaseGraphStorage): OPTIONAL MATCH p = (n)-[*..%d]-(m) RETURN nodes(p) AS nodes, relationships(p) AS relationships LIMIT %d - $$) AS (nodes agtype[], relationships agtype[])""" % ( + $$) AS (nodes agtype, relationships agtype)""" % ( self.graph_name, encoded_node_label, max_depth, @@ -1328,17 +1371,23 @@ class PGGraphStorage(BaseGraphStorage): results = await self._query(query) - nodes = set() + nodes = {} edges = [] + unique_edge_ids = set() for result in results: if node_label == "*": if result["n"]: node = result["n"] - nodes.add(self._decode_graph_label(node["node_id"])) + node_id = self._decode_graph_label(node["node_id"]) + if node_id not in nodes: + nodes[node_id] = node + if result["m"]: node = result["m"] - nodes.add(self._decode_graph_label(node["node_id"])) + node_id = self._decode_graph_label(node["node_id"]) + if node_id not in nodes: + nodes[node_id] = node if result["r"]: edge = result["r"] src_id = self._decode_graph_label(edge["start_id"]) @@ -1347,16 +1396,36 @@ class PGGraphStorage(BaseGraphStorage): else: if result["nodes"]: for node in result["nodes"]: - nodes.add(self._decode_graph_label(node["node_id"])) + node_id = self._decode_graph_label(node["node_id"]) + if node_id not in nodes: + nodes[node_id] = node + if result["relationships"]: - for edge in result["relationships"]: - src_id = self._decode_graph_label(edge["start_id"]) - tgt_id = self._decode_graph_label(edge["end_id"]) - edges.append((src_id, tgt_id)) + for edge in result["relationships"]: # src --DIRECTED--> target + src_id = self._decode_graph_label(edge[0]["node_id"]) + tgt_id = self._decode_graph_label(edge[2]["node_id"]) + id = src_id + "," + tgt_id + if id in unique_edge_ids: + continue + else: + unique_edge_ids.add(id) + edges.append( + (id, src_id, tgt_id, {"source": edge[0], "target": edge[2]}) + ) kg = KnowledgeGraph( - nodes=[KnowledgeGraphNode(id=node_id) for node_id in nodes], - edges=[KnowledgeGraphEdge(source=src, target=tgt) for src, tgt in edges], + nodes=[ + KnowledgeGraphNode( + id=node_id, labels=[node_id], properties=nodes[node_id] + ) + for node_id in nodes + ], + edges=[ + KnowledgeGraphEdge( + id=id, type="DIRECTED", source=src, target=tgt, properties=props + ) + for id, src, tgt, props in edges + ], ) return kg