diff --git a/ingestion/setup.py b/ingestion/setup.py index d4a9cacb40a..1ed1d3f085e 100644 --- a/ingestion/setup.py +++ b/ingestion/setup.py @@ -82,6 +82,7 @@ plugins: Dict[str, Set[str]] = { "elasticsearch": {"elasticsearch==7.13.1", "requests-aws4auth==1.1.2"}, "glue": {"boto3~=1.19.12"}, "dynamodb": {"boto3~=1.19.12"}, + "sagemaker": {"boto3~=1.19.12"}, "hive": { "pyhive~=0.6.5", "thrift~=0.13.0", diff --git a/ingestion/src/metadata/clients/aws_client.py b/ingestion/src/metadata/clients/aws_client.py index 45c25cbd328..7ed88223a61 100644 --- a/ingestion/src/metadata/clients/aws_client.py +++ b/ingestion/src/metadata/clients/aws_client.py @@ -21,6 +21,7 @@ from metadata.clients.connection_clients import ( GlueDBClient, GluePipelineClient, KinesisClient, + SageMakerClient, ) from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials from metadata.utils.logger import utils_logger @@ -84,7 +85,7 @@ class AWSClient: ) return session.resource(service_name=service_name) - def get_dynomo_client(self) -> DynamoClient: + def get_dynamo_client(self) -> DynamoClient: return DynamoClient(self.get_resource("dynamodb")) def get_glue_db_client(self) -> GlueDBClient: @@ -93,5 +94,8 @@ class AWSClient: def get_glue_pipeline_client(self) -> GluePipelineClient: return GluePipelineClient(self.get_client("glue")) + def get_sagemaker_client(self) -> SageMakerClient: + return SageMakerClient(self.get_client("sagemaker")) + def get_kinesis_client(self) -> KinesisClient: return KinesisClient(self.get_client("kinesis")) diff --git a/ingestion/src/metadata/clients/connection_clients.py b/ingestion/src/metadata/clients/connection_clients.py index 8f4fea99311..08196b737c7 100644 --- a/ingestion/src/metadata/clients/connection_clients.py +++ b/ingestion/src/metadata/clients/connection_clients.py @@ -122,6 +122,12 @@ class MlflowClientWrapper: self.client = client +@dataclass +class SageMakerClient: + def __init__(self, client) -> None: + self.client = client + + @dataclass class FivetranClient: def __init__(self, client) -> None: diff --git a/ingestion/src/metadata/examples/workflows/sagemaker.yaml b/ingestion/src/metadata/examples/workflows/sagemaker.yaml new file mode 100644 index 00000000000..ea8721046ba --- /dev/null +++ b/ingestion/src/metadata/examples/workflows/sagemaker.yaml @@ -0,0 +1,21 @@ +source: + type: sagemaker + serviceName: local_sagemaker + serviceConnection: + config: + type: SageMaker + awsConfig: + awsAccessKeyId: aws_access_key_id + awsSecretAccessKey: aws_secret_access_key + awsRegion: aws region + sourceConfig: + config: + type: MlModelMetadata +sink: + type: metadata-rest + config: {} +workflowConfig: + loggerLevel: "DEBUG" + openMetadataServerConfig: + hostPort: http://localhost:8585/api + authProvider: no-auth diff --git a/ingestion/src/metadata/ingestion/source/mlmodel/sagemaker.py b/ingestion/src/metadata/ingestion/source/mlmodel/sagemaker.py new file mode 100644 index 00000000000..dfbe5b0d9a3 --- /dev/null +++ b/ingestion/src/metadata/ingestion/source/mlmodel/sagemaker.py @@ -0,0 +1,197 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SageMaker source module""" + +import traceback +from typing import Iterable, List, Optional + +from pydantic import BaseModel, Extra, Field, ValidationError + +from metadata.generated.schema.api.data.createMlModel import CreateMlModelRequest +from metadata.generated.schema.entity.data.mlmodel import ( + MlFeature, + MlHyperParameter, + MlStore, +) +from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import ( + OpenMetadataConnection, +) +from metadata.generated.schema.entity.services.connections.mlmodel.sageMakerConnection import ( + SageMakerConnection, +) +from metadata.generated.schema.metadataIngestion.workflow import ( + Source as WorkflowSource, +) +from metadata.generated.schema.type.entityReference import EntityReference +from metadata.generated.schema.type.tagLabel import TagLabel +from metadata.ingestion.api.source import InvalidSourceException +from metadata.ingestion.source.mlmodel.mlmodel_service import MlModelServiceSource +from metadata.utils.filters import filter_by_mlmodel +from metadata.utils.logger import ingestion_logger + +logger = ingestion_logger() + + +class SageMakerModel(BaseModel): + class Config: + extra = Extra.forbid + + name: str = Field(..., description="Model name", title="Model Name") + arn: str = Field(..., description="Model ARN in AWS account", title="Model ARN") + creation_timestamp: str = Field( + ..., + description="Timestamp of model creation in ISO format", + title="Creation Timestamp", + ) + + +class SagemakerSource(MlModelServiceSource): + """ + Source implementation to ingest SageMaker data. + + We will iterate on the ML Models + and prepare an iterator of CreateMlModelRequest + """ + + def __init__(self, config: WorkflowSource, metadata_config: OpenMetadataConnection): + super().__init__(config, metadata_config) + self.sagemaker = self.connection.client + + @classmethod + def create(cls, config_dict, metadata_config: OpenMetadataConnection): + config: WorkflowSource = WorkflowSource.parse_obj(config_dict) + connection: SageMakerConnection = config.serviceConnection.__root__.config + if not isinstance(connection, SageMakerConnection): + raise InvalidSourceException( + f"Expected SageMakerConnection, but got {connection}" + ) + return cls(config, metadata_config) + + def get_mlmodels( # pylint: disable=arguments-differ + self, + ) -> Iterable[SageMakerModel]: + """ + List and filters models + """ + args, has_more_models, models = {"MaxResults": 100}, True, [] + try: + while has_more_models: + response = self.sagemaker.list_models(**args) + models.append(response["Models"]) + has_more_models = response.get("NextToken") + args["NextToken"] = response.get("NextToken") + except Exception as err: + logger.debug(traceback.format_exc()) + logger.error(f"Failed to fetch models list - {err}") + + for model in models: + try: + if filter_by_mlmodel( + self.source_config.mlModelFilterPattern, + mlmodel_name=model["ModelName"], + ): + self.status.filter( + model["ModelName"], + "MlModel name pattern not allowed", + ) + continue + yield SageMakerModel( + name=model["ModelName"], + arn=model["ModelArn"], + creation_timestamp=model["CreationTime"].isoformat(), + ) + except ValidationError as err: + logger.debug(traceback.format_exc()) + logger.warning( + f"Validation error while creating SageMakerModel from model details - {err}" + ) + except Exception as err: + logger.debug(traceback.format_exc()) + logger.warning( + f"Wild error while creating SageMakerModel from model details - {err}" + ) + continue + + def _get_algorithm(self) -> str: # pylint: disable=arguments-differ + logger.info( + "Setting algorithm to default value of `mlmodel` for SageMaker Model" + ) + return "mlmodel" + + def yield_mlmodel( # pylint: disable=arguments-differ + self, model: SageMakerModel + ) -> Iterable[CreateMlModelRequest]: + """ + Prepare the Request model + """ + self.status.scanned(model.name) + + yield CreateMlModelRequest( + name=model.name, + algorithm=self._get_algorithm(), # Setting this to a constant + mlStore=self._get_ml_store(model.name), + service=EntityReference( + id=self.context.mlmodel_service.id, type="mlmodelService" + ), + ) + + def _get_ml_store( # pylint: disable=arguments-differ + self, + model_name: str, + ) -> Optional[MlStore]: + """ + Get the Ml Store for the model + """ + try: + model_info = self.sagemaker.describe_model(ModelName=model_name) + return MlStore(imageRepository=model_info["PrimaryContainer"]["Image"]) + except ValidationError as err: + logger.debug(traceback.format_exc()) + logger.warning( + f"Validation error adding the MlModel store from model description: {model_name} - {err}" + ) + except Exception as err: + logger.debug(traceback.format_exc()) + logger.warning( + f"Wild error adding the MlModel store from model description: {model_name} - {err}" + ) + return None + + def _get_tags(self, model_arn: str) -> Optional[List[TagLabel]]: + try: + tags = self.sagemaker.list_tags(ResourceArn=model_arn)["Tags"] + return [ + TagLabel( + tagFQN=tag["Key"], + description=tag["Value"], + source="Tag", + labelType="Propagated", + state="Confirmed", + ) + for tag in tags + ] + except ValidationError as err: + logger.debug(traceback.format_exc()) + logger.warning( + f"Validation error adding TagLabel from model tags: {model_arn} - {err}" + ) + except Exception as err: + logger.debug(traceback.format_exc()) + logger.warning( + f"Wild error adding TagLabel from model tags: {model_arn} - {err}" + ) + return None + + def _get_hyper_params(self, *args, **kwargs) -> Optional[List[MlHyperParameter]]: + pass + + def _get_ml_features(self, *args, **kwargs) -> Optional[List[MlFeature]]: + pass diff --git a/ingestion/src/metadata/utils/connections.py b/ingestion/src/metadata/utils/connections.py index 7e145d940e1..e10889d38b7 100644 --- a/ingestion/src/metadata/utils/connections.py +++ b/ingestion/src/metadata/utils/connections.py @@ -48,6 +48,7 @@ from metadata.clients.connection_clients import ( NifiClientWrapper, PowerBiClient, RedashClient, + SageMakerClient, SalesforceClient, SupersetClient, TableauClient, @@ -121,6 +122,9 @@ from metadata.generated.schema.entity.services.connections.messaging.redpandaCon from metadata.generated.schema.entity.services.connections.mlmodel.mlflowConnection import ( MlflowConnection, ) +from metadata.generated.schema.entity.services.connections.mlmodel.sageMakerConnection import ( + SageMakerConnection, +) from metadata.generated.schema.entity.services.connections.pipeline.airbyteConnection import ( AirbyteConnection, ) @@ -285,8 +289,8 @@ def _( ) -> DynamoClient: from metadata.clients.aws_client import AWSClient - dynomo_connection = AWSClient(connection.awsConfig).get_dynomo_client() - return dynomo_connection + dynamo_connection = AWSClient(connection.awsConfig).get_dynamo_client() + return dynamo_connection @get_connection.register @@ -496,7 +500,7 @@ def _(connection: DynamoClient) -> None: def _(connection: GlueDBClient) -> None: """ Test that we can connect to the source using the given aws resource - :param engine: boto cliet to test + :param engine: boto client to test :return: None or raise an exception if we cannot connect """ from botocore.client import ClientError @@ -947,6 +951,36 @@ def _(connection: MlflowClientWrapper) -> None: raise SourceConnectionException(msg) from exc +@get_connection.register +def _( + connection: SageMakerConnection, + verbose: bool = False, # pylint: disable=unused-argument +) -> SageMakerClient: + from metadata.clients.aws_client import AWSClient + + sagemaker_connection = AWSClient(connection.awsConfig).get_sagemaker_client() + return sagemaker_connection + + +@test_connection.register +def _(connection: SageMakerClient) -> None: + """ + Test that we can connect to the SageMaker source using the given aws resource + :param engine: boto service resource to test + :return: None or raise an exception if we cannot connect + """ + from botocore.client import ClientError + + try: + connection.client.list_models() + except ClientError as err: + msg = f"Connection error for {connection}: {err}. Check the connection details." + raise SourceConnectionException(msg) from err + except Exception as exc: + msg = f"Unknown error connecting with {connection}: {exc}." + raise SourceConnectionException(msg) from exc + + @get_connection.register def _( connection: NifiConnection, verbose: bool = False diff --git a/ingestion/tests/unit/source/test_sagemaker.py b/ingestion/tests/unit/source/test_sagemaker.py new file mode 100644 index 00000000000..dd355e4e143 --- /dev/null +++ b/ingestion/tests/unit/source/test_sagemaker.py @@ -0,0 +1,105 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SageMaker unit tests +""" +import json +from unittest import TestCase +from unittest.mock import patch + +import boto3 +from moto import mock_sagemaker + +from metadata.ingestion.api.workflow import Workflow + +CONFIG = """ +{ + "source": { + "type": "sagemaker", + "serviceName": "local_sagemaker", + "serviceConnection": { + "config": { + "type": "SageMaker", + "awsConfig": { + "awsAccessKeyId": "aws_access_key_id", + "awsSecretAccessKey": "aws_secret_access_key", + "awsRegion": "us-east-2" + } + } + }, + "sourceConfig": { + "config": { + "type": "MlModelMetadata" + } + } + }, + "sink": { + "type": "file", + "config": { + "filename": "/var/tmp/datasets.json" + } + }, + "workflowConfig": { + "openMetadataServerConfig": { + "hostPort": "http://localhost:8585/api", + "authProvider": "no-auth" + } + } +} +""" + + +def execute_workflow(): + workflow = Workflow.create(json.loads(CONFIG)) + workflow.execute() + workflow.print_status() + workflow.stop() + + +def get_file_path(): + return json.loads(CONFIG)["sink"]["config"]["filename"] + + +def _setup_mock_sagemaker(create_model: bool = False): + sagemaker = boto3.Session().client("sagemaker") + if not create_model: + return + print("Creating model!!!!!!!") + sagemaker.create_model( + ModelName="mock-model", + PrimaryContainer={ + "Environment": {}, + "Image": "123.dkr.ecr.eu-west-1.amazonaws.com/image:mock-image", + "Mode": "SingleModel", + }, + ExecutionRoleArn="arn:aws:iam::123:role/service-role/mockRole", + EnableNetworkIsolation=False, + ) + + +@mock_sagemaker +@patch("sqlalchemy.engine.base.Engine.connect") +class SageMakerIngestionTest(TestCase): + def test_sagemaker_empty_models(self, mock_connect): + _setup_mock_sagemaker() + execute_workflow() + file_path = get_file_path() + with open(file_path, "r") as file: + assert len(json.loads(file.read())) == 0 + + def test_sagemaker_models(self, mock_connect): + _setup_mock_sagemaker(create_model=True) + execute_workflow() + file_path = get_file_path() + with open(file_path, "r") as file: + data = json.loads(file.read()) + assert data[0]["name"] == "mock-model" diff --git a/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/mlmodel/sageMakerConnection.json b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/mlmodel/sageMakerConnection.json new file mode 100644 index 00000000000..c5447c75c38 --- /dev/null +++ b/openmetadata-spec/src/main/resources/json/schema/entity/services/connections/mlmodel/sageMakerConnection.json @@ -0,0 +1,34 @@ +{ + "$id": "https://open-metadata.org/schema/entity/services/connections/mlmodel/sageMakerConnection.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "SageMakerConnection", + "description": "SageMaker Connection Config", + "type": "object", + "javaType": "org.openmetadata.schema.services.connections.mlmodel.SageMakerConnection", + "definitions": { + "sageMakerType": { + "description": "Service type.", + "type": "string", + "enum": ["SageMaker"], + "default": "SageMaker" + } + }, + "properties": { + "type": { + "title": "Service Type", + "description": "Service Type", + "$ref": "#/definitions/sageMakerType", + "default": "SageMaker" + }, + "awsConfig": { + "title": "AWS Credentials Configuration", + "$ref": "../../../../security/credentials/awsCredentials.json" + }, + "supportsMetadataExtraction": { + "title": "Supports Metadata Extraction", + "$ref": "../connectionBasicType.json#/definitions/supportsMetadataExtraction" + } + }, + "additionalProperties": false, + "required": ["awsConfig"] +} diff --git a/openmetadata-spec/src/main/resources/json/schema/entity/services/mlmodelService.json b/openmetadata-spec/src/main/resources/json/schema/entity/services/mlmodelService.json index e392b34ec0e..5a52ae46006 100644 --- a/openmetadata-spec/src/main/resources/json/schema/entity/services/mlmodelService.json +++ b/openmetadata-spec/src/main/resources/json/schema/entity/services/mlmodelService.json @@ -14,7 +14,7 @@ "description": "Type of MlModel service", "type": "string", "javaInterfaces": ["org.openmetadata.schema.EnumInterface"], - "enum": ["Mlflow", "Sklearn", "CustomMlModel"], + "enum": ["Mlflow", "Sklearn", "CustomMlModel", "SageMaker"], "javaEnums": [ { "name": "Mlflow" @@ -24,6 +24,9 @@ }, { "name": "CustomMlModel" + }, + { + "name": "SageMaker" } ] }, @@ -46,6 +49,9 @@ }, { "$ref": "./connections/mlmodel/customMlModelConnection.json" + }, + { + "$ref": "./connections/mlmodel/sageMakerConnection.json" } ] }