From 57c5a50d2014b1774cee1f787a43096e74db57e9 Mon Sep 17 00:00:00 2001 From: Teddy Date: Tue, 23 Sep 2025 16:53:51 +0200 Subject: [PATCH] ISSUE #23435 - Fix pass / fail count for custom SQL (#23506) * fix: added logic to compute pass/fail for sql queries with cte, nested queries, and joins * added logic to correctly compute pass / fail rows * style: ran python linting * fix: failing tests * style: fix linting error * fix: flawed count logic * fix: handle case where we don't compute row count --- .../table/base/tableCustomSQLQuery.py | 115 ++++- .../table/sqlalchemy/tableCustomSQLQuery.py | 275 +++++++++++- .../metadata/great_expectations/action1xx.py | 3 +- .../tests/integration/postgres/conftest.py | 2 +- .../test_passed_failed_row_calculation.py | 180 ++++++++ .../test_row_count_logic_with_total.py | 302 +++++++++++++ .../test_table_custom_sql_query.py | 247 ++++++++++ .../test_table_custom_sql_query_row_counts.py | 425 ++++++++++++++++++ .../test_zero_threshold_edge_cases.py | 192 ++++++++ 9 files changed, 1728 insertions(+), 13 deletions(-) create mode 100644 ingestion/tests/unit/data_quality/validations/test_passed_failed_row_calculation.py create mode 100644 ingestion/tests/unit/data_quality/validations/test_row_count_logic_with_total.py create mode 100644 ingestion/tests/unit/data_quality/validations/test_table_custom_sql_query.py create mode 100644 ingestion/tests/unit/data_quality/validations/test_table_custom_sql_query_row_counts.py create mode 100644 ingestion/tests/unit/data_quality/validations/test_zero_threshold_edge_cases.py diff --git a/ingestion/src/metadata/data_quality/validations/table/base/tableCustomSQLQuery.py b/ingestion/src/metadata/data_quality/validations/table/base/tableCustomSQLQuery.py index 9f31fd822ed..80704edc037 100644 --- a/ingestion/src/metadata/data_quality/validations/table/base/tableCustomSQLQuery.py +++ b/ingestion/src/metadata/data_quality/validations/table/base/tableCustomSQLQuery.py @@ -87,11 +87,13 @@ class BaseTableCustomSQLQueryValidator(BaseTestValidator): [TestResultValue(name=RESULT_ROW_COUNT, value=None)], ) len_rows = rows if isinstance(rows, int) else len(rows) - if evaluate_threshold( + test_passed = evaluate_threshold( threshold, operator, len_rows, - ): + ) + + if test_passed: status = TestCaseStatus.Success result_value = len_rows else: @@ -99,17 +101,23 @@ class BaseTableCustomSQLQueryValidator(BaseTestValidator): result_value = len_rows if self.test_case.computePassedFailedRowCount: - row_count = self.get_row_count() + row_count = self._get_total_row_count_if_needed() + passed_rows, failed_rows = self._calculate_passed_failed_rows( + test_passed, operator, threshold, len_rows, row_count + ) else: + passed_rows = None + failed_rows = None row_count = None return self.get_test_case_result_object( self.execution_date, status, - f"Found {result_value} row(s). Test query is expected to return {threshold} row.", + f"Found {result_value} row(s). Test query is expected to return {operator} {threshold} row(s).", [TestResultValue(name=RESULT_ROW_COUNT, value=str(result_value))], row_count=row_count, - failed_rows=result_value, + failed_rows=failed_rows, + passed_rows=passed_rows, ) @abstractmethod @@ -132,3 +140,100 @@ class BaseTableCustomSQLQueryValidator(BaseTestValidator): Tuple[int, int]: """ return self.compute_row_count() + + def _get_total_row_count_if_needed(self) -> int: + """Get total row count if computePassedFailedRowCount is enabled""" + return self.get_row_count() + + def _calculate_passed_failed_rows( + self, + test_passed: bool, + operator: str, + threshold: int, + len_rows: int, + row_count: int, + ) -> tuple[int, int]: + """Calculate passed and failed rows based on test result and operator + + Args: + test_passed: Whether the test passed + operator: Comparison operator (>, >=, <, <=, ==) + threshold: Expected threshold value + len_rows: Number of rows returned by the test query + row_count: Total number of rows in the table (or None) + + Returns: + Tuple of (passed_rows, failed_rows) + """ + if test_passed: + return self._calculate_passed_rows_success(operator, len_rows, row_count) + return self._calculate_passed_rows_failure( + operator, threshold, len_rows, row_count + ) + + def _calculate_passed_rows_success( + self, operator: str, len_rows: int, row_count: int + ) -> tuple[int, int]: + """Calculate passed/failed rows when test passed""" + if operator in (">", ">="): + passed_rows = len_rows + failed_rows = (row_count - len_rows) if row_count else 0 + elif operator in ("<", "<="): + passed_rows = row_count - len_rows + failed_rows = len_rows + elif operator == "==": + passed_rows = len_rows + failed_rows = row_count - len_rows + else: + passed_rows = len_rows + failed_rows = 0 + + return max(0, passed_rows), max(0, failed_rows) + + def _calculate_passed_rows_failure( + self, operator: str, threshold: int, len_rows: int, row_count: int + ) -> tuple[int, int]: + """Calculate passed/failed rows when test failed""" + if operator in (">", ">="): + return self._calculate_greater_than_failure(len_rows, row_count) + if operator in ("<", "<="): + return self._calculate_less_than_failure(len_rows, row_count) + if operator == "==": + return self._calculate_equal_failure(threshold, len_rows, row_count) + + failed_rows = row_count if row_count else len_rows + return 0, max(0, failed_rows) + + def _calculate_greater_than_failure( + self, len_rows: int, row_count: int + ) -> tuple[int, int]: + """Calculate rows for > or >= operator failure (expected more rows)""" + passed_rows = len_rows + failed_rows = (row_count - len_rows) if row_count else 0 + return max(0, passed_rows), max(0, failed_rows) + + def _calculate_less_than_failure( + self, len_rows: int, row_count: int + ) -> tuple[int, int]: + """Calculate rows for < or <= operator failure (expected fewer rows)""" + failed_rows = len_rows + passed_rows = row_count - failed_rows + + return max(0, passed_rows), max(0, failed_rows) + + def _calculate_equal_failure( + self, threshold: int, len_rows: int, row_count: int + ) -> tuple[int, int]: + """Calculate rows for == operator failure (expected exact count)""" + if row_count: + if len_rows > threshold: + failed_rows = len_rows - threshold + passed_rows = row_count - failed_rows + else: + failed_rows = row_count - len_rows + passed_rows = len_rows + else: + failed_rows = abs(len_rows - threshold) + passed_rows = 0 + + return max(0, passed_rows), max(0, failed_rows) diff --git a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableCustomSQLQuery.py b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableCustomSQLQuery.py index 1cc85790561..35a50ea32d7 100644 --- a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableCustomSQLQuery.py +++ b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableCustomSQLQuery.py @@ -13,10 +13,13 @@ Validator for table custom SQL Query test case """ -from typing import Optional, cast +from typing import Optional, Tuple, cast +import sqlparse from sqlalchemy import text from sqlalchemy.sql import func, select +from sqlparse.sql import Statement, Token, Where +from sqlparse.tokens import Keyword from metadata.data_quality.validations.mixins.sqa_validator_mixin import ( SQAValidatorMixin, @@ -33,11 +36,244 @@ from metadata.profiler.metrics.registry import Metrics from metadata.profiler.orm.functions.table_metric_computer import TableMetricComputer from metadata.profiler.processor.runner import QueryRunner from metadata.utils.helpers import is_safe_sql_query +from metadata.utils.logger import ingestion_logger + +logger = ingestion_logger() class TableCustomSQLQueryValidator(BaseTableCustomSQLQueryValidator, SQAValidatorMixin): """Validator for table custom SQL Query test case""" + def _replace_where_clause( + self, sql_query: str, partition_expression: str + ) -> Optional[str]: + """Replace or add WHERE clause in SQL query using sqlparse. + + This method properly handles: + - Queries with existing WHERE clause (replaces it) + - Queries without WHERE clause (adds it) + - Complex queries with joins, subqueries, CTEs + - Preserves GROUP BY, ORDER BY, LIMIT, etc. + + Args: + sql_query: Original SQL query + partition_expression: New WHERE condition (without WHERE keyword) + + Returns: + Modified SQL query with partition_expression as WHERE clause + """ + parsed = sqlparse.parse(sql_query) + if not parsed or len(parsed) == 0: + return None + + statement: Statement = parsed[0] + tokens = list(statement.tokens) + + where_idx, where_end_idx, insert_before_idx = self._find_clause_positions( + tokens + ) + new_tokens = self._build_new_tokens( + tokens, where_idx, where_end_idx, insert_before_idx, partition_expression + ) + + return "".join(str(token) for token in new_tokens) + + def _find_clause_positions( + self, tokens: list + ) -> Tuple[Optional[int], Optional[int], Optional[int]]: + """Find positions of WHERE clause and insertion points in token list. + + Args: + tokens: List of parsed SQL tokens + + Returns: + Tuple of (where_idx, where_end_idx, insert_before_idx) + """ + where_idx = None + where_end_idx = None + insert_before_idx = None + paren_depth = 0 + + for i, token in enumerate(tokens): + paren_depth = self._update_parentheses_depth(token, paren_depth) + + if isinstance(token, Where) and paren_depth == 0: + where_idx = i + where_end_idx = i + 1 + break + + if self._should_insert_before_token(token, insert_before_idx, paren_depth): + insert_before_idx = i + + return where_idx, where_end_idx, insert_before_idx + + def _update_parentheses_depth(self, token: Token, current_depth: int) -> int: + """Update parentheses depth based on token content. + + Args: + token: SQL token to analyze + current_depth: Current parentheses depth + + Returns: + Updated parentheses depth + """ + if token.ttype is None and hasattr(token, "tokens"): + paren_count = str(token).count("(") - str(token).count(")") + return current_depth + paren_count + elif token.value == "(": + return current_depth + 1 + elif token.value == ")": + return current_depth - 1 + return current_depth + + def _should_insert_before_token( + self, token: Token, insert_before_idx: Optional[int], paren_depth: int + ) -> bool: + """Check if WHERE clause should be inserted before this token. + + Args: + token: SQL token to check + insert_before_idx: Current insertion index (None if not set) + paren_depth: Current parentheses depth + + Returns: + True if WHERE should be inserted before this token + """ + if insert_before_idx is not None or paren_depth != 0: + return False + + if token.ttype is not Keyword: + return False + + clause_keywords = { + "GROUP BY", + "ORDER BY", + "HAVING", + "LIMIT", + "OFFSET", + "UNION", + "EXCEPT", + "INTERSECT", + } + + return any(keyword in token.value.upper() for keyword in clause_keywords) + + def _build_new_tokens( + self, + tokens: list, + where_idx: Optional[int], + where_end_idx: Optional[int], + insert_before_idx: Optional[int], + partition_expression: str, + ) -> list: + """Build new token list with WHERE clause inserted or replaced. + + Args: + tokens: Original token list + where_idx: Index of existing WHERE clause (None if not found) + where_end_idx: End index of existing WHERE clause + insert_before_idx: Index to insert WHERE before (None if append) + partition_expression: WHERE condition expression + + Returns: + New list of tokens with WHERE clause + """ + if where_idx is not None: + return self._replace_existing_where( + tokens, where_idx, where_end_idx, partition_expression + ) + elif insert_before_idx is not None: + return self._insert_where_before_clause( + tokens, insert_before_idx, partition_expression + ) + else: + return self._append_where_clause(tokens, partition_expression) + + def _replace_existing_where( + self, + tokens: list, + where_idx: int, + where_end_idx: int, + partition_expression: str, + ) -> list: + """Replace existing WHERE clause with new expression. + + Args: + tokens: Original token list + where_idx: Index of WHERE clause to replace + where_end_idx: End index of WHERE clause + partition_expression: New WHERE condition + + Returns: + Token list with replaced WHERE clause + """ + original_where = str(tokens[where_idx]) + trailing_whitespace = self._extract_trailing_whitespace(original_where) + + return ( + tokens[:where_idx] + + [ + Token(Keyword, "WHERE"), + Token(None, f" {partition_expression}{trailing_whitespace}"), + ] + + tokens[where_end_idx:] + ) + + def _insert_where_before_clause( + self, tokens: list, insert_before_idx: int, partition_expression: str + ) -> list: + """Insert WHERE clause before specified token index. + + Args: + tokens: Original token list + insert_before_idx: Index to insert WHERE clause before + partition_expression: WHERE condition expression + + Returns: + Token list with WHERE clause inserted + """ + return ( + tokens[:insert_before_idx] + + [Token(Keyword, "WHERE"), Token(None, f" {partition_expression} ")] + + tokens[insert_before_idx:] + ) + + def _append_where_clause(self, tokens: list, partition_expression: str) -> list: + """Append WHERE clause to end of token list. + + Args: + tokens: Original token list + partition_expression: WHERE condition expression + + Returns: + Token list with WHERE clause appended + """ + return tokens + [ + Token(None, " "), + Token(Keyword, "WHERE"), + Token(None, f" {partition_expression}"), + ] + + def _extract_trailing_whitespace(self, where_clause: str) -> str: + """Extract trailing whitespace from WHERE clause string. + + Args: + where_clause: Original WHERE clause string + + Returns: + Trailing whitespace string + """ + if not where_clause.split(): + return "" + + last_word = where_clause.split()[-1] + where_content_end = where_clause.rfind(last_word) + len(last_word) + + if where_content_end < len(where_clause): + return where_clause[where_content_end:] + + return "" + def run_validation(self) -> TestCaseResult: """Run validation for the given test case @@ -85,12 +321,39 @@ class TableCustomSQLQueryValidator(BaseTableCustomSQLQueryValidator, SQAValidato None, ) if partition_expression: - stmt = ( - select(func.count()) - .select_from(self.runner.table) - .filter(text(partition_expression)) + custom_sql = self.get_test_case_param_value( + self.test_case.parameterValues, # type: ignore + "sqlExpression", + str, ) - return self.runner.session.execute(stmt).scalar() + + if custom_sql: + modified_query = self._replace_where_clause( + custom_sql, partition_expression + ) + if modified_query is None: + return None + count_query = f"SELECT COUNT(*) FROM ({modified_query}) AS test_results" + + try: + result = self.runner.session.execute(text(count_query)).scalar() + return result + except Exception as exc: + logger.error( + "Failed to execute custom SQL with partition expression. " + f"Query: {count_query}\n" + f"Error: {exc}\n", + exc_info=True, + ) + self.runner.session.rollback() + raise exc + else: + stmt = ( + select(func.count()) + .select_from(self.runner.table) + .filter(text(partition_expression)) + ) + return self.runner.session.execute(stmt).scalar() self.runner = cast(QueryRunner, self.runner) dialect = self.runner._session.get_bind().dialect.name diff --git a/ingestion/src/metadata/great_expectations/action1xx.py b/ingestion/src/metadata/great_expectations/action1xx.py index 0d497828d61..79a8c29bf85 100644 --- a/ingestion/src/metadata/great_expectations/action1xx.py +++ b/ingestion/src/metadata/great_expectations/action1xx.py @@ -86,7 +86,8 @@ class OpenMetadataValidationAction1xx(ValidationAction): # This will be initialized in the run method ometa_conn: Optional[OpenMetadata] = None - def run( # pylint: disable=unused-argument, arguments-differ + # pylint: disable=unused-argument,arguments-differ + def run( self, checkpoint_result: CheckpointResult, action_context: Union[ActionContext, None], diff --git a/ingestion/tests/integration/postgres/conftest.py b/ingestion/tests/integration/postgres/conftest.py index 3e8ea634d61..253c4219e72 100644 --- a/ingestion/tests/integration/postgres/conftest.py +++ b/ingestion/tests/integration/postgres/conftest.py @@ -26,7 +26,7 @@ def create_service_request(postgres_container, tmp_path_factory): username=postgres_container.username, authType=BasicAuth(password=postgres_container.password), hostPort="localhost:" - + postgres_container.get_exposed_port(postgres_container.port), + + str(postgres_container.get_exposed_port(postgres_container.port)), database="dvdrental", ) ), diff --git a/ingestion/tests/unit/data_quality/validations/test_passed_failed_row_calculation.py b/ingestion/tests/unit/data_quality/validations/test_passed_failed_row_calculation.py new file mode 100644 index 00000000000..7dec1b94bee --- /dev/null +++ b/ingestion/tests/unit/data_quality/validations/test_passed_failed_row_calculation.py @@ -0,0 +1,180 @@ +""" +Unit tests for passed/failed row count calculation logic +""" + +import unittest + + +def calculate_passed_failed_rows( + test_passed: bool, + operator: str, + threshold: int, + actual_rows: int, + total_rows: int = None, +): + """ + Calculate passed and failed rows based on test result, operator, threshold, and actual row count. + + This function replicates the LEGACY logic for cases without total row count. + Note: This is kept for backward compatibility but the new logic with total_rows is more accurate. + """ + if total_rows is None: + if test_passed: + return actual_rows, 0 + else: + if operator in (">", ">="): + failed_rows = 0 + passed_rows = actual_rows + elif operator in ("<", "<="): + failed_rows = max(0, actual_rows - threshold) + passed_rows = threshold + elif operator == "==": + failed_rows = abs(actual_rows - threshold) + passed_rows = 0 + else: + failed_rows = actual_rows + passed_rows = 0 + + return max(0, passed_rows), max(0, failed_rows) + else: + + raise NotImplementedError( + "Use test_row_count_logic_with_total.py for total row count tests" + ) + + +class TestPassedFailedRowCalculation(unittest.TestCase): + """Test cases for passed/failed row count calculation logic""" + + def test_greater_than_operator_success(self): + """Test > operator when test passes""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=True, operator=">", threshold=5, actual_rows=10 + ) + self.assertEqual(passed_rows, 10) + self.assertEqual(failed_rows, 0) + + def test_greater_than_operator_failure(self): + """Test > operator when test fails (got fewer rows than expected)""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=False, operator=">", threshold=10, actual_rows=5 + ) + self.assertEqual(passed_rows, 5) + self.assertEqual(failed_rows, 0) + + def test_greater_than_equal_operator_success(self): + """Test >= operator when test passes""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=True, operator=">=", threshold=10, actual_rows=10 + ) + self.assertEqual(passed_rows, 10) + self.assertEqual(failed_rows, 0) + + def test_greater_than_equal_operator_failure(self): + """Test >= operator when test fails""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=False, operator=">=", threshold=15, actual_rows=8 + ) + self.assertEqual(passed_rows, 8) + self.assertEqual(failed_rows, 0) + + def test_less_than_operator_success(self): + """Test < operator when test passes""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=True, operator="<", threshold=10, actual_rows=5 + ) + self.assertEqual(passed_rows, 5) + self.assertEqual(failed_rows, 0) + + def test_less_than_operator_failure(self): + """Test < operator when test fails (got more rows than expected)""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=False, operator="<", threshold=5, actual_rows=12 + ) + self.assertEqual(passed_rows, 5) + self.assertEqual(failed_rows, 7) + + def test_less_than_equal_operator_success(self): + """Test <= operator when test passes""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=True, operator="<=", threshold=10, actual_rows=10 + ) + self.assertEqual(passed_rows, 10) + self.assertEqual(failed_rows, 0) + + def test_less_than_equal_operator_failure(self): + """Test <= operator when test fails""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=False, operator="<=", threshold=8, actual_rows=15 + ) + self.assertEqual(passed_rows, 8) + self.assertEqual(failed_rows, 7) + + def test_equal_operator_success(self): + """Test == operator when test passes""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=True, operator="==", threshold=10, actual_rows=10 + ) + self.assertEqual(passed_rows, 10) + self.assertEqual(failed_rows, 0) + + def test_equal_operator_failure_more_rows(self): + """Test == operator when test fails with more rows than expected""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=False, operator="==", threshold=10, actual_rows=15 + ) + self.assertEqual(passed_rows, 0) + self.assertEqual(failed_rows, 5) + + def test_equal_operator_failure_fewer_rows(self): + """Test == operator when test fails with fewer rows than expected""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=False, operator="==", threshold=10, actual_rows=3 + ) + self.assertEqual(passed_rows, 0) + self.assertEqual(failed_rows, 7) + + def test_edge_case_zero_threshold_greater_than(self): + """Test edge case with zero threshold and > operator""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=True, operator=">", threshold=0, actual_rows=5 + ) + self.assertEqual(passed_rows, 5) + self.assertEqual(failed_rows, 0) + + def test_edge_case_zero_actual_rows_less_than(self): + """Test edge case with zero actual rows and < operator""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=True, operator="<", threshold=5, actual_rows=0 + ) + self.assertEqual(passed_rows, 0) + self.assertEqual(failed_rows, 0) + + def test_edge_case_negative_protection_greater_than_equal(self): + """Test protection against negative calculations for >= operator""" + + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=True, operator=">=", threshold=5, actual_rows=10 + ) + self.assertEqual(passed_rows, 10) + self.assertEqual(failed_rows, 0) + + def test_equal_operator_with_zero_threshold(self): + """Test == operator with zero threshold""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=False, operator="==", threshold=0, actual_rows=5 + ) + self.assertEqual(passed_rows, 0) + self.assertEqual(failed_rows, 5) + + def test_unknown_operator_fallback(self): + """Test fallback for unknown operators""" + passed_rows, failed_rows = calculate_passed_failed_rows( + test_passed=False, operator="!=", threshold=10, actual_rows=15 + ) + self.assertEqual(passed_rows, 0) + self.assertEqual(failed_rows, 15) + + +if __name__ == "__main__": + unittest.main() diff --git a/ingestion/tests/unit/data_quality/validations/test_row_count_logic_with_total.py b/ingestion/tests/unit/data_quality/validations/test_row_count_logic_with_total.py new file mode 100644 index 00000000000..ecb28151be8 --- /dev/null +++ b/ingestion/tests/unit/data_quality/validations/test_row_count_logic_with_total.py @@ -0,0 +1,302 @@ +""" +Unit tests for passed/failed row count calculation with total row count +""" + +import unittest + + +def calculate_passed_failed_rows_with_total( + test_passed: bool, + operator: str, + threshold: int, + actual_rows: int, + total_rows: int = None, +): + """ + Calculate passed and failed rows considering total row count. + + This replicates the fixed logic from BaseTableCustomSQLQueryValidator. + """ + row_count = total_rows + len_rows = actual_rows + + if test_passed: + + if operator in (">", ">=", "=="): + + passed_rows = len_rows + failed_rows = (row_count - len_rows) if row_count else 0 + elif operator in ("<", "<="): + + passed_rows = row_count if row_count else len_rows + failed_rows = 0 + else: + + passed_rows = len_rows + failed_rows = 0 + else: + + if operator in (">", ">="): + + passed_rows = len_rows + failed_rows = (row_count - len_rows) if row_count else 0 + elif operator in ("<", "<="): + + if threshold <= 0: + + if row_count: + failed_rows = row_count + passed_rows = 0 + else: + + failed_rows = max(len_rows, 1) + passed_rows = 0 + else: + + failed_rows = max(0, len_rows - threshold) + passed_rows = (row_count - failed_rows) if row_count else threshold + elif operator == "==": + + if row_count: + + if len_rows > threshold: + + failed_rows = len_rows - threshold + passed_rows = row_count - failed_rows + else: + + failed_rows = row_count - len_rows + passed_rows = len_rows + else: + + failed_rows = abs(len_rows - threshold) + passed_rows = 0 + else: + + failed_rows = row_count if row_count else len_rows + passed_rows = 0 + + passed_rows = max(0, passed_rows) + failed_rows = max(0, failed_rows) + + return passed_rows, failed_rows + + +class TestPassedFailedRowCalculationWithTotal(unittest.TestCase): + """Test cases for the fixed row count calculation logic with total row count""" + + def test_greater_than_operator_bug_case(self): + """Test the bug case: >= with threshold 0, 0 rows returned, 1 total row""" + + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=False, + operator=">=", + threshold=1, + actual_rows=0, + total_rows=1, + ) + + self.assertEqual(passed_rows, 0) + self.assertEqual(failed_rows, 1) + + def test_greater_than_with_total_rows_success(self): + """Test > operator success with total row count""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=True, + operator=">", + threshold=5, + actual_rows=10, + total_rows=100, + ) + self.assertEqual(passed_rows, 10) + self.assertEqual(failed_rows, 90) + + def test_greater_than_with_total_rows_failure(self): + """Test > operator failure with total row count""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=False, + operator=">", + threshold=50, + actual_rows=10, + total_rows=100, + ) + + self.assertEqual(passed_rows, 10) + self.assertEqual(failed_rows, 90) + + def test_less_than_with_total_rows_success(self): + """Test < operator success with total row count""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=True, + operator="<", + threshold=10, + actual_rows=5, + total_rows=100, + ) + + self.assertEqual(passed_rows, 100) + self.assertEqual(failed_rows, 0) + + def test_less_than_with_total_rows_failure(self): + """Test < operator failure with total row count""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=False, + operator="<", + threshold=10, + actual_rows=20, + total_rows=100, + ) + + self.assertEqual(failed_rows, 10) + self.assertEqual(passed_rows, 90) + + def test_less_than_equal_with_total_rows_success(self): + """Test <= operator success with total row count""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=True, + operator="<=", + threshold=10, + actual_rows=10, + total_rows=100, + ) + self.assertEqual(passed_rows, 100) + self.assertEqual(failed_rows, 0) + + def test_less_than_equal_with_total_rows_failure(self): + """Test <= operator failure with total row count""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=False, + operator="<=", + threshold=5, + actual_rows=15, + total_rows=100, + ) + + self.assertEqual(failed_rows, 10) + self.assertEqual(passed_rows, 90) + + def test_equal_with_total_rows_success(self): + """Test == operator success with total row count""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=True, + operator="==", + threshold=10, + actual_rows=10, + total_rows=100, + ) + self.assertEqual(passed_rows, 10) + self.assertEqual(failed_rows, 90) + + def test_equal_with_total_rows_failure_too_many(self): + """Test == operator failure with too many rows""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=False, + operator="==", + threshold=10, + actual_rows=15, + total_rows=100, + ) + + self.assertEqual(failed_rows, 5) + self.assertEqual(passed_rows, 95) + + def test_equal_with_total_rows_failure_too_few(self): + """Test == operator failure with too few rows""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=False, + operator="==", + threshold=10, + actual_rows=3, + total_rows=100, + ) + + self.assertEqual(failed_rows, 97) + self.assertEqual(passed_rows, 3) + + def test_edge_case_zero_total_rows(self): + """Test edge case with zero total rows""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=False, operator=">", threshold=5, actual_rows=0, total_rows=0 + ) + self.assertEqual(passed_rows, 0) + self.assertEqual(failed_rows, 0) + + def test_edge_case_no_total_row_count(self): + """Test when total row count is not available (None)""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=False, + operator=">", + threshold=10, + actual_rows=5, + total_rows=None, + ) + + self.assertEqual(passed_rows, 5) + self.assertEqual(failed_rows, 0) + + def test_greater_equal_zero_threshold_zero_result(self): + """Test >= 0 with 0 results but rows in table""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=True, + operator=">=", + threshold=0, + actual_rows=0, + total_rows=100, + ) + self.assertEqual(passed_rows, 0) + self.assertEqual(failed_rows, 100) + + def test_all_rows_match_greater_than(self): + """Test when all rows in table match the condition for >""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=True, + operator=">", + threshold=5, + actual_rows=100, + total_rows=100, + ) + self.assertEqual(passed_rows, 100) + self.assertEqual(failed_rows, 0) + + def test_no_rows_match_less_than(self): + """Test when no rows match the condition for <""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=True, + operator="<", + threshold=10, + actual_rows=0, + total_rows=100, + ) + self.assertEqual(passed_rows, 100) + self.assertEqual(failed_rows, 0) + + def test_less_than_zero_threshold_failure_bug_case(self): + """Test the reported bug case: < 0 threshold with failure""" + + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=False, + operator="<", + threshold=0, + actual_rows=0, + total_rows=1, + ) + + self.assertEqual(passed_rows, 0) + self.assertEqual(failed_rows, 1) + + def test_less_than_negative_threshold_failure(self): + """Test < -1 threshold failure (impossible expectation)""" + passed_rows, failed_rows = calculate_passed_failed_rows_with_total( + test_passed=False, + operator="<", + threshold=-1, + actual_rows=5, + total_rows=100, + ) + + self.assertEqual(passed_rows, 0) + self.assertEqual(failed_rows, 100) + + +if __name__ == "__main__": + unittest.main() diff --git a/ingestion/tests/unit/data_quality/validations/test_table_custom_sql_query.py b/ingestion/tests/unit/data_quality/validations/test_table_custom_sql_query.py new file mode 100644 index 00000000000..76e92f5d8e4 --- /dev/null +++ b/ingestion/tests/unit/data_quality/validations/test_table_custom_sql_query.py @@ -0,0 +1,247 @@ +""" +Unit tests for TableCustomSQLQueryValidator._replace_where_clause method +""" + +import unittest +from datetime import datetime +from unittest.mock import Mock + +from metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery import ( + TableCustomSQLQueryValidator, +) + + +class TestTableCustomSQLQueryValidator(unittest.TestCase): + """Test cases for TableCustomSQLQueryValidator._replace_where_clause method""" + + def setUp(self): + """Set up test fixtures""" + mock_runner = Mock() + mock_test_case = Mock() + mock_execution_date = datetime.now() + + self.validator = TableCustomSQLQueryValidator( + runner=mock_runner, + test_case=mock_test_case, + execution_date=mock_execution_date, + ) + + def test_simple_query_no_where_clause(self): + """Test adding WHERE clause to simple query without existing WHERE""" + sql = "SELECT * FROM users" + partition_expr = "created_at > '2023-01-01'" + + result = self.validator._replace_where_clause(sql, partition_expr) + expected = "SELECT * FROM users WHERE created_at > '2023-01-01'" + + self.assertEqual(result, expected) + + def test_simple_query_with_existing_where(self): + """Test replacing existing WHERE clause in simple query""" + sql = "SELECT * FROM users WHERE id > 100" + partition_expr = "created_at > '2023-01-01'" + + result = self.validator._replace_where_clause(sql, partition_expr) + expected = "SELECT * FROM users WHERE created_at > '2023-01-01'" + + self.assertEqual(result, expected) + + def test_query_with_order_by_no_where(self): + """Test adding WHERE clause before ORDER BY""" + sql = "SELECT * FROM users ORDER BY name" + partition_expr = "status = 'active'" + + result = self.validator._replace_where_clause(sql, partition_expr) + expected = "SELECT * FROM users WHERE status = 'active' ORDER BY name" + + self.assertEqual(result, expected) + + def test_query_with_group_by_no_where(self): + """Test adding WHERE clause before GROUP BY""" + sql = "SELECT department, COUNT(*) FROM employees GROUP BY department" + partition_expr = "hire_date > '2020-01-01'" + + result = self.validator._replace_where_clause(sql, partition_expr) + expected = "SELECT department, COUNT(*) FROM employees WHERE hire_date > '2020-01-01' GROUP BY department" + + self.assertEqual(result, expected) + + def test_query_with_subquery_inner_where_preserved(self): + """Test that WHERE clause in subquery is preserved""" + sql = "SELECT foo FROM a INNER JOIN (SELECT bar FROM b WHERE abc = 3) WHERE a.id BETWEEN 2 AND 4" + partition_expr = "a.status = 'active'" + + result = self.validator._replace_where_clause(sql, partition_expr) + expected = "SELECT foo FROM a INNER JOIN (SELECT bar FROM b WHERE abc = 3) WHERE a.status = 'active'" + + self.assertEqual(result, expected) + + def test_complex_subquery_multiple_levels(self): + """Test complex nested subqueries with multiple WHERE clauses""" + sql = """SELECT * FROM users u + WHERE u.id IN ( + SELECT user_id FROM orders o + WHERE o.total > ( + SELECT AVG(total) FROM orders WHERE status = 'completed' + ) + )""" + partition_expr = "u.created_at > '2023-01-01'" + + result = self.validator._replace_where_clause(sql, partition_expr) + + self.assertIn("WHERE u.created_at > '2023-01-01'", result) + self.assertNotIn("WHERE u.id IN", result) + + def test_query_with_cte_where_preserved(self): + """Test that WHERE clauses in CTEs are preserved""" + sql = """WITH active_users AS ( + SELECT * FROM users WHERE status = 'active' + ) + SELECT * FROM active_users WHERE id > 100""" + partition_expr = "created_at > '2023-01-01'" + + result = self.validator._replace_where_clause(sql, partition_expr) + + self.assertIn("WHERE status = 'active'", result) + self.assertIn("WHERE created_at > '2023-01-01'", result) + self.assertNotIn("WHERE id > 100", result) + + def test_query_with_union_no_where(self): + """Test adding WHERE clause before UNION""" + sql = "SELECT id FROM table1 UNION SELECT id FROM table2" + partition_expr = "status = 'active'" + + result = self.validator._replace_where_clause(sql, partition_expr) + expected = ( + "SELECT id FROM table1 WHERE status = 'active' UNION SELECT id FROM table2" + ) + + self.assertEqual(result, expected) + + def test_query_with_having_no_where(self): + """Test adding WHERE clause before HAVING""" + sql = "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5" + partition_expr = "hire_date > '2020-01-01'" + + result = self.validator._replace_where_clause(sql, partition_expr) + expected = "SELECT department, COUNT(*) FROM employees WHERE hire_date > '2020-01-01' GROUP BY department HAVING COUNT(*) > 5" + + self.assertEqual(result, expected) + + def test_query_with_limit_no_where(self): + """Test adding WHERE clause before LIMIT""" + sql = "SELECT * FROM users LIMIT 10" + partition_expr = "status = 'active'" + + result = self.validator._replace_where_clause(sql, partition_expr) + expected = "SELECT * FROM users WHERE status = 'active' LIMIT 10" + + self.assertEqual(result, expected) + + def test_query_with_offset_no_where(self): + """Test adding WHERE clause before OFFSET""" + sql = "SELECT * FROM users OFFSET 20" + partition_expr = "status = 'active'" + + result = self.validator._replace_where_clause(sql, partition_expr) + expected = "SELECT * FROM users WHERE status = 'active' OFFSET 20" + + self.assertEqual(result, expected) + + def test_complex_query_with_joins_and_subqueries(self): + """Test complex query with joins, subqueries, and existing WHERE""" + sql = """SELECT u.name, o.total + FROM users u + INNER JOIN orders o ON u.id = o.user_id + LEFT JOIN ( + SELECT user_id, COUNT(*) as order_count + FROM orders + WHERE created_at > '2022-01-01' + GROUP BY user_id + ) oc ON u.id = oc.user_id + WHERE u.status = 'active' AND o.total > 100 + ORDER BY o.total DESC""" + partition_expr = "u.created_at BETWEEN '2023-01-01' AND '2023-12-31'" + + result = self.validator._replace_where_clause(sql, partition_expr) + + self.assertIn( + "WHERE u.created_at BETWEEN '2023-01-01' AND '2023-12-31'", result + ) + self.assertIn("WHERE created_at > '2022-01-01'", result) + self.assertIn("ORDER BY o.total DESC", result) + self.assertNotIn("WHERE u.status = 'active' AND o.total > 100", result) + + def test_query_with_multiple_clauses_existing_where(self): + """Test query with WHERE, GROUP BY, HAVING, ORDER BY, LIMIT""" + sql = """SELECT department, AVG(salary) + FROM employees + WHERE salary > 50000 + GROUP BY department + HAVING AVG(salary) > 60000 + ORDER BY AVG(salary) DESC + LIMIT 5""" + partition_expr = "hire_date > '2020-01-01'" + + result = self.validator._replace_where_clause(sql, partition_expr) + expected = """SELECT department, AVG(salary) + FROM employees + WHERE hire_date > '2020-01-01' + GROUP BY department + HAVING AVG(salary) > 60000 + ORDER BY AVG(salary) DESC + LIMIT 5""" + + self.assertEqual(result, expected) + + def test_empty_sql_query(self): + """Test handling of empty SQL query""" + sql = "" + partition_expr = "status = 'active'" + + result = self.validator._replace_where_clause(sql, partition_expr) + + self.assertIsNone(result) + + def test_malformed_sql_query(self): + """Test handling of malformed SQL query""" + sql = "SELECT FROM" + partition_expr = "status = 'active'" + + result = self.validator._replace_where_clause(sql, partition_expr) + + self.assertIsNotNone(result) + + def test_case_insensitive_keywords(self): + """Test that keyword matching is case insensitive""" + sql = "select * from users where id > 100 order by name" + partition_expr = "status = 'active'" + + result = self.validator._replace_where_clause(sql, partition_expr) + expected = "select * from users WHERE status = 'active' order by name" + + self.assertEqual(result, expected) + + def test_deeply_nested_subqueries(self): + """Test handling of deeply nested subqueries""" + sql = """SELECT * FROM table1 + WHERE id IN ( + SELECT t2.id FROM table2 t2 + WHERE t2.value > ( + SELECT AVG(t3.value) FROM table3 t3 + WHERE t3.category IN ( + SELECT category FROM categories + WHERE active = 1 + ) + ) + )""" + partition_expr = "created_at > '2023-01-01'" + + result = self.validator._replace_where_clause(sql, partition_expr) + + self.assertIn("WHERE created_at > '2023-01-01'", result) + self.assertNotIn("WHERE id IN", result) + + +if __name__ == "__main__": + unittest.main() diff --git a/ingestion/tests/unit/data_quality/validations/test_table_custom_sql_query_row_counts.py b/ingestion/tests/unit/data_quality/validations/test_table_custom_sql_query_row_counts.py new file mode 100644 index 00000000000..f65aea7b7e7 --- /dev/null +++ b/ingestion/tests/unit/data_quality/validations/test_table_custom_sql_query_row_counts.py @@ -0,0 +1,425 @@ +""" +Unit tests for BaseTableCustomSQLQueryValidator passed/failed row count logic +""" + +import unittest +from datetime import datetime +from unittest.mock import Mock, patch + +from metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery import ( + TableCustomSQLQueryValidator, +) +from metadata.generated.schema.tests.basic import TestCaseStatus + + +class TestTableCustomSQLQueryRowCounts(unittest.TestCase): + """Test cases for passed/failed row count calculation logic""" + + def setUp(self): + """Set up test fixtures""" + self.mock_runner = Mock() + self.mock_test_case = Mock() + self.mock_test_case.computePassedFailedRowCount = True + self.mock_test_case.parameterValues = [] + self.mock_test_case.fullyQualifiedName = "test.case" + self.mock_execution_date = int(datetime.now().timestamp()) + + self.validator = TableCustomSQLQueryValidator( + runner=self.mock_runner, + test_case=self.mock_test_case, + execution_date=self.mock_execution_date, + ) + + def _create_mock_param_values( + self, operator, threshold, sql_expression="SELECT * FROM test" + ): + """Helper to create mock parameter values""" + import json + + # Create runtime parameters JSON + runtime_params = { + "conn_config": { + "config": { + "type": "Mysql", + "scheme": "mysql+pymysql", + "username": "test", + "password": "test", + "hostPort": "localhost:3306", + "database": "test_db", + } + }, + "entity": { + "id": "test-table-id", + "name": "test_table", + "fullyQualifiedName": "test.db.test_table", + }, + } + + # Create parameter mocks with explicit name attributes + sql_param = Mock() + sql_param.name = "sqlExpression" + sql_param.value = sql_expression + + operator_param = Mock() + operator_param.name = "operator" + operator_param.value = operator + + threshold_param = Mock() + threshold_param.name = "threshold" + threshold_param.value = threshold + + strategy_param = Mock() + strategy_param.name = "strategy" + strategy_param.value = "COUNT" + + runtime_param = Mock() + runtime_param.name = "TableCustomSQLQueryRuntimeParameters" + runtime_param.value = json.dumps(runtime_params) + + return [ + sql_param, + operator_param, + threshold_param, + strategy_param, + runtime_param, + ] + + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count" + ) + def test_greater_than_operator_success( + self, mock_compute_row_count, mock_run_results, mock_get_runtime_params + ): + """Test > operator when test passes""" + self.mock_test_case.parameterValues = self._create_mock_param_values(">", 5) + mock_run_results.return_value = 10 + mock_compute_row_count.return_value = 1000 + mock_get_runtime_params.return_value = Mock() + + result = self.validator.run_validation() + + self.assertEqual(result.testCaseStatus, TestCaseStatus.Success) + self.assertEqual(result.passedRows, 10) + self.assertEqual(result.failedRows, 990) + + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count" + ) + def test_greater_than_operator_failure( + self, mock_compute_row_count, mock_run_results, mock_get_runtime_params + ): + """Test > operator when test fails (got fewer rows than expected)""" + self.mock_test_case.parameterValues = self._create_mock_param_values(">", 10) + mock_run_results.return_value = 5 + mock_compute_row_count.return_value = 1000 + mock_get_runtime_params.return_value = Mock() + + result = self.validator.run_validation() + + self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed) + self.assertEqual(result.passedRows, 5) + self.assertEqual(result.failedRows, 995) + + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count" + ) + def test_greater_than_equal_operator_success( + self, mock_compute_row_count, mock_run_results, mock_get_runtime_params + ): + """Test >= operator when test passes""" + self.mock_test_case.parameterValues = self._create_mock_param_values(">=", 10) + mock_run_results.return_value = 10 + mock_compute_row_count.return_value = 1000 + mock_get_runtime_params.return_value = Mock() + + result = self.validator.run_validation() + + self.assertEqual(result.testCaseStatus, TestCaseStatus.Success) + self.assertEqual(result.passedRows, 10) + self.assertEqual(result.failedRows, 990) + + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count" + ) + def test_greater_than_equal_operator_failure( + self, mock_compute_row_count, mock_run_results, mock_get_runtime_params + ): + """Test >= operator when test fails""" + self.mock_test_case.parameterValues = self._create_mock_param_values(">=", 15) + mock_run_results.return_value = 8 + mock_compute_row_count.return_value = 1000 + mock_get_runtime_params.return_value = Mock() + + result = self.validator.run_validation() + + self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed) + self.assertEqual(result.passedRows, 8) + self.assertEqual(result.failedRows, 992) + + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count" + ) + def test_less_than_operator_success( + self, mock_compute_row_count, mock_run_results, mock_get_runtime_params + ): + """Test < operator when test passes""" + self.mock_test_case.parameterValues = self._create_mock_param_values("<", 10) + mock_run_results.return_value = 5 + mock_compute_row_count.return_value = 1000 + mock_get_runtime_params.return_value = Mock() + + result = self.validator.run_validation() + + self.assertEqual(result.testCaseStatus, TestCaseStatus.Success) + self.assertEqual(result.passedRows, 995) + self.assertEqual(result.failedRows, 5) + + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count" + ) + def test_less_than_operator_failure( + self, mock_compute_row_count, mock_run_results, mock_get_runtime_params + ): + """Test < operator when test fails (got more rows than expected)""" + self.mock_test_case.parameterValues = self._create_mock_param_values("<", 5) + mock_run_results.return_value = 12 + mock_compute_row_count.return_value = 1000 + mock_get_runtime_params.return_value = Mock() + + result = self.validator.run_validation() + + self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed) + self.assertEqual(result.passedRows, 988) + self.assertEqual(result.failedRows, 12) + + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count" + ) + def test_less_than_equal_operator_success( + self, mock_compute_row_count, mock_run_results, mock_get_runtime_params + ): + """Test <= operator when test passes""" + self.mock_test_case.parameterValues = self._create_mock_param_values("<=", 10) + mock_run_results.return_value = 10 + mock_compute_row_count.return_value = 1000 + mock_get_runtime_params.return_value = Mock() + + result = self.validator.run_validation() + + self.assertEqual(result.testCaseStatus, TestCaseStatus.Success) + self.assertEqual(result.passedRows, 990) + self.assertEqual(result.failedRows, 10) + + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count" + ) + def test_less_than_equal_operator_failure( + self, mock_compute_row_count, mock_run_results, mock_get_runtime_params + ): + """Test <= operator when test fails""" + self.mock_test_case.parameterValues = self._create_mock_param_values("<=", 8) + mock_run_results.return_value = 15 + mock_compute_row_count.return_value = 1000 + mock_get_runtime_params.return_value = Mock() + + result = self.validator.run_validation() + + self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed) + self.assertEqual(result.passedRows, 985) + self.assertEqual(result.failedRows, 15) + + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count" + ) + def test_equal_operator_success( + self, mock_compute_row_count, mock_run_results, mock_get_runtime_params + ): + """Test == operator when test passes""" + self.mock_test_case.parameterValues = self._create_mock_param_values("==", 10) + mock_run_results.return_value = 10 + mock_compute_row_count.return_value = 1000 + mock_get_runtime_params.return_value = Mock() + + result = self.validator.run_validation() + + self.assertEqual(result.testCaseStatus, TestCaseStatus.Success) + self.assertEqual(result.passedRows, 10) + self.assertEqual(result.failedRows, 990) + + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count" + ) + def test_equal_operator_failure_more_rows( + self, mock_compute_row_count, mock_run_results, mock_get_runtime_params + ): + """Test == operator when test fails with more rows than expected""" + self.mock_test_case.parameterValues = self._create_mock_param_values("==", 10) + mock_run_results.return_value = 15 + mock_compute_row_count.return_value = 1000 + mock_get_runtime_params.return_value = Mock() + + result = self.validator.run_validation() + + self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed) + self.assertEqual(result.passedRows, 995) + self.assertEqual(result.failedRows, 5) + + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count" + ) + def test_equal_operator_failure_fewer_rows( + self, mock_compute_row_count, mock_run_results, mock_get_runtime_params + ): + """Test == operator when test fails with fewer rows than expected""" + self.mock_test_case.parameterValues = self._create_mock_param_values("==", 10) + mock_run_results.return_value = 3 + mock_compute_row_count.return_value = 1000 + mock_get_runtime_params.return_value = Mock() + + result = self.validator.run_validation() + + self.assertEqual(result.testCaseStatus, TestCaseStatus.Failed) + self.assertEqual(result.passedRows, 3) + self.assertEqual(result.failedRows, 997) + + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count" + ) + def test_edge_case_zero_threshold_greater_than( + self, mock_compute_row_count, mock_run_results, mock_get_runtime_params + ): + """Test edge case with zero threshold and > operator""" + self.mock_test_case.parameterValues = self._create_mock_param_values(">", 0) + mock_run_results.return_value = 5 + mock_compute_row_count.return_value = 1000 + mock_get_runtime_params.return_value = Mock() + + result = self.validator.run_validation() + + self.assertEqual(result.testCaseStatus, TestCaseStatus.Success) + self.assertEqual(result.passedRows, 5) + self.assertEqual(result.failedRows, 995) + + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count" + ) + def test_edge_case_zero_actual_rows_less_than( + self, mock_compute_row_count, mock_run_results, mock_get_runtime_params + ): + """Test edge case with zero actual rows and < operator""" + self.mock_test_case.parameterValues = self._create_mock_param_values("<", 5) + mock_run_results.return_value = 0 + mock_compute_row_count.return_value = 1000 + mock_get_runtime_params.return_value = Mock() + + result = self.validator.run_validation() + + self.assertEqual(result.testCaseStatus, TestCaseStatus.Success) + self.assertEqual(result.passedRows, 1000) + self.assertEqual(result.failedRows, 0) + + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.get_runtime_parameters" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator._run_results" + ) + @patch( + "metadata.data_quality.validations.table.sqlalchemy.tableCustomSQLQuery.TableCustomSQLQueryValidator.compute_row_count" + ) + def test_edge_case_negative_failed_rows_protection( + self, mock_compute_row_count, mock_run_results, mock_get_runtime_params + ): + """Test protection against negative failed rows for >= operator""" + self.mock_test_case.parameterValues = self._create_mock_param_values(">=", 5) + mock_run_results.return_value = 10 + mock_compute_row_count.return_value = 1000 + mock_get_runtime_params.return_value = Mock() + + result = self.validator.run_validation() + + self.assertEqual(result.testCaseStatus, TestCaseStatus.Success) + self.assertEqual(result.passedRows, 10) + self.assertEqual(result.failedRows, 990) + + +if __name__ == "__main__": + unittest.main() diff --git a/ingestion/tests/unit/data_quality/validations/test_zero_threshold_edge_cases.py b/ingestion/tests/unit/data_quality/validations/test_zero_threshold_edge_cases.py new file mode 100644 index 00000000000..1684c441d72 --- /dev/null +++ b/ingestion/tests/unit/data_quality/validations/test_zero_threshold_edge_cases.py @@ -0,0 +1,192 @@ +""" +Unit tests for edge cases with zero or negative thresholds +""" + +import unittest + + +def calculate_less_than_failure_fixed( + threshold: int, len_rows: int, row_count: int +) -> tuple[int, int]: + """ + Fixed implementation of _calculate_less_than_failure + """ + if threshold <= 0: + + if row_count: + failed_rows = row_count + passed_rows = 0 + else: + + failed_rows = max(len_rows, 1) + passed_rows = 0 + else: + + failed_rows = max(0, len_rows - threshold) + passed_rows = (row_count - failed_rows) if row_count else threshold + + return max(0, passed_rows), max(0, failed_rows) + + +def calculate_less_than_failure_old( + threshold: int, len_rows: int, row_count: int +) -> tuple[int, int]: + """ + Original buggy implementation for comparison + """ + failed_rows = max(0, len_rows - threshold) + passed_rows = (row_count - failed_rows) if row_count else threshold + return max(0, passed_rows), max(0, failed_rows) + + +class TestZeroThresholdEdgeCases(unittest.TestCase): + """Test cases for zero and negative threshold edge cases""" + + def test_bug_case_less_than_zero_threshold(self): + """Test the reported bug case: < 0 threshold""" + + len_rows = 0 + threshold = 0 + row_count = 1 + + old_passed, old_failed = calculate_less_than_failure_old( + threshold, len_rows, row_count + ) + + new_passed, new_failed = calculate_less_than_failure_fixed( + threshold, len_rows, row_count + ) + + print( + f"Bug case - len_rows={len_rows}, threshold={threshold}, row_count={row_count}" + ) + print(f"Old: passed={old_passed}, failed={old_failed}") + print(f"New: passed={new_passed}, failed={new_failed}") + + self.assertEqual(old_passed, 1) + self.assertEqual(old_failed, 0) + + self.assertEqual(new_passed, 0) + self.assertEqual(new_failed, 1) + + def test_less_than_negative_threshold(self): + """Test < -1 threshold (impossible expectation)""" + len_rows = 5 + threshold = -1 + row_count = 100 + + old_passed, old_failed = calculate_less_than_failure_old( + threshold, len_rows, row_count + ) + new_passed, new_failed = calculate_less_than_failure_fixed( + threshold, len_rows, row_count + ) + + print( + f"Negative threshold - len_rows={len_rows}, threshold={threshold}, row_count={row_count}" + ) + print(f"Old: passed={old_passed}, failed={old_failed}") + print(f"New: passed={new_passed}, failed={new_failed}") + + self.assertEqual(old_passed, 94) + self.assertEqual(old_failed, 6) + + self.assertEqual(new_passed, 0) + self.assertEqual(new_failed, 100) + + def test_less_than_zero_no_results(self): + """Test < 0 threshold with no query results""" + len_rows = 0 + threshold = 0 + row_count = 50 + + old_passed, old_failed = calculate_less_than_failure_old( + threshold, len_rows, row_count + ) + new_passed, new_failed = calculate_less_than_failure_fixed( + threshold, len_rows, row_count + ) + + print( + f"Zero threshold, no results - len_rows={len_rows}, threshold={threshold}, row_count={row_count}" + ) + print(f"Old: passed={old_passed}, failed={old_failed}") + print(f"New: passed={new_passed}, failed={new_failed}") + + self.assertEqual(old_passed, 50) + self.assertEqual(old_failed, 0) + + self.assertEqual(new_passed, 0) + self.assertEqual(new_failed, 50) + + def test_normal_case_still_works(self): + """Test that normal cases (threshold > 0) still work correctly""" + len_rows = 15 + threshold = 10 + row_count = 100 + + old_passed, old_failed = calculate_less_than_failure_old( + threshold, len_rows, row_count + ) + new_passed, new_failed = calculate_less_than_failure_fixed( + threshold, len_rows, row_count + ) + + print( + f"Normal case - len_rows={len_rows}, threshold={threshold}, row_count={row_count}" + ) + print(f"Old: passed={old_passed}, failed={old_failed}") + print(f"New: passed={new_passed}, failed={new_failed}") + + self.assertEqual(old_passed, new_passed) + self.assertEqual(old_failed, new_failed) + + self.assertEqual(new_passed, 95) + self.assertEqual(new_failed, 5) + + def test_edge_case_no_row_count(self): + """Test zero threshold with no total row count available""" + len_rows = 0 + threshold = 0 + row_count = None + + new_passed, new_failed = calculate_less_than_failure_fixed( + threshold, len_rows, row_count + ) + + print( + f"No row count - len_rows={len_rows}, threshold={threshold}, row_count={row_count}" + ) + print(f"New: passed={new_passed}, failed={new_failed}") + + self.assertEqual(new_passed, 0) + self.assertEqual(new_failed, 1) + + def test_less_than_equal_zero_with_results(self): + """Test <= 0 threshold with some query results""" + len_rows = 3 + threshold = 0 + row_count = 20 + + old_passed, old_failed = calculate_less_than_failure_old( + threshold, len_rows, row_count + ) + new_passed, new_failed = calculate_less_than_failure_fixed( + threshold, len_rows, row_count + ) + + print( + f"<= 0 with results - len_rows={len_rows}, threshold={threshold}, row_count={row_count}" + ) + print(f"Old: passed={old_passed}, failed={old_failed}") + print(f"New: passed={new_passed}, failed={new_failed}") + + self.assertEqual(old_passed, 17) + self.assertEqual(old_failed, 3) + + self.assertEqual(new_passed, 0) + self.assertEqual(new_failed, 20) + + +if __name__ == "__main__": + unittest.main()