diff --git a/lightrag/operate.py b/lightrag/operate.py index b7319bcb..dbd4402c 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -27,6 +27,7 @@ from .utils import ( update_chunk_cache_list, remove_think_tags, linear_gradient_weighted_polling, + process_chunks_unified, ) from .base import ( BaseGraphStorage, @@ -3215,128 +3216,3 @@ async def query_with_keywords( ) else: raise ValueError(f"Unknown mode {param.mode}") - - -async def apply_rerank_if_enabled( - query: str, - retrieved_docs: list[dict], - global_config: dict, - enable_rerank: bool = True, - top_n: int = None, -) -> list[dict]: - """ - Apply reranking to retrieved documents if rerank is enabled. - - Args: - query: The search query - retrieved_docs: List of retrieved documents - global_config: Global configuration containing rerank settings - enable_rerank: Whether to enable reranking from query parameter - top_n: Number of top documents to return after reranking - - Returns: - Reranked documents if rerank is enabled, otherwise original documents - """ - if not enable_rerank or not retrieved_docs: - return retrieved_docs - - rerank_func = global_config.get("rerank_model_func") - if not rerank_func: - logger.warning( - "Rerank is enabled but no rerank model is configured. Please set up a rerank model or set enable_rerank=False in query parameters." - ) - return retrieved_docs - - try: - # Apply reranking - let rerank_model_func handle top_k internally - reranked_docs = await rerank_func( - query=query, - documents=retrieved_docs, - top_n=top_n, - ) - if reranked_docs and len(reranked_docs) > 0: - if len(reranked_docs) > top_n: - reranked_docs = reranked_docs[:top_n] - logger.info( - f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}" - ) - return reranked_docs - else: - logger.warning("Rerank returned empty results, using original documents") - return retrieved_docs - - except Exception as e: - logger.error(f"Error during reranking: {e}, using original documents") - return retrieved_docs - - -async def process_chunks_unified( - query: str, - unique_chunks: list[dict], - query_param: QueryParam, - global_config: dict, - source_type: str = "mixed", - chunk_token_limit: int = None, # Add parameter for dynamic token limit -) -> list[dict]: - """ - Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation. - - Args: - query: Search query for reranking - chunks: List of text chunks to process - query_param: Query parameters containing configuration - global_config: Global configuration dictionary - source_type: Source type for logging ("vector", "entity", "relationship", "mixed") - chunk_token_limit: Dynamic token limit for chunks (if None, uses default) - - Returns: - Processed and filtered list of text chunks - """ - if not unique_chunks: - return [] - - # 1. Apply reranking if enabled and query is provided - if query_param.enable_rerank and query and unique_chunks: - rerank_top_k = query_param.chunk_top_k or len(unique_chunks) - unique_chunks = await apply_rerank_if_enabled( - query=query, - retrieved_docs=unique_chunks, - global_config=global_config, - enable_rerank=query_param.enable_rerank, - top_n=rerank_top_k, - ) - logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})") - - # 2. Apply chunk_top_k limiting if specified - if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0: - if len(unique_chunks) > query_param.chunk_top_k: - unique_chunks = unique_chunks[: query_param.chunk_top_k] - logger.debug( - f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})" - ) - - # 3. Token-based final truncation - tokenizer = global_config.get("tokenizer") - if tokenizer and unique_chunks: - # Set default chunk_token_limit if not provided - if chunk_token_limit is None: - # Get default from query_param or global_config - chunk_token_limit = getattr( - query_param, - "max_total_tokens", - global_config.get("MAX_TOTAL_TOKENS", 32000), - ) - - original_count = len(unique_chunks) - unique_chunks = truncate_list_by_token_size( - unique_chunks, - key=lambda x: x.get("content", ""), - max_token_size=chunk_token_limit, - tokenizer=tokenizer, - ) - logger.debug( - f"Token truncation: {len(unique_chunks)} chunks from {original_count} " - f"(chunk available tokens: {chunk_token_limit}, source: {source_type})" - ) - - return unique_chunks diff --git a/lightrag/utils.py b/lightrag/utils.py index 875d2b0f..eb462946 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -55,7 +55,7 @@ def get_env_value( # Use TYPE_CHECKING to avoid circular imports if TYPE_CHECKING: - from lightrag.base import BaseKVStorage + from lightrag.base import BaseKVStorage, QueryParam # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance @@ -1777,3 +1777,123 @@ class TokenTracker: f"Completion tokens: {usage['completion_tokens']}, " f"Total tokens: {usage['total_tokens']}" ) + + +async def apply_rerank_if_enabled( + query: str, + retrieved_docs: list[dict], + global_config: dict, + enable_rerank: bool = True, + top_n: int = None, +) -> list[dict]: + """ + Apply reranking to retrieved documents if rerank is enabled. + + Args: + query: The search query + retrieved_docs: List of retrieved documents + global_config: Global configuration containing rerank settings + enable_rerank: Whether to enable reranking from query parameter + top_n: Number of top documents to return after reranking + + Returns: + Reranked documents if rerank is enabled, otherwise original documents + """ + if not enable_rerank or not retrieved_docs: + return retrieved_docs + + rerank_func = global_config.get("rerank_model_func") + if not rerank_func: + logger.warning( + "Rerank is enabled but no rerank model is configured. Please set up a rerank model or set enable_rerank=False in query parameters." + ) + return retrieved_docs + + try: + # Apply reranking - let rerank_model_func handle top_k internally + reranked_docs = await rerank_func( + query=query, + documents=retrieved_docs, + top_n=top_n, + ) + if reranked_docs and len(reranked_docs) > 0: + if len(reranked_docs) > top_n: + reranked_docs = reranked_docs[:top_n] + logger.info(f"Successfully reranked: {len(retrieved_docs)} chunks") + return reranked_docs + else: + logger.warning("Rerank returned empty results, using original chunks") + return retrieved_docs + + except Exception as e: + logger.error(f"Error during reranking: {e}, using original chunks") + return retrieved_docs + + +async def process_chunks_unified( + query: str, + unique_chunks: list[dict], + query_param: "QueryParam", + global_config: dict, + source_type: str = "mixed", + chunk_token_limit: int = None, # Add parameter for dynamic token limit +) -> list[dict]: + """ + Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation. + + Args: + query: Search query for reranking + chunks: List of text chunks to process + query_param: Query parameters containing configuration + global_config: Global configuration dictionary + source_type: Source type for logging ("vector", "entity", "relationship", "mixed") + chunk_token_limit: Dynamic token limit for chunks (if None, uses default) + + Returns: + Processed and filtered list of text chunks + """ + if not unique_chunks: + return [] + + # 1. Apply reranking if enabled and query is provided + if query_param.enable_rerank and query and unique_chunks: + rerank_top_k = query_param.chunk_top_k or len(unique_chunks) + unique_chunks = await apply_rerank_if_enabled( + query=query, + retrieved_docs=unique_chunks, + global_config=global_config, + enable_rerank=query_param.enable_rerank, + top_n=rerank_top_k, + ) + + # 2. Apply chunk_top_k limiting if specified + if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0: + if len(unique_chunks) > query_param.chunk_top_k: + unique_chunks = unique_chunks[: query_param.chunk_top_k] + logger.info(f"Kept chunk_top-k: {len(unique_chunks)} chunks") + + # 3. Token-based final truncation + tokenizer = global_config.get("tokenizer") + if tokenizer and unique_chunks: + # Set default chunk_token_limit if not provided + if chunk_token_limit is None: + # Get default from query_param or global_config + chunk_token_limit = getattr( + query_param, + "max_total_tokens", + global_config.get("MAX_TOTAL_TOKENS", 32000), + ) + + original_count = len(unique_chunks) + unique_chunks = truncate_list_by_token_size( + unique_chunks, + key=lambda x: x.get("content", ""), + max_token_size=chunk_token_limit, + tokenizer=tokenizer, + ) + logger.debug( + f"Token truncation: {len(unique_chunks)} chunks from {original_count} " + f"(chunk available tokens: {chunk_token_limit}, source: {source_type})" + ) + + return unique_chunks