| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  | import json | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | import random | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | import re | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | import string | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | import subprocess | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | import time | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | import uuid | 
					
						
							| 
									
										
										
										
											2024-12-02 10:24:21 +08:00
										 |  |  | from collections.abc import Generator, Mapping | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from datetime import datetime | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | from hashlib import sha256 | 
					
						
							| 
									
										
										
										
											2024-12-15 17:18:17 +08:00
										 |  |  | from typing import Any, Optional, Union, cast | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from zoneinfo import available_timezones | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-22 11:01:32 +08:00
										 |  |  | from flask import Response, stream_with_context | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  | from flask_restful import fields  # type: ignore | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-22 11:01:32 +08:00
										 |  |  | from configs import dify_config | 
					
						
							| 
									
										
										
										
											2024-07-22 22:58:07 +08:00
										 |  |  | from core.app.features.rate_limiting.rate_limit import RateLimitGenerator | 
					
						
							| 
									
										
										
										
											2024-10-21 10:43:49 +08:00
										 |  |  | from core.file import helpers as file_helpers | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | from extensions.ext_redis import redis_client | 
					
						
							|  |  |  | from models.account import Account | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | def run(script): | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     return subprocess.getstatusoutput("source /root/.bashrc && " + script) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-19 09:16:33 +08:00
										 |  |  | class AppIconUrlField(fields.Raw): | 
					
						
							|  |  |  |     def output(self, key, obj): | 
					
						
							|  |  |  |         if obj is None: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 10:23:03 +08:00
										 |  |  |         from models.model import App, IconType, Site | 
					
						
							| 
									
										
										
										
											2024-08-19 09:16:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-22 15:05:04 +08:00
										 |  |  |         if isinstance(obj, dict) and "app" in obj: | 
					
						
							|  |  |  |             obj = obj["app"] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 10:23:03 +08:00
										 |  |  |         if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value: | 
					
						
							| 
									
										
										
										
											2024-10-21 10:43:49 +08:00
										 |  |  |             return file_helpers.get_signed_file_url(obj.icon) | 
					
						
							| 
									
										
										
										
											2024-08-19 09:16:33 +08:00
										 |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-22 11:11:31 +09:00
										 |  |  | class AvatarUrlField(fields.Raw): | 
					
						
							|  |  |  |     def output(self, key, obj): | 
					
						
							|  |  |  |         if obj is None: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         from models.account import Account | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if isinstance(obj, Account) and obj.avatar is not None: | 
					
						
							|  |  |  |             return file_helpers.get_signed_file_url(obj.avatar) | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | class TimestampField(fields.Raw): | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  |     def format(self, value) -> int: | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         return int(value.timestamp()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def email(email): | 
					
						
							|  |  |  |     # Define a regex pattern for email addresses | 
					
						
							| 
									
										
										
										
											2024-06-17 22:32:59 +09:00
										 |  |  |     pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$" | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |     # Check if the email matches the pattern | 
					
						
							|  |  |  |     if re.match(pattern, email) is not None: | 
					
						
							|  |  |  |         return email | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     error = "{email} is not a valid email.".format(email=email) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |     raise ValueError(error) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def uuid_value(value): | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     if value == "": | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         return str(value) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         uuid_obj = uuid.UUID(value) | 
					
						
							|  |  |  |         return str(uuid_obj) | 
					
						
							|  |  |  |     except ValueError: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         error = "{value} is not a valid uuid.".format(value=value) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         raise ValueError(error) | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  | def alphanumeric(value: str): | 
					
						
							|  |  |  |     # check if the value is alphanumeric and underlined | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     if re.match(r"^[a-zA-Z0-9_]+$", value): | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  |         return value | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     raise ValueError(f"{value} is not a valid alphanumeric value") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | def timestamp_value(timestamp): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         int_timestamp = int(timestamp) | 
					
						
							|  |  |  |         if int_timestamp < 0: | 
					
						
							|  |  |  |             raise ValueError | 
					
						
							|  |  |  |         return int_timestamp | 
					
						
							|  |  |  |     except ValueError: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         error = "{timestamp} is not a valid timestamp.".format(timestamp=timestamp) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         raise ValueError(error) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-11 16:40:52 +08:00
										 |  |  | class StrLen: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     """Restrict input to an integer in a range (inclusive)""" | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     def __init__(self, max_length, argument="argument"): | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         self.max_length = max_length | 
					
						
							|  |  |  |         self.argument = argument | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, value): | 
					
						
							|  |  |  |         length = len(value) | 
					
						
							|  |  |  |         if length > self.max_length: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |             error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format( | 
					
						
							|  |  |  |                 arg=self.argument, val=value, length=self.max_length | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             raise ValueError(error) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return value | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-11 16:40:52 +08:00
										 |  |  | class FloatRange: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     """Restrict input to an float in a range (inclusive)""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, low, high, argument="argument"): | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         self.low = low | 
					
						
							|  |  |  |         self.high = high | 
					
						
							|  |  |  |         self.argument = argument | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, value): | 
					
						
							|  |  |  |         value = _get_float(value) | 
					
						
							|  |  |  |         if value < self.low or value > self.high: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |             error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( | 
					
						
							|  |  |  |                 arg=self.argument, val=value, lo=self.low, hi=self.high | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             raise ValueError(error) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return value | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-11 16:40:52 +08:00
										 |  |  | class DatetimeString: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     def __init__(self, format, argument="argument"): | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         self.format = format | 
					
						
							|  |  |  |         self.argument = argument | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, value): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             datetime.strptime(value, self.format) | 
					
						
							|  |  |  |         except ValueError: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |             error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format( | 
					
						
							|  |  |  |                 arg=self.argument, val=value, format=self.format | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             raise ValueError(error) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return value | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _get_float(value): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         return float(value) | 
					
						
							|  |  |  |     except (TypeError, ValueError): | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         raise ValueError("{} is not a valid float".format(value)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | def timezone(timezone_string): | 
					
						
							|  |  |  |     if timezone_string and timezone_string in available_timezones(): | 
					
						
							|  |  |  |         return timezone_string | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |     raise ValueError(error) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def generate_string(n): | 
					
						
							|  |  |  |     letters_digits = string.ascii_letters + string.digits | 
					
						
							|  |  |  |     result = "" | 
					
						
							|  |  |  |     for i in range(n): | 
					
						
							|  |  |  |         result += random.choice(letters_digits) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return result | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-12 23:46:30 +08:00
										 |  |  | def extract_remote_ip(request) -> str: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     if request.headers.get("CF-Connecting-IP"): | 
					
						
							| 
									
										
										
										
											2024-12-15 17:18:17 +08:00
										 |  |  |         return cast(str, request.headers.get("Cf-Connecting-Ip")) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |     elif request.headers.getlist("X-Forwarded-For"): | 
					
						
							| 
									
										
										
										
											2024-12-15 17:18:17 +08:00
										 |  |  |         return cast(str, request.headers.getlist("X-Forwarded-For")[0]) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |     else: | 
					
						
							| 
									
										
										
										
											2024-12-15 17:18:17 +08:00
										 |  |  |         return cast(str, request.remote_addr) | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def generate_text_hash(text: str) -> str: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     hash_text = str(text) + "None" | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  |     return sha256(hash_text.encode()).hexdigest() | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-02 10:24:21 +08:00
										 |  |  | def compact_generate_response( | 
					
						
							|  |  |  |     response: Union[Mapping[str, Any], RateLimitGenerator, Generator[str, None, None]], | 
					
						
							|  |  |  | ) -> Response: | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  |     if isinstance(response, dict): | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         return Response(response=json.dumps(response), status=200, mimetype="application/json") | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  |     else: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  |         def generate() -> Generator: | 
					
						
							|  |  |  |             yield from response | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TokenManager: | 
					
						
							|  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  |     def generate_token( | 
					
						
							|  |  |  |         cls, | 
					
						
							|  |  |  |         token_type: str, | 
					
						
							|  |  |  |         account: Optional[Account] = None, | 
					
						
							|  |  |  |         email: Optional[str] = None, | 
					
						
							|  |  |  |         additional_data: Optional[dict] = None, | 
					
						
							|  |  |  |     ) -> str: | 
					
						
							|  |  |  |         if account is None and email is None: | 
					
						
							|  |  |  |             raise ValueError("Account or email must be provided") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         account_id = account.id if account else None | 
					
						
							|  |  |  |         account_email = account.email if account else email | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if account_id: | 
					
						
							|  |  |  |             old_token = cls._get_current_token_for_account(account_id, token_type) | 
					
						
							|  |  |  |             if old_token: | 
					
						
							|  |  |  |                 if isinstance(old_token, bytes): | 
					
						
							|  |  |  |                     old_token = old_token.decode("utf-8") | 
					
						
							|  |  |  |                 cls.revoke_token(old_token, token_type) | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         token = str(uuid.uuid4()) | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  |         token_data = {"account_id": account_id, "email": account_email, "token_type": token_type} | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  |         if additional_data: | 
					
						
							|  |  |  |             token_data.update(additional_data) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-22 11:01:32 +08:00
										 |  |  |         expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES") | 
					
						
							| 
									
										
										
										
											2024-12-15 17:18:17 +08:00
										 |  |  |         if expiry_minutes is None: | 
					
						
							|  |  |  |             raise ValueError(f"Expiry minutes for {token_type} token is not set") | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  |         token_key = cls._get_token_key(token, token_type) | 
					
						
							| 
									
										
										
										
											2024-10-21 18:14:26 +08:00
										 |  |  |         expiry_time = int(expiry_minutes * 60) | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  |         redis_client.setex(token_key, expiry_time, json.dumps(token_data)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if account_id: | 
					
						
							| 
									
										
										
										
											2024-12-15 17:18:17 +08:00
										 |  |  |             cls._set_current_token_for_account(account_id, token, token_type, expiry_minutes) | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return token | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def _get_token_key(cls, token: str, token_type: str) -> str: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         return f"{token_type}:token:{token}" | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def revoke_token(cls, token: str, token_type: str): | 
					
						
							|  |  |  |         token_key = cls._get_token_key(token, token_type) | 
					
						
							|  |  |  |         redis_client.delete(token_key) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]: | 
					
						
							|  |  |  |         key = cls._get_token_key(token, token_type) | 
					
						
							|  |  |  |         token_data_json = redis_client.get(key) | 
					
						
							|  |  |  |         if token_data_json is None: | 
					
						
							|  |  |  |             logging.warning(f"{token_type} token {token} not found with key {key}") | 
					
						
							|  |  |  |             return None | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |         token_data: Optional[dict[str, Any]] = json.loads(token_data_json) | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  |         return token_data | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: | 
					
						
							|  |  |  |         key = cls._get_account_token_key(account_id, token_type) | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |         current_token: Optional[str] = redis_client.get(key) | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  |         return current_token | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  |     def _set_current_token_for_account( | 
					
						
							|  |  |  |         cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float] | 
					
						
							|  |  |  |     ): | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  |         key = cls._get_account_token_key(account_id, token_type) | 
					
						
							| 
									
										
										
										
											2024-10-21 10:03:40 +08:00
										 |  |  |         expiry_time = int(expiry_hours * 60 * 60) | 
					
						
							|  |  |  |         redis_client.setex(key, expiry_time, token) | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def _get_account_token_key(cls, account_id: str, token_type: str) -> str: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         return f"{token_type}:account:{account_id}" | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class RateLimiter: | 
					
						
							|  |  |  |     def __init__(self, prefix: str, max_attempts: int, time_window: int): | 
					
						
							|  |  |  |         self.prefix = prefix | 
					
						
							|  |  |  |         self.max_attempts = max_attempts | 
					
						
							|  |  |  |         self.time_window = time_window | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _get_key(self, email: str) -> str: | 
					
						
							|  |  |  |         return f"{self.prefix}:{email}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def is_rate_limited(self, email: str) -> bool: | 
					
						
							|  |  |  |         key = self._get_key(email) | 
					
						
							|  |  |  |         current_time = int(time.time()) | 
					
						
							|  |  |  |         window_start_time = current_time - self.time_window | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         redis_client.zremrangebyscore(key, "-inf", window_start_time) | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  |         attempts = redis_client.zcard(key) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if attempts and int(attempts) >= self.max_attempts: | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  |         return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def increment_rate_limit(self, email: str): | 
					
						
							|  |  |  |         key = self._get_key(email) | 
					
						
							|  |  |  |         current_time = int(time.time()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         redis_client.zadd(key, {current_time: current_time}) | 
					
						
							|  |  |  |         redis_client.expire(key, self.time_window * 2) |