feat(ingest): support CLL for redshift materialized views with auto refresh (#9508)

This commit is contained in:
Harshal Sheth 2023-12-22 02:18:22 -05:00 committed by GitHub
parent 4fe1df6892
commit 52687f3eea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 207 additions and 89 deletions

View File

@ -98,7 +98,7 @@ usage_common = {
sqlglot_lib = {
# Using an Acryl fork of sqlglot.
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:hsheth?expand=1
"acryl-sqlglot==19.0.2.dev10",
"acryl-sqlglot==20.4.1.dev14",
}
sql_common = (

View File

@ -5,7 +5,7 @@ import itertools
import logging
import pathlib
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import pydantic.dataclasses
import sqlglot
@ -60,6 +60,8 @@ RULES_BEFORE_TYPE_ANNOTATION: tuple = tuple(
),
)
)
# Quick check that the rules were loaded correctly.
assert 0 < len(RULES_BEFORE_TYPE_ANNOTATION) < len(sqlglot.optimizer.optimizer.RULES)
class GraphQLSchemaField(TypedDict):
@ -150,12 +152,16 @@ class _TableName(_FrozenModel):
def as_sqlglot_table(self) -> sqlglot.exp.Table:
return sqlglot.exp.Table(
catalog=self.database, db=self.db_schema, this=self.table
catalog=sqlglot.exp.Identifier(this=self.database)
if self.database
else None,
db=sqlglot.exp.Identifier(this=self.db_schema) if self.db_schema else None,
this=sqlglot.exp.Identifier(this=self.table),
)
def qualified(
self,
dialect: str,
dialect: sqlglot.Dialect,
default_db: Optional[str] = None,
default_schema: Optional[str] = None,
) -> "_TableName":
@ -271,7 +277,9 @@ class SqlParsingResult(_ParserBaseModel):
)
def _parse_statement(sql: sqlglot.exp.ExpOrStr, dialect: str) -> sqlglot.Expression:
def _parse_statement(
sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect
) -> sqlglot.Expression:
statement: sqlglot.Expression = sqlglot.maybe_parse(
sql, dialect=dialect, error_level=sqlglot.ErrorLevel.RAISE
)
@ -279,8 +287,7 @@ def _parse_statement(sql: sqlglot.exp.ExpOrStr, dialect: str) -> sqlglot.Express
def _table_level_lineage(
statement: sqlglot.Expression,
dialect: str,
statement: sqlglot.Expression, dialect: sqlglot.Dialect
) -> Tuple[Set[_TableName], Set[_TableName]]:
# Generate table-level lineage.
modified = {
@ -482,6 +489,26 @@ _SupportedColumnLineageTypes = Union[
]
_SupportedColumnLineageTypesTuple = (sqlglot.exp.Subqueryable, sqlglot.exp.DerivedTable)
DIALECTS_WITH_CASE_INSENSITIVE_COLS = {
# Column identifiers are case-insensitive in BigQuery, so we need to
# do a normalization step beforehand to make sure it's resolved correctly.
"bigquery",
# Our snowflake source lowercases column identifiers, so we are forced
# to do fuzzy (case-insensitive) resolution instead of exact resolution.
"snowflake",
# Teradata column names are case-insensitive.
# A name, even when enclosed in double quotation marks, is not case sensitive. For example, CUSTOMER and Customer are the same.
# See more below:
# https://documentation.sas.com/doc/en/pgmsascdc/9.4_3.5/acreldb/n0ejgx4895bofnn14rlguktfx5r3.htm
"teradata",
}
DIALECTS_WITH_DEFAULT_UPPERCASE_COLS = {
# In some dialects, column identifiers are effectively case insensitive
# because they are automatically converted to uppercase. Most other systems
# automatically lowercase unquoted identifiers.
"snowflake",
}
class UnsupportedStatementTypeError(TypeError):
pass
@ -495,8 +522,8 @@ class SqlUnderstandingError(Exception):
# TODO: Break this up into smaller functions.
def _column_level_lineage( # noqa: C901
statement: sqlglot.exp.Expression,
dialect: str,
input_tables: Dict[_TableName, SchemaInfo],
dialect: sqlglot.Dialect,
table_schemas: Dict[_TableName, SchemaInfo],
output_table: Optional[_TableName],
default_db: Optional[str],
default_schema: Optional[str],
@ -515,19 +542,9 @@ def _column_level_lineage( # noqa: C901
column_lineage: List[_ColumnLineageInfo] = []
use_case_insensitive_cols = dialect in {
# Column identifiers are case-insensitive in BigQuery, so we need to
# do a normalization step beforehand to make sure it's resolved correctly.
"bigquery",
# Our snowflake source lowercases column identifiers, so we are forced
# to do fuzzy (case-insensitive) resolution instead of exact resolution.
"snowflake",
# Teradata column names are case-insensitive.
# A name, even when enclosed in double quotation marks, is not case sensitive. For example, CUSTOMER and Customer are the same.
# See more below:
# https://documentation.sas.com/doc/en/pgmsascdc/9.4_3.5/acreldb/n0ejgx4895bofnn14rlguktfx5r3.htm
"teradata",
}
use_case_insensitive_cols = _is_dialect_instance(
dialect, DIALECTS_WITH_CASE_INSENSITIVE_COLS
)
sqlglot_db_schema = sqlglot.MappingSchema(
dialect=dialect,
@ -537,14 +554,16 @@ def _column_level_lineage( # noqa: C901
table_schema_normalized_mapping: Dict[_TableName, Dict[str, str]] = defaultdict(
dict
)
for table, table_schema in input_tables.items():
for table, table_schema in table_schemas.items():
normalized_table_schema: SchemaInfo = {}
for col, col_type in table_schema.items():
if use_case_insensitive_cols:
col_normalized = (
# This is required to match Sqlglot's behavior.
col.upper()
if dialect in {"snowflake"}
if _is_dialect_instance(
dialect, DIALECTS_WITH_DEFAULT_UPPERCASE_COLS
)
else col.lower()
)
else:
@ -561,7 +580,7 @@ def _column_level_lineage( # noqa: C901
if use_case_insensitive_cols:
def _sqlglot_force_column_normalizer(
node: sqlglot.exp.Expression, dialect: "sqlglot.DialectType" = None
node: sqlglot.exp.Expression,
) -> sqlglot.exp.Expression:
if isinstance(node, sqlglot.exp.Column):
node.this.set("quoted", False)
@ -572,9 +591,7 @@ def _column_level_lineage( # noqa: C901
# "Prior to case normalization sql %s",
# statement.sql(pretty=True, dialect=dialect),
# )
statement = statement.transform(
_sqlglot_force_column_normalizer, dialect, copy=False
)
statement = statement.transform(_sqlglot_force_column_normalizer, copy=False)
# logger.debug(
# "Sql after casing normalization %s",
# statement.sql(pretty=True, dialect=dialect),
@ -595,7 +612,8 @@ def _column_level_lineage( # noqa: C901
# Optimize the statement + qualify column references.
logger.debug(
"Prior to qualification sql %s", statement.sql(pretty=True, dialect=dialect)
"Prior to column qualification sql %s",
statement.sql(pretty=True, dialect=dialect),
)
try:
# Second time running qualify, this time with:
@ -678,7 +696,7 @@ def _column_level_lineage( # noqa: C901
# Otherwise, we can't process it.
continue
if dialect == "bigquery" and output_col.lower() in {
if _is_dialect_instance(dialect, "bigquery") and output_col.lower() in {
"_partitiontime",
"_partitiondate",
}:
@ -923,7 +941,7 @@ def _translate_sqlglot_type(
def _translate_internal_column_lineage(
table_name_urn_mapping: Dict[_TableName, str],
raw_column_lineage: _ColumnLineageInfo,
dialect: str,
dialect: sqlglot.Dialect,
) -> ColumnLineageInfo:
downstream_urn = None
if raw_column_lineage.downstream.table:
@ -956,18 +974,44 @@ def _translate_internal_column_lineage(
)
def _get_dialect(platform: str) -> str:
def _get_dialect_str(platform: str) -> str:
# TODO: convert datahub platform names to sqlglot dialect
if platform == "presto-on-hive":
return "hive"
if platform == "mssql":
elif platform == "mssql":
return "tsql"
if platform == "athena":
elif platform == "athena":
return "trino"
elif platform == "mysql":
# In sqlglot v20+, MySQL is now case-sensitive by default, which is the
# default behavior on Linux. However, MySQL's default case sensitivity
# actually depends on the underlying OS.
# For us, it's simpler to just assume that it's case-insensitive, and
# let the fuzzy resolution logic handle it.
return "mysql, normalization_strategy = lowercase"
else:
return platform
def _get_dialect(platform: str) -> sqlglot.Dialect:
return sqlglot.Dialect.get_or_raise(_get_dialect_str(platform))
def _is_dialect_instance(
dialect: sqlglot.Dialect, platforms: Union[str, Iterable[str]]
) -> bool:
if isinstance(platforms, str):
platforms = [platforms]
else:
platforms = list(platforms)
dialects = [sqlglot.Dialect.get_or_raise(platform) for platform in platforms]
if any(isinstance(dialect, dialect_class.__class__) for dialect_class in dialects):
return True
return False
def _sqlglot_lineage_inner(
sql: sqlglot.exp.ExpOrStr,
schema_resolver: SchemaResolver,
@ -975,7 +1019,7 @@ def _sqlglot_lineage_inner(
default_schema: Optional[str] = None,
) -> SqlParsingResult:
dialect = _get_dialect(schema_resolver.platform)
if dialect == "snowflake":
if _is_dialect_instance(dialect, "snowflake"):
# in snowflake, table identifiers must be uppercased to match sqlglot's behavior.
if default_db:
default_db = default_db.upper()
@ -1064,7 +1108,7 @@ def _sqlglot_lineage_inner(
column_lineage = _column_level_lineage(
select_statement,
dialect=dialect,
input_tables=table_name_schema_mapping,
table_schemas=table_name_schema_mapping,
output_table=downstream_table,
default_db=default_db,
default_schema=default_schema,
@ -1204,13 +1248,13 @@ def detach_ctes(
full_new_name, dialect=dialect, into=sqlglot.exp.Table
)
# We expect node.parent to be a Table or Column.
# Either way, it should support catalog/db/name.
parent = node.parent
if "catalog" in parent.arg_types:
# We expect node.parent to be a Table or Column, both of which support catalog/db/name.
# However, we check the parent's arg_types to be safe.
if "catalog" in parent.arg_types and table_expr.catalog:
parent.set("catalog", table_expr.catalog)
if "db" in parent.arg_types:
if "db" in parent.arg_types and table_expr.db:
parent.set("db", table_expr.db)
new_node = sqlglot.exp.Identifier(this=table_expr.name)

View File

@ -0,0 +1,54 @@
{
"query_type": "CREATE",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:redshift,customer,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:redshift,orders,PROD)"
],
"out_tables": [
"urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)"
],
"column_lineage": [
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)",
"column": "cust_id",
"column_type": null,
"native_column_type": null
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,customer,PROD)",
"column": "cust_id"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)",
"column": "first_name",
"column_type": null,
"native_column_type": null
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,customer,PROD)",
"column": "first_name"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)",
"column": "total_amount",
"column_type": null,
"native_column_type": null
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,orders,PROD)",
"column": "amount"
}
]
}
]
}

View File

@ -0,0 +1,46 @@
from datahub.utilities.sqlglot_lineage import detach_ctes
def test_detach_ctes_simple():
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 ON table2.id = __cte_0.id"
detached_expr = detach_ctes(
original,
platform="snowflake",
cte_mapping={"__cte_0": "_my_cte_table"},
)
detached = detached_expr.sql(dialect="snowflake")
assert (
detached
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN _my_cte_table ON table2.id = _my_cte_table.id"
)
def test_detach_ctes_with_alias():
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 AS tablealias ON table2.id = tablealias.id"
detached_expr = detach_ctes(
original,
platform="snowflake",
cte_mapping={"__cte_0": "_my_cte_table"},
)
detached = detached_expr.sql(dialect="snowflake")
assert (
detached
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN _my_cte_table AS tablealias ON table2.id = tablealias.id"
)
def test_detach_ctes_with_multipart_replacement():
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 ON table2.id = __cte_0.id"
detached_expr = detach_ctes(
original,
platform="snowflake",
cte_mapping={"__cte_0": "my_db.my_schema.my_table"},
)
detached = detached_expr.sql(dialect="snowflake")
assert (
detached
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN my_db.my_schema.my_table ON table2.id = my_db.my_schema.my_table.id"
)

View File

@ -3,59 +3,11 @@ import pathlib
import pytest
from datahub.testing.check_sql_parser_result import assert_sql_result
from datahub.utilities.sqlglot_lineage import (
_UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT,
detach_ctes,
)
from datahub.utilities.sqlglot_lineage import _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT
RESOURCE_DIR = pathlib.Path(__file__).parent / "goldens"
def test_detach_ctes_simple():
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 ON table2.id = __cte_0.id"
detached_expr = detach_ctes(
original,
platform="snowflake",
cte_mapping={"__cte_0": "_my_cte_table"},
)
detached = detached_expr.sql(dialect="snowflake")
assert (
detached
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN _my_cte_table ON table2.id = _my_cte_table.id"
)
def test_detach_ctes_with_alias():
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 AS tablealias ON table2.id = tablealias.id"
detached_expr = detach_ctes(
original,
platform="snowflake",
cte_mapping={"__cte_0": "_my_cte_table"},
)
detached = detached_expr.sql(dialect="snowflake")
assert (
detached
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN _my_cte_table AS tablealias ON table2.id = tablealias.id"
)
def test_detach_ctes_with_multipart_replacement():
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 ON table2.id = __cte_0.id"
detached_expr = detach_ctes(
original,
platform="snowflake",
cte_mapping={"__cte_0": "my_db.my_schema.my_table"},
)
detached = detached_expr.sql(dialect="snowflake")
assert (
detached
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN my_db.my_schema.my_table ON table2.id = my_db.my_schema.my_table.id"
)
def test_select_max():
# The COL2 should get normalized to col2.
assert_sql_result(
@ -1023,3 +975,25 @@ UPDATE accounts SET (contact_first_name, contact_last_name) =
},
expected_file=RESOURCE_DIR / "test_postgres_complex_update.json",
)
def test_redshift_materialized_view_auto_refresh():
# Example query from the redshift docs: https://docs.aws.amazon.com/prescriptive-guidance/latest/materialized-views-redshift/refreshing-materialized-views.html
assert_sql_result(
"""
CREATE MATERIALIZED VIEW mv_total_orders
AUTO REFRESH YES -- Add this clause to auto refresh the MV
AS
SELECT c.cust_id,
c.first_name,
sum(o.amount) as total_amount
FROM orders o
JOIN customer c
ON c.cust_id = o.customer_id
GROUP BY c.cust_id,
c.first_name;
""",
dialect="redshift",
expected_file=RESOURCE_DIR
/ "test_redshift_materialized_view_auto_refresh.json",
)