LightRAG/examples/graph_visual_with_neo4j.py
2025-05-18 07:17:21 +08:00

187 lines
6.0 KiB
Python

import os
import json
import xml.etree.ElementTree as ET
from neo4j import GraphDatabase
# Constants
WORKING_DIR = "./dickens"
BATCH_SIZE_NODES = 500
BATCH_SIZE_EDGES = 100
# Neo4j connection credentials
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "your_password"
def xml_to_json(xml_file):
try:
tree = ET.parse(xml_file)
root = tree.getroot()
# Print the root element's tag and attributes to confirm the file has been correctly loaded
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
data = {"nodes": [], "edges": []}
# Use namespace
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
for node in root.findall(".//node", namespace):
node_data = {
"id": node.get("id").strip('"'),
"entity_type": node.find("./data[@key='d1']", namespace).text.strip('"')
if node.find("./data[@key='d1']", namespace) is not None
else "",
"description": node.find("./data[@key='d2']", namespace).text
if node.find("./data[@key='d2']", namespace) is not None
else "",
"source_id": node.find("./data[@key='d3']", namespace).text
if node.find("./data[@key='d3']", namespace) is not None
else "",
}
data["nodes"].append(node_data)
for edge in root.findall(".//edge", namespace):
edge_data = {
"source": edge.get("source").strip('"'),
"target": edge.get("target").strip('"'),
"weight": float(edge.find("./data[@key='d5']", namespace).text)
if edge.find("./data[@key='d5']", namespace) is not None
else 0.0,
"description": edge.find("./data[@key='d6']", namespace).text
if edge.find("./data[@key='d6']", namespace) is not None
else "",
"keywords": edge.find("./data[@key='d7']", namespace).text
if edge.find("./data[@key='d7']", namespace) is not None
else "",
"source_id": edge.find("./data[@key='d8']", namespace).text
if edge.find("./data[@key='d8']", namespace) is not None
else "",
}
data["edges"].append(edge_data)
# Print the number of nodes and edges found
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
return data
except ET.ParseError as e:
print(f"Error parsing XML file: {e}")
return None
except Exception as e:
print(f"An error occurred: {e}")
return None
def convert_xml_to_json(xml_path, output_path):
"""Converts XML file to JSON and saves the output."""
if not os.path.exists(xml_path):
print(f"Error: File not found - {xml_path}")
return None
json_data = xml_to_json(xml_path)
if json_data:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False, indent=2)
print(f"JSON file created: {output_path}")
return json_data
else:
print("Failed to create JSON data")
return None
def process_in_batches(tx, query, data, batch_size):
"""Process data in batches and execute the given query."""
for i in range(0, len(data), batch_size):
batch = data[i : i + batch_size]
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
def main():
# Paths
xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
json_file = os.path.join(WORKING_DIR, "graph_data.json")
# Convert XML to JSON
json_data = convert_xml_to_json(xml_file, json_file)
if json_data is None:
return
# Load nodes and edges
nodes = json_data.get("nodes", [])
edges = json_data.get("edges", [])
# Neo4j queries
create_nodes_query = """
UNWIND $nodes AS node
MERGE (e:Entity {id: node.id})
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
e.displayName = node.id
REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.id]) YIELD node AS labeledNode
RETURN count(*)
"""
create_edges_query = """
UNWIND $edges AS edge
MATCH (source {id: edge.source})
MATCH (target {id: edge.target})
WITH source, target, edge,
CASE
WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
WHEN edge.keywords CONTAINS 'located' THEN 'located'
WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
END AS relType
CALL apoc.create.relationship(source, relType, {
weight: edge.weight,
description: edge.description,
keywords: edge.keywords,
source_id: edge.source_id
}, target) YIELD rel
RETURN count(*)
"""
set_displayname_and_labels_query = """
MATCH (n)
SET n.displayName = n.id
WITH n
CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
RETURN count(*)
"""
# Create a Neo4j driver
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
try:
# Execute queries in batches
with driver.session() as session:
# Insert nodes in batches
session.execute_write(
process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
)
# Insert edges in batches
session.execute_write(
process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
)
# Set displayName and labels
session.run(set_displayname_and_labels_query)
except Exception as e:
print(f"Error occurred: {e}")
finally:
driver.close()
if __name__ == "__main__":
main()