feat(ingest/unity): Add profiling support (#7976)

- Also adds a new databricks sdk
This commit is contained in:
Andrew Sikowitz 2023-05-11 13:00:50 -04:00 committed by GitHub
parent 294f65fdd7
commit afcf462cb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 600 additions and 47 deletions

View File

@ -1,10 +1,11 @@
from datetime import datetime, timezone
from typing import Dict, Optional
import os
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional
import pydantic
from pydantic import Field
from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.common import AllowDenyPattern, ConfigModel
from datahub.configuration.source_common import DatasetSourceConfigMixin
from datahub.configuration.validate_field_rename import pydantic_renamed_field
from datahub.ingestion.source.state.stale_entity_removal_handler import (
@ -12,12 +13,71 @@ from datahub.ingestion.source.state.stale_entity_removal_handler import (
)
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionConfigBase,
StatefulProfilingConfigMixin,
)
from datahub.ingestion.source.usage.usage_common import BaseUsageConfig
class UnityCatalogProfilerConfig(ConfigModel):
# TODO: Reduce duplicate code with DataLakeProfilerConfig, GEProfilingConfig, SQLAlchemyConfig
enabled: bool = Field(
default=False, description="Whether profiling should be done."
)
warehouse_id: Optional[str] = Field(
default=None, description="SQL Warehouse id, for running profiling queries."
)
profile_table_level_only: bool = Field(
default=False,
description="Whether to perform profiling at table-level only or include column-level profiling as well.",
)
pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description=(
"Regex patterns to filter tables for profiling during ingestion. "
"Specify regex to match the `catalog.schema.table` format. "
"Note that only tables allowed by the `table_pattern` will be considered."
),
)
call_analyze: bool = Field(
default=True,
description=(
"Whether to call ANALYZE TABLE as part of profile ingestion."
"If false, will ingest the results of the most recent ANALYZE TABLE call, if any."
),
)
max_wait_secs: int = Field(
default=int(timedelta(hours=1).total_seconds()),
description="Maximum time to wait for an ANALYZE TABLE query to complete.",
)
max_workers: int = Field(
default=5 * (os.cpu_count() or 4),
description="Number of worker threads to use for profiling. Set to 1 to disable.",
)
@pydantic.root_validator
def warehouse_id_required_for_profiling(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
if values.get("enabled") and not values.get("warehouse_id"):
raise ValueError("warehouse_id must be set when profiling is enabled.")
return values
@property
def include_columns(self):
return not self.profile_table_level_only
class UnityCatalogSourceConfig(
StatefulIngestionConfigBase, BaseUsageConfig, DatasetSourceConfigMixin
StatefulIngestionConfigBase,
BaseUsageConfig,
DatasetSourceConfigMixin,
StatefulProfilingConfigMixin,
):
token: str = pydantic.Field(description="Databricks personal access token")
workspace_url: str = pydantic.Field(
@ -76,6 +136,10 @@ class UnityCatalogSourceConfig(
description="Generate usage statistics.",
)
profiling: UnityCatalogProfilerConfig = Field(
default=UnityCatalogProfilerConfig(), description="Data profiling configuration"
)
stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = pydantic.Field(
default=None, description="Unity Catalog Stateful Ingestion Config."
)

View File

@ -0,0 +1,119 @@
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import Callable, Collection, Iterable, Optional
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.unity.config import UnityCatalogProfilerConfig
from datahub.ingestion.source.unity.proxy import UnityCatalogApiProxy
from datahub.ingestion.source.unity.proxy_types import (
ColumnProfile,
TableProfile,
TableReference,
)
from datahub.ingestion.source.unity.report import UnityCatalogReport
from datahub.metadata.schema_classes import (
DatasetFieldProfileClass,
DatasetProfileClass,
)
logger = logging.getLogger(__name__)
@dataclass
class UnityCatalogProfiler:
config: UnityCatalogProfilerConfig
report: UnityCatalogReport
proxy: UnityCatalogApiProxy
dataset_urn_builder: Callable[[TableReference], str]
def get_workunits(
self, table_refs: Collection[TableReference]
) -> Iterable[MetadataWorkUnit]:
try:
tables = self._filter_tables(table_refs)
with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
futures = [executor.submit(self.process_table, ref) for ref in tables]
for future in as_completed(futures):
wu: Optional[MetadataWorkUnit] = future.result()
if wu:
self.report.num_profile_workunits_emitted += 1
yield wu
except Exception as e:
self.report.report_warning("profiling", str(e))
logger.warning(f"Unexpected error during profiling: {e}", exc_info=True)
return
def _filter_tables(
self, table_refs: Collection[TableReference]
) -> Collection[TableReference]:
return [
ref
for ref in table_refs
if self.config.pattern.allowed(ref.qualified_table_name)
]
def process_table(self, ref: TableReference) -> Optional[MetadataWorkUnit]:
try:
table_profile = self.proxy.get_table_stats(
ref,
max_wait_secs=self.config.max_wait_secs,
call_analyze=self.config.call_analyze,
include_columns=self.config.include_columns,
)
if table_profile:
return self.gen_dataset_profile_workunit(ref, table_profile)
elif table_profile is not None: # table_profile is Falsy == empty
self.report.profile_table_empty.append(str(ref))
except Exception as e:
self.report.report_warning("profiling", str(e))
logger.warning(
f"Unexpected error during profiling table {ref}: {e}", exc_info=True
)
return None
def gen_dataset_profile_workunit(
self, ref: TableReference, table_profile: TableProfile
) -> MetadataWorkUnit:
row_count = table_profile.num_rows
aspect = DatasetProfileClass(
timestampMillis=int(time.time() * 1000),
rowCount=row_count,
columnCount=table_profile.num_columns,
sizeInBytes=table_profile.total_size,
fieldProfiles=[
self._gen_dataset_field_profile(row_count, column_profile)
for column_profile in table_profile.column_profiles
if column_profile # Drop column profiles with no data
]
if self.config.include_columns
else None,
)
return MetadataChangeProposalWrapper(
entityUrn=self.dataset_urn_builder(ref),
aspect=aspect,
).as_workunit()
@staticmethod
def _gen_dataset_field_profile(
num_rows: Optional[int], column_profile: ColumnProfile
) -> DatasetFieldProfileClass:
unique_proportion: Optional[float] = None
null_proportion: Optional[float] = None
if num_rows:
if column_profile.distinct_count is not None:
unique_proportion = min(1.0, column_profile.distinct_count / num_rows)
if column_profile.null_count is not None:
null_proportion = min(1.0, column_profile.null_count / num_rows)
return DatasetFieldProfileClass(
fieldPath=column_profile.name,
uniqueCount=column_profile.distinct_count,
uniqueProportion=unique_proportion,
nullCount=column_profile.null_count,
nullProportion=null_proportion,
min=column_profile.min,
max=column_profile.max,
)

View File

@ -16,6 +16,9 @@ from databricks.sdk.service.sql import (
from databricks_cli.sdk.api_client import ApiClient
from databricks_cli.unity_catalog.api import UnityCatalogApi
from datahub.ingestion.source.unity.proxy_profiling import (
UnityCatalogProxyProfilingMixin,
)
from datahub.ingestion.source.unity.proxy_types import (
ALLOWED_STATEMENT_TYPES,
DATA_TYPE_REGISTRY,
@ -48,14 +51,19 @@ class QueryFilterWithStatementTypes(QueryFilter):
return v
class UnityCatalogApiProxy:
_unity_catalog_api: UnityCatalogApi
class UnityCatalogApiProxy(UnityCatalogProxyProfilingMixin):
_workspace_client: WorkspaceClient
_unity_catalog_api: UnityCatalogApi
_workspace_url: str
report: UnityCatalogReport
warehouse_id: str
def __init__(
self, workspace_url: str, personal_access_token: str, report: UnityCatalogReport
self,
workspace_url: str,
personal_access_token: str,
warehouse_id: Optional[str],
report: UnityCatalogReport,
):
self._workspace_client = WorkspaceClient(
host=workspace_url, token=personal_access_token
@ -63,7 +71,7 @@ class UnityCatalogApiProxy:
self._unity_catalog_api = UnityCatalogApi(
ApiClient(host=workspace_url, token=personal_access_token)
)
self._workspace_url = workspace_url
self.warehouse_id = warehouse_id or ""
self.report = report
def check_connectivity(self) -> bool:
@ -131,6 +139,7 @@ class UnityCatalogApiProxy:
yield self._create_table(schema=schema, obj=table)
def service_principals(self) -> Iterable[ServicePrincipal]:
# TODO: Replace with self._workspace_client.service_principals.list() when it supports pagination
start_index = 1 # Unfortunately 1-indexed
items_per_page = 0
total_results = float("inf")

View File

@ -0,0 +1,238 @@
import logging
import time
from typing import Optional, Union
from databricks.sdk import WorkspaceClient
from databricks.sdk.core import DatabricksError
from databricks.sdk.service._internal import Wait
from databricks.sdk.service.catalog import TableInfo
from databricks.sdk.service.sql import (
ExecuteStatementResponse,
GetStatementResponse,
GetWarehouseResponse,
StatementState,
StatementStatus,
)
from databricks_cli.unity_catalog.api import UnityCatalogApi
from datahub.ingestion.source.unity.proxy_types import (
ColumnProfile,
TableProfile,
TableReference,
)
from datahub.ingestion.source.unity.report import UnityCatalogReport
from datahub.utilities.lossy_collections import LossyList
logger: logging.Logger = logging.getLogger(__name__)
# TODO: Move to separate proxy/ directory with rest of proxy code
class UnityCatalogProxyProfilingMixin:
_workspace_client: WorkspaceClient
_unity_catalog_api: UnityCatalogApi
report: UnityCatalogReport
warehouse_id: str
def check_profiling_connectivity(self):
self._workspace_client.warehouses.get(self.warehouse_id)
return True
def start_warehouse(self) -> Optional[Wait[GetWarehouseResponse]]:
"""Starts a databricks SQL warehouse.
Returns:
- A Wait object that can be used to wait for warehouse start completion.
- None if the warehouse does not exist.
"""
try:
return self._workspace_client.warehouses.start(self.warehouse_id)
except DatabricksError as e:
logger.warning(f"Unable to start warehouse -- are you sure it exists? {e}")
return None
def get_table_stats(
self,
ref: TableReference,
*,
max_wait_secs: int,
call_analyze: bool,
include_columns: bool,
) -> Optional[TableProfile]:
"""Returns profiling information for a table.
Performs three steps:
1. Call ANALYZE TABLE to compute statistics for all columns
2. Poll for ANALYZE completion with exponential backoff, with `max_wait_secs` timeout
3. Get the ANALYZE result via the properties field in the tables API.
This is supposed to be returned by a DESCRIBE TABLE EXTENDED command, but I don't see it.
Raises:
DatabricksError: If any of the above steps fail
"""
# Currently uses databricks sdk, which is synchronous
# If we need to improve performance, we can manually make requests via aiohttp
try:
if call_analyze:
response = self._analyze_table(ref, include_columns=include_columns)
success = self._check_analyze_table_statement_status(
response, max_wait_secs=max_wait_secs
)
if not success:
self.report.profile_table_timeouts.append(str(ref))
return None
return self._get_table_profile(ref, include_columns=include_columns)
except DatabricksError as e:
# Attempt to parse out generic part of error message
msg = str(e)
idx = (str(msg).find("`") + 1) or (str(msg).find("'") + 1) or len(str(msg))
base_msg = msg[:idx]
self.report.profile_table_errors.setdefault(base_msg, LossyList()).append(
(str(ref), msg)
)
logger.warning(
f"Failure during profiling {ref}, {e.kwargs}: ({e.error_code}) {e}",
exc_info=True,
)
if (
call_analyze
and include_columns
and self._should_retry_unsupported_column(ref, e)
):
return self.get_table_stats(
ref,
max_wait_secs=max_wait_secs,
call_analyze=call_analyze,
include_columns=False,
)
else:
return None
def _should_retry_unsupported_column(
self, ref: TableReference, e: DatabricksError
) -> bool:
if "[UNSUPPORTED_FEATURE.ANALYZE_UNSUPPORTED_COLUMN_TYPE]" in str(e):
logger.info(
f"Attempting to profile table without columns due to unsupported column type: {ref}"
)
self.report.num_profile_failed_unsupported_column_type += 1
return True
return False
def _analyze_table(
self, ref: TableReference, include_columns: bool
) -> ExecuteStatementResponse:
statement = f"ANALYZE TABLE {ref.schema}.{ref.table} COMPUTE STATISTICS"
if include_columns:
statement += " FOR ALL COLUMNS"
response = self._workspace_client.statement_execution.execute_statement(
statement=statement,
catalog=ref.catalog,
wait_timeout="0s", # Fetch result asynchronously
warehouse_id=self.warehouse_id,
)
self._raise_if_error(response, "analyze-table")
return response
def _check_analyze_table_statement_status(
self, execute_response: ExecuteStatementResponse, max_wait_secs: int
) -> bool:
statement_id: str = execute_response.statement_id
status: StatementStatus = execute_response.status
backoff_sec = 1
total_wait_time = 0
while (
total_wait_time < max_wait_secs and status.state != StatementState.SUCCEEDED
):
time.sleep(min(backoff_sec, max_wait_secs - total_wait_time))
total_wait_time += backoff_sec
backoff_sec *= 2
response = self._workspace_client.statement_execution.get_statement(
statement_id
)
self._raise_if_error(response, "get-statement")
status = response.status
return status.state == StatementState.SUCCEEDED
def _get_table_profile(
self, ref: TableReference, include_columns: bool
) -> TableProfile:
table_info = self._workspace_client.tables.get(ref.qualified_table_name)
return self._create_table_profile(table_info, include_columns=include_columns)
def _create_table_profile(
self, table_info: TableInfo, include_columns: bool
) -> TableProfile:
# Warning: this implementation is brittle -- dependent on properties that can change
columns_names = [column.name for column in table_info.columns]
return TableProfile(
num_rows=self._get_int(table_info, "spark.sql.statistics.numRows"),
total_size=self._get_int(table_info, "spark.sql.statistics.totalSize"),
num_columns=len(columns_names),
column_profiles=[
self._create_column_profile(column, table_info)
for column in columns_names
]
if include_columns
else [],
)
def _create_column_profile(
self, column: str, table_info: TableInfo
) -> ColumnProfile:
return ColumnProfile(
name=column,
null_count=self._get_int(
table_info, f"spark.sql.statistics.colStats.{column}.nullCount"
),
distinct_count=self._get_int(
table_info, f"spark.sql.statistics.colStats.{column}.distinctCount"
),
min=table_info.properties.get(
f"spark.sql.statistics.colStats.{column}.min"
),
max=table_info.properties.get(
f"spark.sql.statistics.colStats.{column}.max"
),
avg_len=table_info.properties.get(
f"spark.sql.statistics.colStats.{column}.avgLen"
),
max_len=table_info.properties.get(
f"spark.sql.statistics.colStats.{column}.maxLen"
),
version=table_info.properties.get(
f"spark.sql.statistics.colStats.{column}.version"
),
)
def _get_int(self, table_info: TableInfo, field: str) -> Optional[int]:
value = table_info.properties.get(field)
if value is not None:
try:
return int(value)
except ValueError:
logger.warning(
f"Failed to parse int for {table_info.name} - {field}: {value}"
)
self.report.num_profile_failed_int_casts += 1
return None
@staticmethod
def _raise_if_error(
response: Union[ExecuteStatementResponse, GetStatementResponse], key: str
) -> None:
if response.status.state in [
StatementState.FAILED,
StatementState.CANCELED,
StatementState.CLOSED,
]:
raise DatabricksError(
response.status.error.message,
error_code=response.status.error.error_code.value,
status=response.status.state.value,
context=key,
)

View File

@ -174,3 +174,44 @@ class Query:
# User whose credentials were used to run the query
executed_as_user_id: int
executed_as_user_name: str
@dataclass
class TableProfile:
num_rows: Optional[int]
num_columns: Optional[int]
total_size: Optional[int]
column_profiles: List["ColumnProfile"]
def __bool__(self):
return any(
(
self.num_rows is not None,
self.num_columns is not None,
self.total_size is not None,
any(self.column_profiles),
)
)
@dataclass
class ColumnProfile:
name: str
null_count: Optional[int]
distinct_count: Optional[int]
min: Optional[str]
max: Optional[str]
version: Optional[str]
avg_len: Optional[str]
max_len: Optional[str]
def __bool__(self):
return any(
(
self.null_count is not None,
self.distinct_count is not None,
self.min is not None,
self.max is not None,
)
)

View File

@ -1,9 +1,11 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Tuple
from datahub.ingestion.api.report import EntityFilterReport
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalSourceReport,
)
from datahub.utilities.lossy_collections import LossyDict, LossyList
@dataclass
@ -12,6 +14,7 @@ class UnityCatalogReport(StaleEntityRemovalSourceReport):
catalogs: EntityFilterReport = EntityFilterReport.field(type="catalog")
schemas: EntityFilterReport = EntityFilterReport.field(type="schema")
tables: EntityFilterReport = EntityFilterReport.field(type="table/view")
table_profiles: EntityFilterReport = EntityFilterReport.field(type="table profile")
num_queries: int = 0
num_queries_dropped_parse_failure: int = 0
@ -21,3 +24,12 @@ class UnityCatalogReport(StaleEntityRemovalSourceReport):
num_operational_stats_workunits_emitted: int = 0
num_usage_workunits_emitted: int = 0
profile_table_timeouts: LossyList[str] = field(default_factory=LossyList)
profile_table_empty: LossyList[str] = field(default_factory=LossyList)
profile_table_errors: LossyDict[str, LossyList[Tuple[str, str]]] = field(
default_factory=LossyDict
)
num_profile_failed_unsupported_column_type: int = 0
num_profile_failed_int_casts: int = 0
num_profile_workunits_emitted: int = 0

View File

@ -1,6 +1,7 @@
import logging
import re
import time
from datetime import timedelta
from typing import Dict, Iterable, List, Optional, Set
from datahub.emitter.mce_builder import (
@ -45,15 +46,18 @@ from datahub.ingestion.source.state.stale_entity_removal_handler import (
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionSourceBase,
)
from datahub.ingestion.source.unity import proxy
from datahub.ingestion.source.unity.config import UnityCatalogSourceConfig
from datahub.ingestion.source.unity.proxy import (
from datahub.ingestion.source.unity.profiler import UnityCatalogProfiler
from datahub.ingestion.source.unity.proxy import UnityCatalogApiProxy
from datahub.ingestion.source.unity.proxy_types import (
Catalog,
Column,
Metastore,
Schema,
ServicePrincipal,
Table,
TableReference,
)
from datahub.ingestion.source.unity.proxy_types import TableReference
from datahub.ingestion.source.unity.report import UnityCatalogReport
from datahub.ingestion.source.unity.usage import UnityCatalogUsageExtractor
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
@ -115,7 +119,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
"""
config: UnityCatalogSourceConfig
unity_catalog_api_proxy: proxy.UnityCatalogApiProxy
unity_catalog_api_proxy: UnityCatalogApiProxy
platform: str = "databricks"
platform_instance_name: str
@ -127,8 +131,11 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
self.config = config
self.report: UnityCatalogReport = UnityCatalogReport()
self.unity_catalog_api_proxy = proxy.UnityCatalogApiProxy(
config.workspace_url, config.token, report=self.report
self.unity_catalog_api_proxy = UnityCatalogApiProxy(
config.workspace_url,
config.token,
config.profiling.warehouse_id,
report=self.report,
)
# Determine the platform_instance_name
@ -156,21 +163,43 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
self.service_principals: Dict[str, ServicePrincipal] = {}
# Global set of table refs
self.table_refs: Set[TableReference] = set()
self.view_refs: Set[TableReference] = set()
@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
test_report.capability_report = {}
try:
config = UnityCatalogSourceConfig.parse_obj_allow_extras(config_dict)
report = UnityCatalogReport()
unity_proxy = proxy.UnityCatalogApiProxy(
config.workspace_url, config.token, report=report
unity_proxy = UnityCatalogApiProxy(
config.workspace_url,
config.token,
config.profiling.warehouse_id,
report=report,
)
if unity_proxy.check_connectivity():
test_report.basic_connectivity = CapabilityReport(capable=True)
else:
test_report.basic_connectivity = CapabilityReport(capable=False)
# TODO: Refactor into separate file / method
if config.profiling.enabled and not config.profiling.warehouse_id:
test_report.capability_report[
SourceCapability.DATA_PROFILING
] = CapabilityReport(
capable=False, failure_reason="Warehouse ID not provided"
)
elif config.profiling.enabled:
try:
unity_proxy.check_profiling_connectivity()
test_report.capability_report[
SourceCapability.DATA_PROFILING
] = CapabilityReport(capable=True)
except Exception as e:
test_report.capability_report[
SourceCapability.DATA_PROFILING
] = CapabilityReport(capable=False, failure_reason=str(e))
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=f"{e}"
@ -192,6 +221,17 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
)
def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
wait_on_warehouse = None
if self.config.profiling.enabled:
# Can take several minutes, so start now and wait later
wait_on_warehouse = self.unity_catalog_api_proxy.start_warehouse()
if wait_on_warehouse is None:
self.report.report_failure(
"initialization",
f"SQL warehouse {self.config.profiling.warehouse_id} not found",
)
return
self.build_service_principal_map()
yield from self.process_metastores()
@ -203,7 +243,19 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
table_urn_builder=self.gen_dataset_urn,
user_urn_builder=self.gen_user_urn,
)
yield from usage_extractor.run(self.table_refs)
yield from usage_extractor.run(self.table_refs | self.view_refs)
if self.config.profiling.enabled:
assert wait_on_warehouse
timeout = timedelta(seconds=self.config.profiling.max_wait_secs)
wait_on_warehouse.result(timeout)
profiling_extractor = UnityCatalogProfiler(
self.config.profiling,
self.report,
self.unity_catalog_api_proxy,
self.gen_dataset_urn,
)
yield from profiling_extractor.get_workunits(self.table_refs)
def build_service_principal_map(self) -> None:
try:
@ -233,9 +285,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
self.report.metastores.processed(metastore.metastore_id)
def process_catalogs(
self, metastore: proxy.Metastore
) -> Iterable[MetadataWorkUnit]:
def process_catalogs(self, metastore: Metastore) -> Iterable[MetadataWorkUnit]:
for catalog in self.unity_catalog_api_proxy.catalogs(metastore=metastore):
if not self.config.catalog_pattern.allowed(catalog.id):
self.report.catalogs.dropped(catalog.id)
@ -246,7 +296,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
self.report.catalogs.processed(catalog.id)
def process_schemas(self, catalog: proxy.Catalog) -> Iterable[MetadataWorkUnit]:
def process_schemas(self, catalog: Catalog) -> Iterable[MetadataWorkUnit]:
for schema in self.unity_catalog_api_proxy.schemas(catalog=catalog):
if not self.config.schema_pattern.allowed(schema.id):
self.report.schemas.dropped(schema.id)
@ -257,19 +307,20 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
self.report.schemas.processed(schema.id)
def process_tables(self, schema: proxy.Schema) -> Iterable[MetadataWorkUnit]:
def process_tables(self, schema: Schema) -> Iterable[MetadataWorkUnit]:
for table in self.unity_catalog_api_proxy.tables(schema=schema):
if not self.config.table_pattern.allowed(table.ref.qualified_table_name):
self.report.tables.dropped(table.id, type=table.type)
continue
self.table_refs.add(table.ref)
if table.type.lower() == "view":
self.view_refs.add(table.ref)
else:
self.table_refs.add(table.ref)
yield from self.process_table(table, schema)
self.report.tables.processed(table.id, type=table.type)
def process_table(
self, table: proxy.Table, schema: proxy.Schema
) -> Iterable[MetadataWorkUnit]:
def process_table(self, table: Table, schema: Schema) -> Iterable[MetadataWorkUnit]:
dataset_urn = self.gen_dataset_urn(table.ref)
yield from self.add_table_to_dataset_container(dataset_urn, schema)
@ -282,13 +333,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
sub_type = self._create_table_sub_type_aspect(table)
schema_metadata = self._create_schema_metadata_aspect(table)
operation = self._create_table_operation_aspect(table)
domain = self._get_domain_aspect(
dataset_name=str(
f"{table.schema.catalog.name}.{table.schema.name}.{table.name}"
)
)
domain = self._get_domain_aspect(dataset_name=table.ref.qualified_table_name)
ownership = self._create_table_ownership_aspect(table)
lineage: Optional[UpstreamLineageClass] = None
@ -317,7 +362,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
]
def _generate_column_lineage_aspect(
self, dataset_urn: str, table: proxy.Table
self, dataset_urn: str, table: Table
) -> Optional[UpstreamLineageClass]:
upstreams: List[UpstreamClass] = []
finegrained_lineages: List[FineGrainedLineage] = []
@ -353,7 +398,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
return None
def _generate_lineage_aspect(
self, dataset_urn: str, table: proxy.Table
self, dataset_urn: str, table: Table
) -> Optional[UpstreamLineageClass]:
upstreams: List[UpstreamClass] = []
for upstream in sorted(table.upstreams.keys()):
@ -485,9 +530,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
dataset_urn=dataset_urn,
)
def _create_table_property_aspect(
self, table: proxy.Table
) -> DatasetPropertiesClass:
def _create_table_property_aspect(self, table: Table) -> DatasetPropertiesClass:
custom_properties: dict = {}
if table.storage_location is not None:
custom_properties["storage_location"] = table.storage_location
@ -524,7 +567,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
lastModified=last_modified,
)
def _create_table_operation_aspect(self, table: proxy.Table) -> OperationClass:
def _create_table_operation_aspect(self, table: Table) -> OperationClass:
"""Produce an operation aspect for a table.
If a last updated time is present, we produce an update operation.
@ -553,9 +596,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
return operation
def _create_table_ownership_aspect(
self, table: proxy.Table
) -> Optional[OwnershipClass]:
def _create_table_ownership_aspect(self, table: Table) -> Optional[OwnershipClass]:
owner_urn = self.get_owner_urn(table.owner)
if owner_urn is not None:
return OwnershipClass(
@ -568,7 +609,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
)
return None
def _create_table_sub_type_aspect(self, table: proxy.Table) -> SubTypesClass:
def _create_table_sub_type_aspect(self, table: Table) -> SubTypesClass:
return SubTypesClass(
typeNames=[
DatasetSubTypes.VIEW
@ -577,13 +618,13 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
]
)
def _create_view_property_aspect(self, table: proxy.Table) -> ViewProperties:
def _create_view_property_aspect(self, table: Table) -> ViewProperties:
assert table.view_definition
return ViewProperties(
materialized=False, viewLanguage="SQL", viewLogic=table.view_definition
)
def _create_schema_metadata_aspect(self, table: proxy.Table) -> SchemaMetadataClass:
def _create_schema_metadata_aspect(self, table: Table) -> SchemaMetadataClass:
schema_fields: List[SchemaFieldClass] = []
for column in table.columns:
@ -599,7 +640,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
)
@staticmethod
def _create_schema_field(column: proxy.Column) -> List[SchemaFieldClass]:
def _create_schema_field(column: Column) -> List[SchemaFieldClass]:
_COMPLEX_TYPE = re.compile("^(struct|array)")
if _COMPLEX_TYPE.match(column.type_text.lower()):

View File

@ -33,6 +33,34 @@ def test_within_thirty_days():
)
def test_profiling_requires_warehouses_id():
config = UnityCatalogSourceConfig.parse_obj(
{
"token": "token",
"workspace_url": "https://workspace_url",
"profiling": {"enabled": True, "warehouse_id": "my_warehouse_id"},
}
)
assert config.profiling.enabled is True
config = UnityCatalogSourceConfig.parse_obj(
{
"token": "token",
"workspace_url": "https://workspace_url",
"profiling": {"enabled": False},
}
)
assert config.profiling.enabled is False
with pytest.raises(ValueError):
UnityCatalogSourceConfig.parse_obj(
{
"token": "token",
"workspace_url": "workspace_url",
}
)
@freeze_time(FROZEN_TIME)
def test_workspace_url_should_start_with_https():
@ -41,5 +69,6 @@ def test_workspace_url_should_start_with_https():
{
"token": "token",
"workspace_url": "workspace_url",
"profiling": {"enabled": True},
}
)