diff --git a/common/data_source/google_util/oauth_flow.py b/common/data_source/google_util/oauth_flow.py index 7e39e5283..e6ba58274 100644 --- a/common/data_source/google_util/oauth_flow.py +++ b/common/data_source/google_util/oauth_flow.py @@ -3,15 +3,9 @@ import os import threading from typing import Any, Callable -import requests - from common.data_source.config import DocumentSource from common.data_source.google_util.constant import GOOGLE_SCOPES -GOOGLE_DEVICE_CODE_URL = "https://oauth2.googleapis.com/device/code" -GOOGLE_DEVICE_TOKEN_URL = "https://oauth2.googleapis.com/token" -DEFAULT_DEVICE_INTERVAL = 5 - def _get_requested_scopes(source: DocumentSource) -> list[str]: """Return the scopes to request, honoring an optional override env var.""" @@ -55,62 +49,6 @@ def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_messag return result.get("value") -def _extract_client_info(credentials: dict[str, Any]) -> tuple[str, str | None]: - if "client_id" in credentials: - return credentials["client_id"], credentials.get("client_secret") - for key in ("installed", "web"): - if key in credentials and isinstance(credentials[key], dict): - nested = credentials[key] - if "client_id" not in nested: - break - return nested["client_id"], nested.get("client_secret") - raise ValueError("Provided Google OAuth credentials are missing client_id.") - - -def start_device_authorization_flow( - credentials: dict[str, Any], - source: DocumentSource, -) -> tuple[dict[str, Any], dict[str, Any]]: - client_id, client_secret = _extract_client_info(credentials) - data = { - "client_id": client_id, - "scope": " ".join(_get_requested_scopes(source)), - } - if client_secret: - data["client_secret"] = client_secret - resp = requests.post(GOOGLE_DEVICE_CODE_URL, data=data, timeout=15) - resp.raise_for_status() - payload = resp.json() - state = { - "client_id": client_id, - "client_secret": client_secret, - "device_code": payload.get("device_code"), - "interval": payload.get("interval", DEFAULT_DEVICE_INTERVAL), - } - response_data = { - "user_code": payload.get("user_code"), - "verification_url": payload.get("verification_url") or payload.get("verification_uri"), - "verification_url_complete": payload.get("verification_url_complete") - or payload.get("verification_uri_complete"), - "expires_in": payload.get("expires_in"), - "interval": state["interval"], - } - return state, response_data - - -def poll_device_authorization_flow(state: dict[str, Any]) -> dict[str, Any]: - data = { - "client_id": state["client_id"], - "device_code": state["device_code"], - "grant_type": "urn:ietf:params:oauth:grant-type:device_code", - } - if state.get("client_secret"): - data["client_secret"] = state["client_secret"] - resp = requests.post(GOOGLE_DEVICE_TOKEN_URL, data=data, timeout=20) - resp.raise_for_status() - return resp.json() - - def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]: """Launch the standard Google OAuth local-server flow to mint user tokens.""" from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore @@ -125,10 +63,7 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT") port = int(preferred_port) if preferred_port else 0 timeout_secs = _get_oauth_timeout_secs() - timeout_message = ( - f"Google OAuth verification timed out after {timeout_secs} seconds. " - "Close any pending consent windows and rerun the connector configuration to try again." - ) + timeout_message = f"Google OAuth verification timed out after {timeout_secs} seconds. Close any pending consent windows and rerun the connector configuration to try again." print("Launching Google OAuth flow. A browser window should open shortly.") print("If it does not, copy the URL shown in the console into your browser manually.") @@ -153,11 +88,8 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource instructions = [ "Google rejected one or more of the requested OAuth scopes.", "Fix options:", - " 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes " - " (Drive metadata + Admin Directory read scopes), then re-run the flow.", + " 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes (Drive metadata + Admin Directory read scopes), then re-run the flow.", " 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.", - " 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes " - " (be aware the connector may lose functionality).", ] raise RuntimeError("\n".join(instructions)) from warning raise @@ -184,8 +116,6 @@ def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource) client_config = {"web": credentials["web"]} if client_config is None: - raise ValueError( - "Provided Google OAuth credentials are missing both tokens and a client configuration." - ) + raise ValueError("Provided Google OAuth credentials are missing both tokens and a client configuration.") return _run_local_server_flow(client_config, source) diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 1df38ed1c..495e562ed 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -114,7 +114,7 @@ class Extractor: async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""): out_results = [] error_count = 0 - max_errors = 3 + max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3)) limiter = trio.Semaphore(max_concurrency) diff --git a/rag/raptor.py b/rag/raptor.py index e6efe3504..a455d0127 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -15,27 +15,35 @@ # import logging import re -import umap + import numpy as np -from sklearn.mixture import GaussianMixture import trio +import umap +from sklearn.mixture import GaussianMixture from api.db.services.task_service import has_canceled from common.connection_utils import timeout from common.exceptions import TaskCanceledException +from common.token_utils import truncate from graphrag.utils import ( - get_llm_cache, + chat_limiter, get_embed_cache, + get_llm_cache, set_embed_cache, set_llm_cache, - chat_limiter, ) -from common.token_utils import truncate class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: def __init__( - self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1 + self, + max_cluster, + llm_model, + embd_model, + prompt, + max_token=512, + threshold=0.1, + max_errors=3, ): self._max_cluster = max_cluster self._llm_model = llm_model @@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: self._threshold = threshold self._prompt = prompt self._max_token = max_token + self._max_errors = max(1, max_errors) + self._error_count = 0 - @timeout(60*20) + @timeout(60 * 20) async def _chat(self, system, history, gen_conf): - response = await trio.to_thread.run_sync( - lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) - ) + cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)) + if cached: + return cached - if response: - return response - response = await trio.to_thread.run_sync( - lambda: self._llm_model.chat(system, history, gen_conf) - ) - response = re.sub(r"^.*", "", response, flags=re.DOTALL) - if response.find("**ERROR**") >= 0: - raise Exception(response) - await trio.to_thread.run_sync( - lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) - ) - return response + last_exc = None + for attempt in range(3): + try: + response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf)) + response = re.sub(r"^.*", "", response, flags=re.DOTALL) + if response.find("**ERROR**") >= 0: + raise Exception(response) + await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)) + return response + except Exception as exc: + last_exc = exc + logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc) + if attempt < 2: + await trio.sleep(1 + attempt) + + raise last_exc if last_exc else Exception("LLM chat failed without exception") @timeout(20) async def _embedding_encode(self, txt): - response = await trio.to_thread.run_sync( - lambda: get_embed_cache(self._embd_model.llm_name, txt) - ) + response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt)) if response is not None: return response embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) @@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: n_clusters = np.arange(1, max_clusters) bics = [] for n in n_clusters: - if task_id: if has_canceled(task_id): logging.info(f"Task {task_id} cancelled during get optimal clusters.") @@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: layers = [(0, len(chunks))] start, end = 0, len(chunks) - @timeout(60*20) + @timeout(60 * 20) async def summarize(ck_idx: list[int]): nonlocal chunks @@ -111,47 +122,50 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: raise TaskCanceledException(f"Task {task_id} was cancelled") texts = [chunks[i][0] for i in ck_idx] - len_per_chunk = int( - (self._llm_model.max_length - self._max_token) / len(texts) - ) - cluster_content = "\n".join( - [truncate(t, max(1, len_per_chunk)) for t in texts] - ) - async with chat_limiter: + len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) + cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) + try: + async with chat_limiter: + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.") + raise TaskCanceledException(f"Task {task_id} was cancelled") - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.") - raise TaskCanceledException(f"Task {task_id} was cancelled") + cnt = await self._chat( + "You're a helpful assistant.", + [ + { + "role": "user", + "content": self._prompt.format(cluster_content=cluster_content), + } + ], + {"max_tokens": max(self._max_token, 512)}, # fix issue: #10235 + ) + cnt = re.sub( + "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", + "", + cnt, + ) + logging.debug(f"SUM: {cnt}") - cnt = await self._chat( - "You're a helpful assistant.", - [ - { - "role": "user", - "content": self._prompt.format( - cluster_content=cluster_content - ), - } - ], - {"max_tokens": max(self._max_token, 512)}, # fix issue: #10235 - ) - cnt = re.sub( - "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", - "", - cnt, - ) - logging.debug(f"SUM: {cnt}") + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before RAPTOR embedding.") + raise TaskCanceledException(f"Task {task_id} was cancelled") - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before RAPTOR embedding.") - raise TaskCanceledException(f"Task {task_id} was cancelled") - - embds = await self._embedding_encode(cnt) - chunks.append((cnt, embds)) + embds = await self._embedding_encode(cnt) + chunks.append((cnt, embds)) + except TaskCanceledException: + raise + except Exception as exc: + self._error_count += 1 + warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}" + logging.warning(warn_msg) + if callback: + callback(msg=warn_msg) + if self._error_count >= self._max_errors: + raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc labels = [] while end - start > 1: - if task_id: if has_canceled(task_id): logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.") @@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: if len(embeddings) == 2: await summarize([start, start + 1]) if callback: - callback( - msg="Cluster one layer: {} -> {}".format( - end - start, len(chunks) - end - ) - ) + callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) labels.extend([0, 0]) layers.append((end, len(chunks))) start = end @@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: nursery.start_soon(summarize, ck_idx) - assert len(chunks) - end == n_clusters, "{} vs. {}".format( - len(chunks) - end, n_clusters - ) + assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) labels.extend(lbls) layers.append((end, len(chunks))) if callback: - callback( - msg="Cluster one layer: {} -> {}".format( - end - start, len(chunks) - end - ) - ) + callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) start = end end = len(chunks) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index d926415e5..a183bf0cf 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -649,6 +649,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si res = [] tk_count = 0 + max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3)) + async def generate(chunks, did): nonlocal tk_count, res raptor = Raptor( @@ -658,6 +660,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si raptor_config["prompt"], raptor_config["max_token"], raptor_config["threshold"], + max_errors=max_errors, ) original_length = len(chunks) chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])