ISSUE #1753 - Add Row Count to Custom SQL Test (#22697)

* feat: add count rows support for custom SQL

* style: ran python linting

* feat: added logic for partitioned custom sql row count

* migration: partitionExpression parameter

* chore: resolve conflicts

(cherry picked from commit d58b8a63d675e9bf91a2283a5f37702648cdab7f)
This commit is contained in:
Teddy 2025-08-19 06:40:49 +02:00 committed by Teddy Crepineau
parent 8d98833622
commit c69c21d82e
16 changed files with 345 additions and 12 deletions

View File

@ -0,0 +1,18 @@
UPDATE test_definition
SET json = JSON_SET(json, '$.supportsRowLevelPassedFailed', true)
WHERE JSON_EXTRACT(json, '$.name') = 'tableCustomSQLQuery';
UPDATE test_definition
SET json = JSON_ARRAY_APPEND(
json,
'$.parameterDefinition',
JSON_OBJECT(
'name', 'partitionExpression',
'displayName', 'Partition Expression',
'description', 'Partition expression that will be used to compute the passed/failed row count, if compute row count is enabled (e.g. created_date > DATE_SUB(CURDATE(), INTERVAL 1 DAY)).',
'dataType', 'STRING',
'required', false
)
)
WHERE JSON_EXTRACT(json, '$.name') = 'tableCustomSQLQuery'
AND NOT JSON_CONTAINS(JSON_EXTRACT(json, '$.parameterDefinition[*].name'),'"partitionExpression"');

View File

@ -0,0 +1,22 @@
UPDATE test_definition
SET json = jsonb_set(json, '{supportsRowLevelPassedFailed}', 'true'::jsonb)
WHERE json->>'name' = 'tableCustomSQLQuery';
UPDATE test_definition
SET json = jsonb_set(
json,
'{parameterDefinition}',
(json->'parameterDefinition') || jsonb_build_object(
'name', 'partitionExpression',
'displayName', 'Partition Expression',
'description', 'Partition expression that will be used to compute the passed/failed row count, if compute row count is enabled.',
'dataType', 'STRING',
'required', false
)
)
WHERE json->>'name' = 'tableCustomSQLQuery'
AND NOT EXISTS (
SELECT 1
FROM jsonb_array_elements(json->'parameterDefinition') AS elem
WHERE elem->>'name' = 'partitionExpression'
);

View File

@ -4,8 +4,13 @@ from typing import List, Optional, Union
from pydantic import BaseModel
from metadata.generated.schema.entity.data.table import Column, TableProfilerConfig
from metadata.generated.schema.entity.data.table import (
Column,
Table,
TableProfilerConfig,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseConnection,
DatabaseServiceType,
)
from metadata.ingestion.models.custom_pydantic import CustomSecretStr
@ -27,3 +32,8 @@ class TableDiffRuntimeParameters(BaseModel):
extraColumns: List[str]
whereClause: Optional[str]
table_profile_config: Optional[TableProfilerConfig]
class TableCustomSQLQueryRuntimeParameters(BaseModel):
conn_config: DatabaseConnection
entity: Table

View File

@ -19,9 +19,15 @@ from typing import Dict, Set, Type
from metadata.data_quality.validations.runtime_param_setter.param_setter import (
RuntimeParameterSetter,
)
from metadata.data_quality.validations.runtime_param_setter.table_custom_sql_query_params_setter import (
TableCustomSQLQueryParamsSetter,
)
from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import (
TableDiffParamsSetter,
)
from metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery import (
TableCustomSQLQueryValidator,
)
from metadata.data_quality.validations.table.sqlalchemy.tableDiff import (
TableDiffValidator,
)
@ -60,6 +66,9 @@ class RuntimeParameterSetterFactory:
"""Set"""
self._setter_map: Dict[str, Set[Type[RuntimeParameterSetter]]] = {
validator_name(TableDiffValidator): {TableDiffParamsSetter},
validator_name(TableCustomSQLQueryValidator): {
TableCustomSQLQueryParamsSetter
},
}
def get_runtime_param_setters(

View File

@ -0,0 +1,31 @@
# Copyright 2024 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module that defines the TableCustomSQLQueryParamsSetter class."""
from metadata.data_quality.validations.models import (
TableCustomSQLQueryRuntimeParameters,
)
from metadata.data_quality.validations.runtime_param_setter.param_setter import (
RuntimeParameterSetter,
)
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
class TableCustomSQLQueryParamsSetter(RuntimeParameterSetter):
"""Set runtime parameters for a the table custom sql query test."""
def get_parameters(self, test_case) -> TableCustomSQLQueryRuntimeParameters:
return TableCustomSQLQueryRuntimeParameters(
conn_config=DatabaseConnection(
config=self.service_connection_config,
),
entity=self.table_entity,
)

View File

@ -98,13 +98,37 @@ class BaseTableCustomSQLQueryValidator(BaseTestValidator):
status = TestCaseStatus.Failed
result_value = len_rows
if self.test_case.computePassedFailedRowCount:
row_count = self.get_row_count()
else:
row_count = None
return self.get_test_case_result_object(
self.execution_date,
status,
f"Found {result_value} row(s). Test query is expected to return {threshold} row.",
[TestResultValue(name=RESULT_ROW_COUNT, value=str(result_value))],
row_count=row_count,
failed_rows=result_value,
)
@abstractmethod
def _run_results(self, sql_expression: str, strategy: Strategy = Strategy.ROWS):
raise NotImplementedError
@abstractmethod
def compute_row_count(self):
"""Compute row count for the given column
Raises:
NotImplementedError:
"""
raise NotImplementedError
def get_row_count(self) -> int:
"""Get row count
Returns:
Tuple[int, int]:
"""
return self.compute_row_count()

View File

@ -13,6 +13,8 @@
Validator for table custom SQL Query test case
"""
from typing import List, Optional
from metadata.data_quality.validations.mixins.pandas_validator_mixin import (
PandasValidatorMixin,
)
@ -39,3 +41,39 @@ class TableCustomSQLQueryValidator(
if len(runner.query(sql_expression))
]
)
def compute_row_count(self) -> Optional[int]:
"""Compute row count for the given column
Returns:
Optional[int]: Total number of rows across all dataframes
"""
runner: List["DataFrame"] = self.runner # type: ignore
if not runner:
return None
total_rows = 0
partition_expression = next(
(
param.value
for param in self.test_case.parameterValues
if param.name == "partitionExpression"
),
None,
)
for dataframe in runner:
if dataframe is not None:
if partition_expression:
try:
total_rows += len(dataframe.query(partition_expression))
except Exception as e:
logger.error(
"Error executing partition expression, "
f"expression may be invalid: {partition_expression} - {e}"
)
return None
else:
total_rows += len(dataframe.index)
return total_rows if total_rows > 0 else None

View File

@ -13,21 +13,42 @@
Validator for table custom SQL Query test case
"""
from typing import Optional, cast
from sqlalchemy import text
from sqlalchemy.sql import func, select
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
SQAValidatorMixin,
)
from metadata.data_quality.validations.models import (
TableCustomSQLQueryRuntimeParameters,
)
from metadata.data_quality.validations.table.base.tableCustomSQLQuery import (
BaseTableCustomSQLQueryValidator,
Strategy,
)
from metadata.generated.schema.tests.basic import TestCaseResult
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.orm.functions.table_metric_computer import TableMetricComputer
from metadata.profiler.processor.runner import QueryRunner
from metadata.utils.helpers import is_safe_sql_query
class TableCustomSQLQueryValidator(BaseTableCustomSQLQueryValidator, SQAValidatorMixin):
"""Validator for table custom SQL Query test case"""
def run_validation(self) -> TestCaseResult:
"""Run validation for the given test case
Returns:
TestCaseResult:
"""
self.runtime_params = self.get_runtime_parameters(
TableCustomSQLQueryRuntimeParameters
)
return super().run_validation()
def _run_results(self, sql_expression: str, strategy: Strategy = Strategy.ROWS):
"""compute result of the test case"""
if not is_safe_sql_query(sql_expression):
@ -48,3 +69,39 @@ class TableCustomSQLQueryValidator(BaseTableCustomSQLQueryValidator, SQAValidato
except Exception as exc:
self.runner._session.rollback() # pylint: disable=protected-access
raise exc
def compute_row_count(self) -> Optional[int]:
"""Compute row count for the given column
Raises:
NotImplementedError:
"""
partition_expression = next(
(
param.value
for param in self.test_case.parameterValues
if param.name == "partitionExpression"
),
None,
)
if partition_expression:
stmt = (
select(func.count())
.select_from(self.runner.table)
.filter(text(partition_expression))
)
return self.runner.session.execute(stmt).scalar()
self.runner = cast(QueryRunner, self.runner)
dialect = self.runner._session.get_bind().dialect.name
table_metric_computer: TableMetricComputer = TableMetricComputer(
dialect,
runner=self.runner,
metrics=[Metrics.ROW_COUNT],
conn_config=self.runtime_params.conn_config,
entity=self.runtime_params.entity,
)
row = table_metric_computer.compute()
if row:
return dict(row).get(Metrics.ROW_COUNT.value.name())
return None

View File

@ -38,18 +38,9 @@ from metadata.utils.logger import profiler_interface_registry_logger
logger = profiler_interface_registry_logger()
@inject
def get_row_count_metric(metrics: Inject[Type[MetricRegistry]] = None):
if metrics is None:
raise DependencyNotFoundError(
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
)
return metrics.ROW_COUNT().name()
COLUMN_COUNT = "columnCount"
COLUMN_NAMES = "columnNames"
ROW_COUNT = get_row_count_metric()
ROW_COUNT = "rowCount"
SIZE_IN_BYTES = "sizeInBytes"
CREATE_DATETIME = "createDateTime"

View File

@ -51,12 +51,12 @@ from metadata.profiler.metrics.core import (
TMetric,
)
from metadata.profiler.metrics.static.row_count import RowCount
from metadata.profiler.orm.functions.table_metric_computer import CREATE_DATETIME
from metadata.profiler.orm.registry import NOT_COMPUTE
from metadata.profiler.processor.metric_filter import MetricFilter
from metadata.utils.logger import profiler_logger
logger = profiler_logger()
CREATE_DATETIME = "createDateTime"
class MissingMetricException(Exception):

View File

@ -156,6 +156,8 @@ def create_sqlite_table():
session.add_all(data)
session.commit()
runner.service_connection = sqlite_conn
runner.entity = TABLE
yield runner
# clean up
User.__table__.drop(bind=engine)
@ -635,6 +637,31 @@ def test_case_table_custom_sql_unsafe_query_aborted():
) # type: ignore
@pytest.fixture
def test_case_table_custom_sql_with_partition_condition():
"""Test case for test column_value_median_to_be_between"""
return TestCase(
name=TEST_CASE_NAME,
entityLink=ENTITY_LINK_USER,
testSuite=EntityReference(id=uuid4(), type="TestSuite"), # type: ignore
testDefinition=EntityReference(id=uuid4(), type="TestDefinition"), # type: ignore
parameterValues=[
TestCaseParameterValue(
name="sqlExpression",
value="SELECT * FROM users WHERE age > 20 AND name = 'John'",
),
TestCaseParameterValue(
name="strategy",
value="ROWS",
),
TestCaseParameterValue(
name="partitionExpression",
value="name = 'John'",
),
],
) # type: ignore
@pytest.fixture
def test_case_table_row_count_to_be_between():
"""Test case for test column_value_median_to_be_between"""
@ -709,6 +736,23 @@ def test_case_table_custom_sql_query_success_dl():
)
@pytest.fixture
def test_case_table_custom_sql_query_success_dl_with_partition_expression():
"""Test case for test custom SQL table test"""
return TestCase(
name=TEST_CASE_NAME,
entityLink=ENTITY_LINK_USER,
testSuite=EntityReference(id=uuid4(), type="TestSuite"), # type: ignore
testDefinition=EntityReference(id=uuid4(), type="TestDefinition"), # type: ignore
parameterValues=[
TestCaseParameterValue(name="sqlExpression", value="age < 0"),
TestCaseParameterValue(
name="partitionExpression", value="nickname == 'johnny b goode'"
),
],
)
@pytest.fixture
def test_case_column_values_to_be_between_date():
return TestCase(

View File

@ -19,11 +19,27 @@ from unittest.mock import patch
import pytest
from metadata.data_quality.validations.models import (
TableCustomSQLQueryRuntimeParameters,
)
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus
from metadata.generated.schema.tests.testCase import TestCaseParameterValue
from metadata.utils.importer import import_test_case_class
EXECUTION_DATE = datetime.strptime("2021-07-03", "%Y-%m-%d")
TEST_CASE_SUPPORT_ROW_LEVEL_PASS_FAILED = {
"columnValuesLengthToBeBetween",
"columnValuesToBeBetween",
"columnValuesToBeInSet",
"columnValuesToBeNotInSet",
"columnValuesToBeNotNull",
"columnValuesToBeUnique",
"columnValuesToMatchRegex",
"columnValuesToNotMatchRegex",
"tableCustomSQLQuery",
}
# pylint: disable=line-too-long
@pytest.mark.parametrize(
@ -359,6 +375,12 @@ EXECUTION_DATE = datetime.strptime("2021-07-03", "%Y-%m-%d")
"TABLE",
(TestCaseResult, "0", None, TestCaseStatus.Success, None, None, None, None),
),
(
"test_case_table_custom_sql_with_partition_condition",
"tableCustomSQLQuery",
"TABLE",
(TestCaseResult, "10", None, TestCaseStatus.Failed, 10, 10, 50.0, 50.0),
),
(
"test_case_table_row_count_to_be_between",
"tableRowCountToBeBetween",
@ -460,6 +482,22 @@ def test_suite_validation_database(
failed_percentage,
) = expected
if test_case_type in TEST_CASE_SUPPORT_ROW_LEVEL_PASS_FAILED:
test_case.computePassedFailedRowCount = True
if test_case_type == "tableCustomSQLQuery":
runtime_params = TableCustomSQLQueryRuntimeParameters(
conn_config=DatabaseConnection(
config=create_sqlite_table.service_connection,
),
entity=create_sqlite_table.entity,
)
test_case.parameterValues.append(
TestCaseParameterValue(
name=TableCustomSQLQueryRuntimeParameters.__name__,
value=runtime_params.model_dump_json(),
)
)
if test_case_name == "test_case_column_values_to_be_between_date":
with patch(
"metadata.data_quality.validations.column.sqlalchemy.columnValuesToBeBetween.ColumnValuesToBeBetweenValidator._run_results",
@ -525,3 +563,11 @@ def test_suite_validation_database(
if failed_percentage:
assert round(res.failedRowsPercentage, 2) == failed_percentage
assert res.testCaseStatus == status
if (
test_case_type in TEST_CASE_SUPPORT_ROW_LEVEL_PASS_FAILED
and test_case_name != "test_case_table_custom_sql_unsafe_query_aborted"
):
assert res.failedRows is not None
assert res.failedRowsPercentage is not None
assert res.passedRows is not None
assert res.passedRowsPercentage is not None

View File

@ -23,6 +23,18 @@ from pandas import DataFrame
from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus
from metadata.utils.importer import import_test_case_class
TEST_CASE_SUPPORT_ROW_LEVEL_PASS_FAILED = {
"columnValuesLengthToBeBetween",
"columnValuesToBeBetween",
"columnValuesToBeInSet",
"columnValuesToBeNotInSet",
"columnValuesToBeNotNull",
"columnValuesToBeUnique",
"columnValuesToMatchRegex",
"columnValuesToNotMatchRegex",
"tableCustomSQLQuery",
}
EXECUTION_DATE = datetime.strptime("2021-07-03", "%Y-%m-%d")
DL_DATA = (
[
@ -536,6 +548,21 @@ DATALAKE_DATA_FRAME = lambda times_increase_sample_data: DataFrame(
0.0,
),
),
(
"test_case_table_custom_sql_query_success_dl_with_partition_expression",
"tableCustomSQLQuery",
"TABLE",
(
TestCaseResult,
None,
None,
TestCaseStatus.Success,
2000,
0,
100.0,
0.0,
),
),
],
)
def test_suite_validation_datalake(
@ -559,6 +586,9 @@ def test_suite_validation_datalake(
failed_percentage,
) = expected
if test_case_type in TEST_CASE_SUPPORT_ROW_LEVEL_PASS_FAILED:
test_case.computePassedFailedRowCount = True
test_handler_obj = import_test_case_class(
test_type,
"pandas",
@ -587,3 +617,8 @@ def test_suite_validation_datalake(
if failed_percentage:
assert round(res.failedRowsPercentage, 2) == failed_percentage
assert res.testCaseStatus == status
if test_case_type in TEST_CASE_SUPPORT_ROW_LEVEL_PASS_FAILED:
assert res.failedRows is not None
assert res.failedRowsPercentage is not None
assert res.passedRows is not None
assert res.passedRowsPercentage is not None

View File

@ -35,8 +35,16 @@
"description": "Threshold to use to determine if the test passes or fails (defaults to 0).",
"dataType": "NUMBER",
"required": false
},
{
"name": "partitionExpression",
"displayName": "Partition Expression",
"description": "Partition expression that will be used to compute the passed/failed row count, if compute row count is enabled.",
"dataType": "STRING",
"required": false
}
],
"supportsRowLevelPassedFailed": true,
"provider": "system",
"dataQualityDimension": "SQL"
}