Update operate.py

This commit is contained in:
zrguo 2025-07-15 18:57:57 +08:00
parent 93b25a65d5
commit 29e82723e6

View File

@ -1892,26 +1892,42 @@ async def _build_query_context(
entities_context = []
relations_context = []
# Store original data for later text chunk retrieval
original_node_datas = []
original_edge_datas = []
# Handle local and global modes
if query_param.mode == "local":
entities_context, relations_context, entity_chunks = await _get_node_data(
(
entities_context,
relations_context,
node_datas,
use_relations,
) = await _get_node_data(
ll_keywords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
all_chunks.extend(entity_chunks)
original_node_datas = node_datas
original_edge_datas = use_relations
elif query_param.mode == "global":
entities_context, relations_context, relationship_chunks = await _get_edge_data(
(
entities_context,
relations_context,
edge_datas,
use_entities,
) = await _get_edge_data(
hl_keywords,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param,
)
all_chunks.extend(relationship_chunks)
original_edge_datas = edge_datas
original_node_datas = use_entities
else: # hybrid or mix mode
ll_data = await _get_node_data(
@ -1929,10 +1945,13 @@ async def _build_query_context(
query_param,
)
(ll_entities_context, ll_relations_context, ll_chunks) = ll_data
(hl_entities_context, hl_relations_context, hl_chunks) = hl_data
(ll_entities_context, ll_relations_context, ll_node_datas, ll_edge_datas) = (
ll_data
)
(hl_entities_context, hl_relations_context, hl_edge_datas, hl_node_datas) = (
hl_data
)
# Collect chunks from entity and relationship sources
# Get vector chunks first if in mix mode
if query_param.mode == "mix" and chunks_vdb:
vector_chunks = await _get_vector_context(
@ -1942,8 +1961,9 @@ async def _build_query_context(
)
all_chunks.extend(vector_chunks)
all_chunks.extend(ll_chunks)
all_chunks.extend(hl_chunks)
# Store original data from both sources
original_node_datas = ll_node_datas + hl_node_datas
original_edge_datas = ll_edge_datas + hl_edge_datas
# Combine entities and relations contexts
entities_context = process_combine_contexts(
@ -2027,6 +2047,73 @@ async def _build_query_context(
f"Truncated relations: {original_relation_count} -> {len(relations_context)} (relation max tokens: {max_relation_tokens})"
)
# After truncation, get text chunks based on final entities and relations
logger.info("Getting text chunks based on truncated entities and relations...")
# Create filtered data based on truncated context
final_node_datas = []
final_edge_datas = []
if entities_context and original_node_datas:
# Create a set of entity names from final truncated context
final_entity_names = {entity["entity"] for entity in entities_context}
# Filter original node data based on final entities
final_node_datas = [
node
for node in original_node_datas
if node.get("entity_name") in final_entity_names
]
if relations_context and original_edge_datas:
# Create a set of relation pairs from final truncated context
final_relation_pairs = {
(rel["entity1"], rel["entity2"]) for rel in relations_context
}
# Filter original edge data based on final relations
final_edge_datas = [
edge
for edge in original_edge_datas
if (edge.get("src_id"), edge.get("tgt_id")) in final_relation_pairs
or (
edge.get("src_tgt", (None, None))[0],
edge.get("src_tgt", (None, None))[1],
)
in final_relation_pairs
]
# Get text chunks based on final filtered data
text_chunk_tasks = []
if final_node_datas:
text_chunk_tasks.append(
_find_most_related_text_unit_from_entities(
final_node_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
)
)
if final_edge_datas:
text_chunk_tasks.append(
_find_related_text_unit_from_relationships(
final_edge_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
)
)
# Execute text chunk retrieval in parallel
if text_chunk_tasks:
text_chunk_results = await asyncio.gather(*text_chunk_tasks)
for chunks in text_chunk_results:
if chunks:
all_chunks.extend(chunks)
# Apply token processing to chunks if tokenizer is available
text_units_context = []
if tokenizer and all_chunks:
# Calculate dynamic token limit for text chunks
entities_str = json.dumps(entities_context, ensure_ascii=False)
relations_str = json.dumps(relations_context, ensure_ascii=False)
@ -2122,7 +2209,6 @@ async def _build_query_context(
)
# Rebuild text_units_context with truncated chunks
text_units_context = []
for i, chunk in enumerate(truncated_chunks):
text_units_context.append(
{
@ -2187,7 +2273,7 @@ async def _get_node_data(
)
if not len(results):
return "", "", ""
return "", "", [], []
# Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results]
@ -2214,14 +2300,8 @@ async def _get_node_data(
}
for k, n, d in zip(results, node_datas, node_degrees)
if n is not None
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# get entitytext chunk
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
)
]
use_relations = await _find_most_related_edges_from_entities(
node_datas,
query_param,
@ -2229,7 +2309,7 @@ async def _get_node_data(
)
logger.info(
f"Local query: {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks"
f"Local query: {len(node_datas)} entites, {len(use_relations)} relations"
)
# build prompt
@ -2278,7 +2358,7 @@ async def _get_node_data(
}
)
return entities_context, relations_context, use_text_units
return entities_context, relations_context, node_datas, use_relations
async def _find_most_related_text_unit_from_entities(
@ -2456,7 +2536,7 @@ async def _get_edge_data(
)
if not len(results):
return "", "", ""
return "", "", [], []
# Prepare edge pairs in two forms:
# For the batch edge properties function, use dicts.
@ -2495,21 +2575,15 @@ async def _get_edge_data(
edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships(
edge_datas,
query_param,
knowledge_graph_inst,
),
_find_related_text_unit_from_relationships(
edge_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
),
use_entities = await _find_most_related_entities_from_relationships(
edge_datas,
query_param,
knowledge_graph_inst,
)
logger.info(
f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks"
f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations"
)
relations_context = []
@ -2558,16 +2632,8 @@ async def _get_edge_data(
}
)
text_units_context = []
for i, t in enumerate(use_text_units):
text_units_context.append(
{
"id": i + 1,
"content": t["content"],
"file_path": t.get("file_path", "unknown"),
}
)
return entities_context, relations_context, text_units_context
# Return original data for later text chunk retrieval
return entities_context, relations_context, edge_datas, use_entities
async def _find_most_related_entities_from_relationships(