feat(ingest/athena): Add option for Athena partitioned profiling (#10723)

This commit is contained in:
Tamas Nemeth 2024-07-20 00:00:40 +02:00 committed by GitHub
parent 44cdb046d7
commit 20574cf1c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 323 additions and 44 deletions

View File

@ -330,7 +330,14 @@ plugins: Dict[str, Set[str]] = {
# sqlalchemy-bigquery is included here since it provides an implementation of
# a SQLalchemy-conform STRUCT type definition
"athena": sql_common
| {"PyAthena[SQLAlchemy]>=2.6.0,<3.0.0", "sqlalchemy-bigquery>=1.4.1"},
# We need to set tenacity lower than 8.4.0 as
# this version has missing dependency asyncio
# https://github.com/jd/tenacity/issues/471
| {
"PyAthena[SQLAlchemy]>=2.6.0,<3.0.0",
"sqlalchemy-bigquery>=1.4.1",
"tenacity!=8.4.0",
},
"azure-ad": set(),
"bigquery": sql_common
| bigquery_common

View File

@ -13,6 +13,7 @@ import unittest.mock
import uuid
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
@ -39,6 +40,7 @@ from great_expectations.data_context.types.base import (
from great_expectations.dataset.dataset import Dataset
from great_expectations.dataset.sqlalchemy_dataset import SqlAlchemyDataset
from great_expectations.datasource.sqlalchemy_datasource import SqlAlchemyDatasource
from great_expectations.execution_engine.sqlalchemy_dialect import GXSqlDialect
from great_expectations.profile.base import ProfilerDataType
from great_expectations.profile.basic_dataset_profiler import BasicDatasetProfilerBase
from sqlalchemy.engine import Connection, Engine
@ -72,9 +74,14 @@ from datahub.utilities.sqlalchemy_query_combiner import (
get_query_columns,
)
if TYPE_CHECKING:
from pyathena.cursor import Cursor
assert MARKUPSAFE_PATCHED
logger: logging.Logger = logging.getLogger(__name__)
_original_get_column_median = SqlAlchemyDataset.get_column_median
P = ParamSpec("P")
POSTGRESQL = "postgresql"
MYSQL = "mysql"
@ -203,6 +210,47 @@ def _get_column_quantiles_bigquery_patch( # type:ignore
return list()
def _get_column_quantiles_awsathena_patch( # type:ignore
self, column: str, quantiles: Iterable
) -> list:
import ast
table_name = ".".join(
[f'"{table_part}"' for table_part in str(self._table).split(".")]
)
quantiles_list = list(quantiles)
quantiles_query = (
f"SELECT approx_percentile({column}, ARRAY{str(quantiles_list)}) as quantiles "
f"from (SELECT {column} from {table_name})"
)
try:
quantiles_results = self.engine.execute(quantiles_query).fetchone()[0]
quantiles_results_list = ast.literal_eval(quantiles_results)
return quantiles_results_list
except ProgrammingError as pe:
self._treat_quantiles_exception(pe)
return []
def _get_column_median_patch(self, column):
# AWS Athena and presto have an special function that can be used to retrieve the median
if (
self.sql_engine_dialect.name.lower() == GXSqlDialect.AWSATHENA
or self.sql_engine_dialect.name.lower() == GXSqlDialect.TRINO
):
table_name = ".".join(
[f'"{table_part}"' for table_part in str(self._table).split(".")]
)
element_values = self.engine.execute(
f"SELECT approx_percentile({column}, 0.5) FROM {table_name}"
)
return convert_to_json_serializable(element_values.fetchone()[0])
else:
return _original_get_column_median(self, column)
def _is_single_row_query_method(query: Any) -> bool:
SINGLE_ROW_QUERY_FILES = {
# "great_expectations/dataset/dataset.py",
@ -1038,6 +1086,12 @@ class DatahubGEProfiler:
), unittest.mock.patch(
"great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset._get_column_quantiles_bigquery",
_get_column_quantiles_bigquery_patch,
), unittest.mock.patch(
"great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset._get_column_quantiles_awsathena",
_get_column_quantiles_awsathena_patch,
), unittest.mock.patch(
"great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset.get_column_median",
_get_column_median_patch,
), concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as async_executor, SQLAlchemyQueryCombiner(
@ -1114,15 +1168,16 @@ class DatahubGEProfiler:
**request.batch_kwargs,
)
def _drop_trino_temp_table(self, temp_dataset: Dataset) -> None:
def _drop_temp_table(self, temp_dataset: Dataset) -> None:
schema = temp_dataset._table.schema
table = temp_dataset._table.name
table_name = f'"{schema}"."{table}"' if schema else f'"{table}"'
try:
with self.base_engine.connect() as connection:
connection.execute(f"drop view if exists {schema}.{table}")
logger.debug(f"View {schema}.{table} was dropped.")
connection.execute(f"drop view if exists {table_name}")
logger.debug(f"View {table_name} was dropped.")
except Exception:
logger.warning(f"Unable to delete trino temporary table: {schema}.{table}")
logger.warning(f"Unable to delete temporary table: {table_name}")
def _generate_single_profile(
self,
@ -1149,6 +1204,19 @@ class DatahubGEProfiler:
}
bigquery_temp_table: Optional[str] = None
temp_view: Optional[str] = None
if platform and platform.upper() == "ATHENA" and (custom_sql):
if custom_sql is not None:
# Note that limit and offset are not supported for custom SQL.
temp_view = create_athena_temp_table(
self, custom_sql, pretty_name, self.base_engine.raw_connection()
)
ge_config["table"] = temp_view
ge_config["schema"] = None
ge_config["limit"] = None
ge_config["offset"] = None
custom_sql = None
if platform == BIGQUERY and (
custom_sql or self.config.limit or self.config.offset
):
@ -1234,8 +1302,16 @@ class DatahubGEProfiler:
)
return None
finally:
if batch is not None and self.base_engine.engine.name == TRINO:
self._drop_trino_temp_table(batch)
if batch is not None and self.base_engine.engine.name.upper() in [
"TRINO",
"AWSATHENA",
]:
if (
self.base_engine.engine.name.upper() == "TRINO"
or temp_view is not None
):
self._drop_temp_table(batch)
# if we are not on Trino then we only drop table if temp table variable was set
def _get_ge_dataset(
self,
@ -1299,6 +1375,40 @@ def _get_column_types_to_ignore(dialect_name: str) -> List[str]:
return []
def create_athena_temp_table(
instance: Union[DatahubGEProfiler, _SingleDatasetProfiler],
sql: str,
table_pretty_name: str,
raw_connection: Any,
) -> Optional[str]:
try:
cursor: "Cursor" = cast("Cursor", raw_connection.cursor())
logger.debug(f"Creating view for {table_pretty_name}: {sql}")
temp_view = f"ge_{uuid.uuid4()}"
if "." in table_pretty_name:
schema_part = table_pretty_name.split(".")[-1]
schema_part_quoted = ".".join(
[f'"{part}"' for part in str(schema_part).split(".")]
)
temp_view = f"{schema_part_quoted}_{temp_view}"
temp_view = f"ge_{uuid.uuid4()}"
cursor.execute(f'create or replace view "{temp_view}" as {sql}')
except Exception as e:
if not instance.config.catch_exceptions:
raise e
logger.exception(f"Encountered exception while profiling {table_pretty_name}")
instance.report.report_warning(
table_pretty_name,
f"Profiling exception {e} when running custom sql {sql}",
)
return None
finally:
raw_connection.close()
return temp_view
def create_bigquery_temp_table(
instance: Union[DatahubGEProfiler, _SingleDatasetProfiler],
bq_sql: str,

View File

@ -147,7 +147,7 @@ class GEProfilingConfig(ConfigModel):
partition_profiling_enabled: bool = Field(
default=True,
description="Whether to profile partitioned tables. Only BigQuery supports this. "
description="Whether to profile partitioned tables. Only BigQuery and Aws Athena supports this. "
"If enabled, latest partition data is used for profiling.",
)
partition_datetime: Optional[datetime.datetime] = Field(

View File

@ -1,7 +1,9 @@
import datetime
import json
import logging
import re
import typing
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import pydantic
@ -27,6 +29,7 @@ from datahub.ingestion.api.decorators import (
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.s3_util import make_s3_urn
from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes
from datahub.ingestion.source.ge_profiling_config import GEProfilingConfig
from datahub.ingestion.source.sql.sql_common import (
SQLAlchemySource,
register_custom_type,
@ -52,6 +55,14 @@ register_custom_type(STRUCT, RecordTypeClass)
register_custom_type(MapType, MapTypeClass)
class AthenaProfilingConfig(GEProfilingConfig):
# Overriding default value for partition_profiling
partition_profiling_enabled: bool = pydantic.Field(
default=False,
description="Enable partition profiling. This will profile the latest partition of the table.",
)
class CustomAthenaRestDialect(AthenaRestDialect):
"""Custom definition of the Athena dialect.
@ -171,13 +182,17 @@ class CustomAthenaRestDialect(AthenaRestDialect):
# To extract all of them, we simply iterate over all detected fields and
# convert them to SQLalchemy types
struct_args = []
for field in struct_type["fields"]:
for struct_field in struct_type["fields"]:
struct_args.append(
(
field["name"],
self._get_column_type(field["type"]["type"])
if field["type"]["type"] not in ["record", "array"]
else self._get_column_type(field["type"]),
struct_field["name"],
(
self._get_column_type(
struct_field["type"]["native_data_type"]
)
if struct_field["type"]["type"] not in ["record", "array"]
else self._get_column_type(struct_field["type"])
),
)
)
@ -189,7 +204,7 @@ class CustomAthenaRestDialect(AthenaRestDialect):
detected_col_type = MapType
# the type definition for maps looks like the following: key_type:val_type (e.g., string:string)
key_type_raw, value_type_raw = type_meta_information.split(",")
key_type_raw, value_type_raw = type_meta_information.split(",", 1)
# convert both type names to actual SQLalchemy types
args = [
@ -257,6 +272,8 @@ class AthenaConfig(SQLCommonConfig):
print_warning=True,
)
profiling: AthenaProfilingConfig = AthenaProfilingConfig()
def get_sql_alchemy_url(self):
return make_sqlalchemy_uri(
self.scheme,
@ -275,6 +292,12 @@ class AthenaConfig(SQLCommonConfig):
)
@dataclass
class Partitionitem:
partitions: List[str] = field(default_factory=list)
max_partition: Optional[Dict[str, str]] = None
@platform_name("Athena")
@support_status(SupportStatus.CERTIFIED)
@config_class(AthenaConfig)
@ -294,6 +317,8 @@ class AthenaSource(SQLAlchemySource):
- Profiling when enabled.
"""
table_partition_cache: Dict[str, Dict[str, Partitionitem]] = {}
def __init__(self, config, ctx):
super().__init__(config, ctx, "athena")
self.cursor: Optional[BaseCursor] = None
@ -429,12 +454,53 @@ class AthenaSource(SQLAlchemySource):
return [schema for schema in schemas if schema == athena_config.database]
return schemas
# Overwrite to get partitions
def get_partitions(
self, inspector: Inspector, schema: str, table: str
) -> List[str]:
partitions = []
if not self.cursor:
return []
metadata: AthenaTableMetadata = self.cursor.get_table_metadata(
table_name=table, schema_name=schema
)
if metadata.partition_keys:
for key in metadata.partition_keys:
if key.name:
partitions.append(key.name)
if not partitions:
return []
# We create an artiificaial concatenated partition key to be able to query max partition easier
part_concat = "|| '-' ||".join(partitions)
max_partition_query = f'select {",".join(partitions)} from "{schema}"."{table}$partitions" where {part_concat} = (select max({part_concat}) from "{schema}"."{table}$partitions")'
ret = self.cursor.execute(max_partition_query)
max_partition: Dict[str, str] = {}
if ret:
max_partitons = list(ret)
for idx, row in enumerate([row[0] for row in ret.description]):
max_partition[row] = max_partitons[0][idx]
if self.table_partition_cache.get(schema) is None:
self.table_partition_cache[schema] = {}
self.table_partition_cache[schema][table] = Partitionitem(
partitions=partitions,
max_partition=max_partition,
)
return partitions
return []
# Overwrite to modify the creation of schema fields
def get_schema_fields_for_column(
self,
dataset_name: str,
column: Dict,
pk_constraints: Optional[dict] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
fields = get_schema_fields_for_sqlalchemy_column(
@ -442,17 +508,45 @@ class AthenaSource(SQLAlchemySource):
column_type=column["type"],
description=column.get("comment", None),
nullable=column.get("nullable", True),
is_part_of_key=True
if (
pk_constraints is not None
and isinstance(pk_constraints, dict)
and column["name"] in pk_constraints.get("constrained_columns", [])
)
else False,
is_part_of_key=(
True
if (
pk_constraints is not None
and isinstance(pk_constraints, dict)
and column["name"] in pk_constraints.get("constrained_columns", [])
)
else False
),
is_partitioning_key=(
True
if (partition_keys is not None and column["name"] in partition_keys)
else False
),
)
return fields
def generate_partition_profiler_query(
self, schema: str, table: str, partition_datetime: Optional[datetime.datetime]
) -> Tuple[Optional[str], Optional[str]]:
if not self.config.profiling.partition_profiling_enabled:
return None, None
partition: Optional[Partitionitem] = self.table_partition_cache.get(
schema, {}
).get(table, None)
if partition and partition.max_partition:
max_partition_filters = []
for key, value in partition.max_partition.items():
max_partition_filters.append(f"CAST({key} as VARCHAR) = '{value}'")
max_partition = str(partition.max_partition)
return (
max_partition,
f'SELECT * FROM "{schema}"."{table}" WHERE {" AND ".join(max_partition_filters)}',
)
return None, None
def close(self):
if self.cursor:
self.cursor.close()

View File

@ -170,6 +170,7 @@ class HiveSource(TwoTierSQLAlchemySource):
dataset_name: str,
column: Dict[Any, Any],
pk_constraints: Optional[Dict[Any, Any]] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
fields = super().get_schema_fields_for_column(

View File

@ -878,6 +878,7 @@ class HiveMetastoreSource(SQLAlchemySource):
dataset_name: str,
column: Dict[Any, Any],
pk_constraints: Optional[Dict[Any, Any]] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
return get_schema_fields_for_hive_column(

View File

@ -713,7 +713,16 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
data_reader,
)
except Exception as e:
self.warn(logger, f"{schema}.{table}", f"Ingestion error: {e}")
self.warn(
logger,
f"{schema}.{table}",
f"Ingestion error: {e}",
)
logger.debug(
f"Error processing table {schema}.{table}: Error was: {e} Traceback:",
exc_info=e,
)
except Exception as e:
self.error(logger, f"{schema}", f"Tables error: {e}")
@ -725,6 +734,11 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
) -> Optional[Dict[str, List[str]]]:
return None
def get_partitions(
self, inspector: Inspector, schema: str, table: str
) -> Optional[List[str]]:
return None
def _process_table(
self,
dataset_name: str,
@ -769,9 +783,14 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
extra_tags = self.get_extra_tags(inspector, schema, table)
pk_constraints: dict = inspector.get_pk_constraint(table, schema)
partitions: Optional[List[str]] = self.get_partitions(inspector, schema, table)
foreign_keys = self._get_foreign_keys(dataset_urn, inspector, schema, table)
schema_fields = self.get_schema_fields(
dataset_name, columns, pk_constraints, tags=extra_tags
dataset_name,
columns,
pk_constraints,
tags=extra_tags,
partition_keys=partitions,
)
schema_metadata = get_schema_metadata(
self.report,
@ -921,6 +940,7 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
if len(columns) == 0:
self.warn(logger, "missing column information", dataset_name)
except Exception as e:
logger.error(traceback.format_exc())
self.warn(
logger,
dataset_name,
@ -949,6 +969,7 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
dataset_name: str,
columns: List[dict],
pk_constraints: Optional[dict] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[Dict[str, List[str]]] = None,
) -> List[SchemaField]:
canonical_schema = []
@ -957,7 +978,11 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
if tags:
column_tags = tags.get(column["name"], [])
fields = self.get_schema_fields_for_column(
dataset_name, column, pk_constraints, tags=column_tags
dataset_name,
column,
pk_constraints,
tags=column_tags,
partition_keys=partition_keys,
)
canonical_schema.extend(fields)
return canonical_schema
@ -967,6 +992,7 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
dataset_name: str,
column: dict,
pk_constraints: Optional[dict] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
gtc: Optional[GlobalTagsClass] = None
@ -989,6 +1015,10 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
and column["name"] in pk_constraints.get("constrained_columns", [])
):
field.isPartOfKey = True
if partition_keys is not None and column["name"] in partition_keys:
field.isPartitioningKey = True
return [field]
def loop_views(
@ -1017,7 +1047,11 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
sql_config=sql_config,
)
except Exception as e:
self.warn(logger, f"{schema}.{view}", f"Ingestion error: {e}")
self.warn(
logger,
f"{schema}.{view}",
f"Ingestion error: {e} {traceback.format_exc()}",
)
except Exception as e:
self.error(logger, f"{schema}", f"Views error: {e}")

View File

@ -388,6 +388,7 @@ class TrinoSource(SQLAlchemySource):
dataset_name: str,
column: dict,
pk_constraints: Optional[dict] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
fields = super().get_schema_fields_for_column(

View File

@ -1,5 +1,6 @@
import json
import logging
import traceback
import uuid
from typing import Any, Dict, List, Optional, Type, Union
@ -46,7 +47,6 @@ class SqlAlchemyColumnToAvroConverter:
cls, column_type: Union[types.TypeEngine, STRUCT, MapType], nullable: bool
) -> Dict[str, Any]:
"""Determines the concrete AVRO schema type for a SQLalchemy-typed column"""
if isinstance(
column_type, tuple(cls.PRIMITIVE_SQL_ALCHEMY_TYPE_TO_AVRO_TYPE.keys())
):
@ -80,21 +80,38 @@ class SqlAlchemyColumnToAvroConverter:
}
if isinstance(column_type, types.ARRAY):
array_type = column_type.item_type
return {
"type": "array",
"items": cls.get_avro_type(column_type=array_type, nullable=nullable),
"native_data_type": f"array<{str(column_type.item_type)}>",
}
if isinstance(column_type, MapType):
key_type = column_type.types[0]
value_type = column_type.types[1]
return {
"type": "map",
"values": cls.get_avro_type(column_type=value_type, nullable=nullable),
"native_data_type": str(column_type),
"key_type": cls.get_avro_type(column_type=key_type, nullable=nullable),
"key_native_data_type": str(key_type),
}
try:
key_type = column_type.types[0]
value_type = column_type.types[1]
return {
"type": "map",
"values": cls.get_avro_type(
column_type=value_type, nullable=nullable
),
"native_data_type": str(column_type),
"key_type": cls.get_avro_type(
column_type=key_type, nullable=nullable
),
"key_native_data_type": str(key_type),
}
except Exception as e:
logger.warning(
f"Unable to parse MapType {column_type} the error was: {e}"
)
return {
"type": "map",
"values": {"type": "null", "_nullable": True},
"native_data_type": str(column_type),
"key_type": {"type": "null", "_nullable": True},
"key_native_data_type": "null",
}
if STRUCT and isinstance(column_type, STRUCT):
fields = []
for field_def in column_type._STRUCT_fields:
@ -108,14 +125,23 @@ class SqlAlchemyColumnToAvroConverter:
}
)
struct_name = f"__struct_{str(uuid.uuid4()).replace('-', '')}"
return {
"type": "record",
"name": struct_name,
"fields": fields,
"native_data_type": str(column_type),
"_nullable": nullable,
}
try:
return {
"type": "record",
"name": struct_name,
"fields": fields,
"native_data_type": str(column_type),
"_nullable": nullable,
}
except Exception:
# This is a workaround for the case when the struct name is not string convertable because SqlAlchemt throws an error
return {
"type": "record",
"name": struct_name,
"fields": fields,
"native_data_type": "map",
"_nullable": nullable,
}
return {
"type": "null",
@ -153,6 +179,7 @@ def get_schema_fields_for_sqlalchemy_column(
description: Optional[str] = None,
nullable: Optional[bool] = True,
is_part_of_key: Optional[bool] = False,
is_partitioning_key: Optional[bool] = False,
) -> List[SchemaField]:
"""Creates SchemaFields from a given SQLalchemy column.
@ -181,7 +208,7 @@ def get_schema_fields_for_sqlalchemy_column(
)
except Exception as e:
logger.warning(
f"Unable to parse column {column_name} and type {column_type} the error was: {e}"
f"Unable to parse column {column_name} and type {column_type} the error was: {e} Traceback: {traceback.format_exc()}"
)
# fallback description in case any exception occurred
@ -208,4 +235,8 @@ def get_schema_fields_for_sqlalchemy_column(
is_part_of_key if is_part_of_key is not None else False
)
schema_fields[0].isPartitioningKey = (
is_partitioning_key if is_partitioning_key is not None else False
)
return schema_fields