mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-12 18:47:45 +00:00
feat(ingest/athena): Add option for Athena partitioned profiling (#10723)
This commit is contained in:
parent
44cdb046d7
commit
20574cf1c6
@ -330,7 +330,14 @@ plugins: Dict[str, Set[str]] = {
|
||||
# sqlalchemy-bigquery is included here since it provides an implementation of
|
||||
# a SQLalchemy-conform STRUCT type definition
|
||||
"athena": sql_common
|
||||
| {"PyAthena[SQLAlchemy]>=2.6.0,<3.0.0", "sqlalchemy-bigquery>=1.4.1"},
|
||||
# We need to set tenacity lower than 8.4.0 as
|
||||
# this version has missing dependency asyncio
|
||||
# https://github.com/jd/tenacity/issues/471
|
||||
| {
|
||||
"PyAthena[SQLAlchemy]>=2.6.0,<3.0.0",
|
||||
"sqlalchemy-bigquery>=1.4.1",
|
||||
"tenacity!=8.4.0",
|
||||
},
|
||||
"azure-ad": set(),
|
||||
"bigquery": sql_common
|
||||
| bigquery_common
|
||||
|
||||
@ -13,6 +13,7 @@ import unittest.mock
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
@ -39,6 +40,7 @@ from great_expectations.data_context.types.base import (
|
||||
from great_expectations.dataset.dataset import Dataset
|
||||
from great_expectations.dataset.sqlalchemy_dataset import SqlAlchemyDataset
|
||||
from great_expectations.datasource.sqlalchemy_datasource import SqlAlchemyDatasource
|
||||
from great_expectations.execution_engine.sqlalchemy_dialect import GXSqlDialect
|
||||
from great_expectations.profile.base import ProfilerDataType
|
||||
from great_expectations.profile.basic_dataset_profiler import BasicDatasetProfilerBase
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
@ -72,9 +74,14 @@ from datahub.utilities.sqlalchemy_query_combiner import (
|
||||
get_query_columns,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyathena.cursor import Cursor
|
||||
|
||||
assert MARKUPSAFE_PATCHED
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
_original_get_column_median = SqlAlchemyDataset.get_column_median
|
||||
|
||||
P = ParamSpec("P")
|
||||
POSTGRESQL = "postgresql"
|
||||
MYSQL = "mysql"
|
||||
@ -203,6 +210,47 @@ def _get_column_quantiles_bigquery_patch( # type:ignore
|
||||
return list()
|
||||
|
||||
|
||||
def _get_column_quantiles_awsathena_patch( # type:ignore
|
||||
self, column: str, quantiles: Iterable
|
||||
) -> list:
|
||||
import ast
|
||||
|
||||
table_name = ".".join(
|
||||
[f'"{table_part}"' for table_part in str(self._table).split(".")]
|
||||
)
|
||||
|
||||
quantiles_list = list(quantiles)
|
||||
quantiles_query = (
|
||||
f"SELECT approx_percentile({column}, ARRAY{str(quantiles_list)}) as quantiles "
|
||||
f"from (SELECT {column} from {table_name})"
|
||||
)
|
||||
try:
|
||||
quantiles_results = self.engine.execute(quantiles_query).fetchone()[0]
|
||||
quantiles_results_list = ast.literal_eval(quantiles_results)
|
||||
return quantiles_results_list
|
||||
|
||||
except ProgrammingError as pe:
|
||||
self._treat_quantiles_exception(pe)
|
||||
return []
|
||||
|
||||
|
||||
def _get_column_median_patch(self, column):
|
||||
# AWS Athena and presto have an special function that can be used to retrieve the median
|
||||
if (
|
||||
self.sql_engine_dialect.name.lower() == GXSqlDialect.AWSATHENA
|
||||
or self.sql_engine_dialect.name.lower() == GXSqlDialect.TRINO
|
||||
):
|
||||
table_name = ".".join(
|
||||
[f'"{table_part}"' for table_part in str(self._table).split(".")]
|
||||
)
|
||||
element_values = self.engine.execute(
|
||||
f"SELECT approx_percentile({column}, 0.5) FROM {table_name}"
|
||||
)
|
||||
return convert_to_json_serializable(element_values.fetchone()[0])
|
||||
else:
|
||||
return _original_get_column_median(self, column)
|
||||
|
||||
|
||||
def _is_single_row_query_method(query: Any) -> bool:
|
||||
SINGLE_ROW_QUERY_FILES = {
|
||||
# "great_expectations/dataset/dataset.py",
|
||||
@ -1038,6 +1086,12 @@ class DatahubGEProfiler:
|
||||
), unittest.mock.patch(
|
||||
"great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset._get_column_quantiles_bigquery",
|
||||
_get_column_quantiles_bigquery_patch,
|
||||
), unittest.mock.patch(
|
||||
"great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset._get_column_quantiles_awsathena",
|
||||
_get_column_quantiles_awsathena_patch,
|
||||
), unittest.mock.patch(
|
||||
"great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset.get_column_median",
|
||||
_get_column_median_patch,
|
||||
), concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers
|
||||
) as async_executor, SQLAlchemyQueryCombiner(
|
||||
@ -1114,15 +1168,16 @@ class DatahubGEProfiler:
|
||||
**request.batch_kwargs,
|
||||
)
|
||||
|
||||
def _drop_trino_temp_table(self, temp_dataset: Dataset) -> None:
|
||||
def _drop_temp_table(self, temp_dataset: Dataset) -> None:
|
||||
schema = temp_dataset._table.schema
|
||||
table = temp_dataset._table.name
|
||||
table_name = f'"{schema}"."{table}"' if schema else f'"{table}"'
|
||||
try:
|
||||
with self.base_engine.connect() as connection:
|
||||
connection.execute(f"drop view if exists {schema}.{table}")
|
||||
logger.debug(f"View {schema}.{table} was dropped.")
|
||||
connection.execute(f"drop view if exists {table_name}")
|
||||
logger.debug(f"View {table_name} was dropped.")
|
||||
except Exception:
|
||||
logger.warning(f"Unable to delete trino temporary table: {schema}.{table}")
|
||||
logger.warning(f"Unable to delete temporary table: {table_name}")
|
||||
|
||||
def _generate_single_profile(
|
||||
self,
|
||||
@ -1149,6 +1204,19 @@ class DatahubGEProfiler:
|
||||
}
|
||||
|
||||
bigquery_temp_table: Optional[str] = None
|
||||
temp_view: Optional[str] = None
|
||||
if platform and platform.upper() == "ATHENA" and (custom_sql):
|
||||
if custom_sql is not None:
|
||||
# Note that limit and offset are not supported for custom SQL.
|
||||
temp_view = create_athena_temp_table(
|
||||
self, custom_sql, pretty_name, self.base_engine.raw_connection()
|
||||
)
|
||||
ge_config["table"] = temp_view
|
||||
ge_config["schema"] = None
|
||||
ge_config["limit"] = None
|
||||
ge_config["offset"] = None
|
||||
custom_sql = None
|
||||
|
||||
if platform == BIGQUERY and (
|
||||
custom_sql or self.config.limit or self.config.offset
|
||||
):
|
||||
@ -1234,8 +1302,16 @@ class DatahubGEProfiler:
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if batch is not None and self.base_engine.engine.name == TRINO:
|
||||
self._drop_trino_temp_table(batch)
|
||||
if batch is not None and self.base_engine.engine.name.upper() in [
|
||||
"TRINO",
|
||||
"AWSATHENA",
|
||||
]:
|
||||
if (
|
||||
self.base_engine.engine.name.upper() == "TRINO"
|
||||
or temp_view is not None
|
||||
):
|
||||
self._drop_temp_table(batch)
|
||||
# if we are not on Trino then we only drop table if temp table variable was set
|
||||
|
||||
def _get_ge_dataset(
|
||||
self,
|
||||
@ -1299,6 +1375,40 @@ def _get_column_types_to_ignore(dialect_name: str) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def create_athena_temp_table(
|
||||
instance: Union[DatahubGEProfiler, _SingleDatasetProfiler],
|
||||
sql: str,
|
||||
table_pretty_name: str,
|
||||
raw_connection: Any,
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
cursor: "Cursor" = cast("Cursor", raw_connection.cursor())
|
||||
logger.debug(f"Creating view for {table_pretty_name}: {sql}")
|
||||
temp_view = f"ge_{uuid.uuid4()}"
|
||||
if "." in table_pretty_name:
|
||||
schema_part = table_pretty_name.split(".")[-1]
|
||||
schema_part_quoted = ".".join(
|
||||
[f'"{part}"' for part in str(schema_part).split(".")]
|
||||
)
|
||||
temp_view = f"{schema_part_quoted}_{temp_view}"
|
||||
|
||||
temp_view = f"ge_{uuid.uuid4()}"
|
||||
cursor.execute(f'create or replace view "{temp_view}" as {sql}')
|
||||
except Exception as e:
|
||||
if not instance.config.catch_exceptions:
|
||||
raise e
|
||||
logger.exception(f"Encountered exception while profiling {table_pretty_name}")
|
||||
instance.report.report_warning(
|
||||
table_pretty_name,
|
||||
f"Profiling exception {e} when running custom sql {sql}",
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
raw_connection.close()
|
||||
|
||||
return temp_view
|
||||
|
||||
|
||||
def create_bigquery_temp_table(
|
||||
instance: Union[DatahubGEProfiler, _SingleDatasetProfiler],
|
||||
bq_sql: str,
|
||||
|
||||
@ -147,7 +147,7 @@ class GEProfilingConfig(ConfigModel):
|
||||
|
||||
partition_profiling_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to profile partitioned tables. Only BigQuery supports this. "
|
||||
description="Whether to profile partitioned tables. Only BigQuery and Aws Athena supports this. "
|
||||
"If enabled, latest partition data is used for profiling.",
|
||||
)
|
||||
partition_datetime: Optional[datetime.datetime] = Field(
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import typing
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
|
||||
|
||||
import pydantic
|
||||
@ -27,6 +29,7 @@ from datahub.ingestion.api.decorators import (
|
||||
from datahub.ingestion.api.workunit import MetadataWorkUnit
|
||||
from datahub.ingestion.source.aws.s3_util import make_s3_urn
|
||||
from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes
|
||||
from datahub.ingestion.source.ge_profiling_config import GEProfilingConfig
|
||||
from datahub.ingestion.source.sql.sql_common import (
|
||||
SQLAlchemySource,
|
||||
register_custom_type,
|
||||
@ -52,6 +55,14 @@ register_custom_type(STRUCT, RecordTypeClass)
|
||||
register_custom_type(MapType, MapTypeClass)
|
||||
|
||||
|
||||
class AthenaProfilingConfig(GEProfilingConfig):
|
||||
# Overriding default value for partition_profiling
|
||||
partition_profiling_enabled: bool = pydantic.Field(
|
||||
default=False,
|
||||
description="Enable partition profiling. This will profile the latest partition of the table.",
|
||||
)
|
||||
|
||||
|
||||
class CustomAthenaRestDialect(AthenaRestDialect):
|
||||
"""Custom definition of the Athena dialect.
|
||||
|
||||
@ -171,13 +182,17 @@ class CustomAthenaRestDialect(AthenaRestDialect):
|
||||
# To extract all of them, we simply iterate over all detected fields and
|
||||
# convert them to SQLalchemy types
|
||||
struct_args = []
|
||||
for field in struct_type["fields"]:
|
||||
for struct_field in struct_type["fields"]:
|
||||
struct_args.append(
|
||||
(
|
||||
field["name"],
|
||||
self._get_column_type(field["type"]["type"])
|
||||
if field["type"]["type"] not in ["record", "array"]
|
||||
else self._get_column_type(field["type"]),
|
||||
struct_field["name"],
|
||||
(
|
||||
self._get_column_type(
|
||||
struct_field["type"]["native_data_type"]
|
||||
)
|
||||
if struct_field["type"]["type"] not in ["record", "array"]
|
||||
else self._get_column_type(struct_field["type"])
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@ -189,7 +204,7 @@ class CustomAthenaRestDialect(AthenaRestDialect):
|
||||
detected_col_type = MapType
|
||||
|
||||
# the type definition for maps looks like the following: key_type:val_type (e.g., string:string)
|
||||
key_type_raw, value_type_raw = type_meta_information.split(",")
|
||||
key_type_raw, value_type_raw = type_meta_information.split(",", 1)
|
||||
|
||||
# convert both type names to actual SQLalchemy types
|
||||
args = [
|
||||
@ -257,6 +272,8 @@ class AthenaConfig(SQLCommonConfig):
|
||||
print_warning=True,
|
||||
)
|
||||
|
||||
profiling: AthenaProfilingConfig = AthenaProfilingConfig()
|
||||
|
||||
def get_sql_alchemy_url(self):
|
||||
return make_sqlalchemy_uri(
|
||||
self.scheme,
|
||||
@ -275,6 +292,12 @@ class AthenaConfig(SQLCommonConfig):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Partitionitem:
|
||||
partitions: List[str] = field(default_factory=list)
|
||||
max_partition: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
@platform_name("Athena")
|
||||
@support_status(SupportStatus.CERTIFIED)
|
||||
@config_class(AthenaConfig)
|
||||
@ -294,6 +317,8 @@ class AthenaSource(SQLAlchemySource):
|
||||
- Profiling when enabled.
|
||||
"""
|
||||
|
||||
table_partition_cache: Dict[str, Dict[str, Partitionitem]] = {}
|
||||
|
||||
def __init__(self, config, ctx):
|
||||
super().__init__(config, ctx, "athena")
|
||||
self.cursor: Optional[BaseCursor] = None
|
||||
@ -429,12 +454,53 @@ class AthenaSource(SQLAlchemySource):
|
||||
return [schema for schema in schemas if schema == athena_config.database]
|
||||
return schemas
|
||||
|
||||
# Overwrite to get partitions
|
||||
def get_partitions(
|
||||
self, inspector: Inspector, schema: str, table: str
|
||||
) -> List[str]:
|
||||
partitions = []
|
||||
|
||||
if not self.cursor:
|
||||
return []
|
||||
|
||||
metadata: AthenaTableMetadata = self.cursor.get_table_metadata(
|
||||
table_name=table, schema_name=schema
|
||||
)
|
||||
|
||||
if metadata.partition_keys:
|
||||
for key in metadata.partition_keys:
|
||||
if key.name:
|
||||
partitions.append(key.name)
|
||||
|
||||
if not partitions:
|
||||
return []
|
||||
|
||||
# We create an artiificaial concatenated partition key to be able to query max partition easier
|
||||
part_concat = "|| '-' ||".join(partitions)
|
||||
max_partition_query = f'select {",".join(partitions)} from "{schema}"."{table}$partitions" where {part_concat} = (select max({part_concat}) from "{schema}"."{table}$partitions")'
|
||||
ret = self.cursor.execute(max_partition_query)
|
||||
max_partition: Dict[str, str] = {}
|
||||
if ret:
|
||||
max_partitons = list(ret)
|
||||
for idx, row in enumerate([row[0] for row in ret.description]):
|
||||
max_partition[row] = max_partitons[0][idx]
|
||||
if self.table_partition_cache.get(schema) is None:
|
||||
self.table_partition_cache[schema] = {}
|
||||
self.table_partition_cache[schema][table] = Partitionitem(
|
||||
partitions=partitions,
|
||||
max_partition=max_partition,
|
||||
)
|
||||
return partitions
|
||||
|
||||
return []
|
||||
|
||||
# Overwrite to modify the creation of schema fields
|
||||
def get_schema_fields_for_column(
|
||||
self,
|
||||
dataset_name: str,
|
||||
column: Dict,
|
||||
pk_constraints: Optional[dict] = None,
|
||||
partition_keys: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> List[SchemaField]:
|
||||
fields = get_schema_fields_for_sqlalchemy_column(
|
||||
@ -442,17 +508,45 @@ class AthenaSource(SQLAlchemySource):
|
||||
column_type=column["type"],
|
||||
description=column.get("comment", None),
|
||||
nullable=column.get("nullable", True),
|
||||
is_part_of_key=True
|
||||
if (
|
||||
pk_constraints is not None
|
||||
and isinstance(pk_constraints, dict)
|
||||
and column["name"] in pk_constraints.get("constrained_columns", [])
|
||||
)
|
||||
else False,
|
||||
is_part_of_key=(
|
||||
True
|
||||
if (
|
||||
pk_constraints is not None
|
||||
and isinstance(pk_constraints, dict)
|
||||
and column["name"] in pk_constraints.get("constrained_columns", [])
|
||||
)
|
||||
else False
|
||||
),
|
||||
is_partitioning_key=(
|
||||
True
|
||||
if (partition_keys is not None and column["name"] in partition_keys)
|
||||
else False
|
||||
),
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
def generate_partition_profiler_query(
|
||||
self, schema: str, table: str, partition_datetime: Optional[datetime.datetime]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
if not self.config.profiling.partition_profiling_enabled:
|
||||
return None, None
|
||||
|
||||
partition: Optional[Partitionitem] = self.table_partition_cache.get(
|
||||
schema, {}
|
||||
).get(table, None)
|
||||
|
||||
if partition and partition.max_partition:
|
||||
max_partition_filters = []
|
||||
for key, value in partition.max_partition.items():
|
||||
max_partition_filters.append(f"CAST({key} as VARCHAR) = '{value}'")
|
||||
max_partition = str(partition.max_partition)
|
||||
return (
|
||||
max_partition,
|
||||
f'SELECT * FROM "{schema}"."{table}" WHERE {" AND ".join(max_partition_filters)}',
|
||||
)
|
||||
return None, None
|
||||
|
||||
def close(self):
|
||||
if self.cursor:
|
||||
self.cursor.close()
|
||||
|
||||
@ -170,6 +170,7 @@ class HiveSource(TwoTierSQLAlchemySource):
|
||||
dataset_name: str,
|
||||
column: Dict[Any, Any],
|
||||
pk_constraints: Optional[Dict[Any, Any]] = None,
|
||||
partition_keys: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> List[SchemaField]:
|
||||
fields = super().get_schema_fields_for_column(
|
||||
|
||||
@ -878,6 +878,7 @@ class HiveMetastoreSource(SQLAlchemySource):
|
||||
dataset_name: str,
|
||||
column: Dict[Any, Any],
|
||||
pk_constraints: Optional[Dict[Any, Any]] = None,
|
||||
partition_keys: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> List[SchemaField]:
|
||||
return get_schema_fields_for_hive_column(
|
||||
|
||||
@ -713,7 +713,16 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
|
||||
data_reader,
|
||||
)
|
||||
except Exception as e:
|
||||
self.warn(logger, f"{schema}.{table}", f"Ingestion error: {e}")
|
||||
self.warn(
|
||||
logger,
|
||||
f"{schema}.{table}",
|
||||
f"Ingestion error: {e}",
|
||||
)
|
||||
logger.debug(
|
||||
f"Error processing table {schema}.{table}: Error was: {e} Traceback:",
|
||||
exc_info=e,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.error(logger, f"{schema}", f"Tables error: {e}")
|
||||
|
||||
@ -725,6 +734,11 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
|
||||
) -> Optional[Dict[str, List[str]]]:
|
||||
return None
|
||||
|
||||
def get_partitions(
|
||||
self, inspector: Inspector, schema: str, table: str
|
||||
) -> Optional[List[str]]:
|
||||
return None
|
||||
|
||||
def _process_table(
|
||||
self,
|
||||
dataset_name: str,
|
||||
@ -769,9 +783,14 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
|
||||
|
||||
extra_tags = self.get_extra_tags(inspector, schema, table)
|
||||
pk_constraints: dict = inspector.get_pk_constraint(table, schema)
|
||||
partitions: Optional[List[str]] = self.get_partitions(inspector, schema, table)
|
||||
foreign_keys = self._get_foreign_keys(dataset_urn, inspector, schema, table)
|
||||
schema_fields = self.get_schema_fields(
|
||||
dataset_name, columns, pk_constraints, tags=extra_tags
|
||||
dataset_name,
|
||||
columns,
|
||||
pk_constraints,
|
||||
tags=extra_tags,
|
||||
partition_keys=partitions,
|
||||
)
|
||||
schema_metadata = get_schema_metadata(
|
||||
self.report,
|
||||
@ -921,6 +940,7 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
|
||||
if len(columns) == 0:
|
||||
self.warn(logger, "missing column information", dataset_name)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
self.warn(
|
||||
logger,
|
||||
dataset_name,
|
||||
@ -949,6 +969,7 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
|
||||
dataset_name: str,
|
||||
columns: List[dict],
|
||||
pk_constraints: Optional[dict] = None,
|
||||
partition_keys: Optional[List[str]] = None,
|
||||
tags: Optional[Dict[str, List[str]]] = None,
|
||||
) -> List[SchemaField]:
|
||||
canonical_schema = []
|
||||
@ -957,7 +978,11 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
|
||||
if tags:
|
||||
column_tags = tags.get(column["name"], [])
|
||||
fields = self.get_schema_fields_for_column(
|
||||
dataset_name, column, pk_constraints, tags=column_tags
|
||||
dataset_name,
|
||||
column,
|
||||
pk_constraints,
|
||||
tags=column_tags,
|
||||
partition_keys=partition_keys,
|
||||
)
|
||||
canonical_schema.extend(fields)
|
||||
return canonical_schema
|
||||
@ -967,6 +992,7 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
|
||||
dataset_name: str,
|
||||
column: dict,
|
||||
pk_constraints: Optional[dict] = None,
|
||||
partition_keys: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> List[SchemaField]:
|
||||
gtc: Optional[GlobalTagsClass] = None
|
||||
@ -989,6 +1015,10 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
|
||||
and column["name"] in pk_constraints.get("constrained_columns", [])
|
||||
):
|
||||
field.isPartOfKey = True
|
||||
|
||||
if partition_keys is not None and column["name"] in partition_keys:
|
||||
field.isPartitioningKey = True
|
||||
|
||||
return [field]
|
||||
|
||||
def loop_views(
|
||||
@ -1017,7 +1047,11 @@ class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource):
|
||||
sql_config=sql_config,
|
||||
)
|
||||
except Exception as e:
|
||||
self.warn(logger, f"{schema}.{view}", f"Ingestion error: {e}")
|
||||
self.warn(
|
||||
logger,
|
||||
f"{schema}.{view}",
|
||||
f"Ingestion error: {e} {traceback.format_exc()}",
|
||||
)
|
||||
except Exception as e:
|
||||
self.error(logger, f"{schema}", f"Views error: {e}")
|
||||
|
||||
|
||||
@ -388,6 +388,7 @@ class TrinoSource(SQLAlchemySource):
|
||||
dataset_name: str,
|
||||
column: dict,
|
||||
pk_constraints: Optional[dict] = None,
|
||||
partition_keys: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> List[SchemaField]:
|
||||
fields = super().get_schema_fields_for_column(
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
@ -46,7 +47,6 @@ class SqlAlchemyColumnToAvroConverter:
|
||||
cls, column_type: Union[types.TypeEngine, STRUCT, MapType], nullable: bool
|
||||
) -> Dict[str, Any]:
|
||||
"""Determines the concrete AVRO schema type for a SQLalchemy-typed column"""
|
||||
|
||||
if isinstance(
|
||||
column_type, tuple(cls.PRIMITIVE_SQL_ALCHEMY_TYPE_TO_AVRO_TYPE.keys())
|
||||
):
|
||||
@ -80,21 +80,38 @@ class SqlAlchemyColumnToAvroConverter:
|
||||
}
|
||||
if isinstance(column_type, types.ARRAY):
|
||||
array_type = column_type.item_type
|
||||
|
||||
return {
|
||||
"type": "array",
|
||||
"items": cls.get_avro_type(column_type=array_type, nullable=nullable),
|
||||
"native_data_type": f"array<{str(column_type.item_type)}>",
|
||||
}
|
||||
if isinstance(column_type, MapType):
|
||||
key_type = column_type.types[0]
|
||||
value_type = column_type.types[1]
|
||||
return {
|
||||
"type": "map",
|
||||
"values": cls.get_avro_type(column_type=value_type, nullable=nullable),
|
||||
"native_data_type": str(column_type),
|
||||
"key_type": cls.get_avro_type(column_type=key_type, nullable=nullable),
|
||||
"key_native_data_type": str(key_type),
|
||||
}
|
||||
try:
|
||||
key_type = column_type.types[0]
|
||||
value_type = column_type.types[1]
|
||||
return {
|
||||
"type": "map",
|
||||
"values": cls.get_avro_type(
|
||||
column_type=value_type, nullable=nullable
|
||||
),
|
||||
"native_data_type": str(column_type),
|
||||
"key_type": cls.get_avro_type(
|
||||
column_type=key_type, nullable=nullable
|
||||
),
|
||||
"key_native_data_type": str(key_type),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unable to parse MapType {column_type} the error was: {e}"
|
||||
)
|
||||
return {
|
||||
"type": "map",
|
||||
"values": {"type": "null", "_nullable": True},
|
||||
"native_data_type": str(column_type),
|
||||
"key_type": {"type": "null", "_nullable": True},
|
||||
"key_native_data_type": "null",
|
||||
}
|
||||
if STRUCT and isinstance(column_type, STRUCT):
|
||||
fields = []
|
||||
for field_def in column_type._STRUCT_fields:
|
||||
@ -108,14 +125,23 @@ class SqlAlchemyColumnToAvroConverter:
|
||||
}
|
||||
)
|
||||
struct_name = f"__struct_{str(uuid.uuid4()).replace('-', '')}"
|
||||
|
||||
return {
|
||||
"type": "record",
|
||||
"name": struct_name,
|
||||
"fields": fields,
|
||||
"native_data_type": str(column_type),
|
||||
"_nullable": nullable,
|
||||
}
|
||||
try:
|
||||
return {
|
||||
"type": "record",
|
||||
"name": struct_name,
|
||||
"fields": fields,
|
||||
"native_data_type": str(column_type),
|
||||
"_nullable": nullable,
|
||||
}
|
||||
except Exception:
|
||||
# This is a workaround for the case when the struct name is not string convertable because SqlAlchemt throws an error
|
||||
return {
|
||||
"type": "record",
|
||||
"name": struct_name,
|
||||
"fields": fields,
|
||||
"native_data_type": "map",
|
||||
"_nullable": nullable,
|
||||
}
|
||||
|
||||
return {
|
||||
"type": "null",
|
||||
@ -153,6 +179,7 @@ def get_schema_fields_for_sqlalchemy_column(
|
||||
description: Optional[str] = None,
|
||||
nullable: Optional[bool] = True,
|
||||
is_part_of_key: Optional[bool] = False,
|
||||
is_partitioning_key: Optional[bool] = False,
|
||||
) -> List[SchemaField]:
|
||||
"""Creates SchemaFields from a given SQLalchemy column.
|
||||
|
||||
@ -181,7 +208,7 @@ def get_schema_fields_for_sqlalchemy_column(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unable to parse column {column_name} and type {column_type} the error was: {e}"
|
||||
f"Unable to parse column {column_name} and type {column_type} the error was: {e} Traceback: {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
# fallback description in case any exception occurred
|
||||
@ -208,4 +235,8 @@ def get_schema_fields_for_sqlalchemy_column(
|
||||
is_part_of_key if is_part_of_key is not None else False
|
||||
)
|
||||
|
||||
schema_fields[0].isPartitioningKey = (
|
||||
is_partitioning_key if is_partitioning_key is not None else False
|
||||
)
|
||||
|
||||
return schema_fields
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user