| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  | import json | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | import random | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | import re | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | import string | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | import subprocess | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | import time | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | import uuid | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  | from collections.abc import Generator | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from datetime import datetime | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | from hashlib import sha256 | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | from typing import Any, Optional, Union | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from zoneinfo import available_timezones | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | from flask import Response, current_app, stream_with_context | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from flask_restful import fields | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-22 22:58:07 +08:00
										 |  |  | from core.app.features.rate_limiting.rate_limit import RateLimitGenerator | 
					
						
							| 
									
										
										
										
											2024-08-19 09:16:33 +08:00
										 |  |  | from core.file.upload_file_parser import UploadFileParser | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | from extensions.ext_redis import redis_client | 
					
						
							|  |  |  | from models.account import Account | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | def run(script): | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     return subprocess.getstatusoutput("source /root/.bashrc && " + script) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-19 09:16:33 +08:00
										 |  |  | class AppIconUrlField(fields.Raw): | 
					
						
							|  |  |  |     def output(self, key, obj): | 
					
						
							|  |  |  |         if obj is None: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         from models.model import IconType | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if obj.icon_type == IconType.IMAGE.value: | 
					
						
							|  |  |  |             return UploadFileParser.get_signed_temp_image_url(obj.icon) | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | class TimestampField(fields.Raw): | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  |     def format(self, value) -> int: | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         return int(value.timestamp()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def email(email): | 
					
						
							|  |  |  |     # Define a regex pattern for email addresses | 
					
						
							| 
									
										
										
										
											2024-06-17 22:32:59 +09:00
										 |  |  |     pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$" | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |     # Check if the email matches the pattern | 
					
						
							|  |  |  |     if re.match(pattern, email) is not None: | 
					
						
							|  |  |  |         return email | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     error = "{email} is not a valid email.".format(email=email) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |     raise ValueError(error) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def uuid_value(value): | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     if value == "": | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         return str(value) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         uuid_obj = uuid.UUID(value) | 
					
						
							|  |  |  |         return str(uuid_obj) | 
					
						
							|  |  |  |     except ValueError: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         error = "{value} is not a valid uuid.".format(value=value) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         raise ValueError(error) | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  | def alphanumeric(value: str): | 
					
						
							|  |  |  |     # check if the value is alphanumeric and underlined | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     if re.match(r"^[a-zA-Z0-9_]+$", value): | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  |         return value | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     raise ValueError(f"{value} is not a valid alphanumeric value") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | def timestamp_value(timestamp): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         int_timestamp = int(timestamp) | 
					
						
							|  |  |  |         if int_timestamp < 0: | 
					
						
							|  |  |  |             raise ValueError | 
					
						
							|  |  |  |         return int_timestamp | 
					
						
							|  |  |  |     except ValueError: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         error = "{timestamp} is not a valid timestamp.".format(timestamp=timestamp) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         raise ValueError(error) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-11 16:40:52 +08:00
										 |  |  | class StrLen: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     """Restrict input to an integer in a range (inclusive)""" | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     def __init__(self, max_length, argument="argument"): | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         self.max_length = max_length | 
					
						
							|  |  |  |         self.argument = argument | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, value): | 
					
						
							|  |  |  |         length = len(value) | 
					
						
							|  |  |  |         if length > self.max_length: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |             error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format( | 
					
						
							|  |  |  |                 arg=self.argument, val=value, length=self.max_length | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             raise ValueError(error) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return value | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-11 16:40:52 +08:00
										 |  |  | class FloatRange: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     """Restrict input to an float in a range (inclusive)""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, low, high, argument="argument"): | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         self.low = low | 
					
						
							|  |  |  |         self.high = high | 
					
						
							|  |  |  |         self.argument = argument | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, value): | 
					
						
							|  |  |  |         value = _get_float(value) | 
					
						
							|  |  |  |         if value < self.low or value > self.high: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |             error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( | 
					
						
							|  |  |  |                 arg=self.argument, val=value, lo=self.low, hi=self.high | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             raise ValueError(error) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return value | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-11 16:40:52 +08:00
										 |  |  | class DatetimeString: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     def __init__(self, format, argument="argument"): | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         self.format = format | 
					
						
							|  |  |  |         self.argument = argument | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, value): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             datetime.strptime(value, self.format) | 
					
						
							|  |  |  |         except ValueError: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |             error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format( | 
					
						
							|  |  |  |                 arg=self.argument, val=value, format=self.format | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             raise ValueError(error) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return value | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _get_float(value): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         return float(value) | 
					
						
							|  |  |  |     except (TypeError, ValueError): | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         raise ValueError("{} is not a valid float".format(value)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | def timezone(timezone_string): | 
					
						
							|  |  |  |     if timezone_string and timezone_string in available_timezones(): | 
					
						
							|  |  |  |         return timezone_string | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |     raise ValueError(error) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def generate_string(n): | 
					
						
							|  |  |  |     letters_digits = string.ascii_letters + string.digits | 
					
						
							|  |  |  |     result = "" | 
					
						
							|  |  |  |     for i in range(n): | 
					
						
							|  |  |  |         result += random.choice(letters_digits) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return result | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-12 23:46:30 +08:00
										 |  |  | def extract_remote_ip(request) -> str: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     if request.headers.get("CF-Connecting-IP"): | 
					
						
							|  |  |  |         return request.headers.get("Cf-Connecting-Ip") | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |     elif request.headers.getlist("X-Forwarded-For"): | 
					
						
							|  |  |  |         return request.headers.getlist("X-Forwarded-For")[0] | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         return request.remote_addr | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def generate_text_hash(text: str) -> str: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |     hash_text = str(text) + "None" | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  |     return sha256(hash_text.encode()).hexdigest() | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-22 22:58:07 +08:00
										 |  |  | def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response: | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  |     if isinstance(response, dict): | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         return Response(response=json.dumps(response), status=200, mimetype="application/json") | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  |     else: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  |         def generate() -> Generator: | 
					
						
							|  |  |  |             yield from response | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TokenManager: | 
					
						
							|  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2024-10-09 14:36:43 +08:00
										 |  |  |     def generate_token(cls, account: Account, token_type: str, additional_data: Optional[dict] = None) -> str: | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  |         old_token = cls._get_current_token_for_account(account.id, token_type) | 
					
						
							|  |  |  |         if old_token: | 
					
						
							|  |  |  |             if isinstance(old_token, bytes): | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |                 old_token = old_token.decode("utf-8") | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  |             cls.revoke_token(old_token, token_type) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         token = str(uuid.uuid4()) | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         token_data = {"account_id": account.id, "email": account.email, "token_type": token_type} | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  |         if additional_data: | 
					
						
							|  |  |  |             token_data.update(additional_data) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"] | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  |         token_key = cls._get_token_key(token, token_type) | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data)) | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) | 
					
						
							|  |  |  |         return token | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def _get_token_key(cls, token: str, token_type: str) -> str: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         return f"{token_type}:token:{token}" | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def revoke_token(cls, token: str, token_type: str): | 
					
						
							|  |  |  |         token_key = cls._get_token_key(token, token_type) | 
					
						
							|  |  |  |         redis_client.delete(token_key) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]: | 
					
						
							|  |  |  |         key = cls._get_token_key(token, token_type) | 
					
						
							|  |  |  |         token_data_json = redis_client.get(key) | 
					
						
							|  |  |  |         if token_data_json is None: | 
					
						
							|  |  |  |             logging.warning(f"{token_type} token {token} not found with key {key}") | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  |         token_data = json.loads(token_data_json) | 
					
						
							|  |  |  |         return token_data | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: | 
					
						
							|  |  |  |         key = cls._get_account_token_key(account_id, token_type) | 
					
						
							|  |  |  |         current_token = redis_client.get(key) | 
					
						
							|  |  |  |         return current_token | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_hours: int): | 
					
						
							|  |  |  |         key = cls._get_account_token_key(account_id, token_type) | 
					
						
							|  |  |  |         redis_client.setex(key, expiry_hours * 60 * 60, token) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def _get_account_token_key(cls, account_id: str, token_type: str) -> str: | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         return f"{token_type}:account:{account_id}" | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class RateLimiter: | 
					
						
							|  |  |  |     def __init__(self, prefix: str, max_attempts: int, time_window: int): | 
					
						
							|  |  |  |         self.prefix = prefix | 
					
						
							|  |  |  |         self.max_attempts = max_attempts | 
					
						
							|  |  |  |         self.time_window = time_window | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _get_key(self, email: str) -> str: | 
					
						
							|  |  |  |         return f"{self.prefix}:{email}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def is_rate_limited(self, email: str) -> bool: | 
					
						
							|  |  |  |         key = self._get_key(email) | 
					
						
							|  |  |  |         current_time = int(time.time()) | 
					
						
							|  |  |  |         window_start_time = current_time - self.time_window | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 17:53:12 +08:00
										 |  |  |         redis_client.zremrangebyscore(key, "-inf", window_start_time) | 
					
						
							| 
									
										
										
										
											2024-07-05 13:38:51 +08:00
										 |  |  |         attempts = redis_client.zcard(key) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if attempts and int(attempts) >= self.max_attempts: | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  |         return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def increment_rate_limit(self, email: str): | 
					
						
							|  |  |  |         key = self._get_key(email) | 
					
						
							|  |  |  |         current_time = int(time.time()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         redis_client.zadd(key, {current_time: current_time}) | 
					
						
							|  |  |  |         redis_client.expire(key, self.time_window * 2) |