feat(dbt): enable dbt read artifacts from s3 (#4935)

Co-authored-by: Shirshanka Das <shirshanka@apache.org>
This commit is contained in:
BZ 2022-05-25 04:57:02 -04:00 committed by GitHub
parent c131e13582
commit 2eb7920fe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 133 additions and 87 deletions

View File

@ -7,7 +7,7 @@ from botocore.config import Config
from botocore.utils import fix_s3_host from botocore.utils import fix_s3_host
from pydantic.fields import Field from pydantic.fields import Field
from datahub.configuration.common import AllowDenyPattern from datahub.configuration.common import AllowDenyPattern, ConfigModel
from datahub.configuration.source_common import EnvBasedSourceConfigBase from datahub.configuration.source_common import EnvBasedSourceConfigBase
if TYPE_CHECKING: if TYPE_CHECKING:
@ -35,24 +35,16 @@ def assume_role(
return assumed_role_object["Credentials"] return assumed_role_object["Credentials"]
class AwsSourceConfig(EnvBasedSourceConfigBase): class AwsConnectionConfig(ConfigModel):
""" """
Common AWS credentials config. Common AWS credentials config.
Currently used by: Currently used by:
- Glue source - Glue source
- SageMaker source - SageMaker source
- dbt source
""" """
database_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="regex patterns for databases to filter in ingestion.",
)
table_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="regex patterns for tables to filter in ingestion.",
)
aws_access_key_id: Optional[str] = Field( aws_access_key_id: Optional[str] = Field(
default=None, default=None,
description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html", description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html",
@ -157,3 +149,22 @@ class AwsSourceConfig(EnvBasedSourceConfigBase):
def get_sagemaker_client(self) -> "SageMakerClient": def get_sagemaker_client(self) -> "SageMakerClient":
return self.get_session().client("sagemaker") return self.get_session().client("sagemaker")
class AwsSourceConfig(EnvBasedSourceConfigBase, AwsConnectionConfig):
"""
Common AWS credentials config.
Currently used by:
- Glue source
- SageMaker source
"""
database_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="regex patterns for databases to filter in ingestion.",
)
table_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="regex patterns for tables to filter in ingestion.",
)

View File

@ -3,6 +3,7 @@ import logging
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, cast from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, cast
from urllib.parse import urlparse
import dateutil.parser import dateutil.parser
import requests import requests
@ -23,6 +24,7 @@ from datahub.ingestion.api.decorators import (
) )
from datahub.ingestion.api.ingestion_job_state_provider import JobId from datahub.ingestion.api.ingestion_job_state_provider import JobId
from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig
from datahub.ingestion.source.sql.sql_types import ( from datahub.ingestion.source.sql.sql_types import (
BIGQUERY_TYPES_MAP, BIGQUERY_TYPES_MAP,
POSTGRES_TYPES_MAP, POSTGRES_TYPES_MAP,
@ -172,6 +174,15 @@ class DBTConfig(StatefulIngestionConfigBase):
default=None, default=None,
description='Regex string to extract owner from the dbt node using the `(?P<name>...) syntax` of the [match object](https://docs.python.org/3/library/re.html#match-objects), where the group name must be `owner`. Examples: (1)`r"(?P<owner>(.*)): (\w+) (\w+)"` will extract `jdoe` as the owner from `"jdoe: John Doe"` (2) `r"@(?P<owner>(.*))"` will extract `alice` as the owner from `"@alice"`.', # noqa: W605 description='Regex string to extract owner from the dbt node using the `(?P<name>...) syntax` of the [match object](https://docs.python.org/3/library/re.html#match-objects), where the group name must be `owner`. Examples: (1)`r"(?P<owner>(.*)): (\w+) (\w+)"` will extract `jdoe` as the owner from `"jdoe: John Doe"` (2) `r"@(?P<owner>(.*))"` will extract `alice` as the owner from `"@alice"`.', # noqa: W605
) )
aws_connection: Optional[AwsConnectionConfig] = Field(
default=None,
description="When fetching manifest files from s3, configuration for aws connection details",
)
@property
def s3_client(self):
assert self.aws_connection
return self.aws_connection.get_s3_client()
# Custom Stateful Ingestion settings # Custom Stateful Ingestion settings
stateful_ingestion: Optional[DBTStatefulIngestionConfig] = Field( stateful_ingestion: Optional[DBTStatefulIngestionConfig] = Field(
@ -197,6 +208,22 @@ class DBTConfig(StatefulIngestionConfigBase):
) )
return write_semantics return write_semantics
@validator("aws_connection")
def aws_connection_needed_if_s3_uris_present(
cls, aws_connection: Optional[AwsConnectionConfig], values: Dict, **kwargs: Any
) -> Optional[AwsConnectionConfig]:
# first check if there are fields that contain s3 uris
uri_containing_fields = [
f
for f in ["manifest_path", "catalog_path", "sources_path"]
if values.get(f, "").startswith("s3://")
]
if uri_containing_fields and not aws_connection:
raise ValueError(
f"Please provide aws_connection configuration, since s3 uris have been provided in fields {uri_containing_fields}"
)
return aws_connection
@dataclass @dataclass
class DBTColumn: class DBTColumn:
@ -387,80 +414,6 @@ def extract_dbt_entities(
return dbt_entities return dbt_entities
def load_file_as_json(uri: str) -> Any:
if re.match("^https?://", uri):
return json.loads(requests.get(uri).text)
else:
with open(uri, "r") as f:
return json.load(f)
def loadManifestAndCatalog(
manifest_path: str,
catalog_path: str,
sources_path: Optional[str],
load_schemas: bool,
use_identifiers: bool,
tag_prefix: str,
node_type_pattern: AllowDenyPattern,
report: DBTSourceReport,
node_name_pattern: AllowDenyPattern,
) -> Tuple[
List[DBTNode],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Dict[str, Dict[str, Any]],
]:
dbt_manifest_json = load_file_as_json(manifest_path)
dbt_catalog_json = load_file_as_json(catalog_path)
if sources_path is not None:
dbt_sources_json = load_file_as_json(sources_path)
sources_results = dbt_sources_json["results"]
else:
sources_results = {}
manifest_schema = dbt_manifest_json.get("metadata", {}).get("dbt_schema_version")
manifest_version = dbt_manifest_json.get("metadata", {}).get("dbt_version")
catalog_schema = dbt_catalog_json.get("metadata", {}).get("dbt_schema_version")
catalog_version = dbt_catalog_json.get("metadata", {}).get("dbt_version")
manifest_nodes = dbt_manifest_json["nodes"]
manifest_sources = dbt_manifest_json["sources"]
all_manifest_entities = {**manifest_nodes, **manifest_sources}
catalog_nodes = dbt_catalog_json["nodes"]
catalog_sources = dbt_catalog_json["sources"]
all_catalog_entities = {**catalog_nodes, **catalog_sources}
nodes = extract_dbt_entities(
all_manifest_entities,
all_catalog_entities,
sources_results,
load_schemas,
use_identifiers,
tag_prefix,
node_type_pattern,
report,
node_name_pattern,
)
return (
nodes,
manifest_schema,
manifest_version,
catalog_schema,
catalog_version,
all_manifest_entities,
)
def get_db_fqn(database: Optional[str], schema: str, name: str) -> str: def get_db_fqn(database: Optional[str], schema: str, name: str) -> str:
if database is not None: if database is not None:
fqn = f"{database}.{schema}.{name}" fqn = f"{database}.{schema}.{name}"
@ -769,6 +722,88 @@ class DBTSource(StatefulIngestionSourceBase):
): ):
yield from soft_delete_item(table_urn, "dataset") yield from soft_delete_item(table_urn, "dataset")
# s3://data-analysis.pelotime.com/dbt-artifacts/data-engineering-dbt/catalog.json
def load_file_as_json(self, uri: str) -> Any:
if re.match("^https?://", uri):
return json.loads(requests.get(uri).text)
elif re.match("^s3://", uri):
u = urlparse(uri)
response = self.config.s3_client.get_object(
Bucket=u.netloc, Key=u.path.lstrip("/")
)
return json.loads(response["Body"].read().decode("utf-8"))
else:
with open(uri, "r") as f:
return json.load(f)
def loadManifestAndCatalog(
self,
manifest_path: str,
catalog_path: str,
sources_path: Optional[str],
load_schemas: bool,
use_identifiers: bool,
tag_prefix: str,
node_type_pattern: AllowDenyPattern,
report: DBTSourceReport,
node_name_pattern: AllowDenyPattern,
) -> Tuple[
List[DBTNode],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Dict[str, Dict[str, Any]],
]:
dbt_manifest_json = self.load_file_as_json(manifest_path)
dbt_catalog_json = self.load_file_as_json(catalog_path)
if sources_path is not None:
dbt_sources_json = self.load_file_as_json(sources_path)
sources_results = dbt_sources_json["results"]
else:
sources_results = {}
manifest_schema = dbt_manifest_json.get("metadata", {}).get(
"dbt_schema_version"
)
manifest_version = dbt_manifest_json.get("metadata", {}).get("dbt_version")
catalog_schema = dbt_catalog_json.get("metadata", {}).get("dbt_schema_version")
catalog_version = dbt_catalog_json.get("metadata", {}).get("dbt_version")
manifest_nodes = dbt_manifest_json["nodes"]
manifest_sources = dbt_manifest_json["sources"]
all_manifest_entities = {**manifest_nodes, **manifest_sources}
catalog_nodes = dbt_catalog_json["nodes"]
catalog_sources = dbt_catalog_json["sources"]
all_catalog_entities = {**catalog_nodes, **catalog_sources}
nodes = extract_dbt_entities(
all_manifest_entities,
all_catalog_entities,
sources_results,
load_schemas,
use_identifiers,
tag_prefix,
node_type_pattern,
report,
node_name_pattern,
)
return (
nodes,
manifest_schema,
manifest_version,
catalog_schema,
catalog_version,
all_manifest_entities,
)
# create workunits from dbt nodes # create workunits from dbt nodes
def get_workunits(self) -> Iterable[MetadataWorkUnit]: def get_workunits(self) -> Iterable[MetadataWorkUnit]:
if self.config.write_semantics == "PATCH" and not self.ctx.graph: if self.config.write_semantics == "PATCH" and not self.ctx.graph:
@ -784,7 +819,7 @@ class DBTSource(StatefulIngestionSourceBase):
catalog_schema, catalog_schema,
catalog_version, catalog_version,
manifest_nodes_raw, manifest_nodes_raw,
) = loadManifestAndCatalog( ) = self.loadManifestAndCatalog(
self.config.manifest_path, self.config.manifest_path,
self.config.catalog_path, self.config.catalog_path,
self.config.sources_path, self.config.sources_path,
@ -1356,7 +1391,7 @@ class DBTSource(StatefulIngestionSourceBase):
""" """
project_id = ( project_id = (
load_file_as_json(self.config.manifest_path) self.load_file_as_json(self.config.manifest_path)
.get("metadata", {}) .get("metadata", {})
.get("project_id") .get("project_id")
) )