Feature/dimensionality column values stddev to be between (#24235)

* Initial implementation for Dimensionality on Data Quality Tests

* Fix ColumnValuesToBeUnique and create TestCaseResult API

* Refactor dimension result

* Initial E2E Implementation without Impact Score

* Dimensionality Thin Slice

* Update generated TypeScript types

* Update generated TypeScript types

* Removed useless method to use the one we already had

* Fix Pandas Dimensionality checks

* Remove useless comments

* Implement PR comments, fix Tests

* Improve the code a bit

* Fix imports

* Implement Dimensionality for ColumnMeanToBeBetween

* Removed useless comments and improved minor things

* Implement UnitTests

* Fixes

* Moved import pandas to type checking

* Fix Min/Max being optional

* Fix Unittests

* small fixes

* Fix Unittests

* Fix Issue with counting total rows on mean

* Improve code

* Fix Merge

* Removed unused type

* Refactor to reduce code repetition and complexity

* Fix conflict

* Rename method

* Refactor some metrics

* Implement Dimensionality to ColumnLengthToBeBetween

* Implement Dimensionality for ColumnMedianToBeBetween in Pandas

* Implement Median Dimensionality for SQL

* Add database tests

* Fix median metric

* Implement Dimensionality SumToBeBetween

* Implement dimensionality for Column Values not In Set

* Implement Dimensionality for ColumnValuestoMatchRegex and ColumnValuesToNotMatchRegex

* Implement NotNull and MissingCount dimensionality

* Implement columnValuesToBeBetween dimensionality

* Fix test

* Implement Pandas Dimensionality for ColumnValueStdDevToBeBetween

* Implement Dimensionality for ColumnValuesStdDevToBeBetween

* Fixed tests due to sqlite now supporting stddev

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
IceS2 2025-11-10 12:13:27 +01:00 committed by GitHub
parent c56edc3df1
commit dddec06143
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 886 additions and 39 deletions

View File

@ -15,11 +15,19 @@ Validator for column value stddev to be between test case
import traceback
from abc import abstractmethod
from typing import Union
from typing import List, Optional, Union
from sqlalchemy import Column
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.data_quality.validations.base_test_handler import (
BaseTestValidator,
DimensionInfo,
DimensionResult,
TestEvaluation,
)
from metadata.data_quality.validations.checkers.between_bounds_checker import (
BetweenBoundsChecker,
)
from metadata.generated.schema.tests.basic import (
TestCaseResult,
TestCaseStatus,
@ -37,6 +45,9 @@ STDDEV = "stddev"
class BaseColumnValueStdDevToBeBetweenValidator(BaseTestValidator):
"""Validator for column value stddev to be between test case"""
MIN_BOUND = "minValueForStdDevInCol"
MAX_BOUND = "maxValueForStdDevInCol"
def _run_validation(self) -> TestCaseResult:
"""Execute the specific test validation logic
@ -46,9 +57,16 @@ class BaseColumnValueStdDevToBeBetweenValidator(BaseTestValidator):
Returns:
TestCaseResult: The test case result for the overall validation
"""
test_params = self._get_test_parameters()
try:
column: Union[SQALikeColumn, Column] = self.get_column()
res = self._run_results(Metrics.STDDEV, column)
stddev_value = self._run_results(Metrics.STDDEV, column)
metric_values = {
Metrics.STDDEV.name: stddev_value,
}
except (ValueError, RuntimeError) as exc:
msg = f"Error computing {self.test_case.fullyQualifiedName}: {exc}" # type: ignore
logger.debug(traceback.format_exc())
@ -60,18 +78,162 @@ class BaseColumnValueStdDevToBeBetweenValidator(BaseTestValidator):
[TestResultValue(name=STDDEV, value=None)],
)
min_bound = self.get_min_bound("minValueForStdDevInCol")
max_bound = self.get_max_bound("maxValueForStdDevInCol")
evaluation = self._evaluate_test_condition(metric_values, test_params)
result_message = self._format_result_message(
metric_values, test_params=test_params
)
test_result_values = self._get_test_result_values(metric_values)
return self.get_test_case_result_object(
self.execution_date,
self.get_test_case_status(min_bound <= res <= max_bound),
f"Found stddev={res} vs. the expected min={min_bound}, max={max_bound}.",
[TestResultValue(name=STDDEV, value=str(res))],
min_bound=min_bound,
max_bound=max_bound,
self.get_test_case_status(evaluation["matched"]),
result_message,
test_result_values,
min_bound=test_params[self.MIN_BOUND],
max_bound=test_params[self.MAX_BOUND],
)
def _get_validation_checker(self, test_params: dict) -> BetweenBoundsChecker:
"""Get validation checker for this test
Args:
test_params: Test parameters including min and max bounds
Returns:
BetweenBoundsChecker: Checker instance configured with bounds
"""
return BetweenBoundsChecker(
min_bound=test_params[self.MIN_BOUND],
max_bound=test_params[self.MAX_BOUND],
)
def _get_test_parameters(self) -> dict:
"""Get test parameters for this validator
Returns:
dict: Test parameters including min and max bounds
"""
return {
self.MIN_BOUND: self.get_min_bound(self.MIN_BOUND),
self.MAX_BOUND: self.get_max_bound(self.MAX_BOUND),
}
def _get_metrics_to_compute(self, test_params: Optional[dict] = None) -> dict:
"""Get metrics that need to be computed for this test
Args:
test_params: Optional test parameters (unused for stddev validator)
Returns:
dict: Dictionary mapping metric names to Metrics enum values
"""
return {
Metrics.STDDEV.name: Metrics.STDDEV,
}
def _evaluate_test_condition(
self, metric_values: dict, test_params: dict
) -> TestEvaluation:
"""Evaluate the stddev-to-be-between test condition
For stddev test, the condition passes if the stddev value is within the specified bounds.
Since this is a statistical validator (group-level), passed/failed row counts are not applicable.
Args:
metric_values: Dictionary with keys from Metrics enum names
e.g., {"STDDEV": 15.5}
test_params: Dictionary with 'minValueForStdDevInCol' and 'maxValueForStdDevInCol'
Returns:
dict with keys:
- matched: bool - whether test passed (stddev within bounds)
- passed_rows: None - not applicable for statistical validators
- failed_rows: None - not applicable for statistical validators
- total_rows: None - not applicable for statistical validators
"""
stddev_value = metric_values[Metrics.STDDEV.name]
min_bound = test_params[self.MIN_BOUND]
max_bound = test_params[self.MAX_BOUND]
matched = min_bound <= stddev_value <= max_bound
return {
"matched": matched,
"passed_rows": None,
"failed_rows": None,
"total_rows": None,
}
def _format_result_message(
self,
metric_values: dict,
dimension_info: Optional[DimensionInfo] = None,
test_params: Optional[dict] = None,
) -> str:
"""Format the result message for stddev-to-be-between test
Args:
metric_values: Dictionary with Metrics enum names as keys
dimension_info: Optional DimensionInfo with dimension details
test_params: Test parameters with min/max bounds. Required for this test.
Returns:
str: Formatted result message
"""
if test_params is None:
raise ValueError(
"test_params is required for columnValueStdDevToBeBetween._format_result_message"
)
stddev_value = metric_values[Metrics.STDDEV.name]
min_bound = test_params[self.MIN_BOUND]
max_bound = test_params[self.MAX_BOUND]
if dimension_info:
return (
f"Dimension {dimension_info['dimension_name']}={dimension_info['dimension_value']}: "
f"Found stddev={stddev_value} vs. the expected min={min_bound}, max={max_bound}"
)
else:
return f"Found stddev={stddev_value} vs. the expected min={min_bound}, max={max_bound}."
def _get_test_result_values(self, metric_values: dict) -> List[TestResultValue]:
"""Get test result values for stddev-to-be-between test
Args:
metric_values: Dictionary with Metrics enum names as keys
Returns:
List[TestResultValue]: Test result values for the test case
"""
return [
TestResultValue(
name=STDDEV,
value=str(metric_values[Metrics.STDDEV.name]),
),
]
@abstractmethod
def _run_results(self, metric: Metrics, column: Union[SQALikeColumn, Column]):
raise NotImplementedError
@abstractmethod
def _execute_dimensional_validation(
self,
column: Union[SQALikeColumn, Column],
dimension_col: Union[SQALikeColumn, Column],
metrics_to_compute: dict,
test_params: dict,
) -> List[DimensionResult]:
"""Execute dimensional validation query for a single dimension column
Args:
column: The column being tested (e.g., revenue)
dimension_col: The dimension column to group by (e.g., region)
metrics_to_compute: Dict mapping metric names to Metrics enum values
test_params: Test parameters including min and max bounds
Returns:
List of DimensionResult objects for each dimension value
"""
raise NotImplementedError

View File

@ -13,17 +13,37 @@
Validator for column value stddev to be between test case
"""
from typing import Optional
from collections import defaultdict
from typing import List, Optional, cast
import pandas as pd
from metadata.data_quality.validations.base_test_handler import (
DIMENSION_FAILED_COUNT_KEY,
DIMENSION_TOTAL_COUNT_KEY,
DIMENSION_VALUE_KEY,
)
from metadata.data_quality.validations.column.base.columnValueStdDevToBeBetween import (
BaseColumnValueStdDevToBeBetweenValidator,
)
from metadata.data_quality.validations.impact_score import (
DEFAULT_TOP_DIMENSIONS,
calculate_impact_score_pandas,
)
from metadata.data_quality.validations.mixins.pandas_validator_mixin import (
PandasValidatorMixin,
aggregate_others_statistical_pandas,
)
from metadata.generated.schema.tests.dimensionResult import DimensionResult
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.metrics.static.stddev import StdDev, SumSumSquaresCount
from metadata.utils.logger import test_suite_logger
from metadata.utils.sqa_like_column import SQALikeColumn
logger = test_suite_logger()
SUM_SQUARES_KEY = "SUM_SQUARES"
class ColumnValueStdDevToBeBetweenValidator(
BaseColumnValueStdDevToBeBetweenValidator, PandasValidatorMixin
@ -38,3 +58,202 @@ class ColumnValueStdDevToBeBetweenValidator(
column: column
"""
return self.run_dataframe_results(self.runner, metric, column)
def _execute_dimensional_validation(
self,
column: SQALikeColumn,
dimension_col: SQALikeColumn,
metrics_to_compute: dict,
test_params: dict,
) -> List[DimensionResult]:
"""Execute dimensional validation for stddev with proper weighted aggregation
Follows the iterate pattern from the StdDev metric's df_fn method to handle
multiple dataframes efficiently without concatenating them in memory.
Memory-efficient approach: Instead of concatenating all dataframes (which creates
a full copy in memory), we iterate over them and accumulate aggregates. This is
especially important for large parquet files split across many chunks.
For statistical validators like StdDev, we need special handling:
1. Iterate over all dataframes and accumulate sum/sum_squares/counts per dimension
2. Compute weighted stddev across dataframes for each dimension
3. Determine if stddev is within bounds (all rows pass/fail together)
4. For "Others": recompute weighted stddev from aggregated sum/sum_squares/counts
Args:
column: The column being validated
dimension_col: The dimension column to group by
metrics_to_compute: Dict mapping metric names to Metrics enums
test_params: Test parameters (min/max bounds)
Returns:
List[DimensionResult]: Top N dimensions plus "Others"
"""
checker = self._get_validation_checker(test_params)
dimension_results = []
try:
dfs = self.runner if isinstance(self.runner, list) else [self.runner]
stddev_impl = Metrics.STDDEV(column).get_pandas_computation()
row_count_impl = Metrics.ROW_COUNT().get_pandas_computation()
dimension_aggregates = defaultdict(
lambda: {
Metrics.STDDEV.name: stddev_impl.create_accumulator(),
DIMENSION_TOTAL_COUNT_KEY: row_count_impl.create_accumulator(),
}
)
for df in dfs:
df_typed = cast(pd.DataFrame, df)
grouped = df_typed.groupby(dimension_col.name, dropna=False)
for dimension_value, group_df in grouped:
dimension_value = self.format_dimension_value(dimension_value)
dimension_aggregates[dimension_value][
Metrics.STDDEV.name
] = stddev_impl.update_accumulator(
dimension_aggregates[dimension_value][Metrics.STDDEV.name],
group_df,
)
dimension_aggregates[dimension_value][
DIMENSION_TOTAL_COUNT_KEY
] += row_count_impl.update_accumulator(
dimension_aggregates[dimension_value][
DIMENSION_TOTAL_COUNT_KEY
],
group_df,
)
results_data = []
for dimension_value, agg in dimension_aggregates.items():
stddev_value = stddev_impl.aggregate_accumulator(
agg[Metrics.STDDEV.name]
)
total_rows = row_count_impl.aggregate_accumulator(
agg[DIMENSION_TOTAL_COUNT_KEY]
)
if stddev_value is None:
logger.warning(
"Skipping '%s=%s' dimension since 'stddev' is 'None'",
dimension_col.name,
dimension_value,
)
continue
# Statistical validator: when stddev fails, ALL rows in dimension fail
failed_count = (
total_rows
if checker.violates_pandas({Metrics.STDDEV.name: stddev_value})
else 0
)
results_data.append(
{
DIMENSION_VALUE_KEY: dimension_value,
Metrics.STDDEV.name: stddev_value,
Metrics.COUNT.name: agg[Metrics.STDDEV.name].count_value,
Metrics.SUM.name: agg[Metrics.STDDEV.name].sum_value,
SUM_SQUARES_KEY: agg[Metrics.STDDEV.name].sum_squares_value,
DIMENSION_TOTAL_COUNT_KEY: total_rows,
DIMENSION_FAILED_COUNT_KEY: failed_count,
}
)
results_df = pd.DataFrame(results_data)
if not results_df.empty:
results_df = calculate_impact_score_pandas(
results_df,
failed_column=DIMENSION_FAILED_COUNT_KEY,
total_column=DIMENSION_TOTAL_COUNT_KEY,
)
def calculate_weighted_stddev(
df_aggregated, others_mask, metric_column
):
"""Calculate weighted stddev for Others using StdDev accumulator
For "Others" group, we recompute stddev from aggregated statistics
by constructing an accumulator and using the exact same aggregation
logic as the StdDev metric (ensuring consistency and DRY principle).
For top N dimensions, we use the pre-computed stddev.
"""
result = df_aggregated[metric_column].copy()
if others_mask.any():
others_sum = df_aggregated.loc[
others_mask, Metrics.SUM.name
].iloc[0]
others_count = df_aggregated.loc[
others_mask, Metrics.COUNT.name
].iloc[0]
others_sum_squares = df_aggregated.loc[
others_mask, SUM_SQUARES_KEY
].iloc[0]
accumulator = SumSumSquaresCount(
sum_value=others_sum,
sum_squares_value=others_sum_squares,
count_value=others_count,
)
others_stddev = StdDev.aggregate_accumulator(accumulator)
if others_stddev is not None:
result.loc[others_mask] = others_stddev
return result
results_df = aggregate_others_statistical_pandas(
results_df,
dimension_column=DIMENSION_VALUE_KEY,
agg_functions={
Metrics.SUM.name: "sum",
Metrics.COUNT.name: "sum",
SUM_SQUARES_KEY: "sum",
DIMENSION_TOTAL_COUNT_KEY: "sum",
DIMENSION_FAILED_COUNT_KEY: "sum",
},
final_metric_calculators={
Metrics.STDDEV.name: calculate_weighted_stddev
},
exclude_from_final=[
Metrics.SUM.name,
Metrics.COUNT.name,
SUM_SQUARES_KEY,
],
top_n=DEFAULT_TOP_DIMENSIONS,
violation_metrics=[Metrics.STDDEV.name],
violation_predicate=checker.violates_pandas,
)
for row_dict in results_df.to_dict("records"):
metric_values = self._build_metric_values_from_row(
row_dict, metrics_to_compute, test_params
)
evaluation = self._evaluate_test_condition(
metric_values, test_params
)
dimension_result = self._create_dimension_result(
row_dict,
dimension_col.name,
metric_values,
evaluation,
test_params,
)
dimension_results.append(dimension_result)
except Exception as exc:
logger.warning(f"Error executing dimensional query: {exc}")
logger.debug("Full error details: ", exc_info=True)
return dimension_results

View File

@ -10,20 +10,36 @@
# limitations under the License.
"""
Validator for column value stddevv to be between test case
Validator for column value stddev to be between test case
"""
from typing import Optional
from typing import Any, Dict, List, Optional
from sqlalchemy import Column
from sqlalchemy import Column, case, func, select
from metadata.data_quality.validations.base_test_handler import (
DIMENSION_FAILED_COUNT_KEY,
DIMENSION_IMPACT_SCORE_KEY,
DIMENSION_OTHERS_LABEL,
DIMENSION_TOTAL_COUNT_KEY,
DIMENSION_VALUE_KEY,
)
from metadata.data_quality.validations.column.base.columnValueStdDevToBeBetween import (
BaseColumnValueStdDevToBeBetweenValidator,
)
from metadata.data_quality.validations.impact_score import (
DEFAULT_TOP_DIMENSIONS,
get_impact_score_expression,
)
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
DIMENSION_GROUP_LABEL,
SQAValidatorMixin,
)
from metadata.generated.schema.tests.dimensionResult import DimensionResult
from metadata.profiler.metrics.registry import Metrics
from metadata.utils.logger import test_suite_logger
logger = test_suite_logger()
class ColumnValueStdDevToBeBetweenValidator(
@ -39,3 +55,232 @@ class ColumnValueStdDevToBeBetweenValidator(
column: column
"""
return self.run_query_results(self.runner, metric, column)
def _execute_dimensional_validation(
self,
column: Column,
dimension_col: Column,
metrics_to_compute: dict,
test_params: dict,
) -> List[DimensionResult]:
"""Execute dimensional validation for stddev using two-pass approach
Two-pass query strategy for accurate "Others" stddev:
Pass 1: Compute stddev for top N dimensions using CTE-based aggregation
Returns "Others" row with stddev=None (cannot aggregate stddevs)
Pass 2: Recompute stddev for "Others" from raw table data
Query: SELECT STDDEV(column) WHERE dimension NOT IN (top_N_values)
Uses native database STDDEV function (works across all dialects)
This approach ensures mathematical accuracy and dialect compatibility.
Args:
column: The column being validated
dimension_col: The dimension column to group by
metrics_to_compute: Dict mapping metric names to Metrics enums
test_params: Test parameters (min/max bounds)
Returns:
List[DimensionResult]: Top N dimensions plus "Others" with accurate stddev
"""
dimension_results = []
try:
# ==================== PASS 1: Top N Dimensions ====================
metric_expressions = {
DIMENSION_TOTAL_COUNT_KEY: Metrics.ROW_COUNT().fn(),
Metrics.STDDEV.name: Metrics.STDDEV(column).fn(),
}
def build_stddev_final(cte):
"""For top N: use pre-computed stddev. For Others: return None."""
return case(
[
(
getattr(cte.c, DIMENSION_GROUP_LABEL)
!= DIMENSION_OTHERS_LABEL,
func.max(getattr(cte.c, Metrics.STDDEV.name)),
)
],
else_=None,
)
failed_count_builder = self._get_validation_checker(
test_params
).get_sqa_failed_rows_builder(
{Metrics.STDDEV.name: Metrics.STDDEV.name},
DIMENSION_TOTAL_COUNT_KEY,
)
result_rows = self._execute_with_others_aggregation_statistical(
dimension_col,
metric_expressions,
failed_count_builder,
final_metric_builders={
Metrics.STDDEV.name: build_stddev_final,
},
top_dimensions_count=DEFAULT_TOP_DIMENSIONS,
)
# ==================== PASS 2: Recompute "Others" Stddev ====================
# Convert immutable RowMapping objects to mutable dicts
result_rows = [dict(row) for row in result_rows]
# Separate top N dimensions from "Others" row
top_n_rows = [
row
for row in result_rows
if row[DIMENSION_VALUE_KEY] != DIMENSION_OTHERS_LABEL
]
has_others = len(top_n_rows) < len(result_rows)
# Recompute "Others" only if it existed in Pass 1
if has_others:
if recomputed_others := self._compute_others_stddev(
column,
dimension_col,
failed_count_builder,
top_n_rows,
):
result_rows = top_n_rows + [recomputed_others]
else:
result_rows = top_n_rows
else:
result_rows = top_n_rows
# ==================== Process Results ====================
for row in result_rows:
stddev_value = row.get(Metrics.STDDEV.name)
if stddev_value is None:
logger.debug(
"Skipping dimension '%s=%s' with None stddev",
dimension_col.name,
row.get(DIMENSION_VALUE_KEY),
)
continue
metric_values = {
Metrics.STDDEV.name: stddev_value,
}
evaluation = self._evaluate_test_condition(metric_values, test_params)
dimension_result = self._create_dimension_result(
row,
dimension_col.name,
metric_values,
evaluation,
test_params,
)
dimension_results.append(dimension_result)
except Exception as exc:
logger.warning(f"Error executing dimensional query: {exc}")
logger.debug("Full error details: ", exc_info=True)
return dimension_results
def _compute_others_stddev(
self,
column: Column,
dimension_col: Column,
failed_count_builder,
top_dimension_values: List[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
"""Recompute stddev and metrics for "Others" dimension group.
Uses two-pass approach: Pass 1 computed top N dimensions, this computes
"Others" by rerunning stddev on all rows NOT in top N dimensions.
Args:
column: The column being validated
dimension_col: The dimension column to group by
failed_count_builder: SQL expression builder for failed count (from checker)
result_rows: Results from Pass 1 WITHOUT "Others" row (only top N dimensions)
Returns:
New "Others" row dict with recomputed metrics, or None if computation failed
"""
# Extract top N dimension values (result_rows no longer contains "Others")
# If no top dimensions to exclude, cannot compute "Others"
if not top_dimension_values:
return None
try:
# Compute stddev directly on base table with WHERE filter
stddev_expr = Metrics.STDDEV(column).fn()
total_count_expr = Metrics.ROW_COUNT().fn()
# Create stats subquery with WHERE filter for "Others" group
# Query: SELECT STDDEV(col), COUNT(*) FROM table WHERE dimension NOT IN (top_N)
# Extract just the dimension values from the top N result rows
top_dimension_value_list = [
row[DIMENSION_VALUE_KEY] for row in top_dimension_values
]
stats_subquery = (
select(
[
stddev_expr.label(Metrics.STDDEV.name),
total_count_expr.label(DIMENSION_TOTAL_COUNT_KEY),
]
)
.select_from(self.runner.dataset)
.where(dimension_col.notin_(top_dimension_value_list))
).alias("others_stats")
# Apply failed_count builder to stats subquery (reused from Pass 1)
failed_count_expr = failed_count_builder(stats_subquery)
# Calculate impact score in SQL (same expression as Pass 1)
total_count_col = getattr(stats_subquery.c, DIMENSION_TOTAL_COUNT_KEY)
impact_score_expr = get_impact_score_expression(
failed_count_expr, total_count_col
)
# Final query: stddev, total_count, failed_count, impact_score
# All computed in SQL just like Pass 1
others_query = select(
[
getattr(stats_subquery.c, Metrics.STDDEV.name),
total_count_col,
failed_count_expr.label(DIMENSION_FAILED_COUNT_KEY),
impact_score_expr.label(DIMENSION_IMPACT_SCORE_KEY),
]
).select_from(stats_subquery)
result = self.runner.session.execute(others_query).fetchone()
if result:
others_stddev, total_count, failed_count, impact_score = result
logger.debug(
"Recomputed 'Others' (SQL): stddev=%s, failed=%d/%d, impact=%.3f",
others_stddev,
failed_count,
total_count,
impact_score,
)
# Return new "Others" row with SQL-computed values
return {
DIMENSION_VALUE_KEY: DIMENSION_OTHERS_LABEL,
Metrics.STDDEV.name: others_stddev,
DIMENSION_TOTAL_COUNT_KEY: total_count,
DIMENSION_FAILED_COUNT_KEY: failed_count,
DIMENSION_IMPACT_SCORE_KEY: impact_score,
}
return None
except Exception as exc:
logger.warning(
"Failed to recompute 'Others' stddev, will be excluded: %s", exc
)
logger.debug("Full error details: ", exc_info=True)
return None

View File

@ -16,6 +16,8 @@ Population Standard deviation Metric definition
# Keep SQA docs style defining custom constructs
# pylint: disable=consider-using-f-string,duplicate-code
import math
from typing import TYPE_CHECKING, NamedTuple, Optional
from sqlalchemy import column
from sqlalchemy.ext.compiler import compiles
@ -23,6 +25,7 @@ from sqlalchemy.sql.functions import FunctionElement
from metadata.generated.schema.configuration.profilerConfiguration import MetricType
from metadata.profiler.metrics.core import CACHE, StaticMetric, _label
from metadata.profiler.metrics.pandas_metric_protocol import PandasComputation
from metadata.profiler.orm.functions.length import LenFn
from metadata.profiler.orm.registry import (
FLOAT_SET,
@ -33,9 +36,20 @@ from metadata.profiler.orm.registry import (
)
from metadata.utils.logger import profiler_logger
if TYPE_CHECKING:
import pandas as pd
logger = profiler_logger()
class SumSumSquaresCount(NamedTuple):
"""Running sum, sum of squares, and count for computing stddev efficiently"""
sum_value: float
sum_squares_value: float
count_value: int
class StdDevFn(FunctionElement):
name = __qualname__
inherit_cache = CACHE
@ -54,12 +68,11 @@ def _(element, compiler, **kw):
@compiles(StdDevFn, Dialects.SQLite) # Needed for unit tests
def _(element, compiler, **kw):
"""
This actually returns the squared STD, but as
it is only required for tests we can live with it.
SQLite standard deviation using computational formula.
Requires SQRT function (registered via tests/unit/conftest.py for unit tests).
"""
proc = compiler.process(element.clauses, **kw)
return "AVG(%s * %s) - AVG(%s) * AVG(%s)" % ((proc,) * 4)
return "SQRT(AVG(%s * %s) - AVG(%s) * AVG(%s))" % ((proc,) * 4)
@compiles(StdDevFn, Dialects.Trino)
@ -127,23 +140,125 @@ class StdDev(StaticMetric):
def df_fn(self, dfs=None):
"""pandas function"""
import pandas as pd # pylint: disable=import-outside-toplevel
if is_quantifiable(self.col.type):
computation = self.get_pandas_computation()
accumulator = computation.create_accumulator()
for df in dfs:
try:
df = pd.to_numeric(pd.concat(df[self.col.name] for df in dfs))
if not df.empty:
return df.std()
return None
accumulator = computation.update_accumulator(accumulator, df)
except MemoryError:
logger.error(
f"Unable to compute Standard Deviation for {self.col.name} due to memory constraints."
f"Unable to compute 'Standard Deviation' for {self.col.name} due to memory constraints."
f"We recommend using a smaller sample size or partitionning."
)
return None
except Exception as err:
logger.debug(
f"Error while computing 'Standard Deviation' for column {self.col.name}: {err}"
)
return None
return computation.aggregate_accumulator(accumulator)
logger.debug(
f"{self.col.name} has type {self.col.type}, which is not listed as quantifiable."
+ " We won't compute STDDEV for it."
def get_pandas_computation(self) -> PandasComputation:
"""Get pandas computation with accumulator for efficient stddev calculation
Returns:
PandasComputation: Computation protocol with create/update/aggregate methods
"""
return PandasComputation[SumSumSquaresCount, Optional[float]](
create_accumulator=lambda: SumSumSquaresCount(0.0, 0.0, 0),
update_accumulator=lambda acc, df: StdDev.update_accumulator(
acc, df, self.col
),
aggregate_accumulator=StdDev.aggregate_accumulator,
)
return None
@staticmethod
def update_accumulator(
sum_sum_squares_count: SumSumSquaresCount, df: "pd.DataFrame", column
) -> SumSumSquaresCount:
"""Optimized accumulator: maintains running sum, sum of squares, and count
Instead of concatenating dataframes, directly accumulates the necessary
statistics for computing standard deviation. This is memory efficient (O(1))
and enables proper aggregation of "Others" in dimensional validation.
Formula for variance across multiple groups:
variance = (sum_squares / count) - (sum / count)²
stddev = variance
Args:
sum_sum_squares_count: Current accumulator state
df: DataFrame chunk to process
column: Column to compute stddev for
Returns:
Updated accumulator with new chunk's statistics added
"""
import pandas as pd
clean_df = df[column.name].dropna()
if clean_df.empty:
return sum_sum_squares_count
chunk_count = len(clean_df)
if is_quantifiable(column.type):
numeric_df = pd.to_numeric(clean_df, errors="coerce").dropna()
if numeric_df.empty:
return sum_sum_squares_count
chunk_sum = numeric_df.sum()
chunk_sum_squares = (numeric_df**2).sum()
chunk_count = len(numeric_df)
else:
return sum_sum_squares_count
if pd.isnull(chunk_sum) or pd.isnull(chunk_sum_squares):
return sum_sum_squares_count
return SumSumSquaresCount(
sum_value=sum_sum_squares_count.sum_value + chunk_sum,
sum_squares_value=sum_sum_squares_count.sum_squares_value
+ chunk_sum_squares,
count_value=sum_sum_squares_count.count_value + chunk_count,
)
@staticmethod
def aggregate_accumulator(
sum_sum_squares_count: SumSumSquaresCount,
) -> Optional[float]:
"""Compute final stddev from running sum, sum of squares, and count
Uses the computational formula for variance:
variance = E[] - E[X]²
= (sum_squares / count) - (sum / count)²
Args:
sum_sum_squares_count: Accumulated statistics
Returns:
Population standard deviation, or None if no data
"""
if sum_sum_squares_count.count_value == 0:
return None
mean = sum_sum_squares_count.sum_value / sum_sum_squares_count.count_value
mean_of_squares = (
sum_sum_squares_count.sum_squares_value / sum_sum_squares_count.count_value
)
variance = mean_of_squares - (mean**2)
# Handle floating point precision issues
if variance < 0:
if abs(variance) < 1e-10: # Close to zero due to floating point
variance = 0
else:
logger.warning(
f"Negative variance ({variance}) encountered, returning None"
)
return None
return math.sqrt(variance)

View File

@ -1,3 +1,49 @@
import math
import sqlalchemy as sqa
from pytest import fixture
@fixture(scope="session", autouse=True)
def register_sqlite_math_functions():
"""
Register custom math functions for SQLite used in unit tests.
SQLite doesn't have built-in SQRT function, so we register Python's math.sqrt
to make it available for all SQLite connections in tests.
This runs automatically for all unit tests (autouse=True) and only once
per test session (scope="session").
"""
def safe_sqrt(x):
"""
Safe square root that handles floating-point precision issues.
When computing variance using AVG(x*x) - AVG(x)*AVG(x), floating-point
precision can result in slightly negative values (e.g., -1e-15) when
the true variance is zero. This function treats near-zero negative
values as zero, matching the behavior in stddev.py:254-256.
"""
if x is None:
return None
if x < 0:
if abs(x) < 1e-10:
return 0.0
raise ValueError(f"Cannot compute square root of negative number: {x}")
return math.sqrt(x)
@sqa.event.listens_for(sqa.engine.Engine, "connect")
def register_functions(dbapi_conn, connection_record):
if "sqlite" in str(type(dbapi_conn)):
dbapi_conn.create_function("SQRT", 1, safe_sqrt)
yield
# Clean up event listener after tests
sqa.event.remove(sqa.engine.Engine, "connect", register_functions)
def pytest_pycollect_makeitem(collector, name, obj):
try:
if obj.__name__ in ("TestSuiteSource", "TestSuiteInterfaceFactory"):

View File

@ -232,7 +232,7 @@ class DatalakeMetricsTest(TestCase):
)
res = profiler.compute_metrics()._column_results
assert round(res.get(User.age.name).get(Metrics.STDDEV.name), 2) == 0.82
assert round(res.get(User.age.name).get(Metrics.STDDEV.name), 2) == 0.71
def test_time(self):
"""

View File

@ -240,7 +240,7 @@ class ProfilerTest(TestCase):
maxLength=None,
mean=31.0,
sum=124.0,
stddev=0.816496580927726,
stddev=0.7071067811865476,
variance=None,
median=31.0,
firstQuartile=30.5,

View File

@ -182,9 +182,7 @@ class MetricsTest(TestCase):
profiler_interface=self.sqa_profiler_interface,
)
res = profiler.compute_metrics()._column_results
# SQLITE STD custom implementation returns the squared STD.
# Only useful for testing purposes
assert res.get(User.age.name).get(Metrics.STDDEV.name) == 0.25
assert res.get(User.age.name).get(Metrics.STDDEV.name) == 0.5
def test_earliest_time(self):
"""

View File

@ -173,7 +173,7 @@ class ProfilerTest(TestCase):
maxLength=None,
mean=30.5,
sum=61.0,
stddev=0.25,
stddev=0.5,
variance=None,
distinctCount=2.0,
distinctProportion=1.0,
@ -182,7 +182,7 @@ class ProfilerTest(TestCase):
firstQuartile=30.0,
thirdQuartile=31.0,
interQuartileRange=1.0,
nonParametricSkew=2.0,
nonParametricSkew=1.0,
histogram=Histogram(boundaries=["30.000 and up"], frequencies=[2]),
)

View File

@ -1188,3 +1188,19 @@ def test_case_column_values_to_be_between_dimensional():
dimensionColumns=["name"],
computePassedFailedRowCount=True,
) # type: ignore
@pytest.fixture
def test_case_column_value_stddev_to_be_between_dimensional():
"""Test case for test column_value_median_to_be_between"""
return TestCase(
name=TEST_CASE_NAME,
entityLink=ENTITY_LINK_AGE,
testSuite=EntityReference(id=uuid4(), type="TestSuite"), # type: ignore
testDefinition=EntityReference(id=uuid4(), type="TestDefinition"), # type: ignore
parameterValues=[
TestCaseParameterValue(name="minValueForStdDevInCol", value="20"),
TestCaseParameterValue(name="maxValueForStdDevInCol", value="40"),
],
dimensionColumns=["name"],
) # type: ignore

View File

@ -207,7 +207,7 @@ TEST_CASE_SUPPORT_ROW_LEVEL_PASS_FAILED = {
"COLUMN",
(
TestCaseResult,
"0.25",
"0.5",
None,
TestCaseStatus.Failed,
None,
@ -956,6 +956,29 @@ TEST_CASE_SUPPORT_ROW_LEVEL_PASS_FAILED = {
("name=Others", TestCaseStatus.Success, 30, 0, 100, 0, 0),
],
),
(
"test_case_column_value_stddev_to_be_between_dimensional",
"columnValueStdDevToBeBetween",
"COLUMN",
(
TestCaseResult,
"0.5",
None,
TestCaseStatus.Failed,
None,
None,
None,
None,
),
[
("name=John", TestCaseStatus.Failed, None, None, None, None, 0.0667),
("name=Alice", TestCaseStatus.Failed, None, None, None, None, 0.0333),
("name=Bob", TestCaseStatus.Failed, None, None, None, None, 0.0333),
("name=Charlie", TestCaseStatus.Failed, None, None, None, None, 0.0333),
("name=Diana", TestCaseStatus.Failed, None, None, None, None, 0.0333),
("name=Others", TestCaseStatus.Failed, None, None, None, None, 0.0667),
],
),
],
)
def test_suite_validation_database(

View File

@ -323,7 +323,7 @@ DATALAKE_DATA_FRAME = lambda times_increase_sample_data: DataFrame(
"COLUMN",
(
TestCaseResult,
"0.5000208346355071",
"0.5",
None,
TestCaseStatus.Failed,
None,
@ -1154,6 +1154,29 @@ DATALAKE_DATA_FRAME = lambda times_increase_sample_data: DataFrame(
("name=Others", TestCaseStatus.Success, 4000, 0, 100, 0, 0),
],
),
(
"test_case_column_value_stddev_to_be_between_dimensional",
"columnValueStdDevToBeBetween",
"COLUMN",
(
TestCaseResult,
"0.5",
None,
TestCaseStatus.Failed,
None,
None,
None,
None,
),
[
("name=Alice", TestCaseStatus.Failed, None, None, None, None, 0.6667),
("name=Bob", TestCaseStatus.Failed, None, None, None, None, 0.6667),
("name=Charlie", TestCaseStatus.Failed, None, None, None, None, 0.6667),
("name=Diana", TestCaseStatus.Failed, None, None, None, None, 0.6667),
("name=Jane", TestCaseStatus.Failed, None, None, None, None, 0.6667),
("name=Others", TestCaseStatus.Failed, None, None, None, None, 0.6667),
],
),
],
)
def test_suite_validation_datalake(