2024-08-25 18:58:20 +08:00
|
|
|
# Copyright (c) 2024 Microsoft Corporation.
|
|
|
|
# Licensed under the MIT License
|
2024-08-02 18:51:14 +08:00
|
|
|
"""
|
|
|
|
Reference:
|
|
|
|
- [graphrag](https://github.com/microsoft/graphrag)
|
|
|
|
"""
|
|
|
|
|
2024-11-14 17:13:48 +08:00
|
|
|
import logging
|
2024-08-02 18:51:14 +08:00
|
|
|
import json
|
|
|
|
import re
|
2024-11-18 17:38:17 +08:00
|
|
|
from typing import Callable
|
2024-08-02 18:51:14 +08:00
|
|
|
from dataclasses import dataclass
|
|
|
|
import networkx as nx
|
|
|
|
import pandas as pd
|
2025-01-22 19:43:14 +08:00
|
|
|
from graphrag.general import leiden
|
|
|
|
from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
|
|
|
from graphrag.general.extractor import Extractor
|
|
|
|
from graphrag.general.leiden import add_community_info2graph
|
2024-08-02 18:51:14 +08:00
|
|
|
from rag.llm.chat_model import Base as CompletionLLM
|
2025-03-03 18:59:49 +08:00
|
|
|
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
2024-08-06 16:01:43 +08:00
|
|
|
from rag.utils import num_tokens_from_string
|
2025-03-03 18:59:49 +08:00
|
|
|
import trio
|
2024-08-02 18:51:14 +08:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class CommunityReportsResult:
|
|
|
|
"""Community reports result class definition."""
|
|
|
|
|
2024-11-18 17:38:17 +08:00
|
|
|
output: list[str]
|
|
|
|
structured_output: list[dict]
|
2024-08-02 18:51:14 +08:00
|
|
|
|
|
|
|
|
2024-12-17 09:48:03 +08:00
|
|
|
class CommunityReportsExtractor(Extractor):
|
2024-08-02 18:51:14 +08:00
|
|
|
"""Community reports extractor class definition."""
|
|
|
|
|
|
|
|
_extraction_prompt: str
|
|
|
|
_output_formatter_prompt: str
|
|
|
|
_max_report_length: int
|
|
|
|
|
|
|
|
def __init__(
|
2025-01-22 19:43:14 +08:00
|
|
|
self,
|
|
|
|
llm_invoker: CompletionLLM,
|
|
|
|
max_report_length: int | None = None,
|
2024-08-02 18:51:14 +08:00
|
|
|
):
|
2025-03-26 15:34:42 +08:00
|
|
|
super().__init__(llm_invoker)
|
2024-08-02 18:51:14 +08:00
|
|
|
"""Init method definition."""
|
|
|
|
self._llm = llm_invoker
|
2025-01-22 19:43:14 +08:00
|
|
|
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
|
2024-08-02 18:51:14 +08:00
|
|
|
self._max_report_length = max_report_length or 1500
|
|
|
|
|
2025-03-03 18:59:49 +08:00
|
|
|
async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
2025-01-22 19:43:14 +08:00
|
|
|
for node_degree in graph.degree:
|
|
|
|
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
|
|
|
|
2024-11-18 17:38:17 +08:00
|
|
|
communities: dict[str, dict[str, list]] = leiden.run(graph, {})
|
2024-08-06 16:01:43 +08:00
|
|
|
total = sum([len(comm.items()) for _, comm in communities.items()])
|
2024-08-02 18:51:14 +08:00
|
|
|
res_str = []
|
|
|
|
res_dict = []
|
2024-08-06 16:01:43 +08:00
|
|
|
over, token_count = 0, 0
|
2025-03-11 18:36:10 +08:00
|
|
|
async def extract_community_report(community):
|
|
|
|
nonlocal res_str, res_dict, over, token_count
|
2025-03-26 15:34:42 +08:00
|
|
|
cm_id, cm = community
|
|
|
|
weight = cm["weight"]
|
|
|
|
ents = cm["nodes"]
|
|
|
|
if len(ents) < 2:
|
2025-03-11 18:36:10 +08:00
|
|
|
return
|
2025-03-26 15:34:42 +08:00
|
|
|
ent_list = [{"entity": ent, "description": graph.nodes[ent]["description"]} for ent in ents]
|
|
|
|
ent_df = pd.DataFrame(ent_list)
|
|
|
|
|
|
|
|
rela_list = []
|
|
|
|
k = 0
|
|
|
|
for i in range(0, len(ents)):
|
|
|
|
if k >= 10000:
|
|
|
|
break
|
|
|
|
for j in range(i + 1, len(ents)):
|
|
|
|
if k >= 10000:
|
|
|
|
break
|
|
|
|
edge = graph.get_edge_data(ents[i], ents[j])
|
|
|
|
if edge is None:
|
|
|
|
continue
|
|
|
|
rela_list.append({"source": ents[i], "target": ents[j], "description": edge["description"]})
|
|
|
|
k += 1
|
|
|
|
rela_df = pd.DataFrame(rela_list)
|
2025-03-11 18:36:10 +08:00
|
|
|
|
|
|
|
prompt_variables = {
|
|
|
|
"entity_df": ent_df.to_csv(index_label="id"),
|
|
|
|
"relation_df": rela_df.to_csv(index_label="id")
|
|
|
|
}
|
|
|
|
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
|
|
|
gen_conf = {"temperature": 0.3}
|
|
|
|
async with chat_limiter:
|
2025-06-12 19:09:50 +08:00
|
|
|
try:
|
|
|
|
with trio.move_on_after(120) as cancel_scope:
|
|
|
|
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf)
|
|
|
|
if cancel_scope.cancelled_caught:
|
|
|
|
logging.warning("extract_community_report._chat timeout, skipping...")
|
|
|
|
return
|
|
|
|
except Exception as e:
|
|
|
|
logging.error(f"extract_community_report._chat failed: {e}")
|
|
|
|
return
|
2025-03-11 18:36:10 +08:00
|
|
|
token_count += num_tokens_from_string(text + response)
|
|
|
|
response = re.sub(r"^[^\{]*", "", response)
|
|
|
|
response = re.sub(r"[^\}]*$", "", response)
|
|
|
|
response = re.sub(r"\{\{", "{", response)
|
|
|
|
response = re.sub(r"\}\}", "}", response)
|
|
|
|
logging.debug(response)
|
|
|
|
try:
|
|
|
|
response = json.loads(response)
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
logging.error(f"Failed to parse JSON response: {e}")
|
|
|
|
logging.error(f"Response content: {response}")
|
|
|
|
return
|
|
|
|
if not dict_has_keys_with_types(response, [
|
|
|
|
("title", str),
|
|
|
|
("summary", str),
|
|
|
|
("findings", list),
|
|
|
|
("rating", float),
|
|
|
|
("rating_explanation", str),
|
|
|
|
]):
|
|
|
|
return
|
|
|
|
response["weight"] = weight
|
|
|
|
response["entities"] = ents
|
|
|
|
add_community_info2graph(graph, ents, response["title"])
|
|
|
|
res_str.append(self._get_text_output(response))
|
|
|
|
res_dict.append(response)
|
|
|
|
over += 1
|
|
|
|
if callback:
|
|
|
|
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
|
|
|
|
|
|
|
|
st = trio.current_time()
|
|
|
|
async with trio.open_nursery() as nursery:
|
|
|
|
for level, comm in communities.items():
|
|
|
|
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
|
|
|
for community in comm.items():
|
fix(nursery): Fix Closure Trap Issues in Trio Concurrent Tasks (#7106)
## Problem Description
Multiple files in the RAGFlow project contain closure trap issues when
using lambda functions with `trio.open_nursery()`. This problem causes
concurrent tasks created in loops to reference the same variable,
resulting in all tasks processing the same data (the data from the last
iteration) rather than each task processing its corresponding data from
the loop.
## Issue Details
When using a `lambda` to create a closure function and passing it to
`nursery.start_soon()` within a loop, the lambda function captures a
reference to the loop variable rather than its value. For example:
```python
# Problematic code
async with trio.open_nursery() as nursery:
for d in docs:
nursery.start_soon(lambda: doc_keyword_extraction(chat_mdl, d, topn))
```
In this pattern, when concurrent tasks begin execution, `d` has already
become the value after the loop ends (typically the last element),
causing all tasks to use the same data.
## Fix Solution
Changed the way concurrent tasks are created with `nursery.start_soon()`
by leveraging Trio's API design to directly pass the function and its
arguments separately:
```python
# Fixed code
async with trio.open_nursery() as nursery:
for d in docs:
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, topn)
```
This way, each task uses the parameter values at the time of the
function call, rather than references captured through closures.
## Fixed Files
Fixed closure traps in the following files:
1. `rag/svr/task_executor.py`: 3 fixes, involving document keyword
extraction, question generation, and tag processing
2. `rag/raptor.py`: 1 fix, involving document summarization
3. `graphrag/utils.py`: 2 fixes, involving graph node and edge
processing
4. `graphrag/entity_resolution.py`: 2 fixes, involving entity resolution
and graph node merging
5. `graphrag/general/mind_map_extractor.py`: 2 fixes, involving document
processing
6. `graphrag/general/extractor.py`: 3 fixes, involving content
processing and graph node/edge merging
7. `graphrag/general/community_reports_extractor.py`: 1 fix, involving
community report extraction
## Potential Impact
This fix resolves a serious concurrency issue that could have caused:
- Data processing errors (processing duplicate data)
- Performance degradation (all tasks working on the same data)
- Inconsistent results (some data not being processed)
After the fix, all concurrent tasks should correctly process their
respective data, improving system correctness and reliability.
2025-04-18 18:00:20 +08:00
|
|
|
nursery.start_soon(extract_community_report, community)
|
2025-03-11 18:36:10 +08:00
|
|
|
if callback:
|
|
|
|
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
|
2024-08-02 18:51:14 +08:00
|
|
|
|
|
|
|
return CommunityReportsResult(
|
|
|
|
structured_output=res_dict,
|
|
|
|
output=res_str,
|
|
|
|
)
|
|
|
|
|
|
|
|
def _get_text_output(self, parsed_output: dict) -> str:
|
|
|
|
title = parsed_output.get("title", "Report")
|
|
|
|
summary = parsed_output.get("summary", "")
|
|
|
|
findings = parsed_output.get("findings", [])
|
|
|
|
|
|
|
|
def finding_summary(finding: dict):
|
|
|
|
if isinstance(finding, str):
|
|
|
|
return finding
|
|
|
|
return finding.get("summary")
|
|
|
|
|
|
|
|
def finding_explanation(finding: dict):
|
|
|
|
if isinstance(finding, str):
|
|
|
|
return ""
|
|
|
|
return finding.get("explanation")
|
|
|
|
|
|
|
|
report_sections = "\n\n".join(
|
|
|
|
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
|
|
|
|
)
|
2024-09-06 10:04:01 +08:00
|
|
|
return f"# {title}\n\n{summary}\n\n{report_sections}"
|