update chunks truncation method

This commit is contained in:
zrguo 2025-07-08 13:31:05 +08:00
parent f5c80d7cde
commit 04a57445da
5 changed files with 211 additions and 180 deletions

View File

@ -294,6 +294,16 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
"""Number of text chunks to retrieve initially from vector search.
If None, defaults to top_k value.
"""
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
"""Number of text chunks to keep after reranking.
If None, keeps all chunks returned from initial retrieval.
"""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
"""Maximum number of tokens allowed for each retrieved text chunk."""

View File

@ -153,7 +153,7 @@ curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_d
python examples/lightrag_openai_demo.py
```
For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample codes LLM and embedding configurations accordingly.
For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample code's LLM and embedding configurations accordingly.
**Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory.
@ -300,6 +300,16 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
"""Number of text chunks to retrieve initially from vector search.
If None, defaults to top_k value.
"""
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
"""Number of text chunks to keep after reranking.
If None, keeps all chunks returned from initial retrieval.
"""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
"""Maximum number of tokens allowed for each retrieved text chunk."""

View File

@ -46,7 +46,9 @@ OLLAMA_EMULATING_MODEL_TAG=latest
# HISTORY_TURNS=3
# COSINE_THRESHOLD=0.2
# TOP_K=60
# MAX_TOKEN_TEXT_CHUNK=4000
# CHUNK_TOP_K=5
# CHUNK_RERANK_TOP_K=5
# MAX_TOKEN_TEXT_CHUNK=6000
# MAX_TOKEN_RELATION_DESC=4000
# MAX_TOKEN_ENTITY_DESC=4000

View File

@ -60,7 +60,17 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
"""Number of text chunks to retrieve initially from vector search.
If None, defaults to top_k value.
"""
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
"""Number of text chunks to keep after reranking.
If None, keeps all chunks returned from initial retrieval.
"""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "6000"))
"""Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_global_context: int = int(
@ -280,21 +290,6 @@ class BaseKVStorage(StorageNameSpace, ABC):
False: if the cache drop failed, or the cache mode is not supported
"""
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
# """Delete specific cache records from storage by chunk IDs
# Importance notes for in-memory storage:
# 1. Changes will be persisted to disk during the next index_done_callback
# 2. update flags to notify other processes that data persistence is needed
# Args:
# chunk_ids (list[str]): List of chunk IDs to be dropped from storage
# Returns:
# True: if the cache drop successfully
# False: if the cache drop failed, or the operation is not supported
# """
@dataclass
class BaseGraphStorage(StorageNameSpace, ABC):

View File

@ -1526,6 +1526,7 @@ async def kg_query(
# Build context
context = await _build_query_context(
query,
ll_keywords_str,
hl_keywords_str,
knowledge_graph_inst,
@ -1744,93 +1745,52 @@ async def _get_vector_context(
query: str,
chunks_vdb: BaseVectorStorage,
query_param: QueryParam,
tokenizer: Tokenizer,
) -> tuple[list, list, list] | None:
) -> list[dict]:
"""
Retrieve vector context from the vector database.
Retrieve text chunks from the vector database without reranking or truncation.
This function performs vector search to find relevant text chunks for a query,
formats them with file path and creation time information.
This function performs vector search to find relevant text chunks for a query.
Reranking and truncation will be handled later in the unified processing.
Args:
query: The query string to search for
chunks_vdb: Vector database containing document chunks
query_param: Query parameters including top_k and ids
tokenizer: Tokenizer for counting tokens
query_param: Query parameters including chunk_top_k and ids
Returns:
Tuple (empty_entities, empty_relations, text_units) for combine_contexts,
compatible with _get_edge_data and _get_node_data format
List of text chunks with metadata
"""
try:
results = await chunks_vdb.query(
query, top_k=query_param.top_k, ids=query_param.ids
)
# Use chunk_top_k if specified, otherwise fall back to top_k
search_top_k = query_param.chunk_top_k or query_param.top_k
results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids)
if not results:
return [], [], []
return []
valid_chunks = []
for result in results:
if "content" in result:
# Directly use content from chunks_vdb.query result
chunk_with_time = {
chunk_with_metadata = {
"content": result["content"],
"created_at": result.get("created_at", None),
"file_path": result.get("file_path", "unknown_source"),
"source_type": "vector", # Mark the source type
}
valid_chunks.append(chunk_with_time)
if not valid_chunks:
return [], [], []
# Apply reranking if enabled
global_config = chunks_vdb.global_config
valid_chunks = await apply_rerank_if_enabled(
query=query,
retrieved_docs=valid_chunks,
global_config=global_config,
top_k=query_param.top_k,
)
maybe_trun_chunks = truncate_list_by_token_size(
valid_chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
valid_chunks.append(chunk_with_metadata)
logger.debug(
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
)
logger.info(
f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
f"Vector search retrieved {len(valid_chunks)} chunks (top_k: {search_top_k})"
)
return valid_chunks
if not maybe_trun_chunks:
return [], [], []
# Create empty entities and relations contexts
entities_context = []
relations_context = []
# Create text_units_context directly as a list of dictionaries
text_units_context = []
for i, chunk in enumerate(maybe_trun_chunks):
text_units_context.append(
{
"id": i + 1,
"content": chunk["content"],
"file_path": chunk["file_path"],
}
)
return entities_context, relations_context, text_units_context
except Exception as e:
logger.error(f"Error in _get_vector_context: {e}")
return [], [], []
return []
async def _build_query_context(
query: str,
ll_keywords: str,
hl_keywords: str,
knowledge_graph_inst: BaseGraphStorage,
@ -1838,27 +1798,36 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None, # Add chunks_vdb parameter for mix mode
chunks_vdb: BaseVectorStorage = None,
):
logger.info(f"Process {os.getpid()} building query context...")
# Handle local and global modes as before
# Collect all chunks from different sources
all_chunks = []
entities_context = []
relations_context = []
# Handle local and global modes
if query_param.mode == "local":
entities_context, relations_context, text_units_context = await _get_node_data(
entities_context, relations_context, entity_chunks = await _get_node_data(
ll_keywords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
all_chunks.extend(entity_chunks)
elif query_param.mode == "global":
entities_context, relations_context, text_units_context = await _get_edge_data(
entities_context, relations_context, relationship_chunks = await _get_edge_data(
hl_keywords,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param,
)
all_chunks.extend(relationship_chunks)
else: # hybrid or mix mode
ll_data = await _get_node_data(
ll_keywords,
@ -1875,61 +1844,58 @@ async def _build_query_context(
query_param,
)
(
ll_entities_context,
ll_relations_context,
ll_text_units_context,
) = ll_data
(ll_entities_context, ll_relations_context, ll_chunks) = ll_data
(hl_entities_context, hl_relations_context, hl_chunks) = hl_data
(
hl_entities_context,
hl_relations_context,
hl_text_units_context,
) = hl_data
# Collect chunks from entity and relationship sources
all_chunks.extend(ll_chunks)
all_chunks.extend(hl_chunks)
# Initialize vector data with empty lists
vector_entities_context, vector_relations_context, vector_text_units_context = (
[],
[],
[],
)
# Only get vector data if in mix mode
if query_param.mode == "mix" and hasattr(query_param, "original_query"):
# Get tokenizer from text_chunks_db
tokenizer = text_chunks_db.global_config.get("tokenizer")
# Get vector context in triple format
vector_data = await _get_vector_context(
query_param.original_query, # We need to pass the original query
# Get vector chunks if in mix mode
if query_param.mode == "mix" and chunks_vdb:
vector_chunks = await _get_vector_context(
query,
chunks_vdb,
query_param,
tokenizer,
)
all_chunks.extend(vector_chunks)
# If vector_data is not None, unpack it
if vector_data is not None:
(
vector_entities_context,
vector_relations_context,
vector_text_units_context,
) = vector_data
# Combine and deduplicate the entities, relationships, and sources
# Combine entities and relations contexts
entities_context = process_combine_contexts(
hl_entities_context, ll_entities_context, vector_entities_context
hl_entities_context, ll_entities_context
)
relations_context = process_combine_contexts(
hl_relations_context, ll_relations_context, vector_relations_context
hl_relations_context, ll_relations_context
)
text_units_context = process_combine_contexts(
hl_text_units_context, ll_text_units_context, vector_text_units_context
# Process all chunks uniformly: deduplication, reranking, and token truncation
processed_chunks = await process_chunks_unified(
query=query,
chunks=all_chunks,
query_param=query_param,
global_config=text_chunks_db.global_config,
source_type="mixed",
)
# Build final text_units_context from processed chunks
text_units_context = []
for i, chunk in enumerate(processed_chunks):
text_units_context.append(
{
"id": i + 1,
"content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
}
)
logger.info(
f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks"
)
# not necessary to use LLM to generate a response
if not entities_context and not relations_context:
return None
# 转换为 JSON 字符串
entities_str = json.dumps(entities_context, ensure_ascii=False)
relations_str = json.dumps(relations_context, ensure_ascii=False)
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
@ -1975,15 +1941,6 @@ async def _get_node_data(
if not len(results):
return "", "", ""
# Apply reranking if enabled for entity results
global_config = entities_vdb.global_config
results = await apply_rerank_if_enabled(
query=query,
retrieved_docs=results,
global_config=global_config,
top_k=query_param.top_k,
)
# Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results]
@ -2085,16 +2042,7 @@ async def _get_node_data(
}
)
text_units_context = []
for i, t in enumerate(use_text_units):
text_units_context.append(
{
"id": i + 1,
"content": t["content"],
"file_path": t.get("file_path", "unknown_source"),
}
)
return entities_context, relations_context, text_units_context
return entities_context, relations_context, use_text_units
async def _find_most_related_text_unit_from_entities(
@ -2183,23 +2131,21 @@ async def _find_most_related_text_unit_from_entities(
logger.warning("No valid text units found")
return []
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
# Sort by relation counts and order, but don't truncate
all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
)
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug(
f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
)
logger.debug(f"Found {len(all_text_units)} entity-related chunks")
all_text_units = [t["data"] for t in all_text_units]
return all_text_units
# Add source type marking and return chunk data
result_chunks = []
for t in all_text_units:
chunk_data = t["data"].copy()
chunk_data["source_type"] = "entity"
result_chunks.append(chunk_data)
return result_chunks
async def _find_most_related_edges_from_entities(
@ -2287,15 +2233,6 @@ async def _get_edge_data(
if not len(results):
return "", "", ""
# Apply reranking if enabled for relationship results
global_config = relationships_vdb.global_config
results = await apply_rerank_if_enabled(
query=keywords,
retrieved_docs=results,
global_config=global_config,
top_k=query_param.top_k,
)
# Prepare edge pairs in two forms:
# For the batch edge properties function, use dicts.
edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
@ -2510,21 +2447,16 @@ async def _find_related_text_unit_from_relationships(
logger.warning("No valid text chunks after filtering")
return []
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
truncated_text_units = truncate_list_by_token_size(
valid_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug(f"Found {len(valid_text_units)} relationship-related chunks")
logger.debug(
f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
)
# Add source type marking and return chunk data
result_chunks = []
for t in valid_text_units:
chunk_data = t["data"].copy()
chunk_data["source_type"] = "relationship"
result_chunks.append(chunk_data)
all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units]
return all_text_units
return result_chunks
async def naive_query(
@ -2552,12 +2484,30 @@ async def naive_query(
tokenizer: Tokenizer = global_config["tokenizer"]
_, _, text_units_context = await _get_vector_context(
query, chunks_vdb, query_param, tokenizer
chunks = await _get_vector_context(query, chunks_vdb, query_param)
if chunks is None or len(chunks) == 0:
return PROMPTS["fail_response"]
# Process chunks using unified processing
processed_chunks = await process_chunks_unified(
query=query,
chunks=chunks,
query_param=query_param,
global_config=global_config,
source_type="vector",
)
if text_units_context is None or len(text_units_context) == 0:
return PROMPTS["fail_response"]
# Build text_units_context from processed chunks
text_units_context = []
for i, chunk in enumerate(processed_chunks):
text_units_context.append(
{
"id": i + 1,
"content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
}
)
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
if query_param.only_need_context:
@ -2683,6 +2633,7 @@ async def kg_query_with_keywords(
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
context = await _build_query_context(
query,
ll_keywords_str,
hl_keywords_str,
knowledge_graph_inst,
@ -2805,8 +2756,6 @@ async def query_with_keywords(
f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}"
)
param.original_query = query
# Use appropriate query method based on mode
if param.mode in ["local", "global", "hybrid", "mix"]:
return await kg_query_with_keywords(
@ -2887,3 +2836,68 @@ async def apply_rerank_if_enabled(
except Exception as e:
logger.error(f"Error during reranking: {e}, using original documents")
return retrieved_docs
async def process_chunks_unified(
query: str,
chunks: list[dict],
query_param: QueryParam,
global_config: dict,
source_type: str = "mixed",
) -> list[dict]:
"""
Unified processing for text chunks: deduplication, reranking, and token truncation.
Args:
query: Search query for reranking
chunks: List of text chunks to process
query_param: Query parameters containing configuration
global_config: Global configuration dictionary
source_type: Source type for logging ("vector", "entity", "relationship", "mixed")
Returns:
Processed and filtered list of text chunks
"""
if not chunks:
return []
# 1. Deduplication based on content
seen_content = set()
unique_chunks = []
for chunk in chunks:
content = chunk.get("content", "")
if content and content not in seen_content:
seen_content.add(content)
unique_chunks.append(chunk)
logger.debug(
f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})"
)
# 2. Apply reranking if enabled and query is provided
if global_config.get("enable_rerank", False) and query and unique_chunks:
rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks)
unique_chunks = await apply_rerank_if_enabled(
query=query,
retrieved_docs=unique_chunks,
global_config=global_config,
top_k=rerank_top_k,
)
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
# 3. Token-based final truncation
tokenizer = global_config.get("tokenizer")
if tokenizer and unique_chunks:
original_count = len(unique_chunks)
unique_chunks = truncate_list_by_token_size(
unique_chunks,
key=lambda x: x.get("content", ""),
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug(
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
f"(max tokens: {query_param.max_token_for_text_unit}, source: {source_type})"
)
return unique_chunks