mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-29 10:57:52 +00:00
feat(ingest): support CLL for redshift materialized views with auto refresh (#9508)
This commit is contained in:
parent
4fe1df6892
commit
52687f3eea
@ -98,7 +98,7 @@ usage_common = {
|
||||
sqlglot_lib = {
|
||||
# Using an Acryl fork of sqlglot.
|
||||
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:hsheth?expand=1
|
||||
"acryl-sqlglot==19.0.2.dev10",
|
||||
"acryl-sqlglot==20.4.1.dev14",
|
||||
}
|
||||
|
||||
sql_common = (
|
||||
|
||||
@ -5,7 +5,7 @@ import itertools
|
||||
import logging
|
||||
import pathlib
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import pydantic.dataclasses
|
||||
import sqlglot
|
||||
@ -60,6 +60,8 @@ RULES_BEFORE_TYPE_ANNOTATION: tuple = tuple(
|
||||
),
|
||||
)
|
||||
)
|
||||
# Quick check that the rules were loaded correctly.
|
||||
assert 0 < len(RULES_BEFORE_TYPE_ANNOTATION) < len(sqlglot.optimizer.optimizer.RULES)
|
||||
|
||||
|
||||
class GraphQLSchemaField(TypedDict):
|
||||
@ -150,12 +152,16 @@ class _TableName(_FrozenModel):
|
||||
|
||||
def as_sqlglot_table(self) -> sqlglot.exp.Table:
|
||||
return sqlglot.exp.Table(
|
||||
catalog=self.database, db=self.db_schema, this=self.table
|
||||
catalog=sqlglot.exp.Identifier(this=self.database)
|
||||
if self.database
|
||||
else None,
|
||||
db=sqlglot.exp.Identifier(this=self.db_schema) if self.db_schema else None,
|
||||
this=sqlglot.exp.Identifier(this=self.table),
|
||||
)
|
||||
|
||||
def qualified(
|
||||
self,
|
||||
dialect: str,
|
||||
dialect: sqlglot.Dialect,
|
||||
default_db: Optional[str] = None,
|
||||
default_schema: Optional[str] = None,
|
||||
) -> "_TableName":
|
||||
@ -271,7 +277,9 @@ class SqlParsingResult(_ParserBaseModel):
|
||||
)
|
||||
|
||||
|
||||
def _parse_statement(sql: sqlglot.exp.ExpOrStr, dialect: str) -> sqlglot.Expression:
|
||||
def _parse_statement(
|
||||
sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect
|
||||
) -> sqlglot.Expression:
|
||||
statement: sqlglot.Expression = sqlglot.maybe_parse(
|
||||
sql, dialect=dialect, error_level=sqlglot.ErrorLevel.RAISE
|
||||
)
|
||||
@ -279,8 +287,7 @@ def _parse_statement(sql: sqlglot.exp.ExpOrStr, dialect: str) -> sqlglot.Express
|
||||
|
||||
|
||||
def _table_level_lineage(
|
||||
statement: sqlglot.Expression,
|
||||
dialect: str,
|
||||
statement: sqlglot.Expression, dialect: sqlglot.Dialect
|
||||
) -> Tuple[Set[_TableName], Set[_TableName]]:
|
||||
# Generate table-level lineage.
|
||||
modified = {
|
||||
@ -482,6 +489,26 @@ _SupportedColumnLineageTypes = Union[
|
||||
]
|
||||
_SupportedColumnLineageTypesTuple = (sqlglot.exp.Subqueryable, sqlglot.exp.DerivedTable)
|
||||
|
||||
DIALECTS_WITH_CASE_INSENSITIVE_COLS = {
|
||||
# Column identifiers are case-insensitive in BigQuery, so we need to
|
||||
# do a normalization step beforehand to make sure it's resolved correctly.
|
||||
"bigquery",
|
||||
# Our snowflake source lowercases column identifiers, so we are forced
|
||||
# to do fuzzy (case-insensitive) resolution instead of exact resolution.
|
||||
"snowflake",
|
||||
# Teradata column names are case-insensitive.
|
||||
# A name, even when enclosed in double quotation marks, is not case sensitive. For example, CUSTOMER and Customer are the same.
|
||||
# See more below:
|
||||
# https://documentation.sas.com/doc/en/pgmsascdc/9.4_3.5/acreldb/n0ejgx4895bofnn14rlguktfx5r3.htm
|
||||
"teradata",
|
||||
}
|
||||
DIALECTS_WITH_DEFAULT_UPPERCASE_COLS = {
|
||||
# In some dialects, column identifiers are effectively case insensitive
|
||||
# because they are automatically converted to uppercase. Most other systems
|
||||
# automatically lowercase unquoted identifiers.
|
||||
"snowflake",
|
||||
}
|
||||
|
||||
|
||||
class UnsupportedStatementTypeError(TypeError):
|
||||
pass
|
||||
@ -495,8 +522,8 @@ class SqlUnderstandingError(Exception):
|
||||
# TODO: Break this up into smaller functions.
|
||||
def _column_level_lineage( # noqa: C901
|
||||
statement: sqlglot.exp.Expression,
|
||||
dialect: str,
|
||||
input_tables: Dict[_TableName, SchemaInfo],
|
||||
dialect: sqlglot.Dialect,
|
||||
table_schemas: Dict[_TableName, SchemaInfo],
|
||||
output_table: Optional[_TableName],
|
||||
default_db: Optional[str],
|
||||
default_schema: Optional[str],
|
||||
@ -515,19 +542,9 @@ def _column_level_lineage( # noqa: C901
|
||||
|
||||
column_lineage: List[_ColumnLineageInfo] = []
|
||||
|
||||
use_case_insensitive_cols = dialect in {
|
||||
# Column identifiers are case-insensitive in BigQuery, so we need to
|
||||
# do a normalization step beforehand to make sure it's resolved correctly.
|
||||
"bigquery",
|
||||
# Our snowflake source lowercases column identifiers, so we are forced
|
||||
# to do fuzzy (case-insensitive) resolution instead of exact resolution.
|
||||
"snowflake",
|
||||
# Teradata column names are case-insensitive.
|
||||
# A name, even when enclosed in double quotation marks, is not case sensitive. For example, CUSTOMER and Customer are the same.
|
||||
# See more below:
|
||||
# https://documentation.sas.com/doc/en/pgmsascdc/9.4_3.5/acreldb/n0ejgx4895bofnn14rlguktfx5r3.htm
|
||||
"teradata",
|
||||
}
|
||||
use_case_insensitive_cols = _is_dialect_instance(
|
||||
dialect, DIALECTS_WITH_CASE_INSENSITIVE_COLS
|
||||
)
|
||||
|
||||
sqlglot_db_schema = sqlglot.MappingSchema(
|
||||
dialect=dialect,
|
||||
@ -537,14 +554,16 @@ def _column_level_lineage( # noqa: C901
|
||||
table_schema_normalized_mapping: Dict[_TableName, Dict[str, str]] = defaultdict(
|
||||
dict
|
||||
)
|
||||
for table, table_schema in input_tables.items():
|
||||
for table, table_schema in table_schemas.items():
|
||||
normalized_table_schema: SchemaInfo = {}
|
||||
for col, col_type in table_schema.items():
|
||||
if use_case_insensitive_cols:
|
||||
col_normalized = (
|
||||
# This is required to match Sqlglot's behavior.
|
||||
col.upper()
|
||||
if dialect in {"snowflake"}
|
||||
if _is_dialect_instance(
|
||||
dialect, DIALECTS_WITH_DEFAULT_UPPERCASE_COLS
|
||||
)
|
||||
else col.lower()
|
||||
)
|
||||
else:
|
||||
@ -561,7 +580,7 @@ def _column_level_lineage( # noqa: C901
|
||||
if use_case_insensitive_cols:
|
||||
|
||||
def _sqlglot_force_column_normalizer(
|
||||
node: sqlglot.exp.Expression, dialect: "sqlglot.DialectType" = None
|
||||
node: sqlglot.exp.Expression,
|
||||
) -> sqlglot.exp.Expression:
|
||||
if isinstance(node, sqlglot.exp.Column):
|
||||
node.this.set("quoted", False)
|
||||
@ -572,9 +591,7 @@ def _column_level_lineage( # noqa: C901
|
||||
# "Prior to case normalization sql %s",
|
||||
# statement.sql(pretty=True, dialect=dialect),
|
||||
# )
|
||||
statement = statement.transform(
|
||||
_sqlglot_force_column_normalizer, dialect, copy=False
|
||||
)
|
||||
statement = statement.transform(_sqlglot_force_column_normalizer, copy=False)
|
||||
# logger.debug(
|
||||
# "Sql after casing normalization %s",
|
||||
# statement.sql(pretty=True, dialect=dialect),
|
||||
@ -595,7 +612,8 @@ def _column_level_lineage( # noqa: C901
|
||||
|
||||
# Optimize the statement + qualify column references.
|
||||
logger.debug(
|
||||
"Prior to qualification sql %s", statement.sql(pretty=True, dialect=dialect)
|
||||
"Prior to column qualification sql %s",
|
||||
statement.sql(pretty=True, dialect=dialect),
|
||||
)
|
||||
try:
|
||||
# Second time running qualify, this time with:
|
||||
@ -678,7 +696,7 @@ def _column_level_lineage( # noqa: C901
|
||||
# Otherwise, we can't process it.
|
||||
continue
|
||||
|
||||
if dialect == "bigquery" and output_col.lower() in {
|
||||
if _is_dialect_instance(dialect, "bigquery") and output_col.lower() in {
|
||||
"_partitiontime",
|
||||
"_partitiondate",
|
||||
}:
|
||||
@ -923,7 +941,7 @@ def _translate_sqlglot_type(
|
||||
def _translate_internal_column_lineage(
|
||||
table_name_urn_mapping: Dict[_TableName, str],
|
||||
raw_column_lineage: _ColumnLineageInfo,
|
||||
dialect: str,
|
||||
dialect: sqlglot.Dialect,
|
||||
) -> ColumnLineageInfo:
|
||||
downstream_urn = None
|
||||
if raw_column_lineage.downstream.table:
|
||||
@ -956,18 +974,44 @@ def _translate_internal_column_lineage(
|
||||
)
|
||||
|
||||
|
||||
def _get_dialect(platform: str) -> str:
|
||||
def _get_dialect_str(platform: str) -> str:
|
||||
# TODO: convert datahub platform names to sqlglot dialect
|
||||
if platform == "presto-on-hive":
|
||||
return "hive"
|
||||
if platform == "mssql":
|
||||
elif platform == "mssql":
|
||||
return "tsql"
|
||||
if platform == "athena":
|
||||
elif platform == "athena":
|
||||
return "trino"
|
||||
elif platform == "mysql":
|
||||
# In sqlglot v20+, MySQL is now case-sensitive by default, which is the
|
||||
# default behavior on Linux. However, MySQL's default case sensitivity
|
||||
# actually depends on the underlying OS.
|
||||
# For us, it's simpler to just assume that it's case-insensitive, and
|
||||
# let the fuzzy resolution logic handle it.
|
||||
return "mysql, normalization_strategy = lowercase"
|
||||
else:
|
||||
return platform
|
||||
|
||||
|
||||
def _get_dialect(platform: str) -> sqlglot.Dialect:
|
||||
return sqlglot.Dialect.get_or_raise(_get_dialect_str(platform))
|
||||
|
||||
|
||||
def _is_dialect_instance(
|
||||
dialect: sqlglot.Dialect, platforms: Union[str, Iterable[str]]
|
||||
) -> bool:
|
||||
if isinstance(platforms, str):
|
||||
platforms = [platforms]
|
||||
else:
|
||||
platforms = list(platforms)
|
||||
|
||||
dialects = [sqlglot.Dialect.get_or_raise(platform) for platform in platforms]
|
||||
|
||||
if any(isinstance(dialect, dialect_class.__class__) for dialect_class in dialects):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _sqlglot_lineage_inner(
|
||||
sql: sqlglot.exp.ExpOrStr,
|
||||
schema_resolver: SchemaResolver,
|
||||
@ -975,7 +1019,7 @@ def _sqlglot_lineage_inner(
|
||||
default_schema: Optional[str] = None,
|
||||
) -> SqlParsingResult:
|
||||
dialect = _get_dialect(schema_resolver.platform)
|
||||
if dialect == "snowflake":
|
||||
if _is_dialect_instance(dialect, "snowflake"):
|
||||
# in snowflake, table identifiers must be uppercased to match sqlglot's behavior.
|
||||
if default_db:
|
||||
default_db = default_db.upper()
|
||||
@ -1064,7 +1108,7 @@ def _sqlglot_lineage_inner(
|
||||
column_lineage = _column_level_lineage(
|
||||
select_statement,
|
||||
dialect=dialect,
|
||||
input_tables=table_name_schema_mapping,
|
||||
table_schemas=table_name_schema_mapping,
|
||||
output_table=downstream_table,
|
||||
default_db=default_db,
|
||||
default_schema=default_schema,
|
||||
@ -1204,13 +1248,13 @@ def detach_ctes(
|
||||
full_new_name, dialect=dialect, into=sqlglot.exp.Table
|
||||
)
|
||||
|
||||
# We expect node.parent to be a Table or Column.
|
||||
# Either way, it should support catalog/db/name.
|
||||
parent = node.parent
|
||||
|
||||
if "catalog" in parent.arg_types:
|
||||
# We expect node.parent to be a Table or Column, both of which support catalog/db/name.
|
||||
# However, we check the parent's arg_types to be safe.
|
||||
if "catalog" in parent.arg_types and table_expr.catalog:
|
||||
parent.set("catalog", table_expr.catalog)
|
||||
if "db" in parent.arg_types:
|
||||
if "db" in parent.arg_types and table_expr.db:
|
||||
parent.set("db", table_expr.db)
|
||||
|
||||
new_node = sqlglot.exp.Identifier(this=table_expr.name)
|
||||
|
||||
@ -0,0 +1,54 @@
|
||||
{
|
||||
"query_type": "CREATE",
|
||||
"in_tables": [
|
||||
"urn:li:dataset:(urn:li:dataPlatform:redshift,customer,PROD)",
|
||||
"urn:li:dataset:(urn:li:dataPlatform:redshift,orders,PROD)"
|
||||
],
|
||||
"out_tables": [
|
||||
"urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)"
|
||||
],
|
||||
"column_lineage": [
|
||||
{
|
||||
"downstream": {
|
||||
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)",
|
||||
"column": "cust_id",
|
||||
"column_type": null,
|
||||
"native_column_type": null
|
||||
},
|
||||
"upstreams": [
|
||||
{
|
||||
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,customer,PROD)",
|
||||
"column": "cust_id"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"downstream": {
|
||||
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)",
|
||||
"column": "first_name",
|
||||
"column_type": null,
|
||||
"native_column_type": null
|
||||
},
|
||||
"upstreams": [
|
||||
{
|
||||
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,customer,PROD)",
|
||||
"column": "first_name"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"downstream": {
|
||||
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)",
|
||||
"column": "total_amount",
|
||||
"column_type": null,
|
||||
"native_column_type": null
|
||||
},
|
||||
"upstreams": [
|
||||
{
|
||||
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,orders,PROD)",
|
||||
"column": "amount"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
46
metadata-ingestion/tests/unit/sql_parsing/test_sql_detach.py
Normal file
46
metadata-ingestion/tests/unit/sql_parsing/test_sql_detach.py
Normal file
@ -0,0 +1,46 @@
|
||||
from datahub.utilities.sqlglot_lineage import detach_ctes
|
||||
|
||||
|
||||
def test_detach_ctes_simple():
|
||||
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 ON table2.id = __cte_0.id"
|
||||
detached_expr = detach_ctes(
|
||||
original,
|
||||
platform="snowflake",
|
||||
cte_mapping={"__cte_0": "_my_cte_table"},
|
||||
)
|
||||
detached = detached_expr.sql(dialect="snowflake")
|
||||
|
||||
assert (
|
||||
detached
|
||||
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN _my_cte_table ON table2.id = _my_cte_table.id"
|
||||
)
|
||||
|
||||
|
||||
def test_detach_ctes_with_alias():
|
||||
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 AS tablealias ON table2.id = tablealias.id"
|
||||
detached_expr = detach_ctes(
|
||||
original,
|
||||
platform="snowflake",
|
||||
cte_mapping={"__cte_0": "_my_cte_table"},
|
||||
)
|
||||
detached = detached_expr.sql(dialect="snowflake")
|
||||
|
||||
assert (
|
||||
detached
|
||||
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN _my_cte_table AS tablealias ON table2.id = tablealias.id"
|
||||
)
|
||||
|
||||
|
||||
def test_detach_ctes_with_multipart_replacement():
|
||||
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 ON table2.id = __cte_0.id"
|
||||
detached_expr = detach_ctes(
|
||||
original,
|
||||
platform="snowflake",
|
||||
cte_mapping={"__cte_0": "my_db.my_schema.my_table"},
|
||||
)
|
||||
detached = detached_expr.sql(dialect="snowflake")
|
||||
|
||||
assert (
|
||||
detached
|
||||
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN my_db.my_schema.my_table ON table2.id = my_db.my_schema.my_table.id"
|
||||
)
|
||||
@ -3,59 +3,11 @@ import pathlib
|
||||
import pytest
|
||||
|
||||
from datahub.testing.check_sql_parser_result import assert_sql_result
|
||||
from datahub.utilities.sqlglot_lineage import (
|
||||
_UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT,
|
||||
detach_ctes,
|
||||
)
|
||||
from datahub.utilities.sqlglot_lineage import _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT
|
||||
|
||||
RESOURCE_DIR = pathlib.Path(__file__).parent / "goldens"
|
||||
|
||||
|
||||
def test_detach_ctes_simple():
|
||||
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 ON table2.id = __cte_0.id"
|
||||
detached_expr = detach_ctes(
|
||||
original,
|
||||
platform="snowflake",
|
||||
cte_mapping={"__cte_0": "_my_cte_table"},
|
||||
)
|
||||
detached = detached_expr.sql(dialect="snowflake")
|
||||
|
||||
assert (
|
||||
detached
|
||||
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN _my_cte_table ON table2.id = _my_cte_table.id"
|
||||
)
|
||||
|
||||
|
||||
def test_detach_ctes_with_alias():
|
||||
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 AS tablealias ON table2.id = tablealias.id"
|
||||
detached_expr = detach_ctes(
|
||||
original,
|
||||
platform="snowflake",
|
||||
cte_mapping={"__cte_0": "_my_cte_table"},
|
||||
)
|
||||
detached = detached_expr.sql(dialect="snowflake")
|
||||
|
||||
assert (
|
||||
detached
|
||||
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN _my_cte_table AS tablealias ON table2.id = tablealias.id"
|
||||
)
|
||||
|
||||
|
||||
def test_detach_ctes_with_multipart_replacement():
|
||||
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 ON table2.id = __cte_0.id"
|
||||
detached_expr = detach_ctes(
|
||||
original,
|
||||
platform="snowflake",
|
||||
cte_mapping={"__cte_0": "my_db.my_schema.my_table"},
|
||||
)
|
||||
detached = detached_expr.sql(dialect="snowflake")
|
||||
|
||||
assert (
|
||||
detached
|
||||
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN my_db.my_schema.my_table ON table2.id = my_db.my_schema.my_table.id"
|
||||
)
|
||||
|
||||
|
||||
def test_select_max():
|
||||
# The COL2 should get normalized to col2.
|
||||
assert_sql_result(
|
||||
@ -1023,3 +975,25 @@ UPDATE accounts SET (contact_first_name, contact_last_name) =
|
||||
},
|
||||
expected_file=RESOURCE_DIR / "test_postgres_complex_update.json",
|
||||
)
|
||||
|
||||
|
||||
def test_redshift_materialized_view_auto_refresh():
|
||||
# Example query from the redshift docs: https://docs.aws.amazon.com/prescriptive-guidance/latest/materialized-views-redshift/refreshing-materialized-views.html
|
||||
assert_sql_result(
|
||||
"""
|
||||
CREATE MATERIALIZED VIEW mv_total_orders
|
||||
AUTO REFRESH YES -- Add this clause to auto refresh the MV
|
||||
AS
|
||||
SELECT c.cust_id,
|
||||
c.first_name,
|
||||
sum(o.amount) as total_amount
|
||||
FROM orders o
|
||||
JOIN customer c
|
||||
ON c.cust_id = o.customer_id
|
||||
GROUP BY c.cust_id,
|
||||
c.first_name;
|
||||
""",
|
||||
dialect="redshift",
|
||||
expected_file=RESOURCE_DIR
|
||||
/ "test_redshift_materialized_view_auto_refresh.json",
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user