diff --git a/metadata-ingestion/src/datahub/ingestion/api/common.py b/metadata-ingestion/src/datahub/ingestion/api/common.py index 0c0855557c..778bd11961 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/common.py +++ b/metadata-ingestion/src/datahub/ingestion/api/common.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Dict, Generic, Iterable, Optional, Tuple, Type from datahub.emitter.mce_builder import set_dataset_urn_to_lower from datahub.ingestion.api.committable import Committable -from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph +from datahub.ingestion.graph.client import DataHubGraph if TYPE_CHECKING: from datahub.ingestion.run.pipeline import PipelineConfig @@ -43,22 +43,19 @@ class PipelineContext: def __init__( self, run_id: str, - datahub_api: Optional["DatahubClientConfig"] = None, + graph: Optional[DataHubGraph] = None, pipeline_name: Optional[str] = None, dry_run: bool = False, preview_mode: bool = False, pipeline_config: Optional["PipelineConfig"] = None, ) -> None: self.pipeline_config = pipeline_config + self.graph = graph self.run_id = run_id self.pipeline_name = pipeline_name self.dry_run_mode = dry_run self.preview_mode = preview_mode self.checkpointers: Dict[str, Committable] = {} - try: - self.graph = DataHubGraph(datahub_api) if datahub_api is not None else None - except Exception as e: - raise Exception(f"Failed to connect to DataHub: {e}") from e self._set_dataset_urn_to_lower_if_needed() diff --git a/metadata-ingestion/src/datahub/ingestion/run/pipeline.py b/metadata-ingestion/src/datahub/ingestion/run/pipeline.py index fe4911729a..8cabc0cccf 100644 --- a/metadata-ingestion/src/datahub/ingestion/run/pipeline.py +++ b/metadata-ingestion/src/datahub/ingestion/run/pipeline.py @@ -27,6 +27,7 @@ from datahub.ingestion.api.sink import Sink, SinkReport, WriteCallback from datahub.ingestion.api.source import Extractor, Source from datahub.ingestion.api.transform import Transformer from datahub.ingestion.extractor.extractor_registry import extractor_registry +from datahub.ingestion.graph.client import DataHubGraph from datahub.ingestion.reporting.reporting_provider_registry import ( reporting_provider_registry, ) @@ -183,10 +184,15 @@ class Pipeline: self.last_time_printed = int(time.time()) self.cli_report = CliReport() + self.graph = None + with _add_init_error_context("connect to DataHub"): + if self.config.datahub_api: + self.graph = DataHubGraph(self.config.datahub_api) + with _add_init_error_context("set up framework context"): self.ctx = PipelineContext( run_id=self.config.run_id, - datahub_api=self.config.datahub_api, + graph=self.graph, pipeline_name=self.config.pipeline_name, dry_run=dry_run, preview_mode=preview_mode, diff --git a/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py b/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py index 40b6f47235..2b1ef8ec50 100644 --- a/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py +++ b/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from datetime import timedelta from enum import auto from threading import BoundedSemaphore -from typing import Union, cast +from typing import Union from datahub.cli.cli_utils import set_env_variables_override_config from datahub.configuration.common import ( @@ -126,8 +126,7 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]): def handle_work_unit_start(self, workunit: WorkUnit) -> None: if isinstance(workunit, MetadataWorkUnit): - mwu: MetadataWorkUnit = cast(MetadataWorkUnit, workunit) - self.treat_errors_as_warnings = mwu.treat_errors_as_warnings + self.treat_errors_as_warnings = workunit.treat_errors_as_warnings def handle_work_unit_end(self, workunit: WorkUnit) -> None: pass diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_source.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_source.py index 9538aadd5e..458c359949 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_source.py @@ -1331,6 +1331,3 @@ class LookerDashboardSource(TestableSource, StatefulIngestionSourceBase): def get_report(self) -> SourceReport: return self.reporter - - def close(self): - self.prepare_for_commit() diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_source.py b/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_source.py index 55615f2c1d..e9dbb6d63f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_source.py @@ -2163,6 +2163,3 @@ class LookMLSource(StatefulIngestionSourceBase): def get_report(self): return self.reporter - - def close(self): - self.prepare_for_commit() diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py index 0603699723..8cc52cfec3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py @@ -1228,7 +1228,7 @@ class PowerBiDashboardSource(StatefulIngestionSourceBase): # Because job_id is used as dictionary key, we have to set a new job_id # Refer to https://github.com/datahub-project/datahub/blob/master/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py#L390 self.stale_entity_removal_handler.set_job_id(workspace.id) - self.register_stateful_ingestion_usecase_handler( + self.state_provider.register_stateful_ingestion_usecase_handler( self.stale_entity_removal_handler ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py b/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py index 4c81ed16fb..148b865ca4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py @@ -1283,6 +1283,3 @@ class VerticaSource(SQLAlchemySource): return each["owner_name"] return None - - def close(self): - self.prepare_for_commit() diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/entity_removal_state.py b/metadata-ingestion/src/datahub/ingestion/source/state/entity_removal_state.py index e7df1f9116..65d3428111 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/entity_removal_state.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/entity_removal_state.py @@ -1,11 +1,9 @@ -from typing import Any, Dict, Iterable, List, Type +from typing import Any, Dict, Iterable, List, Tuple, Type import pydantic from datahub.emitter.mce_builder import make_assertion_urn, make_container_urn -from datahub.ingestion.source.state.stale_entity_removal_handler import ( - StaleEntityCheckpointStateBase, -) +from datahub.ingestion.source.state.checkpoint import CheckpointStateBase from datahub.utilities.checkpoint_state_util import CheckpointStateUtil from datahub.utilities.dedup_list import deduplicate_list from datahub.utilities.urns.urn import guess_entity_type @@ -56,7 +54,7 @@ def pydantic_state_migrator(mapping: Dict[str, str]) -> classmethod: return pydantic.root_validator(pre=True, allow_reuse=True)(_validate_field_rename) -class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointState"]): +class GenericCheckpointState(CheckpointStateBase): urns: List[str] = pydantic.Field(default_factory=list) # We store a bit of extra internal-only state so that we can keep the urns list deduplicated. @@ -85,11 +83,15 @@ class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointSt self.urns = deduplicate_list(self.urns) self._urns_set = set(self.urns) - @classmethod - def get_supported_types(cls) -> List[str]: - return ["*"] - def add_checkpoint_urn(self, type: str, urn: str) -> None: + """ + Adds an urn into the list used for tracking the type. + + :param type: Deprecated parameter, has no effect. + :param urn: The urn string + """ + + # TODO: Deprecate the `type` parameter and remove it. if urn not in self._urns_set: self.urns.append(urn) self._urns_set.add(urn) @@ -97,9 +99,18 @@ class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointSt def get_urns_not_in( self, type: str, other_checkpoint_state: "GenericCheckpointState" ) -> Iterable[str]: + """ + Gets the urns present in this checkpoint but not the other_checkpoint for the given type. + + :param type: Deprecated. Set to "*". + :param other_checkpoint_state: the checkpoint state to compute the urn set difference against. + :return: an iterable to the set of urns present in this checkpoint state but not in the other_checkpoint. + """ + diff = set(self.urns) - set(other_checkpoint_state.urns) # To maintain backwards compatibility, we provide this filtering mechanism. + # TODO: Deprecate the `type` parameter and remove it. if type == "*": yield from diff elif type == "topic": @@ -110,6 +121,36 @@ class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointSt def get_percent_entities_changed( self, old_checkpoint_state: "GenericCheckpointState" ) -> float: - return StaleEntityCheckpointStateBase.compute_percent_entities_changed( + """ + Returns the percentage of entities that have changed relative to `old_checkpoint_state`. + + :param old_checkpoint_state: the old checkpoint state to compute the relative change percent against. + :return: (1-|intersection(self, old_checkpoint_state)| / |old_checkpoint_state|) * 100.0 + """ + return compute_percent_entities_changed( [(self.urns, old_checkpoint_state.urns)] ) + + +def compute_percent_entities_changed( + new_old_entity_list: List[Tuple[List[str], List[str]]] +) -> float: + old_count_all = 0 + overlap_count_all = 0 + for new_entities, old_entities in new_old_entity_list: + (overlap_count, old_count, _,) = get_entity_overlap_and_cardinalities( + new_entities=new_entities, old_entities=old_entities + ) + overlap_count_all += overlap_count + old_count_all += old_count + if old_count_all: + return (1 - overlap_count_all / old_count_all) * 100.0 + return 0.0 + + +def get_entity_overlap_and_cardinalities( + new_entities: List[str], old_entities: List[str] +) -> Tuple[int, int, int]: + new_set = set(new_entities) + old_set = set(old_entities) + return len(new_set.intersection(old_set)), len(old_set), len(new_set) diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/profiling_state_handler.py b/metadata-ingestion/src/datahub/ingestion/source/state/profiling_state_handler.py index 853abe51df..9883bc2b8e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/profiling_state_handler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/profiling_state_handler.py @@ -40,15 +40,17 @@ class ProfilingHandler(StatefulIngestionUsecaseHandlerBase[ProfilingCheckpointSt pipeline_name: Optional[str], run_id: str, ): - self.source = source + self.state_provider = source.state_provider self.stateful_ingestion_config: Optional[ ProfilingStatefulIngestionConfig ] = config.stateful_ingestion self.pipeline_name = pipeline_name self.run_id = run_id - self.checkpointing_enabled: bool = source.is_stateful_ingestion_configured() + self.checkpointing_enabled: bool = ( + self.state_provider.is_stateful_ingestion_configured() + ) self._job_id = self._init_job_id() - self.source.register_stateful_ingestion_usecase_handler(self) + self.state_provider.register_stateful_ingestion_usecase_handler(self) def _ignore_old_state(self) -> bool: if ( @@ -91,7 +93,7 @@ class ProfilingHandler(StatefulIngestionUsecaseHandlerBase[ProfilingCheckpointSt def get_current_state(self) -> Optional[ProfilingCheckpointState]: if not self.is_checkpointing_enabled() or self._ignore_new_state(): return None - cur_checkpoint = self.source.get_current_checkpoint(self.job_id) + cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id) assert cur_checkpoint is not None cur_state = cast(ProfilingCheckpointState, cur_checkpoint.state) return cur_state @@ -108,7 +110,7 @@ class ProfilingHandler(StatefulIngestionUsecaseHandlerBase[ProfilingCheckpointSt def get_last_state(self) -> Optional[ProfilingCheckpointState]: if not self.is_checkpointing_enabled() or self._ignore_old_state(): return None - last_checkpoint = self.source.get_last_checkpoint( + last_checkpoint = self.state_provider.get_last_checkpoint( self.job_id, ProfilingCheckpointState ) if last_checkpoint and last_checkpoint.state: diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py b/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py index 69daaafdb3..459dbe0ce0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py @@ -46,14 +46,17 @@ class RedundantRunSkipHandler( run_id: str, ): self.source = source + self.state_provider = source.state_provider self.stateful_ingestion_config: Optional[ StatefulRedundantRunSkipConfig ] = config.stateful_ingestion self.pipeline_name = pipeline_name self.run_id = run_id - self.checkpointing_enabled: bool = source.is_stateful_ingestion_configured() + self.checkpointing_enabled: bool = ( + self.state_provider.is_stateful_ingestion_configured() + ) self._job_id = self._init_job_id() - self.source.register_stateful_ingestion_usecase_handler(self) + self.state_provider.register_stateful_ingestion_usecase_handler(self) def _ignore_old_state(self) -> bool: if ( @@ -114,7 +117,7 @@ class RedundantRunSkipHandler( ) -> None: if not self.is_checkpointing_enabled() or self._ignore_new_state(): return - cur_checkpoint = self.source.get_current_checkpoint(self.job_id) + cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id) assert cur_checkpoint is not None cur_state = cast(BaseUsageCheckpointState, cur_checkpoint.state) cur_state.begin_timestamp_millis = start_time_millis @@ -125,7 +128,7 @@ class RedundantRunSkipHandler( return False # Determine from the last check point state last_successful_pipeline_run_end_time_millis: Optional[int] = None - last_checkpoint = self.source.get_last_checkpoint( + last_checkpoint = self.state_provider.get_last_checkpoint( self.job_id, BaseUsageCheckpointState ) if last_checkpoint and last_checkpoint.state: diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/stale_entity_removal_handler.py b/metadata-ingestion/src/datahub/ingestion/source/state/stale_entity_removal_handler.py index 9303a6af26..c7d535174b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/stale_entity_removal_handler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/stale_entity_removal_handler.py @@ -1,25 +1,14 @@ import logging -from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import ( - Dict, - Generic, - Iterable, - List, - Optional, - Set, - Tuple, - Type, - TypeVar, - cast, -) +from typing import Dict, Iterable, Optional, Set, Type, cast import pydantic from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import JobId from datahub.ingestion.api.workunit import MetadataWorkUnit -from datahub.ingestion.source.state.checkpoint import Checkpoint, CheckpointStateBase +from datahub.ingestion.source.state.checkpoint import Checkpoint +from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState from datahub.ingestion.source.state.stateful_ingestion_base import ( StatefulIngestionConfig, StatefulIngestionConfigBase, @@ -61,102 +50,26 @@ class StaleEntityRemovalSourceReport(StatefulIngestionReport): self.soft_deleted_stale_entities.append(urn) -Derived = TypeVar("Derived", bound=CheckpointStateBase) - - -class StaleEntityCheckpointStateBase(CheckpointStateBase, ABC, Generic[Derived]): - """ - Defines the abstract interface for the checkpoint states that are used for stale entity removal. - Examples include sql_common state for tracking table and & view urns, - dbt that tracks node & assertion urns, kafka state tracking topic urns. - """ - - @classmethod - @abstractmethod - def get_supported_types(cls) -> List[str]: - pass - - @abstractmethod - def add_checkpoint_urn(self, type: str, urn: str) -> None: - """ - Adds an urn into the list used for tracking the type. - :param type: The type of the urn such as a 'table', 'view', - 'node', 'topic', 'assertion' that the concrete sub-class understands. - :param urn: The urn string - :return: None. - """ - pass - - @abstractmethod - def get_urns_not_in( - self, type: str, other_checkpoint_state: Derived - ) -> Iterable[str]: - """ - Gets the urns present in this checkpoint but not the other_checkpoint for the given type. - :param type: The type of the urn such as a 'table', 'view', - 'node', 'topic', 'assertion' that the concrete sub-class understands. - :param other_checkpoint_state: the checkpoint state to compute the urn set difference against. - :return: an iterable to the set of urns present in this checkpoing state but not in the other_checkpoint. - """ - pass - - @abstractmethod - def get_percent_entities_changed(self, old_checkpoint_state: Derived) -> float: - """ - Returns the percentage of entities that have changed relative to `old_checkpoint_state`. - :param old_checkpoint_state: the old checkpoint state to compute the relative change percent against. - :return: (1-|intersection(self, old_checkpoint_state)| / |old_checkpoint_state|) * 100.0 - """ - pass - - @staticmethod - def compute_percent_entities_changed( - new_old_entity_list: List[Tuple[List[str], List[str]]] - ) -> float: - old_count_all = 0 - overlap_count_all = 0 - for new_entities, old_entities in new_old_entity_list: - ( - overlap_count, - old_count, - _, - ) = StaleEntityCheckpointStateBase.get_entity_overlap_and_cardinalities( - new_entities=new_entities, old_entities=old_entities - ) - overlap_count_all += overlap_count - old_count_all += old_count - if old_count_all: - return (1 - overlap_count_all / old_count_all) * 100.0 - return 0.0 - - @staticmethod - def get_entity_overlap_and_cardinalities( - new_entities: List[str], old_entities: List[str] - ) -> Tuple[int, int, int]: - new_set = set(new_entities) - old_set = set(old_entities) - return len(new_set.intersection(old_set)), len(old_set), len(new_set) - - class StaleEntityRemovalHandler( - StatefulIngestionUsecaseHandlerBase[StaleEntityCheckpointStateBase] + StatefulIngestionUsecaseHandlerBase["GenericCheckpointState"] ): """ The stateful ingestion helper class that handles stale entity removal. This contains the generic logic for all sources that need to support stale entity removal for all the states - derived from StaleEntityCheckpointStateBase. This uses the template method pattern on CRTP based derived state - class hierarchies. + derived from GenericCheckpointState. """ def __init__( self, source: StatefulIngestionSourceBase, config: StatefulIngestionConfigBase[StatefulStaleMetadataRemovalConfig], - state_type_class: Type[StaleEntityCheckpointStateBase], + state_type_class: Type["GenericCheckpointState"], pipeline_name: Optional[str], run_id: str, ): self.source = source + self.state_provider = source.state_provider + self.state_type_class = state_type_class self.pipeline_name = pipeline_name self.run_id = run_id @@ -166,7 +79,7 @@ class StaleEntityRemovalHandler( self.checkpointing_enabled: bool = ( True if ( - source.is_stateful_ingestion_configured() + self.state_provider.is_stateful_ingestion_configured() and self.stateful_ingestion_config and self.stateful_ingestion_config.remove_stale_metadata ) @@ -174,7 +87,7 @@ class StaleEntityRemovalHandler( ) self._job_id = self._init_job_id() self._urns_to_skip: Set[str] = set() - self.source.register_stateful_ingestion_usecase_handler(self) + self.state_provider.register_stateful_ingestion_usecase_handler(self) @classmethod def compute_job_id( @@ -246,7 +159,7 @@ class StaleEntityRemovalHandler( ) return None - def _create_soft_delete_workunit(self, urn: str, type: str) -> MetadataWorkUnit: + def _create_soft_delete_workunit(self, urn: str) -> MetadataWorkUnit: logger.info(f"Soft-deleting stale entity - {urn}") mcp = MetadataChangeProposalWrapper( entityUrn=urn, @@ -278,20 +191,16 @@ class StaleEntityRemovalHandler( if not self.is_checkpointing_enabled() or self._ignore_old_state(): return logger.debug("Checking for stale entity removal.") - last_checkpoint: Optional[Checkpoint] = self.source.get_last_checkpoint( + last_checkpoint: Optional[Checkpoint] = self.state_provider.get_last_checkpoint( self.job_id, self.state_type_class ) if not last_checkpoint: return - cur_checkpoint = self.source.get_current_checkpoint(self.job_id) + cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id) assert cur_checkpoint is not None # Get the underlying states - last_checkpoint_state = cast( - StaleEntityCheckpointStateBase, last_checkpoint.state - ) - cur_checkpoint_state = cast( - StaleEntityCheckpointStateBase, cur_checkpoint.state - ) + last_checkpoint_state = cast(GenericCheckpointState, last_checkpoint.state) + cur_checkpoint_state = cast(GenericCheckpointState, cur_checkpoint.state) # Check if the entity delta is below the fail-safe threshold. entity_difference_percent = cur_checkpoint_state.get_percent_entities_changed( @@ -316,21 +225,20 @@ class StaleEntityRemovalHandler( return # Everything looks good, emit the soft-deletion workunits - for type in self.state_type_class.get_supported_types(): - for urn in last_checkpoint_state.get_urns_not_in( - type=type, other_checkpoint_state=cur_checkpoint_state - ): - if urn in self._urns_to_skip: - logger.debug( - f"Not soft-deleting entity {urn} since it is in urns_to_skip" - ) - continue - yield self._create_soft_delete_workunit(urn, type) + for urn in last_checkpoint_state.get_urns_not_in( + type="*", other_checkpoint_state=cur_checkpoint_state + ): + if urn in self._urns_to_skip: + logger.debug( + f"Not soft-deleting entity {urn} since it is in urns_to_skip" + ) + continue + yield self._create_soft_delete_workunit(urn) def add_entity_to_state(self, type: str, urn: str) -> None: if not self.is_checkpointing_enabled() or self._ignore_new_state(): return - cur_checkpoint = self.source.get_current_checkpoint(self.job_id) + cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id) assert cur_checkpoint is not None - cur_state = cast(StaleEntityCheckpointStateBase, cur_checkpoint.state) + cur_state = cast(GenericCheckpointState, cur_checkpoint.state) cur_state.add_checkpoint_urn(type=type, urn=urn) diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py b/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py index 9df1e531a7..0c7a8349bc 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py @@ -27,10 +27,7 @@ from datahub.ingestion.source.state.use_case_handler import ( from datahub.ingestion.source.state_provider.state_provider_registry import ( ingestion_checkpoint_provider_registry, ) -from datahub.metadata.schema_classes import ( - DatahubIngestionCheckpointClass, - DatahubIngestionRunSummaryClass, -) +from datahub.metadata.schema_classes import DatahubIngestionCheckpointClass logger: logging.Logger = logging.getLogger(__name__) @@ -171,13 +168,8 @@ class StatefulIngestionSourceBase(Source): ctx: PipelineContext, ) -> None: super().__init__(ctx) - self.stateful_ingestion_config = config.stateful_ingestion - self.last_checkpoints: Dict[JobId, Optional[Checkpoint]] = {} - self.cur_checkpoints: Dict[JobId, Optional[Checkpoint]] = {} - self.run_summaries_to_report: Dict[JobId, DatahubIngestionRunSummaryClass] = {} self.report: StatefulIngestionReport = StatefulIngestionReport() - self._initialize_checkpointing_state_provider() - self._usecase_handlers: Dict[JobId, StatefulIngestionUsecaseHandlerBase] = {} + self.state_provider = StateProviderWrapper(config.stateful_ingestion, ctx) def warn(self, log: logging.Logger, key: str, reason: str) -> None: self.report.report_warning(key, reason) @@ -187,6 +179,26 @@ class StatefulIngestionSourceBase(Source): self.report.report_failure(key, reason) log.error(f"{key} => {reason}") + def close(self) -> None: + self.state_provider.prepare_for_commit() + super().close() + + +class StateProviderWrapper: + def __init__( + self, + config: Optional[StatefulIngestionConfig], + ctx: PipelineContext, + ) -> None: + self.ctx = ctx + self.stateful_ingestion_config = config + + self.last_checkpoints: Dict[JobId, Optional[Checkpoint]] = {} + self.cur_checkpoints: Dict[JobId, Optional[Checkpoint]] = {} + self.report: StatefulIngestionReport = StatefulIngestionReport() + self._initialize_checkpointing_state_provider() + self._usecase_handlers: Dict[JobId, StatefulIngestionUsecaseHandlerBase] = {} + # # Checkpointing specific support. # @@ -383,7 +395,3 @@ class StatefulIngestionSourceBase(Source): def prepare_for_commit(self) -> None: """NOTE: Sources should call this method from their close method.""" self._prepare_checkpoint_states_for_commit() - - def close(self) -> None: - self.prepare_for_commit() - super().close() diff --git a/metadata-ingestion/tests/integration/azure_ad/test_azure_ad.py b/metadata-ingestion/tests/integration/azure_ad/test_azure_ad.py index 42b9da3b98..7005bc2e44 100644 --- a/metadata-ingestion/tests/integration/azure_ad/test_azure_ad.py +++ b/metadata-ingestion/tests/integration/azure_ad/test_azure_ad.py @@ -1,17 +1,16 @@ import json import pathlib from functools import partial -from typing import List, Optional, cast +from typing import List from unittest.mock import patch from freezegun import freeze_time from datahub.ingestion.run.pipeline import Pipeline -from datahub.ingestion.source.identity.azure_ad import AzureADConfig, AzureADSource -from datahub.ingestion.source.state.checkpoint import Checkpoint -from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState +from datahub.ingestion.source.identity.azure_ad import AzureADConfig from tests.test_helpers import mce_helpers from tests.test_helpers.state_helpers import ( + get_current_checkpoint_from_pipeline, validate_all_providers_have_committed_successfully, ) @@ -328,15 +327,6 @@ def test_azure_source_ingestion_disabled(pytestconfig, mock_datahub_graph, tmp_p ) -def get_current_checkpoint_from_pipeline( - pipeline: Pipeline, -) -> Optional[Checkpoint[GenericCheckpointState]]: - azure_ad_source = cast(AzureADSource, pipeline.source) - return azure_ad_source.get_current_checkpoint( - azure_ad_source.stale_entity_removal_handler.job_id - ) - - @freeze_time(FROZEN_TIME) def test_azure_ad_stateful_ingestion( pytestconfig, tmp_path, mock_time, mock_datahub_graph diff --git a/metadata-ingestion/tests/integration/iceberg/test_iceberg.py b/metadata-ingestion/tests/integration/iceberg/test_iceberg.py index 032ba93b1b..b26b574e54 100644 --- a/metadata-ingestion/tests/integration/iceberg/test_iceberg.py +++ b/metadata-ingestion/tests/integration/iceberg/test_iceberg.py @@ -1,5 +1,5 @@ from pathlib import PosixPath -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Dict, Union from unittest.mock import patch import pytest @@ -8,11 +8,9 @@ from iceberg.core.filesystem.file_status import FileStatus from iceberg.core.filesystem.local_filesystem import LocalFileSystem from datahub.ingestion.run.pipeline import Pipeline -from datahub.ingestion.source.iceberg.iceberg import IcebergSource -from datahub.ingestion.source.state.checkpoint import Checkpoint -from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState from tests.test_helpers import mce_helpers from tests.test_helpers.state_helpers import ( + get_current_checkpoint_from_pipeline, run_and_get_pipeline, validate_all_providers_have_committed_successfully, ) @@ -22,15 +20,6 @@ GMS_PORT = 8080 GMS_SERVER = f"http://localhost:{GMS_PORT}" -def get_current_checkpoint_from_pipeline( - pipeline: Pipeline, -) -> Optional[Checkpoint[GenericCheckpointState]]: - iceberg_source = cast(IcebergSource, pipeline.source) - return iceberg_source.get_current_checkpoint( - iceberg_source.stale_entity_removal_handler.job_id - ) - - @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_iceberg_ingest(pytestconfig, tmp_path, mock_time): diff --git a/metadata-ingestion/tests/integration/kafka-connect/test_kafka_connect.py b/metadata-ingestion/tests/integration/kafka-connect/test_kafka_connect.py index b229fc8bee..8efea3ced6 100644 --- a/metadata-ingestion/tests/integration/kafka-connect/test_kafka_connect.py +++ b/metadata-ingestion/tests/integration/kafka-connect/test_kafka_connect.py @@ -1,6 +1,6 @@ import subprocess import time -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, List, cast from unittest import mock import pytest @@ -8,12 +8,12 @@ import requests from freezegun import freeze_time from datahub.ingestion.run.pipeline import Pipeline -from datahub.ingestion.source.state.checkpoint import Checkpoint from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState from tests.test_helpers import mce_helpers from tests.test_helpers.click_helpers import run_datahub_cmd from tests.test_helpers.docker_helpers import wait_for_port from tests.test_helpers.state_helpers import ( + get_current_checkpoint_from_pipeline, validate_all_providers_have_committed_successfully, ) @@ -489,14 +489,3 @@ def test_kafka_connect_ingest_stateful( "urn:li:dataJob:(urn:li:dataFlow:(kafka-connect,connect-instance-1.mysql_source2,PROD),librarydb.member)", ] assert sorted(deleted_job_urns) == sorted(difference_job_urns) - - -def get_current_checkpoint_from_pipeline( - pipeline: Pipeline, -) -> Optional[Checkpoint]: - from datahub.ingestion.source.kafka_connect import KafkaConnectSource - - kafka_connect_source = cast(KafkaConnectSource, pipeline.source) - return kafka_connect_source.get_current_checkpoint( - kafka_connect_source.stale_entity_removal_handler.job_id - ) diff --git a/metadata-ingestion/tests/integration/kafka/test_kafka_state.py b/metadata-ingestion/tests/integration/kafka/test_kafka_state.py index 52940696df..6dfc0427f7 100644 --- a/metadata-ingestion/tests/integration/kafka/test_kafka_state.py +++ b/metadata-ingestion/tests/integration/kafka/test_kafka_state.py @@ -1,17 +1,14 @@ import time -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, List from unittest.mock import patch import pytest from confluent_kafka.admin import AdminClient, NewTopic from freezegun import freeze_time -from datahub.ingestion.run.pipeline import Pipeline -from datahub.ingestion.source.kafka import KafkaSource -from datahub.ingestion.source.state.checkpoint import Checkpoint -from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState from tests.test_helpers.docker_helpers import wait_for_port from tests.test_helpers.state_helpers import ( + get_current_checkpoint_from_pipeline, run_and_get_pipeline, validate_all_providers_have_committed_successfully, ) @@ -81,15 +78,6 @@ class KafkaTopicsCxtManager: self.delete_kafka_topics(self.topics) -def get_current_checkpoint_from_pipeline( - pipeline: Pipeline, -) -> Optional[Checkpoint[GenericCheckpointState]]: - kafka_source = cast(KafkaSource, pipeline.source) - return kafka_source.get_current_checkpoint( - kafka_source.stale_entity_removal_handler.job_id - ) - - @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_kafka_ingest_with_stateful( diff --git a/metadata-ingestion/tests/integration/ldap/test_ldap_stateful.py b/metadata-ingestion/tests/integration/ldap/test_ldap_stateful.py index 1f6fa3d405..c11dda8a25 100644 --- a/metadata-ingestion/tests/integration/ldap/test_ldap_stateful.py +++ b/metadata-ingestion/tests/integration/ldap/test_ldap_stateful.py @@ -1,18 +1,15 @@ import pathlib import time -from typing import Optional, cast from unittest import mock import pytest from freezegun import freeze_time from datahub.ingestion.run.pipeline import Pipeline -from datahub.ingestion.source.ldap import LDAPSource -from datahub.ingestion.source.state.checkpoint import Checkpoint -from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState from tests.test_helpers import mce_helpers from tests.test_helpers.docker_helpers import wait_for_port from tests.test_helpers.state_helpers import ( + get_current_checkpoint_from_pipeline, validate_all_providers_have_committed_successfully, ) @@ -90,15 +87,6 @@ def ldap_ingest_common( return pipeline -def get_current_checkpoint_from_pipeline( - pipeline: Pipeline, -) -> Optional[Checkpoint[GenericCheckpointState]]: - ldap_source = cast(LDAPSource, pipeline.source) - return ldap_source.get_current_checkpoint( - ldap_source.stale_entity_removal_handler.job_id - ) - - @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_ldap_stateful( diff --git a/metadata-ingestion/tests/integration/looker/test_looker.py b/metadata-ingestion/tests/integration/looker/test_looker.py index 3e1618ecfc..c6f6583740 100644 --- a/metadata-ingestion/tests/integration/looker/test_looker.py +++ b/metadata-ingestion/tests/integration/looker/test_looker.py @@ -27,11 +27,10 @@ from datahub.ingestion.source.looker.looker_query_model import ( LookViewField, UserViewField, ) -from datahub.ingestion.source.looker.looker_source import LookerDashboardSource -from datahub.ingestion.source.state.checkpoint import Checkpoint from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState from tests.test_helpers import mce_helpers from tests.test_helpers.state_helpers import ( + get_current_checkpoint_from_pipeline, validate_all_providers_have_committed_successfully, ) @@ -719,12 +718,3 @@ def test_looker_ingest_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_ assert len(difference_dashboard_urns) == 1 deleted_dashboard_urns = ["urn:li:dashboard:(looker,dashboards.11)"] assert sorted(deleted_dashboard_urns) == sorted(difference_dashboard_urns) - - -def get_current_checkpoint_from_pipeline( - pipeline: Pipeline, -) -> Optional[Checkpoint]: - dbt_source = cast(LookerDashboardSource, pipeline.source) - return dbt_source.get_current_checkpoint( - dbt_source.stale_entity_removal_handler.job_id - ) diff --git a/metadata-ingestion/tests/integration/lookml/test_lookml.py b/metadata-ingestion/tests/integration/lookml/test_lookml.py index 173b0983ed..85eb4dcd92 100644 --- a/metadata-ingestion/tests/integration/lookml/test_lookml.py +++ b/metadata-ingestion/tests/integration/lookml/test_lookml.py @@ -1,6 +1,6 @@ import logging import pathlib -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, List, cast from unittest import mock import pydantic @@ -15,10 +15,8 @@ from datahub.ingestion.source.file import read_metadata_file from datahub.ingestion.source.looker.lookml_source import ( LookerModel, LookerRefinementResolver, - LookMLSource, LookMLSourceConfig, ) -from datahub.ingestion.source.state.checkpoint import Checkpoint from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState from datahub.metadata.schema_classes import ( DatasetSnapshotClass, @@ -27,6 +25,7 @@ from datahub.metadata.schema_classes import ( ) from tests.test_helpers import mce_helpers from tests.test_helpers.state_helpers import ( + get_current_checkpoint_from_pipeline, validate_all_providers_have_committed_successfully, ) @@ -847,15 +846,6 @@ def test_lookml_ingest_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_ assert sorted(deleted_dataset_urns) == sorted(difference_dataset_urns) -def get_current_checkpoint_from_pipeline( - pipeline: Pipeline, -) -> Optional[Checkpoint]: - dbt_source = cast(LookMLSource, pipeline.source) - return dbt_source.get_current_checkpoint( - dbt_source.stale_entity_removal_handler.job_id - ) - - def test_lookml_base_folder(): fake_api = { "base_url": "https://filler.cloud.looker.com", diff --git a/metadata-ingestion/tests/integration/okta/test_okta.py b/metadata-ingestion/tests/integration/okta/test_okta.py index 04f78efacf..926100ae5b 100644 --- a/metadata-ingestion/tests/integration/okta/test_okta.py +++ b/metadata-ingestion/tests/integration/okta/test_okta.py @@ -1,7 +1,6 @@ import asyncio import pathlib from functools import partial -from typing import Optional, cast from unittest.mock import Mock, patch import jsonpickle @@ -10,11 +9,10 @@ from freezegun import freeze_time from okta.models import Group, User from datahub.ingestion.run.pipeline import Pipeline -from datahub.ingestion.source.identity.okta import OktaConfig, OktaSource -from datahub.ingestion.source.state.checkpoint import Checkpoint -from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState +from datahub.ingestion.source.identity.okta import OktaConfig from tests.test_helpers import mce_helpers from tests.test_helpers.state_helpers import ( + get_current_checkpoint_from_pipeline, validate_all_providers_have_committed_successfully, ) @@ -203,15 +201,6 @@ def test_okta_source_custom_user_name_regex(pytestconfig, mock_datahub_graph, tm ) -def get_current_checkpoint_from_pipeline( - pipeline: Pipeline, -) -> Optional[Checkpoint[GenericCheckpointState]]: - azure_ad_source = cast(OktaSource, pipeline.source) - return azure_ad_source.get_current_checkpoint( - azure_ad_source.stale_entity_removal_handler.job_id - ) - - @freeze_time(FROZEN_TIME) def test_okta_stateful_ingestion(pytestconfig, tmp_path, mock_time, mock_datahub_graph): test_resources_dir: pathlib.Path = pytestconfig.rootpath / "tests/integration/okta" diff --git a/metadata-ingestion/tests/integration/powerbi/test_stateful_ingestion.py b/metadata-ingestion/tests/integration/powerbi/test_stateful_ingestion.py index 8741442645..077b48ca17 100644 --- a/metadata-ingestion/tests/integration/powerbi/test_stateful_ingestion.py +++ b/metadata-ingestion/tests/integration/powerbi/test_stateful_ingestion.py @@ -216,9 +216,11 @@ def get_current_checkpoint_from_pipeline( ) -> Dict[JobId, Optional[Checkpoint[Any]]]: powerbi_source = cast(PowerBiDashboardSource, pipeline.source) checkpoints = {} - for job_id in powerbi_source._usecase_handlers.keys(): + for job_id in powerbi_source.state_provider._usecase_handlers.keys(): # for multi-workspace checkpoint, every good checkpoint will have an unique workspaceid suffix - checkpoints[job_id] = powerbi_source.get_current_checkpoint(job_id) + checkpoints[job_id] = powerbi_source.state_provider.get_current_checkpoint( + job_id + ) return checkpoints diff --git a/metadata-ingestion/tests/integration/superset/test_superset.py b/metadata-ingestion/tests/integration/superset/test_superset.py index 656819dbcd..bc299e3651 100644 --- a/metadata-ingestion/tests/integration/superset/test_superset.py +++ b/metadata-ingestion/tests/integration/superset/test_superset.py @@ -1,15 +1,13 @@ -from typing import Any, Dict, Optional, cast +from typing import Any, Dict from unittest.mock import patch import pytest from freezegun import freeze_time from datahub.ingestion.run.pipeline import Pipeline -from datahub.ingestion.source.state.checkpoint import Checkpoint -from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState -from datahub.ingestion.source.superset import SupersetSource from tests.test_helpers import mce_helpers from tests.test_helpers.state_helpers import ( + get_current_checkpoint_from_pipeline, run_and_get_pipeline, validate_all_providers_have_committed_successfully, ) @@ -19,15 +17,6 @@ GMS_PORT = 8080 GMS_SERVER = f"http://localhost:{GMS_PORT}" -def get_current_checkpoint_from_pipeline( - pipeline: Pipeline, -) -> Optional[Checkpoint[GenericCheckpointState]]: - superset_source = cast(SupersetSource, pipeline.source) - return superset_source.get_current_checkpoint( - superset_source.stale_entity_removal_handler.job_id - ) - - def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: api_vs_response = { "mock://mock-domain.superset.com/api/v1/security/login": { diff --git a/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py b/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py index 496a7b7486..e6af282976 100644 --- a/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py +++ b/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py @@ -2,7 +2,7 @@ import json import logging import pathlib import sys -from typing import Optional, cast +from typing import cast from unittest import mock import pytest @@ -17,8 +17,6 @@ from tableauserverclient.models import ( from datahub.configuration.source_common import DEFAULT_ENV from datahub.ingestion.run.pipeline import Pipeline, PipelineContext -from datahub.ingestion.source.state.checkpoint import Checkpoint -from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState from datahub.ingestion.source.tableau import TableauConfig, TableauSource from datahub.ingestion.source.tableau_common import ( TableauLineageOverrides, @@ -31,6 +29,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( from datahub.metadata.schema_classes import MetadataChangeProposalClass, UpstreamClass from tests.test_helpers import mce_helpers from tests.test_helpers.state_helpers import ( + get_current_checkpoint_from_pipeline, validate_all_providers_have_committed_successfully, ) @@ -253,15 +252,6 @@ def tableau_ingest_common( return pipeline -def get_current_checkpoint_from_pipeline( - pipeline: Pipeline, -) -> Optional[Checkpoint[GenericCheckpointState]]: - tableau_source = cast(TableauSource, pipeline.source) - return tableau_source.get_current_checkpoint( - tableau_source.stale_entity_removal_handler.job_id - ) - - @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_tableau_ingest(pytestconfig, tmp_path, mock_datahub_graph): diff --git a/metadata-ingestion/tests/test_helpers/state_helpers.py b/metadata-ingestion/tests/test_helpers/state_helpers.py index e4a5ffc007..c2d1fa3d50 100644 --- a/metadata-ingestion/tests/test_helpers/state_helpers.py +++ b/metadata-ingestion/tests/test_helpers/state_helpers.py @@ -11,6 +11,14 @@ from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import ( ) from datahub.ingestion.graph.client import DataHubGraph from datahub.ingestion.run.pipeline import Pipeline +from datahub.ingestion.source.state.checkpoint import Checkpoint +from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState +from datahub.ingestion.source.state.stale_entity_removal_handler import ( + StaleEntityRemovalHandler, +) +from datahub.ingestion.source.state.stateful_ingestion_base import ( + StatefulIngestionSourceBase, +) def validate_all_providers_have_committed_successfully( @@ -91,3 +99,15 @@ def mock_datahub_graph(): mock_datahub_graph_ctx = MockDataHubGraphContext() return mock_datahub_graph_ctx.mock_graph + + +def get_current_checkpoint_from_pipeline( + pipeline: Pipeline, +) -> Optional[Checkpoint[GenericCheckpointState]]: + # TODO: This only works for stale entity removal. We need to generalize this. + + stateful_source = cast(StatefulIngestionSourceBase, pipeline.source) + stale_entity_removal_handler: StaleEntityRemovalHandler = stateful_source.stale_entity_removal_handler # type: ignore + return stateful_source.state_provider.get_current_checkpoint( + stale_entity_removal_handler.job_id + ) diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stale_entity_removal_handler.py b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stale_entity_removal_handler.py index cfada6c3a4..aeb3ffd87c 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stale_entity_removal_handler.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stale_entity_removal_handler.py @@ -2,8 +2,8 @@ from typing import Dict, List, Tuple import pytest -from datahub.ingestion.source.state.stale_entity_removal_handler import ( - StaleEntityCheckpointStateBase, +from datahub.ingestion.source.state.entity_removal_state import ( + compute_percent_entities_changed, ) OldNewEntLists = List[Tuple[List[str], List[str]]] @@ -43,9 +43,5 @@ old_new_ent_tests: Dict[str, Tuple[OldNewEntLists, float]] = { def test_change_percent( new_old_entity_list: OldNewEntLists, expected_percent_change: float ) -> None: - actual_percent_change = ( - StaleEntityCheckpointStateBase.compute_percent_entities_changed( - new_old_entity_list - ) - ) + actual_percent_change = compute_percent_entities_changed(new_old_entity_list) assert actual_percent_change == expected_percent_change diff --git a/metadata-ingestion/tests/unit/test_glue_source.py b/metadata-ingestion/tests/unit/test_glue_source.py index 294bea9cdf..23dc3b97e0 100644 --- a/metadata-ingestion/tests/unit/test_glue_source.py +++ b/metadata-ingestion/tests/unit/test_glue_source.py @@ -10,10 +10,8 @@ from freezegun import freeze_time from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.extractor.schema_util import avro_schema_to_mce_fields -from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.sink.file import write_metadata_file from datahub.ingestion.source.aws.glue import GlueSource, GlueSourceConfig -from datahub.ingestion.source.state.checkpoint import Checkpoint from datahub.ingestion.source.state.sql_common_state import ( BaseSQLAlchemyCheckpointState, ) @@ -26,6 +24,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.schema import ( from datahub.utilities.hive_schema_to_avro import get_avro_schema_for_hive_column from tests.test_helpers import mce_helpers from tests.test_helpers.state_helpers import ( + get_current_checkpoint_from_pipeline, run_and_get_pipeline, validate_all_providers_have_committed_successfully, ) @@ -240,15 +239,6 @@ def test_config_without_platform(): assert source.platform == "glue" -def get_current_checkpoint_from_pipeline( - pipeline: Pipeline, -) -> Optional[Checkpoint]: - glue_source = cast(GlueSource, pipeline.source) - return glue_source.get_current_checkpoint( - glue_source.stale_entity_removal_handler.job_id - ) - - @freeze_time(FROZEN_TIME) def test_glue_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph): test_resources_dir = pytestconfig.rootpath / "tests/unit/glue" diff --git a/smoke-test/tests/test_stateful_ingestion.py b/smoke-test/tests/test_stateful_ingestion.py index e7b012a788..bf423eacff 100644 --- a/smoke-test/tests/test_stateful_ingestion.py +++ b/smoke-test/tests/test_stateful_ingestion.py @@ -1,13 +1,13 @@ from typing import Any, Dict, Optional, cast -from sqlalchemy import create_engine -from sqlalchemy.sql import text - from datahub.ingestion.api.committable import StatefulCommittable from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.source.sql.mysql import MySQLConfig, MySQLSource -from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState from datahub.ingestion.source.state.checkpoint import Checkpoint +from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState +from sqlalchemy import create_engine +from sqlalchemy.sql import text + from tests.utils import ( get_gms_url, get_mysql_password, @@ -49,8 +49,9 @@ def test_stateful_ingestion(wait_for_healthchecks): def get_current_checkpoint_from_pipeline( pipeline: Pipeline, ) -> Optional[Checkpoint[GenericCheckpointState]]: + # TODO: Refactor to use the helper method in the metadata-ingestion tests, instead of copying it here. mysql_source = cast(MySQLSource, pipeline.source) - return mysql_source.get_current_checkpoint( + return mysql_source.state_provider.get_current_checkpoint( mysql_source.stale_entity_removal_handler.job_id )