LightRAG/lightrag/operate.py

1002 lines
33 KiB
Python
Raw Normal View History

2024-10-10 15:02:30 +08:00
import asyncio
import json
import re
from typing import Union
from collections import Counter, defaultdict
2024-10-14 19:41:07 +08:00
import warnings
2024-10-10 15:02:30 +08:00
from .utils import (
logger,
clean_str,
compute_mdhash_id,
decode_tokens_by_tiktoken,
encode_string_by_tiktoken,
is_float_regex,
list_of_list_to_csv,
pack_user_ass_to_openai_messages,
split_string_by_multi_markers,
truncate_list_by_token_size,
process_combine_contexts,
2024-11-12 13:32:40 +08:00
locate_json_string_body_from_string,
2024-10-10 15:02:30 +08:00
)
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
TextChunkSchema,
QueryParam,
)
from .prompt import GRAPH_FIELD_SEP, PROMPTS
2024-10-10 15:02:30 +08:00
def chunking_by_token_size(
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
):
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
results = []
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
):
chunk_content = decode_tokens_by_tiktoken(
tokens[start : start + max_token_size], model_name=tiktoken_model
)
results.append(
{
"tokens": min(max_token_size, len(tokens) - start),
"content": chunk_content.strip(),
"chunk_order_index": index,
}
)
return results
2024-10-10 15:02:30 +08:00
async def _handle_entity_relation_summary(
entity_or_relation_name: str,
description: str,
global_config: dict,
) -> str:
use_llm_func: callable = global_config["llm_model_func"]
llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
if len(tokens) < summary_max_tokens: # No need for summary
return description
prompt_template = PROMPTS["summarize_entity_descriptions"]
use_description = decode_tokens_by_tiktoken(
tokens[:llm_max_tokens], model_name=tiktoken_model_name
)
context_base = dict(
entity_name=entity_or_relation_name,
description_list=use_description.split(GRAPH_FIELD_SEP),
)
use_prompt = prompt_template.format(**context_base)
logger.debug(f"Trigger summary: {entity_or_relation_name}")
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
return summary
async def _handle_single_entity_extraction(
record_attributes: list[str],
chunk_key: str,
):
2024-10-18 15:33:11 +08:00
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
2024-10-10 15:02:30 +08:00
return None
# add this record as a node in the G
entity_name = clean_str(record_attributes[1].upper())
if not entity_name.strip():
return None
entity_type = clean_str(record_attributes[2].upper())
entity_description = clean_str(record_attributes[3])
entity_source_id = chunk_key
return dict(
entity_name=entity_name,
entity_type=entity_type,
description=entity_description,
source_id=entity_source_id,
)
async def _handle_single_relationship_extraction(
record_attributes: list[str],
chunk_key: str,
):
2024-10-18 15:33:11 +08:00
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
2024-10-10 15:02:30 +08:00
return None
# add this record as edge
source = clean_str(record_attributes[1].upper())
target = clean_str(record_attributes[2].upper())
edge_description = clean_str(record_attributes[3])
edge_keywords = clean_str(record_attributes[4])
edge_source_id = chunk_key
weight = (
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
)
return dict(
src_id=source,
tgt_id=target,
weight=weight,
description=edge_description,
keywords=edge_keywords,
source_id=edge_source_id,
)
async def _merge_nodes_then_upsert(
entity_name: str,
nodes_data: list[dict],
2024-10-26 00:11:21 -04:00
knowledge_graph_inst: BaseGraphStorage,
2024-10-10 15:02:30 +08:00
global_config: dict,
):
already_entitiy_types = []
already_source_ids = []
already_description = []
2024-10-26 00:11:21 -04:00
already_node = await knowledge_graph_inst.get_node(entity_name)
2024-10-10 15:02:30 +08:00
if already_node is not None:
already_entitiy_types.append(already_node["entity_type"])
already_source_ids.extend(
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
)
already_description.append(already_node["description"])
entity_type = sorted(
Counter(
[dp["entity_type"] for dp in nodes_data] + already_entitiy_types
).items(),
key=lambda x: x[1],
reverse=True,
)[0][0]
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in nodes_data] + already_description))
)
source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
)
description = await _handle_entity_relation_summary(
entity_name, description, global_config
)
node_data = dict(
entity_type=entity_type,
description=description,
source_id=source_id,
)
2024-10-26 00:11:21 -04:00
await knowledge_graph_inst.upsert_node(
2024-10-10 15:02:30 +08:00
entity_name,
node_data=node_data,
)
node_data["entity_name"] = entity_name
return node_data
async def _merge_edges_then_upsert(
src_id: str,
tgt_id: str,
edges_data: list[dict],
2024-10-26 00:11:21 -04:00
knowledge_graph_inst: BaseGraphStorage,
2024-10-10 15:02:30 +08:00
global_config: dict,
):
already_weights = []
already_source_ids = []
already_description = []
already_keywords = []
2024-10-26 00:11:21 -04:00
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
2024-10-10 15:02:30 +08:00
already_weights.append(already_edge["weight"])
already_source_ids.extend(
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
)
already_description.append(already_edge["description"])
already_keywords.extend(
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
)
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in edges_data] + already_description))
)
keywords = GRAPH_FIELD_SEP.join(
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
)
source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in edges_data] + already_source_ids)
)
for need_insert_id in [src_id, tgt_id]:
2024-10-26 00:11:21 -04:00
if not (await knowledge_graph_inst.has_node(need_insert_id)):
await knowledge_graph_inst.upsert_node(
2024-10-10 15:02:30 +08:00
need_insert_id,
node_data={
"source_id": source_id,
"description": description,
"entity_type": '"UNKNOWN"',
},
)
description = await _handle_entity_relation_summary(
(src_id, tgt_id), description, global_config
)
2024-10-26 00:11:21 -04:00
await knowledge_graph_inst.upsert_edge(
2024-10-10 15:02:30 +08:00
src_id,
tgt_id,
edge_data=dict(
weight=weight,
description=description,
keywords=keywords,
source_id=source_id,
),
)
edge_data = dict(
src_id=src_id,
tgt_id=tgt_id,
description=description,
keywords=keywords,
)
2024-10-10 15:02:30 +08:00
return edge_data
2024-10-10 15:02:30 +08:00
async def extract_entities(
chunks: dict[str, TextChunkSchema],
2024-10-26 00:11:21 -04:00
knowledge_graph_inst: BaseGraphStorage,
2024-10-10 15:02:30 +08:00
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict,
) -> Union[BaseGraphStorage, None]:
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
ordered_chunks = list(chunks.items())
2024-11-25 13:29:55 +08:00
# add language and example number params to prompt
2024-11-25 13:40:38 +08:00
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
2024-11-25 13:29:55 +08:00
example_number = global_config["addon_params"].get("example_number", None)
2024-11-25 13:40:38 +08:00
if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
examples = "\n".join(
PROMPTS["entity_extraction_examples"][: int(example_number)]
)
2024-11-25 13:29:55 +08:00
else:
2024-11-25 13:40:38 +08:00
examples = "\n".join(PROMPTS["entity_extraction_examples"])
2024-10-10 15:02:30 +08:00
entity_extract_prompt = PROMPTS["entity_extraction"]
context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
2024-11-25 13:29:55 +08:00
examples=examples,
2024-11-25 13:40:38 +08:00
language=language,
)
2024-10-10 15:02:30 +08:00
continue_prompt = PROMPTS["entiti_continue_extraction"]
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
already_processed = 0
already_entities = 0
already_relations = 0
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
nonlocal already_processed, already_entities, already_relations
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
final_result = await use_llm_func(hint_prompt)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await use_llm_func(continue_prompt, history_messages=history)
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
final_result += glean_result
if now_glean_index == entity_extract_max_gleaning - 1:
break
if_loop_result: str = await use_llm_func(
if_loop_prompt, history_messages=history
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
records = split_string_by_multi_markers(
final_result,
[context_base["record_delimiter"], context_base["completion_delimiter"]],
)
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
record = record.group(1)
record_attributes = split_string_by_multi_markers(
record, [context_base["tuple_delimiter"]]
)
if_entities = await _handle_single_entity_extraction(
record_attributes, chunk_key
)
if if_entities is not None:
maybe_nodes[if_entities["entity_name"]].append(if_entities)
continue
if_relation = await _handle_single_relationship_extraction(
record_attributes, chunk_key
)
if if_relation is not None:
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
if_relation
)
already_processed += 1
already_entities += len(maybe_nodes)
already_relations += len(maybe_edges)
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="",
flush=True,
)
return dict(maybe_nodes), dict(maybe_edges)
# use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings
results = await asyncio.gather(
*[_process_single_content(c) for c in ordered_chunks]
)
print() # clear the progress bar
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for m_nodes, m_edges in results:
for k, v in m_nodes.items():
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v)
all_entities_data = await asyncio.gather(
*[
2024-10-26 00:11:21 -04:00
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
2024-10-10 15:02:30 +08:00
for k, v in maybe_nodes.items()
]
)
all_relationships_data = await asyncio.gather(
*[
2024-10-26 00:11:21 -04:00
_merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
2024-10-10 15:02:30 +08:00
for k, v in maybe_edges.items()
]
)
if not len(all_entities_data):
logger.warning("Didn't extract any entities, maybe your LLM is not working")
return None
if not len(all_relationships_data):
logger.warning(
"Didn't extract any relationships, maybe your LLM is not working"
)
2024-10-10 15:02:30 +08:00
return None
if entity_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
for dp in all_entities_data
}
await entity_vdb.upsert(data_for_vdb)
if relationships_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
2024-10-10 15:02:30 +08:00
}
for dp in all_relationships_data
}
await relationships_vdb.upsert(data_for_vdb)
2024-10-26 00:11:21 -04:00
return knowledge_graph_inst
2024-10-10 15:02:30 +08:00
2024-11-25 13:40:38 +08:00
2024-11-25 13:29:55 +08:00
async def kg_query(
2024-10-10 15:02:30 +08:00
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
2024-10-17 16:02:43 +08:00
context = None
2024-11-25 13:29:55 +08:00
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
2024-11-25 13:40:38 +08:00
examples = "\n".join(
PROMPTS["keywords_extraction_examples"][: int(example_number)]
)
2024-11-25 13:29:55 +08:00
else:
2024-11-25 13:40:38 +08:00
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
2024-11-25 13:29:55 +08:00
# Set mode
if query_param.mode not in ["local", "global", "hybrid"]:
logger.error(f"Unknown mode {query_param.mode} in kg_query")
return PROMPTS["fail_response"]
2024-11-25 13:40:38 +08:00
2024-11-25 13:29:55 +08:00
# LLM generate keywords
2024-10-10 15:02:30 +08:00
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
2024-11-25 13:40:38 +08:00
kw_prompt = kw_prompt_temp.format(query=query, examples=examples)
result = await use_model_func(kw_prompt)
logger.info("kw_prompt result:")
2024-11-25 13:29:55 +08:00
print(result)
2024-10-10 15:02:30 +08:00
try:
2024-11-25 13:29:55 +08:00
json_text = locate_json_string_body_from_string(result)
2024-11-11 15:19:42 +08:00
keywords_data = json.loads(json_text)
2024-11-25 13:29:55 +08:00
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
2024-11-25 13:40:38 +08:00
2024-11-25 13:29:55 +08:00
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e} {result}")
return PROMPTS["fail_response"]
2024-11-25 13:40:38 +08:00
2024-11-25 13:29:55 +08:00
# Handdle keywords missing
if hl_keywords == [] and ll_keywords == []:
logger.warning("low_level_keywords and high_level_keywords is empty")
2024-11-25 13:40:38 +08:00
return PROMPTS["fail_response"]
if ll_keywords == [] and query_param.mode in ["local", "hybrid"]:
2024-11-25 13:29:55 +08:00
logger.warning("low_level_keywords is empty")
return PROMPTS["fail_response"]
else:
ll_keywords = ", ".join(ll_keywords)
2024-11-25 13:40:38 +08:00
if hl_keywords == [] and query_param.mode in ["global", "hybrid"]:
2024-11-25 13:29:55 +08:00
logger.warning("high_level_keywords is empty")
return PROMPTS["fail_response"]
else:
hl_keywords = ", ".join(hl_keywords)
2024-11-25 13:40:38 +08:00
2024-11-25 13:29:55 +08:00
# Build context
2024-11-25 13:40:38 +08:00
keywords = [ll_keywords, hl_keywords]
2024-11-25 13:29:55 +08:00
context = await _build_query_context(
2024-11-25 13:40:38 +08:00
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
2024-10-10 15:02:30 +08:00
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
2024-11-25 13:40:38 +08:00
)
if query_param.only_need_prompt:
return sys_prompt
2024-10-10 15:02:30 +08:00
response = await use_model_func(
query,
system_prompt=sys_prompt,
2024-11-25 13:40:38 +08:00
)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
2024-10-10 15:02:30 +08:00
return response
2024-11-25 13:29:55 +08:00
async def _build_query_context(
query: list,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
2024-11-25 13:40:38 +08:00
):
2024-11-25 13:29:55 +08:00
ll_kewwords, hl_keywrds = query[0], query[1]
if query_param.mode in ["local", "hybrid"]:
if ll_kewwords == "":
2024-11-25 13:40:38 +08:00
ll_entities_context, ll_relations_context, ll_text_units_context = (
"",
"",
"",
)
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
2024-11-25 13:29:55 +08:00
query_param.mode = "global"
else:
2024-11-25 13:40:38 +08:00
(
ll_entities_context,
ll_relations_context,
ll_text_units_context,
) = await _get_node_data(
2024-11-25 13:29:55 +08:00
ll_kewwords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
2024-11-25 13:40:38 +08:00
query_param,
)
2024-11-25 13:29:55 +08:00
if query_param.mode in ["global", "hybrid"]:
if hl_keywrds == "":
2024-11-25 13:40:38 +08:00
hl_entities_context, hl_relations_context, hl_text_units_context = (
"",
"",
"",
)
warnings.warn(
"High Level context is None. Return empty High entity/relationship/source"
)
2024-11-25 13:29:55 +08:00
query_param.mode = "local"
else:
2024-11-25 13:40:38 +08:00
(
hl_entities_context,
hl_relations_context,
hl_text_units_context,
) = await _get_edge_data(
2024-11-25 13:29:55 +08:00
hl_keywrds,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
2024-11-25 13:40:38 +08:00
query_param,
)
if query_param.mode == "hybrid":
entities_context, relations_context, text_units_context = combine_contexts(
[hl_entities_context, ll_entities_context],
[hl_relations_context, ll_relations_context],
[hl_text_units_context, ll_text_units_context],
)
elif query_param.mode == "local":
entities_context, relations_context, text_units_context = (
ll_entities_context,
ll_relations_context,
ll_text_units_context,
)
elif query_param.mode == "global":
entities_context, relations_context, text_units_context = (
hl_entities_context,
hl_relations_context,
hl_text_units_context,
)
2024-11-25 13:29:55 +08:00
return f"""
# -----Entities-----
# ```csv
# {entities_context}
# ```
# -----Relationships-----
# ```csv
# {relations_context}
# ```
# -----Sources-----
# ```csv
# {text_units_context}
# ```
# """
async def _get_node_data(
2024-10-10 15:02:30 +08:00
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
2024-11-25 13:29:55 +08:00
# 获取相似的实体
2024-10-10 15:02:30 +08:00
results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return None
2024-11-25 13:29:55 +08:00
# 获取实体信息
2024-10-10 15:02:30 +08:00
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
)
if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged")
2024-11-25 13:40:38 +08:00
2024-11-25 13:29:55 +08:00
# 获取实体的度
2024-10-10 15:02:30 +08:00
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
)
node_datas = [
{**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(results, node_datas, node_degrees)
if n is not None
2024-11-06 11:18:14 -05:00
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
2024-11-25 13:29:55 +08:00
# 根据实体获取文本片段
2024-10-10 15:02:30 +08:00
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst
)
2024-11-25 13:29:55 +08:00
# 获取关联的边
2024-10-10 15:02:30 +08:00
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst
)
logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
2024-11-25 13:40:38 +08:00
)
2024-11-25 13:29:55 +08:00
# 构建提示词
2024-10-10 15:02:30 +08:00
entites_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(node_datas):
entites_section_list.append(
[
i,
n["entity_name"],
n.get("entity_type", "UNKNOWN"),
n.get("description", "UNKNOWN"),
n["rank"],
]
)
entities_context = list_of_list_to_csv(entites_section_list)
relations_section_list = [
["id", "source", "target", "description", "keywords", "weight", "rank"]
]
for i, e in enumerate(use_relations):
relations_section_list.append(
[
i,
e["src_tgt"][0],
e["src_tgt"][1],
e["description"],
e["keywords"],
e["weight"],
e["rank"],
]
)
relations_context = list_of_list_to_csv(relations_section_list)
text_units_section_list = [["id", "content"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
2024-11-25 13:40:38 +08:00
return entities_context, relations_context, text_units_context
2024-10-10 15:02:30 +08:00
2024-10-10 15:02:30 +08:00
async def _find_most_related_text_unit_from_entities(
node_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
for dp in node_datas
]
edges = await asyncio.gather(
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
)
all_one_hop_nodes = set()
for this_edges in edges:
if not this_edges:
continue
all_one_hop_nodes.update([e[1] for e in this_edges])
2024-11-11 10:45:22 +08:00
2024-10-10 15:02:30 +08:00
all_one_hop_nodes = list(all_one_hop_nodes)
all_one_hop_nodes_data = await asyncio.gather(
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
)
2024-11-11 10:45:22 +08:00
2024-11-05 18:36:59 -08:00
# Add null check for node data
2024-10-10 15:02:30 +08:00
all_one_hop_text_units_lookup = {
k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP]))
for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data)
2024-11-05 18:36:59 -08:00
if v is not None and "source_id" in v # Add source_id check
2024-10-10 15:02:30 +08:00
}
2024-11-11 10:45:22 +08:00
2024-10-10 15:02:30 +08:00
all_text_units_lookup = {}
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
for c_id in this_text_units:
2024-11-21 14:35:18 +08:00
if c_id not in all_text_units_lookup:
all_text_units_lookup[c_id] = {
"data": await text_chunks_db.get_by_id(c_id),
"order": index,
"relation_counts": 0,
}
if this_edges:
2024-11-05 18:36:59 -08:00
for e in this_edges:
if (
e[1] in all_one_hop_text_units_lookup
and c_id in all_one_hop_text_units_lookup[e[1]]
):
2024-11-21 14:35:18 +08:00
all_text_units_lookup[c_id]["relation_counts"] += 1
2024-11-11 10:45:22 +08:00
2024-11-05 18:36:59 -08:00
# Filter out None values and ensure data has content
2024-10-10 15:02:30 +08:00
all_text_units = [
2024-11-11 10:45:22 +08:00
{"id": k, **v}
for k, v in all_text_units_lookup.items()
2024-11-05 18:36:59 -08:00
if v is not None and v.get("data") is not None and "content" in v["data"]
2024-10-10 15:02:30 +08:00
]
2024-11-11 10:45:22 +08:00
2024-11-05 18:36:59 -08:00
if not all_text_units:
logger.warning("No valid text units found")
return []
2024-11-11 10:45:22 +08:00
2024-10-10 15:02:30 +08:00
all_text_units = sorted(
2024-11-11 10:45:22 +08:00
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
2024-10-10 15:02:30 +08:00
)
2024-11-11 10:45:22 +08:00
2024-10-10 15:02:30 +08:00
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
)
2024-11-11 10:45:22 +08:00
2024-11-05 18:36:59 -08:00
all_text_units = [t["data"] for t in all_text_units]
2024-10-10 15:02:30 +08:00
return all_text_units
2024-10-10 15:02:30 +08:00
async def _find_most_related_edges_from_entities(
node_datas: list[dict],
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
):
all_related_edges = await asyncio.gather(
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
)
2024-11-14 15:59:37 +08:00
all_edges = []
seen = set()
2024-10-10 15:02:30 +08:00
for this_edges in all_related_edges:
2024-11-14 15:59:37 +08:00
for e in this_edges:
sorted_edge = tuple(sorted(e))
if sorted_edge not in seen:
seen.add(sorted_edge)
all_edges.append(sorted_edge)
2024-10-10 15:02:30 +08:00
all_edges_pack = await asyncio.gather(
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
)
all_edges_degree = await asyncio.gather(
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
)
all_edges_data = [
{"src_tgt": k, "rank": d, **v}
for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree)
if v is not None
]
all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
all_edges_data = truncate_list_by_token_size(
all_edges_data,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_global_context,
)
return all_edges_data
2024-11-25 13:29:55 +08:00
async def _get_edge_data(
2024-10-10 15:02:30 +08:00
keywords,
knowledge_graph_inst: BaseGraphStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
2024-10-10 15:02:30 +08:00
if not len(results):
return None
2024-10-10 15:02:30 +08:00
edge_datas = await asyncio.gather(
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
)
2024-10-10 15:02:30 +08:00
if not all([n is not None for n in edge_datas]):
logger.warning("Some edges are missing, maybe the storage is damaged")
edge_degree = await asyncio.gather(
*[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results]
)
edge_datas = [
{"src_id": k["src_id"], "tgt_id": k["tgt_id"], "rank": d, **v}
for k, v, d in zip(results, edge_datas, edge_degree)
if v is not None
]
edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
edge_datas = truncate_list_by_token_size(
edge_datas,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_global_context,
)
use_entities = await _find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst
)
use_text_units = await _find_related_text_unit_from_relationships(
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
)
logger.info(
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
)
2024-11-25 13:29:55 +08:00
2024-10-10 15:02:30 +08:00
relations_section_list = [
["id", "source", "target", "description", "keywords", "weight", "rank"]
]
for i, e in enumerate(edge_datas):
relations_section_list.append(
[
i,
e["src_id"],
e["tgt_id"],
e["description"],
e["keywords"],
e["weight"],
e["rank"],
]
)
relations_context = list_of_list_to_csv(relations_section_list)
entites_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(use_entities):
entites_section_list.append(
[
i,
n["entity_name"],
n.get("entity_type", "UNKNOWN"),
n.get("description", "UNKNOWN"),
n["rank"],
]
)
entities_context = list_of_list_to_csv(entites_section_list)
text_units_section_list = [["id", "content"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
2024-11-25 13:40:38 +08:00
return entities_context, relations_context, text_units_context
2024-10-10 15:02:30 +08:00
2024-10-10 15:02:30 +08:00
async def _find_most_related_entities_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
):
2024-11-14 15:59:37 +08:00
entity_names = []
seen = set()
2024-10-10 15:02:30 +08:00
for e in edge_datas:
2024-11-14 15:59:37 +08:00
if e["src_id"] not in seen:
entity_names.append(e["src_id"])
seen.add(e["src_id"])
if e["tgt_id"] not in seen:
entity_names.append(e["tgt_id"])
seen.add(e["tgt_id"])
2024-10-10 15:02:30 +08:00
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
)
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names]
)
node_datas = [
{**n, "entity_name": k, "rank": d}
for k, n, d in zip(entity_names, node_datas, node_degrees)
]
node_datas = truncate_list_by_token_size(
node_datas,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_local_context,
)
return node_datas
2024-10-10 15:02:30 +08:00
async def _find_related_text_unit_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
for dp in edge_datas
]
all_text_units_lookup = {}
for index, unit_list in enumerate(text_units):
for c_id in unit_list:
if c_id not in all_text_units_lookup:
all_text_units_lookup[c_id] = {
"data": await text_chunks_db.get_by_id(c_id),
"order": index,
}
2024-10-10 15:02:30 +08:00
if any([v is None for v in all_text_units_lookup.values()]):
logger.warning("Text chunks are missing, maybe the storage is damaged")
all_text_units = [
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
]
all_text_units = sorted(all_text_units, key=lambda x: x["order"])
2024-10-10 15:02:30 +08:00
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
)
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
return all_text_units
2024-11-25 13:29:55 +08:00
def combine_contexts(entities, relationships, sources):
2024-10-10 15:02:30 +08:00
# Function to extract entities, relationships, and sources from context strings
2024-11-25 13:29:55 +08:00
hl_entities, ll_entities = entities[0], entities[1]
2024-11-25 13:40:38 +08:00
hl_relationships, ll_relationships = relationships[0], relationships[1]
2024-11-25 13:29:55 +08:00
hl_sources, ll_sources = sources[0], sources[1]
2024-10-10 15:02:30 +08:00
# Combine and deduplicate the entities
combined_entities = process_combine_contexts(hl_entities, ll_entities)
2024-11-06 11:18:14 -05:00
2024-10-10 15:02:30 +08:00
# Combine and deduplicate the relationships
2024-11-06 11:18:14 -05:00
combined_relationships = process_combine_contexts(
hl_relationships, ll_relationships
)
2024-10-10 15:02:30 +08:00
# Combine and deduplicate the sources
combined_sources = process_combine_contexts(hl_sources, ll_sources)
2024-11-25 13:29:55 +08:00
return combined_entities, combined_relationships, combined_sources
2024-10-10 15:02:30 +08:00
2024-10-10 15:02:30 +08:00
async def naive_query(
query,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
):
use_model_func = global_config["llm_model_func"]
results = await chunks_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return PROMPTS["fail_response"]
chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids)
2024-10-10 15:02:30 +08:00
maybe_trun_chunks = truncate_list_by_token_size(
chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
)
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
2024-11-25 13:29:55 +08:00
section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
2024-10-10 15:02:30 +08:00
if query_param.only_need_context:
return section
sys_prompt_temp = PROMPTS["naive_rag_response"]
sys_prompt = sys_prompt_temp.format(
content_data=section, response_type=query_param.response_type
)
if query_param.only_need_prompt:
return sys_prompt
2024-10-10 15:02:30 +08:00
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
2024-10-14 19:41:07 +08:00
if len(response) > len(sys_prompt):
response = (
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
2024-10-10 15:02:30 +08:00
2024-11-06 11:18:14 -05:00
return response