Simplify topology & update context management (#13196)

This commit is contained in:
Pere Miquel Brull 2023-09-15 09:44:42 +02:00 committed by GitHub
parent 047ab980cc
commit 442528267c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 132 additions and 85 deletions

View File

@ -13,11 +13,15 @@ Mixin to be used by service sources to dynamically
generate the _run based on their topology. generate the _run based on their topology.
""" """
import traceback import traceback
from functools import singledispatchmethod
from typing import Any, Generic, Iterable, List, TypeVar from typing import Any, Generic, Iterable, List, TypeVar
from pydantic import BaseModel from pydantic import BaseModel
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.entity.classification.tag import Tag
from metadata.ingestion.api.models import Either, Entity from metadata.ingestion.api.models import Either, Entity
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.models.topology import ( from metadata.ingestion.models.topology import (
NodeStage, NodeStage,
ServiceTopology, ServiceTopology,
@ -134,7 +138,7 @@ class TopologyRunnerMixin(Generic[C]):
""" """
yield from self.process_nodes(get_topology_root(self.topology)) yield from self.process_nodes(get_topology_root(self.topology))
def update_context(self, key: str, value: Any) -> None: def _replace_context(self, key: str, value: Any) -> None:
""" """
Update the key of the context with the given value Update the key of the context with the given value
:param key: element to update from the source context :param key: element to update from the source context
@ -142,7 +146,7 @@ class TopologyRunnerMixin(Generic[C]):
""" """
self.context.__dict__[key] = value self.context.__dict__[key] = value
def append_context(self, key: str, value: Any) -> None: def _append_context(self, key: str, value: Any) -> None:
""" """
Update the key of the context with the given value Update the key of the context with the given value
:param key: element to update from the source context :param key: element to update from the source context
@ -172,6 +176,115 @@ class TopologyRunnerMixin(Generic[C]):
*context_names, entity_request.name.__root__ *context_names, entity_request.name.__root__
) )
def update_context(self, stage: NodeStage, entity: Entity):
"""Append or update context"""
if stage.context and not stage.cache_all:
self._replace_context(key=stage.context, value=entity)
if stage.context and stage.cache_all:
self._append_context(key=stage.context, value=entity)
@singledispatchmethod
def yield_and_update_context(
self,
right: C,
stage: NodeStage,
entity_request: Either[C],
) -> Iterable[Either[Entity]]:
"""
Handle the process of yielding the request and validating
that everything was properly updated.
The default implementation is based on a get_by_name validation
"""
entity = None
entity_fqn = self.fqn_from_context(stage=stage, entity_request=right)
# we get entity from OM if we do not want to overwrite existing data in OM
if not stage.overwrite and not self._is_force_overwrite_enabled():
entity = self.metadata.get_by_name(
entity=stage.type_,
fqn=entity_fqn,
fields=["*"], # Get all the available data from the Entity
)
# if entity does not exist in OM, or we want to overwrite, we will yield the entity_request
if entity is None:
tries = 3
while not entity and tries > 0:
yield entity_request
entity = self.metadata.get_by_name(
entity=stage.type_,
fqn=entity_fqn,
fields=["*"], # Get all the available data from the Entity
)
tries -= 1
# We have ack the sink waiting for a response, but got nothing back
if stage.must_return and entity is None:
# Safe access to Entity Request name
raise MissingExpectedEntityAckException(
f"Missing ack back from [{stage.type_.__name__}: {getattr(entity_request, 'name')}] - "
"Possible causes are changes in the server Fernet key or mismatched JSON Schemas "
"for the service connection."
)
self.update_context(stage=stage, entity=entity)
@yield_and_update_context.register
def _(
self,
right: AddLineageRequest,
stage: NodeStage,
entity_request: Either[C],
) -> Iterable[Either[Entity]]:
"""
Lineage Implementation for the context information.
There is no simple (efficient) validation to make sure that this specific
lineage has been properly drawn. We'll skip the process for now.
"""
yield entity_request
self.update_context(stage=stage, entity=right)
@yield_and_update_context.register
def _(
self,
right: OMetaTagAndClassification,
stage: NodeStage,
entity_request: Either[C],
) -> Iterable[Either[Entity]]:
"""Tag implementation for the context information"""
yield entity_request
tag = None
tries = 3
while not tag and tries > 0:
yield entity_request
tag = self.metadata.get_by_name(
entity=Tag,
fqn=fqn.build(
metadata=self.metadata,
entity_type=Tag,
classification_name=right.tag_request.classification.__root__,
tag_name=right.tag_request.name.__root__,
),
)
tries -= 1
# We have ack the sink waiting for a response, but got nothing back
if stage.must_return and tag is None:
# Safe access to Entity Request name
raise MissingExpectedEntityAckException(
f"Missing ack back from [Tag: {right.tag_request.name}] - "
"Possible causes are changes in the server Fernet key or mismatched JSON Schemas "
"for the service connection."
)
# We want to keep the full payload in the context
self.update_context(stage=stage, entity=right)
def sink_request( def sink_request(
self, stage: NodeStage, entity_request: Either[C] self, stage: NodeStage, entity_request: Either[C]
) -> Iterable[Either[Entity]]: ) -> Iterable[Either[Entity]]:
@ -197,49 +310,14 @@ class TopologyRunnerMixin(Generic[C]):
# We need to acknowledge that the Entity has been properly sent to the server # We need to acknowledge that the Entity has been properly sent to the server
# to update the context # to update the context
if stage.ack_sink: if stage.context:
entity = None yield from self.yield_and_update_context(
entity, stage=stage, entity_request=entity_request
entity_fqn = self.fqn_from_context(
stage=stage, entity_request=entity_request.right
) )
# we get entity from OM if we do not want to overwrite existing data in OM
if not stage.overwrite and not self._is_force_overwrite_enabled():
entity = self.metadata.get_by_name(
entity=stage.type_,
fqn=entity_fqn,
fields=["*"], # Get all the available data from the Entity
)
# if entity does not exist in OM, or we want to overwrite, we will yield the entity_request
if entity is None:
tries = 3
while not entity and tries > 0:
yield entity_request
entity = self.metadata.get_by_name(
entity=stage.type_,
fqn=entity_fqn,
fields=["*"], # Get all the available data from the Entity
)
tries -= 1
# We have ack the sink waiting for a response, but got nothing back
if stage.must_return and entity is None:
# Safe access to Entity Request name
raise MissingExpectedEntityAckException(
f"Missing ack back from [{stage.type_.__name__}: {getattr(entity_request, 'name')}] - "
"Possible causes are changes in the server Fernet key or mismatched JSON Schemas "
"for the service connection."
)
else: else:
yield entity_request yield entity_request
if stage.context and not stage.cache_all:
self.update_context(key=stage.context, value=entity)
if stage.context and stage.cache_all:
self.append_context(key=stage.context, value=entity)
else: else:
# if entity_request.right is None, means that we have a Left. We yield the Either and # if entity_request.right is None, means that we have a Left. We yield the Either and
# let the step take care of the # let the step take care of the

View File

@ -32,8 +32,9 @@ class NodeStage(BaseModel, Generic[T]):
type_: Type[T] # Entity type type_: Type[T] # Entity type
processor: str # has the producer results as an argument. Here is where filters happen processor: str # has the producer results as an argument. Here is where filters happen
context: Optional[str] = None # context key storing stage state, if needed context: Optional[
ack_sink: bool = True # Validate that the request is present in OM and update the context with the results str
] = None # context key storing stage state, if needed. This requires us to ACK the ingestion
nullable: bool = False # The yielded value can be null nullable: bool = False # The yielded value can be null
must_return: bool = False # The sink MUST return a value back after ack. Useful to validate services are correct. must_return: bool = False # The sink MUST return a value back after ack. Useful to validate services are correct.
cache_all: bool = ( cache_all: bool = (

View File

@ -105,9 +105,7 @@ class DashboardServiceTopology(ServiceTopology):
), ),
NodeStage( NodeStage(
type_=OMetaTagAndClassification, type_=OMetaTagAndClassification,
context="tags",
processor="yield_tag", processor="yield_tag",
ack_sink=False,
nullable=True, nullable=True,
), ),
], ],
@ -169,7 +167,6 @@ class DashboardServiceTopology(ServiceTopology):
context="lineage", context="lineage",
processor="yield_dashboard_lineage", processor="yield_dashboard_lineage",
consumer=["dashboard_service"], consumer=["dashboard_service"],
ack_sink=False,
nullable=True, nullable=True,
), ),
NodeStage( NodeStage(
@ -177,7 +174,6 @@ class DashboardServiceTopology(ServiceTopology):
context="usage", context="usage",
processor="yield_dashboard_usage", processor="yield_dashboard_usage",
consumer=["dashboard_service"], consumer=["dashboard_service"],
ack_sink=False,
nullable=True, nullable=True,
), ),
], ],

View File

@ -126,7 +126,6 @@ class DatabaseServiceTopology(ServiceTopology):
type_=OMetaTagAndClassification, type_=OMetaTagAndClassification,
context="tags", context="tags",
processor="yield_database_schema_tag_details", processor="yield_database_schema_tag_details",
ack_sink=False,
nullable=True, nullable=True,
cache_all=True, cache_all=True,
), ),
@ -147,7 +146,6 @@ class DatabaseServiceTopology(ServiceTopology):
type_=OMetaTagAndClassification, type_=OMetaTagAndClassification,
context="tags", context="tags",
processor="yield_table_tag_details", processor="yield_table_tag_details",
ack_sink=False,
nullable=True, nullable=True,
cache_all=True, cache_all=True,
), ),
@ -159,9 +157,7 @@ class DatabaseServiceTopology(ServiceTopology):
), ),
NodeStage( NodeStage(
type_=OMetaLifeCycleData, type_=OMetaLifeCycleData,
context="life_cycle",
processor="yield_life_cycle_data", processor="yield_life_cycle_data",
ack_sink=False,
nullable=True, nullable=True,
), ),
], ],
@ -182,17 +178,15 @@ class DatabaseServiceTopology(ServiceTopology):
producer="get_stored_procedure_queries", producer="get_stored_procedure_queries",
stages=[ stages=[
NodeStage( NodeStage(
type_=AddLineageRequest, # TODO: Fix context management for multiple types type_=AddLineageRequest,
processor="yield_procedure_lineage", processor="yield_procedure_lineage",
context="stored_procedure_query_lineage", # Used to flag if the query has had processed lineage context="stored_procedure_query_lineage", # Used to flag if the query has had processed lineage
nullable=True, nullable=True,
ack_sink=False,
), ),
NodeStage( NodeStage(
type_=Query, type_=Query,
processor="yield_procedure_query", processor="yield_procedure_query",
nullable=True, nullable=True,
ack_sink=False,
), ),
], ],
) )

View File

@ -65,7 +65,6 @@ class DbtServiceTopology(ServiceTopology):
NodeStage( NodeStage(
type_=DbtFiles, type_=DbtFiles,
processor="validate_dbt_files", processor="validate_dbt_files",
ack_sink=False,
nullable=True, nullable=True,
) )
], ],
@ -82,14 +81,12 @@ class DbtServiceTopology(ServiceTopology):
type_=OMetaTagAndClassification, type_=OMetaTagAndClassification,
context="tags", context="tags",
processor="yield_dbt_tags", processor="yield_dbt_tags",
ack_sink=False,
nullable=True, nullable=True,
cache_all=True, cache_all=True,
), ),
NodeStage( NodeStage(
type_=DataModelLink, type_=DataModelLink,
processor="yield_data_models", processor="yield_data_models",
ack_sink=False,
nullable=True, nullable=True,
), ),
], ],
@ -100,17 +97,14 @@ class DbtServiceTopology(ServiceTopology):
NodeStage( NodeStage(
type_=AddLineageRequest, type_=AddLineageRequest,
processor="create_dbt_lineage", processor="create_dbt_lineage",
ack_sink=False,
), ),
NodeStage( NodeStage(
type_=AddLineageRequest, type_=AddLineageRequest,
processor="create_dbt_query_lineage", processor="create_dbt_query_lineage",
ack_sink=False,
), ),
NodeStage( NodeStage(
type_=DataModelLink, type_=DataModelLink,
processor="process_dbt_descriptions", processor="process_dbt_descriptions",
ack_sink=False,
nullable=True, nullable=True,
), ),
], ],
@ -121,17 +115,14 @@ class DbtServiceTopology(ServiceTopology):
NodeStage( NodeStage(
type_=CreateTestDefinitionRequest, type_=CreateTestDefinitionRequest,
processor="create_dbt_tests_definition", processor="create_dbt_tests_definition",
ack_sink=False,
), ),
NodeStage( NodeStage(
type_=CreateTestCaseRequest, type_=CreateTestCaseRequest,
processor="create_dbt_test_case", processor="create_dbt_test_case",
ack_sink=False,
), ),
NodeStage( NodeStage(
type_=TestCaseResult, type_=TestCaseResult,
processor="add_dbt_test_result", processor="add_dbt_test_result",
ack_sink=False,
nullable=True, nullable=True,
), ),
], ],

View File

@ -133,12 +133,10 @@ class StoredProcedureMixin:
) -> Iterable[Either[AddLineageRequest]]: ) -> Iterable[Either[AddLineageRequest]]:
"""Add procedure lineage from its query""" """Add procedure lineage from its query"""
self.update_context(key="stored_procedure_query_lineage", value=False)
if self.is_lineage_query( if self.is_lineage_query(
query_type=query_by_procedure.query_type, query_type=query_by_procedure.query_type,
query_text=query_by_procedure.query_text, query_text=query_by_procedure.query_text,
): ):
self.update_context(key="stored_procedure_query_lineage", value=True)
for either_lineage in get_lineage_by_query( for either_lineage in get_lineage_by_query(
self.metadata, self.metadata,

View File

@ -95,11 +95,9 @@ class MessagingServiceTopology(ServiceTopology):
), ),
NodeStage( NodeStage(
type_=TopicSampleData, type_=TopicSampleData,
context="topic_sample_data",
processor="yield_topic_sample_data", processor="yield_topic_sample_data",
consumer=["messaging_service"], consumer=["messaging_service"],
nullable=True, nullable=True,
ack_sink=False,
), ),
], ],
) )

View File

@ -129,7 +129,7 @@ class DatabrickspipelineSource(PipelineServiceSource):
def get_tasks(self, pipeline_details: dict) -> List[Task]: def get_tasks(self, pipeline_details: dict) -> List[Task]:
task_list = [] task_list = []
self.append_context(key="job_id_list", value=pipeline_details["job_id"]) self._append_context(key="job_id_list", value=pipeline_details["job_id"])
downstream_tasks = self.get_downstream_tasks( downstream_tasks = self.get_downstream_tasks(
pipeline_details["settings"].get("tasks") pipeline_details["settings"].get("tasks")

View File

@ -81,7 +81,6 @@ class PipelineServiceTopology(ServiceTopology):
type_=OMetaTagAndClassification, type_=OMetaTagAndClassification,
context="tags", context="tags",
processor="yield_tag", processor="yield_tag",
ack_sink=False,
nullable=True, nullable=True,
), ),
NodeStage( NodeStage(
@ -92,18 +91,14 @@ class PipelineServiceTopology(ServiceTopology):
), ),
NodeStage( NodeStage(
type_=OMetaPipelineStatus, type_=OMetaPipelineStatus,
context="pipeline_status",
processor="yield_pipeline_status", processor="yield_pipeline_status",
consumer=["pipeline_service"], consumer=["pipeline_service"],
nullable=True, nullable=True,
ack_sink=False,
), ),
NodeStage( NodeStage(
type_=AddLineageRequest, type_=AddLineageRequest,
context="lineage",
processor="yield_pipeline_lineage", processor="yield_pipeline_lineage",
consumer=["pipeline_service"], consumer=["pipeline_service"],
ack_sink=False,
nullable=True, nullable=True,
), ),
], ],

View File

@ -91,10 +91,8 @@ class SearchServiceTopology(ServiceTopology):
), ),
NodeStage( NodeStage(
type_=OMetaIndexSampleData, type_=OMetaIndexSampleData,
context="search_index_sample_data",
processor="yield_search_index_sample_data", processor="yield_search_index_sample_data",
consumer=["search_service"], consumer=["search_service"],
ack_sink=False,
nullable=True, nullable=True,
), ),
], ],

View File

@ -30,9 +30,7 @@ class MockTopology(ServiceTopology):
stages=[ stages=[
NodeStage( NodeStage(
type_=int, type_=int,
context="numbers",
processor="yield_numbers", processor="yield_numbers",
ack_sink=False,
) )
], ],
children=["strings"], children=["strings"],
@ -42,9 +40,7 @@ class MockTopology(ServiceTopology):
stages=[ stages=[
NodeStage( NodeStage(
type_=str, type_=str,
context="strings",
processor="yield_strings", processor="yield_strings",
ack_sink=False,
consumer=["numbers"], consumer=["numbers"],
) )
], ],
@ -69,21 +65,23 @@ class MockSource(TopologyRunnerMixin):
def yield_numbers(number: int): def yield_numbers(number: int):
yield Either(right=number + 1) yield Either(right=number + 1)
def yield_strings(self, my_str: str): @staticmethod
yield Either(right=my_str + str(self.context.numbers)) def yield_strings(my_str: str):
yield Either(right=my_str)
class TopologyRunnerTest(TestCase): class TopologyRunnerTest(TestCase):
"""Validate filter patterns""" """Validate filter patterns"""
def test_node_and_stage(self): @staticmethod
def test_node_and_stage():
source = MockSource() source = MockSource()
processed = list(source._iter()) processed = list(source._iter())
assert [either.right for either in processed] == [ assert [either.right for either in processed] == [
2, 2,
"abc2", "abc",
"def2", "def",
3, 3,
"abc3", "abc",
"def3", "def",
] ]