MINOR: Data diff sample fix (#18632)

* fix(data-diff): sampling configuration

handle the sampling condition separately for the 2 tables allowing to apply sampling on columns with mismatching cases

* format
This commit is contained in:
Imri Paran 2024-11-15 08:22:13 +01:00 committed by GitHub
parent 706d13e289
commit bde6ee4125
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 148 additions and 48 deletions

View File

@ -25,14 +25,17 @@ from data_diff.diff_tables import DiffResultWrapper
from data_diff.errors import DataDiffMismatchingKeyTypesError from data_diff.errors import DataDiffMismatchingKeyTypesError
from data_diff.utils import ArithAlphanumeric, CaseInsensitiveDict from data_diff.utils import ArithAlphanumeric, CaseInsensitiveDict
from sqlalchemy import Column as SAColumn from sqlalchemy import Column as SAColumn
from sqlalchemy import literal, select from sqlalchemy import create_engine, literal, select
from metadata.data_quality.validations import utils from metadata.data_quality.validations import utils
from metadata.data_quality.validations.base_test_handler import BaseTestValidator from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.data_quality.validations.mixins.sqa_validator_mixin import ( from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
SQAValidatorMixin, SQAValidatorMixin,
) )
from metadata.data_quality.validations.models import TableDiffRuntimeParameters from metadata.data_quality.validations.models import (
TableDiffRuntimeParameters,
TableParameter,
)
from metadata.generated.schema.entity.data.table import Column, ProfileSampleType from metadata.generated.schema.entity.data.table import Column, ProfileSampleType
from metadata.generated.schema.entity.services.connections.database.sapHanaConnection import ( from metadata.generated.schema.entity.services.connections.database.sapHanaConnection import (
SapHanaScheme, SapHanaScheme,
@ -48,7 +51,8 @@ from metadata.generated.schema.tests.basic import (
from metadata.profiler.orm.converter.base import build_orm_col from metadata.profiler.orm.converter.base import build_orm_col
from metadata.profiler.orm.functions.md5 import MD5 from metadata.profiler.orm.functions.md5 import MD5
from metadata.profiler.orm.functions.substr import Substr from metadata.profiler.orm.functions.substr import Substr
from metadata.profiler.orm.registry import Dialects from metadata.profiler.orm.registry import Dialects, PythonDialects
from metadata.utils.collections import CaseInsensitiveList
from metadata.utils.logger import test_suite_logger from metadata.utils.logger import test_suite_logger
logger = test_suite_logger() logger = test_suite_logger()
@ -67,6 +71,37 @@ SUPPORTED_DIALECTS = [
] ]
def build_sample_where_clause(
table: TableParameter, key_columns: List[str], salt: str, hex_nounce: str
) -> str:
sql_alchemy_columns = [
build_orm_col(i, c, table.database_service_type)
for i, c in enumerate(table.columns)
if c.name.root in key_columns
]
reduced_concat = reduce(
lambda c1, c2: c1.concat(c2), sql_alchemy_columns + [literal(salt)]
)
sqa_dialect = create_engine(
f"{PythonDialects[table.database_service_type.name].value}://"
).dialect
return str(
select()
.filter(
Substr(
MD5(reduced_concat),
1,
8,
)
< hex_nounce
)
.whereclause.compile(
dialect=sqa_dialect,
compile_kwargs={"literal_binds": True},
)
)
def compile_and_clauses(elements) -> str: def compile_and_clauses(elements) -> str:
"""Compile a list of elements into a string with 'and' clauses. """Compile a list of elements into a string with 'and' clauses.
@ -287,17 +322,20 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
def get_table_diff(self) -> DiffResultWrapper: def get_table_diff(self) -> DiffResultWrapper:
"""Calls data_diff.diff_tables with the parameters from the test case.""" """Calls data_diff.diff_tables with the parameters from the test case."""
left_where, right_where = self.sample_where_clause()
table1 = data_diff.connect_to_table( table1 = data_diff.connect_to_table(
self.runtime_params.table1.serviceUrl, self.runtime_params.table1.serviceUrl,
self.runtime_params.table1.path, self.runtime_params.table1.path,
self.runtime_params.keyColumns, # type: ignore self.runtime_params.keyColumns, # type: ignore
case_sensitive=self.get_case_sensitive(), case_sensitive=self.get_case_sensitive(),
where=left_where,
) )
table2 = data_diff.connect_to_table( table2 = data_diff.connect_to_table(
self.runtime_params.table2.serviceUrl, self.runtime_params.table2.serviceUrl,
self.runtime_params.table2.path, self.runtime_params.table2.path,
self.runtime_params.keyColumns, # type: ignore self.runtime_params.keyColumns, # type: ignore
case_sensitive=self.get_case_sensitive(), case_sensitive=self.get_case_sensitive(),
where=right_where,
) )
data_diff_kwargs = { data_diff_kwargs = {
"key_columns": self.runtime_params.keyColumns, "key_columns": self.runtime_params.keyColumns,
@ -314,12 +352,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
def get_where(self) -> Optional[str]: def get_where(self) -> Optional[str]:
"""Returns the where clause from the test case parameters or None if it is a blank string.""" """Returns the where clause from the test case parameters or None if it is a blank string."""
runtime_where = self.runtime_params.whereClause return self.runtime_params.whereClause or None
sample_where = self.sample_where_clause()
where = compile_and_clauses([c for c in [runtime_where, sample_where] if c])
return where if where else None
def sample_where_clause(self) -> Optional[str]: def sample_where_clause(self) -> Tuple[Optional[str], Optional[str]]:
"""We use a where clause to sample the data for the diff. This is useful because with data diff """We use a where clause to sample the data for the diff. This is useful because with data diff
we do not have access to the underlying 'SELECT' statement. This method generates a where clause we do not have access to the underlying 'SELECT' statement. This method generates a where clause
that selects a random sample of the data based on the profile sample configuration. that selects a random sample of the data based on the profile sample configuration.
@ -340,19 +375,23 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
with different ids because the comparison will not be sensible. with different ids because the comparison will not be sensible.
""" """
if ( if (
# no sample configuration
self.runtime_params.table_profile_config is None self.runtime_params.table_profile_config is None
or self.runtime_params.table_profile_config.profileSample is None or self.runtime_params.table_profile_config.profileSample is None
# sample is 100% or in other words no sample is required
or (
self.runtime_params.table_profile_config.profileSampleType
== ProfileSampleType.PERCENTAGE
and self.runtime_params.table_profile_config.profileSample == 100
)
): ):
return None return None, None
if DatabaseServiceType.Mssql in [ if DatabaseServiceType.Mssql in [
self.runtime_params.table1.database_service_type, self.runtime_params.table1.database_service_type,
self.runtime_params.table2.database_service_type, self.runtime_params.table2.database_service_type,
]: ]:
raise ValueError( logger.warning("Sampling not supported in MSSQL. Skipping sampling.")
"MSSQL does not support sampling in data diff.\n" return None, None
"You can request this feature here:\n"
"https://github.com/open-metadata/OpenMetadata/issues/new?labels=enhancement&projects=&template=feature_request.md" # pylint: disable=line-too-long
)
nounce = self.calculate_nounce() nounce = self.calculate_nounce()
# SQL MD5 returns a 32 character hex string even with leading zeros so we need to # SQL MD5 returns a 32 character hex string even with leading zeros so we need to
# pad the nounce to 8 characters in preserve lexical order. # pad the nounce to 8 characters in preserve lexical order.
@ -364,29 +403,14 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
salt = "".join( salt = "".join(
random.choices(string.ascii_letters + string.digits, k=5) random.choices(string.ascii_letters + string.digits, k=5)
) # 1 / ~62^5 should be enough entropy. Use letters and digits to avoid messing with SQL syntax ) # 1 / ~62^5 should be enough entropy. Use letters and digits to avoid messing with SQL syntax
sql_alchemy_columns = [ key_columns = (
build_orm_col( CaseInsensitiveList(self.runtime_params.keyColumns)
i, c, self.runtime_params.table1.database_service_type if not self.get_case_sensitive()
) # TODO: get from runtime params else self.runtime_params.keyColumns
for i, c in enumerate(self.runtime_params.table1.columns)
if c.name.root in self.runtime_params.keyColumns
]
reduced_concat = reduce(
lambda c1, c2: c1.concat(c2), sql_alchemy_columns + [literal(salt)]
)
return str(
(
select()
.filter(
Substr(
MD5(reduced_concat),
1,
8,
)
< hex_nounce
)
.whereclause.compile(compile_kwargs={"literal_binds": True})
) )
return tuple(
build_sample_where_clause(table, key_columns, salt, hex_nounce)
for table in [self.runtime_params.table1, self.runtime_params.table2]
) )
def calculate_nounce(self, max_nounce=2**32 - 1) -> int: def calculate_nounce(self, max_nounce=2**32 - 1) -> int:

View File

@ -326,6 +326,10 @@ class TestParameters(BaseModel):
timestamp=int(datetime.now().timestamp() * 1000), timestamp=int(datetime.now().timestamp() * 1000),
testCaseStatus=TestCaseStatus.Success, testCaseStatus=TestCaseStatus.Success,
), ),
TableProfilerConfig(
profileSampleType=ProfileSampleType.PERCENTAGE,
profileSample=10,
),
), ),
( (
TestCaseDefinition( TestCaseDefinition(

View File

@ -19,6 +19,7 @@ from metadata.generated.schema.entity.data.table import (
from metadata.generated.schema.entity.services.databaseService import ( from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType, DatabaseServiceType,
) )
from metadata.generated.schema.tests.testCase import TestCase, TestCaseParameterValue
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -61,12 +62,18 @@ def test_compile_and_clauses(elements, expected):
} }
), ),
"table2": TableParameter.model_construct( "table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres} **{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
}
), ),
"keyColumns": ["id"], "keyColumns": ["id"],
} }
), ),
"SUBSTRING(MD5(id || 'a'), 1, 8) < '19999999'", ("SUBSTRING(MD5(id || 'a'), 1, 8) < '19999999'",) * 2,
), ),
( (
TableDiffRuntimeParameters.model_construct( TableDiffRuntimeParameters.model_construct(
@ -86,12 +93,18 @@ def test_compile_and_clauses(elements, expected):
} }
), ),
"table2": TableParameter.model_construct( "table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres} **{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
}
), ),
"keyColumns": ["id"], "keyColumns": ["id"],
} }
), ),
"SUBSTRING(MD5(id || 'a'), 1, 8) < '33333333'", ("SUBSTRING(MD5(id || 'a'), 1, 8) < '33333333'",) * 2,
), ),
( (
TableDiffRuntimeParameters.model_construct( TableDiffRuntimeParameters.model_construct(
@ -111,12 +124,18 @@ def test_compile_and_clauses(elements, expected):
} }
), ),
"table2": TableParameter.model_construct( "table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres} **{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
}
), ),
"keyColumns": ["id", "name"], "keyColumns": ["id", "name"],
} }
), ),
"SUBSTRING(MD5(id || name || 'a'), 1, 8) < '19999999'", ("SUBSTRING(MD5(id || name || 'a'), 1, 8) < '19999999'",) * 2,
), ),
( (
TableDiffRuntimeParameters.model_construct( TableDiffRuntimeParameters.model_construct(
@ -136,12 +155,51 @@ def test_compile_and_clauses(elements, expected):
} }
), ),
"table2": TableParameter.model_construct( "table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres} **{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
},
), ),
"keyColumns": ["id", "name"], "keyColumns": ["id", "name"],
} }
), ),
"SUBSTRING(MD5(id || name || 'a'), 1, 8) < '0083126e'", ("SUBSTRING(MD5(id || name || 'a'), 1, 8) < '0083126e'",) * 2,
),
(
TableDiffRuntimeParameters.model_construct(
**{
"table_profile_config": TableProfilerConfig(
profileSampleType=ProfileSampleType.ROWS,
profileSample=20,
),
"table1": TableParameter.model_construct(
**{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
}
),
"table2": TableParameter.model_construct(
**{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="ID", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
},
),
"keyColumns": ["id"],
}
),
(
"SUBSTRING(MD5(id || 'a'), 1, 8) < '0083126e'",
"SUBSTRING(MD5(\"ID\" || 'a'), 1, 8) < '0083126e'",
),
), ),
( (
TableDiffRuntimeParameters.model_construct( TableDiffRuntimeParameters.model_construct(
@ -157,17 +215,31 @@ def test_compile_and_clauses(elements, expected):
} }
), ),
"table2": TableParameter.model_construct( "table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres} **{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
},
), ),
"keyColumns": ["id", "name"], "keyColumns": ["id"],
} }
), ),
None, (None, None),
), ),
], ],
) )
def test_sample_where_clauses(config, expected): def test_sample_where_clauses(config, expected):
validator = TableDiffValidator(None, None, None) validator = TableDiffValidator(
None,
TestCase.model_construct(
parameterValues=[
TestCaseParameterValue(name="caseSensitiveColumns", value="false")
]
),
None,
)
validator.runtime_params = config validator.runtime_params = config
if ( if (
config.table_profile_config config.table_profile_config