mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2026-01-07 13:07:22 +00:00
feat: add SageMaker connector (#8435)
* feat: add sagemaker connector Signed-off-by: Tushar Mittal <chiragmittal.mittal@gmail.com> * fix: fix linting errors and update imports Signed-off-by: Tushar Mittal <chiragmittal.mittal@gmail.com> * test: add unit tests for sagemake source Signed-off-by: Tushar Mittal <chiragmittal.mittal@gmail.com> Signed-off-by: Tushar Mittal <chiragmittal.mittal@gmail.com>
This commit is contained in:
parent
d93b46ef31
commit
6f2c93089c
@ -82,6 +82,7 @@ plugins: Dict[str, Set[str]] = {
|
||||
"elasticsearch": {"elasticsearch==7.13.1", "requests-aws4auth==1.1.2"},
|
||||
"glue": {"boto3~=1.19.12"},
|
||||
"dynamodb": {"boto3~=1.19.12"},
|
||||
"sagemaker": {"boto3~=1.19.12"},
|
||||
"hive": {
|
||||
"pyhive~=0.6.5",
|
||||
"thrift~=0.13.0",
|
||||
|
||||
@ -21,6 +21,7 @@ from metadata.clients.connection_clients import (
|
||||
GlueDBClient,
|
||||
GluePipelineClient,
|
||||
KinesisClient,
|
||||
SageMakerClient,
|
||||
)
|
||||
from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials
|
||||
from metadata.utils.logger import utils_logger
|
||||
@ -84,7 +85,7 @@ class AWSClient:
|
||||
)
|
||||
return session.resource(service_name=service_name)
|
||||
|
||||
def get_dynomo_client(self) -> DynamoClient:
|
||||
def get_dynamo_client(self) -> DynamoClient:
|
||||
return DynamoClient(self.get_resource("dynamodb"))
|
||||
|
||||
def get_glue_db_client(self) -> GlueDBClient:
|
||||
@ -93,5 +94,8 @@ class AWSClient:
|
||||
def get_glue_pipeline_client(self) -> GluePipelineClient:
|
||||
return GluePipelineClient(self.get_client("glue"))
|
||||
|
||||
def get_sagemaker_client(self) -> SageMakerClient:
|
||||
return SageMakerClient(self.get_client("sagemaker"))
|
||||
|
||||
def get_kinesis_client(self) -> KinesisClient:
|
||||
return KinesisClient(self.get_client("kinesis"))
|
||||
|
||||
@ -122,6 +122,12 @@ class MlflowClientWrapper:
|
||||
self.client = client
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageMakerClient:
|
||||
def __init__(self, client) -> None:
|
||||
self.client = client
|
||||
|
||||
|
||||
@dataclass
|
||||
class FivetranClient:
|
||||
def __init__(self, client) -> None:
|
||||
|
||||
21
ingestion/src/metadata/examples/workflows/sagemaker.yaml
Normal file
21
ingestion/src/metadata/examples/workflows/sagemaker.yaml
Normal file
@ -0,0 +1,21 @@
|
||||
source:
|
||||
type: sagemaker
|
||||
serviceName: local_sagemaker
|
||||
serviceConnection:
|
||||
config:
|
||||
type: SageMaker
|
||||
awsConfig:
|
||||
awsAccessKeyId: aws_access_key_id
|
||||
awsSecretAccessKey: aws_secret_access_key
|
||||
awsRegion: aws region
|
||||
sourceConfig:
|
||||
config:
|
||||
type: MlModelMetadata
|
||||
sink:
|
||||
type: metadata-rest
|
||||
config: {}
|
||||
workflowConfig:
|
||||
loggerLevel: "DEBUG"
|
||||
openMetadataServerConfig:
|
||||
hostPort: http://localhost:8585/api
|
||||
authProvider: no-auth
|
||||
197
ingestion/src/metadata/ingestion/source/mlmodel/sagemaker.py
Normal file
197
ingestion/src/metadata/ingestion/source/mlmodel/sagemaker.py
Normal file
@ -0,0 +1,197 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SageMaker source module"""
|
||||
|
||||
import traceback
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, ValidationError
|
||||
|
||||
from metadata.generated.schema.api.data.createMlModel import CreateMlModelRequest
|
||||
from metadata.generated.schema.entity.data.mlmodel import (
|
||||
MlFeature,
|
||||
MlHyperParameter,
|
||||
MlStore,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
|
||||
OpenMetadataConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.mlmodel.sageMakerConnection import (
|
||||
SageMakerConnection,
|
||||
)
|
||||
from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
Source as WorkflowSource,
|
||||
)
|
||||
from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.generated.schema.type.tagLabel import TagLabel
|
||||
from metadata.ingestion.api.source import InvalidSourceException
|
||||
from metadata.ingestion.source.mlmodel.mlmodel_service import MlModelServiceSource
|
||||
from metadata.utils.filters import filter_by_mlmodel
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
|
||||
logger = ingestion_logger()
|
||||
|
||||
|
||||
class SageMakerModel(BaseModel):
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
|
||||
name: str = Field(..., description="Model name", title="Model Name")
|
||||
arn: str = Field(..., description="Model ARN in AWS account", title="Model ARN")
|
||||
creation_timestamp: str = Field(
|
||||
...,
|
||||
description="Timestamp of model creation in ISO format",
|
||||
title="Creation Timestamp",
|
||||
)
|
||||
|
||||
|
||||
class SagemakerSource(MlModelServiceSource):
|
||||
"""
|
||||
Source implementation to ingest SageMaker data.
|
||||
|
||||
We will iterate on the ML Models
|
||||
and prepare an iterator of CreateMlModelRequest
|
||||
"""
|
||||
|
||||
def __init__(self, config: WorkflowSource, metadata_config: OpenMetadataConnection):
|
||||
super().__init__(config, metadata_config)
|
||||
self.sagemaker = self.connection.client
|
||||
|
||||
@classmethod
|
||||
def create(cls, config_dict, metadata_config: OpenMetadataConnection):
|
||||
config: WorkflowSource = WorkflowSource.parse_obj(config_dict)
|
||||
connection: SageMakerConnection = config.serviceConnection.__root__.config
|
||||
if not isinstance(connection, SageMakerConnection):
|
||||
raise InvalidSourceException(
|
||||
f"Expected SageMakerConnection, but got {connection}"
|
||||
)
|
||||
return cls(config, metadata_config)
|
||||
|
||||
def get_mlmodels( # pylint: disable=arguments-differ
|
||||
self,
|
||||
) -> Iterable[SageMakerModel]:
|
||||
"""
|
||||
List and filters models
|
||||
"""
|
||||
args, has_more_models, models = {"MaxResults": 100}, True, []
|
||||
try:
|
||||
while has_more_models:
|
||||
response = self.sagemaker.list_models(**args)
|
||||
models.append(response["Models"])
|
||||
has_more_models = response.get("NextToken")
|
||||
args["NextToken"] = response.get("NextToken")
|
||||
except Exception as err:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.error(f"Failed to fetch models list - {err}")
|
||||
|
||||
for model in models:
|
||||
try:
|
||||
if filter_by_mlmodel(
|
||||
self.source_config.mlModelFilterPattern,
|
||||
mlmodel_name=model["ModelName"],
|
||||
):
|
||||
self.status.filter(
|
||||
model["ModelName"],
|
||||
"MlModel name pattern not allowed",
|
||||
)
|
||||
continue
|
||||
yield SageMakerModel(
|
||||
name=model["ModelName"],
|
||||
arn=model["ModelArn"],
|
||||
creation_timestamp=model["CreationTime"].isoformat(),
|
||||
)
|
||||
except ValidationError as err:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Validation error while creating SageMakerModel from model details - {err}"
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Wild error while creating SageMakerModel from model details - {err}"
|
||||
)
|
||||
continue
|
||||
|
||||
def _get_algorithm(self) -> str: # pylint: disable=arguments-differ
|
||||
logger.info(
|
||||
"Setting algorithm to default value of `mlmodel` for SageMaker Model"
|
||||
)
|
||||
return "mlmodel"
|
||||
|
||||
def yield_mlmodel( # pylint: disable=arguments-differ
|
||||
self, model: SageMakerModel
|
||||
) -> Iterable[CreateMlModelRequest]:
|
||||
"""
|
||||
Prepare the Request model
|
||||
"""
|
||||
self.status.scanned(model.name)
|
||||
|
||||
yield CreateMlModelRequest(
|
||||
name=model.name,
|
||||
algorithm=self._get_algorithm(), # Setting this to a constant
|
||||
mlStore=self._get_ml_store(model.name),
|
||||
service=EntityReference(
|
||||
id=self.context.mlmodel_service.id, type="mlmodelService"
|
||||
),
|
||||
)
|
||||
|
||||
def _get_ml_store( # pylint: disable=arguments-differ
|
||||
self,
|
||||
model_name: str,
|
||||
) -> Optional[MlStore]:
|
||||
"""
|
||||
Get the Ml Store for the model
|
||||
"""
|
||||
try:
|
||||
model_info = self.sagemaker.describe_model(ModelName=model_name)
|
||||
return MlStore(imageRepository=model_info["PrimaryContainer"]["Image"])
|
||||
except ValidationError as err:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Validation error adding the MlModel store from model description: {model_name} - {err}"
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Wild error adding the MlModel store from model description: {model_name} - {err}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_tags(self, model_arn: str) -> Optional[List[TagLabel]]:
|
||||
try:
|
||||
tags = self.sagemaker.list_tags(ResourceArn=model_arn)["Tags"]
|
||||
return [
|
||||
TagLabel(
|
||||
tagFQN=tag["Key"],
|
||||
description=tag["Value"],
|
||||
source="Tag",
|
||||
labelType="Propagated",
|
||||
state="Confirmed",
|
||||
)
|
||||
for tag in tags
|
||||
]
|
||||
except ValidationError as err:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Validation error adding TagLabel from model tags: {model_arn} - {err}"
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Wild error adding TagLabel from model tags: {model_arn} - {err}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_hyper_params(self, *args, **kwargs) -> Optional[List[MlHyperParameter]]:
|
||||
pass
|
||||
|
||||
def _get_ml_features(self, *args, **kwargs) -> Optional[List[MlFeature]]:
|
||||
pass
|
||||
@ -48,6 +48,7 @@ from metadata.clients.connection_clients import (
|
||||
NifiClientWrapper,
|
||||
PowerBiClient,
|
||||
RedashClient,
|
||||
SageMakerClient,
|
||||
SalesforceClient,
|
||||
SupersetClient,
|
||||
TableauClient,
|
||||
@ -121,6 +122,9 @@ from metadata.generated.schema.entity.services.connections.messaging.redpandaCon
|
||||
from metadata.generated.schema.entity.services.connections.mlmodel.mlflowConnection import (
|
||||
MlflowConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.mlmodel.sageMakerConnection import (
|
||||
SageMakerConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.pipeline.airbyteConnection import (
|
||||
AirbyteConnection,
|
||||
)
|
||||
@ -285,8 +289,8 @@ def _(
|
||||
) -> DynamoClient:
|
||||
from metadata.clients.aws_client import AWSClient
|
||||
|
||||
dynomo_connection = AWSClient(connection.awsConfig).get_dynomo_client()
|
||||
return dynomo_connection
|
||||
dynamo_connection = AWSClient(connection.awsConfig).get_dynamo_client()
|
||||
return dynamo_connection
|
||||
|
||||
|
||||
@get_connection.register
|
||||
@ -496,7 +500,7 @@ def _(connection: DynamoClient) -> None:
|
||||
def _(connection: GlueDBClient) -> None:
|
||||
"""
|
||||
Test that we can connect to the source using the given aws resource
|
||||
:param engine: boto cliet to test
|
||||
:param engine: boto client to test
|
||||
:return: None or raise an exception if we cannot connect
|
||||
"""
|
||||
from botocore.client import ClientError
|
||||
@ -947,6 +951,36 @@ def _(connection: MlflowClientWrapper) -> None:
|
||||
raise SourceConnectionException(msg) from exc
|
||||
|
||||
|
||||
@get_connection.register
|
||||
def _(
|
||||
connection: SageMakerConnection,
|
||||
verbose: bool = False, # pylint: disable=unused-argument
|
||||
) -> SageMakerClient:
|
||||
from metadata.clients.aws_client import AWSClient
|
||||
|
||||
sagemaker_connection = AWSClient(connection.awsConfig).get_sagemaker_client()
|
||||
return sagemaker_connection
|
||||
|
||||
|
||||
@test_connection.register
|
||||
def _(connection: SageMakerClient) -> None:
|
||||
"""
|
||||
Test that we can connect to the SageMaker source using the given aws resource
|
||||
:param engine: boto service resource to test
|
||||
:return: None or raise an exception if we cannot connect
|
||||
"""
|
||||
from botocore.client import ClientError
|
||||
|
||||
try:
|
||||
connection.client.list_models()
|
||||
except ClientError as err:
|
||||
msg = f"Connection error for {connection}: {err}. Check the connection details."
|
||||
raise SourceConnectionException(msg) from err
|
||||
except Exception as exc:
|
||||
msg = f"Unknown error connecting with {connection}: {exc}."
|
||||
raise SourceConnectionException(msg) from exc
|
||||
|
||||
|
||||
@get_connection.register
|
||||
def _(
|
||||
connection: NifiConnection, verbose: bool = False
|
||||
|
||||
105
ingestion/tests/unit/source/test_sagemaker.py
Normal file
105
ingestion/tests/unit/source/test_sagemaker.py
Normal file
@ -0,0 +1,105 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SageMaker unit tests
|
||||
"""
|
||||
import json
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import boto3
|
||||
from moto import mock_sagemaker
|
||||
|
||||
from metadata.ingestion.api.workflow import Workflow
|
||||
|
||||
CONFIG = """
|
||||
{
|
||||
"source": {
|
||||
"type": "sagemaker",
|
||||
"serviceName": "local_sagemaker",
|
||||
"serviceConnection": {
|
||||
"config": {
|
||||
"type": "SageMaker",
|
||||
"awsConfig": {
|
||||
"awsAccessKeyId": "aws_access_key_id",
|
||||
"awsSecretAccessKey": "aws_secret_access_key",
|
||||
"awsRegion": "us-east-2"
|
||||
}
|
||||
}
|
||||
},
|
||||
"sourceConfig": {
|
||||
"config": {
|
||||
"type": "MlModelMetadata"
|
||||
}
|
||||
}
|
||||
},
|
||||
"sink": {
|
||||
"type": "file",
|
||||
"config": {
|
||||
"filename": "/var/tmp/datasets.json"
|
||||
}
|
||||
},
|
||||
"workflowConfig": {
|
||||
"openMetadataServerConfig": {
|
||||
"hostPort": "http://localhost:8585/api",
|
||||
"authProvider": "no-auth"
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def execute_workflow():
|
||||
workflow = Workflow.create(json.loads(CONFIG))
|
||||
workflow.execute()
|
||||
workflow.print_status()
|
||||
workflow.stop()
|
||||
|
||||
|
||||
def get_file_path():
|
||||
return json.loads(CONFIG)["sink"]["config"]["filename"]
|
||||
|
||||
|
||||
def _setup_mock_sagemaker(create_model: bool = False):
|
||||
sagemaker = boto3.Session().client("sagemaker")
|
||||
if not create_model:
|
||||
return
|
||||
print("Creating model!!!!!!!")
|
||||
sagemaker.create_model(
|
||||
ModelName="mock-model",
|
||||
PrimaryContainer={
|
||||
"Environment": {},
|
||||
"Image": "123.dkr.ecr.eu-west-1.amazonaws.com/image:mock-image",
|
||||
"Mode": "SingleModel",
|
||||
},
|
||||
ExecutionRoleArn="arn:aws:iam::123:role/service-role/mockRole",
|
||||
EnableNetworkIsolation=False,
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@patch("sqlalchemy.engine.base.Engine.connect")
|
||||
class SageMakerIngestionTest(TestCase):
|
||||
def test_sagemaker_empty_models(self, mock_connect):
|
||||
_setup_mock_sagemaker()
|
||||
execute_workflow()
|
||||
file_path = get_file_path()
|
||||
with open(file_path, "r") as file:
|
||||
assert len(json.loads(file.read())) == 0
|
||||
|
||||
def test_sagemaker_models(self, mock_connect):
|
||||
_setup_mock_sagemaker(create_model=True)
|
||||
execute_workflow()
|
||||
file_path = get_file_path()
|
||||
with open(file_path, "r") as file:
|
||||
data = json.loads(file.read())
|
||||
assert data[0]["name"] == "mock-model"
|
||||
@ -0,0 +1,34 @@
|
||||
{
|
||||
"$id": "https://open-metadata.org/schema/entity/services/connections/mlmodel/sageMakerConnection.json",
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "SageMakerConnection",
|
||||
"description": "SageMaker Connection Config",
|
||||
"type": "object",
|
||||
"javaType": "org.openmetadata.schema.services.connections.mlmodel.SageMakerConnection",
|
||||
"definitions": {
|
||||
"sageMakerType": {
|
||||
"description": "Service type.",
|
||||
"type": "string",
|
||||
"enum": ["SageMaker"],
|
||||
"default": "SageMaker"
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
"type": {
|
||||
"title": "Service Type",
|
||||
"description": "Service Type",
|
||||
"$ref": "#/definitions/sageMakerType",
|
||||
"default": "SageMaker"
|
||||
},
|
||||
"awsConfig": {
|
||||
"title": "AWS Credentials Configuration",
|
||||
"$ref": "../../../../security/credentials/awsCredentials.json"
|
||||
},
|
||||
"supportsMetadataExtraction": {
|
||||
"title": "Supports Metadata Extraction",
|
||||
"$ref": "../connectionBasicType.json#/definitions/supportsMetadataExtraction"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": ["awsConfig"]
|
||||
}
|
||||
@ -14,7 +14,7 @@
|
||||
"description": "Type of MlModel service",
|
||||
"type": "string",
|
||||
"javaInterfaces": ["org.openmetadata.schema.EnumInterface"],
|
||||
"enum": ["Mlflow", "Sklearn", "CustomMlModel"],
|
||||
"enum": ["Mlflow", "Sklearn", "CustomMlModel", "SageMaker"],
|
||||
"javaEnums": [
|
||||
{
|
||||
"name": "Mlflow"
|
||||
@ -24,6 +24,9 @@
|
||||
},
|
||||
{
|
||||
"name": "CustomMlModel"
|
||||
},
|
||||
{
|
||||
"name": "SageMaker"
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -46,6 +49,9 @@
|
||||
},
|
||||
{
|
||||
"$ref": "./connections/mlmodel/customMlModelConnection.json"
|
||||
},
|
||||
{
|
||||
"$ref": "./connections/mlmodel/sageMakerConnection.json"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user