mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-11-24 22:16:57 +00:00
Feat: add fault-tolerant mechanism to RAPTOR (#11206)
### What problem does this PR solve? Add fault-tolerant mechanism to RAPTOR. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
70a0f081f6
commit
908450509f
@ -3,15 +3,9 @@ import os
|
||||
import threading
|
||||
from typing import Any, Callable
|
||||
|
||||
import requests
|
||||
|
||||
from common.data_source.config import DocumentSource
|
||||
from common.data_source.google_util.constant import GOOGLE_SCOPES
|
||||
|
||||
GOOGLE_DEVICE_CODE_URL = "https://oauth2.googleapis.com/device/code"
|
||||
GOOGLE_DEVICE_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
DEFAULT_DEVICE_INTERVAL = 5
|
||||
|
||||
|
||||
def _get_requested_scopes(source: DocumentSource) -> list[str]:
|
||||
"""Return the scopes to request, honoring an optional override env var."""
|
||||
@ -55,62 +49,6 @@ def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_messag
|
||||
return result.get("value")
|
||||
|
||||
|
||||
def _extract_client_info(credentials: dict[str, Any]) -> tuple[str, str | None]:
|
||||
if "client_id" in credentials:
|
||||
return credentials["client_id"], credentials.get("client_secret")
|
||||
for key in ("installed", "web"):
|
||||
if key in credentials and isinstance(credentials[key], dict):
|
||||
nested = credentials[key]
|
||||
if "client_id" not in nested:
|
||||
break
|
||||
return nested["client_id"], nested.get("client_secret")
|
||||
raise ValueError("Provided Google OAuth credentials are missing client_id.")
|
||||
|
||||
|
||||
def start_device_authorization_flow(
|
||||
credentials: dict[str, Any],
|
||||
source: DocumentSource,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
client_id, client_secret = _extract_client_info(credentials)
|
||||
data = {
|
||||
"client_id": client_id,
|
||||
"scope": " ".join(_get_requested_scopes(source)),
|
||||
}
|
||||
if client_secret:
|
||||
data["client_secret"] = client_secret
|
||||
resp = requests.post(GOOGLE_DEVICE_CODE_URL, data=data, timeout=15)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
state = {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"device_code": payload.get("device_code"),
|
||||
"interval": payload.get("interval", DEFAULT_DEVICE_INTERVAL),
|
||||
}
|
||||
response_data = {
|
||||
"user_code": payload.get("user_code"),
|
||||
"verification_url": payload.get("verification_url") or payload.get("verification_uri"),
|
||||
"verification_url_complete": payload.get("verification_url_complete")
|
||||
or payload.get("verification_uri_complete"),
|
||||
"expires_in": payload.get("expires_in"),
|
||||
"interval": state["interval"],
|
||||
}
|
||||
return state, response_data
|
||||
|
||||
|
||||
def poll_device_authorization_flow(state: dict[str, Any]) -> dict[str, Any]:
|
||||
data = {
|
||||
"client_id": state["client_id"],
|
||||
"device_code": state["device_code"],
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
}
|
||||
if state.get("client_secret"):
|
||||
data["client_secret"] = state["client_secret"]
|
||||
resp = requests.post(GOOGLE_DEVICE_TOKEN_URL, data=data, timeout=20)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
|
||||
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
@ -125,10 +63,7 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
|
||||
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
|
||||
port = int(preferred_port) if preferred_port else 0
|
||||
timeout_secs = _get_oauth_timeout_secs()
|
||||
timeout_message = (
|
||||
f"Google OAuth verification timed out after {timeout_secs} seconds. "
|
||||
"Close any pending consent windows and rerun the connector configuration to try again."
|
||||
)
|
||||
timeout_message = f"Google OAuth verification timed out after {timeout_secs} seconds. Close any pending consent windows and rerun the connector configuration to try again."
|
||||
|
||||
print("Launching Google OAuth flow. A browser window should open shortly.")
|
||||
print("If it does not, copy the URL shown in the console into your browser manually.")
|
||||
@ -153,11 +88,8 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
|
||||
instructions = [
|
||||
"Google rejected one or more of the requested OAuth scopes.",
|
||||
"Fix options:",
|
||||
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
|
||||
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
|
||||
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes (Drive metadata + Admin Directory read scopes), then re-run the flow.",
|
||||
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
|
||||
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
|
||||
" (be aware the connector may lose functionality).",
|
||||
]
|
||||
raise RuntimeError("\n".join(instructions)) from warning
|
||||
raise
|
||||
@ -184,8 +116,6 @@ def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource)
|
||||
client_config = {"web": credentials["web"]}
|
||||
|
||||
if client_config is None:
|
||||
raise ValueError(
|
||||
"Provided Google OAuth credentials are missing both tokens and a client configuration."
|
||||
)
|
||||
raise ValueError("Provided Google OAuth credentials are missing both tokens and a client configuration.")
|
||||
|
||||
return _run_local_server_flow(client_config, source)
|
||||
|
||||
@ -114,7 +114,7 @@ class Extractor:
|
||||
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
|
||||
out_results = []
|
||||
error_count = 0
|
||||
max_errors = 3
|
||||
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
|
||||
|
||||
limiter = trio.Semaphore(max_concurrency)
|
||||
|
||||
|
||||
154
rag/raptor.py
154
rag/raptor.py
@ -15,27 +15,35 @@
|
||||
#
|
||||
import logging
|
||||
import re
|
||||
import umap
|
||||
|
||||
import numpy as np
|
||||
from sklearn.mixture import GaussianMixture
|
||||
import trio
|
||||
import umap
|
||||
from sklearn.mixture import GaussianMixture
|
||||
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.connection_utils import timeout
|
||||
from common.exceptions import TaskCanceledException
|
||||
from common.token_utils import truncate
|
||||
from graphrag.utils import (
|
||||
get_llm_cache,
|
||||
chat_limiter,
|
||||
get_embed_cache,
|
||||
get_llm_cache,
|
||||
set_embed_cache,
|
||||
set_llm_cache,
|
||||
chat_limiter,
|
||||
)
|
||||
from common.token_utils import truncate
|
||||
|
||||
|
||||
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
def __init__(
|
||||
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1
|
||||
self,
|
||||
max_cluster,
|
||||
llm_model,
|
||||
embd_model,
|
||||
prompt,
|
||||
max_token=512,
|
||||
threshold=0.1,
|
||||
max_errors=3,
|
||||
):
|
||||
self._max_cluster = max_cluster
|
||||
self._llm_model = llm_model
|
||||
@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
self._threshold = threshold
|
||||
self._prompt = prompt
|
||||
self._max_token = max_token
|
||||
self._max_errors = max(1, max_errors)
|
||||
self._error_count = 0
|
||||
|
||||
@timeout(60*20)
|
||||
@timeout(60 * 20)
|
||||
async def _chat(self, system, history, gen_conf):
|
||||
response = await trio.to_thread.run_sync(
|
||||
lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
|
||||
)
|
||||
cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
if response:
|
||||
return response
|
||||
response = await trio.to_thread.run_sync(
|
||||
lambda: self._llm_model.chat(system, history, gen_conf)
|
||||
)
|
||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
|
||||
)
|
||||
return response
|
||||
last_exc = None
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
|
||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
|
||||
return response
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
|
||||
if attempt < 2:
|
||||
await trio.sleep(1 + attempt)
|
||||
|
||||
raise last_exc if last_exc else Exception("LLM chat failed without exception")
|
||||
|
||||
@timeout(20)
|
||||
async def _embedding_encode(self, txt):
|
||||
response = await trio.to_thread.run_sync(
|
||||
lambda: get_embed_cache(self._embd_model.llm_name, txt)
|
||||
)
|
||||
response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
|
||||
if response is not None:
|
||||
return response
|
||||
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
|
||||
@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
n_clusters = np.arange(1, max_clusters)
|
||||
bics = []
|
||||
for n in n_clusters:
|
||||
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during get optimal clusters.")
|
||||
@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
layers = [(0, len(chunks))]
|
||||
start, end = 0, len(chunks)
|
||||
|
||||
@timeout(60*20)
|
||||
@timeout(60 * 20)
|
||||
async def summarize(ck_idx: list[int]):
|
||||
nonlocal chunks
|
||||
|
||||
@ -111,47 +122,50 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
texts = [chunks[i][0] for i in ck_idx]
|
||||
len_per_chunk = int(
|
||||
(self._llm_model.max_length - self._max_token) / len(texts)
|
||||
)
|
||||
cluster_content = "\n".join(
|
||||
[truncate(t, max(1, len_per_chunk)) for t in texts]
|
||||
)
|
||||
async with chat_limiter:
|
||||
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
|
||||
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
||||
try:
|
||||
async with chat_limiter:
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
cnt = await self._chat(
|
||||
"You're a helpful assistant.",
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._prompt.format(cluster_content=cluster_content),
|
||||
}
|
||||
],
|
||||
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
|
||||
)
|
||||
cnt = re.sub(
|
||||
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
|
||||
"",
|
||||
cnt,
|
||||
)
|
||||
logging.debug(f"SUM: {cnt}")
|
||||
|
||||
cnt = await self._chat(
|
||||
"You're a helpful assistant.",
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._prompt.format(
|
||||
cluster_content=cluster_content
|
||||
),
|
||||
}
|
||||
],
|
||||
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
|
||||
)
|
||||
cnt = re.sub(
|
||||
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
|
||||
"",
|
||||
cnt,
|
||||
)
|
||||
logging.debug(f"SUM: {cnt}")
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
embds = await self._embedding_encode(cnt)
|
||||
chunks.append((cnt, embds))
|
||||
embds = await self._embedding_encode(cnt)
|
||||
chunks.append((cnt, embds))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._error_count += 1
|
||||
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
|
||||
logging.warning(warn_msg)
|
||||
if callback:
|
||||
callback(msg=warn_msg)
|
||||
if self._error_count >= self._max_errors:
|
||||
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
|
||||
|
||||
labels = []
|
||||
while end - start > 1:
|
||||
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
|
||||
@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
if len(embeddings) == 2:
|
||||
await summarize([start, start + 1])
|
||||
if callback:
|
||||
callback(
|
||||
msg="Cluster one layer: {} -> {}".format(
|
||||
end - start, len(chunks) - end
|
||||
)
|
||||
)
|
||||
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
|
||||
labels.extend([0, 0])
|
||||
layers.append((end, len(chunks)))
|
||||
start = end
|
||||
@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
|
||||
nursery.start_soon(summarize, ck_idx)
|
||||
|
||||
assert len(chunks) - end == n_clusters, "{} vs. {}".format(
|
||||
len(chunks) - end, n_clusters
|
||||
)
|
||||
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
|
||||
labels.extend(lbls)
|
||||
layers.append((end, len(chunks)))
|
||||
if callback:
|
||||
callback(
|
||||
msg="Cluster one layer: {} -> {}".format(
|
||||
end - start, len(chunks) - end
|
||||
)
|
||||
)
|
||||
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
|
||||
start = end
|
||||
end = len(chunks)
|
||||
|
||||
|
||||
@ -649,6 +649,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
|
||||
res = []
|
||||
tk_count = 0
|
||||
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
|
||||
|
||||
async def generate(chunks, did):
|
||||
nonlocal tk_count, res
|
||||
raptor = Raptor(
|
||||
@ -658,6 +660,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
raptor_config["prompt"],
|
||||
raptor_config["max_token"],
|
||||
raptor_config["threshold"],
|
||||
max_errors=max_errors,
|
||||
)
|
||||
original_length = len(chunks)
|
||||
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user