diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py b/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py index 03c2e2b2d1..648445dd26 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py @@ -73,9 +73,8 @@ class SqlQueriesSourceConfig(PlatformInstanceConfigMixin, EnvConfigMixin): default=None, ) override_dialect: Optional[str] = Field( - description="DEPRECATED: This field is ignored. SQL dialect detection is now handled automatically by the SQL parsing aggregator based on the platform.", + description="The SQL dialect to use when parsing queries. Overrides automatic dialect detection.", default=None, - hidden_from_docs=True, ) @@ -230,6 +229,7 @@ class SqlQueriesSource(Source): session_id=query_entry.session_id, default_db=self.config.default_db, default_schema=self.config.default_schema, + override_dialect=self.config.override_dialect, ) self.aggregator.add_observed_query(observed_query) diff --git a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py index 134d38a03a..0b5c5b52a5 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py @@ -49,6 +49,7 @@ from datahub.sql_parsing.sqlglot_lineage import ( sqlglot_lineage, ) from datahub.sql_parsing.sqlglot_utils import ( + DialectOrStr, _parse_statement, get_query_fingerprint, try_format_query, @@ -109,6 +110,7 @@ class ObservedQuery: default_schema: Optional[str] = None query_hash: Optional[str] = None usage_multiplier: int = 1 + override_dialect: Optional[DialectOrStr] = None # Use this to store additional key-value information about the query for debugging. extra_info: Optional[dict] = None @@ -834,6 +836,7 @@ class SqlParsingAggregator(Closeable): session_id=session_id, timestamp=observed.timestamp, user=observed.user, + override_dialect=observed.override_dialect, ) if parsed.debug_info.error: self.report.observed_query_parse_failures.append( @@ -1168,6 +1171,7 @@ class SqlParsingAggregator(Closeable): session_id: str = _MISSING_SESSION_ID, timestamp: Optional[datetime] = None, user: Optional[Union[CorpUserUrn, CorpGroupUrn]] = None, + override_dialect: Optional[DialectOrStr] = None, ) -> SqlParsingResult: with self.report.sql_parsing_timer: parsed = sqlglot_lineage( @@ -1175,6 +1179,7 @@ class SqlParsingAggregator(Closeable): schema_resolver=schema_resolver, default_db=default_db, default_schema=default_schema, + override_dialect=override_dialect, ) self.report.num_sql_parsed += 1 diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py b/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py index afd053eced..9882f6aec4 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py +++ b/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py @@ -2,7 +2,7 @@ import functools import os import pathlib from datetime import datetime, timedelta, timezone -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from freezegun import freeze_time @@ -1028,6 +1028,48 @@ def test_sql_aggreator_close_cleans_tmp(tmp_path): assert len(os.listdir(tmp_path)) == 0 +@freeze_time(FROZEN_TIME) +def test_override_dialect_passed_to_sqlglot_lineage() -> None: + """Test that override_dialect is correctly passed to sqlglot_lineage""" + aggregator = SqlParsingAggregator( + platform="redshift", + generate_lineage=True, + generate_usage_statistics=False, + generate_operations=False, + ) + base_query = ObservedQuery( + query="create table foo as select a, b from bar", + default_db="dev", + default_schema="public", + ) + + with patch( + "datahub.sql_parsing.sql_parsing_aggregator.sqlglot_lineage" + ) as mock_sqlglot_lineage: + mock_sqlglot_lineage.return_value = MagicMock() + + # Test with override_dialect set + + base_query.override_dialect = "snowflake" + aggregator.add_observed_query(base_query) + + mock_sqlglot_lineage.assert_called_once() + call_args = mock_sqlglot_lineage.call_args + assert call_args.kwargs["override_dialect"] == "snowflake" + + # Reset mock + mock_sqlglot_lineage.reset_mock() + + # Test without override_dialect (should be None) + + base_query.override_dialect = None + aggregator.add_observed_query(base_query) + + mock_sqlglot_lineage.assert_called_once() + call_args = mock_sqlglot_lineage.call_args + assert call_args.kwargs["override_dialect"] is None + + @freeze_time(FROZEN_TIME) def test_diamond_problem(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None: aggregator = SqlParsingAggregator(