123 lines
4.8 KiB
Python

import types
from typing import Any, Callable, Dict, Optional, Type, cast
from unittest.mock import MagicMock, create_autospec
import pytest
from avrogen.dict_wrapper import DictWrapper
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
IngestionCheckpointingProviderBase,
)
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.graph.config import DatahubClientConfig
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
)
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionSourceBase,
)
def validate_all_providers_have_committed_successfully(
pipeline: Pipeline, expected_providers: int
) -> None:
"""
makes sure the pipeline includes the desired number of providers
and for each of them verify it has successfully committed
"""
provider_count: int = 0
for _, provider in pipeline.ctx.get_committables():
provider_count += 1
assert isinstance(provider, IngestionCheckpointingProviderBase)
stateful_committable = cast(IngestionCheckpointingProviderBase, provider)
assert stateful_committable.has_successfully_committed()
assert stateful_committable.state_to_commit
assert provider_count == expected_providers
def run_and_get_pipeline(pipeline_config_dict: Dict[str, Any]) -> Pipeline:
pipeline = Pipeline.create(pipeline_config_dict)
pipeline.run()
pipeline.raise_from_status()
return pipeline
@pytest.fixture
def mock_datahub_graph():
class MockDataHubGraphContext:
pipeline_name: str = "test_pipeline"
run_id: str = "test_run"
def __init__(self) -> None:
"""
Create a new monkey-patched instance of the DataHubGraph graph client.
"""
# ensure this mock keeps the same api of the original class
self.mock_graph = create_autospec(DataHubGraph)
# Make server stateful ingestion capable
self.mock_graph.get_config.return_value = {"statefulIngestionCapable": True}
# Bind mock_graph's emit_mcp to testcase's monkey_patch_emit_mcp so that we can emulate emits.
self.mock_graph.emit_mcp = types.MethodType(
self.monkey_patch_emit_mcp, self.mock_graph
)
# Bind mock_graph's get_latest_timeseries_value to monkey_patch_get_latest_timeseries_value
self.mock_graph.get_latest_timeseries_value = types.MethodType(
self.monkey_patch_get_latest_timeseries_value, self.mock_graph
)
# Tracking for emitted mcps.
self.mcps_emitted: Dict[str, MetadataChangeProposalWrapper] = {}
def monkey_patch_emit_mcp(
self, graph_ref: MagicMock, mcpw: MetadataChangeProposalWrapper
) -> None:
"""
Mockey patched implementation of DatahubGraph.emit_mcp that caches the mcp locally in memory.
"""
# Cache the mcpw against the entityUrn
assert mcpw.entityUrn is not None
self.mcps_emitted[mcpw.entityUrn] = mcpw
def monkey_patch_get_latest_timeseries_value(
self,
graph_ref: MagicMock,
entity_urn: str,
aspect_type: Type[DictWrapper],
filter_criteria_map: Dict[str, str],
) -> Optional[DictWrapper]:
"""
Monkey patched implementation of DatahubGraph.get_latest_timeseries_value that returns the latest cached aspect
for a given entity urn.
"""
# Retrieve the cached mcpw and return its aspect value.
mcpw = self.mcps_emitted.get(entity_urn)
if mcpw:
return mcpw.aspect
return None
mock_datahub_graph_ctx = MockDataHubGraphContext()
return mock_datahub_graph_ctx.mock_graph
@pytest.fixture
def mock_datahub_graph_instance(
mock_datahub_graph: Callable[[DatahubClientConfig], DataHubGraph],
) -> DataHubGraph:
return mock_datahub_graph(DatahubClientConfig(server="http://fake.domain.local"))
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
# TODO: This only works for stale entity removal. We need to generalize this.
stateful_source = cast(StatefulIngestionSourceBase, pipeline.source)
return stateful_source.state_provider.get_current_checkpoint(
StaleEntityRemovalHandler.compute_job_id(
getattr(stateful_source, "platform", "default")
)
)