mirror of
https://github.com/datahub-project/datahub.git
synced 2025-09-25 09:00:50 +00:00
feat(dbt): enable dbt read artifacts from s3 (#4935)
Co-authored-by: Shirshanka Das <shirshanka@apache.org>
This commit is contained in:
parent
c131e13582
commit
2eb7920fe0
@ -7,7 +7,7 @@ from botocore.config import Config
|
||||
from botocore.utils import fix_s3_host
|
||||
from pydantic.fields import Field
|
||||
|
||||
from datahub.configuration.common import AllowDenyPattern
|
||||
from datahub.configuration.common import AllowDenyPattern, ConfigModel
|
||||
from datahub.configuration.source_common import EnvBasedSourceConfigBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -35,24 +35,16 @@ def assume_role(
|
||||
return assumed_role_object["Credentials"]
|
||||
|
||||
|
||||
class AwsSourceConfig(EnvBasedSourceConfigBase):
|
||||
class AwsConnectionConfig(ConfigModel):
|
||||
"""
|
||||
Common AWS credentials config.
|
||||
|
||||
Currently used by:
|
||||
- Glue source
|
||||
- SageMaker source
|
||||
- dbt source
|
||||
"""
|
||||
|
||||
database_pattern: AllowDenyPattern = Field(
|
||||
default=AllowDenyPattern.allow_all(),
|
||||
description="regex patterns for databases to filter in ingestion.",
|
||||
)
|
||||
table_pattern: AllowDenyPattern = Field(
|
||||
default=AllowDenyPattern.allow_all(),
|
||||
description="regex patterns for tables to filter in ingestion.",
|
||||
)
|
||||
|
||||
aws_access_key_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html",
|
||||
@ -157,3 +149,22 @@ class AwsSourceConfig(EnvBasedSourceConfigBase):
|
||||
|
||||
def get_sagemaker_client(self) -> "SageMakerClient":
|
||||
return self.get_session().client("sagemaker")
|
||||
|
||||
|
||||
class AwsSourceConfig(EnvBasedSourceConfigBase, AwsConnectionConfig):
|
||||
"""
|
||||
Common AWS credentials config.
|
||||
|
||||
Currently used by:
|
||||
- Glue source
|
||||
- SageMaker source
|
||||
"""
|
||||
|
||||
database_pattern: AllowDenyPattern = Field(
|
||||
default=AllowDenyPattern.allow_all(),
|
||||
description="regex patterns for databases to filter in ingestion.",
|
||||
)
|
||||
table_pattern: AllowDenyPattern = Field(
|
||||
default=AllowDenyPattern.allow_all(),
|
||||
description="regex patterns for tables to filter in ingestion.",
|
||||
)
|
||||
|
@ -3,6 +3,7 @@ import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import dateutil.parser
|
||||
import requests
|
||||
@ -23,6 +24,7 @@ from datahub.ingestion.api.decorators import (
|
||||
)
|
||||
from datahub.ingestion.api.ingestion_job_state_provider import JobId
|
||||
from datahub.ingestion.api.workunit import MetadataWorkUnit
|
||||
from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig
|
||||
from datahub.ingestion.source.sql.sql_types import (
|
||||
BIGQUERY_TYPES_MAP,
|
||||
POSTGRES_TYPES_MAP,
|
||||
@ -172,6 +174,15 @@ class DBTConfig(StatefulIngestionConfigBase):
|
||||
default=None,
|
||||
description='Regex string to extract owner from the dbt node using the `(?P<name>...) syntax` of the [match object](https://docs.python.org/3/library/re.html#match-objects), where the group name must be `owner`. Examples: (1)`r"(?P<owner>(.*)): (\w+) (\w+)"` will extract `jdoe` as the owner from `"jdoe: John Doe"` (2) `r"@(?P<owner>(.*))"` will extract `alice` as the owner from `"@alice"`.', # noqa: W605
|
||||
)
|
||||
aws_connection: Optional[AwsConnectionConfig] = Field(
|
||||
default=None,
|
||||
description="When fetching manifest files from s3, configuration for aws connection details",
|
||||
)
|
||||
|
||||
@property
|
||||
def s3_client(self):
|
||||
assert self.aws_connection
|
||||
return self.aws_connection.get_s3_client()
|
||||
|
||||
# Custom Stateful Ingestion settings
|
||||
stateful_ingestion: Optional[DBTStatefulIngestionConfig] = Field(
|
||||
@ -197,6 +208,22 @@ class DBTConfig(StatefulIngestionConfigBase):
|
||||
)
|
||||
return write_semantics
|
||||
|
||||
@validator("aws_connection")
|
||||
def aws_connection_needed_if_s3_uris_present(
|
||||
cls, aws_connection: Optional[AwsConnectionConfig], values: Dict, **kwargs: Any
|
||||
) -> Optional[AwsConnectionConfig]:
|
||||
# first check if there are fields that contain s3 uris
|
||||
uri_containing_fields = [
|
||||
f
|
||||
for f in ["manifest_path", "catalog_path", "sources_path"]
|
||||
if values.get(f, "").startswith("s3://")
|
||||
]
|
||||
if uri_containing_fields and not aws_connection:
|
||||
raise ValueError(
|
||||
f"Please provide aws_connection configuration, since s3 uris have been provided in fields {uri_containing_fields}"
|
||||
)
|
||||
return aws_connection
|
||||
|
||||
|
||||
@dataclass
|
||||
class DBTColumn:
|
||||
@ -387,80 +414,6 @@ def extract_dbt_entities(
|
||||
return dbt_entities
|
||||
|
||||
|
||||
def load_file_as_json(uri: str) -> Any:
|
||||
if re.match("^https?://", uri):
|
||||
return json.loads(requests.get(uri).text)
|
||||
else:
|
||||
with open(uri, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def loadManifestAndCatalog(
|
||||
manifest_path: str,
|
||||
catalog_path: str,
|
||||
sources_path: Optional[str],
|
||||
load_schemas: bool,
|
||||
use_identifiers: bool,
|
||||
tag_prefix: str,
|
||||
node_type_pattern: AllowDenyPattern,
|
||||
report: DBTSourceReport,
|
||||
node_name_pattern: AllowDenyPattern,
|
||||
) -> Tuple[
|
||||
List[DBTNode],
|
||||
Optional[str],
|
||||
Optional[str],
|
||||
Optional[str],
|
||||
Optional[str],
|
||||
Dict[str, Dict[str, Any]],
|
||||
]:
|
||||
dbt_manifest_json = load_file_as_json(manifest_path)
|
||||
|
||||
dbt_catalog_json = load_file_as_json(catalog_path)
|
||||
|
||||
if sources_path is not None:
|
||||
dbt_sources_json = load_file_as_json(sources_path)
|
||||
sources_results = dbt_sources_json["results"]
|
||||
else:
|
||||
sources_results = {}
|
||||
|
||||
manifest_schema = dbt_manifest_json.get("metadata", {}).get("dbt_schema_version")
|
||||
manifest_version = dbt_manifest_json.get("metadata", {}).get("dbt_version")
|
||||
|
||||
catalog_schema = dbt_catalog_json.get("metadata", {}).get("dbt_schema_version")
|
||||
catalog_version = dbt_catalog_json.get("metadata", {}).get("dbt_version")
|
||||
|
||||
manifest_nodes = dbt_manifest_json["nodes"]
|
||||
manifest_sources = dbt_manifest_json["sources"]
|
||||
|
||||
all_manifest_entities = {**manifest_nodes, **manifest_sources}
|
||||
|
||||
catalog_nodes = dbt_catalog_json["nodes"]
|
||||
catalog_sources = dbt_catalog_json["sources"]
|
||||
|
||||
all_catalog_entities = {**catalog_nodes, **catalog_sources}
|
||||
|
||||
nodes = extract_dbt_entities(
|
||||
all_manifest_entities,
|
||||
all_catalog_entities,
|
||||
sources_results,
|
||||
load_schemas,
|
||||
use_identifiers,
|
||||
tag_prefix,
|
||||
node_type_pattern,
|
||||
report,
|
||||
node_name_pattern,
|
||||
)
|
||||
|
||||
return (
|
||||
nodes,
|
||||
manifest_schema,
|
||||
manifest_version,
|
||||
catalog_schema,
|
||||
catalog_version,
|
||||
all_manifest_entities,
|
||||
)
|
||||
|
||||
|
||||
def get_db_fqn(database: Optional[str], schema: str, name: str) -> str:
|
||||
if database is not None:
|
||||
fqn = f"{database}.{schema}.{name}"
|
||||
@ -769,6 +722,88 @@ class DBTSource(StatefulIngestionSourceBase):
|
||||
):
|
||||
yield from soft_delete_item(table_urn, "dataset")
|
||||
|
||||
# s3://data-analysis.pelotime.com/dbt-artifacts/data-engineering-dbt/catalog.json
|
||||
def load_file_as_json(self, uri: str) -> Any:
|
||||
if re.match("^https?://", uri):
|
||||
return json.loads(requests.get(uri).text)
|
||||
elif re.match("^s3://", uri):
|
||||
u = urlparse(uri)
|
||||
response = self.config.s3_client.get_object(
|
||||
Bucket=u.netloc, Key=u.path.lstrip("/")
|
||||
)
|
||||
return json.loads(response["Body"].read().decode("utf-8"))
|
||||
else:
|
||||
with open(uri, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def loadManifestAndCatalog(
|
||||
self,
|
||||
manifest_path: str,
|
||||
catalog_path: str,
|
||||
sources_path: Optional[str],
|
||||
load_schemas: bool,
|
||||
use_identifiers: bool,
|
||||
tag_prefix: str,
|
||||
node_type_pattern: AllowDenyPattern,
|
||||
report: DBTSourceReport,
|
||||
node_name_pattern: AllowDenyPattern,
|
||||
) -> Tuple[
|
||||
List[DBTNode],
|
||||
Optional[str],
|
||||
Optional[str],
|
||||
Optional[str],
|
||||
Optional[str],
|
||||
Dict[str, Dict[str, Any]],
|
||||
]:
|
||||
dbt_manifest_json = self.load_file_as_json(manifest_path)
|
||||
|
||||
dbt_catalog_json = self.load_file_as_json(catalog_path)
|
||||
|
||||
if sources_path is not None:
|
||||
dbt_sources_json = self.load_file_as_json(sources_path)
|
||||
sources_results = dbt_sources_json["results"]
|
||||
else:
|
||||
sources_results = {}
|
||||
|
||||
manifest_schema = dbt_manifest_json.get("metadata", {}).get(
|
||||
"dbt_schema_version"
|
||||
)
|
||||
manifest_version = dbt_manifest_json.get("metadata", {}).get("dbt_version")
|
||||
|
||||
catalog_schema = dbt_catalog_json.get("metadata", {}).get("dbt_schema_version")
|
||||
catalog_version = dbt_catalog_json.get("metadata", {}).get("dbt_version")
|
||||
|
||||
manifest_nodes = dbt_manifest_json["nodes"]
|
||||
manifest_sources = dbt_manifest_json["sources"]
|
||||
|
||||
all_manifest_entities = {**manifest_nodes, **manifest_sources}
|
||||
|
||||
catalog_nodes = dbt_catalog_json["nodes"]
|
||||
catalog_sources = dbt_catalog_json["sources"]
|
||||
|
||||
all_catalog_entities = {**catalog_nodes, **catalog_sources}
|
||||
|
||||
nodes = extract_dbt_entities(
|
||||
all_manifest_entities,
|
||||
all_catalog_entities,
|
||||
sources_results,
|
||||
load_schemas,
|
||||
use_identifiers,
|
||||
tag_prefix,
|
||||
node_type_pattern,
|
||||
report,
|
||||
node_name_pattern,
|
||||
)
|
||||
|
||||
return (
|
||||
nodes,
|
||||
manifest_schema,
|
||||
manifest_version,
|
||||
catalog_schema,
|
||||
catalog_version,
|
||||
all_manifest_entities,
|
||||
)
|
||||
|
||||
# create workunits from dbt nodes
|
||||
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
|
||||
if self.config.write_semantics == "PATCH" and not self.ctx.graph:
|
||||
@ -784,7 +819,7 @@ class DBTSource(StatefulIngestionSourceBase):
|
||||
catalog_schema,
|
||||
catalog_version,
|
||||
manifest_nodes_raw,
|
||||
) = loadManifestAndCatalog(
|
||||
) = self.loadManifestAndCatalog(
|
||||
self.config.manifest_path,
|
||||
self.config.catalog_path,
|
||||
self.config.sources_path,
|
||||
@ -1356,7 +1391,7 @@ class DBTSource(StatefulIngestionSourceBase):
|
||||
"""
|
||||
|
||||
project_id = (
|
||||
load_file_as_json(self.config.manifest_path)
|
||||
self.load_file_as_json(self.config.manifest_path)
|
||||
.get("metadata", {})
|
||||
.get("project_id")
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user