From 2eb7920fe007fb4dd14bac84da54641c369728d3 Mon Sep 17 00:00:00 2001 From: BZ <93607724+BoyuanZhangDE@users.noreply.github.com> Date: Wed, 25 May 2022 04:57:02 -0400 Subject: [PATCH] feat(dbt): enable dbt read artifacts from s3 (#4935) Co-authored-by: Shirshanka Das --- .../ingestion/source/aws/aws_common.py | 33 ++-- .../src/datahub/ingestion/source/dbt.py | 187 +++++++++++------- 2 files changed, 133 insertions(+), 87 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py index 06cabe0be2..d4c123ede1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py @@ -7,7 +7,7 @@ from botocore.config import Config from botocore.utils import fix_s3_host from pydantic.fields import Field -from datahub.configuration.common import AllowDenyPattern +from datahub.configuration.common import AllowDenyPattern, ConfigModel from datahub.configuration.source_common import EnvBasedSourceConfigBase if TYPE_CHECKING: @@ -35,24 +35,16 @@ def assume_role( return assumed_role_object["Credentials"] -class AwsSourceConfig(EnvBasedSourceConfigBase): +class AwsConnectionConfig(ConfigModel): """ Common AWS credentials config. Currently used by: - Glue source - SageMaker source + - dbt source """ - database_pattern: AllowDenyPattern = Field( - default=AllowDenyPattern.allow_all(), - description="regex patterns for databases to filter in ingestion.", - ) - table_pattern: AllowDenyPattern = Field( - default=AllowDenyPattern.allow_all(), - description="regex patterns for tables to filter in ingestion.", - ) - aws_access_key_id: Optional[str] = Field( default=None, description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html", @@ -157,3 +149,22 @@ class AwsSourceConfig(EnvBasedSourceConfigBase): def get_sagemaker_client(self) -> "SageMakerClient": return self.get_session().client("sagemaker") + + +class AwsSourceConfig(EnvBasedSourceConfigBase, AwsConnectionConfig): + """ + Common AWS credentials config. + + Currently used by: + - Glue source + - SageMaker source + """ + + database_pattern: AllowDenyPattern = Field( + default=AllowDenyPattern.allow_all(), + description="regex patterns for databases to filter in ingestion.", + ) + table_pattern: AllowDenyPattern = Field( + default=AllowDenyPattern.allow_all(), + description="regex patterns for tables to filter in ingestion.", + ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt.py b/metadata-ingestion/src/datahub/ingestion/source/dbt.py index 4ec561d537..72a89b02ab 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt.py @@ -3,6 +3,7 @@ import logging import re from dataclasses import dataclass, field from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, cast +from urllib.parse import urlparse import dateutil.parser import requests @@ -23,6 +24,7 @@ from datahub.ingestion.api.decorators import ( ) from datahub.ingestion.api.ingestion_job_state_provider import JobId from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig from datahub.ingestion.source.sql.sql_types import ( BIGQUERY_TYPES_MAP, POSTGRES_TYPES_MAP, @@ -172,6 +174,15 @@ class DBTConfig(StatefulIngestionConfigBase): default=None, description='Regex string to extract owner from the dbt node using the `(?P...) syntax` of the [match object](https://docs.python.org/3/library/re.html#match-objects), where the group name must be `owner`. Examples: (1)`r"(?P(.*)): (\w+) (\w+)"` will extract `jdoe` as the owner from `"jdoe: John Doe"` (2) `r"@(?P(.*))"` will extract `alice` as the owner from `"@alice"`.', # noqa: W605 ) + aws_connection: Optional[AwsConnectionConfig] = Field( + default=None, + description="When fetching manifest files from s3, configuration for aws connection details", + ) + + @property + def s3_client(self): + assert self.aws_connection + return self.aws_connection.get_s3_client() # Custom Stateful Ingestion settings stateful_ingestion: Optional[DBTStatefulIngestionConfig] = Field( @@ -197,6 +208,22 @@ class DBTConfig(StatefulIngestionConfigBase): ) return write_semantics + @validator("aws_connection") + def aws_connection_needed_if_s3_uris_present( + cls, aws_connection: Optional[AwsConnectionConfig], values: Dict, **kwargs: Any + ) -> Optional[AwsConnectionConfig]: + # first check if there are fields that contain s3 uris + uri_containing_fields = [ + f + for f in ["manifest_path", "catalog_path", "sources_path"] + if values.get(f, "").startswith("s3://") + ] + if uri_containing_fields and not aws_connection: + raise ValueError( + f"Please provide aws_connection configuration, since s3 uris have been provided in fields {uri_containing_fields}" + ) + return aws_connection + @dataclass class DBTColumn: @@ -387,80 +414,6 @@ def extract_dbt_entities( return dbt_entities -def load_file_as_json(uri: str) -> Any: - if re.match("^https?://", uri): - return json.loads(requests.get(uri).text) - else: - with open(uri, "r") as f: - return json.load(f) - - -def loadManifestAndCatalog( - manifest_path: str, - catalog_path: str, - sources_path: Optional[str], - load_schemas: bool, - use_identifiers: bool, - tag_prefix: str, - node_type_pattern: AllowDenyPattern, - report: DBTSourceReport, - node_name_pattern: AllowDenyPattern, -) -> Tuple[ - List[DBTNode], - Optional[str], - Optional[str], - Optional[str], - Optional[str], - Dict[str, Dict[str, Any]], -]: - dbt_manifest_json = load_file_as_json(manifest_path) - - dbt_catalog_json = load_file_as_json(catalog_path) - - if sources_path is not None: - dbt_sources_json = load_file_as_json(sources_path) - sources_results = dbt_sources_json["results"] - else: - sources_results = {} - - manifest_schema = dbt_manifest_json.get("metadata", {}).get("dbt_schema_version") - manifest_version = dbt_manifest_json.get("metadata", {}).get("dbt_version") - - catalog_schema = dbt_catalog_json.get("metadata", {}).get("dbt_schema_version") - catalog_version = dbt_catalog_json.get("metadata", {}).get("dbt_version") - - manifest_nodes = dbt_manifest_json["nodes"] - manifest_sources = dbt_manifest_json["sources"] - - all_manifest_entities = {**manifest_nodes, **manifest_sources} - - catalog_nodes = dbt_catalog_json["nodes"] - catalog_sources = dbt_catalog_json["sources"] - - all_catalog_entities = {**catalog_nodes, **catalog_sources} - - nodes = extract_dbt_entities( - all_manifest_entities, - all_catalog_entities, - sources_results, - load_schemas, - use_identifiers, - tag_prefix, - node_type_pattern, - report, - node_name_pattern, - ) - - return ( - nodes, - manifest_schema, - manifest_version, - catalog_schema, - catalog_version, - all_manifest_entities, - ) - - def get_db_fqn(database: Optional[str], schema: str, name: str) -> str: if database is not None: fqn = f"{database}.{schema}.{name}" @@ -769,6 +722,88 @@ class DBTSource(StatefulIngestionSourceBase): ): yield from soft_delete_item(table_urn, "dataset") + # s3://data-analysis.pelotime.com/dbt-artifacts/data-engineering-dbt/catalog.json + def load_file_as_json(self, uri: str) -> Any: + if re.match("^https?://", uri): + return json.loads(requests.get(uri).text) + elif re.match("^s3://", uri): + u = urlparse(uri) + response = self.config.s3_client.get_object( + Bucket=u.netloc, Key=u.path.lstrip("/") + ) + return json.loads(response["Body"].read().decode("utf-8")) + else: + with open(uri, "r") as f: + return json.load(f) + + def loadManifestAndCatalog( + self, + manifest_path: str, + catalog_path: str, + sources_path: Optional[str], + load_schemas: bool, + use_identifiers: bool, + tag_prefix: str, + node_type_pattern: AllowDenyPattern, + report: DBTSourceReport, + node_name_pattern: AllowDenyPattern, + ) -> Tuple[ + List[DBTNode], + Optional[str], + Optional[str], + Optional[str], + Optional[str], + Dict[str, Dict[str, Any]], + ]: + dbt_manifest_json = self.load_file_as_json(manifest_path) + + dbt_catalog_json = self.load_file_as_json(catalog_path) + + if sources_path is not None: + dbt_sources_json = self.load_file_as_json(sources_path) + sources_results = dbt_sources_json["results"] + else: + sources_results = {} + + manifest_schema = dbt_manifest_json.get("metadata", {}).get( + "dbt_schema_version" + ) + manifest_version = dbt_manifest_json.get("metadata", {}).get("dbt_version") + + catalog_schema = dbt_catalog_json.get("metadata", {}).get("dbt_schema_version") + catalog_version = dbt_catalog_json.get("metadata", {}).get("dbt_version") + + manifest_nodes = dbt_manifest_json["nodes"] + manifest_sources = dbt_manifest_json["sources"] + + all_manifest_entities = {**manifest_nodes, **manifest_sources} + + catalog_nodes = dbt_catalog_json["nodes"] + catalog_sources = dbt_catalog_json["sources"] + + all_catalog_entities = {**catalog_nodes, **catalog_sources} + + nodes = extract_dbt_entities( + all_manifest_entities, + all_catalog_entities, + sources_results, + load_schemas, + use_identifiers, + tag_prefix, + node_type_pattern, + report, + node_name_pattern, + ) + + return ( + nodes, + manifest_schema, + manifest_version, + catalog_schema, + catalog_version, + all_manifest_entities, + ) + # create workunits from dbt nodes def get_workunits(self) -> Iterable[MetadataWorkUnit]: if self.config.write_semantics == "PATCH" and not self.ctx.graph: @@ -784,7 +819,7 @@ class DBTSource(StatefulIngestionSourceBase): catalog_schema, catalog_version, manifest_nodes_raw, - ) = loadManifestAndCatalog( + ) = self.loadManifestAndCatalog( self.config.manifest_path, self.config.catalog_path, self.config.sources_path, @@ -1356,7 +1391,7 @@ class DBTSource(StatefulIngestionSourceBase): """ project_id = ( - load_file_as_json(self.config.manifest_path) + self.load_file_as_json(self.config.manifest_path) .get("metadata", {}) .get("project_id") )