LightRAG/lightrag/lightrag.py

493 lines
17 KiB
Python
Raw Normal View History

2024-10-10 15:02:30 +08:00
import asyncio
import os
2024-11-25 15:04:38 +08:00
from tqdm.asyncio import tqdm as tqdm_async
2024-10-10 15:02:30 +08:00
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Type, cast
2024-10-10 15:02:30 +08:00
from .llm import (
gpt_4o_mini_complete,
openai_embedding,
)
2024-10-10 15:02:30 +08:00
from .operate import (
chunking_by_token_size,
extract_entities,
2024-11-25 13:29:55 +08:00
# local_query,global_query,hybrid_query,
kg_query,
2024-10-10 15:02:30 +08:00
naive_query,
)
from .utils import (
EmbeddingFunc,
compute_mdhash_id,
limit_async_func_call,
convert_response_to_json,
logger,
set_logger,
)
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
StorageNameSpace,
QueryParam,
)
from .storage import (
JsonKVStorage,
NanoVectorDBStorage,
NetworkXStorage,
2024-11-12 13:32:40 +08:00
)
from .kg.neo4j_impl import Neo4JStorage
2024-11-12 13:32:40 +08:00
from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
# future KG integrations
# from .kg.ArangoDB_impl import (
# GraphStorage as ArangoDBStorage
# )
2024-11-12 13:32:40 +08:00
2024-10-10 15:02:30 +08:00
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
Ensure that there is always an event loop available.
This function tries to get the current event loop. If the current event loop is closed or does not exist,
it creates a new event loop and sets it as the current event loop.
Returns:
asyncio.AbstractEventLoop: The current or newly created event loop.
"""
2024-10-10 15:02:30 +08:00
try:
# Try to get the current event loop
current_loop = asyncio.get_event_loop()
if current_loop._closed:
raise RuntimeError("Event loop is closed.")
return current_loop
2024-11-07 14:54:15 +08:00
2024-10-10 15:02:30 +08:00
except RuntimeError:
# If no event loop exists or it is closed, create a new one
2024-11-02 18:35:07 -04:00
logger.info("Creating a new event loop in main thread.")
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
return new_loop
2024-10-10 15:02:30 +08:00
2024-10-15 19:40:08 +08:00
2024-10-10 15:02:30 +08:00
@dataclass
class LightRAG:
working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
)
2024-11-12 13:32:40 +08:00
kv_storage: str = field(default="JsonKVStorage")
vector_storage: str = field(default="NanoVectorDBStorage")
graph_storage: str = field(default="NetworkXStorage")
current_log_level = logger.level
log_level: str = field(default=current_log_level)
2024-10-10 15:02:30 +08:00
# text chunking
chunk_token_size: int = 1200
chunk_overlap_token_size: int = 100
tiktoken_model_name: str = "gpt-4o-mini"
# entity extraction
entity_extract_max_gleaning: int = 1
entity_summary_to_max_tokens: int = 500
# node embedding
node_embedding_algorithm: str = "node2vec"
node2vec_params: dict = field(
default_factory=lambda: {
"dimensions": 1536,
"num_walks": 10,
"walk_length": 40,
"window_size": 2,
"iterations": 3,
"random_seed": 3,
}
)
2024-10-14 20:33:46 +08:00
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
2024-10-10 15:02:30 +08:00
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
# LLM
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
2024-10-10 15:02:30 +08:00
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
llm_model_kwargs: dict = field(default_factory=dict)
2024-10-10 15:02:30 +08:00
# storage
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
2024-11-12 13:32:40 +08:00
2024-10-10 15:02:30 +08:00
enable_llm_cache: bool = True
# extension
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
def __post_init__(self):
log_file = os.path.join("lightrag.log")
2024-10-10 15:02:30 +08:00
set_logger(log_file)
logger.setLevel(self.log_level)
2024-10-10 15:02:30 +08:00
logger.info(f"Logger initialized for working directory: {self.working_dir}")
2024-10-10 15:02:30 +08:00
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
2024-11-06 11:18:14 -05:00
# @TODO: should move all storage setup here to leverage initial start params attached to self.
2024-11-12 13:32:40 +08:00
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
self._get_storage_class()[self.kv_storage]
)
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
self.vector_storage
]
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
self.graph_storage
]
2024-10-10 15:02:30 +08:00
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
self.llm_response_cache = (
self.key_string_value_json_storage_cls(
2024-11-12 13:32:40 +08:00
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
2024-10-10 15:02:30 +08:00
)
if self.enable_llm_cache
else None
)
2024-10-15 21:11:12 +08:00
2024-10-10 15:02:30 +08:00
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
2024-10-15 21:11:12 +08:00
self.embedding_func
2024-10-10 15:02:30 +08:00
)
2024-10-15 19:40:08 +08:00
####
# add embedding func by walter
####
self.full_docs = self.key_string_value_json_storage_cls(
2024-11-12 13:32:40 +08:00
namespace="full_docs",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
self.text_chunks = self.key_string_value_json_storage_cls(
2024-11-12 13:32:40 +08:00
namespace="text_chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
self.chunk_entity_relation_graph = self.graph_storage_cls(
2024-12-03 16:04:58 +08:00
namespace="chunk_entity_relation",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
####
# add embedding func by walter over
####
self.entities_vdb = self.vector_db_storage_cls(
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
2024-10-10 15:02:30 +08:00
)
self.relationships_vdb = self.vector_db_storage_cls(
namespace="relationships",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
2024-10-10 15:02:30 +08:00
)
self.chunks_vdb = self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
2024-10-10 15:02:30 +08:00
)
2024-10-10 15:02:30 +08:00
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
2024-10-28 17:05:38 +02:00
partial(
self.llm_model_func,
hashing_kv=self.llm_response_cache,
**self.llm_model_kwargs,
)
2024-10-10 15:02:30 +08:00
)
2024-11-06 11:18:14 -05:00
def _get_storage_class(self) -> Type[BaseGraphStorage]:
return {
2024-11-08 16:12:58 +08:00
# kv storage
2024-11-12 13:32:40 +08:00
"JsonKVStorage": JsonKVStorage,
"OracleKVStorage": OracleKVStorage,
2024-11-08 16:12:58 +08:00
# vector storage
2024-11-12 13:32:40 +08:00
"NanoVectorDBStorage": NanoVectorDBStorage,
"OracleVectorDBStorage": OracleVectorDBStorage,
2024-11-08 16:12:58 +08:00
# graph storage
"NetworkXStorage": NetworkXStorage,
2024-11-08 16:12:58 +08:00
"Neo4JStorage": Neo4JStorage,
"OracleGraphStorage": OracleGraphStorage,
# "ArangoDBStorage": ArangoDBStorage
}
2024-10-10 15:02:30 +08:00
def insert(self, string_or_strings):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings))
async def ainsert(self, string_or_strings):
update_storage = False
2024-10-10 15:02:30 +08:00
try:
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
new_docs = {
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
for c in string_or_strings
}
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning("All docs are already in the storage")
2024-10-10 15:02:30 +08:00
return
update_storage = True
2024-10-10 15:02:30 +08:00
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
inserting_chunks = {}
2024-11-25 15:04:38 +08:00
for doc_key, doc in tqdm_async(
new_docs.items(), desc="Chunking documents", unit="doc"
):
2024-10-10 15:02:30 +08:00
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_key,
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
inserting_chunks.update(chunks)
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning("All chunks are already in the storage")
2024-10-10 15:02:30 +08:00
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
await self.chunks_vdb.upsert(inserting_chunks)
logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
inserting_chunks,
2024-10-26 00:11:21 -04:00
knowledge_graph_inst=self.chunk_entity_relation_graph,
2024-10-10 15:02:30 +08:00
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.warning("No new entities and relationships found")
return
self.chunk_entity_relation_graph = maybe_new_kg
await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks)
finally:
if update_storage:
await self._insert_done()
2024-10-10 15:02:30 +08:00
async def _insert_done(self):
tasks = []
for storage_inst in [
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
2024-11-25 18:06:19 +08:00
def insert_custom_kg(self, custom_kg: dict):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
async def ainsert_custom_kg(self, custom_kg: dict):
update_storage = False
try:
# Insert entities into knowledge graph
all_entities_data = []
for entity_data in custom_kg.get("entities", []):
entity_name = f'"{entity_data["entity_name"].upper()}"'
entity_type = entity_data.get("entity_type", "UNKNOWN")
description = entity_data.get("description", "No description provided")
source_id = entity_data["source_id"]
# Prepare node data
node_data = {
"entity_type": entity_type,
"description": description,
"source_id": source_id,
}
# Insert node data into the knowledge graph
await self.chunk_entity_relation_graph.upsert_node(
entity_name, node_data=node_data
)
node_data["entity_name"] = entity_name
all_entities_data.append(node_data)
update_storage = True
# Insert relationships into knowledge graph
all_relationships_data = []
for relationship_data in custom_kg.get("relationships", []):
src_id = f'"{relationship_data["src_id"].upper()}"'
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
description = relationship_data["description"]
keywords = relationship_data["keywords"]
weight = relationship_data.get("weight", 1.0)
source_id = relationship_data["source_id"]
# Check if nodes exist in the knowledge graph
for need_insert_id in [src_id, tgt_id]:
if not (
await self.chunk_entity_relation_graph.has_node(need_insert_id)
):
await self.chunk_entity_relation_graph.upsert_node(
need_insert_id,
node_data={
"source_id": source_id,
"description": "UNKNOWN",
"entity_type": "UNKNOWN",
},
)
# Insert edge into the knowledge graph
await self.chunk_entity_relation_graph.upsert_edge(
src_id,
tgt_id,
edge_data={
"weight": weight,
"description": description,
"keywords": keywords,
"source_id": source_id,
},
)
edge_data = {
"src_id": src_id,
"tgt_id": tgt_id,
"description": description,
"keywords": keywords,
}
all_relationships_data.append(edge_data)
update_storage = True
# Insert entities into vector storage if needed
if self.entities_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
for dp in all_entities_data
}
await self.entities_vdb.upsert(data_for_vdb)
# Insert relationships into vector storage if needed
if self.relationships_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
}
for dp in all_relationships_data
}
await self.relationships_vdb.upsert(data_for_vdb)
finally:
if update_storage:
await self._insert_done()
2024-10-10 15:02:30 +08:00
def query(self, query: str, param: QueryParam = QueryParam()):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param))
2024-10-10 15:02:30 +08:00
async def aquery(self, query: str, param: QueryParam = QueryParam()):
2024-11-25 13:29:55 +08:00
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
2024-10-10 15:02:30 +08:00
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "naive":
response = await naive_query(
query,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
)
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
async def _query_done(self):
tasks = []
for storage_inst in [self.llm_response_cache]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
2024-11-06 11:18:14 -05:00
await asyncio.gather(*tasks)
2024-11-11 17:48:40 +08:00
def delete_by_entity(self, entity_name: str):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.adelete_by_entity(entity_name))
2024-11-11 17:54:22 +08:00
2024-11-11 17:48:40 +08:00
async def adelete_by_entity(self, entity_name: str):
2024-11-11 17:54:22 +08:00
entity_name = f'"{entity_name.upper()}"'
2024-11-11 17:48:40 +08:00
try:
await self.entities_vdb.delete_entity(entity_name)
await self.relationships_vdb.delete_relation(entity_name)
await self.chunk_entity_relation_graph.delete_node(entity_name)
2024-11-11 17:54:22 +08:00
logger.info(
f"Entity '{entity_name}' and its relationships have been deleted."
)
2024-11-11 17:48:40 +08:00
await self._delete_by_entity_done()
except Exception as e:
logger.error(f"Error while deleting entity '{entity_name}': {e}")
2024-11-11 17:54:22 +08:00
2024-11-11 17:48:40 +08:00
async def _delete_by_entity_done(self):
tasks = []
for storage_inst in [
self.entities_vdb,
self.relationships_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
2024-11-11 17:54:22 +08:00
await asyncio.gather(*tasks)