mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-25 17:08:29 +00:00
feat(ingest): use mainline sqlglot (#11693)
This commit is contained in:
parent
48f4b1a327
commit
35f30b7d3c
@ -13,14 +13,6 @@ def get_long_description():
|
||||
return pathlib.Path(os.path.join(root, "README.md")).read_text()
|
||||
|
||||
|
||||
rest_common = {"requests", "requests_file"}
|
||||
|
||||
sqlglot_lib = {
|
||||
# Using an Acryl fork of sqlglot.
|
||||
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main?expand=1
|
||||
"acryl-sqlglot[rs]==24.0.1.dev7",
|
||||
}
|
||||
|
||||
_version: str = package_metadata["__version__"]
|
||||
_self_pin = (
|
||||
f"=={_version}"
|
||||
@ -32,11 +24,7 @@ base_requirements = {
|
||||
# Actual dependencies.
|
||||
"dagster >= 1.3.3",
|
||||
"dagit >= 1.3.3",
|
||||
*rest_common,
|
||||
# Ignoring the dependency below because it causes issues with the vercel built wheel install
|
||||
# f"acryl-datahub[datahub-rest]{_self_pin}",
|
||||
"acryl-datahub[datahub-rest]",
|
||||
*sqlglot_lib,
|
||||
f"acryl-datahub[datahub-rest,sql-parser]{_self_pin}",
|
||||
}
|
||||
|
||||
mypy_stubs = {
|
||||
|
||||
@ -14,7 +14,7 @@ target-version = ['py37', 'py38', 'py39', 'py310']
|
||||
[tool.isort]
|
||||
combine_as_imports = true
|
||||
indent = ' '
|
||||
known_future_library = ['__future__', 'datahub.utilities._markupsafe_compat', 'datahub_provider._airflow_compat']
|
||||
known_future_library = ['__future__', 'datahub.utilities._markupsafe_compat', 'datahub.sql_parsing._sqlglot_patch']
|
||||
profile = 'black'
|
||||
sections = 'FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER'
|
||||
skip_glob = 'src/datahub/metadata'
|
||||
|
||||
@ -99,9 +99,11 @@ usage_common = {
|
||||
}
|
||||
|
||||
sqlglot_lib = {
|
||||
# Using an Acryl fork of sqlglot.
|
||||
# We heavily monkeypatch sqlglot.
|
||||
# Prior to the patching, we originally maintained an acryl-sqlglot fork:
|
||||
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main?expand=1
|
||||
"acryl-sqlglot[rs]==25.25.2.dev9",
|
||||
"sqlglot[rs]==25.26.0",
|
||||
"patchy==2.8.0",
|
||||
}
|
||||
|
||||
classification_lib = {
|
||||
|
||||
215
metadata-ingestion/src/datahub/sql_parsing/_sqlglot_patch.py
Normal file
215
metadata-ingestion/src/datahub/sql_parsing/_sqlglot_patch.py
Normal file
@ -0,0 +1,215 @@
|
||||
import dataclasses
|
||||
import difflib
|
||||
import logging
|
||||
|
||||
import patchy.api
|
||||
import sqlglot
|
||||
import sqlglot.expressions
|
||||
import sqlglot.lineage
|
||||
import sqlglot.optimizer.scope
|
||||
import sqlglot.optimizer.unnest_subqueries
|
||||
|
||||
from datahub.utilities.is_pytest import is_pytest_running
|
||||
from datahub.utilities.unified_diff import apply_diff
|
||||
|
||||
# This injects a few patches into sqlglot to add features and mitigate
|
||||
# some bugs and performance issues.
|
||||
# The diffs in this file should match the diffs declared in our fork.
|
||||
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main
|
||||
# For a diff-formatted view, see:
|
||||
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main.diff
|
||||
|
||||
_DEBUG_PATCHER = is_pytest_running() or True
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_apply_diff_subprocess = patchy.api._apply_patch
|
||||
|
||||
|
||||
def _new_apply_patch(source: str, patch_text: str, forwards: bool, name: str) -> str:
|
||||
assert forwards, "Only forward patches are supported"
|
||||
|
||||
result = apply_diff(source, patch_text)
|
||||
|
||||
# TODO: When in testing mode, still run the subprocess and check that the
|
||||
# results line up.
|
||||
if _DEBUG_PATCHER:
|
||||
result_subprocess = _apply_diff_subprocess(source, patch_text, forwards, name)
|
||||
if result_subprocess != result:
|
||||
logger.info("Results from subprocess and _apply_diff do not match")
|
||||
logger.debug(f"Subprocess result:\n{result_subprocess}")
|
||||
logger.debug(f"Our result:\n{result}")
|
||||
diff = difflib.unified_diff(
|
||||
result_subprocess.splitlines(), result.splitlines()
|
||||
)
|
||||
logger.debug("Diff:\n" + "\n".join(diff))
|
||||
raise ValueError("Results from subprocess and _apply_diff do not match")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
patchy.api._apply_patch = _new_apply_patch
|
||||
|
||||
|
||||
def _patch_deepcopy() -> None:
|
||||
patchy.patch(
|
||||
sqlglot.expressions.Expression.__deepcopy__,
|
||||
"""\
|
||||
@@ -1,4 +1,7 @@ def meta(self) -> t.Dict[str, t.Any]:
|
||||
def __deepcopy__(self, memo):
|
||||
+ import datahub.utilities.cooperative_timeout
|
||||
+ datahub.utilities.cooperative_timeout.cooperate()
|
||||
+
|
||||
root = self.__class__()
|
||||
stack = [(self, root)]
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def _patch_scope_traverse() -> None:
|
||||
# Circular scope dependencies can happen in somewhat specific circumstances
|
||||
# due to our usage of sqlglot.
|
||||
# See https://github.com/tobymao/sqlglot/pull/4244
|
||||
patchy.patch(
|
||||
sqlglot.optimizer.scope.Scope.traverse,
|
||||
"""\
|
||||
@@ -5,9 +5,16 @@ def traverse(self):
|
||||
Scope: scope instances in depth-first-search post-order
|
||||
\"""
|
||||
stack = [self]
|
||||
+ seen_scopes = set()
|
||||
result = []
|
||||
while stack:
|
||||
scope = stack.pop()
|
||||
+
|
||||
+ # Scopes aren't hashable, so we use id(scope) instead.
|
||||
+ if id(scope) in seen_scopes:
|
||||
+ raise OptimizeError(f"Scope {scope} has a circular scope dependency")
|
||||
+ seen_scopes.add(id(scope))
|
||||
+
|
||||
result.append(scope)
|
||||
stack.extend(
|
||||
itertools.chain(
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def _patch_unnest_subqueries() -> None:
|
||||
patchy.patch(
|
||||
sqlglot.optimizer.unnest_subqueries.decorrelate,
|
||||
"""\
|
||||
@@ -261,16 +261,19 @@ def remove_aggs(node):
|
||||
if key in group_by:
|
||||
key.replace(nested)
|
||||
elif isinstance(predicate, exp.EQ):
|
||||
- parent_predicate = _replace(
|
||||
- parent_predicate,
|
||||
- f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
|
||||
- )
|
||||
+ if parent_predicate:
|
||||
+ parent_predicate = _replace(
|
||||
+ parent_predicate,
|
||||
+ f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
|
||||
+ )
|
||||
else:
|
||||
key.replace(exp.to_identifier("_x"))
|
||||
- parent_predicate = _replace(
|
||||
- parent_predicate,
|
||||
- f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))",
|
||||
- )
|
||||
+
|
||||
+ if parent_predicate:
|
||||
+ parent_predicate = _replace(
|
||||
+ parent_predicate,
|
||||
+ f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))",
|
||||
+ )
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def _patch_lineage() -> None:
|
||||
# Add the "subfield" attribute to sqlglot.lineage.Node.
|
||||
# With dataclasses, the easiest way to do this is with inheritance.
|
||||
# Unfortunately, mypy won't pick up on the new field, so we need to
|
||||
# use type ignores everywhere we use subfield.
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Node(sqlglot.lineage.Node):
|
||||
subfield: str = ""
|
||||
|
||||
sqlglot.lineage.Node = Node # type: ignore
|
||||
|
||||
patchy.patch(
|
||||
sqlglot.lineage.lineage,
|
||||
"""\
|
||||
@@ -12,7 +12,8 @@ def lineage(
|
||||
\"""
|
||||
|
||||
expression = maybe_parse(sql, dialect=dialect)
|
||||
- column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
|
||||
+ # column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
|
||||
+ assert isinstance(column, str)
|
||||
|
||||
if sources:
|
||||
expression = exp.expand(
|
||||
""",
|
||||
)
|
||||
|
||||
patchy.patch(
|
||||
sqlglot.lineage.to_node,
|
||||
"""\
|
||||
@@ -235,11 +237,12 @@ def to_node(
|
||||
)
|
||||
|
||||
# Find all columns that went into creating this one to list their lineage nodes.
|
||||
- source_columns = set(find_all_in_scope(select, exp.Column))
|
||||
+ source_columns = list(find_all_in_scope(select, exp.Column))
|
||||
|
||||
- # If the source is a UDTF find columns used in the UTDF to generate the table
|
||||
+ # If the source is a UDTF find columns used in the UDTF to generate the table
|
||||
+ source = scope.expression
|
||||
if isinstance(source, exp.UDTF):
|
||||
- source_columns |= set(source.find_all(exp.Column))
|
||||
+ source_columns += list(source.find_all(exp.Column))
|
||||
derived_tables = [
|
||||
source.expression.parent
|
||||
for source in scope.sources.values()
|
||||
@@ -254,6 +257,7 @@ def to_node(
|
||||
if dt.comments and dt.comments[0].startswith("source: ")
|
||||
}
|
||||
|
||||
+ c: exp.Column
|
||||
for c in source_columns:
|
||||
table = c.table
|
||||
source = scope.sources.get(table)
|
||||
@@ -281,8 +285,21 @@ def to_node(
|
||||
# it means this column's lineage is unknown. This can happen if the definition of a source used in a query
|
||||
# is not passed into the `sources` map.
|
||||
source = source or exp.Placeholder()
|
||||
+
|
||||
+ subfields = []
|
||||
+ field: exp.Expression = c
|
||||
+ while isinstance(field.parent, exp.Dot):
|
||||
+ field = field.parent
|
||||
+ subfields.append(field.name)
|
||||
+ subfield = ".".join(subfields)
|
||||
+
|
||||
node.downstream.append(
|
||||
- Node(name=c.sql(comments=False), source=source, expression=source)
|
||||
+ Node(
|
||||
+ name=c.sql(comments=False),
|
||||
+ source=source,
|
||||
+ expression=source,
|
||||
+ subfield=subfield,
|
||||
+ )
|
||||
)
|
||||
|
||||
return node
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
_patch_deepcopy()
|
||||
_patch_scope_traverse()
|
||||
_patch_unnest_subqueries()
|
||||
_patch_lineage()
|
||||
|
||||
SQLGLOT_PATCHED = True
|
||||
@ -1,3 +1,5 @@
|
||||
from datahub.sql_parsing._sqlglot_patch import SQLGLOT_PATCHED
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
@ -53,6 +55,8 @@ from datahub.utilities.cooperative_timeout import (
|
||||
cooperative_timeout,
|
||||
)
|
||||
|
||||
assert SQLGLOT_PATCHED
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Urn = str
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from datahub.sql_parsing._sqlglot_patch import SQLGLOT_PATCHED
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import logging
|
||||
@ -8,6 +10,8 @@ import sqlglot
|
||||
import sqlglot.errors
|
||||
import sqlglot.optimizer.eliminate_ctes
|
||||
|
||||
assert SQLGLOT_PATCHED
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DialectOrStr = Union[sqlglot.Dialect, str]
|
||||
SQL_PARSE_CACHE_SIZE = 1000
|
||||
|
||||
5
metadata-ingestion/src/datahub/utilities/is_pytest.py
Normal file
5
metadata-ingestion/src/datahub/utilities/is_pytest.py
Normal file
@ -0,0 +1,5 @@
|
||||
import sys
|
||||
|
||||
|
||||
def is_pytest_running() -> bool:
|
||||
return "pytest" in sys.modules
|
||||
236
metadata-ingestion/src/datahub/utilities/unified_diff.py
Normal file
236
metadata-ingestion/src/datahub/utilities/unified_diff.py
Normal file
@ -0,0 +1,236 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
_LOOKAROUND_LINES = 300
|
||||
|
||||
# The Python difflib library can generate unified diffs, but it cannot apply them.
|
||||
# There weren't any well-maintained and easy-to-use libraries for applying
|
||||
# unified diffs, so I wrote my own.
|
||||
#
|
||||
# My implementation is focused on ensuring correctness, and will throw
|
||||
# an exception whenever it detects an issue.
|
||||
#
|
||||
# Alternatives considered:
|
||||
# - diff-match-patch: This was the most promising since it's from Google.
|
||||
# Unfortunately, they deprecated the library in Aug 2024. That may not have
|
||||
# been a dealbreaker, since a somewhat greenfield community fork exists:
|
||||
# https://github.com/dmsnell/diff-match-patch
|
||||
# However, there's also a long-standing bug in the library around the
|
||||
# handling of line breaks when parsing diffs. See:
|
||||
# https://github.com/google/diff-match-patch/issues/157
|
||||
# - python-patch: Seems abandoned.
|
||||
# - patch-ng: Fork of python-patch, but mainly targeted at applying patches to trees.
|
||||
# It did not have simple "apply patch to string" abstractions.
|
||||
# - unidiff: Parses diffs, but cannot apply them.
|
||||
|
||||
|
||||
class InvalidDiffError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DiffApplyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hunk:
|
||||
source_start: int
|
||||
source_lines: int
|
||||
target_start: int
|
||||
target_lines: int
|
||||
lines: List[Tuple[str, str]]
|
||||
|
||||
|
||||
def parse_patch(patch_text: str) -> List[Hunk]:
|
||||
"""
|
||||
Parses a unified diff patch into a list of Hunk objects.
|
||||
|
||||
Args:
|
||||
patch_text: Unified diff format patch text
|
||||
|
||||
Returns:
|
||||
List of parsed Hunk objects
|
||||
|
||||
Raises:
|
||||
InvalidDiffError: If the patch is in an invalid format
|
||||
"""
|
||||
hunks = []
|
||||
patch_lines = patch_text.splitlines()
|
||||
i = 0
|
||||
|
||||
while i < len(patch_lines):
|
||||
line = patch_lines[i]
|
||||
|
||||
if line.startswith("@@"):
|
||||
try:
|
||||
header_parts = line.split()
|
||||
if len(header_parts) < 3:
|
||||
raise ValueError(f"Invalid hunk header format: {line}")
|
||||
|
||||
source_changes, target_changes = header_parts[1:3]
|
||||
source_start, source_lines = map(int, source_changes[1:].split(","))
|
||||
target_start, target_lines = map(int, target_changes[1:].split(","))
|
||||
|
||||
hunk = Hunk(source_start, source_lines, target_start, target_lines, [])
|
||||
i += 1
|
||||
|
||||
while i < len(patch_lines) and not patch_lines[i].startswith("@@"):
|
||||
hunk_line = patch_lines[i]
|
||||
if hunk_line:
|
||||
hunk.lines.append((hunk_line[0], hunk_line[1:]))
|
||||
else:
|
||||
# Fully empty lines usually means an empty context line that was
|
||||
# trimmed by trailing whitespace removal.
|
||||
hunk.lines.append((" ", ""))
|
||||
i += 1
|
||||
|
||||
hunks.append(hunk)
|
||||
except (IndexError, ValueError) as e:
|
||||
raise InvalidDiffError(f"Failed to parse hunk: {str(e)}") from e
|
||||
else:
|
||||
raise InvalidDiffError(f"Invalid line format: {line}")
|
||||
|
||||
return hunks
|
||||
|
||||
|
||||
def find_hunk_start(source_lines: List[str], hunk: Hunk) -> int:
|
||||
"""
|
||||
Finds the actual starting line of a hunk in the source lines.
|
||||
|
||||
Args:
|
||||
source_lines: The original source lines
|
||||
hunk: The hunk to locate
|
||||
|
||||
Returns:
|
||||
The actual line number where the hunk starts
|
||||
|
||||
Raises:
|
||||
DiffApplyError: If the hunk's context cannot be found in the source lines
|
||||
"""
|
||||
|
||||
# Extract context lines from the hunk, stopping at the first non-context line
|
||||
context_lines = []
|
||||
for prefix, line in hunk.lines:
|
||||
if prefix == " ":
|
||||
context_lines.append(line)
|
||||
else:
|
||||
break
|
||||
|
||||
if not context_lines:
|
||||
logger.debug("No context lines found in hunk.")
|
||||
return hunk.source_start - 1 # Default to the original start if no context
|
||||
|
||||
logger.debug(
|
||||
f"Searching for {len(context_lines)} context lines, starting with {context_lines[0]}"
|
||||
)
|
||||
|
||||
# Define the range to search for the context lines
|
||||
search_start = max(0, hunk.source_start - _LOOKAROUND_LINES)
|
||||
search_end = min(len(source_lines), hunk.source_start + _LOOKAROUND_LINES)
|
||||
|
||||
# Iterate over the possible starting positions in the source lines
|
||||
for i in range(search_start, search_end):
|
||||
# Check if the context lines match the source lines starting at position i
|
||||
match = True
|
||||
for j, context_line in enumerate(context_lines):
|
||||
if (i + j >= len(source_lines)) or source_lines[i + j] != context_line:
|
||||
match = False
|
||||
break
|
||||
if match:
|
||||
# logger.debug(f"Context match found at line: {i}")
|
||||
return i
|
||||
|
||||
logger.debug(f"Could not find match for hunk context lines: {context_lines}")
|
||||
raise DiffApplyError("Could not find match for hunk context.")
|
||||
|
||||
|
||||
def apply_hunk(result_lines: List[str], hunk: Hunk, hunk_index: int) -> None:
|
||||
"""
|
||||
Applies a single hunk to the result lines.
|
||||
|
||||
Args:
|
||||
result_lines: The current state of the patched file
|
||||
hunk: The hunk to apply
|
||||
hunk_index: The index of the hunk (for logging purposes)
|
||||
|
||||
Raises:
|
||||
DiffApplyError: If the hunk cannot be applied correctly
|
||||
"""
|
||||
current_line = find_hunk_start(result_lines, hunk)
|
||||
logger.debug(f"Hunk {hunk_index + 1} start line: {current_line}")
|
||||
|
||||
for line_index, (prefix, content) in enumerate(hunk.lines):
|
||||
# logger.debug(f"Processing line {line_index + 1} of hunk {hunk_index + 1}")
|
||||
# logger.debug(f"Current line: {current_line}, Total lines: {len(result_lines)}")
|
||||
# logger.debug(f"Prefix: {prefix}, Content: {content}")
|
||||
|
||||
if current_line >= len(result_lines):
|
||||
logger.debug(f"Reached end of file while applying hunk {hunk_index + 1}")
|
||||
while line_index < len(hunk.lines) and hunk.lines[line_index][0] == "+":
|
||||
result_lines.append(hunk.lines[line_index][1])
|
||||
line_index += 1
|
||||
|
||||
# If there's context or deletions past the end of the file, that's an error.
|
||||
if line_index < len(hunk.lines):
|
||||
raise DiffApplyError(
|
||||
f"Found context or deletions after end of file in hunk {hunk_index + 1}"
|
||||
)
|
||||
break
|
||||
|
||||
if prefix == "-":
|
||||
if result_lines[current_line].strip() != content.strip():
|
||||
raise DiffApplyError(
|
||||
f"Removing line that doesn't exactly match. Expected: '{content.strip()}', Found: '{result_lines[current_line].strip()}'"
|
||||
)
|
||||
result_lines.pop(current_line)
|
||||
elif prefix == "+":
|
||||
result_lines.insert(current_line, content)
|
||||
current_line += 1
|
||||
elif prefix == " ":
|
||||
if result_lines[current_line].strip() != content.strip():
|
||||
raise DiffApplyError(
|
||||
f"Context line doesn't exactly match. Expected: '{content.strip()}', Found: '{result_lines[current_line].strip()}'"
|
||||
)
|
||||
current_line += 1
|
||||
else:
|
||||
raise DiffApplyError(
|
||||
f"Invalid line prefix '{prefix}' in hunk {hunk_index + 1}, line {line_index + 1}"
|
||||
)
|
||||
|
||||
|
||||
def apply_diff(source: str, patch_text: str) -> str:
|
||||
"""
|
||||
Applies a unified diff patch to source text and returns the patched result.
|
||||
|
||||
Args:
|
||||
source: Original source text to be patched
|
||||
patch_text: Unified diff format patch text (with @@ markers and hunks)
|
||||
|
||||
Returns:
|
||||
The patched text result
|
||||
|
||||
Raises:
|
||||
InvalidDiffError: If the patch is in an invalid format
|
||||
DiffApplyError: If the patch cannot be applied correctly
|
||||
"""
|
||||
|
||||
# logger.debug(f"Original source:\n{source}")
|
||||
# logger.debug(f"Patch text:\n{patch_text}")
|
||||
|
||||
hunks = parse_patch(patch_text)
|
||||
logger.debug(f"Parsed into {len(hunks)} hunks")
|
||||
|
||||
source_lines = source.splitlines()
|
||||
result_lines = source_lines.copy()
|
||||
|
||||
for hunk_index, hunk in enumerate(hunks):
|
||||
logger.debug(f"Processing hunk {hunk_index + 1}")
|
||||
apply_hunk(result_lines, hunk, hunk_index)
|
||||
|
||||
result = "\n".join(result_lines) + "\n"
|
||||
# logger.debug(f"Patched result:\n{result}")
|
||||
return result
|
||||
@ -0,0 +1,48 @@
|
||||
from datahub.sql_parsing._sqlglot_patch import SQLGLOT_PATCHED
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import sqlglot
|
||||
import sqlglot.errors
|
||||
import sqlglot.lineage
|
||||
import sqlglot.optimizer
|
||||
|
||||
from datahub.utilities.cooperative_timeout import (
|
||||
CooperativeTimeoutError,
|
||||
cooperative_timeout,
|
||||
)
|
||||
from datahub.utilities.perf_timer import PerfTimer
|
||||
|
||||
assert SQLGLOT_PATCHED
|
||||
|
||||
|
||||
def test_cooperative_timeout_sql() -> None:
|
||||
statement = sqlglot.parse_one("SELECT pg_sleep(3)", dialect="postgres")
|
||||
with pytest.raises(
|
||||
CooperativeTimeoutError
|
||||
), PerfTimer() as timer, cooperative_timeout(timeout=0.6):
|
||||
while True:
|
||||
# sql() implicitly calls copy(), which is where we check for the timeout.
|
||||
assert statement.sql() is not None
|
||||
time.sleep(0.0001)
|
||||
assert 0.6 <= timer.elapsed_seconds() <= 1.0
|
||||
|
||||
|
||||
def test_scope_circular_dependency() -> None:
|
||||
scope = sqlglot.optimizer.build_scope(
|
||||
sqlglot.parse_one("WITH w AS (SELECT * FROM q) SELECT * FROM w")
|
||||
)
|
||||
assert scope is not None
|
||||
|
||||
cte_scope = scope.cte_scopes[0]
|
||||
cte_scope.cte_scopes.append(cte_scope)
|
||||
|
||||
with pytest.raises(sqlglot.errors.OptimizeError, match="circular scope dependency"):
|
||||
list(scope.traverse())
|
||||
|
||||
|
||||
def test_lineage_node_subfield() -> None:
|
||||
expression = sqlglot.parse_one("SELECT 1 AS test")
|
||||
node = sqlglot.lineage.Node("test", expression, expression, subfield="subfield") # type: ignore
|
||||
assert node.subfield == "subfield" # type: ignore
|
||||
191
metadata-ingestion/tests/unit/utilities/test_unified_diff.py
Normal file
191
metadata-ingestion/tests/unit/utilities/test_unified_diff.py
Normal file
@ -0,0 +1,191 @@
|
||||
import pytest
|
||||
|
||||
from datahub.utilities.unified_diff import (
|
||||
DiffApplyError,
|
||||
Hunk,
|
||||
InvalidDiffError,
|
||||
apply_diff,
|
||||
apply_hunk,
|
||||
find_hunk_start,
|
||||
parse_patch,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_patch():
|
||||
patch_text = """@@ -1,3 +1,4 @@
|
||||
Line 1
|
||||
-Line 2
|
||||
+Line 2 modified
|
||||
+Line 2.5
|
||||
Line 3"""
|
||||
hunks = parse_patch(patch_text)
|
||||
assert len(hunks) == 1
|
||||
assert hunks[0].source_start == 1
|
||||
assert hunks[0].source_lines == 3
|
||||
assert hunks[0].target_start == 1
|
||||
assert hunks[0].target_lines == 4
|
||||
assert hunks[0].lines == [
|
||||
(" ", "Line 1"),
|
||||
("-", "Line 2"),
|
||||
("+", "Line 2 modified"),
|
||||
("+", "Line 2.5"),
|
||||
(" ", "Line 3"),
|
||||
]
|
||||
|
||||
|
||||
def test_parse_patch_invalid():
|
||||
with pytest.raises(InvalidDiffError):
|
||||
parse_patch("Invalid patch")
|
||||
|
||||
|
||||
def test_parse_patch_bad_header():
|
||||
# A patch with a malformed header
|
||||
bad_patch_text = """@@ -1,3
|
||||
Line 1
|
||||
-Line 2
|
||||
+Line 2 modified
|
||||
Line 3"""
|
||||
with pytest.raises(InvalidDiffError):
|
||||
parse_patch(bad_patch_text)
|
||||
|
||||
|
||||
def test_find_hunk_start():
|
||||
source_lines = ["Line 1", "Line 2", "Line 3", "Line 4"]
|
||||
hunk = Hunk(2, 2, 2, 2, [(" ", "Line 2"), (" ", "Line 3")])
|
||||
assert find_hunk_start(source_lines, hunk) == 1
|
||||
|
||||
|
||||
def test_find_hunk_start_not_found():
|
||||
source_lines = ["Line 1", "Line 2", "Line 3", "Line 4"]
|
||||
hunk = Hunk(2, 2, 2, 2, [(" ", "Line X"), (" ", "Line Y")])
|
||||
with pytest.raises(DiffApplyError, match="Could not find match for hunk context."):
|
||||
find_hunk_start(source_lines, hunk)
|
||||
|
||||
|
||||
def test_apply_hunk_success():
|
||||
result_lines = ["Line 1", "Line 2", "Line 3"]
|
||||
hunk = Hunk(
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
[(" ", "Line 2"), ("-", "Line 3"), ("+", "Line 3 modified"), ("+", "Line 3.5")],
|
||||
)
|
||||
apply_hunk(result_lines, hunk, 0)
|
||||
assert result_lines == ["Line 1", "Line 2", "Line 3 modified", "Line 3.5"]
|
||||
|
||||
|
||||
def test_apply_hunk_mismatch():
|
||||
result_lines = ["Line 1", "Line 2", "Line X"]
|
||||
hunk = Hunk(
|
||||
2, 2, 2, 2, [(" ", "Line 2"), ("-", "Line 3"), ("+", "Line 3 modified")]
|
||||
)
|
||||
with pytest.raises(
|
||||
DiffApplyError, match="Removing line that doesn't exactly match"
|
||||
):
|
||||
apply_hunk(result_lines, hunk, 0)
|
||||
|
||||
|
||||
def test_apply_hunk_context_mismatch():
|
||||
result_lines = ["Line 1", "Line 3"]
|
||||
hunk = Hunk(2, 2, 2, 2, [(" ", "Line 1"), ("+", "Line 2"), (" ", "Line 4")])
|
||||
with pytest.raises(DiffApplyError, match="Context line doesn't exactly match"):
|
||||
apply_hunk(result_lines, hunk, 0)
|
||||
|
||||
|
||||
def test_apply_hunk_invalid_prefix():
|
||||
result_lines = ["Line 1", "Line 2", "Line 3"]
|
||||
hunk = Hunk(
|
||||
2, 2, 2, 2, [(" ", "Line 2"), ("*", "Line 3"), ("+", "Line 3 modified")]
|
||||
)
|
||||
with pytest.raises(DiffApplyError, match="Invalid line prefix"):
|
||||
apply_hunk(result_lines, hunk, 0)
|
||||
|
||||
|
||||
def test_apply_hunk_end_of_file():
|
||||
result_lines = ["Line 1", "Line 2"]
|
||||
hunk = Hunk(
|
||||
2, 2, 2, 3, [(" ", "Line 2"), ("-", "Line 3"), ("+", "Line 3 modified")]
|
||||
)
|
||||
with pytest.raises(
|
||||
DiffApplyError, match="Found context or deletions after end of file"
|
||||
):
|
||||
apply_hunk(result_lines, hunk, 0)
|
||||
|
||||
|
||||
def test_apply_hunk_context_beyond_end_of_file():
|
||||
result_lines = ["Line 1", "Line 3"]
|
||||
hunk = Hunk(
|
||||
2, 2, 2, 3, [(" ", "Line 1"), ("+", "Line 2"), (" ", "Line 3"), (" ", "Line 4")]
|
||||
)
|
||||
with pytest.raises(
|
||||
DiffApplyError, match="Found context or deletions after end of file"
|
||||
):
|
||||
apply_hunk(result_lines, hunk, 0)
|
||||
|
||||
|
||||
def test_apply_hunk_remove_non_existent_line():
|
||||
result_lines = ["Line 1", "Line 2", "Line 4"]
|
||||
hunk = Hunk(
|
||||
2, 2, 2, 3, [(" ", "Line 2"), ("-", "Line 3"), ("+", "Line 3 modified")]
|
||||
)
|
||||
with pytest.raises(
|
||||
DiffApplyError, match="Removing line that doesn't exactly match"
|
||||
):
|
||||
apply_hunk(result_lines, hunk, 0)
|
||||
|
||||
|
||||
def test_apply_hunk_addition_beyond_end_of_file():
|
||||
result_lines = ["Line 1", "Line 2"]
|
||||
hunk = Hunk(
|
||||
2, 2, 2, 3, [(" ", "Line 2"), ("+", "Line 3 modified"), ("+", "Line 4")]
|
||||
)
|
||||
apply_hunk(result_lines, hunk, 0)
|
||||
assert result_lines == ["Line 1", "Line 2", "Line 3 modified", "Line 4"]
|
||||
|
||||
|
||||
def test_apply_diff():
|
||||
source = """Line 1
|
||||
Line 2
|
||||
Line 3
|
||||
Line 4"""
|
||||
patch = """@@ -1,4 +1,5 @@
|
||||
Line 1
|
||||
-Line 2
|
||||
+Line 2 modified
|
||||
+Line 2.5
|
||||
Line 3
|
||||
Line 4"""
|
||||
result = apply_diff(source, patch)
|
||||
expected = """Line 1
|
||||
Line 2 modified
|
||||
Line 2.5
|
||||
Line 3
|
||||
Line 4
|
||||
"""
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_apply_diff_invalid_patch():
|
||||
source = "Line 1\nLine 2\n"
|
||||
patch = "Invalid patch"
|
||||
with pytest.raises(InvalidDiffError):
|
||||
apply_diff(source, patch)
|
||||
|
||||
|
||||
def test_apply_diff_unapplicable_patch():
|
||||
source = "Line 1\nLine 2\n"
|
||||
patch = "@@ -1,2 +1,2 @@\n Line 1\n-Line X\n+Line 2 modified\n"
|
||||
with pytest.raises(DiffApplyError):
|
||||
apply_diff(source, patch)
|
||||
|
||||
|
||||
def test_apply_diff_add_to_empty_file():
|
||||
source = ""
|
||||
patch = """\
|
||||
@@ -1,0 +1,1 @@
|
||||
+Line 1
|
||||
+Line 2
|
||||
"""
|
||||
result = apply_diff(source, patch)
|
||||
assert result == "Line 1\nLine 2\n"
|
||||
@ -1,6 +1,7 @@
|
||||
import doctest
|
||||
|
||||
from datahub.utilities.delayed_iter import delayed_iter
|
||||
from datahub.utilities.is_pytest import is_pytest_running
|
||||
from datahub.utilities.sql_parser import SqlLineageSQLParser
|
||||
|
||||
|
||||
@ -295,3 +296,7 @@ def test_logging_name_extraction():
|
||||
).attempted
|
||||
> 0
|
||||
)
|
||||
|
||||
|
||||
def test_is_pytest_running() -> None:
|
||||
assert is_pytest_running()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user