mirror of
https://github.com/datahub-project/datahub.git
synced 2025-11-01 19:25:56 +00:00
feat(sdk): add datajob lineage & dataset sql parsing lineage (#13365)
This commit is contained in:
parent
71e104068e
commit
a414bbb798
@ -0,0 +1,13 @@
|
||||
from datahub.metadata.urns import DataFlowUrn, DataJobUrn
|
||||
from datahub.sdk.lineage_client import LineageClient
|
||||
from datahub.sdk.main_client import DataHubClient
|
||||
|
||||
client = DataHubClient.from_env()
|
||||
lineage_client = LineageClient(client=client)
|
||||
|
||||
flow_urn = DataFlowUrn(orchestrator="airflow", flow_id="data_pipeline", cluster="PROD")
|
||||
|
||||
lineage_client.add_datajob_lineage(
|
||||
datajob=DataJobUrn(flow=flow_urn, job_id="data_pipeline"),
|
||||
upstreams=[DataJobUrn(flow=flow_urn, job_id="extract_job")],
|
||||
)
|
||||
@ -0,0 +1,14 @@
|
||||
from datahub.metadata.urns import DataFlowUrn, DataJobUrn, DatasetUrn
|
||||
from datahub.sdk.lineage_client import LineageClient
|
||||
from datahub.sdk.main_client import DataHubClient
|
||||
|
||||
client = DataHubClient.from_env()
|
||||
lineage_client = LineageClient(client=client)
|
||||
|
||||
flow_urn = DataFlowUrn(orchestrator="airflow", flow_id="data_pipeline", cluster="PROD")
|
||||
|
||||
lineage_client.add_datajob_lineage(
|
||||
datajob=DataJobUrn(flow=flow_urn, job_id="data_pipeline"),
|
||||
upstreams=[DatasetUrn(platform="postgres", name="raw_data")],
|
||||
downstreams=[DatasetUrn(platform="snowflake", name="processed_data")],
|
||||
)
|
||||
20
metadata-ingestion/examples/library/lineage_dataset_copy.py
Normal file
20
metadata-ingestion/examples/library/lineage_dataset_copy.py
Normal file
@ -0,0 +1,20 @@
|
||||
from datahub.metadata.urns import DatasetUrn
|
||||
from datahub.sdk.lineage_client import LineageClient
|
||||
from datahub.sdk.main_client import DataHubClient
|
||||
|
||||
client = DataHubClient.from_env()
|
||||
lineage_client = LineageClient(client=client)
|
||||
|
||||
lineage_client.add_dataset_copy_lineage(
|
||||
upstream=DatasetUrn(platform="postgres", name="customer_data"),
|
||||
downstream=DatasetUrn(platform="snowflake", name="customer_info"),
|
||||
column_lineage="auto_fuzzy",
|
||||
)
|
||||
# by default, the column lineage is "auto_fuzzy", which will match similar field names.
|
||||
# can also be "auto_strict" for strict matching.
|
||||
# can also be a dict mapping upstream fields to downstream fields.
|
||||
# e.g.
|
||||
# column_lineage={
|
||||
# "customer_id": ["id"],
|
||||
# "full_name": ["first_name", "last_name"],
|
||||
# }
|
||||
@ -0,0 +1,27 @@
|
||||
from datahub.sdk.lineage_client import LineageClient
|
||||
from datahub.sdk.main_client import DataHubClient
|
||||
|
||||
client = DataHubClient.from_env()
|
||||
lineage_client = LineageClient(client=client)
|
||||
|
||||
sql_query = """
|
||||
CREATE TABLE sales_summary AS
|
||||
SELECT
|
||||
p.product_name,
|
||||
c.customer_segment,
|
||||
SUM(s.quantity) as total_quantity,
|
||||
SUM(s.amount) as total_sales
|
||||
FROM sales s
|
||||
JOIN products p ON s.product_id = p.id
|
||||
JOIN customers c ON s.customer_id = c.id
|
||||
GROUP BY p.product_name, c.customer_segment
|
||||
"""
|
||||
|
||||
# sales_summary will be assumed to be in the default db/schema
|
||||
# e.g. prod_db.public.sales_summary
|
||||
lineage_client.add_dataset_lineage_from_sql(
|
||||
query_text=sql_query,
|
||||
platform="snowflake",
|
||||
default_db="prod_db",
|
||||
default_schema="public",
|
||||
)
|
||||
@ -0,0 +1,17 @@
|
||||
from datahub.metadata.urns import DatasetUrn
|
||||
from datahub.sdk.lineage_client import LineageClient
|
||||
from datahub.sdk.main_client import DataHubClient
|
||||
|
||||
client = DataHubClient.from_env()
|
||||
lineage_client = LineageClient(client=client)
|
||||
|
||||
|
||||
lineage_client.add_dataset_transform_lineage(
|
||||
upstream=DatasetUrn(platform="snowflake", name="source_table"),
|
||||
downstream=DatasetUrn(platform="snowflake", name="target_table"),
|
||||
column_lineage={
|
||||
"customer_id": ["id"],
|
||||
"full_name": ["first_name", "last_name"],
|
||||
},
|
||||
)
|
||||
# column_lineage is optional -- if not provided, table-level lineage is inferred.
|
||||
@ -0,0 +1,24 @@
|
||||
from datahub.metadata.urns import DatasetUrn
|
||||
from datahub.sdk.lineage_client import LineageClient
|
||||
from datahub.sdk.main_client import DataHubClient
|
||||
|
||||
client = DataHubClient.from_env()
|
||||
lineage_client = LineageClient(client=client)
|
||||
|
||||
# this can be any transformation logic e.g. a spark job, an airflow DAG, python script, etc.
|
||||
# if you have a SQL query, we recommend using add_dataset_lineage_from_sql instead.
|
||||
|
||||
query_text = """
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
spark = SparkSession.builder.appName("HighValueFilter").getOrCreate()
|
||||
df = spark.read.table("customers")
|
||||
high_value = df.filter("lifetime_value > 10000")
|
||||
high_value.write.saveAsTable("high_value_customers")
|
||||
"""
|
||||
|
||||
lineage_client.add_dataset_transform_lineage(
|
||||
upstream=DatasetUrn(platform="snowflake", name="customers"),
|
||||
downstream=DatasetUrn(platform="snowflake", name="high_value_customers"),
|
||||
query_text=query_text,
|
||||
)
|
||||
@ -4,22 +4,24 @@ import difflib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Set, Union
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import datahub.metadata.schema_classes as models
|
||||
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
||||
from datahub.errors import SdkUsageError
|
||||
from datahub.metadata.schema_classes import SchemaMetadataClass
|
||||
from datahub.metadata.urns import DatasetUrn, QueryUrn
|
||||
from datahub.sdk._shared import DatasetUrnOrStr
|
||||
from datahub.metadata.urns import DataJobUrn, DatasetUrn, QueryUrn
|
||||
from datahub.sdk._shared import DatajobUrnOrStr, DatasetUrnOrStr
|
||||
from datahub.sdk._utils import DEFAULT_ACTOR_URN
|
||||
from datahub.sdk.dataset import ColumnLineageMapping, parse_cll_mapping
|
||||
from datahub.specific.datajob import DataJobPatchBuilder
|
||||
from datahub.specific.dataset import DatasetPatchBuilder
|
||||
from datahub.sql_parsing.fingerprint_utils import generate_hash
|
||||
from datahub.utilities.ordered_set import OrderedSet
|
||||
from datahub.utilities.urns.error import InvalidUrnError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datahub.sdk.main_client import DataHubClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_empty_audit_stamp = models.AuditStampClass(
|
||||
time=0,
|
||||
@ -27,16 +29,19 @@ _empty_audit_stamp = models.AuditStampClass(
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LineageClient:
|
||||
def __init__(self, client: DataHubClient):
|
||||
self._client = client
|
||||
|
||||
def _get_fields_from_dataset_urn(self, dataset_urn: DatasetUrn) -> Set[str]:
|
||||
schema_metadata = self._client._graph.get_aspect(
|
||||
str(dataset_urn), SchemaMetadataClass
|
||||
str(dataset_urn), models.SchemaMetadataClass
|
||||
)
|
||||
if schema_metadata is None:
|
||||
return Set()
|
||||
return set()
|
||||
|
||||
return {field.fieldPath for field in schema_metadata.fields}
|
||||
|
||||
@ -122,7 +127,7 @@ class LineageClient:
|
||||
|
||||
if column_lineage is None:
|
||||
cll = None
|
||||
elif column_lineage in ["auto_fuzzy", "auto_strict"]:
|
||||
elif column_lineage == "auto_fuzzy" or column_lineage == "auto_strict":
|
||||
upstream_schema = self._get_fields_from_dataset_urn(upstream)
|
||||
downstream_schema = self._get_fields_from_dataset_urn(downstream)
|
||||
if column_lineage == "auto_fuzzy":
|
||||
@ -144,6 +149,8 @@ class LineageClient:
|
||||
downstream=downstream,
|
||||
cll_mapping=column_lineage,
|
||||
)
|
||||
else:
|
||||
assert_never(column_lineage)
|
||||
|
||||
updater = DatasetPatchBuilder(str(downstream))
|
||||
updater.add_upstream_lineage(
|
||||
@ -227,9 +234,129 @@ class LineageClient:
|
||||
raise SdkUsageError(
|
||||
f"Dataset {updater.urn} does not exist, and hence cannot be updated."
|
||||
)
|
||||
|
||||
mcps: List[
|
||||
Union[MetadataChangeProposalWrapper, models.MetadataChangeProposalClass]
|
||||
] = list(updater.build())
|
||||
if query_entity:
|
||||
mcps.extend(query_entity)
|
||||
self._client._graph.emit_mcps(mcps)
|
||||
|
||||
def add_dataset_lineage_from_sql(
|
||||
self,
|
||||
*,
|
||||
query_text: str,
|
||||
platform: str,
|
||||
platform_instance: Optional[str] = None,
|
||||
env: str = "PROD",
|
||||
default_db: Optional[str] = None,
|
||||
default_schema: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Add lineage by parsing a SQL query."""
|
||||
from datahub.sql_parsing.sqlglot_lineage import (
|
||||
create_lineage_sql_parsed_result,
|
||||
)
|
||||
|
||||
# Parse the SQL query to extract lineage information
|
||||
parsed_result = create_lineage_sql_parsed_result(
|
||||
query=query_text,
|
||||
default_db=default_db,
|
||||
default_schema=default_schema,
|
||||
platform=platform,
|
||||
platform_instance=platform_instance,
|
||||
env=env,
|
||||
graph=self._client._graph,
|
||||
)
|
||||
|
||||
if parsed_result.debug_info.table_error:
|
||||
raise SdkUsageError(
|
||||
f"Failed to parse SQL query: {parsed_result.debug_info.error}"
|
||||
)
|
||||
elif parsed_result.debug_info.column_error:
|
||||
logger.warning(
|
||||
f"Failed to parse SQL query: {parsed_result.debug_info.error}",
|
||||
)
|
||||
|
||||
if not parsed_result.out_tables:
|
||||
raise SdkUsageError(
|
||||
"No output tables found in the query. Cannot establish lineage."
|
||||
)
|
||||
|
||||
# Use the first output table as the downstream
|
||||
downstream_urn = parsed_result.out_tables[0]
|
||||
|
||||
# Process all upstream tables found in the query
|
||||
for upstream_table in parsed_result.in_tables:
|
||||
# Skip self-lineage
|
||||
if upstream_table == downstream_urn:
|
||||
continue
|
||||
|
||||
# Extract column-level lineage for this specific upstream table
|
||||
column_mapping = {}
|
||||
if parsed_result.column_lineage:
|
||||
for col_lineage in parsed_result.column_lineage:
|
||||
if not (col_lineage.downstream and col_lineage.downstream.column):
|
||||
continue
|
||||
|
||||
# Filter upstreams to only include columns from current upstream table
|
||||
upstream_cols = [
|
||||
ref.column
|
||||
for ref in col_lineage.upstreams
|
||||
if ref.table == upstream_table and ref.column
|
||||
]
|
||||
|
||||
if upstream_cols:
|
||||
column_mapping[col_lineage.downstream.column] = upstream_cols
|
||||
|
||||
# Add lineage, including query text
|
||||
self.add_dataset_transform_lineage(
|
||||
upstream=upstream_table,
|
||||
downstream=downstream_urn,
|
||||
column_lineage=column_mapping or None,
|
||||
query_text=query_text,
|
||||
)
|
||||
|
||||
def add_datajob_lineage(
|
||||
self,
|
||||
*,
|
||||
datajob: DatajobUrnOrStr,
|
||||
upstreams: Optional[List[Union[DatasetUrnOrStr, DatajobUrnOrStr]]] = None,
|
||||
downstreams: Optional[List[DatasetUrnOrStr]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add lineage between a datajob and datasets/datajobs.
|
||||
|
||||
Args:
|
||||
datajob: The datajob URN to connect lineage with
|
||||
upstreams: List of upstream datasets or datajobs that serve as inputs to the datajob
|
||||
downstreams: List of downstream datasets that are outputs of the datajob
|
||||
"""
|
||||
|
||||
if not upstreams and not downstreams:
|
||||
raise SdkUsageError("No upstreams or downstreams provided")
|
||||
|
||||
datajob_urn = DataJobUrn.from_string(datajob)
|
||||
|
||||
# Initialize the patch builder for the datajob
|
||||
patch_builder = DataJobPatchBuilder(str(datajob_urn))
|
||||
|
||||
# Process upstream connections (inputs to the datajob)
|
||||
if upstreams:
|
||||
for upstream in upstreams:
|
||||
# try converting to dataset urn
|
||||
try:
|
||||
dataset_urn = DatasetUrn.from_string(upstream)
|
||||
patch_builder.add_input_dataset(dataset_urn)
|
||||
except InvalidUrnError:
|
||||
# try converting to datajob urn
|
||||
datajob_urn = DataJobUrn.from_string(upstream)
|
||||
patch_builder.add_input_datajob(datajob_urn)
|
||||
|
||||
# Process downstream connections (outputs from the datajob)
|
||||
if downstreams:
|
||||
for downstream in downstreams:
|
||||
downstream_urn = DatasetUrn.from_string(downstream)
|
||||
patch_builder.add_output_dataset(downstream_urn)
|
||||
|
||||
# Apply the changes to the entity
|
||||
self._client.entities.update(patch_builder)
|
||||
|
||||
@ -0,0 +1,19 @@
|
||||
[
|
||||
{
|
||||
"entityType": "dataJob",
|
||||
"entityUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),process_job)",
|
||||
"changeType": "PATCH",
|
||||
"aspectName": "dataJobInputOutput",
|
||||
"aspect": {
|
||||
"json": [
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/inputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)",
|
||||
"value": {
|
||||
"destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,33 @@
|
||||
[
|
||||
{
|
||||
"entityType": "dataJob",
|
||||
"entityUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)",
|
||||
"changeType": "PATCH",
|
||||
"aspectName": "dataJobInputOutput",
|
||||
"aspect": {
|
||||
"json": [
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/inputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)",
|
||||
"value": {
|
||||
"destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||
}
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/inputDatajobEdges/urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),upstream_job)",
|
||||
"value": {
|
||||
"destinationUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),upstream_job)"
|
||||
}
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/outputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)",
|
||||
"value": {
|
||||
"destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,19 @@
|
||||
[
|
||||
{
|
||||
"entityType": "dataJob",
|
||||
"entityUrn": "urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)",
|
||||
"changeType": "PATCH",
|
||||
"aspectName": "dataJobInputOutput",
|
||||
"aspect": {
|
||||
"json": [
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/outputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)",
|
||||
"value": {
|
||||
"destinationUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,67 @@
|
||||
[
|
||||
{
|
||||
"entityType": "dataset",
|
||||
"entityUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)",
|
||||
"changeType": "PATCH",
|
||||
"aspectName": "upstreamLineage",
|
||||
"aspect": {
|
||||
"json": [
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/upstreams/urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)",
|
||||
"value": {
|
||||
"auditStamp": {
|
||||
"time": 0,
|
||||
"actor": "urn:li:corpuser:unknown"
|
||||
},
|
||||
"dataset": "urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)",
|
||||
"type": "TRANSFORMED",
|
||||
"query": "urn:li:query:743f1807704d3d59ff076232e8788b43d98292f7d98ad14ce283d606351b7bb6"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"entityType": "query",
|
||||
"entityUrn": "urn:li:query:743f1807704d3d59ff076232e8788b43d98292f7d98ad14ce283d606351b7bb6",
|
||||
"changeType": "UPSERT",
|
||||
"aspectName": "queryProperties",
|
||||
"aspect": {
|
||||
"json": {
|
||||
"customProperties": {},
|
||||
"statement": {
|
||||
"value": "create table sales_summary as SELECT price, qty, unit_cost FROM orders",
|
||||
"language": "SQL"
|
||||
},
|
||||
"source": "SYSTEM",
|
||||
"created": {
|
||||
"time": 0,
|
||||
"actor": "urn:li:corpuser:__ingestion"
|
||||
},
|
||||
"lastModified": {
|
||||
"time": 0,
|
||||
"actor": "urn:li:corpuser:__ingestion"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"entityType": "query",
|
||||
"entityUrn": "urn:li:query:743f1807704d3d59ff076232e8788b43d98292f7d98ad14ce283d606351b7bb6",
|
||||
"changeType": "UPSERT",
|
||||
"aspectName": "querySubjects",
|
||||
"aspect": {
|
||||
"json": {
|
||||
"subjects": [
|
||||
{
|
||||
"entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)"
|
||||
},
|
||||
{
|
||||
"entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,67 @@
|
||||
[
|
||||
{
|
||||
"entityType": "dataset",
|
||||
"entityUrn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)",
|
||||
"changeType": "PATCH",
|
||||
"aspectName": "upstreamLineage",
|
||||
"aspect": {
|
||||
"json": [
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/upstreams/urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)",
|
||||
"value": {
|
||||
"auditStamp": {
|
||||
"time": 0,
|
||||
"actor": "urn:li:corpuser:unknown"
|
||||
},
|
||||
"dataset": "urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)",
|
||||
"type": "TRANSFORMED",
|
||||
"query": "urn:li:query:41fd73db4d7749a886910c3c7f06c29082420f5e6feb988c534c595561bb4131"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"entityType": "query",
|
||||
"entityUrn": "urn:li:query:41fd73db4d7749a886910c3c7f06c29082420f5e6feb988c534c595561bb4131",
|
||||
"changeType": "UPSERT",
|
||||
"aspectName": "queryProperties",
|
||||
"aspect": {
|
||||
"json": {
|
||||
"customProperties": {},
|
||||
"statement": {
|
||||
"value": "\n CREATE TABLE sales_summary AS\n SELECT \n p.product_name,\n SUM(s.quantity) as total_quantity,\n FROM sales s\n JOIN products p ON s.product_id = p.id\n GROUP BY p.product_name\n ",
|
||||
"language": "SQL"
|
||||
},
|
||||
"source": "SYSTEM",
|
||||
"created": {
|
||||
"time": 0,
|
||||
"actor": "urn:li:corpuser:__ingestion"
|
||||
},
|
||||
"lastModified": {
|
||||
"time": 0,
|
||||
"actor": "urn:li:corpuser:__ingestion"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"entityType": "query",
|
||||
"entityUrn": "urn:li:query:41fd73db4d7749a886910c3c7f06c29082420f5e6feb988c534c595561bb4131",
|
||||
"changeType": "UPSERT",
|
||||
"aspectName": "querySubjects",
|
||||
"aspect": {
|
||||
"json": {
|
||||
"subjects": [
|
||||
{
|
||||
"entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)"
|
||||
},
|
||||
{
|
||||
"entity": "urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -1,6 +1,6 @@
|
||||
import pathlib
|
||||
from typing import Dict, List, Set, cast
|
||||
from unittest.mock import MagicMock, Mock
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -13,6 +13,14 @@ from datahub.metadata.schema_classes import (
|
||||
)
|
||||
from datahub.sdk.lineage_client import LineageClient
|
||||
from datahub.sdk.main_client import DataHubClient
|
||||
from datahub.sql_parsing.sql_parsing_common import QueryType
|
||||
from datahub.sql_parsing.sqlglot_lineage import (
|
||||
ColumnLineageInfo,
|
||||
ColumnRef,
|
||||
DownstreamColumnRef,
|
||||
SqlParsingResult,
|
||||
)
|
||||
from datahub.utilities.urns.error import InvalidUrnError
|
||||
from tests.test_helpers import mce_helpers
|
||||
|
||||
_GOLDEN_DIR = pathlib.Path(__file__).parent / "lineage_client_golden"
|
||||
@ -22,6 +30,7 @@ _GOLDEN_DIR.mkdir(exist_ok=True)
|
||||
@pytest.fixture
|
||||
def mock_graph() -> Mock:
|
||||
graph = Mock()
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
@ -40,12 +49,9 @@ def assert_client_golden(client: DataHubClient, golden_path: pathlib.Path) -> No
|
||||
)
|
||||
|
||||
|
||||
def test_get_fuzzy_column_lineage():
|
||||
def test_get_fuzzy_column_lineage(client: DataHubClient) -> None:
|
||||
"""Test the fuzzy column lineage matching algorithm."""
|
||||
# Create a minimal client just for testing the method
|
||||
client = MagicMock(spec=DataHubClient)
|
||||
lineage_client = LineageClient(client=client)
|
||||
|
||||
# Test cases
|
||||
test_cases = [
|
||||
# Case 1: Exact matches
|
||||
@ -104,7 +110,7 @@ def test_get_fuzzy_column_lineage():
|
||||
|
||||
# Run test cases
|
||||
for i, test_case in enumerate(test_cases):
|
||||
result = lineage_client._get_fuzzy_column_lineage(
|
||||
result = client.lineage._get_fuzzy_column_lineage(
|
||||
cast(Set[str], test_case["upstream_fields"]),
|
||||
cast(Set[str], test_case["downstream_fields"]),
|
||||
)
|
||||
@ -113,11 +119,9 @@ def test_get_fuzzy_column_lineage():
|
||||
)
|
||||
|
||||
|
||||
def test_get_strict_column_lineage():
|
||||
def test_get_strict_column_lineage(client: DataHubClient) -> None:
|
||||
"""Test the strict column lineage matching algorithm."""
|
||||
# Create a minimal client just for testing the method
|
||||
client = MagicMock(spec=DataHubClient)
|
||||
lineage_client = LineageClient(client=client)
|
||||
|
||||
# Define test cases
|
||||
test_cases = [
|
||||
@ -143,7 +147,7 @@ def test_get_strict_column_lineage():
|
||||
|
||||
# Run test cases
|
||||
for i, test_case in enumerate(test_cases):
|
||||
result = lineage_client._get_strict_column_lineage(
|
||||
result = client.lineage._get_strict_column_lineage(
|
||||
cast(Set[str], test_case["upstream_fields"]),
|
||||
cast(Set[str], test_case["downstream_fields"]),
|
||||
)
|
||||
@ -152,7 +156,6 @@ def test_get_strict_column_lineage():
|
||||
|
||||
def test_add_dataset_copy_lineage_auto_fuzzy(client: DataHubClient) -> None:
|
||||
"""Test auto fuzzy column lineage mapping."""
|
||||
lineage_client = LineageClient(client=client)
|
||||
|
||||
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
|
||||
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
|
||||
@ -213,23 +216,24 @@ def test_add_dataset_copy_lineage_auto_fuzzy(client: DataHubClient) -> None:
|
||||
],
|
||||
)
|
||||
|
||||
# Mock the _get_fields_from_dataset_urn method to return our test fields
|
||||
lineage_client._get_fields_from_dataset_urn = MagicMock() # type: ignore
|
||||
lineage_client._get_fields_from_dataset_urn.side_effect = lambda urn: sorted(
|
||||
{ # type: ignore
|
||||
field.fieldPath
|
||||
for field in (
|
||||
upstream_schema if "upstream" in str(urn) else downstream_schema
|
||||
).fields
|
||||
}
|
||||
)
|
||||
# Use patch.object with a context manager
|
||||
with patch.object(LineageClient, "_get_fields_from_dataset_urn") as mock_method:
|
||||
# Configure the mock with a simpler side effect function
|
||||
mock_method.side_effect = lambda urn: sorted(
|
||||
{
|
||||
field.fieldPath
|
||||
for field in (
|
||||
upstream_schema if "upstream" in str(urn) else downstream_schema
|
||||
).fields
|
||||
}
|
||||
)
|
||||
|
||||
# Run the lineage function
|
||||
lineage_client.add_dataset_copy_lineage(
|
||||
upstream=upstream,
|
||||
downstream=downstream,
|
||||
column_lineage="auto_fuzzy",
|
||||
)
|
||||
# Now use client.lineage with the patched method
|
||||
client.lineage.add_dataset_copy_lineage(
|
||||
upstream=upstream,
|
||||
downstream=downstream,
|
||||
column_lineage="auto_fuzzy",
|
||||
)
|
||||
|
||||
# Use golden file for assertion
|
||||
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_copy_fuzzy_golden.json")
|
||||
@ -237,8 +241,6 @@ def test_add_dataset_copy_lineage_auto_fuzzy(client: DataHubClient) -> None:
|
||||
|
||||
def test_add_dataset_copy_lineage_auto_strict(client: DataHubClient) -> None:
|
||||
"""Test strict column lineage with field matches."""
|
||||
lineage_client = LineageClient(client=client)
|
||||
|
||||
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
|
||||
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
|
||||
|
||||
@ -303,41 +305,22 @@ def test_add_dataset_copy_lineage_auto_strict(client: DataHubClient) -> None:
|
||||
],
|
||||
)
|
||||
|
||||
# Mock the _get_fields_from_dataset_urn method to return our test fields
|
||||
lineage_client._get_fields_from_dataset_urn = MagicMock() # type: ignore
|
||||
lineage_client._get_fields_from_dataset_urn.side_effect = lambda urn: sorted(
|
||||
{ # type: ignore
|
||||
field.fieldPath
|
||||
for field in (
|
||||
upstream_schema if "upstream" in str(urn) else downstream_schema
|
||||
).fields
|
||||
}
|
||||
)
|
||||
with patch.object(LineageClient, "_get_fields_from_dataset_urn") as mock_method:
|
||||
mock_method.side_effect = lambda urn: sorted(
|
||||
{
|
||||
field.fieldPath
|
||||
for field in (
|
||||
upstream_schema if "upstream" in str(urn) else downstream_schema
|
||||
).fields
|
||||
}
|
||||
)
|
||||
|
||||
# Run the lineage function
|
||||
lineage_client.add_dataset_copy_lineage(
|
||||
upstream=upstream,
|
||||
downstream=downstream,
|
||||
column_lineage="auto_strict",
|
||||
)
|
||||
|
||||
# Mock the _get_fields_from_dataset_urn method to return our test fields
|
||||
lineage_client._get_fields_from_dataset_urn = MagicMock() # type: ignore
|
||||
lineage_client._get_fields_from_dataset_urn.side_effect = lambda urn: sorted(
|
||||
{ # type: ignore
|
||||
field.fieldPath
|
||||
for field in (
|
||||
upstream_schema if "upstream" in str(urn) else downstream_schema
|
||||
).fields
|
||||
}
|
||||
)
|
||||
|
||||
# Run the lineage function
|
||||
lineage_client.add_dataset_copy_lineage(
|
||||
upstream=upstream,
|
||||
downstream=downstream,
|
||||
column_lineage="auto_strict",
|
||||
)
|
||||
# Run the lineage function
|
||||
client.lineage.add_dataset_copy_lineage(
|
||||
upstream=upstream,
|
||||
downstream=downstream,
|
||||
column_lineage="auto_strict",
|
||||
)
|
||||
|
||||
# Use golden file for assertion
|
||||
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_copy_strict_golden.json")
|
||||
@ -345,13 +328,12 @@ def test_add_dataset_copy_lineage_auto_strict(client: DataHubClient) -> None:
|
||||
|
||||
def test_add_dataset_transform_lineage_basic(client: DataHubClient) -> None:
|
||||
"""Test basic lineage without column mapping or query."""
|
||||
lineage_client = LineageClient(client=client)
|
||||
|
||||
# Basic lineage test
|
||||
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
|
||||
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
|
||||
|
||||
lineage_client.add_dataset_transform_lineage(
|
||||
client.lineage.add_dataset_transform_lineage(
|
||||
upstream=upstream,
|
||||
downstream=downstream,
|
||||
)
|
||||
@ -360,7 +342,6 @@ def test_add_dataset_transform_lineage_basic(client: DataHubClient) -> None:
|
||||
|
||||
def test_add_dataset_transform_lineage_complete(client: DataHubClient) -> None:
|
||||
"""Test complete lineage with column mapping and query."""
|
||||
lineage_client = LineageClient(client=client)
|
||||
|
||||
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
|
||||
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
|
||||
@ -372,10 +353,211 @@ def test_add_dataset_transform_lineage_complete(client: DataHubClient) -> None:
|
||||
"ds_col2": ["us_col2", "us_col3"], # 2:1 mapping
|
||||
}
|
||||
|
||||
lineage_client.add_dataset_transform_lineage(
|
||||
client.lineage.add_dataset_transform_lineage(
|
||||
upstream=upstream,
|
||||
downstream=downstream,
|
||||
query_text=query_text,
|
||||
column_lineage=column_lineage,
|
||||
)
|
||||
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_complete_golden.json")
|
||||
|
||||
|
||||
def test_add_dataset_lineage_from_sql(client: DataHubClient) -> None:
|
||||
"""Test adding lineage from SQL parsing with a golden file."""
|
||||
|
||||
# Create minimal mock result with necessary info
|
||||
mock_result = SqlParsingResult(
|
||||
in_tables=["urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)"],
|
||||
out_tables=[
|
||||
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
|
||||
],
|
||||
column_lineage=[], # Simplified - we only care about table-level lineage for this test
|
||||
query_type=QueryType.SELECT,
|
||||
debug_info=MagicMock(error=None, table_error=None),
|
||||
)
|
||||
|
||||
# Simple SQL that would produce the expected lineage
|
||||
query_text = (
|
||||
"create table sales_summary as SELECT price, qty, unit_cost FROM orders"
|
||||
)
|
||||
|
||||
# Patch SQL parser and execute lineage creation
|
||||
with patch(
|
||||
"datahub.sql_parsing.sqlglot_lineage.create_lineage_sql_parsed_result",
|
||||
return_value=mock_result,
|
||||
):
|
||||
client.lineage.add_dataset_lineage_from_sql(
|
||||
query_text=query_text, platform="snowflake", env="PROD"
|
||||
)
|
||||
|
||||
# Validate against golden file
|
||||
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_from_sql_golden.json")
|
||||
|
||||
|
||||
def test_add_dataset_lineage_from_sql_with_multiple_upstreams(
|
||||
client: DataHubClient,
|
||||
) -> None:
|
||||
"""Test adding lineage for a dataset with multiple upstreams."""
|
||||
|
||||
# Create minimal mock result with necessary info
|
||||
mock_result = SqlParsingResult(
|
||||
in_tables=[
|
||||
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)",
|
||||
"urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)",
|
||||
],
|
||||
out_tables=[
|
||||
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
|
||||
],
|
||||
column_lineage=[
|
||||
ColumnLineageInfo(
|
||||
downstream=DownstreamColumnRef(
|
||||
column="product_name",
|
||||
),
|
||||
upstreams=[
|
||||
ColumnRef(
|
||||
table="urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)",
|
||||
column="product_name",
|
||||
)
|
||||
],
|
||||
),
|
||||
ColumnLineageInfo(
|
||||
downstream=DownstreamColumnRef(
|
||||
column="total_quantity",
|
||||
),
|
||||
upstreams=[
|
||||
ColumnRef(
|
||||
table="urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)",
|
||||
column="quantity",
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
query_type=QueryType.SELECT,
|
||||
debug_info=MagicMock(error=None, table_error=None),
|
||||
)
|
||||
|
||||
# Simple SQL that would produce the expected lineage
|
||||
query_text = """
|
||||
CREATE TABLE sales_summary AS
|
||||
SELECT
|
||||
p.product_name,
|
||||
SUM(s.quantity) as total_quantity,
|
||||
FROM sales s
|
||||
JOIN products p ON s.product_id = p.id
|
||||
GROUP BY p.product_name
|
||||
"""
|
||||
|
||||
# Patch SQL parser and execute lineage creation
|
||||
with patch(
|
||||
"datahub.sql_parsing.sqlglot_lineage.create_lineage_sql_parsed_result",
|
||||
return_value=mock_result,
|
||||
):
|
||||
client.lineage.add_dataset_lineage_from_sql(
|
||||
query_text=query_text, platform="snowflake", env="PROD"
|
||||
)
|
||||
|
||||
# Validate against golden file
|
||||
assert_client_golden(
|
||||
client, _GOLDEN_DIR / "test_lineage_from_sql_multiple_upstreams_golden.json"
|
||||
)
|
||||
|
||||
|
||||
def test_add_datajob_lineage(client: DataHubClient) -> None:
|
||||
"""Test adding lineage for datajobs using DataJobPatchBuilder."""
|
||||
|
||||
# Define URNs for test with correct format
|
||||
datajob_urn = (
|
||||
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)"
|
||||
)
|
||||
input_dataset_urn = (
|
||||
"urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||
)
|
||||
input_datajob_urn = (
|
||||
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),upstream_job)"
|
||||
)
|
||||
output_dataset_urn = (
|
||||
"urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)"
|
||||
)
|
||||
|
||||
# Test adding both upstream and downstream connections
|
||||
client.lineage.add_datajob_lineage(
|
||||
datajob=datajob_urn,
|
||||
upstreams=[input_dataset_urn, input_datajob_urn],
|
||||
downstreams=[output_dataset_urn],
|
||||
)
|
||||
|
||||
# Validate lineage MCPs against golden file
|
||||
assert_client_golden(client, _GOLDEN_DIR / "test_datajob_lineage_golden.json")
|
||||
|
||||
|
||||
def test_add_datajob_inputs_only(client: DataHubClient) -> None:
|
||||
"""Test adding only inputs to a datajob."""
|
||||
|
||||
# Define URNs for test
|
||||
datajob_urn = (
|
||||
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),process_job)"
|
||||
)
|
||||
input_dataset_urn = (
|
||||
"urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||
)
|
||||
|
||||
# Test adding just upstream connections
|
||||
client.lineage.add_datajob_lineage(
|
||||
datajob=datajob_urn,
|
||||
upstreams=[input_dataset_urn],
|
||||
)
|
||||
|
||||
# Validate lineage MCPs
|
||||
assert_client_golden(client, _GOLDEN_DIR / "test_datajob_inputs_only_golden.json")
|
||||
|
||||
|
||||
def test_add_datajob_outputs_only(client: DataHubClient) -> None:
|
||||
"""Test adding only outputs to a datajob."""
|
||||
|
||||
# Define URNs for test
|
||||
datajob_urn = (
|
||||
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)"
|
||||
)
|
||||
output_dataset_urn = (
|
||||
"urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)"
|
||||
)
|
||||
|
||||
# Test adding just downstream connections
|
||||
client.lineage.add_datajob_lineage(
|
||||
datajob=datajob_urn, downstreams=[output_dataset_urn]
|
||||
)
|
||||
|
||||
# Validate lineage MCPs
|
||||
assert_client_golden(client, _GOLDEN_DIR / "test_datajob_outputs_only_golden.json")
|
||||
|
||||
|
||||
def test_add_datajob_lineage_validation(client: DataHubClient) -> None:
|
||||
"""Test validation checks in add_datajob_lineage."""
|
||||
|
||||
# Define URNs for test
|
||||
datajob_urn = (
|
||||
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)"
|
||||
)
|
||||
invalid_urn = "urn:li:glossaryNode:something"
|
||||
|
||||
# Test with invalid datajob URN
|
||||
with pytest.raises(
|
||||
InvalidUrnError,
|
||||
match="Passed an urn of type glossaryNode to the from_string method of DataJobUrn",
|
||||
):
|
||||
client.lineage.add_datajob_lineage(
|
||||
datajob=invalid_urn,
|
||||
upstreams=[
|
||||
"urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||
],
|
||||
)
|
||||
|
||||
# Test with invalid upstream URN
|
||||
with pytest.raises(InvalidUrnError):
|
||||
client.lineage.add_datajob_lineage(datajob=datajob_urn, upstreams=[invalid_urn])
|
||||
|
||||
# Test with invalid downstream URN
|
||||
with pytest.raises(InvalidUrnError):
|
||||
client.lineage.add_datajob_lineage(
|
||||
datajob=datajob_urn, downstreams=[invalid_urn]
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user