feat(sdk): add get_lineage (#13654)

This commit is contained in:
Hyejin Yoon 2025-06-06 12:34:52 +09:00 committed by GitHub
parent 01357940b1
commit e169b4ac05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 812 additions and 38 deletions

View File

@ -0,0 +1,24 @@
from datahub.metadata.urns import DatasetUrn
from datahub.sdk.main_client import DataHubClient
from datahub.sdk.search_filters import FilterDsl as F
client = DataHubClient.from_env()
dataset_urn = DatasetUrn(platform="snowflake", name="downstream_table")
# Get column lineage for the entire flow
# you can pass source_urn and source_column to get lineage for a specific column
# alternatively, you can pass schemaFieldUrn to source_urn.
# e.g. source_urn="urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table),id)"
downstream_column_lineage = client.lineage.get_lineage(
source_urn=dataset_urn,
source_column="id",
direction="downstream",
max_hops=1,
filter=F.and_(
F.platform("snowflake"),
F.entity_type("dataset"),
),
)
print(downstream_column_lineage)

View File

@ -0,0 +1,17 @@
from datahub.metadata.urns import DatasetUrn
from datahub.sdk.main_client import DataHubClient
from datahub.sdk.search_filters import FilterDsl as F
client = DataHubClient.from_env()
downstream_lineage = client.lineage.get_lineage(
source_urn=DatasetUrn(platform="snowflake", name="downstream_table"),
direction="downstream",
max_hops=2,
filter=F.and_(
F.platform("airflow"),
F.entity_type("dataJob"),
),
)
print(downstream_lineage)

View File

@ -2,9 +2,12 @@ from __future__ import annotations
import difflib import difflib
import logging import logging
from dataclasses import dataclass
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any,
Callable, Callable,
Dict,
List, List,
Literal, Literal,
Optional, Optional,
@ -18,12 +21,7 @@ from typing_extensions import assert_never, deprecated
import datahub.metadata.schema_classes as models import datahub.metadata.schema_classes as models
from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.errors import SdkUsageError from datahub.errors import SdkUsageError
from datahub.metadata.urns import ( from datahub.metadata.urns import DataJobUrn, DatasetUrn, QueryUrn, SchemaFieldUrn, Urn
DataJobUrn,
DatasetUrn,
QueryUrn,
Urn,
)
from datahub.sdk._shared import ( from datahub.sdk._shared import (
ChartUrnOrStr, ChartUrnOrStr,
DashboardUrnOrStr, DashboardUrnOrStr,
@ -32,6 +30,8 @@ from datahub.sdk._shared import (
) )
from datahub.sdk._utils import DEFAULT_ACTOR_URN from datahub.sdk._utils import DEFAULT_ACTOR_URN
from datahub.sdk.dataset import ColumnLineageMapping, parse_cll_mapping from datahub.sdk.dataset import ColumnLineageMapping, parse_cll_mapping
from datahub.sdk.search_client import compile_filters
from datahub.sdk.search_filters import Filter, FilterDsl
from datahub.specific.chart import ChartPatchBuilder from datahub.specific.chart import ChartPatchBuilder
from datahub.specific.dashboard import DashboardPatchBuilder from datahub.specific.dashboard import DashboardPatchBuilder
from datahub.specific.datajob import DataJobPatchBuilder from datahub.specific.datajob import DataJobPatchBuilder
@ -53,9 +53,29 @@ _empty_audit_stamp = models.AuditStampClass(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class LineagePath:
urn: str
entity_name: str
column_name: Optional[str] = None
@dataclass
class LineageResult:
urn: str
type: str
hops: int
direction: Literal["upstream", "downstream"]
platform: Optional[str] = None
name: Optional[str] = None
description: Optional[str] = None
paths: Optional[List[LineagePath]] = None
class LineageClient: class LineageClient:
def __init__(self, client: DataHubClient): def __init__(self, client: DataHubClient):
self._client = client self._client = client
self._graph = client._graph
def _get_fields_from_dataset_urn(self, dataset_urn: DatasetUrn) -> Set[str]: def _get_fields_from_dataset_urn(self, dataset_urn: DatasetUrn) -> Set[str]:
schema_metadata = self._client._graph.get_aspect( schema_metadata = self._client._graph.get_aspect(
@ -700,3 +720,242 @@ class LineageClient:
# Apply the changes to the entity # Apply the changes to the entity
self._client.entities.update(patch_builder) self._client.entities.update(patch_builder)
def get_lineage(
self,
*,
source_urn: Union[str, Urn],
source_column: Optional[str] = None,
direction: Literal["upstream", "downstream"] = "upstream",
max_hops: int = 1,
filter: Optional[Filter] = None,
count: int = 500,
) -> List[LineageResult]:
"""
Retrieve lineage entities connected to a source entity.
Args:
source_urn: Source URN for the lineage search
source_column: Source column for the lineage search
direction: Direction of lineage traversal
max_hops: Maximum number of hops to traverse
filter: Filters to apply to the lineage search
count: Maximum number of results to return
Returns:
List of lineage results
Raises:
SdkUsageError for invalid filter values
"""
# Validate and convert input URN
source_urn = Urn.from_string(source_urn)
# Prepare GraphQL query variables with a separate method
variables = self._process_input_variables(
source_urn, source_column, filter, direction, max_hops, count
)
return self._execute_lineage_query(variables, direction)
def _process_input_variables(
self,
source_urn: Urn,
source_column: Optional[str] = None,
filters: Optional[Filter] = None,
direction: Literal["upstream", "downstream"] = "upstream",
max_hops: int = 1,
count: int = 500,
) -> Dict[str, Any]:
"""
Process filters and prepare GraphQL query variables for lineage search.
Args:
source_urn: Source URN for the lineage search
source_column: Source column for the lineage search
filters: Optional filters to apply
direction: Direction of lineage traversal
max_hops: Maximum number of hops to traverse
count: Maximum number of results to return
Returns:
Dictionary of GraphQL query variables
Raises:
SdkUsageError for invalid filter values
"""
# print warning if max_hops is greater than 2
if max_hops > 2:
logger.warning(
"""If `max_hops` is more than 2, the search will try to find the full lineage graph.
By default, only 500 results are shown.
You can change the `count` to get more or fewer results.
"""
)
# Determine hop values
max_hop_values = (
[str(hop) for hop in range(1, max_hops + 1)]
if max_hops <= 2
else ["1", "2", "3+"]
)
degree_filter = FilterDsl.custom_filter(
field="degree",
condition="EQUAL",
values=max_hop_values,
)
filters_with_max_hops = (
FilterDsl.and_(degree_filter, filters)
if filters is not None
else degree_filter
)
types, compiled_filters = compile_filters(filters_with_max_hops)
# Prepare base variables
variables: Dict[str, Any] = {
"input": {
"urn": str(source_urn),
"direction": direction.upper(),
"count": count,
"types": types,
"orFilters": compiled_filters,
}
}
# if column is provided, update the variables to include the schema field urn
if isinstance(source_urn, SchemaFieldUrn) or source_column:
variables["input"]["searchFlags"] = {
"groupingSpec": {
"groupingCriteria": {
"baseEntityType": "SCHEMA_FIELD",
"groupingEntityType": "SCHEMA_FIELD",
}
}
}
if isinstance(source_urn, SchemaFieldUrn):
variables["input"]["urn"] = str(source_urn)
elif source_column:
variables["input"]["urn"] = str(SchemaFieldUrn(source_urn, source_column))
return variables
def _execute_lineage_query(
self,
variables: Dict[str, Any],
direction: Literal["upstream", "downstream"],
) -> List[LineageResult]:
"""Execute GraphQL query and process results."""
# Construct GraphQL query with dynamic path query
graphql_query = """
query scrollAcrossLineage($input: ScrollAcrossLineageInput!) {
scrollAcrossLineage(input: $input) {
nextScrollId
searchResults {
degree
entity {
urn
type
... on Dataset {
name
platform {
name
}
properties {
description
}
}
... on DataJob {
jobId
dataPlatformInstance {
platform {
name
}
}
properties {
name
description
}
}
}
paths {
path {
urn
type
}
}
}
}
}
"""
results: List[LineageResult] = []
first_iter = True
scroll_id: Optional[str] = None
while first_iter or scroll_id:
first_iter = False
# Update scroll ID if applicable
if scroll_id:
variables["input"]["scrollId"] = scroll_id
# Execute GraphQL query
response = self._graph.execute_graphql(graphql_query, variables=variables)
data = response["scrollAcrossLineage"]
scroll_id = data.get("nextScrollId")
# Process search results
for entry in data["searchResults"]:
entity = entry["entity"]
result = self._create_lineage_result(entity, entry, direction)
results.append(result)
return results
def _create_lineage_result(
self,
entity: Dict[str, Any],
entry: Dict[str, Any],
direction: Literal["upstream", "downstream"],
) -> LineageResult:
"""Create a LineageResult from entity and entry data."""
platform = entity.get("platform", {}).get("name") or entity.get(
"dataPlatformInstance", {}
).get("platform", {}).get("name")
result = LineageResult(
urn=entity["urn"],
type=entity["type"],
hops=entry["degree"],
direction=direction,
platform=platform,
)
properties = entity.get("properties", {})
if properties:
result.name = properties.get("name", "")
result.description = properties.get("description", "")
result.paths = []
if "paths" in entry:
# Process each path in the lineage graph
for path in entry["paths"]:
for path_entry in path["path"]:
# Only include schema fields in the path (exclude other types like Query)
if path_entry["type"] == "SCHEMA_FIELD":
schema_field_urn = SchemaFieldUrn.from_string(path_entry["urn"])
result.paths.append(
LineagePath(
urn=path_entry["urn"],
entity_name=DatasetUrn.from_string(
schema_field_urn.parent
).name,
column_name=schema_field_urn.field_path,
)
)
return result

View File

@ -1,11 +1,12 @@
import pathlib import pathlib
from typing import Dict, List, Sequence, Set, cast from typing import Dict, List, Optional, Sequence, Set, cast
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
from datahub.errors import SdkUsageError from datahub.errors import SdkUsageError
from datahub.sdk.main_client import DataHubClient from datahub.sdk.main_client import DataHubClient
from datahub.sdk.search_filters import FilterDsl as F
from datahub.sql_parsing.sql_parsing_common import QueryType from datahub.sql_parsing.sql_parsing_common import QueryType
from datahub.sql_parsing.sqlglot_lineage import ( from datahub.sql_parsing.sqlglot_lineage import (
ColumnLineageInfo, ColumnLineageInfo,
@ -127,38 +128,8 @@ def test_get_strict_column_lineage(
cast(Set[str], upstream_fields), cast(Set[str], upstream_fields),
cast(Set[str], downstream_fields), cast(Set[str], downstream_fields),
) )
assert result == expected, f"Test failed: {result} != {expected}"
"""Test the strict column lineage matching algorithm.""" """Test the strict column lineage matching algorithm."""
test_cases = [ assert result == expected, f"Test failed: {result} != {expected}"
# Exact matches
{
"upstream_fields": {"id", "name", "email"},
"downstream_fields": {"id", "name", "phone"},
"expected": {"id": ["id"], "name": ["name"]},
},
# No matches
{
"upstream_fields": {"col1", "col2", "col3"},
"downstream_fields": {"col4", "col5", "col6"},
"expected": {},
},
# Case mismatch (should match)
{
"upstream_fields": {"ID", "Name", "Email"},
"downstream_fields": {"id", "name", "email"},
"expected": {"id": ["ID"], "name": ["Name"], "email": ["Email"]},
},
]
# Run test cases
for test_case in test_cases:
result = client.lineage._get_strict_column_lineage(
cast(Set[str], test_case["upstream_fields"]),
cast(Set[str], test_case["downstream_fields"]),
)
assert result == test_case["expected"], (
f"Test failed: {result} != {test_case['expected']}"
)
def test_infer_lineage_from_sql(client: DataHubClient) -> None: def test_infer_lineage_from_sql(client: DataHubClient) -> None:
@ -444,3 +415,284 @@ def test_add_lineage_invalid_parameter_combinations(client: DataHubClient) -> No
downstream="urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)", downstream="urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)",
column_lineage={"target_col": ["source_col"]}, column_lineage={"target_col": ["source_col"]},
) )
def test_get_lineage_basic(client: DataHubClient) -> None:
"""Test basic lineage retrieval with default parameters."""
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
# Mock GraphQL response
mock_response = {
"scrollAcrossLineage": {
"nextScrollId": None,
"searchResults": [
{
"entity": {
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)",
"type": "DATASET",
"platform": {"name": "snowflake"},
"properties": {
"name": "upstream_table",
"description": "Upstream source table",
},
},
"degree": 1,
}
],
}
}
# Patch the GraphQL execution method
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
results = client.lineage.get_lineage(source_urn=source_urn)
# Validate results
assert len(results) == 1
assert (
results[0].urn
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
)
assert results[0].type == "DATASET"
assert results[0].name == "upstream_table"
assert results[0].platform == "snowflake"
assert results[0].direction == "upstream"
assert results[0].hops == 1
def test_get_lineage_with_entity_type_filters(client: DataHubClient) -> None:
"""Test lineage retrieval with entity type and platform filters."""
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
# Mock GraphQL response with multiple entity types
mock_response = {
"scrollAcrossLineage": {
"nextScrollId": None,
"searchResults": [
{
"entity": {
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)",
"type": "DATASET",
"platform": {"name": "snowflake"},
"properties": {
"name": "upstream_table",
"description": "Upstream source table",
},
},
"degree": 1,
},
],
}
}
# Patch the GraphQL execution method to return results for multiple calls
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
results = client.lineage.get_lineage(
source_urn=source_urn,
filter=F.entity_type("dataset"),
)
# Validate results
assert len(results) == 1
assert {r.type for r in results} == {"DATASET"}
assert {r.platform for r in results} == {"snowflake"}
def test_get_lineage_downstream(client: DataHubClient) -> None:
"""Test downstream lineage retrieval."""
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
# Mock GraphQL response for downstream lineage
mock_response = {
"scrollAcrossLineage": {
"nextScrollId": None,
"searchResults": [
{
"entity": {
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)",
"type": "DATASET",
"properties": {
"name": "downstream_table",
"description": "Downstream target table",
"platform": {"name": "snowflake"},
},
},
"degree": 1,
}
],
}
}
# Patch the GraphQL execution method
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
results = client.lineage.get_lineage(
source_urn=source_urn,
direction="downstream",
)
# Validate results
assert len(results) == 1
assert (
results[0].urn
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
)
assert results[0].direction == "downstream"
def test_get_lineage_multiple_hops(
client: DataHubClient, caplog: pytest.LogCaptureFixture
) -> None:
"""Test lineage retrieval with multiple hops."""
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
# Mock GraphQL response with multiple hops
mock_response = {
"scrollAcrossLineage": {
"nextScrollId": None,
"searchResults": [
{
"entity": {
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table1,PROD)",
"type": "DATASET",
"properties": {
"name": "upstream_table1",
"description": "First upstream table",
"platform": {"name": "snowflake"},
},
},
"degree": 1,
},
{
"entity": {
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table2,PROD)",
"type": "DATASET",
"properties": {
"name": "upstream_table2",
"description": "Second upstream table",
"platform": {"name": "snowflake"},
},
},
"degree": 2,
},
],
}
}
# Patch the GraphQL execution method
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
results = client.lineage.get_lineage(source_urn=source_urn, max_hops=2)
# check warning if logged when max_hops > 2
with patch.object(
client._graph, "execute_graphql", return_value=mock_response
), caplog.at_level("WARNING"):
client.lineage.get_lineage(source_urn=source_urn, max_hops=3)
assert any(
"the search will try to find the full lineage graph" in msg
for msg in caplog.messages
)
# Validate results
assert len(results) == 2
assert results[0].hops == 1
assert results[0].type == "DATASET"
assert results[1].hops == 2
assert results[1].type == "DATASET"
def test_get_lineage_no_results(client: DataHubClient) -> None:
"""Test lineage retrieval with no results."""
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
# Mock GraphQL response with no results
mock_response: Dict[str, Dict[str, Optional[List]]] = {
"scrollAcrossLineage": {"nextScrollId": None, "searchResults": []}
}
# Patch the GraphQL execution method
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
results = client.lineage.get_lineage(source_urn=source_urn)
# Validate results
assert len(results) == 0
def test_get_lineage_column_lineage_with_source_column(client: DataHubClient) -> None:
"""Test lineage retrieval with column lineage."""
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
source_column = "source_column"
# Mock GraphQL response with column lineage
mock_response = {
"scrollAcrossLineage": {
"nextScrollId": None,
"searchResults": [
{
"entity": {
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)",
"type": "DATASET",
"platform": {"name": "snowflake"},
"properties": {
"name": "upstream_table",
"description": "Upstream source table",
},
},
"degree": 1,
},
],
}
}
# Patch the GraphQL execution method
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
results = client.lineage.get_lineage(
source_urn=source_urn,
source_column=source_column,
)
# Validate results
assert len(results) == 1
assert (
results[0].urn
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
)
def test_get_lineage_column_lineage_with_schema_field_urn(
client: DataHubClient,
) -> None:
"""Test lineage retrieval with column lineage."""
source_urn = "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD),source_column)"
# Mock GraphQL response with column lineage
mock_response = {
"scrollAcrossLineage": {
"nextScrollId": None,
"searchResults": [
{
"entity": {
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)",
"type": "DATASET",
"platform": {"name": "snowflake"},
"properties": {
"name": "upstream_table",
"description": "Upstream source table",
},
},
"degree": 1,
},
],
}
}
# Patch the GraphQL execution method
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
results = client.lineage.get_lineage(
source_urn=source_urn,
)
# Validate results
assert len(results) == 1
assert (
results[0].urn
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
)

View File

@ -0,0 +1,222 @@
from typing import Dict, Generator
import pytest
from datahub.ingestion.graph.client import DataHubGraph
from datahub.metadata.urns import SchemaFieldUrn
from datahub.sdk.dataset import Dataset
from datahub.sdk.lineage_client import LineageResult
from datahub.sdk.main_client import DataHubClient
from datahub.sdk.search_filters import FilterDsl as F
from tests.utils import wait_for_writes_to_sync
@pytest.fixture(scope="module")
def test_client(graph_client: DataHubGraph) -> DataHubClient:
return DataHubClient(graph=graph_client)
@pytest.fixture(scope="module")
def test_datasets(
test_client: DataHubClient,
) -> Generator[Dict[str, Dataset], None, None]:
datasets = {
"upstream": Dataset(
platform="snowflake",
name="test_lineage_upstream_001",
schema=[("name", "string"), ("id", "int")],
),
"downstream1": Dataset(
platform="snowflake",
name="test_lineage_downstream_001",
schema=[("name", "string"), ("id", "int")],
),
"downstream2": Dataset(
platform="snowflake",
name="test_lineage_downstream_002",
schema=[("name", "string"), ("id", "int")],
),
"downstream3": Dataset(
platform="mysql",
name="test_lineage_downstream_003",
schema=[("name", "string"), ("id", "int")],
),
}
for entity in datasets.values():
test_client._graph.delete_entity(str(entity.urn), hard=True)
for entity in datasets.values():
test_client.entities.upsert(entity)
# Add lineage
test_client.lineage.add_lineage(
upstream=str(datasets["upstream"].urn),
downstream=str(datasets["downstream1"].urn),
column_lineage=True,
)
test_client.lineage.add_lineage(
upstream=str(datasets["downstream1"].urn),
downstream=str(datasets["downstream2"].urn),
column_lineage=True,
)
test_client.lineage.add_lineage(
upstream=str(datasets["downstream2"].urn),
downstream=str(datasets["downstream3"].urn),
column_lineage=True,
)
wait_for_writes_to_sync()
yield datasets
# Cleanup
for entity in datasets.values():
try:
test_client._graph.delete_entity(str(entity.urn), hard=True)
except Exception as e:
raise Exception(f"Could not delete entity {entity.urn}: {e}")
def validate_lineage_results(
lineage_result: LineageResult,
hops=None,
direction=None,
platform=None,
urn=None,
paths_len=None,
):
if hops is not None:
assert lineage_result.hops == hops
if direction is not None:
assert lineage_result.direction == direction
if platform is not None:
assert lineage_result.platform == platform
if urn is not None:
assert lineage_result.urn == urn
if paths_len is not None and lineage_result.paths is not None:
assert len(lineage_result.paths) == paths_len
def test_table_level_lineage(
test_client: DataHubClient, test_datasets: Dict[str, Dataset]
):
table_lineage_results = test_client.lineage.get_lineage(
source_urn=str(test_datasets["upstream"].urn),
direction="downstream",
max_hops=3,
)
assert len(table_lineage_results) == 3
urns = {r.urn for r in table_lineage_results}
expected = {
str(test_datasets["downstream1"].urn),
str(test_datasets["downstream2"].urn),
str(test_datasets["downstream3"].urn),
}
assert urns == expected
table_lineage_results = sorted(table_lineage_results, key=lambda x: x.hops)
validate_lineage_results(
table_lineage_results[0],
hops=1,
platform="snowflake",
urn=str(test_datasets["downstream1"].urn),
paths_len=0,
)
validate_lineage_results(
table_lineage_results[1],
hops=2,
platform="snowflake",
urn=str(test_datasets["downstream2"].urn),
paths_len=0,
)
validate_lineage_results(
table_lineage_results[2],
hops=3,
platform="mysql",
urn=str(test_datasets["downstream3"].urn),
paths_len=0,
)
def test_column_level_lineage(
test_client: DataHubClient, test_datasets: Dict[str, Dataset]
):
column_lineage_results = test_client.lineage.get_lineage(
source_urn=str(test_datasets["upstream"].urn),
source_column="id",
direction="downstream",
max_hops=3,
)
assert len(column_lineage_results) == 3
column_lineage_results = sorted(column_lineage_results, key=lambda x: x.hops)
validate_lineage_results(
column_lineage_results[0],
hops=1,
urn=str(test_datasets["downstream1"].urn),
paths_len=2,
)
validate_lineage_results(
column_lineage_results[1],
hops=2,
urn=str(test_datasets["downstream2"].urn),
paths_len=3,
)
validate_lineage_results(
column_lineage_results[2],
hops=3,
urn=str(test_datasets["downstream3"].urn),
paths_len=4,
)
def test_filtered_column_level_lineage(
test_client: DataHubClient, test_datasets: Dict[str, Dataset]
):
filtered_column_lineage_results = test_client.lineage.get_lineage(
source_urn=str(test_datasets["upstream"].urn),
source_column="id",
direction="downstream",
max_hops=3,
filter=F.and_(F.platform("mysql"), F.entity_type("dataset")),
)
assert len(filtered_column_lineage_results) == 1
validate_lineage_results(
filtered_column_lineage_results[0],
hops=3,
platform="mysql",
urn=str(test_datasets["downstream3"].urn),
paths_len=4,
)
def test_column_level_lineage_from_schema_field(
test_client: DataHubClient, test_datasets: Dict[str, Dataset]
):
source_schema_field = SchemaFieldUrn(test_datasets["upstream"].urn, "id")
column_lineage_results = test_client.lineage.get_lineage(
source_urn=str(source_schema_field), direction="downstream", max_hops=3
)
assert len(column_lineage_results) == 3
column_lineage_results = sorted(column_lineage_results, key=lambda x: x.hops)
validate_lineage_results(
column_lineage_results[0],
hops=1,
urn=str(test_datasets["downstream1"].urn),
paths_len=2,
)
validate_lineage_results(
column_lineage_results[1],
hops=2,
urn=str(test_datasets["downstream2"].urn),
paths_len=3,
)
validate_lineage_results(
column_lineage_results[2],
hops=3,
urn=str(test_datasets["downstream3"].urn),
paths_len=4,
)