| 
									
										
										
										
											2024-08-25 18:58:20 +08:00
										 |  |  | # Copyright (c) 2024 Microsoft Corporation. | 
					
						
							|  |  |  | # Licensed under the MIT License | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  | """
 | 
					
						
							|  |  |  | Reference: | 
					
						
							|  |  |  |  - [graphrag](https://github.com/microsoft/graphrag) | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import json | 
					
						
							|  |  |  | from dataclasses import dataclass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from graphrag.utils import ErrorHandlerFn, perform_variable_replacements | 
					
						
							|  |  |  | from rag.llm.chat_model import Base as CompletionLLM | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from rag.utils import num_tokens_from_string | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | SUMMARIZE_PROMPT = """
 | 
					
						
							|  |  |  | You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. | 
					
						
							|  |  |  | Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. | 
					
						
							|  |  |  | Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. | 
					
						
							|  |  |  | If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. | 
					
						
							|  |  |  | Make sure it is written in third person, and include the entity names so we the have full context. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ####### | 
					
						
							|  |  |  | -Data- | 
					
						
							|  |  |  | Entities: {entity_name} | 
					
						
							|  |  |  | Description List: {description_list} | 
					
						
							|  |  |  | ####### | 
					
						
							|  |  |  | Output: | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Max token size for input prompts | 
					
						
							|  |  |  | DEFAULT_MAX_INPUT_TOKENS = 4_000 | 
					
						
							|  |  |  | # Max token count for LLM answers | 
					
						
							|  |  |  | DEFAULT_MAX_SUMMARY_LENGTH = 128 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @dataclass | 
					
						
							|  |  |  | class SummarizationResult: | 
					
						
							|  |  |  |     """Unipartite graph extraction result class definition.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     items: str | tuple[str, str] | 
					
						
							|  |  |  |     description: str | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SummarizeExtractor: | 
					
						
							|  |  |  |     """Unipartite graph extractor class definition.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     _llm: CompletionLLM | 
					
						
							|  |  |  |     _entity_name_key: str | 
					
						
							|  |  |  |     _input_descriptions_key: str | 
					
						
							|  |  |  |     _summarization_prompt: str | 
					
						
							|  |  |  |     _on_error: ErrorHandlerFn | 
					
						
							|  |  |  |     _max_summary_length: int | 
					
						
							|  |  |  |     _max_input_tokens: int | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         llm_invoker: CompletionLLM, | 
					
						
							|  |  |  |         entity_name_key: str | None = None, | 
					
						
							|  |  |  |         input_descriptions_key: str | None = None, | 
					
						
							|  |  |  |         summarization_prompt: str | None = None, | 
					
						
							|  |  |  |         on_error: ErrorHandlerFn | None = None, | 
					
						
							|  |  |  |         max_summary_length: int | None = None, | 
					
						
							|  |  |  |         max_input_tokens: int | None = None, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         """Init method definition.""" | 
					
						
							|  |  |  |         # TODO: streamline construction | 
					
						
							|  |  |  |         self._llm = llm_invoker | 
					
						
							|  |  |  |         self._entity_name_key = entity_name_key or "entity_name" | 
					
						
							|  |  |  |         self._input_descriptions_key = input_descriptions_key or "description_list" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT | 
					
						
							|  |  |  |         self._on_error = on_error or (lambda _e, _s, _d: None) | 
					
						
							|  |  |  |         self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH | 
					
						
							|  |  |  |         self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         items: str | tuple[str, str], | 
					
						
							|  |  |  |         descriptions: list[str], | 
					
						
							|  |  |  |     ) -> SummarizationResult: | 
					
						
							|  |  |  |         """Call method definition.""" | 
					
						
							|  |  |  |         result = "" | 
					
						
							|  |  |  |         if len(descriptions) == 0: | 
					
						
							|  |  |  |             result = "" | 
					
						
							|  |  |  |         if len(descriptions) == 1: | 
					
						
							|  |  |  |             result = descriptions[0] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             result = self._summarize_descriptions(items, descriptions) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return SummarizationResult( | 
					
						
							|  |  |  |             items=items, | 
					
						
							|  |  |  |             description=result or "", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _summarize_descriptions( | 
					
						
							|  |  |  |         self, items: str | tuple[str, str], descriptions: list[str] | 
					
						
							|  |  |  |     ) -> str: | 
					
						
							|  |  |  |         """Summarize descriptions into a single description.""" | 
					
						
							|  |  |  |         sorted_items = sorted(items) if isinstance(items, list) else items | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Safety check, should always be a list | 
					
						
							|  |  |  |         if not isinstance(descriptions, list): | 
					
						
							|  |  |  |             descriptions = [descriptions] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Iterate over descriptions, adding all until the max input tokens is reached | 
					
						
							|  |  |  |         usable_tokens = self._max_input_tokens - num_tokens_from_string( | 
					
						
							|  |  |  |             self._summarization_prompt | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         descriptions_collected = [] | 
					
						
							|  |  |  |         result = "" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for i, description in enumerate(descriptions): | 
					
						
							|  |  |  |             usable_tokens -= num_tokens_from_string(description) | 
					
						
							|  |  |  |             descriptions_collected.append(description) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # If buffer is full, or all descriptions have been added, summarize | 
					
						
							|  |  |  |             if (usable_tokens < 0 and len(descriptions_collected) > 1) or ( | 
					
						
							|  |  |  |                 i == len(descriptions) - 1 | 
					
						
							|  |  |  |             ): | 
					
						
							|  |  |  |                 # Calculate result (final or partial) | 
					
						
							|  |  |  |                 result = await self._summarize_descriptions_with_llm( | 
					
						
							|  |  |  |                     sorted_items, descriptions_collected | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # If we go for another loop, reset values to new | 
					
						
							|  |  |  |                 if i != len(descriptions) - 1: | 
					
						
							|  |  |  |                     descriptions_collected = [result] | 
					
						
							|  |  |  |                     usable_tokens = ( | 
					
						
							|  |  |  |                         self._max_input_tokens | 
					
						
							|  |  |  |                         - num_tokens_from_string(self._summarization_prompt) | 
					
						
							|  |  |  |                         - num_tokens_from_string(result) | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return result | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _summarize_descriptions_with_llm( | 
					
						
							|  |  |  |         self, items: str | tuple[str, str] | list[str], descriptions: list[str] | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         """Summarize descriptions using the LLM.""" | 
					
						
							|  |  |  |         variables = { | 
					
						
							|  |  |  |                         self._entity_name_key: json.dumps(items), | 
					
						
							|  |  |  |                         self._input_descriptions_key: json.dumps(sorted(descriptions)), | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |         text = perform_variable_replacements(self._summarization_prompt, variables=variables) | 
					
						
							|  |  |  |         return self._llm.chat("", [{"role": "user", "content": text}]) |