mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-06-26 22:00:19 +00:00

- Enable the llm_cache storage to support get_by_mode_and_id, to improve the performance for using real KV server - Provide an option for the developers to cache the LLM response when extracting entities for a document. Solving the paint point that sometimes the process failed, the processed chunks we need to call LLM again, money and time wasted. With the new option (by default not enabled) enabling, we can cache that result, can significantly save the time and money for beginners.
1040 lines
41 KiB
Python
1040 lines
41 KiB
Python
import asyncio
|
|
import os
|
|
from tqdm.asyncio import tqdm as tqdm_async
|
|
from dataclasses import asdict, dataclass, field
|
|
from datetime import datetime
|
|
from functools import partial
|
|
from typing import Type, cast, Dict
|
|
|
|
from .llm import (
|
|
gpt_4o_mini_complete,
|
|
openai_embedding,
|
|
)
|
|
from .operate import (
|
|
chunking_by_token_size,
|
|
extract_entities,
|
|
# local_query,global_query,hybrid_query,
|
|
kg_query,
|
|
naive_query,
|
|
mix_kg_vector_query,
|
|
)
|
|
|
|
from .utils import (
|
|
EmbeddingFunc,
|
|
compute_mdhash_id,
|
|
limit_async_func_call,
|
|
convert_response_to_json,
|
|
logger,
|
|
set_logger,
|
|
)
|
|
from .base import (
|
|
BaseGraphStorage,
|
|
BaseKVStorage,
|
|
BaseVectorStorage,
|
|
StorageNameSpace,
|
|
QueryParam,
|
|
DocStatus,
|
|
)
|
|
|
|
from .storage import (
|
|
JsonKVStorage,
|
|
NanoVectorDBStorage,
|
|
NetworkXStorage,
|
|
JsonDocStatusStorage,
|
|
)
|
|
|
|
from .prompt import GRAPH_FIELD_SEP
|
|
|
|
# future KG integrations
|
|
|
|
# from .kg.ArangoDB_impl import (
|
|
# GraphStorage as ArangoDBStorage
|
|
# )
|
|
|
|
|
|
def lazy_external_import(module_name: str, class_name: str):
|
|
"""Lazily import a class from an external module based on the package of the caller."""
|
|
|
|
# Get the caller's module and package
|
|
import inspect
|
|
|
|
caller_frame = inspect.currentframe().f_back
|
|
module = inspect.getmodule(caller_frame)
|
|
package = module.__package__ if module else None
|
|
|
|
def import_class(*args, **kwargs):
|
|
import importlib
|
|
|
|
# Import the module using importlib
|
|
module = importlib.import_module(module_name, package=package)
|
|
|
|
# Get the class from the module and instantiate it
|
|
cls = getattr(module, class_name)
|
|
return cls(*args, **kwargs)
|
|
|
|
return import_class
|
|
|
|
|
|
Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage")
|
|
OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage")
|
|
OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage")
|
|
OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage")
|
|
MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
|
|
MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
|
|
ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
|
|
TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
|
|
TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
|
|
TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage")
|
|
PGKVStorage = lazy_external_import(".kg.postgres_impl", "PGKVStorage")
|
|
PGVectorStorage = lazy_external_import(".kg.postgres_impl", "PGVectorStorage")
|
|
AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")
|
|
PGGraphStorage = lazy_external_import(".kg.postgres_impl", "PGGraphStorage")
|
|
GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage")
|
|
PGDocStatusStorage = lazy_external_import(".kg.postgres_impl", "PGDocStatusStorage")
|
|
|
|
|
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
"""
|
|
Ensure that there is always an event loop available.
|
|
|
|
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
|
it creates a new event loop and sets it as the current event loop.
|
|
|
|
Returns:
|
|
asyncio.AbstractEventLoop: The current or newly created event loop.
|
|
"""
|
|
try:
|
|
# Try to get the current event loop
|
|
current_loop = asyncio.get_event_loop()
|
|
if current_loop.is_closed():
|
|
raise RuntimeError("Event loop is closed.")
|
|
return current_loop
|
|
|
|
except RuntimeError:
|
|
# If no event loop exists or it is closed, create a new one
|
|
logger.info("Creating a new event loop in main thread.")
|
|
new_loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(new_loop)
|
|
return new_loop
|
|
|
|
|
|
@dataclass
|
|
class LightRAG:
|
|
working_dir: str = field(
|
|
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
|
)
|
|
# Default not to use embedding cache
|
|
embedding_cache_config: dict = field(
|
|
default_factory=lambda: {
|
|
"enabled": False,
|
|
"similarity_threshold": 0.95,
|
|
"use_llm_check": False,
|
|
}
|
|
)
|
|
kv_storage: str = field(default="JsonKVStorage")
|
|
vector_storage: str = field(default="NanoVectorDBStorage")
|
|
graph_storage: str = field(default="NetworkXStorage")
|
|
|
|
current_log_level = logger.level
|
|
log_level: str = field(default=current_log_level)
|
|
|
|
# text chunking
|
|
chunk_token_size: int = 1200
|
|
chunk_overlap_token_size: int = 100
|
|
tiktoken_model_name: str = "gpt-4o-mini"
|
|
|
|
# entity extraction
|
|
entity_extract_max_gleaning: int = 1
|
|
entity_summary_to_max_tokens: int = 500
|
|
|
|
# node embedding
|
|
node_embedding_algorithm: str = "node2vec"
|
|
node2vec_params: dict = field(
|
|
default_factory=lambda: {
|
|
"dimensions": 1536,
|
|
"num_walks": 10,
|
|
"walk_length": 40,
|
|
"window_size": 2,
|
|
"iterations": 3,
|
|
"random_seed": 3,
|
|
}
|
|
)
|
|
|
|
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
|
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
|
embedding_batch_num: int = 32
|
|
embedding_func_max_async: int = 16
|
|
|
|
# LLM
|
|
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
|
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
|
llm_model_max_token_size: int = 32768
|
|
llm_model_max_async: int = 16
|
|
llm_model_kwargs: dict = field(default_factory=dict)
|
|
|
|
# storage
|
|
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
|
|
|
enable_llm_cache: bool = True
|
|
# Sometimes there are some reason the LLM failed at Extracting Entities, and we want to continue without LLM cost, we can use this flag
|
|
enable_llm_cache_for_entity_extract: bool = False
|
|
|
|
# extension
|
|
addon_params: dict = field(default_factory=dict)
|
|
convert_response_to_json_func: callable = convert_response_to_json
|
|
|
|
# Add new field for document status storage type
|
|
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
|
|
|
def __post_init__(self):
|
|
log_file = os.path.join("lightrag.log")
|
|
set_logger(log_file)
|
|
logger.setLevel(self.log_level)
|
|
|
|
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
|
|
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
|
|
|
# @TODO: should move all storage setup here to leverage initial start params attached to self.
|
|
|
|
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
|
self._get_storage_class()[self.kv_storage]
|
|
)
|
|
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
|
|
self.vector_storage
|
|
]
|
|
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
|
|
self.graph_storage
|
|
]
|
|
|
|
if not os.path.exists(self.working_dir):
|
|
logger.info(f"Creating working directory {self.working_dir}")
|
|
os.makedirs(self.working_dir)
|
|
|
|
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
|
namespace="llm_response_cache",
|
|
global_config=asdict(self),
|
|
embedding_func=None,
|
|
)
|
|
|
|
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
|
self.embedding_func
|
|
)
|
|
|
|
####
|
|
# add embedding func by walter
|
|
####
|
|
self.full_docs = self.key_string_value_json_storage_cls(
|
|
namespace="full_docs",
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
)
|
|
self.text_chunks = self.key_string_value_json_storage_cls(
|
|
namespace="text_chunks",
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
)
|
|
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
|
namespace="chunk_entity_relation",
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
)
|
|
####
|
|
# add embedding func by walter over
|
|
####
|
|
|
|
self.entities_vdb = self.vector_db_storage_cls(
|
|
namespace="entities",
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
meta_fields={"entity_name"},
|
|
)
|
|
self.relationships_vdb = self.vector_db_storage_cls(
|
|
namespace="relationships",
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
meta_fields={"src_id", "tgt_id"},
|
|
)
|
|
self.chunks_vdb = self.vector_db_storage_cls(
|
|
namespace="chunks",
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
)
|
|
|
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
|
partial(
|
|
self.llm_model_func,
|
|
hashing_kv=self.llm_response_cache
|
|
if self.llm_response_cache
|
|
and hasattr(self.llm_response_cache, "global_config")
|
|
else self.key_string_value_json_storage_cls(
|
|
namespace="llm_response_cache",
|
|
global_config=asdict(self),
|
|
embedding_func=None,
|
|
),
|
|
**self.llm_model_kwargs,
|
|
)
|
|
)
|
|
|
|
# Initialize document status storage
|
|
self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage]
|
|
self.doc_status = self.doc_status_storage_cls(
|
|
namespace="doc_status",
|
|
global_config=asdict(self),
|
|
embedding_func=None,
|
|
)
|
|
|
|
def _get_storage_class(self) -> dict:
|
|
return {
|
|
# kv storage
|
|
"JsonKVStorage": JsonKVStorage,
|
|
"OracleKVStorage": OracleKVStorage,
|
|
"MongoKVStorage": MongoKVStorage,
|
|
"TiDBKVStorage": TiDBKVStorage,
|
|
# vector storage
|
|
"NanoVectorDBStorage": NanoVectorDBStorage,
|
|
"OracleVectorDBStorage": OracleVectorDBStorage,
|
|
"MilvusVectorDBStorge": MilvusVectorDBStorge,
|
|
"ChromaVectorDBStorage": ChromaVectorDBStorage,
|
|
"TiDBVectorDBStorage": TiDBVectorDBStorage,
|
|
# graph storage
|
|
"NetworkXStorage": NetworkXStorage,
|
|
"Neo4JStorage": Neo4JStorage,
|
|
"OracleGraphStorage": OracleGraphStorage,
|
|
"AGEStorage": AGEStorage,
|
|
"PGGraphStorage": PGGraphStorage,
|
|
"PGKVStorage": PGKVStorage,
|
|
"PGDocStatusStorage": PGDocStatusStorage,
|
|
"PGVectorStorage": PGVectorStorage,
|
|
"TiDBGraphStorage": TiDBGraphStorage,
|
|
"GremlinStorage": GremlinStorage,
|
|
# "ArangoDBStorage": ArangoDBStorage
|
|
"JsonDocStatusStorage": JsonDocStatusStorage,
|
|
}
|
|
|
|
def insert(self, string_or_strings):
|
|
loop = always_get_an_event_loop()
|
|
return loop.run_until_complete(self.ainsert(string_or_strings))
|
|
|
|
async def ainsert(self, string_or_strings):
|
|
"""Insert documents with checkpoint support
|
|
|
|
Args:
|
|
string_or_strings: Single document string or list of document strings
|
|
"""
|
|
if isinstance(string_or_strings, str):
|
|
string_or_strings = [string_or_strings]
|
|
|
|
# 1. Remove duplicate contents from the list
|
|
unique_contents = list(set(doc.strip() for doc in string_or_strings))
|
|
|
|
# 2. Generate document IDs and initial status
|
|
new_docs = {
|
|
compute_mdhash_id(content, prefix="doc-"): {
|
|
"content": content,
|
|
"content_summary": self._get_content_summary(content),
|
|
"content_length": len(content),
|
|
"status": DocStatus.PENDING,
|
|
"created_at": datetime.now().isoformat(),
|
|
"updated_at": datetime.now().isoformat(),
|
|
}
|
|
for content in unique_contents
|
|
}
|
|
|
|
# 3. Filter out already processed documents
|
|
_add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
|
|
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
|
|
|
if not new_docs:
|
|
logger.info("All documents have been processed or are duplicates")
|
|
return
|
|
|
|
logger.info(f"Processing {len(new_docs)} new unique documents")
|
|
|
|
# Process documents in batches
|
|
batch_size = self.addon_params.get("insert_batch_size", 10)
|
|
for i in range(0, len(new_docs), batch_size):
|
|
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
|
|
|
|
for doc_id, doc in tqdm_async(
|
|
batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
|
|
):
|
|
try:
|
|
# Update status to processing
|
|
doc_status = {
|
|
"content_summary": doc["content_summary"],
|
|
"content_length": doc["content_length"],
|
|
"status": DocStatus.PROCESSING,
|
|
"created_at": doc["created_at"],
|
|
"updated_at": datetime.now().isoformat(),
|
|
}
|
|
await self.doc_status.upsert({doc_id: doc_status})
|
|
|
|
# Generate chunks from document
|
|
chunks = {
|
|
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
|
**dp,
|
|
"full_doc_id": doc_id,
|
|
}
|
|
for dp in chunking_by_token_size(
|
|
doc["content"],
|
|
overlap_token_size=self.chunk_overlap_token_size,
|
|
max_token_size=self.chunk_token_size,
|
|
tiktoken_model=self.tiktoken_model_name,
|
|
)
|
|
}
|
|
|
|
# Update status with chunks information
|
|
doc_status.update(
|
|
{
|
|
"chunks_count": len(chunks),
|
|
"updated_at": datetime.now().isoformat(),
|
|
}
|
|
)
|
|
await self.doc_status.upsert({doc_id: doc_status})
|
|
|
|
try:
|
|
# Store chunks in vector database
|
|
await self.chunks_vdb.upsert(chunks)
|
|
|
|
# Extract and store entities and relationships
|
|
maybe_new_kg = await extract_entities(
|
|
chunks,
|
|
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
|
entity_vdb=self.entities_vdb,
|
|
relationships_vdb=self.relationships_vdb,
|
|
llm_response_cache=self.llm_response_cache,
|
|
global_config=asdict(self),
|
|
)
|
|
|
|
if maybe_new_kg is None:
|
|
raise Exception(
|
|
"Failed to extract entities and relationships"
|
|
)
|
|
|
|
self.chunk_entity_relation_graph = maybe_new_kg
|
|
|
|
# Store original document and chunks
|
|
await self.full_docs.upsert(
|
|
{doc_id: {"content": doc["content"]}}
|
|
)
|
|
await self.text_chunks.upsert(chunks)
|
|
|
|
# Update status to processed
|
|
doc_status.update(
|
|
{
|
|
"status": DocStatus.PROCESSED,
|
|
"updated_at": datetime.now().isoformat(),
|
|
}
|
|
)
|
|
await self.doc_status.upsert({doc_id: doc_status})
|
|
|
|
except Exception as e:
|
|
# Mark as failed if any step fails
|
|
doc_status.update(
|
|
{
|
|
"status": DocStatus.FAILED,
|
|
"error": str(e),
|
|
"updated_at": datetime.now().isoformat(),
|
|
}
|
|
)
|
|
await self.doc_status.upsert({doc_id: doc_status})
|
|
raise e
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
|
|
logger.error(error_msg)
|
|
continue
|
|
|
|
finally:
|
|
# Ensure all indexes are updated after each document
|
|
await self._insert_done()
|
|
|
|
async def _insert_done(self):
|
|
tasks = []
|
|
for storage_inst in [
|
|
self.full_docs,
|
|
self.text_chunks,
|
|
self.llm_response_cache,
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.chunks_vdb,
|
|
self.chunk_entity_relation_graph,
|
|
]:
|
|
if storage_inst is None:
|
|
continue
|
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
|
await asyncio.gather(*tasks)
|
|
|
|
def insert_custom_kg(self, custom_kg: dict):
|
|
loop = always_get_an_event_loop()
|
|
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
|
|
|
async def ainsert_custom_kg(self, custom_kg: dict):
|
|
update_storage = False
|
|
try:
|
|
# Insert chunks into vector storage
|
|
all_chunks_data = {}
|
|
chunk_to_source_map = {}
|
|
for chunk_data in custom_kg.get("chunks", []):
|
|
chunk_content = chunk_data["content"]
|
|
source_id = chunk_data["source_id"]
|
|
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
|
|
|
chunk_entry = {"content": chunk_content.strip(), "source_id": source_id}
|
|
all_chunks_data[chunk_id] = chunk_entry
|
|
chunk_to_source_map[source_id] = chunk_id
|
|
update_storage = True
|
|
|
|
if self.chunks_vdb is not None and all_chunks_data:
|
|
await self.chunks_vdb.upsert(all_chunks_data)
|
|
if self.text_chunks is not None and all_chunks_data:
|
|
await self.text_chunks.upsert(all_chunks_data)
|
|
|
|
# Insert entities into knowledge graph
|
|
all_entities_data = []
|
|
for entity_data in custom_kg.get("entities", []):
|
|
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
|
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
|
description = entity_data.get("description", "No description provided")
|
|
# source_id = entity_data["source_id"]
|
|
source_chunk_id = entity_data.get("source_id", "UNKNOWN")
|
|
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
|
|
|
|
# Log if source_id is UNKNOWN
|
|
if source_id == "UNKNOWN":
|
|
logger.warning(
|
|
f"Entity '{entity_name}' has an UNKNOWN source_id. Please check the source mapping."
|
|
)
|
|
|
|
# Prepare node data
|
|
node_data = {
|
|
"entity_type": entity_type,
|
|
"description": description,
|
|
"source_id": source_id,
|
|
}
|
|
# Insert node data into the knowledge graph
|
|
await self.chunk_entity_relation_graph.upsert_node(
|
|
entity_name, node_data=node_data
|
|
)
|
|
node_data["entity_name"] = entity_name
|
|
all_entities_data.append(node_data)
|
|
update_storage = True
|
|
|
|
# Insert relationships into knowledge graph
|
|
all_relationships_data = []
|
|
for relationship_data in custom_kg.get("relationships", []):
|
|
src_id = f'"{relationship_data["src_id"].upper()}"'
|
|
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
|
description = relationship_data["description"]
|
|
keywords = relationship_data["keywords"]
|
|
weight = relationship_data.get("weight", 1.0)
|
|
# source_id = relationship_data["source_id"]
|
|
source_chunk_id = relationship_data.get("source_id", "UNKNOWN")
|
|
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
|
|
|
|
# Log if source_id is UNKNOWN
|
|
if source_id == "UNKNOWN":
|
|
logger.warning(
|
|
f"Relationship from '{src_id}' to '{tgt_id}' has an UNKNOWN source_id. Please check the source mapping."
|
|
)
|
|
|
|
# Check if nodes exist in the knowledge graph
|
|
for need_insert_id in [src_id, tgt_id]:
|
|
if not (
|
|
await self.chunk_entity_relation_graph.has_node(need_insert_id)
|
|
):
|
|
await self.chunk_entity_relation_graph.upsert_node(
|
|
need_insert_id,
|
|
node_data={
|
|
"source_id": source_id,
|
|
"description": "UNKNOWN",
|
|
"entity_type": "UNKNOWN",
|
|
},
|
|
)
|
|
|
|
# Insert edge into the knowledge graph
|
|
await self.chunk_entity_relation_graph.upsert_edge(
|
|
src_id,
|
|
tgt_id,
|
|
edge_data={
|
|
"weight": weight,
|
|
"description": description,
|
|
"keywords": keywords,
|
|
"source_id": source_id,
|
|
},
|
|
)
|
|
edge_data = {
|
|
"src_id": src_id,
|
|
"tgt_id": tgt_id,
|
|
"description": description,
|
|
"keywords": keywords,
|
|
}
|
|
all_relationships_data.append(edge_data)
|
|
update_storage = True
|
|
|
|
# Insert entities into vector storage if needed
|
|
if self.entities_vdb is not None:
|
|
data_for_vdb = {
|
|
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
|
"content": dp["entity_name"] + dp["description"],
|
|
"entity_name": dp["entity_name"],
|
|
}
|
|
for dp in all_entities_data
|
|
}
|
|
await self.entities_vdb.upsert(data_for_vdb)
|
|
|
|
# Insert relationships into vector storage if needed
|
|
if self.relationships_vdb is not None:
|
|
data_for_vdb = {
|
|
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
|
"src_id": dp["src_id"],
|
|
"tgt_id": dp["tgt_id"],
|
|
"content": dp["keywords"]
|
|
+ dp["src_id"]
|
|
+ dp["tgt_id"]
|
|
+ dp["description"],
|
|
}
|
|
for dp in all_relationships_data
|
|
}
|
|
await self.relationships_vdb.upsert(data_for_vdb)
|
|
finally:
|
|
if update_storage:
|
|
await self._insert_done()
|
|
|
|
def query(self, query: str, param: QueryParam = QueryParam()):
|
|
loop = always_get_an_event_loop()
|
|
return loop.run_until_complete(self.aquery(query, param))
|
|
|
|
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
|
if param.mode in ["local", "global", "hybrid"]:
|
|
response = await kg_query(
|
|
query,
|
|
self.chunk_entity_relation_graph,
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.text_chunks,
|
|
param,
|
|
asdict(self),
|
|
hashing_kv=self.llm_response_cache
|
|
if self.llm_response_cache
|
|
and hasattr(self.llm_response_cache, "global_config")
|
|
else self.key_string_value_json_storage_cls(
|
|
namespace="llm_response_cache",
|
|
global_config=asdict(self),
|
|
embedding_func=None,
|
|
),
|
|
)
|
|
elif param.mode == "naive":
|
|
response = await naive_query(
|
|
query,
|
|
self.chunks_vdb,
|
|
self.text_chunks,
|
|
param,
|
|
asdict(self),
|
|
hashing_kv=self.llm_response_cache
|
|
if self.llm_response_cache
|
|
and hasattr(self.llm_response_cache, "global_config")
|
|
else self.key_string_value_json_storage_cls(
|
|
namespace="llm_response_cache",
|
|
global_config=asdict(self),
|
|
embedding_func=None,
|
|
),
|
|
)
|
|
elif param.mode == "mix":
|
|
response = await mix_kg_vector_query(
|
|
query,
|
|
self.chunk_entity_relation_graph,
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.chunks_vdb,
|
|
self.text_chunks,
|
|
param,
|
|
asdict(self),
|
|
hashing_kv=self.llm_response_cache
|
|
if self.llm_response_cache
|
|
and hasattr(self.llm_response_cache, "global_config")
|
|
else self.key_string_value_json_storage_cls(
|
|
namespace="llm_response_cache",
|
|
global_config=asdict(self),
|
|
embedding_func=None,
|
|
),
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown mode {param.mode}")
|
|
await self._query_done()
|
|
return response
|
|
|
|
async def _query_done(self):
|
|
tasks = []
|
|
for storage_inst in [self.llm_response_cache]:
|
|
if storage_inst is None:
|
|
continue
|
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
|
await asyncio.gather(*tasks)
|
|
|
|
def delete_by_entity(self, entity_name: str):
|
|
loop = always_get_an_event_loop()
|
|
return loop.run_until_complete(self.adelete_by_entity(entity_name))
|
|
|
|
async def adelete_by_entity(self, entity_name: str):
|
|
entity_name = f'"{entity_name.upper()}"'
|
|
|
|
try:
|
|
await self.entities_vdb.delete_entity(entity_name)
|
|
await self.relationships_vdb.delete_entity_relation(entity_name)
|
|
await self.chunk_entity_relation_graph.delete_node(entity_name)
|
|
|
|
logger.info(
|
|
f"Entity '{entity_name}' and its relationships have been deleted."
|
|
)
|
|
await self._delete_by_entity_done()
|
|
except Exception as e:
|
|
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
|
|
|
async def _delete_by_entity_done(self):
|
|
tasks = []
|
|
for storage_inst in [
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.chunk_entity_relation_graph,
|
|
]:
|
|
if storage_inst is None:
|
|
continue
|
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
|
await asyncio.gather(*tasks)
|
|
|
|
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
|
|
"""Get summary of document content
|
|
|
|
Args:
|
|
content: Original document content
|
|
max_length: Maximum length of summary
|
|
|
|
Returns:
|
|
Truncated content with ellipsis if needed
|
|
"""
|
|
content = content.strip()
|
|
if len(content) <= max_length:
|
|
return content
|
|
return content[:max_length] + "..."
|
|
|
|
async def get_processing_status(self) -> Dict[str, int]:
|
|
"""Get current document processing status counts
|
|
|
|
Returns:
|
|
Dict with counts for each status
|
|
"""
|
|
return await self.doc_status.get_status_counts()
|
|
|
|
async def adelete_by_doc_id(self, doc_id: str):
|
|
"""Delete a document and all its related data
|
|
|
|
Args:
|
|
doc_id: Document ID to delete
|
|
"""
|
|
try:
|
|
# 1. Get the document status and related data
|
|
doc_status = await self.doc_status.get(doc_id)
|
|
if not doc_status:
|
|
logger.warning(f"Document {doc_id} not found")
|
|
return
|
|
|
|
logger.debug(f"Starting deletion for document {doc_id}")
|
|
|
|
# 2. Get all related chunks
|
|
chunks = await self.text_chunks.filter(
|
|
lambda x: x.get("full_doc_id") == doc_id
|
|
)
|
|
chunk_ids = list(chunks.keys())
|
|
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
|
|
|
# 3. Before deleting, check the related entities and relationships for these chunks
|
|
for chunk_id in chunk_ids:
|
|
# Check entities
|
|
entities = [
|
|
dp
|
|
for dp in self.entities_vdb.client_storage["data"]
|
|
if dp.get("source_id") == chunk_id
|
|
]
|
|
logger.debug(f"Chunk {chunk_id} has {len(entities)} related entities")
|
|
|
|
# Check relationships
|
|
relations = [
|
|
dp
|
|
for dp in self.relationships_vdb.client_storage["data"]
|
|
if dp.get("source_id") == chunk_id
|
|
]
|
|
logger.debug(f"Chunk {chunk_id} has {len(relations)} related relations")
|
|
|
|
# Continue with the original deletion process...
|
|
|
|
# 4. Delete chunks from vector database
|
|
if chunk_ids:
|
|
await self.chunks_vdb.delete(chunk_ids)
|
|
await self.text_chunks.delete(chunk_ids)
|
|
|
|
# 5. Find and process entities and relationships that have these chunks as source
|
|
# Get all nodes in the graph
|
|
nodes = self.chunk_entity_relation_graph._graph.nodes(data=True)
|
|
edges = self.chunk_entity_relation_graph._graph.edges(data=True)
|
|
|
|
# Track which entities and relationships need to be deleted or updated
|
|
entities_to_delete = set()
|
|
entities_to_update = {} # entity_name -> new_source_id
|
|
relationships_to_delete = set()
|
|
relationships_to_update = {} # (src, tgt) -> new_source_id
|
|
|
|
# Process entities
|
|
for node, data in nodes:
|
|
if "source_id" in data:
|
|
# Split source_id using GRAPH_FIELD_SEP
|
|
sources = set(data["source_id"].split(GRAPH_FIELD_SEP))
|
|
sources.difference_update(chunk_ids)
|
|
if not sources:
|
|
entities_to_delete.add(node)
|
|
logger.debug(
|
|
f"Entity {node} marked for deletion - no remaining sources"
|
|
)
|
|
else:
|
|
new_source_id = GRAPH_FIELD_SEP.join(sources)
|
|
entities_to_update[node] = new_source_id
|
|
logger.debug(
|
|
f"Entity {node} will be updated with new source_id: {new_source_id}"
|
|
)
|
|
|
|
# Process relationships
|
|
for src, tgt, data in edges:
|
|
if "source_id" in data:
|
|
# Split source_id using GRAPH_FIELD_SEP
|
|
sources = set(data["source_id"].split(GRAPH_FIELD_SEP))
|
|
sources.difference_update(chunk_ids)
|
|
if not sources:
|
|
relationships_to_delete.add((src, tgt))
|
|
logger.debug(
|
|
f"Relationship {src}-{tgt} marked for deletion - no remaining sources"
|
|
)
|
|
else:
|
|
new_source_id = GRAPH_FIELD_SEP.join(sources)
|
|
relationships_to_update[(src, tgt)] = new_source_id
|
|
logger.debug(
|
|
f"Relationship {src}-{tgt} will be updated with new source_id: {new_source_id}"
|
|
)
|
|
|
|
# Delete entities
|
|
if entities_to_delete:
|
|
for entity in entities_to_delete:
|
|
await self.entities_vdb.delete_entity(entity)
|
|
logger.debug(f"Deleted entity {entity} from vector DB")
|
|
self.chunk_entity_relation_graph.remove_nodes(list(entities_to_delete))
|
|
logger.debug(f"Deleted {len(entities_to_delete)} entities from graph")
|
|
|
|
# Update entities
|
|
for entity, new_source_id in entities_to_update.items():
|
|
node_data = self.chunk_entity_relation_graph._graph.nodes[entity]
|
|
node_data["source_id"] = new_source_id
|
|
await self.chunk_entity_relation_graph.upsert_node(entity, node_data)
|
|
logger.debug(
|
|
f"Updated entity {entity} with new source_id: {new_source_id}"
|
|
)
|
|
|
|
# Delete relationships
|
|
if relationships_to_delete:
|
|
for src, tgt in relationships_to_delete:
|
|
rel_id_0 = compute_mdhash_id(src + tgt, prefix="rel-")
|
|
rel_id_1 = compute_mdhash_id(tgt + src, prefix="rel-")
|
|
await self.relationships_vdb.delete([rel_id_0, rel_id_1])
|
|
logger.debug(f"Deleted relationship {src}-{tgt} from vector DB")
|
|
self.chunk_entity_relation_graph.remove_edges(
|
|
list(relationships_to_delete)
|
|
)
|
|
logger.debug(
|
|
f"Deleted {len(relationships_to_delete)} relationships from graph"
|
|
)
|
|
|
|
# Update relationships
|
|
for (src, tgt), new_source_id in relationships_to_update.items():
|
|
edge_data = self.chunk_entity_relation_graph._graph.edges[src, tgt]
|
|
edge_data["source_id"] = new_source_id
|
|
await self.chunk_entity_relation_graph.upsert_edge(src, tgt, edge_data)
|
|
logger.debug(
|
|
f"Updated relationship {src}-{tgt} with new source_id: {new_source_id}"
|
|
)
|
|
|
|
# 6. Delete original document and status
|
|
await self.full_docs.delete([doc_id])
|
|
await self.doc_status.delete([doc_id])
|
|
|
|
# 7. Ensure all indexes are updated
|
|
await self._insert_done()
|
|
|
|
logger.info(
|
|
f"Successfully deleted document {doc_id} and related data. "
|
|
f"Deleted {len(entities_to_delete)} entities and {len(relationships_to_delete)} relationships. "
|
|
f"Updated {len(entities_to_update)} entities and {len(relationships_to_update)} relationships."
|
|
)
|
|
|
|
# Add verification step
|
|
async def verify_deletion():
|
|
# Verify if the document has been deleted
|
|
if await self.full_docs.get_by_id(doc_id):
|
|
logger.error(f"Document {doc_id} still exists in full_docs")
|
|
|
|
# Verify if chunks have been deleted
|
|
remaining_chunks = await self.text_chunks.filter(
|
|
lambda x: x.get("full_doc_id") == doc_id
|
|
)
|
|
if remaining_chunks:
|
|
logger.error(f"Found {len(remaining_chunks)} remaining chunks")
|
|
|
|
# Verify entities and relationships
|
|
for chunk_id in chunk_ids:
|
|
# Check entities
|
|
entities_with_chunk = [
|
|
dp
|
|
for dp in self.entities_vdb.client_storage["data"]
|
|
if chunk_id
|
|
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
|
]
|
|
if entities_with_chunk:
|
|
logger.error(
|
|
f"Found {len(entities_with_chunk)} entities still referencing chunk {chunk_id}"
|
|
)
|
|
|
|
# Check relationships
|
|
relations_with_chunk = [
|
|
dp
|
|
for dp in self.relationships_vdb.client_storage["data"]
|
|
if chunk_id
|
|
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
|
]
|
|
if relations_with_chunk:
|
|
logger.error(
|
|
f"Found {len(relations_with_chunk)} relations still referencing chunk {chunk_id}"
|
|
)
|
|
|
|
await verify_deletion()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error while deleting document {doc_id}: {e}")
|
|
|
|
def delete_by_doc_id(self, doc_id: str):
|
|
"""Synchronous version of adelete"""
|
|
return asyncio.run(self.adelete_by_doc_id(doc_id))
|
|
|
|
async def get_entity_info(
|
|
self, entity_name: str, include_vector_data: bool = False
|
|
):
|
|
"""Get detailed information of an entity
|
|
|
|
Args:
|
|
entity_name: Entity name (no need for quotes)
|
|
include_vector_data: Whether to include data from the vector database
|
|
|
|
Returns:
|
|
dict: A dictionary containing entity information, including:
|
|
- entity_name: Entity name
|
|
- source_id: Source document ID
|
|
- graph_data: Complete node data from the graph database
|
|
- vector_data: (optional) Data from the vector database
|
|
"""
|
|
entity_name = f'"{entity_name.upper()}"'
|
|
|
|
# Get information from the graph
|
|
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
|
source_id = node_data.get("source_id") if node_data else None
|
|
|
|
result = {
|
|
"entity_name": entity_name,
|
|
"source_id": source_id,
|
|
"graph_data": node_data,
|
|
}
|
|
|
|
# Optional: Get vector database information
|
|
if include_vector_data:
|
|
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
|
vector_data = self.entities_vdb._client.get([entity_id])
|
|
result["vector_data"] = vector_data[0] if vector_data else None
|
|
|
|
return result
|
|
|
|
def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
|
|
"""Synchronous version of getting entity information
|
|
|
|
Args:
|
|
entity_name: Entity name (no need for quotes)
|
|
include_vector_data: Whether to include data from the vector database
|
|
"""
|
|
try:
|
|
import tracemalloc
|
|
|
|
tracemalloc.start()
|
|
return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
|
|
finally:
|
|
tracemalloc.stop()
|
|
|
|
async def get_relation_info(
|
|
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
|
):
|
|
"""Get detailed information of a relationship
|
|
|
|
Args:
|
|
src_entity: Source entity name (no need for quotes)
|
|
tgt_entity: Target entity name (no need for quotes)
|
|
include_vector_data: Whether to include data from the vector database
|
|
|
|
Returns:
|
|
dict: A dictionary containing relationship information, including:
|
|
- src_entity: Source entity name
|
|
- tgt_entity: Target entity name
|
|
- source_id: Source document ID
|
|
- graph_data: Complete edge data from the graph database
|
|
- vector_data: (optional) Data from the vector database
|
|
"""
|
|
src_entity = f'"{src_entity.upper()}"'
|
|
tgt_entity = f'"{tgt_entity.upper()}"'
|
|
|
|
# Get information from the graph
|
|
edge_data = await self.chunk_entity_relation_graph.get_edge(
|
|
src_entity, tgt_entity
|
|
)
|
|
source_id = edge_data.get("source_id") if edge_data else None
|
|
|
|
result = {
|
|
"src_entity": src_entity,
|
|
"tgt_entity": tgt_entity,
|
|
"source_id": source_id,
|
|
"graph_data": edge_data,
|
|
}
|
|
|
|
# Optional: Get vector database information
|
|
if include_vector_data:
|
|
rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-")
|
|
vector_data = self.relationships_vdb._client.get([rel_id])
|
|
result["vector_data"] = vector_data[0] if vector_data else None
|
|
|
|
return result
|
|
|
|
def get_relation_info_sync(
|
|
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
|
):
|
|
"""Synchronous version of getting relationship information
|
|
|
|
Args:
|
|
src_entity: Source entity name (no need for quotes)
|
|
tgt_entity: Target entity name (no need for quotes)
|
|
include_vector_data: Whether to include data from the vector database
|
|
"""
|
|
try:
|
|
import tracemalloc
|
|
|
|
tracemalloc.start()
|
|
return asyncio.run(
|
|
self.get_relation_info(src_entity, tgt_entity, include_vector_data)
|
|
)
|
|
finally:
|
|
tracemalloc.stop()
|