| 
									
										
										
										
											2024-06-15 02:46:02 +08:00
										 |  |  | import json | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from core.helper import encrypter | 
					
						
							|  |  |  | from extensions.ext_database import db | 
					
						
							|  |  |  | from models.source import DataSourceApiKeyAuthBinding | 
					
						
							|  |  |  | from services.auth.api_key_auth_factory import ApiKeyAuthFactory | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ApiKeyAuthService: | 
					
						
							|  |  |  |     @staticmethod | 
					
						
							|  |  |  |     def get_provider_auth_list(tenant_id: str) -> list: | 
					
						
							| 
									
										
										
										
											2024-08-26 13:43:57 +08:00
										 |  |  |         data_source_api_key_bindings = ( | 
					
						
							|  |  |  |             db.session.query(DataSourceApiKeyAuthBinding) | 
					
						
							|  |  |  |             .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) | 
					
						
							|  |  |  |             .all() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-06-15 02:46:02 +08:00
										 |  |  |         return data_source_api_key_bindings | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @staticmethod | 
					
						
							|  |  |  |     def create_provider_auth(tenant_id: str, args: dict): | 
					
						
							| 
									
										
										
										
											2024-08-26 13:43:57 +08:00
										 |  |  |         auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials() | 
					
						
							| 
									
										
										
										
											2024-06-15 02:46:02 +08:00
										 |  |  |         if auth_result: | 
					
						
							|  |  |  |             # Encrypt the api key | 
					
						
							| 
									
										
										
										
											2024-08-26 13:43:57 +08:00
										 |  |  |             api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"]) | 
					
						
							|  |  |  |             args["credentials"]["config"]["api_key"] = api_key | 
					
						
							| 
									
										
										
										
											2024-06-15 02:46:02 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             data_source_api_key_binding = DataSourceApiKeyAuthBinding() | 
					
						
							|  |  |  |             data_source_api_key_binding.tenant_id = tenant_id | 
					
						
							| 
									
										
										
										
											2024-08-26 13:43:57 +08:00
										 |  |  |             data_source_api_key_binding.category = args["category"] | 
					
						
							|  |  |  |             data_source_api_key_binding.provider = args["provider"] | 
					
						
							|  |  |  |             data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False) | 
					
						
							| 
									
										
										
										
											2024-06-15 02:46:02 +08:00
										 |  |  |             db.session.add(data_source_api_key_binding) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @staticmethod | 
					
						
							|  |  |  |     def get_auth_credentials(tenant_id: str, category: str, provider: str): | 
					
						
							| 
									
										
										
										
											2024-08-26 13:43:57 +08:00
										 |  |  |         data_source_api_key_bindings = ( | 
					
						
							|  |  |  |             db.session.query(DataSourceApiKeyAuthBinding) | 
					
						
							|  |  |  |             .filter( | 
					
						
							|  |  |  |                 DataSourceApiKeyAuthBinding.tenant_id == tenant_id, | 
					
						
							|  |  |  |                 DataSourceApiKeyAuthBinding.category == category, | 
					
						
							|  |  |  |                 DataSourceApiKeyAuthBinding.provider == provider, | 
					
						
							|  |  |  |                 DataSourceApiKeyAuthBinding.disabled.is_(False), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .first() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-06-15 02:46:02 +08:00
										 |  |  |         if not data_source_api_key_bindings: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  |         credentials = json.loads(data_source_api_key_bindings.credentials) | 
					
						
							|  |  |  |         return credentials | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @staticmethod | 
					
						
							|  |  |  |     def delete_provider_auth(tenant_id: str, binding_id: str): | 
					
						
							| 
									
										
										
										
											2024-08-26 13:43:57 +08:00
										 |  |  |         data_source_api_key_binding = ( | 
					
						
							|  |  |  |             db.session.query(DataSourceApiKeyAuthBinding) | 
					
						
							|  |  |  |             .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) | 
					
						
							|  |  |  |             .first() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-06-15 02:46:02 +08:00
										 |  |  |         if data_source_api_key_binding: | 
					
						
							|  |  |  |             db.session.delete(data_source_api_key_binding) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def validate_api_key_auth_args(cls, args): | 
					
						
							| 
									
										
										
										
											2024-08-26 13:43:57 +08:00
										 |  |  |         if "category" not in args or not args["category"]: | 
					
						
							|  |  |  |             raise ValueError("category is required") | 
					
						
							|  |  |  |         if "provider" not in args or not args["provider"]: | 
					
						
							|  |  |  |             raise ValueError("provider is required") | 
					
						
							|  |  |  |         if "credentials" not in args or not args["credentials"]: | 
					
						
							|  |  |  |             raise ValueError("credentials is required") | 
					
						
							|  |  |  |         if not isinstance(args["credentials"], dict): | 
					
						
							|  |  |  |             raise ValueError("credentials must be a dictionary") | 
					
						
							|  |  |  |         if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]: | 
					
						
							|  |  |  |             raise ValueError("auth_type is required") |