From cdaa5c10af65f486b13237d60a0b85a3da590ea0 Mon Sep 17 00:00:00 2001 From: Imri Paran Date: Mon, 11 Nov 2024 10:07:23 +0100 Subject: [PATCH] [GEN-1996] feat(data-quality): use sampling config in data diff (#18532) * feat(data-quality): use sampling config in data diff - get the table profiling config - use hashing to sample deterministically the same ids from each table - use dirty-equals to assert results of stochastic processes * - reverted missing md5 - added missing database service type * - use a custom substr sql function * fixed nounce * added failure for mssql with sampling because it requires a larger change in the data-diff library * fixed unit tests * updated range for sampling --- ingestion/setup.py | 1 + .../data_quality/validations/models.py | 7 +- .../table_diff_params_setter.py | 9 +- .../validations/table/sqlalchemy/tableDiff.py | 152 ++++++++++++++- .../metadata/profiler/orm/functions/md5.py | 39 ++++ .../metadata/profiler/orm/functions/substr.py | 32 ++++ .../{test_table_diff.py => test_data_diff.py} | 65 ++++++- .../metadata/data_quality/test_data_diff.py | 178 ++++++++++++++++++ 8 files changed, 476 insertions(+), 7 deletions(-) create mode 100644 ingestion/src/metadata/profiler/orm/functions/md5.py create mode 100644 ingestion/src/metadata/profiler/orm/functions/substr.py rename ingestion/tests/integration/data_quality/{test_table_diff.py => test_data_diff.py} (90%) create mode 100644 ingestion/tests/unit/metadata/data_quality/test_data_diff.py diff --git a/ingestion/setup.py b/ingestion/setup.py index f7970ffeca4..e209682d835 100644 --- a/ingestion/setup.py +++ b/ingestion/setup.py @@ -354,6 +354,7 @@ test = { "pytest==7.0.0", "pytest-cov", "pytest-order", + "dirty-equals", # install dbt dependency "dbt-artifacts-parser", "freezegun", diff --git a/ingestion/src/metadata/data_quality/validations/models.py b/ingestion/src/metadata/data_quality/validations/models.py index cdb222b8ae4..f2e4128ff7e 100644 --- a/ingestion/src/metadata/data_quality/validations/models.py +++ b/ingestion/src/metadata/data_quality/validations/models.py @@ -4,13 +4,17 @@ from typing import List, Optional from pydantic import BaseModel -from metadata.generated.schema.entity.data.table import Column +from metadata.generated.schema.entity.data.table import Column, TableProfilerConfig +from metadata.generated.schema.entity.services.databaseService import ( + DatabaseServiceType, +) class TableParameter(BaseModel): serviceUrl: str path: str columns: List[Column] + database_service_type: DatabaseServiceType class TableDiffRuntimeParameters(BaseModel): @@ -19,3 +23,4 @@ class TableDiffRuntimeParameters(BaseModel): keyColumns: List[str] extraColumns: List[str] whereClause: Optional[str] + table_profile_config: Optional[TableProfilerConfig] diff --git a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py index ffadb0e0c6b..09d713d9ee8 100644 --- a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py +++ b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py @@ -77,7 +77,9 @@ class TableDiffParamsSetter(RuntimeParameterSetter): key_columns = self.get_key_columns(test_case) extra_columns = self.get_extra_columns(key_columns, test_case) return TableDiffRuntimeParameters( + table_profile_config=self.table_entity.tableProfilerConfig, table1=TableParameter( + database_service_type=service1.serviceType, path=self.get_data_diff_table_path( self.table_entity.fullyQualifiedName.root ), @@ -94,6 +96,7 @@ class TableDiffParamsSetter(RuntimeParameterSetter): ), ), table2=TableParameter( + database_service_type=service2.serviceType, path=self.get_data_diff_table_path(table2_fqn), serviceUrl=self.get_data_diff_url( service2, @@ -118,8 +121,10 @@ class TableDiffParamsSetter(RuntimeParameterSetter): param_where_clause = self.get_parameter(test_case, "where", None) partition_where_clause = ( None - if not self.sampler._partition_details - or not self.sampler._partition_details.enablePartitioning + if not ( + self.sampler._partition_details + and self.sampler._partition_details.enablePartitioning + ) else self.sampler.get_partitioned_query().whereclause.compile( compile_kwargs={"literal_binds": True} ) diff --git a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py index 39a04474a40..02aea8f7741 100644 --- a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py +++ b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py @@ -10,8 +10,11 @@ # limitations under the License. # pylint: disable=missing-module-docstring import logging +import random +import string import traceback from decimal import Decimal +from functools import reduce from itertools import islice from typing import Dict, Iterable, List, Optional, Tuple, cast from urllib.parse import urlparse @@ -22,6 +25,7 @@ from data_diff.diff_tables import DiffResultWrapper from data_diff.errors import DataDiffMismatchingKeyTypesError from data_diff.utils import ArithAlphanumeric, CaseInsensitiveDict from sqlalchemy import Column as SAColumn +from sqlalchemy import literal, select from metadata.data_quality.validations import utils from metadata.data_quality.validations.base_test_handler import BaseTestValidator @@ -29,15 +33,21 @@ from metadata.data_quality.validations.mixins.sqa_validator_mixin import ( SQAValidatorMixin, ) from metadata.data_quality.validations.models import TableDiffRuntimeParameters -from metadata.generated.schema.entity.data.table import Column +from metadata.generated.schema.entity.data.table import Column, ProfileSampleType from metadata.generated.schema.entity.services.connections.database.sapHanaConnection import ( SapHanaScheme, ) +from metadata.generated.schema.entity.services.databaseService import ( + DatabaseServiceType, +) from metadata.generated.schema.tests.basic import ( TestCaseResult, TestCaseStatus, TestResultValue, ) +from metadata.profiler.orm.converter.base import build_orm_col +from metadata.profiler.orm.functions.md5 import MD5 +from metadata.profiler.orm.functions.substr import Substr from metadata.profiler.orm.registry import Dialects from metadata.utils.logger import test_suite_logger @@ -57,6 +67,42 @@ SUPPORTED_DIALECTS = [ ] +def compile_and_clauses(elements) -> str: + """Compile a list of elements into a string with 'and' clauses. + + Args: + elements: A string or a list of strings or lists + + Returns: + A string with 'and' clauses + + Raises: + ValueError: If the input is not a string or a list + + Examples: + >>> compile_and_clauses("a") + 'a' + >>> compile_and_clauses(["a", "b"]) + 'a and b' + >>> compile_and_clauses([["a", "b"], "c"]) + '(a and b) and c' + """ + if isinstance(elements, str): + return elements + if isinstance(elements, list): + if len(elements) == 1: + return compile_and_clauses(elements[0]) + return " and ".join( + ( + f"({compile_and_clauses(e)})" + if isinstance(e, list) + else compile_and_clauses(e) + ) + for e in elements + ) + raise ValueError("Input must be a string or a list") + + class UnsupportedDialectError(Exception): def __init__(self, param: str, dialect: str): super().__init__(f"Unsupported dialect in param {param}: {dialect}") @@ -268,7 +314,105 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): def get_where(self) -> Optional[str]: """Returns the where clause from the test case parameters or None if it is a blank string.""" - return self.runtime_params.whereClause or None + runtime_where = self.runtime_params.whereClause + sample_where = self.sample_where_clause() + where = compile_and_clauses([c for c in [runtime_where, sample_where] if c]) + return where if where else None + + def sample_where_clause(self) -> Optional[str]: + """We use a where clause to sample the data for the diff. This is useful because with data diff + we do not have access to the underlying 'SELECT' statement. This method generates a where clause + that selects a random sample of the data based on the profile sample configuration. + The method uses the md5 hash of the key columns and a random salt to select a random sample of the data. + This ensures that the same data is selected from the two tables of the comparison. + + Example: + -- Table 1 -- | -- Table 2 -- + id | name | id | name + 1 | Alice | 1 | Alice + 2 | Bob | 2 | Bob + 3 | Charlie | 3 | Charlie + 4 | David | 4 | Edward + 5 | Edward | 6 | Frank + + If we want a sample of 20% of the data, the where clause will intend to select one of the rows + on Table 1 and the hash will ensure that the same row is selected on Table 2. We want to avoid selecting rows + with different ids because the comparison will not be sensible. + """ + if ( + self.runtime_params.table_profile_config is None + or self.runtime_params.table_profile_config.profileSample is None + ): + return None + if DatabaseServiceType.Mssql in [ + self.runtime_params.table1.database_service_type, + self.runtime_params.table2.database_service_type, + ]: + raise ValueError( + "MSSQL does not support sampling in data diff.\n" + "You can request this feature here:\n" + "https://github.com/open-metadata/OpenMetadata/issues/new?labels=enhancement&projects=&template=feature_request.md" # pylint: disable=line-too-long + ) + nounce = self.calculate_nounce() + # SQL MD5 returns a 32 character hex string even with leading zeros so we need to + # pad the nounce to 8 characters in preserve lexical order. + # example: SELECT md5('j(R1wzR*y[^GxWJ5B>L{-HLETRD'); + hex_nounce = hex(nounce)[2:].rjust(8, "0") + # TODO: using strings for this is sub-optimal. But using bytes buffers requires a by-database + # implementaiton. We can use this as default and add database specific implementations as the + # need arises. + salt = "".join( + random.choices(string.ascii_letters + string.digits, k=5) + ) # 1 / ~62^5 should be enough entropy. Use letters and digits to avoid messing with SQL syntax + sql_alchemy_columns = [ + build_orm_col( + i, c, self.runtime_params.table1.database_service_type + ) # TODO: get from runtime params + for i, c in enumerate(self.runtime_params.table1.columns) + if c.name.root in self.runtime_params.keyColumns + ] + reduced_concat = reduce( + lambda c1, c2: c1.concat(c2), sql_alchemy_columns + [literal(salt)] + ) + return str( + ( + select() + .filter( + Substr( + MD5(reduced_concat), + 1, + 8, + ) + < hex_nounce + ) + .whereclause.compile(compile_kwargs={"literal_binds": True}) + ) + ) + + def calculate_nounce(self, max_nounce=2**32 - 1) -> int: + """Calculate the nounce based on the profile sample configuration. The nounce is + the sample fraction projected to a number on a scale of 0 to max_nounce""" + if ( + self.runtime_params.table_profile_config.profileSampleType + == ProfileSampleType.PERCENTAGE + ): + return int( + max_nounce + * self.runtime_params.table_profile_config.profileSample + / 100 + ) + if ( + self.runtime_params.table_profile_config.profileSampleType + == ProfileSampleType.ROWS + ): + row_count = self.get_row_count() + if row_count is None: + raise ValueError("Row count is required for ROWS profile sample type") + return int( + max_nounce + * (self.runtime_params.table_profile_config.profileSample / row_count) + ) + raise ValueError("Invalid profile sample type") def get_runtime_params(self) -> TableDiffRuntimeParameters: raw = self.get_test_case_param_value( @@ -470,3 +614,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): return utils.get_bool_test_case_param( self.test_case.parameterValues, "caseSensitiveColumns" ) + + def get_row_count(self) -> Optional[int]: + self.runner._sample = None # pylint: disable=protected-access + return self._compute_row_count(self.runner, None) diff --git a/ingestion/src/metadata/profiler/orm/functions/md5.py b/ingestion/src/metadata/profiler/orm/functions/md5.py new file mode 100644 index 00000000000..035326cb17d --- /dev/null +++ b/ingestion/src/metadata/profiler/orm/functions/md5.py @@ -0,0 +1,39 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Define Concat function +""" +# Keep SQA docs style defining custom constructs + +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.functions import FunctionElement + +from metadata.profiler.metrics.core import CACHE +from metadata.profiler.orm.registry import Dialects +from metadata.utils.logger import profiler_logger + +logger = profiler_logger() + + +class MD5(FunctionElement): + inherit_cache = CACHE + + +@compiles(MD5) +def _(element, compiler, **kw): + return f"MD5({compiler.process(element.clauses, **kw)})" + + +@compiles(MD5, Dialects.MSSQL) +def _(element, compiler, **kw): + # TODO requires separate where clauses for each table + return f"CONVERT(VARCHAR(8), HashBytes('MD5', {compiler.process(element.clauses, **kw)}), 2)" diff --git a/ingestion/src/metadata/profiler/orm/functions/substr.py b/ingestion/src/metadata/profiler/orm/functions/substr.py new file mode 100644 index 00000000000..9b884e15613 --- /dev/null +++ b/ingestion/src/metadata/profiler/orm/functions/substr.py @@ -0,0 +1,32 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Define Concat function +""" +# Keep SQA docs style defining custom constructs + +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.functions import FunctionElement + +from metadata.profiler.metrics.core import CACHE +from metadata.utils.logger import profiler_logger + +logger = profiler_logger() + + +class Substr(FunctionElement): + inherit_cache = CACHE + + +@compiles(Substr) +def _(element, compiler, **kw): + return f"SUBSTRING({compiler.process(element.clauses, **kw)})" diff --git a/ingestion/tests/integration/data_quality/test_table_diff.py b/ingestion/tests/integration/data_quality/test_data_diff.py similarity index 90% rename from ingestion/tests/integration/data_quality/test_table_diff.py rename to ingestion/tests/integration/data_quality/test_data_diff.py index 62d69ca6beb..976d82cb784 100644 --- a/ingestion/tests/integration/data_quality/test_table_diff.py +++ b/ingestion/tests/integration/data_quality/test_data_diff.py @@ -2,6 +2,7 @@ import sys from datetime import datetime import pytest +from dirty_equals import IsApprox from pydantic import BaseModel from sqlalchemy import VARBINARY from sqlalchemy import Column as SQAColumn @@ -15,7 +16,11 @@ from sqlalchemy.sql import sqltypes from _openmetadata_testutils.postgres.conftest import postgres_container from _openmetadata_testutils.pydantic.test_utils import assert_equal_pydantic_objects from metadata.data_quality.api.models import TestCaseDefinition -from metadata.generated.schema.entity.data.table import Table +from metadata.generated.schema.entity.data.table import ( + ProfileSampleType, + Table, + TableProfilerConfig, +) from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.generated.schema.metadataIngestion.testSuitePipeline import ( TestSuiteConfigType, @@ -40,6 +45,7 @@ class TestParameters(BaseModel): test_case_defintion: TestCaseDefinition table2_fqn: str expected: TestCaseResult + table_profile_config: TableProfilerConfig = None def __init__(self, *args, **kwargs): if args: @@ -74,6 +80,54 @@ class TestParameters(BaseModel): passedRows=599, ), ), + ( + TestCaseDefinition( + name="compare_same_tables_with_percentage_sample", + testDefinitionName="tableDiff", + computePassedFailedRowCount=True, + parameterValues=[ + TestCaseParameterValue( + name="keyColumns", value="['customer_id']" + ), + ], + ), + "POSTGRES_SERVICE.dvdrental.public.customer", + TestCaseResult.model_construct( + timestamp=int(datetime.now().timestamp() * 1000), + testCaseStatus=TestCaseStatus.Success, + failedRows=0, + # we use approximations becuase the sampling is not deterministic + passedRows=IsApprox(59, delta=20), + ), + TableProfilerConfig( + profileSampleType=ProfileSampleType.PERCENTAGE, + profileSample=10, + ), + ), + ( + TestCaseDefinition( + name="compare_same_tables_with_row_sample", + testDefinitionName="tableDiff", + computePassedFailedRowCount=True, + parameterValues=[ + TestCaseParameterValue( + name="keyColumns", value="['customer_id']" + ), + ], + ), + "POSTGRES_SERVICE.dvdrental.public.customer", + TestCaseResult.model_construct( + timestamp=int(datetime.now().timestamp() * 1000), + testCaseStatus=TestCaseStatus.Success, + failedRows=0, + passedRows=IsApprox(10, delta=5), + ), + TableProfilerConfig( + profileSampleType=ProfileSampleType.ROWS, + # we use approximations becuase the sampling is not deterministic + profileSample=10, + ), + ), ( TestCaseDefinition( name="with_explicit_key_columns", @@ -303,7 +357,7 @@ def test_happy_paths( cleanup_fqns, ): metadata = patched_metadata - table1 = metadata.get_by_name( + table1: Table = metadata.get_by_name( Table, f"{postgres_service.fullyQualifiedName.root}.dvdrental.public.customer", nullable=False, @@ -328,6 +382,10 @@ def test_happy_paths( ), ] ) + if parameters.table_profile_config: + metadata.create_or_update_table_profiler_config( + table1.fullyQualifiedName.root, parameters.table_profile_config + ) workflow_config = { "source": { "type": TestSuiteConfigType.TestSuite.value, @@ -347,6 +405,9 @@ def test_happy_paths( "workflowConfig": workflow_config, } run_workflow(TestSuiteWorkflow, workflow_config) + metadata.create_or_update_table_profiler_config( + table1.fullyQualifiedName.root, TableProfilerConfig() + ) test_case_entity = metadata.get_by_name( TestCase, f"{table1.fullyQualifiedName.root}.{parameters.test_case_defintion.name}", diff --git a/ingestion/tests/unit/metadata/data_quality/test_data_diff.py b/ingestion/tests/unit/metadata/data_quality/test_data_diff.py new file mode 100644 index 00000000000..079e57a173c --- /dev/null +++ b/ingestion/tests/unit/metadata/data_quality/test_data_diff.py @@ -0,0 +1,178 @@ +from unittest.mock import Mock, patch + +import pytest + +from metadata.data_quality.validations.models import ( + TableDiffRuntimeParameters, + TableParameter, +) +from metadata.data_quality.validations.table.sqlalchemy.tableDiff import ( + TableDiffValidator, + compile_and_clauses, +) +from metadata.generated.schema.entity.data.table import ( + Column, + DataType, + ProfileSampleType, + TableProfilerConfig, +) +from metadata.generated.schema.entity.services.databaseService import ( + DatabaseServiceType, +) + + +@pytest.mark.parametrize( + "elements, expected", + [ + ("a", "a"), + (["a", "b"], "a and b"), + (["a", ["b", "c"]], "a and (b and c)"), + (["a", ["b", ["c", "d"]]], "a and (b and (c and d))"), + (["a", ["b", "c"], "d"], "a and (b and c) and d"), + ([], ""), + ("", ""), + (["a"], "a"), + ([["a"]], "a"), + ([["a"]], "a"), + ], +) +def test_compile_and_clauses(elements, expected): + assert compile_and_clauses(elements) == expected + + +@pytest.mark.parametrize( + "config,expected", + [ + ( + TableDiffRuntimeParameters.model_construct( + **{ + "database_service_type": "BigQuery", + "table_profile_config": TableProfilerConfig( + profileSampleType=ProfileSampleType.PERCENTAGE, + profileSample=10, + ), + "table1": TableParameter.model_construct( + **{ + "database_service_type": DatabaseServiceType.Postgres, + "columns": [ + Column(name="id", dataType=DataType.STRING), + Column(name="name", dataType=DataType.STRING), + ], + } + ), + "table2": TableParameter.model_construct( + **{"database_service_type": DatabaseServiceType.Postgres} + ), + "keyColumns": ["id"], + } + ), + "SUBSTRING(MD5(id || 'a'), 1, 8) < '19999999'", + ), + ( + TableDiffRuntimeParameters.model_construct( + **{ + "database_service_type": "BigQuery", + "table_profile_config": TableProfilerConfig( + profileSampleType=ProfileSampleType.PERCENTAGE, + profileSample=20, + ), + "table1": TableParameter.model_construct( + **{ + "database_service_type": DatabaseServiceType.Postgres, + "columns": [ + Column(name="id", dataType=DataType.STRING), + Column(name="name", dataType=DataType.STRING), + ], + } + ), + "table2": TableParameter.model_construct( + **{"database_service_type": DatabaseServiceType.Postgres} + ), + "keyColumns": ["id"], + } + ), + "SUBSTRING(MD5(id || 'a'), 1, 8) < '33333333'", + ), + ( + TableDiffRuntimeParameters.model_construct( + **{ + "database_service_type": "BigQuery", + "table_profile_config": TableProfilerConfig( + profileSampleType=ProfileSampleType.PERCENTAGE, + profileSample=10, + ), + "table1": TableParameter.model_construct( + **{ + "database_service_type": DatabaseServiceType.Postgres, + "columns": [ + Column(name="id", dataType=DataType.STRING), + Column(name="name", dataType=DataType.STRING), + ], + } + ), + "table2": TableParameter.model_construct( + **{"database_service_type": DatabaseServiceType.Postgres} + ), + "keyColumns": ["id", "name"], + } + ), + "SUBSTRING(MD5(id || name || 'a'), 1, 8) < '19999999'", + ), + ( + TableDiffRuntimeParameters.model_construct( + **{ + "database_service_type": "BigQuery", + "table_profile_config": TableProfilerConfig( + profileSampleType=ProfileSampleType.ROWS, + profileSample=20, + ), + "table1": TableParameter.model_construct( + **{ + "database_service_type": DatabaseServiceType.Postgres, + "columns": [ + Column(name="id", dataType=DataType.STRING), + Column(name="name", dataType=DataType.STRING), + ], + } + ), + "table2": TableParameter.model_construct( + **{"database_service_type": DatabaseServiceType.Postgres} + ), + "keyColumns": ["id", "name"], + } + ), + "SUBSTRING(MD5(id || name || 'a'), 1, 8) < '0083126e'", + ), + ( + TableDiffRuntimeParameters.model_construct( + **{ + "table_profile_config": None, + "table1": TableParameter.model_construct( + **{ + "database_service_type": DatabaseServiceType.Postgres, + "columns": [ + Column(name="id", dataType=DataType.STRING), + Column(name="name", dataType=DataType.STRING), + ], + } + ), + "table2": TableParameter.model_construct( + **{"database_service_type": DatabaseServiceType.Postgres} + ), + "keyColumns": ["id", "name"], + } + ), + None, + ), + ], +) +def test_sample_where_clauses(config, expected): + validator = TableDiffValidator(None, None, None) + validator.runtime_params = config + if ( + config.table_profile_config + and config.table_profile_config.profileSampleType == ProfileSampleType.ROWS + ): + validator.get_row_count = Mock(return_value=10_000) + with patch("random.choices", Mock(return_value=["a"])): + assert validator.sample_where_clause() == expected