[GEN-1996] feat(data-quality): use sampling config in data diff (#18532)

* feat(data-quality): use sampling config in data diff

- get the table profiling config
- use hashing to sample deterministically the same ids from each table
- use dirty-equals to assert results of stochastic processes

* - reverted missing md5
- added missing database service type

* - use a custom substr sql function

* fixed nounce

* added failure for mssql with sampling because it requires a larger change in the data-diff library

* fixed unit tests

* updated range for sampling
This commit is contained in:
Imri Paran 2024-11-11 10:07:23 +01:00 committed by GitHub
parent 943b4efb4d
commit cdaa5c10af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 476 additions and 7 deletions

View File

@ -354,6 +354,7 @@ test = {
"pytest==7.0.0",
"pytest-cov",
"pytest-order",
"dirty-equals",
# install dbt dependency
"dbt-artifacts-parser",
"freezegun",

View File

@ -4,13 +4,17 @@ from typing import List, Optional
from pydantic import BaseModel
from metadata.generated.schema.entity.data.table import Column
from metadata.generated.schema.entity.data.table import Column, TableProfilerConfig
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
)
class TableParameter(BaseModel):
serviceUrl: str
path: str
columns: List[Column]
database_service_type: DatabaseServiceType
class TableDiffRuntimeParameters(BaseModel):
@ -19,3 +23,4 @@ class TableDiffRuntimeParameters(BaseModel):
keyColumns: List[str]
extraColumns: List[str]
whereClause: Optional[str]
table_profile_config: Optional[TableProfilerConfig]

View File

@ -77,7 +77,9 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
key_columns = self.get_key_columns(test_case)
extra_columns = self.get_extra_columns(key_columns, test_case)
return TableDiffRuntimeParameters(
table_profile_config=self.table_entity.tableProfilerConfig,
table1=TableParameter(
database_service_type=service1.serviceType,
path=self.get_data_diff_table_path(
self.table_entity.fullyQualifiedName.root
),
@ -94,6 +96,7 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
),
),
table2=TableParameter(
database_service_type=service2.serviceType,
path=self.get_data_diff_table_path(table2_fqn),
serviceUrl=self.get_data_diff_url(
service2,
@ -118,8 +121,10 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
param_where_clause = self.get_parameter(test_case, "where", None)
partition_where_clause = (
None
if not self.sampler._partition_details
or not self.sampler._partition_details.enablePartitioning
if not (
self.sampler._partition_details
and self.sampler._partition_details.enablePartitioning
)
else self.sampler.get_partitioned_query().whereclause.compile(
compile_kwargs={"literal_binds": True}
)

View File

@ -10,8 +10,11 @@
# limitations under the License.
# pylint: disable=missing-module-docstring
import logging
import random
import string
import traceback
from decimal import Decimal
from functools import reduce
from itertools import islice
from typing import Dict, Iterable, List, Optional, Tuple, cast
from urllib.parse import urlparse
@ -22,6 +25,7 @@ from data_diff.diff_tables import DiffResultWrapper
from data_diff.errors import DataDiffMismatchingKeyTypesError
from data_diff.utils import ArithAlphanumeric, CaseInsensitiveDict
from sqlalchemy import Column as SAColumn
from sqlalchemy import literal, select
from metadata.data_quality.validations import utils
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
@ -29,15 +33,21 @@ from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
SQAValidatorMixin,
)
from metadata.data_quality.validations.models import TableDiffRuntimeParameters
from metadata.generated.schema.entity.data.table import Column
from metadata.generated.schema.entity.data.table import Column, ProfileSampleType
from metadata.generated.schema.entity.services.connections.database.sapHanaConnection import (
SapHanaScheme,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
)
from metadata.generated.schema.tests.basic import (
TestCaseResult,
TestCaseStatus,
TestResultValue,
)
from metadata.profiler.orm.converter.base import build_orm_col
from metadata.profiler.orm.functions.md5 import MD5
from metadata.profiler.orm.functions.substr import Substr
from metadata.profiler.orm.registry import Dialects
from metadata.utils.logger import test_suite_logger
@ -57,6 +67,42 @@ SUPPORTED_DIALECTS = [
]
def compile_and_clauses(elements) -> str:
"""Compile a list of elements into a string with 'and' clauses.
Args:
elements: A string or a list of strings or lists
Returns:
A string with 'and' clauses
Raises:
ValueError: If the input is not a string or a list
Examples:
>>> compile_and_clauses("a")
'a'
>>> compile_and_clauses(["a", "b"])
'a and b'
>>> compile_and_clauses([["a", "b"], "c"])
'(a and b) and c'
"""
if isinstance(elements, str):
return elements
if isinstance(elements, list):
if len(elements) == 1:
return compile_and_clauses(elements[0])
return " and ".join(
(
f"({compile_and_clauses(e)})"
if isinstance(e, list)
else compile_and_clauses(e)
)
for e in elements
)
raise ValueError("Input must be a string or a list")
class UnsupportedDialectError(Exception):
def __init__(self, param: str, dialect: str):
super().__init__(f"Unsupported dialect in param {param}: {dialect}")
@ -268,7 +314,105 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
def get_where(self) -> Optional[str]:
"""Returns the where clause from the test case parameters or None if it is a blank string."""
return self.runtime_params.whereClause or None
runtime_where = self.runtime_params.whereClause
sample_where = self.sample_where_clause()
where = compile_and_clauses([c for c in [runtime_where, sample_where] if c])
return where if where else None
def sample_where_clause(self) -> Optional[str]:
"""We use a where clause to sample the data for the diff. This is useful because with data diff
we do not have access to the underlying 'SELECT' statement. This method generates a where clause
that selects a random sample of the data based on the profile sample configuration.
The method uses the md5 hash of the key columns and a random salt to select a random sample of the data.
This ensures that the same data is selected from the two tables of the comparison.
Example:
-- Table 1 -- | -- Table 2 --
id | name | id | name
1 | Alice | 1 | Alice
2 | Bob | 2 | Bob
3 | Charlie | 3 | Charlie
4 | David | 4 | Edward
5 | Edward | 6 | Frank
If we want a sample of 20% of the data, the where clause will intend to select one of the rows
on Table 1 and the hash will ensure that the same row is selected on Table 2. We want to avoid selecting rows
with different ids because the comparison will not be sensible.
"""
if (
self.runtime_params.table_profile_config is None
or self.runtime_params.table_profile_config.profileSample is None
):
return None
if DatabaseServiceType.Mssql in [
self.runtime_params.table1.database_service_type,
self.runtime_params.table2.database_service_type,
]:
raise ValueError(
"MSSQL does not support sampling in data diff.\n"
"You can request this feature here:\n"
"https://github.com/open-metadata/OpenMetadata/issues/new?labels=enhancement&projects=&template=feature_request.md" # pylint: disable=line-too-long
)
nounce = self.calculate_nounce()
# SQL MD5 returns a 32 character hex string even with leading zeros so we need to
# pad the nounce to 8 characters in preserve lexical order.
# example: SELECT md5('j(R1wzR*y[^GxWJ5B>L{-HLETRD');
hex_nounce = hex(nounce)[2:].rjust(8, "0")
# TODO: using strings for this is sub-optimal. But using bytes buffers requires a by-database
# implementaiton. We can use this as default and add database specific implementations as the
# need arises.
salt = "".join(
random.choices(string.ascii_letters + string.digits, k=5)
) # 1 / ~62^5 should be enough entropy. Use letters and digits to avoid messing with SQL syntax
sql_alchemy_columns = [
build_orm_col(
i, c, self.runtime_params.table1.database_service_type
) # TODO: get from runtime params
for i, c in enumerate(self.runtime_params.table1.columns)
if c.name.root in self.runtime_params.keyColumns
]
reduced_concat = reduce(
lambda c1, c2: c1.concat(c2), sql_alchemy_columns + [literal(salt)]
)
return str(
(
select()
.filter(
Substr(
MD5(reduced_concat),
1,
8,
)
< hex_nounce
)
.whereclause.compile(compile_kwargs={"literal_binds": True})
)
)
def calculate_nounce(self, max_nounce=2**32 - 1) -> int:
"""Calculate the nounce based on the profile sample configuration. The nounce is
the sample fraction projected to a number on a scale of 0 to max_nounce"""
if (
self.runtime_params.table_profile_config.profileSampleType
== ProfileSampleType.PERCENTAGE
):
return int(
max_nounce
* self.runtime_params.table_profile_config.profileSample
/ 100
)
if (
self.runtime_params.table_profile_config.profileSampleType
== ProfileSampleType.ROWS
):
row_count = self.get_row_count()
if row_count is None:
raise ValueError("Row count is required for ROWS profile sample type")
return int(
max_nounce
* (self.runtime_params.table_profile_config.profileSample / row_count)
)
raise ValueError("Invalid profile sample type")
def get_runtime_params(self) -> TableDiffRuntimeParameters:
raw = self.get_test_case_param_value(
@ -470,3 +614,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
return utils.get_bool_test_case_param(
self.test_case.parameterValues, "caseSensitiveColumns"
)
def get_row_count(self) -> Optional[int]:
self.runner._sample = None # pylint: disable=protected-access
return self._compute_row_count(self.runner, None)

View File

@ -0,0 +1,39 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Define Concat function
"""
# Keep SQA docs style defining custom constructs
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.functions import FunctionElement
from metadata.profiler.metrics.core import CACHE
from metadata.profiler.orm.registry import Dialects
from metadata.utils.logger import profiler_logger
logger = profiler_logger()
class MD5(FunctionElement):
inherit_cache = CACHE
@compiles(MD5)
def _(element, compiler, **kw):
return f"MD5({compiler.process(element.clauses, **kw)})"
@compiles(MD5, Dialects.MSSQL)
def _(element, compiler, **kw):
# TODO requires separate where clauses for each table
return f"CONVERT(VARCHAR(8), HashBytes('MD5', {compiler.process(element.clauses, **kw)}), 2)"

View File

@ -0,0 +1,32 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Define Concat function
"""
# Keep SQA docs style defining custom constructs
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.functions import FunctionElement
from metadata.profiler.metrics.core import CACHE
from metadata.utils.logger import profiler_logger
logger = profiler_logger()
class Substr(FunctionElement):
inherit_cache = CACHE
@compiles(Substr)
def _(element, compiler, **kw):
return f"SUBSTRING({compiler.process(element.clauses, **kw)})"

View File

@ -2,6 +2,7 @@ import sys
from datetime import datetime
import pytest
from dirty_equals import IsApprox
from pydantic import BaseModel
from sqlalchemy import VARBINARY
from sqlalchemy import Column as SQAColumn
@ -15,7 +16,11 @@ from sqlalchemy.sql import sqltypes
from _openmetadata_testutils.postgres.conftest import postgres_container
from _openmetadata_testutils.pydantic.test_utils import assert_equal_pydantic_objects
from metadata.data_quality.api.models import TestCaseDefinition
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.data.table import (
ProfileSampleType,
Table,
TableProfilerConfig,
)
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.metadataIngestion.testSuitePipeline import (
TestSuiteConfigType,
@ -40,6 +45,7 @@ class TestParameters(BaseModel):
test_case_defintion: TestCaseDefinition
table2_fqn: str
expected: TestCaseResult
table_profile_config: TableProfilerConfig = None
def __init__(self, *args, **kwargs):
if args:
@ -74,6 +80,54 @@ class TestParameters(BaseModel):
passedRows=599,
),
),
(
TestCaseDefinition(
name="compare_same_tables_with_percentage_sample",
testDefinitionName="tableDiff",
computePassedFailedRowCount=True,
parameterValues=[
TestCaseParameterValue(
name="keyColumns", value="['customer_id']"
),
],
),
"POSTGRES_SERVICE.dvdrental.public.customer",
TestCaseResult.model_construct(
timestamp=int(datetime.now().timestamp() * 1000),
testCaseStatus=TestCaseStatus.Success,
failedRows=0,
# we use approximations becuase the sampling is not deterministic
passedRows=IsApprox(59, delta=20),
),
TableProfilerConfig(
profileSampleType=ProfileSampleType.PERCENTAGE,
profileSample=10,
),
),
(
TestCaseDefinition(
name="compare_same_tables_with_row_sample",
testDefinitionName="tableDiff",
computePassedFailedRowCount=True,
parameterValues=[
TestCaseParameterValue(
name="keyColumns", value="['customer_id']"
),
],
),
"POSTGRES_SERVICE.dvdrental.public.customer",
TestCaseResult.model_construct(
timestamp=int(datetime.now().timestamp() * 1000),
testCaseStatus=TestCaseStatus.Success,
failedRows=0,
passedRows=IsApprox(10, delta=5),
),
TableProfilerConfig(
profileSampleType=ProfileSampleType.ROWS,
# we use approximations becuase the sampling is not deterministic
profileSample=10,
),
),
(
TestCaseDefinition(
name="with_explicit_key_columns",
@ -303,7 +357,7 @@ def test_happy_paths(
cleanup_fqns,
):
metadata = patched_metadata
table1 = metadata.get_by_name(
table1: Table = metadata.get_by_name(
Table,
f"{postgres_service.fullyQualifiedName.root}.dvdrental.public.customer",
nullable=False,
@ -328,6 +382,10 @@ def test_happy_paths(
),
]
)
if parameters.table_profile_config:
metadata.create_or_update_table_profiler_config(
table1.fullyQualifiedName.root, parameters.table_profile_config
)
workflow_config = {
"source": {
"type": TestSuiteConfigType.TestSuite.value,
@ -347,6 +405,9 @@ def test_happy_paths(
"workflowConfig": workflow_config,
}
run_workflow(TestSuiteWorkflow, workflow_config)
metadata.create_or_update_table_profiler_config(
table1.fullyQualifiedName.root, TableProfilerConfig()
)
test_case_entity = metadata.get_by_name(
TestCase,
f"{table1.fullyQualifiedName.root}.{parameters.test_case_defintion.name}",

View File

@ -0,0 +1,178 @@
from unittest.mock import Mock, patch
import pytest
from metadata.data_quality.validations.models import (
TableDiffRuntimeParameters,
TableParameter,
)
from metadata.data_quality.validations.table.sqlalchemy.tableDiff import (
TableDiffValidator,
compile_and_clauses,
)
from metadata.generated.schema.entity.data.table import (
Column,
DataType,
ProfileSampleType,
TableProfilerConfig,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
)
@pytest.mark.parametrize(
"elements, expected",
[
("a", "a"),
(["a", "b"], "a and b"),
(["a", ["b", "c"]], "a and (b and c)"),
(["a", ["b", ["c", "d"]]], "a and (b and (c and d))"),
(["a", ["b", "c"], "d"], "a and (b and c) and d"),
([], ""),
("", ""),
(["a"], "a"),
([["a"]], "a"),
([["a"]], "a"),
],
)
def test_compile_and_clauses(elements, expected):
assert compile_and_clauses(elements) == expected
@pytest.mark.parametrize(
"config,expected",
[
(
TableDiffRuntimeParameters.model_construct(
**{
"database_service_type": "BigQuery",
"table_profile_config": TableProfilerConfig(
profileSampleType=ProfileSampleType.PERCENTAGE,
profileSample=10,
),
"table1": TableParameter.model_construct(
**{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
}
),
"table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres}
),
"keyColumns": ["id"],
}
),
"SUBSTRING(MD5(id || 'a'), 1, 8) < '19999999'",
),
(
TableDiffRuntimeParameters.model_construct(
**{
"database_service_type": "BigQuery",
"table_profile_config": TableProfilerConfig(
profileSampleType=ProfileSampleType.PERCENTAGE,
profileSample=20,
),
"table1": TableParameter.model_construct(
**{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
}
),
"table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres}
),
"keyColumns": ["id"],
}
),
"SUBSTRING(MD5(id || 'a'), 1, 8) < '33333333'",
),
(
TableDiffRuntimeParameters.model_construct(
**{
"database_service_type": "BigQuery",
"table_profile_config": TableProfilerConfig(
profileSampleType=ProfileSampleType.PERCENTAGE,
profileSample=10,
),
"table1": TableParameter.model_construct(
**{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
}
),
"table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres}
),
"keyColumns": ["id", "name"],
}
),
"SUBSTRING(MD5(id || name || 'a'), 1, 8) < '19999999'",
),
(
TableDiffRuntimeParameters.model_construct(
**{
"database_service_type": "BigQuery",
"table_profile_config": TableProfilerConfig(
profileSampleType=ProfileSampleType.ROWS,
profileSample=20,
),
"table1": TableParameter.model_construct(
**{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
}
),
"table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres}
),
"keyColumns": ["id", "name"],
}
),
"SUBSTRING(MD5(id || name || 'a'), 1, 8) < '0083126e'",
),
(
TableDiffRuntimeParameters.model_construct(
**{
"table_profile_config": None,
"table1": TableParameter.model_construct(
**{
"database_service_type": DatabaseServiceType.Postgres,
"columns": [
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
}
),
"table2": TableParameter.model_construct(
**{"database_service_type": DatabaseServiceType.Postgres}
),
"keyColumns": ["id", "name"],
}
),
None,
),
],
)
def test_sample_where_clauses(config, expected):
validator = TableDiffValidator(None, None, None)
validator.runtime_params = config
if (
config.table_profile_config
and config.table_profile_config.profileSampleType == ProfileSampleType.ROWS
):
validator.get_row_count = Mock(return_value=10_000)
with patch("random.choices", Mock(return_value=["a"])):
assert validator.sample_where_clause() == expected