Fixes #11357 - Implement profiler custom metric processing (#14021)

* feat: add backend support for custom metrics

* feat: fix python test

* feat: support custom metrics computation

* feat: updated tests for custom metrics

* feat: added dl support for min max of datetime

* feat: added is safe query check for query sampler

* feat: added support for custom metric computation in dl

* feat: added explicit addProper for pydantic model import fo Extra

* feat: added custom metric to returned obj

* feat: wrapped trino import in __init__

* feat: fix python linting

* feat: fix typing in 3.8
This commit is contained in:
Teddy 2023-11-17 17:51:39 +01:00 committed by GitHub
parent efb6c5f221
commit c7ac28f2c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 738 additions and 148 deletions

View File

@ -428,7 +428,7 @@ def get_parser(args=None):
return parser.parse_args(args) return parser.parse_args(args)
def metadata(args=None): # pylint: disable=too-many-branches def metadata(args=None):
""" """
This method implements parsing of the arguments passed from CLI This method implements parsing of the arguments passed from CLI
""" """

View File

@ -15,7 +15,10 @@ Return types for Profiler workflow execution.
We need to define this class as we end up having We need to define this class as we end up having
multiple profilers per table and columns. multiple profilers per table and columns.
""" """
from typing import List, Optional, Union from typing import List, Optional, Type, Union
from sqlalchemy import Column
from sqlalchemy.orm import DeclarativeMeta
from metadata.config.common import ConfigModel from metadata.config.common import ConfigModel
from metadata.generated.schema.api.data.createTableProfile import ( from metadata.generated.schema.api.data.createTableProfile import (
@ -31,9 +34,12 @@ from metadata.generated.schema.entity.data.table import (
from metadata.generated.schema.entity.services.connections.connectionBasicType import ( from metadata.generated.schema.entity.services.connections.connectionBasicType import (
SampleDataStorageConfig, SampleDataStorageConfig,
) )
from metadata.generated.schema.tests.customMetric import CustomMetric
from metadata.generated.schema.type.basic import FullyQualifiedEntityName from metadata.generated.schema.type.basic import FullyQualifiedEntityName
from metadata.ingestion.models.table_metadata import ColumnTag from metadata.ingestion.models.table_metadata import ColumnTag
from metadata.profiler.metrics.core import Metric, MetricTypes
from metadata.profiler.processor.models import ProfilerDef from metadata.profiler.processor.models import ProfilerDef
from metadata.utils.sqa_like_column import SQALikeColumn
class ColumnConfig(ConfigModel): class ColumnConfig(ConfigModel):
@ -113,3 +119,15 @@ class ProfilerResponse(ConfigModel):
def __str__(self): def __str__(self):
"""Return the table name being processed""" """Return the table name being processed"""
return f"Table [{self.table.name.__root__}]" return f"Table [{self.table.name.__root__}]"
class ThreadPoolMetrics(ConfigModel):
"""thread pool metric"""
metrics: Union[List[Union[Type[Metric], CustomMetric]], Type[Metric]]
metric_type: MetricTypes
column: Optional[Union[Column, SQALikeColumn]]
table: Union[Table, DeclarativeMeta]
class Config:
arbitrary_types_allowed = True

View File

@ -22,11 +22,13 @@ from typing import Dict, List, Optional
from sqlalchemy import Column from sqlalchemy import Column
from metadata.generated.schema.entity.data.table import TableData from metadata.generated.schema.entity.data.table import CustomMetricProfile, TableData
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import ( from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
DatalakeConnection, DatalakeConnection,
) )
from metadata.generated.schema.tests.customMetric import CustomMetric
from metadata.mixins.pandas.pandas_mixin import PandasInterfaceMixin from metadata.mixins.pandas.pandas_mixin import PandasInterfaceMixin
from metadata.profiler.api.models import ThreadPoolMetrics
from metadata.profiler.interface.profiler_interface import ProfilerInterface from metadata.profiler.interface.profiler_interface import ProfilerInterface
from metadata.profiler.metrics.core import MetricTypes from metadata.profiler.metrics.core import MetricTypes
from metadata.profiler.metrics.registry import Metrics from metadata.profiler.metrics.registry import Metrics
@ -239,35 +241,67 @@ class PandasProfilerInterface(ProfilerInterface, PandasInterfaceMixin):
""" """
return None # to be implemented return None # to be implemented
def _compute_custom_metrics(
self, metrics: List[CustomMetric], runner, *args, **kwargs
):
"""Compute custom metrics. For pandas source we expect expression
to be a boolean value. We'll return the length of the dataframe
Args:
metrics (List[Metrics]): list of customMetrics
runner (_type_): runner
"""
if not metrics:
return None
custom_metrics = []
for metric in metrics:
try:
row = sum(
len(df.query(metric.expression).index)
for df in runner
if len(df.query(metric.expression).index)
)
custom_metrics.append(
CustomMetricProfile(name=metric.name.__root__, value=row)
)
except Exception as exc:
msg = f"Error trying to compute profile for custom metric: {exc}"
logger.debug(traceback.format_exc())
logger.warning(msg)
if custom_metrics:
return {"customMetrics": custom_metrics}
return None
def compute_metrics( def compute_metrics(
self, self,
metrics, metric_func: ThreadPoolMetrics,
metric_type,
column,
table,
): ):
"""Run metrics in processor worker""" """Run metrics in processor worker"""
logger.debug(f"Running profiler for {table}") logger.debug(f"Running profiler for {metric_func.table}")
try: try:
row = None row = None
if self.complex_dataframe_sample: if self.complex_dataframe_sample:
row = self._get_metric_fn[metric_type.value]( row = self._get_metric_fn[metric_func.metric_type.value](
metrics, metric_func.metrics,
self.complex_dataframe_sample, self.complex_dataframe_sample,
column=column, column=metric_func.column,
) )
except Exception as exc: except Exception as exc:
name = f"{column if column is not None else table}" name = f"{metric_func.column if metric_func.column is not None else metric_func.table}"
error = f"{name} metric_type.value: {exc}" error = f"{name} metric_type.value: {exc}"
logger.error(error) logger.error(error)
self.status.failed_profiler(error, traceback.format_exc()) self.status.failed_profiler(error, traceback.format_exc())
row = None row = None
if column is not None: if metric_func.column is not None:
column = column.name column = metric_func.column.name
self.status.scanned(f"{table.name.__root__}.{column}") self.status.scanned(f"{metric_func.table.name.__root__}.{column}")
else: else:
self.status.scanned(table.name.__root__) self.status.scanned(metric_func.table.name.__root__)
return row, column, metric_type.value column = None
return row, column, metric_func.metric_type.value
def fetch_sample_data(self, table, columns: SQALikeColumn) -> TableData: def fetch_sample_data(self, table, columns: SQALikeColumn) -> TableData:
"""Fetch sample data from database """Fetch sample data from database
@ -329,7 +363,7 @@ class PandasProfilerInterface(ProfilerInterface, PandasInterfaceMixin):
profile_results = {"table": {}, "columns": defaultdict(dict)} profile_results = {"table": {}, "columns": defaultdict(dict)}
metric_list = [ metric_list = [
self.compute_metrics(*metric_func) for metric_func in metric_funcs self.compute_metrics(metric_func) for metric_func in metric_funcs
] ]
for metric_result in metric_list: for metric_result in metric_list:
profile, column, metric_type = metric_result profile, column, metric_type = metric_result
@ -338,6 +372,8 @@ class PandasProfilerInterface(ProfilerInterface, PandasInterfaceMixin):
profile_results["table"].update(profile) profile_results["table"].update(profile)
if metric_type == MetricTypes.System.value: if metric_type == MetricTypes.System.value:
profile_results["system"] = profile profile_results["system"] = profile
elif metric_type == MetricTypes.Custom.value and column is None:
profile_results["table"].update(profile)
else: else:
if profile: if profile:
profile_results["columns"][column].update( profile_results["columns"][column].update(

View File

@ -46,6 +46,7 @@ from metadata.generated.schema.entity.services.databaseService import (
from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline import ( from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline import (
DatabaseServiceProfilerPipeline, DatabaseServiceProfilerPipeline,
) )
from metadata.generated.schema.tests.customMetric import CustomMetric
from metadata.ingestion.api.models import StackTraceError from metadata.ingestion.api.models import StackTraceError
from metadata.ingestion.api.status import Status from metadata.ingestion.api.status import Status
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
@ -130,6 +131,7 @@ class ProfilerInterface(ABC):
MetricTypes.Query.value: self._compute_query_metrics, MetricTypes.Query.value: self._compute_query_metrics,
MetricTypes.Window.value: self._compute_window_metrics, MetricTypes.Window.value: self._compute_window_metrics,
MetricTypes.System.value: self._compute_system_metrics, MetricTypes.System.value: self._compute_system_metrics,
MetricTypes.Custom.value: self._compute_custom_metrics,
} }
@abstractmethod @abstractmethod
@ -459,6 +461,13 @@ class ProfilerInterface(ABC):
"""Get metrics""" """Get metrics"""
raise NotImplementedError raise NotImplementedError
@abstractmethod
def _compute_custom_metrics(
self, metrics: List[CustomMetric], runner, *args, **kwargs
):
"""Compute custom metrics"""
raise NotImplementedError
@abstractmethod @abstractmethod
def get_all_metrics(self, metric_funcs) -> dict: def get_all_metrics(self, metric_funcs) -> dict:
"""run profiler metrics""" """run profiler metrics"""

View File

@ -22,13 +22,15 @@ from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Dict, List, Optional from typing import Dict, List, Optional
from sqlalchemy import Column, inspect from sqlalchemy import Column, inspect, text
from sqlalchemy.exc import ProgrammingError, ResourceClosedError from sqlalchemy.exc import ProgrammingError, ResourceClosedError
from sqlalchemy.orm import scoped_session from sqlalchemy.orm import scoped_session
from metadata.generated.schema.entity.data.table import TableData from metadata.generated.schema.entity.data.table import CustomMetricProfile, TableData
from metadata.generated.schema.tests.customMetric import CustomMetric
from metadata.ingestion.connections.session import create_and_bind_thread_safe_session from metadata.ingestion.connections.session import create_and_bind_thread_safe_session
from metadata.mixins.sqalchemy.sqa_mixin import SQAInterfaceMixin from metadata.mixins.sqalchemy.sqa_mixin import SQAInterfaceMixin
from metadata.profiler.api.models import ThreadPoolMetrics
from metadata.profiler.interface.profiler_interface import ProfilerInterface from metadata.profiler.interface.profiler_interface import ProfilerInterface
from metadata.profiler.metrics.core import MetricTypes from metadata.profiler.metrics.core import MetricTypes
from metadata.profiler.metrics.registry import Metrics from metadata.profiler.metrics.registry import Metrics
@ -42,6 +44,7 @@ from metadata.profiler.orm.registry import Dialects
from metadata.profiler.processor.runner import QueryRunner from metadata.profiler.processor.runner import QueryRunner
from metadata.utils.constants import SAMPLE_DATA_DEFAULT_COUNT from metadata.utils.constants import SAMPLE_DATA_DEFAULT_COUNT
from metadata.utils.custom_thread_pool import CustomThreadPoolExecutor from metadata.utils.custom_thread_pool import CustomThreadPoolExecutor
from metadata.utils.helpers import is_safe_sql_query
from metadata.utils.logger import profiler_interface_registry_logger from metadata.utils.logger import profiler_interface_registry_logger
logger = profiler_interface_registry_logger() logger = profiler_interface_registry_logger()
@ -307,6 +310,8 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
row = runner.select_first_from_sample( row = runner.select_first_from_sample(
*[metric(column).fn() for metric in metrics], *[metric(column).fn() for metric in metrics],
) )
if row:
return dict(row)
except ProgrammingError as exc: except ProgrammingError as exc:
logger.info( logger.info(
f"Skipping metrics for {runner.table.__tablename__}.{column.name} due to {exc}" f"Skipping metrics for {runner.table.__tablename__}.{column.name} due to {exc}"
@ -314,8 +319,43 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
except Exception as exc: except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}" msg = f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
handle_query_exception(msg, exc, session) handle_query_exception(msg, exc, session)
if row: return None
return dict(row)
def _compute_custom_metrics(
self, metrics: List[CustomMetric], runner, session, *args, **kwargs
):
"""Compute custom metrics
Args:
metrics (List[Metrics]): list of customMetrics
runner (_type_): runner
"""
if not metrics:
return None
custom_metrics = []
for metric in metrics:
try:
if not is_safe_sql_query(metric.expression):
raise RuntimeError(
f"SQL expression is not safe\n\n{metric.expression}"
)
crs = session.execute(text(metric.expression))
row = (
crs.scalar()
) # raise MultipleResultsFound if more than one row is returned
custom_metrics.append(
CustomMetricProfile(name=metric.name.__root__, value=row)
)
except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.__tablename__}.{metric.columnName}: {exc}"
logger.debug(traceback.format_exc())
logger.warning(msg)
if custom_metrics:
return {"customMetrics": custom_metrics}
return None return None
def _compute_system_metrics( def _compute_system_metrics(
@ -376,14 +416,11 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
def compute_metrics_in_thread( def compute_metrics_in_thread(
self, self,
metrics, metric_func: ThreadPoolMetrics,
metric_type,
column,
table,
): ):
"""Run metrics in processor worker""" """Run metrics in processor worker"""
logger.debug( logger.debug(
f"Running profiler for {table.__tablename__} on thread {threading.current_thread()}" f"Running profiler for {metric_func.table.__tablename__} on thread {threading.current_thread()}"
) )
Session = self.session_factory # pylint: disable=invalid-name Session = self.session_factory # pylint: disable=invalid-name
with Session() as session: with Session() as session:
@ -391,36 +428,40 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
self.set_catalog(session) self.set_catalog(session)
sampler = self._create_thread_safe_sampler( sampler = self._create_thread_safe_sampler(
session, session,
table, metric_func.table,
) )
sample = sampler.random_sample() sample = sampler.random_sample()
runner = self._create_thread_safe_runner( runner = self._create_thread_safe_runner(
session, session,
table, metric_func.table,
sample, sample,
) )
try: try:
row = self._get_metric_fn[metric_type.value]( row = self._get_metric_fn[metric_func.metric_type.value](
metrics, metric_func.metrics,
runner=runner, runner=runner,
session=session, session=session,
column=column, column=metric_func.column,
sample=sample, sample=sample,
) )
except Exception as exc: except Exception as exc:
error = f"{column if column is not None else runner.table.__tablename__} metric_type.value: {exc}" error = (
f"{metric_func.column if metric_func.column is not None else metric_func.table.__tablename__} "
f"metric_type.value: {exc}"
)
logger.error(error) logger.error(error)
self.status.failed_profiler(error, traceback.format_exc()) self.status.failed_profiler(error, traceback.format_exc())
row = None row = None
if column is not None: if metric_func.column is not None:
column = column.name column = metric_func.column.name
self.status.scanned(f"{table.__tablename__}.{column}") self.status.scanned(f"{metric_func.table.__tablename__}.{column}")
else: else:
self.status.scanned(table.__tablename__) self.status.scanned(metric_func.table.__tablename__)
column = None
return row, column, metric_type.value return row, column, metric_func.metric_type.value
# pylint: disable=use-dict-literal # pylint: disable=use-dict-literal
def get_all_metrics( def get_all_metrics(
@ -434,7 +475,7 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
futures = [ futures = [
pool.submit( pool.submit(
self.compute_metrics_in_thread, self.compute_metrics_in_thread,
*metric_func, metric_func,
) )
for metric_func in metric_funcs for metric_func in metric_funcs
] ]
@ -455,6 +496,8 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
profile_results["table"].update(profile) profile_results["table"].update(profile)
elif metric_type == MetricTypes.System.value: elif metric_type == MetricTypes.System.value:
profile_results["system"] = profile profile_results["system"] = profile
elif metric_type == MetricTypes.Custom.value and column is None:
profile_results["table"].update(profile)
else: else:
profile_results["columns"][column].update( profile_results["columns"][column].update(
{ {

View File

@ -72,6 +72,8 @@ class SingleStoreProfilerInterface(SQAProfilerInterface):
row = runner.select_first_from_sample( row = runner.select_first_from_sample(
*[metric(column).fn() for metric in metrics], *[metric(column).fn() for metric in metrics],
) )
if row:
return dict(row)
except ProgrammingError: except ProgrammingError:
logger.info( logger.info(
f"Skipping window metrics for {runner.table.__tablename__}.{column.name} due to overflow" f"Skipping window metrics for {runner.table.__tablename__}.{column.name} due to overflow"
@ -81,6 +83,4 @@ class SingleStoreProfilerInterface(SQAProfilerInterface):
except Exception as exc: except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}" msg = f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
handle_query_exception(msg, exc, session) handle_query_exception(msg, exc, session)
if row:
return dict(row)
return None return None

View File

@ -67,6 +67,8 @@ class TrinoProfilerInterface(SQAProfilerInterface):
row = runner.select_first_from_sample( row = runner.select_first_from_sample(
*[metric(column).fn() for metric in metrics], **runner_kwargs *[metric(column).fn() for metric in metrics], **runner_kwargs
) )
if row:
return dict(row)
except ProgrammingError as err: except ProgrammingError as err:
logger.info( logger.info(
f"Skipping window metrics for {runner.table.__tablename__}.{column.name} due to {err}" f"Skipping window metrics for {runner.table.__tablename__}.{column.name} due to {err}"
@ -76,6 +78,4 @@ class TrinoProfilerInterface(SQAProfilerInterface):
except Exception as exc: except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}" msg = f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
handle_query_exception(msg, exc, session) handle_query_exception(msg, exc, session)
if row:
return dict(row)
return None return None

View File

@ -90,6 +90,9 @@ class Max(StaticMetric):
def df_fn(self, dfs=None): def df_fn(self, dfs=None):
"""pandas function""" """pandas function"""
if is_quantifiable(self.col.type) or is_date_time(self.col.type): if is_quantifiable(self.col.type):
return max((df[self.col.name].max() for df in dfs)) return max((df[self.col.name].max() for df in dfs))
if is_date_time(self.col.type):
max_ = max((df[self.col.name].max() for df in dfs))
return int(max_.timestamp() * 1000)
return 0 return 0

View File

@ -90,6 +90,9 @@ class Min(StaticMetric):
def df_fn(self, dfs=None): def df_fn(self, dfs=None):
"""pandas function""" """pandas function"""
if is_quantifiable(self.col.type) or is_date_time(self.col.type): if is_quantifiable(self.col.type):
return min((df[self.col.name].min() for df in dfs)) return min((df[self.col.name].min() for df in dfs))
if is_date_time(self.col.type):
min_ = min((df[self.col.name].min() for df in dfs))
return int(min_.timestamp() * 1000)
return 0 return 0

View File

@ -33,7 +33,10 @@ from metadata.generated.schema.entity.data.table import (
TableData, TableData,
TableProfile, TableProfile,
) )
from metadata.profiler.api.models import ProfilerResponse from metadata.generated.schema.tests.customMetric import (
CustomMetric as CustomMetricEntity,
)
from metadata.profiler.api.models import ProfilerResponse, ThreadPoolMetrics
from metadata.profiler.interface.profiler_interface import ProfilerInterface from metadata.profiler.interface.profiler_interface import ProfilerInterface
from metadata.profiler.metrics.core import ( from metadata.profiler.metrics.core import (
ComposedMetric, ComposedMetric,
@ -273,6 +276,33 @@ class Profiler(Generic[TMetric]):
return [metric for metric in metrics if metric.is_col_metric()] return [metric for metric in metrics if metric.is_col_metric()]
def get_custom_metrics(
self, column_name: Optional[str] = None
) -> Optional[List[CustomMetricEntity]]:
"""Get custom metrics for a table or column
Args:
column (Optional[str]): optional column name. If None will fetch table level custom metrics
Returns:
List[str]
"""
if column_name is None:
return self.profiler_interface.table_entity.customMetrics or None
# if we have a column we'll get the custom metrics for this column
column = next(
(
clmn
for clmn in self.profiler_interface.table_entity.columns
if clmn.name.__root__ == column_name
),
None,
)
if column:
return column.customMetrics or None
return None
@property @property
def sample(self): def sample(self):
"""Return the sample used for the profiler""" """Return the sample used for the profiler"""
@ -345,22 +375,40 @@ class Profiler(Generic[TMetric]):
def _prepare_table_metrics(self) -> List: def _prepare_table_metrics(self) -> List:
"""prepare table metrics""" """prepare table metrics"""
metrics = []
table_metrics = [ table_metrics = [
metric metric
for metric in self.static_metrics for metric in self.static_metrics
if (not metric.is_col_metric() and not metric.is_system_metrics()) if (not metric.is_col_metric() and not metric.is_system_metrics())
] ]
custom_table_metrics = self.get_custom_metrics()
if table_metrics: if table_metrics:
return [ metrics.extend(
( [
table_metrics, # metric functions ThreadPoolMetrics(
MetricTypes.Table, # metric type for function mapping metrics=table_metrics,
None, # column name metric_type=MetricTypes.Table,
self.table, # table name column=None,
), table=self.table,
] )
return [] ]
)
if custom_table_metrics:
metrics.extend(
[
ThreadPoolMetrics(
metrics=custom_table_metrics,
metric_type=MetricTypes.Custom,
column=None,
table=self.table,
)
]
)
return metrics
def _prepare_system_metrics(self) -> List: def _prepare_system_metrics(self) -> List:
"""prepare system metrics""" """prepare system metrics"""
@ -368,11 +416,11 @@ class Profiler(Generic[TMetric]):
if system_metrics: if system_metrics:
return [ return [
( ThreadPoolMetrics(
system_metric, # metric functions metrics=system_metric, # metric functions
MetricTypes.System, # metric type for function mapping metric_type=MetricTypes.System, # metric type for function mapping
None, # column name column=None, # column name
self.table, # table name table=self.table, # table name
) )
for system_metric in system_metrics for system_metric in system_metrics
] ]
@ -386,45 +434,60 @@ class Profiler(Generic[TMetric]):
for column in self.columns for column in self.columns
if column.type.__class__.__name__ not in NOT_COMPUTE if column.type.__class__.__name__ not in NOT_COMPUTE
] ]
column_metrics_for_thread_pool = []
column_metrics_for_thread_pool = [ static_metrics = [
*[ ThreadPoolMetrics(
( metrics=[
[ metric
metric for metric in self.get_col_metrics(self.static_metrics, column)
for metric in self.get_col_metrics(self.static_metrics, column) if not metric.is_window_metric()
if not metric.is_window_metric() ],
], metric_type=MetricTypes.Static,
MetricTypes.Static, column=column,
column, table=self.table,
self.table, )
) for column in columns
for column in columns
],
*[
(
metric,
MetricTypes.Query,
column,
self.table,
)
for column in columns
for metric in self.get_col_metrics(self.query_metrics, column)
],
*[
(
[
metric
for metric in self.get_col_metrics(self.static_metrics, column)
if metric.is_window_metric()
],
MetricTypes.Window,
column,
self.table,
)
for column in columns
],
] ]
query_metrics = [
ThreadPoolMetrics(
metrics=metric,
metric_type=MetricTypes.Query,
column=column,
table=self.table,
)
for column in columns
for metric in self.get_col_metrics(self.query_metrics, column)
]
window_metrics = [
ThreadPoolMetrics(
metrics=[
metric
for metric in self.get_col_metrics(self.static_metrics, column)
if metric.is_window_metric()
],
metric_type=MetricTypes.Window,
column=column,
table=self.table,
)
for column in columns
]
# we'll add the system metrics to the thread pool computation
for metric_type in [static_metrics, query_metrics, window_metrics]:
column_metrics_for_thread_pool.extend(metric_type)
# we'll add the custom metrics to the thread pool computation
for column in columns:
custom_metrics = self.get_custom_metrics(column.name)
if custom_metrics:
column_metrics_for_thread_pool.append(
ThreadPoolMetrics(
metrics=custom_metrics,
metric_type=MetricTypes.Custom,
column=column,
table=self.table,
)
)
return column_metrics_for_thread_pool return column_metrics_for_thread_pool
@ -568,6 +631,7 @@ class Profiler(Generic[TMetric]):
profileSampleType=self.profile_sample_config.profile_sample_type profileSampleType=self.profile_sample_config.profile_sample_type
if self.profile_sample_config if self.profile_sample_config
else None, else None,
customMetrics=self._table_results.get("customMetrics"),
) )
if self._system_results: if self._system_results:

View File

@ -31,6 +31,7 @@ from metadata.profiler.orm.functions.random_num import RandomNumFn
from metadata.profiler.orm.registry import Dialects from metadata.profiler.orm.registry import Dialects
from metadata.profiler.processor.handle_partition import partition_filter_handler from metadata.profiler.processor.handle_partition import partition_filter_handler
from metadata.profiler.processor.sampler.sampler_interface import SamplerInterface from metadata.profiler.processor.sampler.sampler_interface import SamplerInterface
from metadata.utils.helpers import is_safe_sql_query
from metadata.utils.logger import profiler_interface_registry_logger from metadata.utils.logger import profiler_interface_registry_logger
from metadata.utils.sqa_utils import ( from metadata.utils.sqa_utils import (
build_query_filter, build_query_filter,
@ -171,6 +172,11 @@ class SQASampler(SamplerInterface):
def _fetch_sample_data_from_user_query(self) -> TableData: def _fetch_sample_data_from_user_query(self) -> TableData:
"""Returns a table data object using results from query execution""" """Returns a table data object using results from query execution"""
if not is_safe_sql_query(self._profile_sample_query):
raise RuntimeError(
f"SQL expression is not safe\n\n{self._profile_sample_query}"
)
rnd = self.client.execute(f"{self._profile_sample_query}") rnd = self.client.execute(f"{self._profile_sample_query}")
try: try:
columns = [col.name for col in rnd.cursor.description] columns = [col.name for col in rnd.cursor.description]
@ -183,6 +189,11 @@ class SQASampler(SamplerInterface):
def _rdn_sample_from_user_query(self) -> Query: def _rdn_sample_from_user_query(self) -> Query:
"""Returns sql alchemy object to use when running profiling""" """Returns sql alchemy object to use when running profiling"""
if not is_safe_sql_query(self._profile_sample_query):
raise RuntimeError(
f"SQL expression is not safe\n\n{self._profile_sample_query}"
)
return self.client.query(self.table).from_statement( return self.client.query(self.table).from_statement(
text(f"{self._profile_sample_query}") text(f"{self._profile_sample_query}")
) )

View File

@ -13,7 +13,6 @@ Helper module to handle data sampling
for the profiler for the profiler
""" """
from sqlalchemy import inspect, or_, text from sqlalchemy import inspect, or_, text
from trino.sqlalchemy.dialect import TrinoDialect
from metadata.profiler.orm.registry import FLOAT_SET from metadata.profiler.orm.registry import FLOAT_SET
from metadata.profiler.processor.handle_partition import RANDOM_LABEL from metadata.profiler.processor.handle_partition import RANDOM_LABEL
@ -26,7 +25,13 @@ class TrinoSampler(SQASampler):
run the query in the whole table. run the query in the whole table.
""" """
TrinoDialect._json_deserializer = None # pylint: disable=protected-access def __init__(self, *args, **kwargs):
# pylint: disable=import-outside-toplevel
from trino.sqlalchemy.dialect import TrinoDialect
TrinoDialect._json_deserializer = None
super().__init__(*args, **kwargs)
def _base_sample_query(self, label=None): def _base_sample_query(self, label=None):
sqa_columns = [col for col in inspect(self.table).c if col.name != RANDOM_LABEL] sqa_columns = [col for col in inspect(self.table).c if col.name != RANDOM_LABEL]

View File

@ -238,9 +238,7 @@ class OpenMetadataSource(Source):
""" """
tables = self.metadata.list_all_entities( tables = self.metadata.list_all_entities(
entity=Table, entity=Table,
fields=[ fields=["tableProfilerConfig", "columns", "customMetrics"],
"tableProfilerConfig",
],
params={ params={
"service": self.config.source.serviceName, "service": self.config.source.serviceName,
"database": fqn.build( "database": fqn.build(

View File

@ -430,6 +430,9 @@ def is_safe_sql_query(sql_query: str) -> bool:
"SET TRANSACTION", "SET TRANSACTION",
} }
if sql_query is None:
return True
parsed_queries: Tuple[Statement] = sqlparse.parse(sql_query) parsed_queries: Tuple[Statement] = sqlparse.parse(sql_query)
for parsed_query in parsed_queries: for parsed_query in parsed_queries:
validation = [ validation = [

View File

@ -0,0 +1,4 @@
id,first_name,last_name,city,country,birthdate,age
1,John,Doe,Los Angeles,US,1980-01-01,40
2,Jane,Doe,Los Angeles,US,2000-12-31,39
3,Jane,Smith,Paris,,2001-11-11,28
1 id first_name last_name city country birthdate age
2 1 John Doe Los Angeles US 1980-01-01 40
3 2 Jane Doe Los Angeles US 2000-12-31 39
4 3 Jane Smith Paris 2001-11-11 28

View File

@ -0,0 +1,233 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test Metrics behavior
"""
# import datetime
import os
from unittest import TestCase
from unittest.mock import patch
from uuid import uuid4
import boto3
import botocore
import pandas as pd
from moto import mock_s3
from metadata.generated.schema.entity.data.table import Column as EntityColumn
from metadata.generated.schema.entity.data.table import ColumnName, DataType, Table
from metadata.generated.schema.entity.services.connections.database.datalake.s3Config import (
S3Config,
)
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
DatalakeConnection,
)
from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials
from metadata.generated.schema.tests.customMetric import CustomMetric
from metadata.profiler.interface.pandas.profiler_interface import (
PandasProfilerInterface,
)
from metadata.profiler.processor.core import Profiler
BUCKET_NAME = "MyBucket"
@mock_s3
class MetricsTest(TestCase):
"""
Run checks on different metrics
"""
current_dir = os.path.dirname(__file__)
resources_dir = os.path.join(current_dir, "resources")
datalake_conn = DatalakeConnection(
configSource=S3Config(
securityConfig=AWSCredentials(
awsAccessKeyId="fake_access_key",
awsSecretAccessKey="fake_secret_key",
awsRegion="us-west-1",
)
)
)
dfs = [
pd.read_csv(os.path.join(resources_dir, "profiler_test_.csv"), parse_dates=[5])
]
table_entity = Table(
id=uuid4(),
name="user",
columns=[
EntityColumn(
name=ColumnName(__root__="id"),
dataType=DataType.INT,
)
],
)
def setUp(self):
# Mock our S3 bucket and ingest a file
boto3.DEFAULT_SESSION = None
self.client = boto3.client(
"s3",
region_name="us-weat-1",
)
# check that we are not running our test against a real bucket
try:
s3 = boto3.resource(
"s3",
region_name="us-west-1",
aws_access_key_id="fake_access_key",
aws_secret_access_key="fake_secret_key",
)
s3.meta.client.head_bucket(Bucket=BUCKET_NAME)
except botocore.exceptions.ClientError:
pass
else:
err = f"{BUCKET_NAME} should not exist."
raise EnvironmentError(err)
self.client.create_bucket(
Bucket=BUCKET_NAME,
CreateBucketConfiguration={"LocationConstraint": "us-west-1"},
)
resources_paths = [
os.path.join(path, filename)
for path, _, files in os.walk(self.resources_dir)
for filename in files
]
self.s3_keys = []
for path in resources_paths:
key = os.path.relpath(path, self.resources_dir)
self.s3_keys.append(key)
self.client.upload_file(Filename=path, Bucket=BUCKET_NAME, Key=key)
with patch.object(
PandasProfilerInterface,
"_convert_table_to_list_of_dataframe_objects",
return_value=self.dfs,
):
self.sqa_profiler_interface = PandasProfilerInterface(
self.datalake_conn,
None,
self.table_entity,
None,
None,
None,
None,
None,
thread_count=1,
)
def test_table_custom_metric(self):
table_entity = Table(
id=uuid4(),
name="user",
columns=[
EntityColumn(
name=ColumnName(__root__="id"),
dataType=DataType.INT,
)
],
customMetrics=[
CustomMetric(
name="LastNameFilter",
expression="'last_name' != Doe",
),
CustomMetric(
name="notUS",
expression="'country == US'",
),
],
)
with patch.object(
PandasProfilerInterface,
"_convert_table_to_list_of_dataframe_objects",
return_value=self.dfs,
):
self.sqa_profiler_interface = PandasProfilerInterface(
self.datalake_conn,
None,
table_entity,
None,
None,
None,
None,
None,
thread_count=1,
)
profiler = Profiler(
profiler_interface=self.sqa_profiler_interface,
)
metrics = profiler.compute_metrics()
for k, v in metrics._table_results.items():
for metric in v:
if metric.name == "LastNameFilter":
assert metric.value == 1
if metric.name == "notUS":
assert metric.value == 2
def test_column_custom_metric(self):
table_entity = Table(
id=uuid4(),
name="user",
columns=[
EntityColumn(
name=ColumnName(__root__="id"),
dataType=DataType.INT,
customMetrics=[
CustomMetric(
name="LastNameFilter",
columnName="id",
expression="'last_name' != Doe",
),
CustomMetric(
name="notUS",
columnName="id",
expression="'country == US'",
),
],
)
],
)
with patch.object(
PandasProfilerInterface,
"_convert_table_to_list_of_dataframe_objects",
return_value=self.dfs,
):
self.sqa_profiler_interface = PandasProfilerInterface(
self.datalake_conn,
None,
table_entity,
None,
None,
None,
None,
None,
thread_count=1,
)
profiler = Profiler(
profiler_interface=self.sqa_profiler_interface,
)
metrics = profiler.compute_metrics()
for k, v in metrics._column_results.items():
for metric in v.get("customMetrics", []):
if metric.name == "CustomerBornedAfter1991":
assert metric.value == 1
if metric.name == "AverageAge":
assert metric.value == 2

View File

@ -269,5 +269,8 @@ class ProfilerTest(TestCase):
column_metrics = default_profiler._prepare_column_metrics() column_metrics = default_profiler._prepare_column_metrics()
for metric in column_metrics: for metric in column_metrics:
if metric[1] is not MetricTypes.Table and metric[2].name == "id": if (
assert all(metric_filter.count(m.name()) for m in metric[0]) metric.metric_type is not MetricTypes.Table
and metric.column.name == "id"
):
assert all(metric_filter.count(m.name()) for m in metric.metrics)

View File

@ -32,6 +32,7 @@ from metadata.generated.schema.entity.data.table import (
Table, Table,
TableProfile, TableProfile,
) )
from metadata.profiler.api.models import ThreadPoolMetrics
from metadata.profiler.interface.pandas.profiler_interface import ( from metadata.profiler.interface.pandas.profiler_interface import (
PandasProfilerInterface, PandasProfilerInterface,
) )
@ -146,15 +147,15 @@ class PandasInterfaceTest(TestCase):
def test_get_all_metrics(self): def test_get_all_metrics(self):
table_metrics = [ table_metrics = [
( ThreadPoolMetrics(
[ metrics=[
metric metric
for metric in self.metrics for metric in self.metrics
if (not metric.is_col_metric() and not metric.is_system_metrics()) if (not metric.is_col_metric() and not metric.is_system_metrics())
], ],
MetricTypes.Table, metric_type=MetricTypes.Table,
None, column=None,
self.table_entity, table=self.table_entity,
) )
] ]
column_metrics = [] column_metrics = []
@ -164,36 +165,36 @@ class PandasInterfaceTest(TestCase):
if col.name == "id": if col.name == "id":
continue continue
column_metrics.append( column_metrics.append(
( ThreadPoolMetrics(
[ metrics=[
metric metric
for metric in self.static_metrics for metric in self.static_metrics
if metric.is_col_metric() and not metric.is_window_metric() if metric.is_col_metric() and not metric.is_window_metric()
], ],
MetricTypes.Static, metric_type=MetricTypes.Static,
col, column=col,
self.table_entity, table=self.table_entity,
) )
) )
for query_metric in self.query_metrics: for query_metric in self.query_metrics:
query_metrics.append( query_metrics.append(
( ThreadPoolMetrics(
query_metric, metrics=query_metric,
MetricTypes.Query, metric_type=MetricTypes.Query,
col, column=col,
self.table_entity, table=self.table_entity,
) )
) )
window_metrics.append( window_metrics.append(
( ThreadPoolMetrics(
[ metrics=[
metric metric
for metric in self.window_metrics for metric in self.window_metrics
if metric.is_window_metric() if metric.is_window_metric()
], ],
MetricTypes.Window, metric_type=MetricTypes.Window,
col, column=col,
self.table_entity, table=self.table_entity,
) )
) )

View File

@ -27,6 +27,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
SQLiteConnection, SQLiteConnection,
SQLiteScheme, SQLiteScheme,
) )
from metadata.generated.schema.tests.customMetric import CustomMetric
from metadata.profiler.interface.sqlalchemy.profiler_interface import ( from metadata.profiler.interface.sqlalchemy.profiler_interface import (
SQAProfilerInterface, SQAProfilerInterface,
) )
@ -865,6 +866,102 @@ class MetricsTest(TestCase):
session = self.sqa_profiler_interface.session session = self.sqa_profiler_interface.session
system().sql(session) system().sql(session)
def test_table_custom_metric(self):
table_entity = Table(
id=uuid4(),
name="user",
columns=[
EntityColumn(
name=ColumnName(__root__="id"),
dataType=DataType.INT,
)
],
customMetrics=[
CustomMetric(
name="CustomerBornedAfter1991",
expression="SELECT COUNT(id) FROM users WHERE dob > '1991-01-01'",
),
CustomMetric(
name="AverageAge",
expression="SELECT SUM(age)/COUNT(*) FROM users",
),
],
)
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
):
self.sqa_profiler_interface = SQAProfilerInterface(
self.sqlite_conn,
None,
table_entity,
None,
None,
None,
None,
None,
thread_count=1,
)
profiler = Profiler(
profiler_interface=self.sqa_profiler_interface,
)
metrics = profiler.compute_metrics()
for k, v in metrics._table_results.items():
for metric in v:
if metric.name == "CustomerBornedAfter1991":
assert metric.value == 2
if metric.name == "AverageAge":
assert metric.value == 20.0
def test_column_custom_metric(self):
table_entity = Table(
id=uuid4(),
name="user",
columns=[
EntityColumn(
name=ColumnName(__root__="id"),
dataType=DataType.INT,
customMetrics=[
CustomMetric(
name="CustomerBornedAfter1991",
columnName="id",
expression="SELECT SUM(id) FROM users WHERE dob > '1991-01-01'",
),
CustomMetric(
name="AverageAge",
columnName="id",
expression="SELECT SUM(age)/COUNT(*) FROM users",
),
],
)
],
)
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
):
self.sqa_profiler_interface = SQAProfilerInterface(
self.sqlite_conn,
None,
table_entity,
None,
None,
None,
None,
None,
thread_count=1,
)
profiler = Profiler(
profiler_interface=self.sqa_profiler_interface,
)
metrics = profiler.compute_metrics()
for k, v in metrics._column_results.items():
for metric in v.get("customMetrics", []):
if metric.name == "CustomerBornedAfter1991":
assert metric.value == 3.0
if metric.name == "AverageAge":
assert metric.value == 20.0
@classmethod @classmethod
def tearDownClass(cls) -> None: def tearDownClass(cls) -> None:
os.remove(cls.db_path) os.remove(cls.db_path)

View File

@ -42,6 +42,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
SQLiteConnection, SQLiteConnection,
SQLiteScheme, SQLiteScheme,
) )
from metadata.generated.schema.tests.customMetric import CustomMetric
from metadata.ingestion.source import sqa_types from metadata.ingestion.source import sqa_types
from metadata.profiler.interface.sqlalchemy.profiler_interface import ( from metadata.profiler.interface.sqlalchemy.profiler_interface import (
SQAProfilerInterface, SQAProfilerInterface,
@ -83,8 +84,27 @@ class ProfilerTest(TestCase):
EntityColumn( EntityColumn(
name=ColumnName(__root__="id"), name=ColumnName(__root__="id"),
dataType=DataType.INT, dataType=DataType.INT,
customMetrics=[
CustomMetric(
name="custom_metric",
description="custom metric",
expression="SELECT cos(id) FROM users",
)
],
) )
], ],
customMetrics=[
CustomMetric(
name="custom_metric",
description="custom metric",
expression="SELECT COUNT(id) / COUNT(age) FROM users",
),
CustomMetric(
name="custom_metric_two",
description="custom metric",
expression="SELECT COUNT(id) * COUNT(age) FROM users",
),
],
) )
with patch.object( with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
@ -231,6 +251,30 @@ class ProfilerTest(TestCase):
) )
) )
def test__prepare_column_metrics(self):
"""test _prepare_column_metrics returns as expected"""
profiler = Profiler(
Metrics.FIRST_QUARTILE.value,
profiler_interface=self.sqa_profiler_interface,
)
metrics = profiler._prepare_column_metrics()
for metric in metrics:
if metric.metrics:
if isinstance(metric.metrics[0], CustomMetric):
assert metric.metrics[0].name.__root__ == "custom_metric"
else:
assert metric.metrics[0].name() == "firstQuartile"
def test__prepare_table_metrics(self):
"""test _prepare_table_metrics returns as expected"""
profiler = Profiler(
Metrics.COLUMN_COUNT.value,
profiler_interface=self.sqa_profiler_interface,
)
metrics = profiler._prepare_table_metrics()
self.assertEqual(2, len(metrics))
def test_profiler_with_timeout(self): def test_profiler_with_timeout(self):
"""check timeout is properly used""" """check timeout is properly used"""
@ -259,6 +303,7 @@ class ProfilerTest(TestCase):
def test_profiler_get_col_metrics(self): def test_profiler_get_col_metrics(self):
"""check getc column metrics""" """check getc column metrics"""
metric_filter = ["mean", "min", "max", "firstQuartile"] metric_filter = ["mean", "min", "max", "firstQuartile"]
custom_metric_filter = ["custom_metric"]
self.sqa_profiler_interface.table_entity.tableProfilerConfig = ( self.sqa_profiler_interface.table_entity.tableProfilerConfig = (
TableProfilerConfig( TableProfilerConfig(
includeColumns=[ includeColumns=[
@ -273,8 +318,20 @@ class ProfilerTest(TestCase):
column_metrics = default_profiler._prepare_column_metrics() column_metrics = default_profiler._prepare_column_metrics()
for metric in column_metrics: for metric in column_metrics:
if metric[1] is not MetricTypes.Table and metric[2].name == "id": if (
assert all(metric_filter.count(m.name()) for m in metric[0]) metric.metric_type is not MetricTypes.Table
and metric.column.name == "id"
):
assert all(
metric_filter.count(m.name())
for m in metric.metrics
if not isinstance(m, CustomMetric)
)
assert all(
custom_metric_filter.count(m.name.__root__)
for m in metric.metrics
if isinstance(m, CustomMetric)
)
@classmethod @classmethod
def tearDownClass(cls) -> None: def tearDownClass(cls) -> None:

View File

@ -38,6 +38,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
SQLiteConnection, SQLiteConnection,
SQLiteScheme, SQLiteScheme,
) )
from metadata.profiler.api.models import ThreadPoolMetrics
from metadata.profiler.interface.sqlalchemy.profiler_interface import ( from metadata.profiler.interface.sqlalchemy.profiler_interface import (
SQAProfilerInterface, SQAProfilerInterface,
) )
@ -162,15 +163,15 @@ class SQAInterfaceTestMultiThread(TestCase):
def test_get_all_metrics(self): def test_get_all_metrics(self):
table_metrics = [ table_metrics = [
( ThreadPoolMetrics(
[ metrics=[
metric metric
for metric in self.metrics for metric in self.metrics
if (not metric.is_col_metric() and not metric.is_system_metrics()) if (not metric.is_col_metric() and not metric.is_system_metrics())
], ],
MetricTypes.Table, metric_type=MetricTypes.Table,
None, column=None,
self.table, table=self.table,
) )
] ]
column_metrics = [] column_metrics = []
@ -178,36 +179,36 @@ class SQAInterfaceTestMultiThread(TestCase):
window_metrics = [] window_metrics = []
for col in inspect(User).c: for col in inspect(User).c:
column_metrics.append( column_metrics.append(
( ThreadPoolMetrics(
[ metrics=[
metric metric
for metric in self.static_metrics for metric in self.static_metrics
if metric.is_col_metric() and not metric.is_window_metric() if metric.is_col_metric() and not metric.is_window_metric()
], ],
MetricTypes.Static, metric_type=MetricTypes.Static,
col, column=col,
self.table, table=self.table,
) )
) )
for query_metric in self.query_metrics: for query_metric in self.query_metrics:
query_metrics.append( query_metrics.append(
( ThreadPoolMetrics(
query_metric, metrics=query_metric,
MetricTypes.Query, metric_type=MetricTypes.Query,
col, column=col,
self.table, table=self.table,
) )
) )
window_metrics.append( window_metrics.append(
( ThreadPoolMetrics(
[ metrics=[
metric metric
for metric in self.window_metrics for metric in self.window_metrics
if metric.is_window_metric() if metric.is_window_metric()
], ],
MetricTypes.Window, metric_type=MetricTypes.Window,
col, column=col,
self.table, table=self.table,
) )
) )

View File

@ -26,5 +26,6 @@
} }
] ]
} }
} },
"additionalProperties": false
} }