diff --git a/ingestion/src/metadata/ingestion/api/topology_runner.py b/ingestion/src/metadata/ingestion/api/topology_runner.py index f3e5d754077..0cd74be40a7 100644 --- a/ingestion/src/metadata/ingestion/api/topology_runner.py +++ b/ingestion/src/metadata/ingestion/api/topology_runner.py @@ -13,11 +13,15 @@ Mixin to be used by service sources to dynamically generate the _run based on their topology. """ import traceback +from functools import singledispatchmethod from typing import Any, Generic, Iterable, List, TypeVar from pydantic import BaseModel +from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest +from metadata.generated.schema.entity.classification.tag import Tag from metadata.ingestion.api.models import Either, Entity +from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification from metadata.ingestion.models.topology import ( NodeStage, ServiceTopology, @@ -134,7 +138,7 @@ class TopologyRunnerMixin(Generic[C]): """ yield from self.process_nodes(get_topology_root(self.topology)) - def update_context(self, key: str, value: Any) -> None: + def _replace_context(self, key: str, value: Any) -> None: """ Update the key of the context with the given value :param key: element to update from the source context @@ -142,7 +146,7 @@ class TopologyRunnerMixin(Generic[C]): """ self.context.__dict__[key] = value - def append_context(self, key: str, value: Any) -> None: + def _append_context(self, key: str, value: Any) -> None: """ Update the key of the context with the given value :param key: element to update from the source context @@ -172,6 +176,115 @@ class TopologyRunnerMixin(Generic[C]): *context_names, entity_request.name.__root__ ) + def update_context(self, stage: NodeStage, entity: Entity): + """Append or update context""" + if stage.context and not stage.cache_all: + self._replace_context(key=stage.context, value=entity) + if stage.context and stage.cache_all: + self._append_context(key=stage.context, value=entity) + + @singledispatchmethod + def yield_and_update_context( + self, + right: C, + stage: NodeStage, + entity_request: Either[C], + ) -> Iterable[Either[Entity]]: + """ + Handle the process of yielding the request and validating + that everything was properly updated. + + The default implementation is based on a get_by_name validation + """ + + entity = None + + entity_fqn = self.fqn_from_context(stage=stage, entity_request=right) + + # we get entity from OM if we do not want to overwrite existing data in OM + if not stage.overwrite and not self._is_force_overwrite_enabled(): + entity = self.metadata.get_by_name( + entity=stage.type_, + fqn=entity_fqn, + fields=["*"], # Get all the available data from the Entity + ) + # if entity does not exist in OM, or we want to overwrite, we will yield the entity_request + if entity is None: + tries = 3 + while not entity and tries > 0: + yield entity_request + entity = self.metadata.get_by_name( + entity=stage.type_, + fqn=entity_fqn, + fields=["*"], # Get all the available data from the Entity + ) + tries -= 1 + + # We have ack the sink waiting for a response, but got nothing back + if stage.must_return and entity is None: + # Safe access to Entity Request name + raise MissingExpectedEntityAckException( + f"Missing ack back from [{stage.type_.__name__}: {getattr(entity_request, 'name')}] - " + "Possible causes are changes in the server Fernet key or mismatched JSON Schemas " + "for the service connection." + ) + + self.update_context(stage=stage, entity=entity) + + @yield_and_update_context.register + def _( + self, + right: AddLineageRequest, + stage: NodeStage, + entity_request: Either[C], + ) -> Iterable[Either[Entity]]: + """ + Lineage Implementation for the context information. + + There is no simple (efficient) validation to make sure that this specific + lineage has been properly drawn. We'll skip the process for now. + """ + yield entity_request + self.update_context(stage=stage, entity=right) + + @yield_and_update_context.register + def _( + self, + right: OMetaTagAndClassification, + stage: NodeStage, + entity_request: Either[C], + ) -> Iterable[Either[Entity]]: + """Tag implementation for the context information""" + yield entity_request + + tag = None + + tries = 3 + while not tag and tries > 0: + yield entity_request + tag = self.metadata.get_by_name( + entity=Tag, + fqn=fqn.build( + metadata=self.metadata, + entity_type=Tag, + classification_name=right.tag_request.classification.__root__, + tag_name=right.tag_request.name.__root__, + ), + ) + tries -= 1 + + # We have ack the sink waiting for a response, but got nothing back + if stage.must_return and tag is None: + # Safe access to Entity Request name + raise MissingExpectedEntityAckException( + f"Missing ack back from [Tag: {right.tag_request.name}] - " + "Possible causes are changes in the server Fernet key or mismatched JSON Schemas " + "for the service connection." + ) + + # We want to keep the full payload in the context + self.update_context(stage=stage, entity=right) + def sink_request( self, stage: NodeStage, entity_request: Either[C] ) -> Iterable[Either[Entity]]: @@ -197,49 +310,14 @@ class TopologyRunnerMixin(Generic[C]): # We need to acknowledge that the Entity has been properly sent to the server # to update the context - if stage.ack_sink: - entity = None - - entity_fqn = self.fqn_from_context( - stage=stage, entity_request=entity_request.right + if stage.context: + yield from self.yield_and_update_context( + entity, stage=stage, entity_request=entity_request ) - # we get entity from OM if we do not want to overwrite existing data in OM - if not stage.overwrite and not self._is_force_overwrite_enabled(): - entity = self.metadata.get_by_name( - entity=stage.type_, - fqn=entity_fqn, - fields=["*"], # Get all the available data from the Entity - ) - # if entity does not exist in OM, or we want to overwrite, we will yield the entity_request - if entity is None: - tries = 3 - while not entity and tries > 0: - yield entity_request - entity = self.metadata.get_by_name( - entity=stage.type_, - fqn=entity_fqn, - fields=["*"], # Get all the available data from the Entity - ) - tries -= 1 - - # We have ack the sink waiting for a response, but got nothing back - if stage.must_return and entity is None: - # Safe access to Entity Request name - raise MissingExpectedEntityAckException( - f"Missing ack back from [{stage.type_.__name__}: {getattr(entity_request, 'name')}] - " - "Possible causes are changes in the server Fernet key or mismatched JSON Schemas " - "for the service connection." - ) - else: yield entity_request - if stage.context and not stage.cache_all: - self.update_context(key=stage.context, value=entity) - if stage.context and stage.cache_all: - self.append_context(key=stage.context, value=entity) - else: # if entity_request.right is None, means that we have a Left. We yield the Either and # let the step take care of the diff --git a/ingestion/src/metadata/ingestion/models/topology.py b/ingestion/src/metadata/ingestion/models/topology.py index fd1456c1641..410a3ec82e7 100644 --- a/ingestion/src/metadata/ingestion/models/topology.py +++ b/ingestion/src/metadata/ingestion/models/topology.py @@ -32,8 +32,9 @@ class NodeStage(BaseModel, Generic[T]): type_: Type[T] # Entity type processor: str # has the producer results as an argument. Here is where filters happen - context: Optional[str] = None # context key storing stage state, if needed - ack_sink: bool = True # Validate that the request is present in OM and update the context with the results + context: Optional[ + str + ] = None # context key storing stage state, if needed. This requires us to ACK the ingestion nullable: bool = False # The yielded value can be null must_return: bool = False # The sink MUST return a value back after ack. Useful to validate services are correct. cache_all: bool = ( diff --git a/ingestion/src/metadata/ingestion/source/dashboard/dashboard_service.py b/ingestion/src/metadata/ingestion/source/dashboard/dashboard_service.py index d7167dc9907..1d40c5dea1c 100644 --- a/ingestion/src/metadata/ingestion/source/dashboard/dashboard_service.py +++ b/ingestion/src/metadata/ingestion/source/dashboard/dashboard_service.py @@ -105,9 +105,7 @@ class DashboardServiceTopology(ServiceTopology): ), NodeStage( type_=OMetaTagAndClassification, - context="tags", processor="yield_tag", - ack_sink=False, nullable=True, ), ], @@ -169,7 +167,6 @@ class DashboardServiceTopology(ServiceTopology): context="lineage", processor="yield_dashboard_lineage", consumer=["dashboard_service"], - ack_sink=False, nullable=True, ), NodeStage( @@ -177,7 +174,6 @@ class DashboardServiceTopology(ServiceTopology): context="usage", processor="yield_dashboard_usage", consumer=["dashboard_service"], - ack_sink=False, nullable=True, ), ], diff --git a/ingestion/src/metadata/ingestion/source/database/database_service.py b/ingestion/src/metadata/ingestion/source/database/database_service.py index 4f9e1479c70..5fceebc63c8 100644 --- a/ingestion/src/metadata/ingestion/source/database/database_service.py +++ b/ingestion/src/metadata/ingestion/source/database/database_service.py @@ -126,7 +126,6 @@ class DatabaseServiceTopology(ServiceTopology): type_=OMetaTagAndClassification, context="tags", processor="yield_database_schema_tag_details", - ack_sink=False, nullable=True, cache_all=True, ), @@ -147,7 +146,6 @@ class DatabaseServiceTopology(ServiceTopology): type_=OMetaTagAndClassification, context="tags", processor="yield_table_tag_details", - ack_sink=False, nullable=True, cache_all=True, ), @@ -159,9 +157,7 @@ class DatabaseServiceTopology(ServiceTopology): ), NodeStage( type_=OMetaLifeCycleData, - context="life_cycle", processor="yield_life_cycle_data", - ack_sink=False, nullable=True, ), ], @@ -182,17 +178,15 @@ class DatabaseServiceTopology(ServiceTopology): producer="get_stored_procedure_queries", stages=[ NodeStage( - type_=AddLineageRequest, # TODO: Fix context management for multiple types + type_=AddLineageRequest, processor="yield_procedure_lineage", context="stored_procedure_query_lineage", # Used to flag if the query has had processed lineage nullable=True, - ack_sink=False, ), NodeStage( type_=Query, processor="yield_procedure_query", nullable=True, - ack_sink=False, ), ], ) diff --git a/ingestion/src/metadata/ingestion/source/database/dbt/dbt_service.py b/ingestion/src/metadata/ingestion/source/database/dbt/dbt_service.py index d42b582809c..53649724680 100644 --- a/ingestion/src/metadata/ingestion/source/database/dbt/dbt_service.py +++ b/ingestion/src/metadata/ingestion/source/database/dbt/dbt_service.py @@ -65,7 +65,6 @@ class DbtServiceTopology(ServiceTopology): NodeStage( type_=DbtFiles, processor="validate_dbt_files", - ack_sink=False, nullable=True, ) ], @@ -82,14 +81,12 @@ class DbtServiceTopology(ServiceTopology): type_=OMetaTagAndClassification, context="tags", processor="yield_dbt_tags", - ack_sink=False, nullable=True, cache_all=True, ), NodeStage( type_=DataModelLink, processor="yield_data_models", - ack_sink=False, nullable=True, ), ], @@ -100,17 +97,14 @@ class DbtServiceTopology(ServiceTopology): NodeStage( type_=AddLineageRequest, processor="create_dbt_lineage", - ack_sink=False, ), NodeStage( type_=AddLineageRequest, processor="create_dbt_query_lineage", - ack_sink=False, ), NodeStage( type_=DataModelLink, processor="process_dbt_descriptions", - ack_sink=False, nullable=True, ), ], @@ -121,17 +115,14 @@ class DbtServiceTopology(ServiceTopology): NodeStage( type_=CreateTestDefinitionRequest, processor="create_dbt_tests_definition", - ack_sink=False, ), NodeStage( type_=CreateTestCaseRequest, processor="create_dbt_test_case", - ack_sink=False, ), NodeStage( type_=TestCaseResult, processor="add_dbt_test_result", - ack_sink=False, nullable=True, ), ], diff --git a/ingestion/src/metadata/ingestion/source/database/stored_procedures_mixin.py b/ingestion/src/metadata/ingestion/source/database/stored_procedures_mixin.py index 2acbe327fd5..9905782f06e 100644 --- a/ingestion/src/metadata/ingestion/source/database/stored_procedures_mixin.py +++ b/ingestion/src/metadata/ingestion/source/database/stored_procedures_mixin.py @@ -133,12 +133,10 @@ class StoredProcedureMixin: ) -> Iterable[Either[AddLineageRequest]]: """Add procedure lineage from its query""" - self.update_context(key="stored_procedure_query_lineage", value=False) if self.is_lineage_query( query_type=query_by_procedure.query_type, query_text=query_by_procedure.query_text, ): - self.update_context(key="stored_procedure_query_lineage", value=True) for either_lineage in get_lineage_by_query( self.metadata, diff --git a/ingestion/src/metadata/ingestion/source/messaging/messaging_service.py b/ingestion/src/metadata/ingestion/source/messaging/messaging_service.py index 38a5c55ccdc..380ec9bc3d2 100644 --- a/ingestion/src/metadata/ingestion/source/messaging/messaging_service.py +++ b/ingestion/src/metadata/ingestion/source/messaging/messaging_service.py @@ -95,11 +95,9 @@ class MessagingServiceTopology(ServiceTopology): ), NodeStage( type_=TopicSampleData, - context="topic_sample_data", processor="yield_topic_sample_data", consumer=["messaging_service"], nullable=True, - ack_sink=False, ), ], ) diff --git a/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py b/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py index f913b78bb4d..e60fa1448e7 100644 --- a/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py +++ b/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py @@ -129,7 +129,7 @@ class DatabrickspipelineSource(PipelineServiceSource): def get_tasks(self, pipeline_details: dict) -> List[Task]: task_list = [] - self.append_context(key="job_id_list", value=pipeline_details["job_id"]) + self._append_context(key="job_id_list", value=pipeline_details["job_id"]) downstream_tasks = self.get_downstream_tasks( pipeline_details["settings"].get("tasks") diff --git a/ingestion/src/metadata/ingestion/source/pipeline/pipeline_service.py b/ingestion/src/metadata/ingestion/source/pipeline/pipeline_service.py index 6027c351247..cf13868d2dd 100644 --- a/ingestion/src/metadata/ingestion/source/pipeline/pipeline_service.py +++ b/ingestion/src/metadata/ingestion/source/pipeline/pipeline_service.py @@ -81,7 +81,6 @@ class PipelineServiceTopology(ServiceTopology): type_=OMetaTagAndClassification, context="tags", processor="yield_tag", - ack_sink=False, nullable=True, ), NodeStage( @@ -92,18 +91,14 @@ class PipelineServiceTopology(ServiceTopology): ), NodeStage( type_=OMetaPipelineStatus, - context="pipeline_status", processor="yield_pipeline_status", consumer=["pipeline_service"], nullable=True, - ack_sink=False, ), NodeStage( type_=AddLineageRequest, - context="lineage", processor="yield_pipeline_lineage", consumer=["pipeline_service"], - ack_sink=False, nullable=True, ), ], diff --git a/ingestion/src/metadata/ingestion/source/search/search_service.py b/ingestion/src/metadata/ingestion/source/search/search_service.py index fb68d08bcc4..044fcf2819d 100644 --- a/ingestion/src/metadata/ingestion/source/search/search_service.py +++ b/ingestion/src/metadata/ingestion/source/search/search_service.py @@ -91,10 +91,8 @@ class SearchServiceTopology(ServiceTopology): ), NodeStage( type_=OMetaIndexSampleData, - context="search_index_sample_data", processor="yield_search_index_sample_data", consumer=["search_service"], - ack_sink=False, nullable=True, ), ], diff --git a/ingestion/tests/unit/topology/test_runner.py b/ingestion/tests/unit/topology/test_runner.py index 382b77f4b69..57ed13931e6 100644 --- a/ingestion/tests/unit/topology/test_runner.py +++ b/ingestion/tests/unit/topology/test_runner.py @@ -30,9 +30,7 @@ class MockTopology(ServiceTopology): stages=[ NodeStage( type_=int, - context="numbers", processor="yield_numbers", - ack_sink=False, ) ], children=["strings"], @@ -42,9 +40,7 @@ class MockTopology(ServiceTopology): stages=[ NodeStage( type_=str, - context="strings", processor="yield_strings", - ack_sink=False, consumer=["numbers"], ) ], @@ -69,21 +65,23 @@ class MockSource(TopologyRunnerMixin): def yield_numbers(number: int): yield Either(right=number + 1) - def yield_strings(self, my_str: str): - yield Either(right=my_str + str(self.context.numbers)) + @staticmethod + def yield_strings(my_str: str): + yield Either(right=my_str) class TopologyRunnerTest(TestCase): """Validate filter patterns""" - def test_node_and_stage(self): + @staticmethod + def test_node_and_stage(): source = MockSource() processed = list(source._iter()) assert [either.right for either in processed] == [ 2, - "abc2", - "def2", + "abc", + "def", 3, - "abc3", - "def3", + "abc", + "def", ]