605 lines
24 KiB
Python

import json
import tempfile
from datetime import datetime, timezone
from functools import partial
from typing import cast
from unittest.mock import Mock
import pytest
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.incremental_lineage_helper import auto_incremental_lineage
from datahub.ingestion.api.source_helpers import auto_workunit_reporter
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.source.sql_queries import (
QueryEntry,
SqlQueriesSource,
SqlQueriesSourceConfig,
)
from datahub.metadata.urns import CorpUserUrn, DatasetUrn
class TestQueryEntry:
@pytest.mark.parametrize(
"entry_dict,entry_config,expected_query_entry,should_raise",
[
# Timestamp format variations
pytest.param(
{
"query": "SELECT * FROM table",
"timestamp": 1609459200,
"user": "test_user",
"upstream_tables": ["table1"],
},
SqlQueriesSourceConfig(platform="athena", query_file="dummy.json"),
QueryEntry(
query="SELECT * FROM table",
timestamp=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
user=CorpUserUrn("test_user"),
operation_type=None,
downstream_tables=[],
upstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:athena,table1,PROD)"
)
],
session_id=None,
),
False,
id="numeric_unix_timestamp",
),
pytest.param(
{
"query": "SELECT * FROM table",
"timestamp": 1609459200.5,
"user": "test_user",
"upstream_tables": ["table1"],
},
SqlQueriesSourceConfig(platform="athena", query_file="dummy.json"),
QueryEntry(
query="SELECT * FROM table",
timestamp=datetime(
2021, 1, 1, 0, 0, 0, 500000, tzinfo=timezone.utc
),
user=CorpUserUrn("test_user"),
operation_type=None,
downstream_tables=[],
upstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:athena,table1,PROD)"
)
],
session_id=None,
),
False,
id="float_unix_timestamp",
),
pytest.param(
{
"query": "SELECT * FROM table",
"timestamp": "1609459200",
"user": "test_user",
"upstream_tables": ["table1"],
},
SqlQueriesSourceConfig(platform="athena", query_file="dummy.json"),
QueryEntry(
query="SELECT * FROM table",
timestamp=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
user=CorpUserUrn("test_user"),
operation_type=None,
downstream_tables=[],
upstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:athena,table1,PROD)"
)
],
session_id=None,
),
False,
id="string_unix_timestamp",
),
pytest.param(
{
"query": "SELECT * FROM table",
"timestamp": "2021-01-01T00:00:00Z",
"user": "test_user",
"upstream_tables": ["table1"],
},
SqlQueriesSourceConfig(platform="athena", query_file="dummy.json"),
QueryEntry(
query="SELECT * FROM table",
timestamp=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
user=CorpUserUrn("test_user"),
operation_type=None,
downstream_tables=[],
upstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:athena,table1,PROD)"
)
],
session_id=None,
),
False,
id="iso_format_with_z",
),
pytest.param(
{
"query": "SELECT * FROM table",
"timestamp": "2025-07-19 15:08:07.000000",
"user": "test_user",
"upstream_tables": ["table1"],
},
SqlQueriesSourceConfig(platform="athena", query_file="dummy.json"),
QueryEntry(
query="SELECT * FROM table",
timestamp=datetime(2025, 7, 19, 15, 8, 7, tzinfo=timezone.utc),
user=CorpUserUrn("test_user"),
operation_type=None,
downstream_tables=[],
upstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:athena,table1,PROD)"
)
],
session_id=None,
),
False,
id="datetime_string_with_microseconds",
),
pytest.param(
{
"query": "SELECT * FROM table",
"timestamp": "2025-07-19 15:08:07",
"user": "test_user",
"upstream_tables": ["table1"],
},
SqlQueriesSourceConfig(platform="athena", query_file="dummy.json"),
QueryEntry(
query="SELECT * FROM table",
timestamp=datetime(2025, 7, 19, 15, 8, 7, tzinfo=timezone.utc),
user=CorpUserUrn("test_user"),
operation_type=None,
downstream_tables=[],
upstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:athena,table1,PROD)"
)
],
session_id=None,
),
False,
id="datetime_string_without_microseconds",
),
# Table filtering - empty strings removed
pytest.param(
{
"query": "SELECT * FROM table",
"timestamp": 1609459200,
"user": "test_user",
"upstream_tables": ["table1", "", "table2"],
"downstream_tables": ["output_table"],
},
SqlQueriesSourceConfig(platform="athena", query_file="dummy.json"),
QueryEntry(
query="SELECT * FROM table",
timestamp=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
user=CorpUserUrn("test_user"),
operation_type=None,
downstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:athena,output_table,PROD)"
)
],
upstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:athena,table1,PROD)"
),
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:athena,table2,PROD)"
),
],
session_id=None,
),
False,
id="filter_empty_upstream_tables",
),
pytest.param(
{
"query": "SELECT * FROM table",
"timestamp": 1609459200,
"user": "test_user",
"upstream_tables": [""],
"downstream_tables": [""],
},
SqlQueriesSourceConfig(platform="athena", query_file="dummy.json"),
QueryEntry(
query="SELECT * FROM table",
timestamp=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
user=CorpUserUrn("test_user"),
operation_type=None,
downstream_tables=[],
upstream_tables=[],
session_id=None,
),
False,
id="all_empty_tables",
),
pytest.param(
{
"query": "SELECT * FROM table",
"timestamp": 1609459200,
"user": "test_user",
"upstream_tables": [" ", "\t"],
"downstream_tables": [" "],
},
SqlQueriesSourceConfig(platform="athena", query_file="dummy.json"),
QueryEntry(
query="SELECT * FROM table",
timestamp=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
user=CorpUserUrn("test_user"),
operation_type=None,
downstream_tables=[],
upstream_tables=[],
session_id=None,
),
False,
id="whitespace_only_tables",
),
# Missing fields
pytest.param(
{
"query": "SELECT * FROM table",
"user": "test_user",
"upstream_tables": ["table1"],
},
SqlQueriesSourceConfig(platform="athena", query_file="dummy.json"),
QueryEntry(
query="SELECT * FROM table",
timestamp=None,
user=CorpUserUrn("test_user"),
operation_type=None,
downstream_tables=[],
upstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:athena,table1,PROD)"
)
],
session_id=None,
),
False,
id="no_timestamp",
),
pytest.param(
{
"query": "SELECT * FROM table",
"timestamp": 1609459200,
"upstream_tables": ["table1"],
},
SqlQueriesSourceConfig(platform="athena", query_file="dummy.json"),
QueryEntry(
query="SELECT * FROM table",
timestamp=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
user=None,
operation_type=None,
downstream_tables=[],
upstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:athena,table1,PROD)"
)
],
session_id=None,
),
False,
id="no_user",
),
# Both upstream and downstream tables
pytest.param(
{
"query": "CREATE TABLE out AS SELECT * FROM table",
"timestamp": 1609459200,
"user": "test_user",
"upstream_tables": ["table1"],
"downstream_tables": ["output_table"],
},
SqlQueriesSourceConfig(platform="athena", query_file="dummy.json"),
QueryEntry(
query="CREATE TABLE out AS SELECT * FROM table",
timestamp=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
user=CorpUserUrn("test_user"),
operation_type=None,
downstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:athena,output_table,PROD)"
)
],
upstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:athena,table1,PROD)"
)
],
session_id=None,
),
False,
id="upstream_and_downstream_tables",
),
# Config variations - platform, platform_instance, env
pytest.param(
{
"query": "SELECT * FROM table",
"timestamp": 1609459200,
"user": "test_user",
"upstream_tables": ["table1"],
},
SqlQueriesSourceConfig(
platform="snowflake",
query_file="dummy.json",
platform_instance="prod_instance",
env="DEV",
),
QueryEntry(
query="SELECT * FROM table",
timestamp=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
user=CorpUserUrn("test_user"),
operation_type=None,
downstream_tables=[],
upstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:snowflake,prod_instance.table1,DEV)"
)
],
session_id=None,
),
False,
id="config_with_platform_instance_and_env",
),
pytest.param(
{
"query": "SELECT * FROM table",
"timestamp": 1609459200,
"user": "test_user",
"upstream_tables": ["table1"],
},
SqlQueriesSourceConfig(
platform="bigquery", query_file="dummy.json", env="PRE"
),
QueryEntry(
query="SELECT * FROM table",
timestamp=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
user=CorpUserUrn("test_user"),
operation_type=None,
downstream_tables=[],
upstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:bigquery,table1,PRE)"
)
],
session_id=None,
),
False,
id="config_with_different_env",
),
pytest.param(
{
"query": "SELECT * FROM table",
"timestamp": 1609459200,
"user": "test_user",
"upstream_tables": ["table1"],
},
SqlQueriesSourceConfig(
platform="postgres",
query_file="dummy.json",
platform_instance="dev_cluster",
),
QueryEntry(
query="SELECT * FROM table",
timestamp=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
user=CorpUserUrn("test_user"),
operation_type=None,
downstream_tables=[],
upstream_tables=[
DatasetUrn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:postgres,dev_cluster.table1,PROD)"
)
],
session_id=None,
),
False,
id="config_with_platform_instance_only",
),
# Error cases
pytest.param(
{
"query": "SELECT * FROM table",
"timestamp": "invalid-timestamp",
"user": "test_user",
"upstream_tables": ["table1"],
},
SqlQueriesSourceConfig(platform="athena", query_file="dummy.json"),
None,
True,
id="invalid_timestamp_format",
),
],
)
def test_create(self, entry_dict, entry_config, expected_query_entry, should_raise):
"""Test QueryEntry creation with various input formats and edge cases."""
if should_raise:
with pytest.raises(ValueError):
QueryEntry.create(entry_dict, config=entry_config)
return
query_entry = QueryEntry.create(entry_dict, config=entry_config)
# Check all fields match expected
assert query_entry.query == expected_query_entry.query
assert query_entry.timestamp == expected_query_entry.timestamp
assert query_entry.user == expected_query_entry.user
assert query_entry.operation_type == expected_query_entry.operation_type
assert query_entry.downstream_tables == expected_query_entry.downstream_tables
assert query_entry.upstream_tables == expected_query_entry.upstream_tables
assert query_entry.session_id == expected_query_entry.session_id
class TestSqlQueriesSourceConfig:
"""Test configuration parsing for SqlQueriesSourceConfig."""
def test_incremental_lineage_default(self):
"""Test that incremental_lineage defaults to False."""
config_dict = {"query_file": "test.jsonl", "platform": "snowflake"}
config = SqlQueriesSourceConfig.parse_obj(config_dict)
assert config.incremental_lineage is False
def test_incremental_lineage_enabled(self):
"""Test that incremental_lineage can be enabled."""
config_dict = {
"query_file": "test.jsonl",
"platform": "snowflake",
"incremental_lineage": True,
}
config = SqlQueriesSourceConfig.parse_obj(config_dict)
assert config.incremental_lineage is True
def test_incremental_lineage_disabled_explicitly(self):
"""Test that incremental_lineage can be explicitly disabled."""
config_dict = {
"query_file": "test.jsonl",
"platform": "snowflake",
"incremental_lineage": False,
}
config = SqlQueriesSourceConfig.parse_obj(config_dict)
assert config.incremental_lineage is False
class TestSqlQueriesSource:
"""Test SqlQueriesSource functionality including patch lineage support."""
@pytest.fixture
def mock_graph(self):
"""Create a mock DataHubGraph."""
mock_graph = Mock(spec=DataHubGraph)
mock_graph.initialize_schema_resolver_from_datahub.return_value = None
return mock_graph
@pytest.fixture
def pipeline_context(self, mock_graph):
"""Create a PipelineContext with mock graph."""
return PipelineContext(run_id="test", graph=mock_graph)
@pytest.fixture
def temp_query_file(self):
"""Create a temporary query file for testing."""
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
# Write test queries
queries = [
{
"query": "INSERT INTO target_table SELECT * FROM source_table",
"timestamp": 1640995200, # 2022-01-01
"user": "test_user",
"downstream_tables": ["target_table"],
"upstream_tables": ["source_table"],
},
{
"query": "CREATE TABLE output AS SELECT * FROM input1 JOIN input2",
"timestamp": 1641081600, # 2022-01-02
"user": "another_user",
"downstream_tables": ["output"],
"upstream_tables": ["input1", "input2"],
},
]
for query in queries:
f.write(json.dumps(query) + "\n")
query_file_path = f.name
yield query_file_path
# Cleanup
import os
os.unlink(query_file_path)
def test_workunit_generation_structure(self, pipeline_context, temp_query_file):
"""Test that MCPs are generated with proper structure."""
config = SqlQueriesSourceConfig(
query_file=temp_query_file, platform="snowflake", incremental_lineage=True
)
source = SqlQueriesSource(pipeline_context, config)
# Generate MCPs (these will be converted to workunits by the processors)
mcps = list(source.get_workunits_internal())
# Should generate some MCPs (exact number depends on SQL aggregator behavior)
assert len(mcps) >= 0 # At minimum, no errors should occur
# All items should be MCPs
for mcp in mcps:
# Should be MetadataChangeProposalWrapper objects
assert hasattr(mcp, "aspectName") or hasattr(mcp, "aspect")
@pytest.mark.parametrize("incremental_lineage", [None, True, False])
def test_workunit_processors_with_incremental_lineage(
self, pipeline_context, temp_query_file, incremental_lineage
):
"""Test workunit processors with different incremental_lineage settings."""
# Handle None case (default behavior) by not passing the parameter
if incremental_lineage is None:
config = SqlQueriesSourceConfig(
query_file=temp_query_file,
platform="snowflake",
# incremental_lineage not specified, should default to False
)
expected_value = False # Default value
else:
config = SqlQueriesSourceConfig(
query_file=temp_query_file,
platform="snowflake",
incremental_lineage=incremental_lineage,
)
expected_value = incremental_lineage
source = SqlQueriesSource(pipeline_context, config)
# Verify config is properly set
assert source.config.incremental_lineage == expected_value
# Verify processors are set up correctly
processors = source.get_workunit_processors()
assert len(processors) == 2
assert all(proc is not None for proc in processors)
# Check that processors are the expected functions
partial_processors = [
cast(partial, proc) for proc in processors if isinstance(proc, partial)
]
processor_funcs = [proc.func for proc in partial_processors]
assert auto_workunit_reporter in processor_funcs
assert auto_incremental_lineage in processor_funcs
# Verify processors reflect the config
auto_incremental_processor = next(
cast(partial, proc)
for proc in processors
if isinstance(proc, partial) and proc.func == auto_incremental_lineage
)
assert auto_incremental_processor.args[0] == expected_value
def test_backward_compatibility(self, pipeline_context, temp_query_file):
"""Test that existing configurations without incremental_lineage still work."""
# This simulates an existing config that doesn't have incremental_lineage
config_dict = {
"query_file": temp_query_file,
"platform": "snowflake",
"usage": {"bucket_duration": "DAY"},
}
config = SqlQueriesSourceConfig.parse_obj(config_dict)
source = SqlQueriesSource(pipeline_context, config)
# Should default to False
assert source.config.incremental_lineage is False
# Should still work normally
processors = source.get_workunit_processors()
assert len(processors) == 2