Fixes #12127: Add Support for Complex types of Databricks & UnityCatalog in profiler (#15976)

This commit is contained in:
Ayush Shah 2024-04-23 15:54:36 +05:30 committed by GitHub
parent df5d5e1866
commit 0963a111fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 128 additions and 22 deletions

View File

@ -13,7 +13,11 @@ Define custom types as wrappers on top of
existing SQA types to have a bridge between existing SQA types to have a bridge between
SQA dialects and OM rich type system SQA dialects and OM rich type system
""" """
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy.sql.sqltypes import TypeDecorator
from metadata.utils.sqlalchemy_utils import convert_numpy_to_list
class SQAMap(types.String): class SQAMap(types.String):
@ -22,11 +26,26 @@ class SQAMap(types.String):
""" """
class SQAStruct(types.String): class SQAStruct(TypeDecorator):
""" """
Custom Struct type definition Custom Struct type definition
""" """
impl = types.String
cache_ok = True
def process_result_value(self, value, dialect):
"""This is executed during result retrieval
Args:
value: database record
dialect: database dialect
Returns:
python list conversion of ndarray
"""
return convert_numpy_to_list(value)
class SQADateTimeRange(types.String): class SQADateTimeRange(types.String):
""" """

View File

@ -13,14 +13,93 @@
Interfaces with database for all database engine Interfaces with database for all database engine
supporting sqlalchemy abstraction layer supporting sqlalchemy abstraction layer
""" """
from typing import List
from pyhive.sqlalchemy_hive import HiveCompiler
from sqlalchemy import Column, inspect
from metadata.generated.schema.entity.data.table import Column as OMColumn
from metadata.generated.schema.entity.data.table import ColumnName, DataType, TableData
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
)
from metadata.profiler.interface.sqlalchemy.profiler_interface import ( from metadata.profiler.interface.sqlalchemy.profiler_interface import (
SQAProfilerInterface, SQAProfilerInterface,
) )
from metadata.profiler.orm.converter.base import build_orm_col
class DatabricksProfilerInterface(SQAProfilerInterface): class DatabricksProfilerInterface(SQAProfilerInterface):
"""Databricks profiler interface"""
def visit_column(self, *args, **kwargs):
result = super( # pylint: disable=bad-super-call
HiveCompiler, self
).visit_column(*args, **kwargs)
dot_count = result.count(".")
# Here the databricks uses HiveCompiler.
# the `result` here would be `db.schema.table` or `db.schema.table.column`
# for struct it will be `db.schema.table.column.nestedchild.nestedchild` etc
# the logic is to add the backticks to nested children.
if dot_count > 2:
splitted_result = result.split(".", 2)[-1].split(".")
result = ".".join(result.split(".", 2)[:-1])
result += "." + "`.`".join(splitted_result)
return result
HiveCompiler.visit_column = visit_column
def __init__(self, service_connection_config, **kwargs): def __init__(self, service_connection_config, **kwargs):
super().__init__(service_connection_config=service_connection_config, **kwargs) super().__init__(service_connection_config=service_connection_config, **kwargs)
self.set_catalog(self.session) self.set_catalog(self.session)
def _get_struct_columns(self, columns: List[OMColumn], parent: str):
"""Get struct columns"""
columns_list = []
for idx, col in enumerate(columns):
if col.dataType != DataType.STRUCT:
col.name = ColumnName(__root__=f"{parent}.{col.name.__root__}")
col = build_orm_col(idx, col, DatabaseServiceType.Databricks)
col._set_parent( # pylint: disable=protected-access
self.table.__table__
)
columns_list.append(col)
else:
col = self._get_struct_columns(
col.children, f"{parent}.{col.name.__root__}"
)
columns_list.extend(col)
return columns_list
def get_columns(self) -> Column:
"""Get columns from table"""
columns = []
for idx, column in enumerate(self.table_entity.columns):
if column.dataType == DataType.STRUCT:
columns.extend(
self._get_struct_columns(column.children, column.name.__root__)
)
else:
col = build_orm_col(idx, column, DatabaseServiceType.Databricks)
col._set_parent( # pylint: disable=protected-access
self.table.__table__
)
columns.append(col)
return columns
def fetch_sample_data(self, table, columns) -> TableData:
"""Fetch sample data from database
Args:
table: ORM declarative table
Returns:
TableData: sample table data
"""
sampler = self._get_sampler(
table=table,
)
return sampler.fetch_sample_data(list(inspect(self.table).c))

View File

@ -18,15 +18,12 @@ supporting sqlalchemy abstraction layer
from metadata.ingestion.source.database.databricks.connection import ( from metadata.ingestion.source.database.databricks.connection import (
get_connection as databricks_get_connection, get_connection as databricks_get_connection,
) )
from metadata.profiler.interface.sqlalchemy.profiler_interface import ( from metadata.profiler.interface.sqlalchemy.databricks.profiler_interface import (
SQAProfilerInterface, DatabricksProfilerInterface,
) )
class UnityCatalogProfilerInterface(SQAProfilerInterface): class UnityCatalogProfilerInterface(DatabricksProfilerInterface):
def __init__(self, service_connection_config, **kwargs):
super().__init__(service_connection_config=service_connection_config, **kwargs)
def create_session(self): def create_session(self):
self.connection = databricks_get_connection(self.service_connection_config) self.connection = databricks_get_connection(self.service_connection_config)
super().create_session() super().create_session()

View File

@ -68,7 +68,7 @@ class MissingMetricException(Exception):
""" """
class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attributes class Profiler(Generic[TMetric]):
""" """
Core Profiler. Core Profiler.
@ -122,7 +122,6 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute
# We will get columns from the property # We will get columns from the property
self._columns: Optional[List[Column]] = None self._columns: Optional[List[Column]] = None
self.fetch_column_from_property()
self.data_frame_list = None self.data_frame_list = None
@property @property
@ -176,14 +175,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute
if column.name not in self._get_excluded_columns() if column.name not in self._get_excluded_columns()
] ]
return [ return self._columns
column
for column in self._columns
if column.type.__class__.__name__ not in NOT_COMPUTE
]
def fetch_column_from_property(self) -> Optional[List[Column]]:
self._columns = self.columns
def _get_excluded_columns(self) -> Optional[Set[str]]: def _get_excluded_columns(self) -> Optional[Set[str]]:
"""Get excluded columns for table being profiled""" """Get excluded columns for table being profiled"""
@ -385,6 +377,11 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute
def _prepare_column_metrics(self) -> List: def _prepare_column_metrics(self) -> List:
"""prepare column metrics""" """prepare column metrics"""
column_metrics_for_thread_pool = [] column_metrics_for_thread_pool = []
columns = [
column
for column in self.columns
if column.type.__class__.__name__ not in NOT_COMPUTE
]
static_metrics = [ static_metrics = [
ThreadPoolMetrics( ThreadPoolMetrics(
metrics=[ metrics=[
@ -400,7 +397,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute
column=column, column=column,
table=self.table, table=self.table,
) )
for column in self.columns for column in columns
] ]
query_metrics = [ query_metrics = [
ThreadPoolMetrics( ThreadPoolMetrics(
@ -409,7 +406,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute
column=column, column=column,
table=self.table, table=self.table,
) )
for column in self.columns for column in columns
for metric in self.metric_filter.get_column_metrics( for metric in self.metric_filter.get_column_metrics(
QueryMetric, column, self.profiler_interface.table_entity.serviceType QueryMetric, column, self.profiler_interface.table_entity.serviceType
) )
@ -429,7 +426,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute
column=column, column=column,
table=self.table, table=self.table,
) )
for column in self.columns for column in columns
] ]
# we'll add the system metrics to the thread pool computation # we'll add the system metrics to the thread pool computation
@ -437,7 +434,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute
column_metrics_for_thread_pool.extend(metric_type) column_metrics_for_thread_pool.extend(metric_type)
# we'll add the custom metrics to the thread pool computation # we'll add the custom metrics to the thread pool computation
for column in self.columns: for column in columns:
custom_metrics = self.get_custom_metrics(column.name) custom_metrics = self.get_custom_metrics(column.name)
if custom_metrics: if custom_metrics:
column_metrics_for_thread_pool.append( column_metrics_for_thread_pool.append(

View File

@ -110,3 +110,18 @@ def get_display_datatype(
if scale is not None and precision is not None: if scale is not None and precision is not None:
return f"{col_type}({str(precision)},{str(scale)})" return f"{col_type}({str(precision)},{str(scale)})"
return col_type return col_type
def convert_numpy_to_list(data):
"""
Recursively converts numpy arrays to lists in a nested data structure.
"""
import numpy as np # pylint: disable=import-outside-toplevel
if isinstance(data, np.ndarray):
return data.tolist()
if isinstance(data, list):
return [convert_numpy_to_list(item) for item in data]
if isinstance(data, dict):
return {key: convert_numpy_to_list(value) for key, value in data.items()}
return data

View File

@ -1,7 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import boto3 import boto3
import botocore
import pytest import pytest
from testcontainers.localstack import LocalStackContainer from testcontainers.localstack import LocalStackContainer