feat(ingest/athena): Add option for Athena partitioned profiling (#10723)

This commit is contained in:
Tamas Nemeth 2024-07-20 00:00:40 +02:00 committed by GitHub
parent 44cdb046d7
commit 20574cf1c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 323 additions and 44 deletions

View File

@ -330,7 +330,14 @@ plugins: Dict[str, Set[str]] = {
# sqlalchemy-bigquery is included here since it provides an implementation of # sqlalchemy-bigquery is included here since it provides an implementation of
# a SQLalchemy-conform STRUCT type definition # a SQLalchemy-conform STRUCT type definition
"athena": sql_common "athena": sql_common
| {"PyAthena[SQLAlchemy]>=2.6.0,<3.0.0", "sqlalchemy-bigquery>=1.4.1"}, # We need to set tenacity lower than 8.4.0 as
# this version has missing dependency asyncio
# https://github.com/jd/tenacity/issues/471
| {
"PyAthena[SQLAlchemy]>=2.6.0,<3.0.0",
"sqlalchemy-bigquery>=1.4.1",
"tenacity!=8.4.0",
},
"azure-ad": set(), "azure-ad": set(),
"bigquery": sql_common "bigquery": sql_common
| bigquery_common | bigquery_common

View File

@ -13,6 +13,7 @@ import unittest.mock
import uuid import uuid
from functools import lru_cache from functools import lru_cache
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict, Dict,
@ -39,6 +40,7 @@ from great_expectations.data_context.types.base import (
from great_expectations.dataset.dataset import Dataset from great_expectations.dataset.dataset import Dataset
from great_expectations.dataset.sqlalchemy_dataset import SqlAlchemyDataset from great_expectations.dataset.sqlalchemy_dataset import SqlAlchemyDataset
from great_expectations.datasource.sqlalchemy_datasource import SqlAlchemyDatasource from great_expectations.datasource.sqlalchemy_datasource import SqlAlchemyDatasource
from great_expectations.execution_engine.sqlalchemy_dialect import GXSqlDialect
from great_expectations.profile.base import ProfilerDataType from great_expectations.profile.base import ProfilerDataType
from great_expectations.profile.basic_dataset_profiler import BasicDatasetProfilerBase from great_expectations.profile.basic_dataset_profiler import BasicDatasetProfilerBase
from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine import Connection, Engine
@ -72,9 +74,14 @@ from datahub.utilities.sqlalchemy_query_combiner import (
get_query_columns, get_query_columns,
) )
if TYPE_CHECKING:
from pyathena.cursor import Cursor
assert MARKUPSAFE_PATCHED assert MARKUPSAFE_PATCHED
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)
_original_get_column_median = SqlAlchemyDataset.get_column_median
P = ParamSpec("P") P = ParamSpec("P")
POSTGRESQL = "postgresql" POSTGRESQL = "postgresql"
MYSQL = "mysql" MYSQL = "mysql"
@ -203,6 +210,47 @@ def _get_column_quantiles_bigquery_patch( # type:ignore
return list() return list()
def _get_column_quantiles_awsathena_patch( # type:ignore
self, column: str, quantiles: Iterable
) -> list:
import ast
table_name = ".".join(
[f'"{table_part}"' for table_part in str(self._table).split(".")]
)
quantiles_list = list(quantiles)
quantiles_query = (
f"SELECT approx_percentile({column}, ARRAY{str(quantiles_list)}) as quantiles "
f"from (SELECT {column} from {table_name})"
)
try:
quantiles_results = self.engine.execute(quantiles_query).fetchone()[0]
quantiles_results_list = ast.literal_eval(quantiles_results)
return quantiles_results_list
except ProgrammingError as pe:
self._treat_quantiles_exception(pe)
return []
def _get_column_median_patch(self, column):
# AWS Athena and presto have an special function that can be used to retrieve the median
if (
self.sql_engine_dialect.name.lower() == GXSqlDialect.AWSATHENA
or self.sql_engine_dialect.name.lower() == GXSqlDialect.TRINO
):
table_name = ".".join(
[f'"{table_part}"' for table_part in str(self._table).split(".")]
)
element_values = self.engine.execute(
f"SELECT approx_percentile({column}, 0.5) FROM {table_name}"
)
return convert_to_json_serializable(element_values.fetchone()[0])
else:
return _original_get_column_median(self, column)
def _is_single_row_query_method(query: Any) -> bool: def _is_single_row_query_method(query: Any) -> bool:
SINGLE_ROW_QUERY_FILES = { SINGLE_ROW_QUERY_FILES = {
# "great_expectations/dataset/dataset.py", # "great_expectations/dataset/dataset.py",
@ -1038,6 +1086,12 @@ class DatahubGEProfiler:
), unittest.mock.patch( ), unittest.mock.patch(
"great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset._get_column_quantiles_bigquery", "great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset._get_column_quantiles_bigquery",
_get_column_quantiles_bigquery_patch, _get_column_quantiles_bigquery_patch,
), unittest.mock.patch(
"great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset._get_column_quantiles_awsathena",
_get_column_quantiles_awsathena_patch,
), unittest.mock.patch(
"great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset.get_column_median",
_get_column_median_patch,
), concurrent.futures.ThreadPoolExecutor( ), concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers max_workers=max_workers
) as async_executor, SQLAlchemyQueryCombiner( ) as async_executor, SQLAlchemyQueryCombiner(
@ -1114,15 +1168,16 @@ class DatahubGEProfiler:
**request.batch_kwargs, **request.batch_kwargs,
) )
def _drop_trino_temp_table(self, temp_dataset: Dataset) -> None: def _drop_temp_table(self, temp_dataset: Dataset) -> None:
schema = temp_dataset._table.schema schema = temp_dataset._table.schema
table = temp_dataset._table.name table = temp_dataset._table.name
table_name = f'"{schema}"."{table}"' if schema else f'"{table}"'
try: try:
with self.base_engine.connect() as connection: with self.base_engine.connect() as connection:
connection.execute(f"drop view if exists {schema}.{table}") connection.execute(f"drop view if exists {table_name}")
logger.debug(f"View {schema}.{table} was dropped.") logger.debug(f"View {table_name} was dropped.")
except Exception: except Exception:
logger.warning(f"Unable to delete trino temporary table: {schema}.{table}") logger.warning(f"Unable to delete temporary table: {table_name}")
def _generate_single_profile( def _generate_single_profile(
self, self,
@ -1149,6 +1204,19 @@ class DatahubGEProfiler:
} }
bigquery_temp_table: Optional[str] = None bigquery_temp_table: Optional[str] = None
temp_view: Optional[str] = None
if platform and platform.upper() == "ATHENA" and (custom_sql):
if custom_sql is not None:
# Note that limit and offset are not supported for custom SQL.
temp_view = create_athena_temp_table(
self, custom_sql, pretty_name, self.base_engine.raw_connection()
)
ge_config["table"] = temp_view
ge_config["schema"] = None
ge_config["limit"] = None
ge_config["offset"] = None
custom_sql = None
if platform == BIGQUERY and ( if platform == BIGQUERY and (
custom_sql or self.config.limit or self.config.offset custom_sql or self.config.limit or self.config.offset
): ):
@ -1234,8 +1302,16 @@ class DatahubGEProfiler:
) )
return None return None
finally: finally:
if batch is not None and self.base_engine.engine.name == TRINO: if batch is not None and self.base_engine.engine.name.upper() in [
self._drop_trino_temp_table(batch) "TRINO",
"AWSATHENA",
]:
if (
self.base_engine.engine.name.upper() == "TRINO"
or temp_view is not None
):
self._drop_temp_table(batch)
# if we are not on Trino then we only drop table if temp table variable was set
def _get_ge_dataset( def _get_ge_dataset(
self, self,
@ -1299,6 +1375,40 @@ def _get_column_types_to_ignore(dialect_name: str) -> List[str]:
return [] return []
def create_athena_temp_table(
instance: Union[DatahubGEProfiler, _SingleDatasetProfiler],
sql: str,
table_pretty_name: str,
raw_connection: Any,
) -> Optional[str]:
try:
cursor: "Cursor" = cast("Cursor", raw_connection.cursor())
logger.debug(f"Creating view for {table_pretty_name}: {sql}")
temp_view = f"ge_{uuid.uuid4()}"
if "." in table_pretty_name:
schema_part = table_pretty_name.split(".")[-1]
schema_part_quoted = ".".join(
[f'"{part}"' for part in str(schema_part).split(".")]
)
temp_view = f"{schema_part_quoted}_{temp_view}"
temp_view = f"ge_{uuid.uuid4()}"
cursor.execute(f'create or replace view "{temp_view}" as {sql}')
except Exception as e:
if not instance.config.catch_exceptions:
raise e
logger.exception(f"Encountered exception while profiling {table_pretty_name}")
instance.report.report_warning(
table_pretty_name,
f"Profiling exception {e} when running custom sql {sql}",
)
return None
finally:
raw_connection.close()
return temp_view
def create_bigquery_temp_table( def create_bigquery_temp_table(
instance: Union[DatahubGEProfiler, _SingleDatasetProfiler], instance: Union[DatahubGEProfiler, _SingleDatasetProfiler],
bq_sql: str, bq_sql: str,

View File

@ -147,7 +147,7 @@ class GEProfilingConfig(ConfigModel):
partition_profiling_enabled: bool = Field( partition_profiling_enabled: bool = Field(
default=True, default=True,
description="Whether to profile partitioned tables. Only BigQuery supports this. " description="Whether to profile partitioned tables. Only BigQuery and Aws Athena supports this. "
"If enabled, latest partition data is used for profiling.", "If enabled, latest partition data is used for profiling.",
) )
partition_datetime: Optional[datetime.datetime] = Field( partition_datetime: Optional[datetime.datetime] = Field(

View File

@ -1,7 +1,9 @@
import datetime
import json import json
import logging import logging
import re import re
import typing import typing
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import pydantic import pydantic
@ -27,6 +29,7 @@ from datahub.ingestion.api.decorators import (
from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.s3_util import make_s3_urn from datahub.ingestion.source.aws.s3_util import make_s3_urn
from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes
from datahub.ingestion.source.ge_profiling_config import GEProfilingConfig
from datahub.ingestion.source.sql.sql_common import ( from datahub.ingestion.source.sql.sql_common import (
SQLAlchemySource, SQLAlchemySource,
register_custom_type, register_custom_type,
@ -52,6 +55,14 @@ register_custom_type(STRUCT, RecordTypeClass)
register_custom_type(MapType, MapTypeClass) register_custom_type(MapType, MapTypeClass)
class AthenaProfilingConfig(GEProfilingConfig):
# Overriding default value for partition_profiling
partition_profiling_enabled: bool = pydantic.Field(
default=False,
description="Enable partition profiling. This will profile the latest partition of the table.",
)
class CustomAthenaRestDialect(AthenaRestDialect): class CustomAthenaRestDialect(AthenaRestDialect):
"""Custom definition of the Athena dialect. """Custom definition of the Athena dialect.
@ -171,13 +182,17 @@ class CustomAthenaRestDialect(AthenaRestDialect):
# To extract all of them, we simply iterate over all detected fields and # To extract all of them, we simply iterate over all detected fields and
# convert them to SQLalchemy types # convert them to SQLalchemy types
struct_args = [] struct_args = []
for field in struct_type["fields"]: for struct_field in struct_type["fields"]:
struct_args.append( struct_args.append(
( (
field["name"], struct_field["name"],
self._get_column_type(field["type"]["type"]) (
if field["type"]["type"] not in ["record", "array"] self._get_column_type(
else self._get_column_type(field["type"]), struct_field["type"]["native_data_type"]
)
if struct_field["type"]["type"] not in ["record", "array"]
else self._get_column_type(struct_field["type"])
),
) )
) )
@ -189,7 +204,7 @@ class CustomAthenaRestDialect(AthenaRestDialect):
detected_col_type = MapType detected_col_type = MapType
# the type definition for maps looks like the following: key_type:val_type (e.g., string:string) # the type definition for maps looks like the following: key_type:val_type (e.g., string:string)
key_type_raw, value_type_raw = type_meta_information.split(",") key_type_raw, value_type_raw = type_meta_information.split(",", 1)
# convert both type names to actual SQLalchemy types # convert both type names to actual SQLalchemy types
args = [ args = [
@ -257,6 +272,8 @@ class AthenaConfig(SQLCommonConfig):
print_warning=True, print_warning=True,
) )
profiling: AthenaProfilingConfig = AthenaProfilingConfig()
def get_sql_alchemy_url(self): def get_sql_alchemy_url(self):
return make_sqlalchemy_uri( return make_sqlalchemy_uri(
self.scheme, self.scheme,
@ -275,6 +292,12 @@ class AthenaConfig(SQLCommonConfig):
) )
@dataclass
class Partitionitem:
partitions: List[str] = field(default_factory=list)
max_partition: Optional[Dict[str, str]] = None
@platform_name("Athena") @platform_name("Athena")
@support_status(SupportStatus.CERTIFIED) @support_status(SupportStatus.CERTIFIED)
@config_class(AthenaConfig) @config_class(AthenaConfig)
@ -294,6 +317,8 @@ class AthenaSource(SQLAlchemySource):
- Profiling when enabled. - Profiling when enabled.
""" """
table_partition_cache: Dict[str, Dict[str, Partitionitem]] = {}
def __init__(self, config, ctx): def __init__(self, config, ctx):
super().__init__(config, ctx, "athena") super().__init__(config, ctx, "athena")
self.cursor: Optional[BaseCursor] = None self.cursor: Optional[BaseCursor] = None
@ -429,12 +454,53 @@ class AthenaSource(SQLAlchemySource):
return [schema for schema in schemas if schema == athena_config.database] return [schema for schema in schemas if schema == athena_config.database]
return schemas return schemas
# Overwrite to get partitions
def get_partitions(
self, inspector: Inspector, schema: str, table: str
) -> List[str]:
partitions = []
if not self.cursor:
return []
metadata: AthenaTableMetadata = self.cursor.get_table_metadata(
table_name=table, schema_name=schema
)
if metadata.partition_keys:
for key in metadata.partition_keys:
if key.name:
partitions.append(key.name)
if not partitions:
return []
# We create an artiificaial concatenated partition key to be able to query max partition easier
part_concat = "|| '-' ||".join(partitions)
max_partition_query = f'select {",".join(partitions)} from "{schema}"."{table}$partitions" where {part_concat} = (select max({part_concat}) from "{schema}"."{table}$partitions")'
ret = self.cursor.execute(max_partition_query)
max_partition: Dict[str, str] = {}
if ret:
max_partitons = list(ret)
for idx, row in enumerate([row[0] for row in ret.description]):
max_partition[row] = max_partitons[0][idx]
if self.table_partition_cache.get(schema) is None:
self.table_partition_cache[schema] = {}
self.table_partition_cache[schema][table] = Partitionitem(
partitions=partitions,
max_partition=max_partition,
)
return partitions
return []
# Overwrite to modify the creation of schema fields # Overwrite to modify the creation of schema fields
def get_schema_fields_for_column( def get_schema_fields_for_column(
self, self,
dataset_name: str, dataset_name: str,
column: Dict, column: Dict,
pk_constraints: Optional[dict] = None, pk_constraints: Optional[dict] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
) -> List[SchemaField]: ) -> List[SchemaField]:
fields = get_schema_fields_for_sqlalchemy_column( fields = get_schema_fields_for_sqlalchemy_column(
@ -442,17 +508,45 @@ class AthenaSource(SQLAlchemySource):
column_type=column["type"], column_type=column["type"],
description=column.get("comment", None), description=column.get("comment", None),
nullable=column.get("nullable", True), nullable=column.get("nullable", True),
is_part_of_key=True is_part_of_key=(
if ( True
pk_constraints is not None if (
and isinstance(pk_constraints, dict) pk_constraints is not None
and column["name"] in pk_constraints.get("constrained_columns", []) and isinstance(pk_constraints, dict)
) and column["name"] in pk_constraints.get("constrained_columns", [])
else False, )
else False
),
is_partitioning_key=(
True
if (partition_keys is not None and column["name"] in partition_keys)
else False
),
) )
return fields return fields
def generate_partition_profiler_query(
self, schema: str, table: str, partition_datetime: Optional[datetime.datetime]
) -> Tuple[Optional[str], Optional[str]]:
if not self.config.profiling.partition_profiling_enabled:
return None, None
partition: Optional[Partitionitem] = self.table_partition_cache.get(
schema, {}
).get(table, None)
if partition and partition.max_partition:
max_partition_filters = []
for key, value in partition.max_partition.items():
max_partition_filters.append(f"CAST({key} as VARCHAR) = '{value}'")
max_partition = str(partition.max_partition)
return (
max_partition,
f'SELECT * FROM "{schema}"."{table}" WHERE {" AND ".join(max_partition_filters)}',
)
return None, None
def close(self): def close(self):
if self.cursor: if self.cursor:
self.cursor.close() self.cursor.close()

View File

@ -170,6 +170,7 @@ class HiveSource(TwoTierSQLAlchemySource):
dataset_name: str, dataset_name: str,
column: Dict[Any, Any], column: Dict[Any, Any],
pk_constraints: Optional[Dict[Any, Any]] = None, pk_constraints: Optional[Dict[Any, Any]] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
) -> List[SchemaField]: ) -> List[SchemaField]:
fields = super().get_schema_fields_for_column( fields = super().get_schema_fields_for_column(

View File

@ -878,6 +878,7 @@ class HiveMetastoreSource(SQLAlchemySource):
dataset_name: str, dataset_name: str,
column: Dict[Any, Any], column: Dict[Any, Any],
pk_constraints: Optional[Dict[Any, Any]] = None, pk_constraints: Optional[Dict[Any, Any]] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
) -> List[SchemaField]: ) -> List[SchemaField]:
return get_schema_fields_for_hive_column( return get_schema_fields_for_hive_column(

View File

@ -713,7 +713,16 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
data_reader, data_reader,
) )
except Exception as e: except Exception as e:
self.warn(logger, f"{schema}.{table}", f"Ingestion error: {e}") self.warn(
logger,
f"{schema}.{table}",
f"Ingestion error: {e}",
)
logger.debug(
f"Error processing table {schema}.{table}: Error was: {e} Traceback:",
exc_info=e,
)
except Exception as e: except Exception as e:
self.error(logger, f"{schema}", f"Tables error: {e}") self.error(logger, f"{schema}", f"Tables error: {e}")
@ -725,6 +734,11 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
) -> Optional[Dict[str, List[str]]]: ) -> Optional[Dict[str, List[str]]]:
return None return None
def get_partitions(
self, inspector: Inspector, schema: str, table: str
) -> Optional[List[str]]:
return None
def _process_table( def _process_table(
self, self,
dataset_name: str, dataset_name: str,
@ -769,9 +783,14 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
extra_tags = self.get_extra_tags(inspector, schema, table) extra_tags = self.get_extra_tags(inspector, schema, table)
pk_constraints: dict = inspector.get_pk_constraint(table, schema) pk_constraints: dict = inspector.get_pk_constraint(table, schema)
partitions: Optional[List[str]] = self.get_partitions(inspector, schema, table)
foreign_keys = self._get_foreign_keys(dataset_urn, inspector, schema, table) foreign_keys = self._get_foreign_keys(dataset_urn, inspector, schema, table)
schema_fields = self.get_schema_fields( schema_fields = self.get_schema_fields(
dataset_name, columns, pk_constraints, tags=extra_tags dataset_name,
columns,
pk_constraints,
tags=extra_tags,
partition_keys=partitions,
) )
schema_metadata = get_schema_metadata( schema_metadata = get_schema_metadata(
self.report, self.report,
@ -921,6 +940,7 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
if len(columns) == 0: if len(columns) == 0:
self.warn(logger, "missing column information", dataset_name) self.warn(logger, "missing column information", dataset_name)
except Exception as e: except Exception as e:
logger.error(traceback.format_exc())
self.warn( self.warn(
logger, logger,
dataset_name, dataset_name,
@ -949,6 +969,7 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
dataset_name: str, dataset_name: str,
columns: List[dict], columns: List[dict],
pk_constraints: Optional[dict] = None, pk_constraints: Optional[dict] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[Dict[str, List[str]]] = None, tags: Optional[Dict[str, List[str]]] = None,
) -> List[SchemaField]: ) -> List[SchemaField]:
canonical_schema = [] canonical_schema = []
@ -957,7 +978,11 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
if tags: if tags:
column_tags = tags.get(column["name"], []) column_tags = tags.get(column["name"], [])
fields = self.get_schema_fields_for_column( fields = self.get_schema_fields_for_column(
dataset_name, column, pk_constraints, tags=column_tags dataset_name,
column,
pk_constraints,
tags=column_tags,
partition_keys=partition_keys,
) )
canonical_schema.extend(fields) canonical_schema.extend(fields)
return canonical_schema return canonical_schema
@ -967,6 +992,7 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
dataset_name: str, dataset_name: str,
column: dict, column: dict,
pk_constraints: Optional[dict] = None, pk_constraints: Optional[dict] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
) -> List[SchemaField]: ) -> List[SchemaField]:
gtc: Optional[GlobalTagsClass] = None gtc: Optional[GlobalTagsClass] = None
@ -989,6 +1015,10 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
and column["name"] in pk_constraints.get("constrained_columns", []) and column["name"] in pk_constraints.get("constrained_columns", [])
): ):
field.isPartOfKey = True field.isPartOfKey = True
if partition_keys is not None and column["name"] in partition_keys:
field.isPartitioningKey = True
return [field] return [field]
def loop_views( def loop_views(
@ -1017,7 +1047,11 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
sql_config=sql_config, sql_config=sql_config,
) )
except Exception as e: except Exception as e:
self.warn(logger, f"{schema}.{view}", f"Ingestion error: {e}") self.warn(
logger,
f"{schema}.{view}",
f"Ingestion error: {e} {traceback.format_exc()}",
)
except Exception as e: except Exception as e:
self.error(logger, f"{schema}", f"Views error: {e}") self.error(logger, f"{schema}", f"Views error: {e}")

View File

@ -388,6 +388,7 @@ class TrinoSource(SQLAlchemySource):
dataset_name: str, dataset_name: str,
column: dict, column: dict,
pk_constraints: Optional[dict] = None, pk_constraints: Optional[dict] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
) -> List[SchemaField]: ) -> List[SchemaField]:
fields = super().get_schema_fields_for_column( fields = super().get_schema_fields_for_column(

View File

@ -1,5 +1,6 @@
import json import json
import logging import logging
import traceback
import uuid import uuid
from typing import Any, Dict, List, Optional, Type, Union from typing import Any, Dict, List, Optional, Type, Union
@ -46,7 +47,6 @@ class SqlAlchemyColumnToAvroConverter:
cls, column_type: Union[types.TypeEngine, STRUCT, MapType], nullable: bool cls, column_type: Union[types.TypeEngine, STRUCT, MapType], nullable: bool
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Determines the concrete AVRO schema type for a SQLalchemy-typed column""" """Determines the concrete AVRO schema type for a SQLalchemy-typed column"""
if isinstance( if isinstance(
column_type, tuple(cls.PRIMITIVE_SQL_ALCHEMY_TYPE_TO_AVRO_TYPE.keys()) column_type, tuple(cls.PRIMITIVE_SQL_ALCHEMY_TYPE_TO_AVRO_TYPE.keys())
): ):
@ -80,21 +80,38 @@ class SqlAlchemyColumnToAvroConverter:
} }
if isinstance(column_type, types.ARRAY): if isinstance(column_type, types.ARRAY):
array_type = column_type.item_type array_type = column_type.item_type
return { return {
"type": "array", "type": "array",
"items": cls.get_avro_type(column_type=array_type, nullable=nullable), "items": cls.get_avro_type(column_type=array_type, nullable=nullable),
"native_data_type": f"array<{str(column_type.item_type)}>", "native_data_type": f"array<{str(column_type.item_type)}>",
} }
if isinstance(column_type, MapType): if isinstance(column_type, MapType):
key_type = column_type.types[0] try:
value_type = column_type.types[1] key_type = column_type.types[0]
return { value_type = column_type.types[1]
"type": "map", return {
"values": cls.get_avro_type(column_type=value_type, nullable=nullable), "type": "map",
"native_data_type": str(column_type), "values": cls.get_avro_type(
"key_type": cls.get_avro_type(column_type=key_type, nullable=nullable), column_type=value_type, nullable=nullable
"key_native_data_type": str(key_type), ),
} "native_data_type": str(column_type),
"key_type": cls.get_avro_type(
column_type=key_type, nullable=nullable
),
"key_native_data_type": str(key_type),
}
except Exception as e:
logger.warning(
f"Unable to parse MapType {column_type} the error was: {e}"
)
return {
"type": "map",
"values": {"type": "null", "_nullable": True},
"native_data_type": str(column_type),
"key_type": {"type": "null", "_nullable": True},
"key_native_data_type": "null",
}
if STRUCT and isinstance(column_type, STRUCT): if STRUCT and isinstance(column_type, STRUCT):
fields = [] fields = []
for field_def in column_type._STRUCT_fields: for field_def in column_type._STRUCT_fields:
@ -108,14 +125,23 @@ class SqlAlchemyColumnToAvroConverter:
} }
) )
struct_name = f"__struct_{str(uuid.uuid4()).replace('-', '')}" struct_name = f"__struct_{str(uuid.uuid4()).replace('-', '')}"
try:
return { return {
"type": "record", "type": "record",
"name": struct_name, "name": struct_name,
"fields": fields, "fields": fields,
"native_data_type": str(column_type), "native_data_type": str(column_type),
"_nullable": nullable, "_nullable": nullable,
} }
except Exception:
# This is a workaround for the case when the struct name is not string convertable because SqlAlchemt throws an error
return {
"type": "record",
"name": struct_name,
"fields": fields,
"native_data_type": "map",
"_nullable": nullable,
}
return { return {
"type": "null", "type": "null",
@ -153,6 +179,7 @@ def get_schema_fields_for_sqlalchemy_column(
description: Optional[str] = None, description: Optional[str] = None,
nullable: Optional[bool] = True, nullable: Optional[bool] = True,
is_part_of_key: Optional[bool] = False, is_part_of_key: Optional[bool] = False,
is_partitioning_key: Optional[bool] = False,
) -> List[SchemaField]: ) -> List[SchemaField]:
"""Creates SchemaFields from a given SQLalchemy column. """Creates SchemaFields from a given SQLalchemy column.
@ -181,7 +208,7 @@ def get_schema_fields_for_sqlalchemy_column(
) )
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"Unable to parse column {column_name} and type {column_type} the error was: {e}" f"Unable to parse column {column_name} and type {column_type} the error was: {e} Traceback: {traceback.format_exc()}"
) )
# fallback description in case any exception occurred # fallback description in case any exception occurred
@ -208,4 +235,8 @@ def get_schema_fields_for_sqlalchemy_column(
is_part_of_key if is_part_of_key is not None else False is_part_of_key if is_part_of_key is not None else False
) )
schema_fields[0].isPartitioningKey = (
is_partitioning_key if is_partitioning_key is not None else False
)
return schema_fields return schema_fields