mirror of
				https://github.com/langgenius/dify.git
				synced 2025-10-30 18:33:30 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			260 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			260 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import datetime
 | |
| 
 | |
| import pytz
 | |
| from flask import current_app, request
 | |
| from flask_login import current_user
 | |
| from flask_restful import Resource, fields, marshal_with, reqparse
 | |
| 
 | |
| from constants.languages import supported_language
 | |
| from controllers.console import api
 | |
| from controllers.console.setup import setup_required
 | |
| from controllers.console.workspace.error import (
 | |
|     AccountAlreadyInitedError,
 | |
|     CurrentPasswordIncorrectError,
 | |
|     InvalidInvitationCodeError,
 | |
|     RepeatPasswordNotMatchError,
 | |
| )
 | |
| from controllers.console.wraps import account_initialization_required
 | |
| from extensions.ext_database import db
 | |
| from fields.member_fields import account_fields
 | |
| from libs.helper import TimestampField, timezone
 | |
| from libs.login import login_required
 | |
| from models.account import AccountIntegrate, InvitationCode
 | |
| from services.account_service import AccountService
 | |
| from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
 | |
| 
 | |
| 
 | |
| class AccountInitApi(Resource):
 | |
| 
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     def post(self):
 | |
|         account = current_user
 | |
| 
 | |
|         if account.status == 'active':
 | |
|             raise AccountAlreadyInitedError()
 | |
| 
 | |
|         parser = reqparse.RequestParser()
 | |
| 
 | |
|         if current_app.config['EDITION'] == 'CLOUD':
 | |
|             parser.add_argument('invitation_code', type=str, location='json')
 | |
| 
 | |
|         parser.add_argument(
 | |
|             'interface_language', type=supported_language, required=True, location='json')
 | |
|         parser.add_argument('timezone', type=timezone,
 | |
|                             required=True, location='json')
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         if current_app.config['EDITION'] == 'CLOUD':
 | |
|             if not args['invitation_code']:
 | |
|                 raise ValueError('invitation_code is required')
 | |
| 
 | |
|             # check invitation code
 | |
|             invitation_code = db.session.query(InvitationCode).filter(
 | |
|                 InvitationCode.code == args['invitation_code'],
 | |
|                 InvitationCode.status == 'unused',
 | |
|             ).first()
 | |
| 
 | |
|             if not invitation_code:
 | |
|                 raise InvalidInvitationCodeError()
 | |
| 
 | |
|             invitation_code.status = 'used'
 | |
|             invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
 | |
|             invitation_code.used_by_tenant_id = account.current_tenant_id
 | |
|             invitation_code.used_by_account_id = account.id
 | |
| 
 | |
|         account.interface_language = args['interface_language']
 | |
|         account.timezone = args['timezone']
 | |
|         account.interface_theme = 'light'
 | |
|         account.status = 'active'
 | |
|         account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
 | |
|         db.session.commit()
 | |
| 
 | |
|         return {'result': 'success'}
 | |
| 
 | |
| 
 | |
| class AccountProfileApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(account_fields)
 | |
|     def get(self):
 | |
|         return current_user
 | |
| 
 | |
| 
 | |
| class AccountNameApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(account_fields)
 | |
|     def post(self):
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument('name', type=str, required=True, location='json')
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         # Validate account name length
 | |
|         if len(args['name']) < 3 or len(args['name']) > 30:
 | |
|             raise ValueError(
 | |
|                 "Account name must be between 3 and 30 characters.")
 | |
| 
 | |
|         updated_account = AccountService.update_account(current_user, name=args['name'])
 | |
| 
 | |
|         return updated_account
 | |
| 
 | |
| 
 | |
| class AccountAvatarApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(account_fields)
 | |
|     def post(self):
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument('avatar', type=str, required=True, location='json')
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         updated_account = AccountService.update_account(current_user, avatar=args['avatar'])
 | |
| 
 | |
|         return updated_account
 | |
| 
 | |
| 
 | |
| class AccountInterfaceLanguageApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(account_fields)
 | |
|     def post(self):
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument(
 | |
|             'interface_language', type=supported_language, required=True, location='json')
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         updated_account = AccountService.update_account(current_user, interface_language=args['interface_language'])
 | |
| 
 | |
|         return updated_account
 | |
| 
 | |
| 
 | |
| class AccountInterfaceThemeApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(account_fields)
 | |
|     def post(self):
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument('interface_theme', type=str, choices=[
 | |
|             'light', 'dark'], required=True, location='json')
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme'])
 | |
| 
 | |
|         return updated_account
 | |
| 
 | |
| 
 | |
| class AccountTimezoneApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(account_fields)
 | |
|     def post(self):
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument('timezone', type=str,
 | |
|                             required=True, location='json')
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         # Validate timezone string, e.g. America/New_York, Asia/Shanghai
 | |
|         if args['timezone'] not in pytz.all_timezones:
 | |
|             raise ValueError("Invalid timezone string.")
 | |
| 
 | |
|         updated_account = AccountService.update_account(current_user, timezone=args['timezone'])
 | |
| 
 | |
|         return updated_account
 | |
| 
 | |
| 
 | |
| class AccountPasswordApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(account_fields)
 | |
|     def post(self):
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument('password', type=str,
 | |
|                             required=False, location='json')
 | |
|         parser.add_argument('new_password', type=str,
 | |
|                             required=True, location='json')
 | |
|         parser.add_argument('repeat_new_password', type=str,
 | |
|                             required=True, location='json')
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         if args['new_password'] != args['repeat_new_password']:
 | |
|             raise RepeatPasswordNotMatchError()
 | |
| 
 | |
|         try:
 | |
|             AccountService.update_account_password(
 | |
|                 current_user, args['password'], args['new_password'])
 | |
|         except ServiceCurrentPasswordIncorrectError:
 | |
|             raise CurrentPasswordIncorrectError()
 | |
| 
 | |
|         return {"result": "success"}
 | |
| 
 | |
| 
 | |
| class AccountIntegrateApi(Resource):
 | |
|     integrate_fields = {
 | |
|         'provider': fields.String,
 | |
|         'created_at': TimestampField,
 | |
|         'is_bound': fields.Boolean,
 | |
|         'link': fields.String
 | |
|     }
 | |
| 
 | |
|     integrate_list_fields = {
 | |
|         'data': fields.List(fields.Nested(integrate_fields)),
 | |
|     }
 | |
| 
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(integrate_list_fields)
 | |
|     def get(self):
 | |
|         account = current_user
 | |
| 
 | |
|         account_integrates = db.session.query(AccountIntegrate).filter(
 | |
|             AccountIntegrate.account_id == account.id).all()
 | |
| 
 | |
|         base_url = request.url_root.rstrip('/')
 | |
|         oauth_base_path = "/console/api/oauth/login"
 | |
|         providers = ["github", "google"]
 | |
| 
 | |
|         integrate_data = []
 | |
|         for provider in providers:
 | |
|             existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None)
 | |
|             if existing_integrate:
 | |
|                 integrate_data.append({
 | |
|                     'id': existing_integrate.id,
 | |
|                     'provider': provider,
 | |
|                     'created_at': existing_integrate.created_at,
 | |
|                     'is_bound': True,
 | |
|                     'link': None
 | |
|                 })
 | |
|             else:
 | |
|                 integrate_data.append({
 | |
|                     'id': None,
 | |
|                     'provider': provider,
 | |
|                     'created_at': None,
 | |
|                     'is_bound': False,
 | |
|                     'link': f'{base_url}{oauth_base_path}/{provider}'
 | |
|                 })
 | |
| 
 | |
|         return {'data': integrate_data}
 | |
| 
 | |
| 
 | |
| # Register API resources
 | |
| api.add_resource(AccountInitApi, '/account/init')
 | |
| api.add_resource(AccountProfileApi, '/account/profile')
 | |
| api.add_resource(AccountNameApi, '/account/name')
 | |
| api.add_resource(AccountAvatarApi, '/account/avatar')
 | |
| api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language')
 | |
| api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme')
 | |
| api.add_resource(AccountTimezoneApi, '/account/timezone')
 | |
| api.add_resource(AccountPasswordApi, '/account/password')
 | |
| api.add_resource(AccountIntegrateApi, '/account/integrates')
 | |
| # api.add_resource(AccountEmailApi, '/account/email')
 | |
| # api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
 | 
