mirror of
https://github.com/datahub-project/datahub.git
synced 2025-11-03 20:27:50 +00:00
feat(sdk): add get_lineage (#13654)
This commit is contained in:
parent
01357940b1
commit
e169b4ac05
24
metadata-ingestion/examples/library/get_column_lineage.py
Normal file
24
metadata-ingestion/examples/library/get_column_lineage.py
Normal file
@ -0,0 +1,24 @@
|
||||
from datahub.metadata.urns import DatasetUrn
|
||||
from datahub.sdk.main_client import DataHubClient
|
||||
from datahub.sdk.search_filters import FilterDsl as F
|
||||
|
||||
client = DataHubClient.from_env()
|
||||
|
||||
dataset_urn = DatasetUrn(platform="snowflake", name="downstream_table")
|
||||
|
||||
# Get column lineage for the entire flow
|
||||
# you can pass source_urn and source_column to get lineage for a specific column
|
||||
# alternatively, you can pass schemaFieldUrn to source_urn.
|
||||
# e.g. source_urn="urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table),id)"
|
||||
downstream_column_lineage = client.lineage.get_lineage(
|
||||
source_urn=dataset_urn,
|
||||
source_column="id",
|
||||
direction="downstream",
|
||||
max_hops=1,
|
||||
filter=F.and_(
|
||||
F.platform("snowflake"),
|
||||
F.entity_type("dataset"),
|
||||
),
|
||||
)
|
||||
|
||||
print(downstream_column_lineage)
|
||||
17
metadata-ingestion/examples/library/get_lineage_basic.py
Normal file
17
metadata-ingestion/examples/library/get_lineage_basic.py
Normal file
@ -0,0 +1,17 @@
|
||||
from datahub.metadata.urns import DatasetUrn
|
||||
from datahub.sdk.main_client import DataHubClient
|
||||
from datahub.sdk.search_filters import FilterDsl as F
|
||||
|
||||
client = DataHubClient.from_env()
|
||||
|
||||
downstream_lineage = client.lineage.get_lineage(
|
||||
source_urn=DatasetUrn(platform="snowflake", name="downstream_table"),
|
||||
direction="downstream",
|
||||
max_hops=2,
|
||||
filter=F.and_(
|
||||
F.platform("airflow"),
|
||||
F.entity_type("dataJob"),
|
||||
),
|
||||
)
|
||||
|
||||
print(downstream_lineage)
|
||||
@ -2,9 +2,12 @@ from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
@ -18,12 +21,7 @@ from typing_extensions import assert_never, deprecated
|
||||
import datahub.metadata.schema_classes as models
|
||||
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
||||
from datahub.errors import SdkUsageError
|
||||
from datahub.metadata.urns import (
|
||||
DataJobUrn,
|
||||
DatasetUrn,
|
||||
QueryUrn,
|
||||
Urn,
|
||||
)
|
||||
from datahub.metadata.urns import DataJobUrn, DatasetUrn, QueryUrn, SchemaFieldUrn, Urn
|
||||
from datahub.sdk._shared import (
|
||||
ChartUrnOrStr,
|
||||
DashboardUrnOrStr,
|
||||
@ -32,6 +30,8 @@ from datahub.sdk._shared import (
|
||||
)
|
||||
from datahub.sdk._utils import DEFAULT_ACTOR_URN
|
||||
from datahub.sdk.dataset import ColumnLineageMapping, parse_cll_mapping
|
||||
from datahub.sdk.search_client import compile_filters
|
||||
from datahub.sdk.search_filters import Filter, FilterDsl
|
||||
from datahub.specific.chart import ChartPatchBuilder
|
||||
from datahub.specific.dashboard import DashboardPatchBuilder
|
||||
from datahub.specific.datajob import DataJobPatchBuilder
|
||||
@ -53,9 +53,29 @@ _empty_audit_stamp = models.AuditStampClass(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LineagePath:
|
||||
urn: str
|
||||
entity_name: str
|
||||
column_name: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LineageResult:
|
||||
urn: str
|
||||
type: str
|
||||
hops: int
|
||||
direction: Literal["upstream", "downstream"]
|
||||
platform: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
paths: Optional[List[LineagePath]] = None
|
||||
|
||||
|
||||
class LineageClient:
|
||||
def __init__(self, client: DataHubClient):
|
||||
self._client = client
|
||||
self._graph = client._graph
|
||||
|
||||
def _get_fields_from_dataset_urn(self, dataset_urn: DatasetUrn) -> Set[str]:
|
||||
schema_metadata = self._client._graph.get_aspect(
|
||||
@ -700,3 +720,242 @@ class LineageClient:
|
||||
|
||||
# Apply the changes to the entity
|
||||
self._client.entities.update(patch_builder)
|
||||
|
||||
def get_lineage(
|
||||
self,
|
||||
*,
|
||||
source_urn: Union[str, Urn],
|
||||
source_column: Optional[str] = None,
|
||||
direction: Literal["upstream", "downstream"] = "upstream",
|
||||
max_hops: int = 1,
|
||||
filter: Optional[Filter] = None,
|
||||
count: int = 500,
|
||||
) -> List[LineageResult]:
|
||||
"""
|
||||
Retrieve lineage entities connected to a source entity.
|
||||
Args:
|
||||
source_urn: Source URN for the lineage search
|
||||
source_column: Source column for the lineage search
|
||||
direction: Direction of lineage traversal
|
||||
max_hops: Maximum number of hops to traverse
|
||||
filter: Filters to apply to the lineage search
|
||||
count: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of lineage results
|
||||
|
||||
Raises:
|
||||
SdkUsageError for invalid filter values
|
||||
"""
|
||||
# Validate and convert input URN
|
||||
source_urn = Urn.from_string(source_urn)
|
||||
# Prepare GraphQL query variables with a separate method
|
||||
variables = self._process_input_variables(
|
||||
source_urn, source_column, filter, direction, max_hops, count
|
||||
)
|
||||
|
||||
return self._execute_lineage_query(variables, direction)
|
||||
|
||||
def _process_input_variables(
|
||||
self,
|
||||
source_urn: Urn,
|
||||
source_column: Optional[str] = None,
|
||||
filters: Optional[Filter] = None,
|
||||
direction: Literal["upstream", "downstream"] = "upstream",
|
||||
max_hops: int = 1,
|
||||
count: int = 500,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process filters and prepare GraphQL query variables for lineage search.
|
||||
|
||||
Args:
|
||||
source_urn: Source URN for the lineage search
|
||||
source_column: Source column for the lineage search
|
||||
filters: Optional filters to apply
|
||||
direction: Direction of lineage traversal
|
||||
max_hops: Maximum number of hops to traverse
|
||||
count: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
Dictionary of GraphQL query variables
|
||||
|
||||
Raises:
|
||||
SdkUsageError for invalid filter values
|
||||
"""
|
||||
|
||||
# print warning if max_hops is greater than 2
|
||||
if max_hops > 2:
|
||||
logger.warning(
|
||||
"""If `max_hops` is more than 2, the search will try to find the full lineage graph.
|
||||
By default, only 500 results are shown.
|
||||
You can change the `count` to get more or fewer results.
|
||||
"""
|
||||
)
|
||||
|
||||
# Determine hop values
|
||||
max_hop_values = (
|
||||
[str(hop) for hop in range(1, max_hops + 1)]
|
||||
if max_hops <= 2
|
||||
else ["1", "2", "3+"]
|
||||
)
|
||||
|
||||
degree_filter = FilterDsl.custom_filter(
|
||||
field="degree",
|
||||
condition="EQUAL",
|
||||
values=max_hop_values,
|
||||
)
|
||||
|
||||
filters_with_max_hops = (
|
||||
FilterDsl.and_(degree_filter, filters)
|
||||
if filters is not None
|
||||
else degree_filter
|
||||
)
|
||||
|
||||
types, compiled_filters = compile_filters(filters_with_max_hops)
|
||||
|
||||
# Prepare base variables
|
||||
variables: Dict[str, Any] = {
|
||||
"input": {
|
||||
"urn": str(source_urn),
|
||||
"direction": direction.upper(),
|
||||
"count": count,
|
||||
"types": types,
|
||||
"orFilters": compiled_filters,
|
||||
}
|
||||
}
|
||||
|
||||
# if column is provided, update the variables to include the schema field urn
|
||||
if isinstance(source_urn, SchemaFieldUrn) or source_column:
|
||||
variables["input"]["searchFlags"] = {
|
||||
"groupingSpec": {
|
||||
"groupingCriteria": {
|
||||
"baseEntityType": "SCHEMA_FIELD",
|
||||
"groupingEntityType": "SCHEMA_FIELD",
|
||||
}
|
||||
}
|
||||
}
|
||||
if isinstance(source_urn, SchemaFieldUrn):
|
||||
variables["input"]["urn"] = str(source_urn)
|
||||
elif source_column:
|
||||
variables["input"]["urn"] = str(SchemaFieldUrn(source_urn, source_column))
|
||||
|
||||
return variables
|
||||
|
||||
def _execute_lineage_query(
|
||||
self,
|
||||
variables: Dict[str, Any],
|
||||
direction: Literal["upstream", "downstream"],
|
||||
) -> List[LineageResult]:
|
||||
"""Execute GraphQL query and process results."""
|
||||
# Construct GraphQL query with dynamic path query
|
||||
graphql_query = """
|
||||
query scrollAcrossLineage($input: ScrollAcrossLineageInput!) {
|
||||
scrollAcrossLineage(input: $input) {
|
||||
nextScrollId
|
||||
searchResults {
|
||||
degree
|
||||
entity {
|
||||
urn
|
||||
type
|
||||
... on Dataset {
|
||||
name
|
||||
platform {
|
||||
name
|
||||
}
|
||||
properties {
|
||||
description
|
||||
}
|
||||
}
|
||||
... on DataJob {
|
||||
jobId
|
||||
dataPlatformInstance {
|
||||
platform {
|
||||
name
|
||||
}
|
||||
}
|
||||
properties {
|
||||
name
|
||||
description
|
||||
}
|
||||
}
|
||||
}
|
||||
paths {
|
||||
path {
|
||||
urn
|
||||
type
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
results: List[LineageResult] = []
|
||||
|
||||
first_iter = True
|
||||
scroll_id: Optional[str] = None
|
||||
|
||||
while first_iter or scroll_id:
|
||||
first_iter = False
|
||||
|
||||
# Update scroll ID if applicable
|
||||
if scroll_id:
|
||||
variables["input"]["scrollId"] = scroll_id
|
||||
|
||||
# Execute GraphQL query
|
||||
response = self._graph.execute_graphql(graphql_query, variables=variables)
|
||||
data = response["scrollAcrossLineage"]
|
||||
scroll_id = data.get("nextScrollId")
|
||||
|
||||
# Process search results
|
||||
for entry in data["searchResults"]:
|
||||
entity = entry["entity"]
|
||||
|
||||
result = self._create_lineage_result(entity, entry, direction)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def _create_lineage_result(
|
||||
self,
|
||||
entity: Dict[str, Any],
|
||||
entry: Dict[str, Any],
|
||||
direction: Literal["upstream", "downstream"],
|
||||
) -> LineageResult:
|
||||
"""Create a LineageResult from entity and entry data."""
|
||||
platform = entity.get("platform", {}).get("name") or entity.get(
|
||||
"dataPlatformInstance", {}
|
||||
).get("platform", {}).get("name")
|
||||
|
||||
result = LineageResult(
|
||||
urn=entity["urn"],
|
||||
type=entity["type"],
|
||||
hops=entry["degree"],
|
||||
direction=direction,
|
||||
platform=platform,
|
||||
)
|
||||
|
||||
properties = entity.get("properties", {})
|
||||
if properties:
|
||||
result.name = properties.get("name", "")
|
||||
result.description = properties.get("description", "")
|
||||
|
||||
result.paths = []
|
||||
if "paths" in entry:
|
||||
# Process each path in the lineage graph
|
||||
for path in entry["paths"]:
|
||||
for path_entry in path["path"]:
|
||||
# Only include schema fields in the path (exclude other types like Query)
|
||||
if path_entry["type"] == "SCHEMA_FIELD":
|
||||
schema_field_urn = SchemaFieldUrn.from_string(path_entry["urn"])
|
||||
result.paths.append(
|
||||
LineagePath(
|
||||
urn=path_entry["urn"],
|
||||
entity_name=DatasetUrn.from_string(
|
||||
schema_field_urn.parent
|
||||
).name,
|
||||
column_name=schema_field_urn.field_path,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import pathlib
|
||||
from typing import Dict, List, Sequence, Set, cast
|
||||
from typing import Dict, List, Optional, Sequence, Set, cast
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from datahub.errors import SdkUsageError
|
||||
from datahub.sdk.main_client import DataHubClient
|
||||
from datahub.sdk.search_filters import FilterDsl as F
|
||||
from datahub.sql_parsing.sql_parsing_common import QueryType
|
||||
from datahub.sql_parsing.sqlglot_lineage import (
|
||||
ColumnLineageInfo,
|
||||
@ -127,38 +128,8 @@ def test_get_strict_column_lineage(
|
||||
cast(Set[str], upstream_fields),
|
||||
cast(Set[str], downstream_fields),
|
||||
)
|
||||
assert result == expected, f"Test failed: {result} != {expected}"
|
||||
"""Test the strict column lineage matching algorithm."""
|
||||
test_cases = [
|
||||
# Exact matches
|
||||
{
|
||||
"upstream_fields": {"id", "name", "email"},
|
||||
"downstream_fields": {"id", "name", "phone"},
|
||||
"expected": {"id": ["id"], "name": ["name"]},
|
||||
},
|
||||
# No matches
|
||||
{
|
||||
"upstream_fields": {"col1", "col2", "col3"},
|
||||
"downstream_fields": {"col4", "col5", "col6"},
|
||||
"expected": {},
|
||||
},
|
||||
# Case mismatch (should match)
|
||||
{
|
||||
"upstream_fields": {"ID", "Name", "Email"},
|
||||
"downstream_fields": {"id", "name", "email"},
|
||||
"expected": {"id": ["ID"], "name": ["Name"], "email": ["Email"]},
|
||||
},
|
||||
]
|
||||
|
||||
# Run test cases
|
||||
for test_case in test_cases:
|
||||
result = client.lineage._get_strict_column_lineage(
|
||||
cast(Set[str], test_case["upstream_fields"]),
|
||||
cast(Set[str], test_case["downstream_fields"]),
|
||||
)
|
||||
assert result == test_case["expected"], (
|
||||
f"Test failed: {result} != {test_case['expected']}"
|
||||
)
|
||||
assert result == expected, f"Test failed: {result} != {expected}"
|
||||
|
||||
|
||||
def test_infer_lineage_from_sql(client: DataHubClient) -> None:
|
||||
@ -444,3 +415,284 @@ def test_add_lineage_invalid_parameter_combinations(client: DataHubClient) -> No
|
||||
downstream="urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)",
|
||||
column_lineage={"target_col": ["source_col"]},
|
||||
)
|
||||
|
||||
|
||||
def test_get_lineage_basic(client: DataHubClient) -> None:
|
||||
"""Test basic lineage retrieval with default parameters."""
|
||||
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||
|
||||
# Mock GraphQL response
|
||||
mock_response = {
|
||||
"scrollAcrossLineage": {
|
||||
"nextScrollId": None,
|
||||
"searchResults": [
|
||||
{
|
||||
"entity": {
|
||||
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)",
|
||||
"type": "DATASET",
|
||||
"platform": {"name": "snowflake"},
|
||||
"properties": {
|
||||
"name": "upstream_table",
|
||||
"description": "Upstream source table",
|
||||
},
|
||||
},
|
||||
"degree": 1,
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
# Patch the GraphQL execution method
|
||||
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
|
||||
results = client.lineage.get_lineage(source_urn=source_urn)
|
||||
|
||||
# Validate results
|
||||
assert len(results) == 1
|
||||
assert (
|
||||
results[0].urn
|
||||
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
|
||||
)
|
||||
assert results[0].type == "DATASET"
|
||||
assert results[0].name == "upstream_table"
|
||||
assert results[0].platform == "snowflake"
|
||||
assert results[0].direction == "upstream"
|
||||
assert results[0].hops == 1
|
||||
|
||||
|
||||
def test_get_lineage_with_entity_type_filters(client: DataHubClient) -> None:
|
||||
"""Test lineage retrieval with entity type and platform filters."""
|
||||
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||
|
||||
# Mock GraphQL response with multiple entity types
|
||||
mock_response = {
|
||||
"scrollAcrossLineage": {
|
||||
"nextScrollId": None,
|
||||
"searchResults": [
|
||||
{
|
||||
"entity": {
|
||||
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)",
|
||||
"type": "DATASET",
|
||||
"platform": {"name": "snowflake"},
|
||||
"properties": {
|
||||
"name": "upstream_table",
|
||||
"description": "Upstream source table",
|
||||
},
|
||||
},
|
||||
"degree": 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
# Patch the GraphQL execution method to return results for multiple calls
|
||||
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
|
||||
results = client.lineage.get_lineage(
|
||||
source_urn=source_urn,
|
||||
filter=F.entity_type("dataset"),
|
||||
)
|
||||
|
||||
# Validate results
|
||||
assert len(results) == 1
|
||||
assert {r.type for r in results} == {"DATASET"}
|
||||
assert {r.platform for r in results} == {"snowflake"}
|
||||
|
||||
|
||||
def test_get_lineage_downstream(client: DataHubClient) -> None:
|
||||
"""Test downstream lineage retrieval."""
|
||||
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||
|
||||
# Mock GraphQL response for downstream lineage
|
||||
mock_response = {
|
||||
"scrollAcrossLineage": {
|
||||
"nextScrollId": None,
|
||||
"searchResults": [
|
||||
{
|
||||
"entity": {
|
||||
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)",
|
||||
"type": "DATASET",
|
||||
"properties": {
|
||||
"name": "downstream_table",
|
||||
"description": "Downstream target table",
|
||||
"platform": {"name": "snowflake"},
|
||||
},
|
||||
},
|
||||
"degree": 1,
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
# Patch the GraphQL execution method
|
||||
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
|
||||
results = client.lineage.get_lineage(
|
||||
source_urn=source_urn,
|
||||
direction="downstream",
|
||||
)
|
||||
|
||||
# Validate results
|
||||
assert len(results) == 1
|
||||
assert (
|
||||
results[0].urn
|
||||
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
|
||||
)
|
||||
assert results[0].direction == "downstream"
|
||||
|
||||
|
||||
def test_get_lineage_multiple_hops(
|
||||
client: DataHubClient, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test lineage retrieval with multiple hops."""
|
||||
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||
|
||||
# Mock GraphQL response with multiple hops
|
||||
mock_response = {
|
||||
"scrollAcrossLineage": {
|
||||
"nextScrollId": None,
|
||||
"searchResults": [
|
||||
{
|
||||
"entity": {
|
||||
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table1,PROD)",
|
||||
"type": "DATASET",
|
||||
"properties": {
|
||||
"name": "upstream_table1",
|
||||
"description": "First upstream table",
|
||||
"platform": {"name": "snowflake"},
|
||||
},
|
||||
},
|
||||
"degree": 1,
|
||||
},
|
||||
{
|
||||
"entity": {
|
||||
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table2,PROD)",
|
||||
"type": "DATASET",
|
||||
"properties": {
|
||||
"name": "upstream_table2",
|
||||
"description": "Second upstream table",
|
||||
"platform": {"name": "snowflake"},
|
||||
},
|
||||
},
|
||||
"degree": 2,
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
# Patch the GraphQL execution method
|
||||
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
|
||||
results = client.lineage.get_lineage(source_urn=source_urn, max_hops=2)
|
||||
|
||||
# check warning if logged when max_hops > 2
|
||||
with patch.object(
|
||||
client._graph, "execute_graphql", return_value=mock_response
|
||||
), caplog.at_level("WARNING"):
|
||||
client.lineage.get_lineage(source_urn=source_urn, max_hops=3)
|
||||
assert any(
|
||||
"the search will try to find the full lineage graph" in msg
|
||||
for msg in caplog.messages
|
||||
)
|
||||
|
||||
# Validate results
|
||||
assert len(results) == 2
|
||||
assert results[0].hops == 1
|
||||
assert results[0].type == "DATASET"
|
||||
assert results[1].hops == 2
|
||||
assert results[1].type == "DATASET"
|
||||
|
||||
|
||||
def test_get_lineage_no_results(client: DataHubClient) -> None:
|
||||
"""Test lineage retrieval with no results."""
|
||||
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||
|
||||
# Mock GraphQL response with no results
|
||||
mock_response: Dict[str, Dict[str, Optional[List]]] = {
|
||||
"scrollAcrossLineage": {"nextScrollId": None, "searchResults": []}
|
||||
}
|
||||
|
||||
# Patch the GraphQL execution method
|
||||
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
|
||||
results = client.lineage.get_lineage(source_urn=source_urn)
|
||||
|
||||
# Validate results
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_get_lineage_column_lineage_with_source_column(client: DataHubClient) -> None:
|
||||
"""Test lineage retrieval with column lineage."""
|
||||
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||
source_column = "source_column"
|
||||
|
||||
# Mock GraphQL response with column lineage
|
||||
mock_response = {
|
||||
"scrollAcrossLineage": {
|
||||
"nextScrollId": None,
|
||||
"searchResults": [
|
||||
{
|
||||
"entity": {
|
||||
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)",
|
||||
"type": "DATASET",
|
||||
"platform": {"name": "snowflake"},
|
||||
"properties": {
|
||||
"name": "upstream_table",
|
||||
"description": "Upstream source table",
|
||||
},
|
||||
},
|
||||
"degree": 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
# Patch the GraphQL execution method
|
||||
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
|
||||
results = client.lineage.get_lineage(
|
||||
source_urn=source_urn,
|
||||
source_column=source_column,
|
||||
)
|
||||
|
||||
# Validate results
|
||||
assert len(results) == 1
|
||||
assert (
|
||||
results[0].urn
|
||||
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
|
||||
)
|
||||
|
||||
|
||||
def test_get_lineage_column_lineage_with_schema_field_urn(
|
||||
client: DataHubClient,
|
||||
) -> None:
|
||||
"""Test lineage retrieval with column lineage."""
|
||||
source_urn = "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD),source_column)"
|
||||
|
||||
# Mock GraphQL response with column lineage
|
||||
mock_response = {
|
||||
"scrollAcrossLineage": {
|
||||
"nextScrollId": None,
|
||||
"searchResults": [
|
||||
{
|
||||
"entity": {
|
||||
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)",
|
||||
"type": "DATASET",
|
||||
"platform": {"name": "snowflake"},
|
||||
"properties": {
|
||||
"name": "upstream_table",
|
||||
"description": "Upstream source table",
|
||||
},
|
||||
},
|
||||
"degree": 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
# Patch the GraphQL execution method
|
||||
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
|
||||
results = client.lineage.get_lineage(
|
||||
source_urn=source_urn,
|
||||
)
|
||||
|
||||
# Validate results
|
||||
assert len(results) == 1
|
||||
assert (
|
||||
results[0].urn
|
||||
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
|
||||
)
|
||||
|
||||
222
smoke-test/tests/lineage/test_lineage_sdk.py
Normal file
222
smoke-test/tests/lineage/test_lineage_sdk.py
Normal file
@ -0,0 +1,222 @@
|
||||
from typing import Dict, Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from datahub.ingestion.graph.client import DataHubGraph
|
||||
from datahub.metadata.urns import SchemaFieldUrn
|
||||
from datahub.sdk.dataset import Dataset
|
||||
from datahub.sdk.lineage_client import LineageResult
|
||||
from datahub.sdk.main_client import DataHubClient
|
||||
from datahub.sdk.search_filters import FilterDsl as F
|
||||
from tests.utils import wait_for_writes_to_sync
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def test_client(graph_client: DataHubGraph) -> DataHubClient:
|
||||
return DataHubClient(graph=graph_client)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def test_datasets(
|
||||
test_client: DataHubClient,
|
||||
) -> Generator[Dict[str, Dataset], None, None]:
|
||||
datasets = {
|
||||
"upstream": Dataset(
|
||||
platform="snowflake",
|
||||
name="test_lineage_upstream_001",
|
||||
schema=[("name", "string"), ("id", "int")],
|
||||
),
|
||||
"downstream1": Dataset(
|
||||
platform="snowflake",
|
||||
name="test_lineage_downstream_001",
|
||||
schema=[("name", "string"), ("id", "int")],
|
||||
),
|
||||
"downstream2": Dataset(
|
||||
platform="snowflake",
|
||||
name="test_lineage_downstream_002",
|
||||
schema=[("name", "string"), ("id", "int")],
|
||||
),
|
||||
"downstream3": Dataset(
|
||||
platform="mysql",
|
||||
name="test_lineage_downstream_003",
|
||||
schema=[("name", "string"), ("id", "int")],
|
||||
),
|
||||
}
|
||||
|
||||
for entity in datasets.values():
|
||||
test_client._graph.delete_entity(str(entity.urn), hard=True)
|
||||
for entity in datasets.values():
|
||||
test_client.entities.upsert(entity)
|
||||
|
||||
# Add lineage
|
||||
test_client.lineage.add_lineage(
|
||||
upstream=str(datasets["upstream"].urn),
|
||||
downstream=str(datasets["downstream1"].urn),
|
||||
column_lineage=True,
|
||||
)
|
||||
test_client.lineage.add_lineage(
|
||||
upstream=str(datasets["downstream1"].urn),
|
||||
downstream=str(datasets["downstream2"].urn),
|
||||
column_lineage=True,
|
||||
)
|
||||
test_client.lineage.add_lineage(
|
||||
upstream=str(datasets["downstream2"].urn),
|
||||
downstream=str(datasets["downstream3"].urn),
|
||||
column_lineage=True,
|
||||
)
|
||||
|
||||
wait_for_writes_to_sync()
|
||||
|
||||
yield datasets
|
||||
|
||||
# Cleanup
|
||||
for entity in datasets.values():
|
||||
try:
|
||||
test_client._graph.delete_entity(str(entity.urn), hard=True)
|
||||
except Exception as e:
|
||||
raise Exception(f"Could not delete entity {entity.urn}: {e}")
|
||||
|
||||
|
||||
def validate_lineage_results(
|
||||
lineage_result: LineageResult,
|
||||
hops=None,
|
||||
direction=None,
|
||||
platform=None,
|
||||
urn=None,
|
||||
paths_len=None,
|
||||
):
|
||||
if hops is not None:
|
||||
assert lineage_result.hops == hops
|
||||
if direction is not None:
|
||||
assert lineage_result.direction == direction
|
||||
if platform is not None:
|
||||
assert lineage_result.platform == platform
|
||||
if urn is not None:
|
||||
assert lineage_result.urn == urn
|
||||
if paths_len is not None and lineage_result.paths is not None:
|
||||
assert len(lineage_result.paths) == paths_len
|
||||
|
||||
|
||||
def test_table_level_lineage(
|
||||
test_client: DataHubClient, test_datasets: Dict[str, Dataset]
|
||||
):
|
||||
table_lineage_results = test_client.lineage.get_lineage(
|
||||
source_urn=str(test_datasets["upstream"].urn),
|
||||
direction="downstream",
|
||||
max_hops=3,
|
||||
)
|
||||
|
||||
assert len(table_lineage_results) == 3
|
||||
urns = {r.urn for r in table_lineage_results}
|
||||
expected = {
|
||||
str(test_datasets["downstream1"].urn),
|
||||
str(test_datasets["downstream2"].urn),
|
||||
str(test_datasets["downstream3"].urn),
|
||||
}
|
||||
assert urns == expected
|
||||
|
||||
table_lineage_results = sorted(table_lineage_results, key=lambda x: x.hops)
|
||||
validate_lineage_results(
|
||||
table_lineage_results[0],
|
||||
hops=1,
|
||||
platform="snowflake",
|
||||
urn=str(test_datasets["downstream1"].urn),
|
||||
paths_len=0,
|
||||
)
|
||||
validate_lineage_results(
|
||||
table_lineage_results[1],
|
||||
hops=2,
|
||||
platform="snowflake",
|
||||
urn=str(test_datasets["downstream2"].urn),
|
||||
paths_len=0,
|
||||
)
|
||||
validate_lineage_results(
|
||||
table_lineage_results[2],
|
||||
hops=3,
|
||||
platform="mysql",
|
||||
urn=str(test_datasets["downstream3"].urn),
|
||||
paths_len=0,
|
||||
)
|
||||
|
||||
|
||||
def test_column_level_lineage(
|
||||
test_client: DataHubClient, test_datasets: Dict[str, Dataset]
|
||||
):
|
||||
column_lineage_results = test_client.lineage.get_lineage(
|
||||
source_urn=str(test_datasets["upstream"].urn),
|
||||
source_column="id",
|
||||
direction="downstream",
|
||||
max_hops=3,
|
||||
)
|
||||
|
||||
assert len(column_lineage_results) == 3
|
||||
column_lineage_results = sorted(column_lineage_results, key=lambda x: x.hops)
|
||||
validate_lineage_results(
|
||||
column_lineage_results[0],
|
||||
hops=1,
|
||||
urn=str(test_datasets["downstream1"].urn),
|
||||
paths_len=2,
|
||||
)
|
||||
validate_lineage_results(
|
||||
column_lineage_results[1],
|
||||
hops=2,
|
||||
urn=str(test_datasets["downstream2"].urn),
|
||||
paths_len=3,
|
||||
)
|
||||
validate_lineage_results(
|
||||
column_lineage_results[2],
|
||||
hops=3,
|
||||
urn=str(test_datasets["downstream3"].urn),
|
||||
paths_len=4,
|
||||
)
|
||||
|
||||
|
||||
def test_filtered_column_level_lineage(
|
||||
test_client: DataHubClient, test_datasets: Dict[str, Dataset]
|
||||
):
|
||||
filtered_column_lineage_results = test_client.lineage.get_lineage(
|
||||
source_urn=str(test_datasets["upstream"].urn),
|
||||
source_column="id",
|
||||
direction="downstream",
|
||||
max_hops=3,
|
||||
filter=F.and_(F.platform("mysql"), F.entity_type("dataset")),
|
||||
)
|
||||
|
||||
assert len(filtered_column_lineage_results) == 1
|
||||
validate_lineage_results(
|
||||
filtered_column_lineage_results[0],
|
||||
hops=3,
|
||||
platform="mysql",
|
||||
urn=str(test_datasets["downstream3"].urn),
|
||||
paths_len=4,
|
||||
)
|
||||
|
||||
|
||||
def test_column_level_lineage_from_schema_field(
|
||||
test_client: DataHubClient, test_datasets: Dict[str, Dataset]
|
||||
):
|
||||
source_schema_field = SchemaFieldUrn(test_datasets["upstream"].urn, "id")
|
||||
column_lineage_results = test_client.lineage.get_lineage(
|
||||
source_urn=str(source_schema_field), direction="downstream", max_hops=3
|
||||
)
|
||||
|
||||
assert len(column_lineage_results) == 3
|
||||
column_lineage_results = sorted(column_lineage_results, key=lambda x: x.hops)
|
||||
validate_lineage_results(
|
||||
column_lineage_results[0],
|
||||
hops=1,
|
||||
urn=str(test_datasets["downstream1"].urn),
|
||||
paths_len=2,
|
||||
)
|
||||
validate_lineage_results(
|
||||
column_lineage_results[1],
|
||||
hops=2,
|
||||
urn=str(test_datasets["downstream2"].urn),
|
||||
paths_len=3,
|
||||
)
|
||||
validate_lineage_results(
|
||||
column_lineage_results[2],
|
||||
hops=3,
|
||||
urn=str(test_datasets["downstream3"].urn),
|
||||
paths_len=4,
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user