fix(ingest): glue import type stubs only for testing (#3032)

This commit is contained in:
Kevin Hu 2021-08-04 15:43:22 -04:00 committed by GitHub
parent 8c4a1414fc
commit 3d061161d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 81 additions and 61 deletions

View File

@ -1,16 +1,19 @@
from functools import reduce
from typing import List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union
import boto3
from boto3.session import Session
from mypy_boto3_glue import GlueClient
from mypy_boto3_s3 import S3Client
from mypy_boto3_sagemaker import SageMakerClient
from datahub.configuration import ConfigModel
from datahub.configuration.common import AllowDenyPattern
from datahub.emitter.mce_builder import DEFAULT_ENV
if TYPE_CHECKING:
from mypy_boto3_glue import GlueClient
from mypy_boto3_s3 import S3Client
from mypy_boto3_sagemaker import SageMakerClient
def assume_role(
role_arn: str, aws_region: str, credentials: Optional[dict] = None
@ -88,13 +91,13 @@ class AwsSourceConfig(ConfigModel):
else:
return Session(region_name=self.aws_region)
def get_s3_client(self) -> S3Client:
def get_s3_client(self) -> "S3Client":
return self.get_session().client("s3")
def get_glue_client(self) -> GlueClient:
def get_glue_client(self) -> "GlueClient":
return self.get_session().client("glue")
def get_sagemaker_client(self) -> SageMakerClient:
def get_sagemaker_client(self) -> "SageMakerClient":
return self.get_session().client("sagemaker")

View File

@ -1,12 +1,5 @@
from dataclasses import dataclass
from typing import Iterable, List
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
DescribeFeatureGroupResponseTypeDef,
FeatureDefinitionTypeDef,
FeatureGroupSummaryTypeDef,
)
from typing import TYPE_CHECKING, Iterable, List
import datahub.emitter.mce_builder as builder
from datahub.ingestion.api.workunit import MetadataWorkUnit
@ -27,14 +20,23 @@ from datahub.metadata.schema_classes import (
MLPrimaryKeyPropertiesClass,
)
if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
DescribeFeatureGroupResponseTypeDef,
FeatureDefinitionTypeDef,
FeatureGroupSummaryTypeDef,
)
@dataclass
class FeatureGroupProcessor:
sagemaker_client: SageMakerClient
sagemaker_client: "SageMakerClient"
env: str
report: SagemakerSourceReport
def get_all_feature_groups(self) -> List[FeatureGroupSummaryTypeDef]:
def get_all_feature_groups(self) -> List["FeatureGroupSummaryTypeDef"]:
"""
List all feature groups in SageMaker.
"""
@ -50,7 +52,7 @@ class FeatureGroupProcessor:
def get_feature_group_details(
self, feature_group_name: str
) -> DescribeFeatureGroupResponseTypeDef:
) -> "DescribeFeatureGroupResponseTypeDef":
"""
Get details of a feature group (including list of component features).
"""
@ -74,7 +76,7 @@ class FeatureGroupProcessor:
return feature_group
def get_feature_group_wu(
self, feature_group_details: DescribeFeatureGroupResponseTypeDef
self, feature_group_details: "DescribeFeatureGroupResponseTypeDef"
) -> MetadataWorkUnit:
"""
Generate an MLFeatureTable workunit for a SageMaker feature group.
@ -146,8 +148,8 @@ class FeatureGroupProcessor:
def get_feature_wu(
self,
feature_group_details: DescribeFeatureGroupResponseTypeDef,
feature: FeatureDefinitionTypeDef,
feature_group_details: "DescribeFeatureGroupResponseTypeDef",
feature: "FeatureDefinitionTypeDef",
) -> MetadataWorkUnit:
"""
Generate an MLFeature workunit for a SageMaker feature.

View File

@ -2,6 +2,7 @@ from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
DefaultDict,
Dict,
@ -16,8 +17,6 @@ from typing import (
Union,
)
from mypy_boto3_sagemaker import SageMakerClient
from datahub.emitter import mce_builder
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.aws_common import make_s3_urn
@ -47,6 +46,9 @@ from datahub.metadata.schema_classes import (
JobStatusClass,
)
if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient
JobInfo = TypeVar(
"JobInfo",
AutoMlJobInfo,
@ -151,7 +153,7 @@ class JobProcessor:
"""
# boto3 SageMaker client
sagemaker_client: SageMakerClient
sagemaker_client: "SageMakerClient"
env: str
report: SagemakerSourceReport
# config filter for specific job types to ingest (see metadata-ingestion README)

View File

@ -1,19 +1,20 @@
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, DefaultDict, Dict, List, Set
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
ActionSummaryTypeDef,
ArtifactSummaryTypeDef,
AssociationSummaryTypeDef,
ContextSummaryTypeDef,
)
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set
from datahub.ingestion.source.aws.sagemaker_processors.common import (
SagemakerSourceReport,
)
if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
ActionSummaryTypeDef,
ArtifactSummaryTypeDef,
AssociationSummaryTypeDef,
ContextSummaryTypeDef,
)
@dataclass
class LineageInfo:
@ -42,13 +43,13 @@ class LineageInfo:
@dataclass
class LineageProcessor:
sagemaker_client: SageMakerClient
sagemaker_client: "SageMakerClient"
env: str
report: SagemakerSourceReport
nodes: Dict[str, Dict[str, Any]] = field(default_factory=dict)
lineage_info: LineageInfo = field(default_factory=LineageInfo)
def get_all_actions(self) -> List[ActionSummaryTypeDef]:
def get_all_actions(self) -> List["ActionSummaryTypeDef"]:
"""
List all actions in SageMaker.
"""
@ -62,7 +63,7 @@ class LineageProcessor:
return actions
def get_all_artifacts(self) -> List[ArtifactSummaryTypeDef]:
def get_all_artifacts(self) -> List["ArtifactSummaryTypeDef"]:
"""
List all artifacts in SageMaker.
"""
@ -76,7 +77,7 @@ class LineageProcessor:
return artifacts
def get_all_contexts(self) -> List[ContextSummaryTypeDef]:
def get_all_contexts(self) -> List["ContextSummaryTypeDef"]:
"""
List all contexts in SageMaker.
"""
@ -90,7 +91,7 @@ class LineageProcessor:
return contexts
def get_incoming_edges(self, node_arn: str) -> List[AssociationSummaryTypeDef]:
def get_incoming_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]:
"""
Get all incoming edges for a node in the lineage graph.
"""
@ -104,7 +105,7 @@ class LineageProcessor:
return edges
def get_outgoing_edges(self, node_arn: str) -> List[AssociationSummaryTypeDef]:
def get_outgoing_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]:
"""
Get all outgoing edges for a node in the lineage graph.
"""

View File

@ -1,16 +1,15 @@
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
DescribeEndpointOutputTypeDef,
DescribeModelOutputTypeDef,
DescribeModelPackageGroupOutputTypeDef,
EndpointSummaryTypeDef,
ModelPackageGroupSummaryTypeDef,
ModelSummaryTypeDef,
from typing import (
TYPE_CHECKING,
DefaultDict,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
)
import datahub.emitter.mce_builder as builder
@ -43,6 +42,17 @@ from datahub.metadata.schema_classes import (
OwnershipTypeClass,
)
if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
DescribeEndpointOutputTypeDef,
DescribeModelOutputTypeDef,
DescribeModelPackageGroupOutputTypeDef,
EndpointSummaryTypeDef,
ModelPackageGroupSummaryTypeDef,
ModelSummaryTypeDef,
)
ENDPOINT_STATUS_MAP: Dict[str, str] = {
"OutOfService": DeploymentStatusClass.OUT_OF_SERVICE,
"Creating": DeploymentStatusClass.CREATING,
@ -58,7 +68,7 @@ ENDPOINT_STATUS_MAP: Dict[str, str] = {
@dataclass
class ModelProcessor:
sagemaker_client: SageMakerClient
sagemaker_client: "SageMakerClient"
env: str
report: SagemakerSourceReport
lineage: LineageInfo
@ -81,7 +91,7 @@ class ModelProcessor:
group_arn_to_name: Dict[str, str] = field(default_factory=dict)
def get_all_models(self) -> List[ModelSummaryTypeDef]:
def get_all_models(self) -> List["ModelSummaryTypeDef"]:
"""
List all models in SageMaker.
"""
@ -95,7 +105,7 @@ class ModelProcessor:
return models
def get_model_details(self, model_name: str) -> DescribeModelOutputTypeDef:
def get_model_details(self, model_name: str) -> "DescribeModelOutputTypeDef":
"""
Get details of a model.
"""
@ -103,7 +113,7 @@ class ModelProcessor:
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_model
return self.sagemaker_client.describe_model(ModelName=model_name)
def get_all_groups(self) -> List[ModelPackageGroupSummaryTypeDef]:
def get_all_groups(self) -> List["ModelPackageGroupSummaryTypeDef"]:
"""
List all model groups in SageMaker.
"""
@ -118,7 +128,7 @@ class ModelProcessor:
def get_group_details(
self, group_name: str
) -> DescribeModelPackageGroupOutputTypeDef:
) -> "DescribeModelPackageGroupOutputTypeDef":
"""
Get details of a model group.
"""
@ -128,7 +138,7 @@ class ModelProcessor:
ModelPackageGroupName=group_name
)
def get_all_endpoints(self) -> List[EndpointSummaryTypeDef]:
def get_all_endpoints(self) -> List["EndpointSummaryTypeDef"]:
endpoints = []
@ -140,7 +150,9 @@ class ModelProcessor:
return endpoints
def get_endpoint_details(self, endpoint_name: str) -> DescribeEndpointOutputTypeDef:
def get_endpoint_details(
self, endpoint_name: str
) -> "DescribeEndpointOutputTypeDef":
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_endpoint
return self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
@ -162,7 +174,7 @@ class ModelProcessor:
return endpoint_status
def get_endpoint_wu(
self, endpoint_details: DescribeEndpointOutputTypeDef
self, endpoint_details: "DescribeEndpointOutputTypeDef"
) -> MetadataWorkUnit:
"""a
Get a workunit for an endpoint.
@ -206,7 +218,7 @@ class ModelProcessor:
def get_model_endpoints(
self,
model_details: DescribeModelOutputTypeDef,
model_details: "DescribeModelOutputTypeDef",
endpoint_arn_to_name: Dict[str, str],
model_image: Optional[str],
model_uri: Optional[str],
@ -235,7 +247,7 @@ class ModelProcessor:
return model_endpoints_sorted
def get_group_wu(
self, group_details: DescribeModelPackageGroupOutputTypeDef
self, group_details: "DescribeModelPackageGroupOutputTypeDef"
) -> MetadataWorkUnit:
"""
Get a workunit for a model group.
@ -285,7 +297,7 @@ class ModelProcessor:
return MetadataWorkUnit(id=group_name, mce=mce)
def match_model_jobs(
self, model_details: DescribeModelOutputTypeDef
self, model_details: "DescribeModelOutputTypeDef"
) -> Tuple[Set[str], Set[str], List[MLHyperParamClass], List[MLMetricClass]]:
model_training_jobs: Set[str] = set()
@ -380,7 +392,7 @@ class ModelProcessor:
def get_model_wu(
self,
model_details: DescribeModelOutputTypeDef,
model_details: "DescribeModelOutputTypeDef",
endpoint_arn_to_name: Dict[str, str],
) -> MetadataWorkUnit:
"""