From a82d4e0647af9c3dbc28a4277bd6a43069585544 Mon Sep 17 00:00:00 2001 From: Tamas Nemeth Date: Wed, 10 Sep 2025 12:33:54 +0200 Subject: [PATCH] fix(ingest/athena): Fix Athena partition extraction and CONCAT function type issues (#14712) --- .../datahub/ingestion/source/sql/athena.py | 113 ++- .../source/sql/athena_properties_extractor.py | 20 +- .../tests/unit/test_athena_source.py | 645 +++++++++++++++++- 3 files changed, 746 insertions(+), 32 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py index 8707a04f1d..cd513ee488 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py @@ -73,6 +73,11 @@ except ImportError: logger = logging.getLogger(__name__) +# Precompiled regex for SQL identifier validation +# Athena identifiers can only contain lowercase letters, numbers, underscore, and period (for complex types) +# Note: Athena automatically converts uppercase to lowercase, but we're being strict for security +_IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z0-9_.]+$") + assert STRUCT, "required type modules are not available" register_custom_type(STRUCT, RecordTypeClass) register_custom_type(MapType, MapTypeClass) @@ -510,20 +515,76 @@ class AthenaSource(SQLAlchemySource): return [schema for schema in schemas if schema == athena_config.database] return schemas + @classmethod + def _sanitize_identifier(cls, identifier: str) -> str: + """Sanitize SQL identifiers to prevent injection attacks. + + Args: + identifier: The SQL identifier to sanitize + + Returns: + Sanitized identifier safe for SQL queries + + Raises: + ValueError: If identifier contains unsafe characters + """ + if not identifier: + raise ValueError("Identifier cannot be empty") + + # Allow only alphanumeric characters, underscores, and periods for identifiers + # This matches Athena's identifier naming rules + if not _IDENTIFIER_PATTERN.match(identifier): + raise ValueError( + f"Identifier '{identifier}' contains unsafe characters. Only alphanumeric characters, underscores, and periods are allowed." + ) + + return identifier + @classmethod def _casted_partition_key(cls, key: str) -> str: # We need to cast the partition keys to a VARCHAR, since otherwise # Athena may throw an error during concatenation / comparison. - return f"CAST({key} as VARCHAR)" + sanitized_key = cls._sanitize_identifier(key) + return f"CAST({sanitized_key} as VARCHAR)" + + @classmethod + def _build_max_partition_query( + cls, schema: str, table: str, partitions: List[str] + ) -> str: + """Build SQL query to find the row with maximum partition values. + + Args: + schema: Database schema name + table: Table name + partitions: List of partition column names + + Returns: + SQL query string to find the maximum partition + + Raises: + ValueError: If any identifier contains unsafe characters + """ + # Sanitize all identifiers to prevent SQL injection + sanitized_schema = cls._sanitize_identifier(schema) + sanitized_table = cls._sanitize_identifier(table) + sanitized_partitions = [ + cls._sanitize_identifier(partition) for partition in partitions + ] + + casted_keys = [cls._casted_partition_key(key) for key in partitions] + if len(casted_keys) == 1: + part_concat = casted_keys[0] + else: + separator = "CAST('-' AS VARCHAR)" + part_concat = f"CONCAT({f', {separator}, '.join(casted_keys)})" + + return f'select {",".join(sanitized_partitions)} from "{sanitized_schema}"."{sanitized_table}$partitions" where {part_concat} = (select max({part_concat}) from "{sanitized_schema}"."{sanitized_table}$partitions")' @override def get_partitions( self, inspector: Inspector, schema: str, table: str ) -> Optional[List[str]]: - if ( - not self.config.extract_partitions - and not self.config.extract_partitions_using_create_statements - ): + if not self.config.extract_partitions: return None if not self.cursor: @@ -557,11 +618,9 @@ class AthenaSource(SQLAlchemySource): context=f"{schema}.{table}", level=StructuredLogLevel.WARN, ): - # We create an artifical concatenated partition key to be able to query max partition easier - part_concat = " || '-' || ".join( - self._casted_partition_key(key) for key in partitions + max_partition_query = self._build_max_partition_query( + schema, table, partitions ) - max_partition_query = f'select {",".join(partitions)} from "{schema}"."{table}$partitions" where {part_concat} = (select max({part_concat}) from "{schema}"."{table}$partitions")' ret = self.cursor.execute(max_partition_query) max_partition: Dict[str, str] = {} if ret: @@ -678,16 +737,34 @@ class AthenaSource(SQLAlchemySource): ).get(table, None) if partition and partition.max_partition: - max_partition_filters = [] - for key, value in partition.max_partition.items(): - max_partition_filters.append( - f"{self._casted_partition_key(key)} = '{value}'" + try: + # Sanitize identifiers to prevent SQL injection + sanitized_schema = self._sanitize_identifier(schema) + sanitized_table = self._sanitize_identifier(table) + + max_partition_filters = [] + for key, value in partition.max_partition.items(): + # Sanitize partition key and properly escape the value + sanitized_key = self._sanitize_identifier(key) + # Escape single quotes in the value to prevent injection + escaped_value = value.replace("'", "''") if value else "" + max_partition_filters.append( + f"{self._casted_partition_key(sanitized_key)} = '{escaped_value}'" + ) + max_partition = str(partition.max_partition) + return ( + max_partition, + f'SELECT * FROM "{sanitized_schema}"."{sanitized_table}" WHERE {" AND ".join(max_partition_filters)}', ) - max_partition = str(partition.max_partition) - return ( - max_partition, - f'SELECT * FROM "{schema}"."{table}" WHERE {" AND ".join(max_partition_filters)}', - ) + except ValueError as e: + # If sanitization fails due to malicious identifiers, + # return None to disable partition profiling for this table + # rather than crashing the entire ingestion + logger.warning( + f"Failed to generate partition profiler query for {schema}.{table} due to unsafe identifiers: {e}. " + f"Partition profiling disabled for this table." + ) + return None, None return None, None def close(self): diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/athena_properties_extractor.py b/metadata-ingestion/src/datahub/ingestion/source/sql/athena_properties_extractor.py index 4358789fcd..88d4dd9ac5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/athena_properties_extractor.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/athena_properties_extractor.py @@ -174,20 +174,16 @@ class AthenaPropertiesExtractor: def format_column_definition(line): # Use regex to parse the line more accurately # Pattern: column_name data_type [COMMENT comment_text] [,] - # Use greedy match for comment to capture everything until trailing comma - pattern = r"^\s*(.+?)\s+([\s,\w<>\[\]]+)((\s+COMMENT\s+(.+?)(,?))|(,?)\s*)?$" - match = re.match(pattern, line, re.IGNORECASE) + # Improved pattern to better separate column name, data type, and comment + pattern = r"^\s*([`\w']+)\s+([\w<>\[\](),\s]+?)(\s+COMMENT\s+(.+?))?(,?)\s*$" + match = re.match(pattern, line.strip(), re.IGNORECASE) if not match: return line - column_name = match.group(1) - data_type = match.group(2) - comment_part = match.group(5) # COMMENT part - # there are different number of match groups depending on whether comment exists - if comment_part: - trailing_comma = match.group(6) if match.group(6) else "" - else: - trailing_comma = match.group(7) if match.group(7) else "" + column_name = match.group(1).strip() + data_type = match.group(2).strip() + comment_part = match.group(4) # COMMENT part + trailing_comma = match.group(5) if match.group(5) else "" # Add backticks to column name if not already present if not (column_name.startswith("`") and column_name.endswith("`")): @@ -201,7 +197,7 @@ class AthenaPropertiesExtractor: # Handle comment quoting and escaping if comment_part.startswith("'") and comment_part.endswith("'"): - # Already properly single quoted - but check for proper escaping + # Already single quoted - but check for proper escaping inner_content = comment_part[1:-1] # Re-escape any single quotes that aren't properly escaped escaped_content = inner_content.replace("'", "''") diff --git a/metadata-ingestion/tests/unit/test_athena_source.py b/metadata-ingestion/tests/unit/test_athena_source.py index 682a312c25..9d2c3a0a87 100644 --- a/metadata-ingestion/tests/unit/test_athena_source.py +++ b/metadata-ingestion/tests/unit/test_athena_source.py @@ -3,10 +3,12 @@ from typing import List from unittest import mock import pytest +import sqlglot from freezegun import freeze_time from pyathena import OperationalError from sqlalchemy import types from sqlalchemy_bigquery import STRUCT +from sqlglot.dialects import Athena from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.source.aws.s3_util import make_s3_urn @@ -14,6 +16,7 @@ from datahub.ingestion.source.sql.athena import ( AthenaConfig, AthenaSource, CustomAthenaRestDialect, + Partitionitem, ) from datahub.metadata.schema_classes import ( ArrayTypeClass, @@ -182,8 +185,8 @@ def test_athena_get_table_properties(): expected_query = """\ select year,month from "test_schema"."test_table$partitions" \ -where CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR) = \ -(select max(CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR)) \ +where CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR)) = \ +(select max(CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR))) \ from "test_schema"."test_table$partitions")""" assert mock_cursor.execute.call_count == 2 assert expected_create_table_query == mock_cursor.execute.call_args_list[0][0][0] @@ -503,3 +506,641 @@ def test_convert_simple_field_paths_to_v1_default_behavior(): ) assert config.emit_schema_fieldpaths_as_v1 is False # Should default to False + + +def test_get_partitions_returns_none_when_extract_partitions_disabled(): + """Test that get_partitions returns None when extract_partitions is False""" + config = AthenaConfig.parse_obj( + { + "aws_region": "us-west-1", + "query_result_location": "s3://query-result-location/", + "work_group": "test-workgroup", + "extract_partitions": False, + } + ) + + ctx = PipelineContext(run_id="test") + source = AthenaSource(config=config, ctx=ctx) + + # Mock inspector - should not be used if extract_partitions is False + mock_inspector = mock.MagicMock() + + # Call get_partitions - should return None immediately without any database calls + result = source.get_partitions(mock_inspector, "test_schema", "test_table") + + # Verify result is None + assert result is None + + # Verify that no inspector methods were called + assert not mock_inspector.called, ( + "Inspector should not be called when extract_partitions=False" + ) + + +def test_get_partitions_attempts_extraction_when_extract_partitions_enabled(): + """Test that get_partitions attempts partition extraction when extract_partitions is True""" + config = AthenaConfig.parse_obj( + { + "aws_region": "us-west-1", + "query_result_location": "s3://query-result-location/", + "work_group": "test-workgroup", + "extract_partitions": True, + } + ) + + ctx = PipelineContext(run_id="test") + source = AthenaSource(config=config, ctx=ctx) + + # Mock inspector and cursor for partition extraction + mock_inspector = mock.MagicMock() + mock_cursor = mock.MagicMock() + + # Mock the table metadata response + mock_metadata = mock.MagicMock() + mock_partition_key = mock.MagicMock() + mock_partition_key.name = "year" + mock_metadata.partition_keys = [mock_partition_key] + mock_cursor.get_table_metadata.return_value = mock_metadata + + # Set the cursor on the source + source.cursor = mock_cursor + + # Call get_partitions - should attempt partition extraction + result = source.get_partitions(mock_inspector, "test_schema", "test_table") + + # Verify that the cursor was used (partition extraction was attempted) + mock_cursor.get_table_metadata.assert_called_once_with( + table_name="test_table", schema_name="test_schema" + ) + + # Result should be a list (even if empty) + assert isinstance(result, list) + assert result == ["year"] # Should contain the partition key name + + +def test_partition_profiling_sql_generation_single_key(): + """Test that partition profiling generates valid SQL for single partition key and can be parsed by SQLGlot.""" + + config = AthenaConfig.parse_obj( + { + "aws_region": "us-west-1", + "query_result_location": "s3://query-result-location/", + "work_group": "test-workgroup", + "extract_partitions": True, + "profiling": {"enabled": True, "partition_profiling_enabled": True}, + } + ) + + ctx = PipelineContext(run_id="test") + source = AthenaSource(config=config, ctx=ctx) + + # Mock cursor and metadata for single partition key + mock_cursor = mock.MagicMock() + mock_metadata = mock.MagicMock() + mock_partition_key = mock.MagicMock() + mock_partition_key.name = "year" + mock_metadata.partition_keys = [mock_partition_key] + mock_cursor.get_table_metadata.return_value = mock_metadata + + # Mock successful partition query execution + mock_result = mock.MagicMock() + mock_result.description = [["year"]] + mock_result.__iter__ = lambda x: iter([["2023"]]) + mock_cursor.execute.return_value = mock_result + + source.cursor = mock_cursor + + # Call get_partitions to trigger SQL generation + result = source.get_partitions(mock.MagicMock(), "test_schema", "test_table") + + # Get the generated SQL query + assert mock_cursor.execute.called + generated_query = mock_cursor.execute.call_args[0][0] + + # Verify the query structure for single partition key + assert "CAST(year as VARCHAR)" in generated_query + assert "CONCAT" not in generated_query # Single key shouldn't use CONCAT + assert '"test_schema"."test_table$partitions"' in generated_query + + # Validate that SQLGlot can parse the generated query using Athena dialect + try: + parsed = sqlglot.parse_one(generated_query, dialect=Athena) + assert parsed is not None + assert isinstance(parsed, sqlglot.expressions.Select) + print(f"✅ Single partition SQL parsed successfully: {generated_query}") + except Exception as e: + pytest.fail( + f"SQLGlot failed to parse single partition query: {e}\nQuery: {generated_query}" + ) + + assert result == ["year"] + + +def test_partition_profiling_sql_generation_multiple_keys(): + """Test that partition profiling generates valid SQL for multiple partition keys and can be parsed by SQLGlot.""" + + config = AthenaConfig.parse_obj( + { + "aws_region": "us-west-1", + "query_result_location": "s3://query-result-location/", + "work_group": "test-workgroup", + "extract_partitions": True, + "profiling": {"enabled": True, "partition_profiling_enabled": True}, + } + ) + + ctx = PipelineContext(run_id="test") + source = AthenaSource(config=config, ctx=ctx) + + # Mock cursor and metadata for multiple partition keys + mock_cursor = mock.MagicMock() + mock_metadata = mock.MagicMock() + + mock_year_key = mock.MagicMock() + mock_year_key.name = "year" + mock_month_key = mock.MagicMock() + mock_month_key.name = "month" + mock_day_key = mock.MagicMock() + mock_day_key.name = "day" + + mock_metadata.partition_keys = [mock_year_key, mock_month_key, mock_day_key] + mock_cursor.get_table_metadata.return_value = mock_metadata + + # Mock successful partition query execution + mock_result = mock.MagicMock() + mock_result.description = [["year"], ["month"], ["day"]] + mock_result.__iter__ = lambda x: iter([["2023", "12", "25"]]) + mock_cursor.execute.return_value = mock_result + + source.cursor = mock_cursor + + # Call get_partitions to trigger SQL generation + result = source.get_partitions(mock.MagicMock(), "test_schema", "test_table") + + # Get the generated SQL query + assert mock_cursor.execute.called + generated_query = mock_cursor.execute.call_args[0][0] + + # Verify the query structure for multiple partition keys + assert "CONCAT(" in generated_query # Multiple keys should use CONCAT + assert "CAST(year as VARCHAR)" in generated_query + assert "CAST(month as VARCHAR)" in generated_query + assert "CAST(day as VARCHAR)" in generated_query + assert "CAST('-' AS VARCHAR)" in generated_query # Separator should be cast + assert '"test_schema"."test_table$partitions"' in generated_query + + # Validate that SQLGlot can parse the generated query using Athena dialect + try: + parsed = sqlglot.parse_one(generated_query, dialect=Athena) + assert parsed is not None + assert isinstance(parsed, sqlglot.expressions.Select) + print(f"✅ Multiple partition SQL parsed successfully: {generated_query}") + except Exception as e: + pytest.fail( + f"SQLGlot failed to parse multiple partition query: {e}\nQuery: {generated_query}" + ) + + assert result == ["year", "month", "day"] + + +def test_partition_profiling_sql_generation_complex_schema_table_names(): + """Test that partition profiling handles complex schema/table names correctly and generates valid SQL.""" + + config = AthenaConfig.parse_obj( + { + "aws_region": "us-west-1", + "query_result_location": "s3://query-result-location/", + "work_group": "test-workgroup", + "extract_partitions": True, + "profiling": {"enabled": True, "partition_profiling_enabled": True}, + } + ) + + ctx = PipelineContext(run_id="test") + source = AthenaSource(config=config, ctx=ctx) + + # Mock cursor and metadata + mock_cursor = mock.MagicMock() + mock_metadata = mock.MagicMock() + + mock_partition_key = mock.MagicMock() + mock_partition_key.name = "event_date" + mock_metadata.partition_keys = [mock_partition_key] + mock_cursor.get_table_metadata.return_value = mock_metadata + + # Mock successful partition query execution + mock_result = mock.MagicMock() + mock_result.description = [["event_date"]] + mock_result.__iter__ = lambda x: iter([["2023-12-25"]]) + mock_cursor.execute.return_value = mock_result + + source.cursor = mock_cursor + + # Test with complex schema and table names + schema = "ad_cdp_audience" # From the user's error + table = "system_import_label" + + result = source.get_partitions(mock.MagicMock(), schema, table) + + # Get the generated SQL query + assert mock_cursor.execute.called + generated_query = mock_cursor.execute.call_args[0][0] + + # Verify proper quoting of schema and table names + expected_table_ref = f'"{schema}"."{table}$partitions"' + assert expected_table_ref in generated_query + assert "CAST(event_date as VARCHAR)" in generated_query + + # Validate that SQLGlot can parse the generated query using Athena dialect + try: + parsed = sqlglot.parse_one(generated_query, dialect=Athena) + assert parsed is not None + assert isinstance(parsed, sqlglot.expressions.Select) + print(f"✅ Complex schema/table SQL parsed successfully: {generated_query}") + except Exception as e: + pytest.fail( + f"SQLGlot failed to parse complex schema/table query: {e}\nQuery: {generated_query}" + ) + + assert result == ["event_date"] + + +def test_casted_partition_key_method(): + """Test the _casted_partition_key helper method generates valid SQL fragments.""" + + # Test the static method directly + casted_key = AthenaSource._casted_partition_key("test_column") + assert casted_key == "CAST(test_column as VARCHAR)" + + # Verify SQLGlot can parse the CAST expression + try: + parsed = sqlglot.parse_one(casted_key, dialect=Athena) + assert parsed is not None + assert isinstance(parsed, sqlglot.expressions.Cast) + assert parsed.this.name == "test_column" + assert "VARCHAR" in str(parsed.to.this) # More flexible assertion + print(f"✅ CAST expression parsed successfully: {casted_key}") + except Exception as e: + pytest.fail( + f"SQLGlot failed to parse CAST expression: {e}\nExpression: {casted_key}" + ) + + # Test with various column name formats + test_cases = ["year", "month", "event_date", "created_at", "partition_key_123"] + + for column_name in test_cases: + casted = AthenaSource._casted_partition_key(column_name) + try: + parsed = sqlglot.parse_one(casted, dialect=Athena) + assert parsed is not None + assert isinstance(parsed, sqlglot.expressions.Cast) + except Exception as e: + pytest.fail(f"SQLGlot failed to parse CAST for column '{column_name}': {e}") + + +def test_concat_function_generation_validates_with_sqlglot(): + """Test that our CONCAT function generation produces valid Athena SQL according to SQLGlot.""" + + # Test the CONCAT function generation logic directly + partition_keys = ["year", "month", "day"] + casted_keys = [AthenaSource._casted_partition_key(key) for key in partition_keys] + + # Replicate the CONCAT generation logic from the source + concat_args = [] + for i, key in enumerate(casted_keys): + concat_args.append(key) + if i < len(casted_keys) - 1: # Add separator except for last element + concat_args.append("CAST('-' AS VARCHAR)") + + concat_expr = f"CONCAT({', '.join(concat_args)})" + + # Verify the generated CONCAT expression + expected = "CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR), CAST('-' AS VARCHAR), CAST(day as VARCHAR))" + assert concat_expr == expected + + # Validate with SQLGlot + try: + parsed = sqlglot.parse_one(concat_expr, dialect=Athena) + assert parsed is not None + assert isinstance(parsed, sqlglot.expressions.Concat) + # Verify all arguments are properly parsed + assert len(parsed.expressions) == 5 # 3 partition keys + 2 separators + print(f"✅ CONCAT expression parsed successfully: {concat_expr}") + except Exception as e: + pytest.fail( + f"SQLGlot failed to parse CONCAT expression: {e}\nExpression: {concat_expr}" + ) + + +def test_build_max_partition_query(): + """Test _build_max_partition_query method directly without mocking.""" + + # Test single partition key + query_single = AthenaSource._build_max_partition_query( + "test_schema", "test_table", ["year"] + ) + expected_single = 'select year from "test_schema"."test_table$partitions" where CAST(year as VARCHAR) = (select max(CAST(year as VARCHAR)) from "test_schema"."test_table$partitions")' + assert query_single == expected_single + + # Test multiple partition keys + query_multiple = AthenaSource._build_max_partition_query( + "test_schema", "test_table", ["year", "month", "day"] + ) + expected_multiple = "select year,month,day from \"test_schema\".\"test_table$partitions\" where CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR), CAST('-' AS VARCHAR), CAST(day as VARCHAR)) = (select max(CONCAT(CAST(year as VARCHAR), CAST('-' AS VARCHAR), CAST(month as VARCHAR), CAST('-' AS VARCHAR), CAST(day as VARCHAR))) from \"test_schema\".\"test_table$partitions\")" + assert query_multiple == expected_multiple + + # Validate with SQLGlot that generated queries are valid SQL + try: + parsed_single = sqlglot.parse_one(query_single, dialect=Athena) + assert parsed_single is not None + assert isinstance(parsed_single, sqlglot.expressions.Select) + + parsed_multiple = sqlglot.parse_one(query_multiple, dialect=Athena) + assert parsed_multiple is not None + assert isinstance(parsed_multiple, sqlglot.expressions.Select) + + print("✅ Both queries parsed successfully by SQLGlot") + except Exception as e: + pytest.fail(f"SQLGlot failed to parse generated query: {e}") + + +def test_partition_profiling_disabled_no_sql_generation(): + """Test that when partition profiling is disabled, no complex SQL is generated.""" + config = AthenaConfig.parse_obj( + { + "aws_region": "us-west-1", + "query_result_location": "s3://query-result-location/", + "work_group": "test-workgroup", + "extract_partitions": True, + "profiling": {"enabled": False, "partition_profiling_enabled": False}, + } + ) + + ctx = PipelineContext(run_id="test") + source = AthenaSource(config=config, ctx=ctx) + + # Mock cursor and metadata + mock_cursor = mock.MagicMock() + mock_metadata = mock.MagicMock() + mock_partition_key = mock.MagicMock() + mock_partition_key.name = "year" + mock_metadata.partition_keys = [mock_partition_key] + mock_cursor.get_table_metadata.return_value = mock_metadata + + source.cursor = mock_cursor + + # Call get_partitions - should not generate complex profiling SQL + result = source.get_partitions(mock.MagicMock(), "test_schema", "test_table") + + # Should only call get_table_metadata, not execute complex partition queries + mock_cursor.get_table_metadata.assert_called_once() + # The execute method should not be called for profiling queries when profiling is disabled + assert not mock_cursor.execute.called + + assert result == ["year"] + + +def test_sanitize_identifier_valid_names(): + """Test _sanitize_identifier method with valid Athena identifiers.""" + # Valid simple identifiers + valid_identifiers = [ + "table_name", + "schema123", + "column_1", + "MyTable", # Should be allowed (Athena converts to lowercase) + "DATABASE", + "a", # Single character + "table123name", + "data_warehouse_table", + ] + + for identifier in valid_identifiers: + result = AthenaSource._sanitize_identifier(identifier) + assert result == identifier, f"Expected {identifier} to be valid" + + +def test_sanitize_identifier_valid_complex_types(): + """Test _sanitize_identifier method with valid complex type identifiers.""" + # Valid complex type identifiers (with periods) + valid_complex_identifiers = [ + "struct.field", + "data.subfield", + "nested.struct.field", + "array.element", + "map.key", + "record.attribute.subfield", + ] + + for identifier in valid_complex_identifiers: + result = AthenaSource._sanitize_identifier(identifier) + assert result == identifier, ( + f"Expected {identifier} to be valid for complex types" + ) + + +def test_sanitize_identifier_invalid_characters(): + """Test _sanitize_identifier method rejects invalid characters.""" + # Invalid identifiers that should be rejected + invalid_identifiers = [ + "table-name", # hyphen not allowed in Athena + "table name", # spaces not allowed + "table@domain", # @ not allowed + "table#hash", # # not allowed + "table$var", # $ not allowed (except in system tables) + "table%percent", # % not allowed + "table&and", # & not allowed + "table*star", # * not allowed + "table(paren", # ( not allowed + "table)paren", # ) not allowed + "table+plus", # + not allowed + "table=equal", # = not allowed + "table[bracket", # [ not allowed + "table]bracket", # ] not allowed + "table{brace", # { not allowed + "table}brace", # } not allowed + "table|pipe", # | not allowed + "table\\backslash", # \ not allowed + "table:colon", # : not allowed + "table;semicolon", # ; not allowed + "tablegreater", # > not allowed + "table?question", # ? not allowed + "table/slash", # / not allowed + "table,comma", # , not allowed + ] + + for identifier in invalid_identifiers: + with pytest.raises(ValueError, match="contains unsafe characters"): + AthenaSource._sanitize_identifier(identifier) + + +def test_sanitize_identifier_sql_injection_attempts(): + """Test _sanitize_identifier method blocks SQL injection attempts.""" + # SQL injection attempts that should be blocked + sql_injection_attempts = [ + "table'; DROP TABLE users; --", + 'table" OR 1=1 --', + "table; DELETE FROM data;", + "'; UNION SELECT * FROM passwords --", + "table') OR ('1'='1", + "table\"; INSERT INTO logs VALUES ('hack'); --", + "table' AND 1=0 UNION SELECT * FROM admin --", + "table/*comment*/", + "table--comment", + "table' OR 'a'='a", + "1'; DROP DATABASE test; --", + "table'; EXEC xp_cmdshell('dir'); --", + ] + + for injection_attempt in sql_injection_attempts: + with pytest.raises(ValueError, match="contains unsafe characters"): + AthenaSource._sanitize_identifier(injection_attempt) + + +def test_sanitize_identifier_quote_injection_attempts(): + """Test _sanitize_identifier method blocks quote-based injection attempts.""" + # Quote injection attempts + quote_injection_attempts = [ + 'table"', + "table'", + 'table""', + "table''", + 'table"extra', + "table'extra", + "`table`", + '"table"', + "'table'", + ] + + for injection_attempt in quote_injection_attempts: + with pytest.raises(ValueError, match="contains unsafe characters"): + AthenaSource._sanitize_identifier(injection_attempt) + + +def test_sanitize_identifier_empty_and_edge_cases(): + """Test _sanitize_identifier method with empty and edge case inputs.""" + # Empty identifier + with pytest.raises(ValueError, match="Identifier cannot be empty"): + AthenaSource._sanitize_identifier("") + + # None input (should also raise ValueError from empty check) + with pytest.raises(ValueError, match="Identifier cannot be empty"): + AthenaSource._sanitize_identifier(None) # type: ignore[arg-type] + + # Very long valid identifier + long_identifier = "a" * 200 # Still within Athena's 255 byte limit + result = AthenaSource._sanitize_identifier(long_identifier) + assert result == long_identifier + + +def test_sanitize_identifier_integration_with_build_max_partition_query(): + """Test that _sanitize_identifier works correctly within _build_max_partition_query.""" + # Test with valid identifiers + query = AthenaSource._build_max_partition_query( + "valid_schema", "valid_table", ["valid_partition"] + ) + assert "valid_schema" in query + assert "valid_table" in query + assert "valid_partition" in query + + # Test that invalid identifiers are rejected before query building + with pytest.raises(ValueError, match="contains unsafe characters"): + AthenaSource._build_max_partition_query( + "schema'; DROP TABLE users; --", "table", ["partition"] + ) + + with pytest.raises(ValueError, match="contains unsafe characters"): + AthenaSource._build_max_partition_query( + "schema", "table-with-hyphen", ["partition"] + ) + + with pytest.raises(ValueError, match="contains unsafe characters"): + AthenaSource._build_max_partition_query( + "schema", "table", ["partition; DELETE FROM data;"] + ) + + +def test_sanitize_identifier_error_handling_in_get_partitions(): + """Test that ValueError from _sanitize_identifier is handled gracefully in get_partitions method.""" + config = AthenaConfig.parse_obj( + { + "aws_region": "us-west-1", + "query_result_location": "s3://query-result-location/", + "work_group": "test-workgroup", + "extract_partitions": True, + "profiling": {"enabled": True, "partition_profiling_enabled": True}, + } + ) + + ctx = PipelineContext(run_id="test") + source = AthenaSource(config=config, ctx=ctx) + + # Mock cursor and metadata + mock_cursor = mock.MagicMock() + mock_metadata = mock.MagicMock() + mock_partition_key = mock.MagicMock() + mock_partition_key.name = "year" + mock_metadata.partition_keys = [mock_partition_key] + mock_cursor.get_table_metadata.return_value = mock_metadata + source.cursor = mock_cursor + + # Test with malicious schema name - should be handled gracefully by report_exc + result = source.get_partitions( + mock.MagicMock(), "schema'; DROP TABLE users; --", "valid_table" + ) + + # Should still return the partition list from metadata since the error + # occurs in the profiling section which is wrapped in report_exc + assert result == ["year"] + + # Verify metadata was still called + mock_cursor.get_table_metadata.assert_called_once() + + +def test_sanitize_identifier_error_handling_in_generate_partition_profiler_query( + caplog, +): + """Test that ValueError from _sanitize_identifier is handled gracefully in generate_partition_profiler_query.""" + import logging + + config = AthenaConfig.parse_obj( + { + "aws_region": "us-west-1", + "query_result_location": "s3://query-result-location/", + "work_group": "test-workgroup", + "profiling": {"enabled": True, "partition_profiling_enabled": True}, + } + ) + + ctx = PipelineContext(run_id="test") + source = AthenaSource(config=config, ctx=ctx) + + # Add a mock partition to the cache with malicious partition key + source.table_partition_cache["valid_schema"] = { + "valid_table": Partitionitem( + partitions=["year"], + max_partition={"malicious'; DROP TABLE users; --": "2023"}, + ) + } + + # Capture log messages at WARNING level + with caplog.at_level(logging.WARNING): + # This should handle the ValueError gracefully and return None, None + # instead of crashing the entire ingestion process + result = source.generate_partition_profiler_query( + "valid_schema", "valid_table", None + ) + + # Verify the method returns None, None when sanitization fails + assert result == (None, None), "Should return None, None when sanitization fails" + + # Verify that a warning log message was generated + assert len(caplog.records) == 1 + log_record = caplog.records[0] + assert log_record.levelname == "WARNING" + assert ( + "Failed to generate partition profiler query for valid_schema.valid_table due to unsafe identifiers" + in log_record.message + ) + assert "contains unsafe characters" in log_record.message + assert "Partition profiling disabled for this table" in log_record.message