LightRAG/lightrag/lightrag.py

1438 lines
56 KiB
Python
Raw Normal View History

2024-10-10 15:02:30 +08:00
import asyncio
import os
2024-11-25 15:04:38 +08:00
from tqdm.asyncio import tqdm as tqdm_async
2024-10-10 15:02:30 +08:00
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Type, cast, Dict
2024-10-10 15:02:30 +08:00
from .operate import (
chunking_by_token_size,
extract_entities,
2024-11-25 13:29:55 +08:00
# local_query,global_query,hybrid_query,
kg_query,
2024-10-10 15:02:30 +08:00
naive_query,
mix_kg_vector_query,
extract_keywords_only,
kg_query_with_keywords,
2024-10-10 15:02:30 +08:00
)
from .utils import (
EmbeddingFunc,
compute_mdhash_id,
limit_async_func_call,
convert_response_to_json,
logger,
set_logger,
2025-01-16 12:58:15 +08:00
statistic_data,
2024-10-10 15:02:30 +08:00
)
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
StorageNameSpace,
QueryParam,
DocStatus,
2024-10-10 15:02:30 +08:00
)
2024-12-31 17:15:57 +08:00
from .prompt import GRAPH_FIELD_SEP
2025-01-16 12:58:15 +08:00
STORAGES = {
2025-01-27 09:34:00 +01:00
"NetworkXStorage": ".storage.networkx_storage",
"JsonKVStorage": ".storage.json_kv_storage",
"NanoVectorDBStorage": ".storage.nano_vector_db",
"JsonDocStatusStorage": ".storage.jsondocstatus_storage",
2025-01-16 12:58:15 +08:00
"Neo4JStorage": ".kg.neo4j_impl",
"OracleKVStorage": ".kg.oracle_impl",
"OracleGraphStorage": ".kg.oracle_impl",
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorge": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
2025-01-16 12:58:15 +08:00
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",
"TiDBVectorDBStorage": ".kg.tidb_impl",
"TiDBGraphStorage": ".kg.tidb_impl",
"PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl",
"GremlinStorage": ".kg.gremlin_impl",
"PGDocStatusStorage": ".kg.postgres_impl",
2025-01-16 12:52:37 +08:00
}
2024-11-12 13:32:40 +08:00
def lazy_external_import(module_name: str, class_name: str):
"""Lazily import a class from an external module based on the package of the caller."""
2024-12-10 17:16:21 +08:00
# Get the caller's module and package
import inspect
caller_frame = inspect.currentframe().f_back
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None
def import_class(*args, **kwargs):
import importlib
module = importlib.import_module(module_name, package=package)
cls = getattr(module, class_name)
return cls(*args, **kwargs)
return import_class
2024-10-10 15:02:30 +08:00
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
Ensure that there is always an event loop available.
This function tries to get the current event loop. If the current event loop is closed or does not exist,
it creates a new event loop and sets it as the current event loop.
Returns:
asyncio.AbstractEventLoop: The current or newly created event loop.
"""
2024-10-10 15:02:30 +08:00
try:
# Try to get the current event loop
current_loop = asyncio.get_event_loop()
if current_loop.is_closed():
raise RuntimeError("Event loop is closed.")
return current_loop
2024-11-07 14:54:15 +08:00
2024-10-10 15:02:30 +08:00
except RuntimeError:
# If no event loop exists or it is closed, create a new one
2024-11-02 18:35:07 -04:00
logger.info("Creating a new event loop in main thread.")
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
return new_loop
2024-10-10 15:02:30 +08:00
2024-10-15 19:40:08 +08:00
2024-10-10 15:02:30 +08:00
@dataclass
class LightRAG:
working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
)
# Default not to use embedding cache
embedding_cache_config: dict = field(
default_factory=lambda: {
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False,
}
)
2024-11-12 13:32:40 +08:00
kv_storage: str = field(default="JsonKVStorage")
vector_storage: str = field(default="NanoVectorDBStorage")
graph_storage: str = field(default="NetworkXStorage")
current_log_level = logger.level
log_level: str = field(default=current_log_level)
2024-10-10 15:02:30 +08:00
# text chunking
chunk_token_size: int = 1200
chunk_overlap_token_size: int = 100
tiktoken_model_name: str = "gpt-4o-mini"
# entity extraction
entity_extract_max_gleaning: int = 1
entity_summary_to_max_tokens: int = 500
# node embedding
node_embedding_algorithm: str = "node2vec"
node2vec_params: dict = field(
default_factory=lambda: {
"dimensions": 1536,
"num_walks": 10,
"walk_length": 40,
"window_size": 2,
"iterations": 3,
"random_seed": 3,
}
)
2024-10-14 20:33:46 +08:00
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
embedding_func: EmbeddingFunc = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
2024-10-10 15:02:30 +08:00
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
# LLM
llm_model_func: callable = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
2024-10-10 15:02:30 +08:00
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
llm_model_kwargs: dict = field(default_factory=dict)
2024-10-10 15:02:30 +08:00
# storage
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
2024-11-12 13:32:40 +08:00
2024-10-10 15:02:30 +08:00
enable_llm_cache: bool = True
# Sometimes there are some reason the LLM failed at Extracting Entities, and we want to continue without LLM cost, we can use this flag
2025-01-06 15:27:31 +08:00
enable_llm_cache_for_entity_extract: bool = True
2024-10-10 15:02:30 +08:00
# extension
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
# Add new field for document status storage type
doc_status_storage: str = field(default="JsonDocStatusStorage")
2025-01-09 17:20:24 +05:30
# Custom Chunking Function
chunking_func: callable = chunking_by_token_size
chunking_func_kwargs: dict = field(default_factory=dict)
def __post_init__(self):
log_file = os.path.join("lightrag.log")
2024-10-10 15:02:30 +08:00
set_logger(log_file)
logger.setLevel(self.log_level)
2024-10-10 15:02:30 +08:00
logger.info(f"Logger initialized for working directory: {self.working_dir}")
2025-01-16 12:52:37 +08:00
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
2025-01-16 12:52:37 +08:00
# show config
2025-01-16 12:58:15 +08:00
global_config = asdict(self)
2025-01-16 12:52:37 +08:00
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
2024-10-10 15:02:30 +08:00
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
2025-01-16 12:52:37 +08:00
# Init LLM
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
2025-01-16 12:52:37 +08:00
# Initialize all storages
2024-11-12 13:32:40 +08:00
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
2025-01-16 12:58:15 +08:00
self._get_storage_class(self.kv_storage)
2024-11-12 13:32:40 +08:00
)
2025-01-16 12:58:15 +08:00
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
2024-11-12 13:32:40 +08:00
self.vector_storage
2025-01-16 12:58:15 +08:00
)
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
2024-11-12 13:32:40 +08:00
self.graph_storage
2025-01-16 12:58:15 +08:00
)
2025-01-16 12:52:37 +08:00
self.key_string_value_json_storage_cls = partial(
2025-01-16 12:58:15 +08:00
self.key_string_value_json_storage_cls, global_config=global_config
2025-01-16 12:52:37 +08:00
)
2024-10-10 15:02:30 +08:00
2025-01-16 12:52:37 +08:00
self.vector_db_storage_cls = partial(
2025-01-16 12:58:15 +08:00
self.vector_db_storage_cls, global_config=global_config
2024-11-12 13:32:40 +08:00
)
2025-01-16 12:52:37 +08:00
self.graph_storage_cls = partial(
2025-01-16 12:58:15 +08:00
self.graph_storage_cls, global_config=global_config
2025-01-16 12:52:37 +08:00
)
self.json_doc_status_storage = self.key_string_value_json_storage_cls(
namespace="json_doc_status_storage",
2024-12-26 22:14:04 +08:00
embedding_func=None,
2024-10-10 15:02:30 +08:00
)
2024-12-26 22:14:04 +08:00
self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
embedding_func=None,
2024-10-10 15:02:30 +08:00
)
2024-10-15 19:40:08 +08:00
####
# add embedding func by walter
####
self.full_docs = self.key_string_value_json_storage_cls(
2024-11-12 13:32:40 +08:00
namespace="full_docs",
embedding_func=self.embedding_func,
)
self.text_chunks = self.key_string_value_json_storage_cls(
2024-11-12 13:32:40 +08:00
namespace="text_chunks",
embedding_func=self.embedding_func,
)
self.chunk_entity_relation_graph = self.graph_storage_cls(
2024-12-03 16:04:58 +08:00
namespace="chunk_entity_relation",
embedding_func=self.embedding_func,
)
####
# add embedding func by walter over
####
self.entities_vdb = self.vector_db_storage_cls(
namespace="entities",
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
2024-10-10 15:02:30 +08:00
)
self.relationships_vdb = self.vector_db_storage_cls(
namespace="relationships",
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
2024-10-10 15:02:30 +08:00
)
self.chunks_vdb = self.vector_db_storage_cls(
namespace="chunks",
embedding_func=self.embedding_func,
2024-10-10 15:02:30 +08:00
)
2025-01-16 12:58:15 +08:00
if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config"
):
2025-01-16 12:52:37 +08:00
hashing_kv = self.llm_response_cache
else:
hashing_kv = self.key_string_value_json_storage_cls(
2025-01-16 12:58:15 +08:00
namespace="llm_response_cache",
embedding_func=None,
)
2025-01-16 12:52:37 +08:00
2024-10-10 15:02:30 +08:00
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
2024-10-28 17:05:38 +02:00
partial(
self.llm_model_func,
2025-01-16 12:52:37 +08:00
hashing_kv=hashing_kv,
2024-10-28 17:05:38 +02:00
**self.llm_model_kwargs,
)
2024-10-10 15:02:30 +08:00
)
2024-11-06 11:18:14 -05:00
# Initialize document status storage
2025-01-16 12:52:37 +08:00
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.doc_status = self.doc_status_storage_cls(
namespace="doc_status",
2025-01-16 12:52:37 +08:00
global_config=global_config,
embedding_func=None,
)
async def get_graph_labels(self):
text = await self.chunk_entity_relation_graph.get_all_labels()
return text
async def get_graps(self, nodel_label: str, max_depth: int):
return await self.chunk_entity_relation_graph.get_knowledge_graph(
node_label=nodel_label, max_depth=max_depth
)
2025-01-16 12:52:37 +08:00
def _get_storage_class(self, storage_name: str) -> dict:
import_path = STORAGES[storage_name]
storage_class = lazy_external_import(import_path, storage_name)
return storage_class
2025-01-16 12:58:15 +08:00
def set_storage_client(self, db_client):
2025-01-16 12:52:37 +08:00
# Now only tested on Oracle Database
2025-01-16 12:58:15 +08:00
for storage in [
self.vector_db_storage_cls,
self.graph_storage_cls,
self.doc_status,
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.key_string_value_json_storage_cls,
self.chunks_vdb,
self.relationships_vdb,
self.entities_vdb,
self.graph_storage_cls,
self.chunk_entity_relation_graph,
self.llm_response_cache,
]:
2025-01-16 12:52:37 +08:00
# set client
storage.db = db_client
2024-10-10 15:02:30 +08:00
def insert(
self, string_or_strings, split_by_character=None, split_by_character_only=False
):
2024-10-10 15:02:30 +08:00
loop = always_get_an_event_loop()
2025-01-07 16:26:12 +08:00
return loop.run_until_complete(
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
2025-01-07 16:26:12 +08:00
)
2024-10-10 15:02:30 +08:00
async def ainsert(
2025-01-09 15:28:57 +08:00
self, string_or_strings, split_by_character=None, split_by_character_only=False
):
"""Insert documents with checkpoint support
Args:
string_or_strings: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
"""
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
# 2. Generate document IDs and initial status
new_docs = {
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
2024-10-10 15:02:30 +08:00
}
for content in unique_contents
}
# 3. Filter out already processed documents
# _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
_add_doc_keys = {
doc_id
for doc_id in new_docs.keys()
if (current_doc := await self.doc_status.get_by_id(doc_id)) is None
or current_doc["status"] == DocStatus.FAILED
}
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not new_docs:
logger.info("All documents have been processed or are duplicates")
return
logger.info(f"Processing {len(new_docs)} new unique documents")
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size):
2025-01-07 16:26:12 +08:00
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async(
2025-01-07 16:26:12 +08:00
batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
2024-11-25 15:04:38 +08:00
):
try:
# Update status to processing
doc_status = {
"content_summary": doc["content_summary"],
"content_length": doc["content_length"],
"status": DocStatus.PROCESSING,
"created_at": doc["created_at"],
"updated_at": datetime.now().isoformat(),
2024-10-10 15:02:30 +08:00
}
await self.doc_status.upsert({doc_id: doc_status})
# Generate chunks from document
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
2025-01-09 17:20:24 +05:30
for dp in self.chunking_func(
doc["content"],
split_by_character=split_by_character,
split_by_character_only=split_by_character_only,
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
2025-01-09 17:20:24 +05:30
**self.chunking_func_kwargs,
)
}
# Update status with chunks information
doc_status.update(
{
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
2024-10-10 15:02:30 +08:00
)
await self.doc_status.upsert({doc_id: doc_status})
try:
# Store chunks in vector database
await self.chunks_vdb.upsert(chunks)
# Extract and store entities and relationships
maybe_new_kg = await extract_entities(
chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)
2024-10-10 15:02:30 +08:00
if maybe_new_kg is None:
raise Exception(
"Failed to extract entities and relationships"
)
self.chunk_entity_relation_graph = maybe_new_kg
# Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
await self.text_chunks.upsert(chunks)
# Update status to processed
doc_status.update(
{
"status": DocStatus.PROCESSED,
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
except Exception as e:
# Mark as failed if any step fails
doc_status.update(
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
raise e
except Exception as e:
import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
continue
else:
# Only update index when processing succeeds
await self._insert_done()
2024-10-10 15:02:30 +08:00
2025-01-07 20:57:39 +05:30
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
loop = always_get_an_event_loop()
return loop.run_until_complete(
self.ainsert_custom_chunks(full_text, text_chunks)
)
2025-01-07 20:57:39 +05:30
async def ainsert_custom_chunks(self, full_text: str, text_chunks: list[str]):
update_storage = False
try:
doc_key = compute_mdhash_id(full_text.strip(), prefix="doc-")
new_docs = {doc_key: {"content": full_text.strip()}}
2025-01-07 20:57:39 +05:30
_add_doc_keys = await self.full_docs.filter_keys([doc_key])
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning("This document is already in the storage.")
return
update_storage = True
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
inserting_chunks = {}
for chunk_text in text_chunks:
chunk_text_stripped = chunk_text.strip()
chunk_key = compute_mdhash_id(chunk_text_stripped, prefix="chunk-")
2025-01-07 20:57:39 +05:30
inserting_chunks[chunk_key] = {
"content": chunk_text_stripped,
"full_doc_id": doc_key,
}
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
2025-01-07 20:57:39 +05:30
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning("All chunks are already in the storage.")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
await self.chunks_vdb.upsert(inserting_chunks)
logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
inserting_chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.warning("No new entities and relationships found")
return
else:
self.chunk_entity_relation_graph = maybe_new_kg
await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks)
finally:
if update_storage:
await self._insert_done()
2025-01-16 12:52:37 +08:00
async def apipeline_process_documents(self, string_or_strings):
"""Input list remove duplicates, generate document IDs and initial pendding status, filter out already stored documents, store docs
Args:
string_or_strings: Single document string or list of document strings
"""
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
2025-01-16 12:58:15 +08:00
logger.info(
f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents"
)
2025-01-16 12:52:37 +08:00
# 2. Generate document IDs and initial status
new_docs = {
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(),
"updated_at": None,
}
for content in unique_contents
}
2025-01-16 12:58:15 +08:00
# 3. Filter out already processed documents
2025-01-16 12:52:37 +08:00
_not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
if len(_not_stored_doc_keys) < len(new_docs):
2025-01-16 12:58:15 +08:00
logger.info(
f"Skipping {len(new_docs) - len(_not_stored_doc_keys)} already existing documents"
2025-01-16 12:58:15 +08:00
)
2025-01-16 12:52:37 +08:00
new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys}
if not new_docs:
2025-01-16 12:58:15 +08:00
logger.info("All documents have been processed or are duplicates")
2025-01-16 12:52:37 +08:00
return None
2025-01-16 12:58:15 +08:00
# 4. Store original document
2025-01-16 12:52:37 +08:00
for doc_id, doc in new_docs.items():
await self.full_docs.upsert({doc_id: {"content": doc["content"]}})
await self.full_docs.change_status(doc_id, DocStatus.PENDING)
logger.info(f"Stored {len(new_docs)} new unique documents")
2025-01-16 12:58:15 +08:00
2025-01-16 12:52:37 +08:00
async def apipeline_process_chunks(self):
2025-01-16 12:58:15 +08:00
"""Get pendding documents, split into chunks,insert chunks"""
# 1. get all pending and failed documents
2025-01-16 12:52:37 +08:00
_todo_doc_keys = []
2025-01-16 12:58:15 +08:00
_failed_doc = await self.full_docs.get_by_status_and_ids(
status=DocStatus.FAILED, ids=None
)
_pendding_doc = await self.full_docs.get_by_status_and_ids(
status=DocStatus.PENDING, ids=None
)
2025-01-16 12:52:37 +08:00
if _failed_doc:
_todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
if _pendding_doc:
_todo_doc_keys.extend([doc["id"] for doc in _pendding_doc])
if not _todo_doc_keys:
logger.info("All documents have been processed or are duplicates")
return None
else:
logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents")
2025-01-16 12:58:15 +08:00
2025-01-16 12:52:37 +08:00
new_docs = {
2025-01-16 12:58:15 +08:00
doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys)
2025-01-16 12:52:37 +08:00
}
# 2. split docs into chunks, insert chunks, update doc status
chunk_cnt = 0
batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async(
2025-01-16 12:58:15 +08:00
batch_docs.items(),
desc=f"Level 1 - Spliting doc in batch {i // batch_size + 1}",
2025-01-16 12:58:15 +08:00
):
2025-01-16 12:52:37 +08:00
try:
# Generate chunks from document
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
"status": DocStatus.PENDING,
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
chunk_cnt += len(chunks)
await self.text_chunks.upsert(chunks)
await self.text_chunks.change_status(doc_id, DocStatus.PROCESSED)
try:
# Store chunks in vector database
await self.chunks_vdb.upsert(chunks)
# Update doc status
await self.full_docs.change_status(doc_id, DocStatus.PROCESSED)
except Exception as e:
# Mark as failed if any step fails
await self.full_docs.change_status(doc_id, DocStatus.FAILED)
raise e
except Exception as e:
2025-01-16 12:58:15 +08:00
import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
continue
logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents")
2025-01-16 12:52:37 +08:00
async def apipeline_process_extract_graph(self):
"""Get pendding or failed chunks, extract entities and relationships from each chunk"""
2025-01-16 12:58:15 +08:00
# 1. get all pending and failed chunks
2025-01-16 12:52:37 +08:00
_todo_chunk_keys = []
2025-01-16 12:58:15 +08:00
_failed_chunks = await self.text_chunks.get_by_status_and_ids(
status=DocStatus.FAILED, ids=None
)
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(
status=DocStatus.PENDING, ids=None
)
2025-01-16 12:52:37 +08:00
if _failed_chunks:
_todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
if _pendding_chunks:
_todo_chunk_keys.extend([doc["id"] for doc in _pendding_chunks])
if not _todo_chunk_keys:
logger.info("All chunks have been processed or are duplicates")
return None
2025-01-16 12:58:15 +08:00
2025-01-16 12:52:37 +08:00
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
2025-01-16 12:58:15 +08:00
semaphore = asyncio.Semaphore(
batch_size
) # Control the number of tasks that are processed simultaneously
2025-01-16 12:52:37 +08:00
2025-01-16 12:58:15 +08:00
async def process_chunk(chunk_id):
2025-01-16 12:52:37 +08:00
async with semaphore:
2025-01-16 12:58:15 +08:00
chunks = {
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
}
2025-01-16 12:52:37 +08:00
# Extract and store entities and relationships
try:
maybe_new_kg = await extract_entities(
chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.info("No entities or relationships extracted!")
# Update status to processed
await self.text_chunks.change_status(chunk_id, DocStatus.PROCESSED)
except Exception as e:
logger.error("Failed to extract entities and relationships")
# Mark as failed if any step fails
await self.text_chunks.change_status(chunk_id, DocStatus.FAILED)
2025-01-16 12:58:15 +08:00
raise e
with tqdm_async(
total=len(_todo_chunk_keys),
desc="\nLevel 1 - Processing chunks",
unit="chunk",
position=0,
) as progress:
2025-01-16 12:52:37 +08:00
tasks = []
for chunk_id in _todo_chunk_keys:
task = asyncio.create_task(process_chunk(chunk_id))
tasks.append(task)
2025-01-16 12:58:15 +08:00
2025-01-16 12:52:37 +08:00
for future in asyncio.as_completed(tasks):
await future
progress.update(1)
2025-01-16 12:58:15 +08:00
progress.set_postfix(
{
"LLM call": statistic_data["llm_call"],
"LLM cache": statistic_data["llm_cache"],
}
)
2025-01-16 12:52:37 +08:00
# Ensure all indexes are updated after each document
await self._insert_done()
2024-10-10 15:02:30 +08:00
async def _insert_done(self):
tasks = []
for storage_inst in [
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
2024-11-25 18:06:19 +08:00
def insert_custom_kg(self, custom_kg: dict):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
async def ainsert_custom_kg(self, custom_kg: dict):
update_storage = False
try:
2024-12-04 19:44:04 +08:00
# Insert chunks into vector storage
all_chunks_data = {}
chunk_to_source_map = {}
for chunk_data in custom_kg.get("chunks", []):
chunk_content = chunk_data["content"]
source_id = chunk_data["source_id"]
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
chunk_entry = {"content": chunk_content.strip(), "source_id": source_id}
all_chunks_data[chunk_id] = chunk_entry
chunk_to_source_map[source_id] = chunk_id
update_storage = True
if self.chunks_vdb is not None and all_chunks_data:
await self.chunks_vdb.upsert(all_chunks_data)
if self.text_chunks is not None and all_chunks_data:
await self.text_chunks.upsert(all_chunks_data)
2024-11-25 18:06:19 +08:00
# Insert entities into knowledge graph
all_entities_data = []
for entity_data in custom_kg.get("entities", []):
entity_name = f'"{entity_data["entity_name"].upper()}"'
entity_type = entity_data.get("entity_type", "UNKNOWN")
description = entity_data.get("description", "No description provided")
2024-12-04 19:44:04 +08:00
# source_id = entity_data["source_id"]
source_chunk_id = entity_data.get("source_id", "UNKNOWN")
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
# Log if source_id is UNKNOWN
if source_id == "UNKNOWN":
logger.warning(
f"Entity '{entity_name}' has an UNKNOWN source_id. Please check the source mapping."
)
2024-11-25 18:06:19 +08:00
# Prepare node data
node_data = {
"entity_type": entity_type,
"description": description,
"source_id": source_id,
}
# Insert node data into the knowledge graph
await self.chunk_entity_relation_graph.upsert_node(
entity_name, node_data=node_data
)
node_data["entity_name"] = entity_name
all_entities_data.append(node_data)
update_storage = True
# Insert relationships into knowledge graph
all_relationships_data = []
for relationship_data in custom_kg.get("relationships", []):
src_id = f'"{relationship_data["src_id"].upper()}"'
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
description = relationship_data["description"]
keywords = relationship_data["keywords"]
weight = relationship_data.get("weight", 1.0)
2024-12-04 19:44:04 +08:00
# source_id = relationship_data["source_id"]
source_chunk_id = relationship_data.get("source_id", "UNKNOWN")
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
# Log if source_id is UNKNOWN
if source_id == "UNKNOWN":
logger.warning(
f"Relationship from '{src_id}' to '{tgt_id}' has an UNKNOWN source_id. Please check the source mapping."
)
2024-11-25 18:06:19 +08:00
# Check if nodes exist in the knowledge graph
for need_insert_id in [src_id, tgt_id]:
if not (
2025-01-07 16:26:12 +08:00
await self.chunk_entity_relation_graph.has_node(need_insert_id)
2024-11-25 18:06:19 +08:00
):
await self.chunk_entity_relation_graph.upsert_node(
need_insert_id,
node_data={
"source_id": source_id,
"description": "UNKNOWN",
"entity_type": "UNKNOWN",
},
)
# Insert edge into the knowledge graph
await self.chunk_entity_relation_graph.upsert_edge(
src_id,
tgt_id,
edge_data={
"weight": weight,
"description": description,
"keywords": keywords,
"source_id": source_id,
},
)
edge_data = {
"src_id": src_id,
"tgt_id": tgt_id,
"description": description,
"keywords": keywords,
}
all_relationships_data.append(edge_data)
update_storage = True
# Insert entities into vector storage if needed
if self.entities_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
for dp in all_entities_data
}
await self.entities_vdb.upsert(data_for_vdb)
# Insert relationships into vector storage if needed
if self.relationships_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"]
2025-01-07 16:26:12 +08:00
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
2024-11-25 18:06:19 +08:00
}
for dp in all_relationships_data
}
await self.relationships_vdb.upsert(data_for_vdb)
finally:
if update_storage:
await self._insert_done()
2025-01-27 10:32:22 +05:30
def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
2024-10-10 15:02:30 +08:00
loop = always_get_an_event_loop()
2025-01-27 10:32:22 +05:30
return loop.run_until_complete(self.aquery(query, prompt, param))
2025-01-27 10:32:22 +05:30
async def aquery(
self, query: str, prompt: str = "", param: QueryParam = QueryParam()
):
2024-11-25 13:29:55 +08:00
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
2024-10-10 15:02:30 +08:00
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
2025-01-07 16:26:12 +08:00
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
),
2025-01-27 10:32:22 +05:30
prompt=prompt,
2024-10-10 15:02:30 +08:00
)
elif param.mode == "naive":
response = await naive_query(
query,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
2025-01-07 16:26:12 +08:00
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
),
2024-10-10 15:02:30 +08:00
)
elif param.mode == "mix":
response = await mix_kg_vector_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
2025-01-07 16:26:12 +08:00
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
),
)
2024-10-10 15:02:30 +08:00
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
def query_with_separate_keyword_extraction(
2025-01-14 22:23:14 +05:30
self, query: str, prompt: str, param: QueryParam = QueryParam()
):
"""
1. Extract keywords from the 'query' using new function in operate.py.
2. Then run the standard aquery() flow with the final prompt (formatted_question).
"""
loop = always_get_an_event_loop()
2025-01-14 22:23:14 +05:30
return loop.run_until_complete(
self.aquery_with_separate_keyword_extraction(query, prompt, param)
)
async def aquery_with_separate_keyword_extraction(
2025-01-14 22:23:14 +05:30
self, query: str, prompt: str, param: QueryParam = QueryParam()
):
"""
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
"""
# ---------------------
# STEP 1: Keyword Extraction
# ---------------------
# We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords).
hl_keywords, ll_keywords = await extract_keywords_only(
text=query,
param=param,
global_config=asdict(self),
2025-01-14 22:23:14 +05:30
hashing_kv=self.llm_response_cache
or self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
2025-01-14 22:23:14 +05:30
),
)
2025-01-14 22:23:14 +05:30
param.hl_keywords = (hl_keywords,)
param.ll_keywords = (ll_keywords,)
# ---------------------
# STEP 2: Final Query Logic
# ---------------------
2025-01-14 22:23:14 +05:30
# Create a new string with the prompt and the keywords
ll_keywords_str = ", ".join(ll_keywords)
hl_keywords_str = ", ".join(hl_keywords)
formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query_with_keywords(
formatted_question,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
2025-01-14 22:23:14 +05:30
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
),
)
elif param.mode == "naive":
response = await naive_query(
formatted_question,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
2025-01-14 22:23:14 +05:30
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
),
)
elif param.mode == "mix":
response = await mix_kg_vector_query(
formatted_question,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
2025-01-14 22:23:14 +05:30
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
),
)
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
2024-10-10 15:02:30 +08:00
return response
async def _query_done(self):
tasks = []
for storage_inst in [self.llm_response_cache]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
2024-11-06 11:18:14 -05:00
await asyncio.gather(*tasks)
2024-11-11 17:48:40 +08:00
def delete_by_entity(self, entity_name: str):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.adelete_by_entity(entity_name))
2024-11-11 17:54:22 +08:00
2024-11-11 17:48:40 +08:00
async def adelete_by_entity(self, entity_name: str):
2024-11-11 17:54:22 +08:00
entity_name = f'"{entity_name.upper()}"'
2024-11-11 17:48:40 +08:00
try:
await self.entities_vdb.delete_entity(entity_name)
2024-12-31 17:15:57 +08:00
await self.relationships_vdb.delete_entity_relation(entity_name)
2024-11-11 17:48:40 +08:00
await self.chunk_entity_relation_graph.delete_node(entity_name)
2024-11-11 17:54:22 +08:00
logger.info(
f"Entity '{entity_name}' and its relationships have been deleted."
)
2024-11-11 17:48:40 +08:00
await self._delete_by_entity_done()
except Exception as e:
logger.error(f"Error while deleting entity '{entity_name}': {e}")
2024-11-11 17:54:22 +08:00
2024-11-11 17:48:40 +08:00
async def _delete_by_entity_done(self):
tasks = []
for storage_inst in [
self.entities_vdb,
self.relationships_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
2024-11-11 17:54:22 +08:00
await asyncio.gather(*tasks)
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
"""Get summary of document content
Args:
content: Original document content
max_length: Maximum length of summary
Returns:
Truncated content with ellipsis if needed
"""
content = content.strip()
if len(content) <= max_length:
return content
return content[:max_length] + "..."
async def get_processing_status(self) -> Dict[str, int]:
"""Get current document processing status counts
Returns:
Dict with counts for each status
"""
return await self.doc_status.get_status_counts()
2024-12-31 17:15:57 +08:00
async def adelete_by_doc_id(self, doc_id: str):
"""Delete a document and all its related data
Args:
doc_id: Document ID to delete
"""
try:
# 1. Get the document status and related data
doc_status = await self.doc_status.get(doc_id)
if not doc_status:
logger.warning(f"Document {doc_id} not found")
return
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
logger.debug(f"Starting deletion for document {doc_id}")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# 2. Get all related chunks
2024-12-31 17:32:04 +08:00
chunks = await self.text_chunks.filter(
lambda x: x.get("full_doc_id") == doc_id
)
2024-12-31 17:15:57 +08:00
chunk_ids = list(chunks.keys())
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# 3. Before deleting, check the related entities and relationships for these chunks
for chunk_id in chunk_ids:
# Check entities
entities = [
2024-12-31 17:32:04 +08:00
dp
for dp in self.entities_vdb.client_storage["data"]
2024-12-31 17:15:57 +08:00
if dp.get("source_id") == chunk_id
]
logger.debug(f"Chunk {chunk_id} has {len(entities)} related entities")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Check relationships
relations = [
2024-12-31 17:32:04 +08:00
dp
for dp in self.relationships_vdb.client_storage["data"]
2024-12-31 17:15:57 +08:00
if dp.get("source_id") == chunk_id
]
logger.debug(f"Chunk {chunk_id} has {len(relations)} related relations")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Continue with the original deletion process...
# 4. Delete chunks from vector database
if chunk_ids:
await self.chunks_vdb.delete(chunk_ids)
await self.text_chunks.delete(chunk_ids)
# 5. Find and process entities and relationships that have these chunks as source
# Get all nodes in the graph
nodes = self.chunk_entity_relation_graph._graph.nodes(data=True)
edges = self.chunk_entity_relation_graph._graph.edges(data=True)
# Track which entities and relationships need to be deleted or updated
entities_to_delete = set()
entities_to_update = {} # entity_name -> new_source_id
relationships_to_delete = set()
relationships_to_update = {} # (src, tgt) -> new_source_id
# Process entities
for node, data in nodes:
2024-12-31 17:32:04 +08:00
if "source_id" in data:
2024-12-31 17:15:57 +08:00
# Split source_id using GRAPH_FIELD_SEP
2024-12-31 17:32:04 +08:00
sources = set(data["source_id"].split(GRAPH_FIELD_SEP))
2024-12-31 17:15:57 +08:00
sources.difference_update(chunk_ids)
if not sources:
entities_to_delete.add(node)
2024-12-31 17:32:04 +08:00
logger.debug(
f"Entity {node} marked for deletion - no remaining sources"
)
2024-12-31 17:15:57 +08:00
else:
new_source_id = GRAPH_FIELD_SEP.join(sources)
entities_to_update[node] = new_source_id
2024-12-31 17:32:04 +08:00
logger.debug(
f"Entity {node} will be updated with new source_id: {new_source_id}"
)
2024-12-31 17:15:57 +08:00
# Process relationships
for src, tgt, data in edges:
2024-12-31 17:32:04 +08:00
if "source_id" in data:
2024-12-31 17:15:57 +08:00
# Split source_id using GRAPH_FIELD_SEP
2024-12-31 17:32:04 +08:00
sources = set(data["source_id"].split(GRAPH_FIELD_SEP))
2024-12-31 17:15:57 +08:00
sources.difference_update(chunk_ids)
if not sources:
relationships_to_delete.add((src, tgt))
2024-12-31 17:32:04 +08:00
logger.debug(
f"Relationship {src}-{tgt} marked for deletion - no remaining sources"
)
2024-12-31 17:15:57 +08:00
else:
new_source_id = GRAPH_FIELD_SEP.join(sources)
relationships_to_update[(src, tgt)] = new_source_id
2024-12-31 17:32:04 +08:00
logger.debug(
f"Relationship {src}-{tgt} will be updated with new source_id: {new_source_id}"
)
2024-12-31 17:15:57 +08:00
# Delete entities
if entities_to_delete:
for entity in entities_to_delete:
await self.entities_vdb.delete_entity(entity)
logger.debug(f"Deleted entity {entity} from vector DB")
self.chunk_entity_relation_graph.remove_nodes(list(entities_to_delete))
logger.debug(f"Deleted {len(entities_to_delete)} entities from graph")
# Update entities
for entity, new_source_id in entities_to_update.items():
node_data = self.chunk_entity_relation_graph._graph.nodes[entity]
2024-12-31 17:32:04 +08:00
node_data["source_id"] = new_source_id
2024-12-31 17:15:57 +08:00
await self.chunk_entity_relation_graph.upsert_node(entity, node_data)
2024-12-31 17:32:04 +08:00
logger.debug(
f"Updated entity {entity} with new source_id: {new_source_id}"
)
2024-12-31 17:15:57 +08:00
# Delete relationships
if relationships_to_delete:
for src, tgt in relationships_to_delete:
rel_id_0 = compute_mdhash_id(src + tgt, prefix="rel-")
rel_id_1 = compute_mdhash_id(tgt + src, prefix="rel-")
await self.relationships_vdb.delete([rel_id_0, rel_id_1])
logger.debug(f"Deleted relationship {src}-{tgt} from vector DB")
2024-12-31 17:32:04 +08:00
self.chunk_entity_relation_graph.remove_edges(
list(relationships_to_delete)
)
logger.debug(
f"Deleted {len(relationships_to_delete)} relationships from graph"
)
2024-12-31 17:15:57 +08:00
# Update relationships
for (src, tgt), new_source_id in relationships_to_update.items():
edge_data = self.chunk_entity_relation_graph._graph.edges[src, tgt]
2024-12-31 17:32:04 +08:00
edge_data["source_id"] = new_source_id
2024-12-31 17:15:57 +08:00
await self.chunk_entity_relation_graph.upsert_edge(src, tgt, edge_data)
2024-12-31 17:32:04 +08:00
logger.debug(
f"Updated relationship {src}-{tgt} with new source_id: {new_source_id}"
)
2024-12-31 17:15:57 +08:00
# 6. Delete original document and status
await self.full_docs.delete([doc_id])
await self.doc_status.delete([doc_id])
# 7. Ensure all indexes are updated
await self._insert_done()
logger.info(
f"Successfully deleted document {doc_id} and related data. "
f"Deleted {len(entities_to_delete)} entities and {len(relationships_to_delete)} relationships. "
f"Updated {len(entities_to_update)} entities and {len(relationships_to_update)} relationships."
)
# Add verification step
async def verify_deletion():
# Verify if the document has been deleted
if await self.full_docs.get_by_id(doc_id):
logger.error(f"Document {doc_id} still exists in full_docs")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Verify if chunks have been deleted
remaining_chunks = await self.text_chunks.filter(
lambda x: x.get("full_doc_id") == doc_id
)
if remaining_chunks:
logger.error(f"Found {len(remaining_chunks)} remaining chunks")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Verify entities and relationships
for chunk_id in chunk_ids:
# Check entities
entities_with_chunk = [
2024-12-31 17:32:04 +08:00
dp
for dp in self.entities_vdb.client_storage["data"]
if chunk_id
2025-01-07 16:26:12 +08:00
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
2024-12-31 17:15:57 +08:00
]
if entities_with_chunk:
2024-12-31 17:32:04 +08:00
logger.error(
f"Found {len(entities_with_chunk)} entities still referencing chunk {chunk_id}"
)
2024-12-31 17:15:57 +08:00
# Check relationships
relations_with_chunk = [
2024-12-31 17:32:04 +08:00
dp
for dp in self.relationships_vdb.client_storage["data"]
if chunk_id
2025-01-07 16:26:12 +08:00
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
2024-12-31 17:15:57 +08:00
]
if relations_with_chunk:
2024-12-31 17:32:04 +08:00
logger.error(
f"Found {len(relations_with_chunk)} relations still referencing chunk {chunk_id}"
)
2024-12-31 17:15:57 +08:00
await verify_deletion()
except Exception as e:
logger.error(f"Error while deleting document {doc_id}: {e}")
def delete_by_doc_id(self, doc_id: str):
"""Synchronous version of adelete"""
return asyncio.run(self.adelete_by_doc_id(doc_id))
2024-12-31 17:32:04 +08:00
async def get_entity_info(
2025-01-07 16:26:12 +08:00
self, entity_name: str, include_vector_data: bool = False
2024-12-31 17:32:04 +08:00
):
2024-12-31 17:15:57 +08:00
"""Get detailed information of an entity
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Args:
entity_name: Entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Returns:
dict: A dictionary containing entity information, including:
- entity_name: Entity name
- source_id: Source document ID
- graph_data: Complete node data from the graph database
- vector_data: (optional) Data from the vector database
"""
entity_name = f'"{entity_name.upper()}"'
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Get information from the graph
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
2024-12-31 17:32:04 +08:00
source_id = node_data.get("source_id") if node_data else None
2024-12-31 17:15:57 +08:00
result = {
"entity_name": entity_name,
"source_id": source_id,
"graph_data": node_data,
}
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Optional: Get vector database information
if include_vector_data:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
vector_data = self.entities_vdb._client.get([entity_id])
result["vector_data"] = vector_data[0] if vector_data else None
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
return result
def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
"""Synchronous version of getting entity information
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Args:
entity_name: Entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
"""
try:
import tracemalloc
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
tracemalloc.start()
return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
finally:
tracemalloc.stop()
2024-12-31 17:32:04 +08:00
async def get_relation_info(
2025-01-07 16:26:12 +08:00
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
2024-12-31 17:32:04 +08:00
):
2024-12-31 17:15:57 +08:00
"""Get detailed information of a relationship
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Args:
src_entity: Source entity name (no need for quotes)
tgt_entity: Target entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Returns:
dict: A dictionary containing relationship information, including:
- src_entity: Source entity name
- tgt_entity: Target entity name
- source_id: Source document ID
- graph_data: Complete edge data from the graph database
- vector_data: (optional) Data from the vector database
"""
src_entity = f'"{src_entity.upper()}"'
tgt_entity = f'"{tgt_entity.upper()}"'
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Get information from the graph
2024-12-31 17:32:04 +08:00
edge_data = await self.chunk_entity_relation_graph.get_edge(
src_entity, tgt_entity
)
source_id = edge_data.get("source_id") if edge_data else None
2024-12-31 17:15:57 +08:00
result = {
"src_entity": src_entity,
"tgt_entity": tgt_entity,
"source_id": source_id,
"graph_data": edge_data,
}
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Optional: Get vector database information
if include_vector_data:
rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-")
vector_data = self.relationships_vdb._client.get([rel_id])
result["vector_data"] = vector_data[0] if vector_data else None
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
return result
2024-12-31 17:32:04 +08:00
def get_relation_info_sync(
2025-01-07 16:26:12 +08:00
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
2024-12-31 17:32:04 +08:00
):
2024-12-31 17:15:57 +08:00
"""Synchronous version of getting relationship information
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Args:
src_entity: Source entity name (no need for quotes)
tgt_entity: Target entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
"""
try:
import tracemalloc
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
tracemalloc.start()
2024-12-31 17:32:04 +08:00
return asyncio.run(
self.get_relation_info(src_entity, tgt_entity, include_vector_data)
)
2024-12-31 17:15:57 +08:00
finally:
tracemalloc.stop()