Simplified API for validating DataFrames (#24009)

* Refactor previous tests for shared resources

* Add validation result models

This also includes a method for merging them, useful when running validation in batches

* Added `DataFrameValidationEngine` for running tests

This also includes a registry for mapping test names to pandas test classes

* Implement the DataFrameValidator facade

This includes the logic to load tests from different sources (OpenMetadata or code) and pass them down to the engine.

It also includes tests for the integration with OpenMetadata

* Add examples for the API

* Apply comments
This commit is contained in:
Eugenio 2025-11-04 09:52:43 +01:00 committed by GitHub
parent 73da5b507d
commit 42416a513e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 2564 additions and 249 deletions

View File

@ -17,6 +17,7 @@ To be used by OpenMetadata class
import traceback
from datetime import datetime
from typing import List, Optional, Type, Union
from urllib.parse import urlencode, urljoin
from uuid import UUID
from metadata.generated.schema.api.tests.createLogicalTestCases import (
@ -170,7 +171,7 @@ class OMetaTestsMixin:
test_definition_fqn: Optional[str] = None,
test_case_parameter_values: Optional[List[TestCaseParameterValue]] = None,
description: Optional[str] = None,
):
) -> TestCase:
"""Get or create a test case
Args:
@ -203,6 +204,34 @@ class OMetaTestsMixin:
)
return test_case
def get_executable_test_suite(self, table_fqn: str) -> Optional[TestSuite]:
"""Given an entity fqn, retrieve the link test suite if it exists
Args:
table_fqn (str): entity fully qualified name
Returns:
An instance of TestSuite or None
"""
table_entity = self.get_by_name(
entity=Table, fqn=table_fqn, fields=["testSuite"]
)
if not table_entity:
raise RuntimeError(
f"Unable to find table {table_fqn} in OpenMetadata. "
"This could be because the table has not been ingested yet or your JWT Token is expired or missing."
)
if not table_entity.testSuite:
return None
return self.get_by_name(
entity=TestSuite,
fqn=table_entity.testSuite.fullyQualifiedName,
fields=["tests"],
nullable=False,
)
def get_or_create_executable_test_suite(
self, entity_fqn: str
) -> Union[EntityReference, TestSuite]:
@ -379,3 +408,25 @@ class OMetaTestsMixin:
data=inspection_query,
)
return TestCase(**resp)
def delete_test_case(
self,
test_case_fqn: str,
recursive: bool = True,
hard: bool = False,
) -> None:
"""Delete a test case
Args:
test_case_fqn: Fully qualified name of the test case to delete
recursive (bool, optional): delete children if true
hard (bool, optional): hard delete if true
"""
params = urlencode(
dict(
recursive="true" if recursive else "false",
hardDelete="true" if hard else "false",
)
)
url = f"{self.get_suffix(TestCase)}/name/{quote(test_case_fqn)}"
self.client.delete(urljoin(url, "?" + params))

View File

@ -0,0 +1,2 @@
class WholeTableTestsWarning(RuntimeWarning):
"""Warns when the user runs tests that require the whole table on a subset of it"""

View File

@ -0,0 +1,145 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Orchestration engine for DataFrame validation execution."""
import logging
import time
from datetime import datetime
from typing import List, Tuple, Type
from pandas import DataFrame
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus
from metadata.generated.schema.tests.testCase import TestCase
from metadata.generated.schema.type.basic import Timestamp
from metadata.sdk.data_quality.dataframes.validation_results import (
FailureMode,
ValidationResult,
)
from metadata.sdk.data_quality.dataframes.validators import VALIDATOR_REGISTRY
logger = logging.getLogger(__name__)
class DataFrameValidationEngine:
"""Orchestrates execution of multiple validators on a DataFrame."""
def __init__(self, test_cases: List[TestCase]):
self.test_cases: List[TestCase] = test_cases
def execute(
self,
df: DataFrame,
mode: FailureMode = FailureMode.SHORT_CIRCUIT,
) -> ValidationResult:
"""Execute all validations and return aggregated results.
Args:
df: DataFrame to validate
mode: Validation mode (only "short-circuit" supported)
Returns:
ValidationResult with outcomes for all tests
"""
results: List[Tuple[TestCase, TestCaseResult]] = []
start_time = time.time()
for test_case in self.test_cases:
test_result = self._execute_single_test(df, test_case)
results.append((test_case, test_result))
if mode is FailureMode.SHORT_CIRCUIT and test_result.testCaseStatus in (
TestCaseStatus.Failed,
TestCaseStatus.Aborted,
):
break
execution_time = (time.time() - start_time) * 1000
return self._build_validation_result(results, execution_time)
def _execute_single_test(
self, df: DataFrame, test_case: TestCase
) -> TestCaseResult:
"""Execute validation and return structured result.
Returns:
TestValidationResult with validation outcome
"""
validator_class = self._get_validator_class(test_case)
validator = validator_class(
runner=[df],
test_case=test_case,
execution_date=Timestamp(root=int(datetime.now().timestamp() * 1000)),
)
try:
return validator.run_validation()
except Exception as err:
message = (
f"Error executing {test_case.testDefinition.fullyQualifiedName} - {err}"
)
logger.exception(message)
return validator.get_test_case_result_object(
validator.execution_date,
TestCaseStatus.Aborted,
message,
[],
)
@staticmethod
def _build_validation_result(
test_results: List[Tuple[TestCase, TestCaseResult]], execution_time_ms: float
) -> ValidationResult:
"""Build aggregated validation result.
Args:
test_results: Individual test results
execution_time_ms: Total execution time
Returns:
ValidationResult with aggregated outcomes
"""
passed = sum(
1 for _, r in test_results if r.testCaseStatus == TestCaseStatus.Success
)
failed = len(test_results) - passed
success = failed == 0
return ValidationResult(
success=success,
total_tests=len(test_results),
passed_tests=passed,
failed_tests=failed,
test_cases_and_results=test_results,
execution_time_ms=execution_time_ms,
)
@staticmethod
def _get_validator_class(test_case: TestCase) -> Type[BaseTestValidator]:
"""Resolve validator class from test definition name.
Returns:
Validator class for the test definition
Raises:
ValueError: If test definition is not supported
"""
validator_class = VALIDATOR_REGISTRY.get(
test_case.testDefinition.fullyQualifiedName # pyright: ignore[reportArgumentType]
)
if not validator_class:
raise ValueError(
f"Unknown test definition: {test_case.testDefinition.fullyQualifiedName}"
)
return validator_class

View File

@ -0,0 +1,201 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DataFrame validation API."""
import warnings
from typing import Any, Callable, Iterable, List, Optional, cast, final
from pandas import DataFrame
from metadata.generated.schema.tests.testCase import TestCase
from metadata.ingestion.ometa.ometa_api import OpenMetadata as OMeta
from metadata.sdk import OpenMetadata
from metadata.sdk import client as get_client
from metadata.sdk.data_quality.dataframes.custom_warnings import WholeTableTestsWarning
from metadata.sdk.data_quality.dataframes.dataframe_validation_engine import (
DataFrameValidationEngine,
)
from metadata.sdk.data_quality.dataframes.models import create_mock_test_case
from metadata.sdk.data_quality.dataframes.validation_results import (
FailureMode,
ValidationResult,
)
from metadata.sdk.data_quality.dataframes.validators import requires_whole_table
from metadata.sdk.data_quality.tests.base_tests import BaseTest
ValidatorCallback = Callable[[DataFrame, ValidationResult], None]
@final
class DataFrameValidator:
"""Facade for DataFrame data quality validation.
Provides a simple interface to configure and execute data quality tests
on pandas DataFrames using OpenMetadata test definitions.
Example:
validator = DataFrameValidator()
validator.add_test(ColumnValuesToBeNotNull(column="email"))
validator.add_test(ColumnValuesToBeUnique(column="customer_id"))
result = validator.validate(df, mode=FailureMode.ShortCircuit)
if not result.success:
print(f"Validation failed: {result.failures}")
"""
def __init__(
self,
client: Optional[ # pyright: ignore[reportRedeclaration]
OMeta[Any, Any]
] = None,
):
self._test_cases: List[TestCase] = []
if client is None:
metadata: OpenMetadata = get_client()
client: OMeta[Any, Any] = metadata.ometa
self._client = client
def add_test(self, test: BaseTest) -> None:
"""Add a single test definition to be executed.
Args:
test: Test definition (e.g., ColumnValuesToBeNotNull)
"""
self._test_cases.append(create_mock_test_case(test))
def add_tests(self, *tests: BaseTest) -> None:
"""Add multiple test definitions at once.
Args:
*tests: Variable number of test definitions
"""
self._test_cases.extend(create_mock_test_case(t) for t in tests)
def add_openmetadata_test(self, test_fqn: str) -> None:
test_case = cast(
TestCase,
self._client.get_by_name(
TestCase,
test_fqn,
fields=["testDefinition", "testSuite"],
nullable=False,
),
)
self._test_cases.append(test_case)
def add_openmetadata_table_tests(self, table_fqn: str) -> None:
test_suite = self._client.get_executable_test_suite(table_fqn)
if test_suite is None:
raise ValueError(f"Table {table_fqn!r} does not have a test suite to run")
for test in test_suite.tests or []:
assert test.fullyQualifiedName is not None
self.add_openmetadata_test(test.fullyQualifiedName)
def validate(
self,
df: DataFrame,
mode: FailureMode = FailureMode.SHORT_CIRCUIT,
) -> ValidationResult:
"""Execute all configured tests on the DataFrame.
Args:
df: DataFrame to validate
mode: Validation mode (`FailureMode.ShortCircuit` stops on first failure)
Returns:
ValidationResult with outcomes for all tests
"""
engine = DataFrameValidationEngine(self._test_cases)
return engine.execute(df, mode)
def _check_full_table_tests_included(self) -> None:
test_names: set[str] = { # pyright: ignore[reportAssignmentType]
test.testDefinition.fullyQualifiedName
for test in self._test_cases
if requires_whole_table(
test.testDefinition.fullyQualifiedName # pyright: ignore[reportArgumentType]
)
}
if not test_names:
return
warnings.warn(
WholeTableTestsWarning(
"Running tests that require the whole table on chunks could lead to false positives. "
+ "For example, a DataFrame with 200 rows split in chunks of 50 could pass tests expecting "
+ "DataFrames to contain max 100 rows.\n\n"
+ "The following tests could have unexpected results:\n\n\t- "
+ "\n\t- ".join(sorted(test_names))
)
)
def run(
self,
data: Iterable[DataFrame],
on_success: ValidatorCallback,
on_failure: ValidatorCallback,
mode: FailureMode = FailureMode.SHORT_CIRCUIT,
) -> ValidationResult:
"""Execute all configured tests on the DataFrame and call callbacks.
Useful for running validation based on chunks, for example:
```python
validator = DataFrameValidator()
validator.add_test(ColumnValuesToBeNotNull(column="email"))
def load_df_to_destination(df, result):
...
def rollback(df, result):
"Clears data previously loaded"
...
result = validator.run(
pandas.read_csv('somefile.csv', chunksize=1000),
on_success=load_df_to_destination,
on_failure=rollback,
mode=FailureMode.SHORT_CIRCUIT,
)
```
Args:
data: An iterable of pandas DataFrames
on_success: Callback to execute after successful validation
on_failure: Callback to execute after failed validation
mode: Validation mode (`FailureMode.ShortCircuit` stops on first failure)
Returns:
Merged ValidationResult aggregating all batch validations
"""
self._check_full_table_tests_included()
results: List[ValidationResult] = []
for df in data:
validation_result = self.validate(df, mode)
results.append(validation_result)
if validation_result.success:
on_success(df, validation_result)
else:
on_failure(df, validation_result)
if mode is FailureMode.SHORT_CIRCUIT:
break
return ValidationResult.merge(*results)

View File

@ -0,0 +1,44 @@
from uuid import uuid4
from metadata.generated.schema.tests.testCase import TestCase
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.sdk.data_quality import BaseTest, ColumnTest
class MockTestCase(TestCase):
"""Mock test case."""
def create_mock_test_case(test_definition: BaseTest) -> MockTestCase:
"""Convert TestCaseDefinition to TestCase object.
Returns:
Synthetic TestCase for DataFrame validation
"""
entity_link = "<#E::table::dataframe_validation>"
if isinstance(test_definition, ColumnTest):
entity_link = (
f"<#E::table::dataframe_validation::columns::{test_definition.column_name}>"
)
return MockTestCase( # pyright: ignore[reportCallIssue]
id=uuid4(),
name=test_definition.name,
fullyQualifiedName=test_definition.name,
displayName=test_definition.display_name,
description=test_definition.description,
testDefinition=EntityReference( # pyright: ignore[reportCallIssue]
id=uuid4(),
name=test_definition.test_definition_name,
fullyQualifiedName=test_definition.test_definition_name,
type="testDefinition",
),
entityLink=entity_link,
parameterValues=test_definition.parameters,
testSuite=EntityReference( # pyright: ignore[reportCallIssue]
id=uuid4(),
name="dataframe_validation",
type="testSuite",
),
computePassedFailedRowCount=True,
)

View File

@ -0,0 +1,243 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DataFrame validation result models."""
import logging
from enum import Enum
from typing import List, Optional, Tuple, cast
from pydantic import BaseModel
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus
from metadata.generated.schema.tests.testCase import TestCase
from metadata.generated.schema.type.basic import FullyQualifiedEntityName
from metadata.sdk import OpenMetadata
from metadata.sdk import client as get_client
from metadata.sdk.data_quality.dataframes.models import MockTestCase
from metadata.utils.entity_link import (
get_entity_link, # pyright: ignore[reportUnknownVariableType]
)
from metadata.utils.entity_link import get_column_name_or_none
logger = logging.getLogger(__name__)
class FailureMode(Enum):
SHORT_CIRCUIT = "short-circuit"
class ValidationResult(BaseModel):
"""Aggregated results from validating multiple tests on a DataFrame.
Attributes:
success: True if all tests passed
total_tests: Total number of tests executed
passed_tests: Number of tests that passed
failed_tests: Number of tests that failed
test_results: Individual test results
execution_time_ms: Total execution time in milliseconds
"""
success: bool
total_tests: int
passed_tests: int
failed_tests: int
test_cases_and_results: List[Tuple[TestCase, TestCaseResult]]
execution_time_ms: float
@property
def failures(self) -> List[TestCaseResult]:
"""Get only failed test results.
Returns:
List of test results where status is Failed or Aborted
"""
return [
result
for result in self.test_results
if result.testCaseStatus in (TestCaseStatus.Failed, TestCaseStatus.Aborted)
]
@property
def passes(self) -> List[TestCaseResult]:
"""Get only passed test results.
Returns:
List of test results where status is Success
"""
return [
result
for result in self.test_results
if result.testCaseStatus == TestCaseStatus.Success
]
@property
def test_results(self) -> List[TestCaseResult]:
"""Get all test results."""
return [result for _, result in self.test_cases_and_results]
def publish(self, table_fqn: str, client: Optional[OpenMetadata] = None) -> None:
"""Publish test results to OpenMetadata.
Args:
table_fqn: Fully qualified table name
client: OpenMetadata client
"""
if client is None:
client = get_client()
metadata = client.ometa
for test_case, result in self.test_cases_and_results:
if isinstance(test_case, MockTestCase):
test_case = metadata.get_or_create_test_case(
test_case_fqn=f"{table_fqn}.{test_case.name.root}",
entity_link=get_entity_link(
Table,
table_fqn,
column_name=get_column_name_or_none(test_case.entityLink.root),
),
test_definition_fqn=test_case.testDefinition.fullyQualifiedName,
test_case_parameter_values=test_case.parameterValues,
description=getattr(test_case.description, "root", None),
)
res = metadata.add_test_case_results(
result,
cast(FullyQualifiedEntityName, test_case.fullyQualifiedName).root,
)
logger.debug(f"Result: {res}")
@classmethod
def merge(cls, *results: "ValidationResult") -> "ValidationResult":
"""Merge multiple ValidationResult objects into one.
Aggregates results from multiple validation runs, useful when validating
DataFrames in batches. When the same test case is run multiple times across
batches, results are aggregated by test case FQN.
Args:
*results: Variable number of ValidationResult objects to merge
Returns:
A new ValidationResult with aggregated test case results
Raises:
ValueError: If no results are provided to merge
"""
if not results:
raise ValueError("At least one ValidationResult must be provided to merge")
from collections import defaultdict
aggregated_results: dict[
str, List[Tuple[TestCase, TestCaseResult]]
] = defaultdict(list)
total_execution_time = 0.0
for result in results:
for test_case, test_result in result.test_cases_and_results:
fqn = test_case.fullyQualifiedName
if fqn is None:
raise ValueError(
"Cannot merge results with test cases that have no fullyQualifiedName"
)
aggregated_results[str(fqn)].append((test_case, test_result))
total_execution_time += result.execution_time_ms
merged_test_cases_and_results: List[Tuple[TestCase, TestCaseResult]] = []
for fqn, test_cases_and_results_for_fqn in aggregated_results.items():
test_case = test_cases_and_results_for_fqn[0][0]
results_for_test = [result for _, result in test_cases_and_results_for_fqn]
merged_result = cls._aggregate_test_case_results(results_for_test)
merged_test_cases_and_results.append((test_case, merged_result))
total = len(merged_test_cases_and_results)
passed = sum(
1
for _, test_result in merged_test_cases_and_results
if test_result.testCaseStatus is TestCaseStatus.Success
)
failed = total - passed
return cls(
success=failed == 0,
total_tests=total,
passed_tests=passed,
failed_tests=failed,
test_cases_and_results=merged_test_cases_and_results,
execution_time_ms=total_execution_time,
)
@staticmethod
def _aggregate_test_case_results(
results: List[TestCaseResult],
) -> TestCaseResult:
"""Aggregate multiple TestCaseResult objects for the same test case.
Combines metrics from multiple test runs by summing passed/failed rows
and determining overall status.
Args:
results: List of TestCaseResult objects from different batch runs
Returns:
A single aggregated TestCaseResult
"""
if not results:
raise ValueError("At least one TestCaseResult must be provided")
if len(results) == 1:
return results[0]
total_passed_rows = sum(r.passedRows or 0 for r in results)
total_failed_rows = sum(r.failedRows or 0 for r in results)
total_rows = total_passed_rows + total_failed_rows
passed_rows_percentage = (
(total_passed_rows / total_rows * 100) if total_rows > 0 else None
)
failed_rows_percentage = (
(total_failed_rows / total_rows * 100) if total_rows > 0 else None
)
overall_status = TestCaseStatus.Success
if any(r.testCaseStatus == TestCaseStatus.Aborted for r in results):
overall_status = TestCaseStatus.Aborted
elif any(r.testCaseStatus == TestCaseStatus.Failed for r in results):
overall_status = TestCaseStatus.Failed
first_result = results[0]
merged_result_messages = [r.result for r in results if r.result]
return TestCaseResult(
id=None,
testCaseFQN=first_result.testCaseFQN,
timestamp=first_result.timestamp,
testCaseStatus=overall_status,
result=(
"; ".join(merged_result_messages) if merged_result_messages else None
),
sampleData=None,
testResultValue=None,
passedRows=total_passed_rows if total_rows > 0 else None,
failedRows=total_failed_rows if total_rows > 0 else None,
passedRowsPercentage=passed_rows_percentage,
failedRowsPercentage=failed_rows_percentage,
incidentId=None,
maxBound=first_result.maxBound,
minBound=first_result.minBound,
testCase=first_result.testCase,
testDefinition=first_result.testDefinition,
dimensionResults=None,
)

View File

@ -0,0 +1,133 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Registry of pandas validators."""
from metadata.data_quality.validations.column.pandas.columnValueLengthsToBeBetween import (
ColumnValueLengthsToBeBetweenValidator,
)
from metadata.data_quality.validations.column.pandas.columnValueMaxToBeBetween import (
ColumnValueMaxToBeBetweenValidator,
)
from metadata.data_quality.validations.column.pandas.columnValueMeanToBeBetween import (
ColumnValueMeanToBeBetweenValidator,
)
from metadata.data_quality.validations.column.pandas.columnValueMedianToBeBetween import (
ColumnValueMedianToBeBetweenValidator,
)
from metadata.data_quality.validations.column.pandas.columnValueMinToBeBetween import (
ColumnValueMinToBeBetweenValidator,
)
from metadata.data_quality.validations.column.pandas.columnValuesMissingCount import (
ColumnValuesMissingCountValidator,
)
from metadata.data_quality.validations.column.pandas.columnValuesSumToBeBetween import (
ColumnValuesSumToBeBetweenValidator,
)
from metadata.data_quality.validations.column.pandas.columnValueStdDevToBeBetween import (
ColumnValueStdDevToBeBetweenValidator,
)
from metadata.data_quality.validations.column.pandas.columnValuesToBeAtExpectedLocation import (
ColumnValuesToBeAtExpectedLocationValidator,
)
from metadata.data_quality.validations.column.pandas.columnValuesToBeBetween import (
ColumnValuesToBeBetweenValidator,
)
from metadata.data_quality.validations.column.pandas.columnValuesToBeInSet import (
ColumnValuesToBeInSetValidator,
)
from metadata.data_quality.validations.column.pandas.columnValuesToBeNotInSet import (
ColumnValuesToBeNotInSetValidator,
)
from metadata.data_quality.validations.column.pandas.columnValuesToBeNotNull import (
ColumnValuesToBeNotNullValidator,
)
from metadata.data_quality.validations.column.pandas.columnValuesToBeUnique import (
ColumnValuesToBeUniqueValidator,
)
from metadata.data_quality.validations.column.pandas.columnValuesToMatchRegex import (
ColumnValuesToMatchRegexValidator,
)
from metadata.data_quality.validations.column.pandas.columnValuesToNotMatchRegex import (
ColumnValuesToNotMatchRegexValidator,
)
from metadata.data_quality.validations.table.pandas.tableColumnCountToBeBetween import (
TableColumnCountToBeBetweenValidator,
)
from metadata.data_quality.validations.table.pandas.tableColumnCountToEqual import (
TableColumnCountToEqualValidator,
)
from metadata.data_quality.validations.table.pandas.tableColumnNameToExist import (
TableColumnNameToExistValidator,
)
from metadata.data_quality.validations.table.pandas.tableColumnToMatchSet import (
TableColumnToMatchSetValidator,
)
from metadata.data_quality.validations.table.pandas.tableRowCountToBeBetween import (
TableRowCountToBeBetweenValidator,
)
from metadata.data_quality.validations.table.pandas.tableRowCountToEqual import (
TableRowCountToEqualValidator,
)
VALIDATOR_REGISTRY = {
"columnValuesToBeNotNull": ColumnValuesToBeNotNullValidator,
"columnValuesToBeUnique": ColumnValuesToBeUniqueValidator,
"columnValuesToBeBetween": ColumnValuesToBeBetweenValidator,
"columnValuesToBeInSet": ColumnValuesToBeInSetValidator,
"columnValuesToBeNotInSet": ColumnValuesToBeNotInSetValidator,
"columnValuesToMatchRegex": ColumnValuesToMatchRegexValidator,
"columnValuesToNotMatchRegex": ColumnValuesToNotMatchRegexValidator,
"columnValueLengthsToBeBetween": ColumnValueLengthsToBeBetweenValidator,
"columnValueMaxToBeBetween": ColumnValueMaxToBeBetweenValidator,
"columnValueMeanToBeBetween": ColumnValueMeanToBeBetweenValidator,
"columnValueMedianToBeBetween": ColumnValueMedianToBeBetweenValidator,
"columnValueMinToBeBetween": ColumnValueMinToBeBetweenValidator,
"columnValueStdDevToBeBetween": ColumnValueStdDevToBeBetweenValidator,
"columnValuesSumToBeBetween": ColumnValuesSumToBeBetweenValidator,
"columnValuesMissingCount": ColumnValuesMissingCountValidator,
"columnValuesToBeAtExpectedLocation": ColumnValuesToBeAtExpectedLocationValidator,
"tableRowCountToBeBetween": TableRowCountToBeBetweenValidator,
"tableRowCountToEqual": TableRowCountToEqualValidator,
"tableColumnCountToBeBetween": TableColumnCountToBeBetweenValidator,
"tableColumnCountToEqual": TableColumnCountToEqualValidator,
"tableColumnNameToExist": TableColumnNameToExistValidator,
"tableColumnToMatchSet": TableColumnToMatchSetValidator,
}
VALIDATORS_THAT_REQUIRE_FULL_TABLE = {
"columnValuesToBeUnique",
"columnValueMeanToBeBetween",
"columnValueMedianToBeBetween",
"columnValueStdDevToBeBetween",
"columnValuesSumToBeBetween",
"columnValuesMissingCount",
"tableRowCountToBeBetween",
"tableRowCountToEqual",
}
def requires_whole_table(validator_name: str) -> bool:
"""Whether the validator requires a whole table to return appropriate results
Examples:
- `columnValuesToBeUnique` needs to see the whole column to make sure uniqueness is met
- `tableRowCountToEqual` needs to see the whole table to make sure the expected row count is met
These tests could return false positives when operating on batches
Args:
validator_name: The name of the validator to check
Returns:
Whether the validator requires a whole table to return appropriate results
"""
return validator_name in VALIDATORS_THAT_REQUIRE_FULL_TABLE

View File

@ -0,0 +1,294 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Examples demonstrating DataFrame validation with OpenMetadata SDK."""
# pyright: reportUnknownVariableType=false, reportAttributeAccessIssue=false, reportUnknownMemberType=false
# pyright: reportUnusedCallResult=false
# pylint: disable=W5001
import pandas as pd
from metadata.sdk import configure
from metadata.sdk.data_quality import (
ColumnValuesToBeBetween,
ColumnValuesToBeNotNull,
ColumnValuesToBeUnique,
ColumnValuesToMatchRegex,
TableRowCountToBeBetween,
)
from metadata.sdk.data_quality.dataframes.dataframe_validator import DataFrameValidator
from metadata.sdk.data_quality.dataframes.validation_results import (
FailureMode,
ValidationResult,
)
def basic_validation_example():
"""Basic example validating a customer DataFrame."""
print("\n=== Basic DataFrame Validation Example ===\n")
df = pd.DataFrame(
{
"customer_id": [1, 2, 3, 4, 5],
"email": [
"alice@example.com",
"bob@example.com",
"carol@example.com",
"dave@example.com",
"eve@example.com",
],
"age": [25, 30, 35, 40, 45],
}
)
validator = DataFrameValidator()
validator.add_test(ColumnValuesToBeNotNull(column="email"))
validator.add_test(ColumnValuesToBeUnique(column="customer_id"))
result = validator.validate(df)
if result.success:
print("✓ All validations passed!")
print(
f" Executed {result.total_tests} tests in {result.execution_time_ms:.2f}ms"
)
else:
print("✗ Validation failed")
for failure in result.failures:
print(f" - {failure.test_name}: {failure.result_message}")
def multiple_tests_example():
"""Example with multiple validation rules."""
print("\n=== Multiple Tests Validation Example ===\n")
df = pd.DataFrame(
{
"customer_id": [1, 2, 3, 4, 5],
"email": [
"alice@example.com",
"bob@example.com",
"carol@example.com",
"dave@example.com",
"eve@example.com",
],
"age": [25, 30, 35, 40, 45],
"status": ["active", "active", "inactive", "active", "active"],
}
)
validator = DataFrameValidator()
validator.add_tests(
TableRowCountToBeBetween(min_count=1, max_count=1000),
ColumnValuesToBeNotNull(column="customer_id"),
ColumnValuesToBeNotNull(column="email"),
ColumnValuesToBeUnique(column="customer_id"),
ColumnValuesToBeBetween(column="age", min_value=0, max_value=120),
ColumnValuesToMatchRegex(column="email", regex=r"^[\w\.-]+@[\w\.-]+\.\w+$"),
)
result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT)
print(f"Validation: {'PASSED' if result.success else 'FAILED'}")
print(f"Tests: {result.passed_tests}/{result.total_tests} passed")
print(f"Execution time: {result.execution_time_ms:.2f}ms\n")
for test_result in result.test_results:
status_icon = "" if test_result.status.value == "Success" else ""
print(f"{status_icon} {test_result.test_name}")
if test_result.passed_rows > 0:
print(f" Passed: {test_result.passed_rows}/{test_result.total_rows} rows")
if test_result.failed_rows > 0:
percentage = test_result.failed_rows / test_result.total_rows * 100
print(f" Failed: {test_result.failed_rows} rows ({percentage:.1f}%)")
def integrating_with_openmetadata_example():
"""Integrating with OpenMetadata."""
def transform_to_dwh_table(raw_df: pd.DataFrame) -> pd.DataFrame:
"""Transform the dataframe to dwh table."""
return raw_df
configure(host="http://localhost:8585/api", jwt_token="your jwt token")
df = pd.read_parquet("s3://some_bucket/raw_table.parquet")
df = transform_to_dwh_table(df)
# Instantiate validator and load the executable test suite for a table
validator = DataFrameValidator()
validator.add_openmetadata_table_tests(
"DbService.database_name.schema_name.dwh_table"
)
result = validator.validate(df)
print(f"Validation: {'PASSED' if result.success else 'FAILED'}")
# Publish the results back to Open Metadata
result.publish("DbService.database_name.schema_name.dwh_table")
if result.success:
df.to_parquet("s3://some_bucket/dwh_table.parquet")
def processing_big_data_with_chunks_example():
"""Processing big data with chunks."""
configure(host="http://localhost:8585/api", jwt_token="your jwt token")
validator = DataFrameValidator()
validator.add_openmetadata_table_tests(
"DbService.database_name.schema_name.dwh_table"
)
def load_df_to_destination(_df: pd.DataFrame, _result: ValidationResult):
"""Loads data into destination."""
def rollback(_df: pd.DataFrame, _result: ValidationResult):
"""Clears data previously loaded"""
results = validator.run(
pd.read_csv("somefile.csv", chunksize=1000),
on_success=load_df_to_destination,
on_failure=rollback,
mode=FailureMode.SHORT_CIRCUIT,
)
results.publish("DbService.database_name.schema_name.dwh_table")
def validation_failure_example():
"""Example demonstrating validation failures."""
print("\n=== Validation Failure Example ===\n")
df = pd.DataFrame(
{
"customer_id": [1, 2, 2, 4, 5],
"email": [
"alice@example.com",
None,
"carol@example.com",
"dave@example.com",
"invalid-email",
],
"age": [25, 150, 35, -5, 45],
}
)
validator = DataFrameValidator()
validator.add_tests(
ColumnValuesToBeUnique(column="customer_id"),
ColumnValuesToBeNotNull(column="email"),
ColumnValuesToBeBetween(column="age", min_value=0, max_value=120),
)
result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT)
print(f"Validation: {'PASSED' if result.success else 'FAILED'}\n")
if not result.success:
print("Failures detected:")
for failure in result.failures:
print(f"\n Test: {failure.test_name}")
print(f" Type: {failure.test_type}")
print(f" Message: {failure.result_message}")
print(f" Failed rows: {failure.failed_rows}/{failure.total_rows}")
def etl_pipeline_integration_example():
"""Example integrating validation into an ETL pipeline."""
print("\n=== ETL Pipeline Integration Example ===\n")
def extract_data():
return pd.DataFrame(
{
"id": [1, 2, 3, 4, 5],
"name": ["Alice", "Bob", "Carol", "Dave", "Eve"],
"value": [100, 200, 300, 400, 500],
}
)
def transform_data(df: pd.DataFrame) -> pd.DataFrame:
df = df.copy()
df["value_doubled"] = df["value"] * 2
return df
def validate_data(df: pd.DataFrame) -> ValidationResult:
validator = DataFrameValidator()
validator.add_tests(
TableRowCountToBeBetween(min_count=1, max_count=10000),
ColumnValuesToBeNotNull(column="id"),
ColumnValuesToBeUnique(column="id"),
ColumnValuesToBeBetween(column="value", min_value=0, max_value=10000),
)
return validator.validate(df, mode=FailureMode.SHORT_CIRCUIT)
def load_data(df: pd.DataFrame) -> None:
print(f"Loading {len(df)} rows to data warehouse...")
print("Starting ETL pipeline...")
print("\n1. Extract")
raw_df = extract_data()
print(f" Extracted {len(raw_df)} rows")
print("\n2. Transform")
transformed_df = transform_data(raw_df)
print(f" Transformed {len(transformed_df)} rows")
print("\n3. Validate")
validation_result = validate_data(transformed_df)
if validation_result.success:
print(" ✓ Validation passed")
print("\n4. Load")
load_data(transformed_df)
print(" ✓ Data loaded successfully")
else:
print(" ✗ Validation failed")
print("\n Failures:")
for failure in validation_result.failures:
print(f" - {failure.test_name}: {failure.result_message}")
print("\n Pipeline aborted - data not loaded")
def short_circuit_mode_example():
"""Example demonstrating short-circuit mode behavior."""
print("\n=== Short-Circuit Mode Example ===\n")
df = pd.DataFrame(
{
"id": [1, 2, 2, 3],
"email": [None, None, None, "test@example.com"],
"age": [25, 30, 35, 40],
}
)
validator = DataFrameValidator()
validator.add_tests(
ColumnValuesToBeUnique(column="id"),
ColumnValuesToBeNotNull(column="email"),
ColumnValuesToBeBetween(column="age", min_value=0, max_value=120),
)
result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT)
print("Short-circuit mode stops at first failure:")
print(f" Tests executed: {len(result.test_results)} of {result.total_tests}")
print(f" First failure: {result.failures[0].test_name}")
print("\nRemaining tests were not executed due to short-circuit mode.")
if __name__ == "__main__":
basic_validation_example()
multiple_tests_example()
validation_failure_example()
etl_pipeline_integration_example()
short_circuit_mode_example()

View File

@ -106,6 +106,21 @@ def get_table_or_column_fqn(entity_link: str) -> str:
)
def get_column_name_or_none(entity_link: str) -> Optional[str]:
"""It attempts to get a column from an entity link
Args:
entity_link: entity link
Returns:
The column name or None
"""
split_entity_link = split(entity_link)
if len(split_entity_link) == 4 and split_entity_link[2] == "columns":
return split_entity_link[3]
return None
get_entity_link_registry = class_register()

View File

@ -2,12 +2,252 @@
Minimal conftest for SDK integration tests.
Override the parent conftest to avoid testcontainers dependency.
"""
import pytest
from sqlalchemy import Column as SQAColumn
from sqlalchemy import Integer, MetaData, String
from sqlalchemy import Table as SQATable
from sqlalchemy import create_engine
from _openmetadata_testutils.ometa import int_admin_ometa
from _openmetadata_testutils.postgres.conftest import postgres_container
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.api.data.createDatabaseSchema import (
CreateDatabaseSchemaRequest,
)
from metadata.generated.schema.api.services.createDatabaseService import (
CreateDatabaseServiceRequest,
)
from metadata.generated.schema.entity.services.connections.database.common.basicAuth import (
BasicAuth,
)
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseConnection,
DatabaseService,
DatabaseServiceType,
)
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.workflow.metadata import MetadataWorkflow
@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def metadata():
"""Provide authenticated OpenMetadata client"""
return int_admin_ometa()
@pytest.fixture(scope="module")
def create_postgres_service(postgres_container, tmp_path_factory):
return CreateDatabaseServiceRequest(
name="dq_test_service_" + tmp_path_factory.mktemp("dq").name,
serviceType=DatabaseServiceType.Postgres,
connection=DatabaseConnection(
config=PostgresConnection(
username=postgres_container.username,
authType=BasicAuth(password=postgres_container.password),
hostPort="localhost:"
+ str(postgres_container.get_exposed_port(postgres_container.port)),
database="dq_test_db",
)
),
)
@pytest.fixture(scope="module")
def db_service(metadata, create_postgres_service, postgres_container):
engine = create_engine(
postgres_container.get_connection_url(), isolation_level="AUTOCOMMIT"
)
engine.execute("CREATE DATABASE dq_test_db")
service_entity = metadata.create_or_update(data=create_postgres_service)
service_entity.connection.config.authType.password = (
create_postgres_service.connection.config.authType.password
)
yield service_entity
service = metadata.get_by_name(
DatabaseService, service_entity.fullyQualifiedName.root
)
if service:
metadata.delete(DatabaseService, service.id, recursive=True, hard_delete=True)
@pytest.fixture(scope="module")
def database(metadata, db_service):
database_entity = metadata.create_or_update(
CreateDatabaseRequest(
name="dq_test_db",
service=db_service.fullyQualifiedName,
)
)
return database_entity
@pytest.fixture(scope="module")
def schema(metadata, database):
schema_entity = metadata.create_or_update(
CreateDatabaseSchemaRequest(
name="public",
database=database.fullyQualifiedName,
)
)
return schema_entity
@pytest.fixture(scope="module")
def test_data(db_service, postgres_container):
engine = create_engine(
postgres_container.get_connection_url().replace("/dvdrental", "/dq_test_db")
)
sql_metadata = MetaData()
users_table = SQATable(
"users",
sql_metadata,
SQAColumn("id", Integer, primary_key=True),
SQAColumn("username", String(50), nullable=False),
SQAColumn("email", String(100)),
SQAColumn("age", Integer),
SQAColumn("score", Integer),
)
products_table = SQATable(
"products",
sql_metadata,
SQAColumn("product_id", Integer, primary_key=True),
SQAColumn("name", String(100)),
SQAColumn("price", Integer),
)
stg_products_table = SQATable(
"stg_products",
sql_metadata,
SQAColumn("id", Integer, primary_key=True),
SQAColumn("name", String(100)),
SQAColumn("price", Integer),
)
sql_metadata.create_all(engine)
with engine.connect() as conn:
conn.execute(
users_table.insert(),
[
{
"id": 1,
"username": "alice",
"email": "alice@example.com",
"age": 25,
"score": 85,
},
{
"id": 2,
"username": "bob",
"email": "bob@example.com",
"age": 30,
"score": 90,
},
{"id": 3, "username": "charlie", "email": None, "age": 35, "score": 75},
{
"id": 4,
"username": "diana",
"email": "diana@example.com",
"age": 28,
"score": 95,
},
{
"id": 5,
"username": "eve",
"email": "eve@example.com",
"age": 22,
"score": 88,
},
],
)
conn.execute(
products_table.insert(),
[
{"product_id": 1, "name": "Widget", "price": 100},
{"product_id": 2, "name": "Gadget", "price": 200},
{"product_id": 3, "name": "Doohickey", "price": 150},
],
)
conn.execute(
stg_products_table.insert(),
[
{"id": 1, "name": "Widget", "price": 100},
{"id": 2, "name": "Gadget", "price": 200},
{"id": 3, "name": "Doohickey", "price": 150},
],
)
return {
"users": users_table,
"products": products_table,
"stg_products": stg_products_table,
}
@pytest.fixture(scope="module")
def ingest_metadata(metadata, db_service, schema, test_data):
workflow_config = {
"source": {
"type": db_service.connection.config.type.value.lower(),
"serviceName": db_service.fullyQualifiedName.root,
"sourceConfig": {
"config": {
"type": "DatabaseMetadata",
"schemaFilterPattern": {"includes": ["public"]},
}
},
"serviceConnection": db_service.connection.model_dump(),
},
"sink": {"type": "metadata-rest", "config": {}},
"workflowConfig": {
"loggerLevel": "INFO",
"openMetadataServerConfig": metadata.config.model_dump(),
},
}
workflow = MetadataWorkflow.create(workflow_config)
workflow.execute()
workflow.raise_from_status()
return workflow
@pytest.fixture(scope="module")
def patch_passwords(db_service, monkeymodule):
def override_password(getter):
def inner(*args, **kwargs):
result = getter(*args, **kwargs)
if isinstance(result, DatabaseService):
if result.fullyQualifiedName.root == db_service.fullyQualifiedName.root:
result.connection.config.authType.password = (
db_service.connection.config.authType.password
)
return result
return inner
monkeymodule.setattr(
"metadata.ingestion.ometa.ometa_api.OpenMetadata.get_by_name",
override_password(OpenMetadata.get_by_name),
)
monkeymodule.setattr(
"metadata.ingestion.ometa.ometa_api.OpenMetadata.get_by_id",
override_password(OpenMetadata.get_by_id),
)
@pytest.fixture(scope="module")
def monkeymodule():
with pytest.MonkeyPatch.context() as mp:
yield mp

View File

@ -0,0 +1,255 @@
from typing import Any, Generator, Mapping
from unittest.mock import Mock, patch
import pandas
import pytest
from dirty_equals import HasAttributes, IsList, IsTuple
from pandas import DataFrame
from sqlalchemy import create_engine
from sqlalchemy.sql.schema import Table as SQATable
from testcontainers.postgres import PostgresContainer
from metadata.generated.schema.api.tests.createTestCase import CreateTestCaseRequest
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.tests.basic import TestCaseStatus
from metadata.generated.schema.tests.testCase import TestCase, TestCaseParameterValue
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.sdk.data_quality import ColumnValueMinToBeBetween
from metadata.sdk.data_quality.dataframes.dataframe_validator import DataFrameValidator
from metadata.utils.entity_link import get_entity_link
@pytest.fixture(scope="module")
def table_fqn(db_service: DatabaseService) -> str:
return f"{db_service.name.root}.dq_test_db.public.users"
@pytest.fixture(scope="module")
def column_unique_test(
table_fqn: str, metadata: OpenMetadata[TestCase, CreateTestCaseRequest]
) -> TestCase:
request = CreateTestCaseRequest(
name="column_not_null",
testDefinition="columnValuesToBeUnique",
entityLink=get_entity_link(
Table,
table_fqn,
column_name="username",
),
)
test_case = metadata.create_or_update(request)
return test_case
@pytest.fixture(scope="module")
def table_row_count_test(
table_fqn: str, metadata: OpenMetadata[TestCase, CreateTestCaseRequest]
) -> TestCase:
request = CreateTestCaseRequest(
name="table_row_count",
testDefinition="tableRowCountToEqual",
entityLink=get_entity_link(Table, table_fqn),
parameterValues=[TestCaseParameterValue(name="value", value="5")],
)
test_case = metadata.create_or_update(request)
return test_case
@pytest.fixture(scope="module")
def dataframe(
test_data: Mapping[str, SQATable], postgres_container: PostgresContainer
) -> Generator[DataFrame, None, None]:
engine = create_engine(
postgres_container.get_connection_url().replace("/dvdrental", "/dq_test_db"),
isolation_level="AUTOCOMMIT",
)
with engine.connect() as connection:
yield pandas.read_sql(
test_data["users"].select(),
connection,
)
def test_it_runs_tests_from_openmetadata(
ingest_metadata: None,
metadata: OpenMetadata[Any, Any],
dataframe: DataFrame,
column_unique_test: TestCase,
table_row_count_test: TestCase,
) -> None:
validator = DataFrameValidator(client=metadata)
validator.add_openmetadata_test(column_unique_test.fullyQualifiedName.root)
validator.add_openmetadata_test(table_row_count_test.fullyQualifiedName.root)
result = validator.validate(dataframe)
assert result == HasAttributes(
success=True,
total_tests=2,
passed_tests=2,
failed_tests=0,
)
def test_it_runs_openmetadata_table_tests(
table_fqn: str,
ingest_metadata: None,
metadata: OpenMetadata[Any, Any],
dataframe: DataFrame,
column_unique_test: TestCase,
table_row_count_test: TestCase,
) -> None:
validator = DataFrameValidator(client=metadata)
validator.add_openmetadata_table_tests(table_fqn)
result = validator.validate(dataframe)
assert result == HasAttributes(
success=True,
total_tests=2,
passed_tests=2,
failed_tests=0,
)
class TestFullUseCase:
def test_it_runs_tests_and_publishes_results(
self,
table_fqn: str,
ingest_metadata: None,
metadata: OpenMetadata[Any, Any],
dataframe: DataFrame,
column_unique_test: TestCase,
table_row_count_test: TestCase,
) -> None:
# First ensure only previously reported tests exist in OM
test_suite = metadata.get_executable_test_suite(table_fqn)
assert test_suite is not None
original_test_names = {t.fullyQualifiedName for t in test_suite.tests}
assert original_test_names == {
column_unique_test.fullyQualifiedName.root,
table_row_count_test.fullyQualifiedName.root,
}
# Run validation
validator = DataFrameValidator(client=metadata)
validator.add_openmetadata_table_tests(table_fqn)
# Forcing a failure with this test
validator.add_test(
ColumnValueMinToBeBetween(
name="column_value_min_to_be_between_90_and_100",
column="score",
min_value=90,
max_value=100,
)
)
result = validator.validate(dataframe)
assert result == HasAttributes(
success=False,
total_tests=3,
passed_tests=2,
failed_tests=1,
test_cases_and_results=IsList(
IsTuple(
HasAttributes(
fullyQualifiedName=HasAttributes(
root=column_unique_test.fullyQualifiedName.root
),
),
HasAttributes(
testCaseStatus=TestCaseStatus.Success,
),
),
IsTuple(
HasAttributes(
fullyQualifiedName=HasAttributes(
root=table_row_count_test.fullyQualifiedName.root
),
),
HasAttributes(
testCaseStatus=TestCaseStatus.Success,
),
),
IsTuple(
HasAttributes(
fullyQualifiedName=HasAttributes(
root="column_value_min_to_be_between_90_and_100"
),
),
HasAttributes(
testCaseStatus=TestCaseStatus.Failed,
),
),
check_order=False,
),
)
# Publish results
with patch(
"metadata.sdk.data_quality.dataframes.validation_results.get_client",
return_value=Mock(ometa=metadata),
) as mock_client:
result.publish(
table_fqn,
)
mock_client.assert_called_once()
# Validate test defined in code has been pushed to OM as well
test_suite = metadata.get_executable_test_suite(table_fqn)
assert test_suite is not None
assert len(test_suite.tests) == 3
assert original_test_names.issubset(
{t.fullyQualifiedName for t in test_suite.tests}
)
assert {t.fullyQualifiedName for t in test_suite.tests} == {
column_unique_test.fullyQualifiedName.root,
table_row_count_test.fullyQualifiedName.root,
f"{table_fqn}.score.column_value_min_to_be_between_90_and_100",
}
# Validate each test case result
required_fields = ["testCaseResult", "testSuite", "testDefinition"]
column_unique_result = metadata.get_by_name(
TestCase, column_unique_test.fullyQualifiedName.root, fields=required_fields
).testCaseResult
assert column_unique_result == HasAttributes(
testCaseStatus=TestCaseStatus.Success
)
table_row_count_result = metadata.get_by_name(
TestCase,
table_row_count_test.fullyQualifiedName.root,
fields=required_fields,
).testCaseResult
assert table_row_count_result == HasAttributes(
testCaseStatus=TestCaseStatus.Success
)
code_test_case_result = metadata.get_by_name(
TestCase,
f"{table_fqn}.score.column_value_min_to_be_between_90_and_100",
fields=required_fields,
).testCaseResult
assert code_test_case_result == HasAttributes(
testCaseStatus=TestCaseStatus.Failed
)
# Clean up code test
metadata.delete_test_case(
f"{table_fqn}.score.column_value_min_to_be_between_90_and_100",
recursive=True,
hard=True,
)

View File

@ -6,36 +6,11 @@ import sys
import pytest
from dirty_equals import HasAttributes
from sqlalchemy import Column as SQAColumn
from sqlalchemy import Integer, MetaData, String
from sqlalchemy import Table as SQATable
from sqlalchemy import create_engine
from _openmetadata_testutils.ometa import int_admin_ometa
from _openmetadata_testutils.postgres.conftest import postgres_container
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.api.data.createDatabaseSchema import (
CreateDatabaseSchemaRequest,
)
from metadata.generated.schema.api.services.createDatabaseService import (
CreateDatabaseServiceRequest,
)
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.connections.database.common.basicAuth import (
BasicAuth,
)
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseConnection,
DatabaseService,
DatabaseServiceType,
)
from metadata.generated.schema.tests.basic import TestCaseStatus
from metadata.generated.schema.tests.testCase import TestCase
from metadata.generated.schema.type.basic import EntityLink
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.sdk.data_quality import (
ColumnValuesToBeBetween,
ColumnValuesToBeNotNull,
@ -45,7 +20,6 @@ from metadata.sdk.data_quality import (
TableRowCountToBeBetween,
TestRunner,
)
from metadata.workflow.metadata import MetadataWorkflow
if not sys.version_info >= (3, 9):
pytest.skip(
@ -54,226 +28,6 @@ if not sys.version_info >= (3, 9):
)
@pytest.fixture(scope="module")
def metadata():
return int_admin_ometa()
@pytest.fixture(scope="module")
def create_postgres_service(postgres_container, tmp_path_factory):
return CreateDatabaseServiceRequest(
name="dq_test_service_" + tmp_path_factory.mktemp("dq").name,
serviceType=DatabaseServiceType.Postgres,
connection=DatabaseConnection(
config=PostgresConnection(
username=postgres_container.username,
authType=BasicAuth(password=postgres_container.password),
hostPort="localhost:"
+ str(postgres_container.get_exposed_port(postgres_container.port)),
database="dq_test_db",
)
),
)
@pytest.fixture(scope="module")
def db_service(metadata, create_postgres_service, postgres_container):
engine = create_engine(
postgres_container.get_connection_url(), isolation_level="AUTOCOMMIT"
)
engine.execute("CREATE DATABASE dq_test_db")
service_entity = metadata.create_or_update(data=create_postgres_service)
service_entity.connection.config.authType.password = (
create_postgres_service.connection.config.authType.password
)
yield service_entity
service = metadata.get_by_name(
DatabaseService, service_entity.fullyQualifiedName.root
)
if service:
metadata.delete(DatabaseService, service.id, recursive=True, hard_delete=True)
@pytest.fixture(scope="module")
def database(metadata, db_service):
database_entity = metadata.create_or_update(
CreateDatabaseRequest(
name="dq_test_db",
service=db_service.fullyQualifiedName,
)
)
return database_entity
@pytest.fixture(scope="module")
def schema(metadata, database):
schema_entity = metadata.create_or_update(
CreateDatabaseSchemaRequest(
name="public",
database=database.fullyQualifiedName,
)
)
return schema_entity
@pytest.fixture(scope="module")
def test_data(postgres_container):
engine = create_engine(
postgres_container.get_connection_url().replace("/dvdrental", "/dq_test_db")
)
sql_metadata = MetaData()
users_table = SQATable(
"users",
sql_metadata,
SQAColumn("id", Integer, primary_key=True),
SQAColumn("username", String(50), nullable=False),
SQAColumn("email", String(100)),
SQAColumn("age", Integer),
SQAColumn("score", Integer),
)
products_table = SQATable(
"products",
sql_metadata,
SQAColumn("product_id", Integer, primary_key=True),
SQAColumn("name", String(100)),
SQAColumn("price", Integer),
)
stg_products_table = SQATable(
"stg_products",
sql_metadata,
SQAColumn("id", Integer, primary_key=True),
SQAColumn("name", String(100)),
SQAColumn("price", Integer),
)
sql_metadata.create_all(engine)
with engine.connect() as conn:
conn.execute(
users_table.insert(),
[
{
"id": 1,
"username": "alice",
"email": "alice@example.com",
"age": 25,
"score": 85,
},
{
"id": 2,
"username": "bob",
"email": "bob@example.com",
"age": 30,
"score": 90,
},
{"id": 3, "username": "charlie", "email": None, "age": 35, "score": 75},
{
"id": 4,
"username": "diana",
"email": "diana@example.com",
"age": 28,
"score": 95,
},
{
"id": 5,
"username": "eve",
"email": "eve@example.com",
"age": 22,
"score": 88,
},
],
)
conn.execute(
products_table.insert(),
[
{"product_id": 1, "name": "Widget", "price": 100},
{"product_id": 2, "name": "Gadget", "price": 200},
{"product_id": 3, "name": "Doohickey", "price": 150},
],
)
conn.execute(
stg_products_table.insert(),
[
{"id": 1, "name": "Widget", "price": 100},
{"id": 2, "name": "Gadget", "price": 200},
{"id": 3, "name": "Doohickey", "price": 150},
],
)
return {
"users": users_table,
"products": products_table,
"stg_products": stg_products_table,
}
@pytest.fixture(scope="module")
def ingest_metadata(metadata, db_service, schema, test_data):
workflow_config = {
"source": {
"type": db_service.connection.config.type.value.lower(),
"serviceName": db_service.fullyQualifiedName.root,
"sourceConfig": {
"config": {
"type": "DatabaseMetadata",
"schemaFilterPattern": {"includes": ["public"]},
}
},
"serviceConnection": db_service.connection.model_dump(),
},
"sink": {"type": "metadata-rest", "config": {}},
"workflowConfig": {
"loggerLevel": "INFO",
"openMetadataServerConfig": metadata.config.model_dump(),
},
}
workflow = MetadataWorkflow.create(workflow_config)
workflow.execute()
workflow.raise_from_status()
return workflow
@pytest.fixture(scope="module")
def patch_passwords(db_service, monkeymodule):
def override_password(getter):
def inner(*args, **kwargs):
result = getter(*args, **kwargs)
if isinstance(result, DatabaseService):
if result.fullyQualifiedName.root == db_service.fullyQualifiedName.root:
result.connection.config.authType.password = (
db_service.connection.config.authType.password
)
return result
return inner
monkeymodule.setattr(
"metadata.ingestion.ometa.ometa_api.OpenMetadata.get_by_name",
override_password(OpenMetadata.get_by_name),
)
monkeymodule.setattr(
"metadata.ingestion.ometa.ometa_api.OpenMetadata.get_by_id",
override_password(OpenMetadata.get_by_id),
)
@pytest.fixture(scope="module")
def monkeymodule():
with pytest.MonkeyPatch.context() as mp:
yield mp
def test_table_row_count_tests(
metadata,
db_service,

View File

@ -0,0 +1,561 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for DataFrame validator."""
from typing import Generator, List, Tuple
from unittest.mock import Mock
import pandas as pd
import pytest
from pandas import DataFrame
from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus
from metadata.generated.schema.tests.testCase import TestCase
from metadata.sdk.data_quality import (
ColumnValuesToBeNotNull,
ColumnValuesToBeUnique,
TableRowCountToBeBetween,
)
from metadata.sdk.data_quality.dataframes.custom_warnings import WholeTableTestsWarning
from metadata.sdk.data_quality.dataframes.dataframe_validator import DataFrameValidator
from metadata.sdk.data_quality.dataframes.validation_results import (
FailureMode,
ValidationResult,
)
class TestDataFrameValidator:
"""Test DataFrameValidator initialization and configuration."""
def test_validator_initialization(self):
"""Test validator can be created."""
validator = DataFrameValidator(Mock())
assert validator is not None
assert validator._test_cases == []
def test_add_single_test(self):
"""Test adding a single test definition."""
validator = DataFrameValidator(Mock())
validator.add_test(ColumnValuesToBeNotNull(column="email"))
assert len(validator._test_cases) == 1
def test_add_multiple_tests(self):
"""Test adding multiple test definitions at once."""
validator = DataFrameValidator(Mock())
validator.add_tests(
ColumnValuesToBeNotNull(column="email"),
ColumnValuesToBeUnique(column="id"),
TableRowCountToBeBetween(min_count=1, max_count=100),
)
assert len(validator._test_cases) == 3
class TestValidationSuccess:
"""Test successful validation scenarios."""
def test_validate_not_null_success(self):
"""Test validation passes with valid DataFrame."""
df = pd.DataFrame({"email": ["a@b.com", "c@d.com", "e@f.com"]})
validator = DataFrameValidator(Mock())
validator.add_test(ColumnValuesToBeNotNull(column="email"))
result = validator.validate(df)
assert result.success is True
assert result.passed_tests == 1
assert result.failed_tests == 0
assert len(result.test_results) == 1
assert result.test_results[0].testCaseStatus is TestCaseStatus.Success
def test_validate_unique_success(self):
"""Test uniqueness validation passes."""
df = pd.DataFrame({"id": [1, 2, 3, 4, 5]})
validator = DataFrameValidator(Mock())
validator.add_test(ColumnValuesToBeUnique(column="id"))
result = validator.validate(df)
assert result.success is True
assert result.passed_tests == 1
def test_validate_multiple_tests_success(self):
"""Test multiple tests all passing."""
df = pd.DataFrame(
{
"id": [1, 2, 3],
"email": ["a@b.com", "c@d.com", "e@f.com"],
"age": [25, 30, 35],
}
)
validator = DataFrameValidator(Mock())
validator.add_tests(
ColumnValuesToBeNotNull(column="email"),
ColumnValuesToBeUnique(column="id"),
TableRowCountToBeBetween(min_count=1, max_count=10),
)
result = validator.validate(df)
assert result.success is True
assert result.passed_tests == 3
assert result.failed_tests == 0
class TestValidationFailure:
"""Test validation failure scenarios."""
def test_validate_not_null_failure(self):
"""Test validation fails with null values."""
df = pd.DataFrame({"email": ["a@b.com", None, "e@f.com"]})
validator = DataFrameValidator(Mock())
validator.add_test(ColumnValuesToBeNotNull(column="email"))
result = validator.validate(df)
assert result.success is False
assert result.passed_tests == 0
assert result.failed_tests == 1
assert result.test_results[0].testCaseStatus is TestCaseStatus.Failed
def test_validate_unique_failure(self):
"""Test uniqueness validation fails with duplicates."""
df = pd.DataFrame({"id": [1, 2, 2, 3]})
validator = DataFrameValidator(Mock())
validator.add_test(ColumnValuesToBeUnique(column="id"))
result = validator.validate(df)
assert result.success is False
assert result.failed_tests == 1
def test_validate_row_count_failure(self):
"""Test row count validation fails."""
df = pd.DataFrame({"col": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]})
validator = DataFrameValidator(Mock())
validator.add_test(TableRowCountToBeBetween(min_count=1, max_count=10))
result = validator.validate(df)
assert result.success is False
class TestShortCircuitMode:
"""Test short-circuit validation mode."""
def test_short_circuit_stops_on_first_failure(self):
"""Test short-circuit mode stops after first failure."""
df = pd.DataFrame(
{
"id": [1, 2, 2],
"email": [None, None, None],
}
)
validator = DataFrameValidator(Mock())
validator.add_tests(
ColumnValuesToBeUnique(column="id"),
ColumnValuesToBeNotNull(column="email"),
)
result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT)
assert result.success is False
assert len(result.test_results) == 1
def test_short_circuit_continues_on_success(self):
"""Test short-circuit mode continues when tests pass."""
df = pd.DataFrame(
{
"id": [1, 2, 3],
"email": ["a@b.com", "c@d.com", "e@f.com"],
}
)
validator = DataFrameValidator(Mock())
validator.add_tests(
ColumnValuesToBeUnique(column="id"),
ColumnValuesToBeNotNull(column="email"),
)
result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT)
assert result.success is True
assert len(result.test_results) == 2
class TestValidationResultProperties:
"""Test ValidationResult helper properties."""
def test_failures_property(self):
"""Test failures property returns only failed tests."""
df = pd.DataFrame(
{
"id": [1, 2, 2],
"email": ["a@b.com", "c@d.com", "e@f.com"],
}
)
validator = DataFrameValidator(Mock())
validator.add_tests(
ColumnValuesToBeUnique(column="id"),
ColumnValuesToBeNotNull(column="email"),
)
result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT)
assert len(result.failures) == 1
def test_passes_property(self):
"""Test passes property returns only passed tests."""
df = pd.DataFrame(
{
"id": [1, 2, 3],
"email": ["a@b.com", None, "e@f.com"],
}
)
validator = DataFrameValidator(Mock())
validator.add_tests(
ColumnValuesToBeUnique(column="id"),
ColumnValuesToBeNotNull(column="email"),
)
result = validator.validate(df, mode=FailureMode.SHORT_CIRCUIT)
assert len(result.passes) == 1
class TestEdgeCases:
"""Test edge cases and error conditions."""
def test_empty_dataframe(self):
"""Test validation on empty DataFrame."""
df = pd.DataFrame({"email": []})
validator = DataFrameValidator(Mock())
validator.add_test(ColumnValuesToBeNotNull(column="email"))
result = validator.validate(df)
assert result.success is True
def test_missing_column(self):
"""Test validation with missing column."""
df = pd.DataFrame({"name": ["Alice", "Bob"]})
validator = DataFrameValidator(Mock())
validator.add_test(ColumnValuesToBeNotNull(column="email"))
result = validator.validate(df)
assert result.success is False
assert result.test_results[0].testCaseStatus is TestCaseStatus.Aborted
def test_no_tests_configured(self):
"""Test validation with no tests configured."""
df = pd.DataFrame({"col": [1, 2, 3]})
validator = DataFrameValidator(Mock())
result = validator.validate(df)
assert result.success is True
assert result.total_tests == 0
assert len(result.test_results) == 0
def test_execution_time_recorded(self):
"""Test that execution time is recorded."""
df = pd.DataFrame({"col": [1, 2, 3]})
validator = DataFrameValidator(Mock())
validator.add_test(TableRowCountToBeBetween(min_count=1, max_count=10))
result = validator.validate(df)
assert result.execution_time_ms > 0
class TracksValidationCallbacks:
def __init__(self) -> None:
self.calls: List[Tuple[DataFrame, ValidationResult]] = []
@property
def times_called(self) -> int:
return len(self.calls)
@property
def was_called(self) -> bool:
return self.times_called > 0
def __call__(self, df: DataFrame, result: ValidationResult) -> None:
self.calls.append((df, result))
@pytest.mark.filterwarnings(
"error::metadata.sdk.data_quality.dataframes.custom_warnings.WholeTableTestsWarning"
)
class TestValidatorRun:
@pytest.fixture
def on_success_callback(self) -> TracksValidationCallbacks:
return TracksValidationCallbacks()
@pytest.fixture
def on_failure_callback(self) -> TracksValidationCallbacks:
return TracksValidationCallbacks()
@pytest.fixture
def validator(self) -> DataFrameValidator:
df_validator = DataFrameValidator(Mock())
df_validator.add_tests(
ColumnValuesToBeNotNull(column="id"),
)
return df_validator
def test_it_only_calls_on_success(
self,
validator: DataFrameValidator,
on_success_callback: TracksValidationCallbacks,
on_failure_callback: TracksValidationCallbacks,
) -> None:
dfs = (pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) for i in range(0, 9, 3))
validator.run(dfs, on_success_callback, on_failure_callback)
assert on_success_callback.times_called == 3
assert on_failure_callback.was_called is False
def test_it_calls_both_on_success_and_on_failure(
self,
validator: DataFrameValidator,
on_success_callback: TracksValidationCallbacks,
on_failure_callback: TracksValidationCallbacks,
) -> None:
def generate_data() -> Generator[DataFrame, None, None]:
for i in range(0, 6, 3):
yield pd.DataFrame({"id": [i + 1, i + 2, i + 3]})
yield pd.DataFrame({"id": [None]})
validator.run(generate_data(), on_success_callback, on_failure_callback)
assert on_success_callback.times_called == 2
assert on_failure_callback.times_called == 1
assert pd.DataFrame({"id": [None]}).equals(on_failure_callback.calls[0][0])
def test_it_aborts_on_failure(
self,
validator: DataFrameValidator,
on_success_callback: TracksValidationCallbacks,
on_failure_callback: TracksValidationCallbacks,
) -> None:
def generate_data() -> Generator[DataFrame, None, None]:
yield pd.DataFrame({"id": [None]})
for i in range(0, 6, 3):
yield pd.DataFrame({"id": [i + 1, i + 2, i + 3]})
validator.run(generate_data(), on_success_callback, on_failure_callback)
assert on_success_callback.was_called is False
assert on_failure_callback.times_called == 1
assert pd.DataFrame({"id": [None]}).equals(on_failure_callback.calls[0][0])
def test_it_warns_when_using_tests_that_require_the_whole_table_or_column(
self,
validator: DataFrameValidator,
on_success_callback: TracksValidationCallbacks,
on_failure_callback: TracksValidationCallbacks,
) -> None:
dfs = (pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) for i in range(0, 9, 3))
validator.add_tests(
TableRowCountToBeBetween(min_count=1, max_count=10),
ColumnValuesToBeUnique(column="id"),
)
expected_warning_match = (
"The following tests could have unexpected results:\n\n"
+ "\t- columnValuesToBeUnique\n"
+ "\t- tableRowCountToBeBetween"
)
with pytest.warns(WholeTableTestsWarning, match=expected_warning_match):
validator.run(dfs, on_success_callback, on_failure_callback)
def test_run_returns_merged_validation_result(
self,
validator: DataFrameValidator,
on_success_callback: TracksValidationCallbacks,
on_failure_callback: TracksValidationCallbacks,
) -> None:
dfs = (pd.DataFrame({"id": [i + 1, i + 2, i + 3]}) for i in range(0, 9, 3))
result = validator.run(dfs, on_success_callback, on_failure_callback)
assert isinstance(result, ValidationResult)
assert result.success is True
def test_merged_result_has_correct_aggregated_metrics(
self,
validator: DataFrameValidator,
on_success_callback: TracksValidationCallbacks,
on_failure_callback: TracksValidationCallbacks,
) -> None:
dfs = [
pd.DataFrame({"id": [1, 2, 3]}),
pd.DataFrame({"id": [4, 5, 6]}),
pd.DataFrame({"id": [7, 8, 9]}),
]
result = validator.run(iter(dfs), on_success_callback, on_failure_callback)
assert result.total_tests == 1
assert result.passed_tests == 1
assert result.failed_tests == 0
assert result.success is True
def test_merged_result_aggregates_failures_correctly(
self,
validator: DataFrameValidator,
on_success_callback: TracksValidationCallbacks,
on_failure_callback: TracksValidationCallbacks,
) -> None:
dfs = [
pd.DataFrame({"id": [1, 2, 3]}),
pd.DataFrame({"id": [None]}),
]
result = validator.run(iter(dfs), on_success_callback, on_failure_callback)
assert result.total_tests == 1
assert result.passed_tests == 0
assert result.failed_tests == 1
assert result.success is False
def test_merged_result_contains_aggregated_test_case_results(
self,
validator: DataFrameValidator,
on_success_callback: TracksValidationCallbacks,
on_failure_callback: TracksValidationCallbacks,
) -> None:
validator.add_test(ColumnValuesToBeUnique(column="id"))
dfs = [
pd.DataFrame({"id": [1, 2, 3]}),
pd.DataFrame({"id": [4, 5, 6]}),
]
with pytest.warns(WholeTableTestsWarning):
result = validator.run(iter(dfs), on_success_callback, on_failure_callback)
assert len(result.test_cases_and_results) == 2
assert all(
isinstance(test_case, TestCase) and isinstance(test_result, TestCaseResult)
for test_case, test_result in result.test_cases_and_results
)
_, aggregated_not_null_result = result.test_cases_and_results[0]
assert aggregated_not_null_result.passedRows == 6
def test_merged_result_sums_execution_times(
self,
validator: DataFrameValidator,
on_success_callback: TracksValidationCallbacks,
on_failure_callback: TracksValidationCallbacks,
) -> None:
dfs = [
pd.DataFrame({"id": [1, 2, 3]}),
pd.DataFrame({"id": [4, 5, 6]}),
pd.DataFrame({"id": [7, 8, 9]}),
]
result = validator.run(iter(dfs), on_success_callback, on_failure_callback)
assert result.execution_time_ms > 0
individual_times = [
call[1].execution_time_ms for call in on_success_callback.calls
]
assert result.execution_time_ms == sum(individual_times)
def test_merged_result_with_mixed_success_and_failure(
self,
validator: DataFrameValidator,
on_success_callback: TracksValidationCallbacks,
on_failure_callback: TracksValidationCallbacks,
) -> None:
def generate_mixed_data() -> Generator[DataFrame, None, None]:
yield pd.DataFrame({"id": [1, 2, 3]})
yield pd.DataFrame({"id": [4, 5, 6]})
yield pd.DataFrame({"id": [None]})
result = validator.run(
generate_mixed_data(), on_success_callback, on_failure_callback
)
assert result.total_tests == 1
assert result.passed_tests == 0
assert result.failed_tests == 1
assert result.success is False
assert len(result.test_cases_and_results) == 1
def test_merged_result_aggregates_multiple_tests_per_batch(
self,
on_success_callback: TracksValidationCallbacks,
on_failure_callback: TracksValidationCallbacks,
) -> None:
validator = DataFrameValidator(Mock())
validator.add_tests(
ColumnValuesToBeNotNull(column="id"),
ColumnValuesToBeUnique(column="id"),
TableRowCountToBeBetween(min_count=1, max_count=10),
)
dfs = [
pd.DataFrame({"id": [1, 2, 3]}),
pd.DataFrame({"id": [4, 5, 6]}),
]
with pytest.warns(WholeTableTestsWarning):
result = validator.run(iter(dfs), on_success_callback, on_failure_callback)
assert result.total_tests == 3
assert result.passed_tests == 3
assert result.failed_tests == 0
assert result.success is True
def test_merged_result_with_short_circuit_on_failure(
self,
validator: DataFrameValidator,
on_success_callback: TracksValidationCallbacks,
on_failure_callback: TracksValidationCallbacks,
) -> None:
def generate_data() -> Generator[DataFrame, None, None]:
yield pd.DataFrame({"id": [None]})
yield pd.DataFrame({"id": [1, 2, 3]})
yield pd.DataFrame({"id": [4, 5, 6]})
result = validator.run(
generate_data(), on_success_callback, on_failure_callback
)
assert result.total_tests == 1
assert result.passed_tests == 0
assert result.failed_tests == 1
assert result.success is False
def test_merged_result_reflects_all_batches_processed(
self,
validator: DataFrameValidator,
on_success_callback: TracksValidationCallbacks,
on_failure_callback: TracksValidationCallbacks,
) -> None:
batch_count = 5
dfs = [
pd.DataFrame({"id": [i + 1, i + 2, i + 3]})
for i in range(0, batch_count * 3, 3)
]
result = validator.run(iter(dfs), on_success_callback, on_failure_callback)
assert result.total_tests == 1
assert len(result.test_cases_and_results) == 1
assert on_success_callback.times_called == batch_count
_, aggregated_result = result.test_cases_and_results[0]
assert aggregated_result.passedRows == batch_count * 3

View File

@ -0,0 +1,377 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for ValidationResult."""
from datetime import datetime
from uuid import UUID
import pytest
from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus
from metadata.generated.schema.tests.testCase import TestCase
from metadata.generated.schema.type.basic import FullyQualifiedEntityName, Timestamp
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.sdk.data_quality.dataframes.validation_results import ValidationResult
def create_test_case(fqn: str) -> TestCase:
"""Helper to create a test case with minimal required fields."""
return TestCase(
name=fqn.split(".")[-1],
fullyQualifiedName=FullyQualifiedEntityName(fqn),
testDefinition=EntityReference(
id=UUID("12345678-1234-1234-1234-123456789abc"),
type="testDefinition",
fullyQualifiedName=fqn,
),
entityLink="<#E::table::test_table>",
testSuite=EntityReference(
id=UUID("87654321-4321-4321-4321-cba987654321"),
type="testSuite",
fullyQualifiedName="test_suite",
),
)
def create_test_result(
status: TestCaseStatus,
passed_rows: int = 0,
failed_rows: int = 0,
) -> TestCaseResult:
"""Helper to create a test case result."""
total = passed_rows + failed_rows
return TestCaseResult(
timestamp=Timestamp(int(datetime.now().timestamp() * 1000)),
testCaseStatus=status,
passedRows=passed_rows if total > 0 else None,
failedRows=failed_rows if total > 0 else None,
passedRowsPercentage=(passed_rows / total * 100) if total > 0 else None,
failedRowsPercentage=(failed_rows / total * 100) if total > 0 else None,
)
class TestValidationResultMerge:
"""Test ValidationResult.merge method."""
def test_merge_single_result(self) -> None:
"""Test merging a single ValidationResult returns equivalent result."""
test_case = create_test_case("test.case.one")
test_result = create_test_result(TestCaseStatus.Success, passed_rows=100)
result = ValidationResult(
success=True,
total_tests=1,
passed_tests=1,
failed_tests=0,
test_cases_and_results=[(test_case, test_result)],
execution_time_ms=10.0,
)
merged = ValidationResult.merge(result)
assert merged.success is True
assert merged.total_tests == 1
assert merged.passed_tests == 1
assert merged.failed_tests == 0
assert merged.execution_time_ms == 10.0
assert len(merged.test_cases_and_results) == 1
def test_merge_multiple_results_same_test_case(self) -> None:
"""Test merging results for same test case aggregates metrics."""
test_case = create_test_case("test.case.one")
result1 = ValidationResult(
success=True,
total_tests=1,
passed_tests=1,
failed_tests=0,
test_cases_and_results=[
(test_case, create_test_result(TestCaseStatus.Success, passed_rows=50))
],
execution_time_ms=10.0,
)
result2 = ValidationResult(
success=True,
total_tests=1,
passed_tests=1,
failed_tests=0,
test_cases_and_results=[
(test_case, create_test_result(TestCaseStatus.Success, passed_rows=30))
],
execution_time_ms=8.0,
)
merged = ValidationResult.merge(result1, result2)
assert merged.success is True
assert merged.total_tests == 1
assert merged.passed_tests == 1
assert merged.failed_tests == 0
assert merged.execution_time_ms == 18.0
assert len(merged.test_cases_and_results) == 1
_, aggregated_result = merged.test_cases_and_results[0]
assert aggregated_result.passedRows == 80
assert aggregated_result.failedRows == 0
assert aggregated_result.testCaseStatus == TestCaseStatus.Success
def test_merge_aggregates_passed_and_failed_rows(self) -> None:
"""Test that passed and failed rows are summed correctly."""
test_case = create_test_case("test.case.one")
result1 = ValidationResult(
success=False,
total_tests=1,
passed_tests=0,
failed_tests=1,
test_cases_and_results=[
(
test_case,
create_test_result(
TestCaseStatus.Failed, passed_rows=40, failed_rows=10
),
)
],
execution_time_ms=10.0,
)
result2 = ValidationResult(
success=False,
total_tests=1,
passed_tests=0,
failed_tests=1,
test_cases_and_results=[
(
test_case,
create_test_result(
TestCaseStatus.Failed, passed_rows=30, failed_rows=20
),
)
],
execution_time_ms=12.0,
)
merged = ValidationResult.merge(result1, result2)
assert merged.success is False
assert merged.total_tests == 1
assert merged.failed_tests == 1
assert merged.execution_time_ms == 22.0
_, aggregated_result = merged.test_cases_and_results[0]
assert aggregated_result.passedRows == 70
assert aggregated_result.failedRows == 30
assert aggregated_result.testCaseStatus == TestCaseStatus.Failed
assert aggregated_result.passedRowsPercentage == pytest.approx(70.0)
assert aggregated_result.failedRowsPercentage == pytest.approx(30.0)
def test_merge_multiple_test_cases(self) -> None:
"""Test merging results with multiple different test cases."""
test_case1 = create_test_case("test.case.one")
test_case2 = create_test_case("test.case.two")
result1 = ValidationResult(
success=True,
total_tests=2,
passed_tests=2,
failed_tests=0,
test_cases_and_results=[
(
test_case1,
create_test_result(TestCaseStatus.Success, passed_rows=50),
),
(
test_case2,
create_test_result(TestCaseStatus.Success, passed_rows=30),
),
],
execution_time_ms=15.0,
)
result2 = ValidationResult(
success=True,
total_tests=2,
passed_tests=2,
failed_tests=0,
test_cases_and_results=[
(
test_case1,
create_test_result(TestCaseStatus.Success, passed_rows=25),
),
(
test_case2,
create_test_result(TestCaseStatus.Success, passed_rows=35),
),
],
execution_time_ms=12.0,
)
merged = ValidationResult.merge(result1, result2)
assert merged.success is True
assert merged.total_tests == 2
assert merged.passed_tests == 2
assert merged.failed_tests == 0
assert merged.execution_time_ms == 27.0
assert len(merged.test_cases_and_results) == 2
fqns_to_results = {
tc.fullyQualifiedName.root: result
for tc, result in merged.test_cases_and_results
}
assert fqns_to_results["test.case.one"].passedRows == 75
assert fqns_to_results["test.case.two"].passedRows == 65
def test_merge_with_failure_status_propagates(self) -> None:
"""Test that if any batch fails, overall status is Failed."""
test_case = create_test_case("test.case.one")
result1 = ValidationResult(
success=True,
total_tests=1,
passed_tests=1,
failed_tests=0,
test_cases_and_results=[
(test_case, create_test_result(TestCaseStatus.Success, passed_rows=50))
],
execution_time_ms=10.0,
)
result2 = ValidationResult(
success=False,
total_tests=1,
passed_tests=0,
failed_tests=1,
test_cases_and_results=[
(
test_case,
create_test_result(
TestCaseStatus.Failed, passed_rows=20, failed_rows=10
),
)
],
execution_time_ms=10.0,
)
merged = ValidationResult.merge(result1, result2)
assert merged.success is False
assert merged.failed_tests == 1
_, aggregated_result = merged.test_cases_and_results[0]
assert aggregated_result.testCaseStatus == TestCaseStatus.Failed
def test_merge_with_aborted_status_takes_precedence(self) -> None:
"""Test that Aborted status takes precedence over Failed."""
test_case = create_test_case("test.case.one")
result1 = ValidationResult(
success=False,
total_tests=1,
passed_tests=0,
failed_tests=1,
test_cases_and_results=[
(
test_case,
create_test_result(
TestCaseStatus.Failed, passed_rows=40, failed_rows=10
),
)
],
execution_time_ms=10.0,
)
result2 = ValidationResult(
success=False,
total_tests=1,
passed_tests=0,
failed_tests=1,
test_cases_and_results=[
(test_case, create_test_result(TestCaseStatus.Aborted))
],
execution_time_ms=5.0,
)
merged = ValidationResult.merge(result1, result2)
_, aggregated_result = merged.test_cases_and_results[0]
assert aggregated_result.testCaseStatus == TestCaseStatus.Aborted
def test_merge_empty_raises_error(self) -> None:
"""Test that merging with no results raises ValueError."""
with pytest.raises(ValueError, match="At least one ValidationResult"):
ValidationResult.merge()
def test_merge_without_fqn_raises_error(self) -> None:
"""Test that merging test cases without FQN raises ValueError."""
test_case = TestCase(
name="test",
fullyQualifiedName=None,
testDefinition=EntityReference(
id=UUID("12345678-1234-1234-1234-123456789abc"),
type="testDefinition",
fullyQualifiedName="test.def",
),
entityLink="<#E::table::test_table>",
testSuite=EntityReference(
id=UUID("87654321-4321-4321-4321-cba987654321"),
type="testSuite",
fullyQualifiedName="test_suite",
),
)
result = ValidationResult(
success=True,
total_tests=1,
passed_tests=1,
failed_tests=0,
test_cases_and_results=[
(test_case, create_test_result(TestCaseStatus.Success))
],
execution_time_ms=10.0,
)
with pytest.raises(ValueError, match="no fullyQualifiedName"):
ValidationResult.merge(result)
def test_merge_preserves_test_case_reference(self) -> None:
"""Test that the merged result contains the original test case reference."""
test_case = create_test_case("test.case.one")
result1 = ValidationResult(
success=True,
total_tests=1,
passed_tests=1,
failed_tests=0,
test_cases_and_results=[
(test_case, create_test_result(TestCaseStatus.Success, passed_rows=50))
],
execution_time_ms=10.0,
)
result2 = ValidationResult(
success=True,
total_tests=1,
passed_tests=1,
failed_tests=0,
test_cases_and_results=[
(test_case, create_test_result(TestCaseStatus.Success, passed_rows=30))
],
execution_time_ms=8.0,
)
merged = ValidationResult.merge(result1, result2)
merged_test_case, _ = merged.test_cases_and_results[0]
assert merged_test_case.fullyQualifiedName.root == "test.case.one"
assert merged_test_case.name.root == "one"