datahub/smoke-test/tests/lineage/test_lineage.py

964 lines
36 KiB
Python

import logging
import time
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import networkx as nx
import pydantic
import pytest
from pydantic import BaseModel, ConfigDict
import datahub.emitter.mce_builder as builder
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.graph.client import DataHubGraph
from datahub.metadata.schema_classes import (
AuditStampClass,
ChangeAuditStampsClass,
ChartInfoClass,
DataFlowInfoClass,
DataJobInfoClass,
DataJobInputOutputClass,
DatasetLineageTypeClass,
DatasetPropertiesClass,
EdgeClass,
FineGrainedLineageClass as FineGrainedLineage,
FineGrainedLineageDownstreamTypeClass as FineGrainedLineageDownstreamType,
FineGrainedLineageUpstreamTypeClass as FineGrainedLineageUpstreamType,
OtherSchemaClass,
QueryLanguageClass,
QueryPropertiesClass,
QuerySourceClass,
QueryStatementClass,
SchemaFieldClass,
SchemaFieldDataTypeClass,
SchemaMetadataClass,
StringTypeClass,
UpstreamClass,
UpstreamLineageClass,
)
from datahub.utilities.urns.dataset_urn import DatasetUrn
from datahub.utilities.urns.urn import Urn
from tests.utils import ingest_file_via_rest, wait_for_writes_to_sync
logger = logging.getLogger(__name__)
class DeleteAgent:
def delete_entity(self, urn: str) -> None:
pass
class DataHubGraphDeleteAgent(DeleteAgent):
def __init__(self, graph: DataHubGraph):
self.graph = graph
def delete_entity(self, urn: str) -> None:
self.graph.delete_entity(urn, hard=True)
class DataHubConsoleDeleteAgent(DeleteAgent):
def delete_entity(self, urn: str) -> None:
print(f"Would delete {urn}")
class DataHubConsoleEmitter:
def emit_mcp(self, mcp: MetadataChangeProposalWrapper) -> None:
print(mcp)
INFINITE_HOPS: int = -1
def ingest_tableau_cll_via_rest(auth_session) -> None:
ingest_file_via_rest(
auth_session,
"tests/lineage/tableau_cll_mcps.json",
)
def search_across_lineage(
graph: DataHubGraph,
main_entity: str,
hops: int = INFINITE_HOPS,
direction: str = "UPSTREAM",
convert_schema_fields_to_datasets: bool = True,
):
def _explain_sal_result(result: dict) -> str:
explain = ""
entities = [
x["entity"]["urn"] for x in result["searchAcrossLineage"]["searchResults"]
]
number_of_results = len(entities)
explain += f"Number of results: {number_of_results}\n"
explain += "Entities: "
try:
for e in entities:
explain += f"\t{e.replace('urn:li:', '')}\n"
for entity in entities:
paths = [
x["paths"][0]["path"]
for x in result["searchAcrossLineage"]["searchResults"]
if x["entity"]["urn"] == entity
]
explain += f"Paths for entity {entity}: "
for path in paths:
explain += (
"\t"
+ " -> ".join(
[
x["urn"]
.replace("urn:li:schemaField", "field")
.replace("urn:li:dataset", "dataset")
.replace("urn:li:dataPlatform", "platform")
for x in path
]
)
+ "\n"
)
except Exception:
# breakpoint()
pass
return explain
variable: dict[str, Any] = {
"input": (
{
"urn": main_entity,
"query": "*",
"direction": direction,
"searchFlags": {
"groupingSpec": {
"groupingCriteria": [
{
"baseEntityType": "SCHEMA_FIELD",
"groupingEntityType": "DATASET",
},
]
},
"skipCache": True,
},
}
if convert_schema_fields_to_datasets
else {
"urn": main_entity,
"query": "*",
"direction": direction,
"searchFlags": {
"skipCache": True,
},
}
)
}
if hops != INFINITE_HOPS:
variable["input"].update(
{
"orFilters": [
{
"and": [
{
"field": "degree",
"condition": "EQUAL",
"values": ["{}".format(hops)],
"negated": False,
}
]
}
]
}
)
result = graph.execute_graphql(
"""
query($input: SearchAcrossLineageInput!) {
searchAcrossLineage(input: $input)
{
searchResults {
entity {
urn
}
paths {
path {
urn
}
}
}
}
}
""",
variables=variable,
)
print(f"Query -> Entity {main_entity} with hops {hops} and direction {direction}")
print(result)
print(_explain_sal_result(result))
return result
class Direction(Enum):
UPSTREAM = "UPSTREAM"
DOWNSTREAM = "DOWNSTREAM"
def opposite(self):
if self == Direction.UPSTREAM:
return Direction.DOWNSTREAM
else:
return Direction.UPSTREAM
class Path(BaseModel):
path: List[str]
def add_node(self, node: str) -> None:
self.path.append(node)
def __hash__(self) -> int:
return ".".join(self.path).__hash__()
class LineageExpectation(BaseModel):
direction: Direction
main_entity: str
hops: int
impacted_entities: Dict[str, List[Path]]
class ImpactQuery(BaseModel):
main_entity: str
hops: int
direction: Direction
upconvert_schema_fields_to_datasets: bool
def __hash__(self) -> int:
raw_string = (
f"{self.main_entity}{self.hops}{self.direction}"
+ f"{self.upconvert_schema_fields_to_datasets}"
)
return raw_string.__hash__()
class ScenarioExpectation:
"""
This class stores the expectations for the lineage of a scenario. It is used
to store the pre-materialized expectations for all datasets and schema
fields across all hops and directions possible. This makes it easy to check
that the results of a lineage query match the expectations.
"""
def __init__(self):
self._graph = nx.DiGraph()
def __simplify(self, urn_or_list: Union[str, List[str]]) -> str:
if isinstance(urn_or_list, list):
return ",".join([self.__simplify(x) for x in urn_or_list])
else:
return (
urn_or_list.replace("urn:li:schemaField", "F")
.replace("urn:li:dataset", "D")
.replace("urn:li:dataPlatform", "P")
.replace("urn:li:query", "Q")
)
def extend_impacted_entities(
self,
direction: Direction,
parent_entity: str,
child_entity: str,
path_extension: Optional[List[str]] = None,
) -> None:
via_node = path_extension[0] if path_extension else None
if via_node:
self._graph.add_edge(parent_entity, child_entity, via=via_node)
else:
self._graph.add_edge(parent_entity, child_entity)
def generate_query_expectation_pairs(
self, max_hops: int
) -> Iterable[Tuple[ImpactQuery, LineageExpectation]]:
upconvert_options = [
True
] # TODO: Add False once search-across-lineage supports returning schema fields
for main_entity in self._graph.nodes():
for direction in [Direction.UPSTREAM, Direction.DOWNSTREAM]:
for upconvert_schema_fields_to_datasets in upconvert_options:
possible_hops = [h for h in range(1, max_hops)] + [INFINITE_HOPS]
for hops in possible_hops:
query = ImpactQuery(
main_entity=main_entity,
hops=hops,
direction=direction,
upconvert_schema_fields_to_datasets=upconvert_schema_fields_to_datasets,
)
yield query, self.get_expectation_for_query(query)
def get_expectation_for_query(self, query: ImpactQuery) -> LineageExpectation:
graph_to_walk = (
self._graph
if query.direction == Direction.DOWNSTREAM
else self._graph.reverse()
)
entity_paths = nx.shortest_path(graph_to_walk, source=query.main_entity)
lineage_expectation = LineageExpectation(
direction=query.direction,
main_entity=query.main_entity,
hops=query.hops,
impacted_entities={},
)
for entity, paths in entity_paths.items():
if entity == query.main_entity:
continue
if query.hops != INFINITE_HOPS and len(paths) != (
query.hops + 1
): # +1 because the path includes the main entity
print(
f"Skipping {entity} because it is less than or more than {query.hops} hops away"
)
continue
path_graph = nx.path_graph(paths)
expanded_path: List[str] = []
via_entity = None
for ea in path_graph.edges():
expanded_path.append(ea[0])
if "via" in graph_to_walk.edges[ea[0], ea[1]]:
via_entity = graph_to_walk.edges[ea[0], ea[1]]["via"]
expanded_path.append(via_entity)
if via_entity and not via_entity.startswith(
"urn:li:query"
): # Transient nodes like queries are not included as impacted entities
if via_entity not in lineage_expectation.impacted_entities:
lineage_expectation.impacted_entities[via_entity] = []
via_path = Path(path=[x for x in expanded_path])
if via_path not in lineage_expectation.impacted_entities[via_entity]:
lineage_expectation.impacted_entities[via_entity].append(
Path(path=[x for x in expanded_path])
)
expanded_path.append(paths[-1])
if entity not in lineage_expectation.impacted_entities:
lineage_expectation.impacted_entities[entity] = []
lineage_expectation.impacted_entities[entity].append(
Path(path=expanded_path)
)
if query.upconvert_schema_fields_to_datasets:
entries_to_add: Dict[str, List[Path]] = {}
entries_to_remove = []
for impacted_entity in lineage_expectation.impacted_entities:
if impacted_entity.startswith("urn:li:schemaField"):
impacted_dataset_entity = Urn.create_from_string(
impacted_entity
).entity_ids[0]
if impacted_dataset_entity in entries_to_add:
entries_to_add[impacted_dataset_entity].extend(
lineage_expectation.impacted_entities[impacted_entity]
)
else:
entries_to_add[impacted_dataset_entity] = (
lineage_expectation.impacted_entities[impacted_entity]
)
entries_to_remove.append(impacted_entity)
for impacted_entity in entries_to_remove:
del lineage_expectation.impacted_entities[impacted_entity]
lineage_expectation.impacted_entities.update(entries_to_add)
return lineage_expectation
class Scenario(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
class LineageStyle(Enum):
DATASET_QUERY_DATASET = "DATASET_QUERY_DATASET"
DATASET_JOB_DATASET = "DATASET_JOB_DATASET"
lineage_style: LineageStyle
default_platform: str = "mysql"
default_transformation_platform: str = "airflow"
hop_platform_map: Dict[int, str] = {}
hop_transformation_map: Dict[int, str] = {}
num_hops: int = 1
default_datasets_at_each_hop: int = 2
default_dataset_fanin: int = 2 # Number of datasets that feed into a transformation
default_column_fanin: int = 2 # Number of columns that feed into a transformation
default_dataset_fanout: int = (
1 # Number of datasets that a transformation feeds into
)
default_column_fanout: int = 1 # Number of columns that a transformation feeds into
# num_upstream_datasets: int = 2
# num_downstream_datasets: int = 1
default_dataset_prefix: str = "librarydb."
hop_dataset_prefix_map: Dict[int, str] = {}
query_id: str = "guid-guid-guid"
query_string: str = "SELECT * FROM foo"
transformation_job: str = "job1"
transformation_flow: str = "flow1"
_generated_urns: Set[str] = set()
expectations: ScenarioExpectation = pydantic.Field(
default_factory=ScenarioExpectation
)
def get_column_name(self, column_index: int) -> str:
return f"column_{column_index}"
def set_upstream_dataset_prefix(self, dataset):
self.upstream_dataset_prefix = dataset
def set_downstream_dataset_prefix(self, dataset):
self.downstream_dataset_prefix = dataset
def set_transformation_query(self, query: str) -> None:
self.transformation_query = query
def set_transformation_job(self, job: str) -> None:
self.transformation_job = job
def set_transformation_flow(self, flow: str) -> None:
self.transformation_flow = flow
def get_transformation_job_urn(self, hop_index: int) -> str:
return builder.make_data_job_urn(
orchestrator=self.default_transformation_platform,
flow_id=f"layer_{hop_index}_{self.transformation_flow}",
job_id=self.transformation_job,
cluster="PROD",
)
def get_transformation_query_urn(self, hop_index: int = 0) -> str:
return f"urn:li:query:{self.query_id}_{hop_index}" # TODO - add hop index to query id
def get_transformation_flow_urn(self, hop_index: int) -> str:
return builder.make_data_flow_urn(
orchestrator=self.default_transformation_platform,
flow_id=f"layer_{hop_index}_{self.transformation_flow}",
cluster="PROD",
)
def get_upstream_dataset_urns(self, hop_index: int) -> List[str]:
return [
self.get_dataset_urn(hop_index=hop_index, index=i)
for i in range(self.default_dataset_fanin)
]
def get_dataset_urn(self, hop_index: int, index: int) -> str:
platform = self.hop_platform_map.get(hop_index, self.default_platform)
prefix = self.hop_dataset_prefix_map.get(
index, f"{self.default_dataset_prefix}layer_{hop_index}."
)
return builder.make_dataset_urn(platform, f"{prefix}{index}")
def get_column_urn(
self, hop_index: int, dataset_index: int, column_index: int = 0
) -> str:
return builder.make_schema_field_urn(
self.get_dataset_urn(hop_index, dataset_index),
self.get_column_name(column_index),
)
def get_upstream_column_urn(
self, hop_index: int, dataset_index: int, column_index: int = 0
) -> str:
return builder.make_schema_field_urn(
self.get_dataset_urn(hop_index, dataset_index),
self.get_column_name(column_index),
)
def get_downstream_column_urn(
self, hop_index: int, dataset_index: int, column_index: int = 0
) -> str:
return builder.make_schema_field_urn(
self.get_dataset_urn(hop_index + 1, dataset_index),
self.get_column_name(column_index),
)
def get_downstream_dataset_urns(self, hop_index: int) -> List[str]:
return [
self.get_dataset_urn(hop_index + 1, i)
for i in range(self.default_dataset_fanout)
]
def get_lineage_mcps(self) -> Iterable[MetadataChangeProposalWrapper]:
for hop_index in range(0, self.num_hops):
yield from self.get_lineage_mcps_for_hop(hop_index)
def get_lineage_mcps_for_hop(
self, hop_index: int
) -> Iterable[MetadataChangeProposalWrapper]:
assert self.expectations is not None
if self.lineage_style == Scenario.LineageStyle.DATASET_JOB_DATASET:
fine_grained_lineage = FineGrainedLineage(
upstreamType=FineGrainedLineageUpstreamType.FIELD_SET,
upstreams=[
self.get_upstream_column_urn(hop_index, dataset_index, 0)
for dataset_index in range(self.default_dataset_fanin)
],
downstreamType=FineGrainedLineageDownstreamType.FIELD,
downstreams=[
self.get_downstream_column_urn(hop_index, dataset_index, 0)
for dataset_index in range(self.default_dataset_fanout)
],
)
datajob_io = DataJobInputOutputClass(
inputDatasets=self.get_upstream_dataset_urns(hop_index),
outputDatasets=self.get_downstream_dataset_urns(hop_index),
inputDatajobs=[], # not supporting job -> job lineage for now
fineGrainedLineages=[fine_grained_lineage],
)
yield MetadataChangeProposalWrapper(
entityUrn=self.get_transformation_job_urn(hop_index),
aspect=datajob_io,
)
# Add field level expectations
for upstream_field_urn in fine_grained_lineage.upstreams or []:
for downstream_field_urn in fine_grained_lineage.downstreams or []:
self.expectations.extend_impacted_entities(
Direction.DOWNSTREAM,
upstream_field_urn,
downstream_field_urn,
path_extension=[
self.get_transformation_job_urn(hop_index),
downstream_field_urn,
],
)
# Add table level expectations
for upstream_dataset_urn in datajob_io.inputDatasets:
# No path extension, because we don't use via nodes for dataset -> dataset edges
self.expectations.extend_impacted_entities(
Direction.DOWNSTREAM,
upstream_dataset_urn,
self.get_transformation_job_urn(hop_index),
)
for downstream_dataset_urn in datajob_io.outputDatasets:
self.expectations.extend_impacted_entities(
Direction.DOWNSTREAM,
self.get_transformation_job_urn(hop_index),
downstream_dataset_urn,
)
if self.lineage_style == Scenario.LineageStyle.DATASET_QUERY_DATASET:
# we emit upstream lineage from the downstream dataset
for downstream_dataset_index in range(self.default_dataset_fanout):
mcp_entity_urn = self.get_dataset_urn(
hop_index + 1, downstream_dataset_index
)
fine_grained_lineages = [
FineGrainedLineage(
upstreamType=FineGrainedLineageUpstreamType.FIELD_SET,
upstreams=[
self.get_upstream_column_urn(
hop_index, d_i, upstream_col_index
)
for d_i in range(self.default_dataset_fanin)
],
downstreamType=FineGrainedLineageDownstreamType.FIELD,
downstreams=[
self.get_downstream_column_urn(
hop_index,
downstream_dataset_index,
downstream_col_index,
)
for downstream_col_index in range(
self.default_column_fanout
)
],
query=self.get_transformation_query_urn(hop_index),
)
for upstream_col_index in range(self.default_column_fanin)
]
upstream_lineage = UpstreamLineageClass(
upstreams=[
UpstreamClass(
dataset=self.get_dataset_urn(hop_index, i),
type=DatasetLineageTypeClass.TRANSFORMED,
query=self.get_transformation_query_urn(hop_index),
)
for i in range(self.default_dataset_fanin)
],
fineGrainedLineages=fine_grained_lineages,
)
for fine_grained_lineage in fine_grained_lineages:
# Add field level expectations
for upstream_field_urn in fine_grained_lineage.upstreams or []:
for downstream_field_urn in (
fine_grained_lineage.downstreams or []
):
self.expectations.extend_impacted_entities(
Direction.DOWNSTREAM,
upstream_field_urn,
downstream_field_urn,
path_extension=[
self.get_transformation_query_urn(hop_index),
downstream_field_urn,
],
)
# Add table level expectations
for upstream_dataset in upstream_lineage.upstreams:
self.expectations.extend_impacted_entities(
Direction.DOWNSTREAM,
upstream_dataset.dataset,
mcp_entity_urn,
path_extension=[
self.get_transformation_query_urn(hop_index),
mcp_entity_urn,
],
)
yield MetadataChangeProposalWrapper(
entityUrn=mcp_entity_urn,
aspect=upstream_lineage,
)
def get_entity_mcps(self) -> Iterable[MetadataChangeProposalWrapper]:
for hop_index in range(
0, self.num_hops + 1
): # we generate entities with last hop inclusive
for mcp in self.get_entity_mcps_for_hop(hop_index):
assert mcp.entityUrn
self._generated_urns.add(mcp.entityUrn)
yield mcp
def get_entity_mcps_for_hop(
self, hop_index: int
) -> Iterable[MetadataChangeProposalWrapper]:
if self.lineage_style == Scenario.LineageStyle.DATASET_JOB_DATASET:
# Construct the DataJobInfo aspect with the job -> flow lineage.
dataflow_urn = self.get_transformation_flow_urn(hop_index)
dataflow_info = DataFlowInfoClass(
name=self.transformation_flow.title() + " Flow"
)
dataflow_info_mcp = MetadataChangeProposalWrapper(
entityUrn=dataflow_urn,
aspect=dataflow_info,
)
yield dataflow_info_mcp
datajob_info = DataJobInfoClass(
name=self.transformation_job.title() + " Job",
type="AIRFLOW",
flowUrn=dataflow_urn,
)
# Construct a MetadataChangeProposalWrapper object with the DataJobInfo aspect.
# NOTE: This will overwrite all of the existing dataJobInfo aspect information associated with this job.
datajob_info_mcp = MetadataChangeProposalWrapper(
entityUrn=self.get_transformation_job_urn(hop_index),
aspect=datajob_info,
)
yield datajob_info_mcp
if self.lineage_style == Scenario.LineageStyle.DATASET_QUERY_DATASET:
query_urn = self.get_transformation_query_urn(hop_index=hop_index)
fake_auditstamp = AuditStampClass(
time=int(time.time() * 1000),
actor="urn:li:corpuser:datahub",
)
query_properties = QueryPropertiesClass(
statement=QueryStatementClass(
value=self.query_string,
language=QueryLanguageClass.SQL,
),
source=QuerySourceClass.SYSTEM,
created=fake_auditstamp,
lastModified=fake_auditstamp,
)
query_info_mcp = MetadataChangeProposalWrapper(
entityUrn=query_urn,
aspect=query_properties,
)
yield query_info_mcp
# Generate schema and properties mcps for all datasets
for dataset_index in range(self.default_datasets_at_each_hop):
dataset_urn = DatasetUrn.from_string(
self.get_dataset_urn(hop_index, dataset_index)
)
yield from MetadataChangeProposalWrapper.construct_many(
entityUrn=str(dataset_urn),
aspects=[
SchemaMetadataClass(
schemaName=str(dataset_urn),
platform=builder.make_data_platform_urn(dataset_urn.platform),
version=0,
hash="",
platformSchema=OtherSchemaClass(rawSchema=""),
fields=[
SchemaFieldClass(
fieldPath=self.get_column_name(i),
type=SchemaFieldDataTypeClass(type=StringTypeClass()),
nativeDataType="string",
)
for i in range(self.default_column_fanin)
],
),
DatasetPropertiesClass(
name=dataset_urn.name,
),
],
)
def cleanup(self, delete_agent: DeleteAgent) -> None:
"""Delete all entities created by this scenario."""
for urn in self._generated_urns:
delete_agent.delete_entity(urn)
def test_expectation(self, graph: DataHubGraph) -> bool:
print("Testing expectation...")
assert self.expectations is not None
try:
for hop_index in range(self.num_hops):
for dataset_urn in self.get_upstream_dataset_urns(hop_index):
assert graph.exists(dataset_urn) is True
for dataset_urn in self.get_downstream_dataset_urns(hop_index):
assert graph.exists(dataset_urn) is True
if self.lineage_style == Scenario.LineageStyle.DATASET_JOB_DATASET:
assert graph.exists(self.get_transformation_job_urn(hop_index)) is True
assert graph.exists(self.get_transformation_flow_urn(hop_index)) is True
if self.lineage_style == Scenario.LineageStyle.DATASET_QUERY_DATASET:
assert (
graph.exists(self.get_transformation_query_urn(hop_index)) is True
)
wait_for_writes_to_sync() # Wait for the graph to update
# We would like to check that lineage is correct for all datasets and schema fields for all values of hops and for all directions of lineage exploration
# Since we already have expectations stored for all datasets and schema_fields, we can just check that the results match the expectations
for (
query,
expectation,
) in self.expectations.generate_query_expectation_pairs(self.num_hops):
impacted_entities_expectation = set(
[x for x in expectation.impacted_entities.keys()]
)
if len(impacted_entities_expectation) == 0:
continue
result = search_across_lineage(
graph,
query.main_entity,
query.hops,
query.direction.value,
query.upconvert_schema_fields_to_datasets,
)
impacted_entities = set(
[
x["entity"]["urn"]
for x in result["searchAcrossLineage"]["searchResults"]
]
)
try:
assert impacted_entities == impacted_entities_expectation, (
f"Expected impacted entities to be {impacted_entities_expectation}, found {impacted_entities}"
)
except Exception:
# breakpoint()
raise
search_results = result["searchAcrossLineage"]["searchResults"]
for impacted_entity in impacted_entities:
# breakpoint()
impacted_entity_paths: List[Path] = []
# breakpoint()
entity_paths_response = [
x["paths"]
for x in search_results
if x["entity"]["urn"] == impacted_entity
]
for path_response in entity_paths_response:
for p in path_response:
q = p["path"]
impacted_entity_paths.append(
Path(path=[x["urn"] for x in q])
)
# if len(impacted_entity_paths) > 1:
# breakpoint()
try:
assert len(impacted_entity_paths) == len(
expectation.impacted_entities[impacted_entity]
), (
f"Expected length of impacted entity paths to be {len(expectation.impacted_entities[impacted_entity])}, found {len(impacted_entity_paths)}"
)
assert set(impacted_entity_paths) == set(
expectation.impacted_entities[impacted_entity]
), (
f"Expected impacted entity paths to be {expectation.impacted_entities[impacted_entity]}, found {impacted_entity_paths}"
)
except Exception:
# breakpoint()
raise
# for i in range(len(impacted_entity_paths)):
# assert impacted_entity_paths[i].path == expectation.impacted_entities[impacted_entity][i].path, f"Expected impacted entity paths to be {expectation.impacted_entities[impacted_entity][i].path}, found {impacted_entity_paths[i].path}"
print("Test passed!")
return True
except AssertionError as e:
print("Test failed!")
raise e
return False
# @tenacity.retry(
# stop=tenacity.stop_after_attempt(sleep_times), wait=tenacity.wait_fixed(sleep_sec)
# )
@pytest.mark.parametrize(
"lineage_style",
[
Scenario.LineageStyle.DATASET_QUERY_DATASET,
Scenario.LineageStyle.DATASET_JOB_DATASET,
],
)
@pytest.mark.parametrize(
"graph_level",
[
1,
2,
3,
# TODO - convert this to range of 1 to 10 to make sure we can handle large graphs
],
)
def test_lineage_via_node(
graph_client: DataHubGraph, lineage_style: Scenario.LineageStyle, graph_level: int
) -> None:
scenario: Scenario = Scenario(
hop_platform_map={0: "mysql", 1: "snowflake"},
lineage_style=lineage_style,
num_hops=graph_level,
default_dataset_prefix=f"{lineage_style.value}.",
)
# Create an emitter to the GMS REST API.
emitter = graph_client
# emitter = DataHubConsoleEmitter()
# Emit metadata!
for mcp in scenario.get_entity_mcps():
emitter.emit_mcp(mcp)
for mcps in scenario.get_lineage_mcps():
emitter.emit_mcp(mcps)
wait_for_writes_to_sync()
try:
scenario.test_expectation(graph_client)
finally:
scenario.cleanup(DataHubGraphDeleteAgent(graph_client))
@pytest.fixture(scope="module")
def chart_urn_fixture():
return "urn:li:chart:(tableau,2241f3d6-df8d-b515-9c0c-f5e5b347b26e)"
@pytest.fixture(scope="module")
def intermediates_fixture():
return [
"urn:li:dataset:(urn:li:dataPlatform:tableau,6bd53e72-9fe4-ea86-3d23-14b826c13fa5,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:tableau,1c5653d6-c448-0850-108b-5c78aeaf6b51,PROD)",
]
@pytest.fixture(scope="module")
def destination_urn_fixture():
return "urn:li:dataset:(urn:li:dataPlatform:external,sales target %28us%29.xlsx.sheet1,PROD)"
@pytest.fixture(scope="module", autouse=False)
def ingest_multipath_metadata(
graph_client: DataHubGraph,
chart_urn_fixture,
intermediates_fixture,
destination_urn_fixture,
):
fake_auditstamp = AuditStampClass(
time=int(time.time() * 1000),
actor="urn:li:corpuser:datahub",
)
chart_urn = chart_urn_fixture
intermediates = intermediates_fixture
destination_urn = destination_urn_fixture
for mcp in MetadataChangeProposalWrapper.construct_many(
entityUrn=destination_urn,
aspects=[
DatasetPropertiesClass(
name="sales target (us).xlsx.sheet1",
),
],
):
graph_client.emit_mcp(mcp)
for intermediate in intermediates:
for mcp in MetadataChangeProposalWrapper.construct_many(
entityUrn=intermediate,
aspects=[
DatasetPropertiesClass(
name="intermediate",
),
UpstreamLineageClass(
upstreams=[
UpstreamClass(
dataset=destination_urn,
type="TRANSFORMED",
)
]
),
],
):
graph_client.emit_mcp(mcp)
for mcp in MetadataChangeProposalWrapper.construct_many(
entityUrn=chart_urn,
aspects=[
ChartInfoClass(
title="chart",
description="chart",
lastModified=ChangeAuditStampsClass(created=fake_auditstamp),
inputEdges=[
EdgeClass(
destinationUrn=intermediate_entity,
sourceUrn=chart_urn,
)
for intermediate_entity in intermediates
],
)
],
):
graph_client.emit_mcp(mcp)
wait_for_writes_to_sync()
yield
for urn in [chart_urn] + intermediates + [destination_urn]:
graph_client.delete_entity(urn, hard=True)
wait_for_writes_to_sync()
# TODO: Reenable once fixed
# def test_simple_lineage_multiple_paths(
# graph_client: DataHubGraph,
# ingest_multipath_metadata,
# chart_urn_fixture,
# intermediates_fixture,
# destination_urn_fixture,
# ):
# chart_urn = chart_urn_fixture
# intermediates = intermediates_fixture
# destination_urn = destination_urn_fixture
# results = search_across_lineage(
# graph_client,
# chart_urn,
# direction="UPSTREAM",
# convert_schema_fields_to_datasets=True,
# )
# assert destination_urn in [
# x["entity"]["urn"] for x in results["searchAcrossLineage"]["searchResults"]
# ]
# for search_result in results["searchAcrossLineage"]["searchResults"]:
# if search_result["entity"]["urn"] == destination_urn:
# assert (
# len(search_result["paths"]) == 2
# ) # 2 paths from the chart to the dataset
# for path in search_result["paths"]:
# assert len(path["path"]) == 3
# assert path["path"][-1]["urn"] == destination_urn
# assert path["path"][0]["urn"] == chart_urn
# assert path["path"][1]["urn"] in intermediates