mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-12 10:35:51 +00:00
feat(ingest/unity): Add profiling support (#7976)
- Also adds a new databricks sdk
This commit is contained in:
parent
294f65fdd7
commit
afcf462cb1
@ -1,10 +1,11 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pydantic
|
||||
from pydantic import Field
|
||||
|
||||
from datahub.configuration.common import AllowDenyPattern
|
||||
from datahub.configuration.common import AllowDenyPattern, ConfigModel
|
||||
from datahub.configuration.source_common import DatasetSourceConfigMixin
|
||||
from datahub.configuration.validate_field_rename import pydantic_renamed_field
|
||||
from datahub.ingestion.source.state.stale_entity_removal_handler import (
|
||||
@ -12,12 +13,71 @@ from datahub.ingestion.source.state.stale_entity_removal_handler import (
|
||||
)
|
||||
from datahub.ingestion.source.state.stateful_ingestion_base import (
|
||||
StatefulIngestionConfigBase,
|
||||
StatefulProfilingConfigMixin,
|
||||
)
|
||||
from datahub.ingestion.source.usage.usage_common import BaseUsageConfig
|
||||
|
||||
|
||||
class UnityCatalogProfilerConfig(ConfigModel):
|
||||
# TODO: Reduce duplicate code with DataLakeProfilerConfig, GEProfilingConfig, SQLAlchemyConfig
|
||||
enabled: bool = Field(
|
||||
default=False, description="Whether profiling should be done."
|
||||
)
|
||||
|
||||
warehouse_id: Optional[str] = Field(
|
||||
default=None, description="SQL Warehouse id, for running profiling queries."
|
||||
)
|
||||
|
||||
profile_table_level_only: bool = Field(
|
||||
default=False,
|
||||
description="Whether to perform profiling at table-level only or include column-level profiling as well.",
|
||||
)
|
||||
|
||||
pattern: AllowDenyPattern = Field(
|
||||
default=AllowDenyPattern.allow_all(),
|
||||
description=(
|
||||
"Regex patterns to filter tables for profiling during ingestion. "
|
||||
"Specify regex to match the `catalog.schema.table` format. "
|
||||
"Note that only tables allowed by the `table_pattern` will be considered."
|
||||
),
|
||||
)
|
||||
|
||||
call_analyze: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether to call ANALYZE TABLE as part of profile ingestion."
|
||||
"If false, will ingest the results of the most recent ANALYZE TABLE call, if any."
|
||||
),
|
||||
)
|
||||
|
||||
max_wait_secs: int = Field(
|
||||
default=int(timedelta(hours=1).total_seconds()),
|
||||
description="Maximum time to wait for an ANALYZE TABLE query to complete.",
|
||||
)
|
||||
|
||||
max_workers: int = Field(
|
||||
default=5 * (os.cpu_count() or 4),
|
||||
description="Number of worker threads to use for profiling. Set to 1 to disable.",
|
||||
)
|
||||
|
||||
@pydantic.root_validator
|
||||
def warehouse_id_required_for_profiling(
|
||||
cls, values: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
if values.get("enabled") and not values.get("warehouse_id"):
|
||||
raise ValueError("warehouse_id must be set when profiling is enabled.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def include_columns(self):
|
||||
return not self.profile_table_level_only
|
||||
|
||||
|
||||
class UnityCatalogSourceConfig(
|
||||
StatefulIngestionConfigBase, BaseUsageConfig, DatasetSourceConfigMixin
|
||||
StatefulIngestionConfigBase,
|
||||
BaseUsageConfig,
|
||||
DatasetSourceConfigMixin,
|
||||
StatefulProfilingConfigMixin,
|
||||
):
|
||||
token: str = pydantic.Field(description="Databricks personal access token")
|
||||
workspace_url: str = pydantic.Field(
|
||||
@ -76,6 +136,10 @@ class UnityCatalogSourceConfig(
|
||||
description="Generate usage statistics.",
|
||||
)
|
||||
|
||||
profiling: UnityCatalogProfilerConfig = Field(
|
||||
default=UnityCatalogProfilerConfig(), description="Data profiling configuration"
|
||||
)
|
||||
|
||||
stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = pydantic.Field(
|
||||
default=None, description="Unity Catalog Stateful Ingestion Config."
|
||||
)
|
||||
|
||||
@ -0,0 +1,119 @@
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Collection, Iterable, Optional
|
||||
|
||||
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
||||
from datahub.ingestion.api.workunit import MetadataWorkUnit
|
||||
from datahub.ingestion.source.unity.config import UnityCatalogProfilerConfig
|
||||
from datahub.ingestion.source.unity.proxy import UnityCatalogApiProxy
|
||||
from datahub.ingestion.source.unity.proxy_types import (
|
||||
ColumnProfile,
|
||||
TableProfile,
|
||||
TableReference,
|
||||
)
|
||||
from datahub.ingestion.source.unity.report import UnityCatalogReport
|
||||
from datahub.metadata.schema_classes import (
|
||||
DatasetFieldProfileClass,
|
||||
DatasetProfileClass,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnityCatalogProfiler:
|
||||
config: UnityCatalogProfilerConfig
|
||||
report: UnityCatalogReport
|
||||
proxy: UnityCatalogApiProxy
|
||||
dataset_urn_builder: Callable[[TableReference], str]
|
||||
|
||||
def get_workunits(
|
||||
self, table_refs: Collection[TableReference]
|
||||
) -> Iterable[MetadataWorkUnit]:
|
||||
try:
|
||||
tables = self._filter_tables(table_refs)
|
||||
with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
|
||||
futures = [executor.submit(self.process_table, ref) for ref in tables]
|
||||
for future in as_completed(futures):
|
||||
wu: Optional[MetadataWorkUnit] = future.result()
|
||||
if wu:
|
||||
self.report.num_profile_workunits_emitted += 1
|
||||
yield wu
|
||||
except Exception as e:
|
||||
self.report.report_warning("profiling", str(e))
|
||||
logger.warning(f"Unexpected error during profiling: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
def _filter_tables(
|
||||
self, table_refs: Collection[TableReference]
|
||||
) -> Collection[TableReference]:
|
||||
return [
|
||||
ref
|
||||
for ref in table_refs
|
||||
if self.config.pattern.allowed(ref.qualified_table_name)
|
||||
]
|
||||
|
||||
def process_table(self, ref: TableReference) -> Optional[MetadataWorkUnit]:
|
||||
try:
|
||||
table_profile = self.proxy.get_table_stats(
|
||||
ref,
|
||||
max_wait_secs=self.config.max_wait_secs,
|
||||
call_analyze=self.config.call_analyze,
|
||||
include_columns=self.config.include_columns,
|
||||
)
|
||||
if table_profile:
|
||||
return self.gen_dataset_profile_workunit(ref, table_profile)
|
||||
elif table_profile is not None: # table_profile is Falsy == empty
|
||||
self.report.profile_table_empty.append(str(ref))
|
||||
except Exception as e:
|
||||
self.report.report_warning("profiling", str(e))
|
||||
logger.warning(
|
||||
f"Unexpected error during profiling table {ref}: {e}", exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
def gen_dataset_profile_workunit(
|
||||
self, ref: TableReference, table_profile: TableProfile
|
||||
) -> MetadataWorkUnit:
|
||||
row_count = table_profile.num_rows
|
||||
aspect = DatasetProfileClass(
|
||||
timestampMillis=int(time.time() * 1000),
|
||||
rowCount=row_count,
|
||||
columnCount=table_profile.num_columns,
|
||||
sizeInBytes=table_profile.total_size,
|
||||
fieldProfiles=[
|
||||
self._gen_dataset_field_profile(row_count, column_profile)
|
||||
for column_profile in table_profile.column_profiles
|
||||
if column_profile # Drop column profiles with no data
|
||||
]
|
||||
if self.config.include_columns
|
||||
else None,
|
||||
)
|
||||
return MetadataChangeProposalWrapper(
|
||||
entityUrn=self.dataset_urn_builder(ref),
|
||||
aspect=aspect,
|
||||
).as_workunit()
|
||||
|
||||
@staticmethod
|
||||
def _gen_dataset_field_profile(
|
||||
num_rows: Optional[int], column_profile: ColumnProfile
|
||||
) -> DatasetFieldProfileClass:
|
||||
unique_proportion: Optional[float] = None
|
||||
null_proportion: Optional[float] = None
|
||||
if num_rows:
|
||||
if column_profile.distinct_count is not None:
|
||||
unique_proportion = min(1.0, column_profile.distinct_count / num_rows)
|
||||
if column_profile.null_count is not None:
|
||||
null_proportion = min(1.0, column_profile.null_count / num_rows)
|
||||
|
||||
return DatasetFieldProfileClass(
|
||||
fieldPath=column_profile.name,
|
||||
uniqueCount=column_profile.distinct_count,
|
||||
uniqueProportion=unique_proportion,
|
||||
nullCount=column_profile.null_count,
|
||||
nullProportion=null_proportion,
|
||||
min=column_profile.min,
|
||||
max=column_profile.max,
|
||||
)
|
||||
@ -16,6 +16,9 @@ from databricks.sdk.service.sql import (
|
||||
from databricks_cli.sdk.api_client import ApiClient
|
||||
from databricks_cli.unity_catalog.api import UnityCatalogApi
|
||||
|
||||
from datahub.ingestion.source.unity.proxy_profiling import (
|
||||
UnityCatalogProxyProfilingMixin,
|
||||
)
|
||||
from datahub.ingestion.source.unity.proxy_types import (
|
||||
ALLOWED_STATEMENT_TYPES,
|
||||
DATA_TYPE_REGISTRY,
|
||||
@ -48,14 +51,19 @@ class QueryFilterWithStatementTypes(QueryFilter):
|
||||
return v
|
||||
|
||||
|
||||
class UnityCatalogApiProxy:
|
||||
_unity_catalog_api: UnityCatalogApi
|
||||
class UnityCatalogApiProxy(UnityCatalogProxyProfilingMixin):
|
||||
_workspace_client: WorkspaceClient
|
||||
_unity_catalog_api: UnityCatalogApi
|
||||
_workspace_url: str
|
||||
report: UnityCatalogReport
|
||||
warehouse_id: str
|
||||
|
||||
def __init__(
|
||||
self, workspace_url: str, personal_access_token: str, report: UnityCatalogReport
|
||||
self,
|
||||
workspace_url: str,
|
||||
personal_access_token: str,
|
||||
warehouse_id: Optional[str],
|
||||
report: UnityCatalogReport,
|
||||
):
|
||||
self._workspace_client = WorkspaceClient(
|
||||
host=workspace_url, token=personal_access_token
|
||||
@ -63,7 +71,7 @@ class UnityCatalogApiProxy:
|
||||
self._unity_catalog_api = UnityCatalogApi(
|
||||
ApiClient(host=workspace_url, token=personal_access_token)
|
||||
)
|
||||
self._workspace_url = workspace_url
|
||||
self.warehouse_id = warehouse_id or ""
|
||||
self.report = report
|
||||
|
||||
def check_connectivity(self) -> bool:
|
||||
@ -131,6 +139,7 @@ class UnityCatalogApiProxy:
|
||||
yield self._create_table(schema=schema, obj=table)
|
||||
|
||||
def service_principals(self) -> Iterable[ServicePrincipal]:
|
||||
# TODO: Replace with self._workspace_client.service_principals.list() when it supports pagination
|
||||
start_index = 1 # Unfortunately 1-indexed
|
||||
items_per_page = 0
|
||||
total_results = float("inf")
|
||||
|
||||
@ -0,0 +1,238 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
from databricks.sdk import WorkspaceClient
|
||||
from databricks.sdk.core import DatabricksError
|
||||
from databricks.sdk.service._internal import Wait
|
||||
from databricks.sdk.service.catalog import TableInfo
|
||||
from databricks.sdk.service.sql import (
|
||||
ExecuteStatementResponse,
|
||||
GetStatementResponse,
|
||||
GetWarehouseResponse,
|
||||
StatementState,
|
||||
StatementStatus,
|
||||
)
|
||||
from databricks_cli.unity_catalog.api import UnityCatalogApi
|
||||
|
||||
from datahub.ingestion.source.unity.proxy_types import (
|
||||
ColumnProfile,
|
||||
TableProfile,
|
||||
TableReference,
|
||||
)
|
||||
from datahub.ingestion.source.unity.report import UnityCatalogReport
|
||||
from datahub.utilities.lossy_collections import LossyList
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: Move to separate proxy/ directory with rest of proxy code
|
||||
class UnityCatalogProxyProfilingMixin:
|
||||
_workspace_client: WorkspaceClient
|
||||
_unity_catalog_api: UnityCatalogApi
|
||||
report: UnityCatalogReport
|
||||
warehouse_id: str
|
||||
|
||||
def check_profiling_connectivity(self):
|
||||
self._workspace_client.warehouses.get(self.warehouse_id)
|
||||
return True
|
||||
|
||||
def start_warehouse(self) -> Optional[Wait[GetWarehouseResponse]]:
|
||||
"""Starts a databricks SQL warehouse.
|
||||
|
||||
Returns:
|
||||
- A Wait object that can be used to wait for warehouse start completion.
|
||||
- None if the warehouse does not exist.
|
||||
"""
|
||||
try:
|
||||
return self._workspace_client.warehouses.start(self.warehouse_id)
|
||||
except DatabricksError as e:
|
||||
logger.warning(f"Unable to start warehouse -- are you sure it exists? {e}")
|
||||
return None
|
||||
|
||||
def get_table_stats(
|
||||
self,
|
||||
ref: TableReference,
|
||||
*,
|
||||
max_wait_secs: int,
|
||||
call_analyze: bool,
|
||||
include_columns: bool,
|
||||
) -> Optional[TableProfile]:
|
||||
"""Returns profiling information for a table.
|
||||
|
||||
Performs three steps:
|
||||
1. Call ANALYZE TABLE to compute statistics for all columns
|
||||
2. Poll for ANALYZE completion with exponential backoff, with `max_wait_secs` timeout
|
||||
3. Get the ANALYZE result via the properties field in the tables API.
|
||||
This is supposed to be returned by a DESCRIBE TABLE EXTENDED command, but I don't see it.
|
||||
|
||||
Raises:
|
||||
DatabricksError: If any of the above steps fail
|
||||
"""
|
||||
|
||||
# Currently uses databricks sdk, which is synchronous
|
||||
# If we need to improve performance, we can manually make requests via aiohttp
|
||||
try:
|
||||
if call_analyze:
|
||||
response = self._analyze_table(ref, include_columns=include_columns)
|
||||
success = self._check_analyze_table_statement_status(
|
||||
response, max_wait_secs=max_wait_secs
|
||||
)
|
||||
if not success:
|
||||
self.report.profile_table_timeouts.append(str(ref))
|
||||
return None
|
||||
return self._get_table_profile(ref, include_columns=include_columns)
|
||||
except DatabricksError as e:
|
||||
# Attempt to parse out generic part of error message
|
||||
msg = str(e)
|
||||
idx = (str(msg).find("`") + 1) or (str(msg).find("'") + 1) or len(str(msg))
|
||||
base_msg = msg[:idx]
|
||||
self.report.profile_table_errors.setdefault(base_msg, LossyList()).append(
|
||||
(str(ref), msg)
|
||||
)
|
||||
logger.warning(
|
||||
f"Failure during profiling {ref}, {e.kwargs}: ({e.error_code}) {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if (
|
||||
call_analyze
|
||||
and include_columns
|
||||
and self._should_retry_unsupported_column(ref, e)
|
||||
):
|
||||
return self.get_table_stats(
|
||||
ref,
|
||||
max_wait_secs=max_wait_secs,
|
||||
call_analyze=call_analyze,
|
||||
include_columns=False,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def _should_retry_unsupported_column(
|
||||
self, ref: TableReference, e: DatabricksError
|
||||
) -> bool:
|
||||
if "[UNSUPPORTED_FEATURE.ANALYZE_UNSUPPORTED_COLUMN_TYPE]" in str(e):
|
||||
logger.info(
|
||||
f"Attempting to profile table without columns due to unsupported column type: {ref}"
|
||||
)
|
||||
self.report.num_profile_failed_unsupported_column_type += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
def _analyze_table(
|
||||
self, ref: TableReference, include_columns: bool
|
||||
) -> ExecuteStatementResponse:
|
||||
statement = f"ANALYZE TABLE {ref.schema}.{ref.table} COMPUTE STATISTICS"
|
||||
if include_columns:
|
||||
statement += " FOR ALL COLUMNS"
|
||||
response = self._workspace_client.statement_execution.execute_statement(
|
||||
statement=statement,
|
||||
catalog=ref.catalog,
|
||||
wait_timeout="0s", # Fetch result asynchronously
|
||||
warehouse_id=self.warehouse_id,
|
||||
)
|
||||
self._raise_if_error(response, "analyze-table")
|
||||
return response
|
||||
|
||||
def _check_analyze_table_statement_status(
|
||||
self, execute_response: ExecuteStatementResponse, max_wait_secs: int
|
||||
) -> bool:
|
||||
statement_id: str = execute_response.statement_id
|
||||
status: StatementStatus = execute_response.status
|
||||
|
||||
backoff_sec = 1
|
||||
total_wait_time = 0
|
||||
while (
|
||||
total_wait_time < max_wait_secs and status.state != StatementState.SUCCEEDED
|
||||
):
|
||||
time.sleep(min(backoff_sec, max_wait_secs - total_wait_time))
|
||||
total_wait_time += backoff_sec
|
||||
backoff_sec *= 2
|
||||
|
||||
response = self._workspace_client.statement_execution.get_statement(
|
||||
statement_id
|
||||
)
|
||||
self._raise_if_error(response, "get-statement")
|
||||
status = response.status
|
||||
|
||||
return status.state == StatementState.SUCCEEDED
|
||||
|
||||
def _get_table_profile(
|
||||
self, ref: TableReference, include_columns: bool
|
||||
) -> TableProfile:
|
||||
table_info = self._workspace_client.tables.get(ref.qualified_table_name)
|
||||
return self._create_table_profile(table_info, include_columns=include_columns)
|
||||
|
||||
def _create_table_profile(
|
||||
self, table_info: TableInfo, include_columns: bool
|
||||
) -> TableProfile:
|
||||
# Warning: this implementation is brittle -- dependent on properties that can change
|
||||
columns_names = [column.name for column in table_info.columns]
|
||||
return TableProfile(
|
||||
num_rows=self._get_int(table_info, "spark.sql.statistics.numRows"),
|
||||
total_size=self._get_int(table_info, "spark.sql.statistics.totalSize"),
|
||||
num_columns=len(columns_names),
|
||||
column_profiles=[
|
||||
self._create_column_profile(column, table_info)
|
||||
for column in columns_names
|
||||
]
|
||||
if include_columns
|
||||
else [],
|
||||
)
|
||||
|
||||
def _create_column_profile(
|
||||
self, column: str, table_info: TableInfo
|
||||
) -> ColumnProfile:
|
||||
return ColumnProfile(
|
||||
name=column,
|
||||
null_count=self._get_int(
|
||||
table_info, f"spark.sql.statistics.colStats.{column}.nullCount"
|
||||
),
|
||||
distinct_count=self._get_int(
|
||||
table_info, f"spark.sql.statistics.colStats.{column}.distinctCount"
|
||||
),
|
||||
min=table_info.properties.get(
|
||||
f"spark.sql.statistics.colStats.{column}.min"
|
||||
),
|
||||
max=table_info.properties.get(
|
||||
f"spark.sql.statistics.colStats.{column}.max"
|
||||
),
|
||||
avg_len=table_info.properties.get(
|
||||
f"spark.sql.statistics.colStats.{column}.avgLen"
|
||||
),
|
||||
max_len=table_info.properties.get(
|
||||
f"spark.sql.statistics.colStats.{column}.maxLen"
|
||||
),
|
||||
version=table_info.properties.get(
|
||||
f"spark.sql.statistics.colStats.{column}.version"
|
||||
),
|
||||
)
|
||||
|
||||
def _get_int(self, table_info: TableInfo, field: str) -> Optional[int]:
|
||||
value = table_info.properties.get(field)
|
||||
if value is not None:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Failed to parse int for {table_info.name} - {field}: {value}"
|
||||
)
|
||||
self.report.num_profile_failed_int_casts += 1
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _raise_if_error(
|
||||
response: Union[ExecuteStatementResponse, GetStatementResponse], key: str
|
||||
) -> None:
|
||||
if response.status.state in [
|
||||
StatementState.FAILED,
|
||||
StatementState.CANCELED,
|
||||
StatementState.CLOSED,
|
||||
]:
|
||||
raise DatabricksError(
|
||||
response.status.error.message,
|
||||
error_code=response.status.error.error_code.value,
|
||||
status=response.status.state.value,
|
||||
context=key,
|
||||
)
|
||||
@ -174,3 +174,44 @@ class Query:
|
||||
# User whose credentials were used to run the query
|
||||
executed_as_user_id: int
|
||||
executed_as_user_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableProfile:
|
||||
num_rows: Optional[int]
|
||||
num_columns: Optional[int]
|
||||
total_size: Optional[int]
|
||||
column_profiles: List["ColumnProfile"]
|
||||
|
||||
def __bool__(self):
|
||||
return any(
|
||||
(
|
||||
self.num_rows is not None,
|
||||
self.num_columns is not None,
|
||||
self.total_size is not None,
|
||||
any(self.column_profiles),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColumnProfile:
|
||||
name: str
|
||||
null_count: Optional[int]
|
||||
distinct_count: Optional[int]
|
||||
min: Optional[str]
|
||||
max: Optional[str]
|
||||
|
||||
version: Optional[str]
|
||||
avg_len: Optional[str]
|
||||
max_len: Optional[str]
|
||||
|
||||
def __bool__(self):
|
||||
return any(
|
||||
(
|
||||
self.null_count is not None,
|
||||
self.distinct_count is not None,
|
||||
self.min is not None,
|
||||
self.max is not None,
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple
|
||||
|
||||
from datahub.ingestion.api.report import EntityFilterReport
|
||||
from datahub.ingestion.source.state.stale_entity_removal_handler import (
|
||||
StaleEntityRemovalSourceReport,
|
||||
)
|
||||
from datahub.utilities.lossy_collections import LossyDict, LossyList
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -12,6 +14,7 @@ class UnityCatalogReport(StaleEntityRemovalSourceReport):
|
||||
catalogs: EntityFilterReport = EntityFilterReport.field(type="catalog")
|
||||
schemas: EntityFilterReport = EntityFilterReport.field(type="schema")
|
||||
tables: EntityFilterReport = EntityFilterReport.field(type="table/view")
|
||||
table_profiles: EntityFilterReport = EntityFilterReport.field(type="table profile")
|
||||
|
||||
num_queries: int = 0
|
||||
num_queries_dropped_parse_failure: int = 0
|
||||
@ -21,3 +24,12 @@ class UnityCatalogReport(StaleEntityRemovalSourceReport):
|
||||
|
||||
num_operational_stats_workunits_emitted: int = 0
|
||||
num_usage_workunits_emitted: int = 0
|
||||
|
||||
profile_table_timeouts: LossyList[str] = field(default_factory=LossyList)
|
||||
profile_table_empty: LossyList[str] = field(default_factory=LossyList)
|
||||
profile_table_errors: LossyDict[str, LossyList[Tuple[str, str]]] = field(
|
||||
default_factory=LossyDict
|
||||
)
|
||||
num_profile_failed_unsupported_column_type: int = 0
|
||||
num_profile_failed_int_casts: int = 0
|
||||
num_profile_workunits_emitted: int = 0
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import Dict, Iterable, List, Optional, Set
|
||||
|
||||
from datahub.emitter.mce_builder import (
|
||||
@ -45,15 +46,18 @@ from datahub.ingestion.source.state.stale_entity_removal_handler import (
|
||||
from datahub.ingestion.source.state.stateful_ingestion_base import (
|
||||
StatefulIngestionSourceBase,
|
||||
)
|
||||
from datahub.ingestion.source.unity import proxy
|
||||
from datahub.ingestion.source.unity.config import UnityCatalogSourceConfig
|
||||
from datahub.ingestion.source.unity.proxy import (
|
||||
from datahub.ingestion.source.unity.profiler import UnityCatalogProfiler
|
||||
from datahub.ingestion.source.unity.proxy import UnityCatalogApiProxy
|
||||
from datahub.ingestion.source.unity.proxy_types import (
|
||||
Catalog,
|
||||
Column,
|
||||
Metastore,
|
||||
Schema,
|
||||
ServicePrincipal,
|
||||
Table,
|
||||
TableReference,
|
||||
)
|
||||
from datahub.ingestion.source.unity.proxy_types import TableReference
|
||||
from datahub.ingestion.source.unity.report import UnityCatalogReport
|
||||
from datahub.ingestion.source.unity.usage import UnityCatalogUsageExtractor
|
||||
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
|
||||
@ -115,7 +119,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
"""
|
||||
|
||||
config: UnityCatalogSourceConfig
|
||||
unity_catalog_api_proxy: proxy.UnityCatalogApiProxy
|
||||
unity_catalog_api_proxy: UnityCatalogApiProxy
|
||||
platform: str = "databricks"
|
||||
platform_instance_name: str
|
||||
|
||||
@ -127,8 +131,11 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
|
||||
self.config = config
|
||||
self.report: UnityCatalogReport = UnityCatalogReport()
|
||||
self.unity_catalog_api_proxy = proxy.UnityCatalogApiProxy(
|
||||
config.workspace_url, config.token, report=self.report
|
||||
self.unity_catalog_api_proxy = UnityCatalogApiProxy(
|
||||
config.workspace_url,
|
||||
config.token,
|
||||
config.profiling.warehouse_id,
|
||||
report=self.report,
|
||||
)
|
||||
|
||||
# Determine the platform_instance_name
|
||||
@ -156,21 +163,43 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
self.service_principals: Dict[str, ServicePrincipal] = {}
|
||||
# Global set of table refs
|
||||
self.table_refs: Set[TableReference] = set()
|
||||
self.view_refs: Set[TableReference] = set()
|
||||
|
||||
@staticmethod
|
||||
def test_connection(config_dict: dict) -> TestConnectionReport:
|
||||
test_report = TestConnectionReport()
|
||||
test_report.capability_report = {}
|
||||
try:
|
||||
config = UnityCatalogSourceConfig.parse_obj_allow_extras(config_dict)
|
||||
report = UnityCatalogReport()
|
||||
unity_proxy = proxy.UnityCatalogApiProxy(
|
||||
config.workspace_url, config.token, report=report
|
||||
unity_proxy = UnityCatalogApiProxy(
|
||||
config.workspace_url,
|
||||
config.token,
|
||||
config.profiling.warehouse_id,
|
||||
report=report,
|
||||
)
|
||||
if unity_proxy.check_connectivity():
|
||||
test_report.basic_connectivity = CapabilityReport(capable=True)
|
||||
else:
|
||||
test_report.basic_connectivity = CapabilityReport(capable=False)
|
||||
|
||||
# TODO: Refactor into separate file / method
|
||||
if config.profiling.enabled and not config.profiling.warehouse_id:
|
||||
test_report.capability_report[
|
||||
SourceCapability.DATA_PROFILING
|
||||
] = CapabilityReport(
|
||||
capable=False, failure_reason="Warehouse ID not provided"
|
||||
)
|
||||
elif config.profiling.enabled:
|
||||
try:
|
||||
unity_proxy.check_profiling_connectivity()
|
||||
test_report.capability_report[
|
||||
SourceCapability.DATA_PROFILING
|
||||
] = CapabilityReport(capable=True)
|
||||
except Exception as e:
|
||||
test_report.capability_report[
|
||||
SourceCapability.DATA_PROFILING
|
||||
] = CapabilityReport(capable=False, failure_reason=str(e))
|
||||
except Exception as e:
|
||||
test_report.basic_connectivity = CapabilityReport(
|
||||
capable=False, failure_reason=f"{e}"
|
||||
@ -192,6 +221,17 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
)
|
||||
|
||||
def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
|
||||
wait_on_warehouse = None
|
||||
if self.config.profiling.enabled:
|
||||
# Can take several minutes, so start now and wait later
|
||||
wait_on_warehouse = self.unity_catalog_api_proxy.start_warehouse()
|
||||
if wait_on_warehouse is None:
|
||||
self.report.report_failure(
|
||||
"initialization",
|
||||
f"SQL warehouse {self.config.profiling.warehouse_id} not found",
|
||||
)
|
||||
return
|
||||
|
||||
self.build_service_principal_map()
|
||||
yield from self.process_metastores()
|
||||
|
||||
@ -203,7 +243,19 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
table_urn_builder=self.gen_dataset_urn,
|
||||
user_urn_builder=self.gen_user_urn,
|
||||
)
|
||||
yield from usage_extractor.run(self.table_refs)
|
||||
yield from usage_extractor.run(self.table_refs | self.view_refs)
|
||||
|
||||
if self.config.profiling.enabled:
|
||||
assert wait_on_warehouse
|
||||
timeout = timedelta(seconds=self.config.profiling.max_wait_secs)
|
||||
wait_on_warehouse.result(timeout)
|
||||
profiling_extractor = UnityCatalogProfiler(
|
||||
self.config.profiling,
|
||||
self.report,
|
||||
self.unity_catalog_api_proxy,
|
||||
self.gen_dataset_urn,
|
||||
)
|
||||
yield from profiling_extractor.get_workunits(self.table_refs)
|
||||
|
||||
def build_service_principal_map(self) -> None:
|
||||
try:
|
||||
@ -233,9 +285,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
|
||||
self.report.metastores.processed(metastore.metastore_id)
|
||||
|
||||
def process_catalogs(
|
||||
self, metastore: proxy.Metastore
|
||||
) -> Iterable[MetadataWorkUnit]:
|
||||
def process_catalogs(self, metastore: Metastore) -> Iterable[MetadataWorkUnit]:
|
||||
for catalog in self.unity_catalog_api_proxy.catalogs(metastore=metastore):
|
||||
if not self.config.catalog_pattern.allowed(catalog.id):
|
||||
self.report.catalogs.dropped(catalog.id)
|
||||
@ -246,7 +296,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
|
||||
self.report.catalogs.processed(catalog.id)
|
||||
|
||||
def process_schemas(self, catalog: proxy.Catalog) -> Iterable[MetadataWorkUnit]:
|
||||
def process_schemas(self, catalog: Catalog) -> Iterable[MetadataWorkUnit]:
|
||||
for schema in self.unity_catalog_api_proxy.schemas(catalog=catalog):
|
||||
if not self.config.schema_pattern.allowed(schema.id):
|
||||
self.report.schemas.dropped(schema.id)
|
||||
@ -257,19 +307,20 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
|
||||
self.report.schemas.processed(schema.id)
|
||||
|
||||
def process_tables(self, schema: proxy.Schema) -> Iterable[MetadataWorkUnit]:
|
||||
def process_tables(self, schema: Schema) -> Iterable[MetadataWorkUnit]:
|
||||
for table in self.unity_catalog_api_proxy.tables(schema=schema):
|
||||
if not self.config.table_pattern.allowed(table.ref.qualified_table_name):
|
||||
self.report.tables.dropped(table.id, type=table.type)
|
||||
continue
|
||||
|
||||
self.table_refs.add(table.ref)
|
||||
if table.type.lower() == "view":
|
||||
self.view_refs.add(table.ref)
|
||||
else:
|
||||
self.table_refs.add(table.ref)
|
||||
yield from self.process_table(table, schema)
|
||||
self.report.tables.processed(table.id, type=table.type)
|
||||
|
||||
def process_table(
|
||||
self, table: proxy.Table, schema: proxy.Schema
|
||||
) -> Iterable[MetadataWorkUnit]:
|
||||
def process_table(self, table: Table, schema: Schema) -> Iterable[MetadataWorkUnit]:
|
||||
dataset_urn = self.gen_dataset_urn(table.ref)
|
||||
yield from self.add_table_to_dataset_container(dataset_urn, schema)
|
||||
|
||||
@ -282,13 +333,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
sub_type = self._create_table_sub_type_aspect(table)
|
||||
schema_metadata = self._create_schema_metadata_aspect(table)
|
||||
operation = self._create_table_operation_aspect(table)
|
||||
|
||||
domain = self._get_domain_aspect(
|
||||
dataset_name=str(
|
||||
f"{table.schema.catalog.name}.{table.schema.name}.{table.name}"
|
||||
)
|
||||
)
|
||||
|
||||
domain = self._get_domain_aspect(dataset_name=table.ref.qualified_table_name)
|
||||
ownership = self._create_table_ownership_aspect(table)
|
||||
|
||||
lineage: Optional[UpstreamLineageClass] = None
|
||||
@ -317,7 +362,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
]
|
||||
|
||||
def _generate_column_lineage_aspect(
|
||||
self, dataset_urn: str, table: proxy.Table
|
||||
self, dataset_urn: str, table: Table
|
||||
) -> Optional[UpstreamLineageClass]:
|
||||
upstreams: List[UpstreamClass] = []
|
||||
finegrained_lineages: List[FineGrainedLineage] = []
|
||||
@ -353,7 +398,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
return None
|
||||
|
||||
def _generate_lineage_aspect(
|
||||
self, dataset_urn: str, table: proxy.Table
|
||||
self, dataset_urn: str, table: Table
|
||||
) -> Optional[UpstreamLineageClass]:
|
||||
upstreams: List[UpstreamClass] = []
|
||||
for upstream in sorted(table.upstreams.keys()):
|
||||
@ -485,9 +530,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
dataset_urn=dataset_urn,
|
||||
)
|
||||
|
||||
def _create_table_property_aspect(
|
||||
self, table: proxy.Table
|
||||
) -> DatasetPropertiesClass:
|
||||
def _create_table_property_aspect(self, table: Table) -> DatasetPropertiesClass:
|
||||
custom_properties: dict = {}
|
||||
if table.storage_location is not None:
|
||||
custom_properties["storage_location"] = table.storage_location
|
||||
@ -524,7 +567,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
lastModified=last_modified,
|
||||
)
|
||||
|
||||
def _create_table_operation_aspect(self, table: proxy.Table) -> OperationClass:
|
||||
def _create_table_operation_aspect(self, table: Table) -> OperationClass:
|
||||
"""Produce an operation aspect for a table.
|
||||
|
||||
If a last updated time is present, we produce an update operation.
|
||||
@ -553,9 +596,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
|
||||
return operation
|
||||
|
||||
def _create_table_ownership_aspect(
|
||||
self, table: proxy.Table
|
||||
) -> Optional[OwnershipClass]:
|
||||
def _create_table_ownership_aspect(self, table: Table) -> Optional[OwnershipClass]:
|
||||
owner_urn = self.get_owner_urn(table.owner)
|
||||
if owner_urn is not None:
|
||||
return OwnershipClass(
|
||||
@ -568,7 +609,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
)
|
||||
return None
|
||||
|
||||
def _create_table_sub_type_aspect(self, table: proxy.Table) -> SubTypesClass:
|
||||
def _create_table_sub_type_aspect(self, table: Table) -> SubTypesClass:
|
||||
return SubTypesClass(
|
||||
typeNames=[
|
||||
DatasetSubTypes.VIEW
|
||||
@ -577,13 +618,13 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
]
|
||||
)
|
||||
|
||||
def _create_view_property_aspect(self, table: proxy.Table) -> ViewProperties:
|
||||
def _create_view_property_aspect(self, table: Table) -> ViewProperties:
|
||||
assert table.view_definition
|
||||
return ViewProperties(
|
||||
materialized=False, viewLanguage="SQL", viewLogic=table.view_definition
|
||||
)
|
||||
|
||||
def _create_schema_metadata_aspect(self, table: proxy.Table) -> SchemaMetadataClass:
|
||||
def _create_schema_metadata_aspect(self, table: Table) -> SchemaMetadataClass:
|
||||
schema_fields: List[SchemaFieldClass] = []
|
||||
|
||||
for column in table.columns:
|
||||
@ -599,7 +640,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_schema_field(column: proxy.Column) -> List[SchemaFieldClass]:
|
||||
def _create_schema_field(column: Column) -> List[SchemaFieldClass]:
|
||||
_COMPLEX_TYPE = re.compile("^(struct|array)")
|
||||
|
||||
if _COMPLEX_TYPE.match(column.type_text.lower()):
|
||||
|
||||
@ -33,6 +33,34 @@ def test_within_thirty_days():
|
||||
)
|
||||
|
||||
|
||||
def test_profiling_requires_warehouses_id():
|
||||
config = UnityCatalogSourceConfig.parse_obj(
|
||||
{
|
||||
"token": "token",
|
||||
"workspace_url": "https://workspace_url",
|
||||
"profiling": {"enabled": True, "warehouse_id": "my_warehouse_id"},
|
||||
}
|
||||
)
|
||||
assert config.profiling.enabled is True
|
||||
|
||||
config = UnityCatalogSourceConfig.parse_obj(
|
||||
{
|
||||
"token": "token",
|
||||
"workspace_url": "https://workspace_url",
|
||||
"profiling": {"enabled": False},
|
||||
}
|
||||
)
|
||||
assert config.profiling.enabled is False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
UnityCatalogSourceConfig.parse_obj(
|
||||
{
|
||||
"token": "token",
|
||||
"workspace_url": "workspace_url",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
def test_workspace_url_should_start_with_https():
|
||||
|
||||
@ -41,5 +69,6 @@ def test_workspace_url_should_start_with_https():
|
||||
{
|
||||
"token": "token",
|
||||
"workspace_url": "workspace_url",
|
||||
"profiling": {"enabled": True},
|
||||
}
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user