diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/config.py b/metadata-ingestion/src/datahub/ingestion/source/unity/config.py index 7e020bbc9e..7ee3aed992 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/config.py @@ -1,10 +1,11 @@ -from datetime import datetime, timezone -from typing import Dict, Optional +import os +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Optional import pydantic from pydantic import Field -from datahub.configuration.common import AllowDenyPattern +from datahub.configuration.common import AllowDenyPattern, ConfigModel from datahub.configuration.source_common import DatasetSourceConfigMixin from datahub.configuration.validate_field_rename import pydantic_renamed_field from datahub.ingestion.source.state.stale_entity_removal_handler import ( @@ -12,12 +13,71 @@ from datahub.ingestion.source.state.stale_entity_removal_handler import ( ) from datahub.ingestion.source.state.stateful_ingestion_base import ( StatefulIngestionConfigBase, + StatefulProfilingConfigMixin, ) from datahub.ingestion.source.usage.usage_common import BaseUsageConfig +class UnityCatalogProfilerConfig(ConfigModel): + # TODO: Reduce duplicate code with DataLakeProfilerConfig, GEProfilingConfig, SQLAlchemyConfig + enabled: bool = Field( + default=False, description="Whether profiling should be done." + ) + + warehouse_id: Optional[str] = Field( + default=None, description="SQL Warehouse id, for running profiling queries." + ) + + profile_table_level_only: bool = Field( + default=False, + description="Whether to perform profiling at table-level only or include column-level profiling as well.", + ) + + pattern: AllowDenyPattern = Field( + default=AllowDenyPattern.allow_all(), + description=( + "Regex patterns to filter tables for profiling during ingestion. " + "Specify regex to match the `catalog.schema.table` format. " + "Note that only tables allowed by the `table_pattern` will be considered." + ), + ) + + call_analyze: bool = Field( + default=True, + description=( + "Whether to call ANALYZE TABLE as part of profile ingestion." + "If false, will ingest the results of the most recent ANALYZE TABLE call, if any." + ), + ) + + max_wait_secs: int = Field( + default=int(timedelta(hours=1).total_seconds()), + description="Maximum time to wait for an ANALYZE TABLE query to complete.", + ) + + max_workers: int = Field( + default=5 * (os.cpu_count() or 4), + description="Number of worker threads to use for profiling. Set to 1 to disable.", + ) + + @pydantic.root_validator + def warehouse_id_required_for_profiling( + cls, values: Dict[str, Any] + ) -> Dict[str, Any]: + if values.get("enabled") and not values.get("warehouse_id"): + raise ValueError("warehouse_id must be set when profiling is enabled.") + return values + + @property + def include_columns(self): + return not self.profile_table_level_only + + class UnityCatalogSourceConfig( - StatefulIngestionConfigBase, BaseUsageConfig, DatasetSourceConfigMixin + StatefulIngestionConfigBase, + BaseUsageConfig, + DatasetSourceConfigMixin, + StatefulProfilingConfigMixin, ): token: str = pydantic.Field(description="Databricks personal access token") workspace_url: str = pydantic.Field( @@ -76,6 +136,10 @@ class UnityCatalogSourceConfig( description="Generate usage statistics.", ) + profiling: UnityCatalogProfilerConfig = Field( + default=UnityCatalogProfilerConfig(), description="Data profiling configuration" + ) + stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = pydantic.Field( default=None, description="Unity Catalog Stateful Ingestion Config." ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/profiler.py b/metadata-ingestion/src/datahub/ingestion/source/unity/profiler.py new file mode 100644 index 0000000000..9609f41997 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/profiler.py @@ -0,0 +1,119 @@ +import logging +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Callable, Collection, Iterable, Optional + +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.source.unity.config import UnityCatalogProfilerConfig +from datahub.ingestion.source.unity.proxy import UnityCatalogApiProxy +from datahub.ingestion.source.unity.proxy_types import ( + ColumnProfile, + TableProfile, + TableReference, +) +from datahub.ingestion.source.unity.report import UnityCatalogReport +from datahub.metadata.schema_classes import ( + DatasetFieldProfileClass, + DatasetProfileClass, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class UnityCatalogProfiler: + config: UnityCatalogProfilerConfig + report: UnityCatalogReport + proxy: UnityCatalogApiProxy + dataset_urn_builder: Callable[[TableReference], str] + + def get_workunits( + self, table_refs: Collection[TableReference] + ) -> Iterable[MetadataWorkUnit]: + try: + tables = self._filter_tables(table_refs) + with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor: + futures = [executor.submit(self.process_table, ref) for ref in tables] + for future in as_completed(futures): + wu: Optional[MetadataWorkUnit] = future.result() + if wu: + self.report.num_profile_workunits_emitted += 1 + yield wu + except Exception as e: + self.report.report_warning("profiling", str(e)) + logger.warning(f"Unexpected error during profiling: {e}", exc_info=True) + return + + def _filter_tables( + self, table_refs: Collection[TableReference] + ) -> Collection[TableReference]: + return [ + ref + for ref in table_refs + if self.config.pattern.allowed(ref.qualified_table_name) + ] + + def process_table(self, ref: TableReference) -> Optional[MetadataWorkUnit]: + try: + table_profile = self.proxy.get_table_stats( + ref, + max_wait_secs=self.config.max_wait_secs, + call_analyze=self.config.call_analyze, + include_columns=self.config.include_columns, + ) + if table_profile: + return self.gen_dataset_profile_workunit(ref, table_profile) + elif table_profile is not None: # table_profile is Falsy == empty + self.report.profile_table_empty.append(str(ref)) + except Exception as e: + self.report.report_warning("profiling", str(e)) + logger.warning( + f"Unexpected error during profiling table {ref}: {e}", exc_info=True + ) + return None + + def gen_dataset_profile_workunit( + self, ref: TableReference, table_profile: TableProfile + ) -> MetadataWorkUnit: + row_count = table_profile.num_rows + aspect = DatasetProfileClass( + timestampMillis=int(time.time() * 1000), + rowCount=row_count, + columnCount=table_profile.num_columns, + sizeInBytes=table_profile.total_size, + fieldProfiles=[ + self._gen_dataset_field_profile(row_count, column_profile) + for column_profile in table_profile.column_profiles + if column_profile # Drop column profiles with no data + ] + if self.config.include_columns + else None, + ) + return MetadataChangeProposalWrapper( + entityUrn=self.dataset_urn_builder(ref), + aspect=aspect, + ).as_workunit() + + @staticmethod + def _gen_dataset_field_profile( + num_rows: Optional[int], column_profile: ColumnProfile + ) -> DatasetFieldProfileClass: + unique_proportion: Optional[float] = None + null_proportion: Optional[float] = None + if num_rows: + if column_profile.distinct_count is not None: + unique_proportion = min(1.0, column_profile.distinct_count / num_rows) + if column_profile.null_count is not None: + null_proportion = min(1.0, column_profile.null_count / num_rows) + + return DatasetFieldProfileClass( + fieldPath=column_profile.name, + uniqueCount=column_profile.distinct_count, + uniqueProportion=unique_proportion, + nullCount=column_profile.null_count, + nullProportion=null_proportion, + min=column_profile.min, + max=column_profile.max, + ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/proxy.py b/metadata-ingestion/src/datahub/ingestion/source/unity/proxy.py index b246aa1375..3c2f6cc712 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/proxy.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/proxy.py @@ -16,6 +16,9 @@ from databricks.sdk.service.sql import ( from databricks_cli.sdk.api_client import ApiClient from databricks_cli.unity_catalog.api import UnityCatalogApi +from datahub.ingestion.source.unity.proxy_profiling import ( + UnityCatalogProxyProfilingMixin, +) from datahub.ingestion.source.unity.proxy_types import ( ALLOWED_STATEMENT_TYPES, DATA_TYPE_REGISTRY, @@ -48,14 +51,19 @@ class QueryFilterWithStatementTypes(QueryFilter): return v -class UnityCatalogApiProxy: - _unity_catalog_api: UnityCatalogApi +class UnityCatalogApiProxy(UnityCatalogProxyProfilingMixin): _workspace_client: WorkspaceClient + _unity_catalog_api: UnityCatalogApi _workspace_url: str report: UnityCatalogReport + warehouse_id: str def __init__( - self, workspace_url: str, personal_access_token: str, report: UnityCatalogReport + self, + workspace_url: str, + personal_access_token: str, + warehouse_id: Optional[str], + report: UnityCatalogReport, ): self._workspace_client = WorkspaceClient( host=workspace_url, token=personal_access_token @@ -63,7 +71,7 @@ class UnityCatalogApiProxy: self._unity_catalog_api = UnityCatalogApi( ApiClient(host=workspace_url, token=personal_access_token) ) - self._workspace_url = workspace_url + self.warehouse_id = warehouse_id or "" self.report = report def check_connectivity(self) -> bool: @@ -131,6 +139,7 @@ class UnityCatalogApiProxy: yield self._create_table(schema=schema, obj=table) def service_principals(self) -> Iterable[ServicePrincipal]: + # TODO: Replace with self._workspace_client.service_principals.list() when it supports pagination start_index = 1 # Unfortunately 1-indexed items_per_page = 0 total_results = float("inf") diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_profiling.py b/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_profiling.py new file mode 100644 index 0000000000..fe74a2f4c0 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_profiling.py @@ -0,0 +1,238 @@ +import logging +import time +from typing import Optional, Union + +from databricks.sdk import WorkspaceClient +from databricks.sdk.core import DatabricksError +from databricks.sdk.service._internal import Wait +from databricks.sdk.service.catalog import TableInfo +from databricks.sdk.service.sql import ( + ExecuteStatementResponse, + GetStatementResponse, + GetWarehouseResponse, + StatementState, + StatementStatus, +) +from databricks_cli.unity_catalog.api import UnityCatalogApi + +from datahub.ingestion.source.unity.proxy_types import ( + ColumnProfile, + TableProfile, + TableReference, +) +from datahub.ingestion.source.unity.report import UnityCatalogReport +from datahub.utilities.lossy_collections import LossyList + +logger: logging.Logger = logging.getLogger(__name__) + + +# TODO: Move to separate proxy/ directory with rest of proxy code +class UnityCatalogProxyProfilingMixin: + _workspace_client: WorkspaceClient + _unity_catalog_api: UnityCatalogApi + report: UnityCatalogReport + warehouse_id: str + + def check_profiling_connectivity(self): + self._workspace_client.warehouses.get(self.warehouse_id) + return True + + def start_warehouse(self) -> Optional[Wait[GetWarehouseResponse]]: + """Starts a databricks SQL warehouse. + + Returns: + - A Wait object that can be used to wait for warehouse start completion. + - None if the warehouse does not exist. + """ + try: + return self._workspace_client.warehouses.start(self.warehouse_id) + except DatabricksError as e: + logger.warning(f"Unable to start warehouse -- are you sure it exists? {e}") + return None + + def get_table_stats( + self, + ref: TableReference, + *, + max_wait_secs: int, + call_analyze: bool, + include_columns: bool, + ) -> Optional[TableProfile]: + """Returns profiling information for a table. + + Performs three steps: + 1. Call ANALYZE TABLE to compute statistics for all columns + 2. Poll for ANALYZE completion with exponential backoff, with `max_wait_secs` timeout + 3. Get the ANALYZE result via the properties field in the tables API. + This is supposed to be returned by a DESCRIBE TABLE EXTENDED command, but I don't see it. + + Raises: + DatabricksError: If any of the above steps fail + """ + + # Currently uses databricks sdk, which is synchronous + # If we need to improve performance, we can manually make requests via aiohttp + try: + if call_analyze: + response = self._analyze_table(ref, include_columns=include_columns) + success = self._check_analyze_table_statement_status( + response, max_wait_secs=max_wait_secs + ) + if not success: + self.report.profile_table_timeouts.append(str(ref)) + return None + return self._get_table_profile(ref, include_columns=include_columns) + except DatabricksError as e: + # Attempt to parse out generic part of error message + msg = str(e) + idx = (str(msg).find("`") + 1) or (str(msg).find("'") + 1) or len(str(msg)) + base_msg = msg[:idx] + self.report.profile_table_errors.setdefault(base_msg, LossyList()).append( + (str(ref), msg) + ) + logger.warning( + f"Failure during profiling {ref}, {e.kwargs}: ({e.error_code}) {e}", + exc_info=True, + ) + + if ( + call_analyze + and include_columns + and self._should_retry_unsupported_column(ref, e) + ): + return self.get_table_stats( + ref, + max_wait_secs=max_wait_secs, + call_analyze=call_analyze, + include_columns=False, + ) + else: + return None + + def _should_retry_unsupported_column( + self, ref: TableReference, e: DatabricksError + ) -> bool: + if "[UNSUPPORTED_FEATURE.ANALYZE_UNSUPPORTED_COLUMN_TYPE]" in str(e): + logger.info( + f"Attempting to profile table without columns due to unsupported column type: {ref}" + ) + self.report.num_profile_failed_unsupported_column_type += 1 + return True + return False + + def _analyze_table( + self, ref: TableReference, include_columns: bool + ) -> ExecuteStatementResponse: + statement = f"ANALYZE TABLE {ref.schema}.{ref.table} COMPUTE STATISTICS" + if include_columns: + statement += " FOR ALL COLUMNS" + response = self._workspace_client.statement_execution.execute_statement( + statement=statement, + catalog=ref.catalog, + wait_timeout="0s", # Fetch result asynchronously + warehouse_id=self.warehouse_id, + ) + self._raise_if_error(response, "analyze-table") + return response + + def _check_analyze_table_statement_status( + self, execute_response: ExecuteStatementResponse, max_wait_secs: int + ) -> bool: + statement_id: str = execute_response.statement_id + status: StatementStatus = execute_response.status + + backoff_sec = 1 + total_wait_time = 0 + while ( + total_wait_time < max_wait_secs and status.state != StatementState.SUCCEEDED + ): + time.sleep(min(backoff_sec, max_wait_secs - total_wait_time)) + total_wait_time += backoff_sec + backoff_sec *= 2 + + response = self._workspace_client.statement_execution.get_statement( + statement_id + ) + self._raise_if_error(response, "get-statement") + status = response.status + + return status.state == StatementState.SUCCEEDED + + def _get_table_profile( + self, ref: TableReference, include_columns: bool + ) -> TableProfile: + table_info = self._workspace_client.tables.get(ref.qualified_table_name) + return self._create_table_profile(table_info, include_columns=include_columns) + + def _create_table_profile( + self, table_info: TableInfo, include_columns: bool + ) -> TableProfile: + # Warning: this implementation is brittle -- dependent on properties that can change + columns_names = [column.name for column in table_info.columns] + return TableProfile( + num_rows=self._get_int(table_info, "spark.sql.statistics.numRows"), + total_size=self._get_int(table_info, "spark.sql.statistics.totalSize"), + num_columns=len(columns_names), + column_profiles=[ + self._create_column_profile(column, table_info) + for column in columns_names + ] + if include_columns + else [], + ) + + def _create_column_profile( + self, column: str, table_info: TableInfo + ) -> ColumnProfile: + return ColumnProfile( + name=column, + null_count=self._get_int( + table_info, f"spark.sql.statistics.colStats.{column}.nullCount" + ), + distinct_count=self._get_int( + table_info, f"spark.sql.statistics.colStats.{column}.distinctCount" + ), + min=table_info.properties.get( + f"spark.sql.statistics.colStats.{column}.min" + ), + max=table_info.properties.get( + f"spark.sql.statistics.colStats.{column}.max" + ), + avg_len=table_info.properties.get( + f"spark.sql.statistics.colStats.{column}.avgLen" + ), + max_len=table_info.properties.get( + f"spark.sql.statistics.colStats.{column}.maxLen" + ), + version=table_info.properties.get( + f"spark.sql.statistics.colStats.{column}.version" + ), + ) + + def _get_int(self, table_info: TableInfo, field: str) -> Optional[int]: + value = table_info.properties.get(field) + if value is not None: + try: + return int(value) + except ValueError: + logger.warning( + f"Failed to parse int for {table_info.name} - {field}: {value}" + ) + self.report.num_profile_failed_int_casts += 1 + return None + + @staticmethod + def _raise_if_error( + response: Union[ExecuteStatementResponse, GetStatementResponse], key: str + ) -> None: + if response.status.state in [ + StatementState.FAILED, + StatementState.CANCELED, + StatementState.CLOSED, + ]: + raise DatabricksError( + response.status.error.message, + error_code=response.status.error.error_code.value, + status=response.status.state.value, + context=key, + ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_types.py b/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_types.py index af7a7d68c5..729d27a6d4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_types.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_types.py @@ -174,3 +174,44 @@ class Query: # User whose credentials were used to run the query executed_as_user_id: int executed_as_user_name: str + + +@dataclass +class TableProfile: + num_rows: Optional[int] + num_columns: Optional[int] + total_size: Optional[int] + column_profiles: List["ColumnProfile"] + + def __bool__(self): + return any( + ( + self.num_rows is not None, + self.num_columns is not None, + self.total_size is not None, + any(self.column_profiles), + ) + ) + + +@dataclass +class ColumnProfile: + name: str + null_count: Optional[int] + distinct_count: Optional[int] + min: Optional[str] + max: Optional[str] + + version: Optional[str] + avg_len: Optional[str] + max_len: Optional[str] + + def __bool__(self): + return any( + ( + self.null_count is not None, + self.distinct_count is not None, + self.min is not None, + self.max is not None, + ) + ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/report.py b/metadata-ingestion/src/datahub/ingestion/source/unity/report.py index 0788deb827..2b6fa649eb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/report.py @@ -1,9 +1,11 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Tuple from datahub.ingestion.api.report import EntityFilterReport from datahub.ingestion.source.state.stale_entity_removal_handler import ( StaleEntityRemovalSourceReport, ) +from datahub.utilities.lossy_collections import LossyDict, LossyList @dataclass @@ -12,6 +14,7 @@ class UnityCatalogReport(StaleEntityRemovalSourceReport): catalogs: EntityFilterReport = EntityFilterReport.field(type="catalog") schemas: EntityFilterReport = EntityFilterReport.field(type="schema") tables: EntityFilterReport = EntityFilterReport.field(type="table/view") + table_profiles: EntityFilterReport = EntityFilterReport.field(type="table profile") num_queries: int = 0 num_queries_dropped_parse_failure: int = 0 @@ -21,3 +24,12 @@ class UnityCatalogReport(StaleEntityRemovalSourceReport): num_operational_stats_workunits_emitted: int = 0 num_usage_workunits_emitted: int = 0 + + profile_table_timeouts: LossyList[str] = field(default_factory=LossyList) + profile_table_empty: LossyList[str] = field(default_factory=LossyList) + profile_table_errors: LossyDict[str, LossyList[Tuple[str, str]]] = field( + default_factory=LossyDict + ) + num_profile_failed_unsupported_column_type: int = 0 + num_profile_failed_int_casts: int = 0 + num_profile_workunits_emitted: int = 0 diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/source.py b/metadata-ingestion/src/datahub/ingestion/source/unity/source.py index b6c77bb8f7..a288d3d015 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/source.py @@ -1,6 +1,7 @@ import logging import re import time +from datetime import timedelta from typing import Dict, Iterable, List, Optional, Set from datahub.emitter.mce_builder import ( @@ -45,15 +46,18 @@ from datahub.ingestion.source.state.stale_entity_removal_handler import ( from datahub.ingestion.source.state.stateful_ingestion_base import ( StatefulIngestionSourceBase, ) -from datahub.ingestion.source.unity import proxy from datahub.ingestion.source.unity.config import UnityCatalogSourceConfig -from datahub.ingestion.source.unity.proxy import ( +from datahub.ingestion.source.unity.profiler import UnityCatalogProfiler +from datahub.ingestion.source.unity.proxy import UnityCatalogApiProxy +from datahub.ingestion.source.unity.proxy_types import ( Catalog, + Column, Metastore, Schema, ServicePrincipal, + Table, + TableReference, ) -from datahub.ingestion.source.unity.proxy_types import TableReference from datahub.ingestion.source.unity.report import UnityCatalogReport from datahub.ingestion.source.unity.usage import UnityCatalogUsageExtractor from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( @@ -115,7 +119,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): """ config: UnityCatalogSourceConfig - unity_catalog_api_proxy: proxy.UnityCatalogApiProxy + unity_catalog_api_proxy: UnityCatalogApiProxy platform: str = "databricks" platform_instance_name: str @@ -127,8 +131,11 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): self.config = config self.report: UnityCatalogReport = UnityCatalogReport() - self.unity_catalog_api_proxy = proxy.UnityCatalogApiProxy( - config.workspace_url, config.token, report=self.report + self.unity_catalog_api_proxy = UnityCatalogApiProxy( + config.workspace_url, + config.token, + config.profiling.warehouse_id, + report=self.report, ) # Determine the platform_instance_name @@ -156,21 +163,43 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): self.service_principals: Dict[str, ServicePrincipal] = {} # Global set of table refs self.table_refs: Set[TableReference] = set() + self.view_refs: Set[TableReference] = set() @staticmethod def test_connection(config_dict: dict) -> TestConnectionReport: test_report = TestConnectionReport() + test_report.capability_report = {} try: config = UnityCatalogSourceConfig.parse_obj_allow_extras(config_dict) report = UnityCatalogReport() - unity_proxy = proxy.UnityCatalogApiProxy( - config.workspace_url, config.token, report=report + unity_proxy = UnityCatalogApiProxy( + config.workspace_url, + config.token, + config.profiling.warehouse_id, + report=report, ) if unity_proxy.check_connectivity(): test_report.basic_connectivity = CapabilityReport(capable=True) else: test_report.basic_connectivity = CapabilityReport(capable=False) + # TODO: Refactor into separate file / method + if config.profiling.enabled and not config.profiling.warehouse_id: + test_report.capability_report[ + SourceCapability.DATA_PROFILING + ] = CapabilityReport( + capable=False, failure_reason="Warehouse ID not provided" + ) + elif config.profiling.enabled: + try: + unity_proxy.check_profiling_connectivity() + test_report.capability_report[ + SourceCapability.DATA_PROFILING + ] = CapabilityReport(capable=True) + except Exception as e: + test_report.capability_report[ + SourceCapability.DATA_PROFILING + ] = CapabilityReport(capable=False, failure_reason=str(e)) except Exception as e: test_report.basic_connectivity = CapabilityReport( capable=False, failure_reason=f"{e}" @@ -192,6 +221,17 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): ) def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: + wait_on_warehouse = None + if self.config.profiling.enabled: + # Can take several minutes, so start now and wait later + wait_on_warehouse = self.unity_catalog_api_proxy.start_warehouse() + if wait_on_warehouse is None: + self.report.report_failure( + "initialization", + f"SQL warehouse {self.config.profiling.warehouse_id} not found", + ) + return + self.build_service_principal_map() yield from self.process_metastores() @@ -203,7 +243,19 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): table_urn_builder=self.gen_dataset_urn, user_urn_builder=self.gen_user_urn, ) - yield from usage_extractor.run(self.table_refs) + yield from usage_extractor.run(self.table_refs | self.view_refs) + + if self.config.profiling.enabled: + assert wait_on_warehouse + timeout = timedelta(seconds=self.config.profiling.max_wait_secs) + wait_on_warehouse.result(timeout) + profiling_extractor = UnityCatalogProfiler( + self.config.profiling, + self.report, + self.unity_catalog_api_proxy, + self.gen_dataset_urn, + ) + yield from profiling_extractor.get_workunits(self.table_refs) def build_service_principal_map(self) -> None: try: @@ -233,9 +285,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): self.report.metastores.processed(metastore.metastore_id) - def process_catalogs( - self, metastore: proxy.Metastore - ) -> Iterable[MetadataWorkUnit]: + def process_catalogs(self, metastore: Metastore) -> Iterable[MetadataWorkUnit]: for catalog in self.unity_catalog_api_proxy.catalogs(metastore=metastore): if not self.config.catalog_pattern.allowed(catalog.id): self.report.catalogs.dropped(catalog.id) @@ -246,7 +296,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): self.report.catalogs.processed(catalog.id) - def process_schemas(self, catalog: proxy.Catalog) -> Iterable[MetadataWorkUnit]: + def process_schemas(self, catalog: Catalog) -> Iterable[MetadataWorkUnit]: for schema in self.unity_catalog_api_proxy.schemas(catalog=catalog): if not self.config.schema_pattern.allowed(schema.id): self.report.schemas.dropped(schema.id) @@ -257,19 +307,20 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): self.report.schemas.processed(schema.id) - def process_tables(self, schema: proxy.Schema) -> Iterable[MetadataWorkUnit]: + def process_tables(self, schema: Schema) -> Iterable[MetadataWorkUnit]: for table in self.unity_catalog_api_proxy.tables(schema=schema): if not self.config.table_pattern.allowed(table.ref.qualified_table_name): self.report.tables.dropped(table.id, type=table.type) continue - self.table_refs.add(table.ref) + if table.type.lower() == "view": + self.view_refs.add(table.ref) + else: + self.table_refs.add(table.ref) yield from self.process_table(table, schema) self.report.tables.processed(table.id, type=table.type) - def process_table( - self, table: proxy.Table, schema: proxy.Schema - ) -> Iterable[MetadataWorkUnit]: + def process_table(self, table: Table, schema: Schema) -> Iterable[MetadataWorkUnit]: dataset_urn = self.gen_dataset_urn(table.ref) yield from self.add_table_to_dataset_container(dataset_urn, schema) @@ -282,13 +333,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): sub_type = self._create_table_sub_type_aspect(table) schema_metadata = self._create_schema_metadata_aspect(table) operation = self._create_table_operation_aspect(table) - - domain = self._get_domain_aspect( - dataset_name=str( - f"{table.schema.catalog.name}.{table.schema.name}.{table.name}" - ) - ) - + domain = self._get_domain_aspect(dataset_name=table.ref.qualified_table_name) ownership = self._create_table_ownership_aspect(table) lineage: Optional[UpstreamLineageClass] = None @@ -317,7 +362,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): ] def _generate_column_lineage_aspect( - self, dataset_urn: str, table: proxy.Table + self, dataset_urn: str, table: Table ) -> Optional[UpstreamLineageClass]: upstreams: List[UpstreamClass] = [] finegrained_lineages: List[FineGrainedLineage] = [] @@ -353,7 +398,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): return None def _generate_lineage_aspect( - self, dataset_urn: str, table: proxy.Table + self, dataset_urn: str, table: Table ) -> Optional[UpstreamLineageClass]: upstreams: List[UpstreamClass] = [] for upstream in sorted(table.upstreams.keys()): @@ -485,9 +530,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): dataset_urn=dataset_urn, ) - def _create_table_property_aspect( - self, table: proxy.Table - ) -> DatasetPropertiesClass: + def _create_table_property_aspect(self, table: Table) -> DatasetPropertiesClass: custom_properties: dict = {} if table.storage_location is not None: custom_properties["storage_location"] = table.storage_location @@ -524,7 +567,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): lastModified=last_modified, ) - def _create_table_operation_aspect(self, table: proxy.Table) -> OperationClass: + def _create_table_operation_aspect(self, table: Table) -> OperationClass: """Produce an operation aspect for a table. If a last updated time is present, we produce an update operation. @@ -553,9 +596,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): return operation - def _create_table_ownership_aspect( - self, table: proxy.Table - ) -> Optional[OwnershipClass]: + def _create_table_ownership_aspect(self, table: Table) -> Optional[OwnershipClass]: owner_urn = self.get_owner_urn(table.owner) if owner_urn is not None: return OwnershipClass( @@ -568,7 +609,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): ) return None - def _create_table_sub_type_aspect(self, table: proxy.Table) -> SubTypesClass: + def _create_table_sub_type_aspect(self, table: Table) -> SubTypesClass: return SubTypesClass( typeNames=[ DatasetSubTypes.VIEW @@ -577,13 +618,13 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): ] ) - def _create_view_property_aspect(self, table: proxy.Table) -> ViewProperties: + def _create_view_property_aspect(self, table: Table) -> ViewProperties: assert table.view_definition return ViewProperties( materialized=False, viewLanguage="SQL", viewLogic=table.view_definition ) - def _create_schema_metadata_aspect(self, table: proxy.Table) -> SchemaMetadataClass: + def _create_schema_metadata_aspect(self, table: Table) -> SchemaMetadataClass: schema_fields: List[SchemaFieldClass] = [] for column in table.columns: @@ -599,7 +640,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource): ) @staticmethod - def _create_schema_field(column: proxy.Column) -> List[SchemaFieldClass]: + def _create_schema_field(column: Column) -> List[SchemaFieldClass]: _COMPLEX_TYPE = re.compile("^(struct|array)") if _COMPLEX_TYPE.match(column.type_text.lower()): diff --git a/metadata-ingestion/tests/unit/test_unity_catalog_config.py b/metadata-ingestion/tests/unit/test_unity_catalog_config.py index 1a296b0ea6..4be6f60171 100644 --- a/metadata-ingestion/tests/unit/test_unity_catalog_config.py +++ b/metadata-ingestion/tests/unit/test_unity_catalog_config.py @@ -33,6 +33,34 @@ def test_within_thirty_days(): ) +def test_profiling_requires_warehouses_id(): + config = UnityCatalogSourceConfig.parse_obj( + { + "token": "token", + "workspace_url": "https://workspace_url", + "profiling": {"enabled": True, "warehouse_id": "my_warehouse_id"}, + } + ) + assert config.profiling.enabled is True + + config = UnityCatalogSourceConfig.parse_obj( + { + "token": "token", + "workspace_url": "https://workspace_url", + "profiling": {"enabled": False}, + } + ) + assert config.profiling.enabled is False + + with pytest.raises(ValueError): + UnityCatalogSourceConfig.parse_obj( + { + "token": "token", + "workspace_url": "workspace_url", + } + ) + + @freeze_time(FROZEN_TIME) def test_workspace_url_should_start_with_https(): @@ -41,5 +69,6 @@ def test_workspace_url_should_start_with_https(): { "token": "token", "workspace_url": "workspace_url", + "profiling": {"enabled": True}, } )