mirror of
https://github.com/datahub-project/datahub.git
synced 2025-07-11 11:13:25 +00:00
564 lines
19 KiB
Python
564 lines
19 KiB
Python
import pathlib
|
|
from typing import Dict, List, Set, cast
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
import pytest
|
|
|
|
from datahub.metadata.schema_classes import (
|
|
OtherSchemaClass,
|
|
SchemaFieldClass,
|
|
SchemaFieldDataTypeClass,
|
|
SchemaMetadataClass,
|
|
StringTypeClass,
|
|
)
|
|
from datahub.sdk.lineage_client import LineageClient
|
|
from datahub.sdk.main_client import DataHubClient
|
|
from datahub.sql_parsing.sql_parsing_common import QueryType
|
|
from datahub.sql_parsing.sqlglot_lineage import (
|
|
ColumnLineageInfo,
|
|
ColumnRef,
|
|
DownstreamColumnRef,
|
|
SqlParsingResult,
|
|
)
|
|
from datahub.testing import mce_helpers
|
|
from datahub.utilities.urns.error import InvalidUrnError
|
|
|
|
_GOLDEN_DIR = pathlib.Path(__file__).parent / "lineage_client_golden"
|
|
_GOLDEN_DIR.mkdir(exist_ok=True)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_graph() -> Mock:
|
|
graph = Mock()
|
|
|
|
return graph
|
|
|
|
|
|
@pytest.fixture
|
|
def client(mock_graph: Mock) -> DataHubClient:
|
|
return DataHubClient(graph=mock_graph)
|
|
|
|
|
|
def assert_client_golden(client: DataHubClient, golden_path: pathlib.Path) -> None:
|
|
mcps = client._graph.emit_mcps.call_args[0][0] # type: ignore
|
|
|
|
mce_helpers.check_goldens_stream(
|
|
outputs=mcps,
|
|
golden_path=golden_path,
|
|
ignore_order=False,
|
|
)
|
|
|
|
|
|
def test_get_fuzzy_column_lineage(client: DataHubClient) -> None:
|
|
"""Test the fuzzy column lineage matching algorithm."""
|
|
# Create a minimal client just for testing the method
|
|
# Test cases
|
|
test_cases = [
|
|
# Case 1: Exact matches
|
|
{
|
|
"upstream_fields": {"id", "name", "email"},
|
|
"downstream_fields": {"id", "name", "phone"},
|
|
"expected": {"id": ["id"], "name": ["name"]},
|
|
},
|
|
# Case 2: Case insensitive matches
|
|
{
|
|
"upstream_fields": {"ID", "Name", "Email"},
|
|
"downstream_fields": {"id", "name", "phone"},
|
|
"expected": {"id": ["ID"], "name": ["Name"]},
|
|
},
|
|
# Case 3: Camel case to snake case
|
|
{
|
|
"upstream_fields": {"id", "user_id", "full_name"},
|
|
"downstream_fields": {"id", "userId", "fullName"},
|
|
"expected": {
|
|
"id": ["id"],
|
|
"userId": ["user_id"],
|
|
"fullName": ["full_name"],
|
|
},
|
|
},
|
|
# Case 4: Snake case to camel case
|
|
{
|
|
"upstream_fields": {"id", "userId", "fullName"},
|
|
"downstream_fields": {"id", "user_id", "full_name"},
|
|
"expected": {
|
|
"id": ["id"],
|
|
"user_id": ["userId"],
|
|
"full_name": ["fullName"],
|
|
},
|
|
},
|
|
# Case 5: Mixed matches
|
|
{
|
|
"upstream_fields": {"id", "customer_id", "user_name"},
|
|
"downstream_fields": {
|
|
"id",
|
|
"customerId",
|
|
"address",
|
|
},
|
|
"expected": {"id": ["id"], "customerId": ["customer_id"]},
|
|
},
|
|
# Case 6: Mixed matches with different casing
|
|
{
|
|
"upstream_fields": {"id", "customer_id", "userName", "address_id"},
|
|
"downstream_fields": {"id", "customerId", "user_name", "user_address"},
|
|
"expected": {
|
|
"id": ["id"],
|
|
"customerId": ["customer_id"],
|
|
"user_name": ["userName"],
|
|
}, # user_address <> address_id shouldn't match
|
|
},
|
|
]
|
|
|
|
# Run test cases
|
|
for i, test_case in enumerate(test_cases):
|
|
result = client.lineage._get_fuzzy_column_lineage(
|
|
cast(Set[str], test_case["upstream_fields"]),
|
|
cast(Set[str], test_case["downstream_fields"]),
|
|
)
|
|
assert result == test_case["expected"], (
|
|
f"Test case {i + 1} failed: {result} != {test_case['expected']}"
|
|
)
|
|
|
|
|
|
def test_get_strict_column_lineage(client: DataHubClient) -> None:
|
|
"""Test the strict column lineage matching algorithm."""
|
|
# Create a minimal client just for testing the method
|
|
|
|
# Define test cases
|
|
test_cases = [
|
|
# Case 1: Exact matches
|
|
{
|
|
"upstream_fields": {"id", "name", "email"},
|
|
"downstream_fields": {"id", "name", "phone"},
|
|
"expected": {"id": ["id"], "name": ["name"]},
|
|
},
|
|
# Case 2: No matches
|
|
{
|
|
"upstream_fields": {"col1", "col2", "col3"},
|
|
"downstream_fields": {"col4", "col5", "col6"},
|
|
"expected": {},
|
|
},
|
|
# Case 3: Case mismatch (should match)
|
|
{
|
|
"upstream_fields": {"ID", "Name", "Email"},
|
|
"downstream_fields": {"id", "name", "email"},
|
|
"expected": {"id": ["ID"], "name": ["Name"], "email": ["Email"]},
|
|
},
|
|
]
|
|
|
|
# Run test cases
|
|
for i, test_case in enumerate(test_cases):
|
|
result = client.lineage._get_strict_column_lineage(
|
|
cast(Set[str], test_case["upstream_fields"]),
|
|
cast(Set[str], test_case["downstream_fields"]),
|
|
)
|
|
assert result == test_case["expected"], f"Test case {i + 1} failed"
|
|
|
|
|
|
def test_add_dataset_copy_lineage_auto_fuzzy(client: DataHubClient) -> None:
|
|
"""Test auto fuzzy column lineage mapping."""
|
|
|
|
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
|
|
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
|
|
|
|
# Create upstream and downstream schema
|
|
upstream_schema = SchemaMetadataClass(
|
|
schemaName="upstream_table",
|
|
platform="urn:li:dataPlatform:snowflake",
|
|
version=1,
|
|
hash="1234567890",
|
|
platformSchema=OtherSchemaClass(rawSchema=""),
|
|
fields=[
|
|
SchemaFieldClass(
|
|
fieldPath="id",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
SchemaFieldClass(
|
|
fieldPath="user_id",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
SchemaFieldClass(
|
|
fieldPath="address",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
SchemaFieldClass(
|
|
fieldPath="age",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
],
|
|
)
|
|
|
|
downstream_schema = SchemaMetadataClass(
|
|
schemaName="downstream_table",
|
|
platform="urn:li:dataPlatform:snowflake",
|
|
version=1,
|
|
hash="1234567890",
|
|
platformSchema=OtherSchemaClass(rawSchema=""),
|
|
fields=[
|
|
SchemaFieldClass(
|
|
fieldPath="id",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
SchemaFieldClass(
|
|
fieldPath="userId",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
SchemaFieldClass(
|
|
fieldPath="score",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
],
|
|
)
|
|
|
|
# Use patch.object with a context manager
|
|
with patch.object(LineageClient, "_get_fields_from_dataset_urn") as mock_method:
|
|
# Configure the mock with a simpler side effect function
|
|
mock_method.side_effect = lambda urn: sorted(
|
|
{
|
|
field.fieldPath
|
|
for field in (
|
|
upstream_schema if "upstream" in str(urn) else downstream_schema
|
|
).fields
|
|
}
|
|
)
|
|
|
|
# Now use client.lineage with the patched method
|
|
client.lineage.add_dataset_copy_lineage(
|
|
upstream=upstream,
|
|
downstream=downstream,
|
|
column_lineage="auto_fuzzy",
|
|
)
|
|
|
|
# Use golden file for assertion
|
|
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_copy_fuzzy_golden.json")
|
|
|
|
|
|
def test_add_dataset_copy_lineage_auto_strict(client: DataHubClient) -> None:
|
|
"""Test strict column lineage with field matches."""
|
|
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
|
|
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
|
|
|
|
# Create upstream and downstream schema
|
|
upstream_schema = SchemaMetadataClass(
|
|
schemaName="upstream_table",
|
|
platform="urn:li:dataPlatform:snowflake",
|
|
version=1,
|
|
hash="1234567890",
|
|
platformSchema=OtherSchemaClass(rawSchema=""),
|
|
fields=[
|
|
SchemaFieldClass(
|
|
fieldPath="id",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
SchemaFieldClass(
|
|
fieldPath="name",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
SchemaFieldClass(
|
|
fieldPath="user_id",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
SchemaFieldClass(
|
|
fieldPath="address",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
],
|
|
)
|
|
|
|
downstream_schema = SchemaMetadataClass(
|
|
schemaName="downstream_table",
|
|
platform="urn:li:dataPlatform:snowflake",
|
|
version=1,
|
|
hash="1234567890",
|
|
platformSchema=OtherSchemaClass(rawSchema=""),
|
|
fields=[
|
|
SchemaFieldClass(
|
|
fieldPath="id",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
SchemaFieldClass(
|
|
fieldPath="name",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
SchemaFieldClass(
|
|
fieldPath="address",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
SchemaFieldClass(
|
|
fieldPath="score",
|
|
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
|
|
nativeDataType="string",
|
|
),
|
|
],
|
|
)
|
|
|
|
with patch.object(LineageClient, "_get_fields_from_dataset_urn") as mock_method:
|
|
mock_method.side_effect = lambda urn: sorted(
|
|
{
|
|
field.fieldPath
|
|
for field in (
|
|
upstream_schema if "upstream" in str(urn) else downstream_schema
|
|
).fields
|
|
}
|
|
)
|
|
|
|
# Run the lineage function
|
|
client.lineage.add_dataset_copy_lineage(
|
|
upstream=upstream,
|
|
downstream=downstream,
|
|
column_lineage="auto_strict",
|
|
)
|
|
|
|
# Use golden file for assertion
|
|
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_copy_strict_golden.json")
|
|
|
|
|
|
def test_add_dataset_transform_lineage_basic(client: DataHubClient) -> None:
|
|
"""Test basic lineage without column mapping or query."""
|
|
|
|
# Basic lineage test
|
|
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
|
|
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
|
|
|
|
client.lineage.add_dataset_transform_lineage(
|
|
upstream=upstream,
|
|
downstream=downstream,
|
|
)
|
|
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_basic_golden.json")
|
|
|
|
|
|
def test_add_dataset_transform_lineage_complete(client: DataHubClient) -> None:
|
|
"""Test complete lineage with column mapping and query."""
|
|
|
|
upstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
|
|
downstream = "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
|
|
query_text = (
|
|
"SELECT us_col1 as ds_col1, us_col2 + us_col3 as ds_col2 FROM upstream_table"
|
|
)
|
|
column_lineage: Dict[str, List[str]] = {
|
|
"ds_col1": ["us_col1"], # Simple 1:1 mapping
|
|
"ds_col2": ["us_col2", "us_col3"], # 2:1 mapping
|
|
}
|
|
|
|
client.lineage.add_dataset_transform_lineage(
|
|
upstream=upstream,
|
|
downstream=downstream,
|
|
query_text=query_text,
|
|
column_lineage=column_lineage,
|
|
)
|
|
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_complete_golden.json")
|
|
|
|
|
|
def test_add_dataset_lineage_from_sql(client: DataHubClient) -> None:
|
|
"""Test adding lineage from SQL parsing with a golden file."""
|
|
|
|
# Create minimal mock result with necessary info
|
|
mock_result = SqlParsingResult(
|
|
in_tables=["urn:li:dataset:(urn:li:dataPlatform:snowflake,orders,PROD)"],
|
|
out_tables=[
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
|
|
],
|
|
column_lineage=[], # Simplified - we only care about table-level lineage for this test
|
|
query_type=QueryType.SELECT,
|
|
debug_info=MagicMock(error=None, table_error=None),
|
|
)
|
|
|
|
# Simple SQL that would produce the expected lineage
|
|
query_text = (
|
|
"create table sales_summary as SELECT price, qty, unit_cost FROM orders"
|
|
)
|
|
|
|
# Patch SQL parser and execute lineage creation
|
|
with patch(
|
|
"datahub.sql_parsing.sqlglot_lineage.create_lineage_sql_parsed_result",
|
|
return_value=mock_result,
|
|
):
|
|
client.lineage.add_dataset_lineage_from_sql(
|
|
query_text=query_text, platform="snowflake", env="PROD"
|
|
)
|
|
|
|
# Validate against golden file
|
|
assert_client_golden(client, _GOLDEN_DIR / "test_lineage_from_sql_golden.json")
|
|
|
|
|
|
def test_add_dataset_lineage_from_sql_with_multiple_upstreams(
|
|
client: DataHubClient,
|
|
) -> None:
|
|
"""Test adding lineage for a dataset with multiple upstreams."""
|
|
|
|
# Create minimal mock result with necessary info
|
|
mock_result = SqlParsingResult(
|
|
in_tables=[
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)",
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,products,PROD)",
|
|
],
|
|
out_tables=[
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,sales_summary,PROD)"
|
|
],
|
|
column_lineage=[
|
|
ColumnLineageInfo(
|
|
downstream=DownstreamColumnRef(
|
|
column="product_name",
|
|
),
|
|
upstreams=[
|
|
ColumnRef(
|
|
table="urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)",
|
|
column="product_name",
|
|
)
|
|
],
|
|
),
|
|
ColumnLineageInfo(
|
|
downstream=DownstreamColumnRef(
|
|
column="total_quantity",
|
|
),
|
|
upstreams=[
|
|
ColumnRef(
|
|
table="urn:li:dataset:(urn:li:dataPlatform:snowflake,sales,PROD)",
|
|
column="quantity",
|
|
)
|
|
],
|
|
),
|
|
],
|
|
query_type=QueryType.SELECT,
|
|
debug_info=MagicMock(error=None, table_error=None),
|
|
)
|
|
|
|
# Simple SQL that would produce the expected lineage
|
|
query_text = """
|
|
CREATE TABLE sales_summary AS
|
|
SELECT
|
|
p.product_name,
|
|
SUM(s.quantity) as total_quantity,
|
|
FROM sales s
|
|
JOIN products p ON s.product_id = p.id
|
|
GROUP BY p.product_name
|
|
"""
|
|
|
|
# Patch SQL parser and execute lineage creation
|
|
with patch(
|
|
"datahub.sql_parsing.sqlglot_lineage.create_lineage_sql_parsed_result",
|
|
return_value=mock_result,
|
|
):
|
|
client.lineage.add_dataset_lineage_from_sql(
|
|
query_text=query_text, platform="snowflake", env="PROD"
|
|
)
|
|
|
|
# Validate against golden file
|
|
assert_client_golden(
|
|
client, _GOLDEN_DIR / "test_lineage_from_sql_multiple_upstreams_golden.json"
|
|
)
|
|
|
|
|
|
def test_add_datajob_lineage(client: DataHubClient) -> None:
|
|
"""Test adding lineage for datajobs using DataJobPatchBuilder."""
|
|
|
|
# Define URNs for test with correct format
|
|
datajob_urn = (
|
|
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)"
|
|
)
|
|
input_dataset_urn = (
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
|
)
|
|
input_datajob_urn = (
|
|
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),upstream_job)"
|
|
)
|
|
output_dataset_urn = (
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)"
|
|
)
|
|
|
|
# Test adding both upstream and downstream connections
|
|
client.lineage.add_datajob_lineage(
|
|
datajob=datajob_urn,
|
|
upstreams=[input_dataset_urn, input_datajob_urn],
|
|
downstreams=[output_dataset_urn],
|
|
)
|
|
|
|
# Validate lineage MCPs against golden file
|
|
assert_client_golden(client, _GOLDEN_DIR / "test_datajob_lineage_golden.json")
|
|
|
|
|
|
def test_add_datajob_inputs_only(client: DataHubClient) -> None:
|
|
"""Test adding only inputs to a datajob."""
|
|
|
|
# Define URNs for test
|
|
datajob_urn = (
|
|
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),process_job)"
|
|
)
|
|
input_dataset_urn = (
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
|
)
|
|
|
|
# Test adding just upstream connections
|
|
client.lineage.add_datajob_lineage(
|
|
datajob=datajob_urn,
|
|
upstreams=[input_dataset_urn],
|
|
)
|
|
|
|
# Validate lineage MCPs
|
|
assert_client_golden(client, _GOLDEN_DIR / "test_datajob_inputs_only_golden.json")
|
|
|
|
|
|
def test_add_datajob_outputs_only(client: DataHubClient) -> None:
|
|
"""Test adding only outputs to a datajob."""
|
|
|
|
# Define URNs for test
|
|
datajob_urn = (
|
|
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)"
|
|
)
|
|
output_dataset_urn = (
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)"
|
|
)
|
|
|
|
# Test adding just downstream connections
|
|
client.lineage.add_datajob_lineage(
|
|
datajob=datajob_urn, downstreams=[output_dataset_urn]
|
|
)
|
|
|
|
# Validate lineage MCPs
|
|
assert_client_golden(client, _GOLDEN_DIR / "test_datajob_outputs_only_golden.json")
|
|
|
|
|
|
def test_add_datajob_lineage_validation(client: DataHubClient) -> None:
|
|
"""Test validation checks in add_datajob_lineage."""
|
|
|
|
# Define URNs for test
|
|
datajob_urn = (
|
|
"urn:li:dataJob:(urn:li:dataFlow:(airflow,example_dag,PROD),transform_job)"
|
|
)
|
|
invalid_urn = "urn:li:glossaryNode:something"
|
|
|
|
# Test with invalid datajob URN
|
|
with pytest.raises(
|
|
InvalidUrnError,
|
|
match="Passed an urn of type glossaryNode to the from_string method of DataJobUrn",
|
|
):
|
|
client.lineage.add_datajob_lineage(
|
|
datajob=invalid_urn,
|
|
upstreams=[
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
|
],
|
|
)
|
|
|
|
# Test with invalid upstream URN
|
|
with pytest.raises(InvalidUrnError):
|
|
client.lineage.add_datajob_lineage(datajob=datajob_urn, upstreams=[invalid_urn])
|
|
|
|
# Test with invalid downstream URN
|
|
with pytest.raises(InvalidUrnError):
|
|
client.lineage.add_datajob_lineage(
|
|
datajob=datajob_urn, downstreams=[invalid_urn]
|
|
)
|