mirror of
https://github.com/datahub-project/datahub.git
synced 2025-09-04 06:43:16 +00:00
feat(cli): cache sql parsing intermediates (#10399)
This commit is contained in:
parent
1dae37a8ed
commit
0e8fc5129f
@ -406,10 +406,11 @@ def _column_level_lineage( # noqa: C901
|
|||||||
return default_col_name
|
return default_col_name
|
||||||
|
|
||||||
# Optimize the statement + qualify column references.
|
# Optimize the statement + qualify column references.
|
||||||
logger.debug(
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
"Prior to column qualification sql %s",
|
logger.debug(
|
||||||
statement.sql(pretty=True, dialect=dialect),
|
"Prior to column qualification sql %s",
|
||||||
)
|
statement.sql(pretty=True, dialect=dialect),
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
# Second time running qualify, this time with:
|
# Second time running qualify, this time with:
|
||||||
# - the select instead of the full outer statement
|
# - the select instead of the full outer statement
|
||||||
@ -434,7 +435,8 @@ def _column_level_lineage( # noqa: C901
|
|||||||
raise SqlUnderstandingError(
|
raise SqlUnderstandingError(
|
||||||
f"sqlglot failed to map columns to their source tables; likely missing/outdated table schema info: {e}"
|
f"sqlglot failed to map columns to their source tables; likely missing/outdated table schema info: {e}"
|
||||||
) from e
|
) from e
|
||||||
logger.debug("Qualified sql %s", statement.sql(pretty=True, dialect=dialect))
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug("Qualified sql %s", statement.sql(pretty=True, dialect=dialect))
|
||||||
|
|
||||||
# Handle the create DDL case.
|
# Handle the create DDL case.
|
||||||
if is_create_ddl:
|
if is_create_ddl:
|
||||||
@ -805,7 +807,7 @@ def _sqlglot_lineage_inner(
|
|||||||
logger.debug("Parsing lineage from sql statement: %s", sql)
|
logger.debug("Parsing lineage from sql statement: %s", sql)
|
||||||
statement = parse_statement(sql, dialect=dialect)
|
statement = parse_statement(sql, dialect=dialect)
|
||||||
|
|
||||||
original_statement = statement.copy()
|
original_statement, statement = statement, statement.copy()
|
||||||
# logger.debug(
|
# logger.debug(
|
||||||
# "Formatted sql statement: %s",
|
# "Formatted sql statement: %s",
|
||||||
# original_statement.sql(pretty=True, dialect=dialect),
|
# original_statement.sql(pretty=True, dialect=dialect),
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Iterable, Optional, Tuple, Union
|
from typing import Dict, Iterable, Optional, Tuple, Union
|
||||||
@ -7,6 +8,7 @@ import sqlglot.errors
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
DialectOrStr = Union[sqlglot.Dialect, str]
|
DialectOrStr = Union[sqlglot.Dialect, str]
|
||||||
|
SQL_PARSE_CACHE_SIZE = 1000
|
||||||
|
|
||||||
|
|
||||||
def _get_dialect_str(platform: str) -> str:
|
def _get_dialect_str(platform: str) -> str:
|
||||||
@ -55,7 +57,8 @@ def is_dialect_instance(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def parse_statement(
|
@functools.lru_cache(maxsize=SQL_PARSE_CACHE_SIZE)
|
||||||
|
def _parse_statement(
|
||||||
sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect
|
sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect
|
||||||
) -> sqlglot.Expression:
|
) -> sqlglot.Expression:
|
||||||
statement: sqlglot.Expression = sqlglot.maybe_parse(
|
statement: sqlglot.Expression = sqlglot.maybe_parse(
|
||||||
@ -64,6 +67,16 @@ def parse_statement(
|
|||||||
return statement
|
return statement
|
||||||
|
|
||||||
|
|
||||||
|
def parse_statement(
|
||||||
|
sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect
|
||||||
|
) -> sqlglot.Expression:
|
||||||
|
# Parsing is significantly more expensive than copying the expression.
|
||||||
|
# Because the expressions are mutable, we don't want to allow the caller
|
||||||
|
# to modify the parsed expression that sits in the cache. We keep
|
||||||
|
# the cached versions pristine by returning a copy on each call.
|
||||||
|
return _parse_statement(sql, dialect).copy()
|
||||||
|
|
||||||
|
|
||||||
def parse_statements_and_pick(sql: str, platform: DialectOrStr) -> sqlglot.Expression:
|
def parse_statements_and_pick(sql: str, platform: DialectOrStr) -> sqlglot.Expression:
|
||||||
dialect = get_dialect(platform)
|
dialect = get_dialect(platform)
|
||||||
statements = [
|
statements = [
|
||||||
@ -277,4 +290,5 @@ def detach_ctes(
|
|||||||
else:
|
else:
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
statement = statement.copy()
|
||||||
return statement.transform(replace_cte_refs, copy=False)
|
return statement.transform(replace_cte_refs, copy=False)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user