diff --git a/ingestion/src/metadata/ingestion/lineage/masker.py b/ingestion/src/metadata/ingestion/lineage/masker.py index 2e00bdd6ede..22e6946c791 100644 --- a/ingestion/src/metadata/ingestion/lineage/masker.py +++ b/ingestion/src/metadata/ingestion/lineage/masker.py @@ -13,6 +13,7 @@ Query masking utilities """ import traceback +from typing import Optional from cachetools import LRUCache from collate_sqllineage.runner import SQLPARSE_DIALECT, LineageRunner @@ -26,6 +27,7 @@ MASK_TOKEN = "?" # Cache size is 128 to avoid memory issues masked_query_cache = LRUCache(maxsize=128) + # pylint: disable=protected-access def get_logger(): # pylint: disable=import-outside-toplevel @@ -52,7 +54,6 @@ def mask_literals_with_sqlparse(query: str, parser: LineageRunner): Literal.Number.Integer, Literal.Number.Float, Literal.String.Single, - Literal.String.Symbol, ): token.value = MASK_TOKEN elif token.is_group: @@ -113,7 +114,9 @@ def mask_literals_with_sqlfluff(query: str, parser: LineageRunner) -> str: def mask_query( - query: str, dialect: str = Dialect.ANSI.value, parser: LineageRunner = None + query: str, + dialect: str = Dialect.ANSI.value, + parser: Optional[LineageRunner] = None, ) -> str: """ Mask a query using sqlparse or sqlfluff. diff --git a/ingestion/tests/unit/test_sql_lineage.py b/ingestion/tests/unit/test_sql_lineage.py index 5cd736caad4..c44575ef841 100644 --- a/ingestion/tests/unit/test_sql_lineage.py +++ b/ingestion/tests/unit/test_sql_lineage.py @@ -16,6 +16,7 @@ import uuid from unittest import TestCase import pytest +from collate_sqllineage.runner import SQLPARSE_DIALECT from metadata.generated.schema.entity.data.table import Table from metadata.ingestion.lineage.masker import mask_query @@ -261,7 +262,18 @@ class SqlLineageTest(TestCase): """select * from users where id > 2 and name <> 'pere';""", Dialect.ANSI.value, ), - ("""select * from users where id > 2 and name <> 'pere';""", "random"), + ( + """select * from users where id > 2 and name <> 'pere';""", + "random", + ), + ( + """CREATE TABLE "db001"."table001" AS SELECT * FROM "db002"."table002" WHERE age > 18 AND name = 'John';""", + SQLPARSE_DIALECT, # test with sqlparse + ), + ( + """CREATE TABLE "db001"."table001" AS SELECT * FROM "db002"."table002" WHERE age > 18 AND name = 'John';""", + Dialect.ANSI.value, # test with sqlfluff + ), ] expected_query_list = [ @@ -274,6 +286,8 @@ class SqlLineageTest(TestCase): """select * from (select * from (SELECT CASE address WHEN ? THEN ? ELSE ? END AS person FROM user));""", """select * from users where id > ? and name <> ?;""", """select * from users where id > ? and name <> ?;""", + """CREATE TABLE "db001"."table001" AS SELECT * FROM "db002"."table002" WHERE age > ? AND name = ?;""", + """CREATE TABLE "db001"."table001" AS SELECT * FROM "db002"."table002" WHERE age > ? AND name = ?;""", ] for i, query in enumerate(query_list):