feat(dbt): enable dbt read artifacts from s3 (#4935)

Co-authored-by: Shirshanka Das <shirshanka@apache.org>
This commit is contained in:
BZ 2022-05-25 04:57:02 -04:00 committed by GitHub
parent c131e13582
commit 2eb7920fe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 133 additions and 87 deletions

View File

@ -7,7 +7,7 @@ from botocore.config import Config
from botocore.utils import fix_s3_host
from pydantic.fields import Field
from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.common import AllowDenyPattern, ConfigModel
from datahub.configuration.source_common import EnvBasedSourceConfigBase
if TYPE_CHECKING:
@ -35,24 +35,16 @@ def assume_role(
return assumed_role_object["Credentials"]
class AwsSourceConfig(EnvBasedSourceConfigBase):
class AwsConnectionConfig(ConfigModel):
"""
Common AWS credentials config.
Currently used by:
- Glue source
- SageMaker source
- dbt source
"""
database_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="regex patterns for databases to filter in ingestion.",
)
table_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="regex patterns for tables to filter in ingestion.",
)
aws_access_key_id: Optional[str] = Field(
default=None,
description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html",
@ -157,3 +149,22 @@ class AwsSourceConfig(EnvBasedSourceConfigBase):
def get_sagemaker_client(self) -> "SageMakerClient":
return self.get_session().client("sagemaker")
class AwsSourceConfig(EnvBasedSourceConfigBase, AwsConnectionConfig):
"""
Common AWS credentials config.
Currently used by:
- Glue source
- SageMaker source
"""
database_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="regex patterns for databases to filter in ingestion.",
)
table_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="regex patterns for tables to filter in ingestion.",
)

View File

@ -3,6 +3,7 @@ import logging
import re
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, cast
from urllib.parse import urlparse
import dateutil.parser
import requests
@ -23,6 +24,7 @@ from datahub.ingestion.api.decorators import (
)
from datahub.ingestion.api.ingestion_job_state_provider import JobId
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig
from datahub.ingestion.source.sql.sql_types import (
BIGQUERY_TYPES_MAP,
POSTGRES_TYPES_MAP,
@ -172,6 +174,15 @@ class DBTConfig(StatefulIngestionConfigBase):
default=None,
description='Regex string to extract owner from the dbt node using the `(?P<name>...) syntax` of the [match object](https://docs.python.org/3/library/re.html#match-objects), where the group name must be `owner`. Examples: (1)`r"(?P<owner>(.*)): (\w+) (\w+)"` will extract `jdoe` as the owner from `"jdoe: John Doe"` (2) `r"@(?P<owner>(.*))"` will extract `alice` as the owner from `"@alice"`.', # noqa: W605
)
aws_connection: Optional[AwsConnectionConfig] = Field(
default=None,
description="When fetching manifest files from s3, configuration for aws connection details",
)
@property
def s3_client(self):
assert self.aws_connection
return self.aws_connection.get_s3_client()
# Custom Stateful Ingestion settings
stateful_ingestion: Optional[DBTStatefulIngestionConfig] = Field(
@ -197,6 +208,22 @@ class DBTConfig(StatefulIngestionConfigBase):
)
return write_semantics
@validator("aws_connection")
def aws_connection_needed_if_s3_uris_present(
cls, aws_connection: Optional[AwsConnectionConfig], values: Dict, **kwargs: Any
) -> Optional[AwsConnectionConfig]:
# first check if there are fields that contain s3 uris
uri_containing_fields = [
f
for f in ["manifest_path", "catalog_path", "sources_path"]
if values.get(f, "").startswith("s3://")
]
if uri_containing_fields and not aws_connection:
raise ValueError(
f"Please provide aws_connection configuration, since s3 uris have been provided in fields {uri_containing_fields}"
)
return aws_connection
@dataclass
class DBTColumn:
@ -387,80 +414,6 @@ def extract_dbt_entities(
return dbt_entities
def load_file_as_json(uri: str) -> Any:
if re.match("^https?://", uri):
return json.loads(requests.get(uri).text)
else:
with open(uri, "r") as f:
return json.load(f)
def loadManifestAndCatalog(
manifest_path: str,
catalog_path: str,
sources_path: Optional[str],
load_schemas: bool,
use_identifiers: bool,
tag_prefix: str,
node_type_pattern: AllowDenyPattern,
report: DBTSourceReport,
node_name_pattern: AllowDenyPattern,
) -> Tuple[
List[DBTNode],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Dict[str, Dict[str, Any]],
]:
dbt_manifest_json = load_file_as_json(manifest_path)
dbt_catalog_json = load_file_as_json(catalog_path)
if sources_path is not None:
dbt_sources_json = load_file_as_json(sources_path)
sources_results = dbt_sources_json["results"]
else:
sources_results = {}
manifest_schema = dbt_manifest_json.get("metadata", {}).get("dbt_schema_version")
manifest_version = dbt_manifest_json.get("metadata", {}).get("dbt_version")
catalog_schema = dbt_catalog_json.get("metadata", {}).get("dbt_schema_version")
catalog_version = dbt_catalog_json.get("metadata", {}).get("dbt_version")
manifest_nodes = dbt_manifest_json["nodes"]
manifest_sources = dbt_manifest_json["sources"]
all_manifest_entities = {**manifest_nodes, **manifest_sources}
catalog_nodes = dbt_catalog_json["nodes"]
catalog_sources = dbt_catalog_json["sources"]
all_catalog_entities = {**catalog_nodes, **catalog_sources}
nodes = extract_dbt_entities(
all_manifest_entities,
all_catalog_entities,
sources_results,
load_schemas,
use_identifiers,
tag_prefix,
node_type_pattern,
report,
node_name_pattern,
)
return (
nodes,
manifest_schema,
manifest_version,
catalog_schema,
catalog_version,
all_manifest_entities,
)
def get_db_fqn(database: Optional[str], schema: str, name: str) -> str:
if database is not None:
fqn = f"{database}.{schema}.{name}"
@ -769,6 +722,88 @@ class DBTSource(StatefulIngestionSourceBase):
):
yield from soft_delete_item(table_urn, "dataset")
# s3://data-analysis.pelotime.com/dbt-artifacts/data-engineering-dbt/catalog.json
def load_file_as_json(self, uri: str) -> Any:
if re.match("^https?://", uri):
return json.loads(requests.get(uri).text)
elif re.match("^s3://", uri):
u = urlparse(uri)
response = self.config.s3_client.get_object(
Bucket=u.netloc, Key=u.path.lstrip("/")
)
return json.loads(response["Body"].read().decode("utf-8"))
else:
with open(uri, "r") as f:
return json.load(f)
def loadManifestAndCatalog(
self,
manifest_path: str,
catalog_path: str,
sources_path: Optional[str],
load_schemas: bool,
use_identifiers: bool,
tag_prefix: str,
node_type_pattern: AllowDenyPattern,
report: DBTSourceReport,
node_name_pattern: AllowDenyPattern,
) -> Tuple[
List[DBTNode],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Dict[str, Dict[str, Any]],
]:
dbt_manifest_json = self.load_file_as_json(manifest_path)
dbt_catalog_json = self.load_file_as_json(catalog_path)
if sources_path is not None:
dbt_sources_json = self.load_file_as_json(sources_path)
sources_results = dbt_sources_json["results"]
else:
sources_results = {}
manifest_schema = dbt_manifest_json.get("metadata", {}).get(
"dbt_schema_version"
)
manifest_version = dbt_manifest_json.get("metadata", {}).get("dbt_version")
catalog_schema = dbt_catalog_json.get("metadata", {}).get("dbt_schema_version")
catalog_version = dbt_catalog_json.get("metadata", {}).get("dbt_version")
manifest_nodes = dbt_manifest_json["nodes"]
manifest_sources = dbt_manifest_json["sources"]
all_manifest_entities = {**manifest_nodes, **manifest_sources}
catalog_nodes = dbt_catalog_json["nodes"]
catalog_sources = dbt_catalog_json["sources"]
all_catalog_entities = {**catalog_nodes, **catalog_sources}
nodes = extract_dbt_entities(
all_manifest_entities,
all_catalog_entities,
sources_results,
load_schemas,
use_identifiers,
tag_prefix,
node_type_pattern,
report,
node_name_pattern,
)
return (
nodes,
manifest_schema,
manifest_version,
catalog_schema,
catalog_version,
all_manifest_entities,
)
# create workunits from dbt nodes
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
if self.config.write_semantics == "PATCH" and not self.ctx.graph:
@ -784,7 +819,7 @@ class DBTSource(StatefulIngestionSourceBase):
catalog_schema,
catalog_version,
manifest_nodes_raw,
) = loadManifestAndCatalog(
) = self.loadManifestAndCatalog(
self.config.manifest_path,
self.config.catalog_path,
self.config.sources_path,
@ -1356,7 +1391,7 @@ class DBTSource(StatefulIngestionSourceBase):
"""
project_id = (
load_file_as_json(self.config.manifest_path)
self.load_file_as_json(self.config.manifest_path)
.get("metadata", {})
.get("project_id")
)