Yongteng Lei df16a80f25
Feat: add initial Google Drive connector support (#11147)
### What problem does this PR solve?

This feature is primarily ported from the
[Onyx](https://github.com/onyx-dot-app/onyx) project with necessary
modifications. Thanks for such a brilliant project.

Minor: consistently use `google_drive` rather than `google_driver`.

<img width="566" height="731" alt="image"
src="https://github.com/user-attachments/assets/6f64e70e-881e-42c7-b45f-809d3e0024a4"
/>

<img width="904" height="830" alt="image"
src="https://github.com/user-attachments/assets/dfa7d1ef-819a-4a82-8c52-0999f48ed4a6"
/>

<img width="911" height="869" alt="image"
src="https://github.com/user-attachments/assets/39e792fb-9fbe-4f3d-9b3c-b2265186bc22"
/>

<img width="947" height="323" alt="image"
src="https://github.com/user-attachments/assets/27d70e96-d9c0-42d9-8c89-276919b6d61d"
/>


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-10 19:15:02 +08:00

158 lines
6.8 KiB
Python

import json
import logging
from typing import Any
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
from common.data_source.config import OAUTH_GOOGLE_DRIVE_CLIENT_ID, OAUTH_GOOGLE_DRIVE_CLIENT_SECRET, DocumentSource
from common.data_source.google_util.constant import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
DB_CREDENTIALS_DICT_TOKEN_KEY,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
GOOGLE_SCOPES,
GoogleOAuthAuthenticationMethod,
)
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str:
"""we really don't want to be persisting the client id and secret anywhere but the
environment.
Returns a string of serialized json.
"""
# strip the client id and secret
oauth_creds_json_str = oauth_creds.to_json()
oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str)
oauth_creds_sanitized_json.pop("client_id", None)
oauth_creds_sanitized_json.pop("client_secret", None)
oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json)
return oauth_creds_sanitized_json_str
def get_google_creds(
credentials: dict[str, str],
source: DocumentSource,
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going through
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
Return a tuple where:
The first element is the requested credentials
The second element is a new credentials dict that the caller should write back
to the db. This happens if token rotation occurs while loading credentials.
"""
oauth_creds = None
service_creds = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
# OAUTH
authentication_method: str = credentials.get(
DB_CREDENTIALS_AUTHENTICATION_METHOD,
GoogleOAuthAuthenticationMethod.UPLOADED,
)
credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
credentials_dict = json.loads(credentials_dict_str)
regenerated_from_client_secret = False
if "client_id" not in credentials_dict or "client_secret" not in credentials_dict or "refresh_token" not in credentials_dict:
try:
credentials_dict = ensure_oauth_token_dict(credentials_dict, source)
except Exception as exc:
raise PermissionError(
"Google Drive OAuth credentials are incomplete. Please finish the OAuth flow to generate access tokens."
) from exc
credentials_dict_str = json.dumps(credentials_dict)
regenerated_from_client_secret = True
# only send what get_google_oauth_creds needs
authorized_user_info = {}
# oauth_interactive is sanitized and needs credentials from the environment
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
else:
authorized_user_info["client_id"] = credentials_dict["client_id"]
authorized_user_info["client_secret"] = credentials_dict["client_secret"]
authorized_user_info["refresh_token"] = credentials_dict["refresh_token"]
authorized_user_info["token"] = credentials_dict["token"]
authorized_user_info["expiry"] = credentials_dict["expiry"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(token_json_str=token_json_str, source=source)
# tell caller to update token stored in DB if the refresh token changed
if oauth_creds:
should_persist = regenerated_from_client_secret or oauth_creds.refresh_token != authorized_user_info["refresh_token"]
if should_persist:
# if oauth_interactive, sanitize the credentials so they don't get stored in the db
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds)
else:
oauth_creds_json_str = oauth_creds.to_json()
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY],
DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method,
}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
# SERVICE ACCOUNT
service_account_key_json_str = credentials[DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY]
service_account_key = json.loads(service_account_key_json_str)
service_creds = ServiceAccountCredentials.from_service_account_info(service_account_key, scopes=GOOGLE_SCOPES[source])
if not service_creds.valid or not service_creds.expired:
service_creds.refresh(Request())
if not service_creds.valid:
raise PermissionError(f"Unable to access {source} - service account credentials are invalid.")
creds: ServiceAccountCredentials | OAuthCredentials | None = oauth_creds or service_creds
if creds is None:
raise PermissionError(f"Unable to access {source} - unknown credential structure.")
return creds, new_creds_dict
def get_google_oauth_creds(token_json_str: str, source: DocumentSource) -> OAuthCredentials | None:
"""creds_json only needs to contain client_id, client_secret and refresh_token to
refresh the creds.
expiry and token are optional ... however, if passing in expiry, token
should also be passed in or else we may not return any creds.
(probably a sign we should refactor the function)
"""
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(
info=creds_json,
scopes=GOOGLE_SCOPES[source],
)
if creds.valid:
return creds
if creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
logging.info("Refreshed Google Drive tokens.")
return creds
except Exception:
logging.exception("Failed to refresh google drive access token")
return None
return None