From a8b07c5fe6dc55eebf44e63b35cd957709c56a26 Mon Sep 17 00:00:00 2001 From: Nadav Gross <33874964+nadavgross@users.noreply.github.com> Date: Tue, 16 Jul 2024 22:28:14 +0300 Subject: [PATCH] feat(ingestion/sqlglot): add optional `default_dialect` parameter to sqlglot lineage (#10830) --- .../src/datahub/ingestion/graph/client.py | 2 ++ .../src/datahub/sql_parsing/sqlglot_lineage.py | 16 +++++++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/graph/client.py b/metadata-ingestion/src/datahub/ingestion/graph/client.py index 7ba412b3e7..1d6097da23 100644 --- a/metadata-ingestion/src/datahub/ingestion/graph/client.py +++ b/metadata-ingestion/src/datahub/ingestion/graph/client.py @@ -1241,6 +1241,7 @@ class DataHubGraph(DatahubRestEmitter): env: str = DEFAULT_ENV, default_db: Optional[str] = None, default_schema: Optional[str] = None, + default_dialect: Optional[str] = None, ) -> "SqlParsingResult": from datahub.sql_parsing.sqlglot_lineage import sqlglot_lineage @@ -1254,6 +1255,7 @@ class DataHubGraph(DatahubRestEmitter): schema_resolver=schema_resolver, default_db=default_db, default_schema=default_schema, + default_dialect=default_dialect, ) def create_tag(self, tag_name: str) -> str: diff --git a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py index 9c2a588a57..976ff8bcc9 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py @@ -843,8 +843,14 @@ def _sqlglot_lineage_inner( schema_resolver: SchemaResolverInterface, default_db: Optional[str] = None, default_schema: Optional[str] = None, + default_dialect: Optional[str] = None, ) -> SqlParsingResult: - dialect = get_dialect(schema_resolver.platform) + + if not default_dialect: + dialect = get_dialect(schema_resolver.platform) + else: + dialect = get_dialect(default_dialect) + if is_dialect_instance(dialect, "snowflake"): # in snowflake, table identifiers must be uppercased to match sqlglot's behavior. if default_db: @@ -1003,6 +1009,7 @@ def sqlglot_lineage( schema_resolver: SchemaResolverInterface, default_db: Optional[str] = None, default_schema: Optional[str] = None, + default_dialect: Optional[str] = None, ) -> SqlParsingResult: """Parse a SQL statement and generate lineage information. @@ -1020,8 +1027,9 @@ def sqlglot_lineage( can be brittle with respect to missing schema information and complex SQL logic like UNNESTs. - The SQL dialect is inferred from the schema_resolver's platform. The - set of supported dialects is the same as sqlglot's. See their + The SQL dialect can be given as an argument called default_dialect or it can + be inferred from the schema_resolver's platform. + The set of supported dialects is the same as sqlglot's. See their `documentation `_ for the full list. @@ -1035,6 +1043,7 @@ def sqlglot_lineage( schema_resolver: The schema resolver to use for resolving table schemas. default_db: The default database to use for unqualified table names. default_schema: The default schema to use for unqualified table names. + default_dialect: A default dialect to override the dialect provided by 'schema_resolver'. Returns: A SqlParsingResult object containing the parsed lineage information. @@ -1059,6 +1068,7 @@ def sqlglot_lineage( schema_resolver=schema_resolver, default_db=default_db, default_schema=default_schema, + default_dialect=default_dialect, ) except Exception as e: return SqlParsingResult.make_from_error(e)