mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-06 20:07:04 +00:00
### What problem does this PR solve? This feature is primarily ported from the [Onyx](https://github.com/onyx-dot-app/onyx) project with necessary modifications. Thanks for such a brilliant project. Minor: consistently use `google_drive` rather than `google_driver`. <img width="566" height="731" alt="image" src="https://github.com/user-attachments/assets/6f64e70e-881e-42c7-b45f-809d3e0024a4" /> <img width="904" height="830" alt="image" src="https://github.com/user-attachments/assets/dfa7d1ef-819a-4a82-8c52-0999f48ed4a6" /> <img width="911" height="869" alt="image" src="https://github.com/user-attachments/assets/39e792fb-9fbe-4f3d-9b3c-b2265186bc22" /> <img width="947" height="323" alt="image" src="https://github.com/user-attachments/assets/27d70e96-d9c0-42d9-8c89-276919b6d61d" /> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
158 lines
6.8 KiB
Python
158 lines
6.8 KiB
Python
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
from google.auth.transport.requests import Request # type: ignore
|
|
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore
|
|
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
|
|
|
|
from common.data_source.config import OAUTH_GOOGLE_DRIVE_CLIENT_ID, OAUTH_GOOGLE_DRIVE_CLIENT_SECRET, DocumentSource
|
|
from common.data_source.google_util.constant import (
|
|
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
|
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
|
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
|
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
|
GOOGLE_SCOPES,
|
|
GoogleOAuthAuthenticationMethod,
|
|
)
|
|
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
|
|
|
|
|
|
def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str:
|
|
"""we really don't want to be persisting the client id and secret anywhere but the
|
|
environment.
|
|
|
|
Returns a string of serialized json.
|
|
"""
|
|
|
|
# strip the client id and secret
|
|
oauth_creds_json_str = oauth_creds.to_json()
|
|
oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str)
|
|
oauth_creds_sanitized_json.pop("client_id", None)
|
|
oauth_creds_sanitized_json.pop("client_secret", None)
|
|
oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json)
|
|
return oauth_creds_sanitized_json_str
|
|
|
|
|
|
def get_google_creds(
|
|
credentials: dict[str, str],
|
|
source: DocumentSource,
|
|
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
|
"""Checks for two different types of credentials.
|
|
(1) A credential which holds a token acquired via a user going through
|
|
the Google OAuth flow.
|
|
(2) A credential which holds a service account key JSON file, which
|
|
can then be used to impersonate any user in the workspace.
|
|
|
|
Return a tuple where:
|
|
The first element is the requested credentials
|
|
The second element is a new credentials dict that the caller should write back
|
|
to the db. This happens if token rotation occurs while loading credentials.
|
|
"""
|
|
oauth_creds = None
|
|
service_creds = None
|
|
new_creds_dict = None
|
|
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
|
# OAUTH
|
|
authentication_method: str = credentials.get(
|
|
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
|
GoogleOAuthAuthenticationMethod.UPLOADED,
|
|
)
|
|
|
|
credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
|
|
credentials_dict = json.loads(credentials_dict_str)
|
|
|
|
regenerated_from_client_secret = False
|
|
if "client_id" not in credentials_dict or "client_secret" not in credentials_dict or "refresh_token" not in credentials_dict:
|
|
try:
|
|
credentials_dict = ensure_oauth_token_dict(credentials_dict, source)
|
|
except Exception as exc:
|
|
raise PermissionError(
|
|
"Google Drive OAuth credentials are incomplete. Please finish the OAuth flow to generate access tokens."
|
|
) from exc
|
|
credentials_dict_str = json.dumps(credentials_dict)
|
|
regenerated_from_client_secret = True
|
|
|
|
# only send what get_google_oauth_creds needs
|
|
authorized_user_info = {}
|
|
|
|
# oauth_interactive is sanitized and needs credentials from the environment
|
|
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
|
|
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
|
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
|
else:
|
|
authorized_user_info["client_id"] = credentials_dict["client_id"]
|
|
authorized_user_info["client_secret"] = credentials_dict["client_secret"]
|
|
|
|
authorized_user_info["refresh_token"] = credentials_dict["refresh_token"]
|
|
|
|
authorized_user_info["token"] = credentials_dict["token"]
|
|
authorized_user_info["expiry"] = credentials_dict["expiry"]
|
|
|
|
token_json_str = json.dumps(authorized_user_info)
|
|
oauth_creds = get_google_oauth_creds(token_json_str=token_json_str, source=source)
|
|
|
|
# tell caller to update token stored in DB if the refresh token changed
|
|
if oauth_creds:
|
|
should_persist = regenerated_from_client_secret or oauth_creds.refresh_token != authorized_user_info["refresh_token"]
|
|
if should_persist:
|
|
# if oauth_interactive, sanitize the credentials so they don't get stored in the db
|
|
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
|
|
oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds)
|
|
else:
|
|
oauth_creds_json_str = oauth_creds.to_json()
|
|
|
|
new_creds_dict = {
|
|
DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str,
|
|
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY],
|
|
DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method,
|
|
}
|
|
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
|
# SERVICE ACCOUNT
|
|
service_account_key_json_str = credentials[DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY]
|
|
service_account_key = json.loads(service_account_key_json_str)
|
|
|
|
service_creds = ServiceAccountCredentials.from_service_account_info(service_account_key, scopes=GOOGLE_SCOPES[source])
|
|
|
|
if not service_creds.valid or not service_creds.expired:
|
|
service_creds.refresh(Request())
|
|
|
|
if not service_creds.valid:
|
|
raise PermissionError(f"Unable to access {source} - service account credentials are invalid.")
|
|
|
|
creds: ServiceAccountCredentials | OAuthCredentials | None = oauth_creds or service_creds
|
|
if creds is None:
|
|
raise PermissionError(f"Unable to access {source} - unknown credential structure.")
|
|
|
|
return creds, new_creds_dict
|
|
|
|
|
|
def get_google_oauth_creds(token_json_str: str, source: DocumentSource) -> OAuthCredentials | None:
|
|
"""creds_json only needs to contain client_id, client_secret and refresh_token to
|
|
refresh the creds.
|
|
|
|
expiry and token are optional ... however, if passing in expiry, token
|
|
should also be passed in or else we may not return any creds.
|
|
(probably a sign we should refactor the function)
|
|
"""
|
|
|
|
creds_json = json.loads(token_json_str)
|
|
creds = OAuthCredentials.from_authorized_user_info(
|
|
info=creds_json,
|
|
scopes=GOOGLE_SCOPES[source],
|
|
)
|
|
if creds.valid:
|
|
return creds
|
|
|
|
if creds.expired and creds.refresh_token:
|
|
try:
|
|
creds.refresh(Request())
|
|
if creds.valid:
|
|
logging.info("Refreshed Google Drive tokens.")
|
|
return creds
|
|
except Exception:
|
|
logging.exception("Failed to refresh google drive access token")
|
|
return None
|
|
|
|
return None
|