diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/tests_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/tests_mixin.py index 1e93b9ad535..f799789279b 100644 --- a/ingestion/src/metadata/ingestion/ometa/mixins/tests_mixin.py +++ b/ingestion/src/metadata/ingestion/ometa/mixins/tests_mixin.py @@ -17,6 +17,7 @@ To be used by OpenMetadata class import traceback from datetime import datetime from typing import List, Optional, Type, Union +from urllib.parse import urlencode, urljoin from uuid import UUID from metadata.generated.schema.api.tests.createLogicalTestCases import ( @@ -170,7 +171,7 @@ class OMetaTestsMixin: test_definition_fqn: Optional[str] = None, test_case_parameter_values: Optional[List[TestCaseParameterValue]] = None, description: Optional[str] = None, - ): + ) -> TestCase: """Get or create a test case Args: @@ -203,6 +204,34 @@ class OMetaTestsMixin: ) return test_case + def get_executable_test_suite(self, table_fqn: str) -> Optional[TestSuite]: + """Given an entity fqn, retrieve the link test suite if it exists + + Args: + table_fqn (str): entity fully qualified name + + Returns: + An instance of TestSuite or None + """ + table_entity = self.get_by_name( + entity=Table, fqn=table_fqn, fields=["testSuite"] + ) + if not table_entity: + raise RuntimeError( + f"Unable to find table {table_fqn} in OpenMetadata. " + "This could be because the table has not been ingested yet or your JWT Token is expired or missing." + ) + + if not table_entity.testSuite: + return None + + return self.get_by_name( + entity=TestSuite, + fqn=table_entity.testSuite.fullyQualifiedName, + fields=["tests"], + nullable=False, + ) + def get_or_create_executable_test_suite( self, entity_fqn: str ) -> Union[EntityReference, TestSuite]: @@ -379,3 +408,25 @@ class OMetaTestsMixin: data=inspection_query, ) return TestCase(**resp) + + def delete_test_case( + self, + test_case_fqn: str, + recursive: bool = True, + hard: bool = False, + ) -> None: + """Delete a test case + Args: + test_case_fqn: Fully qualified name of the test case to delete + recursive (bool, optional): delete children if true + hard (bool, optional): hard delete if true + """ + params = urlencode( + dict( + recursive="true" if recursive else "false", + hardDelete="true" if hard else "false", + ) + ) + url = f"{self.get_suffix(TestCase)}/name/{quote(test_case_fqn)}" + + self.client.delete(urljoin(url, "?" + params)) diff --git a/ingestion/src/metadata/sdk/data_quality/dataframes/custom_warnings.py b/ingestion/src/metadata/sdk/data_quality/dataframes/custom_warnings.py new file mode 100644 index 00000000000..b08efdca6e7 --- /dev/null +++ b/ingestion/src/metadata/sdk/data_quality/dataframes/custom_warnings.py @@ -0,0 +1,2 @@ +class WholeTableTestsWarning(RuntimeWarning): + """Warns when the user runs tests that require the whole table on a subset of it""" diff --git a/ingestion/src/metadata/sdk/data_quality/dataframes/dataframe_validation_engine.py b/ingestion/src/metadata/sdk/data_quality/dataframes/dataframe_validation_engine.py new file mode 100644 index 00000000000..ae43019e148 --- /dev/null +++ b/ingestion/src/metadata/sdk/data_quality/dataframes/dataframe_validation_engine.py @@ -0,0 +1,145 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Orchestration engine for DataFrame validation execution.""" +import logging +import time +from datetime import datetime +from typing import List, Tuple, Type + +from pandas import DataFrame + +from metadata.data_quality.validations.base_test_handler import BaseTestValidator +from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus +from metadata.generated.schema.tests.testCase import TestCase +from metadata.generated.schema.type.basic import Timestamp +from metadata.sdk.data_quality.dataframes.validation_results import ( + FailureMode, + ValidationResult, +) +from metadata.sdk.data_quality.dataframes.validators import VALIDATOR_REGISTRY + +logger = logging.getLogger(__name__) + + +class DataFrameValidationEngine: + """Orchestrates execution of multiple validators on a DataFrame.""" + + def __init__(self, test_cases: List[TestCase]): + self.test_cases: List[TestCase] = test_cases + + def execute( + self, + df: DataFrame, + mode: FailureMode = FailureMode.SHORT_CIRCUIT, + ) -> ValidationResult: + """Execute all validations and return aggregated results. + + Args: + df: DataFrame to validate + mode: Validation mode (only "short-circuit" supported) + + Returns: + ValidationResult with outcomes for all tests + """ + results: List[Tuple[TestCase, TestCaseResult]] = [] + start_time = time.time() + + for test_case in self.test_cases: + test_result = self._execute_single_test(df, test_case) + results.append((test_case, test_result)) + + if mode is FailureMode.SHORT_CIRCUIT and test_result.testCaseStatus in ( + TestCaseStatus.Failed, + TestCaseStatus.Aborted, + ): + break + + execution_time = (time.time() - start_time) * 1000 + return self._build_validation_result(results, execution_time) + + def _execute_single_test( + self, df: DataFrame, test_case: TestCase + ) -> TestCaseResult: + """Execute validation and return structured result. + + Returns: + TestValidationResult with validation outcome + """ + validator_class = self._get_validator_class(test_case) + + validator = validator_class( + runner=[df], + test_case=test_case, + execution_date=Timestamp(root=int(datetime.now().timestamp() * 1000)), + ) + + try: + return validator.run_validation() + except Exception as err: + message = ( + f"Error executing {test_case.testDefinition.fullyQualifiedName} - {err}" + ) + logger.exception(message) + return validator.get_test_case_result_object( + validator.execution_date, + TestCaseStatus.Aborted, + message, + [], + ) + + @staticmethod + def _build_validation_result( + test_results: List[Tuple[TestCase, TestCaseResult]], execution_time_ms: float + ) -> ValidationResult: + """Build aggregated validation result. + + Args: + test_results: Individual test results + execution_time_ms: Total execution time + + Returns: + ValidationResult with aggregated outcomes + """ + passed = sum( + 1 for _, r in test_results if r.testCaseStatus == TestCaseStatus.Success + ) + failed = len(test_results) - passed + success = failed == 0 + + return ValidationResult( + success=success, + total_tests=len(test_results), + passed_tests=passed, + failed_tests=failed, + test_cases_and_results=test_results, + execution_time_ms=execution_time_ms, + ) + + @staticmethod + def _get_validator_class(test_case: TestCase) -> Type[BaseTestValidator]: + """Resolve validator class from test definition name. + + Returns: + Validator class for the test definition + + Raises: + ValueError: If test definition is not supported + """ + validator_class = VALIDATOR_REGISTRY.get( + test_case.testDefinition.fullyQualifiedName # pyright: ignore[reportArgumentType] + ) + if not validator_class: + raise ValueError( + f"Unknown test definition: {test_case.testDefinition.fullyQualifiedName}" + ) + + return validator_class diff --git a/ingestion/src/metadata/sdk/data_quality/dataframes/dataframe_validator.py b/ingestion/src/metadata/sdk/data_quality/dataframes/dataframe_validator.py new file mode 100644 index 00000000000..ef5dcc789b1 --- /dev/null +++ b/ingestion/src/metadata/sdk/data_quality/dataframes/dataframe_validator.py @@ -0,0 +1,201 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DataFrame validation API.""" +import warnings +from typing import Any, Callable, Iterable, List, Optional, cast, final + +from pandas import DataFrame + +from metadata.generated.schema.tests.testCase import TestCase +from metadata.ingestion.ometa.ometa_api import OpenMetadata as OMeta +from metadata.sdk import OpenMetadata +from metadata.sdk import client as get_client +from metadata.sdk.data_quality.dataframes.custom_warnings import WholeTableTestsWarning +from metadata.sdk.data_quality.dataframes.dataframe_validation_engine import ( + DataFrameValidationEngine, +) +from metadata.sdk.data_quality.dataframes.models import create_mock_test_case +from metadata.sdk.data_quality.dataframes.validation_results import ( + FailureMode, + ValidationResult, +) +from metadata.sdk.data_quality.dataframes.validators import requires_whole_table +from metadata.sdk.data_quality.tests.base_tests import BaseTest + +ValidatorCallback = Callable[[DataFrame, ValidationResult], None] + + +@final +class DataFrameValidator: + """Facade for DataFrame data quality validation. + + Provides a simple interface to configure and execute data quality tests + on pandas DataFrames using OpenMetadata test definitions. + + Example: + validator = DataFrameValidator() + validator.add_test(ColumnValuesToBeNotNull(column="email")) + validator.add_test(ColumnValuesToBeUnique(column="customer_id")) + + result = validator.validate(df, mode=FailureMode.ShortCircuit) + if not result.success: + print(f"Validation failed: {result.failures}") + """ + + def __init__( + self, + client: Optional[ # pyright: ignore[reportRedeclaration] + OMeta[Any, Any] + ] = None, + ): + self._test_cases: List[TestCase] = [] + + if client is None: + metadata: OpenMetadata = get_client() + client: OMeta[Any, Any] = metadata.ometa + + self._client = client + + def add_test(self, test: BaseTest) -> None: + """Add a single test definition to be executed. + + Args: + test: Test definition (e.g., ColumnValuesToBeNotNull) + """ + self._test_cases.append(create_mock_test_case(test)) + + def add_tests(self, *tests: BaseTest) -> None: + """Add multiple test definitions at once. + + Args: + *tests: Variable number of test definitions + """ + self._test_cases.extend(create_mock_test_case(t) for t in tests) + + def add_openmetadata_test(self, test_fqn: str) -> None: + test_case = cast( + TestCase, + self._client.get_by_name( + TestCase, + test_fqn, + fields=["testDefinition", "testSuite"], + nullable=False, + ), + ) + + self._test_cases.append(test_case) + + def add_openmetadata_table_tests(self, table_fqn: str) -> None: + test_suite = self._client.get_executable_test_suite(table_fqn) + + if test_suite is None: + raise ValueError(f"Table {table_fqn!r} does not have a test suite to run") + + for test in test_suite.tests or []: + assert test.fullyQualifiedName is not None + self.add_openmetadata_test(test.fullyQualifiedName) + + def validate( + self, + df: DataFrame, + mode: FailureMode = FailureMode.SHORT_CIRCUIT, + ) -> ValidationResult: + """Execute all configured tests on the DataFrame. + + Args: + df: DataFrame to validate + mode: Validation mode (`FailureMode.ShortCircuit` stops on first failure) + + Returns: + ValidationResult with outcomes for all tests + """ + engine = DataFrameValidationEngine(self._test_cases) + return engine.execute(df, mode) + + def _check_full_table_tests_included(self) -> None: + test_names: set[str] = { # pyright: ignore[reportAssignmentType] + test.testDefinition.fullyQualifiedName + for test in self._test_cases + if requires_whole_table( + test.testDefinition.fullyQualifiedName # pyright: ignore[reportArgumentType] + ) + } + + if not test_names: + return + + warnings.warn( + WholeTableTestsWarning( + "Running tests that require the whole table on chunks could lead to false positives. " + + "For example, a DataFrame with 200 rows split in chunks of 50 could pass tests expecting " + + "DataFrames to contain max 100 rows.\n\n" + + "The following tests could have unexpected results:\n\n\t- " + + "\n\t- ".join(sorted(test_names)) + ) + ) + + def run( + self, + data: Iterable[DataFrame], + on_success: ValidatorCallback, + on_failure: ValidatorCallback, + mode: FailureMode = FailureMode.SHORT_CIRCUIT, + ) -> ValidationResult: + """Execute all configured tests on the DataFrame and call callbacks. + + Useful for running validation based on chunks, for example: + + ```python + validator = DataFrameValidator() + validator.add_test(ColumnValuesToBeNotNull(column="email")) + + def load_df_to_destination(df, result): + ... + + def rollback(df, result): + "Clears data previously loaded" + ... + + result = validator.run( + pandas.read_csv('somefile.csv', chunksize=1000), + on_success=load_df_to_destination, + on_failure=rollback, + mode=FailureMode.SHORT_CIRCUIT, + ) + ``` + + Args: + data: An iterable of pandas DataFrames + on_success: Callback to execute after successful validation + on_failure: Callback to execute after failed validation + mode: Validation mode (`FailureMode.ShortCircuit` stops on first failure) + + Returns: + Merged ValidationResult aggregating all batch validations + """ + self._check_full_table_tests_included() + + results: List[ValidationResult] = [] + + for df in data: + validation_result = self.validate(df, mode) + results.append(validation_result) + + if validation_result.success: + on_success(df, validation_result) + else: + on_failure(df, validation_result) + + if mode is FailureMode.SHORT_CIRCUIT: + break + + return ValidationResult.merge(*results) diff --git a/ingestion/src/metadata/sdk/data_quality/dataframes/models.py b/ingestion/src/metadata/sdk/data_quality/dataframes/models.py new file mode 100644 index 00000000000..40edc1d5a49 --- /dev/null +++ b/ingestion/src/metadata/sdk/data_quality/dataframes/models.py @@ -0,0 +1,44 @@ +from uuid import uuid4 + +from metadata.generated.schema.tests.testCase import TestCase +from metadata.generated.schema.type.entityReference import EntityReference +from metadata.sdk.data_quality import BaseTest, ColumnTest + + +class MockTestCase(TestCase): + """Mock test case.""" + + +def create_mock_test_case(test_definition: BaseTest) -> MockTestCase: + """Convert TestCaseDefinition to TestCase object. + + Returns: + Synthetic TestCase for DataFrame validation + """ + entity_link = "<#E::table::dataframe_validation>" + if isinstance(test_definition, ColumnTest): + entity_link = ( + f"<#E::table::dataframe_validation::columns::{test_definition.column_name}>" + ) + + return MockTestCase( # pyright: ignore[reportCallIssue] + id=uuid4(), + name=test_definition.name, + fullyQualifiedName=test_definition.name, + displayName=test_definition.display_name, + description=test_definition.description, + testDefinition=EntityReference( # pyright: ignore[reportCallIssue] + id=uuid4(), + name=test_definition.test_definition_name, + fullyQualifiedName=test_definition.test_definition_name, + type="testDefinition", + ), + entityLink=entity_link, + parameterValues=test_definition.parameters, + testSuite=EntityReference( # pyright: ignore[reportCallIssue] + id=uuid4(), + name="dataframe_validation", + type="testSuite", + ), + computePassedFailedRowCount=True, + ) diff --git a/ingestion/src/metadata/sdk/data_quality/dataframes/validation_results.py b/ingestion/src/metadata/sdk/data_quality/dataframes/validation_results.py new file mode 100644 index 00000000000..171f775fce7 --- /dev/null +++ b/ingestion/src/metadata/sdk/data_quality/dataframes/validation_results.py @@ -0,0 +1,243 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DataFrame validation result models.""" +import logging +from enum import Enum +from typing import List, Optional, Tuple, cast + +from pydantic import BaseModel + +from metadata.generated.schema.entity.data.table import Table +from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus +from metadata.generated.schema.tests.testCase import TestCase +from metadata.generated.schema.type.basic import FullyQualifiedEntityName +from metadata.sdk import OpenMetadata +from metadata.sdk import client as get_client +from metadata.sdk.data_quality.dataframes.models import MockTestCase +from metadata.utils.entity_link import ( + get_entity_link, # pyright: ignore[reportUnknownVariableType] +) +from metadata.utils.entity_link import get_column_name_or_none + +logger = logging.getLogger(__name__) + + +class FailureMode(Enum): + SHORT_CIRCUIT = "short-circuit" + + +class ValidationResult(BaseModel): + """Aggregated results from validating multiple tests on a DataFrame. + + Attributes: + success: True if all tests passed + total_tests: Total number of tests executed + passed_tests: Number of tests that passed + failed_tests: Number of tests that failed + test_results: Individual test results + execution_time_ms: Total execution time in milliseconds + """ + + success: bool + total_tests: int + passed_tests: int + failed_tests: int + test_cases_and_results: List[Tuple[TestCase, TestCaseResult]] + execution_time_ms: float + + @property + def failures(self) -> List[TestCaseResult]: + """Get only failed test results. + + Returns: + List of test results where status is Failed or Aborted + """ + return [ + result + for result in self.test_results + if result.testCaseStatus in (TestCaseStatus.Failed, TestCaseStatus.Aborted) + ] + + @property + def passes(self) -> List[TestCaseResult]: + """Get only passed test results. + + Returns: + List of test results where status is Success + """ + return [ + result + for result in self.test_results + if result.testCaseStatus == TestCaseStatus.Success + ] + + @property + def test_results(self) -> List[TestCaseResult]: + """Get all test results.""" + return [result for _, result in self.test_cases_and_results] + + def publish(self, table_fqn: str, client: Optional[OpenMetadata] = None) -> None: + """Publish test results to OpenMetadata. + Args: + table_fqn: Fully qualified table name + client: OpenMetadata client + """ + if client is None: + client = get_client() + + metadata = client.ometa + + for test_case, result in self.test_cases_and_results: + if isinstance(test_case, MockTestCase): + test_case = metadata.get_or_create_test_case( + test_case_fqn=f"{table_fqn}.{test_case.name.root}", + entity_link=get_entity_link( + Table, + table_fqn, + column_name=get_column_name_or_none(test_case.entityLink.root), + ), + test_definition_fqn=test_case.testDefinition.fullyQualifiedName, + test_case_parameter_values=test_case.parameterValues, + description=getattr(test_case.description, "root", None), + ) + + res = metadata.add_test_case_results( + result, + cast(FullyQualifiedEntityName, test_case.fullyQualifiedName).root, + ) + + logger.debug(f"Result: {res}") + + @classmethod + def merge(cls, *results: "ValidationResult") -> "ValidationResult": + """Merge multiple ValidationResult objects into one. + + Aggregates results from multiple validation runs, useful when validating + DataFrames in batches. When the same test case is run multiple times across + batches, results are aggregated by test case FQN. + + Args: + *results: Variable number of ValidationResult objects to merge + + Returns: + A new ValidationResult with aggregated test case results + + Raises: + ValueError: If no results are provided to merge + """ + if not results: + raise ValueError("At least one ValidationResult must be provided to merge") + + from collections import defaultdict + + aggregated_results: dict[ + str, List[Tuple[TestCase, TestCaseResult]] + ] = defaultdict(list) + total_execution_time = 0.0 + + for result in results: + for test_case, test_result in result.test_cases_and_results: + fqn = test_case.fullyQualifiedName + if fqn is None: + raise ValueError( + "Cannot merge results with test cases that have no fullyQualifiedName" + ) + aggregated_results[str(fqn)].append((test_case, test_result)) + total_execution_time += result.execution_time_ms + + merged_test_cases_and_results: List[Tuple[TestCase, TestCaseResult]] = [] + for fqn, test_cases_and_results_for_fqn in aggregated_results.items(): + test_case = test_cases_and_results_for_fqn[0][0] + results_for_test = [result for _, result in test_cases_and_results_for_fqn] + + merged_result = cls._aggregate_test_case_results(results_for_test) + merged_test_cases_and_results.append((test_case, merged_result)) + + total = len(merged_test_cases_and_results) + passed = sum( + 1 + for _, test_result in merged_test_cases_and_results + if test_result.testCaseStatus is TestCaseStatus.Success + ) + failed = total - passed + + return cls( + success=failed == 0, + total_tests=total, + passed_tests=passed, + failed_tests=failed, + test_cases_and_results=merged_test_cases_and_results, + execution_time_ms=total_execution_time, + ) + + @staticmethod + def _aggregate_test_case_results( + results: List[TestCaseResult], + ) -> TestCaseResult: + """Aggregate multiple TestCaseResult objects for the same test case. + + Combines metrics from multiple test runs by summing passed/failed rows + and determining overall status. + + Args: + results: List of TestCaseResult objects from different batch runs + + Returns: + A single aggregated TestCaseResult + """ + if not results: + raise ValueError("At least one TestCaseResult must be provided") + + if len(results) == 1: + return results[0] + + total_passed_rows = sum(r.passedRows or 0 for r in results) + total_failed_rows = sum(r.failedRows or 0 for r in results) + total_rows = total_passed_rows + total_failed_rows + + passed_rows_percentage = ( + (total_passed_rows / total_rows * 100) if total_rows > 0 else None + ) + failed_rows_percentage = ( + (total_failed_rows / total_rows * 100) if total_rows > 0 else None + ) + + overall_status = TestCaseStatus.Success + if any(r.testCaseStatus == TestCaseStatus.Aborted for r in results): + overall_status = TestCaseStatus.Aborted + elif any(r.testCaseStatus == TestCaseStatus.Failed for r in results): + overall_status = TestCaseStatus.Failed + + first_result = results[0] + merged_result_messages = [r.result for r in results if r.result] + + return TestCaseResult( + id=None, + testCaseFQN=first_result.testCaseFQN, + timestamp=first_result.timestamp, + testCaseStatus=overall_status, + result=( + "; ".join(merged_result_messages) if merged_result_messages else None + ), + sampleData=None, + testResultValue=None, + passedRows=total_passed_rows if total_rows > 0 else None, + failedRows=total_failed_rows if total_rows > 0 else None, + passedRowsPercentage=passed_rows_percentage, + failedRowsPercentage=failed_rows_percentage, + incidentId=None, + maxBound=first_result.maxBound, + minBound=first_result.minBound, + testCase=first_result.testCase, + testDefinition=first_result.testDefinition, + dimensionResults=None, + ) diff --git a/ingestion/src/metadata/sdk/data_quality/dataframes/validators.py b/ingestion/src/metadata/sdk/data_quality/dataframes/validators.py new file mode 100644 index 00000000000..4430796b576 --- /dev/null +++ b/ingestion/src/metadata/sdk/data_quality/dataframes/validators.py @@ -0,0 +1,133 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Registry of pandas validators.""" + +from metadata.data_quality.validations.column.pandas.columnValueLengthsToBeBetween import ( + ColumnValueLengthsToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValueMaxToBeBetween import ( + ColumnValueMaxToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValueMeanToBeBetween import ( + ColumnValueMeanToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValueMedianToBeBetween import ( + ColumnValueMedianToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValueMinToBeBetween import ( + ColumnValueMinToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesMissingCount import ( + ColumnValuesMissingCountValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesSumToBeBetween import ( + ColumnValuesSumToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValueStdDevToBeBetween import ( + ColumnValueStdDevToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToBeAtExpectedLocation import ( + ColumnValuesToBeAtExpectedLocationValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToBeBetween import ( + ColumnValuesToBeBetweenValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToBeInSet import ( + ColumnValuesToBeInSetValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToBeNotInSet import ( + ColumnValuesToBeNotInSetValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToBeNotNull import ( + ColumnValuesToBeNotNullValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToBeUnique import ( + ColumnValuesToBeUniqueValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToMatchRegex import ( + ColumnValuesToMatchRegexValidator, +) +from metadata.data_quality.validations.column.pandas.columnValuesToNotMatchRegex import ( + ColumnValuesToNotMatchRegexValidator, +) +from metadata.data_quality.validations.table.pandas.tableColumnCountToBeBetween import ( + TableColumnCountToBeBetweenValidator, +) +from metadata.data_quality.validations.table.pandas.tableColumnCountToEqual import ( + TableColumnCountToEqualValidator, +) +from metadata.data_quality.validations.table.pandas.tableColumnNameToExist import ( + TableColumnNameToExistValidator, +) +from metadata.data_quality.validations.table.pandas.tableColumnToMatchSet import ( + TableColumnToMatchSetValidator, +) +from metadata.data_quality.validations.table.pandas.tableRowCountToBeBetween import ( + TableRowCountToBeBetweenValidator, +) +from metadata.data_quality.validations.table.pandas.tableRowCountToEqual import ( + TableRowCountToEqualValidator, +) + +VALIDATOR_REGISTRY = { + "columnValuesToBeNotNull": ColumnValuesToBeNotNullValidator, + "columnValuesToBeUnique": ColumnValuesToBeUniqueValidator, + "columnValuesToBeBetween": ColumnValuesToBeBetweenValidator, + "columnValuesToBeInSet": ColumnValuesToBeInSetValidator, + "columnValuesToBeNotInSet": ColumnValuesToBeNotInSetValidator, + "columnValuesToMatchRegex": ColumnValuesToMatchRegexValidator, + "columnValuesToNotMatchRegex": ColumnValuesToNotMatchRegexValidator, + "columnValueLengthsToBeBetween": ColumnValueLengthsToBeBetweenValidator, + "columnValueMaxToBeBetween": ColumnValueMaxToBeBetweenValidator, + "columnValueMeanToBeBetween": ColumnValueMeanToBeBetweenValidator, + "columnValueMedianToBeBetween": ColumnValueMedianToBeBetweenValidator, + "columnValueMinToBeBetween": ColumnValueMinToBeBetweenValidator, + "columnValueStdDevToBeBetween": ColumnValueStdDevToBeBetweenValidator, + "columnValuesSumToBeBetween": ColumnValuesSumToBeBetweenValidator, + "columnValuesMissingCount": ColumnValuesMissingCountValidator, + "columnValuesToBeAtExpectedLocation": ColumnValuesToBeAtExpectedLocationValidator, + "tableRowCountToBeBetween": TableRowCountToBeBetweenValidator, + "tableRowCountToEqual": TableRowCountToEqualValidator, + "tableColumnCountToBeBetween": TableColumnCountToBeBetweenValidator, + "tableColumnCountToEqual": TableColumnCountToEqualValidator, + "tableColumnNameToExist": TableColumnNameToExistValidator, + "tableColumnToMatchSet": TableColumnToMatchSetValidator, +} + + +VALIDATORS_THAT_REQUIRE_FULL_TABLE = { + "columnValuesToBeUnique", + "columnValueMeanToBeBetween", + "columnValueMedianToBeBetween", + "columnValueStdDevToBeBetween", + "columnValuesSumToBeBetween", + "columnValuesMissingCount", + "tableRowCountToBeBetween", + "tableRowCountToEqual", +} + + +def requires_whole_table(validator_name: str) -> bool: + """Whether the validator requires a whole table to return appropriate results + + Examples: + - `columnValuesToBeUnique` needs to see the whole column to make sure uniqueness is met + - `tableRowCountToEqual` needs to see the whole table to make sure the expected row count is met + + These tests could return false positives when operating on batches + + Args: + validator_name: The name of the validator to check + + Returns: + Whether the validator requires a whole table to return appropriate results + """ + return validator_name in VALIDATORS_THAT_REQUIRE_FULL_TABLE diff --git a/ingestion/src/metadata/sdk/examples/dataframe_validation_example.py b/ingestion/src/metadata/sdk/examples/dataframe_validation_example.py new file mode 100644 index 00000000000..8959a8a1687 --- /dev/null +++ b/ingestion/src/metadata/sdk/examples/dataframe_validation_example.py @@ -0,0 +1,294 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Examples demonstrating DataFrame validation with OpenMetadata SDK.""" +# pyright: reportUnknownVariableType=false, reportAttributeAccessIssue=false, reportUnknownMemberType=false +# pyright: reportUnusedCallResult=false +# pylint: disable=W5001 +import pandas as pd + +from metadata.sdk import configure +from metadata.sdk.data_quality import ( + ColumnValuesToBeBetween, + ColumnValuesToBeNotNull, + ColumnValuesToBeUnique, + ColumnValuesToMatchRegex, + TableRowCountToBeBetween, +) +from metadata.sdk.data_quality.dataframes.dataframe_validator import DataFrameValidator +from metadata.sdk.data_quality.dataframes.validation_results import ( + FailureMode, + ValidationResult, +) + + +def basic_validation_example(): + """Basic example validating a customer DataFrame.""" + print("\n=== Basic DataFrame Validation Example ===\n") + + df = pd.DataFrame( + { + "customer_id": [1, 2, 3, 4, 5], + "email": [ + "alice@example.com", + "bob@example.com", + "carol@example.com", + "dave@example.com", + "eve@example.com", + ], + "age": [25, 30, 35, 40, 45], + } + ) + + validator = DataFrameValidator() + validator.add_test(ColumnValuesToBeNotNull(column="email")) + validator.add_test(ColumnValuesToBeUnique(column="customer_id")) + + result = validator.validate(df) + + if result.success: + print("✓ All validations passed!") + print( + f" Executed {result.total_tests} tests in {result.execution_time_ms:.2f}ms" + ) + else: + print("✗ Validation failed") + for failure in result.failures: + print(f" - {failure.test_name}: {failure.result_message}") + + +def multiple_tests_example(): + """Example with multiple validation rules.""" + print("\n=== Multiple Tests Validation Example ===\n") + + df = pd.DataFrame( + { + "customer_id": [1, 2, 3, 4, 5], + "email": [ + "alice@example.com", + "bob@example.com", + "carol@example.com", + "dave@example.com", + "eve@example.com", + ], + "age": [25, 30, 35, 40, 45], + "status": ["active", "active", "inactive", "active", "active"], + } + ) + + validator = DataFrameValidator() + validator.add_tests( + TableRowCountToBeBetween(min_count=1, max_count=1000), + ColumnValuesToBeNotNull(column="customer_id"), + ColumnValuesToBeNotNull(column="email"), + ColumnValuesToBeUnique(column="customer_id"), + ColumnValuesToBeBetween(column="age", min_value=0, max_value=120), + ColumnValuesToMatchRegex(column="email", regex=r"^[\w\.-]+@[\w\.-]+\.\w+$"), + ) + + result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + print(f"Validation: {'PASSED' if result.success else 'FAILED'}") + print(f"Tests: {result.passed_tests}/{result.total_tests} passed") + print(f"Execution time: {result.execution_time_ms:.2f}ms\n") + + for test_result in result.test_results: + status_icon = "✓" if test_result.status.value == "Success" else "✗" + print(f"{status_icon} {test_result.test_name}") + if test_result.passed_rows > 0: + print(f" Passed: {test_result.passed_rows}/{test_result.total_rows} rows") + if test_result.failed_rows > 0: + percentage = test_result.failed_rows / test_result.total_rows * 100 + print(f" Failed: {test_result.failed_rows} rows ({percentage:.1f}%)") + + +def integrating_with_openmetadata_example(): + """Integrating with OpenMetadata.""" + + def transform_to_dwh_table(raw_df: pd.DataFrame) -> pd.DataFrame: + """Transform the dataframe to dwh table.""" + return raw_df + + configure(host="http://localhost:8585/api", jwt_token="your jwt token") + + df = pd.read_parquet("s3://some_bucket/raw_table.parquet") + + df = transform_to_dwh_table(df) + + # Instantiate validator and load the executable test suite for a table + validator = DataFrameValidator() + validator.add_openmetadata_table_tests( + "DbService.database_name.schema_name.dwh_table" + ) + + result = validator.validate(df) + print(f"Validation: {'PASSED' if result.success else 'FAILED'}") + + # Publish the results back to Open Metadata + result.publish("DbService.database_name.schema_name.dwh_table") + + if result.success: + df.to_parquet("s3://some_bucket/dwh_table.parquet") + + +def processing_big_data_with_chunks_example(): + """Processing big data with chunks.""" + + configure(host="http://localhost:8585/api", jwt_token="your jwt token") + + validator = DataFrameValidator() + validator.add_openmetadata_table_tests( + "DbService.database_name.schema_name.dwh_table" + ) + + def load_df_to_destination(_df: pd.DataFrame, _result: ValidationResult): + """Loads data into destination.""" + + def rollback(_df: pd.DataFrame, _result: ValidationResult): + """Clears data previously loaded""" + + results = validator.run( + pd.read_csv("somefile.csv", chunksize=1000), + on_success=load_df_to_destination, + on_failure=rollback, + mode=FailureMode.SHORT_CIRCUIT, + ) + + results.publish("DbService.database_name.schema_name.dwh_table") + + +def validation_failure_example(): + """Example demonstrating validation failures.""" + print("\n=== Validation Failure Example ===\n") + + df = pd.DataFrame( + { + "customer_id": [1, 2, 2, 4, 5], + "email": [ + "alice@example.com", + None, + "carol@example.com", + "dave@example.com", + "invalid-email", + ], + "age": [25, 150, 35, -5, 45], + } + ) + + validator = DataFrameValidator() + validator.add_tests( + ColumnValuesToBeUnique(column="customer_id"), + ColumnValuesToBeNotNull(column="email"), + ColumnValuesToBeBetween(column="age", min_value=0, max_value=120), + ) + + result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + print(f"Validation: {'PASSED' if result.success else 'FAILED'}\n") + + if not result.success: + print("Failures detected:") + for failure in result.failures: + print(f"\n Test: {failure.test_name}") + print(f" Type: {failure.test_type}") + print(f" Message: {failure.result_message}") + print(f" Failed rows: {failure.failed_rows}/{failure.total_rows}") + + +def etl_pipeline_integration_example(): + """Example integrating validation into an ETL pipeline.""" + print("\n=== ETL Pipeline Integration Example ===\n") + + def extract_data(): + return pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "name": ["Alice", "Bob", "Carol", "Dave", "Eve"], + "value": [100, 200, 300, 400, 500], + } + ) + + def transform_data(df: pd.DataFrame) -> pd.DataFrame: + df = df.copy() + df["value_doubled"] = df["value"] * 2 + return df + + def validate_data(df: pd.DataFrame) -> ValidationResult: + validator = DataFrameValidator() + validator.add_tests( + TableRowCountToBeBetween(min_count=1, max_count=10000), + ColumnValuesToBeNotNull(column="id"), + ColumnValuesToBeUnique(column="id"), + ColumnValuesToBeBetween(column="value", min_value=0, max_value=10000), + ) + return validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + def load_data(df: pd.DataFrame) -> None: + print(f"Loading {len(df)} rows to data warehouse...") + + print("Starting ETL pipeline...") + print("\n1. Extract") + raw_df = extract_data() + print(f" Extracted {len(raw_df)} rows") + + print("\n2. Transform") + transformed_df = transform_data(raw_df) + print(f" Transformed {len(transformed_df)} rows") + + print("\n3. Validate") + validation_result = validate_data(transformed_df) + + if validation_result.success: + print(" ✓ Validation passed") + print("\n4. Load") + load_data(transformed_df) + print(" ✓ Data loaded successfully") + else: + print(" ✗ Validation failed") + print("\n Failures:") + for failure in validation_result.failures: + print(f" - {failure.test_name}: {failure.result_message}") + print("\n Pipeline aborted - data not loaded") + + +def short_circuit_mode_example(): + """Example demonstrating short-circuit mode behavior.""" + print("\n=== Short-Circuit Mode Example ===\n") + + df = pd.DataFrame( + { + "id": [1, 2, 2, 3], + "email": [None, None, None, "test@example.com"], + "age": [25, 30, 35, 40], + } + ) + + validator = DataFrameValidator() + validator.add_tests( + ColumnValuesToBeUnique(column="id"), + ColumnValuesToBeNotNull(column="email"), + ColumnValuesToBeBetween(column="age", min_value=0, max_value=120), + ) + + result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + print("Short-circuit mode stops at first failure:") + print(f" Tests executed: {len(result.test_results)} of {result.total_tests}") + print(f" First failure: {result.failures[0].test_name}") + print("\nRemaining tests were not executed due to short-circuit mode.") + + +if __name__ == "__main__": + basic_validation_example() + multiple_tests_example() + validation_failure_example() + etl_pipeline_integration_example() + short_circuit_mode_example() diff --git a/ingestion/src/metadata/utils/entity_link.py b/ingestion/src/metadata/utils/entity_link.py index 3408d20e9d9..a3ba3bad16c 100644 --- a/ingestion/src/metadata/utils/entity_link.py +++ b/ingestion/src/metadata/utils/entity_link.py @@ -106,6 +106,21 @@ def get_table_or_column_fqn(entity_link: str) -> str: ) +def get_column_name_or_none(entity_link: str) -> Optional[str]: + """It attempts to get a column from an entity link + + Args: + entity_link: entity link + + Returns: + The column name or None + """ + split_entity_link = split(entity_link) + if len(split_entity_link) == 4 and split_entity_link[2] == "columns": + return split_entity_link[3] + return None + + get_entity_link_registry = class_register() diff --git a/ingestion/tests/integration/sdk/conftest.py b/ingestion/tests/integration/sdk/conftest.py index 0854c14b63a..1d1e1bbff13 100644 --- a/ingestion/tests/integration/sdk/conftest.py +++ b/ingestion/tests/integration/sdk/conftest.py @@ -2,12 +2,252 @@ Minimal conftest for SDK integration tests. Override the parent conftest to avoid testcontainers dependency. """ + import pytest +from sqlalchemy import Column as SQAColumn +from sqlalchemy import Integer, MetaData, String +from sqlalchemy import Table as SQATable +from sqlalchemy import create_engine from _openmetadata_testutils.ometa import int_admin_ometa +from _openmetadata_testutils.postgres.conftest import postgres_container +from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest +from metadata.generated.schema.api.data.createDatabaseSchema import ( + CreateDatabaseSchemaRequest, +) +from metadata.generated.schema.api.services.createDatabaseService import ( + CreateDatabaseServiceRequest, +) +from metadata.generated.schema.entity.services.connections.database.common.basicAuth import ( + BasicAuth, +) +from metadata.generated.schema.entity.services.connections.database.postgresConnection import ( + PostgresConnection, +) +from metadata.generated.schema.entity.services.databaseService import ( + DatabaseConnection, + DatabaseService, + DatabaseServiceType, +) +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.workflow.metadata import MetadataWorkflow -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def metadata(): - """Provide authenticated OpenMetadata client""" return int_admin_ometa() + + +@pytest.fixture(scope="module") +def create_postgres_service(postgres_container, tmp_path_factory): + return CreateDatabaseServiceRequest( + name="dq_test_service_" + tmp_path_factory.mktemp("dq").name, + serviceType=DatabaseServiceType.Postgres, + connection=DatabaseConnection( + config=PostgresConnection( + username=postgres_container.username, + authType=BasicAuth(password=postgres_container.password), + hostPort="localhost:" + + str(postgres_container.get_exposed_port(postgres_container.port)), + database="dq_test_db", + ) + ), + ) + + +@pytest.fixture(scope="module") +def db_service(metadata, create_postgres_service, postgres_container): + engine = create_engine( + postgres_container.get_connection_url(), isolation_level="AUTOCOMMIT" + ) + engine.execute("CREATE DATABASE dq_test_db") + + service_entity = metadata.create_or_update(data=create_postgres_service) + service_entity.connection.config.authType.password = ( + create_postgres_service.connection.config.authType.password + ) + yield service_entity + + service = metadata.get_by_name( + DatabaseService, service_entity.fullyQualifiedName.root + ) + if service: + metadata.delete(DatabaseService, service.id, recursive=True, hard_delete=True) + + +@pytest.fixture(scope="module") +def database(metadata, db_service): + database_entity = metadata.create_or_update( + CreateDatabaseRequest( + name="dq_test_db", + service=db_service.fullyQualifiedName, + ) + ) + return database_entity + + +@pytest.fixture(scope="module") +def schema(metadata, database): + schema_entity = metadata.create_or_update( + CreateDatabaseSchemaRequest( + name="public", + database=database.fullyQualifiedName, + ) + ) + return schema_entity + + +@pytest.fixture(scope="module") +def test_data(db_service, postgres_container): + engine = create_engine( + postgres_container.get_connection_url().replace("/dvdrental", "/dq_test_db") + ) + + sql_metadata = MetaData() + + users_table = SQATable( + "users", + sql_metadata, + SQAColumn("id", Integer, primary_key=True), + SQAColumn("username", String(50), nullable=False), + SQAColumn("email", String(100)), + SQAColumn("age", Integer), + SQAColumn("score", Integer), + ) + + products_table = SQATable( + "products", + sql_metadata, + SQAColumn("product_id", Integer, primary_key=True), + SQAColumn("name", String(100)), + SQAColumn("price", Integer), + ) + + stg_products_table = SQATable( + "stg_products", + sql_metadata, + SQAColumn("id", Integer, primary_key=True), + SQAColumn("name", String(100)), + SQAColumn("price", Integer), + ) + + sql_metadata.create_all(engine) + + with engine.connect() as conn: + conn.execute( + users_table.insert(), + [ + { + "id": 1, + "username": "alice", + "email": "alice@example.com", + "age": 25, + "score": 85, + }, + { + "id": 2, + "username": "bob", + "email": "bob@example.com", + "age": 30, + "score": 90, + }, + {"id": 3, "username": "charlie", "email": None, "age": 35, "score": 75}, + { + "id": 4, + "username": "diana", + "email": "diana@example.com", + "age": 28, + "score": 95, + }, + { + "id": 5, + "username": "eve", + "email": "eve@example.com", + "age": 22, + "score": 88, + }, + ], + ) + + conn.execute( + products_table.insert(), + [ + {"product_id": 1, "name": "Widget", "price": 100}, + {"product_id": 2, "name": "Gadget", "price": 200}, + {"product_id": 3, "name": "Doohickey", "price": 150}, + ], + ) + + conn.execute( + stg_products_table.insert(), + [ + {"id": 1, "name": "Widget", "price": 100}, + {"id": 2, "name": "Gadget", "price": 200}, + {"id": 3, "name": "Doohickey", "price": 150}, + ], + ) + + return { + "users": users_table, + "products": products_table, + "stg_products": stg_products_table, + } + + +@pytest.fixture(scope="module") +def ingest_metadata(metadata, db_service, schema, test_data): + workflow_config = { + "source": { + "type": db_service.connection.config.type.value.lower(), + "serviceName": db_service.fullyQualifiedName.root, + "sourceConfig": { + "config": { + "type": "DatabaseMetadata", + "schemaFilterPattern": {"includes": ["public"]}, + } + }, + "serviceConnection": db_service.connection.model_dump(), + }, + "sink": {"type": "metadata-rest", "config": {}}, + "workflowConfig": { + "loggerLevel": "INFO", + "openMetadataServerConfig": metadata.config.model_dump(), + }, + } + + workflow = MetadataWorkflow.create(workflow_config) + workflow.execute() + workflow.raise_from_status() + + return workflow + + +@pytest.fixture(scope="module") +def patch_passwords(db_service, monkeymodule): + def override_password(getter): + def inner(*args, **kwargs): + result = getter(*args, **kwargs) + if isinstance(result, DatabaseService): + if result.fullyQualifiedName.root == db_service.fullyQualifiedName.root: + result.connection.config.authType.password = ( + db_service.connection.config.authType.password + ) + return result + + return inner + + monkeymodule.setattr( + "metadata.ingestion.ometa.ometa_api.OpenMetadata.get_by_name", + override_password(OpenMetadata.get_by_name), + ) + + monkeymodule.setattr( + "metadata.ingestion.ometa.ometa_api.OpenMetadata.get_by_id", + override_password(OpenMetadata.get_by_id), + ) + + +@pytest.fixture(scope="module") +def monkeymodule(): + with pytest.MonkeyPatch.context() as mp: + yield mp diff --git a/ingestion/tests/integration/sdk/data_quality/dataframes/test_dataframe_validator.py b/ingestion/tests/integration/sdk/data_quality/dataframes/test_dataframe_validator.py new file mode 100644 index 00000000000..679fe896efd --- /dev/null +++ b/ingestion/tests/integration/sdk/data_quality/dataframes/test_dataframe_validator.py @@ -0,0 +1,255 @@ +from typing import Any, Generator, Mapping +from unittest.mock import Mock, patch + +import pandas +import pytest +from dirty_equals import HasAttributes, IsList, IsTuple +from pandas import DataFrame +from sqlalchemy import create_engine +from sqlalchemy.sql.schema import Table as SQATable +from testcontainers.postgres import PostgresContainer + +from metadata.generated.schema.api.tests.createTestCase import CreateTestCaseRequest +from metadata.generated.schema.entity.data.table import Table +from metadata.generated.schema.entity.services.databaseService import DatabaseService +from metadata.generated.schema.tests.basic import TestCaseStatus +from metadata.generated.schema.tests.testCase import TestCase, TestCaseParameterValue +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.sdk.data_quality import ColumnValueMinToBeBetween +from metadata.sdk.data_quality.dataframes.dataframe_validator import DataFrameValidator +from metadata.utils.entity_link import get_entity_link + + +@pytest.fixture(scope="module") +def table_fqn(db_service: DatabaseService) -> str: + return f"{db_service.name.root}.dq_test_db.public.users" + + +@pytest.fixture(scope="module") +def column_unique_test( + table_fqn: str, metadata: OpenMetadata[TestCase, CreateTestCaseRequest] +) -> TestCase: + request = CreateTestCaseRequest( + name="column_not_null", + testDefinition="columnValuesToBeUnique", + entityLink=get_entity_link( + Table, + table_fqn, + column_name="username", + ), + ) + + test_case = metadata.create_or_update(request) + + return test_case + + +@pytest.fixture(scope="module") +def table_row_count_test( + table_fqn: str, metadata: OpenMetadata[TestCase, CreateTestCaseRequest] +) -> TestCase: + request = CreateTestCaseRequest( + name="table_row_count", + testDefinition="tableRowCountToEqual", + entityLink=get_entity_link(Table, table_fqn), + parameterValues=[TestCaseParameterValue(name="value", value="5")], + ) + + test_case = metadata.create_or_update(request) + + return test_case + + +@pytest.fixture(scope="module") +def dataframe( + test_data: Mapping[str, SQATable], postgres_container: PostgresContainer +) -> Generator[DataFrame, None, None]: + engine = create_engine( + postgres_container.get_connection_url().replace("/dvdrental", "/dq_test_db"), + isolation_level="AUTOCOMMIT", + ) + with engine.connect() as connection: + yield pandas.read_sql( + test_data["users"].select(), + connection, + ) + + +def test_it_runs_tests_from_openmetadata( + ingest_metadata: None, + metadata: OpenMetadata[Any, Any], + dataframe: DataFrame, + column_unique_test: TestCase, + table_row_count_test: TestCase, +) -> None: + validator = DataFrameValidator(client=metadata) + + validator.add_openmetadata_test(column_unique_test.fullyQualifiedName.root) + validator.add_openmetadata_test(table_row_count_test.fullyQualifiedName.root) + + result = validator.validate(dataframe) + + assert result == HasAttributes( + success=True, + total_tests=2, + passed_tests=2, + failed_tests=0, + ) + + +def test_it_runs_openmetadata_table_tests( + table_fqn: str, + ingest_metadata: None, + metadata: OpenMetadata[Any, Any], + dataframe: DataFrame, + column_unique_test: TestCase, + table_row_count_test: TestCase, +) -> None: + validator = DataFrameValidator(client=metadata) + + validator.add_openmetadata_table_tests(table_fqn) + + result = validator.validate(dataframe) + + assert result == HasAttributes( + success=True, + total_tests=2, + passed_tests=2, + failed_tests=0, + ) + + +class TestFullUseCase: + def test_it_runs_tests_and_publishes_results( + self, + table_fqn: str, + ingest_metadata: None, + metadata: OpenMetadata[Any, Any], + dataframe: DataFrame, + column_unique_test: TestCase, + table_row_count_test: TestCase, + ) -> None: + # First ensure only previously reported tests exist in OM + test_suite = metadata.get_executable_test_suite(table_fqn) + assert test_suite is not None + original_test_names = {t.fullyQualifiedName for t in test_suite.tests} + assert original_test_names == { + column_unique_test.fullyQualifiedName.root, + table_row_count_test.fullyQualifiedName.root, + } + + # Run validation + validator = DataFrameValidator(client=metadata) + + validator.add_openmetadata_table_tests(table_fqn) + # Forcing a failure with this test + validator.add_test( + ColumnValueMinToBeBetween( + name="column_value_min_to_be_between_90_and_100", + column="score", + min_value=90, + max_value=100, + ) + ) + + result = validator.validate(dataframe) + + assert result == HasAttributes( + success=False, + total_tests=3, + passed_tests=2, + failed_tests=1, + test_cases_and_results=IsList( + IsTuple( + HasAttributes( + fullyQualifiedName=HasAttributes( + root=column_unique_test.fullyQualifiedName.root + ), + ), + HasAttributes( + testCaseStatus=TestCaseStatus.Success, + ), + ), + IsTuple( + HasAttributes( + fullyQualifiedName=HasAttributes( + root=table_row_count_test.fullyQualifiedName.root + ), + ), + HasAttributes( + testCaseStatus=TestCaseStatus.Success, + ), + ), + IsTuple( + HasAttributes( + fullyQualifiedName=HasAttributes( + root="column_value_min_to_be_between_90_and_100" + ), + ), + HasAttributes( + testCaseStatus=TestCaseStatus.Failed, + ), + ), + check_order=False, + ), + ) + + # Publish results + with patch( + "metadata.sdk.data_quality.dataframes.validation_results.get_client", + return_value=Mock(ometa=metadata), + ) as mock_client: + result.publish( + table_fqn, + ) + + mock_client.assert_called_once() + + # Validate test defined in code has been pushed to OM as well + test_suite = metadata.get_executable_test_suite(table_fqn) + + assert test_suite is not None + assert len(test_suite.tests) == 3 + assert original_test_names.issubset( + {t.fullyQualifiedName for t in test_suite.tests} + ) + + assert {t.fullyQualifiedName for t in test_suite.tests} == { + column_unique_test.fullyQualifiedName.root, + table_row_count_test.fullyQualifiedName.root, + f"{table_fqn}.score.column_value_min_to_be_between_90_and_100", + } + + # Validate each test case result + required_fields = ["testCaseResult", "testSuite", "testDefinition"] + column_unique_result = metadata.get_by_name( + TestCase, column_unique_test.fullyQualifiedName.root, fields=required_fields + ).testCaseResult + assert column_unique_result == HasAttributes( + testCaseStatus=TestCaseStatus.Success + ) + + table_row_count_result = metadata.get_by_name( + TestCase, + table_row_count_test.fullyQualifiedName.root, + fields=required_fields, + ).testCaseResult + assert table_row_count_result == HasAttributes( + testCaseStatus=TestCaseStatus.Success + ) + + code_test_case_result = metadata.get_by_name( + TestCase, + f"{table_fqn}.score.column_value_min_to_be_between_90_and_100", + fields=required_fields, + ).testCaseResult + assert code_test_case_result == HasAttributes( + testCaseStatus=TestCaseStatus.Failed + ) + + # Clean up code test + metadata.delete_test_case( + f"{table_fqn}.score.column_value_min_to_be_between_90_and_100", + recursive=True, + hard=True, + ) diff --git a/ingestion/tests/integration/sdk/test_dq_as_code_integration.py b/ingestion/tests/integration/sdk/test_dq_as_code_integration.py index ffbb6bcda8e..bea9cc0c79b 100644 --- a/ingestion/tests/integration/sdk/test_dq_as_code_integration.py +++ b/ingestion/tests/integration/sdk/test_dq_as_code_integration.py @@ -6,36 +6,11 @@ import sys import pytest from dirty_equals import HasAttributes -from sqlalchemy import Column as SQAColumn -from sqlalchemy import Integer, MetaData, String -from sqlalchemy import Table as SQATable -from sqlalchemy import create_engine -from _openmetadata_testutils.ometa import int_admin_ometa -from _openmetadata_testutils.postgres.conftest import postgres_container -from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest -from metadata.generated.schema.api.data.createDatabaseSchema import ( - CreateDatabaseSchemaRequest, -) -from metadata.generated.schema.api.services.createDatabaseService import ( - CreateDatabaseServiceRequest, -) from metadata.generated.schema.entity.data.table import Table -from metadata.generated.schema.entity.services.connections.database.common.basicAuth import ( - BasicAuth, -) -from metadata.generated.schema.entity.services.connections.database.postgresConnection import ( - PostgresConnection, -) -from metadata.generated.schema.entity.services.databaseService import ( - DatabaseConnection, - DatabaseService, - DatabaseServiceType, -) from metadata.generated.schema.tests.basic import TestCaseStatus from metadata.generated.schema.tests.testCase import TestCase from metadata.generated.schema.type.basic import EntityLink -from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.sdk.data_quality import ( ColumnValuesToBeBetween, ColumnValuesToBeNotNull, @@ -45,7 +20,6 @@ from metadata.sdk.data_quality import ( TableRowCountToBeBetween, TestRunner, ) -from metadata.workflow.metadata import MetadataWorkflow if not sys.version_info >= (3, 9): pytest.skip( @@ -54,226 +28,6 @@ if not sys.version_info >= (3, 9): ) -@pytest.fixture(scope="module") -def metadata(): - return int_admin_ometa() - - -@pytest.fixture(scope="module") -def create_postgres_service(postgres_container, tmp_path_factory): - return CreateDatabaseServiceRequest( - name="dq_test_service_" + tmp_path_factory.mktemp("dq").name, - serviceType=DatabaseServiceType.Postgres, - connection=DatabaseConnection( - config=PostgresConnection( - username=postgres_container.username, - authType=BasicAuth(password=postgres_container.password), - hostPort="localhost:" - + str(postgres_container.get_exposed_port(postgres_container.port)), - database="dq_test_db", - ) - ), - ) - - -@pytest.fixture(scope="module") -def db_service(metadata, create_postgres_service, postgres_container): - engine = create_engine( - postgres_container.get_connection_url(), isolation_level="AUTOCOMMIT" - ) - engine.execute("CREATE DATABASE dq_test_db") - - service_entity = metadata.create_or_update(data=create_postgres_service) - service_entity.connection.config.authType.password = ( - create_postgres_service.connection.config.authType.password - ) - yield service_entity - - service = metadata.get_by_name( - DatabaseService, service_entity.fullyQualifiedName.root - ) - if service: - metadata.delete(DatabaseService, service.id, recursive=True, hard_delete=True) - - -@pytest.fixture(scope="module") -def database(metadata, db_service): - database_entity = metadata.create_or_update( - CreateDatabaseRequest( - name="dq_test_db", - service=db_service.fullyQualifiedName, - ) - ) - return database_entity - - -@pytest.fixture(scope="module") -def schema(metadata, database): - schema_entity = metadata.create_or_update( - CreateDatabaseSchemaRequest( - name="public", - database=database.fullyQualifiedName, - ) - ) - return schema_entity - - -@pytest.fixture(scope="module") -def test_data(postgres_container): - engine = create_engine( - postgres_container.get_connection_url().replace("/dvdrental", "/dq_test_db") - ) - - sql_metadata = MetaData() - - users_table = SQATable( - "users", - sql_metadata, - SQAColumn("id", Integer, primary_key=True), - SQAColumn("username", String(50), nullable=False), - SQAColumn("email", String(100)), - SQAColumn("age", Integer), - SQAColumn("score", Integer), - ) - - products_table = SQATable( - "products", - sql_metadata, - SQAColumn("product_id", Integer, primary_key=True), - SQAColumn("name", String(100)), - SQAColumn("price", Integer), - ) - - stg_products_table = SQATable( - "stg_products", - sql_metadata, - SQAColumn("id", Integer, primary_key=True), - SQAColumn("name", String(100)), - SQAColumn("price", Integer), - ) - - sql_metadata.create_all(engine) - - with engine.connect() as conn: - conn.execute( - users_table.insert(), - [ - { - "id": 1, - "username": "alice", - "email": "alice@example.com", - "age": 25, - "score": 85, - }, - { - "id": 2, - "username": "bob", - "email": "bob@example.com", - "age": 30, - "score": 90, - }, - {"id": 3, "username": "charlie", "email": None, "age": 35, "score": 75}, - { - "id": 4, - "username": "diana", - "email": "diana@example.com", - "age": 28, - "score": 95, - }, - { - "id": 5, - "username": "eve", - "email": "eve@example.com", - "age": 22, - "score": 88, - }, - ], - ) - - conn.execute( - products_table.insert(), - [ - {"product_id": 1, "name": "Widget", "price": 100}, - {"product_id": 2, "name": "Gadget", "price": 200}, - {"product_id": 3, "name": "Doohickey", "price": 150}, - ], - ) - - conn.execute( - stg_products_table.insert(), - [ - {"id": 1, "name": "Widget", "price": 100}, - {"id": 2, "name": "Gadget", "price": 200}, - {"id": 3, "name": "Doohickey", "price": 150}, - ], - ) - - return { - "users": users_table, - "products": products_table, - "stg_products": stg_products_table, - } - - -@pytest.fixture(scope="module") -def ingest_metadata(metadata, db_service, schema, test_data): - workflow_config = { - "source": { - "type": db_service.connection.config.type.value.lower(), - "serviceName": db_service.fullyQualifiedName.root, - "sourceConfig": { - "config": { - "type": "DatabaseMetadata", - "schemaFilterPattern": {"includes": ["public"]}, - } - }, - "serviceConnection": db_service.connection.model_dump(), - }, - "sink": {"type": "metadata-rest", "config": {}}, - "workflowConfig": { - "loggerLevel": "INFO", - "openMetadataServerConfig": metadata.config.model_dump(), - }, - } - - workflow = MetadataWorkflow.create(workflow_config) - workflow.execute() - workflow.raise_from_status() - - return workflow - - -@pytest.fixture(scope="module") -def patch_passwords(db_service, monkeymodule): - def override_password(getter): - def inner(*args, **kwargs): - result = getter(*args, **kwargs) - if isinstance(result, DatabaseService): - if result.fullyQualifiedName.root == db_service.fullyQualifiedName.root: - result.connection.config.authType.password = ( - db_service.connection.config.authType.password - ) - return result - - return inner - - monkeymodule.setattr( - "metadata.ingestion.ometa.ometa_api.OpenMetadata.get_by_name", - override_password(OpenMetadata.get_by_name), - ) - - monkeymodule.setattr( - "metadata.ingestion.ometa.ometa_api.OpenMetadata.get_by_id", - override_password(OpenMetadata.get_by_id), - ) - - -@pytest.fixture(scope="module") -def monkeymodule(): - with pytest.MonkeyPatch.context() as mp: - yield mp - - def test_table_row_count_tests( metadata, db_service, diff --git a/ingestion/tests/unit/sdk/data_quality/test_dataframe_validator.py b/ingestion/tests/unit/sdk/data_quality/test_dataframe_validator.py new file mode 100644 index 00000000000..3b00c84b286 --- /dev/null +++ b/ingestion/tests/unit/sdk/data_quality/test_dataframe_validator.py @@ -0,0 +1,561 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for DataFrame validator.""" +from typing import Generator, List, Tuple +from unittest.mock import Mock + +import pandas as pd +import pytest +from pandas import DataFrame + +from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus +from metadata.generated.schema.tests.testCase import TestCase +from metadata.sdk.data_quality import ( + ColumnValuesToBeNotNull, + ColumnValuesToBeUnique, + TableRowCountToBeBetween, +) +from metadata.sdk.data_quality.dataframes.custom_warnings import WholeTableTestsWarning +from metadata.sdk.data_quality.dataframes.dataframe_validator import DataFrameValidator +from metadata.sdk.data_quality.dataframes.validation_results import ( + FailureMode, + ValidationResult, +) + + +class TestDataFrameValidator: + """Test DataFrameValidator initialization and configuration.""" + + def test_validator_initialization(self): + """Test validator can be created.""" + validator = DataFrameValidator(Mock()) + assert validator is not None + assert validator._test_cases == [] + + def test_add_single_test(self): + """Test adding a single test definition.""" + validator = DataFrameValidator(Mock()) + validator.add_test(ColumnValuesToBeNotNull(column="email")) + assert len(validator._test_cases) == 1 + + def test_add_multiple_tests(self): + """Test adding multiple test definitions at once.""" + validator = DataFrameValidator(Mock()) + validator.add_tests( + ColumnValuesToBeNotNull(column="email"), + ColumnValuesToBeUnique(column="id"), + TableRowCountToBeBetween(min_count=1, max_count=100), + ) + assert len(validator._test_cases) == 3 + + +class TestValidationSuccess: + """Test successful validation scenarios.""" + + def test_validate_not_null_success(self): + """Test validation passes with valid DataFrame.""" + df = pd.DataFrame({"email": ["a@b.com", "c@d.com", "e@f.com"]}) + validator = DataFrameValidator(Mock()) + validator.add_test(ColumnValuesToBeNotNull(column="email")) + + result = validator.validate(df) + + assert result.success is True + assert result.passed_tests == 1 + assert result.failed_tests == 0 + assert len(result.test_results) == 1 + assert result.test_results[0].testCaseStatus is TestCaseStatus.Success + + def test_validate_unique_success(self): + """Test uniqueness validation passes.""" + df = pd.DataFrame({"id": [1, 2, 3, 4, 5]}) + validator = DataFrameValidator(Mock()) + validator.add_test(ColumnValuesToBeUnique(column="id")) + + result = validator.validate(df) + + assert result.success is True + assert result.passed_tests == 1 + + def test_validate_multiple_tests_success(self): + """Test multiple tests all passing.""" + df = pd.DataFrame( + { + "id": [1, 2, 3], + "email": ["a@b.com", "c@d.com", "e@f.com"], + "age": [25, 30, 35], + } + ) + validator = DataFrameValidator(Mock()) + validator.add_tests( + ColumnValuesToBeNotNull(column="email"), + ColumnValuesToBeUnique(column="id"), + TableRowCountToBeBetween(min_count=1, max_count=10), + ) + + result = validator.validate(df) + + assert result.success is True + assert result.passed_tests == 3 + assert result.failed_tests == 0 + + +class TestValidationFailure: + """Test validation failure scenarios.""" + + def test_validate_not_null_failure(self): + """Test validation fails with null values.""" + df = pd.DataFrame({"email": ["a@b.com", None, "e@f.com"]}) + validator = DataFrameValidator(Mock()) + validator.add_test(ColumnValuesToBeNotNull(column="email")) + + result = validator.validate(df) + + assert result.success is False + assert result.passed_tests == 0 + assert result.failed_tests == 1 + assert result.test_results[0].testCaseStatus is TestCaseStatus.Failed + + def test_validate_unique_failure(self): + """Test uniqueness validation fails with duplicates.""" + df = pd.DataFrame({"id": [1, 2, 2, 3]}) + validator = DataFrameValidator(Mock()) + validator.add_test(ColumnValuesToBeUnique(column="id")) + + result = validator.validate(df) + + assert result.success is False + assert result.failed_tests == 1 + + def test_validate_row_count_failure(self): + """Test row count validation fails.""" + df = pd.DataFrame({"col": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]}) + validator = DataFrameValidator(Mock()) + validator.add_test(TableRowCountToBeBetween(min_count=1, max_count=10)) + + result = validator.validate(df) + + assert result.success is False + + +class TestShortCircuitMode: + """Test short-circuit validation mode.""" + + def test_short_circuit_stops_on_first_failure(self): + """Test short-circuit mode stops after first failure.""" + df = pd.DataFrame( + { + "id": [1, 2, 2], + "email": [None, None, None], + } + ) + validator = DataFrameValidator(Mock()) + validator.add_tests( + ColumnValuesToBeUnique(column="id"), + ColumnValuesToBeNotNull(column="email"), + ) + + result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + assert result.success is False + assert len(result.test_results) == 1 + + def test_short_circuit_continues_on_success(self): + """Test short-circuit mode continues when tests pass.""" + df = pd.DataFrame( + { + "id": [1, 2, 3], + "email": ["a@b.com", "c@d.com", "e@f.com"], + } + ) + validator = DataFrameValidator(Mock()) + validator.add_tests( + ColumnValuesToBeUnique(column="id"), + ColumnValuesToBeNotNull(column="email"), + ) + + result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + assert result.success is True + assert len(result.test_results) == 2 + + +class TestValidationResultProperties: + """Test ValidationResult helper properties.""" + + def test_failures_property(self): + """Test failures property returns only failed tests.""" + df = pd.DataFrame( + { + "id": [1, 2, 2], + "email": ["a@b.com", "c@d.com", "e@f.com"], + } + ) + validator = DataFrameValidator(Mock()) + validator.add_tests( + ColumnValuesToBeUnique(column="id"), + ColumnValuesToBeNotNull(column="email"), + ) + + result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + assert len(result.failures) == 1 + + def test_passes_property(self): + """Test passes property returns only passed tests.""" + df = pd.DataFrame( + { + "id": [1, 2, 3], + "email": ["a@b.com", None, "e@f.com"], + } + ) + validator = DataFrameValidator(Mock()) + validator.add_tests( + ColumnValuesToBeUnique(column="id"), + ColumnValuesToBeNotNull(column="email"), + ) + + result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT) + + assert len(result.passes) == 1 + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_dataframe(self): + """Test validation on empty DataFrame.""" + df = pd.DataFrame({"email": []}) + validator = DataFrameValidator(Mock()) + validator.add_test(ColumnValuesToBeNotNull(column="email")) + + result = validator.validate(df) + + assert result.success is True + + def test_missing_column(self): + """Test validation with missing column.""" + df = pd.DataFrame({"name": ["Alice", "Bob"]}) + validator = DataFrameValidator(Mock()) + validator.add_test(ColumnValuesToBeNotNull(column="email")) + + result = validator.validate(df) + + assert result.success is False + assert result.test_results[0].testCaseStatus is TestCaseStatus.Aborted + + def test_no_tests_configured(self): + """Test validation with no tests configured.""" + df = pd.DataFrame({"col": [1, 2, 3]}) + validator = DataFrameValidator(Mock()) + + result = validator.validate(df) + + assert result.success is True + assert result.total_tests == 0 + assert len(result.test_results) == 0 + + def test_execution_time_recorded(self): + """Test that execution time is recorded.""" + df = pd.DataFrame({"col": [1, 2, 3]}) + validator = DataFrameValidator(Mock()) + validator.add_test(TableRowCountToBeBetween(min_count=1, max_count=10)) + + result = validator.validate(df) + + assert result.execution_time_ms > 0 + + +class TracksValidationCallbacks: + def __init__(self) -> None: + self.calls: List[Tuple[DataFrame, ValidationResult]] = [] + + @property + def times_called(self) -> int: + return len(self.calls) + + @property + def was_called(self) -> bool: + return self.times_called > 0 + + def __call__(self, df: DataFrame, result: ValidationResult) -> None: + self.calls.append((df, result)) + + +@pytest.mark.filterwarnings( + "error::metadata.sdk.data_quality.dataframes.custom_warnings.WholeTableTestsWarning" +) +class TestValidatorRun: + @pytest.fixture + def on_success_callback(self) -> TracksValidationCallbacks: + return TracksValidationCallbacks() + + @pytest.fixture + def on_failure_callback(self) -> TracksValidationCallbacks: + return TracksValidationCallbacks() + + @pytest.fixture + def validator(self) -> DataFrameValidator: + df_validator = DataFrameValidator(Mock()) + df_validator.add_tests( + ColumnValuesToBeNotNull(column="id"), + ) + return df_validator + + def test_it_only_calls_on_success( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + dfs = (pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) for i in range(0, 9, 3)) + + validator.run(dfs, on_success_callback, on_failure_callback) + + assert on_success_callback.times_called == 3 + assert on_failure_callback.was_called is False + + def test_it_calls_both_on_success_and_on_failure( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + def generate_data() -> Generator[DataFrame, None, None]: + for i in range(0, 6, 3): + yield pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) + yield pd.DataFrame({"id": [None]}) + + validator.run(generate_data(), on_success_callback, on_failure_callback) + + assert on_success_callback.times_called == 2 + assert on_failure_callback.times_called == 1 + assert pd.DataFrame({"id": [None]}).equals(on_failure_callback.calls[0][0]) + + def test_it_aborts_on_failure( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + def generate_data() -> Generator[DataFrame, None, None]: + yield pd.DataFrame({"id": [None]}) + for i in range(0, 6, 3): + yield pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) + + validator.run(generate_data(), on_success_callback, on_failure_callback) + + assert on_success_callback.was_called is False + assert on_failure_callback.times_called == 1 + assert pd.DataFrame({"id": [None]}).equals(on_failure_callback.calls[0][0]) + + def test_it_warns_when_using_tests_that_require_the_whole_table_or_column( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + dfs = (pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) for i in range(0, 9, 3)) + + validator.add_tests( + TableRowCountToBeBetween(min_count=1, max_count=10), + ColumnValuesToBeUnique(column="id"), + ) + + expected_warning_match = ( + "The following tests could have unexpected results:\n\n" + + "\t- columnValuesToBeUnique\n" + + "\t- tableRowCountToBeBetween" + ) + + with pytest.warns(WholeTableTestsWarning, match=expected_warning_match): + validator.run(dfs, on_success_callback, on_failure_callback) + + def test_run_returns_merged_validation_result( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + dfs = (pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) for i in range(0, 9, 3)) + + result = validator.run(dfs, on_success_callback, on_failure_callback) + + assert isinstance(result, ValidationResult) + assert result.success is True + + def test_merged_result_has_correct_aggregated_metrics( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + dfs = [ + pd.DataFrame({"id": [1, 2, 3]}), + pd.DataFrame({"id": [4, 5, 6]}), + pd.DataFrame({"id": [7, 8, 9]}), + ] + + result = validator.run(iter(dfs), on_success_callback, on_failure_callback) + + assert result.total_tests == 1 + assert result.passed_tests == 1 + assert result.failed_tests == 0 + assert result.success is True + + def test_merged_result_aggregates_failures_correctly( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + dfs = [ + pd.DataFrame({"id": [1, 2, 3]}), + pd.DataFrame({"id": [None]}), + ] + + result = validator.run(iter(dfs), on_success_callback, on_failure_callback) + + assert result.total_tests == 1 + assert result.passed_tests == 0 + assert result.failed_tests == 1 + assert result.success is False + + def test_merged_result_contains_aggregated_test_case_results( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + validator.add_test(ColumnValuesToBeUnique(column="id")) + + dfs = [ + pd.DataFrame({"id": [1, 2, 3]}), + pd.DataFrame({"id": [4, 5, 6]}), + ] + + with pytest.warns(WholeTableTestsWarning): + result = validator.run(iter(dfs), on_success_callback, on_failure_callback) + + assert len(result.test_cases_and_results) == 2 + assert all( + isinstance(test_case, TestCase) and isinstance(test_result, TestCaseResult) + for test_case, test_result in result.test_cases_and_results + ) + _, aggregated_not_null_result = result.test_cases_and_results[0] + assert aggregated_not_null_result.passedRows == 6 + + def test_merged_result_sums_execution_times( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + dfs = [ + pd.DataFrame({"id": [1, 2, 3]}), + pd.DataFrame({"id": [4, 5, 6]}), + pd.DataFrame({"id": [7, 8, 9]}), + ] + + result = validator.run(iter(dfs), on_success_callback, on_failure_callback) + + assert result.execution_time_ms > 0 + individual_times = [ + call[1].execution_time_ms for call in on_success_callback.calls + ] + assert result.execution_time_ms == sum(individual_times) + + def test_merged_result_with_mixed_success_and_failure( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + def generate_mixed_data() -> Generator[DataFrame, None, None]: + yield pd.DataFrame({"id": [1, 2, 3]}) + yield pd.DataFrame({"id": [4, 5, 6]}) + yield pd.DataFrame({"id": [None]}) + + result = validator.run( + generate_mixed_data(), on_success_callback, on_failure_callback + ) + + assert result.total_tests == 1 + assert result.passed_tests == 0 + assert result.failed_tests == 1 + assert result.success is False + assert len(result.test_cases_and_results) == 1 + + def test_merged_result_aggregates_multiple_tests_per_batch( + self, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + validator = DataFrameValidator(Mock()) + validator.add_tests( + ColumnValuesToBeNotNull(column="id"), + ColumnValuesToBeUnique(column="id"), + TableRowCountToBeBetween(min_count=1, max_count=10), + ) + + dfs = [ + pd.DataFrame({"id": [1, 2, 3]}), + pd.DataFrame({"id": [4, 5, 6]}), + ] + + with pytest.warns(WholeTableTestsWarning): + result = validator.run(iter(dfs), on_success_callback, on_failure_callback) + + assert result.total_tests == 3 + assert result.passed_tests == 3 + assert result.failed_tests == 0 + assert result.success is True + + def test_merged_result_with_short_circuit_on_failure( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + def generate_data() -> Generator[DataFrame, None, None]: + yield pd.DataFrame({"id": [None]}) + yield pd.DataFrame({"id": [1, 2, 3]}) + yield pd.DataFrame({"id": [4, 5, 6]}) + + result = validator.run( + generate_data(), on_success_callback, on_failure_callback + ) + + assert result.total_tests == 1 + assert result.passed_tests == 0 + assert result.failed_tests == 1 + assert result.success is False + + def test_merged_result_reflects_all_batches_processed( + self, + validator: DataFrameValidator, + on_success_callback: TracksValidationCallbacks, + on_failure_callback: TracksValidationCallbacks, + ) -> None: + batch_count = 5 + dfs = [ + pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) + for i in range(0, batch_count * 3, 3) + ] + + result = validator.run(iter(dfs), on_success_callback, on_failure_callback) + + assert result.total_tests == 1 + assert len(result.test_cases_and_results) == 1 + assert on_success_callback.times_called == batch_count + _, aggregated_result = result.test_cases_and_results[0] + assert aggregated_result.passedRows == batch_count * 3 diff --git a/ingestion/tests/unit/sdk/data_quality/test_validation_results.py b/ingestion/tests/unit/sdk/data_quality/test_validation_results.py new file mode 100644 index 00000000000..31ef23a4469 --- /dev/null +++ b/ingestion/tests/unit/sdk/data_quality/test_validation_results.py @@ -0,0 +1,377 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ValidationResult.""" +from datetime import datetime +from uuid import UUID + +import pytest + +from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus +from metadata.generated.schema.tests.testCase import TestCase +from metadata.generated.schema.type.basic import FullyQualifiedEntityName, Timestamp +from metadata.generated.schema.type.entityReference import EntityReference +from metadata.sdk.data_quality.dataframes.validation_results import ValidationResult + + +def create_test_case(fqn: str) -> TestCase: + """Helper to create a test case with minimal required fields.""" + return TestCase( + name=fqn.split(".")[-1], + fullyQualifiedName=FullyQualifiedEntityName(fqn), + testDefinition=EntityReference( + id=UUID("12345678-1234-1234-1234-123456789abc"), + type="testDefinition", + fullyQualifiedName=fqn, + ), + entityLink="<#E::table::test_table>", + testSuite=EntityReference( + id=UUID("87654321-4321-4321-4321-cba987654321"), + type="testSuite", + fullyQualifiedName="test_suite", + ), + ) + + +def create_test_result( + status: TestCaseStatus, + passed_rows: int = 0, + failed_rows: int = 0, +) -> TestCaseResult: + """Helper to create a test case result.""" + total = passed_rows + failed_rows + return TestCaseResult( + timestamp=Timestamp(int(datetime.now().timestamp() * 1000)), + testCaseStatus=status, + passedRows=passed_rows if total > 0 else None, + failedRows=failed_rows if total > 0 else None, + passedRowsPercentage=(passed_rows / total * 100) if total > 0 else None, + failedRowsPercentage=(failed_rows / total * 100) if total > 0 else None, + ) + + +class TestValidationResultMerge: + """Test ValidationResult.merge method.""" + + def test_merge_single_result(self) -> None: + """Test merging a single ValidationResult returns equivalent result.""" + test_case = create_test_case("test.case.one") + test_result = create_test_result(TestCaseStatus.Success, passed_rows=100) + + result = ValidationResult( + success=True, + total_tests=1, + passed_tests=1, + failed_tests=0, + test_cases_and_results=[(test_case, test_result)], + execution_time_ms=10.0, + ) + + merged = ValidationResult.merge(result) + + assert merged.success is True + assert merged.total_tests == 1 + assert merged.passed_tests == 1 + assert merged.failed_tests == 0 + assert merged.execution_time_ms == 10.0 + assert len(merged.test_cases_and_results) == 1 + + def test_merge_multiple_results_same_test_case(self) -> None: + """Test merging results for same test case aggregates metrics.""" + test_case = create_test_case("test.case.one") + + result1 = ValidationResult( + success=True, + total_tests=1, + passed_tests=1, + failed_tests=0, + test_cases_and_results=[ + (test_case, create_test_result(TestCaseStatus.Success, passed_rows=50)) + ], + execution_time_ms=10.0, + ) + + result2 = ValidationResult( + success=True, + total_tests=1, + passed_tests=1, + failed_tests=0, + test_cases_and_results=[ + (test_case, create_test_result(TestCaseStatus.Success, passed_rows=30)) + ], + execution_time_ms=8.0, + ) + + merged = ValidationResult.merge(result1, result2) + + assert merged.success is True + assert merged.total_tests == 1 + assert merged.passed_tests == 1 + assert merged.failed_tests == 0 + assert merged.execution_time_ms == 18.0 + assert len(merged.test_cases_and_results) == 1 + + _, aggregated_result = merged.test_cases_and_results[0] + assert aggregated_result.passedRows == 80 + assert aggregated_result.failedRows == 0 + assert aggregated_result.testCaseStatus == TestCaseStatus.Success + + def test_merge_aggregates_passed_and_failed_rows(self) -> None: + """Test that passed and failed rows are summed correctly.""" + test_case = create_test_case("test.case.one") + + result1 = ValidationResult( + success=False, + total_tests=1, + passed_tests=0, + failed_tests=1, + test_cases_and_results=[ + ( + test_case, + create_test_result( + TestCaseStatus.Failed, passed_rows=40, failed_rows=10 + ), + ) + ], + execution_time_ms=10.0, + ) + + result2 = ValidationResult( + success=False, + total_tests=1, + passed_tests=0, + failed_tests=1, + test_cases_and_results=[ + ( + test_case, + create_test_result( + TestCaseStatus.Failed, passed_rows=30, failed_rows=20 + ), + ) + ], + execution_time_ms=12.0, + ) + + merged = ValidationResult.merge(result1, result2) + + assert merged.success is False + assert merged.total_tests == 1 + assert merged.failed_tests == 1 + assert merged.execution_time_ms == 22.0 + + _, aggregated_result = merged.test_cases_and_results[0] + assert aggregated_result.passedRows == 70 + assert aggregated_result.failedRows == 30 + assert aggregated_result.testCaseStatus == TestCaseStatus.Failed + assert aggregated_result.passedRowsPercentage == pytest.approx(70.0) + assert aggregated_result.failedRowsPercentage == pytest.approx(30.0) + + def test_merge_multiple_test_cases(self) -> None: + """Test merging results with multiple different test cases.""" + test_case1 = create_test_case("test.case.one") + test_case2 = create_test_case("test.case.two") + + result1 = ValidationResult( + success=True, + total_tests=2, + passed_tests=2, + failed_tests=0, + test_cases_and_results=[ + ( + test_case1, + create_test_result(TestCaseStatus.Success, passed_rows=50), + ), + ( + test_case2, + create_test_result(TestCaseStatus.Success, passed_rows=30), + ), + ], + execution_time_ms=15.0, + ) + + result2 = ValidationResult( + success=True, + total_tests=2, + passed_tests=2, + failed_tests=0, + test_cases_and_results=[ + ( + test_case1, + create_test_result(TestCaseStatus.Success, passed_rows=25), + ), + ( + test_case2, + create_test_result(TestCaseStatus.Success, passed_rows=35), + ), + ], + execution_time_ms=12.0, + ) + + merged = ValidationResult.merge(result1, result2) + + assert merged.success is True + assert merged.total_tests == 2 + assert merged.passed_tests == 2 + assert merged.failed_tests == 0 + assert merged.execution_time_ms == 27.0 + assert len(merged.test_cases_and_results) == 2 + + fqns_to_results = { + tc.fullyQualifiedName.root: result + for tc, result in merged.test_cases_and_results + } + + assert fqns_to_results["test.case.one"].passedRows == 75 + assert fqns_to_results["test.case.two"].passedRows == 65 + + def test_merge_with_failure_status_propagates(self) -> None: + """Test that if any batch fails, overall status is Failed.""" + test_case = create_test_case("test.case.one") + + result1 = ValidationResult( + success=True, + total_tests=1, + passed_tests=1, + failed_tests=0, + test_cases_and_results=[ + (test_case, create_test_result(TestCaseStatus.Success, passed_rows=50)) + ], + execution_time_ms=10.0, + ) + + result2 = ValidationResult( + success=False, + total_tests=1, + passed_tests=0, + failed_tests=1, + test_cases_and_results=[ + ( + test_case, + create_test_result( + TestCaseStatus.Failed, passed_rows=20, failed_rows=10 + ), + ) + ], + execution_time_ms=10.0, + ) + + merged = ValidationResult.merge(result1, result2) + + assert merged.success is False + assert merged.failed_tests == 1 + + _, aggregated_result = merged.test_cases_and_results[0] + assert aggregated_result.testCaseStatus == TestCaseStatus.Failed + + def test_merge_with_aborted_status_takes_precedence(self) -> None: + """Test that Aborted status takes precedence over Failed.""" + test_case = create_test_case("test.case.one") + + result1 = ValidationResult( + success=False, + total_tests=1, + passed_tests=0, + failed_tests=1, + test_cases_and_results=[ + ( + test_case, + create_test_result( + TestCaseStatus.Failed, passed_rows=40, failed_rows=10 + ), + ) + ], + execution_time_ms=10.0, + ) + + result2 = ValidationResult( + success=False, + total_tests=1, + passed_tests=0, + failed_tests=1, + test_cases_and_results=[ + (test_case, create_test_result(TestCaseStatus.Aborted)) + ], + execution_time_ms=5.0, + ) + + merged = ValidationResult.merge(result1, result2) + + _, aggregated_result = merged.test_cases_and_results[0] + assert aggregated_result.testCaseStatus == TestCaseStatus.Aborted + + def test_merge_empty_raises_error(self) -> None: + """Test that merging with no results raises ValueError.""" + with pytest.raises(ValueError, match="At least one ValidationResult"): + ValidationResult.merge() + + def test_merge_without_fqn_raises_error(self) -> None: + """Test that merging test cases without FQN raises ValueError.""" + test_case = TestCase( + name="test", + fullyQualifiedName=None, + testDefinition=EntityReference( + id=UUID("12345678-1234-1234-1234-123456789abc"), + type="testDefinition", + fullyQualifiedName="test.def", + ), + entityLink="<#E::table::test_table>", + testSuite=EntityReference( + id=UUID("87654321-4321-4321-4321-cba987654321"), + type="testSuite", + fullyQualifiedName="test_suite", + ), + ) + + result = ValidationResult( + success=True, + total_tests=1, + passed_tests=1, + failed_tests=0, + test_cases_and_results=[ + (test_case, create_test_result(TestCaseStatus.Success)) + ], + execution_time_ms=10.0, + ) + + with pytest.raises(ValueError, match="no fullyQualifiedName"): + ValidationResult.merge(result) + + def test_merge_preserves_test_case_reference(self) -> None: + """Test that the merged result contains the original test case reference.""" + test_case = create_test_case("test.case.one") + + result1 = ValidationResult( + success=True, + total_tests=1, + passed_tests=1, + failed_tests=0, + test_cases_and_results=[ + (test_case, create_test_result(TestCaseStatus.Success, passed_rows=50)) + ], + execution_time_ms=10.0, + ) + + result2 = ValidationResult( + success=True, + total_tests=1, + passed_tests=1, + failed_tests=0, + test_cases_and_results=[ + (test_case, create_test_result(TestCaseStatus.Success, passed_rows=30)) + ], + execution_time_ms=8.0, + ) + + merged = ValidationResult.merge(result1, result2) + + merged_test_case, _ = merged.test_cases_and_results[0] + assert merged_test_case.fullyQualifiedName.root == "test.case.one" + assert merged_test_case.name.root == "one"