| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-11-24 13:28:46 +08:00
										 |  |  | from datetime import UTC, datetime | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from typing import Optional | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import requests | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from flask import current_app, redirect, request | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  | from flask_restful import Resource  # type: ignore | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  | from sqlalchemy import select | 
					
						
							|  |  |  | from sqlalchemy.orm import Session | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  | from werkzeug.exceptions import Unauthorized | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-06 12:05:13 +08:00
										 |  |  | from configs import dify_config | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from constants.languages import languages | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  | from events.tenant_event import tenant_was_created | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from extensions.ext_database import db | 
					
						
							| 
									
										
										
										
											2024-10-12 23:46:30 +08:00
										 |  |  | from libs.helper import extract_remote_ip | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo | 
					
						
							| 
									
										
										
										
											2024-10-21 10:43:49 +08:00
										 |  |  | from models import Account | 
					
						
							|  |  |  | from models.account import AccountStatus | 
					
						
							| 
									
										
										
										
											2024-02-12 02:09:01 +08:00
										 |  |  | from services.account_service import AccountService, RegisterService, TenantService | 
					
						
							| 
									
										
										
										
											2024-12-29 22:33:42 -05:00
										 |  |  | from services.errors.account import AccountNotFoundError, AccountRegisterError | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  | from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError | 
					
						
							|  |  |  | from services.feature_service import FeatureService | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from .. import api | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_oauth_providers(): | 
					
						
							|  |  |  |     with current_app.app_context(): | 
					
						
							| 
									
										
										
										
											2024-07-06 12:05:13 +08:00
										 |  |  |         if not dify_config.GITHUB_CLIENT_ID or not dify_config.GITHUB_CLIENT_SECRET: | 
					
						
							|  |  |  |             github_oauth = None | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             github_oauth = GitHubOAuth( | 
					
						
							|  |  |  |                 client_id=dify_config.GITHUB_CLIENT_ID, | 
					
						
							|  |  |  |                 client_secret=dify_config.GITHUB_CLIENT_SECRET, | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                 redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github", | 
					
						
							| 
									
										
										
										
											2024-07-06 12:05:13 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |         if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET: | 
					
						
							|  |  |  |             google_oauth = None | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             google_oauth = GoogleOAuth( | 
					
						
							|  |  |  |                 client_id=dify_config.GOOGLE_CLIENT_ID, | 
					
						
							|  |  |  |                 client_secret=dify_config.GOOGLE_CLIENT_SECRET, | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                 redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google", | 
					
						
							| 
									
										
										
										
											2024-07-06 12:05:13 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth} | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         return OAUTH_PROVIDERS | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class OAuthLogin(Resource): | 
					
						
							|  |  |  |     def get(self, provider: str): | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  |         invite_token = request.args.get("invite_token") or None | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         OAUTH_PROVIDERS = get_oauth_providers() | 
					
						
							|  |  |  |         with current_app.app_context(): | 
					
						
							|  |  |  |             oauth_provider = OAUTH_PROVIDERS.get(provider) | 
					
						
							|  |  |  |         if not oauth_provider: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             return {"error": "Invalid provider"}, 400 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  |         auth_url = oauth_provider.get_authorization_url(invite_token=invite_token) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         return redirect(auth_url) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class OAuthCallback(Resource): | 
					
						
							|  |  |  |     def get(self, provider: str): | 
					
						
							|  |  |  |         OAUTH_PROVIDERS = get_oauth_providers() | 
					
						
							|  |  |  |         with current_app.app_context(): | 
					
						
							|  |  |  |             oauth_provider = OAUTH_PROVIDERS.get(provider) | 
					
						
							|  |  |  |         if not oauth_provider: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             return {"error": "Invalid provider"}, 400 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         code = request.args.get("code") | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  |         state = request.args.get("state") | 
					
						
							|  |  |  |         invite_token = None | 
					
						
							|  |  |  |         if state: | 
					
						
							|  |  |  |             invite_token = state | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             token = oauth_provider.get_access_token(code) | 
					
						
							|  |  |  |             user_info = oauth_provider.get_user_info(token) | 
					
						
							| 
									
										
										
										
											2024-12-23 15:23:11 +08:00
										 |  |  |         except requests.exceptions.RequestException as e: | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |             error_text = e.response.text if e.response else str(e) | 
					
						
							|  |  |  |             logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}") | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             return {"error": "OAuth process failed"}, 400 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  |         if invite_token and RegisterService.is_valid_invite_token(invite_token): | 
					
						
							|  |  |  |             invitation = RegisterService._get_invitation_by_token(token=invite_token) | 
					
						
							|  |  |  |             if invitation: | 
					
						
							|  |  |  |                 invitation_email = invitation.get("email", None) | 
					
						
							|  |  |  |                 if invitation_email != user_info.email: | 
					
						
							|  |  |  |                     return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             account = _generate_account(provider, user_info) | 
					
						
							|  |  |  |         except AccountNotFoundError: | 
					
						
							|  |  |  |             return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.") | 
					
						
							| 
									
										
										
										
											2024-10-23 10:59:30 +08:00
										 |  |  |         except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError): | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  |             return redirect( | 
					
						
							|  |  |  |                 f"{dify_config.CONSOLE_WEB_URL}/signin" | 
					
						
							|  |  |  |                 "?message=Workspace not found, please contact system admin to invite you to join in a workspace." | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-12-29 22:33:42 -05:00
										 |  |  |         except AccountRegisterError as e: | 
					
						
							|  |  |  |             return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}") | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         # Check account status | 
					
						
							| 
									
										
										
										
											2024-10-23 10:59:30 +08:00
										 |  |  |         if account.status == AccountStatus.BANNED.value: | 
					
						
							|  |  |  |             return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.") | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if account.status == AccountStatus.PENDING.value: | 
					
						
							|  |  |  |             account.status = AccountStatus.ACTIVE.value | 
					
						
							| 
									
										
										
										
											2024-11-24 13:28:46 +08:00
										 |  |  |             account.initialized_at = datetime.now(UTC).replace(tzinfo=None) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             TenantService.create_owner_tenant_if_not_exist(account) | 
					
						
							|  |  |  |         except Unauthorized: | 
					
						
							|  |  |  |             return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.") | 
					
						
							|  |  |  |         except WorkSpaceNotAllowedCreateError: | 
					
						
							|  |  |  |             return redirect( | 
					
						
							|  |  |  |                 f"{dify_config.CONSOLE_WEB_URL}/signin" | 
					
						
							|  |  |  |                 "?message=Workspace not found, please contact system admin to invite you to join in a workspace." | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-02-12 02:09:01 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-12 23:46:30 +08:00
										 |  |  |         token_pair = AccountService.login( | 
					
						
							|  |  |  |             account=account, | 
					
						
							|  |  |  |             ip_address=extract_remote_ip(request), | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-09-25 12:49:16 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-12 23:46:30 +08:00
										 |  |  |         return redirect( | 
					
						
							|  |  |  |             f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |     account: Optional[Account] = Account.get_by_openid(provider, user_info.id) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if not account: | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  |         with Session(db.engine) as session: | 
					
						
							|  |  |  |             account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none() | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     return account | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _generate_account(provider: str, user_info: OAuthUserInfo): | 
					
						
							|  |  |  |     # Get account by openid or email. | 
					
						
							|  |  |  |     account = _get_account_by_openid_or_email(provider, user_info) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  |     if account: | 
					
						
							|  |  |  |         tenant = TenantService.get_join_tenants(account) | 
					
						
							|  |  |  |         if not tenant: | 
					
						
							|  |  |  |             if not FeatureService.get_system_features().is_allow_create_workspace: | 
					
						
							|  |  |  |                 raise WorkSpaceNotAllowedCreateError() | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | 
					
						
							|  |  |  |                 TenantService.create_tenant_member(tenant, account, role="owner") | 
					
						
							|  |  |  |                 account.current_tenant = tenant | 
					
						
							|  |  |  |                 tenant_was_created.send(tenant) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |     if not account: | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  |         if not FeatureService.get_system_features().is_allow_register: | 
					
						
							|  |  |  |             raise AccountNotFoundError() | 
					
						
							| 
									
										
										
										
											2024-09-12 15:50:49 +08:00
										 |  |  |         account_name = user_info.name or "Dify" | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         account = RegisterService.register( | 
					
						
							| 
									
										
										
										
											2024-07-06 12:05:13 +08:00
										 |  |  |             email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Set interface language | 
					
						
							| 
									
										
										
										
											2024-01-23 21:14:53 +08:00
										 |  |  |         preferred_lang = request.accept_languages.best_match(languages) | 
					
						
							|  |  |  |         if preferred_lang and preferred_lang in languages: | 
					
						
							|  |  |  |             interface_language = preferred_lang | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-01-23 21:14:53 +08:00
										 |  |  |             interface_language = languages[0] | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         account.interface_language = interface_language | 
					
						
							|  |  |  |         db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Link account | 
					
						
							|  |  |  |     AccountService.link_account_integrate(provider, user_info.id, account) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return account | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  | api.add_resource(OAuthLogin, "/oauth/login/<provider>") | 
					
						
							|  |  |  | api.add_resource(OAuthCallback, "/oauth/authorize/<provider>") |