mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-12 18:47:45 +00:00
fix(mssql): correctly split stored procs SQL (#12618)
Co-authored-by: Tobias Tekampe <tobias.tekampe@movingdots.com>
This commit is contained in:
parent
3f492279aa
commit
27e72782b9
@ -8,11 +8,11 @@ END_KEYWORD = "END"
|
||||
|
||||
CONTROL_FLOW_KEYWORDS = [
|
||||
"GO",
|
||||
r"BEGIN\w+TRY",
|
||||
r"BEGIN\w+CATCH",
|
||||
r"BEGIN\s+TRY",
|
||||
r"BEGIN\s+CATCH",
|
||||
"BEGIN",
|
||||
r"END\w+TRY",
|
||||
r"END\w+CATCH",
|
||||
r"END\s+TRY",
|
||||
r"END\s+CATCH",
|
||||
# This isn't strictly correct, but we assume that IF | (condition) | (block) should all be split up
|
||||
# This mainly ensures that IF statements don't get tacked onto the previous statement incorrectly
|
||||
"IF",
|
||||
@ -73,25 +73,31 @@ class _StatementSplitter:
|
||||
# what a given END is closing.
|
||||
self.current_case_statements = 0
|
||||
|
||||
def _is_keyword_at_position(self, pos: int, keyword: str) -> bool:
|
||||
def _is_keyword_at_position(self, pos: int, keyword: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Check if a keyword exists at the given position using regex word boundaries.
|
||||
"""
|
||||
sql = self.sql
|
||||
|
||||
if pos + len(keyword) > len(sql):
|
||||
return False
|
||||
keyword_length = len(keyword.replace(r"\s+", " "))
|
||||
|
||||
if pos + keyword_length > len(sql):
|
||||
return False, ""
|
||||
|
||||
# If we're not at a word boundary, we can't generate a keyword.
|
||||
if pos > 0 and not (
|
||||
bool(re.match(r"\w\W", sql[pos - 1 : pos + 1]))
|
||||
or bool(re.match(r"\W\w", sql[pos - 1 : pos + 1]))
|
||||
):
|
||||
return False
|
||||
return False, ""
|
||||
|
||||
pattern = rf"^{re.escape(keyword)}\b"
|
||||
pattern = rf"^{keyword}\b"
|
||||
match = re.match(pattern, sql[pos:], re.IGNORECASE)
|
||||
return bool(match)
|
||||
is_match = bool(match)
|
||||
actual_match = (
|
||||
sql[pos:][match.start() : match.end()] if match is not None else ""
|
||||
)
|
||||
return is_match, actual_match
|
||||
|
||||
def _look_ahead_for_keywords(self, keywords: List[str]) -> Tuple[bool, str, int]:
|
||||
"""
|
||||
@ -99,7 +105,8 @@ class _StatementSplitter:
|
||||
"""
|
||||
|
||||
for keyword in keywords:
|
||||
if self._is_keyword_at_position(self.i, keyword):
|
||||
is_match, keyword = self._is_keyword_at_position(self.i, keyword)
|
||||
if is_match:
|
||||
return True, keyword, len(keyword)
|
||||
return False, "", 0
|
||||
|
||||
@ -118,7 +125,7 @@ class _StatementSplitter:
|
||||
|
||||
def process(self) -> Iterator[str]:
|
||||
if not self.sql or not self.sql.strip():
|
||||
return
|
||||
yield from ()
|
||||
|
||||
prev_real_char = "\0" # the most recent non-whitespace, non-comment character
|
||||
while self.i < len(self.sql):
|
||||
@ -181,7 +188,7 @@ class _StatementSplitter:
|
||||
def _process_normal(self, most_recent_real_char: str) -> Iterator[str]:
|
||||
c = self.sql[self.i]
|
||||
|
||||
if self._is_keyword_at_position(self.i, CASE_KEYWORD):
|
||||
if self._is_keyword_at_position(self.i, CASE_KEYWORD)[0]:
|
||||
self.current_case_statements += 1
|
||||
|
||||
is_control_keyword, keyword, keyword_len = self._look_ahead_for_keywords(
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
from datahub.sql_parsing.split_statements import split_statements
|
||||
|
||||
|
||||
@ -61,8 +63,8 @@ drop table #temp1
|
||||
statements = [statement.strip() for statement in split_statements(test_sql)]
|
||||
assert statements == [
|
||||
"DROP TABLE #temp1",
|
||||
"SELECT 'foo' into #temp1",
|
||||
"DROP table #temp1",
|
||||
"select 'foo' into #temp1",
|
||||
"drop table #temp1",
|
||||
]
|
||||
|
||||
|
||||
@ -117,3 +119,62 @@ SELECT 1 as a INTO #foo
|
||||
"TRUNCATE TABLE #foo",
|
||||
"SELECT 1 as a INTO #foo",
|
||||
]
|
||||
|
||||
|
||||
def test_split_statement_with_try_catch():
|
||||
test_sql = """\
|
||||
BEGIN TRY
|
||||
-- Generate divide-by-zero error.
|
||||
SELECT 1 / 0;
|
||||
END TRY
|
||||
|
||||
BEGIN CATCH
|
||||
-- Execute error retrieval routine.
|
||||
SELECT ERROR_MESSAGE() AS ErrorMessage;
|
||||
END CATCH;
|
||||
"""
|
||||
statements = [statement.strip() for statement in split_statements(test_sql)]
|
||||
expected = [
|
||||
"BEGIN TRY",
|
||||
"-- Generate divide-by-zero error.",
|
||||
"SELECT 1 / 0",
|
||||
"END TRY",
|
||||
"BEGIN CATCH",
|
||||
"-- Execute error retrieval routine.",
|
||||
"SELECT ERROR_MESSAGE() AS ErrorMessage",
|
||||
"END CATCH",
|
||||
]
|
||||
assert statements == expected
|
||||
|
||||
|
||||
def test_split_statement_with_empty_query():
|
||||
test_sql = ""
|
||||
statements = [statement.strip() for statement in split_statements(test_sql)]
|
||||
expected: List[str] = []
|
||||
assert statements == expected
|
||||
|
||||
|
||||
def test_split_statement_with_empty_string_in_query():
|
||||
test_sql = """\
|
||||
SELECT
|
||||
a,
|
||||
b as B
|
||||
FROM myTable
|
||||
WHERE
|
||||
a = ''"""
|
||||
statements = [statement.strip() for statement in split_statements(test_sql)]
|
||||
expected = [test_sql]
|
||||
assert statements == expected
|
||||
|
||||
|
||||
def test_split_statement_with_quotes_in_sting_in_query():
|
||||
test_sql = """\
|
||||
SELECT
|
||||
a,
|
||||
b as B
|
||||
FROM myTable
|
||||
WHERE
|
||||
a = 'hi, my name''s tim.'"""
|
||||
statements = [statement.strip() for statement in split_statements(test_sql)]
|
||||
expected = [test_sql]
|
||||
assert statements == expected
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user