mirror of
				https://github.com/infiniflow/ragflow.git
				synced 2025-11-04 03:39:41 +00:00 
			
		
		
		
	### What problem does this PR solve? Add support for OAuth2 and OpenID Connect (OIDC) authentication, allowing OAuth/OIDC authentication using the specified routes: - `/login/<channel>`: Initiates the OAuth flow for the specified channel - `/oauth/callback/<channel>`: Handles the OAuth callback after successful authentication The callback URL should be configured in your OAuth provider as: ``` https://your-app.com/oauth/callback/<channel> ``` For detailed instructions on configuring **service_conf.yaml.template**, see: `./api/apps/auth/README.md#usage`. - Related issues #3495 ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Documentation Update
		
			
				
	
	
		
			804 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			804 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#
 | 
						||
#  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
 | 
						||
#
 | 
						||
#  Licensed under the Apache License, Version 2.0 (the "License");
 | 
						||
#  you may not use this file except in compliance with the License.
 | 
						||
#  You may obtain a copy of the License at
 | 
						||
#
 | 
						||
#      http://www.apache.org/licenses/LICENSE-2.0
 | 
						||
#
 | 
						||
#  Unless required by applicable law or agreed to in writing, software
 | 
						||
#  distributed under the License is distributed on an "AS IS" BASIS,
 | 
						||
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						||
#  See the License for the specific language governing permissions and
 | 
						||
#  limitations under the License.
 | 
						||
#
 | 
						||
import logging
 | 
						||
import json
 | 
						||
import re
 | 
						||
from datetime import datetime
 | 
						||
 | 
						||
from flask import request, session, redirect
 | 
						||
from werkzeug.security import generate_password_hash, check_password_hash
 | 
						||
from flask_login import login_required, current_user, login_user, logout_user
 | 
						||
 | 
						||
from api.db.db_models import TenantLLM
 | 
						||
from api.db.services.llm_service import TenantLLMService, LLMService
 | 
						||
from api.utils.api_utils import (
 | 
						||
    server_error_response,
 | 
						||
    validate_request,
 | 
						||
    get_data_error_result,
 | 
						||
)
 | 
						||
from api.utils import (
 | 
						||
    get_uuid,
 | 
						||
    get_format_time,
 | 
						||
    decrypt,
 | 
						||
    download_img,
 | 
						||
    current_timestamp,
 | 
						||
    datetime_format,
 | 
						||
)
 | 
						||
from api.db import UserTenantRole, FileType
 | 
						||
from api import settings
 | 
						||
from api.db.services.user_service import UserService, TenantService, UserTenantService
 | 
						||
from api.db.services.file_service import FileService
 | 
						||
from api.utils.api_utils import get_json_result, construct_response
 | 
						||
from api.apps.auth import get_auth_client
 | 
						||
 | 
						||
 | 
						||
@manager.route("/login", methods=["POST", "GET"])  # noqa: F821
 | 
						||
def login():
 | 
						||
    """
 | 
						||
    User login endpoint.
 | 
						||
    ---
 | 
						||
    tags:
 | 
						||
      - User
 | 
						||
    parameters:
 | 
						||
      - in: body
 | 
						||
        name: body
 | 
						||
        description: Login credentials.
 | 
						||
        required: true
 | 
						||
        schema:
 | 
						||
          type: object
 | 
						||
          properties:
 | 
						||
            email:
 | 
						||
              type: string
 | 
						||
              description: User email.
 | 
						||
            password:
 | 
						||
              type: string
 | 
						||
              description: User password.
 | 
						||
    responses:
 | 
						||
      200:
 | 
						||
        description: Login successful.
 | 
						||
        schema:
 | 
						||
          type: object
 | 
						||
      401:
 | 
						||
        description: Authentication failed.
 | 
						||
        schema:
 | 
						||
          type: object
 | 
						||
    """
 | 
						||
    if not request.json:
 | 
						||
        return get_json_result(
 | 
						||
            data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
 | 
						||
        )
 | 
						||
 | 
						||
    email = request.json.get("email", "")
 | 
						||
    users = UserService.query(email=email)
 | 
						||
    if not users:
 | 
						||
        return get_json_result(
 | 
						||
            data=False,
 | 
						||
            code=settings.RetCode.AUTHENTICATION_ERROR,
 | 
						||
            message=f"Email: {email} is not registered!",
 | 
						||
        )
 | 
						||
 | 
						||
    password = request.json.get("password")
 | 
						||
    try:
 | 
						||
        password = decrypt(password)
 | 
						||
    except BaseException:
 | 
						||
        return get_json_result(
 | 
						||
            data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password"
 | 
						||
        )
 | 
						||
 | 
						||
    user = UserService.query_user(email, password)
 | 
						||
    if user:
 | 
						||
        response_data = user.to_json()
 | 
						||
        user.access_token = get_uuid()
 | 
						||
        login_user(user)
 | 
						||
        user.update_time = (current_timestamp(),)
 | 
						||
        user.update_date = (datetime_format(datetime.now()),)
 | 
						||
        user.save()
 | 
						||
        msg = "Welcome back!"
 | 
						||
        return construct_response(data=response_data, auth=user.get_id(), message=msg)
 | 
						||
    else:
 | 
						||
        return get_json_result(
 | 
						||
            data=False,
 | 
						||
            code=settings.RetCode.AUTHENTICATION_ERROR,
 | 
						||
            message="Email and password do not match!",
 | 
						||
        )
 | 
						||
 | 
						||
 | 
						||
@manager.route("/login/<channel>") # noqa: F821
 | 
						||
def oauth_login(channel):
 | 
						||
    channel_config = settings.OAUTH_CONFIG.get(channel)
 | 
						||
    if not channel_config:
 | 
						||
        raise ValueError(f"Invalid channel name: {channel}")
 | 
						||
    auth_cli = get_auth_client(channel_config)
 | 
						||
 | 
						||
    auth_url = auth_cli.get_authorization_url()
 | 
						||
    return redirect(auth_url)
 | 
						||
 | 
						||
 | 
						||
@manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
 | 
						||
def oauth_callback(channel):
 | 
						||
    """
 | 
						||
    Handle the OAuth/OIDC callback for various channels dynamically.
 | 
						||
    """
 | 
						||
    try:
 | 
						||
        channel_config = settings.OAUTH_CONFIG.get(channel)
 | 
						||
        if not channel_config:
 | 
						||
            raise ValueError(f"Invalid channel name: {channel}")
 | 
						||
        auth_cli = get_auth_client(channel_config)
 | 
						||
 | 
						||
        # Obtain the authorization code
 | 
						||
        code = request.args.get("code")
 | 
						||
        if not code:
 | 
						||
            return redirect("/?error=missing_code")
 | 
						||
 | 
						||
        # Exchange authorization code for access token
 | 
						||
        token_info = auth_cli.exchange_code_for_token(code)
 | 
						||
        access_token = token_info.get("access_token")
 | 
						||
        if not access_token:
 | 
						||
            return redirect("/?error=token_failed")
 | 
						||
 | 
						||
        id_token = token_info.get("id_token")
 | 
						||
 | 
						||
        # Fetch user info
 | 
						||
        user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
 | 
						||
        if not user_info.email:
 | 
						||
            return redirect("/?error=email_missing")
 | 
						||
 | 
						||
        # Login or register
 | 
						||
        users = UserService.query(email=user_info.email)
 | 
						||
        user_id = get_uuid()
 | 
						||
        
 | 
						||
        if not users:
 | 
						||
            try:
 | 
						||
                try:
 | 
						||
                    avatar = download_img(user_info.avatar_url)
 | 
						||
                except Exception as e:
 | 
						||
                    logging.exception(e)
 | 
						||
                    avatar = ""
 | 
						||
 | 
						||
                users = user_register(
 | 
						||
                    user_id,
 | 
						||
                    {
 | 
						||
                        "access_token": access_token,
 | 
						||
                        "email": user_info.email,
 | 
						||
                        "avatar": avatar,
 | 
						||
                        "nickname": user_info.nickname,
 | 
						||
                        "login_channel": channel,
 | 
						||
                        "last_login_time": get_format_time(),
 | 
						||
                        "is_superuser": False,
 | 
						||
                    },
 | 
						||
                )
 | 
						||
 | 
						||
                if not users:
 | 
						||
                    raise Exception(f"Failed to register {user_info.email}")
 | 
						||
                if len(users) > 1:
 | 
						||
                    raise Exception(f"Same email: {user_info.email} exists!")
 | 
						||
 | 
						||
                # Try to log in
 | 
						||
                user = users[0]
 | 
						||
                login_user(user)
 | 
						||
                return redirect(f"/?auth_success=true&user_id={user.get_id()}")
 | 
						||
 | 
						||
            except Exception as e:
 | 
						||
                rollback_user_registration(user_id)
 | 
						||
                logging.exception(e)
 | 
						||
                return redirect(f"/?error={str(e)}")
 | 
						||
 | 
						||
        # User exists, try to log in
 | 
						||
        user = users[0]
 | 
						||
        user.access_token = get_uuid()
 | 
						||
        login_user(user)
 | 
						||
        user.save()
 | 
						||
        return redirect(f"/?auth_success=true&user_id={user.get_id()}")
 | 
						||
    except Exception as e:
 | 
						||
        return redirect(f"/?error={str(e)}")
 | 
						||
 | 
						||
 | 
						||
@manager.route("/github_callback", methods=["GET"])  # noqa: F821
 | 
						||
def github_callback():
 | 
						||
    """
 | 
						||
    GitHub OAuth callback endpoint.
 | 
						||
    ---
 | 
						||
    tags:
 | 
						||
      - OAuth
 | 
						||
    parameters:
 | 
						||
      - in: query
 | 
						||
        name: code
 | 
						||
        type: string
 | 
						||
        required: true
 | 
						||
        description: Authorization code from GitHub.
 | 
						||
    responses:
 | 
						||
      200:
 | 
						||
        description: Authentication successful.
 | 
						||
        schema:
 | 
						||
          type: object
 | 
						||
    """
 | 
						||
    import requests
 | 
						||
 | 
						||
    res = requests.post(
 | 
						||
        settings.GITHUB_OAUTH.get("url"),
 | 
						||
        data={
 | 
						||
            "client_id": settings.GITHUB_OAUTH.get("client_id"),
 | 
						||
            "client_secret": settings.GITHUB_OAUTH.get("secret_key"),
 | 
						||
            "code": request.args.get("code"),
 | 
						||
        },
 | 
						||
        headers={"Accept": "application/json"},
 | 
						||
    )
 | 
						||
    res = res.json()
 | 
						||
    if "error" in res:
 | 
						||
        return redirect("/?error=%s" % res["error_description"])
 | 
						||
 | 
						||
    if "user:email" not in res["scope"].split(","):
 | 
						||
        return redirect("/?error=user:email not in scope")
 | 
						||
 | 
						||
    session["access_token"] = res["access_token"]
 | 
						||
    session["access_token_from"] = "github"
 | 
						||
    user_info = user_info_from_github(session["access_token"])
 | 
						||
    email_address = user_info["email"]
 | 
						||
    users = UserService.query(email=email_address)
 | 
						||
    user_id = get_uuid()
 | 
						||
    if not users:
 | 
						||
        # User isn't try to register
 | 
						||
        try:
 | 
						||
            try:
 | 
						||
                avatar = download_img(user_info["avatar_url"])
 | 
						||
            except Exception as e:
 | 
						||
                logging.exception(e)
 | 
						||
                avatar = ""
 | 
						||
            users = user_register(
 | 
						||
                user_id,
 | 
						||
                {
 | 
						||
                    "access_token": session["access_token"],
 | 
						||
                    "email": email_address,
 | 
						||
                    "avatar": avatar,
 | 
						||
                    "nickname": user_info["login"],
 | 
						||
                    "login_channel": "github",
 | 
						||
                    "last_login_time": get_format_time(),
 | 
						||
                    "is_superuser": False,
 | 
						||
                },
 | 
						||
            )
 | 
						||
            if not users:
 | 
						||
                raise Exception(f"Fail to register {email_address}.")
 | 
						||
            if len(users) > 1:
 | 
						||
                raise Exception(f"Same email: {email_address} exists!")
 | 
						||
 | 
						||
            # Try to log in
 | 
						||
            user = users[0]
 | 
						||
            login_user(user)
 | 
						||
            return redirect("/?auth=%s" % user.get_id())
 | 
						||
        except Exception as e:
 | 
						||
            rollback_user_registration(user_id)
 | 
						||
            logging.exception(e)
 | 
						||
            return redirect("/?error=%s" % str(e))
 | 
						||
 | 
						||
    # User has already registered, try to log in
 | 
						||
    user = users[0]
 | 
						||
    user.access_token = get_uuid()
 | 
						||
    login_user(user)
 | 
						||
    user.save()
 | 
						||
    return redirect("/?auth=%s" % user.get_id())
 | 
						||
 | 
						||
 | 
						||
@manager.route("/feishu_callback", methods=["GET"])  # noqa: F821
 | 
						||
def feishu_callback():
 | 
						||
    """
 | 
						||
    Feishu OAuth callback endpoint.
 | 
						||
    ---
 | 
						||
    tags:
 | 
						||
      - OAuth
 | 
						||
    parameters:
 | 
						||
      - in: query
 | 
						||
        name: code
 | 
						||
        type: string
 | 
						||
        required: true
 | 
						||
        description: Authorization code from Feishu.
 | 
						||
    responses:
 | 
						||
      200:
 | 
						||
        description: Authentication successful.
 | 
						||
        schema:
 | 
						||
          type: object
 | 
						||
    """
 | 
						||
    import requests
 | 
						||
 | 
						||
    app_access_token_res = requests.post(
 | 
						||
        settings.FEISHU_OAUTH.get("app_access_token_url"),
 | 
						||
        data=json.dumps(
 | 
						||
            {
 | 
						||
                "app_id": settings.FEISHU_OAUTH.get("app_id"),
 | 
						||
                "app_secret": settings.FEISHU_OAUTH.get("app_secret"),
 | 
						||
            }
 | 
						||
        ),
 | 
						||
        headers={"Content-Type": "application/json; charset=utf-8"},
 | 
						||
    )
 | 
						||
    app_access_token_res = app_access_token_res.json()
 | 
						||
    if app_access_token_res["code"] != 0:
 | 
						||
        return redirect("/?error=%s" % app_access_token_res)
 | 
						||
 | 
						||
    res = requests.post(
 | 
						||
        settings.FEISHU_OAUTH.get("user_access_token_url"),
 | 
						||
        data=json.dumps(
 | 
						||
            {
 | 
						||
                "grant_type": settings.FEISHU_OAUTH.get("grant_type"),
 | 
						||
                "code": request.args.get("code"),
 | 
						||
            }
 | 
						||
        ),
 | 
						||
        headers={
 | 
						||
            "Content-Type": "application/json; charset=utf-8",
 | 
						||
            "Authorization": f"Bearer {app_access_token_res['app_access_token']}",
 | 
						||
        },
 | 
						||
    )
 | 
						||
    res = res.json()
 | 
						||
    if res["code"] != 0:
 | 
						||
        return redirect("/?error=%s" % res["message"])
 | 
						||
 | 
						||
    if "contact:user.email:readonly" not in res["data"]["scope"].split():
 | 
						||
        return redirect("/?error=contact:user.email:readonly not in scope")
 | 
						||
    session["access_token"] = res["data"]["access_token"]
 | 
						||
    session["access_token_from"] = "feishu"
 | 
						||
    user_info = user_info_from_feishu(session["access_token"])
 | 
						||
    email_address = user_info["email"]
 | 
						||
    users = UserService.query(email=email_address)
 | 
						||
    user_id = get_uuid()
 | 
						||
    if not users:
 | 
						||
        # User isn't try to register
 | 
						||
        try:
 | 
						||
            try:
 | 
						||
                avatar = download_img(user_info["avatar_url"])
 | 
						||
            except Exception as e:
 | 
						||
                logging.exception(e)
 | 
						||
                avatar = ""
 | 
						||
            users = user_register(
 | 
						||
                user_id,
 | 
						||
                {
 | 
						||
                    "access_token": session["access_token"],
 | 
						||
                    "email": email_address,
 | 
						||
                    "avatar": avatar,
 | 
						||
                    "nickname": user_info["en_name"],
 | 
						||
                    "login_channel": "feishu",
 | 
						||
                    "last_login_time": get_format_time(),
 | 
						||
                    "is_superuser": False,
 | 
						||
                },
 | 
						||
            )
 | 
						||
            if not users:
 | 
						||
                raise Exception(f"Fail to register {email_address}.")
 | 
						||
            if len(users) > 1:
 | 
						||
                raise Exception(f"Same email: {email_address} exists!")
 | 
						||
 | 
						||
            # Try to log in
 | 
						||
            user = users[0]
 | 
						||
            login_user(user)
 | 
						||
            return redirect("/?auth=%s" % user.get_id())
 | 
						||
        except Exception as e:
 | 
						||
            rollback_user_registration(user_id)
 | 
						||
            logging.exception(e)
 | 
						||
            return redirect("/?error=%s" % str(e))
 | 
						||
 | 
						||
    # User has already registered, try to log in
 | 
						||
    user = users[0]
 | 
						||
    user.access_token = get_uuid()
 | 
						||
    login_user(user)
 | 
						||
    user.save()
 | 
						||
    return redirect("/?auth=%s" % user.get_id())
 | 
						||
 | 
						||
 | 
						||
def user_info_from_feishu(access_token):
 | 
						||
    import requests
 | 
						||
 | 
						||
    headers = {
 | 
						||
        "Content-Type": "application/json; charset=utf-8",
 | 
						||
        "Authorization": f"Bearer {access_token}",
 | 
						||
    }
 | 
						||
    res = requests.get(
 | 
						||
        "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers
 | 
						||
    )
 | 
						||
    user_info = res.json()["data"]
 | 
						||
    user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
 | 
						||
    return user_info
 | 
						||
 | 
						||
 | 
						||
def user_info_from_github(access_token):
 | 
						||
    import requests
 | 
						||
 | 
						||
    headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
 | 
						||
    res = requests.get(
 | 
						||
        f"https://api.github.com/user?access_token={access_token}", headers=headers
 | 
						||
    )
 | 
						||
    user_info = res.json()
 | 
						||
    email_info = requests.get(
 | 
						||
        f"https://api.github.com/user/emails?access_token={access_token}",
 | 
						||
        headers=headers,
 | 
						||
    ).json()
 | 
						||
    user_info["email"] = next(
 | 
						||
        (email for email in email_info if email["primary"]), None
 | 
						||
    )["email"]
 | 
						||
    return user_info
 | 
						||
 | 
						||
 | 
						||
@manager.route("/logout", methods=["GET"])  # noqa: F821
 | 
						||
@login_required
 | 
						||
def log_out():
 | 
						||
    """
 | 
						||
    User logout endpoint.
 | 
						||
    ---
 | 
						||
    tags:
 | 
						||
      - User
 | 
						||
    security:
 | 
						||
      - ApiKeyAuth: []
 | 
						||
    responses:
 | 
						||
      200:
 | 
						||
        description: Logout successful.
 | 
						||
        schema:
 | 
						||
          type: object
 | 
						||
    """
 | 
						||
    current_user.access_token = ""
 | 
						||
    current_user.save()
 | 
						||
    logout_user()
 | 
						||
    return get_json_result(data=True)
 | 
						||
 | 
						||
 | 
						||
@manager.route("/setting", methods=["POST"])  # noqa: F821
 | 
						||
@login_required
 | 
						||
def setting_user():
 | 
						||
    """
 | 
						||
    Update user settings.
 | 
						||
    ---
 | 
						||
    tags:
 | 
						||
      - User
 | 
						||
    security:
 | 
						||
      - ApiKeyAuth: []
 | 
						||
    parameters:
 | 
						||
      - in: body
 | 
						||
        name: body
 | 
						||
        description: User settings to update.
 | 
						||
        required: true
 | 
						||
        schema:
 | 
						||
          type: object
 | 
						||
          properties:
 | 
						||
            nickname:
 | 
						||
              type: string
 | 
						||
              description: New nickname.
 | 
						||
            email:
 | 
						||
              type: string
 | 
						||
              description: New email.
 | 
						||
    responses:
 | 
						||
      200:
 | 
						||
        description: Settings updated successfully.
 | 
						||
        schema:
 | 
						||
          type: object
 | 
						||
    """
 | 
						||
    update_dict = {}
 | 
						||
    request_data = request.json
 | 
						||
    if request_data.get("password"):
 | 
						||
        new_password = request_data.get("new_password")
 | 
						||
        if not check_password_hash(
 | 
						||
                current_user.password, decrypt(request_data["password"])
 | 
						||
        ):
 | 
						||
            return get_json_result(
 | 
						||
                data=False,
 | 
						||
                code=settings.RetCode.AUTHENTICATION_ERROR,
 | 
						||
                message="Password error!",
 | 
						||
            )
 | 
						||
 | 
						||
        if new_password:
 | 
						||
            update_dict["password"] = generate_password_hash(decrypt(new_password))
 | 
						||
 | 
						||
    for k in request_data.keys():
 | 
						||
        if k in [
 | 
						||
            "password",
 | 
						||
            "new_password",
 | 
						||
            "email",
 | 
						||
            "status",
 | 
						||
            "is_superuser",
 | 
						||
            "login_channel",
 | 
						||
            "is_anonymous",
 | 
						||
            "is_active",
 | 
						||
            "is_authenticated",
 | 
						||
            "last_login_time",
 | 
						||
        ]:
 | 
						||
            continue
 | 
						||
        update_dict[k] = request_data[k]
 | 
						||
 | 
						||
    try:
 | 
						||
        UserService.update_by_id(current_user.id, update_dict)
 | 
						||
        return get_json_result(data=True)
 | 
						||
    except Exception as e:
 | 
						||
        logging.exception(e)
 | 
						||
        return get_json_result(
 | 
						||
            data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR
 | 
						||
        )
 | 
						||
 | 
						||
 | 
						||
@manager.route("/info", methods=["GET"])  # noqa: F821
 | 
						||
@login_required
 | 
						||
def user_profile():
 | 
						||
    """
 | 
						||
    Get user profile information.
 | 
						||
    ---
 | 
						||
    tags:
 | 
						||
      - User
 | 
						||
    security:
 | 
						||
      - ApiKeyAuth: []
 | 
						||
    responses:
 | 
						||
      200:
 | 
						||
        description: User profile retrieved successfully.
 | 
						||
        schema:
 | 
						||
          type: object
 | 
						||
          properties:
 | 
						||
            id:
 | 
						||
              type: string
 | 
						||
              description: User ID.
 | 
						||
            nickname:
 | 
						||
              type: string
 | 
						||
              description: User nickname.
 | 
						||
            email:
 | 
						||
              type: string
 | 
						||
              description: User email.
 | 
						||
    """
 | 
						||
    return get_json_result(data=current_user.to_dict())
 | 
						||
 | 
						||
 | 
						||
def rollback_user_registration(user_id):
 | 
						||
    try:
 | 
						||
        UserService.delete_by_id(user_id)
 | 
						||
    except Exception:
 | 
						||
        pass
 | 
						||
    try:
 | 
						||
        TenantService.delete_by_id(user_id)
 | 
						||
    except Exception:
 | 
						||
        pass
 | 
						||
    try:
 | 
						||
        u = UserTenantService.query(tenant_id=user_id)
 | 
						||
        if u:
 | 
						||
            UserTenantService.delete_by_id(u[0].id)
 | 
						||
    except Exception:
 | 
						||
        pass
 | 
						||
    try:
 | 
						||
        TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
 | 
						||
    except Exception:
 | 
						||
        pass
 | 
						||
 | 
						||
 | 
						||
def user_register(user_id, user):
 | 
						||
    user["id"] = user_id
 | 
						||
    tenant = {
 | 
						||
        "id": user_id,
 | 
						||
        "name": user["nickname"] + "‘s Kingdom",
 | 
						||
        "llm_id": settings.CHAT_MDL,
 | 
						||
        "embd_id": settings.EMBEDDING_MDL,
 | 
						||
        "asr_id": settings.ASR_MDL,
 | 
						||
        "parser_ids": settings.PARSERS,
 | 
						||
        "img2txt_id": settings.IMAGE2TEXT_MDL,
 | 
						||
        "rerank_id": settings.RERANK_MDL,
 | 
						||
    }
 | 
						||
    usr_tenant = {
 | 
						||
        "tenant_id": user_id,
 | 
						||
        "user_id": user_id,
 | 
						||
        "invited_by": user_id,
 | 
						||
        "role": UserTenantRole.OWNER,
 | 
						||
    }
 | 
						||
    file_id = get_uuid()
 | 
						||
    file = {
 | 
						||
        "id": file_id,
 | 
						||
        "parent_id": file_id,
 | 
						||
        "tenant_id": user_id,
 | 
						||
        "created_by": user_id,
 | 
						||
        "name": "/",
 | 
						||
        "type": FileType.FOLDER.value,
 | 
						||
        "size": 0,
 | 
						||
        "location": "",
 | 
						||
    }
 | 
						||
    tenant_llm = []
 | 
						||
    for llm in LLMService.query(fid=settings.LLM_FACTORY):
 | 
						||
        tenant_llm.append(
 | 
						||
            {
 | 
						||
                "tenant_id": user_id,
 | 
						||
                "llm_factory": settings.LLM_FACTORY,
 | 
						||
                "llm_name": llm.llm_name,
 | 
						||
                "model_type": llm.model_type,
 | 
						||
                "api_key": settings.API_KEY,
 | 
						||
                "api_base": settings.LLM_BASE_URL,
 | 
						||
                "max_tokens": llm.max_tokens if llm.max_tokens else 8192
 | 
						||
            }
 | 
						||
        )
 | 
						||
 | 
						||
    if not UserService.save(**user):
 | 
						||
        return
 | 
						||
    TenantService.insert(**tenant)
 | 
						||
    UserTenantService.insert(**usr_tenant)
 | 
						||
    TenantLLMService.insert_many(tenant_llm)
 | 
						||
    FileService.insert(file)
 | 
						||
    return UserService.query(email=user["email"])
 | 
						||
 | 
						||
 | 
						||
@manager.route("/register", methods=["POST"])  # noqa: F821
 | 
						||
@validate_request("nickname", "email", "password")
 | 
						||
def user_add():
 | 
						||
    """
 | 
						||
    Register a new user.
 | 
						||
    ---
 | 
						||
    tags:
 | 
						||
      - User
 | 
						||
    parameters:
 | 
						||
      - in: body
 | 
						||
        name: body
 | 
						||
        description: Registration details.
 | 
						||
        required: true
 | 
						||
        schema:
 | 
						||
          type: object
 | 
						||
          properties:
 | 
						||
            nickname:
 | 
						||
              type: string
 | 
						||
              description: User nickname.
 | 
						||
            email:
 | 
						||
              type: string
 | 
						||
              description: User email.
 | 
						||
            password:
 | 
						||
              type: string
 | 
						||
              description: User password.
 | 
						||
    responses:
 | 
						||
      200:
 | 
						||
        description: Registration successful.
 | 
						||
        schema:
 | 
						||
          type: object
 | 
						||
    """
 | 
						||
 | 
						||
    if not settings.REGISTER_ENABLED:
 | 
						||
        return get_json_result(
 | 
						||
            data=False,
 | 
						||
            message="User registration is disabled!",
 | 
						||
            code=settings.RetCode.OPERATING_ERROR,
 | 
						||
        )
 | 
						||
 | 
						||
    req = request.json
 | 
						||
    email_address = req["email"]
 | 
						||
 | 
						||
    # Validate the email address
 | 
						||
    if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", email_address):
 | 
						||
        return get_json_result(
 | 
						||
            data=False,
 | 
						||
            message=f"Invalid email address: {email_address}!",
 | 
						||
            code=settings.RetCode.OPERATING_ERROR,
 | 
						||
        )
 | 
						||
 | 
						||
    # Check if the email address is already used
 | 
						||
    if UserService.query(email=email_address):
 | 
						||
        return get_json_result(
 | 
						||
            data=False,
 | 
						||
            message=f"Email: {email_address} has already registered!",
 | 
						||
            code=settings.RetCode.OPERATING_ERROR,
 | 
						||
        )
 | 
						||
 | 
						||
    # Construct user info data
 | 
						||
    nickname = req["nickname"]
 | 
						||
    user_dict = {
 | 
						||
        "access_token": get_uuid(),
 | 
						||
        "email": email_address,
 | 
						||
        "nickname": nickname,
 | 
						||
        "password": decrypt(req["password"]),
 | 
						||
        "login_channel": "password",
 | 
						||
        "last_login_time": get_format_time(),
 | 
						||
        "is_superuser": False,
 | 
						||
    }
 | 
						||
 | 
						||
    user_id = get_uuid()
 | 
						||
    try:
 | 
						||
        users = user_register(user_id, user_dict)
 | 
						||
        if not users:
 | 
						||
            raise Exception(f"Fail to register {email_address}.")
 | 
						||
        if len(users) > 1:
 | 
						||
            raise Exception(f"Same email: {email_address} exists!")
 | 
						||
        user = users[0]
 | 
						||
        login_user(user)
 | 
						||
        return construct_response(
 | 
						||
            data=user.to_json(),
 | 
						||
            auth=user.get_id(),
 | 
						||
            message=f"{nickname}, welcome aboard!",
 | 
						||
        )
 | 
						||
    except Exception as e:
 | 
						||
        rollback_user_registration(user_id)
 | 
						||
        logging.exception(e)
 | 
						||
        return get_json_result(
 | 
						||
            data=False,
 | 
						||
            message=f"User registration failure, error: {str(e)}",
 | 
						||
            code=settings.RetCode.EXCEPTION_ERROR,
 | 
						||
        )
 | 
						||
 | 
						||
 | 
						||
@manager.route("/tenant_info", methods=["GET"])  # noqa: F821
 | 
						||
@login_required
 | 
						||
def tenant_info():
 | 
						||
    """
 | 
						||
    Get tenant information.
 | 
						||
    ---
 | 
						||
    tags:
 | 
						||
      - Tenant
 | 
						||
    security:
 | 
						||
      - ApiKeyAuth: []
 | 
						||
    responses:
 | 
						||
      200:
 | 
						||
        description: Tenant information retrieved successfully.
 | 
						||
        schema:
 | 
						||
          type: object
 | 
						||
          properties:
 | 
						||
            tenant_id:
 | 
						||
              type: string
 | 
						||
              description: Tenant ID.
 | 
						||
            name:
 | 
						||
              type: string
 | 
						||
              description: Tenant name.
 | 
						||
            llm_id:
 | 
						||
              type: string
 | 
						||
              description: LLM ID.
 | 
						||
            embd_id:
 | 
						||
              type: string
 | 
						||
              description: Embedding model ID.
 | 
						||
    """
 | 
						||
    try:
 | 
						||
        tenants = TenantService.get_info_by(current_user.id)
 | 
						||
        if not tenants:
 | 
						||
            return get_data_error_result(message="Tenant not found!")
 | 
						||
        return get_json_result(data=tenants[0])
 | 
						||
    except Exception as e:
 | 
						||
        return server_error_response(e)
 | 
						||
 | 
						||
 | 
						||
@manager.route("/set_tenant_info", methods=["POST"])  # noqa: F821
 | 
						||
@login_required
 | 
						||
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
 | 
						||
def set_tenant_info():
 | 
						||
    """
 | 
						||
    Update tenant information.
 | 
						||
    ---
 | 
						||
    tags:
 | 
						||
      - Tenant
 | 
						||
    security:
 | 
						||
      - ApiKeyAuth: []
 | 
						||
    parameters:
 | 
						||
      - in: body
 | 
						||
        name: body
 | 
						||
        description: Tenant information to update.
 | 
						||
        required: true
 | 
						||
        schema:
 | 
						||
          type: object
 | 
						||
          properties:
 | 
						||
            tenant_id:
 | 
						||
              type: string
 | 
						||
              description: Tenant ID.
 | 
						||
            llm_id:
 | 
						||
              type: string
 | 
						||
              description: LLM ID.
 | 
						||
            embd_id:
 | 
						||
              type: string
 | 
						||
              description: Embedding model ID.
 | 
						||
            asr_id:
 | 
						||
              type: string
 | 
						||
              description: ASR model ID.
 | 
						||
            img2txt_id:
 | 
						||
              type: string
 | 
						||
              description: Image to Text model ID.
 | 
						||
    responses:
 | 
						||
      200:
 | 
						||
        description: Tenant information updated successfully.
 | 
						||
        schema:
 | 
						||
          type: object
 | 
						||
    """
 | 
						||
    req = request.json
 | 
						||
    try:
 | 
						||
        tid = req.pop("tenant_id")
 | 
						||
        TenantService.update_by_id(tid, req)
 | 
						||
        return get_json_result(data=True)
 | 
						||
    except Exception as e:
 | 
						||
        return server_error_response(e)
 |