mirror of
				https://github.com/langgenius/dify.git
				synced 2025-10-31 02:42:59 +00:00 
			
		
		
		
	chore(api/libs): Apply ruff format. (#7301)
This commit is contained in:
		
							parent
							
								
									d07b2b9915
								
							
						
					
					
						commit
						9414143b5f
					
				| @ -25,7 +25,7 @@ class FireCrawlDataSource(BearerDataSource): | ||||
|         TEST_CRAWL_SITE_URL = "https://www.google.com" | ||||
|         FIRECRAWL_API_VERSION = "v0" | ||||
| 
 | ||||
|         test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape" | ||||
|         test_api_endpoint = self.api_base_url.rstrip("/") + f"/{FIRECRAWL_API_VERSION}/scrape" | ||||
| 
 | ||||
|         headers = { | ||||
|             "Authorization": f"Bearer {self.api_key}", | ||||
| @ -45,9 +45,9 @@ class FireCrawlDataSource(BearerDataSource): | ||||
|         data_source_binding = DataSourceBearerBinding.query.filter( | ||||
|             db.and_( | ||||
|                 DataSourceBearerBinding.tenant_id == current_user.current_tenant_id, | ||||
|                 DataSourceBearerBinding.provider == 'firecrawl', | ||||
|                 DataSourceBearerBinding.provider == "firecrawl", | ||||
|                 DataSourceBearerBinding.endpoint_url == self.api_base_url, | ||||
|                 DataSourceBearerBinding.bearer_key == self.api_key | ||||
|                 DataSourceBearerBinding.bearer_key == self.api_key, | ||||
|             ) | ||||
|         ).first() | ||||
|         if data_source_binding: | ||||
| @ -56,9 +56,9 @@ class FireCrawlDataSource(BearerDataSource): | ||||
|         else: | ||||
|             new_data_source_binding = DataSourceBearerBinding( | ||||
|                 tenant_id=current_user.current_tenant_id, | ||||
|                 provider='firecrawl', | ||||
|                 provider="firecrawl", | ||||
|                 endpoint_url=self.api_base_url, | ||||
|                 bearer_key=self.api_key | ||||
|                 bearer_key=self.api_key, | ||||
|             ) | ||||
|             db.session.add(new_data_source_binding) | ||||
|             db.session.commit() | ||||
|  | ||||
| @ -4,7 +4,7 @@ from werkzeug.exceptions import HTTPException | ||||
| 
 | ||||
| 
 | ||||
| class BaseHTTPException(HTTPException): | ||||
|     error_code: str = 'unknown' | ||||
|     error_code: str = "unknown" | ||||
|     data: Optional[dict] = None | ||||
| 
 | ||||
|     def __init__(self, description=None, response=None): | ||||
| @ -14,4 +14,4 @@ class BaseHTTPException(HTTPException): | ||||
|             "code": self.error_code, | ||||
|             "message": self.description, | ||||
|             "status": self.code, | ||||
|         } | ||||
|         } | ||||
|  | ||||
| @ -10,7 +10,6 @@ from core.errors.error import AppInvokeQuotaExceededError | ||||
| 
 | ||||
| 
 | ||||
| class ExternalApi(Api): | ||||
| 
 | ||||
|     def handle_error(self, e): | ||||
|         """Error handler for the API transforms a raised exception into a Flask | ||||
|         response, with the appropriate HTTP status code and body. | ||||
| @ -29,54 +28,57 @@ class ExternalApi(Api): | ||||
| 
 | ||||
|             status_code = e.code | ||||
|             default_data = { | ||||
|                 'code': re.sub(r'(?<!^)(?=[A-Z])', '_', type(e).__name__).lower(), | ||||
|                 'message': getattr(e, 'description', http_status_message(status_code)), | ||||
|                 'status': status_code | ||||
|                 "code": re.sub(r"(?<!^)(?=[A-Z])", "_", type(e).__name__).lower(), | ||||
|                 "message": getattr(e, "description", http_status_message(status_code)), | ||||
|                 "status": status_code, | ||||
|             } | ||||
| 
 | ||||
|             if default_data['message'] and default_data['message'] == 'Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)': | ||||
|                 default_data['message'] = 'Invalid JSON payload received or JSON payload is empty.' | ||||
|             if ( | ||||
|                 default_data["message"] | ||||
|                 and default_data["message"] == "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)" | ||||
|             ): | ||||
|                 default_data["message"] = "Invalid JSON payload received or JSON payload is empty." | ||||
| 
 | ||||
|             headers = e.get_response().headers | ||||
|         elif isinstance(e, ValueError): | ||||
|             status_code = 400 | ||||
|             default_data = { | ||||
|                 'code': 'invalid_param', | ||||
|                 'message': str(e), | ||||
|                 'status': status_code | ||||
|                 "code": "invalid_param", | ||||
|                 "message": str(e), | ||||
|                 "status": status_code, | ||||
|             } | ||||
|         elif isinstance(e, AppInvokeQuotaExceededError): | ||||
|             status_code = 429 | ||||
|             default_data = { | ||||
|                 'code': 'too_many_requests', | ||||
|                 'message': str(e), | ||||
|                 'status': status_code | ||||
|                 "code": "too_many_requests", | ||||
|                 "message": str(e), | ||||
|                 "status": status_code, | ||||
|             } | ||||
|         else: | ||||
|             status_code = 500 | ||||
|             default_data = { | ||||
|                 'message': http_status_message(status_code), | ||||
|                 "message": http_status_message(status_code), | ||||
|             } | ||||
| 
 | ||||
|         # Werkzeug exceptions generate a content-length header which is added | ||||
|         # to the response in addition to the actual content-length header | ||||
|         # https://github.com/flask-restful/flask-restful/issues/534 | ||||
|         remove_headers = ('Content-Length',) | ||||
|         remove_headers = ("Content-Length",) | ||||
| 
 | ||||
|         for header in remove_headers: | ||||
|             headers.pop(header, None) | ||||
| 
 | ||||
|         data = getattr(e, 'data', default_data) | ||||
|         data = getattr(e, "data", default_data) | ||||
| 
 | ||||
|         error_cls_name = type(e).__name__ | ||||
|         if error_cls_name in self.errors: | ||||
|             custom_data = self.errors.get(error_cls_name, {}) | ||||
|             custom_data = custom_data.copy() | ||||
|             status_code = custom_data.get('status', 500) | ||||
|             status_code = custom_data.get("status", 500) | ||||
| 
 | ||||
|             if 'message' in custom_data: | ||||
|                 custom_data['message'] = custom_data['message'].format( | ||||
|                     message=str(e.description if hasattr(e, 'description') else e) | ||||
|             if "message" in custom_data: | ||||
|                 custom_data["message"] = custom_data["message"].format( | ||||
|                     message=str(e.description if hasattr(e, "description") else e) | ||||
|                 ) | ||||
|             data.update(custom_data) | ||||
| 
 | ||||
| @ -94,32 +96,20 @@ class ExternalApi(Api): | ||||
|             # another NotAcceptable error). | ||||
|             supported_mediatypes = list(self.representations.keys())  # only supported application/json | ||||
|             fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain" | ||||
|             data = { | ||||
|                 'code': 'not_acceptable', | ||||
|                 'message': data.get('message') | ||||
|             } | ||||
|             resp = self.make_response( | ||||
|                 data, | ||||
|                 status_code, | ||||
|                 headers, | ||||
|                 fallback_mediatype = fallback_mediatype | ||||
|             ) | ||||
|             data = {"code": "not_acceptable", "message": data.get("message")} | ||||
|             resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype) | ||||
|         elif status_code == 400: | ||||
|             if isinstance(data.get('message'), dict): | ||||
|                 param_key, param_value = list(data.get('message').items())[0] | ||||
|                 data = { | ||||
|                     'code': 'invalid_param', | ||||
|                     'message': param_value, | ||||
|                     'params': param_key | ||||
|                 } | ||||
|             if isinstance(data.get("message"), dict): | ||||
|                 param_key, param_value = list(data.get("message").items())[0] | ||||
|                 data = {"code": "invalid_param", "message": param_value, "params": param_key} | ||||
|             else: | ||||
|                 if 'code' not in data: | ||||
|                     data['code'] = 'unknown' | ||||
|                 if "code" not in data: | ||||
|                     data["code"] = "unknown" | ||||
| 
 | ||||
|             resp = self.make_response(data, status_code, headers) | ||||
|         else: | ||||
|             if 'code' not in data: | ||||
|                 data['code'] = 'unknown' | ||||
|             if "code" not in data: | ||||
|                 data["code"] = "unknown" | ||||
| 
 | ||||
|             resp = self.make_response(data, status_code, headers) | ||||
| 
 | ||||
|  | ||||
| @ -70,7 +70,7 @@ class PKCS1OAEP_Cipher: | ||||
|         if mgfunc: | ||||
|             self._mgf = mgfunc | ||||
|         else: | ||||
|             self._mgf = lambda x,y: MGF1(x,y,self._hashObj) | ||||
|             self._mgf = lambda x, y: MGF1(x, y, self._hashObj) | ||||
| 
 | ||||
|         self._label = _copy_bytes(None, None, label) | ||||
|         self._randfunc = randfunc | ||||
| @ -107,7 +107,7 @@ class PKCS1OAEP_Cipher: | ||||
| 
 | ||||
|         # See 7.1.1 in RFC3447 | ||||
|         modBits = Crypto.Util.number.size(self._key.n) | ||||
|         k = ceil_div(modBits, 8) # Convert from bits to bytes | ||||
|         k = ceil_div(modBits, 8)  # Convert from bits to bytes | ||||
|         hLen = self._hashObj.digest_size | ||||
|         mLen = len(message) | ||||
| 
 | ||||
| @ -118,13 +118,13 @@ class PKCS1OAEP_Cipher: | ||||
|         # Step 2a | ||||
|         lHash = sha1(self._label).digest() | ||||
|         # Step 2b | ||||
|         ps = b'\x00' * ps_len | ||||
|         ps = b"\x00" * ps_len | ||||
|         # Step 2c | ||||
|         db = lHash + ps + b'\x01' + _copy_bytes(None, None, message) | ||||
|         db = lHash + ps + b"\x01" + _copy_bytes(None, None, message) | ||||
|         # Step 2d | ||||
|         ros = self._randfunc(hLen) | ||||
|         # Step 2e | ||||
|         dbMask = self._mgf(ros, k-hLen-1) | ||||
|         dbMask = self._mgf(ros, k - hLen - 1) | ||||
|         # Step 2f | ||||
|         maskedDB = strxor(db, dbMask) | ||||
|         # Step 2g | ||||
| @ -132,7 +132,7 @@ class PKCS1OAEP_Cipher: | ||||
|         # Step 2h | ||||
|         maskedSeed = strxor(ros, seedMask) | ||||
|         # Step 2i | ||||
|         em = b'\x00' + maskedSeed + maskedDB | ||||
|         em = b"\x00" + maskedSeed + maskedDB | ||||
|         # Step 3a (OS2IP) | ||||
|         em_int = bytes_to_long(em) | ||||
|         # Step 3b (RSAEP) | ||||
| @ -160,10 +160,10 @@ class PKCS1OAEP_Cipher: | ||||
|         """ | ||||
|         # See 7.1.2 in RFC3447 | ||||
|         modBits = Crypto.Util.number.size(self._key.n) | ||||
|         k = ceil_div(modBits,8) # Convert from bits to bytes | ||||
|         k = ceil_div(modBits, 8)  # Convert from bits to bytes | ||||
|         hLen = self._hashObj.digest_size | ||||
|         # Step 1b and 1c | ||||
|         if len(ciphertext) != k or k<hLen+2: | ||||
|         if len(ciphertext) != k or k < hLen + 2: | ||||
|             raise ValueError("Ciphertext with incorrect length.") | ||||
|         # Step 2a (O2SIP) | ||||
|         ct_int = bytes_to_long(ciphertext) | ||||
| @ -178,18 +178,18 @@ class PKCS1OAEP_Cipher: | ||||
|         y = em[0] | ||||
|         # y must be 0, but we MUST NOT check it here in order not to | ||||
|         # allow attacks like Manger's (http://dl.acm.org/citation.cfm?id=704143) | ||||
|         maskedSeed = em[1:hLen+1] | ||||
|         maskedDB = em[hLen+1:] | ||||
|         maskedSeed = em[1 : hLen + 1] | ||||
|         maskedDB = em[hLen + 1 :] | ||||
|         # Step 3c | ||||
|         seedMask = self._mgf(maskedDB, hLen) | ||||
|         # Step 3d | ||||
|         seed = strxor(maskedSeed, seedMask) | ||||
|         # Step 3e | ||||
|         dbMask = self._mgf(seed, k-hLen-1) | ||||
|         dbMask = self._mgf(seed, k - hLen - 1) | ||||
|         # Step 3f | ||||
|         db = strxor(maskedDB, dbMask) | ||||
|         # Step 3g | ||||
|         one_pos = hLen + db[hLen:].find(b'\x01') | ||||
|         one_pos = hLen + db[hLen:].find(b"\x01") | ||||
|         lHash1 = db[:hLen] | ||||
|         invalid = bord(y) | int(one_pos < hLen) | ||||
|         hash_compare = strxor(lHash1, lHash) | ||||
| @ -200,9 +200,10 @@ class PKCS1OAEP_Cipher: | ||||
|         if invalid != 0: | ||||
|             raise ValueError("Incorrect decryption.") | ||||
|         # Step 4 | ||||
|         return db[one_pos + 1:] | ||||
|         return db[one_pos + 1 :] | ||||
| 
 | ||||
| def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None): | ||||
| 
 | ||||
| def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None): | ||||
|     """Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption. | ||||
| 
 | ||||
|     :param key: | ||||
|  | ||||
| @ -21,7 +21,7 @@ from models.account import Account | ||||
| 
 | ||||
| 
 | ||||
| def run(script): | ||||
|     return subprocess.getstatusoutput('source /root/.bashrc && ' + script) | ||||
|     return subprocess.getstatusoutput("source /root/.bashrc && " + script) | ||||
| 
 | ||||
| 
 | ||||
| class TimestampField(fields.Raw): | ||||
| @ -36,29 +36,29 @@ def email(email): | ||||
|     if re.match(pattern, email) is not None: | ||||
|         return email | ||||
| 
 | ||||
|     error = ('{email} is not a valid email.' | ||||
|              .format(email=email)) | ||||
|     error = "{email} is not a valid email.".format(email=email) | ||||
|     raise ValueError(error) | ||||
| 
 | ||||
| 
 | ||||
| def uuid_value(value): | ||||
|     if value == '': | ||||
|     if value == "": | ||||
|         return str(value) | ||||
| 
 | ||||
|     try: | ||||
|         uuid_obj = uuid.UUID(value) | ||||
|         return str(uuid_obj) | ||||
|     except ValueError: | ||||
|         error = ('{value} is not a valid uuid.' | ||||
|                  .format(value=value)) | ||||
|         error = "{value} is not a valid uuid.".format(value=value) | ||||
|         raise ValueError(error) | ||||
| 
 | ||||
| 
 | ||||
| def alphanumeric(value: str): | ||||
|     # check if the value is alphanumeric and underlined | ||||
|     if re.match(r'^[a-zA-Z0-9_]+$', value): | ||||
|     if re.match(r"^[a-zA-Z0-9_]+$", value): | ||||
|         return value | ||||
| 
 | ||||
|     raise ValueError(f'{value} is not a valid alphanumeric value') | ||||
|     raise ValueError(f"{value} is not a valid alphanumeric value") | ||||
| 
 | ||||
| 
 | ||||
| def timestamp_value(timestamp): | ||||
|     try: | ||||
| @ -67,31 +67,32 @@ def timestamp_value(timestamp): | ||||
|             raise ValueError | ||||
|         return int_timestamp | ||||
|     except ValueError: | ||||
|         error = ('{timestamp} is not a valid timestamp.' | ||||
|                  .format(timestamp=timestamp)) | ||||
|         error = "{timestamp} is not a valid timestamp.".format(timestamp=timestamp) | ||||
|         raise ValueError(error) | ||||
| 
 | ||||
| 
 | ||||
| class str_len: | ||||
|     """ Restrict input to an integer in a range (inclusive) """ | ||||
|     """Restrict input to an integer in a range (inclusive)""" | ||||
| 
 | ||||
|     def __init__(self, max_length, argument='argument'): | ||||
|     def __init__(self, max_length, argument="argument"): | ||||
|         self.max_length = max_length | ||||
|         self.argument = argument | ||||
| 
 | ||||
|     def __call__(self, value): | ||||
|         length = len(value) | ||||
|         if length > self.max_length: | ||||
|             error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}' | ||||
|                      .format(arg=self.argument, val=value, length=self.max_length)) | ||||
|             error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format( | ||||
|                 arg=self.argument, val=value, length=self.max_length | ||||
|             ) | ||||
|             raise ValueError(error) | ||||
| 
 | ||||
|         return value | ||||
| 
 | ||||
| 
 | ||||
| class float_range: | ||||
|     """ Restrict input to an float in a range (inclusive) """ | ||||
|     def __init__(self, low, high, argument='argument'): | ||||
|     """Restrict input to an float in a range (inclusive)""" | ||||
| 
 | ||||
|     def __init__(self, low, high, argument="argument"): | ||||
|         self.low = low | ||||
|         self.high = high | ||||
|         self.argument = argument | ||||
| @ -99,15 +100,16 @@ class float_range: | ||||
|     def __call__(self, value): | ||||
|         value = _get_float(value) | ||||
|         if value < self.low or value > self.high: | ||||
|             error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}' | ||||
|                      .format(arg=self.argument, val=value, lo=self.low, hi=self.high)) | ||||
|             error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( | ||||
|                 arg=self.argument, val=value, lo=self.low, hi=self.high | ||||
|             ) | ||||
|             raise ValueError(error) | ||||
| 
 | ||||
|         return value | ||||
| 
 | ||||
| 
 | ||||
| class datetime_string: | ||||
|     def __init__(self, format, argument='argument'): | ||||
|     def __init__(self, format, argument="argument"): | ||||
|         self.format = format | ||||
|         self.argument = argument | ||||
| 
 | ||||
| @ -115,8 +117,9 @@ class datetime_string: | ||||
|         try: | ||||
|             datetime.strptime(value, self.format) | ||||
|         except ValueError: | ||||
|             error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}' | ||||
|                      .format(arg=self.argument, val=value, format=self.format)) | ||||
|             error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format( | ||||
|                 arg=self.argument, val=value, format=self.format | ||||
|             ) | ||||
|             raise ValueError(error) | ||||
| 
 | ||||
|         return value | ||||
| @ -126,14 +129,14 @@ def _get_float(value): | ||||
|     try: | ||||
|         return float(value) | ||||
|     except (TypeError, ValueError): | ||||
|         raise ValueError('{} is not a valid float'.format(value)) | ||||
|         raise ValueError("{} is not a valid float".format(value)) | ||||
| 
 | ||||
| 
 | ||||
| def timezone(timezone_string): | ||||
|     if timezone_string and timezone_string in available_timezones(): | ||||
|         return timezone_string | ||||
| 
 | ||||
|     error = ('{timezone_string} is not a valid timezone.' | ||||
|              .format(timezone_string=timezone_string)) | ||||
|     error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string) | ||||
|     raise ValueError(error) | ||||
| 
 | ||||
| 
 | ||||
| @ -147,8 +150,8 @@ def generate_string(n): | ||||
| 
 | ||||
| 
 | ||||
| def get_remote_ip(request) -> str: | ||||
|     if request.headers.get('CF-Connecting-IP'): | ||||
|         return request.headers.get('Cf-Connecting-Ip') | ||||
|     if request.headers.get("CF-Connecting-IP"): | ||||
|         return request.headers.get("Cf-Connecting-Ip") | ||||
|     elif request.headers.getlist("X-Forwarded-For"): | ||||
|         return request.headers.getlist("X-Forwarded-For")[0] | ||||
|     else: | ||||
| @ -156,54 +159,45 @@ def get_remote_ip(request) -> str: | ||||
| 
 | ||||
| 
 | ||||
| def generate_text_hash(text: str) -> str: | ||||
|     hash_text = str(text) + 'None' | ||||
|     hash_text = str(text) + "None" | ||||
|     return sha256(hash_text.encode()).hexdigest() | ||||
| 
 | ||||
| 
 | ||||
| def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response: | ||||
|     if isinstance(response, dict): | ||||
|         return Response(response=json.dumps(response), status=200, mimetype='application/json') | ||||
|         return Response(response=json.dumps(response), status=200, mimetype="application/json") | ||||
|     else: | ||||
| 
 | ||||
|         def generate() -> Generator: | ||||
|             yield from response | ||||
| 
 | ||||
|         return Response(stream_with_context(generate()), status=200, | ||||
|                         mimetype='text/event-stream') | ||||
|         return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") | ||||
| 
 | ||||
| 
 | ||||
| class TokenManager: | ||||
| 
 | ||||
|     @classmethod | ||||
|     def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str: | ||||
|         old_token = cls._get_current_token_for_account(account.id, token_type) | ||||
|         if old_token: | ||||
|             if isinstance(old_token, bytes): | ||||
|                 old_token = old_token.decode('utf-8') | ||||
|                 old_token = old_token.decode("utf-8") | ||||
|             cls.revoke_token(old_token, token_type) | ||||
| 
 | ||||
|         token = str(uuid.uuid4()) | ||||
|         token_data = { | ||||
|             'account_id': account.id, | ||||
|             'email': account.email, | ||||
|             'token_type': token_type | ||||
|         } | ||||
|         token_data = {"account_id": account.id, "email": account.email, "token_type": token_type} | ||||
|         if additional_data: | ||||
|             token_data.update(additional_data) | ||||
| 
 | ||||
|         expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS'] | ||||
|         expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"] | ||||
|         token_key = cls._get_token_key(token, token_type) | ||||
|         redis_client.setex( | ||||
|             token_key, | ||||
|             expiry_hours * 60 * 60, | ||||
|             json.dumps(token_data) | ||||
|         ) | ||||
|         redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data)) | ||||
| 
 | ||||
|         cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) | ||||
|         return token | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _get_token_key(cls, token: str, token_type: str) -> str: | ||||
|         return f'{token_type}:token:{token}' | ||||
|         return f"{token_type}:token:{token}" | ||||
| 
 | ||||
|     @classmethod | ||||
|     def revoke_token(cls, token: str, token_type: str): | ||||
| @ -233,7 +227,7 @@ class TokenManager: | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _get_account_token_key(cls, account_id: str, token_type: str) -> str: | ||||
|         return f'{token_type}:account:{account_id}' | ||||
|         return f"{token_type}:account:{account_id}" | ||||
| 
 | ||||
| 
 | ||||
| class RateLimiter: | ||||
| @ -250,7 +244,7 @@ class RateLimiter: | ||||
|         current_time = int(time.time()) | ||||
|         window_start_time = current_time - self.time_window | ||||
| 
 | ||||
|         redis_client.zremrangebyscore(key, '-inf', window_start_time) | ||||
|         redis_client.zremrangebyscore(key, "-inf", window_start_time) | ||||
|         attempts = redis_client.zcard(key) | ||||
| 
 | ||||
|         if attempts and int(attempts) >= self.max_attempts: | ||||
|  | ||||
| @ -1,4 +1,3 @@ | ||||
| 
 | ||||
| class InfiniteScrollPagination: | ||||
|     def __init__(self, data, limit, has_more): | ||||
|         self.data = data | ||||
|  | ||||
| @ -10,13 +10,13 @@ def parse_json_markdown(json_string: str) -> dict: | ||||
|     end_index = json_string.find("```", start_index + len("```json")) | ||||
| 
 | ||||
|     if start_index != -1 and end_index != -1: | ||||
|         extracted_content = json_string[start_index + len("```json"):end_index].strip() | ||||
|         extracted_content = json_string[start_index + len("```json") : end_index].strip() | ||||
| 
 | ||||
|         # Parse the JSON string into a Python dictionary | ||||
|         parsed = json.loads(extracted_content) | ||||
|     elif start_index != -1 and end_index == -1 and json_string.endswith("``"): | ||||
|         end_index = json_string.find("``", start_index + len("```json")) | ||||
|         extracted_content = json_string[start_index + len("```json"):end_index].strip() | ||||
|         extracted_content = json_string[start_index + len("```json") : end_index].strip() | ||||
| 
 | ||||
|         # Parse the JSON string into a Python dictionary | ||||
|         parsed = json.loads(extracted_content) | ||||
| @ -37,7 +37,6 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: | ||||
|     for key in expected_keys: | ||||
|         if key not in json_obj: | ||||
|             raise OutputParserException( | ||||
|                 f"Got invalid return object. Expected key `{key}` " | ||||
|                 f"to be present, but got {json_obj}" | ||||
|                 f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}" | ||||
|             ) | ||||
|     return json_obj | ||||
|  | ||||
| @ -51,27 +51,29 @@ def login_required(func): | ||||
| 
 | ||||
|     @wraps(func) | ||||
|     def decorated_view(*args, **kwargs): | ||||
|         auth_header = request.headers.get('Authorization') | ||||
|         admin_api_key_enable = os.getenv('ADMIN_API_KEY_ENABLE', default='False') | ||||
|         if admin_api_key_enable.lower() == 'true': | ||||
|         auth_header = request.headers.get("Authorization") | ||||
|         admin_api_key_enable = os.getenv("ADMIN_API_KEY_ENABLE", default="False") | ||||
|         if admin_api_key_enable.lower() == "true": | ||||
|             if auth_header: | ||||
|                 if ' ' not in auth_header: | ||||
|                     raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | ||||
|                 if " " not in auth_header: | ||||
|                     raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | ||||
|                 auth_scheme, auth_token = auth_header.split(None, 1) | ||||
|                 auth_scheme = auth_scheme.lower() | ||||
|                 if auth_scheme != 'bearer': | ||||
|                     raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | ||||
|                 admin_api_key = os.getenv('ADMIN_API_KEY') | ||||
|                 if auth_scheme != "bearer": | ||||
|                     raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | ||||
|                 admin_api_key = os.getenv("ADMIN_API_KEY") | ||||
| 
 | ||||
|                 if admin_api_key: | ||||
|                     if os.getenv('ADMIN_API_KEY') == auth_token: | ||||
|                         workspace_id = request.headers.get('X-WORKSPACE-ID') | ||||
|                     if os.getenv("ADMIN_API_KEY") == auth_token: | ||||
|                         workspace_id = request.headers.get("X-WORKSPACE-ID") | ||||
|                         if workspace_id: | ||||
|                             tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ | ||||
|                                 .filter(Tenant.id == workspace_id) \ | ||||
|                                 .filter(TenantAccountJoin.tenant_id == Tenant.id) \ | ||||
|                                 .filter(TenantAccountJoin.role == 'owner') \ | ||||
|                             tenant_account_join = ( | ||||
|                                 db.session.query(Tenant, TenantAccountJoin) | ||||
|                                 .filter(Tenant.id == workspace_id) | ||||
|                                 .filter(TenantAccountJoin.tenant_id == Tenant.id) | ||||
|                                 .filter(TenantAccountJoin.role == "owner") | ||||
|                                 .one_or_none() | ||||
|                             ) | ||||
|                             if tenant_account_join: | ||||
|                                 tenant, ta = tenant_account_join | ||||
|                                 account = Account.query.filter_by(id=ta.account_id).first() | ||||
|  | ||||
| @ -35,31 +35,31 @@ class OAuth: | ||||
| 
 | ||||
| 
 | ||||
| class GitHubOAuth(OAuth): | ||||
|     _AUTH_URL = 'https://github.com/login/oauth/authorize' | ||||
|     _TOKEN_URL = 'https://github.com/login/oauth/access_token' | ||||
|     _USER_INFO_URL = 'https://api.github.com/user' | ||||
|     _EMAIL_INFO_URL = 'https://api.github.com/user/emails' | ||||
|     _AUTH_URL = "https://github.com/login/oauth/authorize" | ||||
|     _TOKEN_URL = "https://github.com/login/oauth/access_token" | ||||
|     _USER_INFO_URL = "https://api.github.com/user" | ||||
|     _EMAIL_INFO_URL = "https://api.github.com/user/emails" | ||||
| 
 | ||||
|     def get_authorization_url(self): | ||||
|         params = { | ||||
|             'client_id': self.client_id, | ||||
|             'redirect_uri': self.redirect_uri, | ||||
|             'scope': 'user:email'  # Request only basic user information | ||||
|             "client_id": self.client_id, | ||||
|             "redirect_uri": self.redirect_uri, | ||||
|             "scope": "user:email",  # Request only basic user information | ||||
|         } | ||||
|         return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" | ||||
| 
 | ||||
|     def get_access_token(self, code: str): | ||||
|         data = { | ||||
|             'client_id': self.client_id, | ||||
|             'client_secret': self.client_secret, | ||||
|             'code': code, | ||||
|             'redirect_uri': self.redirect_uri | ||||
|             "client_id": self.client_id, | ||||
|             "client_secret": self.client_secret, | ||||
|             "code": code, | ||||
|             "redirect_uri": self.redirect_uri, | ||||
|         } | ||||
|         headers = {'Accept': 'application/json'} | ||||
|         headers = {"Accept": "application/json"} | ||||
|         response = requests.post(self._TOKEN_URL, data=data, headers=headers) | ||||
| 
 | ||||
|         response_json = response.json() | ||||
|         access_token = response_json.get('access_token') | ||||
|         access_token = response_json.get("access_token") | ||||
| 
 | ||||
|         if not access_token: | ||||
|             raise ValueError(f"Error in GitHub OAuth: {response_json}") | ||||
| @ -67,55 +67,51 @@ class GitHubOAuth(OAuth): | ||||
|         return access_token | ||||
| 
 | ||||
|     def get_raw_user_info(self, token: str): | ||||
|         headers = {'Authorization': f"token {token}"} | ||||
|         headers = {"Authorization": f"token {token}"} | ||||
|         response = requests.get(self._USER_INFO_URL, headers=headers) | ||||
|         response.raise_for_status() | ||||
|         user_info = response.json() | ||||
| 
 | ||||
|         email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) | ||||
|         email_info = email_response.json() | ||||
|         primary_email = next((email for email in email_info if email['primary'] == True), None) | ||||
|         primary_email = next((email for email in email_info if email["primary"] == True), None) | ||||
| 
 | ||||
|         return {**user_info, 'email': primary_email['email']} | ||||
|         return {**user_info, "email": primary_email["email"]} | ||||
| 
 | ||||
|     def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: | ||||
|         email = raw_info.get('email') | ||||
|         email = raw_info.get("email") | ||||
|         if not email: | ||||
|             email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com" | ||||
|         return OAuthUserInfo( | ||||
|             id=str(raw_info['id']), | ||||
|             name=raw_info['name'], | ||||
|             email=email | ||||
|         ) | ||||
|         return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email) | ||||
| 
 | ||||
| 
 | ||||
| class GoogleOAuth(OAuth): | ||||
|     _AUTH_URL = 'https://accounts.google.com/o/oauth2/v2/auth' | ||||
|     _TOKEN_URL = 'https://oauth2.googleapis.com/token' | ||||
|     _USER_INFO_URL = 'https://www.googleapis.com/oauth2/v3/userinfo' | ||||
|     _AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth" | ||||
|     _TOKEN_URL = "https://oauth2.googleapis.com/token" | ||||
|     _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" | ||||
| 
 | ||||
|     def get_authorization_url(self): | ||||
|         params = { | ||||
|             'client_id': self.client_id, | ||||
|             'response_type': 'code', | ||||
|             'redirect_uri': self.redirect_uri, | ||||
|             'scope': 'openid email' | ||||
|             "client_id": self.client_id, | ||||
|             "response_type": "code", | ||||
|             "redirect_uri": self.redirect_uri, | ||||
|             "scope": "openid email", | ||||
|         } | ||||
|         return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" | ||||
| 
 | ||||
|     def get_access_token(self, code: str): | ||||
|         data = { | ||||
|             'client_id': self.client_id, | ||||
|             'client_secret': self.client_secret, | ||||
|             'code': code, | ||||
|             'grant_type': 'authorization_code', | ||||
|             'redirect_uri': self.redirect_uri | ||||
|             "client_id": self.client_id, | ||||
|             "client_secret": self.client_secret, | ||||
|             "code": code, | ||||
|             "grant_type": "authorization_code", | ||||
|             "redirect_uri": self.redirect_uri, | ||||
|         } | ||||
|         headers = {'Accept': 'application/json'} | ||||
|         headers = {"Accept": "application/json"} | ||||
|         response = requests.post(self._TOKEN_URL, data=data, headers=headers) | ||||
| 
 | ||||
|         response_json = response.json() | ||||
|         access_token = response_json.get('access_token') | ||||
|         access_token = response_json.get("access_token") | ||||
| 
 | ||||
|         if not access_token: | ||||
|             raise ValueError(f"Error in Google OAuth: {response_json}") | ||||
| @ -123,16 +119,10 @@ class GoogleOAuth(OAuth): | ||||
|         return access_token | ||||
| 
 | ||||
|     def get_raw_user_info(self, token: str): | ||||
|         headers = {'Authorization': f"Bearer {token}"} | ||||
|         headers = {"Authorization": f"Bearer {token}"} | ||||
|         response = requests.get(self._USER_INFO_URL, headers=headers) | ||||
|         response.raise_for_status() | ||||
|         return response.json() | ||||
| 
 | ||||
|     def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: | ||||
|         return OAuthUserInfo( | ||||
|             id=str(raw_info['sub']), | ||||
|             name=None, | ||||
|             email=raw_info['email'] | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
|         return OAuthUserInfo(id=str(raw_info["sub"]), name=None, email=raw_info["email"]) | ||||
|  | ||||
| @ -21,53 +21,49 @@ class OAuthDataSource: | ||||
| 
 | ||||
| 
 | ||||
| class NotionOAuth(OAuthDataSource): | ||||
|     _AUTH_URL = 'https://api.notion.com/v1/oauth/authorize' | ||||
|     _TOKEN_URL = 'https://api.notion.com/v1/oauth/token' | ||||
|     _AUTH_URL = "https://api.notion.com/v1/oauth/authorize" | ||||
|     _TOKEN_URL = "https://api.notion.com/v1/oauth/token" | ||||
|     _NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search" | ||||
|     _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks" | ||||
|     _NOTION_BOT_USER = "https://api.notion.com/v1/users/me" | ||||
| 
 | ||||
|     def get_authorization_url(self): | ||||
|         params = { | ||||
|             'client_id': self.client_id, | ||||
|             'response_type': 'code', | ||||
|             'redirect_uri': self.redirect_uri, | ||||
|             'owner': 'user' | ||||
|             "client_id": self.client_id, | ||||
|             "response_type": "code", | ||||
|             "redirect_uri": self.redirect_uri, | ||||
|             "owner": "user", | ||||
|         } | ||||
|         return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" | ||||
| 
 | ||||
|     def get_access_token(self, code: str): | ||||
|         data = { | ||||
|             'code': code, | ||||
|             'grant_type': 'authorization_code', | ||||
|             'redirect_uri': self.redirect_uri | ||||
|         } | ||||
|         headers = {'Accept': 'application/json'} | ||||
|         data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} | ||||
|         headers = {"Accept": "application/json"} | ||||
|         auth = (self.client_id, self.client_secret) | ||||
|         response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) | ||||
| 
 | ||||
|         response_json = response.json() | ||||
|         access_token = response_json.get('access_token') | ||||
|         access_token = response_json.get("access_token") | ||||
|         if not access_token: | ||||
|             raise ValueError(f"Error in Notion OAuth: {response_json}") | ||||
|         workspace_name = response_json.get('workspace_name') | ||||
|         workspace_icon = response_json.get('workspace_icon') | ||||
|         workspace_id = response_json.get('workspace_id') | ||||
|         workspace_name = response_json.get("workspace_name") | ||||
|         workspace_icon = response_json.get("workspace_icon") | ||||
|         workspace_id = response_json.get("workspace_id") | ||||
|         # get all authorized pages | ||||
|         pages = self.get_authorized_pages(access_token) | ||||
|         source_info = { | ||||
|             'workspace_name': workspace_name, | ||||
|             'workspace_icon': workspace_icon, | ||||
|             'workspace_id': workspace_id, | ||||
|             'pages': pages, | ||||
|             'total': len(pages) | ||||
|             "workspace_name": workspace_name, | ||||
|             "workspace_icon": workspace_icon, | ||||
|             "workspace_id": workspace_id, | ||||
|             "pages": pages, | ||||
|             "total": len(pages), | ||||
|         } | ||||
|         # save data source binding | ||||
|         data_source_binding = DataSourceOauthBinding.query.filter( | ||||
|             db.and_( | ||||
|                 DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | ||||
|                 DataSourceOauthBinding.provider == 'notion', | ||||
|                 DataSourceOauthBinding.access_token == access_token | ||||
|                 DataSourceOauthBinding.provider == "notion", | ||||
|                 DataSourceOauthBinding.access_token == access_token, | ||||
|             ) | ||||
|         ).first() | ||||
|         if data_source_binding: | ||||
| @ -79,7 +75,7 @@ class NotionOAuth(OAuthDataSource): | ||||
|                 tenant_id=current_user.current_tenant_id, | ||||
|                 access_token=access_token, | ||||
|                 source_info=source_info, | ||||
|                 provider='notion' | ||||
|                 provider="notion", | ||||
|             ) | ||||
|             db.session.add(new_data_source_binding) | ||||
|             db.session.commit() | ||||
| @ -91,18 +87,18 @@ class NotionOAuth(OAuthDataSource): | ||||
|         # get all authorized pages | ||||
|         pages = self.get_authorized_pages(access_token) | ||||
|         source_info = { | ||||
|             'workspace_name': workspace_name, | ||||
|             'workspace_icon': workspace_icon, | ||||
|             'workspace_id': workspace_id, | ||||
|             'pages': pages, | ||||
|             'total': len(pages) | ||||
|             "workspace_name": workspace_name, | ||||
|             "workspace_icon": workspace_icon, | ||||
|             "workspace_id": workspace_id, | ||||
|             "pages": pages, | ||||
|             "total": len(pages), | ||||
|         } | ||||
|         # save data source binding | ||||
|         data_source_binding = DataSourceOauthBinding.query.filter( | ||||
|             db.and_( | ||||
|                 DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | ||||
|                 DataSourceOauthBinding.provider == 'notion', | ||||
|                 DataSourceOauthBinding.access_token == access_token | ||||
|                 DataSourceOauthBinding.provider == "notion", | ||||
|                 DataSourceOauthBinding.access_token == access_token, | ||||
|             ) | ||||
|         ).first() | ||||
|         if data_source_binding: | ||||
| @ -114,7 +110,7 @@ class NotionOAuth(OAuthDataSource): | ||||
|                 tenant_id=current_user.current_tenant_id, | ||||
|                 access_token=access_token, | ||||
|                 source_info=source_info, | ||||
|                 provider='notion' | ||||
|                 provider="notion", | ||||
|             ) | ||||
|             db.session.add(new_data_source_binding) | ||||
|             db.session.commit() | ||||
| @ -124,9 +120,9 @@ class NotionOAuth(OAuthDataSource): | ||||
|         data_source_binding = DataSourceOauthBinding.query.filter( | ||||
|             db.and_( | ||||
|                 DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | ||||
|                 DataSourceOauthBinding.provider == 'notion', | ||||
|                 DataSourceOauthBinding.provider == "notion", | ||||
|                 DataSourceOauthBinding.id == binding_id, | ||||
|                 DataSourceOauthBinding.disabled == False | ||||
|                 DataSourceOauthBinding.disabled == False, | ||||
|             ) | ||||
|         ).first() | ||||
|         if data_source_binding: | ||||
| @ -134,17 +130,17 @@ class NotionOAuth(OAuthDataSource): | ||||
|             pages = self.get_authorized_pages(data_source_binding.access_token) | ||||
|             source_info = data_source_binding.source_info | ||||
|             new_source_info = { | ||||
|                 'workspace_name': source_info['workspace_name'], | ||||
|                 'workspace_icon': source_info['workspace_icon'], | ||||
|                 'workspace_id': source_info['workspace_id'], | ||||
|                 'pages': pages, | ||||
|                 'total': len(pages) | ||||
|                 "workspace_name": source_info["workspace_name"], | ||||
|                 "workspace_icon": source_info["workspace_icon"], | ||||
|                 "workspace_id": source_info["workspace_id"], | ||||
|                 "pages": pages, | ||||
|                 "total": len(pages), | ||||
|             } | ||||
|             data_source_binding.source_info = new_source_info | ||||
|             data_source_binding.disabled = False | ||||
|             db.session.commit() | ||||
|         else: | ||||
|             raise ValueError('Data source binding not found') | ||||
|             raise ValueError("Data source binding not found") | ||||
| 
 | ||||
|     def get_authorized_pages(self, access_token: str): | ||||
|         pages = [] | ||||
| @ -152,143 +148,121 @@ class NotionOAuth(OAuthDataSource): | ||||
|         database_results = self.notion_database_search(access_token) | ||||
|         # get page detail | ||||
|         for page_result in page_results: | ||||
|             page_id = page_result['id'] | ||||
|             page_name = 'Untitled' | ||||
|             for key in page_result['properties']: | ||||
|                 if 'title' in page_result['properties'][key] and page_result['properties'][key]['title']: | ||||
|                     title_list = page_result['properties'][key]['title'] | ||||
|                     if len(title_list) > 0 and 'plain_text' in title_list[0]: | ||||
|                         page_name = title_list[0]['plain_text'] | ||||
|             page_icon = page_result['icon'] | ||||
|             page_id = page_result["id"] | ||||
|             page_name = "Untitled" | ||||
|             for key in page_result["properties"]: | ||||
|                 if "title" in page_result["properties"][key] and page_result["properties"][key]["title"]: | ||||
|                     title_list = page_result["properties"][key]["title"] | ||||
|                     if len(title_list) > 0 and "plain_text" in title_list[0]: | ||||
|                         page_name = title_list[0]["plain_text"] | ||||
|             page_icon = page_result["icon"] | ||||
|             if page_icon: | ||||
|                 icon_type = page_icon['type'] | ||||
|                 if icon_type == 'external' or icon_type == 'file': | ||||
|                     url = page_icon[icon_type]['url'] | ||||
|                     icon = { | ||||
|                         'type': 'url', | ||||
|                         'url': url if url.startswith('http') else f'https://www.notion.so{url}' | ||||
|                     } | ||||
|                 icon_type = page_icon["type"] | ||||
|                 if icon_type == "external" or icon_type == "file": | ||||
|                     url = page_icon[icon_type]["url"] | ||||
|                     icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} | ||||
|                 else: | ||||
|                     icon = { | ||||
|                         'type': 'emoji', | ||||
|                         'emoji': page_icon[icon_type] | ||||
|                     } | ||||
|                     icon = {"type": "emoji", "emoji": page_icon[icon_type]} | ||||
|             else: | ||||
|                 icon = None | ||||
|             parent = page_result['parent'] | ||||
|             parent_type = parent['type'] | ||||
|             if parent_type == 'block_id': | ||||
|             parent = page_result["parent"] | ||||
|             parent_type = parent["type"] | ||||
|             if parent_type == "block_id": | ||||
|                 parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) | ||||
|             elif parent_type == 'workspace': | ||||
|                 parent_id = 'root' | ||||
|             elif parent_type == "workspace": | ||||
|                 parent_id = "root" | ||||
|             else: | ||||
|                 parent_id = parent[parent_type] | ||||
|             page = { | ||||
|                 'page_id': page_id, | ||||
|                 'page_name': page_name, | ||||
|                 'page_icon': icon, | ||||
|                 'parent_id': parent_id, | ||||
|                 'type': 'page' | ||||
|                 "page_id": page_id, | ||||
|                 "page_name": page_name, | ||||
|                 "page_icon": icon, | ||||
|                 "parent_id": parent_id, | ||||
|                 "type": "page", | ||||
|             } | ||||
|             pages.append(page) | ||||
|             # get database detail | ||||
|         for database_result in database_results: | ||||
|             page_id = database_result['id'] | ||||
|             if len(database_result['title']) > 0: | ||||
|                 page_name = database_result['title'][0]['plain_text'] | ||||
|             page_id = database_result["id"] | ||||
|             if len(database_result["title"]) > 0: | ||||
|                 page_name = database_result["title"][0]["plain_text"] | ||||
|             else: | ||||
|                 page_name = 'Untitled' | ||||
|             page_icon = database_result['icon'] | ||||
|                 page_name = "Untitled" | ||||
|             page_icon = database_result["icon"] | ||||
|             if page_icon: | ||||
|                 icon_type = page_icon['type'] | ||||
|                 if icon_type == 'external' or icon_type == 'file': | ||||
|                     url = page_icon[icon_type]['url'] | ||||
|                     icon = { | ||||
|                         'type': 'url', | ||||
|                         'url': url if url.startswith('http') else f'https://www.notion.so{url}' | ||||
|                     } | ||||
|                 icon_type = page_icon["type"] | ||||
|                 if icon_type == "external" or icon_type == "file": | ||||
|                     url = page_icon[icon_type]["url"] | ||||
|                     icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} | ||||
|                 else: | ||||
|                     icon = { | ||||
|                         'type': icon_type, | ||||
|                         icon_type: page_icon[icon_type] | ||||
|                     } | ||||
|                     icon = {"type": icon_type, icon_type: page_icon[icon_type]} | ||||
|             else: | ||||
|                 icon = None | ||||
|             parent = database_result['parent'] | ||||
|             parent_type = parent['type'] | ||||
|             if parent_type == 'block_id': | ||||
|             parent = database_result["parent"] | ||||
|             parent_type = parent["type"] | ||||
|             if parent_type == "block_id": | ||||
|                 parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) | ||||
|             elif parent_type == 'workspace': | ||||
|                 parent_id = 'root' | ||||
|             elif parent_type == "workspace": | ||||
|                 parent_id = "root" | ||||
|             else: | ||||
|                 parent_id = parent[parent_type] | ||||
|             page = { | ||||
|                 'page_id': page_id, | ||||
|                 'page_name': page_name, | ||||
|                 'page_icon': icon, | ||||
|                 'parent_id': parent_id, | ||||
|                 'type': 'database' | ||||
|                 "page_id": page_id, | ||||
|                 "page_name": page_name, | ||||
|                 "page_icon": icon, | ||||
|                 "parent_id": parent_id, | ||||
|                 "type": "database", | ||||
|             } | ||||
|             pages.append(page) | ||||
|         return pages | ||||
| 
 | ||||
|     def notion_page_search(self, access_token: str): | ||||
|         data = { | ||||
|             'filter': { | ||||
|                 "value": "page", | ||||
|                 "property": "object" | ||||
|             } | ||||
|         } | ||||
|         data = {"filter": {"value": "page", "property": "object"}} | ||||
|         headers = { | ||||
|             'Content-Type': 'application/json', | ||||
|             'Authorization': f"Bearer {access_token}", | ||||
|             'Notion-Version': '2022-06-28', | ||||
|             "Content-Type": "application/json", | ||||
|             "Authorization": f"Bearer {access_token}", | ||||
|             "Notion-Version": "2022-06-28", | ||||
|         } | ||||
|         response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) | ||||
|         response_json = response.json() | ||||
|         results = response_json.get('results', []) | ||||
|         results = response_json.get("results", []) | ||||
|         return results | ||||
| 
 | ||||
|     def notion_block_parent_page_id(self, access_token: str, block_id: str): | ||||
|         headers = { | ||||
|             'Authorization': f"Bearer {access_token}", | ||||
|             'Notion-Version': '2022-06-28', | ||||
|             "Authorization": f"Bearer {access_token}", | ||||
|             "Notion-Version": "2022-06-28", | ||||
|         } | ||||
|         response = requests.get(url=f'{self._NOTION_BLOCK_SEARCH}/{block_id}', headers=headers) | ||||
|         response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) | ||||
|         response_json = response.json() | ||||
|         parent = response_json['parent'] | ||||
|         parent_type = parent['type'] | ||||
|         if parent_type == 'block_id': | ||||
|         parent = response_json["parent"] | ||||
|         parent_type = parent["type"] | ||||
|         if parent_type == "block_id": | ||||
|             return self.notion_block_parent_page_id(access_token, parent[parent_type]) | ||||
|         return parent[parent_type] | ||||
| 
 | ||||
|     def notion_workspace_name(self, access_token: str): | ||||
|         headers = { | ||||
|             'Authorization': f"Bearer {access_token}", | ||||
|             'Notion-Version': '2022-06-28', | ||||
|             "Authorization": f"Bearer {access_token}", | ||||
|             "Notion-Version": "2022-06-28", | ||||
|         } | ||||
|         response = requests.get(url=self._NOTION_BOT_USER, headers=headers) | ||||
|         response_json = response.json() | ||||
|         if 'object' in response_json and response_json['object'] == 'user': | ||||
|             user_type = response_json['type'] | ||||
|         if "object" in response_json and response_json["object"] == "user": | ||||
|             user_type = response_json["type"] | ||||
|             user_info = response_json[user_type] | ||||
|             if 'workspace_name' in user_info: | ||||
|                 return user_info['workspace_name'] | ||||
|         return 'workspace' | ||||
|             if "workspace_name" in user_info: | ||||
|                 return user_info["workspace_name"] | ||||
|         return "workspace" | ||||
| 
 | ||||
|     def notion_database_search(self, access_token: str): | ||||
|         data = { | ||||
|             'filter': { | ||||
|                 "value": "database", | ||||
|                 "property": "object" | ||||
|             } | ||||
|         } | ||||
|         data = {"filter": {"value": "database", "property": "object"}} | ||||
|         headers = { | ||||
|             'Content-Type': 'application/json', | ||||
|             'Authorization': f"Bearer {access_token}", | ||||
|             'Notion-Version': '2022-06-28', | ||||
|             "Content-Type": "application/json", | ||||
|             "Authorization": f"Bearer {access_token}", | ||||
|             "Notion-Version": "2022-06-28", | ||||
|         } | ||||
|         response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) | ||||
|         response_json = response.json() | ||||
|         results = response_json.get('results', []) | ||||
|         results = response_json.get("results", []) | ||||
|         return results | ||||
|  | ||||
| @ -9,14 +9,14 @@ class PassportService: | ||||
|         self.sk = dify_config.SECRET_KEY | ||||
| 
 | ||||
|     def issue(self, payload): | ||||
|         return jwt.encode(payload, self.sk, algorithm='HS256') | ||||
|         return jwt.encode(payload, self.sk, algorithm="HS256") | ||||
| 
 | ||||
|     def verify(self, token): | ||||
|         try: | ||||
|             return jwt.decode(token, self.sk, algorithms=['HS256']) | ||||
|             return jwt.decode(token, self.sk, algorithms=["HS256"]) | ||||
|         except jwt.exceptions.InvalidSignatureError: | ||||
|             raise Unauthorized('Invalid token signature.') | ||||
|             raise Unauthorized("Invalid token signature.") | ||||
|         except jwt.exceptions.DecodeError: | ||||
|             raise Unauthorized('Invalid token.') | ||||
|             raise Unauthorized("Invalid token.") | ||||
|         except jwt.exceptions.ExpiredSignatureError: | ||||
|             raise Unauthorized('Token has expired.') | ||||
|             raise Unauthorized("Token has expired.") | ||||
|  | ||||
| @ -5,6 +5,7 @@ import re | ||||
| 
 | ||||
| password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$" | ||||
| 
 | ||||
| 
 | ||||
| def valid_password(password): | ||||
|     # Define a regex pattern for password rules | ||||
|     pattern = password_pattern | ||||
| @ -12,11 +13,11 @@ def valid_password(password): | ||||
|     if re.match(pattern, password) is not None: | ||||
|         return password | ||||
| 
 | ||||
|     raise ValueError('Not a valid password.') | ||||
|     raise ValueError("Not a valid password.") | ||||
| 
 | ||||
| 
 | ||||
| def hash_password(password_str, salt_byte): | ||||
|     dk = hashlib.pbkdf2_hmac('sha256', password_str.encode('utf-8'), salt_byte, 10000) | ||||
|     dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000) | ||||
|     return binascii.hexlify(dk) | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -48,7 +48,7 @@ def encrypt(text, public_key): | ||||
| def get_decrypt_decoding(tenant_id): | ||||
|     filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" | ||||
| 
 | ||||
|     cache_key = 'tenant_privkey:{hash}'.format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) | ||||
|     cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) | ||||
|     private_key = redis_client.get(cache_key) | ||||
|     if not private_key: | ||||
|         try: | ||||
| @ -66,12 +66,12 @@ def get_decrypt_decoding(tenant_id): | ||||
| 
 | ||||
| def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa): | ||||
|     if encrypted_text.startswith(prefix_hybrid): | ||||
|         encrypted_text = encrypted_text[len(prefix_hybrid):] | ||||
|         encrypted_text = encrypted_text[len(prefix_hybrid) :] | ||||
| 
 | ||||
|         enc_aes_key = encrypted_text[:rsa_key.size_in_bytes()] | ||||
|         nonce = encrypted_text[rsa_key.size_in_bytes():rsa_key.size_in_bytes() + 16] | ||||
|         tag = encrypted_text[rsa_key.size_in_bytes() + 16:rsa_key.size_in_bytes() + 32] | ||||
|         ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32:] | ||||
|         enc_aes_key = encrypted_text[: rsa_key.size_in_bytes()] | ||||
|         nonce = encrypted_text[rsa_key.size_in_bytes() : rsa_key.size_in_bytes() + 16] | ||||
|         tag = encrypted_text[rsa_key.size_in_bytes() + 16 : rsa_key.size_in_bytes() + 32] | ||||
|         ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32 :] | ||||
| 
 | ||||
|         aes_key = cipher_rsa.decrypt(enc_aes_key) | ||||
| 
 | ||||
|  | ||||
| @ -5,7 +5,9 @@ from email.mime.text import MIMEText | ||||
| 
 | ||||
| 
 | ||||
| class SMTPClient: | ||||
|     def __init__(self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False): | ||||
|     def __init__( | ||||
|         self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False | ||||
|     ): | ||||
|         self.server = server | ||||
|         self.port = port | ||||
|         self._from = _from | ||||
| @ -25,17 +27,17 @@ class SMTPClient: | ||||
|                     smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) | ||||
|             else: | ||||
|                 smtp = smtplib.SMTP(self.server, self.port, timeout=10) | ||||
|                  | ||||
| 
 | ||||
|             if self.username and self.password: | ||||
|                 smtp.login(self.username, self.password) | ||||
| 
 | ||||
|             msg = MIMEMultipart() | ||||
|             msg['Subject'] = mail['subject'] | ||||
|             msg['From'] = self._from | ||||
|             msg['To'] = mail['to'] | ||||
|             msg.attach(MIMEText(mail['html'], 'html')) | ||||
|             msg["Subject"] = mail["subject"] | ||||
|             msg["From"] = self._from | ||||
|             msg["To"] = mail["to"] | ||||
|             msg.attach(MIMEText(mail["html"], "html")) | ||||
| 
 | ||||
|             smtp.sendmail(self._from, mail['to'], msg.as_string()) | ||||
|             smtp.sendmail(self._from, mail["to"], msg.as_string()) | ||||
|         except smtplib.SMTPException as e: | ||||
|             logging.error(f"SMTP error occurred: {str(e)}") | ||||
|             raise | ||||
|  | ||||
| @ -73,12 +73,10 @@ exclude = [ | ||||
|     "core/**/*.py", | ||||
|     "controllers/**/*.py", | ||||
|     "models/**/*.py", | ||||
|     "utils/**/*.py", | ||||
|     "migrations/**/*", | ||||
|     "services/**/*.py", | ||||
|     "tasks/**/*.py", | ||||
|     "tests/**/*.py", | ||||
|     "libs/**/*.py", | ||||
|     "configs/**/*.py", | ||||
| ] | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 -LAN-
						-LAN-