| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-04-12 16:22:24 +08:00
										 |  |  | from datetime import datetime, timezone | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from typing import Optional | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import requests | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from flask import current_app, redirect, request | 
					
						
							|  |  |  | from flask_restful import Resource | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-06 12:05:13 +08:00
										 |  |  | from configs import dify_config | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from constants.languages import languages | 
					
						
							|  |  |  | from extensions.ext_database import db | 
					
						
							| 
									
										
										
										
											2024-06-21 12:39:07 +08:00
										 |  |  | from libs.helper import get_remote_ip | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from models.account import Account, AccountStatus | 
					
						
							| 
									
										
										
										
											2024-02-12 02:09:01 +08:00
										 |  |  | from services.account_service import AccountService, RegisterService, TenantService | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from .. import api | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_oauth_providers(): | 
					
						
							|  |  |  |     with current_app.app_context(): | 
					
						
							| 
									
										
										
										
											2024-07-06 12:05:13 +08:00
										 |  |  |         if not dify_config.GITHUB_CLIENT_ID or not dify_config.GITHUB_CLIENT_SECRET: | 
					
						
							|  |  |  |             github_oauth = None | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             github_oauth = GitHubOAuth( | 
					
						
							|  |  |  |                 client_id=dify_config.GITHUB_CLIENT_ID, | 
					
						
							|  |  |  |                 client_secret=dify_config.GITHUB_CLIENT_SECRET, | 
					
						
							|  |  |  |                 redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github', | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET: | 
					
						
							|  |  |  |             google_oauth = None | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             google_oauth = GoogleOAuth( | 
					
						
							|  |  |  |                 client_id=dify_config.GOOGLE_CLIENT_ID, | 
					
						
							|  |  |  |                 client_secret=dify_config.GOOGLE_CLIENT_SECRET, | 
					
						
							|  |  |  |                 redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google', | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth} | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         return OAUTH_PROVIDERS | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class OAuthLogin(Resource): | 
					
						
							|  |  |  |     def get(self, provider: str): | 
					
						
							|  |  |  |         OAUTH_PROVIDERS = get_oauth_providers() | 
					
						
							|  |  |  |         with current_app.app_context(): | 
					
						
							|  |  |  |             oauth_provider = OAUTH_PROVIDERS.get(provider) | 
					
						
							|  |  |  |             print(vars(oauth_provider)) | 
					
						
							|  |  |  |         if not oauth_provider: | 
					
						
							|  |  |  |             return {'error': 'Invalid provider'}, 400 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auth_url = oauth_provider.get_authorization_url() | 
					
						
							|  |  |  |         return redirect(auth_url) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class OAuthCallback(Resource): | 
					
						
							|  |  |  |     def get(self, provider: str): | 
					
						
							|  |  |  |         OAUTH_PROVIDERS = get_oauth_providers() | 
					
						
							|  |  |  |         with current_app.app_context(): | 
					
						
							|  |  |  |             oauth_provider = OAUTH_PROVIDERS.get(provider) | 
					
						
							|  |  |  |         if not oauth_provider: | 
					
						
							|  |  |  |             return {'error': 'Invalid provider'}, 400 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         code = request.args.get('code') | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             token = oauth_provider.get_access_token(code) | 
					
						
							|  |  |  |             user_info = oauth_provider.get_user_info(token) | 
					
						
							|  |  |  |         except requests.exceptions.HTTPError as e: | 
					
						
							| 
									
										
										
										
											2024-07-06 12:05:13 +08:00
										 |  |  |             logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}') | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             return {'error': 'OAuth process failed'}, 400 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         account = _generate_account(provider, user_info) | 
					
						
							|  |  |  |         # Check account status | 
					
						
							|  |  |  |         if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: | 
					
						
							|  |  |  |             return {'error': 'Account is banned or closed.'}, 403 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if account.status == AccountStatus.PENDING.value: | 
					
						
							|  |  |  |             account.status = AccountStatus.ACTIVE.value | 
					
						
							| 
									
										
										
										
											2024-04-12 16:22:24 +08:00
										 |  |  |             account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-12 02:09:01 +08:00
										 |  |  |         TenantService.create_owner_tenant_if_not_exist(account) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-21 12:39:07 +08:00
										 |  |  |         token = AccountService.login(account, ip_address=get_remote_ip(request)) | 
					
						
							| 
									
										
										
										
											2023-09-25 12:49:16 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-06 12:05:13 +08:00
										 |  |  |         return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}') | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: | 
					
						
							|  |  |  |     account = Account.get_by_openid(provider, user_info.id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not account: | 
					
						
							|  |  |  |         account = Account.query.filter_by(email=user_info.email).first() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return account | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _generate_account(provider: str, user_info: OAuthUserInfo): | 
					
						
							|  |  |  |     # Get account by openid or email. | 
					
						
							|  |  |  |     account = _get_account_by_openid_or_email(provider, user_info) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not account: | 
					
						
							|  |  |  |         # Create account | 
					
						
							|  |  |  |         account_name = user_info.name if user_info.name else 'Dify' | 
					
						
							|  |  |  |         account = RegisterService.register( | 
					
						
							| 
									
										
										
										
											2024-07-06 12:05:13 +08:00
										 |  |  |             email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Set interface language | 
					
						
							| 
									
										
										
										
											2024-01-23 21:14:53 +08:00
										 |  |  |         preferred_lang = request.accept_languages.best_match(languages) | 
					
						
							|  |  |  |         if preferred_lang and preferred_lang in languages: | 
					
						
							|  |  |  |             interface_language = preferred_lang | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-01-23 21:14:53 +08:00
										 |  |  |             interface_language = languages[0] | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         account.interface_language = interface_language | 
					
						
							|  |  |  |         db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Link account | 
					
						
							|  |  |  |     AccountService.link_account_integrate(provider, user_info.id, account) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return account | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | api.add_resource(OAuthLogin, '/oauth/login/<provider>') | 
					
						
							|  |  |  | api.add_resource(OAuthCallback, '/oauth/authorize/<provider>') |