mirror of
https://github.com/datahub-project/datahub.git
synced 2025-11-07 15:04:01 +00:00
feat(sdk): add get_lineage (#13654)
This commit is contained in:
parent
01357940b1
commit
e169b4ac05
24
metadata-ingestion/examples/library/get_column_lineage.py
Normal file
24
metadata-ingestion/examples/library/get_column_lineage.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from datahub.metadata.urns import DatasetUrn
|
||||||
|
from datahub.sdk.main_client import DataHubClient
|
||||||
|
from datahub.sdk.search_filters import FilterDsl as F
|
||||||
|
|
||||||
|
client = DataHubClient.from_env()
|
||||||
|
|
||||||
|
dataset_urn = DatasetUrn(platform="snowflake", name="downstream_table")
|
||||||
|
|
||||||
|
# Get column lineage for the entire flow
|
||||||
|
# you can pass source_urn and source_column to get lineage for a specific column
|
||||||
|
# alternatively, you can pass schemaFieldUrn to source_urn.
|
||||||
|
# e.g. source_urn="urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table),id)"
|
||||||
|
downstream_column_lineage = client.lineage.get_lineage(
|
||||||
|
source_urn=dataset_urn,
|
||||||
|
source_column="id",
|
||||||
|
direction="downstream",
|
||||||
|
max_hops=1,
|
||||||
|
filter=F.and_(
|
||||||
|
F.platform("snowflake"),
|
||||||
|
F.entity_type("dataset"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(downstream_column_lineage)
|
||||||
17
metadata-ingestion/examples/library/get_lineage_basic.py
Normal file
17
metadata-ingestion/examples/library/get_lineage_basic.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from datahub.metadata.urns import DatasetUrn
|
||||||
|
from datahub.sdk.main_client import DataHubClient
|
||||||
|
from datahub.sdk.search_filters import FilterDsl as F
|
||||||
|
|
||||||
|
client = DataHubClient.from_env()
|
||||||
|
|
||||||
|
downstream_lineage = client.lineage.get_lineage(
|
||||||
|
source_urn=DatasetUrn(platform="snowflake", name="downstream_table"),
|
||||||
|
direction="downstream",
|
||||||
|
max_hops=2,
|
||||||
|
filter=F.and_(
|
||||||
|
F.platform("airflow"),
|
||||||
|
F.entity_type("dataJob"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(downstream_lineage)
|
||||||
@ -2,9 +2,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import difflib
|
import difflib
|
||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
|
Dict,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
@ -18,12 +21,7 @@ from typing_extensions import assert_never, deprecated
|
|||||||
import datahub.metadata.schema_classes as models
|
import datahub.metadata.schema_classes as models
|
||||||
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
||||||
from datahub.errors import SdkUsageError
|
from datahub.errors import SdkUsageError
|
||||||
from datahub.metadata.urns import (
|
from datahub.metadata.urns import DataJobUrn, DatasetUrn, QueryUrn, SchemaFieldUrn, Urn
|
||||||
DataJobUrn,
|
|
||||||
DatasetUrn,
|
|
||||||
QueryUrn,
|
|
||||||
Urn,
|
|
||||||
)
|
|
||||||
from datahub.sdk._shared import (
|
from datahub.sdk._shared import (
|
||||||
ChartUrnOrStr,
|
ChartUrnOrStr,
|
||||||
DashboardUrnOrStr,
|
DashboardUrnOrStr,
|
||||||
@ -32,6 +30,8 @@ from datahub.sdk._shared import (
|
|||||||
)
|
)
|
||||||
from datahub.sdk._utils import DEFAULT_ACTOR_URN
|
from datahub.sdk._utils import DEFAULT_ACTOR_URN
|
||||||
from datahub.sdk.dataset import ColumnLineageMapping, parse_cll_mapping
|
from datahub.sdk.dataset import ColumnLineageMapping, parse_cll_mapping
|
||||||
|
from datahub.sdk.search_client import compile_filters
|
||||||
|
from datahub.sdk.search_filters import Filter, FilterDsl
|
||||||
from datahub.specific.chart import ChartPatchBuilder
|
from datahub.specific.chart import ChartPatchBuilder
|
||||||
from datahub.specific.dashboard import DashboardPatchBuilder
|
from datahub.specific.dashboard import DashboardPatchBuilder
|
||||||
from datahub.specific.datajob import DataJobPatchBuilder
|
from datahub.specific.datajob import DataJobPatchBuilder
|
||||||
@ -53,9 +53,29 @@ _empty_audit_stamp = models.AuditStampClass(
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LineagePath:
|
||||||
|
urn: str
|
||||||
|
entity_name: str
|
||||||
|
column_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LineageResult:
|
||||||
|
urn: str
|
||||||
|
type: str
|
||||||
|
hops: int
|
||||||
|
direction: Literal["upstream", "downstream"]
|
||||||
|
platform: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
paths: Optional[List[LineagePath]] = None
|
||||||
|
|
||||||
|
|
||||||
class LineageClient:
|
class LineageClient:
|
||||||
def __init__(self, client: DataHubClient):
|
def __init__(self, client: DataHubClient):
|
||||||
self._client = client
|
self._client = client
|
||||||
|
self._graph = client._graph
|
||||||
|
|
||||||
def _get_fields_from_dataset_urn(self, dataset_urn: DatasetUrn) -> Set[str]:
|
def _get_fields_from_dataset_urn(self, dataset_urn: DatasetUrn) -> Set[str]:
|
||||||
schema_metadata = self._client._graph.get_aspect(
|
schema_metadata = self._client._graph.get_aspect(
|
||||||
@ -700,3 +720,242 @@ class LineageClient:
|
|||||||
|
|
||||||
# Apply the changes to the entity
|
# Apply the changes to the entity
|
||||||
self._client.entities.update(patch_builder)
|
self._client.entities.update(patch_builder)
|
||||||
|
|
||||||
|
def get_lineage(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
source_urn: Union[str, Urn],
|
||||||
|
source_column: Optional[str] = None,
|
||||||
|
direction: Literal["upstream", "downstream"] = "upstream",
|
||||||
|
max_hops: int = 1,
|
||||||
|
filter: Optional[Filter] = None,
|
||||||
|
count: int = 500,
|
||||||
|
) -> List[LineageResult]:
|
||||||
|
"""
|
||||||
|
Retrieve lineage entities connected to a source entity.
|
||||||
|
Args:
|
||||||
|
source_urn: Source URN for the lineage search
|
||||||
|
source_column: Source column for the lineage search
|
||||||
|
direction: Direction of lineage traversal
|
||||||
|
max_hops: Maximum number of hops to traverse
|
||||||
|
filter: Filters to apply to the lineage search
|
||||||
|
count: Maximum number of results to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of lineage results
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SdkUsageError for invalid filter values
|
||||||
|
"""
|
||||||
|
# Validate and convert input URN
|
||||||
|
source_urn = Urn.from_string(source_urn)
|
||||||
|
# Prepare GraphQL query variables with a separate method
|
||||||
|
variables = self._process_input_variables(
|
||||||
|
source_urn, source_column, filter, direction, max_hops, count
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._execute_lineage_query(variables, direction)
|
||||||
|
|
||||||
|
def _process_input_variables(
|
||||||
|
self,
|
||||||
|
source_urn: Urn,
|
||||||
|
source_column: Optional[str] = None,
|
||||||
|
filters: Optional[Filter] = None,
|
||||||
|
direction: Literal["upstream", "downstream"] = "upstream",
|
||||||
|
max_hops: int = 1,
|
||||||
|
count: int = 500,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Process filters and prepare GraphQL query variables for lineage search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_urn: Source URN for the lineage search
|
||||||
|
source_column: Source column for the lineage search
|
||||||
|
filters: Optional filters to apply
|
||||||
|
direction: Direction of lineage traversal
|
||||||
|
max_hops: Maximum number of hops to traverse
|
||||||
|
count: Maximum number of results to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of GraphQL query variables
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SdkUsageError for invalid filter values
|
||||||
|
"""
|
||||||
|
|
||||||
|
# print warning if max_hops is greater than 2
|
||||||
|
if max_hops > 2:
|
||||||
|
logger.warning(
|
||||||
|
"""If `max_hops` is more than 2, the search will try to find the full lineage graph.
|
||||||
|
By default, only 500 results are shown.
|
||||||
|
You can change the `count` to get more or fewer results.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine hop values
|
||||||
|
max_hop_values = (
|
||||||
|
[str(hop) for hop in range(1, max_hops + 1)]
|
||||||
|
if max_hops <= 2
|
||||||
|
else ["1", "2", "3+"]
|
||||||
|
)
|
||||||
|
|
||||||
|
degree_filter = FilterDsl.custom_filter(
|
||||||
|
field="degree",
|
||||||
|
condition="EQUAL",
|
||||||
|
values=max_hop_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
filters_with_max_hops = (
|
||||||
|
FilterDsl.and_(degree_filter, filters)
|
||||||
|
if filters is not None
|
||||||
|
else degree_filter
|
||||||
|
)
|
||||||
|
|
||||||
|
types, compiled_filters = compile_filters(filters_with_max_hops)
|
||||||
|
|
||||||
|
# Prepare base variables
|
||||||
|
variables: Dict[str, Any] = {
|
||||||
|
"input": {
|
||||||
|
"urn": str(source_urn),
|
||||||
|
"direction": direction.upper(),
|
||||||
|
"count": count,
|
||||||
|
"types": types,
|
||||||
|
"orFilters": compiled_filters,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# if column is provided, update the variables to include the schema field urn
|
||||||
|
if isinstance(source_urn, SchemaFieldUrn) or source_column:
|
||||||
|
variables["input"]["searchFlags"] = {
|
||||||
|
"groupingSpec": {
|
||||||
|
"groupingCriteria": {
|
||||||
|
"baseEntityType": "SCHEMA_FIELD",
|
||||||
|
"groupingEntityType": "SCHEMA_FIELD",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isinstance(source_urn, SchemaFieldUrn):
|
||||||
|
variables["input"]["urn"] = str(source_urn)
|
||||||
|
elif source_column:
|
||||||
|
variables["input"]["urn"] = str(SchemaFieldUrn(source_urn, source_column))
|
||||||
|
|
||||||
|
return variables
|
||||||
|
|
||||||
|
def _execute_lineage_query(
|
||||||
|
self,
|
||||||
|
variables: Dict[str, Any],
|
||||||
|
direction: Literal["upstream", "downstream"],
|
||||||
|
) -> List[LineageResult]:
|
||||||
|
"""Execute GraphQL query and process results."""
|
||||||
|
# Construct GraphQL query with dynamic path query
|
||||||
|
graphql_query = """
|
||||||
|
query scrollAcrossLineage($input: ScrollAcrossLineageInput!) {
|
||||||
|
scrollAcrossLineage(input: $input) {
|
||||||
|
nextScrollId
|
||||||
|
searchResults {
|
||||||
|
degree
|
||||||
|
entity {
|
||||||
|
urn
|
||||||
|
type
|
||||||
|
... on Dataset {
|
||||||
|
name
|
||||||
|
platform {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
properties {
|
||||||
|
description
|
||||||
|
}
|
||||||
|
}
|
||||||
|
... on DataJob {
|
||||||
|
jobId
|
||||||
|
dataPlatformInstance {
|
||||||
|
platform {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
properties {
|
||||||
|
name
|
||||||
|
description
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
paths {
|
||||||
|
path {
|
||||||
|
urn
|
||||||
|
type
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
results: List[LineageResult] = []
|
||||||
|
|
||||||
|
first_iter = True
|
||||||
|
scroll_id: Optional[str] = None
|
||||||
|
|
||||||
|
while first_iter or scroll_id:
|
||||||
|
first_iter = False
|
||||||
|
|
||||||
|
# Update scroll ID if applicable
|
||||||
|
if scroll_id:
|
||||||
|
variables["input"]["scrollId"] = scroll_id
|
||||||
|
|
||||||
|
# Execute GraphQL query
|
||||||
|
response = self._graph.execute_graphql(graphql_query, variables=variables)
|
||||||
|
data = response["scrollAcrossLineage"]
|
||||||
|
scroll_id = data.get("nextScrollId")
|
||||||
|
|
||||||
|
# Process search results
|
||||||
|
for entry in data["searchResults"]:
|
||||||
|
entity = entry["entity"]
|
||||||
|
|
||||||
|
result = self._create_lineage_result(entity, entry, direction)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _create_lineage_result(
|
||||||
|
self,
|
||||||
|
entity: Dict[str, Any],
|
||||||
|
entry: Dict[str, Any],
|
||||||
|
direction: Literal["upstream", "downstream"],
|
||||||
|
) -> LineageResult:
|
||||||
|
"""Create a LineageResult from entity and entry data."""
|
||||||
|
platform = entity.get("platform", {}).get("name") or entity.get(
|
||||||
|
"dataPlatformInstance", {}
|
||||||
|
).get("platform", {}).get("name")
|
||||||
|
|
||||||
|
result = LineageResult(
|
||||||
|
urn=entity["urn"],
|
||||||
|
type=entity["type"],
|
||||||
|
hops=entry["degree"],
|
||||||
|
direction=direction,
|
||||||
|
platform=platform,
|
||||||
|
)
|
||||||
|
|
||||||
|
properties = entity.get("properties", {})
|
||||||
|
if properties:
|
||||||
|
result.name = properties.get("name", "")
|
||||||
|
result.description = properties.get("description", "")
|
||||||
|
|
||||||
|
result.paths = []
|
||||||
|
if "paths" in entry:
|
||||||
|
# Process each path in the lineage graph
|
||||||
|
for path in entry["paths"]:
|
||||||
|
for path_entry in path["path"]:
|
||||||
|
# Only include schema fields in the path (exclude other types like Query)
|
||||||
|
if path_entry["type"] == "SCHEMA_FIELD":
|
||||||
|
schema_field_urn = SchemaFieldUrn.from_string(path_entry["urn"])
|
||||||
|
result.paths.append(
|
||||||
|
LineagePath(
|
||||||
|
urn=path_entry["urn"],
|
||||||
|
entity_name=DatasetUrn.from_string(
|
||||||
|
schema_field_urn.parent
|
||||||
|
).name,
|
||||||
|
column_name=schema_field_urn.field_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
import pathlib
|
import pathlib
|
||||||
from typing import Dict, List, Sequence, Set, cast
|
from typing import Dict, List, Optional, Sequence, Set, cast
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from datahub.errors import SdkUsageError
|
from datahub.errors import SdkUsageError
|
||||||
from datahub.sdk.main_client import DataHubClient
|
from datahub.sdk.main_client import DataHubClient
|
||||||
|
from datahub.sdk.search_filters import FilterDsl as F
|
||||||
from datahub.sql_parsing.sql_parsing_common import QueryType
|
from datahub.sql_parsing.sql_parsing_common import QueryType
|
||||||
from datahub.sql_parsing.sqlglot_lineage import (
|
from datahub.sql_parsing.sqlglot_lineage import (
|
||||||
ColumnLineageInfo,
|
ColumnLineageInfo,
|
||||||
@ -127,38 +128,8 @@ def test_get_strict_column_lineage(
|
|||||||
cast(Set[str], upstream_fields),
|
cast(Set[str], upstream_fields),
|
||||||
cast(Set[str], downstream_fields),
|
cast(Set[str], downstream_fields),
|
||||||
)
|
)
|
||||||
assert result == expected, f"Test failed: {result} != {expected}"
|
|
||||||
"""Test the strict column lineage matching algorithm."""
|
"""Test the strict column lineage matching algorithm."""
|
||||||
test_cases = [
|
assert result == expected, f"Test failed: {result} != {expected}"
|
||||||
# Exact matches
|
|
||||||
{
|
|
||||||
"upstream_fields": {"id", "name", "email"},
|
|
||||||
"downstream_fields": {"id", "name", "phone"},
|
|
||||||
"expected": {"id": ["id"], "name": ["name"]},
|
|
||||||
},
|
|
||||||
# No matches
|
|
||||||
{
|
|
||||||
"upstream_fields": {"col1", "col2", "col3"},
|
|
||||||
"downstream_fields": {"col4", "col5", "col6"},
|
|
||||||
"expected": {},
|
|
||||||
},
|
|
||||||
# Case mismatch (should match)
|
|
||||||
{
|
|
||||||
"upstream_fields": {"ID", "Name", "Email"},
|
|
||||||
"downstream_fields": {"id", "name", "email"},
|
|
||||||
"expected": {"id": ["ID"], "name": ["Name"], "email": ["Email"]},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Run test cases
|
|
||||||
for test_case in test_cases:
|
|
||||||
result = client.lineage._get_strict_column_lineage(
|
|
||||||
cast(Set[str], test_case["upstream_fields"]),
|
|
||||||
cast(Set[str], test_case["downstream_fields"]),
|
|
||||||
)
|
|
||||||
assert result == test_case["expected"], (
|
|
||||||
f"Test failed: {result} != {test_case['expected']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_infer_lineage_from_sql(client: DataHubClient) -> None:
|
def test_infer_lineage_from_sql(client: DataHubClient) -> None:
|
||||||
@ -444,3 +415,284 @@ def test_add_lineage_invalid_parameter_combinations(client: DataHubClient) -> No
|
|||||||
downstream="urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)",
|
downstream="urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)",
|
||||||
column_lineage={"target_col": ["source_col"]},
|
column_lineage={"target_col": ["source_col"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_lineage_basic(client: DataHubClient) -> None:
|
||||||
|
"""Test basic lineage retrieval with default parameters."""
|
||||||
|
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||||
|
|
||||||
|
# Mock GraphQL response
|
||||||
|
mock_response = {
|
||||||
|
"scrollAcrossLineage": {
|
||||||
|
"nextScrollId": None,
|
||||||
|
"searchResults": [
|
||||||
|
{
|
||||||
|
"entity": {
|
||||||
|
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)",
|
||||||
|
"type": "DATASET",
|
||||||
|
"platform": {"name": "snowflake"},
|
||||||
|
"properties": {
|
||||||
|
"name": "upstream_table",
|
||||||
|
"description": "Upstream source table",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"degree": 1,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Patch the GraphQL execution method
|
||||||
|
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
|
||||||
|
results = client.lineage.get_lineage(source_urn=source_urn)
|
||||||
|
|
||||||
|
# Validate results
|
||||||
|
assert len(results) == 1
|
||||||
|
assert (
|
||||||
|
results[0].urn
|
||||||
|
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
|
||||||
|
)
|
||||||
|
assert results[0].type == "DATASET"
|
||||||
|
assert results[0].name == "upstream_table"
|
||||||
|
assert results[0].platform == "snowflake"
|
||||||
|
assert results[0].direction == "upstream"
|
||||||
|
assert results[0].hops == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_lineage_with_entity_type_filters(client: DataHubClient) -> None:
|
||||||
|
"""Test lineage retrieval with entity type and platform filters."""
|
||||||
|
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||||
|
|
||||||
|
# Mock GraphQL response with multiple entity types
|
||||||
|
mock_response = {
|
||||||
|
"scrollAcrossLineage": {
|
||||||
|
"nextScrollId": None,
|
||||||
|
"searchResults": [
|
||||||
|
{
|
||||||
|
"entity": {
|
||||||
|
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)",
|
||||||
|
"type": "DATASET",
|
||||||
|
"platform": {"name": "snowflake"},
|
||||||
|
"properties": {
|
||||||
|
"name": "upstream_table",
|
||||||
|
"description": "Upstream source table",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"degree": 1,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Patch the GraphQL execution method to return results for multiple calls
|
||||||
|
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
|
||||||
|
results = client.lineage.get_lineage(
|
||||||
|
source_urn=source_urn,
|
||||||
|
filter=F.entity_type("dataset"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate results
|
||||||
|
assert len(results) == 1
|
||||||
|
assert {r.type for r in results} == {"DATASET"}
|
||||||
|
assert {r.platform for r in results} == {"snowflake"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_lineage_downstream(client: DataHubClient) -> None:
|
||||||
|
"""Test downstream lineage retrieval."""
|
||||||
|
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||||
|
|
||||||
|
# Mock GraphQL response for downstream lineage
|
||||||
|
mock_response = {
|
||||||
|
"scrollAcrossLineage": {
|
||||||
|
"nextScrollId": None,
|
||||||
|
"searchResults": [
|
||||||
|
{
|
||||||
|
"entity": {
|
||||||
|
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)",
|
||||||
|
"type": "DATASET",
|
||||||
|
"properties": {
|
||||||
|
"name": "downstream_table",
|
||||||
|
"description": "Downstream target table",
|
||||||
|
"platform": {"name": "snowflake"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"degree": 1,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Patch the GraphQL execution method
|
||||||
|
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
|
||||||
|
results = client.lineage.get_lineage(
|
||||||
|
source_urn=source_urn,
|
||||||
|
direction="downstream",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate results
|
||||||
|
assert len(results) == 1
|
||||||
|
assert (
|
||||||
|
results[0].urn
|
||||||
|
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)"
|
||||||
|
)
|
||||||
|
assert results[0].direction == "downstream"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_lineage_multiple_hops(
|
||||||
|
client: DataHubClient, caplog: pytest.LogCaptureFixture
|
||||||
|
) -> None:
|
||||||
|
"""Test lineage retrieval with multiple hops."""
|
||||||
|
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||||
|
|
||||||
|
# Mock GraphQL response with multiple hops
|
||||||
|
mock_response = {
|
||||||
|
"scrollAcrossLineage": {
|
||||||
|
"nextScrollId": None,
|
||||||
|
"searchResults": [
|
||||||
|
{
|
||||||
|
"entity": {
|
||||||
|
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table1,PROD)",
|
||||||
|
"type": "DATASET",
|
||||||
|
"properties": {
|
||||||
|
"name": "upstream_table1",
|
||||||
|
"description": "First upstream table",
|
||||||
|
"platform": {"name": "snowflake"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"degree": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity": {
|
||||||
|
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table2,PROD)",
|
||||||
|
"type": "DATASET",
|
||||||
|
"properties": {
|
||||||
|
"name": "upstream_table2",
|
||||||
|
"description": "Second upstream table",
|
||||||
|
"platform": {"name": "snowflake"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"degree": 2,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Patch the GraphQL execution method
|
||||||
|
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
|
||||||
|
results = client.lineage.get_lineage(source_urn=source_urn, max_hops=2)
|
||||||
|
|
||||||
|
# check warning if logged when max_hops > 2
|
||||||
|
with patch.object(
|
||||||
|
client._graph, "execute_graphql", return_value=mock_response
|
||||||
|
), caplog.at_level("WARNING"):
|
||||||
|
client.lineage.get_lineage(source_urn=source_urn, max_hops=3)
|
||||||
|
assert any(
|
||||||
|
"the search will try to find the full lineage graph" in msg
|
||||||
|
for msg in caplog.messages
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate results
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0].hops == 1
|
||||||
|
assert results[0].type == "DATASET"
|
||||||
|
assert results[1].hops == 2
|
||||||
|
assert results[1].type == "DATASET"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_lineage_no_results(client: DataHubClient) -> None:
|
||||||
|
"""Test lineage retrieval with no results."""
|
||||||
|
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||||
|
|
||||||
|
# Mock GraphQL response with no results
|
||||||
|
mock_response: Dict[str, Dict[str, Optional[List]]] = {
|
||||||
|
"scrollAcrossLineage": {"nextScrollId": None, "searchResults": []}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Patch the GraphQL execution method
|
||||||
|
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
|
||||||
|
results = client.lineage.get_lineage(source_urn=source_urn)
|
||||||
|
|
||||||
|
# Validate results
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_lineage_column_lineage_with_source_column(client: DataHubClient) -> None:
|
||||||
|
"""Test lineage retrieval with column lineage."""
|
||||||
|
source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)"
|
||||||
|
source_column = "source_column"
|
||||||
|
|
||||||
|
# Mock GraphQL response with column lineage
|
||||||
|
mock_response = {
|
||||||
|
"scrollAcrossLineage": {
|
||||||
|
"nextScrollId": None,
|
||||||
|
"searchResults": [
|
||||||
|
{
|
||||||
|
"entity": {
|
||||||
|
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)",
|
||||||
|
"type": "DATASET",
|
||||||
|
"platform": {"name": "snowflake"},
|
||||||
|
"properties": {
|
||||||
|
"name": "upstream_table",
|
||||||
|
"description": "Upstream source table",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"degree": 1,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Patch the GraphQL execution method
|
||||||
|
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
|
||||||
|
results = client.lineage.get_lineage(
|
||||||
|
source_urn=source_urn,
|
||||||
|
source_column=source_column,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate results
|
||||||
|
assert len(results) == 1
|
||||||
|
assert (
|
||||||
|
results[0].urn
|
||||||
|
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_lineage_column_lineage_with_schema_field_urn(
|
||||||
|
client: DataHubClient,
|
||||||
|
) -> None:
|
||||||
|
"""Test lineage retrieval with column lineage."""
|
||||||
|
source_urn = "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD),source_column)"
|
||||||
|
|
||||||
|
# Mock GraphQL response with column lineage
|
||||||
|
mock_response = {
|
||||||
|
"scrollAcrossLineage": {
|
||||||
|
"nextScrollId": None,
|
||||||
|
"searchResults": [
|
||||||
|
{
|
||||||
|
"entity": {
|
||||||
|
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)",
|
||||||
|
"type": "DATASET",
|
||||||
|
"platform": {"name": "snowflake"},
|
||||||
|
"properties": {
|
||||||
|
"name": "upstream_table",
|
||||||
|
"description": "Upstream source table",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"degree": 1,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Patch the GraphQL execution method
|
||||||
|
with patch.object(client._graph, "execute_graphql", return_value=mock_response):
|
||||||
|
results = client.lineage.get_lineage(
|
||||||
|
source_urn=source_urn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate results
|
||||||
|
assert len(results) == 1
|
||||||
|
assert (
|
||||||
|
results[0].urn
|
||||||
|
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)"
|
||||||
|
)
|
||||||
|
|||||||
222
smoke-test/tests/lineage/test_lineage_sdk.py
Normal file
222
smoke-test/tests/lineage/test_lineage_sdk.py
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
from typing import Dict, Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from datahub.ingestion.graph.client import DataHubGraph
|
||||||
|
from datahub.metadata.urns import SchemaFieldUrn
|
||||||
|
from datahub.sdk.dataset import Dataset
|
||||||
|
from datahub.sdk.lineage_client import LineageResult
|
||||||
|
from datahub.sdk.main_client import DataHubClient
|
||||||
|
from datahub.sdk.search_filters import FilterDsl as F
|
||||||
|
from tests.utils import wait_for_writes_to_sync
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def test_client(graph_client: DataHubGraph) -> DataHubClient:
|
||||||
|
return DataHubClient(graph=graph_client)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def test_datasets(
|
||||||
|
test_client: DataHubClient,
|
||||||
|
) -> Generator[Dict[str, Dataset], None, None]:
|
||||||
|
datasets = {
|
||||||
|
"upstream": Dataset(
|
||||||
|
platform="snowflake",
|
||||||
|
name="test_lineage_upstream_001",
|
||||||
|
schema=[("name", "string"), ("id", "int")],
|
||||||
|
),
|
||||||
|
"downstream1": Dataset(
|
||||||
|
platform="snowflake",
|
||||||
|
name="test_lineage_downstream_001",
|
||||||
|
schema=[("name", "string"), ("id", "int")],
|
||||||
|
),
|
||||||
|
"downstream2": Dataset(
|
||||||
|
platform="snowflake",
|
||||||
|
name="test_lineage_downstream_002",
|
||||||
|
schema=[("name", "string"), ("id", "int")],
|
||||||
|
),
|
||||||
|
"downstream3": Dataset(
|
||||||
|
platform="mysql",
|
||||||
|
name="test_lineage_downstream_003",
|
||||||
|
schema=[("name", "string"), ("id", "int")],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
for entity in datasets.values():
|
||||||
|
test_client._graph.delete_entity(str(entity.urn), hard=True)
|
||||||
|
for entity in datasets.values():
|
||||||
|
test_client.entities.upsert(entity)
|
||||||
|
|
||||||
|
# Add lineage
|
||||||
|
test_client.lineage.add_lineage(
|
||||||
|
upstream=str(datasets["upstream"].urn),
|
||||||
|
downstream=str(datasets["downstream1"].urn),
|
||||||
|
column_lineage=True,
|
||||||
|
)
|
||||||
|
test_client.lineage.add_lineage(
|
||||||
|
upstream=str(datasets["downstream1"].urn),
|
||||||
|
downstream=str(datasets["downstream2"].urn),
|
||||||
|
column_lineage=True,
|
||||||
|
)
|
||||||
|
test_client.lineage.add_lineage(
|
||||||
|
upstream=str(datasets["downstream2"].urn),
|
||||||
|
downstream=str(datasets["downstream3"].urn),
|
||||||
|
column_lineage=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
wait_for_writes_to_sync()
|
||||||
|
|
||||||
|
yield datasets
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
for entity in datasets.values():
|
||||||
|
try:
|
||||||
|
test_client._graph.delete_entity(str(entity.urn), hard=True)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Could not delete entity {entity.urn}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_lineage_results(
|
||||||
|
lineage_result: LineageResult,
|
||||||
|
hops=None,
|
||||||
|
direction=None,
|
||||||
|
platform=None,
|
||||||
|
urn=None,
|
||||||
|
paths_len=None,
|
||||||
|
):
|
||||||
|
if hops is not None:
|
||||||
|
assert lineage_result.hops == hops
|
||||||
|
if direction is not None:
|
||||||
|
assert lineage_result.direction == direction
|
||||||
|
if platform is not None:
|
||||||
|
assert lineage_result.platform == platform
|
||||||
|
if urn is not None:
|
||||||
|
assert lineage_result.urn == urn
|
||||||
|
if paths_len is not None and lineage_result.paths is not None:
|
||||||
|
assert len(lineage_result.paths) == paths_len
|
||||||
|
|
||||||
|
|
||||||
|
def test_table_level_lineage(
|
||||||
|
test_client: DataHubClient, test_datasets: Dict[str, Dataset]
|
||||||
|
):
|
||||||
|
table_lineage_results = test_client.lineage.get_lineage(
|
||||||
|
source_urn=str(test_datasets["upstream"].urn),
|
||||||
|
direction="downstream",
|
||||||
|
max_hops=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(table_lineage_results) == 3
|
||||||
|
urns = {r.urn for r in table_lineage_results}
|
||||||
|
expected = {
|
||||||
|
str(test_datasets["downstream1"].urn),
|
||||||
|
str(test_datasets["downstream2"].urn),
|
||||||
|
str(test_datasets["downstream3"].urn),
|
||||||
|
}
|
||||||
|
assert urns == expected
|
||||||
|
|
||||||
|
table_lineage_results = sorted(table_lineage_results, key=lambda x: x.hops)
|
||||||
|
validate_lineage_results(
|
||||||
|
table_lineage_results[0],
|
||||||
|
hops=1,
|
||||||
|
platform="snowflake",
|
||||||
|
urn=str(test_datasets["downstream1"].urn),
|
||||||
|
paths_len=0,
|
||||||
|
)
|
||||||
|
validate_lineage_results(
|
||||||
|
table_lineage_results[1],
|
||||||
|
hops=2,
|
||||||
|
platform="snowflake",
|
||||||
|
urn=str(test_datasets["downstream2"].urn),
|
||||||
|
paths_len=0,
|
||||||
|
)
|
||||||
|
validate_lineage_results(
|
||||||
|
table_lineage_results[2],
|
||||||
|
hops=3,
|
||||||
|
platform="mysql",
|
||||||
|
urn=str(test_datasets["downstream3"].urn),
|
||||||
|
paths_len=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_column_level_lineage(
|
||||||
|
test_client: DataHubClient, test_datasets: Dict[str, Dataset]
|
||||||
|
):
|
||||||
|
column_lineage_results = test_client.lineage.get_lineage(
|
||||||
|
source_urn=str(test_datasets["upstream"].urn),
|
||||||
|
source_column="id",
|
||||||
|
direction="downstream",
|
||||||
|
max_hops=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(column_lineage_results) == 3
|
||||||
|
column_lineage_results = sorted(column_lineage_results, key=lambda x: x.hops)
|
||||||
|
validate_lineage_results(
|
||||||
|
column_lineage_results[0],
|
||||||
|
hops=1,
|
||||||
|
urn=str(test_datasets["downstream1"].urn),
|
||||||
|
paths_len=2,
|
||||||
|
)
|
||||||
|
validate_lineage_results(
|
||||||
|
column_lineage_results[1],
|
||||||
|
hops=2,
|
||||||
|
urn=str(test_datasets["downstream2"].urn),
|
||||||
|
paths_len=3,
|
||||||
|
)
|
||||||
|
validate_lineage_results(
|
||||||
|
column_lineage_results[2],
|
||||||
|
hops=3,
|
||||||
|
urn=str(test_datasets["downstream3"].urn),
|
||||||
|
paths_len=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_filtered_column_level_lineage(
|
||||||
|
test_client: DataHubClient, test_datasets: Dict[str, Dataset]
|
||||||
|
):
|
||||||
|
filtered_column_lineage_results = test_client.lineage.get_lineage(
|
||||||
|
source_urn=str(test_datasets["upstream"].urn),
|
||||||
|
source_column="id",
|
||||||
|
direction="downstream",
|
||||||
|
max_hops=3,
|
||||||
|
filter=F.and_(F.platform("mysql"), F.entity_type("dataset")),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(filtered_column_lineage_results) == 1
|
||||||
|
validate_lineage_results(
|
||||||
|
filtered_column_lineage_results[0],
|
||||||
|
hops=3,
|
||||||
|
platform="mysql",
|
||||||
|
urn=str(test_datasets["downstream3"].urn),
|
||||||
|
paths_len=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_column_level_lineage_from_schema_field(
|
||||||
|
test_client: DataHubClient, test_datasets: Dict[str, Dataset]
|
||||||
|
):
|
||||||
|
source_schema_field = SchemaFieldUrn(test_datasets["upstream"].urn, "id")
|
||||||
|
column_lineage_results = test_client.lineage.get_lineage(
|
||||||
|
source_urn=str(source_schema_field), direction="downstream", max_hops=3
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(column_lineage_results) == 3
|
||||||
|
column_lineage_results = sorted(column_lineage_results, key=lambda x: x.hops)
|
||||||
|
validate_lineage_results(
|
||||||
|
column_lineage_results[0],
|
||||||
|
hops=1,
|
||||||
|
urn=str(test_datasets["downstream1"].urn),
|
||||||
|
paths_len=2,
|
||||||
|
)
|
||||||
|
validate_lineage_results(
|
||||||
|
column_lineage_results[1],
|
||||||
|
hops=2,
|
||||||
|
urn=str(test_datasets["downstream2"].urn),
|
||||||
|
paths_len=3,
|
||||||
|
)
|
||||||
|
validate_lineage_results(
|
||||||
|
column_lineage_results[2],
|
||||||
|
hops=3,
|
||||||
|
urn=str(test_datasets["downstream3"].urn),
|
||||||
|
paths_len=4,
|
||||||
|
)
|
||||||
Loading…
x
Reference in New Issue
Block a user