perf(ingest): streamline CLL generation (#11645)

This commit is contained in:
Harshal Sheth 2024-10-17 17:50:59 -07:00 committed by GitHub
parent d7ccf3b2c3
commit 6e3724b2da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 118 additions and 28 deletions

View File

@ -101,7 +101,7 @@ usage_common = {
sqlglot_lib = { sqlglot_lib = {
# Using an Acryl fork of sqlglot. # Using an Acryl fork of sqlglot.
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main?expand=1 # https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main?expand=1
"acryl-sqlglot[rs]==25.20.2.dev6", "acryl-sqlglot[rs]==25.25.2.dev9",
} }
classification_lib = { classification_lib = {

View File

@ -133,7 +133,7 @@ def parse_alter_table_rename(default_schema: str, query: str) -> Tuple[str, str,
assert isinstance(parsed_query, sqlglot.exp.Alter) assert isinstance(parsed_query, sqlglot.exp.Alter)
prev_name = parsed_query.this.name prev_name = parsed_query.this.name
rename_clause = parsed_query.args["actions"][0] rename_clause = parsed_query.args["actions"][0]
assert isinstance(rename_clause, sqlglot.exp.RenameTable) assert isinstance(rename_clause, sqlglot.exp.AlterRename)
new_name = rename_clause.this.name new_name = rename_clause.this.name
schema = parsed_query.this.db or default_schema schema = parsed_query.this.db or default_schema

View File

@ -2131,7 +2131,7 @@ class TableauSiteSource:
fine_grained_lineages: List[FineGrainedLineage] = [] fine_grained_lineages: List[FineGrainedLineage] = []
if self.config.extract_column_level_lineage: if self.config.extract_column_level_lineage:
logger.info("Extracting CLL from custom sql") logger.debug("Extracting CLL from custom sql")
fine_grained_lineages = make_fine_grained_lineage_class( fine_grained_lineages = make_fine_grained_lineage_class(
parsed_result, csql_urn, out_columns parsed_result, csql_urn, out_columns
) )

View File

@ -1,6 +1,5 @@
import dataclasses import dataclasses
import functools import functools
import itertools
import logging import logging
import traceback import traceback
from collections import defaultdict from collections import defaultdict
@ -14,6 +13,8 @@ import sqlglot.optimizer
import sqlglot.optimizer.annotate_types import sqlglot.optimizer.annotate_types
import sqlglot.optimizer.optimizer import sqlglot.optimizer.optimizer
import sqlglot.optimizer.qualify import sqlglot.optimizer.qualify
import sqlglot.optimizer.qualify_columns
import sqlglot.optimizer.unnest_subqueries
from datahub.cli.env_utils import get_boolean_env_variable from datahub.cli.env_utils import get_boolean_env_variable
from datahub.ingestion.graph.client import DataHubGraph from datahub.ingestion.graph.client import DataHubGraph
@ -63,24 +64,30 @@ SQL_LINEAGE_TIMEOUT_ENABLED = get_boolean_env_variable(
SQL_LINEAGE_TIMEOUT_SECONDS = 10 SQL_LINEAGE_TIMEOUT_SECONDS = 10
RULES_BEFORE_TYPE_ANNOTATION: tuple = tuple( # These rules are a subset of the rules in sqlglot.optimizer.optimizer.RULES.
filter( # If there's a change in their rules, we probably need to re-evaluate our list as well.
lambda func: func.__name__ assert len(sqlglot.optimizer.optimizer.RULES) == 14
not in {
# Skip pushdown_predicates because it sometimes throws exceptions, and we _OPTIMIZE_RULES = (
# don't actually need it for anything. sqlglot.optimizer.optimizer.qualify,
"pushdown_predicates", # We need to enable this in order for annotate types to work.
# Skip normalize because it can sometimes be expensive. sqlglot.optimizer.optimizer.pushdown_projections,
"normalize", # sqlglot.optimizer.optimizer.normalize, # causes perf issues
}, sqlglot.optimizer.optimizer.unnest_subqueries,
itertools.takewhile( # sqlglot.optimizer.optimizer.pushdown_predicates, # causes perf issues
lambda func: func != sqlglot.optimizer.annotate_types.annotate_types, # sqlglot.optimizer.optimizer.optimize_joins,
sqlglot.optimizer.optimizer.RULES, # sqlglot.optimizer.optimizer.eliminate_subqueries,
), # sqlglot.optimizer.optimizer.merge_subqueries,
) # sqlglot.optimizer.optimizer.eliminate_joins,
# sqlglot.optimizer.optimizer.eliminate_ctes,
sqlglot.optimizer.optimizer.quote_identifiers,
# These three are run separately or not run at all.
# sqlglot.optimizer.optimizer.annotate_types,
# sqlglot.optimizer.canonicalize.canonicalize,
# sqlglot.optimizer.simplify.simplify,
) )
# Quick check that the rules were loaded correctly.
assert 0 < len(RULES_BEFORE_TYPE_ANNOTATION) < len(sqlglot.optimizer.optimizer.RULES) _DEBUG_TYPE_ANNOTATIONS = False
class _ColumnRef(_FrozenModel): class _ColumnRef(_FrozenModel):
@ -385,11 +392,12 @@ def _prepare_query_columns(
schema=sqlglot_db_schema, schema=sqlglot_db_schema,
qualify_columns=True, qualify_columns=True,
validate_qualify_columns=False, validate_qualify_columns=False,
allow_partial_qualification=True,
identify=True, identify=True,
# sqlglot calls the db -> schema -> table hierarchy "catalog", "db", "table". # sqlglot calls the db -> schema -> table hierarchy "catalog", "db", "table".
catalog=default_db, catalog=default_db,
db=default_schema, db=default_schema,
rules=RULES_BEFORE_TYPE_ANNOTATION, rules=_OPTIMIZE_RULES,
) )
except (sqlglot.errors.OptimizeError, ValueError) as e: except (sqlglot.errors.OptimizeError, ValueError) as e:
raise SqlUnderstandingError( raise SqlUnderstandingError(
@ -408,6 +416,10 @@ def _prepare_query_columns(
except (sqlglot.errors.OptimizeError, sqlglot.errors.ParseError) as e: except (sqlglot.errors.OptimizeError, sqlglot.errors.ParseError) as e:
# This is not a fatal error, so we can continue. # This is not a fatal error, so we can continue.
logger.debug("sqlglot failed to annotate or parse types: %s", e) logger.debug("sqlglot failed to annotate or parse types: %s", e)
if _DEBUG_TYPE_ANNOTATIONS and logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Type annotated sql %s", statement.sql(pretty=True, dialect=dialect)
)
return statement, _ColumnResolver( return statement, _ColumnResolver(
sqlglot_db_schema=sqlglot_db_schema, sqlglot_db_schema=sqlglot_db_schema,
@ -907,6 +919,7 @@ def _sqlglot_lineage_inner(
# At this stage we only want to qualify the table names. The columns will be dealt with later. # At this stage we only want to qualify the table names. The columns will be dealt with later.
qualify_columns=False, qualify_columns=False,
validate_qualify_columns=False, validate_qualify_columns=False,
allow_partial_qualification=True,
# Only insert quotes where necessary. # Only insert quotes where necessary.
identify=False, identify=False,
) )

View File

@ -19,6 +19,7 @@ from tableauserverclient.models import (
from datahub.configuration.source_common import DEFAULT_ENV from datahub.configuration.source_common import DEFAULT_ENV
from datahub.emitter.mce_builder import make_schema_field_urn from datahub.emitter.mce_builder import make_schema_field_urn
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.run.pipeline import Pipeline, PipelineContext from datahub.ingestion.run.pipeline import Pipeline, PipelineContext
from datahub.ingestion.source.tableau.tableau import ( from datahub.ingestion.source.tableau.tableau import (
TableauConfig, TableauConfig,
@ -37,7 +38,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
FineGrainedLineageUpstreamType, FineGrainedLineageUpstreamType,
UpstreamLineage, UpstreamLineage,
) )
from datahub.metadata.schema_classes import MetadataChangeProposalClass, UpstreamClass from datahub.metadata.schema_classes import UpstreamClass
from tests.test_helpers import mce_helpers, test_connection_helpers from tests.test_helpers import mce_helpers, test_connection_helpers
from tests.test_helpers.state_helpers import ( from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline, get_current_checkpoint_from_pipeline,
@ -939,11 +940,12 @@ def test_tableau_unsupported_csql():
database_override_map={"production database": "prod"} database_override_map={"production database": "prod"}
) )
def test_lineage_metadata( def check_lineage_metadata(
lineage, expected_entity_urn, expected_upstream_table, expected_cll lineage, expected_entity_urn, expected_upstream_table, expected_cll
): ):
mcp = cast(MetadataChangeProposalClass, next(iter(lineage)).metadata) mcp = cast(MetadataChangeProposalWrapper, list(lineage)[0].metadata)
assert mcp.aspect == UpstreamLineage(
expected = UpstreamLineage(
upstreams=[ upstreams=[
UpstreamClass( UpstreamClass(
dataset=expected_upstream_table, dataset=expected_upstream_table,
@ -966,6 +968,9 @@ def test_tableau_unsupported_csql():
) )
assert mcp.entityUrn == expected_entity_urn assert mcp.entityUrn == expected_entity_urn
actual_aspect = mcp.aspect
assert actual_aspect == expected
csql_urn = "urn:li:dataset:(urn:li:dataPlatform:tableau,09988088-05ad-173c-a2f1-f33ba3a13d1a,PROD)" csql_urn = "urn:li:dataset:(urn:li:dataPlatform:tableau,09988088-05ad-173c-a2f1-f33ba3a13d1a,PROD)"
expected_upstream_table = "urn:li:dataset:(urn:li:dataPlatform:bigquery,my_bigquery_project.invent_dw.UserDetail,PROD)" expected_upstream_table = "urn:li:dataset:(urn:li:dataPlatform:bigquery,my_bigquery_project.invent_dw.UserDetail,PROD)"
expected_cll = { expected_cll = {
@ -996,7 +1001,7 @@ def test_tableau_unsupported_csql():
}, },
out_columns=[], out_columns=[],
) )
test_lineage_metadata( check_lineage_metadata(
lineage=lineage, lineage=lineage,
expected_entity_urn=csql_urn, expected_entity_urn=csql_urn,
expected_upstream_table=expected_upstream_table, expected_upstream_table=expected_upstream_table,
@ -1014,7 +1019,7 @@ def test_tableau_unsupported_csql():
}, },
out_columns=[], out_columns=[],
) )
test_lineage_metadata( check_lineage_metadata(
lineage=lineage, lineage=lineage,
expected_entity_urn=csql_urn, expected_entity_urn=csql_urn,
expected_upstream_table=expected_upstream_table, expected_upstream_table=expected_upstream_table,

View File

@ -0,0 +1,57 @@
{
"query_type": "SELECT",
"query_type_props": {},
"query_fingerprint": "4094ebd230c1d47c7e6879b05ab927e550923b1986eb58c5f3814396cf401d18",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:bigquery,invent_dw.UserDetail,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "user_id",
"column_type": null,
"native_column_type": null
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,invent_dw.UserDetail,PROD)",
"column": "user_id"
}
]
},
{
"downstream": {
"table": null,
"column": "source",
"column_type": null,
"native_column_type": null
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,invent_dw.UserDetail,PROD)",
"column": "source"
}
]
},
{
"downstream": {
"table": null,
"column": "user_source",
"column_type": null,
"native_column_type": null
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,invent_dw.UserDetail,PROD)",
"column": "user_source"
}
]
}
],
"debug_info": {
"confidence": 0.2,
"generalized_statement": "SELECT user_id, source, user_source FROM (SELECT *, ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY __partition_day DESC) AS rank_ FROM invent_dw.UserDetail) AS source_user WHERE rank_ = ?"
}
}

View File

@ -1253,3 +1253,18 @@ DROP SCHEMA my_schema
dialect="snowflake", dialect="snowflake",
expected_file=RESOURCE_DIR / "test_snowflake_drop_schema.json", expected_file=RESOURCE_DIR / "test_snowflake_drop_schema.json",
) )
def test_bigquery_subquery_column_inference() -> None:
assert_sql_result(
"""\
SELECT user_id, source, user_source
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY __partition_day DESC) AS rank_
FROM invent_dw.UserDetail
) source_user
WHERE rank_ = 1
""",
dialect="bigquery",
expected_file=RESOURCE_DIR / "test_bigquery_subquery_column_inference.json",
)