MINOR: add column case sensitivity parameter (#18115)

* fix(data-quality): table diff

- added handling for case-insensitive columns
- added handling for different numeric types (int/float/Decimal)
- added handling of boolean test case parameters

* add migrations for table diff

* add migrations for table diff

* removed cross type diff for now. it appears to be flaky

* fixed migrations

* use casefold() instead of lower()

* - implemented utils.get_test_case_param_value
- fixed params for case sensitive column

* handle bool test case parameters

* format

* testing

* format

* list -> List

* list -> List

* - change caseSensitiveColumns default to fase
- added migration to stay backward compatible

* - removed migration files
- updated logging message for table diff migration

* changed bool test case parameters default to always be false

* format

* docs: data diff

- added the caseSensitiveColumns parameter

requires: https://github.com/open-metadata/OpenMetadata/pull/18115

* fixed test_get_bool_test_case_param

(cherry picked from commit be82086e2542d2d176ac66e0bf11100646448b4f)
This commit is contained in:
Imri Paran 2024-10-15 16:29:43 +02:00 committed by sushi30
parent 58b11669aa
commit 5f93b05d8b
18 changed files with 421 additions and 136 deletions

View File

@ -50,8 +50,8 @@ echo "Running local docker using mode [$mode] database [$database] and skipping
cd ../
echo "Stopping any previous Local Docker Containers"
docker compose -f docker/development/docker-compose-postgres.yml down
docker compose -f docker/development/docker-compose.yml down
docker compose -f docker/development/docker-compose-postgres.yml down --remove-orphans
docker compose -f docker/development/docker-compose.yml down --remove-orphans
if [[ $skipMaven == "false" ]]; then
if [[ $mode == "no-ui" ]]; then

View File

@ -20,6 +20,7 @@ from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Callable, List, Optional, Type, TypeVar, Union
from metadata.data_quality.validations import utils
from metadata.data_quality.validations.runtime_param_setter.param_setter import (
RuntimeParameterSetter,
)
@ -65,37 +66,18 @@ class BaseTestValidator(ABC):
"""
raise NotImplementedError
@staticmethod
def get_test_case_param_value(
self,
test_case_param_vals: list[TestCaseParameterValue],
test_case_param_vals: List[TestCaseParameterValue],
name: str,
type_: T,
default: Optional[R] = None,
pre_processor: Optional[Callable] = None,
) -> Optional[Union[R, T]]:
"""Give a column and a type return the value with the appropriate type casting for the
test case definition.
Args:
test_case: the test case
type_ (Union[float, int, str]): type for the value
name (str): column name
default (_type_, optional): Default value to return if column is not found
pre_processor: pre processor function/type to use against the value before casting to type_
"""
value = next(
(param.value for param in test_case_param_vals if param.name == name), None
return utils.get_test_case_param_value(
test_case_param_vals, name, type_, default, pre_processor
)
if not value:
return default if default is not None else None
if not pre_processor:
return type_(value)
pre_processed_value = pre_processor(value)
return type_(pre_processed_value)
def get_test_case_result_object( # pylint: disable=too-many-arguments
self,
execution_date: Union[datetime, float],

View File

@ -20,6 +20,7 @@ from typing import Union
from sqlalchemy import Column
from metadata.data_quality.validations import utils
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.generated.schema.tests.basic import (
TestCaseResult,
@ -50,11 +51,8 @@ class BaseColumnValuesToBeInSetValidator(BaseTestValidator):
literal_eval,
)
match_enum = self.get_test_case_param_value(
self.test_case.parameterValues, # type: ignore
"matchEnum",
bool,
default=False,
match_enum = utils.get_bool_test_case_param(
self.test_case.parameterValues, "matchEnum"
)
try:

View File

@ -13,6 +13,7 @@ from ast import literal_eval
from typing import List, Optional
from urllib.parse import urlparse
from metadata.data_quality.validations import utils
from metadata.data_quality.validations.models import (
Column,
TableDiffRuntimeParameters,
@ -27,6 +28,7 @@ from metadata.generated.schema.tests.testCase import TestCase
from metadata.ingestion.source.connections import get_connection
from metadata.profiler.orm.registry import Dialects
from metadata.utils import fqn
from metadata.utils.collections import CaseInsensitiveList
class TableDiffParamsSetter(RuntimeParameterSetter):
@ -58,6 +60,9 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
DatabaseService, self.table_entity.service.id, nullable=False
)
table2_fqn = self.get_parameter(test_case, "table2")
case_sensitive_columns: bool = utils.get_bool_test_case_param(
test_case.parameterValues, "caseSensitiveColumns"
)
if table2_fqn is None:
raise ValueError("table2 not set")
table2: Table = self.ometa_client.get_by_name(
@ -82,7 +87,10 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
override_url=service1_url,
),
columns=self.filter_relevant_columns(
self.table_entity.columns, key_columns, extra_columns
self.table_entity.columns,
key_columns,
extra_columns,
case_sensitive=case_sensitive_columns,
),
),
table2=TableParameter(
@ -94,7 +102,10 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
or service2_url,
),
columns=self.filter_relevant_columns(
table2.columns, key_columns, extra_columns
table2.columns,
key_columns,
extra_columns,
case_sensitive=case_sensitive_columns,
),
),
keyColumns=key_columns,
@ -156,9 +167,17 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
@staticmethod
def filter_relevant_columns(
columns: List[Column], key_columns: List[str], extra_columns: List[str]
columns: List[Column],
key_columns: List[str],
extra_columns: List[str],
case_sensitive: bool,
) -> List[Column]:
return [c for c in columns if c.name.root in [*key_columns, *extra_columns]]
validated_columns = (
[*key_columns, *extra_columns]
if case_sensitive
else CaseInsensitiveList([*key_columns, *extra_columns])
)
return [c for c in columns if c.name.root in validated_columns]
@staticmethod
def get_parameter(test_case: TestCase, key: str, default=None):
@ -195,7 +214,7 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
if hasattr(db_service.connection.config, "supportsDatabase"):
kwargs["path"] = f"/{database}"
if kwargs["scheme"] in {Dialects.MSSQL, Dialects.Snowflake}:
kwargs["path"] += f"/{schema}"
kwargs["path"] = f"/{database}/{schema}"
return url._replace(**kwargs).geturl()
@staticmethod

View File

@ -11,17 +11,19 @@
# pylint: disable=missing-module-docstring
import logging
import traceback
from decimal import Decimal
from itertools import islice
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple, cast
from urllib.parse import urlparse
import data_diff
import sqlalchemy.types
from data_diff.diff_tables import DiffResultWrapper
from data_diff.errors import DataDiffMismatchingKeyTypesError
from data_diff.utils import ArithAlphanumeric
from data_diff.utils import ArithAlphanumeric, CaseInsensitiveDict
from sqlalchemy import Column as SAColumn
from metadata.data_quality.validations import utils
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
SQAValidatorMixin,
@ -75,6 +77,18 @@ def masked(s: str, mask: bool = True) -> str:
return "***" if mask else s
def is_numeric(t: type) -> bool:
"""Check if a type is numeric.
Args:
t: type to check
Returns:
True if the type is numeric otherwise False
"""
return t in [int, float, Decimal]
class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
"""
Compare two tables and fail if the number of differences exceeds a threshold
@ -167,12 +181,14 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
self.runtime_params.table1.path,
self.runtime_params.keyColumns,
extra_columns=self.runtime_params.extraColumns,
case_sensitive=self.get_case_sensitive(),
).with_schema()
table2 = data_diff.connect_to_table(
self.runtime_params.table2.serviceUrl,
self.runtime_params.table2.path,
self.runtime_params.keyColumns,
extra_columns=self.runtime_params.extraColumns,
case_sensitive=self.get_case_sensitive(),
).with_schema()
result = []
for column in table1.key_columns + table1.extra_columns:
@ -185,7 +201,8 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
col2_type = self._get_column_python_type(
table2._schema[column] # pylint: disable=protected-access
)
if is_numeric(col1_type) and is_numeric(col2_type):
continue
if col1_type != col2_type:
result.append(column)
return result
@ -228,11 +245,13 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
self.runtime_params.table1.serviceUrl,
self.runtime_params.table1.path,
self.runtime_params.keyColumns, # type: ignore
case_sensitive=self.get_case_sensitive(),
)
table2 = data_diff.connect_to_table(
self.runtime_params.table2.serviceUrl,
self.runtime_params.table2.path,
self.runtime_params.keyColumns, # type: ignore
case_sensitive=self.get_case_sensitive(),
)
data_diff_kwargs = {
"key_columns": self.runtime_params.keyColumns,
@ -308,7 +327,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
def get_column_diff(self) -> Optional[TestCaseResult]:
"""Get the column diff between the two tables. If there are no differences, return None."""
removed, added = self.get_changed_added_columns(
self.runtime_params.table1.columns, self.runtime_params.table2.columns
self.runtime_params.table1.columns,
self.runtime_params.table2.columns,
self.get_case_sensitive(),
)
changed = self.get_incomparable_columns()
if removed or added or changed:
@ -321,7 +342,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
@staticmethod
def get_changed_added_columns(
left: List[Column], right: List[Column]
left: List[Column], right: List[Column], case_sensitive: bool
) -> Optional[Tuple[List[str], List[str]]]:
"""Given a list of columns from two tables, return the columns that are removed and added.
@ -335,6 +356,10 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
removed: List[str] = []
added: List[str] = []
right_columns_dict: Dict[str, Column] = {c.name.root: c for c in right}
if not case_sensitive:
right_columns_dict = cast(
Dict[str, Column], CaseInsensitiveDict(right_columns_dict)
)
for column in left:
table2_column = right_columns_dict.get(column.name.root)
if table2_column is None:
@ -345,7 +370,10 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
return removed, added
def column_validation_result(
self, removed: List[str], added: List[str], changed: List[str]
self,
removed: List[str],
added: List[str],
changed: List[str],
) -> TestCaseResult:
"""Build the result for a column validation result. Messages will only be added
for non-empty categories. Values will be populated reported for all categories.
@ -367,13 +395,18 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
message += f"\n Added columns: {','.join(added)}\n"
if changed:
message += "\n Changed columns:"
table1_columns = {
c.name.root: c for c in self.runtime_params.table1.columns
}
table2_columns = {
c.name.root: c for c in self.runtime_params.table2.columns
}
if not self.get_case_sensitive():
table1_columns = CaseInsensitiveDict(table1_columns)
table2_columns = CaseInsensitiveDict(table2_columns)
for col in changed:
col1 = next(
c for c in self.runtime_params.table1.columns if c.name.root == col
)
col2 = next(
c for c in self.runtime_params.table2.columns if c.name.root == col
)
col1 = table1_columns[col]
col2 = table2_columns[col]
message += (
f"\n {col}: {col1.dataType.value} -> {col2.dataType.value}"
)
@ -432,3 +465,8 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
if str(ex) == "2":
# This is a known issue in data_diff where the diff object is closed
pass
def get_case_sensitive(self):
return utils.get_bool_test_case_param(
self.test_case.parameterValues, "caseSensitiveColumns"
)

View File

@ -0,0 +1,56 @@
"""
Data quality validation utility functions.
"""
from typing import Callable, List, Optional, TypeVar, Union
from metadata.generated.schema.tests.testCase import TestCaseParameterValue
T = TypeVar("T", bound=Callable)
R = TypeVar("R")
def get_test_case_param_value(
test_case_param_vals: List[TestCaseParameterValue],
name: str,
type_: T,
default: Optional[R] = None,
pre_processor: Optional[Callable] = None,
) -> Optional[Union[R, T]]:
"""Return a test case parameter value with the appropriate type casting for the test case definition.
Args:
test_case_param_vals: list of test case parameter values
type_ (Union[float, int, str]): type for the value
name (str): column name
default (_type_, optional): Default value to return if column is not found
pre_processor: pre processor function/type to use against the value before casting to type_
"""
value = next(
(param.value for param in test_case_param_vals if param.name == name), None
)
if not value:
return default if default is not None else None
if not pre_processor:
return type_(value)
pre_processed_value = pre_processor(value)
return type_(pre_processed_value)
def get_bool_test_case_param(
test_case_param_vals: List[TestCaseParameterValue],
name: str,
) -> Optional[Union[R, T]]:
"""Return a test case parameter value as a boolean. Boolean values are always False by default.
Args:
test_case_param_vals: list of test case parameter values
name (str): column name
"""
str_val: str = get_test_case_param_value(test_case_param_vals, name, str, None)
if str_val is None:
return False
return str_val.lower() == "true"

View File

@ -0,0 +1,27 @@
"""
Uility classes for collections
"""
class CaseInsensitiveString(str):
"""
A case-insensitive string. Useful for case-insensitive comparisons like SQL.
"""
def __eq__(self, other):
return self.casefold() == other.casefold()
def __hash__(self):
return hash(self.casefold())
class CaseInsensitiveList(list):
"""A case-insensitive list that treats all its string elements as case-insensitive.
Non-string elements are treated with default behavior."""
def __contains__(self, item):
return (
any(CaseInsensitiveString(x) == item for x in self)
if isinstance(item, str)
else any(x == item for x in self)
)

View File

@ -1,84 +0,0 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Helper module for test suite functions
"""
from __future__ import annotations
from datetime import datetime
from typing import Callable, List, Optional
from metadata.generated.schema.tests.basic import (
TestCaseResult,
TestCaseStatus,
TestResultValue,
)
from metadata.generated.schema.tests.testCase import TestCaseParameterValue
def get_test_case_param_value(
test_case_param_vals: list[TestCaseParameterValue],
name: str,
type_,
default=None,
pre_processor: Optional[Callable] = None,
):
"""Give a column and a type return the value with the appropriate type casting for the
test case definition.
Args:
test_case: the test case
type_ (Union[float, int, str]): type for the value
name (str): column name
default (_type_, optional): Default value to return if column is not found
pre_processor: pre processor function/type to use against the value before casting to type_
"""
value = next(
(param.value for param in test_case_param_vals if param.name == name), None
)
if not value:
return default
if not pre_processor:
return type_(value)
pre_processed_value = pre_processor(value)
return type_(pre_processed_value)
def build_test_case_result(
execution_datetime: datetime,
status: TestCaseStatus,
result: str,
test_result_value: List[TestResultValue],
sample_data: Optional[str] = None,
) -> TestCaseResult:
"""create a test case result object
Args:
execution_datetime (datetime): execution datetime of the test
status (TestCaseStatus): failed, succeed, aborted
result (str): message to display
testResultValue (List[TestResultValue]): values for the test result
Returns:
TestCaseResult:
"""
return TestCaseResult(
timestamp=execution_datetime,
testCaseStatus=status,
result=result,
testResultValue=test_result_value,
sampleData=sample_data,
)

View File

@ -108,8 +108,20 @@ def ingest_postgres_metadata(
"source": {
"type": postgres_service.connection.config.type.value.lower(),
"serviceName": postgres_service.fullyQualifiedName.root,
"serviceConnection": postgres_service.connection,
"sourceConfig": {"config": {}},
"serviceConnection": postgres_service.connection.model_copy(
update={
"config": postgres_service.connection.config.model_copy(
update={
"ingestAllDatabases": True,
}
)
}
),
"sourceConfig": {
"config": {
"schemaFilterPattern": {"excludes": ["information_schema"]},
}
},
},
"sink": sink_config,
"workflowConfig": workflow_config,

View File

@ -1,4 +1,5 @@
import sys
from datetime import datetime
import pytest
from pydantic import BaseModel
@ -220,6 +221,59 @@ class TestParameters(BaseModel):
),
"MYSQL_SERVICE.default.test.changed_customer",
TestCaseResult(
timestamp=int(datetime.now().timestamp() * 1000),
testCaseStatus=TestCaseStatus.Failed,
),
),
(
TestCaseDefinition(
name="postgres_different_case_columns_fail",
testDefinitionName="tableDiff",
computePassedFailedRowCount=True,
parameterValues=[
TestCaseParameterValue(
name="caseSensitiveColumns", value="true"
)
],
),
"POSTGRES_SERVICE.dvdrental.public.customer_different_case_columns",
TestCaseResult(
timestamp=int(datetime.now().timestamp() * 1000),
testCaseStatus=TestCaseStatus.Failed,
testResultValue=[
TestResultValue(name="removedColumns", value="1"),
TestResultValue(name="addedColumns", value="0"),
TestResultValue(name="changedColumns", value="0"),
],
),
),
(
TestCaseDefinition(
name="postgres_different_case_columns_success",
testDefinitionName="tableDiff",
computePassedFailedRowCount=True,
parameterValues=[
TestCaseParameterValue(
name="caseSensitiveColumns", value="false"
)
],
),
"POSTGRES_SERVICE.dvdrental.public.customer_different_case_columns",
TestCaseResult(
timestamp=int(datetime.now().timestamp() * 1000),
testCaseStatus=TestCaseStatus.Success,
),
),
(
TestCaseDefinition(
name="table_from_another_db",
testDefinitionName="tableDiff",
computePassedFailedRowCount=True,
parameterValues=[],
),
"POSTGRES_SERVICE.other_db.public.customer",
TestCaseResult(
timestamp=int(datetime.now().timestamp() * 1000),
testCaseStatus=TestCaseStatus.Failed,
),
),
@ -278,7 +332,7 @@ def test_happy_paths(
},
"processor": {
"type": "orm-test-runner",
"config": {"testCases": [parameters.test_case_defintion.dict()]},
"config": {"testCases": [parameters.test_case_defintion.model_dump()]},
},
"sink": sink_config,
"workflowConfig": workflow_config,
@ -410,6 +464,16 @@ def test_error_paths(
def add_changed_tables(connection: Connection):
connection.execute("CREATE TABLE customer_200 AS SELECT * FROM customer LIMIT 200;")
connection.execute(
"CREATE TABLE customer_different_case_columns AS SELECT * FROM customer;"
)
connection.execute(
'ALTER TABLE customer_different_case_columns RENAME COLUMN first_name TO "First_Name";'
)
# TODO: this appears to be unsupported by data diff. Cross data type comparison is flaky.
# connection.execute(
# "ALTER TABLE customer_different_case_columns ALTER COLUMN store_id TYPE decimal"
# )
connection.execute("CREATE TABLE changed_customer AS SELECT * FROM customer;")
connection.execute(
"UPDATE changed_customer SET first_name = 'John' WHERE MOD(customer_id, 2) = 0;"

View File

@ -73,7 +73,7 @@ def run_data_quality_workflow(
"columnName": "first_name",
"parameterValues": [
{"name": "allowedValues", "value": "['Tom', 'Jerry']"},
{"name": "matchEnum", "value": ""},
{"name": "matchEnum", "value": "false"},
],
},
{
@ -82,7 +82,7 @@ def run_data_quality_workflow(
"columnName": "first_name",
"parameterValues": [
{"name": "allowedValues", "value": "['Tom', 'Jerry']"},
{"name": "matchEnum", "value": "True"},
{"name": "matchEnum", "value": "true"},
],
},
{

View File

@ -0,0 +1,34 @@
from ast import literal_eval
import pytest
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.generated.schema.tests.testCase import TestCaseParameterValue
@pytest.mark.parametrize(
"param_values, name, type_, default, expected",
[
([TestCaseParameterValue(name="str", value="test")], "str", str, None, "test"),
(
[TestCaseParameterValue(name="param", value="[1, 2, 3]")],
"param",
literal_eval,
None,
[1, 2, 3],
),
([TestCaseParameterValue(name="param", value="123")], "param", int, None, 123),
(
[TestCaseParameterValue(name="param", value=None)],
"param",
str,
"default",
"default",
),
],
)
def test_get_test_case_param_value(param_values, name, type_, default, expected):
result = BaseTestValidator.get_test_case_param_value(
param_values, name, type_, default
)
assert result == expected

View File

@ -0,0 +1,22 @@
import pytest
from metadata.data_quality.validations.utils import get_bool_test_case_param
from metadata.generated.schema.tests.testCase import TestCaseParameterValue
@pytest.mark.parametrize(
"test_case_param_vals, name, expected",
[
([TestCaseParameterValue(name="param1", value="true")], "param1", True),
([TestCaseParameterValue(name="param1", value="false")], "param1", False),
([TestCaseParameterValue(name="param1", value="True")], "param1", True),
([TestCaseParameterValue(name="param1", value="False")], "param1", False),
([TestCaseParameterValue(name="param1", value="TRUE")], "param1", True),
([TestCaseParameterValue(name="param1", value="FALSE")], "param1", False),
([TestCaseParameterValue(name="param1", value="invalid")], "param1", False),
([], "param1", False),
([TestCaseParameterValue(name="param2", value="true")], "param1", False),
],
)
def test_get_bool_test_case_param(test_case_param_vals, name, expected):
assert get_bool_test_case_param(test_case_param_vals, name) == expected

View File

@ -485,6 +485,7 @@ Consistency
* `table2`: The table against which the comparison will be done. Must be the fully qualified name as defined in OpenMetadata
* `threshold`: The threshold of different rows above which the test should fail -- default to 0
* `where`: Any `where` clause to pass
* `caseSensitiveColumns`: Whether the column comparison should be case sensitive or not. Default to `false`.
**Behavior**
@ -511,6 +512,8 @@ parameterValues:
value: 10
- name: where
value: id != 999
- name: caseSensitiveColumns
value: false
```
**JSON Config**
@ -541,6 +544,10 @@ parameterValues:
{
"name": "where",
"value": "id != 999"
},
{
"name": "caseSensitiveColumns",
"value": false
}
]
}

View File

@ -0,0 +1,20 @@
package org.openmetadata.service.migration.mysql.v157;
import static org.openmetadata.service.migration.utils.v157.MigrationUtil.migrateTableDiffParams;
import lombok.SneakyThrows;
import org.openmetadata.service.migration.api.MigrationProcessImpl;
import org.openmetadata.service.migration.utils.MigrationFile;
public class Migration extends MigrationProcessImpl {
public Migration(MigrationFile migrationFile) {
super(migrationFile);
}
@Override
@SneakyThrows
public void runDataMigration() {
migrateTableDiffParams(handle, collectionDAO, authenticationConfiguration, false);
}
}

View File

@ -0,0 +1,20 @@
package org.openmetadata.service.migration.postgres.v157;
import static org.openmetadata.service.migration.utils.v157.MigrationUtil.migrateTableDiffParams;
import lombok.SneakyThrows;
import org.openmetadata.service.migration.api.MigrationProcessImpl;
import org.openmetadata.service.migration.utils.MigrationFile;
public class Migration extends MigrationProcessImpl {
public Migration(MigrationFile migrationFile) {
super(migrationFile);
}
@Override
@SneakyThrows
public void runDataMigration() {
migrateTableDiffParams(handle, collectionDAO, authenticationConfiguration, false);
}
}

View File

@ -0,0 +1,63 @@
package org.openmetadata.service.migration.utils.v157;
import static org.openmetadata.service.Entity.TEST_CASE;
import static org.openmetadata.service.Entity.TEST_DEFINITION;
import java.util.List;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.jdbi.v3.core.Handle;
import org.openmetadata.schema.api.security.AuthenticationConfiguration;
import org.openmetadata.schema.tests.TestCase;
import org.openmetadata.schema.tests.TestCaseParameterValue;
import org.openmetadata.schema.tests.TestDefinition;
import org.openmetadata.schema.type.Relationship;
import org.openmetadata.service.jdbi3.CollectionDAO;
import org.openmetadata.service.util.JsonUtils;
@Slf4j
public class MigrationUtil {
private static final String TABLE_DIFF = "tableDiff";
public static void migrateTableDiffParams(
Handle handle,
CollectionDAO daoCollection,
AuthenticationConfiguration config,
boolean postgres) {
int pageSize = 1000;
int offset = 0;
while (true) {
List<String> jsons = daoCollection.testCaseDAO().listAfterWithOffset(pageSize, offset);
if (jsons.isEmpty()) {
break;
}
offset += pageSize;
for (String json : jsons) {
TestCase testCase = JsonUtils.readValue(json, TestCase.class);
TestDefinition td = getTestDefinition(daoCollection, testCase);
if (Objects.nonNull(td) && Objects.equals(td.getName(), TABLE_DIFF)) {
LOG.debug("Adding caseSensitiveColumns=true table diff test case: {}", testCase.getId());
testCase
.getParameterValues()
.add(new TestCaseParameterValue().withName("caseSensitiveColumns").withValue("true"));
daoCollection.testCaseDAO().update(testCase);
}
}
}
}
public static TestDefinition getTestDefinition(CollectionDAO dao, TestCase testCase) {
List<CollectionDAO.EntityRelationshipRecord> records =
dao.relationshipDAO()
.findFrom(
testCase.getId(), TEST_CASE, Relationship.CONTAINS.ordinal(), TEST_DEFINITION);
if (records.size() > 1) {
throw new RuntimeException(
"Multiple test definitions found for test case: " + testCase.getId());
}
if (records.isEmpty()) {
return null;
}
return dao.testDefinitionDAO().findEntityById(records.get(0).getId());
}
}

View File

@ -42,6 +42,13 @@
"description": "Use this where clause to filter the rows to compare.",
"dataType": "STRING",
"required": false
},
{
"name": "caseSensitiveColumns",
"displayName": "Case sensitive columns",
"description": "Use case sensitivity when comparing the columns.",
"dataType": "BOOLEAN",
"required": false
}
],
"provider": "system",