fix(ingest/snowflake): Skip sql parsing if all the features disable in config where it is needed (#14908)

This commit is contained in:
Tamas Nemeth 2025-10-03 11:08:26 +02:00 committed by GitHub
parent 9b6ad2263f
commit 78d258383f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 236 additions and 10 deletions

View File

@ -297,15 +297,31 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
if use_cached_audit_log: if use_cached_audit_log:
logger.info(f"Using cached audit log at {audit_log_file}") logger.info(f"Using cached audit log at {audit_log_file}")
else: else:
logger.info(f"Fetching audit log into {audit_log_file}") # Check if any query-based features are enabled before fetching
needs_query_data = any(
[
self.config.include_lineage,
self.config.include_queries,
self.config.include_usage_statistics,
self.config.include_query_usage_statistics,
self.config.include_operations,
]
)
with self.report.copy_history_fetch_timer: if not needs_query_data:
for copy_entry in self.fetch_copy_history(): logger.info(
queries.append(copy_entry) "All query-based features are disabled. Skipping expensive query log fetch."
)
else:
logger.info(f"Fetching audit log into {audit_log_file}")
with self.report.query_log_fetch_timer: with self.report.copy_history_fetch_timer:
for entry in self.fetch_query_log(users): for copy_entry in self.fetch_copy_history():
queries.append(entry) queries.append(copy_entry)
with self.report.query_log_fetch_timer:
for entry in self.fetch_query_log(users):
queries.append(entry)
stored_proc_tracker: StoredProcLineageTracker = self._exit_stack.enter_context( stored_proc_tracker: StoredProcLineageTracker = self._exit_stack.enter_context(
StoredProcLineageTracker( StoredProcLineageTracker(

View File

@ -1,13 +1,24 @@
import datetime import datetime
from unittest.mock import Mock, patch
import pytest import pytest
import sqlglot import sqlglot
from sqlglot.dialects.snowflake import Snowflake from sqlglot.dialects.snowflake import Snowflake
from datahub.configuration.common import AllowDenyPattern from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.time_window_config import BucketDuration from datahub.configuration.time_window_config import (
from datahub.ingestion.source.snowflake.snowflake_config import QueryDedupStrategyType BaseTimeWindowConfig,
from datahub.ingestion.source.snowflake.snowflake_queries import QueryLogQueryBuilder BucketDuration,
)
from datahub.ingestion.source.snowflake.snowflake_config import (
QueryDedupStrategyType,
SnowflakeIdentifierConfig,
)
from datahub.ingestion.source.snowflake.snowflake_queries import (
QueryLogQueryBuilder,
SnowflakeQueriesExtractor,
SnowflakeQueriesExtractorConfig,
)
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
@ -389,3 +400,202 @@ class TestSnowflakeViewQueries:
assert where_clause is not None assert where_clause is not None
where_str = str(where_clause).upper() where_str = str(where_clause).upper()
assert "TABLE_SCHEMA" in where_str and "PUBLIC" in where_str assert "TABLE_SCHEMA" in where_str and "PUBLIC" in where_str
class TestSnowflakeQueriesExtractorOptimization:
"""Tests for the query fetch optimization when all features are disabled."""
def _create_mock_extractor(
self,
include_lineage: bool = False,
include_queries: bool = False,
include_usage_statistics: bool = False,
include_query_usage_statistics: bool = False,
include_operations: bool = False,
) -> SnowflakeQueriesExtractor:
"""Helper to create a SnowflakeQueriesExtractor with mocked dependencies."""
mock_connection = Mock()
mock_connection.query.return_value = []
config = SnowflakeQueriesExtractorConfig(
window=BaseTimeWindowConfig(
start_time=datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc),
end_time=datetime.datetime(2021, 1, 2, tzinfo=datetime.timezone.utc),
),
include_lineage=include_lineage,
include_queries=include_queries,
include_usage_statistics=include_usage_statistics,
include_query_usage_statistics=include_query_usage_statistics,
include_operations=include_operations,
)
mock_report = Mock()
mock_filters = Mock()
mock_identifiers = Mock()
mock_identifiers.platform = "snowflake"
mock_identifiers.identifier_config = SnowflakeIdentifierConfig()
extractor = SnowflakeQueriesExtractor(
connection=mock_connection,
config=config,
structured_report=mock_report,
filters=mock_filters,
identifiers=mock_identifiers,
)
return extractor
def test_skip_query_fetch_when_all_features_disabled(self):
"""Test that query fetching is skipped when all query features are disabled."""
extractor = self._create_mock_extractor(
include_lineage=False,
include_queries=False,
include_usage_statistics=False,
include_query_usage_statistics=False,
include_operations=False,
)
# Mock the fetch methods
with (
patch.object(extractor, "fetch_users", return_value={}) as mock_fetch_users,
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
# Execute the method
list(extractor.get_workunits_internal())
# Verify fetch_users was called (always needed for setup)
mock_fetch_users.assert_called_once()
# Verify expensive fetches were NOT called
mock_fetch_copy_history.assert_not_called()
mock_fetch_query_log.assert_not_called()
def test_fetch_queries_when_lineage_enabled(self):
"""Test that query fetching happens when lineage is enabled."""
extractor = self._create_mock_extractor(
include_lineage=True,
include_queries=False,
include_usage_statistics=False,
include_query_usage_statistics=False,
include_operations=False,
)
with (
patch.object(extractor, "fetch_users", return_value={}) as mock_fetch_users,
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
list(extractor.get_workunits_internal())
mock_fetch_users.assert_called_once()
mock_fetch_copy_history.assert_called_once()
mock_fetch_query_log.assert_called_once()
def test_fetch_queries_when_usage_statistics_enabled(self):
"""Test that query fetching happens when usage statistics are enabled."""
extractor = self._create_mock_extractor(
include_lineage=False,
include_queries=False,
include_usage_statistics=True,
include_query_usage_statistics=False,
include_operations=False,
)
with (
patch.object(extractor, "fetch_users", return_value={}),
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
list(extractor.get_workunits_internal())
mock_fetch_copy_history.assert_called_once()
mock_fetch_query_log.assert_called_once()
def test_fetch_queries_when_operations_enabled(self):
"""Test that query fetching happens when operations are enabled."""
extractor = self._create_mock_extractor(
include_lineage=False,
include_queries=False,
include_usage_statistics=False,
include_query_usage_statistics=False,
include_operations=True,
)
with (
patch.object(extractor, "fetch_users", return_value={}),
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
list(extractor.get_workunits_internal())
mock_fetch_copy_history.assert_called_once()
mock_fetch_query_log.assert_called_once()
def test_fetch_queries_when_any_single_feature_enabled(self):
"""Test that query fetching happens when any single feature is enabled."""
features = [
"include_lineage",
"include_queries",
"include_usage_statistics",
"include_query_usage_statistics",
"include_operations",
]
for feature in features:
kwargs = {f: False for f in features}
kwargs[feature] = True
extractor = self._create_mock_extractor(**kwargs)
with (
patch.object(extractor, "fetch_users", return_value={}),
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
list(extractor.get_workunits_internal())
# Verify fetches were called
mock_fetch_copy_history.assert_called_once()
mock_fetch_query_log.assert_called_once()
def test_report_counts_with_disabled_features(self):
"""Test that report counts are zero when features are disabled."""
extractor = self._create_mock_extractor(
include_lineage=False,
include_queries=False,
include_usage_statistics=False,
include_query_usage_statistics=False,
include_operations=False,
)
with (
patch.object(extractor, "fetch_users", return_value={}),
patch.object(extractor, "fetch_copy_history", return_value=[]),
patch.object(extractor, "fetch_query_log", return_value=[]),
):
list(extractor.get_workunits_internal())
# Verify that num_preparsed_queries is 0
assert extractor.report.sql_aggregator is not None
assert extractor.report.sql_aggregator.num_preparsed_queries == 0