diff --git a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py index de648ec29b..b8dce11616 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py @@ -406,10 +406,11 @@ def _column_level_lineage( # noqa: C901 return default_col_name # Optimize the statement + qualify column references. - logger.debug( - "Prior to column qualification sql %s", - statement.sql(pretty=True, dialect=dialect), - ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Prior to column qualification sql %s", + statement.sql(pretty=True, dialect=dialect), + ) try: # Second time running qualify, this time with: # - the select instead of the full outer statement @@ -434,7 +435,8 @@ def _column_level_lineage( # noqa: C901 raise SqlUnderstandingError( f"sqlglot failed to map columns to their source tables; likely missing/outdated table schema info: {e}" ) from e - logger.debug("Qualified sql %s", statement.sql(pretty=True, dialect=dialect)) + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Qualified sql %s", statement.sql(pretty=True, dialect=dialect)) # Handle the create DDL case. if is_create_ddl: @@ -805,7 +807,7 @@ def _sqlglot_lineage_inner( logger.debug("Parsing lineage from sql statement: %s", sql) statement = parse_statement(sql, dialect=dialect) - original_statement = statement.copy() + original_statement, statement = statement, statement.copy() # logger.debug( # "Formatted sql statement: %s", # original_statement.sql(pretty=True, dialect=dialect), diff --git a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py index c7cf975a3a..25988f9905 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py @@ -1,3 +1,4 @@ +import functools import hashlib import logging from typing import Dict, Iterable, Optional, Tuple, Union @@ -7,6 +8,7 @@ import sqlglot.errors logger = logging.getLogger(__name__) DialectOrStr = Union[sqlglot.Dialect, str] +SQL_PARSE_CACHE_SIZE = 1000 def _get_dialect_str(platform: str) -> str: @@ -55,7 +57,8 @@ def is_dialect_instance( return False -def parse_statement( +@functools.lru_cache(maxsize=SQL_PARSE_CACHE_SIZE) +def _parse_statement( sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect ) -> sqlglot.Expression: statement: sqlglot.Expression = sqlglot.maybe_parse( @@ -64,6 +67,16 @@ def parse_statement( return statement +def parse_statement( + sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect +) -> sqlglot.Expression: + # Parsing is significantly more expensive than copying the expression. + # Because the expressions are mutable, we don't want to allow the caller + # to modify the parsed expression that sits in the cache. We keep + # the cached versions pristine by returning a copy on each call. + return _parse_statement(sql, dialect).copy() + + def parse_statements_and_pick(sql: str, platform: DialectOrStr) -> sqlglot.Expression: dialect = get_dialect(platform) statements = [ @@ -277,4 +290,5 @@ def detach_ctes( else: return node + statement = statement.copy() return statement.transform(replace_cte_refs, copy=False)