mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-18 14:06:59 +00:00
MINOR: Data diff sample fix (#18632)
* fix(data-diff): sampling configuration handle the sampling condition separately for the 2 tables allowing to apply sampling on columns with mismatching cases * format
This commit is contained in:
parent
706d13e289
commit
bde6ee4125
@ -25,14 +25,17 @@ from data_diff.diff_tables import DiffResultWrapper
|
||||
from data_diff.errors import DataDiffMismatchingKeyTypesError
|
||||
from data_diff.utils import ArithAlphanumeric, CaseInsensitiveDict
|
||||
from sqlalchemy import Column as SAColumn
|
||||
from sqlalchemy import literal, select
|
||||
from sqlalchemy import create_engine, literal, select
|
||||
|
||||
from metadata.data_quality.validations import utils
|
||||
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
|
||||
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
|
||||
SQAValidatorMixin,
|
||||
)
|
||||
from metadata.data_quality.validations.models import TableDiffRuntimeParameters
|
||||
from metadata.data_quality.validations.models import (
|
||||
TableDiffRuntimeParameters,
|
||||
TableParameter,
|
||||
)
|
||||
from metadata.generated.schema.entity.data.table import Column, ProfileSampleType
|
||||
from metadata.generated.schema.entity.services.connections.database.sapHanaConnection import (
|
||||
SapHanaScheme,
|
||||
@ -48,7 +51,8 @@ from metadata.generated.schema.tests.basic import (
|
||||
from metadata.profiler.orm.converter.base import build_orm_col
|
||||
from metadata.profiler.orm.functions.md5 import MD5
|
||||
from metadata.profiler.orm.functions.substr import Substr
|
||||
from metadata.profiler.orm.registry import Dialects
|
||||
from metadata.profiler.orm.registry import Dialects, PythonDialects
|
||||
from metadata.utils.collections import CaseInsensitiveList
|
||||
from metadata.utils.logger import test_suite_logger
|
||||
|
||||
logger = test_suite_logger()
|
||||
@ -67,6 +71,37 @@ SUPPORTED_DIALECTS = [
|
||||
]
|
||||
|
||||
|
||||
def build_sample_where_clause(
|
||||
table: TableParameter, key_columns: List[str], salt: str, hex_nounce: str
|
||||
) -> str:
|
||||
sql_alchemy_columns = [
|
||||
build_orm_col(i, c, table.database_service_type)
|
||||
for i, c in enumerate(table.columns)
|
||||
if c.name.root in key_columns
|
||||
]
|
||||
reduced_concat = reduce(
|
||||
lambda c1, c2: c1.concat(c2), sql_alchemy_columns + [literal(salt)]
|
||||
)
|
||||
sqa_dialect = create_engine(
|
||||
f"{PythonDialects[table.database_service_type.name].value}://"
|
||||
).dialect
|
||||
return str(
|
||||
select()
|
||||
.filter(
|
||||
Substr(
|
||||
MD5(reduced_concat),
|
||||
1,
|
||||
8,
|
||||
)
|
||||
< hex_nounce
|
||||
)
|
||||
.whereclause.compile(
|
||||
dialect=sqa_dialect,
|
||||
compile_kwargs={"literal_binds": True},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def compile_and_clauses(elements) -> str:
|
||||
"""Compile a list of elements into a string with 'and' clauses.
|
||||
|
||||
@ -287,17 +322,20 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
|
||||
|
||||
def get_table_diff(self) -> DiffResultWrapper:
|
||||
"""Calls data_diff.diff_tables with the parameters from the test case."""
|
||||
left_where, right_where = self.sample_where_clause()
|
||||
table1 = data_diff.connect_to_table(
|
||||
self.runtime_params.table1.serviceUrl,
|
||||
self.runtime_params.table1.path,
|
||||
self.runtime_params.keyColumns, # type: ignore
|
||||
case_sensitive=self.get_case_sensitive(),
|
||||
where=left_where,
|
||||
)
|
||||
table2 = data_diff.connect_to_table(
|
||||
self.runtime_params.table2.serviceUrl,
|
||||
self.runtime_params.table2.path,
|
||||
self.runtime_params.keyColumns, # type: ignore
|
||||
case_sensitive=self.get_case_sensitive(),
|
||||
where=right_where,
|
||||
)
|
||||
data_diff_kwargs = {
|
||||
"key_columns": self.runtime_params.keyColumns,
|
||||
@ -314,12 +352,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
|
||||
|
||||
def get_where(self) -> Optional[str]:
|
||||
"""Returns the where clause from the test case parameters or None if it is a blank string."""
|
||||
runtime_where = self.runtime_params.whereClause
|
||||
sample_where = self.sample_where_clause()
|
||||
where = compile_and_clauses([c for c in [runtime_where, sample_where] if c])
|
||||
return where if where else None
|
||||
return self.runtime_params.whereClause or None
|
||||
|
||||
def sample_where_clause(self) -> Optional[str]:
|
||||
def sample_where_clause(self) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""We use a where clause to sample the data for the diff. This is useful because with data diff
|
||||
we do not have access to the underlying 'SELECT' statement. This method generates a where clause
|
||||
that selects a random sample of the data based on the profile sample configuration.
|
||||
@ -340,19 +375,23 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
|
||||
with different ids because the comparison will not be sensible.
|
||||
"""
|
||||
if (
|
||||
# no sample configuration
|
||||
self.runtime_params.table_profile_config is None
|
||||
or self.runtime_params.table_profile_config.profileSample is None
|
||||
# sample is 100% or in other words no sample is required
|
||||
or (
|
||||
self.runtime_params.table_profile_config.profileSampleType
|
||||
== ProfileSampleType.PERCENTAGE
|
||||
and self.runtime_params.table_profile_config.profileSample == 100
|
||||
)
|
||||
):
|
||||
return None
|
||||
return None, None
|
||||
if DatabaseServiceType.Mssql in [
|
||||
self.runtime_params.table1.database_service_type,
|
||||
self.runtime_params.table2.database_service_type,
|
||||
]:
|
||||
raise ValueError(
|
||||
"MSSQL does not support sampling in data diff.\n"
|
||||
"You can request this feature here:\n"
|
||||
"https://github.com/open-metadata/OpenMetadata/issues/new?labels=enhancement&projects=&template=feature_request.md" # pylint: disable=line-too-long
|
||||
)
|
||||
logger.warning("Sampling not supported in MSSQL. Skipping sampling.")
|
||||
return None, None
|
||||
nounce = self.calculate_nounce()
|
||||
# SQL MD5 returns a 32 character hex string even with leading zeros so we need to
|
||||
# pad the nounce to 8 characters in preserve lexical order.
|
||||
@ -364,29 +403,14 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
|
||||
salt = "".join(
|
||||
random.choices(string.ascii_letters + string.digits, k=5)
|
||||
) # 1 / ~62^5 should be enough entropy. Use letters and digits to avoid messing with SQL syntax
|
||||
sql_alchemy_columns = [
|
||||
build_orm_col(
|
||||
i, c, self.runtime_params.table1.database_service_type
|
||||
) # TODO: get from runtime params
|
||||
for i, c in enumerate(self.runtime_params.table1.columns)
|
||||
if c.name.root in self.runtime_params.keyColumns
|
||||
]
|
||||
reduced_concat = reduce(
|
||||
lambda c1, c2: c1.concat(c2), sql_alchemy_columns + [literal(salt)]
|
||||
)
|
||||
return str(
|
||||
(
|
||||
select()
|
||||
.filter(
|
||||
Substr(
|
||||
MD5(reduced_concat),
|
||||
1,
|
||||
8,
|
||||
)
|
||||
< hex_nounce
|
||||
)
|
||||
.whereclause.compile(compile_kwargs={"literal_binds": True})
|
||||
key_columns = (
|
||||
CaseInsensitiveList(self.runtime_params.keyColumns)
|
||||
if not self.get_case_sensitive()
|
||||
else self.runtime_params.keyColumns
|
||||
)
|
||||
return tuple(
|
||||
build_sample_where_clause(table, key_columns, salt, hex_nounce)
|
||||
for table in [self.runtime_params.table1, self.runtime_params.table2]
|
||||
)
|
||||
|
||||
def calculate_nounce(self, max_nounce=2**32 - 1) -> int:
|
||||
|
@ -326,6 +326,10 @@ class TestParameters(BaseModel):
|
||||
timestamp=int(datetime.now().timestamp() * 1000),
|
||||
testCaseStatus=TestCaseStatus.Success,
|
||||
),
|
||||
TableProfilerConfig(
|
||||
profileSampleType=ProfileSampleType.PERCENTAGE,
|
||||
profileSample=10,
|
||||
),
|
||||
),
|
||||
(
|
||||
TestCaseDefinition(
|
||||
|
@ -19,6 +19,7 @@ from metadata.generated.schema.entity.data.table import (
|
||||
from metadata.generated.schema.entity.services.databaseService import (
|
||||
DatabaseServiceType,
|
||||
)
|
||||
from metadata.generated.schema.tests.testCase import TestCase, TestCaseParameterValue
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -61,12 +62,18 @@ def test_compile_and_clauses(elements, expected):
|
||||
}
|
||||
),
|
||||
"table2": TableParameter.model_construct(
|
||||
**{"database_service_type": DatabaseServiceType.Postgres}
|
||||
**{
|
||||
"database_service_type": DatabaseServiceType.Postgres,
|
||||
"columns": [
|
||||
Column(name="id", dataType=DataType.STRING),
|
||||
Column(name="name", dataType=DataType.STRING),
|
||||
],
|
||||
}
|
||||
),
|
||||
"keyColumns": ["id"],
|
||||
}
|
||||
),
|
||||
"SUBSTRING(MD5(id || 'a'), 1, 8) < '19999999'",
|
||||
("SUBSTRING(MD5(id || 'a'), 1, 8) < '19999999'",) * 2,
|
||||
),
|
||||
(
|
||||
TableDiffRuntimeParameters.model_construct(
|
||||
@ -86,12 +93,18 @@ def test_compile_and_clauses(elements, expected):
|
||||
}
|
||||
),
|
||||
"table2": TableParameter.model_construct(
|
||||
**{"database_service_type": DatabaseServiceType.Postgres}
|
||||
**{
|
||||
"database_service_type": DatabaseServiceType.Postgres,
|
||||
"columns": [
|
||||
Column(name="id", dataType=DataType.STRING),
|
||||
Column(name="name", dataType=DataType.STRING),
|
||||
],
|
||||
}
|
||||
),
|
||||
"keyColumns": ["id"],
|
||||
}
|
||||
),
|
||||
"SUBSTRING(MD5(id || 'a'), 1, 8) < '33333333'",
|
||||
("SUBSTRING(MD5(id || 'a'), 1, 8) < '33333333'",) * 2,
|
||||
),
|
||||
(
|
||||
TableDiffRuntimeParameters.model_construct(
|
||||
@ -111,12 +124,18 @@ def test_compile_and_clauses(elements, expected):
|
||||
}
|
||||
),
|
||||
"table2": TableParameter.model_construct(
|
||||
**{"database_service_type": DatabaseServiceType.Postgres}
|
||||
**{
|
||||
"database_service_type": DatabaseServiceType.Postgres,
|
||||
"columns": [
|
||||
Column(name="id", dataType=DataType.STRING),
|
||||
Column(name="name", dataType=DataType.STRING),
|
||||
],
|
||||
}
|
||||
),
|
||||
"keyColumns": ["id", "name"],
|
||||
}
|
||||
),
|
||||
"SUBSTRING(MD5(id || name || 'a'), 1, 8) < '19999999'",
|
||||
("SUBSTRING(MD5(id || name || 'a'), 1, 8) < '19999999'",) * 2,
|
||||
),
|
||||
(
|
||||
TableDiffRuntimeParameters.model_construct(
|
||||
@ -136,12 +155,51 @@ def test_compile_and_clauses(elements, expected):
|
||||
}
|
||||
),
|
||||
"table2": TableParameter.model_construct(
|
||||
**{"database_service_type": DatabaseServiceType.Postgres}
|
||||
**{
|
||||
"database_service_type": DatabaseServiceType.Postgres,
|
||||
"columns": [
|
||||
Column(name="id", dataType=DataType.STRING),
|
||||
Column(name="name", dataType=DataType.STRING),
|
||||
],
|
||||
},
|
||||
),
|
||||
"keyColumns": ["id", "name"],
|
||||
}
|
||||
),
|
||||
"SUBSTRING(MD5(id || name || 'a'), 1, 8) < '0083126e'",
|
||||
("SUBSTRING(MD5(id || name || 'a'), 1, 8) < '0083126e'",) * 2,
|
||||
),
|
||||
(
|
||||
TableDiffRuntimeParameters.model_construct(
|
||||
**{
|
||||
"table_profile_config": TableProfilerConfig(
|
||||
profileSampleType=ProfileSampleType.ROWS,
|
||||
profileSample=20,
|
||||
),
|
||||
"table1": TableParameter.model_construct(
|
||||
**{
|
||||
"database_service_type": DatabaseServiceType.Postgres,
|
||||
"columns": [
|
||||
Column(name="id", dataType=DataType.STRING),
|
||||
Column(name="name", dataType=DataType.STRING),
|
||||
],
|
||||
}
|
||||
),
|
||||
"table2": TableParameter.model_construct(
|
||||
**{
|
||||
"database_service_type": DatabaseServiceType.Postgres,
|
||||
"columns": [
|
||||
Column(name="ID", dataType=DataType.STRING),
|
||||
Column(name="name", dataType=DataType.STRING),
|
||||
],
|
||||
},
|
||||
),
|
||||
"keyColumns": ["id"],
|
||||
}
|
||||
),
|
||||
(
|
||||
"SUBSTRING(MD5(id || 'a'), 1, 8) < '0083126e'",
|
||||
"SUBSTRING(MD5(\"ID\" || 'a'), 1, 8) < '0083126e'",
|
||||
),
|
||||
),
|
||||
(
|
||||
TableDiffRuntimeParameters.model_construct(
|
||||
@ -157,17 +215,31 @@ def test_compile_and_clauses(elements, expected):
|
||||
}
|
||||
),
|
||||
"table2": TableParameter.model_construct(
|
||||
**{"database_service_type": DatabaseServiceType.Postgres}
|
||||
**{
|
||||
"database_service_type": DatabaseServiceType.Postgres,
|
||||
"columns": [
|
||||
Column(name="id", dataType=DataType.STRING),
|
||||
Column(name="name", dataType=DataType.STRING),
|
||||
],
|
||||
},
|
||||
),
|
||||
"keyColumns": ["id", "name"],
|
||||
"keyColumns": ["id"],
|
||||
}
|
||||
),
|
||||
None,
|
||||
(None, None),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_sample_where_clauses(config, expected):
|
||||
validator = TableDiffValidator(None, None, None)
|
||||
validator = TableDiffValidator(
|
||||
None,
|
||||
TestCase.model_construct(
|
||||
parameterValues=[
|
||||
TestCaseParameterValue(name="caseSensitiveColumns", value="false")
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
validator.runtime_params = config
|
||||
if (
|
||||
config.table_profile_config
|
||||
|
Loading…
x
Reference in New Issue
Block a user