mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-10-10 08:16:25 +00:00
* feat: add backend support for custom metrics * feat: fix python test * feat: support custom metrics computation * feat: updated tests for custom metrics * feat: added dl support for min max of datetime * feat: added is safe query check for query sampler * feat: added support for custom metric computation in dl * feat: added explicit addProper for pydantic model import fo Extra * feat: added custom metric to returned obj * feat: wrapped trino import in __init__ * feat: fix python linting * feat: fix typing in 3.8
This commit is contained in:
parent
efb6c5f221
commit
c7ac28f2c2
@ -428,7 +428,7 @@ def get_parser(args=None):
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def metadata(args=None): # pylint: disable=too-many-branches
|
||||
def metadata(args=None):
|
||||
"""
|
||||
This method implements parsing of the arguments passed from CLI
|
||||
"""
|
||||
|
@ -15,7 +15,10 @@ Return types for Profiler workflow execution.
|
||||
We need to define this class as we end up having
|
||||
multiple profilers per table and columns.
|
||||
"""
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
from metadata.config.common import ConfigModel
|
||||
from metadata.generated.schema.api.data.createTableProfile import (
|
||||
@ -31,9 +34,12 @@ from metadata.generated.schema.entity.data.table import (
|
||||
from metadata.generated.schema.entity.services.connections.connectionBasicType import (
|
||||
SampleDataStorageConfig,
|
||||
)
|
||||
from metadata.generated.schema.tests.customMetric import CustomMetric
|
||||
from metadata.generated.schema.type.basic import FullyQualifiedEntityName
|
||||
from metadata.ingestion.models.table_metadata import ColumnTag
|
||||
from metadata.profiler.metrics.core import Metric, MetricTypes
|
||||
from metadata.profiler.processor.models import ProfilerDef
|
||||
from metadata.utils.sqa_like_column import SQALikeColumn
|
||||
|
||||
|
||||
class ColumnConfig(ConfigModel):
|
||||
@ -113,3 +119,15 @@ class ProfilerResponse(ConfigModel):
|
||||
def __str__(self):
|
||||
"""Return the table name being processed"""
|
||||
return f"Table [{self.table.name.__root__}]"
|
||||
|
||||
|
||||
class ThreadPoolMetrics(ConfigModel):
|
||||
"""thread pool metric"""
|
||||
|
||||
metrics: Union[List[Union[Type[Metric], CustomMetric]], Type[Metric]]
|
||||
metric_type: MetricTypes
|
||||
column: Optional[Union[Column, SQALikeColumn]]
|
||||
table: Union[Table, DeclarativeMeta]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
@ -22,11 +22,13 @@ from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy import Column
|
||||
|
||||
from metadata.generated.schema.entity.data.table import TableData
|
||||
from metadata.generated.schema.entity.data.table import CustomMetricProfile, TableData
|
||||
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
|
||||
DatalakeConnection,
|
||||
)
|
||||
from metadata.generated.schema.tests.customMetric import CustomMetric
|
||||
from metadata.mixins.pandas.pandas_mixin import PandasInterfaceMixin
|
||||
from metadata.profiler.api.models import ThreadPoolMetrics
|
||||
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
||||
from metadata.profiler.metrics.core import MetricTypes
|
||||
from metadata.profiler.metrics.registry import Metrics
|
||||
@ -239,35 +241,67 @@ class PandasProfilerInterface(ProfilerInterface, PandasInterfaceMixin):
|
||||
"""
|
||||
return None # to be implemented
|
||||
|
||||
def _compute_custom_metrics(
|
||||
self, metrics: List[CustomMetric], runner, *args, **kwargs
|
||||
):
|
||||
"""Compute custom metrics. For pandas source we expect expression
|
||||
to be a boolean value. We'll return the length of the dataframe
|
||||
|
||||
Args:
|
||||
metrics (List[Metrics]): list of customMetrics
|
||||
runner (_type_): runner
|
||||
"""
|
||||
if not metrics:
|
||||
return None
|
||||
|
||||
custom_metrics = []
|
||||
|
||||
for metric in metrics:
|
||||
try:
|
||||
row = sum(
|
||||
len(df.query(metric.expression).index)
|
||||
for df in runner
|
||||
if len(df.query(metric.expression).index)
|
||||
)
|
||||
custom_metrics.append(
|
||||
CustomMetricProfile(name=metric.name.__root__, value=row)
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
msg = f"Error trying to compute profile for custom metric: {exc}"
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(msg)
|
||||
if custom_metrics:
|
||||
return {"customMetrics": custom_metrics}
|
||||
return None
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
metrics,
|
||||
metric_type,
|
||||
column,
|
||||
table,
|
||||
metric_func: ThreadPoolMetrics,
|
||||
):
|
||||
"""Run metrics in processor worker"""
|
||||
logger.debug(f"Running profiler for {table}")
|
||||
logger.debug(f"Running profiler for {metric_func.table}")
|
||||
try:
|
||||
row = None
|
||||
if self.complex_dataframe_sample:
|
||||
row = self._get_metric_fn[metric_type.value](
|
||||
metrics,
|
||||
row = self._get_metric_fn[metric_func.metric_type.value](
|
||||
metric_func.metrics,
|
||||
self.complex_dataframe_sample,
|
||||
column=column,
|
||||
column=metric_func.column,
|
||||
)
|
||||
except Exception as exc:
|
||||
name = f"{column if column is not None else table}"
|
||||
name = f"{metric_func.column if metric_func.column is not None else metric_func.table}"
|
||||
error = f"{name} metric_type.value: {exc}"
|
||||
logger.error(error)
|
||||
self.status.failed_profiler(error, traceback.format_exc())
|
||||
row = None
|
||||
if column is not None:
|
||||
column = column.name
|
||||
self.status.scanned(f"{table.name.__root__}.{column}")
|
||||
if metric_func.column is not None:
|
||||
column = metric_func.column.name
|
||||
self.status.scanned(f"{metric_func.table.name.__root__}.{column}")
|
||||
else:
|
||||
self.status.scanned(table.name.__root__)
|
||||
return row, column, metric_type.value
|
||||
self.status.scanned(metric_func.table.name.__root__)
|
||||
column = None
|
||||
return row, column, metric_func.metric_type.value
|
||||
|
||||
def fetch_sample_data(self, table, columns: SQALikeColumn) -> TableData:
|
||||
"""Fetch sample data from database
|
||||
@ -329,7 +363,7 @@ class PandasProfilerInterface(ProfilerInterface, PandasInterfaceMixin):
|
||||
|
||||
profile_results = {"table": {}, "columns": defaultdict(dict)}
|
||||
metric_list = [
|
||||
self.compute_metrics(*metric_func) for metric_func in metric_funcs
|
||||
self.compute_metrics(metric_func) for metric_func in metric_funcs
|
||||
]
|
||||
for metric_result in metric_list:
|
||||
profile, column, metric_type = metric_result
|
||||
@ -338,6 +372,8 @@ class PandasProfilerInterface(ProfilerInterface, PandasInterfaceMixin):
|
||||
profile_results["table"].update(profile)
|
||||
if metric_type == MetricTypes.System.value:
|
||||
profile_results["system"] = profile
|
||||
elif metric_type == MetricTypes.Custom.value and column is None:
|
||||
profile_results["table"].update(profile)
|
||||
else:
|
||||
if profile:
|
||||
profile_results["columns"][column].update(
|
||||
|
@ -46,6 +46,7 @@ from metadata.generated.schema.entity.services.databaseService import (
|
||||
from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline import (
|
||||
DatabaseServiceProfilerPipeline,
|
||||
)
|
||||
from metadata.generated.schema.tests.customMetric import CustomMetric
|
||||
from metadata.ingestion.api.models import StackTraceError
|
||||
from metadata.ingestion.api.status import Status
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
@ -130,6 +131,7 @@ class ProfilerInterface(ABC):
|
||||
MetricTypes.Query.value: self._compute_query_metrics,
|
||||
MetricTypes.Window.value: self._compute_window_metrics,
|
||||
MetricTypes.System.value: self._compute_system_metrics,
|
||||
MetricTypes.Custom.value: self._compute_custom_metrics,
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
@ -459,6 +461,13 @@ class ProfilerInterface(ABC):
|
||||
"""Get metrics"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _compute_custom_metrics(
|
||||
self, metrics: List[CustomMetric], runner, *args, **kwargs
|
||||
):
|
||||
"""Compute custom metrics"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_all_metrics(self, metric_funcs) -> dict:
|
||||
"""run profiler metrics"""
|
||||
|
@ -22,13 +22,15 @@ from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy import Column, inspect
|
||||
from sqlalchemy import Column, inspect, text
|
||||
from sqlalchemy.exc import ProgrammingError, ResourceClosedError
|
||||
from sqlalchemy.orm import scoped_session
|
||||
|
||||
from metadata.generated.schema.entity.data.table import TableData
|
||||
from metadata.generated.schema.entity.data.table import CustomMetricProfile, TableData
|
||||
from metadata.generated.schema.tests.customMetric import CustomMetric
|
||||
from metadata.ingestion.connections.session import create_and_bind_thread_safe_session
|
||||
from metadata.mixins.sqalchemy.sqa_mixin import SQAInterfaceMixin
|
||||
from metadata.profiler.api.models import ThreadPoolMetrics
|
||||
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
||||
from metadata.profiler.metrics.core import MetricTypes
|
||||
from metadata.profiler.metrics.registry import Metrics
|
||||
@ -42,6 +44,7 @@ from metadata.profiler.orm.registry import Dialects
|
||||
from metadata.profiler.processor.runner import QueryRunner
|
||||
from metadata.utils.constants import SAMPLE_DATA_DEFAULT_COUNT
|
||||
from metadata.utils.custom_thread_pool import CustomThreadPoolExecutor
|
||||
from metadata.utils.helpers import is_safe_sql_query
|
||||
from metadata.utils.logger import profiler_interface_registry_logger
|
||||
|
||||
logger = profiler_interface_registry_logger()
|
||||
@ -307,6 +310,8 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
|
||||
row = runner.select_first_from_sample(
|
||||
*[metric(column).fn() for metric in metrics],
|
||||
)
|
||||
if row:
|
||||
return dict(row)
|
||||
except ProgrammingError as exc:
|
||||
logger.info(
|
||||
f"Skipping metrics for {runner.table.__tablename__}.{column.name} due to {exc}"
|
||||
@ -314,8 +319,43 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
|
||||
except Exception as exc:
|
||||
msg = f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
|
||||
handle_query_exception(msg, exc, session)
|
||||
if row:
|
||||
return dict(row)
|
||||
return None
|
||||
|
||||
def _compute_custom_metrics(
|
||||
self, metrics: List[CustomMetric], runner, session, *args, **kwargs
|
||||
):
|
||||
"""Compute custom metrics
|
||||
|
||||
Args:
|
||||
metrics (List[Metrics]): list of customMetrics
|
||||
runner (_type_): runner
|
||||
"""
|
||||
if not metrics:
|
||||
return None
|
||||
|
||||
custom_metrics = []
|
||||
|
||||
for metric in metrics:
|
||||
try:
|
||||
if not is_safe_sql_query(metric.expression):
|
||||
raise RuntimeError(
|
||||
f"SQL expression is not safe\n\n{metric.expression}"
|
||||
)
|
||||
|
||||
crs = session.execute(text(metric.expression))
|
||||
row = (
|
||||
crs.scalar()
|
||||
) # raise MultipleResultsFound if more than one row is returned
|
||||
custom_metrics.append(
|
||||
CustomMetricProfile(name=metric.name.__root__, value=row)
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
msg = f"Error trying to compute profile for {runner.table.__tablename__}.{metric.columnName}: {exc}"
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(msg)
|
||||
if custom_metrics:
|
||||
return {"customMetrics": custom_metrics}
|
||||
return None
|
||||
|
||||
def _compute_system_metrics(
|
||||
@ -376,14 +416,11 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
|
||||
|
||||
def compute_metrics_in_thread(
|
||||
self,
|
||||
metrics,
|
||||
metric_type,
|
||||
column,
|
||||
table,
|
||||
metric_func: ThreadPoolMetrics,
|
||||
):
|
||||
"""Run metrics in processor worker"""
|
||||
logger.debug(
|
||||
f"Running profiler for {table.__tablename__} on thread {threading.current_thread()}"
|
||||
f"Running profiler for {metric_func.table.__tablename__} on thread {threading.current_thread()}"
|
||||
)
|
||||
Session = self.session_factory # pylint: disable=invalid-name
|
||||
with Session() as session:
|
||||
@ -391,36 +428,40 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
|
||||
self.set_catalog(session)
|
||||
sampler = self._create_thread_safe_sampler(
|
||||
session,
|
||||
table,
|
||||
metric_func.table,
|
||||
)
|
||||
sample = sampler.random_sample()
|
||||
runner = self._create_thread_safe_runner(
|
||||
session,
|
||||
table,
|
||||
metric_func.table,
|
||||
sample,
|
||||
)
|
||||
|
||||
try:
|
||||
row = self._get_metric_fn[metric_type.value](
|
||||
metrics,
|
||||
row = self._get_metric_fn[metric_func.metric_type.value](
|
||||
metric_func.metrics,
|
||||
runner=runner,
|
||||
session=session,
|
||||
column=column,
|
||||
column=metric_func.column,
|
||||
sample=sample,
|
||||
)
|
||||
except Exception as exc:
|
||||
error = f"{column if column is not None else runner.table.__tablename__} metric_type.value: {exc}"
|
||||
error = (
|
||||
f"{metric_func.column if metric_func.column is not None else metric_func.table.__tablename__} "
|
||||
f"metric_type.value: {exc}"
|
||||
)
|
||||
logger.error(error)
|
||||
self.status.failed_profiler(error, traceback.format_exc())
|
||||
row = None
|
||||
|
||||
if column is not None:
|
||||
column = column.name
|
||||
self.status.scanned(f"{table.__tablename__}.{column}")
|
||||
if metric_func.column is not None:
|
||||
column = metric_func.column.name
|
||||
self.status.scanned(f"{metric_func.table.__tablename__}.{column}")
|
||||
else:
|
||||
self.status.scanned(table.__tablename__)
|
||||
self.status.scanned(metric_func.table.__tablename__)
|
||||
column = None
|
||||
|
||||
return row, column, metric_type.value
|
||||
return row, column, metric_func.metric_type.value
|
||||
|
||||
# pylint: disable=use-dict-literal
|
||||
def get_all_metrics(
|
||||
@ -434,7 +475,7 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
|
||||
futures = [
|
||||
pool.submit(
|
||||
self.compute_metrics_in_thread,
|
||||
*metric_func,
|
||||
metric_func,
|
||||
)
|
||||
for metric_func in metric_funcs
|
||||
]
|
||||
@ -455,6 +496,8 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
|
||||
profile_results["table"].update(profile)
|
||||
elif metric_type == MetricTypes.System.value:
|
||||
profile_results["system"] = profile
|
||||
elif metric_type == MetricTypes.Custom.value and column is None:
|
||||
profile_results["table"].update(profile)
|
||||
else:
|
||||
profile_results["columns"][column].update(
|
||||
{
|
||||
|
@ -72,6 +72,8 @@ class SingleStoreProfilerInterface(SQAProfilerInterface):
|
||||
row = runner.select_first_from_sample(
|
||||
*[metric(column).fn() for metric in metrics],
|
||||
)
|
||||
if row:
|
||||
return dict(row)
|
||||
except ProgrammingError:
|
||||
logger.info(
|
||||
f"Skipping window metrics for {runner.table.__tablename__}.{column.name} due to overflow"
|
||||
@ -81,6 +83,4 @@ class SingleStoreProfilerInterface(SQAProfilerInterface):
|
||||
except Exception as exc:
|
||||
msg = f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
|
||||
handle_query_exception(msg, exc, session)
|
||||
if row:
|
||||
return dict(row)
|
||||
return None
|
||||
|
@ -67,6 +67,8 @@ class TrinoProfilerInterface(SQAProfilerInterface):
|
||||
row = runner.select_first_from_sample(
|
||||
*[metric(column).fn() for metric in metrics], **runner_kwargs
|
||||
)
|
||||
if row:
|
||||
return dict(row)
|
||||
except ProgrammingError as err:
|
||||
logger.info(
|
||||
f"Skipping window metrics for {runner.table.__tablename__}.{column.name} due to {err}"
|
||||
@ -76,6 +78,4 @@ class TrinoProfilerInterface(SQAProfilerInterface):
|
||||
except Exception as exc:
|
||||
msg = f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
|
||||
handle_query_exception(msg, exc, session)
|
||||
if row:
|
||||
return dict(row)
|
||||
return None
|
||||
|
@ -90,6 +90,9 @@ class Max(StaticMetric):
|
||||
|
||||
def df_fn(self, dfs=None):
|
||||
"""pandas function"""
|
||||
if is_quantifiable(self.col.type) or is_date_time(self.col.type):
|
||||
if is_quantifiable(self.col.type):
|
||||
return max((df[self.col.name].max() for df in dfs))
|
||||
if is_date_time(self.col.type):
|
||||
max_ = max((df[self.col.name].max() for df in dfs))
|
||||
return int(max_.timestamp() * 1000)
|
||||
return 0
|
||||
|
@ -90,6 +90,9 @@ class Min(StaticMetric):
|
||||
|
||||
def df_fn(self, dfs=None):
|
||||
"""pandas function"""
|
||||
if is_quantifiable(self.col.type) or is_date_time(self.col.type):
|
||||
if is_quantifiable(self.col.type):
|
||||
return min((df[self.col.name].min() for df in dfs))
|
||||
if is_date_time(self.col.type):
|
||||
min_ = min((df[self.col.name].min() for df in dfs))
|
||||
return int(min_.timestamp() * 1000)
|
||||
return 0
|
||||
|
@ -33,7 +33,10 @@ from metadata.generated.schema.entity.data.table import (
|
||||
TableData,
|
||||
TableProfile,
|
||||
)
|
||||
from metadata.profiler.api.models import ProfilerResponse
|
||||
from metadata.generated.schema.tests.customMetric import (
|
||||
CustomMetric as CustomMetricEntity,
|
||||
)
|
||||
from metadata.profiler.api.models import ProfilerResponse, ThreadPoolMetrics
|
||||
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
||||
from metadata.profiler.metrics.core import (
|
||||
ComposedMetric,
|
||||
@ -273,6 +276,33 @@ class Profiler(Generic[TMetric]):
|
||||
|
||||
return [metric for metric in metrics if metric.is_col_metric()]
|
||||
|
||||
def get_custom_metrics(
|
||||
self, column_name: Optional[str] = None
|
||||
) -> Optional[List[CustomMetricEntity]]:
|
||||
"""Get custom metrics for a table or column
|
||||
|
||||
Args:
|
||||
column (Optional[str]): optional column name. If None will fetch table level custom metrics
|
||||
|
||||
Returns:
|
||||
List[str]
|
||||
"""
|
||||
if column_name is None:
|
||||
return self.profiler_interface.table_entity.customMetrics or None
|
||||
|
||||
# if we have a column we'll get the custom metrics for this column
|
||||
column = next(
|
||||
(
|
||||
clmn
|
||||
for clmn in self.profiler_interface.table_entity.columns
|
||||
if clmn.name.__root__ == column_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if column:
|
||||
return column.customMetrics or None
|
||||
return None
|
||||
|
||||
@property
|
||||
def sample(self):
|
||||
"""Return the sample used for the profiler"""
|
||||
@ -345,22 +375,40 @@ class Profiler(Generic[TMetric]):
|
||||
|
||||
def _prepare_table_metrics(self) -> List:
|
||||
"""prepare table metrics"""
|
||||
metrics = []
|
||||
table_metrics = [
|
||||
metric
|
||||
for metric in self.static_metrics
|
||||
if (not metric.is_col_metric() and not metric.is_system_metrics())
|
||||
]
|
||||
|
||||
custom_table_metrics = self.get_custom_metrics()
|
||||
|
||||
if table_metrics:
|
||||
return [
|
||||
(
|
||||
table_metrics, # metric functions
|
||||
MetricTypes.Table, # metric type for function mapping
|
||||
None, # column name
|
||||
self.table, # table name
|
||||
),
|
||||
]
|
||||
return []
|
||||
metrics.extend(
|
||||
[
|
||||
ThreadPoolMetrics(
|
||||
metrics=table_metrics,
|
||||
metric_type=MetricTypes.Table,
|
||||
column=None,
|
||||
table=self.table,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if custom_table_metrics:
|
||||
metrics.extend(
|
||||
[
|
||||
ThreadPoolMetrics(
|
||||
metrics=custom_table_metrics,
|
||||
metric_type=MetricTypes.Custom,
|
||||
column=None,
|
||||
table=self.table,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
def _prepare_system_metrics(self) -> List:
|
||||
"""prepare system metrics"""
|
||||
@ -368,11 +416,11 @@ class Profiler(Generic[TMetric]):
|
||||
|
||||
if system_metrics:
|
||||
return [
|
||||
(
|
||||
system_metric, # metric functions
|
||||
MetricTypes.System, # metric type for function mapping
|
||||
None, # column name
|
||||
self.table, # table name
|
||||
ThreadPoolMetrics(
|
||||
metrics=system_metric, # metric functions
|
||||
metric_type=MetricTypes.System, # metric type for function mapping
|
||||
column=None, # column name
|
||||
table=self.table, # table name
|
||||
)
|
||||
for system_metric in system_metrics
|
||||
]
|
||||
@ -386,45 +434,60 @@ class Profiler(Generic[TMetric]):
|
||||
for column in self.columns
|
||||
if column.type.__class__.__name__ not in NOT_COMPUTE
|
||||
]
|
||||
|
||||
column_metrics_for_thread_pool = [
|
||||
*[
|
||||
(
|
||||
[
|
||||
metric
|
||||
for metric in self.get_col_metrics(self.static_metrics, column)
|
||||
if not metric.is_window_metric()
|
||||
],
|
||||
MetricTypes.Static,
|
||||
column,
|
||||
self.table,
|
||||
)
|
||||
for column in columns
|
||||
],
|
||||
*[
|
||||
(
|
||||
metric,
|
||||
MetricTypes.Query,
|
||||
column,
|
||||
self.table,
|
||||
)
|
||||
for column in columns
|
||||
for metric in self.get_col_metrics(self.query_metrics, column)
|
||||
],
|
||||
*[
|
||||
(
|
||||
[
|
||||
metric
|
||||
for metric in self.get_col_metrics(self.static_metrics, column)
|
||||
if metric.is_window_metric()
|
||||
],
|
||||
MetricTypes.Window,
|
||||
column,
|
||||
self.table,
|
||||
)
|
||||
for column in columns
|
||||
],
|
||||
column_metrics_for_thread_pool = []
|
||||
static_metrics = [
|
||||
ThreadPoolMetrics(
|
||||
metrics=[
|
||||
metric
|
||||
for metric in self.get_col_metrics(self.static_metrics, column)
|
||||
if not metric.is_window_metric()
|
||||
],
|
||||
metric_type=MetricTypes.Static,
|
||||
column=column,
|
||||
table=self.table,
|
||||
)
|
||||
for column in columns
|
||||
]
|
||||
query_metrics = [
|
||||
ThreadPoolMetrics(
|
||||
metrics=metric,
|
||||
metric_type=MetricTypes.Query,
|
||||
column=column,
|
||||
table=self.table,
|
||||
)
|
||||
for column in columns
|
||||
for metric in self.get_col_metrics(self.query_metrics, column)
|
||||
]
|
||||
window_metrics = [
|
||||
ThreadPoolMetrics(
|
||||
metrics=[
|
||||
metric
|
||||
for metric in self.get_col_metrics(self.static_metrics, column)
|
||||
if metric.is_window_metric()
|
||||
],
|
||||
metric_type=MetricTypes.Window,
|
||||
column=column,
|
||||
table=self.table,
|
||||
)
|
||||
for column in columns
|
||||
]
|
||||
|
||||
# we'll add the system metrics to the thread pool computation
|
||||
for metric_type in [static_metrics, query_metrics, window_metrics]:
|
||||
column_metrics_for_thread_pool.extend(metric_type)
|
||||
|
||||
# we'll add the custom metrics to the thread pool computation
|
||||
for column in columns:
|
||||
custom_metrics = self.get_custom_metrics(column.name)
|
||||
if custom_metrics:
|
||||
column_metrics_for_thread_pool.append(
|
||||
ThreadPoolMetrics(
|
||||
metrics=custom_metrics,
|
||||
metric_type=MetricTypes.Custom,
|
||||
column=column,
|
||||
table=self.table,
|
||||
)
|
||||
)
|
||||
|
||||
return column_metrics_for_thread_pool
|
||||
|
||||
@ -568,6 +631,7 @@ class Profiler(Generic[TMetric]):
|
||||
profileSampleType=self.profile_sample_config.profile_sample_type
|
||||
if self.profile_sample_config
|
||||
else None,
|
||||
customMetrics=self._table_results.get("customMetrics"),
|
||||
)
|
||||
|
||||
if self._system_results:
|
||||
|
@ -31,6 +31,7 @@ from metadata.profiler.orm.functions.random_num import RandomNumFn
|
||||
from metadata.profiler.orm.registry import Dialects
|
||||
from metadata.profiler.processor.handle_partition import partition_filter_handler
|
||||
from metadata.profiler.processor.sampler.sampler_interface import SamplerInterface
|
||||
from metadata.utils.helpers import is_safe_sql_query
|
||||
from metadata.utils.logger import profiler_interface_registry_logger
|
||||
from metadata.utils.sqa_utils import (
|
||||
build_query_filter,
|
||||
@ -171,6 +172,11 @@ class SQASampler(SamplerInterface):
|
||||
|
||||
def _fetch_sample_data_from_user_query(self) -> TableData:
|
||||
"""Returns a table data object using results from query execution"""
|
||||
if not is_safe_sql_query(self._profile_sample_query):
|
||||
raise RuntimeError(
|
||||
f"SQL expression is not safe\n\n{self._profile_sample_query}"
|
||||
)
|
||||
|
||||
rnd = self.client.execute(f"{self._profile_sample_query}")
|
||||
try:
|
||||
columns = [col.name for col in rnd.cursor.description]
|
||||
@ -183,6 +189,11 @@ class SQASampler(SamplerInterface):
|
||||
|
||||
def _rdn_sample_from_user_query(self) -> Query:
|
||||
"""Returns sql alchemy object to use when running profiling"""
|
||||
if not is_safe_sql_query(self._profile_sample_query):
|
||||
raise RuntimeError(
|
||||
f"SQL expression is not safe\n\n{self._profile_sample_query}"
|
||||
)
|
||||
|
||||
return self.client.query(self.table).from_statement(
|
||||
text(f"{self._profile_sample_query}")
|
||||
)
|
||||
|
@ -13,7 +13,6 @@ Helper module to handle data sampling
|
||||
for the profiler
|
||||
"""
|
||||
from sqlalchemy import inspect, or_, text
|
||||
from trino.sqlalchemy.dialect import TrinoDialect
|
||||
|
||||
from metadata.profiler.orm.registry import FLOAT_SET
|
||||
from metadata.profiler.processor.handle_partition import RANDOM_LABEL
|
||||
@ -26,7 +25,13 @@ class TrinoSampler(SQASampler):
|
||||
run the query in the whole table.
|
||||
"""
|
||||
|
||||
TrinoDialect._json_deserializer = None # pylint: disable=protected-access
|
||||
def __init__(self, *args, **kwargs):
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from trino.sqlalchemy.dialect import TrinoDialect
|
||||
|
||||
TrinoDialect._json_deserializer = None
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _base_sample_query(self, label=None):
|
||||
sqa_columns = [col for col in inspect(self.table).c if col.name != RANDOM_LABEL]
|
||||
|
@ -238,9 +238,7 @@ class OpenMetadataSource(Source):
|
||||
"""
|
||||
tables = self.metadata.list_all_entities(
|
||||
entity=Table,
|
||||
fields=[
|
||||
"tableProfilerConfig",
|
||||
],
|
||||
fields=["tableProfilerConfig", "columns", "customMetrics"],
|
||||
params={
|
||||
"service": self.config.source.serviceName,
|
||||
"database": fqn.build(
|
||||
|
@ -430,6 +430,9 @@ def is_safe_sql_query(sql_query: str) -> bool:
|
||||
"SET TRANSACTION",
|
||||
}
|
||||
|
||||
if sql_query is None:
|
||||
return True
|
||||
|
||||
parsed_queries: Tuple[Statement] = sqlparse.parse(sql_query)
|
||||
for parsed_query in parsed_queries:
|
||||
validation = [
|
||||
|
@ -0,0 +1,4 @@
|
||||
id,first_name,last_name,city,country,birthdate,age
|
||||
1,John,Doe,Los Angeles,US,1980-01-01,40
|
||||
2,Jane,Doe,Los Angeles,US,2000-12-31,39
|
||||
3,Jane,Smith,Paris,,2001-11-11,28
|
|
233
ingestion/tests/unit/profiler/pandas/test_custom_metrics.py
Normal file
233
ingestion/tests/unit/profiler/pandas/test_custom_metrics.py
Normal file
@ -0,0 +1,233 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Test Metrics behavior
|
||||
"""
|
||||
# import datetime
|
||||
import os
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
import pandas as pd
|
||||
from moto import mock_s3
|
||||
|
||||
from metadata.generated.schema.entity.data.table import Column as EntityColumn
|
||||
from metadata.generated.schema.entity.data.table import ColumnName, DataType, Table
|
||||
from metadata.generated.schema.entity.services.connections.database.datalake.s3Config import (
|
||||
S3Config,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
|
||||
DatalakeConnection,
|
||||
)
|
||||
from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials
|
||||
from metadata.generated.schema.tests.customMetric import CustomMetric
|
||||
from metadata.profiler.interface.pandas.profiler_interface import (
|
||||
PandasProfilerInterface,
|
||||
)
|
||||
from metadata.profiler.processor.core import Profiler
|
||||
|
||||
BUCKET_NAME = "MyBucket"
|
||||
|
||||
|
||||
@mock_s3
|
||||
class MetricsTest(TestCase):
|
||||
"""
|
||||
Run checks on different metrics
|
||||
"""
|
||||
|
||||
current_dir = os.path.dirname(__file__)
|
||||
resources_dir = os.path.join(current_dir, "resources")
|
||||
|
||||
datalake_conn = DatalakeConnection(
|
||||
configSource=S3Config(
|
||||
securityConfig=AWSCredentials(
|
||||
awsAccessKeyId="fake_access_key",
|
||||
awsSecretAccessKey="fake_secret_key",
|
||||
awsRegion="us-west-1",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
dfs = [
|
||||
pd.read_csv(os.path.join(resources_dir, "profiler_test_.csv"), parse_dates=[5])
|
||||
]
|
||||
|
||||
table_entity = Table(
|
||||
id=uuid4(),
|
||||
name="user",
|
||||
columns=[
|
||||
EntityColumn(
|
||||
name=ColumnName(__root__="id"),
|
||||
dataType=DataType.INT,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
# Mock our S3 bucket and ingest a file
|
||||
boto3.DEFAULT_SESSION = None
|
||||
self.client = boto3.client(
|
||||
"s3",
|
||||
region_name="us-weat-1",
|
||||
)
|
||||
|
||||
# check that we are not running our test against a real bucket
|
||||
try:
|
||||
s3 = boto3.resource(
|
||||
"s3",
|
||||
region_name="us-west-1",
|
||||
aws_access_key_id="fake_access_key",
|
||||
aws_secret_access_key="fake_secret_key",
|
||||
)
|
||||
s3.meta.client.head_bucket(Bucket=BUCKET_NAME)
|
||||
except botocore.exceptions.ClientError:
|
||||
pass
|
||||
else:
|
||||
err = f"{BUCKET_NAME} should not exist."
|
||||
raise EnvironmentError(err)
|
||||
self.client.create_bucket(
|
||||
Bucket=BUCKET_NAME,
|
||||
CreateBucketConfiguration={"LocationConstraint": "us-west-1"},
|
||||
)
|
||||
|
||||
resources_paths = [
|
||||
os.path.join(path, filename)
|
||||
for path, _, files in os.walk(self.resources_dir)
|
||||
for filename in files
|
||||
]
|
||||
|
||||
self.s3_keys = []
|
||||
|
||||
for path in resources_paths:
|
||||
key = os.path.relpath(path, self.resources_dir)
|
||||
self.s3_keys.append(key)
|
||||
self.client.upload_file(Filename=path, Bucket=BUCKET_NAME, Key=key)
|
||||
|
||||
with patch.object(
|
||||
PandasProfilerInterface,
|
||||
"_convert_table_to_list_of_dataframe_objects",
|
||||
return_value=self.dfs,
|
||||
):
|
||||
self.sqa_profiler_interface = PandasProfilerInterface(
|
||||
self.datalake_conn,
|
||||
None,
|
||||
self.table_entity,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
thread_count=1,
|
||||
)
|
||||
|
||||
def test_table_custom_metric(self):
|
||||
table_entity = Table(
|
||||
id=uuid4(),
|
||||
name="user",
|
||||
columns=[
|
||||
EntityColumn(
|
||||
name=ColumnName(__root__="id"),
|
||||
dataType=DataType.INT,
|
||||
)
|
||||
],
|
||||
customMetrics=[
|
||||
CustomMetric(
|
||||
name="LastNameFilter",
|
||||
expression="'last_name' != Doe",
|
||||
),
|
||||
CustomMetric(
|
||||
name="notUS",
|
||||
expression="'country == US'",
|
||||
),
|
||||
],
|
||||
)
|
||||
with patch.object(
|
||||
PandasProfilerInterface,
|
||||
"_convert_table_to_list_of_dataframe_objects",
|
||||
return_value=self.dfs,
|
||||
):
|
||||
self.sqa_profiler_interface = PandasProfilerInterface(
|
||||
self.datalake_conn,
|
||||
None,
|
||||
table_entity,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
thread_count=1,
|
||||
)
|
||||
|
||||
profiler = Profiler(
|
||||
profiler_interface=self.sqa_profiler_interface,
|
||||
)
|
||||
metrics = profiler.compute_metrics()
|
||||
for k, v in metrics._table_results.items():
|
||||
for metric in v:
|
||||
if metric.name == "LastNameFilter":
|
||||
assert metric.value == 1
|
||||
if metric.name == "notUS":
|
||||
assert metric.value == 2
|
||||
|
||||
def test_column_custom_metric(self):
|
||||
table_entity = Table(
|
||||
id=uuid4(),
|
||||
name="user",
|
||||
columns=[
|
||||
EntityColumn(
|
||||
name=ColumnName(__root__="id"),
|
||||
dataType=DataType.INT,
|
||||
customMetrics=[
|
||||
CustomMetric(
|
||||
name="LastNameFilter",
|
||||
columnName="id",
|
||||
expression="'last_name' != Doe",
|
||||
),
|
||||
CustomMetric(
|
||||
name="notUS",
|
||||
columnName="id",
|
||||
expression="'country == US'",
|
||||
),
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
with patch.object(
|
||||
PandasProfilerInterface,
|
||||
"_convert_table_to_list_of_dataframe_objects",
|
||||
return_value=self.dfs,
|
||||
):
|
||||
self.sqa_profiler_interface = PandasProfilerInterface(
|
||||
self.datalake_conn,
|
||||
None,
|
||||
table_entity,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
thread_count=1,
|
||||
)
|
||||
|
||||
profiler = Profiler(
|
||||
profiler_interface=self.sqa_profiler_interface,
|
||||
)
|
||||
metrics = profiler.compute_metrics()
|
||||
for k, v in metrics._column_results.items():
|
||||
for metric in v.get("customMetrics", []):
|
||||
if metric.name == "CustomerBornedAfter1991":
|
||||
assert metric.value == 1
|
||||
if metric.name == "AverageAge":
|
||||
assert metric.value == 2
|
@ -269,5 +269,8 @@ class ProfilerTest(TestCase):
|
||||
|
||||
column_metrics = default_profiler._prepare_column_metrics()
|
||||
for metric in column_metrics:
|
||||
if metric[1] is not MetricTypes.Table and metric[2].name == "id":
|
||||
assert all(metric_filter.count(m.name()) for m in metric[0])
|
||||
if (
|
||||
metric.metric_type is not MetricTypes.Table
|
||||
and metric.column.name == "id"
|
||||
):
|
||||
assert all(metric_filter.count(m.name()) for m in metric.metrics)
|
||||
|
@ -32,6 +32,7 @@ from metadata.generated.schema.entity.data.table import (
|
||||
Table,
|
||||
TableProfile,
|
||||
)
|
||||
from metadata.profiler.api.models import ThreadPoolMetrics
|
||||
from metadata.profiler.interface.pandas.profiler_interface import (
|
||||
PandasProfilerInterface,
|
||||
)
|
||||
@ -146,15 +147,15 @@ class PandasInterfaceTest(TestCase):
|
||||
|
||||
def test_get_all_metrics(self):
|
||||
table_metrics = [
|
||||
(
|
||||
[
|
||||
ThreadPoolMetrics(
|
||||
metrics=[
|
||||
metric
|
||||
for metric in self.metrics
|
||||
if (not metric.is_col_metric() and not metric.is_system_metrics())
|
||||
],
|
||||
MetricTypes.Table,
|
||||
None,
|
||||
self.table_entity,
|
||||
metric_type=MetricTypes.Table,
|
||||
column=None,
|
||||
table=self.table_entity,
|
||||
)
|
||||
]
|
||||
column_metrics = []
|
||||
@ -164,36 +165,36 @@ class PandasInterfaceTest(TestCase):
|
||||
if col.name == "id":
|
||||
continue
|
||||
column_metrics.append(
|
||||
(
|
||||
[
|
||||
ThreadPoolMetrics(
|
||||
metrics=[
|
||||
metric
|
||||
for metric in self.static_metrics
|
||||
if metric.is_col_metric() and not metric.is_window_metric()
|
||||
],
|
||||
MetricTypes.Static,
|
||||
col,
|
||||
self.table_entity,
|
||||
metric_type=MetricTypes.Static,
|
||||
column=col,
|
||||
table=self.table_entity,
|
||||
)
|
||||
)
|
||||
for query_metric in self.query_metrics:
|
||||
query_metrics.append(
|
||||
(
|
||||
query_metric,
|
||||
MetricTypes.Query,
|
||||
col,
|
||||
self.table_entity,
|
||||
ThreadPoolMetrics(
|
||||
metrics=query_metric,
|
||||
metric_type=MetricTypes.Query,
|
||||
column=col,
|
||||
table=self.table_entity,
|
||||
)
|
||||
)
|
||||
window_metrics.append(
|
||||
(
|
||||
[
|
||||
ThreadPoolMetrics(
|
||||
metrics=[
|
||||
metric
|
||||
for metric in self.window_metrics
|
||||
if metric.is_window_metric()
|
||||
],
|
||||
MetricTypes.Window,
|
||||
col,
|
||||
self.table_entity,
|
||||
metric_type=MetricTypes.Window,
|
||||
column=col,
|
||||
table=self.table_entity,
|
||||
)
|
||||
)
|
||||
|
@ -27,6 +27,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
|
||||
SQLiteConnection,
|
||||
SQLiteScheme,
|
||||
)
|
||||
from metadata.generated.schema.tests.customMetric import CustomMetric
|
||||
from metadata.profiler.interface.sqlalchemy.profiler_interface import (
|
||||
SQAProfilerInterface,
|
||||
)
|
||||
@ -865,6 +866,102 @@ class MetricsTest(TestCase):
|
||||
session = self.sqa_profiler_interface.session
|
||||
system().sql(session)
|
||||
|
||||
def test_table_custom_metric(self):
|
||||
table_entity = Table(
|
||||
id=uuid4(),
|
||||
name="user",
|
||||
columns=[
|
||||
EntityColumn(
|
||||
name=ColumnName(__root__="id"),
|
||||
dataType=DataType.INT,
|
||||
)
|
||||
],
|
||||
customMetrics=[
|
||||
CustomMetric(
|
||||
name="CustomerBornedAfter1991",
|
||||
expression="SELECT COUNT(id) FROM users WHERE dob > '1991-01-01'",
|
||||
),
|
||||
CustomMetric(
|
||||
name="AverageAge",
|
||||
expression="SELECT SUM(age)/COUNT(*) FROM users",
|
||||
),
|
||||
],
|
||||
)
|
||||
with patch.object(
|
||||
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
|
||||
):
|
||||
self.sqa_profiler_interface = SQAProfilerInterface(
|
||||
self.sqlite_conn,
|
||||
None,
|
||||
table_entity,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
thread_count=1,
|
||||
)
|
||||
|
||||
profiler = Profiler(
|
||||
profiler_interface=self.sqa_profiler_interface,
|
||||
)
|
||||
metrics = profiler.compute_metrics()
|
||||
for k, v in metrics._table_results.items():
|
||||
for metric in v:
|
||||
if metric.name == "CustomerBornedAfter1991":
|
||||
assert metric.value == 2
|
||||
if metric.name == "AverageAge":
|
||||
assert metric.value == 20.0
|
||||
|
||||
def test_column_custom_metric(self):
|
||||
table_entity = Table(
|
||||
id=uuid4(),
|
||||
name="user",
|
||||
columns=[
|
||||
EntityColumn(
|
||||
name=ColumnName(__root__="id"),
|
||||
dataType=DataType.INT,
|
||||
customMetrics=[
|
||||
CustomMetric(
|
||||
name="CustomerBornedAfter1991",
|
||||
columnName="id",
|
||||
expression="SELECT SUM(id) FROM users WHERE dob > '1991-01-01'",
|
||||
),
|
||||
CustomMetric(
|
||||
name="AverageAge",
|
||||
columnName="id",
|
||||
expression="SELECT SUM(age)/COUNT(*) FROM users",
|
||||
),
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
with patch.object(
|
||||
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
|
||||
):
|
||||
self.sqa_profiler_interface = SQAProfilerInterface(
|
||||
self.sqlite_conn,
|
||||
None,
|
||||
table_entity,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
thread_count=1,
|
||||
)
|
||||
|
||||
profiler = Profiler(
|
||||
profiler_interface=self.sqa_profiler_interface,
|
||||
)
|
||||
metrics = profiler.compute_metrics()
|
||||
for k, v in metrics._column_results.items():
|
||||
for metric in v.get("customMetrics", []):
|
||||
if metric.name == "CustomerBornedAfter1991":
|
||||
assert metric.value == 3.0
|
||||
if metric.name == "AverageAge":
|
||||
assert metric.value == 20.0
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
os.remove(cls.db_path)
|
||||
|
@ -42,6 +42,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
|
||||
SQLiteConnection,
|
||||
SQLiteScheme,
|
||||
)
|
||||
from metadata.generated.schema.tests.customMetric import CustomMetric
|
||||
from metadata.ingestion.source import sqa_types
|
||||
from metadata.profiler.interface.sqlalchemy.profiler_interface import (
|
||||
SQAProfilerInterface,
|
||||
@ -83,8 +84,27 @@ class ProfilerTest(TestCase):
|
||||
EntityColumn(
|
||||
name=ColumnName(__root__="id"),
|
||||
dataType=DataType.INT,
|
||||
customMetrics=[
|
||||
CustomMetric(
|
||||
name="custom_metric",
|
||||
description="custom metric",
|
||||
expression="SELECT cos(id) FROM users",
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
customMetrics=[
|
||||
CustomMetric(
|
||||
name="custom_metric",
|
||||
description="custom metric",
|
||||
expression="SELECT COUNT(id) / COUNT(age) FROM users",
|
||||
),
|
||||
CustomMetric(
|
||||
name="custom_metric_two",
|
||||
description="custom metric",
|
||||
expression="SELECT COUNT(id) * COUNT(age) FROM users",
|
||||
),
|
||||
],
|
||||
)
|
||||
with patch.object(
|
||||
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
|
||||
@ -231,6 +251,30 @@ class ProfilerTest(TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test__prepare_column_metrics(self):
|
||||
"""test _prepare_column_metrics returns as expected"""
|
||||
profiler = Profiler(
|
||||
Metrics.FIRST_QUARTILE.value,
|
||||
profiler_interface=self.sqa_profiler_interface,
|
||||
)
|
||||
|
||||
metrics = profiler._prepare_column_metrics()
|
||||
for metric in metrics:
|
||||
if metric.metrics:
|
||||
if isinstance(metric.metrics[0], CustomMetric):
|
||||
assert metric.metrics[0].name.__root__ == "custom_metric"
|
||||
else:
|
||||
assert metric.metrics[0].name() == "firstQuartile"
|
||||
|
||||
def test__prepare_table_metrics(self):
|
||||
"""test _prepare_table_metrics returns as expected"""
|
||||
profiler = Profiler(
|
||||
Metrics.COLUMN_COUNT.value,
|
||||
profiler_interface=self.sqa_profiler_interface,
|
||||
)
|
||||
metrics = profiler._prepare_table_metrics()
|
||||
self.assertEqual(2, len(metrics))
|
||||
|
||||
def test_profiler_with_timeout(self):
|
||||
"""check timeout is properly used"""
|
||||
|
||||
@ -259,6 +303,7 @@ class ProfilerTest(TestCase):
|
||||
def test_profiler_get_col_metrics(self):
|
||||
"""check getc column metrics"""
|
||||
metric_filter = ["mean", "min", "max", "firstQuartile"]
|
||||
custom_metric_filter = ["custom_metric"]
|
||||
self.sqa_profiler_interface.table_entity.tableProfilerConfig = (
|
||||
TableProfilerConfig(
|
||||
includeColumns=[
|
||||
@ -273,8 +318,20 @@ class ProfilerTest(TestCase):
|
||||
|
||||
column_metrics = default_profiler._prepare_column_metrics()
|
||||
for metric in column_metrics:
|
||||
if metric[1] is not MetricTypes.Table and metric[2].name == "id":
|
||||
assert all(metric_filter.count(m.name()) for m in metric[0])
|
||||
if (
|
||||
metric.metric_type is not MetricTypes.Table
|
||||
and metric.column.name == "id"
|
||||
):
|
||||
assert all(
|
||||
metric_filter.count(m.name())
|
||||
for m in metric.metrics
|
||||
if not isinstance(m, CustomMetric)
|
||||
)
|
||||
assert all(
|
||||
custom_metric_filter.count(m.name.__root__)
|
||||
for m in metric.metrics
|
||||
if isinstance(m, CustomMetric)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
|
@ -38,6 +38,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
|
||||
SQLiteConnection,
|
||||
SQLiteScheme,
|
||||
)
|
||||
from metadata.profiler.api.models import ThreadPoolMetrics
|
||||
from metadata.profiler.interface.sqlalchemy.profiler_interface import (
|
||||
SQAProfilerInterface,
|
||||
)
|
||||
@ -162,15 +163,15 @@ class SQAInterfaceTestMultiThread(TestCase):
|
||||
|
||||
def test_get_all_metrics(self):
|
||||
table_metrics = [
|
||||
(
|
||||
[
|
||||
ThreadPoolMetrics(
|
||||
metrics=[
|
||||
metric
|
||||
for metric in self.metrics
|
||||
if (not metric.is_col_metric() and not metric.is_system_metrics())
|
||||
],
|
||||
MetricTypes.Table,
|
||||
None,
|
||||
self.table,
|
||||
metric_type=MetricTypes.Table,
|
||||
column=None,
|
||||
table=self.table,
|
||||
)
|
||||
]
|
||||
column_metrics = []
|
||||
@ -178,36 +179,36 @@ class SQAInterfaceTestMultiThread(TestCase):
|
||||
window_metrics = []
|
||||
for col in inspect(User).c:
|
||||
column_metrics.append(
|
||||
(
|
||||
[
|
||||
ThreadPoolMetrics(
|
||||
metrics=[
|
||||
metric
|
||||
for metric in self.static_metrics
|
||||
if metric.is_col_metric() and not metric.is_window_metric()
|
||||
],
|
||||
MetricTypes.Static,
|
||||
col,
|
||||
self.table,
|
||||
metric_type=MetricTypes.Static,
|
||||
column=col,
|
||||
table=self.table,
|
||||
)
|
||||
)
|
||||
for query_metric in self.query_metrics:
|
||||
query_metrics.append(
|
||||
(
|
||||
query_metric,
|
||||
MetricTypes.Query,
|
||||
col,
|
||||
self.table,
|
||||
ThreadPoolMetrics(
|
||||
metrics=query_metric,
|
||||
metric_type=MetricTypes.Query,
|
||||
column=col,
|
||||
table=self.table,
|
||||
)
|
||||
)
|
||||
window_metrics.append(
|
||||
(
|
||||
[
|
||||
ThreadPoolMetrics(
|
||||
metrics=[
|
||||
metric
|
||||
for metric in self.window_metrics
|
||||
if metric.is_window_metric()
|
||||
],
|
||||
MetricTypes.Window,
|
||||
col,
|
||||
self.table,
|
||||
metric_type=MetricTypes.Window,
|
||||
column=col,
|
||||
table=self.table,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -26,5 +26,6 @@
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user