mirror of
				https://github.com/langgenius/dify.git
				synced 2025-10-24 23:48:40 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			256 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			256 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import datetime
 | |
| 
 | |
| import pytz
 | |
| from flask import request
 | |
| from flask_login import current_user
 | |
| from flask_restful import Resource, fields, marshal_with, reqparse
 | |
| 
 | |
| from configs import dify_config
 | |
| from constants.languages import supported_language
 | |
| from controllers.console import api
 | |
| from controllers.console.workspace.error import (
 | |
|     AccountAlreadyInitedError,
 | |
|     CurrentPasswordIncorrectError,
 | |
|     InvalidInvitationCodeError,
 | |
|     RepeatPasswordNotMatchError,
 | |
| )
 | |
| from controllers.console.wraps import account_initialization_required, setup_required
 | |
| from extensions.ext_database import db
 | |
| from fields.member_fields import account_fields
 | |
| from libs.helper import TimestampField, timezone
 | |
| from libs.login import login_required
 | |
| from models import AccountIntegrate, InvitationCode
 | |
| from services.account_service import AccountService
 | |
| from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
 | |
| 
 | |
| 
 | |
| class AccountInitApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     def post(self):
 | |
|         account = current_user
 | |
| 
 | |
|         if account.status == "active":
 | |
|             raise AccountAlreadyInitedError()
 | |
| 
 | |
|         parser = reqparse.RequestParser()
 | |
| 
 | |
|         if dify_config.EDITION == "CLOUD":
 | |
|             parser.add_argument("invitation_code", type=str, location="json")
 | |
| 
 | |
|         parser.add_argument("interface_language", type=supported_language, required=True, location="json")
 | |
|         parser.add_argument("timezone", type=timezone, required=True, location="json")
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         if dify_config.EDITION == "CLOUD":
 | |
|             if not args["invitation_code"]:
 | |
|                 raise ValueError("invitation_code is required")
 | |
| 
 | |
|             # check invitation code
 | |
|             invitation_code = (
 | |
|                 db.session.query(InvitationCode)
 | |
|                 .filter(
 | |
|                     InvitationCode.code == args["invitation_code"],
 | |
|                     InvitationCode.status == "unused",
 | |
|                 )
 | |
|                 .first()
 | |
|             )
 | |
| 
 | |
|             if not invitation_code:
 | |
|                 raise InvalidInvitationCodeError()
 | |
| 
 | |
|             invitation_code.status = "used"
 | |
|             invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
 | |
|             invitation_code.used_by_tenant_id = account.current_tenant_id
 | |
|             invitation_code.used_by_account_id = account.id
 | |
| 
 | |
|         account.interface_language = args["interface_language"]
 | |
|         account.timezone = args["timezone"]
 | |
|         account.interface_theme = "light"
 | |
|         account.status = "active"
 | |
|         account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
 | |
|         db.session.commit()
 | |
| 
 | |
|         return {"result": "success"}
 | |
| 
 | |
| 
 | |
| class AccountProfileApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(account_fields)
 | |
|     def get(self):
 | |
|         return current_user
 | |
| 
 | |
| 
 | |
| class AccountNameApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(account_fields)
 | |
|     def post(self):
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument("name", type=str, required=True, location="json")
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         # Validate account name length
 | |
|         if len(args["name"]) < 3 or len(args["name"]) > 30:
 | |
|             raise ValueError("Account name must be between 3 and 30 characters.")
 | |
| 
 | |
|         updated_account = AccountService.update_account(current_user, name=args["name"])
 | |
| 
 | |
|         return updated_account
 | |
| 
 | |
| 
 | |
| class AccountAvatarApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(account_fields)
 | |
|     def post(self):
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument("avatar", type=str, required=True, location="json")
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
 | |
| 
 | |
|         return updated_account
 | |
| 
 | |
| 
 | |
| class AccountInterfaceLanguageApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(account_fields)
 | |
|     def post(self):
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument("interface_language", type=supported_language, required=True, location="json")
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
 | |
| 
 | |
|         return updated_account
 | |
| 
 | |
| 
 | |
| class AccountInterfaceThemeApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(account_fields)
 | |
|     def post(self):
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
 | |
| 
 | |
|         return updated_account
 | |
| 
 | |
| 
 | |
| class AccountTimezoneApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(account_fields)
 | |
|     def post(self):
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument("timezone", type=str, required=True, location="json")
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         # Validate timezone string, e.g. America/New_York, Asia/Shanghai
 | |
|         if args["timezone"] not in pytz.all_timezones:
 | |
|             raise ValueError("Invalid timezone string.")
 | |
| 
 | |
|         updated_account = AccountService.update_account(current_user, timezone=args["timezone"])
 | |
| 
 | |
|         return updated_account
 | |
| 
 | |
| 
 | |
| class AccountPasswordApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(account_fields)
 | |
|     def post(self):
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument("password", type=str, required=False, location="json")
 | |
|         parser.add_argument("new_password", type=str, required=True, location="json")
 | |
|         parser.add_argument("repeat_new_password", type=str, required=True, location="json")
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         if args["new_password"] != args["repeat_new_password"]:
 | |
|             raise RepeatPasswordNotMatchError()
 | |
| 
 | |
|         try:
 | |
|             AccountService.update_account_password(current_user, args["password"], args["new_password"])
 | |
|         except ServiceCurrentPasswordIncorrectError:
 | |
|             raise CurrentPasswordIncorrectError()
 | |
| 
 | |
|         return {"result": "success"}
 | |
| 
 | |
| 
 | |
| class AccountIntegrateApi(Resource):
 | |
|     integrate_fields = {
 | |
|         "provider": fields.String,
 | |
|         "created_at": TimestampField,
 | |
|         "is_bound": fields.Boolean,
 | |
|         "link": fields.String,
 | |
|     }
 | |
| 
 | |
|     integrate_list_fields = {
 | |
|         "data": fields.List(fields.Nested(integrate_fields)),
 | |
|     }
 | |
| 
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(integrate_list_fields)
 | |
|     def get(self):
 | |
|         account = current_user
 | |
| 
 | |
|         account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all()
 | |
| 
 | |
|         base_url = request.url_root.rstrip("/")
 | |
|         oauth_base_path = "/console/api/oauth/login"
 | |
|         providers = ["github", "google"]
 | |
| 
 | |
|         integrate_data = []
 | |
|         for provider in providers:
 | |
|             existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None)
 | |
|             if existing_integrate:
 | |
|                 integrate_data.append(
 | |
|                     {
 | |
|                         "id": existing_integrate.id,
 | |
|                         "provider": provider,
 | |
|                         "created_at": existing_integrate.created_at,
 | |
|                         "is_bound": True,
 | |
|                         "link": None,
 | |
|                     }
 | |
|                 )
 | |
|             else:
 | |
|                 integrate_data.append(
 | |
|                     {
 | |
|                         "id": None,
 | |
|                         "provider": provider,
 | |
|                         "created_at": None,
 | |
|                         "is_bound": False,
 | |
|                         "link": f"{base_url}{oauth_base_path}/{provider}",
 | |
|                     }
 | |
|                 )
 | |
| 
 | |
|         return {"data": integrate_data}
 | |
| 
 | |
| 
 | |
| # Register API resources
 | |
| api.add_resource(AccountInitApi, "/account/init")
 | |
| api.add_resource(AccountProfileApi, "/account/profile")
 | |
| api.add_resource(AccountNameApi, "/account/name")
 | |
| api.add_resource(AccountAvatarApi, "/account/avatar")
 | |
| api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language")
 | |
| api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
 | |
| api.add_resource(AccountTimezoneApi, "/account/timezone")
 | |
| api.add_resource(AccountPasswordApi, "/account/password")
 | |
| api.add_resource(AccountIntegrateApi, "/account/integrates")
 | |
| # api.add_resource(AccountEmailApi, '/account/email')
 | |
| # api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
 | 
