mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-26 17:37:33 +00:00
fix(ingestion/snowflake): Address diamond lineage problem + performance improvements (#13918)
This commit is contained in:
parent
1f8abfc762
commit
f0a6e016e8
@ -1577,27 +1577,33 @@ class SqlParsingAggregator(Closeable):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class QueryLineageInfo:
|
||||
upstreams: List[UrnStr] # this is direct upstreams, with *no temp tables*
|
||||
column_lineage: List[ColumnLineageInfo]
|
||||
upstreams: OrderedSet[
|
||||
UrnStr
|
||||
] # this is direct upstreams, with *no temp tables*
|
||||
column_lineage: OrderedSet[ColumnLineageInfo]
|
||||
confidence_score: float
|
||||
|
||||
def _merge_lineage_from(self, other_query: "QueryLineageInfo") -> None:
|
||||
self.upstreams += other_query.upstreams
|
||||
self.column_lineage += other_query.column_lineage
|
||||
self.upstreams.update(other_query.upstreams)
|
||||
self.column_lineage.update(other_query.column_lineage)
|
||||
self.confidence_score = min(
|
||||
self.confidence_score, other_query.confidence_score
|
||||
)
|
||||
|
||||
cache: Dict[str, QueryLineageInfo] = {}
|
||||
|
||||
def _recurse_into_query(
|
||||
query: QueryMetadata, recursion_path: List[QueryId]
|
||||
) -> QueryLineageInfo:
|
||||
if query.query_id in recursion_path:
|
||||
# This is a cycle, so we just return the query as-is.
|
||||
return QueryLineageInfo(
|
||||
upstreams=query.upstreams,
|
||||
column_lineage=query.column_lineage,
|
||||
upstreams=OrderedSet(query.upstreams),
|
||||
column_lineage=OrderedSet(query.column_lineage),
|
||||
confidence_score=query.confidence_score,
|
||||
)
|
||||
if query.query_id in cache:
|
||||
return cache[query.query_id]
|
||||
recursion_path = [*recursion_path, query.query_id]
|
||||
composed_of_queries.add(query.query_id)
|
||||
|
||||
@ -1612,7 +1618,7 @@ class SqlParsingAggregator(Closeable):
|
||||
upstream_query = self._query_map.get(upstream_query_id)
|
||||
if (
|
||||
upstream_query
|
||||
and upstream_query.query_id not in composed_of_queries
|
||||
and upstream_query.query_id not in recursion_path
|
||||
):
|
||||
temp_query_lineage_info = _recurse_into_query(
|
||||
upstream_query, recursion_path
|
||||
@ -1672,11 +1678,14 @@ class SqlParsingAggregator(Closeable):
|
||||
]
|
||||
)
|
||||
|
||||
return QueryLineageInfo(
|
||||
upstreams=list(new_upstreams),
|
||||
column_lineage=new_cll,
|
||||
ret = QueryLineageInfo(
|
||||
upstreams=new_upstreams,
|
||||
column_lineage=OrderedSet(new_cll),
|
||||
confidence_score=new_confidence_score,
|
||||
)
|
||||
cache[query.query_id] = ret
|
||||
|
||||
return ret
|
||||
|
||||
resolved_lineage_info = _recurse_into_query(base_query, [])
|
||||
|
||||
@ -1716,8 +1725,8 @@ class SqlParsingAggregator(Closeable):
|
||||
base_query,
|
||||
query_id=composite_query_id,
|
||||
formatted_query_string=merged_query_text,
|
||||
upstreams=resolved_lineage_info.upstreams,
|
||||
column_lineage=resolved_lineage_info.column_lineage,
|
||||
upstreams=list(resolved_lineage_info.upstreams),
|
||||
column_lineage=list(resolved_lineage_info.column_lineage),
|
||||
confidence_score=resolved_lineage_info.confidence_score,
|
||||
)
|
||||
|
||||
|
||||
@ -125,6 +125,17 @@ class _DownstreamColumnRef(_ParserBaseModel):
|
||||
|
||||
|
||||
class DownstreamColumnRef(_ParserBaseModel):
|
||||
"""
|
||||
TODO: Instead of implementing custom __hash__ function this class should simply inherit from _FrozenModel.
|
||||
What stops us is that `column_type` field of type `SchemaFieldDataTypeClass` is not hashable - it's an
|
||||
auto-generated class from .pdl model files. We need generic solution allowing us to either:
|
||||
1. Implement hashing for .pdl model objects
|
||||
2. Reliably provide pydantic (both v1 and v2) with information to skip particular fields from default
|
||||
hash function - with a twist here that _FrozenModel implements its own `__lt__` function - it needs
|
||||
to understand that instruction as well.
|
||||
Instances of this class needs to be hashable as we store them in a set when processing lineage from queries.
|
||||
"""
|
||||
|
||||
table: Optional[Urn] = None
|
||||
column: str
|
||||
column_type: Optional[SchemaFieldDataTypeClass] = None
|
||||
@ -140,8 +151,11 @@ class DownstreamColumnRef(_ParserBaseModel):
|
||||
return v
|
||||
return SchemaFieldDataTypeClass.from_obj(v)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.table, self.column, self.native_column_type))
|
||||
|
||||
class ColumnTransformation(_ParserBaseModel):
|
||||
|
||||
class ColumnTransformation(_FrozenModel):
|
||||
is_direct_copy: bool
|
||||
column_logic: str
|
||||
|
||||
@ -154,11 +168,21 @@ class _ColumnLineageInfo(_ParserBaseModel):
|
||||
|
||||
|
||||
class ColumnLineageInfo(_ParserBaseModel):
|
||||
"""
|
||||
TODO: Instead of implementing custom __hash__ function this class should simply inherit from _FrozenModel.
|
||||
To achieve this, we need to change `upstreams` to `Tuple[ColumnRef, ...]` - along with many code lines
|
||||
depending on it.
|
||||
Instances of this class needs to be hashable as we store them in a set when processing lineage from queries.
|
||||
"""
|
||||
|
||||
downstream: DownstreamColumnRef
|
||||
upstreams: List[ColumnRef]
|
||||
|
||||
logic: Optional[ColumnTransformation] = pydantic.Field(default=None)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.downstream, tuple(self.upstreams), self.logic))
|
||||
|
||||
|
||||
class _JoinInfo(_ParserBaseModel):
|
||||
join_type: str
|
||||
|
||||
@ -0,0 +1,65 @@
|
||||
[
|
||||
{
|
||||
"entityType": "dataset",
|
||||
"entityUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,dummy_test.diamond_problem.diamond_destination,PROD)",
|
||||
"changeType": "UPSERT",
|
||||
"aspectName": "upstreamLineage",
|
||||
"aspect": {
|
||||
"json": {
|
||||
"upstreams": [
|
||||
{
|
||||
"auditStamp": {
|
||||
"time": 1707182625000,
|
||||
"actor": "urn:li:corpuser:_ingestion"
|
||||
},
|
||||
"created": {
|
||||
"time": 1751377942741,
|
||||
"actor": "urn:li:corpuser:_ingestion"
|
||||
},
|
||||
"dataset": "urn:li:dataset:(urn:li:dataPlatform:snowflake,dummy_test.diamond_problem.diamond_source1,PROD)",
|
||||
"type": "TRANSFORMED",
|
||||
"query": "urn:li:query:composite_b904cb462bf8f3bfd8736ae83df0cdae5651c4de16183ccc395493e65d6a3824"
|
||||
}
|
||||
],
|
||||
"fineGrainedLineages": [
|
||||
{
|
||||
"upstreamType": "FIELD_SET",
|
||||
"upstreams": [
|
||||
"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,dummy_test.diamond_problem.diamond_source1,PROD),col_a)"
|
||||
],
|
||||
"downstreamType": "FIELD",
|
||||
"downstreams": [
|
||||
"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,dummy_test.diamond_problem.diamond_destination,PROD),col_a)"
|
||||
],
|
||||
"confidenceScore": 0.35,
|
||||
"query": "urn:li:query:composite_b904cb462bf8f3bfd8736ae83df0cdae5651c4de16183ccc395493e65d6a3824"
|
||||
},
|
||||
{
|
||||
"upstreamType": "FIELD_SET",
|
||||
"upstreams": [
|
||||
"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,dummy_test.diamond_problem.diamond_source1,PROD),col_b)"
|
||||
],
|
||||
"downstreamType": "FIELD",
|
||||
"downstreams": [
|
||||
"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,dummy_test.diamond_problem.diamond_destination,PROD),col_b)"
|
||||
],
|
||||
"confidenceScore": 0.35,
|
||||
"query": "urn:li:query:composite_b904cb462bf8f3bfd8736ae83df0cdae5651c4de16183ccc395493e65d6a3824"
|
||||
},
|
||||
{
|
||||
"upstreamType": "FIELD_SET",
|
||||
"upstreams": [
|
||||
"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,dummy_test.diamond_problem.diamond_source1,PROD),col_c)"
|
||||
],
|
||||
"downstreamType": "FIELD",
|
||||
"downstreams": [
|
||||
"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,dummy_test.diamond_problem.diamond_destination,PROD),col_c)"
|
||||
],
|
||||
"confidenceScore": 0.35,
|
||||
"query": "urn:li:query:composite_b904cb462bf8f3bfd8736ae83df0cdae5651c4de16183ccc395493e65d6a3824"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -1,7 +1,7 @@
|
||||
import functools
|
||||
import os
|
||||
import pathlib
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -9,6 +9,7 @@ from freezegun import freeze_time
|
||||
|
||||
from datahub.configuration.datetimes import parse_user_datetime
|
||||
from datahub.configuration.time_window_config import BucketDuration, get_time_bucket
|
||||
from datahub.ingestion.sink.file import write_metadata_file
|
||||
from datahub.ingestion.source.usage.usage_common import BaseUsageConfig
|
||||
from datahub.metadata.urns import CorpUserUrn, DatasetUrn
|
||||
from datahub.sql_parsing.sql_parsing_aggregator import (
|
||||
@ -1025,3 +1026,67 @@ def test_sql_aggreator_close_cleans_tmp(tmp_path):
|
||||
assert len(os.listdir(tmp_path)) > 0
|
||||
aggregator.close()
|
||||
assert len(os.listdir(tmp_path)) == 0
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
def test_diamond_problem(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
|
||||
aggregator = SqlParsingAggregator(
|
||||
platform="snowflake",
|
||||
generate_lineage=True,
|
||||
generate_usage_statistics=False,
|
||||
generate_operations=False,
|
||||
is_temp_table=lambda x: x.lower()
|
||||
in [
|
||||
"dummy_test.diamond_problem.t1",
|
||||
"dummy_test.diamond_problem.t2",
|
||||
"dummy_test.diamond_problem.t3",
|
||||
"dummy_test.diamond_problem.t4",
|
||||
],
|
||||
)
|
||||
|
||||
aggregator._schema_resolver.add_raw_schema_info(
|
||||
DatasetUrn("snowflake", "dummy_test.diamond_problem.diamond_source1").urn(),
|
||||
{"col_a": "int", "col_b": "int", "col_c": "int"},
|
||||
)
|
||||
|
||||
aggregator._schema_resolver.add_raw_schema_info(
|
||||
DatasetUrn(
|
||||
"snowflake",
|
||||
"dummy_test.diamond_problem.diamond_destination",
|
||||
).urn(),
|
||||
{"col_a": "int", "col_b": "int", "col_c": "int"},
|
||||
)
|
||||
|
||||
# Diamond query pattern: source1 -> t1 -> {t2, t3} -> t4 -> destination
|
||||
queries = [
|
||||
"CREATE TEMPORARY TABLE t1 as select * from diamond_source1;",
|
||||
"CREATE TEMPORARY TABLE t2 as select * from t1;",
|
||||
"CREATE TEMPORARY TABLE t3 as select * from t1;",
|
||||
"CREATE TEMPORARY TABLE t4 as select t2.col_a, t3.col_b, t2.col_c from t2 join t3 on t2.col_a = t3.col_a;",
|
||||
"CREATE TABLE diamond_destination as select * from t4;",
|
||||
]
|
||||
|
||||
base_timestamp = datetime(2025, 7, 1, 13, 52, 18, 741000, tzinfo=timezone.utc)
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
aggregator.add(
|
||||
ObservedQuery(
|
||||
query=query,
|
||||
default_db="dummy_test",
|
||||
default_schema="diamond_problem",
|
||||
session_id="14774700499701726",
|
||||
timestamp=base_timestamp + timedelta(seconds=i),
|
||||
)
|
||||
)
|
||||
|
||||
mcpws = [mcp for mcp in aggregator.gen_metadata()]
|
||||
lineage_mcpws = [mcpw for mcpw in mcpws if mcpw.aspectName == "upstreamLineage"]
|
||||
out_path = tmp_path / "mcpw.json"
|
||||
write_metadata_file(out_path, lineage_mcpws)
|
||||
|
||||
mce_helpers.check_golden_file(
|
||||
pytestconfig,
|
||||
out_path,
|
||||
pytestconfig.rootpath
|
||||
/ "tests/unit/sql_parsing/aggregator_goldens/test_diamond_problem_golden.json",
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user