fix(ingest/snowflake): Skip sql parsing if all the features disable in config where it is needed (#14908)

This commit is contained in:
Tamas Nemeth 2025-10-03 11:08:26 +02:00 committed by GitHub
parent 9b6ad2263f
commit 78d258383f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 236 additions and 10 deletions

View File

@ -297,15 +297,31 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
if use_cached_audit_log:
logger.info(f"Using cached audit log at {audit_log_file}")
else:
logger.info(f"Fetching audit log into {audit_log_file}")
# Check if any query-based features are enabled before fetching
needs_query_data = any(
[
self.config.include_lineage,
self.config.include_queries,
self.config.include_usage_statistics,
self.config.include_query_usage_statistics,
self.config.include_operations,
]
)
with self.report.copy_history_fetch_timer:
for copy_entry in self.fetch_copy_history():
queries.append(copy_entry)
if not needs_query_data:
logger.info(
"All query-based features are disabled. Skipping expensive query log fetch."
)
else:
logger.info(f"Fetching audit log into {audit_log_file}")
with self.report.query_log_fetch_timer:
for entry in self.fetch_query_log(users):
queries.append(entry)
with self.report.copy_history_fetch_timer:
for copy_entry in self.fetch_copy_history():
queries.append(copy_entry)
with self.report.query_log_fetch_timer:
for entry in self.fetch_query_log(users):
queries.append(entry)
stored_proc_tracker: StoredProcLineageTracker = self._exit_stack.enter_context(
StoredProcLineageTracker(

View File

@ -1,13 +1,24 @@
import datetime
from unittest.mock import Mock, patch
import pytest
import sqlglot
from sqlglot.dialects.snowflake import Snowflake
from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.time_window_config import BucketDuration
from datahub.ingestion.source.snowflake.snowflake_config import QueryDedupStrategyType
from datahub.ingestion.source.snowflake.snowflake_queries import QueryLogQueryBuilder
from datahub.configuration.time_window_config import (
BaseTimeWindowConfig,
BucketDuration,
)
from datahub.ingestion.source.snowflake.snowflake_config import (
QueryDedupStrategyType,
SnowflakeIdentifierConfig,
)
from datahub.ingestion.source.snowflake.snowflake_queries import (
QueryLogQueryBuilder,
SnowflakeQueriesExtractor,
SnowflakeQueriesExtractorConfig,
)
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
@ -389,3 +400,202 @@ class TestSnowflakeViewQueries:
assert where_clause is not None
where_str = str(where_clause).upper()
assert "TABLE_SCHEMA" in where_str and "PUBLIC" in where_str
class TestSnowflakeQueriesExtractorOptimization:
"""Tests for the query fetch optimization when all features are disabled."""
def _create_mock_extractor(
self,
include_lineage: bool = False,
include_queries: bool = False,
include_usage_statistics: bool = False,
include_query_usage_statistics: bool = False,
include_operations: bool = False,
) -> SnowflakeQueriesExtractor:
"""Helper to create a SnowflakeQueriesExtractor with mocked dependencies."""
mock_connection = Mock()
mock_connection.query.return_value = []
config = SnowflakeQueriesExtractorConfig(
window=BaseTimeWindowConfig(
start_time=datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc),
end_time=datetime.datetime(2021, 1, 2, tzinfo=datetime.timezone.utc),
),
include_lineage=include_lineage,
include_queries=include_queries,
include_usage_statistics=include_usage_statistics,
include_query_usage_statistics=include_query_usage_statistics,
include_operations=include_operations,
)
mock_report = Mock()
mock_filters = Mock()
mock_identifiers = Mock()
mock_identifiers.platform = "snowflake"
mock_identifiers.identifier_config = SnowflakeIdentifierConfig()
extractor = SnowflakeQueriesExtractor(
connection=mock_connection,
config=config,
structured_report=mock_report,
filters=mock_filters,
identifiers=mock_identifiers,
)
return extractor
def test_skip_query_fetch_when_all_features_disabled(self):
"""Test that query fetching is skipped when all query features are disabled."""
extractor = self._create_mock_extractor(
include_lineage=False,
include_queries=False,
include_usage_statistics=False,
include_query_usage_statistics=False,
include_operations=False,
)
# Mock the fetch methods
with (
patch.object(extractor, "fetch_users", return_value={}) as mock_fetch_users,
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
# Execute the method
list(extractor.get_workunits_internal())
# Verify fetch_users was called (always needed for setup)
mock_fetch_users.assert_called_once()
# Verify expensive fetches were NOT called
mock_fetch_copy_history.assert_not_called()
mock_fetch_query_log.assert_not_called()
def test_fetch_queries_when_lineage_enabled(self):
"""Test that query fetching happens when lineage is enabled."""
extractor = self._create_mock_extractor(
include_lineage=True,
include_queries=False,
include_usage_statistics=False,
include_query_usage_statistics=False,
include_operations=False,
)
with (
patch.object(extractor, "fetch_users", return_value={}) as mock_fetch_users,
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
list(extractor.get_workunits_internal())
mock_fetch_users.assert_called_once()
mock_fetch_copy_history.assert_called_once()
mock_fetch_query_log.assert_called_once()
def test_fetch_queries_when_usage_statistics_enabled(self):
"""Test that query fetching happens when usage statistics are enabled."""
extractor = self._create_mock_extractor(
include_lineage=False,
include_queries=False,
include_usage_statistics=True,
include_query_usage_statistics=False,
include_operations=False,
)
with (
patch.object(extractor, "fetch_users", return_value={}),
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
list(extractor.get_workunits_internal())
mock_fetch_copy_history.assert_called_once()
mock_fetch_query_log.assert_called_once()
def test_fetch_queries_when_operations_enabled(self):
"""Test that query fetching happens when operations are enabled."""
extractor = self._create_mock_extractor(
include_lineage=False,
include_queries=False,
include_usage_statistics=False,
include_query_usage_statistics=False,
include_operations=True,
)
with (
patch.object(extractor, "fetch_users", return_value={}),
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
list(extractor.get_workunits_internal())
mock_fetch_copy_history.assert_called_once()
mock_fetch_query_log.assert_called_once()
def test_fetch_queries_when_any_single_feature_enabled(self):
"""Test that query fetching happens when any single feature is enabled."""
features = [
"include_lineage",
"include_queries",
"include_usage_statistics",
"include_query_usage_statistics",
"include_operations",
]
for feature in features:
kwargs = {f: False for f in features}
kwargs[feature] = True
extractor = self._create_mock_extractor(**kwargs)
with (
patch.object(extractor, "fetch_users", return_value={}),
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
list(extractor.get_workunits_internal())
# Verify fetches were called
mock_fetch_copy_history.assert_called_once()
mock_fetch_query_log.assert_called_once()
def test_report_counts_with_disabled_features(self):
"""Test that report counts are zero when features are disabled."""
extractor = self._create_mock_extractor(
include_lineage=False,
include_queries=False,
include_usage_statistics=False,
include_query_usage_statistics=False,
include_operations=False,
)
with (
patch.object(extractor, "fetch_users", return_value={}),
patch.object(extractor, "fetch_copy_history", return_value=[]),
patch.object(extractor, "fetch_query_log", return_value=[]),
):
list(extractor.get_workunits_internal())
# Verify that num_preparsed_queries is 0
assert extractor.report.sql_aggregator is not None
assert extractor.report.sql_aggregator.num_preparsed_queries == 0