datahub/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py

250 lines
7.7 KiB
Python

from unittest.mock import patch
from botocore.stub import Stubber
from freezegun import freeze_time
import datahub.ingestion.source.aws.sagemaker_processors.models
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.sink.file import write_metadata_file
from datahub.ingestion.source.aws.sagemaker import (
SagemakerSource,
SagemakerSourceConfig,
)
from datahub.ingestion.source.aws.sagemaker_processors.jobs import (
job_type_to_info,
job_types,
)
from datahub.testing.doctest import assert_doctest
from tests.test_helpers import mce_helpers
from tests.unit.sagemaker.test_sagemaker_source_stubs import (
describe_endpoint_response_1,
describe_endpoint_response_2,
describe_feature_group_response_1,
describe_feature_group_response_2,
describe_feature_group_response_3,
describe_group_response,
describe_model_response_1,
describe_model_response_2,
get_first_model_package_incoming_response,
get_model_group_incoming_response,
get_second_model_package_incoming_response,
job_stubs,
list_actions_response,
list_artifacts_response,
list_contexts_response,
list_endpoints_response,
list_feature_groups_response,
list_first_endpoint_incoming_response,
list_first_endpoint_outgoing_response,
list_groups_response,
list_models_response,
list_second_endpoint_incoming_response,
list_second_endpoint_outgoing_response,
)
FROZEN_TIME = "2020-04-14 07:00:00"
def sagemaker_source() -> SagemakerSource:
return SagemakerSource(
ctx=PipelineContext(run_id="sagemaker-source-test"),
config=SagemakerSourceConfig(aws_region="us-west-2"),
)
@freeze_time(FROZEN_TIME)
def test_sagemaker_ingest(tmp_path, pytestconfig):
sagemaker_source_instance = sagemaker_source()
with Stubber(sagemaker_source_instance.sagemaker_client) as sagemaker_stubber:
sagemaker_stubber.add_response(
"list_actions",
list_actions_response,
{},
)
sagemaker_stubber.add_response(
"list_artifacts",
list_artifacts_response,
{},
)
sagemaker_stubber.add_response(
"list_contexts",
list_contexts_response,
{},
)
sagemaker_stubber.add_response(
"list_associations",
list_first_endpoint_incoming_response,
{
"DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:action/deploy-the-first-endpoint"
},
)
sagemaker_stubber.add_response(
"list_associations",
list_first_endpoint_outgoing_response,
{
"SourceArn": "arn:aws:sagemaker:us-west-2:123412341234:action/deploy-the-first-endpoint"
},
)
sagemaker_stubber.add_response(
"list_associations",
list_second_endpoint_incoming_response,
{
"DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:action/deploy-the-second-endpoint"
},
)
sagemaker_stubber.add_response(
"list_associations",
list_second_endpoint_outgoing_response,
{
"SourceArn": "arn:aws:sagemaker:us-west-2:123412341234:action/deploy-the-second-endpoint"
},
)
sagemaker_stubber.add_response(
"list_associations",
get_model_group_incoming_response,
{
"DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:context/a-model-package-group-context"
},
)
sagemaker_stubber.add_response(
"list_associations",
get_first_model_package_incoming_response,
{
"DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:artifact/the-first-model-package-artifact"
},
)
sagemaker_stubber.add_response(
"list_associations",
get_second_model_package_incoming_response,
{
"DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:artifact/the-second-model-package-artifact"
},
)
sagemaker_stubber.add_response(
"list_feature_groups",
list_feature_groups_response,
{},
)
sagemaker_stubber.add_response(
"describe_feature_group",
describe_feature_group_response_1,
{
"FeatureGroupName": "test-2",
},
)
sagemaker_stubber.add_response(
"describe_feature_group",
describe_feature_group_response_2,
{
"FeatureGroupName": "test-1",
},
)
sagemaker_stubber.add_response(
"describe_feature_group",
describe_feature_group_response_3,
{
"FeatureGroupName": "test",
},
)
for job_type in job_types:
job = job_stubs[job_type.value]
job_info = job_type_to_info[job_type]
sagemaker_stubber.add_response(
job_info.list_command,
job["list"],
{},
)
for job_type in job_types:
job = job_stubs[job_type.value]
job_info = job_type_to_info[job_type]
sagemaker_stubber.add_response(
job_info.describe_command,
job["describe"],
{job_info.describe_name_key: job["describe_name"]},
)
sagemaker_stubber.add_response(
"list_endpoints",
list_endpoints_response,
{},
)
sagemaker_stubber.add_response(
"describe_endpoint",
describe_endpoint_response_1,
{"EndpointName": "the-first-endpoint"},
)
sagemaker_stubber.add_response(
"describe_endpoint",
describe_endpoint_response_2,
{"EndpointName": "the-second-endpoint"},
)
sagemaker_stubber.add_response(
"list_model_package_groups",
list_groups_response,
{},
)
sagemaker_stubber.add_response(
"describe_model_package_group",
describe_group_response,
{"ModelPackageGroupName": "a-model-package-group"},
)
sagemaker_stubber.add_response(
"list_models",
list_models_response,
{},
)
sagemaker_stubber.add_response(
"describe_model",
describe_model_response_1,
{"ModelName": "the-first-model"},
)
sagemaker_stubber.add_response(
"describe_model",
describe_model_response_2,
{"ModelName": "the-second-model"},
)
# Patch the client factory's get_client method to return the stubbed client for jobs
with patch.object(
sagemaker_source_instance.client_factory,
"get_client",
return_value=sagemaker_source_instance.sagemaker_client,
):
# Run the test and generate the MCEs
mce_objects = [
wu.metadata for wu in sagemaker_source_instance.get_workunits()
]
write_metadata_file(tmp_path / "sagemaker_mces.json", mce_objects)
# Verify the output.
test_resources_dir = pytestconfig.rootpath / "tests/unit/sagemaker"
mce_helpers.check_golden_file(
pytestconfig,
output_path=tmp_path / "sagemaker_mces.json",
golden_path=test_resources_dir / "sagemaker_mces_golden.json",
)
def test_doc_test_run() -> None:
assert_doctest(datahub.ingestion.source.aws.sagemaker_processors.models)