| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  | import logging | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import requests | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from flask import current_app, redirect, request | 
					
						
							| 
									
										
										
										
											2023-08-21 13:57:18 +08:00
										 |  |  | from flask_login import current_user | 
					
						
							| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  | from flask_restful import Resource | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from werkzeug.exceptions import Forbidden | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-06 12:05:13 +08:00
										 |  |  | from configs import dify_config | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from controllers.console import api | 
					
						
							| 
									
										
										
										
											2023-10-08 05:21:32 -05:00
										 |  |  | from libs.login import login_required | 
					
						
							| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  | from libs.oauth_data_source import NotionOAuth | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  | from ..setup import setup_required | 
					
						
							|  |  |  | from ..wraps import account_initialization_required | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_oauth_providers(): | 
					
						
							|  |  |  |     with current_app.app_context(): | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         notion_oauth = NotionOAuth( | 
					
						
							|  |  |  |             client_id=dify_config.NOTION_CLIENT_ID, | 
					
						
							|  |  |  |             client_secret=dify_config.NOTION_CLIENT_SECRET, | 
					
						
							|  |  |  |             redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion", | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         OAUTH_PROVIDERS = {"notion": notion_oauth} | 
					
						
							| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  |         return OAUTH_PROVIDERS | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class OAuthDataSource(Resource): | 
					
						
							|  |  |  |     def get(self, provider: str): | 
					
						
							|  |  |  |         # The role of the current user in the table must be admin or owner | 
					
						
							| 
									
										
										
										
											2024-01-26 12:47:42 +08:00
										 |  |  |         if not current_user.is_admin_or_owner: | 
					
						
							| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  |             raise Forbidden() | 
					
						
							|  |  |  |         OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() | 
					
						
							|  |  |  |         with current_app.app_context(): | 
					
						
							|  |  |  |             oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) | 
					
						
							|  |  |  |             print(vars(oauth_provider)) | 
					
						
							|  |  |  |         if not oauth_provider: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             return {"error": "Invalid provider"}, 400 | 
					
						
							|  |  |  |         if dify_config.NOTION_INTEGRATION_TYPE == "internal": | 
					
						
							| 
									
										
										
										
											2024-07-06 12:05:13 +08:00
										 |  |  |             internal_secret = dify_config.NOTION_INTERNAL_SECRET | 
					
						
							|  |  |  |             if not internal_secret: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                 return ({"error": "Internal secret is not set"},) | 
					
						
							| 
									
										
										
										
											2023-06-17 19:50:21 +08:00
										 |  |  |             oauth_provider.save_internal_access_token(internal_secret) | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             return {"data": ""} | 
					
						
							| 
									
										
										
										
											2023-06-17 19:50:21 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             auth_url = oauth_provider.get_authorization_url() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             return {"data": auth_url}, 200 | 
					
						
							| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class OAuthDataSourceCallback(Resource): | 
					
						
							| 
									
										
										
										
											2023-09-28 14:39:13 +08:00
										 |  |  |     def get(self, provider: str): | 
					
						
							|  |  |  |         OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() | 
					
						
							|  |  |  |         with current_app.app_context(): | 
					
						
							|  |  |  |             oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) | 
					
						
							|  |  |  |         if not oauth_provider: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             return {"error": "Invalid provider"}, 400 | 
					
						
							|  |  |  |         if "code" in request.args: | 
					
						
							|  |  |  |             code = request.args.get("code") | 
					
						
							| 
									
										
										
										
											2023-09-28 14:39:13 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}") | 
					
						
							|  |  |  |         elif "error" in request.args: | 
					
						
							|  |  |  |             error = request.args.get("error") | 
					
						
							| 
									
										
										
										
											2023-09-28 14:39:13 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}") | 
					
						
							| 
									
										
										
										
											2023-09-28 14:39:13 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-28 14:39:13 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | class OAuthDataSourceBinding(Resource): | 
					
						
							| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  |     def get(self, provider: str): | 
					
						
							|  |  |  |         OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() | 
					
						
							|  |  |  |         with current_app.app_context(): | 
					
						
							|  |  |  |             oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) | 
					
						
							|  |  |  |         if not oauth_provider: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             return {"error": "Invalid provider"}, 400 | 
					
						
							|  |  |  |         if "code" in request.args: | 
					
						
							|  |  |  |             code = request.args.get("code") | 
					
						
							| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  |             try: | 
					
						
							|  |  |  |                 oauth_provider.get_access_token(code) | 
					
						
							|  |  |  |             except requests.exceptions.HTTPError as e: | 
					
						
							|  |  |  |                 logging.exception( | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                     f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}" | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 return {"error": "OAuth data source process failed"}, 400 | 
					
						
							| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             return {"result": "success"}, 200 | 
					
						
							| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class OAuthDataSourceSync(Resource): | 
					
						
							|  |  |  |     @setup_required | 
					
						
							|  |  |  |     @login_required | 
					
						
							|  |  |  |     @account_initialization_required | 
					
						
							|  |  |  |     def get(self, provider, binding_id): | 
					
						
							|  |  |  |         provider = str(provider) | 
					
						
							|  |  |  |         binding_id = str(binding_id) | 
					
						
							|  |  |  |         OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() | 
					
						
							|  |  |  |         with current_app.app_context(): | 
					
						
							|  |  |  |             oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) | 
					
						
							|  |  |  |         if not oauth_provider: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             return {"error": "Invalid provider"}, 400 | 
					
						
							| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             oauth_provider.sync_data_source(binding_id) | 
					
						
							|  |  |  |         except requests.exceptions.HTTPError as e: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") | 
					
						
							|  |  |  |             return {"error": "OAuth data source process failed"}, 400 | 
					
						
							| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         return {"result": "success"}, 200 | 
					
						
							| 
									
										
										
										
											2023-06-16 21:47:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  | api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>") | 
					
						
							|  |  |  | api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>") | 
					
						
							|  |  |  | api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>") | 
					
						
							|  |  |  | api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync") |