from __future__ import annotations from functools import partial import asyncio import json import re import os import json_repair from typing import Any, AsyncIterator from collections import Counter, defaultdict from .utils import ( logger, clean_str, compute_mdhash_id, Tokenizer, is_float_regex, normalize_extracted_info, pack_user_ass_to_openai_messages, split_string_by_multi_markers, truncate_list_by_token_size, compute_args_hash, handle_cache, save_to_cache, CacheData, get_conversation_turns, use_llm_func_with_cache, update_chunk_cache_list, remove_think_tags, pick_by_weighted_polling, pick_by_vector_similarity, process_chunks_unified, build_file_path, ) from .base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, TextChunkSchema, QueryParam, ) from .prompt import PROMPTS from .constants import ( GRAPH_FIELD_SEP, DEFAULT_MAX_ENTITY_TOKENS, DEFAULT_MAX_RELATION_TOKENS, DEFAULT_MAX_TOTAL_TOKENS, DEFAULT_RELATED_CHUNK_NUMBER, DEFAULT_KG_CHUNK_PICK_METHOD, DEFAULT_ENTITY_TYPES, DEFAULT_SUMMARY_LANGUAGE, ) from .kg.shared_storage import get_storage_keyed_lock import time from dotenv import load_dotenv # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance # the OS environment variables take precedence over the .env file load_dotenv(dotenv_path=".env", override=False) def chunking_by_token_size( tokenizer: Tokenizer, content: str, split_by_character: str | None = None, split_by_character_only: bool = False, overlap_token_size: int = 128, max_token_size: int = 1024, ) -> list[dict[str, Any]]: tokens = tokenizer.encode(content) results: list[dict[str, Any]] = [] if split_by_character: raw_chunks = content.split(split_by_character) new_chunks = [] if split_by_character_only: for chunk in raw_chunks: _tokens = tokenizer.encode(chunk) new_chunks.append((len(_tokens), chunk)) else: for chunk in raw_chunks: _tokens = tokenizer.encode(chunk) if len(_tokens) > max_token_size: for start in range( 0, len(_tokens), max_token_size - overlap_token_size ): chunk_content = tokenizer.decode( _tokens[start : start + max_token_size] ) new_chunks.append( (min(max_token_size, len(_tokens) - start), chunk_content) ) else: new_chunks.append((len(_tokens), chunk)) for index, (_len, chunk) in enumerate(new_chunks): results.append( { "tokens": _len, "content": chunk.strip(), "chunk_order_index": index, } ) else: for index, start in enumerate( range(0, len(tokens), max_token_size - overlap_token_size) ): chunk_content = tokenizer.decode(tokens[start : start + max_token_size]) results.append( { "tokens": min(max_token_size, len(tokens) - start), "content": chunk_content.strip(), "chunk_order_index": index, } ) return results async def _handle_entity_relation_summary( description_type: str, entity_or_relation_name: str, description_list: list[str], seperator: str, global_config: dict, llm_response_cache: BaseKVStorage | None = None, ) -> tuple[str, bool]: """Handle entity relation description summary using map-reduce approach. This function summarizes a list of descriptions using a map-reduce strategy: 1. If total tokens < summary_context_size and len(description_list) < force_llm_summary_on_merge, no need to summarize 2. If total tokens < summary_max_tokens, summarize with LLM directly 3. Otherwise, split descriptions into chunks that fit within token limits 4. Summarize each chunk, then recursively process the summaries 5. Continue until we get a final summary within token limits or num of descriptions is less than force_llm_summary_on_merge Args: entity_or_relation_name: Name of the entity or relation being summarized description_list: List of description strings to summarize global_config: Global configuration containing tokenizer and limits llm_response_cache: Optional cache for LLM responses Returns: Tuple of (final_summarized_description_string, llm_was_used_boolean) """ # Handle empty input if not description_list: return "", False # If only one description, return it directly (no need for LLM call) if len(description_list) == 1: return description_list[0], False # Get configuration tokenizer: Tokenizer = global_config["tokenizer"] summary_context_size = global_config["summary_context_size"] summary_max_tokens = global_config["summary_max_tokens"] force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] current_list = description_list[:] # Copy the list to avoid modifying original llm_was_used = False # Track whether LLM was used during the entire process # Iterative map-reduce process while True: # Calculate total tokens in current list total_tokens = sum(len(tokenizer.encode(desc)) for desc in current_list) # If total length is within limits, perform final summarization if total_tokens <= summary_context_size or len(current_list) <= 2: if ( len(current_list) < force_llm_summary_on_merge and total_tokens < summary_max_tokens ): # no LLM needed, just join the descriptions final_description = seperator.join(current_list) return final_description if final_description else "", llm_was_used else: if total_tokens > summary_context_size and len(current_list) <= 2: logger.warning( f"Summarizing {entity_or_relation_name}: Oversize descpriton found" ) # Final summarization of remaining descriptions - LLM will be used final_summary = await _summarize_descriptions( description_type, entity_or_relation_name, current_list, global_config, llm_response_cache, ) return final_summary, True # LLM was used for final summarization # Need to split into chunks - Map phase # Ensure each chunk has minimum 2 descriptions to guarantee progress chunks = [] current_chunk = [] current_tokens = 0 # Currently least 3 descriptions in current_list for i, desc in enumerate(current_list): desc_tokens = len(tokenizer.encode(desc)) # If adding current description would exceed limit, finalize current chunk if current_tokens + desc_tokens > summary_context_size and current_chunk: # Ensure we have at least 2 descriptions in the chunk (when possible) if len(current_chunk) == 1: # Force add one more description to ensure minimum 2 per chunk current_chunk.append(desc) chunks.append(current_chunk) logger.warning( f"Summarizing {entity_or_relation_name}: Oversize descpriton found" ) current_chunk = [] # next group is empty current_tokens = 0 else: # curren_chunk is ready for summary in reduce phase chunks.append(current_chunk) current_chunk = [desc] # leave it for next group current_tokens = desc_tokens else: current_chunk.append(desc) current_tokens += desc_tokens # Add the last chunk if it exists if current_chunk: chunks.append(current_chunk) logger.info( f" Summarizing {entity_or_relation_name}: Map {len(current_list)} descriptions into {len(chunks)} groups" ) # Reduce phase: summarize each group from chunks new_summaries = [] for chunk in chunks: if len(chunk) == 1: # Optimization: single description chunks don't need LLM summarization new_summaries.append(chunk[0]) else: # Multiple descriptions need LLM summarization summary = await _summarize_descriptions( description_type, entity_or_relation_name, chunk, global_config, llm_response_cache, ) new_summaries.append(summary) llm_was_used = True # Mark that LLM was used in reduce phase # Update current list with new summaries for next iteration current_list = new_summaries async def _summarize_descriptions( description_type: str, description_name: str, description_list: list[str], global_config: dict, llm_response_cache: BaseKVStorage | None = None, ) -> str: """Helper function to summarize a list of descriptions using LLM. Args: entity_or_relation_name: Name of the entity or relation being summarized descriptions: List of description strings to summarize global_config: Global configuration containing LLM function and settings llm_response_cache: Optional cache for LLM responses Returns: Summarized description string """ use_llm_func: callable = global_config["llm_model_func"] # Apply higher priority (8) to entity/relation summary tasks use_llm_func = partial(use_llm_func, _priority=8) language = global_config["addon_params"].get( "language", DEFAULT_SUMMARY_LANGUAGE] ) summary_length_recommended = global_config["summary_length_recommended"] prompt_template = PROMPTS["summarize_entity_descriptions"] # Join descriptions and apply token-based truncation if necessary joined_descriptions = "\n\n".join(description_list) tokenizer = global_config["tokenizer"] summary_context_size = global_config["summary_context_size"] # Token-based truncation to ensure input fits within limits tokens = tokenizer.encode(joined_descriptions) if len(tokens) > summary_context_size: truncated_tokens = tokens[:summary_context_size] joined_descriptions = tokenizer.decode(truncated_tokens) # Prepare context for the prompt context_base = dict( description_type=description_type, description_name=description_name, description_list=joined_descriptions, summary_length=summary_length_recommended, language=language, ) use_prompt = prompt_template.format(**context_base) logger.debug( f"Summarizing {len(description_list)} descriptions for: {description_name}" ) # Use LLM function with cache (higher priority for summary generation) summary = await use_llm_func_with_cache( use_prompt, use_llm_func, llm_response_cache=llm_response_cache, cache_type="extract", ) return summary async def _handle_single_entity_extraction( record_attributes: list[str], chunk_key: str, file_path: str = "unknown_source", ): if len(record_attributes) < 4 or '"entity"' not in record_attributes[0]: return None # Clean and validate entity name entity_name = clean_str(record_attributes[1]).strip() if not entity_name: logger.warning( f"Entity extraction error: empty entity name in: {record_attributes}" ) return None # Normalize entity name entity_name = normalize_extracted_info(entity_name, is_entity=True) # Check if entity name became empty after normalization if not entity_name or not entity_name.strip(): logger.warning( f"Entity extraction error: entity name became empty after normalization. Original: '{record_attributes[1]}'" ) return None # Clean and validate entity type entity_type = clean_str(record_attributes[2]).strip('"') if not entity_type.strip() or entity_type.startswith('("'): logger.warning( f"Entity extraction error: invalid entity type in: {record_attributes}" ) return None # Clean and validate description entity_description = clean_str(record_attributes[3]) entity_description = normalize_extracted_info(entity_description) if not entity_description.strip(): logger.warning( f"Entity extraction error: empty description for entity '{entity_name}' of type '{entity_type}'" ) return None return dict( entity_name=entity_name, entity_type=entity_type, description=entity_description, source_id=chunk_key, file_path=file_path, ) async def _handle_single_relationship_extraction( record_attributes: list[str], chunk_key: str, file_path: str = "unknown_source", ): if len(record_attributes) < 5 or '"relationship"' not in record_attributes[0]: return None # add this record as edge source = clean_str(record_attributes[1]) target = clean_str(record_attributes[2]) # Normalize source and target entity names source = normalize_extracted_info(source, is_entity=True) target = normalize_extracted_info(target, is_entity=True) # Check if source or target became empty after normalization if not source or not source.strip(): logger.warning( f"Relationship extraction error: source entity became empty after normalization. Original: '{record_attributes[1]}'" ) return None if not target or not target.strip(): logger.warning( f"Relationship extraction error: target entity became empty after normalization. Original: '{record_attributes[2]}'" ) return None if source == target: logger.debug( f"Relationship source and target are the same in: {record_attributes}" ) return None edge_description = clean_str(record_attributes[3]) edge_description = normalize_extracted_info(edge_description) edge_keywords = normalize_extracted_info( clean_str(record_attributes[4]), is_entity=True ) edge_keywords = edge_keywords.replace(",", ",") edge_source_id = chunk_key weight = ( float(record_attributes[-1].strip('"').strip("'")) if is_float_regex(record_attributes[-1].strip('"').strip("'")) else 1.0 ) return dict( src_id=source, tgt_id=target, weight=weight, description=edge_description, keywords=edge_keywords, source_id=edge_source_id, file_path=file_path, ) async def _rebuild_knowledge_from_chunks( entities_to_rebuild: dict[str, set[str]], relationships_to_rebuild: dict[tuple[str, str], set[str]], knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, text_chunks_storage: BaseKVStorage, llm_response_cache: BaseKVStorage, global_config: dict[str, str], pipeline_status: dict | None = None, pipeline_status_lock=None, ) -> None: """Rebuild entity and relationship descriptions from cached extraction results with parallel processing This method uses cached LLM extraction results instead of calling LLM again, following the same approach as the insert process. Now with parallel processing controlled by llm_model_max_async and using get_storage_keyed_lock for data consistency. Args: entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids knowledge_graph_inst: Knowledge graph storage entities_vdb: Entity vector database relationships_vdb: Relationship vector database text_chunks_storage: Text chunks storage llm_response_cache: LLM response cache global_config: Global configuration containing llm_model_max_async pipeline_status: Pipeline status dictionary pipeline_status_lock: Lock for pipeline status """ if not entities_to_rebuild and not relationships_to_rebuild: return # Get all referenced chunk IDs all_referenced_chunk_ids = set() for chunk_ids in entities_to_rebuild.values(): all_referenced_chunk_ids.update(chunk_ids) for chunk_ids in relationships_to_rebuild.values(): all_referenced_chunk_ids.update(chunk_ids) status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions (parallel processing)" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) # Get cached extraction results for these chunks using storage # cached_results: chunk_id -> [list of extraction result from LLM cache sorted by created_at] cached_results = await _get_cached_extraction_results( llm_response_cache, all_referenced_chunk_ids, text_chunks_storage=text_chunks_storage, ) if not cached_results: status_message = "No cached extraction results found, cannot rebuild" logger.warning(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) return # Process cached results to get entities and relationships for each chunk chunk_entities = {} # chunk_id -> {entity_name: [entity_data]} chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]} for chunk_id, extraction_results in cached_results.items(): try: # Handle multiple extraction results per chunk chunk_entities[chunk_id] = defaultdict(list) chunk_relationships[chunk_id] = defaultdict(list) # process multiple LLM extraction results for a single chunk_id for extraction_result in extraction_results: entities, relationships = await _parse_extraction_result( text_chunks_storage=text_chunks_storage, extraction_result=extraction_result, chunk_id=chunk_id, ) # Merge entities and relationships from this extraction result # Only keep the first occurrence of each entity_name in the same chunk_id for entity_name, entity_list in entities.items(): if ( entity_name not in chunk_entities[chunk_id] or len(chunk_entities[chunk_id][entity_name]) == 0 ): chunk_entities[chunk_id][entity_name].extend(entity_list) # Only keep the first occurrence of each rel_key in the same chunk_id for rel_key, rel_list in relationships.items(): if ( rel_key not in chunk_relationships[chunk_id] or len(chunk_relationships[chunk_id][rel_key]) == 0 ): chunk_relationships[chunk_id][rel_key].extend(rel_list) except Exception as e: status_message = ( f"Failed to parse cached extraction result for chunk {chunk_id}: {e}" ) logger.info(status_message) # Per requirement, change to info if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) continue # Get max async tasks limit from global_config for semaphore control graph_max_async = global_config.get("llm_model_max_async", 4) * 2 semaphore = asyncio.Semaphore(graph_max_async) # Counters for tracking progress rebuilt_entities_count = 0 rebuilt_relationships_count = 0 failed_entities_count = 0 failed_relationships_count = 0 async def _locked_rebuild_entity(entity_name, chunk_ids): nonlocal rebuilt_entities_count, failed_entities_count async with semaphore: workspace = global_config.get("workspace", "") namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" async with get_storage_keyed_lock( [entity_name], namespace=namespace, enable_logging=False ): try: await _rebuild_single_entity( knowledge_graph_inst=knowledge_graph_inst, entities_vdb=entities_vdb, entity_name=entity_name, chunk_ids=chunk_ids, chunk_entities=chunk_entities, llm_response_cache=llm_response_cache, global_config=global_config, ) rebuilt_entities_count += 1 status_message = ( f"Rebuilt `{entity_name}` from {len(chunk_ids)} chunks" ) logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) except Exception as e: failed_entities_count += 1 status_message = f"Failed to rebuild `{entity_name}`: {e}" logger.info(status_message) # Per requirement, change to info if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) async def _locked_rebuild_relationship(src, tgt, chunk_ids): nonlocal rebuilt_relationships_count, failed_relationships_count async with semaphore: workspace = global_config.get("workspace", "") namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" # Sort src and tgt to ensure order-independent lock key generation sorted_key_parts = sorted([src, tgt]) async with get_storage_keyed_lock( sorted_key_parts, namespace=namespace, enable_logging=False, ): try: await _rebuild_single_relationship( knowledge_graph_inst=knowledge_graph_inst, relationships_vdb=relationships_vdb, src=src, tgt=tgt, chunk_ids=chunk_ids, chunk_relationships=chunk_relationships, llm_response_cache=llm_response_cache, global_config=global_config, ) rebuilt_relationships_count += 1 status_message = ( f"Rebuilt `{src} - {tgt}` from {len(chunk_ids)} chunks" ) logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) except Exception as e: failed_relationships_count += 1 status_message = f"Failed to rebuild `{src} - {tgt}`: {e}" logger.info(status_message) # Per requirement, change to info if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) # Create tasks for parallel processing tasks = [] # Add entity rebuilding tasks for entity_name, chunk_ids in entities_to_rebuild.items(): task = asyncio.create_task(_locked_rebuild_entity(entity_name, chunk_ids)) tasks.append(task) # Add relationship rebuilding tasks for (src, tgt), chunk_ids in relationships_to_rebuild.items(): task = asyncio.create_task(_locked_rebuild_relationship(src, tgt, chunk_ids)) tasks.append(task) # Log parallel processing start status_message = f"Starting parallel rebuild of {len(entities_to_rebuild)} entities and {len(relationships_to_rebuild)} relationships (async: {graph_max_async})" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) # Execute all tasks in parallel with semaphore control and early failure detection done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) # Check if any task raised an exception for task in done: if task.exception(): # If a task failed, cancel all pending tasks for pending_task in pending: pending_task.cancel() # Wait for cancellation to complete if pending: await asyncio.wait(pending) # Re-raise the exception to notify the caller raise task.exception() # Final status report status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships rebuilt successfully." if failed_entities_count > 0 or failed_relationships_count > 0: status_message += f" Failed: {failed_entities_count} entities, {failed_relationships_count} relationships." logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) async def _get_cached_extraction_results( llm_response_cache: BaseKVStorage, chunk_ids: set[str], text_chunks_storage: BaseKVStorage, ) -> dict[str, list[str]]: """Get cached extraction results for specific chunk IDs This function retrieves cached LLM extraction results for the given chunk IDs and returns them sorted by creation time. The results are sorted at two levels: 1. Individual extraction results within each chunk are sorted by create_time (earliest first) 2. Chunks themselves are sorted by the create_time of their earliest extraction result Args: llm_response_cache: LLM response cache storage chunk_ids: Set of chunk IDs to get cached results for text_chunks_storage: Text chunks storage for retrieving chunk data and LLM cache references Returns: Dict mapping chunk_id -> list of extraction_result_text, where: - Keys (chunk_ids) are ordered by the create_time of their first extraction result - Values (extraction results) are ordered by create_time within each chunk """ cached_results = {} # Collect all LLM cache IDs from chunks all_cache_ids = set() # Read from storage chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids)) for chunk_data in chunk_data_list: if chunk_data and isinstance(chunk_data, dict): llm_cache_list = chunk_data.get("llm_cache_list", []) if llm_cache_list: all_cache_ids.update(llm_cache_list) else: logger.warning(f"Chunk data is invalid or None: {chunk_data}") if not all_cache_ids: logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs") return cached_results # Batch get LLM cache entries cache_data_list = await llm_response_cache.get_by_ids(list(all_cache_ids)) # Process cache entries and group by chunk_id valid_entries = 0 for cache_entry in cache_data_list: if ( cache_entry is not None and isinstance(cache_entry, dict) and cache_entry.get("cache_type") == "extract" and cache_entry.get("chunk_id") in chunk_ids ): chunk_id = cache_entry["chunk_id"] extraction_result = cache_entry["return"] create_time = cache_entry.get( "create_time", 0 ) # Get creation time, default to 0 valid_entries += 1 # Support multiple LLM caches per chunk if chunk_id not in cached_results: cached_results[chunk_id] = [] # Store tuple with extraction result and creation time for sorting cached_results[chunk_id].append((extraction_result, create_time)) # Sort extraction results by create_time for each chunk and collect earliest times chunk_earliest_times = {} for chunk_id in cached_results: # Sort by create_time (x[1]), then extract only extraction_result (x[0]) cached_results[chunk_id].sort(key=lambda x: x[1]) # Store the earliest create_time for this chunk (first item after sorting) chunk_earliest_times[chunk_id] = cached_results[chunk_id][0][1] # Extract only extraction_result (x[0]) cached_results[chunk_id] = [item[0] for item in cached_results[chunk_id]] # Sort cached_results by the earliest create_time of each chunk sorted_chunk_ids = sorted( chunk_earliest_times.keys(), key=lambda chunk_id: chunk_earliest_times[chunk_id] ) # Rebuild cached_results in sorted order sorted_cached_results = {} for chunk_id in sorted_chunk_ids: sorted_cached_results[chunk_id] = cached_results[chunk_id] logger.info( f"Found {valid_entries} valid cache entries, {len(sorted_cached_results)} chunks with results" ) return sorted_cached_results async def _parse_extraction_result( text_chunks_storage: BaseKVStorage, extraction_result: str, chunk_id: str ) -> tuple[dict, dict]: """Parse cached extraction result using the same logic as extract_entities Args: text_chunks_storage: Text chunks storage to get chunk data extraction_result: The cached LLM extraction result chunk_id: The chunk ID for source tracking Returns: Tuple of (entities_dict, relationships_dict) """ # Get chunk data for file_path from storage chunk_data = await text_chunks_storage.get_by_id(chunk_id) file_path = ( chunk_data.get("file_path", "unknown_source") if chunk_data else "unknown_source" ) context_base = dict( tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], ) maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) # Parse the extraction result using the same logic as in extract_entities records = split_string_by_multi_markers( extraction_result, [context_base["record_delimiter"], context_base["completion_delimiter"]], ) for record in records: record = re.search(r"\((.*)\)", record) if record is None: continue record = record.group(1) record_attributes = split_string_by_multi_markers( record, [context_base["tuple_delimiter"]] ) # Try to parse as entity entity_data = await _handle_single_entity_extraction( record_attributes, chunk_id, file_path ) if entity_data is not None: maybe_nodes[entity_data["entity_name"]].append(entity_data) continue # Try to parse as relationship relationship_data = await _handle_single_relationship_extraction( record_attributes, chunk_id, file_path ) if relationship_data is not None: maybe_edges[ (relationship_data["src_id"], relationship_data["tgt_id"]) ].append(relationship_data) return dict(maybe_nodes), dict(maybe_edges) async def _rebuild_single_entity( knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, entity_name: str, chunk_ids: set[str], chunk_entities: dict, llm_response_cache: BaseKVStorage, global_config: dict[str, str], ) -> None: """Rebuild a single entity from cached extraction results""" # Get current entity data current_entity = await knowledge_graph_inst.get_node(entity_name) if not current_entity: return # Helper function to update entity in both graph and vector storage async def _update_entity_storage( final_description: str, entity_type: str, file_paths: set[str] ): # Update entity in graph storage updated_entity_data = { **current_entity, "description": final_description, "entity_type": entity_type, "source_id": GRAPH_FIELD_SEP.join(chunk_ids), "file_path": GRAPH_FIELD_SEP.join(file_paths) if file_paths else current_entity.get("file_path", "unknown_source"), } await knowledge_graph_inst.upsert_node(entity_name, updated_entity_data) # Update entity in vector database entity_vdb_id = compute_mdhash_id(entity_name, prefix="ent-") entity_content = f"{entity_name}\n{final_description}" await entities_vdb.upsert( { entity_vdb_id: { "content": entity_content, "entity_name": entity_name, "source_id": updated_entity_data["source_id"], "description": final_description, "entity_type": entity_type, "file_path": updated_entity_data["file_path"], } } ) # Collect all entity data from relevant chunks all_entity_data = [] for chunk_id in chunk_ids: if chunk_id in chunk_entities and entity_name in chunk_entities[chunk_id]: all_entity_data.extend(chunk_entities[chunk_id][entity_name]) if not all_entity_data: logger.warning( f"No entity data found for `{entity_name}`, trying to rebuild from relationships" ) # Get all edges connected to this entity edges = await knowledge_graph_inst.get_node_edges(entity_name) if not edges: logger.warning(f"No relations attached to entity `{entity_name}`") return # Collect relationship data to extract entity information relationship_descriptions = [] file_paths = set() # Get edge data for all connected relationships for src_id, tgt_id in edges: edge_data = await knowledge_graph_inst.get_edge(src_id, tgt_id) if edge_data: if edge_data.get("description"): relationship_descriptions.append(edge_data["description"]) if edge_data.get("file_path"): edge_file_paths = edge_data["file_path"].split(GRAPH_FIELD_SEP) file_paths.update(edge_file_paths) # deduplicate descriptions description_list = list(dict.fromkeys(relationship_descriptions)) # Generate final description from relationships or fallback to current if description_list: final_description, _ = await _handle_entity_relation_summary( "Entity", entity_name, description_list, GRAPH_FIELD_SEP, global_config, llm_response_cache=llm_response_cache, ) else: final_description = current_entity.get("description", "") entity_type = current_entity.get("entity_type", "UNKNOWN") await _update_entity_storage(final_description, entity_type, file_paths) return # Process cached entity data descriptions = [] entity_types = [] file_paths = set() for entity_data in all_entity_data: if entity_data.get("description"): descriptions.append(entity_data["description"]) if entity_data.get("entity_type"): entity_types.append(entity_data["entity_type"]) if entity_data.get("file_path"): file_paths.add(entity_data["file_path"]) # Remove duplicates while preserving order description_list = list(dict.fromkeys(descriptions)) entity_types = list(dict.fromkeys(entity_types)) # Get most common entity type entity_type = ( max(set(entity_types), key=entity_types.count) if entity_types else current_entity.get("entity_type", "UNKNOWN") ) # Generate final description from entities or fallback to current if description_list: final_description, _ = await _handle_entity_relation_summary( "Entity", entity_name, description_list, GRAPH_FIELD_SEP, global_config, llm_response_cache=llm_response_cache, ) else: final_description = current_entity.get("description", "") await _update_entity_storage(final_description, entity_type, file_paths) async def _rebuild_single_relationship( knowledge_graph_inst: BaseGraphStorage, relationships_vdb: BaseVectorStorage, src: str, tgt: str, chunk_ids: set[str], chunk_relationships: dict, llm_response_cache: BaseKVStorage, global_config: dict[str, str], ) -> None: """Rebuild a single relationship from cached extraction results Note: This function assumes the caller has already acquired the appropriate keyed lock for the relationship pair to ensure thread safety. """ # Get current relationship data current_relationship = await knowledge_graph_inst.get_edge(src, tgt) if not current_relationship: return # Collect all relationship data from relevant chunks all_relationship_data = [] for chunk_id in chunk_ids: if chunk_id in chunk_relationships: # Check both (src, tgt) and (tgt, src) since relationships can be bidirectional for edge_key in [(src, tgt), (tgt, src)]: if edge_key in chunk_relationships[chunk_id]: all_relationship_data.extend( chunk_relationships[chunk_id][edge_key] ) if not all_relationship_data: logger.warning(f"No relation data found for `{src}-{tgt}`") return # Merge descriptions and keywords descriptions = [] keywords = [] weights = [] file_paths = set() for rel_data in all_relationship_data: if rel_data.get("description"): descriptions.append(rel_data["description"]) if rel_data.get("keywords"): keywords.append(rel_data["keywords"]) if rel_data.get("weight"): weights.append(rel_data["weight"]) if rel_data.get("file_path"): file_paths.add(rel_data["file_path"]) # Remove duplicates while preserving order description_list = list(dict.fromkeys(descriptions)) keywords = list(dict.fromkeys(keywords)) combined_keywords = ( ", ".join(set(keywords)) if keywords else current_relationship.get("keywords", "") ) weight = sum(weights) if weights else current_relationship.get("weight", 1.0) # Generate final description from relations or fallback to current if description_list: final_description, _ = await _handle_entity_relation_summary( "Relation", f"{src}-{tgt}", description_list, GRAPH_FIELD_SEP, global_config, llm_response_cache=llm_response_cache, ) else: # fallback to keep current(unchanged) final_description = current_relationship.get("description", "") # Update relationship in graph storage updated_relationship_data = { **current_relationship, "description": final_description if final_description else current_relationship.get("description", ""), "keywords": combined_keywords, "weight": weight, "source_id": GRAPH_FIELD_SEP.join(chunk_ids), "file_path": GRAPH_FIELD_SEP.join([fp for fp in file_paths if fp]) if file_paths else current_relationship.get("file_path", "unknown_source"), } await knowledge_graph_inst.upsert_edge(src, tgt, updated_relationship_data) # Update relationship in vector database rel_vdb_id = compute_mdhash_id(src + tgt, prefix="rel-") rel_vdb_id_reverse = compute_mdhash_id(tgt + src, prefix="rel-") # Delete old vector records first (both directions to be safe) try: await relationships_vdb.delete([rel_vdb_id, rel_vdb_id_reverse]) except Exception as e: logger.debug( f"Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}" ) # Insert new vector record rel_content = f"{combined_keywords}\t{src}\n{tgt}\n{final_description}" await relationships_vdb.upsert( { rel_vdb_id: { "src_id": src, "tgt_id": tgt, "source_id": updated_relationship_data["source_id"], "content": rel_content, "keywords": combined_keywords, "description": final_description, "weight": weight, "file_path": updated_relationship_data["file_path"], } } ) async def _merge_nodes_then_upsert( entity_name: str, nodes_data: list[dict], knowledge_graph_inst: BaseGraphStorage, global_config: dict, pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, ): """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" already_entity_types = [] already_source_ids = [] already_description = [] already_file_paths = [] already_node = await knowledge_graph_inst.get_node(entity_name) if already_node: already_entity_types.append(already_node["entity_type"]) already_source_ids.extend(already_node["source_id"].split(GRAPH_FIELD_SEP)) already_file_paths.extend(already_node["file_path"].split(GRAPH_FIELD_SEP)) already_description.extend(already_node["description"].split(GRAPH_FIELD_SEP)) entity_type = sorted( Counter( [dp["entity_type"] for dp in nodes_data] + already_entity_types ).items(), key=lambda x: x[1], reverse=True, )[0][0] # Get the entity type with the highest count # merge and deduplicate description description_list = list( dict.fromkeys( already_description + [dp["description"] for dp in nodes_data if dp.get("description")] ) ) num_fragment = len(description_list) already_fragment = len(already_description) deduplicated_num = already_fragment + len(nodes_data) - num_fragment if deduplicated_num > 0: dd_message = f"(dd:{deduplicated_num})" else: dd_message = "" if num_fragment > 0: # Get summary and LLM usage status description, llm_was_used = await _handle_entity_relation_summary( "Entity", entity_name, description_list, GRAPH_FIELD_SEP, global_config, llm_response_cache, ) # Log based on actual LLM usage if llm_was_used: status_message = f"LLMmrg: `{entity_name}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}" else: status_message = f"Merged: `{entity_name}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) else: logger.error(f"Entity {entity_name} has no description") description = "(no description)" source_id = GRAPH_FIELD_SEP.join( set([dp["source_id"] for dp in nodes_data] + already_source_ids) ) file_path = build_file_path(already_file_paths, nodes_data, entity_name) node_data = dict( entity_id=entity_name, entity_type=entity_type, description=description, source_id=source_id, file_path=file_path, created_at=int(time.time()), ) await knowledge_graph_inst.upsert_node( entity_name, node_data=node_data, ) node_data["entity_name"] = entity_name return node_data async def _merge_edges_then_upsert( src_id: str, tgt_id: str, edges_data: list[dict], knowledge_graph_inst: BaseGraphStorage, global_config: dict, pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, added_entities: list = None, # New parameter to track entities added during edge processing ): if src_id == tgt_id: return None already_weights = [] already_source_ids = [] already_description = [] already_keywords = [] already_file_paths = [] if await knowledge_graph_inst.has_edge(src_id, tgt_id): already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) # Handle the case where get_edge returns None or missing fields if already_edge: # Get weight with default 1.0 if missing already_weights.append(already_edge.get("weight", 1.0)) # Get source_id with empty string default if missing or None if already_edge.get("source_id") is not None: already_source_ids.extend( already_edge["source_id"].split(GRAPH_FIELD_SEP) ) # Get file_path with empty string default if missing or None if already_edge.get("file_path") is not None: already_file_paths.extend( already_edge["file_path"].split(GRAPH_FIELD_SEP) ) # Get description with empty string default if missing or None if already_edge.get("description") is not None: already_description.extend( already_edge["description"].split(GRAPH_FIELD_SEP) ) # Get keywords with empty string default if missing or None if already_edge.get("keywords") is not None: already_keywords.extend( split_string_by_multi_markers( already_edge["keywords"], [GRAPH_FIELD_SEP] ) ) # Process edges_data with None checks weight = sum([dp["weight"] for dp in edges_data] + already_weights) description_list = list( dict.fromkeys( already_description + [dp["description"] for dp in edges_data if dp.get("description")] ) ) num_fragment = len(description_list) already_fragment = len(already_description) deduplicated_num = already_fragment + len(edges_data) - num_fragment if deduplicated_num > 0: dd_message = f"(dd:{deduplicated_num})" else: dd_message = "" if num_fragment > 0: # Get summary and LLM usage status description, llm_was_used = await _handle_entity_relation_summary( "Relation", f"({src_id}, {tgt_id})", description_list, GRAPH_FIELD_SEP, global_config, llm_response_cache, ) # Log based on actual LLM usage if llm_was_used: status_message = f"LLMmrg: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}" else: status_message = f"Merged: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) else: logger.error(f"Edge {src_id} - {tgt_id} has no description") description = "(no description)" # Split all existing and new keywords into individual terms, then combine and deduplicate all_keywords = set() # Process already_keywords (which are comma-separated) for keyword_str in already_keywords: if keyword_str: # Skip empty strings all_keywords.update(k.strip() for k in keyword_str.split(",") if k.strip()) # Process new keywords from edges_data for edge in edges_data: if edge.get("keywords"): all_keywords.update( k.strip() for k in edge["keywords"].split(",") if k.strip() ) # Join all unique keywords with commas keywords = ",".join(sorted(all_keywords)) source_id = GRAPH_FIELD_SEP.join( set( [dp["source_id"] for dp in edges_data if dp.get("source_id")] + already_source_ids ) ) file_path = build_file_path(already_file_paths, edges_data, f"{src_id}-{tgt_id}") for need_insert_id in [src_id, tgt_id]: if not (await knowledge_graph_inst.has_node(need_insert_id)): node_data = { "entity_id": need_insert_id, "source_id": source_id, "description": description, "entity_type": "UNKNOWN", "file_path": file_path, "created_at": int(time.time()), } await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data) # Track entities added during edge processing if added_entities is not None: entity_data = { "entity_name": need_insert_id, "entity_type": "UNKNOWN", "description": description, "source_id": source_id, "file_path": file_path, "created_at": int(time.time()), } added_entities.append(entity_data) await knowledge_graph_inst.upsert_edge( src_id, tgt_id, edge_data=dict( weight=weight, description=description, keywords=keywords, source_id=source_id, file_path=file_path, created_at=int(time.time()), ), ) edge_data = dict( src_id=src_id, tgt_id=tgt_id, description=description, keywords=keywords, source_id=source_id, file_path=file_path, created_at=int(time.time()), ) return edge_data async def merge_nodes_and_edges( chunk_results: list, knowledge_graph_inst: BaseGraphStorage, entity_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, global_config: dict[str, str], full_entities_storage: BaseKVStorage = None, full_relations_storage: BaseKVStorage = None, doc_id: str = None, pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, current_file_number: int = 0, total_files: int = 0, file_path: str = "unknown_source", ) -> None: """Two-phase merge: process all entities first, then all relationships This approach ensures data consistency by: 1. Phase 1: Process all entities concurrently 2. Phase 2: Process all relationships concurrently (may add missing entities) 3. Phase 3: Update full_entities and full_relations storage with final results Args: chunk_results: List of tuples (maybe_nodes, maybe_edges) containing extracted entities and relationships knowledge_graph_inst: Knowledge graph storage entity_vdb: Entity vector database relationships_vdb: Relationship vector database global_config: Global configuration full_entities_storage: Storage for document entity lists full_relations_storage: Storage for document relation lists doc_id: Document ID for storage indexing pipeline_status: Pipeline status dictionary pipeline_status_lock: Lock for pipeline status llm_response_cache: LLM response cache current_file_number: Current file number for logging total_files: Total files for logging file_path: File path for logging """ # Collect all nodes and edges from all chunks all_nodes = defaultdict(list) all_edges = defaultdict(list) for maybe_nodes, maybe_edges in chunk_results: # Collect nodes for entity_name, entities in maybe_nodes.items(): all_nodes[entity_name].extend(entities) # Collect edges with sorted keys for undirected graph for edge_key, edges in maybe_edges.items(): sorted_edge_key = tuple(sorted(edge_key)) all_edges[sorted_edge_key].extend(edges) total_entities_count = len(all_nodes) total_relations_count = len(all_edges) log_message = f"Merging stage {current_file_number}/{total_files}: {file_path}" logger.info(log_message) async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) # Get max async tasks limit from global_config for semaphore control graph_max_async = global_config.get("llm_model_max_async", 4) * 2 semaphore = asyncio.Semaphore(graph_max_async) # ===== Phase 1: Process all entities concurrently ===== log_message = f"Phase 1: Processing {total_entities_count} entities from {doc_id} (async: {graph_max_async})" logger.info(log_message) async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) async def _locked_process_entity_name(entity_name, entities): async with semaphore: workspace = global_config.get("workspace", "") namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" async with get_storage_keyed_lock( [entity_name], namespace=namespace, enable_logging=False ): entity_data = await _merge_nodes_then_upsert( entity_name, entities, knowledge_graph_inst, global_config, pipeline_status, pipeline_status_lock, llm_response_cache, ) if entity_vdb is not None: data_for_vdb = { compute_mdhash_id(entity_data["entity_name"], prefix="ent-"): { "entity_name": entity_data["entity_name"], "entity_type": entity_data["entity_type"], "content": f"{entity_data['entity_name']}\n{entity_data['description']}", "source_id": entity_data["source_id"], "file_path": entity_data.get("file_path", "unknown_source"), } } await entity_vdb.upsert(data_for_vdb) return entity_data # Create entity processing tasks entity_tasks = [] for entity_name, entities in all_nodes.items(): task = asyncio.create_task(_locked_process_entity_name(entity_name, entities)) entity_tasks.append(task) # Execute entity tasks with error handling processed_entities = [] if entity_tasks: done, pending = await asyncio.wait( entity_tasks, return_when=asyncio.FIRST_EXCEPTION ) # Check if any task raised an exception for task in done: if task.exception(): # If a task failed, cancel all pending tasks for pending_task in pending: pending_task.cancel() # Wait for cancellation to complete if pending: await asyncio.wait(pending) # Re-raise the exception to notify the caller raise task.exception() # If all tasks completed successfully, collect results processed_entities = [task.result() for task in entity_tasks] # ===== Phase 2: Process all relationships concurrently ===== log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})" logger.info(log_message) async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) async def _locked_process_edges(edge_key, edges): async with semaphore: workspace = global_config.get("workspace", "") namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" sorted_edge_key = sorted([edge_key[0], edge_key[1]]) async with get_storage_keyed_lock( sorted_edge_key, namespace=namespace, enable_logging=False, ): added_entities = [] # Track entities added during edge processing edge_data = await _merge_edges_then_upsert( edge_key[0], edge_key[1], edges, knowledge_graph_inst, global_config, pipeline_status, pipeline_status_lock, llm_response_cache, added_entities, # Pass list to collect added entities ) if edge_data is None: return None, [] if relationships_vdb is not None: data_for_vdb = { compute_mdhash_id( edge_data["src_id"] + edge_data["tgt_id"], prefix="rel-" ): { "src_id": edge_data["src_id"], "tgt_id": edge_data["tgt_id"], "keywords": edge_data["keywords"], "content": f"{edge_data['src_id']}\t{edge_data['tgt_id']}\n{edge_data['keywords']}\n{edge_data['description']}", "source_id": edge_data["source_id"], "file_path": edge_data.get("file_path", "unknown_source"), "weight": edge_data.get("weight", 1.0), } } await relationships_vdb.upsert(data_for_vdb) return edge_data, added_entities # Create relationship processing tasks edge_tasks = [] for edge_key, edges in all_edges.items(): task = asyncio.create_task(_locked_process_edges(edge_key, edges)) edge_tasks.append(task) # Execute relationship tasks with error handling processed_edges = [] all_added_entities = [] if edge_tasks: done, pending = await asyncio.wait( edge_tasks, return_when=asyncio.FIRST_EXCEPTION ) # Check if any task raised an exception for task in done: if task.exception(): # If a task failed, cancel all pending tasks for pending_task in pending: pending_task.cancel() # Wait for cancellation to complete if pending: await asyncio.wait(pending) # Re-raise the exception to notify the caller raise task.exception() # If all tasks completed successfully, collect results for task in edge_tasks: edge_data, added_entities = task.result() if edge_data is not None: processed_edges.append(edge_data) all_added_entities.extend(added_entities) # ===== Phase 3: Update full_entities and full_relations storage ===== if full_entities_storage and full_relations_storage and doc_id: try: # Merge all entities: original entities + entities added during edge processing final_entity_names = set() # Add original processed entities for entity_data in processed_entities: if entity_data and entity_data.get("entity_name"): final_entity_names.add(entity_data["entity_name"]) # Add entities that were added during relationship processing for added_entity in all_added_entities: if added_entity and added_entity.get("entity_name"): final_entity_names.add(added_entity["entity_name"]) # Collect all relation pairs final_relation_pairs = set() for edge_data in processed_edges: if edge_data: src_id = edge_data.get("src_id") tgt_id = edge_data.get("tgt_id") if src_id and tgt_id: relation_pair = tuple(sorted([src_id, tgt_id])) final_relation_pairs.add(relation_pair) log_message = f"Phase 3: Updating final {len(final_entity_names)}({len(processed_entities)}+{len(all_added_entities)}) entities and {len(final_relation_pairs)} relations from {doc_id}" logger.info(log_message) async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) # Update storage if final_entity_names: await full_entities_storage.upsert( { doc_id: { "entity_names": list(final_entity_names), "count": len(final_entity_names), } } ) if final_relation_pairs: await full_relations_storage.upsert( { doc_id: { "relation_pairs": [ list(pair) for pair in final_relation_pairs ], "count": len(final_relation_pairs), } } ) logger.debug( f"Updated entity-relation index for document {doc_id}: {len(final_entity_names)} entities (original: {len(processed_entities)}, added: {len(all_added_entities)}), {len(final_relation_pairs)} relations" ) except Exception as e: logger.error( f"Failed to update entity-relation index for document {doc_id}: {e}" ) # Don't raise exception to avoid affecting main flow log_message = f"Completed merging: {len(processed_entities)} entities, {len(all_added_entities)} extra entities, {len(processed_edges)} relations" logger.info(log_message) async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) async def extract_entities( chunks: dict[str, TextChunkSchema], global_config: dict[str, str], pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, text_chunks_storage: BaseKVStorage | None = None, ) -> list: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] ordered_chunks = list(chunks.items()) # add language and example number params to prompt language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE) entity_types = global_config["addon_params"].get( "entity_types", DEFAULT_ENTITY_TYPES ) example_number = global_config["addon_params"].get("example_number", None) if example_number and example_number < len(PROMPTS["entity_extraction_examples"]): examples = "\n".join( PROMPTS["entity_extraction_examples"][: int(example_number)] ) else: examples = "\n".join(PROMPTS["entity_extraction_examples"]) example_context_base = dict( tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], entity_types=", ".join(entity_types), language=language, ) # add example's format examples = examples.format(**example_context_base) entity_extract_prompt = PROMPTS["entity_extraction"] context_base = dict( tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], entity_types=",".join(entity_types), examples=examples, language=language, ) continue_prompt = PROMPTS["entity_continue_extraction"].format(**context_base) if_loop_prompt = PROMPTS["entity_if_loop_extraction"] processed_chunks = 0 total_chunks = len(ordered_chunks) async def _process_extraction_result( result: str, chunk_key: str, file_path: str = "unknown_source" ): """Process a single extraction result (either initial or gleaning) Args: result (str): The extraction result to process chunk_key (str): The chunk key for source tracking file_path (str): The file path for citation Returns: tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships """ maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) records = split_string_by_multi_markers( result, [context_base["record_delimiter"], context_base["completion_delimiter"]], ) for record in records: record = re.search(r"\((.*)\)", record) if record is None: continue record = record.group(1) record_attributes = split_string_by_multi_markers( record, [context_base["tuple_delimiter"]] ) if_entities = await _handle_single_entity_extraction( record_attributes, chunk_key, file_path ) if if_entities is not None: maybe_nodes[if_entities["entity_name"]].append(if_entities) continue if_relation = await _handle_single_relationship_extraction( record_attributes, chunk_key, file_path ) if if_relation is not None: maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append( if_relation ) return maybe_nodes, maybe_edges async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): """Process a single chunk Args: chunk_key_dp (tuple[str, TextChunkSchema]): ("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) Returns: tuple: (maybe_nodes, maybe_edges) containing extracted entities and relationships """ nonlocal processed_chunks chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] content = chunk_dp["content"] # Get file path from chunk data or use default file_path = chunk_dp.get("file_path", "unknown_source") # Create cache keys collector for batch processing cache_keys_collector = [] # Get initial extraction hint_prompt = entity_extract_prompt.format( **{**context_base, "input_text": content} ) final_result = await use_llm_func_with_cache( hint_prompt, use_llm_func, llm_response_cache=llm_response_cache, cache_type="extract", chunk_id=chunk_key, cache_keys_collector=cache_keys_collector, ) # Store LLM cache reference in chunk (will be handled by use_llm_func_with_cache) history = pack_user_ass_to_openai_messages(hint_prompt, final_result) # Process initial extraction with file path maybe_nodes, maybe_edges = await _process_extraction_result( final_result, chunk_key, file_path ) # Process additional gleaning results for now_glean_index in range(entity_extract_max_gleaning): glean_result = await use_llm_func_with_cache( continue_prompt, use_llm_func, llm_response_cache=llm_response_cache, history_messages=history, cache_type="extract", chunk_id=chunk_key, cache_keys_collector=cache_keys_collector, ) history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) # Process gleaning result separately with file path glean_nodes, glean_edges = await _process_extraction_result( glean_result, chunk_key, file_path ) # Merge results - only add entities and edges with new names for entity_name, entities in glean_nodes.items(): if ( entity_name not in maybe_nodes ): # Only accetp entities with new name in gleaning stage maybe_nodes[entity_name].extend(entities) for edge_key, edges in glean_edges.items(): if ( edge_key not in maybe_edges ): # Only accetp edges with new name in gleaning stage maybe_edges[edge_key].extend(edges) if now_glean_index == entity_extract_max_gleaning - 1: break if_loop_result: str = await use_llm_func_with_cache( if_loop_prompt, use_llm_func, llm_response_cache=llm_response_cache, history_messages=history, cache_type="extract", cache_keys_collector=cache_keys_collector, ) if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if if_loop_result != "yes": break # Batch update chunk's llm_cache_list with all collected cache keys if cache_keys_collector and text_chunks_storage: await update_chunk_cache_list( chunk_key, text_chunks_storage, cache_keys_collector, "entity_extraction", ) processed_chunks += 1 entities_count = len(maybe_nodes) relations_count = len(maybe_edges) log_message = f"Chunk {processed_chunks} of {total_chunks} extracted {entities_count} Ent + {relations_count} Rel" logger.info(log_message) if pipeline_status is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) # Return the extracted nodes and edges for centralized processing return maybe_nodes, maybe_edges # Get max async tasks limit from global_config chunk_max_async = global_config.get("llm_model_max_async", 4) semaphore = asyncio.Semaphore(chunk_max_async) async def _process_with_semaphore(chunk): async with semaphore: return await _process_single_content(chunk) tasks = [] for c in ordered_chunks: task = asyncio.create_task(_process_with_semaphore(c)) tasks.append(task) # Wait for tasks to complete or for the first exception to occur # This allows us to cancel remaining tasks if any task fails done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) # Check if any task raised an exception for task in done: if task.exception(): # If a task failed, cancel all pending tasks # This prevents unnecessary processing since the parent function will abort anyway for pending_task in pending: pending_task.cancel() # Wait for cancellation to complete if pending: await asyncio.wait(pending) # Re-raise the exception to notify the caller raise task.exception() # If all tasks completed successfully, collect results chunk_results = [task.result() for task in tasks] # Return the chunk_results for later processing in merge_nodes_and_edges return chunk_results async def kg_query( query: str, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, chunks_vdb: BaseVectorStorage = None, ) -> str | AsyncIterator[str]: if not query: return PROMPTS["fail_response"] if query_param.model_func: use_model_func = query_param.model_func else: use_model_func = global_config["llm_model_func"] # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) # Handle cache args_hash = compute_args_hash( query_param.mode, query, query_param.response_type, query_param.top_k, query_param.chunk_top_k, query_param.max_entity_tokens, query_param.max_relation_tokens, query_param.max_total_tokens, query_param.hl_keywords or [], query_param.ll_keywords or [], query_param.user_prompt or "", query_param.enable_rerank, ) cached_response = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) if cached_response is not None: return cached_response hl_keywords, ll_keywords = await get_keywords_from_query( query, query_param, global_config, hashing_kv ) logger.debug(f"High-level keywords: {hl_keywords}") logger.debug(f"Low-level keywords: {ll_keywords}") # Handle empty keywords if ll_keywords == [] and query_param.mode in ["local", "hybrid", "mix"]: logger.warning("low_level_keywords is empty") if hl_keywords == [] and query_param.mode in ["global", "hybrid", "mix"]: logger.warning("high_level_keywords is empty") if hl_keywords == [] and ll_keywords == []: if len(query) < 50: logger.warning(f"Forced low_level_keywords to origin query: {query}") ll_keywords = [query] else: return PROMPTS["fail_response"] ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" # Build context context = await _build_query_context( query, ll_keywords_str, hl_keywords_str, knowledge_graph_inst, entities_vdb, relationships_vdb, text_chunks_db, query_param, chunks_vdb, ) if query_param.only_need_context: return context if context is not None else PROMPTS["fail_response"] if context is None: return PROMPTS["fail_response"] # Process conversation history history_context = "" if query_param.conversation_history: history_context = get_conversation_turns( query_param.conversation_history, query_param.history_turns ) # Build system prompt user_prompt = ( query_param.user_prompt if query_param.user_prompt else PROMPTS["DEFAULT_USER_PROMPT"] ) sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type, history=history_context, user_prompt=user_prompt, ) if query_param.only_need_prompt: return sys_prompt tokenizer: Tokenizer = global_config["tokenizer"] len_of_prompts = len(tokenizer.encode(query + sys_prompt)) logger.debug( f"[kg_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})" ) response = await use_model_func( query, system_prompt=sys_prompt, stream=query_param.stream, ) if isinstance(response, str) and len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") .replace("user", "") .replace("model", "") .replace(query, "") .replace("", "") .replace("", "") .strip() ) if hashing_kv.global_config.get("enable_llm_cache"): # Save to cache with query parameters queryparam_dict = { "mode": query_param.mode, "response_type": query_param.response_type, "top_k": query_param.top_k, "chunk_top_k": query_param.chunk_top_k, "max_entity_tokens": query_param.max_entity_tokens, "max_relation_tokens": query_param.max_relation_tokens, "max_total_tokens": query_param.max_total_tokens, "hl_keywords": query_param.hl_keywords or [], "ll_keywords": query_param.ll_keywords or [], "user_prompt": query_param.user_prompt or "", "enable_rerank": query_param.enable_rerank, } await save_to_cache( hashing_kv, CacheData( args_hash=args_hash, content=response, prompt=query, mode=query_param.mode, cache_type="query", queryparam=queryparam_dict, ), ) return response async def get_keywords_from_query( query: str, query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, ) -> tuple[list[str], list[str]]: """ Retrieves high-level and low-level keywords for RAG operations. This function checks if keywords are already provided in query parameters, and if not, extracts them from the query text using LLM. Args: query: The user's query text query_param: Query parameters that may contain pre-defined keywords global_config: Global configuration dictionary hashing_kv: Optional key-value storage for caching results Returns: A tuple containing (high_level_keywords, low_level_keywords) """ # Check if pre-defined keywords are already provided if query_param.hl_keywords or query_param.ll_keywords: return query_param.hl_keywords, query_param.ll_keywords # Extract keywords using extract_keywords_only function which already supports conversation history hl_keywords, ll_keywords = await extract_keywords_only( query, query_param, global_config, hashing_kv ) return hl_keywords, ll_keywords async def extract_keywords_only( text: str, param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, ) -> tuple[list[str], list[str]]: """ Extract high-level and low-level keywords from the given 'text' using the LLM. This method does NOT build the final RAG context or provide a final answer. It ONLY extracts keywords (hl_keywords, ll_keywords). """ # 1. Handle cache if needed - add cache type for keywords args_hash = compute_args_hash( param.mode, text, param.hl_keywords or [], param.ll_keywords or [], ) cached_response = await handle_cache( hashing_kv, args_hash, text, param.mode, cache_type="keywords" ) if cached_response is not None: try: keywords_data = json_repair.loads(cached_response) return keywords_data.get("high_level_keywords", []), keywords_data.get( "low_level_keywords", [] ) except (json.JSONDecodeError, KeyError): logger.warning( "Invalid cache format for keywords, proceeding with extraction" ) # 2. Build the examples example_number = global_config["addon_params"].get("example_number", None) if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]): examples = "\n".join( PROMPTS["keywords_extraction_examples"][: int(example_number)] ) else: examples = "\n".join(PROMPTS["keywords_extraction_examples"]) language = global_config["addon_params"].get( "language", DEFAULT_SUMMARY_LANGUAGE ) # 3. Process conversation history # history_context = "" # if param.conversation_history: # history_context = get_conversation_turns( # param.conversation_history, param.history_turns # ) # 4. Build the keyword-extraction prompt kw_prompt = PROMPTS["keywords_extraction"].format( query=text, examples=examples, language=language, ) tokenizer: Tokenizer = global_config["tokenizer"] len_of_prompts = len(tokenizer.encode(kw_prompt)) logger.debug( f"[extract_keywords] Sending to LLM: {len_of_prompts:,} tokens (Prompt: {len_of_prompts})" ) # 5. Call the LLM for keyword extraction if param.model_func: use_model_func = param.model_func else: use_model_func = global_config["llm_model_func"] # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) result = await use_model_func(kw_prompt, keyword_extraction=True) # 6. Parse out JSON from the LLM response result = remove_think_tags(result) try: keywords_data = json_repair.loads(result) if not keywords_data: logger.error("No JSON-like structure found in the LLM respond.") return [], [] except json.JSONDecodeError as e: logger.error(f"JSON parsing error: {e}") logger.error(f"LLM respond: {result}") return [], [] hl_keywords = keywords_data.get("high_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", []) # 7. Cache only the processed keywords with cache type if hl_keywords or ll_keywords: cache_data = { "high_level_keywords": hl_keywords, "low_level_keywords": ll_keywords, } if hashing_kv.global_config.get("enable_llm_cache"): # Save to cache with query parameters queryparam_dict = { "mode": param.mode, "response_type": param.response_type, "top_k": param.top_k, "chunk_top_k": param.chunk_top_k, "max_entity_tokens": param.max_entity_tokens, "max_relation_tokens": param.max_relation_tokens, "max_total_tokens": param.max_total_tokens, "hl_keywords": param.hl_keywords or [], "ll_keywords": param.ll_keywords or [], "user_prompt": param.user_prompt or "", "enable_rerank": param.enable_rerank, } await save_to_cache( hashing_kv, CacheData( args_hash=args_hash, content=json.dumps(cache_data), prompt=text, mode=param.mode, cache_type="keywords", queryparam=queryparam_dict, ), ) return hl_keywords, ll_keywords async def _get_vector_context( query: str, chunks_vdb: BaseVectorStorage, query_param: QueryParam, ) -> list[dict]: """ Retrieve text chunks from the vector database without reranking or truncation. This function performs vector search to find relevant text chunks for a query. Reranking and truncation will be handled later in the unified processing. Args: query: The query string to search for chunks_vdb: Vector database containing document chunks query_param: Query parameters including chunk_top_k and ids Returns: List of text chunks with metadata """ try: # Use chunk_top_k if specified, otherwise fall back to top_k search_top_k = query_param.chunk_top_k or query_param.top_k results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids) if not results: logger.info(f"Naive query: 0 chunks (chunk_top_k: {search_top_k})") return [] valid_chunks = [] for result in results: if "content" in result: chunk_with_metadata = { "content": result["content"], "created_at": result.get("created_at", None), "file_path": result.get("file_path", "unknown_source"), "source_type": "vector", # Mark the source type "chunk_id": result.get("id"), # Add chunk_id for deduplication } valid_chunks.append(chunk_with_metadata) logger.info( f"Naive query: {len(valid_chunks)} chunks (chunk_top_k: {search_top_k})" ) return valid_chunks except Exception as e: logger.error(f"Error in _get_vector_context: {e}") return [] async def _build_query_context( query: str, ll_keywords: str, hl_keywords: str, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, chunks_vdb: BaseVectorStorage = None, ): logger.info(f"Process {os.getpid()} building query context...") # Collect chunks from different sources separately vector_chunks = [] entity_chunks = [] relation_chunks = [] entities_context = [] relations_context = [] # Store original data for later text chunk retrieval local_entities = [] local_relations = [] global_entities = [] global_relations = [] # Track chunk sources and metadata for final logging chunk_tracking = {} # chunk_id -> {source, frequency, order} # Pre-compute query embedding if vector similarity method is used kg_chunk_pick_method = text_chunks_db.global_config.get( "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD ) query_embedding = None if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: embedding_func_config = text_chunks_db.embedding_func if embedding_func_config and embedding_func_config.func: try: query_embedding = await embedding_func_config.func([query]) query_embedding = query_embedding[ 0 ] # Extract first embedding from batch result logger.debug( "Pre-computed query embedding for vector similarity chunk selection" ) except Exception as e: logger.warning(f"Failed to pre-compute query embedding: {e}") query_embedding = None # Handle local and global modes if query_param.mode == "local" and len(ll_keywords) > 0: local_entities, local_relations = await _get_node_data( ll_keywords, knowledge_graph_inst, entities_vdb, query_param, ) elif query_param.mode == "global" and len(hl_keywords) > 0: global_relations, global_entities = await _get_edge_data( hl_keywords, knowledge_graph_inst, relationships_vdb, query_param, ) else: # hybrid or mix mode if len(ll_keywords) > 0: local_entities, local_relations = await _get_node_data( ll_keywords, knowledge_graph_inst, entities_vdb, query_param, ) if len(hl_keywords) > 0: global_relations, global_entities = await _get_edge_data( hl_keywords, knowledge_graph_inst, relationships_vdb, query_param, ) # Get vector chunks first if in mix mode if query_param.mode == "mix" and chunks_vdb: vector_chunks = await _get_vector_context( query, chunks_vdb, query_param, ) # Track vector chunks with source metadata for i, chunk in enumerate(vector_chunks): chunk_id = chunk.get("chunk_id") or chunk.get("id") if chunk_id: chunk_tracking[chunk_id] = { "source": "C", "frequency": 1, # Vector chunks always have frequency 1 "order": i + 1, # 1-based order in vector search results } else: logger.warning(f"Vector chunk missing chunk_id: {chunk}") # Use round-robin merge to combine local and global data fairly final_entities = [] seen_entities = set() # Round-robin merge entities max_len = max(len(local_entities), len(global_entities)) for i in range(max_len): # First from local if i < len(local_entities): entity = local_entities[i] entity_name = entity.get("entity_name") if entity_name and entity_name not in seen_entities: final_entities.append(entity) seen_entities.add(entity_name) # Then from global if i < len(global_entities): entity = global_entities[i] entity_name = entity.get("entity_name") if entity_name and entity_name not in seen_entities: final_entities.append(entity) seen_entities.add(entity_name) # Round-robin merge relations final_relations = [] seen_relations = set() max_len = max(len(local_relations), len(global_relations)) for i in range(max_len): # First from local if i < len(local_relations): relation = local_relations[i] # Build relation unique identifier if "src_tgt" in relation: rel_key = tuple(sorted(relation["src_tgt"])) else: rel_key = tuple( sorted([relation.get("src_id"), relation.get("tgt_id")]) ) if rel_key not in seen_relations: final_relations.append(relation) seen_relations.add(rel_key) # Then from global if i < len(global_relations): relation = global_relations[i] # Build relation unique identifier if "src_tgt" in relation: rel_key = tuple(sorted(relation["src_tgt"])) else: rel_key = tuple( sorted([relation.get("src_id"), relation.get("tgt_id")]) ) if rel_key not in seen_relations: final_relations.append(relation) seen_relations.add(rel_key) # Generate entities context entities_context = [] for i, n in enumerate(final_entities): created_at = n.get("created_at", "UNKNOWN") if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) # Get file path from node data file_path = n.get("file_path", "unknown_source") entities_context.append( { "id": i + 1, "entity": n["entity_name"], "type": n.get("entity_type", "UNKNOWN"), "description": n.get("description", "UNKNOWN"), "created_at": created_at, "file_path": file_path, } ) # Generate relations context relations_context = [] for i, e in enumerate(final_relations): created_at = e.get("created_at", "UNKNOWN") # Convert timestamp to readable format if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) # Get file path from edge data file_path = e.get("file_path", "unknown_source") # Handle different relation data formats if "src_tgt" in e: entity1, entity2 = e["src_tgt"] else: entity1, entity2 = e.get("src_id"), e.get("tgt_id") relations_context.append( { "id": i + 1, "entity1": entity1, "entity2": entity2, "description": e.get("description", "UNKNOWN"), "created_at": created_at, "file_path": file_path, } ) logger.debug( f"Initial KG query results: {len(entities_context)} entities, {len(relations_context)} relations" ) # Unified token control system - Apply precise token limits to entities and relations tokenizer = text_chunks_db.global_config.get("tokenizer") # Get new token limits from query_param (with fallback to global_config) max_entity_tokens = getattr( query_param, "max_entity_tokens", text_chunks_db.global_config.get( "max_entity_tokens", DEFAULT_MAX_ENTITY_TOKENS ), ) max_relation_tokens = getattr( query_param, "max_relation_tokens", text_chunks_db.global_config.get( "max_relation_tokens", DEFAULT_MAX_RELATION_TOKENS ), ) max_total_tokens = getattr( query_param, "max_total_tokens", text_chunks_db.global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS), ) # Truncate entities based on complete JSON serialization if entities_context: # Process entities context to replace GRAPH_FIELD_SEP with : in file_path fields for entity in entities_context: # remove file_path and created_at entity.pop("file_path", None) entity.pop("created_at", None) # if "file_path" in entity and entity["file_path"]: # entity["file_path"] = entity["file_path"].replace(GRAPH_FIELD_SEP, ";") entities_context = truncate_list_by_token_size( entities_context, key=lambda x: json.dumps(x, ensure_ascii=False), max_token_size=max_entity_tokens, tokenizer=tokenizer, ) # Truncate relations based on complete JSON serialization if relations_context: # Process relations context to replace GRAPH_FIELD_SEP with : in file_path fields for relation in relations_context: # remove file_path and created_at relation.pop("file_path", None) relation.pop("created_at", None) # if "file_path" in relation and relation["file_path"]: # relation["file_path"] = relation["file_path"].replace( # GRAPH_FIELD_SEP, ";" # ) relations_context = truncate_list_by_token_size( relations_context, key=lambda x: json.dumps(x, ensure_ascii=False), max_token_size=max_relation_tokens, tokenizer=tokenizer, ) # After truncation, get text chunks based on final entities and relations logger.info( f"Truncated KG query results: {len(entities_context)} entities, {len(relations_context)} relations" ) # Create filtered data based on truncated context final_node_datas = [] if entities_context and final_entities: final_entity_names = {e["entity"] for e in entities_context} seen_nodes = set() for node in final_entities: name = node.get("entity_name") if name in final_entity_names and name not in seen_nodes: final_node_datas.append(node) seen_nodes.add(name) final_edge_datas = [] if relations_context and final_relations: final_relation_pairs = {(r["entity1"], r["entity2"]) for r in relations_context} seen_edges = set() for edge in final_relations: src, tgt = edge.get("src_id"), edge.get("tgt_id") if src is None or tgt is None: src, tgt = edge.get("src_tgt", (None, None)) pair = (src, tgt) if pair in final_relation_pairs and pair not in seen_edges: final_edge_datas.append(edge) seen_edges.add(pair) # Get text chunks based on final filtered data # To preserve the influence of entity order, entiy-based chunks should not be deduplcicated by vector_chunks if final_node_datas: entity_chunks = await _find_related_text_unit_from_entities( final_node_datas, query_param, text_chunks_db, knowledge_graph_inst, query, chunks_vdb, chunk_tracking=chunk_tracking, query_embedding=query_embedding, ) # Find deduplcicated chunks from edge # Deduplication cause chunks solely relation-based to be prioritized and sent to the LLM when re-ranking is disabled if final_edge_datas: relation_chunks = await _find_related_text_unit_from_relations( final_edge_datas, query_param, text_chunks_db, entity_chunks, query, chunks_vdb, chunk_tracking=chunk_tracking, query_embedding=query_embedding, ) # Round-robin merge chunks from different sources with deduplication by chunk_id merged_chunks = [] seen_chunk_ids = set() max_len = max(len(vector_chunks), len(entity_chunks), len(relation_chunks)) origin_len = len(vector_chunks) + len(entity_chunks) + len(relation_chunks) for i in range(max_len): # Add from vector chunks first (Naive mode) if i < len(vector_chunks): chunk = vector_chunks[i] chunk_id = chunk.get("chunk_id") or chunk.get("id") if chunk_id and chunk_id not in seen_chunk_ids: seen_chunk_ids.add(chunk_id) merged_chunks.append( { "content": chunk["content"], "file_path": chunk.get("file_path", "unknown_source"), "chunk_id": chunk_id, } ) # Add from entity chunks (Local mode) if i < len(entity_chunks): chunk = entity_chunks[i] chunk_id = chunk.get("chunk_id") or chunk.get("id") if chunk_id and chunk_id not in seen_chunk_ids: seen_chunk_ids.add(chunk_id) merged_chunks.append( { "content": chunk["content"], "file_path": chunk.get("file_path", "unknown_source"), "chunk_id": chunk_id, } ) # Add from relation chunks (Global mode) if i < len(relation_chunks): chunk = relation_chunks[i] chunk_id = chunk.get("chunk_id") or chunk.get("id") if chunk_id and chunk_id not in seen_chunk_ids: seen_chunk_ids.add(chunk_id) merged_chunks.append( { "content": chunk["content"], "file_path": chunk.get("file_path", "unknown_source"), "chunk_id": chunk_id, } ) logger.info( f"Round-robin merged total chunks from {origin_len} to {len(merged_chunks)}" ) # Apply token processing to merged chunks text_units_context = [] truncated_chunks = [] if merged_chunks: # Calculate dynamic token limit for text chunks entities_str = json.dumps(entities_context, ensure_ascii=False) relations_str = json.dumps(relations_context, ensure_ascii=False) # Calculate base context tokens (entities + relations + template) kg_context_template = """-----Entities(KG)----- ```json {entities_str} ``` -----Relationships(KG)----- ```json {relations_str} ``` -----Document Chunks(DC)----- ```json [] ``` """ kg_context = kg_context_template.format( entities_str=entities_str, relations_str=relations_str ) kg_context_tokens = len(tokenizer.encode(kg_context)) # Calculate actual system prompt overhead dynamically # 1. Converstion history not included in context length calculation history_context = "" # if query_param.conversation_history: # history_context = get_conversation_turns( # query_param.conversation_history, query_param.history_turns # ) # history_tokens = ( # len(tokenizer.encode(history_context)) if history_context else 0 # ) # 2. Calculate system prompt template tokens (excluding context_data) user_prompt = query_param.user_prompt if query_param.user_prompt else "" response_type = ( query_param.response_type if query_param.response_type else "Multiple Paragraphs" ) # Get the system prompt template from PROMPTS sys_prompt_template = text_chunks_db.global_config.get( "system_prompt_template", PROMPTS["rag_response"] ) # Create a sample system prompt with placeholders filled (excluding context_data) sample_sys_prompt = sys_prompt_template.format( history=history_context, context_data="", # Empty for overhead calculation response_type=response_type, user_prompt=user_prompt, ) sys_prompt_template_tokens = len(tokenizer.encode(sample_sys_prompt)) # Total system prompt overhead = template + query tokens query_tokens = len(tokenizer.encode(query)) sys_prompt_overhead = sys_prompt_template_tokens + query_tokens buffer_tokens = 100 # Safety buffer as requested # Calculate available tokens for text chunks used_tokens = kg_context_tokens + sys_prompt_overhead + buffer_tokens available_chunk_tokens = max_total_tokens - used_tokens logger.debug( f"Token allocation - Total: {max_total_tokens}, SysPrompt: {sys_prompt_overhead}, KG: {kg_context_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}" ) # Apply token truncation to chunks using the dynamic limit truncated_chunks = await process_chunks_unified( query=query, unique_chunks=merged_chunks, query_param=query_param, global_config=text_chunks_db.global_config, source_type=query_param.mode, chunk_token_limit=available_chunk_tokens, # Pass dynamic limit ) # Rebuild text_units_context with truncated chunks for i, chunk in enumerate(truncated_chunks): text_units_context.append( { "id": i + 1, "content": chunk["content"], "file_path": chunk.get("file_path", "unknown_source"), } ) logger.debug( f"Final chunk processing: {len(merged_chunks)} -> {len(text_units_context)} (chunk available tokens: {available_chunk_tokens})" ) logger.info( f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks" ) # not necessary to use LLM to generate a response if not entities_context and not relations_context: return None # output chunks tracking infomations # format: / (e.g., E5/2 R2/1 C1/1) if truncated_chunks and chunk_tracking: chunk_tracking_log = [] for chunk in truncated_chunks: chunk_id = chunk.get("chunk_id") if chunk_id and chunk_id in chunk_tracking: tracking_info = chunk_tracking[chunk_id] source = tracking_info["source"] frequency = tracking_info["frequency"] order = tracking_info["order"] chunk_tracking_log.append(f"{source}{frequency}/{order}") else: chunk_tracking_log.append("?0/0") if chunk_tracking_log: logger.info(f"chunks: {' '.join(chunk_tracking_log)}") entities_str = json.dumps(entities_context, ensure_ascii=False) relations_str = json.dumps(relations_context, ensure_ascii=False) text_units_str = json.dumps(text_units_context, ensure_ascii=False) result = f"""-----Entities(KG)----- ```json {entities_str} ``` -----Relationships(KG)----- ```json {relations_str} ``` -----Document Chunks(DC)----- ```json {text_units_str} ``` """ return result async def _get_node_data( query: str, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, query_param: QueryParam, ): # get similar entities logger.info( f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}" ) results = await entities_vdb.query( query, top_k=query_param.top_k, ids=query_param.ids ) if not len(results): return [], [] # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] # Call the batch node retrieval and degree functions concurrently. nodes_dict, degrees_dict = await asyncio.gather( knowledge_graph_inst.get_nodes_batch(node_ids), knowledge_graph_inst.node_degrees_batch(node_ids), ) # Now, if you need the node data and degree in order: node_datas = [nodes_dict.get(nid) for nid in node_ids] node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids] if not all([n is not None for n in node_datas]): logger.warning("Some nodes are missing, maybe the storage is damaged") node_datas = [ { **n, "entity_name": k["entity_name"], "rank": d, "created_at": k.get("created_at"), } for k, n, d in zip(results, node_datas, node_degrees) if n is not None ] use_relations = await _find_most_related_edges_from_entities( node_datas, query_param, knowledge_graph_inst, ) logger.info( f"Local query: {len(node_datas)} entites, {len(use_relations)} relations" ) # Entities are sorted by cosine similarity # Relations are sorted by rank + weight return node_datas, use_relations async def _find_most_related_edges_from_entities( node_datas: list[dict], query_param: QueryParam, knowledge_graph_inst: BaseGraphStorage, ): node_names = [dp["entity_name"] for dp in node_datas] batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names) all_edges = [] seen = set() for node_name in node_names: this_edges = batch_edges_dict.get(node_name, []) for e in this_edges: sorted_edge = tuple(sorted(e)) if sorted_edge not in seen: seen.add(sorted_edge) all_edges.append(sorted_edge) # Prepare edge pairs in two forms: # For the batch edge properties function, use dicts. edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges] # For edge degrees, use tuples. edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples # Call the batched functions concurrently. edge_data_dict, edge_degrees_dict = await asyncio.gather( knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples), ) # Reconstruct edge_datas list in the same order as the deduplicated results. all_edges_data = [] for pair in all_edges: edge_props = edge_data_dict.get(pair) if edge_props is not None: if "weight" not in edge_props: logger.warning( f"Edge {pair} missing 'weight' attribute, using default value 1.0" ) edge_props["weight"] = 1.0 combined = { "src_tgt": pair, "rank": edge_degrees_dict.get(pair, 0), **edge_props, } all_edges_data.append(combined) all_edges_data = sorted( all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True ) return all_edges_data async def _find_related_text_unit_from_entities( node_datas: list[dict], query_param: QueryParam, text_chunks_db: BaseKVStorage, knowledge_graph_inst: BaseGraphStorage, query: str = None, chunks_vdb: BaseVectorStorage = None, chunk_tracking: dict = None, query_embedding=None, ): """ Find text chunks related to entities using configurable chunk selection method. This function supports two chunk selection strategies: 1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count 2. VECTOR: Vector similarity-based selection using embedding cosine similarity """ logger.debug(f"Finding text chunks from {len(node_datas)} entities") if not node_datas: return [] # Step 1: Collect all text chunks for each entity entities_with_chunks = [] for entity in node_datas: if entity.get("source_id"): chunks = split_string_by_multi_markers( entity["source_id"], [GRAPH_FIELD_SEP] ) if chunks: entities_with_chunks.append( { "entity_name": entity["entity_name"], "chunks": chunks, "entity_data": entity, } ) if not entities_with_chunks: logger.warning("No entities with text chunks found") return [] kg_chunk_pick_method = text_chunks_db.global_config.get( "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD ) max_related_chunks = text_chunks_db.global_config.get( "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER ) # Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities) chunk_occurrence_count = {} for entity_info in entities_with_chunks: deduplicated_chunks = [] for chunk_id in entity_info["chunks"]: chunk_occurrence_count[chunk_id] = ( chunk_occurrence_count.get(chunk_id, 0) + 1 ) # If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position) if chunk_occurrence_count[chunk_id] == 1: deduplicated_chunks.append(chunk_id) # count > 1 means this chunk appeared in an earlier entity, so skip it # Update entity's chunks to deduplicated chunks entity_info["chunks"] = deduplicated_chunks # Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority) total_entity_chunks = 0 for entity_info in entities_with_chunks: sorted_chunks = sorted( entity_info["chunks"], key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0), reverse=True, ) entity_info["sorted_chunks"] = sorted_chunks total_entity_chunks += len(sorted_chunks) selected_chunk_ids = [] # Initialize to avoid UnboundLocalError # Step 4: Apply the selected chunk selection algorithm # Pick by vector similarity: # The order of text chunks aligns with the naive retrieval's destination. # When reranking is disabled, the text chunks delivered to the LLM tend to favor naive retrieval. if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2) # Get embedding function from global config embedding_func_config = text_chunks_db.embedding_func if not embedding_func_config: logger.warning("No embedding function found, falling back to WEIGHT method") kg_chunk_pick_method = "WEIGHT" else: try: actual_embedding_func = embedding_func_config.func selected_chunk_ids = None if actual_embedding_func: selected_chunk_ids = await pick_by_vector_similarity( query=query, text_chunks_storage=text_chunks_db, chunks_vdb=chunks_vdb, num_of_chunks=num_of_chunks, entity_info=entities_with_chunks, embedding_func=actual_embedding_func, query_embedding=query_embedding, ) if selected_chunk_ids == []: kg_chunk_pick_method = "WEIGHT" logger.warning( "No entity-related chunks selected by vector similarity, falling back to WEIGHT method" ) else: logger.info( f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by vector similarity" ) except Exception as e: logger.error( f"Error in vector similarity sorting: {e}, falling back to WEIGHT method" ) kg_chunk_pick_method = "WEIGHT" if kg_chunk_pick_method == "WEIGHT": # Pick by entity and chunk weight: # When reranking is disabled, delivered more solely KG related chunks to the LLM selected_chunk_ids = pick_by_weighted_polling( entities_with_chunks, max_related_chunks, min_related_chunks=1 ) logger.info( f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by weighted polling" ) if not selected_chunk_ids: return [] # Step 5: Batch retrieve chunk data unique_chunk_ids = list( dict.fromkeys(selected_chunk_ids) ) # Remove duplicates while preserving order chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids) # Step 6: Build result chunks with valid data and update chunk tracking result_chunks = [] for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list)): if chunk_data is not None and "content" in chunk_data: chunk_data_copy = chunk_data.copy() chunk_data_copy["source_type"] = "entity" chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication result_chunks.append(chunk_data_copy) # Update chunk tracking if provided if chunk_tracking is not None: chunk_tracking[chunk_id] = { "source": "E", "frequency": chunk_occurrence_count.get(chunk_id, 1), "order": i + 1, # 1-based order in final entity-related results } return result_chunks async def _get_edge_data( keywords, knowledge_graph_inst: BaseGraphStorage, relationships_vdb: BaseVectorStorage, query_param: QueryParam, ): logger.info( f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}" ) results = await relationships_vdb.query( keywords, top_k=query_param.top_k, ids=query_param.ids ) if not len(results): return [], [] # Prepare edge pairs in two forms: # For the batch edge properties function, use dicts. edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results] edge_data_dict = await knowledge_graph_inst.get_edges_batch(edge_pairs_dicts) # Reconstruct edge_datas list in the same order as results. edge_datas = [] for k in results: pair = (k["src_id"], k["tgt_id"]) edge_props = edge_data_dict.get(pair) if edge_props is not None: if "weight" not in edge_props: logger.warning( f"Edge {pair} missing 'weight' attribute, using default value 1.0" ) edge_props["weight"] = 1.0 # Keep edge data without rank, maintain vector search order combined = { "src_id": k["src_id"], "tgt_id": k["tgt_id"], "created_at": k.get("created_at", None), **edge_props, } edge_datas.append(combined) # Relations maintain vector search order (sorted by similarity) use_entities = await _find_most_related_entities_from_relationships( edge_datas, query_param, knowledge_graph_inst, ) logger.info( f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations" ) return edge_datas, use_entities async def _find_most_related_entities_from_relationships( edge_datas: list[dict], query_param: QueryParam, knowledge_graph_inst: BaseGraphStorage, ): entity_names = [] seen = set() for e in edge_datas: if e["src_id"] not in seen: entity_names.append(e["src_id"]) seen.add(e["src_id"]) if e["tgt_id"] not in seen: entity_names.append(e["tgt_id"]) seen.add(e["tgt_id"]) # Only get nodes data, no need for node degrees nodes_dict = await knowledge_graph_inst.get_nodes_batch(entity_names) # Rebuild the list in the same order as entity_names node_datas = [] for entity_name in entity_names: node = nodes_dict.get(entity_name) if node is None: logger.warning(f"Node '{entity_name}' not found in batch retrieval.") continue # Combine the node data with the entity name, no rank needed combined = {**node, "entity_name": entity_name} node_datas.append(combined) return node_datas async def _find_related_text_unit_from_relations( edge_datas: list[dict], query_param: QueryParam, text_chunks_db: BaseKVStorage, entity_chunks: list[dict] = None, query: str = None, chunks_vdb: BaseVectorStorage = None, chunk_tracking: dict = None, query_embedding=None, ): """ Find text chunks related to relationships using configurable chunk selection method. This function supports two chunk selection strategies: 1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count 2. VECTOR: Vector similarity-based selection using embedding cosine similarity """ logger.debug(f"Finding text chunks from {len(edge_datas)} relations") if not edge_datas: return [] # Step 1: Collect all text chunks for each relationship relations_with_chunks = [] for relation in edge_datas: if relation.get("source_id"): chunks = split_string_by_multi_markers( relation["source_id"], [GRAPH_FIELD_SEP] ) if chunks: # Build relation identifier if "src_tgt" in relation: rel_key = tuple(sorted(relation["src_tgt"])) else: rel_key = tuple( sorted([relation.get("src_id"), relation.get("tgt_id")]) ) relations_with_chunks.append( { "relation_key": rel_key, "chunks": chunks, "relation_data": relation, } ) if not relations_with_chunks: logger.warning("No relation-related chunks found") return [] kg_chunk_pick_method = text_chunks_db.global_config.get( "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD ) max_related_chunks = text_chunks_db.global_config.get( "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER ) # Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned relationships) # Also remove duplicates with entity_chunks # Extract chunk IDs from entity_chunks for deduplication entity_chunk_ids = set() if entity_chunks: for chunk in entity_chunks: chunk_id = chunk.get("chunk_id") if chunk_id: entity_chunk_ids.add(chunk_id) chunk_occurrence_count = {} # Track unique chunk_ids that have been removed to avoid double counting removed_entity_chunk_ids = set() for relation_info in relations_with_chunks: deduplicated_chunks = [] for chunk_id in relation_info["chunks"]: # Skip chunks that already exist in entity_chunks if chunk_id in entity_chunk_ids: # Only count each unique chunk_id once removed_entity_chunk_ids.add(chunk_id) continue chunk_occurrence_count[chunk_id] = ( chunk_occurrence_count.get(chunk_id, 0) + 1 ) # If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position) if chunk_occurrence_count[chunk_id] == 1: deduplicated_chunks.append(chunk_id) # count > 1 means this chunk appeared in an earlier relationship, so skip it # Update relationship's chunks to deduplicated chunks relation_info["chunks"] = deduplicated_chunks # Check if any relations still have chunks after deduplication relations_with_chunks = [ relation_info for relation_info in relations_with_chunks if relation_info["chunks"] ] if not relations_with_chunks: logger.info( f"Find no additional relations-related chunks from {len(edge_datas)} relations" ) return [] # Step 3: Sort chunks for each relationship by occurrence count (higher count = higher priority) total_relation_chunks = 0 for relation_info in relations_with_chunks: sorted_chunks = sorted( relation_info["chunks"], key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0), reverse=True, ) relation_info["sorted_chunks"] = sorted_chunks total_relation_chunks += len(sorted_chunks) logger.info( f"Find {total_relation_chunks} additional chunks in {len(relations_with_chunks)} relations ({len(removed_entity_chunk_ids)} duplicated chunks removed)" ) # Step 4: Apply the selected chunk selection algorithm selected_chunk_ids = [] # Initialize to avoid UnboundLocalError if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2) # Get embedding function from global config embedding_func_config = text_chunks_db.embedding_func if not embedding_func_config: logger.warning("No embedding function found, falling back to WEIGHT method") kg_chunk_pick_method = "WEIGHT" else: try: actual_embedding_func = embedding_func_config.func if actual_embedding_func: selected_chunk_ids = await pick_by_vector_similarity( query=query, text_chunks_storage=text_chunks_db, chunks_vdb=chunks_vdb, num_of_chunks=num_of_chunks, entity_info=relations_with_chunks, embedding_func=actual_embedding_func, query_embedding=query_embedding, ) if selected_chunk_ids == []: kg_chunk_pick_method = "WEIGHT" logger.warning( "No relation-related chunks selected by vector similarity, falling back to WEIGHT method" ) else: logger.info( f"Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by vector similarity" ) except Exception as e: logger.error( f"Error in vector similarity sorting: {e}, falling back to WEIGHT method" ) kg_chunk_pick_method = "WEIGHT" if kg_chunk_pick_method == "WEIGHT": # Apply linear gradient weighted polling algorithm selected_chunk_ids = pick_by_weighted_polling( relations_with_chunks, max_related_chunks, min_related_chunks=1 ) logger.info( f"Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by weighted polling" ) logger.debug( f"KG related chunks: {len(entity_chunks)} from entitys, {len(selected_chunk_ids)} from relations" ) if not selected_chunk_ids: return [] # Step 5: Batch retrieve chunk data unique_chunk_ids = list( dict.fromkeys(selected_chunk_ids) ) # Remove duplicates while preserving order chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids) # Step 6: Build result chunks with valid data and update chunk tracking result_chunks = [] for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list)): if chunk_data is not None and "content" in chunk_data: chunk_data_copy = chunk_data.copy() chunk_data_copy["source_type"] = "relationship" chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication result_chunks.append(chunk_data_copy) # Update chunk tracking if provided if chunk_tracking is not None: chunk_tracking[chunk_id] = { "source": "R", "frequency": chunk_occurrence_count.get(chunk_id, 1), "order": i + 1, # 1-based order in final relation-related results } return result_chunks async def naive_query( query: str, chunks_vdb: BaseVectorStorage, query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, ) -> str | AsyncIterator[str]: if query_param.model_func: use_model_func = query_param.model_func else: use_model_func = global_config["llm_model_func"] # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) # Handle cache args_hash = compute_args_hash( query_param.mode, query, query_param.response_type, query_param.top_k, query_param.chunk_top_k, query_param.max_entity_tokens, query_param.max_relation_tokens, query_param.max_total_tokens, query_param.hl_keywords or [], query_param.ll_keywords or [], query_param.user_prompt or "", query_param.enable_rerank, ) cached_response = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) if cached_response is not None: return cached_response tokenizer: Tokenizer = global_config["tokenizer"] chunks = await _get_vector_context(query, chunks_vdb, query_param) if chunks is None or len(chunks) == 0: return PROMPTS["fail_response"] # Calculate dynamic token limit for chunks # Get token limits from query_param (with fallback to global_config) max_total_tokens = getattr( query_param, "max_total_tokens", global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS), ) # Calculate conversation history tokens history_context = "" if query_param.conversation_history: history_context = get_conversation_turns( query_param.conversation_history, query_param.history_turns ) history_tokens = len(tokenizer.encode(history_context)) if history_context else 0 # Calculate system prompt template tokens (excluding content_data) user_prompt = query_param.user_prompt if query_param.user_prompt else "" response_type = ( query_param.response_type if query_param.response_type else "Multiple Paragraphs" ) # Use the provided system prompt or default sys_prompt_template = ( system_prompt if system_prompt else PROMPTS["naive_rag_response"] ) # Create a sample system prompt with empty content_data to calculate overhead sample_sys_prompt = sys_prompt_template.format( content_data="", # Empty for overhead calculation response_type=response_type, history=history_context, user_prompt=user_prompt, ) sys_prompt_template_tokens = len(tokenizer.encode(sample_sys_prompt)) # Total system prompt overhead = template + query tokens query_tokens = len(tokenizer.encode(query)) sys_prompt_overhead = sys_prompt_template_tokens + query_tokens buffer_tokens = 100 # Safety buffer # Calculate available tokens for chunks used_tokens = sys_prompt_overhead + buffer_tokens available_chunk_tokens = max_total_tokens - used_tokens logger.debug( f"Naive query token allocation - Total: {max_total_tokens}, History: {history_tokens}, SysPrompt: {sys_prompt_overhead}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}" ) # Process chunks using unified processing with dynamic token limit processed_chunks = await process_chunks_unified( query=query, unique_chunks=chunks, query_param=query_param, global_config=global_config, source_type="vector", chunk_token_limit=available_chunk_tokens, # Pass dynamic limit ) logger.info(f"Final context: {len(processed_chunks)} chunks") # Build text_units_context from processed chunks text_units_context = [] for i, chunk in enumerate(processed_chunks): text_units_context.append( { "id": i + 1, "content": chunk["content"], "file_path": chunk.get("file_path", "unknown_source"), } ) text_units_str = json.dumps(text_units_context, ensure_ascii=False) if query_param.only_need_context: return f""" ---Document Chunks(DC)--- ```json {text_units_str} ``` """ # Process conversation history history_context = "" if query_param.conversation_history: history_context = get_conversation_turns( query_param.conversation_history, query_param.history_turns ) # Build system prompt user_prompt = ( query_param.user_prompt if query_param.user_prompt else PROMPTS["DEFAULT_USER_PROMPT"] ) sys_prompt_temp = system_prompt if system_prompt else PROMPTS["naive_rag_response"] sys_prompt = sys_prompt_temp.format( content_data=text_units_str, response_type=query_param.response_type, history=history_context, user_prompt=user_prompt, ) if query_param.only_need_prompt: return sys_prompt len_of_prompts = len(tokenizer.encode(query + sys_prompt)) logger.debug( f"[naive_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})" ) response = await use_model_func( query, system_prompt=sys_prompt, stream=query_param.stream, ) if isinstance(response, str) and len(response) > len(sys_prompt): response = ( response[len(sys_prompt) :] .replace(sys_prompt, "") .replace("user", "") .replace("model", "") .replace(query, "") .replace("", "") .replace("", "") .strip() ) if hashing_kv.global_config.get("enable_llm_cache"): # Save to cache with query parameters queryparam_dict = { "mode": query_param.mode, "response_type": query_param.response_type, "top_k": query_param.top_k, "chunk_top_k": query_param.chunk_top_k, "max_entity_tokens": query_param.max_entity_tokens, "max_relation_tokens": query_param.max_relation_tokens, "max_total_tokens": query_param.max_total_tokens, "hl_keywords": query_param.hl_keywords or [], "ll_keywords": query_param.ll_keywords or [], "user_prompt": query_param.user_prompt or "", "enable_rerank": query_param.enable_rerank, } await save_to_cache( hashing_kv, CacheData( args_hash=args_hash, content=response, prompt=query, mode=query_param.mode, cache_type="query", queryparam=queryparam_dict, ), ) return response