From 42416a513e376dfcda579ef88b8ef6b10ee9baef Mon Sep 17 00:00:00 2001 From: Eugenio Date: Tue, 4 Nov 2025 09:52:43 +0100 Subject: [PATCH] Simplified API for validating DataFrames (#24009) * Refactor previous tests for shared resources * Add validation result models This also includes a method for merging them, useful when running validation in batches * Added `DataFrameValidationEngine` for running tests This also includes a registry for mapping test names to pandas test classes * Implement the DataFrameValidator facade This includes the logic to load tests from different sources (OpenMetadata or code) and pass them down to the engine. It also includes tests for the integration with OpenMetadata * Add examples for the API * Apply comments --- .../ingestion/ometa/mixins/tests_mixin.py | 53 +- .../dataframes/custom_warnings.py | 2 + .../dataframes/dataframe_validation_engine.py | 145 +++++ .../dataframes/dataframe_validator.py | 201 +++++++ .../sdk/data_quality/dataframes/models.py | 44 ++ .../dataframes/validation_results.py | 243 ++++++++ .../sdk/data_quality/dataframes/validators.py | 133 +++++ .../examples/dataframe_validation_example.py | 294 +++++++++ ingestion/src/metadata/utils/entity_link.py | 15 + ingestion/tests/integration/sdk/conftest.py | 244 +++++++- .../dataframes/test_dataframe_validator.py | 255 ++++++++ .../sdk/test_dq_as_code_integration.py | 246 -------- .../data_quality/test_dataframe_validator.py | 561 ++++++++++++++++++ .../data_quality/test_validation_results.py | 377 ++++++++++++ 14 files changed, 2564 insertions(+), 249 deletions(-) create mode 100644 ingestion/src/metadata/sdk/data_quality/dataframes/custom_warnings.py create mode 100644 ingestion/src/metadata/sdk/data_quality/dataframes/dataframe_validation_engine.py create mode 100644 ingestion/src/metadata/sdk/data_quality/dataframes/dataframe_validator.py create mode 100644 ingestion/src/metadata/sdk/data_quality/dataframes/models.py create mode 100644 ingestion/src/metadata/sdk/data_quality/dataframes/validation_results.py create mode 100644 ingestion/src/metadata/sdk/data_quality/dataframes/validators.py create mode 100644 ingestion/src/metadata/sdk/examples/dataframe_validation_example.py create mode 100644 ingestion/tests/integration/sdk/data_quality/dataframes/test_dataframe_validator.py create mode 100644 ingestion/tests/unit/sdk/data_quality/test_dataframe_validator.py create mode 100644 ingestion/tests/unit/sdk/data_quality/test_validation_results.py diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/tests_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/tests_mixin.py index 1e93b9ad535..f799789279b 100644 --- a/ingestion/src/metadata/ingestion/ometa/mixins/tests_mixin.py +++ b/ingestion/src/metadata/ingestion/ometa/mixins/tests_mixin.py @@ -17,6 +17,7 @@ To be used by OpenMetadata class import traceback from datetime import datetime from typing import List, Optional, Type, Union +from urllib.parse import urlencode, urljoin from uuid import UUID from metadata.generated.schema.api.tests.createLogicalTestCases import ( @@ -170,7 +171,7 @@ class OMetaTestsMixin: test_definition_fqn: Optional[str] = None, test_case_parameter_values: Optional[List[TestCaseParameterValue]] = None, description: Optional[str] = None, - ): + ) -> TestCase: """Get or create a test case Args: @@ -203,6 +204,34 @@ class OMetaTestsMixin: ) return test_case + def get_executable_test_suite(self, table_fqn: str) -> Optional[TestSuite]: + """Given an entity fqn, retrieve the link test suite if it exists + + Args: + table_fqn (str): entity fully qualified name + + Returns: + An instance of TestSuite or None + """ + table_entity = self.get_by_name( + entity=Table, fqn=table_fqn, fields=["testSuite"] + ) + if not table_entity: + raise RuntimeError( + f"Unable to find table {table_fqn} in OpenMetadata. " + "This could be because the table has not been ingested yet or your JWT Token is expired or missing." + ) + + if not table_entity.testSuite: + return None + + return self.get_by_name( + entity=TestSuite, + fqn=table_entity.testSuite.fullyQualifiedName, + fields=["tests"], + nullable=False, + ) + def get_or_create_executable_test_suite( self, entity_fqn: str ) -> Union[EntityReference, TestSuite]: @@ -379,3 +408,25 @@ class OMetaTestsMixin: data=inspection_query, ) return TestCase(**resp) + + def delete_test_case( + self, + test_case_fqn: str, + recursive: bool = True, + hard: bool = False, + ) -> None: + """Delete a test case + Args: + test_case_fqn: Fully qualified name of the test case to delete + recursive (bool, optional): delete children if true + hard (bool, optional): hard delete if true + """ + params = urlencode( + dict( + recursive="true" if recursive else "false", + hardDelete="true" if hard else "false", + ) + ) + url = f"{self.get_suffix(TestCase)}/name/{quote(test_case_fqn)}" + + self.client.delete(urljoin(url, "?" + params)) diff --git a/ingestion/src/metadata/sdk/data_quality/dataframes/custom_warnings.py b/ingestion/src/metadata/sdk/data_quality/dataframes/custom_warnings.py new file mode 100644 index 00000000000..b08efdca6e7 --- /dev/null +++ b/ingestion/src/metadata/sdk/data_quality/dataframes/custom_warnings.py @@ -0,0 +1,2 @@ +class WholeTableTestsWarning(RuntimeWarning): + """Warns when the user runs tests that require the whole table on a subset of it""" diff --git a/ingestion/src/metadata/sdk/data_quality/dataframes/dataframe_validation_engine.py b/ingestion/src/metadata/sdk/data_quality/dataframes/dataframe_validation_engine.py new file mode 100644 index 00000000000..ae43019e148 --- /dev/null +++ b/ingestion/src/metadata/sdk/data_quality/dataframes/dataframe_validation_engine.py @@ -0,0 +1,145 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Orchestration engine for DataFrame validation execution.""" +import logging +import time +from datetime import datetime +from typing import List, Tuple, Type + +from pandas import DataFrame + +from metadata.data_quality.validations.base_test_handler import BaseTestValidator +from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus +from metadata.generated.schema.tests.testCase import TestCase +from metadata.generated.schema.type.basic import Timestamp +from metadata.sdk.data_quality.dataframes.validation_results import ( + FailureMode, + ValidationResult, +) +from metadata.sdk.data_quality.dataframes.validators import VALIDATOR_REGISTRY + +logger = logging.getLogger(__name__) + + +class DataFrameValidationEngine: + """Orchestrates execution of multiple validators on a DataFrame.""" + + def __init__(self, test_cases: List[TestCase]): + self.test_cases: List[TestCase] = test_cases + + def execute( + self, + df: DataFrame, + mode: FailureMode = FailureMode.SHORT_CIRCUIT, + ) -> ValidationResult: + """Execute all validations and return aggregated results. + + Args: + df: DataFrame to validate + mode: Validation mode (only "short-circuit" supported) + + Returns: + ValidationResult with outcomes for all tests + """ + results: List[Tuple[TestCase, TestCaseResult]] = [] + start_time = time.time() + + for test_case in self.test_cases: + test_result = self._execute_single_test(df, test_case) + results.append((test_case, test_result)) + + if mode is FailureMode.SHORT_CIRCUIT and test_result.testCaseStatus in ( + TestCaseStatus.Failed, + TestCaseStatus.Aborted, + ): + break + + execution_time = (time.time() - start_time) * 1000 + return self._build_validation_result(results, execution_time) + + def _execute_single_test( + self, df: DataFrame, test_case: TestCase + ) -> TestCaseResult: + """Execute validation and return structured result. + + Returns: + TestValidationResult with validation outcome + """ + validator_class = self._get_validator_class(test_case) + + validator = validator_class( + runner=[df], + test_case=test_case, + execution_date=Timestamp(root=int(datetime.now().timestamp() * 1000)), + ) + + try: + return validator.run_validation() + except Exception as err: + message = ( + f"Error executing {test_case.testDefinition.fullyQualifiedName} - {err}" + ) + logger.exception(message) + return validator.get_test_case_result_object( + validator.execution_date, + TestCaseStatus.Aborted, + message, + [], + ) + + @staticmethod + def _build_validation_result( + test_results: List[Tuple[TestCase, TestCaseResult]], execution_time_ms: float + ) -> ValidationResult: + """Build aggregated validation result. + + Args: + test_results: Individual test results + execution_time_ms: Total execution time + + Returns: + ValidationResult with aggregated outcomes + """ + passed = sum( + 1 for _, r in test_results if r.testCaseStatus == TestCaseStatus.Success + ) + failed = len(test_results) - passed + success = failed == 0 + + return ValidationResult( + success=success, + total_tests=len(test_results), + passed_tests=passed, + failed_tests=failed, + test_cases_and_results=test_results, + execution_time_ms=execution_time_ms, + ) + + @staticmethod + def _get_validator_class(test_case: TestCase) -> Type[BaseTestValidator]: + """Resolve validator class from test definition name. + + Returns: + Validator class for the test definition + + Raises: + ValueError: If test definition is not supported + """ + validator_class = VALIDATOR_REGISTRY.get( + test_case.testDefinition.fullyQualifiedName # pyright: ignore[reportArgumentType] + ) + if not validator_class: + raise ValueError( + f"Unknown test definition: {test_case.testDefinition.fullyQualifiedName}" + ) + + return validator_class diff --git a/ingestion/src/metadata/sdk/data_quality/dataframes/dataframe_validator.py b/ingestion/src/metadata/sdk/data_quality/dataframes/dataframe_validator.py new file mode 100644 index 00000000000..ef5dcc789b1 --- /dev/null +++ b/ingestion/src/metadata/sdk/data_quality/dataframes/dataframe_validator.py @@ -0,0 +1,201 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DataFrame validation API.""" +import warnings +from typing import Any, Callable, Iterable, List, Optional, cast, final + +from pandas import DataFrame + +from metadata.generated.schema.tests.testCase import TestCase +from metadata.ingestion.ometa.ometa_api import OpenMetadata as OMeta +from metadata.sdk import OpenMetadata +from metadata.sdk import client as get_client +from metadata.sdk.data_quality.dataframes.custom_warnings import WholeTableTestsWarning +from metadata.sdk.data_quality.dataframes.dataframe_validation_engine import ( + DataFrameValidationEngine, +) +from metadata.sdk.data_quality.dataframes.models import create_mock_test_case +from metadata.sdk.data_quality.dataframes.validation_results import ( + FailureMode, + ValidationResult, +) +from metadata.sdk.data_quality.dataframes.validators import requires_whole_table +from metadata.sdk.data_quality.tests.base_tests import BaseTest + +ValidatorCallback = Callable[[DataFrame, ValidationResult], None] + + +@final +class DataFrameValidator: + """Facade for DataFrame data quality validation. + + Provides a simple interface to configure and execute data quality tests + on pandas DataFrames using OpenMetadata test definitions. + + Example: + validator = DataFrameValidator() + validator.add_test(ColumnValuesToBeNotNull(column="email")) + validator.add_test(ColumnValuesToBeUnique(column="customer_id")) + + result = validator.validate(df, mode=FailureMode.ShortCircuit) + if not result.success: + print(f"Validation failed: {result.failures}") + """ + + def __init__( + self, + client: Optional[ # pyright: ignore[reportRedeclaration] + OMeta[Any, Any] + ] = None, + ): + self._test_cases: List[TestCase] = [] + + if client is None: + metadata: OpenMetadata = get_client() + client: OMeta[Any, Any] = metadata.ometa + + self._client = client + + def add_test(self, test: BaseTest) -> None: + """Add a single test definition to be executed. + + Args: + test: Test definition (e.g., ColumnValuesToBeNotNull) + """ + self._test_cases.append(create_mock_test_case(test)) + + def add_tests(self, *tests: BaseTest) -> None: + """Add multiple test definitions at once. + + Args: + *tests: Variable number of test definitions + """ + self._test_cases.extend(create_mock_test_case(t) for t in tests) + + def add_openmetadata_test(self, test_fqn: str) -> None: + test_case = cast( + TestCase, + self._client.get_by_name( + TestCase, + test_fqn, + fields=["testDefinition", "testSuite"], + nullable=False, + ), + ) + + self._test_cases.append(test_case) + + def add_openmetadata_table_tests(self, table_fqn: str) -> None: + test_suite = self._client.get_executable_test_suite(table_fqn) + + if test_suite is None: + raise ValueError(f"Table {table_fqn!r} does not have a test suite to run") + + for test in test_suite.tests or []: + assert test.fullyQualifiedName is not None + self.add_openmetadata_test(test.fullyQualifiedName) + + def validate( + self, + df: DataFrame, + mode: FailureMode = FailureMode.SHORT_CIRCUIT, + ) -> ValidationResult: + """Execute all configured tests on the DataFrame. + + Args: + df: DataFrame to validate + mode: Validation mode (`FailureMode.ShortCircuit` stops on first failure) + + Returns: + ValidationResult with outcomes for all tests + """ + engine = DataFrameValidationEngine(self._test_cases) + return engine.execute(df, mode) + + def _check_full_table_tests_included(self) -> None: + test_names: set[str] = { # pyright: ignore[reportAssignmentType] + test.testDefinition.fullyQualifiedName + for test in self._test_cases + if requires_whole_table( + test.testDefinition.fullyQualifiedName # pyright: ignore[reportArgumentType] + ) + } + + if not test_names: + return + + warnings.warn( + WholeTableTestsWarning( + "Running tests that require the whole table on chunks could lead to false positives. " + + "For example, a DataFrame with 200 rows split in chunks of 50 could pass tests expecting " + + "DataFrames to contain max 100 rows.\n\n" + + "The following tests could have unexpected results:\n\n\t- " + + "\n\t- ".join(sorted(test_names)) + ) + ) + + def run( + self, + data: Iterable[DataFrame], + on_success: ValidatorCallback, + on_failure: ValidatorCallback, + mode: FailureMode = FailureMode.SHORT_CIRCUIT, + ) -> ValidationResult: + """Execute all configured tests on the DataFrame and call callbacks. + + Useful for running validation based on chunks, for example: + + ```python + validator = DataFrameValidator() + validator.add_test(ColumnValuesToBeNotNull(column="email")) + + def load_df_to_destination(df, result): + ... + + def rollback(df, result): + "Clears data previously loaded" + ... + + result = validator.run( + pandas.read_csv('somefile.csv', chunksize=1000), + on_success=load_df_to_destination, + on_failure=rollback, + mode=FailureMode.SHORT_CIRCUIT, + ) + ``` + + Args: + data: An iterable of pandas DataFrames + on_success: Callback to execute after successful validation + on_failure: Callback to execute after failed validation + mode: Validation mode (`FailureMode.ShortCircuit` stops on first failure) + + Returns: + Merged ValidationResult aggregating all batch validations + """ + self._check_full_table_tests_included() + + results: List[ValidationResult] = [] + + for df in data: + validation_result = self.validate(df, mode) + results.append(validation_result) + + if validation_result.success: + on_success(df, validation_result) + else: + on_failure(df, validation_result) + + if mode is FailureMode.SHORT_CIRCUIT: + break + + return ValidationResult.merge(*results) diff --git a/ingestion/src/metadata/sdk/data_quality/dataframes/models.py b/ingestion/src/metadata/sdk/data_quality/dataframes/models.py new file mode 100644 index 00000000000..40edc1d5a49 --- /dev/null +++ b/ingestion/src/metadata/sdk/data_quality/dataframes/models.py @@ -0,0 +1,44 @@ +from uuid import uuid4 + +from metadata.generated.schema.tests.testCase import TestCase +from metadata.generated.schema.type.entityReference import EntityReference +from metadata.sdk.data_quality import BaseTest, ColumnTest + + +class MockTestCase(TestCase): + """Mock test case.""" + + +def create_mock_test_case(test_definition: BaseTest) -> MockTestCase: + """Convert TestCaseDefinition to TestCase object. + + Returns: + Synthetic TestCase for DataFrame validation + """ + entity_link = "<#E::table::dataframe_validation>" + if isinstance(test_definition, ColumnTest): + entity_link = ( + f"<#E::table::dataframe_validation::columns::{test_definition.column_name}>" + ) + + return MockTestCase( # pyright: ignore[reportCallIssue] + id=uuid4(), + name=test_definition.name, + fullyQualifiedName=test_definition.name, + displayName=test_definition.display_name, + description=test_definition.description, + testDefinition=EntityReference( # pyright: ignore[reportCallIssue] + id=uuid4(), + name=test_definition.test_definition_name, + fullyQualifiedName=test_definition.test_definition_name, + type="testDefinition", + ), + entityLink=entity_link, + parameterValues=test_definition.parameters, + testSuite=EntityReference( # pyright: ignore[reportCallIssue] + id=uuid4(), + name="dataframe_validation", + type="testSuite", + ), + computePassedFailedRowCount=True, + ) diff --git a/ingestion/src/metadata/sdk/data_quality/dataframes/validation_results.py b/ingestion/src/metadata/sdk/data_quality/dataframes/validation_results.py new file mode 100644 index 00000000000..171f775fce7 --- /dev/null +++ b/ingestion/src/metadata/sdk/data_quality/dataframes/validation_results.py @@ -0,0 +1,243 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DataFrame validation result models.""" +import logging +from enum import Enum +from typing import List, Optional, Tuple, cast + +from pydantic import BaseModel + +from metadata.generated.schema.entity.data.table import Table +from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus +from metadata.generated.schema.tests.testCase import TestCase +from metadata.generated.schema.type.basic import FullyQualifiedEntityName +from metadata.sdk import OpenMetadata +from metadata.sdk import client as get_client +from metadata.sdk.data_quality.dataframes.models import MockTestCase +from metadata.utils.entity_link import ( + get_entity_link, # pyright: ignore[reportUnknownVariableType] +) +from metadata.utils.entity_link import get_column_name_or_none + +logger = logging.getLogger(__name__) + + +class FailureMode(Enum): + SHORT_CIRCUIT = "short-circuit" + + +class ValidationResult(BaseModel): + """Aggregated results from validating multiple tests on a DataFrame. + + Attributes: + success: True if all tests passed + total_tests: Total number of tests executed + passed_tests: Number of tests that passed + failed_tests: Number of tests that failed + test_results: Individual test results + execution_time_ms: Total execution time in milliseconds + """ + + success: bool + total_tests: int + passed_tests: int + failed_tests: int + test_cases_and_results: List[Tuple[TestCase, TestCaseResult]] + execution_time_ms: float + + @property + def failures(self) -> List[TestCaseResult]: + """Get only failed test results. + + Returns: + List of test results where status is Failed or Aborted + """ + return [ + result + for result in self.test_results + if result.testCaseStatus in (TestCaseStatus.Failed, TestCaseStatus.Aborted) + ] + + @property + def passes(self) -> List[TestCaseResult]: + """Get only passed test results. + + Returns: + List of test results where status is Success + """ + return [ + result + for result in self.test_results + if result.testCaseStatus == TestCaseStatus.Success + ] + + @property + def test_results(self) -> List[TestCaseResult]: + """Get all test results.""" + return [result for _, result in self.test_cases_and_results] + + def publish(self, table_fqn: str, client: Optional[OpenMetadata] = None) -> None: + """Publish test results to OpenMetadata. + Args: + table_fqn: Fully qualified table name + client: OpenMetadata client + """ + if client is None: + client = get_client() + + metadata = client.ometa + + for test_case, result in self.test_cases_and_results: + if isinstance(test_case, MockTestCase): + test_case = metadata.get_or_create_test_case( + test_case_fqn=f"{table_fqn}.{test_case.name.root}", + entity_link=get_entity_link( + Table, + table_fqn, + column_name=get_column_name_or_none(test_case.entityLink.root), + ), + test_definition_fqn=test_case.testDefinition.fullyQualifiedName, + test_case_parameter_values=test_case.parameterValues, + description=getattr(test_case.description, "root", None), + ) + + res = metadata.add_test_case_results( + result, + cast(FullyQualifiedEntityName, test_case.fullyQualifiedName).root, + ) + + logger.debug(f"Result: {res}") + + @classmethod + def merge(cls, *results: "ValidationResult") -> "ValidationResult": + """Merge multiple ValidationResult objects into one. + + Aggregates results from multiple validation runs, useful when validating + DataFrames in batches. When the same test case is run multiple times across + batches, results are aggregated by test case FQN. + + Args: + *results: Variable number of ValidationResult objects to merge + + Returns: + A new ValidationResult with aggregated test case results + + Raises: + ValueError: If no results are provided to merge + """ + if not results: + raise ValueError("At least one ValidationResult must be provided to merge") + + from collections import defaultdict + + aggregated_results: dict[ + str, List[Tuple[TestCase, TestCaseResult]] + ] = defaultdict(list) + total_execution_time = 0.0 + + for result in results: + for test_case, test_result in result.test_cases_and_results: + fqn = test_case.fullyQualifiedName + if fqn is None: + raise ValueError( + "Cannot merge results with test cases that have no fullyQualifiedName" + ) + aggregated_results[str(fqn)].append((test_case, test_result)) + total_execution_time += result.execution_time_ms + + merged_test_cases_and_results: List[Tuple[TestCase, TestCaseResult]] = [] + for fqn, test_cases_and_results_for_fqn in aggregated_results.items(): + test_case = test_cases_and_results_for_fqn[0][0] + results_for_test = [result for _, result in test_cases_and_results_for_fqn] + + merged_result = cls._aggregate_test_case_results(results_for_test) + merged_test_cases_and_results.append((test_case, merged_result)) + + total = len(merged_test_cases_and_results) + passed = sum( + 1 + for _, test_result in merged_test_cases_and_results + if test_result.testCaseStatus is TestCaseStatus.Success + ) + failed = total - passed + + return cls( + success=failed == 0, + total_tests=total, + passed_tests=passed, + failed_tests=failed, + test_cases_and_results=merged_test_cases_and_results, + execution_time_ms=total_execution_time, + ) + + @staticmethod + def _aggregate_test_case_results( + results: List[TestCaseResult], + ) -> TestCaseResult: + """Aggregate multiple TestCaseResult objects for the same test case. + + Combines metrics from multiple test runs by summing passed/failed rows + and determining overall status. + + Args: + results: List of TestCaseResult objects from different batch runs + + Returns: + A single aggregated TestCaseResult + """ + if not results: + raise ValueError("At least one TestCaseResult must be provided") + + if len(results) == 1: + return results[0] + + total_passed_rows = sum(r.passedRows or 0 for r in results) + total_failed_rows = sum(r.failedRows or 0 for r in results) + total_rows = total_passed_rows + total_failed_rows + + passed_rows_percentage = ( + (total_passed_rows / total_rows * 100) if total_rows > 0 else None + ) + failed_rows_percentage = ( + (total_failed_rows / total_rows * 100) if total_rows > 0 else None + ) + + overall_status = TestCaseStatus.Success + if any(r.testCaseStatus == TestCaseStatus.Aborted for r in results): + overall_status = TestCaseStatus.Aborted + elif any(r.testCaseStatus == TestCaseStatus.Failed for r in results): + overall_status = TestCaseStatus.Failed + + first_result = results[0] + merged_result_messages = [r.result for r in results if r.result] + + return TestCaseResult( + id=None, + testCaseFQN=first_result.testCaseFQN, + timestamp=first_result.timestamp, + testCaseStatus=overall_status, + result=( + "; ".join(merged_result_messages) if merged_result_messages else None + ), + sampleData=None, + testResultValue=None, + passedRows=total_passed_rows if total_rows > 0 else None, + failedRows=total_failed_rows if total_rows > 0 else None, + passedRowsPercentage=passed_rows_percentage, + failedRowsPercentage=failed_rows_percentage, + incidentId=None, + maxBound=first_result.maxBound, + minBound=first_result.minBound, + testCase=first_result.testCase, + testDefinition=first_result.testDefinition, + dimensionResults=None, + ) diff --git a/ingestion/src/metadata/sdk/data_quality/dataframes/validators.py b/ingestion/src/metadata/sdk/data_quality/dataframes/validators.py new file mode 100644 index 00000000000..4430796b576 --- /dev/null +++ b/ingestion/src/metadata/sdk/data_quality/dataframes/validators.py @@ -0,0 +1,133 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Registry of pandas validators.""" + +from metadata.data_quality.validations.column.pandas.columnValueLengthsToBeBetween import ( + ColumnValueLengthsToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValueMaxToBeBetween import ( + ColumnValueMaxToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValueMeanToBeBetween import ( + ColumnValueMeanToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValueMedianToBeBetween import ( + ColumnValueMedianToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValueMinToBeBetween import ( + ColumnValueMinToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesMissingCount import ( + ColumnValuesMissingCountValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesSumToBeBetween import ( + ColumnValuesSumToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValueStdDevToBeBetween import ( + ColumnValueStdDevToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToBeAtExpectedLocation import ( + ColumnValuesToBeAtExpectedLocationValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToBeBetween import ( + ColumnValuesToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToBeInSet import ( + ColumnValuesToBeInSetValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToBeNotInSet import ( + ColumnValuesToBeNotInSetValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToBeNotNull import ( + ColumnValuesToBeNotNullValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToBeUnique import ( + ColumnValuesToBeUniqueValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToMatchRegex import ( + ColumnValuesToMatchRegexValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToNotMatchRegex import ( + ColumnValuesToNotMatchRegexValidator, +) +from metadata.data_quality.validations.table.pandas.tableColumnCountToBeBetween import ( + TableColumnCountToBeBetweenValidator, +) +from metadata.data_quality.validations.table.pandas.tableColumnCountToEqual import ( + TableColumnCountToEqualValidator, +) +from metadata.data_quality.validations.table.pandas.tableColumnNameToExist import ( + TableColumnNameToExistValidator, +) +from metadata.data_quality.validations.table.pandas.tableColumnToMatchSet import ( + TableColumnToMatchSetValidator, +) +from metadata.data_quality.validations.table.pandas.tableRowCountToBeBetween import ( + TableRowCountToBeBetweenValidator, +) +from metadata.data_quality.validations.table.pandas.tableRowCountToEqual import ( + TableRowCountToEqualValidator, +) + +VALIDATOR_REGISTRY = { + "columnValuesToBeNotNull": ColumnValuesToBeNotNullValidator, + "columnValuesToBeUnique": ColumnValuesToBeUniqueValidator, + "columnValuesToBeBetween": ColumnValuesToBeBetweenValidator, + "columnValuesToBeInSet": ColumnValuesToBeInSetValidator, + "columnValuesToBeNotInSet": ColumnValuesToBeNotInSetValidator, + "columnValuesToMatchRegex": ColumnValuesToMatchRegexValidator, + "columnValuesToNotMatchRegex": ColumnValuesToNotMatchRegexValidator, + "columnValueLengthsToBeBetween": ColumnValueLengthsToBeBetweenValidator, + "columnValueMaxToBeBetween": ColumnValueMaxToBeBetweenValidator, + "columnValueMeanToBeBetween": ColumnValueMeanToBeBetweenValidator, + "columnValueMedianToBeBetween": ColumnValueMedianToBeBetweenValidator, + "columnValueMinToBeBetween": ColumnValueMinToBeBetweenValidator, + "columnValueStdDevToBeBetween": ColumnValueStdDevToBeBetweenValidator, + "columnValuesSumToBeBetween": ColumnValuesSumToBeBetweenValidator, + "columnValuesMissingCount": ColumnValuesMissingCountValidator, + "columnValuesToBeAtExpectedLocation": ColumnValuesToBeAtExpectedLocationValidator, + "tableRowCountToBeBetween": TableRowCountToBeBetweenValidator, + "tableRowCountToEqual": TableRowCountToEqualValidator, + "tableColumnCountToBeBetween": TableColumnCountToBeBetweenValidator, + "tableColumnCountToEqual": TableColumnCountToEqualValidator, + "tableColumnNameToExist": TableColumnNameToExistValidator, + "tableColumnToMatchSet": TableColumnToMatchSetValidator, +} + + +VALIDATORS_THAT_REQUIRE_FULL_TABLE = { + "columnValuesToBeUnique", + "columnValueMeanToBeBetween", + "columnValueMedianToBeBetween", + "columnValueStdDevToBeBetween", + "columnValuesSumToBeBetween", + "columnValuesMissingCount", + "tableRowCountToBeBetween", + "tableRowCountToEqual", +} + + +def requires_whole_table(validator_name: str) -> bool: + """Whether the validator requires a whole table to return appropriate results + + Examples: + - `columnValuesToBeUnique` needs to see the whole column to make sure uniqueness is met + - `tableRowCountToEqual` needs to see the whole table to make sure the expected row count is met + + These tests could return false positives when operating on batches + + Args: + validator_name: The name of the validator to check + + Returns: + Whether the validator requires a whole table to return appropriate results + """ + return validator_name in VALIDATORS_THAT_REQUIRE_FULL_TABLE diff --git a/ingestion/src/metadata/sdk/examples/dataframe_validation_example.py b/ingestion/src/metadata/sdk/examples/dataframe_validation_example.py new file mode 100644 index 00000000000..8959a8a1687 --- /dev/null +++ b/ingestion/src/metadata/sdk/examples/dataframe_validation_example.py @@ -0,0 +1,294 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Examples demonstrating DataFrame validation with OpenMetadata SDK.""" +# pyright: reportUnknownVariableType=false, reportAttributeAccessIssue=false, reportUnknownMemberType=false +# pyright: reportUnusedCallResult=false +# pylint: disable=W5001 +import pandas as pd + +from metadata.sdk import configure +from metadata.sdk.data_quality import ( + ColumnValuesToBeBetween, + ColumnValuesToBeNotNull, + ColumnValuesToBeUnique, + ColumnValuesToMatchRegex, + TableRowCountToBeBetween, +) +from metadata.sdk.data_quality.dataframes.dataframe_validator import DataFrameValidator +from metadata.sdk.data_quality.dataframes.validation_results import ( + FailureMode, + ValidationResult, +) + + +def basic_validation_example(): + """Basic example validating a customer DataFrame.""" + print("\n=== Basic DataFrame Validation Example ===\n") + + df = pd.DataFrame( + { + "customer_id": [1, 2, 3, 4, 5], + "email": [ + "alice@example.com", + "bob@example.com", + "carol@example.com", + "dave@example.com", + "eve@example.com", + ], + "age": [25, 30, 35, 40, 45], + } + ) + + validator = DataFrameValidator() + validator.add_test(ColumnValuesToBeNotNull(column="email")) + validator.add_test(ColumnValuesToBeUnique(column="customer_id")) + + result = validator.validate(df) + + if result.success: + print("✓ All validations passed!") + print( + f" Executed {result.total_tests} tests in {result.execution_time_ms:.2f}ms" + ) + else: + print("✗ Validation failed") + for failure in result.failures: + print(f" - {failure.test_name}: {failure.result_message}") + + +def multiple_tests_example(): + """Example with multiple validation rules.""" + print("\n=== Multiple Tests Validation Example ===\n") + + df = pd.DataFrame( + { + "customer_id": [1, 2, 3, 4, 5], + "email": [ + "alice@example.com", + "bob@example.com", + "carol@example.com", + "dave@example.com", + "eve@example.com", + ], + "age": [25, 30, 35, 40, 45], + "status": ["active", "active", "inactive", "active", "active"], + } + ) + + validator = DataFrameValidator() + validator.add_tests( + TableRowCountToBeBetween(min_count=1, max_count=1000), + ColumnValuesToBeNotNull(column="customer_id"), + ColumnValuesToBeNotNull(column="email"), + ColumnValuesToBeUnique(column="customer_id"), + ColumnValuesToBeBetween(column="age", min_value=0, max_value=120), + ColumnValuesToMatchRegex(column="email", regex=r"^[\w\.-]+@[\w\.-]+\.\w+$"), + ) + + result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + print(f"Validation: {'PASSED' if result.success else 'FAILED'}") + print(f"Tests: {result.passed_tests}/{result.total_tests} passed") + print(f"Execution time: {result.execution_time_ms:.2f}ms\n") + + for test_result in result.test_results: + status_icon = "✓" if test_result.status.value == "Success" else "✗" + print(f"{status_icon} {test_result.test_name}") + if test_result.passed_rows > 0: + print(f" Passed: {test_result.passed_rows}/{test_result.total_rows} rows") + if test_result.failed_rows > 0: + percentage = test_result.failed_rows / test_result.total_rows * 100 + print(f" Failed: {test_result.failed_rows} rows ({percentage:.1f}%)") + + +def integrating_with_openmetadata_example(): + """Integrating with OpenMetadata.""" + + def transform_to_dwh_table(raw_df: pd.DataFrame) -> pd.DataFrame: + """Transform the dataframe to dwh table.""" + return raw_df + + configure(host="http://localhost:8585/api", jwt_token="your jwt token") + + df = pd.read_parquet("s3://some_bucket/raw_table.parquet") + + df = transform_to_dwh_table(df) + + # Instantiate validator and load the executable test suite for a table + validator = DataFrameValidator() + validator.add_openmetadata_table_tests( + "DbService.database_name.schema_name.dwh_table" + ) + + result = validator.validate(df) + print(f"Validation: {'PASSED' if result.success else 'FAILED'}") + + # Publish the results back to Open Metadata + result.publish("DbService.database_name.schema_name.dwh_table") + + if result.success: + df.to_parquet("s3://some_bucket/dwh_table.parquet") + + +def processing_big_data_with_chunks_example(): + """Processing big data with chunks.""" + + configure(host="http://localhost:8585/api", jwt_token="your jwt token") + + validator = DataFrameValidator() + validator.add_openmetadata_table_tests( + "DbService.database_name.schema_name.dwh_table" + ) + + def load_df_to_destination(_df: pd.DataFrame, _result: ValidationResult): + """Loads data into destination.""" + + def rollback(_df: pd.DataFrame, _result: ValidationResult): + """Clears data previously loaded""" + + results = validator.run( + pd.read_csv("somefile.csv", chunksize=1000), + on_success=load_df_to_destination, + on_failure=rollback, + mode=FailureMode.SHORT_CIRCUIT, + ) + + results.publish("DbService.database_name.schema_name.dwh_table") + + +def validation_failure_example(): + """Example demonstrating validation failures.""" + print("\n=== Validation Failure Example ===\n") + + df = pd.DataFrame( + { + "customer_id": [1, 2, 2, 4, 5], + "email": [ + "alice@example.com", + None, + "carol@example.com", + "dave@example.com", + "invalid-email", + ], + "age": [25, 150, 35, -5, 45], + } + ) + + validator = DataFrameValidator() + validator.add_tests( + ColumnValuesToBeUnique(column="customer_id"), + ColumnValuesToBeNotNull(column="email"), + ColumnValuesToBeBetween(column="age", min_value=0, max_value=120), + ) + + result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + print(f"Validation: {'PASSED' if result.success else 'FAILED'}\n") + + if not result.success: + print("Failures detected:") + for failure in result.failures: + print(f"\n Test: {failure.test_name}") + print(f" Type: {failure.test_type}") + print(f" Message: {failure.result_message}") + print(f" Failed rows: {failure.failed_rows}/{failure.total_rows}") + + +def etl_pipeline_integration_example(): + """Example integrating validation into an ETL pipeline.""" + print("\n=== ETL Pipeline Integration Example ===\n") + + def extract_data(): + return pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "name": ["Alice", "Bob", "Carol", "Dave", "Eve"], + "value": [100, 200, 300, 400, 500], + } + ) + + def transform_data(df: pd.DataFrame) -> pd.DataFrame: + df = df.copy() + df["value_doubled"] = df["value"] * 2 + return df + + def validate_data(df: pd.DataFrame) -> ValidationResult: + validator = DataFrameValidator() + validator.add_tests( + TableRowCountToBeBetween(min_count=1, max_count=10000), + ColumnValuesToBeNotNull(column="id"), + ColumnValuesToBeUnique(column="id"), + ColumnValuesToBeBetween(column="value", min_value=0, max_value=10000), + ) + return validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + def load_data(df: pd.DataFrame) -> None: + print(f"Loading {len(df)} rows to data warehouse...") + + print("Starting ETL pipeline...") + print("\n1. Extract") + raw_df = extract_data() + print(f" Extracted {len(raw_df)} rows") + + print("\n2. Transform") + transformed_df = transform_data(raw_df) + print(f" Transformed {len(transformed_df)} rows") + + print("\n3. Validate") + validation_result = validate_data(transformed_df) + + if validation_result.success: + print(" ✓ Validation passed") + print("\n4. Load") + load_data(transformed_df) + print(" ✓ Data loaded successfully") + else: + print(" ✗ Validation failed") + print("\n Failures:") + for failure in validation_result.failures: + print(f" - {failure.test_name}: {failure.result_message}") + print("\n Pipeline aborted - data not loaded") + + +def short_circuit_mode_example(): + """Example demonstrating short-circuit mode behavior.""" + print("\n=== Short-Circuit Mode Example ===\n") + + df = pd.DataFrame( + { + "id": [1, 2, 2, 3], + "email": [None, None, None, "test@example.com"], + "age": [25, 30, 35, 40], + } + ) + + validator = DataFrameValidator() + validator.add_tests( + ColumnValuesToBeUnique(column="id"), + ColumnValuesToBeNotNull(column="email"), + ColumnValuesToBeBetween(column="age", min_value=0, max_value=120), + ) + + result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + print("Short-circuit mode stops at first failure:") + print(f" Tests executed: {len(result.test_results)} of {result.total_tests}") + print(f" First failure: {result.failures[0].test_name}") + print("\nRemaining tests were not executed due to short-circuit mode.") + + +if __name__ == "__main__": + basic_validation_example() + multiple_tests_example() + validation_failure_example() + etl_pipeline_integration_example() + short_circuit_mode_example() diff --git a/ingestion/src/metadata/utils/entity_link.py b/ingestion/src/metadata/utils/entity_link.py index 3408d20e9d9..a3ba3bad16c 100644 --- a/ingestion/src/metadata/utils/entity_link.py +++ b/ingestion/src/metadata/utils/entity_link.py @@ -106,6 +106,21 @@ def get_table_or_column_fqn(entity_link: str) -> str: ) +def get_column_name_or_none(entity_link: str) -> Optional[str]: + """It attempts to get a column from an entity link + + Args: + entity_link: entity link + + Returns: + The column name or None + """ + split_entity_link = split(entity_link) + if len(split_entity_link) == 4 and split_entity_link[2] == "columns": + return split_entity_link[3] + return None + + get_entity_link_registry = class_register() diff --git a/ingestion/tests/integration/sdk/conftest.py b/ingestion/tests/integration/sdk/conftest.py index 0854c14b63a..1d1e1bbff13 100644 --- a/ingestion/tests/integration/sdk/conftest.py +++ b/ingestion/tests/integration/sdk/conftest.py @@ -2,12 +2,252 @@ Minimal conftest for SDK integration tests. Override the parent conftest to avoid testcontainers dependency. """ + import pytest +from sqlalchemy import Column as SQAColumn +from sqlalchemy import Integer, MetaData, String +from sqlalchemy import Table as SQATable +from sqlalchemy import create_engine from _openmetadata_testutils.ometa import int_admin_ometa +from _openmetadata_testutils.postgres.conftest import postgres_container +from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest +from metadata.generated.schema.api.data.createDatabaseSchema import ( + CreateDatabaseSchemaRequest, +) +from metadata.generated.schema.api.services.createDatabaseService import ( + CreateDatabaseServiceRequest, +) +from metadata.generated.schema.entity.services.connections.database.common.basicAuth import ( + BasicAuth, +) +from metadata.generated.schema.entity.services.connections.database.postgresConnection import ( + PostgresConnection, +) +from metadata.generated.schema.entity.services.databaseService import ( + DatabaseConnection, + DatabaseService, + DatabaseServiceType, +) +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.workflow.metadata import MetadataWorkflow -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def metadata(): - """Provide authenticated OpenMetadata client""" return int_admin_ometa() + + +@pytest.fixture(scope="module") +def create_postgres_service(postgres_container, tmp_path_factory): + return CreateDatabaseServiceRequest( + name="dq_test_service_" + tmp_path_factory.mktemp("dq").name, + serviceType=DatabaseServiceType.Postgres, + connection=DatabaseConnection( + config=PostgresConnection( + username=postgres_container.username, + authType=BasicAuth(password=postgres_container.password), + hostPort="localhost:" + + str(postgres_container.get_exposed_port(postgres_container.port)), + database="dq_test_db", + ) + ), + ) + + +@pytest.fixture(scope="module") +def db_service(metadata, create_postgres_service, postgres_container): + engine = create_engine( + postgres_container.get_connection_url(), isolation_level="AUTOCOMMIT" + ) + engine.execute("CREATE DATABASE dq_test_db") + + service_entity = metadata.create_or_update(data=create_postgres_service) + service_entity.connection.config.authType.password = ( + create_postgres_service.connection.config.authType.password + ) + yield service_entity + + service = metadata.get_by_name( + DatabaseService, service_entity.fullyQualifiedName.root + ) + if service: + metadata.delete(DatabaseService, service.id, recursive=True, hard_delete=True) + + +@pytest.fixture(scope="module") +def database(metadata, db_service): + database_entity = metadata.create_or_update( + CreateDatabaseRequest( + name="dq_test_db", + service=db_service.fullyQualifiedName, + ) + ) + return database_entity + + +@pytest.fixture(scope="module") +def schema(metadata, database): + schema_entity = metadata.create_or_update( + CreateDatabaseSchemaRequest( + name="public", + database=database.fullyQualifiedName, + ) + ) + return schema_entity + + +@pytest.fixture(scope="module") +def test_data(db_service, postgres_container): + engine = create_engine( + postgres_container.get_connection_url().replace("/dvdrental", "/dq_test_db") + ) + + sql_metadata = MetaData() + + users_table = SQATable( + "users", + sql_metadata, + SQAColumn("id", Integer, primary_key=True), + SQAColumn("username", String(50), nullable=False), + SQAColumn("email", String(100)), + SQAColumn("age", Integer), + SQAColumn("score", Integer), + ) + + products_table = SQATable( + "products", + sql_metadata, + SQAColumn("product_id", Integer, primary_key=True), + SQAColumn("name", String(100)), + SQAColumn("price", Integer), + ) + + stg_products_table = SQATable( + "stg_products", + sql_metadata, + SQAColumn("id", Integer, primary_key=True), + SQAColumn("name", String(100)), + SQAColumn("price", Integer), + ) + + sql_metadata.create_all(engine) + + with engine.connect() as conn: + conn.execute( + users_table.insert(), + [ + { + "id": 1, + "username": "alice", + "email": "alice@example.com", + "age": 25, + "score": 85, + }, + { + "id": 2, + "username": "bob", + "email": "bob@example.com", + "age": 30, + "score": 90, + }, + {"id": 3, "username": "charlie", "email": None, "age": 35, "score": 75}, + { + "id": 4, + "username": "diana", + "email": "diana@example.com", + "age": 28, + "score": 95, + }, + { + "id": 5, + "username": "eve", + "email": "eve@example.com", + "age": 22, + "score": 88, + }, + ], + ) + + conn.execute( + products_table.insert(), + [ + {"product_id": 1, "name": "Widget", "price": 100}, + {"product_id": 2, "name": "Gadget", "price": 200}, + {"product_id": 3, "name": "Doohickey", "price": 150}, + ], + ) + + conn.execute( + stg_products_table.insert(), + [ + {"id": 1, "name": "Widget", "price": 100}, + {"id": 2, "name": "Gadget", "price": 200}, + {"id": 3, "name": "Doohickey", "price": 150}, + ], + ) + + return { + "users": users_table, + "products": products_table, + "stg_products": stg_products_table, + } + + +@pytest.fixture(scope="module") +def ingest_metadata(metadata, db_service, schema, test_data): + workflow_config = { + "source": { + "type": db_service.connection.config.type.value.lower(), + "serviceName": db_service.fullyQualifiedName.root, + "sourceConfig": { + "config": { + "type": "DatabaseMetadata", + "schemaFilterPattern": {"includes": ["public"]}, + } + }, + "serviceConnection": db_service.connection.model_dump(), + }, + "sink": {"type": "metadata-rest", "config": {}}, + "workflowConfig": { + "loggerLevel": "INFO", + "openMetadataServerConfig": metadata.config.model_dump(), + }, + } + + workflow = MetadataWorkflow.create(workflow_config) + workflow.execute() + workflow.raise_from_status() + + return workflow + + +@pytest.fixture(scope="module") +def patch_passwords(db_service, monkeymodule): + def override_password(getter): + def inner(*args, **kwargs): + result = getter(*args, **kwargs) + if isinstance(result, DatabaseService): + if result.fullyQualifiedName.root == db_service.fullyQualifiedName.root: + result.connection.config.authType.password = ( + db_service.connection.config.authType.password + ) + return result + + return inner + + monkeymodule.setattr( + "metadata.ingestion.ometa.ometa_api.OpenMetadata.get_by_name", + override_password(OpenMetadata.get_by_name), + ) + + monkeymodule.setattr( + "metadata.ingestion.ometa.ometa_api.OpenMetadata.get_by_id", + override_password(OpenMetadata.get_by_id), + ) + + +@pytest.fixture(scope="module") +def monkeymodule(): + with pytest.MonkeyPatch.context() as mp: + yield mp diff --git a/ingestion/tests/integration/sdk/data_quality/dataframes/test_dataframe_validator.py b/ingestion/tests/integration/sdk/data_quality/dataframes/test_dataframe_validator.py new file mode 100644 index 00000000000..679fe896efd --- /dev/null +++ b/ingestion/tests/integration/sdk/data_quality/dataframes/test_dataframe_validator.py @@ -0,0 +1,255 @@ +from typing import Any, Generator, Mapping +from unittest.mock import Mock, patch + +import pandas +import pytest +from dirty_equals import HasAttributes, IsList, IsTuple +from pandas import DataFrame +from sqlalchemy import create_engine +from sqlalchemy.sql.schema import Table as SQATable +from testcontainers.postgres import PostgresContainer + +from metadata.generated.schema.api.tests.createTestCase import CreateTestCaseRequest +from metadata.generated.schema.entity.data.table import Table +from metadata.generated.schema.entity.services.databaseService import DatabaseService +from metadata.generated.schema.tests.basic import TestCaseStatus +from metadata.generated.schema.tests.testCase import TestCase, TestCaseParameterValue +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.sdk.data_quality import ColumnValueMinToBeBetween +from metadata.sdk.data_quality.dataframes.dataframe_validator import DataFrameValidator +from metadata.utils.entity_link import get_entity_link + + +@pytest.fixture(scope="module") +def table_fqn(db_service: DatabaseService) -> str: + return f"{db_service.name.root}.dq_test_db.public.users" + + +@pytest.fixture(scope="module") +def column_unique_test( + table_fqn: str, metadata: OpenMetadata[TestCase, CreateTestCaseRequest] +) -> TestCase: + request = CreateTestCaseRequest( + name="column_not_null", + testDefinition="columnValuesToBeUnique", + entityLink=get_entity_link( + Table, + table_fqn, + column_name="username", + ), + ) + + test_case = metadata.create_or_update(request) + + return test_case + + +@pytest.fixture(scope="module") +def table_row_count_test( + table_fqn: str, metadata: OpenMetadata[TestCase, CreateTestCaseRequest] +) -> TestCase: + request = CreateTestCaseRequest( + name="table_row_count", + testDefinition="tableRowCountToEqual", + entityLink=get_entity_link(Table, table_fqn), + parameterValues=[TestCaseParameterValue(name="value", value="5")], + ) + + test_case = metadata.create_or_update(request) + + return test_case + + +@pytest.fixture(scope="module") +def dataframe( + test_data: Mapping[str, SQATable], postgres_container: PostgresContainer +) -> Generator[DataFrame, None, None]: + engine = create_engine( + postgres_container.get_connection_url().replace("/dvdrental", "/dq_test_db"), + isolation_level="AUTOCOMMIT", + ) + with engine.connect() as connection: + yield pandas.read_sql( + test_data["users"].select(), + connection, + ) + + +def test_it_runs_tests_from_openmetadata( + ingest_metadata: None, + metadata: OpenMetadata[Any, Any], + dataframe: DataFrame, + column_unique_test: TestCase, + table_row_count_test: TestCase, +) -> None: + validator = DataFrameValidator(client=metadata) + + validator.add_openmetadata_test(column_unique_test.fullyQualifiedName.root) + validator.add_openmetadata_test(table_row_count_test.fullyQualifiedName.root) + + result = validator.validate(dataframe) + + assert result == HasAttributes( + success=True, + total_tests=2, + passed_tests=2, + failed_tests=0, + ) + + +def test_it_runs_openmetadata_table_tests( + table_fqn: str, + ingest_metadata: None, + metadata: OpenMetadata[Any, Any], + dataframe: DataFrame, + column_unique_test: TestCase, + table_row_count_test: TestCase, +) -> None: + validator = DataFrameValidator(client=metadata) + + validator.add_openmetadata_table_tests(table_fqn) + + result = validator.validate(dataframe) + + assert result == HasAttributes( + success=True, + total_tests=2, + passed_tests=2, + failed_tests=0, + ) + + +class TestFullUseCase: + def test_it_runs_tests_and_publishes_results( + self, + table_fqn: str, + ingest_metadata: None, + metadata: OpenMetadata[Any, Any], + dataframe: DataFrame, + column_unique_test: TestCase, + table_row_count_test: TestCase, + ) -> None: + # First ensure only previously reported tests exist in OM + test_suite = metadata.get_executable_test_suite(table_fqn) + assert test_suite is not None + original_test_names = {t.fullyQualifiedName for t in test_suite.tests} + assert original_test_names == { + column_unique_test.fullyQualifiedName.root, + table_row_count_test.fullyQualifiedName.root, + } + + # Run validation + validator = DataFrameValidator(client=metadata) + + validator.add_openmetadata_table_tests(table_fqn) + # Forcing a failure with this test + validator.add_test( + ColumnValueMinToBeBetween( + name="column_value_min_to_be_between_90_and_100", + column="score", + min_value=90, + max_value=100, + ) + ) + + result = validator.validate(dataframe) + + assert result == HasAttributes( + success=False, + total_tests=3, + passed_tests=2, + failed_tests=1, + test_cases_and_results=IsList( + IsTuple( + HasAttributes( + fullyQualifiedName=HasAttributes( + root=column_unique_test.fullyQualifiedName.root + ), + ), + HasAttributes( + testCaseStatus=TestCaseStatus.Success, + ), + ), + IsTuple( + HasAttributes( + fullyQualifiedName=HasAttributes( + root=table_row_count_test.fullyQualifiedName.root + ), + ), + HasAttributes( + testCaseStatus=TestCaseStatus.Success, + ), + ), + IsTuple( + HasAttributes( + fullyQualifiedName=HasAttributes( + root="column_value_min_to_be_between_90_and_100" + ), + ), + HasAttributes( + testCaseStatus=TestCaseStatus.Failed, + ), + ), + check_order=False, + ), + ) + + # Publish results + with patch( + "metadata.sdk.data_quality.dataframes.validation_results.get_client", + return_value=Mock(ometa=metadata), + ) as mock_client: + result.publish( + table_fqn, + ) + + mock_client.assert_called_once() + + # Validate test defined in code has been pushed to OM as well + test_suite = metadata.get_executable_test_suite(table_fqn) + + assert test_suite is not None + assert len(test_suite.tests) == 3 + assert original_test_names.issubset( + {t.fullyQualifiedName for t in test_suite.tests} + ) + + assert {t.fullyQualifiedName for t in test_suite.tests} == { + column_unique_test.fullyQualifiedName.root, + table_row_count_test.fullyQualifiedName.root, + f"{table_fqn}.score.column_value_min_to_be_between_90_and_100", + } + + # Validate each test case result + required_fields = ["testCaseResult", "testSuite", "testDefinition"] + column_unique_result = metadata.get_by_name( + TestCase, column_unique_test.fullyQualifiedName.root, fields=required_fields + ).testCaseResult + assert column_unique_result == HasAttributes( + testCaseStatus=TestCaseStatus.Success + ) + + table_row_count_result = metadata.get_by_name( + TestCase, + table_row_count_test.fullyQualifiedName.root, + fields=required_fields, + ).testCaseResult + assert table_row_count_result == HasAttributes( + testCaseStatus=TestCaseStatus.Success + ) + + code_test_case_result = metadata.get_by_name( + TestCase, + f"{table_fqn}.score.column_value_min_to_be_between_90_and_100", + fields=required_fields, + ).testCaseResult + assert code_test_case_result == HasAttributes( + testCaseStatus=TestCaseStatus.Failed + ) + + # Clean up code test + metadata.delete_test_case( + f"{table_fqn}.score.column_value_min_to_be_between_90_and_100", + recursive=True, + hard=True, + ) diff --git a/ingestion/tests/integration/sdk/test_dq_as_code_integration.py b/ingestion/tests/integration/sdk/test_dq_as_code_integration.py index ffbb6bcda8e..bea9cc0c79b 100644 --- a/ingestion/tests/integration/sdk/test_dq_as_code_integration.py +++ b/ingestion/tests/integration/sdk/test_dq_as_code_integration.py @@ -6,36 +6,11 @@ import sys import pytest from dirty_equals import HasAttributes -from sqlalchemy import Column as SQAColumn -from sqlalchemy import Integer, MetaData, String -from sqlalchemy import Table as SQATable -from sqlalchemy import create_engine -from _openmetadata_testutils.ometa import int_admin_ometa -from _openmetadata_testutils.postgres.conftest import postgres_container -from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest -from metadata.generated.schema.api.data.createDatabaseSchema import ( - CreateDatabaseSchemaRequest, -) -from metadata.generated.schema.api.services.createDatabaseService import ( - CreateDatabaseServiceRequest, -) from metadata.generated.schema.entity.data.table import Table -from metadata.generated.schema.entity.services.connections.database.common.basicAuth import ( - BasicAuth, -) -from metadata.generated.schema.entity.services.connections.database.postgresConnection import ( - PostgresConnection, -) -from metadata.generated.schema.entity.services.databaseService import ( - DatabaseConnection, - DatabaseService, - DatabaseServiceType, -) from metadata.generated.schema.tests.basic import TestCaseStatus from metadata.generated.schema.tests.testCase import TestCase from metadata.generated.schema.type.basic import EntityLink -from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.sdk.data_quality import ( ColumnValuesToBeBetween, ColumnValuesToBeNotNull, @@ -45,7 +20,6 @@ from metadata.sdk.data_quality import ( TableRowCountToBeBetween, TestRunner, ) -from metadata.workflow.metadata import MetadataWorkflow if not sys.version_info >= (3, 9): pytest.skip( @@ -54,226 +28,6 @@ if not sys.version_info >= (3, 9): ) -@pytest.fixture(scope="module") -def metadata(): - return int_admin_ometa() - - -@pytest.fixture(scope="module") -def create_postgres_service(postgres_container, tmp_path_factory): - return CreateDatabaseServiceRequest( - name="dq_test_service_" + tmp_path_factory.mktemp("dq").name, - serviceType=DatabaseServiceType.Postgres, - connection=DatabaseConnection( - config=PostgresConnection( - username=postgres_container.username, - authType=BasicAuth(password=postgres_container.password), - hostPort="localhost:" - + str(postgres_container.get_exposed_port(postgres_container.port)), - database="dq_test_db", - ) - ), - ) - - -@pytest.fixture(scope="module") -def db_service(metadata, create_postgres_service, postgres_container): - engine = create_engine( - postgres_container.get_connection_url(), isolation_level="AUTOCOMMIT" - ) - engine.execute("CREATE DATABASE dq_test_db") - - service_entity = metadata.create_or_update(data=create_postgres_service) - service_entity.connection.config.authType.password = ( - create_postgres_service.connection.config.authType.password - ) - yield service_entity - - service = metadata.get_by_name( - DatabaseService, service_entity.fullyQualifiedName.root - ) - if service: - metadata.delete(DatabaseService, service.id, recursive=True, hard_delete=True) - - -@pytest.fixture(scope="module") -def database(metadata, db_service): - database_entity = metadata.create_or_update( - CreateDatabaseRequest( - name="dq_test_db", - service=db_service.fullyQualifiedName, - ) - ) - return database_entity - - -@pytest.fixture(scope="module") -def schema(metadata, database): - schema_entity = metadata.create_or_update( - CreateDatabaseSchemaRequest( - name="public", - database=database.fullyQualifiedName, - ) - ) - return schema_entity - - -@pytest.fixture(scope="module") -def test_data(postgres_container): - engine = create_engine( - postgres_container.get_connection_url().replace("/dvdrental", "/dq_test_db") - ) - - sql_metadata = MetaData() - - users_table = SQATable( - "users", - sql_metadata, - SQAColumn("id", Integer, primary_key=True), - SQAColumn("username", String(50), nullable=False), - SQAColumn("email", String(100)), - SQAColumn("age", Integer), - SQAColumn("score", Integer), - ) - - products_table = SQATable( - "products", - sql_metadata, - SQAColumn("product_id", Integer, primary_key=True), - SQAColumn("name", String(100)), - SQAColumn("price", Integer), - ) - - stg_products_table = SQATable( - "stg_products", - sql_metadata, - SQAColumn("id", Integer, primary_key=True), - SQAColumn("name", String(100)), - SQAColumn("price", Integer), - ) - - sql_metadata.create_all(engine) - - with engine.connect() as conn: - conn.execute( - users_table.insert(), - [ - { - "id": 1, - "username": "alice", - "email": "alice@example.com", - "age": 25, - "score": 85, - }, - { - "id": 2, - "username": "bob", - "email": "bob@example.com", - "age": 30, - "score": 90, - }, - {"id": 3, "username": "charlie", "email": None, "age": 35, "score": 75}, - { - "id": 4, - "username": "diana", - "email": "diana@example.com", - "age": 28, - "score": 95, - }, - { - "id": 5, - "username": "eve", - "email": "eve@example.com", - "age": 22, - "score": 88, - }, - ], - ) - - conn.execute( - products_table.insert(), - [ - {"product_id": 1, "name": "Widget", "price": 100}, - {"product_id": 2, "name": "Gadget", "price": 200}, - {"product_id": 3, "name": "Doohickey", "price": 150}, - ], - ) - - conn.execute( - stg_products_table.insert(), - [ - {"id": 1, "name": "Widget", "price": 100}, - {"id": 2, "name": "Gadget", "price": 200}, - {"id": 3, "name": "Doohickey", "price": 150}, - ], - ) - - return { - "users": users_table, - "products": products_table, - "stg_products": stg_products_table, - } - - -@pytest.fixture(scope="module") -def ingest_metadata(metadata, db_service, schema, test_data): - workflow_config = { - "source": { - "type": db_service.connection.config.type.value.lower(), - "serviceName": db_service.fullyQualifiedName.root, - "sourceConfig": { - "config": { - "type": "DatabaseMetadata", - "schemaFilterPattern": {"includes": ["public"]}, - } - }, - "serviceConnection": db_service.connection.model_dump(), - }, - "sink": {"type": "metadata-rest", "config": {}}, - "workflowConfig": { - "loggerLevel": "INFO", - "openMetadataServerConfig": metadata.config.model_dump(), - }, - } - - workflow = MetadataWorkflow.create(workflow_config) - workflow.execute() - workflow.raise_from_status() - - return workflow - - -@pytest.fixture(scope="module") -def patch_passwords(db_service, monkeymodule): - def override_password(getter): - def inner(*args, **kwargs): - result = getter(*args, **kwargs) - if isinstance(result, DatabaseService): - if result.fullyQualifiedName.root == db_service.fullyQualifiedName.root: - result.connection.config.authType.password = ( - db_service.connection.config.authType.password - ) - return result - - return inner - - monkeymodule.setattr( - "metadata.ingestion.ometa.ometa_api.OpenMetadata.get_by_name", - override_password(OpenMetadata.get_by_name), - ) - - monkeymodule.setattr( - "metadata.ingestion.ometa.ometa_api.OpenMetadata.get_by_id", - override_password(OpenMetadata.get_by_id), - ) - - -@pytest.fixture(scope="module") -def monkeymodule(): - with pytest.MonkeyPatch.context() as mp: - yield mp - - def test_table_row_count_tests( metadata, db_service, diff --git a/ingestion/tests/unit/sdk/data_quality/test_dataframe_validator.py b/ingestion/tests/unit/sdk/data_quality/test_dataframe_validator.py new file mode 100644 index 00000000000..3b00c84b286 --- /dev/null +++ b/ingestion/tests/unit/sdk/data_quality/test_dataframe_validator.py @@ -0,0 +1,561 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for DataFrame validator.""" +from typing import Generator, List, Tuple +from unittest.mock import Mock + +import pandas as pd +import pytest +from pandas import DataFrame + +from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus +from metadata.generated.schema.tests.testCase import TestCase +from metadata.sdk.data_quality import ( + ColumnValuesToBeNotNull, + ColumnValuesToBeUnique, + TableRowCountToBeBetween, +) +from metadata.sdk.data_quality.dataframes.custom_warnings import WholeTableTestsWarning +from metadata.sdk.data_quality.dataframes.dataframe_validator import DataFrameValidator +from metadata.sdk.data_quality.dataframes.validation_results import ( + FailureMode, + ValidationResult, +) + + +class TestDataFrameValidator: + """Test DataFrameValidator initialization and configuration.""" + + def test_validator_initialization(self): + """Test validator can be created.""" + validator = DataFrameValidator(Mock()) + assert validator is not None + assert validator._test_cases == [] + + def test_add_single_test(self): + """Test adding a single test definition.""" + validator = DataFrameValidator(Mock()) + validator.add_test(ColumnValuesToBeNotNull(column="email")) + assert len(validator._test_cases) == 1 + + def test_add_multiple_tests(self): + """Test adding multiple test definitions at once.""" + validator = DataFrameValidator(Mock()) + validator.add_tests( + ColumnValuesToBeNotNull(column="email"), + ColumnValuesToBeUnique(column="id"), + TableRowCountToBeBetween(min_count=1, max_count=100), + ) + assert len(validator._test_cases) == 3 + + +class TestValidationSuccess: + """Test successful validation scenarios.""" + + def test_validate_not_null_success(self): + """Test validation passes with valid DataFrame.""" + df = pd.DataFrame({"email": ["a@b.com", "c@d.com", "e@f.com"]}) + validator = DataFrameValidator(Mock()) + validator.add_test(ColumnValuesToBeNotNull(column="email")) + + result = validator.validate(df) + + assert result.success is True + assert result.passed_tests == 1 + assert result.failed_tests == 0 + assert len(result.test_results) == 1 + assert result.test_results[0].testCaseStatus is TestCaseStatus.Success + + def test_validate_unique_success(self): + """Test uniqueness validation passes.""" + df = pd.DataFrame({"id": [1, 2, 3, 4, 5]}) + validator = DataFrameValidator(Mock()) + validator.add_test(ColumnValuesToBeUnique(column="id")) + + result = validator.validate(df) + + assert result.success is True + assert result.passed_tests == 1 + + def test_validate_multiple_tests_success(self): + """Test multiple tests all passing.""" + df = pd.DataFrame( + { + "id": [1, 2, 3], + "email": ["a@b.com", "c@d.com", "e@f.com"], + "age": [25, 30, 35], + } + ) + validator = DataFrameValidator(Mock()) + validator.add_tests( + ColumnValuesToBeNotNull(column="email"), + ColumnValuesToBeUnique(column="id"), + TableRowCountToBeBetween(min_count=1, max_count=10), + ) + + result = validator.validate(df) + + assert result.success is True + assert result.passed_tests == 3 + assert result.failed_tests == 0 + + +class TestValidationFailure: + """Test validation failure scenarios.""" + + def test_validate_not_null_failure(self): + """Test validation fails with null values.""" + df = pd.DataFrame({"email": ["a@b.com", None, "e@f.com"]}) + validator = DataFrameValidator(Mock()) + validator.add_test(ColumnValuesToBeNotNull(column="email")) + + result = validator.validate(df) + + assert result.success is False + assert result.passed_tests == 0 + assert result.failed_tests == 1 + assert result.test_results[0].testCaseStatus is TestCaseStatus.Failed + + def test_validate_unique_failure(self): + """Test uniqueness validation fails with duplicates.""" + df = pd.DataFrame({"id": [1, 2, 2, 3]}) + validator = DataFrameValidator(Mock()) + validator.add_test(ColumnValuesToBeUnique(column="id")) + + result = validator.validate(df) + + assert result.success is False + assert result.failed_tests == 1 + + def test_validate_row_count_failure(self): + """Test row count validation fails.""" + df = pd.DataFrame({"col": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]}) + validator = DataFrameValidator(Mock()) + validator.add_test(TableRowCountToBeBetween(min_count=1, max_count=10)) + + result = validator.validate(df) + + assert result.success is False + + +class TestShortCircuitMode: + """Test short-circuit validation mode.""" + + def test_short_circuit_stops_on_first_failure(self): + """Test short-circuit mode stops after first failure.""" + df = pd.DataFrame( + { + "id": [1, 2, 2], + "email": [None, None, None], + } + ) + validator = DataFrameValidator(Mock()) + validator.add_tests( + ColumnValuesToBeUnique(column="id"), + ColumnValuesToBeNotNull(column="email"), + ) + + result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + assert result.success is False + assert len(result.test_results) == 1 + + def test_short_circuit_continues_on_success(self): + """Test short-circuit mode continues when tests pass.""" + df = pd.DataFrame( + { + "id": [1, 2, 3], + "email": ["a@b.com", "c@d.com", "e@f.com"], + } + ) + validator = DataFrameValidator(Mock()) + validator.add_tests( + ColumnValuesToBeUnique(column="id"), + ColumnValuesToBeNotNull(column="email"), + ) + + result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + assert result.success is True + assert len(result.test_results) == 2 + + +class TestValidationResultProperties: + """Test ValidationResult helper properties.""" + + def test_failures_property(self): + """Test failures property returns only failed tests.""" + df = pd.DataFrame( + { + "id": [1, 2, 2], + "email": ["a@b.com", "c@d.com", "e@f.com"], + } + ) + validator = DataFrameValidator(Mock()) + validator.add_tests( + ColumnValuesToBeUnique(column="id"), + ColumnValuesToBeNotNull(column="email"), + ) + + result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + assert len(result.failures) == 1 + + def test_passes_property(self): + """Test passes property returns only passed tests.""" + df = pd.DataFrame( + { + "id": [1, 2, 3], + "email": ["a@b.com", None, "e@f.com"], + } + ) + validator = DataFrameValidator(Mock()) + validator.add_tests( + ColumnValuesToBeUnique(column="id"), + ColumnValuesToBeNotNull(column="email"), + ) + + result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + assert len(result.passes) == 1 + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_dataframe(self): + """Test validation on empty DataFrame.""" + df = pd.DataFrame({"email": []}) + validator = DataFrameValidator(Mock()) + validator.add_test(ColumnValuesToBeNotNull(column="email")) + + result = validator.validate(df) + + assert result.success is True + + def test_missing_column(self): + """Test validation with missing column.""" + df = pd.DataFrame({"name": ["Alice", "Bob"]}) + validator = DataFrameValidator(Mock()) + validator.add_test(ColumnValuesToBeNotNull(column="email")) + + result = validator.validate(df) + + assert result.success is False + assert result.test_results[0].testCaseStatus is TestCaseStatus.Aborted + + def test_no_tests_configured(self): + """Test validation with no tests configured.""" + df = pd.DataFrame({"col": [1, 2, 3]}) + validator = DataFrameValidator(Mock()) + + result = validator.validate(df) + + assert result.success is True + assert result.total_tests == 0 + assert len(result.test_results) == 0 + + def test_execution_time_recorded(self): + """Test that execution time is recorded.""" + df = pd.DataFrame({"col": [1, 2, 3]}) + validator = DataFrameValidator(Mock()) + validator.add_test(TableRowCountToBeBetween(min_count=1, max_count=10)) + + result = validator.validate(df) + + assert result.execution_time_ms > 0 + + +class TracksValidationCallbacks: + def __init__(self) -> None: + self.calls: List[Tuple[DataFrame, ValidationResult]] = [] + + @property + def times_called(self) -> int: + return len(self.calls) + + @property + def was_called(self) -> bool: + return self.times_called > 0 + + def __call__(self, df: DataFrame, result: ValidationResult) -> None: + self.calls.append((df, result)) + + +@pytest.mark.filterwarnings( + "error::metadata.sdk.data_quality.dataframes.custom_warnings.WholeTableTestsWarning" +) +class TestValidatorRun: + @pytest.fixture + def on_success_callback(self) -> TracksValidationCallbacks: + return TracksValidationCallbacks() + + @pytest.fixture + def on_failure_callback(self) -> TracksValidationCallbacks: + return TracksValidationCallbacks() + + @pytest.fixture + def validator(self) -> DataFrameValidator: + df_validator = DataFrameValidator(Mock()) + df_validator.add_tests( + ColumnValuesToBeNotNull(column="id"), + ) + return df_validator + + def test_it_only_calls_on_success( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + dfs = (pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) for i in range(0, 9, 3)) + + validator.run(dfs, on_success_callback, on_failure_callback) + + assert on_success_callback.times_called == 3 + assert on_failure_callback.was_called is False + + def test_it_calls_both_on_success_and_on_failure( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + def generate_data() -> Generator[DataFrame, None, None]: + for i in range(0, 6, 3): + yield pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) + yield pd.DataFrame({"id": [None]}) + + validator.run(generate_data(), on_success_callback, on_failure_callback) + + assert on_success_callback.times_called == 2 + assert on_failure_callback.times_called == 1 + assert pd.DataFrame({"id": [None]}).equals(on_failure_callback.calls[0][0]) + + def test_it_aborts_on_failure( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + def generate_data() -> Generator[DataFrame, None, None]: + yield pd.DataFrame({"id": [None]}) + for i in range(0, 6, 3): + yield pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) + + validator.run(generate_data(), on_success_callback, on_failure_callback) + + assert on_success_callback.was_called is False + assert on_failure_callback.times_called == 1 + assert pd.DataFrame({"id": [None]}).equals(on_failure_callback.calls[0][0]) + + def test_it_warns_when_using_tests_that_require_the_whole_table_or_column( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + dfs = (pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) for i in range(0, 9, 3)) + + validator.add_tests( + TableRowCountToBeBetween(min_count=1, max_count=10), + ColumnValuesToBeUnique(column="id"), + ) + + expected_warning_match = ( + "The following tests could have unexpected results:\n\n" + + "\t- columnValuesToBeUnique\n" + + "\t- tableRowCountToBeBetween" + ) + + with pytest.warns(WholeTableTestsWarning, match=expected_warning_match): + validator.run(dfs, on_success_callback, on_failure_callback) + + def test_run_returns_merged_validation_result( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + dfs = (pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) for i in range(0, 9, 3)) + + result = validator.run(dfs, on_success_callback, on_failure_callback) + + assert isinstance(result, ValidationResult) + assert result.success is True + + def test_merged_result_has_correct_aggregated_metrics( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + dfs = [ + pd.DataFrame({"id": [1, 2, 3]}), + pd.DataFrame({"id": [4, 5, 6]}), + pd.DataFrame({"id": [7, 8, 9]}), + ] + + result = validator.run(iter(dfs), on_success_callback, on_failure_callback) + + assert result.total_tests == 1 + assert result.passed_tests == 1 + assert result.failed_tests == 0 + assert result.success is True + + def test_merged_result_aggregates_failures_correctly( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + dfs = [ + pd.DataFrame({"id": [1, 2, 3]}), + pd.DataFrame({"id": [None]}), + ] + + result = validator.run(iter(dfs), on_success_callback, on_failure_callback) + + assert result.total_tests == 1 + assert result.passed_tests == 0 + assert result.failed_tests == 1 + assert result.success is False + + def test_merged_result_contains_aggregated_test_case_results( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + validator.add_test(ColumnValuesToBeUnique(column="id")) + + dfs = [ + pd.DataFrame({"id": [1, 2, 3]}), + pd.DataFrame({"id": [4, 5, 6]}), + ] + + with pytest.warns(WholeTableTestsWarning): + result = validator.run(iter(dfs), on_success_callback, on_failure_callback) + + assert len(result.test_cases_and_results) == 2 + assert all( + isinstance(test_case, TestCase) and isinstance(test_result, TestCaseResult) + for test_case, test_result in result.test_cases_and_results + ) + _, aggregated_not_null_result = result.test_cases_and_results[0] + assert aggregated_not_null_result.passedRows == 6 + + def test_merged_result_sums_execution_times( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + dfs = [ + pd.DataFrame({"id": [1, 2, 3]}), + pd.DataFrame({"id": [4, 5, 6]}), + pd.DataFrame({"id": [7, 8, 9]}), + ] + + result = validator.run(iter(dfs), on_success_callback, on_failure_callback) + + assert result.execution_time_ms > 0 + individual_times = [ + call[1].execution_time_ms for call in on_success_callback.calls + ] + assert result.execution_time_ms == sum(individual_times) + + def test_merged_result_with_mixed_success_and_failure( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + def generate_mixed_data() -> Generator[DataFrame, None, None]: + yield pd.DataFrame({"id": [1, 2, 3]}) + yield pd.DataFrame({"id": [4, 5, 6]}) + yield pd.DataFrame({"id": [None]}) + + result = validator.run( + generate_mixed_data(), on_success_callback, on_failure_callback + ) + + assert result.total_tests == 1 + assert result.passed_tests == 0 + assert result.failed_tests == 1 + assert result.success is False + assert len(result.test_cases_and_results) == 1 + + def test_merged_result_aggregates_multiple_tests_per_batch( + self, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + validator = DataFrameValidator(Mock()) + validator.add_tests( + ColumnValuesToBeNotNull(column="id"), + ColumnValuesToBeUnique(column="id"), + TableRowCountToBeBetween(min_count=1, max_count=10), + ) + + dfs = [ + pd.DataFrame({"id": [1, 2, 3]}), + pd.DataFrame({"id": [4, 5, 6]}), + ] + + with pytest.warns(WholeTableTestsWarning): + result = validator.run(iter(dfs), on_success_callback, on_failure_callback) + + assert result.total_tests == 3 + assert result.passed_tests == 3 + assert result.failed_tests == 0 + assert result.success is True + + def test_merged_result_with_short_circuit_on_failure( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + def generate_data() -> Generator[DataFrame, None, None]: + yield pd.DataFrame({"id": [None]}) + yield pd.DataFrame({"id": [1, 2, 3]}) + yield pd.DataFrame({"id": [4, 5, 6]}) + + result = validator.run( + generate_data(), on_success_callback, on_failure_callback + ) + + assert result.total_tests == 1 + assert result.passed_tests == 0 + assert result.failed_tests == 1 + assert result.success is False + + def test_merged_result_reflects_all_batches_processed( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + batch_count = 5 + dfs = [ + pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) + for i in range(0, batch_count * 3, 3) + ] + + result = validator.run(iter(dfs), on_success_callback, on_failure_callback) + + assert result.total_tests == 1 + assert len(result.test_cases_and_results) == 1 + assert on_success_callback.times_called == batch_count + _, aggregated_result = result.test_cases_and_results[0] + assert aggregated_result.passedRows == batch_count * 3 diff --git a/ingestion/tests/unit/sdk/data_quality/test_validation_results.py b/ingestion/tests/unit/sdk/data_quality/test_validation_results.py new file mode 100644 index 00000000000..31ef23a4469 --- /dev/null +++ b/ingestion/tests/unit/sdk/data_quality/test_validation_results.py @@ -0,0 +1,377 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ValidationResult.""" +from datetime import datetime +from uuid import UUID + +import pytest + +from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus +from metadata.generated.schema.tests.testCase import TestCase +from metadata.generated.schema.type.basic import FullyQualifiedEntityName, Timestamp +from metadata.generated.schema.type.entityReference import EntityReference +from metadata.sdk.data_quality.dataframes.validation_results import ValidationResult + + +def create_test_case(fqn: str) -> TestCase: + """Helper to create a test case with minimal required fields.""" + return TestCase( + name=fqn.split(".")[-1], + fullyQualifiedName=FullyQualifiedEntityName(fqn), + testDefinition=EntityReference( + id=UUID("12345678-1234-1234-1234-123456789abc"), + type="testDefinition", + fullyQualifiedName=fqn, + ), + entityLink="<#E::table::test_table>", + testSuite=EntityReference( + id=UUID("87654321-4321-4321-4321-cba987654321"), + type="testSuite", + fullyQualifiedName="test_suite", + ), + ) + + +def create_test_result( + status: TestCaseStatus, + passed_rows: int = 0, + failed_rows: int = 0, +) -> TestCaseResult: + """Helper to create a test case result.""" + total = passed_rows + failed_rows + return TestCaseResult( + timestamp=Timestamp(int(datetime.now().timestamp() * 1000)), + testCaseStatus=status, + passedRows=passed_rows if total > 0 else None, + failedRows=failed_rows if total > 0 else None, + passedRowsPercentage=(passed_rows / total * 100) if total > 0 else None, + failedRowsPercentage=(failed_rows / total * 100) if total > 0 else None, + ) + + +class TestValidationResultMerge: + """Test ValidationResult.merge method.""" + + def test_merge_single_result(self) -> None: + """Test merging a single ValidationResult returns equivalent result.""" + test_case = create_test_case("test.case.one") + test_result = create_test_result(TestCaseStatus.Success, passed_rows=100) + + result = ValidationResult( + success=True, + total_tests=1, + passed_tests=1, + failed_tests=0, + test_cases_and_results=[(test_case, test_result)], + execution_time_ms=10.0, + ) + + merged = ValidationResult.merge(result) + + assert merged.success is True + assert merged.total_tests == 1 + assert merged.passed_tests == 1 + assert merged.failed_tests == 0 + assert merged.execution_time_ms == 10.0 + assert len(merged.test_cases_and_results) == 1 + + def test_merge_multiple_results_same_test_case(self) -> None: + """Test merging results for same test case aggregates metrics.""" + test_case = create_test_case("test.case.one") + + result1 = ValidationResult( + success=True, + total_tests=1, + passed_tests=1, + failed_tests=0, + test_cases_and_results=[ + (test_case, create_test_result(TestCaseStatus.Success, passed_rows=50)) + ], + execution_time_ms=10.0, + ) + + result2 = ValidationResult( + success=True, + total_tests=1, + passed_tests=1, + failed_tests=0, + test_cases_and_results=[ + (test_case, create_test_result(TestCaseStatus.Success, passed_rows=30)) + ], + execution_time_ms=8.0, + ) + + merged = ValidationResult.merge(result1, result2) + + assert merged.success is True + assert merged.total_tests == 1 + assert merged.passed_tests == 1 + assert merged.failed_tests == 0 + assert merged.execution_time_ms == 18.0 + assert len(merged.test_cases_and_results) == 1 + + _, aggregated_result = merged.test_cases_and_results[0] + assert aggregated_result.passedRows == 80 + assert aggregated_result.failedRows == 0 + assert aggregated_result.testCaseStatus == TestCaseStatus.Success + + def test_merge_aggregates_passed_and_failed_rows(self) -> None: + """Test that passed and failed rows are summed correctly.""" + test_case = create_test_case("test.case.one") + + result1 = ValidationResult( + success=False, + total_tests=1, + passed_tests=0, + failed_tests=1, + test_cases_and_results=[ + ( + test_case, + create_test_result( + TestCaseStatus.Failed, passed_rows=40, failed_rows=10 + ), + ) + ], + execution_time_ms=10.0, + ) + + result2 = ValidationResult( + success=False, + total_tests=1, + passed_tests=0, + failed_tests=1, + test_cases_and_results=[ + ( + test_case, + create_test_result( + TestCaseStatus.Failed, passed_rows=30, failed_rows=20 + ), + ) + ], + execution_time_ms=12.0, + ) + + merged = ValidationResult.merge(result1, result2) + + assert merged.success is False + assert merged.total_tests == 1 + assert merged.failed_tests == 1 + assert merged.execution_time_ms == 22.0 + + _, aggregated_result = merged.test_cases_and_results[0] + assert aggregated_result.passedRows == 70 + assert aggregated_result.failedRows == 30 + assert aggregated_result.testCaseStatus == TestCaseStatus.Failed + assert aggregated_result.passedRowsPercentage == pytest.approx(70.0) + assert aggregated_result.failedRowsPercentage == pytest.approx(30.0) + + def test_merge_multiple_test_cases(self) -> None: + """Test merging results with multiple different test cases.""" + test_case1 = create_test_case("test.case.one") + test_case2 = create_test_case("test.case.two") + + result1 = ValidationResult( + success=True, + total_tests=2, + passed_tests=2, + failed_tests=0, + test_cases_and_results=[ + ( + test_case1, + create_test_result(TestCaseStatus.Success, passed_rows=50), + ), + ( + test_case2, + create_test_result(TestCaseStatus.Success, passed_rows=30), + ), + ], + execution_time_ms=15.0, + ) + + result2 = ValidationResult( + success=True, + total_tests=2, + passed_tests=2, + failed_tests=0, + test_cases_and_results=[ + ( + test_case1, + create_test_result(TestCaseStatus.Success, passed_rows=25), + ), + ( + test_case2, + create_test_result(TestCaseStatus.Success, passed_rows=35), + ), + ], + execution_time_ms=12.0, + ) + + merged = ValidationResult.merge(result1, result2) + + assert merged.success is True + assert merged.total_tests == 2 + assert merged.passed_tests == 2 + assert merged.failed_tests == 0 + assert merged.execution_time_ms == 27.0 + assert len(merged.test_cases_and_results) == 2 + + fqns_to_results = { + tc.fullyQualifiedName.root: result + for tc, result in merged.test_cases_and_results + } + + assert fqns_to_results["test.case.one"].passedRows == 75 + assert fqns_to_results["test.case.two"].passedRows == 65 + + def test_merge_with_failure_status_propagates(self) -> None: + """Test that if any batch fails, overall status is Failed.""" + test_case = create_test_case("test.case.one") + + result1 = ValidationResult( + success=True, + total_tests=1, + passed_tests=1, + failed_tests=0, + test_cases_and_results=[ + (test_case, create_test_result(TestCaseStatus.Success, passed_rows=50)) + ], + execution_time_ms=10.0, + ) + + result2 = ValidationResult( + success=False, + total_tests=1, + passed_tests=0, + failed_tests=1, + test_cases_and_results=[ + ( + test_case, + create_test_result( + TestCaseStatus.Failed, passed_rows=20, failed_rows=10 + ), + ) + ], + execution_time_ms=10.0, + ) + + merged = ValidationResult.merge(result1, result2) + + assert merged.success is False + assert merged.failed_tests == 1 + + _, aggregated_result = merged.test_cases_and_results[0] + assert aggregated_result.testCaseStatus == TestCaseStatus.Failed + + def test_merge_with_aborted_status_takes_precedence(self) -> None: + """Test that Aborted status takes precedence over Failed.""" + test_case = create_test_case("test.case.one") + + result1 = ValidationResult( + success=False, + total_tests=1, + passed_tests=0, + failed_tests=1, + test_cases_and_results=[ + ( + test_case, + create_test_result( + TestCaseStatus.Failed, passed_rows=40, failed_rows=10 + ), + ) + ], + execution_time_ms=10.0, + ) + + result2 = ValidationResult( + success=False, + total_tests=1, + passed_tests=0, + failed_tests=1, + test_cases_and_results=[ + (test_case, create_test_result(TestCaseStatus.Aborted)) + ], + execution_time_ms=5.0, + ) + + merged = ValidationResult.merge(result1, result2) + + _, aggregated_result = merged.test_cases_and_results[0] + assert aggregated_result.testCaseStatus == TestCaseStatus.Aborted + + def test_merge_empty_raises_error(self) -> None: + """Test that merging with no results raises ValueError.""" + with pytest.raises(ValueError, match="At least one ValidationResult"): + ValidationResult.merge() + + def test_merge_without_fqn_raises_error(self) -> None: + """Test that merging test cases without FQN raises ValueError.""" + test_case = TestCase( + name="test", + fullyQualifiedName=None, + testDefinition=EntityReference( + id=UUID("12345678-1234-1234-1234-123456789abc"), + type="testDefinition", + fullyQualifiedName="test.def", + ), + entityLink="<#E::table::test_table>", + testSuite=EntityReference( + id=UUID("87654321-4321-4321-4321-cba987654321"), + type="testSuite", + fullyQualifiedName="test_suite", + ), + ) + + result = ValidationResult( + success=True, + total_tests=1, + passed_tests=1, + failed_tests=0, + test_cases_and_results=[ + (test_case, create_test_result(TestCaseStatus.Success)) + ], + execution_time_ms=10.0, + ) + + with pytest.raises(ValueError, match="no fullyQualifiedName"): + ValidationResult.merge(result) + + def test_merge_preserves_test_case_reference(self) -> None: + """Test that the merged result contains the original test case reference.""" + test_case = create_test_case("test.case.one") + + result1 = ValidationResult( + success=True, + total_tests=1, + passed_tests=1, + failed_tests=0, + test_cases_and_results=[ + (test_case, create_test_result(TestCaseStatus.Success, passed_rows=50)) + ], + execution_time_ms=10.0, + ) + + result2 = ValidationResult( + success=True, + total_tests=1, + passed_tests=1, + failed_tests=0, + test_cases_and_results=[ + (test_case, create_test_result(TestCaseStatus.Success, passed_rows=30)) + ], + execution_time_ms=8.0, + ) + + merged = ValidationResult.merge(result1, result2) + + merged_test_case, _ = merged.test_cases_and_results[0] + assert merged_test_case.fullyQualifiedName.root == "test.case.one" + assert merged_test_case.name.root == "one"