Merge branch 'migrate/es-os-new-client-api' into migrate/es-os-entities-api

This commit is contained in:
Bhanu Agrawal 2025-10-08 14:03:35 +05:30 committed by GitHub
commit 0d20e80c20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 760 additions and 65 deletions

View File

@ -0,0 +1,61 @@
-- Correct the table diff test definition
-- This is to include the new parameter table2.keyColumns
UPDATE test_definition
SET json = JSON_SET(
json,
'$.parameterDefinition',
JSON_ARRAY(
JSON_OBJECT(
'name', 'keyColumns',
'displayName', 'Table 1\'s key columns',
'description', 'The columns to use as the key for the comparison. If not provided, it will be resolved from the primary key or unique columns. The tuples created from the key columns must be unique.',
'dataType', 'ARRAY',
'required', false
),
JSON_OBJECT(
'name', 'table2',
'displayName', 'Table 2',
'description', 'Fully qualified name of the table to compare against.',
'dataType', 'STRING',
'required', true
),
JSON_OBJECT(
'name', 'table2.keyColumns',
'displayName', 'Table 2\'s key columns',
'description', 'The columns in table 2 to use as comparison. If not provided, it will default to `Key Columns`, risking errors if the key columns\' names have changed.',
'dataType', 'ARRAY',
'required', false
),
JSON_OBJECT(
'name', 'threshold',
'displayName', 'Threshold',
'description', 'Threshold to use to determine if the test passes or fails (defaults to 0).',
'dataType', 'NUMBER',
'required', false
),
JSON_OBJECT(
'name', 'useColumns',
'displayName', 'Use Columns',
'description', 'Limits the scope of the test to this list of columns. If not provided, all columns will be used except the key columns.',
'dataType', 'ARRAY',
'required', false
),
JSON_OBJECT(
'name', 'where',
'displayName', 'SQL Where Clause',
'description', 'Use this where clause to filter the rows to compare.',
'dataType', 'STRING',
'required', false
),
JSON_OBJECT(
'name', 'caseSensitiveColumns',
'displayName', 'Case sensitive columns',
'description', 'Use case sensitivity when comparing the columns.',
'dataType', 'BOOLEAN',
'required', false
)
),
'$.version',
0.2
)
WHERE name = 'tableDiff';

View File

@ -0,0 +1,58 @@
-- Correct the table diff test definition
-- This is to include the new parameter table2.keyColumns
UPDATE test_definition
SET json = json::jsonb || json_build_object(
'parameterDefinition', jsonb_build_array(
jsonb_build_object(
'name', 'keyColumns',
'displayName', 'Table 1''s key Columns',
'description', 'The columns to use as the key for the comparison. If not provided, it will be resolved from the primary key or unique columns. The tuples created from the key columns must be unique.',
'dataType', 'ARRAY',
'required', false
),
jsonb_build_object(
'name', 'table2',
'displayName', 'Table 2',
'description', 'Fully qualified name of the table to compare against.',
'dataType', 'STRING',
'required', true
),
jsonb_build_object(
'name', 'table2.keyColumns',
'displayName', 'Table 2''s key columns',
'description', 'The columns in table 2 to use as comparison. If not provided, it will default to `Key Columns`, risking errors if the key columns'' names have changed.',
'dataType', 'ARRAY',
'required', false
),
jsonb_build_object(
'name', 'threshold',
'displayName', 'Threshold',
'description', 'Threshold to use to determine if the test passes or fails (defaults to 0).',
'dataType', 'NUMBER',
'required', false
),
jsonb_build_object(
'name', 'useColumns',
'displayName', 'Use Columns',
'description', 'Limits the scope of the test to this list of columns. If not provided, all columns will be used except the key columns.',
'dataType', 'ARRAY',
'required', false
),
jsonb_build_object(
'name', 'where',
'displayName', 'SQL Where Clause',
'description', 'Use this where clause to filter the rows to compare.',
'dataType', 'STRING',
'required', false
),
jsonb_build_object(
'name', 'caseSensitiveColumns',
'displayName', 'Case sensitive columns',
'description', 'Use case sensitivity when comparing the columns.',
'dataType', 'BOOLEAN',
'required', false
)
),
'version', 0.2
)::jsonb
WHERE name = 'tableDiff';

View File

@ -35,7 +35,7 @@ VERSIONS = {
"neo4j": "neo4j~=5.3",
"pandas": "pandas~=2.0.3",
"pyarrow": "pyarrow~=16.0",
"pydantic": "pydantic~=2.0,>=2.7.0",
"pydantic": "pydantic~=2.0,>=2.7.0,<2.12", # Pin down to <2.12 due to breaking changes in 2.12.0
"pydantic-settings": "pydantic-settings~=2.0,>=2.7.0",
"pydomo": "pydomo~=0.3",
"pymysql": "pymysql~=1.0",

View File

@ -2,7 +2,7 @@
from typing import List, Optional, Union
from pydantic import BaseModel
from pydantic import BaseModel, Field
from metadata.generated.schema.entity.data.table import (
Column,
@ -23,13 +23,19 @@ class TableParameter(BaseModel):
database_service_type: DatabaseServiceType
privateKey: Optional[CustomSecretStr]
passPhrase: Optional[CustomSecretStr]
key_columns: Optional[list[str]] = None
extra_columns: Optional[list[str]] = None
class TableDiffRuntimeParameters(BaseModel):
table1: TableParameter
table2: TableParameter
keyColumns: List[str]
extraColumns: List[str]
keyColumns: Optional[List[str]] = Field(
..., deprecated="Please use `tableX.key_columns` instead"
)
extraColumns: Optional[List[str]] = Field(
..., deprecated="Please use `tableX.extra_columns` instead"
)
whereClause: Optional[str]
table_profile_config: Optional[TableProfilerConfig]

View File

@ -92,6 +92,8 @@ class BaseTableParameter:
),
privateKey=None,
passPhrase=None,
key_columns=key_columns,
extra_columns=extra_columns,
)
@staticmethod
@ -154,6 +156,13 @@ class BaseTableParameter:
else None
)
@classmethod
def get_service_connection_config(
cls,
service: DatabaseService,
) -> Optional[Union[str, dict]]:
return cls._get_service_connection_config(service.connection.config)
def get_data_diff_url(
self,
db_service: DatabaseService,

View File

@ -10,10 +10,23 @@
# limitations under the License.
"""Module that defines the TableDiffParamsSetter class."""
from ast import literal_eval
from typing import List, Optional, Set
from typing import (
Any,
Callable,
List,
Optional,
Protocol,
Set,
Union,
runtime_checkable,
)
from metadata.data_quality.validations import utils
from metadata.data_quality.validations.models import Column, TableDiffRuntimeParameters
from metadata.data_quality.validations.models import (
Column,
TableDiffRuntimeParameters,
TableParameter,
)
from metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter import (
ServiceSpecPatch,
)
@ -28,9 +41,32 @@ from metadata.utils import fqn
from metadata.utils.collections import CaseInsensitiveList
@runtime_checkable
class TableParameterSetter(Protocol):
def get(
self,
service: DatabaseService,
entity: Table,
key_columns,
extra_columns,
case_sensitive_columns,
service_url: Optional[Union[str, dict]],
) -> TableParameter:
...
def get_service_connection_config(self, service: DatabaseService):
...
def get_service_url(
param_setter: TableParameterSetter, service: DatabaseService
) -> Optional[Union[str, dict[str, Any]]]:
return param_setter.get_service_connection_config(service)
class TableDiffParamsSetter(RuntimeParameterSetter):
"""
Set runtime parameters for a the table diff test.
Set runtime parameters for a table diff test.
Sets the following variables:
- service1Url: The url of the first service (data diff compliant)
- service2Url: The url of the second service (data diff compliant)
@ -41,11 +77,17 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
- whereClause: Exrtact where clause based on partitioning and user input
"""
def __init__(self, *args, **kwargs):
def __init__(
self,
*args,
service_url_getter: Callable[
[TableParameterSetter, DatabaseService],
Optional[Union[str, dict[str, Any]]],
] = get_service_url,
**kwargs,
):
super().__init__(*args, **kwargs)
self._table_columns = {
column.name.root: column for column in self.table_entity.columns
}
self.get_service_url = service_url_getter
def get_parameters(self, test_case) -> TableDiffRuntimeParameters:
service1: DatabaseService = self.ometa_client.get_by_id(
@ -63,34 +105,37 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
DatabaseService, table2.service.id, nullable=False
)
service_spec_patch_table_1 = ServiceSpecPatch(
ServiceType.Database, service1.connection.config.type.value.lower()
)
data_diff_class_1 = service_spec_patch_table_1.get_data_diff_class()()
service_spec_patch_table_2 = ServiceSpecPatch(
ServiceType.Database, service2.connection.config.type.value.lower()
)
data_diff_class_2 = service_spec_patch_table_2.get_data_diff_class()()
table1_param_setter = self.get_param_setter(service1)
table2_param_setter = self.get_param_setter(service2)
service1_url = data_diff_class_1._get_service_connection_config(
service1.connection.config
)
service2_url = (
self.get_parameter(test_case, "service2Url") or service1_url
if table2.service == self.table_entity.service
else data_diff_class_2._get_service_connection_config(
service2.connection.config
)
or None
)
service1_url = self.get_service_url(table1_param_setter, service1)
if table2.service == self.table_entity.service:
service2_url = self.get_parameter(test_case, "service2Url") or service1_url
else:
service2_url = self.get_service_url(table2_param_setter, service2)
key_columns = self.get_key_columns(test_case)
table2_key_columns = (
self.get_table_key_columns(test_case, table2) or key_columns
)
extra_columns = (
self.get_extra_columns(
key_columns, test_case, self.table_entity.columns, table2.columns
key_columns | table2_key_columns,
test_case,
self.table_entity.columns,
table2.columns,
)
or set()
)
table1_extra_columns = self.get_table_extra_columns(
test_case, self.table_entity
)
table2_extra_columns = (
self.get_table_extra_columns(test_case, table2) or extra_columns
)
case_sensitive_columns: bool = (
utils.get_bool_test_case_param(
test_case.parameterValues, "caseSensitiveColumns"
@ -100,19 +145,19 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
return TableDiffRuntimeParameters(
table_profile_config=self.table_entity.tableProfilerConfig,
table1=data_diff_class_1.get(
table1=table1_param_setter.get(
service1,
self.table_entity,
key_columns,
extra_columns,
table1_extra_columns or extra_columns,
case_sensitive_columns,
service1_url,
),
table2=data_diff_class_2.get(
table2=table2_param_setter.get(
service2,
table2,
key_columns,
extra_columns,
table2_key_columns,
table2_extra_columns,
case_sensitive_columns,
service2_url,
),
@ -178,6 +223,45 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
)
return set(key_columns)
def get_table_key_columns(
self, test_case: TestCase, table: Table
) -> Optional[set[str]]:
key = "table1" if table is self.table_entity else "table2"
param = self.get_parameter(test_case, f"{key}.keyColumns", "[]")
key_columns: List[str] = literal_eval(param)
if not key_columns:
return None
self.validate_columns(key_columns, table)
return set(key_columns)
def get_table_extra_columns(
self, test_case: TestCase, table: Table
) -> Optional[List[str]]:
key = "table1" if table is self.table_entity else "table2"
param = self.get_parameter(test_case, f"{key}.extraColumns", "[]")
extra_columns: List[str] = literal_eval(param)
if not extra_columns:
return None
self.validate_columns(extra_columns, table)
return extra_columns
def validate_columns(
self, column_names: List[str], table: Optional[Table] = None
) -> None:
if table is None:
table = self.table_entity
table_columns_names: Set[str] = {c.name.root for c in table.columns}
for column in column_names:
if column not in table_columns_names:
raise ValueError(
f"Failed to resolve key columns for table diff.\n"
f"Column '{column}' not found in table '{table.name.root}'.\n"
)
@staticmethod
def filter_relevant_columns(
columns: List[Column],
@ -207,10 +291,9 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
"___SERVICE___", "__DATABASE__", schema, table
).replace("___SERVICE___.__DATABASE__.", "")
def validate_columns(self, column_names: List[str]):
for column in column_names:
if not self._table_columns.get(column):
raise ValueError(
f"Failed to resolve key columns for table diff.\n"
f"Column '{column}' not found in table '{self.table_entity.name.root}'.\n"
)
@staticmethod
def get_param_setter(service: DatabaseService) -> TableParameterSetter:
patch = ServiceSpecPatch(
ServiceType.Database, service.connection.config.type.value.lower()
)
return patch.get_data_diff_class()()

View File

@ -266,7 +266,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
table1 = data_diff.connect_to_table(
self.runtime_params.table1.serviceUrl,
self.runtime_params.table1.path,
self.runtime_params.keyColumns,
self.runtime_params.table1.key_columns,
extra_columns=self.runtime_params.extraColumns,
case_sensitive=self.get_case_sensitive(),
key_content=self.runtime_params.table1.privateKey.get_secret_value()
@ -279,7 +279,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
table2 = data_diff.connect_to_table(
self.runtime_params.table2.serviceUrl,
self.runtime_params.table2.path,
self.runtime_params.keyColumns,
self.runtime_params.table2.key_columns,
extra_columns=self.runtime_params.extraColumns,
case_sensitive=self.get_case_sensitive(),
key_content=self.runtime_params.table2.privateKey.get_secret_value()
@ -345,7 +345,8 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
table1 = data_diff.connect_to_table(
self.runtime_params.table1.serviceUrl,
self.runtime_params.table1.path,
self.runtime_params.keyColumns, # type: ignore
self.runtime_params.table1.key_columns, # type: ignore
extra_columns=self.runtime_params.table1.extra_columns,
case_sensitive=self.get_case_sensitive(),
where=left_where,
key_content=self.runtime_params.table1.privateKey.get_secret_value()
@ -358,7 +359,8 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
table2 = data_diff.connect_to_table(
self.runtime_params.table2.serviceUrl,
self.runtime_params.table2.path,
self.runtime_params.keyColumns, # type: ignore
self.runtime_params.table2.key_columns, # type: ignore
extra_columns=self.runtime_params.table2.extra_columns,
case_sensitive=self.get_case_sensitive(),
where=right_where,
key_content=self.runtime_params.table1.privateKey.get_secret_value()
@ -369,8 +371,6 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
else None,
)
data_diff_kwargs = {
"key_columns": self.runtime_params.keyColumns,
"extra_columns": self.runtime_params.extraColumns,
"where": self.get_where(),
}
logger.debug(
@ -434,14 +434,27 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
salt = "".join(
random.choices(string.ascii_letters + string.digits, k=5)
) # 1 / ~62^5 should be enough entropy. Use letters and digits to avoid messing with SQL syntax
key_columns = (
CaseInsensitiveList(self.runtime_params.keyColumns)
if not self.get_case_sensitive()
else self.runtime_params.keyColumns
return (
build_sample_where_clause(
self.runtime_params.table1,
self.maybe_case_sensitive(self.runtime_params.table1.key_columns),
salt,
hex_nounce,
),
build_sample_where_clause(
self.runtime_params.table2,
self.maybe_case_sensitive(self.runtime_params.table2.key_columns),
salt,
hex_nounce,
),
)
return tuple(
build_sample_where_clause(table, key_columns, salt, hex_nounce)
for table in [self.runtime_params.table1, self.runtime_params.table2]
def maybe_case_sensitive(self, iterable: Iterable[str]) -> list[str]:
return (
CaseInsensitiveList(iterable)
if not self.get_case_sensitive()
else list(iterable)
)
def calculate_nounce(self, max_nounce=2**32 - 1) -> int:
@ -522,8 +535,16 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
def get_column_diff(self) -> Optional[TestCaseResult]:
"""Get the column diff between the two tables. If there are no differences, return None."""
removed, added = self.get_changed_added_columns(
self.runtime_params.table1.columns,
self.runtime_params.table2.columns,
[
c
for c in self.runtime_params.table1.columns
if c.name.root not in self.runtime_params.table1.key_columns
],
[
c
for c in self.runtime_params.table2.columns
if c.name.root not in self.runtime_params.table2.key_columns
],
self.get_case_sensitive(),
)
changed = self.get_incomparable_columns()
@ -585,9 +606,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
f"Tables have {sum(map(len, [removed, added, changed]))} different columns:"
)
if removed:
message += f"\n Removed columns: {','.join(removed)}\n"
message += f"\n Removed columns: {', '.join(removed)}\n"
if added:
message += f"\n Added columns: {','.join(added)}\n"
message += f"\n Added columns: {', '.join(added)}\n"
if changed:
message += "\n Changed columns:"
table1_columns = {

View File

@ -0,0 +1,226 @@
import json
import uuid
from typing import List
from unittest.mock import create_autospec
import pytest
from dirty_equals import HasAttributes, IsInstance, IsListOrTuple
from metadata.data_quality.validations.models import (
TableDiffRuntimeParameters,
TableParameter,
)
from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import (
TableDiffParamsSetter,
TableParameterSetter,
)
from metadata.generated.schema.entity.data.table import (
Column,
ColumnName,
DataType,
Table,
)
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseConnection,
DatabaseService,
DatabaseServiceType,
)
from metadata.generated.schema.tests.testCase import TestCase, TestCaseParameterValue
from metadata.generated.schema.type.basic import FullyQualifiedEntityName
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.sampler.sampler_interface import SamplerInterface
@pytest.fixture
def metadata(
service1: DatabaseService, table1: Table, service2: DatabaseService, table2: Table
) -> OpenMetadata:
mock = create_autospec(OpenMetadata, spec_set=True, instance=True)
objects_by_entity_and_id = {
(DatabaseService, table1.service.id): service1,
(DatabaseService, table2.service.id): service2,
}
objects_by_entity_and_name = {
(Table, table2.fullyQualifiedName.root): table2,
}
def mock_get_by_id(entity, entity_id, **kwargs):
return objects_by_entity_and_id.get((entity, entity_id), None)
def mock_get_by_name(entity, fqn, **kwargs):
return objects_by_entity_and_name.get((entity, fqn), None)
mock.get_by_id.side_effect = mock_get_by_id
mock.get_by_name.side_effect = mock_get_by_name
return mock
@pytest.fixture
def service_connection_config() -> DatabaseConnection:
return create_autospec(DatabaseConnection, spec_set=True, instance=True)
@pytest.fixture
def sampler() -> SamplerInterface:
mock = create_autospec(SamplerInterface, instance=True)
mock.partition_details = None
return mock
@pytest.fixture
def service1() -> DatabaseService:
return DatabaseService.model_construct(
id=uuid.uuid4(),
name="TestService1",
fullyQualifiedName="TestService1",
serviceType=DatabaseServiceType.Postgres,
connection=DatabaseConnection(config=PostgresConnection.model_construct()),
)
@pytest.fixture
def service2() -> DatabaseService:
return DatabaseService.model_construct(
id=uuid.uuid4(),
name="TestService2",
fullyQualifiedName="TestService2",
serviceType=DatabaseServiceType.Postgres,
connection=DatabaseConnection(config=PostgresConnection.model_construct()),
)
@pytest.fixture
def table1() -> Table:
return Table.model_construct(
id=uuid.uuid4(),
name="table1",
fullyQualifiedName=FullyQualifiedEntityName(
root="TestService1.test_db.test_schema.table1"
),
service=EntityReference.model_construct(id=uuid.uuid4(), name="test_service1"),
columns=[
Column.model_construct(
name=ColumnName(root="id"),
dataType=DataType.STRING,
),
Column.model_construct(
name=ColumnName(root="name"),
dataType=DataType.STRING,
),
],
)
@pytest.fixture
def table2() -> Table:
return Table.model_construct(
id=uuid.uuid4(),
name="table2",
fullyQualifiedName=FullyQualifiedEntityName(
root="TestService2.test_db.test_schema.table2"
),
service=EntityReference.model_construct(id=uuid.uuid4(), name="test_service2"),
columns=[
Column.model_construct(
name=ColumnName(root="table_id"),
dataType=DataType.STRING,
),
Column.model_construct(
name=ColumnName(root="name"),
dataType=DataType.STRING,
),
],
)
def fake_get_service_url(
param_setter: TableParameterSetter, service: DatabaseService
) -> str:
return "postgresql+psycopg2://test:test@localhost/test"
@pytest.fixture
def setter(
metadata: OpenMetadata,
service_connection_config: DatabaseConnection,
sampler: SamplerInterface,
table1: Table,
) -> TableDiffParamsSetter:
return TableDiffParamsSetter(
ometa_client=metadata,
service_connection_config=service_connection_config,
sampler=sampler,
table_entity=table1,
service_url_getter=fake_get_service_url,
)
@pytest.fixture
def parameter_values() -> List[TestCaseParameterValue]:
return [
TestCaseParameterValue(
name="table2", value="TestService2.test_db.test_schema.table2"
)
]
def test_setter_gets_default_key_columns(
setter: TableDiffParamsSetter, parameter_values: List[TestCaseParameterValue]
) -> None:
test_case = TestCase.model_construct(
parameterValues=[
*parameter_values,
TestCaseParameterValue(name="keyColumns", value=json.dumps(["id"])),
],
)
assert setter.get_parameters(test_case) == IsInstance(
TableDiffRuntimeParameters
) & HasAttributes(
keyColumns=["id"],
extraColumns=IsListOrTuple("name", "table_id", check_order=False),
table1=IsInstance(TableParameter)
& HasAttributes(
key_columns=["id"],
),
table2=IsInstance(TableParameter)
& HasAttributes(
key_columns=["id"],
),
)
def test_setter_gets_per_table_key_columns(
setter: TableDiffParamsSetter, parameter_values: List[TestCaseParameterValue]
) -> None:
test_case = TestCase.model_construct(
parameterValues=[
*parameter_values,
TestCaseParameterValue(name="keyColumns", value=json.dumps(["id"])),
TestCaseParameterValue(
name="table2.keyColumns", value=json.dumps(["table_id"])
),
]
)
assert setter.get_parameters(test_case) == IsInstance(
TableDiffRuntimeParameters
) & HasAttributes(
keyColumns=["id"],
extraColumns=IsListOrTuple("name", check_order=False),
table1=IsInstance(TableParameter)
& HasAttributes(
key_columns=["id"],
),
table2=IsInstance(TableParameter)
& HasAttributes(
key_columns=["table_id"],
),
)

View File

@ -0,0 +1,212 @@
import datetime
from typing import Generator
from unittest.mock import MagicMock, Mock, patch
import pytest
from dirty_equals import Contains, DirtyEquals, HasAttributes
from metadata.data_quality.validations.models import (
TableDiffRuntimeParameters,
TableParameter,
)
from metadata.data_quality.validations.table.sqlalchemy.tableDiff import (
TableDiffValidator,
)
from metadata.generated.schema.entity.data.table import Column, ColumnName, DataType
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
)
from metadata.generated.schema.tests.basic import TestCaseStatus
from metadata.generated.schema.tests.testCase import TestCase
from metadata.generated.schema.type.basic import Timestamp
def build_table_parameter(
*columns: Column,
key_columns: list[str],
extra_columns: list[str],
service_url: str = "postgresql://postgres:postgres@service:5432/postgres",
) -> TableParameter:
return TableParameter.model_construct(
serviceUrl=service_url,
path="test_schema.test_table",
database_service_type=DatabaseServiceType.Postgres,
columns=columns,
privateKey=None,
passPhrase=None,
key_columns=key_columns,
extra_columns=extra_columns,
)
@pytest.fixture
def table1_parameter() -> TableParameter:
return build_table_parameter(
Column.model_construct(name=ColumnName(root="id"), dataType=DataType.STRING),
Column.model_construct(
name=ColumnName(root="first_name"), dataType=DataType.STRING
),
Column.model_construct(
name=ColumnName(root="last_name"), dataType=DataType.STRING
),
key_columns=["id"],
extra_columns=["first_name", "last_name"],
service_url="postgresql://postgres:postgres@service1:5432/postgres",
)
@pytest.fixture
def table2_parameter() -> TableParameter:
return build_table_parameter(
Column.model_construct(
name=ColumnName(root="table_id"), dataType=DataType.STRING
),
Column.model_construct(
name=ColumnName(root="first_name"), dataType=DataType.STRING
),
Column.model_construct(
name=ColumnName(root="last_name"), dataType=DataType.STRING
),
key_columns=["table_id"],
extra_columns=["first_name", "last_name"],
service_url="postgresql://postgres:postgres@service2:5432/postgres",
)
@pytest.fixture
def parameters(
table1_parameter: TableParameter, table2_parameter: TableParameter
) -> TableDiffRuntimeParameters:
return TableDiffRuntimeParameters(
table1=table1_parameter,
table2=table2_parameter,
table_profile_config=None,
whereClause=None,
keyColumns=None,
extraColumns=None,
)
@pytest.fixture
def validator(
parameters: TableDiffRuntimeParameters,
) -> Generator[TableDiffValidator, None, None]:
with patch(
"metadata.data_quality.validations.table.sqlalchemy.tableDiff.data_diff"
) as data_diff:
mock_table = MagicMock()
mock_table.key_columns = []
mock_table.extra_columns = []
data_diff.connect_to_table = Mock(return_value=mock_table)
validator = TableDiffValidator(
runner=[],
test_case=TestCase.model_construct(parameterValues=[]),
execution_date=Timestamp(root=int(datetime.datetime.now().timestamp())),
)
validator.runtime_params = parameters
yield validator
class TestGetColumnDiff:
def test_it_returns_none_when_no_diff(
self, validator: TableDiffValidator, parameters: TableDiffRuntimeParameters
) -> None:
assert validator.get_column_diff() is None
@pytest.mark.parametrize(
"table1_parameter, table2_parameter, expected",
(
(
build_table_parameter(
Column.model_construct(
name=ColumnName(root="id"), dataType=DataType.STRING
),
Column.model_construct(
name=ColumnName(root="last_name"), dataType=DataType.STRING
),
key_columns=["id"],
extra_columns=["last_name"],
),
build_table_parameter(
Column.model_construct(
name=ColumnName(root="id"), dataType=DataType.STRING
),
Column.model_construct(
name=ColumnName(root="first_name"), dataType=DataType.STRING
),
key_columns=["id"],
extra_columns=["first_name"],
),
HasAttributes(
testCaseStatus=TestCaseStatus.Failed,
result=Contains("Removed columns: last_name")
& Contains("Added columns: first_name")
& ~Contains("Changed"),
),
),
(
build_table_parameter(
Column.model_construct(
name=ColumnName(root="id"), dataType=DataType.STRING
),
Column.model_construct(
name=ColumnName(root="last_name"), dataType=DataType.STRING
),
key_columns=["id"],
extra_columns=["last_name"],
),
build_table_parameter(
Column.model_construct(
name=ColumnName(root="table_id"), dataType=DataType.STRING
),
Column.model_construct(
name=ColumnName(root="first_name"), dataType=DataType.STRING
),
key_columns=["table_id"],
extra_columns=["first_name"],
),
HasAttributes(
testCaseStatus=TestCaseStatus.Failed,
result=Contains("Removed columns: last_name")
& Contains("Added columns: first_name")
& ~Contains("Changed"),
),
),
(
build_table_parameter(
Column.model_construct(
name=ColumnName(root="id"), dataType=DataType.STRING
),
Column.model_construct(
name=ColumnName(root="last_name"), dataType=DataType.STRING
),
key_columns=["id"],
extra_columns=["last_name"],
),
build_table_parameter(
Column.model_construct(
name=ColumnName(root="table_id"), dataType=DataType.STRING
),
Column.model_construct(
name=ColumnName(root="first_name"), dataType=DataType.STRING
),
key_columns=["id"], # The error trying to solve in #22302
extra_columns=["first_name"],
),
HasAttributes(
testCaseStatus=TestCaseStatus.Failed,
result=Contains("Removed columns: last_name")
& Contains("Added columns: table_id, first_name")
& ~Contains("Changed"),
),
),
),
)
def test_it_returns_the_expected_result(
self,
validator: TableDiffValidator,
parameters: TableDiffRuntimeParameters,
expected: DirtyEquals,
) -> None:
assert validator.get_column_diff() == expected

View File

@ -59,6 +59,7 @@ def test_compile_and_clauses(elements, expected):
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
"key_columns": ["id"],
}
),
"table2": TableParameter.model_construct(
@ -68,6 +69,7 @@ def test_compile_and_clauses(elements, expected):
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
"key_columns": ["id"],
}
),
"keyColumns": ["id"],
@ -90,6 +92,7 @@ def test_compile_and_clauses(elements, expected):
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
"key_columns": ["id"],
}
),
"table2": TableParameter.model_construct(
@ -99,6 +102,7 @@ def test_compile_and_clauses(elements, expected):
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
"key_columns": ["id"],
}
),
"keyColumns": ["id"],
@ -121,6 +125,7 @@ def test_compile_and_clauses(elements, expected):
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
"key_columns": ["id", "name"],
}
),
"table2": TableParameter.model_construct(
@ -130,6 +135,7 @@ def test_compile_and_clauses(elements, expected):
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
"key_columns": ["id", "name"],
}
),
"keyColumns": ["id", "name"],
@ -152,6 +158,7 @@ def test_compile_and_clauses(elements, expected):
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
"key_columns": ["id", "name"],
}
),
"table2": TableParameter.model_construct(
@ -161,6 +168,7 @@ def test_compile_and_clauses(elements, expected):
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
"key_columns": ["id", "name"],
},
),
"keyColumns": ["id", "name"],
@ -182,6 +190,7 @@ def test_compile_and_clauses(elements, expected):
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
"key_columns": ["id"],
}
),
"table2": TableParameter.model_construct(
@ -191,6 +200,7 @@ def test_compile_and_clauses(elements, expected):
Column(name="ID", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
"key_columns": ["id"],
},
),
"keyColumns": ["id"],
@ -212,6 +222,7 @@ def test_compile_and_clauses(elements, expected):
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
"key_columns": ["id"],
}
),
"table2": TableParameter.model_construct(
@ -221,6 +232,7 @@ def test_compile_and_clauses(elements, expected):
Column(name="id", dataType=DataType.STRING),
Column(name="name", dataType=DataType.STRING),
],
"key_columns": ["id"],
},
),
"keyColumns": ["id"],

View File

@ -8,17 +8,24 @@
"OpenMetadata"
],
"parameterDefinition": [
{
"name": "keyColumns",
"displayName": "Table 1's key columns",
"description": "The columns to use as the key for the comparison. If not provided, it will be resolved from the primary key or unique columns. The tuples created from the key columns must be unique.",
"dataType": "ARRAY",
"required": false
},
{
"name": "table2",
"displayName": "Table 2",
"description": "Fully qualified name of the table to compare against.",
"dataType": "STRING",
"required": "true"
"required": true
},
{
"name": "keyColumns",
"displayName": "Key Columns",
"description": "The columns to use as the key for the comparison. If not provided, it will be resolved from the primary key or unique columns. The tuples created from the key columns must be unique.",
"name": "table2.keyColumns",
"displayName": "Table 2's key columns",
"description": "The columns in table 2 to use as comparison. If not provided, it will default to `Key Columns`, risking errors if the key columns' names have changed.",
"dataType": "ARRAY",
"required": false
},

View File

@ -164,7 +164,7 @@ test('Table difference test case', async ({ page }) => {
await page
.locator('label')
.filter({ hasText: 'Key Columns' })
.filter({ hasText: "Table 1's key columns" })
.getByRole('button')
.click();
await page.waitForSelector(`[data-id="tableDiff"]`, {