feat(sdk): add datajob lineage & dataset sql parsing lineage (#13365)

This commit is contained in:
Hyejin Yoon 2025-05-09 10:20:48 +09:00 committed by GitHub
parent 71e104068e
commit a414bbb798
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 703 additions and 74 deletions

View File

@ -0,0 +1,13 @@
from datahub.metadata.urns import DataFlowUrn, DataJobUrn
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient
client = DataHubClient.from_env()
lineage_client = LineageClient(client=client)
flow_urn = DataFlowUrn(orchestrator="airflow", flow_id="data_pipeline", cluster="PROD")
lineage_client.add_datajob_lineage(
datajob=DataJobUrn(flow=flow_urn, job_id="data_pipeline"),
upstreams=[DataJobUrn(flow=flow_urn, job_id="extract_job")],
)

View File

@ -0,0 +1,14 @@
from datahub.metadata.urns import DataFlowUrn, DataJobUrn, DatasetUrn
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient
client = DataHubClient.from_env()
lineage_client = LineageClient(client=client)
flow_urn = DataFlowUrn(orchestrator="airflow", flow_id="data_pipeline", cluster="PROD")
lineage_client.add_datajob_lineage(
datajob=DataJobUrn(flow=flow_urn, job_id="data_pipeline"),
upstreams=[DatasetUrn(platform="postgres", name="raw_data")],
downstreams=[DatasetUrn(platform="snowflake", name="processed_data")],
)

View File

@ -0,0 +1,20 @@
from datahub.metadata.urns import DatasetUrn
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient
client = DataHubClient.from_env()
lineage_client = LineageClient(client=client)
lineage_client.add_dataset_copy_lineage(
upstream=DatasetUrn(platform="postgres", name="customer_data"),
downstream=DatasetUrn(platform="snowflake", name="customer_info"),
column_lineage="auto_fuzzy",
)
# by default, the column lineage is "auto_fuzzy", which will match similar field names.
# can also be "auto_strict" for strict matching.
# can also be a dict mapping upstream fields to downstream fields.
# e.g.
# column_lineage={
# "customer_id": ["id"],
# "full_name": ["first_name", "last_name"],
# }

View File

@ -0,0 +1,27 @@
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient
client = DataHubClient.from_env()
lineage_client = LineageClient(client=client)
sql_query = """
CREATE TABLE sales_summary AS
SELECT
p.product_name,
c.customer_segment,
SUM(s.quantity) as total_quantity,
SUM(s.amount) as total_sales
FROM sales s
JOIN products p ON s.product_id = p.id
JOIN customers c ON s.customer_id = c.id
GROUP BY p.product_name, c.customer_segment
"""
# sales_summary will be assumed to be in the default db/schema
# e.g. prod_db.public.sales_summary
lineage_client.add_dataset_lineage_from_sql(
query_text=sql_query,
platform="snowflake",
default_db="prod_db",
default_schema="public",
)

View File

@ -0,0 +1,17 @@
from datahub.metadata.urns import DatasetUrn
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient
client = DataHubClient.from_env()
lineage_client = LineageClient(client=client)
lineage_client.add_dataset_transform_lineage(
upstream=DatasetUrn(platform="snowflake", name="source_table"),
downstream=DatasetUrn(platform="snowflake", name="target_table"),
column_lineage={
"customer_id": ["id"],
"full_name": ["first_name", "last_name"],
},
)
# column_lineage is optional -- if not provided, table-level lineage is inferred.

View File

@ -0,0 +1,24 @@
from datahub.metadata.urns import DatasetUrn
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient
client = DataHubClient.from_env()
lineage_client = LineageClient(client=client)
# this can be any transformation logic e.g. a spark job, an airflow DAG, python script, etc.
# if you have a SQL query, we recommend using add_dataset_lineage_from_sql instead.
query_text = """
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("HighValueFilter").getOrCreate()
df = spark.read.table("customers")
high_value = df.filter("lifetime_value > 10000")
high_value.write.saveAsTable("high_value_customers")
"""
lineage_client.add_dataset_transform_lineage(
upstream=DatasetUrn(platform="snowflake", name="customers"),
downstream=DatasetUrn(platform="snowflake", name="high_value_customers"),
query_text=query_text,
)

View File

@ -4,22 +4,24 @@ import difflib
import logging import logging
from typing import TYPE_CHECKING, List, Literal, Optional, Set, Union from typing import TYPE_CHECKING, List, Literal, Optional, Set, Union
from typing_extensions import assert_never
import datahub.metadata.schema_classes as models import datahub.metadata.schema_classes as models
from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.errors import SdkUsageError from datahub.errors import SdkUsageError
from datahub.metadata.schema_classes import SchemaMetadataClass from datahub.metadata.urns import DataJobUrn, DatasetUrn, QueryUrn
from datahub.metadata.urns import DatasetUrn, QueryUrn from datahub.sdk._shared import DatajobUrnOrStr, DatasetUrnOrStr
from datahub.sdk._shared import DatasetUrnOrStr
from datahub.sdk._utils import DEFAULT_ACTOR_URN from datahub.sdk._utils import DEFAULT_ACTOR_URN
from datahub.sdk.dataset import ColumnLineageMapping, parse_cll_mapping from datahub.sdk.dataset import ColumnLineageMapping, parse_cll_mapping
from datahub.specific.datajob import DataJobPatchBuilder
from datahub.specific.dataset import DatasetPatchBuilder from datahub.specific.dataset import DatasetPatchBuilder
from datahub.sql_parsing.fingerprint_utils import generate_hash from datahub.sql_parsing.fingerprint_utils import generate_hash
from datahub.utilities.ordered_set import OrderedSet from datahub.utilities.ordered_set import OrderedSet
from datahub.utilities.urns.error import InvalidUrnError
if TYPE_CHECKING: if TYPE_CHECKING:
from datahub.sdk.main_client import DataHubClient from datahub.sdk.main_client import DataHubClient
logger = logging.getLogger(__name__)
_empty_audit_stamp = models.AuditStampClass( _empty_audit_stamp = models.AuditStampClass(
time=0, time=0,
@ -27,16 +29,19 @@ _empty_audit_stamp = models.AuditStampClass(
) )
logger = logging.getLogger(__name__)
class LineageClient: class LineageClient:
def __init__(self, client: DataHubClient): def __init__(self, client: DataHubClient):
self._client = client self._client = client
def _get_fields_from_dataset_urn(self, dataset_urn: DatasetUrn) -> Set[str]: def _get_fields_from_dataset_urn(self, dataset_urn: DatasetUrn) -> Set[str]:
schema_metadata = self._client._graph.get_aspect( schema_metadata = self._client._graph.get_aspect(
str(dataset_urn), SchemaMetadataClass str(dataset_urn), models.SchemaMetadataClass
) )
if schema_metadata is None: if schema_metadata is None:
return Set() return set()
return {field.fieldPath for field in schema_metadata.fields} return {field.fieldPath for field in schema_metadata.fields}
@ -122,7 +127,7 @@ class LineageClient:
if column_lineage is None: if column_lineage is None:
cll = None cll = None
elif column_lineage in ["auto_fuzzy", "auto_strict"]: elif column_lineage == "auto_fuzzy" or column_lineage == "auto_strict":
upstream_schema = self._get_fields_from_dataset_urn(upstream) upstream_schema = self._get_fields_from_dataset_urn(upstream)
downstream_schema = self._get_fields_from_dataset_urn(downstream) downstream_schema = self._get_fields_from_dataset_urn(downstream)
if column_lineage == "auto_fuzzy": if column_lineage == "auto_fuzzy":
@ -144,6 +149,8 @@ class LineageClient:
downstream=downstream, downstream=downstream,
cll_mapping=column_lineage, cll_mapping=column_lineage,
) )
else:
assert_never(column_lineage)
updater = DatasetPatchBuilder(str(downstream)) updater = DatasetPatchBuilder(str(downstream))
updater.add_upstream_lineage( updater.add_upstream_lineage(
@ -227,9 +234,129 @@ class LineageClient:
raise SdkUsageError( raise SdkUsageError(
f"Dataset {updater.urn} does not exist, and hence cannot be updated." f"Dataset {updater.urn} does not exist, and hence cannot be updated."
) )
mcps: List[ mcps: List[
Union[MetadataChangeProposalWrapper, models.MetadataChangeProposalClass] Union[MetadataChangeProposalWrapper, models.MetadataChangeProposalClass]
] = list(updater.build()) ] = list(updater.build())
if query_entity: if query_entity:
mcps.extend(query_entity) mcps.extend(query_entity)
self._client._graph.emit_mcps(mcps) self._client._graph.emit_mcps(mcps)
def add_dataset_lineage_from_sql(
self,
*,
query_text: str,
platform: str,
platform_instance: Optional[str] = None,
env: str = "PROD",
default_db: Optional[str] = None,
default_schema: Optional[str] = None,
) -> None:
"""Add lineage by parsing a SQL query."""
from datahub.sql_parsing.sqlglot_lineage import (
create_lineage_sql_parsed_result,
)
# Parse the SQL query to extract lineage information
parsed_result = create_lineage_sql_parsed_result(
query=query_text,
default_db=default_db,
default_schema=default_schema,
platform=platform,
platform_instance=platform_instance,
env=env,
graph=self._client._graph,
)
if parsed_result.debug_info.table_error:
raise SdkUsageError(
f"Failed to parse SQL query: {parsed_result.debug_info.error}"
)
elif parsed_result.debug_info.column_error:
logger.warning(
f"Failed to parse SQL query: {parsed_result.debug_info.error}",
)
if not parsed_result.out_tables:
raise SdkUsageError(
"No output tables found in the query. Cannot establish lineage."
)
# Use the first output table as the downstream
downstream_urn = parsed_result.out_tables[0]
# Process all upstream tables found in the query
for upstream_table in parsed_result.in_tables:
# Skip self-lineage
if upstream_table == downstream_urn:
continue
# Extract column-level lineage for this specific upstream table
column_mapping = {}
if parsed_result.column_lineage:
for col_lineage in parsed_result.column_lineage:
if not (col_lineage.downstream and col_lineage.downstream.column):
continue
# Filter upstreams to only include columns from current upstream table
upstream_cols = [
ref.column
for ref in col_lineage.upstreams
if ref.table == upstream_table and ref.column
]
if upstream_cols:
column_mapping[col_lineage.downstream.column] = upstream_cols
# Add lineage, including query text
self.add_dataset_transform_lineage(
upstream=upstream_table,
downstream=downstream_urn,
column_lineage=column_mapping or None,
query_text=query_text,
)
def add_datajob_lineage(
self,
*,
datajob: DatajobUrnOrStr,
upstreams: Optional[List[Union[DatasetUrnOrStr, DatajobUrnOrStr]]] = None,
downstreams: Optional[List[DatasetUrnOrStr]] = None,
) -> None:
"""
Add lineage between a datajob and datasets/datajobs.
Args:
datajob: The datajob URN to connect lineage with
upstreams: List of upstream datasets or datajobs that serve as inputs to the datajob
downstreams: List of downstream datasets that are outputs of the datajob
"""
if not upstreams and not downstreams:
raise SdkUsageError("No upstreams or downstreams provided")
datajob_urn = DataJobUrn.from_string(datajob)
# Initialize the patch builder for the datajob
patch_builder = DataJobPatchBuilder(str(datajob_urn))
# Process upstream connections (inputs to the datajob)
if upstreams:
for upstream in upstreams:
# try converting to dataset urn
try:
dataset_urn = DatasetUrn.from_string(upstream)
patch_builder.add_input_dataset(dataset_urn)
except InvalidUrnError:
# try converting to datajob urn
datajob_urn = DataJobUrn.from_string(upstream)
patch_builder.add_input_datajob(datajob_urn)
# Process downstream connections (outputs from the datajob)
if downstreams:
for downstream in downstreams:
downstream_urn = DatasetUrn.from_string(downstream)
patch_builder.add_output_dataset(downstream_urn)
# Apply the changes to the entity
self._client.entities.update(patch_builder)

View File

@ -0,0 +1,19 @@
[
{
"entityType": "dataJob",
"entityUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),process_job)",
"changeType": "PATCH",
"aspectName": "dataJobInputOutput",
"aspect": {
"json": [
{
"op": "add",
"path": "/inputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)",
"value": {
"destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
}
}
]
}
}
]

View File

@ -0,0 +1,33 @@
[
{
"entityType": "dataJob",
"entityUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)",
"changeType": "PATCH",
"aspectName": "dataJobInputOutput",
"aspect": {
"json": [
{
"op": "add",
"path": "/inputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)",
"value": {
"destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
}
},
{
"op": "add",
"path": "/inputDatajobEdges/urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),upstream_job)",
"value": {
"destinationUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),upstream_job)"
}
},
{
"op": "add",
"path": "/outputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)",
"value": {
"destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)"
}
}
]
}
}
]

View File

@ -0,0 +1,19 @@
[
{
"entityType": "dataJob",
"entityUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)",
"changeType": "PATCH",
"aspectName": "dataJobInputOutput",
"aspect": {
"json": [
{
"op": "add",
"path": "/outputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)",
"value": {
"destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)"
}
}
]
}
}
]

View File

@ -0,0 +1,67 @@
[
{
"entityType": "dataset",
"entityUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)",
"changeType": "PATCH",
"aspectName": "upstreamLineage",
"aspect": {
"json": [
{
"op": "add",
"path": "/upstreams/urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)",
"value": {
"auditStamp": {
"time": 0,
"actor": "urn:li:corpuser:unknown"
},
"dataset": "urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)",
"type": "TRANSFORMED",
"query": "urn:li:query:743f1807704d3d59ff076232e8788b43d98292f7d98ad14ce283d606351b7bb6"
}
}
]
}
},
{
"entityType": "query",
"entityUrn": "urn:li:query:743f1807704d3d59ff076232e8788b43d98292f7d98ad14ce283d606351b7bb6",
"changeType": "UPSERT",
"aspectName": "queryProperties",
"aspect": {
"json": {
"customProperties": {},
"statement": {
"value": "create table sales_summary as SELECT price, qty, unit_cost FROM orders",
"language": "SQL"
},
"source": "SYSTEM",
"created": {
"time": 0,
"actor": "urn:li:corpuser:__ingestion"
},
"lastModified": {
"time": 0,
"actor": "urn:li:corpuser:__ingestion"
}
}
}
},
{
"entityType": "query",
"entityUrn": "urn:li:query:743f1807704d3d59ff076232e8788b43d98292f7d98ad14ce283d606351b7bb6",
"changeType": "UPSERT",
"aspectName": "querySubjects",
"aspect": {
"json": {
"subjects": [
{
"entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)"
},
{
"entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
}
]
}
}
}
]

View File

@ -0,0 +1,67 @@
[
{
"entityType": "dataset",
"entityUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)",
"changeType": "PATCH",
"aspectName": "upstreamLineage",
"aspect": {
"json": [
{
"op": "add",
"path": "/upstreams/urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)",
"value": {
"auditStamp": {
"time": 0,
"actor": "urn:li:corpuser:unknown"
},
"dataset": "urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)",
"type": "TRANSFORMED",
"query": "urn:li:query:41fd73db4d7749a886910c3c7f06c29082420f5e6feb988c534c595561bb4131"
}
}
]
}
},
{
"entityType": "query",
"entityUrn": "urn:li:query:41fd73db4d7749a886910c3c7f06c29082420f5e6feb988c534c595561bb4131",
"changeType": "UPSERT",
"aspectName": "queryProperties",
"aspect": {
"json": {
"customProperties": {},
"statement": {
"value": "\n CREATE TABLE sales_summary AS\n SELECT \n p.product_name,\n SUM(s.quantity) as total_quantity,\n FROM sales s\n JOIN products p ON s.product_id = p.id\n GROUP BY p.product_name\n ",
"language": "SQL"
},
"source": "SYSTEM",
"created": {
"time": 0,
"actor": "urn:li:corpuser:__ingestion"
},
"lastModified": {
"time": 0,
"actor": "urn:li:corpuser:__ingestion"
}
}
}
},
{
"entityType": "query",
"entityUrn": "urn:li:query:41fd73db4d7749a886910c3c7f06c29082420f5e6feb988c534c595561bb4131",
"changeType": "UPSERT",
"aspectName": "querySubjects",
"aspect": {
"json": {
"subjects": [
{
"entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)"
},
{
"entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
}
]
}
}
}
]

View File

@ -1,6 +1,6 @@
import pathlib import pathlib
from typing import Dict, List, Set, cast from typing import Dict, List, Set, cast
from unittest.mock import MagicMock, Mock from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
@ -13,6 +13,14 @@ from datahub.metadata.schema_classes import (
) )
from datahub.sdk.lineage_client import LineageClient from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient from datahub.sdk.main_client import DataHubClient
from datahub.sql_parsing.sql_parsing_common import QueryType
from datahub.sql_parsing.sqlglot_lineage import (
ColumnLineageInfo,
ColumnRef,
DownstreamColumnRef,
SqlParsingResult,
)
from datahub.utilities.urns.error import InvalidUrnError
from tests.test_helpers import mce_helpers from tests.test_helpers import mce_helpers
_GOLDEN_DIR = pathlib.Path(__file__).parent / "lineage_client_golden" _GOLDEN_DIR = pathlib.Path(__file__).parent / "lineage_client_golden"
@ -22,6 +30,7 @@ _GOLDEN_DIR.mkdir(exist_ok=True)
@pytest.fixture @pytest.fixture
def mock_graph() -> Mock: def mock_graph() -> Mock:
graph = Mock() graph = Mock()
return graph return graph
@ -40,12 +49,9 @@ def assert_client_golden(client: DataHubClient, golden_path: pathlib.Path) -> No
) )
def test_get_fuzzy_column_lineage(): def test_get_fuzzy_column_lineage(client: DataHubClient) -> None:
"""Test the fuzzy column lineage matching algorithm.""" """Test the fuzzy column lineage matching algorithm."""
# Create a minimal client just for testing the method # Create a minimal client just for testing the method
client = MagicMock(spec=DataHubClient)
lineage_client = LineageClient(client=client)
# Test cases # Test cases
test_cases = [ test_cases = [
# Case 1: Exact matches # Case 1: Exact matches
@ -104,7 +110,7 @@ def test_get_fuzzy_column_lineage():
# Run test cases # Run test cases
for i, test_case in enumerate(test_cases): for i, test_case in enumerate(test_cases):
result = lineage_client._get_fuzzy_column_lineage( result = client.lineage._get_fuzzy_column_lineage(
cast(Set[str], test_case["upstream_fields"]), cast(Set[str], test_case["upstream_fields"]),
cast(Set[str], test_case["downstream_fields"]), cast(Set[str], test_case["downstream_fields"]),
) )
@ -113,11 +119,9 @@ def test_get_fuzzy_column_lineage():
) )
def test_get_strict_column_lineage(): def test_get_strict_column_lineage(client: DataHubClient) -> None:
"""Test the strict column lineage matching algorithm.""" """Test the strict column lineage matching algorithm."""
# Create a minimal client just for testing the method # Create a minimal client just for testing the method
client = MagicMock(spec=DataHubClient)
lineage_client = LineageClient(client=client)
# Define test cases # Define test cases
test_cases = [ test_cases = [
@ -143,7 +147,7 @@ def test_get_strict_column_lineage():
# Run test cases # Run test cases
for i, test_case in enumerate(test_cases): for i, test_case in enumerate(test_cases):
result = lineage_client._get_strict_column_lineage( result = client.lineage._get_strict_column_lineage(
cast(Set[str], test_case["upstream_fields"]), cast(Set[str], test_case["upstream_fields"]),
cast(Set[str], test_case["downstream_fields"]), cast(Set[str], test_case["downstream_fields"]),
) )
@ -152,7 +156,6 @@ def test_get_strict_column_lineage():
def test_add_dataset_copy_lineage_auto_fuzzy(client: DataHubClient) -> None: def test_add_dataset_copy_lineage_auto_fuzzy(client: DataHubClient) -> None:
"""Test auto fuzzy column lineage mapping.""" """Test auto fuzzy column lineage mapping."""
lineage_client = LineageClient(client=client)
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)" upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)" downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
@ -213,23 +216,24 @@ def test_add_dataset_copy_lineage_auto_fuzzy(client: DataHubClient) -> None:
], ],
) )
# Mock the _get_fields_from_dataset_urn method to return our test fields # Use patch.object with a context manager
lineage_client._get_fields_from_dataset_urn = MagicMock() # type: ignore with patch.object(LineageClient, "_get_fields_from_dataset_urn") as mock_method:
lineage_client._get_fields_from_dataset_urn.side_effect = lambda urn: sorted( # Configure the mock with a simpler side effect function
{ # type: ignore mock_method.side_effect = lambda urn: sorted(
field.fieldPath {
for field in ( field.fieldPath
upstream_schema if "upstream" in str(urn) else downstream_schema for field in (
).fields upstream_schema if "upstream" in str(urn) else downstream_schema
} ).fields
) }
)
# Run the lineage function # Now use client.lineage with the patched method
lineage_client.add_dataset_copy_lineage( client.lineage.add_dataset_copy_lineage(
upstream=upstream, upstream=upstream,
downstream=downstream, downstream=downstream,
column_lineage="auto_fuzzy", column_lineage="auto_fuzzy",
) )
# Use golden file for assertion # Use golden file for assertion
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_copy_fuzzy_golden.json") assert_client_golden(client, _GOLDEN_DIR / "test_lineage_copy_fuzzy_golden.json")
@ -237,8 +241,6 @@ def test_add_dataset_copy_lineage_auto_fuzzy(client: DataHubClient) -> None:
def test_add_dataset_copy_lineage_auto_strict(client: DataHubClient) -> None: def test_add_dataset_copy_lineage_auto_strict(client: DataHubClient) -> None:
"""Test strict column lineage with field matches.""" """Test strict column lineage with field matches."""
lineage_client = LineageClient(client=client)
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)" upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)" downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
@ -303,41 +305,22 @@ def test_add_dataset_copy_lineage_auto_strict(client: DataHubClient) -> None:
], ],
) )
# Mock the _get_fields_from_dataset_urn method to return our test fields with patch.object(LineageClient, "_get_fields_from_dataset_urn") as mock_method:
lineage_client._get_fields_from_dataset_urn = MagicMock() # type: ignore mock_method.side_effect = lambda urn: sorted(
lineage_client._get_fields_from_dataset_urn.side_effect = lambda urn: sorted( {
{ # type: ignore field.fieldPath
field.fieldPath for field in (
for field in ( upstream_schema if "upstream" in str(urn) else downstream_schema
upstream_schema if "upstream" in str(urn) else downstream_schema ).fields
).fields }
} )
)
# Run the lineage function # Run the lineage function
lineage_client.add_dataset_copy_lineage( client.lineage.add_dataset_copy_lineage(
upstream=upstream, upstream=upstream,
downstream=downstream, downstream=downstream,
column_lineage="auto_strict", column_lineage="auto_strict",
) )
# Mock the _get_fields_from_dataset_urn method to return our test fields
lineage_client._get_fields_from_dataset_urn = MagicMock() # type: ignore
lineage_client._get_fields_from_dataset_urn.side_effect = lambda urn: sorted(
{ # type: ignore
field.fieldPath
for field in (
upstream_schema if "upstream" in str(urn) else downstream_schema
).fields
}
)
# Run the lineage function
lineage_client.add_dataset_copy_lineage(
upstream=upstream,
downstream=downstream,
column_lineage="auto_strict",
)
# Use golden file for assertion # Use golden file for assertion
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_copy_strict_golden.json") assert_client_golden(client, _GOLDEN_DIR / "test_lineage_copy_strict_golden.json")
@ -345,13 +328,12 @@ def test_add_dataset_copy_lineage_auto_strict(client: DataHubClient) -> None:
def test_add_dataset_transform_lineage_basic(client: DataHubClient) -> None: def test_add_dataset_transform_lineage_basic(client: DataHubClient) -> None:
"""Test basic lineage without column mapping or query.""" """Test basic lineage without column mapping or query."""
lineage_client = LineageClient(client=client)
# Basic lineage test # Basic lineage test
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)" upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)" downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
lineage_client.add_dataset_transform_lineage( client.lineage.add_dataset_transform_lineage(
upstream=upstream, upstream=upstream,
downstream=downstream, downstream=downstream,
) )
@ -360,7 +342,6 @@ def test_add_dataset_transform_lineage_basic(client: DataHubClient) -> None:
def test_add_dataset_transform_lineage_complete(client: DataHubClient) -> None: def test_add_dataset_transform_lineage_complete(client: DataHubClient) -> None:
"""Test complete lineage with column mapping and query.""" """Test complete lineage with column mapping and query."""
lineage_client = LineageClient(client=client)
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)" upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)" downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
@ -372,10 +353,211 @@ def test_add_dataset_transform_lineage_complete(client: DataHubClient) -> None:
"ds_col2": ["us_col2", "us_col3"], # 2:1 mapping "ds_col2": ["us_col2", "us_col3"], # 2:1 mapping
} }
lineage_client.add_dataset_transform_lineage( client.lineage.add_dataset_transform_lineage(
upstream=upstream, upstream=upstream,
downstream=downstream, downstream=downstream,
query_text=query_text, query_text=query_text,
column_lineage=column_lineage, column_lineage=column_lineage,
) )
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_complete_golden.json") assert_client_golden(client, _GOLDEN_DIR / "test_lineage_complete_golden.json")
def test_add_dataset_lineage_from_sql(client: DataHubClient) -> None:
"""Test adding lineage from SQL parsing with a golden file."""
# Create minimal mock result with necessary info
mock_result = SqlParsingResult(
in_tables=["urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)"],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
],
column_lineage=[], # Simplified - we only care about table-level lineage for this test
query_type=QueryType.SELECT,
debug_info=MagicMock(error=None, table_error=None),
)
# Simple SQL that would produce the expected lineage
query_text = (
"create table sales_summary as SELECT price, qty, unit_cost FROM orders"
)
# Patch SQL parser and execute lineage creation
with patch(
"datahub.sql_parsing.sqlglot_lineage.create_lineage_sql_parsed_result",
return_value=mock_result,
):
client.lineage.add_dataset_lineage_from_sql(
query_text=query_text, platform="snowflake", env="PROD"
)
# Validate against golden file
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_from_sql_golden.json")
def test_add_dataset_lineage_from_sql_with_multiple_upstreams(
client: DataHubClient,
) -> None:
"""Test adding lineage for a dataset with multiple upstreams."""
# Create minimal mock result with necessary info
mock_result = SqlParsingResult(
in_tables=[
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)",
],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
],
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(
column="product_name",
),
upstreams=[
ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)",
column="product_name",
)
],
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
column="total_quantity",
),
upstreams=[
ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)",
column="quantity",
)
],
),
],
query_type=QueryType.SELECT,
debug_info=MagicMock(error=None, table_error=None),
)
# Simple SQL that would produce the expected lineage
query_text = """
CREATE TABLE sales_summary AS
SELECT
p.product_name,
SUM(s.quantity) as total_quantity,
FROM sales s
JOIN products p ON s.product_id = p.id
GROUP BY p.product_name
"""
# Patch SQL parser and execute lineage creation
with patch(
"datahub.sql_parsing.sqlglot_lineage.create_lineage_sql_parsed_result",
return_value=mock_result,
):
client.lineage.add_dataset_lineage_from_sql(
query_text=query_text, platform="snowflake", env="PROD"
)
# Validate against golden file
assert_client_golden(
client, _GOLDEN_DIR / "test_lineage_from_sql_multiple_upstreams_golden.json"
)
def test_add_datajob_lineage(client: DataHubClient) -> None:
"""Test adding lineage for datajobs using DataJobPatchBuilder."""
# Define URNs for test with correct format
datajob_urn = (
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)"
)
input_dataset_urn = (
"urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
)
input_datajob_urn = (
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),upstream_job)"
)
output_dataset_urn = (
"urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)"
)
# Test adding both upstream and downstream connections
client.lineage.add_datajob_lineage(
datajob=datajob_urn,
upstreams=[input_dataset_urn, input_datajob_urn],
downstreams=[output_dataset_urn],
)
# Validate lineage MCPs against golden file
assert_client_golden(client, _GOLDEN_DIR / "test_datajob_lineage_golden.json")
def test_add_datajob_inputs_only(client: DataHubClient) -> None:
"""Test adding only inputs to a datajob."""
# Define URNs for test
datajob_urn = (
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),process_job)"
)
input_dataset_urn = (
"urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
)
# Test adding just upstream connections
client.lineage.add_datajob_lineage(
datajob=datajob_urn,
upstreams=[input_dataset_urn],
)
# Validate lineage MCPs
assert_client_golden(client, _GOLDEN_DIR / "test_datajob_inputs_only_golden.json")
def test_add_datajob_outputs_only(client: DataHubClient) -> None:
"""Test adding only outputs to a datajob."""
# Define URNs for test
datajob_urn = (
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)"
)
output_dataset_urn = (
"urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)"
)
# Test adding just downstream connections
client.lineage.add_datajob_lineage(
datajob=datajob_urn, downstreams=[output_dataset_urn]
)
# Validate lineage MCPs
assert_client_golden(client, _GOLDEN_DIR / "test_datajob_outputs_only_golden.json")
def test_add_datajob_lineage_validation(client: DataHubClient) -> None:
"""Test validation checks in add_datajob_lineage."""
# Define URNs for test
datajob_urn = (
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)"
)
invalid_urn = "urn:li:glossaryNode:something"
# Test with invalid datajob URN
with pytest.raises(
InvalidUrnError,
match="Passed an urn of type glossaryNode to the from_string method of DataJobUrn",
):
client.lineage.add_datajob_lineage(
datajob=invalid_urn,
upstreams=[
"urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
],
)
# Test with invalid upstream URN
with pytest.raises(InvalidUrnError):
client.lineage.add_datajob_lineage(datajob=datajob_urn, upstreams=[invalid_urn])
# Test with invalid downstream URN
with pytest.raises(InvalidUrnError):
client.lineage.add_datajob_lineage(
datajob=datajob_urn, downstreams=[invalid_urn]
)