| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  | import json | 
					
						
							|  |  |  | from enum import Enum | 
					
						
							|  |  |  | from json import JSONDecodeError | 
					
						
							|  |  |  | from typing import Optional | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from extensions.ext_redis import redis_client | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ProviderCredentialsCacheType(Enum): | 
					
						
							|  |  |  |     PROVIDER = "provider" | 
					
						
							|  |  |  |     MODEL = "provider_model" | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |     LOAD_BALANCING_MODEL = "load_balancing_provider_model" | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ProviderCredentialsCache: | 
					
						
							|  |  |  |     def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType): | 
					
						
							|  |  |  |         self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get(self) -> Optional[dict]: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get cached model provider credentials. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         cached_provider_credentials = redis_client.get(self.cache_key) | 
					
						
							|  |  |  |         if cached_provider_credentials: | 
					
						
							|  |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 cached_provider_credentials = cached_provider_credentials.decode("utf-8") | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |                 cached_provider_credentials = json.loads(cached_provider_credentials) | 
					
						
							|  |  |  |             except JSONDecodeError: | 
					
						
							|  |  |  |                 return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return cached_provider_credentials | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def set(self, credentials: dict) -> None: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Cache model provider credentials. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param credentials: provider credentials | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-01-05 09:43:41 +08:00
										 |  |  |         redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def delete(self) -> None: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Delete cached model provider credentials. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         redis_client.delete(self.cache_key) |