Refa: improve flow of GraphRAG and RAPTOR (#10709)

### What problem does this PR solve?

Improve flow of GraphRAG and RAPTOR.

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei 2025-10-22 09:29:20 +08:00 committed by GitHub
parent acc0f7396e
commit 2d491188b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 17 deletions

View File

@ -579,7 +579,7 @@ def run_graphrag():
sample_document = documents[0] sample_document = documents[0]
document_ids = [document["id"] for document in documents] document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}): if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}") logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
@ -648,7 +648,7 @@ def run_raptor():
sample_document = documents[0] sample_document = documents[0]
document_ids = [document["id"] for document in documents] document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}): if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
logging.warning(f"Cannot save raptor_task_id for kb {kb_id}") logging.warning(f"Cannot save raptor_task_id for kb {kb_id}")
@ -717,7 +717,7 @@ def run_mindmap():
sample_document = documents[0] sample_document = documents[0]
document_ids = [document["id"] for document in documents] document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}): if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}):
logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}") logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}")

View File

@ -671,9 +671,11 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def _sync_progress(cls, docs:list[dict]): def _sync_progress(cls, docs:list[dict]):
from api.db.services.task_service import TaskService
for d in docs: for d in docs:
try: try:
tsks = Task.query(doc_id=d["id"], order_by=Task.create_time) tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time)
if not tsks: if not tsks:
continue continue
msg = [] msg = []
@ -791,21 +793,23 @@ class DocumentService(CommonService):
"cancelled": int(cancelled), "cancelled": int(cancelled),
} }
def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]): def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", doc_ids=[]):
""" """
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level. You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
Optionally, specify a list of doc_ids to determine which documents participate in the task. Optionally, specify a list of doc_ids to determine which documents participate in the task.
""" """
chunking_config = DocumentService.get_chunking_config(doc["id"]) assert ty in ["graphrag", "raptor", "mindmap"], "type should be graphrag, raptor or mindmap"
chunking_config = DocumentService.get_chunking_config(sample_doc_id["id"])
hasher = xxhash.xxh64() hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()): for field in sorted(chunking_config.keys()):
hasher.update(str(chunking_config[field]).encode("utf-8")) hasher.update(str(chunking_config[field]).encode("utf-8"))
def new_task(): def new_task():
nonlocal doc nonlocal sample_doc_id
return { return {
"id": get_uuid(), "id": get_uuid(),
"doc_id": fake_doc_id if fake_doc_id else doc["id"], "doc_id": sample_doc_id["id"],
"from_page": 100000000, "from_page": 100000000,
"to_page": 100000000, "to_page": 100000000,
"task_type": ty, "task_type": ty,
@ -820,9 +824,9 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[])
task["digest"] = hasher.hexdigest() task["digest"] = hasher.hexdigest()
bulk_insert_into_db(Task, [task], True) bulk_insert_into_db(Task, [task], True)
if ty in ["graphrag", "raptor", "mindmap"]: task["doc_id"] = fake_doc_id
task["doc_ids"] = doc_ids task["doc_ids"] = doc_ids
DocumentService.begin2parse(doc["id"]) DocumentService.begin2parse(sample_doc_id["id"])
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status." assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
return task["id"] return task["id"]

View File

@ -228,9 +228,10 @@ async def collect():
canceled = False canceled = False
if msg.get("doc_id", "") in [GRAPH_RAPTOR_FAKE_DOC_ID, CANVAS_DEBUG_DOC_ID]: if msg.get("doc_id", "") in [GRAPH_RAPTOR_FAKE_DOC_ID, CANVAS_DEBUG_DOC_ID]:
task = msg task = msg
if task["task_type"] in ["graphrag", "raptor", "mindmap"] and msg.get("doc_ids", []): if task["task_type"] in ["graphrag", "raptor", "mindmap"]:
task = TaskService.get_task(msg["id"], msg["doc_ids"]) task = TaskService.get_task(msg["id"], msg["doc_ids"])
task["doc_ids"] = msg["doc_ids"] task["doc_id"] = msg["doc_id"]
task["doc_ids"] = msg.get("doc_ids", []) or []
else: else:
task = TaskService.get_task(msg["id"]) task = TaskService.get_task(msg["id"])
@ -1052,12 +1053,12 @@ async def task_manager():
async def main(): async def main():
logging.info(r""" logging.info(r"""
____ __ _ ____ __ _
/ _/___ ____ ____ _____/ /_(_)___ ____ ________ ______ _____ _____ / _/___ ____ ____ _____/ /_(_)___ ____ ________ ______ _____ _____
/ // __ \/ __ `/ _ \/ ___/ __/ / __ \/ __ \ / ___/ _ \/ ___/ | / / _ \/ ___/ / // __ \/ __ `/ _ \/ ___/ __/ / __ \/ __ \ / ___/ _ \/ ___/ | / / _ \/ ___/
_/ // / / / /_/ / __(__ ) /_/ / /_/ / / / / (__ ) __/ / | |/ / __/ / _/ // / / / /_/ / __(__ ) /_/ / /_/ / / / / (__ ) __/ / | |/ / __/ /
/___/_/ /_/\__, /\___/____/\__/_/\____/_/ /_/ /____/\___/_/ |___/\___/_/ /___/_/ /_/\__, /\___/____/\__/_/\____/_/ /_/ /____/\___/_/ |___/\___/_/
/____/ /____/
""") """)
logging.info(f'RAGFlow version: {get_ragflow_version()}') logging.info(f'RAGFlow version: {get_ragflow_version()}')
settings.init_settings() settings.init_settings()