mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-18 22:18:23 +00:00
MINOR: Data diff sample fix (#18632)
* fix(data-diff): sampling configuration handle the sampling condition separately for the 2 tables allowing to apply sampling on columns with mismatching cases * format
This commit is contained in:
parent
706d13e289
commit
bde6ee4125
@ -25,14 +25,17 @@ from data_diff.diff_tables import DiffResultWrapper
|
|||||||
from data_diff.errors import DataDiffMismatchingKeyTypesError
|
from data_diff.errors import DataDiffMismatchingKeyTypesError
|
||||||
from data_diff.utils import ArithAlphanumeric, CaseInsensitiveDict
|
from data_diff.utils import ArithAlphanumeric, CaseInsensitiveDict
|
||||||
from sqlalchemy import Column as SAColumn
|
from sqlalchemy import Column as SAColumn
|
||||||
from sqlalchemy import literal, select
|
from sqlalchemy import create_engine, literal, select
|
||||||
|
|
||||||
from metadata.data_quality.validations import utils
|
from metadata.data_quality.validations import utils
|
||||||
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
|
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
|
||||||
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
|
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
|
||||||
SQAValidatorMixin,
|
SQAValidatorMixin,
|
||||||
)
|
)
|
||||||
from metadata.data_quality.validations.models import TableDiffRuntimeParameters
|
from metadata.data_quality.validations.models import (
|
||||||
|
TableDiffRuntimeParameters,
|
||||||
|
TableParameter,
|
||||||
|
)
|
||||||
from metadata.generated.schema.entity.data.table import Column, ProfileSampleType
|
from metadata.generated.schema.entity.data.table import Column, ProfileSampleType
|
||||||
from metadata.generated.schema.entity.services.connections.database.sapHanaConnection import (
|
from metadata.generated.schema.entity.services.connections.database.sapHanaConnection import (
|
||||||
SapHanaScheme,
|
SapHanaScheme,
|
||||||
@ -48,7 +51,8 @@ from metadata.generated.schema.tests.basic import (
|
|||||||
from metadata.profiler.orm.converter.base import build_orm_col
|
from metadata.profiler.orm.converter.base import build_orm_col
|
||||||
from metadata.profiler.orm.functions.md5 import MD5
|
from metadata.profiler.orm.functions.md5 import MD5
|
||||||
from metadata.profiler.orm.functions.substr import Substr
|
from metadata.profiler.orm.functions.substr import Substr
|
||||||
from metadata.profiler.orm.registry import Dialects
|
from metadata.profiler.orm.registry import Dialects, PythonDialects
|
||||||
|
from metadata.utils.collections import CaseInsensitiveList
|
||||||
from metadata.utils.logger import test_suite_logger
|
from metadata.utils.logger import test_suite_logger
|
||||||
|
|
||||||
logger = test_suite_logger()
|
logger = test_suite_logger()
|
||||||
@ -67,6 +71,37 @@ SUPPORTED_DIALECTS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_sample_where_clause(
|
||||||
|
table: TableParameter, key_columns: List[str], salt: str, hex_nounce: str
|
||||||
|
) -> str:
|
||||||
|
sql_alchemy_columns = [
|
||||||
|
build_orm_col(i, c, table.database_service_type)
|
||||||
|
for i, c in enumerate(table.columns)
|
||||||
|
if c.name.root in key_columns
|
||||||
|
]
|
||||||
|
reduced_concat = reduce(
|
||||||
|
lambda c1, c2: c1.concat(c2), sql_alchemy_columns + [literal(salt)]
|
||||||
|
)
|
||||||
|
sqa_dialect = create_engine(
|
||||||
|
f"{PythonDialects[table.database_service_type.name].value}://"
|
||||||
|
).dialect
|
||||||
|
return str(
|
||||||
|
select()
|
||||||
|
.filter(
|
||||||
|
Substr(
|
||||||
|
MD5(reduced_concat),
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
)
|
||||||
|
< hex_nounce
|
||||||
|
)
|
||||||
|
.whereclause.compile(
|
||||||
|
dialect=sqa_dialect,
|
||||||
|
compile_kwargs={"literal_binds": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def compile_and_clauses(elements) -> str:
|
def compile_and_clauses(elements) -> str:
|
||||||
"""Compile a list of elements into a string with 'and' clauses.
|
"""Compile a list of elements into a string with 'and' clauses.
|
||||||
|
|
||||||
@ -287,17 +322,20 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
|
|||||||
|
|
||||||
def get_table_diff(self) -> DiffResultWrapper:
|
def get_table_diff(self) -> DiffResultWrapper:
|
||||||
"""Calls data_diff.diff_tables with the parameters from the test case."""
|
"""Calls data_diff.diff_tables with the parameters from the test case."""
|
||||||
|
left_where, right_where = self.sample_where_clause()
|
||||||
table1 = data_diff.connect_to_table(
|
table1 = data_diff.connect_to_table(
|
||||||
self.runtime_params.table1.serviceUrl,
|
self.runtime_params.table1.serviceUrl,
|
||||||
self.runtime_params.table1.path,
|
self.runtime_params.table1.path,
|
||||||
self.runtime_params.keyColumns, # type: ignore
|
self.runtime_params.keyColumns, # type: ignore
|
||||||
case_sensitive=self.get_case_sensitive(),
|
case_sensitive=self.get_case_sensitive(),
|
||||||
|
where=left_where,
|
||||||
)
|
)
|
||||||
table2 = data_diff.connect_to_table(
|
table2 = data_diff.connect_to_table(
|
||||||
self.runtime_params.table2.serviceUrl,
|
self.runtime_params.table2.serviceUrl,
|
||||||
self.runtime_params.table2.path,
|
self.runtime_params.table2.path,
|
||||||
self.runtime_params.keyColumns, # type: ignore
|
self.runtime_params.keyColumns, # type: ignore
|
||||||
case_sensitive=self.get_case_sensitive(),
|
case_sensitive=self.get_case_sensitive(),
|
||||||
|
where=right_where,
|
||||||
)
|
)
|
||||||
data_diff_kwargs = {
|
data_diff_kwargs = {
|
||||||
"key_columns": self.runtime_params.keyColumns,
|
"key_columns": self.runtime_params.keyColumns,
|
||||||
@ -314,12 +352,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
|
|||||||
|
|
||||||
def get_where(self) -> Optional[str]:
|
def get_where(self) -> Optional[str]:
|
||||||
"""Returns the where clause from the test case parameters or None if it is a blank string."""
|
"""Returns the where clause from the test case parameters or None if it is a blank string."""
|
||||||
runtime_where = self.runtime_params.whereClause
|
return self.runtime_params.whereClause or None
|
||||||
sample_where = self.sample_where_clause()
|
|
||||||
where = compile_and_clauses([c for c in [runtime_where, sample_where] if c])
|
|
||||||
return where if where else None
|
|
||||||
|
|
||||||
def sample_where_clause(self) -> Optional[str]:
|
def sample_where_clause(self) -> Tuple[Optional[str], Optional[str]]:
|
||||||
"""We use a where clause to sample the data for the diff. This is useful because with data diff
|
"""We use a where clause to sample the data for the diff. This is useful because with data diff
|
||||||
we do not have access to the underlying 'SELECT' statement. This method generates a where clause
|
we do not have access to the underlying 'SELECT' statement. This method generates a where clause
|
||||||
that selects a random sample of the data based on the profile sample configuration.
|
that selects a random sample of the data based on the profile sample configuration.
|
||||||
@ -340,19 +375,23 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
|
|||||||
with different ids because the comparison will not be sensible.
|
with different ids because the comparison will not be sensible.
|
||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
|
# no sample configuration
|
||||||
self.runtime_params.table_profile_config is None
|
self.runtime_params.table_profile_config is None
|
||||||
or self.runtime_params.table_profile_config.profileSample is None
|
or self.runtime_params.table_profile_config.profileSample is None
|
||||||
|
# sample is 100% or in other words no sample is required
|
||||||
|
or (
|
||||||
|
self.runtime_params.table_profile_config.profileSampleType
|
||||||
|
== ProfileSampleType.PERCENTAGE
|
||||||
|
and self.runtime_params.table_profile_config.profileSample == 100
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return None
|
return None, None
|
||||||
if DatabaseServiceType.Mssql in [
|
if DatabaseServiceType.Mssql in [
|
||||||
self.runtime_params.table1.database_service_type,
|
self.runtime_params.table1.database_service_type,
|
||||||
self.runtime_params.table2.database_service_type,
|
self.runtime_params.table2.database_service_type,
|
||||||
]:
|
]:
|
||||||
raise ValueError(
|
logger.warning("Sampling not supported in MSSQL. Skipping sampling.")
|
||||||
"MSSQL does not support sampling in data diff.\n"
|
return None, None
|
||||||
"You can request this feature here:\n"
|
|
||||||
"https://github.com/open-metadata/OpenMetadata/issues/new?labels=enhancement&projects=&template=feature_request.md" # pylint: disable=line-too-long
|
|
||||||
)
|
|
||||||
nounce = self.calculate_nounce()
|
nounce = self.calculate_nounce()
|
||||||
# SQL MD5 returns a 32 character hex string even with leading zeros so we need to
|
# SQL MD5 returns a 32 character hex string even with leading zeros so we need to
|
||||||
# pad the nounce to 8 characters in preserve lexical order.
|
# pad the nounce to 8 characters in preserve lexical order.
|
||||||
@ -364,29 +403,14 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
|
|||||||
salt = "".join(
|
salt = "".join(
|
||||||
random.choices(string.ascii_letters + string.digits, k=5)
|
random.choices(string.ascii_letters + string.digits, k=5)
|
||||||
) # 1 / ~62^5 should be enough entropy. Use letters and digits to avoid messing with SQL syntax
|
) # 1 / ~62^5 should be enough entropy. Use letters and digits to avoid messing with SQL syntax
|
||||||
sql_alchemy_columns = [
|
key_columns = (
|
||||||
build_orm_col(
|
CaseInsensitiveList(self.runtime_params.keyColumns)
|
||||||
i, c, self.runtime_params.table1.database_service_type
|
if not self.get_case_sensitive()
|
||||||
) # TODO: get from runtime params
|
else self.runtime_params.keyColumns
|
||||||
for i, c in enumerate(self.runtime_params.table1.columns)
|
|
||||||
if c.name.root in self.runtime_params.keyColumns
|
|
||||||
]
|
|
||||||
reduced_concat = reduce(
|
|
||||||
lambda c1, c2: c1.concat(c2), sql_alchemy_columns + [literal(salt)]
|
|
||||||
)
|
)
|
||||||
return str(
|
return tuple(
|
||||||
(
|
build_sample_where_clause(table, key_columns, salt, hex_nounce)
|
||||||
select()
|
for table in [self.runtime_params.table1, self.runtime_params.table2]
|
||||||
.filter(
|
|
||||||
Substr(
|
|
||||||
MD5(reduced_concat),
|
|
||||||
1,
|
|
||||||
8,
|
|
||||||
)
|
|
||||||
< hex_nounce
|
|
||||||
)
|
|
||||||
.whereclause.compile(compile_kwargs={"literal_binds": True})
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def calculate_nounce(self, max_nounce=2**32 - 1) -> int:
|
def calculate_nounce(self, max_nounce=2**32 - 1) -> int:
|
||||||
|
@ -326,6 +326,10 @@ class TestParameters(BaseModel):
|
|||||||
timestamp=int(datetime.now().timestamp() * 1000),
|
timestamp=int(datetime.now().timestamp() * 1000),
|
||||||
testCaseStatus=TestCaseStatus.Success,
|
testCaseStatus=TestCaseStatus.Success,
|
||||||
),
|
),
|
||||||
|
TableProfilerConfig(
|
||||||
|
profileSampleType=ProfileSampleType.PERCENTAGE,
|
||||||
|
profileSample=10,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
TestCaseDefinition(
|
TestCaseDefinition(
|
||||||
|
@ -19,6 +19,7 @@ from metadata.generated.schema.entity.data.table import (
|
|||||||
from metadata.generated.schema.entity.services.databaseService import (
|
from metadata.generated.schema.entity.services.databaseService import (
|
||||||
DatabaseServiceType,
|
DatabaseServiceType,
|
||||||
)
|
)
|
||||||
|
from metadata.generated.schema.tests.testCase import TestCase, TestCaseParameterValue
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -61,12 +62,18 @@ def test_compile_and_clauses(elements, expected):
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
"table2": TableParameter.model_construct(
|
"table2": TableParameter.model_construct(
|
||||||
**{"database_service_type": DatabaseServiceType.Postgres}
|
**{
|
||||||
|
"database_service_type": DatabaseServiceType.Postgres,
|
||||||
|
"columns": [
|
||||||
|
Column(name="id", dataType=DataType.STRING),
|
||||||
|
Column(name="name", dataType=DataType.STRING),
|
||||||
|
],
|
||||||
|
}
|
||||||
),
|
),
|
||||||
"keyColumns": ["id"],
|
"keyColumns": ["id"],
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
"SUBSTRING(MD5(id || 'a'), 1, 8) < '19999999'",
|
("SUBSTRING(MD5(id || 'a'), 1, 8) < '19999999'",) * 2,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
TableDiffRuntimeParameters.model_construct(
|
TableDiffRuntimeParameters.model_construct(
|
||||||
@ -86,12 +93,18 @@ def test_compile_and_clauses(elements, expected):
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
"table2": TableParameter.model_construct(
|
"table2": TableParameter.model_construct(
|
||||||
**{"database_service_type": DatabaseServiceType.Postgres}
|
**{
|
||||||
|
"database_service_type": DatabaseServiceType.Postgres,
|
||||||
|
"columns": [
|
||||||
|
Column(name="id", dataType=DataType.STRING),
|
||||||
|
Column(name="name", dataType=DataType.STRING),
|
||||||
|
],
|
||||||
|
}
|
||||||
),
|
),
|
||||||
"keyColumns": ["id"],
|
"keyColumns": ["id"],
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
"SUBSTRING(MD5(id || 'a'), 1, 8) < '33333333'",
|
("SUBSTRING(MD5(id || 'a'), 1, 8) < '33333333'",) * 2,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
TableDiffRuntimeParameters.model_construct(
|
TableDiffRuntimeParameters.model_construct(
|
||||||
@ -111,12 +124,18 @@ def test_compile_and_clauses(elements, expected):
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
"table2": TableParameter.model_construct(
|
"table2": TableParameter.model_construct(
|
||||||
**{"database_service_type": DatabaseServiceType.Postgres}
|
**{
|
||||||
|
"database_service_type": DatabaseServiceType.Postgres,
|
||||||
|
"columns": [
|
||||||
|
Column(name="id", dataType=DataType.STRING),
|
||||||
|
Column(name="name", dataType=DataType.STRING),
|
||||||
|
],
|
||||||
|
}
|
||||||
),
|
),
|
||||||
"keyColumns": ["id", "name"],
|
"keyColumns": ["id", "name"],
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
"SUBSTRING(MD5(id || name || 'a'), 1, 8) < '19999999'",
|
("SUBSTRING(MD5(id || name || 'a'), 1, 8) < '19999999'",) * 2,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
TableDiffRuntimeParameters.model_construct(
|
TableDiffRuntimeParameters.model_construct(
|
||||||
@ -136,12 +155,51 @@ def test_compile_and_clauses(elements, expected):
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
"table2": TableParameter.model_construct(
|
"table2": TableParameter.model_construct(
|
||||||
**{"database_service_type": DatabaseServiceType.Postgres}
|
**{
|
||||||
|
"database_service_type": DatabaseServiceType.Postgres,
|
||||||
|
"columns": [
|
||||||
|
Column(name="id", dataType=DataType.STRING),
|
||||||
|
Column(name="name", dataType=DataType.STRING),
|
||||||
|
],
|
||||||
|
},
|
||||||
),
|
),
|
||||||
"keyColumns": ["id", "name"],
|
"keyColumns": ["id", "name"],
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
"SUBSTRING(MD5(id || name || 'a'), 1, 8) < '0083126e'",
|
("SUBSTRING(MD5(id || name || 'a'), 1, 8) < '0083126e'",) * 2,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
TableDiffRuntimeParameters.model_construct(
|
||||||
|
**{
|
||||||
|
"table_profile_config": TableProfilerConfig(
|
||||||
|
profileSampleType=ProfileSampleType.ROWS,
|
||||||
|
profileSample=20,
|
||||||
|
),
|
||||||
|
"table1": TableParameter.model_construct(
|
||||||
|
**{
|
||||||
|
"database_service_type": DatabaseServiceType.Postgres,
|
||||||
|
"columns": [
|
||||||
|
Column(name="id", dataType=DataType.STRING),
|
||||||
|
Column(name="name", dataType=DataType.STRING),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"table2": TableParameter.model_construct(
|
||||||
|
**{
|
||||||
|
"database_service_type": DatabaseServiceType.Postgres,
|
||||||
|
"columns": [
|
||||||
|
Column(name="ID", dataType=DataType.STRING),
|
||||||
|
Column(name="name", dataType=DataType.STRING),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"keyColumns": ["id"],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SUBSTRING(MD5(id || 'a'), 1, 8) < '0083126e'",
|
||||||
|
"SUBSTRING(MD5(\"ID\" || 'a'), 1, 8) < '0083126e'",
|
||||||
|
),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
TableDiffRuntimeParameters.model_construct(
|
TableDiffRuntimeParameters.model_construct(
|
||||||
@ -157,17 +215,31 @@ def test_compile_and_clauses(elements, expected):
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
"table2": TableParameter.model_construct(
|
"table2": TableParameter.model_construct(
|
||||||
**{"database_service_type": DatabaseServiceType.Postgres}
|
**{
|
||||||
|
"database_service_type": DatabaseServiceType.Postgres,
|
||||||
|
"columns": [
|
||||||
|
Column(name="id", dataType=DataType.STRING),
|
||||||
|
Column(name="name", dataType=DataType.STRING),
|
||||||
|
],
|
||||||
|
},
|
||||||
),
|
),
|
||||||
"keyColumns": ["id", "name"],
|
"keyColumns": ["id"],
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
None,
|
(None, None),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sample_where_clauses(config, expected):
|
def test_sample_where_clauses(config, expected):
|
||||||
validator = TableDiffValidator(None, None, None)
|
validator = TableDiffValidator(
|
||||||
|
None,
|
||||||
|
TestCase.model_construct(
|
||||||
|
parameterValues=[
|
||||||
|
TestCaseParameterValue(name="caseSensitiveColumns", value="false")
|
||||||
|
]
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
validator.runtime_params = config
|
validator.runtime_params = config
|
||||||
if (
|
if (
|
||||||
config.table_profile_config
|
config.table_profile_config
|
||||||
|
Loading…
x
Reference in New Issue
Block a user