refactor(ingest): simplify stateful ingestion provider interface (#8104)

Co-authored-by: Tamas Nemeth <treff7es@gmail.com>
This commit is contained in:
Harshal Sheth 2023-05-24 01:27:57 +05:30 committed by GitHub
parent afd65e16fb
commit b0f8c3de1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 179 additions and 323 deletions

View File

@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Dict, Generic, Iterable, Optional, Tuple, Type
from datahub.emitter.mce_builder import set_dataset_urn_to_lower from datahub.emitter.mce_builder import set_dataset_urn_to_lower
from datahub.ingestion.api.committable import Committable from datahub.ingestion.api.committable import Committable
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph from datahub.ingestion.graph.client import DataHubGraph
if TYPE_CHECKING: if TYPE_CHECKING:
from datahub.ingestion.run.pipeline import PipelineConfig from datahub.ingestion.run.pipeline import PipelineConfig
@ -43,22 +43,19 @@ class PipelineContext:
def __init__( def __init__(
self, self,
run_id: str, run_id: str,
datahub_api: Optional["DatahubClientConfig"] = None, graph: Optional[DataHubGraph] = None,
pipeline_name: Optional[str] = None, pipeline_name: Optional[str] = None,
dry_run: bool = False, dry_run: bool = False,
preview_mode: bool = False, preview_mode: bool = False,
pipeline_config: Optional["PipelineConfig"] = None, pipeline_config: Optional["PipelineConfig"] = None,
) -> None: ) -> None:
self.pipeline_config = pipeline_config self.pipeline_config = pipeline_config
self.graph = graph
self.run_id = run_id self.run_id = run_id
self.pipeline_name = pipeline_name self.pipeline_name = pipeline_name
self.dry_run_mode = dry_run self.dry_run_mode = dry_run
self.preview_mode = preview_mode self.preview_mode = preview_mode
self.checkpointers: Dict[str, Committable] = {} self.checkpointers: Dict[str, Committable] = {}
try:
self.graph = DataHubGraph(datahub_api) if datahub_api is not None else None
except Exception as e:
raise Exception(f"Failed to connect to DataHub: {e}") from e
self._set_dataset_urn_to_lower_if_needed() self._set_dataset_urn_to_lower_if_needed()

View File

@ -27,6 +27,7 @@ from datahub.ingestion.api.sink import Sink, SinkReport, WriteCallback
from datahub.ingestion.api.source import Extractor, Source from datahub.ingestion.api.source import Extractor, Source
from datahub.ingestion.api.transform import Transformer from datahub.ingestion.api.transform import Transformer
from datahub.ingestion.extractor.extractor_registry import extractor_registry from datahub.ingestion.extractor.extractor_registry import extractor_registry
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.reporting.reporting_provider_registry import ( from datahub.ingestion.reporting.reporting_provider_registry import (
reporting_provider_registry, reporting_provider_registry,
) )
@ -183,10 +184,15 @@ class Pipeline:
self.last_time_printed = int(time.time()) self.last_time_printed = int(time.time())
self.cli_report = CliReport() self.cli_report = CliReport()
self.graph = None
with _add_init_error_context("connect to DataHub"):
if self.config.datahub_api:
self.graph = DataHubGraph(self.config.datahub_api)
with _add_init_error_context("set up framework context"): with _add_init_error_context("set up framework context"):
self.ctx = PipelineContext( self.ctx = PipelineContext(
run_id=self.config.run_id, run_id=self.config.run_id,
datahub_api=self.config.datahub_api, graph=self.graph,
pipeline_name=self.config.pipeline_name, pipeline_name=self.config.pipeline_name,
dry_run=dry_run, dry_run=dry_run,
preview_mode=preview_mode, preview_mode=preview_mode,

View File

@ -7,7 +7,7 @@ from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from enum import auto from enum import auto
from threading import BoundedSemaphore from threading import BoundedSemaphore
from typing import Union, cast from typing import Union
from datahub.cli.cli_utils import set_env_variables_override_config from datahub.cli.cli_utils import set_env_variables_override_config
from datahub.configuration.common import ( from datahub.configuration.common import (
@ -126,8 +126,7 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
def handle_work_unit_start(self, workunit: WorkUnit) -> None: def handle_work_unit_start(self, workunit: WorkUnit) -> None:
if isinstance(workunit, MetadataWorkUnit): if isinstance(workunit, MetadataWorkUnit):
mwu: MetadataWorkUnit = cast(MetadataWorkUnit, workunit) self.treat_errors_as_warnings = workunit.treat_errors_as_warnings
self.treat_errors_as_warnings = mwu.treat_errors_as_warnings
def handle_work_unit_end(self, workunit: WorkUnit) -> None: def handle_work_unit_end(self, workunit: WorkUnit) -> None:
pass pass

View File

@ -1331,6 +1331,3 @@ class LookerDashboardSource(TestableSource, StatefulIngestionSourceBase):
def get_report(self) -> SourceReport: def get_report(self) -> SourceReport:
return self.reporter return self.reporter
def close(self):
self.prepare_for_commit()

View File

@ -2163,6 +2163,3 @@ class LookMLSource(StatefulIngestionSourceBase):
def get_report(self): def get_report(self):
return self.reporter return self.reporter
def close(self):
self.prepare_for_commit()

View File

@ -1228,7 +1228,7 @@ class PowerBiDashboardSource(StatefulIngestionSourceBase):
# Because job_id is used as dictionary key, we have to set a new job_id # Because job_id is used as dictionary key, we have to set a new job_id
# Refer to https://github.com/datahub-project/datahub/blob/master/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py#L390 # Refer to https://github.com/datahub-project/datahub/blob/master/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py#L390
self.stale_entity_removal_handler.set_job_id(workspace.id) self.stale_entity_removal_handler.set_job_id(workspace.id)
self.register_stateful_ingestion_usecase_handler( self.state_provider.register_stateful_ingestion_usecase_handler(
self.stale_entity_removal_handler self.stale_entity_removal_handler
) )

View File

@ -1283,6 +1283,3 @@ class VerticaSource(SQLAlchemySource):
return each["owner_name"] return each["owner_name"]
return None return None
def close(self):
self.prepare_for_commit()

View File

@ -1,11 +1,9 @@
from typing import Any, Dict, Iterable, List, Type from typing import Any, Dict, Iterable, List, Tuple, Type
import pydantic import pydantic
from datahub.emitter.mce_builder import make_assertion_urn, make_container_urn from datahub.emitter.mce_builder import make_assertion_urn, make_container_urn
from datahub.ingestion.source.state.stale_entity_removal_handler import ( from datahub.ingestion.source.state.checkpoint import CheckpointStateBase
StaleEntityCheckpointStateBase,
)
from datahub.utilities.checkpoint_state_util import CheckpointStateUtil from datahub.utilities.checkpoint_state_util import CheckpointStateUtil
from datahub.utilities.dedup_list import deduplicate_list from datahub.utilities.dedup_list import deduplicate_list
from datahub.utilities.urns.urn import guess_entity_type from datahub.utilities.urns.urn import guess_entity_type
@ -56,7 +54,7 @@ def pydantic_state_migrator(mapping: Dict[str, str]) -> classmethod:
return pydantic.root_validator(pre=True, allow_reuse=True)(_validate_field_rename) return pydantic.root_validator(pre=True, allow_reuse=True)(_validate_field_rename)
class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointState"]): class GenericCheckpointState(CheckpointStateBase):
urns: List[str] = pydantic.Field(default_factory=list) urns: List[str] = pydantic.Field(default_factory=list)
# We store a bit of extra internal-only state so that we can keep the urns list deduplicated. # We store a bit of extra internal-only state so that we can keep the urns list deduplicated.
@ -85,11 +83,15 @@ class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointSt
self.urns = deduplicate_list(self.urns) self.urns = deduplicate_list(self.urns)
self._urns_set = set(self.urns) self._urns_set = set(self.urns)
@classmethod
def get_supported_types(cls) -> List[str]:
return ["*"]
def add_checkpoint_urn(self, type: str, urn: str) -> None: def add_checkpoint_urn(self, type: str, urn: str) -> None:
"""
Adds an urn into the list used for tracking the type.
:param type: Deprecated parameter, has no effect.
:param urn: The urn string
"""
# TODO: Deprecate the `type` parameter and remove it.
if urn not in self._urns_set: if urn not in self._urns_set:
self.urns.append(urn) self.urns.append(urn)
self._urns_set.add(urn) self._urns_set.add(urn)
@ -97,9 +99,18 @@ class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointSt
def get_urns_not_in( def get_urns_not_in(
self, type: str, other_checkpoint_state: "GenericCheckpointState" self, type: str, other_checkpoint_state: "GenericCheckpointState"
) -> Iterable[str]: ) -> Iterable[str]:
"""
Gets the urns present in this checkpoint but not the other_checkpoint for the given type.
:param type: Deprecated. Set to "*".
:param other_checkpoint_state: the checkpoint state to compute the urn set difference against.
:return: an iterable to the set of urns present in this checkpoint state but not in the other_checkpoint.
"""
diff = set(self.urns) - set(other_checkpoint_state.urns) diff = set(self.urns) - set(other_checkpoint_state.urns)
# To maintain backwards compatibility, we provide this filtering mechanism. # To maintain backwards compatibility, we provide this filtering mechanism.
# TODO: Deprecate the `type` parameter and remove it.
if type == "*": if type == "*":
yield from diff yield from diff
elif type == "topic": elif type == "topic":
@ -110,6 +121,36 @@ class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointSt
def get_percent_entities_changed( def get_percent_entities_changed(
self, old_checkpoint_state: "GenericCheckpointState" self, old_checkpoint_state: "GenericCheckpointState"
) -> float: ) -> float:
return StaleEntityCheckpointStateBase.compute_percent_entities_changed( """
Returns the percentage of entities that have changed relative to `old_checkpoint_state`.
:param old_checkpoint_state: the old checkpoint state to compute the relative change percent against.
:return: (1-|intersection(self, old_checkpoint_state)| / |old_checkpoint_state|) * 100.0
"""
return compute_percent_entities_changed(
[(self.urns, old_checkpoint_state.urns)] [(self.urns, old_checkpoint_state.urns)]
) )
def compute_percent_entities_changed(
new_old_entity_list: List[Tuple[List[str], List[str]]]
) -> float:
old_count_all = 0
overlap_count_all = 0
for new_entities, old_entities in new_old_entity_list:
(overlap_count, old_count, _,) = get_entity_overlap_and_cardinalities(
new_entities=new_entities, old_entities=old_entities
)
overlap_count_all += overlap_count
old_count_all += old_count
if old_count_all:
return (1 - overlap_count_all / old_count_all) * 100.0
return 0.0
def get_entity_overlap_and_cardinalities(
new_entities: List[str], old_entities: List[str]
) -> Tuple[int, int, int]:
new_set = set(new_entities)
old_set = set(old_entities)
return len(new_set.intersection(old_set)), len(old_set), len(new_set)

View File

@ -40,15 +40,17 @@ class ProfilingHandler(StatefulIngestionUsecaseHandlerBase[ProfilingCheckpointSt
pipeline_name: Optional[str], pipeline_name: Optional[str],
run_id: str, run_id: str,
): ):
self.source = source self.state_provider = source.state_provider
self.stateful_ingestion_config: Optional[ self.stateful_ingestion_config: Optional[
ProfilingStatefulIngestionConfig ProfilingStatefulIngestionConfig
] = config.stateful_ingestion ] = config.stateful_ingestion
self.pipeline_name = pipeline_name self.pipeline_name = pipeline_name
self.run_id = run_id self.run_id = run_id
self.checkpointing_enabled: bool = source.is_stateful_ingestion_configured() self.checkpointing_enabled: bool = (
self.state_provider.is_stateful_ingestion_configured()
)
self._job_id = self._init_job_id() self._job_id = self._init_job_id()
self.source.register_stateful_ingestion_usecase_handler(self) self.state_provider.register_stateful_ingestion_usecase_handler(self)
def _ignore_old_state(self) -> bool: def _ignore_old_state(self) -> bool:
if ( if (
@ -91,7 +93,7 @@ class ProfilingHandler(StatefulIngestionUsecaseHandlerBase[ProfilingCheckpointSt
def get_current_state(self) -> Optional[ProfilingCheckpointState]: def get_current_state(self) -> Optional[ProfilingCheckpointState]:
if not self.is_checkpointing_enabled() or self._ignore_new_state(): if not self.is_checkpointing_enabled() or self._ignore_new_state():
return None return None
cur_checkpoint = self.source.get_current_checkpoint(self.job_id) cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id)
assert cur_checkpoint is not None assert cur_checkpoint is not None
cur_state = cast(ProfilingCheckpointState, cur_checkpoint.state) cur_state = cast(ProfilingCheckpointState, cur_checkpoint.state)
return cur_state return cur_state
@ -108,7 +110,7 @@ class ProfilingHandler(StatefulIngestionUsecaseHandlerBase[ProfilingCheckpointSt
def get_last_state(self) -> Optional[ProfilingCheckpointState]: def get_last_state(self) -> Optional[ProfilingCheckpointState]:
if not self.is_checkpointing_enabled() or self._ignore_old_state(): if not self.is_checkpointing_enabled() or self._ignore_old_state():
return None return None
last_checkpoint = self.source.get_last_checkpoint( last_checkpoint = self.state_provider.get_last_checkpoint(
self.job_id, ProfilingCheckpointState self.job_id, ProfilingCheckpointState
) )
if last_checkpoint and last_checkpoint.state: if last_checkpoint and last_checkpoint.state:

View File

@ -46,14 +46,17 @@ class RedundantRunSkipHandler(
run_id: str, run_id: str,
): ):
self.source = source self.source = source
self.state_provider = source.state_provider
self.stateful_ingestion_config: Optional[ self.stateful_ingestion_config: Optional[
StatefulRedundantRunSkipConfig StatefulRedundantRunSkipConfig
] = config.stateful_ingestion ] = config.stateful_ingestion
self.pipeline_name = pipeline_name self.pipeline_name = pipeline_name
self.run_id = run_id self.run_id = run_id
self.checkpointing_enabled: bool = source.is_stateful_ingestion_configured() self.checkpointing_enabled: bool = (
self.state_provider.is_stateful_ingestion_configured()
)
self._job_id = self._init_job_id() self._job_id = self._init_job_id()
self.source.register_stateful_ingestion_usecase_handler(self) self.state_provider.register_stateful_ingestion_usecase_handler(self)
def _ignore_old_state(self) -> bool: def _ignore_old_state(self) -> bool:
if ( if (
@ -114,7 +117,7 @@ class RedundantRunSkipHandler(
) -> None: ) -> None:
if not self.is_checkpointing_enabled() or self._ignore_new_state(): if not self.is_checkpointing_enabled() or self._ignore_new_state():
return return
cur_checkpoint = self.source.get_current_checkpoint(self.job_id) cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id)
assert cur_checkpoint is not None assert cur_checkpoint is not None
cur_state = cast(BaseUsageCheckpointState, cur_checkpoint.state) cur_state = cast(BaseUsageCheckpointState, cur_checkpoint.state)
cur_state.begin_timestamp_millis = start_time_millis cur_state.begin_timestamp_millis = start_time_millis
@ -125,7 +128,7 @@ class RedundantRunSkipHandler(
return False return False
# Determine from the last check point state # Determine from the last check point state
last_successful_pipeline_run_end_time_millis: Optional[int] = None last_successful_pipeline_run_end_time_millis: Optional[int] = None
last_checkpoint = self.source.get_last_checkpoint( last_checkpoint = self.state_provider.get_last_checkpoint(
self.job_id, BaseUsageCheckpointState self.job_id, BaseUsageCheckpointState
) )
if last_checkpoint and last_checkpoint.state: if last_checkpoint and last_checkpoint.state:

View File

@ -1,25 +1,14 @@
import logging import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import ( from typing import Dict, Iterable, Optional, Set, Type, cast
Dict,
Generic,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
cast,
)
import pydantic import pydantic
from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import JobId from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import JobId
from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.state.checkpoint import Checkpoint, CheckpointStateBase from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.state.stateful_ingestion_base import ( from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionConfig, StatefulIngestionConfig,
StatefulIngestionConfigBase, StatefulIngestionConfigBase,
@ -61,102 +50,26 @@ class StaleEntityRemovalSourceReport(StatefulIngestionReport):
self.soft_deleted_stale_entities.append(urn) self.soft_deleted_stale_entities.append(urn)
Derived = TypeVar("Derived", bound=CheckpointStateBase)
class StaleEntityCheckpointStateBase(CheckpointStateBase, ABC, Generic[Derived]):
"""
Defines the abstract interface for the checkpoint states that are used for stale entity removal.
Examples include sql_common state for tracking table and & view urns,
dbt that tracks node & assertion urns, kafka state tracking topic urns.
"""
@classmethod
@abstractmethod
def get_supported_types(cls) -> List[str]:
pass
@abstractmethod
def add_checkpoint_urn(self, type: str, urn: str) -> None:
"""
Adds an urn into the list used for tracking the type.
:param type: The type of the urn such as a 'table', 'view',
'node', 'topic', 'assertion' that the concrete sub-class understands.
:param urn: The urn string
:return: None.
"""
pass
@abstractmethod
def get_urns_not_in(
self, type: str, other_checkpoint_state: Derived
) -> Iterable[str]:
"""
Gets the urns present in this checkpoint but not the other_checkpoint for the given type.
:param type: The type of the urn such as a 'table', 'view',
'node', 'topic', 'assertion' that the concrete sub-class understands.
:param other_checkpoint_state: the checkpoint state to compute the urn set difference against.
:return: an iterable to the set of urns present in this checkpoing state but not in the other_checkpoint.
"""
pass
@abstractmethod
def get_percent_entities_changed(self, old_checkpoint_state: Derived) -> float:
"""
Returns the percentage of entities that have changed relative to `old_checkpoint_state`.
:param old_checkpoint_state: the old checkpoint state to compute the relative change percent against.
:return: (1-|intersection(self, old_checkpoint_state)| / |old_checkpoint_state|) * 100.0
"""
pass
@staticmethod
def compute_percent_entities_changed(
new_old_entity_list: List[Tuple[List[str], List[str]]]
) -> float:
old_count_all = 0
overlap_count_all = 0
for new_entities, old_entities in new_old_entity_list:
(
overlap_count,
old_count,
_,
) = StaleEntityCheckpointStateBase.get_entity_overlap_and_cardinalities(
new_entities=new_entities, old_entities=old_entities
)
overlap_count_all += overlap_count
old_count_all += old_count
if old_count_all:
return (1 - overlap_count_all / old_count_all) * 100.0
return 0.0
@staticmethod
def get_entity_overlap_and_cardinalities(
new_entities: List[str], old_entities: List[str]
) -> Tuple[int, int, int]:
new_set = set(new_entities)
old_set = set(old_entities)
return len(new_set.intersection(old_set)), len(old_set), len(new_set)
class StaleEntityRemovalHandler( class StaleEntityRemovalHandler(
StatefulIngestionUsecaseHandlerBase[StaleEntityCheckpointStateBase] StatefulIngestionUsecaseHandlerBase["GenericCheckpointState"]
): ):
""" """
The stateful ingestion helper class that handles stale entity removal. The stateful ingestion helper class that handles stale entity removal.
This contains the generic logic for all sources that need to support stale entity removal for all the states This contains the generic logic for all sources that need to support stale entity removal for all the states
derived from StaleEntityCheckpointStateBase. This uses the template method pattern on CRTP based derived state derived from GenericCheckpointState.
class hierarchies.
""" """
def __init__( def __init__(
self, self,
source: StatefulIngestionSourceBase, source: StatefulIngestionSourceBase,
config: StatefulIngestionConfigBase[StatefulStaleMetadataRemovalConfig], config: StatefulIngestionConfigBase[StatefulStaleMetadataRemovalConfig],
state_type_class: Type[StaleEntityCheckpointStateBase], state_type_class: Type["GenericCheckpointState"],
pipeline_name: Optional[str], pipeline_name: Optional[str],
run_id: str, run_id: str,
): ):
self.source = source self.source = source
self.state_provider = source.state_provider
self.state_type_class = state_type_class self.state_type_class = state_type_class
self.pipeline_name = pipeline_name self.pipeline_name = pipeline_name
self.run_id = run_id self.run_id = run_id
@ -166,7 +79,7 @@ class StaleEntityRemovalHandler(
self.checkpointing_enabled: bool = ( self.checkpointing_enabled: bool = (
True True
if ( if (
source.is_stateful_ingestion_configured() self.state_provider.is_stateful_ingestion_configured()
and self.stateful_ingestion_config and self.stateful_ingestion_config
and self.stateful_ingestion_config.remove_stale_metadata and self.stateful_ingestion_config.remove_stale_metadata
) )
@ -174,7 +87,7 @@ class StaleEntityRemovalHandler(
) )
self._job_id = self._init_job_id() self._job_id = self._init_job_id()
self._urns_to_skip: Set[str] = set() self._urns_to_skip: Set[str] = set()
self.source.register_stateful_ingestion_usecase_handler(self) self.state_provider.register_stateful_ingestion_usecase_handler(self)
@classmethod @classmethod
def compute_job_id( def compute_job_id(
@ -246,7 +159,7 @@ class StaleEntityRemovalHandler(
) )
return None return None
def _create_soft_delete_workunit(self, urn: str, type: str) -> MetadataWorkUnit: def _create_soft_delete_workunit(self, urn: str) -> MetadataWorkUnit:
logger.info(f"Soft-deleting stale entity - {urn}") logger.info(f"Soft-deleting stale entity - {urn}")
mcp = MetadataChangeProposalWrapper( mcp = MetadataChangeProposalWrapper(
entityUrn=urn, entityUrn=urn,
@ -278,20 +191,16 @@ class StaleEntityRemovalHandler(
if not self.is_checkpointing_enabled() or self._ignore_old_state(): if not self.is_checkpointing_enabled() or self._ignore_old_state():
return return
logger.debug("Checking for stale entity removal.") logger.debug("Checking for stale entity removal.")
last_checkpoint: Optional[Checkpoint] = self.source.get_last_checkpoint( last_checkpoint: Optional[Checkpoint] = self.state_provider.get_last_checkpoint(
self.job_id, self.state_type_class self.job_id, self.state_type_class
) )
if not last_checkpoint: if not last_checkpoint:
return return
cur_checkpoint = self.source.get_current_checkpoint(self.job_id) cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id)
assert cur_checkpoint is not None assert cur_checkpoint is not None
# Get the underlying states # Get the underlying states
last_checkpoint_state = cast( last_checkpoint_state = cast(GenericCheckpointState, last_checkpoint.state)
StaleEntityCheckpointStateBase, last_checkpoint.state cur_checkpoint_state = cast(GenericCheckpointState, cur_checkpoint.state)
)
cur_checkpoint_state = cast(
StaleEntityCheckpointStateBase, cur_checkpoint.state
)
# Check if the entity delta is below the fail-safe threshold. # Check if the entity delta is below the fail-safe threshold.
entity_difference_percent = cur_checkpoint_state.get_percent_entities_changed( entity_difference_percent = cur_checkpoint_state.get_percent_entities_changed(
@ -316,21 +225,20 @@ class StaleEntityRemovalHandler(
return return
# Everything looks good, emit the soft-deletion workunits # Everything looks good, emit the soft-deletion workunits
for type in self.state_type_class.get_supported_types():
for urn in last_checkpoint_state.get_urns_not_in( for urn in last_checkpoint_state.get_urns_not_in(
type=type, other_checkpoint_state=cur_checkpoint_state type="*", other_checkpoint_state=cur_checkpoint_state
): ):
if urn in self._urns_to_skip: if urn in self._urns_to_skip:
logger.debug( logger.debug(
f"Not soft-deleting entity {urn} since it is in urns_to_skip" f"Not soft-deleting entity {urn} since it is in urns_to_skip"
) )
continue continue
yield self._create_soft_delete_workunit(urn, type) yield self._create_soft_delete_workunit(urn)
def add_entity_to_state(self, type: str, urn: str) -> None: def add_entity_to_state(self, type: str, urn: str) -> None:
if not self.is_checkpointing_enabled() or self._ignore_new_state(): if not self.is_checkpointing_enabled() or self._ignore_new_state():
return return
cur_checkpoint = self.source.get_current_checkpoint(self.job_id) cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id)
assert cur_checkpoint is not None assert cur_checkpoint is not None
cur_state = cast(StaleEntityCheckpointStateBase, cur_checkpoint.state) cur_state = cast(GenericCheckpointState, cur_checkpoint.state)
cur_state.add_checkpoint_urn(type=type, urn=urn) cur_state.add_checkpoint_urn(type=type, urn=urn)

View File

@ -27,10 +27,7 @@ from datahub.ingestion.source.state.use_case_handler import (
from datahub.ingestion.source.state_provider.state_provider_registry import ( from datahub.ingestion.source.state_provider.state_provider_registry import (
ingestion_checkpoint_provider_registry, ingestion_checkpoint_provider_registry,
) )
from datahub.metadata.schema_classes import ( from datahub.metadata.schema_classes import DatahubIngestionCheckpointClass
DatahubIngestionCheckpointClass,
DatahubIngestionRunSummaryClass,
)
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)
@ -171,13 +168,8 @@ class StatefulIngestionSourceBase(Source):
ctx: PipelineContext, ctx: PipelineContext,
) -> None: ) -> None:
super().__init__(ctx) super().__init__(ctx)
self.stateful_ingestion_config = config.stateful_ingestion
self.last_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
self.cur_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
self.run_summaries_to_report: Dict[JobId, DatahubIngestionRunSummaryClass] = {}
self.report: StatefulIngestionReport = StatefulIngestionReport() self.report: StatefulIngestionReport = StatefulIngestionReport()
self._initialize_checkpointing_state_provider() self.state_provider = StateProviderWrapper(config.stateful_ingestion, ctx)
self._usecase_handlers: Dict[JobId, StatefulIngestionUsecaseHandlerBase] = {}
def warn(self, log: logging.Logger, key: str, reason: str) -> None: def warn(self, log: logging.Logger, key: str, reason: str) -> None:
self.report.report_warning(key, reason) self.report.report_warning(key, reason)
@ -187,6 +179,26 @@ class StatefulIngestionSourceBase(Source):
self.report.report_failure(key, reason) self.report.report_failure(key, reason)
log.error(f"{key} => {reason}") log.error(f"{key} => {reason}")
def close(self) -> None:
self.state_provider.prepare_for_commit()
super().close()
class StateProviderWrapper:
def __init__(
self,
config: Optional[StatefulIngestionConfig],
ctx: PipelineContext,
) -> None:
self.ctx = ctx
self.stateful_ingestion_config = config
self.last_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
self.cur_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
self.report: StatefulIngestionReport = StatefulIngestionReport()
self._initialize_checkpointing_state_provider()
self._usecase_handlers: Dict[JobId, StatefulIngestionUsecaseHandlerBase] = {}
# #
# Checkpointing specific support. # Checkpointing specific support.
# #
@ -383,7 +395,3 @@ class StatefulIngestionSourceBase(Source):
def prepare_for_commit(self) -> None: def prepare_for_commit(self) -> None:
"""NOTE: Sources should call this method from their close method.""" """NOTE: Sources should call this method from their close method."""
self._prepare_checkpoint_states_for_commit() self._prepare_checkpoint_states_for_commit()
def close(self) -> None:
self.prepare_for_commit()
super().close()

View File

@ -1,17 +1,16 @@
import json import json
import pathlib import pathlib
from functools import partial from functools import partial
from typing import List, Optional, cast from typing import List
from unittest.mock import patch from unittest.mock import patch
from freezegun import freeze_time from freezegun import freeze_time
from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.identity.azure_ad import AzureADConfig, AzureADSource from datahub.ingestion.source.identity.azure_ad import AzureADConfig
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers import mce_helpers from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import ( from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully, validate_all_providers_have_committed_successfully,
) )
@ -328,15 +327,6 @@ def test_azure_source_ingestion_disabled(pytestconfig, mock_datahub_graph, tmp_p
) )
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
azure_ad_source = cast(AzureADSource, pipeline.source)
return azure_ad_source.get_current_checkpoint(
azure_ad_source.stale_entity_removal_handler.job_id
)
@freeze_time(FROZEN_TIME) @freeze_time(FROZEN_TIME)
def test_azure_ad_stateful_ingestion( def test_azure_ad_stateful_ingestion(
pytestconfig, tmp_path, mock_time, mock_datahub_graph pytestconfig, tmp_path, mock_time, mock_datahub_graph

View File

@ -1,5 +1,5 @@
from pathlib import PosixPath from pathlib import PosixPath
from typing import Any, Dict, Optional, Union, cast from typing import Any, Dict, Union
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -8,11 +8,9 @@ from iceberg.core.filesystem.file_status import FileStatus
from iceberg.core.filesystem.local_filesystem import LocalFileSystem from iceberg.core.filesystem.local_filesystem import LocalFileSystem
from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.iceberg.iceberg import IcebergSource
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers import mce_helpers from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import ( from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
run_and_get_pipeline, run_and_get_pipeline,
validate_all_providers_have_committed_successfully, validate_all_providers_have_committed_successfully,
) )
@ -22,15 +20,6 @@ GMS_PORT = 8080
GMS_SERVER = f"http://localhost:{GMS_PORT}" GMS_SERVER = f"http://localhost:{GMS_PORT}"
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
iceberg_source = cast(IcebergSource, pipeline.source)
return iceberg_source.get_current_checkpoint(
iceberg_source.stale_entity_removal_handler.job_id
)
@freeze_time(FROZEN_TIME) @freeze_time(FROZEN_TIME)
@pytest.mark.integration @pytest.mark.integration
def test_iceberg_ingest(pytestconfig, tmp_path, mock_time): def test_iceberg_ingest(pytestconfig, tmp_path, mock_time):

View File

@ -1,6 +1,6 @@
import subprocess import subprocess
import time import time
from typing import Any, Dict, List, Optional, cast from typing import Any, Dict, List, cast
from unittest import mock from unittest import mock
import pytest import pytest
@ -8,12 +8,12 @@ import requests
from freezegun import freeze_time from freezegun import freeze_time
from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers import mce_helpers from tests.test_helpers import mce_helpers
from tests.test_helpers.click_helpers import run_datahub_cmd from tests.test_helpers.click_helpers import run_datahub_cmd
from tests.test_helpers.docker_helpers import wait_for_port from tests.test_helpers.docker_helpers import wait_for_port
from tests.test_helpers.state_helpers import ( from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully, validate_all_providers_have_committed_successfully,
) )
@ -489,14 +489,3 @@ def test_kafka_connect_ingest_stateful(
"urn:li:dataJob:(urn:li:dataFlow:(kafka-connect,connect-instance-1.mysql_source2,PROD),librarydb.member)", "urn:li:dataJob:(urn:li:dataFlow:(kafka-connect,connect-instance-1.mysql_source2,PROD),librarydb.member)",
] ]
assert sorted(deleted_job_urns) == sorted(difference_job_urns) assert sorted(deleted_job_urns) == sorted(difference_job_urns)
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint]:
from datahub.ingestion.source.kafka_connect import KafkaConnectSource
kafka_connect_source = cast(KafkaConnectSource, pipeline.source)
return kafka_connect_source.get_current_checkpoint(
kafka_connect_source.stale_entity_removal_handler.job_id
)

View File

@ -1,17 +1,14 @@
import time import time
from typing import Any, Dict, List, Optional, cast from typing import Any, Dict, List
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from confluent_kafka.admin import AdminClient, NewTopic from confluent_kafka.admin import AdminClient, NewTopic
from freezegun import freeze_time from freezegun import freeze_time
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.kafka import KafkaSource
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers.docker_helpers import wait_for_port from tests.test_helpers.docker_helpers import wait_for_port
from tests.test_helpers.state_helpers import ( from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
run_and_get_pipeline, run_and_get_pipeline,
validate_all_providers_have_committed_successfully, validate_all_providers_have_committed_successfully,
) )
@ -81,15 +78,6 @@ class KafkaTopicsCxtManager:
self.delete_kafka_topics(self.topics) self.delete_kafka_topics(self.topics)
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
kafka_source = cast(KafkaSource, pipeline.source)
return kafka_source.get_current_checkpoint(
kafka_source.stale_entity_removal_handler.job_id
)
@freeze_time(FROZEN_TIME) @freeze_time(FROZEN_TIME)
@pytest.mark.integration @pytest.mark.integration
def test_kafka_ingest_with_stateful( def test_kafka_ingest_with_stateful(

View File

@ -1,18 +1,15 @@
import pathlib import pathlib
import time import time
from typing import Optional, cast
from unittest import mock from unittest import mock
import pytest import pytest
from freezegun import freeze_time from freezegun import freeze_time
from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.ldap import LDAPSource
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers import mce_helpers from tests.test_helpers import mce_helpers
from tests.test_helpers.docker_helpers import wait_for_port from tests.test_helpers.docker_helpers import wait_for_port
from tests.test_helpers.state_helpers import ( from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully, validate_all_providers_have_committed_successfully,
) )
@ -90,15 +87,6 @@ def ldap_ingest_common(
return pipeline return pipeline
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
ldap_source = cast(LDAPSource, pipeline.source)
return ldap_source.get_current_checkpoint(
ldap_source.stale_entity_removal_handler.job_id
)
@freeze_time(FROZEN_TIME) @freeze_time(FROZEN_TIME)
@pytest.mark.integration @pytest.mark.integration
def test_ldap_stateful( def test_ldap_stateful(

View File

@ -27,11 +27,10 @@ from datahub.ingestion.source.looker.looker_query_model import (
LookViewField, LookViewField,
UserViewField, UserViewField,
) )
from datahub.ingestion.source.looker.looker_source import LookerDashboardSource
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers import mce_helpers from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import ( from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully, validate_all_providers_have_committed_successfully,
) )
@ -719,12 +718,3 @@ def test_looker_ingest_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_
assert len(difference_dashboard_urns) == 1 assert len(difference_dashboard_urns) == 1
deleted_dashboard_urns = ["urn:li:dashboard:(looker,dashboards.11)"] deleted_dashboard_urns = ["urn:li:dashboard:(looker,dashboards.11)"]
assert sorted(deleted_dashboard_urns) == sorted(difference_dashboard_urns) assert sorted(deleted_dashboard_urns) == sorted(difference_dashboard_urns)
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint]:
dbt_source = cast(LookerDashboardSource, pipeline.source)
return dbt_source.get_current_checkpoint(
dbt_source.stale_entity_removal_handler.job_id
)

View File

@ -1,6 +1,6 @@
import logging import logging
import pathlib import pathlib
from typing import Any, Dict, List, Optional, cast from typing import Any, Dict, List, cast
from unittest import mock from unittest import mock
import pydantic import pydantic
@ -15,10 +15,8 @@ from datahub.ingestion.source.file import read_metadata_file
from datahub.ingestion.source.looker.lookml_source import ( from datahub.ingestion.source.looker.lookml_source import (
LookerModel, LookerModel,
LookerRefinementResolver, LookerRefinementResolver,
LookMLSource,
LookMLSourceConfig, LookMLSourceConfig,
) )
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.metadata.schema_classes import ( from datahub.metadata.schema_classes import (
DatasetSnapshotClass, DatasetSnapshotClass,
@ -27,6 +25,7 @@ from datahub.metadata.schema_classes import (
) )
from tests.test_helpers import mce_helpers from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import ( from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully, validate_all_providers_have_committed_successfully,
) )
@ -847,15 +846,6 @@ def test_lookml_ingest_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_
assert sorted(deleted_dataset_urns) == sorted(difference_dataset_urns) assert sorted(deleted_dataset_urns) == sorted(difference_dataset_urns)
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint]:
dbt_source = cast(LookMLSource, pipeline.source)
return dbt_source.get_current_checkpoint(
dbt_source.stale_entity_removal_handler.job_id
)
def test_lookml_base_folder(): def test_lookml_base_folder():
fake_api = { fake_api = {
"base_url": "https://filler.cloud.looker.com", "base_url": "https://filler.cloud.looker.com",

View File

@ -1,7 +1,6 @@
import asyncio import asyncio
import pathlib import pathlib
from functools import partial from functools import partial
from typing import Optional, cast
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import jsonpickle import jsonpickle
@ -10,11 +9,10 @@ from freezegun import freeze_time
from okta.models import Group, User from okta.models import Group, User
from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.identity.okta import OktaConfig, OktaSource from datahub.ingestion.source.identity.okta import OktaConfig
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers import mce_helpers from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import ( from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully, validate_all_providers_have_committed_successfully,
) )
@ -203,15 +201,6 @@ def test_okta_source_custom_user_name_regex(pytestconfig, mock_datahub_graph, tm
) )
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
azure_ad_source = cast(OktaSource, pipeline.source)
return azure_ad_source.get_current_checkpoint(
azure_ad_source.stale_entity_removal_handler.job_id
)
@freeze_time(FROZEN_TIME) @freeze_time(FROZEN_TIME)
def test_okta_stateful_ingestion(pytestconfig, tmp_path, mock_time, mock_datahub_graph): def test_okta_stateful_ingestion(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
test_resources_dir: pathlib.Path = pytestconfig.rootpath / "tests/integration/okta" test_resources_dir: pathlib.Path = pytestconfig.rootpath / "tests/integration/okta"

View File

@ -216,9 +216,11 @@ def get_current_checkpoint_from_pipeline(
) -> Dict[JobId, Optional[Checkpoint[Any]]]: ) -> Dict[JobId, Optional[Checkpoint[Any]]]:
powerbi_source = cast(PowerBiDashboardSource, pipeline.source) powerbi_source = cast(PowerBiDashboardSource, pipeline.source)
checkpoints = {} checkpoints = {}
for job_id in powerbi_source._usecase_handlers.keys(): for job_id in powerbi_source.state_provider._usecase_handlers.keys():
# for multi-workspace checkpoint, every good checkpoint will have an unique workspaceid suffix # for multi-workspace checkpoint, every good checkpoint will have an unique workspaceid suffix
checkpoints[job_id] = powerbi_source.get_current_checkpoint(job_id) checkpoints[job_id] = powerbi_source.state_provider.get_current_checkpoint(
job_id
)
return checkpoints return checkpoints

View File

@ -1,15 +1,13 @@
from typing import Any, Dict, Optional, cast from typing import Any, Dict
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from freezegun import freeze_time from freezegun import freeze_time
from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.superset import SupersetSource
from tests.test_helpers import mce_helpers from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import ( from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
run_and_get_pipeline, run_and_get_pipeline,
validate_all_providers_have_committed_successfully, validate_all_providers_have_committed_successfully,
) )
@ -19,15 +17,6 @@ GMS_PORT = 8080
GMS_SERVER = f"http://localhost:{GMS_PORT}" GMS_SERVER = f"http://localhost:{GMS_PORT}"
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
superset_source = cast(SupersetSource, pipeline.source)
return superset_source.get_current_checkpoint(
superset_source.stale_entity_removal_handler.job_id
)
def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: def register_mock_api(request_mock: Any, override_data: dict = {}) -> None:
api_vs_response = { api_vs_response = {
"mock://mock-domain.superset.com/api/v1/security/login": { "mock://mock-domain.superset.com/api/v1/security/login": {

View File

@ -2,7 +2,7 @@ import json
import logging import logging
import pathlib import pathlib
import sys import sys
from typing import Optional, cast from typing import cast
from unittest import mock from unittest import mock
import pytest import pytest
@ -17,8 +17,6 @@ from tableauserverclient.models import (
from datahub.configuration.source_common import DEFAULT_ENV from datahub.configuration.source_common import DEFAULT_ENV
from datahub.ingestion.run.pipeline import Pipeline, PipelineContext from datahub.ingestion.run.pipeline import Pipeline, PipelineContext
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.tableau import TableauConfig, TableauSource from datahub.ingestion.source.tableau import TableauConfig, TableauSource
from datahub.ingestion.source.tableau_common import ( from datahub.ingestion.source.tableau_common import (
TableauLineageOverrides, TableauLineageOverrides,
@ -31,6 +29,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
from datahub.metadata.schema_classes import MetadataChangeProposalClass, UpstreamClass from datahub.metadata.schema_classes import MetadataChangeProposalClass, UpstreamClass
from tests.test_helpers import mce_helpers from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import ( from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully, validate_all_providers_have_committed_successfully,
) )
@ -253,15 +252,6 @@ def tableau_ingest_common(
return pipeline return pipeline
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
tableau_source = cast(TableauSource, pipeline.source)
return tableau_source.get_current_checkpoint(
tableau_source.stale_entity_removal_handler.job_id
)
@freeze_time(FROZEN_TIME) @freeze_time(FROZEN_TIME)
@pytest.mark.integration @pytest.mark.integration
def test_tableau_ingest(pytestconfig, tmp_path, mock_datahub_graph): def test_tableau_ingest(pytestconfig, tmp_path, mock_datahub_graph):

View File

@ -11,6 +11,14 @@ from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
) )
from datahub.ingestion.graph.client import DataHubGraph from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
)
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionSourceBase,
)
def validate_all_providers_have_committed_successfully( def validate_all_providers_have_committed_successfully(
@ -91,3 +99,15 @@ def mock_datahub_graph():
mock_datahub_graph_ctx = MockDataHubGraphContext() mock_datahub_graph_ctx = MockDataHubGraphContext()
return mock_datahub_graph_ctx.mock_graph return mock_datahub_graph_ctx.mock_graph
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
# TODO: This only works for stale entity removal. We need to generalize this.
stateful_source = cast(StatefulIngestionSourceBase, pipeline.source)
stale_entity_removal_handler: StaleEntityRemovalHandler = stateful_source.stale_entity_removal_handler # type: ignore
return stateful_source.state_provider.get_current_checkpoint(
stale_entity_removal_handler.job_id
)

View File

@ -2,8 +2,8 @@ from typing import Dict, List, Tuple
import pytest import pytest
from datahub.ingestion.source.state.stale_entity_removal_handler import ( from datahub.ingestion.source.state.entity_removal_state import (
StaleEntityCheckpointStateBase, compute_percent_entities_changed,
) )
OldNewEntLists = List[Tuple[List[str], List[str]]] OldNewEntLists = List[Tuple[List[str], List[str]]]
@ -43,9 +43,5 @@ old_new_ent_tests: Dict[str, Tuple[OldNewEntLists, float]] = {
def test_change_percent( def test_change_percent(
new_old_entity_list: OldNewEntLists, expected_percent_change: float new_old_entity_list: OldNewEntLists, expected_percent_change: float
) -> None: ) -> None:
actual_percent_change = ( actual_percent_change = compute_percent_entities_changed(new_old_entity_list)
StaleEntityCheckpointStateBase.compute_percent_entities_changed(
new_old_entity_list
)
)
assert actual_percent_change == expected_percent_change assert actual_percent_change == expected_percent_change

View File

@ -10,10 +10,8 @@ from freezegun import freeze_time
from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.extractor.schema_util import avro_schema_to_mce_fields from datahub.ingestion.extractor.schema_util import avro_schema_to_mce_fields
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.sink.file import write_metadata_file from datahub.ingestion.sink.file import write_metadata_file
from datahub.ingestion.source.aws.glue import GlueSource, GlueSourceConfig from datahub.ingestion.source.aws.glue import GlueSource, GlueSourceConfig
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.sql_common_state import ( from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState, BaseSQLAlchemyCheckpointState,
) )
@ -26,6 +24,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.schema import (
from datahub.utilities.hive_schema_to_avro import get_avro_schema_for_hive_column from datahub.utilities.hive_schema_to_avro import get_avro_schema_for_hive_column
from tests.test_helpers import mce_helpers from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import ( from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
run_and_get_pipeline, run_and_get_pipeline,
validate_all_providers_have_committed_successfully, validate_all_providers_have_committed_successfully,
) )
@ -240,15 +239,6 @@ def test_config_without_platform():
assert source.platform == "glue" assert source.platform == "glue"
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint]:
glue_source = cast(GlueSource, pipeline.source)
return glue_source.get_current_checkpoint(
glue_source.stale_entity_removal_handler.job_id
)
@freeze_time(FROZEN_TIME) @freeze_time(FROZEN_TIME)
def test_glue_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph): def test_glue_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
test_resources_dir = pytestconfig.rootpath / "tests/unit/glue" test_resources_dir = pytestconfig.rootpath / "tests/unit/glue"

View File

@ -1,13 +1,13 @@
from typing import Any, Dict, Optional, cast from typing import Any, Dict, Optional, cast
from sqlalchemy import create_engine
from sqlalchemy.sql import text
from datahub.ingestion.api.committable import StatefulCommittable from datahub.ingestion.api.committable import StatefulCommittable
from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.sql.mysql import MySQLConfig, MySQLSource from datahub.ingestion.source.sql.mysql import MySQLConfig, MySQLSource
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.state.checkpoint import Checkpoint from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from sqlalchemy import create_engine
from sqlalchemy.sql import text
from tests.utils import ( from tests.utils import (
get_gms_url, get_gms_url,
get_mysql_password, get_mysql_password,
@ -49,8 +49,9 @@ def test_stateful_ingestion(wait_for_healthchecks):
def get_current_checkpoint_from_pipeline( def get_current_checkpoint_from_pipeline(
pipeline: Pipeline, pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]: ) -> Optional[Checkpoint[GenericCheckpointState]]:
# TODO: Refactor to use the helper method in the metadata-ingestion tests, instead of copying it here.
mysql_source = cast(MySQLSource, pipeline.source) mysql_source = cast(MySQLSource, pipeline.source)
return mysql_source.get_current_checkpoint( return mysql_source.state_provider.get_current_checkpoint(
mysql_source.stale_entity_removal_handler.job_id mysql_source.stale_entity_removal_handler.job_id
) )