fix(mssql): correctly split stored procs SQL (#12618)

Co-authored-by: Tobias Tekampe <tobias.tekampe@movingdots.com>
This commit is contained in:
ttekampe 2025-02-18 21:09:06 +01:00 committed by GitHub
parent 3f492279aa
commit 27e72782b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 83 additions and 15 deletions

View File

@ -8,11 +8,11 @@ END_KEYWORD = "END"
CONTROL_FLOW_KEYWORDS = [
"GO",
r"BEGIN\w+TRY",
r"BEGIN\w+CATCH",
r"BEGIN\s+TRY",
r"BEGIN\s+CATCH",
"BEGIN",
r"END\w+TRY",
r"END\w+CATCH",
r"END\s+TRY",
r"END\s+CATCH",
# This isn't strictly correct, but we assume that IF | (condition) | (block) should all be split up
# This mainly ensures that IF statements don't get tacked onto the previous statement incorrectly
"IF",
@ -73,25 +73,31 @@ class _StatementSplitter:
# what a given END is closing.
self.current_case_statements = 0
def _is_keyword_at_position(self, pos: int, keyword: str) -> bool:
def _is_keyword_at_position(self, pos: int, keyword: str) -> Tuple[bool, str]:
"""
Check if a keyword exists at the given position using regex word boundaries.
"""
sql = self.sql
if pos + len(keyword) > len(sql):
return False
keyword_length = len(keyword.replace(r"\s+", " "))
if pos + keyword_length > len(sql):
return False, ""
# If we're not at a word boundary, we can't generate a keyword.
if pos > 0 and not (
bool(re.match(r"\w\W", sql[pos - 1 : pos + 1]))
or bool(re.match(r"\W\w", sql[pos - 1 : pos + 1]))
):
return False
return False, ""
pattern = rf"^{re.escape(keyword)}\b"
pattern = rf"^{keyword}\b"
match = re.match(pattern, sql[pos:], re.IGNORECASE)
return bool(match)
is_match = bool(match)
actual_match = (
sql[pos:][match.start() : match.end()] if match is not None else ""
)
return is_match, actual_match
def _look_ahead_for_keywords(self, keywords: List[str]) -> Tuple[bool, str, int]:
"""
@ -99,7 +105,8 @@ class _StatementSplitter:
"""
for keyword in keywords:
if self._is_keyword_at_position(self.i, keyword):
is_match, keyword = self._is_keyword_at_position(self.i, keyword)
if is_match:
return True, keyword, len(keyword)
return False, "", 0
@ -118,7 +125,7 @@ class _StatementSplitter:
def process(self) -> Iterator[str]:
if not self.sql or not self.sql.strip():
return
yield from ()
prev_real_char = "\0" # the most recent non-whitespace, non-comment character
while self.i < len(self.sql):
@ -181,7 +188,7 @@ class _StatementSplitter:
def _process_normal(self, most_recent_real_char: str) -> Iterator[str]:
c = self.sql[self.i]
if self._is_keyword_at_position(self.i, CASE_KEYWORD):
if self._is_keyword_at_position(self.i, CASE_KEYWORD)[0]:
self.current_case_statements += 1
is_control_keyword, keyword, keyword_len = self._look_ahead_for_keywords(

View File

@ -1,3 +1,5 @@
from typing import List
from datahub.sql_parsing.split_statements import split_statements
@ -61,8 +63,8 @@ drop table #temp1
statements = [statement.strip() for statement in split_statements(test_sql)]
assert statements == [
"DROP TABLE #temp1",
"SELECT 'foo' into #temp1",
"DROP table #temp1",
"select 'foo' into #temp1",
"drop table #temp1",
]
@ -117,3 +119,62 @@ SELECT 1 as a INTO #foo
"TRUNCATE TABLE #foo",
"SELECT 1 as a INTO #foo",
]
def test_split_statement_with_try_catch():
test_sql = """\
BEGIN TRY
-- Generate divide-by-zero error.
SELECT 1 / 0;
END TRY
BEGIN CATCH
-- Execute error retrieval routine.
SELECT ERROR_MESSAGE() AS ErrorMessage;
END CATCH;
"""
statements = [statement.strip() for statement in split_statements(test_sql)]
expected = [
"BEGIN TRY",
"-- Generate divide-by-zero error.",
"SELECT 1 / 0",
"END TRY",
"BEGIN CATCH",
"-- Execute error retrieval routine.",
"SELECT ERROR_MESSAGE() AS ErrorMessage",
"END CATCH",
]
assert statements == expected
def test_split_statement_with_empty_query():
test_sql = ""
statements = [statement.strip() for statement in split_statements(test_sql)]
expected: List[str] = []
assert statements == expected
def test_split_statement_with_empty_string_in_query():
test_sql = """\
SELECT
a,
b as B
FROM myTable
WHERE
a = ''"""
statements = [statement.strip() for statement in split_statements(test_sql)]
expected = [test_sql]
assert statements == expected
def test_split_statement_with_quotes_in_sting_in_query():
test_sql = """\
SELECT
a,
b as B
FROM myTable
WHERE
a = 'hi, my name''s tim.'"""
statements = [statement.strip() for statement in split_statements(test_sql)]
expected = [test_sql]
assert statements == expected