datahub/metadata-ingestion/tests/unit/test_athena_properties_extractor.py

767 lines
31 KiB
Python

"""
Pytest tests for AthenaPropertiesExtractor.
Tests the extraction of properties, partitioning information,
and row format details from various Athena CREATE TABLE SQL statements.
"""
import pytest
from datahub.ingestion.source.sql.athena_properties_extractor import (
AthenaPropertiesExtractionError,
AthenaPropertiesExtractor,
AthenaTableInfo,
ColumnInfo,
PartitionInfo,
RowFormatInfo,
TableProperties,
TransformInfo,
)
class TestAthenaPropertiesExtractor:
"""Test class for AthenaPropertiesExtractor."""
def test_iceberg_table_with_complex_partitioning(self):
"""Test extraction from Iceberg table with complex partitioning."""
sql = """
CREATE TABLE iceberg_table (ts timestamp, id bigint, data string, category string)
PARTITIONED BY (category, bucket(16, id), year(ts), month(ts), day(ts), hour(ts), truncate(10, ts))
LOCATION 's3://amzn-s3-demo-bucket/your-folder/'
TBLPROPERTIES ( 'table_type' = 'ICEBERG' ) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test basic structure
assert isinstance(result, AthenaTableInfo)
assert isinstance(result.partition_info, PartitionInfo)
assert isinstance(result.table_properties, TableProperties)
assert isinstance(result.row_format, RowFormatInfo)
# Test partition info
partition_info = result.partition_info
# Should have multiple simple columns
assert len(partition_info.simple_columns) > 0
# Check for category column (simple partition)
category_cols = [
col for col in partition_info.simple_columns if col.name == "category"
]
assert len(category_cols) == 1
assert category_cols[0].type == "TEXT"
# Check for id column (used in bucket transform)
id_cols = [col for col in partition_info.simple_columns if col.name == "id"]
assert len(id_cols) == 1
assert id_cols[0].type == "BIGINT"
# Check for ts column (used in time transforms)
ts_cols = [col for col in partition_info.simple_columns if col.name == "ts"]
assert len(ts_cols) == 1
assert ts_cols[0].type == "TIMESTAMP"
# Test transforms
transforms = partition_info.transforms
assert len(transforms) >= 6 # bucket, year, month, day, hour, truncate
# Check bucket transform
bucket_transforms = [t for t in transforms if t.type == "bucket"]
assert len(bucket_transforms) == 1
bucket_transform = bucket_transforms[0]
assert bucket_transform.column.name == "id"
assert bucket_transform.bucket_count == 16
# Check time transforms
time_transform_types = {
t.type for t in transforms if t.type in ["year", "month", "day", "hour"]
}
assert "year" in time_transform_types
assert "month" in time_transform_types
assert "day" in time_transform_types
assert "hour" in time_transform_types
# Check truncate transform
truncate_transforms = [t for t in transforms if t.type == "truncate"]
assert len(truncate_transforms) == 1
truncate_transform = truncate_transforms[0]
assert truncate_transform.column.name == "ts"
assert truncate_transform.length == 10
# Test table properties
table_props = result.table_properties
assert table_props.location == "s3://amzn-s3-demo-bucket/your-folder/"
assert table_props.additional_properties is not None
assert table_props.additional_properties.get("table_type") == "ICEBERG"
def test_trino_table_with_array_partitioning(self):
"""Test extraction from Trino table with ARRAY partitioning."""
sql = """
create table trino.db_collection (
col1 varchar,
col2 varchar,
col3 varchar
)with (
external_location = 's3a://bucket/trino/db_collection/*',
format = 'PARQUET',
partitioned_by = ARRAY['col1','col2']
) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test table properties
table_props = result.table_properties
assert table_props.location == "s3a://bucket/trino/db_collection/*"
assert table_props.format == "PARQUET"
# Note: ARRAY partitioning might not be parsed the same way as standard PARTITIONED BY
# This tests that the extraction doesn't fail and extracts what it can
def test_simple_orc_table(self):
"""Test extraction from simple ORC table."""
sql = """
CREATE TABLE orders (
orderkey bigint,
orderstatus varchar,
totalprice double,
orderdate date
)
WITH (format = 'ORC')
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test basic structure
assert isinstance(result, AthenaTableInfo)
# Should have no partitions
assert len(result.partition_info.simple_columns) == 0
assert len(result.partition_info.transforms) == 0
# Test table properties
table_props = result.table_properties
assert table_props.format == "ORC"
assert table_props.location is None
assert table_props.comment is None
def test_table_with_comments(self):
"""Test extraction from table with table and column comments."""
sql = """
CREATE TABLE IF NOT EXISTS orders (
orderkey bigint,
orderstatus varchar,
totalprice double COMMENT 'Price in cents.',
orderdate date
)
COMMENT 'A table to keep track of orders.' \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test table comment
table_props = result.table_properties
assert table_props.comment == "A table to keep track of orders."
# No partitions expected
assert len(result.partition_info.simple_columns) == 0
assert len(result.partition_info.transforms) == 0
def test_table_with_row_format_and_serde(self):
"""Test extraction from table with row format and SERDE properties."""
sql = """
CREATE TABLE IF NOT EXISTS orders (
orderkey bigint,
orderstatus varchar,
totalprice double,
orderdate date
)
ROW FORMAT DELIMITED COLLECTION ITEMS TERMINATED BY ','
STORED AS PARQUET
WITH SERDEPROPERTIES (
'serialization.format' = '1'
) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test table properties
table_props = result.table_properties
# Test SERDE properties
assert table_props.serde_properties is not None
assert table_props.serde_properties.get("serialization.format") == "1"
# Test row format
row_format = result.row_format
assert isinstance(row_format, RowFormatInfo)
assert isinstance(row_format.properties, dict)
assert "No RowFormatDelimitedProperty found" not in row_format.json_formatted
def test_empty_sql_raises_error(self):
"""Test that empty SQL raises appropriate error."""
with pytest.raises(
AthenaPropertiesExtractionError, match="SQL statement cannot be empty"
):
AthenaPropertiesExtractor.get_table_properties("")
with pytest.raises(
AthenaPropertiesExtractionError, match="SQL statement cannot be empty"
):
AthenaPropertiesExtractor.get_table_properties(" ")
def test_minimal_create_table(self):
"""Test extraction from minimal CREATE TABLE statement."""
sql = "CREATE TABLE test (id int)"
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Should not fail and return basic structure
assert isinstance(result, AthenaTableInfo)
assert len(result.partition_info.simple_columns) == 0
assert len(result.partition_info.transforms) == 0
assert result.table_properties.location is None
def test_column_info_dataclass(self):
"""Test ColumnInfo dataclass properties."""
sql = """
CREATE TABLE test (id bigint, name varchar)
PARTITIONED BY (id) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test that we get ColumnInfo objects
assert len(result.partition_info.simple_columns) == 1
column = result.partition_info.simple_columns[0]
assert isinstance(column, ColumnInfo)
assert column.name == "id"
assert column.type == "BIGINT"
def test_transform_info_dataclass(self):
"""Test TransformInfo dataclass properties."""
sql = """
CREATE TABLE test (ts timestamp, id bigint)
PARTITIONED BY (year(ts), bucket(8, id)) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
transforms = result.partition_info.transforms
assert len(transforms) >= 2
# Find year transform
year_transforms = [t for t in transforms if t.type == "year"]
assert len(year_transforms) == 1
year_transform = year_transforms[0]
assert isinstance(year_transform, TransformInfo)
assert year_transform.type == "year"
assert isinstance(year_transform.column, ColumnInfo)
assert year_transform.column.name == "ts"
assert year_transform.bucket_count is None
assert year_transform.length is None
# Find bucket transform
bucket_transforms = [t for t in transforms if t.type == "bucket"]
assert len(bucket_transforms) == 1
bucket_transform = bucket_transforms[0]
assert isinstance(bucket_transform, TransformInfo)
assert bucket_transform.type == "bucket"
assert bucket_transform.column.name == "id"
assert bucket_transform.bucket_count == 8
assert bucket_transform.length is None
def test_multiple_sql_statements_stateless(self):
"""Test that the extractor is stateless and works with multiple SQL statements."""
sql1 = "CREATE TABLE test1 (id int) WITH (format = 'PARQUET')"
sql2 = "CREATE TABLE test2 (name varchar) WITH (format = 'ORC')"
# Call multiple times to ensure no state interference
result1 = AthenaPropertiesExtractor.get_table_properties(sql1)
result2 = AthenaPropertiesExtractor.get_table_properties(sql2)
result1_again = AthenaPropertiesExtractor.get_table_properties(sql1)
# Results should be consistent
assert result1.table_properties.format == "PARQUET"
assert result2.table_properties.format == "ORC"
assert result1_again.table_properties.format == "PARQUET"
# Results should be independent
assert result1.table_properties.format != result2.table_properties.format
@pytest.mark.parametrize(
"sql,expected_location",
[
(
"CREATE TABLE test (id int) LOCATION 's3://bucket/path/'",
"s3://bucket/path/",
),
("CREATE TABLE test (id int)", None),
],
)
def test_location_extraction_parametrized(self, sql, expected_location):
"""Test location extraction with parametrized inputs."""
result = AthenaPropertiesExtractor.get_table_properties(sql)
assert result.table_properties.location == expected_location
# Integration test that could be run with actual SQL files
class TestAthenaPropertiesExtractorIntegration:
"""Integration tests for AthenaPropertiesExtractor."""
def test_complex_real_world_example(self):
"""Test with a complex real-world-like example."""
sql = """
CREATE TABLE analytics.user_events (
user_id bigint COMMENT 'Unique user identifier',
event_time timestamp COMMENT 'When the event occurred',
event_type varchar COMMENT 'Type of event',
session_id varchar,
properties map<varchar, varchar> COMMENT 'Event properties',
created_date date
)
COMMENT 'User event tracking table'
PARTITIONED BY (
created_date,
bucket(100, user_id),
hour(event_time)
)
LOCATION 's3://analytics-bucket/user-events/'
STORED AS PARQUET
TBLPROPERTIES (
'table_type' = 'ICEBERG',
'write.target-file-size-bytes' = '134217728',
'write.delete.mode' = 'copy-on-write'
) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Comprehensive validation
assert isinstance(result, AthenaTableInfo)
# Check table properties
props = result.table_properties
assert props.location == "s3://analytics-bucket/user-events/"
assert props.comment == "User event tracking table"
assert props.additional_properties is not None
assert props.additional_properties.get("table_type") == "ICEBERG"
# Check partitioning
partition_info = result.partition_info
# Should have created_date as simple partition
date_cols = [
col for col in partition_info.simple_columns if col.name == "user_id"
]
assert len(date_cols) == 1
assert date_cols[0].type == "BIGINT"
date_cols = [
col for col in partition_info.simple_columns if col.name == "created_date"
]
assert len(date_cols) == 1
assert date_cols[0].type == "DATE"
# Should have transforms
transforms = partition_info.transforms
transform_types = {t.type for t in transforms}
assert "bucket" in transform_types
assert "hour" in transform_types
# Validate bucket transform
bucket_transforms = [t for t in transforms if t.type == "bucket"]
assert len(bucket_transforms) == 1
assert bucket_transforms[0].bucket_count == 100
assert bucket_transforms[0].column.name == "user_id"
def test_external_table_with_row_format_delimited(self):
"""Test extraction from external table with detailed row format."""
sql = """
CREATE EXTERNAL TABLE `my_table`(
`itcf id` string,
`itcf control name` string,
`itcf control description` string,
`itcf process` string,
`standard` string,
`controlid` string,
`threshold` string,
`status` string,
`date reported` string,
`remediation (accs specific)` string,
`aws account id` string,
`aws resource id` string,
`aws account owner` string)
ROW FORMAT DELIMITED
FIELDS TERMINATED BY ','
ESCAPED BY '\\\\'
LINES TERMINATED BY '\\n'
LOCATION
's3://myfolder/'
TBLPROPERTIES (
'skip.header.line.count'='1');
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test basic structure
assert isinstance(result, AthenaTableInfo)
# Test table properties
table_props = result.table_properties
assert table_props.location == "s3://myfolder/"
# Test TBLPROPERTIES
assert table_props.additional_properties is not None
assert table_props.additional_properties.get("skip.header.line.count") == "1"
# Test row format
row_format = result.row_format
assert isinstance(row_format, RowFormatInfo)
# The row format should contain delimited properties
# Note: The exact keys depend on how sqlglot parses ROW FORMAT DELIMITED
assert isinstance(row_format.properties, dict)
# Should have structured JSON output
assert row_format.json_formatted != "No RowFormatDelimitedProperty found"
# Should not have partitions (no PARTITIONED BY clause)
assert len(result.partition_info.simple_columns) == 0
assert len(result.partition_info.transforms) == 0
def test_database_qualified_table_with_iceberg_properties(self):
"""Test extraction from database-qualified table with Iceberg properties."""
sql = """
CREATE TABLE mydatabase.my_table (
id string,
name string,
type string,
industry string,
annual_revenue double,
website string,
phone string,
billing_street string,
billing_city string,
billing_state string,
billing_postal_code string,
billing_country string,
shipping_street string,
shipping_city string,
shipping_state string,
shipping_postal_code string,
shipping_country string,
number_of_employees int,
description string,
owner_id string,
created_date timestamp,
last_modified_date timestamp,
is_deleted boolean)
LOCATION 's3://mybucket/myfolder/'
TBLPROPERTIES (
'table_type'='iceberg',
'write_compression'='snappy',
'format'='parquet',
'optimize_rewrite_delete_file_threshold'='10'
); \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test basic structure
assert isinstance(result, AthenaTableInfo)
# Test table properties
table_props = result.table_properties
assert table_props.location == "s3://mybucket/myfolder/"
# Test multiple TBLPROPERTIES
assert table_props.additional_properties is not None
expected_props = {
"table_type": "iceberg",
"write_compression": "snappy",
"format": "parquet",
"optimize_rewrite_delete_file_threshold": "10",
}
for key, expected_value in expected_props.items():
assert table_props.additional_properties.get(key) == expected_value, (
f"Expected {key}={expected_value}, got {table_props.additional_properties.get(key)}"
)
# Should not have partitions (no PARTITIONED BY clause)
assert len(result.partition_info.simple_columns) == 0
assert len(result.partition_info.transforms) == 0
# Row format should be empty/default
row_format = result.row_format
assert isinstance(row_format, RowFormatInfo)
# Should either be empty dict or indicate no row format found
assert (
len(row_format.properties) == 0
or "No RowFormatDelimitedProperty found" in row_format.json_formatted
)
def test_iceberg_table_with_backtick_partitioning(self):
"""Test extraction from Iceberg table with backtick-quoted partition functions."""
sql = """
CREATE TABLE datalake_agg.ml_outdoor_master (
event_uuid string,
uuid string,
_pk string)
PARTITIONED BY (
`day(event_timestamp)`,
`month(event_timestamp)`
)
LOCATION 's3://bucket/folder/table'
TBLPROPERTIES (
'table_type'='iceberg',
'vacuum_max_snapshot_age_seconds'='60',
'format'='PARQUET',
'write_compression'='GZIP',
'optimize_rewrite_delete_file_threshold'='2',
'optimize_rewrite_data_file_threshold'='5',
'vacuum_min_snapshots_to_keep'='6'
) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Test basic structure
assert isinstance(result, AthenaTableInfo)
# Test table properties
table_props = result.table_properties
assert table_props.location == "s3://bucket/folder/table"
# Test comprehensive TBLPROPERTIES for Iceberg
assert table_props.additional_properties is not None
expected_props = {
"table_type": "iceberg",
"vacuum_max_snapshot_age_seconds": "60",
"format": "PARQUET",
"write_compression": "GZIP",
"optimize_rewrite_delete_file_threshold": "2",
"optimize_rewrite_data_file_threshold": "5",
"vacuum_min_snapshots_to_keep": "6",
}
for key, expected_value in expected_props.items():
actual_value = table_props.additional_properties.get(key)
assert actual_value == expected_value, (
f"Expected {key}={expected_value}, got {actual_value}"
)
# Test partition info - this is the interesting part with backtick-quoted functions
partition_info = result.partition_info
# Should have transforms for day() and month() functions
transforms = partition_info.transforms
assert len(transforms) >= 2, (
f"Expected at least 2 transforms, got {len(transforms)}"
)
# Check for day transform
day_transforms = [t for t in transforms if t.type == "day"]
assert len(day_transforms) >= 1, (
f"Expected day transform, transforms: {[t.type for t in transforms]}"
)
if day_transforms:
day_transform = day_transforms[0]
assert isinstance(day_transform, TransformInfo)
assert day_transform.type == "day"
assert isinstance(day_transform.column, ColumnInfo)
# The column should be event_timestamp (extracted from day(event_timestamp))
assert day_transform.column.name == "event_timestamp"
# Check for month transform
month_transforms = [t for t in transforms if t.type == "month"]
assert len(month_transforms) >= 1, (
f"Expected month transform, transforms: {[t.type for t in transforms]}"
)
if month_transforms:
month_transform = month_transforms[0]
assert isinstance(month_transform, TransformInfo)
assert month_transform.type == "month"
assert isinstance(month_transform.column, ColumnInfo)
# The column should be event_timestamp (extracted from month(event_timestamp))
assert month_transform.column.name == "event_timestamp"
# Test simple columns - should include event_timestamp from the transforms
simple_columns = partition_info.simple_columns
event_timestamp_cols = [
col for col in simple_columns if col.name == "event_timestamp"
]
assert len(event_timestamp_cols) >= 1, (
f"Expected event_timestamp column, columns: {[col.name for col in simple_columns]}"
)
# The event_timestamp column type might be "unknown" since it's not in the table definition
# but referenced in partitioning - this tests our defensive handling
if event_timestamp_cols:
event_timestamp_col = event_timestamp_cols[0]
assert isinstance(event_timestamp_col, ColumnInfo)
assert event_timestamp_col.name == "event_timestamp"
# Type should be "unknown" since event_timestamp is not in the table columns
assert event_timestamp_col.type == "unknown"
def test_partition_function_extraction_edge_cases(self):
"""Test edge cases in partition function extraction with various formats."""
sql = """
CREATE TABLE test_partitions (
ts timestamp,
id bigint,
data string
)
PARTITIONED BY (
`day(ts)`,
`bucket(5, id)`,
`truncate(100, data)`
) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
partition_info = result.partition_info
transforms = partition_info.transforms
# Should have 3 transforms
assert len(transforms) == 3
# Verify each transform type exists
transform_types = {t.type for t in transforms}
assert "day" in transform_types
assert "bucket" in transform_types
assert "truncate" in transform_types
# Test bucket transform parameters
bucket_transforms = [t for t in transforms if t.type == "bucket"]
if bucket_transforms:
bucket_transform = bucket_transforms[0]
assert bucket_transform.bucket_count == 5
assert bucket_transform.column.name == "id"
assert bucket_transform.column.type == "BIGINT"
# Test truncate transform parameters
truncate_transforms = [t for t in transforms if t.type == "truncate"]
if truncate_transforms:
truncate_transform = truncate_transforms[0]
assert truncate_transform.length == 100
assert truncate_transform.column.name == "data"
assert truncate_transform.column.type == "TEXT"
# Test day transform
day_transforms = [t for t in transforms if t.type == "day"]
if day_transforms:
day_transform = day_transforms[0]
assert day_transform.column.name == "ts"
assert day_transform.column.type == "TIMESTAMP"
def test_partition_function_extraction_edge_cases_with_different_quote(self):
"""Test edge cases in partition function extraction with various formats."""
sql = """
CREATE TABLE test_partitions (
ts timestamp,
id bigint,
data string
)
PARTITIONED BY (
day(`ts`),
bucket(5, `id`),
truncate(100, `data`)
) \
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
partition_info = result.partition_info
transforms = partition_info.transforms
# Should have 3 transforms
assert len(transforms) == 3
# Verify each transform type exists
transform_types = {t.type for t in transforms}
assert "day" in transform_types
assert "bucket" in transform_types
assert "truncate" in transform_types
# Test bucket transform parameters
bucket_transforms = [t for t in transforms if t.type == "bucket"]
if bucket_transforms:
bucket_transform = bucket_transforms[0]
assert bucket_transform.bucket_count == 5
assert bucket_transform.column.name == "id"
assert bucket_transform.column.type == "BIGINT"
# Test truncate transform parameters
truncate_transforms = [t for t in transforms if t.type == "truncate"]
if truncate_transforms:
truncate_transform = truncate_transforms[0]
assert truncate_transform.length == 100
assert truncate_transform.column.name == "data"
assert truncate_transform.column.type == "TEXT"
# Test day transform
day_transforms = [t for t in transforms if t.type == "day"]
if day_transforms:
day_transform = day_transforms[0]
assert day_transform.column.name == "ts"
assert day_transform.column.type == "TIMESTAMP"
def test_complex_real_world_example_with_non_escaped_column_name_and_column_comment(
self,
):
"""Athena's show create table statement doesn't return columns in escaped."""
sql = """
CREATE TABLE test_schema.test_table (
date___hour timestamp,
month string COMMENT 'Month of the year',
date string,
hourly_forecast bigint,
previous_year's_sales bigint COMMENT Previous year's sales,
sheet_name string,
_id string)
LOCATION 's3://analytics-bucket/user-events/'
TBLPROPERTIES (
'table_type'='iceberg',
'vacuum_max_snapshot_age_seconds'='60',
'write_compression'='gzip',
'format'='parquet',
'optimize_rewrite_delete_file_threshold'='2',
'optimize_rewrite_data_file_threshold'='5',
'vacuum_min_snapshots_to_keep'='6'
)
"""
result = AthenaPropertiesExtractor.get_table_properties(sql)
# Comprehensive validation
assert isinstance(result, AthenaTableInfo)
# Check table properties
props = result.table_properties
assert props.location == "s3://analytics-bucket/user-events/"
assert props.additional_properties is not None
assert props.additional_properties.get("table_type") == "iceberg"
assert (
props.additional_properties.get("vacuum_max_snapshot_age_seconds") == "60"
)
assert props.additional_properties.get("format") == "parquet"
assert props.additional_properties.get("write_compression") == "gzip"
assert (
props.additional_properties.get("optimize_rewrite_delete_file_threshold")
== "2"
)
assert (
props.additional_properties.get("optimize_rewrite_data_file_threshold")
== "5"
)
assert props.additional_properties.get("vacuum_min_snapshots_to_keep") == "6"