diff --git a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py index 9167de56d84..e157a148754 100644 --- a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py +++ b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py @@ -11,7 +11,7 @@ # pylint: disable=missing-module-docstring import traceback from itertools import islice -from typing import Dict, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple from urllib.parse import urlparse import data_diff @@ -117,7 +117,11 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): if stats["total"] > 0: logger.debug("Sample of failed rows:") for s in islice(self.get_table_diff(), 10): - logger.debug(s) + # since the data can contiant sensitive information, we don't want to log it + # we can uncomment this line if we must see the data in the logs + # logger.debug(s) + # by default we will log the data masked + logger.debug([s[0], ["*" for _ in s[1]]]) test_case_result = self.get_row_diff_test_case_result( threshold, stats["total"], @@ -127,7 +131,6 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): ) count = self._compute_row_count(self.runner, None) # type: ignore test_case_result.passedRows = stats["unchanged"] - test_case_result.failedRows = stats["total"] test_case_result.passedRowsPercentage = ( test_case_result.passedRows / count * 100 ) @@ -135,10 +138,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): test_case_result.failedRows / count * 100 ) return test_case_result - num_dffs = sum(1 for _ in islice(table_diff_iter, threshold)) return self.get_row_diff_test_case_result( - num_dffs, threshold, + self.calculate_diffs_with_limit(table_diff_iter, threshold), ) def get_incomparable_columns(self) -> List[str]: @@ -274,6 +276,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): (threshold or total_diffs) == 0 or total_diffs < threshold ), result=f"Found {total_diffs} different rows which is more than the threshold of {threshold}", + failedRows=total_diffs, validateColumns=False, testResultValue=[ TestResultValue(name="removedRows", value=str(removed)), @@ -374,3 +377,32 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): TestResultValue(name="changedColumns", value=str(len(changed))), ], ) + + def calculate_diffs_with_limit( + self, diff_iter: Iterable[Tuple[str, Tuple[str, ...]]], limit: int + ) -> int: + """Given an iterator of diffs like + - ('+', (...)) + - ('-', (...)) + ... + Calculate the total diffs by combining diffs for the same key. This gives an accurate count of the total diffs + as opposed to self.calculate_diff_num(diff_list)just counting the number of diffs in the list. + + Args: + diff_iter: iterator returned from the data_diff algorithm + + Returns: + int: accurate count of the total diffs up to the limit + + """ + len_key_columns = len(self.runtime_params.keyColumns) + key_set = set() + # combine diffs on same key to "!" + for _, values in diff_iter: + k = values[:len_key_columns] + if k in key_set: + continue + key_set.add(k) + if len(key_set) > limit: + len(key_set) + return len(key_set) diff --git a/ingestion/tests/integration/data_quality/test_data_diff.py b/ingestion/tests/integration/data_quality/test_data_diff.py index 59a7d6daf6f..1e2a5f5a8ac 100644 --- a/ingestion/tests/integration/data_quality/test_data_diff.py +++ b/ingestion/tests/integration/data_quality/test_data_diff.py @@ -133,7 +133,6 @@ class TestParameters(BaseModel): TestCaseDefinition( name="with_passing_threshold", testDefinitionName="tableDiff", - computePassedFailedRowCount=True, parameterValues=[ TestCaseParameterValue(name="threshold", value="322"), ], @@ -141,7 +140,6 @@ class TestParameters(BaseModel): "POSTGRES_SERVICE.dvdrental.public.changed_customer", TestCaseResult( testCaseStatus=TestCaseStatus.Success, - passedRows=278, failedRows=321, ), ), @@ -149,7 +147,6 @@ class TestParameters(BaseModel): TestCaseDefinition( name="with_failing_threshold", testDefinitionName="tableDiff", - computePassedFailedRowCount=True, parameterValues=[ TestCaseParameterValue(name="threshold", value="321"), ], @@ -157,7 +154,6 @@ class TestParameters(BaseModel): "POSTGRES_SERVICE.dvdrental.public.changed_customer", TestCaseResult( testCaseStatus=TestCaseStatus.Failed, - passedRows=278, failedRows=321, ), ),