117 lines
3.7 KiB
Python
Raw Normal View History

import textwrap
import sqlglot
from datahub.sql_parsing.sql_parsing_common import QueryType
from datahub.sql_parsing.sqlglot_lineage import (
_UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT,
get_query_type_of_sql,
)
from datahub.sql_parsing.sqlglot_utils import (
generalize_query,
get_dialect,
get_query_fingerprint,
is_dialect_instance,
)
def test_update_from_select():
assert _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT == {"returning", "this"}
def test_is_dialect_instance():
snowflake = get_dialect("snowflake")
assert is_dialect_instance(snowflake, "snowflake")
assert not is_dialect_instance(snowflake, "bigquery")
redshift = get_dialect("redshift")
assert is_dialect_instance(redshift, ["redshift", "snowflake"])
assert is_dialect_instance(redshift, ["postgres", "snowflake"])
def test_query_types():
assert get_query_type_of_sql(
sqlglot.parse_one(
"create temp table foo as select * from bar", dialect="redshift"
),
dialect="redshift",
) == (QueryType.CREATE_TABLE_AS_SELECT, {"kind": "TABLE", "temporary": True})
assert get_query_type_of_sql(
sqlglot.parse_one("create table #foo as select * from bar", dialect="redshift"),
dialect="redshift",
) == (QueryType.CREATE_TABLE_AS_SELECT, {"kind": "TABLE", "temporary": True})
assert get_query_type_of_sql(
sqlglot.parse_one("create view foo as select * from bar", dialect="redshift"),
dialect="redshift",
) == (QueryType.CREATE_VIEW, {"kind": "VIEW"})
def test_query_generalization():
# Basic keyword normalization.
assert (
generalize_query("select * from foo", dialect="redshift") == "SELECT * FROM foo"
)
# Comment removal and whitespace normalization.
assert (
generalize_query(
"/* query system = foo, id = asdf */\nselect /* inline comment */ *\nfrom foo",
dialect="redshift",
)
== "SELECT * FROM foo"
)
# Parameter normalization.
assert (
generalize_query(
"UPDATE \"books\" SET page_count = page_count + 1, author_count = author_count + 1 WHERE book_title = 'My New Book'",
dialect="redshift",
)
== 'UPDATE "books" SET page_count = page_count + ?, author_count = author_count + ? WHERE book_title = ?'
)
assert (
generalize_query(
"select * from foo where date = '2021-01-01'", dialect="redshift"
)
== "SELECT * FROM foo WHERE date = ?"
)
assert (
generalize_query(
"select * from books where category in ('fiction', 'biography', 'fantasy')",
dialect="redshift",
)
== "SELECT * FROM books WHERE category IN (?)"
)
assert (
generalize_query(
textwrap.dedent(
"""\
/* Copied from https://stackoverflow.com/a/452934/5004662 */
INSERT INTO MyTable
( Column1, Column2, Column3 )
VALUES
/* multiple value rows */
('John', 123, 'Lloyds Office'),
('Jane', 124, 'Lloyds Office'),
('Billy', 125, 'London Office'),
('Miranda', 126, 'Bristol Office');
"""
),
dialect="mssql",
)
== "INSERT INTO MyTable (Column1, Column2, Column3) VALUES (?), (?), (?), (?)"
)
def test_query_fingerprint():
assert get_query_fingerprint(
"select * /* everything */ from foo where ts = 34", dialect="redshift"
) == get_query_fingerprint("SELECT * FROM foo where ts = 38", dialect="redshift")
assert get_query_fingerprint(
"select 1 + 1", dialect="postgres"
) != get_query_fingerprint("select 2", dialect="postgres")