LightRAG/lightrag/lightrag.py

1549 lines
60 KiB
Python
Raw Normal View History

2024-10-10 15:02:30 +08:00
import asyncio
import os
import configparser
2024-10-10 15:02:30 +08:00
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
2025-02-09 19:21:49 +01:00
from typing import Any, Callable, Optional, Type, Union, cast
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
DocProcessingStatus,
DocStatus,
DocStatusStorage,
QueryParam,
StorageNameSpace,
)
from .namespace import NameSpace, make_namespace
2024-10-10 15:02:30 +08:00
from .operate import (
chunking_by_token_size,
2025-02-09 11:24:08 +01:00
extract_entities,
extract_keywords_only,
kg_query,
kg_query_with_keywords,
mix_kg_vector_query,
naive_query,
2024-10-10 15:02:30 +08:00
)
2025-02-09 19:21:49 +01:00
from .prompt import GRAPH_FIELD_SEP
2024-10-10 15:02:30 +08:00
from .utils import (
EmbeddingFunc,
compute_mdhash_id,
convert_response_to_json,
2025-02-09 19:21:49 +01:00
limit_async_func_call,
2024-10-10 15:02:30 +08:00
logger,
set_logger,
)
2025-02-09 11:24:08 +01:00
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Storage type and implementation compatibility validation table
STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": {
"implementations": [
"JsonKVStorage",
"MongoKVStorage",
"RedisKVStorage",
"TiDBKVStorage",
"PGKVStorage",
"OracleKVStorage",
],
"required_methods": ["get_by_id", "upsert"],
},
"GRAPH_STORAGE": {
"implementations": [
"NetworkXStorage",
"Neo4JStorage",
"MongoGraphStorage",
"TiDBGraphStorage",
"AGEStorage",
"GremlinStorage",
"PGGraphStorage",
"OracleGraphStorage",
],
"required_methods": ["upsert_node", "upsert_edge"],
},
"VECTOR_STORAGE": {
"implementations": [
"NanoVectorDBStorage",
"MilvusVectorDBStorge",
"ChromaVectorDBStorage",
"TiDBVectorDBStorage",
"PGVectorStorage",
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"OracleVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
"DOC_STATUS_STORAGE": {
"implementations": ["JsonDocStatusStorage", "PGDocStatusStorage"],
"required_methods": ["get_pending_docs"],
},
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],
"RedisKVStorage": ["REDIS_URI"],
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
2025-02-11 05:18:09 +08:00
"OracleKVStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Graph Storage Implementations
"NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [],
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"AGEStorage": [
"AGE_POSTGRES_DB",
"AGE_POSTGRES_USER",
"AGE_POSTGRES_PASSWORD",
],
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
"PGGraphStorage": [
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_DATABASE",
],
2025-02-11 05:18:09 +08:00
"OracleGraphStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Vector Storage Implementations
"NanoVectorDBStorage": [],
"MilvusVectorDBStorge": [],
"ChromaVectorDBStorage": [],
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"FaissVectorDBStorage": [],
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
2025-02-11 05:18:09 +08:00
"OracleVectorDBStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
}
# Storage implementation module mapping
2025-01-16 12:58:15 +08:00
STORAGES = {
2025-01-27 09:59:26 +01:00
"NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl",
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
"JsonDocStatusStorage": ".kg.jsondocstatus_impl",
2025-01-16 12:58:15 +08:00
"Neo4JStorage": ".kg.neo4j_impl",
"OracleKVStorage": ".kg.oracle_impl",
"OracleGraphStorage": ".kg.oracle_impl",
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorge": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
2025-01-29 07:31:34 -05:00
"MongoGraphStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
2025-01-16 12:58:15 +08:00
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",
"TiDBVectorDBStorage": ".kg.tidb_impl",
"TiDBGraphStorage": ".kg.tidb_impl",
"PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl",
"GremlinStorage": ".kg.gremlin_impl",
"PGDocStatusStorage": ".kg.postgres_impl",
2025-01-31 19:00:36 +05:30
"FaissVectorDBStorage": ".kg.faiss_impl",
2025-02-10 00:57:28 +08:00
"QdrantVectorDBStorage": ".kg.qdrant_impl",
2025-01-16 12:52:37 +08:00
}
2024-11-12 13:32:40 +08:00
def lazy_external_import(module_name: str, class_name: str):
"""Lazily import a class from an external module based on the package of the caller."""
2024-12-10 17:16:21 +08:00
# Get the caller's module and package
import inspect
caller_frame = inspect.currentframe().f_back
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None
def import_class(*args, **kwargs):
import importlib
module = importlib.import_module(module_name, package=package)
cls = getattr(module, class_name)
return cls(*args, **kwargs)
return import_class
2024-10-10 15:02:30 +08:00
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
Ensure that there is always an event loop available.
This function tries to get the current event loop. If the current event loop is closed or does not exist,
it creates a new event loop and sets it as the current event loop.
Returns:
asyncio.AbstractEventLoop: The current or newly created event loop.
"""
2024-10-10 15:02:30 +08:00
try:
# Try to get the current event loop
current_loop = asyncio.get_event_loop()
if current_loop.is_closed():
raise RuntimeError("Event loop is closed.")
return current_loop
2024-11-07 14:54:15 +08:00
2024-10-10 15:02:30 +08:00
except RuntimeError:
# If no event loop exists or it is closed, create a new one
2024-11-02 18:35:07 -04:00
logger.info("Creating a new event loop in main thread.")
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
return new_loop
2024-10-10 15:02:30 +08:00
2024-10-15 19:40:08 +08:00
2024-10-10 15:02:30 +08:00
@dataclass
class LightRAG:
2025-02-09 00:23:55 +01:00
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
2024-10-10 15:02:30 +08:00
working_dir: str = field(
2025-02-09 19:21:49 +01:00
default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}'
2024-10-10 15:02:30 +08:00
)
2025-02-09 00:23:55 +01:00
"""Directory where cache and temporary files are stored."""
embedding_cache_config: dict[str, Any] = field(
default_factory=lambda: {
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False,
}
)
2025-02-09 00:23:55 +01:00
"""Configuration for embedding cache.
- enabled: If True, enables caching to avoid redundant computations.
- similarity_threshold: Minimum similarity score to use cached embeddings.
- use_llm_check: If True, validates cached embeddings using an LLM.
"""
2024-11-12 13:32:40 +08:00
kv_storage: str = field(default="JsonKVStorage")
2025-02-09 00:23:55 +01:00
"""Storage backend for key-value data."""
vector_storage: str = field(default="NanoVectorDBStorage")
2025-02-09 00:23:55 +01:00
"""Storage backend for vector embeddings."""
graph_storage: str = field(default="NetworkXStorage")
2025-02-09 00:23:55 +01:00
"""Storage backend for knowledge graphs."""
doc_status_storage: str = field(default="JsonDocStatusStorage")
"""Storage type for tracking document processing statuses."""
2025-02-09 00:23:55 +01:00
# Logging
current_log_level = logger.level
2025-02-09 00:23:55 +01:00
log_level: int = field(default=current_log_level)
"""Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
log_dir: str = field(default=os.getcwd())
2025-02-09 00:23:55 +01:00
"""Directory where logs are stored. Defaults to the current working directory."""
2025-02-09 00:23:55 +01:00
# Text chunking
2024-10-10 15:02:30 +08:00
chunk_token_size: int = 1200
2025-02-09 00:23:55 +01:00
"""Maximum number of tokens per text chunk when splitting documents."""
2024-10-10 15:02:30 +08:00
chunk_overlap_token_size: int = 100
2025-02-09 00:23:55 +01:00
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
2024-10-10 15:02:30 +08:00
tiktoken_model_name: str = "gpt-4o-mini"
2025-02-09 00:23:55 +01:00
"""Model name used for tokenization when chunking text."""
2024-10-10 15:02:30 +08:00
2025-02-09 00:23:55 +01:00
# Entity extraction
2024-10-10 15:02:30 +08:00
entity_extract_max_gleaning: int = 1
2025-02-09 00:23:55 +01:00
"""Maximum number of entity extraction attempts for ambiguous content."""
2024-10-10 15:02:30 +08:00
entity_summary_to_max_tokens: int = 500
2025-02-09 00:23:55 +01:00
"""Maximum number of tokens used for summarizing extracted entities."""
2024-10-10 15:02:30 +08:00
2025-02-09 00:23:55 +01:00
# Node embedding
2024-10-10 15:02:30 +08:00
node_embedding_algorithm: str = "node2vec"
2025-02-09 00:23:55 +01:00
"""Algorithm used for node embedding in knowledge graphs."""
node2vec_params: dict[str, int] = field(
2024-10-10 15:02:30 +08:00
default_factory=lambda: {
"dimensions": 1536,
"num_walks": 10,
"walk_length": 40,
"window_size": 2,
"iterations": 3,
"random_seed": 3,
}
)
2025-02-09 00:23:55 +01:00
"""Configuration for the node2vec embedding algorithm:
- dimensions: Number of dimensions for embeddings.
- num_walks: Number of random walks per node.
- walk_length: Number of steps per random walk.
- window_size: Context window size for training.
- iterations: Number of iterations for training.
- random_seed: Seed value for reproducibility.
"""
embedding_func: EmbeddingFunc = None
"""Function for computing text embeddings. Must be set before use."""
2024-10-10 15:02:30 +08:00
embedding_batch_num: int = 32
2025-02-09 00:23:55 +01:00
"""Batch size for embedding computations."""
2024-10-10 15:02:30 +08:00
embedding_func_max_async: int = 16
2025-02-09 00:23:55 +01:00
"""Maximum number of concurrent embedding function calls."""
# LLM Configuration
llm_model_func: callable = None
"""Function for interacting with the large language model (LLM). Must be set before use."""
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
"""Name of the LLM model used for generating responses."""
2024-10-10 15:02:30 +08:00
llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768"))
2025-02-09 00:23:55 +01:00
"""Maximum number of tokens allowed per LLM response."""
llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16"))
2025-02-09 00:23:55 +01:00
"""Maximum number of concurrent LLM calls."""
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the LLM model function."""
# Storage
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional parameters for vector database storage."""
2024-10-10 15:02:30 +08:00
namespace_prefix: str = field(default="")
2025-02-09 00:23:55 +01:00
"""Prefix for namespacing stored data across different environments."""
2024-11-12 13:32:40 +08:00
2024-10-10 15:02:30 +08:00
enable_llm_cache: bool = True
2025-02-09 00:23:55 +01:00
"""Enables caching for LLM responses to avoid redundant computations."""
2025-01-06 15:27:31 +08:00
enable_llm_cache_for_entity_extract: bool = True
2025-02-09 00:23:55 +01:00
"""If True, enables caching for entity extraction steps to reduce LLM costs."""
# Extensions
addon_params: dict[str, Any] = field(default_factory=dict)
"""Dictionary for additional parameters and extensions."""
2024-10-10 15:02:30 +08:00
# extension
2025-02-09 11:10:46 +01:00
addon_params: dict[str, Any] = field(default_factory=dict)
2025-02-09 13:18:47 +01:00
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
convert_response_to_json
)
2024-10-10 15:02:30 +08:00
2025-01-09 17:20:24 +05:30
# Custom Chunking Function
2025-02-09 11:46:01 +01:00
chunking_func: Callable[
[
str,
Optional[str],
bool,
int,
int,
str,
],
list[dict[str, Any]],
] = chunking_by_token_size
2025-01-09 17:20:24 +05:30
def verify_storage_implementation(
self, storage_type: str, storage_name: str
) -> None:
"""Verify if storage implementation is compatible with specified storage type
Args:
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
storage_name: Storage implementation name
Raises:
ValueError: If storage implementation is incompatible or missing required methods
"""
if storage_type not in STORAGE_IMPLEMENTATIONS:
raise ValueError(f"Unknown storage type: {storage_type}")
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
if storage_name not in storage_info["implementations"]:
raise ValueError(
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
)
def check_storage_env_vars(self, storage_name: str) -> None:
"""Check if all required environment variables for storage implementation exist
Args:
storage_name: Storage implementation name
Raises:
ValueError: If required environment variables are missing
"""
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
)
def __post_init__(self):
os.makedirs(self.log_dir, exist_ok=True)
log_file = os.path.join(self.log_dir, "lightrag.log")
2024-10-10 15:02:30 +08:00
set_logger(log_file)
2025-02-02 14:04:24 +03:00
logger.setLevel(self.log_level)
2024-10-10 15:02:30 +08:00
logger.info(f"Logger initialized for working directory: {self.working_dir}")
2025-01-16 12:52:37 +08:00
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
# Verify storage implementation compatibility and environment variables
storage_configs = [
("KV_STORAGE", self.kv_storage),
("VECTOR_STORAGE", self.vector_storage),
("GRAPH_STORAGE", self.graph_storage),
("DOC_STATUS_STORAGE", self.doc_status_storage),
]
for storage_type, storage_name in storage_configs:
# Verify storage implementation compatibility
self.verify_storage_implementation(storage_type, storage_name)
# Check environment variables
self.check_storage_env_vars(storage_name)
2025-01-16 12:52:37 +08:00
# show config
2025-01-16 12:58:15 +08:00
global_config = asdict(self)
2025-01-16 12:52:37 +08:00
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
2024-10-10 15:02:30 +08:00
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
2025-01-16 12:52:37 +08:00
# Init LLM
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
2025-01-16 12:52:37 +08:00
# Initialize all storages
2024-11-12 13:32:40 +08:00
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
2025-01-16 12:58:15 +08:00
self._get_storage_class(self.kv_storage)
2024-11-12 13:32:40 +08:00
)
2025-01-16 12:58:15 +08:00
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
2024-11-12 13:32:40 +08:00
self.vector_storage
2025-01-16 12:58:15 +08:00
)
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
2024-11-12 13:32:40 +08:00
self.graph_storage
2025-01-16 12:58:15 +08:00
)
2025-01-16 12:52:37 +08:00
self.key_string_value_json_storage_cls = partial(
2025-01-16 12:58:15 +08:00
self.key_string_value_json_storage_cls, global_config=global_config
2025-01-16 12:52:37 +08:00
)
2024-10-10 15:02:30 +08:00
2025-01-16 12:52:37 +08:00
self.vector_db_storage_cls = partial(
2025-01-16 12:58:15 +08:00
self.vector_db_storage_cls, global_config=global_config
2024-11-12 13:32:40 +08:00
)
2025-01-16 12:52:37 +08:00
self.graph_storage_cls = partial(
2025-01-16 12:58:15 +08:00
self.graph_storage_cls, global_config=global_config
2025-01-16 12:52:37 +08:00
)
2025-02-11 10:17:51 +08:00
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
2025-02-11 10:17:51 +08:00
self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
embedding_func=self.embedding_func,
)
2024-10-15 19:40:08 +08:00
2025-02-08 23:25:42 +01:00
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
),
2024-11-12 13:32:40 +08:00
embedding_func=self.embedding_func,
)
2025-02-09 11:10:46 +01:00
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
),
2024-11-12 13:32:40 +08:00
embedding_func=self.embedding_func,
)
2025-02-08 23:25:42 +01:00
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
),
2024-12-03 16:04:58 +08:00
embedding_func=self.embedding_func,
)
self.entities_vdb = self.vector_db_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
2024-10-10 15:02:30 +08:00
)
self.relationships_vdb = self.vector_db_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
2024-10-10 15:02:30 +08:00
)
2025-02-09 11:10:46 +01:00
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
),
embedding_func=self.embedding_func,
2024-10-10 15:02:30 +08:00
)
# Initialize document status storage
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
global_config=global_config,
embedding_func=None,
)
2025-02-11 10:17:51 +08:00
# What's for, Is this nessisary ?
2025-01-16 12:58:15 +08:00
if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config"
):
2025-01-16 12:52:37 +08:00
hashing_kv = self.llm_response_cache
else:
hashing_kv = self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
embedding_func=self.embedding_func,
2025-01-16 12:58:15 +08:00
)
2025-01-16 12:52:37 +08:00
2025-02-11 10:17:51 +08:00
2024-10-10 15:02:30 +08:00
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
2024-10-28 17:05:38 +02:00
partial(
self.llm_model_func,
2025-01-16 12:52:37 +08:00
hashing_kv=hashing_kv,
2024-10-28 17:05:38 +02:00
**self.llm_model_kwargs,
)
2024-10-10 15:02:30 +08:00
)
2024-11-06 11:18:14 -05:00
async def get_graph_labels(self):
text = await self.chunk_entity_relation_graph.get_all_labels()
return text
async def get_graps(self, nodel_label: str, max_depth: int):
return await self.chunk_entity_relation_graph.get_knowledge_graph(
node_label=nodel_label, max_depth=max_depth
)
2025-01-16 12:52:37 +08:00
def _get_storage_class(self, storage_name: str) -> dict:
import_path = STORAGES[storage_name]
storage_class = lazy_external_import(import_path, storage_name)
return storage_class
2025-01-16 12:58:15 +08:00
def set_storage_client(self, db_client):
# Inject db to storage implementation (only tested on Oracle Database
# Deprecated, seting correct value to *_storage creating LightRAG insteaded
2025-01-16 12:58:15 +08:00
for storage in [
self.vector_db_storage_cls,
self.graph_storage_cls,
self.doc_status,
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.key_string_value_json_storage_cls,
self.chunks_vdb,
self.relationships_vdb,
self.entities_vdb,
self.graph_storage_cls,
self.chunk_entity_relation_graph,
self.llm_response_cache,
]:
2025-01-16 12:52:37 +08:00
# set client
storage.db = db_client
2024-10-10 15:02:30 +08:00
def insert(
2025-02-09 13:18:47 +01:00
self,
2025-02-09 11:29:05 +01:00
string_or_strings: Union[str, list[str]],
split_by_character: str | None = None,
split_by_character_only: bool = False,
):
2025-02-09 11:29:05 +01:00
"""Sync Insert documents with checkpoint support
Args:
string_or_strings: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
2025-02-09 13:18:47 +01:00
"""
2024-10-10 15:02:30 +08:00
loop = always_get_an_event_loop()
2025-01-07 16:26:12 +08:00
return loop.run_until_complete(
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
2025-01-07 16:26:12 +08:00
)
2024-10-10 15:02:30 +08:00
async def ainsert(
2025-02-09 11:24:08 +01:00
self,
string_or_strings: Union[str, list[str]],
split_by_character: str | None = None,
split_by_character_only: bool = False,
):
2025-02-09 11:29:05 +01:00
"""Async Insert documents with checkpoint support
Args:
string_or_strings: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
"""
2025-02-09 14:55:52 +01:00
await self.apipeline_enqueue_documents(string_or_strings)
2025-02-09 15:24:52 +01:00
await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character_only
)
2025-01-07 20:57:39 +05:30
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
loop = always_get_an_event_loop()
return loop.run_until_complete(
self.ainsert_custom_chunks(full_text, text_chunks)
)
2025-01-07 20:57:39 +05:30
async def ainsert_custom_chunks(self, full_text: str, text_chunks: list[str]):
update_storage = False
try:
doc_key = compute_mdhash_id(full_text.strip(), prefix="doc-")
new_docs = {doc_key: {"content": full_text.strip()}}
2025-01-07 20:57:39 +05:30
2025-02-09 19:56:12 +01:00
_add_doc_keys = await self.full_docs.filter_keys(set(doc_key))
2025-01-07 20:57:39 +05:30
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning("This document is already in the storage.")
return
update_storage = True
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
2025-02-09 19:56:12 +01:00
inserting_chunks: dict[str, Any] = {}
2025-01-07 20:57:39 +05:30
for chunk_text in text_chunks:
chunk_text_stripped = chunk_text.strip()
chunk_key = compute_mdhash_id(chunk_text_stripped, prefix="chunk-")
2025-01-07 20:57:39 +05:30
inserting_chunks[chunk_key] = {
"content": chunk_text_stripped,
"full_doc_id": doc_key,
}
2025-02-09 19:56:12 +01:00
doc_ids = set(inserting_chunks.keys())
add_chunk_keys = await self.text_chunks.filter_keys(doc_ids)
2025-01-07 20:57:39 +05:30
inserting_chunks = {
2025-02-09 19:56:12 +01:00
k: v for k, v in inserting_chunks.items() if k in add_chunk_keys
2025-01-07 20:57:39 +05:30
}
if not len(inserting_chunks):
logger.warning("All chunks are already in the storage.")
return
tasks = [
self.chunks_vdb.upsert(inserting_chunks),
self._process_entity_relation_graph(inserting_chunks),
self.full_docs.upsert(new_docs),
self.text_chunks.upsert(inserting_chunks),
]
await asyncio.gather(*tasks)
2025-01-07 20:57:39 +05:30
finally:
if update_storage:
await self._insert_done()
2025-02-09 14:32:48 +01:00
async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]):
2025-02-09 14:39:32 +01:00
"""
Pipeline for Processing Documents
2025-02-09 15:24:52 +01:00
2025-02-09 11:30:54 +01:00
1. Remove duplicate contents from the list
2. Generate document IDs and initial status
2025-02-09 14:39:32 +01:00
3. Filter out already processed documents
2025-02-09 15:24:52 +01:00
4. Enqueue document in status
"""
2025-01-16 12:52:37 +08:00
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
# 2. Generate document IDs and initial status
2025-02-09 11:10:46 +01:00
new_docs: dict[str, Any] = {
2025-01-16 12:52:37 +08:00
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(),
2025-02-09 11:10:46 +01:00
"updated_at": datetime.now().isoformat(),
2025-01-16 12:52:37 +08:00
}
for content in unique_contents
}
2025-01-16 12:58:15 +08:00
# 3. Filter out already processed documents
2025-02-09 14:55:52 +01:00
# Get docs ids
2025-02-09 19:24:41 +01:00
all_new_doc_ids = set(new_docs.keys())
# Exclude IDs of documents that are already in progress
2025-02-09 21:17:09 +01:00
unique_new_doc_ids = await self.doc_status.filter_keys(all_new_doc_ids)
2025-02-09 19:24:41 +01:00
# Filter new_docs to only include documents with unique IDs
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids}
2025-01-16 12:52:37 +08:00
if not new_docs:
2025-02-11 13:28:18 +08:00
logger.info("No new unique documents were found.")
2025-02-09 11:10:46 +01:00
return
2025-01-16 12:52:37 +08:00
2025-02-09 14:32:48 +01:00
# 4. Store status document
2025-02-09 13:18:47 +01:00
await self.doc_status.upsert(new_docs)
2025-01-16 12:52:37 +08:00
logger.info(f"Stored {len(new_docs)} new unique documents")
2025-01-16 12:58:15 +08:00
2025-02-09 14:32:48 +01:00
async def apipeline_process_enqueue_documents(
2025-02-09 11:24:08 +01:00
self,
split_by_character: str | None = None,
split_by_character_only: bool = False,
) -> None:
2025-02-09 11:30:54 +01:00
"""
2025-02-09 14:32:48 +01:00
Process pending documents by splitting them into chunks, processing
2025-02-09 14:36:49 +01:00
each chunk for entity and relation extraction, and updating the
2025-02-09 14:32:48 +01:00
document status.
2025-02-09 14:36:49 +01:00
2025-02-11 13:28:18 +08:00
1. Get all pending, failed, and abnormally terminated processing documents.
2025-02-09 14:32:48 +01:00
2. Split document content into chunks
3. Process each chunk for entity and relation extraction
4. Update the document status
2025-02-09 14:36:49 +01:00
"""
2025-02-11 13:28:18 +08:00
# 1. Get all pending, failed, and abnormally terminated processing documents.
2025-02-09 15:24:52 +01:00
to_process_docs: dict[str, DocProcessingStatus] = {}
2025-02-09 14:36:49 +01:00
2025-02-11 13:28:18 +08:00
processing_docs = await self.doc_status.get_processing_docs()
to_process_docs.update(processing_docs)
2025-02-09 21:03:14 +01:00
failed_docs = await self.doc_status.get_failed_docs()
to_process_docs.update(failed_docs)
pendings_docs = await self.doc_status.get_pending_docs()
to_process_docs.update(pendings_docs)
2025-02-09 15:25:58 +01:00
2025-02-09 15:24:52 +01:00
if not to_process_docs:
2025-02-09 11:10:46 +01:00
logger.info("All documents have been processed or are duplicates")
2025-02-09 15:25:58 +01:00
return
2025-02-09 14:36:49 +01:00
2025-01-16 12:52:37 +08:00
# 2. split docs into chunks, insert chunks, update doc status
batch_size = self.addon_params.get("insert_batch_size", 10)
2025-02-09 20:18:38 +01:00
docs_batches = [
2025-02-09 15:24:52 +01:00
list(to_process_docs.items())[i : i + batch_size]
for i in range(0, len(to_process_docs), batch_size)
2025-02-09 13:18:47 +01:00
]
2025-02-09 14:36:49 +01:00
2025-02-09 20:18:38 +01:00
logger.info(f"Number of batches to process: {len(docs_batches)}.")
2025-02-09 19:21:49 +01:00
2025-02-09 20:41:18 +01:00
# 3. iterate over batches
2025-02-09 20:18:38 +01:00
for batch_idx, docs_batch in enumerate(docs_batches):
2025-02-09 14:24:35 +01:00
# 4. iterate over batch
2025-02-09 20:18:38 +01:00
for doc_id_processing_status in docs_batch:
doc_id, status_doc = doc_id_processing_status
2025-02-09 15:36:01 +01:00
# Update status in processing
2025-02-09 21:03:14 +01:00
doc_status_id = compute_mdhash_id(status_doc.content, prefix="doc-")
2025-02-09 14:24:35 +01:00
await self.doc_status.upsert(
2025-02-09 13:54:04 +01:00
{
2025-02-09 21:03:14 +01:00
doc_status_id: {
2025-02-09 14:24:35 +01:00
"status": DocStatus.PROCESSING,
"updated_at": datetime.now().isoformat(),
2025-02-11 13:28:18 +08:00
"content": status_doc.content,
2025-02-09 15:24:52 +01:00
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
2025-02-09 14:24:35 +01:00
}
2025-01-16 12:52:37 +08:00
}
2025-02-09 13:54:04 +01:00
)
# Generate chunks from document
chunks: dict[str, Any] = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
2025-02-09 20:18:38 +01:00
"full_doc_id": doc_id,
2025-02-09 13:54:04 +01:00
}
for dp in self.chunking_func(
2025-02-09 15:24:52 +01:00
status_doc.content,
2025-02-09 13:54:04 +01:00
split_by_character,
split_by_character_only,
self.chunk_overlap_token_size,
self.chunk_token_size,
self.tiktoken_model_name,
2025-02-09 13:18:47 +01:00
)
2025-02-09 13:54:04 +01:00
}
2025-02-09 14:36:49 +01:00
# Process document (text chunks and full docs) in parallel
2025-02-09 21:48:19 +01:00
tasks = [
self.chunks_vdb.upsert(chunks),
self._process_entity_relation_graph(chunks),
self.full_docs.upsert({doc_id: {"content": status_doc.content}}),
self.text_chunks.upsert(chunks),
]
2025-02-09 20:41:18 +01:00
try:
await asyncio.gather(*tasks)
await self.doc_status.update_doc_status(
2025-02-09 20:41:18 +01:00
{
2025-02-09 21:03:14 +01:00
doc_status_id: {
2025-02-09 20:41:18 +01:00
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
2025-02-11 13:28:18 +08:00
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
2025-02-09 20:41:18 +01:00
"updated_at": datetime.now().isoformat(),
2025-02-09 13:54:04 +01:00
}
2025-02-09 20:41:18 +01:00
}
)
await self._insert_done()
except Exception as e:
logger.error(f"Failed to process document {doc_id}: {str(e)}")
await self.doc_status.update_doc_status(
2025-02-09 20:41:18 +01:00
{
2025-02-09 21:03:14 +01:00
doc_status_id: {
2025-02-09 20:41:18 +01:00
"status": DocStatus.FAILED,
"error": str(e),
2025-02-11 13:28:18 +08:00
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
2025-02-09 20:41:18 +01:00
"updated_at": datetime.now().isoformat(),
2025-02-09 14:24:35 +01:00
}
2025-02-09 20:41:18 +01:00
}
)
continue
2025-02-09 20:18:38 +01:00
logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.")
2025-01-16 12:52:37 +08:00
2025-02-09 13:03:50 +01:00
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
2025-02-09 13:18:47 +01:00
try:
new_kg = await extract_entities(
chunk,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)
if new_kg is None:
2025-02-09 20:41:18 +01:00
logger.info("No new entities or relationships extracted.")
2025-02-09 13:18:47 +01:00
else:
2025-02-09 20:41:18 +01:00
logger.info("New entities or relationships extracted.")
2025-02-09 13:18:47 +01:00
self.chunk_entity_relation_graph = new_kg
except Exception as e:
logger.error("Failed to extract entities and relationships")
raise e
2024-10-10 15:02:30 +08:00
async def _insert_done(self):
tasks = []
for storage_inst in [
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
2024-11-25 18:06:19 +08:00
def insert_custom_kg(self, custom_kg: dict):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
async def ainsert_custom_kg(self, custom_kg: dict):
update_storage = False
try:
2024-12-04 19:44:04 +08:00
# Insert chunks into vector storage
all_chunks_data = {}
chunk_to_source_map = {}
for chunk_data in custom_kg.get("chunks", []):
chunk_content = chunk_data["content"]
source_id = chunk_data["source_id"]
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
chunk_entry = {"content": chunk_content.strip(), "source_id": source_id}
all_chunks_data[chunk_id] = chunk_entry
chunk_to_source_map[source_id] = chunk_id
update_storage = True
if self.chunks_vdb is not None and all_chunks_data:
await self.chunks_vdb.upsert(all_chunks_data)
if self.text_chunks is not None and all_chunks_data:
await self.text_chunks.upsert(all_chunks_data)
2024-11-25 18:06:19 +08:00
# Insert entities into knowledge graph
all_entities_data = []
for entity_data in custom_kg.get("entities", []):
entity_name = f'"{entity_data["entity_name"].upper()}"'
entity_type = entity_data.get("entity_type", "UNKNOWN")
description = entity_data.get("description", "No description provided")
2024-12-04 19:44:04 +08:00
# source_id = entity_data["source_id"]
source_chunk_id = entity_data.get("source_id", "UNKNOWN")
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
# Log if source_id is UNKNOWN
if source_id == "UNKNOWN":
logger.warning(
f"Entity '{entity_name}' has an UNKNOWN source_id. Please check the source mapping."
)
2024-11-25 18:06:19 +08:00
# Prepare node data
node_data = {
"entity_type": entity_type,
"description": description,
"source_id": source_id,
}
# Insert node data into the knowledge graph
await self.chunk_entity_relation_graph.upsert_node(
entity_name, node_data=node_data
)
node_data["entity_name"] = entity_name
all_entities_data.append(node_data)
update_storage = True
# Insert relationships into knowledge graph
all_relationships_data = []
for relationship_data in custom_kg.get("relationships", []):
src_id = f'"{relationship_data["src_id"].upper()}"'
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
description = relationship_data["description"]
keywords = relationship_data["keywords"]
weight = relationship_data.get("weight", 1.0)
2024-12-04 19:44:04 +08:00
# source_id = relationship_data["source_id"]
source_chunk_id = relationship_data.get("source_id", "UNKNOWN")
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
# Log if source_id is UNKNOWN
if source_id == "UNKNOWN":
logger.warning(
f"Relationship from '{src_id}' to '{tgt_id}' has an UNKNOWN source_id. Please check the source mapping."
)
2024-11-25 18:06:19 +08:00
# Check if nodes exist in the knowledge graph
for need_insert_id in [src_id, tgt_id]:
if not (
2025-01-07 16:26:12 +08:00
await self.chunk_entity_relation_graph.has_node(need_insert_id)
2024-11-25 18:06:19 +08:00
):
await self.chunk_entity_relation_graph.upsert_node(
need_insert_id,
node_data={
"source_id": source_id,
"description": "UNKNOWN",
"entity_type": "UNKNOWN",
},
)
# Insert edge into the knowledge graph
await self.chunk_entity_relation_graph.upsert_edge(
src_id,
tgt_id,
edge_data={
"weight": weight,
"description": description,
"keywords": keywords,
"source_id": source_id,
},
)
edge_data = {
"src_id": src_id,
"tgt_id": tgt_id,
"description": description,
"keywords": keywords,
}
all_relationships_data.append(edge_data)
update_storage = True
# Insert entities into vector storage if needed
if self.entities_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
for dp in all_entities_data
}
await self.entities_vdb.upsert(data_for_vdb)
# Insert relationships into vector storage if needed
if self.relationships_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"]
2025-01-07 16:26:12 +08:00
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
2024-11-25 18:06:19 +08:00
}
for dp in all_relationships_data
}
await self.relationships_vdb.upsert(data_for_vdb)
finally:
if update_storage:
await self._insert_done()
2025-01-27 10:32:22 +05:30
def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
2024-10-10 15:02:30 +08:00
loop = always_get_an_event_loop()
2025-01-27 10:32:22 +05:30
return loop.run_until_complete(self.aquery(query, prompt, param))
2025-01-27 10:32:22 +05:30
async def aquery(
self, query: str, prompt: str = "", param: QueryParam = QueryParam()
):
2024-11-25 13:29:55 +08:00
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
2024-10-10 15:02:30 +08:00
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
2025-01-07 16:26:12 +08:00
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
2025-02-02 04:27:55 +08:00
embedding_func=self.embedding_func,
),
2025-01-27 10:32:22 +05:30
prompt=prompt,
2024-10-10 15:02:30 +08:00
)
elif param.mode == "naive":
response = await naive_query(
query,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
2025-01-07 16:26:12 +08:00
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
2024-10-10 15:02:30 +08:00
)
elif param.mode == "mix":
response = await mix_kg_vector_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
if self.llm_response_cache
2025-01-07 16:26:12 +08:00
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
)
2024-10-10 15:02:30 +08:00
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
def query_with_separate_keyword_extraction(
2025-01-14 22:23:14 +05:30
self, query: str, prompt: str, param: QueryParam = QueryParam()
):
"""
1. Extract keywords from the 'query' using new function in operate.py.
2. Then run the standard aquery() flow with the final prompt (formatted_question).
"""
loop = always_get_an_event_loop()
2025-01-14 22:23:14 +05:30
return loop.run_until_complete(
self.aquery_with_separate_keyword_extraction(query, prompt, param)
)
async def aquery_with_separate_keyword_extraction(
2025-01-14 22:23:14 +05:30
self, query: str, prompt: str, param: QueryParam = QueryParam()
):
"""
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
"""
# ---------------------
# STEP 1: Keyword Extraction
# ---------------------
# We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords).
hl_keywords, ll_keywords = await extract_keywords_only(
text=query,
param=param,
global_config=asdict(self),
2025-01-14 22:23:14 +05:30
hashing_kv=self.llm_response_cache
or self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
2025-01-14 22:23:14 +05:30
),
)
2025-01-14 22:23:14 +05:30
param.hl_keywords = (hl_keywords,)
param.ll_keywords = (ll_keywords,)
# ---------------------
# STEP 2: Final Query Logic
# ---------------------
2025-01-14 22:23:14 +05:30
# Create a new string with the prompt and the keywords
ll_keywords_str = ", ".join(ll_keywords)
hl_keywords_str = ", ".join(hl_keywords)
formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query_with_keywords(
formatted_question,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
2025-01-14 22:23:14 +05:30
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_funcne,
),
)
elif param.mode == "naive":
response = await naive_query(
formatted_question,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
2025-01-14 22:23:14 +05:30
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
)
elif param.mode == "mix":
response = await mix_kg_vector_query(
formatted_question,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache
2025-01-14 22:23:14 +05:30
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
)
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
2024-10-10 15:02:30 +08:00
return response
async def _query_done(self):
tasks = []
for storage_inst in [self.llm_response_cache]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
2024-11-06 11:18:14 -05:00
await asyncio.gather(*tasks)
2024-11-11 17:48:40 +08:00
def delete_by_entity(self, entity_name: str):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.adelete_by_entity(entity_name))
2024-11-11 17:54:22 +08:00
2024-11-11 17:48:40 +08:00
async def adelete_by_entity(self, entity_name: str):
2024-11-11 17:54:22 +08:00
entity_name = f'"{entity_name.upper()}"'
2024-11-11 17:48:40 +08:00
try:
await self.entities_vdb.delete_entity(entity_name)
2024-12-31 17:15:57 +08:00
await self.relationships_vdb.delete_entity_relation(entity_name)
2024-11-11 17:48:40 +08:00
await self.chunk_entity_relation_graph.delete_node(entity_name)
2024-11-11 17:54:22 +08:00
logger.info(
f"Entity '{entity_name}' and its relationships have been deleted."
)
2024-11-11 17:48:40 +08:00
await self._delete_by_entity_done()
except Exception as e:
logger.error(f"Error while deleting entity '{entity_name}': {e}")
2024-11-11 17:54:22 +08:00
2024-11-11 17:48:40 +08:00
async def _delete_by_entity_done(self):
tasks = []
for storage_inst in [
self.entities_vdb,
self.relationships_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
2024-11-11 17:54:22 +08:00
await asyncio.gather(*tasks)
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
"""Get summary of document content
Args:
content: Original document content
max_length: Maximum length of summary
Returns:
Truncated content with ellipsis if needed
"""
content = content.strip()
if len(content) <= max_length:
return content
return content[:max_length] + "..."
2025-02-09 11:24:08 +01:00
async def get_processing_status(self) -> dict[str, int]:
"""Get current document processing status counts
Returns:
Dict with counts for each status
"""
return await self.doc_status.get_status_counts()
2024-12-31 17:15:57 +08:00
async def adelete_by_doc_id(self, doc_id: str):
"""Delete a document and all its related data
Args:
doc_id: Document ID to delete
"""
try:
# 1. Get the document status and related data
doc_status = await self.doc_status.get(doc_id)
if not doc_status:
logger.warning(f"Document {doc_id} not found")
return
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
logger.debug(f"Starting deletion for document {doc_id}")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# 2. Get all related chunks
2024-12-31 17:32:04 +08:00
chunks = await self.text_chunks.filter(
lambda x: x.get("full_doc_id") == doc_id
)
2024-12-31 17:15:57 +08:00
chunk_ids = list(chunks.keys())
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# 3. Before deleting, check the related entities and relationships for these chunks
for chunk_id in chunk_ids:
# Check entities
entities = [
2024-12-31 17:32:04 +08:00
dp
for dp in self.entities_vdb.client_storage["data"]
2024-12-31 17:15:57 +08:00
if dp.get("source_id") == chunk_id
]
logger.debug(f"Chunk {chunk_id} has {len(entities)} related entities")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Check relationships
relations = [
2024-12-31 17:32:04 +08:00
dp
for dp in self.relationships_vdb.client_storage["data"]
2024-12-31 17:15:57 +08:00
if dp.get("source_id") == chunk_id
]
logger.debug(f"Chunk {chunk_id} has {len(relations)} related relations")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Continue with the original deletion process...
# 4. Delete chunks from vector database
if chunk_ids:
await self.chunks_vdb.delete(chunk_ids)
await self.text_chunks.delete(chunk_ids)
# 5. Find and process entities and relationships that have these chunks as source
# Get all nodes in the graph
nodes = self.chunk_entity_relation_graph._graph.nodes(data=True)
edges = self.chunk_entity_relation_graph._graph.edges(data=True)
# Track which entities and relationships need to be deleted or updated
entities_to_delete = set()
entities_to_update = {} # entity_name -> new_source_id
relationships_to_delete = set()
relationships_to_update = {} # (src, tgt) -> new_source_id
# Process entities
for node, data in nodes:
2024-12-31 17:32:04 +08:00
if "source_id" in data:
2024-12-31 17:15:57 +08:00
# Split source_id using GRAPH_FIELD_SEP
2024-12-31 17:32:04 +08:00
sources = set(data["source_id"].split(GRAPH_FIELD_SEP))
2024-12-31 17:15:57 +08:00
sources.difference_update(chunk_ids)
if not sources:
entities_to_delete.add(node)
2024-12-31 17:32:04 +08:00
logger.debug(
f"Entity {node} marked for deletion - no remaining sources"
)
2024-12-31 17:15:57 +08:00
else:
new_source_id = GRAPH_FIELD_SEP.join(sources)
entities_to_update[node] = new_source_id
2024-12-31 17:32:04 +08:00
logger.debug(
f"Entity {node} will be updated with new source_id: {new_source_id}"
)
2024-12-31 17:15:57 +08:00
# Process relationships
for src, tgt, data in edges:
2024-12-31 17:32:04 +08:00
if "source_id" in data:
2024-12-31 17:15:57 +08:00
# Split source_id using GRAPH_FIELD_SEP
2024-12-31 17:32:04 +08:00
sources = set(data["source_id"].split(GRAPH_FIELD_SEP))
2024-12-31 17:15:57 +08:00
sources.difference_update(chunk_ids)
if not sources:
relationships_to_delete.add((src, tgt))
2024-12-31 17:32:04 +08:00
logger.debug(
f"Relationship {src}-{tgt} marked for deletion - no remaining sources"
)
2024-12-31 17:15:57 +08:00
else:
new_source_id = GRAPH_FIELD_SEP.join(sources)
relationships_to_update[(src, tgt)] = new_source_id
2024-12-31 17:32:04 +08:00
logger.debug(
f"Relationship {src}-{tgt} will be updated with new source_id: {new_source_id}"
)
2024-12-31 17:15:57 +08:00
# Delete entities
if entities_to_delete:
for entity in entities_to_delete:
await self.entities_vdb.delete_entity(entity)
logger.debug(f"Deleted entity {entity} from vector DB")
self.chunk_entity_relation_graph.remove_nodes(list(entities_to_delete))
logger.debug(f"Deleted {len(entities_to_delete)} entities from graph")
# Update entities
for entity, new_source_id in entities_to_update.items():
node_data = self.chunk_entity_relation_graph._graph.nodes[entity]
2024-12-31 17:32:04 +08:00
node_data["source_id"] = new_source_id
2024-12-31 17:15:57 +08:00
await self.chunk_entity_relation_graph.upsert_node(entity, node_data)
2024-12-31 17:32:04 +08:00
logger.debug(
f"Updated entity {entity} with new source_id: {new_source_id}"
)
2024-12-31 17:15:57 +08:00
# Delete relationships
if relationships_to_delete:
for src, tgt in relationships_to_delete:
rel_id_0 = compute_mdhash_id(src + tgt, prefix="rel-")
rel_id_1 = compute_mdhash_id(tgt + src, prefix="rel-")
await self.relationships_vdb.delete([rel_id_0, rel_id_1])
logger.debug(f"Deleted relationship {src}-{tgt} from vector DB")
2024-12-31 17:32:04 +08:00
self.chunk_entity_relation_graph.remove_edges(
list(relationships_to_delete)
)
logger.debug(
f"Deleted {len(relationships_to_delete)} relationships from graph"
)
2024-12-31 17:15:57 +08:00
# Update relationships
for (src, tgt), new_source_id in relationships_to_update.items():
edge_data = self.chunk_entity_relation_graph._graph.edges[src, tgt]
2024-12-31 17:32:04 +08:00
edge_data["source_id"] = new_source_id
2024-12-31 17:15:57 +08:00
await self.chunk_entity_relation_graph.upsert_edge(src, tgt, edge_data)
2024-12-31 17:32:04 +08:00
logger.debug(
f"Updated relationship {src}-{tgt} with new source_id: {new_source_id}"
)
2024-12-31 17:15:57 +08:00
# 6. Delete original document and status
await self.full_docs.delete([doc_id])
await self.doc_status.delete([doc_id])
# 7. Ensure all indexes are updated
await self._insert_done()
logger.info(
f"Successfully deleted document {doc_id} and related data. "
f"Deleted {len(entities_to_delete)} entities and {len(relationships_to_delete)} relationships. "
f"Updated {len(entities_to_update)} entities and {len(relationships_to_update)} relationships."
)
# Add verification step
async def verify_deletion():
# Verify if the document has been deleted
if await self.full_docs.get_by_id(doc_id):
logger.error(f"Document {doc_id} still exists in full_docs")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Verify if chunks have been deleted
remaining_chunks = await self.text_chunks.filter(
lambda x: x.get("full_doc_id") == doc_id
)
if remaining_chunks:
logger.error(f"Found {len(remaining_chunks)} remaining chunks")
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Verify entities and relationships
for chunk_id in chunk_ids:
# Check entities
entities_with_chunk = [
2024-12-31 17:32:04 +08:00
dp
for dp in self.entities_vdb.client_storage["data"]
if chunk_id
2025-01-07 16:26:12 +08:00
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
2024-12-31 17:15:57 +08:00
]
if entities_with_chunk:
2024-12-31 17:32:04 +08:00
logger.error(
f"Found {len(entities_with_chunk)} entities still referencing chunk {chunk_id}"
)
2024-12-31 17:15:57 +08:00
# Check relationships
relations_with_chunk = [
2024-12-31 17:32:04 +08:00
dp
for dp in self.relationships_vdb.client_storage["data"]
if chunk_id
2025-01-07 16:26:12 +08:00
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
2024-12-31 17:15:57 +08:00
]
if relations_with_chunk:
2024-12-31 17:32:04 +08:00
logger.error(
f"Found {len(relations_with_chunk)} relations still referencing chunk {chunk_id}"
)
2024-12-31 17:15:57 +08:00
await verify_deletion()
except Exception as e:
logger.error(f"Error while deleting document {doc_id}: {e}")
def delete_by_doc_id(self, doc_id: str):
"""Synchronous version of adelete"""
return asyncio.run(self.adelete_by_doc_id(doc_id))
2024-12-31 17:32:04 +08:00
async def get_entity_info(
2025-01-07 16:26:12 +08:00
self, entity_name: str, include_vector_data: bool = False
2024-12-31 17:32:04 +08:00
):
2024-12-31 17:15:57 +08:00
"""Get detailed information of an entity
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Args:
entity_name: Entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Returns:
dict: A dictionary containing entity information, including:
- entity_name: Entity name
- source_id: Source document ID
- graph_data: Complete node data from the graph database
- vector_data: (optional) Data from the vector database
"""
entity_name = f'"{entity_name.upper()}"'
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Get information from the graph
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
2024-12-31 17:32:04 +08:00
source_id = node_data.get("source_id") if node_data else None
2024-12-31 17:15:57 +08:00
result = {
"entity_name": entity_name,
"source_id": source_id,
"graph_data": node_data,
}
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Optional: Get vector database information
if include_vector_data:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
vector_data = self.entities_vdb._client.get([entity_id])
result["vector_data"] = vector_data[0] if vector_data else None
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
return result
def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
"""Synchronous version of getting entity information
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Args:
entity_name: Entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
"""
try:
import tracemalloc
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
tracemalloc.start()
return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
finally:
tracemalloc.stop()
2024-12-31 17:32:04 +08:00
async def get_relation_info(
2025-01-07 16:26:12 +08:00
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
2024-12-31 17:32:04 +08:00
):
2024-12-31 17:15:57 +08:00
"""Get detailed information of a relationship
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Args:
src_entity: Source entity name (no need for quotes)
tgt_entity: Target entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Returns:
dict: A dictionary containing relationship information, including:
- src_entity: Source entity name
- tgt_entity: Target entity name
- source_id: Source document ID
- graph_data: Complete edge data from the graph database
- vector_data: (optional) Data from the vector database
"""
src_entity = f'"{src_entity.upper()}"'
tgt_entity = f'"{tgt_entity.upper()}"'
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Get information from the graph
2024-12-31 17:32:04 +08:00
edge_data = await self.chunk_entity_relation_graph.get_edge(
src_entity, tgt_entity
)
source_id = edge_data.get("source_id") if edge_data else None
2024-12-31 17:15:57 +08:00
result = {
"src_entity": src_entity,
"tgt_entity": tgt_entity,
"source_id": source_id,
"graph_data": edge_data,
}
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
# Optional: Get vector database information
if include_vector_data:
rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-")
vector_data = self.relationships_vdb._client.get([rel_id])
result["vector_data"] = vector_data[0] if vector_data else None
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
return result
2024-12-31 17:32:04 +08:00
def get_relation_info_sync(
2025-01-07 16:26:12 +08:00
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
2024-12-31 17:32:04 +08:00
):
2024-12-31 17:15:57 +08:00
"""Synchronous version of getting relationship information
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
Args:
src_entity: Source entity name (no need for quotes)
tgt_entity: Target entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
"""
try:
import tracemalloc
2024-12-31 17:32:04 +08:00
2024-12-31 17:15:57 +08:00
tracemalloc.start()
2024-12-31 17:32:04 +08:00
return asyncio.run(
self.get_relation_info(src_entity, tgt_entity, include_vector_data)
)
2024-12-31 17:15:57 +08:00
finally:
tracemalloc.stop()