From a414bbb79890ddb3d8d2c0b7a3b4c97bb919f723 Mon Sep 17 00:00:00 2001 From: Hyejin Yoon <0327jane@gmail.com> Date: Fri, 9 May 2025 10:20:48 +0900 Subject: [PATCH] feat(sdk): add datajob lineage & dataset sql parsing lineage (#13365) --- .../library/lineage_datajob_to_datajob.py | 13 + .../library/lineage_datajob_to_dataset.py | 14 + .../examples/library/lineage_dataset_copy.py | 20 ++ .../library/lineage_dataset_from_sql.py | 27 ++ .../library/lineage_dataset_transform.py | 17 + ...neage_dataset_transform_with_query_text.py | 24 ++ .../src/datahub/sdk/lineage_client.py | 141 +++++++- .../test_datajob_inputs_only_golden.json | 19 ++ .../test_datajob_lineage_golden.json | 33 ++ .../test_datajob_outputs_only_golden.json | 19 ++ .../test_lineage_from_sql_golden.json | 67 ++++ ...ge_from_sql_multiple_upstreams_golden.json | 67 ++++ .../tests/unit/sdk_v2/test_lineage_client.py | 316 ++++++++++++++---- 13 files changed, 703 insertions(+), 74 deletions(-) create mode 100644 metadata-ingestion/examples/library/lineage_datajob_to_datajob.py create mode 100644 metadata-ingestion/examples/library/lineage_datajob_to_dataset.py create mode 100644 metadata-ingestion/examples/library/lineage_dataset_copy.py create mode 100644 metadata-ingestion/examples/library/lineage_dataset_from_sql.py create mode 100644 metadata-ingestion/examples/library/lineage_dataset_transform.py create mode 100644 metadata-ingestion/examples/library/lineage_dataset_transform_with_query_text.py create mode 100644 metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_datajob_inputs_only_golden.json create mode 100644 metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_datajob_lineage_golden.json create mode 100644 metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_datajob_outputs_only_golden.json create mode 100644 metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_lineage_from_sql_golden.json create mode 100644 metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_lineage_from_sql_multiple_upstreams_golden.json diff --git a/metadata-ingestion/examples/library/lineage_datajob_to_datajob.py b/metadata-ingestion/examples/library/lineage_datajob_to_datajob.py new file mode 100644 index 0000000000..29f932d621 --- /dev/null +++ b/metadata-ingestion/examples/library/lineage_datajob_to_datajob.py @@ -0,0 +1,13 @@ +from datahub.metadata.urns import DataFlowUrn, DataJobUrn +from datahub.sdk.lineage_client import LineageClient +from datahub.sdk.main_client import DataHubClient + +client = DataHubClient.from_env() +lineage_client = LineageClient(client=client) + +flow_urn = DataFlowUrn(orchestrator="airflow", flow_id="data_pipeline", cluster="PROD") + +lineage_client.add_datajob_lineage( + datajob=DataJobUrn(flow=flow_urn, job_id="data_pipeline"), + upstreams=[DataJobUrn(flow=flow_urn, job_id="extract_job")], +) diff --git a/metadata-ingestion/examples/library/lineage_datajob_to_dataset.py b/metadata-ingestion/examples/library/lineage_datajob_to_dataset.py new file mode 100644 index 0000000000..b583d9a009 --- /dev/null +++ b/metadata-ingestion/examples/library/lineage_datajob_to_dataset.py @@ -0,0 +1,14 @@ +from datahub.metadata.urns import DataFlowUrn, DataJobUrn, DatasetUrn +from datahub.sdk.lineage_client import LineageClient +from datahub.sdk.main_client import DataHubClient + +client = DataHubClient.from_env() +lineage_client = LineageClient(client=client) + +flow_urn = DataFlowUrn(orchestrator="airflow", flow_id="data_pipeline", cluster="PROD") + +lineage_client.add_datajob_lineage( + datajob=DataJobUrn(flow=flow_urn, job_id="data_pipeline"), + upstreams=[DatasetUrn(platform="postgres", name="raw_data")], + downstreams=[DatasetUrn(platform="snowflake", name="processed_data")], +) diff --git a/metadata-ingestion/examples/library/lineage_dataset_copy.py b/metadata-ingestion/examples/library/lineage_dataset_copy.py new file mode 100644 index 0000000000..32c1e19cfc --- /dev/null +++ b/metadata-ingestion/examples/library/lineage_dataset_copy.py @@ -0,0 +1,20 @@ +from datahub.metadata.urns import DatasetUrn +from datahub.sdk.lineage_client import LineageClient +from datahub.sdk.main_client import DataHubClient + +client = DataHubClient.from_env() +lineage_client = LineageClient(client=client) + +lineage_client.add_dataset_copy_lineage( + upstream=DatasetUrn(platform="postgres", name="customer_data"), + downstream=DatasetUrn(platform="snowflake", name="customer_info"), + column_lineage="auto_fuzzy", +) +# by default, the column lineage is "auto_fuzzy", which will match similar field names. +# can also be "auto_strict" for strict matching. +# can also be a dict mapping upstream fields to downstream fields. +# e.g. +# column_lineage={ +# "customer_id": ["id"], +# "full_name": ["first_name", "last_name"], +# } diff --git a/metadata-ingestion/examples/library/lineage_dataset_from_sql.py b/metadata-ingestion/examples/library/lineage_dataset_from_sql.py new file mode 100644 index 0000000000..3650802ff8 --- /dev/null +++ b/metadata-ingestion/examples/library/lineage_dataset_from_sql.py @@ -0,0 +1,27 @@ +from datahub.sdk.lineage_client import LineageClient +from datahub.sdk.main_client import DataHubClient + +client = DataHubClient.from_env() +lineage_client = LineageClient(client=client) + +sql_query = """ +CREATE TABLE sales_summary AS +SELECT + p.product_name, + c.customer_segment, + SUM(s.quantity) as total_quantity, + SUM(s.amount) as total_sales +FROM sales s +JOIN products p ON s.product_id = p.id +JOIN customers c ON s.customer_id = c.id +GROUP BY p.product_name, c.customer_segment +""" + +# sales_summary will be assumed to be in the default db/schema +# e.g. prod_db.public.sales_summary +lineage_client.add_dataset_lineage_from_sql( + query_text=sql_query, + platform="snowflake", + default_db="prod_db", + default_schema="public", +) diff --git a/metadata-ingestion/examples/library/lineage_dataset_transform.py b/metadata-ingestion/examples/library/lineage_dataset_transform.py new file mode 100644 index 0000000000..fa3d278bd4 --- /dev/null +++ b/metadata-ingestion/examples/library/lineage_dataset_transform.py @@ -0,0 +1,17 @@ +from datahub.metadata.urns import DatasetUrn +from datahub.sdk.lineage_client import LineageClient +from datahub.sdk.main_client import DataHubClient + +client = DataHubClient.from_env() +lineage_client = LineageClient(client=client) + + +lineage_client.add_dataset_transform_lineage( + upstream=DatasetUrn(platform="snowflake", name="source_table"), + downstream=DatasetUrn(platform="snowflake", name="target_table"), + column_lineage={ + "customer_id": ["id"], + "full_name": ["first_name", "last_name"], + }, +) +# column_lineage is optional -- if not provided, table-level lineage is inferred. diff --git a/metadata-ingestion/examples/library/lineage_dataset_transform_with_query_text.py b/metadata-ingestion/examples/library/lineage_dataset_transform_with_query_text.py new file mode 100644 index 0000000000..d88b023531 --- /dev/null +++ b/metadata-ingestion/examples/library/lineage_dataset_transform_with_query_text.py @@ -0,0 +1,24 @@ +from datahub.metadata.urns import DatasetUrn +from datahub.sdk.lineage_client import LineageClient +from datahub.sdk.main_client import DataHubClient + +client = DataHubClient.from_env() +lineage_client = LineageClient(client=client) + +# this can be any transformation logic e.g. a spark job, an airflow DAG, python script, etc. +# if you have a SQL query, we recommend using add_dataset_lineage_from_sql instead. + +query_text = """ +from pyspark.sql import SparkSession + +spark = SparkSession.builder.appName("HighValueFilter").getOrCreate() +df = spark.read.table("customers") +high_value = df.filter("lifetime_value > 10000") +high_value.write.saveAsTable("high_value_customers") +""" + +lineage_client.add_dataset_transform_lineage( + upstream=DatasetUrn(platform="snowflake", name="customers"), + downstream=DatasetUrn(platform="snowflake", name="high_value_customers"), + query_text=query_text, +) diff --git a/metadata-ingestion/src/datahub/sdk/lineage_client.py b/metadata-ingestion/src/datahub/sdk/lineage_client.py index 20b24405df..a83b5f3c08 100644 --- a/metadata-ingestion/src/datahub/sdk/lineage_client.py +++ b/metadata-ingestion/src/datahub/sdk/lineage_client.py @@ -4,22 +4,24 @@ import difflib import logging from typing import TYPE_CHECKING, List, Literal, Optional, Set, Union +from typing_extensions import assert_never + import datahub.metadata.schema_classes as models from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.errors import SdkUsageError -from datahub.metadata.schema_classes import SchemaMetadataClass -from datahub.metadata.urns import DatasetUrn, QueryUrn -from datahub.sdk._shared import DatasetUrnOrStr +from datahub.metadata.urns import DataJobUrn, DatasetUrn, QueryUrn +from datahub.sdk._shared import DatajobUrnOrStr, DatasetUrnOrStr from datahub.sdk._utils import DEFAULT_ACTOR_URN from datahub.sdk.dataset import ColumnLineageMapping, parse_cll_mapping +from datahub.specific.datajob import DataJobPatchBuilder from datahub.specific.dataset import DatasetPatchBuilder from datahub.sql_parsing.fingerprint_utils import generate_hash from datahub.utilities.ordered_set import OrderedSet +from datahub.utilities.urns.error import InvalidUrnError if TYPE_CHECKING: from datahub.sdk.main_client import DataHubClient -logger = logging.getLogger(__name__) _empty_audit_stamp = models.AuditStampClass( time=0, @@ -27,16 +29,19 @@ _empty_audit_stamp = models.AuditStampClass( ) +logger = logging.getLogger(__name__) + + class LineageClient: def __init__(self, client: DataHubClient): self._client = client def _get_fields_from_dataset_urn(self, dataset_urn: DatasetUrn) -> Set[str]: schema_metadata = self._client._graph.get_aspect( - str(dataset_urn), SchemaMetadataClass + str(dataset_urn), models.SchemaMetadataClass ) if schema_metadata is None: - return Set() + return set() return {field.fieldPath for field in schema_metadata.fields} @@ -122,7 +127,7 @@ class LineageClient: if column_lineage is None: cll = None - elif column_lineage in ["auto_fuzzy", "auto_strict"]: + elif column_lineage == "auto_fuzzy" or column_lineage == "auto_strict": upstream_schema = self._get_fields_from_dataset_urn(upstream) downstream_schema = self._get_fields_from_dataset_urn(downstream) if column_lineage == "auto_fuzzy": @@ -144,6 +149,8 @@ class LineageClient: downstream=downstream, cll_mapping=column_lineage, ) + else: + assert_never(column_lineage) updater = DatasetPatchBuilder(str(downstream)) updater.add_upstream_lineage( @@ -227,9 +234,129 @@ class LineageClient: raise SdkUsageError( f"Dataset {updater.urn} does not exist, and hence cannot be updated." ) + mcps: List[ Union[MetadataChangeProposalWrapper, models.MetadataChangeProposalClass] ] = list(updater.build()) if query_entity: mcps.extend(query_entity) self._client._graph.emit_mcps(mcps) + + def add_dataset_lineage_from_sql( + self, + *, + query_text: str, + platform: str, + platform_instance: Optional[str] = None, + env: str = "PROD", + default_db: Optional[str] = None, + default_schema: Optional[str] = None, + ) -> None: + """Add lineage by parsing a SQL query.""" + from datahub.sql_parsing.sqlglot_lineage import ( + create_lineage_sql_parsed_result, + ) + + # Parse the SQL query to extract lineage information + parsed_result = create_lineage_sql_parsed_result( + query=query_text, + default_db=default_db, + default_schema=default_schema, + platform=platform, + platform_instance=platform_instance, + env=env, + graph=self._client._graph, + ) + + if parsed_result.debug_info.table_error: + raise SdkUsageError( + f"Failed to parse SQL query: {parsed_result.debug_info.error}" + ) + elif parsed_result.debug_info.column_error: + logger.warning( + f"Failed to parse SQL query: {parsed_result.debug_info.error}", + ) + + if not parsed_result.out_tables: + raise SdkUsageError( + "No output tables found in the query. Cannot establish lineage." + ) + + # Use the first output table as the downstream + downstream_urn = parsed_result.out_tables[0] + + # Process all upstream tables found in the query + for upstream_table in parsed_result.in_tables: + # Skip self-lineage + if upstream_table == downstream_urn: + continue + + # Extract column-level lineage for this specific upstream table + column_mapping = {} + if parsed_result.column_lineage: + for col_lineage in parsed_result.column_lineage: + if not (col_lineage.downstream and col_lineage.downstream.column): + continue + + # Filter upstreams to only include columns from current upstream table + upstream_cols = [ + ref.column + for ref in col_lineage.upstreams + if ref.table == upstream_table and ref.column + ] + + if upstream_cols: + column_mapping[col_lineage.downstream.column] = upstream_cols + + # Add lineage, including query text + self.add_dataset_transform_lineage( + upstream=upstream_table, + downstream=downstream_urn, + column_lineage=column_mapping or None, + query_text=query_text, + ) + + def add_datajob_lineage( + self, + *, + datajob: DatajobUrnOrStr, + upstreams: Optional[List[Union[DatasetUrnOrStr, DatajobUrnOrStr]]] = None, + downstreams: Optional[List[DatasetUrnOrStr]] = None, + ) -> None: + """ + Add lineage between a datajob and datasets/datajobs. + + Args: + datajob: The datajob URN to connect lineage with + upstreams: List of upstream datasets or datajobs that serve as inputs to the datajob + downstreams: List of downstream datasets that are outputs of the datajob + """ + + if not upstreams and not downstreams: + raise SdkUsageError("No upstreams or downstreams provided") + + datajob_urn = DataJobUrn.from_string(datajob) + + # Initialize the patch builder for the datajob + patch_builder = DataJobPatchBuilder(str(datajob_urn)) + + # Process upstream connections (inputs to the datajob) + if upstreams: + for upstream in upstreams: + # try converting to dataset urn + try: + dataset_urn = DatasetUrn.from_string(upstream) + patch_builder.add_input_dataset(dataset_urn) + except InvalidUrnError: + # try converting to datajob urn + datajob_urn = DataJobUrn.from_string(upstream) + patch_builder.add_input_datajob(datajob_urn) + + # Process downstream connections (outputs from the datajob) + if downstreams: + for downstream in downstreams: + downstream_urn = DatasetUrn.from_string(downstream) + patch_builder.add_output_dataset(downstream_urn) + + # Apply the changes to the entity + self._client.entities.update(patch_builder) diff --git a/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_datajob_inputs_only_golden.json b/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_datajob_inputs_only_golden.json new file mode 100644 index 0000000000..ec33e8811f --- /dev/null +++ b/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_datajob_inputs_only_golden.json @@ -0,0 +1,19 @@ +[ +{ + "entityType": "dataJob", + "entityUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),process_job)", + "changeType": "PATCH", + "aspectName": "dataJobInputOutput", + "aspect": { + "json": [ + { + "op": "add", + "path": "/inputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)", + "value": { + "destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)" + } + } + ] + } +} +] \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_datajob_lineage_golden.json b/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_datajob_lineage_golden.json new file mode 100644 index 0000000000..a30b8ff8f0 --- /dev/null +++ b/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_datajob_lineage_golden.json @@ -0,0 +1,33 @@ +[ +{ + "entityType": "dataJob", + "entityUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)", + "changeType": "PATCH", + "aspectName": "dataJobInputOutput", + "aspect": { + "json": [ + { + "op": "add", + "path": "/inputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)", + "value": { + "destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)" + } + }, + { + "op": "add", + "path": "/inputDatajobEdges/urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),upstream_job)", + "value": { + "destinationUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),upstream_job)" + } + }, + { + "op": "add", + "path": "/outputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)", + "value": { + "destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)" + } + } + ] + } +} +] \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_datajob_outputs_only_golden.json b/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_datajob_outputs_only_golden.json new file mode 100644 index 0000000000..53c9cd7fdd --- /dev/null +++ b/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_datajob_outputs_only_golden.json @@ -0,0 +1,19 @@ +[ +{ + "entityType": "dataJob", + "entityUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)", + "changeType": "PATCH", + "aspectName": "dataJobInputOutput", + "aspect": { + "json": [ + { + "op": "add", + "path": "/outputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)", + "value": { + "destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)" + } + } + ] + } +} +] \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_lineage_from_sql_golden.json b/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_lineage_from_sql_golden.json new file mode 100644 index 0000000000..99d292ce26 --- /dev/null +++ b/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_lineage_from_sql_golden.json @@ -0,0 +1,67 @@ +[ +{ + "entityType": "dataset", + "entityUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)", + "changeType": "PATCH", + "aspectName": "upstreamLineage", + "aspect": { + "json": [ + { + "op": "add", + "path": "/upstreams/urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)", + "value": { + "auditStamp": { + "time": 0, + "actor": "urn:li:corpuser:unknown" + }, + "dataset": "urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)", + "type": "TRANSFORMED", + "query": "urn:li:query:743f1807704d3d59ff076232e8788b43d98292f7d98ad14ce283d606351b7bb6" + } + } + ] + } +}, +{ + "entityType": "query", + "entityUrn": "urn:li:query:743f1807704d3d59ff076232e8788b43d98292f7d98ad14ce283d606351b7bb6", + "changeType": "UPSERT", + "aspectName": "queryProperties", + "aspect": { + "json": { + "customProperties": {}, + "statement": { + "value": "create table sales_summary as SELECT price, qty, unit_cost FROM orders", + "language": "SQL" + }, + "source": "SYSTEM", + "created": { + "time": 0, + "actor": "urn:li:corpuser:__ingestion" + }, + "lastModified": { + "time": 0, + "actor": "urn:li:corpuser:__ingestion" + } + } + } +}, +{ + "entityType": "query", + "entityUrn": "urn:li:query:743f1807704d3d59ff076232e8788b43d98292f7d98ad14ce283d606351b7bb6", + "changeType": "UPSERT", + "aspectName": "querySubjects", + "aspect": { + "json": { + "subjects": [ + { + "entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)" + }, + { + "entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)" + } + ] + } + } +} +] \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_lineage_from_sql_multiple_upstreams_golden.json b/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_lineage_from_sql_multiple_upstreams_golden.json new file mode 100644 index 0000000000..e390211167 --- /dev/null +++ b/metadata-ingestion/tests/unit/sdk_v2/lineage_client_golden/test_lineage_from_sql_multiple_upstreams_golden.json @@ -0,0 +1,67 @@ +[ +{ + "entityType": "dataset", + "entityUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)", + "changeType": "PATCH", + "aspectName": "upstreamLineage", + "aspect": { + "json": [ + { + "op": "add", + "path": "/upstreams/urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)", + "value": { + "auditStamp": { + "time": 0, + "actor": "urn:li:corpuser:unknown" + }, + "dataset": "urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)", + "type": "TRANSFORMED", + "query": "urn:li:query:41fd73db4d7749a886910c3c7f06c29082420f5e6feb988c534c595561bb4131" + } + } + ] + } +}, +{ + "entityType": "query", + "entityUrn": "urn:li:query:41fd73db4d7749a886910c3c7f06c29082420f5e6feb988c534c595561bb4131", + "changeType": "UPSERT", + "aspectName": "queryProperties", + "aspect": { + "json": { + "customProperties": {}, + "statement": { + "value": "\n CREATE TABLE sales_summary AS\n SELECT \n p.product_name,\n SUM(s.quantity) as total_quantity,\n FROM sales s\n JOIN products p ON s.product_id = p.id\n GROUP BY p.product_name\n ", + "language": "SQL" + }, + "source": "SYSTEM", + "created": { + "time": 0, + "actor": "urn:li:corpuser:__ingestion" + }, + "lastModified": { + "time": 0, + "actor": "urn:li:corpuser:__ingestion" + } + } + } +}, +{ + "entityType": "query", + "entityUrn": "urn:li:query:41fd73db4d7749a886910c3c7f06c29082420f5e6feb988c534c595561bb4131", + "changeType": "UPSERT", + "aspectName": "querySubjects", + "aspect": { + "json": { + "subjects": [ + { + "entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)" + }, + { + "entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)" + } + ] + } + } +} +] \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sdk_v2/test_lineage_client.py b/metadata-ingestion/tests/unit/sdk_v2/test_lineage_client.py index 0c28b7f339..1af0323ed1 100644 --- a/metadata-ingestion/tests/unit/sdk_v2/test_lineage_client.py +++ b/metadata-ingestion/tests/unit/sdk_v2/test_lineage_client.py @@ -1,6 +1,6 @@ import pathlib from typing import Dict, List, Set, cast -from unittest.mock import MagicMock, Mock +from unittest.mock import MagicMock, Mock, patch import pytest @@ -13,6 +13,14 @@ from datahub.metadata.schema_classes import ( ) from datahub.sdk.lineage_client import LineageClient from datahub.sdk.main_client import DataHubClient +from datahub.sql_parsing.sql_parsing_common import QueryType +from datahub.sql_parsing.sqlglot_lineage import ( + ColumnLineageInfo, + ColumnRef, + DownstreamColumnRef, + SqlParsingResult, +) +from datahub.utilities.urns.error import InvalidUrnError from tests.test_helpers import mce_helpers _GOLDEN_DIR = pathlib.Path(__file__).parent / "lineage_client_golden" @@ -22,6 +30,7 @@ _GOLDEN_DIR.mkdir(exist_ok=True) @pytest.fixture def mock_graph() -> Mock: graph = Mock() + return graph @@ -40,12 +49,9 @@ def assert_client_golden(client: DataHubClient, golden_path: pathlib.Path) -> No ) -def test_get_fuzzy_column_lineage(): +def test_get_fuzzy_column_lineage(client: DataHubClient) -> None: """Test the fuzzy column lineage matching algorithm.""" # Create a minimal client just for testing the method - client = MagicMock(spec=DataHubClient) - lineage_client = LineageClient(client=client) - # Test cases test_cases = [ # Case 1: Exact matches @@ -104,7 +110,7 @@ def test_get_fuzzy_column_lineage(): # Run test cases for i, test_case in enumerate(test_cases): - result = lineage_client._get_fuzzy_column_lineage( + result = client.lineage._get_fuzzy_column_lineage( cast(Set[str], test_case["upstream_fields"]), cast(Set[str], test_case["downstream_fields"]), ) @@ -113,11 +119,9 @@ def test_get_fuzzy_column_lineage(): ) -def test_get_strict_column_lineage(): +def test_get_strict_column_lineage(client: DataHubClient) -> None: """Test the strict column lineage matching algorithm.""" # Create a minimal client just for testing the method - client = MagicMock(spec=DataHubClient) - lineage_client = LineageClient(client=client) # Define test cases test_cases = [ @@ -143,7 +147,7 @@ def test_get_strict_column_lineage(): # Run test cases for i, test_case in enumerate(test_cases): - result = lineage_client._get_strict_column_lineage( + result = client.lineage._get_strict_column_lineage( cast(Set[str], test_case["upstream_fields"]), cast(Set[str], test_case["downstream_fields"]), ) @@ -152,7 +156,6 @@ def test_get_strict_column_lineage(): def test_add_dataset_copy_lineage_auto_fuzzy(client: DataHubClient) -> None: """Test auto fuzzy column lineage mapping.""" - lineage_client = LineageClient(client=client) upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)" downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)" @@ -213,23 +216,24 @@ def test_add_dataset_copy_lineage_auto_fuzzy(client: DataHubClient) -> None: ], ) - # Mock the _get_fields_from_dataset_urn method to return our test fields - lineage_client._get_fields_from_dataset_urn = MagicMock() # type: ignore - lineage_client._get_fields_from_dataset_urn.side_effect = lambda urn: sorted( - { # type: ignore - field.fieldPath - for field in ( - upstream_schema if "upstream" in str(urn) else downstream_schema - ).fields - } - ) + # Use patch.object with a context manager + with patch.object(LineageClient, "_get_fields_from_dataset_urn") as mock_method: + # Configure the mock with a simpler side effect function + mock_method.side_effect = lambda urn: sorted( + { + field.fieldPath + for field in ( + upstream_schema if "upstream" in str(urn) else downstream_schema + ).fields + } + ) - # Run the lineage function - lineage_client.add_dataset_copy_lineage( - upstream=upstream, - downstream=downstream, - column_lineage="auto_fuzzy", - ) + # Now use client.lineage with the patched method + client.lineage.add_dataset_copy_lineage( + upstream=upstream, + downstream=downstream, + column_lineage="auto_fuzzy", + ) # Use golden file for assertion assert_client_golden(client, _GOLDEN_DIR / "test_lineage_copy_fuzzy_golden.json") @@ -237,8 +241,6 @@ def test_add_dataset_copy_lineage_auto_fuzzy(client: DataHubClient) -> None: def test_add_dataset_copy_lineage_auto_strict(client: DataHubClient) -> None: """Test strict column lineage with field matches.""" - lineage_client = LineageClient(client=client) - upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)" downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)" @@ -303,41 +305,22 @@ def test_add_dataset_copy_lineage_auto_strict(client: DataHubClient) -> None: ], ) - # Mock the _get_fields_from_dataset_urn method to return our test fields - lineage_client._get_fields_from_dataset_urn = MagicMock() # type: ignore - lineage_client._get_fields_from_dataset_urn.side_effect = lambda urn: sorted( - { # type: ignore - field.fieldPath - for field in ( - upstream_schema if "upstream" in str(urn) else downstream_schema - ).fields - } - ) + with patch.object(LineageClient, "_get_fields_from_dataset_urn") as mock_method: + mock_method.side_effect = lambda urn: sorted( + { + field.fieldPath + for field in ( + upstream_schema if "upstream" in str(urn) else downstream_schema + ).fields + } + ) - # Run the lineage function - lineage_client.add_dataset_copy_lineage( - upstream=upstream, - downstream=downstream, - column_lineage="auto_strict", - ) - - # Mock the _get_fields_from_dataset_urn method to return our test fields - lineage_client._get_fields_from_dataset_urn = MagicMock() # type: ignore - lineage_client._get_fields_from_dataset_urn.side_effect = lambda urn: sorted( - { # type: ignore - field.fieldPath - for field in ( - upstream_schema if "upstream" in str(urn) else downstream_schema - ).fields - } - ) - - # Run the lineage function - lineage_client.add_dataset_copy_lineage( - upstream=upstream, - downstream=downstream, - column_lineage="auto_strict", - ) + # Run the lineage function + client.lineage.add_dataset_copy_lineage( + upstream=upstream, + downstream=downstream, + column_lineage="auto_strict", + ) # Use golden file for assertion assert_client_golden(client, _GOLDEN_DIR / "test_lineage_copy_strict_golden.json") @@ -345,13 +328,12 @@ def test_add_dataset_copy_lineage_auto_strict(client: DataHubClient) -> None: def test_add_dataset_transform_lineage_basic(client: DataHubClient) -> None: """Test basic lineage without column mapping or query.""" - lineage_client = LineageClient(client=client) # Basic lineage test upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)" downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)" - lineage_client.add_dataset_transform_lineage( + client.lineage.add_dataset_transform_lineage( upstream=upstream, downstream=downstream, ) @@ -360,7 +342,6 @@ def test_add_dataset_transform_lineage_basic(client: DataHubClient) -> None: def test_add_dataset_transform_lineage_complete(client: DataHubClient) -> None: """Test complete lineage with column mapping and query.""" - lineage_client = LineageClient(client=client) upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)" downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)" @@ -372,10 +353,211 @@ def test_add_dataset_transform_lineage_complete(client: DataHubClient) -> None: "ds_col2": ["us_col2", "us_col3"], # 2:1 mapping } - lineage_client.add_dataset_transform_lineage( + client.lineage.add_dataset_transform_lineage( upstream=upstream, downstream=downstream, query_text=query_text, column_lineage=column_lineage, ) assert_client_golden(client, _GOLDEN_DIR / "test_lineage_complete_golden.json") + + +def test_add_dataset_lineage_from_sql(client: DataHubClient) -> None: + """Test adding lineage from SQL parsing with a golden file.""" + + # Create minimal mock result with necessary info + mock_result = SqlParsingResult( + in_tables=["urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)"], + out_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)" + ], + column_lineage=[], # Simplified - we only care about table-level lineage for this test + query_type=QueryType.SELECT, + debug_info=MagicMock(error=None, table_error=None), + ) + + # Simple SQL that would produce the expected lineage + query_text = ( + "create table sales_summary as SELECT price, qty, unit_cost FROM orders" + ) + + # Patch SQL parser and execute lineage creation + with patch( + "datahub.sql_parsing.sqlglot_lineage.create_lineage_sql_parsed_result", + return_value=mock_result, + ): + client.lineage.add_dataset_lineage_from_sql( + query_text=query_text, platform="snowflake", env="PROD" + ) + + # Validate against golden file + assert_client_golden(client, _GOLDEN_DIR / "test_lineage_from_sql_golden.json") + + +def test_add_dataset_lineage_from_sql_with_multiple_upstreams( + client: DataHubClient, +) -> None: + """Test adding lineage for a dataset with multiple upstreams.""" + + # Create minimal mock result with necessary info + mock_result = SqlParsingResult( + in_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)", + ], + out_tables=[ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)" + ], + column_lineage=[ + ColumnLineageInfo( + downstream=DownstreamColumnRef( + column="product_name", + ), + upstreams=[ + ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)", + column="product_name", + ) + ], + ), + ColumnLineageInfo( + downstream=DownstreamColumnRef( + column="total_quantity", + ), + upstreams=[ + ColumnRef( + table="urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)", + column="quantity", + ) + ], + ), + ], + query_type=QueryType.SELECT, + debug_info=MagicMock(error=None, table_error=None), + ) + + # Simple SQL that would produce the expected lineage + query_text = """ + CREATE TABLE sales_summary AS + SELECT + p.product_name, + SUM(s.quantity) as total_quantity, + FROM sales s + JOIN products p ON s.product_id = p.id + GROUP BY p.product_name + """ + + # Patch SQL parser and execute lineage creation + with patch( + "datahub.sql_parsing.sqlglot_lineage.create_lineage_sql_parsed_result", + return_value=mock_result, + ): + client.lineage.add_dataset_lineage_from_sql( + query_text=query_text, platform="snowflake", env="PROD" + ) + + # Validate against golden file + assert_client_golden( + client, _GOLDEN_DIR / "test_lineage_from_sql_multiple_upstreams_golden.json" + ) + + +def test_add_datajob_lineage(client: DataHubClient) -> None: + """Test adding lineage for datajobs using DataJobPatchBuilder.""" + + # Define URNs for test with correct format + datajob_urn = ( + "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)" + ) + input_dataset_urn = ( + "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)" + ) + input_datajob_urn = ( + "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),upstream_job)" + ) + output_dataset_urn = ( + "urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)" + ) + + # Test adding both upstream and downstream connections + client.lineage.add_datajob_lineage( + datajob=datajob_urn, + upstreams=[input_dataset_urn, input_datajob_urn], + downstreams=[output_dataset_urn], + ) + + # Validate lineage MCPs against golden file + assert_client_golden(client, _GOLDEN_DIR / "test_datajob_lineage_golden.json") + + +def test_add_datajob_inputs_only(client: DataHubClient) -> None: + """Test adding only inputs to a datajob.""" + + # Define URNs for test + datajob_urn = ( + "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),process_job)" + ) + input_dataset_urn = ( + "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)" + ) + + # Test adding just upstream connections + client.lineage.add_datajob_lineage( + datajob=datajob_urn, + upstreams=[input_dataset_urn], + ) + + # Validate lineage MCPs + assert_client_golden(client, _GOLDEN_DIR / "test_datajob_inputs_only_golden.json") + + +def test_add_datajob_outputs_only(client: DataHubClient) -> None: + """Test adding only outputs to a datajob.""" + + # Define URNs for test + datajob_urn = ( + "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)" + ) + output_dataset_urn = ( + "urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)" + ) + + # Test adding just downstream connections + client.lineage.add_datajob_lineage( + datajob=datajob_urn, downstreams=[output_dataset_urn] + ) + + # Validate lineage MCPs + assert_client_golden(client, _GOLDEN_DIR / "test_datajob_outputs_only_golden.json") + + +def test_add_datajob_lineage_validation(client: DataHubClient) -> None: + """Test validation checks in add_datajob_lineage.""" + + # Define URNs for test + datajob_urn = ( + "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)" + ) + invalid_urn = "urn:li:glossaryNode:something" + + # Test with invalid datajob URN + with pytest.raises( + InvalidUrnError, + match="Passed an urn of type glossaryNode to the from_string method of DataJobUrn", + ): + client.lineage.add_datajob_lineage( + datajob=invalid_urn, + upstreams=[ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)" + ], + ) + + # Test with invalid upstream URN + with pytest.raises(InvalidUrnError): + client.lineage.add_datajob_lineage(datajob=datajob_urn, upstreams=[invalid_urn]) + + # Test with invalid downstream URN + with pytest.raises(InvalidUrnError): + client.lineage.add_datajob_lineage( + datajob=datajob_urn, downstreams=[invalid_urn] + )