feat: add SageMaker connector (#8435)

* feat: add sagemaker connector

Signed-off-by: Tushar Mittal <chiragmittal.mittal@gmail.com>

* fix: fix linting errors and update imports

Signed-off-by: Tushar Mittal <chiragmittal.mittal@gmail.com>

* test: add unit tests for sagemake source

Signed-off-by: Tushar Mittal <chiragmittal.mittal@gmail.com>

Signed-off-by: Tushar Mittal <chiragmittal.mittal@gmail.com>
This commit is contained in:
Tushar Mittal 2022-11-03 22:49:20 +05:30 committed by GitHub
parent d93b46ef31
commit 6f2c93089c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 413 additions and 5 deletions

View File

@ -82,6 +82,7 @@ plugins: Dict[str, Set[str]] = {
"elasticsearch": {"elasticsearch==7.13.1", "requests-aws4auth==1.1.2"},
"glue": {"boto3~=1.19.12"},
"dynamodb": {"boto3~=1.19.12"},
"sagemaker": {"boto3~=1.19.12"},
"hive": {
"pyhive~=0.6.5",
"thrift~=0.13.0",

View File

@ -21,6 +21,7 @@ from metadata.clients.connection_clients import (
GlueDBClient,
GluePipelineClient,
KinesisClient,
SageMakerClient,
)
from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials
from metadata.utils.logger import utils_logger
@ -84,7 +85,7 @@ class AWSClient:
)
return session.resource(service_name=service_name)
def get_dynomo_client(self) -> DynamoClient:
def get_dynamo_client(self) -> DynamoClient:
return DynamoClient(self.get_resource("dynamodb"))
def get_glue_db_client(self) -> GlueDBClient:
@ -93,5 +94,8 @@ class AWSClient:
def get_glue_pipeline_client(self) -> GluePipelineClient:
return GluePipelineClient(self.get_client("glue"))
def get_sagemaker_client(self) -> SageMakerClient:
return SageMakerClient(self.get_client("sagemaker"))
def get_kinesis_client(self) -> KinesisClient:
return KinesisClient(self.get_client("kinesis"))

View File

@ -122,6 +122,12 @@ class MlflowClientWrapper:
self.client = client
@dataclass
class SageMakerClient:
def __init__(self, client) -> None:
self.client = client
@dataclass
class FivetranClient:
def __init__(self, client) -> None:

View File

@ -0,0 +1,21 @@
source:
type: sagemaker
serviceName: local_sagemaker
serviceConnection:
config:
type: SageMaker
awsConfig:
awsAccessKeyId: aws_access_key_id
awsSecretAccessKey: aws_secret_access_key
awsRegion: aws region
sourceConfig:
config:
type: MlModelMetadata
sink:
type: metadata-rest
config: {}
workflowConfig:
loggerLevel: "DEBUG"
openMetadataServerConfig:
hostPort: http://localhost:8585/api
authProvider: no-auth

View File

@ -0,0 +1,197 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SageMaker source module"""
import traceback
from typing import Iterable, List, Optional
from pydantic import BaseModel, Extra, Field, ValidationError
from metadata.generated.schema.api.data.createMlModel import CreateMlModelRequest
from metadata.generated.schema.entity.data.mlmodel import (
MlFeature,
MlHyperParameter,
MlStore,
)
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
OpenMetadataConnection,
)
from metadata.generated.schema.entity.services.connections.mlmodel.sageMakerConnection import (
SageMakerConnection,
)
from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.generated.schema.type.tagLabel import TagLabel
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source.mlmodel.mlmodel_service import MlModelServiceSource
from metadata.utils.filters import filter_by_mlmodel
from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
class SageMakerModel(BaseModel):
class Config:
extra = Extra.forbid
name: str = Field(..., description="Model name", title="Model Name")
arn: str = Field(..., description="Model ARN in AWS account", title="Model ARN")
creation_timestamp: str = Field(
...,
description="Timestamp of model creation in ISO format",
title="Creation Timestamp",
)
class SagemakerSource(MlModelServiceSource):
"""
Source implementation to ingest SageMaker data.
We will iterate on the ML Models
and prepare an iterator of CreateMlModelRequest
"""
def __init__(self, config: WorkflowSource, metadata_config: OpenMetadataConnection):
super().__init__(config, metadata_config)
self.sagemaker = self.connection.client
@classmethod
def create(cls, config_dict, metadata_config: OpenMetadataConnection):
config: WorkflowSource = WorkflowSource.parse_obj(config_dict)
connection: SageMakerConnection = config.serviceConnection.__root__.config
if not isinstance(connection, SageMakerConnection):
raise InvalidSourceException(
f"Expected SageMakerConnection, but got {connection}"
)
return cls(config, metadata_config)
def get_mlmodels( # pylint: disable=arguments-differ
self,
) -> Iterable[SageMakerModel]:
"""
List and filters models
"""
args, has_more_models, models = {"MaxResults": 100}, True, []
try:
while has_more_models:
response = self.sagemaker.list_models(**args)
models.append(response["Models"])
has_more_models = response.get("NextToken")
args["NextToken"] = response.get("NextToken")
except Exception as err:
logger.debug(traceback.format_exc())
logger.error(f"Failed to fetch models list - {err}")
for model in models:
try:
if filter_by_mlmodel(
self.source_config.mlModelFilterPattern,
mlmodel_name=model["ModelName"],
):
self.status.filter(
model["ModelName"],
"MlModel name pattern not allowed",
)
continue
yield SageMakerModel(
name=model["ModelName"],
arn=model["ModelArn"],
creation_timestamp=model["CreationTime"].isoformat(),
)
except ValidationError as err:
logger.debug(traceback.format_exc())
logger.warning(
f"Validation error while creating SageMakerModel from model details - {err}"
)
except Exception as err:
logger.debug(traceback.format_exc())
logger.warning(
f"Wild error while creating SageMakerModel from model details - {err}"
)
continue
def _get_algorithm(self) -> str: # pylint: disable=arguments-differ
logger.info(
"Setting algorithm to default value of `mlmodel` for SageMaker Model"
)
return "mlmodel"
def yield_mlmodel( # pylint: disable=arguments-differ
self, model: SageMakerModel
) -> Iterable[CreateMlModelRequest]:
"""
Prepare the Request model
"""
self.status.scanned(model.name)
yield CreateMlModelRequest(
name=model.name,
algorithm=self._get_algorithm(), # Setting this to a constant
mlStore=self._get_ml_store(model.name),
service=EntityReference(
id=self.context.mlmodel_service.id, type="mlmodelService"
),
)
def _get_ml_store( # pylint: disable=arguments-differ
self,
model_name: str,
) -> Optional[MlStore]:
"""
Get the Ml Store for the model
"""
try:
model_info = self.sagemaker.describe_model(ModelName=model_name)
return MlStore(imageRepository=model_info["PrimaryContainer"]["Image"])
except ValidationError as err:
logger.debug(traceback.format_exc())
logger.warning(
f"Validation error adding the MlModel store from model description: {model_name} - {err}"
)
except Exception as err:
logger.debug(traceback.format_exc())
logger.warning(
f"Wild error adding the MlModel store from model description: {model_name} - {err}"
)
return None
def _get_tags(self, model_arn: str) -> Optional[List[TagLabel]]:
try:
tags = self.sagemaker.list_tags(ResourceArn=model_arn)["Tags"]
return [
TagLabel(
tagFQN=tag["Key"],
description=tag["Value"],
source="Tag",
labelType="Propagated",
state="Confirmed",
)
for tag in tags
]
except ValidationError as err:
logger.debug(traceback.format_exc())
logger.warning(
f"Validation error adding TagLabel from model tags: {model_arn} - {err}"
)
except Exception as err:
logger.debug(traceback.format_exc())
logger.warning(
f"Wild error adding TagLabel from model tags: {model_arn} - {err}"
)
return None
def _get_hyper_params(self, *args, **kwargs) -> Optional[List[MlHyperParameter]]:
pass
def _get_ml_features(self, *args, **kwargs) -> Optional[List[MlFeature]]:
pass

View File

@ -48,6 +48,7 @@ from metadata.clients.connection_clients import (
NifiClientWrapper,
PowerBiClient,
RedashClient,
SageMakerClient,
SalesforceClient,
SupersetClient,
TableauClient,
@ -121,6 +122,9 @@ from metadata.generated.schema.entity.services.connections.messaging.redpandaCon
from metadata.generated.schema.entity.services.connections.mlmodel.mlflowConnection import (
MlflowConnection,
)
from metadata.generated.schema.entity.services.connections.mlmodel.sageMakerConnection import (
SageMakerConnection,
)
from metadata.generated.schema.entity.services.connections.pipeline.airbyteConnection import (
AirbyteConnection,
)
@ -285,8 +289,8 @@ def _(
) -> DynamoClient:
from metadata.clients.aws_client import AWSClient
dynomo_connection = AWSClient(connection.awsConfig).get_dynomo_client()
return dynomo_connection
dynamo_connection = AWSClient(connection.awsConfig).get_dynamo_client()
return dynamo_connection
@get_connection.register
@ -496,7 +500,7 @@ def _(connection: DynamoClient) -> None:
def _(connection: GlueDBClient) -> None:
"""
Test that we can connect to the source using the given aws resource
:param engine: boto cliet to test
:param engine: boto client to test
:return: None or raise an exception if we cannot connect
"""
from botocore.client import ClientError
@ -947,6 +951,36 @@ def _(connection: MlflowClientWrapper) -> None:
raise SourceConnectionException(msg) from exc
@get_connection.register
def _(
connection: SageMakerConnection,
verbose: bool = False, # pylint: disable=unused-argument
) -> SageMakerClient:
from metadata.clients.aws_client import AWSClient
sagemaker_connection = AWSClient(connection.awsConfig).get_sagemaker_client()
return sagemaker_connection
@test_connection.register
def _(connection: SageMakerClient) -> None:
"""
Test that we can connect to the SageMaker source using the given aws resource
:param engine: boto service resource to test
:return: None or raise an exception if we cannot connect
"""
from botocore.client import ClientError
try:
connection.client.list_models()
except ClientError as err:
msg = f"Connection error for {connection}: {err}. Check the connection details."
raise SourceConnectionException(msg) from err
except Exception as exc:
msg = f"Unknown error connecting with {connection}: {exc}."
raise SourceConnectionException(msg) from exc
@get_connection.register
def _(
connection: NifiConnection, verbose: bool = False

View File

@ -0,0 +1,105 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SageMaker unit tests
"""
import json
from unittest import TestCase
from unittest.mock import patch
import boto3
from moto import mock_sagemaker
from metadata.ingestion.api.workflow import Workflow
CONFIG = """
{
"source": {
"type": "sagemaker",
"serviceName": "local_sagemaker",
"serviceConnection": {
"config": {
"type": "SageMaker",
"awsConfig": {
"awsAccessKeyId": "aws_access_key_id",
"awsSecretAccessKey": "aws_secret_access_key",
"awsRegion": "us-east-2"
}
}
},
"sourceConfig": {
"config": {
"type": "MlModelMetadata"
}
}
},
"sink": {
"type": "file",
"config": {
"filename": "/var/tmp/datasets.json"
}
},
"workflowConfig": {
"openMetadataServerConfig": {
"hostPort": "http://localhost:8585/api",
"authProvider": "no-auth"
}
}
}
"""
def execute_workflow():
workflow = Workflow.create(json.loads(CONFIG))
workflow.execute()
workflow.print_status()
workflow.stop()
def get_file_path():
return json.loads(CONFIG)["sink"]["config"]["filename"]
def _setup_mock_sagemaker(create_model: bool = False):
sagemaker = boto3.Session().client("sagemaker")
if not create_model:
return
print("Creating model!!!!!!!")
sagemaker.create_model(
ModelName="mock-model",
PrimaryContainer={
"Environment": {},
"Image": "123.dkr.ecr.eu-west-1.amazonaws.com/image:mock-image",
"Mode": "SingleModel",
},
ExecutionRoleArn="arn:aws:iam::123:role/service-role/mockRole",
EnableNetworkIsolation=False,
)
@mock_sagemaker
@patch("sqlalchemy.engine.base.Engine.connect")
class SageMakerIngestionTest(TestCase):
def test_sagemaker_empty_models(self, mock_connect):
_setup_mock_sagemaker()
execute_workflow()
file_path = get_file_path()
with open(file_path, "r") as file:
assert len(json.loads(file.read())) == 0
def test_sagemaker_models(self, mock_connect):
_setup_mock_sagemaker(create_model=True)
execute_workflow()
file_path = get_file_path()
with open(file_path, "r") as file:
data = json.loads(file.read())
assert data[0]["name"] == "mock-model"

View File

@ -0,0 +1,34 @@
{
"$id": "https://open-metadata.org/schema/entity/services/connections/mlmodel/sageMakerConnection.json",
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "SageMakerConnection",
"description": "SageMaker Connection Config",
"type": "object",
"javaType": "org.openmetadata.schema.services.connections.mlmodel.SageMakerConnection",
"definitions": {
"sageMakerType": {
"description": "Service type.",
"type": "string",
"enum": ["SageMaker"],
"default": "SageMaker"
}
},
"properties": {
"type": {
"title": "Service Type",
"description": "Service Type",
"$ref": "#/definitions/sageMakerType",
"default": "SageMaker"
},
"awsConfig": {
"title": "AWS Credentials Configuration",
"$ref": "../../../../security/credentials/awsCredentials.json"
},
"supportsMetadataExtraction": {
"title": "Supports Metadata Extraction",
"$ref": "../connectionBasicType.json#/definitions/supportsMetadataExtraction"
}
},
"additionalProperties": false,
"required": ["awsConfig"]
}

View File

@ -14,7 +14,7 @@
"description": "Type of MlModel service",
"type": "string",
"javaInterfaces": ["org.openmetadata.schema.EnumInterface"],
"enum": ["Mlflow", "Sklearn", "CustomMlModel"],
"enum": ["Mlflow", "Sklearn", "CustomMlModel", "SageMaker"],
"javaEnums": [
{
"name": "Mlflow"
@ -24,6 +24,9 @@
},
{
"name": "CustomMlModel"
},
{
"name": "SageMaker"
}
]
},
@ -46,6 +49,9 @@
},
{
"$ref": "./connections/mlmodel/customMlModelConnection.json"
},
{
"$ref": "./connections/mlmodel/sageMakerConnection.json"
}
]
}