[GEN-970] feat(data-quality): support multiple runtime parameter types (#18588)

* feat(data-quality): support multiple runtime parameter types

- changed the runtime parameters setter factory to return sets
- add the runtime parameters based on the name of the runtime of the runtime parameter

**NOTE** requires changes on collate side

* empty set for default case
This commit is contained in:
Imri Paran 2024-11-21 08:07:33 +01:00 committed by GitHub
parent 890a4f8755
commit 0169aad418
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 66 additions and 44 deletions

View File

@ -16,7 +16,7 @@ Validators are test classes (e.g. columnValuesToBeBetween, etc.)
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Optional, Type, Union
from typing import TYPE_CHECKING, Set, Type, Union
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.data_quality.validations.runtime_param_setter.param_setter import (
@ -66,24 +66,20 @@ class IValidatorBuilder(ABC):
)
self.reset()
def set_runtime_params(
self, runtime_params_setter: Optional[RuntimeParameterSetter]
):
def set_runtime_params(self, runtime_params_setters: Set[RuntimeParameterSetter]):
"""Set the runtime parameters for the validator object
# TODO: We should support setting n runtime parameters
Args:
runtime_params_setter (Optional[RuntimeParameterSetter]): The runtime parameter setter
runtime_params_setters (Optional[RuntimeParameterSetter]): The runtime parameter setter
"""
if runtime_params_setter:
params = runtime_params_setter.get_parameters(self.test_case)
for setter in runtime_params_setters:
params = setter.get_parameters(self.test_case)
if not self.test_case.parameterValues:
# If there are no parameters, create a new list
self.test_case.parameterValues = []
self.test_case.parameterValues.append(
TestCaseParameterValue(
name="runtimeParams", value=params.model_dump_json()
name=type(params).__name__, value=params.model_dump_json()
)
)

View File

@ -15,7 +15,7 @@ supporting sqlalchemy abstraction layer
"""
from abc import ABC, abstractmethod
from typing import Optional, Type
from typing import Optional, Set, Type
from metadata.data_quality.builders.i_validator_builder import IValidatorBuilder
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
@ -111,9 +111,9 @@ class TestSuiteInterface(ABC):
runtime_params_setter_fact: RuntimeParameterSetterFactory = (
self._get_runtime_params_setter_fact()
) # type: ignore
runtime_params_setter: Optional[
runtime_params_setters: Set[
RuntimeParameterSetter
] = runtime_params_setter_fact.get_runtime_param_setter(
] = runtime_params_setter_fact.get_runtime_param_setters(
test_case.testDefinition.fullyQualifiedName, # type: ignore
self.ometa_client,
self.service_connection_config,
@ -127,7 +127,7 @@ class TestSuiteInterface(ABC):
).entityType.value
validator_builder = self._get_validator_builder(test_case, entity_type)
validator_builder.set_runtime_params(runtime_params_setter)
validator_builder.set_runtime_params(runtime_params_setters)
validator: BaseTestValidator = validator_builder.validator
try:
return validator.run_validation()

View File

@ -19,10 +19,9 @@ import reprlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel
from metadata.data_quality.validations import utils
from metadata.data_quality.validations.runtime_param_setter.param_setter import (
RuntimeParameterSetter,
)
from metadata.generated.schema.tests.basic import (
TestCaseResult,
TestCaseStatus,
@ -37,6 +36,7 @@ if TYPE_CHECKING:
T = TypeVar("T", bound=Callable)
R = TypeVar("R")
S = TypeVar("S", bound=BaseModel)
class BaseTestValidator(ABC):
@ -45,8 +45,6 @@ class BaseTestValidator(ABC):
This can be useful to resolve complex test parameters based on the parameters gibven by the user.
"""
runtime_parameter_setter: Optional[Type[RuntimeParameterSetter]] = None
def __init__(
self,
runner: Union[QueryRunner, List["DataFrame"]],
@ -168,3 +166,10 @@ class BaseTestValidator(ABC):
def get_predicted_value(self) -> Optional[str]:
"""Get predicted value"""
return None
def get_runtime_parameters(self, setter_class: Type[S]) -> S:
"""Get runtime parameters"""
for param in self.test_case.parameterValues or []:
if param.name == setter_class.__name__:
return setter_class.model_validate_json(param.value)
raise ValueError(f"Runtime parameter {setter_class.__name__} not found")

View File

@ -13,8 +13,8 @@ Module that defines the RuntimeParameterFactory class.
This class is responsible for creating instances of the RuntimeParameterSetter
based on the test case.
"""
from typing import Optional
import sys
from typing import Dict, Set, Type
from metadata.data_quality.validations.runtime_param_setter.param_setter import (
RuntimeParameterSetter,
@ -22,7 +22,35 @@ from metadata.data_quality.validations.runtime_param_setter.param_setter import
from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import (
TableDiffParamsSetter,
)
from metadata.data_quality.validations.table.sqlalchemy.tableDiff import (
TableDiffValidator,
)
from metadata.generated.schema.entity.data.table import Table
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.profiler.processor.sampler.sqlalchemy.sampler import SQASampler
def removesuffix(s: str, suffix: str) -> str:
"""A custom implementation of removesuffix for python versions < 3.9
Args:
s (str): The string to remove the suffix from
suffix (str): The suffix to remove
Returns:
str: The string with the suffix removed
"""
if sys.version_info >= (3, 9):
return s.removesuffix(suffix)
if s.endswith(suffix):
return s[: -len(suffix)]
return s
def validator_name(test_case_class: Type) -> str:
return removesuffix(
test_case_class.__name__[0].lower() + test_case_class.__name__[1:], "Validator"
)
class RuntimeParameterSetterFactory:
@ -30,25 +58,25 @@ class RuntimeParameterSetterFactory:
def __init__(self) -> None:
"""Set"""
self._setter_map = {
TableDiffParamsSetter: {"tableDiff"},
self._setter_map: Dict[str, Set[Type[RuntimeParameterSetter]]] = {
validator_name(TableDiffValidator): {TableDiffParamsSetter},
}
def get_runtime_param_setter(
def get_runtime_param_setters(
self,
name: str,
ometa: OpenMetadata,
service_connection_config,
table_entity,
sampler,
) -> Optional[RuntimeParameterSetter]:
table_entity: Table,
sampler: SQASampler,
) -> Set[RuntimeParameterSetter]:
"""Get the runtime parameter setter"""
for setter_cls, validator_names in self._setter_map.items():
if name in validator_names:
return setter_cls(
ometa,
service_connection_config,
table_entity,
sampler,
)
return None
return {
setter(
ometa,
service_connection_config,
table_entity,
sampler,
)
for setter in self._setter_map.get(name, set())
}

View File

@ -178,7 +178,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
runtime_params: TableDiffRuntimeParameters
def run_validation(self) -> TestCaseResult:
self.runtime_params = self.get_runtime_params()
self.runtime_params = self.get_runtime_parameters(TableDiffRuntimeParameters)
try:
self._validate_dialects()
return self._run()
@ -438,13 +438,6 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
)
raise ValueError("Invalid profile sample type")
def get_runtime_params(self) -> TableDiffRuntimeParameters:
raw = self.get_test_case_param_value(
self.test_case.parameterValues, "runtimeParams", str
)
runtime_params = TableDiffRuntimeParameters.model_validate_json(raw)
return runtime_params
def get_row_diff_test_case_result(
self,
threshold: int,