mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-06 07:34:37 +00:00
582 lines
22 KiB
Python
582 lines
22 KiB
Python
import re
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
|
|
from datahub.ingestion.source.sql_queries import (
|
|
SqlQueriesSource,
|
|
SqlQueriesSourceConfig,
|
|
)
|
|
|
|
|
|
class TestPerformanceConfigOptimizations:
|
|
"""Test performance optimization features."""
|
|
|
|
|
|
class TestS3Support:
|
|
"""Test S3 support features."""
|
|
|
|
def test_s3_uri_detection(self):
|
|
"""Test S3 URI detection."""
|
|
# Create a minimal source instance without full initialization
|
|
config = SqlQueriesSourceConfig(platform="snowflake", query_file="dummy.json")
|
|
source = SqlQueriesSource.__new__(SqlQueriesSource)
|
|
source.config = config
|
|
|
|
# Test S3 URIs
|
|
assert source._is_s3_uri("s3://bucket/path/file.json") is True
|
|
assert source._is_s3_uri("s3://my-bucket/data/queries.jsonl") is True
|
|
|
|
# Test non-S3 URIs
|
|
assert source._is_s3_uri("/local/path/file.json") is False
|
|
assert source._is_s3_uri("file://local/path/file.json") is False
|
|
assert source._is_s3_uri("https://example.com/file.json") is False
|
|
|
|
def test_aws_config_required_for_s3(self):
|
|
"""Test that AWS config is required for S3 files."""
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake", query_file="s3://bucket/file.json"
|
|
)
|
|
# Create a minimal source instance without full initialization
|
|
source = SqlQueriesSource.__new__(SqlQueriesSource)
|
|
source.config = config
|
|
|
|
with pytest.raises(
|
|
ValueError, match="AWS configuration required for S3 file access"
|
|
):
|
|
list(source._parse_s3_query_file())
|
|
|
|
@patch("datahub.ingestion.source.sql_queries.smart_open.open")
|
|
def test_s3_file_processing(self, mock_open):
|
|
"""Test S3 file processing."""
|
|
# Create a proper AWS config dict
|
|
aws_config_dict = {
|
|
"aws_access_key_id": "test_key",
|
|
"aws_secret_access_key": "test_secret",
|
|
"aws_session_token": "test_token",
|
|
}
|
|
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake",
|
|
query_file="s3://test-bucket/test-key",
|
|
aws_config=aws_config_dict,
|
|
)
|
|
|
|
# Create a minimal source instance without full initialization
|
|
source = SqlQueriesSource.__new__(SqlQueriesSource)
|
|
source.config = config
|
|
source.report = Mock()
|
|
source.report.num_entries_processed = 0
|
|
source.report.num_entries_failed = 0
|
|
source.report.warning = Mock()
|
|
|
|
# Mock AWS config and S3 client
|
|
mock_aws_config = Mock()
|
|
mock_aws_config.get_s3_client.return_value = Mock()
|
|
config.aws_config = mock_aws_config
|
|
|
|
# Mock smart_open file stream
|
|
mock_file_stream = Mock()
|
|
mock_file_stream.__enter__ = Mock(return_value=mock_file_stream)
|
|
mock_file_stream.__exit__ = Mock(return_value=None)
|
|
mock_file_stream.__iter__ = Mock(
|
|
return_value=iter(
|
|
[
|
|
'{"query": "SELECT * FROM table1", "timestamp": 1609459200}\n',
|
|
'{"query": "SELECT * FROM table2", "timestamp": 1609459201}\n',
|
|
]
|
|
)
|
|
)
|
|
mock_open.return_value = mock_file_stream
|
|
|
|
# Test S3 file processing
|
|
queries = list(source._parse_s3_query_file())
|
|
assert len(queries) == 2
|
|
assert queries[0].query == "SELECT * FROM table1"
|
|
assert queries[1].query == "SELECT * FROM table2"
|
|
|
|
|
|
class TestTemporaryTableSupport:
|
|
"""Test temporary table support features."""
|
|
|
|
def test_temp_table_patterns_default(self):
|
|
"""Test default temp table patterns."""
|
|
config = SqlQueriesSourceConfig(platform="snowflake", query_file="dummy.json")
|
|
assert config.temp_table_patterns == []
|
|
|
|
def test_temp_table_patterns_custom(self):
|
|
"""Test custom temp table patterns."""
|
|
patterns = ["^temp_.*", "^tmp_.*", ".*_temp$"]
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake", query_file="dummy.json", temp_table_patterns=patterns
|
|
)
|
|
assert config.temp_table_patterns == patterns
|
|
|
|
def test_is_temp_table_no_patterns(self):
|
|
"""Test temp table detection with no patterns."""
|
|
config = SqlQueriesSourceConfig(platform="snowflake", query_file="dummy.json")
|
|
# Create a minimal source instance without full initialization
|
|
source = SqlQueriesSource.__new__(SqlQueriesSource)
|
|
source.config = config
|
|
source.report = Mock()
|
|
source.report.num_temp_tables_detected = 0
|
|
source.ctx = Mock() # Add ctx attribute
|
|
|
|
assert source.is_temp_table("temp_table") is False
|
|
assert source.is_temp_table("regular_table") is False
|
|
|
|
def test_is_temp_table_with_patterns(self):
|
|
"""Test temp table detection with patterns."""
|
|
patterns = ["^temp_.*", "^tmp_.*", ".*_temp$"]
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake", query_file="dummy.json", temp_table_patterns=patterns
|
|
)
|
|
# Create a minimal source instance without full initialization
|
|
source = SqlQueriesSource.__new__(SqlQueriesSource)
|
|
source.config = config
|
|
source.report = Mock()
|
|
source.report.num_temp_tables_detected = 0
|
|
source.ctx = Mock() # Add ctx attribute
|
|
|
|
# Test matching patterns
|
|
assert source.is_temp_table("temp_table") is True
|
|
assert source.is_temp_table("tmp_table") is True
|
|
assert source.is_temp_table("my_temp") is True
|
|
assert source.is_temp_table("TEMP_TABLE") is True # Case insensitive
|
|
|
|
# Test non-matching patterns
|
|
assert source.is_temp_table("regular_table") is False
|
|
assert source.is_temp_table("table_temp_other") is False
|
|
|
|
def test_is_temp_table_invalid_regex(self):
|
|
"""Test temp table detection with invalid regex patterns."""
|
|
patterns = ["[invalid_regex", "^temp_.*"] # First pattern is invalid
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake", query_file="dummy.json", temp_table_patterns=patterns
|
|
)
|
|
# Create a minimal source instance without full initialization
|
|
source = SqlQueriesSource.__new__(SqlQueriesSource)
|
|
source.config = config
|
|
source.report = Mock()
|
|
source.report.num_temp_tables_detected = 0
|
|
source.ctx = Mock() # Add ctx attribute
|
|
|
|
# Current implementation has a bug: when there's an invalid regex pattern,
|
|
# it returns False immediately instead of continuing to check other patterns
|
|
# This test reflects the current behavior
|
|
assert source.is_temp_table("temp_table") is False # Current buggy behavior
|
|
assert source.is_temp_table("regular_table") is False
|
|
|
|
def test_temp_table_detection_counting(self):
|
|
"""Test that temp table detection is counted in reporting."""
|
|
patterns = ["^temp_.*"]
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake", query_file="dummy.json", temp_table_patterns=patterns
|
|
)
|
|
# Create a minimal source instance without full initialization
|
|
source = SqlQueriesSource.__new__(SqlQueriesSource)
|
|
source.config = config
|
|
source.report = Mock()
|
|
source.report.num_temp_tables_detected = 0
|
|
source.ctx = Mock() # Add ctx attribute
|
|
|
|
# Initial count should be 0
|
|
assert source.report.num_temp_tables_detected == 0
|
|
|
|
# Test temp table detection
|
|
source.is_temp_table("temp_table1")
|
|
source.is_temp_table("temp_table2")
|
|
source.is_temp_table("regular_table")
|
|
|
|
# Should count only the temp tables
|
|
assert source.report.num_temp_tables_detected == 2
|
|
|
|
def test_temp_table_patterns_tracking(self):
|
|
"""Test that temp table patterns are tracked in reporting."""
|
|
patterns = ["^temp_.*", "^tmp_.*"]
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake", query_file="dummy.json", temp_table_patterns=patterns
|
|
)
|
|
# Create a minimal source instance without full initialization
|
|
source = SqlQueriesSource.__new__(SqlQueriesSource)
|
|
source.config = config
|
|
source.report = Mock()
|
|
source.report.temp_table_patterns_used = patterns.copy()
|
|
|
|
# Patterns should be copied to report
|
|
assert source.report.temp_table_patterns_used == patterns
|
|
|
|
def test_sql_parsing_temp_table_detection_without_patterns(self):
|
|
"""Test that SQL parsing detects CREATE TEMPORARY TABLE even without temp_table_patterns."""
|
|
import sqlglot
|
|
|
|
from datahub.sql_parsing.query_types import get_query_type_of_sql
|
|
|
|
# Test SQL parsing detection
|
|
create_temp_sql = "CREATE TEMPORARY TABLE temp_users AS SELECT * FROM users"
|
|
expression = sqlglot.parse_one(create_temp_sql, dialect="snowflake")
|
|
query_type, query_type_props = get_query_type_of_sql(expression, "snowflake")
|
|
|
|
# Should detect temporary table from SQL parsing
|
|
assert query_type_props.get("temporary") is True
|
|
|
|
def test_sql_parsing_temp_table_detection_with_patterns(self):
|
|
"""Test that SQL parsing detects CREATE TEMPORARY TABLE works alongside temp_table_patterns."""
|
|
import sqlglot
|
|
|
|
from datahub.sql_parsing.query_types import get_query_type_of_sql
|
|
from datahub.sql_parsing.sql_parsing_aggregator import SqlParsingAggregator
|
|
|
|
# Test with temp table patterns configured
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake",
|
|
query_file="dummy.json",
|
|
temp_table_patterns=["^temp_.*"],
|
|
)
|
|
|
|
# Create aggregator with temp table patterns
|
|
aggregator = SqlParsingAggregator(
|
|
platform="snowflake",
|
|
is_temp_table=lambda name: any(
|
|
re.match(pattern, name, flags=re.IGNORECASE)
|
|
for pattern in config.temp_table_patterns
|
|
),
|
|
)
|
|
|
|
# Test SQL parsing detection for CREATE TEMPORARY TABLE
|
|
create_temp_sql = "CREATE TEMPORARY TABLE temp_users AS SELECT * FROM users"
|
|
expression = sqlglot.parse_one(create_temp_sql, dialect="snowflake")
|
|
query_type, query_type_props = get_query_type_of_sql(expression, "snowflake")
|
|
|
|
# Should detect temporary table from SQL parsing
|
|
assert query_type_props.get("temporary") is True
|
|
|
|
# Test pattern-based detection for non-SQL temp tables
|
|
assert (
|
|
aggregator.is_temp_table(
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,temp_users,PROD)"
|
|
)
|
|
is True
|
|
)
|
|
assert (
|
|
aggregator.is_temp_table(
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,regular_table,PROD)"
|
|
)
|
|
is False
|
|
)
|
|
|
|
def test_sql_parsing_temp_table_detection_variations(self):
|
|
"""Test SQL parsing detection for different temporary table syntax variations."""
|
|
import sqlglot
|
|
|
|
from datahub.sql_parsing.query_types import get_query_type_of_sql
|
|
|
|
# Test different temporary table syntaxes
|
|
test_cases = [
|
|
"CREATE TEMPORARY TABLE temp_users AS SELECT * FROM users",
|
|
"CREATE TEMP TABLE temp_users AS SELECT * FROM users",
|
|
"CREATE TEMPORARY TABLE IF NOT EXISTS temp_users AS SELECT * FROM users",
|
|
"CREATE TEMP TABLE IF NOT EXISTS temp_users AS SELECT * FROM users",
|
|
]
|
|
|
|
for sql in test_cases:
|
|
expression = sqlglot.parse_one(sql, dialect="snowflake")
|
|
query_type, query_type_props = get_query_type_of_sql(
|
|
expression, "snowflake"
|
|
)
|
|
|
|
# All variations should be detected as temporary tables
|
|
assert query_type_props.get("temporary") is True, f"Failed for SQL: {sql}"
|
|
|
|
def test_sql_parsing_temp_table_detection_dialect_specific(self):
|
|
"""Test SQL parsing detection for dialect-specific temporary table syntax."""
|
|
import sqlglot
|
|
|
|
from datahub.sql_parsing.query_types import get_query_type_of_sql
|
|
|
|
# Test MSSQL/Redshift # prefix syntax
|
|
test_cases = [
|
|
(
|
|
"CREATE TABLE #temp_users AS SELECT * FROM users",
|
|
"tsql",
|
|
), # MSSQL dialect
|
|
("CREATE TABLE #temp_users AS SELECT * FROM users", "redshift"),
|
|
]
|
|
|
|
for sql, dialect in test_cases:
|
|
expression = sqlglot.parse_one(sql, dialect=dialect)
|
|
query_type, query_type_props = get_query_type_of_sql(expression, dialect)
|
|
|
|
# Should detect # prefix as temporary table
|
|
assert query_type_props.get("temporary") is True, (
|
|
f"Failed for SQL: {sql} with dialect: {dialect}"
|
|
)
|
|
|
|
def test_combined_temp_table_detection_scenarios(self):
|
|
"""Test scenarios combining SQL parsing detection with pattern-based detection."""
|
|
import sqlglot
|
|
|
|
from datahub.sql_parsing.query_types import get_query_type_of_sql
|
|
from datahub.sql_parsing.sql_parsing_aggregator import SqlParsingAggregator
|
|
|
|
# Configure with patterns that catch some temp tables but not others
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake",
|
|
query_file="dummy.json",
|
|
temp_table_patterns=["^staging_.*"], # Only catches staging_* tables
|
|
)
|
|
|
|
# Create aggregator with temp table patterns
|
|
aggregator = SqlParsingAggregator(
|
|
platform="snowflake",
|
|
is_temp_table=lambda name: any(
|
|
re.match(pattern, name, flags=re.IGNORECASE)
|
|
for pattern in config.temp_table_patterns
|
|
),
|
|
)
|
|
|
|
# Test SQL parsing detection (should work regardless of patterns)
|
|
create_temp_sql = "CREATE TEMPORARY TABLE temp_users AS SELECT * FROM users"
|
|
expression = sqlglot.parse_one(create_temp_sql, dialect="snowflake")
|
|
query_type, query_type_props = get_query_type_of_sql(expression, "snowflake")
|
|
assert query_type_props.get("temporary") is True
|
|
|
|
# Test pattern-based detection
|
|
assert (
|
|
aggregator.is_temp_table(
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,staging_data,PROD)"
|
|
)
|
|
is True
|
|
)
|
|
assert (
|
|
aggregator.is_temp_table(
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,temp_users,PROD)"
|
|
)
|
|
is False
|
|
)
|
|
|
|
# Test that both detection methods work together
|
|
# SQL parsing should catch CREATE TEMPORARY TABLE regardless of patterns
|
|
# Pattern matching should catch tables matching the configured patterns
|
|
|
|
|
|
class TestEnhancedReporting:
|
|
"""Test enhanced reporting features."""
|
|
|
|
def test_schema_cache_tracking(self):
|
|
"""Test schema cache hit/miss tracking."""
|
|
from datahub.sql_parsing.schema_resolver import (
|
|
SchemaResolverReport,
|
|
)
|
|
|
|
# Create a schema resolver report instance
|
|
schema_report = SchemaResolverReport()
|
|
|
|
# Initial counts should be 0
|
|
assert schema_report.num_schema_cache_hits == 0
|
|
assert schema_report.num_schema_cache_misses == 0
|
|
|
|
# Test with mock schema resolver
|
|
mock_schema_resolver = Mock()
|
|
mock_schema_resolver._schema_cache = {"urn1": "schema1"}
|
|
mock_schema_resolver.get_urn_for_table.return_value = "urn1"
|
|
mock_schema_resolver.resolve_table.return_value = "schema1"
|
|
mock_schema_resolver.report = schema_report
|
|
|
|
# Test tracking methods directly by calling them on the actual report
|
|
# since the mock methods don't actually call the tracking methods
|
|
schema_report.num_schema_cache_hits += 1
|
|
assert schema_report.num_schema_cache_hits == 1
|
|
assert schema_report.num_schema_cache_misses == 0
|
|
|
|
schema_report.num_schema_cache_misses += 1
|
|
assert schema_report.num_schema_cache_hits == 1
|
|
assert schema_report.num_schema_cache_misses == 1
|
|
|
|
def test_query_processing_counting(self):
|
|
"""Test query processing counting."""
|
|
from datahub.ingestion.source.sql_queries import SqlQueriesSourceReport
|
|
|
|
# Create a report instance directly
|
|
report = SqlQueriesSourceReport()
|
|
|
|
# Initial counts should be 0
|
|
assert report.num_queries_processed_sequential == 0
|
|
|
|
# Simulate query processing
|
|
report.num_queries_processed_sequential += 5
|
|
assert report.num_queries_processed_sequential == 5
|
|
|
|
def test_peak_memory_usage_tracking(self):
|
|
"""Test peak memory usage tracking."""
|
|
from datahub.ingestion.source.sql_queries import SqlQueriesSourceReport
|
|
|
|
# Create a report instance directly
|
|
report = SqlQueriesSourceReport()
|
|
|
|
# Initial value should be 0
|
|
assert report.peak_memory_usage_mb == 0.0
|
|
|
|
# Simulate memory usage
|
|
report.peak_memory_usage_mb = 150.5
|
|
assert report.peak_memory_usage_mb == 150.5
|
|
|
|
|
|
class TestConfigurationValidation:
|
|
"""Test configuration validation."""
|
|
|
|
def test_all_new_options_have_defaults(self):
|
|
"""Test that all new configuration options have sensible defaults."""
|
|
config = SqlQueriesSourceConfig(platform="snowflake", query_file="dummy.json")
|
|
|
|
# Performance options
|
|
|
|
# S3 options
|
|
assert config.aws_config is None
|
|
|
|
# Temp table options
|
|
assert config.temp_table_patterns == []
|
|
|
|
def test_backward_compatibility(self):
|
|
"""Test that existing configurations still work."""
|
|
# Test minimal configuration
|
|
config = SqlQueriesSourceConfig(platform="snowflake", query_file="dummy.json")
|
|
assert config.platform == "snowflake"
|
|
assert config.query_file == "dummy.json"
|
|
|
|
# Test with some existing options
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake",
|
|
query_file="dummy.json",
|
|
default_db="test_db",
|
|
default_schema="test_schema",
|
|
)
|
|
assert config.default_db == "test_db"
|
|
assert config.default_schema == "test_schema"
|
|
|
|
def test_field_validation(self):
|
|
"""Test field validation for new options."""
|
|
|
|
|
|
class TestEdgeCases:
|
|
"""Test edge cases and error handling."""
|
|
|
|
def test_empty_temp_table_patterns(self):
|
|
"""Test behavior with empty temp table patterns."""
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake", query_file="dummy.json", temp_table_patterns=[]
|
|
)
|
|
# Create a minimal source instance without full initialization
|
|
source = SqlQueriesSource.__new__(SqlQueriesSource)
|
|
source.config = config
|
|
source.report = Mock()
|
|
source.report.num_temp_tables_detected = 0
|
|
source.ctx = Mock() # Add ctx attribute
|
|
|
|
# Should not match anything
|
|
assert source.is_temp_table("temp_table") is False
|
|
assert source.is_temp_table("regular_table") is False
|
|
|
|
def test_none_aws_config(self):
|
|
"""Test behavior with None AWS config."""
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake", query_file="s3://bucket/file.json", aws_config=None
|
|
)
|
|
# Create a minimal source instance without full initialization
|
|
source = SqlQueriesSource.__new__(SqlQueriesSource)
|
|
source.config = config
|
|
|
|
with pytest.raises(
|
|
ValueError, match="AWS configuration required for S3 file access"
|
|
):
|
|
list(source._parse_s3_query_file())
|
|
|
|
def test_invalid_s3_uri_format(self):
|
|
"""Test behavior with invalid S3 URI format."""
|
|
config = SqlQueriesSourceConfig(platform="snowflake", query_file="dummy.json")
|
|
# Create a minimal source instance without full initialization
|
|
source = SqlQueriesSource.__new__(SqlQueriesSource)
|
|
source.config = config
|
|
|
|
# Should not detect as S3 URI
|
|
assert source._is_s3_uri("not-an-s3-uri") is False
|
|
assert (
|
|
source._is_s3_uri("s3://") is True
|
|
) # Even incomplete S3 URI should be detected
|
|
|
|
|
|
class TestIntegrationScenarios:
|
|
"""Test integration scenarios combining multiple features."""
|
|
|
|
def test_s3_with_lazy_loading(self):
|
|
"""Test S3 processing with lazy loading enabled."""
|
|
# Create a proper AWS config dict
|
|
aws_config_dict = {
|
|
"aws_access_key_id": "test_key",
|
|
"aws_secret_access_key": "test_secret",
|
|
"aws_session_token": "test_token",
|
|
}
|
|
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake",
|
|
query_file="s3://bucket/file.json",
|
|
aws_config=aws_config_dict,
|
|
)
|
|
|
|
# Create a minimal source instance without full initialization
|
|
source = SqlQueriesSource.__new__(SqlQueriesSource)
|
|
source.config = config
|
|
|
|
# Verify configuration
|
|
assert source._is_s3_uri(config.query_file) is True
|
|
|
|
def test_temp_tables_support(self):
|
|
"""Test temp table support."""
|
|
config = SqlQueriesSourceConfig(
|
|
platform="athena",
|
|
query_file="dummy.json",
|
|
temp_table_patterns=["^temp_.*", "^tmp_.*"],
|
|
)
|
|
|
|
# Create a minimal source instance without full initialization
|
|
source = SqlQueriesSource.__new__(SqlQueriesSource)
|
|
source.config = config
|
|
source.report = Mock()
|
|
source.report.num_temp_tables_detected = 0
|
|
source.ctx = Mock() # Add ctx attribute
|
|
|
|
# Verify configuration
|
|
assert len(config.temp_table_patterns) == 2
|
|
|
|
# Test temp table detection
|
|
assert source.is_temp_table("temp_table") is True
|
|
assert source.is_temp_table("tmp_table") is True
|
|
assert source.is_temp_table("regular_table") is False
|
|
|
|
def test_performance_optimizations_combined(self):
|
|
"""Test all performance optimizations combined."""
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake",
|
|
query_file="dummy.json",
|
|
)
|
|
|
|
# Verify all optimizations are enabled
|
|
assert config.enable_lazy_schema_loading is True
|
|
assert config.temp_table_patterns == []
|
|
|
|
def test_backward_compatibility_with_new_features(self):
|
|
"""Test that existing configurations work with new features available."""
|
|
# Old-style configuration should still work
|
|
config = SqlQueriesSourceConfig(
|
|
platform="snowflake",
|
|
query_file="dummy.json",
|
|
default_db="test_db",
|
|
default_schema="test_schema",
|
|
)
|
|
|
|
# New features should have sensible defaults
|
|
|
|
# Old features should still work
|
|
assert config.default_db == "test_db"
|
|
assert config.default_schema == "test_schema"
|