LightRAG/lightrag/lightrag.py

2240 lines
89 KiB
Python
Raw Normal View History

2025-02-14 22:50:49 +01:00
from __future__ import annotations
2024-10-10 15:02:30 +08:00
import asyncio
import configparser
import os
import warnings
2024-10-10 15:02:30 +08:00
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
2025-02-20 13:05:35 +01:00
from typing import Any, AsyncIterator, Callable, Iterator, cast, final
2025-02-20 12:54:52 +01:00
2025-02-20 13:44:17 +01:00
from lightrag.kg import (
STORAGE_ENV_REQUIREMENTS,
STORAGES,
verify_storage_implementation,
)
2025-02-20 13:21:41 +01:00
2025-02-09 19:21:49 +01:00
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
DocProcessingStatus,
DocStatus,
DocStatusStorage,
QueryParam,
StorageNameSpace,
StoragesStatus,
2025-02-09 19:21:49 +01:00
)
from .namespace import NameSpace, make_namespace
2024-10-10 15:02:30 +08:00
from .operate import (
chunking_by_token_size,
2025-02-09 11:24:08 +01:00
extract_entities,
extract_keywords_only,
kg_query,
kg_query_with_keywords,
mix_kg_vector_query,
naive_query,
2024-10-10 15:02:30 +08:00
)
2025-02-09 19:21:49 +01:00
from .prompt import GRAPH_FIELD_SEP
2024-10-10 15:02:30 +08:00
from .utils import (
EmbeddingFunc,
2025-02-20 13:18:17 +01:00
always_get_an_event_loop,
2024-10-10 15:02:30 +08:00
compute_mdhash_id,
convert_response_to_json,
encode_string_by_tiktoken,
2025-02-20 13:18:17 +01:00
lazy_external_import,
2025-02-09 19:21:49 +01:00
limit_async_func_call,
2024-10-10 15:02:30 +08:00
logger,
)
2025-02-20 14:29:36 +01:00
from .types import KnowledgeGraph
from dotenv import load_dotenv
# Load environment variables
load_dotenv(override=True)
2025-02-09 11:24:08 +01:00
2025-02-20 13:39:46 +01:00
# TODO: TO REMOVE @Yannick
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
2025-02-20 13:09:33 +01:00
2025-02-20 13:05:35 +01:00
@final
2024-10-10 15:02:30 +08:00
@dataclass
class LightRAG:
2025-02-09 00:23:55 +01:00
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
2025-02-20 13:13:38 +01:00
# Directory
# ---
2024-10-10 15:02:30 +08:00
working_dir: str = field(
2025-02-20 13:05:59 +01:00
default=f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
2024-10-10 15:02:30 +08:00
)
2025-02-09 00:23:55 +01:00
"""Directory where cache and temporary files are stored."""
2025-02-20 13:13:38 +01:00
# Storage
# ---
2025-02-09 00:23:55 +01:00
2024-11-12 13:32:40 +08:00
kv_storage: str = field(default="JsonKVStorage")
2025-02-09 00:23:55 +01:00
"""Storage backend for key-value data."""
vector_storage: str = field(default="NanoVectorDBStorage")
2025-02-09 00:23:55 +01:00
"""Storage backend for vector embeddings."""
graph_storage: str = field(default="NetworkXStorage")
2025-02-09 00:23:55 +01:00
"""Storage backend for knowledge graphs."""
doc_status_storage: str = field(default="JsonDocStatusStorage")
"""Storage type for tracking document processing statuses."""
# Logging (Deprecated, use setup_logger in utils.py instead)
2025-02-20 13:13:38 +01:00
# ---
2025-02-20 13:05:59 +01:00
log_level: int = field(default=logger.level)
2025-02-20 13:27:55 +01:00
log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log"))
2025-02-20 13:13:38 +01:00
# Entity extraction
# ---
entity_extract_max_gleaning: int = field(default=1)
"""Maximum number of entity extraction attempts for ambiguous content."""
entity_summary_to_max_tokens: int = field(
default=int(os.getenv("MAX_TOKEN_SUMMARY", 500))
)
2025-02-09 00:23:55 +01:00
# Text chunking
2025-02-20 13:13:38 +01:00
# ---
2025-02-20 13:05:59 +01:00
chunk_token_size: int = field(default=int(os.getenv("CHUNK_SIZE", 1200)))
2025-02-09 00:23:55 +01:00
"""Maximum number of tokens per text chunk when splitting documents."""
2025-02-20 13:09:33 +01:00
chunk_overlap_token_size: int = field(
default=int(os.getenv("CHUNK_OVERLAP_SIZE", 100))
)
2025-02-09 00:23:55 +01:00
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
2025-02-20 13:05:59 +01:00
tiktoken_model_name: str = field(default="gpt-4o-mini")
2025-02-09 00:23:55 +01:00
"""Model name used for tokenization when chunking text."""
2024-10-10 15:02:30 +08:00
2025-02-09 00:23:55 +01:00
"""Maximum number of tokens used for summarizing extracted entities."""
2024-10-10 15:02:30 +08:00
2025-02-20 13:13:38 +01:00
chunking_func: Callable[
[
str,
str | None,
bool,
int,
int,
str,
],
list[dict[str, Any]],
] = field(default_factory=lambda: chunking_by_token_size)
"""
Custom chunking function for splitting text into chunks before processing.
The function should take the following parameters:
- `content`: The text to be split into chunks.
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
- `split_by_character_only`: If True, the text is split only on the specified character.
- `chunk_token_size`: The maximum number of tokens per chunk.
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
- `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
The function should return a list of dictionaries, where each dictionary contains the following keys:
- `tokens`: The number of tokens in the chunk.
- `content`: The text content of the chunk.
Defaults to `chunking_by_token_size` if not specified.
"""
2025-02-09 00:23:55 +01:00
# Node embedding
2025-02-20 13:13:38 +01:00
# ---
2025-02-20 13:09:33 +01:00
node_embedding_algorithm: str = field(default="node2vec")
2025-02-09 00:23:55 +01:00
"""Algorithm used for node embedding in knowledge graphs."""
node2vec_params: dict[str, int] = field(
2024-10-10 15:02:30 +08:00
default_factory=lambda: {
"dimensions": 1536,
"num_walks": 10,
"walk_length": 40,
"window_size": 2,
"iterations": 3,
"random_seed": 3,
}
)
2025-02-09 00:23:55 +01:00
"""Configuration for the node2vec embedding algorithm:
- dimensions: Number of dimensions for embeddings.
- num_walks: Number of random walks per node.
- walk_length: Number of steps per random walk.
- window_size: Context window size for training.
- iterations: Number of iterations for training.
- random_seed: Seed value for reproducibility.
"""
2025-02-20 13:13:38 +01:00
# Embedding
# ---
2025-02-20 13:06:16 +01:00
embedding_func: EmbeddingFunc | None = field(default=None)
2025-02-09 00:23:55 +01:00
"""Function for computing text embeddings. Must be set before use."""
2024-10-10 15:02:30 +08:00
2025-02-20 13:06:16 +01:00
embedding_batch_num: int = field(default=32)
2025-02-09 00:23:55 +01:00
"""Batch size for embedding computations."""
2025-02-20 13:06:16 +01:00
embedding_func_max_async: int = field(default=16)
2025-02-09 00:23:55 +01:00
"""Maximum number of concurrent embedding function calls."""
2025-02-20 13:13:38 +01:00
embedding_cache_config: dict[str, Any] = field(
2025-02-20 14:17:26 +01:00
default_factory=lambda: {
2025-02-20 13:13:38 +01:00
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False,
}
)
"""Configuration for embedding cache.
- enabled: If True, enables caching to avoid redundant computations.
- similarity_threshold: Minimum similarity score to use cached embeddings.
- use_llm_check: If True, validates cached embeddings using an LLM.
"""
2025-02-09 00:23:55 +01:00
# LLM Configuration
2025-02-20 13:13:38 +01:00
# ---
2025-02-20 13:06:16 +01:00
llm_model_func: Callable[..., object] | None = field(default=None)
2025-02-09 00:23:55 +01:00
"""Function for interacting with the large language model (LLM). Must be set before use."""
2025-02-20 13:06:16 +01:00
llm_model_name: str = field(default="gpt-4o-mini")
2025-02-09 00:23:55 +01:00
"""Name of the LLM model used for generating responses."""
2024-10-10 15:02:30 +08:00
2025-02-20 13:06:16 +01:00
llm_model_max_token_size: int = field(default=int(os.getenv("MAX_TOKENS", 32768)))
2025-02-09 00:23:55 +01:00
"""Maximum number of tokens allowed per LLM response."""
2025-02-20 13:06:16 +01:00
llm_model_max_async: int = field(default=int(os.getenv("MAX_ASYNC", 16)))
2025-02-09 00:23:55 +01:00
"""Maximum number of concurrent LLM calls."""
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the LLM model function."""
# Storage
2025-02-20 13:13:38 +01:00
# ---
2025-02-09 00:23:55 +01:00
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional parameters for vector database storage."""
2024-10-10 15:02:30 +08:00
namespace_prefix: str = field(default="")
2025-02-09 00:23:55 +01:00
"""Prefix for namespacing stored data across different environments."""
2024-11-12 13:32:40 +08:00
2025-02-20 13:06:34 +01:00
enable_llm_cache: bool = field(default=True)
2025-02-09 00:23:55 +01:00
"""Enables caching for LLM responses to avoid redundant computations."""
2025-02-20 13:06:34 +01:00
enable_llm_cache_for_entity_extract: bool = field(default=True)
2025-02-09 00:23:55 +01:00
"""If True, enables caching for entity extraction steps to reduce LLM costs."""
# Extensions
2025-02-20 13:13:38 +01:00
# ---
2025-02-20 13:06:34 +01:00
max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20)))
2025-02-20 12:57:25 +01:00
"""Maximum number of parallel insert operations."""
2025-02-20 13:09:33 +01:00
2025-02-09 00:23:55 +01:00
addon_params: dict[str, Any] = field(default_factory=dict)
2024-10-10 15:02:30 +08:00
# Storages Management
2025-02-20 13:13:38 +01:00
# ---
2025-02-20 13:06:34 +01:00
auto_manage_storages_states: bool = field(default=True)
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
2025-02-20 13:13:38 +01:00
# Storages Management
# ---
2025-02-20 13:09:33 +01:00
convert_response_to_json_func: Callable[[str], dict[str, Any]] = field(
default_factory=lambda: convert_response_to_json
2025-02-09 13:18:47 +01:00
)
2025-02-20 13:09:33 +01:00
"""
Custom function for converting LLM responses to JSON format.
The default function is :func:`.utils.convert_response_to_json`.
"""
2024-10-10 15:02:30 +08:00
2025-02-20 13:44:17 +01:00
cosine_better_than_threshold: float = field(
default=float(os.getenv("COSINE_THRESHOLD", 0.2))
)
2025-02-20 13:30:30 +01:00
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
def __post_init__(self):
2025-02-27 19:05:51 +08:00
from lightrag.kg.shared_storage import (
initialize_share_data,
)
2025-02-28 21:35:04 +08:00
# Handle deprecated parameters
kwargs = self.__dict__
if "log_level" in kwargs:
warnings.warn(
"WARNING: log_level parameter is deprecated, use setup_logger in utils.py instead",
UserWarning,
stacklevel=2,
)
# Remove the attribute to prevent its use
delattr(self, "log_level")
if "log_file_path" in kwargs:
warnings.warn(
"WARNING: log_file_path parameter is deprecated, use setup_logger in utils.py instead",
UserWarning,
stacklevel=2,
)
delattr(self, "log_file_path")
initialize_share_data()
2025-01-16 12:52:37 +08:00
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
# Verify storage implementation compatibility and environment variables
storage_configs = [
("KV_STORAGE", self.kv_storage),
("VECTOR_STORAGE", self.vector_storage),
("GRAPH_STORAGE", self.graph_storage),
("DOC_STATUS_STORAGE", self.doc_status_storage),
]
for storage_type, storage_name in storage_configs:
# Verify storage implementation compatibility
2025-02-20 13:39:46 +01:00
verify_storage_implementation(storage_type, storage_name)
# Check environment variables
# self.check_storage_env_vars(storage_name)
# Ensure vector_db_storage_cls_kwargs has required fields
self.vector_db_storage_cls_kwargs = {
2025-02-20 13:44:17 +01:00
"cosine_better_than_threshold": self.cosine_better_than_threshold,
2025-02-13 04:12:00 +08:00
**self.vector_db_storage_cls_kwargs,
}
# Show config
2025-01-16 12:58:15 +08:00
global_config = asdict(self)
2025-01-16 12:52:37 +08:00
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
2024-10-10 15:02:30 +08:00
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
2025-01-16 12:52:37 +08:00
# Init LLM
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
2025-01-16 12:52:37 +08:00
self.embedding_func
)
2025-01-16 12:52:37 +08:00
# Initialize all storages
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
2025-01-16 12:58:15 +08:00
self._get_storage_class(self.kv_storage)
) # type: ignore
self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
2024-11-12 13:32:40 +08:00
self.vector_storage
) # type: ignore
self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
2024-11-12 13:32:40 +08:00
self.graph_storage
) # type: ignore
self.key_string_value_json_storage_cls = partial( # type: ignore
2025-01-16 12:58:15 +08:00
self.key_string_value_json_storage_cls, global_config=global_config
2025-01-16 12:52:37 +08:00
)
self.vector_db_storage_cls = partial( # type: ignore
2025-01-16 12:58:15 +08:00
self.vector_db_storage_cls, global_config=global_config
2024-11-12 13:32:40 +08:00
)
self.graph_storage_cls = partial( # type: ignore
2025-01-16 12:58:15 +08:00
self.graph_storage_cls, global_config=global_config
2025-01-16 12:52:37 +08:00
)
2025-02-11 10:17:51 +08:00
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
2025-02-11 10:17:51 +08:00
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
embedding_func=self.embedding_func,
)
2024-10-15 19:40:08 +08:00
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
),
2024-11-12 13:32:40 +08:00
embedding_func=self.embedding_func,
)
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
),
2024-11-12 13:32:40 +08:00
embedding_func=self.embedding_func,
)
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
),
2024-12-03 16:04:58 +08:00
embedding_func=self.embedding_func,
)
self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
),
embedding_func=self.embedding_func,
2025-02-27 23:34:57 +07:00
meta_fields={"entity_name", "source_id", "content"},
2024-10-10 15:02:30 +08:00
)
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
),
embedding_func=self.embedding_func,
2025-02-27 23:34:57 +07:00
meta_fields={"src_id", "tgt_id", "source_id", "content"},
2024-10-10 15:02:30 +08:00
)
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
),
embedding_func=self.embedding_func,
2024-10-10 15:02:30 +08:00
)
# Initialize document status storage
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
global_config=global_config,
embedding_func=None,
)
2025-01-16 12:58:15 +08:00
if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config"
):
2025-01-16 12:52:37 +08:00
hashing_kv = self.llm_response_cache
else:
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
2025-03-03 19:17:34 +08:00
global_config=asdict(self),
embedding_func=self.embedding_func,
2025-01-16 12:58:15 +08:00
)
2025-02-14 23:33:59 +01:00
2024-10-10 15:02:30 +08:00
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
2024-10-28 17:05:38 +02:00
partial(
self.llm_model_func, # type: ignore
2025-01-16 12:52:37 +08:00
hashing_kv=hashing_kv,
2024-10-28 17:05:38 +02:00
**self.llm_model_kwargs,
)
2024-10-10 15:02:30 +08:00
)
2024-11-06 11:18:14 -05:00
2025-02-20 13:30:30 +01:00
self._storages_status = StoragesStatus.CREATED
if self.auto_manage_storages_states:
2025-02-25 04:16:22 +07:00
self._run_async_safely(self.initialize_storages, "Storage Initialization")
def __del__(self):
if self.auto_manage_storages_states:
2025-02-25 04:16:22 +07:00
self._run_async_safely(self.finalize_storages, "Storage Finalization")
def _run_async_safely(self, async_func, action_name=""):
"""Safely execute an async function, avoiding event loop conflicts."""
try:
loop = always_get_an_event_loop()
2025-02-25 04:16:22 +07:00
if loop.is_running():
task = loop.create_task(async_func())
task.add_done_callback(
2025-02-25 04:18:52 +07:00
lambda t: logger.info(f"{action_name} completed!")
2025-02-25 04:16:22 +07:00
)
else:
loop.run_until_complete(async_func())
except RuntimeError:
logger.warning(
f"No running event loop, creating a new loop for {action_name}."
)
loop = asyncio.new_event_loop()
loop.run_until_complete(async_func())
loop.close()
async def initialize_storages(self):
"""Asynchronously initialize the storages"""
2025-02-20 13:30:30 +01:00
if self._storages_status == StoragesStatus.CREATED:
tasks = []
for storage in (
self.full_docs,
self.text_chunks,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
self.llm_response_cache,
self.doc_status,
):
if storage:
tasks.append(storage.initialize())
await asyncio.gather(*tasks)
2025-02-20 13:30:30 +01:00
self._storages_status = StoragesStatus.INITIALIZED
logger.debug("Initialized Storages")
async def finalize_storages(self):
"""Asynchronously finalize the storages"""
2025-02-20 13:30:30 +01:00
if self._storages_status == StoragesStatus.INITIALIZED:
tasks = []
for storage in (
self.full_docs,
self.text_chunks,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
self.llm_response_cache,
self.doc_status,
):
if storage:
tasks.append(storage.finalize())
await asyncio.gather(*tasks)
2025-02-20 13:30:30 +01:00
self._storages_status = StoragesStatus.FINALIZED
logger.debug("Finalized Storages")
2025-02-20 15:09:43 +01:00
async def get_graph_labels(self):
text = await self.chunk_entity_relation_graph.get_all_labels()
return text
2025-02-20 14:29:36 +01:00
async def get_knowledge_graph(
self, node_label: str, max_depth: int
2025-02-20 14:29:36 +01:00
) -> KnowledgeGraph:
return await self.chunk_entity_relation_graph.get_knowledge_graph(
node_label=node_label, max_depth=max_depth
2025-02-20 14:29:36 +01:00
)
2025-02-14 22:50:49 +01:00
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
2025-01-16 12:52:37 +08:00
import_path = STORAGES[storage_name]
storage_class = lazy_external_import(import_path, storage_name)
return storage_class
2025-01-16 12:58:15 +08:00
@staticmethod
def clean_text(text: str) -> str:
"""Clean text by removing null bytes (0x00) and whitespace"""
2025-02-21 13:23:55 +08:00
return text.strip().replace("\x00", "")
def insert(
2025-02-09 13:18:47 +01:00
self,
2025-02-14 22:50:49 +01:00
input: str | list[str],
2025-02-09 11:29:05 +01:00
split_by_character: str | None = None,
split_by_character_only: bool = False,
ids: str | list[str] | None = None,
2025-02-18 21:16:52 +01:00
) -> None:
2025-02-09 11:29:05 +01:00
"""Sync Insert documents with checkpoint support
Args:
2025-02-14 22:50:49 +01:00
input: Single document string or list of document strings
2025-02-09 11:29:05 +01:00
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
ids: single string of the document ID or list of unique document IDs, if not provided, MD5 hash IDs will be generated
2025-02-09 13:18:47 +01:00
"""
2024-10-10 15:02:30 +08:00
loop = always_get_an_event_loop()
2025-02-18 21:16:52 +01:00
loop.run_until_complete(
self.ainsert(input, split_by_character, split_by_character_only, ids)
2025-01-07 16:26:12 +08:00
)
2024-10-10 15:02:30 +08:00
async def ainsert(
2025-02-09 11:24:08 +01:00
self,
2025-02-14 22:50:49 +01:00
input: str | list[str],
2025-02-09 11:24:08 +01:00
split_by_character: str | None = None,
split_by_character_only: bool = False,
ids: str | list[str] | None = None,
2025-02-18 21:16:52 +01:00
) -> None:
2025-02-09 11:29:05 +01:00
"""Async Insert documents with checkpoint support
Args:
2025-02-14 22:50:49 +01:00
input: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
"""
await self.apipeline_enqueue_documents(input, ids)
2025-02-09 15:24:52 +01:00
await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character_only
)
2025-02-26 12:11:28 +01:00
def insert_custom_chunks(
self,
full_text: str,
text_chunks: list[str],
doc_id: str | list[str] | None = None,
) -> None:
2025-01-07 20:57:39 +05:30
loop = always_get_an_event_loop()
2025-02-26 12:11:28 +01:00
loop.run_until_complete(
self.ainsert_custom_chunks(full_text, text_chunks, doc_id)
)
2025-01-07 20:57:39 +05:30
2025-02-18 21:16:52 +01:00
async def ainsert_custom_chunks(
self, full_text: str, text_chunks: list[str], doc_id: str | None = None
2025-02-18 21:16:52 +01:00
) -> None:
2025-01-07 20:57:39 +05:30
update_storage = False
try:
# Clean input texts
full_text = self.clean_text(full_text)
text_chunks = [self.clean_text(chunk) for chunk in text_chunks]
# Process cleaned texts
if doc_id is None:
doc_key = compute_mdhash_id(full_text, prefix="doc-")
else:
doc_key = doc_id
new_docs = {doc_key: {"content": full_text}}
2025-01-07 20:57:39 +05:30
_add_doc_keys = await self.full_docs.filter_keys({doc_key})
2025-01-07 20:57:39 +05:30
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning("This document is already in the storage.")
return
update_storage = True
2025-02-19 22:07:25 +01:00
logger.info(f"Inserting {len(new_docs)} docs")
2025-01-07 20:57:39 +05:30
2025-02-09 19:56:12 +01:00
inserting_chunks: dict[str, Any] = {}
2025-01-07 20:57:39 +05:30
for chunk_text in text_chunks:
chunk_key = compute_mdhash_id(chunk_text, prefix="chunk-")
2025-01-07 20:57:39 +05:30
inserting_chunks[chunk_key] = {
"content": chunk_text,
2025-01-07 20:57:39 +05:30
"full_doc_id": doc_key,
}
2025-02-09 19:56:12 +01:00
doc_ids = set(inserting_chunks.keys())
add_chunk_keys = await self.text_chunks.filter_keys(doc_ids)
2025-01-07 20:57:39 +05:30
inserting_chunks = {
2025-02-09 19:56:12 +01:00
k: v for k, v in inserting_chunks.items() if k in add_chunk_keys
2025-01-07 20:57:39 +05:30
}
if not len(inserting_chunks):
logger.warning("All chunks are already in the storage.")
return
tasks = [
self.chunks_vdb.upsert(inserting_chunks),
self._process_entity_relation_graph(inserting_chunks),
self.full_docs.upsert(new_docs),
self.text_chunks.upsert(inserting_chunks),
]
await asyncio.gather(*tasks)
2025-01-07 20:57:39 +05:30
finally:
if update_storage:
await self._insert_done()
async def apipeline_enqueue_documents(
self, input: str | list[str], ids: list[str] | None = None
) -> None:
2025-02-09 14:39:32 +01:00
"""
Pipeline for Processing Documents
2025-02-09 15:24:52 +01:00
1. Validate ids if provided or generate MD5 hash IDs
2. Remove duplicate contents
3. Generate document initial status
4. Filter out already processed documents
5. Enqueue document in status
2025-02-09 15:24:52 +01:00
"""
2025-02-14 22:50:49 +01:00
if isinstance(input, str):
input = [input]
if isinstance(ids, str):
ids = [ids]
2025-01-16 12:52:37 +08:00
# 1. Validate ids if provided or generate MD5 hash IDs
if ids is not None:
# Check if the number of IDs matches the number of documents
if len(ids) != len(input):
raise ValueError("Number of IDs must match the number of documents")
# Check if IDs are unique
if len(ids) != len(set(ids)):
raise ValueError("IDs must be unique")
# Generate contents dict of IDs provided by user and documents
contents = {id_: doc for id_, doc in zip(ids, input)}
else:
# Clean input text and remove duplicates
input = list(set(self.clean_text(doc) for doc in input))
# Generate contents dict of MD5 hash IDs and documents
2025-02-22 10:18:39 +08:00
contents = {compute_mdhash_id(doc, prefix="doc-"): doc for doc in input}
# 2. Remove duplicate contents
unique_contents = {
id_: content
for content, id_ in {
content: id_ for id_, content in contents.items()
}.items()
}
2025-01-16 12:52:37 +08:00
# 3. Generate document initial status
2025-02-09 11:10:46 +01:00
new_docs: dict[str, Any] = {
id_: {
2025-01-16 12:52:37 +08:00
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
2025-02-17 18:26:07 +01:00
"status": DocStatus.PENDING,
2025-01-16 12:52:37 +08:00
"created_at": datetime.now().isoformat(),
2025-02-09 11:10:46 +01:00
"updated_at": datetime.now().isoformat(),
2025-01-16 12:52:37 +08:00
}
for id_, content in unique_contents.items()
2025-01-16 12:52:37 +08:00
}
# 4. Filter out already processed documents
2025-02-09 14:55:52 +01:00
# Get docs ids
2025-02-09 19:24:41 +01:00
all_new_doc_ids = set(new_docs.keys())
# Exclude IDs of documents that are already in progress
2025-02-09 21:17:09 +01:00
unique_new_doc_ids = await self.doc_status.filter_keys(all_new_doc_ids)
2025-02-09 19:24:41 +01:00
# Filter new_docs to only include documents with unique IDs
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids}
2025-01-16 12:52:37 +08:00
if not new_docs:
2025-02-11 13:28:18 +08:00
logger.info("No new unique documents were found.")
2025-02-09 11:10:46 +01:00
return
2025-01-16 12:52:37 +08:00
# 5. Store status document
2025-02-09 13:18:47 +01:00
await self.doc_status.upsert(new_docs)
2025-01-16 12:52:37 +08:00
logger.info(f"Stored {len(new_docs)} new unique documents")
2025-01-16 12:58:15 +08:00
2025-02-09 14:32:48 +01:00
async def apipeline_process_enqueue_documents(
2025-02-09 11:24:08 +01:00
self,
split_by_character: str | None = None,
split_by_character_only: bool = False,
) -> None:
2025-02-09 11:30:54 +01:00
"""
2025-02-09 14:32:48 +01:00
Process pending documents by splitting them into chunks, processing
2025-02-09 14:36:49 +01:00
each chunk for entity and relation extraction, and updating the
2025-02-09 14:32:48 +01:00
document status.
2025-02-09 14:36:49 +01:00
2025-02-11 13:28:18 +08:00
1. Get all pending, failed, and abnormally terminated processing documents.
2025-02-09 14:32:48 +01:00
2. Split document content into chunks
3. Process each chunk for entity and relation extraction
4. Update the document status
2025-02-09 14:36:49 +01:00
"""
2025-03-01 16:23:34 +08:00
from lightrag.kg.shared_storage import (
get_namespace_data,
get_pipeline_status_lock,
2025-02-19 23:45:51 +01:00
)
2025-02-09 14:36:49 +01:00
# Get pipeline status shared data and lock
pipeline_status = await get_namespace_data("pipeline_status")
pipeline_status_lock = get_pipeline_status_lock()
2025-02-28 21:35:04 +08:00
# Check if another process is already processing the queue
async with pipeline_status_lock:
2025-02-28 21:35:04 +08:00
# Ensure only one worker is processing documents
if not pipeline_status.get("busy", False):
# 先检查是否有需要处理的文档
processing_docs, failed_docs, pending_docs = await asyncio.gather(
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
self.doc_status.get_docs_by_status(DocStatus.FAILED),
self.doc_status.get_docs_by_status(DocStatus.PENDING),
)
2025-02-09 14:36:49 +01:00
to_process_docs: dict[str, DocProcessingStatus] = {}
to_process_docs.update(processing_docs)
to_process_docs.update(failed_docs)
to_process_docs.update(pending_docs)
# 如果没有需要处理的文档,直接返回,保留 pipeline_status 中的内容不变
if not to_process_docs:
logger.info("No documents to process")
return
# 有文档需要处理,更新 pipeline_status
2025-02-28 21:35:04 +08:00
pipeline_status.update(
{
"busy": True,
"job_name": "indexing files",
"job_start": datetime.now().isoformat(),
"docs": 0,
"batchs": 0,
"cur_batch": 0,
"request_pending": False, # Clear any previous request
"latest_message": "",
2025-02-19 23:53:25 +01:00
}
2025-02-28 21:35:04 +08:00
)
# Cleaning history_messages without breaking it as a shared list object
del pipeline_status["history_messages"][:]
else:
# Another process is busy, just set request flag and return
pipeline_status["request_pending"] = True
2025-02-28 21:35:04 +08:00
logger.info(
"Another process is already processing the document queue. Request queued."
)
return
2025-02-28 21:35:04 +08:00
try:
# Process documents until no more documents or requests
while True:
if not to_process_docs:
log_message = "All documents have been processed or are duplicates"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
break
# 2. split docs into chunks, insert chunks, update doc status
docs_batches = [
list(to_process_docs.items())[i : i + self.max_parallel_insert]
for i in range(0, len(to_process_docs), self.max_parallel_insert)
]
log_message = f"Number of batches to process: {len(docs_batches)}."
logger.info(log_message)
# Update pipeline status with current batch information
pipeline_status["docs"] += len(to_process_docs)
pipeline_status["batchs"] += len(docs_batches)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
batches: list[Any] = []
# 3. iterate over batches
for batch_idx, docs_batch in enumerate(docs_batches):
# Update current batch in pipeline status (directly, as it's atomic)
pipeline_status["cur_batch"] += 1
2025-02-28 21:35:04 +08:00
async def batch(
batch_idx: int,
docs_batch: list[tuple[str, DocProcessingStatus]],
size_batch: int,
) -> None:
2025-02-28 21:35:04 +08:00
log_message = (
f"Start processing batch {batch_idx + 1} of {size_batch}."
2025-02-20 00:09:46 +01:00
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# 4. iterate over batch
for doc_id_processing_status in docs_batch:
doc_id, status_doc = doc_id_processing_status
# Generate chunks from document
chunks: dict[str, Any] = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
2025-02-19 23:53:25 +01:00
}
for dp in self.chunking_func(
status_doc.content,
split_by_character,
split_by_character_only,
self.chunk_overlap_token_size,
self.chunk_token_size,
self.tiktoken_model_name,
)
2025-02-19 23:53:25 +01:00
}
# Process document (text chunks and full docs) in parallel
tasks = [
self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.PROCESSING,
"updated_at": datetime.now().isoformat(),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
}
}
),
self.chunks_vdb.upsert(chunks),
self._process_entity_relation_graph(chunks),
self.full_docs.upsert(
{doc_id: {"content": status_doc.content}}
),
self.text_chunks.upsert(chunks),
]
try:
await asyncio.gather(*tasks)
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}
)
except Exception as e:
2025-02-28 21:35:04 +08:00
logger.error(
f"Failed to process document {doc_id}: {str(e)}"
)
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.FAILED,
"error": str(e),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}
)
continue
2025-02-28 21:35:04 +08:00
log_message = (
f"Completed batch {batch_idx + 1} of {len(docs_batches)}."
2025-02-19 23:53:25 +01:00
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
batches.append(batch(batch_idx, docs_batch, len(docs_batches)))
await asyncio.gather(*batches)
await self._insert_done()
2025-02-28 21:35:04 +08:00
# Check if there's a pending request to process more documents (with lock)
has_pending_request = False
async with pipeline_status_lock:
has_pending_request = pipeline_status.get("request_pending", False)
if has_pending_request:
# Clear the request flag before checking for more documents
pipeline_status["request_pending"] = False
2025-02-28 21:35:04 +08:00
if not has_pending_request:
break
2025-02-28 21:35:04 +08:00
log_message = "Processing additional documents due to pending request"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
2025-02-28 21:35:04 +08:00
# 获取新的待处理文档
processing_docs, failed_docs, pending_docs = await asyncio.gather(
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
self.doc_status.get_docs_by_status(DocStatus.FAILED),
self.doc_status.get_docs_by_status(DocStatus.PENDING),
)
2025-02-19 23:53:25 +01:00
to_process_docs = {}
to_process_docs.update(processing_docs)
to_process_docs.update(failed_docs)
to_process_docs.update(pending_docs)
2025-02-19 23:53:25 +01:00
finally:
log_message = "Document processing pipeline completed"
logger.info(log_message)
# Always reset busy status when done or if an exception occurs (with lock)
async with pipeline_status_lock:
pipeline_status["busy"] = False
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
2025-01-16 12:52:37 +08:00
2025-02-09 13:03:50 +01:00
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
2025-02-09 13:18:47 +01:00
try:
2025-02-20 14:17:26 +01:00
await extract_entities(
2025-02-09 13:18:47 +01:00
chunk,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)
except Exception as e:
logger.error("Failed to extract entities and relationships")
raise e
2025-02-18 21:16:52 +01:00
async def _insert_done(self) -> None:
tasks = [
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in [ # type: ignore
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
]
if storage_inst is not None
]
2024-10-10 15:02:30 +08:00
await asyncio.gather(*tasks)
2025-02-28 21:35:04 +08:00
log_message = "All Insert done"
logger.info(log_message)
2025-02-28 21:35:04 +08:00
# 获取 pipeline_status 并更新 latest_message 和 history_messages
from lightrag.kg.shared_storage import get_namespace_data
2025-02-28 21:35:04 +08:00
pipeline_status = await get_namespace_data("pipeline_status")
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
2024-10-10 15:02:30 +08:00
2025-03-03 14:54:28 +08:00
def insert_custom_kg(
self, custom_kg: dict[str, Any], full_doc_id: str = None
) -> None:
2024-11-25 18:06:19 +08:00
loop = always_get_an_event_loop()
2025-03-01 13:26:02 +01:00
loop.run_until_complete(self.ainsert_custom_kg(custom_kg, full_doc_id))
2024-11-25 18:06:19 +08:00
2025-03-03 14:54:28 +08:00
async def ainsert_custom_kg(
self, custom_kg: dict[str, Any], full_doc_id: str = None
) -> None:
2024-11-25 18:06:19 +08:00
update_storage = False
try:
2024-12-04 19:44:04 +08:00
# Insert chunks into vector storage
all_chunks_data: dict[str, dict[str, str]] = {}
chunk_to_source_map: dict[str, str] = {}
for chunk_data in custom_kg.get("chunks", []):
chunk_content = self.clean_text(chunk_data["content"])
2024-12-04 19:44:04 +08:00
source_id = chunk_data["source_id"]
2025-02-19 10:28:25 +01:00
tokens = len(
encode_string_by_tiktoken(
chunk_content, model_name=self.tiktoken_model_name
)
)
chunk_order_index = (
0
if "chunk_order_index" not in chunk_data.keys()
else chunk_data["chunk_order_index"]
)
2025-02-17 15:25:50 +01:00
chunk_id = compute_mdhash_id(chunk_content, prefix="chunk-")
2024-12-04 19:44:04 +08:00
2025-02-17 15:12:35 +01:00
chunk_entry = {
2025-02-17 15:25:50 +01:00
"content": chunk_content,
2025-02-17 15:12:35 +01:00
"source_id": source_id,
"tokens": tokens,
"chunk_order_index": chunk_order_index,
2025-03-03 14:54:28 +08:00
"full_doc_id": full_doc_id
if full_doc_id is not None
else source_id,
2025-02-17 15:25:50 +01:00
"status": DocStatus.PROCESSED,
2025-02-17 15:12:35 +01:00
}
2024-12-04 19:44:04 +08:00
all_chunks_data[chunk_id] = chunk_entry
chunk_to_source_map[source_id] = chunk_id
update_storage = True
if all_chunks_data:
await asyncio.gather(
self.chunks_vdb.upsert(all_chunks_data),
self.text_chunks.upsert(all_chunks_data),
)
2024-12-04 19:44:04 +08:00
2024-11-25 18:06:19 +08:00
# Insert entities into knowledge graph
all_entities_data: list[dict[str, str]] = []
2024-11-25 18:06:19 +08:00
for entity_data in custom_kg.get("entities", []):
2025-03-02 14:23:06 +08:00
entity_name = entity_data["entity_name"]
2024-11-25 18:06:19 +08:00
entity_type = entity_data.get("entity_type", "UNKNOWN")
description = entity_data.get("description", "No description provided")
2024-12-04 19:44:04 +08:00
source_chunk_id = entity_data.get("source_id", "UNKNOWN")
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
# Log if source_id is UNKNOWN
if source_id == "UNKNOWN":
logger.warning(
f"Entity '{entity_name}' has an UNKNOWN source_id. Please check the source mapping."
)
2024-11-25 18:06:19 +08:00
# Prepare node data
node_data: dict[str, str] = {
2024-11-25 18:06:19 +08:00
"entity_type": entity_type,
"description": description,
"source_id": source_id,
}
# Insert node data into the knowledge graph
await self.chunk_entity_relation_graph.upsert_node(
entity_name, node_data=node_data
)
node_data["entity_name"] = entity_name
all_entities_data.append(node_data)
update_storage = True
# Insert relationships into knowledge graph
all_relationships_data: list[dict[str, str]] = []
2024-11-25 18:06:19 +08:00
for relationship_data in custom_kg.get("relationships", []):
2025-03-02 14:23:06 +08:00
src_id = relationship_data["src_id"]
tgt_id = relationship_data["tgt_id"]
2024-11-25 18:06:19 +08:00
description = relationship_data["description"]
keywords = relationship_data["keywords"]
weight = relationship_data.get("weight", 1.0)
2024-12-04 19:44:04 +08:00
source_chunk_id = relationship_data.get("source_id", "UNKNOWN")
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
# Log if source_id is UNKNOWN
if source_id == "UNKNOWN":
logger.warning(
f"Relationship from '{src_id}' to '{tgt_id}' has an UNKNOWN source_id. Please check the source mapping."
)
2024-11-25 18:06:19 +08:00
# Check if nodes exist in the knowledge graph
for need_insert_id in [src_id, tgt_id]:
if not (
2025-01-07 16:26:12 +08:00
await self.chunk_entity_relation_graph.has_node(need_insert_id)
2024-11-25 18:06:19 +08:00
):
await self.chunk_entity_relation_graph.upsert_node(
need_insert_id,
node_data={
"source_id": source_id,
"description": "UNKNOWN",
"entity_type": "UNKNOWN",
},
)
# Insert edge into the knowledge graph
await self.chunk_entity_relation_graph.upsert_edge(
src_id,
tgt_id,
edge_data={
"weight": weight,
"description": description,
"keywords": keywords,
"source_id": source_id,
},
)
edge_data: dict[str, str] = {
2024-11-25 18:06:19 +08:00
"src_id": src_id,
"tgt_id": tgt_id,
"description": description,
"keywords": keywords,
"source_id": source_id,
"weight": weight,
2024-11-25 18:06:19 +08:00
}
all_relationships_data.append(edge_data)
update_storage = True
# Insert entities into vector storage with consistent format
data_for_vdb = {
2025-02-14 23:33:59 +01:00
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + "\n" + dp["description"],
2025-02-14 23:33:59 +01:00
"entity_name": dp["entity_name"],
"source_id": dp["source_id"],
"description": dp["description"],
"entity_type": dp["entity_type"],
2024-11-25 18:06:19 +08:00
}
2025-02-14 23:33:59 +01:00
for dp in all_entities_data
}
await self.entities_vdb.upsert(data_for_vdb)
2024-11-25 18:06:19 +08:00
# Insert relationships into vector storage with consistent format
data_for_vdb = {
2025-02-14 23:33:59 +01:00
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"source_id": dp["source_id"],
"content": f"{dp['keywords']}\t{dp['src_id']}\n{dp['tgt_id']}\n{dp['description']}",
"keywords": dp["keywords"],
"description": dp["description"],
"weight": dp["weight"],
2024-11-25 18:06:19 +08:00
}
2025-02-14 23:33:59 +01:00
for dp in all_relationships_data
}
await self.relationships_vdb.upsert(data_for_vdb)
2025-02-14 23:33:59 +01:00
except Exception as e:
logger.error(f"Error in ainsert_custom_kg: {e}")
raise
2024-11-25 18:06:19 +08:00
finally:
if update_storage:
await self._insert_done()
def query(
self,
query: str,
param: QueryParam = QueryParam(),
system_prompt: str | None = None,
2025-02-14 23:42:52 +01:00
) -> str | Iterator[str]:
"""
Perform a sync query.
Args:
query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns:
str: The result of the query execution.
2025-02-14 23:33:59 +01:00
"""
2024-10-10 15:02:30 +08:00
loop = always_get_an_event_loop()
2025-02-14 23:52:05 +01:00
return loop.run_until_complete(self.aquery(query, param, system_prompt)) # type: ignore
2025-01-27 10:32:22 +05:30
async def aquery(
self,
query: str,
param: QueryParam = QueryParam(),
system_prompt: str | None = None,
2025-02-14 23:42:52 +01:00
) -> str | AsyncIterator[str]:
"""
Perform a async query.
Args:
query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns:
str: The result of the query execution.
"""
2024-11-25 13:29:55 +08:00
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
2024-10-10 15:02:30 +08:00
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
2025-01-07 16:26:12 +08:00
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
2025-02-02 04:27:55 +08:00
embedding_func=self.embedding_func,
),
system_prompt=system_prompt,
2024-10-10 15:02:30 +08:00
)
elif param.mode == "naive":
response = await naive_query(
query,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
2025-01-07 16:26:12 +08:00
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
system_prompt=system_prompt,
2024-10-10 15:02:30 +08:00
)
elif param.mode == "mix":
response = await mix_kg_vector_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
2025-01-07 16:26:12 +08:00
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
system_prompt=system_prompt,
)
2024-10-10 15:02:30 +08:00
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
def query_with_separate_keyword_extraction(
2025-02-14 23:52:05 +01:00
self, query: str, prompt: str, param: QueryParam = QueryParam()
):
"""
1. Extract keywords from the 'query' using new function in operate.py.
2. Then run the standard aquery() flow with the final prompt (formatted_question).
"""
loop = always_get_an_event_loop()
2025-01-14 22:23:14 +05:30
return loop.run_until_complete(
self.aquery_with_separate_keyword_extraction(query, prompt, param)
)
async def aquery_with_separate_keyword_extraction(
2025-02-14 23:52:05 +01:00
self, query: str, prompt: str, param: QueryParam = QueryParam()
2025-02-15 00:01:21 +01:00
) -> str | AsyncIterator[str]:
"""
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
"""
# ---------------------
# STEP 1: Keyword Extraction
# ---------------------
hl_keywords, ll_keywords = await extract_keywords_only(
text=query,
param=param,
global_config=asdict(self),
2025-01-14 22:23:14 +05:30
hashing_kv=self.llm_response_cache
or self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
2025-01-14 22:23:14 +05:30
),
)
2025-01-14 22:23:14 +05:30
2025-02-14 23:52:05 +01:00
param.hl_keywords = hl_keywords
param.ll_keywords = ll_keywords
2025-01-14 22:23:14 +05:30
# ---------------------
# STEP 2: Final Query Logic
# ---------------------
2025-01-14 22:23:14 +05:30
# Create a new string with the prompt and the keywords
ll_keywords_str = ", ".join(ll_keywords)
hl_keywords_str = ", ".join(hl_keywords)
formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query_with_keywords(
formatted_question,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
2025-01-14 22:23:14 +05:30
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
2025-02-14 23:52:05 +01:00
embedding_func=self.embedding_func,
),
)
elif param.mode == "naive":
response = await naive_query(
formatted_question,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
2025-01-14 22:23:14 +05:30
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
)
elif param.mode == "mix":
response = await mix_kg_vector_query(
formatted_question,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
2025-01-14 22:23:14 +05:30
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
)
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
2024-10-10 15:02:30 +08:00
return response
async def _query_done(self):
2025-02-15 00:01:21 +01:00
await self.llm_response_cache.index_done_callback()
2024-11-11 17:48:40 +08:00
2025-02-18 21:16:52 +01:00
def delete_by_entity(self, entity_name: str) -> None:
2024-11-11 17:48:40 +08:00
loop = always_get_an_event_loop()
return loop.run_until_complete(self.adelete_by_entity(entity_name))
2024-11-11 17:54:22 +08:00
2025-02-18 21:16:52 +01:00
async def adelete_by_entity(self, entity_name: str) -> None:
2024-11-11 17:48:40 +08:00
try:
await self.entities_vdb.delete_entity(entity_name)
2024-12-31 17:15:57 +08:00
await self.relationships_vdb.delete_entity_relation(entity_name)
2024-11-11 17:48:40 +08:00
await self.chunk_entity_relation_graph.delete_node(entity_name)
2024-11-11 17:54:22 +08:00
logger.info(
f"Entity '{entity_name}' and its relationships have been deleted."
)
2024-11-11 17:48:40 +08:00
await self._delete_by_entity_done()
except Exception as e:
logger.error(f"Error while deleting entity '{entity_name}': {e}")
2024-11-11 17:54:22 +08:00
2025-02-18 21:16:52 +01:00
async def _delete_by_entity_done(self) -> None:
2025-02-15 00:01:21 +01:00
await asyncio.gather(
*[
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in [ # type: ignore
self.entities_vdb,
self.relationships_vdb,
self.chunk_entity_relation_graph,
]
]
)
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
"""Get summary of document content
Args:
content: Original document content
max_length: Maximum length of summary
Returns:
Truncated content with ellipsis if needed
"""
content = content.strip()
if len(content) <= max_length:
return content
return content[:max_length] + "..."
2025-02-09 11:24:08 +01:00
async def get_processing_status(self) -> dict[str, int]:
"""Get current document processing status counts
Returns:
Dict with counts for each status
"""
return await self.doc_status.get_status_counts()
2024-12-31 17:15:57 +08:00
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get documents by status
Returns:
Dict with document id is keys and document status is values
"""
return await self.doc_status.get_docs_by_status(status)
2025-02-15 00:10:37 +01:00
async def adelete_by_doc_id(self, doc_id: str) -> None:
2024-12-31 17:15:57 +08:00
"""Delete a document and all its related data
Args:
doc_id: Document ID to delete
"""
try:
# 1. Get the document status and related data
2025-02-13 20:45:24 +01:00
doc_status = await self.doc_status.get_by_id(doc_id)
2024-12-31 17:15:57 +08:00
if not doc_status:
logger.warning(f"Document {doc_id} not found")
return
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
logger.debug(f"Starting deletion for document {doc_id}")
2024-12-31 17:32:04 +08:00
2025-02-27 23:34:57 +07:00
doc_to_chunk_id = doc_id.replace("doc", "chunk")
2024-12-31 17:15:57 +08:00
# 2. Get all related chunks
2025-02-27 23:34:57 +07:00
chunks = await self.text_chunks.get_by_id(doc_to_chunk_id)
2025-02-15 00:10:37 +01:00
if not chunks:
return
2025-02-27 23:34:57 +07:00
chunk_ids = {chunks["full_doc_id"].replace("doc", "chunk")}
2024-12-31 17:15:57 +08:00
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# 3. Before deleting, check the related entities and relationships for these chunks
for chunk_id in chunk_ids:
# Check entities
2025-03-03 19:17:34 +08:00
entities_storage = await self.entities_vdb.client_storage
2024-12-31 17:15:57 +08:00
entities = [
2024-12-31 17:32:04 +08:00
dp
2025-03-03 19:17:34 +08:00
for dp in entities_storage["data"]
2025-02-27 23:34:57 +07:00
if chunk_id in dp.get("source_id")
2024-12-31 17:15:57 +08:00
]
logger.debug(f"Chunk {chunk_id} has {len(entities)} related entities")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Check relationships
2025-03-03 19:17:34 +08:00
relationships_storage = await self.relationships_vdb.client_storage
2024-12-31 17:15:57 +08:00
relations = [
2024-12-31 17:32:04 +08:00
dp
2025-03-03 19:17:34 +08:00
for dp in relationships_storage["data"]
2025-02-27 23:34:57 +07:00
if chunk_id in dp.get("source_id")
2024-12-31 17:15:57 +08:00
]
logger.debug(f"Chunk {chunk_id} has {len(relations)} related relations")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Continue with the original deletion process...
# 4. Delete chunks from vector database
if chunk_ids:
await self.chunks_vdb.delete(chunk_ids)
await self.text_chunks.delete(chunk_ids)
# 5. Find and process entities and relationships that have these chunks as source
# Get all nodes in the graph
nodes = self.chunk_entity_relation_graph._graph.nodes(data=True)
edges = self.chunk_entity_relation_graph._graph.edges(data=True)
# Track which entities and relationships need to be deleted or updated
entities_to_delete = set()
entities_to_update = {} # entity_name -> new_source_id
relationships_to_delete = set()
relationships_to_update = {} # (src, tgt) -> new_source_id
# Process entities
for node, data in nodes:
2024-12-31 17:32:04 +08:00
if "source_id" in data:
2024-12-31 17:15:57 +08:00
# Split source_id using GRAPH_FIELD_SEP
2024-12-31 17:32:04 +08:00
sources = set(data["source_id"].split(GRAPH_FIELD_SEP))
2024-12-31 17:15:57 +08:00
sources.difference_update(chunk_ids)
if not sources:
entities_to_delete.add(node)
2024-12-31 17:32:04 +08:00
logger.debug(
f"Entity {node} marked for deletion - no remaining sources"
)
2024-12-31 17:15:57 +08:00
else:
new_source_id = GRAPH_FIELD_SEP.join(sources)
entities_to_update[node] = new_source_id
2024-12-31 17:32:04 +08:00
logger.debug(
f"Entity {node} will be updated with new source_id: {new_source_id}"
)
2024-12-31 17:15:57 +08:00
# Process relationships
for src, tgt, data in edges:
2024-12-31 17:32:04 +08:00
if "source_id" in data:
2024-12-31 17:15:57 +08:00
# Split source_id using GRAPH_FIELD_SEP
2024-12-31 17:32:04 +08:00
sources = set(data["source_id"].split(GRAPH_FIELD_SEP))
2024-12-31 17:15:57 +08:00
sources.difference_update(chunk_ids)
if not sources:
relationships_to_delete.add((src, tgt))
2024-12-31 17:32:04 +08:00
logger.debug(
f"Relationship {src}-{tgt} marked for deletion - no remaining sources"
)
2024-12-31 17:15:57 +08:00
else:
new_source_id = GRAPH_FIELD_SEP.join(sources)
relationships_to_update[(src, tgt)] = new_source_id
2024-12-31 17:32:04 +08:00
logger.debug(
f"Relationship {src}-{tgt} will be updated with new source_id: {new_source_id}"
)
2024-12-31 17:15:57 +08:00
# Delete entities
if entities_to_delete:
for entity in entities_to_delete:
await self.entities_vdb.delete_entity(entity)
logger.debug(f"Deleted entity {entity} from vector DB")
2025-03-03 19:17:34 +08:00
await self.chunk_entity_relation_graph.remove_nodes(
list(entities_to_delete)
)
2024-12-31 17:15:57 +08:00
logger.debug(f"Deleted {len(entities_to_delete)} entities from graph")
# Update entities
for entity, new_source_id in entities_to_update.items():
node_data = self.chunk_entity_relation_graph._graph.nodes[entity]
2024-12-31 17:32:04 +08:00
node_data["source_id"] = new_source_id
2024-12-31 17:15:57 +08:00
await self.chunk_entity_relation_graph.upsert_node(entity, node_data)
2024-12-31 17:32:04 +08:00
logger.debug(
f"Updated entity {entity} with new source_id: {new_source_id}"
)
2024-12-31 17:15:57 +08:00
# Delete relationships
if relationships_to_delete:
for src, tgt in relationships_to_delete:
rel_id_0 = compute_mdhash_id(src + tgt, prefix="rel-")
rel_id_1 = compute_mdhash_id(tgt + src, prefix="rel-")
await self.relationships_vdb.delete([rel_id_0, rel_id_1])
logger.debug(f"Deleted relationship {src}-{tgt} from vector DB")
2025-03-03 19:17:34 +08:00
await self.chunk_entity_relation_graph.remove_edges(
2024-12-31 17:32:04 +08:00
list(relationships_to_delete)
)
logger.debug(
f"Deleted {len(relationships_to_delete)} relationships from graph"
)
2024-12-31 17:15:57 +08:00
# Update relationships
for (src, tgt), new_source_id in relationships_to_update.items():
edge_data = self.chunk_entity_relation_graph._graph.edges[src, tgt]
2024-12-31 17:32:04 +08:00
edge_data["source_id"] = new_source_id
2024-12-31 17:15:57 +08:00
await self.chunk_entity_relation_graph.upsert_edge(src, tgt, edge_data)
2024-12-31 17:32:04 +08:00
logger.debug(
f"Updated relationship {src}-{tgt} with new source_id: {new_source_id}"
)
2024-12-31 17:15:57 +08:00
# 6. Delete original document and status
await self.full_docs.delete([doc_id])
await self.doc_status.delete([doc_id])
# 7. Ensure all indexes are updated
await self._insert_done()
logger.info(
f"Successfully deleted document {doc_id} and related data. "
f"Deleted {len(entities_to_delete)} entities and {len(relationships_to_delete)} relationships. "
f"Updated {len(entities_to_update)} entities and {len(relationships_to_update)} relationships."
)
2025-02-27 23:34:57 +07:00
async def process_data(data_type, vdb, chunk_id):
# Check data (entities or relationships)
2025-03-03 19:17:34 +08:00
storage = await vdb.client_storage
2025-02-27 23:34:57 +07:00
data_with_chunk = [
dp
2025-03-03 19:17:34 +08:00
for dp in storage["data"]
2025-02-27 23:34:57 +07:00
if chunk_id in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
]
data_for_vdb = {}
if data_with_chunk:
logger.warning(
f"found {len(data_with_chunk)} {data_type} still referencing chunk {chunk_id}"
)
for item in data_with_chunk:
old_sources = item["source_id"].split(GRAPH_FIELD_SEP)
new_sources = [src for src in old_sources if src != chunk_id]
if not new_sources:
logger.info(
f"{data_type} {item.get('entity_name', 'N/A')} is deleted because source_id is not exists"
)
await vdb.delete_entity(item)
else:
item["source_id"] = GRAPH_FIELD_SEP.join(new_sources)
item_id = item["__id__"]
data_for_vdb[item_id] = item.copy()
if data_type == "entities":
data_for_vdb[item_id]["content"] = data_for_vdb[
item_id
].get("content") or (
item.get("entity_name", "")
+ (item.get("description") or "")
)
else: # relationships
data_for_vdb[item_id]["content"] = data_for_vdb[
item_id
].get("content") or (
(item.get("keywords") or "")
+ (item.get("src_id") or "")
+ (item.get("tgt_id") or "")
+ (item.get("description") or "")
)
if data_for_vdb:
await vdb.upsert(data_for_vdb)
logger.info(f"Successfully updated {data_type} in vector DB")
2024-12-31 17:15:57 +08:00
# Add verification step
async def verify_deletion():
# Verify if the document has been deleted
if await self.full_docs.get_by_id(doc_id):
2025-02-27 23:34:57 +07:00
logger.warning(f"Document {doc_id} still exists in full_docs")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Verify if chunks have been deleted
2025-02-27 23:34:57 +07:00
remaining_chunks = await self.text_chunks.get_by_id(doc_to_chunk_id)
2024-12-31 17:15:57 +08:00
if remaining_chunks:
2025-02-27 23:34:57 +07:00
logger.warning(f"Found {len(remaining_chunks)} remaining chunks")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Verify entities and relationships
for chunk_id in chunk_ids:
2025-02-27 23:34:57 +07:00
await process_data("entities", self.entities_vdb, chunk_id)
await process_data(
"relationships", self.relationships_vdb, chunk_id
)
2024-12-31 17:15:57 +08:00
await verify_deletion()
except Exception as e:
logger.error(f"Error while deleting document {doc_id}: {e}")
2024-12-31 17:32:04 +08:00
async def get_entity_info(
2025-01-07 16:26:12 +08:00
self, entity_name: str, include_vector_data: bool = False
2025-02-14 23:49:39 +01:00
) -> dict[str, str | None | dict[str, str]]:
2024-12-31 17:15:57 +08:00
"""Get detailed information of an entity
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Args:
entity_name: Entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Returns:
dict: A dictionary containing entity information, including:
- entity_name: Entity name
- source_id: Source document ID
- graph_data: Complete node data from the graph database
- vector_data: (optional) Data from the vector database
"""
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Get information from the graph
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
2024-12-31 17:32:04 +08:00
source_id = node_data.get("source_id") if node_data else None
2025-02-14 23:49:39 +01:00
result: dict[str, str | None | dict[str, str]] = {
2024-12-31 17:15:57 +08:00
"entity_name": entity_name,
"source_id": source_id,
"graph_data": node_data,
}
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Optional: Get vector database information
if include_vector_data:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
vector_data = self.entities_vdb._client.get([entity_id])
result["vector_data"] = vector_data[0] if vector_data else None
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
return result
2024-12-31 17:32:04 +08:00
async def get_relation_info(
2025-01-07 16:26:12 +08:00
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
2025-02-18 21:16:52 +01:00
) -> dict[str, str | None | dict[str, str]]:
2024-12-31 17:15:57 +08:00
"""Get detailed information of a relationship
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Args:
src_entity: Source entity name (no need for quotes)
tgt_entity: Target entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Returns:
dict: A dictionary containing relationship information, including:
- src_entity: Source entity name
- tgt_entity: Target entity name
- source_id: Source document ID
- graph_data: Complete edge data from the graph database
- vector_data: (optional) Data from the vector database
"""
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Get information from the graph
2024-12-31 17:32:04 +08:00
edge_data = await self.chunk_entity_relation_graph.get_edge(
src_entity, tgt_entity
)
source_id = edge_data.get("source_id") if edge_data else None
2025-02-14 23:49:39 +01:00
result: dict[str, str | None | dict[str, str]] = {
2024-12-31 17:15:57 +08:00
"src_entity": src_entity,
"tgt_entity": tgt_entity,
"source_id": source_id,
"graph_data": edge_data,
}
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Optional: Get vector database information
if include_vector_data:
rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-")
vector_data = self.relationships_vdb._client.get([rel_id])
result["vector_data"] = vector_data[0] if vector_data else None
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
return result
2025-02-20 13:18:17 +01:00
def check_storage_env_vars(self, storage_name: str) -> None:
"""Check if all required environment variables for storage implementation exist
Args:
storage_name: Storage implementation name
Raises:
ValueError: If required environment variables are missing
"""
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
2025-02-20 13:21:41 +01:00
)
2025-03-01 18:30:58 +08:00
async def aclear_cache(self, modes: list[str] | None = None) -> None:
"""Clear cache data from the LLM response cache storage.
Args:
2025-03-01 18:35:12 +08:00
modes (list[str] | None): Modes of cache to clear. Options: ["default", "naive", "local", "global", "hybrid", "mix"].
2025-03-01 18:30:58 +08:00
"default" represents extraction cache.
If None, clears all cache.
Example:
# Clear all cache
await rag.aclear_cache()
2025-03-01 18:35:12 +08:00
2025-03-01 18:30:58 +08:00
# Clear local mode cache
await rag.aclear_cache(modes=["local"])
2025-03-01 18:35:12 +08:00
2025-03-01 18:30:58 +08:00
# Clear extraction cache
await rag.aclear_cache(modes=["default"])
"""
if not self.llm_response_cache:
logger.warning("No cache storage configured")
return
valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"]
# Validate input
if modes and not all(mode in valid_modes for mode in modes):
raise ValueError(f"Invalid mode. Valid modes are: {valid_modes}")
try:
# Reset the cache storage for specified mode
if modes:
2025-03-01 18:35:12 +08:00
await self.llm_response_cache.delete(modes)
logger.info(f"Cleared cache for modes: {modes}")
2025-03-01 18:30:58 +08:00
else:
# Clear all modes
await self.llm_response_cache.delete(valid_modes)
logger.info("Cleared all cache")
await self.llm_response_cache.index_done_callback()
except Exception as e:
logger.error(f"Error while clearing cache: {e}")
def clear_cache(self, modes: list[str] | None = None) -> None:
"""Synchronous version of aclear_cache."""
2025-03-01 18:35:12 +08:00
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
async def aedit_entity(
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
) -> dict[str, Any]:
"""Asynchronously edit entity information.
Updates entity information in the knowledge graph and re-embeds the entity in the vector database.
Args:
entity_name: Name of the entity to edit
updated_data: Dictionary containing updated attributes, e.g. {"description": "new description", "entity_type": "new type"}
allow_rename: Whether to allow entity renaming, defaults to True
Returns:
Dictionary containing updated entity information
"""
try:
# 1. Get current entity information
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
if not node_data:
raise ValueError(f"Entity '{entity_name}' does not exist")
# Check if entity is being renamed
new_entity_name = updated_data.get("entity_name", entity_name)
is_renaming = new_entity_name != entity_name
# If renaming, check if new name already exists
if is_renaming:
if not allow_rename:
raise ValueError(
"Entity renaming is not allowed. Set allow_rename=True to enable this feature"
)
existing_node = await self.chunk_entity_relation_graph.get_node(
new_entity_name
)
if existing_node:
raise ValueError(
f"Entity name '{new_entity_name}' already exists, cannot rename"
)
# 2. Update entity information in the graph
new_node_data = {**node_data, **updated_data}
if "entity_name" in new_node_data:
del new_node_data[
"entity_name"
] # Node data should not contain entity_name field
# If renaming entity
if is_renaming:
logger.info(f"Renaming entity '{entity_name}' to '{new_entity_name}'")
# Create new entity
await self.chunk_entity_relation_graph.upsert_node(
new_entity_name, new_node_data
)
# Get all edges related to the original entity
edges = await self.chunk_entity_relation_graph.get_node_edges(
entity_name
)
if edges:
# Recreate edges for the new entity
for source, target in edges:
edge_data = await self.chunk_entity_relation_graph.get_edge(
source, target
)
if edge_data:
if source == entity_name:
await self.chunk_entity_relation_graph.upsert_edge(
new_entity_name, target, edge_data
)
else: # target == entity_name
await self.chunk_entity_relation_graph.upsert_edge(
source, new_entity_name, edge_data
)
# Delete old entity
await self.chunk_entity_relation_graph.delete_node(entity_name)
# Delete old entity record from vector database
old_entity_id = compute_mdhash_id(entity_name, prefix="ent-")
await self.entities_vdb.delete([old_entity_id])
# Update working entity name to new name
entity_name = new_entity_name
else:
# If not renaming, directly update node data
await self.chunk_entity_relation_graph.upsert_node(
entity_name, new_node_data
)
# 3. Recalculate entity's vector representation and update vector database
description = new_node_data.get("description", "")
source_id = new_node_data.get("source_id", "")
entity_type = new_node_data.get("entity_type", "")
content = entity_name + "\n" + description
# Calculate entity ID
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
# Prepare data for vector database update
entity_data = {
entity_id: {
"content": content,
"entity_name": entity_name,
"source_id": source_id,
"description": description,
"entity_type": entity_type,
}
}
# Update vector database
await self.entities_vdb.upsert(entity_data)
# 4. Save changes
await self._edit_entity_done()
logger.info(f"Entity '{entity_name}' successfully updated")
return await self.get_entity_info(entity_name, include_vector_data=True)
except Exception as e:
logger.error(f"Error while editing entity '{entity_name}': {e}")
raise
def edit_entity(
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
) -> dict[str, Any]:
"""Synchronously edit entity information.
Updates entity information in the knowledge graph and re-embeds the entity in the vector database.
Args:
entity_name: Name of the entity to edit
updated_data: Dictionary containing updated attributes, e.g. {"description": "new description", "entity_type": "new type"}
allow_rename: Whether to allow entity renaming, defaults to True
Returns:
Dictionary containing updated entity information
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(
self.aedit_entity(entity_name, updated_data, allow_rename)
)
async def _edit_entity_done(self) -> None:
"""Callback after entity editing is complete, ensures updates are persisted"""
await asyncio.gather(
*[
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in [ # type: ignore
self.entities_vdb,
self.chunk_entity_relation_graph,
]
]
)
async def aedit_relation(
self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
) -> dict[str, Any]:
"""Asynchronously edit relation information.
Updates relation (edge) information in the knowledge graph and re-embeds the relation in the vector database.
Args:
source_entity: Name of the source entity
target_entity: Name of the target entity
updated_data: Dictionary containing updated attributes, e.g. {"description": "new description", "keywords": "new keywords"}
Returns:
Dictionary containing updated relation information
"""
try:
# 1. Get current relation information
edge_data = await self.chunk_entity_relation_graph.get_edge(
source_entity, target_entity
)
if not edge_data:
raise ValueError(
f"Relation from '{source_entity}' to '{target_entity}' does not exist"
)
# 2. Update relation information in the graph
new_edge_data = {**edge_data, **updated_data}
await self.chunk_entity_relation_graph.upsert_edge(
source_entity, target_entity, new_edge_data
)
# 3. Recalculate relation's vector representation and update vector database
description = new_edge_data.get("description", "")
keywords = new_edge_data.get("keywords", "")
source_id = new_edge_data.get("source_id", "")
weight = float(new_edge_data.get("weight", 1.0))
# Create content for embedding
content = f"{keywords}\t{source_entity}\n{target_entity}\n{description}"
# Calculate relation ID
relation_id = compute_mdhash_id(
source_entity + target_entity, prefix="rel-"
)
# Prepare data for vector database update
relation_data = {
relation_id: {
"content": content,
"src_id": source_entity,
"tgt_id": target_entity,
"source_id": source_id,
"description": description,
"keywords": keywords,
"weight": weight,
}
}
# Update vector database
await self.relationships_vdb.upsert(relation_data)
# 4. Save changes
await self._edit_relation_done()
logger.info(
f"Relation from '{source_entity}' to '{target_entity}' successfully updated"
)
return await self.get_relation_info(
source_entity, target_entity, include_vector_data=True
)
except Exception as e:
logger.error(
f"Error while editing relation from '{source_entity}' to '{target_entity}': {e}"
)
raise
def edit_relation(
self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
) -> dict[str, Any]:
"""Synchronously edit relation information.
Updates relation (edge) information in the knowledge graph and re-embeds the relation in the vector database.
Args:
source_entity: Name of the source entity
target_entity: Name of the target entity
updated_data: Dictionary containing updated attributes, e.g. {"description": "new description", "keywords": "keywords"}
Returns:
Dictionary containing updated relation information
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(
self.aedit_relation(source_entity, target_entity, updated_data)
)
async def _edit_relation_done(self) -> None:
"""Callback after relation editing is complete, ensures updates are persisted"""
await asyncio.gather(
*[
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in [ # type: ignore
self.relationships_vdb,
self.chunk_entity_relation_graph,
]
]
)
async def acreate_entity(
self, entity_name: str, entity_data: dict[str, Any]
) -> dict[str, Any]:
"""Asynchronously create a new entity.
Creates a new entity in the knowledge graph and adds it to the vector database.
Args:
entity_name: Name of the new entity
entity_data: Dictionary containing entity attributes, e.g. {"description": "description", "entity_type": "type"}
Returns:
Dictionary containing created entity information
"""
try:
# Check if entity already exists
existing_node = await self.chunk_entity_relation_graph.get_node(entity_name)
if existing_node:
raise ValueError(f"Entity '{entity_name}' already exists")
# Prepare node data with defaults if missing
node_data = {
"entity_type": entity_data.get("entity_type", "UNKNOWN"),
"description": entity_data.get("description", ""),
"source_id": entity_data.get("source_id", "manual"),
}
# Add entity to knowledge graph
await self.chunk_entity_relation_graph.upsert_node(entity_name, node_data)
# Prepare content for entity
description = node_data.get("description", "")
source_id = node_data.get("source_id", "")
entity_type = node_data.get("entity_type", "")
content = entity_name + "\n" + description
# Calculate entity ID
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
# Prepare data for vector database update
entity_data_for_vdb = {
entity_id: {
"content": content,
"entity_name": entity_name,
"source_id": source_id,
"description": description,
"entity_type": entity_type,
}
}
# Update vector database
await self.entities_vdb.upsert(entity_data_for_vdb)
# Save changes
await self._edit_entity_done()
logger.info(f"Entity '{entity_name}' successfully created")
return await self.get_entity_info(entity_name, include_vector_data=True)
except Exception as e:
logger.error(f"Error while creating entity '{entity_name}': {e}")
raise
def create_entity(
self, entity_name: str, entity_data: dict[str, Any]
) -> dict[str, Any]:
"""Synchronously create a new entity.
Creates a new entity in the knowledge graph and adds it to the vector database.
Args:
entity_name: Name of the new entity
entity_data: Dictionary containing entity attributes, e.g. {"description": "description", "entity_type": "type"}
Returns:
Dictionary containing created entity information
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.acreate_entity(entity_name, entity_data))
async def acreate_relation(
self, source_entity: str, target_entity: str, relation_data: dict[str, Any]
) -> dict[str, Any]:
"""Asynchronously create a new relation between entities.
Creates a new relation (edge) in the knowledge graph and adds it to the vector database.
Args:
source_entity: Name of the source entity
target_entity: Name of the target entity
relation_data: Dictionary containing relation attributes, e.g. {"description": "description", "keywords": "keywords"}
Returns:
Dictionary containing created relation information
"""
try:
# Check if both entities exist
source_exists = await self.chunk_entity_relation_graph.has_node(
source_entity
)
target_exists = await self.chunk_entity_relation_graph.has_node(
target_entity
)
if not source_exists:
raise ValueError(f"Source entity '{source_entity}' does not exist")
if not target_exists:
raise ValueError(f"Target entity '{target_entity}' does not exist")
# Check if relation already exists
existing_edge = await self.chunk_entity_relation_graph.get_edge(
source_entity, target_entity
)
if existing_edge:
raise ValueError(
f"Relation from '{source_entity}' to '{target_entity}' already exists"
)
# Prepare edge data with defaults if missing
edge_data = {
"description": relation_data.get("description", ""),
"keywords": relation_data.get("keywords", ""),
"source_id": relation_data.get("source_id", "manual"),
"weight": float(relation_data.get("weight", 1.0)),
}
# Add relation to knowledge graph
await self.chunk_entity_relation_graph.upsert_edge(
source_entity, target_entity, edge_data
)
# Prepare content for embedding
description = edge_data.get("description", "")
keywords = edge_data.get("keywords", "")
source_id = edge_data.get("source_id", "")
weight = edge_data.get("weight", 1.0)
# Create content for embedding
content = f"{keywords}\t{source_entity}\n{target_entity}\n{description}"
# Calculate relation ID
relation_id = compute_mdhash_id(
source_entity + target_entity, prefix="rel-"
)
# Prepare data for vector database update
relation_data_for_vdb = {
relation_id: {
"content": content,
"src_id": source_entity,
"tgt_id": target_entity,
"source_id": source_id,
"description": description,
"keywords": keywords,
"weight": weight,
}
}
# Update vector database
await self.relationships_vdb.upsert(relation_data_for_vdb)
# Save changes
await self._edit_relation_done()
logger.info(
f"Relation from '{source_entity}' to '{target_entity}' successfully created"
)
return await self.get_relation_info(
source_entity, target_entity, include_vector_data=True
)
except Exception as e:
logger.error(
f"Error while creating relation from '{source_entity}' to '{target_entity}': {e}"
)
raise
def create_relation(
self, source_entity: str, target_entity: str, relation_data: dict[str, Any]
) -> dict[str, Any]:
"""Synchronously create a new relation between entities.
Creates a new relation (edge) in the knowledge graph and adds it to the vector database.
Args:
source_entity: Name of the source entity
target_entity: Name of the target entity
relation_data: Dictionary containing relation attributes, e.g. {"description": "description", "keywords": "keywords"}
Returns:
Dictionary containing created relation information
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(
self.acreate_relation(source_entity, target_entity, relation_data)
)