feat(sdk): add datajob lineage & dataset sql parsing lineage (#13365)

This commit is contained in:
Hyejin Yoon 2025-05-09 10:20:48 +09:00 committed by GitHub
parent 71e104068e
commit a414bbb798
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 703 additions and 74 deletions

View File

@ -0,0 +1,13 @@
from datahub.metadata.urns import DataFlowUrn, DataJobUrn
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient
client = DataHubClient.from_env()
lineage_client = LineageClient(client=client)
flow_urn = DataFlowUrn(orchestrator="airflow", flow_id="data_pipeline", cluster="PROD")
lineage_client.add_datajob_lineage(
datajob=DataJobUrn(flow=flow_urn, job_id="data_pipeline"),
upstreams=[DataJobUrn(flow=flow_urn, job_id="extract_job")],
)

View File

@ -0,0 +1,14 @@
from datahub.metadata.urns import DataFlowUrn, DataJobUrn, DatasetUrn
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient
client = DataHubClient.from_env()
lineage_client = LineageClient(client=client)
flow_urn = DataFlowUrn(orchestrator="airflow", flow_id="data_pipeline", cluster="PROD")
lineage_client.add_datajob_lineage(
datajob=DataJobUrn(flow=flow_urn, job_id="data_pipeline"),
upstreams=[DatasetUrn(platform="postgres", name="raw_data")],
downstreams=[DatasetUrn(platform="snowflake", name="processed_data")],
)

View File

@ -0,0 +1,20 @@
from datahub.metadata.urns import DatasetUrn
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient
client = DataHubClient.from_env()
lineage_client = LineageClient(client=client)
lineage_client.add_dataset_copy_lineage(
upstream=DatasetUrn(platform="postgres", name="customer_data"),
downstream=DatasetUrn(platform="snowflake", name="customer_info"),
column_lineage="auto_fuzzy",
)
# by default, the column lineage is "auto_fuzzy", which will match similar field names.
# can also be "auto_strict" for strict matching.
# can also be a dict mapping upstream fields to downstream fields.
# e.g.
# column_lineage={
# "customer_id": ["id"],
# "full_name": ["first_name", "last_name"],
# }

View File

@ -0,0 +1,27 @@
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient
client = DataHubClient.from_env()
lineage_client = LineageClient(client=client)
sql_query = """
CREATE TABLE sales_summary AS
SELECT
p.product_name,
c.customer_segment,
SUM(s.quantity) as total_quantity,
SUM(s.amount) as total_sales
FROM sales s
JOIN products p ON s.product_id = p.id
JOIN customers c ON s.customer_id = c.id
GROUP BY p.product_name, c.customer_segment
"""
# sales_summary will be assumed to be in the default db/schema
# e.g. prod_db.public.sales_summary
lineage_client.add_dataset_lineage_from_sql(
query_text=sql_query,
platform="snowflake",
default_db="prod_db",
default_schema="public",
)

View File

@ -0,0 +1,17 @@
from datahub.metadata.urns import DatasetUrn
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient
client = DataHubClient.from_env()
lineage_client = LineageClient(client=client)
lineage_client.add_dataset_transform_lineage(
upstream=DatasetUrn(platform="snowflake", name="source_table"),
downstream=DatasetUrn(platform="snowflake", name="target_table"),
column_lineage={
"customer_id": ["id"],
"full_name": ["first_name", "last_name"],
},
)
# column_lineage is optional -- if not provided, table-level lineage is inferred.

View File

@ -0,0 +1,24 @@
from datahub.metadata.urns import DatasetUrn
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient
client = DataHubClient.from_env()
lineage_client = LineageClient(client=client)
# this can be any transformation logic e.g. a spark job, an airflow DAG, python script, etc.
# if you have a SQL query, we recommend using add_dataset_lineage_from_sql instead.
query_text = """
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("HighValueFilter").getOrCreate()
df = spark.read.table("customers")
high_value = df.filter("lifetime_value > 10000")
high_value.write.saveAsTable("high_value_customers")
"""
lineage_client.add_dataset_transform_lineage(
upstream=DatasetUrn(platform="snowflake", name="customers"),
downstream=DatasetUrn(platform="snowflake", name="high_value_customers"),
query_text=query_text,
)

View File

@ -4,22 +4,24 @@ import difflib
import logging
from typing import TYPE_CHECKING, List, Literal, Optional, Set, Union
from typing_extensions import assert_never
import datahub.metadata.schema_classes as models
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.errors import SdkUsageError
from datahub.metadata.schema_classes import SchemaMetadataClass
from datahub.metadata.urns import DatasetUrn, QueryUrn
from datahub.sdk._shared import DatasetUrnOrStr
from datahub.metadata.urns import DataJobUrn, DatasetUrn, QueryUrn
from datahub.sdk._shared import DatajobUrnOrStr, DatasetUrnOrStr
from datahub.sdk._utils import DEFAULT_ACTOR_URN
from datahub.sdk.dataset import ColumnLineageMapping, parse_cll_mapping
from datahub.specific.datajob import DataJobPatchBuilder
from datahub.specific.dataset import DatasetPatchBuilder
from datahub.sql_parsing.fingerprint_utils import generate_hash
from datahub.utilities.ordered_set import OrderedSet
from datahub.utilities.urns.error import InvalidUrnError
if TYPE_CHECKING:
from datahub.sdk.main_client import DataHubClient
logger = logging.getLogger(__name__)
_empty_audit_stamp = models.AuditStampClass(
time=0,
@ -27,16 +29,19 @@ _empty_audit_stamp = models.AuditStampClass(
)
logger = logging.getLogger(__name__)
class LineageClient:
def __init__(self, client: DataHubClient):
self._client = client
def _get_fields_from_dataset_urn(self, dataset_urn: DatasetUrn) -> Set[str]:
schema_metadata = self._client._graph.get_aspect(
str(dataset_urn), SchemaMetadataClass
str(dataset_urn), models.SchemaMetadataClass
)
if schema_metadata is None:
return Set()
return set()
return {field.fieldPath for field in schema_metadata.fields}
@ -122,7 +127,7 @@ class LineageClient:
if column_lineage is None:
cll = None
elif column_lineage in ["auto_fuzzy", "auto_strict"]:
elif column_lineage == "auto_fuzzy" or column_lineage == "auto_strict":
upstream_schema = self._get_fields_from_dataset_urn(upstream)
downstream_schema = self._get_fields_from_dataset_urn(downstream)
if column_lineage == "auto_fuzzy":
@ -144,6 +149,8 @@ class LineageClient:
downstream=downstream,
cll_mapping=column_lineage,
)
else:
assert_never(column_lineage)
updater = DatasetPatchBuilder(str(downstream))
updater.add_upstream_lineage(
@ -227,9 +234,129 @@ class LineageClient:
raise SdkUsageError(
f"Dataset {updater.urn} does not exist, and hence cannot be updated."
)
mcps: List[
Union[MetadataChangeProposalWrapper, models.MetadataChangeProposalClass]
] = list(updater.build())
if query_entity:
mcps.extend(query_entity)
self._client._graph.emit_mcps(mcps)
def add_dataset_lineage_from_sql(
self,
*,
query_text: str,
platform: str,
platform_instance: Optional[str] = None,
env: str = "PROD",
default_db: Optional[str] = None,
default_schema: Optional[str] = None,
) -> None:
"""Add lineage by parsing a SQL query."""
from datahub.sql_parsing.sqlglot_lineage import (
create_lineage_sql_parsed_result,
)
# Parse the SQL query to extract lineage information
parsed_result = create_lineage_sql_parsed_result(
query=query_text,
default_db=default_db,
default_schema=default_schema,
platform=platform,
platform_instance=platform_instance,
env=env,
graph=self._client._graph,
)
if parsed_result.debug_info.table_error:
raise SdkUsageError(
f"Failed to parse SQL query: {parsed_result.debug_info.error}"
)
elif parsed_result.debug_info.column_error:
logger.warning(
f"Failed to parse SQL query: {parsed_result.debug_info.error}",
)
if not parsed_result.out_tables:
raise SdkUsageError(
"No output tables found in the query. Cannot establish lineage."
)
# Use the first output table as the downstream
downstream_urn = parsed_result.out_tables[0]
# Process all upstream tables found in the query
for upstream_table in parsed_result.in_tables:
# Skip self-lineage
if upstream_table == downstream_urn:
continue
# Extract column-level lineage for this specific upstream table
column_mapping = {}
if parsed_result.column_lineage:
for col_lineage in parsed_result.column_lineage:
if not (col_lineage.downstream and col_lineage.downstream.column):
continue
# Filter upstreams to only include columns from current upstream table
upstream_cols = [
ref.column
for ref in col_lineage.upstreams
if ref.table == upstream_table and ref.column
]
if upstream_cols:
column_mapping[col_lineage.downstream.column] = upstream_cols
# Add lineage, including query text
self.add_dataset_transform_lineage(
upstream=upstream_table,
downstream=downstream_urn,
column_lineage=column_mapping or None,
query_text=query_text,
)
def add_datajob_lineage(
self,
*,
datajob: DatajobUrnOrStr,
upstreams: Optional[List[Union[DatasetUrnOrStr, DatajobUrnOrStr]]] = None,
downstreams: Optional[List[DatasetUrnOrStr]] = None,
) -> None:
"""
Add lineage between a datajob and datasets/datajobs.
Args:
datajob: The datajob URN to connect lineage with
upstreams: List of upstream datasets or datajobs that serve as inputs to the datajob
downstreams: List of downstream datasets that are outputs of the datajob
"""
if not upstreams and not downstreams:
raise SdkUsageError("No upstreams or downstreams provided")
datajob_urn = DataJobUrn.from_string(datajob)
# Initialize the patch builder for the datajob
patch_builder = DataJobPatchBuilder(str(datajob_urn))
# Process upstream connections (inputs to the datajob)
if upstreams:
for upstream in upstreams:
# try converting to dataset urn
try:
dataset_urn = DatasetUrn.from_string(upstream)
patch_builder.add_input_dataset(dataset_urn)
except InvalidUrnError:
# try converting to datajob urn
datajob_urn = DataJobUrn.from_string(upstream)
patch_builder.add_input_datajob(datajob_urn)
# Process downstream connections (outputs from the datajob)
if downstreams:
for downstream in downstreams:
downstream_urn = DatasetUrn.from_string(downstream)
patch_builder.add_output_dataset(downstream_urn)
# Apply the changes to the entity
self._client.entities.update(patch_builder)

View File

@ -0,0 +1,19 @@
[
{
"entityType": "dataJob",
"entityUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),process_job)",
"changeType": "PATCH",
"aspectName": "dataJobInputOutput",
"aspect": {
"json": [
{
"op": "add",
"path": "/inputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)",
"value": {
"destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
}
}
]
}
}
]

View File

@ -0,0 +1,33 @@
[
{
"entityType": "dataJob",
"entityUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)",
"changeType": "PATCH",
"aspectName": "dataJobInputOutput",
"aspect": {
"json": [
{
"op": "add",
"path": "/inputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)",
"value": {
"destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
}
},
{
"op": "add",
"path": "/inputDatajobEdges/urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),upstream_job)",
"value": {
"destinationUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),upstream_job)"
}
},
{
"op": "add",
"path": "/outputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)",
"value": {
"destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)"
}
}
]
}
}
]

View File

@ -0,0 +1,19 @@
[
{
"entityType": "dataJob",
"entityUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)",
"changeType": "PATCH",
"aspectName": "dataJobInputOutput",
"aspect": {
"json": [
{
"op": "add",
"path": "/outputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)",
"value": {
"destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)"
}
}
]
}
}
]

View File

@ -0,0 +1,67 @@
[
{
"entityType": "dataset",
"entityUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)",
"changeType": "PATCH",
"aspectName": "upstreamLineage",
"aspect": {
"json": [
{
"op": "add",
"path": "/upstreams/urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)",
"value": {
"auditStamp": {
"time": 0,
"actor": "urn:li:corpuser:unknown"
},
"dataset": "urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)",
"type": "TRANSFORMED",
"query": "urn:li:query:743f1807704d3d59ff076232e8788b43d98292f7d98ad14ce283d606351b7bb6"
}
}
]
}
},
{
"entityType": "query",
"entityUrn": "urn:li:query:743f1807704d3d59ff076232e8788b43d98292f7d98ad14ce283d606351b7bb6",
"changeType": "UPSERT",
"aspectName": "queryProperties",
"aspect": {
"json": {
"customProperties": {},
"statement": {
"value": "create table sales_summary as SELECT price, qty, unit_cost FROM orders",
"language": "SQL"
},
"source": "SYSTEM",
"created": {
"time": 0,
"actor": "urn:li:corpuser:__ingestion"
},
"lastModified": {
"time": 0,
"actor": "urn:li:corpuser:__ingestion"
}
}
}
},
{
"entityType": "query",
"entityUrn": "urn:li:query:743f1807704d3d59ff076232e8788b43d98292f7d98ad14ce283d606351b7bb6",
"changeType": "UPSERT",
"aspectName": "querySubjects",
"aspect": {
"json": {
"subjects": [
{
"entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)"
},
{
"entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
}
]
}
}
}
]

View File

@ -0,0 +1,67 @@
[
{
"entityType": "dataset",
"entityUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)",
"changeType": "PATCH",
"aspectName": "upstreamLineage",
"aspect": {
"json": [
{
"op": "add",
"path": "/upstreams/urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)",
"value": {
"auditStamp": {
"time": 0,
"actor": "urn:li:corpuser:unknown"
},
"dataset": "urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)",
"type": "TRANSFORMED",
"query": "urn:li:query:41fd73db4d7749a886910c3c7f06c29082420f5e6feb988c534c595561bb4131"
}
}
]
}
},
{
"entityType": "query",
"entityUrn": "urn:li:query:41fd73db4d7749a886910c3c7f06c29082420f5e6feb988c534c595561bb4131",
"changeType": "UPSERT",
"aspectName": "queryProperties",
"aspect": {
"json": {
"customProperties": {},
"statement": {
"value": "\n CREATE TABLE sales_summary AS\n SELECT \n p.product_name,\n SUM(s.quantity) as total_quantity,\n FROM sales s\n JOIN products p ON s.product_id = p.id\n GROUP BY p.product_name\n ",
"language": "SQL"
},
"source": "SYSTEM",
"created": {
"time": 0,
"actor": "urn:li:corpuser:__ingestion"
},
"lastModified": {
"time": 0,
"actor": "urn:li:corpuser:__ingestion"
}
}
}
},
{
"entityType": "query",
"entityUrn": "urn:li:query:41fd73db4d7749a886910c3c7f06c29082420f5e6feb988c534c595561bb4131",
"changeType": "UPSERT",
"aspectName": "querySubjects",
"aspect": {
"json": {
"subjects": [
{
"entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)"
},
{
"entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
}
]
}
}
}
]

View File

@ -1,6 +1,6 @@
import pathlib
from typing import Dict, List, Set, cast
from unittest.mock import MagicMock, Mock
from unittest.mock import MagicMock, Mock, patch
import pytest
@ -13,6 +13,14 @@ from datahub.metadata.schema_classes import (
)
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient
from datahub.sql_parsing.sql_parsing_common import QueryType
from datahub.sql_parsing.sqlglot_lineage import (
ColumnLineageInfo,
ColumnRef,
DownstreamColumnRef,
SqlParsingResult,
)
from datahub.utilities.urns.error import InvalidUrnError
from tests.test_helpers import mce_helpers
_GOLDEN_DIR = pathlib.Path(__file__).parent / "lineage_client_golden"
@ -22,6 +30,7 @@ _GOLDEN_DIR.mkdir(exist_ok=True)
@pytest.fixture
def mock_graph() -> Mock:
graph = Mock()
return graph
@ -40,12 +49,9 @@ def assert_client_golden(client: DataHubClient, golden_path: pathlib.Path) -> No
)
def test_get_fuzzy_column_lineage():
def test_get_fuzzy_column_lineage(client: DataHubClient) -> None:
"""Test the fuzzy column lineage matching algorithm."""
# Create a minimal client just for testing the method
client = MagicMock(spec=DataHubClient)
lineage_client = LineageClient(client=client)
# Test cases
test_cases = [
# Case 1: Exact matches
@ -104,7 +110,7 @@ def test_get_fuzzy_column_lineage():
# Run test cases
for i, test_case in enumerate(test_cases):
result = lineage_client._get_fuzzy_column_lineage(
result = client.lineage._get_fuzzy_column_lineage(
cast(Set[str], test_case["upstream_fields"]),
cast(Set[str], test_case["downstream_fields"]),
)
@ -113,11 +119,9 @@ def test_get_fuzzy_column_lineage():
)
def test_get_strict_column_lineage():
def test_get_strict_column_lineage(client: DataHubClient) -> None:
"""Test the strict column lineage matching algorithm."""
# Create a minimal client just for testing the method
client = MagicMock(spec=DataHubClient)
lineage_client = LineageClient(client=client)
# Define test cases
test_cases = [
@ -143,7 +147,7 @@ def test_get_strict_column_lineage():
# Run test cases
for i, test_case in enumerate(test_cases):
result = lineage_client._get_strict_column_lineage(
result = client.lineage._get_strict_column_lineage(
cast(Set[str], test_case["upstream_fields"]),
cast(Set[str], test_case["downstream_fields"]),
)
@ -152,7 +156,6 @@ def test_get_strict_column_lineage():
def test_add_dataset_copy_lineage_auto_fuzzy(client: DataHubClient) -> None:
"""Test auto fuzzy column lineage mapping."""
lineage_client = LineageClient(client=client)
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
@ -213,23 +216,24 @@ def test_add_dataset_copy_lineage_auto_fuzzy(client: DataHubClient) -> None:
],
)
# Mock the _get_fields_from_dataset_urn method to return our test fields
lineage_client._get_fields_from_dataset_urn = MagicMock() # type: ignore
lineage_client._get_fields_from_dataset_urn.side_effect = lambda urn: sorted(
{ # type: ignore
field.fieldPath
for field in (
upstream_schema if "upstream" in str(urn) else downstream_schema
).fields
}
)
# Use patch.object with a context manager
with patch.object(LineageClient, "_get_fields_from_dataset_urn") as mock_method:
# Configure the mock with a simpler side effect function
mock_method.side_effect = lambda urn: sorted(
{
field.fieldPath
for field in (
upstream_schema if "upstream" in str(urn) else downstream_schema
).fields
}
)
# Run the lineage function
lineage_client.add_dataset_copy_lineage(
upstream=upstream,
downstream=downstream,
column_lineage="auto_fuzzy",
)
# Now use client.lineage with the patched method
client.lineage.add_dataset_copy_lineage(
upstream=upstream,
downstream=downstream,
column_lineage="auto_fuzzy",
)
# Use golden file for assertion
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_copy_fuzzy_golden.json")
@ -237,8 +241,6 @@ def test_add_dataset_copy_lineage_auto_fuzzy(client: DataHubClient) -> None:
def test_add_dataset_copy_lineage_auto_strict(client: DataHubClient) -> None:
"""Test strict column lineage with field matches."""
lineage_client = LineageClient(client=client)
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
@ -303,41 +305,22 @@ def test_add_dataset_copy_lineage_auto_strict(client: DataHubClient) -> None:
],
)
# Mock the _get_fields_from_dataset_urn method to return our test fields
lineage_client._get_fields_from_dataset_urn = MagicMock() # type: ignore
lineage_client._get_fields_from_dataset_urn.side_effect = lambda urn: sorted(
{ # type: ignore
field.fieldPath
for field in (
upstream_schema if "upstream" in str(urn) else downstream_schema
).fields
}
)
with patch.object(LineageClient, "_get_fields_from_dataset_urn") as mock_method:
mock_method.side_effect = lambda urn: sorted(
{
field.fieldPath
for field in (
upstream_schema if "upstream" in str(urn) else downstream_schema
).fields
}
)
# Run the lineage function
lineage_client.add_dataset_copy_lineage(
upstream=upstream,
downstream=downstream,
column_lineage="auto_strict",
)
# Mock the _get_fields_from_dataset_urn method to return our test fields
lineage_client._get_fields_from_dataset_urn = MagicMock() # type: ignore
lineage_client._get_fields_from_dataset_urn.side_effect = lambda urn: sorted(
{ # type: ignore
field.fieldPath
for field in (
upstream_schema if "upstream" in str(urn) else downstream_schema
).fields
}
)
# Run the lineage function
lineage_client.add_dataset_copy_lineage(
upstream=upstream,
downstream=downstream,
column_lineage="auto_strict",
)
# Run the lineage function
client.lineage.add_dataset_copy_lineage(
upstream=upstream,
downstream=downstream,
column_lineage="auto_strict",
)
# Use golden file for assertion
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_copy_strict_golden.json")
@ -345,13 +328,12 @@ def test_add_dataset_copy_lineage_auto_strict(client: DataHubClient) -> None:
def test_add_dataset_transform_lineage_basic(client: DataHubClient) -> None:
"""Test basic lineage without column mapping or query."""
lineage_client = LineageClient(client=client)
# Basic lineage test
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
lineage_client.add_dataset_transform_lineage(
client.lineage.add_dataset_transform_lineage(
upstream=upstream,
downstream=downstream,
)
@ -360,7 +342,6 @@ def test_add_dataset_transform_lineage_basic(client: DataHubClient) -> None:
def test_add_dataset_transform_lineage_complete(client: DataHubClient) -> None:
"""Test complete lineage with column mapping and query."""
lineage_client = LineageClient(client=client)
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
@ -372,10 +353,211 @@ def test_add_dataset_transform_lineage_complete(client: DataHubClient) -> None:
"ds_col2": ["us_col2", "us_col3"], # 2:1 mapping
}
lineage_client.add_dataset_transform_lineage(
client.lineage.add_dataset_transform_lineage(
upstream=upstream,
downstream=downstream,
query_text=query_text,
column_lineage=column_lineage,
)
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_complete_golden.json")
def test_add_dataset_lineage_from_sql(client: DataHubClient) -> None:
"""Test adding lineage from SQL parsing with a golden file."""
# Create minimal mock result with necessary info
mock_result = SqlParsingResult(
in_tables=["urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)"],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
],
column_lineage=[], # Simplified - we only care about table-level lineage for this test
query_type=QueryType.SELECT,
debug_info=MagicMock(error=None, table_error=None),
)
# Simple SQL that would produce the expected lineage
query_text = (
"create table sales_summary as SELECT price, qty, unit_cost FROM orders"
)
# Patch SQL parser and execute lineage creation
with patch(
"datahub.sql_parsing.sqlglot_lineage.create_lineage_sql_parsed_result",
return_value=mock_result,
):
client.lineage.add_dataset_lineage_from_sql(
query_text=query_text, platform="snowflake", env="PROD"
)
# Validate against golden file
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_from_sql_golden.json")
def test_add_dataset_lineage_from_sql_with_multiple_upstreams(
client: DataHubClient,
) -> None:
"""Test adding lineage for a dataset with multiple upstreams."""
# Create minimal mock result with necessary info
mock_result = SqlParsingResult(
in_tables=[
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)",
],
out_tables=[
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
],
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(
column="product_name",
),
upstreams=[
ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)",
column="product_name",
)
],
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(
column="total_quantity",
),
upstreams=[
ColumnRef(
table="urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)",
column="quantity",
)
],
),
],
query_type=QueryType.SELECT,
debug_info=MagicMock(error=None, table_error=None),
)
# Simple SQL that would produce the expected lineage
query_text = """
CREATE TABLE sales_summary AS
SELECT
p.product_name,
SUM(s.quantity) as total_quantity,
FROM sales s
JOIN products p ON s.product_id = p.id
GROUP BY p.product_name
"""
# Patch SQL parser and execute lineage creation
with patch(
"datahub.sql_parsing.sqlglot_lineage.create_lineage_sql_parsed_result",
return_value=mock_result,
):
client.lineage.add_dataset_lineage_from_sql(
query_text=query_text, platform="snowflake", env="PROD"
)
# Validate against golden file
assert_client_golden(
client, _GOLDEN_DIR / "test_lineage_from_sql_multiple_upstreams_golden.json"
)
def test_add_datajob_lineage(client: DataHubClient) -> None:
"""Test adding lineage for datajobs using DataJobPatchBuilder."""
# Define URNs for test with correct format
datajob_urn = (
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)"
)
input_dataset_urn = (
"urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
)
input_datajob_urn = (
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),upstream_job)"
)
output_dataset_urn = (
"urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)"
)
# Test adding both upstream and downstream connections
client.lineage.add_datajob_lineage(
datajob=datajob_urn,
upstreams=[input_dataset_urn, input_datajob_urn],
downstreams=[output_dataset_urn],
)
# Validate lineage MCPs against golden file
assert_client_golden(client, _GOLDEN_DIR / "test_datajob_lineage_golden.json")
def test_add_datajob_inputs_only(client: DataHubClient) -> None:
"""Test adding only inputs to a datajob."""
# Define URNs for test
datajob_urn = (
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),process_job)"
)
input_dataset_urn = (
"urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
)
# Test adding just upstream connections
client.lineage.add_datajob_lineage(
datajob=datajob_urn,
upstreams=[input_dataset_urn],
)
# Validate lineage MCPs
assert_client_golden(client, _GOLDEN_DIR / "test_datajob_inputs_only_golden.json")
def test_add_datajob_outputs_only(client: DataHubClient) -> None:
"""Test adding only outputs to a datajob."""
# Define URNs for test
datajob_urn = (
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)"
)
output_dataset_urn = (
"urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)"
)
# Test adding just downstream connections
client.lineage.add_datajob_lineage(
datajob=datajob_urn, downstreams=[output_dataset_urn]
)
# Validate lineage MCPs
assert_client_golden(client, _GOLDEN_DIR / "test_datajob_outputs_only_golden.json")
def test_add_datajob_lineage_validation(client: DataHubClient) -> None:
"""Test validation checks in add_datajob_lineage."""
# Define URNs for test
datajob_urn = (
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)"
)
invalid_urn = "urn:li:glossaryNode:something"
# Test with invalid datajob URN
with pytest.raises(
InvalidUrnError,
match="Passed an urn of type glossaryNode to the from_string method of DataJobUrn",
):
client.lineage.add_datajob_lineage(
datajob=invalid_urn,
upstreams=[
"urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
],
)
# Test with invalid upstream URN
with pytest.raises(InvalidUrnError):
client.lineage.add_datajob_lineage(datajob=datajob_urn, upstreams=[invalid_urn])
# Test with invalid downstream URN
with pytest.raises(InvalidUrnError):
client.lineage.add_datajob_lineage(
datajob=datajob_urn, downstreams=[invalid_urn]
)