2024-10-10 15:02:30 +08:00
|
|
|
import asyncio
|
|
|
|
import os
|
2024-11-25 15:04:38 +08:00
|
|
|
from tqdm.asyncio import tqdm as tqdm_async
|
2024-10-10 15:02:30 +08:00
|
|
|
from dataclasses import asdict, dataclass, field
|
|
|
|
from datetime import datetime
|
|
|
|
from functools import partial
|
2024-12-28 00:11:25 +08:00
|
|
|
from typing import Type, cast, Dict
|
2024-10-10 15:02:30 +08:00
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
from .llm import (
|
|
|
|
gpt_4o_mini_complete,
|
|
|
|
openai_embedding,
|
|
|
|
)
|
2024-10-10 15:02:30 +08:00
|
|
|
from .operate import (
|
|
|
|
chunking_by_token_size,
|
|
|
|
extract_entities,
|
2024-11-25 13:29:55 +08:00
|
|
|
# local_query,global_query,hybrid_query,
|
|
|
|
kg_query,
|
2024-10-10 15:02:30 +08:00
|
|
|
naive_query,
|
2024-12-28 11:56:28 +08:00
|
|
|
mix_kg_vector_query,
|
2024-10-10 15:02:30 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
from .utils import (
|
|
|
|
EmbeddingFunc,
|
|
|
|
compute_mdhash_id,
|
|
|
|
limit_async_func_call,
|
|
|
|
convert_response_to_json,
|
|
|
|
logger,
|
|
|
|
set_logger,
|
|
|
|
)
|
|
|
|
from .base import (
|
|
|
|
BaseGraphStorage,
|
|
|
|
BaseKVStorage,
|
|
|
|
BaseVectorStorage,
|
|
|
|
StorageNameSpace,
|
|
|
|
QueryParam,
|
2024-12-28 00:11:25 +08:00
|
|
|
DocStatus,
|
2024-10-10 15:02:30 +08:00
|
|
|
)
|
|
|
|
|
2024-11-08 14:58:41 +08:00
|
|
|
from .storage import (
|
|
|
|
JsonKVStorage,
|
|
|
|
NanoVectorDBStorage,
|
|
|
|
NetworkXStorage,
|
2024-12-28 00:11:25 +08:00
|
|
|
JsonDocStatusStorage,
|
2024-11-12 13:32:40 +08:00
|
|
|
)
|
2024-11-08 14:58:41 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
from .prompt import GRAPH_FIELD_SEP
|
|
|
|
|
2025-01-07 00:28:15 +08:00
|
|
|
|
2024-11-08 14:58:41 +08:00
|
|
|
# future KG integrations
|
|
|
|
|
|
|
|
# from .kg.ArangoDB_impl import (
|
|
|
|
# GraphStorage as ArangoDBStorage
|
|
|
|
# )
|
|
|
|
|
2024-11-12 13:32:40 +08:00
|
|
|
|
2024-12-09 15:35:35 +08:00
|
|
|
def lazy_external_import(module_name: str, class_name: str):
|
2024-12-10 16:23:05 +01:00
|
|
|
"""Lazily import a class from an external module based on the package of the caller."""
|
2024-12-09 15:35:35 +08:00
|
|
|
|
2024-12-10 17:16:21 +08:00
|
|
|
# Get the caller's module and package
|
|
|
|
import inspect
|
|
|
|
|
|
|
|
caller_frame = inspect.currentframe().f_back
|
|
|
|
module = inspect.getmodule(caller_frame)
|
|
|
|
package = module.__package__ if module else None
|
|
|
|
|
2024-12-10 16:23:05 +01:00
|
|
|
def import_class(*args, **kwargs):
|
2024-12-09 15:35:35 +08:00
|
|
|
import importlib
|
|
|
|
|
|
|
|
# Import the module using importlib
|
2024-12-10 16:23:05 +01:00
|
|
|
module = importlib.import_module(module_name, package=package)
|
|
|
|
|
|
|
|
# Get the class from the module and instantiate it
|
|
|
|
cls = getattr(module, class_name)
|
|
|
|
return cls(*args, **kwargs)
|
2024-12-09 15:35:35 +08:00
|
|
|
|
|
|
|
return import_class
|
|
|
|
|
|
|
|
|
|
|
|
Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage")
|
|
|
|
OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage")
|
|
|
|
OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage")
|
|
|
|
OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage")
|
|
|
|
MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
|
|
|
|
MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
|
2024-12-10 16:23:05 +01:00
|
|
|
ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
|
2024-12-11 15:53:32 +08:00
|
|
|
TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
|
|
|
|
TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
|
2024-12-17 15:24:38 +08:00
|
|
|
TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage")
|
2025-01-01 22:43:59 +08:00
|
|
|
PGKVStorage = lazy_external_import(".kg.postgres_impl", "PGKVStorage")
|
|
|
|
PGVectorStorage = lazy_external_import(".kg.postgres_impl", "PGVectorStorage")
|
2024-12-13 20:41:38 +01:00
|
|
|
AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")
|
2025-01-01 22:43:59 +08:00
|
|
|
PGGraphStorage = lazy_external_import(".kg.postgres_impl", "PGGraphStorage")
|
2024-12-19 17:47:42 +01:00
|
|
|
GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage")
|
2025-01-01 22:43:59 +08:00
|
|
|
PGDocStatusStorage = lazy_external_import(".kg.postgres_impl", "PGDocStatusStorage")
|
2024-12-09 15:35:35 +08:00
|
|
|
|
2024-12-12 10:21:51 +08:00
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
2024-11-29 13:27:08 -07:00
|
|
|
"""
|
|
|
|
Ensure that there is always an event loop available.
|
|
|
|
|
|
|
|
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
|
|
|
it creates a new event loop and sets it as the current event loop.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
asyncio.AbstractEventLoop: The current or newly created event loop.
|
|
|
|
"""
|
2024-10-10 15:02:30 +08:00
|
|
|
try:
|
2024-11-29 13:27:08 -07:00
|
|
|
# Try to get the current event loop
|
|
|
|
current_loop = asyncio.get_event_loop()
|
2024-12-09 17:10:13 +08:00
|
|
|
if current_loop.is_closed():
|
2024-11-29 13:27:08 -07:00
|
|
|
raise RuntimeError("Event loop is closed.")
|
|
|
|
return current_loop
|
2024-11-07 14:54:15 +08:00
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
except RuntimeError:
|
2024-11-29 13:27:08 -07:00
|
|
|
# If no event loop exists or it is closed, create a new one
|
2024-11-02 18:35:07 -04:00
|
|
|
logger.info("Creating a new event loop in main thread.")
|
2024-11-29 13:27:08 -07:00
|
|
|
new_loop = asyncio.new_event_loop()
|
|
|
|
asyncio.set_event_loop(new_loop)
|
|
|
|
return new_loop
|
2024-10-10 15:02:30 +08:00
|
|
|
|
2024-10-15 19:40:08 +08:00
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
@dataclass
|
|
|
|
class LightRAG:
|
|
|
|
working_dir: str = field(
|
|
|
|
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
|
|
|
)
|
2024-12-06 08:17:20 +08:00
|
|
|
# Default not to use embedding cache
|
|
|
|
embedding_cache_config: dict = field(
|
2024-12-08 17:35:52 +08:00
|
|
|
default_factory=lambda: {
|
|
|
|
"enabled": False,
|
|
|
|
"similarity_threshold": 0.95,
|
|
|
|
"use_llm_check": False,
|
|
|
|
}
|
2024-12-06 08:17:20 +08:00
|
|
|
)
|
2024-11-12 13:32:40 +08:00
|
|
|
kv_storage: str = field(default="JsonKVStorage")
|
2024-11-08 14:58:41 +08:00
|
|
|
vector_storage: str = field(default="NanoVectorDBStorage")
|
|
|
|
graph_storage: str = field(default="NetworkXStorage")
|
2024-11-01 11:01:50 -04:00
|
|
|
|
|
|
|
current_log_level = logger.level
|
|
|
|
log_level: str = field(default=current_log_level)
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
# text chunking
|
|
|
|
chunk_token_size: int = 1200
|
|
|
|
chunk_overlap_token_size: int = 100
|
|
|
|
tiktoken_model_name: str = "gpt-4o-mini"
|
|
|
|
|
|
|
|
# entity extraction
|
|
|
|
entity_extract_max_gleaning: int = 1
|
|
|
|
entity_summary_to_max_tokens: int = 500
|
|
|
|
|
|
|
|
# node embedding
|
|
|
|
node_embedding_algorithm: str = "node2vec"
|
|
|
|
node2vec_params: dict = field(
|
|
|
|
default_factory=lambda: {
|
|
|
|
"dimensions": 1536,
|
|
|
|
"num_walks": 10,
|
|
|
|
"walk_length": 40,
|
|
|
|
"window_size": 2,
|
|
|
|
"iterations": 3,
|
|
|
|
"random_seed": 3,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
2024-10-14 20:33:46 +08:00
|
|
|
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
2024-10-19 09:43:17 +05:30
|
|
|
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
2024-10-10 15:02:30 +08:00
|
|
|
embedding_batch_num: int = 32
|
|
|
|
embedding_func_max_async: int = 16
|
|
|
|
|
|
|
|
# LLM
|
2024-10-19 09:43:17 +05:30
|
|
|
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
2025-01-07 00:28:15 +08:00
|
|
|
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
2024-10-10 15:02:30 +08:00
|
|
|
llm_model_max_token_size: int = 32768
|
|
|
|
llm_model_max_async: int = 16
|
2024-10-21 11:53:06 +00:00
|
|
|
llm_model_kwargs: dict = field(default_factory=dict)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
|
|
# storage
|
|
|
|
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
2024-11-12 13:32:40 +08:00
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
enable_llm_cache: bool = True
|
2025-01-06 12:50:05 +08:00
|
|
|
# Sometimes there are some reason the LLM failed at Extracting Entities, and we want to continue without LLM cost, we can use this flag
|
2025-01-06 15:27:31 +08:00
|
|
|
enable_llm_cache_for_entity_extract: bool = True
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
|
|
# extension
|
|
|
|
addon_params: dict = field(default_factory=dict)
|
|
|
|
convert_response_to_json_func: callable = convert_response_to_json
|
|
|
|
|
2024-12-28 00:11:25 +08:00
|
|
|
# Add new field for document status storage type
|
|
|
|
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
def __post_init__(self):
|
2024-12-04 08:44:13 +08:00
|
|
|
log_file = os.path.join("lightrag.log")
|
2024-10-10 15:02:30 +08:00
|
|
|
set_logger(log_file)
|
2024-11-01 11:01:50 -04:00
|
|
|
logger.setLevel(self.log_level)
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
|
|
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
|
|
|
|
2024-11-06 11:18:14 -05:00
|
|
|
# @TODO: should move all storage setup here to leverage initial start params attached to self.
|
2024-11-01 08:47:52 -04:00
|
|
|
|
2024-11-12 13:32:40 +08:00
|
|
|
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
|
|
|
self._get_storage_class()[self.kv_storage]
|
|
|
|
)
|
|
|
|
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
|
|
|
|
self.vector_storage
|
|
|
|
]
|
|
|
|
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
|
|
|
|
self.graph_storage
|
|
|
|
]
|
2024-11-08 14:58:41 +08:00
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
if not os.path.exists(self.working_dir):
|
|
|
|
logger.info(f"Creating working directory {self.working_dir}")
|
|
|
|
os.makedirs(self.working_dir)
|
|
|
|
|
2024-12-26 22:14:04 +08:00
|
|
|
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
|
|
|
namespace="llm_response_cache",
|
|
|
|
global_config=asdict(self),
|
|
|
|
embedding_func=None,
|
2024-10-10 15:02:30 +08:00
|
|
|
)
|
2024-12-26 22:14:04 +08:00
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
2024-10-15 21:11:12 +08:00
|
|
|
self.embedding_func
|
2024-10-10 15:02:30 +08:00
|
|
|
)
|
2024-10-15 19:40:08 +08:00
|
|
|
|
2024-11-08 14:58:41 +08:00
|
|
|
####
|
|
|
|
# add embedding func by walter
|
|
|
|
####
|
|
|
|
self.full_docs = self.key_string_value_json_storage_cls(
|
2024-11-12 13:32:40 +08:00
|
|
|
namespace="full_docs",
|
|
|
|
global_config=asdict(self),
|
|
|
|
embedding_func=self.embedding_func,
|
2024-11-08 14:58:41 +08:00
|
|
|
)
|
|
|
|
self.text_chunks = self.key_string_value_json_storage_cls(
|
2024-11-12 13:32:40 +08:00
|
|
|
namespace="text_chunks",
|
|
|
|
global_config=asdict(self),
|
|
|
|
embedding_func=self.embedding_func,
|
2024-11-08 14:58:41 +08:00
|
|
|
)
|
|
|
|
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
2024-12-03 16:04:58 +08:00
|
|
|
namespace="chunk_entity_relation",
|
|
|
|
global_config=asdict(self),
|
|
|
|
embedding_func=self.embedding_func,
|
2024-11-08 14:58:41 +08:00
|
|
|
)
|
|
|
|
####
|
|
|
|
# add embedding func by walter over
|
|
|
|
####
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
self.entities_vdb = self.vector_db_storage_cls(
|
|
|
|
namespace="entities",
|
|
|
|
global_config=asdict(self),
|
|
|
|
embedding_func=self.embedding_func,
|
|
|
|
meta_fields={"entity_name"},
|
2024-10-10 15:02:30 +08:00
|
|
|
)
|
2024-10-19 09:43:17 +05:30
|
|
|
self.relationships_vdb = self.vector_db_storage_cls(
|
|
|
|
namespace="relationships",
|
|
|
|
global_config=asdict(self),
|
|
|
|
embedding_func=self.embedding_func,
|
|
|
|
meta_fields={"src_id", "tgt_id"},
|
2024-10-10 15:02:30 +08:00
|
|
|
)
|
2024-10-19 09:43:17 +05:30
|
|
|
self.chunks_vdb = self.vector_db_storage_cls(
|
|
|
|
namespace="chunks",
|
|
|
|
global_config=asdict(self),
|
|
|
|
embedding_func=self.embedding_func,
|
2024-10-10 15:02:30 +08:00
|
|
|
)
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
2024-10-28 17:05:38 +02:00
|
|
|
partial(
|
|
|
|
self.llm_model_func,
|
2024-12-17 16:44:42 +08:00
|
|
|
hashing_kv=self.llm_response_cache
|
|
|
|
if self.llm_response_cache
|
2025-01-07 16:26:12 +08:00
|
|
|
and hasattr(self.llm_response_cache, "global_config")
|
2024-12-17 16:44:42 +08:00
|
|
|
else self.key_string_value_json_storage_cls(
|
|
|
|
namespace="llm_response_cache",
|
|
|
|
global_config=asdict(self),
|
|
|
|
embedding_func=None,
|
|
|
|
),
|
2024-10-28 17:05:38 +02:00
|
|
|
**self.llm_model_kwargs,
|
|
|
|
)
|
2024-10-10 15:02:30 +08:00
|
|
|
)
|
2024-11-06 11:18:14 -05:00
|
|
|
|
2024-12-28 00:11:25 +08:00
|
|
|
# Initialize document status storage
|
|
|
|
self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage]
|
|
|
|
self.doc_status = self.doc_status_storage_cls(
|
|
|
|
namespace="doc_status",
|
|
|
|
global_config=asdict(self),
|
|
|
|
embedding_func=None,
|
|
|
|
)
|
|
|
|
|
|
|
|
def _get_storage_class(self) -> dict:
|
2024-11-01 08:47:52 -04:00
|
|
|
return {
|
2024-11-08 16:12:58 +08:00
|
|
|
# kv storage
|
2024-11-12 13:32:40 +08:00
|
|
|
"JsonKVStorage": JsonKVStorage,
|
|
|
|
"OracleKVStorage": OracleKVStorage,
|
2024-12-05 13:57:43 +08:00
|
|
|
"MongoKVStorage": MongoKVStorage,
|
2024-12-11 15:53:32 +08:00
|
|
|
"TiDBKVStorage": TiDBKVStorage,
|
2024-11-08 16:12:58 +08:00
|
|
|
# vector storage
|
2024-11-12 13:32:40 +08:00
|
|
|
"NanoVectorDBStorage": NanoVectorDBStorage,
|
|
|
|
"OracleVectorDBStorage": OracleVectorDBStorage,
|
2024-12-04 17:26:47 +08:00
|
|
|
"MilvusVectorDBStorge": MilvusVectorDBStorge,
|
2024-12-10 16:23:05 +01:00
|
|
|
"ChromaVectorDBStorage": ChromaVectorDBStorage,
|
2024-12-11 15:53:32 +08:00
|
|
|
"TiDBVectorDBStorage": TiDBVectorDBStorage,
|
2024-11-08 16:12:58 +08:00
|
|
|
# graph storage
|
2024-11-01 08:47:52 -04:00
|
|
|
"NetworkXStorage": NetworkXStorage,
|
2024-11-08 16:12:58 +08:00
|
|
|
"Neo4JStorage": Neo4JStorage,
|
2024-11-08 14:58:41 +08:00
|
|
|
"OracleGraphStorage": OracleGraphStorage,
|
2024-12-13 20:41:38 +01:00
|
|
|
"AGEStorage": AGEStorage,
|
2025-01-01 22:43:59 +08:00
|
|
|
"PGGraphStorage": PGGraphStorage,
|
|
|
|
"PGKVStorage": PGKVStorage,
|
|
|
|
"PGDocStatusStorage": PGDocStatusStorage,
|
|
|
|
"PGVectorStorage": PGVectorStorage,
|
2024-12-17 15:24:38 +08:00
|
|
|
"TiDBGraphStorage": TiDBGraphStorage,
|
2024-12-19 17:47:42 +01:00
|
|
|
"GremlinStorage": GremlinStorage,
|
2024-11-01 11:01:50 -04:00
|
|
|
# "ArangoDBStorage": ArangoDBStorage
|
2024-12-28 00:11:25 +08:00
|
|
|
"JsonDocStatusStorage": JsonDocStatusStorage,
|
2024-11-01 08:47:52 -04:00
|
|
|
}
|
2024-10-10 15:02:30 +08:00
|
|
|
|
2025-01-07 00:28:15 +08:00
|
|
|
def insert(self, string_or_strings, split_by_character=None):
|
2024-10-10 15:02:30 +08:00
|
|
|
loop = always_get_an_event_loop()
|
2025-01-07 16:26:12 +08:00
|
|
|
return loop.run_until_complete(
|
|
|
|
self.ainsert(string_or_strings, split_by_character)
|
|
|
|
)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
2025-01-07 00:28:15 +08:00
|
|
|
async def ainsert(self, string_or_strings, split_by_character):
|
2024-12-28 00:11:25 +08:00
|
|
|
"""Insert documents with checkpoint support
|
|
|
|
|
|
|
|
Args:
|
|
|
|
string_or_strings: Single document string or list of document strings
|
2025-01-07 00:28:15 +08:00
|
|
|
split_by_character: if split_by_character is not None, split the string by character
|
2024-12-28 00:11:25 +08:00
|
|
|
"""
|
|
|
|
if isinstance(string_or_strings, str):
|
|
|
|
string_or_strings = [string_or_strings]
|
|
|
|
|
|
|
|
# 1. Remove duplicate contents from the list
|
|
|
|
unique_contents = list(set(doc.strip() for doc in string_or_strings))
|
|
|
|
|
|
|
|
# 2. Generate document IDs and initial status
|
|
|
|
new_docs = {
|
|
|
|
compute_mdhash_id(content, prefix="doc-"): {
|
|
|
|
"content": content,
|
|
|
|
"content_summary": self._get_content_summary(content),
|
|
|
|
"content_length": len(content),
|
|
|
|
"status": DocStatus.PENDING,
|
|
|
|
"created_at": datetime.now().isoformat(),
|
|
|
|
"updated_at": datetime.now().isoformat(),
|
2024-10-10 15:02:30 +08:00
|
|
|
}
|
2024-12-28 00:11:25 +08:00
|
|
|
for content in unique_contents
|
|
|
|
}
|
|
|
|
|
|
|
|
# 3. Filter out already processed documents
|
|
|
|
_add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
|
|
|
|
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
|
|
|
|
|
|
|
if not new_docs:
|
|
|
|
logger.info("All documents have been processed or are duplicates")
|
|
|
|
return
|
|
|
|
|
|
|
|
logger.info(f"Processing {len(new_docs)} new unique documents")
|
|
|
|
|
|
|
|
# Process documents in batches
|
|
|
|
batch_size = self.addon_params.get("insert_batch_size", 10)
|
|
|
|
for i in range(0, len(new_docs), batch_size):
|
2025-01-07 16:26:12 +08:00
|
|
|
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
|
2024-12-28 00:11:25 +08:00
|
|
|
|
|
|
|
for doc_id, doc in tqdm_async(
|
2025-01-07 16:26:12 +08:00
|
|
|
batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
|
2024-11-25 15:04:38 +08:00
|
|
|
):
|
2024-12-28 00:11:25 +08:00
|
|
|
try:
|
|
|
|
# Update status to processing
|
|
|
|
doc_status = {
|
|
|
|
"content_summary": doc["content_summary"],
|
|
|
|
"content_length": doc["content_length"],
|
|
|
|
"status": DocStatus.PROCESSING,
|
|
|
|
"created_at": doc["created_at"],
|
|
|
|
"updated_at": datetime.now().isoformat(),
|
2024-10-10 15:02:30 +08:00
|
|
|
}
|
2024-12-28 00:11:25 +08:00
|
|
|
await self.doc_status.upsert({doc_id: doc_status})
|
|
|
|
|
|
|
|
# Generate chunks from document
|
|
|
|
chunks = {
|
|
|
|
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
|
|
|
**dp,
|
|
|
|
"full_doc_id": doc_id,
|
|
|
|
}
|
|
|
|
for dp in chunking_by_token_size(
|
|
|
|
doc["content"],
|
2025-01-07 00:28:15 +08:00
|
|
|
split_by_character=split_by_character,
|
2024-12-28 00:11:25 +08:00
|
|
|
overlap_token_size=self.chunk_overlap_token_size,
|
|
|
|
max_token_size=self.chunk_token_size,
|
|
|
|
tiktoken_model=self.tiktoken_model_name,
|
|
|
|
)
|
|
|
|
}
|
|
|
|
|
|
|
|
# Update status with chunks information
|
|
|
|
doc_status.update(
|
|
|
|
{
|
|
|
|
"chunks_count": len(chunks),
|
|
|
|
"updated_at": datetime.now().isoformat(),
|
|
|
|
}
|
2024-10-10 15:02:30 +08:00
|
|
|
)
|
2024-12-28 00:11:25 +08:00
|
|
|
await self.doc_status.upsert({doc_id: doc_status})
|
|
|
|
|
|
|
|
try:
|
|
|
|
# Store chunks in vector database
|
|
|
|
await self.chunks_vdb.upsert(chunks)
|
|
|
|
|
|
|
|
# Extract and store entities and relationships
|
|
|
|
maybe_new_kg = await extract_entities(
|
|
|
|
chunks,
|
|
|
|
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
|
|
|
entity_vdb=self.entities_vdb,
|
|
|
|
relationships_vdb=self.relationships_vdb,
|
2025-01-06 12:50:05 +08:00
|
|
|
llm_response_cache=self.llm_response_cache,
|
2024-12-28 00:11:25 +08:00
|
|
|
global_config=asdict(self),
|
|
|
|
)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
2024-12-28 00:11:25 +08:00
|
|
|
if maybe_new_kg is None:
|
|
|
|
raise Exception(
|
|
|
|
"Failed to extract entities and relationships"
|
|
|
|
)
|
|
|
|
|
|
|
|
self.chunk_entity_relation_graph = maybe_new_kg
|
|
|
|
|
|
|
|
# Store original document and chunks
|
|
|
|
await self.full_docs.upsert(
|
|
|
|
{doc_id: {"content": doc["content"]}}
|
|
|
|
)
|
|
|
|
await self.text_chunks.upsert(chunks)
|
|
|
|
|
|
|
|
# Update status to processed
|
|
|
|
doc_status.update(
|
|
|
|
{
|
|
|
|
"status": DocStatus.PROCESSED,
|
|
|
|
"updated_at": datetime.now().isoformat(),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
await self.doc_status.upsert({doc_id: doc_status})
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
# Mark as failed if any step fails
|
|
|
|
doc_status.update(
|
|
|
|
{
|
|
|
|
"status": DocStatus.FAILED,
|
|
|
|
"error": str(e),
|
|
|
|
"updated_at": datetime.now().isoformat(),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
await self.doc_status.upsert({doc_id: doc_status})
|
|
|
|
raise e
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
import traceback
|
|
|
|
|
|
|
|
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
|
|
|
|
logger.error(error_msg)
|
|
|
|
continue
|
|
|
|
|
|
|
|
finally:
|
|
|
|
# Ensure all indexes are updated after each document
|
|
|
|
await self._insert_done()
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
|
|
async def _insert_done(self):
|
|
|
|
tasks = []
|
|
|
|
for storage_inst in [
|
|
|
|
self.full_docs,
|
|
|
|
self.text_chunks,
|
|
|
|
self.llm_response_cache,
|
|
|
|
self.entities_vdb,
|
|
|
|
self.relationships_vdb,
|
|
|
|
self.chunks_vdb,
|
|
|
|
self.chunk_entity_relation_graph,
|
|
|
|
]:
|
|
|
|
if storage_inst is None:
|
|
|
|
continue
|
|
|
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
|
|
|
await asyncio.gather(*tasks)
|
|
|
|
|
2024-11-25 18:06:19 +08:00
|
|
|
def insert_custom_kg(self, custom_kg: dict):
|
|
|
|
loop = always_get_an_event_loop()
|
|
|
|
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
|
|
|
|
|
|
|
async def ainsert_custom_kg(self, custom_kg: dict):
|
|
|
|
update_storage = False
|
|
|
|
try:
|
2024-12-04 19:44:04 +08:00
|
|
|
# Insert chunks into vector storage
|
|
|
|
all_chunks_data = {}
|
|
|
|
chunk_to_source_map = {}
|
|
|
|
for chunk_data in custom_kg.get("chunks", []):
|
|
|
|
chunk_content = chunk_data["content"]
|
|
|
|
source_id = chunk_data["source_id"]
|
|
|
|
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
|
|
|
|
|
|
|
chunk_entry = {"content": chunk_content.strip(), "source_id": source_id}
|
|
|
|
all_chunks_data[chunk_id] = chunk_entry
|
|
|
|
chunk_to_source_map[source_id] = chunk_id
|
|
|
|
update_storage = True
|
|
|
|
|
|
|
|
if self.chunks_vdb is not None and all_chunks_data:
|
|
|
|
await self.chunks_vdb.upsert(all_chunks_data)
|
|
|
|
if self.text_chunks is not None and all_chunks_data:
|
|
|
|
await self.text_chunks.upsert(all_chunks_data)
|
|
|
|
|
2024-11-25 18:06:19 +08:00
|
|
|
# Insert entities into knowledge graph
|
|
|
|
all_entities_data = []
|
|
|
|
for entity_data in custom_kg.get("entities", []):
|
|
|
|
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
|
|
|
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
|
|
|
description = entity_data.get("description", "No description provided")
|
2024-12-04 19:44:04 +08:00
|
|
|
# source_id = entity_data["source_id"]
|
|
|
|
source_chunk_id = entity_data.get("source_id", "UNKNOWN")
|
|
|
|
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
|
|
|
|
|
|
|
|
# Log if source_id is UNKNOWN
|
|
|
|
if source_id == "UNKNOWN":
|
|
|
|
logger.warning(
|
|
|
|
f"Entity '{entity_name}' has an UNKNOWN source_id. Please check the source mapping."
|
|
|
|
)
|
2024-11-25 18:06:19 +08:00
|
|
|
|
|
|
|
# Prepare node data
|
|
|
|
node_data = {
|
|
|
|
"entity_type": entity_type,
|
|
|
|
"description": description,
|
|
|
|
"source_id": source_id,
|
|
|
|
}
|
|
|
|
# Insert node data into the knowledge graph
|
|
|
|
await self.chunk_entity_relation_graph.upsert_node(
|
|
|
|
entity_name, node_data=node_data
|
|
|
|
)
|
|
|
|
node_data["entity_name"] = entity_name
|
|
|
|
all_entities_data.append(node_data)
|
|
|
|
update_storage = True
|
|
|
|
|
|
|
|
# Insert relationships into knowledge graph
|
|
|
|
all_relationships_data = []
|
|
|
|
for relationship_data in custom_kg.get("relationships", []):
|
|
|
|
src_id = f'"{relationship_data["src_id"].upper()}"'
|
|
|
|
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
|
|
|
description = relationship_data["description"]
|
|
|
|
keywords = relationship_data["keywords"]
|
|
|
|
weight = relationship_data.get("weight", 1.0)
|
2024-12-04 19:44:04 +08:00
|
|
|
# source_id = relationship_data["source_id"]
|
|
|
|
source_chunk_id = relationship_data.get("source_id", "UNKNOWN")
|
|
|
|
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
|
|
|
|
|
|
|
|
# Log if source_id is UNKNOWN
|
|
|
|
if source_id == "UNKNOWN":
|
|
|
|
logger.warning(
|
|
|
|
f"Relationship from '{src_id}' to '{tgt_id}' has an UNKNOWN source_id. Please check the source mapping."
|
|
|
|
)
|
2024-11-25 18:06:19 +08:00
|
|
|
|
|
|
|
# Check if nodes exist in the knowledge graph
|
|
|
|
for need_insert_id in [src_id, tgt_id]:
|
|
|
|
if not (
|
2025-01-07 16:26:12 +08:00
|
|
|
await self.chunk_entity_relation_graph.has_node(need_insert_id)
|
2024-11-25 18:06:19 +08:00
|
|
|
):
|
|
|
|
await self.chunk_entity_relation_graph.upsert_node(
|
|
|
|
need_insert_id,
|
|
|
|
node_data={
|
|
|
|
"source_id": source_id,
|
|
|
|
"description": "UNKNOWN",
|
|
|
|
"entity_type": "UNKNOWN",
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
# Insert edge into the knowledge graph
|
|
|
|
await self.chunk_entity_relation_graph.upsert_edge(
|
|
|
|
src_id,
|
|
|
|
tgt_id,
|
|
|
|
edge_data={
|
|
|
|
"weight": weight,
|
|
|
|
"description": description,
|
|
|
|
"keywords": keywords,
|
|
|
|
"source_id": source_id,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
edge_data = {
|
|
|
|
"src_id": src_id,
|
|
|
|
"tgt_id": tgt_id,
|
|
|
|
"description": description,
|
|
|
|
"keywords": keywords,
|
|
|
|
}
|
|
|
|
all_relationships_data.append(edge_data)
|
|
|
|
update_storage = True
|
|
|
|
|
|
|
|
# Insert entities into vector storage if needed
|
|
|
|
if self.entities_vdb is not None:
|
|
|
|
data_for_vdb = {
|
|
|
|
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
|
|
|
"content": dp["entity_name"] + dp["description"],
|
|
|
|
"entity_name": dp["entity_name"],
|
|
|
|
}
|
|
|
|
for dp in all_entities_data
|
|
|
|
}
|
|
|
|
await self.entities_vdb.upsert(data_for_vdb)
|
|
|
|
|
|
|
|
# Insert relationships into vector storage if needed
|
|
|
|
if self.relationships_vdb is not None:
|
|
|
|
data_for_vdb = {
|
|
|
|
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
|
|
|
"src_id": dp["src_id"],
|
|
|
|
"tgt_id": dp["tgt_id"],
|
|
|
|
"content": dp["keywords"]
|
2025-01-07 16:26:12 +08:00
|
|
|
+ dp["src_id"]
|
|
|
|
+ dp["tgt_id"]
|
|
|
|
+ dp["description"],
|
2024-11-25 18:06:19 +08:00
|
|
|
}
|
|
|
|
for dp in all_relationships_data
|
|
|
|
}
|
|
|
|
await self.relationships_vdb.upsert(data_for_vdb)
|
|
|
|
finally:
|
|
|
|
if update_storage:
|
|
|
|
await self._insert_done()
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
def query(self, query: str, param: QueryParam = QueryParam()):
|
|
|
|
loop = always_get_an_event_loop()
|
|
|
|
return loop.run_until_complete(self.aquery(query, param))
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
2024-11-25 13:29:55 +08:00
|
|
|
if param.mode in ["local", "global", "hybrid"]:
|
|
|
|
response = await kg_query(
|
2024-10-10 15:02:30 +08:00
|
|
|
query,
|
|
|
|
self.chunk_entity_relation_graph,
|
|
|
|
self.entities_vdb,
|
|
|
|
self.relationships_vdb,
|
|
|
|
self.text_chunks,
|
|
|
|
param,
|
|
|
|
asdict(self),
|
2024-12-17 16:44:42 +08:00
|
|
|
hashing_kv=self.llm_response_cache
|
|
|
|
if self.llm_response_cache
|
2025-01-07 16:26:12 +08:00
|
|
|
and hasattr(self.llm_response_cache, "global_config")
|
2024-12-17 16:44:42 +08:00
|
|
|
else self.key_string_value_json_storage_cls(
|
|
|
|
namespace="llm_response_cache",
|
|
|
|
global_config=asdict(self),
|
|
|
|
embedding_func=None,
|
|
|
|
),
|
2024-10-10 15:02:30 +08:00
|
|
|
)
|
|
|
|
elif param.mode == "naive":
|
|
|
|
response = await naive_query(
|
|
|
|
query,
|
|
|
|
self.chunks_vdb,
|
|
|
|
self.text_chunks,
|
|
|
|
param,
|
|
|
|
asdict(self),
|
2024-12-17 16:44:42 +08:00
|
|
|
hashing_kv=self.llm_response_cache
|
|
|
|
if self.llm_response_cache
|
2025-01-07 16:26:12 +08:00
|
|
|
and hasattr(self.llm_response_cache, "global_config")
|
2024-12-17 16:44:42 +08:00
|
|
|
else self.key_string_value_json_storage_cls(
|
|
|
|
namespace="llm_response_cache",
|
|
|
|
global_config=asdict(self),
|
|
|
|
embedding_func=None,
|
|
|
|
),
|
2024-10-10 15:02:30 +08:00
|
|
|
)
|
2024-12-28 11:56:28 +08:00
|
|
|
elif param.mode == "mix":
|
|
|
|
response = await mix_kg_vector_query(
|
|
|
|
query,
|
|
|
|
self.chunk_entity_relation_graph,
|
|
|
|
self.entities_vdb,
|
|
|
|
self.relationships_vdb,
|
|
|
|
self.chunks_vdb,
|
|
|
|
self.text_chunks,
|
|
|
|
param,
|
|
|
|
asdict(self),
|
|
|
|
hashing_kv=self.llm_response_cache
|
|
|
|
if self.llm_response_cache
|
2025-01-07 16:26:12 +08:00
|
|
|
and hasattr(self.llm_response_cache, "global_config")
|
2024-12-28 11:56:28 +08:00
|
|
|
else self.key_string_value_json_storage_cls(
|
|
|
|
namespace="llm_response_cache",
|
|
|
|
global_config=asdict(self),
|
|
|
|
embedding_func=None,
|
|
|
|
),
|
|
|
|
)
|
2024-10-10 15:02:30 +08:00
|
|
|
else:
|
|
|
|
raise ValueError(f"Unknown mode {param.mode}")
|
|
|
|
await self._query_done()
|
|
|
|
return response
|
|
|
|
|
|
|
|
async def _query_done(self):
|
|
|
|
tasks = []
|
|
|
|
for storage_inst in [self.llm_response_cache]:
|
|
|
|
if storage_inst is None:
|
|
|
|
continue
|
|
|
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
2024-11-06 11:18:14 -05:00
|
|
|
await asyncio.gather(*tasks)
|
2024-11-11 17:48:40 +08:00
|
|
|
|
|
|
|
def delete_by_entity(self, entity_name: str):
|
|
|
|
loop = always_get_an_event_loop()
|
|
|
|
return loop.run_until_complete(self.adelete_by_entity(entity_name))
|
2024-11-11 17:54:22 +08:00
|
|
|
|
2024-11-11 17:48:40 +08:00
|
|
|
async def adelete_by_entity(self, entity_name: str):
|
2024-11-11 17:54:22 +08:00
|
|
|
entity_name = f'"{entity_name.upper()}"'
|
2024-11-11 17:48:40 +08:00
|
|
|
|
|
|
|
try:
|
|
|
|
await self.entities_vdb.delete_entity(entity_name)
|
2024-12-31 17:15:57 +08:00
|
|
|
await self.relationships_vdb.delete_entity_relation(entity_name)
|
2024-11-11 17:48:40 +08:00
|
|
|
await self.chunk_entity_relation_graph.delete_node(entity_name)
|
|
|
|
|
2024-11-11 17:54:22 +08:00
|
|
|
logger.info(
|
|
|
|
f"Entity '{entity_name}' and its relationships have been deleted."
|
|
|
|
)
|
2024-11-11 17:48:40 +08:00
|
|
|
await self._delete_by_entity_done()
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
2024-11-11 17:54:22 +08:00
|
|
|
|
2024-11-11 17:48:40 +08:00
|
|
|
async def _delete_by_entity_done(self):
|
|
|
|
tasks = []
|
|
|
|
for storage_inst in [
|
|
|
|
self.entities_vdb,
|
|
|
|
self.relationships_vdb,
|
|
|
|
self.chunk_entity_relation_graph,
|
|
|
|
]:
|
|
|
|
if storage_inst is None:
|
|
|
|
continue
|
|
|
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
2024-11-11 17:54:22 +08:00
|
|
|
await asyncio.gather(*tasks)
|
2024-12-28 00:11:25 +08:00
|
|
|
|
|
|
|
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
|
|
|
|
"""Get summary of document content
|
|
|
|
|
|
|
|
Args:
|
|
|
|
content: Original document content
|
|
|
|
max_length: Maximum length of summary
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Truncated content with ellipsis if needed
|
|
|
|
"""
|
|
|
|
content = content.strip()
|
|
|
|
if len(content) <= max_length:
|
|
|
|
return content
|
|
|
|
return content[:max_length] + "..."
|
|
|
|
|
|
|
|
async def get_processing_status(self) -> Dict[str, int]:
|
|
|
|
"""Get current document processing status counts
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict with counts for each status
|
|
|
|
"""
|
|
|
|
return await self.doc_status.get_status_counts()
|
2024-12-31 17:15:57 +08:00
|
|
|
|
|
|
|
async def adelete_by_doc_id(self, doc_id: str):
|
|
|
|
"""Delete a document and all its related data
|
|
|
|
|
|
|
|
Args:
|
|
|
|
doc_id: Document ID to delete
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
# 1. Get the document status and related data
|
|
|
|
doc_status = await self.doc_status.get(doc_id)
|
|
|
|
if not doc_status:
|
|
|
|
logger.warning(f"Document {doc_id} not found")
|
|
|
|
return
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
logger.debug(f"Starting deletion for document {doc_id}")
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
# 2. Get all related chunks
|
2024-12-31 17:32:04 +08:00
|
|
|
chunks = await self.text_chunks.filter(
|
|
|
|
lambda x: x.get("full_doc_id") == doc_id
|
|
|
|
)
|
2024-12-31 17:15:57 +08:00
|
|
|
chunk_ids = list(chunks.keys())
|
|
|
|
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
# 3. Before deleting, check the related entities and relationships for these chunks
|
|
|
|
for chunk_id in chunk_ids:
|
|
|
|
# Check entities
|
|
|
|
entities = [
|
2024-12-31 17:32:04 +08:00
|
|
|
dp
|
|
|
|
for dp in self.entities_vdb.client_storage["data"]
|
2024-12-31 17:15:57 +08:00
|
|
|
if dp.get("source_id") == chunk_id
|
|
|
|
]
|
|
|
|
logger.debug(f"Chunk {chunk_id} has {len(entities)} related entities")
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
# Check relationships
|
|
|
|
relations = [
|
2024-12-31 17:32:04 +08:00
|
|
|
dp
|
|
|
|
for dp in self.relationships_vdb.client_storage["data"]
|
2024-12-31 17:15:57 +08:00
|
|
|
if dp.get("source_id") == chunk_id
|
|
|
|
]
|
|
|
|
logger.debug(f"Chunk {chunk_id} has {len(relations)} related relations")
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
# Continue with the original deletion process...
|
|
|
|
|
|
|
|
# 4. Delete chunks from vector database
|
|
|
|
if chunk_ids:
|
|
|
|
await self.chunks_vdb.delete(chunk_ids)
|
|
|
|
await self.text_chunks.delete(chunk_ids)
|
|
|
|
|
|
|
|
# 5. Find and process entities and relationships that have these chunks as source
|
|
|
|
# Get all nodes in the graph
|
|
|
|
nodes = self.chunk_entity_relation_graph._graph.nodes(data=True)
|
|
|
|
edges = self.chunk_entity_relation_graph._graph.edges(data=True)
|
|
|
|
|
|
|
|
# Track which entities and relationships need to be deleted or updated
|
|
|
|
entities_to_delete = set()
|
|
|
|
entities_to_update = {} # entity_name -> new_source_id
|
|
|
|
relationships_to_delete = set()
|
|
|
|
relationships_to_update = {} # (src, tgt) -> new_source_id
|
|
|
|
|
|
|
|
# Process entities
|
|
|
|
for node, data in nodes:
|
2024-12-31 17:32:04 +08:00
|
|
|
if "source_id" in data:
|
2024-12-31 17:15:57 +08:00
|
|
|
# Split source_id using GRAPH_FIELD_SEP
|
2024-12-31 17:32:04 +08:00
|
|
|
sources = set(data["source_id"].split(GRAPH_FIELD_SEP))
|
2024-12-31 17:15:57 +08:00
|
|
|
sources.difference_update(chunk_ids)
|
|
|
|
if not sources:
|
|
|
|
entities_to_delete.add(node)
|
2024-12-31 17:32:04 +08:00
|
|
|
logger.debug(
|
|
|
|
f"Entity {node} marked for deletion - no remaining sources"
|
|
|
|
)
|
2024-12-31 17:15:57 +08:00
|
|
|
else:
|
|
|
|
new_source_id = GRAPH_FIELD_SEP.join(sources)
|
|
|
|
entities_to_update[node] = new_source_id
|
2024-12-31 17:32:04 +08:00
|
|
|
logger.debug(
|
|
|
|
f"Entity {node} will be updated with new source_id: {new_source_id}"
|
|
|
|
)
|
2024-12-31 17:15:57 +08:00
|
|
|
|
|
|
|
# Process relationships
|
|
|
|
for src, tgt, data in edges:
|
2024-12-31 17:32:04 +08:00
|
|
|
if "source_id" in data:
|
2024-12-31 17:15:57 +08:00
|
|
|
# Split source_id using GRAPH_FIELD_SEP
|
2024-12-31 17:32:04 +08:00
|
|
|
sources = set(data["source_id"].split(GRAPH_FIELD_SEP))
|
2024-12-31 17:15:57 +08:00
|
|
|
sources.difference_update(chunk_ids)
|
|
|
|
if not sources:
|
|
|
|
relationships_to_delete.add((src, tgt))
|
2024-12-31 17:32:04 +08:00
|
|
|
logger.debug(
|
|
|
|
f"Relationship {src}-{tgt} marked for deletion - no remaining sources"
|
|
|
|
)
|
2024-12-31 17:15:57 +08:00
|
|
|
else:
|
|
|
|
new_source_id = GRAPH_FIELD_SEP.join(sources)
|
|
|
|
relationships_to_update[(src, tgt)] = new_source_id
|
2024-12-31 17:32:04 +08:00
|
|
|
logger.debug(
|
|
|
|
f"Relationship {src}-{tgt} will be updated with new source_id: {new_source_id}"
|
|
|
|
)
|
2024-12-31 17:15:57 +08:00
|
|
|
|
|
|
|
# Delete entities
|
|
|
|
if entities_to_delete:
|
|
|
|
for entity in entities_to_delete:
|
|
|
|
await self.entities_vdb.delete_entity(entity)
|
|
|
|
logger.debug(f"Deleted entity {entity} from vector DB")
|
|
|
|
self.chunk_entity_relation_graph.remove_nodes(list(entities_to_delete))
|
|
|
|
logger.debug(f"Deleted {len(entities_to_delete)} entities from graph")
|
|
|
|
|
|
|
|
# Update entities
|
|
|
|
for entity, new_source_id in entities_to_update.items():
|
|
|
|
node_data = self.chunk_entity_relation_graph._graph.nodes[entity]
|
2024-12-31 17:32:04 +08:00
|
|
|
node_data["source_id"] = new_source_id
|
2024-12-31 17:15:57 +08:00
|
|
|
await self.chunk_entity_relation_graph.upsert_node(entity, node_data)
|
2024-12-31 17:32:04 +08:00
|
|
|
logger.debug(
|
|
|
|
f"Updated entity {entity} with new source_id: {new_source_id}"
|
|
|
|
)
|
2024-12-31 17:15:57 +08:00
|
|
|
|
|
|
|
# Delete relationships
|
|
|
|
if relationships_to_delete:
|
|
|
|
for src, tgt in relationships_to_delete:
|
|
|
|
rel_id_0 = compute_mdhash_id(src + tgt, prefix="rel-")
|
|
|
|
rel_id_1 = compute_mdhash_id(tgt + src, prefix="rel-")
|
|
|
|
await self.relationships_vdb.delete([rel_id_0, rel_id_1])
|
|
|
|
logger.debug(f"Deleted relationship {src}-{tgt} from vector DB")
|
2024-12-31 17:32:04 +08:00
|
|
|
self.chunk_entity_relation_graph.remove_edges(
|
|
|
|
list(relationships_to_delete)
|
|
|
|
)
|
|
|
|
logger.debug(
|
|
|
|
f"Deleted {len(relationships_to_delete)} relationships from graph"
|
|
|
|
)
|
2024-12-31 17:15:57 +08:00
|
|
|
|
|
|
|
# Update relationships
|
|
|
|
for (src, tgt), new_source_id in relationships_to_update.items():
|
|
|
|
edge_data = self.chunk_entity_relation_graph._graph.edges[src, tgt]
|
2024-12-31 17:32:04 +08:00
|
|
|
edge_data["source_id"] = new_source_id
|
2024-12-31 17:15:57 +08:00
|
|
|
await self.chunk_entity_relation_graph.upsert_edge(src, tgt, edge_data)
|
2024-12-31 17:32:04 +08:00
|
|
|
logger.debug(
|
|
|
|
f"Updated relationship {src}-{tgt} with new source_id: {new_source_id}"
|
|
|
|
)
|
2024-12-31 17:15:57 +08:00
|
|
|
|
|
|
|
# 6. Delete original document and status
|
|
|
|
await self.full_docs.delete([doc_id])
|
|
|
|
await self.doc_status.delete([doc_id])
|
|
|
|
|
|
|
|
# 7. Ensure all indexes are updated
|
|
|
|
await self._insert_done()
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
f"Successfully deleted document {doc_id} and related data. "
|
|
|
|
f"Deleted {len(entities_to_delete)} entities and {len(relationships_to_delete)} relationships. "
|
|
|
|
f"Updated {len(entities_to_update)} entities and {len(relationships_to_update)} relationships."
|
|
|
|
)
|
|
|
|
|
|
|
|
# Add verification step
|
|
|
|
async def verify_deletion():
|
|
|
|
# Verify if the document has been deleted
|
|
|
|
if await self.full_docs.get_by_id(doc_id):
|
|
|
|
logger.error(f"Document {doc_id} still exists in full_docs")
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
# Verify if chunks have been deleted
|
|
|
|
remaining_chunks = await self.text_chunks.filter(
|
|
|
|
lambda x: x.get("full_doc_id") == doc_id
|
|
|
|
)
|
|
|
|
if remaining_chunks:
|
|
|
|
logger.error(f"Found {len(remaining_chunks)} remaining chunks")
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
# Verify entities and relationships
|
|
|
|
for chunk_id in chunk_ids:
|
|
|
|
# Check entities
|
|
|
|
entities_with_chunk = [
|
2024-12-31 17:32:04 +08:00
|
|
|
dp
|
|
|
|
for dp in self.entities_vdb.client_storage["data"]
|
|
|
|
if chunk_id
|
2025-01-07 16:26:12 +08:00
|
|
|
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
2024-12-31 17:15:57 +08:00
|
|
|
]
|
|
|
|
if entities_with_chunk:
|
2024-12-31 17:32:04 +08:00
|
|
|
logger.error(
|
|
|
|
f"Found {len(entities_with_chunk)} entities still referencing chunk {chunk_id}"
|
|
|
|
)
|
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
# Check relationships
|
|
|
|
relations_with_chunk = [
|
2024-12-31 17:32:04 +08:00
|
|
|
dp
|
|
|
|
for dp in self.relationships_vdb.client_storage["data"]
|
|
|
|
if chunk_id
|
2025-01-07 16:26:12 +08:00
|
|
|
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
2024-12-31 17:15:57 +08:00
|
|
|
]
|
|
|
|
if relations_with_chunk:
|
2024-12-31 17:32:04 +08:00
|
|
|
logger.error(
|
|
|
|
f"Found {len(relations_with_chunk)} relations still referencing chunk {chunk_id}"
|
|
|
|
)
|
2024-12-31 17:15:57 +08:00
|
|
|
|
|
|
|
await verify_deletion()
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error while deleting document {doc_id}: {e}")
|
|
|
|
|
|
|
|
def delete_by_doc_id(self, doc_id: str):
|
|
|
|
"""Synchronous version of adelete"""
|
|
|
|
return asyncio.run(self.adelete_by_doc_id(doc_id))
|
|
|
|
|
2024-12-31 17:32:04 +08:00
|
|
|
async def get_entity_info(
|
2025-01-07 16:26:12 +08:00
|
|
|
self, entity_name: str, include_vector_data: bool = False
|
2024-12-31 17:32:04 +08:00
|
|
|
):
|
2024-12-31 17:15:57 +08:00
|
|
|
"""Get detailed information of an entity
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
Args:
|
|
|
|
entity_name: Entity name (no need for quotes)
|
|
|
|
include_vector_data: Whether to include data from the vector database
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing entity information, including:
|
|
|
|
- entity_name: Entity name
|
|
|
|
- source_id: Source document ID
|
|
|
|
- graph_data: Complete node data from the graph database
|
|
|
|
- vector_data: (optional) Data from the vector database
|
|
|
|
"""
|
|
|
|
entity_name = f'"{entity_name.upper()}"'
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
# Get information from the graph
|
|
|
|
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
2024-12-31 17:32:04 +08:00
|
|
|
source_id = node_data.get("source_id") if node_data else None
|
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
result = {
|
|
|
|
"entity_name": entity_name,
|
|
|
|
"source_id": source_id,
|
|
|
|
"graph_data": node_data,
|
|
|
|
}
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
# Optional: Get vector database information
|
|
|
|
if include_vector_data:
|
|
|
|
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
|
|
|
vector_data = self.entities_vdb._client.get([entity_id])
|
|
|
|
result["vector_data"] = vector_data[0] if vector_data else None
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
return result
|
|
|
|
|
|
|
|
def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
|
|
|
|
"""Synchronous version of getting entity information
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
Args:
|
|
|
|
entity_name: Entity name (no need for quotes)
|
|
|
|
include_vector_data: Whether to include data from the vector database
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
import tracemalloc
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
tracemalloc.start()
|
|
|
|
return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
|
|
|
|
finally:
|
|
|
|
tracemalloc.stop()
|
|
|
|
|
2024-12-31 17:32:04 +08:00
|
|
|
async def get_relation_info(
|
2025-01-07 16:26:12 +08:00
|
|
|
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
2024-12-31 17:32:04 +08:00
|
|
|
):
|
2024-12-31 17:15:57 +08:00
|
|
|
"""Get detailed information of a relationship
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
Args:
|
|
|
|
src_entity: Source entity name (no need for quotes)
|
|
|
|
tgt_entity: Target entity name (no need for quotes)
|
|
|
|
include_vector_data: Whether to include data from the vector database
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing relationship information, including:
|
|
|
|
- src_entity: Source entity name
|
|
|
|
- tgt_entity: Target entity name
|
|
|
|
- source_id: Source document ID
|
|
|
|
- graph_data: Complete edge data from the graph database
|
|
|
|
- vector_data: (optional) Data from the vector database
|
|
|
|
"""
|
|
|
|
src_entity = f'"{src_entity.upper()}"'
|
|
|
|
tgt_entity = f'"{tgt_entity.upper()}"'
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
# Get information from the graph
|
2024-12-31 17:32:04 +08:00
|
|
|
edge_data = await self.chunk_entity_relation_graph.get_edge(
|
|
|
|
src_entity, tgt_entity
|
|
|
|
)
|
|
|
|
source_id = edge_data.get("source_id") if edge_data else None
|
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
result = {
|
|
|
|
"src_entity": src_entity,
|
|
|
|
"tgt_entity": tgt_entity,
|
|
|
|
"source_id": source_id,
|
|
|
|
"graph_data": edge_data,
|
|
|
|
}
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
# Optional: Get vector database information
|
|
|
|
if include_vector_data:
|
|
|
|
rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-")
|
|
|
|
vector_data = self.relationships_vdb._client.get([rel_id])
|
|
|
|
result["vector_data"] = vector_data[0] if vector_data else None
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
return result
|
|
|
|
|
2024-12-31 17:32:04 +08:00
|
|
|
def get_relation_info_sync(
|
2025-01-07 16:26:12 +08:00
|
|
|
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
2024-12-31 17:32:04 +08:00
|
|
|
):
|
2024-12-31 17:15:57 +08:00
|
|
|
"""Synchronous version of getting relationship information
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
Args:
|
|
|
|
src_entity: Source entity name (no need for quotes)
|
|
|
|
tgt_entity: Target entity name (no need for quotes)
|
|
|
|
include_vector_data: Whether to include data from the vector database
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
import tracemalloc
|
2024-12-31 17:32:04 +08:00
|
|
|
|
2024-12-31 17:15:57 +08:00
|
|
|
tracemalloc.start()
|
2024-12-31 17:32:04 +08:00
|
|
|
return asyncio.run(
|
|
|
|
self.get_relation_info(src_entity, tgt_entity, include_vector_data)
|
|
|
|
)
|
2024-12-31 17:15:57 +08:00
|
|
|
finally:
|
|
|
|
tracemalloc.stop()
|