mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-11-26 23:16:10 +00:00
Update operate.py
This commit is contained in:
parent
93b25a65d5
commit
29e82723e6
@ -1892,26 +1892,42 @@ async def _build_query_context(
|
||||
entities_context = []
|
||||
relations_context = []
|
||||
|
||||
# Store original data for later text chunk retrieval
|
||||
original_node_datas = []
|
||||
original_edge_datas = []
|
||||
|
||||
# Handle local and global modes
|
||||
if query_param.mode == "local":
|
||||
entities_context, relations_context, entity_chunks = await _get_node_data(
|
||||
(
|
||||
entities_context,
|
||||
relations_context,
|
||||
node_datas,
|
||||
use_relations,
|
||||
) = await _get_node_data(
|
||||
ll_keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
all_chunks.extend(entity_chunks)
|
||||
original_node_datas = node_datas
|
||||
original_edge_datas = use_relations
|
||||
|
||||
elif query_param.mode == "global":
|
||||
entities_context, relations_context, relationship_chunks = await _get_edge_data(
|
||||
(
|
||||
entities_context,
|
||||
relations_context,
|
||||
edge_datas,
|
||||
use_entities,
|
||||
) = await _get_edge_data(
|
||||
hl_keywords,
|
||||
knowledge_graph_inst,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
all_chunks.extend(relationship_chunks)
|
||||
original_edge_datas = edge_datas
|
||||
original_node_datas = use_entities
|
||||
|
||||
else: # hybrid or mix mode
|
||||
ll_data = await _get_node_data(
|
||||
@ -1929,10 +1945,13 @@ async def _build_query_context(
|
||||
query_param,
|
||||
)
|
||||
|
||||
(ll_entities_context, ll_relations_context, ll_chunks) = ll_data
|
||||
(hl_entities_context, hl_relations_context, hl_chunks) = hl_data
|
||||
(ll_entities_context, ll_relations_context, ll_node_datas, ll_edge_datas) = (
|
||||
ll_data
|
||||
)
|
||||
(hl_entities_context, hl_relations_context, hl_edge_datas, hl_node_datas) = (
|
||||
hl_data
|
||||
)
|
||||
|
||||
# Collect chunks from entity and relationship sources
|
||||
# Get vector chunks first if in mix mode
|
||||
if query_param.mode == "mix" and chunks_vdb:
|
||||
vector_chunks = await _get_vector_context(
|
||||
@ -1942,8 +1961,9 @@ async def _build_query_context(
|
||||
)
|
||||
all_chunks.extend(vector_chunks)
|
||||
|
||||
all_chunks.extend(ll_chunks)
|
||||
all_chunks.extend(hl_chunks)
|
||||
# Store original data from both sources
|
||||
original_node_datas = ll_node_datas + hl_node_datas
|
||||
original_edge_datas = ll_edge_datas + hl_edge_datas
|
||||
|
||||
# Combine entities and relations contexts
|
||||
entities_context = process_combine_contexts(
|
||||
@ -2027,6 +2047,73 @@ async def _build_query_context(
|
||||
f"Truncated relations: {original_relation_count} -> {len(relations_context)} (relation max tokens: {max_relation_tokens})"
|
||||
)
|
||||
|
||||
# After truncation, get text chunks based on final entities and relations
|
||||
logger.info("Getting text chunks based on truncated entities and relations...")
|
||||
|
||||
# Create filtered data based on truncated context
|
||||
final_node_datas = []
|
||||
final_edge_datas = []
|
||||
|
||||
if entities_context and original_node_datas:
|
||||
# Create a set of entity names from final truncated context
|
||||
final_entity_names = {entity["entity"] for entity in entities_context}
|
||||
# Filter original node data based on final entities
|
||||
final_node_datas = [
|
||||
node
|
||||
for node in original_node_datas
|
||||
if node.get("entity_name") in final_entity_names
|
||||
]
|
||||
|
||||
if relations_context and original_edge_datas:
|
||||
# Create a set of relation pairs from final truncated context
|
||||
final_relation_pairs = {
|
||||
(rel["entity1"], rel["entity2"]) for rel in relations_context
|
||||
}
|
||||
# Filter original edge data based on final relations
|
||||
final_edge_datas = [
|
||||
edge
|
||||
for edge in original_edge_datas
|
||||
if (edge.get("src_id"), edge.get("tgt_id")) in final_relation_pairs
|
||||
or (
|
||||
edge.get("src_tgt", (None, None))[0],
|
||||
edge.get("src_tgt", (None, None))[1],
|
||||
)
|
||||
in final_relation_pairs
|
||||
]
|
||||
|
||||
# Get text chunks based on final filtered data
|
||||
text_chunk_tasks = []
|
||||
|
||||
if final_node_datas:
|
||||
text_chunk_tasks.append(
|
||||
_find_most_related_text_unit_from_entities(
|
||||
final_node_datas,
|
||||
query_param,
|
||||
text_chunks_db,
|
||||
knowledge_graph_inst,
|
||||
)
|
||||
)
|
||||
|
||||
if final_edge_datas:
|
||||
text_chunk_tasks.append(
|
||||
_find_related_text_unit_from_relationships(
|
||||
final_edge_datas,
|
||||
query_param,
|
||||
text_chunks_db,
|
||||
knowledge_graph_inst,
|
||||
)
|
||||
)
|
||||
|
||||
# Execute text chunk retrieval in parallel
|
||||
if text_chunk_tasks:
|
||||
text_chunk_results = await asyncio.gather(*text_chunk_tasks)
|
||||
for chunks in text_chunk_results:
|
||||
if chunks:
|
||||
all_chunks.extend(chunks)
|
||||
|
||||
# Apply token processing to chunks if tokenizer is available
|
||||
text_units_context = []
|
||||
if tokenizer and all_chunks:
|
||||
# Calculate dynamic token limit for text chunks
|
||||
entities_str = json.dumps(entities_context, ensure_ascii=False)
|
||||
relations_str = json.dumps(relations_context, ensure_ascii=False)
|
||||
@ -2122,7 +2209,6 @@ async def _build_query_context(
|
||||
)
|
||||
|
||||
# Rebuild text_units_context with truncated chunks
|
||||
text_units_context = []
|
||||
for i, chunk in enumerate(truncated_chunks):
|
||||
text_units_context.append(
|
||||
{
|
||||
@ -2187,7 +2273,7 @@ async def _get_node_data(
|
||||
)
|
||||
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
return "", "", [], []
|
||||
|
||||
# Extract all entity IDs from your results list
|
||||
node_ids = [r["entity_name"] for r in results]
|
||||
@ -2214,14 +2300,8 @@ async def _get_node_data(
|
||||
}
|
||||
for k, n, d in zip(results, node_datas, node_degrees)
|
||||
if n is not None
|
||||
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
||||
# get entitytext chunk
|
||||
use_text_units = await _find_most_related_text_unit_from_entities(
|
||||
node_datas,
|
||||
query_param,
|
||||
text_chunks_db,
|
||||
knowledge_graph_inst,
|
||||
)
|
||||
]
|
||||
|
||||
use_relations = await _find_most_related_edges_from_entities(
|
||||
node_datas,
|
||||
query_param,
|
||||
@ -2229,7 +2309,7 @@ async def _get_node_data(
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Local query: {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks"
|
||||
f"Local query: {len(node_datas)} entites, {len(use_relations)} relations"
|
||||
)
|
||||
|
||||
# build prompt
|
||||
@ -2278,7 +2358,7 @@ async def _get_node_data(
|
||||
}
|
||||
)
|
||||
|
||||
return entities_context, relations_context, use_text_units
|
||||
return entities_context, relations_context, node_datas, use_relations
|
||||
|
||||
|
||||
async def _find_most_related_text_unit_from_entities(
|
||||
@ -2456,7 +2536,7 @@ async def _get_edge_data(
|
||||
)
|
||||
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
return "", "", [], []
|
||||
|
||||
# Prepare edge pairs in two forms:
|
||||
# For the batch edge properties function, use dicts.
|
||||
@ -2495,21 +2575,15 @@ async def _get_edge_data(
|
||||
edge_datas = sorted(
|
||||
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
use_entities, use_text_units = await asyncio.gather(
|
||||
_find_most_related_entities_from_relationships(
|
||||
edge_datas,
|
||||
query_param,
|
||||
knowledge_graph_inst,
|
||||
),
|
||||
_find_related_text_unit_from_relationships(
|
||||
edge_datas,
|
||||
query_param,
|
||||
text_chunks_db,
|
||||
knowledge_graph_inst,
|
||||
),
|
||||
|
||||
use_entities = await _find_most_related_entities_from_relationships(
|
||||
edge_datas,
|
||||
query_param,
|
||||
knowledge_graph_inst,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks"
|
||||
f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations"
|
||||
)
|
||||
|
||||
relations_context = []
|
||||
@ -2558,16 +2632,8 @@ async def _get_edge_data(
|
||||
}
|
||||
)
|
||||
|
||||
text_units_context = []
|
||||
for i, t in enumerate(use_text_units):
|
||||
text_units_context.append(
|
||||
{
|
||||
"id": i + 1,
|
||||
"content": t["content"],
|
||||
"file_path": t.get("file_path", "unknown"),
|
||||
}
|
||||
)
|
||||
return entities_context, relations_context, text_units_context
|
||||
# Return original data for later text chunk retrieval
|
||||
return entities_context, relations_context, edge_datas, use_entities
|
||||
|
||||
|
||||
async def _find_most_related_entities_from_relationships(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user