mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-07-24 09:20:17 +00:00

Before, the `finally` block would always call `_insert_done()`, which writes out the `vdb_*` and `kv_store_*` files ... even if there was nothing to insert (because all docs had already been inserted). This was causing the speed of skippable inserts to become very slow as the graph grew.
396 lines
13 KiB
Python
396 lines
13 KiB
Python
import asyncio
|
|
import os
|
|
from dataclasses import asdict, dataclass, field
|
|
from datetime import datetime
|
|
from functools import partial
|
|
from typing import Type, cast
|
|
|
|
from .llm import (
|
|
gpt_4o_mini_complete,
|
|
openai_embedding,
|
|
)
|
|
from .operate import (
|
|
chunking_by_token_size,
|
|
extract_entities,
|
|
local_query,
|
|
global_query,
|
|
hybrid_query,
|
|
naive_query,
|
|
)
|
|
|
|
from .utils import (
|
|
EmbeddingFunc,
|
|
compute_mdhash_id,
|
|
limit_async_func_call,
|
|
convert_response_to_json,
|
|
logger,
|
|
set_logger,
|
|
)
|
|
from .base import (
|
|
BaseGraphStorage,
|
|
BaseKVStorage,
|
|
BaseVectorStorage,
|
|
StorageNameSpace,
|
|
QueryParam,
|
|
)
|
|
|
|
from .storage import (
|
|
JsonKVStorage,
|
|
NanoVectorDBStorage,
|
|
NetworkXStorage,
|
|
)
|
|
|
|
from .kg.neo4j_impl import Neo4JStorage
|
|
|
|
from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
|
|
|
|
# future KG integrations
|
|
|
|
# from .kg.ArangoDB_impl import (
|
|
# GraphStorage as ArangoDBStorage
|
|
# )
|
|
|
|
|
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
try:
|
|
return asyncio.get_event_loop()
|
|
|
|
except RuntimeError:
|
|
logger.info("Creating a new event loop in main thread.")
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
return loop
|
|
|
|
|
|
@dataclass
|
|
class LightRAG:
|
|
working_dir: str = field(
|
|
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
|
)
|
|
|
|
kv_storage: str = field(default="JsonKVStorage")
|
|
vector_storage: str = field(default="NanoVectorDBStorage")
|
|
graph_storage: str = field(default="NetworkXStorage")
|
|
|
|
current_log_level = logger.level
|
|
log_level: str = field(default=current_log_level)
|
|
|
|
# text chunking
|
|
chunk_token_size: int = 1200
|
|
chunk_overlap_token_size: int = 100
|
|
tiktoken_model_name: str = "gpt-4o-mini"
|
|
|
|
# entity extraction
|
|
entity_extract_max_gleaning: int = 1
|
|
entity_summary_to_max_tokens: int = 500
|
|
|
|
# node embedding
|
|
node_embedding_algorithm: str = "node2vec"
|
|
node2vec_params: dict = field(
|
|
default_factory=lambda: {
|
|
"dimensions": 1536,
|
|
"num_walks": 10,
|
|
"walk_length": 40,
|
|
"window_size": 2,
|
|
"iterations": 3,
|
|
"random_seed": 3,
|
|
}
|
|
)
|
|
|
|
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
|
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
|
embedding_batch_num: int = 32
|
|
embedding_func_max_async: int = 16
|
|
|
|
# LLM
|
|
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
|
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
|
llm_model_max_token_size: int = 32768
|
|
llm_model_max_async: int = 16
|
|
llm_model_kwargs: dict = field(default_factory=dict)
|
|
|
|
# storage
|
|
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
|
|
|
enable_llm_cache: bool = True
|
|
|
|
# extension
|
|
addon_params: dict = field(default_factory=dict)
|
|
convert_response_to_json_func: callable = convert_response_to_json
|
|
|
|
def __post_init__(self):
|
|
log_file = os.path.join(self.working_dir, "lightrag.log")
|
|
set_logger(log_file)
|
|
logger.setLevel(self.log_level)
|
|
|
|
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
|
|
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
|
|
|
# @TODO: should move all storage setup here to leverage initial start params attached to self.
|
|
|
|
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
|
self._get_storage_class()[self.kv_storage]
|
|
)
|
|
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
|
|
self.vector_storage
|
|
]
|
|
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
|
|
self.graph_storage
|
|
]
|
|
|
|
if not os.path.exists(self.working_dir):
|
|
logger.info(f"Creating working directory {self.working_dir}")
|
|
os.makedirs(self.working_dir)
|
|
|
|
self.llm_response_cache = (
|
|
self.key_string_value_json_storage_cls(
|
|
namespace="llm_response_cache",
|
|
global_config=asdict(self),
|
|
embedding_func=None,
|
|
)
|
|
if self.enable_llm_cache
|
|
else None
|
|
)
|
|
|
|
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
|
self.embedding_func
|
|
)
|
|
|
|
####
|
|
# add embedding func by walter
|
|
####
|
|
self.full_docs = self.key_string_value_json_storage_cls(
|
|
namespace="full_docs",
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
)
|
|
self.text_chunks = self.key_string_value_json_storage_cls(
|
|
namespace="text_chunks",
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
)
|
|
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
|
namespace="chunk_entity_relation",
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
)
|
|
####
|
|
# add embedding func by walter over
|
|
####
|
|
|
|
self.entities_vdb = self.vector_db_storage_cls(
|
|
namespace="entities",
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
meta_fields={"entity_name"},
|
|
)
|
|
self.relationships_vdb = self.vector_db_storage_cls(
|
|
namespace="relationships",
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
meta_fields={"src_id", "tgt_id"},
|
|
)
|
|
self.chunks_vdb = self.vector_db_storage_cls(
|
|
namespace="chunks",
|
|
global_config=asdict(self),
|
|
embedding_func=self.embedding_func,
|
|
)
|
|
|
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
|
partial(
|
|
self.llm_model_func,
|
|
hashing_kv=self.llm_response_cache,
|
|
**self.llm_model_kwargs,
|
|
)
|
|
)
|
|
|
|
def _get_storage_class(self) -> Type[BaseGraphStorage]:
|
|
return {
|
|
# kv storage
|
|
"JsonKVStorage": JsonKVStorage,
|
|
"OracleKVStorage": OracleKVStorage,
|
|
# vector storage
|
|
"NanoVectorDBStorage": NanoVectorDBStorage,
|
|
"OracleVectorDBStorage": OracleVectorDBStorage,
|
|
# graph storage
|
|
"NetworkXStorage": NetworkXStorage,
|
|
"Neo4JStorage": Neo4JStorage,
|
|
"OracleGraphStorage": OracleGraphStorage,
|
|
# "ArangoDBStorage": ArangoDBStorage
|
|
}
|
|
|
|
def insert(self, string_or_strings):
|
|
loop = always_get_an_event_loop()
|
|
return loop.run_until_complete(self.ainsert(string_or_strings))
|
|
|
|
async def ainsert(self, string_or_strings):
|
|
update_storage = False
|
|
try:
|
|
if isinstance(string_or_strings, str):
|
|
string_or_strings = [string_or_strings]
|
|
|
|
new_docs = {
|
|
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
|
|
for c in string_or_strings
|
|
}
|
|
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
|
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
|
if not len(new_docs):
|
|
logger.warning("All docs are already in the storage")
|
|
return
|
|
update_storage = True
|
|
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
|
|
|
inserting_chunks = {}
|
|
for doc_key, doc in new_docs.items():
|
|
chunks = {
|
|
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
|
**dp,
|
|
"full_doc_id": doc_key,
|
|
}
|
|
for dp in chunking_by_token_size(
|
|
doc["content"],
|
|
overlap_token_size=self.chunk_overlap_token_size,
|
|
max_token_size=self.chunk_token_size,
|
|
tiktoken_model=self.tiktoken_model_name,
|
|
)
|
|
}
|
|
inserting_chunks.update(chunks)
|
|
_add_chunk_keys = await self.text_chunks.filter_keys(
|
|
list(inserting_chunks.keys())
|
|
)
|
|
inserting_chunks = {
|
|
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
|
}
|
|
if not len(inserting_chunks):
|
|
logger.warning("All chunks are already in the storage")
|
|
return
|
|
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
|
|
|
await self.chunks_vdb.upsert(inserting_chunks)
|
|
|
|
logger.info("[Entity Extraction]...")
|
|
maybe_new_kg = await extract_entities(
|
|
inserting_chunks,
|
|
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
|
entity_vdb=self.entities_vdb,
|
|
relationships_vdb=self.relationships_vdb,
|
|
global_config=asdict(self),
|
|
)
|
|
if maybe_new_kg is None:
|
|
logger.warning("No new entities and relationships found")
|
|
return
|
|
self.chunk_entity_relation_graph = maybe_new_kg
|
|
|
|
await self.full_docs.upsert(new_docs)
|
|
await self.text_chunks.upsert(inserting_chunks)
|
|
finally:
|
|
if update_storage:
|
|
await self._insert_done()
|
|
|
|
async def _insert_done(self):
|
|
tasks = []
|
|
for storage_inst in [
|
|
self.full_docs,
|
|
self.text_chunks,
|
|
self.llm_response_cache,
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.chunks_vdb,
|
|
self.chunk_entity_relation_graph,
|
|
]:
|
|
if storage_inst is None:
|
|
continue
|
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
|
await asyncio.gather(*tasks)
|
|
|
|
def query(self, query: str, param: QueryParam = QueryParam()):
|
|
loop = always_get_an_event_loop()
|
|
return loop.run_until_complete(self.aquery(query, param))
|
|
|
|
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
|
if param.mode == "local":
|
|
response = await local_query(
|
|
query,
|
|
self.chunk_entity_relation_graph,
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.text_chunks,
|
|
param,
|
|
asdict(self),
|
|
)
|
|
elif param.mode == "global":
|
|
response = await global_query(
|
|
query,
|
|
self.chunk_entity_relation_graph,
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.text_chunks,
|
|
param,
|
|
asdict(self),
|
|
)
|
|
elif param.mode == "hybrid":
|
|
response = await hybrid_query(
|
|
query,
|
|
self.chunk_entity_relation_graph,
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.text_chunks,
|
|
param,
|
|
asdict(self),
|
|
)
|
|
elif param.mode == "naive":
|
|
response = await naive_query(
|
|
query,
|
|
self.chunks_vdb,
|
|
self.text_chunks,
|
|
param,
|
|
asdict(self),
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown mode {param.mode}")
|
|
await self._query_done()
|
|
return response
|
|
|
|
async def _query_done(self):
|
|
tasks = []
|
|
for storage_inst in [self.llm_response_cache]:
|
|
if storage_inst is None:
|
|
continue
|
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
|
await asyncio.gather(*tasks)
|
|
|
|
def delete_by_entity(self, entity_name: str):
|
|
loop = always_get_an_event_loop()
|
|
return loop.run_until_complete(self.adelete_by_entity(entity_name))
|
|
|
|
async def adelete_by_entity(self, entity_name: str):
|
|
entity_name = f'"{entity_name.upper()}"'
|
|
|
|
try:
|
|
await self.entities_vdb.delete_entity(entity_name)
|
|
await self.relationships_vdb.delete_relation(entity_name)
|
|
await self.chunk_entity_relation_graph.delete_node(entity_name)
|
|
|
|
logger.info(
|
|
f"Entity '{entity_name}' and its relationships have been deleted."
|
|
)
|
|
await self._delete_by_entity_done()
|
|
except Exception as e:
|
|
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
|
|
|
async def _delete_by_entity_done(self):
|
|
tasks = []
|
|
for storage_inst in [
|
|
self.entities_vdb,
|
|
self.relationships_vdb,
|
|
self.chunk_entity_relation_graph,
|
|
]:
|
|
if storage_inst is None:
|
|
continue
|
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
|
await asyncio.gather(*tasks)
|