mirror of
https://github.com/datahub-project/datahub.git
synced 2025-07-14 04:31:02 +00:00
905 lines
37 KiB
Python
905 lines
37 KiB
Python
import json
|
|
import logging
|
|
import sys
|
|
import time
|
|
from dataclasses import dataclass
|
|
from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED
|
|
from datetime import timezone
|
|
from decimal import Decimal
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
|
|
import packaging.version
|
|
from great_expectations.checkpoint.actions import ValidationAction
|
|
from great_expectations.core.batch import Batch
|
|
from great_expectations.core.batch_spec import (
|
|
RuntimeDataBatchSpec,
|
|
RuntimeQueryBatchSpec,
|
|
SqlAlchemyDatasourceBatchSpec,
|
|
)
|
|
from great_expectations.core.expectation_validation_result import (
|
|
ExpectationSuiteValidationResult,
|
|
)
|
|
from great_expectations.data_asset.data_asset import DataAsset
|
|
from great_expectations.data_context import AbstractDataContext
|
|
from great_expectations.data_context.types.resource_identifiers import (
|
|
ExpectationSuiteIdentifier,
|
|
ValidationResultIdentifier,
|
|
)
|
|
from great_expectations.execution_engine import PandasExecutionEngine
|
|
from great_expectations.execution_engine.sqlalchemy_execution_engine import (
|
|
SqlAlchemyExecutionEngine,
|
|
)
|
|
from great_expectations.validator.validator import Validator
|
|
from sqlalchemy.engine.base import Connection, Engine
|
|
from sqlalchemy.engine.url import make_url
|
|
|
|
import datahub.emitter.mce_builder as builder
|
|
from datahub.cli.env_utils import get_boolean_env_variable
|
|
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
|
from datahub.emitter.rest_emitter import DatahubRestEmitter
|
|
from datahub.emitter.serialization_helper import pre_json_transform
|
|
from datahub.ingestion.graph.config import ClientMode
|
|
from datahub.ingestion.source.sql.sqlalchemy_uri_mapper import (
|
|
get_platform_from_sqlalchemy_uri,
|
|
)
|
|
from datahub.metadata.com.linkedin.pegasus2avro.assertion import (
|
|
AssertionInfo,
|
|
AssertionResult,
|
|
AssertionResultType,
|
|
AssertionRunEvent,
|
|
AssertionRunStatus,
|
|
AssertionStdAggregation,
|
|
AssertionStdOperator,
|
|
AssertionStdParameter,
|
|
AssertionStdParameters,
|
|
AssertionStdParameterType,
|
|
AssertionType,
|
|
BatchSpec,
|
|
DatasetAssertionInfo,
|
|
DatasetAssertionScope,
|
|
)
|
|
from datahub.metadata.com.linkedin.pegasus2avro.common import DataPlatformInstance
|
|
from datahub.metadata.schema_classes import PartitionSpecClass, PartitionTypeClass
|
|
from datahub.sql_parsing.sqlglot_lineage import create_lineage_sql_parsed_result
|
|
from datahub.utilities.urns.dataset_urn import DatasetUrn
|
|
|
|
# TODO: move this and version check used in tests to some common module
|
|
try:
|
|
from great_expectations import __version__ as GX_VERSION # type: ignore
|
|
|
|
has_name_positional_arg = packaging.version.parse(
|
|
GX_VERSION
|
|
) >= packaging.version.Version("0.18.14")
|
|
except Exception:
|
|
has_name_positional_arg = False
|
|
|
|
if TYPE_CHECKING:
|
|
from great_expectations.data_context.types.resource_identifiers import (
|
|
GXCloudIdentifier,
|
|
)
|
|
|
|
assert MARKUPSAFE_PATCHED
|
|
logger = logging.getLogger(__name__)
|
|
if get_boolean_env_variable("DATAHUB_DEBUG", False):
|
|
handler = logging.StreamHandler(stream=sys.stdout)
|
|
logger.addHandler(handler)
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
GE_PLATFORM_NAME = "great-expectations"
|
|
|
|
|
|
class DataHubValidationAction(ValidationAction):
|
|
def __init__(
|
|
self,
|
|
data_context: AbstractDataContext,
|
|
# this would capture `name` positional arg added in GX 0.18.14
|
|
*args: Union[str, Any],
|
|
server_url: str,
|
|
env: str = builder.DEFAULT_ENV,
|
|
platform_alias: Optional[str] = None,
|
|
platform_instance_map: Optional[Dict[str, str]] = None,
|
|
graceful_exceptions: bool = True,
|
|
token: Optional[str] = None,
|
|
timeout_sec: Optional[float] = None,
|
|
retry_status_codes: Optional[List[int]] = None,
|
|
retry_max_times: Optional[int] = None,
|
|
extra_headers: Optional[Dict[str, str]] = None,
|
|
exclude_dbname: Optional[bool] = None,
|
|
parse_table_names_from_sql: bool = False,
|
|
convert_urns_to_lowercase: bool = False,
|
|
name: str = "DataHubValidationAction",
|
|
):
|
|
if has_name_positional_arg:
|
|
if len(args) >= 1 and isinstance(args[0], str):
|
|
name = args[0]
|
|
super().__init__(data_context, name)
|
|
else:
|
|
super().__init__(data_context)
|
|
self.server_url = server_url
|
|
self.env = env
|
|
self.platform_alias = platform_alias
|
|
self.platform_instance_map = platform_instance_map
|
|
self.graceful_exceptions = graceful_exceptions
|
|
self.token = token
|
|
self.timeout_sec = timeout_sec
|
|
self.retry_status_codes = retry_status_codes
|
|
self.retry_max_times = retry_max_times
|
|
self.extra_headers = extra_headers
|
|
self.exclude_dbname = exclude_dbname
|
|
self.parse_table_names_from_sql = parse_table_names_from_sql
|
|
self.convert_urns_to_lowercase = convert_urns_to_lowercase
|
|
|
|
def _run(
|
|
self,
|
|
validation_result_suite: ExpectationSuiteValidationResult,
|
|
validation_result_suite_identifier: Union[
|
|
ValidationResultIdentifier, "GXCloudIdentifier"
|
|
],
|
|
data_asset: Union[Validator, DataAsset, Batch],
|
|
payload: Optional[Any] = None,
|
|
expectation_suite_identifier: Optional[ExpectationSuiteIdentifier] = None,
|
|
checkpoint_identifier: Optional[Any] = None,
|
|
) -> Dict:
|
|
datasets = []
|
|
try:
|
|
emitter = DatahubRestEmitter(
|
|
gms_server=self.server_url,
|
|
token=self.token,
|
|
read_timeout_sec=self.timeout_sec,
|
|
connect_timeout_sec=self.timeout_sec,
|
|
retry_status_codes=self.retry_status_codes,
|
|
retry_max_times=self.retry_max_times,
|
|
extra_headers=self.extra_headers,
|
|
client_mode=ClientMode.INGESTION,
|
|
datahub_component="gx-plugin",
|
|
)
|
|
|
|
expectation_suite_name = validation_result_suite.meta.get(
|
|
"expectation_suite_name"
|
|
)
|
|
run_id = validation_result_suite.meta.get("run_id")
|
|
if hasattr(data_asset, "active_batch_id"):
|
|
batch_identifier = data_asset.active_batch_id
|
|
else:
|
|
batch_identifier = data_asset.batch_id
|
|
|
|
if isinstance(
|
|
validation_result_suite_identifier, ValidationResultIdentifier
|
|
):
|
|
expectation_suite_name = validation_result_suite_identifier.expectation_suite_identifier.expectation_suite_name
|
|
run_id = validation_result_suite_identifier.run_id
|
|
batch_identifier = validation_result_suite_identifier.batch_identifier
|
|
|
|
# Returns datasets and corresponding batch requests
|
|
datasets = self.get_dataset_partitions(batch_identifier, data_asset)
|
|
|
|
if len(datasets) == 0 or datasets[0]["dataset_urn"] is None:
|
|
warn("Metadata not sent to datahub. No datasets found.")
|
|
return {"datahub_notification_result": "none required"}
|
|
|
|
# Returns assertion info and assertion results
|
|
assertions = self.get_assertions_with_results(
|
|
validation_result_suite,
|
|
expectation_suite_name,
|
|
run_id,
|
|
payload,
|
|
datasets,
|
|
)
|
|
|
|
logger.info("Sending metadata to datahub ...")
|
|
logger.info("Dataset URN - {urn}".format(urn=datasets[0]["dataset_urn"]))
|
|
|
|
for assertion in assertions:
|
|
logger.info(
|
|
"Assertion URN - {urn}".format(urn=assertion["assertionUrn"])
|
|
)
|
|
|
|
# Construct a MetadataChangeProposalWrapper object.
|
|
assertion_info_mcp = MetadataChangeProposalWrapper(
|
|
entityUrn=assertion["assertionUrn"],
|
|
aspect=assertion["assertionInfo"],
|
|
)
|
|
emitter.emit_mcp(assertion_info_mcp)
|
|
|
|
# Construct a MetadataChangeProposalWrapper object.
|
|
assertion_platform_mcp = MetadataChangeProposalWrapper(
|
|
entityUrn=assertion["assertionUrn"],
|
|
aspect=assertion["assertionPlatform"],
|
|
)
|
|
emitter.emit_mcp(assertion_platform_mcp)
|
|
|
|
for assertionResult in assertion["assertionResults"]:
|
|
dataset_assertionResult_mcp = MetadataChangeProposalWrapper(
|
|
entityUrn=assertionResult.assertionUrn,
|
|
aspect=assertionResult,
|
|
)
|
|
|
|
# Emit Result! (timeseries aspect)
|
|
emitter.emit_mcp(dataset_assertionResult_mcp)
|
|
logger.info("Metadata sent to datahub.")
|
|
result = "DataHub notification succeeded"
|
|
except Exception as e:
|
|
result = "DataHub notification failed"
|
|
if self.graceful_exceptions:
|
|
logger.error(e)
|
|
logger.info("Suppressing error because graceful_exceptions is set")
|
|
else:
|
|
raise
|
|
|
|
return {"datahub_notification_result": result}
|
|
|
|
def get_assertions_with_results(
|
|
self,
|
|
validation_result_suite,
|
|
expectation_suite_name,
|
|
run_id,
|
|
payload,
|
|
datasets,
|
|
):
|
|
dataPlatformInstance = DataPlatformInstance(
|
|
platform=builder.make_data_platform_urn(GE_PLATFORM_NAME)
|
|
)
|
|
docs_link = None
|
|
if payload:
|
|
# process the payload
|
|
for action_names in payload.keys():
|
|
if payload[action_names]["class"] == "UpdateDataDocsAction":
|
|
data_docs_pages = payload[action_names]
|
|
for docs_link_key, docs_link_val in data_docs_pages.items():
|
|
if "file://" not in docs_link_val and docs_link_key != "class":
|
|
docs_link = docs_link_val
|
|
|
|
assertions_with_results = []
|
|
for result in validation_result_suite.results:
|
|
expectation_config = result["expectation_config"]
|
|
expectation_type = expectation_config["expectation_type"]
|
|
success = bool(result["success"])
|
|
kwargs = {
|
|
k: v for k, v in expectation_config["kwargs"].items() if k != "batch_id"
|
|
}
|
|
|
|
result = result["result"]
|
|
assertion_datasets = [d["dataset_urn"] for d in datasets]
|
|
if len(datasets) == 1 and "column" in kwargs:
|
|
assertion_fields = [
|
|
builder.make_schema_field_urn(
|
|
datasets[0]["dataset_urn"], kwargs["column"]
|
|
)
|
|
]
|
|
else:
|
|
assertion_fields = None # type:ignore
|
|
|
|
# Be careful what fields to consider for creating assertion urn.
|
|
# Any change in fields below would lead to a new assertion
|
|
# FIXME - Currently, when using evaluation parameters, new assertion is
|
|
# created when runtime resolved kwargs are different,
|
|
# possibly for each validation run
|
|
assertionUrn = builder.make_assertion_urn(
|
|
builder.datahub_guid(
|
|
pre_json_transform(
|
|
{
|
|
"platform": GE_PLATFORM_NAME,
|
|
"nativeType": expectation_type,
|
|
"nativeParameters": kwargs,
|
|
"dataset": assertion_datasets[0],
|
|
"fields": assertion_fields,
|
|
}
|
|
)
|
|
)
|
|
)
|
|
logger.debug(
|
|
"GE expectation_suite_name - {name}, expectation_type - {type}, Assertion URN - {urn}".format(
|
|
name=expectation_suite_name, type=expectation_type, urn=assertionUrn
|
|
)
|
|
)
|
|
assertionInfo: AssertionInfo = self.get_assertion_info(
|
|
expectation_type,
|
|
kwargs,
|
|
assertion_datasets[0],
|
|
assertion_fields,
|
|
expectation_suite_name,
|
|
)
|
|
|
|
# TODO: Understand why their run time is incorrect.
|
|
run_time = run_id.run_time.astimezone(timezone.utc)
|
|
evaluation_parameters = (
|
|
{
|
|
k: convert_to_string(v)
|
|
for k, v in validation_result_suite.evaluation_parameters.items()
|
|
if k and v
|
|
}
|
|
if validation_result_suite.evaluation_parameters
|
|
else None
|
|
)
|
|
|
|
nativeResults = {
|
|
k: convert_to_string(v)
|
|
for k, v in result.items()
|
|
if (
|
|
k
|
|
in [
|
|
"observed_value",
|
|
"partial_unexpected_list",
|
|
"partial_unexpected_counts",
|
|
"details",
|
|
]
|
|
and v
|
|
)
|
|
}
|
|
|
|
actualAggValue = (
|
|
result.get("observed_value")
|
|
if isinstance(result.get("observed_value"), (int, float))
|
|
else None
|
|
)
|
|
|
|
ds = datasets[0]
|
|
# https://docs.greatexpectations.io/docs/reference/expectations/result_format/
|
|
assertionResult = AssertionRunEvent(
|
|
timestampMillis=int(round(time.time() * 1000)),
|
|
assertionUrn=assertionUrn,
|
|
asserteeUrn=ds["dataset_urn"],
|
|
runId=run_time.strftime("%Y-%m-%dT%H:%M:%SZ"),
|
|
result=AssertionResult(
|
|
type=(
|
|
AssertionResultType.SUCCESS
|
|
if success
|
|
else AssertionResultType.FAILURE
|
|
),
|
|
rowCount=parse_int_or_default(result.get("element_count")),
|
|
missingCount=parse_int_or_default(result.get("missing_count")),
|
|
unexpectedCount=parse_int_or_default(
|
|
result.get("unexpected_count")
|
|
),
|
|
actualAggValue=actualAggValue,
|
|
externalUrl=docs_link,
|
|
nativeResults=nativeResults,
|
|
),
|
|
batchSpec=ds["batchSpec"],
|
|
status=AssertionRunStatus.COMPLETE,
|
|
runtimeContext=evaluation_parameters,
|
|
)
|
|
if ds.get("partitionSpec") is not None:
|
|
assertionResult.partitionSpec = ds.get("partitionSpec")
|
|
assertionResults = [assertionResult]
|
|
assertions_with_results.append(
|
|
{
|
|
"assertionUrn": assertionUrn,
|
|
"assertionInfo": assertionInfo,
|
|
"assertionPlatform": dataPlatformInstance,
|
|
"assertionResults": assertionResults,
|
|
}
|
|
)
|
|
return assertions_with_results
|
|
|
|
def get_assertion_info(
|
|
self, expectation_type, kwargs, dataset, fields, expectation_suite_name
|
|
):
|
|
# TODO - can we find exact type of min and max value
|
|
def get_min_max(kwargs, type=AssertionStdParameterType.UNKNOWN):
|
|
return AssertionStdParameters(
|
|
minValue=AssertionStdParameter(
|
|
value=convert_to_string(kwargs.get("min_value")),
|
|
type=type,
|
|
),
|
|
maxValue=AssertionStdParameter(
|
|
value=convert_to_string(kwargs.get("max_value")),
|
|
type=type,
|
|
),
|
|
)
|
|
|
|
known_expectations: Dict[str, DataHubStdAssertion] = {
|
|
# column aggregate expectations
|
|
"expect_column_min_to_be_between": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_COLUMN,
|
|
operator=AssertionStdOperator.BETWEEN,
|
|
aggregation=AssertionStdAggregation.MIN,
|
|
parameters=get_min_max(kwargs),
|
|
),
|
|
"expect_column_max_to_be_between": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_COLUMN,
|
|
operator=AssertionStdOperator.BETWEEN,
|
|
aggregation=AssertionStdAggregation.MAX,
|
|
parameters=get_min_max(kwargs),
|
|
),
|
|
"expect_column_median_to_be_between": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_COLUMN,
|
|
operator=AssertionStdOperator.BETWEEN,
|
|
aggregation=AssertionStdAggregation.MEDIAN,
|
|
parameters=get_min_max(kwargs),
|
|
),
|
|
"expect_column_stdev_to_be_between": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_COLUMN,
|
|
operator=AssertionStdOperator.BETWEEN,
|
|
aggregation=AssertionStdAggregation.STDDEV,
|
|
parameters=get_min_max(kwargs, AssertionStdParameterType.NUMBER),
|
|
),
|
|
"expect_column_mean_to_be_between": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_COLUMN,
|
|
operator=AssertionStdOperator.BETWEEN,
|
|
aggregation=AssertionStdAggregation.MEAN,
|
|
parameters=get_min_max(kwargs, AssertionStdParameterType.NUMBER),
|
|
),
|
|
"expect_column_unique_value_count_to_be_between": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_COLUMN,
|
|
operator=AssertionStdOperator.BETWEEN,
|
|
aggregation=AssertionStdAggregation.UNIQUE_COUNT,
|
|
parameters=get_min_max(kwargs, AssertionStdParameterType.NUMBER),
|
|
),
|
|
"expect_column_proportion_of_unique_values_to_be_between": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_COLUMN,
|
|
operator=AssertionStdOperator.BETWEEN,
|
|
aggregation=AssertionStdAggregation.UNIQUE_PROPOTION,
|
|
parameters=get_min_max(kwargs, AssertionStdParameterType.NUMBER),
|
|
),
|
|
"expect_column_sum_to_be_between": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_COLUMN,
|
|
operator=AssertionStdOperator.BETWEEN,
|
|
aggregation=AssertionStdAggregation.SUM,
|
|
parameters=get_min_max(kwargs, AssertionStdParameterType.NUMBER),
|
|
),
|
|
"expect_column_quantile_values_to_be_between": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_COLUMN,
|
|
operator=AssertionStdOperator.BETWEEN,
|
|
aggregation=AssertionStdAggregation._NATIVE_,
|
|
),
|
|
# column map expectations
|
|
"expect_column_values_to_not_be_null": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_COLUMN,
|
|
operator=AssertionStdOperator.NOT_NULL,
|
|
aggregation=AssertionStdAggregation.IDENTITY,
|
|
),
|
|
"expect_column_values_to_be_in_set": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_COLUMN,
|
|
operator=AssertionStdOperator.IN,
|
|
aggregation=AssertionStdAggregation.IDENTITY,
|
|
parameters=AssertionStdParameters(
|
|
value=AssertionStdParameter(
|
|
value=convert_to_string(kwargs.get("value_set")),
|
|
type=AssertionStdParameterType.SET,
|
|
)
|
|
),
|
|
),
|
|
"expect_column_values_to_be_between": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_COLUMN,
|
|
operator=AssertionStdOperator.BETWEEN,
|
|
aggregation=AssertionStdAggregation.IDENTITY,
|
|
parameters=get_min_max(kwargs),
|
|
),
|
|
"expect_column_values_to_match_regex": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_COLUMN,
|
|
operator=AssertionStdOperator.REGEX_MATCH,
|
|
aggregation=AssertionStdAggregation.IDENTITY,
|
|
parameters=AssertionStdParameters(
|
|
value=AssertionStdParameter(
|
|
value=kwargs.get("regex"),
|
|
type=AssertionStdParameterType.STRING,
|
|
)
|
|
),
|
|
),
|
|
"expect_column_values_to_match_regex_list": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_COLUMN,
|
|
operator=AssertionStdOperator.REGEX_MATCH,
|
|
aggregation=AssertionStdAggregation.IDENTITY,
|
|
parameters=AssertionStdParameters(
|
|
value=AssertionStdParameter(
|
|
value=convert_to_string(kwargs.get("regex_list")),
|
|
type=AssertionStdParameterType.LIST,
|
|
)
|
|
),
|
|
),
|
|
"expect_table_columns_to_match_ordered_list": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_SCHEMA,
|
|
operator=AssertionStdOperator.EQUAL_TO,
|
|
aggregation=AssertionStdAggregation.COLUMNS,
|
|
parameters=AssertionStdParameters(
|
|
value=AssertionStdParameter(
|
|
value=convert_to_string(kwargs.get("column_list")),
|
|
type=AssertionStdParameterType.LIST,
|
|
)
|
|
),
|
|
),
|
|
"expect_table_columns_to_match_set": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_SCHEMA,
|
|
operator=AssertionStdOperator.EQUAL_TO,
|
|
aggregation=AssertionStdAggregation.COLUMNS,
|
|
parameters=AssertionStdParameters(
|
|
value=AssertionStdParameter(
|
|
value=convert_to_string(kwargs.get("column_set")),
|
|
type=AssertionStdParameterType.SET,
|
|
)
|
|
),
|
|
),
|
|
"expect_table_column_count_to_be_between": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_SCHEMA,
|
|
operator=AssertionStdOperator.BETWEEN,
|
|
aggregation=AssertionStdAggregation.COLUMN_COUNT,
|
|
parameters=get_min_max(kwargs, AssertionStdParameterType.NUMBER),
|
|
),
|
|
"expect_table_column_count_to_equal": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_SCHEMA,
|
|
operator=AssertionStdOperator.EQUAL_TO,
|
|
aggregation=AssertionStdAggregation.COLUMN_COUNT,
|
|
parameters=AssertionStdParameters(
|
|
value=AssertionStdParameter(
|
|
value=convert_to_string(kwargs.get("value")),
|
|
type=AssertionStdParameterType.NUMBER,
|
|
)
|
|
),
|
|
),
|
|
"expect_column_to_exist": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_SCHEMA,
|
|
operator=AssertionStdOperator._NATIVE_,
|
|
aggregation=AssertionStdAggregation._NATIVE_,
|
|
),
|
|
"expect_table_row_count_to_equal": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_ROWS,
|
|
operator=AssertionStdOperator.EQUAL_TO,
|
|
aggregation=AssertionStdAggregation.ROW_COUNT,
|
|
parameters=AssertionStdParameters(
|
|
value=AssertionStdParameter(
|
|
value=convert_to_string(kwargs.get("value")),
|
|
type=AssertionStdParameterType.NUMBER,
|
|
)
|
|
),
|
|
),
|
|
"expect_table_row_count_to_be_between": DataHubStdAssertion(
|
|
scope=DatasetAssertionScope.DATASET_ROWS,
|
|
operator=AssertionStdOperator.BETWEEN,
|
|
aggregation=AssertionStdAggregation.ROW_COUNT,
|
|
parameters=get_min_max(kwargs, AssertionStdParameterType.NUMBER),
|
|
),
|
|
}
|
|
|
|
datasetAssertionInfo = DatasetAssertionInfo(
|
|
dataset=dataset,
|
|
fields=fields,
|
|
operator=AssertionStdOperator._NATIVE_,
|
|
aggregation=AssertionStdAggregation._NATIVE_,
|
|
nativeType=expectation_type,
|
|
nativeParameters={k: convert_to_string(v) for k, v in kwargs.items()},
|
|
scope=DatasetAssertionScope.DATASET_ROWS,
|
|
)
|
|
|
|
if expectation_type in known_expectations.keys():
|
|
assertion = known_expectations[expectation_type]
|
|
datasetAssertionInfo.scope = assertion.scope
|
|
datasetAssertionInfo.aggregation = assertion.aggregation
|
|
datasetAssertionInfo.operator = assertion.operator
|
|
datasetAssertionInfo.parameters = assertion.parameters
|
|
|
|
# Heuristically mapping other expectations
|
|
else:
|
|
if "column" in kwargs and expectation_type.startswith(
|
|
"expect_column_value"
|
|
):
|
|
datasetAssertionInfo.scope = DatasetAssertionScope.DATASET_COLUMN
|
|
datasetAssertionInfo.aggregation = AssertionStdAggregation.IDENTITY
|
|
elif "column" in kwargs:
|
|
datasetAssertionInfo.scope = DatasetAssertionScope.DATASET_COLUMN
|
|
datasetAssertionInfo.aggregation = AssertionStdAggregation._NATIVE_
|
|
|
|
return AssertionInfo(
|
|
type=AssertionType.DATASET,
|
|
datasetAssertion=datasetAssertionInfo,
|
|
customProperties={"expectation_suite_name": expectation_suite_name},
|
|
)
|
|
|
|
def get_dataset_partitions(self, batch_identifier, data_asset):
|
|
dataset_partitions = []
|
|
|
|
logger.debug("Finding datasets being validated")
|
|
|
|
# for now, we support only v3-api and sqlalchemy execution engine and Pandas engine
|
|
is_sql_alchemy = isinstance(data_asset, Validator) and (
|
|
isinstance(data_asset.execution_engine, SqlAlchemyExecutionEngine)
|
|
)
|
|
is_pandas = isinstance(data_asset.execution_engine, PandasExecutionEngine)
|
|
if is_sql_alchemy or is_pandas:
|
|
ge_batch_spec = data_asset.active_batch_spec
|
|
partitionSpec = None
|
|
batchSpecProperties = {
|
|
"data_asset_name": str(
|
|
data_asset.active_batch_definition.data_asset_name
|
|
),
|
|
"datasource_name": str(
|
|
data_asset.active_batch_definition.datasource_name
|
|
),
|
|
}
|
|
sqlalchemy_uri = None
|
|
if is_sql_alchemy and isinstance(
|
|
data_asset.execution_engine.engine, Engine
|
|
):
|
|
sqlalchemy_uri = data_asset.execution_engine.engine.url
|
|
# For snowflake sqlalchemy_execution_engine.engine is actually instance of Connection
|
|
elif is_sql_alchemy and isinstance(
|
|
data_asset.execution_engine.engine, Connection
|
|
):
|
|
sqlalchemy_uri = data_asset.execution_engine.engine.engine.url
|
|
|
|
if isinstance(ge_batch_spec, SqlAlchemyDatasourceBatchSpec):
|
|
# e.g. ConfiguredAssetSqlDataConnector with splitter_method or sampling_method
|
|
schema_name = ge_batch_spec.get("schema_name")
|
|
table_name = ge_batch_spec.get("table_name")
|
|
|
|
dataset_urn = make_dataset_urn_from_sqlalchemy_uri(
|
|
sqlalchemy_uri,
|
|
schema_name,
|
|
table_name,
|
|
self.env,
|
|
self.get_platform_instance(
|
|
data_asset.active_batch_definition.datasource_name
|
|
),
|
|
self.exclude_dbname,
|
|
self.platform_alias,
|
|
self.convert_urns_to_lowercase,
|
|
)
|
|
batchSpec = BatchSpec(
|
|
nativeBatchId=batch_identifier,
|
|
customProperties=batchSpecProperties,
|
|
)
|
|
|
|
splitter_method = ge_batch_spec.get("splitter_method")
|
|
if (
|
|
splitter_method is not None
|
|
and splitter_method != "_split_on_whole_table"
|
|
):
|
|
batch_identifiers = ge_batch_spec.get("batch_identifiers", {})
|
|
partitionSpec = PartitionSpecClass(
|
|
partition=convert_to_string(batch_identifiers)
|
|
)
|
|
sampling_method = ge_batch_spec.get("sampling_method", "")
|
|
if sampling_method == "_sample_using_limit":
|
|
batchSpec.limit = ge_batch_spec["sampling_kwargs"]["n"]
|
|
|
|
dataset_partitions.append(
|
|
{
|
|
"dataset_urn": dataset_urn,
|
|
"partitionSpec": partitionSpec,
|
|
"batchSpec": batchSpec,
|
|
}
|
|
)
|
|
elif isinstance(ge_batch_spec, RuntimeQueryBatchSpec):
|
|
if not self.parse_table_names_from_sql:
|
|
warn(
|
|
"Enable parse_table_names_from_sql in DatahubValidationAction config\
|
|
to try to parse the tables being asserted from SQL query"
|
|
)
|
|
return []
|
|
query = data_asset.batches[
|
|
batch_identifier
|
|
].batch_request.runtime_parameters["query"]
|
|
partitionSpec = PartitionSpecClass(
|
|
type=PartitionTypeClass.QUERY,
|
|
partition=f"Query_{builder.datahub_guid(pre_json_transform(query))}",
|
|
)
|
|
|
|
batchSpec = BatchSpec(
|
|
nativeBatchId=batch_identifier,
|
|
query=query,
|
|
customProperties=batchSpecProperties,
|
|
)
|
|
|
|
data_platform = get_platform_from_sqlalchemy_uri(str(sqlalchemy_uri))
|
|
sql_parser_in_tables = create_lineage_sql_parsed_result(
|
|
query=query,
|
|
platform=data_platform,
|
|
env=self.env,
|
|
platform_instance=None,
|
|
default_db=None,
|
|
)
|
|
tables = [
|
|
DatasetUrn.from_string(table_urn).name
|
|
for table_urn in sql_parser_in_tables.in_tables
|
|
]
|
|
if sql_parser_in_tables.debug_info.table_error:
|
|
logger.warning(
|
|
f"Sql parser failed on {query} with {sql_parser_in_tables.debug_info.table_error}"
|
|
)
|
|
tables = []
|
|
|
|
if len(set(tables)) != 1:
|
|
warn(
|
|
"DataHubValidationAction does not support cross dataset assertions."
|
|
)
|
|
return []
|
|
for table in tables:
|
|
dataset_urn = make_dataset_urn_from_sqlalchemy_uri(
|
|
sqlalchemy_uri,
|
|
None,
|
|
table,
|
|
self.env,
|
|
self.get_platform_instance(
|
|
data_asset.active_batch_definition.datasource_name
|
|
),
|
|
self.exclude_dbname,
|
|
self.platform_alias,
|
|
self.convert_urns_to_lowercase,
|
|
)
|
|
dataset_partitions.append(
|
|
{
|
|
"dataset_urn": dataset_urn,
|
|
"partitionSpec": partitionSpec,
|
|
"batchSpec": batchSpec,
|
|
}
|
|
)
|
|
elif isinstance(ge_batch_spec, RuntimeDataBatchSpec):
|
|
data_platform = self.get_platform_instance(
|
|
data_asset.active_batch_definition.datasource_name
|
|
)
|
|
dataset_urn = builder.make_dataset_urn_with_platform_instance(
|
|
platform=(
|
|
data_platform
|
|
if self.platform_alias is None
|
|
else self.platform_alias
|
|
),
|
|
name=data_asset.active_batch_definition.datasource_name,
|
|
platform_instance="",
|
|
env=self.env,
|
|
)
|
|
batchSpec = BatchSpec(
|
|
nativeBatchId=batch_identifier,
|
|
query="",
|
|
customProperties=batchSpecProperties,
|
|
)
|
|
dataset_partitions.append(
|
|
{
|
|
"dataset_urn": dataset_urn,
|
|
"partitionSpec": partitionSpec,
|
|
"batchSpec": batchSpec,
|
|
}
|
|
)
|
|
else:
|
|
warn(
|
|
"DataHubValidationAction does not recognize this GE batch spec type- {batch_spec_type}.".format(
|
|
batch_spec_type=type(ge_batch_spec)
|
|
)
|
|
)
|
|
else:
|
|
# TODO - v2-spec - SqlAlchemyDataset support
|
|
warn(
|
|
"DataHubValidationAction does not recognize this GE data asset type - {asset_type}. This is either using v2-api or execution engine other than sqlalchemy.".format(
|
|
asset_type=type(data_asset)
|
|
)
|
|
)
|
|
|
|
return dataset_partitions
|
|
|
|
def get_platform_instance(self, datasource_name):
|
|
if self.platform_instance_map and datasource_name in self.platform_instance_map:
|
|
return self.platform_instance_map[datasource_name]
|
|
else:
|
|
warn(
|
|
f"Datasource {datasource_name} is not present in platform_instance_map"
|
|
)
|
|
return None
|
|
|
|
|
|
def parse_int_or_default(value, default_value=None):
|
|
if value is None:
|
|
return default_value
|
|
else:
|
|
return int(value)
|
|
|
|
|
|
def make_dataset_urn_from_sqlalchemy_uri(
|
|
sqlalchemy_uri,
|
|
schema_name,
|
|
table_name,
|
|
env,
|
|
platform_instance=None,
|
|
exclude_dbname=None,
|
|
platform_alias=None,
|
|
convert_urns_to_lowercase=False,
|
|
):
|
|
data_platform = get_platform_from_sqlalchemy_uri(str(sqlalchemy_uri))
|
|
url_instance = make_url(sqlalchemy_uri)
|
|
|
|
if schema_name is None and "." in table_name:
|
|
schema_name, table_name = table_name.split(".")[-2:]
|
|
|
|
if data_platform in ["redshift", "postgres"]:
|
|
schema_name = schema_name or "public"
|
|
if url_instance.database is None:
|
|
warn(
|
|
f"DataHubValidationAction failed to locate database name for {data_platform}."
|
|
)
|
|
return None
|
|
schema_name = (
|
|
schema_name if exclude_dbname else f"{url_instance.database}.{schema_name}"
|
|
)
|
|
elif data_platform == "mssql":
|
|
schema_name = schema_name or "dbo"
|
|
if url_instance.database is None:
|
|
warn(
|
|
f"DataHubValidationAction failed to locate database name for {data_platform}."
|
|
)
|
|
return None
|
|
schema_name = (
|
|
schema_name if exclude_dbname else f"{url_instance.database}.{schema_name}"
|
|
)
|
|
elif data_platform in ["trino", "snowflake"]:
|
|
if schema_name is None or url_instance.database is None:
|
|
warn(
|
|
"DataHubValidationAction failed to locate schema name and/or database name for {data_platform}.".format(
|
|
data_platform=data_platform
|
|
)
|
|
)
|
|
return None
|
|
# If data platform is snowflake, we artificially lowercase the Database name.
|
|
# This is because DataHub also does this during ingestion.
|
|
# Ref: https://github.com/datahub-project/datahub/blob/master/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py#L155
|
|
database_name = (
|
|
url_instance.database.lower()
|
|
if data_platform == "snowflake"
|
|
else url_instance.database
|
|
)
|
|
if database_name.endswith(f"/{schema_name}"):
|
|
database_name = database_name[: -len(f"/{schema_name}")]
|
|
schema_name = (
|
|
schema_name if exclude_dbname else f"{database_name}.{schema_name}"
|
|
)
|
|
|
|
elif data_platform == "bigquery":
|
|
if url_instance.host is None or url_instance.database is None:
|
|
warn(
|
|
"DataHubValidationAction failed to locate host and/or database name for {data_platform}. ".format(
|
|
data_platform=data_platform
|
|
)
|
|
)
|
|
return None
|
|
schema_name = f"{url_instance.host}.{url_instance.database}"
|
|
|
|
schema_name = schema_name or url_instance.database
|
|
if schema_name is None:
|
|
warn(
|
|
f"DataHubValidationAction failed to locate schema name for {data_platform}."
|
|
)
|
|
return None
|
|
|
|
dataset_name = f"{schema_name}.{table_name}"
|
|
|
|
if convert_urns_to_lowercase:
|
|
dataset_name = dataset_name.lower()
|
|
|
|
dataset_urn = builder.make_dataset_urn_with_platform_instance(
|
|
platform=data_platform if platform_alias is None else platform_alias,
|
|
name=dataset_name,
|
|
platform_instance=platform_instance,
|
|
env=env,
|
|
)
|
|
|
|
return dataset_urn
|
|
|
|
|
|
@dataclass
|
|
class DataHubStdAssertion:
|
|
scope: Union[str, DatasetAssertionScope]
|
|
operator: Union[str, AssertionStdOperator]
|
|
aggregation: Union[str, AssertionStdAggregation]
|
|
parameters: Optional[AssertionStdParameters] = None
|
|
|
|
|
|
class DecimalEncoder(json.JSONEncoder):
|
|
def default(self, o):
|
|
if isinstance(o, Decimal):
|
|
return str(o)
|
|
return super().default(o)
|
|
|
|
|
|
def convert_to_string(var: Any) -> str:
|
|
try:
|
|
tmp = (
|
|
str(var)
|
|
if isinstance(var, (str, int, float))
|
|
else json.dumps(var, cls=DecimalEncoder)
|
|
)
|
|
except TypeError as e:
|
|
logger.debug(e)
|
|
tmp = str(var)
|
|
return tmp
|
|
|
|
|
|
def warn(msg):
|
|
logger.warning(msg)
|