mirror of
				https://github.com/infiniflow/ragflow.git
				synced 2025-11-04 11:49:37 +00:00 
			
		
		
		
	### What problem does this PR solve? #4045 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
		
			
				
	
	
		
			147 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			147 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright (c) 2024 Microsoft Corporation.
 | 
						|
# Licensed under the MIT License
 | 
						|
"""
 | 
						|
Reference:
 | 
						|
 - [graphrag](https://github.com/microsoft/graphrag)
 | 
						|
"""
 | 
						|
 | 
						|
import json
 | 
						|
from dataclasses import dataclass
 | 
						|
 | 
						|
from graphrag.extractor import Extractor
 | 
						|
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
 | 
						|
from rag.llm.chat_model import Base as CompletionLLM
 | 
						|
 | 
						|
from rag.utils import num_tokens_from_string
 | 
						|
 | 
						|
SUMMARIZE_PROMPT = """
 | 
						|
You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
 | 
						|
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
 | 
						|
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
 | 
						|
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
 | 
						|
Make sure it is written in third person, and include the entity names so we the have full context.
 | 
						|
 | 
						|
#######
 | 
						|
-Data-
 | 
						|
Entities: {entity_name}
 | 
						|
Description List: {description_list}
 | 
						|
#######
 | 
						|
Output:
 | 
						|
"""
 | 
						|
 | 
						|
# Max token size for input prompts
 | 
						|
DEFAULT_MAX_INPUT_TOKENS = 4_000
 | 
						|
# Max token count for LLM answers
 | 
						|
DEFAULT_MAX_SUMMARY_LENGTH = 128
 | 
						|
 | 
						|
 | 
						|
@dataclass
 | 
						|
class SummarizationResult:
 | 
						|
    """Unipartite graph extraction result class definition."""
 | 
						|
 | 
						|
    items: str | tuple[str, str]
 | 
						|
    description: str
 | 
						|
 | 
						|
 | 
						|
class SummarizeExtractor(Extractor):
 | 
						|
    """Unipartite graph extractor class definition."""
 | 
						|
 | 
						|
    _entity_name_key: str
 | 
						|
    _input_descriptions_key: str
 | 
						|
    _summarization_prompt: str
 | 
						|
    _on_error: ErrorHandlerFn
 | 
						|
    _max_summary_length: int
 | 
						|
    _max_input_tokens: int
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        llm_invoker: CompletionLLM,
 | 
						|
        entity_name_key: str | None = None,
 | 
						|
        input_descriptions_key: str | None = None,
 | 
						|
        summarization_prompt: str | None = None,
 | 
						|
        on_error: ErrorHandlerFn | None = None,
 | 
						|
        max_summary_length: int | None = None,
 | 
						|
        max_input_tokens: int | None = None,
 | 
						|
    ):
 | 
						|
        """Init method definition."""
 | 
						|
        # TODO: streamline construction
 | 
						|
        self._llm = llm_invoker
 | 
						|
        self._entity_name_key = entity_name_key or "entity_name"
 | 
						|
        self._input_descriptions_key = input_descriptions_key or "description_list"
 | 
						|
 | 
						|
        self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
 | 
						|
        self._on_error = on_error or (lambda _e, _s, _d: None)
 | 
						|
        self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH
 | 
						|
        self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS
 | 
						|
 | 
						|
    def __call__(
 | 
						|
        self,
 | 
						|
        items: str | tuple[str, str],
 | 
						|
        descriptions: list[str],
 | 
						|
    ) -> SummarizationResult:
 | 
						|
        """Call method definition."""
 | 
						|
        result = ""
 | 
						|
        if len(descriptions) == 0:
 | 
						|
            result = ""
 | 
						|
        if len(descriptions) == 1:
 | 
						|
            result = descriptions[0]
 | 
						|
        else:
 | 
						|
            result = self._summarize_descriptions(items, descriptions)
 | 
						|
 | 
						|
        return SummarizationResult(
 | 
						|
            items=items,
 | 
						|
            description=result or "",
 | 
						|
        )
 | 
						|
 | 
						|
    def _summarize_descriptions(
 | 
						|
        self, items: str | tuple[str, str], descriptions: list[str]
 | 
						|
    ) -> str:
 | 
						|
        """Summarize descriptions into a single description."""
 | 
						|
        sorted_items = sorted(items) if isinstance(items, list) else items
 | 
						|
 | 
						|
        # Safety check, should always be a list
 | 
						|
        if not isinstance(descriptions, list):
 | 
						|
            descriptions = [descriptions]
 | 
						|
 | 
						|
            # Iterate over descriptions, adding all until the max input tokens is reached
 | 
						|
        usable_tokens = self._max_input_tokens - num_tokens_from_string(
 | 
						|
            self._summarization_prompt
 | 
						|
        )
 | 
						|
        descriptions_collected = []
 | 
						|
        result = ""
 | 
						|
 | 
						|
        for i, description in enumerate(descriptions):
 | 
						|
            usable_tokens -= num_tokens_from_string(description)
 | 
						|
            descriptions_collected.append(description)
 | 
						|
 | 
						|
            # If buffer is full, or all descriptions have been added, summarize
 | 
						|
            if (usable_tokens < 0 and len(descriptions_collected) > 1) or (
 | 
						|
                i == len(descriptions) - 1
 | 
						|
            ):
 | 
						|
                # Calculate result (final or partial)
 | 
						|
                result = await self._summarize_descriptions_with_llm(
 | 
						|
                    sorted_items, descriptions_collected
 | 
						|
                )
 | 
						|
 | 
						|
                # If we go for another loop, reset values to new
 | 
						|
                if i != len(descriptions) - 1:
 | 
						|
                    descriptions_collected = [result]
 | 
						|
                    usable_tokens = (
 | 
						|
                        self._max_input_tokens
 | 
						|
                        - num_tokens_from_string(self._summarization_prompt)
 | 
						|
                        - num_tokens_from_string(result)
 | 
						|
                    )
 | 
						|
 | 
						|
        return result
 | 
						|
 | 
						|
    def _summarize_descriptions_with_llm(
 | 
						|
        self, items: str | tuple[str, str] | list[str], descriptions: list[str]
 | 
						|
    ):
 | 
						|
        """Summarize descriptions using the LLM."""
 | 
						|
        variables = {
 | 
						|
                        self._entity_name_key: json.dumps(items),
 | 
						|
                        self._input_descriptions_key: json.dumps(sorted(descriptions)),
 | 
						|
                    }
 | 
						|
        text = perform_variable_replacements(self._summarization_prompt, variables=variables)
 | 
						|
        return self._chat("", [{"role": "user", "content": text}])
 |