feat(cli): cache sql parsing intermediates (#10399)

This commit is contained in:
Harshal Sheth 2024-05-06 16:59:00 -07:00 committed by GitHub
parent 1dae37a8ed
commit 0e8fc5129f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 7 deletions

View File

@ -406,6 +406,7 @@ def _column_level_lineage( # noqa: C901
return default_col_name
# Optimize the statement + qualify column references.
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Prior to column qualification sql %s",
statement.sql(pretty=True, dialect=dialect),
@ -434,6 +435,7 @@ def _column_level_lineage( # noqa: C901
raise SqlUnderstandingError(
f"sqlglot failed to map columns to their source tables; likely missing/outdated table schema info: {e}"
) from e
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Qualified sql %s", statement.sql(pretty=True, dialect=dialect))
# Handle the create DDL case.
@ -805,7 +807,7 @@ def _sqlglot_lineage_inner(
logger.debug("Parsing lineage from sql statement: %s", sql)
statement = parse_statement(sql, dialect=dialect)
original_statement = statement.copy()
original_statement, statement = statement, statement.copy()
# logger.debug(
# "Formatted sql statement: %s",
# original_statement.sql(pretty=True, dialect=dialect),

View File

@ -1,3 +1,4 @@
import functools
import hashlib
import logging
from typing import Dict, Iterable, Optional, Tuple, Union
@ -7,6 +8,7 @@ import sqlglot.errors
logger = logging.getLogger(__name__)
DialectOrStr = Union[sqlglot.Dialect, str]
SQL_PARSE_CACHE_SIZE = 1000
def _get_dialect_str(platform: str) -> str:
@ -55,7 +57,8 @@ def is_dialect_instance(
return False
def parse_statement(
@functools.lru_cache(maxsize=SQL_PARSE_CACHE_SIZE)
def _parse_statement(
sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect
) -> sqlglot.Expression:
statement: sqlglot.Expression = sqlglot.maybe_parse(
@ -64,6 +67,16 @@ def parse_statement(
return statement
def parse_statement(
sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect
) -> sqlglot.Expression:
# Parsing is significantly more expensive than copying the expression.
# Because the expressions are mutable, we don't want to allow the caller
# to modify the parsed expression that sits in the cache. We keep
# the cached versions pristine by returning a copy on each call.
return _parse_statement(sql, dialect).copy()
def parse_statements_and_pick(sql: str, platform: DialectOrStr) -> sqlglot.Expression:
dialect = get_dialect(platform)
statements = [
@ -277,4 +290,5 @@ def detach_ctes(
else:
return node
statement = statement.copy()
return statement.transform(replace_cte_refs, copy=False)