mirror of
https://github.com/datahub-project/datahub.git
synced 2025-11-02 11:49:23 +00:00
fix(ingest): glue import type stubs only for testing (#3032)
This commit is contained in:
parent
8c4a1414fc
commit
3d061161d8
@ -1,16 +1,19 @@
|
||||
from functools import reduce
|
||||
from typing import List, Optional, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import boto3
|
||||
from boto3.session import Session
|
||||
from mypy_boto3_glue import GlueClient
|
||||
from mypy_boto3_s3 import S3Client
|
||||
from mypy_boto3_sagemaker import SageMakerClient
|
||||
|
||||
from datahub.configuration import ConfigModel
|
||||
from datahub.configuration.common import AllowDenyPattern
|
||||
from datahub.emitter.mce_builder import DEFAULT_ENV
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from mypy_boto3_glue import GlueClient
|
||||
from mypy_boto3_s3 import S3Client
|
||||
from mypy_boto3_sagemaker import SageMakerClient
|
||||
|
||||
|
||||
def assume_role(
|
||||
role_arn: str, aws_region: str, credentials: Optional[dict] = None
|
||||
@ -88,13 +91,13 @@ class AwsSourceConfig(ConfigModel):
|
||||
else:
|
||||
return Session(region_name=self.aws_region)
|
||||
|
||||
def get_s3_client(self) -> S3Client:
|
||||
def get_s3_client(self) -> "S3Client":
|
||||
return self.get_session().client("s3")
|
||||
|
||||
def get_glue_client(self) -> GlueClient:
|
||||
def get_glue_client(self) -> "GlueClient":
|
||||
return self.get_session().client("glue")
|
||||
|
||||
def get_sagemaker_client(self) -> SageMakerClient:
|
||||
def get_sagemaker_client(self) -> "SageMakerClient":
|
||||
return self.get_session().client("sagemaker")
|
||||
|
||||
|
||||
|
||||
@ -1,12 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List
|
||||
|
||||
from mypy_boto3_sagemaker import SageMakerClient
|
||||
from mypy_boto3_sagemaker.type_defs import (
|
||||
DescribeFeatureGroupResponseTypeDef,
|
||||
FeatureDefinitionTypeDef,
|
||||
FeatureGroupSummaryTypeDef,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Iterable, List
|
||||
|
||||
import datahub.emitter.mce_builder as builder
|
||||
from datahub.ingestion.api.workunit import MetadataWorkUnit
|
||||
@ -27,14 +20,23 @@ from datahub.metadata.schema_classes import (
|
||||
MLPrimaryKeyPropertiesClass,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from mypy_boto3_sagemaker import SageMakerClient
|
||||
from mypy_boto3_sagemaker.type_defs import (
|
||||
DescribeFeatureGroupResponseTypeDef,
|
||||
FeatureDefinitionTypeDef,
|
||||
FeatureGroupSummaryTypeDef,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureGroupProcessor:
|
||||
sagemaker_client: SageMakerClient
|
||||
sagemaker_client: "SageMakerClient"
|
||||
env: str
|
||||
report: SagemakerSourceReport
|
||||
|
||||
def get_all_feature_groups(self) -> List[FeatureGroupSummaryTypeDef]:
|
||||
def get_all_feature_groups(self) -> List["FeatureGroupSummaryTypeDef"]:
|
||||
"""
|
||||
List all feature groups in SageMaker.
|
||||
"""
|
||||
@ -50,7 +52,7 @@ class FeatureGroupProcessor:
|
||||
|
||||
def get_feature_group_details(
|
||||
self, feature_group_name: str
|
||||
) -> DescribeFeatureGroupResponseTypeDef:
|
||||
) -> "DescribeFeatureGroupResponseTypeDef":
|
||||
"""
|
||||
Get details of a feature group (including list of component features).
|
||||
"""
|
||||
@ -74,7 +76,7 @@ class FeatureGroupProcessor:
|
||||
return feature_group
|
||||
|
||||
def get_feature_group_wu(
|
||||
self, feature_group_details: DescribeFeatureGroupResponseTypeDef
|
||||
self, feature_group_details: "DescribeFeatureGroupResponseTypeDef"
|
||||
) -> MetadataWorkUnit:
|
||||
"""
|
||||
Generate an MLFeatureTable workunit for a SageMaker feature group.
|
||||
@ -146,8 +148,8 @@ class FeatureGroupProcessor:
|
||||
|
||||
def get_feature_wu(
|
||||
self,
|
||||
feature_group_details: DescribeFeatureGroupResponseTypeDef,
|
||||
feature: FeatureDefinitionTypeDef,
|
||||
feature_group_details: "DescribeFeatureGroupResponseTypeDef",
|
||||
feature: "FeatureDefinitionTypeDef",
|
||||
) -> MetadataWorkUnit:
|
||||
"""
|
||||
Generate an MLFeature workunit for a SageMaker feature.
|
||||
|
||||
@ -2,6 +2,7 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
@ -16,8 +17,6 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from mypy_boto3_sagemaker import SageMakerClient
|
||||
|
||||
from datahub.emitter import mce_builder
|
||||
from datahub.ingestion.api.workunit import MetadataWorkUnit
|
||||
from datahub.ingestion.source.aws.aws_common import make_s3_urn
|
||||
@ -47,6 +46,9 @@ from datahub.metadata.schema_classes import (
|
||||
JobStatusClass,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mypy_boto3_sagemaker import SageMakerClient
|
||||
|
||||
JobInfo = TypeVar(
|
||||
"JobInfo",
|
||||
AutoMlJobInfo,
|
||||
@ -151,7 +153,7 @@ class JobProcessor:
|
||||
"""
|
||||
|
||||
# boto3 SageMaker client
|
||||
sagemaker_client: SageMakerClient
|
||||
sagemaker_client: "SageMakerClient"
|
||||
env: str
|
||||
report: SagemakerSourceReport
|
||||
# config filter for specific job types to ingest (see metadata-ingestion README)
|
||||
|
||||
@ -1,19 +1,20 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, DefaultDict, Dict, List, Set
|
||||
|
||||
from mypy_boto3_sagemaker import SageMakerClient
|
||||
from mypy_boto3_sagemaker.type_defs import (
|
||||
ActionSummaryTypeDef,
|
||||
ArtifactSummaryTypeDef,
|
||||
AssociationSummaryTypeDef,
|
||||
ContextSummaryTypeDef,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set
|
||||
|
||||
from datahub.ingestion.source.aws.sagemaker_processors.common import (
|
||||
SagemakerSourceReport,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mypy_boto3_sagemaker import SageMakerClient
|
||||
from mypy_boto3_sagemaker.type_defs import (
|
||||
ActionSummaryTypeDef,
|
||||
ArtifactSummaryTypeDef,
|
||||
AssociationSummaryTypeDef,
|
||||
ContextSummaryTypeDef,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LineageInfo:
|
||||
@ -42,13 +43,13 @@ class LineageInfo:
|
||||
|
||||
@dataclass
|
||||
class LineageProcessor:
|
||||
sagemaker_client: SageMakerClient
|
||||
sagemaker_client: "SageMakerClient"
|
||||
env: str
|
||||
report: SagemakerSourceReport
|
||||
nodes: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
lineage_info: LineageInfo = field(default_factory=LineageInfo)
|
||||
|
||||
def get_all_actions(self) -> List[ActionSummaryTypeDef]:
|
||||
def get_all_actions(self) -> List["ActionSummaryTypeDef"]:
|
||||
"""
|
||||
List all actions in SageMaker.
|
||||
"""
|
||||
@ -62,7 +63,7 @@ class LineageProcessor:
|
||||
|
||||
return actions
|
||||
|
||||
def get_all_artifacts(self) -> List[ArtifactSummaryTypeDef]:
|
||||
def get_all_artifacts(self) -> List["ArtifactSummaryTypeDef"]:
|
||||
"""
|
||||
List all artifacts in SageMaker.
|
||||
"""
|
||||
@ -76,7 +77,7 @@ class LineageProcessor:
|
||||
|
||||
return artifacts
|
||||
|
||||
def get_all_contexts(self) -> List[ContextSummaryTypeDef]:
|
||||
def get_all_contexts(self) -> List["ContextSummaryTypeDef"]:
|
||||
"""
|
||||
List all contexts in SageMaker.
|
||||
"""
|
||||
@ -90,7 +91,7 @@ class LineageProcessor:
|
||||
|
||||
return contexts
|
||||
|
||||
def get_incoming_edges(self, node_arn: str) -> List[AssociationSummaryTypeDef]:
|
||||
def get_incoming_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]:
|
||||
"""
|
||||
Get all incoming edges for a node in the lineage graph.
|
||||
"""
|
||||
@ -104,7 +105,7 @@ class LineageProcessor:
|
||||
|
||||
return edges
|
||||
|
||||
def get_outgoing_edges(self, node_arn: str) -> List[AssociationSummaryTypeDef]:
|
||||
def get_outgoing_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]:
|
||||
"""
|
||||
Get all outgoing edges for a node in the lineage graph.
|
||||
"""
|
||||
|
||||
@ -1,16 +1,15 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from mypy_boto3_sagemaker import SageMakerClient
|
||||
from mypy_boto3_sagemaker.type_defs import (
|
||||
DescribeEndpointOutputTypeDef,
|
||||
DescribeModelOutputTypeDef,
|
||||
DescribeModelPackageGroupOutputTypeDef,
|
||||
EndpointSummaryTypeDef,
|
||||
ModelPackageGroupSummaryTypeDef,
|
||||
ModelSummaryTypeDef,
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import datahub.emitter.mce_builder as builder
|
||||
@ -43,6 +42,17 @@ from datahub.metadata.schema_classes import (
|
||||
OwnershipTypeClass,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mypy_boto3_sagemaker import SageMakerClient
|
||||
from mypy_boto3_sagemaker.type_defs import (
|
||||
DescribeEndpointOutputTypeDef,
|
||||
DescribeModelOutputTypeDef,
|
||||
DescribeModelPackageGroupOutputTypeDef,
|
||||
EndpointSummaryTypeDef,
|
||||
ModelPackageGroupSummaryTypeDef,
|
||||
ModelSummaryTypeDef,
|
||||
)
|
||||
|
||||
ENDPOINT_STATUS_MAP: Dict[str, str] = {
|
||||
"OutOfService": DeploymentStatusClass.OUT_OF_SERVICE,
|
||||
"Creating": DeploymentStatusClass.CREATING,
|
||||
@ -58,7 +68,7 @@ ENDPOINT_STATUS_MAP: Dict[str, str] = {
|
||||
|
||||
@dataclass
|
||||
class ModelProcessor:
|
||||
sagemaker_client: SageMakerClient
|
||||
sagemaker_client: "SageMakerClient"
|
||||
env: str
|
||||
report: SagemakerSourceReport
|
||||
lineage: LineageInfo
|
||||
@ -81,7 +91,7 @@ class ModelProcessor:
|
||||
|
||||
group_arn_to_name: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def get_all_models(self) -> List[ModelSummaryTypeDef]:
|
||||
def get_all_models(self) -> List["ModelSummaryTypeDef"]:
|
||||
"""
|
||||
List all models in SageMaker.
|
||||
"""
|
||||
@ -95,7 +105,7 @@ class ModelProcessor:
|
||||
|
||||
return models
|
||||
|
||||
def get_model_details(self, model_name: str) -> DescribeModelOutputTypeDef:
|
||||
def get_model_details(self, model_name: str) -> "DescribeModelOutputTypeDef":
|
||||
"""
|
||||
Get details of a model.
|
||||
"""
|
||||
@ -103,7 +113,7 @@ class ModelProcessor:
|
||||
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_model
|
||||
return self.sagemaker_client.describe_model(ModelName=model_name)
|
||||
|
||||
def get_all_groups(self) -> List[ModelPackageGroupSummaryTypeDef]:
|
||||
def get_all_groups(self) -> List["ModelPackageGroupSummaryTypeDef"]:
|
||||
"""
|
||||
List all model groups in SageMaker.
|
||||
"""
|
||||
@ -118,7 +128,7 @@ class ModelProcessor:
|
||||
|
||||
def get_group_details(
|
||||
self, group_name: str
|
||||
) -> DescribeModelPackageGroupOutputTypeDef:
|
||||
) -> "DescribeModelPackageGroupOutputTypeDef":
|
||||
"""
|
||||
Get details of a model group.
|
||||
"""
|
||||
@ -128,7 +138,7 @@ class ModelProcessor:
|
||||
ModelPackageGroupName=group_name
|
||||
)
|
||||
|
||||
def get_all_endpoints(self) -> List[EndpointSummaryTypeDef]:
|
||||
def get_all_endpoints(self) -> List["EndpointSummaryTypeDef"]:
|
||||
|
||||
endpoints = []
|
||||
|
||||
@ -140,7 +150,9 @@ class ModelProcessor:
|
||||
|
||||
return endpoints
|
||||
|
||||
def get_endpoint_details(self, endpoint_name: str) -> DescribeEndpointOutputTypeDef:
|
||||
def get_endpoint_details(
|
||||
self, endpoint_name: str
|
||||
) -> "DescribeEndpointOutputTypeDef":
|
||||
|
||||
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_endpoint
|
||||
return self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
|
||||
@ -162,7 +174,7 @@ class ModelProcessor:
|
||||
return endpoint_status
|
||||
|
||||
def get_endpoint_wu(
|
||||
self, endpoint_details: DescribeEndpointOutputTypeDef
|
||||
self, endpoint_details: "DescribeEndpointOutputTypeDef"
|
||||
) -> MetadataWorkUnit:
|
||||
"""a
|
||||
Get a workunit for an endpoint.
|
||||
@ -206,7 +218,7 @@ class ModelProcessor:
|
||||
|
||||
def get_model_endpoints(
|
||||
self,
|
||||
model_details: DescribeModelOutputTypeDef,
|
||||
model_details: "DescribeModelOutputTypeDef",
|
||||
endpoint_arn_to_name: Dict[str, str],
|
||||
model_image: Optional[str],
|
||||
model_uri: Optional[str],
|
||||
@ -235,7 +247,7 @@ class ModelProcessor:
|
||||
return model_endpoints_sorted
|
||||
|
||||
def get_group_wu(
|
||||
self, group_details: DescribeModelPackageGroupOutputTypeDef
|
||||
self, group_details: "DescribeModelPackageGroupOutputTypeDef"
|
||||
) -> MetadataWorkUnit:
|
||||
"""
|
||||
Get a workunit for a model group.
|
||||
@ -285,7 +297,7 @@ class ModelProcessor:
|
||||
return MetadataWorkUnit(id=group_name, mce=mce)
|
||||
|
||||
def match_model_jobs(
|
||||
self, model_details: DescribeModelOutputTypeDef
|
||||
self, model_details: "DescribeModelOutputTypeDef"
|
||||
) -> Tuple[Set[str], Set[str], List[MLHyperParamClass], List[MLMetricClass]]:
|
||||
|
||||
model_training_jobs: Set[str] = set()
|
||||
@ -380,7 +392,7 @@ class ModelProcessor:
|
||||
|
||||
def get_model_wu(
|
||||
self,
|
||||
model_details: DescribeModelOutputTypeDef,
|
||||
model_details: "DescribeModelOutputTypeDef",
|
||||
endpoint_arn_to_name: Dict[str, str],
|
||||
) -> MetadataWorkUnit:
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user