Fixes #12127: Add Support for Complex types of Databricks & UnityCatalog in profiler (#15976)

This commit is contained in:
Ayush Shah 2024-04-23 15:54:36 +05:30 committed by GitHub
parent df5d5e1866
commit 0963a111fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 128 additions and 22 deletions

View File

@ -13,7 +13,11 @@ Define custom types as wrappers on top of
existing SQA types to have a bridge between
SQA dialects and OM rich type system
"""
from sqlalchemy import types
from sqlalchemy.sql.sqltypes import TypeDecorator
from metadata.utils.sqlalchemy_utils import convert_numpy_to_list
class SQAMap(types.String):
@ -22,11 +26,26 @@ class SQAMap(types.String):
"""
class SQAStruct(types.String):
class SQAStruct(TypeDecorator):
"""
Custom Struct type definition
"""
impl = types.String
cache_ok = True
def process_result_value(self, value, dialect):
"""This is executed during result retrieval
Args:
value: database record
dialect: database dialect
Returns:
python list conversion of ndarray
"""
return convert_numpy_to_list(value)
class SQADateTimeRange(types.String):
"""

View File

@ -13,14 +13,93 @@
Interfaces with database for all database engine
supporting sqlalchemy abstraction layer
"""
from typing import List
from pyhive.sqlalchemy_hive import HiveCompiler
from sqlalchemy import Column, inspect
from metadata.generated.schema.entity.data.table import Column as OMColumn
from metadata.generated.schema.entity.data.table import ColumnName, DataType, TableData
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
)
from metadata.profiler.interface.sqlalchemy.profiler_interface import (
SQAProfilerInterface,
)
from metadata.profiler.orm.converter.base import build_orm_col
class DatabricksProfilerInterface(SQAProfilerInterface):
"""Databricks profiler interface"""
def visit_column(self, *args, **kwargs):
result = super( # pylint: disable=bad-super-call
HiveCompiler, self
).visit_column(*args, **kwargs)
dot_count = result.count(".")
# Here the databricks uses HiveCompiler.
# the `result` here would be `db.schema.table` or `db.schema.table.column`
# for struct it will be `db.schema.table.column.nestedchild.nestedchild` etc
# the logic is to add the backticks to nested children.
if dot_count > 2:
splitted_result = result.split(".", 2)[-1].split(".")
result = ".".join(result.split(".", 2)[:-1])
result += "." + "`.`".join(splitted_result)
return result
HiveCompiler.visit_column = visit_column
def __init__(self, service_connection_config, **kwargs):
super().__init__(service_connection_config=service_connection_config, **kwargs)
self.set_catalog(self.session)
def _get_struct_columns(self, columns: List[OMColumn], parent: str):
"""Get struct columns"""
columns_list = []
for idx, col in enumerate(columns):
if col.dataType != DataType.STRUCT:
col.name = ColumnName(__root__=f"{parent}.{col.name.__root__}")
col = build_orm_col(idx, col, DatabaseServiceType.Databricks)
col._set_parent( # pylint: disable=protected-access
self.table.__table__
)
columns_list.append(col)
else:
col = self._get_struct_columns(
col.children, f"{parent}.{col.name.__root__}"
)
columns_list.extend(col)
return columns_list
def get_columns(self) -> Column:
"""Get columns from table"""
columns = []
for idx, column in enumerate(self.table_entity.columns):
if column.dataType == DataType.STRUCT:
columns.extend(
self._get_struct_columns(column.children, column.name.__root__)
)
else:
col = build_orm_col(idx, column, DatabaseServiceType.Databricks)
col._set_parent( # pylint: disable=protected-access
self.table.__table__
)
columns.append(col)
return columns
def fetch_sample_data(self, table, columns) -> TableData:
"""Fetch sample data from database
Args:
table: ORM declarative table
Returns:
TableData: sample table data
"""
sampler = self._get_sampler(
table=table,
)
return sampler.fetch_sample_data(list(inspect(self.table).c))

View File

@ -18,15 +18,12 @@ supporting sqlalchemy abstraction layer
from metadata.ingestion.source.database.databricks.connection import (
get_connection as databricks_get_connection,
)
from metadata.profiler.interface.sqlalchemy.profiler_interface import (
SQAProfilerInterface,
from metadata.profiler.interface.sqlalchemy.databricks.profiler_interface import (
DatabricksProfilerInterface,
)
class UnityCatalogProfilerInterface(SQAProfilerInterface):
def __init__(self, service_connection_config, **kwargs):
super().__init__(service_connection_config=service_connection_config, **kwargs)
class UnityCatalogProfilerInterface(DatabricksProfilerInterface):
def create_session(self):
self.connection = databricks_get_connection(self.service_connection_config)
super().create_session()

View File

@ -68,7 +68,7 @@ class MissingMetricException(Exception):
"""
class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attributes
class Profiler(Generic[TMetric]):
"""
Core Profiler.
@ -122,7 +122,6 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute
# We will get columns from the property
self._columns: Optional[List[Column]] = None
self.fetch_column_from_property()
self.data_frame_list = None
@property
@ -176,14 +175,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute
if column.name not in self._get_excluded_columns()
]
return [
column
for column in self._columns
if column.type.__class__.__name__ not in NOT_COMPUTE
]
def fetch_column_from_property(self) -> Optional[List[Column]]:
self._columns = self.columns
return self._columns
def _get_excluded_columns(self) -> Optional[Set[str]]:
"""Get excluded columns for table being profiled"""
@ -385,6 +377,11 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute
def _prepare_column_metrics(self) -> List:
"""prepare column metrics"""
column_metrics_for_thread_pool = []
columns = [
column
for column in self.columns
if column.type.__class__.__name__ not in NOT_COMPUTE
]
static_metrics = [
ThreadPoolMetrics(
metrics=[
@ -400,7 +397,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute
column=column,
table=self.table,
)
for column in self.columns
for column in columns
]
query_metrics = [
ThreadPoolMetrics(
@ -409,7 +406,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute
column=column,
table=self.table,
)
for column in self.columns
for column in columns
for metric in self.metric_filter.get_column_metrics(
QueryMetric, column, self.profiler_interface.table_entity.serviceType
)
@ -429,7 +426,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute
column=column,
table=self.table,
)
for column in self.columns
for column in columns
]
# we'll add the system metrics to the thread pool computation
@ -437,7 +434,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute
column_metrics_for_thread_pool.extend(metric_type)
# we'll add the custom metrics to the thread pool computation
for column in self.columns:
for column in columns:
custom_metrics = self.get_custom_metrics(column.name)
if custom_metrics:
column_metrics_for_thread_pool.append(

View File

@ -110,3 +110,18 @@ def get_display_datatype(
if scale is not None and precision is not None:
return f"{col_type}({str(precision)},{str(scale)})"
return col_type
def convert_numpy_to_list(data):
"""
Recursively converts numpy arrays to lists in a nested data structure.
"""
import numpy as np # pylint: disable=import-outside-toplevel
if isinstance(data, np.ndarray):
return data.tolist()
if isinstance(data, list):
return [convert_numpy_to_list(item) for item in data]
if isinstance(data, dict):
return {key: convert_numpy_to_list(value) for key, value in data.items()}
return data

View File

@ -1,7 +1,6 @@
from typing import TYPE_CHECKING
import boto3
import botocore
import pytest
from testcontainers.localstack import LocalStackContainer