fix(sql-parsing): improve error handling (#4862)

This commit is contained in:
Aseem Bansal 2022-05-10 15:48:54 +05:30 committed by GitHub
parent e697b89bee
commit d0cdadbb3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,7 @@ import logging
import re import re
import unittest import unittest
import unittest.mock import unittest.mock
from typing import Dict, List, Set from typing import Dict, List, Optional, Set
from sqllineage.core.holders import Column, SQLLineageHolder from sqllineage.core.holders import Column, SQLLineageHolder
from sqllineage.exceptions import SQLLineageException from sqllineage.exceptions import SQLLineageException
@ -64,7 +64,8 @@ class SqlLineageSQLParserImpl:
logger.debug(f"Rewrote original query {original_sql_query} as {sql_query}") logger.debug(f"Rewrote original query {original_sql_query} as {sql_query}")
self._sql = sql_query self._sql = sql_query
self._stmt_holders: Optional[List[LineageAnalyzer]] = None
self._sql_holder: Optional[SQLLineageHolder] = None
try: try:
self._stmt = [ self._stmt = [
s s
@ -97,6 +98,9 @@ class SqlLineageSQLParserImpl:
def get_tables(self) -> List[str]: def get_tables(self) -> List[str]:
result: List[str] = list() result: List[str] = list()
if self._sql_holder is None:
logger.error("sql holder not present so cannot get tables")
return result
for table in self._sql_holder.source_tables: for table in self._sql_holder.source_tables:
table_normalized = re.sub(r"^<default>.", "", str(table)) table_normalized = re.sub(r"^<default>.", "", str(table))
result.append(str(table_normalized)) result.append(str(table_normalized))
@ -115,6 +119,9 @@ class SqlLineageSQLParserImpl:
return result return result
def get_columns(self) -> List[str]: def get_columns(self) -> List[str]:
if self._sql_holder is None:
logger.error("sql holder not present so cannot get columns")
return []
graph: DiGraph = self._sql_holder.graph # For mypy attribute checking graph: DiGraph = self._sql_holder.graph # For mypy attribute checking
column_nodes = [n for n in graph.nodes if isinstance(n, Column)] column_nodes = [n for n in graph.nodes if isinstance(n, Column)]
column_graph = graph.subgraph(column_nodes) column_graph = graph.subgraph(column_nodes)