Improve task execution with early failure detection

- Add early failure detection for async tasks
- Cancel pending tasks on first exception
This commit is contained in:
yangdx 2025-07-19 10:14:22 +08:00
parent 12d4f12e57
commit 05bc5cfb64
2 changed files with 39 additions and 5 deletions

View File

@ -1127,7 +1127,7 @@ class LightRAG:
} }
) )
# Concurrency is controlled by graph db lock for individual entities and relationships # Concurrency is controlled by keyed lock for individual entities and relationships
if file_extraction_stage_ok: if file_extraction_stage_ok:
try: try:
# Get chunk_results from entity_relation_task # Get chunk_results from entity_relation_task

View File

@ -480,8 +480,25 @@ async def _rebuild_knowledge_from_chunks(
pipeline_status["latest_message"] = status_message pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message) pipeline_status["history_messages"].append(status_message)
# Execute all tasks in parallel with semaphore control # Execute all tasks in parallel with semaphore control and early failure detection
await asyncio.gather(*tasks) done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
# Check if any task raised an exception
for task in done:
if task.exception():
# If a task failed, cancel all pending tasks
for pending_task in pending:
pending_task.cancel()
# Wait for cancellation to complete
if pending:
await asyncio.wait(pending)
# Re-raise the exception to notify the caller
raise task.exception()
# If all tasks completed successfully, collect results
# (No need to collect results since these tasks don't return values)
# Final status report # Final status report
status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships rebuilt successfully." status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships rebuilt successfully."
@ -1313,8 +1330,25 @@ async def merge_nodes_and_edges(
for edge_key, edges in all_edges.items(): for edge_key, edges in all_edges.items():
tasks.append(asyncio.create_task(_locked_process_edges(edge_key, edges))) tasks.append(asyncio.create_task(_locked_process_edges(edge_key, edges)))
# Execute all tasks in parallel with semaphore control # Execute all tasks in parallel with semaphore control and early failure detection
await asyncio.gather(*tasks) done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
# Check if any task raised an exception
for task in done:
if task.exception():
# If a task failed, cancel all pending tasks
for pending_task in pending:
pending_task.cancel()
# Wait for cancellation to complete
if pending:
await asyncio.wait(pending)
# Re-raise the exception to notify the caller
raise task.exception()
# If all tasks completed successfully, collect results
# (No need to collect results since these tasks don't return values)
async def extract_entities( async def extract_entities(