ISSUE #23435 - Fix pass / fail count for custom SQL (#23506)

* fix: added logic to compute pass/fail for sql queries with cte, nested queries, and joins

* added logic to correctly compute pass / fail rows

* style: ran python linting

* fix: failing tests

* style: fix linting error

* fix: flawed count logic

* fix: handle case where we don't compute row count
This commit is contained in:
Teddy 2025-09-23 16:53:51 +02:00 committed by GitHub
parent 79fde4ab02
commit 57c5a50d20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1728 additions and 13 deletions

View File

@ -87,11 +87,13 @@ class BaseTableCustomSQLQueryValidator(BaseTestValidator):
[TestResultValue(name=RESULT_ROW_COUNT, value=None)],
)
len_rows = rows if isinstance(rows, int) else len(rows)
if evaluate_threshold(
test_passed = evaluate_threshold(
threshold,
operator,
len_rows,
):
)
if test_passed:
status = TestCaseStatus.Success
result_value = len_rows
else:
@ -99,17 +101,23 @@ class BaseTableCustomSQLQueryValidator(BaseTestValidator):
result_value = len_rows
if self.test_case.computePassedFailedRowCount:
row_count = self.get_row_count()
row_count = self._get_total_row_count_if_needed()
passed_rows, failed_rows = self._calculate_passed_failed_rows(
test_passed, operator, threshold, len_rows, row_count
)
else:
passed_rows = None
failed_rows = None
row_count = None
return self.get_test_case_result_object(
self.execution_date,
status,
f"Found {result_value} row(s). Test query is expected to return {threshold} row.",
f"Found {result_value} row(s). Test query is expected to return {operator} {threshold} row(s).",
[TestResultValue(name=RESULT_ROW_COUNT, value=str(result_value))],
row_count=row_count,
failed_rows=result_value,
failed_rows=failed_rows,
passed_rows=passed_rows,
)
@abstractmethod
@ -132,3 +140,100 @@ class BaseTableCustomSQLQueryValidator(BaseTestValidator):
Tuple[int, int]:
"""
return self.compute_row_count()
def _get_total_row_count_if_needed(self) -> int:
"""Get total row count if computePassedFailedRowCount is enabled"""
return self.get_row_count()
def _calculate_passed_failed_rows(
self,
test_passed: bool,
operator: str,
threshold: int,
len_rows: int,
row_count: int,
) -> tuple[int, int]:
"""Calculate passed and failed rows based on test result and operator
Args:
test_passed: Whether the test passed
operator: Comparison operator (>, >=, <, <=, ==)
threshold: Expected threshold value
len_rows: Number of rows returned by the test query
row_count: Total number of rows in the table (or None)
Returns:
Tuple of (passed_rows, failed_rows)
"""
if test_passed:
return self._calculate_passed_rows_success(operator, len_rows, row_count)
return self._calculate_passed_rows_failure(
operator, threshold, len_rows, row_count
)
def _calculate_passed_rows_success(
self, operator: str, len_rows: int, row_count: int
) -> tuple[int, int]:
"""Calculate passed/failed rows when test passed"""
if operator in (">", ">="):
passed_rows = len_rows
failed_rows = (row_count - len_rows) if row_count else 0
elif operator in ("<", "<="):
passed_rows = row_count - len_rows
failed_rows = len_rows
elif operator == "==":
passed_rows = len_rows
failed_rows = row_count - len_rows
else:
passed_rows = len_rows
failed_rows = 0
return max(0, passed_rows), max(0, failed_rows)
def _calculate_passed_rows_failure(
self, operator: str, threshold: int, len_rows: int, row_count: int
) -> tuple[int, int]:
"""Calculate passed/failed rows when test failed"""
if operator in (">", ">="):
return self._calculate_greater_than_failure(len_rows, row_count)
if operator in ("<", "<="):
return self._calculate_less_than_failure(len_rows, row_count)
if operator == "==":
return self._calculate_equal_failure(threshold, len_rows, row_count)
failed_rows = row_count if row_count else len_rows
return 0, max(0, failed_rows)
def _calculate_greater_than_failure(
self, len_rows: int, row_count: int
) -> tuple[int, int]:
"""Calculate rows for > or >= operator failure (expected more rows)"""
passed_rows = len_rows
failed_rows = (row_count - len_rows) if row_count else 0
return max(0, passed_rows), max(0, failed_rows)
def _calculate_less_than_failure(
self, len_rows: int, row_count: int
) -> tuple[int, int]:
"""Calculate rows for < or <= operator failure (expected fewer rows)"""
failed_rows = len_rows
passed_rows = row_count - failed_rows
return max(0, passed_rows), max(0, failed_rows)
def _calculate_equal_failure(
self, threshold: int, len_rows: int, row_count: int
) -> tuple[int, int]:
"""Calculate rows for == operator failure (expected exact count)"""
if row_count:
if len_rows > threshold:
failed_rows = len_rows - threshold
passed_rows = row_count - failed_rows
else:
failed_rows = row_count - len_rows
passed_rows = len_rows
else:
failed_rows = abs(len_rows - threshold)
passed_rows = 0
return max(0, passed_rows), max(0, failed_rows)

View File

@ -13,10 +13,13 @@
Validator for table custom SQL Query test case
"""
from typing import Optional, cast
from typing import Optional, Tuple, cast
import sqlparse
from sqlalchemy import text
from sqlalchemy.sql import func, select
from sqlparse.sql import Statement, Token, Where
from sqlparse.tokens import Keyword
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
SQAValidatorMixin,
@ -33,11 +36,244 @@ from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.orm.functions.table_metric_computer import TableMetricComputer
from metadata.profiler.processor.runner import QueryRunner
from metadata.utils.helpers import is_safe_sql_query
from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
class TableCustomSQLQueryValidator(BaseTableCustomSQLQueryValidator, SQAValidatorMixin):
"""Validator for table custom SQL Query test case"""
def _replace_where_clause(
self, sql_query: str, partition_expression: str
) -> Optional[str]:
"""Replace or add WHERE clause in SQL query using sqlparse.
This method properly handles:
- Queries with existing WHERE clause (replaces it)
- Queries without WHERE clause (adds it)
- Complex queries with joins, subqueries, CTEs
- Preserves GROUP BY, ORDER BY, LIMIT, etc.
Args:
sql_query: Original SQL query
partition_expression: New WHERE condition (without WHERE keyword)
Returns:
Modified SQL query with partition_expression as WHERE clause
"""
parsed = sqlparse.parse(sql_query)
if not parsed or len(parsed) == 0:
return None
statement: Statement = parsed[0]
tokens = list(statement.tokens)
where_idx, where_end_idx, insert_before_idx = self._find_clause_positions(
tokens
)
new_tokens = self._build_new_tokens(
tokens, where_idx, where_end_idx, insert_before_idx, partition_expression
)
return "".join(str(token) for token in new_tokens)
def _find_clause_positions(
self, tokens: list
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
"""Find positions of WHERE clause and insertion points in token list.
Args:
tokens: List of parsed SQL tokens
Returns:
Tuple of (where_idx, where_end_idx, insert_before_idx)
"""
where_idx = None
where_end_idx = None
insert_before_idx = None
paren_depth = 0
for i, token in enumerate(tokens):
paren_depth = self._update_parentheses_depth(token, paren_depth)
if isinstance(token, Where) and paren_depth == 0:
where_idx = i
where_end_idx = i + 1
break
if self._should_insert_before_token(token, insert_before_idx, paren_depth):
insert_before_idx = i
return where_idx, where_end_idx, insert_before_idx
def _update_parentheses_depth(self, token: Token, current_depth: int) -> int:
"""Update parentheses depth based on token content.
Args:
token: SQL token to analyze
current_depth: Current parentheses depth
Returns:
Updated parentheses depth
"""
if token.ttype is None and hasattr(token, "tokens"):
paren_count = str(token).count("(") - str(token).count(")")
return current_depth + paren_count
elif token.value == "(":
return current_depth + 1
elif token.value == ")":
return current_depth - 1
return current_depth
def _should_insert_before_token(
self, token: Token, insert_before_idx: Optional[int], paren_depth: int
) -> bool:
"""Check if WHERE clause should be inserted before this token.
Args:
token: SQL token to check
insert_before_idx: Current insertion index (None if not set)
paren_depth: Current parentheses depth
Returns:
True if WHERE should be inserted before this token
"""
if insert_before_idx is not None or paren_depth != 0:
return False
if token.ttype is not Keyword:
return False
clause_keywords = {
"GROUP BY",
"ORDER BY",
"HAVING",
"LIMIT",
"OFFSET",
"UNION",
"EXCEPT",
"INTERSECT",
}
return any(keyword in token.value.upper() for keyword in clause_keywords)
def _build_new_tokens(
self,
tokens: list,
where_idx: Optional[int],
where_end_idx: Optional[int],
insert_before_idx: Optional[int],
partition_expression: str,
) -> list:
"""Build new token list with WHERE clause inserted or replaced.
Args:
tokens: Original token list
where_idx: Index of existing WHERE clause (None if not found)
where_end_idx: End index of existing WHERE clause
insert_before_idx: Index to insert WHERE before (None if append)
partition_expression: WHERE condition expression
Returns:
New list of tokens with WHERE clause
"""
if where_idx is not None:
return self._replace_existing_where(
tokens, where_idx, where_end_idx, partition_expression
)
elif insert_before_idx is not None:
return self._insert_where_before_clause(
tokens, insert_before_idx, partition_expression
)
else:
return self._append_where_clause(tokens, partition_expression)
def _replace_existing_where(
self,
tokens: list,
where_idx: int,
where_end_idx: int,
partition_expression: str,
) -> list:
"""Replace existing WHERE clause with new expression.
Args:
tokens: Original token list
where_idx: Index of WHERE clause to replace
where_end_idx: End index of WHERE clause
partition_expression: New WHERE condition
Returns:
Token list with replaced WHERE clause
"""
original_where = str(tokens[where_idx])
trailing_whitespace = self._extract_trailing_whitespace(original_where)
return (
tokens[:where_idx]
+ [
Token(Keyword, "WHERE"),
Token(None, f" {partition_expression}{trailing_whitespace}"),
]
+ tokens[where_end_idx:]
)
def _insert_where_before_clause(
self, tokens: list, insert_before_idx: int, partition_expression: str
) -> list:
"""Insert WHERE clause before specified token index.
Args:
tokens: Original token list
insert_before_idx: Index to insert WHERE clause before
partition_expression: WHERE condition expression
Returns:
Token list with WHERE clause inserted
"""
return (
tokens[:insert_before_idx]
+ [Token(Keyword, "WHERE"), Token(None, f" {partition_expression} ")]
+ tokens[insert_before_idx:]
)
def _append_where_clause(self, tokens: list, partition_expression: str) -> list:
"""Append WHERE clause to end of token list.
Args:
tokens: Original token list
partition_expression: WHERE condition expression
Returns:
Token list with WHERE clause appended
"""
return tokens + [
Token(None, " "),
Token(Keyword, "WHERE"),
Token(None, f" {partition_expression}"),
]
def _extract_trailing_whitespace(self, where_clause: str) -> str:
"""Extract trailing whitespace from WHERE clause string.
Args:
where_clause: Original WHERE clause string
Returns:
Trailing whitespace string
"""
if not where_clause.split():
return ""
last_word = where_clause.split()[-1]
where_content_end = where_clause.rfind(last_word) + len(last_word)
if where_content_end < len(where_clause):
return where_clause[where_content_end:]
return ""
def run_validation(self) -> TestCaseResult:
"""Run validation for the given test case
@ -85,12 +321,39 @@ class TableCustomSQLQueryValidator(BaseTableCustomSQLQueryValidator, SQAValidato
None,
)
if partition_expression:
stmt = (
select(func.count())
.select_from(self.runner.table)
.filter(text(partition_expression))
custom_sql = self.get_test_case_param_value(
self.test_case.parameterValues, # type: ignore
"sqlExpression",
str,
)
return self.runner.session.execute(stmt).scalar()
if custom_sql:
modified_query = self._replace_where_clause(
custom_sql, partition_expression
)
if modified_query is None:
return None
count_query = f"SELECT COUNT(*) FROM ({modified_query}) AS test_results"
try:
result = self.runner.session.execute(text(count_query)).scalar()
return result
except Exception as exc:
logger.error(
"Failed to execute custom SQL with partition expression. "
f"Query: {count_query}\n"
f"Error: {exc}\n",
exc_info=True,
)
self.runner.session.rollback()
raise exc
else:
stmt = (
select(func.count())
.select_from(self.runner.table)
.filter(text(partition_expression))
)
return self.runner.session.execute(stmt).scalar()
self.runner = cast(QueryRunner, self.runner)
dialect = self.runner._session.get_bind().dialect.name

View File

@ -86,7 +86,8 @@ class OpenMetadataValidationAction1xx(ValidationAction):
# This will be initialized in the run method
ometa_conn: Optional[OpenMetadata] = None
def run( # pylint: disable=unused-argument, arguments-differ
# pylint: disable=unused-argument,arguments-differ
def run(
self,
checkpoint_result: CheckpointResult,
action_context: Union[ActionContext, None],

View File

@ -26,7 +26,7 @@ def create_service_request(postgres_container, tmp_path_factory):
username=postgres_container.username,
authType=BasicAuth(password=postgres_container.password),
hostPort="localhost:"
+ postgres_container.get_exposed_port(postgres_container.port),
+ str(postgres_container.get_exposed_port(postgres_container.port)),
database="dvdrental",
)
),

View File

@ -0,0 +1,180 @@
"""
Unit tests for passed/failed row count calculation logic
"""
import unittest
def calculate_passed_failed_rows(
test_passed: bool,
operator: str,
threshold: int,
actual_rows: int,
total_rows: int = None,
):
"""
Calculate passed and failed rows based on test result, operator, threshold, and actual row count.
This function replicates the LEGACY logic for cases without total row count.
Note: This is kept for backward compatibility but the new logic with total_rows is more accurate.
"""
if total_rows is None:
if test_passed:
return actual_rows, 0
else:
if operator in (">", ">="):
failed_rows = 0
passed_rows = actual_rows
elif operator in ("<", "<="):
failed_rows = max(0, actual_rows - threshold)
passed_rows = threshold
elif operator == "==":
failed_rows = abs(actual_rows - threshold)
passed_rows = 0
else:
failed_rows = actual_rows
passed_rows = 0
return max(0, passed_rows), max(0, failed_rows)
else:
raise NotImplementedError(
"Use test_row_count_logic_with_total.py for total row count tests"
)
class TestPassedFailedRowCalculation(unittest.TestCase):
"""Test cases for passed/failed row count calculation logic"""
def test_greater_than_operator_success(self):
"""Test > operator when test passes"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=True, operator=">", threshold=5, actual_rows=10
)
self.assertEqual(passed_rows, 10)
self.assertEqual(failed_rows, 0)
def test_greater_than_operator_failure(self):
"""Test > operator when test fails (got fewer rows than expected)"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=False, operator=">", threshold=10, actual_rows=5
)
self.assertEqual(passed_rows, 5)
self.assertEqual(failed_rows, 0)
def test_greater_than_equal_operator_success(self):
"""Test >= operator when test passes"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=True, operator=">=", threshold=10, actual_rows=10
)
self.assertEqual(passed_rows, 10)
self.assertEqual(failed_rows, 0)
def test_greater_than_equal_operator_failure(self):
"""Test >= operator when test fails"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=False, operator=">=", threshold=15, actual_rows=8
)
self.assertEqual(passed_rows, 8)
self.assertEqual(failed_rows, 0)
def test_less_than_operator_success(self):
"""Test < operator when test passes"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=True, operator="<", threshold=10, actual_rows=5
)
self.assertEqual(passed_rows, 5)
self.assertEqual(failed_rows, 0)
def test_less_than_operator_failure(self):
"""Test < operator when test fails (got more rows than expected)"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=False, operator="<", threshold=5, actual_rows=12
)
self.assertEqual(passed_rows, 5)
self.assertEqual(failed_rows, 7)
def test_less_than_equal_operator_success(self):
"""Test <= operator when test passes"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=True, operator="<=", threshold=10, actual_rows=10
)
self.assertEqual(passed_rows, 10)
self.assertEqual(failed_rows, 0)
def test_less_than_equal_operator_failure(self):
"""Test <= operator when test fails"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=False, operator="<=", threshold=8, actual_rows=15
)
self.assertEqual(passed_rows, 8)
self.assertEqual(failed_rows, 7)
def test_equal_operator_success(self):
"""Test == operator when test passes"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=True, operator="==", threshold=10, actual_rows=10
)
self.assertEqual(passed_rows, 10)
self.assertEqual(failed_rows, 0)
def test_equal_operator_failure_more_rows(self):
"""Test == operator when test fails with more rows than expected"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=False, operator="==", threshold=10, actual_rows=15
)
self.assertEqual(passed_rows, 0)
self.assertEqual(failed_rows, 5)
def test_equal_operator_failure_fewer_rows(self):
"""Test == operator when test fails with fewer rows than expected"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=False, operator="==", threshold=10, actual_rows=3
)
self.assertEqual(passed_rows, 0)
self.assertEqual(failed_rows, 7)
def test_edge_case_zero_threshold_greater_than(self):
"""Test edge case with zero threshold and > operator"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=True, operator=">", threshold=0, actual_rows=5
)
self.assertEqual(passed_rows, 5)
self.assertEqual(failed_rows, 0)
def test_edge_case_zero_actual_rows_less_than(self):
"""Test edge case with zero actual rows and < operator"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=True, operator="<", threshold=5, actual_rows=0
)
self.assertEqual(passed_rows, 0)
self.assertEqual(failed_rows, 0)
def test_edge_case_negative_protection_greater_than_equal(self):
"""Test protection against negative calculations for >= operator"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=True, operator=">=", threshold=5, actual_rows=10
)
self.assertEqual(passed_rows, 10)
self.assertEqual(failed_rows, 0)
def test_equal_operator_with_zero_threshold(self):
"""Test == operator with zero threshold"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=False, operator="==", threshold=0, actual_rows=5
)
self.assertEqual(passed_rows, 0)
self.assertEqual(failed_rows, 5)
def test_unknown_operator_fallback(self):
"""Test fallback for unknown operators"""
passed_rows, failed_rows = calculate_passed_failed_rows(
test_passed=False, operator="!=", threshold=10, actual_rows=15
)
self.assertEqual(passed_rows, 0)
self.assertEqual(failed_rows, 15)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,302 @@
"""
Unit tests for passed/failed row count calculation with total row count
"""
import unittest
def calculate_passed_failed_rows_with_total(
test_passed: bool,
operator: str,
threshold: int,
actual_rows: int,
total_rows: int = None,
):
"""
Calculate passed and failed rows considering total row count.
This replicates the fixed logic from BaseTableCustomSQLQueryValidator.
"""
row_count = total_rows
len_rows = actual_rows
if test_passed:
if operator in (">", ">=", "=="):
passed_rows = len_rows
failed_rows = (row_count - len_rows) if row_count else 0
elif operator in ("<", "<="):
passed_rows = row_count if row_count else len_rows
failed_rows = 0
else:
passed_rows = len_rows
failed_rows = 0
else:
if operator in (">", ">="):
passed_rows = len_rows
failed_rows = (row_count - len_rows) if row_count else 0
elif operator in ("<", "<="):
if threshold <= 0:
if row_count:
failed_rows = row_count
passed_rows = 0
else:
failed_rows = max(len_rows, 1)
passed_rows = 0
else:
failed_rows = max(0, len_rows - threshold)
passed_rows = (row_count - failed_rows) if row_count else threshold
elif operator == "==":
if row_count:
if len_rows > threshold:
failed_rows = len_rows - threshold
passed_rows = row_count - failed_rows
else:
failed_rows = row_count - len_rows
passed_rows = len_rows
else:
failed_rows = abs(len_rows - threshold)
passed_rows = 0
else:
failed_rows = row_count if row_count else len_rows
passed_rows = 0
passed_rows = max(0, passed_rows)
failed_rows = max(0, failed_rows)
return passed_rows, failed_rows
class TestPassedFailedRowCalculationWithTotal(unittest.TestCase):
"""Test cases for the fixed row count calculation logic with total row count"""
def test_greater_than_operator_bug_case(self):
"""Test the bug case: >= with threshold 0, 0 rows returned, 1 total row"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=False,
operator=">=",
threshold=1,
actual_rows=0,
total_rows=1,
)
self.assertEqual(passed_rows, 0)
self.assertEqual(failed_rows, 1)
def test_greater_than_with_total_rows_success(self):
"""Test > operator success with total row count"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=True,
operator=">",
threshold=5,
actual_rows=10,
total_rows=100,
)
self.assertEqual(passed_rows, 10)
self.assertEqual(failed_rows, 90)
def test_greater_than_with_total_rows_failure(self):
"""Test > operator failure with total row count"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=False,
operator=">",
threshold=50,
actual_rows=10,
total_rows=100,
)
self.assertEqual(passed_rows, 10)
self.assertEqual(failed_rows, 90)
def test_less_than_with_total_rows_success(self):
"""Test < operator success with total row count"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=True,
operator="<",
threshold=10,
actual_rows=5,
total_rows=100,
)
self.assertEqual(passed_rows, 100)
self.assertEqual(failed_rows, 0)
def test_less_than_with_total_rows_failure(self):
"""Test < operator failure with total row count"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=False,
operator="<",
threshold=10,
actual_rows=20,
total_rows=100,
)
self.assertEqual(failed_rows, 10)
self.assertEqual(passed_rows, 90)
def test_less_than_equal_with_total_rows_success(self):
"""Test <= operator success with total row count"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=True,
operator="<=",
threshold=10,
actual_rows=10,
total_rows=100,
)
self.assertEqual(passed_rows, 100)
self.assertEqual(failed_rows, 0)
def test_less_than_equal_with_total_rows_failure(self):
"""Test <= operator failure with total row count"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=False,
operator="<=",
threshold=5,
actual_rows=15,
total_rows=100,
)
self.assertEqual(failed_rows, 10)
self.assertEqual(passed_rows, 90)
def test_equal_with_total_rows_success(self):
"""Test == operator success with total row count"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=True,
operator="==",
threshold=10,
actual_rows=10,
total_rows=100,
)
self.assertEqual(passed_rows, 10)
self.assertEqual(failed_rows, 90)
def test_equal_with_total_rows_failure_too_many(self):
"""Test == operator failure with too many rows"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=False,
operator="==",
threshold=10,
actual_rows=15,
total_rows=100,
)
self.assertEqual(failed_rows, 5)
self.assertEqual(passed_rows, 95)
def test_equal_with_total_rows_failure_too_few(self):
"""Test == operator failure with too few rows"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=False,
operator="==",
threshold=10,
actual_rows=3,
total_rows=100,
)
self.assertEqual(failed_rows, 97)
self.assertEqual(passed_rows, 3)
def test_edge_case_zero_total_rows(self):
"""Test edge case with zero total rows"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=False, operator=">", threshold=5, actual_rows=0, total_rows=0
)
self.assertEqual(passed_rows, 0)
self.assertEqual(failed_rows, 0)
def test_edge_case_no_total_row_count(self):
"""Test when total row count is not available (None)"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=False,
operator=">",
threshold=10,
actual_rows=5,
total_rows=None,
)
self.assertEqual(passed_rows, 5)
self.assertEqual(failed_rows, 0)
def test_greater_equal_zero_threshold_zero_result(self):
"""Test >= 0 with 0 results but rows in table"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=True,
operator=">=",
threshold=0,
actual_rows=0,
total_rows=100,
)
self.assertEqual(passed_rows, 0)
self.assertEqual(failed_rows, 100)
def test_all_rows_match_greater_than(self):
"""Test when all rows in table match the condition for >"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=True,
operator=">",
threshold=5,
actual_rows=100,
total_rows=100,
)
self.assertEqual(passed_rows, 100)
self.assertEqual(failed_rows, 0)
def test_no_rows_match_less_than(self):
"""Test when no rows match the condition for <"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=True,
operator="<",
threshold=10,
actual_rows=0,
total_rows=100,
)
self.assertEqual(passed_rows, 100)
self.assertEqual(failed_rows, 0)
def test_less_than_zero_threshold_failure_bug_case(self):
"""Test the reported bug case: < 0 threshold with failure"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=False,
operator="<",
threshold=0,
actual_rows=0,
total_rows=1,
)
self.assertEqual(passed_rows, 0)
self.assertEqual(failed_rows, 1)
def test_less_than_negative_threshold_failure(self):
"""Test < -1 threshold failure (impossible expectation)"""
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
test_passed=False,
operator="<",
threshold=-1,
actual_rows=5,
total_rows=100,
)
self.assertEqual(passed_rows, 0)
self.assertEqual(failed_rows, 100)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,247 @@
"""
Unit tests for TableCustomSQLQueryValidator._replace_where_clause method
"""
import unittest
from datetime import datetime
from unittest.mock import Mock
from metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery import (
TableCustomSQLQueryValidator,
)
class TestTableCustomSQLQueryValidator(unittest.TestCase):
"""Test cases for TableCustomSQLQueryValidator._replace_where_clause method"""
def setUp(self):
"""Set up test fixtures"""
mock_runner = Mock()
mock_test_case = Mock()
mock_execution_date = datetime.now()
self.validator = TableCustomSQLQueryValidator(
runner=mock_runner,
test_case=mock_test_case,
execution_date=mock_execution_date,
)
def test_simple_query_no_where_clause(self):
"""Test adding WHERE clause to simple query without existing WHERE"""
sql = "SELECT * FROM users"
partition_expr = "created_at > '2023-01-01'"
result = self.validator._replace_where_clause(sql, partition_expr)
expected = "SELECT * FROM users WHERE created_at > '2023-01-01'"
self.assertEqual(result, expected)
def test_simple_query_with_existing_where(self):
"""Test replacing existing WHERE clause in simple query"""
sql = "SELECT * FROM users WHERE id > 100"
partition_expr = "created_at > '2023-01-01'"
result = self.validator._replace_where_clause(sql, partition_expr)
expected = "SELECT * FROM users WHERE created_at > '2023-01-01'"
self.assertEqual(result, expected)
def test_query_with_order_by_no_where(self):
"""Test adding WHERE clause before ORDER BY"""
sql = "SELECT * FROM users ORDER BY name"
partition_expr = "status = 'active'"
result = self.validator._replace_where_clause(sql, partition_expr)
expected = "SELECT * FROM users WHERE status = 'active' ORDER BY name"
self.assertEqual(result, expected)
def test_query_with_group_by_no_where(self):
"""Test adding WHERE clause before GROUP BY"""
sql = "SELECT department, COUNT(*) FROM employees GROUP BY department"
partition_expr = "hire_date > '2020-01-01'"
result = self.validator._replace_where_clause(sql, partition_expr)
expected = "SELECT department, COUNT(*) FROM employees WHERE hire_date > '2020-01-01' GROUP BY department"
self.assertEqual(result, expected)
def test_query_with_subquery_inner_where_preserved(self):
"""Test that WHERE clause in subquery is preserved"""
sql = "SELECT foo FROM a INNER JOIN (SELECT bar FROM b WHERE abc = 3) WHERE a.id BETWEEN 2 AND 4"
partition_expr = "a.status = 'active'"
result = self.validator._replace_where_clause(sql, partition_expr)
expected = "SELECT foo FROM a INNER JOIN (SELECT bar FROM b WHERE abc = 3) WHERE a.status = 'active'"
self.assertEqual(result, expected)
def test_complex_subquery_multiple_levels(self):
"""Test complex nested subqueries with multiple WHERE clauses"""
sql = """SELECT * FROM users u
WHERE u.id IN (
SELECT user_id FROM orders o
WHERE o.total > (
SELECT AVG(total) FROM orders WHERE status = 'completed'
)
)"""
partition_expr = "u.created_at > '2023-01-01'"
result = self.validator._replace_where_clause(sql, partition_expr)
self.assertIn("WHERE u.created_at > '2023-01-01'", result)
self.assertNotIn("WHERE u.id IN", result)
def test_query_with_cte_where_preserved(self):
"""Test that WHERE clauses in CTEs are preserved"""
sql = """WITH active_users AS (
SELECT * FROM users WHERE status = 'active'
)
SELECT * FROM active_users WHERE id > 100"""
partition_expr = "created_at > '2023-01-01'"
result = self.validator._replace_where_clause(sql, partition_expr)
self.assertIn("WHERE status = 'active'", result)
self.assertIn("WHERE created_at > '2023-01-01'", result)
self.assertNotIn("WHERE id > 100", result)
def test_query_with_union_no_where(self):
"""Test adding WHERE clause before UNION"""
sql = "SELECT id FROM table1 UNION SELECT id FROM table2"
partition_expr = "status = 'active'"
result = self.validator._replace_where_clause(sql, partition_expr)
expected = (
"SELECT id FROM table1 WHERE status = 'active' UNION SELECT id FROM table2"
)
self.assertEqual(result, expected)
def test_query_with_having_no_where(self):
"""Test adding WHERE clause before HAVING"""
sql = "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5"
partition_expr = "hire_date > '2020-01-01'"
result = self.validator._replace_where_clause(sql, partition_expr)
expected = "SELECT department, COUNT(*) FROM employees WHERE hire_date > '2020-01-01' GROUP BY department HAVING COUNT(*) > 5"
self.assertEqual(result, expected)
def test_query_with_limit_no_where(self):
"""Test adding WHERE clause before LIMIT"""
sql = "SELECT * FROM users LIMIT 10"
partition_expr = "status = 'active'"
result = self.validator._replace_where_clause(sql, partition_expr)
expected = "SELECT * FROM users WHERE status = 'active' LIMIT 10"
self.assertEqual(result, expected)
def test_query_with_offset_no_where(self):
"""Test adding WHERE clause before OFFSET"""
sql = "SELECT * FROM users OFFSET 20"
partition_expr = "status = 'active'"
result = self.validator._replace_where_clause(sql, partition_expr)
expected = "SELECT * FROM users WHERE status = 'active' OFFSET 20"
self.assertEqual(result, expected)
def test_complex_query_with_joins_and_subqueries(self):
"""Test complex query with joins, subqueries, and existing WHERE"""
sql = """SELECT u.name, o.total
FROM users u
INNER JOIN orders o ON u.id = o.user_id
LEFT JOIN (
SELECT user_id, COUNT(*) as order_count
FROM orders
WHERE created_at > '2022-01-01'
GROUP BY user_id
) oc ON u.id = oc.user_id
WHERE u.status = 'active' AND o.total > 100
ORDER BY o.total DESC"""
partition_expr = "u.created_at BETWEEN '2023-01-01' AND '2023-12-31'"
result = self.validator._replace_where_clause(sql, partition_expr)
self.assertIn(
"WHERE u.created_at BETWEEN '2023-01-01' AND '2023-12-31'", result
)
self.assertIn("WHERE created_at > '2022-01-01'", result)
self.assertIn("ORDER BY o.total DESC", result)
self.assertNotIn("WHERE u.status = 'active' AND o.total > 100", result)
def test_query_with_multiple_clauses_existing_where(self):
"""Test query with WHERE, GROUP BY, HAVING, ORDER BY, LIMIT"""
sql = """SELECT department, AVG(salary)
FROM employees
WHERE salary > 50000
GROUP BY department
HAVING AVG(salary) > 60000
ORDER BY AVG(salary) DESC
LIMIT 5"""
partition_expr = "hire_date > '2020-01-01'"
result = self.validator._replace_where_clause(sql, partition_expr)
expected = """SELECT department, AVG(salary)
FROM employees
WHERE hire_date > '2020-01-01'
GROUP BY department
HAVING AVG(salary) > 60000
ORDER BY AVG(salary) DESC
LIMIT 5"""
self.assertEqual(result, expected)
def test_empty_sql_query(self):
"""Test handling of empty SQL query"""
sql = ""
partition_expr = "status = 'active'"
result = self.validator._replace_where_clause(sql, partition_expr)
self.assertIsNone(result)
def test_malformed_sql_query(self):
"""Test handling of malformed SQL query"""
sql = "SELECT FROM"
partition_expr = "status = 'active'"
result = self.validator._replace_where_clause(sql, partition_expr)
self.assertIsNotNone(result)
def test_case_insensitive_keywords(self):
"""Test that keyword matching is case insensitive"""
sql = "select * from users where id > 100 order by name"
partition_expr = "status = 'active'"
result = self.validator._replace_where_clause(sql, partition_expr)
expected = "select * from users WHERE status = 'active' order by name"
self.assertEqual(result, expected)
def test_deeply_nested_subqueries(self):
"""Test handling of deeply nested subqueries"""
sql = """SELECT * FROM table1
WHERE id IN (
SELECT t2.id FROM table2 t2
WHERE t2.value > (
SELECT AVG(t3.value) FROM table3 t3
WHERE t3.category IN (
SELECT category FROM categories
WHERE active = 1
)
)
)"""
partition_expr = "created_at > '2023-01-01'"
result = self.validator._replace_where_clause(sql, partition_expr)
self.assertIn("WHERE created_at > '2023-01-01'", result)
self.assertNotIn("WHERE id IN", result)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,425 @@
"""
Unit tests for BaseTableCustomSQLQueryValidator passed/failed row count logic
"""
import unittest
from datetime import datetime
from unittest.mock import Mock, patch
from metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery import (
TableCustomSQLQueryValidator,
)
from metadata.generated.schema.tests.basic import TestCaseStatus
class TestTableCustomSQLQueryRowCounts(unittest.TestCase):
"""Test cases for passed/failed row count calculation logic"""
def setUp(self):
"""Set up test fixtures"""
self.mock_runner = Mock()
self.mock_test_case = Mock()
self.mock_test_case.computePassedFailedRowCount = True
self.mock_test_case.parameterValues = []
self.mock_test_case.fullyQualifiedName = "test.case"
self.mock_execution_date = int(datetime.now().timestamp())
self.validator = TableCustomSQLQueryValidator(
runner=self.mock_runner,
test_case=self.mock_test_case,
execution_date=self.mock_execution_date,
)
def _create_mock_param_values(
self, operator, threshold, sql_expression="SELECT * FROM test"
):
"""Helper to create mock parameter values"""
import json
# Create runtime parameters JSON
runtime_params = {
"conn_config": {
"config": {
"type": "Mysql",
"scheme": "mysql+pymysql",
"username": "test",
"password": "test",
"hostPort": "localhost:3306",
"database": "test_db",
}
},
"entity": {
"id": "test-table-id",
"name": "test_table",
"fullyQualifiedName": "test.db.test_table",
},
}
# Create parameter mocks with explicit name attributes
sql_param = Mock()
sql_param.name = "sqlExpression"
sql_param.value = sql_expression
operator_param = Mock()
operator_param.name = "operator"
operator_param.value = operator
threshold_param = Mock()
threshold_param.name = "threshold"
threshold_param.value = threshold
strategy_param = Mock()
strategy_param.name = "strategy"
strategy_param.value = "COUNT"
runtime_param = Mock()
runtime_param.name = "TableCustomSQLQueryRuntimeParameters"
runtime_param.value = json.dumps(runtime_params)
return [
sql_param,
operator_param,
threshold_param,
strategy_param,
runtime_param,
]
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
)
def test_greater_than_operator_success(
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
):
"""Test > operator when test passes"""
self.mock_test_case.parameterValues = self._create_mock_param_values(">", 5)
mock_run_results.return_value = 10
mock_compute_row_count.return_value = 1000
mock_get_runtime_params.return_value = Mock()
result = self.validator.run_validation()
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
self.assertEqual(result.passedRows, 10)
self.assertEqual(result.failedRows, 990)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
)
def test_greater_than_operator_failure(
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
):
"""Test > operator when test fails (got fewer rows than expected)"""
self.mock_test_case.parameterValues = self._create_mock_param_values(">", 10)
mock_run_results.return_value = 5
mock_compute_row_count.return_value = 1000
mock_get_runtime_params.return_value = Mock()
result = self.validator.run_validation()
self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed)
self.assertEqual(result.passedRows, 5)
self.assertEqual(result.failedRows, 995)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
)
def test_greater_than_equal_operator_success(
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
):
"""Test >= operator when test passes"""
self.mock_test_case.parameterValues = self._create_mock_param_values(">=", 10)
mock_run_results.return_value = 10
mock_compute_row_count.return_value = 1000
mock_get_runtime_params.return_value = Mock()
result = self.validator.run_validation()
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
self.assertEqual(result.passedRows, 10)
self.assertEqual(result.failedRows, 990)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
)
def test_greater_than_equal_operator_failure(
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
):
"""Test >= operator when test fails"""
self.mock_test_case.parameterValues = self._create_mock_param_values(">=", 15)
mock_run_results.return_value = 8
mock_compute_row_count.return_value = 1000
mock_get_runtime_params.return_value = Mock()
result = self.validator.run_validation()
self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed)
self.assertEqual(result.passedRows, 8)
self.assertEqual(result.failedRows, 992)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
)
def test_less_than_operator_success(
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
):
"""Test < operator when test passes"""
self.mock_test_case.parameterValues = self._create_mock_param_values("<", 10)
mock_run_results.return_value = 5
mock_compute_row_count.return_value = 1000
mock_get_runtime_params.return_value = Mock()
result = self.validator.run_validation()
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
self.assertEqual(result.passedRows, 995)
self.assertEqual(result.failedRows, 5)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
)
def test_less_than_operator_failure(
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
):
"""Test < operator when test fails (got more rows than expected)"""
self.mock_test_case.parameterValues = self._create_mock_param_values("<", 5)
mock_run_results.return_value = 12
mock_compute_row_count.return_value = 1000
mock_get_runtime_params.return_value = Mock()
result = self.validator.run_validation()
self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed)
self.assertEqual(result.passedRows, 988)
self.assertEqual(result.failedRows, 12)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
)
def test_less_than_equal_operator_success(
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
):
"""Test <= operator when test passes"""
self.mock_test_case.parameterValues = self._create_mock_param_values("<=", 10)
mock_run_results.return_value = 10
mock_compute_row_count.return_value = 1000
mock_get_runtime_params.return_value = Mock()
result = self.validator.run_validation()
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
self.assertEqual(result.passedRows, 990)
self.assertEqual(result.failedRows, 10)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
)
def test_less_than_equal_operator_failure(
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
):
"""Test <= operator when test fails"""
self.mock_test_case.parameterValues = self._create_mock_param_values("<=", 8)
mock_run_results.return_value = 15
mock_compute_row_count.return_value = 1000
mock_get_runtime_params.return_value = Mock()
result = self.validator.run_validation()
self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed)
self.assertEqual(result.passedRows, 985)
self.assertEqual(result.failedRows, 15)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
)
def test_equal_operator_success(
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
):
"""Test == operator when test passes"""
self.mock_test_case.parameterValues = self._create_mock_param_values("==", 10)
mock_run_results.return_value = 10
mock_compute_row_count.return_value = 1000
mock_get_runtime_params.return_value = Mock()
result = self.validator.run_validation()
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
self.assertEqual(result.passedRows, 10)
self.assertEqual(result.failedRows, 990)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
)
def test_equal_operator_failure_more_rows(
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
):
"""Test == operator when test fails with more rows than expected"""
self.mock_test_case.parameterValues = self._create_mock_param_values("==", 10)
mock_run_results.return_value = 15
mock_compute_row_count.return_value = 1000
mock_get_runtime_params.return_value = Mock()
result = self.validator.run_validation()
self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed)
self.assertEqual(result.passedRows, 995)
self.assertEqual(result.failedRows, 5)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
)
def test_equal_operator_failure_fewer_rows(
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
):
"""Test == operator when test fails with fewer rows than expected"""
self.mock_test_case.parameterValues = self._create_mock_param_values("==", 10)
mock_run_results.return_value = 3
mock_compute_row_count.return_value = 1000
mock_get_runtime_params.return_value = Mock()
result = self.validator.run_validation()
self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed)
self.assertEqual(result.passedRows, 3)
self.assertEqual(result.failedRows, 997)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
)
def test_edge_case_zero_threshold_greater_than(
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
):
"""Test edge case with zero threshold and > operator"""
self.mock_test_case.parameterValues = self._create_mock_param_values(">", 0)
mock_run_results.return_value = 5
mock_compute_row_count.return_value = 1000
mock_get_runtime_params.return_value = Mock()
result = self.validator.run_validation()
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
self.assertEqual(result.passedRows, 5)
self.assertEqual(result.failedRows, 995)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
)
def test_edge_case_zero_actual_rows_less_than(
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
):
"""Test edge case with zero actual rows and < operator"""
self.mock_test_case.parameterValues = self._create_mock_param_values("<", 5)
mock_run_results.return_value = 0
mock_compute_row_count.return_value = 1000
mock_get_runtime_params.return_value = Mock()
result = self.validator.run_validation()
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
self.assertEqual(result.passedRows, 1000)
self.assertEqual(result.failedRows, 0)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
)
@patch(
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
)
def test_edge_case_negative_failed_rows_protection(
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
):
"""Test protection against negative failed rows for >= operator"""
self.mock_test_case.parameterValues = self._create_mock_param_values(">=", 5)
mock_run_results.return_value = 10
mock_compute_row_count.return_value = 1000
mock_get_runtime_params.return_value = Mock()
result = self.validator.run_validation()
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
self.assertEqual(result.passedRows, 10)
self.assertEqual(result.failedRows, 990)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,192 @@
"""
Unit tests for edge cases with zero or negative thresholds
"""
import unittest
def calculate_less_than_failure_fixed(
threshold: int, len_rows: int, row_count: int
) -> tuple[int, int]:
"""
Fixed implementation of _calculate_less_than_failure
"""
if threshold <= 0:
if row_count:
failed_rows = row_count
passed_rows = 0
else:
failed_rows = max(len_rows, 1)
passed_rows = 0
else:
failed_rows = max(0, len_rows - threshold)
passed_rows = (row_count - failed_rows) if row_count else threshold
return max(0, passed_rows), max(0, failed_rows)
def calculate_less_than_failure_old(
threshold: int, len_rows: int, row_count: int
) -> tuple[int, int]:
"""
Original buggy implementation for comparison
"""
failed_rows = max(0, len_rows - threshold)
passed_rows = (row_count - failed_rows) if row_count else threshold
return max(0, passed_rows), max(0, failed_rows)
class TestZeroThresholdEdgeCases(unittest.TestCase):
"""Test cases for zero and negative threshold edge cases"""
def test_bug_case_less_than_zero_threshold(self):
"""Test the reported bug case: < 0 threshold"""
len_rows = 0
threshold = 0
row_count = 1
old_passed, old_failed = calculate_less_than_failure_old(
threshold, len_rows, row_count
)
new_passed, new_failed = calculate_less_than_failure_fixed(
threshold, len_rows, row_count
)
print(
f"Bug case - len_rows={len_rows}, threshold={threshold}, row_count={row_count}"
)
print(f"Old: passed={old_passed}, failed={old_failed}")
print(f"New: passed={new_passed}, failed={new_failed}")
self.assertEqual(old_passed, 1)
self.assertEqual(old_failed, 0)
self.assertEqual(new_passed, 0)
self.assertEqual(new_failed, 1)
def test_less_than_negative_threshold(self):
"""Test < -1 threshold (impossible expectation)"""
len_rows = 5
threshold = -1
row_count = 100
old_passed, old_failed = calculate_less_than_failure_old(
threshold, len_rows, row_count
)
new_passed, new_failed = calculate_less_than_failure_fixed(
threshold, len_rows, row_count
)
print(
f"Negative threshold - len_rows={len_rows}, threshold={threshold}, row_count={row_count}"
)
print(f"Old: passed={old_passed}, failed={old_failed}")
print(f"New: passed={new_passed}, failed={new_failed}")
self.assertEqual(old_passed, 94)
self.assertEqual(old_failed, 6)
self.assertEqual(new_passed, 0)
self.assertEqual(new_failed, 100)
def test_less_than_zero_no_results(self):
"""Test < 0 threshold with no query results"""
len_rows = 0
threshold = 0
row_count = 50
old_passed, old_failed = calculate_less_than_failure_old(
threshold, len_rows, row_count
)
new_passed, new_failed = calculate_less_than_failure_fixed(
threshold, len_rows, row_count
)
print(
f"Zero threshold, no results - len_rows={len_rows}, threshold={threshold}, row_count={row_count}"
)
print(f"Old: passed={old_passed}, failed={old_failed}")
print(f"New: passed={new_passed}, failed={new_failed}")
self.assertEqual(old_passed, 50)
self.assertEqual(old_failed, 0)
self.assertEqual(new_passed, 0)
self.assertEqual(new_failed, 50)
def test_normal_case_still_works(self):
"""Test that normal cases (threshold > 0) still work correctly"""
len_rows = 15
threshold = 10
row_count = 100
old_passed, old_failed = calculate_less_than_failure_old(
threshold, len_rows, row_count
)
new_passed, new_failed = calculate_less_than_failure_fixed(
threshold, len_rows, row_count
)
print(
f"Normal case - len_rows={len_rows}, threshold={threshold}, row_count={row_count}"
)
print(f"Old: passed={old_passed}, failed={old_failed}")
print(f"New: passed={new_passed}, failed={new_failed}")
self.assertEqual(old_passed, new_passed)
self.assertEqual(old_failed, new_failed)
self.assertEqual(new_passed, 95)
self.assertEqual(new_failed, 5)
def test_edge_case_no_row_count(self):
"""Test zero threshold with no total row count available"""
len_rows = 0
threshold = 0
row_count = None
new_passed, new_failed = calculate_less_than_failure_fixed(
threshold, len_rows, row_count
)
print(
f"No row count - len_rows={len_rows}, threshold={threshold}, row_count={row_count}"
)
print(f"New: passed={new_passed}, failed={new_failed}")
self.assertEqual(new_passed, 0)
self.assertEqual(new_failed, 1)
def test_less_than_equal_zero_with_results(self):
"""Test <= 0 threshold with some query results"""
len_rows = 3
threshold = 0
row_count = 20
old_passed, old_failed = calculate_less_than_failure_old(
threshold, len_rows, row_count
)
new_passed, new_failed = calculate_less_than_failure_fixed(
threshold, len_rows, row_count
)
print(
f"<= 0 with results - len_rows={len_rows}, threshold={threshold}, row_count={row_count}"
)
print(f"Old: passed={old_passed}, failed={old_failed}")
print(f"New: passed={new_passed}, failed={new_failed}")
self.assertEqual(old_passed, 17)
self.assertEqual(old_failed, 3)
self.assertEqual(new_passed, 0)
self.assertEqual(new_failed, 20)
if __name__ == "__main__":
unittest.main()