mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-07-13 12:00:53 +00:00

Enhance RedisKVStorage: Implement connection pooling and error handling. Refactor async methods to use context managers for Redis operations, improving resource management and error logging. Batch processing added for key operations to optimize performance.
2096 lines
69 KiB
Python
2096 lines
69 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import traceback
|
||
import json
|
||
import re
|
||
import os
|
||
from typing import Any, AsyncIterator
|
||
from collections import Counter, defaultdict
|
||
|
||
from .utils import (
|
||
logger,
|
||
clean_str,
|
||
compute_mdhash_id,
|
||
decode_tokens_by_tiktoken,
|
||
encode_string_by_tiktoken,
|
||
is_float_regex,
|
||
list_of_list_to_csv,
|
||
pack_user_ass_to_openai_messages,
|
||
split_string_by_multi_markers,
|
||
truncate_list_by_token_size,
|
||
process_combine_contexts,
|
||
compute_args_hash,
|
||
handle_cache,
|
||
save_to_cache,
|
||
CacheData,
|
||
statistic_data,
|
||
get_conversation_turns,
|
||
verbose_debug,
|
||
)
|
||
from .base import (
|
||
BaseGraphStorage,
|
||
BaseKVStorage,
|
||
BaseVectorStorage,
|
||
TextChunkSchema,
|
||
QueryParam,
|
||
)
|
||
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
||
import time
|
||
from dotenv import load_dotenv
|
||
|
||
# use the .env that is inside the current folder
|
||
# allows to use different .env file for each lightrag instance
|
||
# the OS environment variables take precedence over the .env file
|
||
load_dotenv(dotenv_path=".env", override=False)
|
||
|
||
|
||
def chunking_by_token_size(
|
||
content: str,
|
||
split_by_character: str | None = None,
|
||
split_by_character_only: bool = False,
|
||
overlap_token_size: int = 128,
|
||
max_token_size: int = 1024,
|
||
tiktoken_model: str = "gpt-4o",
|
||
) -> list[dict[str, Any]]:
|
||
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
|
||
results: list[dict[str, Any]] = []
|
||
if split_by_character:
|
||
raw_chunks = content.split(split_by_character)
|
||
new_chunks = []
|
||
if split_by_character_only:
|
||
for chunk in raw_chunks:
|
||
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
|
||
new_chunks.append((len(_tokens), chunk))
|
||
else:
|
||
for chunk in raw_chunks:
|
||
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
|
||
if len(_tokens) > max_token_size:
|
||
for start in range(
|
||
0, len(_tokens), max_token_size - overlap_token_size
|
||
):
|
||
chunk_content = decode_tokens_by_tiktoken(
|
||
_tokens[start : start + max_token_size],
|
||
model_name=tiktoken_model,
|
||
)
|
||
new_chunks.append(
|
||
(min(max_token_size, len(_tokens) - start), chunk_content)
|
||
)
|
||
else:
|
||
new_chunks.append((len(_tokens), chunk))
|
||
for index, (_len, chunk) in enumerate(new_chunks):
|
||
results.append(
|
||
{
|
||
"tokens": _len,
|
||
"content": chunk.strip(),
|
||
"chunk_order_index": index,
|
||
}
|
||
)
|
||
else:
|
||
for index, start in enumerate(
|
||
range(0, len(tokens), max_token_size - overlap_token_size)
|
||
):
|
||
chunk_content = decode_tokens_by_tiktoken(
|
||
tokens[start : start + max_token_size], model_name=tiktoken_model
|
||
)
|
||
results.append(
|
||
{
|
||
"tokens": min(max_token_size, len(tokens) - start),
|
||
"content": chunk_content.strip(),
|
||
"chunk_order_index": index,
|
||
}
|
||
)
|
||
return results
|
||
|
||
|
||
async def _handle_entity_relation_summary(
|
||
entity_or_relation_name: str,
|
||
description: str,
|
||
global_config: dict,
|
||
) -> str:
|
||
"""Handle entity relation summary
|
||
For each entity or relation, input is the combined description of already existing description and new description.
|
||
If too long, use LLM to summarize.
|
||
"""
|
||
use_llm_func: callable = global_config["llm_model_func"]
|
||
llm_max_tokens = global_config["llm_model_max_token_size"]
|
||
tiktoken_model_name = global_config["tiktoken_model_name"]
|
||
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
|
||
language = global_config["addon_params"].get(
|
||
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
||
)
|
||
|
||
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
|
||
if len(tokens) < summary_max_tokens: # No need for summary
|
||
return description
|
||
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
||
use_description = decode_tokens_by_tiktoken(
|
||
tokens[:llm_max_tokens], model_name=tiktoken_model_name
|
||
)
|
||
context_base = dict(
|
||
entity_name=entity_or_relation_name,
|
||
description_list=use_description.split(GRAPH_FIELD_SEP),
|
||
language=language,
|
||
)
|
||
use_prompt = prompt_template.format(**context_base)
|
||
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
||
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
|
||
return summary
|
||
|
||
|
||
async def _handle_single_entity_extraction(
|
||
record_attributes: list[str],
|
||
chunk_key: str,
|
||
file_path: str = "unknown_source",
|
||
):
|
||
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
||
return None
|
||
|
||
# Clean and validate entity name
|
||
entity_name = clean_str(record_attributes[1]).strip('"')
|
||
if not entity_name.strip():
|
||
logger.warning(
|
||
f"Entity extraction error: empty entity name in: {record_attributes}"
|
||
)
|
||
return None
|
||
|
||
# Clean and validate entity type
|
||
entity_type = clean_str(record_attributes[2]).strip('"')
|
||
if not entity_type.strip() or entity_type.startswith('("'):
|
||
logger.warning(
|
||
f"Entity extraction error: invalid entity type in: {record_attributes}"
|
||
)
|
||
return None
|
||
|
||
# Clean and validate description
|
||
entity_description = clean_str(record_attributes[3]).strip('"')
|
||
if not entity_description.strip():
|
||
logger.warning(
|
||
f"Entity extraction error: empty description for entity '{entity_name}' of type '{entity_type}'"
|
||
)
|
||
return None
|
||
|
||
return dict(
|
||
entity_name=entity_name,
|
||
entity_type=entity_type,
|
||
description=entity_description,
|
||
source_id=chunk_key,
|
||
file_path=file_path,
|
||
)
|
||
|
||
|
||
async def _handle_single_relationship_extraction(
|
||
record_attributes: list[str],
|
||
chunk_key: str,
|
||
file_path: str = "unknown_source",
|
||
):
|
||
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
||
return None
|
||
# add this record as edge
|
||
source = clean_str(record_attributes[1]).strip('"')
|
||
target = clean_str(record_attributes[2]).strip('"')
|
||
edge_description = clean_str(record_attributes[3]).strip('"')
|
||
edge_keywords = clean_str(record_attributes[4]).strip('"')
|
||
edge_source_id = chunk_key
|
||
weight = (
|
||
float(record_attributes[-1].strip('"'))
|
||
if is_float_regex(record_attributes[-1])
|
||
else 1.0
|
||
)
|
||
return dict(
|
||
src_id=source,
|
||
tgt_id=target,
|
||
weight=weight,
|
||
description=edge_description,
|
||
keywords=edge_keywords,
|
||
source_id=edge_source_id,
|
||
file_path=file_path,
|
||
)
|
||
|
||
|
||
async def _merge_nodes_then_upsert(
|
||
entity_name: str,
|
||
nodes_data: list[dict],
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
global_config: dict,
|
||
):
|
||
"""Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
|
||
already_entity_types = []
|
||
already_source_ids = []
|
||
already_description = []
|
||
already_file_paths = []
|
||
|
||
already_node = await knowledge_graph_inst.get_node(entity_name)
|
||
if already_node is not None:
|
||
already_entity_types.append(already_node["entity_type"])
|
||
already_source_ids.extend(
|
||
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
|
||
)
|
||
already_file_paths.extend(
|
||
split_string_by_multi_markers(already_node["file_path"], [GRAPH_FIELD_SEP])
|
||
)
|
||
already_description.append(already_node["description"])
|
||
|
||
entity_type = sorted(
|
||
Counter(
|
||
[dp["entity_type"] for dp in nodes_data] + already_entity_types
|
||
).items(),
|
||
key=lambda x: x[1],
|
||
reverse=True,
|
||
)[0][0]
|
||
description = GRAPH_FIELD_SEP.join(
|
||
sorted(set([dp["description"] for dp in nodes_data] + already_description))
|
||
)
|
||
source_id = GRAPH_FIELD_SEP.join(
|
||
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
|
||
)
|
||
file_path = GRAPH_FIELD_SEP.join(
|
||
set([dp["file_path"] for dp in nodes_data] + already_file_paths)
|
||
)
|
||
|
||
logger.debug(f"file_path: {file_path}")
|
||
description = await _handle_entity_relation_summary(
|
||
entity_name, description, global_config
|
||
)
|
||
node_data = dict(
|
||
entity_id=entity_name,
|
||
entity_type=entity_type,
|
||
description=description,
|
||
source_id=source_id,
|
||
file_path=file_path,
|
||
)
|
||
await knowledge_graph_inst.upsert_node(
|
||
entity_name,
|
||
node_data=node_data,
|
||
)
|
||
node_data["entity_name"] = entity_name
|
||
return node_data
|
||
|
||
|
||
async def _merge_edges_then_upsert(
|
||
src_id: str,
|
||
tgt_id: str,
|
||
edges_data: list[dict],
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
global_config: dict,
|
||
):
|
||
already_weights = []
|
||
already_source_ids = []
|
||
already_description = []
|
||
already_keywords = []
|
||
already_file_paths = []
|
||
|
||
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
||
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
||
# Handle the case where get_edge returns None or missing fields
|
||
if already_edge:
|
||
# Get weight with default 0.0 if missing
|
||
already_weights.append(already_edge.get("weight", 0.0))
|
||
|
||
# Get source_id with empty string default if missing or None
|
||
if already_edge.get("source_id") is not None:
|
||
already_source_ids.extend(
|
||
split_string_by_multi_markers(
|
||
already_edge["source_id"], [GRAPH_FIELD_SEP]
|
||
)
|
||
)
|
||
|
||
# Get file_path with empty string default if missing or None
|
||
if already_edge.get("file_path") is not None:
|
||
already_file_paths.extend(
|
||
split_string_by_multi_markers(
|
||
already_edge["file_path"], [GRAPH_FIELD_SEP]
|
||
)
|
||
)
|
||
|
||
# Get description with empty string default if missing or None
|
||
if already_edge.get("description") is not None:
|
||
already_description.append(already_edge["description"])
|
||
|
||
# Get keywords with empty string default if missing or None
|
||
if already_edge.get("keywords") is not None:
|
||
already_keywords.extend(
|
||
split_string_by_multi_markers(
|
||
already_edge["keywords"], [GRAPH_FIELD_SEP]
|
||
)
|
||
)
|
||
|
||
# Process edges_data with None checks
|
||
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
||
description = GRAPH_FIELD_SEP.join(
|
||
sorted(
|
||
set(
|
||
[dp["description"] for dp in edges_data if dp.get("description")]
|
||
+ already_description
|
||
)
|
||
)
|
||
)
|
||
keywords = GRAPH_FIELD_SEP.join(
|
||
sorted(
|
||
set(
|
||
[dp["keywords"] for dp in edges_data if dp.get("keywords")]
|
||
+ already_keywords
|
||
)
|
||
)
|
||
)
|
||
source_id = GRAPH_FIELD_SEP.join(
|
||
set(
|
||
[dp["source_id"] for dp in edges_data if dp.get("source_id")]
|
||
+ already_source_ids
|
||
)
|
||
)
|
||
file_path = GRAPH_FIELD_SEP.join(
|
||
set(
|
||
[dp["file_path"] for dp in edges_data if dp.get("file_path")]
|
||
+ already_file_paths
|
||
)
|
||
)
|
||
|
||
for need_insert_id in [src_id, tgt_id]:
|
||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||
await knowledge_graph_inst.upsert_node(
|
||
need_insert_id,
|
||
node_data={
|
||
"entity_id": need_insert_id,
|
||
"source_id": source_id,
|
||
"description": description,
|
||
"entity_type": "UNKNOWN",
|
||
"file_path": file_path,
|
||
},
|
||
)
|
||
description = await _handle_entity_relation_summary(
|
||
f"({src_id}, {tgt_id})", description, global_config
|
||
)
|
||
await knowledge_graph_inst.upsert_edge(
|
||
src_id,
|
||
tgt_id,
|
||
edge_data=dict(
|
||
weight=weight,
|
||
description=description,
|
||
keywords=keywords,
|
||
source_id=source_id,
|
||
file_path=file_path,
|
||
),
|
||
)
|
||
|
||
edge_data = dict(
|
||
src_id=src_id,
|
||
tgt_id=tgt_id,
|
||
description=description,
|
||
keywords=keywords,
|
||
source_id=source_id,
|
||
file_path=file_path,
|
||
)
|
||
|
||
return edge_data
|
||
|
||
|
||
async def extract_entities(
|
||
chunks: dict[str, TextChunkSchema],
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entity_vdb: BaseVectorStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
global_config: dict[str, str],
|
||
pipeline_status: dict = None,
|
||
pipeline_status_lock=None,
|
||
llm_response_cache: BaseKVStorage | None = None,
|
||
) -> None:
|
||
use_llm_func: callable = global_config["llm_model_func"]
|
||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||
enable_llm_cache_for_entity_extract: bool = global_config[
|
||
"enable_llm_cache_for_entity_extract"
|
||
]
|
||
|
||
ordered_chunks = list(chunks.items())
|
||
# add language and example number params to prompt
|
||
language = global_config["addon_params"].get(
|
||
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
||
)
|
||
entity_types = global_config["addon_params"].get(
|
||
"entity_types", PROMPTS["DEFAULT_ENTITY_TYPES"]
|
||
)
|
||
example_number = global_config["addon_params"].get("example_number", None)
|
||
if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
|
||
examples = "\n".join(
|
||
PROMPTS["entity_extraction_examples"][: int(example_number)]
|
||
)
|
||
else:
|
||
examples = "\n".join(PROMPTS["entity_extraction_examples"])
|
||
|
||
example_context_base = dict(
|
||
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
||
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
||
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
||
entity_types=", ".join(entity_types),
|
||
language=language,
|
||
)
|
||
# add example's format
|
||
examples = examples.format(**example_context_base)
|
||
|
||
entity_extract_prompt = PROMPTS["entity_extraction"]
|
||
context_base = dict(
|
||
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
||
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
||
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
||
entity_types=",".join(entity_types),
|
||
examples=examples,
|
||
language=language,
|
||
)
|
||
|
||
continue_prompt = PROMPTS["entity_continue_extraction"].format(**context_base)
|
||
if_loop_prompt = PROMPTS["entity_if_loop_extraction"]
|
||
|
||
processed_chunks = 0
|
||
total_chunks = len(ordered_chunks)
|
||
|
||
async def _user_llm_func_with_cache(
|
||
input_text: str, history_messages: list[dict[str, str]] = None
|
||
) -> str:
|
||
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
||
if history_messages:
|
||
history = json.dumps(history_messages, ensure_ascii=False)
|
||
_prompt = history + "\n" + input_text
|
||
else:
|
||
_prompt = input_text
|
||
|
||
# TODO: add cache_type="extract"
|
||
arg_hash = compute_args_hash(_prompt)
|
||
cached_return, _1, _2, _3 = await handle_cache(
|
||
llm_response_cache,
|
||
arg_hash,
|
||
_prompt,
|
||
"default",
|
||
cache_type="extract",
|
||
)
|
||
if cached_return:
|
||
logger.debug(f"Found cache for {arg_hash}")
|
||
statistic_data["llm_cache"] += 1
|
||
return cached_return
|
||
statistic_data["llm_call"] += 1
|
||
if history_messages:
|
||
res: str = await use_llm_func(
|
||
input_text, history_messages=history_messages
|
||
)
|
||
else:
|
||
res: str = await use_llm_func(input_text)
|
||
await save_to_cache(
|
||
llm_response_cache,
|
||
CacheData(
|
||
args_hash=arg_hash,
|
||
content=res,
|
||
prompt=_prompt,
|
||
cache_type="extract",
|
||
),
|
||
)
|
||
return res
|
||
|
||
if history_messages:
|
||
return await use_llm_func(input_text, history_messages=history_messages)
|
||
else:
|
||
return await use_llm_func(input_text)
|
||
|
||
async def _process_extraction_result(
|
||
result: str, chunk_key: str, file_path: str = "unknown_source"
|
||
):
|
||
"""Process a single extraction result (either initial or gleaning)
|
||
Args:
|
||
result (str): The extraction result to process
|
||
chunk_key (str): The chunk key for source tracking
|
||
file_path (str): The file path for citation
|
||
Returns:
|
||
tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships
|
||
"""
|
||
maybe_nodes = defaultdict(list)
|
||
maybe_edges = defaultdict(list)
|
||
|
||
records = split_string_by_multi_markers(
|
||
result,
|
||
[context_base["record_delimiter"], context_base["completion_delimiter"]],
|
||
)
|
||
|
||
for record in records:
|
||
record = re.search(r"\((.*)\)", record)
|
||
if record is None:
|
||
continue
|
||
record = record.group(1)
|
||
record_attributes = split_string_by_multi_markers(
|
||
record, [context_base["tuple_delimiter"]]
|
||
)
|
||
|
||
if_entities = await _handle_single_entity_extraction(
|
||
record_attributes, chunk_key, file_path
|
||
)
|
||
if if_entities is not None:
|
||
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
||
continue
|
||
|
||
if_relation = await _handle_single_relationship_extraction(
|
||
record_attributes, chunk_key, file_path
|
||
)
|
||
if if_relation is not None:
|
||
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
||
if_relation
|
||
)
|
||
|
||
return maybe_nodes, maybe_edges
|
||
|
||
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
||
"""Process a single chunk
|
||
Args:
|
||
chunk_key_dp (tuple[str, TextChunkSchema]):
|
||
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
||
"""
|
||
nonlocal processed_chunks
|
||
chunk_key = chunk_key_dp[0]
|
||
chunk_dp = chunk_key_dp[1]
|
||
content = chunk_dp["content"]
|
||
# Get file path from chunk data or use default
|
||
file_path = chunk_dp.get("file_path", "unknown_source")
|
||
|
||
# Get initial extraction
|
||
hint_prompt = entity_extract_prompt.format(
|
||
**context_base, input_text="{input_text}"
|
||
).format(**context_base, input_text=content)
|
||
|
||
final_result = await _user_llm_func_with_cache(hint_prompt)
|
||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
||
|
||
# Process initial extraction with file path
|
||
maybe_nodes, maybe_edges = await _process_extraction_result(
|
||
final_result, chunk_key, file_path
|
||
)
|
||
|
||
# Process additional gleaning results
|
||
for now_glean_index in range(entity_extract_max_gleaning):
|
||
glean_result = await _user_llm_func_with_cache(
|
||
continue_prompt, history_messages=history
|
||
)
|
||
|
||
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
|
||
|
||
# Process gleaning result separately with file path
|
||
glean_nodes, glean_edges = await _process_extraction_result(
|
||
glean_result, chunk_key, file_path
|
||
)
|
||
|
||
# Merge results
|
||
for entity_name, entities in glean_nodes.items():
|
||
maybe_nodes[entity_name].extend(entities)
|
||
for edge_key, edges in glean_edges.items():
|
||
maybe_edges[edge_key].extend(edges)
|
||
|
||
if now_glean_index == entity_extract_max_gleaning - 1:
|
||
break
|
||
|
||
if_loop_result: str = await _user_llm_func_with_cache(
|
||
if_loop_prompt, history_messages=history
|
||
)
|
||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||
if if_loop_result != "yes":
|
||
break
|
||
|
||
processed_chunks += 1
|
||
entities_count = len(maybe_nodes)
|
||
relations_count = len(maybe_edges)
|
||
log_message = f" Chk {processed_chunks}/{total_chunks}: extracted {entities_count} Ent + {relations_count} Rel (deduplicated)"
|
||
logger.info(log_message)
|
||
if pipeline_status is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = log_message
|
||
pipeline_status["history_messages"].append(log_message)
|
||
return dict(maybe_nodes), dict(maybe_edges)
|
||
|
||
tasks = [_process_single_content(c) for c in ordered_chunks]
|
||
results = await asyncio.gather(*tasks)
|
||
|
||
maybe_nodes = defaultdict(list)
|
||
maybe_edges = defaultdict(list)
|
||
for m_nodes, m_edges in results:
|
||
for k, v in m_nodes.items():
|
||
maybe_nodes[k].extend(v)
|
||
for k, v in m_edges.items():
|
||
maybe_edges[tuple(sorted(k))].extend(v)
|
||
|
||
from .kg.shared_storage import get_graph_db_lock
|
||
|
||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||
|
||
# Ensure that nodes and edges are merged and upserted atomically
|
||
async with graph_db_lock:
|
||
all_entities_data = await asyncio.gather(
|
||
*[
|
||
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
||
for k, v in maybe_nodes.items()
|
||
]
|
||
)
|
||
|
||
all_relationships_data = await asyncio.gather(
|
||
*[
|
||
_merge_edges_then_upsert(
|
||
k[0], k[1], v, knowledge_graph_inst, global_config
|
||
)
|
||
for k, v in maybe_edges.items()
|
||
]
|
||
)
|
||
|
||
if not (all_entities_data or all_relationships_data):
|
||
log_message = "Didn't extract any entities and relationships."
|
||
logger.info(log_message)
|
||
if pipeline_status is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = log_message
|
||
pipeline_status["history_messages"].append(log_message)
|
||
return
|
||
|
||
if not all_entities_data:
|
||
log_message = "Didn't extract any entities"
|
||
logger.info(log_message)
|
||
if pipeline_status is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = log_message
|
||
pipeline_status["history_messages"].append(log_message)
|
||
if not all_relationships_data:
|
||
log_message = "Didn't extract any relationships"
|
||
logger.info(log_message)
|
||
if pipeline_status is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = log_message
|
||
pipeline_status["history_messages"].append(log_message)
|
||
|
||
log_message = f"Extracted {len(all_entities_data)} entities + {len(all_relationships_data)} relationships (deduplicated)"
|
||
logger.info(log_message)
|
||
if pipeline_status is not None:
|
||
async with pipeline_status_lock:
|
||
pipeline_status["latest_message"] = log_message
|
||
pipeline_status["history_messages"].append(log_message)
|
||
verbose_debug(
|
||
f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
|
||
)
|
||
verbose_debug(f"New relationships:{all_relationships_data}")
|
||
|
||
if entity_vdb is not None:
|
||
data_for_vdb = {
|
||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||
"entity_name": dp["entity_name"],
|
||
"entity_type": dp["entity_type"],
|
||
"content": f"{dp['entity_name']}\n{dp['description']}",
|
||
"source_id": dp["source_id"],
|
||
"file_path": dp.get("file_path", "unknown_source"),
|
||
}
|
||
for dp in all_entities_data
|
||
}
|
||
await entity_vdb.upsert(data_for_vdb)
|
||
|
||
if relationships_vdb is not None:
|
||
data_for_vdb = {
|
||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||
"src_id": dp["src_id"],
|
||
"tgt_id": dp["tgt_id"],
|
||
"keywords": dp["keywords"],
|
||
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
||
"source_id": dp["source_id"],
|
||
"file_path": dp.get("file_path", "unknown_source"),
|
||
}
|
||
for dp in all_relationships_data
|
||
}
|
||
await relationships_vdb.upsert(data_for_vdb)
|
||
|
||
|
||
async def kg_query(
|
||
query: str,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entities_vdb: BaseVectorStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
text_chunks_db: BaseKVStorage,
|
||
query_param: QueryParam,
|
||
global_config: dict[str, str],
|
||
hashing_kv: BaseKVStorage | None = None,
|
||
system_prompt: str | None = None,
|
||
) -> str | AsyncIterator[str]:
|
||
# Handle cache
|
||
use_model_func = (
|
||
query_param.model_func
|
||
if query_param.model_func
|
||
else global_config["llm_model_func"]
|
||
)
|
||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||
)
|
||
if cached_response is not None:
|
||
return cached_response
|
||
|
||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||
query, query_param, global_config, hashing_kv
|
||
)
|
||
|
||
logger.debug(f"High-level keywords: {hl_keywords}")
|
||
logger.debug(f"Low-level keywords: {ll_keywords}")
|
||
|
||
# Handle empty keywords
|
||
if hl_keywords == [] and ll_keywords == []:
|
||
logger.warning("low_level_keywords and high_level_keywords is empty")
|
||
return PROMPTS["fail_response"]
|
||
if ll_keywords == [] and query_param.mode in ["local", "hybrid"]:
|
||
logger.warning(
|
||
"low_level_keywords is empty, switching from %s mode to global mode",
|
||
query_param.mode,
|
||
)
|
||
query_param.mode = "global"
|
||
if hl_keywords == [] and query_param.mode in ["global", "hybrid"]:
|
||
logger.warning(
|
||
"high_level_keywords is empty, switching from %s mode to local mode",
|
||
query_param.mode,
|
||
)
|
||
query_param.mode = "local"
|
||
|
||
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
|
||
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
|
||
|
||
# Build context
|
||
context = await _build_query_context(
|
||
ll_keywords_str,
|
||
hl_keywords_str,
|
||
knowledge_graph_inst,
|
||
entities_vdb,
|
||
relationships_vdb,
|
||
text_chunks_db,
|
||
query_param,
|
||
)
|
||
|
||
if query_param.only_need_context:
|
||
return context
|
||
if context is None:
|
||
return PROMPTS["fail_response"]
|
||
|
||
# Process conversation history
|
||
history_context = ""
|
||
if query_param.conversation_history:
|
||
history_context = get_conversation_turns(
|
||
query_param.conversation_history, query_param.history_turns
|
||
)
|
||
|
||
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"]
|
||
sys_prompt = sys_prompt_temp.format(
|
||
context_data=context,
|
||
response_type=query_param.response_type,
|
||
history=history_context,
|
||
)
|
||
|
||
if query_param.only_need_prompt:
|
||
return sys_prompt
|
||
|
||
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
|
||
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
||
|
||
response = await use_model_func(
|
||
query,
|
||
system_prompt=sys_prompt,
|
||
stream=query_param.stream,
|
||
)
|
||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||
response = (
|
||
response.replace(sys_prompt, "")
|
||
.replace("user", "")
|
||
.replace("model", "")
|
||
.replace(query, "")
|
||
.replace("<system>", "")
|
||
.replace("</system>", "")
|
||
.strip()
|
||
)
|
||
|
||
# Save to cache
|
||
await save_to_cache(
|
||
hashing_kv,
|
||
CacheData(
|
||
args_hash=args_hash,
|
||
content=response,
|
||
prompt=query,
|
||
quantized=quantized,
|
||
min_val=min_val,
|
||
max_val=max_val,
|
||
mode=query_param.mode,
|
||
cache_type="query",
|
||
),
|
||
)
|
||
return response
|
||
|
||
|
||
async def extract_keywords_only(
|
||
text: str,
|
||
param: QueryParam,
|
||
global_config: dict[str, str],
|
||
hashing_kv: BaseKVStorage | None = None,
|
||
) -> tuple[list[str], list[str]]:
|
||
"""
|
||
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
||
This method does NOT build the final RAG context or provide a final answer.
|
||
It ONLY extracts keywords (hl_keywords, ll_keywords).
|
||
"""
|
||
|
||
# 1. Handle cache if needed - add cache type for keywords
|
||
args_hash = compute_args_hash(param.mode, text, cache_type="keywords")
|
||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
|
||
)
|
||
if cached_response is not None:
|
||
try:
|
||
keywords_data = json.loads(cached_response)
|
||
return keywords_data["high_level_keywords"], keywords_data[
|
||
"low_level_keywords"
|
||
]
|
||
except (json.JSONDecodeError, KeyError):
|
||
logger.warning(
|
||
"Invalid cache format for keywords, proceeding with extraction"
|
||
)
|
||
|
||
# 2. Build the examples
|
||
example_number = global_config["addon_params"].get("example_number", None)
|
||
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
|
||
examples = "\n".join(
|
||
PROMPTS["keywords_extraction_examples"][: int(example_number)]
|
||
)
|
||
else:
|
||
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
|
||
language = global_config["addon_params"].get(
|
||
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
||
)
|
||
|
||
# 3. Process conversation history
|
||
history_context = ""
|
||
if param.conversation_history:
|
||
history_context = get_conversation_turns(
|
||
param.conversation_history, param.history_turns
|
||
)
|
||
|
||
# 4. Build the keyword-extraction prompt
|
||
kw_prompt = PROMPTS["keywords_extraction"].format(
|
||
query=text, examples=examples, language=language, history=history_context
|
||
)
|
||
|
||
len_of_prompts = len(encode_string_by_tiktoken(kw_prompt))
|
||
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
||
|
||
# 5. Call the LLM for keyword extraction
|
||
use_model_func = (
|
||
param.model_func if param.model_func else global_config["llm_model_func"]
|
||
)
|
||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
||
|
||
# 6. Parse out JSON from the LLM response
|
||
match = re.search(r"\{.*\}", result, re.DOTALL)
|
||
if not match:
|
||
logger.error("No JSON-like structure found in the LLM respond.")
|
||
return [], []
|
||
try:
|
||
keywords_data = json.loads(match.group(0))
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"JSON parsing error: {e}")
|
||
return [], []
|
||
|
||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||
|
||
# 7. Cache only the processed keywords with cache type
|
||
if hl_keywords or ll_keywords:
|
||
cache_data = {
|
||
"high_level_keywords": hl_keywords,
|
||
"low_level_keywords": ll_keywords,
|
||
}
|
||
await save_to_cache(
|
||
hashing_kv,
|
||
CacheData(
|
||
args_hash=args_hash,
|
||
content=json.dumps(cache_data),
|
||
prompt=text,
|
||
quantized=quantized,
|
||
min_val=min_val,
|
||
max_val=max_val,
|
||
mode=param.mode,
|
||
cache_type="keywords",
|
||
),
|
||
)
|
||
return hl_keywords, ll_keywords
|
||
|
||
|
||
async def mix_kg_vector_query(
|
||
query: str,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entities_vdb: BaseVectorStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
chunks_vdb: BaseVectorStorage,
|
||
text_chunks_db: BaseKVStorage,
|
||
query_param: QueryParam,
|
||
global_config: dict[str, str],
|
||
hashing_kv: BaseKVStorage | None = None,
|
||
system_prompt: str | None = None,
|
||
) -> str | AsyncIterator[str]:
|
||
"""
|
||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
||
|
||
This function performs a hybrid search by:
|
||
1. Extracting semantic information from knowledge graph
|
||
2. Retrieving relevant text chunks through vector similarity
|
||
3. Combining both results for comprehensive answer generation
|
||
"""
|
||
# 1. Cache handling
|
||
use_model_func = (
|
||
query_param.model_func
|
||
if query_param.model_func
|
||
else global_config["llm_model_func"]
|
||
)
|
||
args_hash = compute_args_hash("mix", query, cache_type="query")
|
||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||
hashing_kv, args_hash, query, "mix", cache_type="query"
|
||
)
|
||
if cached_response is not None:
|
||
return cached_response
|
||
|
||
# Process conversation history
|
||
history_context = ""
|
||
if query_param.conversation_history:
|
||
history_context = get_conversation_turns(
|
||
query_param.conversation_history, query_param.history_turns
|
||
)
|
||
|
||
# 2. Execute knowledge graph and vector searches in parallel
|
||
async def get_kg_context():
|
||
try:
|
||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||
query, query_param, global_config, hashing_kv
|
||
)
|
||
|
||
if not hl_keywords and not ll_keywords:
|
||
logger.warning("Both high-level and low-level keywords are empty")
|
||
return None
|
||
|
||
# Convert keyword lists to strings
|
||
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
|
||
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
|
||
|
||
# Set query mode based on available keywords
|
||
if not ll_keywords_str and not hl_keywords_str:
|
||
return None
|
||
elif not ll_keywords_str:
|
||
query_param.mode = "global"
|
||
elif not hl_keywords_str:
|
||
query_param.mode = "local"
|
||
else:
|
||
query_param.mode = "hybrid"
|
||
|
||
# Build knowledge graph context
|
||
context = await _build_query_context(
|
||
ll_keywords_str,
|
||
hl_keywords_str,
|
||
knowledge_graph_inst,
|
||
entities_vdb,
|
||
relationships_vdb,
|
||
text_chunks_db,
|
||
query_param,
|
||
)
|
||
|
||
return context
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in get_kg_context: {str(e)}")
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
async def get_vector_context():
|
||
# Consider conversation history in vector search
|
||
augmented_query = query
|
||
if history_context:
|
||
augmented_query = f"{history_context}\n{query}"
|
||
|
||
try:
|
||
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
|
||
mix_topk = min(10, query_param.top_k)
|
||
results = await chunks_vdb.query(
|
||
augmented_query, top_k=mix_topk, ids=query_param.ids
|
||
)
|
||
if not results:
|
||
return None
|
||
|
||
chunks_ids = [r["id"] for r in results]
|
||
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
||
|
||
valid_chunks = []
|
||
for chunk, result in zip(chunks, results):
|
||
if chunk is not None and "content" in chunk:
|
||
# Merge chunk content and time metadata
|
||
chunk_with_time = {
|
||
"content": chunk["content"],
|
||
"created_at": result.get("created_at", None),
|
||
}
|
||
valid_chunks.append(chunk_with_time)
|
||
|
||
if not valid_chunks:
|
||
return None
|
||
|
||
maybe_trun_chunks = truncate_list_by_token_size(
|
||
valid_chunks,
|
||
key=lambda x: x["content"],
|
||
max_token_size=query_param.max_token_for_text_unit,
|
||
)
|
||
|
||
if not maybe_trun_chunks:
|
||
return None
|
||
|
||
# Include time information in content
|
||
formatted_chunks = []
|
||
for c in maybe_trun_chunks:
|
||
chunk_text = "File path: " + c["file_path"] + "\n" + c["content"]
|
||
if c["created_at"]:
|
||
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}"
|
||
formatted_chunks.append(chunk_text)
|
||
|
||
logger.debug(
|
||
f"Truncate chunks from {len(chunks)} to {len(formatted_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
|
||
)
|
||
return "\n--New Chunk--\n".join(formatted_chunks)
|
||
except Exception as e:
|
||
logger.error(f"Error in get_vector_context: {e}")
|
||
return None
|
||
|
||
# 3. Execute both retrievals in parallel
|
||
kg_context, vector_context = await asyncio.gather(
|
||
get_kg_context(), get_vector_context()
|
||
)
|
||
|
||
# 4. Merge contexts
|
||
if kg_context is None and vector_context is None:
|
||
return PROMPTS["fail_response"]
|
||
|
||
if query_param.only_need_context:
|
||
return {"kg_context": kg_context, "vector_context": vector_context}
|
||
|
||
# 5. Construct hybrid prompt
|
||
sys_prompt = (
|
||
system_prompt
|
||
if system_prompt
|
||
else PROMPTS["mix_rag_response"].format(
|
||
kg_context=kg_context
|
||
if kg_context
|
||
else "No relevant knowledge graph information found",
|
||
vector_context=vector_context
|
||
if vector_context
|
||
else "No relevant text information found",
|
||
response_type=query_param.response_type,
|
||
history=history_context,
|
||
)
|
||
)
|
||
|
||
if query_param.only_need_prompt:
|
||
return sys_prompt
|
||
|
||
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
|
||
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
|
||
|
||
# 6. Generate response
|
||
response = await use_model_func(
|
||
query,
|
||
system_prompt=sys_prompt,
|
||
stream=query_param.stream,
|
||
)
|
||
|
||
# Clean up response content
|
||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||
response = (
|
||
response.replace(sys_prompt, "")
|
||
.replace("user", "")
|
||
.replace("model", "")
|
||
.replace(query, "")
|
||
.replace("<system>", "")
|
||
.replace("</system>", "")
|
||
.strip()
|
||
)
|
||
|
||
# 7. Save cache - Only cache after collecting complete response
|
||
await save_to_cache(
|
||
hashing_kv,
|
||
CacheData(
|
||
args_hash=args_hash,
|
||
content=response,
|
||
prompt=query,
|
||
quantized=quantized,
|
||
min_val=min_val,
|
||
max_val=max_val,
|
||
mode="mix",
|
||
cache_type="query",
|
||
),
|
||
)
|
||
|
||
return response
|
||
|
||
|
||
async def _build_query_context(
|
||
ll_keywords: str,
|
||
hl_keywords: str,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entities_vdb: BaseVectorStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
text_chunks_db: BaseKVStorage,
|
||
query_param: QueryParam,
|
||
):
|
||
logger.info(f"Process {os.getpid()} buidling query context...")
|
||
if query_param.mode == "local":
|
||
entities_context, relations_context, text_units_context = await _get_node_data(
|
||
ll_keywords,
|
||
knowledge_graph_inst,
|
||
entities_vdb,
|
||
text_chunks_db,
|
||
query_param,
|
||
)
|
||
elif query_param.mode == "global":
|
||
entities_context, relations_context, text_units_context = await _get_edge_data(
|
||
hl_keywords,
|
||
knowledge_graph_inst,
|
||
relationships_vdb,
|
||
text_chunks_db,
|
||
query_param,
|
||
)
|
||
else: # hybrid mode
|
||
ll_data, hl_data = await asyncio.gather(
|
||
_get_node_data(
|
||
ll_keywords,
|
||
knowledge_graph_inst,
|
||
entities_vdb,
|
||
text_chunks_db,
|
||
query_param,
|
||
),
|
||
_get_edge_data(
|
||
hl_keywords,
|
||
knowledge_graph_inst,
|
||
relationships_vdb,
|
||
text_chunks_db,
|
||
query_param,
|
||
),
|
||
)
|
||
|
||
(
|
||
ll_entities_context,
|
||
ll_relations_context,
|
||
ll_text_units_context,
|
||
) = ll_data
|
||
|
||
(
|
||
hl_entities_context,
|
||
hl_relations_context,
|
||
hl_text_units_context,
|
||
) = hl_data
|
||
|
||
entities_context, relations_context, text_units_context = combine_contexts(
|
||
[hl_entities_context, ll_entities_context],
|
||
[hl_relations_context, ll_relations_context],
|
||
[hl_text_units_context, ll_text_units_context],
|
||
)
|
||
# not necessary to use LLM to generate a response
|
||
if not entities_context.strip() and not relations_context.strip():
|
||
return None
|
||
|
||
result = f"""
|
||
-----Entities-----
|
||
```csv
|
||
{entities_context}
|
||
```
|
||
-----Relationships-----
|
||
```csv
|
||
{relations_context}
|
||
```
|
||
-----Sources-----
|
||
```csv
|
||
{text_units_context}
|
||
```
|
||
""".strip()
|
||
return result
|
||
|
||
|
||
async def _get_node_data(
|
||
query: str,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entities_vdb: BaseVectorStorage,
|
||
text_chunks_db: BaseKVStorage,
|
||
query_param: QueryParam,
|
||
):
|
||
# get similar entities
|
||
logger.info(
|
||
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
|
||
)
|
||
|
||
results = await entities_vdb.query(
|
||
query, top_k=query_param.top_k, ids=query_param.ids
|
||
)
|
||
|
||
if not len(results):
|
||
return "", "", ""
|
||
# get entity information
|
||
node_datas, node_degrees = await asyncio.gather(
|
||
asyncio.gather(
|
||
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
||
),
|
||
asyncio.gather(
|
||
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
||
),
|
||
)
|
||
|
||
if not all([n is not None for n in node_datas]):
|
||
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
||
|
||
node_datas = [
|
||
{**n, "entity_name": k["entity_name"], "rank": d}
|
||
for k, n, d in zip(results, node_datas, node_degrees)
|
||
if n is not None
|
||
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
||
# get entitytext chunk
|
||
use_text_units, use_relations = await asyncio.gather(
|
||
_find_most_related_text_unit_from_entities(
|
||
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
||
),
|
||
_find_most_related_edges_from_entities(
|
||
node_datas, query_param, knowledge_graph_inst
|
||
),
|
||
)
|
||
|
||
len_node_datas = len(node_datas)
|
||
node_datas = truncate_list_by_token_size(
|
||
node_datas,
|
||
key=lambda x: x["description"] if x["description"] is not None else "",
|
||
max_token_size=query_param.max_token_for_local_context,
|
||
)
|
||
logger.debug(
|
||
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
||
)
|
||
|
||
logger.info(
|
||
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks"
|
||
)
|
||
|
||
# build prompt
|
||
entites_section_list = [
|
||
[
|
||
"id",
|
||
"entity",
|
||
"type",
|
||
"description",
|
||
"rank",
|
||
"created_at",
|
||
"file_path",
|
||
]
|
||
]
|
||
for i, n in enumerate(node_datas):
|
||
created_at = n.get("created_at", "UNKNOWN")
|
||
if isinstance(created_at, (int, float)):
|
||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||
|
||
# Get file path from node data
|
||
file_path = n.get("file_path", "unknown_source")
|
||
|
||
entites_section_list.append(
|
||
[
|
||
i,
|
||
n["entity_name"],
|
||
n.get("entity_type", "UNKNOWN"),
|
||
n.get("description", "UNKNOWN"),
|
||
n["rank"],
|
||
created_at,
|
||
file_path,
|
||
]
|
||
)
|
||
entities_context = list_of_list_to_csv(entites_section_list)
|
||
|
||
relations_section_list = [
|
||
[
|
||
"id",
|
||
"source",
|
||
"target",
|
||
"description",
|
||
"keywords",
|
||
"weight",
|
||
"rank",
|
||
"created_at",
|
||
"file_path",
|
||
]
|
||
]
|
||
for i, e in enumerate(use_relations):
|
||
created_at = e.get("created_at", "UNKNOWN")
|
||
# Convert timestamp to readable format
|
||
if isinstance(created_at, (int, float)):
|
||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||
|
||
# Get file path from edge data
|
||
file_path = e.get("file_path", "unknown_source")
|
||
|
||
relations_section_list.append(
|
||
[
|
||
i,
|
||
e["src_tgt"][0],
|
||
e["src_tgt"][1],
|
||
e["description"],
|
||
e["keywords"],
|
||
e["weight"],
|
||
e["rank"],
|
||
created_at,
|
||
file_path,
|
||
]
|
||
)
|
||
relations_context = list_of_list_to_csv(relations_section_list)
|
||
|
||
text_units_section_list = [["id", "content", "file_path"]]
|
||
for i, t in enumerate(use_text_units):
|
||
text_units_section_list.append([i, t["content"], t["file_path"]])
|
||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||
return entities_context, relations_context, text_units_context
|
||
|
||
|
||
async def _find_most_related_text_unit_from_entities(
|
||
node_datas: list[dict],
|
||
query_param: QueryParam,
|
||
text_chunks_db: BaseKVStorage,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
):
|
||
text_units = [
|
||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||
for dp in node_datas
|
||
]
|
||
edges = await asyncio.gather(
|
||
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
||
)
|
||
all_one_hop_nodes = set()
|
||
for this_edges in edges:
|
||
if not this_edges:
|
||
continue
|
||
all_one_hop_nodes.update([e[1] for e in this_edges])
|
||
|
||
all_one_hop_nodes = list(all_one_hop_nodes)
|
||
all_one_hop_nodes_data = await asyncio.gather(
|
||
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
|
||
)
|
||
|
||
# Add null check for node data
|
||
all_one_hop_text_units_lookup = {
|
||
k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP]))
|
||
for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data)
|
||
if v is not None and "source_id" in v # Add source_id check
|
||
}
|
||
|
||
all_text_units_lookup = {}
|
||
tasks = []
|
||
|
||
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
|
||
for c_id in this_text_units:
|
||
if c_id not in all_text_units_lookup:
|
||
all_text_units_lookup[c_id] = index
|
||
tasks.append((c_id, index, this_edges))
|
||
|
||
# Process in batches of 25 tasks at a time to avoid overwhelming resources
|
||
batch_size = 25
|
||
results = []
|
||
|
||
for i in range(0, len(tasks), batch_size):
|
||
batch_tasks = tasks[i:i + batch_size]
|
||
batch_results = await asyncio.gather(
|
||
*[text_chunks_db.get_by_id(c_id) for c_id, _, _ in batch_tasks]
|
||
)
|
||
results.extend(batch_results)
|
||
|
||
for (c_id, index, this_edges), data in zip(tasks, results):
|
||
all_text_units_lookup[c_id] = {
|
||
"data": data,
|
||
"order": index,
|
||
"relation_counts": 0,
|
||
}
|
||
|
||
if this_edges:
|
||
for e in this_edges:
|
||
if (
|
||
e[1] in all_one_hop_text_units_lookup
|
||
and c_id in all_one_hop_text_units_lookup[e[1]]
|
||
):
|
||
all_text_units_lookup[c_id]["relation_counts"] += 1
|
||
|
||
# Filter out None values and ensure data has content
|
||
all_text_units = [
|
||
{"id": k, **v}
|
||
for k, v in all_text_units_lookup.items()
|
||
if v is not None and v.get("data") is not None and "content" in v["data"]
|
||
]
|
||
|
||
if not all_text_units:
|
||
logger.warning("No valid text units found")
|
||
return []
|
||
|
||
all_text_units = sorted(
|
||
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
||
)
|
||
|
||
all_text_units = truncate_list_by_token_size(
|
||
all_text_units,
|
||
key=lambda x: x["data"]["content"],
|
||
max_token_size=query_param.max_token_for_text_unit,
|
||
)
|
||
|
||
logger.debug(
|
||
f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
|
||
)
|
||
|
||
all_text_units = [t["data"] for t in all_text_units]
|
||
return all_text_units
|
||
|
||
|
||
async def _find_most_related_edges_from_entities(
|
||
node_datas: list[dict],
|
||
query_param: QueryParam,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
):
|
||
all_related_edges = await asyncio.gather(
|
||
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
||
)
|
||
all_edges = []
|
||
seen = set()
|
||
|
||
for this_edges in all_related_edges:
|
||
for e in this_edges:
|
||
sorted_edge = tuple(sorted(e))
|
||
if sorted_edge not in seen:
|
||
seen.add(sorted_edge)
|
||
all_edges.append(sorted_edge)
|
||
|
||
all_edges_pack, all_edges_degree = await asyncio.gather(
|
||
asyncio.gather(*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]),
|
||
asyncio.gather(
|
||
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
|
||
),
|
||
)
|
||
all_edges_data = [
|
||
{"src_tgt": k, "rank": d, **v}
|
||
for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree)
|
||
if v is not None
|
||
]
|
||
all_edges_data = sorted(
|
||
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||
)
|
||
all_edges_data = truncate_list_by_token_size(
|
||
all_edges_data,
|
||
key=lambda x: x["description"] if x["description"] is not None else "",
|
||
max_token_size=query_param.max_token_for_global_context,
|
||
)
|
||
|
||
logger.debug(
|
||
f"Truncate relations from {len(all_edges)} to {len(all_edges_data)} (max tokens:{query_param.max_token_for_global_context})"
|
||
)
|
||
|
||
return all_edges_data
|
||
|
||
|
||
async def _get_edge_data(
|
||
keywords,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
text_chunks_db: BaseKVStorage,
|
||
query_param: QueryParam,
|
||
):
|
||
logger.info(
|
||
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
||
)
|
||
|
||
results = await relationships_vdb.query(
|
||
keywords, top_k=query_param.top_k, ids=query_param.ids
|
||
)
|
||
|
||
if not len(results):
|
||
return "", "", ""
|
||
|
||
edge_datas, edge_degree = await asyncio.gather(
|
||
asyncio.gather(
|
||
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
||
),
|
||
asyncio.gather(
|
||
*[
|
||
knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"])
|
||
for r in results
|
||
]
|
||
),
|
||
)
|
||
|
||
edge_datas = [
|
||
{
|
||
"src_id": k["src_id"],
|
||
"tgt_id": k["tgt_id"],
|
||
"rank": d,
|
||
"created_at": k.get("__created_at__", None),
|
||
**v,
|
||
}
|
||
for k, v, d in zip(results, edge_datas, edge_degree)
|
||
if v is not None
|
||
]
|
||
edge_datas = sorted(
|
||
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||
)
|
||
edge_datas = truncate_list_by_token_size(
|
||
edge_datas,
|
||
key=lambda x: x["description"] if x["description"] is not None else "",
|
||
max_token_size=query_param.max_token_for_global_context,
|
||
)
|
||
use_entities, use_text_units = await asyncio.gather(
|
||
_find_most_related_entities_from_relationships(
|
||
edge_datas, query_param, knowledge_graph_inst
|
||
),
|
||
_find_related_text_unit_from_relationships(
|
||
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
|
||
),
|
||
)
|
||
logger.info(
|
||
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks"
|
||
)
|
||
|
||
relations_section_list = [
|
||
[
|
||
"id",
|
||
"source",
|
||
"target",
|
||
"description",
|
||
"keywords",
|
||
"weight",
|
||
"rank",
|
||
"created_at",
|
||
"file_path",
|
||
]
|
||
]
|
||
for i, e in enumerate(edge_datas):
|
||
created_at = e.get("created_at", "Unknown")
|
||
# Convert timestamp to readable format
|
||
if isinstance(created_at, (int, float)):
|
||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||
|
||
# Get file path from edge data
|
||
file_path = e.get("file_path", "unknown_source")
|
||
|
||
relations_section_list.append(
|
||
[
|
||
i,
|
||
e["src_id"],
|
||
e["tgt_id"],
|
||
e["description"],
|
||
e["keywords"],
|
||
e["weight"],
|
||
e["rank"],
|
||
created_at,
|
||
file_path,
|
||
]
|
||
)
|
||
relations_context = list_of_list_to_csv(relations_section_list)
|
||
|
||
entites_section_list = [
|
||
["id", "entity", "type", "description", "rank", "created_at", "file_path"]
|
||
]
|
||
for i, n in enumerate(use_entities):
|
||
created_at = n.get("created_at", "Unknown")
|
||
# Convert timestamp to readable format
|
||
if isinstance(created_at, (int, float)):
|
||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||
|
||
# Get file path from node data
|
||
file_path = n.get("file_path", "unknown_source")
|
||
|
||
entites_section_list.append(
|
||
[
|
||
i,
|
||
n["entity_name"],
|
||
n.get("entity_type", "UNKNOWN"),
|
||
n.get("description", "UNKNOWN"),
|
||
n["rank"],
|
||
created_at,
|
||
file_path,
|
||
]
|
||
)
|
||
entities_context = list_of_list_to_csv(entites_section_list)
|
||
|
||
text_units_section_list = [["id", "content", "file_path"]]
|
||
for i, t in enumerate(use_text_units):
|
||
text_units_section_list.append([i, t["content"], t.get("file_path", "unknown")])
|
||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||
return entities_context, relations_context, text_units_context
|
||
|
||
|
||
async def _find_most_related_entities_from_relationships(
|
||
edge_datas: list[dict],
|
||
query_param: QueryParam,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
):
|
||
entity_names = []
|
||
seen = set()
|
||
|
||
for e in edge_datas:
|
||
if e["src_id"] not in seen:
|
||
entity_names.append(e["src_id"])
|
||
seen.add(e["src_id"])
|
||
if e["tgt_id"] not in seen:
|
||
entity_names.append(e["tgt_id"])
|
||
seen.add(e["tgt_id"])
|
||
|
||
node_datas, node_degrees = await asyncio.gather(
|
||
asyncio.gather(
|
||
*[
|
||
knowledge_graph_inst.get_node(entity_name)
|
||
for entity_name in entity_names
|
||
]
|
||
),
|
||
asyncio.gather(
|
||
*[
|
||
knowledge_graph_inst.node_degree(entity_name)
|
||
for entity_name in entity_names
|
||
]
|
||
),
|
||
)
|
||
node_datas = [
|
||
{**n, "entity_name": k, "rank": d}
|
||
for k, n, d in zip(entity_names, node_datas, node_degrees)
|
||
]
|
||
|
||
len_node_datas = len(node_datas)
|
||
node_datas = truncate_list_by_token_size(
|
||
node_datas,
|
||
key=lambda x: x["description"] if x["description"] is not None else "",
|
||
max_token_size=query_param.max_token_for_local_context,
|
||
)
|
||
logger.debug(
|
||
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
||
)
|
||
|
||
return node_datas
|
||
|
||
|
||
async def _find_related_text_unit_from_relationships(
|
||
edge_datas: list[dict],
|
||
query_param: QueryParam,
|
||
text_chunks_db: BaseKVStorage,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
):
|
||
text_units = [
|
||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||
for dp in edge_datas
|
||
]
|
||
all_text_units_lookup = {}
|
||
|
||
async def fetch_chunk_data(c_id, index):
|
||
if c_id not in all_text_units_lookup:
|
||
chunk_data = await text_chunks_db.get_by_id(c_id)
|
||
# Only store valid data
|
||
if chunk_data is not None and "content" in chunk_data:
|
||
all_text_units_lookup[c_id] = {
|
||
"data": chunk_data,
|
||
"order": index,
|
||
}
|
||
|
||
tasks = []
|
||
for index, unit_list in enumerate(text_units):
|
||
for c_id in unit_list:
|
||
tasks.append(fetch_chunk_data(c_id, index))
|
||
|
||
await asyncio.gather(*tasks)
|
||
|
||
if not all_text_units_lookup:
|
||
logger.warning("No valid text chunks found")
|
||
return []
|
||
|
||
all_text_units = [{"id": k, **v} for k, v in all_text_units_lookup.items()]
|
||
all_text_units = sorted(all_text_units, key=lambda x: x["order"])
|
||
|
||
# Ensure all text chunks have content
|
||
valid_text_units = [
|
||
t for t in all_text_units if t["data"] is not None and "content" in t["data"]
|
||
]
|
||
|
||
if not valid_text_units:
|
||
logger.warning("No valid text chunks after filtering")
|
||
return []
|
||
|
||
truncated_text_units = truncate_list_by_token_size(
|
||
valid_text_units,
|
||
key=lambda x: x["data"]["content"],
|
||
max_token_size=query_param.max_token_for_text_unit,
|
||
)
|
||
|
||
logger.debug(
|
||
f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
|
||
)
|
||
|
||
all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units]
|
||
|
||
return all_text_units
|
||
|
||
|
||
def combine_contexts(entities, relationships, sources):
|
||
# Function to extract entities, relationships, and sources from context strings
|
||
hl_entities, ll_entities = entities[0], entities[1]
|
||
hl_relationships, ll_relationships = relationships[0], relationships[1]
|
||
hl_sources, ll_sources = sources[0], sources[1]
|
||
# Combine and deduplicate the entities
|
||
combined_entities = process_combine_contexts(hl_entities, ll_entities)
|
||
|
||
# Combine and deduplicate the relationships
|
||
combined_relationships = process_combine_contexts(
|
||
hl_relationships, ll_relationships
|
||
)
|
||
|
||
# Combine and deduplicate the sources
|
||
combined_sources = process_combine_contexts(hl_sources, ll_sources)
|
||
|
||
return combined_entities, combined_relationships, combined_sources
|
||
|
||
|
||
async def naive_query(
|
||
query: str,
|
||
chunks_vdb: BaseVectorStorage,
|
||
text_chunks_db: BaseKVStorage,
|
||
query_param: QueryParam,
|
||
global_config: dict[str, str],
|
||
hashing_kv: BaseKVStorage | None = None,
|
||
system_prompt: str | None = None,
|
||
) -> str | AsyncIterator[str]:
|
||
# Handle cache
|
||
use_model_func = (
|
||
query_param.model_func
|
||
if query_param.model_func
|
||
else global_config["llm_model_func"]
|
||
)
|
||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||
)
|
||
if cached_response is not None:
|
||
return cached_response
|
||
|
||
results = await chunks_vdb.query(
|
||
query, top_k=query_param.top_k, ids=query_param.ids
|
||
)
|
||
if not len(results):
|
||
return PROMPTS["fail_response"]
|
||
|
||
chunks_ids = [r["id"] for r in results]
|
||
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
||
|
||
# Filter out invalid chunks
|
||
valid_chunks = [
|
||
chunk for chunk in chunks if chunk is not None and "content" in chunk
|
||
]
|
||
|
||
if not valid_chunks:
|
||
logger.warning("No valid chunks found after filtering")
|
||
return PROMPTS["fail_response"]
|
||
|
||
maybe_trun_chunks = truncate_list_by_token_size(
|
||
valid_chunks,
|
||
key=lambda x: x["content"],
|
||
max_token_size=query_param.max_token_for_text_unit,
|
||
)
|
||
|
||
if not maybe_trun_chunks:
|
||
logger.warning("No chunks left after truncation")
|
||
return PROMPTS["fail_response"]
|
||
|
||
logger.debug(
|
||
f"Truncate chunks from {len(chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
|
||
)
|
||
|
||
section = "\n--New Chunk--\n".join(
|
||
[
|
||
"File path: " + c["file_path"] + "\n" + c["content"]
|
||
for c in maybe_trun_chunks
|
||
]
|
||
)
|
||
|
||
if query_param.only_need_context:
|
||
return section
|
||
|
||
# Process conversation history
|
||
history_context = ""
|
||
if query_param.conversation_history:
|
||
history_context = get_conversation_turns(
|
||
query_param.conversation_history, query_param.history_turns
|
||
)
|
||
|
||
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["naive_rag_response"]
|
||
sys_prompt = sys_prompt_temp.format(
|
||
content_data=section,
|
||
response_type=query_param.response_type,
|
||
history=history_context,
|
||
)
|
||
|
||
if query_param.only_need_prompt:
|
||
return sys_prompt
|
||
|
||
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
|
||
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
|
||
|
||
response = await use_model_func(
|
||
query,
|
||
system_prompt=sys_prompt,
|
||
)
|
||
|
||
if len(response) > len(sys_prompt):
|
||
response = (
|
||
response[len(sys_prompt) :]
|
||
.replace(sys_prompt, "")
|
||
.replace("user", "")
|
||
.replace("model", "")
|
||
.replace(query, "")
|
||
.replace("<system>", "")
|
||
.replace("</system>", "")
|
||
.strip()
|
||
)
|
||
|
||
# Save to cache
|
||
await save_to_cache(
|
||
hashing_kv,
|
||
CacheData(
|
||
args_hash=args_hash,
|
||
content=response,
|
||
prompt=query,
|
||
quantized=quantized,
|
||
min_val=min_val,
|
||
max_val=max_val,
|
||
mode=query_param.mode,
|
||
cache_type="query",
|
||
),
|
||
)
|
||
|
||
return response
|
||
|
||
|
||
async def kg_query_with_keywords(
|
||
query: str,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entities_vdb: BaseVectorStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
text_chunks_db: BaseKVStorage,
|
||
query_param: QueryParam,
|
||
global_config: dict[str, str],
|
||
hashing_kv: BaseKVStorage | None = None,
|
||
) -> str | AsyncIterator[str]:
|
||
"""
|
||
Refactored kg_query that does NOT extract keywords by itself.
|
||
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
||
Then it uses those to build context and produce a final LLM response.
|
||
"""
|
||
|
||
# ---------------------------
|
||
# 1) Handle potential cache for query results
|
||
# ---------------------------
|
||
use_model_func = (
|
||
query_param.model_func
|
||
if query_param.model_func
|
||
else global_config["llm_model_func"]
|
||
)
|
||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||
)
|
||
if cached_response is not None:
|
||
return cached_response
|
||
|
||
# ---------------------------
|
||
# 2) RETRIEVE KEYWORDS FROM query_param
|
||
# ---------------------------
|
||
|
||
# If these fields don't exist, default to empty lists/strings.
|
||
hl_keywords = getattr(query_param, "hl_keywords", []) or []
|
||
ll_keywords = getattr(query_param, "ll_keywords", []) or []
|
||
|
||
# If neither has any keywords, you could handle that logic here.
|
||
if not hl_keywords and not ll_keywords:
|
||
logger.warning(
|
||
"No keywords found in query_param. Could default to global mode or fail."
|
||
)
|
||
return PROMPTS["fail_response"]
|
||
if not ll_keywords and query_param.mode in ["local", "hybrid"]:
|
||
logger.warning("low_level_keywords is empty, switching to global mode.")
|
||
query_param.mode = "global"
|
||
if not hl_keywords and query_param.mode in ["global", "hybrid"]:
|
||
logger.warning("high_level_keywords is empty, switching to local mode.")
|
||
query_param.mode = "local"
|
||
|
||
# Flatten low-level and high-level keywords if needed
|
||
ll_keywords_flat = (
|
||
[item for sublist in ll_keywords for item in sublist]
|
||
if any(isinstance(i, list) for i in ll_keywords)
|
||
else ll_keywords
|
||
)
|
||
hl_keywords_flat = (
|
||
[item for sublist in hl_keywords for item in sublist]
|
||
if any(isinstance(i, list) for i in hl_keywords)
|
||
else hl_keywords
|
||
)
|
||
|
||
# Join the flattened lists
|
||
ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else ""
|
||
hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else ""
|
||
|
||
# ---------------------------
|
||
# 3) BUILD CONTEXT
|
||
# ---------------------------
|
||
context = await _build_query_context(
|
||
ll_keywords_str,
|
||
hl_keywords_str,
|
||
knowledge_graph_inst,
|
||
entities_vdb,
|
||
relationships_vdb,
|
||
text_chunks_db,
|
||
query_param,
|
||
)
|
||
if not context:
|
||
return PROMPTS["fail_response"]
|
||
|
||
# If only context is needed, return it
|
||
if query_param.only_need_context:
|
||
return context
|
||
|
||
# ---------------------------
|
||
# 4) BUILD THE SYSTEM PROMPT + CALL LLM
|
||
# ---------------------------
|
||
|
||
# Process conversation history
|
||
history_context = ""
|
||
if query_param.conversation_history:
|
||
history_context = get_conversation_turns(
|
||
query_param.conversation_history, query_param.history_turns
|
||
)
|
||
|
||
sys_prompt_temp = PROMPTS["rag_response"]
|
||
sys_prompt = sys_prompt_temp.format(
|
||
context_data=context,
|
||
response_type=query_param.response_type,
|
||
history=history_context,
|
||
)
|
||
|
||
if query_param.only_need_prompt:
|
||
return sys_prompt
|
||
|
||
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
|
||
logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
|
||
|
||
# 6. Generate response
|
||
response = await use_model_func(
|
||
query,
|
||
system_prompt=sys_prompt,
|
||
stream=query_param.stream,
|
||
)
|
||
|
||
# Clean up response content
|
||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||
response = (
|
||
response.replace(sys_prompt, "")
|
||
.replace("user", "")
|
||
.replace("model", "")
|
||
.replace(query, "")
|
||
.replace("<system>", "")
|
||
.replace("</system>", "")
|
||
.strip()
|
||
)
|
||
|
||
# 7. Save cache - 只有在收集完整响应后才缓存
|
||
await save_to_cache(
|
||
hashing_kv,
|
||
CacheData(
|
||
args_hash=args_hash,
|
||
content=response,
|
||
prompt=query,
|
||
quantized=quantized,
|
||
min_val=min_val,
|
||
max_val=max_val,
|
||
mode=query_param.mode,
|
||
cache_type="query",
|
||
),
|
||
)
|
||
|
||
return response
|
||
|
||
|
||
async def query_with_keywords(
|
||
query: str,
|
||
prompt: str,
|
||
param: QueryParam,
|
||
knowledge_graph_inst: BaseGraphStorage,
|
||
entities_vdb: BaseVectorStorage,
|
||
relationships_vdb: BaseVectorStorage,
|
||
chunks_vdb: BaseVectorStorage,
|
||
text_chunks_db: BaseKVStorage,
|
||
global_config: dict[str, str],
|
||
hashing_kv: BaseKVStorage | None = None,
|
||
) -> str | AsyncIterator[str]:
|
||
"""
|
||
Extract keywords from the query and then use them for retrieving information.
|
||
|
||
1. Extracts high-level and low-level keywords from the query
|
||
2. Formats the query with the extracted keywords and prompt
|
||
3. Uses the appropriate query method based on param.mode
|
||
|
||
Args:
|
||
query: The user's query
|
||
prompt: Additional prompt to prepend to the query
|
||
param: Query parameters
|
||
knowledge_graph_inst: Knowledge graph storage
|
||
entities_vdb: Entities vector database
|
||
relationships_vdb: Relationships vector database
|
||
chunks_vdb: Document chunks vector database
|
||
text_chunks_db: Text chunks storage
|
||
global_config: Global configuration
|
||
hashing_kv: Cache storage
|
||
|
||
Returns:
|
||
Query response or async iterator
|
||
"""
|
||
# Extract keywords
|
||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||
text=query,
|
||
param=param,
|
||
global_config=global_config,
|
||
hashing_kv=hashing_kv,
|
||
)
|
||
|
||
param.hl_keywords = hl_keywords
|
||
param.ll_keywords = ll_keywords
|
||
|
||
# Create a new string with the prompt and the keywords
|
||
ll_keywords_str = ", ".join(ll_keywords)
|
||
hl_keywords_str = ", ".join(hl_keywords)
|
||
formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
|
||
|
||
# Use appropriate query method based on mode
|
||
if param.mode in ["local", "global", "hybrid"]:
|
||
return await kg_query_with_keywords(
|
||
formatted_question,
|
||
knowledge_graph_inst,
|
||
entities_vdb,
|
||
relationships_vdb,
|
||
text_chunks_db,
|
||
param,
|
||
global_config,
|
||
hashing_kv=hashing_kv,
|
||
)
|
||
elif param.mode == "naive":
|
||
return await naive_query(
|
||
formatted_question,
|
||
chunks_vdb,
|
||
text_chunks_db,
|
||
param,
|
||
global_config,
|
||
hashing_kv=hashing_kv,
|
||
)
|
||
elif param.mode == "mix":
|
||
return await mix_kg_vector_query(
|
||
formatted_question,
|
||
knowledge_graph_inst,
|
||
entities_vdb,
|
||
relationships_vdb,
|
||
chunks_vdb,
|
||
text_chunks_db,
|
||
param,
|
||
global_config,
|
||
hashing_kv=hashing_kv,
|
||
)
|
||
else:
|
||
raise ValueError(f"Unknown mode {param.mode}")
|