mirror of
				https://github.com/HKUDS/LightRAG.git
				synced 2025-11-03 19:29:38 +00:00 
			
		
		
		
	Merge pull request #1334 from danielaskdd/main
Refactoring entity and edge merging and add env FORCE_LLM_SUMMARY_ON_MERGE
This commit is contained in:
		
						commit
						0528c06209
					
				@ -43,11 +43,15 @@ WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
 | 
			
		||||
SUMMARY_LANGUAGE=English
 | 
			
		||||
# CHUNK_SIZE=1200
 | 
			
		||||
# CHUNK_OVERLAP_SIZE=100
 | 
			
		||||
### Max tokens for entity or relations summary
 | 
			
		||||
# MAX_TOKEN_SUMMARY=500
 | 
			
		||||
 | 
			
		||||
### Number of parallel processing documents in one patch
 | 
			
		||||
# MAX_PARALLEL_INSERT=2
 | 
			
		||||
 | 
			
		||||
### Max tokens for entity/relations description after merge
 | 
			
		||||
# MAX_TOKEN_SUMMARY=500
 | 
			
		||||
### Number of entities/edges to trigger LLM re-summary on merge ( at least 3 is recommented)
 | 
			
		||||
# FORCE_LLM_SUMMARY_ON_MERGE=6
 | 
			
		||||
 | 
			
		||||
### Num of chunks send to Embedding in single request
 | 
			
		||||
# EMBEDDING_BATCH_NUM=32
 | 
			
		||||
### Max concurrency requests for Embedding
 | 
			
		||||
 | 
			
		||||
@ -1 +1 @@
 | 
			
		||||
__api_version__ = "0143"
 | 
			
		||||
__api_version__ = "0145"
 | 
			
		||||
 | 
			
		||||
@ -261,8 +261,12 @@ def display_splash_screen(args: argparse.Namespace) -> None:
 | 
			
		||||
    ASCIIColors.yellow(f"{args.chunk_overlap_size}")
 | 
			
		||||
    ASCIIColors.white("    ├─ Cosine Threshold: ", end="")
 | 
			
		||||
    ASCIIColors.yellow(f"{args.cosine_threshold}")
 | 
			
		||||
    ASCIIColors.white("    └─ Top-K: ", end="")
 | 
			
		||||
    ASCIIColors.white("    ├─ Top-K: ", end="")
 | 
			
		||||
    ASCIIColors.yellow(f"{args.top_k}")
 | 
			
		||||
    ASCIIColors.white("    ├─ Max Token Summary: ", end="")
 | 
			
		||||
    ASCIIColors.yellow(f"{int(os.getenv('MAX_TOKEN_SUMMARY', 500))}")
 | 
			
		||||
    ASCIIColors.white("    └─ Force LLM Summary on Merge: ", end="")
 | 
			
		||||
    ASCIIColors.yellow(f"{int(os.getenv('FORCE_LLM_SUMMARY_ON_MERGE', 6))}")
 | 
			
		||||
 | 
			
		||||
    # System Configuration
 | 
			
		||||
    ASCIIColors.magenta("\n💾 Storage Configuration:")
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										2
									
								
								lightrag/api/webui/index.html
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										2
									
								
								lightrag/api/webui/index.html
									
									
									
										generated
									
									
									
								
							@ -8,7 +8,7 @@
 | 
			
		||||
    <link rel="icon" type="image/svg+xml" href="logo.png" />
 | 
			
		||||
    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
 | 
			
		||||
    <title>Lightrag</title>
 | 
			
		||||
    <script type="module" crossorigin src="/webui/assets/index-Cicy56pP.js"></script>
 | 
			
		||||
    <script type="module" crossorigin src="/webui/assets/index-BPm_J2w3.js"></script>
 | 
			
		||||
    <link rel="stylesheet" crossorigin href="/webui/assets/index-CTB4Vp_z.css">
 | 
			
		||||
  </head>
 | 
			
		||||
  <body>
 | 
			
		||||
 | 
			
		||||
@ -103,8 +103,10 @@ class LightRAG:
 | 
			
		||||
    entity_extract_max_gleaning: int = field(default=1)
 | 
			
		||||
    """Maximum number of entity extraction attempts for ambiguous content."""
 | 
			
		||||
 | 
			
		||||
    entity_summary_to_max_tokens: int = field(
 | 
			
		||||
        default=int(os.getenv("MAX_TOKEN_SUMMARY", 500))
 | 
			
		||||
    summary_to_max_tokens: int = field(default=int(os.getenv("MAX_TOKEN_SUMMARY", 500)))
 | 
			
		||||
 | 
			
		||||
    force_llm_summary_on_merge: int = field(
 | 
			
		||||
        default=int(os.getenv("FORCE_LLM_SUMMARY_ON_MERGE", 6))
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Text chunking
 | 
			
		||||
 | 
			
		||||
@ -117,15 +117,13 @@ async def _handle_entity_relation_summary(
 | 
			
		||||
    use_llm_func: callable = global_config["llm_model_func"]
 | 
			
		||||
    llm_max_tokens = global_config["llm_model_max_token_size"]
 | 
			
		||||
    tiktoken_model_name = global_config["tiktoken_model_name"]
 | 
			
		||||
    summary_max_tokens = global_config["entity_summary_to_max_tokens"]
 | 
			
		||||
    summary_max_tokens = global_config["summary_to_max_tokens"]
 | 
			
		||||
 | 
			
		||||
    language = global_config["addon_params"].get(
 | 
			
		||||
        "language", PROMPTS["DEFAULT_LANGUAGE"]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
 | 
			
		||||
    if len(tokens) < summary_max_tokens:  # No need for summary
 | 
			
		||||
        return description
 | 
			
		||||
 | 
			
		||||
    prompt_template = PROMPTS["summarize_entity_descriptions"]
 | 
			
		||||
    use_description = decode_tokens_by_tiktoken(
 | 
			
		||||
        tokens[:llm_max_tokens], model_name=tiktoken_model_name
 | 
			
		||||
@ -138,14 +136,6 @@ async def _handle_entity_relation_summary(
 | 
			
		||||
    use_prompt = prompt_template.format(**context_base)
 | 
			
		||||
    logger.debug(f"Trigger summary: {entity_or_relation_name}")
 | 
			
		||||
 | 
			
		||||
    # Update pipeline status when LLM summary is needed
 | 
			
		||||
    status_message = "Use LLM to re-summary description..."
 | 
			
		||||
    logger.info(status_message)
 | 
			
		||||
    if pipeline_status is not None and pipeline_status_lock is not None:
 | 
			
		||||
        async with pipeline_status_lock:
 | 
			
		||||
            pipeline_status["latest_message"] = status_message
 | 
			
		||||
            pipeline_status["history_messages"].append(status_message)
 | 
			
		||||
 | 
			
		||||
    # Use LLM function with cache
 | 
			
		||||
    summary = await use_llm_func_with_cache(
 | 
			
		||||
        use_prompt,
 | 
			
		||||
@ -244,14 +234,6 @@ async def _merge_nodes_then_upsert(
 | 
			
		||||
 | 
			
		||||
    already_node = await knowledge_graph_inst.get_node(entity_name)
 | 
			
		||||
    if already_node is not None:
 | 
			
		||||
        # Update pipeline status when a node that needs merging is found
 | 
			
		||||
        status_message = f"Merging entity: {entity_name}"
 | 
			
		||||
        logger.info(status_message)
 | 
			
		||||
        if pipeline_status is not None and pipeline_status_lock is not None:
 | 
			
		||||
            async with pipeline_status_lock:
 | 
			
		||||
                pipeline_status["latest_message"] = status_message
 | 
			
		||||
                pipeline_status["history_messages"].append(status_message)
 | 
			
		||||
 | 
			
		||||
        already_entity_types.append(already_node["entity_type"])
 | 
			
		||||
        already_source_ids.extend(
 | 
			
		||||
            split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
 | 
			
		||||
@ -278,7 +260,19 @@ async def _merge_nodes_then_upsert(
 | 
			
		||||
        set([dp["file_path"] for dp in nodes_data] + already_file_paths)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    logger.debug(f"file_path: {file_path}")
 | 
			
		||||
    force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
 | 
			
		||||
 | 
			
		||||
    num_fragment = description.count(GRAPH_FIELD_SEP) + 1
 | 
			
		||||
    num_new_fragment = len(set([dp["description"] for dp in nodes_data]))
 | 
			
		||||
 | 
			
		||||
    if num_fragment > 1:
 | 
			
		||||
        if num_fragment >= force_llm_summary_on_merge:
 | 
			
		||||
            status_message = f"LLM merge N: {entity_name} | {num_new_fragment}+{num_fragment-num_new_fragment}"
 | 
			
		||||
            logger.info(status_message)
 | 
			
		||||
            if pipeline_status is not None and pipeline_status_lock is not None:
 | 
			
		||||
                async with pipeline_status_lock:
 | 
			
		||||
                    pipeline_status["latest_message"] = status_message
 | 
			
		||||
                    pipeline_status["history_messages"].append(status_message)
 | 
			
		||||
            description = await _handle_entity_relation_summary(
 | 
			
		||||
                entity_name,
 | 
			
		||||
                description,
 | 
			
		||||
@ -287,6 +281,14 @@ async def _merge_nodes_then_upsert(
 | 
			
		||||
                pipeline_status_lock,
 | 
			
		||||
                llm_response_cache,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            status_message = f"Merge N: {entity_name} | {num_new_fragment}+{num_fragment-num_new_fragment}"
 | 
			
		||||
            logger.info(status_message)
 | 
			
		||||
            if pipeline_status is not None and pipeline_status_lock is not None:
 | 
			
		||||
                async with pipeline_status_lock:
 | 
			
		||||
                    pipeline_status["latest_message"] = status_message
 | 
			
		||||
                    pipeline_status["history_messages"].append(status_message)
 | 
			
		||||
 | 
			
		||||
    node_data = dict(
 | 
			
		||||
        entity_id=entity_name,
 | 
			
		||||
        entity_type=entity_type,
 | 
			
		||||
@ -319,14 +321,6 @@ async def _merge_edges_then_upsert(
 | 
			
		||||
    already_file_paths = []
 | 
			
		||||
 | 
			
		||||
    if await knowledge_graph_inst.has_edge(src_id, tgt_id):
 | 
			
		||||
        # Update pipeline status when an edge that needs merging is found
 | 
			
		||||
        status_message = f"Merging edge::: {src_id} - {tgt_id}"
 | 
			
		||||
        logger.info(status_message)
 | 
			
		||||
        if pipeline_status is not None and pipeline_status_lock is not None:
 | 
			
		||||
            async with pipeline_status_lock:
 | 
			
		||||
                pipeline_status["latest_message"] = status_message
 | 
			
		||||
                pipeline_status["history_messages"].append(status_message)
 | 
			
		||||
 | 
			
		||||
        already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
 | 
			
		||||
        # Handle the case where get_edge returns None or missing fields
 | 
			
		||||
        if already_edge:
 | 
			
		||||
@ -404,6 +398,22 @@ async def _merge_edges_then_upsert(
 | 
			
		||||
                    "file_path": file_path,
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
 | 
			
		||||
 | 
			
		||||
    num_fragment = description.count(GRAPH_FIELD_SEP) + 1
 | 
			
		||||
    num_new_fragment = len(
 | 
			
		||||
        set([dp["description"] for dp in edges_data if dp.get("description")])
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if num_fragment > 1:
 | 
			
		||||
        if num_fragment >= force_llm_summary_on_merge:
 | 
			
		||||
            status_message = f"LLM merge E: {src_id} - {tgt_id} | {num_new_fragment}+{num_fragment-num_new_fragment}"
 | 
			
		||||
            logger.info(status_message)
 | 
			
		||||
            if pipeline_status is not None and pipeline_status_lock is not None:
 | 
			
		||||
                async with pipeline_status_lock:
 | 
			
		||||
                    pipeline_status["latest_message"] = status_message
 | 
			
		||||
                    pipeline_status["history_messages"].append(status_message)
 | 
			
		||||
            description = await _handle_entity_relation_summary(
 | 
			
		||||
                f"({src_id}, {tgt_id})",
 | 
			
		||||
                description,
 | 
			
		||||
@ -412,6 +422,14 @@ async def _merge_edges_then_upsert(
 | 
			
		||||
                pipeline_status_lock,
 | 
			
		||||
                llm_response_cache,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            status_message = f"Merge E: {src_id} - {tgt_id} | {num_new_fragment}+{num_fragment-num_new_fragment}"
 | 
			
		||||
            logger.info(status_message)
 | 
			
		||||
            if pipeline_status is not None and pipeline_status_lock is not None:
 | 
			
		||||
                async with pipeline_status_lock:
 | 
			
		||||
                    pipeline_status["latest_message"] = status_message
 | 
			
		||||
                    pipeline_status["history_messages"].append(status_message)
 | 
			
		||||
 | 
			
		||||
    await knowledge_graph_inst.upsert_edge(
 | 
			
		||||
        src_id,
 | 
			
		||||
        tgt_id,
 | 
			
		||||
@ -550,8 +568,10 @@ async def extract_entities(
 | 
			
		||||
        Args:
 | 
			
		||||
            chunk_key_dp (tuple[str, TextChunkSchema]):
 | 
			
		||||
                ("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
 | 
			
		||||
        Returns:
 | 
			
		||||
            tuple: (maybe_nodes, maybe_edges) containing extracted entities and relationships
 | 
			
		||||
        """
 | 
			
		||||
        nonlocal processed_chunks, total_entities_count, total_relations_count
 | 
			
		||||
        nonlocal processed_chunks
 | 
			
		||||
        chunk_key = chunk_key_dp[0]
 | 
			
		||||
        chunk_dp = chunk_key_dp[1]
 | 
			
		||||
        content = chunk_dp["content"]
 | 
			
		||||
@ -623,13 +643,35 @@ async def extract_entities(
 | 
			
		||||
                pipeline_status["latest_message"] = log_message
 | 
			
		||||
                pipeline_status["history_messages"].append(log_message)
 | 
			
		||||
 | 
			
		||||
        # Use graph database lock to ensure atomic merges and updates
 | 
			
		||||
        chunk_entities_data = []
 | 
			
		||||
        chunk_relationships_data = []
 | 
			
		||||
        # Return the extracted nodes and edges for centralized processing
 | 
			
		||||
        return maybe_nodes, maybe_edges
 | 
			
		||||
 | 
			
		||||
        async with graph_db_lock:
 | 
			
		||||
            # Process and update entities
 | 
			
		||||
    # Handle all chunks in parallel and collect results
 | 
			
		||||
    tasks = [_process_single_content(c) for c in ordered_chunks]
 | 
			
		||||
    chunk_results = await asyncio.gather(*tasks)
 | 
			
		||||
 | 
			
		||||
    # Collect all nodes and edges from all chunks
 | 
			
		||||
    all_nodes = defaultdict(list)
 | 
			
		||||
    all_edges = defaultdict(list)
 | 
			
		||||
 | 
			
		||||
    for maybe_nodes, maybe_edges in chunk_results:
 | 
			
		||||
        # Collect nodes
 | 
			
		||||
        for entity_name, entities in maybe_nodes.items():
 | 
			
		||||
            all_nodes[entity_name].extend(entities)
 | 
			
		||||
 | 
			
		||||
        # Collect edges with sorted keys for undirected graph
 | 
			
		||||
        for edge_key, edges in maybe_edges.items():
 | 
			
		||||
            sorted_edge_key = tuple(sorted(edge_key))
 | 
			
		||||
            all_edges[sorted_edge_key].extend(edges)
 | 
			
		||||
 | 
			
		||||
    # Centralized processing of all nodes and edges
 | 
			
		||||
    entities_data = []
 | 
			
		||||
    relationships_data = []
 | 
			
		||||
 | 
			
		||||
    # Use graph database lock to ensure atomic merges and updates
 | 
			
		||||
    async with graph_db_lock:
 | 
			
		||||
        # Process and update all entities at once
 | 
			
		||||
        for entity_name, entities in all_nodes.items():
 | 
			
		||||
            entity_data = await _merge_nodes_then_upsert(
 | 
			
		||||
                entity_name,
 | 
			
		||||
                entities,
 | 
			
		||||
@ -639,15 +681,13 @@ async def extract_entities(
 | 
			
		||||
                pipeline_status_lock,
 | 
			
		||||
                llm_response_cache,
 | 
			
		||||
            )
 | 
			
		||||
                chunk_entities_data.append(entity_data)
 | 
			
		||||
            entities_data.append(entity_data)
 | 
			
		||||
 | 
			
		||||
            # Process and update relationships
 | 
			
		||||
            for edge_key, edges in maybe_edges.items():
 | 
			
		||||
                # Ensure edge direction consistency
 | 
			
		||||
                sorted_edge_key = tuple(sorted(edge_key))
 | 
			
		||||
        # Process and update all relationships at once
 | 
			
		||||
        for edge_key, edges in all_edges.items():
 | 
			
		||||
            edge_data = await _merge_edges_then_upsert(
 | 
			
		||||
                    sorted_edge_key[0],
 | 
			
		||||
                    sorted_edge_key[1],
 | 
			
		||||
                edge_key[0],
 | 
			
		||||
                edge_key[1],
 | 
			
		||||
                edges,
 | 
			
		||||
                knowledge_graph_inst,
 | 
			
		||||
                global_config,
 | 
			
		||||
@ -655,10 +695,10 @@ async def extract_entities(
 | 
			
		||||
                pipeline_status_lock,
 | 
			
		||||
                llm_response_cache,
 | 
			
		||||
            )
 | 
			
		||||
                chunk_relationships_data.append(edge_data)
 | 
			
		||||
            relationships_data.append(edge_data)
 | 
			
		||||
 | 
			
		||||
            # Update vector database (within the same lock to ensure atomicity)
 | 
			
		||||
            if entity_vdb is not None and chunk_entities_data:
 | 
			
		||||
        # Update vector databases with all collected data
 | 
			
		||||
        if entity_vdb is not None and entities_data:
 | 
			
		||||
            data_for_vdb = {
 | 
			
		||||
                compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
 | 
			
		||||
                    "entity_name": dp["entity_name"],
 | 
			
		||||
@ -667,11 +707,11 @@ async def extract_entities(
 | 
			
		||||
                    "source_id": dp["source_id"],
 | 
			
		||||
                    "file_path": dp.get("file_path", "unknown_source"),
 | 
			
		||||
                }
 | 
			
		||||
                    for dp in chunk_entities_data
 | 
			
		||||
                for dp in entities_data
 | 
			
		||||
            }
 | 
			
		||||
            await entity_vdb.upsert(data_for_vdb)
 | 
			
		||||
 | 
			
		||||
            if relationships_vdb is not None and chunk_relationships_data:
 | 
			
		||||
        if relationships_vdb is not None and relationships_data:
 | 
			
		||||
            data_for_vdb = {
 | 
			
		||||
                compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
 | 
			
		||||
                    "src_id": dp["src_id"],
 | 
			
		||||
@ -681,17 +721,13 @@ async def extract_entities(
 | 
			
		||||
                    "source_id": dp["source_id"],
 | 
			
		||||
                    "file_path": dp.get("file_path", "unknown_source"),
 | 
			
		||||
                }
 | 
			
		||||
                    for dp in chunk_relationships_data
 | 
			
		||||
                for dp in relationships_data
 | 
			
		||||
            }
 | 
			
		||||
            await relationships_vdb.upsert(data_for_vdb)
 | 
			
		||||
 | 
			
		||||
            # Update counters
 | 
			
		||||
            total_entities_count += len(chunk_entities_data)
 | 
			
		||||
            total_relations_count += len(chunk_relationships_data)
 | 
			
		||||
 | 
			
		||||
    # Handle all chunks in parallel
 | 
			
		||||
    tasks = [_process_single_content(c) for c in ordered_chunks]
 | 
			
		||||
    await asyncio.gather(*tasks)
 | 
			
		||||
    # Update total counts
 | 
			
		||||
    total_entities_count = len(entities_data)
 | 
			
		||||
    total_relations_count = len(relationships_data)
 | 
			
		||||
 | 
			
		||||
    log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)"
 | 
			
		||||
    logger.info(log_message)
 | 
			
		||||
 | 
			
		||||
@ -967,7 +967,7 @@ async def use_llm_func_with_cache(
 | 
			
		||||
        res: str = await use_llm_func(input_text, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # Save to cache
 | 
			
		||||
        logger.info(f"Saving LLM cache for {arg_hash}")
 | 
			
		||||
        logger.info(f" == LLM cache == saving {arg_hash}")
 | 
			
		||||
        await save_to_cache(
 | 
			
		||||
            llm_response_cache,
 | 
			
		||||
            CacheData(
 | 
			
		||||
 | 
			
		||||
@ -166,7 +166,7 @@ export default function PipelineStatusDialog({
 | 
			
		||||
          {/* Latest Message */}
 | 
			
		||||
          <div className="space-y-2">
 | 
			
		||||
            <div className="text-sm font-medium">{t('documentPanel.pipelineStatus.latestMessage')}:</div>
 | 
			
		||||
            <div className="font-mono text-sm rounded-md bg-zinc-800 text-zinc-100 p-3">
 | 
			
		||||
            <div className="font-mono text-xs rounded-md bg-zinc-800 text-zinc-100 p-3">
 | 
			
		||||
              {status?.latest_message || '-'}
 | 
			
		||||
            </div>
 | 
			
		||||
          </div>
 | 
			
		||||
@ -177,7 +177,7 @@ export default function PipelineStatusDialog({
 | 
			
		||||
            <div
 | 
			
		||||
              ref={historyRef}
 | 
			
		||||
              onScroll={handleScroll}
 | 
			
		||||
              className="font-mono text-sm rounded-md bg-zinc-800 text-zinc-100 p-3 overflow-y-auto min-h-[7.5em] max-h-[40vh]"
 | 
			
		||||
              className="font-mono text-xs rounded-md bg-zinc-800 text-zinc-100 p-3 overflow-y-auto min-h-[7.5em] max-h-[40vh]"
 | 
			
		||||
            >
 | 
			
		||||
              {status?.history_messages?.length ? (
 | 
			
		||||
                status.history_messages.map((msg, idx) => (
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user