2024-06-20 16:54:12 +02:00
|
|
|
# Copyright 2024 Collate
|
2025-04-03 10:39:47 +05:30
|
|
|
# Licensed under the Collate Community License, Version 1.0 (the "License");
|
2024-06-20 16:54:12 +02:00
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
2025-04-03 10:39:47 +05:30
|
|
|
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
|
2024-06-20 16:54:12 +02:00
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# pylint: disable=missing-module-docstring
|
2024-08-07 17:21:17 +02:00
|
|
|
import logging
|
2024-11-11 10:07:23 +01:00
|
|
|
import random
|
|
|
|
import string
|
2024-06-20 16:54:12 +02:00
|
|
|
import traceback
|
2024-10-15 16:29:43 +02:00
|
|
|
from decimal import Decimal
|
2024-11-11 10:07:23 +01:00
|
|
|
from functools import reduce
|
2024-06-20 16:54:12 +02:00
|
|
|
from itertools import islice
|
2024-10-15 16:29:43 +02:00
|
|
|
from typing import Dict, Iterable, List, Optional, Tuple, cast
|
2024-06-20 16:54:12 +02:00
|
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
|
|
import data_diff
|
2024-07-02 12:36:03 +02:00
|
|
|
import sqlalchemy.types
|
2024-06-20 16:54:12 +02:00
|
|
|
from data_diff.diff_tables import DiffResultWrapper
|
2024-07-02 12:36:03 +02:00
|
|
|
from data_diff.errors import DataDiffMismatchingKeyTypesError
|
2024-10-15 16:29:43 +02:00
|
|
|
from data_diff.utils import ArithAlphanumeric, CaseInsensitiveDict
|
2024-07-02 12:36:03 +02:00
|
|
|
from sqlalchemy import Column as SAColumn
|
2024-12-06 14:21:33 +01:00
|
|
|
from sqlalchemy import literal, select
|
|
|
|
from sqlalchemy.engine import make_url
|
2024-06-20 16:54:12 +02:00
|
|
|
|
2024-10-15 16:29:43 +02:00
|
|
|
from metadata.data_quality.validations import utils
|
2024-06-20 16:54:12 +02:00
|
|
|
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
|
|
|
|
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
|
|
|
|
SQAValidatorMixin,
|
|
|
|
)
|
2024-11-15 08:22:13 +01:00
|
|
|
from metadata.data_quality.validations.models import (
|
|
|
|
TableDiffRuntimeParameters,
|
|
|
|
TableParameter,
|
|
|
|
)
|
2024-11-11 10:07:23 +01:00
|
|
|
from metadata.generated.schema.entity.data.table import Column, ProfileSampleType
|
2024-06-20 16:54:12 +02:00
|
|
|
from metadata.generated.schema.entity.services.connections.database.sapHanaConnection import (
|
|
|
|
SapHanaScheme,
|
|
|
|
)
|
2024-11-11 10:07:23 +01:00
|
|
|
from metadata.generated.schema.entity.services.databaseService import (
|
|
|
|
DatabaseServiceType,
|
|
|
|
)
|
2024-06-20 16:54:12 +02:00
|
|
|
from metadata.generated.schema.tests.basic import (
|
|
|
|
TestCaseResult,
|
|
|
|
TestCaseStatus,
|
|
|
|
TestResultValue,
|
|
|
|
)
|
2024-11-27 08:50:54 +01:00
|
|
|
from metadata.profiler.metrics.registry import Metrics
|
2024-11-11 10:07:23 +01:00
|
|
|
from metadata.profiler.orm.converter.base import build_orm_col
|
|
|
|
from metadata.profiler.orm.functions.md5 import MD5
|
|
|
|
from metadata.profiler.orm.functions.substr import Substr
|
2024-11-15 08:22:13 +01:00
|
|
|
from metadata.profiler.orm.registry import Dialects, PythonDialects
|
|
|
|
from metadata.utils.collections import CaseInsensitiveList
|
2024-06-20 16:54:12 +02:00
|
|
|
from metadata.utils.logger import test_suite_logger
|
|
|
|
|
|
|
|
logger = test_suite_logger()
|
|
|
|
|
|
|
|
SUPPORTED_DIALECTS = [
|
|
|
|
Dialects.Snowflake,
|
|
|
|
Dialects.BigQuery,
|
|
|
|
Dialects.Athena,
|
|
|
|
Dialects.Redshift,
|
|
|
|
Dialects.Postgres,
|
|
|
|
Dialects.MySQL,
|
|
|
|
Dialects.MSSQL,
|
|
|
|
Dialects.Oracle,
|
|
|
|
Dialects.Trino,
|
|
|
|
SapHanaScheme.hana.value,
|
2025-09-30 20:56:54 +05:30
|
|
|
Dialects.Databricks,
|
|
|
|
Dialects.UnityCatalog,
|
2024-06-20 16:54:12 +02:00
|
|
|
]
|
|
|
|
|
|
|
|
|
2024-11-15 08:22:13 +01:00
|
|
|
def build_sample_where_clause(
|
|
|
|
table: TableParameter, key_columns: List[str], salt: str, hex_nounce: str
|
|
|
|
) -> str:
|
|
|
|
sql_alchemy_columns = [
|
|
|
|
build_orm_col(i, c, table.database_service_type)
|
|
|
|
for i, c in enumerate(table.columns)
|
|
|
|
if c.name.root in key_columns
|
|
|
|
]
|
|
|
|
reduced_concat = reduce(
|
|
|
|
lambda c1, c2: c1.concat(c2), sql_alchemy_columns + [literal(salt)]
|
|
|
|
)
|
2024-12-06 14:21:33 +01:00
|
|
|
sqa_dialect = make_url(
|
2024-11-15 08:22:13 +01:00
|
|
|
f"{PythonDialects[table.database_service_type.name].value}://"
|
2024-12-06 14:21:33 +01:00
|
|
|
).get_dialect()
|
2024-11-15 08:22:13 +01:00
|
|
|
return str(
|
|
|
|
select()
|
|
|
|
.filter(
|
|
|
|
Substr(
|
|
|
|
MD5(reduced_concat),
|
|
|
|
1,
|
|
|
|
8,
|
|
|
|
)
|
|
|
|
< hex_nounce
|
|
|
|
)
|
|
|
|
.whereclause.compile(
|
2024-12-06 14:21:33 +01:00
|
|
|
dialect=sqa_dialect(),
|
2024-11-15 08:22:13 +01:00
|
|
|
compile_kwargs={"literal_binds": True},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-11-11 10:07:23 +01:00
|
|
|
def compile_and_clauses(elements) -> str:
|
|
|
|
"""Compile a list of elements into a string with 'and' clauses.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
elements: A string or a list of strings or lists
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A string with 'and' clauses
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
ValueError: If the input is not a string or a list
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> compile_and_clauses("a")
|
|
|
|
'a'
|
|
|
|
>>> compile_and_clauses(["a", "b"])
|
|
|
|
'a and b'
|
|
|
|
>>> compile_and_clauses([["a", "b"], "c"])
|
|
|
|
'(a and b) and c'
|
|
|
|
"""
|
|
|
|
if isinstance(elements, str):
|
|
|
|
return elements
|
|
|
|
if isinstance(elements, list):
|
|
|
|
if len(elements) == 1:
|
|
|
|
return compile_and_clauses(elements[0])
|
|
|
|
return " and ".join(
|
|
|
|
(
|
|
|
|
f"({compile_and_clauses(e)})"
|
|
|
|
if isinstance(e, list)
|
|
|
|
else compile_and_clauses(e)
|
|
|
|
)
|
|
|
|
for e in elements
|
|
|
|
)
|
|
|
|
raise ValueError("Input must be a string or a list")
|
|
|
|
|
|
|
|
|
2024-06-20 16:54:12 +02:00
|
|
|
class UnsupportedDialectError(Exception):
|
|
|
|
def __init__(self, param: str, dialect: str):
|
|
|
|
super().__init__(f"Unsupported dialect in param {param}: {dialect}")
|
|
|
|
|
|
|
|
|
2024-08-07 17:21:17 +02:00
|
|
|
def masked(s: str, mask: bool = True) -> str:
|
|
|
|
"""Mask a string if masked is True otherwise return the string.
|
|
|
|
Only for development purposes, do not use in production.
|
|
|
|
Change it False if you want to see the data in the logs.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
s: string to mask
|
|
|
|
mask: boolean to mask the string
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
masked string if mask is True otherwise return the string
|
|
|
|
"""
|
|
|
|
return "***" if mask else s
|
|
|
|
|
|
|
|
|
2024-10-15 16:29:43 +02:00
|
|
|
def is_numeric(t: type) -> bool:
|
|
|
|
"""Check if a type is numeric.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
t: type to check
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
True if the type is numeric otherwise False
|
|
|
|
"""
|
|
|
|
return t in [int, float, Decimal]
|
|
|
|
|
|
|
|
|
2024-06-20 16:54:12 +02:00
|
|
|
class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
|
|
|
|
"""
|
|
|
|
Compare two tables and fail if the number of differences exceeds a threshold
|
|
|
|
"""
|
|
|
|
|
2024-07-16 11:01:43 +02:00
|
|
|
runtime_params: TableDiffRuntimeParameters
|
2024-06-20 16:54:12 +02:00
|
|
|
|
|
|
|
def run_validation(self) -> TestCaseResult:
|
2024-11-21 08:07:33 +01:00
|
|
|
self.runtime_params = self.get_runtime_parameters(TableDiffRuntimeParameters)
|
2024-06-20 16:54:12 +02:00
|
|
|
try:
|
|
|
|
self._validate_dialects()
|
|
|
|
return self._run()
|
2024-07-02 12:36:03 +02:00
|
|
|
except DataDiffMismatchingKeyTypesError as e:
|
2024-06-20 16:54:12 +02:00
|
|
|
result = TestCaseResult(
|
|
|
|
timestamp=self.execution_date, # type: ignore
|
|
|
|
testCaseStatus=TestCaseStatus.Failed,
|
2024-07-02 12:36:03 +02:00
|
|
|
result=str(e),
|
2024-06-20 16:54:12 +02:00
|
|
|
)
|
|
|
|
return result
|
|
|
|
except UnsupportedDialectError as e:
|
2025-08-26 11:03:40 +02:00
|
|
|
logger.warning(f"[Data Diff]: Unsupported dialect: {e}")
|
2024-06-20 16:54:12 +02:00
|
|
|
result = TestCaseResult(
|
|
|
|
timestamp=self.execution_date, # type: ignore
|
|
|
|
testCaseStatus=TestCaseStatus.Aborted,
|
|
|
|
result=str(e),
|
|
|
|
)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(
|
|
|
|
f"Unexpected error while running the table diff test: {str(e)}\n{traceback.format_exc()}"
|
|
|
|
)
|
|
|
|
result = TestCaseResult(
|
|
|
|
timestamp=self.execution_date, # type: ignore
|
|
|
|
testCaseStatus=TestCaseStatus.Aborted,
|
|
|
|
result=f"ERROR: Unexpected error while running the table diff test: {str(e)}",
|
|
|
|
)
|
|
|
|
logger.debug(result.result)
|
|
|
|
return result
|
|
|
|
|
|
|
|
def _run(self) -> TestCaseResult:
|
2024-07-02 12:36:03 +02:00
|
|
|
result = self.get_column_diff()
|
|
|
|
if result:
|
|
|
|
return result
|
2024-06-20 16:54:12 +02:00
|
|
|
threshold = self.get_test_case_param_value(
|
|
|
|
self.test_case.parameterValues, "threshold", int, default=0
|
|
|
|
)
|
|
|
|
table_diff_iter = self.get_table_diff()
|
|
|
|
|
|
|
|
if not threshold or self.test_case.computePassedFailedRowCount:
|
|
|
|
stats = table_diff_iter.get_stats_dict()
|
|
|
|
if stats["total"] > 0:
|
|
|
|
logger.debug("Sample of failed rows:")
|
2024-08-07 17:21:17 +02:00
|
|
|
# depending on the data, this require scanning a lot of data
|
|
|
|
# so we only log the sample in debug mode. data can be sensitive
|
|
|
|
# so it is masked by default
|
|
|
|
for s in islice(
|
|
|
|
self.safe_table_diff_iterator(),
|
|
|
|
10 if logger.level <= logging.DEBUG else 0,
|
|
|
|
):
|
|
|
|
logger.debug("%s", str([s[0]] + [masked(st) for st in s[1]]))
|
2024-07-02 12:36:03 +02:00
|
|
|
test_case_result = self.get_row_diff_test_case_result(
|
2024-06-20 16:54:12 +02:00
|
|
|
threshold,
|
|
|
|
stats["total"],
|
|
|
|
stats["updated"],
|
|
|
|
stats["exclusive_A"],
|
|
|
|
stats["exclusive_B"],
|
|
|
|
)
|
|
|
|
count = self._compute_row_count(self.runner, None) # type: ignore
|
|
|
|
test_case_result.passedRows = stats["unchanged"]
|
|
|
|
test_case_result.passedRowsPercentage = (
|
|
|
|
test_case_result.passedRows / count * 100
|
|
|
|
)
|
|
|
|
test_case_result.failedRowsPercentage = (
|
|
|
|
test_case_result.failedRows / count * 100
|
|
|
|
)
|
|
|
|
return test_case_result
|
2024-07-02 12:36:03 +02:00
|
|
|
return self.get_row_diff_test_case_result(
|
2024-06-20 16:54:12 +02:00
|
|
|
threshold,
|
2024-07-05 07:51:24 +02:00
|
|
|
self.calculate_diffs_with_limit(table_diff_iter, threshold),
|
2024-06-20 16:54:12 +02:00
|
|
|
)
|
|
|
|
|
2024-07-02 12:36:03 +02:00
|
|
|
def get_incomparable_columns(self) -> List[str]:
|
|
|
|
"""Get the columns that have types that are not comparable between the two tables. For example
|
|
|
|
a column that is a string in one table and an integer in the other.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[str]: A list of column names that have incomparable types
|
|
|
|
"""
|
2025-09-30 20:56:54 +05:30
|
|
|
|
2024-07-02 12:36:03 +02:00
|
|
|
table1 = data_diff.connect_to_table(
|
|
|
|
self.runtime_params.table1.serviceUrl,
|
|
|
|
self.runtime_params.table1.path,
|
2025-10-08 09:32:00 +02:00
|
|
|
self.runtime_params.table1.key_columns,
|
2024-07-02 12:36:03 +02:00
|
|
|
extra_columns=self.runtime_params.extraColumns,
|
2024-10-15 16:29:43 +02:00
|
|
|
case_sensitive=self.get_case_sensitive(),
|
2025-04-23 12:15:42 +02:00
|
|
|
key_content=self.runtime_params.table1.privateKey.get_secret_value()
|
|
|
|
if self.runtime_params.table1.privateKey
|
|
|
|
else None,
|
|
|
|
private_key_passphrase=self.runtime_params.table1.passPhrase.get_secret_value()
|
|
|
|
if self.runtime_params.table1.passPhrase
|
|
|
|
else None,
|
2024-07-02 12:36:03 +02:00
|
|
|
).with_schema()
|
|
|
|
table2 = data_diff.connect_to_table(
|
|
|
|
self.runtime_params.table2.serviceUrl,
|
|
|
|
self.runtime_params.table2.path,
|
2025-10-08 09:32:00 +02:00
|
|
|
self.runtime_params.table2.key_columns,
|
2024-07-02 12:36:03 +02:00
|
|
|
extra_columns=self.runtime_params.extraColumns,
|
2024-10-15 16:29:43 +02:00
|
|
|
case_sensitive=self.get_case_sensitive(),
|
2025-04-23 12:15:42 +02:00
|
|
|
key_content=self.runtime_params.table2.privateKey.get_secret_value()
|
|
|
|
if self.runtime_params.table2.privateKey
|
|
|
|
else None,
|
|
|
|
private_key_passphrase=self.runtime_params.table2.passPhrase.get_secret_value()
|
|
|
|
if self.runtime_params.table2.passPhrase
|
|
|
|
else None,
|
2024-07-02 12:36:03 +02:00
|
|
|
).with_schema()
|
|
|
|
result = []
|
|
|
|
for column in table1.key_columns + table1.extra_columns:
|
2024-11-25 15:53:50 +01:00
|
|
|
col1 = table1._schema.get(column) # pylint: disable=protected-access
|
|
|
|
if col1 is None:
|
|
|
|
# Skip columns that are not in the first table. We cover this case in get_changed_added_columns.
|
2024-07-02 12:36:03 +02:00
|
|
|
continue
|
2024-11-25 15:53:50 +01:00
|
|
|
col2 = table2._schema.get(column) # pylint: disable=protected-access
|
|
|
|
if col2 is None:
|
|
|
|
# Skip columns that are not in the second table. We cover this case in get_changed_added_columns.
|
|
|
|
continue
|
|
|
|
col1_type = self._get_column_python_type(col1)
|
|
|
|
col2_type = self._get_column_python_type(col2)
|
2024-10-15 16:29:43 +02:00
|
|
|
if is_numeric(col1_type) and is_numeric(col2_type):
|
|
|
|
continue
|
2024-07-02 12:36:03 +02:00
|
|
|
if col1_type != col2_type:
|
|
|
|
result.append(column)
|
|
|
|
return result
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _get_column_python_type(column: SAColumn):
|
|
|
|
"""Try to resolve the python_type of a column by cascading through different SQLAlchemy types.
|
|
|
|
If no type is found, return the name of the column type. This is usually undesirable since it can
|
|
|
|
be very database specific, but it is better than nothing.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
column: An SQLAlchemy column object
|
|
|
|
"""
|
|
|
|
result = None
|
|
|
|
try:
|
|
|
|
result = column.python_type
|
|
|
|
except AttributeError:
|
|
|
|
pass
|
|
|
|
try:
|
|
|
|
result = getattr(sqlalchemy.types, type(column).__name__)().python_type
|
|
|
|
except AttributeError:
|
|
|
|
pass
|
|
|
|
try:
|
|
|
|
result = getattr(
|
|
|
|
sqlalchemy.types, type(column).__name__.upper()
|
|
|
|
)().python_type
|
|
|
|
except AttributeError:
|
|
|
|
pass
|
|
|
|
if result == ArithAlphanumeric:
|
|
|
|
result = str
|
|
|
|
elif result == bool:
|
|
|
|
result = int
|
|
|
|
elif result is None:
|
|
|
|
return type(result)
|
|
|
|
return result
|
|
|
|
|
2024-06-20 16:54:12 +02:00
|
|
|
def get_table_diff(self) -> DiffResultWrapper:
|
|
|
|
"""Calls data_diff.diff_tables with the parameters from the test case."""
|
2024-11-15 08:22:13 +01:00
|
|
|
left_where, right_where = self.sample_where_clause()
|
2024-06-20 16:54:12 +02:00
|
|
|
table1 = data_diff.connect_to_table(
|
2024-07-02 12:36:03 +02:00
|
|
|
self.runtime_params.table1.serviceUrl,
|
|
|
|
self.runtime_params.table1.path,
|
2025-10-08 09:32:00 +02:00
|
|
|
self.runtime_params.table1.key_columns, # type: ignore
|
|
|
|
extra_columns=self.runtime_params.table1.extra_columns,
|
2024-10-15 16:29:43 +02:00
|
|
|
case_sensitive=self.get_case_sensitive(),
|
2024-11-15 08:22:13 +01:00
|
|
|
where=left_where,
|
2025-04-23 12:15:42 +02:00
|
|
|
key_content=self.runtime_params.table1.privateKey.get_secret_value()
|
|
|
|
if self.runtime_params.table1.privateKey
|
|
|
|
else None,
|
|
|
|
private_key_passphrase=self.runtime_params.table1.passPhrase.get_secret_value()
|
|
|
|
if self.runtime_params.table1.passPhrase
|
|
|
|
else None,
|
2024-06-20 16:54:12 +02:00
|
|
|
)
|
|
|
|
table2 = data_diff.connect_to_table(
|
2024-07-02 12:36:03 +02:00
|
|
|
self.runtime_params.table2.serviceUrl,
|
|
|
|
self.runtime_params.table2.path,
|
2025-10-08 09:32:00 +02:00
|
|
|
self.runtime_params.table2.key_columns, # type: ignore
|
|
|
|
extra_columns=self.runtime_params.table2.extra_columns,
|
2024-10-15 16:29:43 +02:00
|
|
|
case_sensitive=self.get_case_sensitive(),
|
2024-11-15 08:22:13 +01:00
|
|
|
where=right_where,
|
2025-04-23 12:15:42 +02:00
|
|
|
key_content=self.runtime_params.table1.privateKey.get_secret_value()
|
|
|
|
if self.runtime_params.table1.privateKey
|
|
|
|
else None,
|
|
|
|
private_key_passphrase=self.runtime_params.table1.passPhrase.get_secret_value()
|
|
|
|
if self.runtime_params.table1.passPhrase
|
|
|
|
else None,
|
2024-06-20 16:54:12 +02:00
|
|
|
)
|
|
|
|
data_diff_kwargs = {
|
|
|
|
"where": self.get_where(),
|
|
|
|
}
|
|
|
|
logger.debug(
|
2024-08-07 17:21:17 +02:00
|
|
|
"Calling table diff with parameters: table1=%s, table2=%s, kwargs=%s",
|
|
|
|
table1.table_path,
|
|
|
|
table2.table_path,
|
|
|
|
",".join(f"{k}={v}" for k, v in data_diff_kwargs.items()),
|
2024-06-20 16:54:12 +02:00
|
|
|
)
|
|
|
|
return data_diff.diff_tables(table1, table2, **data_diff_kwargs) # type: ignore
|
|
|
|
|
|
|
|
def get_where(self) -> Optional[str]:
|
|
|
|
"""Returns the where clause from the test case parameters or None if it is a blank string."""
|
2024-11-15 08:22:13 +01:00
|
|
|
return self.runtime_params.whereClause or None
|
2024-11-11 10:07:23 +01:00
|
|
|
|
2024-11-15 08:22:13 +01:00
|
|
|
def sample_where_clause(self) -> Tuple[Optional[str], Optional[str]]:
|
2024-11-11 10:07:23 +01:00
|
|
|
"""We use a where clause to sample the data for the diff. This is useful because with data diff
|
|
|
|
we do not have access to the underlying 'SELECT' statement. This method generates a where clause
|
|
|
|
that selects a random sample of the data based on the profile sample configuration.
|
|
|
|
The method uses the md5 hash of the key columns and a random salt to select a random sample of the data.
|
|
|
|
This ensures that the same data is selected from the two tables of the comparison.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
-- Table 1 -- | -- Table 2 --
|
|
|
|
id | name | id | name
|
|
|
|
1 | Alice | 1 | Alice
|
|
|
|
2 | Bob | 2 | Bob
|
|
|
|
3 | Charlie | 3 | Charlie
|
|
|
|
4 | David | 4 | Edward
|
|
|
|
5 | Edward | 6 | Frank
|
|
|
|
|
|
|
|
If we want a sample of 20% of the data, the where clause will intend to select one of the rows
|
|
|
|
on Table 1 and the hash will ensure that the same row is selected on Table 2. We want to avoid selecting rows
|
|
|
|
with different ids because the comparison will not be sensible.
|
|
|
|
"""
|
|
|
|
if (
|
2024-11-15 08:22:13 +01:00
|
|
|
# no sample configuration
|
2024-11-11 10:07:23 +01:00
|
|
|
self.runtime_params.table_profile_config is None
|
|
|
|
or self.runtime_params.table_profile_config.profileSample is None
|
2024-11-15 08:22:13 +01:00
|
|
|
# sample is 100% or in other words no sample is required
|
|
|
|
or (
|
|
|
|
self.runtime_params.table_profile_config.profileSampleType
|
|
|
|
== ProfileSampleType.PERCENTAGE
|
|
|
|
and self.runtime_params.table_profile_config.profileSample == 100
|
|
|
|
)
|
2024-11-11 10:07:23 +01:00
|
|
|
):
|
2024-11-15 08:22:13 +01:00
|
|
|
return None, None
|
2024-11-11 10:07:23 +01:00
|
|
|
if DatabaseServiceType.Mssql in [
|
|
|
|
self.runtime_params.table1.database_service_type,
|
|
|
|
self.runtime_params.table2.database_service_type,
|
|
|
|
]:
|
2024-11-15 08:22:13 +01:00
|
|
|
logger.warning("Sampling not supported in MSSQL. Skipping sampling.")
|
|
|
|
return None, None
|
2024-11-11 10:07:23 +01:00
|
|
|
nounce = self.calculate_nounce()
|
|
|
|
# SQL MD5 returns a 32 character hex string even with leading zeros so we need to
|
|
|
|
# pad the nounce to 8 characters in preserve lexical order.
|
|
|
|
# example: SELECT md5('j(R1wzR*y[^GxWJ5B>L{-HLETRD');
|
|
|
|
hex_nounce = hex(nounce)[2:].rjust(8, "0")
|
|
|
|
# TODO: using strings for this is sub-optimal. But using bytes buffers requires a by-database
|
|
|
|
# implementaiton. We can use this as default and add database specific implementations as the
|
|
|
|
# need arises.
|
|
|
|
salt = "".join(
|
|
|
|
random.choices(string.ascii_letters + string.digits, k=5)
|
|
|
|
) # 1 / ~62^5 should be enough entropy. Use letters and digits to avoid messing with SQL syntax
|
2025-10-08 09:32:00 +02:00
|
|
|
|
|
|
|
return (
|
|
|
|
build_sample_where_clause(
|
|
|
|
self.runtime_params.table1,
|
|
|
|
self.maybe_case_sensitive(self.runtime_params.table1.key_columns),
|
|
|
|
salt,
|
|
|
|
hex_nounce,
|
|
|
|
),
|
|
|
|
build_sample_where_clause(
|
|
|
|
self.runtime_params.table2,
|
|
|
|
self.maybe_case_sensitive(self.runtime_params.table2.key_columns),
|
|
|
|
salt,
|
|
|
|
hex_nounce,
|
|
|
|
),
|
2024-11-11 10:07:23 +01:00
|
|
|
)
|
2025-10-08 09:32:00 +02:00
|
|
|
|
|
|
|
def maybe_case_sensitive(self, iterable: Iterable[str]) -> list[str]:
|
|
|
|
return (
|
|
|
|
CaseInsensitiveList(iterable)
|
|
|
|
if not self.get_case_sensitive()
|
|
|
|
else list(iterable)
|
2024-11-11 10:07:23 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
def calculate_nounce(self, max_nounce=2**32 - 1) -> int:
|
|
|
|
"""Calculate the nounce based on the profile sample configuration. The nounce is
|
|
|
|
the sample fraction projected to a number on a scale of 0 to max_nounce"""
|
|
|
|
if (
|
|
|
|
self.runtime_params.table_profile_config.profileSampleType
|
|
|
|
== ProfileSampleType.PERCENTAGE
|
|
|
|
):
|
|
|
|
return int(
|
|
|
|
max_nounce
|
|
|
|
* self.runtime_params.table_profile_config.profileSample
|
|
|
|
/ 100
|
|
|
|
)
|
|
|
|
if (
|
|
|
|
self.runtime_params.table_profile_config.profileSampleType
|
|
|
|
== ProfileSampleType.ROWS
|
|
|
|
):
|
2024-11-27 08:50:54 +01:00
|
|
|
row_count = self.get_total_row_count()
|
2024-11-11 10:07:23 +01:00
|
|
|
if row_count is None:
|
|
|
|
raise ValueError("Row count is required for ROWS profile sample type")
|
|
|
|
return int(
|
|
|
|
max_nounce
|
|
|
|
* (self.runtime_params.table_profile_config.profileSample / row_count)
|
|
|
|
)
|
|
|
|
raise ValueError("Invalid profile sample type")
|
2024-06-20 16:54:12 +02:00
|
|
|
|
2024-07-02 12:36:03 +02:00
|
|
|
def get_row_diff_test_case_result(
|
2024-06-20 16:54:12 +02:00
|
|
|
self,
|
|
|
|
threshold: int,
|
|
|
|
total_diffs: int,
|
|
|
|
changed: Optional[int] = None,
|
|
|
|
removed: Optional[int] = None,
|
|
|
|
added: Optional[int] = None,
|
|
|
|
) -> TestCaseResult:
|
2024-07-02 12:36:03 +02:00
|
|
|
"""Build a test case result for a row diff test. If the number of differences is less than the threshold,
|
|
|
|
the test will pass, otherwise it will fail. The result will contain the number of added, removed, and changed
|
|
|
|
rows, as well as the total number of differences.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
threshold: The maximum number of differences allowed before the test fails
|
|
|
|
total_diffs: The total number of differences between the tables
|
|
|
|
changed: The number of rows that have been changed
|
|
|
|
removed: The number of rows that have been removed
|
|
|
|
added: The number of rows that have been added
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
TestCaseResult: The result of the row diff test
|
|
|
|
"""
|
2024-06-20 16:54:12 +02:00
|
|
|
return TestCaseResult(
|
|
|
|
timestamp=self.execution_date, # type: ignore
|
|
|
|
testCaseStatus=self.get_test_case_status(
|
|
|
|
(threshold or total_diffs) == 0 or total_diffs < threshold
|
|
|
|
),
|
|
|
|
result=f"Found {total_diffs} different rows which is more than the threshold of {threshold}",
|
2024-07-05 07:51:24 +02:00
|
|
|
failedRows=total_diffs,
|
2024-06-20 16:54:12 +02:00
|
|
|
validateColumns=False,
|
2024-07-02 12:36:03 +02:00
|
|
|
testResultValue=[
|
|
|
|
TestResultValue(name="removedRows", value=str(removed)),
|
|
|
|
TestResultValue(name="addedRows", value=str(added)),
|
|
|
|
TestResultValue(name="changedRows", value=str(changed)),
|
|
|
|
TestResultValue(name="diffCount", value=str(total_diffs)),
|
|
|
|
],
|
2024-06-20 16:54:12 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
def _validate_dialects(self):
|
2024-07-02 12:36:03 +02:00
|
|
|
for name, param in [
|
|
|
|
("table1.serviceUrl", self.runtime_params.table1.serviceUrl),
|
|
|
|
("table2.serviceUrl", self.runtime_params.table2.serviceUrl),
|
|
|
|
]:
|
2025-06-27 07:58:48 +02:00
|
|
|
if isinstance(param, dict):
|
|
|
|
dialect = param.get("driver")
|
|
|
|
else:
|
|
|
|
dialect = urlparse(param).scheme
|
2024-06-20 16:54:12 +02:00
|
|
|
if dialect not in SUPPORTED_DIALECTS:
|
2024-07-02 12:36:03 +02:00
|
|
|
raise UnsupportedDialectError(name, dialect)
|
|
|
|
|
|
|
|
def get_column_diff(self) -> Optional[TestCaseResult]:
|
|
|
|
"""Get the column diff between the two tables. If there are no differences, return None."""
|
|
|
|
removed, added = self.get_changed_added_columns(
|
2025-10-08 09:32:00 +02:00
|
|
|
[
|
|
|
|
c
|
|
|
|
for c in self.runtime_params.table1.columns
|
|
|
|
if c.name.root not in self.runtime_params.table1.key_columns
|
|
|
|
],
|
|
|
|
[
|
|
|
|
c
|
|
|
|
for c in self.runtime_params.table2.columns
|
|
|
|
if c.name.root not in self.runtime_params.table2.key_columns
|
|
|
|
],
|
2024-10-15 16:29:43 +02:00
|
|
|
self.get_case_sensitive(),
|
2024-07-02 12:36:03 +02:00
|
|
|
)
|
|
|
|
changed = self.get_incomparable_columns()
|
|
|
|
if removed or added or changed:
|
|
|
|
return self.column_validation_result(
|
|
|
|
removed,
|
|
|
|
added,
|
|
|
|
changed,
|
|
|
|
)
|
|
|
|
return None
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_changed_added_columns(
|
2024-10-15 16:29:43 +02:00
|
|
|
left: List[Column], right: List[Column], case_sensitive: bool
|
2024-07-02 12:36:03 +02:00
|
|
|
) -> Optional[Tuple[List[str], List[str]]]:
|
|
|
|
"""Given a list of columns from two tables, return the columns that are removed and added.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
left: List of columns from the first table
|
|
|
|
right: List of columns from the second table
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A tuple of lists containing the removed and added columns or None if there are no differences
|
|
|
|
"""
|
|
|
|
removed: List[str] = []
|
|
|
|
added: List[str] = []
|
|
|
|
right_columns_dict: Dict[str, Column] = {c.name.root: c for c in right}
|
2024-10-15 16:29:43 +02:00
|
|
|
if not case_sensitive:
|
|
|
|
right_columns_dict = cast(
|
|
|
|
Dict[str, Column], CaseInsensitiveDict(right_columns_dict)
|
|
|
|
)
|
2024-07-02 12:36:03 +02:00
|
|
|
for column in left:
|
|
|
|
table2_column = right_columns_dict.get(column.name.root)
|
|
|
|
if table2_column is None:
|
|
|
|
removed.append(column.name.root)
|
|
|
|
continue
|
|
|
|
del right_columns_dict[column.name.root]
|
|
|
|
added.extend(right_columns_dict.keys())
|
|
|
|
return removed, added
|
|
|
|
|
|
|
|
def column_validation_result(
|
2024-10-15 16:29:43 +02:00
|
|
|
self,
|
|
|
|
removed: List[str],
|
|
|
|
added: List[str],
|
|
|
|
changed: List[str],
|
2024-07-02 12:36:03 +02:00
|
|
|
) -> TestCaseResult:
|
|
|
|
"""Build the result for a column validation result. Messages will only be added
|
|
|
|
for non-empty categories. Values will be populated reported for all categories.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
removed: List of removed columns
|
|
|
|
added: List of added columns
|
|
|
|
changed: List of changed columns
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
TestCaseResult: The result of the column validation with a meaningful message
|
|
|
|
"""
|
|
|
|
message = (
|
|
|
|
f"Tables have {sum(map(len, [removed, added, changed]))} different columns:"
|
|
|
|
)
|
|
|
|
if removed:
|
2025-10-08 09:32:00 +02:00
|
|
|
message += f"\n Removed columns: {', '.join(removed)}\n"
|
2024-07-02 12:36:03 +02:00
|
|
|
if added:
|
2025-10-08 09:32:00 +02:00
|
|
|
message += f"\n Added columns: {', '.join(added)}\n"
|
2024-07-02 12:36:03 +02:00
|
|
|
if changed:
|
|
|
|
message += "\n Changed columns:"
|
2024-10-15 16:29:43 +02:00
|
|
|
table1_columns = {
|
|
|
|
c.name.root: c for c in self.runtime_params.table1.columns
|
|
|
|
}
|
|
|
|
table2_columns = {
|
|
|
|
c.name.root: c for c in self.runtime_params.table2.columns
|
|
|
|
}
|
|
|
|
if not self.get_case_sensitive():
|
|
|
|
table1_columns = CaseInsensitiveDict(table1_columns)
|
|
|
|
table2_columns = CaseInsensitiveDict(table2_columns)
|
2024-07-02 12:36:03 +02:00
|
|
|
for col in changed:
|
2024-10-15 16:29:43 +02:00
|
|
|
col1 = table1_columns[col]
|
|
|
|
col2 = table2_columns[col]
|
2024-07-02 12:36:03 +02:00
|
|
|
message += (
|
|
|
|
f"\n {col}: {col1.dataType.value} -> {col2.dataType.value}"
|
|
|
|
)
|
|
|
|
return TestCaseResult(
|
|
|
|
timestamp=self.execution_date, # type: ignore
|
|
|
|
testCaseStatus=TestCaseStatus.Failed,
|
|
|
|
result=message,
|
|
|
|
testResultValue=[
|
|
|
|
TestResultValue(name="removedColumns", value=str(len(removed))),
|
|
|
|
TestResultValue(name="addedColumns", value=str(len(added))),
|
|
|
|
TestResultValue(name="changedColumns", value=str(len(changed))),
|
|
|
|
],
|
|
|
|
)
|
2024-07-05 07:51:24 +02:00
|
|
|
|
|
|
|
def calculate_diffs_with_limit(
|
|
|
|
self, diff_iter: Iterable[Tuple[str, Tuple[str, ...]]], limit: int
|
|
|
|
) -> int:
|
|
|
|
"""Given an iterator of diffs like
|
|
|
|
- ('+', (...))
|
|
|
|
- ('-', (...))
|
|
|
|
...
|
|
|
|
Calculate the total diffs by combining diffs for the same key. This gives an accurate count of the total diffs
|
|
|
|
as opposed to self.calculate_diff_num(diff_list)just counting the number of diffs in the list.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
diff_iter: iterator returned from the data_diff algorithm
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: accurate count of the total diffs up to the limit
|
|
|
|
|
|
|
|
"""
|
|
|
|
len_key_columns = len(self.runtime_params.keyColumns)
|
|
|
|
key_set = set()
|
|
|
|
# combine diffs on same key to "!"
|
|
|
|
for _, values in diff_iter:
|
|
|
|
k = values[:len_key_columns]
|
|
|
|
if k in key_set:
|
|
|
|
continue
|
|
|
|
key_set.add(k)
|
|
|
|
if len(key_set) > limit:
|
|
|
|
len(key_set)
|
|
|
|
return len(key_set)
|
2024-08-07 17:21:17 +02:00
|
|
|
|
|
|
|
def safe_table_diff_iterator(self) -> DiffResultWrapper:
|
|
|
|
"""A safe iterator object which properly closes the diff object when the generator is exhausted.
|
|
|
|
Otherwise the data_diff library will continue to hold the connection open and eventually
|
|
|
|
raise a KeyError.
|
|
|
|
"""
|
|
|
|
gen = self.get_table_diff()
|
|
|
|
try:
|
|
|
|
yield from gen
|
|
|
|
finally:
|
|
|
|
try:
|
|
|
|
gen.diff.close()
|
|
|
|
except KeyError as ex:
|
|
|
|
if str(ex) == "2":
|
|
|
|
# This is a known issue in data_diff where the diff object is closed
|
|
|
|
pass
|
2024-10-15 16:29:43 +02:00
|
|
|
|
|
|
|
def get_case_sensitive(self):
|
|
|
|
return utils.get_bool_test_case_param(
|
|
|
|
self.test_case.parameterValues, "caseSensitiveColumns"
|
|
|
|
)
|
2024-11-11 10:07:23 +01:00
|
|
|
|
|
|
|
def get_row_count(self) -> Optional[int]:
|
|
|
|
return self._compute_row_count(self.runner, None)
|
2024-11-27 08:50:54 +01:00
|
|
|
|
|
|
|
def get_total_row_count(self) -> Optional[int]:
|
|
|
|
row_count = Metrics.ROW_COUNT()
|
|
|
|
try:
|
|
|
|
row = self.runner.select_first_from_table(row_count.fn())
|
|
|
|
return dict(row).get(Metrics.ROW_COUNT.name)
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error getting row count: {e}")
|
|
|
|
return None
|