| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | import logging | 
					
						
							|  |  |  | from datetime import datetime | 
					
						
							|  |  |  | from typing import Optional | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import requests | 
					
						
							| 
									
										
										
										
											2024-01-23 21:14:53 +08:00
										 |  |  | from constants.languages import languages | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from extensions.ext_database import db | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from flask import current_app, redirect, request | 
					
						
							|  |  |  | from flask_restful import Resource | 
					
						
							|  |  |  | from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from models.account import Account, AccountStatus | 
					
						
							|  |  |  | from services.account_service import AccountService, RegisterService | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from .. import api | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_oauth_providers(): | 
					
						
							|  |  |  |     with current_app.app_context(): | 
					
						
							|  |  |  |         github_oauth = GitHubOAuth(client_id=current_app.config.get('GITHUB_CLIENT_ID'), | 
					
						
							|  |  |  |                                    client_secret=current_app.config.get( | 
					
						
							|  |  |  |                                        'GITHUB_CLIENT_SECRET'), | 
					
						
							|  |  |  |                                    redirect_uri=current_app.config.get( | 
					
						
							| 
									
										
										
										
											2023-07-14 11:19:26 +08:00
										 |  |  |                                        'CONSOLE_API_URL') + '/console/api/oauth/authorize/github') | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         google_oauth = GoogleOAuth(client_id=current_app.config.get('GOOGLE_CLIENT_ID'), | 
					
						
							|  |  |  |                                    client_secret=current_app.config.get( | 
					
						
							|  |  |  |                                        'GOOGLE_CLIENT_SECRET'), | 
					
						
							|  |  |  |                                    redirect_uri=current_app.config.get( | 
					
						
							| 
									
										
										
										
											2023-07-14 11:19:26 +08:00
										 |  |  |                                        'CONSOLE_API_URL') + '/console/api/oauth/authorize/google') | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         OAUTH_PROVIDERS = { | 
					
						
							|  |  |  |             'github': github_oauth, | 
					
						
							|  |  |  |             'google': google_oauth | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return OAUTH_PROVIDERS | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class OAuthLogin(Resource): | 
					
						
							|  |  |  |     def get(self, provider: str): | 
					
						
							|  |  |  |         OAUTH_PROVIDERS = get_oauth_providers() | 
					
						
							|  |  |  |         with current_app.app_context(): | 
					
						
							|  |  |  |             oauth_provider = OAUTH_PROVIDERS.get(provider) | 
					
						
							|  |  |  |             print(vars(oauth_provider)) | 
					
						
							|  |  |  |         if not oauth_provider: | 
					
						
							|  |  |  |             return {'error': 'Invalid provider'}, 400 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auth_url = oauth_provider.get_authorization_url() | 
					
						
							|  |  |  |         return redirect(auth_url) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class OAuthCallback(Resource): | 
					
						
							|  |  |  |     def get(self, provider: str): | 
					
						
							|  |  |  |         OAUTH_PROVIDERS = get_oauth_providers() | 
					
						
							|  |  |  |         with current_app.app_context(): | 
					
						
							|  |  |  |             oauth_provider = OAUTH_PROVIDERS.get(provider) | 
					
						
							|  |  |  |         if not oauth_provider: | 
					
						
							|  |  |  |             return {'error': 'Invalid provider'}, 400 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         code = request.args.get('code') | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             token = oauth_provider.get_access_token(code) | 
					
						
							|  |  |  |             user_info = oauth_provider.get_user_info(token) | 
					
						
							|  |  |  |         except requests.exceptions.HTTPError as e: | 
					
						
							|  |  |  |             logging.exception( | 
					
						
							|  |  |  |                 f"An error occurred during the OAuth process with {provider}: {e.response.text}") | 
					
						
							|  |  |  |             return {'error': 'OAuth process failed'}, 400 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         account = _generate_account(provider, user_info) | 
					
						
							|  |  |  |         # Check account status | 
					
						
							|  |  |  |         if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: | 
					
						
							|  |  |  |             return {'error': 'Account is banned or closed.'}, 403 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if account.status == AccountStatus.PENDING.value: | 
					
						
							|  |  |  |             account.status = AccountStatus.ACTIVE.value | 
					
						
							|  |  |  |             account.initialized_at = datetime.utcnow() | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         AccountService.update_last_login(account, request) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-25 12:49:16 +08:00
										 |  |  |         token = AccountService.get_account_jwt_token(account) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}') | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: | 
					
						
							|  |  |  |     account = Account.get_by_openid(provider, user_info.id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not account: | 
					
						
							|  |  |  |         account = Account.query.filter_by(email=user_info.email).first() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return account | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _generate_account(provider: str, user_info: OAuthUserInfo): | 
					
						
							|  |  |  |     # Get account by openid or email. | 
					
						
							|  |  |  |     account = _get_account_by_openid_or_email(provider, user_info) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not account: | 
					
						
							|  |  |  |         # Create account | 
					
						
							|  |  |  |         account_name = user_info.name if user_info.name else 'Dify' | 
					
						
							|  |  |  |         account = RegisterService.register( | 
					
						
							|  |  |  |             email=user_info.email, | 
					
						
							|  |  |  |             name=account_name, | 
					
						
							|  |  |  |             password=None, | 
					
						
							|  |  |  |             open_id=user_info.id, | 
					
						
							|  |  |  |             provider=provider | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Set interface language | 
					
						
							| 
									
										
										
										
											2024-01-23 21:14:53 +08:00
										 |  |  |         preferred_lang = request.accept_languages.best_match(languages) | 
					
						
							|  |  |  |         if preferred_lang and preferred_lang in languages: | 
					
						
							|  |  |  |             interface_language = preferred_lang | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-01-23 21:14:53 +08:00
										 |  |  |             interface_language = languages[0] | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         account.interface_language = interface_language | 
					
						
							|  |  |  |         db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Link account | 
					
						
							|  |  |  |     AccountService.link_account_integrate(provider, user_info.id, account) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return account | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | api.add_resource(OAuthLogin, '/oauth/login/<provider>') | 
					
						
							|  |  |  | api.add_resource(OAuthCallback, '/oauth/authorize/<provider>') |