OpenMetadata/ingestion/src/metadata/utils/secrets_manager.py

161 lines
5.6 KiB
Python
Raw Normal View History

# Copyright 2022 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from abc import abstractmethod
from pydoc import locate
from typing import Optional, Union
import boto3
from botocore.exceptions import ClientError
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
SecretsManagerProvider,
)
from metadata.generated.schema.entity.services.dashboardService import DashboardService
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.entity.services.messagingService import MessagingService
from metadata.generated.schema.entity.services.mlmodelService import MlModelService
from metadata.generated.schema.entity.services.pipelineService import PipelineService
from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials
from metadata.utils.logger import ingestion_logger
from metadata.utils.singleton import Singleton
logger = ingestion_logger()
SECRET_MANAGER_AIRFLOW_CONF = "openmetadata_secrets_manager"
class SecretsManager(metaclass=Singleton):
@abstractmethod
def add_service_config_connection(
self,
service: Union[
DashboardService,
DatabaseService,
MessagingService,
PipelineService,
MlModelService,
],
service_type: str,
) -> Union[
DashboardService,
DatabaseService,
MessagingService,
PipelineService,
MlModelService,
]:
pass
@staticmethod
def build_secret_id(
service_type: str, service_connection_type: str, service_name: str
) -> str:
return f"openmetadata-{service_type.lower()}-{service_connection_type.lower()}-{service_name.lower()}"
class LocalSecretsManager(SecretsManager):
def add_service_config_connection(
self,
service: Union[
DashboardService,
DatabaseService,
MessagingService,
PipelineService,
MlModelService,
],
service_type: str,
) -> Union[
DashboardService,
DatabaseService,
MessagingService,
PipelineService,
MlModelService,
]:
return service
class AWSSecretsManager(SecretsManager):
def add_service_config_connection(
self,
service: Union[
DashboardService,
DatabaseService,
MessagingService,
PipelineService,
MlModelService,
],
service_type: str,
) -> Union[
DashboardService,
DatabaseService,
MessagingService,
PipelineService,
MlModelService,
]:
service_simple_type = service_type.replace("Service", "").lower()
service_connection_type = service.serviceType.value
service_name = service.name.__root__
secret_id = self.build_secret_id(
service_simple_type, service_connection_type, service_name
)
connection_py_file = (
service_connection_type[0].lower() + service_connection_type[1:]
)
conn_class = locate(
f"metadata.generated.schema.entity.services.connections.{service_simple_type}.{connection_py_file}Connection.{service_connection_type}Connection"
)
service.connection.config = conn_class.parse_obj(
json.loads(self._get_string_value(secret_id))
)
return service
def __init__(self, credentials: AWSCredentials):
session = boto3.Session(
aws_access_key_id=credentials.awsAccessKeyId,
aws_secret_access_key=credentials.awsSecretAccessKey.get_secret_value(),
region_name=credentials.awsRegion,
)
self.secretsmanager_client = session.client("secretsmanager")
def _get_string_value(self, name: str) -> str:
"""
:param name: The secret name to retrieve. Current stage is always retrieved.
:return: The value of the secret. When the secret is a string, the value is
contained in the `SecretString` field. When the secret is bytes,
it is contained in the `SecretBinary` field.
"""
if name is None:
raise ValueError
try:
kwargs = {"SecretId": name}
response = self.secretsmanager_client.get_secret_value(**kwargs)
logger.info("Got value for secret %s.", name)
except ClientError:
logger.exception("Couldn't get value for secret %s.", name)
raise
else:
if "SecretString" in response:
return response["SecretString"]
else:
raise ValueError("[SecretString] not present in the response.")
def get_secrets_manager(
secret_manager: SecretsManagerProvider,
credentials: Optional[Union[AWSCredentials]] = None,
) -> SecretsManager:
if secret_manager == SecretsManagerProvider.local:
return LocalSecretsManager()
elif secret_manager == SecretsManagerProvider.aws:
return AWSSecretsManager(credentials)
else:
raise NotImplementedError(f"[{secret_manager}] is not implemented.")