import enum import secrets from datetime import UTC, datetime, timedelta from typing import Any, Optional, cast from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config from extensions.ext_database import db from libs.helper import TokenManager from libs.passport import PassportService from libs.password import compare_password from models.account import Account, AccountStatus from models.model import App, EndUser, Site from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError from tasks.mail_email_code_login import send_email_code_login_mail_task class WebAppAuthType(enum.StrEnum): """Enum for web app authentication types.""" PUBLIC = "public" INTERNAL = "internal" EXTERNAL = "external" class WebAppAuthService: """Service for web app authentication.""" @staticmethod def authenticate(email: str, password: str) -> Account: """authenticate account with email and password""" account = db.session.query(Account).filter_by(email=email).first() if not account: raise AccountNotFoundError() if account.status == AccountStatus.BANNED.value: raise AccountLoginError("Account is banned.") if account.password is None or not compare_password(password, account.password, account.password_salt): raise AccountPasswordError("Invalid email or password.") return cast(Account, account) @classmethod def login(cls, account: Account) -> str: access_token = cls._get_account_jwt_token(account=account) return access_token @classmethod def get_user_through_email(cls, email: str): account = db.session.query(Account).filter(Account.email == email).first() if not account: return None if account.status == AccountStatus.BANNED.value: raise Unauthorized("Account is banned.") return account @classmethod def send_email_code_login_email( cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" ): email = account.email if account else email if email is None: raise ValueError("Email must be provided.") code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) token = TokenManager.generate_token( account=account, email=email, token_type="email_code_login", additional_data={"code": code} ) send_email_code_login_mail_task.delay( language=language, to=account.email if account else email, code=code, ) return token @classmethod def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: return TokenManager.get_token_data(token, "email_code_login") @classmethod def revoke_email_code_login_token(cls, token: str): TokenManager.revoke_token(token, "email_code_login") @classmethod def create_end_user(cls, app_code, email) -> EndUser: site = db.session.query(Site).filter(Site.code == app_code).first() if not site: raise NotFound("Site not found.") app_model = db.session.query(App).filter(App.id == site.app_id).first() if not app_model: raise NotFound("App not found.") end_user = EndUser( tenant_id=app_model.tenant_id, app_id=app_model.id, type="browser", is_anonymous=False, session_id=email, name="enterpriseuser", external_user_id="enterpriseuser", ) db.session.add(end_user) db.session.commit() return end_user @classmethod def _get_account_jwt_token(cls, account: Account) -> str: exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24) exp = int(exp_dt.timestamp()) payload = { "sub": "Web API Passport", "user_id": account.id, "session_id": account.email, "token_source": "webapp_login_token", "auth_type": "internal", "exp": exp, } token: str = PassportService().issue(payload) return token @classmethod def is_app_require_permission_check( cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None ) -> bool: """ Check if the app requires permission check based on its access mode. """ modes_requiring_permission_check = [ "private", "private_all", ] if access_mode: return access_mode in modes_requiring_permission_check if not app_code and not app_id: raise ValueError("Either app_code or app_id must be provided.") if app_code: app_id = AppService.get_app_id_by_code(app_code) if not app_id: raise ValueError("App ID could not be determined from the provided app_code.") webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) if webapp_settings and webapp_settings.access_mode in modes_requiring_permission_check: return True return False @classmethod def get_app_auth_type(cls, app_code: str | None = None, access_mode: str | None = None) -> WebAppAuthType: """ Get the authentication type for the app based on its access mode. """ if not app_code and not access_mode: raise ValueError("Either app_code or access_mode must be provided.") if access_mode: if access_mode == "public": return WebAppAuthType.PUBLIC elif access_mode in ["private", "private_all"]: return WebAppAuthType.INTERNAL elif access_mode == "sso_verified": return WebAppAuthType.EXTERNAL if app_code: webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code) return cls.get_app_auth_type(access_mode=webapp_settings.access_mode) raise ValueError("Could not determine app authentication type.")