mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-11-14 17:13:13 +00:00
- Add missing query parameters (top_k, enable_rerank, max_tokens, etc.) to cache key generation in kg_query, naive_query, and extract_keywords_only functions - Add queryparam field to CacheData structure and PostgreSQL storage for debugging - Update PostgreSQL schema with automatic migration for queryparam JSONB column - Prevent incorrect cache hits between queries with different parameters Fixes issue where different query parameters incorrectly shared the same cached results.
3390 lines
123 KiB
Python
3390 lines
123 KiB
Python
from __future__ import annotations
|
||
from functools import partial
|
||
|
||
import asyncio
|
||
import json
|
||
import re
|
||
import os
|
||
import json_repair
|
||
from typing import Any, AsyncIterator
|
||
from collections import Counter, defaultdict
|
||
|
||
from .utils import (
|
||
logger,
|
||
clean_str,
|
||
compute_mdhash_id,
|
||
Tokenizer,
|
||
is_float_regex,
|
||
normalize_extracted_info,
|
||
pack_user_ass_to_openai_messages,
|
||
split_string_by_multi_markers,
|
||
truncate_list_by_token_size,
|
||
compute_args_hash,
|
||
handle_cache,
|
||
save_to_cache,
|
||
CacheData,
|
||
get_conversation_turns,
|
||
use_llm_func_with_cache,
|
||
update_chunk_cache_list,
|
||
remove_think_tags,
|
||
linear_gradient_weighted_polling,
|
||
process_chunks_unified,
|
||
build_file_path,
|
||
)
|
||
from .base import (
|
||
BaseGraphStorage,
|
||
BaseKVStorage,
|
||
BaseVectorStorage,
|
||
TextChunkSchema,
|
||
QueryParam,
|
||
)
|
||
from .prompt import PROMPTS
|
||
from .constants import (
|
||
GRAPH_FIELD_SEP,
|
||
DEFAULT_MAX_ENTITY_TOKENS,
|
||
DEFAULT_MAX_RELATION_TOKENS,
|
||
DEFAULT_MAX_TOTAL_TOKENS,
|
||
DEFAULT_RELATED_CHUNK_NUMBER,
|
||
)
|
||
from .kg.shared_storage import get_storage_keyed_lock
|
||
import time
|
||
from dotenv import load_dotenv
|
||
|
||
# use the .env that is inside the current folder
|
||
# allows to use different .env file for each lightrag instance
|
||
# the OS environment variables take precedence over the .env file
|
||
load_dotenv(dotenv_path=".env", override=False)
|
||
|
||
|
||
def chunking_by_token_size(
|
||
tokenizer: Tokenizer,
|
||
content: str,
|
||
split_by_character: str | None = None,
|
||
split_by_character_only: bool = False,
|
||
overlap_token_size: int = 128,
|
||
max_token_size: int = 1024,
|
||
) -> list[dict[str, Any]]:
|
||
tokens = tokenizer.encode(content)
|
||
results: list[dict[str, Any]] = []
|
||
if split_by_character:
|
||
raw_chunks = content.split(split_by_character)
|
||
new_chunks = []
|
||
if split_by_character_only:
|
||
for chunk in raw_chunks:
|
||
_tokens = tokenizer.encode(chunk)
|
||
new_chunks.append((len(_tokens), chunk))
|
||
else:
|
||
for chunk in raw_chunks:
|
||
_tokens = tokenizer.encode(chunk)
|
||
if len(_tokens) > max_token_size:
|
||
for start in range(
|
||
0, len(_tokens), max_token_size - overlap_token_size
|
||
):
|
||
chunk_content = tokenizer.decode(
|
||
_tokens[start : start + max_token_size]
|
||
)
|
||
new_chunks.append(
|
||
(min(max_token_size, len(_tokens) - start), chunk_content)
|
||
)
|
||
else:
|
||
new_chunks.append((len(_tokens), chunk))
|
||
for index, (_len, chunk) in enumerate(new_chunks):
|
||
results.append(
|
||
{
|
||
"tokens": _len,
|
||
"content": chunk.strip(),
|
||
"chunk_order_index": index,
|
||
}
|
||
)
|
||
else:
|
||
for index, start in enumerate(
|
||
range(0, len(tokens), max_token_size - overlap_token_size)
|
||
):
|
||
chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
|
||
results.append(
|
||
{
|
||
"tokens": min(max_token_size, len(tokens) - start),
|
||
"content": chunk_content.strip(),
|
||
"chunk_order_index": index,
|
||
}
|
||
)
|
||
return results
|
||
|
||
|
||
async def _handle_entity_relation_summary(
|
||
entity_or_relation_name: str,
|
||
description: str,
|
||
global_config: dict,
|
||
llm_response_cache: BaseKVStorage | None = None,
|
||
) -> str:
|
||
"""Handle entity relation summary
|
||
For each entity or relation, input is the combined description of already existing description and new description.
|
||
If too long, use LLM to summarize.
|
||
"""
|
||
use_llm_func: callable = global_config["llm_model_func"]
|
||
# Apply higher priority (8) to entity/relation summary tasks
|
||
use_llm_func = partial(use_llm_func, _priority=8)
|
||
|
||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||
llm_max_tokens = global_config["summary_max_tokens"]
|
||
|
||
language = global_config["addon_params"].get(
|
||
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
||
)
|
||
|
||
tokens = tokenizer.encode(description)
|
||
|
||
### summarize is not determined here anymore (It's determined by num_fragment now)
|
||
# if len(tokens) < summary_max_tokens: # No need for summary
|
||
# return description
|
||
|
||
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
||
use_description = tokenizer.decode(tokens[:llm_max_tokens])
|
||
context_base = dict(
|
||
entity_name=entity_or_relation_name,
|
||
description_list=use_description.split(GRAPH_FIELD_SEP),
|
||
language=language,
|
||
)
|
||
use_prompt = prompt_template.format(**context_base)
|
||
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
||
|
||
# Use LLM function with cache (higher priority for summary generation)
|
||
summary = await use_llm_func_with_cache(
|
||
use_prompt,
|
||
use_llm_func,
|
||
llm_response_cache=llm_response_cache,
|
||
# max_tokens=summary_max_tokens,
|
||
cache_type="extract",
|
||
)
|
||
return summary
|
||
|
||
|
||
async def _handle_single_entity_extraction(
|
||
record_attributes: list[str],
|
||
chunk_key: str,
|
||
file_path: str = "unknown_source",
|
||
):
|
||
if len(record_attributes) < 4 or '"entity"' not in record_attributes[0]:
|
||
return None
|
||
|
||
# Clean and validate entity name
|
||
entity_name = clean_str(record_attributes[1]).strip()
|
||
if not entity_name:
|
||
logger.warning(
|
||
f"Entity extraction error: empty entity name in: {record_attributes}"
|
||
)
|
||
return None
|
||
|
||
# Normalize entity name
|
||
entity_name = normalize_extracted_info(entity_name, is_entity=True)
|
||
|
||
# Check if entity name became empty after normalization
|
||
if not entity_name or not entity_name.strip():
|
||
logger.warning(
|
||
f"Entity extraction error: entity name became empty after normalization. Original: '{record_attributes[1]}'"
|
||
)
|
||
return None
|
||
|
||
# Clean and validate entity type
|
||
entity_type = clean_str(record_attributes[2]).strip('"')
|
||
if not entity_type.strip() or entity_type.startswith('("'):
|
||
logger.warning(
|
||
f"Entity extraction error: invalid entity type in: {record_attributes}"
|
||
)
|
||
return None
|
||
|
||
# Clean and validate description
|
||
entity_description = clean_str(record_attributes[3])
|
||
entity_description = normalize_extracted_info(entity_description)
|
||
|
||
if not entity_description.strip():
|
||
logger.warning(
|
||
f"Entity extraction error: empty description for entity '{entity_name}' of type '{entity_type}'"
|
||
)
|
||
return None
|
||
|
||
return dict(
|
||
entity_name=entity_name,
|
||
entity_type=entity_type,
|
||
description=entity_description,
|
||
source_id=chunk_key,
|
||
file_path=file_path,
|
||
)
|
||
|
||
|
||
async def _handle_single_relationship_extraction(
|
||
record_attributes: list[str],
|
||
chunk_key: str,
|
||
file_path: str = "unknown_source",
|
||
):
|
||
if len(record_attributes) < 5 or '"relationship"' not in record_attributes[0]:
|
||
return None
|
||
# add this record as edge
|
||
source = clean_str(record_attributes[1])
|
||
target = clean_str(record_attributes[2])
|
||
|
||
# Normalize source and target entity names
|
||
source = normalize_extracted_info(source, is_entity=True)
|
||
target = normalize_extracted_info(target, is_entity=True)
|
||
|
||
# Check if source or target became empty after normalization
|
||
if not source or not source.strip():
|
||
logger.warning(
|
||
f"Relationship extraction error: source entity became empty after normalization. Original: '{record_attributes[1]}'"
|
||
)
|
||
return None
|
||
|
||
if not target or not target.strip():
|
||
logger.warning(
|
||
f"Relationship extraction error: target entity became empty after normalization. Original: '{record_attributes[2]}'"
|
||
)
|
||
return None
|
||
|
||
if source == target:
|
||
logger.debug(
|
||
f"Relationship source and target are the same in: {record_attributes}"
|
||
)
|
||
return None
|
||
|
||
edge_description = clean_str(record_attributes[3])
|
||
edge_description = normalize_extracted_info(edge_description)
|
||
|
||
edge_keywords = normalize_extracted_info(
|
||
clean_str(record_attributes[4]), is_entity=True
|
||
)
|
||
edge_keywords = edge_keywords.replace(",", ",")
|
||
|
||
edge_source_id = chunk_key
|
||
weight = (
|
||
float(record_attributes[-1].strip('"').strip("'"))
|
||
if is_float_regex(record_attributes[-1].strip('"').strip("'"))
|
||
else 1.0
|
||
)
|
||
return dict(
|
||
src_id=source,
|
||
tgt_id=target,
|
||
weight=weight,
|
||
description=edge_description,
|
||
keywords=edge_keywords,
|
||
source_id=edge_source_id,
|
||
file_path=file_path,
|
||
)
|
||
|
||
|
||
async def _rebuild_knowledge_from_chunks(
|
||
entities_to_rebuild: dict[str, set[str]],
|
||
relationships_to_rebuild: dict[tuple[str, str], set[str]],
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entities_vdb: BaseVectorStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
text_chunks_storage: BaseKVStorage,
|
||
llm_response_cache: BaseKVStorage,
|
||
global_config: dict[str, str],
|
||
pipeline_status: dict | None = None,
|
||
pipeline_status_lock=None,
|
||
) -> None:
|
||
"""Rebuild entity and relationship descriptions from cached extraction results with parallel processing
|
||
|
||
This method uses cached LLM extraction results instead of calling LLM again,
|
||
following the same approach as the insert process. Now with parallel processing
|
||
controlled by llm_model_max_async and using get_storage_keyed_lock for data consistency.
|
||
|
||
Args:
|
||
entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids
|
||
relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids
|
||
knowledge_graph_inst: Knowledge graph storage
|
||
entities_vdb: Entity vector database
|
||
relationships_vdb: Relationship vector database
|
||
text_chunks_storage: Text chunks storage
|
||
llm_response_cache: LLM response cache
|
||
global_config: Global configuration containing llm_model_max_async
|
||
pipeline_status: Pipeline status dictionary
|
||
pipeline_status_lock: Lock for pipeline status
|
||
"""
|
||
if not entities_to_rebuild and not relationships_to_rebuild:
|
||
return
|
||
|
||
# Get all referenced chunk IDs
|
||
all_referenced_chunk_ids = set()
|
||
for chunk_ids in entities_to_rebuild.values():
|
||
all_referenced_chunk_ids.update(chunk_ids)
|
||
for chunk_ids in relationships_to_rebuild.values():
|
||
all_referenced_chunk_ids.update(chunk_ids)
|
||
|
||
status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions (parallel processing)"
|
||
logger.info(status_message)
|
||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = status_message
|
||
pipeline_status["history_messages"].append(status_message)
|
||
|
||
# Get cached extraction results for these chunks using storage
|
||
# cached_results: chunk_id -> [list of extraction result from LLM cache sorted by created_at]
|
||
cached_results = await _get_cached_extraction_results(
|
||
llm_response_cache,
|
||
all_referenced_chunk_ids,
|
||
text_chunks_storage=text_chunks_storage,
|
||
)
|
||
|
||
if not cached_results:
|
||
status_message = "No cached extraction results found, cannot rebuild"
|
||
logger.warning(status_message)
|
||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = status_message
|
||
pipeline_status["history_messages"].append(status_message)
|
||
return
|
||
|
||
# Process cached results to get entities and relationships for each chunk
|
||
chunk_entities = {} # chunk_id -> {entity_name: [entity_data]}
|
||
chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]}
|
||
|
||
for chunk_id, extraction_results in cached_results.items():
|
||
try:
|
||
# Handle multiple extraction results per chunk
|
||
chunk_entities[chunk_id] = defaultdict(list)
|
||
chunk_relationships[chunk_id] = defaultdict(list)
|
||
|
||
# process multiple LLM extraction results for a single chunk_id
|
||
for extraction_result in extraction_results:
|
||
entities, relationships = await _parse_extraction_result(
|
||
text_chunks_storage=text_chunks_storage,
|
||
extraction_result=extraction_result,
|
||
chunk_id=chunk_id,
|
||
)
|
||
|
||
# Merge entities and relationships from this extraction result
|
||
# Only keep the first occurrence of each entity_name in the same chunk_id
|
||
for entity_name, entity_list in entities.items():
|
||
if (
|
||
entity_name not in chunk_entities[chunk_id]
|
||
or len(chunk_entities[chunk_id][entity_name]) == 0
|
||
):
|
||
chunk_entities[chunk_id][entity_name].extend(entity_list)
|
||
|
||
# Only keep the first occurrence of each rel_key in the same chunk_id
|
||
for rel_key, rel_list in relationships.items():
|
||
if (
|
||
rel_key not in chunk_relationships[chunk_id]
|
||
or len(chunk_relationships[chunk_id][rel_key]) == 0
|
||
):
|
||
chunk_relationships[chunk_id][rel_key].extend(rel_list)
|
||
|
||
except Exception as e:
|
||
status_message = (
|
||
f"Failed to parse cached extraction result for chunk {chunk_id}: {e}"
|
||
)
|
||
logger.info(status_message) # Per requirement, change to info
|
||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = status_message
|
||
pipeline_status["history_messages"].append(status_message)
|
||
continue
|
||
|
||
# Get max async tasks limit from global_config for semaphore control
|
||
graph_max_async = global_config.get("llm_model_max_async", 4) * 2
|
||
semaphore = asyncio.Semaphore(graph_max_async)
|
||
|
||
# Counters for tracking progress
|
||
rebuilt_entities_count = 0
|
||
rebuilt_relationships_count = 0
|
||
failed_entities_count = 0
|
||
failed_relationships_count = 0
|
||
|
||
async def _locked_rebuild_entity(entity_name, chunk_ids):
|
||
nonlocal rebuilt_entities_count, failed_entities_count
|
||
async with semaphore:
|
||
workspace = global_config.get("workspace", "")
|
||
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
||
async with get_storage_keyed_lock(
|
||
[entity_name], namespace=namespace, enable_logging=False
|
||
):
|
||
try:
|
||
await _rebuild_single_entity(
|
||
knowledge_graph_inst=knowledge_graph_inst,
|
||
entities_vdb=entities_vdb,
|
||
entity_name=entity_name,
|
||
chunk_ids=chunk_ids,
|
||
chunk_entities=chunk_entities,
|
||
llm_response_cache=llm_response_cache,
|
||
global_config=global_config,
|
||
)
|
||
rebuilt_entities_count += 1
|
||
status_message = (
|
||
f"Rebuilt entity: {entity_name} from {len(chunk_ids)} chunks"
|
||
)
|
||
logger.info(status_message)
|
||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = status_message
|
||
pipeline_status["history_messages"].append(status_message)
|
||
except Exception as e:
|
||
failed_entities_count += 1
|
||
status_message = f"Failed to rebuild entity {entity_name}: {e}"
|
||
logger.info(status_message) # Per requirement, change to info
|
||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = status_message
|
||
pipeline_status["history_messages"].append(status_message)
|
||
|
||
async def _locked_rebuild_relationship(src, tgt, chunk_ids):
|
||
nonlocal rebuilt_relationships_count, failed_relationships_count
|
||
async with semaphore:
|
||
workspace = global_config.get("workspace", "")
|
||
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
||
# Sort src and tgt to ensure order-independent lock key generation
|
||
sorted_key_parts = sorted([src, tgt])
|
||
async with get_storage_keyed_lock(
|
||
sorted_key_parts,
|
||
namespace=namespace,
|
||
enable_logging=False,
|
||
):
|
||
try:
|
||
await _rebuild_single_relationship(
|
||
knowledge_graph_inst=knowledge_graph_inst,
|
||
relationships_vdb=relationships_vdb,
|
||
src=src,
|
||
tgt=tgt,
|
||
chunk_ids=chunk_ids,
|
||
chunk_relationships=chunk_relationships,
|
||
llm_response_cache=llm_response_cache,
|
||
global_config=global_config,
|
||
)
|
||
rebuilt_relationships_count += 1
|
||
status_message = f"Rebuilt relationship: {src}->{tgt} from {len(chunk_ids)} chunks"
|
||
logger.info(status_message)
|
||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = status_message
|
||
pipeline_status["history_messages"].append(status_message)
|
||
except Exception as e:
|
||
failed_relationships_count += 1
|
||
status_message = f"Failed to rebuild relationship {src}->{tgt}: {e}"
|
||
logger.info(status_message) # Per requirement, change to info
|
||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = status_message
|
||
pipeline_status["history_messages"].append(status_message)
|
||
|
||
# Create tasks for parallel processing
|
||
tasks = []
|
||
|
||
# Add entity rebuilding tasks
|
||
for entity_name, chunk_ids in entities_to_rebuild.items():
|
||
task = asyncio.create_task(_locked_rebuild_entity(entity_name, chunk_ids))
|
||
tasks.append(task)
|
||
|
||
# Add relationship rebuilding tasks
|
||
for (src, tgt), chunk_ids in relationships_to_rebuild.items():
|
||
task = asyncio.create_task(_locked_rebuild_relationship(src, tgt, chunk_ids))
|
||
tasks.append(task)
|
||
|
||
# Log parallel processing start
|
||
status_message = f"Starting parallel rebuild of {len(entities_to_rebuild)} entities and {len(relationships_to_rebuild)} relationships (async: {graph_max_async})"
|
||
logger.info(status_message)
|
||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = status_message
|
||
pipeline_status["history_messages"].append(status_message)
|
||
|
||
# Execute all tasks in parallel with semaphore control and early failure detection
|
||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
|
||
|
||
# Check if any task raised an exception
|
||
for task in done:
|
||
if task.exception():
|
||
# If a task failed, cancel all pending tasks
|
||
for pending_task in pending:
|
||
pending_task.cancel()
|
||
|
||
# Wait for cancellation to complete
|
||
if pending:
|
||
await asyncio.wait(pending)
|
||
|
||
# Re-raise the exception to notify the caller
|
||
raise task.exception()
|
||
|
||
# Final status report
|
||
status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships rebuilt successfully."
|
||
if failed_entities_count > 0 or failed_relationships_count > 0:
|
||
status_message += f" Failed: {failed_entities_count} entities, {failed_relationships_count} relationships."
|
||
|
||
logger.info(status_message)
|
||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = status_message
|
||
pipeline_status["history_messages"].append(status_message)
|
||
|
||
|
||
async def _get_cached_extraction_results(
|
||
llm_response_cache: BaseKVStorage,
|
||
chunk_ids: set[str],
|
||
text_chunks_storage: BaseKVStorage,
|
||
) -> dict[str, list[str]]:
|
||
"""Get cached extraction results for specific chunk IDs
|
||
|
||
Args:
|
||
llm_response_cache: LLM response cache storage
|
||
chunk_ids: Set of chunk IDs to get cached results for
|
||
text_chunks_data: Pre-loaded chunk data (optional, for performance)
|
||
text_chunks_storage: Text chunks storage (fallback if text_chunks_data is None)
|
||
|
||
Returns:
|
||
Dict mapping chunk_id -> list of extraction_result_text
|
||
"""
|
||
cached_results = {}
|
||
|
||
# Collect all LLM cache IDs from chunks
|
||
all_cache_ids = set()
|
||
|
||
# Read from storage
|
||
chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids))
|
||
for chunk_id, chunk_data in zip(chunk_ids, chunk_data_list):
|
||
if chunk_data and isinstance(chunk_data, dict):
|
||
llm_cache_list = chunk_data.get("llm_cache_list", [])
|
||
if llm_cache_list:
|
||
all_cache_ids.update(llm_cache_list)
|
||
else:
|
||
logger.warning(
|
||
f"Chunk {chunk_id} data is invalid or None: {type(chunk_data)}"
|
||
)
|
||
|
||
if not all_cache_ids:
|
||
logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs")
|
||
return cached_results
|
||
|
||
# Batch get LLM cache entries
|
||
cache_data_list = await llm_response_cache.get_by_ids(list(all_cache_ids))
|
||
|
||
# Process cache entries and group by chunk_id
|
||
valid_entries = 0
|
||
for cache_id, cache_entry in zip(all_cache_ids, cache_data_list):
|
||
if (
|
||
cache_entry is not None
|
||
and isinstance(cache_entry, dict)
|
||
and cache_entry.get("cache_type") == "extract"
|
||
and cache_entry.get("chunk_id") in chunk_ids
|
||
):
|
||
chunk_id = cache_entry["chunk_id"]
|
||
extraction_result = cache_entry["return"]
|
||
create_time = cache_entry.get(
|
||
"create_time", 0
|
||
) # Get creation time, default to 0
|
||
valid_entries += 1
|
||
|
||
# Support multiple LLM caches per chunk
|
||
if chunk_id not in cached_results:
|
||
cached_results[chunk_id] = []
|
||
# Store tuple with extraction result and creation time for sorting
|
||
cached_results[chunk_id].append((extraction_result, create_time))
|
||
|
||
# Sort extraction results by create_time for each chunk
|
||
for chunk_id in cached_results:
|
||
# Sort by create_time (x[1]), then extract only extraction_result (x[0])
|
||
cached_results[chunk_id].sort(key=lambda x: x[1])
|
||
cached_results[chunk_id] = [item[0] for item in cached_results[chunk_id]]
|
||
|
||
logger.info(
|
||
f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results"
|
||
)
|
||
return cached_results
|
||
|
||
|
||
async def _parse_extraction_result(
|
||
text_chunks_storage: BaseKVStorage, extraction_result: str, chunk_id: str
|
||
) -> tuple[dict, dict]:
|
||
"""Parse cached extraction result using the same logic as extract_entities
|
||
|
||
Args:
|
||
text_chunks_storage: Text chunks storage to get chunk data
|
||
extraction_result: The cached LLM extraction result
|
||
chunk_id: The chunk ID for source tracking
|
||
|
||
Returns:
|
||
Tuple of (entities_dict, relationships_dict)
|
||
"""
|
||
|
||
# Get chunk data for file_path from storage
|
||
chunk_data = await text_chunks_storage.get_by_id(chunk_id)
|
||
file_path = (
|
||
chunk_data.get("file_path", "unknown_source")
|
||
if chunk_data
|
||
else "unknown_source"
|
||
)
|
||
context_base = dict(
|
||
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
||
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
||
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
||
)
|
||
maybe_nodes = defaultdict(list)
|
||
maybe_edges = defaultdict(list)
|
||
|
||
# Parse the extraction result using the same logic as in extract_entities
|
||
records = split_string_by_multi_markers(
|
||
extraction_result,
|
||
[context_base["record_delimiter"], context_base["completion_delimiter"]],
|
||
)
|
||
for record in records:
|
||
record = re.search(r"\((.*)\)", record)
|
||
if record is None:
|
||
continue
|
||
record = record.group(1)
|
||
record_attributes = split_string_by_multi_markers(
|
||
record, [context_base["tuple_delimiter"]]
|
||
)
|
||
|
||
# Try to parse as entity
|
||
entity_data = await _handle_single_entity_extraction(
|
||
record_attributes, chunk_id, file_path
|
||
)
|
||
if entity_data is not None:
|
||
maybe_nodes[entity_data["entity_name"]].append(entity_data)
|
||
continue
|
||
|
||
# Try to parse as relationship
|
||
relationship_data = await _handle_single_relationship_extraction(
|
||
record_attributes, chunk_id, file_path
|
||
)
|
||
if relationship_data is not None:
|
||
maybe_edges[
|
||
(relationship_data["src_id"], relationship_data["tgt_id"])
|
||
].append(relationship_data)
|
||
|
||
return dict(maybe_nodes), dict(maybe_edges)
|
||
|
||
|
||
async def _rebuild_single_entity(
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entities_vdb: BaseVectorStorage,
|
||
entity_name: str,
|
||
chunk_ids: set[str],
|
||
chunk_entities: dict,
|
||
llm_response_cache: BaseKVStorage,
|
||
global_config: dict[str, str],
|
||
) -> None:
|
||
"""Rebuild a single entity from cached extraction results"""
|
||
|
||
# Get current entity data
|
||
current_entity = await knowledge_graph_inst.get_node(entity_name)
|
||
if not current_entity:
|
||
return
|
||
|
||
# Helper function to update entity in both graph and vector storage
|
||
async def _update_entity_storage(
|
||
final_description: str, entity_type: str, file_paths: set[str]
|
||
):
|
||
# Update entity in graph storage
|
||
updated_entity_data = {
|
||
**current_entity,
|
||
"description": final_description,
|
||
"entity_type": entity_type,
|
||
"source_id": GRAPH_FIELD_SEP.join(chunk_ids),
|
||
"file_path": GRAPH_FIELD_SEP.join(file_paths)
|
||
if file_paths
|
||
else current_entity.get("file_path", "unknown_source"),
|
||
}
|
||
await knowledge_graph_inst.upsert_node(entity_name, updated_entity_data)
|
||
|
||
# Update entity in vector database
|
||
entity_vdb_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||
|
||
# Delete old vector record first
|
||
try:
|
||
await entities_vdb.delete([entity_vdb_id])
|
||
except Exception as e:
|
||
logger.debug(
|
||
f"Could not delete old entity vector record {entity_vdb_id}: {e}"
|
||
)
|
||
|
||
# Insert new vector record
|
||
entity_content = f"{entity_name}\n{final_description}"
|
||
await entities_vdb.upsert(
|
||
{
|
||
entity_vdb_id: {
|
||
"content": entity_content,
|
||
"entity_name": entity_name,
|
||
"source_id": updated_entity_data["source_id"],
|
||
"description": final_description,
|
||
"entity_type": entity_type,
|
||
"file_path": updated_entity_data["file_path"],
|
||
}
|
||
}
|
||
)
|
||
|
||
# Helper function to generate final description with optional LLM summary
|
||
async def _generate_final_description(combined_description: str) -> str:
|
||
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
|
||
num_fragment = combined_description.count(GRAPH_FIELD_SEP) + 1
|
||
|
||
if num_fragment >= force_llm_summary_on_merge:
|
||
return await _handle_entity_relation_summary(
|
||
entity_name,
|
||
combined_description,
|
||
global_config,
|
||
llm_response_cache=llm_response_cache,
|
||
)
|
||
else:
|
||
return combined_description
|
||
|
||
# Collect all entity data from relevant chunks
|
||
all_entity_data = []
|
||
for chunk_id in chunk_ids:
|
||
if chunk_id in chunk_entities and entity_name in chunk_entities[chunk_id]:
|
||
all_entity_data.extend(chunk_entities[chunk_id][entity_name])
|
||
|
||
if not all_entity_data:
|
||
logger.warning(
|
||
f"No cached entity data found for {entity_name}, trying to rebuild from relationships"
|
||
)
|
||
|
||
# Get all edges connected to this entity
|
||
edges = await knowledge_graph_inst.get_node_edges(entity_name)
|
||
if not edges:
|
||
logger.warning(f"No relationships found for entity {entity_name}")
|
||
return
|
||
|
||
# Collect relationship data to extract entity information
|
||
relationship_descriptions = []
|
||
file_paths = set()
|
||
|
||
# Get edge data for all connected relationships
|
||
for src_id, tgt_id in edges:
|
||
edge_data = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
||
if edge_data:
|
||
if edge_data.get("description"):
|
||
relationship_descriptions.append(edge_data["description"])
|
||
|
||
if edge_data.get("file_path"):
|
||
edge_file_paths = edge_data["file_path"].split(GRAPH_FIELD_SEP)
|
||
file_paths.update(edge_file_paths)
|
||
|
||
# Generate description from relationships or fallback to current
|
||
if relationship_descriptions:
|
||
combined_description = GRAPH_FIELD_SEP.join(relationship_descriptions)
|
||
final_description = await _generate_final_description(combined_description)
|
||
else:
|
||
final_description = current_entity.get("description", "")
|
||
|
||
entity_type = current_entity.get("entity_type", "UNKNOWN")
|
||
await _update_entity_storage(final_description, entity_type, file_paths)
|
||
return
|
||
|
||
# Process cached entity data
|
||
descriptions = []
|
||
entity_types = []
|
||
file_paths = set()
|
||
|
||
for entity_data in all_entity_data:
|
||
if entity_data.get("description"):
|
||
descriptions.append(entity_data["description"])
|
||
if entity_data.get("entity_type"):
|
||
entity_types.append(entity_data["entity_type"])
|
||
if entity_data.get("file_path"):
|
||
file_paths.add(entity_data["file_path"])
|
||
|
||
# Combine all descriptions
|
||
combined_description = (
|
||
GRAPH_FIELD_SEP.join(descriptions)
|
||
if descriptions
|
||
else current_entity.get("description", "")
|
||
)
|
||
|
||
# Get most common entity type
|
||
entity_type = (
|
||
max(set(entity_types), key=entity_types.count)
|
||
if entity_types
|
||
else current_entity.get("entity_type", "UNKNOWN")
|
||
)
|
||
|
||
# Generate final description and update storage
|
||
final_description = await _generate_final_description(combined_description)
|
||
await _update_entity_storage(final_description, entity_type, file_paths)
|
||
|
||
|
||
async def _rebuild_single_relationship(
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
src: str,
|
||
tgt: str,
|
||
chunk_ids: set[str],
|
||
chunk_relationships: dict,
|
||
llm_response_cache: BaseKVStorage,
|
||
global_config: dict[str, str],
|
||
) -> None:
|
||
"""Rebuild a single relationship from cached extraction results
|
||
|
||
Note: This function assumes the caller has already acquired the appropriate
|
||
keyed lock for the relationship pair to ensure thread safety.
|
||
"""
|
||
|
||
# Get current relationship data
|
||
current_relationship = await knowledge_graph_inst.get_edge(src, tgt)
|
||
if not current_relationship:
|
||
return
|
||
|
||
# Collect all relationship data from relevant chunks
|
||
all_relationship_data = []
|
||
for chunk_id in chunk_ids:
|
||
if chunk_id in chunk_relationships:
|
||
# Check both (src, tgt) and (tgt, src) since relationships can be bidirectional
|
||
for edge_key in [(src, tgt), (tgt, src)]:
|
||
if edge_key in chunk_relationships[chunk_id]:
|
||
all_relationship_data.extend(
|
||
chunk_relationships[chunk_id][edge_key]
|
||
)
|
||
|
||
if not all_relationship_data:
|
||
logger.warning(f"No cached relationship data found for {src}-{tgt}")
|
||
return
|
||
|
||
# Merge descriptions and keywords
|
||
descriptions = []
|
||
keywords = []
|
||
weights = []
|
||
file_paths = set()
|
||
|
||
for rel_data in all_relationship_data:
|
||
if rel_data.get("description"):
|
||
descriptions.append(rel_data["description"])
|
||
if rel_data.get("keywords"):
|
||
keywords.append(rel_data["keywords"])
|
||
if rel_data.get("weight"):
|
||
weights.append(rel_data["weight"])
|
||
if rel_data.get("file_path"):
|
||
file_paths.add(rel_data["file_path"])
|
||
|
||
# Combine descriptions and keywords
|
||
combined_description = (
|
||
GRAPH_FIELD_SEP.join(descriptions)
|
||
if descriptions
|
||
else current_relationship.get("description", "")
|
||
)
|
||
combined_keywords = (
|
||
", ".join(set(keywords))
|
||
if keywords
|
||
else current_relationship.get("keywords", "")
|
||
)
|
||
# weight = (
|
||
# sum(weights) / len(weights)
|
||
# if weights
|
||
# else current_relationship.get("weight", 1.0)
|
||
# )
|
||
weight = sum(weights) if weights else current_relationship.get("weight", 1.0)
|
||
|
||
# Use summary if description has too many fragments
|
||
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
|
||
num_fragment = combined_description.count(GRAPH_FIELD_SEP) + 1
|
||
|
||
if num_fragment >= force_llm_summary_on_merge:
|
||
final_description = await _handle_entity_relation_summary(
|
||
f"{src}-{tgt}",
|
||
combined_description,
|
||
global_config,
|
||
llm_response_cache=llm_response_cache,
|
||
)
|
||
else:
|
||
final_description = combined_description
|
||
|
||
# Update relationship in graph storage
|
||
updated_relationship_data = {
|
||
**current_relationship,
|
||
"description": final_description,
|
||
"keywords": combined_keywords,
|
||
"weight": weight,
|
||
"source_id": GRAPH_FIELD_SEP.join(chunk_ids),
|
||
"file_path": GRAPH_FIELD_SEP.join([fp for fp in file_paths if fp])
|
||
if file_paths
|
||
else current_relationship.get("file_path", "unknown_source"),
|
||
}
|
||
await knowledge_graph_inst.upsert_edge(src, tgt, updated_relationship_data)
|
||
|
||
# Update relationship in vector database
|
||
rel_vdb_id = compute_mdhash_id(src + tgt, prefix="rel-")
|
||
rel_vdb_id_reverse = compute_mdhash_id(tgt + src, prefix="rel-")
|
||
|
||
# Delete old vector records first (both directions to be safe)
|
||
try:
|
||
await relationships_vdb.delete([rel_vdb_id, rel_vdb_id_reverse])
|
||
except Exception as e:
|
||
logger.debug(
|
||
f"Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}"
|
||
)
|
||
|
||
# Insert new vector record
|
||
rel_content = f"{combined_keywords}\t{src}\n{tgt}\n{final_description}"
|
||
await relationships_vdb.upsert(
|
||
{
|
||
rel_vdb_id: {
|
||
"src_id": src,
|
||
"tgt_id": tgt,
|
||
"source_id": updated_relationship_data["source_id"],
|
||
"content": rel_content,
|
||
"keywords": combined_keywords,
|
||
"description": final_description,
|
||
"weight": weight,
|
||
"file_path": updated_relationship_data["file_path"],
|
||
}
|
||
}
|
||
)
|
||
|
||
|
||
async def _merge_nodes_then_upsert(
|
||
entity_name: str,
|
||
nodes_data: list[dict],
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
global_config: dict,
|
||
pipeline_status: dict = None,
|
||
pipeline_status_lock=None,
|
||
llm_response_cache: BaseKVStorage | None = None,
|
||
):
|
||
"""Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
|
||
already_entity_types = []
|
||
already_source_ids = []
|
||
already_description = []
|
||
already_file_paths = []
|
||
|
||
already_node = await knowledge_graph_inst.get_node(entity_name)
|
||
if already_node:
|
||
already_entity_types.append(already_node["entity_type"])
|
||
already_source_ids.extend(
|
||
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
|
||
)
|
||
already_file_paths.extend(
|
||
split_string_by_multi_markers(already_node["file_path"], [GRAPH_FIELD_SEP])
|
||
)
|
||
already_description.append(already_node["description"])
|
||
|
||
entity_type = sorted(
|
||
Counter(
|
||
[dp["entity_type"] for dp in nodes_data] + already_entity_types
|
||
).items(),
|
||
key=lambda x: x[1],
|
||
reverse=True,
|
||
)[0][0]
|
||
description = GRAPH_FIELD_SEP.join(
|
||
sorted(set([dp["description"] for dp in nodes_data] + already_description))
|
||
)
|
||
source_id = GRAPH_FIELD_SEP.join(
|
||
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
|
||
)
|
||
file_path = build_file_path(already_file_paths, nodes_data, entity_name)
|
||
|
||
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
|
||
|
||
num_fragment = description.count(GRAPH_FIELD_SEP) + 1
|
||
num_new_fragment = len(set([dp["description"] for dp in nodes_data]))
|
||
|
||
if num_fragment > 1:
|
||
if num_fragment >= force_llm_summary_on_merge:
|
||
status_message = f"LLM merge N: {entity_name} | {num_new_fragment}+{num_fragment-num_new_fragment}"
|
||
logger.info(status_message)
|
||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = status_message
|
||
pipeline_status["history_messages"].append(status_message)
|
||
description = await _handle_entity_relation_summary(
|
||
entity_name,
|
||
description,
|
||
global_config,
|
||
llm_response_cache,
|
||
)
|
||
else:
|
||
status_message = f"Merge N: {entity_name} | {num_new_fragment}+{num_fragment-num_new_fragment}"
|
||
logger.info(status_message)
|
||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = status_message
|
||
pipeline_status["history_messages"].append(status_message)
|
||
|
||
node_data = dict(
|
||
entity_id=entity_name,
|
||
entity_type=entity_type,
|
||
description=description,
|
||
source_id=source_id,
|
||
file_path=file_path,
|
||
created_at=int(time.time()),
|
||
)
|
||
await knowledge_graph_inst.upsert_node(
|
||
entity_name,
|
||
node_data=node_data,
|
||
)
|
||
node_data["entity_name"] = entity_name
|
||
return node_data
|
||
|
||
|
||
async def _merge_edges_then_upsert(
|
||
src_id: str,
|
||
tgt_id: str,
|
||
edges_data: list[dict],
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
global_config: dict,
|
||
pipeline_status: dict = None,
|
||
pipeline_status_lock=None,
|
||
llm_response_cache: BaseKVStorage | None = None,
|
||
added_entities: list = None, # New parameter to track entities added during edge processing
|
||
):
|
||
if src_id == tgt_id:
|
||
return None
|
||
|
||
already_weights = []
|
||
already_source_ids = []
|
||
already_description = []
|
||
already_keywords = []
|
||
already_file_paths = []
|
||
|
||
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
||
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
||
# Handle the case where get_edge returns None or missing fields
|
||
if already_edge:
|
||
# Get weight with default 1.0 if missing
|
||
already_weights.append(already_edge.get("weight", 1.0))
|
||
|
||
# Get source_id with empty string default if missing or None
|
||
if already_edge.get("source_id") is not None:
|
||
already_source_ids.extend(
|
||
split_string_by_multi_markers(
|
||
already_edge["source_id"], [GRAPH_FIELD_SEP]
|
||
)
|
||
)
|
||
|
||
# Get file_path with empty string default if missing or None
|
||
if already_edge.get("file_path") is not None:
|
||
already_file_paths.extend(
|
||
split_string_by_multi_markers(
|
||
already_edge["file_path"], [GRAPH_FIELD_SEP]
|
||
)
|
||
)
|
||
|
||
# Get description with empty string default if missing or None
|
||
if already_edge.get("description") is not None:
|
||
already_description.append(already_edge["description"])
|
||
|
||
# Get keywords with empty string default if missing or None
|
||
if already_edge.get("keywords") is not None:
|
||
already_keywords.extend(
|
||
split_string_by_multi_markers(
|
||
already_edge["keywords"], [GRAPH_FIELD_SEP]
|
||
)
|
||
)
|
||
|
||
# Process edges_data with None checks
|
||
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
||
description = GRAPH_FIELD_SEP.join(
|
||
sorted(
|
||
set(
|
||
[dp["description"] for dp in edges_data if dp.get("description")]
|
||
+ already_description
|
||
)
|
||
)
|
||
)
|
||
|
||
# Split all existing and new keywords into individual terms, then combine and deduplicate
|
||
all_keywords = set()
|
||
# Process already_keywords (which are comma-separated)
|
||
for keyword_str in already_keywords:
|
||
if keyword_str: # Skip empty strings
|
||
all_keywords.update(k.strip() for k in keyword_str.split(",") if k.strip())
|
||
# Process new keywords from edges_data
|
||
for edge in edges_data:
|
||
if edge.get("keywords"):
|
||
all_keywords.update(
|
||
k.strip() for k in edge["keywords"].split(",") if k.strip()
|
||
)
|
||
# Join all unique keywords with commas
|
||
keywords = ",".join(sorted(all_keywords))
|
||
|
||
source_id = GRAPH_FIELD_SEP.join(
|
||
set(
|
||
[dp["source_id"] for dp in edges_data if dp.get("source_id")]
|
||
+ already_source_ids
|
||
)
|
||
)
|
||
file_path = build_file_path(already_file_paths, edges_data, f"{src_id}-{tgt_id}")
|
||
|
||
for need_insert_id in [src_id, tgt_id]:
|
||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||
node_data = {
|
||
"entity_id": need_insert_id,
|
||
"source_id": source_id,
|
||
"description": description,
|
||
"entity_type": "UNKNOWN",
|
||
"file_path": file_path,
|
||
"created_at": int(time.time()),
|
||
}
|
||
await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data)
|
||
|
||
# Track entities added during edge processing
|
||
if added_entities is not None:
|
||
entity_data = {
|
||
"entity_name": need_insert_id,
|
||
"entity_type": "UNKNOWN",
|
||
"description": description,
|
||
"source_id": source_id,
|
||
"file_path": file_path,
|
||
"created_at": int(time.time()),
|
||
}
|
||
added_entities.append(entity_data)
|
||
|
||
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
|
||
|
||
num_fragment = description.count(GRAPH_FIELD_SEP) + 1
|
||
num_new_fragment = len(
|
||
set([dp["description"] for dp in edges_data if dp.get("description")])
|
||
)
|
||
|
||
if num_fragment > 1:
|
||
if num_fragment >= force_llm_summary_on_merge:
|
||
status_message = f"LLM merge E: {src_id} - {tgt_id} | {num_new_fragment}+{num_fragment-num_new_fragment}"
|
||
logger.info(status_message)
|
||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = status_message
|
||
pipeline_status["history_messages"].append(status_message)
|
||
description = await _handle_entity_relation_summary(
|
||
f"({src_id}, {tgt_id})",
|
||
description,
|
||
global_config,
|
||
llm_response_cache,
|
||
)
|
||
else:
|
||
status_message = f"Merge E: {src_id} - {tgt_id} | {num_new_fragment}+{num_fragment-num_new_fragment}"
|
||
logger.info(status_message)
|
||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = status_message
|
||
pipeline_status["history_messages"].append(status_message)
|
||
|
||
await knowledge_graph_inst.upsert_edge(
|
||
src_id,
|
||
tgt_id,
|
||
edge_data=dict(
|
||
weight=weight,
|
||
description=description,
|
||
keywords=keywords,
|
||
source_id=source_id,
|
||
file_path=file_path,
|
||
created_at=int(time.time()),
|
||
),
|
||
)
|
||
|
||
edge_data = dict(
|
||
src_id=src_id,
|
||
tgt_id=tgt_id,
|
||
description=description,
|
||
keywords=keywords,
|
||
source_id=source_id,
|
||
file_path=file_path,
|
||
created_at=int(time.time()),
|
||
)
|
||
|
||
return edge_data
|
||
|
||
|
||
async def merge_nodes_and_edges(
|
||
chunk_results: list,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entity_vdb: BaseVectorStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
global_config: dict[str, str],
|
||
full_entities_storage: BaseKVStorage = None,
|
||
full_relations_storage: BaseKVStorage = None,
|
||
doc_id: str = None,
|
||
pipeline_status: dict = None,
|
||
pipeline_status_lock=None,
|
||
llm_response_cache: BaseKVStorage | None = None,
|
||
current_file_number: int = 0,
|
||
total_files: int = 0,
|
||
file_path: str = "unknown_source",
|
||
) -> None:
|
||
"""Two-phase merge: process all entities first, then all relationships
|
||
|
||
This approach ensures data consistency by:
|
||
1. Phase 1: Process all entities concurrently
|
||
2. Phase 2: Process all relationships concurrently (may add missing entities)
|
||
3. Phase 3: Update full_entities and full_relations storage with final results
|
||
|
||
Args:
|
||
chunk_results: List of tuples (maybe_nodes, maybe_edges) containing extracted entities and relationships
|
||
knowledge_graph_inst: Knowledge graph storage
|
||
entity_vdb: Entity vector database
|
||
relationships_vdb: Relationship vector database
|
||
global_config: Global configuration
|
||
full_entities_storage: Storage for document entity lists
|
||
full_relations_storage: Storage for document relation lists
|
||
doc_id: Document ID for storage indexing
|
||
pipeline_status: Pipeline status dictionary
|
||
pipeline_status_lock: Lock for pipeline status
|
||
llm_response_cache: LLM response cache
|
||
current_file_number: Current file number for logging
|
||
total_files: Total files for logging
|
||
file_path: File path for logging
|
||
"""
|
||
|
||
# Collect all nodes and edges from all chunks
|
||
all_nodes = defaultdict(list)
|
||
all_edges = defaultdict(list)
|
||
|
||
for maybe_nodes, maybe_edges in chunk_results:
|
||
# Collect nodes
|
||
for entity_name, entities in maybe_nodes.items():
|
||
all_nodes[entity_name].extend(entities)
|
||
|
||
# Collect edges with sorted keys for undirected graph
|
||
for edge_key, edges in maybe_edges.items():
|
||
sorted_edge_key = tuple(sorted(edge_key))
|
||
all_edges[sorted_edge_key].extend(edges)
|
||
|
||
total_entities_count = len(all_nodes)
|
||
total_relations_count = len(all_edges)
|
||
|
||
log_message = f"Merging stage {current_file_number}/{total_files}: {file_path}"
|
||
logger.info(log_message)
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = log_message
|
||
pipeline_status["history_messages"].append(log_message)
|
||
|
||
# Get max async tasks limit from global_config for semaphore control
|
||
graph_max_async = global_config.get("llm_model_max_async", 4) * 2
|
||
semaphore = asyncio.Semaphore(graph_max_async)
|
||
|
||
# ===== Phase 1: Process all entities concurrently =====
|
||
log_message = f"Phase 1: Processing {total_entities_count} entities (async: {graph_max_async})"
|
||
logger.info(log_message)
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = log_message
|
||
pipeline_status["history_messages"].append(log_message)
|
||
|
||
async def _locked_process_entity_name(entity_name, entities):
|
||
async with semaphore:
|
||
workspace = global_config.get("workspace", "")
|
||
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
||
async with get_storage_keyed_lock(
|
||
[entity_name], namespace=namespace, enable_logging=False
|
||
):
|
||
entity_data = await _merge_nodes_then_upsert(
|
||
entity_name,
|
||
entities,
|
||
knowledge_graph_inst,
|
||
global_config,
|
||
pipeline_status,
|
||
pipeline_status_lock,
|
||
llm_response_cache,
|
||
)
|
||
if entity_vdb is not None:
|
||
data_for_vdb = {
|
||
compute_mdhash_id(entity_data["entity_name"], prefix="ent-"): {
|
||
"entity_name": entity_data["entity_name"],
|
||
"entity_type": entity_data["entity_type"],
|
||
"content": f"{entity_data['entity_name']}\n{entity_data['description']}",
|
||
"source_id": entity_data["source_id"],
|
||
"file_path": entity_data.get("file_path", "unknown_source"),
|
||
}
|
||
}
|
||
await entity_vdb.upsert(data_for_vdb)
|
||
return entity_data
|
||
|
||
# Create entity processing tasks
|
||
entity_tasks = []
|
||
for entity_name, entities in all_nodes.items():
|
||
task = asyncio.create_task(_locked_process_entity_name(entity_name, entities))
|
||
entity_tasks.append(task)
|
||
|
||
# Execute entity tasks with error handling
|
||
processed_entities = []
|
||
if entity_tasks:
|
||
done, pending = await asyncio.wait(
|
||
entity_tasks, return_when=asyncio.FIRST_EXCEPTION
|
||
)
|
||
|
||
# Check if any task raised an exception
|
||
for task in done:
|
||
if task.exception():
|
||
# If a task failed, cancel all pending tasks
|
||
for pending_task in pending:
|
||
pending_task.cancel()
|
||
# Wait for cancellation to complete
|
||
if pending:
|
||
await asyncio.wait(pending)
|
||
# Re-raise the exception to notify the caller
|
||
raise task.exception()
|
||
|
||
# If all tasks completed successfully, collect results
|
||
processed_entities = [task.result() for task in entity_tasks]
|
||
|
||
# ===== Phase 2: Process all relationships concurrently =====
|
||
log_message = f"Phase 2: Processing {total_relations_count} relations (async: {graph_max_async})"
|
||
logger.info(log_message)
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = log_message
|
||
pipeline_status["history_messages"].append(log_message)
|
||
|
||
async def _locked_process_edges(edge_key, edges):
|
||
async with semaphore:
|
||
workspace = global_config.get("workspace", "")
|
||
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
||
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
|
||
|
||
async with get_storage_keyed_lock(
|
||
sorted_edge_key,
|
||
namespace=namespace,
|
||
enable_logging=False,
|
||
):
|
||
added_entities = [] # Track entities added during edge processing
|
||
edge_data = await _merge_edges_then_upsert(
|
||
edge_key[0],
|
||
edge_key[1],
|
||
edges,
|
||
knowledge_graph_inst,
|
||
global_config,
|
||
pipeline_status,
|
||
pipeline_status_lock,
|
||
llm_response_cache,
|
||
added_entities, # Pass list to collect added entities
|
||
)
|
||
|
||
if edge_data is None:
|
||
return None, []
|
||
|
||
if relationships_vdb is not None:
|
||
data_for_vdb = {
|
||
compute_mdhash_id(
|
||
edge_data["src_id"] + edge_data["tgt_id"], prefix="rel-"
|
||
): {
|
||
"src_id": edge_data["src_id"],
|
||
"tgt_id": edge_data["tgt_id"],
|
||
"keywords": edge_data["keywords"],
|
||
"content": f"{edge_data['src_id']}\t{edge_data['tgt_id']}\n{edge_data['keywords']}\n{edge_data['description']}",
|
||
"source_id": edge_data["source_id"],
|
||
"file_path": edge_data.get("file_path", "unknown_source"),
|
||
"weight": edge_data.get("weight", 1.0),
|
||
}
|
||
}
|
||
await relationships_vdb.upsert(data_for_vdb)
|
||
return edge_data, added_entities
|
||
|
||
# Create relationship processing tasks
|
||
edge_tasks = []
|
||
for edge_key, edges in all_edges.items():
|
||
task = asyncio.create_task(_locked_process_edges(edge_key, edges))
|
||
edge_tasks.append(task)
|
||
|
||
# Execute relationship tasks with error handling
|
||
processed_edges = []
|
||
all_added_entities = []
|
||
|
||
if edge_tasks:
|
||
done, pending = await asyncio.wait(
|
||
edge_tasks, return_when=asyncio.FIRST_EXCEPTION
|
||
)
|
||
|
||
# Check if any task raised an exception
|
||
for task in done:
|
||
if task.exception():
|
||
# If a task failed, cancel all pending tasks
|
||
for pending_task in pending:
|
||
pending_task.cancel()
|
||
# Wait for cancellation to complete
|
||
if pending:
|
||
await asyncio.wait(pending)
|
||
# Re-raise the exception to notify the caller
|
||
raise task.exception()
|
||
|
||
# If all tasks completed successfully, collect results
|
||
for task in edge_tasks:
|
||
edge_data, added_entities = task.result()
|
||
if edge_data is not None:
|
||
processed_edges.append(edge_data)
|
||
all_added_entities.extend(added_entities)
|
||
|
||
# ===== Phase 3: Update full_entities and full_relations storage =====
|
||
if full_entities_storage and full_relations_storage and doc_id:
|
||
try:
|
||
# Merge all entities: original entities + entities added during edge processing
|
||
final_entity_names = set()
|
||
|
||
# Add original processed entities
|
||
for entity_data in processed_entities:
|
||
if entity_data and entity_data.get("entity_name"):
|
||
final_entity_names.add(entity_data["entity_name"])
|
||
|
||
# Add entities that were added during relationship processing
|
||
for added_entity in all_added_entities:
|
||
if added_entity and added_entity.get("entity_name"):
|
||
final_entity_names.add(added_entity["entity_name"])
|
||
|
||
# Collect all relation pairs
|
||
final_relation_pairs = set()
|
||
for edge_data in processed_edges:
|
||
if edge_data:
|
||
src_id = edge_data.get("src_id")
|
||
tgt_id = edge_data.get("tgt_id")
|
||
if src_id and tgt_id:
|
||
relation_pair = tuple(sorted([src_id, tgt_id]))
|
||
final_relation_pairs.add(relation_pair)
|
||
|
||
# Update storage
|
||
if final_entity_names:
|
||
await full_entities_storage.upsert(
|
||
{
|
||
doc_id: {
|
||
"entity_names": list(final_entity_names),
|
||
"count": len(final_entity_names),
|
||
}
|
||
}
|
||
)
|
||
|
||
if final_relation_pairs:
|
||
await full_relations_storage.upsert(
|
||
{
|
||
doc_id: {
|
||
"relation_pairs": [
|
||
list(pair) for pair in final_relation_pairs
|
||
],
|
||
"count": len(final_relation_pairs),
|
||
}
|
||
}
|
||
)
|
||
|
||
logger.debug(
|
||
f"Updated entity-relation index for document {doc_id}: {len(final_entity_names)} entities (original: {len(processed_entities)}, added: {len(all_added_entities)}), {len(final_relation_pairs)} relations"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(
|
||
f"Failed to update entity-relation index for document {doc_id}: {e}"
|
||
)
|
||
# Don't raise exception to avoid affecting main flow
|
||
|
||
log_message = f"Completed merging: {len(processed_entities)} entities, {len(all_added_entities)} added entities, {len(processed_edges)} relations"
|
||
logger.info(log_message)
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = log_message
|
||
pipeline_status["history_messages"].append(log_message)
|
||
|
||
|
||
async def extract_entities(
|
||
chunks: dict[str, TextChunkSchema],
|
||
global_config: dict[str, str],
|
||
pipeline_status: dict = None,
|
||
pipeline_status_lock=None,
|
||
llm_response_cache: BaseKVStorage | None = None,
|
||
text_chunks_storage: BaseKVStorage | None = None,
|
||
) -> list:
|
||
use_llm_func: callable = global_config["llm_model_func"]
|
||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||
|
||
ordered_chunks = list(chunks.items())
|
||
# add language and example number params to prompt
|
||
language = global_config["addon_params"].get(
|
||
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
||
)
|
||
entity_types = global_config["addon_params"].get(
|
||
"entity_types", PROMPTS["DEFAULT_ENTITY_TYPES"]
|
||
)
|
||
example_number = global_config["addon_params"].get("example_number", None)
|
||
if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
|
||
examples = "\n".join(
|
||
PROMPTS["entity_extraction_examples"][: int(example_number)]
|
||
)
|
||
else:
|
||
examples = "\n".join(PROMPTS["entity_extraction_examples"])
|
||
|
||
example_context_base = dict(
|
||
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
||
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
||
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
||
entity_types=", ".join(entity_types),
|
||
language=language,
|
||
)
|
||
# add example's format
|
||
examples = examples.format(**example_context_base)
|
||
|
||
entity_extract_prompt = PROMPTS["entity_extraction"]
|
||
context_base = dict(
|
||
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
||
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
||
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
||
entity_types=",".join(entity_types),
|
||
examples=examples,
|
||
language=language,
|
||
)
|
||
|
||
continue_prompt = PROMPTS["entity_continue_extraction"].format(**context_base)
|
||
if_loop_prompt = PROMPTS["entity_if_loop_extraction"]
|
||
|
||
processed_chunks = 0
|
||
total_chunks = len(ordered_chunks)
|
||
|
||
async def _process_extraction_result(
|
||
result: str, chunk_key: str, file_path: str = "unknown_source"
|
||
):
|
||
"""Process a single extraction result (either initial or gleaning)
|
||
Args:
|
||
result (str): The extraction result to process
|
||
chunk_key (str): The chunk key for source tracking
|
||
file_path (str): The file path for citation
|
||
Returns:
|
||
tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships
|
||
"""
|
||
maybe_nodes = defaultdict(list)
|
||
maybe_edges = defaultdict(list)
|
||
|
||
records = split_string_by_multi_markers(
|
||
result,
|
||
[context_base["record_delimiter"], context_base["completion_delimiter"]],
|
||
)
|
||
|
||
for record in records:
|
||
record = re.search(r"\((.*)\)", record)
|
||
if record is None:
|
||
continue
|
||
record = record.group(1)
|
||
record_attributes = split_string_by_multi_markers(
|
||
record, [context_base["tuple_delimiter"]]
|
||
)
|
||
|
||
if_entities = await _handle_single_entity_extraction(
|
||
record_attributes, chunk_key, file_path
|
||
)
|
||
if if_entities is not None:
|
||
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
||
continue
|
||
|
||
if_relation = await _handle_single_relationship_extraction(
|
||
record_attributes, chunk_key, file_path
|
||
)
|
||
if if_relation is not None:
|
||
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
||
if_relation
|
||
)
|
||
|
||
return maybe_nodes, maybe_edges
|
||
|
||
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
||
"""Process a single chunk
|
||
Args:
|
||
chunk_key_dp (tuple[str, TextChunkSchema]):
|
||
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
||
Returns:
|
||
tuple: (maybe_nodes, maybe_edges) containing extracted entities and relationships
|
||
"""
|
||
nonlocal processed_chunks
|
||
chunk_key = chunk_key_dp[0]
|
||
chunk_dp = chunk_key_dp[1]
|
||
content = chunk_dp["content"]
|
||
# Get file path from chunk data or use default
|
||
file_path = chunk_dp.get("file_path", "unknown_source")
|
||
|
||
# Create cache keys collector for batch processing
|
||
cache_keys_collector = []
|
||
|
||
# Get initial extraction
|
||
hint_prompt = entity_extract_prompt.format(
|
||
**{**context_base, "input_text": content}
|
||
)
|
||
|
||
final_result = await use_llm_func_with_cache(
|
||
hint_prompt,
|
||
use_llm_func,
|
||
llm_response_cache=llm_response_cache,
|
||
cache_type="extract",
|
||
chunk_id=chunk_key,
|
||
cache_keys_collector=cache_keys_collector,
|
||
)
|
||
|
||
# Store LLM cache reference in chunk (will be handled by use_llm_func_with_cache)
|
||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
||
|
||
# Process initial extraction with file path
|
||
maybe_nodes, maybe_edges = await _process_extraction_result(
|
||
final_result, chunk_key, file_path
|
||
)
|
||
|
||
# Process additional gleaning results
|
||
for now_glean_index in range(entity_extract_max_gleaning):
|
||
glean_result = await use_llm_func_with_cache(
|
||
continue_prompt,
|
||
use_llm_func,
|
||
llm_response_cache=llm_response_cache,
|
||
history_messages=history,
|
||
cache_type="extract",
|
||
chunk_id=chunk_key,
|
||
cache_keys_collector=cache_keys_collector,
|
||
)
|
||
|
||
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
|
||
|
||
# Process gleaning result separately with file path
|
||
glean_nodes, glean_edges = await _process_extraction_result(
|
||
glean_result, chunk_key, file_path
|
||
)
|
||
|
||
# Merge results - only add entities and edges with new names
|
||
for entity_name, entities in glean_nodes.items():
|
||
if (
|
||
entity_name not in maybe_nodes
|
||
): # Only accetp entities with new name in gleaning stage
|
||
maybe_nodes[entity_name].extend(entities)
|
||
for edge_key, edges in glean_edges.items():
|
||
if (
|
||
edge_key not in maybe_edges
|
||
): # Only accetp edges with new name in gleaning stage
|
||
maybe_edges[edge_key].extend(edges)
|
||
|
||
if now_glean_index == entity_extract_max_gleaning - 1:
|
||
break
|
||
|
||
if_loop_result: str = await use_llm_func_with_cache(
|
||
if_loop_prompt,
|
||
use_llm_func,
|
||
llm_response_cache=llm_response_cache,
|
||
history_messages=history,
|
||
cache_type="extract",
|
||
cache_keys_collector=cache_keys_collector,
|
||
)
|
||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||
if if_loop_result != "yes":
|
||
break
|
||
|
||
# Batch update chunk's llm_cache_list with all collected cache keys
|
||
if cache_keys_collector and text_chunks_storage:
|
||
await update_chunk_cache_list(
|
||
chunk_key,
|
||
text_chunks_storage,
|
||
cache_keys_collector,
|
||
"entity_extraction",
|
||
)
|
||
|
||
processed_chunks += 1
|
||
entities_count = len(maybe_nodes)
|
||
relations_count = len(maybe_edges)
|
||
log_message = f"Chunk {processed_chunks} of {total_chunks} extracted {entities_count} Ent + {relations_count} Rel"
|
||
logger.info(log_message)
|
||
if pipeline_status is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = log_message
|
||
pipeline_status["history_messages"].append(log_message)
|
||
|
||
# Return the extracted nodes and edges for centralized processing
|
||
return maybe_nodes, maybe_edges
|
||
|
||
# Get max async tasks limit from global_config
|
||
chunk_max_async = global_config.get("llm_model_max_async", 4)
|
||
semaphore = asyncio.Semaphore(chunk_max_async)
|
||
|
||
async def _process_with_semaphore(chunk):
|
||
async with semaphore:
|
||
return await _process_single_content(chunk)
|
||
|
||
tasks = []
|
||
for c in ordered_chunks:
|
||
task = asyncio.create_task(_process_with_semaphore(c))
|
||
tasks.append(task)
|
||
|
||
# Wait for tasks to complete or for the first exception to occur
|
||
# This allows us to cancel remaining tasks if any task fails
|
||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
|
||
|
||
# Check if any task raised an exception
|
||
for task in done:
|
||
if task.exception():
|
||
# If a task failed, cancel all pending tasks
|
||
# This prevents unnecessary processing since the parent function will abort anyway
|
||
for pending_task in pending:
|
||
pending_task.cancel()
|
||
|
||
# Wait for cancellation to complete
|
||
if pending:
|
||
await asyncio.wait(pending)
|
||
|
||
# Re-raise the exception to notify the caller
|
||
raise task.exception()
|
||
|
||
# If all tasks completed successfully, collect results
|
||
chunk_results = [task.result() for task in tasks]
|
||
|
||
# Return the chunk_results for later processing in merge_nodes_and_edges
|
||
return chunk_results
|
||
|
||
|
||
async def kg_query(
|
||
query: str,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entities_vdb: BaseVectorStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
text_chunks_db: BaseKVStorage,
|
||
query_param: QueryParam,
|
||
global_config: dict[str, str],
|
||
hashing_kv: BaseKVStorage | None = None,
|
||
system_prompt: str | None = None,
|
||
chunks_vdb: BaseVectorStorage = None,
|
||
) -> str | AsyncIterator[str]:
|
||
if query_param.model_func:
|
||
use_model_func = query_param.model_func
|
||
else:
|
||
use_model_func = global_config["llm_model_func"]
|
||
# Apply higher priority (5) to query relation LLM function
|
||
use_model_func = partial(use_model_func, _priority=5)
|
||
|
||
# Handle cache
|
||
args_hash = compute_args_hash(
|
||
query_param.mode,
|
||
query,
|
||
query_param.response_type,
|
||
query_param.top_k,
|
||
query_param.chunk_top_k,
|
||
query_param.max_entity_tokens,
|
||
query_param.max_relation_tokens,
|
||
query_param.max_total_tokens,
|
||
query_param.hl_keywords or [],
|
||
query_param.ll_keywords or [],
|
||
query_param.user_prompt or "",
|
||
query_param.enable_rerank,
|
||
)
|
||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||
)
|
||
if cached_response is not None:
|
||
return cached_response
|
||
|
||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||
query, query_param, global_config, hashing_kv
|
||
)
|
||
|
||
logger.debug(f"High-level keywords: {hl_keywords}")
|
||
logger.debug(f"Low-level keywords: {ll_keywords}")
|
||
|
||
# Handle empty keywords
|
||
if hl_keywords == [] and ll_keywords == []:
|
||
logger.warning("low_level_keywords and high_level_keywords is empty")
|
||
return PROMPTS["fail_response"]
|
||
if ll_keywords == [] and query_param.mode in ["local", "hybrid"]:
|
||
logger.warning(
|
||
"low_level_keywords is empty, switching from %s mode to global mode",
|
||
query_param.mode,
|
||
)
|
||
query_param.mode = "global"
|
||
if hl_keywords == [] and query_param.mode in ["global", "hybrid"]:
|
||
logger.warning(
|
||
"high_level_keywords is empty, switching from %s mode to local mode",
|
||
query_param.mode,
|
||
)
|
||
query_param.mode = "local"
|
||
|
||
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
|
||
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
|
||
|
||
# Build context
|
||
context = await _build_query_context(
|
||
query,
|
||
ll_keywords_str,
|
||
hl_keywords_str,
|
||
knowledge_graph_inst,
|
||
entities_vdb,
|
||
relationships_vdb,
|
||
text_chunks_db,
|
||
query_param,
|
||
chunks_vdb,
|
||
)
|
||
|
||
if query_param.only_need_context:
|
||
return context if context is not None else PROMPTS["fail_response"]
|
||
if context is None:
|
||
return PROMPTS["fail_response"]
|
||
|
||
# Process conversation history
|
||
history_context = ""
|
||
if query_param.conversation_history:
|
||
history_context = get_conversation_turns(
|
||
query_param.conversation_history, query_param.history_turns
|
||
)
|
||
|
||
# Build system prompt
|
||
user_prompt = (
|
||
query_param.user_prompt
|
||
if query_param.user_prompt
|
||
else PROMPTS["DEFAULT_USER_PROMPT"]
|
||
)
|
||
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"]
|
||
sys_prompt = sys_prompt_temp.format(
|
||
context_data=context,
|
||
response_type=query_param.response_type,
|
||
history=history_context,
|
||
user_prompt=user_prompt,
|
||
)
|
||
|
||
if query_param.only_need_prompt:
|
||
return sys_prompt
|
||
|
||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||
logger.debug(
|
||
f"[kg_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})"
|
||
)
|
||
|
||
response = await use_model_func(
|
||
query,
|
||
system_prompt=sys_prompt,
|
||
stream=query_param.stream,
|
||
)
|
||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||
response = (
|
||
response.replace(sys_prompt, "")
|
||
.replace("user", "")
|
||
.replace("model", "")
|
||
.replace(query, "")
|
||
.replace("<system>", "")
|
||
.replace("</system>", "")
|
||
.strip()
|
||
)
|
||
|
||
if hashing_kv.global_config.get("enable_llm_cache"):
|
||
# Save to cache with query parameters
|
||
queryparam_dict = {
|
||
"mode": query_param.mode,
|
||
"response_type": query_param.response_type,
|
||
"top_k": query_param.top_k,
|
||
"chunk_top_k": query_param.chunk_top_k,
|
||
"max_entity_tokens": query_param.max_entity_tokens,
|
||
"max_relation_tokens": query_param.max_relation_tokens,
|
||
"max_total_tokens": query_param.max_total_tokens,
|
||
"hl_keywords": query_param.hl_keywords or [],
|
||
"ll_keywords": query_param.ll_keywords or [],
|
||
"user_prompt": query_param.user_prompt or "",
|
||
"enable_rerank": query_param.enable_rerank,
|
||
}
|
||
await save_to_cache(
|
||
hashing_kv,
|
||
CacheData(
|
||
args_hash=args_hash,
|
||
content=response,
|
||
prompt=query,
|
||
mode=query_param.mode,
|
||
cache_type="query",
|
||
queryparam=queryparam_dict,
|
||
),
|
||
)
|
||
|
||
return response
|
||
|
||
|
||
async def get_keywords_from_query(
|
||
query: str,
|
||
query_param: QueryParam,
|
||
global_config: dict[str, str],
|
||
hashing_kv: BaseKVStorage | None = None,
|
||
) -> tuple[list[str], list[str]]:
|
||
"""
|
||
Retrieves high-level and low-level keywords for RAG operations.
|
||
|
||
This function checks if keywords are already provided in query parameters,
|
||
and if not, extracts them from the query text using LLM.
|
||
|
||
Args:
|
||
query: The user's query text
|
||
query_param: Query parameters that may contain pre-defined keywords
|
||
global_config: Global configuration dictionary
|
||
hashing_kv: Optional key-value storage for caching results
|
||
|
||
Returns:
|
||
A tuple containing (high_level_keywords, low_level_keywords)
|
||
"""
|
||
# Check if pre-defined keywords are already provided
|
||
if query_param.hl_keywords or query_param.ll_keywords:
|
||
return query_param.hl_keywords, query_param.ll_keywords
|
||
|
||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||
query, query_param, global_config, hashing_kv
|
||
)
|
||
return hl_keywords, ll_keywords
|
||
|
||
|
||
async def extract_keywords_only(
|
||
text: str,
|
||
param: QueryParam,
|
||
global_config: dict[str, str],
|
||
hashing_kv: BaseKVStorage | None = None,
|
||
) -> tuple[list[str], list[str]]:
|
||
"""
|
||
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
||
This method does NOT build the final RAG context or provide a final answer.
|
||
It ONLY extracts keywords (hl_keywords, ll_keywords).
|
||
"""
|
||
|
||
# 1. Handle cache if needed - add cache type for keywords
|
||
args_hash = compute_args_hash(
|
||
param.mode,
|
||
text,
|
||
param.response_type,
|
||
param.top_k,
|
||
param.chunk_top_k,
|
||
param.max_entity_tokens,
|
||
param.max_relation_tokens,
|
||
param.max_total_tokens,
|
||
param.hl_keywords or [],
|
||
param.ll_keywords or [],
|
||
param.user_prompt or "",
|
||
param.enable_rerank,
|
||
)
|
||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
|
||
)
|
||
if cached_response is not None:
|
||
try:
|
||
keywords_data = json_repair.loads(cached_response)
|
||
return keywords_data.get("high_level_keywords", []), keywords_data.get(
|
||
"low_level_keywords", []
|
||
)
|
||
except (json.JSONDecodeError, KeyError):
|
||
logger.warning(
|
||
"Invalid cache format for keywords, proceeding with extraction"
|
||
)
|
||
|
||
# 2. Build the examples
|
||
example_number = global_config["addon_params"].get("example_number", None)
|
||
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
|
||
examples = "\n".join(
|
||
PROMPTS["keywords_extraction_examples"][: int(example_number)]
|
||
)
|
||
else:
|
||
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
|
||
language = global_config["addon_params"].get(
|
||
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
||
)
|
||
|
||
# 3. Process conversation history
|
||
history_context = ""
|
||
if param.conversation_history:
|
||
history_context = get_conversation_turns(
|
||
param.conversation_history, param.history_turns
|
||
)
|
||
|
||
# 4. Build the keyword-extraction prompt
|
||
kw_prompt = PROMPTS["keywords_extraction"].format(
|
||
query=text, examples=examples, language=language, history=history_context
|
||
)
|
||
|
||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||
len_of_prompts = len(tokenizer.encode(kw_prompt))
|
||
logger.debug(
|
||
f"[extract_keywords] Sending to LLM: {len_of_prompts:,} tokens (Prompt: {len_of_prompts})"
|
||
)
|
||
|
||
# 5. Call the LLM for keyword extraction
|
||
if param.model_func:
|
||
use_model_func = param.model_func
|
||
else:
|
||
use_model_func = global_config["llm_model_func"]
|
||
# Apply higher priority (5) to query relation LLM function
|
||
use_model_func = partial(use_model_func, _priority=5)
|
||
|
||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
||
|
||
# 6. Parse out JSON from the LLM response
|
||
result = remove_think_tags(result)
|
||
try:
|
||
keywords_data = json_repair.loads(result)
|
||
if not keywords_data:
|
||
logger.error("No JSON-like structure found in the LLM respond.")
|
||
return [], []
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"JSON parsing error: {e}")
|
||
logger.error(f"LLM respond: {result}")
|
||
return [], []
|
||
|
||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||
|
||
# 7. Cache only the processed keywords with cache type
|
||
if hl_keywords or ll_keywords:
|
||
cache_data = {
|
||
"high_level_keywords": hl_keywords,
|
||
"low_level_keywords": ll_keywords,
|
||
}
|
||
if hashing_kv.global_config.get("enable_llm_cache"):
|
||
# Save to cache with query parameters
|
||
queryparam_dict = {
|
||
"mode": param.mode,
|
||
"response_type": param.response_type,
|
||
"top_k": param.top_k,
|
||
"chunk_top_k": param.chunk_top_k,
|
||
"max_entity_tokens": param.max_entity_tokens,
|
||
"max_relation_tokens": param.max_relation_tokens,
|
||
"max_total_tokens": param.max_total_tokens,
|
||
"hl_keywords": param.hl_keywords or [],
|
||
"ll_keywords": param.ll_keywords or [],
|
||
"user_prompt": param.user_prompt or "",
|
||
"enable_rerank": param.enable_rerank,
|
||
}
|
||
await save_to_cache(
|
||
hashing_kv,
|
||
CacheData(
|
||
args_hash=args_hash,
|
||
content=json.dumps(cache_data),
|
||
prompt=text,
|
||
mode=param.mode,
|
||
cache_type="keywords",
|
||
queryparam=queryparam_dict,
|
||
),
|
||
)
|
||
|
||
return hl_keywords, ll_keywords
|
||
|
||
|
||
async def _get_vector_context(
|
||
query: str,
|
||
chunks_vdb: BaseVectorStorage,
|
||
query_param: QueryParam,
|
||
) -> list[dict]:
|
||
"""
|
||
Retrieve text chunks from the vector database without reranking or truncation.
|
||
|
||
This function performs vector search to find relevant text chunks for a query.
|
||
Reranking and truncation will be handled later in the unified processing.
|
||
|
||
Args:
|
||
query: The query string to search for
|
||
chunks_vdb: Vector database containing document chunks
|
||
query_param: Query parameters including chunk_top_k and ids
|
||
|
||
Returns:
|
||
List of text chunks with metadata
|
||
"""
|
||
try:
|
||
# Use chunk_top_k if specified, otherwise fall back to top_k
|
||
search_top_k = query_param.chunk_top_k or query_param.top_k
|
||
|
||
results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids)
|
||
if not results:
|
||
return []
|
||
|
||
valid_chunks = []
|
||
for result in results:
|
||
if "content" in result:
|
||
chunk_with_metadata = {
|
||
"content": result["content"],
|
||
"created_at": result.get("created_at", None),
|
||
"file_path": result.get("file_path", "unknown_source"),
|
||
"source_type": "vector", # Mark the source type
|
||
"chunk_id": result.get("id"), # Add chunk_id for deduplication
|
||
}
|
||
valid_chunks.append(chunk_with_metadata)
|
||
|
||
logger.info(
|
||
f"Naive query: {len(valid_chunks)} chunks (chunk_top_k: {search_top_k})"
|
||
)
|
||
return valid_chunks
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in _get_vector_context: {e}")
|
||
return []
|
||
|
||
|
||
async def _build_query_context(
|
||
query: str,
|
||
ll_keywords: str,
|
||
hl_keywords: str,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entities_vdb: BaseVectorStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
text_chunks_db: BaseKVStorage,
|
||
query_param: QueryParam,
|
||
chunks_vdb: BaseVectorStorage = None,
|
||
):
|
||
logger.info(f"Process {os.getpid()} building query context...")
|
||
|
||
# Collect chunks from different sources separately
|
||
vector_chunks = []
|
||
entity_chunks = []
|
||
relation_chunks = []
|
||
entities_context = []
|
||
relations_context = []
|
||
|
||
# Store original data for later text chunk retrieval
|
||
local_entities = []
|
||
local_relations = []
|
||
global_entities = []
|
||
global_relations = []
|
||
|
||
# Handle local and global modes
|
||
if query_param.mode == "local":
|
||
local_entities, local_relations = await _get_node_data(
|
||
ll_keywords,
|
||
knowledge_graph_inst,
|
||
entities_vdb,
|
||
query_param,
|
||
)
|
||
|
||
elif query_param.mode == "global":
|
||
global_relations, global_entities = await _get_edge_data(
|
||
hl_keywords,
|
||
knowledge_graph_inst,
|
||
relationships_vdb,
|
||
query_param,
|
||
)
|
||
|
||
else: # hybrid or mix mode
|
||
local_entities, local_relations = await _get_node_data(
|
||
ll_keywords,
|
||
knowledge_graph_inst,
|
||
entities_vdb,
|
||
query_param,
|
||
)
|
||
global_relations, global_entities = await _get_edge_data(
|
||
hl_keywords,
|
||
knowledge_graph_inst,
|
||
relationships_vdb,
|
||
query_param,
|
||
)
|
||
|
||
# Get vector chunks first if in mix mode
|
||
if query_param.mode == "mix" and chunks_vdb:
|
||
vector_chunks = await _get_vector_context(
|
||
query,
|
||
chunks_vdb,
|
||
query_param,
|
||
)
|
||
|
||
# Use round-robin merge to combine local and global data fairly
|
||
final_entities = []
|
||
seen_entities = set()
|
||
|
||
# Round-robin merge entities
|
||
max_len = max(len(local_entities), len(global_entities))
|
||
for i in range(max_len):
|
||
# First from local
|
||
if i < len(local_entities):
|
||
entity = local_entities[i]
|
||
entity_name = entity.get("entity_name")
|
||
if entity_name and entity_name not in seen_entities:
|
||
final_entities.append(entity)
|
||
seen_entities.add(entity_name)
|
||
|
||
# Then from global
|
||
if i < len(global_entities):
|
||
entity = global_entities[i]
|
||
entity_name = entity.get("entity_name")
|
||
if entity_name and entity_name not in seen_entities:
|
||
final_entities.append(entity)
|
||
seen_entities.add(entity_name)
|
||
|
||
# Round-robin merge relations
|
||
final_relations = []
|
||
seen_relations = set()
|
||
|
||
max_len = max(len(local_relations), len(global_relations))
|
||
for i in range(max_len):
|
||
# First from local
|
||
if i < len(local_relations):
|
||
relation = local_relations[i]
|
||
# Build relation unique identifier
|
||
if "src_tgt" in relation:
|
||
rel_key = tuple(sorted(relation["src_tgt"]))
|
||
else:
|
||
rel_key = tuple(
|
||
sorted([relation.get("src_id"), relation.get("tgt_id")])
|
||
)
|
||
|
||
if rel_key not in seen_relations:
|
||
final_relations.append(relation)
|
||
seen_relations.add(rel_key)
|
||
|
||
# Then from global
|
||
if i < len(global_relations):
|
||
relation = global_relations[i]
|
||
# Build relation unique identifier
|
||
if "src_tgt" in relation:
|
||
rel_key = tuple(sorted(relation["src_tgt"]))
|
||
else:
|
||
rel_key = tuple(
|
||
sorted([relation.get("src_id"), relation.get("tgt_id")])
|
||
)
|
||
|
||
if rel_key not in seen_relations:
|
||
final_relations.append(relation)
|
||
seen_relations.add(rel_key)
|
||
|
||
# Generate entities context
|
||
entities_context = []
|
||
for i, n in enumerate(final_entities):
|
||
created_at = n.get("created_at", "UNKNOWN")
|
||
if isinstance(created_at, (int, float)):
|
||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||
|
||
# Get file path from node data
|
||
file_path = n.get("file_path", "unknown_source")
|
||
|
||
entities_context.append(
|
||
{
|
||
"id": i + 1,
|
||
"entity": n["entity_name"],
|
||
"type": n.get("entity_type", "UNKNOWN"),
|
||
"description": n.get("description", "UNKNOWN"),
|
||
"created_at": created_at,
|
||
"file_path": file_path,
|
||
}
|
||
)
|
||
|
||
# Generate relations context
|
||
relations_context = []
|
||
for i, e in enumerate(final_relations):
|
||
created_at = e.get("created_at", "UNKNOWN")
|
||
# Convert timestamp to readable format
|
||
if isinstance(created_at, (int, float)):
|
||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||
|
||
# Get file path from edge data
|
||
file_path = e.get("file_path", "unknown_source")
|
||
|
||
# Handle different relation data formats
|
||
if "src_tgt" in e:
|
||
entity1, entity2 = e["src_tgt"]
|
||
else:
|
||
entity1, entity2 = e.get("src_id"), e.get("tgt_id")
|
||
|
||
relations_context.append(
|
||
{
|
||
"id": i + 1,
|
||
"entity1": entity1,
|
||
"entity2": entity2,
|
||
"description": e.get("description", "UNKNOWN"),
|
||
"created_at": created_at,
|
||
"file_path": file_path,
|
||
}
|
||
)
|
||
|
||
logger.debug(
|
||
f"Initial KG query results: {len(entities_context)} entities, {len(relations_context)} relations"
|
||
)
|
||
|
||
# Unified token control system - Apply precise token limits to entities and relations
|
||
tokenizer = text_chunks_db.global_config.get("tokenizer")
|
||
# Get new token limits from query_param (with fallback to global_config)
|
||
max_entity_tokens = getattr(
|
||
query_param,
|
||
"max_entity_tokens",
|
||
text_chunks_db.global_config.get(
|
||
"max_entity_tokens", DEFAULT_MAX_ENTITY_TOKENS
|
||
),
|
||
)
|
||
max_relation_tokens = getattr(
|
||
query_param,
|
||
"max_relation_tokens",
|
||
text_chunks_db.global_config.get(
|
||
"max_relation_tokens", DEFAULT_MAX_RELATION_TOKENS
|
||
),
|
||
)
|
||
max_total_tokens = getattr(
|
||
query_param,
|
||
"max_total_tokens",
|
||
text_chunks_db.global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS),
|
||
)
|
||
|
||
# Truncate entities based on complete JSON serialization
|
||
if entities_context:
|
||
# Process entities context to replace GRAPH_FIELD_SEP with : in file_path fields
|
||
for entity in entities_context:
|
||
if "file_path" in entity and entity["file_path"]:
|
||
entity["file_path"] = entity["file_path"].replace(GRAPH_FIELD_SEP, ";")
|
||
|
||
entities_context = truncate_list_by_token_size(
|
||
entities_context,
|
||
key=lambda x: json.dumps(x, ensure_ascii=False),
|
||
max_token_size=max_entity_tokens,
|
||
tokenizer=tokenizer,
|
||
)
|
||
|
||
# Truncate relations based on complete JSON serialization
|
||
if relations_context:
|
||
# Process relations context to replace GRAPH_FIELD_SEP with : in file_path fields
|
||
for relation in relations_context:
|
||
if "file_path" in relation and relation["file_path"]:
|
||
relation["file_path"] = relation["file_path"].replace(
|
||
GRAPH_FIELD_SEP, ";"
|
||
)
|
||
|
||
relations_context = truncate_list_by_token_size(
|
||
relations_context,
|
||
key=lambda x: json.dumps(x, ensure_ascii=False),
|
||
max_token_size=max_relation_tokens,
|
||
tokenizer=tokenizer,
|
||
)
|
||
|
||
# After truncation, get text chunks based on final entities and relations
|
||
logger.info(
|
||
f"Truncated KG query results: {len(entities_context)} entities, {len(relations_context)} relations"
|
||
)
|
||
|
||
# Create filtered data based on truncated context
|
||
final_node_datas = []
|
||
if entities_context and final_entities:
|
||
final_entity_names = {e["entity"] for e in entities_context}
|
||
seen_nodes = set()
|
||
for node in final_entities:
|
||
name = node.get("entity_name")
|
||
if name in final_entity_names and name not in seen_nodes:
|
||
final_node_datas.append(node)
|
||
seen_nodes.add(name)
|
||
|
||
final_edge_datas = []
|
||
if relations_context and final_relations:
|
||
final_relation_pairs = {(r["entity1"], r["entity2"]) for r in relations_context}
|
||
seen_edges = set()
|
||
for edge in final_relations:
|
||
src, tgt = edge.get("src_id"), edge.get("tgt_id")
|
||
if src is None or tgt is None:
|
||
src, tgt = edge.get("src_tgt", (None, None))
|
||
|
||
pair = (src, tgt)
|
||
if pair in final_relation_pairs and pair not in seen_edges:
|
||
final_edge_datas.append(edge)
|
||
seen_edges.add(pair)
|
||
|
||
# Get text chunks based on final filtered data
|
||
if final_node_datas:
|
||
entity_chunks = await _find_most_related_text_unit_from_entities(
|
||
final_node_datas,
|
||
query_param,
|
||
text_chunks_db,
|
||
knowledge_graph_inst,
|
||
)
|
||
|
||
if final_edge_datas:
|
||
relation_chunks = await _find_related_text_unit_from_relationships(
|
||
final_edge_datas,
|
||
query_param,
|
||
text_chunks_db,
|
||
entity_chunks,
|
||
)
|
||
|
||
# Round-robin merge chunks from different sources with deduplication by chunk_id
|
||
merged_chunks = []
|
||
seen_chunk_ids = set()
|
||
max_len = max(len(vector_chunks), len(entity_chunks), len(relation_chunks))
|
||
origin_len = len(vector_chunks) + len(entity_chunks) + len(relation_chunks)
|
||
|
||
for i in range(max_len):
|
||
# Add from vector chunks first (Naive mode)
|
||
if i < len(vector_chunks):
|
||
chunk = vector_chunks[i]
|
||
chunk_id = chunk.get("chunk_id") or chunk.get("id")
|
||
if chunk_id and chunk_id not in seen_chunk_ids:
|
||
seen_chunk_ids.add(chunk_id)
|
||
merged_chunks.append(
|
||
{
|
||
"content": chunk["content"],
|
||
"file_path": chunk.get("file_path", "unknown_source"),
|
||
}
|
||
)
|
||
|
||
# Add from entity chunks (Local mode)
|
||
if i < len(entity_chunks):
|
||
chunk = entity_chunks[i]
|
||
chunk_id = chunk.get("chunk_id") or chunk.get("id")
|
||
if chunk_id and chunk_id not in seen_chunk_ids:
|
||
seen_chunk_ids.add(chunk_id)
|
||
merged_chunks.append(
|
||
{
|
||
"content": chunk["content"],
|
||
"file_path": chunk.get("file_path", "unknown_source"),
|
||
}
|
||
)
|
||
|
||
# Add from relation chunks (Global mode)
|
||
if i < len(relation_chunks):
|
||
chunk = relation_chunks[i]
|
||
chunk_id = chunk.get("chunk_id") or chunk.get("id")
|
||
if chunk_id and chunk_id not in seen_chunk_ids:
|
||
seen_chunk_ids.add(chunk_id)
|
||
merged_chunks.append(
|
||
{
|
||
"content": chunk["content"],
|
||
"file_path": chunk.get("file_path", "unknown_source"),
|
||
}
|
||
)
|
||
|
||
logger.debug(
|
||
f"Round-robin merged total chunks from {origin_len} to {len(merged_chunks)}"
|
||
)
|
||
|
||
# Apply token processing to merged chunks
|
||
text_units_context = []
|
||
if merged_chunks:
|
||
# Calculate dynamic token limit for text chunks
|
||
entities_str = json.dumps(entities_context, ensure_ascii=False)
|
||
relations_str = json.dumps(relations_context, ensure_ascii=False)
|
||
|
||
# Calculate base context tokens (entities + relations + template)
|
||
kg_context_template = """-----Entities(KG)-----
|
||
|
||
```json
|
||
{entities_str}
|
||
```
|
||
|
||
-----Relationships(KG)-----
|
||
|
||
```json
|
||
{relations_str}
|
||
```
|
||
|
||
-----Document Chunks(DC)-----
|
||
|
||
```json
|
||
[]
|
||
```
|
||
|
||
"""
|
||
kg_context = kg_context_template.format(
|
||
entities_str=entities_str, relations_str=relations_str
|
||
)
|
||
kg_context_tokens = len(tokenizer.encode(kg_context))
|
||
|
||
# Calculate actual system prompt overhead dynamically
|
||
# 1. Calculate conversation history tokens
|
||
history_context = ""
|
||
if query_param.conversation_history:
|
||
history_context = get_conversation_turns(
|
||
query_param.conversation_history, query_param.history_turns
|
||
)
|
||
history_tokens = (
|
||
len(tokenizer.encode(history_context)) if history_context else 0
|
||
)
|
||
|
||
# 2. Calculate system prompt template tokens (excluding context_data)
|
||
user_prompt = query_param.user_prompt if query_param.user_prompt else ""
|
||
response_type = (
|
||
query_param.response_type
|
||
if query_param.response_type
|
||
else "Multiple Paragraphs"
|
||
)
|
||
|
||
# Get the system prompt template from PROMPTS
|
||
sys_prompt_template = text_chunks_db.global_config.get(
|
||
"system_prompt_template", PROMPTS["rag_response"]
|
||
)
|
||
|
||
# Create a sample system prompt with placeholders filled (excluding context_data)
|
||
sample_sys_prompt = sys_prompt_template.format(
|
||
history=history_context,
|
||
context_data="", # Empty for overhead calculation
|
||
response_type=response_type,
|
||
user_prompt=user_prompt,
|
||
)
|
||
sys_prompt_template_tokens = len(tokenizer.encode(sample_sys_prompt))
|
||
|
||
# Total system prompt overhead = template + query tokens
|
||
query_tokens = len(tokenizer.encode(query))
|
||
sys_prompt_overhead = sys_prompt_template_tokens + query_tokens
|
||
|
||
buffer_tokens = 100 # Safety buffer as requested
|
||
|
||
# Calculate available tokens for text chunks
|
||
used_tokens = kg_context_tokens + sys_prompt_overhead + buffer_tokens
|
||
available_chunk_tokens = max_total_tokens - used_tokens
|
||
|
||
logger.debug(
|
||
f"Token allocation - Total: {max_total_tokens}, History: {history_tokens}, SysPrompt: {sys_prompt_overhead}, KG: {kg_context_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}"
|
||
)
|
||
|
||
# Apply token truncation to chunks using the dynamic limit
|
||
truncated_chunks = await process_chunks_unified(
|
||
query=query,
|
||
unique_chunks=merged_chunks,
|
||
query_param=query_param,
|
||
global_config=text_chunks_db.global_config,
|
||
source_type=query_param.mode,
|
||
chunk_token_limit=available_chunk_tokens, # Pass dynamic limit
|
||
)
|
||
|
||
# Rebuild text_units_context with truncated chunks
|
||
for i, chunk in enumerate(truncated_chunks):
|
||
text_units_context.append(
|
||
{
|
||
"id": i + 1,
|
||
"content": chunk["content"],
|
||
"file_path": chunk.get("file_path", "unknown_source"),
|
||
}
|
||
)
|
||
|
||
logger.debug(
|
||
f"Final chunk processing: {len(merged_chunks)} -> {len(text_units_context)} (chunk available tokens: {available_chunk_tokens})"
|
||
)
|
||
|
||
logger.info(
|
||
f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks"
|
||
)
|
||
|
||
# not necessary to use LLM to generate a response
|
||
if not entities_context and not relations_context:
|
||
return None
|
||
|
||
entities_str = json.dumps(entities_context, ensure_ascii=False)
|
||
relations_str = json.dumps(relations_context, ensure_ascii=False)
|
||
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
|
||
|
||
result = f"""-----Entities(KG)-----
|
||
|
||
```json
|
||
{entities_str}
|
||
```
|
||
|
||
-----Relationships(KG)-----
|
||
|
||
```json
|
||
{relations_str}
|
||
```
|
||
|
||
-----Document Chunks(DC)-----
|
||
|
||
```json
|
||
{text_units_str}
|
||
```
|
||
|
||
"""
|
||
return result
|
||
|
||
|
||
async def _get_node_data(
|
||
query: str,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entities_vdb: BaseVectorStorage,
|
||
query_param: QueryParam,
|
||
):
|
||
# get similar entities
|
||
logger.info(
|
||
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
|
||
)
|
||
|
||
results = await entities_vdb.query(
|
||
query, top_k=query_param.top_k, ids=query_param.ids
|
||
)
|
||
|
||
if not len(results):
|
||
return [], []
|
||
|
||
# Extract all entity IDs from your results list
|
||
node_ids = [r["entity_name"] for r in results]
|
||
|
||
# Call the batch node retrieval and degree functions concurrently.
|
||
nodes_dict, degrees_dict = await asyncio.gather(
|
||
knowledge_graph_inst.get_nodes_batch(node_ids),
|
||
knowledge_graph_inst.node_degrees_batch(node_ids),
|
||
)
|
||
|
||
# Now, if you need the node data and degree in order:
|
||
node_datas = [nodes_dict.get(nid) for nid in node_ids]
|
||
node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids]
|
||
|
||
if not all([n is not None for n in node_datas]):
|
||
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
||
|
||
node_datas = [
|
||
{
|
||
**n,
|
||
"entity_name": k["entity_name"],
|
||
"rank": d,
|
||
"created_at": k.get("created_at"),
|
||
}
|
||
for k, n, d in zip(results, node_datas, node_degrees)
|
||
if n is not None
|
||
]
|
||
|
||
use_relations = await _find_most_related_edges_from_entities(
|
||
node_datas,
|
||
query_param,
|
||
knowledge_graph_inst,
|
||
)
|
||
|
||
logger.info(
|
||
f"Local query: {len(node_datas)} entites, {len(use_relations)} relations"
|
||
)
|
||
|
||
# Entities are sorted by cosine similarity
|
||
# Relations are sorted by rank + weight
|
||
return node_datas, use_relations
|
||
|
||
|
||
async def _find_most_related_text_unit_from_entities(
|
||
node_datas: list[dict],
|
||
query_param: QueryParam,
|
||
text_chunks_db: BaseKVStorage,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
):
|
||
"""
|
||
Find text chunks related to entities using linear gradient weighted polling algorithm.
|
||
|
||
This function implements the optimized text chunk selection strategy:
|
||
1. Sort text chunks for each entity by occurrence count in other entities
|
||
2. Use linear gradient weighted polling to select chunks fairly
|
||
"""
|
||
logger.debug(f"Searching text chunks for {len(node_datas)} entities")
|
||
|
||
if not node_datas:
|
||
return []
|
||
|
||
# Step 1: Collect all text chunks for each entity
|
||
entities_with_chunks = []
|
||
for entity in node_datas:
|
||
if entity.get("source_id"):
|
||
chunks = split_string_by_multi_markers(
|
||
entity["source_id"], [GRAPH_FIELD_SEP]
|
||
)
|
||
if chunks:
|
||
entities_with_chunks.append(
|
||
{
|
||
"entity_name": entity["entity_name"],
|
||
"chunks": chunks,
|
||
"entity_data": entity,
|
||
}
|
||
)
|
||
|
||
if not entities_with_chunks:
|
||
logger.warning("No entities with text chunks found")
|
||
return []
|
||
|
||
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities)
|
||
chunk_occurrence_count = {}
|
||
for entity_info in entities_with_chunks:
|
||
deduplicated_chunks = []
|
||
for chunk_id in entity_info["chunks"]:
|
||
chunk_occurrence_count[chunk_id] = (
|
||
chunk_occurrence_count.get(chunk_id, 0) + 1
|
||
)
|
||
|
||
# If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position)
|
||
if chunk_occurrence_count[chunk_id] == 1:
|
||
deduplicated_chunks.append(chunk_id)
|
||
# count > 1 means this chunk appeared in an earlier entity, so skip it
|
||
|
||
# Update entity's chunks to deduplicated chunks
|
||
entity_info["chunks"] = deduplicated_chunks
|
||
|
||
# Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority)
|
||
for entity_info in entities_with_chunks:
|
||
sorted_chunks = sorted(
|
||
entity_info["chunks"],
|
||
key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0),
|
||
reverse=True,
|
||
)
|
||
entity_info["sorted_chunks"] = sorted_chunks
|
||
|
||
# Step 4: Apply linear gradient weighted polling algorithm
|
||
max_related_chunks = text_chunks_db.global_config.get(
|
||
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
|
||
)
|
||
|
||
selected_chunk_ids = linear_gradient_weighted_polling(
|
||
entities_with_chunks, max_related_chunks, min_related_chunks=1
|
||
)
|
||
|
||
logger.debug(
|
||
f"Found {len(selected_chunk_ids)} entity-related chunks using linear gradient weighted polling"
|
||
)
|
||
|
||
if not selected_chunk_ids:
|
||
return []
|
||
|
||
# Step 5: Batch retrieve chunk data
|
||
unique_chunk_ids = list(
|
||
dict.fromkeys(selected_chunk_ids)
|
||
) # Remove duplicates while preserving order
|
||
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
|
||
|
||
# Step 6: Build result chunks with valid data
|
||
result_chunks = []
|
||
for chunk_id, chunk_data in zip(unique_chunk_ids, chunk_data_list):
|
||
if chunk_data is not None and "content" in chunk_data:
|
||
chunk_data_copy = chunk_data.copy()
|
||
chunk_data_copy["source_type"] = "entity"
|
||
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
|
||
result_chunks.append(chunk_data_copy)
|
||
|
||
return result_chunks
|
||
|
||
|
||
async def _find_most_related_edges_from_entities(
|
||
node_datas: list[dict],
|
||
query_param: QueryParam,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
):
|
||
node_names = [dp["entity_name"] for dp in node_datas]
|
||
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
|
||
|
||
all_edges = []
|
||
seen = set()
|
||
|
||
for node_name in node_names:
|
||
this_edges = batch_edges_dict.get(node_name, [])
|
||
for e in this_edges:
|
||
sorted_edge = tuple(sorted(e))
|
||
if sorted_edge not in seen:
|
||
seen.add(sorted_edge)
|
||
all_edges.append(sorted_edge)
|
||
|
||
# Prepare edge pairs in two forms:
|
||
# For the batch edge properties function, use dicts.
|
||
edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges]
|
||
# For edge degrees, use tuples.
|
||
edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples
|
||
|
||
# Call the batched functions concurrently.
|
||
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
||
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
||
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
|
||
)
|
||
|
||
# Reconstruct edge_datas list in the same order as the deduplicated results.
|
||
all_edges_data = []
|
||
for pair in all_edges:
|
||
edge_props = edge_data_dict.get(pair)
|
||
if edge_props is not None:
|
||
if "weight" not in edge_props:
|
||
logger.warning(
|
||
f"Edge {pair} missing 'weight' attribute, using default value 1.0"
|
||
)
|
||
edge_props["weight"] = 1.0
|
||
|
||
combined = {
|
||
"src_tgt": pair,
|
||
"rank": edge_degrees_dict.get(pair, 0),
|
||
**edge_props,
|
||
}
|
||
all_edges_data.append(combined)
|
||
|
||
all_edges_data = sorted(
|
||
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||
)
|
||
|
||
return all_edges_data
|
||
|
||
|
||
async def _get_edge_data(
|
||
keywords,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
query_param: QueryParam,
|
||
):
|
||
logger.info(
|
||
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
||
)
|
||
|
||
results = await relationships_vdb.query(
|
||
keywords, top_k=query_param.top_k, ids=query_param.ids
|
||
)
|
||
|
||
if not len(results):
|
||
return [], []
|
||
|
||
# Prepare edge pairs in two forms:
|
||
# For the batch edge properties function, use dicts.
|
||
edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
|
||
edge_data_dict = await knowledge_graph_inst.get_edges_batch(edge_pairs_dicts)
|
||
|
||
# Reconstruct edge_datas list in the same order as results.
|
||
edge_datas = []
|
||
for k in results:
|
||
pair = (k["src_id"], k["tgt_id"])
|
||
edge_props = edge_data_dict.get(pair)
|
||
if edge_props is not None:
|
||
if "weight" not in edge_props:
|
||
logger.warning(
|
||
f"Edge {pair} missing 'weight' attribute, using default value 1.0"
|
||
)
|
||
edge_props["weight"] = 1.0
|
||
|
||
# Keep edge data without rank, maintain vector search order
|
||
combined = {
|
||
"src_id": k["src_id"],
|
||
"tgt_id": k["tgt_id"],
|
||
"created_at": k.get("created_at", None),
|
||
**edge_props,
|
||
}
|
||
edge_datas.append(combined)
|
||
|
||
# Relations maintain vector search order (sorted by similarity)
|
||
|
||
use_entities = await _find_most_related_entities_from_relationships(
|
||
edge_datas,
|
||
query_param,
|
||
knowledge_graph_inst,
|
||
)
|
||
|
||
logger.info(
|
||
f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations"
|
||
)
|
||
|
||
return edge_datas, use_entities
|
||
|
||
|
||
async def _find_most_related_entities_from_relationships(
|
||
edge_datas: list[dict],
|
||
query_param: QueryParam,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
):
|
||
entity_names = []
|
||
seen = set()
|
||
|
||
for e in edge_datas:
|
||
if e["src_id"] not in seen:
|
||
entity_names.append(e["src_id"])
|
||
seen.add(e["src_id"])
|
||
if e["tgt_id"] not in seen:
|
||
entity_names.append(e["tgt_id"])
|
||
seen.add(e["tgt_id"])
|
||
|
||
# Only get nodes data, no need for node degrees
|
||
nodes_dict = await knowledge_graph_inst.get_nodes_batch(entity_names)
|
||
|
||
# Rebuild the list in the same order as entity_names
|
||
node_datas = []
|
||
for entity_name in entity_names:
|
||
node = nodes_dict.get(entity_name)
|
||
if node is None:
|
||
logger.warning(f"Node '{entity_name}' not found in batch retrieval.")
|
||
continue
|
||
# Combine the node data with the entity name, no rank needed
|
||
combined = {**node, "entity_name": entity_name}
|
||
node_datas.append(combined)
|
||
|
||
return node_datas
|
||
|
||
|
||
async def _find_related_text_unit_from_relationships(
|
||
edge_datas: list[dict],
|
||
query_param: QueryParam,
|
||
text_chunks_db: BaseKVStorage,
|
||
entity_chunks: list[dict] = None,
|
||
):
|
||
"""
|
||
Find text chunks related to relationships using linear gradient weighted polling algorithm.
|
||
|
||
This function implements the optimized text chunk selection strategy:
|
||
1. Sort text chunks for each relationship by occurrence count in other relationships
|
||
2. Use linear gradient weighted polling to select chunks fairly
|
||
"""
|
||
logger.debug(f"Searching text chunks for {len(edge_datas)} relationships")
|
||
|
||
if not edge_datas:
|
||
return []
|
||
|
||
# Step 1: Collect all text chunks for each relationship
|
||
relations_with_chunks = []
|
||
for relation in edge_datas:
|
||
if relation.get("source_id"):
|
||
chunks = split_string_by_multi_markers(
|
||
relation["source_id"], [GRAPH_FIELD_SEP]
|
||
)
|
||
if chunks:
|
||
# Build relation identifier
|
||
if "src_tgt" in relation:
|
||
rel_key = tuple(sorted(relation["src_tgt"]))
|
||
else:
|
||
rel_key = tuple(
|
||
sorted([relation.get("src_id"), relation.get("tgt_id")])
|
||
)
|
||
|
||
relations_with_chunks.append(
|
||
{
|
||
"relation_key": rel_key,
|
||
"chunks": chunks,
|
||
"relation_data": relation,
|
||
}
|
||
)
|
||
|
||
if not relations_with_chunks:
|
||
logger.warning("No relationships with text chunks found")
|
||
return []
|
||
|
||
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned relationships)
|
||
chunk_occurrence_count = {}
|
||
for relation_info in relations_with_chunks:
|
||
deduplicated_chunks = []
|
||
for chunk_id in relation_info["chunks"]:
|
||
chunk_occurrence_count[chunk_id] = (
|
||
chunk_occurrence_count.get(chunk_id, 0) + 1
|
||
)
|
||
|
||
# If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position)
|
||
if chunk_occurrence_count[chunk_id] == 1:
|
||
deduplicated_chunks.append(chunk_id)
|
||
# count > 1 means this chunk appeared in an earlier relationship, so skip it
|
||
|
||
# Update relationship's chunks to deduplicated chunks
|
||
relation_info["chunks"] = deduplicated_chunks
|
||
|
||
# Step 3: Sort chunks for each relationship by occurrence count (higher count = higher priority)
|
||
for relation_info in relations_with_chunks:
|
||
sorted_chunks = sorted(
|
||
relation_info["chunks"],
|
||
key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0),
|
||
reverse=True,
|
||
)
|
||
relation_info["sorted_chunks"] = sorted_chunks
|
||
|
||
# Step 4: Apply linear gradient weighted polling algorithm
|
||
max_related_chunks = text_chunks_db.global_config.get(
|
||
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
|
||
)
|
||
|
||
selected_chunk_ids = linear_gradient_weighted_polling(
|
||
relations_with_chunks, max_related_chunks, min_related_chunks=1
|
||
)
|
||
|
||
logger.debug(
|
||
f"Found {len(selected_chunk_ids)} relationship-related chunks using linear gradient weighted polling"
|
||
)
|
||
logger.info(
|
||
f"KG related chunks: {len(entity_chunks)} from entitys, {len(selected_chunk_ids)} from relations"
|
||
)
|
||
|
||
if not selected_chunk_ids:
|
||
return []
|
||
|
||
# Step 4.5: Remove duplicates with entity_chunks before batch retrieval
|
||
if entity_chunks:
|
||
# Extract chunk IDs from entity_chunks
|
||
entity_chunk_ids = set()
|
||
for chunk in entity_chunks:
|
||
chunk_id = chunk.get("chunk_id")
|
||
if chunk_id:
|
||
entity_chunk_ids.add(chunk_id)
|
||
|
||
# Filter out duplicate chunk IDs
|
||
original_count = len(selected_chunk_ids)
|
||
selected_chunk_ids = [
|
||
chunk_id
|
||
for chunk_id in selected_chunk_ids
|
||
if chunk_id not in entity_chunk_ids
|
||
]
|
||
|
||
logger.debug(
|
||
f"Deduplication relation-chunks with entity-chunks: {original_count} -> {len(selected_chunk_ids)} chunks "
|
||
)
|
||
|
||
# Early return if no chunks remain after deduplication
|
||
if not selected_chunk_ids:
|
||
return []
|
||
|
||
# Step 5: Batch retrieve chunk data
|
||
unique_chunk_ids = list(
|
||
dict.fromkeys(selected_chunk_ids)
|
||
) # Remove duplicates while preserving order
|
||
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
|
||
|
||
# Step 6: Build result chunks with valid data
|
||
result_chunks = []
|
||
for chunk_id, chunk_data in zip(unique_chunk_ids, chunk_data_list):
|
||
if chunk_data is not None and "content" in chunk_data:
|
||
chunk_data_copy = chunk_data.copy()
|
||
chunk_data_copy["source_type"] = "relationship"
|
||
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
|
||
result_chunks.append(chunk_data_copy)
|
||
|
||
return result_chunks
|
||
|
||
|
||
async def naive_query(
|
||
query: str,
|
||
chunks_vdb: BaseVectorStorage,
|
||
query_param: QueryParam,
|
||
global_config: dict[str, str],
|
||
hashing_kv: BaseKVStorage | None = None,
|
||
system_prompt: str | None = None,
|
||
) -> str | AsyncIterator[str]:
|
||
if query_param.model_func:
|
||
use_model_func = query_param.model_func
|
||
else:
|
||
use_model_func = global_config["llm_model_func"]
|
||
# Apply higher priority (5) to query relation LLM function
|
||
use_model_func = partial(use_model_func, _priority=5)
|
||
|
||
# Handle cache
|
||
args_hash = compute_args_hash(
|
||
query_param.mode,
|
||
query,
|
||
query_param.response_type,
|
||
query_param.top_k,
|
||
query_param.chunk_top_k,
|
||
query_param.max_entity_tokens,
|
||
query_param.max_relation_tokens,
|
||
query_param.max_total_tokens,
|
||
query_param.hl_keywords or [],
|
||
query_param.ll_keywords or [],
|
||
query_param.user_prompt or "",
|
||
query_param.enable_rerank,
|
||
)
|
||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||
)
|
||
if cached_response is not None:
|
||
return cached_response
|
||
|
||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||
|
||
chunks = await _get_vector_context(query, chunks_vdb, query_param)
|
||
|
||
if chunks is None or len(chunks) == 0:
|
||
return PROMPTS["fail_response"]
|
||
|
||
# Calculate dynamic token limit for chunks
|
||
# Get token limits from query_param (with fallback to global_config)
|
||
max_total_tokens = getattr(
|
||
query_param,
|
||
"max_total_tokens",
|
||
global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS),
|
||
)
|
||
|
||
# Calculate conversation history tokens
|
||
history_context = ""
|
||
if query_param.conversation_history:
|
||
history_context = get_conversation_turns(
|
||
query_param.conversation_history, query_param.history_turns
|
||
)
|
||
history_tokens = len(tokenizer.encode(history_context)) if history_context else 0
|
||
|
||
# Calculate system prompt template tokens (excluding content_data)
|
||
user_prompt = query_param.user_prompt if query_param.user_prompt else ""
|
||
response_type = (
|
||
query_param.response_type
|
||
if query_param.response_type
|
||
else "Multiple Paragraphs"
|
||
)
|
||
|
||
# Use the provided system prompt or default
|
||
sys_prompt_template = (
|
||
system_prompt if system_prompt else PROMPTS["naive_rag_response"]
|
||
)
|
||
|
||
# Create a sample system prompt with empty content_data to calculate overhead
|
||
sample_sys_prompt = sys_prompt_template.format(
|
||
content_data="", # Empty for overhead calculation
|
||
response_type=response_type,
|
||
history=history_context,
|
||
user_prompt=user_prompt,
|
||
)
|
||
sys_prompt_template_tokens = len(tokenizer.encode(sample_sys_prompt))
|
||
|
||
# Total system prompt overhead = template + query tokens
|
||
query_tokens = len(tokenizer.encode(query))
|
||
sys_prompt_overhead = sys_prompt_template_tokens + query_tokens
|
||
|
||
buffer_tokens = 100 # Safety buffer
|
||
|
||
# Calculate available tokens for chunks
|
||
used_tokens = sys_prompt_overhead + buffer_tokens
|
||
available_chunk_tokens = max_total_tokens - used_tokens
|
||
|
||
logger.debug(
|
||
f"Naive query token allocation - Total: {max_total_tokens}, History: {history_tokens}, SysPrompt: {sys_prompt_overhead}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}"
|
||
)
|
||
|
||
# Process chunks using unified processing with dynamic token limit
|
||
processed_chunks = await process_chunks_unified(
|
||
query=query,
|
||
unique_chunks=chunks,
|
||
query_param=query_param,
|
||
global_config=global_config,
|
||
source_type="vector",
|
||
chunk_token_limit=available_chunk_tokens, # Pass dynamic limit
|
||
)
|
||
|
||
logger.info(f"Final context: {len(processed_chunks)} chunks")
|
||
|
||
# Build text_units_context from processed chunks
|
||
text_units_context = []
|
||
for i, chunk in enumerate(processed_chunks):
|
||
text_units_context.append(
|
||
{
|
||
"id": i + 1,
|
||
"content": chunk["content"],
|
||
"file_path": chunk.get("file_path", "unknown_source"),
|
||
}
|
||
)
|
||
|
||
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
|
||
if query_param.only_need_context:
|
||
return f"""
|
||
---Document Chunks---
|
||
|
||
```json
|
||
{text_units_str}
|
||
```
|
||
|
||
"""
|
||
# Process conversation history
|
||
history_context = ""
|
||
if query_param.conversation_history:
|
||
history_context = get_conversation_turns(
|
||
query_param.conversation_history, query_param.history_turns
|
||
)
|
||
|
||
# Build system prompt
|
||
user_prompt = (
|
||
query_param.user_prompt
|
||
if query_param.user_prompt
|
||
else PROMPTS["DEFAULT_USER_PROMPT"]
|
||
)
|
||
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["naive_rag_response"]
|
||
sys_prompt = sys_prompt_temp.format(
|
||
content_data=text_units_str,
|
||
response_type=query_param.response_type,
|
||
history=history_context,
|
||
user_prompt=user_prompt,
|
||
)
|
||
|
||
if query_param.only_need_prompt:
|
||
return sys_prompt
|
||
|
||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||
logger.debug(
|
||
f"[naive_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})"
|
||
)
|
||
|
||
response = await use_model_func(
|
||
query,
|
||
system_prompt=sys_prompt,
|
||
stream=query_param.stream,
|
||
)
|
||
|
||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||
response = (
|
||
response[len(sys_prompt) :]
|
||
.replace(sys_prompt, "")
|
||
.replace("user", "")
|
||
.replace("model", "")
|
||
.replace(query, "")
|
||
.replace("<system>", "")
|
||
.replace("</system>", "")
|
||
.strip()
|
||
)
|
||
|
||
if hashing_kv.global_config.get("enable_llm_cache"):
|
||
# Save to cache with query parameters
|
||
queryparam_dict = {
|
||
"mode": query_param.mode,
|
||
"response_type": query_param.response_type,
|
||
"top_k": query_param.top_k,
|
||
"chunk_top_k": query_param.chunk_top_k,
|
||
"max_entity_tokens": query_param.max_entity_tokens,
|
||
"max_relation_tokens": query_param.max_relation_tokens,
|
||
"max_total_tokens": query_param.max_total_tokens,
|
||
"hl_keywords": query_param.hl_keywords or [],
|
||
"ll_keywords": query_param.ll_keywords or [],
|
||
"user_prompt": query_param.user_prompt or "",
|
||
"enable_rerank": query_param.enable_rerank,
|
||
}
|
||
await save_to_cache(
|
||
hashing_kv,
|
||
CacheData(
|
||
args_hash=args_hash,
|
||
content=response,
|
||
prompt=query,
|
||
mode=query_param.mode,
|
||
cache_type="query",
|
||
queryparam=queryparam_dict,
|
||
),
|
||
)
|
||
|
||
return response
|
||
|
||
|
||
# TODO: Deprecated, use user_prompt in QueryParam instead
|
||
async def kg_query_with_keywords(
|
||
query: str,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entities_vdb: BaseVectorStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
text_chunks_db: BaseKVStorage,
|
||
query_param: QueryParam,
|
||
global_config: dict[str, str],
|
||
hashing_kv: BaseKVStorage | None = None,
|
||
ll_keywords: list[str] = [],
|
||
hl_keywords: list[str] = [],
|
||
chunks_vdb: BaseVectorStorage | None = None,
|
||
) -> str | AsyncIterator[str]:
|
||
"""
|
||
Refactored kg_query that does NOT extract keywords by itself.
|
||
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
||
Then it uses those to build context and produce a final LLM response.
|
||
"""
|
||
if query_param.model_func:
|
||
use_model_func = query_param.model_func
|
||
else:
|
||
use_model_func = global_config["llm_model_func"]
|
||
# Apply higher priority (5) to query relation LLM function
|
||
use_model_func = partial(use_model_func, _priority=5)
|
||
|
||
args_hash = compute_args_hash(query_param.mode, query)
|
||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||
)
|
||
if cached_response is not None:
|
||
return cached_response
|
||
|
||
# If neither has any keywords, you could handle that logic here.
|
||
if not hl_keywords and not ll_keywords:
|
||
logger.warning(
|
||
"No keywords found in query_param. Could default to global mode or fail."
|
||
)
|
||
return PROMPTS["fail_response"]
|
||
if not ll_keywords and query_param.mode in ["local", "hybrid"]:
|
||
logger.warning("low_level_keywords is empty, switching to global mode.")
|
||
query_param.mode = "global"
|
||
if not hl_keywords and query_param.mode in ["global", "hybrid"]:
|
||
logger.warning("high_level_keywords is empty, switching to local mode.")
|
||
query_param.mode = "local"
|
||
|
||
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
|
||
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
|
||
|
||
context = await _build_query_context(
|
||
query,
|
||
ll_keywords_str,
|
||
hl_keywords_str,
|
||
knowledge_graph_inst,
|
||
entities_vdb,
|
||
relationships_vdb,
|
||
text_chunks_db,
|
||
query_param,
|
||
chunks_vdb=chunks_vdb,
|
||
)
|
||
if not context:
|
||
return PROMPTS["fail_response"]
|
||
|
||
if query_param.only_need_context:
|
||
return context
|
||
|
||
# Process conversation history
|
||
history_context = ""
|
||
if query_param.conversation_history:
|
||
history_context = get_conversation_turns(
|
||
query_param.conversation_history, query_param.history_turns
|
||
)
|
||
|
||
sys_prompt_temp = PROMPTS["rag_response"]
|
||
sys_prompt = sys_prompt_temp.format(
|
||
context_data=context,
|
||
response_type=query_param.response_type,
|
||
history=history_context,
|
||
)
|
||
|
||
if query_param.only_need_prompt:
|
||
return sys_prompt
|
||
|
||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||
logger.debug(
|
||
f"[kg_query_with_keywords] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})"
|
||
)
|
||
|
||
# 6. Generate response
|
||
response = await use_model_func(
|
||
query,
|
||
system_prompt=sys_prompt,
|
||
stream=query_param.stream,
|
||
)
|
||
|
||
# Clean up response content
|
||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||
response = (
|
||
response.replace(sys_prompt, "")
|
||
.replace("user", "")
|
||
.replace("model", "")
|
||
.replace(query, "")
|
||
.replace("<system>", "")
|
||
.replace("</system>", "")
|
||
.strip()
|
||
)
|
||
|
||
if hashing_kv.global_config.get("enable_llm_cache"):
|
||
await save_to_cache(
|
||
hashing_kv,
|
||
CacheData(
|
||
args_hash=args_hash,
|
||
content=response,
|
||
prompt=query,
|
||
mode=query_param.mode,
|
||
cache_type="query",
|
||
),
|
||
)
|
||
|
||
return response
|
||
|
||
|
||
# TODO: Deprecated, use user_prompt in QueryParam instead
|
||
async def query_with_keywords(
|
||
query: str,
|
||
prompt: str,
|
||
param: QueryParam,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entities_vdb: BaseVectorStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
chunks_vdb: BaseVectorStorage,
|
||
text_chunks_db: BaseKVStorage,
|
||
global_config: dict[str, str],
|
||
hashing_kv: BaseKVStorage | None = None,
|
||
) -> str | AsyncIterator[str]:
|
||
"""
|
||
Extract keywords from the query and then use them for retrieving information.
|
||
|
||
1. Extracts high-level and low-level keywords from the query
|
||
2. Formats the query with the extracted keywords and prompt
|
||
3. Uses the appropriate query method based on param.mode
|
||
|
||
Args:
|
||
query: The user's query
|
||
prompt: Additional prompt to prepend to the query
|
||
param: Query parameters
|
||
knowledge_graph_inst: Knowledge graph storage
|
||
entities_vdb: Entities vector database
|
||
relationships_vdb: Relationships vector database
|
||
chunks_vdb: Document chunks vector database
|
||
text_chunks_db: Text chunks storage
|
||
global_config: Global configuration
|
||
hashing_kv: Cache storage
|
||
|
||
Returns:
|
||
Query response or async iterator
|
||
"""
|
||
# Extract keywords
|
||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||
query=query,
|
||
query_param=param,
|
||
global_config=global_config,
|
||
hashing_kv=hashing_kv,
|
||
)
|
||
|
||
# Create a new string with the prompt and the keywords
|
||
keywords_str = ", ".join(ll_keywords + hl_keywords)
|
||
formatted_question = (
|
||
f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}"
|
||
)
|
||
|
||
# Use appropriate query method based on mode
|
||
if param.mode in ["local", "global", "hybrid", "mix"]:
|
||
return await kg_query_with_keywords(
|
||
formatted_question,
|
||
knowledge_graph_inst,
|
||
entities_vdb,
|
||
relationships_vdb,
|
||
text_chunks_db,
|
||
param,
|
||
global_config,
|
||
hashing_kv=hashing_kv,
|
||
hl_keywords=hl_keywords,
|
||
ll_keywords=ll_keywords,
|
||
chunks_vdb=chunks_vdb,
|
||
)
|
||
elif param.mode == "naive":
|
||
return await naive_query(
|
||
formatted_question,
|
||
chunks_vdb,
|
||
text_chunks_db,
|
||
param,
|
||
global_config,
|
||
hashing_kv=hashing_kv,
|
||
)
|
||
else:
|
||
raise ValueError(f"Unknown mode {param.mode}")
|