mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-06-26 22:00:19 +00:00
1646 lines
64 KiB
Python
1646 lines
64 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
import configparser
|
|
from dataclasses import asdict, dataclass, field
|
|
from datetime import datetime
|
|
from functools import partial
|
|
from typing import Any, AsyncIterator, Callable, Iterator, cast, final
|
|
|
|
from .base import (
|
|
BaseGraphStorage,
|
|
BaseKVStorage,
|
|
BaseVectorStorage,
|
|
DocProcessingStatus,
|
|
DocStatus,
|
|
DocStatusStorage,
|
|
QueryParam,
|
|
StorageNameSpace,
|
|
StoragesStatus,
|
|
)
|
|
from .namespace import NameSpace, make_namespace
|
|
from .operate import (
|
|
chunking_by_token_size,
|
|
extract_entities,
|
|
extract_keywords_only,
|
|
kg_query,
|
|
kg_query_with_keywords,
|
|
mix_kg_vector_query,
|
|
naive_query,
|
|
)
|
|
from .prompt import GRAPH_FIELD_SEP
|
|
from .utils import (
|
|
EmbeddingFunc,
|
|
always_get_an_event_loop,
|
|
compute_mdhash_id,
|
|
convert_response_to_json,
|
|
lazy_external_import,
|
|
limit_async_func_call,
|
|
logger,
|
|
set_logger,
|
|
encode_string_by_tiktoken,
|
|
)
|
|
|
|
config = configparser.ConfigParser()
|
|
config.read("config.ini", "utf-8")
|
|
|
|
# Storage type and implementation compatibility validation table
|
|
STORAGE_IMPLEMENTATIONS = {
|
|
"KV_STORAGE": {
|
|
"implementations": [
|
|
"JsonKVStorage",
|
|
"MongoKVStorage",
|
|
"RedisKVStorage",
|
|
"TiDBKVStorage",
|
|
"PGKVStorage",
|
|
"OracleKVStorage",
|
|
],
|
|
"required_methods": ["get_by_id", "upsert"],
|
|
},
|
|
"GRAPH_STORAGE": {
|
|
"implementations": [
|
|
"NetworkXStorage",
|
|
"Neo4JStorage",
|
|
"MongoGraphStorage",
|
|
"TiDBGraphStorage",
|
|
"AGEStorage",
|
|
"GremlinStorage",
|
|
"PGGraphStorage",
|
|
"OracleGraphStorage",
|
|
],
|
|
"required_methods": ["upsert_node", "upsert_edge"],
|
|
},
|
|
"VECTOR_STORAGE": {
|
|
"implementations": [
|
|
"NanoVectorDBStorage",
|
|
"MilvusVectorDBStorage",
|
|
"ChromaVectorDBStorage",
|
|
"TiDBVectorDBStorage",
|
|
"PGVectorStorage",
|
|
"FaissVectorDBStorage",
|
|
"QdrantVectorDBStorage",
|
|
"OracleVectorDBStorage",
|
|
"MongoVectorDBStorage",
|
|
],
|
|
"required_methods": ["query", "upsert"],
|
|
},
|
|
"DOC_STATUS_STORAGE": {
|
|
"implementations": [
|
|
"JsonDocStatusStorage",
|
|
"PGDocStatusStorage",
|
|
"PGDocStatusStorage",
|
|
"MongoDocStatusStorage",
|
|
],
|
|
"required_methods": ["get_docs_by_status"],
|
|
},
|
|
}
|
|
|
|
# Storage implementation environment variable without default value
|
|
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
|
# KV Storage Implementations
|
|
"JsonKVStorage": [],
|
|
"MongoKVStorage": [],
|
|
"RedisKVStorage": ["REDIS_URI"],
|
|
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
|
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
|
"OracleKVStorage": [
|
|
"ORACLE_DSN",
|
|
"ORACLE_USER",
|
|
"ORACLE_PASSWORD",
|
|
"ORACLE_CONFIG_DIR",
|
|
],
|
|
# Graph Storage Implementations
|
|
"NetworkXStorage": [],
|
|
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
|
"MongoGraphStorage": [],
|
|
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
|
"AGEStorage": [
|
|
"AGE_POSTGRES_DB",
|
|
"AGE_POSTGRES_USER",
|
|
"AGE_POSTGRES_PASSWORD",
|
|
],
|
|
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
|
"PGGraphStorage": [
|
|
"POSTGRES_USER",
|
|
"POSTGRES_PASSWORD",
|
|
"POSTGRES_DATABASE",
|
|
],
|
|
"OracleGraphStorage": [
|
|
"ORACLE_DSN",
|
|
"ORACLE_USER",
|
|
"ORACLE_PASSWORD",
|
|
"ORACLE_CONFIG_DIR",
|
|
],
|
|
# Vector Storage Implementations
|
|
"NanoVectorDBStorage": [],
|
|
"MilvusVectorDBStorage": [],
|
|
"ChromaVectorDBStorage": [],
|
|
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
|
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
|
"FaissVectorDBStorage": [],
|
|
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
|
"OracleVectorDBStorage": [
|
|
"ORACLE_DSN",
|
|
"ORACLE_USER",
|
|
"ORACLE_PASSWORD",
|
|
"ORACLE_CONFIG_DIR",
|
|
],
|
|
"MongoVectorDBStorage": [],
|
|
# Document Status Storage Implementations
|
|
"JsonDocStatusStorage": [],
|
|
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
|
"MongoDocStatusStorage": [],
|
|
}
|
|
|
|
# Storage implementation module mapping
|
|
STORAGES = {
|
|
"NetworkXStorage": ".kg.networkx_impl",
|
|
"JsonKVStorage": ".kg.json_kv_impl",
|
|
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
|
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
|
"Neo4JStorage": ".kg.neo4j_impl",
|
|
"OracleKVStorage": ".kg.oracle_impl",
|
|
"OracleGraphStorage": ".kg.oracle_impl",
|
|
"OracleVectorDBStorage": ".kg.oracle_impl",
|
|
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
|
"MongoKVStorage": ".kg.mongo_impl",
|
|
"MongoDocStatusStorage": ".kg.mongo_impl",
|
|
"MongoGraphStorage": ".kg.mongo_impl",
|
|
"MongoVectorDBStorage": ".kg.mongo_impl",
|
|
"RedisKVStorage": ".kg.redis_impl",
|
|
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
|
"TiDBKVStorage": ".kg.tidb_impl",
|
|
"TiDBVectorDBStorage": ".kg.tidb_impl",
|
|
"TiDBGraphStorage": ".kg.tidb_impl",
|
|
"PGKVStorage": ".kg.postgres_impl",
|
|
"PGVectorStorage": ".kg.postgres_impl",
|
|
"AGEStorage": ".kg.age_impl",
|
|
"PGGraphStorage": ".kg.postgres_impl",
|
|
"GremlinStorage": ".kg.gremlin_impl",
|
|
"PGDocStatusStorage": ".kg.postgres_impl",
|
|
"FaissVectorDBStorage": ".kg.faiss_impl",
|
|
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@final
|
|
@dataclass
|
|
class LightRAG:
|
|
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
|
|
|
|
# Directory
|
|
# ---
|
|
|
|
working_dir: str = field(
|
|
default=f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
|
)
|
|
"""Directory where cache and temporary files are stored."""
|
|
|
|
# Storage
|
|
# ---
|
|
|
|
kv_storage: str = field(default="JsonKVStorage")
|
|
"""Storage backend for key-value data."""
|
|
|
|
vector_storage: str = field(default="NanoVectorDBStorage")
|
|
"""Storage backend for vector embeddings."""
|
|
|
|
graph_storage: str = field(default="NetworkXStorage")
|
|
"""Storage backend for knowledge graphs."""
|
|
|
|
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
|
"""Storage type for tracking document processing statuses."""
|
|
|
|
# Logging
|
|
# ---
|
|
|
|
log_level: int = field(default=logger.level)
|
|
"""Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
|
|
|
|
log_dir: str = field(default=os.getcwd())
|
|
"""Directory where logs are stored. Defaults to the current working directory."""
|
|
|
|
# Entity extraction
|
|
# ---
|
|
|
|
entity_extract_max_gleaning: int = field(default=1)
|
|
"""Maximum number of entity extraction attempts for ambiguous content."""
|
|
|
|
entity_summary_to_max_tokens: int = field(
|
|
default=int(os.getenv("MAX_TOKEN_SUMMARY", 500))
|
|
)
|
|
|
|
# Text chunking
|
|
# ---
|
|
|
|
chunk_token_size: int = field(default=int(os.getenv("CHUNK_SIZE", 1200)))
|
|
"""Maximum number of tokens per text chunk when splitting documents."""
|
|
|
|
chunk_overlap_token_size: int = field(
|
|
default=int(os.getenv("CHUNK_OVERLAP_SIZE", 100))
|
|
)
|
|
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
|
|
|
|
tiktoken_model_name: str = field(default="gpt-4o-mini")
|
|
"""Model name used for tokenization when chunking text."""
|
|
|
|
"""Maximum number of tokens used for summarizing extracted entities."""
|
|
|
|
chunking_func: Callable[
|
|
[
|
|
str,
|
|
str | None,
|
|
bool,
|
|
int,
|
|
int,
|
|
str,
|
|
],
|
|
list[dict[str, Any]],
|
|
] = field(default_factory=lambda: chunking_by_token_size)
|
|
"""
|
|
Custom chunking function for splitting text into chunks before processing.
|
|
|
|
The function should take the following parameters:
|
|
|
|
- `content`: The text to be split into chunks.
|
|
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
|
|
- `split_by_character_only`: If True, the text is split only on the specified character.
|
|
- `chunk_token_size`: The maximum number of tokens per chunk.
|
|
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
|
|
- `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
|
|
|
|
The function should return a list of dictionaries, where each dictionary contains the following keys:
|
|
- `tokens`: The number of tokens in the chunk.
|
|
- `content`: The text content of the chunk.
|
|
|
|
Defaults to `chunking_by_token_size` if not specified.
|
|
"""
|
|
|
|
# Node embedding
|
|
# ---
|
|
|
|
node_embedding_algorithm: str = field(default="node2vec")
|
|
"""Algorithm used for node embedding in knowledge graphs."""
|
|
|
|
node2vec_params: dict[str, int] = field(
|
|
default_factory=lambda: {
|
|
"dimensions": 1536,
|
|
"num_walks": 10,
|
|
"walk_length": 40,
|
|
"window_size": 2,
|
|
"iterations": 3,
|
|
"random_seed": 3,
|
|
}
|
|
)
|
|
"""Configuration for the node2vec embedding algorithm:
|
|
- dimensions: Number of dimensions for embeddings.
|
|
- num_walks: Number of random walks per node.
|
|
- walk_length: Number of steps per random walk.
|
|
- window_size: Context window size for training.
|
|
- iterations: Number of iterations for training.
|
|
- random_seed: Seed value for reproducibility.
|
|
"""
|
|
|
|
# Embedding
|
|
# ---
|
|
|
|
embedding_func: EmbeddingFunc | None = field(default=None)
|
|
"""Function for computing text embeddings. Must be set before use."""
|
|
|
|
embedding_batch_num: int = field(default=32)
|
|
"""Batch size for embedding computations."""
|
|
|
|
embedding_func_max_async: int = field(default=16)
|
|
"""Maximum number of concurrent embedding function calls."""
|
|
|
|
embedding_cache_config: dict[str, Any] = field(
|
|
default={
|
|
"enabled": False,
|
|
"similarity_threshold": 0.95,
|
|
"use_llm_check": False,
|
|
}
|
|
)
|
|
"""Configuration for embedding cache.
|
|
- enabled: If True, enables caching to avoid redundant computations.
|
|
- similarity_threshold: Minimum similarity score to use cached embeddings.
|
|
- use_llm_check: If True, validates cached embeddings using an LLM.
|
|
"""
|
|
|
|
# LLM Configuration
|
|
# ---
|
|
|
|
llm_model_func: Callable[..., object] | None = field(default=None)
|
|
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
|
|
|
llm_model_name: str = field(default="gpt-4o-mini")
|
|
"""Name of the LLM model used for generating responses."""
|
|
|
|
llm_model_max_token_size: int = field(default=int(os.getenv("MAX_TOKENS", 32768)))
|
|
"""Maximum number of tokens allowed per LLM response."""
|
|
|
|
llm_model_max_async: int = field(default=int(os.getenv("MAX_ASYNC", 16)))
|
|
"""Maximum number of concurrent LLM calls."""
|
|
|
|
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
"""Additional keyword arguments passed to the LLM model function."""
|
|
|
|
# Storage
|
|
# ---
|
|
|
|
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
"""Additional parameters for vector database storage."""
|
|
|
|
namespace_prefix: str = field(default="")
|
|
"""Prefix for namespacing stored data across different environments."""
|
|
|
|
enable_llm_cache: bool = field(default=True)
|
|
"""Enables caching for LLM responses to avoid redundant computations."""
|
|
|
|
enable_llm_cache_for_entity_extract: bool = field(default=True)
|
|
"""If True, enables caching for entity extraction steps to reduce LLM costs."""
|
|
|
|
# Extensions
|
|
# ---
|
|
|
|
max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20)))
|
|
"""Maximum number of parallel insert operations."""
|
|
|
|
addon_params: dict[str, Any] = field(default_factory=dict)
|
|
|
|
# Storages Management
|
|
# ---
|
|
|
|
auto_manage_storages_states: bool = field(default=True)
|
|
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
|
|
|
|
# Storages Management
|
|
# ---
|
|
|
|
convert_response_to_json_func: Callable[[str], dict[str, Any]] = field(
|
|
default_factory=lambda: convert_response_to_json
|
|
)
|
|
"""
|
|
Custom function for converting LLM responses to JSON format.
|
|
|
|
The default function is :func:`.utils.convert_response_to_json`.
|
|
"""
|
|
|
|
def __post_init__(self):
|
|
os.makedirs(self.log_dir, exist_ok=True)
|
|
log_file = os.path.join(self.log_dir, "lightrag.log")
|
|
set_logger(log_file)
|
|
|
|
logger.setLevel(self.log_level)
|
|
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
|
if not os.path.exists(self.working_dir):
|
|
logger.info(f"Creating working directory {self.working_dir}")
|
|
os.makedirs(self.working_dir)
|
|
|
|
# Verify storage implementation compatibility and environment variables
|
|
storage_configs = [
|
|
("KV_STORAGE", self.kv_storage),
|
|
("VECTOR_STORAGE", self.vector_storage),
|
|
("GRAPH_STORAGE", self.graph_storage),
|
|
("DOC_STATUS_STORAGE", self.doc_status_storage),
|
|
]
|
|
|
|
for storage_type, storage_name in storage_configs:
|
|
# Verify storage implementation compatibility
|
|
self.verify_storage_implementation(storage_type, storage_name)
|
|
# Check environment variables
|
|
# self.check_storage_env_vars(storage_name)
|
|
|
|
# Ensure vector_db_storage_cls_kwargs has required fields
|
|
default_vector_db_kwargs = {
|
|
"cosine_better_than_threshold": float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
|
}
|
|
self.vector_db_storage_cls_kwargs = {
|
|
**default_vector_db_kwargs,
|
|
**self.vector_db_storage_cls_kwargs,
|
|
}
|
|
|
|
# Life cycle
|
|
self.storages_status = StoragesStatus.NOT_CREATED
|
|
|
|
# Show config
|
|
global_config = asdict(self)
|
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
|
|
|
# Init LLM
|
|
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
|
self.embedding_func
|
|
)
|
|
|
|
# Initialize all storages
|
|
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
|
|
self._get_storage_class(self.kv_storage)
|
|
) # type: ignore
|
|
self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
|
|
self.vector_storage
|
|
) # type: ignore
|
|
self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
|
|
self.graph_storage
|
|
) # type: ignore
|
|
self.key_string_value_json_storage_cls = partial( # type: ignore
|
|
self.key_string_value_json_storage_cls, global_config=global_config
|
|
)
|
|
self.vector_db_storage_cls = partial( # type: ignore
|
|
self.vector_db_storage_cls, global_config=global_config
|
|
)
|
|
self.graph_storage_cls = partial( # type: ignore
|
|
self.graph_storage_cls, global_config=global_config
|
|
)
|
|
|
|
# Initialize document status storage
|
|
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
|
|
|
self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
),
|
|
embedding_func=self.embedding_func,
|
|
)
|
|
|
|
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
|
),
|
|
embedding_func=self.embedding_func,
|
|
)
|
|
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
|
),
|
|
embedding_func=self.embedding_func,
|
|
)
|
|
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
|
),
|
|
embedding_func=self.embedding_func,
|
|
)
|
|
|
|
self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
|
),
|
|
embedding_func=self.embedding_func,
|
|
meta_fields={"entity_name"},
|
|
)
|
|
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
|
),
|
|
embedding_func=self.embedding_func,
|
|
meta_fields={"src_id", "tgt_id"},
|
|
)
|
|
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
|
|
),
|
|
embedding_func=self.embedding_func,
|
|
)
|
|
|
|
# Initialize document status storage
|
|
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
|
|
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
|
|
global_config=global_config,
|
|
embedding_func=None,
|
|
)
|
|
|
|
if self.llm_response_cache and hasattr(
|
|
self.llm_response_cache, "global_config"
|
|
):
|
|
hashing_kv = self.llm_response_cache
|
|
else:
|
|
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
),
|
|
embedding_func=self.embedding_func,
|
|
)
|
|
|
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
|
partial(
|
|
self.llm_model_func, # type: ignore
|
|
hashing_kv=hashing_kv,
|
|
**self.llm_model_kwargs,
|
|
)
|
|
)
|
|
|
|
self.storages_status = StoragesStatus.CREATED
|
|
|
|
# Initialize storages
|
|
if self.auto_manage_storages_states:
|
|
loop = always_get_an_event_loop()
|
|
loop.run_until_complete(self.initialize_storages())
|
|
|
|
def __del__(self):
|
|
# Finalize storages
|
|
if self.auto_manage_storages_states:
|
|
loop = always_get_an_event_loop()
|
|
loop.run_until_complete(self.finalize_storages())
|
|
|
|
async def initialize_storages(self):
|
|
"""Asynchronously initialize the storages"""
|
|
if self.storages_status == StoragesStatus.CREATED:
|
|
tasks = []
|
|
|
|
for storage in (
|
|
self.full_docs,
|
|
self.text_chunks,
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.chunks_vdb,
|
|
self.chunk_entity_relation_graph,
|
|
self.llm_response_cache,
|
|
self.doc_status,
|
|
):
|
|
if storage:
|
|
tasks.append(storage.initialize())
|
|
|
|
await asyncio.gather(*tasks)
|
|
|
|
self.storages_status = StoragesStatus.INITIALIZED
|
|
logger.debug("Initialized Storages")
|
|
|
|
async def finalize_storages(self):
|
|
"""Asynchronously finalize the storages"""
|
|
if self.storages_status == StoragesStatus.INITIALIZED:
|
|
tasks = []
|
|
|
|
for storage in (
|
|
self.full_docs,
|
|
self.text_chunks,
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.chunks_vdb,
|
|
self.chunk_entity_relation_graph,
|
|
self.llm_response_cache,
|
|
self.doc_status,
|
|
):
|
|
if storage:
|
|
tasks.append(storage.finalize())
|
|
|
|
await asyncio.gather(*tasks)
|
|
|
|
self.storages_status = StoragesStatus.FINALIZED
|
|
logger.debug("Finalized Storages")
|
|
|
|
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
|
import_path = STORAGES[storage_name]
|
|
storage_class = lazy_external_import(import_path, storage_name)
|
|
return storage_class
|
|
|
|
def insert(
|
|
self,
|
|
input: str | list[str],
|
|
split_by_character: str | None = None,
|
|
split_by_character_only: bool = False,
|
|
) -> None:
|
|
"""Sync Insert documents with checkpoint support
|
|
|
|
Args:
|
|
input: Single document string or list of document strings
|
|
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
|
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
|
split_by_character is None, this parameter is ignored.
|
|
"""
|
|
loop = always_get_an_event_loop()
|
|
loop.run_until_complete(
|
|
self.ainsert(input, split_by_character, split_by_character_only)
|
|
)
|
|
|
|
async def ainsert(
|
|
self,
|
|
input: str | list[str],
|
|
split_by_character: str | None = None,
|
|
split_by_character_only: bool = False,
|
|
) -> None:
|
|
"""Async Insert documents with checkpoint support
|
|
|
|
Args:
|
|
input: Single document string or list of document strings
|
|
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
|
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
|
split_by_character is None, this parameter is ignored.
|
|
"""
|
|
await self.apipeline_enqueue_documents(input)
|
|
await self.apipeline_process_enqueue_documents(
|
|
split_by_character, split_by_character_only
|
|
)
|
|
|
|
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]) -> None:
|
|
loop = always_get_an_event_loop()
|
|
loop.run_until_complete(self.ainsert_custom_chunks(full_text, text_chunks))
|
|
|
|
async def ainsert_custom_chunks(
|
|
self, full_text: str, text_chunks: list[str]
|
|
) -> None:
|
|
update_storage = False
|
|
try:
|
|
doc_key = compute_mdhash_id(full_text.strip(), prefix="doc-")
|
|
new_docs = {doc_key: {"content": full_text.strip()}}
|
|
|
|
_add_doc_keys = await self.full_docs.filter_keys(set(doc_key))
|
|
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
|
if not len(new_docs):
|
|
logger.warning("This document is already in the storage.")
|
|
return
|
|
|
|
update_storage = True
|
|
logger.info(f"Inserting {len(new_docs)} docs")
|
|
|
|
inserting_chunks: dict[str, Any] = {}
|
|
for chunk_text in text_chunks:
|
|
chunk_text_stripped = chunk_text.strip()
|
|
chunk_key = compute_mdhash_id(chunk_text_stripped, prefix="chunk-")
|
|
|
|
inserting_chunks[chunk_key] = {
|
|
"content": chunk_text_stripped,
|
|
"full_doc_id": doc_key,
|
|
}
|
|
|
|
doc_ids = set(inserting_chunks.keys())
|
|
add_chunk_keys = await self.text_chunks.filter_keys(doc_ids)
|
|
inserting_chunks = {
|
|
k: v for k, v in inserting_chunks.items() if k in add_chunk_keys
|
|
}
|
|
if not len(inserting_chunks):
|
|
logger.warning("All chunks are already in the storage.")
|
|
return
|
|
|
|
tasks = [
|
|
self.chunks_vdb.upsert(inserting_chunks),
|
|
self._process_entity_relation_graph(inserting_chunks),
|
|
self.full_docs.upsert(new_docs),
|
|
self.text_chunks.upsert(inserting_chunks),
|
|
]
|
|
await asyncio.gather(*tasks)
|
|
|
|
finally:
|
|
if update_storage:
|
|
await self._insert_done()
|
|
|
|
async def apipeline_enqueue_documents(self, input: str | list[str]) -> None:
|
|
"""
|
|
Pipeline for Processing Documents
|
|
|
|
1. Remove duplicate contents from the list
|
|
2. Generate document IDs and initial status
|
|
3. Filter out already processed documents
|
|
4. Enqueue document in status
|
|
"""
|
|
if isinstance(input, str):
|
|
input = [input]
|
|
|
|
# 1. Remove duplicate contents from the list
|
|
unique_contents = list(set(doc.strip() for doc in input))
|
|
|
|
# 2. Generate document IDs and initial status
|
|
new_docs: dict[str, Any] = {
|
|
compute_mdhash_id(content, prefix="doc-"): {
|
|
"content": content,
|
|
"content_summary": self._get_content_summary(content),
|
|
"content_length": len(content),
|
|
"status": DocStatus.PENDING,
|
|
"created_at": datetime.now().isoformat(),
|
|
"updated_at": datetime.now().isoformat(),
|
|
}
|
|
for content in unique_contents
|
|
}
|
|
|
|
# 3. Filter out already processed documents
|
|
# Get docs ids
|
|
all_new_doc_ids = set(new_docs.keys())
|
|
# Exclude IDs of documents that are already in progress
|
|
unique_new_doc_ids = await self.doc_status.filter_keys(all_new_doc_ids)
|
|
# Filter new_docs to only include documents with unique IDs
|
|
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids}
|
|
|
|
if not new_docs:
|
|
logger.info("No new unique documents were found.")
|
|
return
|
|
|
|
# 4. Store status document
|
|
await self.doc_status.upsert(new_docs)
|
|
logger.info(f"Stored {len(new_docs)} new unique documents")
|
|
|
|
async def apipeline_process_enqueue_documents(
|
|
self,
|
|
split_by_character: str | None = None,
|
|
split_by_character_only: bool = False,
|
|
) -> None:
|
|
"""
|
|
Process pending documents by splitting them into chunks, processing
|
|
each chunk for entity and relation extraction, and updating the
|
|
document status.
|
|
|
|
1. Get all pending, failed, and abnormally terminated processing documents.
|
|
2. Split document content into chunks
|
|
3. Process each chunk for entity and relation extraction
|
|
4. Update the document status
|
|
"""
|
|
# 1. Get all pending, failed, and abnormally terminated processing documents.
|
|
# Run the asynchronous status retrievals in parallel using asyncio.gather
|
|
processing_docs, failed_docs, pending_docs = await asyncio.gather(
|
|
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
|
|
self.doc_status.get_docs_by_status(DocStatus.FAILED),
|
|
self.doc_status.get_docs_by_status(DocStatus.PENDING),
|
|
)
|
|
|
|
to_process_docs: dict[str, DocProcessingStatus] = {}
|
|
to_process_docs.update(processing_docs)
|
|
to_process_docs.update(failed_docs)
|
|
to_process_docs.update(pending_docs)
|
|
|
|
if not to_process_docs:
|
|
logger.info("All documents have been processed or are duplicates")
|
|
return
|
|
|
|
# 2. split docs into chunks, insert chunks, update doc status
|
|
docs_batches = [
|
|
list(to_process_docs.items())[i : i + self.max_parallel_insert]
|
|
for i in range(0, len(to_process_docs), self.max_parallel_insert)
|
|
]
|
|
|
|
logger.info(f"Number of batches to process: {len(docs_batches)}.")
|
|
|
|
batches: list[Any] = []
|
|
# 3. iterate over batches
|
|
for batch_idx, docs_batch in enumerate(docs_batches):
|
|
|
|
async def batch(
|
|
batch_idx: int,
|
|
docs_batch: list[tuple[str, DocProcessingStatus]],
|
|
size_batch: int,
|
|
) -> None:
|
|
logger.info(f"Start processing batch {batch_idx + 1} of {size_batch}.")
|
|
# 4. iterate over batch
|
|
for doc_id_processing_status in docs_batch:
|
|
doc_id, status_doc = doc_id_processing_status
|
|
# Update status in processing
|
|
doc_status_id = compute_mdhash_id(status_doc.content, prefix="doc-")
|
|
# Generate chunks from document
|
|
chunks: dict[str, Any] = {
|
|
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
|
**dp,
|
|
"full_doc_id": doc_id,
|
|
}
|
|
for dp in self.chunking_func(
|
|
status_doc.content,
|
|
split_by_character,
|
|
split_by_character_only,
|
|
self.chunk_overlap_token_size,
|
|
self.chunk_token_size,
|
|
self.tiktoken_model_name,
|
|
)
|
|
}
|
|
# Process document (text chunks and full docs) in parallel
|
|
tasks = [
|
|
self.doc_status.upsert(
|
|
{
|
|
doc_status_id: {
|
|
"status": DocStatus.PROCESSING,
|
|
"updated_at": datetime.now().isoformat(),
|
|
"content": status_doc.content,
|
|
"content_summary": status_doc.content_summary,
|
|
"content_length": status_doc.content_length,
|
|
"created_at": status_doc.created_at,
|
|
}
|
|
}
|
|
),
|
|
self.chunks_vdb.upsert(chunks),
|
|
self._process_entity_relation_graph(chunks),
|
|
self.full_docs.upsert(
|
|
{doc_id: {"content": status_doc.content}}
|
|
),
|
|
self.text_chunks.upsert(chunks),
|
|
]
|
|
try:
|
|
await asyncio.gather(*tasks)
|
|
await self.doc_status.upsert(
|
|
{
|
|
doc_status_id: {
|
|
"status": DocStatus.PROCESSED,
|
|
"chunks_count": len(chunks),
|
|
"content": status_doc.content,
|
|
"content_summary": status_doc.content_summary,
|
|
"content_length": status_doc.content_length,
|
|
"created_at": status_doc.created_at,
|
|
"updated_at": datetime.now().isoformat(),
|
|
}
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to process document {doc_id}: {str(e)}")
|
|
await self.doc_status.upsert(
|
|
{
|
|
doc_status_id: {
|
|
"status": DocStatus.FAILED,
|
|
"error": str(e),
|
|
"content": status_doc.content,
|
|
"content_summary": status_doc.content_summary,
|
|
"content_length": status_doc.content_length,
|
|
"created_at": status_doc.created_at,
|
|
"updated_at": datetime.now().isoformat(),
|
|
}
|
|
}
|
|
)
|
|
continue
|
|
logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.")
|
|
|
|
batches.append(batch(batch_idx, docs_batch, len(docs_batches)))
|
|
|
|
await asyncio.gather(*batches)
|
|
await self._insert_done()
|
|
|
|
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
|
try:
|
|
new_kg = await extract_entities(
|
|
chunk,
|
|
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
|
entity_vdb=self.entities_vdb,
|
|
relationships_vdb=self.relationships_vdb,
|
|
llm_response_cache=self.llm_response_cache,
|
|
global_config=asdict(self),
|
|
)
|
|
if new_kg is None:
|
|
logger.info("No new entities or relationships extracted.")
|
|
else:
|
|
async with self._entity_lock:
|
|
logger.info("New entities or relationships extracted.")
|
|
self.chunk_entity_relation_graph = new_kg
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to extract entities and relationships")
|
|
raise e
|
|
|
|
async def _insert_done(self) -> None:
|
|
tasks = [
|
|
cast(StorageNameSpace, storage_inst).index_done_callback()
|
|
for storage_inst in [ # type: ignore
|
|
self.full_docs,
|
|
self.text_chunks,
|
|
self.llm_response_cache,
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.chunks_vdb,
|
|
self.chunk_entity_relation_graph,
|
|
]
|
|
if storage_inst is not None
|
|
]
|
|
await asyncio.gather(*tasks)
|
|
logger.info("All Insert done")
|
|
|
|
def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None:
|
|
loop = always_get_an_event_loop()
|
|
loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
|
|
|
async def ainsert_custom_kg(self, custom_kg: dict[str, Any]) -> None:
|
|
update_storage = False
|
|
try:
|
|
# Insert chunks into vector storage
|
|
all_chunks_data: dict[str, dict[str, str]] = {}
|
|
chunk_to_source_map: dict[str, str] = {}
|
|
for chunk_data in custom_kg.get("chunks", {}):
|
|
chunk_content = chunk_data["content"].strip()
|
|
source_id = chunk_data["source_id"]
|
|
tokens = len(
|
|
encode_string_by_tiktoken(
|
|
chunk_content, model_name=self.tiktoken_model_name
|
|
)
|
|
)
|
|
chunk_order_index = (
|
|
0
|
|
if "chunk_order_index" not in chunk_data.keys()
|
|
else chunk_data["chunk_order_index"]
|
|
)
|
|
chunk_id = compute_mdhash_id(chunk_content, prefix="chunk-")
|
|
|
|
chunk_entry = {
|
|
"content": chunk_content,
|
|
"source_id": source_id,
|
|
"tokens": tokens,
|
|
"chunk_order_index": chunk_order_index,
|
|
"full_doc_id": source_id,
|
|
"status": DocStatus.PROCESSED,
|
|
}
|
|
all_chunks_data[chunk_id] = chunk_entry
|
|
chunk_to_source_map[source_id] = chunk_id
|
|
update_storage = True
|
|
|
|
if all_chunks_data:
|
|
await self.chunks_vdb.upsert(all_chunks_data)
|
|
if all_chunks_data:
|
|
await self.text_chunks.upsert(all_chunks_data)
|
|
|
|
# Insert entities into knowledge graph
|
|
all_entities_data: list[dict[str, str]] = []
|
|
for entity_data in custom_kg.get("entities", []):
|
|
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
|
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
|
description = entity_data.get("description", "No description provided")
|
|
# source_id = entity_data["source_id"]
|
|
source_chunk_id = entity_data.get("source_id", "UNKNOWN")
|
|
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
|
|
|
|
# Log if source_id is UNKNOWN
|
|
if source_id == "UNKNOWN":
|
|
logger.warning(
|
|
f"Entity '{entity_name}' has an UNKNOWN source_id. Please check the source mapping."
|
|
)
|
|
|
|
# Prepare node data
|
|
node_data: dict[str, str] = {
|
|
"entity_type": entity_type,
|
|
"description": description,
|
|
"source_id": source_id,
|
|
}
|
|
# Insert node data into the knowledge graph
|
|
await self.chunk_entity_relation_graph.upsert_node(
|
|
entity_name, node_data=node_data
|
|
)
|
|
node_data["entity_name"] = entity_name
|
|
all_entities_data.append(node_data)
|
|
update_storage = True
|
|
|
|
# Insert relationships into knowledge graph
|
|
all_relationships_data: list[dict[str, str]] = []
|
|
for relationship_data in custom_kg.get("relationships", []):
|
|
src_id = f'"{relationship_data["src_id"].upper()}"'
|
|
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
|
description = relationship_data["description"]
|
|
keywords = relationship_data["keywords"]
|
|
weight = relationship_data.get("weight", 1.0)
|
|
# source_id = relationship_data["source_id"]
|
|
source_chunk_id = relationship_data.get("source_id", "UNKNOWN")
|
|
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
|
|
|
|
# Log if source_id is UNKNOWN
|
|
if source_id == "UNKNOWN":
|
|
logger.warning(
|
|
f"Relationship from '{src_id}' to '{tgt_id}' has an UNKNOWN source_id. Please check the source mapping."
|
|
)
|
|
|
|
# Check if nodes exist in the knowledge graph
|
|
for need_insert_id in [src_id, tgt_id]:
|
|
if not (
|
|
await self.chunk_entity_relation_graph.has_node(need_insert_id)
|
|
):
|
|
await self.chunk_entity_relation_graph.upsert_node(
|
|
need_insert_id,
|
|
node_data={
|
|
"source_id": source_id,
|
|
"description": "UNKNOWN",
|
|
"entity_type": "UNKNOWN",
|
|
},
|
|
)
|
|
|
|
# Insert edge into the knowledge graph
|
|
await self.chunk_entity_relation_graph.upsert_edge(
|
|
src_id,
|
|
tgt_id,
|
|
edge_data={
|
|
"weight": weight,
|
|
"description": description,
|
|
"keywords": keywords,
|
|
"source_id": source_id,
|
|
},
|
|
)
|
|
edge_data: dict[str, str] = {
|
|
"src_id": src_id,
|
|
"tgt_id": tgt_id,
|
|
"description": description,
|
|
"keywords": keywords,
|
|
}
|
|
all_relationships_data.append(edge_data)
|
|
update_storage = True
|
|
|
|
# Insert entities into vector storage if needed
|
|
data_for_vdb = {
|
|
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
|
"content": dp["entity_name"] + dp["description"],
|
|
"entity_name": dp["entity_name"],
|
|
}
|
|
for dp in all_entities_data
|
|
}
|
|
await self.entities_vdb.upsert(data_for_vdb)
|
|
|
|
# Insert relationships into vector storage if needed
|
|
data_for_vdb = {
|
|
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
|
"src_id": dp["src_id"],
|
|
"tgt_id": dp["tgt_id"],
|
|
"content": dp["keywords"]
|
|
+ dp["src_id"]
|
|
+ dp["tgt_id"]
|
|
+ dp["description"],
|
|
}
|
|
for dp in all_relationships_data
|
|
}
|
|
await self.relationships_vdb.upsert(data_for_vdb)
|
|
|
|
finally:
|
|
if update_storage:
|
|
await self._insert_done()
|
|
|
|
def query(
|
|
self,
|
|
query: str,
|
|
param: QueryParam = QueryParam(),
|
|
system_prompt: str | None = None,
|
|
) -> str | Iterator[str]:
|
|
"""
|
|
Perform a sync query.
|
|
|
|
Args:
|
|
query (str): The query to be executed.
|
|
param (QueryParam): Configuration parameters for query execution.
|
|
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
|
|
|
Returns:
|
|
str: The result of the query execution.
|
|
"""
|
|
loop = always_get_an_event_loop()
|
|
|
|
return loop.run_until_complete(self.aquery(query, param, system_prompt)) # type: ignore
|
|
|
|
async def aquery(
|
|
self,
|
|
query: str,
|
|
param: QueryParam = QueryParam(),
|
|
system_prompt: str | None = None,
|
|
) -> str | AsyncIterator[str]:
|
|
"""
|
|
Perform a async query.
|
|
|
|
Args:
|
|
query (str): The query to be executed.
|
|
param (QueryParam): Configuration parameters for query execution.
|
|
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
|
|
|
Returns:
|
|
str: The result of the query execution.
|
|
"""
|
|
if param.mode in ["local", "global", "hybrid"]:
|
|
response = await kg_query(
|
|
query,
|
|
self.chunk_entity_relation_graph,
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.text_chunks,
|
|
param,
|
|
asdict(self),
|
|
hashing_kv=self.llm_response_cache
|
|
if self.llm_response_cache
|
|
and hasattr(self.llm_response_cache, "global_config")
|
|
else self.key_string_value_json_storage_cls(
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
),
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
),
|
|
system_prompt=system_prompt,
|
|
)
|
|
elif param.mode == "naive":
|
|
response = await naive_query(
|
|
query,
|
|
self.chunks_vdb,
|
|
self.text_chunks,
|
|
param,
|
|
asdict(self),
|
|
hashing_kv=self.llm_response_cache
|
|
if self.llm_response_cache
|
|
and hasattr(self.llm_response_cache, "global_config")
|
|
else self.key_string_value_json_storage_cls(
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
),
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
),
|
|
system_prompt=system_prompt,
|
|
)
|
|
elif param.mode == "mix":
|
|
response = await mix_kg_vector_query(
|
|
query,
|
|
self.chunk_entity_relation_graph,
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.chunks_vdb,
|
|
self.text_chunks,
|
|
param,
|
|
asdict(self),
|
|
hashing_kv=self.llm_response_cache
|
|
if self.llm_response_cache
|
|
and hasattr(self.llm_response_cache, "global_config")
|
|
else self.key_string_value_json_storage_cls(
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
),
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
),
|
|
system_prompt=system_prompt,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown mode {param.mode}")
|
|
await self._query_done()
|
|
return response
|
|
|
|
def query_with_separate_keyword_extraction(
|
|
self, query: str, prompt: str, param: QueryParam = QueryParam()
|
|
):
|
|
"""
|
|
1. Extract keywords from the 'query' using new function in operate.py.
|
|
2. Then run the standard aquery() flow with the final prompt (formatted_question).
|
|
"""
|
|
loop = always_get_an_event_loop()
|
|
return loop.run_until_complete(
|
|
self.aquery_with_separate_keyword_extraction(query, prompt, param)
|
|
)
|
|
|
|
async def aquery_with_separate_keyword_extraction(
|
|
self, query: str, prompt: str, param: QueryParam = QueryParam()
|
|
) -> str | AsyncIterator[str]:
|
|
"""
|
|
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
|
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
|
"""
|
|
# ---------------------
|
|
# STEP 1: Keyword Extraction
|
|
# ---------------------
|
|
hl_keywords, ll_keywords = await extract_keywords_only(
|
|
text=query,
|
|
param=param,
|
|
global_config=asdict(self),
|
|
hashing_kv=self.llm_response_cache
|
|
or self.key_string_value_json_storage_cls(
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
),
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
),
|
|
)
|
|
|
|
param.hl_keywords = hl_keywords
|
|
param.ll_keywords = ll_keywords
|
|
|
|
# ---------------------
|
|
# STEP 2: Final Query Logic
|
|
# ---------------------
|
|
|
|
# Create a new string with the prompt and the keywords
|
|
ll_keywords_str = ", ".join(ll_keywords)
|
|
hl_keywords_str = ", ".join(hl_keywords)
|
|
formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
|
|
|
|
if param.mode in ["local", "global", "hybrid"]:
|
|
response = await kg_query_with_keywords(
|
|
formatted_question,
|
|
self.chunk_entity_relation_graph,
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.text_chunks,
|
|
param,
|
|
asdict(self),
|
|
hashing_kv=self.llm_response_cache
|
|
if self.llm_response_cache
|
|
and hasattr(self.llm_response_cache, "global_config")
|
|
else self.key_string_value_json_storage_cls(
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
),
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
),
|
|
)
|
|
elif param.mode == "naive":
|
|
response = await naive_query(
|
|
formatted_question,
|
|
self.chunks_vdb,
|
|
self.text_chunks,
|
|
param,
|
|
asdict(self),
|
|
hashing_kv=self.llm_response_cache
|
|
if self.llm_response_cache
|
|
and hasattr(self.llm_response_cache, "global_config")
|
|
else self.key_string_value_json_storage_cls(
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
),
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
),
|
|
)
|
|
elif param.mode == "mix":
|
|
response = await mix_kg_vector_query(
|
|
formatted_question,
|
|
self.chunk_entity_relation_graph,
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.chunks_vdb,
|
|
self.text_chunks,
|
|
param,
|
|
asdict(self),
|
|
hashing_kv=self.llm_response_cache
|
|
if self.llm_response_cache
|
|
and hasattr(self.llm_response_cache, "global_config")
|
|
else self.key_string_value_json_storage_cls(
|
|
namespace=make_namespace(
|
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
),
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
),
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown mode {param.mode}")
|
|
|
|
await self._query_done()
|
|
return response
|
|
|
|
async def _query_done(self):
|
|
await self.llm_response_cache.index_done_callback()
|
|
|
|
def delete_by_entity(self, entity_name: str) -> None:
|
|
loop = always_get_an_event_loop()
|
|
return loop.run_until_complete(self.adelete_by_entity(entity_name))
|
|
|
|
async def adelete_by_entity(self, entity_name: str) -> None:
|
|
entity_name = f'"{entity_name.upper()}"'
|
|
|
|
try:
|
|
await self.entities_vdb.delete_entity(entity_name)
|
|
await self.relationships_vdb.delete_entity_relation(entity_name)
|
|
await self.chunk_entity_relation_graph.delete_node(entity_name)
|
|
|
|
logger.info(
|
|
f"Entity '{entity_name}' and its relationships have been deleted."
|
|
)
|
|
await self._delete_by_entity_done()
|
|
except Exception as e:
|
|
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
|
|
|
async def _delete_by_entity_done(self) -> None:
|
|
await asyncio.gather(
|
|
*[
|
|
cast(StorageNameSpace, storage_inst).index_done_callback()
|
|
for storage_inst in [ # type: ignore
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.chunk_entity_relation_graph,
|
|
]
|
|
]
|
|
)
|
|
|
|
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
|
|
"""Get summary of document content
|
|
|
|
Args:
|
|
content: Original document content
|
|
max_length: Maximum length of summary
|
|
|
|
Returns:
|
|
Truncated content with ellipsis if needed
|
|
"""
|
|
content = content.strip()
|
|
if len(content) <= max_length:
|
|
return content
|
|
return content[:max_length] + "..."
|
|
|
|
async def get_processing_status(self) -> dict[str, int]:
|
|
"""Get current document processing status counts
|
|
|
|
Returns:
|
|
Dict with counts for each status
|
|
"""
|
|
return await self.doc_status.get_status_counts()
|
|
|
|
async def get_docs_by_status(
|
|
self, status: DocStatus
|
|
) -> dict[str, DocProcessingStatus]:
|
|
"""Get documents by status
|
|
|
|
Returns:
|
|
Dict with document id is keys and document status is values
|
|
"""
|
|
return await self.doc_status.get_docs_by_status(status)
|
|
|
|
async def adelete_by_doc_id(self, doc_id: str) -> None:
|
|
"""Delete a document and all its related data
|
|
|
|
Args:
|
|
doc_id: Document ID to delete
|
|
"""
|
|
try:
|
|
# 1. Get the document status and related data
|
|
doc_status = await self.doc_status.get_by_id(doc_id)
|
|
if not doc_status:
|
|
logger.warning(f"Document {doc_id} not found")
|
|
return
|
|
|
|
logger.debug(f"Starting deletion for document {doc_id}")
|
|
|
|
# 2. Get all related chunks
|
|
chunks = await self.text_chunks.get_by_id(doc_id)
|
|
if not chunks:
|
|
return
|
|
|
|
chunk_ids = list(chunks.keys())
|
|
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
|
|
|
# 3. Before deleting, check the related entities and relationships for these chunks
|
|
for chunk_id in chunk_ids:
|
|
# Check entities
|
|
entities = [
|
|
dp
|
|
for dp in self.entities_vdb.client_storage["data"]
|
|
if dp.get("source_id") == chunk_id
|
|
]
|
|
logger.debug(f"Chunk {chunk_id} has {len(entities)} related entities")
|
|
|
|
# Check relationships
|
|
relations = [
|
|
dp
|
|
for dp in self.relationships_vdb.client_storage["data"]
|
|
if dp.get("source_id") == chunk_id
|
|
]
|
|
logger.debug(f"Chunk {chunk_id} has {len(relations)} related relations")
|
|
|
|
# Continue with the original deletion process...
|
|
|
|
# 4. Delete chunks from vector database
|
|
if chunk_ids:
|
|
await self.chunks_vdb.delete(chunk_ids)
|
|
await self.text_chunks.delete(chunk_ids)
|
|
|
|
# 5. Find and process entities and relationships that have these chunks as source
|
|
# Get all nodes in the graph
|
|
nodes = self.chunk_entity_relation_graph._graph.nodes(data=True)
|
|
edges = self.chunk_entity_relation_graph._graph.edges(data=True)
|
|
|
|
# Track which entities and relationships need to be deleted or updated
|
|
entities_to_delete = set()
|
|
entities_to_update = {} # entity_name -> new_source_id
|
|
relationships_to_delete = set()
|
|
relationships_to_update = {} # (src, tgt) -> new_source_id
|
|
|
|
# Process entities
|
|
for node, data in nodes:
|
|
if "source_id" in data:
|
|
# Split source_id using GRAPH_FIELD_SEP
|
|
sources = set(data["source_id"].split(GRAPH_FIELD_SEP))
|
|
sources.difference_update(chunk_ids)
|
|
if not sources:
|
|
entities_to_delete.add(node)
|
|
logger.debug(
|
|
f"Entity {node} marked for deletion - no remaining sources"
|
|
)
|
|
else:
|
|
new_source_id = GRAPH_FIELD_SEP.join(sources)
|
|
entities_to_update[node] = new_source_id
|
|
logger.debug(
|
|
f"Entity {node} will be updated with new source_id: {new_source_id}"
|
|
)
|
|
|
|
# Process relationships
|
|
for src, tgt, data in edges:
|
|
if "source_id" in data:
|
|
# Split source_id using GRAPH_FIELD_SEP
|
|
sources = set(data["source_id"].split(GRAPH_FIELD_SEP))
|
|
sources.difference_update(chunk_ids)
|
|
if not sources:
|
|
relationships_to_delete.add((src, tgt))
|
|
logger.debug(
|
|
f"Relationship {src}-{tgt} marked for deletion - no remaining sources"
|
|
)
|
|
else:
|
|
new_source_id = GRAPH_FIELD_SEP.join(sources)
|
|
relationships_to_update[(src, tgt)] = new_source_id
|
|
logger.debug(
|
|
f"Relationship {src}-{tgt} will be updated with new source_id: {new_source_id}"
|
|
)
|
|
|
|
# Delete entities
|
|
if entities_to_delete:
|
|
for entity in entities_to_delete:
|
|
await self.entities_vdb.delete_entity(entity)
|
|
logger.debug(f"Deleted entity {entity} from vector DB")
|
|
self.chunk_entity_relation_graph.remove_nodes(list(entities_to_delete))
|
|
logger.debug(f"Deleted {len(entities_to_delete)} entities from graph")
|
|
|
|
# Update entities
|
|
for entity, new_source_id in entities_to_update.items():
|
|
node_data = self.chunk_entity_relation_graph._graph.nodes[entity]
|
|
node_data["source_id"] = new_source_id
|
|
await self.chunk_entity_relation_graph.upsert_node(entity, node_data)
|
|
logger.debug(
|
|
f"Updated entity {entity} with new source_id: {new_source_id}"
|
|
)
|
|
|
|
# Delete relationships
|
|
if relationships_to_delete:
|
|
for src, tgt in relationships_to_delete:
|
|
rel_id_0 = compute_mdhash_id(src + tgt, prefix="rel-")
|
|
rel_id_1 = compute_mdhash_id(tgt + src, prefix="rel-")
|
|
await self.relationships_vdb.delete([rel_id_0, rel_id_1])
|
|
logger.debug(f"Deleted relationship {src}-{tgt} from vector DB")
|
|
self.chunk_entity_relation_graph.remove_edges(
|
|
list(relationships_to_delete)
|
|
)
|
|
logger.debug(
|
|
f"Deleted {len(relationships_to_delete)} relationships from graph"
|
|
)
|
|
|
|
# Update relationships
|
|
for (src, tgt), new_source_id in relationships_to_update.items():
|
|
edge_data = self.chunk_entity_relation_graph._graph.edges[src, tgt]
|
|
edge_data["source_id"] = new_source_id
|
|
await self.chunk_entity_relation_graph.upsert_edge(src, tgt, edge_data)
|
|
logger.debug(
|
|
f"Updated relationship {src}-{tgt} with new source_id: {new_source_id}"
|
|
)
|
|
|
|
# 6. Delete original document and status
|
|
await self.full_docs.delete([doc_id])
|
|
await self.doc_status.delete([doc_id])
|
|
|
|
# 7. Ensure all indexes are updated
|
|
await self._insert_done()
|
|
|
|
logger.info(
|
|
f"Successfully deleted document {doc_id} and related data. "
|
|
f"Deleted {len(entities_to_delete)} entities and {len(relationships_to_delete)} relationships. "
|
|
f"Updated {len(entities_to_update)} entities and {len(relationships_to_update)} relationships."
|
|
)
|
|
|
|
# Add verification step
|
|
async def verify_deletion():
|
|
# Verify if the document has been deleted
|
|
if await self.full_docs.get_by_id(doc_id):
|
|
logger.error(f"Document {doc_id} still exists in full_docs")
|
|
|
|
# Verify if chunks have been deleted
|
|
remaining_chunks = await self.text_chunks.get_by_id(doc_id)
|
|
if remaining_chunks:
|
|
logger.error(f"Found {len(remaining_chunks)} remaining chunks")
|
|
|
|
# Verify entities and relationships
|
|
for chunk_id in chunk_ids:
|
|
# Check entities
|
|
entities_with_chunk = [
|
|
dp
|
|
for dp in self.entities_vdb.client_storage["data"]
|
|
if chunk_id
|
|
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
|
]
|
|
if entities_with_chunk:
|
|
logger.error(
|
|
f"Found {len(entities_with_chunk)} entities still referencing chunk {chunk_id}"
|
|
)
|
|
|
|
# Check relationships
|
|
relations_with_chunk = [
|
|
dp
|
|
for dp in self.relationships_vdb.client_storage["data"]
|
|
if chunk_id
|
|
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
|
]
|
|
if relations_with_chunk:
|
|
logger.error(
|
|
f"Found {len(relations_with_chunk)} relations still referencing chunk {chunk_id}"
|
|
)
|
|
|
|
await verify_deletion()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error while deleting document {doc_id}: {e}")
|
|
|
|
async def get_entity_info(
|
|
self, entity_name: str, include_vector_data: bool = False
|
|
) -> dict[str, str | None | dict[str, str]]:
|
|
"""Get detailed information of an entity
|
|
|
|
Args:
|
|
entity_name: Entity name (no need for quotes)
|
|
include_vector_data: Whether to include data from the vector database
|
|
|
|
Returns:
|
|
dict: A dictionary containing entity information, including:
|
|
- entity_name: Entity name
|
|
- source_id: Source document ID
|
|
- graph_data: Complete node data from the graph database
|
|
- vector_data: (optional) Data from the vector database
|
|
"""
|
|
entity_name = f'"{entity_name.upper()}"'
|
|
|
|
# Get information from the graph
|
|
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
|
source_id = node_data.get("source_id") if node_data else None
|
|
|
|
result: dict[str, str | None | dict[str, str]] = {
|
|
"entity_name": entity_name,
|
|
"source_id": source_id,
|
|
"graph_data": node_data,
|
|
}
|
|
|
|
# Optional: Get vector database information
|
|
if include_vector_data:
|
|
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
|
vector_data = self.entities_vdb._client.get([entity_id])
|
|
result["vector_data"] = vector_data[0] if vector_data else None
|
|
|
|
return result
|
|
|
|
async def get_relation_info(
|
|
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
|
) -> dict[str, str | None | dict[str, str]]:
|
|
"""Get detailed information of a relationship
|
|
|
|
Args:
|
|
src_entity: Source entity name (no need for quotes)
|
|
tgt_entity: Target entity name (no need for quotes)
|
|
include_vector_data: Whether to include data from the vector database
|
|
|
|
Returns:
|
|
dict: A dictionary containing relationship information, including:
|
|
- src_entity: Source entity name
|
|
- tgt_entity: Target entity name
|
|
- source_id: Source document ID
|
|
- graph_data: Complete edge data from the graph database
|
|
- vector_data: (optional) Data from the vector database
|
|
"""
|
|
src_entity = f'"{src_entity.upper()}"'
|
|
tgt_entity = f'"{tgt_entity.upper()}"'
|
|
|
|
# Get information from the graph
|
|
edge_data = await self.chunk_entity_relation_graph.get_edge(
|
|
src_entity, tgt_entity
|
|
)
|
|
source_id = edge_data.get("source_id") if edge_data else None
|
|
|
|
result: dict[str, str | None | dict[str, str]] = {
|
|
"src_entity": src_entity,
|
|
"tgt_entity": tgt_entity,
|
|
"source_id": source_id,
|
|
"graph_data": edge_data,
|
|
}
|
|
|
|
# Optional: Get vector database information
|
|
if include_vector_data:
|
|
rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-")
|
|
vector_data = self.relationships_vdb._client.get([rel_id])
|
|
result["vector_data"] = vector_data[0] if vector_data else None
|
|
|
|
return result
|
|
|
|
def verify_storage_implementation(
|
|
self, storage_type: str, storage_name: str
|
|
) -> None:
|
|
"""Verify if storage implementation is compatible with specified storage type
|
|
|
|
Args:
|
|
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
|
storage_name: Storage implementation name
|
|
|
|
Raises:
|
|
ValueError: If storage implementation is incompatible or missing required methods
|
|
"""
|
|
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
|
raise ValueError(f"Unknown storage type: {storage_type}")
|
|
|
|
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
|
if storage_name not in storage_info["implementations"]:
|
|
raise ValueError(
|
|
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
|
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
|
)
|
|
|
|
def check_storage_env_vars(self, storage_name: str) -> None:
|
|
"""Check if all required environment variables for storage implementation exist
|
|
|
|
Args:
|
|
storage_name: Storage implementation name
|
|
|
|
Raises:
|
|
ValueError: If required environment variables are missing
|
|
"""
|
|
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
|
missing_vars = [var for var in required_vars if var not in os.environ]
|
|
|
|
if missing_vars:
|
|
raise ValueError(
|
|
f"Storage implementation '{storage_name}' requires the following "
|
|
f"environment variables: {', '.join(missing_vars)}"
|
|
) |