mirror of
https://github.com/datahub-project/datahub.git
synced 2025-09-27 18:14:54 +00:00
feat(dbt): enable dbt read artifacts from s3 (#4935)
Co-authored-by: Shirshanka Das <shirshanka@apache.org>
This commit is contained in:
parent
c131e13582
commit
2eb7920fe0
@ -7,7 +7,7 @@ from botocore.config import Config
|
|||||||
from botocore.utils import fix_s3_host
|
from botocore.utils import fix_s3_host
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
from datahub.configuration.common import AllowDenyPattern
|
from datahub.configuration.common import AllowDenyPattern, ConfigModel
|
||||||
from datahub.configuration.source_common import EnvBasedSourceConfigBase
|
from datahub.configuration.source_common import EnvBasedSourceConfigBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -35,24 +35,16 @@ def assume_role(
|
|||||||
return assumed_role_object["Credentials"]
|
return assumed_role_object["Credentials"]
|
||||||
|
|
||||||
|
|
||||||
class AwsSourceConfig(EnvBasedSourceConfigBase):
|
class AwsConnectionConfig(ConfigModel):
|
||||||
"""
|
"""
|
||||||
Common AWS credentials config.
|
Common AWS credentials config.
|
||||||
|
|
||||||
Currently used by:
|
Currently used by:
|
||||||
- Glue source
|
- Glue source
|
||||||
- SageMaker source
|
- SageMaker source
|
||||||
|
- dbt source
|
||||||
"""
|
"""
|
||||||
|
|
||||||
database_pattern: AllowDenyPattern = Field(
|
|
||||||
default=AllowDenyPattern.allow_all(),
|
|
||||||
description="regex patterns for databases to filter in ingestion.",
|
|
||||||
)
|
|
||||||
table_pattern: AllowDenyPattern = Field(
|
|
||||||
default=AllowDenyPattern.allow_all(),
|
|
||||||
description="regex patterns for tables to filter in ingestion.",
|
|
||||||
)
|
|
||||||
|
|
||||||
aws_access_key_id: Optional[str] = Field(
|
aws_access_key_id: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html",
|
description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html",
|
||||||
@ -157,3 +149,22 @@ class AwsSourceConfig(EnvBasedSourceConfigBase):
|
|||||||
|
|
||||||
def get_sagemaker_client(self) -> "SageMakerClient":
|
def get_sagemaker_client(self) -> "SageMakerClient":
|
||||||
return self.get_session().client("sagemaker")
|
return self.get_session().client("sagemaker")
|
||||||
|
|
||||||
|
|
||||||
|
class AwsSourceConfig(EnvBasedSourceConfigBase, AwsConnectionConfig):
|
||||||
|
"""
|
||||||
|
Common AWS credentials config.
|
||||||
|
|
||||||
|
Currently used by:
|
||||||
|
- Glue source
|
||||||
|
- SageMaker source
|
||||||
|
"""
|
||||||
|
|
||||||
|
database_pattern: AllowDenyPattern = Field(
|
||||||
|
default=AllowDenyPattern.allow_all(),
|
||||||
|
description="regex patterns for databases to filter in ingestion.",
|
||||||
|
)
|
||||||
|
table_pattern: AllowDenyPattern = Field(
|
||||||
|
default=AllowDenyPattern.allow_all(),
|
||||||
|
description="regex patterns for tables to filter in ingestion.",
|
||||||
|
)
|
||||||
|
@ -3,6 +3,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, cast
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, cast
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import dateutil.parser
|
import dateutil.parser
|
||||||
import requests
|
import requests
|
||||||
@ -23,6 +24,7 @@ from datahub.ingestion.api.decorators import (
|
|||||||
)
|
)
|
||||||
from datahub.ingestion.api.ingestion_job_state_provider import JobId
|
from datahub.ingestion.api.ingestion_job_state_provider import JobId
|
||||||
from datahub.ingestion.api.workunit import MetadataWorkUnit
|
from datahub.ingestion.api.workunit import MetadataWorkUnit
|
||||||
|
from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig
|
||||||
from datahub.ingestion.source.sql.sql_types import (
|
from datahub.ingestion.source.sql.sql_types import (
|
||||||
BIGQUERY_TYPES_MAP,
|
BIGQUERY_TYPES_MAP,
|
||||||
POSTGRES_TYPES_MAP,
|
POSTGRES_TYPES_MAP,
|
||||||
@ -172,6 +174,15 @@ class DBTConfig(StatefulIngestionConfigBase):
|
|||||||
default=None,
|
default=None,
|
||||||
description='Regex string to extract owner from the dbt node using the `(?P<name>...) syntax` of the [match object](https://docs.python.org/3/library/re.html#match-objects), where the group name must be `owner`. Examples: (1)`r"(?P<owner>(.*)): (\w+) (\w+)"` will extract `jdoe` as the owner from `"jdoe: John Doe"` (2) `r"@(?P<owner>(.*))"` will extract `alice` as the owner from `"@alice"`.', # noqa: W605
|
description='Regex string to extract owner from the dbt node using the `(?P<name>...) syntax` of the [match object](https://docs.python.org/3/library/re.html#match-objects), where the group name must be `owner`. Examples: (1)`r"(?P<owner>(.*)): (\w+) (\w+)"` will extract `jdoe` as the owner from `"jdoe: John Doe"` (2) `r"@(?P<owner>(.*))"` will extract `alice` as the owner from `"@alice"`.', # noqa: W605
|
||||||
)
|
)
|
||||||
|
aws_connection: Optional[AwsConnectionConfig] = Field(
|
||||||
|
default=None,
|
||||||
|
description="When fetching manifest files from s3, configuration for aws connection details",
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def s3_client(self):
|
||||||
|
assert self.aws_connection
|
||||||
|
return self.aws_connection.get_s3_client()
|
||||||
|
|
||||||
# Custom Stateful Ingestion settings
|
# Custom Stateful Ingestion settings
|
||||||
stateful_ingestion: Optional[DBTStatefulIngestionConfig] = Field(
|
stateful_ingestion: Optional[DBTStatefulIngestionConfig] = Field(
|
||||||
@ -197,6 +208,22 @@ class DBTConfig(StatefulIngestionConfigBase):
|
|||||||
)
|
)
|
||||||
return write_semantics
|
return write_semantics
|
||||||
|
|
||||||
|
@validator("aws_connection")
|
||||||
|
def aws_connection_needed_if_s3_uris_present(
|
||||||
|
cls, aws_connection: Optional[AwsConnectionConfig], values: Dict, **kwargs: Any
|
||||||
|
) -> Optional[AwsConnectionConfig]:
|
||||||
|
# first check if there are fields that contain s3 uris
|
||||||
|
uri_containing_fields = [
|
||||||
|
f
|
||||||
|
for f in ["manifest_path", "catalog_path", "sources_path"]
|
||||||
|
if values.get(f, "").startswith("s3://")
|
||||||
|
]
|
||||||
|
if uri_containing_fields and not aws_connection:
|
||||||
|
raise ValueError(
|
||||||
|
f"Please provide aws_connection configuration, since s3 uris have been provided in fields {uri_containing_fields}"
|
||||||
|
)
|
||||||
|
return aws_connection
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DBTColumn:
|
class DBTColumn:
|
||||||
@ -387,80 +414,6 @@ def extract_dbt_entities(
|
|||||||
return dbt_entities
|
return dbt_entities
|
||||||
|
|
||||||
|
|
||||||
def load_file_as_json(uri: str) -> Any:
|
|
||||||
if re.match("^https?://", uri):
|
|
||||||
return json.loads(requests.get(uri).text)
|
|
||||||
else:
|
|
||||||
with open(uri, "r") as f:
|
|
||||||
return json.load(f)
|
|
||||||
|
|
||||||
|
|
||||||
def loadManifestAndCatalog(
|
|
||||||
manifest_path: str,
|
|
||||||
catalog_path: str,
|
|
||||||
sources_path: Optional[str],
|
|
||||||
load_schemas: bool,
|
|
||||||
use_identifiers: bool,
|
|
||||||
tag_prefix: str,
|
|
||||||
node_type_pattern: AllowDenyPattern,
|
|
||||||
report: DBTSourceReport,
|
|
||||||
node_name_pattern: AllowDenyPattern,
|
|
||||||
) -> Tuple[
|
|
||||||
List[DBTNode],
|
|
||||||
Optional[str],
|
|
||||||
Optional[str],
|
|
||||||
Optional[str],
|
|
||||||
Optional[str],
|
|
||||||
Dict[str, Dict[str, Any]],
|
|
||||||
]:
|
|
||||||
dbt_manifest_json = load_file_as_json(manifest_path)
|
|
||||||
|
|
||||||
dbt_catalog_json = load_file_as_json(catalog_path)
|
|
||||||
|
|
||||||
if sources_path is not None:
|
|
||||||
dbt_sources_json = load_file_as_json(sources_path)
|
|
||||||
sources_results = dbt_sources_json["results"]
|
|
||||||
else:
|
|
||||||
sources_results = {}
|
|
||||||
|
|
||||||
manifest_schema = dbt_manifest_json.get("metadata", {}).get("dbt_schema_version")
|
|
||||||
manifest_version = dbt_manifest_json.get("metadata", {}).get("dbt_version")
|
|
||||||
|
|
||||||
catalog_schema = dbt_catalog_json.get("metadata", {}).get("dbt_schema_version")
|
|
||||||
catalog_version = dbt_catalog_json.get("metadata", {}).get("dbt_version")
|
|
||||||
|
|
||||||
manifest_nodes = dbt_manifest_json["nodes"]
|
|
||||||
manifest_sources = dbt_manifest_json["sources"]
|
|
||||||
|
|
||||||
all_manifest_entities = {**manifest_nodes, **manifest_sources}
|
|
||||||
|
|
||||||
catalog_nodes = dbt_catalog_json["nodes"]
|
|
||||||
catalog_sources = dbt_catalog_json["sources"]
|
|
||||||
|
|
||||||
all_catalog_entities = {**catalog_nodes, **catalog_sources}
|
|
||||||
|
|
||||||
nodes = extract_dbt_entities(
|
|
||||||
all_manifest_entities,
|
|
||||||
all_catalog_entities,
|
|
||||||
sources_results,
|
|
||||||
load_schemas,
|
|
||||||
use_identifiers,
|
|
||||||
tag_prefix,
|
|
||||||
node_type_pattern,
|
|
||||||
report,
|
|
||||||
node_name_pattern,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
nodes,
|
|
||||||
manifest_schema,
|
|
||||||
manifest_version,
|
|
||||||
catalog_schema,
|
|
||||||
catalog_version,
|
|
||||||
all_manifest_entities,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_db_fqn(database: Optional[str], schema: str, name: str) -> str:
|
def get_db_fqn(database: Optional[str], schema: str, name: str) -> str:
|
||||||
if database is not None:
|
if database is not None:
|
||||||
fqn = f"{database}.{schema}.{name}"
|
fqn = f"{database}.{schema}.{name}"
|
||||||
@ -769,6 +722,88 @@ class DBTSource(StatefulIngestionSourceBase):
|
|||||||
):
|
):
|
||||||
yield from soft_delete_item(table_urn, "dataset")
|
yield from soft_delete_item(table_urn, "dataset")
|
||||||
|
|
||||||
|
# s3://data-analysis.pelotime.com/dbt-artifacts/data-engineering-dbt/catalog.json
|
||||||
|
def load_file_as_json(self, uri: str) -> Any:
|
||||||
|
if re.match("^https?://", uri):
|
||||||
|
return json.loads(requests.get(uri).text)
|
||||||
|
elif re.match("^s3://", uri):
|
||||||
|
u = urlparse(uri)
|
||||||
|
response = self.config.s3_client.get_object(
|
||||||
|
Bucket=u.netloc, Key=u.path.lstrip("/")
|
||||||
|
)
|
||||||
|
return json.loads(response["Body"].read().decode("utf-8"))
|
||||||
|
else:
|
||||||
|
with open(uri, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def loadManifestAndCatalog(
|
||||||
|
self,
|
||||||
|
manifest_path: str,
|
||||||
|
catalog_path: str,
|
||||||
|
sources_path: Optional[str],
|
||||||
|
load_schemas: bool,
|
||||||
|
use_identifiers: bool,
|
||||||
|
tag_prefix: str,
|
||||||
|
node_type_pattern: AllowDenyPattern,
|
||||||
|
report: DBTSourceReport,
|
||||||
|
node_name_pattern: AllowDenyPattern,
|
||||||
|
) -> Tuple[
|
||||||
|
List[DBTNode],
|
||||||
|
Optional[str],
|
||||||
|
Optional[str],
|
||||||
|
Optional[str],
|
||||||
|
Optional[str],
|
||||||
|
Dict[str, Dict[str, Any]],
|
||||||
|
]:
|
||||||
|
dbt_manifest_json = self.load_file_as_json(manifest_path)
|
||||||
|
|
||||||
|
dbt_catalog_json = self.load_file_as_json(catalog_path)
|
||||||
|
|
||||||
|
if sources_path is not None:
|
||||||
|
dbt_sources_json = self.load_file_as_json(sources_path)
|
||||||
|
sources_results = dbt_sources_json["results"]
|
||||||
|
else:
|
||||||
|
sources_results = {}
|
||||||
|
|
||||||
|
manifest_schema = dbt_manifest_json.get("metadata", {}).get(
|
||||||
|
"dbt_schema_version"
|
||||||
|
)
|
||||||
|
manifest_version = dbt_manifest_json.get("metadata", {}).get("dbt_version")
|
||||||
|
|
||||||
|
catalog_schema = dbt_catalog_json.get("metadata", {}).get("dbt_schema_version")
|
||||||
|
catalog_version = dbt_catalog_json.get("metadata", {}).get("dbt_version")
|
||||||
|
|
||||||
|
manifest_nodes = dbt_manifest_json["nodes"]
|
||||||
|
manifest_sources = dbt_manifest_json["sources"]
|
||||||
|
|
||||||
|
all_manifest_entities = {**manifest_nodes, **manifest_sources}
|
||||||
|
|
||||||
|
catalog_nodes = dbt_catalog_json["nodes"]
|
||||||
|
catalog_sources = dbt_catalog_json["sources"]
|
||||||
|
|
||||||
|
all_catalog_entities = {**catalog_nodes, **catalog_sources}
|
||||||
|
|
||||||
|
nodes = extract_dbt_entities(
|
||||||
|
all_manifest_entities,
|
||||||
|
all_catalog_entities,
|
||||||
|
sources_results,
|
||||||
|
load_schemas,
|
||||||
|
use_identifiers,
|
||||||
|
tag_prefix,
|
||||||
|
node_type_pattern,
|
||||||
|
report,
|
||||||
|
node_name_pattern,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
nodes,
|
||||||
|
manifest_schema,
|
||||||
|
manifest_version,
|
||||||
|
catalog_schema,
|
||||||
|
catalog_version,
|
||||||
|
all_manifest_entities,
|
||||||
|
)
|
||||||
|
|
||||||
# create workunits from dbt nodes
|
# create workunits from dbt nodes
|
||||||
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
|
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
|
||||||
if self.config.write_semantics == "PATCH" and not self.ctx.graph:
|
if self.config.write_semantics == "PATCH" and not self.ctx.graph:
|
||||||
@ -784,7 +819,7 @@ class DBTSource(StatefulIngestionSourceBase):
|
|||||||
catalog_schema,
|
catalog_schema,
|
||||||
catalog_version,
|
catalog_version,
|
||||||
manifest_nodes_raw,
|
manifest_nodes_raw,
|
||||||
) = loadManifestAndCatalog(
|
) = self.loadManifestAndCatalog(
|
||||||
self.config.manifest_path,
|
self.config.manifest_path,
|
||||||
self.config.catalog_path,
|
self.config.catalog_path,
|
||||||
self.config.sources_path,
|
self.config.sources_path,
|
||||||
@ -1356,7 +1391,7 @@ class DBTSource(StatefulIngestionSourceBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
project_id = (
|
project_id = (
|
||||||
load_file_as_json(self.config.manifest_path)
|
self.load_file_as_json(self.config.manifest_path)
|
||||||
.get("metadata", {})
|
.get("metadata", {})
|
||||||
.get("project_id")
|
.get("project_id")
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user