mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-04 22:52:54 +00:00
feat(sql-parsing-aggregator): add override_dialect for observed query (#14201)
This commit is contained in:
parent
63f6432653
commit
e24fc39966
@ -73,9 +73,8 @@ class SqlQueriesSourceConfig(PlatformInstanceConfigMixin, EnvConfigMixin):
|
||||
default=None,
|
||||
)
|
||||
override_dialect: Optional[str] = Field(
|
||||
description="DEPRECATED: This field is ignored. SQL dialect detection is now handled automatically by the SQL parsing aggregator based on the platform.",
|
||||
description="The SQL dialect to use when parsing queries. Overrides automatic dialect detection.",
|
||||
default=None,
|
||||
hidden_from_docs=True,
|
||||
)
|
||||
|
||||
|
||||
@ -230,6 +229,7 @@ class SqlQueriesSource(Source):
|
||||
session_id=query_entry.session_id,
|
||||
default_db=self.config.default_db,
|
||||
default_schema=self.config.default_schema,
|
||||
override_dialect=self.config.override_dialect,
|
||||
)
|
||||
self.aggregator.add_observed_query(observed_query)
|
||||
|
||||
|
||||
@ -49,6 +49,7 @@ from datahub.sql_parsing.sqlglot_lineage import (
|
||||
sqlglot_lineage,
|
||||
)
|
||||
from datahub.sql_parsing.sqlglot_utils import (
|
||||
DialectOrStr,
|
||||
_parse_statement,
|
||||
get_query_fingerprint,
|
||||
try_format_query,
|
||||
@ -109,6 +110,7 @@ class ObservedQuery:
|
||||
default_schema: Optional[str] = None
|
||||
query_hash: Optional[str] = None
|
||||
usage_multiplier: int = 1
|
||||
override_dialect: Optional[DialectOrStr] = None
|
||||
|
||||
# Use this to store additional key-value information about the query for debugging.
|
||||
extra_info: Optional[dict] = None
|
||||
@ -834,6 +836,7 @@ class SqlParsingAggregator(Closeable):
|
||||
session_id=session_id,
|
||||
timestamp=observed.timestamp,
|
||||
user=observed.user,
|
||||
override_dialect=observed.override_dialect,
|
||||
)
|
||||
if parsed.debug_info.error:
|
||||
self.report.observed_query_parse_failures.append(
|
||||
@ -1168,6 +1171,7 @@ class SqlParsingAggregator(Closeable):
|
||||
session_id: str = _MISSING_SESSION_ID,
|
||||
timestamp: Optional[datetime] = None,
|
||||
user: Optional[Union[CorpUserUrn, CorpGroupUrn]] = None,
|
||||
override_dialect: Optional[DialectOrStr] = None,
|
||||
) -> SqlParsingResult:
|
||||
with self.report.sql_parsing_timer:
|
||||
parsed = sqlglot_lineage(
|
||||
@ -1175,6 +1179,7 @@ class SqlParsingAggregator(Closeable):
|
||||
schema_resolver=schema_resolver,
|
||||
default_db=default_db,
|
||||
default_schema=default_schema,
|
||||
override_dialect=override_dialect,
|
||||
)
|
||||
self.report.num_sql_parsed += 1
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import functools
|
||||
import os
|
||||
import pathlib
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
@ -1028,6 +1028,48 @@ def test_sql_aggreator_close_cleans_tmp(tmp_path):
|
||||
assert len(os.listdir(tmp_path)) == 0
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
def test_override_dialect_passed_to_sqlglot_lineage() -> None:
|
||||
"""Test that override_dialect is correctly passed to sqlglot_lineage"""
|
||||
aggregator = SqlParsingAggregator(
|
||||
platform="redshift",
|
||||
generate_lineage=True,
|
||||
generate_usage_statistics=False,
|
||||
generate_operations=False,
|
||||
)
|
||||
base_query = ObservedQuery(
|
||||
query="create table foo as select a, b from bar",
|
||||
default_db="dev",
|
||||
default_schema="public",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"datahub.sql_parsing.sql_parsing_aggregator.sqlglot_lineage"
|
||||
) as mock_sqlglot_lineage:
|
||||
mock_sqlglot_lineage.return_value = MagicMock()
|
||||
|
||||
# Test with override_dialect set
|
||||
|
||||
base_query.override_dialect = "snowflake"
|
||||
aggregator.add_observed_query(base_query)
|
||||
|
||||
mock_sqlglot_lineage.assert_called_once()
|
||||
call_args = mock_sqlglot_lineage.call_args
|
||||
assert call_args.kwargs["override_dialect"] == "snowflake"
|
||||
|
||||
# Reset mock
|
||||
mock_sqlglot_lineage.reset_mock()
|
||||
|
||||
# Test without override_dialect (should be None)
|
||||
|
||||
base_query.override_dialect = None
|
||||
aggregator.add_observed_query(base_query)
|
||||
|
||||
mock_sqlglot_lineage.assert_called_once()
|
||||
call_args = mock_sqlglot_lineage.call_args
|
||||
assert call_args.kwargs["override_dialect"] is None
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
def test_diamond_problem(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
|
||||
aggregator = SqlParsingAggregator(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user