datahub/metadata-ingestion/tests/unit/test_teradata_source.py
2025-07-11 13:01:10 -07:00

1497 lines
60 KiB
Python

from datetime import datetime
from typing import Any, Dict, List
from unittest.mock import MagicMock, patch
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.source.sql.teradata import (
TeradataConfig,
TeradataSource,
TeradataTable,
get_schema_columns,
get_schema_pk_constraints,
)
from datahub.metadata.urns import CorpUserUrn
from datahub.sql_parsing.sql_parsing_aggregator import ObservedQuery
def _base_config() -> Dict[str, Any]:
"""Base configuration for Teradata tests."""
return {
"username": "test_user",
"password": "test_password",
"host_port": "localhost:1025",
"include_table_lineage": True,
"include_usage_statistics": True,
"include_queries": True,
}
class TestTeradataConfig:
"""Test configuration validation and initialization."""
def test_valid_config(self):
"""Test that valid configuration is accepted."""
config_dict = _base_config()
config = TeradataConfig.parse_obj(config_dict)
assert config.host_port == "localhost:1025"
assert config.include_table_lineage is True
assert config.include_usage_statistics is True
assert config.include_queries is True
def test_max_workers_validation_valid(self):
"""Test valid max_workers configuration passes validation."""
config_dict = {
**_base_config(),
"max_workers": 8,
}
config = TeradataConfig.parse_obj(config_dict)
assert config.max_workers == 8
def test_max_workers_default(self):
"""Test max_workers defaults to 10."""
config_dict = _base_config()
config = TeradataConfig.parse_obj(config_dict)
assert config.max_workers == 10
def test_max_workers_custom_value(self):
"""Test custom max_workers value is accepted."""
config_dict = {
**_base_config(),
"max_workers": 5,
}
config = TeradataConfig.parse_obj(config_dict)
assert config.max_workers == 5
def test_include_queries_default(self):
"""Test include_queries defaults to True."""
config_dict = _base_config()
config = TeradataConfig.parse_obj(config_dict)
assert config.include_queries is True
class TestTeradataSource:
"""Test Teradata source functionality."""
@patch("datahub.ingestion.source.sql.teradata.create_engine")
def test_source_initialization(self, mock_create_engine):
"""Test source initializes correctly."""
config = TeradataConfig.parse_obj(_base_config())
ctx = PipelineContext(run_id="test")
# Mock the engine creation
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, ctx)
assert source.config == config
assert source.platform == "teradata"
assert hasattr(source, "aggregator")
assert hasattr(source, "_tables_cache")
assert hasattr(source, "_tables_cache_lock")
@patch("datahub.ingestion.source.sql.teradata.create_engine")
@patch("datahub.ingestion.source.sql.teradata.inspect")
def test_get_inspectors(self, mock_inspect, mock_create_engine):
"""Test inspector creation and database iteration."""
# Mock database names returned by inspector
mock_inspector = MagicMock()
mock_inspector.get_schema_names.return_value = ["db1", "db2", "test_db"]
mock_inspect.return_value = mock_inspector
mock_connection = MagicMock()
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = mock_connection
mock_create_engine.return_value = mock_engine
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(source, "get_metadata_engine", return_value=mock_engine):
inspectors = list(source.get_inspectors())
assert len(inspectors) == 3
# Check that each inspector has the database name set
for inspector in inspectors:
assert hasattr(inspector, "_datahub_database")
def test_cache_tables_and_views_thread_safety(self):
"""Test that cache operations are thread-safe."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock engine and query results
mock_entry = MagicMock()
mock_entry.DataBaseName = "test_db"
mock_entry.name = "test_table"
mock_entry.description = "Test table"
mock_entry.object_type = "Table"
mock_entry.CreateTimeStamp = None
mock_entry.LastAlterName = None
mock_entry.LastAlterTimeStamp = None
mock_entry.RequestText = None
with patch.object(source, "get_metadata_engine") as mock_get_engine:
mock_engine = MagicMock()
mock_engine.execute.return_value = [mock_entry]
mock_get_engine.return_value = mock_engine
# Call the method after patching the engine
source.cache_tables_and_views()
# Verify table was added to cache
assert "test_db" in source._tables_cache
assert len(source._tables_cache["test_db"]) == 1
assert source._tables_cache["test_db"][0].name == "test_table"
# Verify engine was disposed
mock_engine.dispose.assert_called_once()
def test_convert_entry_to_observed_query(self):
"""Test conversion of database entries to ObservedQuery objects."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock database entry
mock_entry = MagicMock()
mock_entry.query_text = "SELECT * FROM table1 (NOT CASESPECIFIC)"
mock_entry.session_id = "session123"
mock_entry.timestamp = "2024-01-01 10:00:00"
mock_entry.user = "test_user"
mock_entry.default_database = "test_db"
observed_query = source._convert_entry_to_observed_query(mock_entry)
assert isinstance(observed_query, ObservedQuery)
assert (
observed_query.query == "SELECT * FROM table1 "
) # (NOT CASESPECIFIC) removed
assert observed_query.session_id == "session123"
assert observed_query.timestamp == "2024-01-01 10:00:00"
assert isinstance(observed_query.user, CorpUserUrn)
assert observed_query.default_db == "test_db"
assert observed_query.default_schema == "test_db"
def test_convert_entry_to_observed_query_with_none_user(self):
"""Test ObservedQuery conversion handles None user correctly."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
mock_entry = MagicMock()
mock_entry.query_text = "SELECT 1"
mock_entry.session_id = "session123"
mock_entry.timestamp = "2024-01-01 10:00:00"
mock_entry.user = None
mock_entry.default_database = "test_db"
observed_query = source._convert_entry_to_observed_query(mock_entry)
assert observed_query.user is None
def test_check_historical_table_exists_success(self):
"""Test historical table check when table exists."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock successful query execution
mock_connection = MagicMock()
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = mock_connection
with patch.object(source, "get_metadata_engine", return_value=mock_engine):
result = source._check_historical_table_exists()
assert result is True
mock_engine.dispose.assert_called_once()
def test_check_historical_table_exists_failure(self):
"""Test historical table check when table doesn't exist."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock failed query execution
mock_connection = MagicMock()
mock_connection.execute.side_effect = Exception("Table not found")
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = mock_connection
with patch.object(source, "get_metadata_engine", return_value=mock_engine):
result = source._check_historical_table_exists()
assert result is False
mock_engine.dispose.assert_called_once()
def test_close_cleanup(self):
"""Test that close() properly cleans up resources."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Replace the aggregator with our mock after creation
source.aggregator = mock_aggregator
with patch(
"datahub.ingestion.source.sql.two_tier_sql_source.TwoTierSQLAlchemySource.close"
) as mock_super_close:
source.close()
mock_aggregator.close.assert_called_once()
mock_super_close.assert_called_once()
class TestSQLInjectionSafety:
"""Test SQL injection vulnerability fixes."""
def test_get_schema_columns_parameterized(self):
"""Test that get_schema_columns uses parameterized queries."""
mock_connection = MagicMock()
mock_connection.execute.return_value.fetchall.return_value = []
# Call the function
get_schema_columns(None, mock_connection, "columnsV", "test_schema")
# Verify parameterized query was used
call_args = mock_connection.execute.call_args
query = call_args[0][0].text
params = call_args[0][1] if len(call_args[0]) > 1 else call_args[1]
assert ":schema" in query
assert "schema" in params
assert params["schema"] == "test_schema"
def test_get_schema_pk_constraints_parameterized(self):
"""Test that get_schema_pk_constraints uses parameterized queries."""
mock_connection = MagicMock()
mock_connection.execute.return_value.fetchall.return_value = []
# Call the function
get_schema_pk_constraints(None, mock_connection, "test_schema")
# Verify parameterized query was used
call_args = mock_connection.execute.call_args
query = call_args[0][0].text
params = call_args[0][1] if len(call_args[0]) > 1 else call_args[1]
assert ":schema" in query
assert "schema" in params
assert params["schema"] == "test_schema"
class TestMemoryEfficiency:
"""Test memory efficiency improvements."""
def test_fetch_lineage_entries_chunked_streaming(self):
"""Test that lineage entries are processed in streaming fashion."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Replace the aggregator with our mock after creation
source.aggregator = mock_aggregator
# Mock the chunked fetching method to return a generator
def mock_generator():
for i in range(5):
mock_entry = MagicMock()
mock_entry.query_text = f"SELECT {i}"
mock_entry.session_id = f"session_{i}"
mock_entry.timestamp = "2024-01-01 10:00:00"
mock_entry.user = "test_user"
mock_entry.default_database = "test_db"
yield mock_entry
with patch.object(
source, "_fetch_lineage_entries_chunked", return_value=mock_generator()
):
mock_aggregator.gen_metadata.return_value = []
# Process entries
list(source._get_audit_log_mcps_with_aggregator())
# Verify aggregator.add was called for each entry (streaming)
assert mock_aggregator.add.call_count == 5
class TestConcurrencySupport:
"""Test thread safety and concurrent operations."""
def test_tables_cache_thread_safety(self):
"""Test that tables cache operations are thread-safe."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Verify lock exists
assert hasattr(source, "_tables_cache_lock")
# Test safe cache access methods
result = source._tables_cache.get("nonexistent_schema", [])
assert result == []
def test_cached_loop_tables_safe_access(self):
"""Test cached_loop_tables uses safe cache access."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Add test data to cache
test_table = TeradataTable(
database="test_db",
name="test_table",
description="Test",
object_type="Table",
create_timestamp=datetime.now(),
last_alter_name=None,
last_alter_timestamp=None,
request_text=None,
)
source._tables_cache["test_schema"] = [test_table]
# Mock inspector and config
mock_inspector = MagicMock()
mock_sql_config = MagicMock()
with patch(
"datahub.ingestion.source.sql.two_tier_sql_source.TwoTierSQLAlchemySource.loop_tables"
) as mock_super:
mock_super.return_value = []
# This should not raise an exception even with missing schema
list(
source.cached_loop_tables(
mock_inspector, "missing_schema", mock_sql_config
)
)
class TestStageTracking:
"""Test stage tracking functionality."""
def test_stage_tracking_in_cache_operation(self):
"""Test that table caching uses stage tracking."""
config = TeradataConfig.parse_obj(_base_config())
# Create source without mocking to test the actual stage tracking during init
with (
patch("datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"),
patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
) as mock_cache,
):
TeradataSource(config, PipelineContext(run_id="test"))
# Verify cache_tables_and_views was called during init (stage tracking happens there)
mock_cache.assert_called_once()
def test_stage_tracking_in_aggregator_processing(self):
"""Test that aggregator processing uses stage tracking."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Replace the aggregator with our mock after creation
source.aggregator = mock_aggregator
with patch.object(source.report, "new_stage") as mock_new_stage:
mock_context_manager = MagicMock()
mock_new_stage.return_value = mock_context_manager
with patch.object(
source, "_fetch_lineage_entries_chunked", return_value=[]
):
mock_aggregator.gen_metadata.return_value = []
list(source._get_audit_log_mcps_with_aggregator())
# Should have called new_stage for query processing and metadata generation
# The actual implementation uses new_stage for "Fetching queries" and "Generating metadata"
assert mock_new_stage.call_count >= 1
class TestErrorHandling:
"""Test error handling and edge cases."""
def test_empty_lineage_entries(self):
"""Test handling of empty lineage entries."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(
source, "_fetch_lineage_entries_chunked", return_value=[]
):
mock_aggregator.gen_metadata.return_value = []
result = list(source._get_audit_log_mcps_with_aggregator())
assert result == []
def test_malformed_query_entry(self):
"""Test handling of malformed query entries."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock entry with missing attributes
mock_entry = MagicMock()
mock_entry.query_text = "SELECT 1"
# Simulate missing attributes
del mock_entry.session_id
del mock_entry.timestamp
del mock_entry.user
del mock_entry.default_database
# Should handle gracefully using getattr with defaults
observed_query = source._convert_entry_to_observed_query(mock_entry)
assert observed_query.query == "SELECT 1"
assert observed_query.session_id is None
assert observed_query.user is None
class TestLineageQuerySeparation:
"""Test the new separated lineage query functionality (no more UNION)."""
def test_make_lineage_queries_current_only(self):
"""Test that only current query is returned when historical lineage is disabled."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": False,
"start_time": "2024-01-01T00:00:00Z",
"end_time": "2024-01-02T00:00:00Z",
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
queries = source._make_lineage_queries()
assert len(queries) == 1
assert '"DBC".QryLogV' in queries[0]
assert "PDCRDATA.DBQLSqlTbl_Hst" not in queries[0]
assert "2024-01-01" in queries[0]
assert "2024-01-02" in queries[0]
def test_make_lineage_queries_with_historical_available(self):
"""Test that UNION query is returned when historical lineage is enabled and table exists."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": True,
"start_time": "2024-01-01T00:00:00Z",
"end_time": "2024-01-02T00:00:00Z",
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(
source, "_check_historical_table_exists", return_value=True
):
queries = source._make_lineage_queries()
assert len(queries) == 1
# Single UNION query should contain both historical and current data
union_query = queries[0]
assert '"DBC".QryLogV' in union_query
assert '"PDCRINFO".DBQLSqlTbl_Hst' in union_query
assert "UNION" in union_query
assert "combined_results" in union_query
# Should have the time filters
assert "2024-01-01" in union_query
assert "2024-01-02" in union_query
def test_make_lineage_queries_with_historical_unavailable(self):
"""Test that only current query is returned when historical lineage is enabled but table doesn't exist."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": True,
"start_time": "2024-01-01T00:00:00Z",
"end_time": "2024-01-02T00:00:00Z",
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(
source, "_check_historical_table_exists", return_value=False
):
queries = source._make_lineage_queries()
assert len(queries) == 1
assert '"DBC".QryLogV' in queries[0]
assert '"PDCRDATA".DBQLSqlTbl_Hst' not in queries[0]
def test_make_lineage_queries_with_database_filter(self):
"""Test that database filters are correctly applied to UNION query."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": True,
"databases": ["test_db1", "test_db2"],
"start_time": "2024-01-01T00:00:00Z",
"end_time": "2024-01-02T00:00:00Z",
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(
source, "_check_historical_table_exists", return_value=True
):
queries = source._make_lineage_queries()
assert len(queries) == 1
# UNION query should have database filters for both current and historical parts
union_query = queries[0]
assert "l.DefaultDatabase in ('test_db1','test_db2')" in union_query
assert "h.DefaultDatabase in ('test_db1','test_db2')" in union_query
def test_fetch_lineage_entries_chunked_multiple_queries(self):
"""Test that _fetch_lineage_entries_chunked handles multiple queries correctly."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": True,
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock the query generation to return 2 queries
with patch.object(
source, "_make_lineage_queries", return_value=["query1", "query2"]
):
# Mock database execution for both queries
mock_result1 = MagicMock()
mock_result1.fetchmany.side_effect = [
[MagicMock(query_text="SELECT 1")], # First batch
[], # End of results
]
mock_result2 = MagicMock()
mock_result2.fetchmany.side_effect = [
[MagicMock(query_text="SELECT 2")], # First batch
[], # End of results
]
mock_connection = MagicMock()
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = (
mock_connection
)
with (
patch.object(
source, "get_metadata_engine", return_value=mock_engine
),
patch.object(
source, "_execute_with_cursor_fallback"
) as mock_execute,
):
mock_execute.side_effect = [mock_result1, mock_result2]
entries = list(source._fetch_lineage_entries_chunked())
# Should have executed both queries
assert mock_execute.call_count == 2
mock_execute.assert_any_call(mock_connection, "query1")
mock_execute.assert_any_call(mock_connection, "query2")
# Should return entries from both queries
assert len(entries) == 2
def test_fetch_lineage_entries_chunked_single_query(self):
"""Test that _fetch_lineage_entries_chunked handles single query correctly."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": False,
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock the query generation to return 1 query
with patch.object(source, "_make_lineage_queries", return_value=["query1"]):
mock_result = MagicMock()
mock_result.fetchmany.side_effect = [
[MagicMock(query_text="SELECT 1")], # First batch
[], # End of results
]
mock_connection = MagicMock()
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = (
mock_connection
)
with (
patch.object(
source, "get_metadata_engine", return_value=mock_engine
),
patch.object(
source,
"_execute_with_cursor_fallback",
return_value=mock_result,
) as mock_execute,
):
entries = list(source._fetch_lineage_entries_chunked())
# Should have executed only one query
assert mock_execute.call_count == 1
mock_execute.assert_called_with(mock_connection, "query1")
# Should return entries from the query
assert len(entries) == 1
def test_fetch_lineage_entries_chunked_batch_processing(self):
"""Test that batch processing works correctly with configurable batch size."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": False,
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(source, "_make_lineage_queries", return_value=["query1"]):
# Create mock entries
mock_entries = [MagicMock(query_text=f"SELECT {i}") for i in range(7)]
mock_result = MagicMock()
# Simulate batching with batch_size=5 (hardcoded in the method)
mock_result.fetchmany.side_effect = [
mock_entries[:5], # First batch (5 items)
mock_entries[5:], # Second batch (2 items)
[], # End of results
]
mock_connection = MagicMock()
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = (
mock_connection
)
with (
patch.object(
source, "get_metadata_engine", return_value=mock_engine
),
patch.object(
source,
"_execute_with_cursor_fallback",
return_value=mock_result,
),
):
entries = list(source._fetch_lineage_entries_chunked())
# Should return all 7 entries
assert len(entries) == 7
# Verify fetchmany was called with the right batch size (5000 is hardcoded)
calls = mock_result.fetchmany.call_args_list
for call in calls:
if call[0]: # If positional args
assert call[0][0] == 5000
else: # If keyword args
assert call[1].get("size", 5000) == 5000
def test_end_to_end_separate_queries_integration(self):
"""Test end-to-end integration of separate queries in the aggregator flow."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": True,
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Replace the aggregator with our mock after creation
source.aggregator = mock_aggregator
# Mock entries from both current and historical queries
current_entry = MagicMock()
current_entry.query_text = "SELECT * FROM current_table"
current_entry.user = "current_user"
current_entry.timestamp = "2024-01-01 10:00:00"
current_entry.default_database = "current_db"
historical_entry = MagicMock()
historical_entry.query_text = "SELECT * FROM historical_table"
historical_entry.user = "historical_user"
historical_entry.timestamp = "2023-12-01 10:00:00"
historical_entry.default_database = "historical_db"
def mock_fetch_generator():
yield current_entry
yield historical_entry
with patch.object(
source,
"_fetch_lineage_entries_chunked",
return_value=mock_fetch_generator(),
):
mock_aggregator.gen_metadata.return_value = []
# Execute the aggregator flow
list(source._get_audit_log_mcps_with_aggregator())
# Verify both entries were added to aggregator
assert mock_aggregator.add.call_count == 2
# Verify the entries were converted correctly
added_queries = [
call[0][0] for call in mock_aggregator.add.call_args_list
]
assert any(
"SELECT * FROM current_table" in query.query
for query in added_queries
)
assert any(
"SELECT * FROM historical_table" in query.query
for query in added_queries
)
def test_query_logging_and_progress_tracking(self):
"""Test that proper logging occurs when processing multiple queries."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": True,
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(
source, "_make_lineage_queries", return_value=["query1", "query2"]
):
mock_result = MagicMock()
call_counter = {"count": 0}
def mock_fetchmany_side_effect(batch_size):
# Return one batch then empty to simulate end of results
call_counter["count"] += 1
if call_counter["count"] == 1:
return [MagicMock(query_text="SELECT 1")]
return []
mock_result.fetchmany.side_effect = mock_fetchmany_side_effect
mock_connection = MagicMock()
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = (
mock_connection
)
with (
patch.object(
source, "get_metadata_engine", return_value=mock_engine
),
patch.object(
source,
"_execute_with_cursor_fallback",
return_value=mock_result,
),
patch(
"datahub.ingestion.source.sql.teradata.logger"
) as mock_logger,
):
list(source._fetch_lineage_entries_chunked())
# Verify progress logging for multiple queries
info_calls = [
call for call in mock_logger.info.call_args_list if call[0]
]
# Should log execution of query 1/2 and 2/2
assert any("query 1/2" in str(call) for call in info_calls)
assert any("query 2/2" in str(call) for call in info_calls)
# Should log completion of both queries
assert any("Completed query 1" in str(call) for call in info_calls)
assert any("Completed query 2" in str(call) for call in info_calls)
class TestQueryConstruction:
"""Test the construction of individual queries."""
def test_current_query_construction(self):
"""Test that the current query is constructed correctly."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"start_time": "2024-01-01T00:00:00Z",
"end_time": "2024-01-02T00:00:00Z",
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
queries = source._make_lineage_queries()
current_query = queries[0]
# Verify current query structure
assert 'FROM "DBC".QryLogV as l' in current_query
assert 'JOIN "DBC".QryLogSqlV as s' in current_query
assert "l.ErrorCode = 0" in current_query
assert "2024-01-01" in current_query
assert "2024-01-02" in current_query
assert 'ORDER BY "timestamp", "query_id", "row_no"' in current_query
def test_historical_query_construction(self):
"""Test that the UNION query contains historical data correctly."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": True,
"start_time": "2024-01-01T00:00:00Z",
"end_time": "2024-01-02T00:00:00Z",
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(
source, "_check_historical_table_exists", return_value=True
):
queries = source._make_lineage_queries()
union_query = queries[0]
# Verify UNION query contains historical data structure
assert 'FROM "PDCRINFO".DBQLSqlTbl_Hst as h' in union_query
assert "h.ErrorCode = 0" in union_query
assert "h.StartTime AT TIME ZONE 'GMT'" in union_query
assert "h.DefaultDatabase" in union_query
assert "2024-01-01" in union_query
assert "2024-01-02" in union_query
assert 'ORDER BY "timestamp", "query_id", "row_no"' in union_query
assert "UNION" in union_query
class TestStreamingQueryReconstruction:
"""Test the streaming query reconstruction functionality."""
def test_reconstruct_queries_streaming_single_row_queries(self):
"""Test streaming reconstruction with single-row queries."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create entries for single-row queries
entries = [
self._create_mock_entry(
"Q1", "SELECT * FROM table1", 1, "2024-01-01 10:00:00"
),
self._create_mock_entry(
"Q2", "SELECT * FROM table2", 1, "2024-01-01 10:01:00"
),
self._create_mock_entry(
"Q3", "SELECT * FROM table3", 1, "2024-01-01 10:02:00"
),
]
# Test streaming reconstruction
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
assert len(reconstructed_queries) == 3
assert reconstructed_queries[0].query == "SELECT * FROM table1"
assert reconstructed_queries[1].query == "SELECT * FROM table2"
assert reconstructed_queries[2].query == "SELECT * FROM table3"
# Verify metadata preservation
assert reconstructed_queries[0].timestamp == "2024-01-01 10:00:00"
assert reconstructed_queries[1].timestamp == "2024-01-01 10:01:00"
assert reconstructed_queries[2].timestamp == "2024-01-01 10:02:00"
def test_reconstruct_queries_streaming_multi_row_queries(self):
"""Test streaming reconstruction with multi-row queries."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create entries for multi-row queries
entries = [
# Query 1: 3 rows
self._create_mock_entry(
"Q1", "SELECT a, b, c ", 1, "2024-01-01 10:00:00"
),
self._create_mock_entry(
"Q1", "FROM large_table ", 2, "2024-01-01 10:00:00"
),
self._create_mock_entry(
"Q1", "WHERE id > 1000", 3, "2024-01-01 10:00:00"
),
# Query 2: 2 rows
self._create_mock_entry(
"Q2", "UPDATE table3 SET ", 1, "2024-01-01 10:01:00"
),
self._create_mock_entry(
"Q2", "status = 'active'", 2, "2024-01-01 10:01:00"
),
]
# Test streaming reconstruction
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
assert len(reconstructed_queries) == 2
assert (
reconstructed_queries[0].query
== "SELECT a, b, c FROM large_table WHERE id > 1000"
)
assert (
reconstructed_queries[1].query == "UPDATE table3 SET status = 'active'"
)
# Verify metadata preservation (should use metadata from first row of each query)
assert reconstructed_queries[0].timestamp == "2024-01-01 10:00:00"
assert reconstructed_queries[1].timestamp == "2024-01-01 10:01:00"
def test_reconstruct_queries_streaming_mixed_queries(self):
"""Test streaming reconstruction with mixed single and multi-row queries."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create entries mixing single and multi-row queries
entries = [
# Single-row query
self._create_mock_entry(
"Q1", "SELECT * FROM table1", 1, "2024-01-01 10:00:00"
),
# Multi-row query (3 rows)
self._create_mock_entry(
"Q2", "SELECT a, b, c ", 1, "2024-01-01 10:01:00"
),
self._create_mock_entry(
"Q2", "FROM large_table ", 2, "2024-01-01 10:01:00"
),
self._create_mock_entry(
"Q2", "WHERE id > 1000", 3, "2024-01-01 10:01:00"
),
# Single-row query
self._create_mock_entry(
"Q3", "SELECT COUNT(*) FROM table2", 1, "2024-01-01 10:02:00"
),
# Multi-row query (2 rows)
self._create_mock_entry(
"Q4", "UPDATE table3 SET ", 1, "2024-01-01 10:03:00"
),
self._create_mock_entry(
"Q4", "status = 'active'", 2, "2024-01-01 10:03:00"
),
]
# Test streaming reconstruction
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
assert len(reconstructed_queries) == 4
assert reconstructed_queries[0].query == "SELECT * FROM table1"
assert (
reconstructed_queries[1].query
== "SELECT a, b, c FROM large_table WHERE id > 1000"
)
assert reconstructed_queries[2].query == "SELECT COUNT(*) FROM table2"
assert (
reconstructed_queries[3].query == "UPDATE table3 SET status = 'active'"
)
def test_reconstruct_queries_streaming_empty_entries(self):
"""Test streaming reconstruction with empty entries."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Test with empty entries
entries: List[Any] = []
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
assert len(reconstructed_queries) == 0
def test_reconstruct_queries_streaming_teradata_specific_transformations(self):
"""Test that Teradata-specific transformations are applied."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create entry with Teradata-specific syntax
entries = [
self._create_mock_entry(
"Q1",
"SELECT * FROM table1 (NOT CASESPECIFIC)",
1,
"2024-01-01 10:00:00",
),
]
# Test streaming reconstruction
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
assert len(reconstructed_queries) == 1
# Should remove (NOT CASESPECIFIC)
assert reconstructed_queries[0].query == "SELECT * FROM table1 "
def test_reconstruct_queries_streaming_metadata_preservation(self):
"""Test that all metadata fields are preserved correctly."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create entry with all metadata fields
entries: List[Any] = [
self._create_mock_entry(
"Q1",
"SELECT * FROM table1",
1,
"2024-01-01 10:00:00",
user="test_user",
default_database="test_db",
session_id="session123",
),
]
# Test streaming reconstruction
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
assert len(reconstructed_queries) == 1
query = reconstructed_queries[0]
# Verify all metadata fields
assert query.query == "SELECT * FROM table1"
assert query.timestamp == "2024-01-01 10:00:00"
assert isinstance(query.user, CorpUserUrn)
assert str(query.user) == "urn:li:corpuser:test_user"
assert query.default_db == "test_db"
assert query.default_schema == "test_db" # Teradata uses database as schema
assert query.session_id == "session123"
def test_reconstruct_queries_streaming_with_none_user(self):
"""Test streaming reconstruction handles None user correctly."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create entry with None user
entries = [
self._create_mock_entry(
"Q1", "SELECT * FROM table1", 1, "2024-01-01 10:00:00", user=None
),
]
# Test streaming reconstruction
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
assert len(reconstructed_queries) == 1
assert reconstructed_queries[0].user is None
def test_reconstruct_queries_streaming_empty_query_text(self):
"""Test streaming reconstruction handles empty query text correctly."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create entries with empty query text
entries = [
self._create_mock_entry("Q1", "", 1, "2024-01-01 10:00:00"),
self._create_mock_entry(
"Q2", "SELECT * FROM table1", 1, "2024-01-01 10:01:00"
),
]
# Test streaming reconstruction
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
# Should only get one query (the non-empty one)
assert len(reconstructed_queries) == 1
assert reconstructed_queries[0].query == "SELECT * FROM table1"
def test_reconstruct_queries_streaming_space_joining_behavior(self):
"""Test that query parts are joined directly without adding spaces."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Test case 1: Parts that include their own spacing
entries1 = [
self._create_mock_entry("Q1", "SELECT ", 1, "2024-01-01 10:00:00"),
self._create_mock_entry("Q1", "col1, ", 2, "2024-01-01 10:00:00"),
self._create_mock_entry("Q1", "col2 ", 3, "2024-01-01 10:00:00"),
self._create_mock_entry("Q1", "FROM ", 4, "2024-01-01 10:00:00"),
self._create_mock_entry("Q1", "table1", 5, "2024-01-01 10:00:00"),
]
# Test case 2: Parts that already have trailing/leading spaces
entries2 = [
self._create_mock_entry("Q2", "SELECT * ", 1, "2024-01-01 10:01:00"),
self._create_mock_entry("Q2", "FROM table2 ", 2, "2024-01-01 10:01:00"),
self._create_mock_entry("Q2", "WHERE id > 1", 3, "2024-01-01 10:01:00"),
]
# Test streaming reconstruction
all_entries = entries1 + entries2
reconstructed_queries = list(
source._reconstruct_queries_streaming(all_entries)
)
assert len(reconstructed_queries) == 2
# Query 1: Should be joined directly without adding spaces
assert reconstructed_queries[0].query == "SELECT col1, col2 FROM table1"
# Query 2: Should handle existing spaces correctly (may have extra spaces)
assert reconstructed_queries[1].query == "SELECT * FROM table2 WHERE id > 1"
def _create_mock_entry(
self,
query_id,
query_text,
row_no,
timestamp,
user="test_user",
default_database="test_db",
session_id=None,
):
"""Create a mock database entry for testing."""
entry = MagicMock()
entry.query_id = query_id
entry.query_text = query_text
entry.row_no = row_no
entry.timestamp = timestamp
entry.user = user
entry.default_database = default_database
entry.session_id = session_id or f"session_{query_id}"
return entry