mirror of
https://github.com/datahub-project/datahub.git
synced 2025-07-24 10:00:07 +00:00
feat(ingest/athena): Iceberg partition columns extraction (#13607)
This commit is contained in:
parent
17aa2d72a5
commit
5b8d4bad7c
@ -34,6 +34,9 @@ from datahub.ingestion.source.common.subtypes import (
|
||||
SourceCapabilityModifier,
|
||||
)
|
||||
from datahub.ingestion.source.ge_profiling_config import GEProfilingConfig
|
||||
from datahub.ingestion.source.sql.athena_properties_extractor import (
|
||||
AthenaPropertiesExtractor,
|
||||
)
|
||||
from datahub.ingestion.source.sql.sql_common import (
|
||||
SQLAlchemySource,
|
||||
register_custom_type,
|
||||
@ -47,12 +50,17 @@ from datahub.ingestion.source.sql.sql_utils import (
|
||||
)
|
||||
from datahub.ingestion.source.sql.sqlalchemy_uri import make_sqlalchemy_uri
|
||||
from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField
|
||||
from datahub.metadata.schema_classes import MapTypeClass, RecordTypeClass
|
||||
from datahub.metadata.schema_classes import (
|
||||
ArrayTypeClass,
|
||||
MapTypeClass,
|
||||
RecordTypeClass,
|
||||
)
|
||||
from datahub.utilities.hive_schema_to_avro import get_avro_schema_for_hive_column
|
||||
from datahub.utilities.sqlalchemy_type_converter import (
|
||||
MapType,
|
||||
get_schema_fields_for_sqlalchemy_column,
|
||||
)
|
||||
from datahub.utilities.urns.field_paths import get_simple_field_path_from_v2_field_path
|
||||
|
||||
try:
|
||||
from typing_extensions import override
|
||||
@ -284,6 +292,11 @@ class AthenaConfig(SQLCommonConfig):
|
||||
description="Extract partitions for tables. Partition extraction needs to run a query (`select * from table$partitions`) on the table. Disable this if you don't want to grant select permission.",
|
||||
)
|
||||
|
||||
extract_partitions_using_create_statements: bool = pydantic.Field(
|
||||
default=False,
|
||||
description="Extract partitions using the `SHOW CREATE TABLE` statement instead of querying the table's partitions directly. This needs to be enabled to extract Iceberg partitions. If extraction fails it falls back to the default partition extraction. This is experimental.",
|
||||
)
|
||||
|
||||
_s3_staging_dir_population = pydantic_renamed_field(
|
||||
old_name="s3_staging_dir",
|
||||
new_name="query_result_location",
|
||||
@ -496,23 +509,38 @@ class AthenaSource(SQLAlchemySource):
|
||||
def get_partitions(
|
||||
self, inspector: Inspector, schema: str, table: str
|
||||
) -> Optional[List[str]]:
|
||||
if not self.config.extract_partitions:
|
||||
if (
|
||||
not self.config.extract_partitions
|
||||
and not self.config.extract_partitions_using_create_statements
|
||||
):
|
||||
return None
|
||||
|
||||
if not self.cursor:
|
||||
return None
|
||||
|
||||
metadata: AthenaTableMetadata = self.cursor.get_table_metadata(
|
||||
table_name=table, schema_name=schema
|
||||
)
|
||||
if self.config.extract_partitions_using_create_statements:
|
||||
try:
|
||||
partitions = self._get_partitions_create_table(schema, table)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to get partitions from create table statement for {schema}.{table} because of {e}. Falling back to SQLAlchemy.",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# If we can't get create table statement, we fall back to SQLAlchemy
|
||||
partitions = self._get_partitions_sqlalchemy(schema, table)
|
||||
else:
|
||||
partitions = self._get_partitions_sqlalchemy(schema, table)
|
||||
|
||||
partitions = []
|
||||
for key in metadata.partition_keys:
|
||||
if key.name:
|
||||
partitions.append(key.name)
|
||||
if not partitions:
|
||||
return []
|
||||
|
||||
if (
|
||||
not self.config.profiling.enabled
|
||||
or not self.config.profiling.partition_profiling_enabled
|
||||
):
|
||||
return partitions
|
||||
|
||||
with self.report.report_exc(
|
||||
message="Failed to extract partition details",
|
||||
context=f"{schema}.{table}",
|
||||
@ -538,6 +566,56 @@ class AthenaSource(SQLAlchemySource):
|
||||
|
||||
return partitions
|
||||
|
||||
def _get_partitions_create_table(self, schema: str, table: str) -> List[str]:
|
||||
assert self.cursor
|
||||
try:
|
||||
res = self.cursor.execute(f"SHOW CREATE TABLE `{schema}`.`{table}`")
|
||||
except Exception as e:
|
||||
# Athena does not support SHOW CREATE TABLE for views
|
||||
# and will throw an error. We need to handle this case
|
||||
# and caller needs to fallback to sqlalchemy's get partitions call.
|
||||
logger.debug(
|
||||
f"Failed to get table properties for {schema}.{table}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise e
|
||||
rows = res.fetchall()
|
||||
|
||||
# Concatenate all rows into a single string with newlines
|
||||
create_table_statement = "\n".join(row[0] for row in rows)
|
||||
|
||||
try:
|
||||
athena_table_info = AthenaPropertiesExtractor.get_table_properties(
|
||||
create_table_statement
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Failed to parse table properties for {schema}.{table}: {e} and statement: {create_table_statement}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise e
|
||||
|
||||
partitions = []
|
||||
if (
|
||||
athena_table_info.partition_info
|
||||
and athena_table_info.partition_info.simple_columns
|
||||
):
|
||||
partitions = [
|
||||
ci.name for ci in athena_table_info.partition_info.simple_columns
|
||||
]
|
||||
return partitions
|
||||
|
||||
def _get_partitions_sqlalchemy(self, schema: str, table: str) -> List[str]:
|
||||
assert self.cursor
|
||||
metadata: AthenaTableMetadata = self.cursor.get_table_metadata(
|
||||
table_name=table, schema_name=schema
|
||||
)
|
||||
partitions = []
|
||||
for key in metadata.partition_keys:
|
||||
if key.name:
|
||||
partitions.append(key.name)
|
||||
return partitions
|
||||
|
||||
# Overwrite to modify the creation of schema fields
|
||||
def get_schema_fields_for_column(
|
||||
self,
|
||||
@ -563,7 +641,14 @@ class AthenaSource(SQLAlchemySource):
|
||||
partition_keys is not None and column["name"] in partition_keys
|
||||
),
|
||||
)
|
||||
|
||||
if isinstance(
|
||||
fields[0].type.type, (RecordTypeClass, MapTypeClass, ArrayTypeClass)
|
||||
):
|
||||
return fields
|
||||
else:
|
||||
fields[0].fieldPath = get_simple_field_path_from_v2_field_path(
|
||||
fields[0].fieldPath
|
||||
)
|
||||
return fields
|
||||
|
||||
def generate_partition_profiler_query(
|
||||
|
@ -0,0 +1,777 @@
|
||||
"""
|
||||
Athena Properties Extractor - A robust tool for parsing CREATE TABLE statements.
|
||||
|
||||
This module provides functionality to extract properties, partitioning information,
|
||||
and row format details from Athena CREATE TABLE SQL statements.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from sqlglot import ParseError, parse_one
|
||||
from sqlglot.dialects.athena import Athena
|
||||
from sqlglot.expressions import (
|
||||
Anonymous,
|
||||
ColumnDef,
|
||||
Create,
|
||||
Day,
|
||||
Expression,
|
||||
FileFormatProperty,
|
||||
Identifier,
|
||||
LocationProperty,
|
||||
Month,
|
||||
PartitionByTruncate,
|
||||
PartitionedByBucket,
|
||||
PartitionedByProperty,
|
||||
Property,
|
||||
RowFormatDelimitedProperty,
|
||||
Schema,
|
||||
SchemaCommentProperty,
|
||||
SerdeProperties,
|
||||
Year,
|
||||
)
|
||||
|
||||
|
||||
class AthenaPropertiesExtractionError(Exception):
|
||||
"""Custom exception for Athena properties extraction errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColumnInfo:
|
||||
"""Information about a table column."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformInfo:
|
||||
"""Information about a partition transform."""
|
||||
|
||||
type: str
|
||||
column: ColumnInfo
|
||||
bucket_count: Optional[int] = None
|
||||
length: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PartitionInfo:
|
||||
"""Information about table partitioning."""
|
||||
|
||||
simple_columns: List[ColumnInfo]
|
||||
transforms: List[TransformInfo]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableProperties:
|
||||
"""General table properties."""
|
||||
|
||||
location: Optional[str] = None
|
||||
format: Optional[str] = None
|
||||
comment: Optional[str] = None
|
||||
serde_properties: Optional[Dict[str, str]] = None
|
||||
row_format: Optional[Dict[str, str]] = None
|
||||
additional_properties: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RowFormatInfo:
|
||||
"""Row format information."""
|
||||
|
||||
properties: Dict[str, str]
|
||||
json_formatted: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AthenaTableInfo:
|
||||
"""Complete information about an Athena table."""
|
||||
|
||||
partition_info: PartitionInfo
|
||||
table_properties: TableProperties
|
||||
row_format: RowFormatInfo
|
||||
|
||||
|
||||
class AthenaPropertiesExtractor:
|
||||
"""A class to extract properties from Athena CREATE TABLE statements."""
|
||||
|
||||
CREATE_TABLE_REGEXP = re.compile(
|
||||
"(CREATE TABLE[\s\n]*)(.*?)(\s*\()", re.MULTILINE | re.IGNORECASE
|
||||
)
|
||||
PARTITIONED_BY_REGEXP = re.compile(
|
||||
"(PARTITIONED BY[\s\n]*\()((?:[^()]|\([^)]*\))*?)(\))",
|
||||
re.MULTILINE | re.IGNORECASE,
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the extractor."""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_table_properties(sql: str) -> AthenaTableInfo:
|
||||
"""Get all table properties from a SQL statement.
|
||||
|
||||
Args:
|
||||
sql: The SQL statement to parse
|
||||
|
||||
Returns:
|
||||
An AthenaTableInfo object containing all table properties
|
||||
|
||||
Raises:
|
||||
AthenaPropertiesExtractionError: If extraction fails
|
||||
"""
|
||||
extractor = AthenaPropertiesExtractor()
|
||||
return extractor._extract_all_properties(sql)
|
||||
|
||||
def _extract_all_properties(self, sql: str) -> AthenaTableInfo:
|
||||
"""Extract all properties from a SQL statement.
|
||||
|
||||
Args:
|
||||
sql: The SQL statement to parse
|
||||
|
||||
Returns:
|
||||
An AthenaTableInfo object containing all properties
|
||||
|
||||
Raises:
|
||||
AthenaPropertiesExtractionError: If extraction fails
|
||||
"""
|
||||
if not sql or not sql.strip():
|
||||
raise AthenaPropertiesExtractionError("SQL statement cannot be empty")
|
||||
|
||||
try:
|
||||
# We need to do certain transformations on the sql create statement:
|
||||
# - table names are not quoted
|
||||
# - column expression is not quoted
|
||||
# - sql parser fails if partition colums quoted
|
||||
fixed_sql = self._fix_sql_partitioning(sql)
|
||||
parsed = parse_one(fixed_sql, dialect=Athena)
|
||||
except ParseError as e:
|
||||
raise AthenaPropertiesExtractionError(f"Failed to parse SQL: {e}") from e
|
||||
except Exception as e:
|
||||
raise AthenaPropertiesExtractionError(
|
||||
f"Unexpected error during SQL parsing: {e}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
partition_info = self._extract_partition_info(parsed)
|
||||
table_properties = self._extract_table_properties(parsed)
|
||||
row_format = self._extract_row_format(parsed)
|
||||
|
||||
return AthenaTableInfo(
|
||||
partition_info=partition_info,
|
||||
table_properties=table_properties,
|
||||
row_format=row_format,
|
||||
)
|
||||
except Exception as e:
|
||||
raise AthenaPropertiesExtractionError(
|
||||
f"Failed to extract table properties: {e}"
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def format_column_definition(line):
|
||||
# Use regex to parse the line more accurately
|
||||
# Pattern: column_name data_type [COMMENT comment_text] [,]
|
||||
# Use greedy match for comment to capture everything until trailing comma
|
||||
pattern = r"^\s*(.+?)\s+([\s,\w<>\[\]]+)((\s+COMMENT\s+(.+?)(,?))|(,?)\s*)?$"
|
||||
match = re.match(pattern, line, re.IGNORECASE)
|
||||
|
||||
if not match:
|
||||
return line
|
||||
column_name = match.group(1)
|
||||
data_type = match.group(2)
|
||||
comment_part = match.group(5) # COMMENT part
|
||||
# there are different number of match groups depending on whether comment exists
|
||||
if comment_part:
|
||||
trailing_comma = match.group(6) if match.group(6) else ""
|
||||
else:
|
||||
trailing_comma = match.group(7) if match.group(7) else ""
|
||||
|
||||
# Add backticks to column name if not already present
|
||||
if not (column_name.startswith("`") and column_name.endswith("`")):
|
||||
column_name = f"`{column_name}`"
|
||||
|
||||
# Build the result
|
||||
result_parts = [column_name, data_type]
|
||||
|
||||
if comment_part:
|
||||
comment_part = comment_part.strip()
|
||||
|
||||
# Handle comment quoting and escaping
|
||||
if comment_part.startswith("'") and comment_part.endswith("'"):
|
||||
# Already properly single quoted - keep as is
|
||||
formatted_comment = comment_part
|
||||
elif comment_part.startswith('"') and comment_part.endswith('"'):
|
||||
# Double quoted - convert to single quotes and escape internal single quotes
|
||||
inner_content = comment_part[1:-1]
|
||||
escaped_content = inner_content.replace("'", "''")
|
||||
formatted_comment = f"'{escaped_content}'"
|
||||
else:
|
||||
# Not quoted - add quotes and escape any single quotes
|
||||
escaped_content = comment_part.replace("'", "''")
|
||||
formatted_comment = f"'{escaped_content}'"
|
||||
|
||||
result_parts.extend(["COMMENT", formatted_comment])
|
||||
|
||||
result = " " + " ".join(result_parts) + trailing_comma
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def format_athena_column_definitions(sql_statement: str) -> str:
|
||||
"""
|
||||
Format Athena CREATE TABLE statement by:
|
||||
1. Adding backticks around column names in column definitions (only in the main table definition)
|
||||
2. Quoting comments (if any exist)
|
||||
"""
|
||||
lines = sql_statement.split("\n")
|
||||
formatted_lines = []
|
||||
|
||||
in_column_definition = False
|
||||
|
||||
for line in lines:
|
||||
stripped_line = line.strip()
|
||||
|
||||
# Check if we're entering column definitions
|
||||
if "CREATE TABLE" in line.upper() and "(" in line:
|
||||
in_column_definition = True
|
||||
formatted_lines.append(line)
|
||||
continue
|
||||
|
||||
# Check if we're exiting column definitions (closing parenthesis before PARTITIONED BY or end)
|
||||
if in_column_definition and ")" in line:
|
||||
in_column_definition = False
|
||||
formatted_lines.append(line)
|
||||
continue
|
||||
|
||||
# Process only column definitions (not PARTITIONED BY or other sections)
|
||||
if in_column_definition and stripped_line:
|
||||
# Match column definition pattern and format it
|
||||
formatted_line = AthenaPropertiesExtractor.format_column_definition(
|
||||
line
|
||||
)
|
||||
formatted_lines.append(formatted_line)
|
||||
else:
|
||||
# For all other lines, keep as-is
|
||||
formatted_lines.append(line)
|
||||
|
||||
return "\n".join(formatted_lines)
|
||||
|
||||
@staticmethod
|
||||
def _fix_sql_partitioning(sql: str) -> str:
|
||||
"""Fix SQL partitioning by removing backticks from partition expressions and quoting table names.
|
||||
|
||||
Args:
|
||||
sql: The SQL statement to fix
|
||||
|
||||
Returns:
|
||||
The fixed SQL statement
|
||||
"""
|
||||
if not sql:
|
||||
return sql
|
||||
|
||||
# Quote table name
|
||||
table_name_match = AthenaPropertiesExtractor.CREATE_TABLE_REGEXP.search(sql)
|
||||
|
||||
if table_name_match:
|
||||
table_name = table_name_match.group(2).strip()
|
||||
if table_name and not (table_name.startswith("`") or "`" in table_name):
|
||||
# Split on dots and quote each part
|
||||
quoted_parts = [
|
||||
f"`{part.strip()}`"
|
||||
for part in table_name.split(".")
|
||||
if part.strip()
|
||||
]
|
||||
if quoted_parts:
|
||||
quoted_table = ".".join(quoted_parts)
|
||||
create_part = table_name_match.group(0).replace(
|
||||
table_name, quoted_table
|
||||
)
|
||||
sql = sql.replace(table_name_match.group(0), create_part)
|
||||
|
||||
# Fix partition expressions
|
||||
partition_match = AthenaPropertiesExtractor.PARTITIONED_BY_REGEXP.search(sql)
|
||||
|
||||
if partition_match:
|
||||
partition_section = partition_match.group(2)
|
||||
if partition_section:
|
||||
partition_section_modified = partition_section.replace("`", "")
|
||||
sql = sql.replace(partition_section, partition_section_modified)
|
||||
|
||||
return AthenaPropertiesExtractor.format_athena_column_definitions(sql)
|
||||
|
||||
@staticmethod
|
||||
def _extract_column_types(create_expr: Create) -> Dict[str, str]:
|
||||
"""Extract column types from a CREATE TABLE expression.
|
||||
|
||||
Args:
|
||||
create_expr: The CREATE TABLE expression to extract types from
|
||||
|
||||
Returns:
|
||||
A dictionary mapping column names to their types
|
||||
"""
|
||||
column_types: Dict[str, str] = {}
|
||||
|
||||
if not create_expr.this or not hasattr(create_expr.this, "expressions"):
|
||||
return column_types
|
||||
|
||||
try:
|
||||
for expr in create_expr.this.expressions:
|
||||
if isinstance(expr, ColumnDef) and expr.this:
|
||||
column_types[expr.name] = str(expr.kind)
|
||||
except Exception:
|
||||
# If we can't extract column types, return empty dict
|
||||
pass
|
||||
|
||||
return column_types
|
||||
|
||||
@staticmethod
|
||||
def _create_column_info(column_name: str, column_type: str) -> ColumnInfo:
|
||||
"""Create a column info object.
|
||||
|
||||
Args:
|
||||
column_name: Name of the column
|
||||
column_type: Type of the column
|
||||
|
||||
Returns:
|
||||
A ColumnInfo object
|
||||
"""
|
||||
return ColumnInfo(
|
||||
name=str(column_name) if column_name else "unknown",
|
||||
type=column_type if column_type else "unknown",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _handle_function_expression(
|
||||
expr: Identifier, column_types: Dict[str, str]
|
||||
) -> Tuple[ColumnInfo, TransformInfo]:
|
||||
"""Handle function expressions like day(event_timestamp).
|
||||
|
||||
Args:
|
||||
expr: The function expression to handle
|
||||
column_types: Dictionary of column types
|
||||
|
||||
Returns:
|
||||
A tuple of (column_info, transform_info)
|
||||
"""
|
||||
func_str = str(expr)
|
||||
|
||||
if "(" not in func_str or ")" not in func_str:
|
||||
# Fallback for malformed function expressions
|
||||
column_info = AthenaPropertiesExtractor._create_column_info(
|
||||
func_str, "unknown"
|
||||
)
|
||||
transform_info = TransformInfo(type="unknown", column=column_info)
|
||||
return column_info, transform_info
|
||||
|
||||
try:
|
||||
func_name = func_str.split("(")[0].lower()
|
||||
column_part = func_str.split("(")[1].split(")")[0].strip("`")
|
||||
|
||||
column_info = AthenaPropertiesExtractor._create_column_info(
|
||||
column_part, column_types.get(column_part, "unknown")
|
||||
)
|
||||
transform_info = TransformInfo(type=func_name, column=column_info)
|
||||
|
||||
return column_info, transform_info
|
||||
except (IndexError, AttributeError):
|
||||
# Fallback for parsing errors
|
||||
column_info = AthenaPropertiesExtractor._create_column_info(
|
||||
func_str, "unknown"
|
||||
)
|
||||
transform_info = TransformInfo(type="unknown", column=column_info)
|
||||
return column_info, transform_info
|
||||
|
||||
@staticmethod
|
||||
def _handle_time_function(
|
||||
expr: Union[Year, Month, Day], column_types: Dict[str, str]
|
||||
) -> Tuple[ColumnInfo, TransformInfo]:
|
||||
"""Handle time-based functions like year, month, day.
|
||||
|
||||
Args:
|
||||
expr: The time function expression to handle
|
||||
column_types: Dictionary of column types
|
||||
|
||||
Returns:
|
||||
A tuple of (column_info, transform_info)
|
||||
"""
|
||||
try:
|
||||
# Navigate the expression tree safely
|
||||
column_name = "unknown"
|
||||
if hasattr(expr, "this") and expr.this:
|
||||
if hasattr(expr.this, "this") and expr.this.this:
|
||||
if hasattr(expr.this.this, "this") and expr.this.this.this:
|
||||
column_name = str(expr.this.this.this)
|
||||
else:
|
||||
column_name = str(expr.this.this)
|
||||
else:
|
||||
column_name = str(expr.this)
|
||||
|
||||
column_info = AthenaPropertiesExtractor._create_column_info(
|
||||
column_name, column_types.get(column_name, "unknown")
|
||||
)
|
||||
transform_info = TransformInfo(
|
||||
type=expr.__class__.__name__.lower(), column=column_info
|
||||
)
|
||||
|
||||
return column_info, transform_info
|
||||
except (AttributeError, TypeError):
|
||||
# Fallback for navigation errors
|
||||
column_info = AthenaPropertiesExtractor._create_column_info(
|
||||
"unknown", "unknown"
|
||||
)
|
||||
transform_info = TransformInfo(type="unknown", column=column_info)
|
||||
return column_info, transform_info
|
||||
|
||||
@staticmethod
|
||||
def _handle_transform_function(
|
||||
expr: Anonymous, column_types: Dict[str, str]
|
||||
) -> Tuple[ColumnInfo, TransformInfo]:
|
||||
"""Handle transform functions like bucket, hour, truncate.
|
||||
|
||||
Args:
|
||||
expr: The transform function expression to handle
|
||||
column_types: Dictionary of column types
|
||||
|
||||
Returns:
|
||||
A tuple of (column_info, transform_info)
|
||||
"""
|
||||
try:
|
||||
# Safely extract column name from the last expression
|
||||
column_name = "unknown"
|
||||
if (
|
||||
hasattr(expr, "expressions")
|
||||
and expr.expressions
|
||||
and len(expr.expressions) > 0
|
||||
):
|
||||
last_expr = expr.expressions[-1]
|
||||
if hasattr(last_expr, "this") and last_expr.this:
|
||||
if hasattr(last_expr.this, "this") and last_expr.this.this:
|
||||
column_name = str(last_expr.this.this)
|
||||
else:
|
||||
column_name = str(last_expr.this)
|
||||
|
||||
column_info = AthenaPropertiesExtractor._create_column_info(
|
||||
column_name, column_types.get(column_name, "unknown")
|
||||
)
|
||||
|
||||
transform_type = str(expr.this).lower() if expr.this else "unknown"
|
||||
transform_info = TransformInfo(type=transform_type, column=column_info)
|
||||
|
||||
# Add transform-specific parameters safely
|
||||
if (
|
||||
transform_type == "bucket"
|
||||
and hasattr(expr, "expressions")
|
||||
and expr.expressions
|
||||
and len(expr.expressions) > 0
|
||||
):
|
||||
first_expr = expr.expressions[0]
|
||||
if hasattr(first_expr, "this"):
|
||||
transform_info.bucket_count = first_expr.this
|
||||
elif (
|
||||
transform_type == "truncate"
|
||||
and hasattr(expr, "expressions")
|
||||
and expr.expressions
|
||||
and len(expr.expressions) > 0
|
||||
):
|
||||
first_expr = expr.expressions[0]
|
||||
if hasattr(first_expr, "this"):
|
||||
transform_info.length = first_expr.this
|
||||
|
||||
return column_info, transform_info
|
||||
except (AttributeError, TypeError, IndexError):
|
||||
# Fallback for any parsing errors
|
||||
column_info = AthenaPropertiesExtractor._create_column_info(
|
||||
"unknown", "unknown"
|
||||
)
|
||||
transform_info = TransformInfo(type="unknown", column=column_info)
|
||||
return column_info, transform_info
|
||||
|
||||
def _extract_partition_info(self, parsed: Expression) -> PartitionInfo:
|
||||
"""Extract partitioning information from the parsed SQL statement.
|
||||
|
||||
Args:
|
||||
parsed: The parsed SQL expression
|
||||
|
||||
Returns:
|
||||
A PartitionInfo object containing simple columns and transforms
|
||||
"""
|
||||
# Get the PARTITIONED BY expression
|
||||
partition_by_expr: Optional[Schema] = None
|
||||
|
||||
try:
|
||||
for prop in parsed.find_all(Property):
|
||||
if isinstance(prop, PartitionedByProperty):
|
||||
partition_by_expr = prop.this
|
||||
break
|
||||
except Exception:
|
||||
# If we can't find properties, return empty result
|
||||
return PartitionInfo(simple_columns=[], transforms=[])
|
||||
|
||||
if not partition_by_expr:
|
||||
return PartitionInfo(simple_columns=[], transforms=[])
|
||||
|
||||
# Extract partitioning columns and transforms
|
||||
simple_columns: List[ColumnInfo] = []
|
||||
transforms: List[TransformInfo] = []
|
||||
|
||||
# Get column types from the table definition
|
||||
column_types: Dict[str, str] = {}
|
||||
if isinstance(parsed, Create):
|
||||
column_types = self._extract_column_types(parsed)
|
||||
|
||||
# Process each expression in the PARTITIONED BY clause
|
||||
if hasattr(partition_by_expr, "expressions") and partition_by_expr.expressions:
|
||||
for expr in partition_by_expr.expressions:
|
||||
try:
|
||||
if isinstance(expr, Identifier) and "(" in str(expr):
|
||||
column_info, transform_info = self._handle_function_expression(
|
||||
expr, column_types
|
||||
)
|
||||
simple_columns.append(column_info)
|
||||
transforms.append(transform_info)
|
||||
elif isinstance(expr, PartitionByTruncate):
|
||||
column_info = AthenaPropertiesExtractor._create_column_info(
|
||||
str(expr.this), column_types.get(str(expr.this), "unknown")
|
||||
)
|
||||
|
||||
expression = expr.args.get("expression")
|
||||
transform_info = TransformInfo(
|
||||
type="truncate",
|
||||
column=column_info,
|
||||
length=int(expression.name)
|
||||
if expression and expression.name
|
||||
else None,
|
||||
)
|
||||
transforms.append(transform_info)
|
||||
simple_columns.append(column_info)
|
||||
elif isinstance(expr, PartitionedByBucket):
|
||||
column_info = AthenaPropertiesExtractor._create_column_info(
|
||||
str(expr.this), column_types.get(str(expr.this), "unknown")
|
||||
)
|
||||
expression = expr.args.get("expression")
|
||||
transform_info = TransformInfo(
|
||||
type="bucket",
|
||||
column=column_info,
|
||||
bucket_count=int(expression.name)
|
||||
if expression and expression.name
|
||||
else None,
|
||||
)
|
||||
simple_columns.append(column_info)
|
||||
transforms.append(transform_info)
|
||||
elif isinstance(expr, (Year, Month, Day)):
|
||||
column_info, transform_info = self._handle_time_function(
|
||||
expr, column_types
|
||||
)
|
||||
transforms.append(transform_info)
|
||||
simple_columns.append(column_info)
|
||||
elif (
|
||||
isinstance(expr, Anonymous)
|
||||
and expr.this
|
||||
and str(expr.this).lower() in ["bucket", "hour", "truncate"]
|
||||
):
|
||||
column_info, transform_info = self._handle_transform_function(
|
||||
expr, column_types
|
||||
)
|
||||
transforms.append(transform_info)
|
||||
simple_columns.append(column_info)
|
||||
elif hasattr(expr, "this") and expr.this:
|
||||
column_name = str(expr.this)
|
||||
column_info = self._create_column_info(
|
||||
column_name, column_types.get(column_name, "unknown")
|
||||
)
|
||||
simple_columns.append(column_info)
|
||||
except Exception:
|
||||
# Skip problematic expressions rather than failing completely
|
||||
continue
|
||||
|
||||
# Remove duplicates from simple_columns while preserving order
|
||||
seen_names: Set[str] = set()
|
||||
unique_simple_columns: List[ColumnInfo] = []
|
||||
|
||||
for col in simple_columns:
|
||||
if col.name and col.name not in seen_names:
|
||||
seen_names.add(col.name)
|
||||
unique_simple_columns.append(col)
|
||||
|
||||
return PartitionInfo(
|
||||
simple_columns=unique_simple_columns, transforms=transforms
|
||||
)
|
||||
|
||||
def _extract_table_properties(self, parsed: Expression) -> TableProperties:
|
||||
"""Extract table properties from the parsed SQL statement.
|
||||
|
||||
Args:
|
||||
parsed: The parsed SQL expression
|
||||
|
||||
Returns:
|
||||
A TableProperties object
|
||||
"""
|
||||
location: Optional[str] = None
|
||||
format_prop: Optional[str] = None
|
||||
comment: Optional[str] = None
|
||||
serde_properties: Optional[Dict[str, str]] = None
|
||||
row_format: Optional[Dict[str, str]] = None
|
||||
additional_properties: Dict[str, str] = {}
|
||||
|
||||
try:
|
||||
props = list(parsed.find_all(Property))
|
||||
except Exception:
|
||||
return TableProperties()
|
||||
|
||||
for prop in props:
|
||||
try:
|
||||
if isinstance(prop, LocationProperty):
|
||||
location = self._safe_get_property_value(prop)
|
||||
|
||||
elif isinstance(prop, FileFormatProperty):
|
||||
format_prop = self._safe_get_property_value(prop)
|
||||
|
||||
elif isinstance(prop, SchemaCommentProperty):
|
||||
comment = self._safe_get_property_value(prop)
|
||||
|
||||
elif isinstance(prop, PartitionedByProperty):
|
||||
continue # Skip partition properties here
|
||||
|
||||
elif isinstance(prop, SerdeProperties):
|
||||
serde_props = self._extract_serde_properties(prop)
|
||||
if serde_props:
|
||||
serde_properties = serde_props
|
||||
|
||||
elif isinstance(prop, RowFormatDelimitedProperty):
|
||||
row_format_props = self._extract_row_format_properties(prop)
|
||||
if row_format_props:
|
||||
row_format = row_format_props
|
||||
|
||||
else:
|
||||
# Handle generic properties
|
||||
key, value = self._extract_generic_property(prop)
|
||||
if (
|
||||
key
|
||||
and value
|
||||
and (not serde_properties or key not in serde_properties)
|
||||
):
|
||||
additional_properties[key] = value
|
||||
|
||||
except Exception:
|
||||
# Skip problematic properties rather than failing completely
|
||||
continue
|
||||
|
||||
if (
|
||||
not location
|
||||
and additional_properties
|
||||
and additional_properties.get("external_location")
|
||||
):
|
||||
location = additional_properties.pop("external_location")
|
||||
|
||||
return TableProperties(
|
||||
location=location,
|
||||
format=format_prop,
|
||||
comment=comment,
|
||||
serde_properties=serde_properties,
|
||||
row_format=row_format,
|
||||
additional_properties=additional_properties
|
||||
if additional_properties
|
||||
else None,
|
||||
)
|
||||
|
||||
def _safe_get_property_value(self, prop: Property) -> Optional[str]:
|
||||
"""Safely extract value from a property."""
|
||||
try:
|
||||
if (
|
||||
hasattr(prop, "args")
|
||||
and "this" in prop.args
|
||||
and prop.args["this"]
|
||||
and hasattr(prop.args["this"], "name")
|
||||
):
|
||||
return prop.args["this"].name
|
||||
except (AttributeError, KeyError, TypeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
def _extract_serde_properties(self, prop: SerdeProperties) -> Dict[str, str]:
|
||||
"""Extract SERDE properties safely."""
|
||||
serde_props: Dict[str, str] = {}
|
||||
try:
|
||||
if hasattr(prop, "expressions") and prop.expressions:
|
||||
for exp in prop.expressions:
|
||||
if (
|
||||
hasattr(exp, "name")
|
||||
and hasattr(exp, "args")
|
||||
and "value" in exp.args
|
||||
and exp.args["value"]
|
||||
and hasattr(exp.args["value"], "name")
|
||||
):
|
||||
serde_props[exp.name] = exp.args["value"].name
|
||||
except Exception:
|
||||
pass
|
||||
return serde_props
|
||||
|
||||
def _extract_row_format_properties(
|
||||
self, prop: RowFormatDelimitedProperty
|
||||
) -> Dict[str, str]:
|
||||
"""Extract row format properties safely."""
|
||||
row_format: Dict[str, str] = {}
|
||||
try:
|
||||
if hasattr(prop, "args") and prop.args:
|
||||
for key, value in prop.args.items():
|
||||
if hasattr(value, "this"):
|
||||
row_format[key] = str(value.this)
|
||||
else:
|
||||
row_format[key] = str(value)
|
||||
except Exception:
|
||||
pass
|
||||
return row_format
|
||||
|
||||
def _extract_generic_property(
|
||||
self, prop: Property
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Extract key-value pair from generic property."""
|
||||
try:
|
||||
if (
|
||||
hasattr(prop, "args")
|
||||
and "this" in prop.args
|
||||
and prop.args["this"]
|
||||
and hasattr(prop.args["this"], "name")
|
||||
and "value" in prop.args
|
||||
and prop.args["value"]
|
||||
and hasattr(prop.args["value"], "name")
|
||||
):
|
||||
key = prop.args["this"].name.lower()
|
||||
value = prop.args["value"].name
|
||||
return key, value
|
||||
except (AttributeError, KeyError, TypeError):
|
||||
pass
|
||||
return None, None
|
||||
|
||||
def _extract_row_format(self, parsed: Expression) -> RowFormatInfo:
|
||||
"""Extract and format RowFormatDelimitedProperty.
|
||||
|
||||
Args:
|
||||
parsed: The parsed SQL expression
|
||||
|
||||
Returns:
|
||||
A RowFormatInfo object
|
||||
"""
|
||||
row_format_props: Dict[str, str] = {}
|
||||
|
||||
try:
|
||||
props = parsed.find_all(Property)
|
||||
for prop in props:
|
||||
if isinstance(prop, RowFormatDelimitedProperty):
|
||||
row_format_props = self._extract_row_format_properties(prop)
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if row_format_props:
|
||||
try:
|
||||
json_formatted = json.dumps(row_format_props, indent=2)
|
||||
except (TypeError, ValueError):
|
||||
json_formatted = "Error formatting row format properties"
|
||||
else:
|
||||
json_formatted = "No RowFormatDelimitedProperty found"
|
||||
|
||||
return RowFormatInfo(properties=row_format_props, json_formatted=json_formatted)
|
@ -273,7 +273,7 @@
|
||||
},
|
||||
"fields": [
|
||||
{
|
||||
"fieldPath": "[version=2.0].[type=string].employee_id",
|
||||
"fieldPath": "employee_id",
|
||||
"nullable": false,
|
||||
"description": "Unique identifier for the employee",
|
||||
"type": {
|
||||
@ -287,7 +287,7 @@
|
||||
"isPartitioningKey": false
|
||||
},
|
||||
{
|
||||
"fieldPath": "[version=2.0].[type=long].annual_salary",
|
||||
"fieldPath": "annual_salary",
|
||||
"nullable": true,
|
||||
"description": "Annual salary of the employee in USD",
|
||||
"type": {
|
||||
@ -301,7 +301,7 @@
|
||||
"isPartitioningKey": false
|
||||
},
|
||||
{
|
||||
"fieldPath": "[version=2.0].[type=string].employee_name",
|
||||
"fieldPath": "employee_name",
|
||||
"nullable": false,
|
||||
"description": "Full name of the employee",
|
||||
"type": {
|
||||
@ -515,7 +515,7 @@
|
||||
},
|
||||
"fields": [
|
||||
{
|
||||
"fieldPath": "[version=2.0].[type=string].employee_id",
|
||||
"fieldPath": "employee_id",
|
||||
"nullable": false,
|
||||
"description": "Unique identifier for the employee",
|
||||
"type": {
|
||||
@ -529,7 +529,7 @@
|
||||
"isPartitioningKey": false
|
||||
},
|
||||
{
|
||||
"fieldPath": "[version=2.0].[type=long].annual_salary",
|
||||
"fieldPath": "annual_salary",
|
||||
"nullable": true,
|
||||
"description": "Annual salary of the employee in USD",
|
||||
"type": {
|
||||
@ -543,7 +543,7 @@
|
||||
"isPartitioningKey": false
|
||||
},
|
||||
{
|
||||
"fieldPath": "[version=2.0].[type=string].employee_name",
|
||||
"fieldPath": "employee_name",
|
||||
"nullable": false,
|
||||
"description": "Full name of the employee",
|
||||
"type": {
|
||||
@ -775,7 +775,7 @@
|
||||
},
|
||||
"fields": [
|
||||
{
|
||||
"fieldPath": "[version=2.0].[type=string].employee_id",
|
||||
"fieldPath": "employee_id",
|
||||
"nullable": false,
|
||||
"description": "Unique identifier for the employee",
|
||||
"type": {
|
||||
@ -789,7 +789,7 @@
|
||||
"isPartitioningKey": false
|
||||
},
|
||||
{
|
||||
"fieldPath": "[version=2.0].[type=long].annual_salary",
|
||||
"fieldPath": "annual_salary",
|
||||
"nullable": true,
|
||||
"description": "Annual salary of the employee in USD",
|
||||
"type": {
|
||||
@ -803,7 +803,7 @@
|
||||
"isPartitioningKey": false
|
||||
},
|
||||
{
|
||||
"fieldPath": "[version=2.0].[type=string].employee_name",
|
||||
"fieldPath": "employee_name",
|
||||
"nullable": false,
|
||||
"description": "Full name of the employee",
|
||||
"type": {
|
||||
|
@ -0,0 +1,766 @@
|
||||
"""
|
||||
Pytest tests for AthenaPropertiesExtractor.
|
||||
|
||||
Tests the extraction of properties, partitioning information,
|
||||
and row format details from various Athena CREATE TABLE SQL statements.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from datahub.ingestion.source.sql.athena_properties_extractor import (
|
||||
AthenaPropertiesExtractionError,
|
||||
AthenaPropertiesExtractor,
|
||||
AthenaTableInfo,
|
||||
ColumnInfo,
|
||||
PartitionInfo,
|
||||
RowFormatInfo,
|
||||
TableProperties,
|
||||
TransformInfo,
|
||||
)
|
||||
|
||||
|
||||
class TestAthenaPropertiesExtractor:
|
||||
"""Test class for AthenaPropertiesExtractor."""
|
||||
|
||||
def test_iceberg_table_with_complex_partitioning(self):
|
||||
"""Test extraction from Iceberg table with complex partitioning."""
|
||||
sql = """
|
||||
CREATE TABLE iceberg_table (ts timestamp, id bigint, data string, category string)
|
||||
PARTITIONED BY (category, bucket(16, id), year(ts), month(ts), day(ts), hour(ts), truncate(10, ts))
|
||||
LOCATION 's3://amzn-s3-demo-bucket/your-folder/'
|
||||
TBLPROPERTIES ( 'table_type' = 'ICEBERG' ) \
|
||||
"""
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
# Test basic structure
|
||||
assert isinstance(result, AthenaTableInfo)
|
||||
assert isinstance(result.partition_info, PartitionInfo)
|
||||
assert isinstance(result.table_properties, TableProperties)
|
||||
assert isinstance(result.row_format, RowFormatInfo)
|
||||
|
||||
# Test partition info
|
||||
partition_info = result.partition_info
|
||||
|
||||
# Should have multiple simple columns
|
||||
assert len(partition_info.simple_columns) > 0
|
||||
|
||||
# Check for category column (simple partition)
|
||||
category_cols = [
|
||||
col for col in partition_info.simple_columns if col.name == "category"
|
||||
]
|
||||
assert len(category_cols) == 1
|
||||
assert category_cols[0].type == "TEXT"
|
||||
|
||||
# Check for id column (used in bucket transform)
|
||||
id_cols = [col for col in partition_info.simple_columns if col.name == "id"]
|
||||
assert len(id_cols) == 1
|
||||
assert id_cols[0].type == "BIGINT"
|
||||
|
||||
# Check for ts column (used in time transforms)
|
||||
ts_cols = [col for col in partition_info.simple_columns if col.name == "ts"]
|
||||
assert len(ts_cols) == 1
|
||||
assert ts_cols[0].type == "TIMESTAMP"
|
||||
|
||||
# Test transforms
|
||||
transforms = partition_info.transforms
|
||||
assert len(transforms) >= 6 # bucket, year, month, day, hour, truncate
|
||||
|
||||
# Check bucket transform
|
||||
bucket_transforms = [t for t in transforms if t.type == "bucket"]
|
||||
assert len(bucket_transforms) == 1
|
||||
bucket_transform = bucket_transforms[0]
|
||||
assert bucket_transform.column.name == "id"
|
||||
assert bucket_transform.bucket_count == 16
|
||||
|
||||
# Check time transforms
|
||||
time_transform_types = {
|
||||
t.type for t in transforms if t.type in ["year", "month", "day", "hour"]
|
||||
}
|
||||
assert "year" in time_transform_types
|
||||
assert "month" in time_transform_types
|
||||
assert "day" in time_transform_types
|
||||
assert "hour" in time_transform_types
|
||||
|
||||
# Check truncate transform
|
||||
truncate_transforms = [t for t in transforms if t.type == "truncate"]
|
||||
assert len(truncate_transforms) == 1
|
||||
truncate_transform = truncate_transforms[0]
|
||||
assert truncate_transform.column.name == "ts"
|
||||
assert truncate_transform.length == 10
|
||||
|
||||
# Test table properties
|
||||
table_props = result.table_properties
|
||||
assert table_props.location == "s3://amzn-s3-demo-bucket/your-folder/"
|
||||
assert table_props.additional_properties is not None
|
||||
assert table_props.additional_properties.get("table_type") == "ICEBERG"
|
||||
|
||||
def test_trino_table_with_array_partitioning(self):
|
||||
"""Test extraction from Trino table with ARRAY partitioning."""
|
||||
sql = """
|
||||
create table trino.db_collection (
|
||||
col1 varchar,
|
||||
col2 varchar,
|
||||
col3 varchar
|
||||
)with (
|
||||
external_location = 's3a://bucket/trino/db_collection/*',
|
||||
format = 'PARQUET',
|
||||
partitioned_by = ARRAY['col1','col2']
|
||||
) \
|
||||
"""
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
# Test table properties
|
||||
table_props = result.table_properties
|
||||
assert table_props.location == "s3a://bucket/trino/db_collection/*"
|
||||
assert table_props.format == "PARQUET"
|
||||
|
||||
# Note: ARRAY partitioning might not be parsed the same way as standard PARTITIONED BY
|
||||
# This tests that the extraction doesn't fail and extracts what it can
|
||||
|
||||
def test_simple_orc_table(self):
|
||||
"""Test extraction from simple ORC table."""
|
||||
sql = """
|
||||
CREATE TABLE orders (
|
||||
orderkey bigint,
|
||||
orderstatus varchar,
|
||||
totalprice double,
|
||||
orderdate date
|
||||
)
|
||||
WITH (format = 'ORC')
|
||||
"""
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
# Test basic structure
|
||||
assert isinstance(result, AthenaTableInfo)
|
||||
|
||||
# Should have no partitions
|
||||
assert len(result.partition_info.simple_columns) == 0
|
||||
assert len(result.partition_info.transforms) == 0
|
||||
|
||||
# Test table properties
|
||||
table_props = result.table_properties
|
||||
assert table_props.format == "ORC"
|
||||
assert table_props.location is None
|
||||
assert table_props.comment is None
|
||||
|
||||
def test_table_with_comments(self):
|
||||
"""Test extraction from table with table and column comments."""
|
||||
sql = """
|
||||
CREATE TABLE IF NOT EXISTS orders (
|
||||
orderkey bigint,
|
||||
orderstatus varchar,
|
||||
totalprice double COMMENT 'Price in cents.',
|
||||
orderdate date
|
||||
)
|
||||
COMMENT 'A table to keep track of orders.' \
|
||||
"""
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
# Test table comment
|
||||
table_props = result.table_properties
|
||||
assert table_props.comment == "A table to keep track of orders."
|
||||
|
||||
# No partitions expected
|
||||
assert len(result.partition_info.simple_columns) == 0
|
||||
assert len(result.partition_info.transforms) == 0
|
||||
|
||||
def test_table_with_row_format_and_serde(self):
|
||||
"""Test extraction from table with row format and SERDE properties."""
|
||||
sql = """
|
||||
CREATE TABLE IF NOT EXISTS orders (
|
||||
orderkey bigint,
|
||||
orderstatus varchar,
|
||||
totalprice double,
|
||||
orderdate date
|
||||
)
|
||||
ROW FORMAT DELIMITED COLLECTION ITEMS TERMINATED BY ','
|
||||
STORED AS PARQUET
|
||||
WITH SERDEPROPERTIES (
|
||||
'serialization.format' = '1'
|
||||
) \
|
||||
"""
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
# Test table properties
|
||||
table_props = result.table_properties
|
||||
|
||||
# Test SERDE properties
|
||||
assert table_props.serde_properties is not None
|
||||
assert table_props.serde_properties.get("serialization.format") == "1"
|
||||
|
||||
# Test row format
|
||||
row_format = result.row_format
|
||||
assert isinstance(row_format, RowFormatInfo)
|
||||
assert isinstance(row_format.properties, dict)
|
||||
assert "No RowFormatDelimitedProperty found" not in row_format.json_formatted
|
||||
|
||||
def test_empty_sql_raises_error(self):
|
||||
"""Test that empty SQL raises appropriate error."""
|
||||
with pytest.raises(
|
||||
AthenaPropertiesExtractionError, match="SQL statement cannot be empty"
|
||||
):
|
||||
AthenaPropertiesExtractor.get_table_properties("")
|
||||
|
||||
with pytest.raises(
|
||||
AthenaPropertiesExtractionError, match="SQL statement cannot be empty"
|
||||
):
|
||||
AthenaPropertiesExtractor.get_table_properties(" ")
|
||||
|
||||
def test_minimal_create_table(self):
|
||||
"""Test extraction from minimal CREATE TABLE statement."""
|
||||
sql = "CREATE TABLE test (id int)"
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
# Should not fail and return basic structure
|
||||
assert isinstance(result, AthenaTableInfo)
|
||||
assert len(result.partition_info.simple_columns) == 0
|
||||
assert len(result.partition_info.transforms) == 0
|
||||
assert result.table_properties.location is None
|
||||
|
||||
def test_column_info_dataclass(self):
|
||||
"""Test ColumnInfo dataclass properties."""
|
||||
sql = """
|
||||
CREATE TABLE test (id bigint, name varchar)
|
||||
PARTITIONED BY (id) \
|
||||
"""
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
# Test that we get ColumnInfo objects
|
||||
assert len(result.partition_info.simple_columns) == 1
|
||||
column = result.partition_info.simple_columns[0]
|
||||
|
||||
assert isinstance(column, ColumnInfo)
|
||||
assert column.name == "id"
|
||||
assert column.type == "BIGINT"
|
||||
|
||||
def test_transform_info_dataclass(self):
|
||||
"""Test TransformInfo dataclass properties."""
|
||||
sql = """
|
||||
CREATE TABLE test (ts timestamp, id bigint)
|
||||
PARTITIONED BY (year(ts), bucket(8, id)) \
|
||||
"""
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
transforms = result.partition_info.transforms
|
||||
assert len(transforms) >= 2
|
||||
|
||||
# Find year transform
|
||||
year_transforms = [t for t in transforms if t.type == "year"]
|
||||
assert len(year_transforms) == 1
|
||||
year_transform = year_transforms[0]
|
||||
|
||||
assert isinstance(year_transform, TransformInfo)
|
||||
assert year_transform.type == "year"
|
||||
assert isinstance(year_transform.column, ColumnInfo)
|
||||
assert year_transform.column.name == "ts"
|
||||
assert year_transform.bucket_count is None
|
||||
assert year_transform.length is None
|
||||
|
||||
# Find bucket transform
|
||||
bucket_transforms = [t for t in transforms if t.type == "bucket"]
|
||||
assert len(bucket_transforms) == 1
|
||||
bucket_transform = bucket_transforms[0]
|
||||
|
||||
assert isinstance(bucket_transform, TransformInfo)
|
||||
assert bucket_transform.type == "bucket"
|
||||
assert bucket_transform.column.name == "id"
|
||||
assert bucket_transform.bucket_count == 8
|
||||
assert bucket_transform.length is None
|
||||
|
||||
def test_multiple_sql_statements_stateless(self):
|
||||
"""Test that the extractor is stateless and works with multiple SQL statements."""
|
||||
sql1 = "CREATE TABLE test1 (id int) WITH (format = 'PARQUET')"
|
||||
sql2 = "CREATE TABLE test2 (name varchar) WITH (format = 'ORC')"
|
||||
|
||||
# Call multiple times to ensure no state interference
|
||||
result1 = AthenaPropertiesExtractor.get_table_properties(sql1)
|
||||
result2 = AthenaPropertiesExtractor.get_table_properties(sql2)
|
||||
result1_again = AthenaPropertiesExtractor.get_table_properties(sql1)
|
||||
|
||||
# Results should be consistent
|
||||
assert result1.table_properties.format == "PARQUET"
|
||||
assert result2.table_properties.format == "ORC"
|
||||
assert result1_again.table_properties.format == "PARQUET"
|
||||
|
||||
# Results should be independent
|
||||
assert result1.table_properties.format != result2.table_properties.format
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,expected_location",
|
||||
[
|
||||
(
|
||||
"CREATE TABLE test (id int) LOCATION 's3://bucket/path/'",
|
||||
"s3://bucket/path/",
|
||||
),
|
||||
("CREATE TABLE test (id int)", None),
|
||||
],
|
||||
)
|
||||
def test_location_extraction_parametrized(self, sql, expected_location):
|
||||
"""Test location extraction with parametrized inputs."""
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
assert result.table_properties.location == expected_location
|
||||
|
||||
|
||||
# Integration test that could be run with actual SQL files
|
||||
class TestAthenaPropertiesExtractorIntegration:
|
||||
"""Integration tests for AthenaPropertiesExtractor."""
|
||||
|
||||
def test_complex_real_world_example(self):
|
||||
"""Test with a complex real-world-like example."""
|
||||
sql = """
|
||||
CREATE TABLE analytics.user_events (
|
||||
user_id bigint COMMENT 'Unique user identifier',
|
||||
event_time timestamp COMMENT 'When the event occurred',
|
||||
event_type varchar COMMENT 'Type of event',
|
||||
session_id varchar,
|
||||
properties map<varchar, varchar> COMMENT 'Event properties',
|
||||
created_date date
|
||||
)
|
||||
COMMENT 'User event tracking table'
|
||||
PARTITIONED BY (
|
||||
created_date,
|
||||
bucket(100, user_id),
|
||||
hour(event_time)
|
||||
)
|
||||
LOCATION 's3://analytics-bucket/user-events/'
|
||||
STORED AS PARQUET
|
||||
TBLPROPERTIES (
|
||||
'table_type' = 'ICEBERG',
|
||||
'write.target-file-size-bytes' = '134217728',
|
||||
'write.delete.mode' = 'copy-on-write'
|
||||
) \
|
||||
"""
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
# Comprehensive validation
|
||||
assert isinstance(result, AthenaTableInfo)
|
||||
|
||||
# Check table properties
|
||||
props = result.table_properties
|
||||
assert props.location == "s3://analytics-bucket/user-events/"
|
||||
assert props.comment == "User event tracking table"
|
||||
assert props.additional_properties is not None
|
||||
assert props.additional_properties.get("table_type") == "ICEBERG"
|
||||
|
||||
# Check partitioning
|
||||
partition_info = result.partition_info
|
||||
|
||||
# Should have created_date as simple partition
|
||||
date_cols = [
|
||||
col for col in partition_info.simple_columns if col.name == "user_id"
|
||||
]
|
||||
assert len(date_cols) == 1
|
||||
assert date_cols[0].type == "BIGINT"
|
||||
|
||||
date_cols = [
|
||||
col for col in partition_info.simple_columns if col.name == "created_date"
|
||||
]
|
||||
assert len(date_cols) == 1
|
||||
assert date_cols[0].type == "DATE"
|
||||
|
||||
# Should have transforms
|
||||
transforms = partition_info.transforms
|
||||
transform_types = {t.type for t in transforms}
|
||||
assert "bucket" in transform_types
|
||||
assert "hour" in transform_types
|
||||
|
||||
# Validate bucket transform
|
||||
bucket_transforms = [t for t in transforms if t.type == "bucket"]
|
||||
assert len(bucket_transforms) == 1
|
||||
assert bucket_transforms[0].bucket_count == 100
|
||||
assert bucket_transforms[0].column.name == "user_id"
|
||||
|
||||
def test_external_table_with_row_format_delimited(self):
|
||||
"""Test extraction from external table with detailed row format."""
|
||||
sql = """
|
||||
CREATE EXTERNAL TABLE `my_table`(
|
||||
`itcf id` string,
|
||||
`itcf control name` string,
|
||||
`itcf control description` string,
|
||||
`itcf process` string,
|
||||
`standard` string,
|
||||
`controlid` string,
|
||||
`threshold` string,
|
||||
`status` string,
|
||||
`date reported` string,
|
||||
`remediation (accs specific)` string,
|
||||
`aws account id` string,
|
||||
`aws resource id` string,
|
||||
`aws account owner` string)
|
||||
ROW FORMAT DELIMITED
|
||||
FIELDS TERMINATED BY ','
|
||||
ESCAPED BY '\\\\'
|
||||
LINES TERMINATED BY '\\n'
|
||||
LOCATION
|
||||
's3://myfolder/'
|
||||
TBLPROPERTIES (
|
||||
'skip.header.line.count'='1');
|
||||
"""
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
# Test basic structure
|
||||
assert isinstance(result, AthenaTableInfo)
|
||||
|
||||
# Test table properties
|
||||
table_props = result.table_properties
|
||||
assert table_props.location == "s3://myfolder/"
|
||||
|
||||
# Test TBLPROPERTIES
|
||||
assert table_props.additional_properties is not None
|
||||
assert table_props.additional_properties.get("skip.header.line.count") == "1"
|
||||
|
||||
# Test row format
|
||||
row_format = result.row_format
|
||||
assert isinstance(row_format, RowFormatInfo)
|
||||
|
||||
# The row format should contain delimited properties
|
||||
# Note: The exact keys depend on how sqlglot parses ROW FORMAT DELIMITED
|
||||
assert isinstance(row_format.properties, dict)
|
||||
|
||||
# Should have structured JSON output
|
||||
assert row_format.json_formatted != "No RowFormatDelimitedProperty found"
|
||||
|
||||
# Should not have partitions (no PARTITIONED BY clause)
|
||||
assert len(result.partition_info.simple_columns) == 0
|
||||
assert len(result.partition_info.transforms) == 0
|
||||
|
||||
def test_database_qualified_table_with_iceberg_properties(self):
|
||||
"""Test extraction from database-qualified table with Iceberg properties."""
|
||||
sql = """
|
||||
CREATE TABLE mydatabase.my_table (
|
||||
id string,
|
||||
name string,
|
||||
type string,
|
||||
industry string,
|
||||
annual_revenue double,
|
||||
website string,
|
||||
phone string,
|
||||
billing_street string,
|
||||
billing_city string,
|
||||
billing_state string,
|
||||
billing_postal_code string,
|
||||
billing_country string,
|
||||
shipping_street string,
|
||||
shipping_city string,
|
||||
shipping_state string,
|
||||
shipping_postal_code string,
|
||||
shipping_country string,
|
||||
number_of_employees int,
|
||||
description string,
|
||||
owner_id string,
|
||||
created_date timestamp,
|
||||
last_modified_date timestamp,
|
||||
is_deleted boolean)
|
||||
LOCATION 's3://mybucket/myfolder/'
|
||||
TBLPROPERTIES (
|
||||
'table_type'='iceberg',
|
||||
'write_compression'='snappy',
|
||||
'format'='parquet',
|
||||
'optimize_rewrite_delete_file_threshold'='10'
|
||||
); \
|
||||
"""
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
# Test basic structure
|
||||
assert isinstance(result, AthenaTableInfo)
|
||||
|
||||
# Test table properties
|
||||
table_props = result.table_properties
|
||||
assert table_props.location == "s3://mybucket/myfolder/"
|
||||
|
||||
# Test multiple TBLPROPERTIES
|
||||
assert table_props.additional_properties is not None
|
||||
expected_props = {
|
||||
"table_type": "iceberg",
|
||||
"write_compression": "snappy",
|
||||
"format": "parquet",
|
||||
"optimize_rewrite_delete_file_threshold": "10",
|
||||
}
|
||||
|
||||
for key, expected_value in expected_props.items():
|
||||
assert table_props.additional_properties.get(key) == expected_value, (
|
||||
f"Expected {key}={expected_value}, got {table_props.additional_properties.get(key)}"
|
||||
)
|
||||
|
||||
# Should not have partitions (no PARTITIONED BY clause)
|
||||
assert len(result.partition_info.simple_columns) == 0
|
||||
assert len(result.partition_info.transforms) == 0
|
||||
|
||||
# Row format should be empty/default
|
||||
row_format = result.row_format
|
||||
assert isinstance(row_format, RowFormatInfo)
|
||||
# Should either be empty dict or indicate no row format found
|
||||
assert (
|
||||
len(row_format.properties) == 0
|
||||
or "No RowFormatDelimitedProperty found" in row_format.json_formatted
|
||||
)
|
||||
|
||||
def test_iceberg_table_with_backtick_partitioning(self):
|
||||
"""Test extraction from Iceberg table with backtick-quoted partition functions."""
|
||||
sql = """
|
||||
CREATE TABLE datalake_agg.ml_outdoor_master (
|
||||
event_uuid string,
|
||||
uuid string,
|
||||
_pk string)
|
||||
PARTITIONED BY (
|
||||
`day(event_timestamp)`,
|
||||
`month(event_timestamp)`
|
||||
)
|
||||
LOCATION 's3://bucket/folder/table'
|
||||
TBLPROPERTIES (
|
||||
'table_type'='iceberg',
|
||||
'vacuum_max_snapshot_age_seconds'='60',
|
||||
'format'='PARQUET',
|
||||
'write_compression'='GZIP',
|
||||
'optimize_rewrite_delete_file_threshold'='2',
|
||||
'optimize_rewrite_data_file_threshold'='5',
|
||||
'vacuum_min_snapshots_to_keep'='6'
|
||||
) \
|
||||
"""
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
# Test basic structure
|
||||
assert isinstance(result, AthenaTableInfo)
|
||||
|
||||
# Test table properties
|
||||
table_props = result.table_properties
|
||||
assert table_props.location == "s3://bucket/folder/table"
|
||||
|
||||
# Test comprehensive TBLPROPERTIES for Iceberg
|
||||
assert table_props.additional_properties is not None
|
||||
expected_props = {
|
||||
"table_type": "iceberg",
|
||||
"vacuum_max_snapshot_age_seconds": "60",
|
||||
"format": "PARQUET",
|
||||
"write_compression": "GZIP",
|
||||
"optimize_rewrite_delete_file_threshold": "2",
|
||||
"optimize_rewrite_data_file_threshold": "5",
|
||||
"vacuum_min_snapshots_to_keep": "6",
|
||||
}
|
||||
|
||||
for key, expected_value in expected_props.items():
|
||||
actual_value = table_props.additional_properties.get(key)
|
||||
assert actual_value == expected_value, (
|
||||
f"Expected {key}={expected_value}, got {actual_value}"
|
||||
)
|
||||
|
||||
# Test partition info - this is the interesting part with backtick-quoted functions
|
||||
partition_info = result.partition_info
|
||||
|
||||
# Should have transforms for day() and month() functions
|
||||
transforms = partition_info.transforms
|
||||
assert len(transforms) >= 2, (
|
||||
f"Expected at least 2 transforms, got {len(transforms)}"
|
||||
)
|
||||
|
||||
# Check for day transform
|
||||
day_transforms = [t for t in transforms if t.type == "day"]
|
||||
assert len(day_transforms) >= 1, (
|
||||
f"Expected day transform, transforms: {[t.type for t in transforms]}"
|
||||
)
|
||||
|
||||
if day_transforms:
|
||||
day_transform = day_transforms[0]
|
||||
assert isinstance(day_transform, TransformInfo)
|
||||
assert day_transform.type == "day"
|
||||
assert isinstance(day_transform.column, ColumnInfo)
|
||||
# The column should be event_timestamp (extracted from day(event_timestamp))
|
||||
assert day_transform.column.name == "event_timestamp"
|
||||
|
||||
# Check for month transform
|
||||
month_transforms = [t for t in transforms if t.type == "month"]
|
||||
assert len(month_transforms) >= 1, (
|
||||
f"Expected month transform, transforms: {[t.type for t in transforms]}"
|
||||
)
|
||||
|
||||
if month_transforms:
|
||||
month_transform = month_transforms[0]
|
||||
assert isinstance(month_transform, TransformInfo)
|
||||
assert month_transform.type == "month"
|
||||
assert isinstance(month_transform.column, ColumnInfo)
|
||||
# The column should be event_timestamp (extracted from month(event_timestamp))
|
||||
assert month_transform.column.name == "event_timestamp"
|
||||
|
||||
# Test simple columns - should include event_timestamp from the transforms
|
||||
simple_columns = partition_info.simple_columns
|
||||
event_timestamp_cols = [
|
||||
col for col in simple_columns if col.name == "event_timestamp"
|
||||
]
|
||||
assert len(event_timestamp_cols) >= 1, (
|
||||
f"Expected event_timestamp column, columns: {[col.name for col in simple_columns]}"
|
||||
)
|
||||
|
||||
# The event_timestamp column type might be "unknown" since it's not in the table definition
|
||||
# but referenced in partitioning - this tests our defensive handling
|
||||
if event_timestamp_cols:
|
||||
event_timestamp_col = event_timestamp_cols[0]
|
||||
assert isinstance(event_timestamp_col, ColumnInfo)
|
||||
assert event_timestamp_col.name == "event_timestamp"
|
||||
# Type should be "unknown" since event_timestamp is not in the table columns
|
||||
assert event_timestamp_col.type == "unknown"
|
||||
|
||||
def test_partition_function_extraction_edge_cases(self):
|
||||
"""Test edge cases in partition function extraction with various formats."""
|
||||
sql = """
|
||||
CREATE TABLE test_partitions (
|
||||
ts timestamp,
|
||||
id bigint,
|
||||
data string
|
||||
)
|
||||
PARTITIONED BY (
|
||||
`day(ts)`,
|
||||
`bucket(5, id)`,
|
||||
`truncate(100, data)`
|
||||
) \
|
||||
"""
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
partition_info = result.partition_info
|
||||
transforms = partition_info.transforms
|
||||
|
||||
# Should have 3 transforms
|
||||
assert len(transforms) == 3
|
||||
|
||||
# Verify each transform type exists
|
||||
transform_types = {t.type for t in transforms}
|
||||
assert "day" in transform_types
|
||||
assert "bucket" in transform_types
|
||||
assert "truncate" in transform_types
|
||||
|
||||
# Test bucket transform parameters
|
||||
bucket_transforms = [t for t in transforms if t.type == "bucket"]
|
||||
if bucket_transforms:
|
||||
bucket_transform = bucket_transforms[0]
|
||||
assert bucket_transform.bucket_count == 5
|
||||
assert bucket_transform.column.name == "id"
|
||||
assert bucket_transform.column.type == "BIGINT"
|
||||
|
||||
# Test truncate transform parameters
|
||||
truncate_transforms = [t for t in transforms if t.type == "truncate"]
|
||||
if truncate_transforms:
|
||||
truncate_transform = truncate_transforms[0]
|
||||
assert truncate_transform.length == 100
|
||||
assert truncate_transform.column.name == "data"
|
||||
assert truncate_transform.column.type == "TEXT"
|
||||
|
||||
# Test day transform
|
||||
day_transforms = [t for t in transforms if t.type == "day"]
|
||||
if day_transforms:
|
||||
day_transform = day_transforms[0]
|
||||
assert day_transform.column.name == "ts"
|
||||
assert day_transform.column.type == "TIMESTAMP"
|
||||
|
||||
def test_partition_function_extraction_edge_cases_with_different_quote(self):
|
||||
"""Test edge cases in partition function extraction with various formats."""
|
||||
sql = """
|
||||
CREATE TABLE test_partitions (
|
||||
ts timestamp,
|
||||
id bigint,
|
||||
data string
|
||||
)
|
||||
PARTITIONED BY (
|
||||
day(`ts`),
|
||||
bucket(5, `id`),
|
||||
truncate(100, `data`)
|
||||
) \
|
||||
"""
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
partition_info = result.partition_info
|
||||
transforms = partition_info.transforms
|
||||
|
||||
# Should have 3 transforms
|
||||
assert len(transforms) == 3
|
||||
|
||||
# Verify each transform type exists
|
||||
transform_types = {t.type for t in transforms}
|
||||
assert "day" in transform_types
|
||||
assert "bucket" in transform_types
|
||||
assert "truncate" in transform_types
|
||||
|
||||
# Test bucket transform parameters
|
||||
bucket_transforms = [t for t in transforms if t.type == "bucket"]
|
||||
if bucket_transforms:
|
||||
bucket_transform = bucket_transforms[0]
|
||||
assert bucket_transform.bucket_count == 5
|
||||
assert bucket_transform.column.name == "id"
|
||||
assert bucket_transform.column.type == "BIGINT"
|
||||
|
||||
# Test truncate transform parameters
|
||||
truncate_transforms = [t for t in transforms if t.type == "truncate"]
|
||||
if truncate_transforms:
|
||||
truncate_transform = truncate_transforms[0]
|
||||
assert truncate_transform.length == 100
|
||||
assert truncate_transform.column.name == "data"
|
||||
assert truncate_transform.column.type == "TEXT"
|
||||
|
||||
# Test day transform
|
||||
day_transforms = [t for t in transforms if t.type == "day"]
|
||||
if day_transforms:
|
||||
day_transform = day_transforms[0]
|
||||
assert day_transform.column.name == "ts"
|
||||
assert day_transform.column.type == "TIMESTAMP"
|
||||
|
||||
def test_complex_real_world_example_with_non_escaped_column_name_and_column_comment(
|
||||
self,
|
||||
):
|
||||
"""Athena's show create table statement doesn't return columns in escaped."""
|
||||
sql = """
|
||||
CREATE TABLE test_schema.test_table (
|
||||
date___hour timestamp,
|
||||
month string COMMENT 'Month of the year',
|
||||
date string,
|
||||
hourly_forecast bigint,
|
||||
previous_year's_sales bigint COMMENT Previous year's sales,
|
||||
sheet_name string,
|
||||
_id string)
|
||||
LOCATION 's3://analytics-bucket/user-events/'
|
||||
TBLPROPERTIES (
|
||||
'table_type'='iceberg',
|
||||
'vacuum_max_snapshot_age_seconds'='60',
|
||||
'write_compression'='gzip',
|
||||
'format'='parquet',
|
||||
'optimize_rewrite_delete_file_threshold'='2',
|
||||
'optimize_rewrite_data_file_threshold'='5',
|
||||
'vacuum_min_snapshots_to_keep'='6'
|
||||
)
|
||||
"""
|
||||
|
||||
result = AthenaPropertiesExtractor.get_table_properties(sql)
|
||||
|
||||
# Comprehensive validation
|
||||
assert isinstance(result, AthenaTableInfo)
|
||||
|
||||
# Check table properties
|
||||
props = result.table_properties
|
||||
assert props.location == "s3://analytics-bucket/user-events/"
|
||||
assert props.additional_properties is not None
|
||||
assert props.additional_properties.get("table_type") == "iceberg"
|
||||
assert (
|
||||
props.additional_properties.get("vacuum_max_snapshot_age_seconds") == "60"
|
||||
)
|
||||
assert props.additional_properties.get("format") == "parquet"
|
||||
assert props.additional_properties.get("write_compression") == "gzip"
|
||||
assert (
|
||||
props.additional_properties.get("optimize_rewrite_delete_file_threshold")
|
||||
== "2"
|
||||
)
|
||||
assert (
|
||||
props.additional_properties.get("optimize_rewrite_data_file_threshold")
|
||||
== "5"
|
||||
)
|
||||
assert props.additional_properties.get("vacuum_min_snapshots_to_keep") == "6"
|
@ -1,8 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from pyathena import OperationalError
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy_bigquery import STRUCT
|
||||
|
||||
@ -81,6 +83,8 @@ def test_athena_get_table_properties():
|
||||
"aws_region": "us-west-1",
|
||||
"s3_staging_dir": "s3://sample-staging-dir/",
|
||||
"work_group": "test-workgroup",
|
||||
"profiling": {"enabled": True, "partition_profiling_enabled": True},
|
||||
"extract_partitions_using_create_statements": True,
|
||||
}
|
||||
)
|
||||
schema: str = "test_schema"
|
||||
@ -108,17 +112,33 @@ def test_athena_get_table_properties():
|
||||
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_inspector = mock.MagicMock()
|
||||
mock_inspector.engine.raw_connection().cursor.return_value = mock_cursor
|
||||
mock_cursor.get_table_metadata.return_value = AthenaTableMetadata(
|
||||
response=table_metadata
|
||||
)
|
||||
|
||||
class MockCursorResult:
|
||||
def __init__(self, data: List, description: List):
|
||||
self._data = data
|
||||
self._description = description
|
||||
|
||||
def __iter__(self):
|
||||
"""Makes the object iterable, which allows list() to work"""
|
||||
return iter(self._data)
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
"""Returns the description as requested"""
|
||||
return self._description
|
||||
|
||||
mock_result = MockCursorResult(
|
||||
data=[["2023", "12"]], description=[["year"], ["month"]]
|
||||
)
|
||||
# Mock partition query results
|
||||
mock_cursor.execute.return_value.description = [
|
||||
["year"],
|
||||
["month"],
|
||||
mock_cursor.execute.side_effect = [
|
||||
OperationalError("First call fails"),
|
||||
mock_result,
|
||||
]
|
||||
mock_cursor.execute.return_value.__iter__.return_value = [["2023", "12"]]
|
||||
mock_cursor.fetchall.side_effect = [OperationalError("First call fails")]
|
||||
|
||||
ctx = PipelineContext(run_id="test")
|
||||
source = AthenaSource(config=config, ctx=ctx)
|
||||
@ -148,13 +168,16 @@ def test_athena_get_table_properties():
|
||||
assert partitions == ["year", "month"]
|
||||
|
||||
# Verify the correct SQL query was generated for partitions
|
||||
expected_create_table_query = "SHOW CREATE TABLE `test_schema`.`test_table`"
|
||||
|
||||
expected_query = """\
|
||||
select year,month from "test_schema"."test_table$partitions" \
|
||||
where CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR) = \
|
||||
(select max(CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR)) \
|
||||
from "test_schema"."test_table$partitions")"""
|
||||
mock_cursor.execute.assert_called_once()
|
||||
actual_query = mock_cursor.execute.call_args[0][0]
|
||||
assert mock_cursor.execute.call_count == 2
|
||||
assert expected_create_table_query == mock_cursor.execute.call_args_list[0][0][0]
|
||||
actual_query = mock_cursor.execute.call_args_list[1][0][0]
|
||||
assert actual_query == expected_query
|
||||
|
||||
# Verify partition cache was populated correctly
|
||||
|
Loading…
x
Reference in New Issue
Block a user