2024-10-10 15:02:30 +08:00
|
|
|
|
import asyncio
|
|
|
|
|
import json
|
|
|
|
|
import re
|
2024-11-25 15:04:38 +08:00
|
|
|
|
from tqdm.asyncio import tqdm as tqdm_async
|
2024-10-10 15:02:30 +08:00
|
|
|
|
from typing import Union
|
|
|
|
|
from collections import Counter, defaultdict
|
|
|
|
|
from .utils import (
|
|
|
|
|
logger,
|
|
|
|
|
clean_str,
|
|
|
|
|
compute_mdhash_id,
|
|
|
|
|
decode_tokens_by_tiktoken,
|
|
|
|
|
encode_string_by_tiktoken,
|
|
|
|
|
is_float_regex,
|
|
|
|
|
list_of_list_to_csv,
|
|
|
|
|
pack_user_ass_to_openai_messages,
|
|
|
|
|
split_string_by_multi_markers,
|
|
|
|
|
truncate_list_by_token_size,
|
2024-10-31 11:32:44 +08:00
|
|
|
|
process_combine_contexts,
|
2024-12-08 17:35:52 +08:00
|
|
|
|
compute_args_hash,
|
|
|
|
|
handle_cache,
|
|
|
|
|
save_to_cache,
|
|
|
|
|
CacheData,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
)
|
|
|
|
|
from .base import (
|
|
|
|
|
BaseGraphStorage,
|
|
|
|
|
BaseKVStorage,
|
|
|
|
|
BaseVectorStorage,
|
|
|
|
|
TextChunkSchema,
|
|
|
|
|
QueryParam,
|
|
|
|
|
)
|
|
|
|
|
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
2024-12-29 15:25:57 +08:00
|
|
|
|
import time
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
def chunking_by_token_size(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
content: str,
|
|
|
|
|
split_by_character=None,
|
2025-01-09 11:55:49 +08:00
|
|
|
|
split_by_character_only=False,
|
2025-01-07 16:26:12 +08:00
|
|
|
|
overlap_token_size=128,
|
|
|
|
|
max_token_size=1024,
|
|
|
|
|
tiktoken_model="gpt-4o",
|
2025-01-09 17:20:24 +05:30
|
|
|
|
**kwargs,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
):
|
|
|
|
|
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
|
|
|
|
|
results = []
|
2025-01-07 00:28:15 +08:00
|
|
|
|
if split_by_character:
|
|
|
|
|
raw_chunks = content.split(split_by_character)
|
|
|
|
|
new_chunks = []
|
2025-01-09 11:55:49 +08:00
|
|
|
|
if split_by_character_only:
|
|
|
|
|
for chunk in raw_chunks:
|
|
|
|
|
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
|
2025-01-07 00:28:15 +08:00
|
|
|
|
new_chunks.append((len(_tokens), chunk))
|
2025-01-09 11:55:49 +08:00
|
|
|
|
else:
|
|
|
|
|
for chunk in raw_chunks:
|
|
|
|
|
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
|
|
|
|
|
if len(_tokens) > max_token_size:
|
|
|
|
|
for start in range(
|
|
|
|
|
0, len(_tokens), max_token_size - overlap_token_size
|
|
|
|
|
):
|
|
|
|
|
chunk_content = decode_tokens_by_tiktoken(
|
|
|
|
|
_tokens[start : start + max_token_size],
|
|
|
|
|
model_name=tiktoken_model,
|
|
|
|
|
)
|
|
|
|
|
new_chunks.append(
|
|
|
|
|
(min(max_token_size, len(_tokens) - start), chunk_content)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
new_chunks.append((len(_tokens), chunk))
|
2025-01-07 00:28:15 +08:00
|
|
|
|
for index, (_len, chunk) in enumerate(new_chunks):
|
|
|
|
|
results.append(
|
|
|
|
|
{
|
|
|
|
|
"tokens": _len,
|
|
|
|
|
"content": chunk.strip(),
|
|
|
|
|
"chunk_order_index": index,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
for index, start in enumerate(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
range(0, len(tokens), max_token_size - overlap_token_size)
|
2025-01-07 00:28:15 +08:00
|
|
|
|
):
|
|
|
|
|
chunk_content = decode_tokens_by_tiktoken(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
tokens[start : start + max_token_size], model_name=tiktoken_model
|
2025-01-07 00:28:15 +08:00
|
|
|
|
)
|
|
|
|
|
results.append(
|
|
|
|
|
{
|
|
|
|
|
"tokens": min(max_token_size, len(tokens) - start),
|
|
|
|
|
"content": chunk_content.strip(),
|
|
|
|
|
"chunk_order_index": index,
|
|
|
|
|
}
|
|
|
|
|
)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return results
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
async def _handle_entity_relation_summary(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
entity_or_relation_name: str,
|
|
|
|
|
description: str,
|
|
|
|
|
global_config: dict,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
) -> str:
|
|
|
|
|
use_llm_func: callable = global_config["llm_model_func"]
|
|
|
|
|
llm_max_tokens = global_config["llm_model_max_token_size"]
|
|
|
|
|
tiktoken_model_name = global_config["tiktoken_model_name"]
|
|
|
|
|
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
|
2024-11-28 14:28:29 +01:00
|
|
|
|
language = global_config["addon_params"].get(
|
|
|
|
|
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
|
|
|
|
)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
|
|
|
|
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
|
|
|
|
|
if len(tokens) < summary_max_tokens: # No need for summary
|
|
|
|
|
return description
|
|
|
|
|
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
|
|
|
|
use_description = decode_tokens_by_tiktoken(
|
|
|
|
|
tokens[:llm_max_tokens], model_name=tiktoken_model_name
|
|
|
|
|
)
|
|
|
|
|
context_base = dict(
|
|
|
|
|
entity_name=entity_or_relation_name,
|
|
|
|
|
description_list=use_description.split(GRAPH_FIELD_SEP),
|
2024-11-28 14:28:29 +01:00
|
|
|
|
language=language,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
)
|
|
|
|
|
use_prompt = prompt_template.format(**context_base)
|
|
|
|
|
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
|
|
|
|
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
|
|
|
|
|
return summary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _handle_single_entity_extraction(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
record_attributes: list[str],
|
|
|
|
|
chunk_key: str,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
):
|
2024-10-18 15:33:11 +08:00
|
|
|
|
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return None
|
|
|
|
|
# add this record as a node in the G
|
|
|
|
|
entity_name = clean_str(record_attributes[1].upper())
|
|
|
|
|
if not entity_name.strip():
|
|
|
|
|
return None
|
|
|
|
|
entity_type = clean_str(record_attributes[2].upper())
|
|
|
|
|
entity_description = clean_str(record_attributes[3])
|
|
|
|
|
entity_source_id = chunk_key
|
|
|
|
|
return dict(
|
|
|
|
|
entity_name=entity_name,
|
|
|
|
|
entity_type=entity_type,
|
|
|
|
|
description=entity_description,
|
|
|
|
|
source_id=entity_source_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _handle_single_relationship_extraction(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
record_attributes: list[str],
|
|
|
|
|
chunk_key: str,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
):
|
2024-10-18 15:33:11 +08:00
|
|
|
|
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return None
|
|
|
|
|
# add this record as edge
|
|
|
|
|
source = clean_str(record_attributes[1].upper())
|
|
|
|
|
target = clean_str(record_attributes[2].upper())
|
|
|
|
|
edge_description = clean_str(record_attributes[3])
|
|
|
|
|
|
|
|
|
|
edge_keywords = clean_str(record_attributes[4])
|
|
|
|
|
edge_source_id = chunk_key
|
|
|
|
|
weight = (
|
|
|
|
|
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
|
|
|
|
|
)
|
|
|
|
|
return dict(
|
|
|
|
|
src_id=source,
|
|
|
|
|
tgt_id=target,
|
|
|
|
|
weight=weight,
|
|
|
|
|
description=edge_description,
|
|
|
|
|
keywords=edge_keywords,
|
|
|
|
|
source_id=edge_source_id,
|
2024-12-29 15:37:34 +08:00
|
|
|
|
metadata={"created_at": time.time()},
|
2024-10-10 15:02:30 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _merge_nodes_then_upsert(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
entity_name: str,
|
|
|
|
|
nodes_data: list[dict],
|
|
|
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
|
|
|
global_config: dict,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
):
|
2024-12-12 23:59:40 +08:00
|
|
|
|
already_entity_types = []
|
2024-10-10 15:02:30 +08:00
|
|
|
|
already_source_ids = []
|
|
|
|
|
already_description = []
|
|
|
|
|
|
2024-10-26 00:11:21 -04:00
|
|
|
|
already_node = await knowledge_graph_inst.get_node(entity_name)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
if already_node is not None:
|
2024-12-12 23:59:40 +08:00
|
|
|
|
already_entity_types.append(already_node["entity_type"])
|
2024-10-10 15:02:30 +08:00
|
|
|
|
already_source_ids.extend(
|
|
|
|
|
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
|
|
|
|
|
)
|
|
|
|
|
already_description.append(already_node["description"])
|
|
|
|
|
|
|
|
|
|
entity_type = sorted(
|
|
|
|
|
Counter(
|
2024-12-12 23:59:40 +08:00
|
|
|
|
[dp["entity_type"] for dp in nodes_data] + already_entity_types
|
2024-10-10 15:02:30 +08:00
|
|
|
|
).items(),
|
|
|
|
|
key=lambda x: x[1],
|
|
|
|
|
reverse=True,
|
|
|
|
|
)[0][0]
|
|
|
|
|
description = GRAPH_FIELD_SEP.join(
|
|
|
|
|
sorted(set([dp["description"] for dp in nodes_data] + already_description))
|
|
|
|
|
)
|
|
|
|
|
source_id = GRAPH_FIELD_SEP.join(
|
|
|
|
|
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
|
|
|
|
|
)
|
|
|
|
|
description = await _handle_entity_relation_summary(
|
|
|
|
|
entity_name, description, global_config
|
|
|
|
|
)
|
|
|
|
|
node_data = dict(
|
|
|
|
|
entity_type=entity_type,
|
|
|
|
|
description=description,
|
|
|
|
|
source_id=source_id,
|
|
|
|
|
)
|
2024-10-26 00:11:21 -04:00
|
|
|
|
await knowledge_graph_inst.upsert_node(
|
2024-10-10 15:02:30 +08:00
|
|
|
|
entity_name,
|
|
|
|
|
node_data=node_data,
|
|
|
|
|
)
|
|
|
|
|
node_data["entity_name"] = entity_name
|
|
|
|
|
return node_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _merge_edges_then_upsert(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
src_id: str,
|
|
|
|
|
tgt_id: str,
|
|
|
|
|
edges_data: list[dict],
|
|
|
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
|
|
|
global_config: dict,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
):
|
|
|
|
|
already_weights = []
|
|
|
|
|
already_source_ids = []
|
|
|
|
|
already_description = []
|
|
|
|
|
already_keywords = []
|
|
|
|
|
|
2024-10-26 00:11:21 -04:00
|
|
|
|
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
|
|
|
|
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
already_weights.append(already_edge["weight"])
|
|
|
|
|
already_source_ids.extend(
|
|
|
|
|
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
|
|
|
|
|
)
|
|
|
|
|
already_description.append(already_edge["description"])
|
|
|
|
|
already_keywords.extend(
|
|
|
|
|
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
|
|
|
|
description = GRAPH_FIELD_SEP.join(
|
|
|
|
|
sorted(set([dp["description"] for dp in edges_data] + already_description))
|
|
|
|
|
)
|
|
|
|
|
keywords = GRAPH_FIELD_SEP.join(
|
|
|
|
|
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
|
|
|
|
|
)
|
|
|
|
|
source_id = GRAPH_FIELD_SEP.join(
|
|
|
|
|
set([dp["source_id"] for dp in edges_data] + already_source_ids)
|
|
|
|
|
)
|
|
|
|
|
for need_insert_id in [src_id, tgt_id]:
|
2024-10-26 00:11:21 -04:00
|
|
|
|
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
|
|
|
|
await knowledge_graph_inst.upsert_node(
|
2024-10-10 15:02:30 +08:00
|
|
|
|
need_insert_id,
|
|
|
|
|
node_data={
|
|
|
|
|
"source_id": source_id,
|
|
|
|
|
"description": description,
|
|
|
|
|
"entity_type": '"UNKNOWN"',
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
description = await _handle_entity_relation_summary(
|
2024-12-06 20:54:01 +05:30
|
|
|
|
f"({src_id}, {tgt_id})", description, global_config
|
2024-10-10 15:02:30 +08:00
|
|
|
|
)
|
2024-10-26 00:11:21 -04:00
|
|
|
|
await knowledge_graph_inst.upsert_edge(
|
2024-10-10 15:02:30 +08:00
|
|
|
|
src_id,
|
|
|
|
|
tgt_id,
|
|
|
|
|
edge_data=dict(
|
|
|
|
|
weight=weight,
|
|
|
|
|
description=description,
|
|
|
|
|
keywords=keywords,
|
|
|
|
|
source_id=source_id,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
edge_data = dict(
|
|
|
|
|
src_id=src_id,
|
|
|
|
|
tgt_id=tgt_id,
|
|
|
|
|
description=description,
|
|
|
|
|
keywords=keywords,
|
|
|
|
|
)
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return edge_data
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
async def extract_entities(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
chunks: dict[str, TextChunkSchema],
|
|
|
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
|
|
|
entity_vdb: BaseVectorStorage,
|
|
|
|
|
relationships_vdb: BaseVectorStorage,
|
|
|
|
|
global_config: dict,
|
|
|
|
|
llm_response_cache: BaseKVStorage = None,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
) -> Union[BaseGraphStorage, None]:
|
|
|
|
|
use_llm_func: callable = global_config["llm_model_func"]
|
|
|
|
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
2025-01-06 12:50:05 +08:00
|
|
|
|
enable_llm_cache_for_entity_extract: bool = global_config[
|
|
|
|
|
"enable_llm_cache_for_entity_extract"
|
|
|
|
|
]
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
|
|
|
|
ordered_chunks = list(chunks.items())
|
2024-11-25 13:29:55 +08:00
|
|
|
|
# add language and example number params to prompt
|
2024-11-25 13:40:38 +08:00
|
|
|
|
language = global_config["addon_params"].get(
|
|
|
|
|
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
|
|
|
|
)
|
2024-12-11 13:53:05 +08:00
|
|
|
|
entity_types = global_config["addon_params"].get(
|
|
|
|
|
"entity_types", PROMPTS["DEFAULT_ENTITY_TYPES"]
|
|
|
|
|
)
|
2024-11-25 13:29:55 +08:00
|
|
|
|
example_number = global_config["addon_params"].get("example_number", None)
|
2024-11-25 13:40:38 +08:00
|
|
|
|
if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
|
|
|
|
|
examples = "\n".join(
|
|
|
|
|
PROMPTS["entity_extraction_examples"][: int(example_number)]
|
|
|
|
|
)
|
2024-11-25 13:29:55 +08:00
|
|
|
|
else:
|
2024-11-25 13:40:38 +08:00
|
|
|
|
examples = "\n".join(PROMPTS["entity_extraction_examples"])
|
|
|
|
|
|
2024-12-03 22:25:50 +08:00
|
|
|
|
example_context_base = dict(
|
|
|
|
|
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
|
|
|
|
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
|
|
|
|
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
2024-12-11 13:53:05 +08:00
|
|
|
|
entity_types=",".join(entity_types),
|
2024-12-03 22:25:50 +08:00
|
|
|
|
language=language,
|
|
|
|
|
)
|
|
|
|
|
# add example's format
|
|
|
|
|
examples = examples.format(**example_context_base)
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
entity_extract_prompt = PROMPTS["entity_extraction"]
|
|
|
|
|
context_base = dict(
|
|
|
|
|
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
|
|
|
|
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
|
|
|
|
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
2024-12-11 13:53:05 +08:00
|
|
|
|
entity_types=",".join(entity_types),
|
2024-11-25 13:29:55 +08:00
|
|
|
|
examples=examples,
|
2024-11-25 13:40:38 +08:00
|
|
|
|
language=language,
|
|
|
|
|
)
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
continue_prompt = PROMPTS["entiti_continue_extraction"]
|
|
|
|
|
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
|
|
|
|
|
|
|
|
|
|
already_processed = 0
|
|
|
|
|
already_entities = 0
|
|
|
|
|
already_relations = 0
|
|
|
|
|
|
2025-01-06 12:50:05 +08:00
|
|
|
|
async def _user_llm_func_with_cache(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
input_text: str, history_messages: list[dict[str, str]] = None
|
2025-01-06 12:50:05 +08:00
|
|
|
|
) -> str:
|
|
|
|
|
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
|
|
|
|
need_to_restore = False
|
|
|
|
|
if (
|
2025-01-07 16:26:12 +08:00
|
|
|
|
global_config["embedding_cache_config"]
|
|
|
|
|
and global_config["embedding_cache_config"]["enabled"]
|
2025-01-06 12:50:05 +08:00
|
|
|
|
):
|
|
|
|
|
new_config = global_config.copy()
|
|
|
|
|
new_config["embedding_cache_config"] = None
|
|
|
|
|
new_config["enable_llm_cache"] = True
|
|
|
|
|
llm_response_cache.global_config = new_config
|
|
|
|
|
need_to_restore = True
|
|
|
|
|
if history_messages:
|
|
|
|
|
history = json.dumps(history_messages)
|
|
|
|
|
_prompt = history + "\n" + input_text
|
|
|
|
|
else:
|
|
|
|
|
_prompt = input_text
|
|
|
|
|
|
|
|
|
|
arg_hash = compute_args_hash(_prompt)
|
|
|
|
|
cached_return, _1, _2, _3 = await handle_cache(
|
|
|
|
|
llm_response_cache, arg_hash, _prompt, "default"
|
|
|
|
|
)
|
|
|
|
|
if need_to_restore:
|
|
|
|
|
llm_response_cache.global_config = global_config
|
|
|
|
|
if cached_return:
|
|
|
|
|
return cached_return
|
|
|
|
|
|
|
|
|
|
if history_messages:
|
|
|
|
|
res: str = await use_llm_func(
|
|
|
|
|
input_text, history_messages=history_messages
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
res: str = await use_llm_func(input_text)
|
|
|
|
|
await save_to_cache(
|
|
|
|
|
llm_response_cache,
|
|
|
|
|
CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
|
|
|
|
|
)
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
if history_messages:
|
|
|
|
|
return await use_llm_func(input_text, history_messages=history_messages)
|
|
|
|
|
else:
|
|
|
|
|
return await use_llm_func(input_text)
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
|
|
|
|
nonlocal already_processed, already_entities, already_relations
|
|
|
|
|
chunk_key = chunk_key_dp[0]
|
|
|
|
|
chunk_dp = chunk_key_dp[1]
|
|
|
|
|
content = chunk_dp["content"]
|
2024-12-04 16:01:19 +08:00
|
|
|
|
# hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
|
2024-12-04 19:44:04 +08:00
|
|
|
|
hint_prompt = entity_extract_prompt.format(
|
|
|
|
|
**context_base, input_text="{input_text}"
|
|
|
|
|
).format(**context_base, input_text=content)
|
2024-12-04 16:01:19 +08:00
|
|
|
|
|
2025-01-06 12:50:05 +08:00
|
|
|
|
final_result = await _user_llm_func_with_cache(hint_prompt)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
|
|
|
|
for now_glean_index in range(entity_extract_max_gleaning):
|
2025-01-06 12:50:05 +08:00
|
|
|
|
glean_result = await _user_llm_func_with_cache(
|
|
|
|
|
continue_prompt, history_messages=history
|
|
|
|
|
)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
|
|
|
|
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
|
|
|
|
|
final_result += glean_result
|
|
|
|
|
if now_glean_index == entity_extract_max_gleaning - 1:
|
|
|
|
|
break
|
|
|
|
|
|
2025-01-06 12:50:05 +08:00
|
|
|
|
if_loop_result: str = await _user_llm_func_with_cache(
|
2024-10-10 15:02:30 +08:00
|
|
|
|
if_loop_prompt, history_messages=history
|
|
|
|
|
)
|
|
|
|
|
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
|
|
|
|
if if_loop_result != "yes":
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
records = split_string_by_multi_markers(
|
|
|
|
|
final_result,
|
|
|
|
|
[context_base["record_delimiter"], context_base["completion_delimiter"]],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
maybe_nodes = defaultdict(list)
|
|
|
|
|
maybe_edges = defaultdict(list)
|
|
|
|
|
for record in records:
|
|
|
|
|
record = re.search(r"\((.*)\)", record)
|
|
|
|
|
if record is None:
|
|
|
|
|
continue
|
|
|
|
|
record = record.group(1)
|
|
|
|
|
record_attributes = split_string_by_multi_markers(
|
|
|
|
|
record, [context_base["tuple_delimiter"]]
|
|
|
|
|
)
|
|
|
|
|
if_entities = await _handle_single_entity_extraction(
|
|
|
|
|
record_attributes, chunk_key
|
|
|
|
|
)
|
|
|
|
|
if if_entities is not None:
|
|
|
|
|
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if_relation = await _handle_single_relationship_extraction(
|
|
|
|
|
record_attributes, chunk_key
|
|
|
|
|
)
|
|
|
|
|
if if_relation is not None:
|
|
|
|
|
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
|
|
|
|
if_relation
|
|
|
|
|
)
|
|
|
|
|
already_processed += 1
|
|
|
|
|
already_entities += len(maybe_nodes)
|
|
|
|
|
already_relations += len(maybe_edges)
|
|
|
|
|
now_ticks = PROMPTS["process_tickers"][
|
|
|
|
|
already_processed % len(PROMPTS["process_tickers"])
|
2025-01-07 16:26:12 +08:00
|
|
|
|
]
|
2024-10-10 15:02:30 +08:00
|
|
|
|
print(
|
|
|
|
|
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
|
|
|
|
end="",
|
|
|
|
|
flush=True,
|
|
|
|
|
)
|
|
|
|
|
return dict(maybe_nodes), dict(maybe_edges)
|
|
|
|
|
|
2024-11-25 15:04:38 +08:00
|
|
|
|
results = []
|
|
|
|
|
for result in tqdm_async(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
|
|
|
|
|
total=len(ordered_chunks),
|
|
|
|
|
desc="Extracting entities from chunks",
|
|
|
|
|
unit="chunk",
|
2024-11-25 15:04:38 +08:00
|
|
|
|
):
|
|
|
|
|
results.append(await result)
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
maybe_nodes = defaultdict(list)
|
|
|
|
|
maybe_edges = defaultdict(list)
|
|
|
|
|
for m_nodes, m_edges in results:
|
|
|
|
|
for k, v in m_nodes.items():
|
|
|
|
|
maybe_nodes[k].extend(v)
|
|
|
|
|
for k, v in m_edges.items():
|
|
|
|
|
maybe_edges[tuple(sorted(k))].extend(v)
|
2024-11-25 15:04:38 +08:00
|
|
|
|
logger.info("Inserting entities into storage...")
|
|
|
|
|
all_entities_data = []
|
|
|
|
|
for result in tqdm_async(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
asyncio.as_completed(
|
|
|
|
|
[
|
|
|
|
|
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
|
|
|
|
for k, v in maybe_nodes.items()
|
|
|
|
|
]
|
|
|
|
|
),
|
|
|
|
|
total=len(maybe_nodes),
|
|
|
|
|
desc="Inserting entities",
|
|
|
|
|
unit="entity",
|
2024-11-25 15:04:38 +08:00
|
|
|
|
):
|
|
|
|
|
all_entities_data.append(await result)
|
|
|
|
|
|
|
|
|
|
logger.info("Inserting relationships into storage...")
|
|
|
|
|
all_relationships_data = []
|
|
|
|
|
for result in tqdm_async(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
asyncio.as_completed(
|
|
|
|
|
[
|
|
|
|
|
_merge_edges_then_upsert(
|
|
|
|
|
k[0], k[1], v, knowledge_graph_inst, global_config
|
|
|
|
|
)
|
|
|
|
|
for k, v in maybe_edges.items()
|
|
|
|
|
]
|
|
|
|
|
),
|
|
|
|
|
total=len(maybe_edges),
|
|
|
|
|
desc="Inserting relationships",
|
|
|
|
|
unit="relationship",
|
2024-11-25 15:04:38 +08:00
|
|
|
|
):
|
|
|
|
|
all_relationships_data.append(await result)
|
|
|
|
|
|
2024-12-10 14:13:11 +08:00
|
|
|
|
if not len(all_entities_data) and not len(all_relationships_data):
|
2024-10-19 09:43:17 +05:30
|
|
|
|
logger.warning(
|
2024-12-10 14:15:43 +08:00
|
|
|
|
"Didn't extract any entities and relationships, maybe your LLM is not working"
|
2024-10-19 09:43:17 +05:30
|
|
|
|
)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return None
|
|
|
|
|
|
2024-12-10 14:13:11 +08:00
|
|
|
|
if not len(all_entities_data):
|
|
|
|
|
logger.warning("Didn't extract any entities")
|
|
|
|
|
if not len(all_relationships_data):
|
|
|
|
|
logger.warning("Didn't extract any relationships")
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
if entity_vdb is not None:
|
|
|
|
|
data_for_vdb = {
|
|
|
|
|
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
|
|
|
|
"content": dp["entity_name"] + dp["description"],
|
|
|
|
|
"entity_name": dp["entity_name"],
|
|
|
|
|
}
|
|
|
|
|
for dp in all_entities_data
|
|
|
|
|
}
|
|
|
|
|
await entity_vdb.upsert(data_for_vdb)
|
|
|
|
|
|
|
|
|
|
if relationships_vdb is not None:
|
|
|
|
|
data_for_vdb = {
|
|
|
|
|
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
|
|
|
|
"src_id": dp["src_id"],
|
|
|
|
|
"tgt_id": dp["tgt_id"],
|
2024-10-19 09:43:17 +05:30
|
|
|
|
"content": dp["keywords"]
|
2025-01-07 16:26:12 +08:00
|
|
|
|
+ dp["src_id"]
|
|
|
|
|
+ dp["tgt_id"]
|
|
|
|
|
+ dp["description"],
|
2024-12-29 15:25:57 +08:00
|
|
|
|
"metadata": {
|
|
|
|
|
"created_at": dp.get("metadata", {}).get("created_at", time.time())
|
2024-12-29 15:37:34 +08:00
|
|
|
|
},
|
2024-10-10 15:02:30 +08:00
|
|
|
|
}
|
|
|
|
|
for dp in all_relationships_data
|
|
|
|
|
}
|
|
|
|
|
await relationships_vdb.upsert(data_for_vdb)
|
|
|
|
|
|
2024-10-26 00:11:21 -04:00
|
|
|
|
return knowledge_graph_inst
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
2024-11-25 13:40:38 +08:00
|
|
|
|
|
2024-11-25 13:29:55 +08:00
|
|
|
|
async def kg_query(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
query,
|
|
|
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
|
|
|
entities_vdb: BaseVectorStorage,
|
|
|
|
|
relationships_vdb: BaseVectorStorage,
|
|
|
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
|
|
|
query_param: QueryParam,
|
|
|
|
|
global_config: dict,
|
|
|
|
|
hashing_kv: BaseKVStorage = None,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
) -> str:
|
2024-12-08 17:35:52 +08:00
|
|
|
|
# Handle cache
|
|
|
|
|
use_model_func = global_config["llm_model_func"]
|
|
|
|
|
args_hash = compute_args_hash(query_param.mode, query)
|
|
|
|
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
|
|
|
|
hashing_kv, args_hash, query, query_param.mode
|
|
|
|
|
)
|
|
|
|
|
if cached_response is not None:
|
|
|
|
|
return cached_response
|
|
|
|
|
|
2024-11-25 13:29:55 +08:00
|
|
|
|
example_number = global_config["addon_params"].get("example_number", None)
|
|
|
|
|
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
|
2024-11-25 13:40:38 +08:00
|
|
|
|
examples = "\n".join(
|
|
|
|
|
PROMPTS["keywords_extraction_examples"][: int(example_number)]
|
|
|
|
|
)
|
2024-11-25 13:29:55 +08:00
|
|
|
|
else:
|
2024-11-25 13:40:38 +08:00
|
|
|
|
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
|
2024-11-28 14:28:29 +01:00
|
|
|
|
language = global_config["addon_params"].get(
|
|
|
|
|
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
|
|
|
|
)
|
2024-11-25 13:40:38 +08:00
|
|
|
|
|
2024-11-25 13:29:55 +08:00
|
|
|
|
# Set mode
|
|
|
|
|
if query_param.mode not in ["local", "global", "hybrid"]:
|
|
|
|
|
logger.error(f"Unknown mode {query_param.mode} in kg_query")
|
|
|
|
|
return PROMPTS["fail_response"]
|
2024-11-25 13:40:38 +08:00
|
|
|
|
|
2024-11-25 13:29:55 +08:00
|
|
|
|
# LLM generate keywords
|
2024-10-10 15:02:30 +08:00
|
|
|
|
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
2024-11-28 14:28:29 +01:00
|
|
|
|
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
|
2024-12-08 17:35:52 +08:00
|
|
|
|
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
2024-11-25 13:40:38 +08:00
|
|
|
|
logger.info("kw_prompt result:")
|
2024-11-25 13:29:55 +08:00
|
|
|
|
print(result)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
try:
|
2024-11-29 21:41:37 +08:00
|
|
|
|
# json_text = locate_json_string_body_from_string(result) # handled in use_model_func
|
2024-12-05 20:22:44 +08:00
|
|
|
|
match = re.search(r"\{.*\}", result, re.DOTALL)
|
|
|
|
|
if match:
|
|
|
|
|
result = match.group(0)
|
|
|
|
|
keywords_data = json.loads(result)
|
|
|
|
|
|
|
|
|
|
hl_keywords = keywords_data.get("high_level_keywords", [])
|
|
|
|
|
ll_keywords = keywords_data.get("low_level_keywords", [])
|
|
|
|
|
else:
|
|
|
|
|
logger.error("No JSON-like structure found in the result.")
|
|
|
|
|
return PROMPTS["fail_response"]
|
2024-11-25 13:40:38 +08:00
|
|
|
|
|
2024-11-25 13:29:55 +08:00
|
|
|
|
# Handle parsing error
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
print(f"JSON parsing error: {e} {result}")
|
|
|
|
|
return PROMPTS["fail_response"]
|
2024-11-25 13:40:38 +08:00
|
|
|
|
|
2024-11-25 13:29:55 +08:00
|
|
|
|
# Handdle keywords missing
|
|
|
|
|
if hl_keywords == [] and ll_keywords == []:
|
|
|
|
|
logger.warning("low_level_keywords and high_level_keywords is empty")
|
2024-11-25 13:40:38 +08:00
|
|
|
|
return PROMPTS["fail_response"]
|
|
|
|
|
if ll_keywords == [] and query_param.mode in ["local", "hybrid"]:
|
2025-01-07 22:02:34 +08:00
|
|
|
|
logger.warning(
|
|
|
|
|
"low_level_keywords is empty, switching from %s mode to global mode",
|
|
|
|
|
query_param.mode,
|
|
|
|
|
)
|
2025-01-06 16:54:53 +08:00
|
|
|
|
query_param.mode = "global"
|
2024-11-25 13:40:38 +08:00
|
|
|
|
if hl_keywords == [] and query_param.mode in ["global", "hybrid"]:
|
2025-01-07 22:02:34 +08:00
|
|
|
|
logger.warning(
|
|
|
|
|
"high_level_keywords is empty, switching from %s mode to local mode",
|
|
|
|
|
query_param.mode,
|
|
|
|
|
)
|
2025-01-06 16:54:53 +08:00
|
|
|
|
query_param.mode = "local"
|
|
|
|
|
|
|
|
|
|
ll_keywords = ", ".join(ll_keywords) if ll_keywords else ""
|
|
|
|
|
hl_keywords = ", ".join(hl_keywords) if hl_keywords else ""
|
|
|
|
|
|
|
|
|
|
logger.info("Using %s mode for query processing", query_param.mode)
|
2024-11-25 13:40:38 +08:00
|
|
|
|
|
2024-11-25 13:29:55 +08:00
|
|
|
|
# Build context
|
2024-11-25 13:40:38 +08:00
|
|
|
|
keywords = [ll_keywords, hl_keywords]
|
2024-11-25 13:29:55 +08:00
|
|
|
|
context = await _build_query_context(
|
2024-11-25 13:40:38 +08:00
|
|
|
|
keywords,
|
|
|
|
|
knowledge_graph_inst,
|
|
|
|
|
entities_vdb,
|
|
|
|
|
relationships_vdb,
|
|
|
|
|
text_chunks_db,
|
|
|
|
|
query_param,
|
|
|
|
|
)
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
if query_param.only_need_context:
|
|
|
|
|
return context
|
|
|
|
|
if context is None:
|
|
|
|
|
return PROMPTS["fail_response"]
|
|
|
|
|
sys_prompt_temp = PROMPTS["rag_response"]
|
|
|
|
|
sys_prompt = sys_prompt_temp.format(
|
|
|
|
|
context_data=context, response_type=query_param.response_type
|
2024-11-25 13:40:38 +08:00
|
|
|
|
)
|
2024-11-15 12:57:01 +08:00
|
|
|
|
if query_param.only_need_prompt:
|
|
|
|
|
return sys_prompt
|
2024-10-10 15:02:30 +08:00
|
|
|
|
response = await use_model_func(
|
|
|
|
|
query,
|
|
|
|
|
system_prompt=sys_prompt,
|
2024-12-06 08:48:55 +08:00
|
|
|
|
stream=query_param.stream,
|
2024-11-25 13:40:38 +08:00
|
|
|
|
)
|
2024-12-06 08:48:55 +08:00
|
|
|
|
if isinstance(response, str) and len(response) > len(sys_prompt):
|
2024-10-19 09:43:17 +05:30
|
|
|
|
response = (
|
|
|
|
|
response.replace(sys_prompt, "")
|
|
|
|
|
.replace("user", "")
|
|
|
|
|
.replace("model", "")
|
|
|
|
|
.replace(query, "")
|
|
|
|
|
.replace("<system>", "")
|
|
|
|
|
.replace("</system>", "")
|
|
|
|
|
.strip()
|
|
|
|
|
)
|
|
|
|
|
|
2024-12-08 17:35:52 +08:00
|
|
|
|
# Save to cache
|
|
|
|
|
await save_to_cache(
|
|
|
|
|
hashing_kv,
|
|
|
|
|
CacheData(
|
|
|
|
|
args_hash=args_hash,
|
|
|
|
|
content=response,
|
|
|
|
|
prompt=query,
|
|
|
|
|
quantized=quantized,
|
|
|
|
|
min_val=min_val,
|
|
|
|
|
max_val=max_val,
|
|
|
|
|
mode=query_param.mode,
|
|
|
|
|
),
|
|
|
|
|
)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return response
|
|
|
|
|
|
2025-01-14 22:10:47 +05:30
|
|
|
|
async def kg_query_with_keywords(
|
|
|
|
|
query: str,
|
|
|
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
|
|
|
entities_vdb: BaseVectorStorage,
|
|
|
|
|
relationships_vdb: BaseVectorStorage,
|
|
|
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
|
|
|
query_param: QueryParam,
|
|
|
|
|
global_config: dict,
|
|
|
|
|
hashing_kv: BaseKVStorage = None,
|
|
|
|
|
) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Refactored kg_query that does NOT extract keywords by itself.
|
|
|
|
|
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
|
|
|
|
Then it uses those to build context and produce a final LLM response.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# ---------------------------
|
|
|
|
|
# 0) Handle potential cache
|
|
|
|
|
# ---------------------------
|
|
|
|
|
use_model_func = global_config["llm_model_func"]
|
|
|
|
|
args_hash = compute_args_hash(query_param.mode, query)
|
|
|
|
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
|
|
|
|
hashing_kv, args_hash, query, query_param.mode
|
|
|
|
|
)
|
|
|
|
|
if cached_response is not None:
|
|
|
|
|
return cached_response
|
|
|
|
|
|
|
|
|
|
# ---------------------------
|
|
|
|
|
# 1) RETRIEVE KEYWORDS FROM query_param
|
|
|
|
|
# ---------------------------
|
|
|
|
|
|
|
|
|
|
# If these fields don't exist, default to empty lists/strings.
|
|
|
|
|
hl_keywords = getattr(query_param, "hl_keywords", []) or []
|
|
|
|
|
ll_keywords = getattr(query_param, "ll_keywords", []) or []
|
|
|
|
|
|
|
|
|
|
# If neither has any keywords, you could handle that logic here.
|
|
|
|
|
if not hl_keywords and not ll_keywords:
|
|
|
|
|
logger.warning("No keywords found in query_param. Could default to global mode or fail.")
|
|
|
|
|
return PROMPTS["fail_response"]
|
|
|
|
|
if not ll_keywords and query_param.mode in ["local", "hybrid"]:
|
|
|
|
|
logger.warning("low_level_keywords is empty, switching to global mode.")
|
|
|
|
|
query_param.mode = "global"
|
|
|
|
|
if not hl_keywords and query_param.mode in ["global", "hybrid"]:
|
|
|
|
|
logger.warning("high_level_keywords is empty, switching to local mode.")
|
|
|
|
|
query_param.mode = "local"
|
|
|
|
|
|
|
|
|
|
# Flatten low-level and high-level keywords if needed
|
|
|
|
|
ll_keywords_flat = [item for sublist in ll_keywords for item in sublist] if any(isinstance(i, list) for i in ll_keywords) else ll_keywords
|
|
|
|
|
hl_keywords_flat = [item for sublist in hl_keywords for item in sublist] if any(isinstance(i, list) for i in hl_keywords) else hl_keywords
|
|
|
|
|
|
|
|
|
|
# Join the flattened lists
|
|
|
|
|
ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else ""
|
|
|
|
|
hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else ""
|
|
|
|
|
|
|
|
|
|
keywords = [ll_keywords_str, hl_keywords_str]
|
|
|
|
|
|
|
|
|
|
logger.info("Using %s mode for query processing", query_param.mode)
|
|
|
|
|
|
|
|
|
|
# ---------------------------
|
|
|
|
|
# 2) BUILD CONTEXT
|
|
|
|
|
# ---------------------------
|
|
|
|
|
context = await _build_query_context(
|
|
|
|
|
keywords,
|
|
|
|
|
knowledge_graph_inst,
|
|
|
|
|
entities_vdb,
|
|
|
|
|
relationships_vdb,
|
|
|
|
|
text_chunks_db,
|
|
|
|
|
query_param,
|
|
|
|
|
)
|
|
|
|
|
if not context:
|
|
|
|
|
return PROMPTS["fail_response"]
|
|
|
|
|
|
|
|
|
|
# If only context is needed, return it
|
|
|
|
|
if query_param.only_need_context:
|
|
|
|
|
return context
|
|
|
|
|
|
|
|
|
|
# ---------------------------
|
|
|
|
|
# 3) BUILD THE SYSTEM PROMPT + CALL LLM
|
|
|
|
|
# ---------------------------
|
|
|
|
|
sys_prompt_temp = PROMPTS["rag_response"]
|
|
|
|
|
sys_prompt = sys_prompt_temp.format(
|
|
|
|
|
context_data=context, response_type=query_param.response_type
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if query_param.only_need_prompt:
|
|
|
|
|
return sys_prompt
|
|
|
|
|
|
|
|
|
|
# Now call the LLM with the final system prompt
|
|
|
|
|
response = await use_model_func(
|
|
|
|
|
query,
|
|
|
|
|
system_prompt=sys_prompt,
|
|
|
|
|
stream=query_param.stream,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Clean up the response
|
|
|
|
|
if isinstance(response, str) and len(response) > len(sys_prompt):
|
|
|
|
|
response = (
|
|
|
|
|
response.replace(sys_prompt, "")
|
|
|
|
|
.replace("user", "")
|
|
|
|
|
.replace("model", "")
|
|
|
|
|
.replace(query, "")
|
|
|
|
|
.replace("<system>", "")
|
|
|
|
|
.replace("</system>", "")
|
|
|
|
|
.strip()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# ---------------------------
|
|
|
|
|
# 4) SAVE TO CACHE
|
|
|
|
|
# ---------------------------
|
|
|
|
|
await save_to_cache(
|
|
|
|
|
hashing_kv,
|
|
|
|
|
CacheData(
|
|
|
|
|
args_hash=args_hash,
|
|
|
|
|
content=response,
|
|
|
|
|
prompt=query,
|
|
|
|
|
quantized=quantized,
|
|
|
|
|
min_val=min_val,
|
|
|
|
|
max_val=max_val,
|
|
|
|
|
mode=query_param.mode,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
async def extract_keywords_only(
|
|
|
|
|
text: str,
|
|
|
|
|
param: QueryParam,
|
|
|
|
|
global_config: dict,
|
|
|
|
|
hashing_kv: BaseKVStorage = None,
|
|
|
|
|
) -> tuple[list[str], list[str]]:
|
|
|
|
|
"""
|
|
|
|
|
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
|
|
|
|
This method does NOT build the final RAG context or provide a final answer.
|
|
|
|
|
It ONLY extracts keywords (hl_keywords, ll_keywords).
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# 1. Handle cache if needed
|
|
|
|
|
args_hash = compute_args_hash(param.mode, text)
|
|
|
|
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
|
|
|
|
hashing_kv, args_hash, text, param.mode
|
|
|
|
|
)
|
|
|
|
|
if cached_response is not None:
|
|
|
|
|
# parse the cached_response if it’s JSON containing keywords
|
|
|
|
|
# or simply return (hl_keywords, ll_keywords) from cached
|
|
|
|
|
# Assuming cached_response is in the same JSON structure:
|
|
|
|
|
match = re.search(r"\{.*\}", cached_response, re.DOTALL)
|
|
|
|
|
if match:
|
|
|
|
|
keywords_data = json.loads(match.group(0))
|
|
|
|
|
hl_keywords = keywords_data.get("high_level_keywords", [])
|
|
|
|
|
ll_keywords = keywords_data.get("low_level_keywords", [])
|
|
|
|
|
return hl_keywords, ll_keywords
|
|
|
|
|
return [], []
|
|
|
|
|
|
|
|
|
|
# 2. Build the examples
|
|
|
|
|
example_number = global_config["addon_params"].get("example_number", None)
|
|
|
|
|
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
|
|
|
|
|
examples = "\n".join(
|
|
|
|
|
PROMPTS["keywords_extraction_examples"][: int(example_number)]
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
|
|
|
|
|
language = global_config["addon_params"].get(
|
|
|
|
|
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 3. Build the keyword-extraction prompt
|
|
|
|
|
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
|
|
|
|
kw_prompt = kw_prompt_temp.format(query=text, examples=examples, language=language)
|
|
|
|
|
|
|
|
|
|
# 4. Call the LLM for keyword extraction
|
|
|
|
|
use_model_func = global_config["llm_model_func"]
|
|
|
|
|
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
|
|
|
|
|
|
|
|
|
# 5. Parse out JSON from the LLM response
|
|
|
|
|
match = re.search(r"\{.*\}", result, re.DOTALL)
|
|
|
|
|
if not match:
|
|
|
|
|
logger.error("No JSON-like structure found in the result.")
|
|
|
|
|
return [], []
|
|
|
|
|
try:
|
|
|
|
|
keywords_data = json.loads(match.group(0))
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
logger.error(f"JSON parsing error: {e}")
|
|
|
|
|
return [], []
|
|
|
|
|
|
|
|
|
|
hl_keywords = keywords_data.get("high_level_keywords", [])
|
|
|
|
|
ll_keywords = keywords_data.get("low_level_keywords", [])
|
|
|
|
|
|
|
|
|
|
# 6. Cache the result if needed
|
|
|
|
|
await save_to_cache(
|
|
|
|
|
hashing_kv,
|
|
|
|
|
CacheData(
|
|
|
|
|
args_hash=args_hash,
|
|
|
|
|
content=result,
|
|
|
|
|
prompt=text,
|
|
|
|
|
quantized=quantized,
|
|
|
|
|
min_val=min_val,
|
|
|
|
|
max_val=max_val,
|
|
|
|
|
mode=param.mode,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
return hl_keywords, ll_keywords
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-11-25 13:29:55 +08:00
|
|
|
|
async def _build_query_context(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
query: list,
|
|
|
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
|
|
|
entities_vdb: BaseVectorStorage,
|
|
|
|
|
relationships_vdb: BaseVectorStorage,
|
|
|
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
|
|
|
query_param: QueryParam,
|
2024-11-25 13:40:38 +08:00
|
|
|
|
):
|
2024-12-12 15:47:57 -05:00
|
|
|
|
# ll_entities_context, ll_relations_context, ll_text_units_context = "", "", ""
|
|
|
|
|
# hl_entities_context, hl_relations_context, hl_text_units_context = "", "", ""
|
|
|
|
|
|
2025-01-06 16:54:53 +08:00
|
|
|
|
ll_keywords, hl_keywords = query[0], query[1]
|
|
|
|
|
|
|
|
|
|
if query_param.mode == "local":
|
|
|
|
|
entities_context, relations_context, text_units_context = await _get_node_data(
|
|
|
|
|
ll_keywords,
|
|
|
|
|
knowledge_graph_inst,
|
|
|
|
|
entities_vdb,
|
|
|
|
|
text_chunks_db,
|
|
|
|
|
query_param,
|
2024-11-25 13:40:38 +08:00
|
|
|
|
)
|
2025-01-06 16:54:53 +08:00
|
|
|
|
elif query_param.mode == "global":
|
|
|
|
|
entities_context, relations_context, text_units_context = await _get_edge_data(
|
|
|
|
|
hl_keywords,
|
|
|
|
|
knowledge_graph_inst,
|
|
|
|
|
relationships_vdb,
|
|
|
|
|
text_chunks_db,
|
|
|
|
|
query_param,
|
|
|
|
|
)
|
|
|
|
|
else: # hybrid mode
|
2025-01-07 22:02:34 +08:00
|
|
|
|
(
|
2024-11-25 13:40:38 +08:00
|
|
|
|
ll_entities_context,
|
|
|
|
|
ll_relations_context,
|
|
|
|
|
ll_text_units_context,
|
2025-01-07 22:02:34 +08:00
|
|
|
|
) = await _get_node_data(
|
2025-01-06 16:54:53 +08:00
|
|
|
|
ll_keywords,
|
|
|
|
|
knowledge_graph_inst,
|
|
|
|
|
entities_vdb,
|
|
|
|
|
text_chunks_db,
|
|
|
|
|
query_param,
|
2024-11-25 13:40:38 +08:00
|
|
|
|
)
|
2025-01-07 22:02:34 +08:00
|
|
|
|
(
|
2024-11-25 13:40:38 +08:00
|
|
|
|
hl_entities_context,
|
|
|
|
|
hl_relations_context,
|
|
|
|
|
hl_text_units_context,
|
2025-01-07 22:02:34 +08:00
|
|
|
|
) = await _get_edge_data(
|
2025-01-06 16:54:53 +08:00
|
|
|
|
hl_keywords,
|
|
|
|
|
knowledge_graph_inst,
|
|
|
|
|
relationships_vdb,
|
|
|
|
|
text_chunks_db,
|
|
|
|
|
query_param,
|
|
|
|
|
)
|
2024-11-25 13:40:38 +08:00
|
|
|
|
entities_context, relations_context, text_units_context = combine_contexts(
|
|
|
|
|
[hl_entities_context, ll_entities_context],
|
|
|
|
|
[hl_relations_context, ll_relations_context],
|
|
|
|
|
[hl_text_units_context, ll_text_units_context],
|
|
|
|
|
)
|
2024-11-25 13:29:55 +08:00
|
|
|
|
return f"""
|
2024-11-26 10:19:28 +08:00
|
|
|
|
-----Entities-----
|
|
|
|
|
```csv
|
|
|
|
|
{entities_context}
|
|
|
|
|
```
|
|
|
|
|
-----Relationships-----
|
|
|
|
|
```csv
|
|
|
|
|
{relations_context}
|
|
|
|
|
```
|
|
|
|
|
-----Sources-----
|
|
|
|
|
```csv
|
|
|
|
|
{text_units_context}
|
|
|
|
|
```
|
|
|
|
|
"""
|
2024-11-25 13:29:55 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _get_node_data(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
query,
|
|
|
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
|
|
|
entities_vdb: BaseVectorStorage,
|
|
|
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
|
|
|
query_param: QueryParam,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
):
|
2024-11-26 10:19:28 +08:00
|
|
|
|
# get similar entities
|
2024-10-10 15:02:30 +08:00
|
|
|
|
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
|
|
|
|
if not len(results):
|
2024-12-12 15:47:57 -05:00
|
|
|
|
return "", "", ""
|
2024-11-26 10:19:28 +08:00
|
|
|
|
# get entity information
|
2024-10-10 15:02:30 +08:00
|
|
|
|
node_datas = await asyncio.gather(
|
|
|
|
|
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
|
|
|
|
)
|
|
|
|
|
if not all([n is not None for n in node_datas]):
|
|
|
|
|
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
2024-11-25 13:40:38 +08:00
|
|
|
|
|
2024-11-26 10:19:28 +08:00
|
|
|
|
# get entity degree
|
2024-10-10 15:02:30 +08:00
|
|
|
|
node_degrees = await asyncio.gather(
|
|
|
|
|
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
|
|
|
|
)
|
|
|
|
|
node_datas = [
|
|
|
|
|
{**n, "entity_name": k["entity_name"], "rank": d}
|
|
|
|
|
for k, n, d in zip(results, node_datas, node_degrees)
|
|
|
|
|
if n is not None
|
2024-11-06 11:18:14 -05:00
|
|
|
|
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
2024-11-26 10:21:39 +08:00
|
|
|
|
# get entitytext chunk
|
2024-10-10 15:02:30 +08:00
|
|
|
|
use_text_units = await _find_most_related_text_unit_from_entities(
|
|
|
|
|
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
|
|
|
|
)
|
2024-11-26 10:19:28 +08:00
|
|
|
|
# get relate edges
|
2024-10-10 15:02:30 +08:00
|
|
|
|
use_relations = await _find_most_related_edges_from_entities(
|
|
|
|
|
node_datas, query_param, knowledge_graph_inst
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
|
2024-11-25 13:40:38 +08:00
|
|
|
|
)
|
2024-11-25 13:29:55 +08:00
|
|
|
|
|
2024-11-26 10:19:28 +08:00
|
|
|
|
# build prompt
|
2024-10-10 15:02:30 +08:00
|
|
|
|
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
|
|
|
|
for i, n in enumerate(node_datas):
|
|
|
|
|
entites_section_list.append(
|
|
|
|
|
[
|
|
|
|
|
i,
|
|
|
|
|
n["entity_name"],
|
|
|
|
|
n.get("entity_type", "UNKNOWN"),
|
|
|
|
|
n.get("description", "UNKNOWN"),
|
|
|
|
|
n["rank"],
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
entities_context = list_of_list_to_csv(entites_section_list)
|
|
|
|
|
|
|
|
|
|
relations_section_list = [
|
2024-12-29 15:37:34 +08:00
|
|
|
|
[
|
|
|
|
|
"id",
|
|
|
|
|
"source",
|
|
|
|
|
"target",
|
|
|
|
|
"description",
|
|
|
|
|
"keywords",
|
|
|
|
|
"weight",
|
|
|
|
|
"rank",
|
|
|
|
|
"created_at",
|
|
|
|
|
]
|
2024-10-10 15:02:30 +08:00
|
|
|
|
]
|
|
|
|
|
for i, e in enumerate(use_relations):
|
2024-12-29 15:37:34 +08:00
|
|
|
|
created_at = e.get("created_at", "UNKNOWN")
|
|
|
|
|
# Convert timestamp to readable format
|
2024-12-29 15:25:57 +08:00
|
|
|
|
if isinstance(created_at, (int, float)):
|
|
|
|
|
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
2024-10-10 15:02:30 +08:00
|
|
|
|
relations_section_list.append(
|
|
|
|
|
[
|
|
|
|
|
i,
|
|
|
|
|
e["src_tgt"][0],
|
|
|
|
|
e["src_tgt"][1],
|
|
|
|
|
e["description"],
|
|
|
|
|
e["keywords"],
|
|
|
|
|
e["weight"],
|
|
|
|
|
e["rank"],
|
2024-12-29 15:37:34 +08:00
|
|
|
|
created_at,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
relations_context = list_of_list_to_csv(relations_section_list)
|
|
|
|
|
|
|
|
|
|
text_units_section_list = [["id", "content"]]
|
|
|
|
|
for i, t in enumerate(use_text_units):
|
|
|
|
|
text_units_section_list.append([i, t["content"]])
|
|
|
|
|
text_units_context = list_of_list_to_csv(text_units_section_list)
|
2024-11-25 13:40:38 +08:00
|
|
|
|
return entities_context, relations_context, text_units_context
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
async def _find_most_related_text_unit_from_entities(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
node_datas: list[dict],
|
|
|
|
|
query_param: QueryParam,
|
|
|
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
|
|
|
knowledge_graph_inst: BaseGraphStorage,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
):
|
|
|
|
|
text_units = [
|
|
|
|
|
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
|
|
|
|
for dp in node_datas
|
|
|
|
|
]
|
|
|
|
|
edges = await asyncio.gather(
|
|
|
|
|
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
|
|
|
|
)
|
|
|
|
|
all_one_hop_nodes = set()
|
|
|
|
|
for this_edges in edges:
|
|
|
|
|
if not this_edges:
|
|
|
|
|
continue
|
|
|
|
|
all_one_hop_nodes.update([e[1] for e in this_edges])
|
2024-11-11 10:45:22 +08:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
all_one_hop_nodes = list(all_one_hop_nodes)
|
|
|
|
|
all_one_hop_nodes_data = await asyncio.gather(
|
|
|
|
|
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
|
|
|
|
|
)
|
2024-11-11 10:45:22 +08:00
|
|
|
|
|
2024-11-05 18:36:59 -08:00
|
|
|
|
# Add null check for node data
|
2024-10-10 15:02:30 +08:00
|
|
|
|
all_one_hop_text_units_lookup = {
|
|
|
|
|
k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP]))
|
|
|
|
|
for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data)
|
2024-11-05 18:36:59 -08:00
|
|
|
|
if v is not None and "source_id" in v # Add source_id check
|
2024-10-10 15:02:30 +08:00
|
|
|
|
}
|
2024-11-11 10:45:22 +08:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
all_text_units_lookup = {}
|
|
|
|
|
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
|
|
|
|
|
for c_id in this_text_units:
|
2024-11-21 14:35:18 +08:00
|
|
|
|
if c_id not in all_text_units_lookup:
|
|
|
|
|
all_text_units_lookup[c_id] = {
|
|
|
|
|
"data": await text_chunks_db.get_by_id(c_id),
|
|
|
|
|
"order": index,
|
|
|
|
|
"relation_counts": 0,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if this_edges:
|
2024-11-05 18:36:59 -08:00
|
|
|
|
for e in this_edges:
|
|
|
|
|
if (
|
2025-01-07 16:26:12 +08:00
|
|
|
|
e[1] in all_one_hop_text_units_lookup
|
|
|
|
|
and c_id in all_one_hop_text_units_lookup[e[1]]
|
2024-11-05 18:36:59 -08:00
|
|
|
|
):
|
2024-11-21 14:35:18 +08:00
|
|
|
|
all_text_units_lookup[c_id]["relation_counts"] += 1
|
2024-11-11 10:45:22 +08:00
|
|
|
|
|
2024-11-05 18:36:59 -08:00
|
|
|
|
# Filter out None values and ensure data has content
|
2024-10-10 15:02:30 +08:00
|
|
|
|
all_text_units = [
|
2024-11-11 10:45:22 +08:00
|
|
|
|
{"id": k, **v}
|
|
|
|
|
for k, v in all_text_units_lookup.items()
|
2024-11-05 18:36:59 -08:00
|
|
|
|
if v is not None and v.get("data") is not None and "content" in v["data"]
|
2024-10-10 15:02:30 +08:00
|
|
|
|
]
|
2024-11-11 10:45:22 +08:00
|
|
|
|
|
2024-11-05 18:36:59 -08:00
|
|
|
|
if not all_text_units:
|
|
|
|
|
logger.warning("No valid text units found")
|
|
|
|
|
return []
|
2024-11-11 10:45:22 +08:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
all_text_units = sorted(
|
2024-11-11 10:45:22 +08:00
|
|
|
|
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
2024-10-10 15:02:30 +08:00
|
|
|
|
)
|
2024-11-11 10:45:22 +08:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
all_text_units = truncate_list_by_token_size(
|
|
|
|
|
all_text_units,
|
|
|
|
|
key=lambda x: x["data"]["content"],
|
|
|
|
|
max_token_size=query_param.max_token_for_text_unit,
|
|
|
|
|
)
|
2024-11-11 10:45:22 +08:00
|
|
|
|
|
2024-11-05 18:36:59 -08:00
|
|
|
|
all_text_units = [t["data"] for t in all_text_units]
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return all_text_units
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
async def _find_most_related_edges_from_entities(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
node_datas: list[dict],
|
|
|
|
|
query_param: QueryParam,
|
|
|
|
|
knowledge_graph_inst: BaseGraphStorage,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
):
|
|
|
|
|
all_related_edges = await asyncio.gather(
|
|
|
|
|
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
|
|
|
|
)
|
2024-11-14 15:59:37 +08:00
|
|
|
|
all_edges = []
|
|
|
|
|
seen = set()
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
for this_edges in all_related_edges:
|
2024-11-14 15:59:37 +08:00
|
|
|
|
for e in this_edges:
|
|
|
|
|
sorted_edge = tuple(sorted(e))
|
|
|
|
|
if sorted_edge not in seen:
|
|
|
|
|
seen.add(sorted_edge)
|
|
|
|
|
all_edges.append(sorted_edge)
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
all_edges_pack = await asyncio.gather(
|
|
|
|
|
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
|
|
|
|
|
)
|
|
|
|
|
all_edges_degree = await asyncio.gather(
|
|
|
|
|
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
|
|
|
|
|
)
|
|
|
|
|
all_edges_data = [
|
|
|
|
|
{"src_tgt": k, "rank": d, **v}
|
|
|
|
|
for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree)
|
|
|
|
|
if v is not None
|
|
|
|
|
]
|
|
|
|
|
all_edges_data = sorted(
|
|
|
|
|
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
|
|
|
|
)
|
|
|
|
|
all_edges_data = truncate_list_by_token_size(
|
|
|
|
|
all_edges_data,
|
|
|
|
|
key=lambda x: x["description"],
|
|
|
|
|
max_token_size=query_param.max_token_for_global_context,
|
|
|
|
|
)
|
|
|
|
|
return all_edges_data
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-11-25 13:29:55 +08:00
|
|
|
|
async def _get_edge_data(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
keywords,
|
|
|
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
|
|
|
relationships_vdb: BaseVectorStorage,
|
|
|
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
|
|
|
query_param: QueryParam,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
):
|
|
|
|
|
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
if not len(results):
|
2024-12-10 14:13:11 +08:00
|
|
|
|
return "", "", ""
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
edge_datas = await asyncio.gather(
|
|
|
|
|
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
|
|
|
|
)
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
if not all([n is not None for n in edge_datas]):
|
|
|
|
|
logger.warning("Some edges are missing, maybe the storage is damaged")
|
|
|
|
|
edge_degree = await asyncio.gather(
|
|
|
|
|
*[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results]
|
|
|
|
|
)
|
|
|
|
|
edge_datas = [
|
2024-12-29 15:25:57 +08:00
|
|
|
|
{
|
2024-12-29 15:37:34 +08:00
|
|
|
|
"src_id": k["src_id"],
|
|
|
|
|
"tgt_id": k["tgt_id"],
|
|
|
|
|
"rank": d,
|
2024-12-29 15:25:57 +08:00
|
|
|
|
"created_at": k.get("__created_at__", None), # 从 KV 存储中获取时间元数据
|
2024-12-29 15:37:34 +08:00
|
|
|
|
**v,
|
2024-12-29 15:25:57 +08:00
|
|
|
|
}
|
2024-10-10 15:02:30 +08:00
|
|
|
|
for k, v, d in zip(results, edge_datas, edge_degree)
|
|
|
|
|
if v is not None
|
|
|
|
|
]
|
|
|
|
|
edge_datas = sorted(
|
|
|
|
|
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
|
|
|
|
)
|
|
|
|
|
edge_datas = truncate_list_by_token_size(
|
|
|
|
|
edge_datas,
|
|
|
|
|
key=lambda x: x["description"],
|
|
|
|
|
max_token_size=query_param.max_token_for_global_context,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
use_entities = await _find_most_related_entities_from_relationships(
|
|
|
|
|
edge_datas, query_param, knowledge_graph_inst
|
|
|
|
|
)
|
|
|
|
|
use_text_units = await _find_related_text_unit_from_relationships(
|
|
|
|
|
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
|
|
|
|
|
)
|
2024-11-25 13:29:55 +08:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
relations_section_list = [
|
2024-12-29 15:37:34 +08:00
|
|
|
|
[
|
|
|
|
|
"id",
|
|
|
|
|
"source",
|
|
|
|
|
"target",
|
|
|
|
|
"description",
|
|
|
|
|
"keywords",
|
|
|
|
|
"weight",
|
|
|
|
|
"rank",
|
|
|
|
|
"created_at",
|
|
|
|
|
]
|
2024-10-10 15:02:30 +08:00
|
|
|
|
]
|
|
|
|
|
for i, e in enumerate(edge_datas):
|
2024-12-29 15:37:34 +08:00
|
|
|
|
created_at = e.get("created_at", "Unknown")
|
|
|
|
|
# Convert timestamp to readable format
|
2024-12-29 15:25:57 +08:00
|
|
|
|
if isinstance(created_at, (int, float)):
|
|
|
|
|
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
2024-10-10 15:02:30 +08:00
|
|
|
|
relations_section_list.append(
|
|
|
|
|
[
|
|
|
|
|
i,
|
|
|
|
|
e["src_id"],
|
|
|
|
|
e["tgt_id"],
|
|
|
|
|
e["description"],
|
|
|
|
|
e["keywords"],
|
|
|
|
|
e["weight"],
|
|
|
|
|
e["rank"],
|
2024-12-29 15:37:34 +08:00
|
|
|
|
created_at,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
relations_context = list_of_list_to_csv(relations_section_list)
|
|
|
|
|
|
|
|
|
|
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
|
|
|
|
for i, n in enumerate(use_entities):
|
|
|
|
|
entites_section_list.append(
|
|
|
|
|
[
|
|
|
|
|
i,
|
|
|
|
|
n["entity_name"],
|
|
|
|
|
n.get("entity_type", "UNKNOWN"),
|
|
|
|
|
n.get("description", "UNKNOWN"),
|
|
|
|
|
n["rank"],
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
entities_context = list_of_list_to_csv(entites_section_list)
|
|
|
|
|
|
|
|
|
|
text_units_section_list = [["id", "content"]]
|
|
|
|
|
for i, t in enumerate(use_text_units):
|
|
|
|
|
text_units_section_list.append([i, t["content"]])
|
|
|
|
|
text_units_context = list_of_list_to_csv(text_units_section_list)
|
2024-11-25 13:40:38 +08:00
|
|
|
|
return entities_context, relations_context, text_units_context
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
async def _find_most_related_entities_from_relationships(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
edge_datas: list[dict],
|
|
|
|
|
query_param: QueryParam,
|
|
|
|
|
knowledge_graph_inst: BaseGraphStorage,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
):
|
2024-11-14 15:59:37 +08:00
|
|
|
|
entity_names = []
|
|
|
|
|
seen = set()
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
for e in edge_datas:
|
2024-11-14 15:59:37 +08:00
|
|
|
|
if e["src_id"] not in seen:
|
|
|
|
|
entity_names.append(e["src_id"])
|
|
|
|
|
seen.add(e["src_id"])
|
|
|
|
|
if e["tgt_id"] not in seen:
|
|
|
|
|
entity_names.append(e["tgt_id"])
|
|
|
|
|
seen.add(e["tgt_id"])
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
node_datas = await asyncio.gather(
|
|
|
|
|
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
node_degrees = await asyncio.gather(
|
|
|
|
|
*[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names]
|
|
|
|
|
)
|
|
|
|
|
node_datas = [
|
|
|
|
|
{**n, "entity_name": k, "rank": d}
|
|
|
|
|
for k, n, d in zip(entity_names, node_datas, node_degrees)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
node_datas = truncate_list_by_token_size(
|
|
|
|
|
node_datas,
|
|
|
|
|
key=lambda x: x["description"],
|
|
|
|
|
max_token_size=query_param.max_token_for_local_context,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return node_datas
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
async def _find_related_text_unit_from_relationships(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
edge_datas: list[dict],
|
|
|
|
|
query_param: QueryParam,
|
|
|
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
|
|
|
knowledge_graph_inst: BaseGraphStorage,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
):
|
|
|
|
|
text_units = [
|
|
|
|
|
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
|
|
|
|
for dp in edge_datas
|
|
|
|
|
]
|
|
|
|
|
all_text_units_lookup = {}
|
|
|
|
|
|
|
|
|
|
for index, unit_list in enumerate(text_units):
|
|
|
|
|
for c_id in unit_list:
|
|
|
|
|
if c_id not in all_text_units_lookup:
|
2024-12-09 15:08:30 +08:00
|
|
|
|
chunk_data = await text_chunks_db.get_by_id(c_id)
|
|
|
|
|
# Only store valid data
|
|
|
|
|
if chunk_data is not None and "content" in chunk_data:
|
|
|
|
|
all_text_units_lookup[c_id] = {
|
|
|
|
|
"data": chunk_data,
|
|
|
|
|
"order": index,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if not all_text_units_lookup:
|
|
|
|
|
logger.warning("No valid text chunks found")
|
|
|
|
|
return []
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-12-09 15:08:30 +08:00
|
|
|
|
all_text_units = [{"id": k, **v} for k, v in all_text_units_lookup.items()]
|
2024-10-19 09:43:17 +05:30
|
|
|
|
all_text_units = sorted(all_text_units, key=lambda x: x["order"])
|
2024-12-09 15:08:30 +08:00
|
|
|
|
|
|
|
|
|
# Ensure all text chunks have content
|
|
|
|
|
valid_text_units = [
|
|
|
|
|
t for t in all_text_units if t["data"] is not None and "content" in t["data"]
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if not valid_text_units:
|
|
|
|
|
logger.warning("No valid text chunks after filtering")
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
truncated_text_units = truncate_list_by_token_size(
|
|
|
|
|
valid_text_units,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
key=lambda x: x["data"]["content"],
|
|
|
|
|
max_token_size=query_param.max_token_for_text_unit,
|
|
|
|
|
)
|
2024-12-09 15:08:30 +08:00
|
|
|
|
|
|
|
|
|
all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units]
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
|
|
|
|
return all_text_units
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-11-25 13:29:55 +08:00
|
|
|
|
def combine_contexts(entities, relationships, sources):
|
2024-10-10 15:02:30 +08:00
|
|
|
|
# Function to extract entities, relationships, and sources from context strings
|
2024-11-25 13:29:55 +08:00
|
|
|
|
hl_entities, ll_entities = entities[0], entities[1]
|
2024-11-25 13:40:38 +08:00
|
|
|
|
hl_relationships, ll_relationships = relationships[0], relationships[1]
|
2024-11-25 13:29:55 +08:00
|
|
|
|
hl_sources, ll_sources = sources[0], sources[1]
|
2024-10-10 15:02:30 +08:00
|
|
|
|
# Combine and deduplicate the entities
|
2024-10-31 11:32:44 +08:00
|
|
|
|
combined_entities = process_combine_contexts(hl_entities, ll_entities)
|
2024-11-06 11:18:14 -05:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
# Combine and deduplicate the relationships
|
2024-11-06 11:18:14 -05:00
|
|
|
|
combined_relationships = process_combine_contexts(
|
|
|
|
|
hl_relationships, ll_relationships
|
|
|
|
|
)
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
# Combine and deduplicate the sources
|
2024-10-31 11:32:44 +08:00
|
|
|
|
combined_sources = process_combine_contexts(hl_sources, ll_sources)
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-11-25 13:29:55 +08:00
|
|
|
|
return combined_entities, combined_relationships, combined_sources
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
async def naive_query(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
query,
|
|
|
|
|
chunks_vdb: BaseVectorStorage,
|
|
|
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
|
|
|
query_param: QueryParam,
|
|
|
|
|
global_config: dict,
|
|
|
|
|
hashing_kv: BaseKVStorage = None,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
):
|
2024-12-08 17:35:52 +08:00
|
|
|
|
# Handle cache
|
2024-10-10 15:02:30 +08:00
|
|
|
|
use_model_func = global_config["llm_model_func"]
|
2024-12-08 17:35:52 +08:00
|
|
|
|
args_hash = compute_args_hash(query_param.mode, query)
|
|
|
|
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
|
|
|
|
hashing_kv, args_hash, query, query_param.mode
|
|
|
|
|
)
|
|
|
|
|
if cached_response is not None:
|
|
|
|
|
return cached_response
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
|
|
|
|
if not len(results):
|
|
|
|
|
return PROMPTS["fail_response"]
|
2024-12-09 15:08:30 +08:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
chunks_ids = [r["id"] for r in results]
|
|
|
|
|
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
2024-11-01 08:47:52 -04:00
|
|
|
|
|
2024-12-09 15:08:30 +08:00
|
|
|
|
# Filter out invalid chunks
|
|
|
|
|
valid_chunks = [
|
|
|
|
|
chunk for chunk in chunks if chunk is not None and "content" in chunk
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if not valid_chunks:
|
|
|
|
|
logger.warning("No valid chunks found after filtering")
|
|
|
|
|
return PROMPTS["fail_response"]
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
maybe_trun_chunks = truncate_list_by_token_size(
|
2024-12-09 15:08:30 +08:00
|
|
|
|
valid_chunks,
|
2024-10-10 15:02:30 +08:00
|
|
|
|
key=lambda x: x["content"],
|
|
|
|
|
max_token_size=query_param.max_token_for_text_unit,
|
|
|
|
|
)
|
2024-12-09 15:08:30 +08:00
|
|
|
|
|
|
|
|
|
if not maybe_trun_chunks:
|
|
|
|
|
logger.warning("No chunks left after truncation")
|
|
|
|
|
return PROMPTS["fail_response"]
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
|
2024-11-25 13:29:55 +08:00
|
|
|
|
section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
|
2024-12-09 15:08:30 +08:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
if query_param.only_need_context:
|
|
|
|
|
return section
|
2024-12-09 15:08:30 +08:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
sys_prompt_temp = PROMPTS["naive_rag_response"]
|
|
|
|
|
sys_prompt = sys_prompt_temp.format(
|
|
|
|
|
content_data=section, response_type=query_param.response_type
|
|
|
|
|
)
|
2024-12-09 15:08:30 +08:00
|
|
|
|
|
2024-11-15 12:57:01 +08:00
|
|
|
|
if query_param.only_need_prompt:
|
|
|
|
|
return sys_prompt
|
2024-12-09 15:08:30 +08:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
response = await use_model_func(
|
|
|
|
|
query,
|
|
|
|
|
system_prompt=sys_prompt,
|
|
|
|
|
)
|
2024-10-14 19:41:07 +08:00
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
if len(response) > len(sys_prompt):
|
|
|
|
|
response = (
|
2025-01-07 16:26:12 +08:00
|
|
|
|
response[len(sys_prompt) :]
|
2024-10-19 09:43:17 +05:30
|
|
|
|
.replace(sys_prompt, "")
|
|
|
|
|
.replace("user", "")
|
|
|
|
|
.replace("model", "")
|
|
|
|
|
.replace(query, "")
|
|
|
|
|
.replace("<system>", "")
|
|
|
|
|
.replace("</system>", "")
|
|
|
|
|
.strip()
|
|
|
|
|
)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
2024-12-08 17:35:52 +08:00
|
|
|
|
# Save to cache
|
|
|
|
|
await save_to_cache(
|
|
|
|
|
hashing_kv,
|
|
|
|
|
CacheData(
|
|
|
|
|
args_hash=args_hash,
|
|
|
|
|
content=response,
|
|
|
|
|
prompt=query,
|
|
|
|
|
quantized=quantized,
|
|
|
|
|
min_val=min_val,
|
|
|
|
|
max_val=max_val,
|
|
|
|
|
mode=query_param.mode,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
2024-11-06 11:18:14 -05:00
|
|
|
|
return response
|
2024-12-28 11:56:28 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def mix_kg_vector_query(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
query,
|
|
|
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
|
|
|
entities_vdb: BaseVectorStorage,
|
|
|
|
|
relationships_vdb: BaseVectorStorage,
|
|
|
|
|
chunks_vdb: BaseVectorStorage,
|
|
|
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
|
|
|
query_param: QueryParam,
|
|
|
|
|
global_config: dict,
|
|
|
|
|
hashing_kv: BaseKVStorage = None,
|
2024-12-28 11:56:28 +08:00
|
|
|
|
) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Hybrid retrieval implementation combining knowledge graph and vector search.
|
|
|
|
|
|
|
|
|
|
This function performs a hybrid search by:
|
|
|
|
|
1. Extracting semantic information from knowledge graph
|
|
|
|
|
2. Retrieving relevant text chunks through vector similarity
|
|
|
|
|
3. Combining both results for comprehensive answer generation
|
|
|
|
|
"""
|
|
|
|
|
# 1. Cache handling
|
|
|
|
|
use_model_func = global_config["llm_model_func"]
|
|
|
|
|
args_hash = compute_args_hash("mix", query)
|
|
|
|
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
|
|
|
|
hashing_kv, args_hash, query, "mix"
|
|
|
|
|
)
|
|
|
|
|
if cached_response is not None:
|
|
|
|
|
return cached_response
|
|
|
|
|
|
|
|
|
|
# 2. Execute knowledge graph and vector searches in parallel
|
|
|
|
|
async def get_kg_context():
|
|
|
|
|
try:
|
|
|
|
|
# Reuse keyword extraction logic from kg_query
|
|
|
|
|
example_number = global_config["addon_params"].get("example_number", None)
|
|
|
|
|
if example_number and example_number < len(
|
2025-01-07 16:26:12 +08:00
|
|
|
|
PROMPTS["keywords_extraction_examples"]
|
2024-12-28 11:56:28 +08:00
|
|
|
|
):
|
|
|
|
|
examples = "\n".join(
|
|
|
|
|
PROMPTS["keywords_extraction_examples"][: int(example_number)]
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
|
|
|
|
|
|
|
|
|
|
language = global_config["addon_params"].get(
|
|
|
|
|
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Extract keywords using LLM
|
|
|
|
|
kw_prompt = PROMPTS["keywords_extraction"].format(
|
|
|
|
|
query=query, examples=examples, language=language
|
|
|
|
|
)
|
|
|
|
|
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
|
|
|
|
|
|
|
|
|
match = re.search(r"\{.*\}", result, re.DOTALL)
|
|
|
|
|
if not match:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"No JSON-like structure found in keywords extraction result"
|
|
|
|
|
)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
result = match.group(0)
|
|
|
|
|
keywords_data = json.loads(result)
|
|
|
|
|
hl_keywords = keywords_data.get("high_level_keywords", [])
|
|
|
|
|
ll_keywords = keywords_data.get("low_level_keywords", [])
|
|
|
|
|
|
|
|
|
|
if not hl_keywords and not ll_keywords:
|
|
|
|
|
logger.warning("Both high-level and low-level keywords are empty")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# Convert keyword lists to strings
|
|
|
|
|
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
|
|
|
|
|
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
|
|
|
|
|
|
|
|
|
|
# Set query mode based on available keywords
|
|
|
|
|
if not ll_keywords_str and not hl_keywords_str:
|
|
|
|
|
return None
|
|
|
|
|
elif not ll_keywords_str:
|
|
|
|
|
query_param.mode = "global"
|
|
|
|
|
elif not hl_keywords_str:
|
|
|
|
|
query_param.mode = "local"
|
|
|
|
|
else:
|
|
|
|
|
query_param.mode = "hybrid"
|
|
|
|
|
|
|
|
|
|
# Build knowledge graph context
|
|
|
|
|
context = await _build_query_context(
|
|
|
|
|
[ll_keywords_str, hl_keywords_str],
|
|
|
|
|
knowledge_graph_inst,
|
|
|
|
|
entities_vdb,
|
|
|
|
|
relationships_vdb,
|
|
|
|
|
text_chunks_db,
|
|
|
|
|
query_param,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return context
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in get_kg_context: {str(e)}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
async def get_vector_context():
|
|
|
|
|
# Reuse vector search logic from naive_query
|
|
|
|
|
try:
|
|
|
|
|
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
|
|
|
|
|
mix_topk = min(10, query_param.top_k)
|
|
|
|
|
results = await chunks_vdb.query(query, top_k=mix_topk)
|
|
|
|
|
if not results:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
chunks_ids = [r["id"] for r in results]
|
|
|
|
|
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
|
|
|
|
|
2024-12-29 15:25:57 +08:00
|
|
|
|
valid_chunks = []
|
|
|
|
|
for chunk, result in zip(chunks, results):
|
|
|
|
|
if chunk is not None and "content" in chunk:
|
2024-12-29 15:37:34 +08:00
|
|
|
|
# Merge chunk content and time metadata
|
2024-12-29 15:25:57 +08:00
|
|
|
|
chunk_with_time = {
|
|
|
|
|
"content": chunk["content"],
|
2024-12-29 15:37:34 +08:00
|
|
|
|
"created_at": result.get("created_at", None),
|
2024-12-29 15:25:57 +08:00
|
|
|
|
}
|
|
|
|
|
valid_chunks.append(chunk_with_time)
|
2024-12-28 11:56:28 +08:00
|
|
|
|
|
|
|
|
|
if not valid_chunks:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
maybe_trun_chunks = truncate_list_by_token_size(
|
|
|
|
|
valid_chunks,
|
|
|
|
|
key=lambda x: x["content"],
|
|
|
|
|
max_token_size=query_param.max_token_for_text_unit,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not maybe_trun_chunks:
|
|
|
|
|
return None
|
|
|
|
|
|
2024-12-29 15:37:34 +08:00
|
|
|
|
# Include time information in content
|
2024-12-29 15:25:57 +08:00
|
|
|
|
formatted_chunks = []
|
|
|
|
|
for c in maybe_trun_chunks:
|
|
|
|
|
chunk_text = c["content"]
|
|
|
|
|
if c["created_at"]:
|
|
|
|
|
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}"
|
|
|
|
|
formatted_chunks.append(chunk_text)
|
|
|
|
|
|
|
|
|
|
return "\n--New Chunk--\n".join(formatted_chunks)
|
2024-12-28 11:56:28 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in get_vector_context: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 3. Execute both retrievals in parallel
|
|
|
|
|
kg_context, vector_context = await asyncio.gather(
|
|
|
|
|
get_kg_context(), get_vector_context()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 4. Merge contexts
|
|
|
|
|
if kg_context is None and vector_context is None:
|
|
|
|
|
return PROMPTS["fail_response"]
|
|
|
|
|
|
|
|
|
|
if query_param.only_need_context:
|
|
|
|
|
return {"kg_context": kg_context, "vector_context": vector_context}
|
|
|
|
|
|
|
|
|
|
# 5. Construct hybrid prompt
|
|
|
|
|
sys_prompt = PROMPTS["mix_rag_response"].format(
|
|
|
|
|
kg_context=kg_context
|
|
|
|
|
if kg_context
|
|
|
|
|
else "No relevant knowledge graph information found",
|
|
|
|
|
vector_context=vector_context
|
|
|
|
|
if vector_context
|
|
|
|
|
else "No relevant text information found",
|
|
|
|
|
response_type=query_param.response_type,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if query_param.only_need_prompt:
|
|
|
|
|
return sys_prompt
|
|
|
|
|
|
|
|
|
|
# 6. Generate response
|
|
|
|
|
response = await use_model_func(
|
|
|
|
|
query,
|
|
|
|
|
system_prompt=sys_prompt,
|
|
|
|
|
stream=query_param.stream,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if isinstance(response, str) and len(response) > len(sys_prompt):
|
|
|
|
|
response = (
|
|
|
|
|
response.replace(sys_prompt, "")
|
|
|
|
|
.replace("user", "")
|
|
|
|
|
.replace("model", "")
|
|
|
|
|
.replace(query, "")
|
|
|
|
|
.replace("<system>", "")
|
|
|
|
|
.replace("</system>", "")
|
|
|
|
|
.strip()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 7. Save cache
|
|
|
|
|
await save_to_cache(
|
|
|
|
|
hashing_kv,
|
|
|
|
|
CacheData(
|
|
|
|
|
args_hash=args_hash,
|
|
|
|
|
content=response,
|
|
|
|
|
prompt=query,
|
|
|
|
|
quantized=quantized,
|
|
|
|
|
min_val=min_val,
|
|
|
|
|
max_val=max_val,
|
|
|
|
|
mode="mix",
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return response
|