219 lines
9.1 KiB
Python
Raw Permalink Normal View History

import tempfile
import types
import unittest
from typing import Dict, List, Optional, Type
from unittest.mock import MagicMock, patch
from avrogen.dict_wrapper import DictWrapper
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
CheckpointJobStateType,
IngestionCheckpointingProviderBase,
JobId,
)
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState,
)
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionConfig,
StateProviderWrapper,
)
from datahub.ingestion.source.state.usage_common_state import (
BaseTimeWindowCheckpointState,
)
from datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider import (
DatahubIngestionCheckpointingProvider,
)
from datahub.ingestion.source.state_provider.file_ingestion_checkpointing_provider import (
FileIngestionCheckpointingProvider,
)
from tests.test_helpers.type_helpers import assert_not_null
class TestIngestionCheckpointProviders(unittest.TestCase):
# Static members for the tests
pipeline_name: str = "test_pipeline"
job_names: List[JobId] = [JobId("job1"), JobId("job2")]
run_id: str = "test_run"
def setUp(self) -> None:
self._setup_mock_graph()
self._create_providers()
def _setup_mock_graph(self) -> None:
"""
Setup monkey-patched graph client.
"""
self.patcher = patch(
"datahub.ingestion.graph.client.DataHubGraph", autospec=True
)
self.addCleanup(self.patcher.stop)
self.mock_graph = self.patcher.start()
# Make server stateful ingestion capable
self.mock_graph.get_config.return_value = {"statefulIngestionCapable": True}
# Bind mock_graph's emit_mcp to testcase's monkey_patch_emit_mcp so that we can emulate emits.
self.mock_graph.emit_mcp = types.MethodType(
self.monkey_patch_emit_mcp, self.mock_graph
)
# Bind mock_graph's get_latest_timeseries_value to monkey_patch_get_latest_timeseries_value
self.mock_graph.get_latest_timeseries_value = types.MethodType(
self.monkey_patch_get_latest_timeseries_value, self.mock_graph
)
# Tracking for emitted mcps.
self.mcps_emitted: Dict[str, MetadataChangeProposalWrapper] = {}
def _create_providers(self) -> None:
ctx: PipelineContext = PipelineContext(
run_id=self.run_id, pipeline_name=self.pipeline_name
)
ctx.graph = self.mock_graph
self.providers: List[IngestionCheckpointingProviderBase] = [
DatahubIngestionCheckpointingProvider.create({}, ctx),
FileIngestionCheckpointingProvider.create(
{"filename": f"{tempfile.mkdtemp()}/checkpoint_mces.json"},
ctx,
),
]
def monkey_patch_emit_mcp(
self, graph_ref: MagicMock, mcpw: MetadataChangeProposalWrapper
) -> None:
"""
Mockey patched implementation of DatahubGraph.emit_mcp that caches the mcp locally in memory.
"""
self.assertIsNotNone(graph_ref)
if mcpw.aspectName != "status":
self.assertEqual(mcpw.entityType, "dataJob")
self.assertEqual(mcpw.aspectName, "datahubIngestionCheckpoint")
# Cache the mcpw against the entityUrn
assert mcpw.entityUrn is not None
self.mcps_emitted[mcpw.entityUrn] = mcpw
def monkey_patch_get_latest_timeseries_value(
self,
graph_ref: MagicMock,
entity_urn: str,
aspect_type: Type[DictWrapper],
filter_criteria_map: Dict[str, str],
) -> Optional[DictWrapper]:
"""
Monkey patched implementation of DatahubGraph.get_latest_timeseries_value that returns the latest cached aspect
for a given entity urn.
"""
self.assertIsNotNone(graph_ref)
self.assertEqual(aspect_type, CheckpointJobStateType)
self.assertEqual(
filter_criteria_map,
{
"pipelineName": self.pipeline_name,
},
)
# Retrieve the cached mcpw and return its aspect value.
mcpw = self.mcps_emitted.get(entity_urn)
if mcpw:
return mcpw.aspect
return None
def test_providers(self):
self.assertEqual(len(self.providers), 2)
for provider in self.providers:
assert provider
# 1. Create the individual job checkpoints with appropriate states.
# Job1 - Checkpoint with a BaseSQLAlchemyCheckpointState state
job1_state_obj = BaseSQLAlchemyCheckpointState()
job1_checkpoint = Checkpoint(
job_name=self.job_names[0],
pipeline_name=self.pipeline_name,
run_id=self.run_id,
state=job1_state_obj,
)
# Job2 - Checkpoint with a BaseTimeWindowCheckpointState state
job2_state_obj = BaseTimeWindowCheckpointState(
begin_timestamp_millis=10, end_timestamp_millis=100
)
job2_checkpoint = Checkpoint(
job_name=self.job_names[1],
pipeline_name=self.pipeline_name,
run_id=self.run_id,
state=job2_state_obj,
)
# 2. Set the provider's state_to_commit.
provider.state_to_commit = {
# NOTE: state_to_commit accepts only the aspect version of the checkpoint.
self.job_names[0]: assert_not_null(
job1_checkpoint.to_checkpoint_aspect(max_allowed_state_size=2**20)
),
self.job_names[1]: assert_not_null(
job2_checkpoint.to_checkpoint_aspect(max_allowed_state_size=2**20)
),
}
# 3. Perform the commit
# NOTE: This will commit the state to
# In-memory self.mcps_emitted because of the monkey-patching for datahub ingestion checkpointer provider.
# And to temp directory json file for file ingestion checkpointer provider.
provider.commit()
self.assertTrue(provider.committed)
# 4. Get last committed state. This must match what has been committed earlier.
# NOTE: This will retrieve the state form where it is committed.
job1_last_state = provider.get_latest_checkpoint(
self.pipeline_name, self.job_names[0]
)
job2_last_state = provider.get_latest_checkpoint(
self.pipeline_name, self.job_names[1]
)
# 5. Validate individual job checkpoint state values that have been committed and retrieved
# against the original values.
self.assertIsNotNone(job1_last_state)
job1_last_checkpoint = Checkpoint.create_from_checkpoint_aspect(
job_name=self.job_names[0],
checkpoint_aspect=job1_last_state,
state_class=type(job1_state_obj),
)
self.assertEqual(job1_last_checkpoint, job1_checkpoint)
self.assertIsNotNone(job2_last_state)
job2_last_checkpoint = Checkpoint.create_from_checkpoint_aspect(
job_name=self.job_names[1],
checkpoint_aspect=job2_last_state,
state_class=type(job2_state_obj),
)
self.assertEqual(job2_last_checkpoint, job2_checkpoint)
def test_state_provider_wrapper_with_config_provided(self):
# stateful_ingestion_config.enabled as true
ctx = PipelineContext(run_id=self.run_id, pipeline_name=self.pipeline_name)
ctx.graph = self.mock_graph
state_provider = StateProviderWrapper(
StatefulIngestionConfig(enabled=True), ctx
)
assert state_provider.stateful_ingestion_config
assert state_provider.ingestion_checkpointing_state_provider
# stateful_ingestion_config.enabled as false
ctx = PipelineContext(run_id=self.run_id, pipeline_name=self.pipeline_name)
ctx.graph = self.mock_graph
state_provider = StateProviderWrapper(
StatefulIngestionConfig(enabled=False), ctx
)
assert state_provider.stateful_ingestion_config
assert not state_provider.ingestion_checkpointing_state_provider
def test_state_provider_wrapper_with_config_not_provided(self):
# graph object is present
ctx = PipelineContext(run_id=self.run_id, pipeline_name=self.pipeline_name)
ctx.graph = self.mock_graph
state_provider = StateProviderWrapper(None, ctx)
assert state_provider.stateful_ingestion_config
assert state_provider.ingestion_checkpointing_state_provider
# graph object is none
ctx = PipelineContext(run_id=self.run_id, pipeline_name=self.pipeline_name)
state_provider = StateProviderWrapper(None, ctx)
assert not state_provider.stateful_ingestion_config
assert not state_provider.ingestion_checkpointing_state_provider