Added dbt multiple project support for s3, azure, gcs datalake sources (#12856)

* Added dbt multiple proj support

* added reader

* common method to group by dir

* added return type
This commit is contained in:
Onkar Ravgan 2023-08-17 14:49:20 +05:30 committed by GitHub
parent 801d07289c
commit 795294c87f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 197 additions and 143 deletions

View File

@ -63,6 +63,7 @@ from metadata.utils.constants import COMPLEX_COLUMN_SEPARATOR, DEFAULT_DATABASE
from metadata.utils.datalake.datalake_utils import fetch_dataframe
from metadata.utils.filters import filter_by_schema, filter_by_table
from metadata.utils.logger import ingestion_logger
from metadata.utils.s3_utils import list_s3_objects
logger = ingestion_logger()
@ -235,15 +236,6 @@ class DatalakeSource(DatabaseServiceSource):
database=self.context.database.fullyQualifiedName,
)
def _list_s3_objects(self, **kwargs) -> Iterable:
try:
paginator = self.client.get_paginator("list_objects_v2")
for page in paginator.paginate(**kwargs):
yield from page.get("Contents", [])
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Unexpected exception to yield s3 object: {exc}")
def get_tables_name_and_type( # pylint: disable=too-many-branches
self,
) -> Optional[Iterable[Tuple[str, str]]]:
@ -297,7 +289,7 @@ class DatalakeSource(DatabaseServiceSource):
kwargs = {"Bucket": bucket_name}
if prefix:
kwargs["Prefix"] = prefix if prefix.endswith("/") else f"{prefix}/"
for key in self._list_s3_objects(**kwargs):
for key in list_s3_objects(self.client, **kwargs):
table_name = self.standardize_table_name(bucket_name, key["Key"])
table_fqn = fqn.build(
self.metadata,

View File

@ -26,6 +26,12 @@ DBT_CATALOG_FILE_NAME = "catalog.json"
DBT_MANIFEST_FILE_NAME = "manifest.json"
DBT_RUN_RESULTS_FILE_NAME = "run_results.json"
DBT_FILE_NAMES_LIST = [
DBT_CATALOG_FILE_NAME,
DBT_MANIFEST_FILE_NAME,
DBT_RUN_RESULTS_FILE_NAME,
]
class SkipResourceTypeEnum(Enum):
"""

View File

@ -13,11 +13,13 @@ Hosts the singledispatch to get DBT files
"""
import json
import traceback
from collections import defaultdict
from functools import singledispatch
from typing import Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple
import requests
from metadata.clients.aws_client import AWSClient
from metadata.generated.schema.metadataIngestion.dbtconfig.dbtAzureConfig import (
DbtAzureConfig,
)
@ -38,19 +40,22 @@ from metadata.generated.schema.metadataIngestion.dbtconfig.dbtS3Config import (
)
from metadata.ingestion.source.database.dbt.constants import (
DBT_CATALOG_FILE_NAME,
DBT_FILE_NAMES_LIST,
DBT_MANIFEST_FILE_NAME,
DBT_RUN_RESULTS_FILE_NAME,
)
from metadata.ingestion.source.database.dbt.models import DbtFiles
from metadata.readers.file.config_source_factory import get_reader
from metadata.utils.credentials import set_google_credentials
from metadata.utils.logger import ometa_logger
from metadata.utils.s3_utils import list_s3_objects
logger = ometa_logger()
class DBTConfigException(Exception):
"""
Raise when encountering errors while extacting dbt files
Raise when encountering errors while extracting dbt files
"""
@ -69,32 +74,23 @@ def get_dbt_details(config):
@get_dbt_details.register
def _(config: DbtLocalConfig):
try:
dbt_run_results = None
dbt_catalog = None
if config.dbtManifestFilePath is not None:
logger.debug(
f"Reading [dbtManifestFilePath] from: {config.dbtCatalogFilePath}"
)
with open(config.dbtManifestFilePath, "r", encoding="utf-8") as manifest:
dbt_manifest = manifest.read()
if config.dbtRunResultsFilePath:
logger.debug(
f"Reading [dbtRunResultsFilePath] from: {config.dbtRunResultsFilePath}"
)
with open(
config.dbtRunResultsFilePath, "r", encoding="utf-8"
) as run_results:
dbt_run_results = run_results.read()
if config.dbtCatalogFilePath:
logger.debug(
f"Reading [dbtCatalogFilePath] from: {config.dbtCatalogFilePath}"
)
with open(config.dbtCatalogFilePath, "r", encoding="utf-8") as catalog:
dbt_catalog = catalog.read()
return DbtFiles(
dbt_catalog=json.loads(dbt_catalog) if dbt_catalog else None,
dbt_manifest=json.loads(dbt_manifest),
dbt_run_results=json.loads(dbt_run_results) if dbt_run_results else None,
blob_grouped_by_directory = defaultdict(list)
subdirectory = (
config.dbtManifestFilePath.rsplit("/", 1)[0]
if "/" in config.dbtManifestFilePath
else ""
)
blob_grouped_by_directory[subdirectory] = [
config.dbtManifestFilePath,
config.dbtCatalogFilePath,
config.dbtRunResultsFilePath,
]
yield from download_dbt_files(
blob_grouped_by_directory=blob_grouped_by_directory,
config=config,
client=None,
bucket_name=None,
)
except Exception as exc:
logger.debug(traceback.format_exc())
@ -129,7 +125,7 @@ def _(config: DbtHttpConfig):
)
if not dbt_manifest:
raise DBTConfigException("Manifest file not found in file server")
return DbtFiles(
yield DbtFiles(
dbt_catalog=dbt_catalog.json() if dbt_catalog else None,
dbt_manifest=dbt_manifest.json(),
dbt_run_results=dbt_run_results.json() if dbt_run_results else None,
@ -206,7 +202,7 @@ def _(config: DbtCloudConfig): # pylint: disable=too-many-locals
if not dbt_manifest:
raise DBTConfigException("Manifest file not found in DBT Cloud")
return DbtFiles(
yield DbtFiles(
dbt_catalog=dbt_catalog,
dbt_manifest=dbt_manifest,
dbt_run_results=dbt_run_results,
@ -218,47 +214,81 @@ def _(config: DbtCloudConfig): # pylint: disable=too-many-locals
raise DBTConfigException(f"Error fetching dbt files from DBT Cloud: {exc}")
def get_blobs_grouped_by_dir(blobs: List[str]) -> Dict[str, List[str]]:
"""
Method to group the objs by the dir
"""
blob_grouped_by_directory = defaultdict(list)
for blob in blobs:
if [file_name for file_name in DBT_FILE_NAMES_LIST if file_name in blob]:
subdirectory = blob.rsplit("/", 1)[0] if "/" in blob else ""
blob_grouped_by_directory[subdirectory].append(blob)
return blob_grouped_by_directory
def download_dbt_files(
blob_grouped_by_directory: Dict, config, client, bucket_name: Optional[str]
) -> Iterable[DbtFiles]:
"""
Method to download the files from sources
"""
for key, blobs in blob_grouped_by_directory.items():
dbt_catalog = None
dbt_manifest = None
dbt_run_results = None
kwargs = {}
if bucket_name:
kwargs = {"bucket_name": bucket_name}
try:
for blob in blobs:
reader = get_reader(config_source=config, client=client)
if DBT_MANIFEST_FILE_NAME in blob:
logger.debug(f"{DBT_MANIFEST_FILE_NAME} found in {key}")
dbt_manifest = reader.read(path=blob, **kwargs)
if DBT_CATALOG_FILE_NAME in blob:
logger.debug(f"{DBT_CATALOG_FILE_NAME} found in {key}")
dbt_catalog = reader.read(path=blob, **kwargs)
if DBT_RUN_RESULTS_FILE_NAME in blob:
logger.debug(f"{DBT_RUN_RESULTS_FILE_NAME} found in {key}")
dbt_run_results = reader.read(path=blob, **kwargs)
if not dbt_manifest:
raise DBTConfigException(f"Manifest file not found at: {key}")
yield DbtFiles(
dbt_catalog=json.loads(dbt_catalog) if dbt_catalog else None,
dbt_manifest=json.loads(dbt_manifest),
dbt_run_results=json.loads(dbt_run_results)
if dbt_run_results
else None,
)
except DBTConfigException as exc:
logger.warning(exc)
@get_dbt_details.register
def _(config: DbtS3Config):
dbt_catalog = None
dbt_manifest = None
dbt_run_results = None
try:
bucket_name, prefix = get_dbt_prefix_config(config)
from metadata.clients.aws_client import ( # pylint: disable=import-outside-toplevel
AWSClient,
)
aws_client = AWSClient(config.dbtSecurityConfig).get_resource("s3")
client = AWSClient(config.dbtSecurityConfig).get_client(service_name="s3")
if not bucket_name:
buckets = aws_client.buckets.all()
buckets = client.list_buckets()["Buckets"]
else:
buckets = [aws_client.Bucket(bucket_name)]
buckets = [{"Name": bucket_name}]
for bucket in buckets:
kwargs = {"Bucket": bucket["Name"]}
if prefix:
obj_list = bucket.objects.filter(Prefix=prefix)
else:
obj_list = bucket.objects.all()
for bucket_object in obj_list:
if DBT_MANIFEST_FILE_NAME in bucket_object.key:
logger.debug(f"{DBT_MANIFEST_FILE_NAME} found")
dbt_manifest = bucket_object.get()["Body"].read().decode()
if DBT_CATALOG_FILE_NAME in bucket_object.key:
logger.debug(f"{DBT_CATALOG_FILE_NAME} found")
dbt_catalog = bucket_object.get()["Body"].read().decode()
if DBT_RUN_RESULTS_FILE_NAME in bucket_object.key:
logger.debug(f"{DBT_RUN_RESULTS_FILE_NAME} found")
dbt_run_results = bucket_object.get()["Body"].read().decode()
if not dbt_manifest:
raise DBTConfigException("Manifest file not found in s3")
return DbtFiles(
dbt_catalog=json.loads(dbt_catalog) if dbt_catalog else None,
dbt_manifest=json.loads(dbt_manifest),
dbt_run_results=json.loads(dbt_run_results) if dbt_run_results else None,
)
except DBTConfigException as exc:
raise exc
kwargs["Prefix"] = prefix if prefix.endswith("/") else f"{prefix}/"
yield from download_dbt_files(
blob_grouped_by_directory=get_blobs_grouped_by_dir(
blobs=[key["Key"] for key in list_s3_objects(client, **kwargs)]
),
config=config,
client=client,
bucket_name=bucket["Name"],
)
except Exception as exc:
logger.debug(traceback.format_exc())
raise DBTConfigException(f"Error fetching dbt files from s3: {exc}")
@ -266,14 +296,12 @@ def _(config: DbtS3Config):
@get_dbt_details.register
def _(config: DbtGcsConfig):
dbt_catalog = None
dbt_manifest = None
dbt_run_results = None
try:
bucket_name, prefix = get_dbt_prefix_config(config)
from google.cloud import storage # pylint: disable=import-outside-toplevel
set_google_credentials(gcp_credentials=config.dbtSecurityConfig)
client = storage.Client()
if not bucket_name:
buckets = client.list_buckets()
@ -284,25 +312,15 @@ def _(config: DbtGcsConfig):
obj_list = client.list_blobs(bucket.name, prefix=prefix)
else:
obj_list = client.list_blobs(bucket.name)
for blob in obj_list:
if DBT_MANIFEST_FILE_NAME in blob.name:
logger.debug(f"{DBT_MANIFEST_FILE_NAME} found")
dbt_manifest = blob.download_as_string().decode()
if DBT_CATALOG_FILE_NAME in blob.name:
logger.debug(f"{DBT_CATALOG_FILE_NAME} found")
dbt_catalog = blob.download_as_string().decode()
if DBT_RUN_RESULTS_FILE_NAME in blob.name:
logger.debug(f"{DBT_RUN_RESULTS_FILE_NAME} found")
dbt_run_results = blob.download_as_string().decode()
if not dbt_manifest:
raise DBTConfigException("Manifest file not found in gcs")
return DbtFiles(
dbt_catalog=json.loads(dbt_catalog) if dbt_catalog else None,
dbt_manifest=json.loads(dbt_manifest),
dbt_run_results=json.loads(dbt_run_results) if dbt_run_results else None,
)
except DBTConfigException as exc:
raise exc
yield from download_dbt_files(
blob_grouped_by_directory=get_blobs_grouped_by_dir(
blobs=[blob.name for blob in obj_list]
),
config=config,
client=client,
bucket_name=bucket.name,
)
except Exception as exc:
logger.debug(traceback.format_exc())
raise DBTConfigException(f"Error fetching dbt files from gcs: {exc}")
@ -310,9 +328,6 @@ def _(config: DbtGcsConfig):
@get_dbt_details.register
def _(config: DbtAzureConfig):
dbt_catalog = None
dbt_manifest = None
dbt_run_results = None
try:
bucket_name, prefix = get_dbt_prefix_config(config)
from azure.identity import ( # pylint: disable=import-outside-toplevel
@ -322,7 +337,7 @@ def _(config: DbtAzureConfig):
BlobServiceClient,
)
azure_client = BlobServiceClient(
client = BlobServiceClient(
f"https://{config.dbtSecurityConfig.accountName}.blob.core.windows.net/",
credential=ClientSecretCredential(
config.dbtSecurityConfig.tenantId,
@ -332,50 +347,29 @@ def _(config: DbtAzureConfig):
)
if not bucket_name:
container_dicts = azure_client.list_containers()
container_dicts = client.list_containers()
containers = [
azure_client.get_container_client(container["name"])
client.get_container_client(container["name"])
for container in container_dicts
]
else:
container_client = azure_client.get_container_client(bucket_name)
container_client = client.get_container_client(bucket_name)
containers = [container_client]
for container_client in containers:
if prefix:
blob_list = container_client.list_blobs(name_starts_with=prefix)
else:
blob_list = container_client.list_blobs()
for blob in blob_list:
if DBT_MANIFEST_FILE_NAME in blob.name:
logger.debug(f"{DBT_MANIFEST_FILE_NAME} found")
dbt_manifest = (
container_client.download_blob(blob.name)
.readall()
.decode("utf-8")
)
if DBT_CATALOG_FILE_NAME in blob.name:
logger.debug(f"{DBT_CATALOG_FILE_NAME} found")
dbt_catalog = (
container_client.download_blob(blob.name)
.readall()
.decode("utf-8")
)
if DBT_RUN_RESULTS_FILE_NAME in blob.name:
logger.debug(f"{DBT_RUN_RESULTS_FILE_NAME} found")
dbt_run_results = (
container_client.download_blob(blob.name)
.readall()
.decode("utf-8")
)
if not dbt_manifest:
raise DBTConfigException("Manifest file not found in Azure")
return DbtFiles(
dbt_catalog=json.loads(dbt_catalog) if dbt_catalog else None,
dbt_manifest=json.loads(dbt_manifest),
dbt_run_results=json.loads(dbt_run_results) if dbt_run_results else None,
)
except DBTConfigException as exc:
raise exc
yield from download_dbt_files(
blob_grouped_by_directory=get_blobs_grouped_by_dir(
blobs=[blob.name for blob in blob_list]
),
config=config,
client=client,
bucket_name=container_client.container_name,
)
except Exception as exc:
logger.debug(traceback.format_exc())
raise DBTConfigException(f"Error fetching dbt files from Azure: {exc}")

View File

@ -55,6 +55,11 @@ class DbtServiceTopology(ServiceTopology):
root = TopologyNode(
producer="get_dbt_files",
stages=[],
children=["process_dbt_files"],
)
process_dbt_files = TopologyNode(
producer="process_dbt_files",
stages=[
NodeStage(
type_=DbtFiles,
@ -160,22 +165,29 @@ class DbtServiceSource(TopologyRunnerMixin, Source, ABC):
}
)
def get_dbt_files(self) -> DbtFiles:
dbt_files = get_dbt_details(self.source_config.dbtConfigSource)
self.context.dbt_files = dbt_files
yield dbt_files
def process_dbt_files(self) -> Iterable[DbtFiles]:
"""
Method return the dbt file from topology
"""
yield self.context.dbt_file
def get_dbt_objects(self) -> DbtObjects:
def get_dbt_files(self) -> Iterable[DbtFiles]:
dbt_files = get_dbt_details(self.source_config.dbtConfigSource)
for dbt_file in dbt_files:
self.context.dbt_file = dbt_file
yield dbt_file
def get_dbt_objects(self) -> Iterable[DbtObjects]:
self.remove_manifest_non_required_keys(
manifest_dict=self.context.dbt_files.dbt_manifest
manifest_dict=self.context.dbt_file.dbt_manifest
)
dbt_objects = DbtObjects(
dbt_catalog=parse_catalog(self.context.dbt_files.dbt_catalog)
if self.context.dbt_files.dbt_catalog
dbt_catalog=parse_catalog(self.context.dbt_file.dbt_catalog)
if self.context.dbt_file.dbt_catalog
else None,
dbt_manifest=parse_manifest(self.context.dbt_files.dbt_manifest),
dbt_run_results=parse_run_results(self.context.dbt_files.dbt_run_results)
if self.context.dbt_files.dbt_run_results
dbt_manifest=parse_manifest(self.context.dbt_file.dbt_manifest),
dbt_run_results=parse_run_results(self.context.dbt_file.dbt_run_results)
if self.context.dbt_file.dbt_run_results
else None,
)
yield dbt_objects
@ -200,7 +212,7 @@ class DbtServiceSource(TopologyRunnerMixin, Source, ABC):
Yield the data models
"""
def get_data_model(self) -> DataModelLink:
def get_data_model(self) -> Iterable[DataModelLink]:
"""
Prepare the data models
"""

View File

@ -30,6 +30,18 @@ from metadata.generated.schema.entity.services.connections.database.datalake.s3C
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
LocalConfig,
)
from metadata.generated.schema.metadataIngestion.dbtconfig.dbtAzureConfig import (
DbtAzureConfig,
)
from metadata.generated.schema.metadataIngestion.dbtconfig.dbtGCSConfig import (
DbtGcsConfig,
)
from metadata.generated.schema.metadataIngestion.dbtconfig.dbtLocalConfig import (
DbtLocalConfig,
)
from metadata.generated.schema.metadataIngestion.dbtconfig.dbtS3Config import (
DbtS3Config,
)
from metadata.readers.file.adls import ADLSReader
from metadata.readers.file.base import Reader
from metadata.readers.file.gcs import GCSReader
@ -42,6 +54,10 @@ CONFIG_SOURCE_READER = {
AzureConfig.__name__: ADLSReader,
GCSConfig.__name__: GCSReader,
S3Config.__name__: S3Reader,
DbtLocalConfig.__name__: LocalReader,
DbtAzureConfig.__name__: ADLSReader,
DbtGcsConfig.__name__: GCSReader,
DbtS3Config.__name__: S3Reader,
}

View File

@ -0,0 +1,34 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
s3 utils module
"""
import traceback
from typing import Iterable
from metadata.utils.logger import utils_logger
logger = utils_logger()
def list_s3_objects(client, **kwargs) -> Iterable:
"""
Method to get list of s3 objects using pagination
"""
try:
paginator = client.get_paginator("list_objects_v2")
for page in paginator.paginate(**kwargs):
yield from page.get("Contents", [])
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Unexpected exception to yield s3 object: {exc}")