MINOR: Fix data diff with threshold (#16926)

* fix: table-diff

passed threshold and diff count in wrong order. test was not covering this due to how the parameters were configured.
This commit is contained in:
Imri Paran 2024-07-05 07:51:24 +02:00 committed by GitHub
parent 61e63386c5
commit d08af1f86d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 9 deletions

View File

@ -11,7 +11,7 @@
# pylint: disable=missing-module-docstring # pylint: disable=missing-module-docstring
import traceback import traceback
from itertools import islice from itertools import islice
from typing import Dict, List, Optional, Tuple from typing import Dict, Iterable, List, Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
import data_diff import data_diff
@ -117,7 +117,11 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
if stats["total"] > 0: if stats["total"] > 0:
logger.debug("Sample of failed rows:") logger.debug("Sample of failed rows:")
for s in islice(self.get_table_diff(), 10): for s in islice(self.get_table_diff(), 10):
logger.debug(s) # since the data can contiant sensitive information, we don't want to log it
# we can uncomment this line if we must see the data in the logs
# logger.debug(s)
# by default we will log the data masked
logger.debug([s[0], ["*" for _ in s[1]]])
test_case_result = self.get_row_diff_test_case_result( test_case_result = self.get_row_diff_test_case_result(
threshold, threshold,
stats["total"], stats["total"],
@ -127,7 +131,6 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
) )
count = self._compute_row_count(self.runner, None) # type: ignore count = self._compute_row_count(self.runner, None) # type: ignore
test_case_result.passedRows = stats["unchanged"] test_case_result.passedRows = stats["unchanged"]
test_case_result.failedRows = stats["total"]
test_case_result.passedRowsPercentage = ( test_case_result.passedRowsPercentage = (
test_case_result.passedRows / count * 100 test_case_result.passedRows / count * 100
) )
@ -135,10 +138,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
test_case_result.failedRows / count * 100 test_case_result.failedRows / count * 100
) )
return test_case_result return test_case_result
num_dffs = sum(1 for _ in islice(table_diff_iter, threshold))
return self.get_row_diff_test_case_result( return self.get_row_diff_test_case_result(
num_dffs,
threshold, threshold,
self.calculate_diffs_with_limit(table_diff_iter, threshold),
) )
def get_incomparable_columns(self) -> List[str]: def get_incomparable_columns(self) -> List[str]:
@ -274,6 +276,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
(threshold or total_diffs) == 0 or total_diffs < threshold (threshold or total_diffs) == 0 or total_diffs < threshold
), ),
result=f"Found {total_diffs} different rows which is more than the threshold of {threshold}", result=f"Found {total_diffs} different rows which is more than the threshold of {threshold}",
failedRows=total_diffs,
validateColumns=False, validateColumns=False,
testResultValue=[ testResultValue=[
TestResultValue(name="removedRows", value=str(removed)), TestResultValue(name="removedRows", value=str(removed)),
@ -374,3 +377,32 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
TestResultValue(name="changedColumns", value=str(len(changed))), TestResultValue(name="changedColumns", value=str(len(changed))),
], ],
) )
def calculate_diffs_with_limit(
self, diff_iter: Iterable[Tuple[str, Tuple[str, ...]]], limit: int
) -> int:
"""Given an iterator of diffs like
- ('+', (...))
- ('-', (...))
...
Calculate the total diffs by combining diffs for the same key. This gives an accurate count of the total diffs
as opposed to self.calculate_diff_num(diff_list)just counting the number of diffs in the list.
Args:
diff_iter: iterator returned from the data_diff algorithm
Returns:
int: accurate count of the total diffs up to the limit
"""
len_key_columns = len(self.runtime_params.keyColumns)
key_set = set()
# combine diffs on same key to "!"
for _, values in diff_iter:
k = values[:len_key_columns]
if k in key_set:
continue
key_set.add(k)
if len(key_set) > limit:
len(key_set)
return len(key_set)

View File

@ -133,7 +133,6 @@ class TestParameters(BaseModel):
TestCaseDefinition( TestCaseDefinition(
name="with_passing_threshold", name="with_passing_threshold",
testDefinitionName="tableDiff", testDefinitionName="tableDiff",
computePassedFailedRowCount=True,
parameterValues=[ parameterValues=[
TestCaseParameterValue(name="threshold", value="322"), TestCaseParameterValue(name="threshold", value="322"),
], ],
@ -141,7 +140,6 @@ class TestParameters(BaseModel):
"POSTGRES_SERVICE.dvdrental.public.changed_customer", "POSTGRES_SERVICE.dvdrental.public.changed_customer",
TestCaseResult( TestCaseResult(
testCaseStatus=TestCaseStatus.Success, testCaseStatus=TestCaseStatus.Success,
passedRows=278,
failedRows=321, failedRows=321,
), ),
), ),
@ -149,7 +147,6 @@ class TestParameters(BaseModel):
TestCaseDefinition( TestCaseDefinition(
name="with_failing_threshold", name="with_failing_threshold",
testDefinitionName="tableDiff", testDefinitionName="tableDiff",
computePassedFailedRowCount=True,
parameterValues=[ parameterValues=[
TestCaseParameterValue(name="threshold", value="321"), TestCaseParameterValue(name="threshold", value="321"),
], ],
@ -157,7 +154,6 @@ class TestParameters(BaseModel):
"POSTGRES_SERVICE.dvdrental.public.changed_customer", "POSTGRES_SERVICE.dvdrental.public.changed_customer",
TestCaseResult( TestCaseResult(
testCaseStatus=TestCaseStatus.Failed, testCaseStatus=TestCaseStatus.Failed,
passedRows=278,
failedRows=321, failedRows=321,
), ),
), ),