from __future__ import annotations import asyncio import traceback import json import re import os from typing import Any, AsyncIterator from collections import Counter, defaultdict from .utils import ( logger, clean_str, compute_mdhash_id, Tokenizer, is_float_regex, normalize_extracted_info, pack_user_ass_to_openai_messages, split_string_by_multi_markers, truncate_list_by_token_size, process_combine_contexts, compute_args_hash, handle_cache, save_to_cache, CacheData, get_conversation_turns, use_llm_func_with_cache, list_of_list_to_json, ) from .base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, TextChunkSchema, QueryParam, ) from .prompt import GRAPH_FIELD_SEP, PROMPTS import time from dotenv import load_dotenv # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance # the OS environment variables take precedence over the .env file load_dotenv(dotenv_path=".env", override=False) def chunking_by_token_size( tokenizer: Tokenizer, content: str, split_by_character: str | None = None, split_by_character_only: bool = False, overlap_token_size: int = 128, max_token_size: int = 1024, ) -> list[dict[str, Any]]: tokens = tokenizer.encode(content) results: list[dict[str, Any]] = [] if split_by_character: raw_chunks = content.split(split_by_character) new_chunks = [] if split_by_character_only: for chunk in raw_chunks: _tokens = tokenizer.encode(chunk) new_chunks.append((len(_tokens), chunk)) else: for chunk in raw_chunks: _tokens = tokenizer.encode(chunk) if len(_tokens) > max_token_size: for start in range( 0, len(_tokens), max_token_size - overlap_token_size ): chunk_content = tokenizer.decode( _tokens[start : start + max_token_size] ) new_chunks.append( (min(max_token_size, len(_tokens) - start), chunk_content) ) else: new_chunks.append((len(_tokens), chunk)) for index, (_len, chunk) in enumerate(new_chunks): results.append( { "tokens": _len, "content": chunk.strip(), "chunk_order_index": index, } ) else: for index, start in enumerate( range(0, len(tokens), max_token_size - overlap_token_size) ): chunk_content = tokenizer.decode(tokens[start : start + max_token_size]) results.append( { "tokens": min(max_token_size, len(tokens) - start), "content": chunk_content.strip(), "chunk_order_index": index, } ) return results async def _handle_entity_relation_summary( entity_or_relation_name: str, description: str, global_config: dict, pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, ) -> str: """Handle entity relation summary For each entity or relation, input is the combined description of already existing description and new description. If too long, use LLM to summarize. """ use_llm_func: callable = global_config["llm_model_func"] tokenizer: Tokenizer = global_config["tokenizer"] llm_max_tokens = global_config["llm_model_max_token_size"] summary_max_tokens = global_config["summary_to_max_tokens"] language = global_config["addon_params"].get( "language", PROMPTS["DEFAULT_LANGUAGE"] ) tokens = tokenizer.encode(description) ### summarize is not determined here anymore (It's determined by num_fragment now) # if len(tokens) < summary_max_tokens: # No need for summary # return description prompt_template = PROMPTS["summarize_entity_descriptions"] use_description = tokenizer.decode(tokens[:llm_max_tokens]) context_base = dict( entity_name=entity_or_relation_name, description_list=use_description.split(GRAPH_FIELD_SEP), language=language, ) use_prompt = prompt_template.format(**context_base) logger.debug(f"Trigger summary: {entity_or_relation_name}") # Use LLM function with cache summary = await use_llm_func_with_cache( use_prompt, use_llm_func, llm_response_cache=llm_response_cache, max_tokens=summary_max_tokens, cache_type="extract", ) return summary async def _handle_single_entity_extraction( record_attributes: list[str], chunk_key: str, file_path: str = "unknown_source", ): if len(record_attributes) < 4 or '"entity"' not in record_attributes[0]: return None # Clean and validate entity name entity_name = clean_str(record_attributes[1]).strip() if not entity_name: logger.warning( f"Entity extraction error: empty entity name in: {record_attributes}" ) return None # Normalize entity name entity_name = normalize_extracted_info(entity_name, is_entity=True) # Clean and validate entity type entity_type = clean_str(record_attributes[2]).strip('"') if not entity_type.strip() or entity_type.startswith('("'): logger.warning( f"Entity extraction error: invalid entity type in: {record_attributes}" ) return None # Clean and validate description entity_description = clean_str(record_attributes[3]) entity_description = normalize_extracted_info(entity_description) if not entity_description.strip(): logger.warning( f"Entity extraction error: empty description for entity '{entity_name}' of type '{entity_type}'" ) return None return dict( entity_name=entity_name, entity_type=entity_type, description=entity_description, source_id=chunk_key, file_path=file_path, ) async def _handle_single_relationship_extraction( record_attributes: list[str], chunk_key: str, file_path: str = "unknown_source", ): if len(record_attributes) < 5 or '"relationship"' not in record_attributes[0]: return None # add this record as edge source = clean_str(record_attributes[1]) target = clean_str(record_attributes[2]) # Normalize source and target entity names source = normalize_extracted_info(source, is_entity=True) target = normalize_extracted_info(target, is_entity=True) edge_description = clean_str(record_attributes[3]) edge_description = normalize_extracted_info(edge_description) edge_keywords = clean_str(record_attributes[4]).strip('"').strip("'") edge_source_id = chunk_key weight = ( float(record_attributes[-1].strip('"').strip("'")) if is_float_regex(record_attributes[-1].strip('"').strip("'")) else 1.0 ) return dict( src_id=source, tgt_id=target, weight=weight, description=edge_description, keywords=edge_keywords, source_id=edge_source_id, file_path=file_path, ) async def _merge_nodes_then_upsert( entity_name: str, nodes_data: list[dict], knowledge_graph_inst: BaseGraphStorage, global_config: dict, pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, ): """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" already_entity_types = [] already_source_ids = [] already_description = [] already_file_paths = [] already_node = await knowledge_graph_inst.get_node(entity_name) if already_node is not None: already_entity_types.append(already_node["entity_type"]) already_source_ids.extend( split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP]) ) already_file_paths.extend( split_string_by_multi_markers(already_node["file_path"], [GRAPH_FIELD_SEP]) ) already_description.append(already_node["description"]) entity_type = sorted( Counter( [dp["entity_type"] for dp in nodes_data] + already_entity_types ).items(), key=lambda x: x[1], reverse=True, )[0][0] description = GRAPH_FIELD_SEP.join( sorted(set([dp["description"] for dp in nodes_data] + already_description)) ) source_id = GRAPH_FIELD_SEP.join( set([dp["source_id"] for dp in nodes_data] + already_source_ids) ) file_path = GRAPH_FIELD_SEP.join( set([dp["file_path"] for dp in nodes_data] + already_file_paths) ) force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] num_fragment = description.count(GRAPH_FIELD_SEP) + 1 num_new_fragment = len(set([dp["description"] for dp in nodes_data])) if num_fragment > 1: if num_fragment >= force_llm_summary_on_merge: status_message = f"LLM merge N: {entity_name} | {num_new_fragment}+{num_fragment-num_new_fragment}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) description = await _handle_entity_relation_summary( entity_name, description, global_config, pipeline_status, pipeline_status_lock, llm_response_cache, ) else: status_message = f"Merge N: {entity_name} | {num_new_fragment}+{num_fragment-num_new_fragment}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) node_data = dict( entity_id=entity_name, entity_type=entity_type, description=description, source_id=source_id, file_path=file_path, ) await knowledge_graph_inst.upsert_node( entity_name, node_data=node_data, ) node_data["entity_name"] = entity_name return node_data async def _merge_edges_then_upsert( src_id: str, tgt_id: str, edges_data: list[dict], knowledge_graph_inst: BaseGraphStorage, global_config: dict, pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, ): already_weights = [] already_source_ids = [] already_description = [] already_keywords = [] already_file_paths = [] if await knowledge_graph_inst.has_edge(src_id, tgt_id): already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) # Handle the case where get_edge returns None or missing fields if already_edge: # Get weight with default 0.0 if missing already_weights.append(already_edge.get("weight", 0.0)) # Get source_id with empty string default if missing or None if already_edge.get("source_id") is not None: already_source_ids.extend( split_string_by_multi_markers( already_edge["source_id"], [GRAPH_FIELD_SEP] ) ) # Get file_path with empty string default if missing or None if already_edge.get("file_path") is not None: already_file_paths.extend( split_string_by_multi_markers( already_edge["file_path"], [GRAPH_FIELD_SEP] ) ) # Get description with empty string default if missing or None if already_edge.get("description") is not None: already_description.append(already_edge["description"]) # Get keywords with empty string default if missing or None if already_edge.get("keywords") is not None: already_keywords.extend( split_string_by_multi_markers( already_edge["keywords"], [GRAPH_FIELD_SEP] ) ) # Process edges_data with None checks weight = sum([dp["weight"] for dp in edges_data] + already_weights) description = GRAPH_FIELD_SEP.join( sorted( set( [dp["description"] for dp in edges_data if dp.get("description")] + already_description ) ) ) keywords = GRAPH_FIELD_SEP.join( sorted( set( [dp["keywords"] for dp in edges_data if dp.get("keywords")] + already_keywords ) ) ) source_id = GRAPH_FIELD_SEP.join( set( [dp["source_id"] for dp in edges_data if dp.get("source_id")] + already_source_ids ) ) file_path = GRAPH_FIELD_SEP.join( set( [dp["file_path"] for dp in edges_data if dp.get("file_path")] + already_file_paths ) ) for need_insert_id in [src_id, tgt_id]: if not (await knowledge_graph_inst.has_node(need_insert_id)): # # Discard this edge if the node does not exist # if need_insert_id == src_id: # logger.warning( # f"Discard edge: {src_id} - {tgt_id} | Source node missing" # ) # else: # logger.warning( # f"Discard edge: {src_id} - {tgt_id} | Target node missing" # ) # return None await knowledge_graph_inst.upsert_node( need_insert_id, node_data={ "entity_id": need_insert_id, "source_id": source_id, "description": description, "entity_type": "UNKNOWN", "file_path": file_path, }, ) force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] num_fragment = description.count(GRAPH_FIELD_SEP) + 1 num_new_fragment = len( set([dp["description"] for dp in edges_data if dp.get("description")]) ) if num_fragment > 1: if num_fragment >= force_llm_summary_on_merge: status_message = f"LLM merge E: {src_id} - {tgt_id} | {num_new_fragment}+{num_fragment-num_new_fragment}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) description = await _handle_entity_relation_summary( f"({src_id}, {tgt_id})", description, global_config, pipeline_status, pipeline_status_lock, llm_response_cache, ) else: status_message = f"Merge E: {src_id} - {tgt_id} | {num_new_fragment}+{num_fragment-num_new_fragment}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) await knowledge_graph_inst.upsert_edge( src_id, tgt_id, edge_data=dict( weight=weight, description=description, keywords=keywords, source_id=source_id, file_path=file_path, ), ) edge_data = dict( src_id=src_id, tgt_id=tgt_id, description=description, keywords=keywords, source_id=source_id, file_path=file_path, ) return edge_data async def extract_entities( chunks: dict[str, TextChunkSchema], knowledge_graph_inst: BaseGraphStorage, entity_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, global_config: dict[str, str], pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, ) -> None: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] ordered_chunks = list(chunks.items()) # add language and example number params to prompt language = global_config["addon_params"].get( "language", PROMPTS["DEFAULT_LANGUAGE"] ) entity_types = global_config["addon_params"].get( "entity_types", PROMPTS["DEFAULT_ENTITY_TYPES"] ) example_number = global_config["addon_params"].get("example_number", None) if example_number and example_number < len(PROMPTS["entity_extraction_examples"]): examples = "\n".join( PROMPTS["entity_extraction_examples"][: int(example_number)] ) else: examples = "\n".join(PROMPTS["entity_extraction_examples"]) example_context_base = dict( tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], entity_types=", ".join(entity_types), language=language, ) # add example's format examples = examples.format(**example_context_base) entity_extract_prompt = PROMPTS["entity_extraction"] context_base = dict( tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], entity_types=",".join(entity_types), examples=examples, language=language, ) continue_prompt = PROMPTS["entity_continue_extraction"].format(**context_base) if_loop_prompt = PROMPTS["entity_if_loop_extraction"] processed_chunks = 0 total_chunks = len(ordered_chunks) total_entities_count = 0 total_relations_count = 0 # Get lock manager from shared storage from .kg.shared_storage import get_graph_db_lock graph_db_lock = get_graph_db_lock(enable_logging=False) # Use the global use_llm_func_with_cache function from utils.py async def _process_extraction_result( result: str, chunk_key: str, file_path: str = "unknown_source" ): """Process a single extraction result (either initial or gleaning) Args: result (str): The extraction result to process chunk_key (str): The chunk key for source tracking file_path (str): The file path for citation Returns: tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships """ maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) records = split_string_by_multi_markers( result, [context_base["record_delimiter"], context_base["completion_delimiter"]], ) for record in records: record = re.search(r"\((.*)\)", record) if record is None: continue record = record.group(1) record_attributes = split_string_by_multi_markers( record, [context_base["tuple_delimiter"]] ) if_entities = await _handle_single_entity_extraction( record_attributes, chunk_key, file_path ) if if_entities is not None: maybe_nodes[if_entities["entity_name"]].append(if_entities) continue if_relation = await _handle_single_relationship_extraction( record_attributes, chunk_key, file_path ) if if_relation is not None: maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append( if_relation ) return maybe_nodes, maybe_edges async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): """Process a single chunk Args: chunk_key_dp (tuple[str, TextChunkSchema]): ("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) Returns: tuple: (maybe_nodes, maybe_edges) containing extracted entities and relationships """ nonlocal processed_chunks chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] content = chunk_dp["content"] # Get file path from chunk data or use default file_path = chunk_dp.get("file_path", "unknown_source") # Get initial extraction hint_prompt = entity_extract_prompt.format( **context_base, input_text="{input_text}" ).format(**context_base, input_text=content) final_result = await use_llm_func_with_cache( hint_prompt, use_llm_func, llm_response_cache=llm_response_cache, cache_type="extract", ) history = pack_user_ass_to_openai_messages(hint_prompt, final_result) # Process initial extraction with file path maybe_nodes, maybe_edges = await _process_extraction_result( final_result, chunk_key, file_path ) # Process additional gleaning results for now_glean_index in range(entity_extract_max_gleaning): glean_result = await use_llm_func_with_cache( continue_prompt, use_llm_func, llm_response_cache=llm_response_cache, history_messages=history, cache_type="extract", ) history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) # Process gleaning result separately with file path glean_nodes, glean_edges = await _process_extraction_result( glean_result, chunk_key, file_path ) # Merge results - only add entities and edges with new names for entity_name, entities in glean_nodes.items(): if ( entity_name not in maybe_nodes ): # Only accetp entities with new name in gleaning stage maybe_nodes[entity_name].extend(entities) for edge_key, edges in glean_edges.items(): if ( edge_key not in maybe_edges ): # Only accetp edges with new name in gleaning stage maybe_edges[edge_key].extend(edges) if now_glean_index == entity_extract_max_gleaning - 1: break if_loop_result: str = await use_llm_func_with_cache( if_loop_prompt, use_llm_func, llm_response_cache=llm_response_cache, history_messages=history, cache_type="extract", ) if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if if_loop_result != "yes": break processed_chunks += 1 entities_count = len(maybe_nodes) relations_count = len(maybe_edges) log_message = f"Chk {processed_chunks}/{total_chunks}: extracted {entities_count} Ent + {relations_count} Rel" logger.info(log_message) if pipeline_status is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) # Return the extracted nodes and edges for centralized processing return maybe_nodes, maybe_edges # Get max async tasks limit from global_config llm_model_max_async = global_config.get("llm_model_max_async", 4) semaphore = asyncio.Semaphore(llm_model_max_async) async def _process_with_semaphore(chunk): async with semaphore: return await _process_single_content(chunk) tasks = [] for c in ordered_chunks: task = asyncio.create_task(_process_with_semaphore(c)) tasks.append(task) # Wait for tasks to complete or for the first exception to occur # This allows us to cancel remaining tasks if any task fails done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) # Check if any task raised an exception for task in done: if task.exception(): # If a task failed, cancel all pending tasks # This prevents unnecessary processing since the parent function will abort anyway for pending_task in pending: pending_task.cancel() # Wait for cancellation to complete if pending: await asyncio.wait(pending) # Re-raise the exception to notify the caller raise task.exception() # If all tasks completed successfully, collect results chunk_results = [task.result() for task in tasks] # Collect all nodes and edges from all chunks all_nodes = defaultdict(list) all_edges = defaultdict(list) for maybe_nodes, maybe_edges in chunk_results: # Collect nodes for entity_name, entities in maybe_nodes.items(): all_nodes[entity_name].extend(entities) # Collect edges with sorted keys for undirected graph for edge_key, edges in maybe_edges.items(): sorted_edge_key = tuple(sorted(edge_key)) all_edges[sorted_edge_key].extend(edges) # Centralized processing of all nodes and edges entities_data = [] relationships_data = [] # Use graph database lock to ensure atomic merges and updates async with graph_db_lock: # Process and update all entities at once for entity_name, entities in all_nodes.items(): entity_data = await _merge_nodes_then_upsert( entity_name, entities, knowledge_graph_inst, global_config, pipeline_status, pipeline_status_lock, llm_response_cache, ) entities_data.append(entity_data) # Process and update all relationships at once for edge_key, edges in all_edges.items(): edge_data = await _merge_edges_then_upsert( edge_key[0], edge_key[1], edges, knowledge_graph_inst, global_config, pipeline_status, pipeline_status_lock, llm_response_cache, ) if edge_data is not None: relationships_data.append(edge_data) # Update total counts total_entities_count = len(entities_data) total_relations_count = len(relationships_data) log_message = f"Updating vector storage: {total_entities_count} entities..." logger.info(log_message) if pipeline_status is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) # Update vector databases with all collected data if entity_vdb is not None and entities_data: data_for_vdb = { compute_mdhash_id(dp["entity_name"], prefix="ent-"): { "entity_name": dp["entity_name"], "entity_type": dp["entity_type"], "content": f"{dp['entity_name']}\n{dp['description']}", "source_id": dp["source_id"], "file_path": dp.get("file_path", "unknown_source"), } for dp in entities_data } await entity_vdb.upsert(data_for_vdb) log_message = ( f"Updating vector storage: {total_relations_count} relationships..." ) logger.info(log_message) if pipeline_status is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) if relationships_vdb is not None and relationships_data: data_for_vdb = { compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { "src_id": dp["src_id"], "tgt_id": dp["tgt_id"], "keywords": dp["keywords"], "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}", "source_id": dp["source_id"], "file_path": dp.get("file_path", "unknown_source"), } for dp in relationships_data } await relationships_vdb.upsert(data_for_vdb) async def kg_query( query: str, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, ) -> str | AsyncIterator[str]: # Handle cache use_model_func = ( query_param.model_func if query_param.model_func else global_config["llm_model_func"] ) args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) if cached_response is not None: return cached_response hl_keywords, ll_keywords = await get_keywords_from_query( query, query_param, global_config, hashing_kv ) logger.debug(f"High-level keywords: {hl_keywords}") logger.debug(f"Low-level keywords: {ll_keywords}") # Handle empty keywords if hl_keywords == [] and ll_keywords == []: logger.warning("low_level_keywords and high_level_keywords is empty") return PROMPTS["fail_response"] if ll_keywords == [] and query_param.mode in ["local", "hybrid"]: logger.warning( "low_level_keywords is empty, switching from %s mode to global mode", query_param.mode, ) query_param.mode = "global" if hl_keywords == [] and query_param.mode in ["global", "hybrid"]: logger.warning( "high_level_keywords is empty, switching from %s mode to local mode", query_param.mode, ) query_param.mode = "local" ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" # Build context context = await _build_query_context( ll_keywords_str, hl_keywords_str, knowledge_graph_inst, entities_vdb, relationships_vdb, text_chunks_db, query_param, ) if query_param.only_need_context: return context if context is None: return PROMPTS["fail_response"] # Process conversation history history_context = "" if query_param.conversation_history: history_context = get_conversation_turns( query_param.conversation_history, query_param.history_turns ) sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type, history=history_context, ) if query_param.only_need_prompt: return sys_prompt tokenizer: Tokenizer = global_config["tokenizer"] len_of_prompts = len(tokenizer.encode(query + sys_prompt)) logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}") response = await use_model_func( query, system_prompt=sys_prompt, stream=query_param.stream, ) if isinstance(response, str) and len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") .replace("user", "") .replace("model", "") .replace(query, "") .replace("", "") .replace("", "") .strip() ) if hashing_kv.global_config.get("enable_llm_cache"): # Save to cache await save_to_cache( hashing_kv, CacheData( args_hash=args_hash, content=response, prompt=query, quantized=quantized, min_val=min_val, max_val=max_val, mode=query_param.mode, cache_type="query", ), ) return response async def get_keywords_from_query( query: str, query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, ) -> tuple[list[str], list[str]]: """ Retrieves high-level and low-level keywords for RAG operations. This function checks if keywords are already provided in query parameters, and if not, extracts them from the query text using LLM. Args: query: The user's query text query_param: Query parameters that may contain pre-defined keywords global_config: Global configuration dictionary hashing_kv: Optional key-value storage for caching results Returns: A tuple containing (high_level_keywords, low_level_keywords) """ # Check if pre-defined keywords are already provided if query_param.hl_keywords or query_param.ll_keywords: return query_param.hl_keywords, query_param.ll_keywords # Extract keywords using extract_keywords_only function which already supports conversation history hl_keywords, ll_keywords = await extract_keywords_only( query, query_param, global_config, hashing_kv ) return hl_keywords, ll_keywords async def extract_keywords_only( text: str, param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, ) -> tuple[list[str], list[str]]: """ Extract high-level and low-level keywords from the given 'text' using the LLM. This method does NOT build the final RAG context or provide a final answer. It ONLY extracts keywords (hl_keywords, ll_keywords). """ # 1. Handle cache if needed - add cache type for keywords args_hash = compute_args_hash(param.mode, text, cache_type="keywords") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, text, param.mode, cache_type="keywords" ) if cached_response is not None: try: keywords_data = json.loads(cached_response) return keywords_data["high_level_keywords"], keywords_data[ "low_level_keywords" ] except (json.JSONDecodeError, KeyError): logger.warning( "Invalid cache format for keywords, proceeding with extraction" ) # 2. Build the examples example_number = global_config["addon_params"].get("example_number", None) if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]): examples = "\n".join( PROMPTS["keywords_extraction_examples"][: int(example_number)] ) else: examples = "\n".join(PROMPTS["keywords_extraction_examples"]) language = global_config["addon_params"].get( "language", PROMPTS["DEFAULT_LANGUAGE"] ) # 3. Process conversation history history_context = "" if param.conversation_history: history_context = get_conversation_turns( param.conversation_history, param.history_turns ) # 4. Build the keyword-extraction prompt kw_prompt = PROMPTS["keywords_extraction"].format( query=text, examples=examples, language=language, history=history_context ) tokenizer: Tokenizer = global_config["tokenizer"] len_of_prompts = len(tokenizer.encode(kw_prompt)) logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}") # 5. Call the LLM for keyword extraction use_model_func = ( param.model_func if param.model_func else global_config["llm_model_func"] ) result = await use_model_func(kw_prompt, keyword_extraction=True) # 6. Parse out JSON from the LLM response match = re.search(r"\{.*\}", result, re.DOTALL) if not match: logger.error("No JSON-like structure found in the LLM respond.") return [], [] try: keywords_data = json.loads(match.group(0)) except json.JSONDecodeError as e: logger.error(f"JSON parsing error: {e}") return [], [] hl_keywords = keywords_data.get("high_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", []) # 7. Cache only the processed keywords with cache type if hl_keywords or ll_keywords: cache_data = { "high_level_keywords": hl_keywords, "low_level_keywords": ll_keywords, } if hashing_kv.global_config.get("enable_llm_cache"): await save_to_cache( hashing_kv, CacheData( args_hash=args_hash, content=json.dumps(cache_data), prompt=text, quantized=quantized, min_val=min_val, max_val=max_val, mode=param.mode, cache_type="keywords", ), ) return hl_keywords, ll_keywords async def mix_kg_vector_query( query: str, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, chunks_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, ) -> str | AsyncIterator[str]: """ Hybrid retrieval implementation combining knowledge graph and vector search. This function performs a hybrid search by: 1. Extracting semantic information from knowledge graph 2. Retrieving relevant text chunks through vector similarity 3. Combining both results for comprehensive answer generation """ # get tokenizer tokenizer: Tokenizer = global_config["tokenizer"] # 1. Cache handling use_model_func = ( query_param.model_func if query_param.model_func else global_config["llm_model_func"] ) args_hash = compute_args_hash("mix", query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, "mix", cache_type="query" ) if cached_response is not None: return cached_response # Process conversation history history_context = "" if query_param.conversation_history: history_context = get_conversation_turns( query_param.conversation_history, query_param.history_turns ) # 2. Execute knowledge graph and vector searches in parallel async def get_kg_context(): try: hl_keywords, ll_keywords = await get_keywords_from_query( query, query_param, global_config, hashing_kv ) if not hl_keywords and not ll_keywords: logger.warning("Both high-level and low-level keywords are empty") return None # Convert keyword lists to strings ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" # Set query mode based on available keywords if not ll_keywords_str and not hl_keywords_str: return None elif not ll_keywords_str: query_param.mode = "global" elif not hl_keywords_str: query_param.mode = "local" else: query_param.mode = "hybrid" # Build knowledge graph context context = await _build_query_context( ll_keywords_str, hl_keywords_str, knowledge_graph_inst, entities_vdb, relationships_vdb, text_chunks_db, query_param, ) return context except Exception as e: logger.error(f"Error in get_kg_context: {str(e)}") traceback.print_exc() return None async def get_vector_context(): # Consider conversation history in vector search augmented_query = query if history_context: augmented_query = f"{history_context}\n{query}" try: # Reduce top_k for vector search in hybrid mode since we have structured information from KG mix_topk = min(10, query_param.top_k) results = await chunks_vdb.query( augmented_query, top_k=mix_topk, ids=query_param.ids ) if not results: return None chunks_ids = [r["id"] for r in results] chunks = await text_chunks_db.get_by_ids(chunks_ids) valid_chunks = [] for chunk, result in zip(chunks, results): if chunk is not None and "content" in chunk: # Merge chunk content and time metadata chunk_with_time = { "content": chunk["content"], "created_at": result.get("created_at", None), "file_path": result.get("file_path", None), } valid_chunks.append(chunk_with_time) if not valid_chunks: return None maybe_trun_chunks = truncate_list_by_token_size( valid_chunks, key=lambda x: x["content"], max_token_size=query_param.max_token_for_text_unit, tokenizer=tokenizer, ) if not maybe_trun_chunks: return None # Include time information in content formatted_chunks = [] for c in maybe_trun_chunks: chunk_text = "File path: " + c["file_path"] + "\n" + c["content"] if c["created_at"]: chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}" formatted_chunks.append(chunk_text) logger.debug( f"Truncate chunks from {len(chunks)} to {len(formatted_chunks)} (max tokens:{query_param.max_token_for_text_unit})" ) return "\n--New Chunk--\n".join(formatted_chunks) except Exception as e: logger.error(f"Error in get_vector_context: {e}") return None # 3. Execute both retrievals in parallel kg_context, vector_context = await asyncio.gather( get_kg_context(), get_vector_context() ) # 4. Merge contexts if kg_context is None and vector_context is None: return PROMPTS["fail_response"] if query_param.only_need_context: context_str = f""" -----Knowledge Graph Context----- {kg_context if kg_context else "No relevant knowledge graph information found"} -----Vector Context----- {vector_context if vector_context else "No relevant text information found"} """.strip() return context_str # 5. Construct hybrid prompt sys_prompt = ( system_prompt if system_prompt else PROMPTS["mix_rag_response"] ).format( kg_context=kg_context if kg_context else "No relevant knowledge graph information found", vector_context=vector_context if vector_context else "No relevant text information found", response_type=query_param.response_type, history=history_context, ) if query_param.only_need_prompt: return sys_prompt len_of_prompts = len(tokenizer.encode(query + sys_prompt)) logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}") # 6. Generate response response = await use_model_func( query, system_prompt=sys_prompt, stream=query_param.stream, ) # Clean up response content if isinstance(response, str) and len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") .replace("user", "") .replace("model", "") .replace(query, "") .replace("", "") .replace("", "") .strip() ) if hashing_kv.global_config.get("enable_llm_cache"): # 7. Save cache - Only cache after collecting complete response await save_to_cache( hashing_kv, CacheData( args_hash=args_hash, content=response, prompt=query, quantized=quantized, min_val=min_val, max_val=max_val, mode="mix", cache_type="query", ), ) return response async def _build_query_context( ll_keywords: str, hl_keywords: str, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, ): logger.info(f"Process {os.getpid()} buidling query context...") if query_param.mode == "local": entities_context, relations_context, text_units_context = await _get_node_data( ll_keywords, knowledge_graph_inst, entities_vdb, text_chunks_db, query_param, ) elif query_param.mode == "global": entities_context, relations_context, text_units_context = await _get_edge_data( hl_keywords, knowledge_graph_inst, relationships_vdb, text_chunks_db, query_param, ) else: # hybrid mode ll_data = await _get_node_data( ll_keywords, knowledge_graph_inst, entities_vdb, text_chunks_db, query_param, ) hl_data = await _get_edge_data( hl_keywords, knowledge_graph_inst, relationships_vdb, text_chunks_db, query_param, ) ( ll_entities_context, ll_relations_context, ll_text_units_context, ) = ll_data ( hl_entities_context, hl_relations_context, hl_text_units_context, ) = hl_data entities_context, relations_context, text_units_context = combine_contexts( [hl_entities_context, ll_entities_context], [hl_relations_context, ll_relations_context], [hl_text_units_context, ll_text_units_context], ) # not necessary to use LLM to generate a response if not entities_context and not relations_context: return None # 转换为 JSON 字符串 entities_str = json.dumps(entities_context, ensure_ascii=False) relations_str = json.dumps(relations_context, ensure_ascii=False) text_units_str = json.dumps(text_units_context, ensure_ascii=False) result = f""" -----Entities----- ```json {entities_str} ``` -----Relationships----- ```json {relations_str} ``` -----Sources----- ```json {text_units_str} ``` """.strip() return result async def _get_node_data( query: str, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, ): # get similar entities logger.info( f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}" ) results = await entities_vdb.query( query, top_k=query_param.top_k, ids=query_param.ids ) if not len(results): return "", "", "" # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] # Call the batch node retrieval and degree functions concurrently. nodes_dict, degrees_dict = await asyncio.gather( knowledge_graph_inst.get_nodes_batch(node_ids), knowledge_graph_inst.node_degrees_batch(node_ids), ) # Now, if you need the node data and degree in order: node_datas = [nodes_dict.get(nid) for nid in node_ids] node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids] if not all([n is not None for n in node_datas]): logger.warning("Some nodes are missing, maybe the storage is damaged") node_datas = [ {**n, "entity_name": k["entity_name"], "rank": d} for k, n, d in zip(results, node_datas, node_degrees) if n is not None ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. # get entitytext chunk use_text_units = await _find_most_related_text_unit_from_entities( node_datas, query_param, text_chunks_db, knowledge_graph_inst, ) use_relations = await _find_most_related_edges_from_entities( node_datas, query_param, knowledge_graph_inst, ) tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") len_node_datas = len(node_datas) node_datas = truncate_list_by_token_size( node_datas, key=lambda x: x["description"] if x["description"] is not None else "", max_token_size=query_param.max_token_for_local_context, tokenizer=tokenizer, ) logger.debug( f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})" ) logger.info( f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks" ) # build prompt entites_section_list = [ [ "id", "entity", "type", "description", "rank", "created_at", "file_path", ] ] for i, n in enumerate(node_datas): created_at = n.get("created_at", "UNKNOWN") if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) # Get file path from node data file_path = n.get("file_path", "unknown_source") entites_section_list.append( [ i, n["entity_name"], n.get("entity_type", "UNKNOWN"), n.get("description", "UNKNOWN"), n["rank"], created_at, file_path, ] ) entities_context = list_of_list_to_json(entites_section_list) relations_section_list = [ [ "id", "source", "target", "description", "keywords", "weight", "rank", "created_at", "file_path", ] ] for i, e in enumerate(use_relations): created_at = e.get("created_at", "UNKNOWN") # Convert timestamp to readable format if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) # Get file path from edge data file_path = e.get("file_path", "unknown_source") relations_section_list.append( [ i, e["src_tgt"][0], e["src_tgt"][1], e["description"], e["keywords"], e["weight"], e["rank"], created_at, file_path, ] ) relations_context = list_of_list_to_json(relations_section_list) text_units_section_list = [["id", "content", "file_path"]] for i, t in enumerate(use_text_units): text_units_section_list.append( [i, t["content"], t.get("file_path", "unknown_source")] ) text_units_context = list_of_list_to_json(text_units_section_list) return entities_context, relations_context, text_units_context async def _find_most_related_text_unit_from_entities( node_datas: list[dict], query_param: QueryParam, text_chunks_db: BaseKVStorage, knowledge_graph_inst: BaseGraphStorage, ): text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) for dp in node_datas if dp["source_id"] is not None ] node_names = [dp["entity_name"] for dp in node_datas] batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names) # Build the edges list in the same order as node_datas. edges = [batch_edges_dict.get(name, []) for name in node_names] all_one_hop_nodes = set() for this_edges in edges: if not this_edges: continue all_one_hop_nodes.update([e[1] for e in this_edges]) all_one_hop_nodes = list(all_one_hop_nodes) # Batch retrieve one-hop node data using get_nodes_batch all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch( all_one_hop_nodes ) all_one_hop_nodes_data = [ all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes ] # Add null check for node data all_one_hop_text_units_lookup = { k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP])) for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data) if v is not None and "source_id" in v # Add source_id check } all_text_units_lookup = {} tasks = [] for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)): for c_id in this_text_units: if c_id not in all_text_units_lookup: all_text_units_lookup[c_id] = index tasks.append((c_id, index, this_edges)) # Process in batches tasks at a time to avoid overwhelming resources batch_size = 5 results = [] for i in range(0, len(tasks), batch_size): batch_tasks = tasks[i : i + batch_size] batch_results = await asyncio.gather( *[text_chunks_db.get_by_id(c_id) for c_id, _, _ in batch_tasks] ) results.extend(batch_results) for (c_id, index, this_edges), data in zip(tasks, results): all_text_units_lookup[c_id] = { "data": data, "order": index, "relation_counts": 0, } if this_edges: for e in this_edges: if ( e[1] in all_one_hop_text_units_lookup and c_id in all_one_hop_text_units_lookup[e[1]] ): all_text_units_lookup[c_id]["relation_counts"] += 1 # Filter out None values and ensure data has content all_text_units = [ {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None and v.get("data") is not None and "content" in v["data"] ] if not all_text_units: logger.warning("No valid text units found") return [] tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") all_text_units = sorted( all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) ) all_text_units = truncate_list_by_token_size( all_text_units, key=lambda x: x["data"]["content"], max_token_size=query_param.max_token_for_text_unit, tokenizer=tokenizer, ) logger.debug( f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})" ) all_text_units = [t["data"] for t in all_text_units] return all_text_units async def _find_most_related_edges_from_entities( node_datas: list[dict], query_param: QueryParam, knowledge_graph_inst: BaseGraphStorage, ): node_names = [dp["entity_name"] for dp in node_datas] batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names) all_edges = [] seen = set() for node_name in node_names: this_edges = batch_edges_dict.get(node_name, []) for e in this_edges: sorted_edge = tuple(sorted(e)) if sorted_edge not in seen: seen.add(sorted_edge) all_edges.append(sorted_edge) # Prepare edge pairs in two forms: # For the batch edge properties function, use dicts. edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges] # For edge degrees, use tuples. edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples # Call the batched functions concurrently. edge_data_dict, edge_degrees_dict = await asyncio.gather( knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples), ) # Reconstruct edge_datas list in the same order as the deduplicated results. all_edges_data = [] for pair in all_edges: edge_props = edge_data_dict.get(pair) if edge_props is not None: combined = { "src_tgt": pair, "rank": edge_degrees_dict.get(pair, 0), **edge_props, } all_edges_data.append(combined) tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer") all_edges_data = sorted( all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True ) all_edges_data = truncate_list_by_token_size( all_edges_data, key=lambda x: x["description"] if x["description"] is not None else "", max_token_size=query_param.max_token_for_global_context, tokenizer=tokenizer, ) logger.debug( f"Truncate relations from {len(all_edges)} to {len(all_edges_data)} (max tokens:{query_param.max_token_for_global_context})" ) return all_edges_data async def _get_edge_data( keywords, knowledge_graph_inst: BaseGraphStorage, relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, ): logger.info( f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}" ) results = await relationships_vdb.query( keywords, top_k=query_param.top_k, ids=query_param.ids ) if not len(results): return "", "", "" # Prepare edge pairs in two forms: # For the batch edge properties function, use dicts. edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results] # For edge degrees, use tuples. edge_pairs_tuples = [(r["src_id"], r["tgt_id"]) for r in results] # Call the batched functions concurrently. edge_data_dict, edge_degrees_dict = await asyncio.gather( knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples), ) # Reconstruct edge_datas list in the same order as results. edge_datas = [] for k in results: pair = (k["src_id"], k["tgt_id"]) edge_props = edge_data_dict.get(pair) if edge_props is not None: # Use edge degree from the batch as rank. combined = { "src_id": k["src_id"], "tgt_id": k["tgt_id"], "rank": edge_degrees_dict.get(pair, k.get("rank", 0)), "created_at": k.get("__created_at__", None), **edge_props, } edge_datas.append(combined) tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") edge_datas = sorted( edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True ) edge_datas = truncate_list_by_token_size( edge_datas, key=lambda x: x["description"] if x["description"] is not None else "", max_token_size=query_param.max_token_for_global_context, tokenizer=tokenizer, ) use_entities, use_text_units = await asyncio.gather( _find_most_related_entities_from_relationships( edge_datas, query_param, knowledge_graph_inst, ), _find_related_text_unit_from_relationships( edge_datas, query_param, text_chunks_db, knowledge_graph_inst, ), ) logger.info( f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks" ) relations_section_list = [ [ "id", "source", "target", "description", "keywords", "weight", "rank", "created_at", "file_path", ] ] for i, e in enumerate(edge_datas): created_at = e.get("created_at", "Unknown") # Convert timestamp to readable format if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) # Get file path from edge data file_path = e.get("file_path", "unknown_source") relations_section_list.append( [ i, e["src_id"], e["tgt_id"], e["description"], e["keywords"], e["weight"], e["rank"], created_at, file_path, ] ) relations_context = list_of_list_to_json(relations_section_list) entites_section_list = [ ["id", "entity", "type", "description", "rank", "created_at", "file_path"] ] for i, n in enumerate(use_entities): created_at = n.get("created_at", "Unknown") # Convert timestamp to readable format if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) # Get file path from node data file_path = n.get("file_path", "unknown_source") entites_section_list.append( [ i, n["entity_name"], n.get("entity_type", "UNKNOWN"), n.get("description", "UNKNOWN"), n["rank"], created_at, file_path, ] ) entities_context = list_of_list_to_json(entites_section_list) text_units_section_list = [["id", "content", "file_path"]] for i, t in enumerate(use_text_units): text_units_section_list.append([i, t["content"], t.get("file_path", "unknown")]) text_units_context = list_of_list_to_json(text_units_section_list) return entities_context, relations_context, text_units_context async def _find_most_related_entities_from_relationships( edge_datas: list[dict], query_param: QueryParam, knowledge_graph_inst: BaseGraphStorage, ): entity_names = [] seen = set() for e in edge_datas: if e["src_id"] not in seen: entity_names.append(e["src_id"]) seen.add(e["src_id"]) if e["tgt_id"] not in seen: entity_names.append(e["tgt_id"]) seen.add(e["tgt_id"]) # Batch approach: Retrieve nodes and their degrees concurrently with one query each. nodes_dict, degrees_dict = await asyncio.gather( knowledge_graph_inst.get_nodes_batch(entity_names), knowledge_graph_inst.node_degrees_batch(entity_names), ) # Rebuild the list in the same order as entity_names node_datas = [] for entity_name in entity_names: node = nodes_dict.get(entity_name) degree = degrees_dict.get(entity_name, 0) if node is None: logger.warning(f"Node '{entity_name}' not found in batch retrieval.") continue # Combine the node data with the entity name and computed degree (as rank) combined = {**node, "entity_name": entity_name, "rank": degree} node_datas.append(combined) tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer") len_node_datas = len(node_datas) node_datas = truncate_list_by_token_size( node_datas, key=lambda x: x["description"] if x["description"] is not None else "", max_token_size=query_param.max_token_for_local_context, tokenizer=tokenizer, ) logger.debug( f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})" ) return node_datas async def _find_related_text_unit_from_relationships( edge_datas: list[dict], query_param: QueryParam, text_chunks_db: BaseKVStorage, knowledge_graph_inst: BaseGraphStorage, ): text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) for dp in edge_datas if dp["source_id"] is not None ] all_text_units_lookup = {} async def fetch_chunk_data(c_id, index): if c_id not in all_text_units_lookup: chunk_data = await text_chunks_db.get_by_id(c_id) # Only store valid data if chunk_data is not None and "content" in chunk_data: all_text_units_lookup[c_id] = { "data": chunk_data, "order": index, } tasks = [] for index, unit_list in enumerate(text_units): for c_id in unit_list: tasks.append(fetch_chunk_data(c_id, index)) await asyncio.gather(*tasks) if not all_text_units_lookup: logger.warning("No valid text chunks found") return [] all_text_units = [{"id": k, **v} for k, v in all_text_units_lookup.items()] all_text_units = sorted(all_text_units, key=lambda x: x["order"]) # Ensure all text chunks have content valid_text_units = [ t for t in all_text_units if t["data"] is not None and "content" in t["data"] ] if not valid_text_units: logger.warning("No valid text chunks after filtering") return [] tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") truncated_text_units = truncate_list_by_token_size( valid_text_units, key=lambda x: x["data"]["content"], max_token_size=query_param.max_token_for_text_unit, tokenizer=tokenizer, ) logger.debug( f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})" ) all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units] return all_text_units def combine_contexts(entities, relationships, sources): # Function to extract entities, relationships, and sources from context strings hl_entities, ll_entities = entities[0], entities[1] hl_relationships, ll_relationships = relationships[0], relationships[1] hl_sources, ll_sources = sources[0], sources[1] # Combine and deduplicate the entities combined_entities = process_combine_contexts(hl_entities, ll_entities) # Combine and deduplicate the relationships combined_relationships = process_combine_contexts( hl_relationships, ll_relationships ) # Combine and deduplicate the sources combined_sources = process_combine_contexts(hl_sources, ll_sources) return combined_entities, combined_relationships, combined_sources async def naive_query( query: str, chunks_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, ) -> str | AsyncIterator[str]: # Handle cache use_model_func = ( query_param.model_func if query_param.model_func else global_config["llm_model_func"] ) args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) if cached_response is not None: return cached_response results = await chunks_vdb.query( query, top_k=query_param.top_k, ids=query_param.ids ) if not len(results): return PROMPTS["fail_response"] chunks_ids = [r["id"] for r in results] chunks = await text_chunks_db.get_by_ids(chunks_ids) # Filter out invalid chunks valid_chunks = [ chunk for chunk in chunks if chunk is not None and "content" in chunk ] if not valid_chunks: logger.warning("No valid chunks found after filtering") return PROMPTS["fail_response"] tokenizer: Tokenizer = global_config["tokenizer"] maybe_trun_chunks = truncate_list_by_token_size( valid_chunks, key=lambda x: x["content"], max_token_size=query_param.max_token_for_text_unit, tokenizer=tokenizer, ) if not maybe_trun_chunks: logger.warning("No chunks left after truncation") return PROMPTS["fail_response"] logger.debug( f"Truncate chunks from {len(chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})" ) section = "\n--New Chunk--\n".join( [ "File path: " + c["file_path"] + "\n" + c["content"] for c in maybe_trun_chunks ] ) if query_param.only_need_context: return section # Process conversation history history_context = "" if query_param.conversation_history: history_context = get_conversation_turns( query_param.conversation_history, query_param.history_turns ) sys_prompt_temp = system_prompt if system_prompt else PROMPTS["naive_rag_response"] sys_prompt = sys_prompt_temp.format( content_data=section, response_type=query_param.response_type, history=history_context, ) if query_param.only_need_prompt: return sys_prompt len_of_prompts = len(tokenizer.encode(query + sys_prompt)) logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}") response = await use_model_func( query, system_prompt=sys_prompt, stream=query_param.stream, ) if isinstance(response, str) and len(response) > len(sys_prompt): response = ( response[len(sys_prompt) :] .replace(sys_prompt, "") .replace("user", "") .replace("model", "") .replace(query, "") .replace("", "") .replace("", "") .strip() ) if hashing_kv.global_config.get("enable_llm_cache"): # Save to cache await save_to_cache( hashing_kv, CacheData( args_hash=args_hash, content=response, prompt=query, quantized=quantized, min_val=min_val, max_val=max_val, mode=query_param.mode, cache_type="query", ), ) return response async def kg_query_with_keywords( query: str, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, ) -> str | AsyncIterator[str]: """ Refactored kg_query that does NOT extract keywords by itself. It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty. Then it uses those to build context and produce a final LLM response. """ # --------------------------- # 1) Handle potential cache for query results # --------------------------- use_model_func = ( query_param.model_func if query_param.model_func else global_config["llm_model_func"] ) args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) if cached_response is not None: return cached_response # --------------------------- # 2) RETRIEVE KEYWORDS FROM query_param # --------------------------- # If these fields don't exist, default to empty lists/strings. hl_keywords = getattr(query_param, "hl_keywords", []) or [] ll_keywords = getattr(query_param, "ll_keywords", []) or [] # If neither has any keywords, you could handle that logic here. if not hl_keywords and not ll_keywords: logger.warning( "No keywords found in query_param. Could default to global mode or fail." ) return PROMPTS["fail_response"] if not ll_keywords and query_param.mode in ["local", "hybrid"]: logger.warning("low_level_keywords is empty, switching to global mode.") query_param.mode = "global" if not hl_keywords and query_param.mode in ["global", "hybrid"]: logger.warning("high_level_keywords is empty, switching to local mode.") query_param.mode = "local" # Flatten low-level and high-level keywords if needed ll_keywords_flat = ( [item for sublist in ll_keywords for item in sublist] if any(isinstance(i, list) for i in ll_keywords) else ll_keywords ) hl_keywords_flat = ( [item for sublist in hl_keywords for item in sublist] if any(isinstance(i, list) for i in hl_keywords) else hl_keywords ) # Join the flattened lists ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else "" hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else "" # --------------------------- # 3) BUILD CONTEXT # --------------------------- context = await _build_query_context( ll_keywords_str, hl_keywords_str, knowledge_graph_inst, entities_vdb, relationships_vdb, text_chunks_db, query_param, ) if not context: return PROMPTS["fail_response"] # If only context is needed, return it if query_param.only_need_context: return context # --------------------------- # 4) BUILD THE SYSTEM PROMPT + CALL LLM # --------------------------- # Process conversation history history_context = "" if query_param.conversation_history: history_context = get_conversation_turns( query_param.conversation_history, query_param.history_turns ) sys_prompt_temp = PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type, history=history_context, ) if query_param.only_need_prompt: return sys_prompt tokenizer: Tokenizer = global_config["tokenizer"] len_of_prompts = len(tokenizer.encode(query + sys_prompt)) logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}") # 6. Generate response response = await use_model_func( query, system_prompt=sys_prompt, stream=query_param.stream, ) # Clean up response content if isinstance(response, str) and len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") .replace("user", "") .replace("model", "") .replace(query, "") .replace("", "") .replace("", "") .strip() ) if hashing_kv.global_config.get("enable_llm_cache"): # 7. Save cache - 只有在收集完整响应后才缓存 await save_to_cache( hashing_kv, CacheData( args_hash=args_hash, content=response, prompt=query, quantized=quantized, min_val=min_val, max_val=max_val, mode=query_param.mode, cache_type="query", ), ) return response async def query_with_keywords( query: str, prompt: str, param: QueryParam, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, chunks_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, ) -> str | AsyncIterator[str]: """ Extract keywords from the query and then use them for retrieving information. 1. Extracts high-level and low-level keywords from the query 2. Formats the query with the extracted keywords and prompt 3. Uses the appropriate query method based on param.mode Args: query: The user's query prompt: Additional prompt to prepend to the query param: Query parameters knowledge_graph_inst: Knowledge graph storage entities_vdb: Entities vector database relationships_vdb: Relationships vector database chunks_vdb: Document chunks vector database text_chunks_db: Text chunks storage global_config: Global configuration hashing_kv: Cache storage Returns: Query response or async iterator """ # Extract keywords hl_keywords, ll_keywords = await get_keywords_from_query( query=query, query_param=param, global_config=global_config, hashing_kv=hashing_kv, ) # Create a new string with the prompt and the keywords ll_keywords_str = ", ".join(ll_keywords) hl_keywords_str = ", ".join(hl_keywords) formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}" # Use appropriate query method based on mode if param.mode in ["local", "global", "hybrid"]: return await kg_query_with_keywords( formatted_question, knowledge_graph_inst, entities_vdb, relationships_vdb, text_chunks_db, param, global_config, hashing_kv=hashing_kv, ) elif param.mode == "naive": return await naive_query( formatted_question, chunks_vdb, text_chunks_db, param, global_config, hashing_kv=hashing_kv, ) elif param.mode == "mix": return await mix_kg_vector_query( formatted_question, knowledge_graph_inst, entities_vdb, relationships_vdb, chunks_vdb, text_chunks_db, param, global_config, hashing_kv=hashing_kv, ) else: raise ValueError(f"Unknown mode {param.mode}")