231 lines
7.9 KiB
Python

import textwrap
from enum import Enum
import pytest
import sqlglot
from datahub.sql_parsing.query_types import get_query_type_of_sql
from datahub.sql_parsing.sql_parsing_common import QueryType
from datahub.sql_parsing.sqlglot_lineage import _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT
from datahub.sql_parsing.sqlglot_utils import (
generalize_query,
generalize_query_fast,
get_dialect,
get_query_fingerprint,
is_dialect_instance,
)
def test_update_from_select():
assert {"returning", "this"} == _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT
def test_is_dialect_instance():
snowflake = get_dialect("snowflake")
assert is_dialect_instance(snowflake, "snowflake")
assert not is_dialect_instance(snowflake, "bigquery")
redshift = get_dialect("redshift")
assert is_dialect_instance(redshift, ["redshift", "snowflake"])
assert is_dialect_instance(redshift, ["postgres", "snowflake"])
def test_query_types():
assert get_query_type_of_sql(
sqlglot.parse_one(
"create temp table foo as select * from bar", dialect="redshift"
),
dialect="redshift",
) == (QueryType.CREATE_TABLE_AS_SELECT, {"kind": "TABLE", "temporary": True})
assert get_query_type_of_sql(
sqlglot.parse_one("create table #foo as select * from bar", dialect="redshift"),
dialect="redshift",
) == (QueryType.CREATE_TABLE_AS_SELECT, {"kind": "TABLE", "temporary": True})
assert get_query_type_of_sql(
sqlglot.parse_one("create view foo as select * from bar", dialect="redshift"),
dialect="redshift",
) == (QueryType.CREATE_VIEW, {"kind": "VIEW"})
class QueryGeneralizationTestMode(Enum):
FULL = "full"
FAST = "fast"
BOTH = "both"
@pytest.mark.parametrize(
"query, dialect, expected, mode",
[
# Basic keyword normalization.
(
"select * from foo",
"redshift",
"SELECT * FROM foo",
QueryGeneralizationTestMode.FULL,
),
# Comment removal and whitespace normalization.
(
"/* query system = foo, id = asdf */\nSELECT /* inline comment */ *\nFROM foo",
"redshift",
"SELECT * FROM foo",
QueryGeneralizationTestMode.BOTH,
),
(
"SELECT a\n -- comment--\n,b --another comment\n FROM books",
"redshift",
"SELECT a, b FROM books",
QueryGeneralizationTestMode.BOTH,
),
# Parameter normalization.
(
"UPDATE \"books\" SET page_count = page_count + 1, author_count = author_count + 1 WHERE book_title = 'My New Book'",
"redshift",
'UPDATE "books" SET page_count = page_count + ?, author_count = author_count + ? WHERE book_title = ?',
QueryGeneralizationTestMode.BOTH,
),
(
"SELECT * FROM foo WHERE date = '2021-01-01'",
"redshift",
"SELECT * FROM foo WHERE date = ?",
QueryGeneralizationTestMode.BOTH,
),
(
"SELECT * FROM books WHERE category IN ('fiction', 'biography', 'fantasy')",
"redshift",
"SELECT * FROM books WHERE category IN (?)",
QueryGeneralizationTestMode.BOTH,
),
(
textwrap.dedent(
"""\
INSERT INTO MyTable
(Column1, Column2, Column3)
VALUES
('John', 123, 'Lloyds Office');
"""
),
"mssql",
"INSERT INTO MyTable (Column1, Column2, Column3) VALUES (?)",
QueryGeneralizationTestMode.BOTH,
),
(
# Uneven spacing within the IN clause.
"SELECT * FROM books WHERE zip_code IN (123,345, 423 )",
"redshift",
"SELECT * FROM books WHERE zip_code IN (?)",
QueryGeneralizationTestMode.BOTH,
),
# Uneven spacing in the column list.
# This isn't perfect e.g. we still have issues with function calls inside selects.
(
"SELECT a\n ,b FROM books",
"redshift",
"SELECT a, b FROM books",
QueryGeneralizationTestMode.BOTH,
),
(
textwrap.dedent(
"""\
/* Copied from https://stackoverflow.com/a/452934/5004662 */
INSERT INTO MyTable
(Column1, Column2, Column3)
VALUES
/* multiple value rows */
('John', 123, 'Lloyds Office'),
('Jane', 124, 'Lloyds Office'),
('Billy', 125, 'London Office'),
('Miranda', 126, 'Bristol Office');
"""
),
"mssql",
"INSERT INTO MyTable (Column1, Column2, Column3) VALUES (?), (?), (?), (?)",
QueryGeneralizationTestMode.FULL,
),
# Test table name normalization.
# These are only supported with fast normalization.
(
"SELECT * FROM datahub_community.fivetran_interval_unconstitutional_staging.datahub_slack_mess-staging-480fd5a7-58f4-4cc9-b6fb-87358788efe6",
"bigquery",
"SELECT * FROM datahub_community.fivetran_interval_unconstitutional_staging.datahub_slack_mess-staging-00000000-0000-0000-0000-000000000000",
QueryGeneralizationTestMode.FAST,
),
(
"SELECT * FROM datahub_community.maggie.commonroom_slack_members_20240315",
"bigquery",
"SELECT * FROM datahub_community.maggie.commonroom_slack_members_YYYYMMDD",
QueryGeneralizationTestMode.FAST,
),
(
"SELECT COUNT(*) FROM ge_temp_aa91f1fd",
"bigquery",
"SELECT COUNT(*) FROM ge_temp_abcdefgh",
QueryGeneralizationTestMode.FAST,
),
],
)
def test_query_generalization(
query: str, dialect: str, expected: str, mode: QueryGeneralizationTestMode
) -> None:
if mode in {QueryGeneralizationTestMode.FULL, QueryGeneralizationTestMode.BOTH}:
assert generalize_query(query, dialect=dialect) == expected
if mode in {QueryGeneralizationTestMode.FAST, QueryGeneralizationTestMode.BOTH}:
assert (
generalize_query_fast(query, dialect=dialect, change_table_names=True)
== expected
)
def test_query_fingerprint():
assert get_query_fingerprint(
"select * /* everything */ from foo where ts = 34", platform="redshift"
) == get_query_fingerprint("SELECT * FROM foo where ts = 38", platform="redshift")
assert get_query_fingerprint(
"select 1 + 1", platform="postgres"
) != get_query_fingerprint("select 2", platform="postgres")
def test_redshift_query_fingerprint():
query1 = "insert into insert_into_table (select * from base_table);"
query2 = "INSERT INTO insert_into_table (SELECT * FROM base_table)"
assert get_query_fingerprint(query1, "redshift") == get_query_fingerprint(
query2, "redshift"
)
assert get_query_fingerprint(query1, "redshift", True) != get_query_fingerprint(
query2, "redshift", True
)
def test_query_fingerprint_with_secondary_id():
query = "SELECT * FROM users WHERE id = 123"
fingerprint1 = get_query_fingerprint(query, "snowflake")
fingerprint2 = get_query_fingerprint(
query, "snowflake", secondary_id="project_id_123"
)
fingerprint3 = get_query_fingerprint(
query, "snowflake", secondary_id="project_id_456"
)
assert fingerprint1 and fingerprint2 and fingerprint3, (
"Fingerprint should not be None"
)
assert fingerprint1 != fingerprint2, "Fingerprint should change with secondary_id"
assert fingerprint2 != fingerprint3, (
"Different secondary_id should yield different fingerprints"
)
fingerprint4 = get_query_fingerprint(
query, "snowflake", secondary_id="project_id_456"
)
assert fingerprint3 == fingerprint4, (
"Fingerprints are deterministic for the same secondary_id"
)