refactor(ingest): simplify stateful ingestion provider interface (#8104)

Co-authored-by: Tamas Nemeth <treff7es@gmail.com>
This commit is contained in:
Harshal Sheth 2023-05-24 01:27:57 +05:30 committed by GitHub
parent afd65e16fb
commit b0f8c3de1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 179 additions and 323 deletions

View File

@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Dict, Generic, Iterable, Optional, Tuple, Type
from datahub.emitter.mce_builder import set_dataset_urn_to_lower
from datahub.ingestion.api.committable import Committable
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from datahub.ingestion.graph.client import DataHubGraph
if TYPE_CHECKING:
from datahub.ingestion.run.pipeline import PipelineConfig
@ -43,22 +43,19 @@ class PipelineContext:
def __init__(
self,
run_id: str,
datahub_api: Optional["DatahubClientConfig"] = None,
graph: Optional[DataHubGraph] = None,
pipeline_name: Optional[str] = None,
dry_run: bool = False,
preview_mode: bool = False,
pipeline_config: Optional["PipelineConfig"] = None,
) -> None:
self.pipeline_config = pipeline_config
self.graph = graph
self.run_id = run_id
self.pipeline_name = pipeline_name
self.dry_run_mode = dry_run
self.preview_mode = preview_mode
self.checkpointers: Dict[str, Committable] = {}
try:
self.graph = DataHubGraph(datahub_api) if datahub_api is not None else None
except Exception as e:
raise Exception(f"Failed to connect to DataHub: {e}") from e
self._set_dataset_urn_to_lower_if_needed()

View File

@ -27,6 +27,7 @@ from datahub.ingestion.api.sink import Sink, SinkReport, WriteCallback
from datahub.ingestion.api.source import Extractor, Source
from datahub.ingestion.api.transform import Transformer
from datahub.ingestion.extractor.extractor_registry import extractor_registry
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.reporting.reporting_provider_registry import (
reporting_provider_registry,
)
@ -183,10 +184,15 @@ class Pipeline:
self.last_time_printed = int(time.time())
self.cli_report = CliReport()
self.graph = None
with _add_init_error_context("connect to DataHub"):
if self.config.datahub_api:
self.graph = DataHubGraph(self.config.datahub_api)
with _add_init_error_context("set up framework context"):
self.ctx = PipelineContext(
run_id=self.config.run_id,
datahub_api=self.config.datahub_api,
graph=self.graph,
pipeline_name=self.config.pipeline_name,
dry_run=dry_run,
preview_mode=preview_mode,

View File

@ -7,7 +7,7 @@ from dataclasses import dataclass
from datetime import timedelta
from enum import auto
from threading import BoundedSemaphore
from typing import Union, cast
from typing import Union
from datahub.cli.cli_utils import set_env_variables_override_config
from datahub.configuration.common import (
@ -126,8 +126,7 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
def handle_work_unit_start(self, workunit: WorkUnit) -> None:
if isinstance(workunit, MetadataWorkUnit):
mwu: MetadataWorkUnit = cast(MetadataWorkUnit, workunit)
self.treat_errors_as_warnings = mwu.treat_errors_as_warnings
self.treat_errors_as_warnings = workunit.treat_errors_as_warnings
def handle_work_unit_end(self, workunit: WorkUnit) -> None:
pass

View File

@ -1331,6 +1331,3 @@ class LookerDashboardSource(TestableSource, StatefulIngestionSourceBase):
def get_report(self) -> SourceReport:
return self.reporter
def close(self):
self.prepare_for_commit()

View File

@ -2163,6 +2163,3 @@ class LookMLSource(StatefulIngestionSourceBase):
def get_report(self):
return self.reporter
def close(self):
self.prepare_for_commit()

View File

@ -1228,7 +1228,7 @@ class PowerBiDashboardSource(StatefulIngestionSourceBase):
# Because job_id is used as dictionary key, we have to set a new job_id
# Refer to https://github.com/datahub-project/datahub/blob/master/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py#L390
self.stale_entity_removal_handler.set_job_id(workspace.id)
self.register_stateful_ingestion_usecase_handler(
self.state_provider.register_stateful_ingestion_usecase_handler(
self.stale_entity_removal_handler
)

View File

@ -1283,6 +1283,3 @@ class VerticaSource(SQLAlchemySource):
return each["owner_name"]
return None
def close(self):
self.prepare_for_commit()

View File

@ -1,11 +1,9 @@
from typing import Any, Dict, Iterable, List, Type
from typing import Any, Dict, Iterable, List, Tuple, Type
import pydantic
from datahub.emitter.mce_builder import make_assertion_urn, make_container_urn
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityCheckpointStateBase,
)
from datahub.ingestion.source.state.checkpoint import CheckpointStateBase
from datahub.utilities.checkpoint_state_util import CheckpointStateUtil
from datahub.utilities.dedup_list import deduplicate_list
from datahub.utilities.urns.urn import guess_entity_type
@ -56,7 +54,7 @@ def pydantic_state_migrator(mapping: Dict[str, str]) -> classmethod:
return pydantic.root_validator(pre=True, allow_reuse=True)(_validate_field_rename)
class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointState"]):
class GenericCheckpointState(CheckpointStateBase):
urns: List[str] = pydantic.Field(default_factory=list)
# We store a bit of extra internal-only state so that we can keep the urns list deduplicated.
@ -85,11 +83,15 @@ class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointSt
self.urns = deduplicate_list(self.urns)
self._urns_set = set(self.urns)
@classmethod
def get_supported_types(cls) -> List[str]:
return ["*"]
def add_checkpoint_urn(self, type: str, urn: str) -> None:
"""
Adds an urn into the list used for tracking the type.
:param type: Deprecated parameter, has no effect.
:param urn: The urn string
"""
# TODO: Deprecate the `type` parameter and remove it.
if urn not in self._urns_set:
self.urns.append(urn)
self._urns_set.add(urn)
@ -97,9 +99,18 @@ class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointSt
def get_urns_not_in(
self, type: str, other_checkpoint_state: "GenericCheckpointState"
) -> Iterable[str]:
"""
Gets the urns present in this checkpoint but not the other_checkpoint for the given type.
:param type: Deprecated. Set to "*".
:param other_checkpoint_state: the checkpoint state to compute the urn set difference against.
:return: an iterable to the set of urns present in this checkpoint state but not in the other_checkpoint.
"""
diff = set(self.urns) - set(other_checkpoint_state.urns)
# To maintain backwards compatibility, we provide this filtering mechanism.
# TODO: Deprecate the `type` parameter and remove it.
if type == "*":
yield from diff
elif type == "topic":
@ -110,6 +121,36 @@ class GenericCheckpointState(StaleEntityCheckpointStateBase["GenericCheckpointSt
def get_percent_entities_changed(
self, old_checkpoint_state: "GenericCheckpointState"
) -> float:
return StaleEntityCheckpointStateBase.compute_percent_entities_changed(
"""
Returns the percentage of entities that have changed relative to `old_checkpoint_state`.
:param old_checkpoint_state: the old checkpoint state to compute the relative change percent against.
:return: (1-|intersection(self, old_checkpoint_state)| / |old_checkpoint_state|) * 100.0
"""
return compute_percent_entities_changed(
[(self.urns, old_checkpoint_state.urns)]
)
def compute_percent_entities_changed(
new_old_entity_list: List[Tuple[List[str], List[str]]]
) -> float:
old_count_all = 0
overlap_count_all = 0
for new_entities, old_entities in new_old_entity_list:
(overlap_count, old_count, _,) = get_entity_overlap_and_cardinalities(
new_entities=new_entities, old_entities=old_entities
)
overlap_count_all += overlap_count
old_count_all += old_count
if old_count_all:
return (1 - overlap_count_all / old_count_all) * 100.0
return 0.0
def get_entity_overlap_and_cardinalities(
new_entities: List[str], old_entities: List[str]
) -> Tuple[int, int, int]:
new_set = set(new_entities)
old_set = set(old_entities)
return len(new_set.intersection(old_set)), len(old_set), len(new_set)

View File

@ -40,15 +40,17 @@ class ProfilingHandler(StatefulIngestionUsecaseHandlerBase[ProfilingCheckpointSt
pipeline_name: Optional[str],
run_id: str,
):
self.source = source
self.state_provider = source.state_provider
self.stateful_ingestion_config: Optional[
ProfilingStatefulIngestionConfig
] = config.stateful_ingestion
self.pipeline_name = pipeline_name
self.run_id = run_id
self.checkpointing_enabled: bool = source.is_stateful_ingestion_configured()
self.checkpointing_enabled: bool = (
self.state_provider.is_stateful_ingestion_configured()
)
self._job_id = self._init_job_id()
self.source.register_stateful_ingestion_usecase_handler(self)
self.state_provider.register_stateful_ingestion_usecase_handler(self)
def _ignore_old_state(self) -> bool:
if (
@ -91,7 +93,7 @@ class ProfilingHandler(StatefulIngestionUsecaseHandlerBase[ProfilingCheckpointSt
def get_current_state(self) -> Optional[ProfilingCheckpointState]:
if not self.is_checkpointing_enabled() or self._ignore_new_state():
return None
cur_checkpoint = self.source.get_current_checkpoint(self.job_id)
cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id)
assert cur_checkpoint is not None
cur_state = cast(ProfilingCheckpointState, cur_checkpoint.state)
return cur_state
@ -108,7 +110,7 @@ class ProfilingHandler(StatefulIngestionUsecaseHandlerBase[ProfilingCheckpointSt
def get_last_state(self) -> Optional[ProfilingCheckpointState]:
if not self.is_checkpointing_enabled() or self._ignore_old_state():
return None
last_checkpoint = self.source.get_last_checkpoint(
last_checkpoint = self.state_provider.get_last_checkpoint(
self.job_id, ProfilingCheckpointState
)
if last_checkpoint and last_checkpoint.state:

View File

@ -46,14 +46,17 @@ class RedundantRunSkipHandler(
run_id: str,
):
self.source = source
self.state_provider = source.state_provider
self.stateful_ingestion_config: Optional[
StatefulRedundantRunSkipConfig
] = config.stateful_ingestion
self.pipeline_name = pipeline_name
self.run_id = run_id
self.checkpointing_enabled: bool = source.is_stateful_ingestion_configured()
self.checkpointing_enabled: bool = (
self.state_provider.is_stateful_ingestion_configured()
)
self._job_id = self._init_job_id()
self.source.register_stateful_ingestion_usecase_handler(self)
self.state_provider.register_stateful_ingestion_usecase_handler(self)
def _ignore_old_state(self) -> bool:
if (
@ -114,7 +117,7 @@ class RedundantRunSkipHandler(
) -> None:
if not self.is_checkpointing_enabled() or self._ignore_new_state():
return
cur_checkpoint = self.source.get_current_checkpoint(self.job_id)
cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id)
assert cur_checkpoint is not None
cur_state = cast(BaseUsageCheckpointState, cur_checkpoint.state)
cur_state.begin_timestamp_millis = start_time_millis
@ -125,7 +128,7 @@ class RedundantRunSkipHandler(
return False
# Determine from the last check point state
last_successful_pipeline_run_end_time_millis: Optional[int] = None
last_checkpoint = self.source.get_last_checkpoint(
last_checkpoint = self.state_provider.get_last_checkpoint(
self.job_id, BaseUsageCheckpointState
)
if last_checkpoint and last_checkpoint.state:

View File

@ -1,25 +1,14 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import (
Dict,
Generic,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
cast,
)
from typing import Dict, Iterable, Optional, Set, Type, cast
import pydantic
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import JobId
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.state.checkpoint import Checkpoint, CheckpointStateBase
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionConfig,
StatefulIngestionConfigBase,
@ -61,102 +50,26 @@ class StaleEntityRemovalSourceReport(StatefulIngestionReport):
self.soft_deleted_stale_entities.append(urn)
Derived = TypeVar("Derived", bound=CheckpointStateBase)
class StaleEntityCheckpointStateBase(CheckpointStateBase, ABC, Generic[Derived]):
"""
Defines the abstract interface for the checkpoint states that are used for stale entity removal.
Examples include sql_common state for tracking table and & view urns,
dbt that tracks node & assertion urns, kafka state tracking topic urns.
"""
@classmethod
@abstractmethod
def get_supported_types(cls) -> List[str]:
pass
@abstractmethod
def add_checkpoint_urn(self, type: str, urn: str) -> None:
"""
Adds an urn into the list used for tracking the type.
:param type: The type of the urn such as a 'table', 'view',
'node', 'topic', 'assertion' that the concrete sub-class understands.
:param urn: The urn string
:return: None.
"""
pass
@abstractmethod
def get_urns_not_in(
self, type: str, other_checkpoint_state: Derived
) -> Iterable[str]:
"""
Gets the urns present in this checkpoint but not the other_checkpoint for the given type.
:param type: The type of the urn such as a 'table', 'view',
'node', 'topic', 'assertion' that the concrete sub-class understands.
:param other_checkpoint_state: the checkpoint state to compute the urn set difference against.
:return: an iterable to the set of urns present in this checkpoing state but not in the other_checkpoint.
"""
pass
@abstractmethod
def get_percent_entities_changed(self, old_checkpoint_state: Derived) -> float:
"""
Returns the percentage of entities that have changed relative to `old_checkpoint_state`.
:param old_checkpoint_state: the old checkpoint state to compute the relative change percent against.
:return: (1-|intersection(self, old_checkpoint_state)| / |old_checkpoint_state|) * 100.0
"""
pass
@staticmethod
def compute_percent_entities_changed(
new_old_entity_list: List[Tuple[List[str], List[str]]]
) -> float:
old_count_all = 0
overlap_count_all = 0
for new_entities, old_entities in new_old_entity_list:
(
overlap_count,
old_count,
_,
) = StaleEntityCheckpointStateBase.get_entity_overlap_and_cardinalities(
new_entities=new_entities, old_entities=old_entities
)
overlap_count_all += overlap_count
old_count_all += old_count
if old_count_all:
return (1 - overlap_count_all / old_count_all) * 100.0
return 0.0
@staticmethod
def get_entity_overlap_and_cardinalities(
new_entities: List[str], old_entities: List[str]
) -> Tuple[int, int, int]:
new_set = set(new_entities)
old_set = set(old_entities)
return len(new_set.intersection(old_set)), len(old_set), len(new_set)
class StaleEntityRemovalHandler(
StatefulIngestionUsecaseHandlerBase[StaleEntityCheckpointStateBase]
StatefulIngestionUsecaseHandlerBase["GenericCheckpointState"]
):
"""
The stateful ingestion helper class that handles stale entity removal.
This contains the generic logic for all sources that need to support stale entity removal for all the states
derived from StaleEntityCheckpointStateBase. This uses the template method pattern on CRTP based derived state
class hierarchies.
derived from GenericCheckpointState.
"""
def __init__(
self,
source: StatefulIngestionSourceBase,
config: StatefulIngestionConfigBase[StatefulStaleMetadataRemovalConfig],
state_type_class: Type[StaleEntityCheckpointStateBase],
state_type_class: Type["GenericCheckpointState"],
pipeline_name: Optional[str],
run_id: str,
):
self.source = source
self.state_provider = source.state_provider
self.state_type_class = state_type_class
self.pipeline_name = pipeline_name
self.run_id = run_id
@ -166,7 +79,7 @@ class StaleEntityRemovalHandler(
self.checkpointing_enabled: bool = (
True
if (
source.is_stateful_ingestion_configured()
self.state_provider.is_stateful_ingestion_configured()
and self.stateful_ingestion_config
and self.stateful_ingestion_config.remove_stale_metadata
)
@ -174,7 +87,7 @@ class StaleEntityRemovalHandler(
)
self._job_id = self._init_job_id()
self._urns_to_skip: Set[str] = set()
self.source.register_stateful_ingestion_usecase_handler(self)
self.state_provider.register_stateful_ingestion_usecase_handler(self)
@classmethod
def compute_job_id(
@ -246,7 +159,7 @@ class StaleEntityRemovalHandler(
)
return None
def _create_soft_delete_workunit(self, urn: str, type: str) -> MetadataWorkUnit:
def _create_soft_delete_workunit(self, urn: str) -> MetadataWorkUnit:
logger.info(f"Soft-deleting stale entity - {urn}")
mcp = MetadataChangeProposalWrapper(
entityUrn=urn,
@ -278,20 +191,16 @@ class StaleEntityRemovalHandler(
if not self.is_checkpointing_enabled() or self._ignore_old_state():
return
logger.debug("Checking for stale entity removal.")
last_checkpoint: Optional[Checkpoint] = self.source.get_last_checkpoint(
last_checkpoint: Optional[Checkpoint] = self.state_provider.get_last_checkpoint(
self.job_id, self.state_type_class
)
if not last_checkpoint:
return
cur_checkpoint = self.source.get_current_checkpoint(self.job_id)
cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id)
assert cur_checkpoint is not None
# Get the underlying states
last_checkpoint_state = cast(
StaleEntityCheckpointStateBase, last_checkpoint.state
)
cur_checkpoint_state = cast(
StaleEntityCheckpointStateBase, cur_checkpoint.state
)
last_checkpoint_state = cast(GenericCheckpointState, last_checkpoint.state)
cur_checkpoint_state = cast(GenericCheckpointState, cur_checkpoint.state)
# Check if the entity delta is below the fail-safe threshold.
entity_difference_percent = cur_checkpoint_state.get_percent_entities_changed(
@ -316,21 +225,20 @@ class StaleEntityRemovalHandler(
return
# Everything looks good, emit the soft-deletion workunits
for type in self.state_type_class.get_supported_types():
for urn in last_checkpoint_state.get_urns_not_in(
type=type, other_checkpoint_state=cur_checkpoint_state
):
if urn in self._urns_to_skip:
logger.debug(
f"Not soft-deleting entity {urn} since it is in urns_to_skip"
)
continue
yield self._create_soft_delete_workunit(urn, type)
for urn in last_checkpoint_state.get_urns_not_in(
type="*", other_checkpoint_state=cur_checkpoint_state
):
if urn in self._urns_to_skip:
logger.debug(
f"Not soft-deleting entity {urn} since it is in urns_to_skip"
)
continue
yield self._create_soft_delete_workunit(urn)
def add_entity_to_state(self, type: str, urn: str) -> None:
if not self.is_checkpointing_enabled() or self._ignore_new_state():
return
cur_checkpoint = self.source.get_current_checkpoint(self.job_id)
cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id)
assert cur_checkpoint is not None
cur_state = cast(StaleEntityCheckpointStateBase, cur_checkpoint.state)
cur_state = cast(GenericCheckpointState, cur_checkpoint.state)
cur_state.add_checkpoint_urn(type=type, urn=urn)

View File

@ -27,10 +27,7 @@ from datahub.ingestion.source.state.use_case_handler import (
from datahub.ingestion.source.state_provider.state_provider_registry import (
ingestion_checkpoint_provider_registry,
)
from datahub.metadata.schema_classes import (
DatahubIngestionCheckpointClass,
DatahubIngestionRunSummaryClass,
)
from datahub.metadata.schema_classes import DatahubIngestionCheckpointClass
logger: logging.Logger = logging.getLogger(__name__)
@ -171,13 +168,8 @@ class StatefulIngestionSourceBase(Source):
ctx: PipelineContext,
) -> None:
super().__init__(ctx)
self.stateful_ingestion_config = config.stateful_ingestion
self.last_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
self.cur_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
self.run_summaries_to_report: Dict[JobId, DatahubIngestionRunSummaryClass] = {}
self.report: StatefulIngestionReport = StatefulIngestionReport()
self._initialize_checkpointing_state_provider()
self._usecase_handlers: Dict[JobId, StatefulIngestionUsecaseHandlerBase] = {}
self.state_provider = StateProviderWrapper(config.stateful_ingestion, ctx)
def warn(self, log: logging.Logger, key: str, reason: str) -> None:
self.report.report_warning(key, reason)
@ -187,6 +179,26 @@ class StatefulIngestionSourceBase(Source):
self.report.report_failure(key, reason)
log.error(f"{key} => {reason}")
def close(self) -> None:
self.state_provider.prepare_for_commit()
super().close()
class StateProviderWrapper:
def __init__(
self,
config: Optional[StatefulIngestionConfig],
ctx: PipelineContext,
) -> None:
self.ctx = ctx
self.stateful_ingestion_config = config
self.last_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
self.cur_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
self.report: StatefulIngestionReport = StatefulIngestionReport()
self._initialize_checkpointing_state_provider()
self._usecase_handlers: Dict[JobId, StatefulIngestionUsecaseHandlerBase] = {}
#
# Checkpointing specific support.
#
@ -383,7 +395,3 @@ class StatefulIngestionSourceBase(Source):
def prepare_for_commit(self) -> None:
"""NOTE: Sources should call this method from their close method."""
self._prepare_checkpoint_states_for_commit()
def close(self) -> None:
self.prepare_for_commit()
super().close()

View File

@ -1,17 +1,16 @@
import json
import pathlib
from functools import partial
from typing import List, Optional, cast
from typing import List
from unittest.mock import patch
from freezegun import freeze_time
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.identity.azure_ad import AzureADConfig, AzureADSource
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.identity.azure_ad import AzureADConfig
from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully,
)
@ -328,15 +327,6 @@ def test_azure_source_ingestion_disabled(pytestconfig, mock_datahub_graph, tmp_p
)
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
azure_ad_source = cast(AzureADSource, pipeline.source)
return azure_ad_source.get_current_checkpoint(
azure_ad_source.stale_entity_removal_handler.job_id
)
@freeze_time(FROZEN_TIME)
def test_azure_ad_stateful_ingestion(
pytestconfig, tmp_path, mock_time, mock_datahub_graph

View File

@ -1,5 +1,5 @@
from pathlib import PosixPath
from typing import Any, Dict, Optional, Union, cast
from typing import Any, Dict, Union
from unittest.mock import patch
import pytest
@ -8,11 +8,9 @@ from iceberg.core.filesystem.file_status import FileStatus
from iceberg.core.filesystem.local_filesystem import LocalFileSystem
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.iceberg.iceberg import IcebergSource
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
run_and_get_pipeline,
validate_all_providers_have_committed_successfully,
)
@ -22,15 +20,6 @@ GMS_PORT = 8080
GMS_SERVER = f"http://localhost:{GMS_PORT}"
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
iceberg_source = cast(IcebergSource, pipeline.source)
return iceberg_source.get_current_checkpoint(
iceberg_source.stale_entity_removal_handler.job_id
)
@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_iceberg_ingest(pytestconfig, tmp_path, mock_time):

View File

@ -1,6 +1,6 @@
import subprocess
import time
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, List, cast
from unittest import mock
import pytest
@ -8,12 +8,12 @@ import requests
from freezegun import freeze_time
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers import mce_helpers
from tests.test_helpers.click_helpers import run_datahub_cmd
from tests.test_helpers.docker_helpers import wait_for_port
from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully,
)
@ -489,14 +489,3 @@ def test_kafka_connect_ingest_stateful(
"urn:li:dataJob:(urn:li:dataFlow:(kafka-connect,connect-instance-1.mysql_source2,PROD),librarydb.member)",
]
assert sorted(deleted_job_urns) == sorted(difference_job_urns)
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint]:
from datahub.ingestion.source.kafka_connect import KafkaConnectSource
kafka_connect_source = cast(KafkaConnectSource, pipeline.source)
return kafka_connect_source.get_current_checkpoint(
kafka_connect_source.stale_entity_removal_handler.job_id
)

View File

@ -1,17 +1,14 @@
import time
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, List
from unittest.mock import patch
import pytest
from confluent_kafka.admin import AdminClient, NewTopic
from freezegun import freeze_time
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.kafka import KafkaSource
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers.docker_helpers import wait_for_port
from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
run_and_get_pipeline,
validate_all_providers_have_committed_successfully,
)
@ -81,15 +78,6 @@ class KafkaTopicsCxtManager:
self.delete_kafka_topics(self.topics)
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
kafka_source = cast(KafkaSource, pipeline.source)
return kafka_source.get_current_checkpoint(
kafka_source.stale_entity_removal_handler.job_id
)
@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_kafka_ingest_with_stateful(

View File

@ -1,18 +1,15 @@
import pathlib
import time
from typing import Optional, cast
from unittest import mock
import pytest
from freezegun import freeze_time
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.ldap import LDAPSource
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers import mce_helpers
from tests.test_helpers.docker_helpers import wait_for_port
from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully,
)
@ -90,15 +87,6 @@ def ldap_ingest_common(
return pipeline
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
ldap_source = cast(LDAPSource, pipeline.source)
return ldap_source.get_current_checkpoint(
ldap_source.stale_entity_removal_handler.job_id
)
@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_ldap_stateful(

View File

@ -27,11 +27,10 @@ from datahub.ingestion.source.looker.looker_query_model import (
LookViewField,
UserViewField,
)
from datahub.ingestion.source.looker.looker_source import LookerDashboardSource
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully,
)
@ -719,12 +718,3 @@ def test_looker_ingest_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_
assert len(difference_dashboard_urns) == 1
deleted_dashboard_urns = ["urn:li:dashboard:(looker,dashboards.11)"]
assert sorted(deleted_dashboard_urns) == sorted(difference_dashboard_urns)
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint]:
dbt_source = cast(LookerDashboardSource, pipeline.source)
return dbt_source.get_current_checkpoint(
dbt_source.stale_entity_removal_handler.job_id
)

View File

@ -1,6 +1,6 @@
import logging
import pathlib
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, List, cast
from unittest import mock
import pydantic
@ -15,10 +15,8 @@ from datahub.ingestion.source.file import read_metadata_file
from datahub.ingestion.source.looker.lookml_source import (
LookerModel,
LookerRefinementResolver,
LookMLSource,
LookMLSourceConfig,
)
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.metadata.schema_classes import (
DatasetSnapshotClass,
@ -27,6 +25,7 @@ from datahub.metadata.schema_classes import (
)
from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully,
)
@ -847,15 +846,6 @@ def test_lookml_ingest_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_
assert sorted(deleted_dataset_urns) == sorted(difference_dataset_urns)
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint]:
dbt_source = cast(LookMLSource, pipeline.source)
return dbt_source.get_current_checkpoint(
dbt_source.stale_entity_removal_handler.job_id
)
def test_lookml_base_folder():
fake_api = {
"base_url": "https://filler.cloud.looker.com",

View File

@ -1,7 +1,6 @@
import asyncio
import pathlib
from functools import partial
from typing import Optional, cast
from unittest.mock import Mock, patch
import jsonpickle
@ -10,11 +9,10 @@ from freezegun import freeze_time
from okta.models import Group, User
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.identity.okta import OktaConfig, OktaSource
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.identity.okta import OktaConfig
from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully,
)
@ -203,15 +201,6 @@ def test_okta_source_custom_user_name_regex(pytestconfig, mock_datahub_graph, tm
)
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
azure_ad_source = cast(OktaSource, pipeline.source)
return azure_ad_source.get_current_checkpoint(
azure_ad_source.stale_entity_removal_handler.job_id
)
@freeze_time(FROZEN_TIME)
def test_okta_stateful_ingestion(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
test_resources_dir: pathlib.Path = pytestconfig.rootpath / "tests/integration/okta"

View File

@ -216,9 +216,11 @@ def get_current_checkpoint_from_pipeline(
) -> Dict[JobId, Optional[Checkpoint[Any]]]:
powerbi_source = cast(PowerBiDashboardSource, pipeline.source)
checkpoints = {}
for job_id in powerbi_source._usecase_handlers.keys():
for job_id in powerbi_source.state_provider._usecase_handlers.keys():
# for multi-workspace checkpoint, every good checkpoint will have an unique workspaceid suffix
checkpoints[job_id] = powerbi_source.get_current_checkpoint(job_id)
checkpoints[job_id] = powerbi_source.state_provider.get_current_checkpoint(
job_id
)
return checkpoints

View File

@ -1,15 +1,13 @@
from typing import Any, Dict, Optional, cast
from typing import Any, Dict
from unittest.mock import patch
import pytest
from freezegun import freeze_time
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.superset import SupersetSource
from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
run_and_get_pipeline,
validate_all_providers_have_committed_successfully,
)
@ -19,15 +17,6 @@ GMS_PORT = 8080
GMS_SERVER = f"http://localhost:{GMS_PORT}"
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
superset_source = cast(SupersetSource, pipeline.source)
return superset_source.get_current_checkpoint(
superset_source.stale_entity_removal_handler.job_id
)
def register_mock_api(request_mock: Any, override_data: dict = {}) -> None:
api_vs_response = {
"mock://mock-domain.superset.com/api/v1/security/login": {

View File

@ -2,7 +2,7 @@ import json
import logging
import pathlib
import sys
from typing import Optional, cast
from typing import cast
from unittest import mock
import pytest
@ -17,8 +17,6 @@ from tableauserverclient.models import (
from datahub.configuration.source_common import DEFAULT_ENV
from datahub.ingestion.run.pipeline import Pipeline, PipelineContext
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.tableau import TableauConfig, TableauSource
from datahub.ingestion.source.tableau_common import (
TableauLineageOverrides,
@ -31,6 +29,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
from datahub.metadata.schema_classes import MetadataChangeProposalClass, UpstreamClass
from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
validate_all_providers_have_committed_successfully,
)
@ -253,15 +252,6 @@ def tableau_ingest_common(
return pipeline
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
tableau_source = cast(TableauSource, pipeline.source)
return tableau_source.get_current_checkpoint(
tableau_source.stale_entity_removal_handler.job_id
)
@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_tableau_ingest(pytestconfig, tmp_path, mock_datahub_graph):

View File

@ -11,6 +11,14 @@ from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
)
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
)
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionSourceBase,
)
def validate_all_providers_have_committed_successfully(
@ -91,3 +99,15 @@ def mock_datahub_graph():
mock_datahub_graph_ctx = MockDataHubGraphContext()
return mock_datahub_graph_ctx.mock_graph
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
# TODO: This only works for stale entity removal. We need to generalize this.
stateful_source = cast(StatefulIngestionSourceBase, pipeline.source)
stale_entity_removal_handler: StaleEntityRemovalHandler = stateful_source.stale_entity_removal_handler # type: ignore
return stateful_source.state_provider.get_current_checkpoint(
stale_entity_removal_handler.job_id
)

View File

@ -2,8 +2,8 @@ from typing import Dict, List, Tuple
import pytest
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityCheckpointStateBase,
from datahub.ingestion.source.state.entity_removal_state import (
compute_percent_entities_changed,
)
OldNewEntLists = List[Tuple[List[str], List[str]]]
@ -43,9 +43,5 @@ old_new_ent_tests: Dict[str, Tuple[OldNewEntLists, float]] = {
def test_change_percent(
new_old_entity_list: OldNewEntLists, expected_percent_change: float
) -> None:
actual_percent_change = (
StaleEntityCheckpointStateBase.compute_percent_entities_changed(
new_old_entity_list
)
)
actual_percent_change = compute_percent_entities_changed(new_old_entity_list)
assert actual_percent_change == expected_percent_change

View File

@ -10,10 +10,8 @@ from freezegun import freeze_time
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.extractor.schema_util import avro_schema_to_mce_fields
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.sink.file import write_metadata_file
from datahub.ingestion.source.aws.glue import GlueSource, GlueSourceConfig
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState,
)
@ -26,6 +24,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.schema import (
from datahub.utilities.hive_schema_to_avro import get_avro_schema_for_hive_column
from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
run_and_get_pipeline,
validate_all_providers_have_committed_successfully,
)
@ -240,15 +239,6 @@ def test_config_without_platform():
assert source.platform == "glue"
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint]:
glue_source = cast(GlueSource, pipeline.source)
return glue_source.get_current_checkpoint(
glue_source.stale_entity_removal_handler.job_id
)
@freeze_time(FROZEN_TIME)
def test_glue_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
test_resources_dir = pytestconfig.rootpath / "tests/unit/glue"

View File

@ -1,13 +1,13 @@
from typing import Any, Dict, Optional, cast
from sqlalchemy import create_engine
from sqlalchemy.sql import text
from datahub.ingestion.api.committable import StatefulCommittable
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.sql.mysql import MySQLConfig, MySQLSource
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from sqlalchemy import create_engine
from sqlalchemy.sql import text
from tests.utils import (
get_gms_url,
get_mysql_password,
@ -49,8 +49,9 @@ def test_stateful_ingestion(wait_for_healthchecks):
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
# TODO: Refactor to use the helper method in the metadata-ingestion tests, instead of copying it here.
mysql_source = cast(MySQLSource, pipeline.source)
return mysql_source.get_current_checkpoint(
return mysql_source.state_provider.get_current_checkpoint(
mysql_source.stale_entity_removal_handler.job_id
)