Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

210 lines
8.1 KiB
Python
Raw Normal View History

Fix #16421: add tableDiff test case (#16554) * feat: add tableDiff test case This changed introduces a "table diff" test case which compares two tables and fails if they are not identical. The similarity is made based on a specific "key" (because the test only makes sense when performed on ordered collections). 1. Added the `tableDiff` test definition. 2. Implemented a "runtime" parameters feature which injects additional parameters for the test at runtime. 3. Integration tests (because of course). This feature was not tested end-to-end yet because "array" data * pydantic v2 * format * format * format and added data diff to setup.py * format * fixed param issue which has type ARRAY * fixed runtime_parameter_setter * moved models to parent directory * handle errors in table diff * fixed issue with edit test case * format * added more details to pytest skip * format * refactor: Improve createTestCaseParameters function in DataQualityUtils * fixed unit test * removed unused fixture * removed validator.py * fixed tests * added validate kwarg to tests_mixin * removed "postgres" data diff extra as they interfere with psycopg2-binary * fixed tests * pinned tenacity for tests * reverted tenacity pinning * added ui support for test diff * fixed dq cypress and added edit flow * organized the test case * added dialect support * fixed tests * option style fix * fixed calculation for passing/failing rows * restrict the tableDiff test to limited services * set where to None if blank string * fixed where clause * fixed tests for where clause * use displayName in place of name in edit form * added docs for RuntimeParameterSetter * fixed cypress --------- Co-authored-by: Shailesh Parmar <shailesh.parmar.webdev@gmail.com>
2024-06-20 16:54:12 +02:00
# Copyright 2024 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=missing-module-docstring
import traceback
from itertools import islice
from typing import Optional
from urllib.parse import urlparse
import data_diff
from data_diff.diff_tables import DiffResultWrapper
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
SQAValidatorMixin,
)
from metadata.data_quality.validations.models import TableDiffRuntimeParameters
from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import (
TableDiffParamsSetter,
)
from metadata.generated.schema.entity.services.connections.database.sapHanaConnection import (
SapHanaScheme,
)
from metadata.generated.schema.tests.basic import (
TestCaseResult,
TestCaseStatus,
TestResultValue,
)
from metadata.profiler.orm.registry import Dialects
from metadata.utils.logger import test_suite_logger
logger = test_suite_logger()
SUPPORTED_DIALECTS = [
Dialects.Snowflake,
Dialects.BigQuery,
Dialects.Athena,
Dialects.Redshift,
Dialects.Postgres,
Dialects.MySQL,
Dialects.MSSQL,
Dialects.Oracle,
Dialects.Trino,
SapHanaScheme.hana.value,
]
class UnsupportedDialectError(Exception):
def __init__(self, param: str, dialect: str):
super().__init__(f"Unsupported dialect in param {param}: {dialect}")
class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
"""
Compare two tables and fail if the number of differences exceeds a threshold
"""
runtime_parameter_setter = TableDiffParamsSetter
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.runtime_params: TableDiffRuntimeParameters = self.get_runtime_params()
def run_validation(self) -> TestCaseResult:
try:
self._validate_dialects()
return self._run()
except KeyError as e:
result = TestCaseResult(
timestamp=self.execution_date, # type: ignore
testCaseStatus=TestCaseStatus.Failed,
result=f"MISMATCHED_COLUMNS: One of the tables is missing the column: '{e}'\n"
"Use two tables with the same schema or provide the extraColumns parameter.",
testResultValue=[TestResultValue(name="diffCount", value=str(0))],
)
logger.error(result.result)
return result
except UnsupportedDialectError as e:
result = TestCaseResult(
timestamp=self.execution_date, # type: ignore
testCaseStatus=TestCaseStatus.Aborted,
result=str(e),
)
return result
except Exception as e:
logger.error(
f"Unexpected error while running the table diff test: {str(e)}\n{traceback.format_exc()}"
)
result = TestCaseResult(
timestamp=self.execution_date, # type: ignore
testCaseStatus=TestCaseStatus.Aborted,
result=f"ERROR: Unexpected error while running the table diff test: {str(e)}",
)
logger.debug(result.result)
return result
def _run(self) -> TestCaseResult:
threshold = self.get_test_case_param_value(
self.test_case.parameterValues, "threshold", int, default=0
)
table_diff_iter = self.get_table_diff()
if not threshold or self.test_case.computePassedFailedRowCount:
stats = table_diff_iter.get_stats_dict()
if stats["total"] > 0:
logger.debug("Sample of failed rows:")
for s in islice(self.get_table_diff(), 10):
logger.debug(s)
test_case_result = self.get_test_case_result(
threshold,
stats["total"],
stats["updated"],
stats["exclusive_A"],
stats["exclusive_B"],
)
count = self._compute_row_count(self.runner, None) # type: ignore
test_case_result.passedRows = stats["unchanged"]
test_case_result.failedRows = stats["total"]
test_case_result.passedRowsPercentage = (
test_case_result.passedRows / count * 100
)
test_case_result.failedRowsPercentage = (
test_case_result.failedRows / count * 100
)
return test_case_result
num_dffs = sum(1 for _ in islice(table_diff_iter, threshold))
return self.get_test_case_result(
num_dffs,
threshold,
)
def get_table_diff(self) -> DiffResultWrapper:
"""Calls data_diff.diff_tables with the parameters from the test case."""
table1 = data_diff.connect_to_table(
self.runtime_params.service1Url, self.runtime_params.table1, self.runtime_params.keyColumns # type: ignore
)
table2 = data_diff.connect_to_table(
self.runtime_params.service2Url, self.runtime_params.table2, self.runtime_params.keyColumns # type: ignore
)
data_diff_kwargs = {
"key_columns": self.runtime_params.keyColumns,
"extra_columns": self.runtime_params.extraColumns,
"where": self.get_where(),
}
logger.debug(
"Calling table diff with parameters:" # pylint: disable=consider-using-f-string
" table1={}, table2={}, kwargs={}".format(
table1.table_path,
table2.table_path,
",".join(f"{k}={v}" for k, v in data_diff_kwargs.items()),
)
)
return data_diff.diff_tables(table1, table2, **data_diff_kwargs) # type: ignore
def get_where(self) -> Optional[str]:
"""Returns the where clause from the test case parameters or None if it is a blank string."""
return self.runtime_params.whereClause or None
def get_runtime_params(self) -> TableDiffRuntimeParameters:
raw = self.get_test_case_param_value(
self.test_case.parameterValues, "runtimeParams", str
)
runtime_params = TableDiffRuntimeParameters.parse_raw(raw)
return runtime_params
@property
def _param_name(self):
return "forbiddenRegex"
def get_test_case_result(
self,
threshold: int,
total_diffs: int,
changed: Optional[int] = None,
removed: Optional[int] = None,
added: Optional[int] = None,
) -> TestCaseResult:
result_values = [
TestResultValue(name="diffCount", value=str(total_diffs)),
]
if changed is not None:
result_values.append(TestResultValue(name="changed", value=str(changed)))
if removed is not None:
result_values.append(TestResultValue(name="removed", value=str(removed)))
if added is not None:
result_values.append(TestResultValue(name="added", value=str(added)))
return TestCaseResult(
timestamp=self.execution_date, # type: ignore
testCaseStatus=self.get_test_case_status(
(threshold or total_diffs) == 0 or total_diffs < threshold
),
result=f"Found {total_diffs} different rows which is more than the threshold of {threshold}",
testResultValue=result_values,
validateColumns=False,
)
def _validate_dialects(self):
for param in ["service1Url", "service2Url"]:
dialect = urlparse(getattr(self.runtime_params, param)).scheme
if dialect not in SUPPORTED_DIALECTS:
raise UnsupportedDialectError(param, dialect)