feat(ingest): use mainline sqlglot (#11693)

This commit is contained in:
Harshal Sheth 2024-10-22 19:57:46 -07:00 committed by GitHub
parent 48f4b1a327
commit 35f30b7d3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 714 additions and 16 deletions

View File

@ -13,14 +13,6 @@ def get_long_description():
return pathlib.Path(os.path.join(root, "README.md")).read_text()
rest_common = {"requests", "requests_file"}
sqlglot_lib = {
# Using an Acryl fork of sqlglot.
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main?expand=1
"acryl-sqlglot[rs]==24.0.1.dev7",
}
_version: str = package_metadata["__version__"]
_self_pin = (
f"=={_version}"
@ -32,11 +24,7 @@ base_requirements = {
# Actual dependencies.
"dagster >= 1.3.3",
"dagit >= 1.3.3",
*rest_common,
# Ignoring the dependency below because it causes issues with the vercel built wheel install
# f"acryl-datahub[datahub-rest]{_self_pin}",
"acryl-datahub[datahub-rest]",
*sqlglot_lib,
f"acryl-datahub[datahub-rest,sql-parser]{_self_pin}",
}
mypy_stubs = {

View File

@ -14,7 +14,7 @@ target-version = ['py37', 'py38', 'py39', 'py310']
[tool.isort]
combine_as_imports = true
indent = ' '
known_future_library = ['__future__', 'datahub.utilities._markupsafe_compat', 'datahub_provider._airflow_compat']
known_future_library = ['__future__', 'datahub.utilities._markupsafe_compat', 'datahub.sql_parsing._sqlglot_patch']
profile = 'black'
sections = 'FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER'
skip_glob = 'src/datahub/metadata'

View File

@ -99,9 +99,11 @@ usage_common = {
}
sqlglot_lib = {
# Using an Acryl fork of sqlglot.
# We heavily monkeypatch sqlglot.
# Prior to the patching, we originally maintained an acryl-sqlglot fork:
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main?expand=1
"acryl-sqlglot[rs]==25.25.2.dev9",
"sqlglot[rs]==25.26.0",
"patchy==2.8.0",
}
classification_lib = {

View File

@ -0,0 +1,215 @@
import dataclasses
import difflib
import logging
import patchy.api
import sqlglot
import sqlglot.expressions
import sqlglot.lineage
import sqlglot.optimizer.scope
import sqlglot.optimizer.unnest_subqueries
from datahub.utilities.is_pytest import is_pytest_running
from datahub.utilities.unified_diff import apply_diff
# This injects a few patches into sqlglot to add features and mitigate
# some bugs and performance issues.
# The diffs in this file should match the diffs declared in our fork.
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main
# For a diff-formatted view, see:
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main.diff
_DEBUG_PATCHER = is_pytest_running() or True
logger = logging.getLogger(__name__)
_apply_diff_subprocess = patchy.api._apply_patch
def _new_apply_patch(source: str, patch_text: str, forwards: bool, name: str) -> str:
assert forwards, "Only forward patches are supported"
result = apply_diff(source, patch_text)
# TODO: When in testing mode, still run the subprocess and check that the
# results line up.
if _DEBUG_PATCHER:
result_subprocess = _apply_diff_subprocess(source, patch_text, forwards, name)
if result_subprocess != result:
logger.info("Results from subprocess and _apply_diff do not match")
logger.debug(f"Subprocess result:\n{result_subprocess}")
logger.debug(f"Our result:\n{result}")
diff = difflib.unified_diff(
result_subprocess.splitlines(), result.splitlines()
)
logger.debug("Diff:\n" + "\n".join(diff))
raise ValueError("Results from subprocess and _apply_diff do not match")
return result
patchy.api._apply_patch = _new_apply_patch
def _patch_deepcopy() -> None:
patchy.patch(
sqlglot.expressions.Expression.__deepcopy__,
"""\
@@ -1,4 +1,7 @@ def meta(self) -> t.Dict[str, t.Any]:
def __deepcopy__(self, memo):
+ import datahub.utilities.cooperative_timeout
+ datahub.utilities.cooperative_timeout.cooperate()
+
root = self.__class__()
stack = [(self, root)]
""",
)
def _patch_scope_traverse() -> None:
# Circular scope dependencies can happen in somewhat specific circumstances
# due to our usage of sqlglot.
# See https://github.com/tobymao/sqlglot/pull/4244
patchy.patch(
sqlglot.optimizer.scope.Scope.traverse,
"""\
@@ -5,9 +5,16 @@ def traverse(self):
Scope: scope instances in depth-first-search post-order
\"""
stack = [self]
+ seen_scopes = set()
result = []
while stack:
scope = stack.pop()
+
+ # Scopes aren't hashable, so we use id(scope) instead.
+ if id(scope) in seen_scopes:
+ raise OptimizeError(f"Scope {scope} has a circular scope dependency")
+ seen_scopes.add(id(scope))
+
result.append(scope)
stack.extend(
itertools.chain(
""",
)
def _patch_unnest_subqueries() -> None:
patchy.patch(
sqlglot.optimizer.unnest_subqueries.decorrelate,
"""\
@@ -261,16 +261,19 @@ def remove_aggs(node):
if key in group_by:
key.replace(nested)
elif isinstance(predicate, exp.EQ):
- parent_predicate = _replace(
- parent_predicate,
- f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
- )
+ if parent_predicate:
+ parent_predicate = _replace(
+ parent_predicate,
+ f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
+ )
else:
key.replace(exp.to_identifier("_x"))
- parent_predicate = _replace(
- parent_predicate,
- f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))",
- )
+
+ if parent_predicate:
+ parent_predicate = _replace(
+ parent_predicate,
+ f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))",
+ )
""",
)
def _patch_lineage() -> None:
# Add the "subfield" attribute to sqlglot.lineage.Node.
# With dataclasses, the easiest way to do this is with inheritance.
# Unfortunately, mypy won't pick up on the new field, so we need to
# use type ignores everywhere we use subfield.
@dataclasses.dataclass(frozen=True)
class Node(sqlglot.lineage.Node):
subfield: str = ""
sqlglot.lineage.Node = Node # type: ignore
patchy.patch(
sqlglot.lineage.lineage,
"""\
@@ -12,7 +12,8 @@ def lineage(
\"""
expression = maybe_parse(sql, dialect=dialect)
- column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
+ # column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
+ assert isinstance(column, str)
if sources:
expression = exp.expand(
""",
)
patchy.patch(
sqlglot.lineage.to_node,
"""\
@@ -235,11 +237,12 @@ def to_node(
)
# Find all columns that went into creating this one to list their lineage nodes.
- source_columns = set(find_all_in_scope(select, exp.Column))
+ source_columns = list(find_all_in_scope(select, exp.Column))
- # If the source is a UDTF find columns used in the UTDF to generate the table
+ # If the source is a UDTF find columns used in the UDTF to generate the table
+ source = scope.expression
if isinstance(source, exp.UDTF):
- source_columns |= set(source.find_all(exp.Column))
+ source_columns += list(source.find_all(exp.Column))
derived_tables = [
source.expression.parent
for source in scope.sources.values()
@@ -254,6 +257,7 @@ def to_node(
if dt.comments and dt.comments[0].startswith("source: ")
}
+ c: exp.Column
for c in source_columns:
table = c.table
source = scope.sources.get(table)
@@ -281,8 +285,21 @@ def to_node(
# it means this column's lineage is unknown. This can happen if the definition of a source used in a query
# is not passed into the `sources` map.
source = source or exp.Placeholder()
+
+ subfields = []
+ field: exp.Expression = c
+ while isinstance(field.parent, exp.Dot):
+ field = field.parent
+ subfields.append(field.name)
+ subfield = ".".join(subfields)
+
node.downstream.append(
- Node(name=c.sql(comments=False), source=source, expression=source)
+ Node(
+ name=c.sql(comments=False),
+ source=source,
+ expression=source,
+ subfield=subfield,
+ )
)
return node
""",
)
_patch_deepcopy()
_patch_scope_traverse()
_patch_unnest_subqueries()
_patch_lineage()
SQLGLOT_PATCHED = True

View File

@ -1,3 +1,5 @@
from datahub.sql_parsing._sqlglot_patch import SQLGLOT_PATCHED
import dataclasses
import functools
import logging
@ -53,6 +55,8 @@ from datahub.utilities.cooperative_timeout import (
cooperative_timeout,
)
assert SQLGLOT_PATCHED
logger = logging.getLogger(__name__)
Urn = str

View File

@ -1,3 +1,5 @@
from datahub.sql_parsing._sqlglot_patch import SQLGLOT_PATCHED
import functools
import hashlib
import logging
@ -8,6 +10,8 @@ import sqlglot
import sqlglot.errors
import sqlglot.optimizer.eliminate_ctes
assert SQLGLOT_PATCHED
logger = logging.getLogger(__name__)
DialectOrStr = Union[sqlglot.Dialect, str]
SQL_PARSE_CACHE_SIZE = 1000

View File

@ -0,0 +1,5 @@
import sys
def is_pytest_running() -> bool:
return "pytest" in sys.modules

View File

@ -0,0 +1,236 @@
import logging
from dataclasses import dataclass
from typing import List, Tuple
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
_LOOKAROUND_LINES = 300
# The Python difflib library can generate unified diffs, but it cannot apply them.
# There weren't any well-maintained and easy-to-use libraries for applying
# unified diffs, so I wrote my own.
#
# My implementation is focused on ensuring correctness, and will throw
# an exception whenever it detects an issue.
#
# Alternatives considered:
# - diff-match-patch: This was the most promising since it's from Google.
# Unfortunately, they deprecated the library in Aug 2024. That may not have
# been a dealbreaker, since a somewhat greenfield community fork exists:
# https://github.com/dmsnell/diff-match-patch
# However, there's also a long-standing bug in the library around the
# handling of line breaks when parsing diffs. See:
# https://github.com/google/diff-match-patch/issues/157
# - python-patch: Seems abandoned.
# - patch-ng: Fork of python-patch, but mainly targeted at applying patches to trees.
# It did not have simple "apply patch to string" abstractions.
# - unidiff: Parses diffs, but cannot apply them.
class InvalidDiffError(Exception):
pass
class DiffApplyError(Exception):
pass
@dataclass
class Hunk:
source_start: int
source_lines: int
target_start: int
target_lines: int
lines: List[Tuple[str, str]]
def parse_patch(patch_text: str) -> List[Hunk]:
"""
Parses a unified diff patch into a list of Hunk objects.
Args:
patch_text: Unified diff format patch text
Returns:
List of parsed Hunk objects
Raises:
InvalidDiffError: If the patch is in an invalid format
"""
hunks = []
patch_lines = patch_text.splitlines()
i = 0
while i < len(patch_lines):
line = patch_lines[i]
if line.startswith("@@"):
try:
header_parts = line.split()
if len(header_parts) < 3:
raise ValueError(f"Invalid hunk header format: {line}")
source_changes, target_changes = header_parts[1:3]
source_start, source_lines = map(int, source_changes[1:].split(","))
target_start, target_lines = map(int, target_changes[1:].split(","))
hunk = Hunk(source_start, source_lines, target_start, target_lines, [])
i += 1
while i < len(patch_lines) and not patch_lines[i].startswith("@@"):
hunk_line = patch_lines[i]
if hunk_line:
hunk.lines.append((hunk_line[0], hunk_line[1:]))
else:
# Fully empty lines usually means an empty context line that was
# trimmed by trailing whitespace removal.
hunk.lines.append((" ", ""))
i += 1
hunks.append(hunk)
except (IndexError, ValueError) as e:
raise InvalidDiffError(f"Failed to parse hunk: {str(e)}") from e
else:
raise InvalidDiffError(f"Invalid line format: {line}")
return hunks
def find_hunk_start(source_lines: List[str], hunk: Hunk) -> int:
"""
Finds the actual starting line of a hunk in the source lines.
Args:
source_lines: The original source lines
hunk: The hunk to locate
Returns:
The actual line number where the hunk starts
Raises:
DiffApplyError: If the hunk's context cannot be found in the source lines
"""
# Extract context lines from the hunk, stopping at the first non-context line
context_lines = []
for prefix, line in hunk.lines:
if prefix == " ":
context_lines.append(line)
else:
break
if not context_lines:
logger.debug("No context lines found in hunk.")
return hunk.source_start - 1 # Default to the original start if no context
logger.debug(
f"Searching for {len(context_lines)} context lines, starting with {context_lines[0]}"
)
# Define the range to search for the context lines
search_start = max(0, hunk.source_start - _LOOKAROUND_LINES)
search_end = min(len(source_lines), hunk.source_start + _LOOKAROUND_LINES)
# Iterate over the possible starting positions in the source lines
for i in range(search_start, search_end):
# Check if the context lines match the source lines starting at position i
match = True
for j, context_line in enumerate(context_lines):
if (i + j >= len(source_lines)) or source_lines[i + j] != context_line:
match = False
break
if match:
# logger.debug(f"Context match found at line: {i}")
return i
logger.debug(f"Could not find match for hunk context lines: {context_lines}")
raise DiffApplyError("Could not find match for hunk context.")
def apply_hunk(result_lines: List[str], hunk: Hunk, hunk_index: int) -> None:
"""
Applies a single hunk to the result lines.
Args:
result_lines: The current state of the patched file
hunk: The hunk to apply
hunk_index: The index of the hunk (for logging purposes)
Raises:
DiffApplyError: If the hunk cannot be applied correctly
"""
current_line = find_hunk_start(result_lines, hunk)
logger.debug(f"Hunk {hunk_index + 1} start line: {current_line}")
for line_index, (prefix, content) in enumerate(hunk.lines):
# logger.debug(f"Processing line {line_index + 1} of hunk {hunk_index + 1}")
# logger.debug(f"Current line: {current_line}, Total lines: {len(result_lines)}")
# logger.debug(f"Prefix: {prefix}, Content: {content}")
if current_line >= len(result_lines):
logger.debug(f"Reached end of file while applying hunk {hunk_index + 1}")
while line_index < len(hunk.lines) and hunk.lines[line_index][0] == "+":
result_lines.append(hunk.lines[line_index][1])
line_index += 1
# If there's context or deletions past the end of the file, that's an error.
if line_index < len(hunk.lines):
raise DiffApplyError(
f"Found context or deletions after end of file in hunk {hunk_index + 1}"
)
break
if prefix == "-":
if result_lines[current_line].strip() != content.strip():
raise DiffApplyError(
f"Removing line that doesn't exactly match. Expected: '{content.strip()}', Found: '{result_lines[current_line].strip()}'"
)
result_lines.pop(current_line)
elif prefix == "+":
result_lines.insert(current_line, content)
current_line += 1
elif prefix == " ":
if result_lines[current_line].strip() != content.strip():
raise DiffApplyError(
f"Context line doesn't exactly match. Expected: '{content.strip()}', Found: '{result_lines[current_line].strip()}'"
)
current_line += 1
else:
raise DiffApplyError(
f"Invalid line prefix '{prefix}' in hunk {hunk_index + 1}, line {line_index + 1}"
)
def apply_diff(source: str, patch_text: str) -> str:
"""
Applies a unified diff patch to source text and returns the patched result.
Args:
source: Original source text to be patched
patch_text: Unified diff format patch text (with @@ markers and hunks)
Returns:
The patched text result
Raises:
InvalidDiffError: If the patch is in an invalid format
DiffApplyError: If the patch cannot be applied correctly
"""
# logger.debug(f"Original source:\n{source}")
# logger.debug(f"Patch text:\n{patch_text}")
hunks = parse_patch(patch_text)
logger.debug(f"Parsed into {len(hunks)} hunks")
source_lines = source.splitlines()
result_lines = source_lines.copy()
for hunk_index, hunk in enumerate(hunks):
logger.debug(f"Processing hunk {hunk_index + 1}")
apply_hunk(result_lines, hunk, hunk_index)
result = "\n".join(result_lines) + "\n"
# logger.debug(f"Patched result:\n{result}")
return result

View File

@ -0,0 +1,48 @@
from datahub.sql_parsing._sqlglot_patch import SQLGLOT_PATCHED
import time
import pytest
import sqlglot
import sqlglot.errors
import sqlglot.lineage
import sqlglot.optimizer
from datahub.utilities.cooperative_timeout import (
CooperativeTimeoutError,
cooperative_timeout,
)
from datahub.utilities.perf_timer import PerfTimer
assert SQLGLOT_PATCHED
def test_cooperative_timeout_sql() -> None:
statement = sqlglot.parse_one("SELECT pg_sleep(3)", dialect="postgres")
with pytest.raises(
CooperativeTimeoutError
), PerfTimer() as timer, cooperative_timeout(timeout=0.6):
while True:
# sql() implicitly calls copy(), which is where we check for the timeout.
assert statement.sql() is not None
time.sleep(0.0001)
assert 0.6 <= timer.elapsed_seconds() <= 1.0
def test_scope_circular_dependency() -> None:
scope = sqlglot.optimizer.build_scope(
sqlglot.parse_one("WITH w AS (SELECT * FROM q) SELECT * FROM w")
)
assert scope is not None
cte_scope = scope.cte_scopes[0]
cte_scope.cte_scopes.append(cte_scope)
with pytest.raises(sqlglot.errors.OptimizeError, match="circular scope dependency"):
list(scope.traverse())
def test_lineage_node_subfield() -> None:
expression = sqlglot.parse_one("SELECT 1 AS test")
node = sqlglot.lineage.Node("test", expression, expression, subfield="subfield") # type: ignore
assert node.subfield == "subfield" # type: ignore

View File

@ -0,0 +1,191 @@
import pytest
from datahub.utilities.unified_diff import (
DiffApplyError,
Hunk,
InvalidDiffError,
apply_diff,
apply_hunk,
find_hunk_start,
parse_patch,
)
def test_parse_patch():
patch_text = """@@ -1,3 +1,4 @@
Line 1
-Line 2
+Line 2 modified
+Line 2.5
Line 3"""
hunks = parse_patch(patch_text)
assert len(hunks) == 1
assert hunks[0].source_start == 1
assert hunks[0].source_lines == 3
assert hunks[0].target_start == 1
assert hunks[0].target_lines == 4
assert hunks[0].lines == [
(" ", "Line 1"),
("-", "Line 2"),
("+", "Line 2 modified"),
("+", "Line 2.5"),
(" ", "Line 3"),
]
def test_parse_patch_invalid():
with pytest.raises(InvalidDiffError):
parse_patch("Invalid patch")
def test_parse_patch_bad_header():
# A patch with a malformed header
bad_patch_text = """@@ -1,3
Line 1
-Line 2
+Line 2 modified
Line 3"""
with pytest.raises(InvalidDiffError):
parse_patch(bad_patch_text)
def test_find_hunk_start():
source_lines = ["Line 1", "Line 2", "Line 3", "Line 4"]
hunk = Hunk(2, 2, 2, 2, [(" ", "Line 2"), (" ", "Line 3")])
assert find_hunk_start(source_lines, hunk) == 1
def test_find_hunk_start_not_found():
source_lines = ["Line 1", "Line 2", "Line 3", "Line 4"]
hunk = Hunk(2, 2, 2, 2, [(" ", "Line X"), (" ", "Line Y")])
with pytest.raises(DiffApplyError, match="Could not find match for hunk context."):
find_hunk_start(source_lines, hunk)
def test_apply_hunk_success():
result_lines = ["Line 1", "Line 2", "Line 3"]
hunk = Hunk(
2,
2,
2,
3,
[(" ", "Line 2"), ("-", "Line 3"), ("+", "Line 3 modified"), ("+", "Line 3.5")],
)
apply_hunk(result_lines, hunk, 0)
assert result_lines == ["Line 1", "Line 2", "Line 3 modified", "Line 3.5"]
def test_apply_hunk_mismatch():
result_lines = ["Line 1", "Line 2", "Line X"]
hunk = Hunk(
2, 2, 2, 2, [(" ", "Line 2"), ("-", "Line 3"), ("+", "Line 3 modified")]
)
with pytest.raises(
DiffApplyError, match="Removing line that doesn't exactly match"
):
apply_hunk(result_lines, hunk, 0)
def test_apply_hunk_context_mismatch():
result_lines = ["Line 1", "Line 3"]
hunk = Hunk(2, 2, 2, 2, [(" ", "Line 1"), ("+", "Line 2"), (" ", "Line 4")])
with pytest.raises(DiffApplyError, match="Context line doesn't exactly match"):
apply_hunk(result_lines, hunk, 0)
def test_apply_hunk_invalid_prefix():
result_lines = ["Line 1", "Line 2", "Line 3"]
hunk = Hunk(
2, 2, 2, 2, [(" ", "Line 2"), ("*", "Line 3"), ("+", "Line 3 modified")]
)
with pytest.raises(DiffApplyError, match="Invalid line prefix"):
apply_hunk(result_lines, hunk, 0)
def test_apply_hunk_end_of_file():
result_lines = ["Line 1", "Line 2"]
hunk = Hunk(
2, 2, 2, 3, [(" ", "Line 2"), ("-", "Line 3"), ("+", "Line 3 modified")]
)
with pytest.raises(
DiffApplyError, match="Found context or deletions after end of file"
):
apply_hunk(result_lines, hunk, 0)
def test_apply_hunk_context_beyond_end_of_file():
result_lines = ["Line 1", "Line 3"]
hunk = Hunk(
2, 2, 2, 3, [(" ", "Line 1"), ("+", "Line 2"), (" ", "Line 3"), (" ", "Line 4")]
)
with pytest.raises(
DiffApplyError, match="Found context or deletions after end of file"
):
apply_hunk(result_lines, hunk, 0)
def test_apply_hunk_remove_non_existent_line():
result_lines = ["Line 1", "Line 2", "Line 4"]
hunk = Hunk(
2, 2, 2, 3, [(" ", "Line 2"), ("-", "Line 3"), ("+", "Line 3 modified")]
)
with pytest.raises(
DiffApplyError, match="Removing line that doesn't exactly match"
):
apply_hunk(result_lines, hunk, 0)
def test_apply_hunk_addition_beyond_end_of_file():
result_lines = ["Line 1", "Line 2"]
hunk = Hunk(
2, 2, 2, 3, [(" ", "Line 2"), ("+", "Line 3 modified"), ("+", "Line 4")]
)
apply_hunk(result_lines, hunk, 0)
assert result_lines == ["Line 1", "Line 2", "Line 3 modified", "Line 4"]
def test_apply_diff():
source = """Line 1
Line 2
Line 3
Line 4"""
patch = """@@ -1,4 +1,5 @@
Line 1
-Line 2
+Line 2 modified
+Line 2.5
Line 3
Line 4"""
result = apply_diff(source, patch)
expected = """Line 1
Line 2 modified
Line 2.5
Line 3
Line 4
"""
assert result == expected
def test_apply_diff_invalid_patch():
source = "Line 1\nLine 2\n"
patch = "Invalid patch"
with pytest.raises(InvalidDiffError):
apply_diff(source, patch)
def test_apply_diff_unapplicable_patch():
source = "Line 1\nLine 2\n"
patch = "@@ -1,2 +1,2 @@\n Line 1\n-Line X\n+Line 2 modified\n"
with pytest.raises(DiffApplyError):
apply_diff(source, patch)
def test_apply_diff_add_to_empty_file():
source = ""
patch = """\
@@ -1,0 +1,1 @@
+Line 1
+Line 2
"""
result = apply_diff(source, patch)
assert result == "Line 1\nLine 2\n"

View File

@ -1,6 +1,7 @@
import doctest
from datahub.utilities.delayed_iter import delayed_iter
from datahub.utilities.is_pytest import is_pytest_running
from datahub.utilities.sql_parser import SqlLineageSQLParser
@ -295,3 +296,7 @@ def test_logging_name_extraction():
).attempted
> 0
)
def test_is_pytest_running() -> None:
assert is_pytest_running()