mirror of
https://github.com/datahub-project/datahub.git
synced 2025-08-18 14:16:48 +00:00
refactor(ingest): simplify stateful ingestion provider interface (#8104)
Co-authored-by: Tamas Nemeth <treff7es@gmail.com>
This commit is contained in:
parent
afd65e16fb
commit
b0f8c3de1e
@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Dict, Generic, Iterable, Optional, Tuple, Type
|
||||
|
||||
from datahub.emitter.mce_builder import set_dataset_urn_to_lower
|
||||
from datahub.ingestion.api.committable import Committable
|
||||
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
|
||||
from datahub.ingestion.graph.client import DataHubGraph
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datahub.ingestion.run.pipeline import PipelineConfig
|
||||
@ -43,22 +43,19 @@ class PipelineContext:
|
||||
def __init__(
|
||||
self,
|
||||
run_id: str,
|
||||
datahub_api: Optional["DatahubClientConfig"] = None,
|
||||
graph: Optional[DataHubGraph] = None,
|
||||
pipeline_name: Optional[str] = None,
|
||||
dry_run: bool = False,
|
||||
preview_mode: bool = False,
|
||||
pipeline_config: Optional["PipelineConfig"] = None,
|
||||
) -> None:
|
||||
self.pipeline_config = pipeline_config
|
||||
self.graph = graph
|
||||
self.run_id = run_id
|
||||
self.pipeline_name = pipeline_name
|
||||
self.dry_run_mode = dry_run
|
||||
self.preview_mode = preview_mode
|
||||
self.checkpointers: Dict[str, Committable] = {}
|
||||
try:
|
||||
self.graph = DataHubGraph(datahub_api) if datahub_api is not None else None
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to connect to DataHub: {e}") from e
|
||||
|
||||
self._set_dataset_urn_to_lower_if_needed()
|
||||
|
||||
|
@ -27,6 +27,7 @@ from datahub.ingestion.api.sink import Sink, SinkReport, WriteCallback
|
||||
from datahub.ingestion.api.source import Extractor, Source
|
||||
from datahub.ingestion.api.transform import Transformer
|
||||
from datahub.ingestion.extractor.extractor_registry import extractor_registry
|
||||
from datahub.ingestion.graph.client import DataHubGraph
|
||||
from datahub.ingestion.reporting.reporting_provider_registry import (
|
||||
reporting_provider_registry,
|
||||
)
|
||||
@ -183,10 +184,15 @@ class Pipeline:
|
||||
self.last_time_printed = int(time.time())
|
||||
self.cli_report = CliReport()
|
||||
|
||||
self.graph = None
|
||||
with _add_init_error_context("connect to DataHub"):
|
||||
if self.config.datahub_api:
|
||||
self.graph = DataHubGraph(self.config.datahub_api)
|
||||
|
||||
with _add_init_error_context("set up framework context"):
|
||||
self.ctx = PipelineContext(
|
||||
run_id=self.config.run_id,
|
||||
datahub_api=self.config.datahub_api,
|
||||
graph=self.graph,
|
||||
pipeline_name=self.config.pipeline_name,
|
||||
dry_run=dry_run,
|
||||
preview_mode=preview_mode,
|
||||
|
@ -7,7 +7,7 @@ from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from enum import auto
|
||||
from threading import BoundedSemaphore
|
||||
from typing import Union, cast
|
||||
from typing import Union
|
||||
|
||||
from datahub.cli.cli_utils import set_env_variables_override_config
|
||||
from datahub.configuration.common import (
|
||||
@ -126,8 +126,7 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
|
||||
|
||||
def handle_work_unit_start(self, workunit: WorkUnit) -> None:
|
||||
if isinstance(workunit, MetadataWorkUnit):
|
||||
mwu: MetadataWorkUnit = cast(MetadataWorkUnit, workunit)
|
||||
self.treat_errors_as_warnings = mwu.treat_errors_as_warnings
|
||||
self.treat_errors_as_warnings = workunit.treat_errors_as_warnings
|
||||
|
||||
def handle_work_unit_end(self, workunit: WorkUnit) -> None:
|
||||
pass
|
||||
|
@ -1331,6 +1331,3 @@ class LookerDashboardSource(TestableSource, StatefulIngestionSourceBase):
|
||||
|
||||
def get_report(self) -> SourceReport:
|
||||
return self.reporter
|
||||
|
||||
def close(self):
|
||||
self.prepare_for_commit()
|
||||
|
@ -2163,6 +2163,3 @@ class LookMLSource(StatefulIngestionSourceBase):
|
||||
|
||||
def get_report(self):
|
||||
return self.reporter
|
||||
|
||||
def close(self):
|
||||
self.prepare_for_commit()
|
||||
|
@ -1228,7 +1228,7 @@ class PowerBiDashboardSource(StatefulIngestionSourceBase):
|
||||
# Because job_id is used as dictionary key, we have to set a new job_id
|
||||
# Refer to https://github.com/datahub-project/datahub/blob/master/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py#L390
|
||||
self.stale_entity_removal_handler.set_job_id(workspace.id)
|
||||
self.register_stateful_ingestion_usecase_handler(
|
||||
self.state_provider.register_stateful_ingestion_usecase_handler(
|
||||
self.stale_entity_removal_handler
|
||||
)
|
||||
|
||||
|
@ -1283,6 +1283,3 @@ class VerticaSource(SQLAlchemySource):
|
||||
return each["owner_name"]
|
||||
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
self.prepare_for_commit()
|
||||
|
@ -1,11 +1,9 @@
|
||||
from typing import Any, Dict, Iterable, List, Type
|
||||
from typing import Any, Dict, Iterable, List, Tuple, Type
|
||||
|
||||
import pydantic
|
||||
|
||||
from datahub.emitter.mce_builder import make_assertion_urn, make_container_urn
|
||||
from datahub.ingestion.source.state.stale_entity_removal_handler import (
|
||||
StaleEntityCheckpointStateBase,
|
||||
)
|
||||
from datahub.ingestion.source.state.checkpoint import CheckpointStateBase
|
||||
from datahub.utilities.checkpoint_state_util import CheckpointStateUtil
|
||||
from datahub.utilities.dedup_list import deduplicate_list
|
||||
from datahub.utilities.urns.urn import guess_entity_type
|
||||
@ -56,7 +54,7 @@ def pydantic_state_migrator(mapping: Dict[str, str]) -> classmethod:
|
||||
return pydantic.root_validator(pre=True, allow_reuse=True)(_validate_field_rename)
|
||||
|
||||
|
||||
class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointState"]):
|
||||
class GenericCheckpointState(CheckpointStateBase):
|
||||
urns: List[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# We store a bit of extra internal-only state so that we can keep the urns list deduplicated.
|
||||
@ -85,11 +83,15 @@ class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointSt
|
||||
self.urns = deduplicate_list(self.urns)
|
||||
self._urns_set = set(self.urns)
|
||||
|
||||
@classmethod
|
||||
def get_supported_types(cls) -> List[str]:
|
||||
return ["*"]
|
||||
|
||||
def add_checkpoint_urn(self, type: str, urn: str) -> None:
|
||||
"""
|
||||
Adds an urn into the list used for tracking the type.
|
||||
|
||||
:param type: Deprecated parameter, has no effect.
|
||||
:param urn: The urn string
|
||||
"""
|
||||
|
||||
# TODO: Deprecate the `type` parameter and remove it.
|
||||
if urn not in self._urns_set:
|
||||
self.urns.append(urn)
|
||||
self._urns_set.add(urn)
|
||||
@ -97,9 +99,18 @@ class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointSt
|
||||
def get_urns_not_in(
|
||||
self, type: str, other_checkpoint_state: "GenericCheckpointState"
|
||||
) -> Iterable[str]:
|
||||
"""
|
||||
Gets the urns present in this checkpoint but not the other_checkpoint for the given type.
|
||||
|
||||
:param type: Deprecated. Set to "*".
|
||||
:param other_checkpoint_state: the checkpoint state to compute the urn set difference against.
|
||||
:return: an iterable to the set of urns present in this checkpoint state but not in the other_checkpoint.
|
||||
"""
|
||||
|
||||
diff = set(self.urns) - set(other_checkpoint_state.urns)
|
||||
|
||||
# To maintain backwards compatibility, we provide this filtering mechanism.
|
||||
# TODO: Deprecate the `type` parameter and remove it.
|
||||
if type == "*":
|
||||
yield from diff
|
||||
elif type == "topic":
|
||||
@ -110,6 +121,36 @@ class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointSt
|
||||
def get_percent_entities_changed(
|
||||
self, old_checkpoint_state: "GenericCheckpointState"
|
||||
) -> float:
|
||||
return StaleEntityCheckpointStateBase.compute_percent_entities_changed(
|
||||
"""
|
||||
Returns the percentage of entities that have changed relative to `old_checkpoint_state`.
|
||||
|
||||
:param old_checkpoint_state: the old checkpoint state to compute the relative change percent against.
|
||||
:return: (1-|intersection(self, old_checkpoint_state)| / |old_checkpoint_state|) * 100.0
|
||||
"""
|
||||
return compute_percent_entities_changed(
|
||||
[(self.urns, old_checkpoint_state.urns)]
|
||||
)
|
||||
|
||||
|
||||
def compute_percent_entities_changed(
|
||||
new_old_entity_list: List[Tuple[List[str], List[str]]]
|
||||
) -> float:
|
||||
old_count_all = 0
|
||||
overlap_count_all = 0
|
||||
for new_entities, old_entities in new_old_entity_list:
|
||||
(overlap_count, old_count, _,) = get_entity_overlap_and_cardinalities(
|
||||
new_entities=new_entities, old_entities=old_entities
|
||||
)
|
||||
overlap_count_all += overlap_count
|
||||
old_count_all += old_count
|
||||
if old_count_all:
|
||||
return (1 - overlap_count_all / old_count_all) * 100.0
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_entity_overlap_and_cardinalities(
|
||||
new_entities: List[str], old_entities: List[str]
|
||||
) -> Tuple[int, int, int]:
|
||||
new_set = set(new_entities)
|
||||
old_set = set(old_entities)
|
||||
return len(new_set.intersection(old_set)), len(old_set), len(new_set)
|
||||
|
@ -40,15 +40,17 @@ class ProfilingHandler(StatefulIngestionUsecaseHandlerBase[ProfilingCheckpointSt
|
||||
pipeline_name: Optional[str],
|
||||
run_id: str,
|
||||
):
|
||||
self.source = source
|
||||
self.state_provider = source.state_provider
|
||||
self.stateful_ingestion_config: Optional[
|
||||
ProfilingStatefulIngestionConfig
|
||||
] = config.stateful_ingestion
|
||||
self.pipeline_name = pipeline_name
|
||||
self.run_id = run_id
|
||||
self.checkpointing_enabled: bool = source.is_stateful_ingestion_configured()
|
||||
self.checkpointing_enabled: bool = (
|
||||
self.state_provider.is_stateful_ingestion_configured()
|
||||
)
|
||||
self._job_id = self._init_job_id()
|
||||
self.source.register_stateful_ingestion_usecase_handler(self)
|
||||
self.state_provider.register_stateful_ingestion_usecase_handler(self)
|
||||
|
||||
def _ignore_old_state(self) -> bool:
|
||||
if (
|
||||
@ -91,7 +93,7 @@ class ProfilingHandler(StatefulIngestionUsecaseHandlerBase[ProfilingCheckpointSt
|
||||
def get_current_state(self) -> Optional[ProfilingCheckpointState]:
|
||||
if not self.is_checkpointing_enabled() or self._ignore_new_state():
|
||||
return None
|
||||
cur_checkpoint = self.source.get_current_checkpoint(self.job_id)
|
||||
cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id)
|
||||
assert cur_checkpoint is not None
|
||||
cur_state = cast(ProfilingCheckpointState, cur_checkpoint.state)
|
||||
return cur_state
|
||||
@ -108,7 +110,7 @@ class ProfilingHandler(StatefulIngestionUsecaseHandlerBase[ProfilingCheckpointSt
|
||||
def get_last_state(self) -> Optional[ProfilingCheckpointState]:
|
||||
if not self.is_checkpointing_enabled() or self._ignore_old_state():
|
||||
return None
|
||||
last_checkpoint = self.source.get_last_checkpoint(
|
||||
last_checkpoint = self.state_provider.get_last_checkpoint(
|
||||
self.job_id, ProfilingCheckpointState
|
||||
)
|
||||
if last_checkpoint and last_checkpoint.state:
|
||||
|
@ -46,14 +46,17 @@ class RedundantRunSkipHandler(
|
||||
run_id: str,
|
||||
):
|
||||
self.source = source
|
||||
self.state_provider = source.state_provider
|
||||
self.stateful_ingestion_config: Optional[
|
||||
StatefulRedundantRunSkipConfig
|
||||
] = config.stateful_ingestion
|
||||
self.pipeline_name = pipeline_name
|
||||
self.run_id = run_id
|
||||
self.checkpointing_enabled: bool = source.is_stateful_ingestion_configured()
|
||||
self.checkpointing_enabled: bool = (
|
||||
self.state_provider.is_stateful_ingestion_configured()
|
||||
)
|
||||
self._job_id = self._init_job_id()
|
||||
self.source.register_stateful_ingestion_usecase_handler(self)
|
||||
self.state_provider.register_stateful_ingestion_usecase_handler(self)
|
||||
|
||||
def _ignore_old_state(self) -> bool:
|
||||
if (
|
||||
@ -114,7 +117,7 @@ class RedundantRunSkipHandler(
|
||||
) -> None:
|
||||
if not self.is_checkpointing_enabled() or self._ignore_new_state():
|
||||
return
|
||||
cur_checkpoint = self.source.get_current_checkpoint(self.job_id)
|
||||
cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id)
|
||||
assert cur_checkpoint is not None
|
||||
cur_state = cast(BaseUsageCheckpointState, cur_checkpoint.state)
|
||||
cur_state.begin_timestamp_millis = start_time_millis
|
||||
@ -125,7 +128,7 @@ class RedundantRunSkipHandler(
|
||||
return False
|
||||
# Determine from the last check point state
|
||||
last_successful_pipeline_run_end_time_millis: Optional[int] = None
|
||||
last_checkpoint = self.source.get_last_checkpoint(
|
||||
last_checkpoint = self.state_provider.get_last_checkpoint(
|
||||
self.job_id, BaseUsageCheckpointState
|
||||
)
|
||||
if last_checkpoint and last_checkpoint.state:
|
||||
|
@ -1,25 +1,14 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
from typing import Dict, Iterable, Optional, Set, Type, cast
|
||||
|
||||
import pydantic
|
||||
|
||||
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
||||
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import JobId
|
||||
from datahub.ingestion.api.workunit import MetadataWorkUnit
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint, CheckpointStateBase
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
||||
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
|
||||
from datahub.ingestion.source.state.stateful_ingestion_base import (
|
||||
StatefulIngestionConfig,
|
||||
StatefulIngestionConfigBase,
|
||||
@ -61,102 +50,26 @@ class StaleEntityRemovalSourceReport(StatefulIngestionReport):
|
||||
self.soft_deleted_stale_entities.append(urn)
|
||||
|
||||
|
||||
Derived = TypeVar("Derived", bound=CheckpointStateBase)
|
||||
|
||||
|
||||
class StaleEntityCheckpointStateBase(CheckpointStateBase, ABC, Generic[Derived]):
|
||||
"""
|
||||
Defines the abstract interface for the checkpoint states that are used for stale entity removal.
|
||||
Examples include sql_common state for tracking table and & view urns,
|
||||
dbt that tracks node & assertion urns, kafka state tracking topic urns.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_supported_types(cls) -> List[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_checkpoint_urn(self, type: str, urn: str) -> None:
|
||||
"""
|
||||
Adds an urn into the list used for tracking the type.
|
||||
:param type: The type of the urn such as a 'table', 'view',
|
||||
'node', 'topic', 'assertion' that the concrete sub-class understands.
|
||||
:param urn: The urn string
|
||||
:return: None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_urns_not_in(
|
||||
self, type: str, other_checkpoint_state: Derived
|
||||
) -> Iterable[str]:
|
||||
"""
|
||||
Gets the urns present in this checkpoint but not the other_checkpoint for the given type.
|
||||
:param type: The type of the urn such as a 'table', 'view',
|
||||
'node', 'topic', 'assertion' that the concrete sub-class understands.
|
||||
:param other_checkpoint_state: the checkpoint state to compute the urn set difference against.
|
||||
:return: an iterable to the set of urns present in this checkpoing state but not in the other_checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_percent_entities_changed(self, old_checkpoint_state: Derived) -> float:
|
||||
"""
|
||||
Returns the percentage of entities that have changed relative to `old_checkpoint_state`.
|
||||
:param old_checkpoint_state: the old checkpoint state to compute the relative change percent against.
|
||||
:return: (1-|intersection(self, old_checkpoint_state)| / |old_checkpoint_state|) * 100.0
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def compute_percent_entities_changed(
|
||||
new_old_entity_list: List[Tuple[List[str], List[str]]]
|
||||
) -> float:
|
||||
old_count_all = 0
|
||||
overlap_count_all = 0
|
||||
for new_entities, old_entities in new_old_entity_list:
|
||||
(
|
||||
overlap_count,
|
||||
old_count,
|
||||
_,
|
||||
) = StaleEntityCheckpointStateBase.get_entity_overlap_and_cardinalities(
|
||||
new_entities=new_entities, old_entities=old_entities
|
||||
)
|
||||
overlap_count_all += overlap_count
|
||||
old_count_all += old_count
|
||||
if old_count_all:
|
||||
return (1 - overlap_count_all / old_count_all) * 100.0
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def get_entity_overlap_and_cardinalities(
|
||||
new_entities: List[str], old_entities: List[str]
|
||||
) -> Tuple[int, int, int]:
|
||||
new_set = set(new_entities)
|
||||
old_set = set(old_entities)
|
||||
return len(new_set.intersection(old_set)), len(old_set), len(new_set)
|
||||
|
||||
|
||||
class StaleEntityRemovalHandler(
|
||||
StatefulIngestionUsecaseHandlerBase[StaleEntityCheckpointStateBase]
|
||||
StatefulIngestionUsecaseHandlerBase["GenericCheckpointState"]
|
||||
):
|
||||
"""
|
||||
The stateful ingestion helper class that handles stale entity removal.
|
||||
This contains the generic logic for all sources that need to support stale entity removal for all the states
|
||||
derived from StaleEntityCheckpointStateBase. This uses the template method pattern on CRTP based derived state
|
||||
class hierarchies.
|
||||
derived from GenericCheckpointState.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: StatefulIngestionSourceBase,
|
||||
config: StatefulIngestionConfigBase[StatefulStaleMetadataRemovalConfig],
|
||||
state_type_class: Type[StaleEntityCheckpointStateBase],
|
||||
state_type_class: Type["GenericCheckpointState"],
|
||||
pipeline_name: Optional[str],
|
||||
run_id: str,
|
||||
):
|
||||
self.source = source
|
||||
self.state_provider = source.state_provider
|
||||
|
||||
self.state_type_class = state_type_class
|
||||
self.pipeline_name = pipeline_name
|
||||
self.run_id = run_id
|
||||
@ -166,7 +79,7 @@ class StaleEntityRemovalHandler(
|
||||
self.checkpointing_enabled: bool = (
|
||||
True
|
||||
if (
|
||||
source.is_stateful_ingestion_configured()
|
||||
self.state_provider.is_stateful_ingestion_configured()
|
||||
and self.stateful_ingestion_config
|
||||
and self.stateful_ingestion_config.remove_stale_metadata
|
||||
)
|
||||
@ -174,7 +87,7 @@ class StaleEntityRemovalHandler(
|
||||
)
|
||||
self._job_id = self._init_job_id()
|
||||
self._urns_to_skip: Set[str] = set()
|
||||
self.source.register_stateful_ingestion_usecase_handler(self)
|
||||
self.state_provider.register_stateful_ingestion_usecase_handler(self)
|
||||
|
||||
@classmethod
|
||||
def compute_job_id(
|
||||
@ -246,7 +159,7 @@ class StaleEntityRemovalHandler(
|
||||
)
|
||||
return None
|
||||
|
||||
def _create_soft_delete_workunit(self, urn: str, type: str) -> MetadataWorkUnit:
|
||||
def _create_soft_delete_workunit(self, urn: str) -> MetadataWorkUnit:
|
||||
logger.info(f"Soft-deleting stale entity - {urn}")
|
||||
mcp = MetadataChangeProposalWrapper(
|
||||
entityUrn=urn,
|
||||
@ -278,20 +191,16 @@ class StaleEntityRemovalHandler(
|
||||
if not self.is_checkpointing_enabled() or self._ignore_old_state():
|
||||
return
|
||||
logger.debug("Checking for stale entity removal.")
|
||||
last_checkpoint: Optional[Checkpoint] = self.source.get_last_checkpoint(
|
||||
last_checkpoint: Optional[Checkpoint] = self.state_provider.get_last_checkpoint(
|
||||
self.job_id, self.state_type_class
|
||||
)
|
||||
if not last_checkpoint:
|
||||
return
|
||||
cur_checkpoint = self.source.get_current_checkpoint(self.job_id)
|
||||
cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id)
|
||||
assert cur_checkpoint is not None
|
||||
# Get the underlying states
|
||||
last_checkpoint_state = cast(
|
||||
StaleEntityCheckpointStateBase, last_checkpoint.state
|
||||
)
|
||||
cur_checkpoint_state = cast(
|
||||
StaleEntityCheckpointStateBase, cur_checkpoint.state
|
||||
)
|
||||
last_checkpoint_state = cast(GenericCheckpointState, last_checkpoint.state)
|
||||
cur_checkpoint_state = cast(GenericCheckpointState, cur_checkpoint.state)
|
||||
|
||||
# Check if the entity delta is below the fail-safe threshold.
|
||||
entity_difference_percent = cur_checkpoint_state.get_percent_entities_changed(
|
||||
@ -316,21 +225,20 @@ class StaleEntityRemovalHandler(
|
||||
return
|
||||
|
||||
# Everything looks good, emit the soft-deletion workunits
|
||||
for type in self.state_type_class.get_supported_types():
|
||||
for urn in last_checkpoint_state.get_urns_not_in(
|
||||
type=type, other_checkpoint_state=cur_checkpoint_state
|
||||
):
|
||||
if urn in self._urns_to_skip:
|
||||
logger.debug(
|
||||
f"Not soft-deleting entity {urn} since it is in urns_to_skip"
|
||||
)
|
||||
continue
|
||||
yield self._create_soft_delete_workunit(urn, type)
|
||||
for urn in last_checkpoint_state.get_urns_not_in(
|
||||
type="*", other_checkpoint_state=cur_checkpoint_state
|
||||
):
|
||||
if urn in self._urns_to_skip:
|
||||
logger.debug(
|
||||
f"Not soft-deleting entity {urn} since it is in urns_to_skip"
|
||||
)
|
||||
continue
|
||||
yield self._create_soft_delete_workunit(urn)
|
||||
|
||||
def add_entity_to_state(self, type: str, urn: str) -> None:
|
||||
if not self.is_checkpointing_enabled() or self._ignore_new_state():
|
||||
return
|
||||
cur_checkpoint = self.source.get_current_checkpoint(self.job_id)
|
||||
cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id)
|
||||
assert cur_checkpoint is not None
|
||||
cur_state = cast(StaleEntityCheckpointStateBase, cur_checkpoint.state)
|
||||
cur_state = cast(GenericCheckpointState, cur_checkpoint.state)
|
||||
cur_state.add_checkpoint_urn(type=type, urn=urn)
|
||||
|
@ -27,10 +27,7 @@ from datahub.ingestion.source.state.use_case_handler import (
|
||||
from datahub.ingestion.source.state_provider.state_provider_registry import (
|
||||
ingestion_checkpoint_provider_registry,
|
||||
)
|
||||
from datahub.metadata.schema_classes import (
|
||||
DatahubIngestionCheckpointClass,
|
||||
DatahubIngestionRunSummaryClass,
|
||||
)
|
||||
from datahub.metadata.schema_classes import DatahubIngestionCheckpointClass
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
@ -171,13 +168,8 @@ class StatefulIngestionSourceBase(Source):
|
||||
ctx: PipelineContext,
|
||||
) -> None:
|
||||
super().__init__(ctx)
|
||||
self.stateful_ingestion_config = config.stateful_ingestion
|
||||
self.last_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
|
||||
self.cur_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
|
||||
self.run_summaries_to_report: Dict[JobId, DatahubIngestionRunSummaryClass] = {}
|
||||
self.report: StatefulIngestionReport = StatefulIngestionReport()
|
||||
self._initialize_checkpointing_state_provider()
|
||||
self._usecase_handlers: Dict[JobId, StatefulIngestionUsecaseHandlerBase] = {}
|
||||
self.state_provider = StateProviderWrapper(config.stateful_ingestion, ctx)
|
||||
|
||||
def warn(self, log: logging.Logger, key: str, reason: str) -> None:
|
||||
self.report.report_warning(key, reason)
|
||||
@ -187,6 +179,26 @@ class StatefulIngestionSourceBase(Source):
|
||||
self.report.report_failure(key, reason)
|
||||
log.error(f"{key} => {reason}")
|
||||
|
||||
def close(self) -> None:
|
||||
self.state_provider.prepare_for_commit()
|
||||
super().close()
|
||||
|
||||
|
||||
class StateProviderWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[StatefulIngestionConfig],
|
||||
ctx: PipelineContext,
|
||||
) -> None:
|
||||
self.ctx = ctx
|
||||
self.stateful_ingestion_config = config
|
||||
|
||||
self.last_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
|
||||
self.cur_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
|
||||
self.report: StatefulIngestionReport = StatefulIngestionReport()
|
||||
self._initialize_checkpointing_state_provider()
|
||||
self._usecase_handlers: Dict[JobId, StatefulIngestionUsecaseHandlerBase] = {}
|
||||
|
||||
#
|
||||
# Checkpointing specific support.
|
||||
#
|
||||
@ -383,7 +395,3 @@ class StatefulIngestionSourceBase(Source):
|
||||
def prepare_for_commit(self) -> None:
|
||||
"""NOTE: Sources should call this method from their close method."""
|
||||
self._prepare_checkpoint_states_for_commit()
|
||||
|
||||
def close(self) -> None:
|
||||
self.prepare_for_commit()
|
||||
super().close()
|
||||
|
@ -1,17 +1,16 @@
|
||||
import json
|
||||
import pathlib
|
||||
from functools import partial
|
||||
from typing import List, Optional, cast
|
||||
from typing import List
|
||||
from unittest.mock import patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
|
||||
from datahub.ingestion.run.pipeline import Pipeline
|
||||
from datahub.ingestion.source.identity.azure_ad import AzureADConfig, AzureADSource
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
||||
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
|
||||
from datahub.ingestion.source.identity.azure_ad import AzureADConfig
|
||||
from tests.test_helpers import mce_helpers
|
||||
from tests.test_helpers.state_helpers import (
|
||||
get_current_checkpoint_from_pipeline,
|
||||
validate_all_providers_have_committed_successfully,
|
||||
)
|
||||
|
||||
@ -328,15 +327,6 @@ def test_azure_source_ingestion_disabled(pytestconfig, mock_datahub_graph, tmp_p
|
||||
)
|
||||
|
||||
|
||||
def get_current_checkpoint_from_pipeline(
|
||||
pipeline: Pipeline,
|
||||
) -> Optional[Checkpoint[GenericCheckpointState]]:
|
||||
azure_ad_source = cast(AzureADSource, pipeline.source)
|
||||
return azure_ad_source.get_current_checkpoint(
|
||||
azure_ad_source.stale_entity_removal_handler.job_id
|
||||
)
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
def test_azure_ad_stateful_ingestion(
|
||||
pytestconfig, tmp_path, mock_time, mock_datahub_graph
|
||||
|
@ -1,5 +1,5 @@
|
||||
from pathlib import PosixPath
|
||||
from typing import Any, Dict, Optional, Union, cast
|
||||
from typing import Any, Dict, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -8,11 +8,9 @@ from iceberg.core.filesystem.file_status import FileStatus
|
||||
from iceberg.core.filesystem.local_filesystem import LocalFileSystem
|
||||
|
||||
from datahub.ingestion.run.pipeline import Pipeline
|
||||
from datahub.ingestion.source.iceberg.iceberg import IcebergSource
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
||||
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
|
||||
from tests.test_helpers import mce_helpers
|
||||
from tests.test_helpers.state_helpers import (
|
||||
get_current_checkpoint_from_pipeline,
|
||||
run_and_get_pipeline,
|
||||
validate_all_providers_have_committed_successfully,
|
||||
)
|
||||
@ -22,15 +20,6 @@ GMS_PORT = 8080
|
||||
GMS_SERVER = f"http://localhost:{GMS_PORT}"
|
||||
|
||||
|
||||
def get_current_checkpoint_from_pipeline(
|
||||
pipeline: Pipeline,
|
||||
) -> Optional[Checkpoint[GenericCheckpointState]]:
|
||||
iceberg_source = cast(IcebergSource, pipeline.source)
|
||||
return iceberg_source.get_current_checkpoint(
|
||||
iceberg_source.stale_entity_removal_handler.job_id
|
||||
)
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
@pytest.mark.integration
|
||||
def test_iceberg_ingest(pytestconfig, tmp_path, mock_time):
|
||||
|
@ -1,6 +1,6 @@
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from typing import Any, Dict, List, cast
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@ -8,12 +8,12 @@ import requests
|
||||
from freezegun import freeze_time
|
||||
|
||||
from datahub.ingestion.run.pipeline import Pipeline
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
||||
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
|
||||
from tests.test_helpers import mce_helpers
|
||||
from tests.test_helpers.click_helpers import run_datahub_cmd
|
||||
from tests.test_helpers.docker_helpers import wait_for_port
|
||||
from tests.test_helpers.state_helpers import (
|
||||
get_current_checkpoint_from_pipeline,
|
||||
validate_all_providers_have_committed_successfully,
|
||||
)
|
||||
|
||||
@ -489,14 +489,3 @@ def test_kafka_connect_ingest_stateful(
|
||||
"urn:li:dataJob:(urn:li:dataFlow:(kafka-connect,connect-instance-1.mysql_source2,PROD),librarydb.member)",
|
||||
]
|
||||
assert sorted(deleted_job_urns) == sorted(difference_job_urns)
|
||||
|
||||
|
||||
def get_current_checkpoint_from_pipeline(
|
||||
pipeline: Pipeline,
|
||||
) -> Optional[Checkpoint]:
|
||||
from datahub.ingestion.source.kafka_connect import KafkaConnectSource
|
||||
|
||||
kafka_connect_source = cast(KafkaConnectSource, pipeline.source)
|
||||
return kafka_connect_source.get_current_checkpoint(
|
||||
kafka_connect_source.stale_entity_removal_handler.job_id
|
||||
)
|
||||
|
@ -1,17 +1,14 @@
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from confluent_kafka.admin import AdminClient, NewTopic
|
||||
from freezegun import freeze_time
|
||||
|
||||
from datahub.ingestion.run.pipeline import Pipeline
|
||||
from datahub.ingestion.source.kafka import KafkaSource
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
||||
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
|
||||
from tests.test_helpers.docker_helpers import wait_for_port
|
||||
from tests.test_helpers.state_helpers import (
|
||||
get_current_checkpoint_from_pipeline,
|
||||
run_and_get_pipeline,
|
||||
validate_all_providers_have_committed_successfully,
|
||||
)
|
||||
@ -81,15 +78,6 @@ class KafkaTopicsCxtManager:
|
||||
self.delete_kafka_topics(self.topics)
|
||||
|
||||
|
||||
def get_current_checkpoint_from_pipeline(
|
||||
pipeline: Pipeline,
|
||||
) -> Optional[Checkpoint[GenericCheckpointState]]:
|
||||
kafka_source = cast(KafkaSource, pipeline.source)
|
||||
return kafka_source.get_current_checkpoint(
|
||||
kafka_source.stale_entity_removal_handler.job_id
|
||||
)
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
@pytest.mark.integration
|
||||
def test_kafka_ingest_with_stateful(
|
||||
|
@ -1,18 +1,15 @@
|
||||
import pathlib
|
||||
import time
|
||||
from typing import Optional, cast
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from datahub.ingestion.run.pipeline import Pipeline
|
||||
from datahub.ingestion.source.ldap import LDAPSource
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
||||
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
|
||||
from tests.test_helpers import mce_helpers
|
||||
from tests.test_helpers.docker_helpers import wait_for_port
|
||||
from tests.test_helpers.state_helpers import (
|
||||
get_current_checkpoint_from_pipeline,
|
||||
validate_all_providers_have_committed_successfully,
|
||||
)
|
||||
|
||||
@ -90,15 +87,6 @@ def ldap_ingest_common(
|
||||
return pipeline
|
||||
|
||||
|
||||
def get_current_checkpoint_from_pipeline(
|
||||
pipeline: Pipeline,
|
||||
) -> Optional[Checkpoint[GenericCheckpointState]]:
|
||||
ldap_source = cast(LDAPSource, pipeline.source)
|
||||
return ldap_source.get_current_checkpoint(
|
||||
ldap_source.stale_entity_removal_handler.job_id
|
||||
)
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
@pytest.mark.integration
|
||||
def test_ldap_stateful(
|
||||
|
@ -27,11 +27,10 @@ from datahub.ingestion.source.looker.looker_query_model import (
|
||||
LookViewField,
|
||||
UserViewField,
|
||||
)
|
||||
from datahub.ingestion.source.looker.looker_source import LookerDashboardSource
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
||||
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
|
||||
from tests.test_helpers import mce_helpers
|
||||
from tests.test_helpers.state_helpers import (
|
||||
get_current_checkpoint_from_pipeline,
|
||||
validate_all_providers_have_committed_successfully,
|
||||
)
|
||||
|
||||
@ -719,12 +718,3 @@ def test_looker_ingest_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_
|
||||
assert len(difference_dashboard_urns) == 1
|
||||
deleted_dashboard_urns = ["urn:li:dashboard:(looker,dashboards.11)"]
|
||||
assert sorted(deleted_dashboard_urns) == sorted(difference_dashboard_urns)
|
||||
|
||||
|
||||
def get_current_checkpoint_from_pipeline(
|
||||
pipeline: Pipeline,
|
||||
) -> Optional[Checkpoint]:
|
||||
dbt_source = cast(LookerDashboardSource, pipeline.source)
|
||||
return dbt_source.get_current_checkpoint(
|
||||
dbt_source.stale_entity_removal_handler.job_id
|
||||
)
|
||||
|
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from typing import Any, Dict, List, cast
|
||||
from unittest import mock
|
||||
|
||||
import pydantic
|
||||
@ -15,10 +15,8 @@ from datahub.ingestion.source.file import read_metadata_file
|
||||
from datahub.ingestion.source.looker.lookml_source import (
|
||||
LookerModel,
|
||||
LookerRefinementResolver,
|
||||
LookMLSource,
|
||||
LookMLSourceConfig,
|
||||
)
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
||||
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
|
||||
from datahub.metadata.schema_classes import (
|
||||
DatasetSnapshotClass,
|
||||
@ -27,6 +25,7 @@ from datahub.metadata.schema_classes import (
|
||||
)
|
||||
from tests.test_helpers import mce_helpers
|
||||
from tests.test_helpers.state_helpers import (
|
||||
get_current_checkpoint_from_pipeline,
|
||||
validate_all_providers_have_committed_successfully,
|
||||
)
|
||||
|
||||
@ -847,15 +846,6 @@ def test_lookml_ingest_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_
|
||||
assert sorted(deleted_dataset_urns) == sorted(difference_dataset_urns)
|
||||
|
||||
|
||||
def get_current_checkpoint_from_pipeline(
|
||||
pipeline: Pipeline,
|
||||
) -> Optional[Checkpoint]:
|
||||
dbt_source = cast(LookMLSource, pipeline.source)
|
||||
return dbt_source.get_current_checkpoint(
|
||||
dbt_source.stale_entity_removal_handler.job_id
|
||||
)
|
||||
|
||||
|
||||
def test_lookml_base_folder():
|
||||
fake_api = {
|
||||
"base_url": "https://filler.cloud.looker.com",
|
||||
|
@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import pathlib
|
||||
from functools import partial
|
||||
from typing import Optional, cast
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import jsonpickle
|
||||
@ -10,11 +9,10 @@ from freezegun import freeze_time
|
||||
from okta.models import Group, User
|
||||
|
||||
from datahub.ingestion.run.pipeline import Pipeline
|
||||
from datahub.ingestion.source.identity.okta import OktaConfig, OktaSource
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
||||
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
|
||||
from datahub.ingestion.source.identity.okta import OktaConfig
|
||||
from tests.test_helpers import mce_helpers
|
||||
from tests.test_helpers.state_helpers import (
|
||||
get_current_checkpoint_from_pipeline,
|
||||
validate_all_providers_have_committed_successfully,
|
||||
)
|
||||
|
||||
@ -203,15 +201,6 @@ def test_okta_source_custom_user_name_regex(pytestconfig, mock_datahub_graph, tm
|
||||
)
|
||||
|
||||
|
||||
def get_current_checkpoint_from_pipeline(
|
||||
pipeline: Pipeline,
|
||||
) -> Optional[Checkpoint[GenericCheckpointState]]:
|
||||
azure_ad_source = cast(OktaSource, pipeline.source)
|
||||
return azure_ad_source.get_current_checkpoint(
|
||||
azure_ad_source.stale_entity_removal_handler.job_id
|
||||
)
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
def test_okta_stateful_ingestion(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
|
||||
test_resources_dir: pathlib.Path = pytestconfig.rootpath / "tests/integration/okta"
|
||||
|
@ -216,9 +216,11 @@ def get_current_checkpoint_from_pipeline(
|
||||
) -> Dict[JobId, Optional[Checkpoint[Any]]]:
|
||||
powerbi_source = cast(PowerBiDashboardSource, pipeline.source)
|
||||
checkpoints = {}
|
||||
for job_id in powerbi_source._usecase_handlers.keys():
|
||||
for job_id in powerbi_source.state_provider._usecase_handlers.keys():
|
||||
# for multi-workspace checkpoint, every good checkpoint will have an unique workspaceid suffix
|
||||
checkpoints[job_id] = powerbi_source.get_current_checkpoint(job_id)
|
||||
checkpoints[job_id] = powerbi_source.state_provider.get_current_checkpoint(
|
||||
job_id
|
||||
)
|
||||
return checkpoints
|
||||
|
||||
|
||||
|
@ -1,15 +1,13 @@
|
||||
from typing import Any, Dict, Optional, cast
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from datahub.ingestion.run.pipeline import Pipeline
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
||||
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
|
||||
from datahub.ingestion.source.superset import SupersetSource
|
||||
from tests.test_helpers import mce_helpers
|
||||
from tests.test_helpers.state_helpers import (
|
||||
get_current_checkpoint_from_pipeline,
|
||||
run_and_get_pipeline,
|
||||
validate_all_providers_have_committed_successfully,
|
||||
)
|
||||
@ -19,15 +17,6 @@ GMS_PORT = 8080
|
||||
GMS_SERVER = f"http://localhost:{GMS_PORT}"
|
||||
|
||||
|
||||
def get_current_checkpoint_from_pipeline(
|
||||
pipeline: Pipeline,
|
||||
) -> Optional[Checkpoint[GenericCheckpointState]]:
|
||||
superset_source = cast(SupersetSource, pipeline.source)
|
||||
return superset_source.get_current_checkpoint(
|
||||
superset_source.stale_entity_removal_handler.job_id
|
||||
)
|
||||
|
||||
|
||||
def register_mock_api(request_mock: Any, override_data: dict = {}) -> None:
|
||||
api_vs_response = {
|
||||
"mock://mock-domain.superset.com/api/v1/security/login": {
|
||||
|
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@ -17,8 +17,6 @@ from tableauserverclient.models import (
|
||||
|
||||
from datahub.configuration.source_common import DEFAULT_ENV
|
||||
from datahub.ingestion.run.pipeline import Pipeline, PipelineContext
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
||||
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
|
||||
from datahub.ingestion.source.tableau import TableauConfig, TableauSource
|
||||
from datahub.ingestion.source.tableau_common import (
|
||||
TableauLineageOverrides,
|
||||
@ -31,6 +29,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
|
||||
from datahub.metadata.schema_classes import MetadataChangeProposalClass, UpstreamClass
|
||||
from tests.test_helpers import mce_helpers
|
||||
from tests.test_helpers.state_helpers import (
|
||||
get_current_checkpoint_from_pipeline,
|
||||
validate_all_providers_have_committed_successfully,
|
||||
)
|
||||
|
||||
@ -253,15 +252,6 @@ def tableau_ingest_common(
|
||||
return pipeline
|
||||
|
||||
|
||||
def get_current_checkpoint_from_pipeline(
|
||||
pipeline: Pipeline,
|
||||
) -> Optional[Checkpoint[GenericCheckpointState]]:
|
||||
tableau_source = cast(TableauSource, pipeline.source)
|
||||
return tableau_source.get_current_checkpoint(
|
||||
tableau_source.stale_entity_removal_handler.job_id
|
||||
)
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
@pytest.mark.integration
|
||||
def test_tableau_ingest(pytestconfig, tmp_path, mock_datahub_graph):
|
||||
|
@ -11,6 +11,14 @@ from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
|
||||
)
|
||||
from datahub.ingestion.graph.client import DataHubGraph
|
||||
from datahub.ingestion.run.pipeline import Pipeline
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
||||
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
|
||||
from datahub.ingestion.source.state.stale_entity_removal_handler import (
|
||||
StaleEntityRemovalHandler,
|
||||
)
|
||||
from datahub.ingestion.source.state.stateful_ingestion_base import (
|
||||
StatefulIngestionSourceBase,
|
||||
)
|
||||
|
||||
|
||||
def validate_all_providers_have_committed_successfully(
|
||||
@ -91,3 +99,15 @@ def mock_datahub_graph():
|
||||
|
||||
mock_datahub_graph_ctx = MockDataHubGraphContext()
|
||||
return mock_datahub_graph_ctx.mock_graph
|
||||
|
||||
|
||||
def get_current_checkpoint_from_pipeline(
|
||||
pipeline: Pipeline,
|
||||
) -> Optional[Checkpoint[GenericCheckpointState]]:
|
||||
# TODO: This only works for stale entity removal. We need to generalize this.
|
||||
|
||||
stateful_source = cast(StatefulIngestionSourceBase, pipeline.source)
|
||||
stale_entity_removal_handler: StaleEntityRemovalHandler = stateful_source.stale_entity_removal_handler # type: ignore
|
||||
return stateful_source.state_provider.get_current_checkpoint(
|
||||
stale_entity_removal_handler.job_id
|
||||
)
|
||||
|
@ -2,8 +2,8 @@ from typing import Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from datahub.ingestion.source.state.stale_entity_removal_handler import (
|
||||
StaleEntityCheckpointStateBase,
|
||||
from datahub.ingestion.source.state.entity_removal_state import (
|
||||
compute_percent_entities_changed,
|
||||
)
|
||||
|
||||
OldNewEntLists = List[Tuple[List[str], List[str]]]
|
||||
@ -43,9 +43,5 @@ old_new_ent_tests: Dict[str, Tuple[OldNewEntLists, float]] = {
|
||||
def test_change_percent(
|
||||
new_old_entity_list: OldNewEntLists, expected_percent_change: float
|
||||
) -> None:
|
||||
actual_percent_change = (
|
||||
StaleEntityCheckpointStateBase.compute_percent_entities_changed(
|
||||
new_old_entity_list
|
||||
)
|
||||
)
|
||||
actual_percent_change = compute_percent_entities_changed(new_old_entity_list)
|
||||
assert actual_percent_change == expected_percent_change
|
||||
|
@ -10,10 +10,8 @@ from freezegun import freeze_time
|
||||
|
||||
from datahub.ingestion.api.common import PipelineContext
|
||||
from datahub.ingestion.extractor.schema_util import avro_schema_to_mce_fields
|
||||
from datahub.ingestion.run.pipeline import Pipeline
|
||||
from datahub.ingestion.sink.file import write_metadata_file
|
||||
from datahub.ingestion.source.aws.glue import GlueSource, GlueSourceConfig
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
||||
from datahub.ingestion.source.state.sql_common_state import (
|
||||
BaseSQLAlchemyCheckpointState,
|
||||
)
|
||||
@ -26,6 +24,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.schema import (
|
||||
from datahub.utilities.hive_schema_to_avro import get_avro_schema_for_hive_column
|
||||
from tests.test_helpers import mce_helpers
|
||||
from tests.test_helpers.state_helpers import (
|
||||
get_current_checkpoint_from_pipeline,
|
||||
run_and_get_pipeline,
|
||||
validate_all_providers_have_committed_successfully,
|
||||
)
|
||||
@ -240,15 +239,6 @@ def test_config_without_platform():
|
||||
assert source.platform == "glue"
|
||||
|
||||
|
||||
def get_current_checkpoint_from_pipeline(
|
||||
pipeline: Pipeline,
|
||||
) -> Optional[Checkpoint]:
|
||||
glue_source = cast(GlueSource, pipeline.source)
|
||||
return glue_source.get_current_checkpoint(
|
||||
glue_source.stale_entity_removal_handler.job_id
|
||||
)
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
def test_glue_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
|
||||
test_resources_dir = pytestconfig.rootpath / "tests/unit/glue"
|
||||
|
@ -1,13 +1,13 @@
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
from datahub.ingestion.api.committable import StatefulCommittable
|
||||
from datahub.ingestion.run.pipeline import Pipeline
|
||||
from datahub.ingestion.source.sql.mysql import MySQLConfig, MySQLSource
|
||||
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
|
||||
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
||||
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
from tests.utils import (
|
||||
get_gms_url,
|
||||
get_mysql_password,
|
||||
@ -49,8 +49,9 @@ def test_stateful_ingestion(wait_for_healthchecks):
|
||||
def get_current_checkpoint_from_pipeline(
|
||||
pipeline: Pipeline,
|
||||
) -> Optional[Checkpoint[GenericCheckpointState]]:
|
||||
# TODO: Refactor to use the helper method in the metadata-ingestion tests, instead of copying it here.
|
||||
mysql_source = cast(MySQLSource, pipeline.source)
|
||||
return mysql_source.get_current_checkpoint(
|
||||
return mysql_source.state_provider.get_current_checkpoint(
|
||||
mysql_source.stale_entity_removal_handler.job_id
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user