datahub/metadata-ingestion/tests/unit/test_teradata_performance.py
2025-07-11 13:01:10 -07:00

591 lines
22 KiB
Python

"""
Performance and optimization tests for TeradataSource.
These tests focus on performance optimizations, memory management,
and efficiency improvements in the Teradata source.
"""
from datetime import datetime
from typing import Any, Dict
from unittest.mock import MagicMock, patch
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.source.sql.teradata import (
TeradataConfig,
TeradataSource,
TeradataTable,
get_schema_columns,
get_schema_foreign_keys,
get_schema_pk_constraints,
)
def _base_config() -> Dict[str, Any]:
"""Base configuration for Teradata tests."""
return {
"username": "test_user",
"password": "test_password",
"host_port": "localhost:1025",
"include_table_lineage": True,
"include_usage_statistics": True,
"include_queries": True,
}
class TestCachingOptimizations:
"""Test caching optimizations and LRU cache functionality."""
def test_get_schema_columns_caching(self):
"""Test that get_schema_columns uses LRU cache effectively."""
mock_connection = MagicMock()
# Mock database response
mock_row1 = MagicMock()
mock_row1._mapping.TableName = "table1"
mock_row1._mapping.ColumnName = "col1"
mock_row2 = MagicMock()
mock_row2._mapping.TableName = "table1"
mock_row2._mapping.ColumnName = "col2"
mock_connection.execute.return_value.fetchall.return_value = [
mock_row1,
mock_row2,
]
# Mock self with report for timing tracking
mock_self = MagicMock()
mock_self.report = MagicMock()
mock_self.report.column_extraction_duration_seconds = 0.0
# First call
result1 = get_schema_columns(
mock_self, mock_connection, "columnsV", "test_schema"
)
# Second call with same parameters should use cache
result2 = get_schema_columns(
mock_self, mock_connection, "columnsV", "test_schema"
)
# Should only execute query once due to caching
assert mock_connection.execute.call_count == 1
assert result1 == result2
assert "table1" in result1
assert len(result1["table1"]) == 2
def test_get_schema_pk_constraints_caching(self):
"""Test that get_schema_pk_constraints uses LRU cache effectively."""
mock_connection = MagicMock()
# Mock database response
mock_row = MagicMock()
mock_row._mapping.TableName = "table1"
mock_row._mapping.ColumnName = "id"
mock_row._mapping.IndexName = "pk_table1"
mock_connection.execute.return_value.fetchall.return_value = [mock_row]
# First call
result1 = get_schema_pk_constraints(None, mock_connection, "test_schema")
# Second call should use cache
result2 = get_schema_pk_constraints(None, mock_connection, "test_schema")
# Should only execute query once
assert mock_connection.execute.call_count == 1
assert result1 == result2
def test_get_schema_foreign_keys_caching(self):
"""Test that get_schema_foreign_keys uses LRU cache effectively."""
mock_connection = MagicMock()
# Mock database response
mock_row = MagicMock()
mock_row._mapping.ChildTable = "child_table"
mock_row._mapping.IndexName = "fk_child"
mock_connection.execute.return_value.fetchall.return_value = [mock_row]
# First call
result1 = get_schema_foreign_keys(None, mock_connection, "test_schema")
# Second call should use cache
result2 = get_schema_foreign_keys(None, mock_connection, "test_schema")
# Should only execute query once
assert mock_connection.execute.call_count == 1
assert result1 == result2
def test_cache_invalidation_different_schemas(self):
"""Test that cache is properly invalidated for different schemas."""
mock_connection = MagicMock()
def mock_execute_side_effect(query, params):
# Return different results based on schema
schema = params["schema"]
mock_result = MagicMock()
if schema == "schema1":
mock_row = MagicMock()
mock_row._mapping.TableName = "table_schema1"
mock_result.fetchall.return_value = [mock_row]
else:
mock_row = MagicMock()
mock_row._mapping.TableName = "table_schema2"
mock_result.fetchall.return_value = [mock_row]
return mock_result
mock_connection.execute.side_effect = mock_execute_side_effect
mock_self = MagicMock()
mock_self.report = MagicMock()
mock_self.report.column_extraction_duration_seconds = 0.0
# Call with different schemas
result1 = get_schema_columns(mock_self, mock_connection, "columnsV", "schema1")
result2 = get_schema_columns(mock_self, mock_connection, "columnsV", "schema2")
# Should execute query twice for different schemas
assert mock_connection.execute.call_count == 2
assert "table_schema1" in result1
assert "table_schema2" in result2
class TestMemoryOptimizations:
"""Test memory optimization features."""
def test_tables_cache_memory_efficiency(self):
"""Test that tables cache is memory efficient."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.ingestion.source.sql.teradata.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Test that cache uses defaultdict for efficient memory usage
from collections import defaultdict
assert isinstance(source._tables_cache, defaultdict)
# Test safe access to non-existent schema
result = source._tables_cache["nonexistent_schema"]
assert result == []
def test_streaming_query_processing(self):
"""Test that query processing uses streaming for memory efficiency."""
config_dict = {
**_base_config(),
"use_server_side_cursors": True,
}
config = TeradataConfig.parse_obj(config_dict)
with patch(
"datahub.ingestion.source.sql.teradata.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
mock_connection = MagicMock()
mock_result = MagicMock()
# Mock streaming execution
mock_connection.execution_options.return_value.execute.return_value = (
mock_result
)
result = source._execute_with_cursor_fallback(mock_connection, "SELECT 1")
# Verify streaming was enabled
mock_connection.execution_options.assert_called_with(stream_results=True)
assert result == mock_result
def test_chunked_processing_batch_size(self):
"""Test that chunked processing uses appropriate batch sizes."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.ingestion.source.sql.teradata.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(
source, "_make_lineage_queries", return_value=["SELECT 1"]
):
# Create mock result that returns batches
mock_result = MagicMock()
# Simulate fetching in batches
call_count = [0]
def mock_fetchmany(batch_size):
call_count[0] += 1
if call_count[0] == 1:
# Return batch_size to verify it's being used
assert batch_size == 5000 # Expected batch size
return [MagicMock() for _ in range(batch_size)]
return [] # End of results
mock_result.fetchmany.side_effect = mock_fetchmany
mock_connection = MagicMock()
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = (
mock_connection
)
with (
patch.object(
source, "get_metadata_engine", return_value=mock_engine
),
patch.object(
source,
"_execute_with_cursor_fallback",
return_value=mock_result,
),
):
# Process entries
entries = list(source._fetch_lineage_entries_chunked())
# Should have fetched one full batch
assert len(entries) == 5000
class TestPerformanceReporting:
"""Test performance metrics and reporting."""
def test_column_extraction_timing(self):
"""Test that column extraction timing is tracked."""
mock_connection = MagicMock()
mock_connection.execute.return_value.fetchall.return_value = []
# Mock self with report
mock_self = MagicMock()
mock_self.report = MagicMock()
mock_self.report.column_extraction_duration_seconds = 0.0
# Test by setting a custom timing value and verifying it's updated
get_schema_columns(mock_self, mock_connection, "columnsV", "test_schema")
# The timing should be updated (it will be a small positive number)
assert mock_self.report.column_extraction_duration_seconds >= 0.0
def test_view_processing_timing_metrics(self):
"""Test that view processing timing is properly tracked."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.ingestion.source.sql.teradata.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create view in cache
test_view = TeradataTable(
database="test_schema",
name="test_view",
description="Test view",
object_type="View",
create_timestamp=datetime.now(),
last_alter_name=None,
last_alter_timestamp=None,
request_text="SELECT * FROM test_table",
)
source._tables_cache["test_schema"] = [test_view]
mock_inspector = MagicMock()
mock_sql_config = MagicMock()
with patch.object(source, "_process_views_single_threaded") as mock_process:
mock_process.return_value = []
# Process views and check timing
list(
source.cached_loop_views(
mock_inspector, "test_schema", mock_sql_config
)
)
# Verify timing metrics were updated
assert source.report.view_extraction_total_time_seconds > 0
assert source.report.num_views_processed == 1
def test_connection_pool_metrics(self):
"""Test that connection pool metrics are tracked."""
config_dict = {
**_base_config(),
"max_workers": 2,
}
config = TeradataConfig.parse_obj(config_dict)
with patch(
"datahub.ingestion.source.sql.teradata.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Verify initial metrics
assert source.report.connection_pool_wait_time_seconds == 0.0
assert source.report.connection_pool_max_wait_time_seconds == 0.0
# Pool metrics are updated in the actual view processing threads
# Here we just verify the structure exists
def test_database_level_metrics_tracking(self):
"""Test that database-level metrics are properly tracked."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.ingestion.source.sql.teradata.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock database entries for metrics
mock_entries = [
MagicMock(
DataBaseName="db1",
name="table1",
description="Test",
object_type="Table",
CreateTimeStamp=datetime.now(),
LastAlterName=None,
LastAlterTimeStamp=None,
RequestText=None,
),
MagicMock(
DataBaseName="db1",
name="view1",
description="Test",
object_type="View",
CreateTimeStamp=datetime.now(),
LastAlterName=None,
LastAlterTimeStamp=None,
RequestText="SELECT 1",
),
MagicMock(
DataBaseName="db2",
name="table2",
description="Test",
object_type="Table",
CreateTimeStamp=datetime.now(),
LastAlterName=None,
LastAlterTimeStamp=None,
RequestText=None,
),
]
with patch.object(source, "get_metadata_engine") as mock_get_engine:
mock_engine = MagicMock()
mock_engine.execute.return_value = mock_entries
mock_get_engine.return_value = mock_engine
source.cache_tables_and_views()
# Verify database-level metrics
assert source.report.num_database_tables_to_scan["db1"] == 1
assert source.report.num_database_views_to_scan["db1"] == 1
assert source.report.num_database_tables_to_scan["db2"] == 1
assert source.report.num_database_views_to_scan["db2"] == 0
class TestThreadSafetyOptimizations:
"""Test thread safety optimizations."""
def test_tables_cache_lock_usage(self):
"""Test that tables cache uses locks for thread safety."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.ingestion.source.sql.teradata.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Verify lock exists (it's a class attribute)
assert hasattr(source.__class__, "_tables_cache_lock")
# Check that it's a lock object by checking its type name
assert source.__class__._tables_cache_lock.__class__.__name__ == "lock"
def test_report_lock_usage(self):
"""Test that report operations use locks for thread safety."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.ingestion.source.sql.teradata.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Verify report lock exists
assert hasattr(source, "_report_lock")
# Check that it's a lock object by checking its type name
assert source._report_lock.__class__.__name__ == "lock"
def test_pooled_engine_lock_usage(self):
"""Test that pooled engine creation uses locks."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.ingestion.source.sql.teradata.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Verify pooled engine lock exists (it's a class attribute)
assert hasattr(source.__class__, "_pooled_engine_lock")
# Check that it's a lock object by checking its type name
assert source.__class__._pooled_engine_lock.__class__.__name__ == "lock"
class TestQueryOptimizations:
"""Test SQL query optimizations."""
def test_parameterized_query_usage(self):
"""Test that parameterized queries are used to prevent SQL injection."""
mock_connection = MagicMock()
mock_connection.execute.return_value.fetchall.return_value = []
mock_self = MagicMock()
mock_self.report = MagicMock()
mock_self.report.column_extraction_duration_seconds = 0.0
# Call with special characters that could cause SQL injection
schema_name = "test'schema"
get_schema_columns(mock_self, mock_connection, "columnsV", schema_name)
# Verify parameterized query was used
call_args = mock_connection.execute.call_args
query = call_args[0][0].text
params = call_args[0][1] if len(call_args[0]) > 1 else call_args[1]
# Query should use parameters, not string concatenation
assert ":schema" in query
assert "schema" in params
assert params["schema"] == schema_name
def test_query_structure_optimization(self):
"""Test that queries are structured for optimal performance."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.ingestion.source.sql.teradata.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Test TABLES_AND_VIEWS_QUERY structure
query = source.TABLES_AND_VIEWS_QUERY
# Should exclude system databases efficiently
assert "NOT IN" in query
assert "ORDER by DataBaseName, TableName" in query
# Should only select necessary table types
assert "t.TableKind in ('T', 'V', 'Q', 'O')" in query
def test_usexviews_optimization(self):
"""Test that usexviews configuration optimizes queries."""
mock_connection = MagicMock()
mock_connection.execute.return_value.fetchall.return_value = []
# Test with usexviews - this function doesn't actually use usexviews, just verify the query structure
get_schema_pk_constraints(None, mock_connection, "test_schema")
# Should use IndicesV view
call_args = mock_connection.execute.call_args
query = call_args[0][0].text
assert "IndicesV" in query
assert "IndexType = 'K'" in query
def test_qvci_optimization(self):
"""Test QVCI optimization for column information."""
from datahub.ingestion.source.sql.teradata import optimized_get_columns
# Create table in cache
test_table = TeradataTable(
database="test_schema",
name="test_table",
description="Test table",
object_type="Table",
create_timestamp=datetime.now(),
last_alter_name=None,
last_alter_timestamp=None,
request_text=None,
)
tables_cache = {"test_schema": [test_table]}
# Mock self object
mock_self = MagicMock()
mock_self.get_schema_columns.return_value = {"test_table": []}
mock_connection = MagicMock()
# Test with QVCI enabled
optimized_get_columns(
mock_self,
mock_connection,
"test_table",
schema="test_schema",
tables_cache=tables_cache,
use_qvci=True,
)
# Should use QVCI columns view
mock_self.get_schema_columns.assert_called_with(
mock_connection, "columnsQV", "test_schema"
)