2023-11-11 04:06:00 +05:30
|
|
|
import tempfile
|
|
|
|
import types
|
|
|
|
import unittest
|
|
|
|
from typing import Dict, List, Optional, Type
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
|
|
|
|
from avrogen.dict_wrapper import DictWrapper
|
|
|
|
|
|
|
|
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
|
|
|
from datahub.ingestion.api.common import PipelineContext
|
|
|
|
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
|
|
|
|
CheckpointJobStateType,
|
|
|
|
IngestionCheckpointingProviderBase,
|
|
|
|
JobId,
|
|
|
|
)
|
|
|
|
from datahub.ingestion.source.state.checkpoint import Checkpoint
|
|
|
|
from datahub.ingestion.source.state.sql_common_state import (
|
|
|
|
BaseSQLAlchemyCheckpointState,
|
|
|
|
)
|
2024-03-06 00:48:03 +05:30
|
|
|
from datahub.ingestion.source.state.stateful_ingestion_base import (
|
|
|
|
StatefulIngestionConfig,
|
|
|
|
StateProviderWrapper,
|
|
|
|
)
|
2023-11-11 04:06:00 +05:30
|
|
|
from datahub.ingestion.source.state.usage_common_state import (
|
|
|
|
BaseTimeWindowCheckpointState,
|
|
|
|
)
|
|
|
|
from datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider import (
|
|
|
|
DatahubIngestionCheckpointingProvider,
|
|
|
|
)
|
|
|
|
from datahub.ingestion.source.state_provider.file_ingestion_checkpointing_provider import (
|
|
|
|
FileIngestionCheckpointingProvider,
|
|
|
|
)
|
|
|
|
from tests.test_helpers.type_helpers import assert_not_null
|
|
|
|
|
|
|
|
|
|
|
|
class TestIngestionCheckpointProviders(unittest.TestCase):
|
|
|
|
# Static members for the tests
|
|
|
|
pipeline_name: str = "test_pipeline"
|
|
|
|
job_names: List[JobId] = [JobId("job1"), JobId("job2")]
|
|
|
|
run_id: str = "test_run"
|
|
|
|
|
|
|
|
def setUp(self) -> None:
|
|
|
|
self._setup_mock_graph()
|
|
|
|
self._create_providers()
|
|
|
|
|
|
|
|
def _setup_mock_graph(self) -> None:
|
|
|
|
"""
|
|
|
|
Setup monkey-patched graph client.
|
|
|
|
"""
|
|
|
|
self.patcher = patch(
|
|
|
|
"datahub.ingestion.graph.client.DataHubGraph", autospec=True
|
|
|
|
)
|
|
|
|
self.addCleanup(self.patcher.stop)
|
|
|
|
self.mock_graph = self.patcher.start()
|
|
|
|
# Make server stateful ingestion capable
|
|
|
|
self.mock_graph.get_config.return_value = {"statefulIngestionCapable": True}
|
|
|
|
# Bind mock_graph's emit_mcp to testcase's monkey_patch_emit_mcp so that we can emulate emits.
|
|
|
|
self.mock_graph.emit_mcp = types.MethodType(
|
|
|
|
self.monkey_patch_emit_mcp, self.mock_graph
|
|
|
|
)
|
|
|
|
# Bind mock_graph's get_latest_timeseries_value to monkey_patch_get_latest_timeseries_value
|
|
|
|
self.mock_graph.get_latest_timeseries_value = types.MethodType(
|
|
|
|
self.monkey_patch_get_latest_timeseries_value, self.mock_graph
|
|
|
|
)
|
|
|
|
# Tracking for emitted mcps.
|
|
|
|
self.mcps_emitted: Dict[str, MetadataChangeProposalWrapper] = {}
|
|
|
|
|
|
|
|
def _create_providers(self) -> None:
|
|
|
|
ctx: PipelineContext = PipelineContext(
|
|
|
|
run_id=self.run_id, pipeline_name=self.pipeline_name
|
|
|
|
)
|
|
|
|
ctx.graph = self.mock_graph
|
|
|
|
self.providers: List[IngestionCheckpointingProviderBase] = [
|
|
|
|
DatahubIngestionCheckpointingProvider.create({}, ctx),
|
|
|
|
FileIngestionCheckpointingProvider.create(
|
|
|
|
{"filename": f"{tempfile.mkdtemp()}/checkpoint_mces.json"},
|
|
|
|
ctx,
|
|
|
|
),
|
|
|
|
]
|
|
|
|
|
|
|
|
def monkey_patch_emit_mcp(
|
|
|
|
self, graph_ref: MagicMock, mcpw: MetadataChangeProposalWrapper
|
|
|
|
) -> None:
|
|
|
|
"""
|
|
|
|
Mockey patched implementation of DatahubGraph.emit_mcp that caches the mcp locally in memory.
|
|
|
|
"""
|
|
|
|
self.assertIsNotNone(graph_ref)
|
|
|
|
if mcpw.aspectName != "status":
|
|
|
|
self.assertEqual(mcpw.entityType, "dataJob")
|
|
|
|
self.assertEqual(mcpw.aspectName, "datahubIngestionCheckpoint")
|
|
|
|
# Cache the mcpw against the entityUrn
|
|
|
|
assert mcpw.entityUrn is not None
|
|
|
|
self.mcps_emitted[mcpw.entityUrn] = mcpw
|
|
|
|
|
|
|
|
def monkey_patch_get_latest_timeseries_value(
|
|
|
|
self,
|
|
|
|
graph_ref: MagicMock,
|
|
|
|
entity_urn: str,
|
|
|
|
aspect_type: Type[DictWrapper],
|
|
|
|
filter_criteria_map: Dict[str, str],
|
|
|
|
) -> Optional[DictWrapper]:
|
|
|
|
"""
|
|
|
|
Monkey patched implementation of DatahubGraph.get_latest_timeseries_value that returns the latest cached aspect
|
|
|
|
for a given entity urn.
|
|
|
|
"""
|
|
|
|
self.assertIsNotNone(graph_ref)
|
|
|
|
self.assertEqual(aspect_type, CheckpointJobStateType)
|
|
|
|
self.assertEqual(
|
|
|
|
filter_criteria_map,
|
|
|
|
{
|
|
|
|
"pipelineName": self.pipeline_name,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
# Retrieve the cached mcpw and return its aspect value.
|
|
|
|
mcpw = self.mcps_emitted.get(entity_urn)
|
|
|
|
if mcpw:
|
|
|
|
return mcpw.aspect
|
|
|
|
return None
|
|
|
|
|
|
|
|
def test_providers(self):
|
|
|
|
self.assertEqual(len(self.providers), 2)
|
|
|
|
for provider in self.providers:
|
|
|
|
assert provider
|
|
|
|
# 1. Create the individual job checkpoints with appropriate states.
|
|
|
|
# Job1 - Checkpoint with a BaseSQLAlchemyCheckpointState state
|
|
|
|
job1_state_obj = BaseSQLAlchemyCheckpointState()
|
|
|
|
job1_checkpoint = Checkpoint(
|
|
|
|
job_name=self.job_names[0],
|
|
|
|
pipeline_name=self.pipeline_name,
|
|
|
|
run_id=self.run_id,
|
|
|
|
state=job1_state_obj,
|
|
|
|
)
|
|
|
|
# Job2 - Checkpoint with a BaseTimeWindowCheckpointState state
|
|
|
|
job2_state_obj = BaseTimeWindowCheckpointState(
|
|
|
|
begin_timestamp_millis=10, end_timestamp_millis=100
|
|
|
|
)
|
|
|
|
job2_checkpoint = Checkpoint(
|
|
|
|
job_name=self.job_names[1],
|
|
|
|
pipeline_name=self.pipeline_name,
|
|
|
|
run_id=self.run_id,
|
|
|
|
state=job2_state_obj,
|
|
|
|
)
|
|
|
|
|
|
|
|
# 2. Set the provider's state_to_commit.
|
|
|
|
provider.state_to_commit = {
|
|
|
|
# NOTE: state_to_commit accepts only the aspect version of the checkpoint.
|
|
|
|
self.job_names[0]: assert_not_null(
|
|
|
|
job1_checkpoint.to_checkpoint_aspect(max_allowed_state_size=2**20)
|
|
|
|
),
|
|
|
|
self.job_names[1]: assert_not_null(
|
|
|
|
job2_checkpoint.to_checkpoint_aspect(max_allowed_state_size=2**20)
|
|
|
|
),
|
|
|
|
}
|
|
|
|
|
|
|
|
# 3. Perform the commit
|
|
|
|
# NOTE: This will commit the state to
|
|
|
|
# In-memory self.mcps_emitted because of the monkey-patching for datahub ingestion checkpointer provider.
|
|
|
|
# And to temp directory json file for file ingestion checkpointer provider.
|
|
|
|
provider.commit()
|
|
|
|
self.assertTrue(provider.committed)
|
|
|
|
|
|
|
|
# 4. Get last committed state. This must match what has been committed earlier.
|
|
|
|
# NOTE: This will retrieve the state form where it is committed.
|
|
|
|
job1_last_state = provider.get_latest_checkpoint(
|
|
|
|
self.pipeline_name, self.job_names[0]
|
|
|
|
)
|
|
|
|
job2_last_state = provider.get_latest_checkpoint(
|
|
|
|
self.pipeline_name, self.job_names[1]
|
|
|
|
)
|
|
|
|
|
|
|
|
# 5. Validate individual job checkpoint state values that have been committed and retrieved
|
|
|
|
# against the original values.
|
|
|
|
self.assertIsNotNone(job1_last_state)
|
|
|
|
job1_last_checkpoint = Checkpoint.create_from_checkpoint_aspect(
|
|
|
|
job_name=self.job_names[0],
|
|
|
|
checkpoint_aspect=job1_last_state,
|
|
|
|
state_class=type(job1_state_obj),
|
|
|
|
)
|
|
|
|
self.assertEqual(job1_last_checkpoint, job1_checkpoint)
|
|
|
|
|
|
|
|
self.assertIsNotNone(job2_last_state)
|
|
|
|
job2_last_checkpoint = Checkpoint.create_from_checkpoint_aspect(
|
|
|
|
job_name=self.job_names[1],
|
|
|
|
checkpoint_aspect=job2_last_state,
|
|
|
|
state_class=type(job2_state_obj),
|
|
|
|
)
|
|
|
|
self.assertEqual(job2_last_checkpoint, job2_checkpoint)
|
2024-03-06 00:48:03 +05:30
|
|
|
|
|
|
|
def test_state_provider_wrapper_with_config_provided(self):
|
|
|
|
# stateful_ingestion_config.enabled as true
|
|
|
|
ctx = PipelineContext(run_id=self.run_id, pipeline_name=self.pipeline_name)
|
|
|
|
ctx.graph = self.mock_graph
|
|
|
|
state_provider = StateProviderWrapper(
|
|
|
|
StatefulIngestionConfig(enabled=True), ctx
|
|
|
|
)
|
|
|
|
assert state_provider.stateful_ingestion_config
|
|
|
|
assert state_provider.ingestion_checkpointing_state_provider
|
|
|
|
# stateful_ingestion_config.enabled as false
|
|
|
|
ctx = PipelineContext(run_id=self.run_id, pipeline_name=self.pipeline_name)
|
|
|
|
ctx.graph = self.mock_graph
|
|
|
|
state_provider = StateProviderWrapper(
|
|
|
|
StatefulIngestionConfig(enabled=False), ctx
|
|
|
|
)
|
|
|
|
assert state_provider.stateful_ingestion_config
|
|
|
|
assert not state_provider.ingestion_checkpointing_state_provider
|
|
|
|
|
|
|
|
def test_state_provider_wrapper_with_config_not_provided(self):
|
|
|
|
# graph object is present
|
|
|
|
ctx = PipelineContext(run_id=self.run_id, pipeline_name=self.pipeline_name)
|
|
|
|
ctx.graph = self.mock_graph
|
|
|
|
state_provider = StateProviderWrapper(None, ctx)
|
|
|
|
assert state_provider.stateful_ingestion_config
|
|
|
|
assert state_provider.ingestion_checkpointing_state_provider
|
|
|
|
# graph object is none
|
|
|
|
ctx = PipelineContext(run_id=self.run_id, pipeline_name=self.pipeline_name)
|
|
|
|
state_provider = StateProviderWrapper(None, ctx)
|
|
|
|
assert not state_provider.stateful_ingestion_config
|
|
|
|
assert not state_provider.ingestion_checkpointing_state_provider
|