Improve dimensionality performance (#24488)

* Fix Bigquery Dimensionality Issue + Refactor

* Remove comment

* Improve Dimensionality Code and Changed Median to use Approx_Quantile for Snowflake

* Remove commented method

* Improve statistical validator failed row count strategy
This commit is contained in:
IceS2 2025-11-21 18:07:31 +01:00 committed by GitHub
parent 76dd0f910e
commit 06c7d82101
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 135 additions and 83 deletions

View File

@ -17,7 +17,6 @@ from typing import List, Optional
from sqlalchemy import Column
from metadata.data_quality.validations.base_test_handler import (
DIMENSION_FAILED_COUNT_KEY,
DIMENSION_TOTAL_COUNT_KEY,
)
from metadata.data_quality.validations.column.base.columnValueMaxToBeBetween import (
@ -78,13 +77,16 @@ class ColumnValueMaxToBeBetweenValidator(
metric_expressions = {
DIMENSION_TOTAL_COUNT_KEY: row_count_expr,
Metrics.MAX.name: max_expr,
DIMENSION_FAILED_COUNT_KEY: (
self._get_validation_checker(
test_params
).build_agg_level_violation_sqa([max_expr], row_count_expr)
),
}
failed_count_builder = (
lambda cte, row_count_expr: self._get_validation_checker(
test_params
).build_agg_level_violation_sqa(
[getattr(cte.c, Metrics.MAX.name)], row_count_expr
)
)
normalized_dimension = self._get_normalized_dimension_expression(
dimension_col
)
@ -93,6 +95,7 @@ class ColumnValueMaxToBeBetweenValidator(
source=self.runner.dataset,
dimension_expr=normalized_dimension,
metric_expressions=metric_expressions,
failed_count_builder=failed_count_builder,
)
for row in result_rows:

View File

@ -18,7 +18,6 @@ from typing import List, Optional
from sqlalchemy import Column
from metadata.data_quality.validations.base_test_handler import (
DIMENSION_FAILED_COUNT_KEY,
DIMENSION_TOTAL_COUNT_KEY,
)
from metadata.data_quality.validations.column.base.columnValueMeanToBeBetween import (
@ -79,13 +78,16 @@ class ColumnValueMeanToBeBetweenValidator(
metric_expressions = {
DIMENSION_TOTAL_COUNT_KEY: row_count_expr,
Metrics.MEAN.name: mean_expr,
DIMENSION_FAILED_COUNT_KEY: (
self._get_validation_checker(
test_params
).build_agg_level_violation_sqa([mean_expr], row_count_expr)
),
}
failed_count_builder = (
lambda cte, row_count_expr: self._get_validation_checker(
test_params
).build_agg_level_violation_sqa(
[getattr(cte.c, Metrics.MEAN.name)], row_count_expr
)
)
normalized_dimension = self._get_normalized_dimension_expression(
dimension_col
)
@ -94,6 +96,7 @@ class ColumnValueMeanToBeBetweenValidator(
source=self.runner.dataset,
dimension_expr=normalized_dimension,
metric_expressions=metric_expressions,
failed_count_builder=failed_count_builder,
)
for row in result_rows:

View File

@ -18,7 +18,6 @@ from typing import List, Optional
from sqlalchemy import Column, select
from metadata.data_quality.validations.base_test_handler import (
DIMENSION_FAILED_COUNT_KEY,
DIMENSION_TOTAL_COUNT_KEY,
DIMENSION_VALUE_KEY,
)
@ -113,13 +112,16 @@ class ColumnValueMedianToBeBetweenValidator(
metric_expressions = {
DIMENSION_TOTAL_COUNT_KEY: row_count_expr,
Metrics.MEDIAN.name: median_expr,
DIMENSION_FAILED_COUNT_KEY: (
self._get_validation_checker(
test_params
).build_agg_level_violation_sqa([median_expr], row_count_expr)
),
}
failed_count_builder = (
lambda cte, row_count_expr: self._get_validation_checker(
test_params
).build_agg_level_violation_sqa(
[getattr(cte.c, Metrics.MEDIAN.name)], row_count_expr
)
)
result_rows = self._run_dimensional_validation_query(
source=normalized_dim_cte,
dimension_expr=normalized_dim_col,
@ -127,6 +129,7 @@ class ColumnValueMedianToBeBetweenValidator(
others_metric_expressions_builder=self._get_others_metric_expressions_builder(
test_params
),
failed_count_builder=failed_count_builder,
)
for row in result_rows:
median_value = row.get(Metrics.MEDIAN.name)
@ -167,11 +170,6 @@ class ColumnValueMedianToBeBetweenValidator(
return {
DIMENSION_TOTAL_COUNT_KEY: row_count_expr,
Metrics.MEDIAN.name: median_expr,
DIMENSION_FAILED_COUNT_KEY: (
self._get_validation_checker(
test_params
).build_agg_level_violation_sqa([median_expr], row_count_expr)
),
}
return build_others_metric_expressions

View File

@ -17,7 +17,6 @@ from typing import List, Optional
from sqlalchemy import Column, func
from metadata.data_quality.validations.base_test_handler import (
DIMENSION_FAILED_COUNT_KEY,
DIMENSION_TOTAL_COUNT_KEY,
)
from metadata.data_quality.validations.column.base.columnValueMinToBeBetween import (
@ -78,13 +77,16 @@ class ColumnValueMinToBeBetweenValidator(
metric_expressions = {
DIMENSION_TOTAL_COUNT_KEY: func.count(),
Metrics.MIN.name: min_expr,
DIMENSION_FAILED_COUNT_KEY: (
self._get_validation_checker(
test_params
).build_agg_level_violation_sqa([min_expr], row_count_expr)
),
}
failed_count_builder = (
lambda cte, row_count_expr: self._get_validation_checker(
test_params
).build_agg_level_violation_sqa(
[getattr(cte.c, Metrics.MIN.name)], row_count_expr
)
)
normalized_dimension = self._get_normalized_dimension_expression(
dimension_col
)
@ -93,6 +95,7 @@ class ColumnValueMinToBeBetweenValidator(
source=self.runner.dataset,
dimension_expr=normalized_dimension,
metric_expressions=metric_expressions,
failed_count_builder=failed_count_builder,
)
for row in result_rows:

View File

@ -17,7 +17,6 @@ from typing import List, Optional
from sqlalchemy import Column
from metadata.data_quality.validations.base_test_handler import (
DIMENSION_FAILED_COUNT_KEY,
DIMENSION_TOTAL_COUNT_KEY,
DIMENSION_VALUE_KEY,
)
@ -86,13 +85,16 @@ class ColumnValueStdDevToBeBetweenValidator(
metric_expressions = {
DIMENSION_TOTAL_COUNT_KEY: row_count_expr,
Metrics.STDDEV.name: stddev_expr,
DIMENSION_FAILED_COUNT_KEY: (
self._get_validation_checker(
test_params
).build_agg_level_violation_sqa([stddev_expr], row_count_expr)
),
}
failed_count_builder = (
lambda cte, row_count_expr: self._get_validation_checker(
test_params
).build_agg_level_violation_sqa(
[getattr(cte.c, Metrics.STDDEV.name)], row_count_expr
)
)
normalized_dimension = self._get_normalized_dimension_expression(
dimension_col
)
@ -101,6 +103,7 @@ class ColumnValueStdDevToBeBetweenValidator(
source=self.runner.dataset,
dimension_expr=normalized_dimension,
metric_expressions=metric_expressions,
failed_count_builder=failed_count_builder,
)
for row in result_rows:

View File

@ -18,7 +18,6 @@ from typing import List, Optional
from sqlalchemy import Column
from metadata.data_quality.validations.base_test_handler import (
DIMENSION_FAILED_COUNT_KEY,
DIMENSION_TOTAL_COUNT_KEY,
)
from metadata.data_quality.validations.column.base.columnValuesSumToBeBetween import (
@ -80,13 +79,16 @@ class ColumnValuesSumToBeBetweenValidator(
metric_expressions = {
DIMENSION_TOTAL_COUNT_KEY: row_count_expr,
Metrics.SUM.name: sum_expr,
DIMENSION_FAILED_COUNT_KEY: (
self._get_validation_checker(
test_params
).build_agg_level_violation_sqa([sum_expr], row_count_expr)
),
}
failed_count_builder = (
lambda cte, row_count_expr: self._get_validation_checker(
test_params
).build_agg_level_violation_sqa(
[getattr(cte.c, Metrics.SUM.name)], row_count_expr
)
)
normalized_dimension = self._get_normalized_dimension_expression(
dimension_col
)
@ -95,6 +97,7 @@ class ColumnValuesSumToBeBetweenValidator(
source=self.runner.dataset,
dimension_expr=normalized_dimension,
metric_expressions=metric_expressions,
failed_count_builder=failed_count_builder,
)
for row in result_rows:

View File

@ -226,73 +226,103 @@ class SQAValidatorMixin:
metric_expressions: Dict[str, ClauseElement],
query_type: DataQualityQueryType,
filter_clause: Optional[ColumnElement] = None,
failed_count_builder: Optional[Callable] = None,
):
"""Build SELECT query for dimensional metrics with impact scoring.
This method constructs identical queries for both top N dimensions and
"Others" aggregation, differing only in filter/grouping/limit parameters.
Args:
source: CTE or table to select from (e.g., runner.dataset, value_counts_cte)
dimension_expr: Normalized dimension expression
metric_expressions: Dict mapping metric names to SQLAlchemy expressions
Must include keys specified by failed_count_key and total_count_key
failed_count_key: Key in metric_expressions for failed count
total_count_key: Key in metric_expressions for total count
filter_clause: Optional WHERE filter (e.g., dimension.notin_(top_n_values))
group_by_dimension: True = GROUP BY dimension (top N), False = aggregate all (Others)
limit: Optional LIMIT clause (typically N+1 for top dimensions query)
Returns:
Select: SQLAlchemy Select object (not executed)
"""
if DIMENSION_FAILED_COUNT_KEY not in metric_expressions:
raise ValueError(
f"metric_expressions must contain 'DIMENSION_FAILED_COUNT_KEY' key"
)
if DIMENSION_TOTAL_COUNT_KEY not in metric_expressions:
raise ValueError(
f"metric_expressions must contain 'DIMENSION_TOTAL_COUNT_KEY' key"
)
if (
DIMENSION_FAILED_COUNT_KEY not in metric_expressions
and failed_count_builder is None
):
raise ValueError(
f"metric_expressions must contain 'DIMENSION_FAILED_COUNT_KEY' key"
)
select_columns = []
# === Level 1: Basic Metrics CTE ===
# Compute all metrics from metric_expressions
basic_metrics_columns = []
for metric_name, metric_expr in metric_expressions.items():
select_columns.append(metric_expr.label(metric_name))
failed_count_expr = metric_expressions[DIMENSION_FAILED_COUNT_KEY]
total_count_expr = metric_expressions[DIMENSION_TOTAL_COUNT_KEY]
impact_score_expr = get_impact_score_expression(
failed_count_expr, total_count_expr
)
select_columns.append(impact_score_expr.label(DIMENSION_IMPACT_SCORE_KEY))
basic_metrics_columns.append(metric_expr.label(metric_name))
match query_type:
case DataQualityQueryType.DIMENSIONAL:
select_columns.append(dimension_expr.label(DIMENSION_VALUE_KEY))
basic_metrics_columns.append(dimension_expr.label(DIMENSION_VALUE_KEY))
case DataQualityQueryType.OTHERS:
select_columns.append(
basic_metrics_columns.append(
literal(DIMENSION_OTHERS_LABEL).label(DIMENSION_VALUE_KEY)
)
query = select(select_columns).select_from(source)
if query_type == DataQualityQueryType.DIMENSIONAL:
query = query.group_by(dimension_expr)
query = query.order_by(impact_score_expr.desc(), dimension_expr.asc())
query = query.limit(DEFAULT_TOP_DIMENSIONS + 1)
query = select(basic_metrics_columns).select_from(source)
if filter_clause is not None:
query = query.where(filter_clause)
return query
if query_type == DataQualityQueryType.DIMENSIONAL:
query = query.group_by(dimension_expr)
basic_metrics_cte = query.cte("basic_metrics")
# === Level 2: Final Metrics CTE ===
# Compute derived metrics
final_metrics_columns = []
match query_type:
case DataQualityQueryType.DIMENSIONAL:
final_metrics_columns.append(
getattr(basic_metrics_cte.c, DIMENSION_VALUE_KEY).label(
DIMENSION_VALUE_KEY
)
)
case DataQualityQueryType.OTHERS:
final_metrics_columns.append(
literal(DIMENSION_OTHERS_LABEL).label(DIMENSION_VALUE_KEY)
)
for metric_name in metric_expressions.keys():
if metric_name != DIMENSION_FAILED_COUNT_KEY:
final_metrics_columns.append(
getattr(basic_metrics_cte.c, metric_name).label(metric_name)
)
total_count_col = getattr(basic_metrics_cte.c, DIMENSION_TOTAL_COUNT_KEY)
failed_count_expr = (
failed_count_builder(basic_metrics_cte, total_count_col)
if failed_count_builder
else getattr(basic_metrics_cte.c, DIMENSION_FAILED_COUNT_KEY)
)
impact_score_expr = get_impact_score_expression(
failed_count_expr, total_count_col
)
final_metrics_columns.append(
failed_count_expr.label(DIMENSION_FAILED_COUNT_KEY)
)
final_metrics_columns.append(
impact_score_expr.label(DIMENSION_IMPACT_SCORE_KEY)
)
final_metrics_cte = select(final_metrics_columns).cte("final_metrics")
final_query = select([final_metrics_cte])
if query_type == DataQualityQueryType.DIMENSIONAL:
final_query = final_query.order_by(
getattr(final_metrics_cte.c, DIMENSION_IMPACT_SCORE_KEY).desc(),
getattr(final_metrics_cte.c, DIMENSION_VALUE_KEY).asc(),
).limit(DEFAULT_TOP_DIMENSIONS + 1)
return final_query
def _run_dimensional_validation_query(
self: HasValidatorContext,
source: FromClause,
dimension_expr: ColumnElement,
metric_expressions: Dict[str, ClauseElement],
failed_count_builder: Optional[Callable] = None,
others_source_builder: Optional[Callable[[List[str]], FromClause]] = None,
others_metric_expressions_builder: Optional[
Callable[[FromClause], Dict[str, ClauseElement]]
@ -326,6 +356,7 @@ class SQAValidatorMixin:
dimension_expr=dimension_expr,
metric_expressions=metric_expressions,
query_type=DataQualityQueryType.DIMENSIONAL,
failed_count_builder=failed_count_builder,
)
top_n_plus_one_results = self.runner.session.execute(
@ -359,6 +390,7 @@ class SQAValidatorMixin:
dimension_expr=dimension_expr, # Only used for grouping, not SELECT
metric_expressions=others_metrics,
query_type=DataQualityQueryType.OTHERS,
failed_count_builder=failed_count_builder,
filter_clause=others_filter,
)

View File

@ -43,6 +43,13 @@ def _(elements, compiler, **kwargs):
return MedianFn.default_fn(elements, compiler, **kwargs)
@compiles(MedianFn, Dialects.Snowflake)
def _(elements, compiler, **kwargs):
col = compiler.process(elements.clauses.clauses[0])
percentile = elements.clauses.clauses[2].value
return "approx_percentile(%s, %s)" % (col, percentile)
@compiles(MedianFn, Dialects.BigQuery)
def _(elements, compiler, **kwargs):
col, _, percentile = [