mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-26 09:26:22 +00:00
fix(ingest/athena): Fix Athena partition extraction and CONCAT function type issues (#14712)
This commit is contained in:
parent
c7ad3f45ea
commit
a82d4e0647
@ -73,6 +73,11 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Precompiled regex for SQL identifier validation
|
||||
# Athena identifiers can only contain lowercase letters, numbers, underscore, and period (for complex types)
|
||||
# Note: Athena automatically converts uppercase to lowercase, but we're being strict for security
|
||||
_IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z0-9_.]+$")
|
||||
|
||||
assert STRUCT, "required type modules are not available"
|
||||
register_custom_type(STRUCT, RecordTypeClass)
|
||||
register_custom_type(MapType, MapTypeClass)
|
||||
@ -510,20 +515,76 @@ class AthenaSource(SQLAlchemySource):
|
||||
return [schema for schema in schemas if schema == athena_config.database]
|
||||
return schemas
|
||||
|
||||
@classmethod
|
||||
def _sanitize_identifier(cls, identifier: str) -> str:
|
||||
"""Sanitize SQL identifiers to prevent injection attacks.
|
||||
|
||||
Args:
|
||||
identifier: The SQL identifier to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized identifier safe for SQL queries
|
||||
|
||||
Raises:
|
||||
ValueError: If identifier contains unsafe characters
|
||||
"""
|
||||
if not identifier:
|
||||
raise ValueError("Identifier cannot be empty")
|
||||
|
||||
# Allow only alphanumeric characters, underscores, and periods for identifiers
|
||||
# This matches Athena's identifier naming rules
|
||||
if not _IDENTIFIER_PATTERN.match(identifier):
|
||||
raise ValueError(
|
||||
f"Identifier '{identifier}' contains unsafe characters. Only alphanumeric characters, underscores, and periods are allowed."
|
||||
)
|
||||
|
||||
return identifier
|
||||
|
||||
@classmethod
|
||||
def _casted_partition_key(cls, key: str) -> str:
|
||||
# We need to cast the partition keys to a VARCHAR, since otherwise
|
||||
# Athena may throw an error during concatenation / comparison.
|
||||
return f"CAST({key} as VARCHAR)"
|
||||
sanitized_key = cls._sanitize_identifier(key)
|
||||
return f"CAST({sanitized_key} as VARCHAR)"
|
||||
|
||||
@classmethod
|
||||
def _build_max_partition_query(
|
||||
cls, schema: str, table: str, partitions: List[str]
|
||||
) -> str:
|
||||
"""Build SQL query to find the row with maximum partition values.
|
||||
|
||||
Args:
|
||||
schema: Database schema name
|
||||
table: Table name
|
||||
partitions: List of partition column names
|
||||
|
||||
Returns:
|
||||
SQL query string to find the maximum partition
|
||||
|
||||
Raises:
|
||||
ValueError: If any identifier contains unsafe characters
|
||||
"""
|
||||
# Sanitize all identifiers to prevent SQL injection
|
||||
sanitized_schema = cls._sanitize_identifier(schema)
|
||||
sanitized_table = cls._sanitize_identifier(table)
|
||||
sanitized_partitions = [
|
||||
cls._sanitize_identifier(partition) for partition in partitions
|
||||
]
|
||||
|
||||
casted_keys = [cls._casted_partition_key(key) for key in partitions]
|
||||
if len(casted_keys) == 1:
|
||||
part_concat = casted_keys[0]
|
||||
else:
|
||||
separator = "CAST('-' AS VARCHAR)"
|
||||
part_concat = f"CONCAT({f', {separator}, '.join(casted_keys)})"
|
||||
|
||||
return f'select {",".join(sanitized_partitions)} from "{sanitized_schema}"."{sanitized_table}$partitions" where {part_concat} = (select max({part_concat}) from "{sanitized_schema}"."{sanitized_table}$partitions")'
|
||||
|
||||
@override
|
||||
def get_partitions(
|
||||
self, inspector: Inspector, schema: str, table: str
|
||||
) -> Optional[List[str]]:
|
||||
if (
|
||||
not self.config.extract_partitions
|
||||
and not self.config.extract_partitions_using_create_statements
|
||||
):
|
||||
if not self.config.extract_partitions:
|
||||
return None
|
||||
|
||||
if not self.cursor:
|
||||
@ -557,11 +618,9 @@ class AthenaSource(SQLAlchemySource):
|
||||
context=f"{schema}.{table}",
|
||||
level=StructuredLogLevel.WARN,
|
||||
):
|
||||
# We create an artifical concatenated partition key to be able to query max partition easier
|
||||
part_concat = " || '-' || ".join(
|
||||
self._casted_partition_key(key) for key in partitions
|
||||
max_partition_query = self._build_max_partition_query(
|
||||
schema, table, partitions
|
||||
)
|
||||
max_partition_query = f'select {",".join(partitions)} from "{schema}"."{table}$partitions" where {part_concat} = (select max({part_concat}) from "{schema}"."{table}$partitions")'
|
||||
ret = self.cursor.execute(max_partition_query)
|
||||
max_partition: Dict[str, str] = {}
|
||||
if ret:
|
||||
@ -678,16 +737,34 @@ class AthenaSource(SQLAlchemySource):
|
||||
).get(table, None)
|
||||
|
||||
if partition and partition.max_partition:
|
||||
max_partition_filters = []
|
||||
for key, value in partition.max_partition.items():
|
||||
max_partition_filters.append(
|
||||
f"{self._casted_partition_key(key)} = '{value}'"
|
||||
try:
|
||||
# Sanitize identifiers to prevent SQL injection
|
||||
sanitized_schema = self._sanitize_identifier(schema)
|
||||
sanitized_table = self._sanitize_identifier(table)
|
||||
|
||||
max_partition_filters = []
|
||||
for key, value in partition.max_partition.items():
|
||||
# Sanitize partition key and properly escape the value
|
||||
sanitized_key = self._sanitize_identifier(key)
|
||||
# Escape single quotes in the value to prevent injection
|
||||
escaped_value = value.replace("'", "''") if value else ""
|
||||
max_partition_filters.append(
|
||||
f"{self._casted_partition_key(sanitized_key)} = '{escaped_value}'"
|
||||
)
|
||||
max_partition = str(partition.max_partition)
|
||||
return (
|
||||
max_partition,
|
||||
f'SELECT * FROM "{sanitized_schema}"."{sanitized_table}" WHERE {" AND ".join(max_partition_filters)}',
|
||||
)
|
||||
max_partition = str(partition.max_partition)
|
||||
return (
|
||||
max_partition,
|
||||
f'SELECT * FROM "{schema}"."{table}" WHERE {" AND ".join(max_partition_filters)}',
|
||||
)
|
||||
except ValueError as e:
|
||||
# If sanitization fails due to malicious identifiers,
|
||||
# return None to disable partition profiling for this table
|
||||
# rather than crashing the entire ingestion
|
||||
logger.warning(
|
||||
f"Failed to generate partition profiler query for {schema}.{table} due to unsafe identifiers: {e}. "
|
||||
f"Partition profiling disabled for this table."
|
||||
)
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
def close(self):
|
||||
|
||||
@ -174,20 +174,16 @@ class AthenaPropertiesExtractor:
|
||||
def format_column_definition(line):
|
||||
# Use regex to parse the line more accurately
|
||||
# Pattern: column_name data_type [COMMENT comment_text] [,]
|
||||
# Use greedy match for comment to capture everything until trailing comma
|
||||
pattern = r"^\s*(.+?)\s+([\s,\w<>\[\]]+)((\s+COMMENT\s+(.+?)(,?))|(,?)\s*)?$"
|
||||
match = re.match(pattern, line, re.IGNORECASE)
|
||||
# Improved pattern to better separate column name, data type, and comment
|
||||
pattern = r"^\s*([`\w']+)\s+([\w<>\[\](),\s]+?)(\s+COMMENT\s+(.+?))?(,?)\s*$"
|
||||
match = re.match(pattern, line.strip(), re.IGNORECASE)
|
||||
|
||||
if not match:
|
||||
return line
|
||||
column_name = match.group(1)
|
||||
data_type = match.group(2)
|
||||
comment_part = match.group(5) # COMMENT part
|
||||
# there are different number of match groups depending on whether comment exists
|
||||
if comment_part:
|
||||
trailing_comma = match.group(6) if match.group(6) else ""
|
||||
else:
|
||||
trailing_comma = match.group(7) if match.group(7) else ""
|
||||
column_name = match.group(1).strip()
|
||||
data_type = match.group(2).strip()
|
||||
comment_part = match.group(4) # COMMENT part
|
||||
trailing_comma = match.group(5) if match.group(5) else ""
|
||||
|
||||
# Add backticks to column name if not already present
|
||||
if not (column_name.startswith("`") and column_name.endswith("`")):
|
||||
@ -201,7 +197,7 @@ class AthenaPropertiesExtractor:
|
||||
|
||||
# Handle comment quoting and escaping
|
||||
if comment_part.startswith("'") and comment_part.endswith("'"):
|
||||
# Already properly single quoted - but check for proper escaping
|
||||
# Already single quoted - but check for proper escaping
|
||||
inner_content = comment_part[1:-1]
|
||||
# Re-escape any single quotes that aren't properly escaped
|
||||
escaped_content = inner_content.replace("'", "''")
|
||||
|
||||
@ -3,10 +3,12 @@ from typing import List
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import sqlglot
|
||||
from freezegun import freeze_time
|
||||
from pyathena import OperationalError
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy_bigquery import STRUCT
|
||||
from sqlglot.dialects import Athena
|
||||
|
||||
from datahub.ingestion.api.common import PipelineContext
|
||||
from datahub.ingestion.source.aws.s3_util import make_s3_urn
|
||||
@ -14,6 +16,7 @@ from datahub.ingestion.source.sql.athena import (
|
||||
AthenaConfig,
|
||||
AthenaSource,
|
||||
CustomAthenaRestDialect,
|
||||
Partitionitem,
|
||||
)
|
||||
from datahub.metadata.schema_classes import (
|
||||
ArrayTypeClass,
|
||||
@ -182,8 +185,8 @@ def test_athena_get_table_properties():
|
||||
|
||||
expected_query = """\
|
||||
select year,month from "test_schema"."test_table$partitions" \
|
||||
where CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR) = \
|
||||
(select max(CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR)) \
|
||||
where CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR)) = \
|
||||
(select max(CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR))) \
|
||||
from "test_schema"."test_table$partitions")"""
|
||||
assert mock_cursor.execute.call_count == 2
|
||||
assert expected_create_table_query == mock_cursor.execute.call_args_list[0][0][0]
|
||||
@ -503,3 +506,641 @@ def test_convert_simple_field_paths_to_v1_default_behavior():
|
||||
)
|
||||
|
||||
assert config.emit_schema_fieldpaths_as_v1 is False # Should default to False
|
||||
|
||||
|
||||
def test_get_partitions_returns_none_when_extract_partitions_disabled():
|
||||
"""Test that get_partitions returns None when extract_partitions is False"""
|
||||
config = AthenaConfig.parse_obj(
|
||||
{
|
||||
"aws_region": "us-west-1",
|
||||
"query_result_location": "s3://query-result-location/",
|
||||
"work_group": "test-workgroup",
|
||||
"extract_partitions": False,
|
||||
}
|
||||
)
|
||||
|
||||
ctx = PipelineContext(run_id="test")
|
||||
source = AthenaSource(config=config, ctx=ctx)
|
||||
|
||||
# Mock inspector - should not be used if extract_partitions is False
|
||||
mock_inspector = mock.MagicMock()
|
||||
|
||||
# Call get_partitions - should return None immediately without any database calls
|
||||
result = source.get_partitions(mock_inspector, "test_schema", "test_table")
|
||||
|
||||
# Verify result is None
|
||||
assert result is None
|
||||
|
||||
# Verify that no inspector methods were called
|
||||
assert not mock_inspector.called, (
|
||||
"Inspector should not be called when extract_partitions=False"
|
||||
)
|
||||
|
||||
|
||||
def test_get_partitions_attempts_extraction_when_extract_partitions_enabled():
|
||||
"""Test that get_partitions attempts partition extraction when extract_partitions is True"""
|
||||
config = AthenaConfig.parse_obj(
|
||||
{
|
||||
"aws_region": "us-west-1",
|
||||
"query_result_location": "s3://query-result-location/",
|
||||
"work_group": "test-workgroup",
|
||||
"extract_partitions": True,
|
||||
}
|
||||
)
|
||||
|
||||
ctx = PipelineContext(run_id="test")
|
||||
source = AthenaSource(config=config, ctx=ctx)
|
||||
|
||||
# Mock inspector and cursor for partition extraction
|
||||
mock_inspector = mock.MagicMock()
|
||||
mock_cursor = mock.MagicMock()
|
||||
|
||||
# Mock the table metadata response
|
||||
mock_metadata = mock.MagicMock()
|
||||
mock_partition_key = mock.MagicMock()
|
||||
mock_partition_key.name = "year"
|
||||
mock_metadata.partition_keys = [mock_partition_key]
|
||||
mock_cursor.get_table_metadata.return_value = mock_metadata
|
||||
|
||||
# Set the cursor on the source
|
||||
source.cursor = mock_cursor
|
||||
|
||||
# Call get_partitions - should attempt partition extraction
|
||||
result = source.get_partitions(mock_inspector, "test_schema", "test_table")
|
||||
|
||||
# Verify that the cursor was used (partition extraction was attempted)
|
||||
mock_cursor.get_table_metadata.assert_called_once_with(
|
||||
table_name="test_table", schema_name="test_schema"
|
||||
)
|
||||
|
||||
# Result should be a list (even if empty)
|
||||
assert isinstance(result, list)
|
||||
assert result == ["year"] # Should contain the partition key name
|
||||
|
||||
|
||||
def test_partition_profiling_sql_generation_single_key():
|
||||
"""Test that partition profiling generates valid SQL for single partition key and can be parsed by SQLGlot."""
|
||||
|
||||
config = AthenaConfig.parse_obj(
|
||||
{
|
||||
"aws_region": "us-west-1",
|
||||
"query_result_location": "s3://query-result-location/",
|
||||
"work_group": "test-workgroup",
|
||||
"extract_partitions": True,
|
||||
"profiling": {"enabled": True, "partition_profiling_enabled": True},
|
||||
}
|
||||
)
|
||||
|
||||
ctx = PipelineContext(run_id="test")
|
||||
source = AthenaSource(config=config, ctx=ctx)
|
||||
|
||||
# Mock cursor and metadata for single partition key
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_metadata = mock.MagicMock()
|
||||
mock_partition_key = mock.MagicMock()
|
||||
mock_partition_key.name = "year"
|
||||
mock_metadata.partition_keys = [mock_partition_key]
|
||||
mock_cursor.get_table_metadata.return_value = mock_metadata
|
||||
|
||||
# Mock successful partition query execution
|
||||
mock_result = mock.MagicMock()
|
||||
mock_result.description = [["year"]]
|
||||
mock_result.__iter__ = lambda x: iter([["2023"]])
|
||||
mock_cursor.execute.return_value = mock_result
|
||||
|
||||
source.cursor = mock_cursor
|
||||
|
||||
# Call get_partitions to trigger SQL generation
|
||||
result = source.get_partitions(mock.MagicMock(), "test_schema", "test_table")
|
||||
|
||||
# Get the generated SQL query
|
||||
assert mock_cursor.execute.called
|
||||
generated_query = mock_cursor.execute.call_args[0][0]
|
||||
|
||||
# Verify the query structure for single partition key
|
||||
assert "CAST(year as VARCHAR)" in generated_query
|
||||
assert "CONCAT" not in generated_query # Single key shouldn't use CONCAT
|
||||
assert '"test_schema"."test_table$partitions"' in generated_query
|
||||
|
||||
# Validate that SQLGlot can parse the generated query using Athena dialect
|
||||
try:
|
||||
parsed = sqlglot.parse_one(generated_query, dialect=Athena)
|
||||
assert parsed is not None
|
||||
assert isinstance(parsed, sqlglot.expressions.Select)
|
||||
print(f"✅ Single partition SQL parsed successfully: {generated_query}")
|
||||
except Exception as e:
|
||||
pytest.fail(
|
||||
f"SQLGlot failed to parse single partition query: {e}\nQuery: {generated_query}"
|
||||
)
|
||||
|
||||
assert result == ["year"]
|
||||
|
||||
|
||||
def test_partition_profiling_sql_generation_multiple_keys():
|
||||
"""Test that partition profiling generates valid SQL for multiple partition keys and can be parsed by SQLGlot."""
|
||||
|
||||
config = AthenaConfig.parse_obj(
|
||||
{
|
||||
"aws_region": "us-west-1",
|
||||
"query_result_location": "s3://query-result-location/",
|
||||
"work_group": "test-workgroup",
|
||||
"extract_partitions": True,
|
||||
"profiling": {"enabled": True, "partition_profiling_enabled": True},
|
||||
}
|
||||
)
|
||||
|
||||
ctx = PipelineContext(run_id="test")
|
||||
source = AthenaSource(config=config, ctx=ctx)
|
||||
|
||||
# Mock cursor and metadata for multiple partition keys
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_metadata = mock.MagicMock()
|
||||
|
||||
mock_year_key = mock.MagicMock()
|
||||
mock_year_key.name = "year"
|
||||
mock_month_key = mock.MagicMock()
|
||||
mock_month_key.name = "month"
|
||||
mock_day_key = mock.MagicMock()
|
||||
mock_day_key.name = "day"
|
||||
|
||||
mock_metadata.partition_keys = [mock_year_key, mock_month_key, mock_day_key]
|
||||
mock_cursor.get_table_metadata.return_value = mock_metadata
|
||||
|
||||
# Mock successful partition query execution
|
||||
mock_result = mock.MagicMock()
|
||||
mock_result.description = [["year"], ["month"], ["day"]]
|
||||
mock_result.__iter__ = lambda x: iter([["2023", "12", "25"]])
|
||||
mock_cursor.execute.return_value = mock_result
|
||||
|
||||
source.cursor = mock_cursor
|
||||
|
||||
# Call get_partitions to trigger SQL generation
|
||||
result = source.get_partitions(mock.MagicMock(), "test_schema", "test_table")
|
||||
|
||||
# Get the generated SQL query
|
||||
assert mock_cursor.execute.called
|
||||
generated_query = mock_cursor.execute.call_args[0][0]
|
||||
|
||||
# Verify the query structure for multiple partition keys
|
||||
assert "CONCAT(" in generated_query # Multiple keys should use CONCAT
|
||||
assert "CAST(year as VARCHAR)" in generated_query
|
||||
assert "CAST(month as VARCHAR)" in generated_query
|
||||
assert "CAST(day as VARCHAR)" in generated_query
|
||||
assert "CAST('-' AS VARCHAR)" in generated_query # Separator should be cast
|
||||
assert '"test_schema"."test_table$partitions"' in generated_query
|
||||
|
||||
# Validate that SQLGlot can parse the generated query using Athena dialect
|
||||
try:
|
||||
parsed = sqlglot.parse_one(generated_query, dialect=Athena)
|
||||
assert parsed is not None
|
||||
assert isinstance(parsed, sqlglot.expressions.Select)
|
||||
print(f"✅ Multiple partition SQL parsed successfully: {generated_query}")
|
||||
except Exception as e:
|
||||
pytest.fail(
|
||||
f"SQLGlot failed to parse multiple partition query: {e}\nQuery: {generated_query}"
|
||||
)
|
||||
|
||||
assert result == ["year", "month", "day"]
|
||||
|
||||
|
||||
def test_partition_profiling_sql_generation_complex_schema_table_names():
|
||||
"""Test that partition profiling handles complex schema/table names correctly and generates valid SQL."""
|
||||
|
||||
config = AthenaConfig.parse_obj(
|
||||
{
|
||||
"aws_region": "us-west-1",
|
||||
"query_result_location": "s3://query-result-location/",
|
||||
"work_group": "test-workgroup",
|
||||
"extract_partitions": True,
|
||||
"profiling": {"enabled": True, "partition_profiling_enabled": True},
|
||||
}
|
||||
)
|
||||
|
||||
ctx = PipelineContext(run_id="test")
|
||||
source = AthenaSource(config=config, ctx=ctx)
|
||||
|
||||
# Mock cursor and metadata
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_metadata = mock.MagicMock()
|
||||
|
||||
mock_partition_key = mock.MagicMock()
|
||||
mock_partition_key.name = "event_date"
|
||||
mock_metadata.partition_keys = [mock_partition_key]
|
||||
mock_cursor.get_table_metadata.return_value = mock_metadata
|
||||
|
||||
# Mock successful partition query execution
|
||||
mock_result = mock.MagicMock()
|
||||
mock_result.description = [["event_date"]]
|
||||
mock_result.__iter__ = lambda x: iter([["2023-12-25"]])
|
||||
mock_cursor.execute.return_value = mock_result
|
||||
|
||||
source.cursor = mock_cursor
|
||||
|
||||
# Test with complex schema and table names
|
||||
schema = "ad_cdp_audience" # From the user's error
|
||||
table = "system_import_label"
|
||||
|
||||
result = source.get_partitions(mock.MagicMock(), schema, table)
|
||||
|
||||
# Get the generated SQL query
|
||||
assert mock_cursor.execute.called
|
||||
generated_query = mock_cursor.execute.call_args[0][0]
|
||||
|
||||
# Verify proper quoting of schema and table names
|
||||
expected_table_ref = f'"{schema}"."{table}$partitions"'
|
||||
assert expected_table_ref in generated_query
|
||||
assert "CAST(event_date as VARCHAR)" in generated_query
|
||||
|
||||
# Validate that SQLGlot can parse the generated query using Athena dialect
|
||||
try:
|
||||
parsed = sqlglot.parse_one(generated_query, dialect=Athena)
|
||||
assert parsed is not None
|
||||
assert isinstance(parsed, sqlglot.expressions.Select)
|
||||
print(f"✅ Complex schema/table SQL parsed successfully: {generated_query}")
|
||||
except Exception as e:
|
||||
pytest.fail(
|
||||
f"SQLGlot failed to parse complex schema/table query: {e}\nQuery: {generated_query}"
|
||||
)
|
||||
|
||||
assert result == ["event_date"]
|
||||
|
||||
|
||||
def test_casted_partition_key_method():
|
||||
"""Test the _casted_partition_key helper method generates valid SQL fragments."""
|
||||
|
||||
# Test the static method directly
|
||||
casted_key = AthenaSource._casted_partition_key("test_column")
|
||||
assert casted_key == "CAST(test_column as VARCHAR)"
|
||||
|
||||
# Verify SQLGlot can parse the CAST expression
|
||||
try:
|
||||
parsed = sqlglot.parse_one(casted_key, dialect=Athena)
|
||||
assert parsed is not None
|
||||
assert isinstance(parsed, sqlglot.expressions.Cast)
|
||||
assert parsed.this.name == "test_column"
|
||||
assert "VARCHAR" in str(parsed.to.this) # More flexible assertion
|
||||
print(f"✅ CAST expression parsed successfully: {casted_key}")
|
||||
except Exception as e:
|
||||
pytest.fail(
|
||||
f"SQLGlot failed to parse CAST expression: {e}\nExpression: {casted_key}"
|
||||
)
|
||||
|
||||
# Test with various column name formats
|
||||
test_cases = ["year", "month", "event_date", "created_at", "partition_key_123"]
|
||||
|
||||
for column_name in test_cases:
|
||||
casted = AthenaSource._casted_partition_key(column_name)
|
||||
try:
|
||||
parsed = sqlglot.parse_one(casted, dialect=Athena)
|
||||
assert parsed is not None
|
||||
assert isinstance(parsed, sqlglot.expressions.Cast)
|
||||
except Exception as e:
|
||||
pytest.fail(f"SQLGlot failed to parse CAST for column '{column_name}': {e}")
|
||||
|
||||
|
||||
def test_concat_function_generation_validates_with_sqlglot():
|
||||
"""Test that our CONCAT function generation produces valid Athena SQL according to SQLGlot."""
|
||||
|
||||
# Test the CONCAT function generation logic directly
|
||||
partition_keys = ["year", "month", "day"]
|
||||
casted_keys = [AthenaSource._casted_partition_key(key) for key in partition_keys]
|
||||
|
||||
# Replicate the CONCAT generation logic from the source
|
||||
concat_args = []
|
||||
for i, key in enumerate(casted_keys):
|
||||
concat_args.append(key)
|
||||
if i < len(casted_keys) - 1: # Add separator except for last element
|
||||
concat_args.append("CAST('-' AS VARCHAR)")
|
||||
|
||||
concat_expr = f"CONCAT({', '.join(concat_args)})"
|
||||
|
||||
# Verify the generated CONCAT expression
|
||||
expected = "CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR), CAST('-' AS VARCHAR), CAST(day as VARCHAR))"
|
||||
assert concat_expr == expected
|
||||
|
||||
# Validate with SQLGlot
|
||||
try:
|
||||
parsed = sqlglot.parse_one(concat_expr, dialect=Athena)
|
||||
assert parsed is not None
|
||||
assert isinstance(parsed, sqlglot.expressions.Concat)
|
||||
# Verify all arguments are properly parsed
|
||||
assert len(parsed.expressions) == 5 # 3 partition keys + 2 separators
|
||||
print(f"✅ CONCAT expression parsed successfully: {concat_expr}")
|
||||
except Exception as e:
|
||||
pytest.fail(
|
||||
f"SQLGlot failed to parse CONCAT expression: {e}\nExpression: {concat_expr}"
|
||||
)
|
||||
|
||||
|
||||
def test_build_max_partition_query():
|
||||
"""Test _build_max_partition_query method directly without mocking."""
|
||||
|
||||
# Test single partition key
|
||||
query_single = AthenaSource._build_max_partition_query(
|
||||
"test_schema", "test_table", ["year"]
|
||||
)
|
||||
expected_single = 'select year from "test_schema"."test_table$partitions" where CAST(year as VARCHAR) = (select max(CAST(year as VARCHAR)) from "test_schema"."test_table$partitions")'
|
||||
assert query_single == expected_single
|
||||
|
||||
# Test multiple partition keys
|
||||
query_multiple = AthenaSource._build_max_partition_query(
|
||||
"test_schema", "test_table", ["year", "month", "day"]
|
||||
)
|
||||
expected_multiple = "select year,month,day from \"test_schema\".\"test_table$partitions\" where CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR), CAST('-' AS VARCHAR), CAST(day as VARCHAR)) = (select max(CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR), CAST('-' AS VARCHAR), CAST(day as VARCHAR))) from \"test_schema\".\"test_table$partitions\")"
|
||||
assert query_multiple == expected_multiple
|
||||
|
||||
# Validate with SQLGlot that generated queries are valid SQL
|
||||
try:
|
||||
parsed_single = sqlglot.parse_one(query_single, dialect=Athena)
|
||||
assert parsed_single is not None
|
||||
assert isinstance(parsed_single, sqlglot.expressions.Select)
|
||||
|
||||
parsed_multiple = sqlglot.parse_one(query_multiple, dialect=Athena)
|
||||
assert parsed_multiple is not None
|
||||
assert isinstance(parsed_multiple, sqlglot.expressions.Select)
|
||||
|
||||
print("✅ Both queries parsed successfully by SQLGlot")
|
||||
except Exception as e:
|
||||
pytest.fail(f"SQLGlot failed to parse generated query: {e}")
|
||||
|
||||
|
||||
def test_partition_profiling_disabled_no_sql_generation():
|
||||
"""Test that when partition profiling is disabled, no complex SQL is generated."""
|
||||
config = AthenaConfig.parse_obj(
|
||||
{
|
||||
"aws_region": "us-west-1",
|
||||
"query_result_location": "s3://query-result-location/",
|
||||
"work_group": "test-workgroup",
|
||||
"extract_partitions": True,
|
||||
"profiling": {"enabled": False, "partition_profiling_enabled": False},
|
||||
}
|
||||
)
|
||||
|
||||
ctx = PipelineContext(run_id="test")
|
||||
source = AthenaSource(config=config, ctx=ctx)
|
||||
|
||||
# Mock cursor and metadata
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_metadata = mock.MagicMock()
|
||||
mock_partition_key = mock.MagicMock()
|
||||
mock_partition_key.name = "year"
|
||||
mock_metadata.partition_keys = [mock_partition_key]
|
||||
mock_cursor.get_table_metadata.return_value = mock_metadata
|
||||
|
||||
source.cursor = mock_cursor
|
||||
|
||||
# Call get_partitions - should not generate complex profiling SQL
|
||||
result = source.get_partitions(mock.MagicMock(), "test_schema", "test_table")
|
||||
|
||||
# Should only call get_table_metadata, not execute complex partition queries
|
||||
mock_cursor.get_table_metadata.assert_called_once()
|
||||
# The execute method should not be called for profiling queries when profiling is disabled
|
||||
assert not mock_cursor.execute.called
|
||||
|
||||
assert result == ["year"]
|
||||
|
||||
|
||||
def test_sanitize_identifier_valid_names():
|
||||
"""Test _sanitize_identifier method with valid Athena identifiers."""
|
||||
# Valid simple identifiers
|
||||
valid_identifiers = [
|
||||
"table_name",
|
||||
"schema123",
|
||||
"column_1",
|
||||
"MyTable", # Should be allowed (Athena converts to lowercase)
|
||||
"DATABASE",
|
||||
"a", # Single character
|
||||
"table123name",
|
||||
"data_warehouse_table",
|
||||
]
|
||||
|
||||
for identifier in valid_identifiers:
|
||||
result = AthenaSource._sanitize_identifier(identifier)
|
||||
assert result == identifier, f"Expected {identifier} to be valid"
|
||||
|
||||
|
||||
def test_sanitize_identifier_valid_complex_types():
|
||||
"""Test _sanitize_identifier method with valid complex type identifiers."""
|
||||
# Valid complex type identifiers (with periods)
|
||||
valid_complex_identifiers = [
|
||||
"struct.field",
|
||||
"data.subfield",
|
||||
"nested.struct.field",
|
||||
"array.element",
|
||||
"map.key",
|
||||
"record.attribute.subfield",
|
||||
]
|
||||
|
||||
for identifier in valid_complex_identifiers:
|
||||
result = AthenaSource._sanitize_identifier(identifier)
|
||||
assert result == identifier, (
|
||||
f"Expected {identifier} to be valid for complex types"
|
||||
)
|
||||
|
||||
|
||||
def test_sanitize_identifier_invalid_characters():
|
||||
"""Test _sanitize_identifier method rejects invalid characters."""
|
||||
# Invalid identifiers that should be rejected
|
||||
invalid_identifiers = [
|
||||
"table-name", # hyphen not allowed in Athena
|
||||
"table name", # spaces not allowed
|
||||
"table@domain", # @ not allowed
|
||||
"table#hash", # # not allowed
|
||||
"table$var", # $ not allowed (except in system tables)
|
||||
"table%percent", # % not allowed
|
||||
"table&and", # & not allowed
|
||||
"table*star", # * not allowed
|
||||
"table(paren", # ( not allowed
|
||||
"table)paren", # ) not allowed
|
||||
"table+plus", # + not allowed
|
||||
"table=equal", # = not allowed
|
||||
"table[bracket", # [ not allowed
|
||||
"table]bracket", # ] not allowed
|
||||
"table{brace", # { not allowed
|
||||
"table}brace", # } not allowed
|
||||
"table|pipe", # | not allowed
|
||||
"table\\backslash", # \ not allowed
|
||||
"table:colon", # : not allowed
|
||||
"table;semicolon", # ; not allowed
|
||||
"table<less", # < not allowed
|
||||
"table>greater", # > not allowed
|
||||
"table?question", # ? not allowed
|
||||
"table/slash", # / not allowed
|
||||
"table,comma", # , not allowed
|
||||
]
|
||||
|
||||
for identifier in invalid_identifiers:
|
||||
with pytest.raises(ValueError, match="contains unsafe characters"):
|
||||
AthenaSource._sanitize_identifier(identifier)
|
||||
|
||||
|
||||
def test_sanitize_identifier_sql_injection_attempts():
|
||||
"""Test _sanitize_identifier method blocks SQL injection attempts."""
|
||||
# SQL injection attempts that should be blocked
|
||||
sql_injection_attempts = [
|
||||
"table'; DROP TABLE users; --",
|
||||
'table" OR 1=1 --',
|
||||
"table; DELETE FROM data;",
|
||||
"'; UNION SELECT * FROM passwords --",
|
||||
"table') OR ('1'='1",
|
||||
"table\"; INSERT INTO logs VALUES ('hack'); --",
|
||||
"table' AND 1=0 UNION SELECT * FROM admin --",
|
||||
"table/*comment*/",
|
||||
"table--comment",
|
||||
"table' OR 'a'='a",
|
||||
"1'; DROP DATABASE test; --",
|
||||
"table'; EXEC xp_cmdshell('dir'); --",
|
||||
]
|
||||
|
||||
for injection_attempt in sql_injection_attempts:
|
||||
with pytest.raises(ValueError, match="contains unsafe characters"):
|
||||
AthenaSource._sanitize_identifier(injection_attempt)
|
||||
|
||||
|
||||
def test_sanitize_identifier_quote_injection_attempts():
|
||||
"""Test _sanitize_identifier method blocks quote-based injection attempts."""
|
||||
# Quote injection attempts
|
||||
quote_injection_attempts = [
|
||||
'table"',
|
||||
"table'",
|
||||
'table""',
|
||||
"table''",
|
||||
'table"extra',
|
||||
"table'extra",
|
||||
"`table`",
|
||||
'"table"',
|
||||
"'table'",
|
||||
]
|
||||
|
||||
for injection_attempt in quote_injection_attempts:
|
||||
with pytest.raises(ValueError, match="contains unsafe characters"):
|
||||
AthenaSource._sanitize_identifier(injection_attempt)
|
||||
|
||||
|
||||
def test_sanitize_identifier_empty_and_edge_cases():
|
||||
"""Test _sanitize_identifier method with empty and edge case inputs."""
|
||||
# Empty identifier
|
||||
with pytest.raises(ValueError, match="Identifier cannot be empty"):
|
||||
AthenaSource._sanitize_identifier("")
|
||||
|
||||
# None input (should also raise ValueError from empty check)
|
||||
with pytest.raises(ValueError, match="Identifier cannot be empty"):
|
||||
AthenaSource._sanitize_identifier(None) # type: ignore[arg-type]
|
||||
|
||||
# Very long valid identifier
|
||||
long_identifier = "a" * 200 # Still within Athena's 255 byte limit
|
||||
result = AthenaSource._sanitize_identifier(long_identifier)
|
||||
assert result == long_identifier
|
||||
|
||||
|
||||
def test_sanitize_identifier_integration_with_build_max_partition_query():
|
||||
"""Test that _sanitize_identifier works correctly within _build_max_partition_query."""
|
||||
# Test with valid identifiers
|
||||
query = AthenaSource._build_max_partition_query(
|
||||
"valid_schema", "valid_table", ["valid_partition"]
|
||||
)
|
||||
assert "valid_schema" in query
|
||||
assert "valid_table" in query
|
||||
assert "valid_partition" in query
|
||||
|
||||
# Test that invalid identifiers are rejected before query building
|
||||
with pytest.raises(ValueError, match="contains unsafe characters"):
|
||||
AthenaSource._build_max_partition_query(
|
||||
"schema'; DROP TABLE users; --", "table", ["partition"]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="contains unsafe characters"):
|
||||
AthenaSource._build_max_partition_query(
|
||||
"schema", "table-with-hyphen", ["partition"]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="contains unsafe characters"):
|
||||
AthenaSource._build_max_partition_query(
|
||||
"schema", "table", ["partition; DELETE FROM data;"]
|
||||
)
|
||||
|
||||
|
||||
def test_sanitize_identifier_error_handling_in_get_partitions():
|
||||
"""Test that ValueError from _sanitize_identifier is handled gracefully in get_partitions method."""
|
||||
config = AthenaConfig.parse_obj(
|
||||
{
|
||||
"aws_region": "us-west-1",
|
||||
"query_result_location": "s3://query-result-location/",
|
||||
"work_group": "test-workgroup",
|
||||
"extract_partitions": True,
|
||||
"profiling": {"enabled": True, "partition_profiling_enabled": True},
|
||||
}
|
||||
)
|
||||
|
||||
ctx = PipelineContext(run_id="test")
|
||||
source = AthenaSource(config=config, ctx=ctx)
|
||||
|
||||
# Mock cursor and metadata
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_metadata = mock.MagicMock()
|
||||
mock_partition_key = mock.MagicMock()
|
||||
mock_partition_key.name = "year"
|
||||
mock_metadata.partition_keys = [mock_partition_key]
|
||||
mock_cursor.get_table_metadata.return_value = mock_metadata
|
||||
source.cursor = mock_cursor
|
||||
|
||||
# Test with malicious schema name - should be handled gracefully by report_exc
|
||||
result = source.get_partitions(
|
||||
mock.MagicMock(), "schema'; DROP TABLE users; --", "valid_table"
|
||||
)
|
||||
|
||||
# Should still return the partition list from metadata since the error
|
||||
# occurs in the profiling section which is wrapped in report_exc
|
||||
assert result == ["year"]
|
||||
|
||||
# Verify metadata was still called
|
||||
mock_cursor.get_table_metadata.assert_called_once()
|
||||
|
||||
|
||||
def test_sanitize_identifier_error_handling_in_generate_partition_profiler_query(
|
||||
caplog,
|
||||
):
|
||||
"""Test that ValueError from _sanitize_identifier is handled gracefully in generate_partition_profiler_query."""
|
||||
import logging
|
||||
|
||||
config = AthenaConfig.parse_obj(
|
||||
{
|
||||
"aws_region": "us-west-1",
|
||||
"query_result_location": "s3://query-result-location/",
|
||||
"work_group": "test-workgroup",
|
||||
"profiling": {"enabled": True, "partition_profiling_enabled": True},
|
||||
}
|
||||
)
|
||||
|
||||
ctx = PipelineContext(run_id="test")
|
||||
source = AthenaSource(config=config, ctx=ctx)
|
||||
|
||||
# Add a mock partition to the cache with malicious partition key
|
||||
source.table_partition_cache["valid_schema"] = {
|
||||
"valid_table": Partitionitem(
|
||||
partitions=["year"],
|
||||
max_partition={"malicious'; DROP TABLE users; --": "2023"},
|
||||
)
|
||||
}
|
||||
|
||||
# Capture log messages at WARNING level
|
||||
with caplog.at_level(logging.WARNING):
|
||||
# This should handle the ValueError gracefully and return None, None
|
||||
# instead of crashing the entire ingestion process
|
||||
result = source.generate_partition_profiler_query(
|
||||
"valid_schema", "valid_table", None
|
||||
)
|
||||
|
||||
# Verify the method returns None, None when sanitization fails
|
||||
assert result == (None, None), "Should return None, None when sanitization fails"
|
||||
|
||||
# Verify that a warning log message was generated
|
||||
assert len(caplog.records) == 1
|
||||
log_record = caplog.records[0]
|
||||
assert log_record.levelname == "WARNING"
|
||||
assert (
|
||||
"Failed to generate partition profiler query for valid_schema.valid_table due to unsafe identifiers"
|
||||
in log_record.message
|
||||
)
|
||||
assert "contains unsafe characters" in log_record.message
|
||||
assert "Partition profiling disabled for this table" in log_record.message
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user