mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-10-14 02:08:54 +00:00
210 lines
8.1 KiB
Python
210 lines
8.1 KiB
Python
![]() |
# Copyright 2024 Collate
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# pylint: disable=missing-module-docstring
|
||
|
import traceback
|
||
|
from itertools import islice
|
||
|
from typing import Optional
|
||
|
from urllib.parse import urlparse
|
||
|
|
||
|
import data_diff
|
||
|
from data_diff.diff_tables import DiffResultWrapper
|
||
|
|
||
|
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
|
||
|
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
|
||
|
SQAValidatorMixin,
|
||
|
)
|
||
|
from metadata.data_quality.validations.models import TableDiffRuntimeParameters
|
||
|
from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import (
|
||
|
TableDiffParamsSetter,
|
||
|
)
|
||
|
from metadata.generated.schema.entity.services.connections.database.sapHanaConnection import (
|
||
|
SapHanaScheme,
|
||
|
)
|
||
|
from metadata.generated.schema.tests.basic import (
|
||
|
TestCaseResult,
|
||
|
TestCaseStatus,
|
||
|
TestResultValue,
|
||
|
)
|
||
|
from metadata.profiler.orm.registry import Dialects
|
||
|
from metadata.utils.logger import test_suite_logger
|
||
|
|
||
|
logger = test_suite_logger()
|
||
|
|
||
|
SUPPORTED_DIALECTS = [
|
||
|
Dialects.Snowflake,
|
||
|
Dialects.BigQuery,
|
||
|
Dialects.Athena,
|
||
|
Dialects.Redshift,
|
||
|
Dialects.Postgres,
|
||
|
Dialects.MySQL,
|
||
|
Dialects.MSSQL,
|
||
|
Dialects.Oracle,
|
||
|
Dialects.Trino,
|
||
|
SapHanaScheme.hana.value,
|
||
|
]
|
||
|
|
||
|
|
||
|
class UnsupportedDialectError(Exception):
|
||
|
def __init__(self, param: str, dialect: str):
|
||
|
super().__init__(f"Unsupported dialect in param {param}: {dialect}")
|
||
|
|
||
|
|
||
|
class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
|
||
|
"""
|
||
|
Compare two tables and fail if the number of differences exceeds a threshold
|
||
|
"""
|
||
|
|
||
|
runtime_parameter_setter = TableDiffParamsSetter
|
||
|
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
self.runtime_params: TableDiffRuntimeParameters = self.get_runtime_params()
|
||
|
|
||
|
def run_validation(self) -> TestCaseResult:
|
||
|
try:
|
||
|
self._validate_dialects()
|
||
|
return self._run()
|
||
|
except KeyError as e:
|
||
|
result = TestCaseResult(
|
||
|
timestamp=self.execution_date, # type: ignore
|
||
|
testCaseStatus=TestCaseStatus.Failed,
|
||
|
result=f"MISMATCHED_COLUMNS: One of the tables is missing the column: '{e}'\n"
|
||
|
"Use two tables with the same schema or provide the extraColumns parameter.",
|
||
|
testResultValue=[TestResultValue(name="diffCount", value=str(0))],
|
||
|
)
|
||
|
logger.error(result.result)
|
||
|
return result
|
||
|
except UnsupportedDialectError as e:
|
||
|
result = TestCaseResult(
|
||
|
timestamp=self.execution_date, # type: ignore
|
||
|
testCaseStatus=TestCaseStatus.Aborted,
|
||
|
result=str(e),
|
||
|
)
|
||
|
return result
|
||
|
except Exception as e:
|
||
|
logger.error(
|
||
|
f"Unexpected error while running the table diff test: {str(e)}\n{traceback.format_exc()}"
|
||
|
)
|
||
|
result = TestCaseResult(
|
||
|
timestamp=self.execution_date, # type: ignore
|
||
|
testCaseStatus=TestCaseStatus.Aborted,
|
||
|
result=f"ERROR: Unexpected error while running the table diff test: {str(e)}",
|
||
|
)
|
||
|
logger.debug(result.result)
|
||
|
return result
|
||
|
|
||
|
def _run(self) -> TestCaseResult:
|
||
|
threshold = self.get_test_case_param_value(
|
||
|
self.test_case.parameterValues, "threshold", int, default=0
|
||
|
)
|
||
|
table_diff_iter = self.get_table_diff()
|
||
|
|
||
|
if not threshold or self.test_case.computePassedFailedRowCount:
|
||
|
stats = table_diff_iter.get_stats_dict()
|
||
|
if stats["total"] > 0:
|
||
|
logger.debug("Sample of failed rows:")
|
||
|
for s in islice(self.get_table_diff(), 10):
|
||
|
logger.debug(s)
|
||
|
test_case_result = self.get_test_case_result(
|
||
|
threshold,
|
||
|
stats["total"],
|
||
|
stats["updated"],
|
||
|
stats["exclusive_A"],
|
||
|
stats["exclusive_B"],
|
||
|
)
|
||
|
count = self._compute_row_count(self.runner, None) # type: ignore
|
||
|
test_case_result.passedRows = stats["unchanged"]
|
||
|
test_case_result.failedRows = stats["total"]
|
||
|
test_case_result.passedRowsPercentage = (
|
||
|
test_case_result.passedRows / count * 100
|
||
|
)
|
||
|
test_case_result.failedRowsPercentage = (
|
||
|
test_case_result.failedRows / count * 100
|
||
|
)
|
||
|
return test_case_result
|
||
|
num_dffs = sum(1 for _ in islice(table_diff_iter, threshold))
|
||
|
return self.get_test_case_result(
|
||
|
num_dffs,
|
||
|
threshold,
|
||
|
)
|
||
|
|
||
|
def get_table_diff(self) -> DiffResultWrapper:
|
||
|
"""Calls data_diff.diff_tables with the parameters from the test case."""
|
||
|
table1 = data_diff.connect_to_table(
|
||
|
self.runtime_params.service1Url, self.runtime_params.table1, self.runtime_params.keyColumns # type: ignore
|
||
|
)
|
||
|
table2 = data_diff.connect_to_table(
|
||
|
self.runtime_params.service2Url, self.runtime_params.table2, self.runtime_params.keyColumns # type: ignore
|
||
|
)
|
||
|
|
||
|
data_diff_kwargs = {
|
||
|
"key_columns": self.runtime_params.keyColumns,
|
||
|
"extra_columns": self.runtime_params.extraColumns,
|
||
|
"where": self.get_where(),
|
||
|
}
|
||
|
logger.debug(
|
||
|
"Calling table diff with parameters:" # pylint: disable=consider-using-f-string
|
||
|
" table1={}, table2={}, kwargs={}".format(
|
||
|
table1.table_path,
|
||
|
table2.table_path,
|
||
|
",".join(f"{k}={v}" for k, v in data_diff_kwargs.items()),
|
||
|
)
|
||
|
)
|
||
|
return data_diff.diff_tables(table1, table2, **data_diff_kwargs) # type: ignore
|
||
|
|
||
|
def get_where(self) -> Optional[str]:
|
||
|
"""Returns the where clause from the test case parameters or None if it is a blank string."""
|
||
|
return self.runtime_params.whereClause or None
|
||
|
|
||
|
def get_runtime_params(self) -> TableDiffRuntimeParameters:
|
||
|
raw = self.get_test_case_param_value(
|
||
|
self.test_case.parameterValues, "runtimeParams", str
|
||
|
)
|
||
|
runtime_params = TableDiffRuntimeParameters.parse_raw(raw)
|
||
|
return runtime_params
|
||
|
|
||
|
@property
|
||
|
def _param_name(self):
|
||
|
return "forbiddenRegex"
|
||
|
|
||
|
def get_test_case_result(
|
||
|
self,
|
||
|
threshold: int,
|
||
|
total_diffs: int,
|
||
|
changed: Optional[int] = None,
|
||
|
removed: Optional[int] = None,
|
||
|
added: Optional[int] = None,
|
||
|
) -> TestCaseResult:
|
||
|
result_values = [
|
||
|
TestResultValue(name="diffCount", value=str(total_diffs)),
|
||
|
]
|
||
|
if changed is not None:
|
||
|
result_values.append(TestResultValue(name="changed", value=str(changed)))
|
||
|
if removed is not None:
|
||
|
result_values.append(TestResultValue(name="removed", value=str(removed)))
|
||
|
if added is not None:
|
||
|
result_values.append(TestResultValue(name="added", value=str(added)))
|
||
|
return TestCaseResult(
|
||
|
timestamp=self.execution_date, # type: ignore
|
||
|
testCaseStatus=self.get_test_case_status(
|
||
|
(threshold or total_diffs) == 0 or total_diffs < threshold
|
||
|
),
|
||
|
result=f"Found {total_diffs} different rows which is more than the threshold of {threshold}",
|
||
|
testResultValue=result_values,
|
||
|
validateColumns=False,
|
||
|
)
|
||
|
|
||
|
def _validate_dialects(self):
|
||
|
for param in ["service1Url", "service2Url"]:
|
||
|
dialect = urlparse(getattr(self.runtime_params, param)).scheme
|
||
|
if dialect not in SUPPORTED_DIALECTS:
|
||
|
raise UnsupportedDialectError(param, dialect)
|