| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  | import asyncio | 
					
						
							|  |  |  | import os | 
					
						
							| 
									
										
										
										
											2024-11-25 15:04:38 +08:00
										 |  |  | from tqdm.asyncio import tqdm as tqdm_async | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  | from dataclasses import asdict, dataclass, field | 
					
						
							|  |  |  | from datetime import datetime | 
					
						
							|  |  |  | from functools import partial | 
					
						
							| 
									
										
										
										
											2024-10-19 09:43:17 +05:30
										 |  |  | from typing import Type, cast | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-19 09:43:17 +05:30
										 |  |  | from .llm import ( | 
					
						
							|  |  |  |     gpt_4o_mini_complete, | 
					
						
							|  |  |  |     openai_embedding, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  | from .operate import ( | 
					
						
							|  |  |  |     chunking_by_token_size, | 
					
						
							|  |  |  |     extract_entities, | 
					
						
							| 
									
										
										
										
											2024-11-25 13:29:55 +08:00
										 |  |  |     # local_query,global_query,hybrid_query, | 
					
						
							|  |  |  |     kg_query, | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |     naive_query, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from .utils import ( | 
					
						
							|  |  |  |     EmbeddingFunc, | 
					
						
							|  |  |  |     compute_mdhash_id, | 
					
						
							|  |  |  |     limit_async_func_call, | 
					
						
							|  |  |  |     convert_response_to_json, | 
					
						
							|  |  |  |     logger, | 
					
						
							|  |  |  |     set_logger, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | from .base import ( | 
					
						
							|  |  |  |     BaseGraphStorage, | 
					
						
							|  |  |  |     BaseKVStorage, | 
					
						
							|  |  |  |     BaseVectorStorage, | 
					
						
							|  |  |  |     StorageNameSpace, | 
					
						
							|  |  |  |     QueryParam, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-08 14:58:41 +08:00
										 |  |  | from .storage import ( | 
					
						
							|  |  |  |     JsonKVStorage, | 
					
						
							|  |  |  |     NanoVectorDBStorage, | 
					
						
							|  |  |  |     NetworkXStorage, | 
					
						
							| 
									
										
										
										
											2024-11-12 13:32:40 +08:00
										 |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-11-08 14:58:41 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # future KG integrations | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # from .kg.ArangoDB_impl import ( | 
					
						
							|  |  |  | #     GraphStorage as ArangoDBStorage | 
					
						
							|  |  |  | # ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-12 13:32:40 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-09 15:35:35 +08:00
										 |  |  | def lazy_external_import(module_name: str, class_name: str): | 
					
						
							|  |  |  |     """Lazily import an external module and return a class from it.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def import_class(): | 
					
						
							|  |  |  |         import importlib | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Import the module using importlib | 
					
						
							|  |  |  |         module = importlib.import_module(module_name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Get the class from the module | 
					
						
							|  |  |  |         return getattr(module, class_name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Return the import_class function itself, not its result | 
					
						
							|  |  |  |     return import_class | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage") | 
					
						
							|  |  |  | OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage") | 
					
						
							|  |  |  | OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage") | 
					
						
							|  |  |  | OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage") | 
					
						
							|  |  |  | MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge") | 
					
						
							|  |  |  | MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  | def always_get_an_event_loop() -> asyncio.AbstractEventLoop: | 
					
						
							| 
									
										
										
										
											2024-11-29 13:27:08 -07:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     Ensure that there is always an event loop available. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     This function tries to get the current event loop. If the current event loop is closed or does not exist, | 
					
						
							|  |  |  |     it creates a new event loop and sets it as the current event loop. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Returns: | 
					
						
							|  |  |  |         asyncio.AbstractEventLoop: The current or newly created event loop. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-11-29 13:27:08 -07:00
										 |  |  |         # Try to get the current event loop | 
					
						
							|  |  |  |         current_loop = asyncio.get_event_loop() | 
					
						
							| 
									
										
										
										
											2024-12-09 17:10:13 +08:00
										 |  |  |         if current_loop.is_closed(): | 
					
						
							| 
									
										
										
										
											2024-11-29 13:27:08 -07:00
										 |  |  |             raise RuntimeError("Event loop is closed.") | 
					
						
							|  |  |  |         return current_loop | 
					
						
							| 
									
										
										
										
											2024-11-07 14:54:15 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |     except RuntimeError: | 
					
						
							| 
									
										
										
										
											2024-11-29 13:27:08 -07:00
										 |  |  |         # If no event loop exists or it is closed, create a new one | 
					
						
							| 
									
										
										
										
											2024-11-02 18:35:07 -04:00
										 |  |  |         logger.info("Creating a new event loop in main thread.") | 
					
						
							| 
									
										
										
										
											2024-11-29 13:27:08 -07:00
										 |  |  |         new_loop = asyncio.new_event_loop() | 
					
						
							|  |  |  |         asyncio.set_event_loop(new_loop) | 
					
						
							|  |  |  |         return new_loop | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-15 19:40:08 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  | @dataclass | 
					
						
							|  |  |  | class LightRAG: | 
					
						
							|  |  |  |     working_dir: str = field( | 
					
						
							|  |  |  |         default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-12-06 08:17:20 +08:00
										 |  |  |     # Default not to use embedding cache | 
					
						
							|  |  |  |     embedding_cache_config: dict = field( | 
					
						
							| 
									
										
										
										
											2024-12-08 17:35:52 +08:00
										 |  |  |         default_factory=lambda: { | 
					
						
							|  |  |  |             "enabled": False, | 
					
						
							|  |  |  |             "similarity_threshold": 0.95, | 
					
						
							|  |  |  |             "use_llm_check": False, | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2024-12-06 08:17:20 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-11-12 13:32:40 +08:00
										 |  |  |     kv_storage: str = field(default="JsonKVStorage") | 
					
						
							| 
									
										
										
										
											2024-11-08 14:58:41 +08:00
										 |  |  |     vector_storage: str = field(default="NanoVectorDBStorage") | 
					
						
							|  |  |  |     graph_storage: str = field(default="NetworkXStorage") | 
					
						
							| 
									
										
										
										
											2024-11-01 11:01:50 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     current_log_level = logger.level | 
					
						
							|  |  |  |     log_level: str = field(default=current_log_level) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |     # text chunking | 
					
						
							|  |  |  |     chunk_token_size: int = 1200 | 
					
						
							|  |  |  |     chunk_overlap_token_size: int = 100 | 
					
						
							|  |  |  |     tiktoken_model_name: str = "gpt-4o-mini" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # entity extraction | 
					
						
							|  |  |  |     entity_extract_max_gleaning: int = 1 | 
					
						
							|  |  |  |     entity_summary_to_max_tokens: int = 500 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # node embedding | 
					
						
							|  |  |  |     node_embedding_algorithm: str = "node2vec" | 
					
						
							|  |  |  |     node2vec_params: dict = field( | 
					
						
							|  |  |  |         default_factory=lambda: { | 
					
						
							|  |  |  |             "dimensions": 1536, | 
					
						
							|  |  |  |             "num_walks": 10, | 
					
						
							|  |  |  |             "walk_length": 40, | 
					
						
							|  |  |  |             "window_size": 2, | 
					
						
							|  |  |  |             "iterations": 3, | 
					
						
							|  |  |  |             "random_seed": 3, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-14 20:33:46 +08:00
										 |  |  |     # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding) | 
					
						
							| 
									
										
										
										
											2024-10-19 09:43:17 +05:30
										 |  |  |     embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |     embedding_batch_num: int = 32 | 
					
						
							|  |  |  |     embedding_func_max_async: int = 16 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # LLM | 
					
						
							| 
									
										
										
										
											2024-10-19 09:43:17 +05:30
										 |  |  |     llm_model_func: callable = gpt_4o_mini_complete  # hf_model_complete# | 
					
						
							|  |  |  |     llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"  #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |     llm_model_max_token_size: int = 32768 | 
					
						
							|  |  |  |     llm_model_max_async: int = 16 | 
					
						
							| 
									
										
										
										
											2024-10-21 11:53:06 +00:00
										 |  |  |     llm_model_kwargs: dict = field(default_factory=dict) | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # storage | 
					
						
							|  |  |  |     vector_db_storage_cls_kwargs: dict = field(default_factory=dict) | 
					
						
							| 
									
										
										
										
											2024-11-12 13:32:40 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |     enable_llm_cache: bool = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # extension | 
					
						
							|  |  |  |     addon_params: dict = field(default_factory=dict) | 
					
						
							|  |  |  |     convert_response_to_json_func: callable = convert_response_to_json | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-19 09:43:17 +05:30
										 |  |  |     def __post_init__(self): | 
					
						
							| 
									
										
										
										
											2024-12-04 08:44:13 +08:00
										 |  |  |         log_file = os.path.join("lightrag.log") | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |         set_logger(log_file) | 
					
						
							| 
									
										
										
										
											2024-11-01 11:01:50 -04:00
										 |  |  |         logger.setLevel(self.log_level) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |         logger.info(f"Logger initialized for working directory: {self.working_dir}") | 
					
						
							| 
									
										
										
										
											2024-10-19 09:43:17 +05:30
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |         _print_config = ",\n  ".join([f"{k} = {v}" for k, v in asdict(self).items()]) | 
					
						
							|  |  |  |         logger.debug(f"LightRAG init with param:\n  {_print_config}\n") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-06 11:18:14 -05:00
										 |  |  |         # @TODO: should move all storage setup here to leverage initial start params attached to self. | 
					
						
							| 
									
										
										
										
											2024-11-01 08:47:52 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-12 13:32:40 +08:00
										 |  |  |         self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( | 
					
						
							|  |  |  |             self._get_storage_class()[self.kv_storage] | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[ | 
					
						
							|  |  |  |             self.vector_storage | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  |         self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ | 
					
						
							|  |  |  |             self.graph_storage | 
					
						
							|  |  |  |         ] | 
					
						
							| 
									
										
										
										
											2024-11-08 14:58:41 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |         if not os.path.exists(self.working_dir): | 
					
						
							|  |  |  |             logger.info(f"Creating working directory {self.working_dir}") | 
					
						
							|  |  |  |             os.makedirs(self.working_dir) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.llm_response_cache = ( | 
					
						
							|  |  |  |             self.key_string_value_json_storage_cls( | 
					
						
							| 
									
										
										
										
											2024-11-12 13:32:40 +08:00
										 |  |  |                 namespace="llm_response_cache", | 
					
						
							|  |  |  |                 global_config=asdict(self), | 
					
						
							|  |  |  |                 embedding_func=None, | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |             if self.enable_llm_cache | 
					
						
							|  |  |  |             else None | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( | 
					
						
							| 
									
										
										
										
											2024-10-15 21:11:12 +08:00
										 |  |  |             self.embedding_func | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-10-15 19:40:08 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-08 14:58:41 +08:00
										 |  |  |         #### | 
					
						
							|  |  |  |         # add embedding func by walter | 
					
						
							|  |  |  |         #### | 
					
						
							|  |  |  |         self.full_docs = self.key_string_value_json_storage_cls( | 
					
						
							| 
									
										
										
										
											2024-11-12 13:32:40 +08:00
										 |  |  |             namespace="full_docs", | 
					
						
							|  |  |  |             global_config=asdict(self), | 
					
						
							|  |  |  |             embedding_func=self.embedding_func, | 
					
						
							| 
									
										
										
										
											2024-11-08 14:58:41 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         self.text_chunks = self.key_string_value_json_storage_cls( | 
					
						
							| 
									
										
										
										
											2024-11-12 13:32:40 +08:00
										 |  |  |             namespace="text_chunks", | 
					
						
							|  |  |  |             global_config=asdict(self), | 
					
						
							|  |  |  |             embedding_func=self.embedding_func, | 
					
						
							| 
									
										
										
										
											2024-11-08 14:58:41 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         self.chunk_entity_relation_graph = self.graph_storage_cls( | 
					
						
							| 
									
										
										
										
											2024-12-03 16:04:58 +08:00
										 |  |  |             namespace="chunk_entity_relation", | 
					
						
							|  |  |  |             global_config=asdict(self), | 
					
						
							|  |  |  |             embedding_func=self.embedding_func, | 
					
						
							| 
									
										
										
										
											2024-11-08 14:58:41 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         #### | 
					
						
							|  |  |  |         # add embedding func by walter over | 
					
						
							|  |  |  |         #### | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-19 09:43:17 +05:30
										 |  |  |         self.entities_vdb = self.vector_db_storage_cls( | 
					
						
							|  |  |  |             namespace="entities", | 
					
						
							|  |  |  |             global_config=asdict(self), | 
					
						
							|  |  |  |             embedding_func=self.embedding_func, | 
					
						
							|  |  |  |             meta_fields={"entity_name"}, | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-10-19 09:43:17 +05:30
										 |  |  |         self.relationships_vdb = self.vector_db_storage_cls( | 
					
						
							|  |  |  |             namespace="relationships", | 
					
						
							|  |  |  |             global_config=asdict(self), | 
					
						
							|  |  |  |             embedding_func=self.embedding_func, | 
					
						
							|  |  |  |             meta_fields={"src_id", "tgt_id"}, | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-10-19 09:43:17 +05:30
										 |  |  |         self.chunks_vdb = self.vector_db_storage_cls( | 
					
						
							|  |  |  |             namespace="chunks", | 
					
						
							|  |  |  |             global_config=asdict(self), | 
					
						
							|  |  |  |             embedding_func=self.embedding_func, | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-10-19 09:43:17 +05:30
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |         self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( | 
					
						
							| 
									
										
										
										
											2024-10-28 17:05:38 +02:00
										 |  |  |             partial( | 
					
						
							|  |  |  |                 self.llm_model_func, | 
					
						
							|  |  |  |                 hashing_kv=self.llm_response_cache, | 
					
						
							|  |  |  |                 **self.llm_model_kwargs, | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-11-06 11:18:14 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-01 08:47:52 -04:00
										 |  |  |     def _get_storage_class(self) -> Type[BaseGraphStorage]: | 
					
						
							|  |  |  |         return { | 
					
						
							| 
									
										
										
										
											2024-11-08 16:12:58 +08:00
										 |  |  |             # kv storage | 
					
						
							| 
									
										
										
										
											2024-11-12 13:32:40 +08:00
										 |  |  |             "JsonKVStorage": JsonKVStorage, | 
					
						
							|  |  |  |             "OracleKVStorage": OracleKVStorage, | 
					
						
							| 
									
										
										
										
											2024-12-05 13:57:43 +08:00
										 |  |  |             "MongoKVStorage": MongoKVStorage, | 
					
						
							| 
									
										
										
										
											2024-11-08 16:12:58 +08:00
										 |  |  |             # vector storage | 
					
						
							| 
									
										
										
										
											2024-11-12 13:32:40 +08:00
										 |  |  |             "NanoVectorDBStorage": NanoVectorDBStorage, | 
					
						
							|  |  |  |             "OracleVectorDBStorage": OracleVectorDBStorage, | 
					
						
							| 
									
										
										
										
											2024-12-04 17:26:47 +08:00
										 |  |  |             "MilvusVectorDBStorge": MilvusVectorDBStorge, | 
					
						
							| 
									
										
										
										
											2024-11-08 16:12:58 +08:00
										 |  |  |             # graph storage | 
					
						
							| 
									
										
										
										
											2024-11-01 08:47:52 -04:00
										 |  |  |             "NetworkXStorage": NetworkXStorage, | 
					
						
							| 
									
										
										
										
											2024-11-08 16:12:58 +08:00
										 |  |  |             "Neo4JStorage": Neo4JStorage, | 
					
						
							| 
									
										
										
										
											2024-11-08 14:58:41 +08:00
										 |  |  |             "OracleGraphStorage": OracleGraphStorage, | 
					
						
							| 
									
										
										
										
											2024-11-01 11:01:50 -04:00
										 |  |  |             # "ArangoDBStorage": ArangoDBStorage | 
					
						
							| 
									
										
										
										
											2024-11-01 08:47:52 -04:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def insert(self, string_or_strings): | 
					
						
							|  |  |  |         loop = always_get_an_event_loop() | 
					
						
							|  |  |  |         return loop.run_until_complete(self.ainsert(string_or_strings)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def ainsert(self, string_or_strings): | 
					
						
							| 
									
										
										
										
											2024-11-12 09:30:21 -07:00
										 |  |  |         update_storage = False | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             if isinstance(string_or_strings, str): | 
					
						
							|  |  |  |                 string_or_strings = [string_or_strings] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             new_docs = { | 
					
						
							|  |  |  |                 compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()} | 
					
						
							|  |  |  |                 for c in string_or_strings | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) | 
					
						
							|  |  |  |             new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} | 
					
						
							|  |  |  |             if not len(new_docs): | 
					
						
							| 
									
										
										
										
											2024-10-19 09:43:17 +05:30
										 |  |  |                 logger.warning("All docs are already in the storage") | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |                 return | 
					
						
							| 
									
										
										
										
											2024-11-12 09:30:21 -07:00
										 |  |  |             update_storage = True | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |             logger.info(f"[New Docs] inserting {len(new_docs)} docs") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             inserting_chunks = {} | 
					
						
							| 
									
										
										
										
											2024-11-25 15:04:38 +08:00
										 |  |  |             for doc_key, doc in tqdm_async( | 
					
						
							|  |  |  |                 new_docs.items(), desc="Chunking documents", unit="doc" | 
					
						
							|  |  |  |             ): | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |                 chunks = { | 
					
						
							|  |  |  |                     compute_mdhash_id(dp["content"], prefix="chunk-"): { | 
					
						
							|  |  |  |                         **dp, | 
					
						
							|  |  |  |                         "full_doc_id": doc_key, | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                     for dp in chunking_by_token_size( | 
					
						
							|  |  |  |                         doc["content"], | 
					
						
							|  |  |  |                         overlap_token_size=self.chunk_overlap_token_size, | 
					
						
							|  |  |  |                         max_token_size=self.chunk_token_size, | 
					
						
							|  |  |  |                         tiktoken_model=self.tiktoken_model_name, | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 inserting_chunks.update(chunks) | 
					
						
							|  |  |  |             _add_chunk_keys = await self.text_chunks.filter_keys( | 
					
						
							|  |  |  |                 list(inserting_chunks.keys()) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             inserting_chunks = { | 
					
						
							|  |  |  |                 k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             if not len(inserting_chunks): | 
					
						
							| 
									
										
										
										
											2024-10-19 09:43:17 +05:30
										 |  |  |                 logger.warning("All chunks are already in the storage") | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |                 return | 
					
						
							|  |  |  |             logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             await self.chunks_vdb.upsert(inserting_chunks) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             logger.info("[Entity Extraction]...") | 
					
						
							|  |  |  |             maybe_new_kg = await extract_entities( | 
					
						
							|  |  |  |                 inserting_chunks, | 
					
						
							| 
									
										
										
										
											2024-10-26 00:11:21 -04:00
										 |  |  |                 knowledge_graph_inst=self.chunk_entity_relation_graph, | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |                 entity_vdb=self.entities_vdb, | 
					
						
							|  |  |  |                 relationships_vdb=self.relationships_vdb, | 
					
						
							|  |  |  |                 global_config=asdict(self), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             if maybe_new_kg is None: | 
					
						
							|  |  |  |                 logger.warning("No new entities and relationships found") | 
					
						
							|  |  |  |                 return | 
					
						
							|  |  |  |             self.chunk_entity_relation_graph = maybe_new_kg | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             await self.full_docs.upsert(new_docs) | 
					
						
							|  |  |  |             await self.text_chunks.upsert(inserting_chunks) | 
					
						
							|  |  |  |         finally: | 
					
						
							| 
									
										
										
										
											2024-11-12 09:30:21 -07:00
										 |  |  |             if update_storage: | 
					
						
							|  |  |  |                 await self._insert_done() | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     async def _insert_done(self): | 
					
						
							|  |  |  |         tasks = [] | 
					
						
							|  |  |  |         for storage_inst in [ | 
					
						
							|  |  |  |             self.full_docs, | 
					
						
							|  |  |  |             self.text_chunks, | 
					
						
							|  |  |  |             self.llm_response_cache, | 
					
						
							|  |  |  |             self.entities_vdb, | 
					
						
							|  |  |  |             self.relationships_vdb, | 
					
						
							|  |  |  |             self.chunks_vdb, | 
					
						
							|  |  |  |             self.chunk_entity_relation_graph, | 
					
						
							|  |  |  |         ]: | 
					
						
							|  |  |  |             if storage_inst is None: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  |             tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) | 
					
						
							|  |  |  |         await asyncio.gather(*tasks) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-25 18:06:19 +08:00
										 |  |  |     def insert_custom_kg(self, custom_kg: dict): | 
					
						
							|  |  |  |         loop = always_get_an_event_loop() | 
					
						
							|  |  |  |         return loop.run_until_complete(self.ainsert_custom_kg(custom_kg)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def ainsert_custom_kg(self, custom_kg: dict): | 
					
						
							|  |  |  |         update_storage = False | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  |             # Insert chunks into vector storage | 
					
						
							|  |  |  |             all_chunks_data = {} | 
					
						
							|  |  |  |             chunk_to_source_map = {} | 
					
						
							|  |  |  |             for chunk_data in custom_kg.get("chunks", []): | 
					
						
							|  |  |  |                 chunk_content = chunk_data["content"] | 
					
						
							|  |  |  |                 source_id = chunk_data["source_id"] | 
					
						
							|  |  |  |                 chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 chunk_entry = {"content": chunk_content.strip(), "source_id": source_id} | 
					
						
							|  |  |  |                 all_chunks_data[chunk_id] = chunk_entry | 
					
						
							|  |  |  |                 chunk_to_source_map[source_id] = chunk_id | 
					
						
							|  |  |  |                 update_storage = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if self.chunks_vdb is not None and all_chunks_data: | 
					
						
							|  |  |  |                 await self.chunks_vdb.upsert(all_chunks_data) | 
					
						
							|  |  |  |             if self.text_chunks is not None and all_chunks_data: | 
					
						
							|  |  |  |                 await self.text_chunks.upsert(all_chunks_data) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-25 18:06:19 +08:00
										 |  |  |             # Insert entities into knowledge graph | 
					
						
							|  |  |  |             all_entities_data = [] | 
					
						
							|  |  |  |             for entity_data in custom_kg.get("entities", []): | 
					
						
							|  |  |  |                 entity_name = f'"{entity_data["entity_name"].upper()}"' | 
					
						
							|  |  |  |                 entity_type = entity_data.get("entity_type", "UNKNOWN") | 
					
						
							|  |  |  |                 description = entity_data.get("description", "No description provided") | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  |                 # source_id = entity_data["source_id"] | 
					
						
							|  |  |  |                 source_chunk_id = entity_data.get("source_id", "UNKNOWN") | 
					
						
							|  |  |  |                 source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Log if source_id is UNKNOWN | 
					
						
							|  |  |  |                 if source_id == "UNKNOWN": | 
					
						
							|  |  |  |                     logger.warning( | 
					
						
							|  |  |  |                         f"Entity '{entity_name}' has an UNKNOWN source_id. Please check the source mapping." | 
					
						
							|  |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2024-11-25 18:06:19 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 # Prepare node data | 
					
						
							|  |  |  |                 node_data = { | 
					
						
							|  |  |  |                     "entity_type": entity_type, | 
					
						
							|  |  |  |                     "description": description, | 
					
						
							|  |  |  |                     "source_id": source_id, | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 # Insert node data into the knowledge graph | 
					
						
							|  |  |  |                 await self.chunk_entity_relation_graph.upsert_node( | 
					
						
							|  |  |  |                     entity_name, node_data=node_data | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 node_data["entity_name"] = entity_name | 
					
						
							|  |  |  |                 all_entities_data.append(node_data) | 
					
						
							|  |  |  |                 update_storage = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Insert relationships into knowledge graph | 
					
						
							|  |  |  |             all_relationships_data = [] | 
					
						
							|  |  |  |             for relationship_data in custom_kg.get("relationships", []): | 
					
						
							|  |  |  |                 src_id = f'"{relationship_data["src_id"].upper()}"' | 
					
						
							|  |  |  |                 tgt_id = f'"{relationship_data["tgt_id"].upper()}"' | 
					
						
							|  |  |  |                 description = relationship_data["description"] | 
					
						
							|  |  |  |                 keywords = relationship_data["keywords"] | 
					
						
							|  |  |  |                 weight = relationship_data.get("weight", 1.0) | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  |                 # source_id = relationship_data["source_id"] | 
					
						
							|  |  |  |                 source_chunk_id = relationship_data.get("source_id", "UNKNOWN") | 
					
						
							|  |  |  |                 source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Log if source_id is UNKNOWN | 
					
						
							|  |  |  |                 if source_id == "UNKNOWN": | 
					
						
							|  |  |  |                     logger.warning( | 
					
						
							|  |  |  |                         f"Relationship from '{src_id}' to '{tgt_id}' has an UNKNOWN source_id. Please check the source mapping." | 
					
						
							|  |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2024-11-25 18:06:19 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 # Check if nodes exist in the knowledge graph | 
					
						
							|  |  |  |                 for need_insert_id in [src_id, tgt_id]: | 
					
						
							|  |  |  |                     if not ( | 
					
						
							|  |  |  |                         await self.chunk_entity_relation_graph.has_node(need_insert_id) | 
					
						
							|  |  |  |                     ): | 
					
						
							|  |  |  |                         await self.chunk_entity_relation_graph.upsert_node( | 
					
						
							|  |  |  |                             need_insert_id, | 
					
						
							|  |  |  |                             node_data={ | 
					
						
							|  |  |  |                                 "source_id": source_id, | 
					
						
							|  |  |  |                                 "description": "UNKNOWN", | 
					
						
							|  |  |  |                                 "entity_type": "UNKNOWN", | 
					
						
							|  |  |  |                             }, | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Insert edge into the knowledge graph | 
					
						
							|  |  |  |                 await self.chunk_entity_relation_graph.upsert_edge( | 
					
						
							|  |  |  |                     src_id, | 
					
						
							|  |  |  |                     tgt_id, | 
					
						
							|  |  |  |                     edge_data={ | 
					
						
							|  |  |  |                         "weight": weight, | 
					
						
							|  |  |  |                         "description": description, | 
					
						
							|  |  |  |                         "keywords": keywords, | 
					
						
							|  |  |  |                         "source_id": source_id, | 
					
						
							|  |  |  |                     }, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 edge_data = { | 
					
						
							|  |  |  |                     "src_id": src_id, | 
					
						
							|  |  |  |                     "tgt_id": tgt_id, | 
					
						
							|  |  |  |                     "description": description, | 
					
						
							|  |  |  |                     "keywords": keywords, | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 all_relationships_data.append(edge_data) | 
					
						
							|  |  |  |                 update_storage = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Insert entities into vector storage if needed | 
					
						
							|  |  |  |             if self.entities_vdb is not None: | 
					
						
							|  |  |  |                 data_for_vdb = { | 
					
						
							|  |  |  |                     compute_mdhash_id(dp["entity_name"], prefix="ent-"): { | 
					
						
							|  |  |  |                         "content": dp["entity_name"] + dp["description"], | 
					
						
							|  |  |  |                         "entity_name": dp["entity_name"], | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                     for dp in all_entities_data | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 await self.entities_vdb.upsert(data_for_vdb) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Insert relationships into vector storage if needed | 
					
						
							|  |  |  |             if self.relationships_vdb is not None: | 
					
						
							|  |  |  |                 data_for_vdb = { | 
					
						
							|  |  |  |                     compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { | 
					
						
							|  |  |  |                         "src_id": dp["src_id"], | 
					
						
							|  |  |  |                         "tgt_id": dp["tgt_id"], | 
					
						
							|  |  |  |                         "content": dp["keywords"] | 
					
						
							|  |  |  |                         + dp["src_id"] | 
					
						
							|  |  |  |                         + dp["tgt_id"] | 
					
						
							|  |  |  |                         + dp["description"], | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                     for dp in all_relationships_data | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 await self.relationships_vdb.upsert(data_for_vdb) | 
					
						
							|  |  |  |         finally: | 
					
						
							|  |  |  |             if update_storage: | 
					
						
							|  |  |  |                 await self._insert_done() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |     def query(self, query: str, param: QueryParam = QueryParam()): | 
					
						
							|  |  |  |         loop = always_get_an_event_loop() | 
					
						
							|  |  |  |         return loop.run_until_complete(self.aquery(query, param)) | 
					
						
							| 
									
										
										
										
											2024-10-19 09:43:17 +05:30
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |     async def aquery(self, query: str, param: QueryParam = QueryParam()): | 
					
						
							| 
									
										
										
										
											2024-11-25 13:29:55 +08:00
										 |  |  |         if param.mode in ["local", "global", "hybrid"]: | 
					
						
							|  |  |  |             response = await kg_query( | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |                 query, | 
					
						
							|  |  |  |                 self.chunk_entity_relation_graph, | 
					
						
							|  |  |  |                 self.entities_vdb, | 
					
						
							|  |  |  |                 self.relationships_vdb, | 
					
						
							|  |  |  |                 self.text_chunks, | 
					
						
							|  |  |  |                 param, | 
					
						
							|  |  |  |                 asdict(self), | 
					
						
							| 
									
										
										
										
											2024-12-08 17:35:52 +08:00
										 |  |  |                 hashing_kv=self.llm_response_cache, | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |         elif param.mode == "naive": | 
					
						
							|  |  |  |             response = await naive_query( | 
					
						
							|  |  |  |                 query, | 
					
						
							|  |  |  |                 self.chunks_vdb, | 
					
						
							|  |  |  |                 self.text_chunks, | 
					
						
							|  |  |  |                 param, | 
					
						
							|  |  |  |                 asdict(self), | 
					
						
							| 
									
										
										
										
											2024-12-08 17:35:52 +08:00
										 |  |  |                 hashing_kv=self.llm_response_cache, | 
					
						
							| 
									
										
										
										
											2024-10-10 15:02:30 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise ValueError(f"Unknown mode {param.mode}") | 
					
						
							|  |  |  |         await self._query_done() | 
					
						
							|  |  |  |         return response | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def _query_done(self): | 
					
						
							|  |  |  |         tasks = [] | 
					
						
							|  |  |  |         for storage_inst in [self.llm_response_cache]: | 
					
						
							|  |  |  |             if storage_inst is None: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  |             tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) | 
					
						
							| 
									
										
										
										
											2024-11-06 11:18:14 -05:00
										 |  |  |         await asyncio.gather(*tasks) | 
					
						
							| 
									
										
										
										
											2024-11-11 17:48:40 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def delete_by_entity(self, entity_name: str): | 
					
						
							|  |  |  |         loop = always_get_an_event_loop() | 
					
						
							|  |  |  |         return loop.run_until_complete(self.adelete_by_entity(entity_name)) | 
					
						
							| 
									
										
										
										
											2024-11-11 17:54:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-11 17:48:40 +08:00
										 |  |  |     async def adelete_by_entity(self, entity_name: str): | 
					
						
							| 
									
										
										
										
											2024-11-11 17:54:22 +08:00
										 |  |  |         entity_name = f'"{entity_name.upper()}"' | 
					
						
							| 
									
										
										
										
											2024-11-11 17:48:40 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             await self.entities_vdb.delete_entity(entity_name) | 
					
						
							|  |  |  |             await self.relationships_vdb.delete_relation(entity_name) | 
					
						
							|  |  |  |             await self.chunk_entity_relation_graph.delete_node(entity_name) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-11 17:54:22 +08:00
										 |  |  |             logger.info( | 
					
						
							|  |  |  |                 f"Entity '{entity_name}' and its relationships have been deleted." | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-11-11 17:48:40 +08:00
										 |  |  |             await self._delete_by_entity_done() | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             logger.error(f"Error while deleting entity '{entity_name}': {e}") | 
					
						
							| 
									
										
										
										
											2024-11-11 17:54:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-11 17:48:40 +08:00
										 |  |  |     async def _delete_by_entity_done(self): | 
					
						
							|  |  |  |         tasks = [] | 
					
						
							|  |  |  |         for storage_inst in [ | 
					
						
							|  |  |  |             self.entities_vdb, | 
					
						
							|  |  |  |             self.relationships_vdb, | 
					
						
							|  |  |  |             self.chunk_entity_relation_graph, | 
					
						
							|  |  |  |         ]: | 
					
						
							|  |  |  |             if storage_inst is None: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  |             tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) | 
					
						
							| 
									
										
										
										
											2024-11-11 17:54:22 +08:00
										 |  |  |         await asyncio.gather(*tasks) |