datahub/metadata-ingestion/tests/unit/test_teradata_source.py

1475 lines
59 KiB
Python
Raw Normal View History

from datetime import datetime
from typing import Any, Dict, List
from unittest.mock import MagicMock, patch
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.source.sql.teradata import (
TeradataConfig,
TeradataSource,
TeradataTable,
get_schema_columns,
get_schema_pk_constraints,
)
from datahub.metadata.urns import CorpUserUrn
from datahub.sql_parsing.sql_parsing_aggregator import ObservedQuery
def _base_config() -> Dict[str, Any]:
"""Base configuration for Teradata tests."""
return {
"username": "test_user",
"password": "test_password",
"host_port": "localhost:1025",
"include_table_lineage": True,
"include_usage_statistics": True,
"include_queries": True,
}
class TestTeradataConfig:
"""Test configuration validation and initialization."""
def test_valid_config(self):
"""Test that valid configuration is accepted."""
config_dict = _base_config()
config = TeradataConfig.parse_obj(config_dict)
assert config.host_port == "localhost:1025"
assert config.include_table_lineage is True
assert config.include_usage_statistics is True
assert config.include_queries is True
def test_max_workers_validation_valid(self):
"""Test valid max_workers configuration passes validation."""
config_dict = {
**_base_config(),
"max_workers": 8,
}
config = TeradataConfig.parse_obj(config_dict)
assert config.max_workers == 8
def test_max_workers_default(self):
"""Test max_workers defaults to 10."""
config_dict = _base_config()
config = TeradataConfig.parse_obj(config_dict)
assert config.max_workers == 10
def test_max_workers_custom_value(self):
"""Test custom max_workers value is accepted."""
config_dict = {
**_base_config(),
"max_workers": 5,
}
config = TeradataConfig.parse_obj(config_dict)
assert config.max_workers == 5
def test_include_queries_default(self):
"""Test include_queries defaults to True."""
config_dict = _base_config()
config = TeradataConfig.parse_obj(config_dict)
assert config.include_queries is True
class TestTeradataSource:
"""Test Teradata source functionality."""
@patch("datahub.ingestion.source.sql.teradata.create_engine")
def test_source_initialization(self, mock_create_engine):
"""Test source initializes correctly."""
config = TeradataConfig.parse_obj(_base_config())
ctx = PipelineContext(run_id="test")
# Mock the engine creation
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, ctx)
assert source.config == config
assert source.platform == "teradata"
assert hasattr(source, "aggregator")
assert hasattr(source, "_tables_cache")
assert hasattr(source, "_tables_cache_lock")
@patch("datahub.ingestion.source.sql.teradata.create_engine")
@patch("datahub.ingestion.source.sql.teradata.inspect")
def test_get_inspectors(self, mock_inspect, mock_create_engine):
"""Test inspector creation and database iteration."""
# Mock database names returned by inspector
mock_inspector = MagicMock()
mock_inspector.get_schema_names.return_value = ["db1", "db2", "test_db"]
mock_inspect.return_value = mock_inspector
mock_connection = MagicMock()
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = mock_connection
mock_create_engine.return_value = mock_engine
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(source, "get_metadata_engine", return_value=mock_engine):
inspectors = list(source.get_inspectors())
assert len(inspectors) == 3
# Check that each inspector has the database name set
for inspector in inspectors:
assert hasattr(inspector, "_datahub_database")
def test_cache_tables_and_views_thread_safety(self):
"""Test that cache operations are thread-safe."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock engine and query results
mock_entry = MagicMock()
mock_entry.DataBaseName = "test_db"
mock_entry.name = "test_table"
mock_entry.description = "Test table"
mock_entry.object_type = "Table"
mock_entry.CreateTimeStamp = None
mock_entry.LastAlterName = None
mock_entry.LastAlterTimeStamp = None
mock_entry.RequestText = None
with patch.object(source, "get_metadata_engine") as mock_get_engine:
mock_engine = MagicMock()
mock_engine.execute.return_value = [mock_entry]
mock_get_engine.return_value = mock_engine
# Call the method after patching the engine
source.cache_tables_and_views()
# Verify table was added to cache
assert "test_db" in source._tables_cache
assert len(source._tables_cache["test_db"]) == 1
assert source._tables_cache["test_db"][0].name == "test_table"
# Verify engine was disposed
mock_engine.dispose.assert_called_once()
def test_convert_entry_to_observed_query(self):
"""Test conversion of database entries to ObservedQuery objects."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock database entry
mock_entry = MagicMock()
mock_entry.query_text = "SELECT * FROM table1 (NOT CASESPECIFIC)"
mock_entry.session_id = "session123"
mock_entry.timestamp = "2024-01-01 10:00:00"
mock_entry.user = "test_user"
mock_entry.default_database = "test_db"
observed_query = source._convert_entry_to_observed_query(mock_entry)
assert isinstance(observed_query, ObservedQuery)
assert (
observed_query.query == "SELECT * FROM table1 "
) # (NOT CASESPECIFIC) removed
assert observed_query.session_id == "session123"
assert observed_query.timestamp == "2024-01-01 10:00:00"
assert isinstance(observed_query.user, CorpUserUrn)
assert observed_query.default_db == "test_db"
assert observed_query.default_schema == "test_db"
def test_convert_entry_to_observed_query_with_none_user(self):
"""Test ObservedQuery conversion handles None user correctly."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
mock_entry = MagicMock()
mock_entry.query_text = "SELECT 1"
mock_entry.session_id = "session123"
mock_entry.timestamp = "2024-01-01 10:00:00"
mock_entry.user = None
mock_entry.default_database = "test_db"
observed_query = source._convert_entry_to_observed_query(mock_entry)
assert observed_query.user is None
def test_check_historical_table_exists_success(self):
"""Test historical table check when table exists."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock successful query execution
mock_connection = MagicMock()
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = mock_connection
with patch.object(source, "get_metadata_engine", return_value=mock_engine):
result = source._check_historical_table_exists()
assert result is True
mock_engine.dispose.assert_called_once()
def test_check_historical_table_exists_failure(self):
"""Test historical table check when table doesn't exist."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock failed query execution
mock_connection = MagicMock()
mock_connection.execute.side_effect = Exception("Table not found")
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = mock_connection
with patch.object(source, "get_metadata_engine", return_value=mock_engine):
result = source._check_historical_table_exists()
assert result is False
mock_engine.dispose.assert_called_once()
def test_close_cleanup(self):
"""Test that close() properly cleans up resources."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Replace the aggregator with our mock after creation
source.aggregator = mock_aggregator
with patch(
"datahub.ingestion.source.sql.two_tier_sql_source.TwoTierSQLAlchemySource.close"
) as mock_super_close:
source.close()
mock_aggregator.close.assert_called_once()
mock_super_close.assert_called_once()
class TestSQLInjectionSafety:
"""Test SQL injection vulnerability fixes."""
def test_get_schema_columns_parameterized(self):
"""Test that get_schema_columns uses parameterized queries."""
mock_connection = MagicMock()
mock_connection.execute.return_value.fetchall.return_value = []
# Call the function
get_schema_columns(None, mock_connection, "columnsV", "test_schema")
# Verify parameterized query was used
call_args = mock_connection.execute.call_args
query = call_args[0][0].text
params = call_args[0][1] if len(call_args[0]) > 1 else call_args[1]
assert ":schema" in query
assert "schema" in params
assert params["schema"] == "test_schema"
def test_get_schema_pk_constraints_parameterized(self):
"""Test that get_schema_pk_constraints uses parameterized queries."""
mock_connection = MagicMock()
mock_connection.execute.return_value.fetchall.return_value = []
# Call the function
get_schema_pk_constraints(None, mock_connection, "test_schema")
# Verify parameterized query was used
call_args = mock_connection.execute.call_args
query = call_args[0][0].text
params = call_args[0][1] if len(call_args[0]) > 1 else call_args[1]
assert ":schema" in query
assert "schema" in params
assert params["schema"] == "test_schema"
class TestMemoryEfficiency:
"""Test memory efficiency improvements."""
def test_fetch_lineage_entries_chunked_streaming(self):
"""Test that lineage entries are processed in streaming fashion."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Replace the aggregator with our mock after creation
source.aggregator = mock_aggregator
# Mock the chunked fetching method to return a generator
def mock_generator():
for i in range(5):
mock_entry = MagicMock()
mock_entry.query_text = f"SELECT {i}"
mock_entry.session_id = f"session_{i}"
mock_entry.timestamp = "2024-01-01 10:00:00"
mock_entry.user = "test_user"
mock_entry.default_database = "test_db"
yield mock_entry
with patch.object(
source, "_fetch_lineage_entries_chunked", return_value=mock_generator()
):
mock_aggregator.gen_metadata.return_value = []
# Process entries
list(source._get_audit_log_mcps_with_aggregator())
# Verify aggregator.add was called for each entry (streaming)
assert mock_aggregator.add.call_count == 5
class TestConcurrencySupport:
"""Test thread safety and concurrent operations."""
def test_tables_cache_thread_safety(self):
"""Test that tables cache operations are thread-safe."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Verify lock exists
assert hasattr(source, "_tables_cache_lock")
# Test safe cache access methods
result = source._tables_cache.get("nonexistent_schema", [])
assert result == []
def test_cached_loop_tables_safe_access(self):
"""Test cached_loop_tables uses safe cache access."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Add test data to cache
test_table = TeradataTable(
database="test_db",
name="test_table",
description="Test",
object_type="Table",
create_timestamp=datetime.now(),
last_alter_name=None,
last_alter_timestamp=None,
request_text=None,
)
source._tables_cache["test_schema"] = [test_table]
# Mock inspector and config
mock_inspector = MagicMock()
mock_sql_config = MagicMock()
with patch(
"datahub.ingestion.source.sql.two_tier_sql_source.TwoTierSQLAlchemySource.loop_tables"
) as mock_super:
mock_super.return_value = []
# This should not raise an exception even with missing schema
list(
source.cached_loop_tables(
mock_inspector, "missing_schema", mock_sql_config
)
)
class TestStageTracking:
"""Test stage tracking functionality."""
def test_stage_tracking_in_cache_operation(self):
"""Test that table caching uses stage tracking."""
config = TeradataConfig.parse_obj(_base_config())
# Create source without mocking to test the actual stage tracking during init
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
), patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
) as mock_cache:
TeradataSource(config, PipelineContext(run_id="test"))
# Verify cache_tables_and_views was called during init (stage tracking happens there)
mock_cache.assert_called_once()
def test_stage_tracking_in_aggregator_processing(self):
"""Test that aggregator processing uses stage tracking."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Replace the aggregator with our mock after creation
source.aggregator = mock_aggregator
with patch.object(source.report, "new_stage") as mock_new_stage:
mock_context_manager = MagicMock()
mock_new_stage.return_value = mock_context_manager
with patch.object(
source, "_fetch_lineage_entries_chunked", return_value=[]
):
mock_aggregator.gen_metadata.return_value = []
list(source._get_audit_log_mcps_with_aggregator())
# Should have called new_stage for query processing and metadata generation
# The actual implementation uses new_stage for "Fetching queries" and "Generating metadata"
assert mock_new_stage.call_count >= 1
class TestErrorHandling:
"""Test error handling and edge cases."""
def test_empty_lineage_entries(self):
"""Test handling of empty lineage entries."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(
source, "_fetch_lineage_entries_chunked", return_value=[]
):
mock_aggregator.gen_metadata.return_value = []
result = list(source._get_audit_log_mcps_with_aggregator())
assert result == []
def test_malformed_query_entry(self):
"""Test handling of malformed query entries."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
# Mock cache_tables_and_views to prevent database connection during init
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock entry with missing attributes
mock_entry = MagicMock()
mock_entry.query_text = "SELECT 1"
# Simulate missing attributes
del mock_entry.session_id
del mock_entry.timestamp
del mock_entry.user
del mock_entry.default_database
# Should handle gracefully using getattr with defaults
observed_query = source._convert_entry_to_observed_query(mock_entry)
assert observed_query.query == "SELECT 1"
assert observed_query.session_id is None
assert observed_query.user is None
class TestLineageQuerySeparation:
"""Test the new separated lineage query functionality (no more UNION)."""
def test_make_lineage_queries_current_only(self):
"""Test that only current query is returned when historical lineage is disabled."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": False,
"start_time": "2024-01-01T00:00:00Z",
"end_time": "2024-01-02T00:00:00Z",
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
queries = source._make_lineage_queries()
assert len(queries) == 1
assert '"DBC".QryLogV' in queries[0]
assert "PDCRDATA.DBQLSqlTbl_Hst" not in queries[0]
assert "2024-01-01" in queries[0]
assert "2024-01-02" in queries[0]
def test_make_lineage_queries_with_historical_available(self):
"""Test that UNION query is returned when historical lineage is enabled and table exists."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": True,
"start_time": "2024-01-01T00:00:00Z",
"end_time": "2024-01-02T00:00:00Z",
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(
source, "_check_historical_table_exists", return_value=True
):
queries = source._make_lineage_queries()
assert len(queries) == 1
# Single UNION query should contain both historical and current data
union_query = queries[0]
assert '"DBC".QryLogV' in union_query
assert '"PDCRINFO".DBQLSqlTbl_Hst' in union_query
assert "UNION" in union_query
assert "combined_results" in union_query
# Should have the time filters
assert "2024-01-01" in union_query
assert "2024-01-02" in union_query
def test_make_lineage_queries_with_historical_unavailable(self):
"""Test that only current query is returned when historical lineage is enabled but table doesn't exist."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": True,
"start_time": "2024-01-01T00:00:00Z",
"end_time": "2024-01-02T00:00:00Z",
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(
source, "_check_historical_table_exists", return_value=False
):
queries = source._make_lineage_queries()
assert len(queries) == 1
assert '"DBC".QryLogV' in queries[0]
assert '"PDCRDATA".DBQLSqlTbl_Hst' not in queries[0]
def test_make_lineage_queries_with_database_filter(self):
"""Test that database filters are correctly applied to UNION query."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": True,
"databases": ["test_db1", "test_db2"],
"start_time": "2024-01-01T00:00:00Z",
"end_time": "2024-01-02T00:00:00Z",
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(
source, "_check_historical_table_exists", return_value=True
):
queries = source._make_lineage_queries()
assert len(queries) == 1
# UNION query should have database filters for both current and historical parts
union_query = queries[0]
assert "l.DefaultDatabase in ('test_db1','test_db2')" in union_query
assert "h.DefaultDatabase in ('test_db1','test_db2')" in union_query
def test_fetch_lineage_entries_chunked_multiple_queries(self):
"""Test that _fetch_lineage_entries_chunked handles multiple queries correctly."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": True,
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock the query generation to return 2 queries
with patch.object(
source, "_make_lineage_queries", return_value=["query1", "query2"]
):
# Mock database execution for both queries
mock_result1 = MagicMock()
mock_result1.fetchmany.side_effect = [
[MagicMock(query_text="SELECT 1")], # First batch
[], # End of results
]
mock_result2 = MagicMock()
mock_result2.fetchmany.side_effect = [
[MagicMock(query_text="SELECT 2")], # First batch
[], # End of results
]
mock_connection = MagicMock()
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = (
mock_connection
)
with patch.object(
source, "get_metadata_engine", return_value=mock_engine
), patch.object(
source, "_execute_with_cursor_fallback"
) as mock_execute:
mock_execute.side_effect = [mock_result1, mock_result2]
entries = list(source._fetch_lineage_entries_chunked())
# Should have executed both queries
assert mock_execute.call_count == 2
mock_execute.assert_any_call(mock_connection, "query1")
mock_execute.assert_any_call(mock_connection, "query2")
# Should return entries from both queries
assert len(entries) == 2
def test_fetch_lineage_entries_chunked_single_query(self):
"""Test that _fetch_lineage_entries_chunked handles single query correctly."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": False,
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Mock the query generation to return 1 query
with patch.object(source, "_make_lineage_queries", return_value=["query1"]):
mock_result = MagicMock()
mock_result.fetchmany.side_effect = [
[MagicMock(query_text="SELECT 1")], # First batch
[], # End of results
]
mock_connection = MagicMock()
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = (
mock_connection
)
with patch.object(
source, "get_metadata_engine", return_value=mock_engine
), patch.object(
source, "_execute_with_cursor_fallback", return_value=mock_result
) as mock_execute:
entries = list(source._fetch_lineage_entries_chunked())
# Should have executed only one query
assert mock_execute.call_count == 1
mock_execute.assert_called_with(mock_connection, "query1")
# Should return entries from the query
assert len(entries) == 1
def test_fetch_lineage_entries_chunked_batch_processing(self):
"""Test that batch processing works correctly with configurable batch size."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": False,
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(source, "_make_lineage_queries", return_value=["query1"]):
# Create mock entries
mock_entries = [MagicMock(query_text=f"SELECT {i}") for i in range(7)]
mock_result = MagicMock()
# Simulate batching with batch_size=5 (hardcoded in the method)
mock_result.fetchmany.side_effect = [
mock_entries[:5], # First batch (5 items)
mock_entries[5:], # Second batch (2 items)
[], # End of results
]
mock_connection = MagicMock()
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = (
mock_connection
)
with patch.object(
source, "get_metadata_engine", return_value=mock_engine
), patch.object(
source, "_execute_with_cursor_fallback", return_value=mock_result
):
entries = list(source._fetch_lineage_entries_chunked())
# Should return all 7 entries
assert len(entries) == 7
# Verify fetchmany was called with the right batch size (5000 is hardcoded)
calls = mock_result.fetchmany.call_args_list
for call in calls:
if call[0]: # If positional args
assert call[0][0] == 5000
else: # If keyword args
assert call[1].get("size", 5000) == 5000
def test_end_to_end_separate_queries_integration(self):
"""Test end-to-end integration of separate queries in the aggregator flow."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": True,
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Replace the aggregator with our mock after creation
source.aggregator = mock_aggregator
# Mock entries from both current and historical queries
current_entry = MagicMock()
current_entry.query_text = "SELECT * FROM current_table"
current_entry.user = "current_user"
current_entry.timestamp = "2024-01-01 10:00:00"
current_entry.default_database = "current_db"
historical_entry = MagicMock()
historical_entry.query_text = "SELECT * FROM historical_table"
historical_entry.user = "historical_user"
historical_entry.timestamp = "2023-12-01 10:00:00"
historical_entry.default_database = "historical_db"
def mock_fetch_generator():
yield current_entry
yield historical_entry
with patch.object(
source,
"_fetch_lineage_entries_chunked",
return_value=mock_fetch_generator(),
):
mock_aggregator.gen_metadata.return_value = []
# Execute the aggregator flow
list(source._get_audit_log_mcps_with_aggregator())
# Verify both entries were added to aggregator
assert mock_aggregator.add.call_count == 2
# Verify the entries were converted correctly
added_queries = [
call[0][0] for call in mock_aggregator.add.call_args_list
]
assert any(
"SELECT * FROM current_table" in query.query
for query in added_queries
)
assert any(
"SELECT * FROM historical_table" in query.query
for query in added_queries
)
def test_query_logging_and_progress_tracking(self):
"""Test that proper logging occurs when processing multiple queries."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": True,
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(
source, "_make_lineage_queries", return_value=["query1", "query2"]
):
mock_result = MagicMock()
call_counter = {"count": 0}
def mock_fetchmany_side_effect(batch_size):
# Return one batch then empty to simulate end of results
call_counter["count"] += 1
if call_counter["count"] == 1:
return [MagicMock(query_text="SELECT 1")]
return []
mock_result.fetchmany.side_effect = mock_fetchmany_side_effect
mock_connection = MagicMock()
mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = (
mock_connection
)
with patch.object(
source, "get_metadata_engine", return_value=mock_engine
), patch.object(
source, "_execute_with_cursor_fallback", return_value=mock_result
), patch("datahub.ingestion.source.sql.teradata.logger") as mock_logger:
list(source._fetch_lineage_entries_chunked())
# Verify progress logging for multiple queries
info_calls = [
call for call in mock_logger.info.call_args_list if call[0]
]
# Should log execution of query 1/2 and 2/2
assert any("query 1/2" in str(call) for call in info_calls)
assert any("query 2/2" in str(call) for call in info_calls)
# Should log completion of both queries
assert any("Completed query 1" in str(call) for call in info_calls)
assert any("Completed query 2" in str(call) for call in info_calls)
class TestQueryConstruction:
"""Test the construction of individual queries."""
def test_current_query_construction(self):
"""Test that the current query is constructed correctly."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"start_time": "2024-01-01T00:00:00Z",
"end_time": "2024-01-02T00:00:00Z",
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
queries = source._make_lineage_queries()
current_query = queries[0]
# Verify current query structure
assert 'FROM "DBC".QryLogV as l' in current_query
assert 'JOIN "DBC".QryLogSqlV as s' in current_query
assert "l.ErrorCode = 0" in current_query
assert "2024-01-01" in current_query
assert "2024-01-02" in current_query
assert 'ORDER BY "timestamp", "query_id", "row_no"' in current_query
def test_historical_query_construction(self):
"""Test that the UNION query contains historical data correctly."""
config = TeradataConfig.parse_obj(
{
**_base_config(),
"include_historical_lineage": True,
"start_time": "2024-01-01T00:00:00Z",
"end_time": "2024-01-02T00:00:00Z",
}
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
with patch.object(
source, "_check_historical_table_exists", return_value=True
):
queries = source._make_lineage_queries()
union_query = queries[0]
# Verify UNION query contains historical data structure
assert 'FROM "PDCRINFO".DBQLSqlTbl_Hst as h' in union_query
assert "h.ErrorCode = 0" in union_query
assert "h.StartTime AT TIME ZONE 'GMT'" in union_query
assert "h.DefaultDatabase" in union_query
assert "2024-01-01" in union_query
assert "2024-01-02" in union_query
assert 'ORDER BY "timestamp", "query_id", "row_no"' in union_query
assert "UNION" in union_query
class TestStreamingQueryReconstruction:
"""Test the streaming query reconstruction functionality."""
def test_reconstruct_queries_streaming_single_row_queries(self):
"""Test streaming reconstruction with single-row queries."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create entries for single-row queries
entries = [
self._create_mock_entry(
"Q1", "SELECT * FROM table1", 1, "2024-01-01 10:00:00"
),
self._create_mock_entry(
"Q2", "SELECT * FROM table2", 1, "2024-01-01 10:01:00"
),
self._create_mock_entry(
"Q3", "SELECT * FROM table3", 1, "2024-01-01 10:02:00"
),
]
# Test streaming reconstruction
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
assert len(reconstructed_queries) == 3
assert reconstructed_queries[0].query == "SELECT * FROM table1"
assert reconstructed_queries[1].query == "SELECT * FROM table2"
assert reconstructed_queries[2].query == "SELECT * FROM table3"
# Verify metadata preservation
assert reconstructed_queries[0].timestamp == "2024-01-01 10:00:00"
assert reconstructed_queries[1].timestamp == "2024-01-01 10:01:00"
assert reconstructed_queries[2].timestamp == "2024-01-01 10:02:00"
def test_reconstruct_queries_streaming_multi_row_queries(self):
"""Test streaming reconstruction with multi-row queries."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create entries for multi-row queries
entries = [
# Query 1: 3 rows
self._create_mock_entry(
"Q1", "SELECT a, b, c ", 1, "2024-01-01 10:00:00"
),
self._create_mock_entry(
"Q1", "FROM large_table ", 2, "2024-01-01 10:00:00"
),
self._create_mock_entry(
"Q1", "WHERE id > 1000", 3, "2024-01-01 10:00:00"
),
# Query 2: 2 rows
self._create_mock_entry(
"Q2", "UPDATE table3 SET ", 1, "2024-01-01 10:01:00"
),
self._create_mock_entry(
"Q2", "status = 'active'", 2, "2024-01-01 10:01:00"
),
]
# Test streaming reconstruction
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
assert len(reconstructed_queries) == 2
assert (
reconstructed_queries[0].query
== "SELECT a, b, c FROM large_table WHERE id > 1000"
)
assert (
reconstructed_queries[1].query == "UPDATE table3 SET status = 'active'"
)
# Verify metadata preservation (should use metadata from first row of each query)
assert reconstructed_queries[0].timestamp == "2024-01-01 10:00:00"
assert reconstructed_queries[1].timestamp == "2024-01-01 10:01:00"
def test_reconstruct_queries_streaming_mixed_queries(self):
"""Test streaming reconstruction with mixed single and multi-row queries."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create entries mixing single and multi-row queries
entries = [
# Single-row query
self._create_mock_entry(
"Q1", "SELECT * FROM table1", 1, "2024-01-01 10:00:00"
),
# Multi-row query (3 rows)
self._create_mock_entry(
"Q2", "SELECT a, b, c ", 1, "2024-01-01 10:01:00"
),
self._create_mock_entry(
"Q2", "FROM large_table ", 2, "2024-01-01 10:01:00"
),
self._create_mock_entry(
"Q2", "WHERE id > 1000", 3, "2024-01-01 10:01:00"
),
# Single-row query
self._create_mock_entry(
"Q3", "SELECT COUNT(*) FROM table2", 1, "2024-01-01 10:02:00"
),
# Multi-row query (2 rows)
self._create_mock_entry(
"Q4", "UPDATE table3 SET ", 1, "2024-01-01 10:03:00"
),
self._create_mock_entry(
"Q4", "status = 'active'", 2, "2024-01-01 10:03:00"
),
]
# Test streaming reconstruction
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
assert len(reconstructed_queries) == 4
assert reconstructed_queries[0].query == "SELECT * FROM table1"
assert (
reconstructed_queries[1].query
== "SELECT a, b, c FROM large_table WHERE id > 1000"
)
assert reconstructed_queries[2].query == "SELECT COUNT(*) FROM table2"
assert (
reconstructed_queries[3].query == "UPDATE table3 SET status = 'active'"
)
def test_reconstruct_queries_streaming_empty_entries(self):
"""Test streaming reconstruction with empty entries."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Test with empty entries
entries: List[Any] = []
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
assert len(reconstructed_queries) == 0
def test_reconstruct_queries_streaming_teradata_specific_transformations(self):
"""Test that Teradata-specific transformations are applied."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create entry with Teradata-specific syntax
entries = [
self._create_mock_entry(
"Q1",
"SELECT * FROM table1 (NOT CASESPECIFIC)",
1,
"2024-01-01 10:00:00",
),
]
# Test streaming reconstruction
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
assert len(reconstructed_queries) == 1
# Should remove (NOT CASESPECIFIC)
assert reconstructed_queries[0].query == "SELECT * FROM table1 "
def test_reconstruct_queries_streaming_metadata_preservation(self):
"""Test that all metadata fields are preserved correctly."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create entry with all metadata fields
entries: List[Any] = [
self._create_mock_entry(
"Q1",
"SELECT * FROM table1",
1,
"2024-01-01 10:00:00",
user="test_user",
default_database="test_db",
session_id="session123",
),
]
# Test streaming reconstruction
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
assert len(reconstructed_queries) == 1
query = reconstructed_queries[0]
# Verify all metadata fields
assert query.query == "SELECT * FROM table1"
assert query.timestamp == "2024-01-01 10:00:00"
assert isinstance(query.user, CorpUserUrn)
assert str(query.user) == "urn:li:corpuser:test_user"
assert query.default_db == "test_db"
assert query.default_schema == "test_db" # Teradata uses database as schema
assert query.session_id == "session123"
def test_reconstruct_queries_streaming_with_none_user(self):
"""Test streaming reconstruction handles None user correctly."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create entry with None user
entries = [
self._create_mock_entry(
"Q1", "SELECT * FROM table1", 1, "2024-01-01 10:00:00", user=None
),
]
# Test streaming reconstruction
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
assert len(reconstructed_queries) == 1
assert reconstructed_queries[0].user is None
def test_reconstruct_queries_streaming_empty_query_text(self):
"""Test streaming reconstruction handles empty query text correctly."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Create entries with empty query text
entries = [
self._create_mock_entry("Q1", "", 1, "2024-01-01 10:00:00"),
self._create_mock_entry(
"Q2", "SELECT * FROM table1", 1, "2024-01-01 10:01:00"
),
]
# Test streaming reconstruction
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
# Should only get one query (the non-empty one)
assert len(reconstructed_queries) == 1
assert reconstructed_queries[0].query == "SELECT * FROM table1"
def test_reconstruct_queries_streaming_space_joining_behavior(self):
"""Test that query parts are joined directly without adding spaces."""
config = TeradataConfig.parse_obj(_base_config())
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
) as mock_aggregator_class:
mock_aggregator = MagicMock()
mock_aggregator_class.return_value = mock_aggregator
with patch(
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
):
source = TeradataSource(config, PipelineContext(run_id="test"))
# Test case 1: Parts that include their own spacing
entries1 = [
self._create_mock_entry("Q1", "SELECT ", 1, "2024-01-01 10:00:00"),
self._create_mock_entry("Q1", "col1, ", 2, "2024-01-01 10:00:00"),
self._create_mock_entry("Q1", "col2 ", 3, "2024-01-01 10:00:00"),
self._create_mock_entry("Q1", "FROM ", 4, "2024-01-01 10:00:00"),
self._create_mock_entry("Q1", "table1", 5, "2024-01-01 10:00:00"),
]
# Test case 2: Parts that already have trailing/leading spaces
entries2 = [
self._create_mock_entry("Q2", "SELECT * ", 1, "2024-01-01 10:01:00"),
self._create_mock_entry("Q2", "FROM table2 ", 2, "2024-01-01 10:01:00"),
self._create_mock_entry("Q2", "WHERE id > 1", 3, "2024-01-01 10:01:00"),
]
# Test streaming reconstruction
all_entries = entries1 + entries2
reconstructed_queries = list(
source._reconstruct_queries_streaming(all_entries)
)
assert len(reconstructed_queries) == 2
# Query 1: Should be joined directly without adding spaces
assert reconstructed_queries[0].query == "SELECT col1, col2 FROM table1"
# Query 2: Should handle existing spaces correctly (may have extra spaces)
assert reconstructed_queries[1].query == "SELECT * FROM table2 WHERE id > 1"
def _create_mock_entry(
self,
query_id,
query_text,
row_no,
timestamp,
user="test_user",
default_database="test_db",
session_id=None,
):
"""Create a mock database entry for testing."""
entry = MagicMock()
entry.query_id = query_id
entry.query_text = query_text
entry.row_no = row_no
entry.timestamp = timestamp
entry.user = user
entry.default_database = default_database
entry.session_id = session_id or f"session_{query_id}"
return entry