From 0963a111fe89ecc8b3442aa3aa0d60c58d4ced91 Mon Sep 17 00:00:00 2001 From: Ayush Shah Date: Tue, 23 Apr 2024 15:54:36 +0530 Subject: [PATCH] Fixes #12127: Add Support for Complex types of Databricks & UnityCatalog in profiler (#15976) --- .../metadata/ingestion/source/sqa_types.py | 21 ++++- .../databricks/profiler_interface.py | 79 +++++++++++++++++++ .../unity_catalog/profiler_interface.py | 9 +-- .../src/metadata/profiler/processor/core.py | 25 +++--- .../src/metadata/utils/sqlalchemy_utils.py | 15 ++++ .../tests/integration/profiler/conftest.py | 1 - 6 files changed, 128 insertions(+), 22 deletions(-) diff --git a/ingestion/src/metadata/ingestion/source/sqa_types.py b/ingestion/src/metadata/ingestion/source/sqa_types.py index d36ddadb225..038c4573a7f 100644 --- a/ingestion/src/metadata/ingestion/source/sqa_types.py +++ b/ingestion/src/metadata/ingestion/source/sqa_types.py @@ -13,7 +13,11 @@ Define custom types as wrappers on top of existing SQA types to have a bridge between SQA dialects and OM rich type system """ + from sqlalchemy import types +from sqlalchemy.sql.sqltypes import TypeDecorator + +from metadata.utils.sqlalchemy_utils import convert_numpy_to_list class SQAMap(types.String): @@ -22,11 +26,26 @@ class SQAMap(types.String): """ -class SQAStruct(types.String): +class SQAStruct(TypeDecorator): """ Custom Struct type definition """ + impl = types.String + cache_ok = True + + def process_result_value(self, value, dialect): + """This is executed during result retrieval + + Args: + value: database record + dialect: database dialect + Returns: + python list conversion of ndarray + """ + + return convert_numpy_to_list(value) + class SQADateTimeRange(types.String): """ diff --git a/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py b/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py index bbcbc85d7a9..35638433c42 100644 --- a/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py +++ b/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py @@ -13,14 +13,93 @@ Interfaces with database for all database engine supporting sqlalchemy abstraction layer """ +from typing import List +from pyhive.sqlalchemy_hive import HiveCompiler +from sqlalchemy import Column, inspect +from metadata.generated.schema.entity.data.table import Column as OMColumn +from metadata.generated.schema.entity.data.table import ColumnName, DataType, TableData +from metadata.generated.schema.entity.services.databaseService import ( + DatabaseServiceType, +) from metadata.profiler.interface.sqlalchemy.profiler_interface import ( SQAProfilerInterface, ) +from metadata.profiler.orm.converter.base import build_orm_col class DatabricksProfilerInterface(SQAProfilerInterface): + """Databricks profiler interface""" + + def visit_column(self, *args, **kwargs): + result = super( # pylint: disable=bad-super-call + HiveCompiler, self + ).visit_column(*args, **kwargs) + dot_count = result.count(".") + # Here the databricks uses HiveCompiler. + # the `result` here would be `db.schema.table` or `db.schema.table.column` + # for struct it will be `db.schema.table.column.nestedchild.nestedchild` etc + # the logic is to add the backticks to nested children. + if dot_count > 2: + splitted_result = result.split(".", 2)[-1].split(".") + result = ".".join(result.split(".", 2)[:-1]) + result += "." + "`.`".join(splitted_result) + + return result + + HiveCompiler.visit_column = visit_column + def __init__(self, service_connection_config, **kwargs): super().__init__(service_connection_config=service_connection_config, **kwargs) self.set_catalog(self.session) + + def _get_struct_columns(self, columns: List[OMColumn], parent: str): + """Get struct columns""" + + columns_list = [] + for idx, col in enumerate(columns): + if col.dataType != DataType.STRUCT: + col.name = ColumnName(__root__=f"{parent}.{col.name.__root__}") + col = build_orm_col(idx, col, DatabaseServiceType.Databricks) + col._set_parent( # pylint: disable=protected-access + self.table.__table__ + ) + columns_list.append(col) + else: + col = self._get_struct_columns( + col.children, f"{parent}.{col.name.__root__}" + ) + columns_list.extend(col) + return columns_list + + def get_columns(self) -> Column: + """Get columns from table""" + columns = [] + for idx, column in enumerate(self.table_entity.columns): + if column.dataType == DataType.STRUCT: + columns.extend( + self._get_struct_columns(column.children, column.name.__root__) + ) + else: + col = build_orm_col(idx, column, DatabaseServiceType.Databricks) + col._set_parent( # pylint: disable=protected-access + self.table.__table__ + ) + columns.append(col) + return columns + + def fetch_sample_data(self, table, columns) -> TableData: + """Fetch sample data from database + + Args: + table: ORM declarative table + + Returns: + TableData: sample table data + """ + sampler = self._get_sampler( + table=table, + ) + + return sampler.fetch_sample_data(list(inspect(self.table).c)) diff --git a/ingestion/src/metadata/profiler/interface/sqlalchemy/unity_catalog/profiler_interface.py b/ingestion/src/metadata/profiler/interface/sqlalchemy/unity_catalog/profiler_interface.py index 148dd38e105..5c44544b72a 100644 --- a/ingestion/src/metadata/profiler/interface/sqlalchemy/unity_catalog/profiler_interface.py +++ b/ingestion/src/metadata/profiler/interface/sqlalchemy/unity_catalog/profiler_interface.py @@ -18,15 +18,12 @@ supporting sqlalchemy abstraction layer from metadata.ingestion.source.database.databricks.connection import ( get_connection as databricks_get_connection, ) -from metadata.profiler.interface.sqlalchemy.profiler_interface import ( - SQAProfilerInterface, +from metadata.profiler.interface.sqlalchemy.databricks.profiler_interface import ( + DatabricksProfilerInterface, ) -class UnityCatalogProfilerInterface(SQAProfilerInterface): - def __init__(self, service_connection_config, **kwargs): - super().__init__(service_connection_config=service_connection_config, **kwargs) - +class UnityCatalogProfilerInterface(DatabricksProfilerInterface): def create_session(self): self.connection = databricks_get_connection(self.service_connection_config) super().create_session() diff --git a/ingestion/src/metadata/profiler/processor/core.py b/ingestion/src/metadata/profiler/processor/core.py index 9e5e02efe94..e1bda12a9cd 100644 --- a/ingestion/src/metadata/profiler/processor/core.py +++ b/ingestion/src/metadata/profiler/processor/core.py @@ -68,7 +68,7 @@ class MissingMetricException(Exception): """ -class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attributes +class Profiler(Generic[TMetric]): """ Core Profiler. @@ -122,7 +122,6 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute # We will get columns from the property self._columns: Optional[List[Column]] = None - self.fetch_column_from_property() self.data_frame_list = None @property @@ -176,14 +175,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute if column.name not in self._get_excluded_columns() ] - return [ - column - for column in self._columns - if column.type.__class__.__name__ not in NOT_COMPUTE - ] - - def fetch_column_from_property(self) -> Optional[List[Column]]: - self._columns = self.columns + return self._columns def _get_excluded_columns(self) -> Optional[Set[str]]: """Get excluded columns for table being profiled""" @@ -385,6 +377,11 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute def _prepare_column_metrics(self) -> List: """prepare column metrics""" column_metrics_for_thread_pool = [] + columns = [ + column + for column in self.columns + if column.type.__class__.__name__ not in NOT_COMPUTE + ] static_metrics = [ ThreadPoolMetrics( metrics=[ @@ -400,7 +397,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute column=column, table=self.table, ) - for column in self.columns + for column in columns ] query_metrics = [ ThreadPoolMetrics( @@ -409,7 +406,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute column=column, table=self.table, ) - for column in self.columns + for column in columns for metric in self.metric_filter.get_column_metrics( QueryMetric, column, self.profiler_interface.table_entity.serviceType ) @@ -429,7 +426,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute column=column, table=self.table, ) - for column in self.columns + for column in columns ] # we'll add the system metrics to the thread pool computation @@ -437,7 +434,7 @@ class Profiler(Generic[TMetric]): # pylint: disable=too-many-instance-attribute column_metrics_for_thread_pool.extend(metric_type) # we'll add the custom metrics to the thread pool computation - for column in self.columns: + for column in columns: custom_metrics = self.get_custom_metrics(column.name) if custom_metrics: column_metrics_for_thread_pool.append( diff --git a/ingestion/src/metadata/utils/sqlalchemy_utils.py b/ingestion/src/metadata/utils/sqlalchemy_utils.py index 7eaf8c00fe9..30d8ca74657 100644 --- a/ingestion/src/metadata/utils/sqlalchemy_utils.py +++ b/ingestion/src/metadata/utils/sqlalchemy_utils.py @@ -110,3 +110,18 @@ def get_display_datatype( if scale is not None and precision is not None: return f"{col_type}({str(precision)},{str(scale)})" return col_type + + +def convert_numpy_to_list(data): + """ + Recursively converts numpy arrays to lists in a nested data structure. + """ + import numpy as np # pylint: disable=import-outside-toplevel + + if isinstance(data, np.ndarray): + return data.tolist() + if isinstance(data, list): + return [convert_numpy_to_list(item) for item in data] + if isinstance(data, dict): + return {key: convert_numpy_to_list(value) for key, value in data.items()} + return data diff --git a/ingestion/tests/integration/profiler/conftest.py b/ingestion/tests/integration/profiler/conftest.py index f0607f74f96..4d71670746c 100644 --- a/ingestion/tests/integration/profiler/conftest.py +++ b/ingestion/tests/integration/profiler/conftest.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING import boto3 -import botocore import pytest from testcontainers.localstack import LocalStackContainer