From af0672e4cf8e9c173d97bbd3afdf7378c6da9aef Mon Sep 17 00:00:00 2001 From: Eugenio Date: Wed, 8 Oct 2025 09:32:00 +0200 Subject: [PATCH] Fixes #22302: add `table2.keyColumns` parameter for table diff validation (#23667) * Update `TableDiffParamsSetter` to move data at table level This means that `key_columns` and `extra_columns` will be defined per table instead of "globally", just like `data_diff` expects * Update `TableDiffValidator` to use table's `key_columns` Call `data_diff` and run validations using each table's `key_columns` * Create migration to update `tableDiff` test definition * Fix Playwright test --- .../mysql/postDataMigrationSQLScript.sql | 61 +++++ .../postgres/postDataMigrationSQLScript.sql | 58 +++++ .../data_quality/validations/models.py | 12 +- .../base_diff_params_setter.py | 9 + .../table_diff_params_setter.py | 161 ++++++++++--- .../validations/table/sqlalchemy/tableDiff.py | 55 +++-- .../test_table_diff_params_setter.py | 226 ++++++++++++++++++ .../validations/table/__init__.py | 0 .../validations/table/sqlalchemy/__init__.py | 0 .../table/sqlalchemy/test_table_diff.py | 212 ++++++++++++++++ .../metadata/data_quality/test_data_diff.py | 12 + .../resources/json/data/tests/tableDiff.json | 15 +- .../ui/playwright/e2e/Pages/TestCases.spec.ts | 2 +- 13 files changed, 759 insertions(+), 64 deletions(-) create mode 100644 ingestion/tests/unit/data_quality/validations/runtime_param_setter/test_table_diff_params_setter.py create mode 100644 ingestion/tests/unit/data_quality/validations/table/__init__.py create mode 100644 ingestion/tests/unit/data_quality/validations/table/sqlalchemy/__init__.py create mode 100644 ingestion/tests/unit/data_quality/validations/table/sqlalchemy/test_table_diff.py diff --git a/bootstrap/sql/migrations/native/1.11.0/mysql/postDataMigrationSQLScript.sql b/bootstrap/sql/migrations/native/1.11.0/mysql/postDataMigrationSQLScript.sql index e69de29bb2d..7a3a7ed7b95 100644 --- a/bootstrap/sql/migrations/native/1.11.0/mysql/postDataMigrationSQLScript.sql +++ b/bootstrap/sql/migrations/native/1.11.0/mysql/postDataMigrationSQLScript.sql @@ -0,0 +1,61 @@ +-- Correct the table diff test definition +-- This is to include the new parameter table2.keyColumns +UPDATE test_definition +SET json = JSON_SET( + json, + '$.parameterDefinition', + JSON_ARRAY( + JSON_OBJECT( + 'name', 'keyColumns', + 'displayName', 'Table 1\'s key columns', + 'description', 'The columns to use as the key for the comparison. If not provided, it will be resolved from the primary key or unique columns. The tuples created from the key columns must be unique.', + 'dataType', 'ARRAY', + 'required', false + ), + JSON_OBJECT( + 'name', 'table2', + 'displayName', 'Table 2', + 'description', 'Fully qualified name of the table to compare against.', + 'dataType', 'STRING', + 'required', true + ), + JSON_OBJECT( + 'name', 'table2.keyColumns', + 'displayName', 'Table 2\'s key columns', + 'description', 'The columns in table 2 to use as comparison. If not provided, it will default to `Key Columns`, risking errors if the key columns\' names have changed.', + 'dataType', 'ARRAY', + 'required', false + ), + JSON_OBJECT( + 'name', 'threshold', + 'displayName', 'Threshold', + 'description', 'Threshold to use to determine if the test passes or fails (defaults to 0).', + 'dataType', 'NUMBER', + 'required', false + ), + JSON_OBJECT( + 'name', 'useColumns', + 'displayName', 'Use Columns', + 'description', 'Limits the scope of the test to this list of columns. If not provided, all columns will be used except the key columns.', + 'dataType', 'ARRAY', + 'required', false + ), + JSON_OBJECT( + 'name', 'where', + 'displayName', 'SQL Where Clause', + 'description', 'Use this where clause to filter the rows to compare.', + 'dataType', 'STRING', + 'required', false + ), + JSON_OBJECT( + 'name', 'caseSensitiveColumns', + 'displayName', 'Case sensitive columns', + 'description', 'Use case sensitivity when comparing the columns.', + 'dataType', 'BOOLEAN', + 'required', false + ) + ), + '$.version', + 0.2 +) +WHERE name = 'tableDiff'; diff --git a/bootstrap/sql/migrations/native/1.11.0/postgres/postDataMigrationSQLScript.sql b/bootstrap/sql/migrations/native/1.11.0/postgres/postDataMigrationSQLScript.sql index e69de29bb2d..10a600b3adc 100644 --- a/bootstrap/sql/migrations/native/1.11.0/postgres/postDataMigrationSQLScript.sql +++ b/bootstrap/sql/migrations/native/1.11.0/postgres/postDataMigrationSQLScript.sql @@ -0,0 +1,58 @@ +-- Correct the table diff test definition +-- This is to include the new parameter table2.keyColumns +UPDATE test_definition +SET json = json::jsonb || json_build_object( + 'parameterDefinition', jsonb_build_array( + jsonb_build_object( + 'name', 'keyColumns', + 'displayName', 'Table 1''s key Columns', + 'description', 'The columns to use as the key for the comparison. If not provided, it will be resolved from the primary key or unique columns. The tuples created from the key columns must be unique.', + 'dataType', 'ARRAY', + 'required', false + ), + jsonb_build_object( + 'name', 'table2', + 'displayName', 'Table 2', + 'description', 'Fully qualified name of the table to compare against.', + 'dataType', 'STRING', + 'required', true + ), + jsonb_build_object( + 'name', 'table2.keyColumns', + 'displayName', 'Table 2''s key columns', + 'description', 'The columns in table 2 to use as comparison. If not provided, it will default to `Key Columns`, risking errors if the key columns'' names have changed.', + 'dataType', 'ARRAY', + 'required', false + ), + jsonb_build_object( + 'name', 'threshold', + 'displayName', 'Threshold', + 'description', 'Threshold to use to determine if the test passes or fails (defaults to 0).', + 'dataType', 'NUMBER', + 'required', false + ), + jsonb_build_object( + 'name', 'useColumns', + 'displayName', 'Use Columns', + 'description', 'Limits the scope of the test to this list of columns. If not provided, all columns will be used except the key columns.', + 'dataType', 'ARRAY', + 'required', false + ), + jsonb_build_object( + 'name', 'where', + 'displayName', 'SQL Where Clause', + 'description', 'Use this where clause to filter the rows to compare.', + 'dataType', 'STRING', + 'required', false + ), + jsonb_build_object( + 'name', 'caseSensitiveColumns', + 'displayName', 'Case sensitive columns', + 'description', 'Use case sensitivity when comparing the columns.', + 'dataType', 'BOOLEAN', + 'required', false + ) + ), + 'version', 0.2 +)::jsonb +WHERE name = 'tableDiff'; diff --git a/ingestion/src/metadata/data_quality/validations/models.py b/ingestion/src/metadata/data_quality/validations/models.py index c460b7d6eb1..5140ad7a273 100644 --- a/ingestion/src/metadata/data_quality/validations/models.py +++ b/ingestion/src/metadata/data_quality/validations/models.py @@ -2,7 +2,7 @@ from typing import List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field from metadata.generated.schema.entity.data.table import ( Column, @@ -23,13 +23,19 @@ class TableParameter(BaseModel): database_service_type: DatabaseServiceType privateKey: Optional[CustomSecretStr] passPhrase: Optional[CustomSecretStr] + key_columns: Optional[list[str]] = None + extra_columns: Optional[list[str]] = None class TableDiffRuntimeParameters(BaseModel): table1: TableParameter table2: TableParameter - keyColumns: List[str] - extraColumns: List[str] + keyColumns: Optional[List[str]] = Field( + ..., deprecated="Please use `tableX.key_columns` instead" + ) + extraColumns: Optional[List[str]] = Field( + ..., deprecated="Please use `tableX.extra_columns` instead" + ) whereClause: Optional[str] table_profile_config: Optional[TableProfilerConfig] diff --git a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py index e1ffefd0e36..fe7dcece977 100644 --- a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py +++ b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py @@ -92,6 +92,8 @@ class BaseTableParameter: ), privateKey=None, passPhrase=None, + key_columns=key_columns, + extra_columns=extra_columns, ) @staticmethod @@ -154,6 +156,13 @@ class BaseTableParameter: else None ) + @classmethod + def get_service_connection_config( + cls, + service: DatabaseService, + ) -> Optional[Union[str, dict]]: + return cls._get_service_connection_config(service.connection.config) + def get_data_diff_url( self, db_service: DatabaseService, diff --git a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py index 48f85402dcf..1ddfcd1bba6 100644 --- a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py +++ b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py @@ -10,10 +10,23 @@ # limitations under the License. """Module that defines the TableDiffParamsSetter class.""" from ast import literal_eval -from typing import List, Optional, Set +from typing import ( + Any, + Callable, + List, + Optional, + Protocol, + Set, + Union, + runtime_checkable, +) from metadata.data_quality.validations import utils -from metadata.data_quality.validations.models import Column, TableDiffRuntimeParameters +from metadata.data_quality.validations.models import ( + Column, + TableDiffRuntimeParameters, + TableParameter, +) from metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter import ( ServiceSpecPatch, ) @@ -28,9 +41,32 @@ from metadata.utils import fqn from metadata.utils.collections import CaseInsensitiveList +@runtime_checkable +class TableParameterSetter(Protocol): + def get( + self, + service: DatabaseService, + entity: Table, + key_columns, + extra_columns, + case_sensitive_columns, + service_url: Optional[Union[str, dict]], + ) -> TableParameter: + ... + + def get_service_connection_config(self, service: DatabaseService): + ... + + +def get_service_url( + param_setter: TableParameterSetter, service: DatabaseService +) -> Optional[Union[str, dict[str, Any]]]: + return param_setter.get_service_connection_config(service) + + class TableDiffParamsSetter(RuntimeParameterSetter): """ - Set runtime parameters for a the table diff test. + Set runtime parameters for a table diff test. Sets the following variables: - service1Url: The url of the first service (data diff compliant) - service2Url: The url of the second service (data diff compliant) @@ -41,11 +77,17 @@ class TableDiffParamsSetter(RuntimeParameterSetter): - whereClause: Exrtact where clause based on partitioning and user input """ - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + service_url_getter: Callable[ + [TableParameterSetter, DatabaseService], + Optional[Union[str, dict[str, Any]]], + ] = get_service_url, + **kwargs, + ): super().__init__(*args, **kwargs) - self._table_columns = { - column.name.root: column for column in self.table_entity.columns - } + self.get_service_url = service_url_getter def get_parameters(self, test_case) -> TableDiffRuntimeParameters: service1: DatabaseService = self.ometa_client.get_by_id( @@ -63,34 +105,37 @@ class TableDiffParamsSetter(RuntimeParameterSetter): DatabaseService, table2.service.id, nullable=False ) - service_spec_patch_table_1 = ServiceSpecPatch( - ServiceType.Database, service1.connection.config.type.value.lower() - ) - data_diff_class_1 = service_spec_patch_table_1.get_data_diff_class()() - service_spec_patch_table_2 = ServiceSpecPatch( - ServiceType.Database, service2.connection.config.type.value.lower() - ) - data_diff_class_2 = service_spec_patch_table_2.get_data_diff_class()() + table1_param_setter = self.get_param_setter(service1) + table2_param_setter = self.get_param_setter(service2) - service1_url = data_diff_class_1._get_service_connection_config( - service1.connection.config - ) - service2_url = ( - self.get_parameter(test_case, "service2Url") or service1_url - if table2.service == self.table_entity.service - else data_diff_class_2._get_service_connection_config( - service2.connection.config - ) - or None - ) + service1_url = self.get_service_url(table1_param_setter, service1) + + if table2.service == self.table_entity.service: + service2_url = self.get_parameter(test_case, "service2Url") or service1_url + else: + service2_url = self.get_service_url(table2_param_setter, service2) key_columns = self.get_key_columns(test_case) + table2_key_columns = ( + self.get_table_key_columns(test_case, table2) or key_columns + ) + extra_columns = ( self.get_extra_columns( - key_columns, test_case, self.table_entity.columns, table2.columns + key_columns | table2_key_columns, + test_case, + self.table_entity.columns, + table2.columns, ) or set() ) + table1_extra_columns = self.get_table_extra_columns( + test_case, self.table_entity + ) + table2_extra_columns = ( + self.get_table_extra_columns(test_case, table2) or extra_columns + ) + case_sensitive_columns: bool = ( utils.get_bool_test_case_param( test_case.parameterValues, "caseSensitiveColumns" @@ -100,19 +145,19 @@ class TableDiffParamsSetter(RuntimeParameterSetter): return TableDiffRuntimeParameters( table_profile_config=self.table_entity.tableProfilerConfig, - table1=data_diff_class_1.get( + table1=table1_param_setter.get( service1, self.table_entity, key_columns, - extra_columns, + table1_extra_columns or extra_columns, case_sensitive_columns, service1_url, ), - table2=data_diff_class_2.get( + table2=table2_param_setter.get( service2, table2, - key_columns, - extra_columns, + table2_key_columns, + table2_extra_columns, case_sensitive_columns, service2_url, ), @@ -178,6 +223,45 @@ class TableDiffParamsSetter(RuntimeParameterSetter): ) return set(key_columns) + def get_table_key_columns( + self, test_case: TestCase, table: Table + ) -> Optional[set[str]]: + key = "table1" if table is self.table_entity else "table2" + param = self.get_parameter(test_case, f"{key}.keyColumns", "[]") + key_columns: List[str] = literal_eval(param) + + if not key_columns: + return None + + self.validate_columns(key_columns, table) + return set(key_columns) + + def get_table_extra_columns( + self, test_case: TestCase, table: Table + ) -> Optional[List[str]]: + key = "table1" if table is self.table_entity else "table2" + param = self.get_parameter(test_case, f"{key}.extraColumns", "[]") + extra_columns: List[str] = literal_eval(param) + if not extra_columns: + return None + self.validate_columns(extra_columns, table) + return extra_columns + + def validate_columns( + self, column_names: List[str], table: Optional[Table] = None + ) -> None: + if table is None: + table = self.table_entity + + table_columns_names: Set[str] = {c.name.root for c in table.columns} + + for column in column_names: + if column not in table_columns_names: + raise ValueError( + f"Failed to resolve key columns for table diff.\n" + f"Column '{column}' not found in table '{table.name.root}'.\n" + ) + @staticmethod def filter_relevant_columns( columns: List[Column], @@ -207,10 +291,9 @@ class TableDiffParamsSetter(RuntimeParameterSetter): "___SERVICE___", "__DATABASE__", schema, table ).replace("___SERVICE___.__DATABASE__.", "") - def validate_columns(self, column_names: List[str]): - for column in column_names: - if not self._table_columns.get(column): - raise ValueError( - f"Failed to resolve key columns for table diff.\n" - f"Column '{column}' not found in table '{self.table_entity.name.root}'.\n" - ) + @staticmethod + def get_param_setter(service: DatabaseService) -> TableParameterSetter: + patch = ServiceSpecPatch( + ServiceType.Database, service.connection.config.type.value.lower() + ) + return patch.get_data_diff_class()() diff --git a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py index 070263b9fea..31968cdb54b 100644 --- a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py +++ b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py @@ -266,7 +266,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): table1 = data_diff.connect_to_table( self.runtime_params.table1.serviceUrl, self.runtime_params.table1.path, - self.runtime_params.keyColumns, + self.runtime_params.table1.key_columns, extra_columns=self.runtime_params.extraColumns, case_sensitive=self.get_case_sensitive(), key_content=self.runtime_params.table1.privateKey.get_secret_value() @@ -279,7 +279,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): table2 = data_diff.connect_to_table( self.runtime_params.table2.serviceUrl, self.runtime_params.table2.path, - self.runtime_params.keyColumns, + self.runtime_params.table2.key_columns, extra_columns=self.runtime_params.extraColumns, case_sensitive=self.get_case_sensitive(), key_content=self.runtime_params.table2.privateKey.get_secret_value() @@ -345,7 +345,8 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): table1 = data_diff.connect_to_table( self.runtime_params.table1.serviceUrl, self.runtime_params.table1.path, - self.runtime_params.keyColumns, # type: ignore + self.runtime_params.table1.key_columns, # type: ignore + extra_columns=self.runtime_params.table1.extra_columns, case_sensitive=self.get_case_sensitive(), where=left_where, key_content=self.runtime_params.table1.privateKey.get_secret_value() @@ -358,7 +359,8 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): table2 = data_diff.connect_to_table( self.runtime_params.table2.serviceUrl, self.runtime_params.table2.path, - self.runtime_params.keyColumns, # type: ignore + self.runtime_params.table2.key_columns, # type: ignore + extra_columns=self.runtime_params.table2.extra_columns, case_sensitive=self.get_case_sensitive(), where=right_where, key_content=self.runtime_params.table1.privateKey.get_secret_value() @@ -369,8 +371,6 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): else None, ) data_diff_kwargs = { - "key_columns": self.runtime_params.keyColumns, - "extra_columns": self.runtime_params.extraColumns, "where": self.get_where(), } logger.debug( @@ -434,14 +434,27 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): salt = "".join( random.choices(string.ascii_letters + string.digits, k=5) ) # 1 / ~62^5 should be enough entropy. Use letters and digits to avoid messing with SQL syntax - key_columns = ( - CaseInsensitiveList(self.runtime_params.keyColumns) - if not self.get_case_sensitive() - else self.runtime_params.keyColumns + + return ( + build_sample_where_clause( + self.runtime_params.table1, + self.maybe_case_sensitive(self.runtime_params.table1.key_columns), + salt, + hex_nounce, + ), + build_sample_where_clause( + self.runtime_params.table2, + self.maybe_case_sensitive(self.runtime_params.table2.key_columns), + salt, + hex_nounce, + ), ) - return tuple( - build_sample_where_clause(table, key_columns, salt, hex_nounce) - for table in [self.runtime_params.table1, self.runtime_params.table2] + + def maybe_case_sensitive(self, iterable: Iterable[str]) -> list[str]: + return ( + CaseInsensitiveList(iterable) + if not self.get_case_sensitive() + else list(iterable) ) def calculate_nounce(self, max_nounce=2**32 - 1) -> int: @@ -522,8 +535,16 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): def get_column_diff(self) -> Optional[TestCaseResult]: """Get the column diff between the two tables. If there are no differences, return None.""" removed, added = self.get_changed_added_columns( - self.runtime_params.table1.columns, - self.runtime_params.table2.columns, + [ + c + for c in self.runtime_params.table1.columns + if c.name.root not in self.runtime_params.table1.key_columns + ], + [ + c + for c in self.runtime_params.table2.columns + if c.name.root not in self.runtime_params.table2.key_columns + ], self.get_case_sensitive(), ) changed = self.get_incomparable_columns() @@ -585,9 +606,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): f"Tables have {sum(map(len, [removed, added, changed]))} different columns:" ) if removed: - message += f"\n Removed columns: {','.join(removed)}\n" + message += f"\n Removed columns: {', '.join(removed)}\n" if added: - message += f"\n Added columns: {','.join(added)}\n" + message += f"\n Added columns: {', '.join(added)}\n" if changed: message += "\n Changed columns:" table1_columns = { diff --git a/ingestion/tests/unit/data_quality/validations/runtime_param_setter/test_table_diff_params_setter.py b/ingestion/tests/unit/data_quality/validations/runtime_param_setter/test_table_diff_params_setter.py new file mode 100644 index 00000000000..221f91b12ae --- /dev/null +++ b/ingestion/tests/unit/data_quality/validations/runtime_param_setter/test_table_diff_params_setter.py @@ -0,0 +1,226 @@ +import json +import uuid +from typing import List +from unittest.mock import create_autospec + +import pytest +from dirty_equals import HasAttributes, IsInstance, IsListOrTuple + +from metadata.data_quality.validations.models import ( + TableDiffRuntimeParameters, + TableParameter, +) +from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import ( + TableDiffParamsSetter, + TableParameterSetter, +) +from metadata.generated.schema.entity.data.table import ( + Column, + ColumnName, + DataType, + Table, +) +from metadata.generated.schema.entity.services.connections.database.postgresConnection import ( + PostgresConnection, +) +from metadata.generated.schema.entity.services.databaseService import ( + DatabaseConnection, + DatabaseService, + DatabaseServiceType, +) +from metadata.generated.schema.tests.testCase import TestCase, TestCaseParameterValue +from metadata.generated.schema.type.basic import FullyQualifiedEntityName +from metadata.generated.schema.type.entityReference import EntityReference +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.sampler.sampler_interface import SamplerInterface + + +@pytest.fixture +def metadata( + service1: DatabaseService, table1: Table, service2: DatabaseService, table2: Table +) -> OpenMetadata: + mock = create_autospec(OpenMetadata, spec_set=True, instance=True) + + objects_by_entity_and_id = { + (DatabaseService, table1.service.id): service1, + (DatabaseService, table2.service.id): service2, + } + + objects_by_entity_and_name = { + (Table, table2.fullyQualifiedName.root): table2, + } + + def mock_get_by_id(entity, entity_id, **kwargs): + return objects_by_entity_and_id.get((entity, entity_id), None) + + def mock_get_by_name(entity, fqn, **kwargs): + return objects_by_entity_and_name.get((entity, fqn), None) + + mock.get_by_id.side_effect = mock_get_by_id + mock.get_by_name.side_effect = mock_get_by_name + + return mock + + +@pytest.fixture +def service_connection_config() -> DatabaseConnection: + return create_autospec(DatabaseConnection, spec_set=True, instance=True) + + +@pytest.fixture +def sampler() -> SamplerInterface: + mock = create_autospec(SamplerInterface, instance=True) + mock.partition_details = None + return mock + + +@pytest.fixture +def service1() -> DatabaseService: + return DatabaseService.model_construct( + id=uuid.uuid4(), + name="TestService1", + fullyQualifiedName="TestService1", + serviceType=DatabaseServiceType.Postgres, + connection=DatabaseConnection(config=PostgresConnection.model_construct()), + ) + + +@pytest.fixture +def service2() -> DatabaseService: + return DatabaseService.model_construct( + id=uuid.uuid4(), + name="TestService2", + fullyQualifiedName="TestService2", + serviceType=DatabaseServiceType.Postgres, + connection=DatabaseConnection(config=PostgresConnection.model_construct()), + ) + + +@pytest.fixture +def table1() -> Table: + return Table.model_construct( + id=uuid.uuid4(), + name="table1", + fullyQualifiedName=FullyQualifiedEntityName( + root="TestService1.test_db.test_schema.table1" + ), + service=EntityReference.model_construct(id=uuid.uuid4(), name="test_service1"), + columns=[ + Column.model_construct( + name=ColumnName(root="id"), + dataType=DataType.STRING, + ), + Column.model_construct( + name=ColumnName(root="name"), + dataType=DataType.STRING, + ), + ], + ) + + +@pytest.fixture +def table2() -> Table: + return Table.model_construct( + id=uuid.uuid4(), + name="table2", + fullyQualifiedName=FullyQualifiedEntityName( + root="TestService2.test_db.test_schema.table2" + ), + service=EntityReference.model_construct(id=uuid.uuid4(), name="test_service2"), + columns=[ + Column.model_construct( + name=ColumnName(root="table_id"), + dataType=DataType.STRING, + ), + Column.model_construct( + name=ColumnName(root="name"), + dataType=DataType.STRING, + ), + ], + ) + + +def fake_get_service_url( + param_setter: TableParameterSetter, service: DatabaseService +) -> str: + return "postgresql+psycopg2://test:test@localhost/test" + + +@pytest.fixture +def setter( + metadata: OpenMetadata, + service_connection_config: DatabaseConnection, + sampler: SamplerInterface, + table1: Table, +) -> TableDiffParamsSetter: + return TableDiffParamsSetter( + ometa_client=metadata, + service_connection_config=service_connection_config, + sampler=sampler, + table_entity=table1, + service_url_getter=fake_get_service_url, + ) + + +@pytest.fixture +def parameter_values() -> List[TestCaseParameterValue]: + return [ + TestCaseParameterValue( + name="table2", value="TestService2.test_db.test_schema.table2" + ) + ] + + +def test_setter_gets_default_key_columns( + setter: TableDiffParamsSetter, parameter_values: List[TestCaseParameterValue] +) -> None: + test_case = TestCase.model_construct( + parameterValues=[ + *parameter_values, + TestCaseParameterValue(name="keyColumns", value=json.dumps(["id"])), + ], + ) + + assert setter.get_parameters(test_case) == IsInstance( + TableDiffRuntimeParameters + ) & HasAttributes( + keyColumns=["id"], + extraColumns=IsListOrTuple("name", "table_id", check_order=False), + table1=IsInstance(TableParameter) + & HasAttributes( + key_columns=["id"], + ), + table2=IsInstance(TableParameter) + & HasAttributes( + key_columns=["id"], + ), + ) + + +def test_setter_gets_per_table_key_columns( + setter: TableDiffParamsSetter, parameter_values: List[TestCaseParameterValue] +) -> None: + test_case = TestCase.model_construct( + parameterValues=[ + *parameter_values, + TestCaseParameterValue(name="keyColumns", value=json.dumps(["id"])), + TestCaseParameterValue( + name="table2.keyColumns", value=json.dumps(["table_id"]) + ), + ] + ) + + assert setter.get_parameters(test_case) == IsInstance( + TableDiffRuntimeParameters + ) & HasAttributes( + keyColumns=["id"], + extraColumns=IsListOrTuple("name", check_order=False), + table1=IsInstance(TableParameter) + & HasAttributes( + key_columns=["id"], + ), + table2=IsInstance(TableParameter) + & HasAttributes( + key_columns=["table_id"], + ), + ) diff --git a/ingestion/tests/unit/data_quality/validations/table/__init__.py b/ingestion/tests/unit/data_quality/validations/table/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/ingestion/tests/unit/data_quality/validations/table/sqlalchemy/__init__.py b/ingestion/tests/unit/data_quality/validations/table/sqlalchemy/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/ingestion/tests/unit/data_quality/validations/table/sqlalchemy/test_table_diff.py b/ingestion/tests/unit/data_quality/validations/table/sqlalchemy/test_table_diff.py new file mode 100644 index 00000000000..487f913f5e2 --- /dev/null +++ b/ingestion/tests/unit/data_quality/validations/table/sqlalchemy/test_table_diff.py @@ -0,0 +1,212 @@ +import datetime +from typing import Generator +from unittest.mock import MagicMock, Mock, patch + +import pytest +from dirty_equals import Contains, DirtyEquals, HasAttributes + +from metadata.data_quality.validations.models import ( + TableDiffRuntimeParameters, + TableParameter, +) +from metadata.data_quality.validations.table.sqlalchemy.tableDiff import ( + TableDiffValidator, +) +from metadata.generated.schema.entity.data.table import Column, ColumnName, DataType +from metadata.generated.schema.entity.services.databaseService import ( + DatabaseServiceType, +) +from metadata.generated.schema.tests.basic import TestCaseStatus +from metadata.generated.schema.tests.testCase import TestCase +from metadata.generated.schema.type.basic import Timestamp + + +def build_table_parameter( + *columns: Column, + key_columns: list[str], + extra_columns: list[str], + service_url: str = "postgresql://postgres:postgres@service:5432/postgres", +) -> TableParameter: + return TableParameter.model_construct( + serviceUrl=service_url, + path="test_schema.test_table", + database_service_type=DatabaseServiceType.Postgres, + columns=columns, + privateKey=None, + passPhrase=None, + key_columns=key_columns, + extra_columns=extra_columns, + ) + + +@pytest.fixture +def table1_parameter() -> TableParameter: + return build_table_parameter( + Column.model_construct(name=ColumnName(root="id"), dataType=DataType.STRING), + Column.model_construct( + name=ColumnName(root="first_name"), dataType=DataType.STRING + ), + Column.model_construct( + name=ColumnName(root="last_name"), dataType=DataType.STRING + ), + key_columns=["id"], + extra_columns=["first_name", "last_name"], + service_url="postgresql://postgres:postgres@service1:5432/postgres", + ) + + +@pytest.fixture +def table2_parameter() -> TableParameter: + return build_table_parameter( + Column.model_construct( + name=ColumnName(root="table_id"), dataType=DataType.STRING + ), + Column.model_construct( + name=ColumnName(root="first_name"), dataType=DataType.STRING + ), + Column.model_construct( + name=ColumnName(root="last_name"), dataType=DataType.STRING + ), + key_columns=["table_id"], + extra_columns=["first_name", "last_name"], + service_url="postgresql://postgres:postgres@service2:5432/postgres", + ) + + +@pytest.fixture +def parameters( + table1_parameter: TableParameter, table2_parameter: TableParameter +) -> TableDiffRuntimeParameters: + return TableDiffRuntimeParameters( + table1=table1_parameter, + table2=table2_parameter, + table_profile_config=None, + whereClause=None, + keyColumns=None, + extraColumns=None, + ) + + +@pytest.fixture +def validator( + parameters: TableDiffRuntimeParameters, +) -> Generator[TableDiffValidator, None, None]: + with patch( + "metadata.data_quality.validations.table.sqlalchemy.tableDiff.data_diff" + ) as data_diff: + mock_table = MagicMock() + mock_table.key_columns = [] + mock_table.extra_columns = [] + data_diff.connect_to_table = Mock(return_value=mock_table) + + validator = TableDiffValidator( + runner=[], + test_case=TestCase.model_construct(parameterValues=[]), + execution_date=Timestamp(root=int(datetime.datetime.now().timestamp())), + ) + validator.runtime_params = parameters + yield validator + + +class TestGetColumnDiff: + def test_it_returns_none_when_no_diff( + self, validator: TableDiffValidator, parameters: TableDiffRuntimeParameters + ) -> None: + assert validator.get_column_diff() is None + + @pytest.mark.parametrize( + "table1_parameter, table2_parameter, expected", + ( + ( + build_table_parameter( + Column.model_construct( + name=ColumnName(root="id"), dataType=DataType.STRING + ), + Column.model_construct( + name=ColumnName(root="last_name"), dataType=DataType.STRING + ), + key_columns=["id"], + extra_columns=["last_name"], + ), + build_table_parameter( + Column.model_construct( + name=ColumnName(root="id"), dataType=DataType.STRING + ), + Column.model_construct( + name=ColumnName(root="first_name"), dataType=DataType.STRING + ), + key_columns=["id"], + extra_columns=["first_name"], + ), + HasAttributes( + testCaseStatus=TestCaseStatus.Failed, + result=Contains("Removed columns: last_name") + & Contains("Added columns: first_name") + & ~Contains("Changed"), + ), + ), + ( + build_table_parameter( + Column.model_construct( + name=ColumnName(root="id"), dataType=DataType.STRING + ), + Column.model_construct( + name=ColumnName(root="last_name"), dataType=DataType.STRING + ), + key_columns=["id"], + extra_columns=["last_name"], + ), + build_table_parameter( + Column.model_construct( + name=ColumnName(root="table_id"), dataType=DataType.STRING + ), + Column.model_construct( + name=ColumnName(root="first_name"), dataType=DataType.STRING + ), + key_columns=["table_id"], + extra_columns=["first_name"], + ), + HasAttributes( + testCaseStatus=TestCaseStatus.Failed, + result=Contains("Removed columns: last_name") + & Contains("Added columns: first_name") + & ~Contains("Changed"), + ), + ), + ( + build_table_parameter( + Column.model_construct( + name=ColumnName(root="id"), dataType=DataType.STRING + ), + Column.model_construct( + name=ColumnName(root="last_name"), dataType=DataType.STRING + ), + key_columns=["id"], + extra_columns=["last_name"], + ), + build_table_parameter( + Column.model_construct( + name=ColumnName(root="table_id"), dataType=DataType.STRING + ), + Column.model_construct( + name=ColumnName(root="first_name"), dataType=DataType.STRING + ), + key_columns=["id"], # The error trying to solve in #22302 + extra_columns=["first_name"], + ), + HasAttributes( + testCaseStatus=TestCaseStatus.Failed, + result=Contains("Removed columns: last_name") + & Contains("Added columns: table_id, first_name") + & ~Contains("Changed"), + ), + ), + ), + ) + def test_it_returns_the_expected_result( + self, + validator: TableDiffValidator, + parameters: TableDiffRuntimeParameters, + expected: DirtyEquals, + ) -> None: + assert validator.get_column_diff() == expected diff --git a/ingestion/tests/unit/metadata/data_quality/test_data_diff.py b/ingestion/tests/unit/metadata/data_quality/test_data_diff.py index 669d60c402f..fa42732d27b 100644 --- a/ingestion/tests/unit/metadata/data_quality/test_data_diff.py +++ b/ingestion/tests/unit/metadata/data_quality/test_data_diff.py @@ -59,6 +59,7 @@ def test_compile_and_clauses(elements, expected): Column(name="id", dataType=DataType.STRING), Column(name="name", dataType=DataType.STRING), ], + "key_columns": ["id"], } ), "table2": TableParameter.model_construct( @@ -68,6 +69,7 @@ def test_compile_and_clauses(elements, expected): Column(name="id", dataType=DataType.STRING), Column(name="name", dataType=DataType.STRING), ], + "key_columns": ["id"], } ), "keyColumns": ["id"], @@ -90,6 +92,7 @@ def test_compile_and_clauses(elements, expected): Column(name="id", dataType=DataType.STRING), Column(name="name", dataType=DataType.STRING), ], + "key_columns": ["id"], } ), "table2": TableParameter.model_construct( @@ -99,6 +102,7 @@ def test_compile_and_clauses(elements, expected): Column(name="id", dataType=DataType.STRING), Column(name="name", dataType=DataType.STRING), ], + "key_columns": ["id"], } ), "keyColumns": ["id"], @@ -121,6 +125,7 @@ def test_compile_and_clauses(elements, expected): Column(name="id", dataType=DataType.STRING), Column(name="name", dataType=DataType.STRING), ], + "key_columns": ["id", "name"], } ), "table2": TableParameter.model_construct( @@ -130,6 +135,7 @@ def test_compile_and_clauses(elements, expected): Column(name="id", dataType=DataType.STRING), Column(name="name", dataType=DataType.STRING), ], + "key_columns": ["id", "name"], } ), "keyColumns": ["id", "name"], @@ -152,6 +158,7 @@ def test_compile_and_clauses(elements, expected): Column(name="id", dataType=DataType.STRING), Column(name="name", dataType=DataType.STRING), ], + "key_columns": ["id", "name"], } ), "table2": TableParameter.model_construct( @@ -161,6 +168,7 @@ def test_compile_and_clauses(elements, expected): Column(name="id", dataType=DataType.STRING), Column(name="name", dataType=DataType.STRING), ], + "key_columns": ["id", "name"], }, ), "keyColumns": ["id", "name"], @@ -182,6 +190,7 @@ def test_compile_and_clauses(elements, expected): Column(name="id", dataType=DataType.STRING), Column(name="name", dataType=DataType.STRING), ], + "key_columns": ["id"], } ), "table2": TableParameter.model_construct( @@ -191,6 +200,7 @@ def test_compile_and_clauses(elements, expected): Column(name="ID", dataType=DataType.STRING), Column(name="name", dataType=DataType.STRING), ], + "key_columns": ["id"], }, ), "keyColumns": ["id"], @@ -212,6 +222,7 @@ def test_compile_and_clauses(elements, expected): Column(name="id", dataType=DataType.STRING), Column(name="name", dataType=DataType.STRING), ], + "key_columns": ["id"], } ), "table2": TableParameter.model_construct( @@ -221,6 +232,7 @@ def test_compile_and_clauses(elements, expected): Column(name="id", dataType=DataType.STRING), Column(name="name", dataType=DataType.STRING), ], + "key_columns": ["id"], }, ), "keyColumns": ["id"], diff --git a/openmetadata-service/src/main/resources/json/data/tests/tableDiff.json b/openmetadata-service/src/main/resources/json/data/tests/tableDiff.json index ed0b68ec70f..0d6e61d92df 100644 --- a/openmetadata-service/src/main/resources/json/data/tests/tableDiff.json +++ b/openmetadata-service/src/main/resources/json/data/tests/tableDiff.json @@ -8,17 +8,24 @@ "OpenMetadata" ], "parameterDefinition": [ + { + "name": "keyColumns", + "displayName": "Table 1's key columns", + "description": "The columns to use as the key for the comparison. If not provided, it will be resolved from the primary key or unique columns. The tuples created from the key columns must be unique.", + "dataType": "ARRAY", + "required": false + }, { "name": "table2", "displayName": "Table 2", "description": "Fully qualified name of the table to compare against.", "dataType": "STRING", - "required": "true" + "required": true }, { - "name": "keyColumns", - "displayName": "Key Columns", - "description": "The columns to use as the key for the comparison. If not provided, it will be resolved from the primary key or unique columns. The tuples created from the key columns must be unique.", + "name": "table2.keyColumns", + "displayName": "Table 2's key columns", + "description": "The columns in table 2 to use as comparison. If not provided, it will default to `Key Columns`, risking errors if the key columns' names have changed.", "dataType": "ARRAY", "required": false }, diff --git a/openmetadata-ui/src/main/resources/ui/playwright/e2e/Pages/TestCases.spec.ts b/openmetadata-ui/src/main/resources/ui/playwright/e2e/Pages/TestCases.spec.ts index 9782d0d5e02..164b44a6ca0 100644 --- a/openmetadata-ui/src/main/resources/ui/playwright/e2e/Pages/TestCases.spec.ts +++ b/openmetadata-ui/src/main/resources/ui/playwright/e2e/Pages/TestCases.spec.ts @@ -164,7 +164,7 @@ test('Table difference test case', async ({ page }) => { await page .locator('label') - .filter({ hasText: 'Key Columns' }) + .filter({ hasText: "Table 1's key columns" }) .getByRole('button') .click(); await page.waitForSelector(`[data-id="tableDiff"]`, {