fix(ingest/athena): Fix Athena partition extraction and CONCAT function type issues (#14712)

This commit is contained in:
Tamas Nemeth 2025-09-10 12:33:54 +02:00 committed by GitHub
parent c7ad3f45ea
commit a82d4e0647
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 746 additions and 32 deletions

View File

@ -73,6 +73,11 @@ except ImportError:
logger = logging.getLogger(__name__)
# Precompiled regex for SQL identifier validation
# Athena identifiers can only contain lowercase letters, numbers, underscore, and period (for complex types)
# Note: Athena automatically converts uppercase to lowercase, but we're being strict for security
_IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z0-9_.]+$")
assert STRUCT, "required type modules are not available"
register_custom_type(STRUCT, RecordTypeClass)
register_custom_type(MapType, MapTypeClass)
@ -510,20 +515,76 @@ class AthenaSource(SQLAlchemySource):
return [schema for schema in schemas if schema == athena_config.database]
return schemas
@classmethod
def _sanitize_identifier(cls, identifier: str) -> str:
"""Sanitize SQL identifiers to prevent injection attacks.
Args:
identifier: The SQL identifier to sanitize
Returns:
Sanitized identifier safe for SQL queries
Raises:
ValueError: If identifier contains unsafe characters
"""
if not identifier:
raise ValueError("Identifier cannot be empty")
# Allow only alphanumeric characters, underscores, and periods for identifiers
# This matches Athena's identifier naming rules
if not _IDENTIFIER_PATTERN.match(identifier):
raise ValueError(
f"Identifier '{identifier}' contains unsafe characters. Only alphanumeric characters, underscores, and periods are allowed."
)
return identifier
@classmethod
def _casted_partition_key(cls, key: str) -> str:
# We need to cast the partition keys to a VARCHAR, since otherwise
# Athena may throw an error during concatenation / comparison.
return f"CAST({key} as VARCHAR)"
sanitized_key = cls._sanitize_identifier(key)
return f"CAST({sanitized_key} as VARCHAR)"
@classmethod
def _build_max_partition_query(
cls, schema: str, table: str, partitions: List[str]
) -> str:
"""Build SQL query to find the row with maximum partition values.
Args:
schema: Database schema name
table: Table name
partitions: List of partition column names
Returns:
SQL query string to find the maximum partition
Raises:
ValueError: If any identifier contains unsafe characters
"""
# Sanitize all identifiers to prevent SQL injection
sanitized_schema = cls._sanitize_identifier(schema)
sanitized_table = cls._sanitize_identifier(table)
sanitized_partitions = [
cls._sanitize_identifier(partition) for partition in partitions
]
casted_keys = [cls._casted_partition_key(key) for key in partitions]
if len(casted_keys) == 1:
part_concat = casted_keys[0]
else:
separator = "CAST('-' AS VARCHAR)"
part_concat = f"CONCAT({f', {separator}, '.join(casted_keys)})"
return f'select {",".join(sanitized_partitions)} from "{sanitized_schema}"."{sanitized_table}$partitions" where {part_concat} = (select max({part_concat}) from "{sanitized_schema}"."{sanitized_table}$partitions")'
@override
def get_partitions(
self, inspector: Inspector, schema: str, table: str
) -> Optional[List[str]]:
if (
not self.config.extract_partitions
and not self.config.extract_partitions_using_create_statements
):
if not self.config.extract_partitions:
return None
if not self.cursor:
@ -557,11 +618,9 @@ class AthenaSource(SQLAlchemySource):
context=f"{schema}.{table}",
level=StructuredLogLevel.WARN,
):
# We create an artifical concatenated partition key to be able to query max partition easier
part_concat = " || '-' || ".join(
self._casted_partition_key(key) for key in partitions
max_partition_query = self._build_max_partition_query(
schema, table, partitions
)
max_partition_query = f'select {",".join(partitions)} from "{schema}"."{table}$partitions" where {part_concat} = (select max({part_concat}) from "{schema}"."{table}$partitions")'
ret = self.cursor.execute(max_partition_query)
max_partition: Dict[str, str] = {}
if ret:
@ -678,16 +737,34 @@ class AthenaSource(SQLAlchemySource):
).get(table, None)
if partition and partition.max_partition:
max_partition_filters = []
for key, value in partition.max_partition.items():
max_partition_filters.append(
f"{self._casted_partition_key(key)} = '{value}'"
try:
# Sanitize identifiers to prevent SQL injection
sanitized_schema = self._sanitize_identifier(schema)
sanitized_table = self._sanitize_identifier(table)
max_partition_filters = []
for key, value in partition.max_partition.items():
# Sanitize partition key and properly escape the value
sanitized_key = self._sanitize_identifier(key)
# Escape single quotes in the value to prevent injection
escaped_value = value.replace("'", "''") if value else ""
max_partition_filters.append(
f"{self._casted_partition_key(sanitized_key)} = '{escaped_value}'"
)
max_partition = str(partition.max_partition)
return (
max_partition,
f'SELECT * FROM "{sanitized_schema}"."{sanitized_table}" WHERE {" AND ".join(max_partition_filters)}',
)
max_partition = str(partition.max_partition)
return (
max_partition,
f'SELECT * FROM "{schema}"."{table}" WHERE {" AND ".join(max_partition_filters)}',
)
except ValueError as e:
# If sanitization fails due to malicious identifiers,
# return None to disable partition profiling for this table
# rather than crashing the entire ingestion
logger.warning(
f"Failed to generate partition profiler query for {schema}.{table} due to unsafe identifiers: {e}. "
f"Partition profiling disabled for this table."
)
return None, None
return None, None
def close(self):

View File

@ -174,20 +174,16 @@ class AthenaPropertiesExtractor:
def format_column_definition(line):
# Use regex to parse the line more accurately
# Pattern: column_name data_type [COMMENT comment_text] [,]
# Use greedy match for comment to capture everything until trailing comma
pattern = r"^\s*(.+?)\s+([\s,\w<>\[\]]+)((\s+COMMENT\s+(.+?)(,?))|(,?)\s*)?$"
match = re.match(pattern, line, re.IGNORECASE)
# Improved pattern to better separate column name, data type, and comment
pattern = r"^\s*([`\w']+)\s+([\w<>\[\](),\s]+?)(\s+COMMENT\s+(.+?))?(,?)\s*$"
match = re.match(pattern, line.strip(), re.IGNORECASE)
if not match:
return line
column_name = match.group(1)
data_type = match.group(2)
comment_part = match.group(5) # COMMENT part
# there are different number of match groups depending on whether comment exists
if comment_part:
trailing_comma = match.group(6) if match.group(6) else ""
else:
trailing_comma = match.group(7) if match.group(7) else ""
column_name = match.group(1).strip()
data_type = match.group(2).strip()
comment_part = match.group(4) # COMMENT part
trailing_comma = match.group(5) if match.group(5) else ""
# Add backticks to column name if not already present
if not (column_name.startswith("`") and column_name.endswith("`")):
@ -201,7 +197,7 @@ class AthenaPropertiesExtractor:
# Handle comment quoting and escaping
if comment_part.startswith("'") and comment_part.endswith("'"):
# Already properly single quoted - but check for proper escaping
# Already single quoted - but check for proper escaping
inner_content = comment_part[1:-1]
# Re-escape any single quotes that aren't properly escaped
escaped_content = inner_content.replace("'", "''")

View File

@ -3,10 +3,12 @@ from typing import List
from unittest import mock
import pytest
import sqlglot
from freezegun import freeze_time
from pyathena import OperationalError
from sqlalchemy import types
from sqlalchemy_bigquery import STRUCT
from sqlglot.dialects import Athena
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.source.aws.s3_util import make_s3_urn
@ -14,6 +16,7 @@ from datahub.ingestion.source.sql.athena import (
AthenaConfig,
AthenaSource,
CustomAthenaRestDialect,
Partitionitem,
)
from datahub.metadata.schema_classes import (
ArrayTypeClass,
@ -182,8 +185,8 @@ def test_athena_get_table_properties():
expected_query = """\
select year,month from "test_schema"."test_table$partitions" \
where CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR) = \
(select max(CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR)) \
where CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR)) = \
(select max(CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR))) \
from "test_schema"."test_table$partitions")"""
assert mock_cursor.execute.call_count == 2
assert expected_create_table_query == mock_cursor.execute.call_args_list[0][0][0]
@ -503,3 +506,641 @@ def test_convert_simple_field_paths_to_v1_default_behavior():
)
assert config.emit_schema_fieldpaths_as_v1 is False # Should default to False
def test_get_partitions_returns_none_when_extract_partitions_disabled():
"""Test that get_partitions returns None when extract_partitions is False"""
config = AthenaConfig.parse_obj(
{
"aws_region": "us-west-1",
"query_result_location": "s3://query-result-location/",
"work_group": "test-workgroup",
"extract_partitions": False,
}
)
ctx = PipelineContext(run_id="test")
source = AthenaSource(config=config, ctx=ctx)
# Mock inspector - should not be used if extract_partitions is False
mock_inspector = mock.MagicMock()
# Call get_partitions - should return None immediately without any database calls
result = source.get_partitions(mock_inspector, "test_schema", "test_table")
# Verify result is None
assert result is None
# Verify that no inspector methods were called
assert not mock_inspector.called, (
"Inspector should not be called when extract_partitions=False"
)
def test_get_partitions_attempts_extraction_when_extract_partitions_enabled():
"""Test that get_partitions attempts partition extraction when extract_partitions is True"""
config = AthenaConfig.parse_obj(
{
"aws_region": "us-west-1",
"query_result_location": "s3://query-result-location/",
"work_group": "test-workgroup",
"extract_partitions": True,
}
)
ctx = PipelineContext(run_id="test")
source = AthenaSource(config=config, ctx=ctx)
# Mock inspector and cursor for partition extraction
mock_inspector = mock.MagicMock()
mock_cursor = mock.MagicMock()
# Mock the table metadata response
mock_metadata = mock.MagicMock()
mock_partition_key = mock.MagicMock()
mock_partition_key.name = "year"
mock_metadata.partition_keys = [mock_partition_key]
mock_cursor.get_table_metadata.return_value = mock_metadata
# Set the cursor on the source
source.cursor = mock_cursor
# Call get_partitions - should attempt partition extraction
result = source.get_partitions(mock_inspector, "test_schema", "test_table")
# Verify that the cursor was used (partition extraction was attempted)
mock_cursor.get_table_metadata.assert_called_once_with(
table_name="test_table", schema_name="test_schema"
)
# Result should be a list (even if empty)
assert isinstance(result, list)
assert result == ["year"] # Should contain the partition key name
def test_partition_profiling_sql_generation_single_key():
"""Test that partition profiling generates valid SQL for single partition key and can be parsed by SQLGlot."""
config = AthenaConfig.parse_obj(
{
"aws_region": "us-west-1",
"query_result_location": "s3://query-result-location/",
"work_group": "test-workgroup",
"extract_partitions": True,
"profiling": {"enabled": True, "partition_profiling_enabled": True},
}
)
ctx = PipelineContext(run_id="test")
source = AthenaSource(config=config, ctx=ctx)
# Mock cursor and metadata for single partition key
mock_cursor = mock.MagicMock()
mock_metadata = mock.MagicMock()
mock_partition_key = mock.MagicMock()
mock_partition_key.name = "year"
mock_metadata.partition_keys = [mock_partition_key]
mock_cursor.get_table_metadata.return_value = mock_metadata
# Mock successful partition query execution
mock_result = mock.MagicMock()
mock_result.description = [["year"]]
mock_result.__iter__ = lambda x: iter([["2023"]])
mock_cursor.execute.return_value = mock_result
source.cursor = mock_cursor
# Call get_partitions to trigger SQL generation
result = source.get_partitions(mock.MagicMock(), "test_schema", "test_table")
# Get the generated SQL query
assert mock_cursor.execute.called
generated_query = mock_cursor.execute.call_args[0][0]
# Verify the query structure for single partition key
assert "CAST(year as VARCHAR)" in generated_query
assert "CONCAT" not in generated_query # Single key shouldn't use CONCAT
assert '"test_schema"."test_table$partitions"' in generated_query
# Validate that SQLGlot can parse the generated query using Athena dialect
try:
parsed = sqlglot.parse_one(generated_query, dialect=Athena)
assert parsed is not None
assert isinstance(parsed, sqlglot.expressions.Select)
print(f"✅ Single partition SQL parsed successfully: {generated_query}")
except Exception as e:
pytest.fail(
f"SQLGlot failed to parse single partition query: {e}\nQuery: {generated_query}"
)
assert result == ["year"]
def test_partition_profiling_sql_generation_multiple_keys():
"""Test that partition profiling generates valid SQL for multiple partition keys and can be parsed by SQLGlot."""
config = AthenaConfig.parse_obj(
{
"aws_region": "us-west-1",
"query_result_location": "s3://query-result-location/",
"work_group": "test-workgroup",
"extract_partitions": True,
"profiling": {"enabled": True, "partition_profiling_enabled": True},
}
)
ctx = PipelineContext(run_id="test")
source = AthenaSource(config=config, ctx=ctx)
# Mock cursor and metadata for multiple partition keys
mock_cursor = mock.MagicMock()
mock_metadata = mock.MagicMock()
mock_year_key = mock.MagicMock()
mock_year_key.name = "year"
mock_month_key = mock.MagicMock()
mock_month_key.name = "month"
mock_day_key = mock.MagicMock()
mock_day_key.name = "day"
mock_metadata.partition_keys = [mock_year_key, mock_month_key, mock_day_key]
mock_cursor.get_table_metadata.return_value = mock_metadata
# Mock successful partition query execution
mock_result = mock.MagicMock()
mock_result.description = [["year"], ["month"], ["day"]]
mock_result.__iter__ = lambda x: iter([["2023", "12", "25"]])
mock_cursor.execute.return_value = mock_result
source.cursor = mock_cursor
# Call get_partitions to trigger SQL generation
result = source.get_partitions(mock.MagicMock(), "test_schema", "test_table")
# Get the generated SQL query
assert mock_cursor.execute.called
generated_query = mock_cursor.execute.call_args[0][0]
# Verify the query structure for multiple partition keys
assert "CONCAT(" in generated_query # Multiple keys should use CONCAT
assert "CAST(year as VARCHAR)" in generated_query
assert "CAST(month as VARCHAR)" in generated_query
assert "CAST(day as VARCHAR)" in generated_query
assert "CAST('-' AS VARCHAR)" in generated_query # Separator should be cast
assert '"test_schema"."test_table$partitions"' in generated_query
# Validate that SQLGlot can parse the generated query using Athena dialect
try:
parsed = sqlglot.parse_one(generated_query, dialect=Athena)
assert parsed is not None
assert isinstance(parsed, sqlglot.expressions.Select)
print(f"✅ Multiple partition SQL parsed successfully: {generated_query}")
except Exception as e:
pytest.fail(
f"SQLGlot failed to parse multiple partition query: {e}\nQuery: {generated_query}"
)
assert result == ["year", "month", "day"]
def test_partition_profiling_sql_generation_complex_schema_table_names():
"""Test that partition profiling handles complex schema/table names correctly and generates valid SQL."""
config = AthenaConfig.parse_obj(
{
"aws_region": "us-west-1",
"query_result_location": "s3://query-result-location/",
"work_group": "test-workgroup",
"extract_partitions": True,
"profiling": {"enabled": True, "partition_profiling_enabled": True},
}
)
ctx = PipelineContext(run_id="test")
source = AthenaSource(config=config, ctx=ctx)
# Mock cursor and metadata
mock_cursor = mock.MagicMock()
mock_metadata = mock.MagicMock()
mock_partition_key = mock.MagicMock()
mock_partition_key.name = "event_date"
mock_metadata.partition_keys = [mock_partition_key]
mock_cursor.get_table_metadata.return_value = mock_metadata
# Mock successful partition query execution
mock_result = mock.MagicMock()
mock_result.description = [["event_date"]]
mock_result.__iter__ = lambda x: iter([["2023-12-25"]])
mock_cursor.execute.return_value = mock_result
source.cursor = mock_cursor
# Test with complex schema and table names
schema = "ad_cdp_audience" # From the user's error
table = "system_import_label"
result = source.get_partitions(mock.MagicMock(), schema, table)
# Get the generated SQL query
assert mock_cursor.execute.called
generated_query = mock_cursor.execute.call_args[0][0]
# Verify proper quoting of schema and table names
expected_table_ref = f'"{schema}"."{table}$partitions"'
assert expected_table_ref in generated_query
assert "CAST(event_date as VARCHAR)" in generated_query
# Validate that SQLGlot can parse the generated query using Athena dialect
try:
parsed = sqlglot.parse_one(generated_query, dialect=Athena)
assert parsed is not None
assert isinstance(parsed, sqlglot.expressions.Select)
print(f"✅ Complex schema/table SQL parsed successfully: {generated_query}")
except Exception as e:
pytest.fail(
f"SQLGlot failed to parse complex schema/table query: {e}\nQuery: {generated_query}"
)
assert result == ["event_date"]
def test_casted_partition_key_method():
"""Test the _casted_partition_key helper method generates valid SQL fragments."""
# Test the static method directly
casted_key = AthenaSource._casted_partition_key("test_column")
assert casted_key == "CAST(test_column as VARCHAR)"
# Verify SQLGlot can parse the CAST expression
try:
parsed = sqlglot.parse_one(casted_key, dialect=Athena)
assert parsed is not None
assert isinstance(parsed, sqlglot.expressions.Cast)
assert parsed.this.name == "test_column"
assert "VARCHAR" in str(parsed.to.this) # More flexible assertion
print(f"✅ CAST expression parsed successfully: {casted_key}")
except Exception as e:
pytest.fail(
f"SQLGlot failed to parse CAST expression: {e}\nExpression: {casted_key}"
)
# Test with various column name formats
test_cases = ["year", "month", "event_date", "created_at", "partition_key_123"]
for column_name in test_cases:
casted = AthenaSource._casted_partition_key(column_name)
try:
parsed = sqlglot.parse_one(casted, dialect=Athena)
assert parsed is not None
assert isinstance(parsed, sqlglot.expressions.Cast)
except Exception as e:
pytest.fail(f"SQLGlot failed to parse CAST for column '{column_name}': {e}")
def test_concat_function_generation_validates_with_sqlglot():
"""Test that our CONCAT function generation produces valid Athena SQL according to SQLGlot."""
# Test the CONCAT function generation logic directly
partition_keys = ["year", "month", "day"]
casted_keys = [AthenaSource._casted_partition_key(key) for key in partition_keys]
# Replicate the CONCAT generation logic from the source
concat_args = []
for i, key in enumerate(casted_keys):
concat_args.append(key)
if i < len(casted_keys) - 1: # Add separator except for last element
concat_args.append("CAST('-' AS VARCHAR)")
concat_expr = f"CONCAT({', '.join(concat_args)})"
# Verify the generated CONCAT expression
expected = "CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR), CAST('-' AS VARCHAR), CAST(day as VARCHAR))"
assert concat_expr == expected
# Validate with SQLGlot
try:
parsed = sqlglot.parse_one(concat_expr, dialect=Athena)
assert parsed is not None
assert isinstance(parsed, sqlglot.expressions.Concat)
# Verify all arguments are properly parsed
assert len(parsed.expressions) == 5 # 3 partition keys + 2 separators
print(f"✅ CONCAT expression parsed successfully: {concat_expr}")
except Exception as e:
pytest.fail(
f"SQLGlot failed to parse CONCAT expression: {e}\nExpression: {concat_expr}"
)
def test_build_max_partition_query():
"""Test _build_max_partition_query method directly without mocking."""
# Test single partition key
query_single = AthenaSource._build_max_partition_query(
"test_schema", "test_table", ["year"]
)
expected_single = 'select year from "test_schema"."test_table$partitions" where CAST(year as VARCHAR) = (select max(CAST(year as VARCHAR)) from "test_schema"."test_table$partitions")'
assert query_single == expected_single
# Test multiple partition keys
query_multiple = AthenaSource._build_max_partition_query(
"test_schema", "test_table", ["year", "month", "day"]
)
expected_multiple = "select year,month,day from \"test_schema\".\"test_table$partitions\" where CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR), CAST('-' AS VARCHAR), CAST(day as VARCHAR)) = (select max(CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR), CAST('-' AS VARCHAR), CAST(day as VARCHAR))) from \"test_schema\".\"test_table$partitions\")"
assert query_multiple == expected_multiple
# Validate with SQLGlot that generated queries are valid SQL
try:
parsed_single = sqlglot.parse_one(query_single, dialect=Athena)
assert parsed_single is not None
assert isinstance(parsed_single, sqlglot.expressions.Select)
parsed_multiple = sqlglot.parse_one(query_multiple, dialect=Athena)
assert parsed_multiple is not None
assert isinstance(parsed_multiple, sqlglot.expressions.Select)
print("✅ Both queries parsed successfully by SQLGlot")
except Exception as e:
pytest.fail(f"SQLGlot failed to parse generated query: {e}")
def test_partition_profiling_disabled_no_sql_generation():
"""Test that when partition profiling is disabled, no complex SQL is generated."""
config = AthenaConfig.parse_obj(
{
"aws_region": "us-west-1",
"query_result_location": "s3://query-result-location/",
"work_group": "test-workgroup",
"extract_partitions": True,
"profiling": {"enabled": False, "partition_profiling_enabled": False},
}
)
ctx = PipelineContext(run_id="test")
source = AthenaSource(config=config, ctx=ctx)
# Mock cursor and metadata
mock_cursor = mock.MagicMock()
mock_metadata = mock.MagicMock()
mock_partition_key = mock.MagicMock()
mock_partition_key.name = "year"
mock_metadata.partition_keys = [mock_partition_key]
mock_cursor.get_table_metadata.return_value = mock_metadata
source.cursor = mock_cursor
# Call get_partitions - should not generate complex profiling SQL
result = source.get_partitions(mock.MagicMock(), "test_schema", "test_table")
# Should only call get_table_metadata, not execute complex partition queries
mock_cursor.get_table_metadata.assert_called_once()
# The execute method should not be called for profiling queries when profiling is disabled
assert not mock_cursor.execute.called
assert result == ["year"]
def test_sanitize_identifier_valid_names():
"""Test _sanitize_identifier method with valid Athena identifiers."""
# Valid simple identifiers
valid_identifiers = [
"table_name",
"schema123",
"column_1",
"MyTable", # Should be allowed (Athena converts to lowercase)
"DATABASE",
"a", # Single character
"table123name",
"data_warehouse_table",
]
for identifier in valid_identifiers:
result = AthenaSource._sanitize_identifier(identifier)
assert result == identifier, f"Expected {identifier} to be valid"
def test_sanitize_identifier_valid_complex_types():
"""Test _sanitize_identifier method with valid complex type identifiers."""
# Valid complex type identifiers (with periods)
valid_complex_identifiers = [
"struct.field",
"data.subfield",
"nested.struct.field",
"array.element",
"map.key",
"record.attribute.subfield",
]
for identifier in valid_complex_identifiers:
result = AthenaSource._sanitize_identifier(identifier)
assert result == identifier, (
f"Expected {identifier} to be valid for complex types"
)
def test_sanitize_identifier_invalid_characters():
"""Test _sanitize_identifier method rejects invalid characters."""
# Invalid identifiers that should be rejected
invalid_identifiers = [
"table-name", # hyphen not allowed in Athena
"table name", # spaces not allowed
"table@domain", # @ not allowed
"table#hash", # # not allowed
"table$var", # $ not allowed (except in system tables)
"table%percent", # % not allowed
"table&and", # & not allowed
"table*star", # * not allowed
"table(paren", # ( not allowed
"table)paren", # ) not allowed
"table+plus", # + not allowed
"table=equal", # = not allowed
"table[bracket", # [ not allowed
"table]bracket", # ] not allowed
"table{brace", # { not allowed
"table}brace", # } not allowed
"table|pipe", # | not allowed
"table\\backslash", # \ not allowed
"table:colon", # : not allowed
"table;semicolon", # ; not allowed
"table<less", # < not allowed
"table>greater", # > not allowed
"table?question", # ? not allowed
"table/slash", # / not allowed
"table,comma", # , not allowed
]
for identifier in invalid_identifiers:
with pytest.raises(ValueError, match="contains unsafe characters"):
AthenaSource._sanitize_identifier(identifier)
def test_sanitize_identifier_sql_injection_attempts():
"""Test _sanitize_identifier method blocks SQL injection attempts."""
# SQL injection attempts that should be blocked
sql_injection_attempts = [
"table'; DROP TABLE users; --",
'table" OR 1=1 --',
"table; DELETE FROM data;",
"'; UNION SELECT * FROM passwords --",
"table') OR ('1'='1",
"table\"; INSERT INTO logs VALUES ('hack'); --",
"table' AND 1=0 UNION SELECT * FROM admin --",
"table/*comment*/",
"table--comment",
"table' OR 'a'='a",
"1'; DROP DATABASE test; --",
"table'; EXEC xp_cmdshell('dir'); --",
]
for injection_attempt in sql_injection_attempts:
with pytest.raises(ValueError, match="contains unsafe characters"):
AthenaSource._sanitize_identifier(injection_attempt)
def test_sanitize_identifier_quote_injection_attempts():
"""Test _sanitize_identifier method blocks quote-based injection attempts."""
# Quote injection attempts
quote_injection_attempts = [
'table"',
"table'",
'table""',
"table''",
'table"extra',
"table'extra",
"`table`",
'"table"',
"'table'",
]
for injection_attempt in quote_injection_attempts:
with pytest.raises(ValueError, match="contains unsafe characters"):
AthenaSource._sanitize_identifier(injection_attempt)
def test_sanitize_identifier_empty_and_edge_cases():
"""Test _sanitize_identifier method with empty and edge case inputs."""
# Empty identifier
with pytest.raises(ValueError, match="Identifier cannot be empty"):
AthenaSource._sanitize_identifier("")
# None input (should also raise ValueError from empty check)
with pytest.raises(ValueError, match="Identifier cannot be empty"):
AthenaSource._sanitize_identifier(None) # type: ignore[arg-type]
# Very long valid identifier
long_identifier = "a" * 200 # Still within Athena's 255 byte limit
result = AthenaSource._sanitize_identifier(long_identifier)
assert result == long_identifier
def test_sanitize_identifier_integration_with_build_max_partition_query():
"""Test that _sanitize_identifier works correctly within _build_max_partition_query."""
# Test with valid identifiers
query = AthenaSource._build_max_partition_query(
"valid_schema", "valid_table", ["valid_partition"]
)
assert "valid_schema" in query
assert "valid_table" in query
assert "valid_partition" in query
# Test that invalid identifiers are rejected before query building
with pytest.raises(ValueError, match="contains unsafe characters"):
AthenaSource._build_max_partition_query(
"schema'; DROP TABLE users; --", "table", ["partition"]
)
with pytest.raises(ValueError, match="contains unsafe characters"):
AthenaSource._build_max_partition_query(
"schema", "table-with-hyphen", ["partition"]
)
with pytest.raises(ValueError, match="contains unsafe characters"):
AthenaSource._build_max_partition_query(
"schema", "table", ["partition; DELETE FROM data;"]
)
def test_sanitize_identifier_error_handling_in_get_partitions():
"""Test that ValueError from _sanitize_identifier is handled gracefully in get_partitions method."""
config = AthenaConfig.parse_obj(
{
"aws_region": "us-west-1",
"query_result_location": "s3://query-result-location/",
"work_group": "test-workgroup",
"extract_partitions": True,
"profiling": {"enabled": True, "partition_profiling_enabled": True},
}
)
ctx = PipelineContext(run_id="test")
source = AthenaSource(config=config, ctx=ctx)
# Mock cursor and metadata
mock_cursor = mock.MagicMock()
mock_metadata = mock.MagicMock()
mock_partition_key = mock.MagicMock()
mock_partition_key.name = "year"
mock_metadata.partition_keys = [mock_partition_key]
mock_cursor.get_table_metadata.return_value = mock_metadata
source.cursor = mock_cursor
# Test with malicious schema name - should be handled gracefully by report_exc
result = source.get_partitions(
mock.MagicMock(), "schema'; DROP TABLE users; --", "valid_table"
)
# Should still return the partition list from metadata since the error
# occurs in the profiling section which is wrapped in report_exc
assert result == ["year"]
# Verify metadata was still called
mock_cursor.get_table_metadata.assert_called_once()
def test_sanitize_identifier_error_handling_in_generate_partition_profiler_query(
caplog,
):
"""Test that ValueError from _sanitize_identifier is handled gracefully in generate_partition_profiler_query."""
import logging
config = AthenaConfig.parse_obj(
{
"aws_region": "us-west-1",
"query_result_location": "s3://query-result-location/",
"work_group": "test-workgroup",
"profiling": {"enabled": True, "partition_profiling_enabled": True},
}
)
ctx = PipelineContext(run_id="test")
source = AthenaSource(config=config, ctx=ctx)
# Add a mock partition to the cache with malicious partition key
source.table_partition_cache["valid_schema"] = {
"valid_table": Partitionitem(
partitions=["year"],
max_partition={"malicious'; DROP TABLE users; --": "2023"},
)
}
# Capture log messages at WARNING level
with caplog.at_level(logging.WARNING):
# This should handle the ValueError gracefully and return None, None
# instead of crashing the entire ingestion process
result = source.generate_partition_profiler_query(
"valid_schema", "valid_table", None
)
# Verify the method returns None, None when sanitization fails
assert result == (None, None), "Should return None, None when sanitization fails"
# Verify that a warning log message was generated
assert len(caplog.records) == 1
log_record = caplog.records[0]
assert log_record.levelname == "WARNING"
assert (
"Failed to generate partition profiler query for valid_schema.valid_table due to unsafe identifiers"
in log_record.message
)
assert "contains unsafe characters" in log_record.message
assert "Partition profiling disabled for this table" in log_record.message