ISSUE #1753 - Add Row Count to Custom SQL Test (#22697)

* feat: add count rows support for custom SQL

* style: ran python linting

* feat: added logic for partitioned custom sql row count

* migration: partitionExpression parameter

* chore: resolve conflicts

(cherry picked from commit d58b8a63d675e9bf91a2283a5f37702648cdab7f)
This commit is contained in:
Teddy 2025-08-19 06:40:49 +02:00 committed by Teddy Crepineau
parent 8d98833622
commit c69c21d82e
16 changed files with 345 additions and 12 deletions

View File

@ -0,0 +1,18 @@
UPDATE test_definition
SET json = JSON_SET(json, '$.supportsRowLevelPassedFailed', true)
WHERE JSON_EXTRACT(json, '$.name') = 'tableCustomSQLQuery';
UPDATE test_definition
SET json = JSON_ARRAY_APPEND(
json,
'$.parameterDefinition',
JSON_OBJECT(
'name', 'partitionExpression',
'displayName', 'Partition Expression',
'description', 'Partition expression that will be used to compute the passed/failed row count, if compute row count is enabled (e.g. created_date > DATE_SUB(CURDATE(), INTERVAL 1 DAY)).',
'dataType', 'STRING',
'required', false
)
)
WHERE JSON_EXTRACT(json, '$.name') = 'tableCustomSQLQuery'
AND NOT JSON_CONTAINS(JSON_EXTRACT(json, '$.parameterDefinition[*].name'),'"partitionExpression"');

View File

@ -0,0 +1,22 @@
UPDATE test_definition
SET json = jsonb_set(json, '{supportsRowLevelPassedFailed}', 'true'::jsonb)
WHERE json->>'name' = 'tableCustomSQLQuery';
UPDATE test_definition
SET json = jsonb_set(
json,
'{parameterDefinition}',
(json->'parameterDefinition') || jsonb_build_object(
'name', 'partitionExpression',
'displayName', 'Partition Expression',
'description', 'Partition expression that will be used to compute the passed/failed row count, if compute row count is enabled.',
'dataType', 'STRING',
'required', false
)
)
WHERE json->>'name' = 'tableCustomSQLQuery'
AND NOT EXISTS (
SELECT 1
FROM jsonb_array_elements(json->'parameterDefinition') AS elem
WHERE elem->>'name' = 'partitionExpression'
);

View File

@ -4,8 +4,13 @@ from typing import List, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
from metadata.generated.schema.entity.data.table import Column, TableProfilerConfig from metadata.generated.schema.entity.data.table import (
Column,
Table,
TableProfilerConfig,
)
from metadata.generated.schema.entity.services.databaseService import ( from metadata.generated.schema.entity.services.databaseService import (
DatabaseConnection,
DatabaseServiceType, DatabaseServiceType,
) )
from metadata.ingestion.models.custom_pydantic import CustomSecretStr from metadata.ingestion.models.custom_pydantic import CustomSecretStr
@ -27,3 +32,8 @@ class TableDiffRuntimeParameters(BaseModel):
extraColumns: List[str] extraColumns: List[str]
whereClause: Optional[str] whereClause: Optional[str]
table_profile_config: Optional[TableProfilerConfig] table_profile_config: Optional[TableProfilerConfig]
class TableCustomSQLQueryRuntimeParameters(BaseModel):
conn_config: DatabaseConnection
entity: Table

View File

@ -19,9 +19,15 @@ from typing import Dict, Set, Type
from metadata.data_quality.validations.runtime_param_setter.param_setter import ( from metadata.data_quality.validations.runtime_param_setter.param_setter import (
RuntimeParameterSetter, RuntimeParameterSetter,
) )
from metadata.data_quality.validations.runtime_param_setter.table_custom_sql_query_params_setter import (
TableCustomSQLQueryParamsSetter,
)
from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import ( from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import (
TableDiffParamsSetter, TableDiffParamsSetter,
) )
from metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery import (
TableCustomSQLQueryValidator,
)
from metadata.data_quality.validations.table.sqlalchemy.tableDiff import ( from metadata.data_quality.validations.table.sqlalchemy.tableDiff import (
TableDiffValidator, TableDiffValidator,
) )
@ -60,6 +66,9 @@ class RuntimeParameterSetterFactory:
"""Set""" """Set"""
self._setter_map: Dict[str, Set[Type[RuntimeParameterSetter]]] = { self._setter_map: Dict[str, Set[Type[RuntimeParameterSetter]]] = {
validator_name(TableDiffValidator): {TableDiffParamsSetter}, validator_name(TableDiffValidator): {TableDiffParamsSetter},
validator_name(TableCustomSQLQueryValidator): {
TableCustomSQLQueryParamsSetter
},
} }
def get_runtime_param_setters( def get_runtime_param_setters(

View File

@ -0,0 +1,31 @@
# Copyright 2024 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module that defines the TableCustomSQLQueryParamsSetter class."""
from metadata.data_quality.validations.models import (
TableCustomSQLQueryRuntimeParameters,
)
from metadata.data_quality.validations.runtime_param_setter.param_setter import (
RuntimeParameterSetter,
)
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
class TableCustomSQLQueryParamsSetter(RuntimeParameterSetter):
"""Set runtime parameters for a the table custom sql query test."""
def get_parameters(self, test_case) -> TableCustomSQLQueryRuntimeParameters:
return TableCustomSQLQueryRuntimeParameters(
conn_config=DatabaseConnection(
config=self.service_connection_config,
),
entity=self.table_entity,
)

View File

@ -98,13 +98,37 @@ class BaseTableCustomSQLQueryValidator(BaseTestValidator):
status = TestCaseStatus.Failed status = TestCaseStatus.Failed
result_value = len_rows result_value = len_rows
if self.test_case.computePassedFailedRowCount:
row_count = self.get_row_count()
else:
row_count = None
return self.get_test_case_result_object( return self.get_test_case_result_object(
self.execution_date, self.execution_date,
status, status,
f"Found {result_value} row(s). Test query is expected to return {threshold} row.", f"Found {result_value} row(s). Test query is expected to return {threshold} row.",
[TestResultValue(name=RESULT_ROW_COUNT, value=str(result_value))], [TestResultValue(name=RESULT_ROW_COUNT, value=str(result_value))],
row_count=row_count,
failed_rows=result_value,
) )
@abstractmethod @abstractmethod
def _run_results(self, sql_expression: str, strategy: Strategy = Strategy.ROWS): def _run_results(self, sql_expression: str, strategy: Strategy = Strategy.ROWS):
raise NotImplementedError raise NotImplementedError
@abstractmethod
def compute_row_count(self):
"""Compute row count for the given column
Raises:
NotImplementedError:
"""
raise NotImplementedError
def get_row_count(self) -> int:
"""Get row count
Returns:
Tuple[int, int]:
"""
return self.compute_row_count()

View File

@ -13,6 +13,8 @@
Validator for table custom SQL Query test case Validator for table custom SQL Query test case
""" """
from typing import List, Optional
from metadata.data_quality.validations.mixins.pandas_validator_mixin import ( from metadata.data_quality.validations.mixins.pandas_validator_mixin import (
PandasValidatorMixin, PandasValidatorMixin,
) )
@ -39,3 +41,39 @@ class TableCustomSQLQueryValidator(
if len(runner.query(sql_expression)) if len(runner.query(sql_expression))
] ]
) )
def compute_row_count(self) -> Optional[int]:
"""Compute row count for the given column
Returns:
Optional[int]: Total number of rows across all dataframes
"""
runner: List["DataFrame"] = self.runner # type: ignore
if not runner:
return None
total_rows = 0
partition_expression = next(
(
param.value
for param in self.test_case.parameterValues
if param.name == "partitionExpression"
),
None,
)
for dataframe in runner:
if dataframe is not None:
if partition_expression:
try:
total_rows += len(dataframe.query(partition_expression))
except Exception as e:
logger.error(
"Error executing partition expression, "
f"expression may be invalid: {partition_expression} - {e}"
)
return None
else:
total_rows += len(dataframe.index)
return total_rows if total_rows > 0 else None

View File

@ -13,21 +13,42 @@
Validator for table custom SQL Query test case Validator for table custom SQL Query test case
""" """
from typing import Optional, cast
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.sql import func, select
from metadata.data_quality.validations.mixins.sqa_validator_mixin import ( from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
SQAValidatorMixin, SQAValidatorMixin,
) )
from metadata.data_quality.validations.models import (
TableCustomSQLQueryRuntimeParameters,
)
from metadata.data_quality.validations.table.base.tableCustomSQLQuery import ( from metadata.data_quality.validations.table.base.tableCustomSQLQuery import (
BaseTableCustomSQLQueryValidator, BaseTableCustomSQLQueryValidator,
Strategy, Strategy,
) )
from metadata.generated.schema.tests.basic import TestCaseResult
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.orm.functions.table_metric_computer import TableMetricComputer
from metadata.profiler.processor.runner import QueryRunner
from metadata.utils.helpers import is_safe_sql_query from metadata.utils.helpers import is_safe_sql_query
class TableCustomSQLQueryValidator(BaseTableCustomSQLQueryValidator, SQAValidatorMixin): class TableCustomSQLQueryValidator(BaseTableCustomSQLQueryValidator, SQAValidatorMixin):
"""Validator for table custom SQL Query test case""" """Validator for table custom SQL Query test case"""
def run_validation(self) -> TestCaseResult:
"""Run validation for the given test case
Returns:
TestCaseResult:
"""
self.runtime_params = self.get_runtime_parameters(
TableCustomSQLQueryRuntimeParameters
)
return super().run_validation()
def _run_results(self, sql_expression: str, strategy: Strategy = Strategy.ROWS): def _run_results(self, sql_expression: str, strategy: Strategy = Strategy.ROWS):
"""compute result of the test case""" """compute result of the test case"""
if not is_safe_sql_query(sql_expression): if not is_safe_sql_query(sql_expression):
@ -48,3 +69,39 @@ class TableCustomSQLQueryValidator(BaseTableCustomSQLQueryValidator, SQAValidato
except Exception as exc: except Exception as exc:
self.runner._session.rollback() # pylint: disable=protected-access self.runner._session.rollback() # pylint: disable=protected-access
raise exc raise exc
def compute_row_count(self) -> Optional[int]:
"""Compute row count for the given column
Raises:
NotImplementedError:
"""
partition_expression = next(
(
param.value
for param in self.test_case.parameterValues
if param.name == "partitionExpression"
),
None,
)
if partition_expression:
stmt = (
select(func.count())
.select_from(self.runner.table)
.filter(text(partition_expression))
)
return self.runner.session.execute(stmt).scalar()
self.runner = cast(QueryRunner, self.runner)
dialect = self.runner._session.get_bind().dialect.name
table_metric_computer: TableMetricComputer = TableMetricComputer(
dialect,
runner=self.runner,
metrics=[Metrics.ROW_COUNT],
conn_config=self.runtime_params.conn_config,
entity=self.runtime_params.entity,
)
row = table_metric_computer.compute()
if row:
return dict(row).get(Metrics.ROW_COUNT.value.name())
return None

View File

@ -38,18 +38,9 @@ from metadata.utils.logger import profiler_interface_registry_logger
logger = profiler_interface_registry_logger() logger = profiler_interface_registry_logger()
@inject
def get_row_count_metric(metrics: Inject[Type[MetricRegistry]] = None):
if metrics is None:
raise DependencyNotFoundError(
"MetricRegistry dependency not found. Please ensure the MetricRegistry is properly registered."
)
return metrics.ROW_COUNT().name()
COLUMN_COUNT = "columnCount" COLUMN_COUNT = "columnCount"
COLUMN_NAMES = "columnNames" COLUMN_NAMES = "columnNames"
ROW_COUNT = get_row_count_metric() ROW_COUNT = "rowCount"
SIZE_IN_BYTES = "sizeInBytes" SIZE_IN_BYTES = "sizeInBytes"
CREATE_DATETIME = "createDateTime" CREATE_DATETIME = "createDateTime"

View File

@ -51,12 +51,12 @@ from metadata.profiler.metrics.core import (
TMetric, TMetric,
) )
from metadata.profiler.metrics.static.row_count import RowCount from metadata.profiler.metrics.static.row_count import RowCount
from metadata.profiler.orm.functions.table_metric_computer import CREATE_DATETIME
from metadata.profiler.orm.registry import NOT_COMPUTE from metadata.profiler.orm.registry import NOT_COMPUTE
from metadata.profiler.processor.metric_filter import MetricFilter from metadata.profiler.processor.metric_filter import MetricFilter
from metadata.utils.logger import profiler_logger from metadata.utils.logger import profiler_logger
logger = profiler_logger() logger = profiler_logger()
CREATE_DATETIME = "createDateTime"
class MissingMetricException(Exception): class MissingMetricException(Exception):

View File

@ -156,6 +156,8 @@ def create_sqlite_table():
session.add_all(data) session.add_all(data)
session.commit() session.commit()
runner.service_connection = sqlite_conn
runner.entity = TABLE
yield runner yield runner
# clean up # clean up
User.__table__.drop(bind=engine) User.__table__.drop(bind=engine)
@ -635,6 +637,31 @@ def test_case_table_custom_sql_unsafe_query_aborted():
) # type: ignore ) # type: ignore
@pytest.fixture
def test_case_table_custom_sql_with_partition_condition():
"""Test case for test column_value_median_to_be_between"""
return TestCase(
name=TEST_CASE_NAME,
entityLink=ENTITY_LINK_USER,
testSuite=EntityReference(id=uuid4(), type="TestSuite"), # type: ignore
testDefinition=EntityReference(id=uuid4(), type="TestDefinition"), # type: ignore
parameterValues=[
TestCaseParameterValue(
name="sqlExpression",
value="SELECT * FROM users WHERE age > 20 AND name = 'John'",
),
TestCaseParameterValue(
name="strategy",
value="ROWS",
),
TestCaseParameterValue(
name="partitionExpression",
value="name = 'John'",
),
],
) # type: ignore
@pytest.fixture @pytest.fixture
def test_case_table_row_count_to_be_between(): def test_case_table_row_count_to_be_between():
"""Test case for test column_value_median_to_be_between""" """Test case for test column_value_median_to_be_between"""
@ -709,6 +736,23 @@ def test_case_table_custom_sql_query_success_dl():
) )
@pytest.fixture
def test_case_table_custom_sql_query_success_dl_with_partition_expression():
"""Test case for test custom SQL table test"""
return TestCase(
name=TEST_CASE_NAME,
entityLink=ENTITY_LINK_USER,
testSuite=EntityReference(id=uuid4(), type="TestSuite"), # type: ignore
testDefinition=EntityReference(id=uuid4(), type="TestDefinition"), # type: ignore
parameterValues=[
TestCaseParameterValue(name="sqlExpression", value="age < 0"),
TestCaseParameterValue(
name="partitionExpression", value="nickname == 'johnny b goode'"
),
],
)
@pytest.fixture @pytest.fixture
def test_case_column_values_to_be_between_date(): def test_case_column_values_to_be_between_date():
return TestCase( return TestCase(

View File

@ -19,11 +19,27 @@ from unittest.mock import patch
import pytest import pytest
from metadata.data_quality.validations.models import (
TableCustomSQLQueryRuntimeParameters,
)
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus
from metadata.generated.schema.tests.testCase import TestCaseParameterValue
from metadata.utils.importer import import_test_case_class from metadata.utils.importer import import_test_case_class
EXECUTION_DATE = datetime.strptime("2021-07-03", "%Y-%m-%d") EXECUTION_DATE = datetime.strptime("2021-07-03", "%Y-%m-%d")
TEST_CASE_SUPPORT_ROW_LEVEL_PASS_FAILED = {
"columnValuesLengthToBeBetween",
"columnValuesToBeBetween",
"columnValuesToBeInSet",
"columnValuesToBeNotInSet",
"columnValuesToBeNotNull",
"columnValuesToBeUnique",
"columnValuesToMatchRegex",
"columnValuesToNotMatchRegex",
"tableCustomSQLQuery",
}
# pylint: disable=line-too-long # pylint: disable=line-too-long
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -359,6 +375,12 @@ EXECUTION_DATE = datetime.strptime("2021-07-03", "%Y-%m-%d")
"TABLE", "TABLE",
(TestCaseResult, "0", None, TestCaseStatus.Success, None, None, None, None), (TestCaseResult, "0", None, TestCaseStatus.Success, None, None, None, None),
), ),
(
"test_case_table_custom_sql_with_partition_condition",
"tableCustomSQLQuery",
"TABLE",
(TestCaseResult, "10", None, TestCaseStatus.Failed, 10, 10, 50.0, 50.0),
),
( (
"test_case_table_row_count_to_be_between", "test_case_table_row_count_to_be_between",
"tableRowCountToBeBetween", "tableRowCountToBeBetween",
@ -460,6 +482,22 @@ def test_suite_validation_database(
failed_percentage, failed_percentage,
) = expected ) = expected
if test_case_type in TEST_CASE_SUPPORT_ROW_LEVEL_PASS_FAILED:
test_case.computePassedFailedRowCount = True
if test_case_type == "tableCustomSQLQuery":
runtime_params = TableCustomSQLQueryRuntimeParameters(
conn_config=DatabaseConnection(
config=create_sqlite_table.service_connection,
),
entity=create_sqlite_table.entity,
)
test_case.parameterValues.append(
TestCaseParameterValue(
name=TableCustomSQLQueryRuntimeParameters.__name__,
value=runtime_params.model_dump_json(),
)
)
if test_case_name == "test_case_column_values_to_be_between_date": if test_case_name == "test_case_column_values_to_be_between_date":
with patch( with patch(
"metadata.data_quality.validations.column.sqlalchemy.columnValuesToBeBetween.ColumnValuesToBeBetweenValidator._run_results", "metadata.data_quality.validations.column.sqlalchemy.columnValuesToBeBetween.ColumnValuesToBeBetweenValidator._run_results",
@ -525,3 +563,11 @@ def test_suite_validation_database(
if failed_percentage: if failed_percentage:
assert round(res.failedRowsPercentage, 2) == failed_percentage assert round(res.failedRowsPercentage, 2) == failed_percentage
assert res.testCaseStatus == status assert res.testCaseStatus == status
if (
test_case_type in TEST_CASE_SUPPORT_ROW_LEVEL_PASS_FAILED
and test_case_name != "test_case_table_custom_sql_unsafe_query_aborted"
):
assert res.failedRows is not None
assert res.failedRowsPercentage is not None
assert res.passedRows is not None
assert res.passedRowsPercentage is not None

View File

@ -23,6 +23,18 @@ from pandas import DataFrame
from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus
from metadata.utils.importer import import_test_case_class from metadata.utils.importer import import_test_case_class
TEST_CASE_SUPPORT_ROW_LEVEL_PASS_FAILED = {
"columnValuesLengthToBeBetween",
"columnValuesToBeBetween",
"columnValuesToBeInSet",
"columnValuesToBeNotInSet",
"columnValuesToBeNotNull",
"columnValuesToBeUnique",
"columnValuesToMatchRegex",
"columnValuesToNotMatchRegex",
"tableCustomSQLQuery",
}
EXECUTION_DATE = datetime.strptime("2021-07-03", "%Y-%m-%d") EXECUTION_DATE = datetime.strptime("2021-07-03", "%Y-%m-%d")
DL_DATA = ( DL_DATA = (
[ [
@ -536,6 +548,21 @@ DATALAKE_DATA_FRAME = lambda times_increase_sample_data: DataFrame(
0.0, 0.0,
), ),
), ),
(
"test_case_table_custom_sql_query_success_dl_with_partition_expression",
"tableCustomSQLQuery",
"TABLE",
(
TestCaseResult,
None,
None,
TestCaseStatus.Success,
2000,
0,
100.0,
0.0,
),
),
], ],
) )
def test_suite_validation_datalake( def test_suite_validation_datalake(
@ -559,6 +586,9 @@ def test_suite_validation_datalake(
failed_percentage, failed_percentage,
) = expected ) = expected
if test_case_type in TEST_CASE_SUPPORT_ROW_LEVEL_PASS_FAILED:
test_case.computePassedFailedRowCount = True
test_handler_obj = import_test_case_class( test_handler_obj = import_test_case_class(
test_type, test_type,
"pandas", "pandas",
@ -587,3 +617,8 @@ def test_suite_validation_datalake(
if failed_percentage: if failed_percentage:
assert round(res.failedRowsPercentage, 2) == failed_percentage assert round(res.failedRowsPercentage, 2) == failed_percentage
assert res.testCaseStatus == status assert res.testCaseStatus == status
if test_case_type in TEST_CASE_SUPPORT_ROW_LEVEL_PASS_FAILED:
assert res.failedRows is not None
assert res.failedRowsPercentage is not None
assert res.passedRows is not None
assert res.passedRowsPercentage is not None

View File

@ -35,8 +35,16 @@
"description": "Threshold to use to determine if the test passes or fails (defaults to 0).", "description": "Threshold to use to determine if the test passes or fails (defaults to 0).",
"dataType": "NUMBER", "dataType": "NUMBER",
"required": false "required": false
},
{
"name": "partitionExpression",
"displayName": "Partition Expression",
"description": "Partition expression that will be used to compute the passed/failed row count, if compute row count is enabled.",
"dataType": "STRING",
"required": false
} }
], ],
"supportsRowLevelPassedFailed": true,
"provider": "system", "provider": "system",
"dataQualityDimension": "SQL" "dataQualityDimension": "SQL"
} }