mirror of
				https://github.com/infiniflow/ragflow.git
				synced 2025-10-25 23:09:20 +00:00 
			
		
		
		
	 c6e723f2ee
			
		
	
	
		c6e723f2ee
		
			
		
	
	
	
	
		
			
			### What problem does this PR solve? #2270 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
		
			
				
	
	
		
			130 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			130 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) 2024 Microsoft Corporation.
 | |
| # Licensed under the MIT License
 | |
| """
 | |
| Reference:
 | |
|  - [graphrag](https://github.com/microsoft/graphrag)
 | |
| """
 | |
| 
 | |
| import json
 | |
| import logging
 | |
| import re
 | |
| import traceback
 | |
| from dataclasses import dataclass
 | |
| from typing import Any, List, Callable
 | |
| import networkx as nx
 | |
| import pandas as pd
 | |
| from graphrag import leiden
 | |
| from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
 | |
| from graphrag.leiden import add_community_info2graph
 | |
| from rag.llm.chat_model import Base as CompletionLLM
 | |
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
 | |
| from rag.utils import num_tokens_from_string
 | |
| from timeit import default_timer as timer
 | |
| 
 | |
| log = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class CommunityReportsResult:
 | |
|     """Community reports result class definition."""
 | |
| 
 | |
|     output: List[str]
 | |
|     structured_output: List[dict]
 | |
| 
 | |
| 
 | |
| class CommunityReportsExtractor:
 | |
|     """Community reports extractor class definition."""
 | |
| 
 | |
|     _llm: CompletionLLM
 | |
|     _extraction_prompt: str
 | |
|     _output_formatter_prompt: str
 | |
|     _on_error: ErrorHandlerFn
 | |
|     _max_report_length: int
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         llm_invoker: CompletionLLM,
 | |
|         extraction_prompt: str | None = None,
 | |
|         on_error: ErrorHandlerFn | None = None,
 | |
|         max_report_length: int | None = None,
 | |
|     ):
 | |
|         """Init method definition."""
 | |
|         self._llm = llm_invoker
 | |
|         self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT
 | |
|         self._on_error = on_error or (lambda _e, _s, _d: None)
 | |
|         self._max_report_length = max_report_length or 1500
 | |
| 
 | |
|     def __call__(self, graph: nx.Graph, callback: Callable | None = None):
 | |
|         communities: dict[str, dict[str, List]] = leiden.run(graph, {})
 | |
|         total = sum([len(comm.items()) for _, comm in communities.items()])
 | |
|         relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
 | |
|         res_str = []
 | |
|         res_dict = []
 | |
|         over, token_count = 0, 0
 | |
|         st = timer()
 | |
|         for level, comm in communities.items():
 | |
|             for cm_id, ents in comm.items():
 | |
|                 weight = ents["weight"]
 | |
|                 ents = ents["nodes"]
 | |
|                 ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents])
 | |
|                 rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True)
 | |
| 
 | |
|                 prompt_variables = {
 | |
|                     "entity_df": ent_df.to_csv(index_label="id"),
 | |
|                     "relation_df": rela_df.to_csv(index_label="id")
 | |
|                 }
 | |
|                 text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
 | |
|                 gen_conf = {"temperature": 0.3}
 | |
|                 try:
 | |
|                     response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
 | |
|                     token_count += num_tokens_from_string(text + response)
 | |
|                     response = re.sub(r"^[^\{]*", "", response)
 | |
|                     response = re.sub(r"[^\}]*$", "", response)
 | |
|                     print(response)
 | |
|                     response = json.loads(response)
 | |
|                     if not dict_has_keys_with_types(response, [
 | |
|                                 ("title", str),
 | |
|                                 ("summary", str),
 | |
|                                 ("findings", list),
 | |
|                                 ("rating", float),
 | |
|                                 ("rating_explanation", str),
 | |
|                             ]): continue
 | |
|                     response["weight"] = weight
 | |
|                     response["entities"] = ents
 | |
|                 except Exception as e:
 | |
|                     print("ERROR: ", traceback.format_exc())
 | |
|                     self._on_error(e, traceback.format_exc(), None)
 | |
|                     continue
 | |
| 
 | |
|                 add_community_info2graph(graph, ents, response["title"])
 | |
|                 res_str.append(self._get_text_output(response))
 | |
|                 res_dict.append(response)
 | |
|                 over += 1
 | |
|                 if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
 | |
| 
 | |
|         return CommunityReportsResult(
 | |
|             structured_output=res_dict,
 | |
|             output=res_str,
 | |
|         )
 | |
| 
 | |
|     def _get_text_output(self, parsed_output: dict) -> str:
 | |
|         title = parsed_output.get("title", "Report")
 | |
|         summary = parsed_output.get("summary", "")
 | |
|         findings = parsed_output.get("findings", [])
 | |
| 
 | |
|         def finding_summary(finding: dict):
 | |
|             if isinstance(finding, str):
 | |
|                 return finding
 | |
|             return finding.get("summary")
 | |
| 
 | |
|         def finding_explanation(finding: dict):
 | |
|             if isinstance(finding, str):
 | |
|                 return ""
 | |
|             return finding.get("explanation")
 | |
| 
 | |
|         report_sections = "\n\n".join(
 | |
|             f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
 | |
|         )
 | |
|      
 | |
|         return f"# {title}\n\n{summary}\n\n{report_sections}"
 |