Fixes #11357 - Implement profiler custom metric processing (#14021)

* feat: add backend support for custom metrics

* feat: fix python test

* feat: support custom metrics computation

* feat: updated tests for custom metrics

* feat: added dl support for min max of datetime

* feat: added is safe query check for query sampler

* feat: added support for custom metric computation in dl

* feat: added explicit addProper for pydantic model import fo Extra

* feat: added custom metric to returned obj

* feat: wrapped trino import in __init__

* feat: fix python linting

* feat: fix typing in 3.8
This commit is contained in:
Teddy 2023-11-17 17:51:39 +01:00 committed by GitHub
parent efb6c5f221
commit c7ac28f2c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 738 additions and 148 deletions

View File

@ -428,7 +428,7 @@ def get_parser(args=None):
return parser.parse_args(args)
def metadata(args=None): # pylint: disable=too-many-branches
def metadata(args=None):
"""
This method implements parsing of the arguments passed from CLI
"""

View File

@ -15,7 +15,10 @@ Return types for Profiler workflow execution.
We need to define this class as we end up having
multiple profilers per table and columns.
"""
from typing import List, Optional, Union
from typing import List, Optional, Type, Union
from sqlalchemy import Column
from sqlalchemy.orm import DeclarativeMeta
from metadata.config.common import ConfigModel
from metadata.generated.schema.api.data.createTableProfile import (
@ -31,9 +34,12 @@ from metadata.generated.schema.entity.data.table import (
from metadata.generated.schema.entity.services.connections.connectionBasicType import (
SampleDataStorageConfig,
)
from metadata.generated.schema.tests.customMetric import CustomMetric
from metadata.generated.schema.type.basic import FullyQualifiedEntityName
from metadata.ingestion.models.table_metadata import ColumnTag
from metadata.profiler.metrics.core import Metric, MetricTypes
from metadata.profiler.processor.models import ProfilerDef
from metadata.utils.sqa_like_column import SQALikeColumn
class ColumnConfig(ConfigModel):
@ -113,3 +119,15 @@ class ProfilerResponse(ConfigModel):
def __str__(self):
"""Return the table name being processed"""
return f"Table [{self.table.name.__root__}]"
class ThreadPoolMetrics(ConfigModel):
"""thread pool metric"""
metrics: Union[List[Union[Type[Metric], CustomMetric]], Type[Metric]]
metric_type: MetricTypes
column: Optional[Union[Column, SQALikeColumn]]
table: Union[Table, DeclarativeMeta]
class Config:
arbitrary_types_allowed = True

View File

@ -22,11 +22,13 @@ from typing import Dict, List, Optional
from sqlalchemy import Column
from metadata.generated.schema.entity.data.table import TableData
from metadata.generated.schema.entity.data.table import CustomMetricProfile, TableData
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
DatalakeConnection,
)
from metadata.generated.schema.tests.customMetric import CustomMetric
from metadata.mixins.pandas.pandas_mixin import PandasInterfaceMixin
from metadata.profiler.api.models import ThreadPoolMetrics
from metadata.profiler.interface.profiler_interface import ProfilerInterface
from metadata.profiler.metrics.core import MetricTypes
from metadata.profiler.metrics.registry import Metrics
@ -239,35 +241,67 @@ class PandasProfilerInterface(ProfilerInterface, PandasInterfaceMixin):
"""
return None # to be implemented
def _compute_custom_metrics(
self, metrics: List[CustomMetric], runner, *args, **kwargs
):
"""Compute custom metrics. For pandas source we expect expression
to be a boolean value. We'll return the length of the dataframe
Args:
metrics (List[Metrics]): list of customMetrics
runner (_type_): runner
"""
if not metrics:
return None
custom_metrics = []
for metric in metrics:
try:
row = sum(
len(df.query(metric.expression).index)
for df in runner
if len(df.query(metric.expression).index)
)
custom_metrics.append(
CustomMetricProfile(name=metric.name.__root__, value=row)
)
except Exception as exc:
msg = f"Error trying to compute profile for custom metric: {exc}"
logger.debug(traceback.format_exc())
logger.warning(msg)
if custom_metrics:
return {"customMetrics": custom_metrics}
return None
def compute_metrics(
self,
metrics,
metric_type,
column,
table,
metric_func: ThreadPoolMetrics,
):
"""Run metrics in processor worker"""
logger.debug(f"Running profiler for {table}")
logger.debug(f"Running profiler for {metric_func.table}")
try:
row = None
if self.complex_dataframe_sample:
row = self._get_metric_fn[metric_type.value](
metrics,
row = self._get_metric_fn[metric_func.metric_type.value](
metric_func.metrics,
self.complex_dataframe_sample,
column=column,
column=metric_func.column,
)
except Exception as exc:
name = f"{column if column is not None else table}"
name = f"{metric_func.column if metric_func.column is not None else metric_func.table}"
error = f"{name} metric_type.value: {exc}"
logger.error(error)
self.status.failed_profiler(error, traceback.format_exc())
row = None
if column is not None:
column = column.name
self.status.scanned(f"{table.name.__root__}.{column}")
if metric_func.column is not None:
column = metric_func.column.name
self.status.scanned(f"{metric_func.table.name.__root__}.{column}")
else:
self.status.scanned(table.name.__root__)
return row, column, metric_type.value
self.status.scanned(metric_func.table.name.__root__)
column = None
return row, column, metric_func.metric_type.value
def fetch_sample_data(self, table, columns: SQALikeColumn) -> TableData:
"""Fetch sample data from database
@ -329,7 +363,7 @@ class PandasProfilerInterface(ProfilerInterface, PandasInterfaceMixin):
profile_results = {"table": {}, "columns": defaultdict(dict)}
metric_list = [
self.compute_metrics(*metric_func) for metric_func in metric_funcs
self.compute_metrics(metric_func) for metric_func in metric_funcs
]
for metric_result in metric_list:
profile, column, metric_type = metric_result
@ -338,6 +372,8 @@ class PandasProfilerInterface(ProfilerInterface, PandasInterfaceMixin):
profile_results["table"].update(profile)
if metric_type == MetricTypes.System.value:
profile_results["system"] = profile
elif metric_type == MetricTypes.Custom.value and column is None:
profile_results["table"].update(profile)
else:
if profile:
profile_results["columns"][column].update(

View File

@ -46,6 +46,7 @@ from metadata.generated.schema.entity.services.databaseService import (
from metadata.generated.schema.metadataIngestion.databaseServiceProfilerPipeline import (
DatabaseServiceProfilerPipeline,
)
from metadata.generated.schema.tests.customMetric import CustomMetric
from metadata.ingestion.api.models import StackTraceError
from metadata.ingestion.api.status import Status
from metadata.ingestion.ometa.ometa_api import OpenMetadata
@ -130,6 +131,7 @@ class ProfilerInterface(ABC):
MetricTypes.Query.value: self._compute_query_metrics,
MetricTypes.Window.value: self._compute_window_metrics,
MetricTypes.System.value: self._compute_system_metrics,
MetricTypes.Custom.value: self._compute_custom_metrics,
}
@abstractmethod
@ -459,6 +461,13 @@ class ProfilerInterface(ABC):
"""Get metrics"""
raise NotImplementedError
@abstractmethod
def _compute_custom_metrics(
self, metrics: List[CustomMetric], runner, *args, **kwargs
):
"""Compute custom metrics"""
raise NotImplementedError
@abstractmethod
def get_all_metrics(self, metric_funcs) -> dict:
"""run profiler metrics"""

View File

@ -22,13 +22,15 @@ from collections import defaultdict
from datetime import datetime, timezone
from typing import Dict, List, Optional
from sqlalchemy import Column, inspect
from sqlalchemy import Column, inspect, text
from sqlalchemy.exc import ProgrammingError, ResourceClosedError
from sqlalchemy.orm import scoped_session
from metadata.generated.schema.entity.data.table import TableData
from metadata.generated.schema.entity.data.table import CustomMetricProfile, TableData
from metadata.generated.schema.tests.customMetric import CustomMetric
from metadata.ingestion.connections.session import create_and_bind_thread_safe_session
from metadata.mixins.sqalchemy.sqa_mixin import SQAInterfaceMixin
from metadata.profiler.api.models import ThreadPoolMetrics
from metadata.profiler.interface.profiler_interface import ProfilerInterface
from metadata.profiler.metrics.core import MetricTypes
from metadata.profiler.metrics.registry import Metrics
@ -42,6 +44,7 @@ from metadata.profiler.orm.registry import Dialects
from metadata.profiler.processor.runner import QueryRunner
from metadata.utils.constants import SAMPLE_DATA_DEFAULT_COUNT
from metadata.utils.custom_thread_pool import CustomThreadPoolExecutor
from metadata.utils.helpers import is_safe_sql_query
from metadata.utils.logger import profiler_interface_registry_logger
logger = profiler_interface_registry_logger()
@ -307,6 +310,8 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
row = runner.select_first_from_sample(
*[metric(column).fn() for metric in metrics],
)
if row:
return dict(row)
except ProgrammingError as exc:
logger.info(
f"Skipping metrics for {runner.table.__tablename__}.{column.name} due to {exc}"
@ -314,8 +319,43 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
handle_query_exception(msg, exc, session)
if row:
return dict(row)
return None
def _compute_custom_metrics(
self, metrics: List[CustomMetric], runner, session, *args, **kwargs
):
"""Compute custom metrics
Args:
metrics (List[Metrics]): list of customMetrics
runner (_type_): runner
"""
if not metrics:
return None
custom_metrics = []
for metric in metrics:
try:
if not is_safe_sql_query(metric.expression):
raise RuntimeError(
f"SQL expression is not safe\n\n{metric.expression}"
)
crs = session.execute(text(metric.expression))
row = (
crs.scalar()
) # raise MultipleResultsFound if more than one row is returned
custom_metrics.append(
CustomMetricProfile(name=metric.name.__root__, value=row)
)
except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.__tablename__}.{metric.columnName}: {exc}"
logger.debug(traceback.format_exc())
logger.warning(msg)
if custom_metrics:
return {"customMetrics": custom_metrics}
return None
def _compute_system_metrics(
@ -376,14 +416,11 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
def compute_metrics_in_thread(
self,
metrics,
metric_type,
column,
table,
metric_func: ThreadPoolMetrics,
):
"""Run metrics in processor worker"""
logger.debug(
f"Running profiler for {table.__tablename__} on thread {threading.current_thread()}"
f"Running profiler for {metric_func.table.__tablename__} on thread {threading.current_thread()}"
)
Session = self.session_factory # pylint: disable=invalid-name
with Session() as session:
@ -391,36 +428,40 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
self.set_catalog(session)
sampler = self._create_thread_safe_sampler(
session,
table,
metric_func.table,
)
sample = sampler.random_sample()
runner = self._create_thread_safe_runner(
session,
table,
metric_func.table,
sample,
)
try:
row = self._get_metric_fn[metric_type.value](
metrics,
row = self._get_metric_fn[metric_func.metric_type.value](
metric_func.metrics,
runner=runner,
session=session,
column=column,
column=metric_func.column,
sample=sample,
)
except Exception as exc:
error = f"{column if column is not None else runner.table.__tablename__} metric_type.value: {exc}"
error = (
f"{metric_func.column if metric_func.column is not None else metric_func.table.__tablename__} "
f"metric_type.value: {exc}"
)
logger.error(error)
self.status.failed_profiler(error, traceback.format_exc())
row = None
if column is not None:
column = column.name
self.status.scanned(f"{table.__tablename__}.{column}")
if metric_func.column is not None:
column = metric_func.column.name
self.status.scanned(f"{metric_func.table.__tablename__}.{column}")
else:
self.status.scanned(table.__tablename__)
self.status.scanned(metric_func.table.__tablename__)
column = None
return row, column, metric_type.value
return row, column, metric_func.metric_type.value
# pylint: disable=use-dict-literal
def get_all_metrics(
@ -434,7 +475,7 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
futures = [
pool.submit(
self.compute_metrics_in_thread,
*metric_func,
metric_func,
)
for metric_func in metric_funcs
]
@ -455,6 +496,8 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
profile_results["table"].update(profile)
elif metric_type == MetricTypes.System.value:
profile_results["system"] = profile
elif metric_type == MetricTypes.Custom.value and column is None:
profile_results["table"].update(profile)
else:
profile_results["columns"][column].update(
{

View File

@ -72,6 +72,8 @@ class SingleStoreProfilerInterface(SQAProfilerInterface):
row = runner.select_first_from_sample(
*[metric(column).fn() for metric in metrics],
)
if row:
return dict(row)
except ProgrammingError:
logger.info(
f"Skipping window metrics for {runner.table.__tablename__}.{column.name} due to overflow"
@ -81,6 +83,4 @@ class SingleStoreProfilerInterface(SQAProfilerInterface):
except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
handle_query_exception(msg, exc, session)
if row:
return dict(row)
return None

View File

@ -67,6 +67,8 @@ class TrinoProfilerInterface(SQAProfilerInterface):
row = runner.select_first_from_sample(
*[metric(column).fn() for metric in metrics], **runner_kwargs
)
if row:
return dict(row)
except ProgrammingError as err:
logger.info(
f"Skipping window metrics for {runner.table.__tablename__}.{column.name} due to {err}"
@ -76,6 +78,4 @@ class TrinoProfilerInterface(SQAProfilerInterface):
except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.__tablename__}.{column.name}: {exc}"
handle_query_exception(msg, exc, session)
if row:
return dict(row)
return None

View File

@ -90,6 +90,9 @@ class Max(StaticMetric):
def df_fn(self, dfs=None):
"""pandas function"""
if is_quantifiable(self.col.type) or is_date_time(self.col.type):
if is_quantifiable(self.col.type):
return max((df[self.col.name].max() for df in dfs))
if is_date_time(self.col.type):
max_ = max((df[self.col.name].max() for df in dfs))
return int(max_.timestamp() * 1000)
return 0

View File

@ -90,6 +90,9 @@ class Min(StaticMetric):
def df_fn(self, dfs=None):
"""pandas function"""
if is_quantifiable(self.col.type) or is_date_time(self.col.type):
if is_quantifiable(self.col.type):
return min((df[self.col.name].min() for df in dfs))
if is_date_time(self.col.type):
min_ = min((df[self.col.name].min() for df in dfs))
return int(min_.timestamp() * 1000)
return 0

View File

@ -33,7 +33,10 @@ from metadata.generated.schema.entity.data.table import (
TableData,
TableProfile,
)
from metadata.profiler.api.models import ProfilerResponse
from metadata.generated.schema.tests.customMetric import (
CustomMetric as CustomMetricEntity,
)
from metadata.profiler.api.models import ProfilerResponse, ThreadPoolMetrics
from metadata.profiler.interface.profiler_interface import ProfilerInterface
from metadata.profiler.metrics.core import (
ComposedMetric,
@ -273,6 +276,33 @@ class Profiler(Generic[TMetric]):
return [metric for metric in metrics if metric.is_col_metric()]
def get_custom_metrics(
self, column_name: Optional[str] = None
) -> Optional[List[CustomMetricEntity]]:
"""Get custom metrics for a table or column
Args:
column (Optional[str]): optional column name. If None will fetch table level custom metrics
Returns:
List[str]
"""
if column_name is None:
return self.profiler_interface.table_entity.customMetrics or None
# if we have a column we'll get the custom metrics for this column
column = next(
(
clmn
for clmn in self.profiler_interface.table_entity.columns
if clmn.name.__root__ == column_name
),
None,
)
if column:
return column.customMetrics or None
return None
@property
def sample(self):
"""Return the sample used for the profiler"""
@ -345,22 +375,40 @@ class Profiler(Generic[TMetric]):
def _prepare_table_metrics(self) -> List:
"""prepare table metrics"""
metrics = []
table_metrics = [
metric
for metric in self.static_metrics
if (not metric.is_col_metric() and not metric.is_system_metrics())
]
custom_table_metrics = self.get_custom_metrics()
if table_metrics:
return [
(
table_metrics, # metric functions
MetricTypes.Table, # metric type for function mapping
None, # column name
self.table, # table name
),
]
return []
metrics.extend(
[
ThreadPoolMetrics(
metrics=table_metrics,
metric_type=MetricTypes.Table,
column=None,
table=self.table,
)
]
)
if custom_table_metrics:
metrics.extend(
[
ThreadPoolMetrics(
metrics=custom_table_metrics,
metric_type=MetricTypes.Custom,
column=None,
table=self.table,
)
]
)
return metrics
def _prepare_system_metrics(self) -> List:
"""prepare system metrics"""
@ -368,11 +416,11 @@ class Profiler(Generic[TMetric]):
if system_metrics:
return [
(
system_metric, # metric functions
MetricTypes.System, # metric type for function mapping
None, # column name
self.table, # table name
ThreadPoolMetrics(
metrics=system_metric, # metric functions
metric_type=MetricTypes.System, # metric type for function mapping
column=None, # column name
table=self.table, # table name
)
for system_metric in system_metrics
]
@ -386,45 +434,60 @@ class Profiler(Generic[TMetric]):
for column in self.columns
if column.type.__class__.__name__ not in NOT_COMPUTE
]
column_metrics_for_thread_pool = [
*[
(
[
metric
for metric in self.get_col_metrics(self.static_metrics, column)
if not metric.is_window_metric()
],
MetricTypes.Static,
column,
self.table,
)
for column in columns
],
*[
(
metric,
MetricTypes.Query,
column,
self.table,
)
for column in columns
for metric in self.get_col_metrics(self.query_metrics, column)
],
*[
(
[
metric
for metric in self.get_col_metrics(self.static_metrics, column)
if metric.is_window_metric()
],
MetricTypes.Window,
column,
self.table,
)
for column in columns
],
column_metrics_for_thread_pool = []
static_metrics = [
ThreadPoolMetrics(
metrics=[
metric
for metric in self.get_col_metrics(self.static_metrics, column)
if not metric.is_window_metric()
],
metric_type=MetricTypes.Static,
column=column,
table=self.table,
)
for column in columns
]
query_metrics = [
ThreadPoolMetrics(
metrics=metric,
metric_type=MetricTypes.Query,
column=column,
table=self.table,
)
for column in columns
for metric in self.get_col_metrics(self.query_metrics, column)
]
window_metrics = [
ThreadPoolMetrics(
metrics=[
metric
for metric in self.get_col_metrics(self.static_metrics, column)
if metric.is_window_metric()
],
metric_type=MetricTypes.Window,
column=column,
table=self.table,
)
for column in columns
]
# we'll add the system metrics to the thread pool computation
for metric_type in [static_metrics, query_metrics, window_metrics]:
column_metrics_for_thread_pool.extend(metric_type)
# we'll add the custom metrics to the thread pool computation
for column in columns:
custom_metrics = self.get_custom_metrics(column.name)
if custom_metrics:
column_metrics_for_thread_pool.append(
ThreadPoolMetrics(
metrics=custom_metrics,
metric_type=MetricTypes.Custom,
column=column,
table=self.table,
)
)
return column_metrics_for_thread_pool
@ -568,6 +631,7 @@ class Profiler(Generic[TMetric]):
profileSampleType=self.profile_sample_config.profile_sample_type
if self.profile_sample_config
else None,
customMetrics=self._table_results.get("customMetrics"),
)
if self._system_results:

View File

@ -31,6 +31,7 @@ from metadata.profiler.orm.functions.random_num import RandomNumFn
from metadata.profiler.orm.registry import Dialects
from metadata.profiler.processor.handle_partition import partition_filter_handler
from metadata.profiler.processor.sampler.sampler_interface import SamplerInterface
from metadata.utils.helpers import is_safe_sql_query
from metadata.utils.logger import profiler_interface_registry_logger
from metadata.utils.sqa_utils import (
build_query_filter,
@ -171,6 +172,11 @@ class SQASampler(SamplerInterface):
def _fetch_sample_data_from_user_query(self) -> TableData:
"""Returns a table data object using results from query execution"""
if not is_safe_sql_query(self._profile_sample_query):
raise RuntimeError(
f"SQL expression is not safe\n\n{self._profile_sample_query}"
)
rnd = self.client.execute(f"{self._profile_sample_query}")
try:
columns = [col.name for col in rnd.cursor.description]
@ -183,6 +189,11 @@ class SQASampler(SamplerInterface):
def _rdn_sample_from_user_query(self) -> Query:
"""Returns sql alchemy object to use when running profiling"""
if not is_safe_sql_query(self._profile_sample_query):
raise RuntimeError(
f"SQL expression is not safe\n\n{self._profile_sample_query}"
)
return self.client.query(self.table).from_statement(
text(f"{self._profile_sample_query}")
)

View File

@ -13,7 +13,6 @@ Helper module to handle data sampling
for the profiler
"""
from sqlalchemy import inspect, or_, text
from trino.sqlalchemy.dialect import TrinoDialect
from metadata.profiler.orm.registry import FLOAT_SET
from metadata.profiler.processor.handle_partition import RANDOM_LABEL
@ -26,7 +25,13 @@ class TrinoSampler(SQASampler):
run the query in the whole table.
"""
TrinoDialect._json_deserializer = None # pylint: disable=protected-access
def __init__(self, *args, **kwargs):
# pylint: disable=import-outside-toplevel
from trino.sqlalchemy.dialect import TrinoDialect
TrinoDialect._json_deserializer = None
super().__init__(*args, **kwargs)
def _base_sample_query(self, label=None):
sqa_columns = [col for col in inspect(self.table).c if col.name != RANDOM_LABEL]

View File

@ -238,9 +238,7 @@ class OpenMetadataSource(Source):
"""
tables = self.metadata.list_all_entities(
entity=Table,
fields=[
"tableProfilerConfig",
],
fields=["tableProfilerConfig", "columns", "customMetrics"],
params={
"service": self.config.source.serviceName,
"database": fqn.build(

View File

@ -430,6 +430,9 @@ def is_safe_sql_query(sql_query: str) -> bool:
"SET TRANSACTION",
}
if sql_query is None:
return True
parsed_queries: Tuple[Statement] = sqlparse.parse(sql_query)
for parsed_query in parsed_queries:
validation = [

View File

@ -0,0 +1,4 @@
id,first_name,last_name,city,country,birthdate,age
1,John,Doe,Los Angeles,US,1980-01-01,40
2,Jane,Doe,Los Angeles,US,2000-12-31,39
3,Jane,Smith,Paris,,2001-11-11,28
1 id first_name last_name city country birthdate age
2 1 John Doe Los Angeles US 1980-01-01 40
3 2 Jane Doe Los Angeles US 2000-12-31 39
4 3 Jane Smith Paris 2001-11-11 28

View File

@ -0,0 +1,233 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test Metrics behavior
"""
# import datetime
import os
from unittest import TestCase
from unittest.mock import patch
from uuid import uuid4
import boto3
import botocore
import pandas as pd
from moto import mock_s3
from metadata.generated.schema.entity.data.table import Column as EntityColumn
from metadata.generated.schema.entity.data.table import ColumnName, DataType, Table
from metadata.generated.schema.entity.services.connections.database.datalake.s3Config import (
S3Config,
)
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
DatalakeConnection,
)
from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials
from metadata.generated.schema.tests.customMetric import CustomMetric
from metadata.profiler.interface.pandas.profiler_interface import (
PandasProfilerInterface,
)
from metadata.profiler.processor.core import Profiler
BUCKET_NAME = "MyBucket"
@mock_s3
class MetricsTest(TestCase):
"""
Run checks on different metrics
"""
current_dir = os.path.dirname(__file__)
resources_dir = os.path.join(current_dir, "resources")
datalake_conn = DatalakeConnection(
configSource=S3Config(
securityConfig=AWSCredentials(
awsAccessKeyId="fake_access_key",
awsSecretAccessKey="fake_secret_key",
awsRegion="us-west-1",
)
)
)
dfs = [
pd.read_csv(os.path.join(resources_dir, "profiler_test_.csv"), parse_dates=[5])
]
table_entity = Table(
id=uuid4(),
name="user",
columns=[
EntityColumn(
name=ColumnName(__root__="id"),
dataType=DataType.INT,
)
],
)
def setUp(self):
# Mock our S3 bucket and ingest a file
boto3.DEFAULT_SESSION = None
self.client = boto3.client(
"s3",
region_name="us-weat-1",
)
# check that we are not running our test against a real bucket
try:
s3 = boto3.resource(
"s3",
region_name="us-west-1",
aws_access_key_id="fake_access_key",
aws_secret_access_key="fake_secret_key",
)
s3.meta.client.head_bucket(Bucket=BUCKET_NAME)
except botocore.exceptions.ClientError:
pass
else:
err = f"{BUCKET_NAME} should not exist."
raise EnvironmentError(err)
self.client.create_bucket(
Bucket=BUCKET_NAME,
CreateBucketConfiguration={"LocationConstraint": "us-west-1"},
)
resources_paths = [
os.path.join(path, filename)
for path, _, files in os.walk(self.resources_dir)
for filename in files
]
self.s3_keys = []
for path in resources_paths:
key = os.path.relpath(path, self.resources_dir)
self.s3_keys.append(key)
self.client.upload_file(Filename=path, Bucket=BUCKET_NAME, Key=key)
with patch.object(
PandasProfilerInterface,
"_convert_table_to_list_of_dataframe_objects",
return_value=self.dfs,
):
self.sqa_profiler_interface = PandasProfilerInterface(
self.datalake_conn,
None,
self.table_entity,
None,
None,
None,
None,
None,
thread_count=1,
)
def test_table_custom_metric(self):
table_entity = Table(
id=uuid4(),
name="user",
columns=[
EntityColumn(
name=ColumnName(__root__="id"),
dataType=DataType.INT,
)
],
customMetrics=[
CustomMetric(
name="LastNameFilter",
expression="'last_name' != Doe",
),
CustomMetric(
name="notUS",
expression="'country == US'",
),
],
)
with patch.object(
PandasProfilerInterface,
"_convert_table_to_list_of_dataframe_objects",
return_value=self.dfs,
):
self.sqa_profiler_interface = PandasProfilerInterface(
self.datalake_conn,
None,
table_entity,
None,
None,
None,
None,
None,
thread_count=1,
)
profiler = Profiler(
profiler_interface=self.sqa_profiler_interface,
)
metrics = profiler.compute_metrics()
for k, v in metrics._table_results.items():
for metric in v:
if metric.name == "LastNameFilter":
assert metric.value == 1
if metric.name == "notUS":
assert metric.value == 2
def test_column_custom_metric(self):
table_entity = Table(
id=uuid4(),
name="user",
columns=[
EntityColumn(
name=ColumnName(__root__="id"),
dataType=DataType.INT,
customMetrics=[
CustomMetric(
name="LastNameFilter",
columnName="id",
expression="'last_name' != Doe",
),
CustomMetric(
name="notUS",
columnName="id",
expression="'country == US'",
),
],
)
],
)
with patch.object(
PandasProfilerInterface,
"_convert_table_to_list_of_dataframe_objects",
return_value=self.dfs,
):
self.sqa_profiler_interface = PandasProfilerInterface(
self.datalake_conn,
None,
table_entity,
None,
None,
None,
None,
None,
thread_count=1,
)
profiler = Profiler(
profiler_interface=self.sqa_profiler_interface,
)
metrics = profiler.compute_metrics()
for k, v in metrics._column_results.items():
for metric in v.get("customMetrics", []):
if metric.name == "CustomerBornedAfter1991":
assert metric.value == 1
if metric.name == "AverageAge":
assert metric.value == 2

View File

@ -269,5 +269,8 @@ class ProfilerTest(TestCase):
column_metrics = default_profiler._prepare_column_metrics()
for metric in column_metrics:
if metric[1] is not MetricTypes.Table and metric[2].name == "id":
assert all(metric_filter.count(m.name()) for m in metric[0])
if (
metric.metric_type is not MetricTypes.Table
and metric.column.name == "id"
):
assert all(metric_filter.count(m.name()) for m in metric.metrics)

View File

@ -32,6 +32,7 @@ from metadata.generated.schema.entity.data.table import (
Table,
TableProfile,
)
from metadata.profiler.api.models import ThreadPoolMetrics
from metadata.profiler.interface.pandas.profiler_interface import (
PandasProfilerInterface,
)
@ -146,15 +147,15 @@ class PandasInterfaceTest(TestCase):
def test_get_all_metrics(self):
table_metrics = [
(
[
ThreadPoolMetrics(
metrics=[
metric
for metric in self.metrics
if (not metric.is_col_metric() and not metric.is_system_metrics())
],
MetricTypes.Table,
None,
self.table_entity,
metric_type=MetricTypes.Table,
column=None,
table=self.table_entity,
)
]
column_metrics = []
@ -164,36 +165,36 @@ class PandasInterfaceTest(TestCase):
if col.name == "id":
continue
column_metrics.append(
(
[
ThreadPoolMetrics(
metrics=[
metric
for metric in self.static_metrics
if metric.is_col_metric() and not metric.is_window_metric()
],
MetricTypes.Static,
col,
self.table_entity,
metric_type=MetricTypes.Static,
column=col,
table=self.table_entity,
)
)
for query_metric in self.query_metrics:
query_metrics.append(
(
query_metric,
MetricTypes.Query,
col,
self.table_entity,
ThreadPoolMetrics(
metrics=query_metric,
metric_type=MetricTypes.Query,
column=col,
table=self.table_entity,
)
)
window_metrics.append(
(
[
ThreadPoolMetrics(
metrics=[
metric
for metric in self.window_metrics
if metric.is_window_metric()
],
MetricTypes.Window,
col,
self.table_entity,
metric_type=MetricTypes.Window,
column=col,
table=self.table_entity,
)
)

View File

@ -27,6 +27,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
SQLiteConnection,
SQLiteScheme,
)
from metadata.generated.schema.tests.customMetric import CustomMetric
from metadata.profiler.interface.sqlalchemy.profiler_interface import (
SQAProfilerInterface,
)
@ -865,6 +866,102 @@ class MetricsTest(TestCase):
session = self.sqa_profiler_interface.session
system().sql(session)
def test_table_custom_metric(self):
table_entity = Table(
id=uuid4(),
name="user",
columns=[
EntityColumn(
name=ColumnName(__root__="id"),
dataType=DataType.INT,
)
],
customMetrics=[
CustomMetric(
name="CustomerBornedAfter1991",
expression="SELECT COUNT(id) FROM users WHERE dob > '1991-01-01'",
),
CustomMetric(
name="AverageAge",
expression="SELECT SUM(age)/COUNT(*) FROM users",
),
],
)
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
):
self.sqa_profiler_interface = SQAProfilerInterface(
self.sqlite_conn,
None,
table_entity,
None,
None,
None,
None,
None,
thread_count=1,
)
profiler = Profiler(
profiler_interface=self.sqa_profiler_interface,
)
metrics = profiler.compute_metrics()
for k, v in metrics._table_results.items():
for metric in v:
if metric.name == "CustomerBornedAfter1991":
assert metric.value == 2
if metric.name == "AverageAge":
assert metric.value == 20.0
def test_column_custom_metric(self):
table_entity = Table(
id=uuid4(),
name="user",
columns=[
EntityColumn(
name=ColumnName(__root__="id"),
dataType=DataType.INT,
customMetrics=[
CustomMetric(
name="CustomerBornedAfter1991",
columnName="id",
expression="SELECT SUM(id) FROM users WHERE dob > '1991-01-01'",
),
CustomMetric(
name="AverageAge",
columnName="id",
expression="SELECT SUM(age)/COUNT(*) FROM users",
),
],
)
],
)
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
):
self.sqa_profiler_interface = SQAProfilerInterface(
self.sqlite_conn,
None,
table_entity,
None,
None,
None,
None,
None,
thread_count=1,
)
profiler = Profiler(
profiler_interface=self.sqa_profiler_interface,
)
metrics = profiler.compute_metrics()
for k, v in metrics._column_results.items():
for metric in v.get("customMetrics", []):
if metric.name == "CustomerBornedAfter1991":
assert metric.value == 3.0
if metric.name == "AverageAge":
assert metric.value == 20.0
@classmethod
def tearDownClass(cls) -> None:
os.remove(cls.db_path)

View File

@ -42,6 +42,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
SQLiteConnection,
SQLiteScheme,
)
from metadata.generated.schema.tests.customMetric import CustomMetric
from metadata.ingestion.source import sqa_types
from metadata.profiler.interface.sqlalchemy.profiler_interface import (
SQAProfilerInterface,
@ -83,8 +84,27 @@ class ProfilerTest(TestCase):
EntityColumn(
name=ColumnName(__root__="id"),
dataType=DataType.INT,
customMetrics=[
CustomMetric(
name="custom_metric",
description="custom metric",
expression="SELECT cos(id) FROM users",
)
],
)
],
customMetrics=[
CustomMetric(
name="custom_metric",
description="custom metric",
expression="SELECT COUNT(id) / COUNT(age) FROM users",
),
CustomMetric(
name="custom_metric_two",
description="custom metric",
expression="SELECT COUNT(id) * COUNT(age) FROM users",
),
],
)
with patch.object(
SQAProfilerInterface, "_convert_table_to_orm_object", return_value=User
@ -231,6 +251,30 @@ class ProfilerTest(TestCase):
)
)
def test__prepare_column_metrics(self):
"""test _prepare_column_metrics returns as expected"""
profiler = Profiler(
Metrics.FIRST_QUARTILE.value,
profiler_interface=self.sqa_profiler_interface,
)
metrics = profiler._prepare_column_metrics()
for metric in metrics:
if metric.metrics:
if isinstance(metric.metrics[0], CustomMetric):
assert metric.metrics[0].name.__root__ == "custom_metric"
else:
assert metric.metrics[0].name() == "firstQuartile"
def test__prepare_table_metrics(self):
"""test _prepare_table_metrics returns as expected"""
profiler = Profiler(
Metrics.COLUMN_COUNT.value,
profiler_interface=self.sqa_profiler_interface,
)
metrics = profiler._prepare_table_metrics()
self.assertEqual(2, len(metrics))
def test_profiler_with_timeout(self):
"""check timeout is properly used"""
@ -259,6 +303,7 @@ class ProfilerTest(TestCase):
def test_profiler_get_col_metrics(self):
"""check getc column metrics"""
metric_filter = ["mean", "min", "max", "firstQuartile"]
custom_metric_filter = ["custom_metric"]
self.sqa_profiler_interface.table_entity.tableProfilerConfig = (
TableProfilerConfig(
includeColumns=[
@ -273,8 +318,20 @@ class ProfilerTest(TestCase):
column_metrics = default_profiler._prepare_column_metrics()
for metric in column_metrics:
if metric[1] is not MetricTypes.Table and metric[2].name == "id":
assert all(metric_filter.count(m.name()) for m in metric[0])
if (
metric.metric_type is not MetricTypes.Table
and metric.column.name == "id"
):
assert all(
metric_filter.count(m.name())
for m in metric.metrics
if not isinstance(m, CustomMetric)
)
assert all(
custom_metric_filter.count(m.name.__root__)
for m in metric.metrics
if isinstance(m, CustomMetric)
)
@classmethod
def tearDownClass(cls) -> None:

View File

@ -38,6 +38,7 @@ from metadata.generated.schema.entity.services.connections.database.sqliteConnec
SQLiteConnection,
SQLiteScheme,
)
from metadata.profiler.api.models import ThreadPoolMetrics
from metadata.profiler.interface.sqlalchemy.profiler_interface import (
SQAProfilerInterface,
)
@ -162,15 +163,15 @@ class SQAInterfaceTestMultiThread(TestCase):
def test_get_all_metrics(self):
table_metrics = [
(
[
ThreadPoolMetrics(
metrics=[
metric
for metric in self.metrics
if (not metric.is_col_metric() and not metric.is_system_metrics())
],
MetricTypes.Table,
None,
self.table,
metric_type=MetricTypes.Table,
column=None,
table=self.table,
)
]
column_metrics = []
@ -178,36 +179,36 @@ class SQAInterfaceTestMultiThread(TestCase):
window_metrics = []
for col in inspect(User).c:
column_metrics.append(
(
[
ThreadPoolMetrics(
metrics=[
metric
for metric in self.static_metrics
if metric.is_col_metric() and not metric.is_window_metric()
],
MetricTypes.Static,
col,
self.table,
metric_type=MetricTypes.Static,
column=col,
table=self.table,
)
)
for query_metric in self.query_metrics:
query_metrics.append(
(
query_metric,
MetricTypes.Query,
col,
self.table,
ThreadPoolMetrics(
metrics=query_metric,
metric_type=MetricTypes.Query,
column=col,
table=self.table,
)
)
window_metrics.append(
(
[
ThreadPoolMetrics(
metrics=[
metric
for metric in self.window_metrics
if metric.is_window_metric()
],
MetricTypes.Window,
col,
self.table,
metric_type=MetricTypes.Window,
column=col,
table=self.table,
)
)

View File

@ -26,5 +26,6 @@
}
]
}
}
},
"additionalProperties": false
}