feat(sql-parsing-aggregator): add override_dialect for observed query (#14201)

This commit is contained in:
Sergio Gómez Villamor 2025-07-25 08:21:44 +02:00 committed by GitHub
parent 63f6432653
commit e24fc39966
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 3 deletions

View File

@ -73,9 +73,8 @@ class SqlQueriesSourceConfig(PlatformInstanceConfigMixin, EnvConfigMixin):
default=None,
)
override_dialect: Optional[str] = Field(
description="DEPRECATED: This field is ignored. SQL dialect detection is now handled automatically by the SQL parsing aggregator based on the platform.",
description="The SQL dialect to use when parsing queries. Overrides automatic dialect detection.",
default=None,
hidden_from_docs=True,
)
@ -230,6 +229,7 @@ class SqlQueriesSource(Source):
session_id=query_entry.session_id,
default_db=self.config.default_db,
default_schema=self.config.default_schema,
override_dialect=self.config.override_dialect,
)
self.aggregator.add_observed_query(observed_query)

View File

@ -49,6 +49,7 @@ from datahub.sql_parsing.sqlglot_lineage import (
sqlglot_lineage,
)
from datahub.sql_parsing.sqlglot_utils import (
DialectOrStr,
_parse_statement,
get_query_fingerprint,
try_format_query,
@ -109,6 +110,7 @@ class ObservedQuery:
default_schema: Optional[str] = None
query_hash: Optional[str] = None
usage_multiplier: int = 1
override_dialect: Optional[DialectOrStr] = None
# Use this to store additional key-value information about the query for debugging.
extra_info: Optional[dict] = None
@ -834,6 +836,7 @@ class SqlParsingAggregator(Closeable):
session_id=session_id,
timestamp=observed.timestamp,
user=observed.user,
override_dialect=observed.override_dialect,
)
if parsed.debug_info.error:
self.report.observed_query_parse_failures.append(
@ -1168,6 +1171,7 @@ class SqlParsingAggregator(Closeable):
session_id: str = _MISSING_SESSION_ID,
timestamp: Optional[datetime] = None,
user: Optional[Union[CorpUserUrn, CorpGroupUrn]] = None,
override_dialect: Optional[DialectOrStr] = None,
) -> SqlParsingResult:
with self.report.sql_parsing_timer:
parsed = sqlglot_lineage(
@ -1175,6 +1179,7 @@ class SqlParsingAggregator(Closeable):
schema_resolver=schema_resolver,
default_db=default_db,
default_schema=default_schema,
override_dialect=override_dialect,
)
self.report.num_sql_parsed += 1

View File

@ -2,7 +2,7 @@ import functools
import os
import pathlib
from datetime import datetime, timedelta, timezone
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from freezegun import freeze_time
@ -1028,6 +1028,48 @@ def test_sql_aggreator_close_cleans_tmp(tmp_path):
assert len(os.listdir(tmp_path)) == 0
@freeze_time(FROZEN_TIME)
def test_override_dialect_passed_to_sqlglot_lineage() -> None:
"""Test that override_dialect is correctly passed to sqlglot_lineage"""
aggregator = SqlParsingAggregator(
platform="redshift",
generate_lineage=True,
generate_usage_statistics=False,
generate_operations=False,
)
base_query = ObservedQuery(
query="create table foo as select a, b from bar",
default_db="dev",
default_schema="public",
)
with patch(
"datahub.sql_parsing.sql_parsing_aggregator.sqlglot_lineage"
) as mock_sqlglot_lineage:
mock_sqlglot_lineage.return_value = MagicMock()
# Test with override_dialect set
base_query.override_dialect = "snowflake"
aggregator.add_observed_query(base_query)
mock_sqlglot_lineage.assert_called_once()
call_args = mock_sqlglot_lineage.call_args
assert call_args.kwargs["override_dialect"] == "snowflake"
# Reset mock
mock_sqlglot_lineage.reset_mock()
# Test without override_dialect (should be None)
base_query.override_dialect = None
aggregator.add_observed_query(base_query)
mock_sqlglot_lineage.assert_called_once()
call_args = mock_sqlglot_lineage.call_args
assert call_args.kwargs["override_dialect"] is None
@freeze_time(FROZEN_TIME)
def test_diamond_problem(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
aggregator = SqlParsingAggregator(