From b45ae1567c29ebeedb3014cbfae4afc6025f43f7 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Feb 2025 01:28:46 +0800 Subject: [PATCH] Refactor LLM cache handling and entity extraction - Removed custom LLM function in entity extraction - Simplified cache handling logic - Added `force_llm_cache` parameter - Updated cache handling conditions --- lightrag/operate.py | 28 +--------------------------- lightrag/utils.py | 8 +++----- 2 files changed, 4 insertions(+), 32 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index bc011cb9..0e1eb3f3 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -352,32 +352,6 @@ async def extract_entities( input_text: str, history_messages: list[dict[str, str]] = None ) -> str: if enable_llm_cache_for_entity_extract and llm_response_cache: - custom_llm = None - if ( - global_config["embedding_cache_config"] - and global_config["embedding_cache_config"]["enabled"] - ): - new_config = global_config.copy() - new_config["embedding_cache_config"] = None - new_config["enable_llm_cache"] = True - - # create a llm function with new_config for handle_cache - async def custom_llm( - prompt, - system_prompt=None, - history_messages=[], - keyword_extraction=False, - **kwargs, - ) -> str: - # 合并 new_config 和其他 kwargs,保证其他参数不被覆盖 - merged_config = {**kwargs, **new_config} - return await use_llm_func( - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - keyword_extraction=keyword_extraction, - **merged_config, - ) if history_messages: history = json.dumps(history_messages, ensure_ascii=False) @@ -392,7 +366,7 @@ async def extract_entities( _prompt, "default", cache_type="extract", - llm=custom_llm, + force_llm_cache=True, ) if cached_return: logger.debug(f"Found cache for {arg_hash}") diff --git a/lightrag/utils.py b/lightrag/utils.py index edf96dcc..1bd06e6d 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -484,10 +484,10 @@ def dequantize_embedding( async def handle_cache( - hashing_kv, args_hash, prompt, mode="default", cache_type=None, llm=None + hashing_kv, args_hash, prompt, mode="default", cache_type=None, force_llm_cache=False ): """Generic cache handling function""" - if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"): + if hashing_kv is None or not (force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")): return None, None, None, None if mode != "default": @@ -513,9 +513,7 @@ async def handle_cache( similarity_threshold=embedding_cache_config["similarity_threshold"], mode=mode, use_llm_check=use_llm_check, - llm_func=llm - if (use_llm_check and llm is not None) - else (llm_model_func if use_llm_check else None), + llm_func=llm_model_func if use_llm_check else None, original_prompt=prompt if use_llm_check else None, cache_type=cache_type, )