MINOR: Data diff sample fix (#18632)

* fix(data-diff): sampling configuration

handle the sampling condition separately for the 2 tables allowing to apply sampling on columns with mismatching cases

* format
This commit is contained in:
Imri Paran 2024-11-15 08:22:13 +01:00 committed by GitHub
parent 706d13e289
commit bde6ee4125
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 148 additions and 48 deletions

View File

@ -25,14 +25,17 @@ from data_diff.diff_tables import DiffResultWrapper
from data_diff.errors import DataDiffMismatchingKeyTypesError
from data_diff.utils import ArithAlphanumeric, CaseInsensitiveDict
from sqlalchemy import Column as SAColumn
from sqlalchemy import literal, select
from sqlalchemy import create_engine, literal, select
from metadata.data_quality.validations import utils
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
SQAValidatorMixin,
)
from metadata.data_quality.validations.models import TableDiffRuntimeParameters
from metadata.data_quality.validations.models import (
TableDiffRuntimeParameters,
TableParameter,
)
from metadata.generated.schema.entity.data.table import Column, ProfileSampleType
from metadata.generated.schema.entity.services.connections.database.sapHanaConnection import (
SapHanaScheme,
@ -48,7 +51,8 @@ from metadata.generated.schema.tests.basic import (
from metadata.profiler.orm.converter.base import build_orm_col
from metadata.profiler.orm.functions.md5 import MD5
from metadata.profiler.orm.functions.substr import Substr
from metadata.profiler.orm.registry import Dialects
from metadata.profiler.orm.registry import Dialects, PythonDialects
from metadata.utils.collections import CaseInsensitiveList
from metadata.utils.logger import test_suite_logger
logger = test_suite_logger()
@ -67,6 +71,37 @@ SUPPORTED_DIALECTS = [
]
def build_sample_where_clause(
table: TableParameter, key_columns: List[str], salt: str, hex_nounce: str
) -> str:
sql_alchemy_columns = [
build_orm_col(i, c, table.database_service_type)
for i, c in enumerate(table.columns)
if c.name.root in key_columns
]
reduced_concat = reduce(
lambda c1, c2: c1.concat(c2), sql_alchemy_columns + [literal(salt)]
)
sqa_dialect = create_engine(
f"{PythonDialects[table.database_service_type.name].value}://"
).dialect
return str(
select()
.filter(
Substr(
MD5(reduced_concat),
1,
8,
)
< hex_nounce
)
.whereclause.compile(
dialect=sqa_dialect,
compile_kwargs={"literal_binds": True},
)
)
def compile_and_clauses(elements) -> str:
"""Compile a list of elements into a string with 'and' clauses.
@ -287,17 +322,20 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
def get_table_diff(self) -> DiffResultWrapper:
"""Calls data_diff.diff_tables with the parameters from the test case."""
left_where, right_where = self.sample_where_clause()
table1 = data_diff.connect_to_table(
self.runtime_params.table1.serviceUrl,
self.runtime_params.table1.path,
self.runtime_params.keyColumns, # type: ignore
case_sensitive=self.get_case_sensitive(),
where=left_where,
)
table2 = data_diff.connect_to_table(
self.runtime_params.table2.serviceUrl,
self.runtime_params.table2.path,
self.runtime_params.keyColumns, # type: ignore
case_sensitive=self.get_case_sensitive(),
where=right_where,
)
data_diff_kwargs = {
"key_columns": self.runtime_params.keyColumns,
@ -314,12 +352,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
def get_where(self) -> Optional[str]:
"""Returns the where clause from the test case parameters or None if it is a blank string."""
runtime_where = self.runtime_params.whereClause
sample_where = self.sample_where_clause()
where = compile_and_clauses([c for c in [runtime_where, sample_where] if c])
return where if where else None
return self.runtime_params.whereClause or None
def sample_where_clause(self) -> Optional[str]:
def sample_where_clause(self) -> Tuple[Optional[str], Optional[str]]:
"""We use a where clause to sample the data for the diff. This is useful because with data diff
we do not have access to the underlying 'SELECT' statement. This method generates a where clause
that selects a random sample of the data based on the profile sample configuration.
@ -340,19 +375,23 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
with different ids because the comparison will not be sensible.
"""
if (
# no sample configuration
self.runtime_params.table_profile_config is None
or self.runtime_params.table_profile_config.profileSample is None
# sample is 100% or in other words no sample is required
or (
self.runtime_params.table_profile_config.profileSampleType
== ProfileSampleType.PERCENTAGE
and self.runtime_params.table_profile_config.profileSample == 100
)
):
return None
return None, None
if DatabaseServiceType.Mssql in [
self.runtime_params.table1.database_service_type,
self.runtime_params.table2.database_service_type,
]:
raise ValueError(
"MSSQL does not support sampling in data diff.\n"
"You can request this feature here:\n"
"https://github.com/open-metadata/OpenMetadata/issues/new?labels=enhancement&projects=&template=feature_request.md" # pylint: disable=line-too-long
)
logger.warning("Sampling not supported in MSSQL. Skipping sampling.")
return None, None
nounce = self.calculate_nounce()
# SQL MD5 returns a 32 character hex string even with leading zeros so we need to
# pad the nounce to 8 characters in preserve lexical order.
@ -364,29 +403,14 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
salt = "".join(
random.choices(string.ascii_letters + string.digits, k=5)
) # 1 / ~62^5 should be enough entropy. Use letters and digits to avoid messing with SQL syntax
sql_alchemy_columns = [
build_orm_col(
i, c, self.runtime_params.table1.database_service_type
) # TODO: get from runtime params
for i, c in enumerate(self.runtime_params.table1.columns)
if c.name.root in self.runtime_params.keyColumns
]
reduced_concat = reduce(
lambda c1, c2: c1.concat(c2), sql_alchemy_columns + [literal(salt)]
)
return str(
(
select()
.filter(
Substr(
MD5(reduced_concat),
1,
8,
)
< hex_nounce
)
.whereclause.compile(compile_kwargs={"literal_binds": True})
key_columns = (
CaseInsensitiveList(self.runtime_params.keyColumns)
if not self.get_case_sensitive()
else self.runtime_params.keyColumns
)
return tuple(
build_sample_where_clause(table, key_columns, salt, hex_nounce)
for table in [self.runtime_params.table1, self.runtime_params.table2]
)
def calculate_nounce(self, max_nounce=2**32 - 1) -> int:

View File

@ -326,6 +326,10 @@ class TestParameters(BaseModel):
timestamp=int(datetime.now().timestamp() * 1000),
testCaseStatus=TestCaseStatus.Success,
),
TableProfilerConfig(
profileSampleType=ProfileSampleType.PERCENTAGE,
profileSample=10,
),
),
(
TestCaseDefinition(

View File

@ -19,6 +19,7 @@ from metadata.generated.schema.entity.data.table import (
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
)
from metadata.generated.schema.tests.testCase import TestCase, TestCaseParameterValue
@pytest.mark.parametrize(
@ -61,12 +62,18 @@ def test_compile_and_clauses(elements, expected):
}
),
"table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres}
**{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
}
),
"keyColumns": ["id"],
}
),
"SUBSTRING(MD5(id || 'a'), 1, 8) < '19999999'",
("SUBSTRING(MD5(id || 'a'), 1, 8) < '19999999'",) * 2,
),
(
TableDiffRuntimeParameters.model_construct(
@ -86,12 +93,18 @@ def test_compile_and_clauses(elements, expected):
}
),
"table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres}
**{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
}
),
"keyColumns": ["id"],
}
),
"SUBSTRING(MD5(id || 'a'), 1, 8) < '33333333'",
("SUBSTRING(MD5(id || 'a'), 1, 8) < '33333333'",) * 2,
),
(
TableDiffRuntimeParameters.model_construct(
@ -111,12 +124,18 @@ def test_compile_and_clauses(elements, expected):
}
),
"table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres}
**{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
}
),
"keyColumns": ["id", "name"],
}
),
"SUBSTRING(MD5(id || name || 'a'), 1, 8) < '19999999'",
("SUBSTRING(MD5(id || name || 'a'), 1, 8) < '19999999'",) * 2,
),
(
TableDiffRuntimeParameters.model_construct(
@ -136,12 +155,51 @@ def test_compile_and_clauses(elements, expected):
}
),
"table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres}
**{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
},
),
"keyColumns": ["id", "name"],
}
),
"SUBSTRING(MD5(id || name || 'a'), 1, 8) < '0083126e'",
("SUBSTRING(MD5(id || name || 'a'), 1, 8) < '0083126e'",) * 2,
),
(
TableDiffRuntimeParameters.model_construct(
**{
"table_profile_config": TableProfilerConfig(
profileSampleType=ProfileSampleType.ROWS,
profileSample=20,
),
"table1": TableParameter.model_construct(
**{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
}
),
"table2": TableParameter.model_construct(
**{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="ID", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
},
),
"keyColumns": ["id"],
}
),
(
"SUBSTRING(MD5(id || 'a'), 1, 8) < '0083126e'",
"SUBSTRING(MD5(\"ID\" || 'a'), 1, 8) < '0083126e'",
),
),
(
TableDiffRuntimeParameters.model_construct(
@ -157,17 +215,31 @@ def test_compile_and_clauses(elements, expected):
}
),
"table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres}
**{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
},
),
"keyColumns": ["id", "name"],
"keyColumns": ["id"],
}
),
None,
(None, None),
),
],
)
def test_sample_where_clauses(config, expected):
validator = TableDiffValidator(None, None, None)
validator = TableDiffValidator(
None,
TestCase.model_construct(
parameterValues=[
TestCaseParameterValue(name="caseSensitiveColumns", value="false")
]
),
None,
)
validator.runtime_params = config
if (
config.table_profile_config