mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-12-02 10:36:29 +00:00
* fix: added logic to compute pass/fail for sql queries with cte, nested queries, and joins * added logic to correctly compute pass / fail rows * style: ran python linting * fix: failing tests * style: fix linting error * fix: flawed count logic * fix: handle case where we don't compute row count
This commit is contained in:
parent
79fde4ab02
commit
57c5a50d20
@ -87,11 +87,13 @@ class BaseTableCustomSQLQueryValidator(BaseTestValidator):
|
||||
[TestResultValue(name=RESULT_ROW_COUNT, value=None)],
|
||||
)
|
||||
len_rows = rows if isinstance(rows, int) else len(rows)
|
||||
if evaluate_threshold(
|
||||
test_passed = evaluate_threshold(
|
||||
threshold,
|
||||
operator,
|
||||
len_rows,
|
||||
):
|
||||
)
|
||||
|
||||
if test_passed:
|
||||
status = TestCaseStatus.Success
|
||||
result_value = len_rows
|
||||
else:
|
||||
@ -99,17 +101,23 @@ class BaseTableCustomSQLQueryValidator(BaseTestValidator):
|
||||
result_value = len_rows
|
||||
|
||||
if self.test_case.computePassedFailedRowCount:
|
||||
row_count = self.get_row_count()
|
||||
row_count = self._get_total_row_count_if_needed()
|
||||
passed_rows, failed_rows = self._calculate_passed_failed_rows(
|
||||
test_passed, operator, threshold, len_rows, row_count
|
||||
)
|
||||
else:
|
||||
passed_rows = None
|
||||
failed_rows = None
|
||||
row_count = None
|
||||
|
||||
return self.get_test_case_result_object(
|
||||
self.execution_date,
|
||||
status,
|
||||
f"Found {result_value} row(s). Test query is expected to return {threshold} row.",
|
||||
f"Found {result_value} row(s). Test query is expected to return {operator} {threshold} row(s).",
|
||||
[TestResultValue(name=RESULT_ROW_COUNT, value=str(result_value))],
|
||||
row_count=row_count,
|
||||
failed_rows=result_value,
|
||||
failed_rows=failed_rows,
|
||||
passed_rows=passed_rows,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@ -132,3 +140,100 @@ class BaseTableCustomSQLQueryValidator(BaseTestValidator):
|
||||
Tuple[int, int]:
|
||||
"""
|
||||
return self.compute_row_count()
|
||||
|
||||
def _get_total_row_count_if_needed(self) -> int:
|
||||
"""Get total row count if computePassedFailedRowCount is enabled"""
|
||||
return self.get_row_count()
|
||||
|
||||
def _calculate_passed_failed_rows(
|
||||
self,
|
||||
test_passed: bool,
|
||||
operator: str,
|
||||
threshold: int,
|
||||
len_rows: int,
|
||||
row_count: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Calculate passed and failed rows based on test result and operator
|
||||
|
||||
Args:
|
||||
test_passed: Whether the test passed
|
||||
operator: Comparison operator (>, >=, <, <=, ==)
|
||||
threshold: Expected threshold value
|
||||
len_rows: Number of rows returned by the test query
|
||||
row_count: Total number of rows in the table (or None)
|
||||
|
||||
Returns:
|
||||
Tuple of (passed_rows, failed_rows)
|
||||
"""
|
||||
if test_passed:
|
||||
return self._calculate_passed_rows_success(operator, len_rows, row_count)
|
||||
return self._calculate_passed_rows_failure(
|
||||
operator, threshold, len_rows, row_count
|
||||
)
|
||||
|
||||
def _calculate_passed_rows_success(
|
||||
self, operator: str, len_rows: int, row_count: int
|
||||
) -> tuple[int, int]:
|
||||
"""Calculate passed/failed rows when test passed"""
|
||||
if operator in (">", ">="):
|
||||
passed_rows = len_rows
|
||||
failed_rows = (row_count - len_rows) if row_count else 0
|
||||
elif operator in ("<", "<="):
|
||||
passed_rows = row_count - len_rows
|
||||
failed_rows = len_rows
|
||||
elif operator == "==":
|
||||
passed_rows = len_rows
|
||||
failed_rows = row_count - len_rows
|
||||
else:
|
||||
passed_rows = len_rows
|
||||
failed_rows = 0
|
||||
|
||||
return max(0, passed_rows), max(0, failed_rows)
|
||||
|
||||
def _calculate_passed_rows_failure(
|
||||
self, operator: str, threshold: int, len_rows: int, row_count: int
|
||||
) -> tuple[int, int]:
|
||||
"""Calculate passed/failed rows when test failed"""
|
||||
if operator in (">", ">="):
|
||||
return self._calculate_greater_than_failure(len_rows, row_count)
|
||||
if operator in ("<", "<="):
|
||||
return self._calculate_less_than_failure(len_rows, row_count)
|
||||
if operator == "==":
|
||||
return self._calculate_equal_failure(threshold, len_rows, row_count)
|
||||
|
||||
failed_rows = row_count if row_count else len_rows
|
||||
return 0, max(0, failed_rows)
|
||||
|
||||
def _calculate_greater_than_failure(
|
||||
self, len_rows: int, row_count: int
|
||||
) -> tuple[int, int]:
|
||||
"""Calculate rows for > or >= operator failure (expected more rows)"""
|
||||
passed_rows = len_rows
|
||||
failed_rows = (row_count - len_rows) if row_count else 0
|
||||
return max(0, passed_rows), max(0, failed_rows)
|
||||
|
||||
def _calculate_less_than_failure(
|
||||
self, len_rows: int, row_count: int
|
||||
) -> tuple[int, int]:
|
||||
"""Calculate rows for < or <= operator failure (expected fewer rows)"""
|
||||
failed_rows = len_rows
|
||||
passed_rows = row_count - failed_rows
|
||||
|
||||
return max(0, passed_rows), max(0, failed_rows)
|
||||
|
||||
def _calculate_equal_failure(
|
||||
self, threshold: int, len_rows: int, row_count: int
|
||||
) -> tuple[int, int]:
|
||||
"""Calculate rows for == operator failure (expected exact count)"""
|
||||
if row_count:
|
||||
if len_rows > threshold:
|
||||
failed_rows = len_rows - threshold
|
||||
passed_rows = row_count - failed_rows
|
||||
else:
|
||||
failed_rows = row_count - len_rows
|
||||
passed_rows = len_rows
|
||||
else:
|
||||
failed_rows = abs(len_rows - threshold)
|
||||
passed_rows = 0
|
||||
|
||||
return max(0, passed_rows), max(0, failed_rows)
|
||||
|
||||
@ -13,10 +13,13 @@
|
||||
Validator for table custom SQL Query test case
|
||||
"""
|
||||
|
||||
from typing import Optional, cast
|
||||
from typing import Optional, Tuple, cast
|
||||
|
||||
import sqlparse
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.sql import func, select
|
||||
from sqlparse.sql import Statement, Token, Where
|
||||
from sqlparse.tokens import Keyword
|
||||
|
||||
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
|
||||
SQAValidatorMixin,
|
||||
@ -33,11 +36,244 @@ from metadata.profiler.metrics.registry import Metrics
|
||||
from metadata.profiler.orm.functions.table_metric_computer import TableMetricComputer
|
||||
from metadata.profiler.processor.runner import QueryRunner
|
||||
from metadata.utils.helpers import is_safe_sql_query
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
|
||||
logger = ingestion_logger()
|
||||
|
||||
|
||||
class TableCustomSQLQueryValidator(BaseTableCustomSQLQueryValidator, SQAValidatorMixin):
|
||||
"""Validator for table custom SQL Query test case"""
|
||||
|
||||
def _replace_where_clause(
|
||||
self, sql_query: str, partition_expression: str
|
||||
) -> Optional[str]:
|
||||
"""Replace or add WHERE clause in SQL query using sqlparse.
|
||||
|
||||
This method properly handles:
|
||||
- Queries with existing WHERE clause (replaces it)
|
||||
- Queries without WHERE clause (adds it)
|
||||
- Complex queries with joins, subqueries, CTEs
|
||||
- Preserves GROUP BY, ORDER BY, LIMIT, etc.
|
||||
|
||||
Args:
|
||||
sql_query: Original SQL query
|
||||
partition_expression: New WHERE condition (without WHERE keyword)
|
||||
|
||||
Returns:
|
||||
Modified SQL query with partition_expression as WHERE clause
|
||||
"""
|
||||
parsed = sqlparse.parse(sql_query)
|
||||
if not parsed or len(parsed) == 0:
|
||||
return None
|
||||
|
||||
statement: Statement = parsed[0]
|
||||
tokens = list(statement.tokens)
|
||||
|
||||
where_idx, where_end_idx, insert_before_idx = self._find_clause_positions(
|
||||
tokens
|
||||
)
|
||||
new_tokens = self._build_new_tokens(
|
||||
tokens, where_idx, where_end_idx, insert_before_idx, partition_expression
|
||||
)
|
||||
|
||||
return "".join(str(token) for token in new_tokens)
|
||||
|
||||
def _find_clause_positions(
|
||||
self, tokens: list
|
||||
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
||||
"""Find positions of WHERE clause and insertion points in token list.
|
||||
|
||||
Args:
|
||||
tokens: List of parsed SQL tokens
|
||||
|
||||
Returns:
|
||||
Tuple of (where_idx, where_end_idx, insert_before_idx)
|
||||
"""
|
||||
where_idx = None
|
||||
where_end_idx = None
|
||||
insert_before_idx = None
|
||||
paren_depth = 0
|
||||
|
||||
for i, token in enumerate(tokens):
|
||||
paren_depth = self._update_parentheses_depth(token, paren_depth)
|
||||
|
||||
if isinstance(token, Where) and paren_depth == 0:
|
||||
where_idx = i
|
||||
where_end_idx = i + 1
|
||||
break
|
||||
|
||||
if self._should_insert_before_token(token, insert_before_idx, paren_depth):
|
||||
insert_before_idx = i
|
||||
|
||||
return where_idx, where_end_idx, insert_before_idx
|
||||
|
||||
def _update_parentheses_depth(self, token: Token, current_depth: int) -> int:
|
||||
"""Update parentheses depth based on token content.
|
||||
|
||||
Args:
|
||||
token: SQL token to analyze
|
||||
current_depth: Current parentheses depth
|
||||
|
||||
Returns:
|
||||
Updated parentheses depth
|
||||
"""
|
||||
if token.ttype is None and hasattr(token, "tokens"):
|
||||
paren_count = str(token).count("(") - str(token).count(")")
|
||||
return current_depth + paren_count
|
||||
elif token.value == "(":
|
||||
return current_depth + 1
|
||||
elif token.value == ")":
|
||||
return current_depth - 1
|
||||
return current_depth
|
||||
|
||||
def _should_insert_before_token(
|
||||
self, token: Token, insert_before_idx: Optional[int], paren_depth: int
|
||||
) -> bool:
|
||||
"""Check if WHERE clause should be inserted before this token.
|
||||
|
||||
Args:
|
||||
token: SQL token to check
|
||||
insert_before_idx: Current insertion index (None if not set)
|
||||
paren_depth: Current parentheses depth
|
||||
|
||||
Returns:
|
||||
True if WHERE should be inserted before this token
|
||||
"""
|
||||
if insert_before_idx is not None or paren_depth != 0:
|
||||
return False
|
||||
|
||||
if token.ttype is not Keyword:
|
||||
return False
|
||||
|
||||
clause_keywords = {
|
||||
"GROUP BY",
|
||||
"ORDER BY",
|
||||
"HAVING",
|
||||
"LIMIT",
|
||||
"OFFSET",
|
||||
"UNION",
|
||||
"EXCEPT",
|
||||
"INTERSECT",
|
||||
}
|
||||
|
||||
return any(keyword in token.value.upper() for keyword in clause_keywords)
|
||||
|
||||
def _build_new_tokens(
|
||||
self,
|
||||
tokens: list,
|
||||
where_idx: Optional[int],
|
||||
where_end_idx: Optional[int],
|
||||
insert_before_idx: Optional[int],
|
||||
partition_expression: str,
|
||||
) -> list:
|
||||
"""Build new token list with WHERE clause inserted or replaced.
|
||||
|
||||
Args:
|
||||
tokens: Original token list
|
||||
where_idx: Index of existing WHERE clause (None if not found)
|
||||
where_end_idx: End index of existing WHERE clause
|
||||
insert_before_idx: Index to insert WHERE before (None if append)
|
||||
partition_expression: WHERE condition expression
|
||||
|
||||
Returns:
|
||||
New list of tokens with WHERE clause
|
||||
"""
|
||||
if where_idx is not None:
|
||||
return self._replace_existing_where(
|
||||
tokens, where_idx, where_end_idx, partition_expression
|
||||
)
|
||||
elif insert_before_idx is not None:
|
||||
return self._insert_where_before_clause(
|
||||
tokens, insert_before_idx, partition_expression
|
||||
)
|
||||
else:
|
||||
return self._append_where_clause(tokens, partition_expression)
|
||||
|
||||
def _replace_existing_where(
|
||||
self,
|
||||
tokens: list,
|
||||
where_idx: int,
|
||||
where_end_idx: int,
|
||||
partition_expression: str,
|
||||
) -> list:
|
||||
"""Replace existing WHERE clause with new expression.
|
||||
|
||||
Args:
|
||||
tokens: Original token list
|
||||
where_idx: Index of WHERE clause to replace
|
||||
where_end_idx: End index of WHERE clause
|
||||
partition_expression: New WHERE condition
|
||||
|
||||
Returns:
|
||||
Token list with replaced WHERE clause
|
||||
"""
|
||||
original_where = str(tokens[where_idx])
|
||||
trailing_whitespace = self._extract_trailing_whitespace(original_where)
|
||||
|
||||
return (
|
||||
tokens[:where_idx]
|
||||
+ [
|
||||
Token(Keyword, "WHERE"),
|
||||
Token(None, f" {partition_expression}{trailing_whitespace}"),
|
||||
]
|
||||
+ tokens[where_end_idx:]
|
||||
)
|
||||
|
||||
def _insert_where_before_clause(
|
||||
self, tokens: list, insert_before_idx: int, partition_expression: str
|
||||
) -> list:
|
||||
"""Insert WHERE clause before specified token index.
|
||||
|
||||
Args:
|
||||
tokens: Original token list
|
||||
insert_before_idx: Index to insert WHERE clause before
|
||||
partition_expression: WHERE condition expression
|
||||
|
||||
Returns:
|
||||
Token list with WHERE clause inserted
|
||||
"""
|
||||
return (
|
||||
tokens[:insert_before_idx]
|
||||
+ [Token(Keyword, "WHERE"), Token(None, f" {partition_expression} ")]
|
||||
+ tokens[insert_before_idx:]
|
||||
)
|
||||
|
||||
def _append_where_clause(self, tokens: list, partition_expression: str) -> list:
|
||||
"""Append WHERE clause to end of token list.
|
||||
|
||||
Args:
|
||||
tokens: Original token list
|
||||
partition_expression: WHERE condition expression
|
||||
|
||||
Returns:
|
||||
Token list with WHERE clause appended
|
||||
"""
|
||||
return tokens + [
|
||||
Token(None, " "),
|
||||
Token(Keyword, "WHERE"),
|
||||
Token(None, f" {partition_expression}"),
|
||||
]
|
||||
|
||||
def _extract_trailing_whitespace(self, where_clause: str) -> str:
|
||||
"""Extract trailing whitespace from WHERE clause string.
|
||||
|
||||
Args:
|
||||
where_clause: Original WHERE clause string
|
||||
|
||||
Returns:
|
||||
Trailing whitespace string
|
||||
"""
|
||||
if not where_clause.split():
|
||||
return ""
|
||||
|
||||
last_word = where_clause.split()[-1]
|
||||
where_content_end = where_clause.rfind(last_word) + len(last_word)
|
||||
|
||||
if where_content_end < len(where_clause):
|
||||
return where_clause[where_content_end:]
|
||||
|
||||
return ""
|
||||
|
||||
def run_validation(self) -> TestCaseResult:
|
||||
"""Run validation for the given test case
|
||||
|
||||
@ -85,12 +321,39 @@ class TableCustomSQLQueryValidator(BaseTableCustomSQLQueryValidator, SQAValidato
|
||||
None,
|
||||
)
|
||||
if partition_expression:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(self.runner.table)
|
||||
.filter(text(partition_expression))
|
||||
custom_sql = self.get_test_case_param_value(
|
||||
self.test_case.parameterValues, # type: ignore
|
||||
"sqlExpression",
|
||||
str,
|
||||
)
|
||||
return self.runner.session.execute(stmt).scalar()
|
||||
|
||||
if custom_sql:
|
||||
modified_query = self._replace_where_clause(
|
||||
custom_sql, partition_expression
|
||||
)
|
||||
if modified_query is None:
|
||||
return None
|
||||
count_query = f"SELECT COUNT(*) FROM ({modified_query}) AS test_results"
|
||||
|
||||
try:
|
||||
result = self.runner.session.execute(text(count_query)).scalar()
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to execute custom SQL with partition expression. "
|
||||
f"Query: {count_query}\n"
|
||||
f"Error: {exc}\n",
|
||||
exc_info=True,
|
||||
)
|
||||
self.runner.session.rollback()
|
||||
raise exc
|
||||
else:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(self.runner.table)
|
||||
.filter(text(partition_expression))
|
||||
)
|
||||
return self.runner.session.execute(stmt).scalar()
|
||||
|
||||
self.runner = cast(QueryRunner, self.runner)
|
||||
dialect = self.runner._session.get_bind().dialect.name
|
||||
|
||||
@ -86,7 +86,8 @@ class OpenMetadataValidationAction1xx(ValidationAction):
|
||||
# This will be initialized in the run method
|
||||
ometa_conn: Optional[OpenMetadata] = None
|
||||
|
||||
def run( # pylint: disable=unused-argument, arguments-differ
|
||||
# pylint: disable=unused-argument,arguments-differ
|
||||
def run(
|
||||
self,
|
||||
checkpoint_result: CheckpointResult,
|
||||
action_context: Union[ActionContext, None],
|
||||
|
||||
@ -26,7 +26,7 @@ def create_service_request(postgres_container, tmp_path_factory):
|
||||
username=postgres_container.username,
|
||||
authType=BasicAuth(password=postgres_container.password),
|
||||
hostPort="localhost:"
|
||||
+ postgres_container.get_exposed_port(postgres_container.port),
|
||||
+ str(postgres_container.get_exposed_port(postgres_container.port)),
|
||||
database="dvdrental",
|
||||
)
|
||||
),
|
||||
|
||||
@ -0,0 +1,180 @@
|
||||
"""
|
||||
Unit tests for passed/failed row count calculation logic
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
def calculate_passed_failed_rows(
|
||||
test_passed: bool,
|
||||
operator: str,
|
||||
threshold: int,
|
||||
actual_rows: int,
|
||||
total_rows: int = None,
|
||||
):
|
||||
"""
|
||||
Calculate passed and failed rows based on test result, operator, threshold, and actual row count.
|
||||
|
||||
This function replicates the LEGACY logic for cases without total row count.
|
||||
Note: This is kept for backward compatibility but the new logic with total_rows is more accurate.
|
||||
"""
|
||||
if total_rows is None:
|
||||
if test_passed:
|
||||
return actual_rows, 0
|
||||
else:
|
||||
if operator in (">", ">="):
|
||||
failed_rows = 0
|
||||
passed_rows = actual_rows
|
||||
elif operator in ("<", "<="):
|
||||
failed_rows = max(0, actual_rows - threshold)
|
||||
passed_rows = threshold
|
||||
elif operator == "==":
|
||||
failed_rows = abs(actual_rows - threshold)
|
||||
passed_rows = 0
|
||||
else:
|
||||
failed_rows = actual_rows
|
||||
passed_rows = 0
|
||||
|
||||
return max(0, passed_rows), max(0, failed_rows)
|
||||
else:
|
||||
|
||||
raise NotImplementedError(
|
||||
"Use test_row_count_logic_with_total.py for total row count tests"
|
||||
)
|
||||
|
||||
|
||||
class TestPassedFailedRowCalculation(unittest.TestCase):
|
||||
"""Test cases for passed/failed row count calculation logic"""
|
||||
|
||||
def test_greater_than_operator_success(self):
|
||||
"""Test > operator when test passes"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=True, operator=">", threshold=5, actual_rows=10
|
||||
)
|
||||
self.assertEqual(passed_rows, 10)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_greater_than_operator_failure(self):
|
||||
"""Test > operator when test fails (got fewer rows than expected)"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=False, operator=">", threshold=10, actual_rows=5
|
||||
)
|
||||
self.assertEqual(passed_rows, 5)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_greater_than_equal_operator_success(self):
|
||||
"""Test >= operator when test passes"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=True, operator=">=", threshold=10, actual_rows=10
|
||||
)
|
||||
self.assertEqual(passed_rows, 10)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_greater_than_equal_operator_failure(self):
|
||||
"""Test >= operator when test fails"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=False, operator=">=", threshold=15, actual_rows=8
|
||||
)
|
||||
self.assertEqual(passed_rows, 8)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_less_than_operator_success(self):
|
||||
"""Test < operator when test passes"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=True, operator="<", threshold=10, actual_rows=5
|
||||
)
|
||||
self.assertEqual(passed_rows, 5)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_less_than_operator_failure(self):
|
||||
"""Test < operator when test fails (got more rows than expected)"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=False, operator="<", threshold=5, actual_rows=12
|
||||
)
|
||||
self.assertEqual(passed_rows, 5)
|
||||
self.assertEqual(failed_rows, 7)
|
||||
|
||||
def test_less_than_equal_operator_success(self):
|
||||
"""Test <= operator when test passes"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=True, operator="<=", threshold=10, actual_rows=10
|
||||
)
|
||||
self.assertEqual(passed_rows, 10)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_less_than_equal_operator_failure(self):
|
||||
"""Test <= operator when test fails"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=False, operator="<=", threshold=8, actual_rows=15
|
||||
)
|
||||
self.assertEqual(passed_rows, 8)
|
||||
self.assertEqual(failed_rows, 7)
|
||||
|
||||
def test_equal_operator_success(self):
|
||||
"""Test == operator when test passes"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=True, operator="==", threshold=10, actual_rows=10
|
||||
)
|
||||
self.assertEqual(passed_rows, 10)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_equal_operator_failure_more_rows(self):
|
||||
"""Test == operator when test fails with more rows than expected"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=False, operator="==", threshold=10, actual_rows=15
|
||||
)
|
||||
self.assertEqual(passed_rows, 0)
|
||||
self.assertEqual(failed_rows, 5)
|
||||
|
||||
def test_equal_operator_failure_fewer_rows(self):
|
||||
"""Test == operator when test fails with fewer rows than expected"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=False, operator="==", threshold=10, actual_rows=3
|
||||
)
|
||||
self.assertEqual(passed_rows, 0)
|
||||
self.assertEqual(failed_rows, 7)
|
||||
|
||||
def test_edge_case_zero_threshold_greater_than(self):
|
||||
"""Test edge case with zero threshold and > operator"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=True, operator=">", threshold=0, actual_rows=5
|
||||
)
|
||||
self.assertEqual(passed_rows, 5)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_edge_case_zero_actual_rows_less_than(self):
|
||||
"""Test edge case with zero actual rows and < operator"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=True, operator="<", threshold=5, actual_rows=0
|
||||
)
|
||||
self.assertEqual(passed_rows, 0)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_edge_case_negative_protection_greater_than_equal(self):
|
||||
"""Test protection against negative calculations for >= operator"""
|
||||
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=True, operator=">=", threshold=5, actual_rows=10
|
||||
)
|
||||
self.assertEqual(passed_rows, 10)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_equal_operator_with_zero_threshold(self):
|
||||
"""Test == operator with zero threshold"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=False, operator="==", threshold=0, actual_rows=5
|
||||
)
|
||||
self.assertEqual(passed_rows, 0)
|
||||
self.assertEqual(failed_rows, 5)
|
||||
|
||||
def test_unknown_operator_fallback(self):
|
||||
"""Test fallback for unknown operators"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows(
|
||||
test_passed=False, operator="!=", threshold=10, actual_rows=15
|
||||
)
|
||||
self.assertEqual(passed_rows, 0)
|
||||
self.assertEqual(failed_rows, 15)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -0,0 +1,302 @@
|
||||
"""
|
||||
Unit tests for passed/failed row count calculation with total row count
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
def calculate_passed_failed_rows_with_total(
|
||||
test_passed: bool,
|
||||
operator: str,
|
||||
threshold: int,
|
||||
actual_rows: int,
|
||||
total_rows: int = None,
|
||||
):
|
||||
"""
|
||||
Calculate passed and failed rows considering total row count.
|
||||
|
||||
This replicates the fixed logic from BaseTableCustomSQLQueryValidator.
|
||||
"""
|
||||
row_count = total_rows
|
||||
len_rows = actual_rows
|
||||
|
||||
if test_passed:
|
||||
|
||||
if operator in (">", ">=", "=="):
|
||||
|
||||
passed_rows = len_rows
|
||||
failed_rows = (row_count - len_rows) if row_count else 0
|
||||
elif operator in ("<", "<="):
|
||||
|
||||
passed_rows = row_count if row_count else len_rows
|
||||
failed_rows = 0
|
||||
else:
|
||||
|
||||
passed_rows = len_rows
|
||||
failed_rows = 0
|
||||
else:
|
||||
|
||||
if operator in (">", ">="):
|
||||
|
||||
passed_rows = len_rows
|
||||
failed_rows = (row_count - len_rows) if row_count else 0
|
||||
elif operator in ("<", "<="):
|
||||
|
||||
if threshold <= 0:
|
||||
|
||||
if row_count:
|
||||
failed_rows = row_count
|
||||
passed_rows = 0
|
||||
else:
|
||||
|
||||
failed_rows = max(len_rows, 1)
|
||||
passed_rows = 0
|
||||
else:
|
||||
|
||||
failed_rows = max(0, len_rows - threshold)
|
||||
passed_rows = (row_count - failed_rows) if row_count else threshold
|
||||
elif operator == "==":
|
||||
|
||||
if row_count:
|
||||
|
||||
if len_rows > threshold:
|
||||
|
||||
failed_rows = len_rows - threshold
|
||||
passed_rows = row_count - failed_rows
|
||||
else:
|
||||
|
||||
failed_rows = row_count - len_rows
|
||||
passed_rows = len_rows
|
||||
else:
|
||||
|
||||
failed_rows = abs(len_rows - threshold)
|
||||
passed_rows = 0
|
||||
else:
|
||||
|
||||
failed_rows = row_count if row_count else len_rows
|
||||
passed_rows = 0
|
||||
|
||||
passed_rows = max(0, passed_rows)
|
||||
failed_rows = max(0, failed_rows)
|
||||
|
||||
return passed_rows, failed_rows
|
||||
|
||||
|
||||
class TestPassedFailedRowCalculationWithTotal(unittest.TestCase):
|
||||
"""Test cases for the fixed row count calculation logic with total row count"""
|
||||
|
||||
def test_greater_than_operator_bug_case(self):
|
||||
"""Test the bug case: >= with threshold 0, 0 rows returned, 1 total row"""
|
||||
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=False,
|
||||
operator=">=",
|
||||
threshold=1,
|
||||
actual_rows=0,
|
||||
total_rows=1,
|
||||
)
|
||||
|
||||
self.assertEqual(passed_rows, 0)
|
||||
self.assertEqual(failed_rows, 1)
|
||||
|
||||
def test_greater_than_with_total_rows_success(self):
|
||||
"""Test > operator success with total row count"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=True,
|
||||
operator=">",
|
||||
threshold=5,
|
||||
actual_rows=10,
|
||||
total_rows=100,
|
||||
)
|
||||
self.assertEqual(passed_rows, 10)
|
||||
self.assertEqual(failed_rows, 90)
|
||||
|
||||
def test_greater_than_with_total_rows_failure(self):
|
||||
"""Test > operator failure with total row count"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=False,
|
||||
operator=">",
|
||||
threshold=50,
|
||||
actual_rows=10,
|
||||
total_rows=100,
|
||||
)
|
||||
|
||||
self.assertEqual(passed_rows, 10)
|
||||
self.assertEqual(failed_rows, 90)
|
||||
|
||||
def test_less_than_with_total_rows_success(self):
|
||||
"""Test < operator success with total row count"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=True,
|
||||
operator="<",
|
||||
threshold=10,
|
||||
actual_rows=5,
|
||||
total_rows=100,
|
||||
)
|
||||
|
||||
self.assertEqual(passed_rows, 100)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_less_than_with_total_rows_failure(self):
|
||||
"""Test < operator failure with total row count"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=False,
|
||||
operator="<",
|
||||
threshold=10,
|
||||
actual_rows=20,
|
||||
total_rows=100,
|
||||
)
|
||||
|
||||
self.assertEqual(failed_rows, 10)
|
||||
self.assertEqual(passed_rows, 90)
|
||||
|
||||
def test_less_than_equal_with_total_rows_success(self):
|
||||
"""Test <= operator success with total row count"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=True,
|
||||
operator="<=",
|
||||
threshold=10,
|
||||
actual_rows=10,
|
||||
total_rows=100,
|
||||
)
|
||||
self.assertEqual(passed_rows, 100)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_less_than_equal_with_total_rows_failure(self):
|
||||
"""Test <= operator failure with total row count"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=False,
|
||||
operator="<=",
|
||||
threshold=5,
|
||||
actual_rows=15,
|
||||
total_rows=100,
|
||||
)
|
||||
|
||||
self.assertEqual(failed_rows, 10)
|
||||
self.assertEqual(passed_rows, 90)
|
||||
|
||||
def test_equal_with_total_rows_success(self):
|
||||
"""Test == operator success with total row count"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=True,
|
||||
operator="==",
|
||||
threshold=10,
|
||||
actual_rows=10,
|
||||
total_rows=100,
|
||||
)
|
||||
self.assertEqual(passed_rows, 10)
|
||||
self.assertEqual(failed_rows, 90)
|
||||
|
||||
def test_equal_with_total_rows_failure_too_many(self):
|
||||
"""Test == operator failure with too many rows"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=False,
|
||||
operator="==",
|
||||
threshold=10,
|
||||
actual_rows=15,
|
||||
total_rows=100,
|
||||
)
|
||||
|
||||
self.assertEqual(failed_rows, 5)
|
||||
self.assertEqual(passed_rows, 95)
|
||||
|
||||
def test_equal_with_total_rows_failure_too_few(self):
|
||||
"""Test == operator failure with too few rows"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=False,
|
||||
operator="==",
|
||||
threshold=10,
|
||||
actual_rows=3,
|
||||
total_rows=100,
|
||||
)
|
||||
|
||||
self.assertEqual(failed_rows, 97)
|
||||
self.assertEqual(passed_rows, 3)
|
||||
|
||||
def test_edge_case_zero_total_rows(self):
|
||||
"""Test edge case with zero total rows"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=False, operator=">", threshold=5, actual_rows=0, total_rows=0
|
||||
)
|
||||
self.assertEqual(passed_rows, 0)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_edge_case_no_total_row_count(self):
|
||||
"""Test when total row count is not available (None)"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=False,
|
||||
operator=">",
|
||||
threshold=10,
|
||||
actual_rows=5,
|
||||
total_rows=None,
|
||||
)
|
||||
|
||||
self.assertEqual(passed_rows, 5)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_greater_equal_zero_threshold_zero_result(self):
|
||||
"""Test >= 0 with 0 results but rows in table"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=True,
|
||||
operator=">=",
|
||||
threshold=0,
|
||||
actual_rows=0,
|
||||
total_rows=100,
|
||||
)
|
||||
self.assertEqual(passed_rows, 0)
|
||||
self.assertEqual(failed_rows, 100)
|
||||
|
||||
def test_all_rows_match_greater_than(self):
|
||||
"""Test when all rows in table match the condition for >"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=True,
|
||||
operator=">",
|
||||
threshold=5,
|
||||
actual_rows=100,
|
||||
total_rows=100,
|
||||
)
|
||||
self.assertEqual(passed_rows, 100)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_no_rows_match_less_than(self):
|
||||
"""Test when no rows match the condition for <"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=True,
|
||||
operator="<",
|
||||
threshold=10,
|
||||
actual_rows=0,
|
||||
total_rows=100,
|
||||
)
|
||||
self.assertEqual(passed_rows, 100)
|
||||
self.assertEqual(failed_rows, 0)
|
||||
|
||||
def test_less_than_zero_threshold_failure_bug_case(self):
|
||||
"""Test the reported bug case: < 0 threshold with failure"""
|
||||
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=False,
|
||||
operator="<",
|
||||
threshold=0,
|
||||
actual_rows=0,
|
||||
total_rows=1,
|
||||
)
|
||||
|
||||
self.assertEqual(passed_rows, 0)
|
||||
self.assertEqual(failed_rows, 1)
|
||||
|
||||
def test_less_than_negative_threshold_failure(self):
|
||||
"""Test < -1 threshold failure (impossible expectation)"""
|
||||
passed_rows, failed_rows = calculate_passed_failed_rows_with_total(
|
||||
test_passed=False,
|
||||
operator="<",
|
||||
threshold=-1,
|
||||
actual_rows=5,
|
||||
total_rows=100,
|
||||
)
|
||||
|
||||
self.assertEqual(passed_rows, 0)
|
||||
self.assertEqual(failed_rows, 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -0,0 +1,247 @@
|
||||
"""
|
||||
Unit tests for TableCustomSQLQueryValidator._replace_where_clause method
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
from metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery import (
|
||||
TableCustomSQLQueryValidator,
|
||||
)
|
||||
|
||||
|
||||
class TestTableCustomSQLQueryValidator(unittest.TestCase):
|
||||
"""Test cases for TableCustomSQLQueryValidator._replace_where_clause method"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
mock_runner = Mock()
|
||||
mock_test_case = Mock()
|
||||
mock_execution_date = datetime.now()
|
||||
|
||||
self.validator = TableCustomSQLQueryValidator(
|
||||
runner=mock_runner,
|
||||
test_case=mock_test_case,
|
||||
execution_date=mock_execution_date,
|
||||
)
|
||||
|
||||
def test_simple_query_no_where_clause(self):
|
||||
"""Test adding WHERE clause to simple query without existing WHERE"""
|
||||
sql = "SELECT * FROM users"
|
||||
partition_expr = "created_at > '2023-01-01'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
expected = "SELECT * FROM users WHERE created_at > '2023-01-01'"
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_simple_query_with_existing_where(self):
|
||||
"""Test replacing existing WHERE clause in simple query"""
|
||||
sql = "SELECT * FROM users WHERE id > 100"
|
||||
partition_expr = "created_at > '2023-01-01'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
expected = "SELECT * FROM users WHERE created_at > '2023-01-01'"
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_query_with_order_by_no_where(self):
|
||||
"""Test adding WHERE clause before ORDER BY"""
|
||||
sql = "SELECT * FROM users ORDER BY name"
|
||||
partition_expr = "status = 'active'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
expected = "SELECT * FROM users WHERE status = 'active' ORDER BY name"
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_query_with_group_by_no_where(self):
|
||||
"""Test adding WHERE clause before GROUP BY"""
|
||||
sql = "SELECT department, COUNT(*) FROM employees GROUP BY department"
|
||||
partition_expr = "hire_date > '2020-01-01'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
expected = "SELECT department, COUNT(*) FROM employees WHERE hire_date > '2020-01-01' GROUP BY department"
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_query_with_subquery_inner_where_preserved(self):
|
||||
"""Test that WHERE clause in subquery is preserved"""
|
||||
sql = "SELECT foo FROM a INNER JOIN (SELECT bar FROM b WHERE abc = 3) WHERE a.id BETWEEN 2 AND 4"
|
||||
partition_expr = "a.status = 'active'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
expected = "SELECT foo FROM a INNER JOIN (SELECT bar FROM b WHERE abc = 3) WHERE a.status = 'active'"
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_complex_subquery_multiple_levels(self):
|
||||
"""Test complex nested subqueries with multiple WHERE clauses"""
|
||||
sql = """SELECT * FROM users u
|
||||
WHERE u.id IN (
|
||||
SELECT user_id FROM orders o
|
||||
WHERE o.total > (
|
||||
SELECT AVG(total) FROM orders WHERE status = 'completed'
|
||||
)
|
||||
)"""
|
||||
partition_expr = "u.created_at > '2023-01-01'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
|
||||
self.assertIn("WHERE u.created_at > '2023-01-01'", result)
|
||||
self.assertNotIn("WHERE u.id IN", result)
|
||||
|
||||
def test_query_with_cte_where_preserved(self):
|
||||
"""Test that WHERE clauses in CTEs are preserved"""
|
||||
sql = """WITH active_users AS (
|
||||
SELECT * FROM users WHERE status = 'active'
|
||||
)
|
||||
SELECT * FROM active_users WHERE id > 100"""
|
||||
partition_expr = "created_at > '2023-01-01'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
|
||||
self.assertIn("WHERE status = 'active'", result)
|
||||
self.assertIn("WHERE created_at > '2023-01-01'", result)
|
||||
self.assertNotIn("WHERE id > 100", result)
|
||||
|
||||
def test_query_with_union_no_where(self):
|
||||
"""Test adding WHERE clause before UNION"""
|
||||
sql = "SELECT id FROM table1 UNION SELECT id FROM table2"
|
||||
partition_expr = "status = 'active'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
expected = (
|
||||
"SELECT id FROM table1 WHERE status = 'active' UNION SELECT id FROM table2"
|
||||
)
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_query_with_having_no_where(self):
|
||||
"""Test adding WHERE clause before HAVING"""
|
||||
sql = "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5"
|
||||
partition_expr = "hire_date > '2020-01-01'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
expected = "SELECT department, COUNT(*) FROM employees WHERE hire_date > '2020-01-01' GROUP BY department HAVING COUNT(*) > 5"
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_query_with_limit_no_where(self):
|
||||
"""Test adding WHERE clause before LIMIT"""
|
||||
sql = "SELECT * FROM users LIMIT 10"
|
||||
partition_expr = "status = 'active'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
expected = "SELECT * FROM users WHERE status = 'active' LIMIT 10"
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_query_with_offset_no_where(self):
|
||||
"""Test adding WHERE clause before OFFSET"""
|
||||
sql = "SELECT * FROM users OFFSET 20"
|
||||
partition_expr = "status = 'active'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
expected = "SELECT * FROM users WHERE status = 'active' OFFSET 20"
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_complex_query_with_joins_and_subqueries(self):
|
||||
"""Test complex query with joins, subqueries, and existing WHERE"""
|
||||
sql = """SELECT u.name, o.total
|
||||
FROM users u
|
||||
INNER JOIN orders o ON u.id = o.user_id
|
||||
LEFT JOIN (
|
||||
SELECT user_id, COUNT(*) as order_count
|
||||
FROM orders
|
||||
WHERE created_at > '2022-01-01'
|
||||
GROUP BY user_id
|
||||
) oc ON u.id = oc.user_id
|
||||
WHERE u.status = 'active' AND o.total > 100
|
||||
ORDER BY o.total DESC"""
|
||||
partition_expr = "u.created_at BETWEEN '2023-01-01' AND '2023-12-31'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
|
||||
self.assertIn(
|
||||
"WHERE u.created_at BETWEEN '2023-01-01' AND '2023-12-31'", result
|
||||
)
|
||||
self.assertIn("WHERE created_at > '2022-01-01'", result)
|
||||
self.assertIn("ORDER BY o.total DESC", result)
|
||||
self.assertNotIn("WHERE u.status = 'active' AND o.total > 100", result)
|
||||
|
||||
def test_query_with_multiple_clauses_existing_where(self):
|
||||
"""Test query with WHERE, GROUP BY, HAVING, ORDER BY, LIMIT"""
|
||||
sql = """SELECT department, AVG(salary)
|
||||
FROM employees
|
||||
WHERE salary > 50000
|
||||
GROUP BY department
|
||||
HAVING AVG(salary) > 60000
|
||||
ORDER BY AVG(salary) DESC
|
||||
LIMIT 5"""
|
||||
partition_expr = "hire_date > '2020-01-01'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
expected = """SELECT department, AVG(salary)
|
||||
FROM employees
|
||||
WHERE hire_date > '2020-01-01'
|
||||
GROUP BY department
|
||||
HAVING AVG(salary) > 60000
|
||||
ORDER BY AVG(salary) DESC
|
||||
LIMIT 5"""
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_empty_sql_query(self):
|
||||
"""Test handling of empty SQL query"""
|
||||
sql = ""
|
||||
partition_expr = "status = 'active'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_malformed_sql_query(self):
|
||||
"""Test handling of malformed SQL query"""
|
||||
sql = "SELECT FROM"
|
||||
partition_expr = "status = 'active'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
def test_case_insensitive_keywords(self):
|
||||
"""Test that keyword matching is case insensitive"""
|
||||
sql = "select * from users where id > 100 order by name"
|
||||
partition_expr = "status = 'active'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
expected = "select * from users WHERE status = 'active' order by name"
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_deeply_nested_subqueries(self):
|
||||
"""Test handling of deeply nested subqueries"""
|
||||
sql = """SELECT * FROM table1
|
||||
WHERE id IN (
|
||||
SELECT t2.id FROM table2 t2
|
||||
WHERE t2.value > (
|
||||
SELECT AVG(t3.value) FROM table3 t3
|
||||
WHERE t3.category IN (
|
||||
SELECT category FROM categories
|
||||
WHERE active = 1
|
||||
)
|
||||
)
|
||||
)"""
|
||||
partition_expr = "created_at > '2023-01-01'"
|
||||
|
||||
result = self.validator._replace_where_clause(sql, partition_expr)
|
||||
|
||||
self.assertIn("WHERE created_at > '2023-01-01'", result)
|
||||
self.assertNotIn("WHERE id IN", result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -0,0 +1,425 @@
|
||||
"""
|
||||
Unit tests for BaseTableCustomSQLQueryValidator passed/failed row count logic
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery import (
|
||||
TableCustomSQLQueryValidator,
|
||||
)
|
||||
from metadata.generated.schema.tests.basic import TestCaseStatus
|
||||
|
||||
|
||||
class TestTableCustomSQLQueryRowCounts(unittest.TestCase):
|
||||
"""Test cases for passed/failed row count calculation logic"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.mock_runner = Mock()
|
||||
self.mock_test_case = Mock()
|
||||
self.mock_test_case.computePassedFailedRowCount = True
|
||||
self.mock_test_case.parameterValues = []
|
||||
self.mock_test_case.fullyQualifiedName = "test.case"
|
||||
self.mock_execution_date = int(datetime.now().timestamp())
|
||||
|
||||
self.validator = TableCustomSQLQueryValidator(
|
||||
runner=self.mock_runner,
|
||||
test_case=self.mock_test_case,
|
||||
execution_date=self.mock_execution_date,
|
||||
)
|
||||
|
||||
def _create_mock_param_values(
|
||||
self, operator, threshold, sql_expression="SELECT * FROM test"
|
||||
):
|
||||
"""Helper to create mock parameter values"""
|
||||
import json
|
||||
|
||||
# Create runtime parameters JSON
|
||||
runtime_params = {
|
||||
"conn_config": {
|
||||
"config": {
|
||||
"type": "Mysql",
|
||||
"scheme": "mysql+pymysql",
|
||||
"username": "test",
|
||||
"password": "test",
|
||||
"hostPort": "localhost:3306",
|
||||
"database": "test_db",
|
||||
}
|
||||
},
|
||||
"entity": {
|
||||
"id": "test-table-id",
|
||||
"name": "test_table",
|
||||
"fullyQualifiedName": "test.db.test_table",
|
||||
},
|
||||
}
|
||||
|
||||
# Create parameter mocks with explicit name attributes
|
||||
sql_param = Mock()
|
||||
sql_param.name = "sqlExpression"
|
||||
sql_param.value = sql_expression
|
||||
|
||||
operator_param = Mock()
|
||||
operator_param.name = "operator"
|
||||
operator_param.value = operator
|
||||
|
||||
threshold_param = Mock()
|
||||
threshold_param.name = "threshold"
|
||||
threshold_param.value = threshold
|
||||
|
||||
strategy_param = Mock()
|
||||
strategy_param.name = "strategy"
|
||||
strategy_param.value = "COUNT"
|
||||
|
||||
runtime_param = Mock()
|
||||
runtime_param.name = "TableCustomSQLQueryRuntimeParameters"
|
||||
runtime_param.value = json.dumps(runtime_params)
|
||||
|
||||
return [
|
||||
sql_param,
|
||||
operator_param,
|
||||
threshold_param,
|
||||
strategy_param,
|
||||
runtime_param,
|
||||
]
|
||||
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
|
||||
)
|
||||
def test_greater_than_operator_success(
|
||||
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
|
||||
):
|
||||
"""Test > operator when test passes"""
|
||||
self.mock_test_case.parameterValues = self._create_mock_param_values(">", 5)
|
||||
mock_run_results.return_value = 10
|
||||
mock_compute_row_count.return_value = 1000
|
||||
mock_get_runtime_params.return_value = Mock()
|
||||
|
||||
result = self.validator.run_validation()
|
||||
|
||||
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
|
||||
self.assertEqual(result.passedRows, 10)
|
||||
self.assertEqual(result.failedRows, 990)
|
||||
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
|
||||
)
|
||||
def test_greater_than_operator_failure(
|
||||
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
|
||||
):
|
||||
"""Test > operator when test fails (got fewer rows than expected)"""
|
||||
self.mock_test_case.parameterValues = self._create_mock_param_values(">", 10)
|
||||
mock_run_results.return_value = 5
|
||||
mock_compute_row_count.return_value = 1000
|
||||
mock_get_runtime_params.return_value = Mock()
|
||||
|
||||
result = self.validator.run_validation()
|
||||
|
||||
self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed)
|
||||
self.assertEqual(result.passedRows, 5)
|
||||
self.assertEqual(result.failedRows, 995)
|
||||
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
|
||||
)
|
||||
def test_greater_than_equal_operator_success(
|
||||
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
|
||||
):
|
||||
"""Test >= operator when test passes"""
|
||||
self.mock_test_case.parameterValues = self._create_mock_param_values(">=", 10)
|
||||
mock_run_results.return_value = 10
|
||||
mock_compute_row_count.return_value = 1000
|
||||
mock_get_runtime_params.return_value = Mock()
|
||||
|
||||
result = self.validator.run_validation()
|
||||
|
||||
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
|
||||
self.assertEqual(result.passedRows, 10)
|
||||
self.assertEqual(result.failedRows, 990)
|
||||
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
|
||||
)
|
||||
def test_greater_than_equal_operator_failure(
|
||||
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
|
||||
):
|
||||
"""Test >= operator when test fails"""
|
||||
self.mock_test_case.parameterValues = self._create_mock_param_values(">=", 15)
|
||||
mock_run_results.return_value = 8
|
||||
mock_compute_row_count.return_value = 1000
|
||||
mock_get_runtime_params.return_value = Mock()
|
||||
|
||||
result = self.validator.run_validation()
|
||||
|
||||
self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed)
|
||||
self.assertEqual(result.passedRows, 8)
|
||||
self.assertEqual(result.failedRows, 992)
|
||||
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
|
||||
)
|
||||
def test_less_than_operator_success(
|
||||
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
|
||||
):
|
||||
"""Test < operator when test passes"""
|
||||
self.mock_test_case.parameterValues = self._create_mock_param_values("<", 10)
|
||||
mock_run_results.return_value = 5
|
||||
mock_compute_row_count.return_value = 1000
|
||||
mock_get_runtime_params.return_value = Mock()
|
||||
|
||||
result = self.validator.run_validation()
|
||||
|
||||
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
|
||||
self.assertEqual(result.passedRows, 995)
|
||||
self.assertEqual(result.failedRows, 5)
|
||||
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
|
||||
)
|
||||
def test_less_than_operator_failure(
|
||||
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
|
||||
):
|
||||
"""Test < operator when test fails (got more rows than expected)"""
|
||||
self.mock_test_case.parameterValues = self._create_mock_param_values("<", 5)
|
||||
mock_run_results.return_value = 12
|
||||
mock_compute_row_count.return_value = 1000
|
||||
mock_get_runtime_params.return_value = Mock()
|
||||
|
||||
result = self.validator.run_validation()
|
||||
|
||||
self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed)
|
||||
self.assertEqual(result.passedRows, 988)
|
||||
self.assertEqual(result.failedRows, 12)
|
||||
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
|
||||
)
|
||||
def test_less_than_equal_operator_success(
|
||||
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
|
||||
):
|
||||
"""Test <= operator when test passes"""
|
||||
self.mock_test_case.parameterValues = self._create_mock_param_values("<=", 10)
|
||||
mock_run_results.return_value = 10
|
||||
mock_compute_row_count.return_value = 1000
|
||||
mock_get_runtime_params.return_value = Mock()
|
||||
|
||||
result = self.validator.run_validation()
|
||||
|
||||
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
|
||||
self.assertEqual(result.passedRows, 990)
|
||||
self.assertEqual(result.failedRows, 10)
|
||||
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
|
||||
)
|
||||
def test_less_than_equal_operator_failure(
|
||||
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
|
||||
):
|
||||
"""Test <= operator when test fails"""
|
||||
self.mock_test_case.parameterValues = self._create_mock_param_values("<=", 8)
|
||||
mock_run_results.return_value = 15
|
||||
mock_compute_row_count.return_value = 1000
|
||||
mock_get_runtime_params.return_value = Mock()
|
||||
|
||||
result = self.validator.run_validation()
|
||||
|
||||
self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed)
|
||||
self.assertEqual(result.passedRows, 985)
|
||||
self.assertEqual(result.failedRows, 15)
|
||||
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
|
||||
)
|
||||
def test_equal_operator_success(
|
||||
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
|
||||
):
|
||||
"""Test == operator when test passes"""
|
||||
self.mock_test_case.parameterValues = self._create_mock_param_values("==", 10)
|
||||
mock_run_results.return_value = 10
|
||||
mock_compute_row_count.return_value = 1000
|
||||
mock_get_runtime_params.return_value = Mock()
|
||||
|
||||
result = self.validator.run_validation()
|
||||
|
||||
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
|
||||
self.assertEqual(result.passedRows, 10)
|
||||
self.assertEqual(result.failedRows, 990)
|
||||
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
|
||||
)
|
||||
def test_equal_operator_failure_more_rows(
|
||||
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
|
||||
):
|
||||
"""Test == operator when test fails with more rows than expected"""
|
||||
self.mock_test_case.parameterValues = self._create_mock_param_values("==", 10)
|
||||
mock_run_results.return_value = 15
|
||||
mock_compute_row_count.return_value = 1000
|
||||
mock_get_runtime_params.return_value = Mock()
|
||||
|
||||
result = self.validator.run_validation()
|
||||
|
||||
self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed)
|
||||
self.assertEqual(result.passedRows, 995)
|
||||
self.assertEqual(result.failedRows, 5)
|
||||
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
|
||||
)
|
||||
def test_equal_operator_failure_fewer_rows(
|
||||
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
|
||||
):
|
||||
"""Test == operator when test fails with fewer rows than expected"""
|
||||
self.mock_test_case.parameterValues = self._create_mock_param_values("==", 10)
|
||||
mock_run_results.return_value = 3
|
||||
mock_compute_row_count.return_value = 1000
|
||||
mock_get_runtime_params.return_value = Mock()
|
||||
|
||||
result = self.validator.run_validation()
|
||||
|
||||
self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed)
|
||||
self.assertEqual(result.passedRows, 3)
|
||||
self.assertEqual(result.failedRows, 997)
|
||||
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
|
||||
)
|
||||
def test_edge_case_zero_threshold_greater_than(
|
||||
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
|
||||
):
|
||||
"""Test edge case with zero threshold and > operator"""
|
||||
self.mock_test_case.parameterValues = self._create_mock_param_values(">", 0)
|
||||
mock_run_results.return_value = 5
|
||||
mock_compute_row_count.return_value = 1000
|
||||
mock_get_runtime_params.return_value = Mock()
|
||||
|
||||
result = self.validator.run_validation()
|
||||
|
||||
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
|
||||
self.assertEqual(result.passedRows, 5)
|
||||
self.assertEqual(result.failedRows, 995)
|
||||
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
|
||||
)
|
||||
def test_edge_case_zero_actual_rows_less_than(
|
||||
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
|
||||
):
|
||||
"""Test edge case with zero actual rows and < operator"""
|
||||
self.mock_test_case.parameterValues = self._create_mock_param_values("<", 5)
|
||||
mock_run_results.return_value = 0
|
||||
mock_compute_row_count.return_value = 1000
|
||||
mock_get_runtime_params.return_value = Mock()
|
||||
|
||||
result = self.validator.run_validation()
|
||||
|
||||
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
|
||||
self.assertEqual(result.passedRows, 1000)
|
||||
self.assertEqual(result.failedRows, 0)
|
||||
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results"
|
||||
)
|
||||
@patch(
|
||||
"metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count"
|
||||
)
|
||||
def test_edge_case_negative_failed_rows_protection(
|
||||
self, mock_compute_row_count, mock_run_results, mock_get_runtime_params
|
||||
):
|
||||
"""Test protection against negative failed rows for >= operator"""
|
||||
self.mock_test_case.parameterValues = self._create_mock_param_values(">=", 5)
|
||||
mock_run_results.return_value = 10
|
||||
mock_compute_row_count.return_value = 1000
|
||||
mock_get_runtime_params.return_value = Mock()
|
||||
|
||||
result = self.validator.run_validation()
|
||||
|
||||
self.assertEqual(result.testCaseStatus, TestCaseStatus.Success)
|
||||
self.assertEqual(result.passedRows, 10)
|
||||
self.assertEqual(result.failedRows, 990)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -0,0 +1,192 @@
|
||||
"""
|
||||
Unit tests for edge cases with zero or negative thresholds
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
def calculate_less_than_failure_fixed(
|
||||
threshold: int, len_rows: int, row_count: int
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Fixed implementation of _calculate_less_than_failure
|
||||
"""
|
||||
if threshold <= 0:
|
||||
|
||||
if row_count:
|
||||
failed_rows = row_count
|
||||
passed_rows = 0
|
||||
else:
|
||||
|
||||
failed_rows = max(len_rows, 1)
|
||||
passed_rows = 0
|
||||
else:
|
||||
|
||||
failed_rows = max(0, len_rows - threshold)
|
||||
passed_rows = (row_count - failed_rows) if row_count else threshold
|
||||
|
||||
return max(0, passed_rows), max(0, failed_rows)
|
||||
|
||||
|
||||
def calculate_less_than_failure_old(
|
||||
threshold: int, len_rows: int, row_count: int
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Original buggy implementation for comparison
|
||||
"""
|
||||
failed_rows = max(0, len_rows - threshold)
|
||||
passed_rows = (row_count - failed_rows) if row_count else threshold
|
||||
return max(0, passed_rows), max(0, failed_rows)
|
||||
|
||||
|
||||
class TestZeroThresholdEdgeCases(unittest.TestCase):
|
||||
"""Test cases for zero and negative threshold edge cases"""
|
||||
|
||||
def test_bug_case_less_than_zero_threshold(self):
|
||||
"""Test the reported bug case: < 0 threshold"""
|
||||
|
||||
len_rows = 0
|
||||
threshold = 0
|
||||
row_count = 1
|
||||
|
||||
old_passed, old_failed = calculate_less_than_failure_old(
|
||||
threshold, len_rows, row_count
|
||||
)
|
||||
|
||||
new_passed, new_failed = calculate_less_than_failure_fixed(
|
||||
threshold, len_rows, row_count
|
||||
)
|
||||
|
||||
print(
|
||||
f"Bug case - len_rows={len_rows}, threshold={threshold}, row_count={row_count}"
|
||||
)
|
||||
print(f"Old: passed={old_passed}, failed={old_failed}")
|
||||
print(f"New: passed={new_passed}, failed={new_failed}")
|
||||
|
||||
self.assertEqual(old_passed, 1)
|
||||
self.assertEqual(old_failed, 0)
|
||||
|
||||
self.assertEqual(new_passed, 0)
|
||||
self.assertEqual(new_failed, 1)
|
||||
|
||||
def test_less_than_negative_threshold(self):
|
||||
"""Test < -1 threshold (impossible expectation)"""
|
||||
len_rows = 5
|
||||
threshold = -1
|
||||
row_count = 100
|
||||
|
||||
old_passed, old_failed = calculate_less_than_failure_old(
|
||||
threshold, len_rows, row_count
|
||||
)
|
||||
new_passed, new_failed = calculate_less_than_failure_fixed(
|
||||
threshold, len_rows, row_count
|
||||
)
|
||||
|
||||
print(
|
||||
f"Negative threshold - len_rows={len_rows}, threshold={threshold}, row_count={row_count}"
|
||||
)
|
||||
print(f"Old: passed={old_passed}, failed={old_failed}")
|
||||
print(f"New: passed={new_passed}, failed={new_failed}")
|
||||
|
||||
self.assertEqual(old_passed, 94)
|
||||
self.assertEqual(old_failed, 6)
|
||||
|
||||
self.assertEqual(new_passed, 0)
|
||||
self.assertEqual(new_failed, 100)
|
||||
|
||||
def test_less_than_zero_no_results(self):
|
||||
"""Test < 0 threshold with no query results"""
|
||||
len_rows = 0
|
||||
threshold = 0
|
||||
row_count = 50
|
||||
|
||||
old_passed, old_failed = calculate_less_than_failure_old(
|
||||
threshold, len_rows, row_count
|
||||
)
|
||||
new_passed, new_failed = calculate_less_than_failure_fixed(
|
||||
threshold, len_rows, row_count
|
||||
)
|
||||
|
||||
print(
|
||||
f"Zero threshold, no results - len_rows={len_rows}, threshold={threshold}, row_count={row_count}"
|
||||
)
|
||||
print(f"Old: passed={old_passed}, failed={old_failed}")
|
||||
print(f"New: passed={new_passed}, failed={new_failed}")
|
||||
|
||||
self.assertEqual(old_passed, 50)
|
||||
self.assertEqual(old_failed, 0)
|
||||
|
||||
self.assertEqual(new_passed, 0)
|
||||
self.assertEqual(new_failed, 50)
|
||||
|
||||
def test_normal_case_still_works(self):
|
||||
"""Test that normal cases (threshold > 0) still work correctly"""
|
||||
len_rows = 15
|
||||
threshold = 10
|
||||
row_count = 100
|
||||
|
||||
old_passed, old_failed = calculate_less_than_failure_old(
|
||||
threshold, len_rows, row_count
|
||||
)
|
||||
new_passed, new_failed = calculate_less_than_failure_fixed(
|
||||
threshold, len_rows, row_count
|
||||
)
|
||||
|
||||
print(
|
||||
f"Normal case - len_rows={len_rows}, threshold={threshold}, row_count={row_count}"
|
||||
)
|
||||
print(f"Old: passed={old_passed}, failed={old_failed}")
|
||||
print(f"New: passed={new_passed}, failed={new_failed}")
|
||||
|
||||
self.assertEqual(old_passed, new_passed)
|
||||
self.assertEqual(old_failed, new_failed)
|
||||
|
||||
self.assertEqual(new_passed, 95)
|
||||
self.assertEqual(new_failed, 5)
|
||||
|
||||
def test_edge_case_no_row_count(self):
|
||||
"""Test zero threshold with no total row count available"""
|
||||
len_rows = 0
|
||||
threshold = 0
|
||||
row_count = None
|
||||
|
||||
new_passed, new_failed = calculate_less_than_failure_fixed(
|
||||
threshold, len_rows, row_count
|
||||
)
|
||||
|
||||
print(
|
||||
f"No row count - len_rows={len_rows}, threshold={threshold}, row_count={row_count}"
|
||||
)
|
||||
print(f"New: passed={new_passed}, failed={new_failed}")
|
||||
|
||||
self.assertEqual(new_passed, 0)
|
||||
self.assertEqual(new_failed, 1)
|
||||
|
||||
def test_less_than_equal_zero_with_results(self):
|
||||
"""Test <= 0 threshold with some query results"""
|
||||
len_rows = 3
|
||||
threshold = 0
|
||||
row_count = 20
|
||||
|
||||
old_passed, old_failed = calculate_less_than_failure_old(
|
||||
threshold, len_rows, row_count
|
||||
)
|
||||
new_passed, new_failed = calculate_less_than_failure_fixed(
|
||||
threshold, len_rows, row_count
|
||||
)
|
||||
|
||||
print(
|
||||
f"<= 0 with results - len_rows={len_rows}, threshold={threshold}, row_count={row_count}"
|
||||
)
|
||||
print(f"Old: passed={old_passed}, failed={old_failed}")
|
||||
print(f"New: passed={new_passed}, failed={new_failed}")
|
||||
|
||||
self.assertEqual(old_passed, 17)
|
||||
self.assertEqual(old_failed, 3)
|
||||
|
||||
self.assertEqual(new_passed, 0)
|
||||
self.assertEqual(new_failed, 20)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user