Fix#8424: Remove brackets from tables and schemas on lineage (#9257)

* Refactor LineageRunner use

* Address PR comments

* Address pylint errors

* Fix failing test

* Remove brackets from tables and schemas on lineage
This commit is contained in:
Nahuel 2022-12-13 06:40:37 +01:00 committed by GitHub
parent c75ba751b7
commit 9a4e3a7a46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 2 deletions

View File

@ -13,6 +13,7 @@ Lineage Parser configuration
"""
import traceback
from collections import defaultdict
from copy import deepcopy
from logging.config import DictConfigurator
from typing import Any, Dict, List, Optional, Tuple
@ -55,7 +56,7 @@ class LineageParser:
def __init__(self, query: str):
self.query = query
self._clean_query = self.clean_raw_query(query)
self.parser = LineageRunner(query)
self.parser = LineageRunner(self._clean_query)
@cached_property
def involved_tables(self) -> Optional[List[Table]]:
@ -295,7 +296,9 @@ class LineageParser:
def retrieve_tables(self, tables: List[Any]) -> List[Table]:
if not self._clean_query:
return []
return [table for table in tables if isinstance(table, Table)]
return [
self.clean_table_name(table) for table in tables if isinstance(table, Table)
]
@classmethod
def clean_raw_query(cls, raw_query: str) -> Optional[str]:
@ -329,3 +332,28 @@ class LineageParser:
return None
return clean_query.strip()
@staticmethod
def clean_table_name(table: Table) -> Table:
"""
Clean table name by:
- Removing brackets from the beginning and end of the table and schema name
Args:
table (Table): table to be cleaned
Returns:
Copy of the table object with cleaned names
"""
clean_table = deepcopy(table)
if insensitive_match(clean_table.raw_name, r"\[.*\]"):
clean_table.raw_name = insensitive_replace(
clean_table.raw_name, r"\[(.*)\]", r"\1"
)
if clean_table.schema.raw_name and insensitive_match(
clean_table.schema.raw_name, r"\[.*\]"
):
clean_table.schema.raw_name = insensitive_replace(
clean_table.schema.raw_name, r"\[(.*)\]", r"\1"
)
return clean_table

View File

@ -54,6 +54,13 @@ class QueryParserTests(TestCase):
clean_tables = set(self.parser.clean_table_list)
self.assertEqual(clean_tables, {"db.grault", "db.holis", "foo", "db.random"})
def test_bracketed_parser_table_list(self):
parser = LineageParser(
"create view [test_schema].[test_view] as select * from [test_table];"
)
clean_tables = set(parser.clean_table_list)
self.assertEqual(clean_tables, {"test_schema.test_view", "test_table"})
def test_parser_table_aliases(self):
aliases = self.parser.table_aliases
self.assertEqual(