904 lines
37 KiB
Python
Raw Normal View History

import json
import logging
import sys
import time
from dataclasses import dataclass
from datetime import timezone
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import datahub.emitter.mce_builder as builder
import packaging.version
from datahub.cli.env_utils import get_boolean_env_variable
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.rest_emitter import DatahubRestEmitter
from datahub.emitter.serialization_helper import pre_json_transform
from datahub.ingestion.source.sql.sqlalchemy_uri_mapper import (
get_platform_from_sqlalchemy_uri,
)
from datahub.metadata.com.linkedin.pegasus2avro.assertion import (
AssertionInfo,
AssertionResult,
AssertionResultType,
AssertionRunEvent,
AssertionRunStatus,
AssertionStdAggregation,
AssertionStdOperator,
AssertionStdParameter,
AssertionStdParameters,
AssertionStdParameterType,
AssertionType,
BatchSpec,
DatasetAssertionInfo,
DatasetAssertionScope,
)
from datahub.metadata.com.linkedin.pegasus2avro.common import DataPlatformInstance
from datahub.metadata.schema_classes import PartitionSpecClass, PartitionTypeClass
from datahub.sql_parsing.sqlglot_lineage import create_lineage_sql_parsed_result
from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED
from datahub.utilities.urns.dataset_urn import DatasetUrn
from great_expectations.checkpoint.actions import ValidationAction
from great_expectations.core.batch import Batch
from great_expectations.core.batch_spec import (
RuntimeDataBatchSpec,
RuntimeQueryBatchSpec,
SqlAlchemyDatasourceBatchSpec,
)
from great_expectations.core.expectation_validation_result import (
ExpectationSuiteValidationResult,
)
from great_expectations.data_asset.data_asset import DataAsset
from great_expectations.data_context import AbstractDataContext
from great_expectations.data_context.types.resource_identifiers import (
ExpectationSuiteIdentifier,
ValidationResultIdentifier,
)
from great_expectations.execution_engine import PandasExecutionEngine
from great_expectations.execution_engine.sqlalchemy_execution_engine import (
SqlAlchemyExecutionEngine,
)
from great_expectations.validator.validator import Validator
from sqlalchemy.engine.base import Connection, Engine
from sqlalchemy.engine.url import make_url
# TODO: move this and version check used in tests to some common module
try:
from great_expectations import __version__ as GX_VERSION # type: ignore
has_name_positional_arg = packaging.version.parse(
GX_VERSION
) >= packaging.version.Version("0.18.14")
except Exception:
has_name_positional_arg = False
if TYPE_CHECKING:
from great_expectations.data_context.types.resource_identifiers import (
GXCloudIdentifier,
)
assert MARKUPSAFE_PATCHED
logger = logging.getLogger(__name__)
if get_boolean_env_variable("DATAHUB_DEBUG", False):
handler = logging.StreamHandler(stream=sys.stdout)
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
GE_PLATFORM_NAME = "great-expectations"
class DataHubValidationAction(ValidationAction):
def __init__(
self,
data_context: AbstractDataContext,
# this would capture `name` positional arg added in GX 0.18.14
*args: Union[str, Any],
server_url: str,
env: str = builder.DEFAULT_ENV,
platform_alias: Optional[str] = None,
platform_instance_map: Optional[Dict[str, str]] = None,
graceful_exceptions: bool = True,
token: Optional[str] = None,
timeout_sec: Optional[float] = None,
retry_status_codes: Optional[List[int]] = None,
retry_max_times: Optional[int] = None,
extra_headers: Optional[Dict[str, str]] = None,
exclude_dbname: Optional[bool] = None,
parse_table_names_from_sql: bool = False,
convert_urns_to_lowercase: bool = False,
name: str = "DataHubValidationAction",
):
if has_name_positional_arg:
if len(args) >= 1 and isinstance(args[0], str):
name = args[0]
super().__init__(data_context, name)
else:
super().__init__(data_context)
self.server_url = server_url
self.env = env
self.platform_alias = platform_alias
self.platform_instance_map = platform_instance_map
self.graceful_exceptions = graceful_exceptions
self.token = token
self.timeout_sec = timeout_sec
self.retry_status_codes = retry_status_codes
self.retry_max_times = retry_max_times
self.extra_headers = extra_headers
self.exclude_dbname = exclude_dbname
self.parse_table_names_from_sql = parse_table_names_from_sql
self.convert_urns_to_lowercase = convert_urns_to_lowercase
def _run(
self,
validation_result_suite: ExpectationSuiteValidationResult,
validation_result_suite_identifier: Union[
ValidationResultIdentifier, "GXCloudIdentifier"
],
data_asset: Union[Validator, DataAsset, Batch],
payload: Optional[Any] = None,
expectation_suite_identifier: Optional[ExpectationSuiteIdentifier] = None,
checkpoint_identifier: Optional[Any] = None,
) -> Dict:
datasets = []
try:
emitter = DatahubRestEmitter(
gms_server=self.server_url,
token=self.token,
read_timeout_sec=self.timeout_sec,
connect_timeout_sec=self.timeout_sec,
retry_status_codes=self.retry_status_codes,
retry_max_times=self.retry_max_times,
extra_headers=self.extra_headers,
)
expectation_suite_name = validation_result_suite.meta.get(
"expectation_suite_name"
)
run_id = validation_result_suite.meta.get("run_id")
if hasattr(data_asset, "active_batch_id"):
batch_identifier = data_asset.active_batch_id
else:
batch_identifier = data_asset.batch_id
if isinstance(
validation_result_suite_identifier, ValidationResultIdentifier
):
expectation_suite_name = (
validation_result_suite_identifier.expectation_suite_identifier.expectation_suite_name
)
run_id = validation_result_suite_identifier.run_id
batch_identifier = validation_result_suite_identifier.batch_identifier
# Returns datasets and corresponding batch requests
datasets = self.get_dataset_partitions(batch_identifier, data_asset)
if len(datasets) == 0 or datasets[0]["dataset_urn"] is None:
warn("Metadata not sent to datahub. No datasets found.")
return {"datahub_notification_result": "none required"}
# Returns assertion info and assertion results
assertions = self.get_assertions_with_results(
validation_result_suite,
expectation_suite_name,
run_id,
payload,
datasets,
)
logger.info("Sending metadata to datahub ...")
logger.info("Dataset URN - {urn}".format(urn=datasets[0]["dataset_urn"]))
for assertion in assertions:
logger.info(
"Assertion URN - {urn}".format(urn=assertion["assertionUrn"])
)
# Construct a MetadataChangeProposalWrapper object.
assertion_info_mcp = MetadataChangeProposalWrapper(
entityUrn=assertion["assertionUrn"],
aspect=assertion["assertionInfo"],
)
emitter.emit_mcp(assertion_info_mcp)
# Construct a MetadataChangeProposalWrapper object.
assertion_platform_mcp = MetadataChangeProposalWrapper(
entityUrn=assertion["assertionUrn"],
aspect=assertion["assertionPlatform"],
)
emitter.emit_mcp(assertion_platform_mcp)
for assertionResult in assertion["assertionResults"]:
dataset_assertionResult_mcp = MetadataChangeProposalWrapper(
entityUrn=assertionResult.assertionUrn,
aspect=assertionResult,
)
# Emit Result! (timeseries aspect)
emitter.emit_mcp(dataset_assertionResult_mcp)
logger.info("Metadata sent to datahub.")
result = "DataHub notification succeeded"
except Exception as e:
result = "DataHub notification failed"
if self.graceful_exceptions:
logger.error(e)
logger.info("Suppressing error because graceful_exceptions is set")
else:
raise
return {"datahub_notification_result": result}
def get_assertions_with_results(
self,
validation_result_suite,
expectation_suite_name,
run_id,
payload,
datasets,
):
dataPlatformInstance = DataPlatformInstance(
platform=builder.make_data_platform_urn(GE_PLATFORM_NAME)
)
docs_link = None
if payload:
# process the payload
for action_names in payload.keys():
if payload[action_names]["class"] == "UpdateDataDocsAction":
data_docs_pages = payload[action_names]
for docs_link_key, docs_link_val in data_docs_pages.items():
if "file://" not in docs_link_val and docs_link_key != "class":
docs_link = docs_link_val
assertions_with_results = []
for result in validation_result_suite.results:
expectation_config = result["expectation_config"]
expectation_type = expectation_config["expectation_type"]
success = bool(result["success"])
kwargs = {
k: v for k, v in expectation_config["kwargs"].items() if k != "batch_id"
}
result = result["result"]
assertion_datasets = [d["dataset_urn"] for d in datasets]
if len(datasets) == 1 and "column" in kwargs:
assertion_fields = [
builder.make_schema_field_urn(
datasets[0]["dataset_urn"], kwargs["column"]
)
]
else:
assertion_fields = None # type:ignore
# Be careful what fields to consider for creating assertion urn.
# Any change in fields below would lead to a new assertion
# FIXME - Currently, when using evaluation parameters, new assertion is
# created when runtime resolved kwargs are different,
# possibly for each validation run
assertionUrn = builder.make_assertion_urn(
builder.datahub_guid(
pre_json_transform(
{
"platform": GE_PLATFORM_NAME,
"nativeType": expectation_type,
"nativeParameters": kwargs,
"dataset": assertion_datasets[0],
"fields": assertion_fields,
}
)
)
)
logger.debug(
"GE expectation_suite_name - {name}, expectation_type - {type}, Assertion URN - {urn}".format(
name=expectation_suite_name, type=expectation_type, urn=assertionUrn
)
)
assertionInfo: AssertionInfo = self.get_assertion_info(
expectation_type,
kwargs,
assertion_datasets[0],
assertion_fields,
expectation_suite_name,
)
# TODO: Understand why their run time is incorrect.
run_time = run_id.run_time.astimezone(timezone.utc)
evaluation_parameters = (
{
k: convert_to_string(v)
for k, v in validation_result_suite.evaluation_parameters.items()
if k and v
}
if validation_result_suite.evaluation_parameters
else None
)
nativeResults = {
k: convert_to_string(v)
for k, v in result.items()
if (
k
in [
"observed_value",
"partial_unexpected_list",
"partial_unexpected_counts",
"details",
]
and v
)
}
actualAggValue = (
result.get("observed_value")
if isinstance(result.get("observed_value"), (int, float))
else None
)
ds = datasets[0]
# https://docs.greatexpectations.io/docs/reference/expectations/result_format/
assertionResult = AssertionRunEvent(
timestampMillis=int(round(time.time() * 1000)),
assertionUrn=assertionUrn,
asserteeUrn=ds["dataset_urn"],
runId=run_time.strftime("%Y-%m-%dT%H:%M:%SZ"),
result=AssertionResult(
type=(
AssertionResultType.SUCCESS
if success
else AssertionResultType.FAILURE
),
rowCount=parse_int_or_default(result.get("element_count")),
missingCount=parse_int_or_default(result.get("missing_count")),
unexpectedCount=parse_int_or_default(
result.get("unexpected_count")
),
actualAggValue=actualAggValue,
externalUrl=docs_link,
nativeResults=nativeResults,
),
batchSpec=ds["batchSpec"],
status=AssertionRunStatus.COMPLETE,
runtimeContext=evaluation_parameters,
)
if ds.get("partitionSpec") is not None:
assertionResult.partitionSpec = ds.get("partitionSpec")
assertionResults = [assertionResult]
assertions_with_results.append(
{
"assertionUrn": assertionUrn,
"assertionInfo": assertionInfo,
"assertionPlatform": dataPlatformInstance,
"assertionResults": assertionResults,
}
)
return assertions_with_results
def get_assertion_info(
self, expectation_type, kwargs, dataset, fields, expectation_suite_name
):
# TODO - can we find exact type of min and max value
def get_min_max(kwargs, type=AssertionStdParameterType.UNKNOWN):
return AssertionStdParameters(
minValue=AssertionStdParameter(
value=convert_to_string(kwargs.get("min_value")),
type=type,
),
maxValue=AssertionStdParameter(
value=convert_to_string(kwargs.get("max_value")),
type=type,
),
)
known_expectations: Dict[str, DataHubStdAssertion] = {
# column aggregate expectations
"expect_column_min_to_be_between": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_COLUMN,
operator=AssertionStdOperator.BETWEEN,
aggregation=AssertionStdAggregation.MIN,
parameters=get_min_max(kwargs),
),
"expect_column_max_to_be_between": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_COLUMN,
operator=AssertionStdOperator.BETWEEN,
aggregation=AssertionStdAggregation.MAX,
parameters=get_min_max(kwargs),
),
"expect_column_median_to_be_between": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_COLUMN,
operator=AssertionStdOperator.BETWEEN,
aggregation=AssertionStdAggregation.MEDIAN,
parameters=get_min_max(kwargs),
),
"expect_column_stdev_to_be_between": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_COLUMN,
operator=AssertionStdOperator.BETWEEN,
aggregation=AssertionStdAggregation.STDDEV,
parameters=get_min_max(kwargs, AssertionStdParameterType.NUMBER),
),
"expect_column_mean_to_be_between": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_COLUMN,
operator=AssertionStdOperator.BETWEEN,
aggregation=AssertionStdAggregation.MEAN,
parameters=get_min_max(kwargs, AssertionStdParameterType.NUMBER),
),
"expect_column_unique_value_count_to_be_between": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_COLUMN,
operator=AssertionStdOperator.BETWEEN,
aggregation=AssertionStdAggregation.UNIQUE_COUNT,
parameters=get_min_max(kwargs, AssertionStdParameterType.NUMBER),
),
"expect_column_proportion_of_unique_values_to_be_between": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_COLUMN,
operator=AssertionStdOperator.BETWEEN,
aggregation=AssertionStdAggregation.UNIQUE_PROPOTION,
parameters=get_min_max(kwargs, AssertionStdParameterType.NUMBER),
),
"expect_column_sum_to_be_between": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_COLUMN,
operator=AssertionStdOperator.BETWEEN,
aggregation=AssertionStdAggregation.SUM,
parameters=get_min_max(kwargs, AssertionStdParameterType.NUMBER),
),
"expect_column_quantile_values_to_be_between": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_COLUMN,
operator=AssertionStdOperator.BETWEEN,
aggregation=AssertionStdAggregation._NATIVE_,
),
# column map expectations
"expect_column_values_to_not_be_null": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_COLUMN,
operator=AssertionStdOperator.NOT_NULL,
aggregation=AssertionStdAggregation.IDENTITY,
),
"expect_column_values_to_be_in_set": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_COLUMN,
operator=AssertionStdOperator.IN,
aggregation=AssertionStdAggregation.IDENTITY,
parameters=AssertionStdParameters(
value=AssertionStdParameter(
value=convert_to_string(kwargs.get("value_set")),
type=AssertionStdParameterType.SET,
)
),
),
"expect_column_values_to_be_between": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_COLUMN,
operator=AssertionStdOperator.BETWEEN,
aggregation=AssertionStdAggregation.IDENTITY,
parameters=get_min_max(kwargs),
),
"expect_column_values_to_match_regex": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_COLUMN,
operator=AssertionStdOperator.REGEX_MATCH,
aggregation=AssertionStdAggregation.IDENTITY,
parameters=AssertionStdParameters(
value=AssertionStdParameter(
value=kwargs.get("regex"),
type=AssertionStdParameterType.STRING,
)
),
),
"expect_column_values_to_match_regex_list": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_COLUMN,
operator=AssertionStdOperator.REGEX_MATCH,
aggregation=AssertionStdAggregation.IDENTITY,
parameters=AssertionStdParameters(
value=AssertionStdParameter(
value=convert_to_string(kwargs.get("regex_list")),
type=AssertionStdParameterType.LIST,
)
),
),
"expect_table_columns_to_match_ordered_list": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_SCHEMA,
operator=AssertionStdOperator.EQUAL_TO,
aggregation=AssertionStdAggregation.COLUMNS,
parameters=AssertionStdParameters(
value=AssertionStdParameter(
value=convert_to_string(kwargs.get("column_list")),
type=AssertionStdParameterType.LIST,
)
),
),
"expect_table_columns_to_match_set": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_SCHEMA,
operator=AssertionStdOperator.EQUAL_TO,
aggregation=AssertionStdAggregation.COLUMNS,
parameters=AssertionStdParameters(
value=AssertionStdParameter(
value=convert_to_string(kwargs.get("column_set")),
type=AssertionStdParameterType.SET,
)
),
),
"expect_table_column_count_to_be_between": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_SCHEMA,
operator=AssertionStdOperator.BETWEEN,
aggregation=AssertionStdAggregation.COLUMN_COUNT,
parameters=get_min_max(kwargs, AssertionStdParameterType.NUMBER),
),
"expect_table_column_count_to_equal": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_SCHEMA,
operator=AssertionStdOperator.EQUAL_TO,
aggregation=AssertionStdAggregation.COLUMN_COUNT,
parameters=AssertionStdParameters(
value=AssertionStdParameter(
value=convert_to_string(kwargs.get("value")),
type=AssertionStdParameterType.NUMBER,
)
),
),
"expect_column_to_exist": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_SCHEMA,
operator=AssertionStdOperator._NATIVE_,
aggregation=AssertionStdAggregation._NATIVE_,
),
"expect_table_row_count_to_equal": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_ROWS,
operator=AssertionStdOperator.EQUAL_TO,
aggregation=AssertionStdAggregation.ROW_COUNT,
parameters=AssertionStdParameters(
value=AssertionStdParameter(
value=convert_to_string(kwargs.get("value")),
type=AssertionStdParameterType.NUMBER,
)
),
),
"expect_table_row_count_to_be_between": DataHubStdAssertion(
scope=DatasetAssertionScope.DATASET_ROWS,
operator=AssertionStdOperator.BETWEEN,
aggregation=AssertionStdAggregation.ROW_COUNT,
parameters=get_min_max(kwargs, AssertionStdParameterType.NUMBER),
),
}
datasetAssertionInfo = DatasetAssertionInfo(
dataset=dataset,
fields=fields,
operator=AssertionStdOperator._NATIVE_,
aggregation=AssertionStdAggregation._NATIVE_,
nativeType=expectation_type,
nativeParameters={k: convert_to_string(v) for k, v in kwargs.items()},
scope=DatasetAssertionScope.DATASET_ROWS,
)
if expectation_type in known_expectations.keys():
assertion = known_expectations[expectation_type]
datasetAssertionInfo.scope = assertion.scope
datasetAssertionInfo.aggregation = assertion.aggregation
datasetAssertionInfo.operator = assertion.operator
datasetAssertionInfo.parameters = assertion.parameters
# Heuristically mapping other expectations
else:
if "column" in kwargs and expectation_type.startswith(
"expect_column_value"
):
datasetAssertionInfo.scope = DatasetAssertionScope.DATASET_COLUMN
datasetAssertionInfo.aggregation = AssertionStdAggregation.IDENTITY
elif "column" in kwargs:
datasetAssertionInfo.scope = DatasetAssertionScope.DATASET_COLUMN
datasetAssertionInfo.aggregation = AssertionStdAggregation._NATIVE_
return AssertionInfo(
type=AssertionType.DATASET,
datasetAssertion=datasetAssertionInfo,
customProperties={"expectation_suite_name": expectation_suite_name},
)
def get_dataset_partitions(self, batch_identifier, data_asset):
dataset_partitions = []
logger.debug("Finding datasets being validated")
# for now, we support only v3-api and sqlalchemy execution engine and Pandas engine
is_sql_alchemy = isinstance(data_asset, Validator) and (
isinstance(data_asset.execution_engine, SqlAlchemyExecutionEngine)
)
is_pandas = isinstance(data_asset.execution_engine, PandasExecutionEngine)
if is_sql_alchemy or is_pandas:
ge_batch_spec = data_asset.active_batch_spec
partitionSpec = None
batchSpecProperties = {
"data_asset_name": str(
data_asset.active_batch_definition.data_asset_name
),
"datasource_name": str(
data_asset.active_batch_definition.datasource_name
),
}
sqlalchemy_uri = None
if is_sql_alchemy and isinstance(
data_asset.execution_engine.engine, Engine
):
sqlalchemy_uri = data_asset.execution_engine.engine.url
# For snowflake sqlalchemy_execution_engine.engine is actually instance of Connection
elif is_sql_alchemy and isinstance(
data_asset.execution_engine.engine, Connection
):
sqlalchemy_uri = data_asset.execution_engine.engine.engine.url
if isinstance(ge_batch_spec, SqlAlchemyDatasourceBatchSpec):
# e.g. ConfiguredAssetSqlDataConnector with splitter_method or sampling_method
schema_name = ge_batch_spec.get("schema_name")
table_name = ge_batch_spec.get("table_name")
dataset_urn = make_dataset_urn_from_sqlalchemy_uri(
sqlalchemy_uri,
schema_name,
table_name,
self.env,
self.get_platform_instance(
data_asset.active_batch_definition.datasource_name
),
self.exclude_dbname,
self.platform_alias,
self.convert_urns_to_lowercase,
)
batchSpec = BatchSpec(
nativeBatchId=batch_identifier,
customProperties=batchSpecProperties,
)
splitter_method = ge_batch_spec.get("splitter_method")
if (
splitter_method is not None
and splitter_method != "_split_on_whole_table"
):
batch_identifiers = ge_batch_spec.get("batch_identifiers", {})
partitionSpec = PartitionSpecClass(
partition=convert_to_string(batch_identifiers)
)
sampling_method = ge_batch_spec.get("sampling_method", "")
if sampling_method == "_sample_using_limit":
batchSpec.limit = ge_batch_spec["sampling_kwargs"]["n"]
dataset_partitions.append(
{
"dataset_urn": dataset_urn,
"partitionSpec": partitionSpec,
"batchSpec": batchSpec,
}
)
elif isinstance(ge_batch_spec, RuntimeQueryBatchSpec):
if not self.parse_table_names_from_sql:
warn(
"Enable parse_table_names_from_sql in DatahubValidationAction config\
to try to parse the tables being asserted from SQL query"
)
return []
query = data_asset.batches[
batch_identifier
].batch_request.runtime_parameters["query"]
partitionSpec = PartitionSpecClass(
type=PartitionTypeClass.QUERY,
partition=f"Query_{builder.datahub_guid(pre_json_transform(query))}",
)
batchSpec = BatchSpec(
nativeBatchId=batch_identifier,
query=query,
customProperties=batchSpecProperties,
)
data_platform = get_platform_from_sqlalchemy_uri(str(sqlalchemy_uri))
sql_parser_in_tables = create_lineage_sql_parsed_result(
query=query,
platform=data_platform,
env=self.env,
platform_instance=None,
default_db=None,
)
tables = [
DatasetUrn.from_string(table_urn).name
for table_urn in sql_parser_in_tables.in_tables
]
if sql_parser_in_tables.debug_info.table_error:
logger.warning(
f"Sql parser failed on {query} with {sql_parser_in_tables.debug_info.table_error}"
)
tables = []
if len(set(tables)) != 1:
warn(
"DataHubValidationAction does not support cross dataset assertions."
)
return []
for table in tables:
dataset_urn = make_dataset_urn_from_sqlalchemy_uri(
sqlalchemy_uri,
None,
table,
self.env,
self.get_platform_instance(
data_asset.active_batch_definition.datasource_name
),
self.exclude_dbname,
self.platform_alias,
self.convert_urns_to_lowercase,
)
dataset_partitions.append(
{
"dataset_urn": dataset_urn,
"partitionSpec": partitionSpec,
"batchSpec": batchSpec,
}
)
elif isinstance(ge_batch_spec, RuntimeDataBatchSpec):
data_platform = self.get_platform_instance(
data_asset.active_batch_definition.datasource_name
)
dataset_urn = builder.make_dataset_urn_with_platform_instance(
platform=(
data_platform
if self.platform_alias is None
else self.platform_alias
),
name=data_asset.active_batch_definition.datasource_name,
platform_instance="",
env=self.env,
)
batchSpec = BatchSpec(
nativeBatchId=batch_identifier,
query="",
customProperties=batchSpecProperties,
)
dataset_partitions.append(
{
"dataset_urn": dataset_urn,
"partitionSpec": partitionSpec,
"batchSpec": batchSpec,
}
)
else:
warn(
"DataHubValidationAction does not recognize this GE batch spec type- {batch_spec_type}.".format(
batch_spec_type=type(ge_batch_spec)
)
)
else:
# TODO - v2-spec - SqlAlchemyDataset support
warn(
"DataHubValidationAction does not recognize this GE data asset type - {asset_type}. This is either using v2-api or execution engine other than sqlalchemy.".format(
asset_type=type(data_asset)
)
)
return dataset_partitions
def get_platform_instance(self, datasource_name):
if self.platform_instance_map and datasource_name in self.platform_instance_map:
return self.platform_instance_map[datasource_name]
else:
warn(
f"Datasource {datasource_name} is not present in platform_instance_map"
)
return None
def parse_int_or_default(value, default_value=None):
if value is None:
return default_value
else:
return int(value)
def make_dataset_urn_from_sqlalchemy_uri(
sqlalchemy_uri,
schema_name,
table_name,
env,
platform_instance=None,
exclude_dbname=None,
platform_alias=None,
convert_urns_to_lowercase=False,
):
data_platform = get_platform_from_sqlalchemy_uri(str(sqlalchemy_uri))
url_instance = make_url(sqlalchemy_uri)
if schema_name is None and "." in table_name:
schema_name, table_name = table_name.split(".")[-2:]
if data_platform in ["redshift", "postgres"]:
schema_name = schema_name or "public"
if url_instance.database is None:
warn(
f"DataHubValidationAction failed to locate database name for {data_platform}."
)
return None
schema_name = (
schema_name if exclude_dbname else f"{url_instance.database}.{schema_name}"
)
elif data_platform == "mssql":
schema_name = schema_name or "dbo"
if url_instance.database is None:
warn(
f"DataHubValidationAction failed to locate database name for {data_platform}."
)
return None
schema_name = (
schema_name if exclude_dbname else f"{url_instance.database}.{schema_name}"
)
elif data_platform in ["trino", "snowflake"]:
if schema_name is None or url_instance.database is None:
warn(
"DataHubValidationAction failed to locate schema name and/or database name for {data_platform}.".format(
data_platform=data_platform
)
)
return None
# If data platform is snowflake, we artificially lowercase the Database name.
# This is because DataHub also does this during ingestion.
# Ref: https://github.com/datahub-project/datahub/blob/master/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py#L155
database_name = (
url_instance.database.lower()
if data_platform == "snowflake"
else url_instance.database
)
if database_name.endswith(f"/{schema_name}"):
database_name = database_name[: -len(f"/{schema_name}")]
schema_name = (
schema_name if exclude_dbname else f"{database_name}.{schema_name}"
)
elif data_platform == "bigquery":
if url_instance.host is None or url_instance.database is None:
warn(
"DataHubValidationAction failed to locate host and/or database name for {data_platform}. ".format(
data_platform=data_platform
)
)
return None
schema_name = f"{url_instance.host}.{url_instance.database}"
schema_name = schema_name or url_instance.database
if schema_name is None:
warn(
f"DataHubValidationAction failed to locate schema name for {data_platform}."
)
return None
dataset_name = f"{schema_name}.{table_name}"
if convert_urns_to_lowercase:
dataset_name = dataset_name.lower()
dataset_urn = builder.make_dataset_urn_with_platform_instance(
platform=data_platform if platform_alias is None else platform_alias,
name=dataset_name,
platform_instance=platform_instance,
env=env,
)
return dataset_urn
@dataclass
class DataHubStdAssertion:
scope: Union[str, DatasetAssertionScope]
operator: Union[str, AssertionStdOperator]
aggregation: Union[str, AssertionStdAggregation]
parameters: Optional[AssertionStdParameters] = None
class DecimalEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, Decimal):
return str(o)
return super().default(o)
def convert_to_string(var: Any) -> str:
try:
tmp = (
str(var)
if isinstance(var, (str, int, float))
else json.dumps(var, cls=DecimalEncoder)
)
except TypeError as e:
logger.debug(e)
tmp = str(var)
return tmp
def warn(msg):
logger.warning(msg)