mirror of
https://github.com/datahub-project/datahub.git
synced 2025-08-06 08:18:08 +00:00
1497 lines
60 KiB
Python
1497 lines
60 KiB
Python
from datetime import datetime
|
|
from typing import Any, Dict, List
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from datahub.ingestion.api.common import PipelineContext
|
|
from datahub.ingestion.source.sql.teradata import (
|
|
TeradataConfig,
|
|
TeradataSource,
|
|
TeradataTable,
|
|
get_schema_columns,
|
|
get_schema_pk_constraints,
|
|
)
|
|
from datahub.metadata.urns import CorpUserUrn
|
|
from datahub.sql_parsing.sql_parsing_aggregator import ObservedQuery
|
|
|
|
|
|
def _base_config() -> Dict[str, Any]:
|
|
"""Base configuration for Teradata tests."""
|
|
return {
|
|
"username": "test_user",
|
|
"password": "test_password",
|
|
"host_port": "localhost:1025",
|
|
"include_table_lineage": True,
|
|
"include_usage_statistics": True,
|
|
"include_queries": True,
|
|
}
|
|
|
|
|
|
class TestTeradataConfig:
|
|
"""Test configuration validation and initialization."""
|
|
|
|
def test_valid_config(self):
|
|
"""Test that valid configuration is accepted."""
|
|
config_dict = _base_config()
|
|
config = TeradataConfig.parse_obj(config_dict)
|
|
|
|
assert config.host_port == "localhost:1025"
|
|
assert config.include_table_lineage is True
|
|
assert config.include_usage_statistics is True
|
|
assert config.include_queries is True
|
|
|
|
def test_max_workers_validation_valid(self):
|
|
"""Test valid max_workers configuration passes validation."""
|
|
config_dict = {
|
|
**_base_config(),
|
|
"max_workers": 8,
|
|
}
|
|
config = TeradataConfig.parse_obj(config_dict)
|
|
assert config.max_workers == 8
|
|
|
|
def test_max_workers_default(self):
|
|
"""Test max_workers defaults to 10."""
|
|
config_dict = _base_config()
|
|
config = TeradataConfig.parse_obj(config_dict)
|
|
assert config.max_workers == 10
|
|
|
|
def test_max_workers_custom_value(self):
|
|
"""Test custom max_workers value is accepted."""
|
|
config_dict = {
|
|
**_base_config(),
|
|
"max_workers": 5,
|
|
}
|
|
config = TeradataConfig.parse_obj(config_dict)
|
|
assert config.max_workers == 5
|
|
|
|
def test_include_queries_default(self):
|
|
"""Test include_queries defaults to True."""
|
|
config_dict = _base_config()
|
|
config = TeradataConfig.parse_obj(config_dict)
|
|
assert config.include_queries is True
|
|
|
|
|
|
class TestTeradataSource:
|
|
"""Test Teradata source functionality."""
|
|
|
|
@patch("datahub.ingestion.source.sql.teradata.create_engine")
|
|
def test_source_initialization(self, mock_create_engine):
|
|
"""Test source initializes correctly."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
ctx = PipelineContext(run_id="test")
|
|
|
|
# Mock the engine creation
|
|
mock_engine = MagicMock()
|
|
mock_create_engine.return_value = mock_engine
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
# Mock cache_tables_and_views to prevent database connection during init
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, ctx)
|
|
|
|
assert source.config == config
|
|
assert source.platform == "teradata"
|
|
assert hasattr(source, "aggregator")
|
|
assert hasattr(source, "_tables_cache")
|
|
assert hasattr(source, "_tables_cache_lock")
|
|
|
|
@patch("datahub.ingestion.source.sql.teradata.create_engine")
|
|
@patch("datahub.ingestion.source.sql.teradata.inspect")
|
|
def test_get_inspectors(self, mock_inspect, mock_create_engine):
|
|
"""Test inspector creation and database iteration."""
|
|
# Mock database names returned by inspector
|
|
mock_inspector = MagicMock()
|
|
mock_inspector.get_schema_names.return_value = ["db1", "db2", "test_db"]
|
|
mock_inspect.return_value = mock_inspector
|
|
|
|
mock_connection = MagicMock()
|
|
mock_engine = MagicMock()
|
|
mock_engine.connect.return_value.__enter__.return_value = mock_connection
|
|
mock_create_engine.return_value = mock_engine
|
|
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
# Mock cache_tables_and_views to prevent database connection during init
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
with patch.object(source, "get_metadata_engine", return_value=mock_engine):
|
|
inspectors = list(source.get_inspectors())
|
|
|
|
assert len(inspectors) == 3
|
|
# Check that each inspector has the database name set
|
|
for inspector in inspectors:
|
|
assert hasattr(inspector, "_datahub_database")
|
|
|
|
def test_cache_tables_and_views_thread_safety(self):
|
|
"""Test that cache operations are thread-safe."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
# Mock cache_tables_and_views to prevent database connection during init
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Mock engine and query results
|
|
mock_entry = MagicMock()
|
|
mock_entry.DataBaseName = "test_db"
|
|
mock_entry.name = "test_table"
|
|
mock_entry.description = "Test table"
|
|
mock_entry.object_type = "Table"
|
|
mock_entry.CreateTimeStamp = None
|
|
mock_entry.LastAlterName = None
|
|
mock_entry.LastAlterTimeStamp = None
|
|
mock_entry.RequestText = None
|
|
|
|
with patch.object(source, "get_metadata_engine") as mock_get_engine:
|
|
mock_engine = MagicMock()
|
|
mock_engine.execute.return_value = [mock_entry]
|
|
mock_get_engine.return_value = mock_engine
|
|
|
|
# Call the method after patching the engine
|
|
source.cache_tables_and_views()
|
|
|
|
# Verify table was added to cache
|
|
assert "test_db" in source._tables_cache
|
|
assert len(source._tables_cache["test_db"]) == 1
|
|
assert source._tables_cache["test_db"][0].name == "test_table"
|
|
|
|
# Verify engine was disposed
|
|
mock_engine.dispose.assert_called_once()
|
|
|
|
def test_convert_entry_to_observed_query(self):
|
|
"""Test conversion of database entries to ObservedQuery objects."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
# Mock cache_tables_and_views to prevent database connection during init
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Mock database entry
|
|
mock_entry = MagicMock()
|
|
mock_entry.query_text = "SELECT * FROM table1 (NOT CASESPECIFIC)"
|
|
mock_entry.session_id = "session123"
|
|
mock_entry.timestamp = "2024-01-01 10:00:00"
|
|
mock_entry.user = "test_user"
|
|
mock_entry.default_database = "test_db"
|
|
|
|
observed_query = source._convert_entry_to_observed_query(mock_entry)
|
|
|
|
assert isinstance(observed_query, ObservedQuery)
|
|
assert (
|
|
observed_query.query == "SELECT * FROM table1 "
|
|
) # (NOT CASESPECIFIC) removed
|
|
assert observed_query.session_id == "session123"
|
|
assert observed_query.timestamp == "2024-01-01 10:00:00"
|
|
assert isinstance(observed_query.user, CorpUserUrn)
|
|
assert observed_query.default_db == "test_db"
|
|
assert observed_query.default_schema == "test_db"
|
|
|
|
def test_convert_entry_to_observed_query_with_none_user(self):
|
|
"""Test ObservedQuery conversion handles None user correctly."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
# Mock cache_tables_and_views to prevent database connection during init
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
mock_entry = MagicMock()
|
|
mock_entry.query_text = "SELECT 1"
|
|
mock_entry.session_id = "session123"
|
|
mock_entry.timestamp = "2024-01-01 10:00:00"
|
|
mock_entry.user = None
|
|
mock_entry.default_database = "test_db"
|
|
|
|
observed_query = source._convert_entry_to_observed_query(mock_entry)
|
|
|
|
assert observed_query.user is None
|
|
|
|
def test_check_historical_table_exists_success(self):
|
|
"""Test historical table check when table exists."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
# Mock cache_tables_and_views to prevent database connection during init
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Mock successful query execution
|
|
mock_connection = MagicMock()
|
|
mock_engine = MagicMock()
|
|
mock_engine.connect.return_value.__enter__.return_value = mock_connection
|
|
|
|
with patch.object(source, "get_metadata_engine", return_value=mock_engine):
|
|
result = source._check_historical_table_exists()
|
|
|
|
assert result is True
|
|
mock_engine.dispose.assert_called_once()
|
|
|
|
def test_check_historical_table_exists_failure(self):
|
|
"""Test historical table check when table doesn't exist."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
# Mock cache_tables_and_views to prevent database connection during init
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Mock failed query execution
|
|
mock_connection = MagicMock()
|
|
mock_connection.execute.side_effect = Exception("Table not found")
|
|
mock_engine = MagicMock()
|
|
mock_engine.connect.return_value.__enter__.return_value = mock_connection
|
|
|
|
with patch.object(source, "get_metadata_engine", return_value=mock_engine):
|
|
result = source._check_historical_table_exists()
|
|
|
|
assert result is False
|
|
mock_engine.dispose.assert_called_once()
|
|
|
|
def test_close_cleanup(self):
|
|
"""Test that close() properly cleans up resources."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
# Mock cache_tables_and_views to prevent database connection during init
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Replace the aggregator with our mock after creation
|
|
source.aggregator = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.two_tier_sql_source.TwoTierSQLAlchemySource.close"
|
|
) as mock_super_close:
|
|
source.close()
|
|
|
|
mock_aggregator.close.assert_called_once()
|
|
mock_super_close.assert_called_once()
|
|
|
|
|
|
class TestSQLInjectionSafety:
|
|
"""Test SQL injection vulnerability fixes."""
|
|
|
|
def test_get_schema_columns_parameterized(self):
|
|
"""Test that get_schema_columns uses parameterized queries."""
|
|
mock_connection = MagicMock()
|
|
mock_connection.execute.return_value.fetchall.return_value = []
|
|
|
|
# Call the function
|
|
get_schema_columns(None, mock_connection, "columnsV", "test_schema")
|
|
|
|
# Verify parameterized query was used
|
|
call_args = mock_connection.execute.call_args
|
|
query = call_args[0][0].text
|
|
params = call_args[0][1] if len(call_args[0]) > 1 else call_args[1]
|
|
|
|
assert ":schema" in query
|
|
assert "schema" in params
|
|
assert params["schema"] == "test_schema"
|
|
|
|
def test_get_schema_pk_constraints_parameterized(self):
|
|
"""Test that get_schema_pk_constraints uses parameterized queries."""
|
|
mock_connection = MagicMock()
|
|
mock_connection.execute.return_value.fetchall.return_value = []
|
|
|
|
# Call the function
|
|
get_schema_pk_constraints(None, mock_connection, "test_schema")
|
|
|
|
# Verify parameterized query was used
|
|
call_args = mock_connection.execute.call_args
|
|
query = call_args[0][0].text
|
|
params = call_args[0][1] if len(call_args[0]) > 1 else call_args[1]
|
|
|
|
assert ":schema" in query
|
|
assert "schema" in params
|
|
assert params["schema"] == "test_schema"
|
|
|
|
|
|
class TestMemoryEfficiency:
|
|
"""Test memory efficiency improvements."""
|
|
|
|
def test_fetch_lineage_entries_chunked_streaming(self):
|
|
"""Test that lineage entries are processed in streaming fashion."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
# Mock cache_tables_and_views to prevent database connection during init
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Replace the aggregator with our mock after creation
|
|
source.aggregator = mock_aggregator
|
|
|
|
# Mock the chunked fetching method to return a generator
|
|
def mock_generator():
|
|
for i in range(5):
|
|
mock_entry = MagicMock()
|
|
mock_entry.query_text = f"SELECT {i}"
|
|
mock_entry.session_id = f"session_{i}"
|
|
mock_entry.timestamp = "2024-01-01 10:00:00"
|
|
mock_entry.user = "test_user"
|
|
mock_entry.default_database = "test_db"
|
|
yield mock_entry
|
|
|
|
with patch.object(
|
|
source, "_fetch_lineage_entries_chunked", return_value=mock_generator()
|
|
):
|
|
mock_aggregator.gen_metadata.return_value = []
|
|
|
|
# Process entries
|
|
list(source._get_audit_log_mcps_with_aggregator())
|
|
|
|
# Verify aggregator.add was called for each entry (streaming)
|
|
assert mock_aggregator.add.call_count == 5
|
|
|
|
|
|
class TestConcurrencySupport:
|
|
"""Test thread safety and concurrent operations."""
|
|
|
|
def test_tables_cache_thread_safety(self):
|
|
"""Test that tables cache operations are thread-safe."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
# Mock cache_tables_and_views to prevent database connection during init
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Verify lock exists
|
|
assert hasattr(source, "_tables_cache_lock")
|
|
|
|
# Test safe cache access methods
|
|
result = source._tables_cache.get("nonexistent_schema", [])
|
|
assert result == []
|
|
|
|
def test_cached_loop_tables_safe_access(self):
|
|
"""Test cached_loop_tables uses safe cache access."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
# Mock cache_tables_and_views to prevent database connection during init
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Add test data to cache
|
|
test_table = TeradataTable(
|
|
database="test_db",
|
|
name="test_table",
|
|
description="Test",
|
|
object_type="Table",
|
|
create_timestamp=datetime.now(),
|
|
last_alter_name=None,
|
|
last_alter_timestamp=None,
|
|
request_text=None,
|
|
)
|
|
source._tables_cache["test_schema"] = [test_table]
|
|
|
|
# Mock inspector and config
|
|
mock_inspector = MagicMock()
|
|
mock_sql_config = MagicMock()
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.two_tier_sql_source.TwoTierSQLAlchemySource.loop_tables"
|
|
) as mock_super:
|
|
mock_super.return_value = []
|
|
|
|
# This should not raise an exception even with missing schema
|
|
list(
|
|
source.cached_loop_tables(
|
|
mock_inspector, "missing_schema", mock_sql_config
|
|
)
|
|
)
|
|
|
|
|
|
class TestStageTracking:
|
|
"""Test stage tracking functionality."""
|
|
|
|
def test_stage_tracking_in_cache_operation(self):
|
|
"""Test that table caching uses stage tracking."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
# Create source without mocking to test the actual stage tracking during init
|
|
with (
|
|
patch("datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"),
|
|
patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
) as mock_cache,
|
|
):
|
|
TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Verify cache_tables_and_views was called during init (stage tracking happens there)
|
|
mock_cache.assert_called_once()
|
|
|
|
def test_stage_tracking_in_aggregator_processing(self):
|
|
"""Test that aggregator processing uses stage tracking."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
# Mock cache_tables_and_views to prevent database connection during init
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Replace the aggregator with our mock after creation
|
|
source.aggregator = mock_aggregator
|
|
|
|
with patch.object(source.report, "new_stage") as mock_new_stage:
|
|
mock_context_manager = MagicMock()
|
|
mock_new_stage.return_value = mock_context_manager
|
|
|
|
with patch.object(
|
|
source, "_fetch_lineage_entries_chunked", return_value=[]
|
|
):
|
|
mock_aggregator.gen_metadata.return_value = []
|
|
|
|
list(source._get_audit_log_mcps_with_aggregator())
|
|
|
|
# Should have called new_stage for query processing and metadata generation
|
|
# The actual implementation uses new_stage for "Fetching queries" and "Generating metadata"
|
|
assert mock_new_stage.call_count >= 1
|
|
|
|
|
|
class TestErrorHandling:
|
|
"""Test error handling and edge cases."""
|
|
|
|
def test_empty_lineage_entries(self):
|
|
"""Test handling of empty lineage entries."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
# Mock cache_tables_and_views to prevent database connection during init
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
with patch.object(
|
|
source, "_fetch_lineage_entries_chunked", return_value=[]
|
|
):
|
|
mock_aggregator.gen_metadata.return_value = []
|
|
result = list(source._get_audit_log_mcps_with_aggregator())
|
|
assert result == []
|
|
|
|
def test_malformed_query_entry(self):
|
|
"""Test handling of malformed query entries."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
# Mock cache_tables_and_views to prevent database connection during init
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Mock entry with missing attributes
|
|
mock_entry = MagicMock()
|
|
mock_entry.query_text = "SELECT 1"
|
|
# Simulate missing attributes
|
|
del mock_entry.session_id
|
|
del mock_entry.timestamp
|
|
del mock_entry.user
|
|
del mock_entry.default_database
|
|
|
|
# Should handle gracefully using getattr with defaults
|
|
observed_query = source._convert_entry_to_observed_query(mock_entry)
|
|
|
|
assert observed_query.query == "SELECT 1"
|
|
assert observed_query.session_id is None
|
|
assert observed_query.user is None
|
|
|
|
|
|
class TestLineageQuerySeparation:
|
|
"""Test the new separated lineage query functionality (no more UNION)."""
|
|
|
|
def test_make_lineage_queries_current_only(self):
|
|
"""Test that only current query is returned when historical lineage is disabled."""
|
|
config = TeradataConfig.parse_obj(
|
|
{
|
|
**_base_config(),
|
|
"include_historical_lineage": False,
|
|
"start_time": "2024-01-01T00:00:00Z",
|
|
"end_time": "2024-01-02T00:00:00Z",
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
queries = source._make_lineage_queries()
|
|
|
|
assert len(queries) == 1
|
|
assert '"DBC".QryLogV' in queries[0]
|
|
assert "PDCRDATA.DBQLSqlTbl_Hst" not in queries[0]
|
|
assert "2024-01-01" in queries[0]
|
|
assert "2024-01-02" in queries[0]
|
|
|
|
def test_make_lineage_queries_with_historical_available(self):
|
|
"""Test that UNION query is returned when historical lineage is enabled and table exists."""
|
|
config = TeradataConfig.parse_obj(
|
|
{
|
|
**_base_config(),
|
|
"include_historical_lineage": True,
|
|
"start_time": "2024-01-01T00:00:00Z",
|
|
"end_time": "2024-01-02T00:00:00Z",
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
with patch.object(
|
|
source, "_check_historical_table_exists", return_value=True
|
|
):
|
|
queries = source._make_lineage_queries()
|
|
|
|
assert len(queries) == 1
|
|
|
|
# Single UNION query should contain both historical and current data
|
|
union_query = queries[0]
|
|
assert '"DBC".QryLogV' in union_query
|
|
assert '"PDCRINFO".DBQLSqlTbl_Hst' in union_query
|
|
assert "UNION" in union_query
|
|
assert "combined_results" in union_query
|
|
|
|
# Should have the time filters
|
|
assert "2024-01-01" in union_query
|
|
assert "2024-01-02" in union_query
|
|
|
|
def test_make_lineage_queries_with_historical_unavailable(self):
|
|
"""Test that only current query is returned when historical lineage is enabled but table doesn't exist."""
|
|
config = TeradataConfig.parse_obj(
|
|
{
|
|
**_base_config(),
|
|
"include_historical_lineage": True,
|
|
"start_time": "2024-01-01T00:00:00Z",
|
|
"end_time": "2024-01-02T00:00:00Z",
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
with patch.object(
|
|
source, "_check_historical_table_exists", return_value=False
|
|
):
|
|
queries = source._make_lineage_queries()
|
|
|
|
assert len(queries) == 1
|
|
assert '"DBC".QryLogV' in queries[0]
|
|
assert '"PDCRDATA".DBQLSqlTbl_Hst' not in queries[0]
|
|
|
|
def test_make_lineage_queries_with_database_filter(self):
|
|
"""Test that database filters are correctly applied to UNION query."""
|
|
config = TeradataConfig.parse_obj(
|
|
{
|
|
**_base_config(),
|
|
"include_historical_lineage": True,
|
|
"databases": ["test_db1", "test_db2"],
|
|
"start_time": "2024-01-01T00:00:00Z",
|
|
"end_time": "2024-01-02T00:00:00Z",
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
with patch.object(
|
|
source, "_check_historical_table_exists", return_value=True
|
|
):
|
|
queries = source._make_lineage_queries()
|
|
|
|
assert len(queries) == 1
|
|
|
|
# UNION query should have database filters for both current and historical parts
|
|
union_query = queries[0]
|
|
assert "l.DefaultDatabase in ('test_db1','test_db2')" in union_query
|
|
assert "h.DefaultDatabase in ('test_db1','test_db2')" in union_query
|
|
|
|
def test_fetch_lineage_entries_chunked_multiple_queries(self):
|
|
"""Test that _fetch_lineage_entries_chunked handles multiple queries correctly."""
|
|
config = TeradataConfig.parse_obj(
|
|
{
|
|
**_base_config(),
|
|
"include_historical_lineage": True,
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Mock the query generation to return 2 queries
|
|
with patch.object(
|
|
source, "_make_lineage_queries", return_value=["query1", "query2"]
|
|
):
|
|
# Mock database execution for both queries
|
|
mock_result1 = MagicMock()
|
|
mock_result1.fetchmany.side_effect = [
|
|
[MagicMock(query_text="SELECT 1")], # First batch
|
|
[], # End of results
|
|
]
|
|
|
|
mock_result2 = MagicMock()
|
|
mock_result2.fetchmany.side_effect = [
|
|
[MagicMock(query_text="SELECT 2")], # First batch
|
|
[], # End of results
|
|
]
|
|
|
|
mock_connection = MagicMock()
|
|
mock_engine = MagicMock()
|
|
mock_engine.connect.return_value.__enter__.return_value = (
|
|
mock_connection
|
|
)
|
|
|
|
with (
|
|
patch.object(
|
|
source, "get_metadata_engine", return_value=mock_engine
|
|
),
|
|
patch.object(
|
|
source, "_execute_with_cursor_fallback"
|
|
) as mock_execute,
|
|
):
|
|
mock_execute.side_effect = [mock_result1, mock_result2]
|
|
|
|
entries = list(source._fetch_lineage_entries_chunked())
|
|
|
|
# Should have executed both queries
|
|
assert mock_execute.call_count == 2
|
|
mock_execute.assert_any_call(mock_connection, "query1")
|
|
mock_execute.assert_any_call(mock_connection, "query2")
|
|
|
|
# Should return entries from both queries
|
|
assert len(entries) == 2
|
|
|
|
def test_fetch_lineage_entries_chunked_single_query(self):
|
|
"""Test that _fetch_lineage_entries_chunked handles single query correctly."""
|
|
config = TeradataConfig.parse_obj(
|
|
{
|
|
**_base_config(),
|
|
"include_historical_lineage": False,
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Mock the query generation to return 1 query
|
|
with patch.object(source, "_make_lineage_queries", return_value=["query1"]):
|
|
mock_result = MagicMock()
|
|
mock_result.fetchmany.side_effect = [
|
|
[MagicMock(query_text="SELECT 1")], # First batch
|
|
[], # End of results
|
|
]
|
|
|
|
mock_connection = MagicMock()
|
|
mock_engine = MagicMock()
|
|
mock_engine.connect.return_value.__enter__.return_value = (
|
|
mock_connection
|
|
)
|
|
|
|
with (
|
|
patch.object(
|
|
source, "get_metadata_engine", return_value=mock_engine
|
|
),
|
|
patch.object(
|
|
source,
|
|
"_execute_with_cursor_fallback",
|
|
return_value=mock_result,
|
|
) as mock_execute,
|
|
):
|
|
entries = list(source._fetch_lineage_entries_chunked())
|
|
|
|
# Should have executed only one query
|
|
assert mock_execute.call_count == 1
|
|
mock_execute.assert_called_with(mock_connection, "query1")
|
|
|
|
# Should return entries from the query
|
|
assert len(entries) == 1
|
|
|
|
def test_fetch_lineage_entries_chunked_batch_processing(self):
|
|
"""Test that batch processing works correctly with configurable batch size."""
|
|
config = TeradataConfig.parse_obj(
|
|
{
|
|
**_base_config(),
|
|
"include_historical_lineage": False,
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
with patch.object(source, "_make_lineage_queries", return_value=["query1"]):
|
|
# Create mock entries
|
|
mock_entries = [MagicMock(query_text=f"SELECT {i}") for i in range(7)]
|
|
|
|
mock_result = MagicMock()
|
|
# Simulate batching with batch_size=5 (hardcoded in the method)
|
|
mock_result.fetchmany.side_effect = [
|
|
mock_entries[:5], # First batch (5 items)
|
|
mock_entries[5:], # Second batch (2 items)
|
|
[], # End of results
|
|
]
|
|
|
|
mock_connection = MagicMock()
|
|
mock_engine = MagicMock()
|
|
mock_engine.connect.return_value.__enter__.return_value = (
|
|
mock_connection
|
|
)
|
|
|
|
with (
|
|
patch.object(
|
|
source, "get_metadata_engine", return_value=mock_engine
|
|
),
|
|
patch.object(
|
|
source,
|
|
"_execute_with_cursor_fallback",
|
|
return_value=mock_result,
|
|
),
|
|
):
|
|
entries = list(source._fetch_lineage_entries_chunked())
|
|
|
|
# Should return all 7 entries
|
|
assert len(entries) == 7
|
|
|
|
# Verify fetchmany was called with the right batch size (5000 is hardcoded)
|
|
calls = mock_result.fetchmany.call_args_list
|
|
for call in calls:
|
|
if call[0]: # If positional args
|
|
assert call[0][0] == 5000
|
|
else: # If keyword args
|
|
assert call[1].get("size", 5000) == 5000
|
|
|
|
def test_end_to_end_separate_queries_integration(self):
|
|
"""Test end-to-end integration of separate queries in the aggregator flow."""
|
|
config = TeradataConfig.parse_obj(
|
|
{
|
|
**_base_config(),
|
|
"include_historical_lineage": True,
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Replace the aggregator with our mock after creation
|
|
source.aggregator = mock_aggregator
|
|
|
|
# Mock entries from both current and historical queries
|
|
current_entry = MagicMock()
|
|
current_entry.query_text = "SELECT * FROM current_table"
|
|
current_entry.user = "current_user"
|
|
current_entry.timestamp = "2024-01-01 10:00:00"
|
|
current_entry.default_database = "current_db"
|
|
|
|
historical_entry = MagicMock()
|
|
historical_entry.query_text = "SELECT * FROM historical_table"
|
|
historical_entry.user = "historical_user"
|
|
historical_entry.timestamp = "2023-12-01 10:00:00"
|
|
historical_entry.default_database = "historical_db"
|
|
|
|
def mock_fetch_generator():
|
|
yield current_entry
|
|
yield historical_entry
|
|
|
|
with patch.object(
|
|
source,
|
|
"_fetch_lineage_entries_chunked",
|
|
return_value=mock_fetch_generator(),
|
|
):
|
|
mock_aggregator.gen_metadata.return_value = []
|
|
|
|
# Execute the aggregator flow
|
|
list(source._get_audit_log_mcps_with_aggregator())
|
|
|
|
# Verify both entries were added to aggregator
|
|
assert mock_aggregator.add.call_count == 2
|
|
|
|
# Verify the entries were converted correctly
|
|
added_queries = [
|
|
call[0][0] for call in mock_aggregator.add.call_args_list
|
|
]
|
|
|
|
assert any(
|
|
"SELECT * FROM current_table" in query.query
|
|
for query in added_queries
|
|
)
|
|
assert any(
|
|
"SELECT * FROM historical_table" in query.query
|
|
for query in added_queries
|
|
)
|
|
|
|
def test_query_logging_and_progress_tracking(self):
|
|
"""Test that proper logging occurs when processing multiple queries."""
|
|
config = TeradataConfig.parse_obj(
|
|
{
|
|
**_base_config(),
|
|
"include_historical_lineage": True,
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
with patch.object(
|
|
source, "_make_lineage_queries", return_value=["query1", "query2"]
|
|
):
|
|
mock_result = MagicMock()
|
|
|
|
call_counter = {"count": 0}
|
|
|
|
def mock_fetchmany_side_effect(batch_size):
|
|
# Return one batch then empty to simulate end of results
|
|
call_counter["count"] += 1
|
|
if call_counter["count"] == 1:
|
|
return [MagicMock(query_text="SELECT 1")]
|
|
return []
|
|
|
|
mock_result.fetchmany.side_effect = mock_fetchmany_side_effect
|
|
|
|
mock_connection = MagicMock()
|
|
mock_engine = MagicMock()
|
|
mock_engine.connect.return_value.__enter__.return_value = (
|
|
mock_connection
|
|
)
|
|
|
|
with (
|
|
patch.object(
|
|
source, "get_metadata_engine", return_value=mock_engine
|
|
),
|
|
patch.object(
|
|
source,
|
|
"_execute_with_cursor_fallback",
|
|
return_value=mock_result,
|
|
),
|
|
patch(
|
|
"datahub.ingestion.source.sql.teradata.logger"
|
|
) as mock_logger,
|
|
):
|
|
list(source._fetch_lineage_entries_chunked())
|
|
|
|
# Verify progress logging for multiple queries
|
|
info_calls = [
|
|
call for call in mock_logger.info.call_args_list if call[0]
|
|
]
|
|
|
|
# Should log execution of query 1/2 and 2/2
|
|
assert any("query 1/2" in str(call) for call in info_calls)
|
|
assert any("query 2/2" in str(call) for call in info_calls)
|
|
|
|
# Should log completion of both queries
|
|
assert any("Completed query 1" in str(call) for call in info_calls)
|
|
assert any("Completed query 2" in str(call) for call in info_calls)
|
|
|
|
|
|
class TestQueryConstruction:
|
|
"""Test the construction of individual queries."""
|
|
|
|
def test_current_query_construction(self):
|
|
"""Test that the current query is constructed correctly."""
|
|
config = TeradataConfig.parse_obj(
|
|
{
|
|
**_base_config(),
|
|
"start_time": "2024-01-01T00:00:00Z",
|
|
"end_time": "2024-01-02T00:00:00Z",
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
queries = source._make_lineage_queries()
|
|
current_query = queries[0]
|
|
|
|
# Verify current query structure
|
|
assert 'FROM "DBC".QryLogV as l' in current_query
|
|
assert 'JOIN "DBC".QryLogSqlV as s' in current_query
|
|
assert "l.ErrorCode = 0" in current_query
|
|
assert "2024-01-01" in current_query
|
|
assert "2024-01-02" in current_query
|
|
assert 'ORDER BY "timestamp", "query_id", "row_no"' in current_query
|
|
|
|
def test_historical_query_construction(self):
|
|
"""Test that the UNION query contains historical data correctly."""
|
|
config = TeradataConfig.parse_obj(
|
|
{
|
|
**_base_config(),
|
|
"include_historical_lineage": True,
|
|
"start_time": "2024-01-01T00:00:00Z",
|
|
"end_time": "2024-01-02T00:00:00Z",
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
with patch.object(
|
|
source, "_check_historical_table_exists", return_value=True
|
|
):
|
|
queries = source._make_lineage_queries()
|
|
union_query = queries[0]
|
|
|
|
# Verify UNION query contains historical data structure
|
|
assert 'FROM "PDCRINFO".DBQLSqlTbl_Hst as h' in union_query
|
|
assert "h.ErrorCode = 0" in union_query
|
|
assert "h.StartTime AT TIME ZONE 'GMT'" in union_query
|
|
assert "h.DefaultDatabase" in union_query
|
|
assert "2024-01-01" in union_query
|
|
assert "2024-01-02" in union_query
|
|
assert 'ORDER BY "timestamp", "query_id", "row_no"' in union_query
|
|
assert "UNION" in union_query
|
|
|
|
|
|
class TestStreamingQueryReconstruction:
|
|
"""Test the streaming query reconstruction functionality."""
|
|
|
|
def test_reconstruct_queries_streaming_single_row_queries(self):
|
|
"""Test streaming reconstruction with single-row queries."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Create entries for single-row queries
|
|
entries = [
|
|
self._create_mock_entry(
|
|
"Q1", "SELECT * FROM table1", 1, "2024-01-01 10:00:00"
|
|
),
|
|
self._create_mock_entry(
|
|
"Q2", "SELECT * FROM table2", 1, "2024-01-01 10:01:00"
|
|
),
|
|
self._create_mock_entry(
|
|
"Q3", "SELECT * FROM table3", 1, "2024-01-01 10:02:00"
|
|
),
|
|
]
|
|
|
|
# Test streaming reconstruction
|
|
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
|
|
|
|
assert len(reconstructed_queries) == 3
|
|
assert reconstructed_queries[0].query == "SELECT * FROM table1"
|
|
assert reconstructed_queries[1].query == "SELECT * FROM table2"
|
|
assert reconstructed_queries[2].query == "SELECT * FROM table3"
|
|
|
|
# Verify metadata preservation
|
|
assert reconstructed_queries[0].timestamp == "2024-01-01 10:00:00"
|
|
assert reconstructed_queries[1].timestamp == "2024-01-01 10:01:00"
|
|
assert reconstructed_queries[2].timestamp == "2024-01-01 10:02:00"
|
|
|
|
def test_reconstruct_queries_streaming_multi_row_queries(self):
|
|
"""Test streaming reconstruction with multi-row queries."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Create entries for multi-row queries
|
|
entries = [
|
|
# Query 1: 3 rows
|
|
self._create_mock_entry(
|
|
"Q1", "SELECT a, b, c ", 1, "2024-01-01 10:00:00"
|
|
),
|
|
self._create_mock_entry(
|
|
"Q1", "FROM large_table ", 2, "2024-01-01 10:00:00"
|
|
),
|
|
self._create_mock_entry(
|
|
"Q1", "WHERE id > 1000", 3, "2024-01-01 10:00:00"
|
|
),
|
|
# Query 2: 2 rows
|
|
self._create_mock_entry(
|
|
"Q2", "UPDATE table3 SET ", 1, "2024-01-01 10:01:00"
|
|
),
|
|
self._create_mock_entry(
|
|
"Q2", "status = 'active'", 2, "2024-01-01 10:01:00"
|
|
),
|
|
]
|
|
|
|
# Test streaming reconstruction
|
|
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
|
|
|
|
assert len(reconstructed_queries) == 2
|
|
assert (
|
|
reconstructed_queries[0].query
|
|
== "SELECT a, b, c FROM large_table WHERE id > 1000"
|
|
)
|
|
assert (
|
|
reconstructed_queries[1].query == "UPDATE table3 SET status = 'active'"
|
|
)
|
|
|
|
# Verify metadata preservation (should use metadata from first row of each query)
|
|
assert reconstructed_queries[0].timestamp == "2024-01-01 10:00:00"
|
|
assert reconstructed_queries[1].timestamp == "2024-01-01 10:01:00"
|
|
|
|
def test_reconstruct_queries_streaming_mixed_queries(self):
|
|
"""Test streaming reconstruction with mixed single and multi-row queries."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Create entries mixing single and multi-row queries
|
|
entries = [
|
|
# Single-row query
|
|
self._create_mock_entry(
|
|
"Q1", "SELECT * FROM table1", 1, "2024-01-01 10:00:00"
|
|
),
|
|
# Multi-row query (3 rows)
|
|
self._create_mock_entry(
|
|
"Q2", "SELECT a, b, c ", 1, "2024-01-01 10:01:00"
|
|
),
|
|
self._create_mock_entry(
|
|
"Q2", "FROM large_table ", 2, "2024-01-01 10:01:00"
|
|
),
|
|
self._create_mock_entry(
|
|
"Q2", "WHERE id > 1000", 3, "2024-01-01 10:01:00"
|
|
),
|
|
# Single-row query
|
|
self._create_mock_entry(
|
|
"Q3", "SELECT COUNT(*) FROM table2", 1, "2024-01-01 10:02:00"
|
|
),
|
|
# Multi-row query (2 rows)
|
|
self._create_mock_entry(
|
|
"Q4", "UPDATE table3 SET ", 1, "2024-01-01 10:03:00"
|
|
),
|
|
self._create_mock_entry(
|
|
"Q4", "status = 'active'", 2, "2024-01-01 10:03:00"
|
|
),
|
|
]
|
|
|
|
# Test streaming reconstruction
|
|
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
|
|
|
|
assert len(reconstructed_queries) == 4
|
|
assert reconstructed_queries[0].query == "SELECT * FROM table1"
|
|
assert (
|
|
reconstructed_queries[1].query
|
|
== "SELECT a, b, c FROM large_table WHERE id > 1000"
|
|
)
|
|
assert reconstructed_queries[2].query == "SELECT COUNT(*) FROM table2"
|
|
assert (
|
|
reconstructed_queries[3].query == "UPDATE table3 SET status = 'active'"
|
|
)
|
|
|
|
def test_reconstruct_queries_streaming_empty_entries(self):
|
|
"""Test streaming reconstruction with empty entries."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Test with empty entries
|
|
entries: List[Any] = []
|
|
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
|
|
assert len(reconstructed_queries) == 0
|
|
|
|
def test_reconstruct_queries_streaming_teradata_specific_transformations(self):
|
|
"""Test that Teradata-specific transformations are applied."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Create entry with Teradata-specific syntax
|
|
entries = [
|
|
self._create_mock_entry(
|
|
"Q1",
|
|
"SELECT * FROM table1 (NOT CASESPECIFIC)",
|
|
1,
|
|
"2024-01-01 10:00:00",
|
|
),
|
|
]
|
|
|
|
# Test streaming reconstruction
|
|
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
|
|
|
|
assert len(reconstructed_queries) == 1
|
|
# Should remove (NOT CASESPECIFIC)
|
|
assert reconstructed_queries[0].query == "SELECT * FROM table1 "
|
|
|
|
def test_reconstruct_queries_streaming_metadata_preservation(self):
|
|
"""Test that all metadata fields are preserved correctly."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Create entry with all metadata fields
|
|
entries: List[Any] = [
|
|
self._create_mock_entry(
|
|
"Q1",
|
|
"SELECT * FROM table1",
|
|
1,
|
|
"2024-01-01 10:00:00",
|
|
user="test_user",
|
|
default_database="test_db",
|
|
session_id="session123",
|
|
),
|
|
]
|
|
|
|
# Test streaming reconstruction
|
|
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
|
|
|
|
assert len(reconstructed_queries) == 1
|
|
query = reconstructed_queries[0]
|
|
|
|
# Verify all metadata fields
|
|
assert query.query == "SELECT * FROM table1"
|
|
assert query.timestamp == "2024-01-01 10:00:00"
|
|
assert isinstance(query.user, CorpUserUrn)
|
|
assert str(query.user) == "urn:li:corpuser:test_user"
|
|
assert query.default_db == "test_db"
|
|
assert query.default_schema == "test_db" # Teradata uses database as schema
|
|
assert query.session_id == "session123"
|
|
|
|
def test_reconstruct_queries_streaming_with_none_user(self):
|
|
"""Test streaming reconstruction handles None user correctly."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Create entry with None user
|
|
entries = [
|
|
self._create_mock_entry(
|
|
"Q1", "SELECT * FROM table1", 1, "2024-01-01 10:00:00", user=None
|
|
),
|
|
]
|
|
|
|
# Test streaming reconstruction
|
|
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
|
|
|
|
assert len(reconstructed_queries) == 1
|
|
assert reconstructed_queries[0].user is None
|
|
|
|
def test_reconstruct_queries_streaming_empty_query_text(self):
|
|
"""Test streaming reconstruction handles empty query text correctly."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Create entries with empty query text
|
|
entries = [
|
|
self._create_mock_entry("Q1", "", 1, "2024-01-01 10:00:00"),
|
|
self._create_mock_entry(
|
|
"Q2", "SELECT * FROM table1", 1, "2024-01-01 10:01:00"
|
|
),
|
|
]
|
|
|
|
# Test streaming reconstruction
|
|
reconstructed_queries = list(source._reconstruct_queries_streaming(entries))
|
|
|
|
# Should only get one query (the non-empty one)
|
|
assert len(reconstructed_queries) == 1
|
|
assert reconstructed_queries[0].query == "SELECT * FROM table1"
|
|
|
|
def test_reconstruct_queries_streaming_space_joining_behavior(self):
|
|
"""Test that query parts are joined directly without adding spaces."""
|
|
config = TeradataConfig.parse_obj(_base_config())
|
|
|
|
with patch(
|
|
"datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator"
|
|
) as mock_aggregator_class:
|
|
mock_aggregator = MagicMock()
|
|
mock_aggregator_class.return_value = mock_aggregator
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.sql.teradata.TeradataSource.cache_tables_and_views"
|
|
):
|
|
source = TeradataSource(config, PipelineContext(run_id="test"))
|
|
|
|
# Test case 1: Parts that include their own spacing
|
|
entries1 = [
|
|
self._create_mock_entry("Q1", "SELECT ", 1, "2024-01-01 10:00:00"),
|
|
self._create_mock_entry("Q1", "col1, ", 2, "2024-01-01 10:00:00"),
|
|
self._create_mock_entry("Q1", "col2 ", 3, "2024-01-01 10:00:00"),
|
|
self._create_mock_entry("Q1", "FROM ", 4, "2024-01-01 10:00:00"),
|
|
self._create_mock_entry("Q1", "table1", 5, "2024-01-01 10:00:00"),
|
|
]
|
|
|
|
# Test case 2: Parts that already have trailing/leading spaces
|
|
entries2 = [
|
|
self._create_mock_entry("Q2", "SELECT * ", 1, "2024-01-01 10:01:00"),
|
|
self._create_mock_entry("Q2", "FROM table2 ", 2, "2024-01-01 10:01:00"),
|
|
self._create_mock_entry("Q2", "WHERE id > 1", 3, "2024-01-01 10:01:00"),
|
|
]
|
|
|
|
# Test streaming reconstruction
|
|
all_entries = entries1 + entries2
|
|
reconstructed_queries = list(
|
|
source._reconstruct_queries_streaming(all_entries)
|
|
)
|
|
|
|
assert len(reconstructed_queries) == 2
|
|
|
|
# Query 1: Should be joined directly without adding spaces
|
|
assert reconstructed_queries[0].query == "SELECT col1, col2 FROM table1"
|
|
|
|
# Query 2: Should handle existing spaces correctly (may have extra spaces)
|
|
assert reconstructed_queries[1].query == "SELECT * FROM table2 WHERE id > 1"
|
|
|
|
def _create_mock_entry(
|
|
self,
|
|
query_id,
|
|
query_text,
|
|
row_no,
|
|
timestamp,
|
|
user="test_user",
|
|
default_database="test_db",
|
|
session_id=None,
|
|
):
|
|
"""Create a mock database entry for testing."""
|
|
entry = MagicMock()
|
|
entry.query_id = query_id
|
|
entry.query_text = query_text
|
|
entry.row_no = row_no
|
|
entry.timestamp = timestamp
|
|
entry.user = user
|
|
entry.default_database = default_database
|
|
entry.session_id = session_id or f"session_{query_id}"
|
|
return entry
|