feat(ingest/dbt): make catalog.json optional (#13352)

This commit is contained in:
Harshal Sheth 2025-04-29 10:39:53 -07:00 committed by GitHub
parent 0029cbedf6
commit d264a7afba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 131 additions and 90 deletions

View File

@ -10,14 +10,12 @@ from pydantic import Field, root_validator
from datahub.ingestion.api.decorators import ( from datahub.ingestion.api.decorators import (
SupportStatus, SupportStatus,
capability,
config_class, config_class,
platform_name, platform_name,
support_status, support_status,
) )
from datahub.ingestion.api.source import ( from datahub.ingestion.api.source import (
CapabilityReport, CapabilityReport,
SourceCapability,
TestableSource, TestableSource,
TestConnectionReport, TestConnectionReport,
) )
@ -262,16 +260,14 @@ query DatahubMetadataQuery_{type}($jobId: BigInt!, $runId: BigInt) {{
@platform_name("dbt") @platform_name("dbt")
@config_class(DBTCloudConfig) @config_class(DBTCloudConfig)
@support_status(SupportStatus.INCUBATING) @support_status(SupportStatus.CERTIFIED)
@capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion")
@capability(SourceCapability.LINEAGE_COARSE, "Enabled by default")
class DBTCloudSource(DBTSourceBase, TestableSource): class DBTCloudSource(DBTSourceBase, TestableSource):
config: DBTCloudConfig config: DBTCloudConfig
@classmethod @classmethod
def create(cls, config_dict, ctx): def create(cls, config_dict, ctx):
config = DBTCloudConfig.parse_obj(config_dict) config = DBTCloudConfig.parse_obj(config_dict)
return cls(config, ctx, "dbt") return cls(config, ctx)
@staticmethod @staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport: def test_connection(config_dict: dict) -> TestConnectionReport:

View File

@ -125,6 +125,7 @@ _DEFAULT_ACTOR = mce_builder.make_user_urn("unknown")
@dataclass @dataclass
class DBTSourceReport(StaleEntityRemovalSourceReport): class DBTSourceReport(StaleEntityRemovalSourceReport):
sql_parser_skipped_missing_code: LossyList[str] = field(default_factory=LossyList) sql_parser_skipped_missing_code: LossyList[str] = field(default_factory=LossyList)
sql_parser_skipped_non_sql_model: LossyList[str] = field(default_factory=LossyList)
sql_parser_parse_failures: int = 0 sql_parser_parse_failures: int = 0
sql_parser_detach_ctes_failures: int = 0 sql_parser_detach_ctes_failures: int = 0
sql_parser_table_errors: int = 0 sql_parser_table_errors: int = 0
@ -829,11 +830,13 @@ def get_column_type(
"Enabled by default, configure using `include_column_lineage`", "Enabled by default, configure using `include_column_lineage`",
) )
class DBTSourceBase(StatefulIngestionSourceBase): class DBTSourceBase(StatefulIngestionSourceBase):
def __init__(self, config: DBTCommonConfig, ctx: PipelineContext, platform: str): def __init__(self, config: DBTCommonConfig, ctx: PipelineContext):
super().__init__(config, ctx) super().__init__(config, ctx)
self.platform: str = "dbt"
self.config = config self.config = config
self.platform: str = platform
self.report: DBTSourceReport = DBTSourceReport() self.report: DBTSourceReport = DBTSourceReport()
self.compiled_owner_extraction_pattern: Optional[Any] = None self.compiled_owner_extraction_pattern: Optional[Any] = None
if self.config.owner_extraction_pattern: if self.config.owner_extraction_pattern:
self.compiled_owner_extraction_pattern = re.compile( self.compiled_owner_extraction_pattern = re.compile(
@ -1177,6 +1180,11 @@ class DBTSourceBase(StatefulIngestionSourceBase):
logger.debug( logger.debug(
f"Not generating CLL for {node.dbt_name} because we don't need it." f"Not generating CLL for {node.dbt_name} because we don't need it."
) )
elif node.language != "sql":
logger.debug(
f"Not generating CLL for {node.dbt_name} because it is not a SQL model."
)
self.report.sql_parser_skipped_non_sql_model.append(node.dbt_name)
elif node.compiled_code: elif node.compiled_code:
# Add CTE stops based on the upstreams list. # Add CTE stops based on the upstreams list.
cte_mapping = { cte_mapping = {

View File

@ -1,3 +1,4 @@
import dataclasses
import json import json
import logging import logging
import re import re
@ -12,16 +13,15 @@ from pydantic import BaseModel, Field, validator
from datahub.configuration.git import GitReference from datahub.configuration.git import GitReference
from datahub.configuration.validate_field_rename import pydantic_renamed_field from datahub.configuration.validate_field_rename import pydantic_renamed_field
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import ( from datahub.ingestion.api.decorators import (
SupportStatus, SupportStatus,
capability,
config_class, config_class,
platform_name, platform_name,
support_status, support_status,
) )
from datahub.ingestion.api.source import ( from datahub.ingestion.api.source import (
CapabilityReport, CapabilityReport,
SourceCapability,
TestableSource, TestableSource,
TestConnectionReport, TestConnectionReport,
) )
@ -40,19 +40,28 @@ from datahub.ingestion.source.dbt.dbt_tests import DBTTest, DBTTestResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclasses.dataclass
class DBTCoreReport(DBTSourceReport):
catalog_info: Optional[dict] = None
manifest_info: Optional[dict] = None
class DBTCoreConfig(DBTCommonConfig): class DBTCoreConfig(DBTCommonConfig):
manifest_path: str = Field( manifest_path: str = Field(
description="Path to dbt manifest JSON. See https://docs.getdbt.com/reference/artifacts/manifest-json Note " description="Path to dbt manifest JSON. See https://docs.getdbt.com/reference/artifacts/manifest-json. "
"this can be a local file or a URI." "This can be a local file or a URI."
) )
catalog_path: str = Field( catalog_path: Optional[str] = Field(
description="Path to dbt catalog JSON. See https://docs.getdbt.com/reference/artifacts/catalog-json Note this " None,
"can be a local file or a URI." description="Path to dbt catalog JSON. See https://docs.getdbt.com/reference/artifacts/catalog-json. "
"This file is optional, but highly recommended. Without it, some metadata like column info will be incomplete or missing. "
"This can be a local file or a URI.",
) )
sources_path: Optional[str] = Field( sources_path: Optional[str] = Field(
default=None, default=None,
description="Path to dbt sources JSON. See https://docs.getdbt.com/reference/artifacts/sources-json. If not " description="Path to dbt sources JSON. See https://docs.getdbt.com/reference/artifacts/sources-json. "
"specified, last-modified fields will not be populated. Note this can be a local file or a URI.", "If not specified, last-modified fields will not be populated. "
"This can be a local file or a URI.",
) )
run_results_paths: List[str] = Field( run_results_paths: List[str] = Field(
default=[], default=[],
@ -161,7 +170,7 @@ def get_columns(
def extract_dbt_entities( def extract_dbt_entities(
all_manifest_entities: Dict[str, Dict[str, Any]], all_manifest_entities: Dict[str, Dict[str, Any]],
all_catalog_entities: Dict[str, Dict[str, Any]], all_catalog_entities: Optional[Dict[str, Dict[str, Any]]],
sources_results: List[Dict[str, Any]], sources_results: List[Dict[str, Any]],
manifest_adapter: str, manifest_adapter: str,
use_identifiers: bool, use_identifiers: bool,
@ -186,15 +195,6 @@ def extract_dbt_entities(
): ):
name = manifest_node["alias"] name = manifest_node["alias"]
# initialize comment to "" for consistency with descriptions
# (since dbt null/undefined descriptions as "")
comment = ""
if key in all_catalog_entities and all_catalog_entities[key]["metadata"].get(
"comment"
):
comment = all_catalog_entities[key]["metadata"]["comment"]
materialization = None materialization = None
if "materialized" in manifest_node.get("config", {}): if "materialized" in manifest_node.get("config", {}):
# It's a model # It's a model
@ -204,8 +204,9 @@ def extract_dbt_entities(
if "depends_on" in manifest_node and "nodes" in manifest_node["depends_on"]: if "depends_on" in manifest_node and "nodes" in manifest_node["depends_on"]:
upstream_nodes = manifest_node["depends_on"]["nodes"] upstream_nodes = manifest_node["depends_on"]["nodes"]
# It's a source catalog_node = (
catalog_node = all_catalog_entities.get(key) all_catalog_entities.get(key) if all_catalog_entities is not None else None
)
missing_from_catalog = catalog_node is None missing_from_catalog = catalog_node is None
catalog_type = None catalog_type = None
@ -214,16 +215,23 @@ def extract_dbt_entities(
# Test and ephemeral nodes will never show up in the catalog. # Test and ephemeral nodes will never show up in the catalog.
missing_from_catalog = False missing_from_catalog = False
else: else:
if not only_include_if_in_catalog: if all_catalog_entities is not None and not only_include_if_in_catalog:
# If the catalog file is missing, we have already generated a general message.
report.warning( report.warning(
title="Node missing from catalog", title="Node missing from catalog",
message="Found a node in the manifest file but not in the catalog. " message="Found a node in the manifest file but not in the catalog. "
"This usually means the catalog file was not generated by `dbt docs generate` and so is incomplete. " "This usually means the catalog file was not generated by `dbt docs generate` and so is incomplete. "
"Some metadata, such as column types and descriptions, will be impacted.", "Some metadata, particularly schema information, will be impacted.",
context=key, context=key,
) )
else: else:
catalog_type = all_catalog_entities[key]["metadata"]["type"] catalog_type = catalog_node["metadata"]["type"]
# initialize comment to "" for consistency with descriptions
# (since dbt null/undefined descriptions as "")
comment = ""
if catalog_node is not None and catalog_node.get("metadata", {}).get("comment"):
comment = catalog_node["metadata"]["comment"]
query_tag_props = manifest_node.get("query_tag", {}) query_tag_props = manifest_node.get("query_tag", {})
@ -231,12 +239,15 @@ def extract_dbt_entities(
owner = meta.get("owner") owner = meta.get("owner")
if owner is None: if owner is None:
owner = manifest_node.get("config", {}).get("meta", {}).get("owner") owner = (manifest_node.get("config", {}).get("meta") or {}).get("owner")
if not meta:
# On older versions of dbt, the meta field was nested under config
# for some node types.
meta = manifest_node.get("config", {}).get("meta") or {}
tags = manifest_node.get("tags", []) tags = manifest_node.get("tags", [])
tags = [tag_prefix + tag for tag in tags] tags = [tag_prefix + tag for tag in tags]
if not meta:
meta = manifest_node.get("config", {}).get("meta", {})
max_loaded_at_str = sources_by_id.get(key, {}).get("max_loaded_at") max_loaded_at_str = sources_by_id.get(key, {}).get("max_loaded_at")
max_loaded_at = None max_loaded_at = None
@ -453,15 +464,18 @@ def load_run_results(
@platform_name("dbt") @platform_name("dbt")
@config_class(DBTCoreConfig) @config_class(DBTCoreConfig)
@support_status(SupportStatus.CERTIFIED) @support_status(SupportStatus.CERTIFIED)
@capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion")
@capability(SourceCapability.LINEAGE_COARSE, "Enabled by default")
class DBTCoreSource(DBTSourceBase, TestableSource): class DBTCoreSource(DBTSourceBase, TestableSource):
config: DBTCoreConfig config: DBTCoreConfig
report: DBTCoreReport
def __init__(self, config: DBTCommonConfig, ctx: PipelineContext):
super().__init__(config, ctx)
self.report = DBTCoreReport()
@classmethod @classmethod
def create(cls, config_dict, ctx): def create(cls, config_dict, ctx):
config = DBTCoreConfig.parse_obj(config_dict) config = DBTCoreConfig.parse_obj(config_dict)
return cls(config, ctx, "dbt") return cls(config, ctx)
@staticmethod @staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport: def test_connection(config_dict: dict) -> TestConnectionReport:
@ -471,9 +485,10 @@ class DBTCoreSource(DBTSourceBase, TestableSource):
DBTCoreSource.load_file_as_json( DBTCoreSource.load_file_as_json(
source_config.manifest_path, source_config.aws_connection source_config.manifest_path, source_config.aws_connection
) )
DBTCoreSource.load_file_as_json( if source_config.catalog_path is not None:
source_config.catalog_path, source_config.aws_connection DBTCoreSource.load_file_as_json(
) source_config.catalog_path, source_config.aws_connection
)
test_report.basic_connectivity = CapabilityReport(capable=True) test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e: except Exception as e:
test_report.basic_connectivity = CapabilityReport( test_report.basic_connectivity = CapabilityReport(
@ -511,46 +526,30 @@ class DBTCoreSource(DBTSourceBase, TestableSource):
dbt_manifest_json = self.load_file_as_json( dbt_manifest_json = self.load_file_as_json(
self.config.manifest_path, self.config.aws_connection self.config.manifest_path, self.config.aws_connection
) )
dbt_manifest_metadata = dbt_manifest_json["metadata"]
dbt_catalog_json = self.load_file_as_json( self.report.manifest_info = dict(
self.config.catalog_path, self.config.aws_connection generated_at=dbt_manifest_metadata.get("generated_at", "unknown"),
dbt_version=dbt_manifest_metadata.get("dbt_version", "unknown"),
project_name=dbt_manifest_metadata.get("project_name", "unknown"),
) )
self.report.info( dbt_catalog_json = None
title="DBT metadata files", dbt_catalog_metadata = None
message="Manifest metadata", if self.config.catalog_path is not None:
context=str( dbt_catalog_json = self.load_file_as_json(
dict( self.config.catalog_path, self.config.aws_connection
generated_at=dbt_manifest_json["metadata"].get( )
"generated_at", "unknown" dbt_catalog_metadata = dbt_catalog_json.get("metadata", {})
), self.report.catalog_info = dict(
dbt_version=dbt_manifest_json["metadata"].get( generated_at=dbt_catalog_metadata.get("generated_at", "unknown"),
"dbt_version", "unknown" dbt_version=dbt_catalog_metadata.get("dbt_version", "unknown"),
), project_name=dbt_catalog_metadata.get("project_name", "unknown"),
project_name=dbt_manifest_json["metadata"].get( )
"project_name", "unknown" else:
), self.report.warning(
) title="No catalog file configured",
), message="Some metadata, particularly schema information, will be missing.",
) )
self.report.info(
title="DBT metadata files",
message="Catalog metadata",
context=str(
dict(
generated_at=dbt_catalog_json.get("metadata", {}).get(
"generated_at", "unknown"
),
dbt_version=dbt_catalog_json.get("metadata", {}).get(
"dbt_version", "unknown"
),
project_name=dbt_catalog_json.get("metadata", {}).get(
"project_name", "unknown"
),
)
),
)
if self.config.sources_path is not None: if self.config.sources_path is not None:
dbt_sources_json = self.load_file_as_json( dbt_sources_json = self.load_file_as_json(
@ -564,18 +563,23 @@ class DBTCoreSource(DBTSourceBase, TestableSource):
manifest_version = dbt_manifest_json["metadata"].get("dbt_version") manifest_version = dbt_manifest_json["metadata"].get("dbt_version")
manifest_adapter = dbt_manifest_json["metadata"].get("adapter_type") manifest_adapter = dbt_manifest_json["metadata"].get("adapter_type")
catalog_schema = dbt_catalog_json.get("metadata", {}).get("dbt_schema_version") catalog_schema = None
catalog_version = dbt_catalog_json.get("metadata", {}).get("dbt_version") catalog_version = None
if dbt_catalog_metadata is not None:
catalog_schema = dbt_catalog_metadata.get("dbt_schema_version")
catalog_version = dbt_catalog_metadata.get("dbt_version")
manifest_nodes = dbt_manifest_json["nodes"] manifest_nodes = dbt_manifest_json["nodes"]
manifest_sources = dbt_manifest_json["sources"] manifest_sources = dbt_manifest_json["sources"]
all_manifest_entities = {**manifest_nodes, **manifest_sources} all_manifest_entities = {**manifest_nodes, **manifest_sources}
catalog_nodes = dbt_catalog_json["nodes"] all_catalog_entities = None
catalog_sources = dbt_catalog_json["sources"] if dbt_catalog_json is not None:
catalog_nodes = dbt_catalog_json["nodes"]
catalog_sources = dbt_catalog_json["sources"]
all_catalog_entities = {**catalog_nodes, **catalog_sources} all_catalog_entities = {**catalog_nodes, **catalog_sources}
nodes = extract_dbt_entities( nodes = extract_dbt_entities(
all_manifest_entities=all_manifest_entities, all_manifest_entities=all_manifest_entities,
@ -626,7 +630,7 @@ class DBTCoreSource(DBTSourceBase, TestableSource):
) )
except Exception as e: except Exception as e:
self.report.info( self.report.info(
title="Dbt Catalog Version", title="dbt Catalog Version",
message="Failed to determine the catalog version", message="Failed to determine the catalog version",
exc=e, exc=e,
) )

View File

@ -61,7 +61,7 @@ def create_mocked_dbt_source() -> DBTCoreSource:
["non_dbt_existing", "dbt:existing"] ["non_dbt_existing", "dbt:existing"]
) )
ctx.graph = graph ctx.graph = graph
return DBTCoreSource(DBTCoreConfig(**create_base_dbt_config()), ctx, "dbt") return DBTCoreSource(DBTCoreConfig(**create_base_dbt_config()), ctx)
def create_base_dbt_config() -> Dict: def create_base_dbt_config() -> Dict:
@ -268,7 +268,7 @@ def test_dbt_prefer_sql_parser_lineage_no_self_reference():
"prefer_sql_parser_lineage": True, "prefer_sql_parser_lineage": True,
} }
) )
source: DBTCoreSource = DBTCoreSource(config, ctx, "dbt") source: DBTCoreSource = DBTCoreSource(config, ctx)
all_nodes_map = { all_nodes_map = {
"model1": DBTNode( "model1": DBTNode(
name="model1", name="model1",
@ -277,7 +277,7 @@ def test_dbt_prefer_sql_parser_lineage_no_self_reference():
alias=None, alias=None,
comment="", comment="",
description="", description="",
language=None, language="sql",
raw_code=None, raw_code=None,
dbt_adapter="postgres", dbt_adapter="postgres",
dbt_name="model1", dbt_name="model1",
@ -300,6 +300,39 @@ def test_dbt_prefer_sql_parser_lineage_no_self_reference():
assert len(upstream_lineage.upstreams) == 1 assert len(upstream_lineage.upstreams) == 1
def test_dbt_cll_skip_python_model() -> None:
ctx = PipelineContext(run_id="test-run-id")
config = DBTCoreConfig.parse_obj(create_base_dbt_config())
source: DBTCoreSource = DBTCoreSource(config, ctx)
all_nodes_map = {
"model1": DBTNode(
name="model1",
database=None,
schema=None,
alias=None,
comment="",
description="",
language="python",
raw_code=None,
dbt_adapter="postgres",
dbt_name="model1",
dbt_file_path=None,
dbt_package_name=None,
node_type="model",
materialization="table",
max_loaded_at=None,
catalog_type=None,
missing_from_catalog=False,
owner=None,
compiled_code="import pandas as pd\n# Other processing here...",
),
}
source._infer_schemas_and_update_cll(all_nodes_map)
assert len(source.report.sql_parser_skipped_non_sql_model) == 1
# TODO: Also test that table-level lineage is still created.
def test_dbt_s3_config(): def test_dbt_s3_config():
# test missing aws config # test missing aws config
config_dict: dict = { config_dict: dict = {
@ -526,8 +559,8 @@ def test_extract_dbt_entities():
catalog_path="tests/unit/dbt/artifacts/catalog.json", catalog_path="tests/unit/dbt/artifacts/catalog.json",
target_platform="dummy", target_platform="dummy",
) )
source = DBTCoreSource(config, ctx, "dbt") source = DBTCoreSource(config, ctx)
assert all(node.database is not None for node in source.loadManifestAndCatalog()[0]) assert all(node.database is not None for node in source.loadManifestAndCatalog()[0])
config.include_database_name = False config.include_database_name = False
source = DBTCoreSource(config, ctx, "dbt") source = DBTCoreSource(config, ctx)
assert all(node.database is None for node in source.loadManifestAndCatalog()[0]) assert all(node.database is None for node in source.loadManifestAndCatalog()[0])