mirror of
				https://github.com/langgenius/dify.git
				synced 2025-11-03 20:33:00 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			263 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			263 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import datetime
 | 
						|
 | 
						|
import pytz
 | 
						|
from flask import request
 | 
						|
from flask_login import current_user
 | 
						|
from flask_restful import Resource, fields, marshal_with, reqparse
 | 
						|
 | 
						|
from configs import dify_config
 | 
						|
from constants.languages import supported_language
 | 
						|
from controllers.console import api
 | 
						|
from controllers.console.setup import setup_required
 | 
						|
from controllers.console.workspace.error import (
 | 
						|
    AccountAlreadyInitedError,
 | 
						|
    CurrentPasswordIncorrectError,
 | 
						|
    InvalidInvitationCodeError,
 | 
						|
    RepeatPasswordNotMatchError,
 | 
						|
)
 | 
						|
from controllers.console.wraps import account_initialization_required
 | 
						|
from extensions.ext_database import db
 | 
						|
from fields.member_fields import account_fields
 | 
						|
from libs.helper import TimestampField, timezone
 | 
						|
from libs.login import login_required
 | 
						|
from models.account import AccountIntegrate, InvitationCode
 | 
						|
from services.account_service import AccountService
 | 
						|
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
 | 
						|
 | 
						|
 | 
						|
class AccountInitApi(Resource):
 | 
						|
 | 
						|
    @setup_required
 | 
						|
    @login_required
 | 
						|
    def post(self):
 | 
						|
        account = current_user
 | 
						|
 | 
						|
        if account.status == 'active':
 | 
						|
            raise AccountAlreadyInitedError()
 | 
						|
 | 
						|
        parser = reqparse.RequestParser()
 | 
						|
 | 
						|
        if dify_config.EDITION == 'CLOUD':
 | 
						|
            parser.add_argument('invitation_code', type=str, location='json')
 | 
						|
 | 
						|
        parser.add_argument(
 | 
						|
            'interface_language', type=supported_language, required=True, location='json')
 | 
						|
        parser.add_argument('timezone', type=timezone,
 | 
						|
                            required=True, location='json')
 | 
						|
        args = parser.parse_args()
 | 
						|
 | 
						|
        if dify_config.EDITION == 'CLOUD':
 | 
						|
            if not args['invitation_code']:
 | 
						|
                raise ValueError('invitation_code is required')
 | 
						|
 | 
						|
            # check invitation code
 | 
						|
            invitation_code = db.session.query(InvitationCode).filter(
 | 
						|
                InvitationCode.code == args['invitation_code'],
 | 
						|
                InvitationCode.status == 'unused',
 | 
						|
            ).first()
 | 
						|
 | 
						|
            if not invitation_code:
 | 
						|
                raise InvalidInvitationCodeError()
 | 
						|
 | 
						|
            invitation_code.status = 'used'
 | 
						|
            invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
 | 
						|
            invitation_code.used_by_tenant_id = account.current_tenant_id
 | 
						|
            invitation_code.used_by_account_id = account.id
 | 
						|
 | 
						|
        account.interface_language = args['interface_language']
 | 
						|
        account.timezone = args['timezone']
 | 
						|
        account.interface_theme = 'light'
 | 
						|
        account.status = 'active'
 | 
						|
        account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
 | 
						|
        db.session.commit()
 | 
						|
 | 
						|
        return {'result': 'success'}
 | 
						|
 | 
						|
 | 
						|
class AccountProfileApi(Resource):
 | 
						|
    @setup_required
 | 
						|
    @login_required
 | 
						|
    @account_initialization_required
 | 
						|
    @marshal_with(account_fields)
 | 
						|
    def get(self):
 | 
						|
        return current_user
 | 
						|
 | 
						|
 | 
						|
class AccountNameApi(Resource):
 | 
						|
    @setup_required
 | 
						|
    @login_required
 | 
						|
    @account_initialization_required
 | 
						|
    @marshal_with(account_fields)
 | 
						|
    def post(self):
 | 
						|
        parser = reqparse.RequestParser()
 | 
						|
        parser.add_argument('name', type=str, required=True, location='json')
 | 
						|
        args = parser.parse_args()
 | 
						|
 | 
						|
        # Validate account name length
 | 
						|
        if len(args['name']) < 3 or len(args['name']) > 30:
 | 
						|
            raise ValueError(
 | 
						|
                "Account name must be between 3 and 30 characters.")
 | 
						|
 | 
						|
        updated_account = AccountService.update_account(current_user, name=args['name'])
 | 
						|
 | 
						|
        return updated_account
 | 
						|
 | 
						|
 | 
						|
class AccountAvatarApi(Resource):
 | 
						|
    @setup_required
 | 
						|
    @login_required
 | 
						|
    @account_initialization_required
 | 
						|
    @marshal_with(account_fields)
 | 
						|
    def post(self):
 | 
						|
        parser = reqparse.RequestParser()
 | 
						|
        parser.add_argument('avatar', type=str, required=True, location='json')
 | 
						|
        args = parser.parse_args()
 | 
						|
 | 
						|
        updated_account = AccountService.update_account(current_user, avatar=args['avatar'])
 | 
						|
 | 
						|
        return updated_account
 | 
						|
 | 
						|
 | 
						|
class AccountInterfaceLanguageApi(Resource):
 | 
						|
    @setup_required
 | 
						|
    @login_required
 | 
						|
    @account_initialization_required
 | 
						|
    @marshal_with(account_fields)
 | 
						|
    def post(self):
 | 
						|
        parser = reqparse.RequestParser()
 | 
						|
        parser.add_argument(
 | 
						|
            'interface_language', type=supported_language, required=True, location='json')
 | 
						|
        args = parser.parse_args()
 | 
						|
 | 
						|
        updated_account = AccountService.update_account(current_user, interface_language=args['interface_language'])
 | 
						|
 | 
						|
        return updated_account
 | 
						|
 | 
						|
 | 
						|
class AccountInterfaceThemeApi(Resource):
 | 
						|
    @setup_required
 | 
						|
    @login_required
 | 
						|
    @account_initialization_required
 | 
						|
    @marshal_with(account_fields)
 | 
						|
    def post(self):
 | 
						|
        parser = reqparse.RequestParser()
 | 
						|
        parser.add_argument('interface_theme', type=str, choices=[
 | 
						|
            'light', 'dark'], required=True, location='json')
 | 
						|
        args = parser.parse_args()
 | 
						|
 | 
						|
        updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme'])
 | 
						|
 | 
						|
        return updated_account
 | 
						|
 | 
						|
 | 
						|
class AccountTimezoneApi(Resource):
 | 
						|
    @setup_required
 | 
						|
    @login_required
 | 
						|
    @account_initialization_required
 | 
						|
    @marshal_with(account_fields)
 | 
						|
    def post(self):
 | 
						|
        parser = reqparse.RequestParser()
 | 
						|
        parser.add_argument('timezone', type=str,
 | 
						|
                            required=True, location='json')
 | 
						|
        args = parser.parse_args()
 | 
						|
 | 
						|
        # Validate timezone string, e.g. America/New_York, Asia/Shanghai
 | 
						|
        if args['timezone'] not in pytz.all_timezones:
 | 
						|
            raise ValueError("Invalid timezone string.")
 | 
						|
 | 
						|
        updated_account = AccountService.update_account(current_user, timezone=args['timezone'])
 | 
						|
 | 
						|
        return updated_account
 | 
						|
 | 
						|
 | 
						|
class AccountPasswordApi(Resource):
 | 
						|
    @setup_required
 | 
						|
    @login_required
 | 
						|
    @account_initialization_required
 | 
						|
    @marshal_with(account_fields)
 | 
						|
    def post(self):
 | 
						|
        parser = reqparse.RequestParser()
 | 
						|
        parser.add_argument('password', type=str,
 | 
						|
                            required=False, location='json')
 | 
						|
        parser.add_argument('new_password', type=str,
 | 
						|
                            required=True, location='json')
 | 
						|
        parser.add_argument('repeat_new_password', type=str,
 | 
						|
                            required=True, location='json')
 | 
						|
        args = parser.parse_args()
 | 
						|
 | 
						|
        if args['new_password'] != args['repeat_new_password']:
 | 
						|
            raise RepeatPasswordNotMatchError()
 | 
						|
 | 
						|
        try:
 | 
						|
            AccountService.update_account_password(
 | 
						|
                current_user, args['password'], args['new_password'])
 | 
						|
        except ServiceCurrentPasswordIncorrectError:
 | 
						|
            raise CurrentPasswordIncorrectError()
 | 
						|
 | 
						|
        return {"result": "success"}
 | 
						|
 | 
						|
 | 
						|
class AccountIntegrateApi(Resource):
 | 
						|
    integrate_fields = {
 | 
						|
        'provider': fields.String,
 | 
						|
        'created_at': TimestampField,
 | 
						|
        'is_bound': fields.Boolean,
 | 
						|
        'link': fields.String
 | 
						|
    }
 | 
						|
 | 
						|
    integrate_list_fields = {
 | 
						|
        'data': fields.List(fields.Nested(integrate_fields)),
 | 
						|
    }
 | 
						|
 | 
						|
    @setup_required
 | 
						|
    @login_required
 | 
						|
    @account_initialization_required
 | 
						|
    @marshal_with(integrate_list_fields)
 | 
						|
    def get(self):
 | 
						|
        account = current_user
 | 
						|
 | 
						|
        account_integrates = db.session.query(AccountIntegrate).filter(
 | 
						|
            AccountIntegrate.account_id == account.id).all()
 | 
						|
 | 
						|
        base_url = request.url_root.rstrip('/')
 | 
						|
        oauth_base_path = "/console/api/oauth/login"
 | 
						|
        providers = ["github", "google"]
 | 
						|
 | 
						|
        integrate_data = []
 | 
						|
        for provider in providers:
 | 
						|
            existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None)
 | 
						|
            if existing_integrate:
 | 
						|
                integrate_data.append({
 | 
						|
                    'id': existing_integrate.id,
 | 
						|
                    'provider': provider,
 | 
						|
                    'created_at': existing_integrate.created_at,
 | 
						|
                    'is_bound': True,
 | 
						|
                    'link': None
 | 
						|
                })
 | 
						|
            else:
 | 
						|
                integrate_data.append({
 | 
						|
                    'id': None,
 | 
						|
                    'provider': provider,
 | 
						|
                    'created_at': None,
 | 
						|
                    'is_bound': False,
 | 
						|
                    'link': f'{base_url}{oauth_base_path}/{provider}'
 | 
						|
                })
 | 
						|
 | 
						|
        return {'data': integrate_data}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
# Register API resources
 | 
						|
api.add_resource(AccountInitApi, '/account/init')
 | 
						|
api.add_resource(AccountProfileApi, '/account/profile')
 | 
						|
api.add_resource(AccountNameApi, '/account/name')
 | 
						|
api.add_resource(AccountAvatarApi, '/account/avatar')
 | 
						|
api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language')
 | 
						|
api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme')
 | 
						|
api.add_resource(AccountTimezoneApi, '/account/timezone')
 | 
						|
api.add_resource(AccountPasswordApi, '/account/password')
 | 
						|
api.add_resource(AccountIntegrateApi, '/account/integrates')
 | 
						|
# api.add_resource(AccountEmailApi, '/account/email')
 | 
						|
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
 |