| 
									
										
										
										
											2024-02-28 16:09:56 +08:00
										 |  |  | from collections.abc import Callable | 
					
						
							| 
									
										
										
										
											2024-04-12 16:22:24 +08:00
										 |  |  | from datetime import datetime, timezone | 
					
						
							| 
									
										
										
										
											2024-02-28 16:09:56 +08:00
										 |  |  | from enum import Enum | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from functools import wraps | 
					
						
							| 
									
										
										
										
											2024-02-28 16:09:56 +08:00
										 |  |  | from typing import Optional | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from flask import current_app, request | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  | from flask_login import user_logged_in | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from flask_restful import Resource | 
					
						
							| 
									
										
										
										
											2024-02-28 16:09:56 +08:00
										 |  |  | from pydantic import BaseModel | 
					
						
							| 
									
										
										
										
											2024-04-02 17:55:49 +08:00
										 |  |  | from werkzeug.exceptions import Forbidden, NotFound, Unauthorized | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | from extensions.ext_database import db | 
					
						
							| 
									
										
										
										
											2023-10-08 05:21:32 -05:00
										 |  |  | from libs.login import _get_user | 
					
						
							| 
									
										
										
										
											2024-04-18 17:33:32 +08:00
										 |  |  | from models.account import Account, Tenant, TenantAccountJoin, TenantStatus | 
					
						
							| 
									
										
										
										
											2024-02-28 16:09:56 +08:00
										 |  |  | from models.model import ApiToken, App, EndUser | 
					
						
							| 
									
										
										
										
											2023-12-20 15:37:57 +08:00
										 |  |  | from services.feature_service import FeatureService | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-28 16:09:56 +08:00
										 |  |  | class WhereisUserArg(Enum): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Enum for whereis_user_arg. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     QUERY = 'query' | 
					
						
							|  |  |  |     JSON = 'json' | 
					
						
							|  |  |  |     FORM = 'form' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class FetchUserArg(BaseModel): | 
					
						
							|  |  |  |     fetch_from: WhereisUserArg | 
					
						
							|  |  |  |     required: bool = False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): | 
					
						
							|  |  |  |     def decorator(view_func): | 
					
						
							|  |  |  |         @wraps(view_func) | 
					
						
							|  |  |  |         def decorated_view(*args, **kwargs): | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             api_token = validate_and_get_api_token('app') | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-18 20:32:44 +08:00
										 |  |  |             app_model = db.session.query(App).filter(App.id == api_token.app_id).first() | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             if not app_model: | 
					
						
							|  |  |  |                 raise NotFound() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if app_model.status != 'normal': | 
					
						
							|  |  |  |                 raise NotFound() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if not app_model.enable_api: | 
					
						
							|  |  |  |                 raise NotFound() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-18 17:33:32 +08:00
										 |  |  |             tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() | 
					
						
							|  |  |  |             if tenant.status == TenantStatus.ARCHIVE: | 
					
						
							|  |  |  |                 raise NotFound() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-28 16:09:56 +08:00
										 |  |  |             kwargs['app_model'] = app_model | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-28 16:46:50 +08:00
										 |  |  |             if fetch_user_arg: | 
					
						
							| 
									
										
										
										
											2024-02-28 16:09:56 +08:00
										 |  |  |                 if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: | 
					
						
							|  |  |  |                     user_id = request.args.get('user') | 
					
						
							|  |  |  |                 elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: | 
					
						
							|  |  |  |                     user_id = request.get_json().get('user') | 
					
						
							|  |  |  |                 elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: | 
					
						
							|  |  |  |                     user_id = request.form.get('user') | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     # use default-user | 
					
						
							|  |  |  |                     user_id = None | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-28 16:09:56 +08:00
										 |  |  |                 if not user_id and fetch_user_arg.required: | 
					
						
							|  |  |  |                     raise ValueError("Arg user must be provided.") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-28 16:46:50 +08:00
										 |  |  |                 if user_id: | 
					
						
							|  |  |  |                     user_id = str(user_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id) | 
					
						
							| 
									
										
										
										
											2024-02-28 16:09:56 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             return view_func(*args, **kwargs) | 
					
						
							|  |  |  |         return decorated_view | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if view is None: | 
					
						
							|  |  |  |         return decorator | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         return decorator(view) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-05 16:53:55 +08:00
										 |  |  | def cloud_edition_billing_resource_check(resource: str, | 
					
						
							|  |  |  |                                          api_token_type: str, | 
					
						
							|  |  |  |                                          error_msg: str = "You have reached the limit of your subscription."): | 
					
						
							|  |  |  |     def interceptor(view): | 
					
						
							|  |  |  |         def decorated(*args, **kwargs): | 
					
						
							| 
									
										
										
										
											2023-12-20 15:37:57 +08:00
										 |  |  |             api_token = validate_and_get_api_token(api_token_type) | 
					
						
							|  |  |  |             features = FeatureService.get_features(api_token.tenant_id) | 
					
						
							| 
									
										
										
										
											2023-12-05 16:53:55 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-20 15:37:57 +08:00
										 |  |  |             if features.billing.enabled: | 
					
						
							|  |  |  |                 members = features.members | 
					
						
							|  |  |  |                 apps = features.apps | 
					
						
							|  |  |  |                 vector_space = features.vector_space | 
					
						
							| 
									
										
										
										
											2024-03-03 12:45:06 +08:00
										 |  |  |                 documents_upload_quota = features.documents_upload_quota | 
					
						
							| 
									
										
										
										
											2023-12-05 16:53:55 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-20 15:37:57 +08:00
										 |  |  |                 if resource == 'members' and 0 < members.limit <= members.size: | 
					
						
							| 
									
										
										
										
											2024-04-02 17:55:49 +08:00
										 |  |  |                     raise Forbidden(error_msg) | 
					
						
							| 
									
										
										
										
											2023-12-20 15:37:57 +08:00
										 |  |  |                 elif resource == 'apps' and 0 < apps.limit <= apps.size: | 
					
						
							| 
									
										
										
										
											2024-04-02 17:55:49 +08:00
										 |  |  |                     raise Forbidden(error_msg) | 
					
						
							| 
									
										
										
										
											2023-12-20 15:37:57 +08:00
										 |  |  |                 elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: | 
					
						
							| 
									
										
										
										
											2024-04-02 17:55:49 +08:00
										 |  |  |                     raise Forbidden(error_msg) | 
					
						
							| 
									
										
										
										
											2024-03-03 12:45:06 +08:00
										 |  |  |                 elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: | 
					
						
							| 
									
										
										
										
											2024-04-02 17:55:49 +08:00
										 |  |  |                     raise Forbidden(error_msg) | 
					
						
							| 
									
										
										
										
											2023-12-05 16:53:55 +08:00
										 |  |  |                 else: | 
					
						
							|  |  |  |                     return view(*args, **kwargs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return view(*args, **kwargs) | 
					
						
							|  |  |  |         return decorated | 
					
						
							|  |  |  |     return interceptor | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-02 17:55:49 +08:00
										 |  |  | def cloud_edition_billing_knowledge_limit_check(resource: str, | 
					
						
							|  |  |  |                                                 api_token_type: str, | 
					
						
							|  |  |  |                                                 error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): | 
					
						
							|  |  |  |     def interceptor(view): | 
					
						
							|  |  |  |         @wraps(view) | 
					
						
							|  |  |  |         def decorated(*args, **kwargs): | 
					
						
							|  |  |  |             api_token = validate_and_get_api_token(api_token_type) | 
					
						
							|  |  |  |             features = FeatureService.get_features(api_token.tenant_id) | 
					
						
							|  |  |  |             if features.billing.enabled: | 
					
						
							|  |  |  |                 if resource == 'add_segment': | 
					
						
							|  |  |  |                     if features.billing.subscription.plan == 'sandbox': | 
					
						
							|  |  |  |                         raise Forbidden(error_msg) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     return view(*args, **kwargs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return view(*args, **kwargs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return decorated | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return interceptor | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | def validate_dataset_token(view=None): | 
					
						
							|  |  |  |     def decorator(view): | 
					
						
							|  |  |  |         @wraps(view) | 
					
						
							|  |  |  |         def decorated(*args, **kwargs): | 
					
						
							|  |  |  |             api_token = validate_and_get_api_token('dataset') | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |             tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ | 
					
						
							|  |  |  |                 .filter(Tenant.id == api_token.tenant_id) \ | 
					
						
							|  |  |  |                 .filter(TenantAccountJoin.tenant_id == Tenant.id) \ | 
					
						
							| 
									
										
										
										
											2024-01-26 00:06:23 +08:00
										 |  |  |                 .filter(TenantAccountJoin.role.in_(['owner'])) \ | 
					
						
							| 
									
										
										
										
											2024-04-18 17:33:32 +08:00
										 |  |  |                 .filter(Tenant.status == TenantStatus.NORMAL) \ | 
					
						
							| 
									
										
										
										
											2024-01-26 12:47:42 +08:00
										 |  |  |                 .one_or_none() # TODO: only owner information is required, so only one is returned. | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |             if tenant_account_join: | 
					
						
							|  |  |  |                 tenant, ta = tenant_account_join | 
					
						
							|  |  |  |                 account = Account.query.filter_by(id=ta.account_id).first() | 
					
						
							|  |  |  |                 # Login admin | 
					
						
							|  |  |  |                 if account: | 
					
						
							|  |  |  |                     account.current_tenant = tenant | 
					
						
							|  |  |  |                     current_app.login_manager._update_request_context_with_user(account) | 
					
						
							|  |  |  |                     user_logged_in.send(current_app._get_current_object(), user=_get_user()) | 
					
						
							|  |  |  |                 else: | 
					
						
							| 
									
										
										
										
											2024-01-26 12:47:42 +08:00
										 |  |  |                     raise Unauthorized("Tenant owner account does not exist.") | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |             else: | 
					
						
							| 
									
										
										
										
											2024-01-26 12:47:42 +08:00
										 |  |  |                 raise Unauthorized("Tenant does not exist.") | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |             return view(api_token.tenant_id, *args, **kwargs) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         return decorated | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if view: | 
					
						
							|  |  |  |         return decorator(view) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # if view is None, it means that the decorator is used without parentheses | 
					
						
							|  |  |  |     # use the decorator as a function for method_decorators | 
					
						
							|  |  |  |     return decorator | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def validate_and_get_api_token(scope=None): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Validate and get API token. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     auth_header = request.headers.get('Authorization') | 
					
						
							| 
									
										
										
										
											2023-08-18 20:32:44 +08:00
										 |  |  |     if auth_header is None or ' ' not in auth_header: | 
					
						
							|  |  |  |         raise Unauthorized("Authorization header must be provided and start with 'Bearer'") | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     auth_scheme, auth_token = auth_header.split(None, 1) | 
					
						
							|  |  |  |     auth_scheme = auth_scheme.lower() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if auth_scheme != 'bearer': | 
					
						
							| 
									
										
										
										
											2023-08-18 20:32:44 +08:00
										 |  |  |         raise Unauthorized("Authorization scheme must be 'Bearer'") | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     api_token = db.session.query(ApiToken).filter( | 
					
						
							|  |  |  |         ApiToken.token == auth_token, | 
					
						
							|  |  |  |         ApiToken.type == scope, | 
					
						
							|  |  |  |     ).first() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not api_token: | 
					
						
							| 
									
										
										
										
											2023-08-18 20:32:44 +08:00
										 |  |  |         raise Unauthorized("Access token is invalid") | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-12 16:22:24 +08:00
										 |  |  |     api_token.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |     db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return api_token | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-28 16:09:56 +08:00
										 |  |  | def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Create or update session terminal based on user ID. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     if not user_id: | 
					
						
							|  |  |  |         user_id = 'DEFAULT-USER' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     end_user = db.session.query(EndUser) \ | 
					
						
							|  |  |  |         .filter( | 
					
						
							|  |  |  |         EndUser.tenant_id == app_model.tenant_id, | 
					
						
							|  |  |  |         EndUser.app_id == app_model.id, | 
					
						
							|  |  |  |         EndUser.session_id == user_id, | 
					
						
							|  |  |  |         EndUser.type == 'service_api' | 
					
						
							|  |  |  |     ).first() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if end_user is None: | 
					
						
							|  |  |  |         end_user = EndUser( | 
					
						
							|  |  |  |             tenant_id=app_model.tenant_id, | 
					
						
							|  |  |  |             app_id=app_model.id, | 
					
						
							|  |  |  |             type='service_api', | 
					
						
							|  |  |  |             is_anonymous=True if user_id == 'DEFAULT-USER' else False, | 
					
						
							|  |  |  |             session_id=user_id | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         db.session.add(end_user) | 
					
						
							|  |  |  |         db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return end_user | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class DatasetApiResource(Resource): | 
					
						
							|  |  |  |     method_decorators = [validate_dataset_token] |