From e169b4ac05d1ebf3c5940ea3512ee14bc3e505f5 Mon Sep 17 00:00:00 2001 From: Hyejin Yoon <0327jane@gmail.com> Date: Fri, 6 Jun 2025 12:34:52 +0900 Subject: [PATCH] feat(sdk): add get_lineage (#13654) --- .../examples/library/get_column_lineage.py | 24 ++ .../examples/library/get_lineage_basic.py | 17 + .../src/datahub/sdk/lineage_client.py | 271 ++++++++++++++- .../tests/unit/sdk_v2/test_lineage_client.py | 316 ++++++++++++++++-- smoke-test/tests/lineage/test_lineage_sdk.py | 222 ++++++++++++ 5 files changed, 812 insertions(+), 38 deletions(-) create mode 100644 metadata-ingestion/examples/library/get_column_lineage.py create mode 100644 metadata-ingestion/examples/library/get_lineage_basic.py create mode 100644 smoke-test/tests/lineage/test_lineage_sdk.py diff --git a/metadata-ingestion/examples/library/get_column_lineage.py b/metadata-ingestion/examples/library/get_column_lineage.py new file mode 100644 index 0000000000..7ba96321c5 --- /dev/null +++ b/metadata-ingestion/examples/library/get_column_lineage.py @@ -0,0 +1,24 @@ +from datahub.metadata.urns import DatasetUrn +from datahub.sdk.main_client import DataHubClient +from datahub.sdk.search_filters import FilterDsl as F + +client = DataHubClient.from_env() + +dataset_urn = DatasetUrn(platform="snowflake", name="downstream_table") + +# Get column lineage for the entire flow +# you can pass source_urn and source_column to get lineage for a specific column +# alternatively, you can pass schemaFieldUrn to source_urn. +# e.g. source_urn="urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table),id)" +downstream_column_lineage = client.lineage.get_lineage( + source_urn=dataset_urn, + source_column="id", + direction="downstream", + max_hops=1, + filter=F.and_( + F.platform("snowflake"), + F.entity_type("dataset"), + ), +) + +print(downstream_column_lineage) diff --git a/metadata-ingestion/examples/library/get_lineage_basic.py b/metadata-ingestion/examples/library/get_lineage_basic.py new file mode 100644 index 0000000000..5884f8ff3f --- /dev/null +++ b/metadata-ingestion/examples/library/get_lineage_basic.py @@ -0,0 +1,17 @@ +from datahub.metadata.urns import DatasetUrn +from datahub.sdk.main_client import DataHubClient +from datahub.sdk.search_filters import FilterDsl as F + +client = DataHubClient.from_env() + +downstream_lineage = client.lineage.get_lineage( + source_urn=DatasetUrn(platform="snowflake", name="downstream_table"), + direction="downstream", + max_hops=2, + filter=F.and_( + F.platform("airflow"), + F.entity_type("dataJob"), + ), +) + +print(downstream_lineage) diff --git a/metadata-ingestion/src/datahub/sdk/lineage_client.py b/metadata-ingestion/src/datahub/sdk/lineage_client.py index 1df6c7a6ab..ee27f48cb4 100644 --- a/metadata-ingestion/src/datahub/sdk/lineage_client.py +++ b/metadata-ingestion/src/datahub/sdk/lineage_client.py @@ -2,9 +2,12 @@ from __future__ import annotations import difflib import logging +from dataclasses import dataclass from typing import ( TYPE_CHECKING, + Any, Callable, + Dict, List, Literal, Optional, @@ -18,12 +21,7 @@ from typing_extensions import assert_never, deprecated import datahub.metadata.schema_classes as models from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.errors import SdkUsageError -from datahub.metadata.urns import ( - DataJobUrn, - DatasetUrn, - QueryUrn, - Urn, -) +from datahub.metadata.urns import DataJobUrn, DatasetUrn, QueryUrn, SchemaFieldUrn, Urn from datahub.sdk._shared import ( ChartUrnOrStr, DashboardUrnOrStr, @@ -32,6 +30,8 @@ from datahub.sdk._shared import ( ) from datahub.sdk._utils import DEFAULT_ACTOR_URN from datahub.sdk.dataset import ColumnLineageMapping, parse_cll_mapping +from datahub.sdk.search_client import compile_filters +from datahub.sdk.search_filters import Filter, FilterDsl from datahub.specific.chart import ChartPatchBuilder from datahub.specific.dashboard import DashboardPatchBuilder from datahub.specific.datajob import DataJobPatchBuilder @@ -53,9 +53,29 @@ _empty_audit_stamp = models.AuditStampClass( logger = logging.getLogger(__name__) +@dataclass +class LineagePath: + urn: str + entity_name: str + column_name: Optional[str] = None + + +@dataclass +class LineageResult: + urn: str + type: str + hops: int + direction: Literal["upstream", "downstream"] + platform: Optional[str] = None + name: Optional[str] = None + description: Optional[str] = None + paths: Optional[List[LineagePath]] = None + + class LineageClient: def __init__(self, client: DataHubClient): self._client = client + self._graph = client._graph def _get_fields_from_dataset_urn(self, dataset_urn: DatasetUrn) -> Set[str]: schema_metadata = self._client._graph.get_aspect( @@ -700,3 +720,242 @@ class LineageClient: # Apply the changes to the entity self._client.entities.update(patch_builder) + + def get_lineage( + self, + *, + source_urn: Union[str, Urn], + source_column: Optional[str] = None, + direction: Literal["upstream", "downstream"] = "upstream", + max_hops: int = 1, + filter: Optional[Filter] = None, + count: int = 500, + ) -> List[LineageResult]: + """ + Retrieve lineage entities connected to a source entity. + Args: + source_urn: Source URN for the lineage search + source_column: Source column for the lineage search + direction: Direction of lineage traversal + max_hops: Maximum number of hops to traverse + filter: Filters to apply to the lineage search + count: Maximum number of results to return + + Returns: + List of lineage results + + Raises: + SdkUsageError for invalid filter values + """ + # Validate and convert input URN + source_urn = Urn.from_string(source_urn) + # Prepare GraphQL query variables with a separate method + variables = self._process_input_variables( + source_urn, source_column, filter, direction, max_hops, count + ) + + return self._execute_lineage_query(variables, direction) + + def _process_input_variables( + self, + source_urn: Urn, + source_column: Optional[str] = None, + filters: Optional[Filter] = None, + direction: Literal["upstream", "downstream"] = "upstream", + max_hops: int = 1, + count: int = 500, + ) -> Dict[str, Any]: + """ + Process filters and prepare GraphQL query variables for lineage search. + + Args: + source_urn: Source URN for the lineage search + source_column: Source column for the lineage search + filters: Optional filters to apply + direction: Direction of lineage traversal + max_hops: Maximum number of hops to traverse + count: Maximum number of results to return + + Returns: + Dictionary of GraphQL query variables + + Raises: + SdkUsageError for invalid filter values + """ + + # print warning if max_hops is greater than 2 + if max_hops > 2: + logger.warning( + """If `max_hops` is more than 2, the search will try to find the full lineage graph. + By default, only 500 results are shown. + You can change the `count` to get more or fewer results. + """ + ) + + # Determine hop values + max_hop_values = ( + [str(hop) for hop in range(1, max_hops + 1)] + if max_hops <= 2 + else ["1", "2", "3+"] + ) + + degree_filter = FilterDsl.custom_filter( + field="degree", + condition="EQUAL", + values=max_hop_values, + ) + + filters_with_max_hops = ( + FilterDsl.and_(degree_filter, filters) + if filters is not None + else degree_filter + ) + + types, compiled_filters = compile_filters(filters_with_max_hops) + + # Prepare base variables + variables: Dict[str, Any] = { + "input": { + "urn": str(source_urn), + "direction": direction.upper(), + "count": count, + "types": types, + "orFilters": compiled_filters, + } + } + + # if column is provided, update the variables to include the schema field urn + if isinstance(source_urn, SchemaFieldUrn) or source_column: + variables["input"]["searchFlags"] = { + "groupingSpec": { + "groupingCriteria": { + "baseEntityType": "SCHEMA_FIELD", + "groupingEntityType": "SCHEMA_FIELD", + } + } + } + if isinstance(source_urn, SchemaFieldUrn): + variables["input"]["urn"] = str(source_urn) + elif source_column: + variables["input"]["urn"] = str(SchemaFieldUrn(source_urn, source_column)) + + return variables + + def _execute_lineage_query( + self, + variables: Dict[str, Any], + direction: Literal["upstream", "downstream"], + ) -> List[LineageResult]: + """Execute GraphQL query and process results.""" + # Construct GraphQL query with dynamic path query + graphql_query = """ + query scrollAcrossLineage($input: ScrollAcrossLineageInput!) { + scrollAcrossLineage(input: $input) { + nextScrollId + searchResults { + degree + entity { + urn + type + ... on Dataset { + name + platform { + name + } + properties { + description + } + } + ... on DataJob { + jobId + dataPlatformInstance { + platform { + name + } + } + properties { + name + description + } + } + } + paths { + path { + urn + type + } + } + } + } + } + """ + + results: List[LineageResult] = [] + + first_iter = True + scroll_id: Optional[str] = None + + while first_iter or scroll_id: + first_iter = False + + # Update scroll ID if applicable + if scroll_id: + variables["input"]["scrollId"] = scroll_id + + # Execute GraphQL query + response = self._graph.execute_graphql(graphql_query, variables=variables) + data = response["scrollAcrossLineage"] + scroll_id = data.get("nextScrollId") + + # Process search results + for entry in data["searchResults"]: + entity = entry["entity"] + + result = self._create_lineage_result(entity, entry, direction) + results.append(result) + + return results + + def _create_lineage_result( + self, + entity: Dict[str, Any], + entry: Dict[str, Any], + direction: Literal["upstream", "downstream"], + ) -> LineageResult: + """Create a LineageResult from entity and entry data.""" + platform = entity.get("platform", {}).get("name") or entity.get( + "dataPlatformInstance", {} + ).get("platform", {}).get("name") + + result = LineageResult( + urn=entity["urn"], + type=entity["type"], + hops=entry["degree"], + direction=direction, + platform=platform, + ) + + properties = entity.get("properties", {}) + if properties: + result.name = properties.get("name", "") + result.description = properties.get("description", "") + + result.paths = [] + if "paths" in entry: + # Process each path in the lineage graph + for path in entry["paths"]: + for path_entry in path["path"]: + # Only include schema fields in the path (exclude other types like Query) + if path_entry["type"] == "SCHEMA_FIELD": + schema_field_urn = SchemaFieldUrn.from_string(path_entry["urn"]) + result.paths.append( + LineagePath( + urn=path_entry["urn"], + entity_name=DatasetUrn.from_string( + schema_field_urn.parent + ).name, + column_name=schema_field_urn.field_path, + ) + ) + + return result diff --git a/metadata-ingestion/tests/unit/sdk_v2/test_lineage_client.py b/metadata-ingestion/tests/unit/sdk_v2/test_lineage_client.py index 28c64eea7e..fbb00b7d60 100644 --- a/metadata-ingestion/tests/unit/sdk_v2/test_lineage_client.py +++ b/metadata-ingestion/tests/unit/sdk_v2/test_lineage_client.py @@ -1,11 +1,12 @@ import pathlib -from typing import Dict, List, Sequence, Set, cast +from typing import Dict, List, Optional, Sequence, Set, cast from unittest.mock import MagicMock, Mock, patch import pytest from datahub.errors import SdkUsageError from datahub.sdk.main_client import DataHubClient +from datahub.sdk.search_filters import FilterDsl as F from datahub.sql_parsing.sql_parsing_common import QueryType from datahub.sql_parsing.sqlglot_lineage import ( ColumnLineageInfo, @@ -127,38 +128,8 @@ def test_get_strict_column_lineage( cast(Set[str], upstream_fields), cast(Set[str], downstream_fields), ) - assert result == expected, f"Test failed: {result} != {expected}" """Test the strict column lineage matching algorithm.""" - test_cases = [ - # Exact matches - { - "upstream_fields": {"id", "name", "email"}, - "downstream_fields": {"id", "name", "phone"}, - "expected": {"id": ["id"], "name": ["name"]}, - }, - # No matches - { - "upstream_fields": {"col1", "col2", "col3"}, - "downstream_fields": {"col4", "col5", "col6"}, - "expected": {}, - }, - # Case mismatch (should match) - { - "upstream_fields": {"ID", "Name", "Email"}, - "downstream_fields": {"id", "name", "email"}, - "expected": {"id": ["ID"], "name": ["Name"], "email": ["Email"]}, - }, - ] - - # Run test cases - for test_case in test_cases: - result = client.lineage._get_strict_column_lineage( - cast(Set[str], test_case["upstream_fields"]), - cast(Set[str], test_case["downstream_fields"]), - ) - assert result == test_case["expected"], ( - f"Test failed: {result} != {test_case['expected']}" - ) + assert result == expected, f"Test failed: {result} != {expected}" def test_infer_lineage_from_sql(client: DataHubClient) -> None: @@ -444,3 +415,284 @@ def test_add_lineage_invalid_parameter_combinations(client: DataHubClient) -> No downstream="urn:li:dataset:(urn:li:dataPlatform:snowflake,target_table,PROD)", column_lineage={"target_col": ["source_col"]}, ) + + +def test_get_lineage_basic(client: DataHubClient) -> None: + """Test basic lineage retrieval with default parameters.""" + source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)" + + # Mock GraphQL response + mock_response = { + "scrollAcrossLineage": { + "nextScrollId": None, + "searchResults": [ + { + "entity": { + "urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)", + "type": "DATASET", + "platform": {"name": "snowflake"}, + "properties": { + "name": "upstream_table", + "description": "Upstream source table", + }, + }, + "degree": 1, + } + ], + } + } + + # Patch the GraphQL execution method + with patch.object(client._graph, "execute_graphql", return_value=mock_response): + results = client.lineage.get_lineage(source_urn=source_urn) + + # Validate results + assert len(results) == 1 + assert ( + results[0].urn + == "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)" + ) + assert results[0].type == "DATASET" + assert results[0].name == "upstream_table" + assert results[0].platform == "snowflake" + assert results[0].direction == "upstream" + assert results[0].hops == 1 + + +def test_get_lineage_with_entity_type_filters(client: DataHubClient) -> None: + """Test lineage retrieval with entity type and platform filters.""" + source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)" + + # Mock GraphQL response with multiple entity types + mock_response = { + "scrollAcrossLineage": { + "nextScrollId": None, + "searchResults": [ + { + "entity": { + "urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)", + "type": "DATASET", + "platform": {"name": "snowflake"}, + "properties": { + "name": "upstream_table", + "description": "Upstream source table", + }, + }, + "degree": 1, + }, + ], + } + } + + # Patch the GraphQL execution method to return results for multiple calls + with patch.object(client._graph, "execute_graphql", return_value=mock_response): + results = client.lineage.get_lineage( + source_urn=source_urn, + filter=F.entity_type("dataset"), + ) + + # Validate results + assert len(results) == 1 + assert {r.type for r in results} == {"DATASET"} + assert {r.platform for r in results} == {"snowflake"} + + +def test_get_lineage_downstream(client: DataHubClient) -> None: + """Test downstream lineage retrieval.""" + source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)" + + # Mock GraphQL response for downstream lineage + mock_response = { + "scrollAcrossLineage": { + "nextScrollId": None, + "searchResults": [ + { + "entity": { + "urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)", + "type": "DATASET", + "properties": { + "name": "downstream_table", + "description": "Downstream target table", + "platform": {"name": "snowflake"}, + }, + }, + "degree": 1, + } + ], + } + } + + # Patch the GraphQL execution method + with patch.object(client._graph, "execute_graphql", return_value=mock_response): + results = client.lineage.get_lineage( + source_urn=source_urn, + direction="downstream", + ) + + # Validate results + assert len(results) == 1 + assert ( + results[0].urn + == "urn:li:dataset:(urn:li:dataPlatform:snowflake,downstream_table,PROD)" + ) + assert results[0].direction == "downstream" + + +def test_get_lineage_multiple_hops( + client: DataHubClient, caplog: pytest.LogCaptureFixture +) -> None: + """Test lineage retrieval with multiple hops.""" + source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)" + + # Mock GraphQL response with multiple hops + mock_response = { + "scrollAcrossLineage": { + "nextScrollId": None, + "searchResults": [ + { + "entity": { + "urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table1,PROD)", + "type": "DATASET", + "properties": { + "name": "upstream_table1", + "description": "First upstream table", + "platform": {"name": "snowflake"}, + }, + }, + "degree": 1, + }, + { + "entity": { + "urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table2,PROD)", + "type": "DATASET", + "properties": { + "name": "upstream_table2", + "description": "Second upstream table", + "platform": {"name": "snowflake"}, + }, + }, + "degree": 2, + }, + ], + } + } + + # Patch the GraphQL execution method + with patch.object(client._graph, "execute_graphql", return_value=mock_response): + results = client.lineage.get_lineage(source_urn=source_urn, max_hops=2) + + # check warning if logged when max_hops > 2 + with patch.object( + client._graph, "execute_graphql", return_value=mock_response + ), caplog.at_level("WARNING"): + client.lineage.get_lineage(source_urn=source_urn, max_hops=3) + assert any( + "the search will try to find the full lineage graph" in msg + for msg in caplog.messages + ) + + # Validate results + assert len(results) == 2 + assert results[0].hops == 1 + assert results[0].type == "DATASET" + assert results[1].hops == 2 + assert results[1].type == "DATASET" + + +def test_get_lineage_no_results(client: DataHubClient) -> None: + """Test lineage retrieval with no results.""" + source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)" + + # Mock GraphQL response with no results + mock_response: Dict[str, Dict[str, Optional[List]]] = { + "scrollAcrossLineage": {"nextScrollId": None, "searchResults": []} + } + + # Patch the GraphQL execution method + with patch.object(client._graph, "execute_graphql", return_value=mock_response): + results = client.lineage.get_lineage(source_urn=source_urn) + + # Validate results + assert len(results) == 0 + + +def test_get_lineage_column_lineage_with_source_column(client: DataHubClient) -> None: + """Test lineage retrieval with column lineage.""" + source_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD)" + source_column = "source_column" + + # Mock GraphQL response with column lineage + mock_response = { + "scrollAcrossLineage": { + "nextScrollId": None, + "searchResults": [ + { + "entity": { + "urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)", + "type": "DATASET", + "platform": {"name": "snowflake"}, + "properties": { + "name": "upstream_table", + "description": "Upstream source table", + }, + }, + "degree": 1, + }, + ], + } + } + + # Patch the GraphQL execution method + with patch.object(client._graph, "execute_graphql", return_value=mock_response): + results = client.lineage.get_lineage( + source_urn=source_urn, + source_column=source_column, + ) + + # Validate results + assert len(results) == 1 + assert ( + results[0].urn + == "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)" + ) + + +def test_get_lineage_column_lineage_with_schema_field_urn( + client: DataHubClient, +) -> None: + """Test lineage retrieval with column lineage.""" + source_urn = "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,source_table,PROD),source_column)" + + # Mock GraphQL response with column lineage + mock_response = { + "scrollAcrossLineage": { + "nextScrollId": None, + "searchResults": [ + { + "entity": { + "urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)", + "type": "DATASET", + "platform": {"name": "snowflake"}, + "properties": { + "name": "upstream_table", + "description": "Upstream source table", + }, + }, + "degree": 1, + }, + ], + } + } + + # Patch the GraphQL execution method + with patch.object(client._graph, "execute_graphql", return_value=mock_response): + results = client.lineage.get_lineage( + source_urn=source_urn, + ) + + # Validate results + assert len(results) == 1 + assert ( + results[0].urn + == "urn:li:dataset:(urn:li:dataPlatform:snowflake,upstream_table,PROD)" + ) diff --git a/smoke-test/tests/lineage/test_lineage_sdk.py b/smoke-test/tests/lineage/test_lineage_sdk.py new file mode 100644 index 0000000000..cca84d4cc7 --- /dev/null +++ b/smoke-test/tests/lineage/test_lineage_sdk.py @@ -0,0 +1,222 @@ +from typing import Dict, Generator + +import pytest + +from datahub.ingestion.graph.client import DataHubGraph +from datahub.metadata.urns import SchemaFieldUrn +from datahub.sdk.dataset import Dataset +from datahub.sdk.lineage_client import LineageResult +from datahub.sdk.main_client import DataHubClient +from datahub.sdk.search_filters import FilterDsl as F +from tests.utils import wait_for_writes_to_sync + + +@pytest.fixture(scope="module") +def test_client(graph_client: DataHubGraph) -> DataHubClient: + return DataHubClient(graph=graph_client) + + +@pytest.fixture(scope="module") +def test_datasets( + test_client: DataHubClient, +) -> Generator[Dict[str, Dataset], None, None]: + datasets = { + "upstream": Dataset( + platform="snowflake", + name="test_lineage_upstream_001", + schema=[("name", "string"), ("id", "int")], + ), + "downstream1": Dataset( + platform="snowflake", + name="test_lineage_downstream_001", + schema=[("name", "string"), ("id", "int")], + ), + "downstream2": Dataset( + platform="snowflake", + name="test_lineage_downstream_002", + schema=[("name", "string"), ("id", "int")], + ), + "downstream3": Dataset( + platform="mysql", + name="test_lineage_downstream_003", + schema=[("name", "string"), ("id", "int")], + ), + } + + for entity in datasets.values(): + test_client._graph.delete_entity(str(entity.urn), hard=True) + for entity in datasets.values(): + test_client.entities.upsert(entity) + + # Add lineage + test_client.lineage.add_lineage( + upstream=str(datasets["upstream"].urn), + downstream=str(datasets["downstream1"].urn), + column_lineage=True, + ) + test_client.lineage.add_lineage( + upstream=str(datasets["downstream1"].urn), + downstream=str(datasets["downstream2"].urn), + column_lineage=True, + ) + test_client.lineage.add_lineage( + upstream=str(datasets["downstream2"].urn), + downstream=str(datasets["downstream3"].urn), + column_lineage=True, + ) + + wait_for_writes_to_sync() + + yield datasets + + # Cleanup + for entity in datasets.values(): + try: + test_client._graph.delete_entity(str(entity.urn), hard=True) + except Exception as e: + raise Exception(f"Could not delete entity {entity.urn}: {e}") + + +def validate_lineage_results( + lineage_result: LineageResult, + hops=None, + direction=None, + platform=None, + urn=None, + paths_len=None, +): + if hops is not None: + assert lineage_result.hops == hops + if direction is not None: + assert lineage_result.direction == direction + if platform is not None: + assert lineage_result.platform == platform + if urn is not None: + assert lineage_result.urn == urn + if paths_len is not None and lineage_result.paths is not None: + assert len(lineage_result.paths) == paths_len + + +def test_table_level_lineage( + test_client: DataHubClient, test_datasets: Dict[str, Dataset] +): + table_lineage_results = test_client.lineage.get_lineage( + source_urn=str(test_datasets["upstream"].urn), + direction="downstream", + max_hops=3, + ) + + assert len(table_lineage_results) == 3 + urns = {r.urn for r in table_lineage_results} + expected = { + str(test_datasets["downstream1"].urn), + str(test_datasets["downstream2"].urn), + str(test_datasets["downstream3"].urn), + } + assert urns == expected + + table_lineage_results = sorted(table_lineage_results, key=lambda x: x.hops) + validate_lineage_results( + table_lineage_results[0], + hops=1, + platform="snowflake", + urn=str(test_datasets["downstream1"].urn), + paths_len=0, + ) + validate_lineage_results( + table_lineage_results[1], + hops=2, + platform="snowflake", + urn=str(test_datasets["downstream2"].urn), + paths_len=0, + ) + validate_lineage_results( + table_lineage_results[2], + hops=3, + platform="mysql", + urn=str(test_datasets["downstream3"].urn), + paths_len=0, + ) + + +def test_column_level_lineage( + test_client: DataHubClient, test_datasets: Dict[str, Dataset] +): + column_lineage_results = test_client.lineage.get_lineage( + source_urn=str(test_datasets["upstream"].urn), + source_column="id", + direction="downstream", + max_hops=3, + ) + + assert len(column_lineage_results) == 3 + column_lineage_results = sorted(column_lineage_results, key=lambda x: x.hops) + validate_lineage_results( + column_lineage_results[0], + hops=1, + urn=str(test_datasets["downstream1"].urn), + paths_len=2, + ) + validate_lineage_results( + column_lineage_results[1], + hops=2, + urn=str(test_datasets["downstream2"].urn), + paths_len=3, + ) + validate_lineage_results( + column_lineage_results[2], + hops=3, + urn=str(test_datasets["downstream3"].urn), + paths_len=4, + ) + + +def test_filtered_column_level_lineage( + test_client: DataHubClient, test_datasets: Dict[str, Dataset] +): + filtered_column_lineage_results = test_client.lineage.get_lineage( + source_urn=str(test_datasets["upstream"].urn), + source_column="id", + direction="downstream", + max_hops=3, + filter=F.and_(F.platform("mysql"), F.entity_type("dataset")), + ) + + assert len(filtered_column_lineage_results) == 1 + validate_lineage_results( + filtered_column_lineage_results[0], + hops=3, + platform="mysql", + urn=str(test_datasets["downstream3"].urn), + paths_len=4, + ) + + +def test_column_level_lineage_from_schema_field( + test_client: DataHubClient, test_datasets: Dict[str, Dataset] +): + source_schema_field = SchemaFieldUrn(test_datasets["upstream"].urn, "id") + column_lineage_results = test_client.lineage.get_lineage( + source_urn=str(source_schema_field), direction="downstream", max_hops=3 + ) + + assert len(column_lineage_results) == 3 + column_lineage_results = sorted(column_lineage_results, key=lambda x: x.hops) + validate_lineage_results( + column_lineage_results[0], + hops=1, + urn=str(test_datasets["downstream1"].urn), + paths_len=2, + ) + validate_lineage_results( + column_lineage_results[1], + hops=2, + urn=str(test_datasets["downstream2"].urn), + paths_len=3, + ) + validate_lineage_results( + column_lineage_results[2], + hops=3, + urn=str(test_datasets["downstream3"].urn), + paths_len=4, + )