feat(ingest/athena): Iceberg partition columns extraction (#13607)

This commit is contained in:
Tamas Nemeth 2025-07-07 13:46:40 +01:00 committed by GitHub
parent 17aa2d72a5
commit 5b8d4bad7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1677 additions and 26 deletions

View File

@ -34,6 +34,9 @@ from datahub.ingestion.source.common.subtypes import (
SourceCapabilityModifier,
)
from datahub.ingestion.source.ge_profiling_config import GEProfilingConfig
from datahub.ingestion.source.sql.athena_properties_extractor import (
AthenaPropertiesExtractor,
)
from datahub.ingestion.source.sql.sql_common import (
SQLAlchemySource,
register_custom_type,
@ -47,12 +50,17 @@ from datahub.ingestion.source.sql.sql_utils import (
)
from datahub.ingestion.source.sql.sqlalchemy_uri import make_sqlalchemy_uri
from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField
from datahub.metadata.schema_classes import MapTypeClass, RecordTypeClass
from datahub.metadata.schema_classes import (
ArrayTypeClass,
MapTypeClass,
RecordTypeClass,
)
from datahub.utilities.hive_schema_to_avro import get_avro_schema_for_hive_column
from datahub.utilities.sqlalchemy_type_converter import (
MapType,
get_schema_fields_for_sqlalchemy_column,
)
from datahub.utilities.urns.field_paths import get_simple_field_path_from_v2_field_path
try:
from typing_extensions import override
@ -284,6 +292,11 @@ class AthenaConfig(SQLCommonConfig):
description="Extract partitions for tables. Partition extraction needs to run a query (`select * from table$partitions`) on the table. Disable this if you don't want to grant select permission.",
)
extract_partitions_using_create_statements: bool = pydantic.Field(
default=False,
description="Extract partitions using the `SHOW CREATE TABLE` statement instead of querying the table's partitions directly. This needs to be enabled to extract Iceberg partitions. If extraction fails it falls back to the default partition extraction. This is experimental.",
)
_s3_staging_dir_population = pydantic_renamed_field(
old_name="s3_staging_dir",
new_name="query_result_location",
@ -496,23 +509,38 @@ class AthenaSource(SQLAlchemySource):
def get_partitions(
self, inspector: Inspector, schema: str, table: str
) -> Optional[List[str]]:
if not self.config.extract_partitions:
if (
not self.config.extract_partitions
and not self.config.extract_partitions_using_create_statements
):
return None
if not self.cursor:
return None
metadata: AthenaTableMetadata = self.cursor.get_table_metadata(
table_name=table, schema_name=schema
)
if self.config.extract_partitions_using_create_statements:
try:
partitions = self._get_partitions_create_table(schema, table)
except Exception as e:
logger.warning(
f"Failed to get partitions from create table statement for {schema}.{table} because of {e}. Falling back to SQLAlchemy.",
exc_info=True,
)
# If we can't get create table statement, we fall back to SQLAlchemy
partitions = self._get_partitions_sqlalchemy(schema, table)
else:
partitions = self._get_partitions_sqlalchemy(schema, table)
partitions = []
for key in metadata.partition_keys:
if key.name:
partitions.append(key.name)
if not partitions:
return []
if (
not self.config.profiling.enabled
or not self.config.profiling.partition_profiling_enabled
):
return partitions
with self.report.report_exc(
message="Failed to extract partition details",
context=f"{schema}.{table}",
@ -538,6 +566,56 @@ class AthenaSource(SQLAlchemySource):
return partitions
def _get_partitions_create_table(self, schema: str, table: str) -> List[str]:
assert self.cursor
try:
res = self.cursor.execute(f"SHOW CREATE TABLE `{schema}`.`{table}`")
except Exception as e:
# Athena does not support SHOW CREATE TABLE for views
# and will throw an error. We need to handle this case
# and caller needs to fallback to sqlalchemy's get partitions call.
logger.debug(
f"Failed to get table properties for {schema}.{table}: {e}",
exc_info=True,
)
raise e
rows = res.fetchall()
# Concatenate all rows into a single string with newlines
create_table_statement = "\n".join(row[0] for row in rows)
try:
athena_table_info = AthenaPropertiesExtractor.get_table_properties(
create_table_statement
)
except Exception as e:
logger.debug(
f"Failed to parse table properties for {schema}.{table}: {e} and statement: {create_table_statement}",
exc_info=True,
)
raise e
partitions = []
if (
athena_table_info.partition_info
and athena_table_info.partition_info.simple_columns
):
partitions = [
ci.name for ci in athena_table_info.partition_info.simple_columns
]
return partitions
def _get_partitions_sqlalchemy(self, schema: str, table: str) -> List[str]:
assert self.cursor
metadata: AthenaTableMetadata = self.cursor.get_table_metadata(
table_name=table, schema_name=schema
)
partitions = []
for key in metadata.partition_keys:
if key.name:
partitions.append(key.name)
return partitions
# Overwrite to modify the creation of schema fields
def get_schema_fields_for_column(
self,
@ -563,7 +641,14 @@ class AthenaSource(SQLAlchemySource):
partition_keys is not None and column["name"] in partition_keys
),
)
if isinstance(
fields[0].type.type, (RecordTypeClass, MapTypeClass, ArrayTypeClass)
):
return fields
else:
fields[0].fieldPath = get_simple_field_path_from_v2_field_path(
fields[0].fieldPath
)
return fields
def generate_partition_profiler_query(

View File

@ -0,0 +1,777 @@
"""
Athena Properties Extractor - A robust tool for parsing CREATE TABLE statements.
This module provides functionality to extract properties, partitioning information,
and row format details from Athena CREATE TABLE SQL statements.
"""
import json
import re
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple, Union
from sqlglot import ParseError, parse_one
from sqlglot.dialects.athena import Athena
from sqlglot.expressions import (
Anonymous,
ColumnDef,
Create,
Day,
Expression,
FileFormatProperty,
Identifier,
LocationProperty,
Month,
PartitionByTruncate,
PartitionedByBucket,
PartitionedByProperty,
Property,
RowFormatDelimitedProperty,
Schema,
SchemaCommentProperty,
SerdeProperties,
Year,
)
class AthenaPropertiesExtractionError(Exception):
"""Custom exception for Athena properties extraction errors."""
pass
@dataclass
class ColumnInfo:
"""Information about a table column."""
name: str
type: str
@dataclass
class TransformInfo:
"""Information about a partition transform."""
type: str
column: ColumnInfo
bucket_count: Optional[int] = None
length: Optional[int] = None
@dataclass
class PartitionInfo:
"""Information about table partitioning."""
simple_columns: List[ColumnInfo]
transforms: List[TransformInfo]
@dataclass
class TableProperties:
"""General table properties."""
location: Optional[str] = None
format: Optional[str] = None
comment: Optional[str] = None
serde_properties: Optional[Dict[str, str]] = None
row_format: Optional[Dict[str, str]] = None
additional_properties: Optional[Dict[str, str]] = None
@dataclass
class RowFormatInfo:
"""Row format information."""
properties: Dict[str, str]
json_formatted: str
@dataclass
class AthenaTableInfo:
"""Complete information about an Athena table."""
partition_info: PartitionInfo
table_properties: TableProperties
row_format: RowFormatInfo
class AthenaPropertiesExtractor:
"""A class to extract properties from Athena CREATE TABLE statements."""
CREATE_TABLE_REGEXP = re.compile(
"(CREATE TABLE[\s\n]*)(.*?)(\s*\()", re.MULTILINE | re.IGNORECASE
)
PARTITIONED_BY_REGEXP = re.compile(
"(PARTITIONED BY[\s\n]*\()((?:[^()]|\([^)]*\))*?)(\))",
re.MULTILINE | re.IGNORECASE,
)
def __init__(self) -> None:
"""Initialize the extractor."""
pass
@staticmethod
def get_table_properties(sql: str) -> AthenaTableInfo:
"""Get all table properties from a SQL statement.
Args:
sql: The SQL statement to parse
Returns:
An AthenaTableInfo object containing all table properties
Raises:
AthenaPropertiesExtractionError: If extraction fails
"""
extractor = AthenaPropertiesExtractor()
return extractor._extract_all_properties(sql)
def _extract_all_properties(self, sql: str) -> AthenaTableInfo:
"""Extract all properties from a SQL statement.
Args:
sql: The SQL statement to parse
Returns:
An AthenaTableInfo object containing all properties
Raises:
AthenaPropertiesExtractionError: If extraction fails
"""
if not sql or not sql.strip():
raise AthenaPropertiesExtractionError("SQL statement cannot be empty")
try:
# We need to do certain transformations on the sql create statement:
# - table names are not quoted
# - column expression is not quoted
# - sql parser fails if partition colums quoted
fixed_sql = self._fix_sql_partitioning(sql)
parsed = parse_one(fixed_sql, dialect=Athena)
except ParseError as e:
raise AthenaPropertiesExtractionError(f"Failed to parse SQL: {e}") from e
except Exception as e:
raise AthenaPropertiesExtractionError(
f"Unexpected error during SQL parsing: {e}"
) from e
try:
partition_info = self._extract_partition_info(parsed)
table_properties = self._extract_table_properties(parsed)
row_format = self._extract_row_format(parsed)
return AthenaTableInfo(
partition_info=partition_info,
table_properties=table_properties,
row_format=row_format,
)
except Exception as e:
raise AthenaPropertiesExtractionError(
f"Failed to extract table properties: {e}"
) from e
@staticmethod
def format_column_definition(line):
# Use regex to parse the line more accurately
# Pattern: column_name data_type [COMMENT comment_text] [,]
# Use greedy match for comment to capture everything until trailing comma
pattern = r"^\s*(.+?)\s+([\s,\w<>\[\]]+)((\s+COMMENT\s+(.+?)(,?))|(,?)\s*)?$"
match = re.match(pattern, line, re.IGNORECASE)
if not match:
return line
column_name = match.group(1)
data_type = match.group(2)
comment_part = match.group(5) # COMMENT part
# there are different number of match groups depending on whether comment exists
if comment_part:
trailing_comma = match.group(6) if match.group(6) else ""
else:
trailing_comma = match.group(7) if match.group(7) else ""
# Add backticks to column name if not already present
if not (column_name.startswith("`") and column_name.endswith("`")):
column_name = f"`{column_name}`"
# Build the result
result_parts = [column_name, data_type]
if comment_part:
comment_part = comment_part.strip()
# Handle comment quoting and escaping
if comment_part.startswith("'") and comment_part.endswith("'"):
# Already properly single quoted - keep as is
formatted_comment = comment_part
elif comment_part.startswith('"') and comment_part.endswith('"'):
# Double quoted - convert to single quotes and escape internal single quotes
inner_content = comment_part[1:-1]
escaped_content = inner_content.replace("'", "''")
formatted_comment = f"'{escaped_content}'"
else:
# Not quoted - add quotes and escape any single quotes
escaped_content = comment_part.replace("'", "''")
formatted_comment = f"'{escaped_content}'"
result_parts.extend(["COMMENT", formatted_comment])
result = " " + " ".join(result_parts) + trailing_comma
return result
@staticmethod
def format_athena_column_definitions(sql_statement: str) -> str:
"""
Format Athena CREATE TABLE statement by:
1. Adding backticks around column names in column definitions (only in the main table definition)
2. Quoting comments (if any exist)
"""
lines = sql_statement.split("\n")
formatted_lines = []
in_column_definition = False
for line in lines:
stripped_line = line.strip()
# Check if we're entering column definitions
if "CREATE TABLE" in line.upper() and "(" in line:
in_column_definition = True
formatted_lines.append(line)
continue
# Check if we're exiting column definitions (closing parenthesis before PARTITIONED BY or end)
if in_column_definition and ")" in line:
in_column_definition = False
formatted_lines.append(line)
continue
# Process only column definitions (not PARTITIONED BY or other sections)
if in_column_definition and stripped_line:
# Match column definition pattern and format it
formatted_line = AthenaPropertiesExtractor.format_column_definition(
line
)
formatted_lines.append(formatted_line)
else:
# For all other lines, keep as-is
formatted_lines.append(line)
return "\n".join(formatted_lines)
@staticmethod
def _fix_sql_partitioning(sql: str) -> str:
"""Fix SQL partitioning by removing backticks from partition expressions and quoting table names.
Args:
sql: The SQL statement to fix
Returns:
The fixed SQL statement
"""
if not sql:
return sql
# Quote table name
table_name_match = AthenaPropertiesExtractor.CREATE_TABLE_REGEXP.search(sql)
if table_name_match:
table_name = table_name_match.group(2).strip()
if table_name and not (table_name.startswith("`") or "`" in table_name):
# Split on dots and quote each part
quoted_parts = [
f"`{part.strip()}`"
for part in table_name.split(".")
if part.strip()
]
if quoted_parts:
quoted_table = ".".join(quoted_parts)
create_part = table_name_match.group(0).replace(
table_name, quoted_table
)
sql = sql.replace(table_name_match.group(0), create_part)
# Fix partition expressions
partition_match = AthenaPropertiesExtractor.PARTITIONED_BY_REGEXP.search(sql)
if partition_match:
partition_section = partition_match.group(2)
if partition_section:
partition_section_modified = partition_section.replace("`", "")
sql = sql.replace(partition_section, partition_section_modified)
return AthenaPropertiesExtractor.format_athena_column_definitions(sql)
@staticmethod
def _extract_column_types(create_expr: Create) -> Dict[str, str]:
"""Extract column types from a CREATE TABLE expression.
Args:
create_expr: The CREATE TABLE expression to extract types from
Returns:
A dictionary mapping column names to their types
"""
column_types: Dict[str, str] = {}
if not create_expr.this or not hasattr(create_expr.this, "expressions"):
return column_types
try:
for expr in create_expr.this.expressions:
if isinstance(expr, ColumnDef) and expr.this:
column_types[expr.name] = str(expr.kind)
except Exception:
# If we can't extract column types, return empty dict
pass
return column_types
@staticmethod
def _create_column_info(column_name: str, column_type: str) -> ColumnInfo:
"""Create a column info object.
Args:
column_name: Name of the column
column_type: Type of the column
Returns:
A ColumnInfo object
"""
return ColumnInfo(
name=str(column_name) if column_name else "unknown",
type=column_type if column_type else "unknown",
)
@staticmethod
def _handle_function_expression(
expr: Identifier, column_types: Dict[str, str]
) -> Tuple[ColumnInfo, TransformInfo]:
"""Handle function expressions like day(event_timestamp).
Args:
expr: The function expression to handle
column_types: Dictionary of column types
Returns:
A tuple of (column_info, transform_info)
"""
func_str = str(expr)
if "(" not in func_str or ")" not in func_str:
# Fallback for malformed function expressions
column_info = AthenaPropertiesExtractor._create_column_info(
func_str, "unknown"
)
transform_info = TransformInfo(type="unknown", column=column_info)
return column_info, transform_info
try:
func_name = func_str.split("(")[0].lower()
column_part = func_str.split("(")[1].split(")")[0].strip("`")
column_info = AthenaPropertiesExtractor._create_column_info(
column_part, column_types.get(column_part, "unknown")
)
transform_info = TransformInfo(type=func_name, column=column_info)
return column_info, transform_info
except (IndexError, AttributeError):
# Fallback for parsing errors
column_info = AthenaPropertiesExtractor._create_column_info(
func_str, "unknown"
)
transform_info = TransformInfo(type="unknown", column=column_info)
return column_info, transform_info
@staticmethod
def _handle_time_function(
expr: Union[Year, Month, Day], column_types: Dict[str, str]
) -> Tuple[ColumnInfo, TransformInfo]:
"""Handle time-based functions like year, month, day.
Args:
expr: The time function expression to handle
column_types: Dictionary of column types
Returns:
A tuple of (column_info, transform_info)
"""
try:
# Navigate the expression tree safely
column_name = "unknown"
if hasattr(expr, "this") and expr.this:
if hasattr(expr.this, "this") and expr.this.this:
if hasattr(expr.this.this, "this") and expr.this.this.this:
column_name = str(expr.this.this.this)
else:
column_name = str(expr.this.this)
else:
column_name = str(expr.this)
column_info = AthenaPropertiesExtractor._create_column_info(
column_name, column_types.get(column_name, "unknown")
)
transform_info = TransformInfo(
type=expr.__class__.__name__.lower(), column=column_info
)
return column_info, transform_info
except (AttributeError, TypeError):
# Fallback for navigation errors
column_info = AthenaPropertiesExtractor._create_column_info(
"unknown", "unknown"
)
transform_info = TransformInfo(type="unknown", column=column_info)
return column_info, transform_info
@staticmethod
def _handle_transform_function(
expr: Anonymous, column_types: Dict[str, str]
) -> Tuple[ColumnInfo, TransformInfo]:
"""Handle transform functions like bucket, hour, truncate.
Args:
expr: The transform function expression to handle
column_types: Dictionary of column types
Returns:
A tuple of (column_info, transform_info)
"""
try:
# Safely extract column name from the last expression
column_name = "unknown"
if (
hasattr(expr, "expressions")
and expr.expressions
and len(expr.expressions) > 0
):
last_expr = expr.expressions[-1]
if hasattr(last_expr, "this") and last_expr.this:
if hasattr(last_expr.this, "this") and last_expr.this.this:
column_name = str(last_expr.this.this)
else:
column_name = str(last_expr.this)
column_info = AthenaPropertiesExtractor._create_column_info(
column_name, column_types.get(column_name, "unknown")
)
transform_type = str(expr.this).lower() if expr.this else "unknown"
transform_info = TransformInfo(type=transform_type, column=column_info)
# Add transform-specific parameters safely
if (
transform_type == "bucket"
and hasattr(expr, "expressions")
and expr.expressions
and len(expr.expressions) > 0
):
first_expr = expr.expressions[0]
if hasattr(first_expr, "this"):
transform_info.bucket_count = first_expr.this
elif (
transform_type == "truncate"
and hasattr(expr, "expressions")
and expr.expressions
and len(expr.expressions) > 0
):
first_expr = expr.expressions[0]
if hasattr(first_expr, "this"):
transform_info.length = first_expr.this
return column_info, transform_info
except (AttributeError, TypeError, IndexError):
# Fallback for any parsing errors
column_info = AthenaPropertiesExtractor._create_column_info(
"unknown", "unknown"
)
transform_info = TransformInfo(type="unknown", column=column_info)
return column_info, transform_info
def _extract_partition_info(self, parsed: Expression) -> PartitionInfo:
"""Extract partitioning information from the parsed SQL statement.
Args:
parsed: The parsed SQL expression
Returns:
A PartitionInfo object containing simple columns and transforms
"""
# Get the PARTITIONED BY expression
partition_by_expr: Optional[Schema] = None
try:
for prop in parsed.find_all(Property):
if isinstance(prop, PartitionedByProperty):
partition_by_expr = prop.this
break
except Exception:
# If we can't find properties, return empty result
return PartitionInfo(simple_columns=[], transforms=[])
if not partition_by_expr:
return PartitionInfo(simple_columns=[], transforms=[])
# Extract partitioning columns and transforms
simple_columns: List[ColumnInfo] = []
transforms: List[TransformInfo] = []
# Get column types from the table definition
column_types: Dict[str, str] = {}
if isinstance(parsed, Create):
column_types = self._extract_column_types(parsed)
# Process each expression in the PARTITIONED BY clause
if hasattr(partition_by_expr, "expressions") and partition_by_expr.expressions:
for expr in partition_by_expr.expressions:
try:
if isinstance(expr, Identifier) and "(" in str(expr):
column_info, transform_info = self._handle_function_expression(
expr, column_types
)
simple_columns.append(column_info)
transforms.append(transform_info)
elif isinstance(expr, PartitionByTruncate):
column_info = AthenaPropertiesExtractor._create_column_info(
str(expr.this), column_types.get(str(expr.this), "unknown")
)
expression = expr.args.get("expression")
transform_info = TransformInfo(
type="truncate",
column=column_info,
length=int(expression.name)
if expression and expression.name
else None,
)
transforms.append(transform_info)
simple_columns.append(column_info)
elif isinstance(expr, PartitionedByBucket):
column_info = AthenaPropertiesExtractor._create_column_info(
str(expr.this), column_types.get(str(expr.this), "unknown")
)
expression = expr.args.get("expression")
transform_info = TransformInfo(
type="bucket",
column=column_info,
bucket_count=int(expression.name)
if expression and expression.name
else None,
)
simple_columns.append(column_info)
transforms.append(transform_info)
elif isinstance(expr, (Year, Month, Day)):
column_info, transform_info = self._handle_time_function(
expr, column_types
)
transforms.append(transform_info)
simple_columns.append(column_info)
elif (
isinstance(expr, Anonymous)
and expr.this
and str(expr.this).lower() in ["bucket", "hour", "truncate"]
):
column_info, transform_info = self._handle_transform_function(
expr, column_types
)
transforms.append(transform_info)
simple_columns.append(column_info)
elif hasattr(expr, "this") and expr.this:
column_name = str(expr.this)
column_info = self._create_column_info(
column_name, column_types.get(column_name, "unknown")
)
simple_columns.append(column_info)
except Exception:
# Skip problematic expressions rather than failing completely
continue
# Remove duplicates from simple_columns while preserving order
seen_names: Set[str] = set()
unique_simple_columns: List[ColumnInfo] = []
for col in simple_columns:
if col.name and col.name not in seen_names:
seen_names.add(col.name)
unique_simple_columns.append(col)
return PartitionInfo(
simple_columns=unique_simple_columns, transforms=transforms
)
def _extract_table_properties(self, parsed: Expression) -> TableProperties:
"""Extract table properties from the parsed SQL statement.
Args:
parsed: The parsed SQL expression
Returns:
A TableProperties object
"""
location: Optional[str] = None
format_prop: Optional[str] = None
comment: Optional[str] = None
serde_properties: Optional[Dict[str, str]] = None
row_format: Optional[Dict[str, str]] = None
additional_properties: Dict[str, str] = {}
try:
props = list(parsed.find_all(Property))
except Exception:
return TableProperties()
for prop in props:
try:
if isinstance(prop, LocationProperty):
location = self._safe_get_property_value(prop)
elif isinstance(prop, FileFormatProperty):
format_prop = self._safe_get_property_value(prop)
elif isinstance(prop, SchemaCommentProperty):
comment = self._safe_get_property_value(prop)
elif isinstance(prop, PartitionedByProperty):
continue # Skip partition properties here
elif isinstance(prop, SerdeProperties):
serde_props = self._extract_serde_properties(prop)
if serde_props:
serde_properties = serde_props
elif isinstance(prop, RowFormatDelimitedProperty):
row_format_props = self._extract_row_format_properties(prop)
if row_format_props:
row_format = row_format_props
else:
# Handle generic properties
key, value = self._extract_generic_property(prop)
if (
key
and value
and (not serde_properties or key not in serde_properties)
):
additional_properties[key] = value
except Exception:
# Skip problematic properties rather than failing completely
continue
if (
not location
and additional_properties
and additional_properties.get("external_location")
):
location = additional_properties.pop("external_location")
return TableProperties(
location=location,
format=format_prop,
comment=comment,
serde_properties=serde_properties,
row_format=row_format,
additional_properties=additional_properties
if additional_properties
else None,
)
def _safe_get_property_value(self, prop: Property) -> Optional[str]:
"""Safely extract value from a property."""
try:
if (
hasattr(prop, "args")
and "this" in prop.args
and prop.args["this"]
and hasattr(prop.args["this"], "name")
):
return prop.args["this"].name
except (AttributeError, KeyError, TypeError):
pass
return None
def _extract_serde_properties(self, prop: SerdeProperties) -> Dict[str, str]:
"""Extract SERDE properties safely."""
serde_props: Dict[str, str] = {}
try:
if hasattr(prop, "expressions") and prop.expressions:
for exp in prop.expressions:
if (
hasattr(exp, "name")
and hasattr(exp, "args")
and "value" in exp.args
and exp.args["value"]
and hasattr(exp.args["value"], "name")
):
serde_props[exp.name] = exp.args["value"].name
except Exception:
pass
return serde_props
def _extract_row_format_properties(
self, prop: RowFormatDelimitedProperty
) -> Dict[str, str]:
"""Extract row format properties safely."""
row_format: Dict[str, str] = {}
try:
if hasattr(prop, "args") and prop.args:
for key, value in prop.args.items():
if hasattr(value, "this"):
row_format[key] = str(value.this)
else:
row_format[key] = str(value)
except Exception:
pass
return row_format
def _extract_generic_property(
self, prop: Property
) -> Tuple[Optional[str], Optional[str]]:
"""Extract key-value pair from generic property."""
try:
if (
hasattr(prop, "args")
and "this" in prop.args
and prop.args["this"]
and hasattr(prop.args["this"], "name")
and "value" in prop.args
and prop.args["value"]
and hasattr(prop.args["value"], "name")
):
key = prop.args["this"].name.lower()
value = prop.args["value"].name
return key, value
except (AttributeError, KeyError, TypeError):
pass
return None, None
def _extract_row_format(self, parsed: Expression) -> RowFormatInfo:
"""Extract and format RowFormatDelimitedProperty.
Args:
parsed: The parsed SQL expression
Returns:
A RowFormatInfo object
"""
row_format_props: Dict[str, str] = {}
try:
props = parsed.find_all(Property)
for prop in props:
if isinstance(prop, RowFormatDelimitedProperty):
row_format_props = self._extract_row_format_properties(prop)
break
except Exception:
pass
if row_format_props:
try:
json_formatted = json.dumps(row_format_props, indent=2)
except (TypeError, ValueError):
json_formatted = "Error formatting row format properties"
else:
json_formatted = "No RowFormatDelimitedProperty found"
return RowFormatInfo(properties=row_format_props, json_formatted=json_formatted)

View File

@ -273,7 +273,7 @@
},
"fields": [
{
"fieldPath": "[version=2.0].[type=string].employee_id",
"fieldPath": "employee_id",
"nullable": false,
"description": "Unique identifier for the employee",
"type": {
@ -287,7 +287,7 @@
"isPartitioningKey": false
},
{
"fieldPath": "[version=2.0].[type=long].annual_salary",
"fieldPath": "annual_salary",
"nullable": true,
"description": "Annual salary of the employee in USD",
"type": {
@ -301,7 +301,7 @@
"isPartitioningKey": false
},
{
"fieldPath": "[version=2.0].[type=string].employee_name",
"fieldPath": "employee_name",
"nullable": false,
"description": "Full name of the employee",
"type": {
@ -515,7 +515,7 @@
},
"fields": [
{
"fieldPath": "[version=2.0].[type=string].employee_id",
"fieldPath": "employee_id",
"nullable": false,
"description": "Unique identifier for the employee",
"type": {
@ -529,7 +529,7 @@
"isPartitioningKey": false
},
{
"fieldPath": "[version=2.0].[type=long].annual_salary",
"fieldPath": "annual_salary",
"nullable": true,
"description": "Annual salary of the employee in USD",
"type": {
@ -543,7 +543,7 @@
"isPartitioningKey": false
},
{
"fieldPath": "[version=2.0].[type=string].employee_name",
"fieldPath": "employee_name",
"nullable": false,
"description": "Full name of the employee",
"type": {
@ -775,7 +775,7 @@
},
"fields": [
{
"fieldPath": "[version=2.0].[type=string].employee_id",
"fieldPath": "employee_id",
"nullable": false,
"description": "Unique identifier for the employee",
"type": {
@ -789,7 +789,7 @@
"isPartitioningKey": false
},
{
"fieldPath": "[version=2.0].[type=long].annual_salary",
"fieldPath": "annual_salary",
"nullable": true,
"description": "Annual salary of the employee in USD",
"type": {
@ -803,7 +803,7 @@
"isPartitioningKey": false
},
{
"fieldPath": "[version=2.0].[type=string].employee_name",
"fieldPath": "employee_name",
"nullable": false,
"description": "Full name of the employee",
"type": {

View File

@ -0,0 +1,766 @@
"""
Pytest tests for AthenaPropertiesExtractor.
Tests the extraction of properties, partitioning information,
and row format details from various Athena CREATE TABLE SQL statements.
"""
import pytest
from datahub.ingestion.source.sql.athena_properties_extractor import (
AthenaPropertiesExtractionError,
AthenaPropertiesExtractor,
AthenaTableInfo,
ColumnInfo,
PartitionInfo,
RowFormatInfo,
TableProperties,
TransformInfo,
)
class TestAthenaPropertiesExtractor:
"""Test class for AthenaPropertiesExtractor."""
def test_iceberg_table_with_complex_partitioning(self):
"""Test extraction from Iceberg table with complex partitioning."""
sql = """
CREATE TABLE iceberg_table (ts timestamp, id bigint, data string, category string)
PARTITIONED BY (category, bucket(16, id), year(ts), month(ts), day(ts), hour(ts), truncate(10, ts))
LOCATION 's3://amzn-s3-demo-bucket/your-folder/'
TBLPROPERTIES ( 'table_type' = 'ICEBERG' ) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test basic structure
assert isinstance(result, AthenaTableInfo)
assert isinstance(result.partition_info, PartitionInfo)
assert isinstance(result.table_properties, TableProperties)
assert isinstance(result.row_format, RowFormatInfo)
# Test partition info
partition_info = result.partition_info
# Should have multiple simple columns
assert len(partition_info.simple_columns) > 0
# Check for category column (simple partition)
category_cols = [
col for col in partition_info.simple_columns if col.name == "category"
]
assert len(category_cols) == 1
assert category_cols[0].type == "TEXT"
# Check for id column (used in bucket transform)
id_cols = [col for col in partition_info.simple_columns if col.name == "id"]
assert len(id_cols) == 1
assert id_cols[0].type == "BIGINT"
# Check for ts column (used in time transforms)
ts_cols = [col for col in partition_info.simple_columns if col.name == "ts"]
assert len(ts_cols) == 1
assert ts_cols[0].type == "TIMESTAMP"
# Test transforms
transforms = partition_info.transforms
assert len(transforms) >= 6 # bucket, year, month, day, hour, truncate
# Check bucket transform
bucket_transforms = [t for t in transforms if t.type == "bucket"]
assert len(bucket_transforms) == 1
bucket_transform = bucket_transforms[0]
assert bucket_transform.column.name == "id"
assert bucket_transform.bucket_count == 16
# Check time transforms
time_transform_types = {
t.type for t in transforms if t.type in ["year", "month", "day", "hour"]
}
assert "year" in time_transform_types
assert "month" in time_transform_types
assert "day" in time_transform_types
assert "hour" in time_transform_types
# Check truncate transform
truncate_transforms = [t for t in transforms if t.type == "truncate"]
assert len(truncate_transforms) == 1
truncate_transform = truncate_transforms[0]
assert truncate_transform.column.name == "ts"
assert truncate_transform.length == 10
# Test table properties
table_props = result.table_properties
assert table_props.location == "s3://amzn-s3-demo-bucket/your-folder/"
assert table_props.additional_properties is not None
assert table_props.additional_properties.get("table_type") == "ICEBERG"
def test_trino_table_with_array_partitioning(self):
"""Test extraction from Trino table with ARRAY partitioning."""
sql = """
create table trino.db_collection (
col1 varchar,
col2 varchar,
col3 varchar
)with (
external_location = 's3a://bucket/trino/db_collection/*',
format = 'PARQUET',
partitioned_by = ARRAY['col1','col2']
) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test table properties
table_props = result.table_properties
assert table_props.location == "s3a://bucket/trino/db_collection/*"
assert table_props.format == "PARQUET"
# Note: ARRAY partitioning might not be parsed the same way as standard PARTITIONED BY
# This tests that the extraction doesn't fail and extracts what it can
def test_simple_orc_table(self):
"""Test extraction from simple ORC table."""
sql = """
CREATE TABLE orders (
orderkey bigint,
orderstatus varchar,
totalprice double,
orderdate date
)
WITH (format = 'ORC')
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test basic structure
assert isinstance(result, AthenaTableInfo)
# Should have no partitions
assert len(result.partition_info.simple_columns) == 0
assert len(result.partition_info.transforms) == 0
# Test table properties
table_props = result.table_properties
assert table_props.format == "ORC"
assert table_props.location is None
assert table_props.comment is None
def test_table_with_comments(self):
"""Test extraction from table with table and column comments."""
sql = """
CREATE TABLE IF NOT EXISTS orders (
orderkey bigint,
orderstatus varchar,
totalprice double COMMENT 'Price in cents.',
orderdate date
)
COMMENT 'A table to keep track of orders.' \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test table comment
table_props = result.table_properties
assert table_props.comment == "A table to keep track of orders."
# No partitions expected
assert len(result.partition_info.simple_columns) == 0
assert len(result.partition_info.transforms) == 0
def test_table_with_row_format_and_serde(self):
"""Test extraction from table with row format and SERDE properties."""
sql = """
CREATE TABLE IF NOT EXISTS orders (
orderkey bigint,
orderstatus varchar,
totalprice double,
orderdate date
)
ROW FORMAT DELIMITED COLLECTION ITEMS TERMINATED BY ','
STORED AS PARQUET
WITH SERDEPROPERTIES (
'serialization.format' = '1'
) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test table properties
table_props = result.table_properties
# Test SERDE properties
assert table_props.serde_properties is not None
assert table_props.serde_properties.get("serialization.format") == "1"
# Test row format
row_format = result.row_format
assert isinstance(row_format, RowFormatInfo)
assert isinstance(row_format.properties, dict)
assert "No RowFormatDelimitedProperty found" not in row_format.json_formatted
def test_empty_sql_raises_error(self):
"""Test that empty SQL raises appropriate error."""
with pytest.raises(
AthenaPropertiesExtractionError, match="SQL statement cannot be empty"
):
AthenaPropertiesExtractor.get_table_properties("")
with pytest.raises(
AthenaPropertiesExtractionError, match="SQL statement cannot be empty"
):
AthenaPropertiesExtractor.get_table_properties(" ")
def test_minimal_create_table(self):
"""Test extraction from minimal CREATE TABLE statement."""
sql = "CREATE TABLE test (id int)"
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Should not fail and return basic structure
assert isinstance(result, AthenaTableInfo)
assert len(result.partition_info.simple_columns) == 0
assert len(result.partition_info.transforms) == 0
assert result.table_properties.location is None
def test_column_info_dataclass(self):
"""Test ColumnInfo dataclass properties."""
sql = """
CREATE TABLE test (id bigint, name varchar)
PARTITIONED BY (id) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test that we get ColumnInfo objects
assert len(result.partition_info.simple_columns) == 1
column = result.partition_info.simple_columns[0]
assert isinstance(column, ColumnInfo)
assert column.name == "id"
assert column.type == "BIGINT"
def test_transform_info_dataclass(self):
"""Test TransformInfo dataclass properties."""
sql = """
CREATE TABLE test (ts timestamp, id bigint)
PARTITIONED BY (year(ts), bucket(8, id)) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
transforms = result.partition_info.transforms
assert len(transforms) >= 2
# Find year transform
year_transforms = [t for t in transforms if t.type == "year"]
assert len(year_transforms) == 1
year_transform = year_transforms[0]
assert isinstance(year_transform, TransformInfo)
assert year_transform.type == "year"
assert isinstance(year_transform.column, ColumnInfo)
assert year_transform.column.name == "ts"
assert year_transform.bucket_count is None
assert year_transform.length is None
# Find bucket transform
bucket_transforms = [t for t in transforms if t.type == "bucket"]
assert len(bucket_transforms) == 1
bucket_transform = bucket_transforms[0]
assert isinstance(bucket_transform, TransformInfo)
assert bucket_transform.type == "bucket"
assert bucket_transform.column.name == "id"
assert bucket_transform.bucket_count == 8
assert bucket_transform.length is None
def test_multiple_sql_statements_stateless(self):
"""Test that the extractor is stateless and works with multiple SQL statements."""
sql1 = "CREATE TABLE test1 (id int) WITH (format = 'PARQUET')"
sql2 = "CREATE TABLE test2 (name varchar) WITH (format = 'ORC')"
# Call multiple times to ensure no state interference
result1 = AthenaPropertiesExtractor.get_table_properties(sql1)
result2 = AthenaPropertiesExtractor.get_table_properties(sql2)
result1_again = AthenaPropertiesExtractor.get_table_properties(sql1)
# Results should be consistent
assert result1.table_properties.format == "PARQUET"
assert result2.table_properties.format == "ORC"
assert result1_again.table_properties.format == "PARQUET"
# Results should be independent
assert result1.table_properties.format != result2.table_properties.format
@pytest.mark.parametrize(
"sql,expected_location",
[
(
"CREATE TABLE test (id int) LOCATION 's3://bucket/path/'",
"s3://bucket/path/",
),
("CREATE TABLE test (id int)", None),
],
)
def test_location_extraction_parametrized(self, sql, expected_location):
"""Test location extraction with parametrized inputs."""
result = AthenaPropertiesExtractor.get_table_properties(sql)
assert result.table_properties.location == expected_location
# Integration test that could be run with actual SQL files
class TestAthenaPropertiesExtractorIntegration:
"""Integration tests for AthenaPropertiesExtractor."""
def test_complex_real_world_example(self):
"""Test with a complex real-world-like example."""
sql = """
CREATE TABLE analytics.user_events (
user_id bigint COMMENT 'Unique user identifier',
event_time timestamp COMMENT 'When the event occurred',
event_type varchar COMMENT 'Type of event',
session_id varchar,
properties map<varchar, varchar> COMMENT 'Event properties',
created_date date
)
COMMENT 'User event tracking table'
PARTITIONED BY (
created_date,
bucket(100, user_id),
hour(event_time)
)
LOCATION 's3://analytics-bucket/user-events/'
STORED AS PARQUET
TBLPROPERTIES (
'table_type' = 'ICEBERG',
'write.target-file-size-bytes' = '134217728',
'write.delete.mode' = 'copy-on-write'
) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Comprehensive validation
assert isinstance(result, AthenaTableInfo)
# Check table properties
props = result.table_properties
assert props.location == "s3://analytics-bucket/user-events/"
assert props.comment == "User event tracking table"
assert props.additional_properties is not None
assert props.additional_properties.get("table_type") == "ICEBERG"
# Check partitioning
partition_info = result.partition_info
# Should have created_date as simple partition
date_cols = [
col for col in partition_info.simple_columns if col.name == "user_id"
]
assert len(date_cols) == 1
assert date_cols[0].type == "BIGINT"
date_cols = [
col for col in partition_info.simple_columns if col.name == "created_date"
]
assert len(date_cols) == 1
assert date_cols[0].type == "DATE"
# Should have transforms
transforms = partition_info.transforms
transform_types = {t.type for t in transforms}
assert "bucket" in transform_types
assert "hour" in transform_types
# Validate bucket transform
bucket_transforms = [t for t in transforms if t.type == "bucket"]
assert len(bucket_transforms) == 1
assert bucket_transforms[0].bucket_count == 100
assert bucket_transforms[0].column.name == "user_id"
def test_external_table_with_row_format_delimited(self):
"""Test extraction from external table with detailed row format."""
sql = """
CREATE EXTERNAL TABLE `my_table`(
`itcf id` string,
`itcf control name` string,
`itcf control description` string,
`itcf process` string,
`standard` string,
`controlid` string,
`threshold` string,
`status` string,
`date reported` string,
`remediation (accs specific)` string,
`aws account id` string,
`aws resource id` string,
`aws account owner` string)
ROW FORMAT DELIMITED
FIELDS TERMINATED BY ','
ESCAPED BY '\\\\'
LINES TERMINATED BY '\\n'
LOCATION
's3://myfolder/'
TBLPROPERTIES (
'skip.header.line.count'='1');
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test basic structure
assert isinstance(result, AthenaTableInfo)
# Test table properties
table_props = result.table_properties
assert table_props.location == "s3://myfolder/"
# Test TBLPROPERTIES
assert table_props.additional_properties is not None
assert table_props.additional_properties.get("skip.header.line.count") == "1"
# Test row format
row_format = result.row_format
assert isinstance(row_format, RowFormatInfo)
# The row format should contain delimited properties
# Note: The exact keys depend on how sqlglot parses ROW FORMAT DELIMITED
assert isinstance(row_format.properties, dict)
# Should have structured JSON output
assert row_format.json_formatted != "No RowFormatDelimitedProperty found"
# Should not have partitions (no PARTITIONED BY clause)
assert len(result.partition_info.simple_columns) == 0
assert len(result.partition_info.transforms) == 0
def test_database_qualified_table_with_iceberg_properties(self):
"""Test extraction from database-qualified table with Iceberg properties."""
sql = """
CREATE TABLE mydatabase.my_table (
id string,
name string,
type string,
industry string,
annual_revenue double,
website string,
phone string,
billing_street string,
billing_city string,
billing_state string,
billing_postal_code string,
billing_country string,
shipping_street string,
shipping_city string,
shipping_state string,
shipping_postal_code string,
shipping_country string,
number_of_employees int,
description string,
owner_id string,
created_date timestamp,
last_modified_date timestamp,
is_deleted boolean)
LOCATION 's3://mybucket/myfolder/'
TBLPROPERTIES (
'table_type'='iceberg',
'write_compression'='snappy',
'format'='parquet',
'optimize_rewrite_delete_file_threshold'='10'
); \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test basic structure
assert isinstance(result, AthenaTableInfo)
# Test table properties
table_props = result.table_properties
assert table_props.location == "s3://mybucket/myfolder/"
# Test multiple TBLPROPERTIES
assert table_props.additional_properties is not None
expected_props = {
"table_type": "iceberg",
"write_compression": "snappy",
"format": "parquet",
"optimize_rewrite_delete_file_threshold": "10",
}
for key, expected_value in expected_props.items():
assert table_props.additional_properties.get(key) == expected_value, (
f"Expected {key}={expected_value}, got {table_props.additional_properties.get(key)}"
)
# Should not have partitions (no PARTITIONED BY clause)
assert len(result.partition_info.simple_columns) == 0
assert len(result.partition_info.transforms) == 0
# Row format should be empty/default
row_format = result.row_format
assert isinstance(row_format, RowFormatInfo)
# Should either be empty dict or indicate no row format found
assert (
len(row_format.properties) == 0
or "No RowFormatDelimitedProperty found" in row_format.json_formatted
)
def test_iceberg_table_with_backtick_partitioning(self):
"""Test extraction from Iceberg table with backtick-quoted partition functions."""
sql = """
CREATE TABLE datalake_agg.ml_outdoor_master (
event_uuid string,
uuid string,
_pk string)
PARTITIONED BY (
`day(event_timestamp)`,
`month(event_timestamp)`
)
LOCATION 's3://bucket/folder/table'
TBLPROPERTIES (
'table_type'='iceberg',
'vacuum_max_snapshot_age_seconds'='60',
'format'='PARQUET',
'write_compression'='GZIP',
'optimize_rewrite_delete_file_threshold'='2',
'optimize_rewrite_data_file_threshold'='5',
'vacuum_min_snapshots_to_keep'='6'
) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test basic structure
assert isinstance(result, AthenaTableInfo)
# Test table properties
table_props = result.table_properties
assert table_props.location == "s3://bucket/folder/table"
# Test comprehensive TBLPROPERTIES for Iceberg
assert table_props.additional_properties is not None
expected_props = {
"table_type": "iceberg",
"vacuum_max_snapshot_age_seconds": "60",
"format": "PARQUET",
"write_compression": "GZIP",
"optimize_rewrite_delete_file_threshold": "2",
"optimize_rewrite_data_file_threshold": "5",
"vacuum_min_snapshots_to_keep": "6",
}
for key, expected_value in expected_props.items():
actual_value = table_props.additional_properties.get(key)
assert actual_value == expected_value, (
f"Expected {key}={expected_value}, got {actual_value}"
)
# Test partition info - this is the interesting part with backtick-quoted functions
partition_info = result.partition_info
# Should have transforms for day() and month() functions
transforms = partition_info.transforms
assert len(transforms) >= 2, (
f"Expected at least 2 transforms, got {len(transforms)}"
)
# Check for day transform
day_transforms = [t for t in transforms if t.type == "day"]
assert len(day_transforms) >= 1, (
f"Expected day transform, transforms: {[t.type for t in transforms]}"
)
if day_transforms:
day_transform = day_transforms[0]
assert isinstance(day_transform, TransformInfo)
assert day_transform.type == "day"
assert isinstance(day_transform.column, ColumnInfo)
# The column should be event_timestamp (extracted from day(event_timestamp))
assert day_transform.column.name == "event_timestamp"
# Check for month transform
month_transforms = [t for t in transforms if t.type == "month"]
assert len(month_transforms) >= 1, (
f"Expected month transform, transforms: {[t.type for t in transforms]}"
)
if month_transforms:
month_transform = month_transforms[0]
assert isinstance(month_transform, TransformInfo)
assert month_transform.type == "month"
assert isinstance(month_transform.column, ColumnInfo)
# The column should be event_timestamp (extracted from month(event_timestamp))
assert month_transform.column.name == "event_timestamp"
# Test simple columns - should include event_timestamp from the transforms
simple_columns = partition_info.simple_columns
event_timestamp_cols = [
col for col in simple_columns if col.name == "event_timestamp"
]
assert len(event_timestamp_cols) >= 1, (
f"Expected event_timestamp column, columns: {[col.name for col in simple_columns]}"
)
# The event_timestamp column type might be "unknown" since it's not in the table definition
# but referenced in partitioning - this tests our defensive handling
if event_timestamp_cols:
event_timestamp_col = event_timestamp_cols[0]
assert isinstance(event_timestamp_col, ColumnInfo)
assert event_timestamp_col.name == "event_timestamp"
# Type should be "unknown" since event_timestamp is not in the table columns
assert event_timestamp_col.type == "unknown"
def test_partition_function_extraction_edge_cases(self):
"""Test edge cases in partition function extraction with various formats."""
sql = """
CREATE TABLE test_partitions (
ts timestamp,
id bigint,
data string
)
PARTITIONED BY (
`day(ts)`,
`bucket(5, id)`,
`truncate(100, data)`
) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
partition_info = result.partition_info
transforms = partition_info.transforms
# Should have 3 transforms
assert len(transforms) == 3
# Verify each transform type exists
transform_types = {t.type for t in transforms}
assert "day" in transform_types
assert "bucket" in transform_types
assert "truncate" in transform_types
# Test bucket transform parameters
bucket_transforms = [t for t in transforms if t.type == "bucket"]
if bucket_transforms:
bucket_transform = bucket_transforms[0]
assert bucket_transform.bucket_count == 5
assert bucket_transform.column.name == "id"
assert bucket_transform.column.type == "BIGINT"
# Test truncate transform parameters
truncate_transforms = [t for t in transforms if t.type == "truncate"]
if truncate_transforms:
truncate_transform = truncate_transforms[0]
assert truncate_transform.length == 100
assert truncate_transform.column.name == "data"
assert truncate_transform.column.type == "TEXT"
# Test day transform
day_transforms = [t for t in transforms if t.type == "day"]
if day_transforms:
day_transform = day_transforms[0]
assert day_transform.column.name == "ts"
assert day_transform.column.type == "TIMESTAMP"
def test_partition_function_extraction_edge_cases_with_different_quote(self):
"""Test edge cases in partition function extraction with various formats."""
sql = """
CREATE TABLE test_partitions (
ts timestamp,
id bigint,
data string
)
PARTITIONED BY (
day(`ts`),
bucket(5, `id`),
truncate(100, `data`)
) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
partition_info = result.partition_info
transforms = partition_info.transforms
# Should have 3 transforms
assert len(transforms) == 3
# Verify each transform type exists
transform_types = {t.type for t in transforms}
assert "day" in transform_types
assert "bucket" in transform_types
assert "truncate" in transform_types
# Test bucket transform parameters
bucket_transforms = [t for t in transforms if t.type == "bucket"]
if bucket_transforms:
bucket_transform = bucket_transforms[0]
assert bucket_transform.bucket_count == 5
assert bucket_transform.column.name == "id"
assert bucket_transform.column.type == "BIGINT"
# Test truncate transform parameters
truncate_transforms = [t for t in transforms if t.type == "truncate"]
if truncate_transforms:
truncate_transform = truncate_transforms[0]
assert truncate_transform.length == 100
assert truncate_transform.column.name == "data"
assert truncate_transform.column.type == "TEXT"
# Test day transform
day_transforms = [t for t in transforms if t.type == "day"]
if day_transforms:
day_transform = day_transforms[0]
assert day_transform.column.name == "ts"
assert day_transform.column.type == "TIMESTAMP"
def test_complex_real_world_example_with_non_escaped_column_name_and_column_comment(
self,
):
"""Athena's show create table statement doesn't return columns in escaped."""
sql = """
CREATE TABLE test_schema.test_table (
date___hour timestamp,
month string COMMENT 'Month of the year',
date string,
hourly_forecast bigint,
previous_year's_sales bigint COMMENT Previous year's sales,
sheet_name string,
_id string)
LOCATION 's3://analytics-bucket/user-events/'
TBLPROPERTIES (
'table_type'='iceberg',
'vacuum_max_snapshot_age_seconds'='60',
'write_compression'='gzip',
'format'='parquet',
'optimize_rewrite_delete_file_threshold'='2',
'optimize_rewrite_data_file_threshold'='5',
'vacuum_min_snapshots_to_keep'='6'
)
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Comprehensive validation
assert isinstance(result, AthenaTableInfo)
# Check table properties
props = result.table_properties
assert props.location == "s3://analytics-bucket/user-events/"
assert props.additional_properties is not None
assert props.additional_properties.get("table_type") == "iceberg"
assert (
props.additional_properties.get("vacuum_max_snapshot_age_seconds") == "60"
)
assert props.additional_properties.get("format") == "parquet"
assert props.additional_properties.get("write_compression") == "gzip"
assert (
props.additional_properties.get("optimize_rewrite_delete_file_threshold")
== "2"
)
assert (
props.additional_properties.get("optimize_rewrite_data_file_threshold")
== "5"
)
assert props.additional_properties.get("vacuum_min_snapshots_to_keep") == "6"

View File

@ -1,8 +1,10 @@
from datetime import datetime
from typing import List
from unittest import mock
import pytest
from freezegun import freeze_time
from pyathena import OperationalError
from sqlalchemy import types
from sqlalchemy_bigquery import STRUCT
@ -81,6 +83,8 @@ def test_athena_get_table_properties():
"aws_region": "us-west-1",
"s3_staging_dir": "s3://sample-staging-dir/",
"work_group": "test-workgroup",
"profiling": {"enabled": True, "partition_profiling_enabled": True},
"extract_partitions_using_create_statements": True,
}
)
schema: str = "test_schema"
@ -108,17 +112,33 @@ def test_athena_get_table_properties():
mock_cursor = mock.MagicMock()
mock_inspector = mock.MagicMock()
mock_inspector.engine.raw_connection().cursor.return_value = mock_cursor
mock_cursor.get_table_metadata.return_value = AthenaTableMetadata(
response=table_metadata
)
class MockCursorResult:
def __init__(self, data: List, description: List):
self._data = data
self._description = description
def __iter__(self):
"""Makes the object iterable, which allows list() to work"""
return iter(self._data)
@property
def description(self):
"""Returns the description as requested"""
return self._description
mock_result = MockCursorResult(
data=[["2023", "12"]], description=[["year"], ["month"]]
)
# Mock partition query results
mock_cursor.execute.return_value.description = [
["year"],
["month"],
mock_cursor.execute.side_effect = [
OperationalError("First call fails"),
mock_result,
]
mock_cursor.execute.return_value.__iter__.return_value = [["2023", "12"]]
mock_cursor.fetchall.side_effect = [OperationalError("First call fails")]
ctx = PipelineContext(run_id="test")
source = AthenaSource(config=config, ctx=ctx)
@ -148,13 +168,16 @@ def test_athena_get_table_properties():
assert partitions == ["year", "month"]
# Verify the correct SQL query was generated for partitions
expected_create_table_query = "SHOW CREATE TABLE `test_schema`.`test_table`"
expected_query = """\
select year,month from "test_schema"."test_table$partitions" \
where CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR) = \
(select max(CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR)) \
from "test_schema"."test_table$partitions")"""
mock_cursor.execute.assert_called_once()
actual_query = mock_cursor.execute.call_args[0][0]
assert mock_cursor.execute.call_count == 2
assert expected_create_table_query == mock_cursor.execute.call_args_list[0][0][0]
actual_query = mock_cursor.execute.call_args_list[1][0][0]
assert actual_query == expected_query
# Verify partition cache was populated correctly