diff --git a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py index 02aea8f7741..56614bdfd21 100644 --- a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py +++ b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py @@ -25,14 +25,17 @@ from data_diff.diff_tables import DiffResultWrapper from data_diff.errors import DataDiffMismatchingKeyTypesError from data_diff.utils import ArithAlphanumeric, CaseInsensitiveDict from sqlalchemy import Column as SAColumn -from sqlalchemy import literal, select +from sqlalchemy import create_engine, literal, select from metadata.data_quality.validations import utils from metadata.data_quality.validations.base_test_handler import BaseTestValidator from metadata.data_quality.validations.mixins.sqa_validator_mixin import ( SQAValidatorMixin, ) -from metadata.data_quality.validations.models import TableDiffRuntimeParameters +from metadata.data_quality.validations.models import ( + TableDiffRuntimeParameters, + TableParameter, +) from metadata.generated.schema.entity.data.table import Column, ProfileSampleType from metadata.generated.schema.entity.services.connections.database.sapHanaConnection import ( SapHanaScheme, @@ -48,7 +51,8 @@ from metadata.generated.schema.tests.basic import ( from metadata.profiler.orm.converter.base import build_orm_col from metadata.profiler.orm.functions.md5 import MD5 from metadata.profiler.orm.functions.substr import Substr -from metadata.profiler.orm.registry import Dialects +from metadata.profiler.orm.registry import Dialects, PythonDialects +from metadata.utils.collections import CaseInsensitiveList from metadata.utils.logger import test_suite_logger logger = test_suite_logger() @@ -67,6 +71,37 @@ SUPPORTED_DIALECTS = [ ] +def build_sample_where_clause( + table: TableParameter, key_columns: List[str], salt: str, hex_nounce: str +) -> str: + sql_alchemy_columns = [ + build_orm_col(i, c, table.database_service_type) + for i, c in enumerate(table.columns) + if c.name.root in key_columns + ] + reduced_concat = reduce( + lambda c1, c2: c1.concat(c2), sql_alchemy_columns + [literal(salt)] + ) + sqa_dialect = create_engine( + f"{PythonDialects[table.database_service_type.name].value}://" + ).dialect + return str( + select() + .filter( + Substr( + MD5(reduced_concat), + 1, + 8, + ) + < hex_nounce + ) + .whereclause.compile( + dialect=sqa_dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + + def compile_and_clauses(elements) -> str: """Compile a list of elements into a string with 'and' clauses. @@ -287,17 +322,20 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): def get_table_diff(self) -> DiffResultWrapper: """Calls data_diff.diff_tables with the parameters from the test case.""" + left_where, right_where = self.sample_where_clause() table1 = data_diff.connect_to_table( self.runtime_params.table1.serviceUrl, self.runtime_params.table1.path, self.runtime_params.keyColumns, # type: ignore case_sensitive=self.get_case_sensitive(), + where=left_where, ) table2 = data_diff.connect_to_table( self.runtime_params.table2.serviceUrl, self.runtime_params.table2.path, self.runtime_params.keyColumns, # type: ignore case_sensitive=self.get_case_sensitive(), + where=right_where, ) data_diff_kwargs = { "key_columns": self.runtime_params.keyColumns, @@ -314,12 +352,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): def get_where(self) -> Optional[str]: """Returns the where clause from the test case parameters or None if it is a blank string.""" - runtime_where = self.runtime_params.whereClause - sample_where = self.sample_where_clause() - where = compile_and_clauses([c for c in [runtime_where, sample_where] if c]) - return where if where else None + return self.runtime_params.whereClause or None - def sample_where_clause(self) -> Optional[str]: + def sample_where_clause(self) -> Tuple[Optional[str], Optional[str]]: """We use a where clause to sample the data for the diff. This is useful because with data diff we do not have access to the underlying 'SELECT' statement. This method generates a where clause that selects a random sample of the data based on the profile sample configuration. @@ -340,19 +375,23 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): with different ids because the comparison will not be sensible. """ if ( + # no sample configuration self.runtime_params.table_profile_config is None or self.runtime_params.table_profile_config.profileSample is None + # sample is 100% or in other words no sample is required + or ( + self.runtime_params.table_profile_config.profileSampleType + == ProfileSampleType.PERCENTAGE + and self.runtime_params.table_profile_config.profileSample == 100 + ) ): - return None + return None, None if DatabaseServiceType.Mssql in [ self.runtime_params.table1.database_service_type, self.runtime_params.table2.database_service_type, ]: - raise ValueError( - "MSSQL does not support sampling in data diff.\n" - "You can request this feature here:\n" - "https://github.com/open-metadata/OpenMetadata/issues/new?labels=enhancement&projects=&template=feature_request.md" # pylint: disable=line-too-long - ) + logger.warning("Sampling not supported in MSSQL. Skipping sampling.") + return None, None nounce = self.calculate_nounce() # SQL MD5 returns a 32 character hex string even with leading zeros so we need to # pad the nounce to 8 characters in preserve lexical order. @@ -364,29 +403,14 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): salt = "".join( random.choices(string.ascii_letters + string.digits, k=5) ) # 1 / ~62^5 should be enough entropy. Use letters and digits to avoid messing with SQL syntax - sql_alchemy_columns = [ - build_orm_col( - i, c, self.runtime_params.table1.database_service_type - ) # TODO: get from runtime params - for i, c in enumerate(self.runtime_params.table1.columns) - if c.name.root in self.runtime_params.keyColumns - ] - reduced_concat = reduce( - lambda c1, c2: c1.concat(c2), sql_alchemy_columns + [literal(salt)] + key_columns = ( + CaseInsensitiveList(self.runtime_params.keyColumns) + if not self.get_case_sensitive() + else self.runtime_params.keyColumns ) - return str( - ( - select() - .filter( - Substr( - MD5(reduced_concat), - 1, - 8, - ) - < hex_nounce - ) - .whereclause.compile(compile_kwargs={"literal_binds": True}) - ) + return tuple( + build_sample_where_clause(table, key_columns, salt, hex_nounce) + for table in [self.runtime_params.table1, self.runtime_params.table2] ) def calculate_nounce(self, max_nounce=2**32 - 1) -> int: diff --git a/ingestion/tests/integration/data_quality/test_data_diff.py b/ingestion/tests/integration/data_quality/test_data_diff.py index 976d82cb784..890e279000d 100644 --- a/ingestion/tests/integration/data_quality/test_data_diff.py +++ b/ingestion/tests/integration/data_quality/test_data_diff.py @@ -326,6 +326,10 @@ class TestParameters(BaseModel): timestamp=int(datetime.now().timestamp() * 1000), testCaseStatus=TestCaseStatus.Success, ), + TableProfilerConfig( + profileSampleType=ProfileSampleType.PERCENTAGE, + profileSample=10, + ), ), ( TestCaseDefinition( diff --git a/ingestion/tests/unit/metadata/data_quality/test_data_diff.py b/ingestion/tests/unit/metadata/data_quality/test_data_diff.py index 079e57a173c..f8af55b6b03 100644 --- a/ingestion/tests/unit/metadata/data_quality/test_data_diff.py +++ b/ingestion/tests/unit/metadata/data_quality/test_data_diff.py @@ -19,6 +19,7 @@ from metadata.generated.schema.entity.data.table import ( from metadata.generated.schema.entity.services.databaseService import ( DatabaseServiceType, ) +from metadata.generated.schema.tests.testCase import TestCase, TestCaseParameterValue @pytest.mark.parametrize( @@ -61,12 +62,18 @@ def test_compile_and_clauses(elements, expected): } ), "table2": TableParameter.model_construct( - **{"database_service_type": DatabaseServiceType.Postgres} + **{ + "database_service_type": DatabaseServiceType.Postgres, + "columns": [ + Column(name="id", dataType=DataType.STRING), + Column(name="name", dataType=DataType.STRING), + ], + } ), "keyColumns": ["id"], } ), - "SUBSTRING(MD5(id || 'a'), 1, 8) < '19999999'", + ("SUBSTRING(MD5(id || 'a'), 1, 8) < '19999999'",) * 2, ), ( TableDiffRuntimeParameters.model_construct( @@ -86,12 +93,18 @@ def test_compile_and_clauses(elements, expected): } ), "table2": TableParameter.model_construct( - **{"database_service_type": DatabaseServiceType.Postgres} + **{ + "database_service_type": DatabaseServiceType.Postgres, + "columns": [ + Column(name="id", dataType=DataType.STRING), + Column(name="name", dataType=DataType.STRING), + ], + } ), "keyColumns": ["id"], } ), - "SUBSTRING(MD5(id || 'a'), 1, 8) < '33333333'", + ("SUBSTRING(MD5(id || 'a'), 1, 8) < '33333333'",) * 2, ), ( TableDiffRuntimeParameters.model_construct( @@ -111,12 +124,18 @@ def test_compile_and_clauses(elements, expected): } ), "table2": TableParameter.model_construct( - **{"database_service_type": DatabaseServiceType.Postgres} + **{ + "database_service_type": DatabaseServiceType.Postgres, + "columns": [ + Column(name="id", dataType=DataType.STRING), + Column(name="name", dataType=DataType.STRING), + ], + } ), "keyColumns": ["id", "name"], } ), - "SUBSTRING(MD5(id || name || 'a'), 1, 8) < '19999999'", + ("SUBSTRING(MD5(id || name || 'a'), 1, 8) < '19999999'",) * 2, ), ( TableDiffRuntimeParameters.model_construct( @@ -136,12 +155,51 @@ def test_compile_and_clauses(elements, expected): } ), "table2": TableParameter.model_construct( - **{"database_service_type": DatabaseServiceType.Postgres} + **{ + "database_service_type": DatabaseServiceType.Postgres, + "columns": [ + Column(name="id", dataType=DataType.STRING), + Column(name="name", dataType=DataType.STRING), + ], + }, ), "keyColumns": ["id", "name"], } ), - "SUBSTRING(MD5(id || name || 'a'), 1, 8) < '0083126e'", + ("SUBSTRING(MD5(id || name || 'a'), 1, 8) < '0083126e'",) * 2, + ), + ( + TableDiffRuntimeParameters.model_construct( + **{ + "table_profile_config": TableProfilerConfig( + profileSampleType=ProfileSampleType.ROWS, + profileSample=20, + ), + "table1": TableParameter.model_construct( + **{ + "database_service_type": DatabaseServiceType.Postgres, + "columns": [ + Column(name="id", dataType=DataType.STRING), + Column(name="name", dataType=DataType.STRING), + ], + } + ), + "table2": TableParameter.model_construct( + **{ + "database_service_type": DatabaseServiceType.Postgres, + "columns": [ + Column(name="ID", dataType=DataType.STRING), + Column(name="name", dataType=DataType.STRING), + ], + }, + ), + "keyColumns": ["id"], + } + ), + ( + "SUBSTRING(MD5(id || 'a'), 1, 8) < '0083126e'", + "SUBSTRING(MD5(\"ID\" || 'a'), 1, 8) < '0083126e'", + ), ), ( TableDiffRuntimeParameters.model_construct( @@ -157,17 +215,31 @@ def test_compile_and_clauses(elements, expected): } ), "table2": TableParameter.model_construct( - **{"database_service_type": DatabaseServiceType.Postgres} + **{ + "database_service_type": DatabaseServiceType.Postgres, + "columns": [ + Column(name="id", dataType=DataType.STRING), + Column(name="name", dataType=DataType.STRING), + ], + }, ), - "keyColumns": ["id", "name"], + "keyColumns": ["id"], } ), - None, + (None, None), ), ], ) def test_sample_where_clauses(config, expected): - validator = TableDiffValidator(None, None, None) + validator = TableDiffValidator( + None, + TestCase.model_construct( + parameterValues=[ + TestCaseParameterValue(name="caseSensitiveColumns", value="false") + ] + ), + None, + ) validator.runtime_params = config if ( config.table_profile_config