| 
									
										
										
										
											2024-08-25 18:58:20 +08:00
										 |  |  | # Copyright (c) 2024 Microsoft Corporation. | 
					
						
							|  |  |  | # Licensed under the MIT License | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  | """
 | 
					
						
							|  |  |  | Reference: | 
					
						
							|  |  |  |  - [graphrag](https://github.com/microsoft/graphrag) | 
					
						
							|  |  |  | """
 | 
					
						
							| 
									
										
										
										
											2024-08-25 18:58:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  | import logging | 
					
						
							|  |  |  | import numbers | 
					
						
							|  |  |  | import re | 
					
						
							|  |  |  | import traceback | 
					
						
							|  |  |  | from dataclasses import dataclass | 
					
						
							| 
									
										
										
										
											2024-08-06 16:01:43 +08:00
										 |  |  | from typing import Any, Mapping, Callable | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  | import tiktoken | 
					
						
							|  |  |  | from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT | 
					
						
							|  |  |  | from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str | 
					
						
							|  |  |  | from rag.llm.chat_model import Base as CompletionLLM | 
					
						
							|  |  |  | import networkx as nx | 
					
						
							|  |  |  | from rag.utils import num_tokens_from_string | 
					
						
							| 
									
										
										
										
											2024-08-06 16:01:43 +08:00
										 |  |  | from timeit import default_timer as timer | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | DEFAULT_TUPLE_DELIMITER = "<|>" | 
					
						
							|  |  |  | DEFAULT_RECORD_DELIMITER = "##" | 
					
						
							|  |  |  | DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" | 
					
						
							|  |  |  | DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"] | 
					
						
							|  |  |  | ENTITY_EXTRACTION_MAX_GLEANINGS = 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @dataclass | 
					
						
							|  |  |  | class GraphExtractionResult: | 
					
						
							|  |  |  |     """Unipartite graph extraction result class definition.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     output: nx.Graph | 
					
						
							|  |  |  |     source_docs: dict[Any, Any] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class GraphExtractor: | 
					
						
							|  |  |  |     """Unipartite graph extractor class definition.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     _llm: CompletionLLM | 
					
						
							|  |  |  |     _join_descriptions: bool | 
					
						
							|  |  |  |     _tuple_delimiter_key: str | 
					
						
							|  |  |  |     _record_delimiter_key: str | 
					
						
							|  |  |  |     _entity_types_key: str | 
					
						
							|  |  |  |     _input_text_key: str | 
					
						
							|  |  |  |     _completion_delimiter_key: str | 
					
						
							|  |  |  |     _entity_name_key: str | 
					
						
							|  |  |  |     _input_descriptions_key: str | 
					
						
							|  |  |  |     _extraction_prompt: str | 
					
						
							|  |  |  |     _summarization_prompt: str | 
					
						
							|  |  |  |     _loop_args: dict[str, Any] | 
					
						
							|  |  |  |     _max_gleanings: int | 
					
						
							|  |  |  |     _on_error: ErrorHandlerFn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         llm_invoker: CompletionLLM, | 
					
						
							|  |  |  |         prompt: str | None = None, | 
					
						
							|  |  |  |         tuple_delimiter_key: str | None = None, | 
					
						
							|  |  |  |         record_delimiter_key: str | None = None, | 
					
						
							|  |  |  |         input_text_key: str | None = None, | 
					
						
							|  |  |  |         entity_types_key: str | None = None, | 
					
						
							|  |  |  |         completion_delimiter_key: str | None = None, | 
					
						
							|  |  |  |         join_descriptions=True, | 
					
						
							|  |  |  |         encoding_model: str | None = None, | 
					
						
							|  |  |  |         max_gleanings: int | None = None, | 
					
						
							|  |  |  |         on_error: ErrorHandlerFn | None = None, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         """Init method definition.""" | 
					
						
							|  |  |  |         # TODO: streamline construction | 
					
						
							|  |  |  |         self._llm = llm_invoker | 
					
						
							|  |  |  |         self._join_descriptions = join_descriptions | 
					
						
							|  |  |  |         self._input_text_key = input_text_key or "input_text" | 
					
						
							|  |  |  |         self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" | 
					
						
							|  |  |  |         self._record_delimiter_key = record_delimiter_key or "record_delimiter" | 
					
						
							|  |  |  |         self._completion_delimiter_key = ( | 
					
						
							|  |  |  |             completion_delimiter_key or "completion_delimiter" | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self._entity_types_key = entity_types_key or "entity_types" | 
					
						
							|  |  |  |         self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT | 
					
						
							|  |  |  |         self._max_gleanings = ( | 
					
						
							|  |  |  |             max_gleanings | 
					
						
							|  |  |  |             if max_gleanings is not None | 
					
						
							|  |  |  |             else ENTITY_EXTRACTION_MAX_GLEANINGS | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self._on_error = on_error or (lambda _e, _s, _d: None) | 
					
						
							|  |  |  |         self.prompt_token_count = num_tokens_from_string(self._extraction_prompt) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Construct the looping arguments | 
					
						
							|  |  |  |         encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") | 
					
						
							|  |  |  |         yes = encoding.encode("YES") | 
					
						
							|  |  |  |         no = encoding.encode("NO") | 
					
						
							|  |  |  |         self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__( | 
					
						
							| 
									
										
										
										
											2024-08-06 16:01:43 +08:00
										 |  |  |         self, texts: list[str], | 
					
						
							|  |  |  |             prompt_variables: dict[str, Any] | None = None, | 
					
						
							|  |  |  |             callback: Callable | None = None | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  |     ) -> GraphExtractionResult: | 
					
						
							|  |  |  |         """Call method definition.""" | 
					
						
							|  |  |  |         if prompt_variables is None: | 
					
						
							|  |  |  |             prompt_variables = {} | 
					
						
							|  |  |  |         all_records: dict[int, str] = {} | 
					
						
							|  |  |  |         source_doc_map: dict[int, str] = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Wire defaults into the prompt variables | 
					
						
							|  |  |  |         prompt_variables = { | 
					
						
							|  |  |  |             **prompt_variables, | 
					
						
							|  |  |  |             self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key) | 
					
						
							|  |  |  |             or DEFAULT_TUPLE_DELIMITER, | 
					
						
							|  |  |  |             self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) | 
					
						
							|  |  |  |             or DEFAULT_RECORD_DELIMITER, | 
					
						
							|  |  |  |             self._completion_delimiter_key: prompt_variables.get( | 
					
						
							|  |  |  |                 self._completion_delimiter_key | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             or DEFAULT_COMPLETION_DELIMITER, | 
					
						
							|  |  |  |             self._entity_types_key: ",".join( | 
					
						
							|  |  |  |                 prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES | 
					
						
							|  |  |  |             ), | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-06 16:01:43 +08:00
										 |  |  |         st = timer() | 
					
						
							|  |  |  |         total = len(texts) | 
					
						
							|  |  |  |         total_token_count = 0 | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  |         for doc_index, text in enumerate(texts): | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 # Invoke the entity extraction | 
					
						
							| 
									
										
										
										
											2024-08-06 16:01:43 +08:00
										 |  |  |                 result, token_count = self._process_document(text, prompt_variables) | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  |                 source_doc_map[doc_index] = text | 
					
						
							|  |  |  |                 all_records[doc_index] = result | 
					
						
							| 
									
										
										
										
											2024-08-06 16:01:43 +08:00
										 |  |  |                 total_token_count += token_count | 
					
						
							|  |  |  |                 if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}") | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  |             except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-08-26 14:12:52 +08:00
										 |  |  |                 if callback: callback(msg="Knowledge graph extraction error:{}".format(str(e))) | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  |                 logging.exception("error extracting graph") | 
					
						
							|  |  |  |                 self._on_error( | 
					
						
							|  |  |  |                     e, | 
					
						
							|  |  |  |                     traceback.format_exc(), | 
					
						
							|  |  |  |                     { | 
					
						
							|  |  |  |                         "doc_index": doc_index, | 
					
						
							|  |  |  |                         "text": text, | 
					
						
							|  |  |  |                     }, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         output = self._process_results( | 
					
						
							|  |  |  |             all_records, | 
					
						
							|  |  |  |             prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER), | 
					
						
							|  |  |  |             prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER), | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return GraphExtractionResult( | 
					
						
							|  |  |  |             output=output, | 
					
						
							|  |  |  |             source_docs=source_doc_map, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _process_document( | 
					
						
							|  |  |  |         self, text: str, prompt_variables: dict[str, str] | 
					
						
							|  |  |  |     ) -> str: | 
					
						
							|  |  |  |         variables = { | 
					
						
							|  |  |  |             **prompt_variables, | 
					
						
							|  |  |  |             self._input_text_key: text, | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2024-08-06 16:01:43 +08:00
										 |  |  |         token_count = 0 | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  |         text = perform_variable_replacements(self._extraction_prompt, variables=variables) | 
					
						
							| 
									
										
										
										
											2024-08-06 16:01:43 +08:00
										 |  |  |         gen_conf = {"temperature": 0.3} | 
					
						
							| 
									
										
										
										
											2024-09-06 10:04:01 +08:00
										 |  |  |         response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | 
					
						
							| 
									
										
										
										
											2024-10-21 12:11:08 +08:00
										 |  |  |         if response.find("**ERROR**") >= 0: raise Exception(response) | 
					
						
							| 
									
										
										
										
											2024-08-06 16:01:43 +08:00
										 |  |  |         token_count = num_tokens_from_string(text + response) | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         results = response or "" | 
					
						
							|  |  |  |         history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Repeat to ensure we maximize entity count | 
					
						
							|  |  |  |         for i in range(self._max_gleanings): | 
					
						
							|  |  |  |             text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) | 
					
						
							|  |  |  |             history.append({"role": "user", "content": text}) | 
					
						
							|  |  |  |             response = self._llm.chat("", history, gen_conf) | 
					
						
							| 
									
										
										
										
											2024-08-22 09:28:23 +08:00
										 |  |  |             if response.find("**ERROR**") >=0: raise Exception(response) | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  |             results += response or "" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # if this is the final glean, don't bother updating the continuation flag | 
					
						
							|  |  |  |             if i >= self._max_gleanings - 1: | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  |             history.append({"role": "assistant", "content": response}) | 
					
						
							|  |  |  |             history.append({"role": "user", "content": LOOP_PROMPT}) | 
					
						
							|  |  |  |             continuation = self._llm.chat("", history, self._loop_args) | 
					
						
							|  |  |  |             if continuation != "YES": | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-06 16:01:43 +08:00
										 |  |  |         return results, token_count | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def _process_results( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         results: dict[int, str], | 
					
						
							|  |  |  |         tuple_delimiter: str, | 
					
						
							|  |  |  |         record_delimiter: str, | 
					
						
							|  |  |  |     ) -> nx.Graph: | 
					
						
							|  |  |  |         """Parse the result string to create an undirected unipartite graph.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             - results - dict of results from the extraction chain | 
					
						
							|  |  |  |             - tuple_delimiter - delimiter between tuples in an output record, default is '<|>' | 
					
						
							|  |  |  |             - record_delimiter - delimiter between records, default is '##' | 
					
						
							|  |  |  |         Returns: | 
					
						
							|  |  |  |             - output - unipartite graph in graphML format | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         graph = nx.Graph() | 
					
						
							|  |  |  |         for source_doc_id, extracted_data in results.items(): | 
					
						
							|  |  |  |             records = [r.strip() for r in extracted_data.split(record_delimiter)] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for record in records: | 
					
						
							|  |  |  |                 record = re.sub(r"^\(|\)$", "", record.strip()) | 
					
						
							|  |  |  |                 record_attributes = record.split(tuple_delimiter) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if record_attributes[0] == '"entity"' and len(record_attributes) >= 4: | 
					
						
							|  |  |  |                     # add this record as a node in the G | 
					
						
							|  |  |  |                     entity_name = clean_str(record_attributes[1].upper()) | 
					
						
							|  |  |  |                     entity_type = clean_str(record_attributes[2].upper()) | 
					
						
							|  |  |  |                     entity_description = clean_str(record_attributes[3]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     if entity_name in graph.nodes(): | 
					
						
							|  |  |  |                         node = graph.nodes[entity_name] | 
					
						
							|  |  |  |                         if self._join_descriptions: | 
					
						
							|  |  |  |                             node["description"] = "\n".join( | 
					
						
							|  |  |  |                                 list({ | 
					
						
							|  |  |  |                                     *_unpack_descriptions(node), | 
					
						
							|  |  |  |                                     entity_description, | 
					
						
							|  |  |  |                                 }) | 
					
						
							|  |  |  |                             ) | 
					
						
							|  |  |  |                         else: | 
					
						
							|  |  |  |                             if len(entity_description) > len(node["description"]): | 
					
						
							|  |  |  |                                 node["description"] = entity_description | 
					
						
							|  |  |  |                         node["source_id"] = ", ".join( | 
					
						
							|  |  |  |                             list({ | 
					
						
							|  |  |  |                                 *_unpack_source_ids(node), | 
					
						
							|  |  |  |                                 str(source_doc_id), | 
					
						
							|  |  |  |                             }) | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                         node["entity_type"] = ( | 
					
						
							|  |  |  |                             entity_type if entity_type != "" else node["entity_type"] | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         graph.add_node( | 
					
						
							|  |  |  |                             entity_name, | 
					
						
							|  |  |  |                             entity_type=entity_type, | 
					
						
							|  |  |  |                             description=entity_description, | 
					
						
							|  |  |  |                             source_id=str(source_doc_id), | 
					
						
							|  |  |  |                             weight=1 | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if ( | 
					
						
							|  |  |  |                     record_attributes[0] == '"relationship"' | 
					
						
							|  |  |  |                     and len(record_attributes) >= 5 | 
					
						
							|  |  |  |                 ): | 
					
						
							|  |  |  |                     # add this record as edge | 
					
						
							|  |  |  |                     source = clean_str(record_attributes[1].upper()) | 
					
						
							|  |  |  |                     target = clean_str(record_attributes[2].upper()) | 
					
						
							|  |  |  |                     edge_description = clean_str(record_attributes[3]) | 
					
						
							|  |  |  |                     edge_source_id = clean_str(str(source_doc_id)) | 
					
						
							|  |  |  |                     weight = ( | 
					
						
							|  |  |  |                         float(record_attributes[-1]) | 
					
						
							|  |  |  |                         if isinstance(record_attributes[-1], numbers.Number) | 
					
						
							|  |  |  |                         else 1.0 | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                     if source not in graph.nodes(): | 
					
						
							|  |  |  |                         graph.add_node( | 
					
						
							|  |  |  |                             source, | 
					
						
							|  |  |  |                             entity_type="", | 
					
						
							|  |  |  |                             description="", | 
					
						
							|  |  |  |                             source_id=edge_source_id, | 
					
						
							|  |  |  |                             weight=1 | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                     if target not in graph.nodes(): | 
					
						
							|  |  |  |                         graph.add_node( | 
					
						
							|  |  |  |                             target, | 
					
						
							|  |  |  |                             entity_type="", | 
					
						
							|  |  |  |                             description="", | 
					
						
							|  |  |  |                             source_id=edge_source_id, | 
					
						
							|  |  |  |                             weight=1 | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                     if graph.has_edge(source, target): | 
					
						
							|  |  |  |                         edge_data = graph.get_edge_data(source, target) | 
					
						
							|  |  |  |                         if edge_data is not None: | 
					
						
							|  |  |  |                             weight += edge_data["weight"] | 
					
						
							|  |  |  |                             if self._join_descriptions: | 
					
						
							|  |  |  |                                 edge_description = "\n".join( | 
					
						
							|  |  |  |                                     list({ | 
					
						
							|  |  |  |                                         *_unpack_descriptions(edge_data), | 
					
						
							|  |  |  |                                         edge_description, | 
					
						
							|  |  |  |                                     }) | 
					
						
							|  |  |  |                                 ) | 
					
						
							|  |  |  |                             edge_source_id = ", ".join( | 
					
						
							|  |  |  |                                 list({ | 
					
						
							|  |  |  |                                     *_unpack_source_ids(edge_data), | 
					
						
							|  |  |  |                                     str(source_doc_id), | 
					
						
							|  |  |  |                                 }) | 
					
						
							|  |  |  |                             ) | 
					
						
							|  |  |  |                     graph.add_edge( | 
					
						
							|  |  |  |                         source, | 
					
						
							|  |  |  |                         target, | 
					
						
							|  |  |  |                         weight=weight, | 
					
						
							|  |  |  |                         description=edge_description, | 
					
						
							|  |  |  |                         source_id=edge_source_id, | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for node_degree in graph.degree: | 
					
						
							|  |  |  |             graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) | 
					
						
							|  |  |  |         return graph | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _unpack_descriptions(data: Mapping) -> list[str]: | 
					
						
							|  |  |  |     value = data.get("description", None) | 
					
						
							|  |  |  |     return [] if value is None else value.split("\n") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _unpack_source_ids(data: Mapping) -> list[str]: | 
					
						
							|  |  |  |     value = data.get("source_id", None) | 
					
						
							|  |  |  |     return [] if value is None else value.split(", ") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 |