MINOR: Fix data diff with threshold (#16926)

* fix: table-diff

passed threshold and diff count in wrong order. test was not covering this due to how the parameters were configured.
This commit is contained in:
Imri Paran 2024-07-05 07:51:24 +02:00 committed by GitHub
parent 61e63386c5
commit d08af1f86d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 9 deletions

View File

@ -11,7 +11,7 @@
# pylint: disable=missing-module-docstring
import traceback
from itertools import islice
from typing import Dict, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple
from urllib.parse import urlparse
import data_diff
@ -117,7 +117,11 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
if stats["total"] > 0:
logger.debug("Sample of failed rows:")
for s in islice(self.get_table_diff(), 10):
logger.debug(s)
# since the data can contiant sensitive information, we don't want to log it
# we can uncomment this line if we must see the data in the logs
# logger.debug(s)
# by default we will log the data masked
logger.debug([s[0], ["*" for _ in s[1]]])
test_case_result = self.get_row_diff_test_case_result(
threshold,
stats["total"],
@ -127,7 +131,6 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
)
count = self._compute_row_count(self.runner, None) # type: ignore
test_case_result.passedRows = stats["unchanged"]
test_case_result.failedRows = stats["total"]
test_case_result.passedRowsPercentage = (
test_case_result.passedRows / count * 100
)
@ -135,10 +138,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
test_case_result.failedRows / count * 100
)
return test_case_result
num_dffs = sum(1 for _ in islice(table_diff_iter, threshold))
return self.get_row_diff_test_case_result(
num_dffs,
threshold,
self.calculate_diffs_with_limit(table_diff_iter, threshold),
)
def get_incomparable_columns(self) -> List[str]:
@ -274,6 +276,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
(threshold or total_diffs) == 0 or total_diffs < threshold
),
result=f"Found {total_diffs} different rows which is more than the threshold of {threshold}",
failedRows=total_diffs,
validateColumns=False,
testResultValue=[
TestResultValue(name="removedRows", value=str(removed)),
@ -374,3 +377,32 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
TestResultValue(name="changedColumns", value=str(len(changed))),
],
)
def calculate_diffs_with_limit(
self, diff_iter: Iterable[Tuple[str, Tuple[str, ...]]], limit: int
) -> int:
"""Given an iterator of diffs like
- ('+', (...))
- ('-', (...))
...
Calculate the total diffs by combining diffs for the same key. This gives an accurate count of the total diffs
as opposed to self.calculate_diff_num(diff_list)just counting the number of diffs in the list.
Args:
diff_iter: iterator returned from the data_diff algorithm
Returns:
int: accurate count of the total diffs up to the limit
"""
len_key_columns = len(self.runtime_params.keyColumns)
key_set = set()
# combine diffs on same key to "!"
for _, values in diff_iter:
k = values[:len_key_columns]
if k in key_set:
continue
key_set.add(k)
if len(key_set) > limit:
len(key_set)
return len(key_set)

View File

@ -133,7 +133,6 @@ class TestParameters(BaseModel):
TestCaseDefinition(
name="with_passing_threshold",
testDefinitionName="tableDiff",
computePassedFailedRowCount=True,
parameterValues=[
TestCaseParameterValue(name="threshold", value="322"),
],
@ -141,7 +140,6 @@ class TestParameters(BaseModel):
"POSTGRES_SERVICE.dvdrental.public.changed_customer",
TestCaseResult(
testCaseStatus=TestCaseStatus.Success,
passedRows=278,
failedRows=321,
),
),
@ -149,7 +147,6 @@ class TestParameters(BaseModel):
TestCaseDefinition(
name="with_failing_threshold",
testDefinitionName="tableDiff",
computePassedFailedRowCount=True,
parameterValues=[
TestCaseParameterValue(name="threshold", value="321"),
],
@ -157,7 +154,6 @@ class TestParameters(BaseModel):
"POSTGRES_SERVICE.dvdrental.public.changed_customer",
TestCaseResult(
testCaseStatus=TestCaseStatus.Failed,
passedRows=278,
failedRows=321,
),
),