from __future__ import annotations from functools import partial import asyncio import json import re import os from typing import Any, AsyncIterator from collections import Counter, defaultdict from .utils import ( logger, clean_str, compute_mdhash_id, Tokenizer, is_float_regex, normalize_extracted_info, pack_user_ass_to_openai_messages, split_string_by_multi_markers, truncate_list_by_token_size, process_combine_contexts, compute_args_hash, handle_cache, save_to_cache, CacheData, get_conversation_turns, use_llm_func_with_cache, update_chunk_cache_list, ) from .base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, TextChunkSchema, QueryParam, ) from .prompt import PROMPTS from .constants import GRAPH_FIELD_SEP import time from dotenv import load_dotenv # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance # the OS environment variables take precedence over the .env file load_dotenv(dotenv_path=".env", override=False) def chunking_by_token_size( tokenizer: Tokenizer, content: str, split_by_character: str | None = None, split_by_character_only: bool = False, overlap_token_size: int = 128, max_token_size: int = 1024, ) -> list[dict[str, Any]]: tokens = tokenizer.encode(content) results: list[dict[str, Any]] = [] if split_by_character: raw_chunks = content.split(split_by_character) new_chunks = [] if split_by_character_only: for chunk in raw_chunks: _tokens = tokenizer.encode(chunk) new_chunks.append((len(_tokens), chunk)) else: for chunk in raw_chunks: _tokens = tokenizer.encode(chunk) if len(_tokens) > max_token_size: for start in range( 0, len(_tokens), max_token_size - overlap_token_size ): chunk_content = tokenizer.decode( _tokens[start : start + max_token_size] ) new_chunks.append( (min(max_token_size, len(_tokens) - start), chunk_content) ) else: new_chunks.append((len(_tokens), chunk)) for index, (_len, chunk) in enumerate(new_chunks): results.append( { "tokens": _len, "content": chunk.strip(), "chunk_order_index": index, } ) else: for index, start in enumerate( range(0, len(tokens), max_token_size - overlap_token_size) ): chunk_content = tokenizer.decode(tokens[start : start + max_token_size]) results.append( { "tokens": min(max_token_size, len(tokens) - start), "content": chunk_content.strip(), "chunk_order_index": index, } ) return results async def _handle_entity_relation_summary( entity_or_relation_name: str, description: str, global_config: dict, llm_response_cache: BaseKVStorage | None = None, ) -> str: """Handle entity relation summary For each entity or relation, input is the combined description of already existing description and new description. If too long, use LLM to summarize. """ use_llm_func: callable = global_config["llm_model_func"] # Apply higher priority (8) to entity/relation summary tasks use_llm_func = partial(use_llm_func, _priority=8) tokenizer: Tokenizer = global_config["tokenizer"] llm_max_tokens = global_config["llm_model_max_token_size"] summary_max_tokens = global_config["summary_to_max_tokens"] language = global_config["addon_params"].get( "language", PROMPTS["DEFAULT_LANGUAGE"] ) tokens = tokenizer.encode(description) ### summarize is not determined here anymore (It's determined by num_fragment now) # if len(tokens) < summary_max_tokens: # No need for summary # return description prompt_template = PROMPTS["summarize_entity_descriptions"] use_description = tokenizer.decode(tokens[:llm_max_tokens]) context_base = dict( entity_name=entity_or_relation_name, description_list=use_description.split(GRAPH_FIELD_SEP), language=language, ) use_prompt = prompt_template.format(**context_base) logger.debug(f"Trigger summary: {entity_or_relation_name}") # Use LLM function with cache (higher priority for summary generation) summary = await use_llm_func_with_cache( use_prompt, use_llm_func, llm_response_cache=llm_response_cache, max_tokens=summary_max_tokens, cache_type="extract", ) return summary async def _handle_single_entity_extraction( record_attributes: list[str], chunk_key: str, file_path: str = "unknown_source", ): if len(record_attributes) < 4 or '"entity"' not in record_attributes[0]: return None # Clean and validate entity name entity_name = clean_str(record_attributes[1]).strip() if not entity_name: logger.warning( f"Entity extraction error: empty entity name in: {record_attributes}" ) return None # Normalize entity name entity_name = normalize_extracted_info(entity_name, is_entity=True) # Clean and validate entity type entity_type = clean_str(record_attributes[2]).strip('"') if not entity_type.strip() or entity_type.startswith('("'): logger.warning( f"Entity extraction error: invalid entity type in: {record_attributes}" ) return None # Clean and validate description entity_description = clean_str(record_attributes[3]) entity_description = normalize_extracted_info(entity_description) if not entity_description.strip(): logger.warning( f"Entity extraction error: empty description for entity '{entity_name}' of type '{entity_type}'" ) return None return dict( entity_name=entity_name, entity_type=entity_type, description=entity_description, source_id=chunk_key, file_path=file_path, ) async def _handle_single_relationship_extraction( record_attributes: list[str], chunk_key: str, file_path: str = "unknown_source", ): if len(record_attributes) < 5 or '"relationship"' not in record_attributes[0]: return None # add this record as edge source = clean_str(record_attributes[1]) target = clean_str(record_attributes[2]) # Normalize source and target entity names source = normalize_extracted_info(source, is_entity=True) target = normalize_extracted_info(target, is_entity=True) if source == target: logger.debug( f"Relationship source and target are the same in: {record_attributes}" ) return None edge_description = clean_str(record_attributes[3]) edge_description = normalize_extracted_info(edge_description) edge_keywords = normalize_extracted_info( clean_str(record_attributes[4]), is_entity=True ) edge_keywords = edge_keywords.replace(",", ",") edge_source_id = chunk_key weight = ( float(record_attributes[-1].strip('"').strip("'")) if is_float_regex(record_attributes[-1].strip('"').strip("'")) else 1.0 ) return dict( src_id=source, tgt_id=target, weight=weight, description=edge_description, keywords=edge_keywords, source_id=edge_source_id, file_path=file_path, ) async def _rebuild_knowledge_from_chunks( entities_to_rebuild: dict[str, set[str]], relationships_to_rebuild: dict[tuple[str, str], set[str]], knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, text_chunks_storage: BaseKVStorage, llm_response_cache: BaseKVStorage, global_config: dict[str, str], pipeline_status: dict | None = None, pipeline_status_lock=None, ) -> None: """Rebuild entity and relationship descriptions from cached extraction results This method uses cached LLM extraction results instead of calling LLM again, following the same approach as the insert process. Args: entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids text_chunks_data: Pre-loaded chunk data dict {chunk_id: chunk_data} """ if not entities_to_rebuild and not relationships_to_rebuild: return rebuilt_entities_count = 0 rebuilt_relationships_count = 0 # Get all referenced chunk IDs all_referenced_chunk_ids = set() for chunk_ids in entities_to_rebuild.values(): all_referenced_chunk_ids.update(chunk_ids) for chunk_ids in relationships_to_rebuild.values(): all_referenced_chunk_ids.update(chunk_ids) status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) # Get cached extraction results for these chunks using storage # cached_results: chunk_id -> [list of extraction result from LLM cache sorted by created_at] cached_results = await _get_cached_extraction_results( llm_response_cache, all_referenced_chunk_ids, text_chunks_storage=text_chunks_storage, ) if not cached_results: status_message = "No cached extraction results found, cannot rebuild" logger.warning(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) return # Process cached results to get entities and relationships for each chunk chunk_entities = {} # chunk_id -> {entity_name: [entity_data]} chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]} for chunk_id, extraction_results in cached_results.items(): try: # Handle multiple extraction results per chunk chunk_entities[chunk_id] = defaultdict(list) chunk_relationships[chunk_id] = defaultdict(list) # process multiple LLM extraction results for a single chunk_id for extraction_result in extraction_results: entities, relationships = await _parse_extraction_result( text_chunks_storage=text_chunks_storage, extraction_result=extraction_result, chunk_id=chunk_id, ) # Merge entities and relationships from this extraction result # Only keep the first occurrence of each entity_name in the same chunk_id for entity_name, entity_list in entities.items(): if ( entity_name not in chunk_entities[chunk_id] or len(chunk_entities[chunk_id][entity_name]) == 0 ): chunk_entities[chunk_id][entity_name].extend(entity_list) # Only keep the first occurrence of each rel_key in the same chunk_id for rel_key, rel_list in relationships.items(): if ( rel_key not in chunk_relationships[chunk_id] or len(chunk_relationships[chunk_id][rel_key]) == 0 ): chunk_relationships[chunk_id][rel_key].extend(rel_list) except Exception as e: status_message = ( f"Failed to parse cached extraction result for chunk {chunk_id}: {e}" ) logger.info(status_message) # Per requirement, change to info if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) continue # Rebuild entities for entity_name, chunk_ids in entities_to_rebuild.items(): try: await _rebuild_single_entity( knowledge_graph_inst=knowledge_graph_inst, entities_vdb=entities_vdb, entity_name=entity_name, chunk_ids=chunk_ids, chunk_entities=chunk_entities, llm_response_cache=llm_response_cache, global_config=global_config, ) rebuilt_entities_count += 1 status_message = ( f"Rebuilt entity: {entity_name} from {len(chunk_ids)} chunks" ) logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) except Exception as e: status_message = f"Failed to rebuild entity {entity_name}: {e}" logger.info(status_message) # Per requirement, change to info if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) # Rebuild relationships for (src, tgt), chunk_ids in relationships_to_rebuild.items(): try: await _rebuild_single_relationship( knowledge_graph_inst=knowledge_graph_inst, relationships_vdb=relationships_vdb, src=src, tgt=tgt, chunk_ids=chunk_ids, chunk_relationships=chunk_relationships, llm_response_cache=llm_response_cache, global_config=global_config, ) rebuilt_relationships_count += 1 status_message = ( f"Rebuilt relationship: {src}->{tgt} from {len(chunk_ids)} chunks" ) logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) except Exception as e: status_message = f"Failed to rebuild relationship {src}->{tgt}: {e}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships." logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) async def _get_cached_extraction_results( llm_response_cache: BaseKVStorage, chunk_ids: set[str], text_chunks_storage: BaseKVStorage, ) -> dict[str, list[str]]: """Get cached extraction results for specific chunk IDs Args: llm_response_cache: LLM response cache storage chunk_ids: Set of chunk IDs to get cached results for text_chunks_data: Pre-loaded chunk data (optional, for performance) text_chunks_storage: Text chunks storage (fallback if text_chunks_data is None) Returns: Dict mapping chunk_id -> list of extraction_result_text """ cached_results = {} # Collect all LLM cache IDs from chunks all_cache_ids = set() # Read from storage chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids)) for chunk_id, chunk_data in zip(chunk_ids, chunk_data_list): if chunk_data and isinstance(chunk_data, dict): llm_cache_list = chunk_data.get("llm_cache_list", []) if llm_cache_list: all_cache_ids.update(llm_cache_list) else: logger.warning( f"Chunk {chunk_id} data is invalid or None: {type(chunk_data)}" ) if not all_cache_ids: logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs") return cached_results # Batch get LLM cache entries cache_data_list = await llm_response_cache.get_by_ids(list(all_cache_ids)) # Process cache entries and group by chunk_id valid_entries = 0 for cache_id, cache_entry in zip(all_cache_ids, cache_data_list): if ( cache_entry is not None and isinstance(cache_entry, dict) and cache_entry.get("cache_type") == "extract" and cache_entry.get("chunk_id") in chunk_ids ): chunk_id = cache_entry["chunk_id"] extraction_result = cache_entry["return"] create_time = cache_entry.get( "create_time", 0 ) # Get creation time, default to 0 valid_entries += 1 # Support multiple LLM caches per chunk if chunk_id not in cached_results: cached_results[chunk_id] = [] # Store tuple with extraction result and creation time for sorting cached_results[chunk_id].append((extraction_result, create_time)) # Sort extraction results by create_time for each chunk for chunk_id in cached_results: # Sort by create_time (x[1]), then extract only extraction_result (x[0]) cached_results[chunk_id].sort(key=lambda x: x[1]) cached_results[chunk_id] = [item[0] for item in cached_results[chunk_id]] logger.info( f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results" ) return cached_results async def _parse_extraction_result( text_chunks_storage: BaseKVStorage, extraction_result: str, chunk_id: str ) -> tuple[dict, dict]: """Parse cached extraction result using the same logic as extract_entities Args: text_chunks_storage: Text chunks storage to get chunk data extraction_result: The cached LLM extraction result chunk_id: The chunk ID for source tracking Returns: Tuple of (entities_dict, relationships_dict) """ # Get chunk data for file_path from storage chunk_data = await text_chunks_storage.get_by_id(chunk_id) file_path = ( chunk_data.get("file_path", "unknown_source") if chunk_data else "unknown_source" ) context_base = dict( tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], ) maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) # Parse the extraction result using the same logic as in extract_entities records = split_string_by_multi_markers( extraction_result, [context_base["record_delimiter"], context_base["completion_delimiter"]], ) for record in records: record = re.search(r"\((.*)\)", record) if record is None: continue record = record.group(1) record_attributes = split_string_by_multi_markers( record, [context_base["tuple_delimiter"]] ) # Try to parse as entity entity_data = await _handle_single_entity_extraction( record_attributes, chunk_id, file_path ) if entity_data is not None: maybe_nodes[entity_data["entity_name"]].append(entity_data) continue # Try to parse as relationship relationship_data = await _handle_single_relationship_extraction( record_attributes, chunk_id, file_path ) if relationship_data is not None: maybe_edges[ (relationship_data["src_id"], relationship_data["tgt_id"]) ].append(relationship_data) return dict(maybe_nodes), dict(maybe_edges) async def _rebuild_single_entity( knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, entity_name: str, chunk_ids: set[str], chunk_entities: dict, llm_response_cache: BaseKVStorage, global_config: dict[str, str], ) -> None: """Rebuild a single entity from cached extraction results""" # Get current entity data current_entity = await knowledge_graph_inst.get_node(entity_name) if not current_entity: return # Helper function to update entity in both graph and vector storage async def _update_entity_storage( final_description: str, entity_type: str, file_paths: set[str] ): # Update entity in graph storage updated_entity_data = { **current_entity, "description": final_description, "entity_type": entity_type, "source_id": GRAPH_FIELD_SEP.join(chunk_ids), "file_path": GRAPH_FIELD_SEP.join(file_paths) if file_paths else current_entity.get("file_path", "unknown_source"), } await knowledge_graph_inst.upsert_node(entity_name, updated_entity_data) # Update entity in vector database entity_vdb_id = compute_mdhash_id(entity_name, prefix="ent-") # Delete old vector record first try: await entities_vdb.delete([entity_vdb_id]) except Exception as e: logger.debug( f"Could not delete old entity vector record {entity_vdb_id}: {e}" ) # Insert new vector record entity_content = f"{entity_name}\n{final_description}" await entities_vdb.upsert( { entity_vdb_id: { "content": entity_content, "entity_name": entity_name, "source_id": updated_entity_data["source_id"], "description": final_description, "entity_type": entity_type, "file_path": updated_entity_data["file_path"], } } ) # Helper function to generate final description with optional LLM summary async def _generate_final_description(combined_description: str) -> str: if len(combined_description) > global_config["summary_to_max_tokens"]: return await _handle_entity_relation_summary( entity_name, combined_description, global_config, llm_response_cache=llm_response_cache, ) else: return combined_description # Collect all entity data from relevant chunks all_entity_data = [] for chunk_id in chunk_ids: if chunk_id in chunk_entities and entity_name in chunk_entities[chunk_id]: all_entity_data.extend(chunk_entities[chunk_id][entity_name]) if not all_entity_data: logger.warning( f"No cached entity data found for {entity_name}, trying to rebuild from relationships" ) # Get all edges connected to this entity edges = await knowledge_graph_inst.get_node_edges(entity_name) if not edges: logger.warning(f"No relationships found for entity {entity_name}") return # Collect relationship data to extract entity information relationship_descriptions = [] file_paths = set() # Get edge data for all connected relationships for src_id, tgt_id in edges: edge_data = await knowledge_graph_inst.get_edge(src_id, tgt_id) if edge_data: if edge_data.get("description"): relationship_descriptions.append(edge_data["description"]) if edge_data.get("file_path"): edge_file_paths = edge_data["file_path"].split(GRAPH_FIELD_SEP) file_paths.update(edge_file_paths) # Generate description from relationships or fallback to current if relationship_descriptions: combined_description = GRAPH_FIELD_SEP.join(relationship_descriptions) final_description = await _generate_final_description(combined_description) else: final_description = current_entity.get("description", "") entity_type = current_entity.get("entity_type", "UNKNOWN") await _update_entity_storage(final_description, entity_type, file_paths) return # Process cached entity data descriptions = [] entity_types = [] file_paths = set() for entity_data in all_entity_data: if entity_data.get("description"): descriptions.append(entity_data["description"]) if entity_data.get("entity_type"): entity_types.append(entity_data["entity_type"]) if entity_data.get("file_path"): file_paths.add(entity_data["file_path"]) # Combine all descriptions combined_description = ( GRAPH_FIELD_SEP.join(descriptions) if descriptions else current_entity.get("description", "") ) # Get most common entity type entity_type = ( max(set(entity_types), key=entity_types.count) if entity_types else current_entity.get("entity_type", "UNKNOWN") ) # Generate final description and update storage final_description = await _generate_final_description(combined_description) await _update_entity_storage(final_description, entity_type, file_paths) async def _rebuild_single_relationship( knowledge_graph_inst: BaseGraphStorage, relationships_vdb: BaseVectorStorage, src: str, tgt: str, chunk_ids: set[str], chunk_relationships: dict, llm_response_cache: BaseKVStorage, global_config: dict[str, str], ) -> None: """Rebuild a single relationship from cached extraction results""" # Get current relationship data current_relationship = await knowledge_graph_inst.get_edge(src, tgt) if not current_relationship: return # Collect all relationship data from relevant chunks all_relationship_data = [] for chunk_id in chunk_ids: if chunk_id in chunk_relationships: # Check both (src, tgt) and (tgt, src) since relationships can be bidirectional for edge_key in [(src, tgt), (tgt, src)]: if edge_key in chunk_relationships[chunk_id]: all_relationship_data.extend( chunk_relationships[chunk_id][edge_key] ) if not all_relationship_data: logger.warning(f"No cached relationship data found for {src}-{tgt}") return # Merge descriptions and keywords descriptions = [] keywords = [] weights = [] file_paths = set() for rel_data in all_relationship_data: if rel_data.get("description"): descriptions.append(rel_data["description"]) if rel_data.get("keywords"): keywords.append(rel_data["keywords"]) if rel_data.get("weight"): weights.append(rel_data["weight"]) if rel_data.get("file_path"): file_paths.add(rel_data["file_path"]) # Combine descriptions and keywords combined_description = ( GRAPH_FIELD_SEP.join(descriptions) if descriptions else current_relationship.get("description", "") ) combined_keywords = ( ", ".join(set(keywords)) if keywords else current_relationship.get("keywords", "") ) # weight = ( # sum(weights) / len(weights) # if weights # else current_relationship.get("weight", 1.0) # ) weight = sum(weights) if weights else current_relationship.get("weight", 1.0) # Use summary if description is too long if len(combined_description) > global_config["summary_to_max_tokens"]: final_description = await _handle_entity_relation_summary( f"{src}-{tgt}", combined_description, global_config, llm_response_cache=llm_response_cache, ) else: final_description = combined_description # Update relationship in graph storage updated_relationship_data = { **current_relationship, "description": final_description, "keywords": combined_keywords, "weight": weight, "source_id": GRAPH_FIELD_SEP.join(chunk_ids), "file_path": GRAPH_FIELD_SEP.join(file_paths) if file_paths else current_relationship.get("file_path", "unknown_source"), } await knowledge_graph_inst.upsert_edge(src, tgt, updated_relationship_data) # Update relationship in vector database rel_vdb_id = compute_mdhash_id(src + tgt, prefix="rel-") rel_vdb_id_reverse = compute_mdhash_id(tgt + src, prefix="rel-") # Delete old vector records first (both directions to be safe) try: await relationships_vdb.delete([rel_vdb_id, rel_vdb_id_reverse]) except Exception as e: logger.debug( f"Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}" ) # Insert new vector record rel_content = f"{combined_keywords}\t{src}\n{tgt}\n{final_description}" await relationships_vdb.upsert( { rel_vdb_id: { "src_id": src, "tgt_id": tgt, "source_id": updated_relationship_data["source_id"], "content": rel_content, "keywords": combined_keywords, "description": final_description, "weight": weight, "file_path": updated_relationship_data["file_path"], } } ) async def _merge_nodes_then_upsert( entity_name: str, nodes_data: list[dict], knowledge_graph_inst: BaseGraphStorage, global_config: dict, pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, ): """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" already_entity_types = [] already_source_ids = [] already_description = [] already_file_paths = [] already_node = await knowledge_graph_inst.get_node(entity_name) if already_node: already_entity_types.append(already_node["entity_type"]) already_source_ids.extend( split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP]) ) already_file_paths.extend( split_string_by_multi_markers(already_node["file_path"], [GRAPH_FIELD_SEP]) ) already_description.append(already_node["description"]) entity_type = sorted( Counter( [dp["entity_type"] for dp in nodes_data] + already_entity_types ).items(), key=lambda x: x[1], reverse=True, )[0][0] description = GRAPH_FIELD_SEP.join( sorted(set([dp["description"] for dp in nodes_data] + already_description)) ) source_id = GRAPH_FIELD_SEP.join( set([dp["source_id"] for dp in nodes_data] + already_source_ids) ) file_path = GRAPH_FIELD_SEP.join( set([dp["file_path"] for dp in nodes_data] + already_file_paths) ) force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] num_fragment = description.count(GRAPH_FIELD_SEP) + 1 num_new_fragment = len(set([dp["description"] for dp in nodes_data])) if num_fragment > 1: if num_fragment >= force_llm_summary_on_merge: status_message = f"LLM merge N: {entity_name} | {num_new_fragment}+{num_fragment-num_new_fragment}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) description = await _handle_entity_relation_summary( entity_name, description, global_config, llm_response_cache, ) else: status_message = f"Merge N: {entity_name} | {num_new_fragment}+{num_fragment-num_new_fragment}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) node_data = dict( entity_id=entity_name, entity_type=entity_type, description=description, source_id=source_id, file_path=file_path, created_at=int(time.time()), ) await knowledge_graph_inst.upsert_node( entity_name, node_data=node_data, ) node_data["entity_name"] = entity_name return node_data async def _merge_edges_then_upsert( src_id: str, tgt_id: str, edges_data: list[dict], knowledge_graph_inst: BaseGraphStorage, global_config: dict, pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, ): if src_id == tgt_id: return None already_weights = [] already_source_ids = [] already_description = [] already_keywords = [] already_file_paths = [] if await knowledge_graph_inst.has_edge(src_id, tgt_id): already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) # Handle the case where get_edge returns None or missing fields if already_edge: # Get weight with default 0.0 if missing already_weights.append(already_edge.get("weight", 0.0)) # Get source_id with empty string default if missing or None if already_edge.get("source_id") is not None: already_source_ids.extend( split_string_by_multi_markers( already_edge["source_id"], [GRAPH_FIELD_SEP] ) ) # Get file_path with empty string default if missing or None if already_edge.get("file_path") is not None: already_file_paths.extend( split_string_by_multi_markers( already_edge["file_path"], [GRAPH_FIELD_SEP] ) ) # Get description with empty string default if missing or None if already_edge.get("description") is not None: already_description.append(already_edge["description"]) # Get keywords with empty string default if missing or None if already_edge.get("keywords") is not None: already_keywords.extend( split_string_by_multi_markers( already_edge["keywords"], [GRAPH_FIELD_SEP] ) ) # Process edges_data with None checks weight = sum([dp["weight"] for dp in edges_data] + already_weights) description = GRAPH_FIELD_SEP.join( sorted( set( [dp["description"] for dp in edges_data if dp.get("description")] + already_description ) ) ) # Split all existing and new keywords into individual terms, then combine and deduplicate all_keywords = set() # Process already_keywords (which are comma-separated) for keyword_str in already_keywords: if keyword_str: # Skip empty strings all_keywords.update(k.strip() for k in keyword_str.split(",") if k.strip()) # Process new keywords from edges_data for edge in edges_data: if edge.get("keywords"): all_keywords.update( k.strip() for k in edge["keywords"].split(",") if k.strip() ) # Join all unique keywords with commas keywords = ",".join(sorted(all_keywords)) source_id = GRAPH_FIELD_SEP.join( set( [dp["source_id"] for dp in edges_data if dp.get("source_id")] + already_source_ids ) ) file_path = GRAPH_FIELD_SEP.join( set( [dp["file_path"] for dp in edges_data if dp.get("file_path")] + already_file_paths ) ) for need_insert_id in [src_id, tgt_id]: if not (await knowledge_graph_inst.has_node(need_insert_id)): # # Discard this edge if the node does not exist # if need_insert_id == src_id: # logger.warning( # f"Discard edge: {src_id} - {tgt_id} | Source node missing" # ) # else: # logger.warning( # f"Discard edge: {src_id} - {tgt_id} | Target node missing" # ) # return None await knowledge_graph_inst.upsert_node( need_insert_id, node_data={ "entity_id": need_insert_id, "source_id": source_id, "description": description, "entity_type": "UNKNOWN", "file_path": file_path, "created_at": int(time.time()), }, ) force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] num_fragment = description.count(GRAPH_FIELD_SEP) + 1 num_new_fragment = len( set([dp["description"] for dp in edges_data if dp.get("description")]) ) if num_fragment > 1: if num_fragment >= force_llm_summary_on_merge: status_message = f"LLM merge E: {src_id} - {tgt_id} | {num_new_fragment}+{num_fragment-num_new_fragment}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) description = await _handle_entity_relation_summary( f"({src_id}, {tgt_id})", description, global_config, llm_response_cache, ) else: status_message = f"Merge E: {src_id} - {tgt_id} | {num_new_fragment}+{num_fragment-num_new_fragment}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) await knowledge_graph_inst.upsert_edge( src_id, tgt_id, edge_data=dict( weight=weight, description=description, keywords=keywords, source_id=source_id, file_path=file_path, created_at=int(time.time()), ), ) edge_data = dict( src_id=src_id, tgt_id=tgt_id, description=description, keywords=keywords, source_id=source_id, file_path=file_path, created_at=int(time.time()), ) return edge_data async def merge_nodes_and_edges( chunk_results: list, knowledge_graph_inst: BaseGraphStorage, entity_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, global_config: dict[str, str], pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, current_file_number: int = 0, total_files: int = 0, file_path: str = "unknown_source", ) -> None: """Merge nodes and edges from extraction results Args: chunk_results: List of tuples (maybe_nodes, maybe_edges) containing extracted entities and relationships knowledge_graph_inst: Knowledge graph storage entity_vdb: Entity vector database relationships_vdb: Relationship vector database global_config: Global configuration pipeline_status: Pipeline status dictionary pipeline_status_lock: Lock for pipeline status llm_response_cache: LLM response cache """ # Get lock manager from shared storage from .kg.shared_storage import get_graph_db_lock # Collect all nodes and edges from all chunks all_nodes = defaultdict(list) all_edges = defaultdict(list) for maybe_nodes, maybe_edges in chunk_results: # Collect nodes for entity_name, entities in maybe_nodes.items(): all_nodes[entity_name].extend(entities) # Collect edges with sorted keys for undirected graph for edge_key, edges in maybe_edges.items(): sorted_edge_key = tuple(sorted(edge_key)) all_edges[sorted_edge_key].extend(edges) # Centralized processing of all nodes and edges entities_data = [] relationships_data = [] # Merge nodes and edges # Use graph database lock to ensure atomic merges and updates graph_db_lock = get_graph_db_lock(enable_logging=False) async with graph_db_lock: async with pipeline_status_lock: log_message = ( f"Merging stage {current_file_number}/{total_files}: {file_path}" ) logger.info(log_message) pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) # Process and update all entities at once for entity_name, entities in all_nodes.items(): entity_data = await _merge_nodes_then_upsert( entity_name, entities, knowledge_graph_inst, global_config, pipeline_status, pipeline_status_lock, llm_response_cache, ) entities_data.append(entity_data) # Process and update all relationships at once for edge_key, edges in all_edges.items(): edge_data = await _merge_edges_then_upsert( edge_key[0], edge_key[1], edges, knowledge_graph_inst, global_config, pipeline_status, pipeline_status_lock, llm_response_cache, ) if edge_data is not None: relationships_data.append(edge_data) # Update total counts total_entities_count = len(entities_data) total_relations_count = len(relationships_data) log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}" logger.info(log_message) if pipeline_status is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) # Update vector databases with all collected data if entity_vdb is not None and entities_data: data_for_vdb = { compute_mdhash_id(dp["entity_name"], prefix="ent-"): { "entity_name": dp["entity_name"], "entity_type": dp["entity_type"], "content": f"{dp['entity_name']}\n{dp['description']}", "source_id": dp["source_id"], "file_path": dp.get("file_path", "unknown_source"), } for dp in entities_data } await entity_vdb.upsert(data_for_vdb) log_message = f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}" logger.info(log_message) if pipeline_status is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) if relationships_vdb is not None and relationships_data: data_for_vdb = { compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { "src_id": dp["src_id"], "tgt_id": dp["tgt_id"], "keywords": dp["keywords"], "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}", "source_id": dp["source_id"], "file_path": dp.get("file_path", "unknown_source"), } for dp in relationships_data } await relationships_vdb.upsert(data_for_vdb) async def extract_entities( chunks: dict[str, TextChunkSchema], global_config: dict[str, str], pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, text_chunks_storage: BaseKVStorage | None = None, ) -> list: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] ordered_chunks = list(chunks.items()) # add language and example number params to prompt language = global_config["addon_params"].get( "language", PROMPTS["DEFAULT_LANGUAGE"] ) entity_types = global_config["addon_params"].get( "entity_types", PROMPTS["DEFAULT_ENTITY_TYPES"] ) example_number = global_config["addon_params"].get("example_number", None) if example_number and example_number < len(PROMPTS["entity_extraction_examples"]): examples = "\n".join( PROMPTS["entity_extraction_examples"][: int(example_number)] ) else: examples = "\n".join(PROMPTS["entity_extraction_examples"]) example_context_base = dict( tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], entity_types=", ".join(entity_types), language=language, ) # add example's format examples = examples.format(**example_context_base) entity_extract_prompt = PROMPTS["entity_extraction"] context_base = dict( tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], entity_types=",".join(entity_types), examples=examples, language=language, ) continue_prompt = PROMPTS["entity_continue_extraction"].format(**context_base) if_loop_prompt = PROMPTS["entity_if_loop_extraction"] processed_chunks = 0 total_chunks = len(ordered_chunks) async def _process_extraction_result( result: str, chunk_key: str, file_path: str = "unknown_source" ): """Process a single extraction result (either initial or gleaning) Args: result (str): The extraction result to process chunk_key (str): The chunk key for source tracking file_path (str): The file path for citation Returns: tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships """ maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) records = split_string_by_multi_markers( result, [context_base["record_delimiter"], context_base["completion_delimiter"]], ) for record in records: record = re.search(r"\((.*)\)", record) if record is None: continue record = record.group(1) record_attributes = split_string_by_multi_markers( record, [context_base["tuple_delimiter"]] ) if_entities = await _handle_single_entity_extraction( record_attributes, chunk_key, file_path ) if if_entities is not None: maybe_nodes[if_entities["entity_name"]].append(if_entities) continue if_relation = await _handle_single_relationship_extraction( record_attributes, chunk_key, file_path ) if if_relation is not None: maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append( if_relation ) return maybe_nodes, maybe_edges async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): """Process a single chunk Args: chunk_key_dp (tuple[str, TextChunkSchema]): ("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) Returns: tuple: (maybe_nodes, maybe_edges) containing extracted entities and relationships """ nonlocal processed_chunks chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] content = chunk_dp["content"] # Get file path from chunk data or use default file_path = chunk_dp.get("file_path", "unknown_source") # Create cache keys collector for batch processing cache_keys_collector = [] # Get initial extraction hint_prompt = entity_extract_prompt.format( **{**context_base, "input_text": content} ) final_result = await use_llm_func_with_cache( hint_prompt, use_llm_func, llm_response_cache=llm_response_cache, cache_type="extract", chunk_id=chunk_key, cache_keys_collector=cache_keys_collector, ) # Store LLM cache reference in chunk (will be handled by use_llm_func_with_cache) history = pack_user_ass_to_openai_messages(hint_prompt, final_result) # Process initial extraction with file path maybe_nodes, maybe_edges = await _process_extraction_result( final_result, chunk_key, file_path ) # Process additional gleaning results for now_glean_index in range(entity_extract_max_gleaning): glean_result = await use_llm_func_with_cache( continue_prompt, use_llm_func, llm_response_cache=llm_response_cache, history_messages=history, cache_type="extract", chunk_id=chunk_key, cache_keys_collector=cache_keys_collector, ) history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) # Process gleaning result separately with file path glean_nodes, glean_edges = await _process_extraction_result( glean_result, chunk_key, file_path ) # Merge results - only add entities and edges with new names for entity_name, entities in glean_nodes.items(): if ( entity_name not in maybe_nodes ): # Only accetp entities with new name in gleaning stage maybe_nodes[entity_name].extend(entities) for edge_key, edges in glean_edges.items(): if ( edge_key not in maybe_edges ): # Only accetp edges with new name in gleaning stage maybe_edges[edge_key].extend(edges) if now_glean_index == entity_extract_max_gleaning - 1: break if_loop_result: str = await use_llm_func_with_cache( if_loop_prompt, use_llm_func, llm_response_cache=llm_response_cache, history_messages=history, cache_type="extract", cache_keys_collector=cache_keys_collector, ) if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if if_loop_result != "yes": break # Batch update chunk's llm_cache_list with all collected cache keys if cache_keys_collector and text_chunks_storage: await update_chunk_cache_list( chunk_key, text_chunks_storage, cache_keys_collector, "entity_extraction", ) processed_chunks += 1 entities_count = len(maybe_nodes) relations_count = len(maybe_edges) log_message = f"Chunk {processed_chunks} of {total_chunks} extracted {entities_count} Ent + {relations_count} Rel" logger.info(log_message) if pipeline_status is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) # Return the extracted nodes and edges for centralized processing return maybe_nodes, maybe_edges # Get max async tasks limit from global_config llm_model_max_async = global_config.get("llm_model_max_async", 4) semaphore = asyncio.Semaphore(llm_model_max_async) async def _process_with_semaphore(chunk): async with semaphore: return await _process_single_content(chunk) tasks = [] for c in ordered_chunks: task = asyncio.create_task(_process_with_semaphore(c)) tasks.append(task) # Wait for tasks to complete or for the first exception to occur # This allows us to cancel remaining tasks if any task fails done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) # Check if any task raised an exception for task in done: if task.exception(): # If a task failed, cancel all pending tasks # This prevents unnecessary processing since the parent function will abort anyway for pending_task in pending: pending_task.cancel() # Wait for cancellation to complete if pending: await asyncio.wait(pending) # Re-raise the exception to notify the caller raise task.exception() # If all tasks completed successfully, collect results chunk_results = [task.result() for task in tasks] # Return the chunk_results for later processing in merge_nodes_and_edges return chunk_results async def kg_query( query: str, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, chunks_vdb: BaseVectorStorage = None, ) -> str | AsyncIterator[str]: if query_param.model_func: use_model_func = query_param.model_func else: use_model_func = global_config["llm_model_func"] # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) # Handle cache args_hash = compute_args_hash(query_param.mode, query) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) if cached_response is not None: return cached_response hl_keywords, ll_keywords = await get_keywords_from_query( query, query_param, global_config, hashing_kv ) logger.debug(f"High-level keywords: {hl_keywords}") logger.debug(f"Low-level keywords: {ll_keywords}") # Handle empty keywords if hl_keywords == [] and ll_keywords == []: logger.warning("low_level_keywords and high_level_keywords is empty") return PROMPTS["fail_response"] if ll_keywords == [] and query_param.mode in ["local", "hybrid"]: logger.warning( "low_level_keywords is empty, switching from %s mode to global mode", query_param.mode, ) query_param.mode = "global" if hl_keywords == [] and query_param.mode in ["global", "hybrid"]: logger.warning( "high_level_keywords is empty, switching from %s mode to local mode", query_param.mode, ) query_param.mode = "local" ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" # Build context context = await _build_query_context( ll_keywords_str, hl_keywords_str, knowledge_graph_inst, entities_vdb, relationships_vdb, text_chunks_db, query_param, chunks_vdb, ) if query_param.only_need_context: return context if context is not None else PROMPTS["fail_response"] if context is None: return PROMPTS["fail_response"] # Process conversation history history_context = "" if query_param.conversation_history: history_context = get_conversation_turns( query_param.conversation_history, query_param.history_turns ) # Build system prompt user_prompt = ( query_param.user_prompt if query_param.user_prompt else PROMPTS["DEFAULT_USER_PROMPT"] ) sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type, history=history_context, user_prompt=user_prompt, ) if query_param.only_need_prompt: return sys_prompt tokenizer: Tokenizer = global_config["tokenizer"] len_of_prompts = len(tokenizer.encode(query + sys_prompt)) logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}") response = await use_model_func( query, system_prompt=sys_prompt, stream=query_param.stream, ) if isinstance(response, str) and len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") .replace("user", "") .replace("model", "") .replace(query, "") .replace("", "") .replace("", "") .strip() ) if hashing_kv.global_config.get("enable_llm_cache"): # Save to cache await save_to_cache( hashing_kv, CacheData( args_hash=args_hash, content=response, prompt=query, quantized=quantized, min_val=min_val, max_val=max_val, mode=query_param.mode, cache_type="query", ), ) return response async def get_keywords_from_query( query: str, query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, ) -> tuple[list[str], list[str]]: """ Retrieves high-level and low-level keywords for RAG operations. This function checks if keywords are already provided in query parameters, and if not, extracts them from the query text using LLM. Args: query: The user's query text query_param: Query parameters that may contain pre-defined keywords global_config: Global configuration dictionary hashing_kv: Optional key-value storage for caching results Returns: A tuple containing (high_level_keywords, low_level_keywords) """ # Check if pre-defined keywords are already provided if query_param.hl_keywords or query_param.ll_keywords: return query_param.hl_keywords, query_param.ll_keywords # Extract keywords using extract_keywords_only function which already supports conversation history hl_keywords, ll_keywords = await extract_keywords_only( query, query_param, global_config, hashing_kv ) return hl_keywords, ll_keywords async def extract_keywords_only( text: str, param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, ) -> tuple[list[str], list[str]]: """ Extract high-level and low-level keywords from the given 'text' using the LLM. This method does NOT build the final RAG context or provide a final answer. It ONLY extracts keywords (hl_keywords, ll_keywords). """ # 1. Handle cache if needed - add cache type for keywords args_hash = compute_args_hash(param.mode, text) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, text, param.mode, cache_type="keywords" ) if cached_response is not None: try: keywords_data = json.loads(cached_response) return keywords_data["high_level_keywords"], keywords_data[ "low_level_keywords" ] except (json.JSONDecodeError, KeyError): logger.warning( "Invalid cache format for keywords, proceeding with extraction" ) # 2. Build the examples example_number = global_config["addon_params"].get("example_number", None) if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]): examples = "\n".join( PROMPTS["keywords_extraction_examples"][: int(example_number)] ) else: examples = "\n".join(PROMPTS["keywords_extraction_examples"]) language = global_config["addon_params"].get( "language", PROMPTS["DEFAULT_LANGUAGE"] ) # 3. Process conversation history history_context = "" if param.conversation_history: history_context = get_conversation_turns( param.conversation_history, param.history_turns ) # 4. Build the keyword-extraction prompt kw_prompt = PROMPTS["keywords_extraction"].format( query=text, examples=examples, language=language, history=history_context ) tokenizer: Tokenizer = global_config["tokenizer"] len_of_prompts = len(tokenizer.encode(kw_prompt)) logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}") # 5. Call the LLM for keyword extraction if param.model_func: use_model_func = param.model_func else: use_model_func = global_config["llm_model_func"] # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) result = await use_model_func(kw_prompt, keyword_extraction=True) # 6. Parse out JSON from the LLM response match = re.search(r"\{.*\}", result, re.DOTALL) if not match: logger.error("No JSON-like structure found in the LLM respond.") return [], [] try: keywords_data = json.loads(match.group(0)) except json.JSONDecodeError as e: logger.error(f"JSON parsing error: {e}") return [], [] hl_keywords = keywords_data.get("high_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", []) # 7. Cache only the processed keywords with cache type if hl_keywords or ll_keywords: cache_data = { "high_level_keywords": hl_keywords, "low_level_keywords": ll_keywords, } if hashing_kv.global_config.get("enable_llm_cache"): await save_to_cache( hashing_kv, CacheData( args_hash=args_hash, content=json.dumps(cache_data), prompt=text, quantized=quantized, min_val=min_val, max_val=max_val, mode=param.mode, cache_type="keywords", ), ) return hl_keywords, ll_keywords async def _get_vector_context( query: str, chunks_vdb: BaseVectorStorage, query_param: QueryParam, tokenizer: Tokenizer, ) -> tuple[list, list, list] | None: """ Retrieve vector context from the vector database. This function performs vector search to find relevant text chunks for a query, formats them with file path and creation time information. Args: query: The query string to search for chunks_vdb: Vector database containing document chunks query_param: Query parameters including top_k and ids tokenizer: Tokenizer for counting tokens Returns: Tuple (empty_entities, empty_relations, text_units) for combine_contexts, compatible with _get_edge_data and _get_node_data format """ try: results = await chunks_vdb.query( query, top_k=query_param.top_k, ids=query_param.ids ) if not results: return [], [], [] valid_chunks = [] for result in results: if "content" in result: # Directly use content from chunks_vdb.query result chunk_with_time = { "content": result["content"], "created_at": result.get("created_at", None), "file_path": result.get("file_path", "unknown_source"), } valid_chunks.append(chunk_with_time) if not valid_chunks: return [], [], [] maybe_trun_chunks = truncate_list_by_token_size( valid_chunks, key=lambda x: x["content"], max_token_size=query_param.max_token_for_text_unit, tokenizer=tokenizer, ) logger.debug( f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})" ) logger.info( f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}" ) if not maybe_trun_chunks: return [], [], [] # Create empty entities and relations contexts entities_context = [] relations_context = [] # Create text_units_context directly as a list of dictionaries text_units_context = [] for i, chunk in enumerate(maybe_trun_chunks): text_units_context.append( { "id": i + 1, "content": chunk["content"], "file_path": chunk["file_path"], } ) return entities_context, relations_context, text_units_context except Exception as e: logger.error(f"Error in _get_vector_context: {e}") return [], [], [] async def _build_query_context( ll_keywords: str, hl_keywords: str, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, chunks_vdb: BaseVectorStorage = None, # Add chunks_vdb parameter for mix mode ): logger.info(f"Process {os.getpid()} building query context...") # Handle local and global modes as before if query_param.mode == "local": entities_context, relations_context, text_units_context = await _get_node_data( ll_keywords, knowledge_graph_inst, entities_vdb, text_chunks_db, query_param, ) elif query_param.mode == "global": entities_context, relations_context, text_units_context = await _get_edge_data( hl_keywords, knowledge_graph_inst, relationships_vdb, text_chunks_db, query_param, ) else: # hybrid or mix mode ll_data = await _get_node_data( ll_keywords, knowledge_graph_inst, entities_vdb, text_chunks_db, query_param, ) hl_data = await _get_edge_data( hl_keywords, knowledge_graph_inst, relationships_vdb, text_chunks_db, query_param, ) ( ll_entities_context, ll_relations_context, ll_text_units_context, ) = ll_data ( hl_entities_context, hl_relations_context, hl_text_units_context, ) = hl_data # Initialize vector data with empty lists vector_entities_context, vector_relations_context, vector_text_units_context = ( [], [], [], ) # Only get vector data if in mix mode if query_param.mode == "mix" and hasattr(query_param, "original_query"): # Get tokenizer from text_chunks_db tokenizer = text_chunks_db.global_config.get("tokenizer") # Get vector context in triple format vector_data = await _get_vector_context( query_param.original_query, # We need to pass the original query chunks_vdb, query_param, tokenizer, ) # If vector_data is not None, unpack it if vector_data is not None: ( vector_entities_context, vector_relations_context, vector_text_units_context, ) = vector_data # Combine and deduplicate the entities, relationships, and sources entities_context = process_combine_contexts( hl_entities_context, ll_entities_context, vector_entities_context ) relations_context = process_combine_contexts( hl_relations_context, ll_relations_context, vector_relations_context ) text_units_context = process_combine_contexts( hl_text_units_context, ll_text_units_context, vector_text_units_context ) # not necessary to use LLM to generate a response if not entities_context and not relations_context: return None # 转换为 JSON 字符串 entities_str = json.dumps(entities_context, ensure_ascii=False) relations_str = json.dumps(relations_context, ensure_ascii=False) text_units_str = json.dumps(text_units_context, ensure_ascii=False) result = f"""-----Entities(KG)----- ```json {entities_str} ``` -----Relationships(KG)----- ```json {relations_str} ``` -----Document Chunks(DC)----- ```json {text_units_str} ``` """ return result async def _get_node_data( query: str, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, ): # get similar entities logger.info( f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}" ) results = await entities_vdb.query( query, top_k=query_param.top_k, ids=query_param.ids ) if not len(results): return "", "", "" # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] # Call the batch node retrieval and degree functions concurrently. nodes_dict, degrees_dict = await asyncio.gather( knowledge_graph_inst.get_nodes_batch(node_ids), knowledge_graph_inst.node_degrees_batch(node_ids), ) # Now, if you need the node data and degree in order: node_datas = [nodes_dict.get(nid) for nid in node_ids] node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids] if not all([n is not None for n in node_datas]): logger.warning("Some nodes are missing, maybe the storage is damaged") node_datas = [ { **n, "entity_name": k["entity_name"], "rank": d, "created_at": k.get("created_at"), } for k, n, d in zip(results, node_datas, node_degrees) if n is not None ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. # get entitytext chunk use_text_units = await _find_most_related_text_unit_from_entities( node_datas, query_param, text_chunks_db, knowledge_graph_inst, ) use_relations = await _find_most_related_edges_from_entities( node_datas, query_param, knowledge_graph_inst, ) tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") len_node_datas = len(node_datas) node_datas = truncate_list_by_token_size( node_datas, key=lambda x: x["description"] if x["description"] is not None else "", max_token_size=query_param.max_token_for_local_context, tokenizer=tokenizer, ) logger.debug( f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})" ) logger.info( f"Local query: {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks" ) # build prompt entities_context = [] for i, n in enumerate(node_datas): created_at = n.get("created_at", "UNKNOWN") if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) # Get file path from node data file_path = n.get("file_path", "unknown_source") entities_context.append( { "id": i + 1, "entity": n["entity_name"], "type": n.get("entity_type", "UNKNOWN"), "description": n.get("description", "UNKNOWN"), "rank": n["rank"], "created_at": created_at, "file_path": file_path, } ) relations_context = [] for i, e in enumerate(use_relations): created_at = e.get("created_at", "UNKNOWN") # Convert timestamp to readable format if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) # Get file path from edge data file_path = e.get("file_path", "unknown_source") relations_context.append( { "id": i + 1, "entity1": e["src_tgt"][0], "entity2": e["src_tgt"][1], "description": e["description"], "keywords": e["keywords"], "weight": e["weight"], "rank": e["rank"], "created_at": created_at, "file_path": file_path, } ) text_units_context = [] for i, t in enumerate(use_text_units): text_units_context.append( { "id": i + 1, "content": t["content"], "file_path": t.get("file_path", "unknown_source"), } ) return entities_context, relations_context, text_units_context async def _find_most_related_text_unit_from_entities( node_datas: list[dict], query_param: QueryParam, text_chunks_db: BaseKVStorage, knowledge_graph_inst: BaseGraphStorage, ): text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) for dp in node_datas if dp["source_id"] is not None ] node_names = [dp["entity_name"] for dp in node_datas] batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names) # Build the edges list in the same order as node_datas. edges = [batch_edges_dict.get(name, []) for name in node_names] all_one_hop_nodes = set() for this_edges in edges: if not this_edges: continue all_one_hop_nodes.update([e[1] for e in this_edges]) all_one_hop_nodes = list(all_one_hop_nodes) # Batch retrieve one-hop node data using get_nodes_batch all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch( all_one_hop_nodes ) all_one_hop_nodes_data = [ all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes ] # Add null check for node data all_one_hop_text_units_lookup = { k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP])) for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data) if v is not None and "source_id" in v # Add source_id check } all_text_units_lookup = {} tasks = [] for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)): for c_id in this_text_units: if c_id not in all_text_units_lookup: all_text_units_lookup[c_id] = index tasks.append((c_id, index, this_edges)) # Process in batches tasks at a time to avoid overwhelming resources batch_size = 5 results = [] for i in range(0, len(tasks), batch_size): batch_tasks = tasks[i : i + batch_size] batch_results = await asyncio.gather( *[text_chunks_db.get_by_id(c_id) for c_id, _, _ in batch_tasks] ) results.extend(batch_results) for (c_id, index, this_edges), data in zip(tasks, results): all_text_units_lookup[c_id] = { "data": data, "order": index, "relation_counts": 0, } if this_edges: for e in this_edges: if ( e[1] in all_one_hop_text_units_lookup and c_id in all_one_hop_text_units_lookup[e[1]] ): all_text_units_lookup[c_id]["relation_counts"] += 1 # Filter out None values and ensure data has content all_text_units = [ {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None and v.get("data") is not None and "content" in v["data"] ] if not all_text_units: logger.warning("No valid text units found") return [] tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") all_text_units = sorted( all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) ) all_text_units = truncate_list_by_token_size( all_text_units, key=lambda x: x["data"]["content"], max_token_size=query_param.max_token_for_text_unit, tokenizer=tokenizer, ) logger.debug( f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})" ) all_text_units = [t["data"] for t in all_text_units] return all_text_units async def _find_most_related_edges_from_entities( node_datas: list[dict], query_param: QueryParam, knowledge_graph_inst: BaseGraphStorage, ): node_names = [dp["entity_name"] for dp in node_datas] batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names) all_edges = [] seen = set() for node_name in node_names: this_edges = batch_edges_dict.get(node_name, []) for e in this_edges: sorted_edge = tuple(sorted(e)) if sorted_edge not in seen: seen.add(sorted_edge) all_edges.append(sorted_edge) # Prepare edge pairs in two forms: # For the batch edge properties function, use dicts. edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges] # For edge degrees, use tuples. edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples # Call the batched functions concurrently. edge_data_dict, edge_degrees_dict = await asyncio.gather( knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples), ) # Reconstruct edge_datas list in the same order as the deduplicated results. all_edges_data = [] for pair in all_edges: edge_props = edge_data_dict.get(pair) if edge_props is not None: if "weight" not in edge_props: logger.warning( f"Edge {pair} missing 'weight' attribute, using default value 0.0" ) edge_props["weight"] = 0.0 combined = { "src_tgt": pair, "rank": edge_degrees_dict.get(pair, 0), **edge_props, } all_edges_data.append(combined) tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer") all_edges_data = sorted( all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True ) all_edges_data = truncate_list_by_token_size( all_edges_data, key=lambda x: x["description"] if x["description"] is not None else "", max_token_size=query_param.max_token_for_global_context, tokenizer=tokenizer, ) logger.debug( f"Truncate relations from {len(all_edges)} to {len(all_edges_data)} (max tokens:{query_param.max_token_for_global_context})" ) return all_edges_data async def _get_edge_data( keywords, knowledge_graph_inst: BaseGraphStorage, relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, ): logger.info( f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}" ) results = await relationships_vdb.query( keywords, top_k=query_param.top_k, ids=query_param.ids ) if not len(results): return "", "", "" # Prepare edge pairs in two forms: # For the batch edge properties function, use dicts. edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results] # For edge degrees, use tuples. edge_pairs_tuples = [(r["src_id"], r["tgt_id"]) for r in results] # Call the batched functions concurrently. edge_data_dict, edge_degrees_dict = await asyncio.gather( knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples), ) # Reconstruct edge_datas list in the same order as results. edge_datas = [] for k in results: pair = (k["src_id"], k["tgt_id"]) edge_props = edge_data_dict.get(pair) if edge_props is not None: if "weight" not in edge_props: logger.warning( f"Edge {pair} missing 'weight' attribute, using default value 0.0" ) edge_props["weight"] = 0.0 # Use edge degree from the batch as rank. combined = { "src_id": k["src_id"], "tgt_id": k["tgt_id"], "rank": edge_degrees_dict.get(pair, k.get("rank", 0)), "created_at": k.get("created_at", None), **edge_props, } edge_datas.append(combined) tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") edge_datas = sorted( edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True ) edge_datas = truncate_list_by_token_size( edge_datas, key=lambda x: x["description"] if x["description"] is not None else "", max_token_size=query_param.max_token_for_global_context, tokenizer=tokenizer, ) use_entities, use_text_units = await asyncio.gather( _find_most_related_entities_from_relationships( edge_datas, query_param, knowledge_graph_inst, ), _find_related_text_unit_from_relationships( edge_datas, query_param, text_chunks_db, knowledge_graph_inst, ), ) logger.info( f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks" ) relations_context = [] for i, e in enumerate(edge_datas): created_at = e.get("created_at", "UNKNOWN") # Convert timestamp to readable format if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) # Get file path from edge data file_path = e.get("file_path", "unknown_source") relations_context.append( { "id": i + 1, "entity1": e["src_id"], "entity2": e["tgt_id"], "description": e["description"], "keywords": e["keywords"], "weight": e["weight"], "rank": e["rank"], "created_at": created_at, "file_path": file_path, } ) entities_context = [] for i, n in enumerate(use_entities): created_at = n.get("created_at", "UNKNOWN") # Convert timestamp to readable format if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) # Get file path from node data file_path = n.get("file_path", "unknown_source") entities_context.append( { "id": i + 1, "entity": n["entity_name"], "type": n.get("entity_type", "UNKNOWN"), "description": n.get("description", "UNKNOWN"), "rank": n["rank"], "created_at": created_at, "file_path": file_path, } ) text_units_context = [] for i, t in enumerate(use_text_units): text_units_context.append( { "id": i + 1, "content": t["content"], "file_path": t.get("file_path", "unknown"), } ) return entities_context, relations_context, text_units_context async def _find_most_related_entities_from_relationships( edge_datas: list[dict], query_param: QueryParam, knowledge_graph_inst: BaseGraphStorage, ): entity_names = [] seen = set() for e in edge_datas: if e["src_id"] not in seen: entity_names.append(e["src_id"]) seen.add(e["src_id"]) if e["tgt_id"] not in seen: entity_names.append(e["tgt_id"]) seen.add(e["tgt_id"]) # Batch approach: Retrieve nodes and their degrees concurrently with one query each. nodes_dict, degrees_dict = await asyncio.gather( knowledge_graph_inst.get_nodes_batch(entity_names), knowledge_graph_inst.node_degrees_batch(entity_names), ) # Rebuild the list in the same order as entity_names node_datas = [] for entity_name in entity_names: node = nodes_dict.get(entity_name) degree = degrees_dict.get(entity_name, 0) if node is None: logger.warning(f"Node '{entity_name}' not found in batch retrieval.") continue # Combine the node data with the entity name and computed degree (as rank) combined = {**node, "entity_name": entity_name, "rank": degree} node_datas.append(combined) tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer") len_node_datas = len(node_datas) node_datas = truncate_list_by_token_size( node_datas, key=lambda x: x["description"] if x["description"] is not None else "", max_token_size=query_param.max_token_for_local_context, tokenizer=tokenizer, ) logger.debug( f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})" ) return node_datas async def _find_related_text_unit_from_relationships( edge_datas: list[dict], query_param: QueryParam, text_chunks_db: BaseKVStorage, knowledge_graph_inst: BaseGraphStorage, ): text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) for dp in edge_datas if dp["source_id"] is not None ] all_text_units_lookup = {} async def fetch_chunk_data(c_id, index): if c_id not in all_text_units_lookup: chunk_data = await text_chunks_db.get_by_id(c_id) # Only store valid data if chunk_data is not None and "content" in chunk_data: all_text_units_lookup[c_id] = { "data": chunk_data, "order": index, } tasks = [] for index, unit_list in enumerate(text_units): for c_id in unit_list: tasks.append(fetch_chunk_data(c_id, index)) await asyncio.gather(*tasks) if not all_text_units_lookup: logger.warning("No valid text chunks found") return [] all_text_units = [{"id": k, **v} for k, v in all_text_units_lookup.items()] all_text_units = sorted(all_text_units, key=lambda x: x["order"]) # Ensure all text chunks have content valid_text_units = [ t for t in all_text_units if t["data"] is not None and "content" in t["data"] ] if not valid_text_units: logger.warning("No valid text chunks after filtering") return [] tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") truncated_text_units = truncate_list_by_token_size( valid_text_units, key=lambda x: x["data"]["content"], max_token_size=query_param.max_token_for_text_unit, tokenizer=tokenizer, ) logger.debug( f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})" ) all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units] return all_text_units async def naive_query( query: str, chunks_vdb: BaseVectorStorage, query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, ) -> str | AsyncIterator[str]: if query_param.model_func: use_model_func = query_param.model_func else: use_model_func = global_config["llm_model_func"] # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) # Handle cache args_hash = compute_args_hash(query_param.mode, query) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) if cached_response is not None: return cached_response tokenizer: Tokenizer = global_config["tokenizer"] _, _, text_units_context = await _get_vector_context( query, chunks_vdb, query_param, tokenizer ) if text_units_context is None or len(text_units_context) == 0: return PROMPTS["fail_response"] text_units_str = json.dumps(text_units_context, ensure_ascii=False) if query_param.only_need_context: return f""" ---Document Chunks--- ```json {text_units_str} ``` """ # Process conversation history history_context = "" if query_param.conversation_history: history_context = get_conversation_turns( query_param.conversation_history, query_param.history_turns ) # Build system prompt user_prompt = ( query_param.user_prompt if query_param.user_prompt else PROMPTS["DEFAULT_USER_PROMPT"] ) sys_prompt_temp = system_prompt if system_prompt else PROMPTS["naive_rag_response"] sys_prompt = sys_prompt_temp.format( content_data=text_units_str, response_type=query_param.response_type, history=history_context, user_prompt=user_prompt, ) if query_param.only_need_prompt: return sys_prompt len_of_prompts = len(tokenizer.encode(query + sys_prompt)) logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}") response = await use_model_func( query, system_prompt=sys_prompt, stream=query_param.stream, ) if isinstance(response, str) and len(response) > len(sys_prompt): response = ( response[len(sys_prompt) :] .replace(sys_prompt, "") .replace("user", "") .replace("model", "") .replace(query, "") .replace("", "") .replace("", "") .strip() ) if hashing_kv.global_config.get("enable_llm_cache"): # Save to cache await save_to_cache( hashing_kv, CacheData( args_hash=args_hash, content=response, prompt=query, quantized=quantized, min_val=min_val, max_val=max_val, mode=query_param.mode, cache_type="query", ), ) return response # TODO: Deprecated, use user_prompt in QueryParam instead async def kg_query_with_keywords( query: str, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, ll_keywords: list[str] = [], hl_keywords: list[str] = [], chunks_vdb: BaseVectorStorage | None = None, ) -> str | AsyncIterator[str]: """ Refactored kg_query that does NOT extract keywords by itself. It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty. Then it uses those to build context and produce a final LLM response. """ if query_param.model_func: use_model_func = query_param.model_func else: use_model_func = global_config["llm_model_func"] # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) args_hash = compute_args_hash(query_param.mode, query) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) if cached_response is not None: return cached_response # If neither has any keywords, you could handle that logic here. if not hl_keywords and not ll_keywords: logger.warning( "No keywords found in query_param. Could default to global mode or fail." ) return PROMPTS["fail_response"] if not ll_keywords and query_param.mode in ["local", "hybrid"]: logger.warning("low_level_keywords is empty, switching to global mode.") query_param.mode = "global" if not hl_keywords and query_param.mode in ["global", "hybrid"]: logger.warning("high_level_keywords is empty, switching to local mode.") query_param.mode = "local" ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" context = await _build_query_context( ll_keywords_str, hl_keywords_str, knowledge_graph_inst, entities_vdb, relationships_vdb, text_chunks_db, query_param, chunks_vdb=chunks_vdb, ) if not context: return PROMPTS["fail_response"] if query_param.only_need_context: return context # Process conversation history history_context = "" if query_param.conversation_history: history_context = get_conversation_turns( query_param.conversation_history, query_param.history_turns ) sys_prompt_temp = PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type, history=history_context, ) if query_param.only_need_prompt: return sys_prompt tokenizer: Tokenizer = global_config["tokenizer"] len_of_prompts = len(tokenizer.encode(query + sys_prompt)) logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}") # 6. Generate response response = await use_model_func( query, system_prompt=sys_prompt, stream=query_param.stream, ) # Clean up response content if isinstance(response, str) and len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") .replace("user", "") .replace("model", "") .replace(query, "") .replace("", "") .replace("", "") .strip() ) if hashing_kv.global_config.get("enable_llm_cache"): await save_to_cache( hashing_kv, CacheData( args_hash=args_hash, content=response, prompt=query, quantized=quantized, min_val=min_val, max_val=max_val, mode=query_param.mode, cache_type="query", ), ) return response # TODO: Deprecated, use user_prompt in QueryParam instead async def query_with_keywords( query: str, prompt: str, param: QueryParam, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, chunks_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, ) -> str | AsyncIterator[str]: """ Extract keywords from the query and then use them for retrieving information. 1. Extracts high-level and low-level keywords from the query 2. Formats the query with the extracted keywords and prompt 3. Uses the appropriate query method based on param.mode Args: query: The user's query prompt: Additional prompt to prepend to the query param: Query parameters knowledge_graph_inst: Knowledge graph storage entities_vdb: Entities vector database relationships_vdb: Relationships vector database chunks_vdb: Document chunks vector database text_chunks_db: Text chunks storage global_config: Global configuration hashing_kv: Cache storage Returns: Query response or async iterator """ # Extract keywords hl_keywords, ll_keywords = await get_keywords_from_query( query=query, query_param=param, global_config=global_config, hashing_kv=hashing_kv, ) # Create a new string with the prompt and the keywords keywords_str = ", ".join(ll_keywords + hl_keywords) formatted_question = ( f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}" ) param.original_query = query # Use appropriate query method based on mode if param.mode in ["local", "global", "hybrid", "mix"]: return await kg_query_with_keywords( formatted_question, knowledge_graph_inst, entities_vdb, relationships_vdb, text_chunks_db, param, global_config, hashing_kv=hashing_kv, hl_keywords=hl_keywords, ll_keywords=ll_keywords, chunks_vdb=chunks_vdb, ) elif param.mode == "naive": return await naive_query( formatted_question, chunks_vdb, text_chunks_db, param, global_config, hashing_kv=hashing_kv, ) else: raise ValueError(f"Unknown mode {param.mode}")