feat(ingest): remove source config from DatahubIngestionCheckpoint (#6722)

This commit is contained in:
Harshal Sheth 2022-12-14 12:39:21 -05:00 committed by GitHub
parent 8109b8b567
commit 2f95719dba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 11 additions and 52 deletions

View File

@ -99,7 +99,6 @@ class Checkpoint(Generic[StateType]):
pipeline_name: str pipeline_name: str
platform_instance_id: str platform_instance_id: str
run_id: str run_id: str
config: ConfigModel
state: StateType state: StateType
@classmethod @classmethod
@ -107,20 +106,10 @@ class Checkpoint(Generic[StateType]):
cls, cls,
job_name: str, job_name: str,
checkpoint_aspect: Optional[DatahubIngestionCheckpointClass], checkpoint_aspect: Optional[DatahubIngestionCheckpointClass],
config_class: Type[ConfigModel],
state_class: Type[StateType], state_class: Type[StateType],
) -> Optional["Checkpoint"]: ) -> Optional["Checkpoint"]:
if checkpoint_aspect is None: if checkpoint_aspect is None:
return None return None
try:
# Construct the config
config_as_dict = json.loads(checkpoint_aspect.config)
config_obj = config_class.parse_obj(config_as_dict)
except Exception as e:
# Failure to load config is probably okay...config structure has changed.
logger.warning(
"Failed to construct checkpoint's config from checkpoint aspect. %s", e
)
else: else:
try: try:
if checkpoint_aspect.state.serde == "utf-8": if checkpoint_aspect.state.serde == "utf-8":
@ -153,7 +142,6 @@ class Checkpoint(Generic[StateType]):
pipeline_name=checkpoint_aspect.pipelineName, pipeline_name=checkpoint_aspect.pipelineName,
platform_instance_id=checkpoint_aspect.platformInstanceId, platform_instance_id=checkpoint_aspect.platformInstanceId,
run_id=checkpoint_aspect.runId, run_id=checkpoint_aspect.runId,
config=config_obj,
state=state_obj, state=state_obj,
) )
logger.info( logger.info(
@ -230,7 +218,7 @@ class Checkpoint(Generic[StateType]):
pipelineName=self.pipeline_name, pipelineName=self.pipeline_name,
platformInstanceId=self.platform_instance_id, platformInstanceId=self.platform_instance_id,
runId=self.run_id, runId=self.run_id,
config=self.config.json(), config="",
state=checkpoint_state, state=checkpoint_state,
) )
return checkpoint_aspect return checkpoint_aspect

View File

@ -3,7 +3,6 @@ from typing import Optional, cast
import pydantic import pydantic
from datahub.configuration.common import ConfigModel
from datahub.ingestion.api.ingestion_job_state_provider import JobId from datahub.ingestion.api.ingestion_job_state_provider import JobId
from datahub.ingestion.source.state.checkpoint import Checkpoint from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.stateful_ingestion_base import ( from datahub.ingestion.source.state.stateful_ingestion_base import (
@ -42,17 +41,14 @@ class RedundantRunSkipHandler(
def __init__( def __init__(
self, self,
source: StatefulIngestionSourceBase, source: StatefulIngestionSourceBase,
config: Optional[StatefulIngestionConfigBase], config: StatefulIngestionConfigBase[StatefulRedundantRunSkipConfig],
pipeline_name: Optional[str], pipeline_name: Optional[str],
run_id: str, run_id: str,
): ):
self.source = source self.source = source
self.config = config self.stateful_ingestion_config: Optional[
self.stateful_ingestion_config = ( StatefulRedundantRunSkipConfig
cast(StatefulRedundantRunSkipConfig, self.config.stateful_ingestion) ] = config.stateful_ingestion
if self.config
else None
)
self.pipeline_name = pipeline_name self.pipeline_name = pipeline_name
self.run_id = run_id self.run_id = run_id
self.checkpointing_enabled: bool = source.is_stateful_ingestion_configured() self.checkpointing_enabled: bool = source.is_stateful_ingestion_configured()
@ -100,14 +96,12 @@ class RedundantRunSkipHandler(
if not self.is_checkpointing_enabled() or self._ignore_new_state(): if not self.is_checkpointing_enabled() or self._ignore_new_state():
return None return None
assert self.config is not None
assert self.pipeline_name is not None assert self.pipeline_name is not None
return Checkpoint( return Checkpoint(
job_name=self.job_id, job_name=self.job_id,
pipeline_name=self.pipeline_name, pipeline_name=self.pipeline_name,
platform_instance_id=self.source.get_platform_instance_id(), platform_instance_id=self.source.get_platform_instance_id(),
run_id=self.run_id, run_id=self.run_id,
config=cast(ConfigModel, self.config),
state=BaseUsageCheckpointState( state=BaseUsageCheckpointState(
begin_timestamp_millis=self.INVALID_TIMESTAMP_VALUE, begin_timestamp_millis=self.INVALID_TIMESTAMP_VALUE,
end_timestamp_millis=self.INVALID_TIMESTAMP_VALUE, end_timestamp_millis=self.INVALID_TIMESTAMP_VALUE,

View File

@ -5,7 +5,6 @@ from typing import Dict, Generic, Iterable, List, Optional, Tuple, Type, TypeVar
import pydantic import pydantic
from datahub.configuration.common import ConfigModel
from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.ingestion_job_state_provider import JobId from datahub.ingestion.api.ingestion_job_state_provider import JobId
from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.api.workunit import MetadataWorkUnit
@ -140,21 +139,18 @@ class StaleEntityRemovalHandler(
def __init__( def __init__(
self, self,
source: StatefulIngestionSourceBase, source: StatefulIngestionSourceBase,
config: Optional[StatefulIngestionConfigBase], config: StatefulIngestionConfigBase[StatefulStaleMetadataRemovalConfig],
state_type_class: Type[StaleEntityCheckpointStateBase], state_type_class: Type[StaleEntityCheckpointStateBase],
pipeline_name: Optional[str], pipeline_name: Optional[str],
run_id: str, run_id: str,
): ):
self.config = config
self.source = source self.source = source
self.state_type_class = state_type_class self.state_type_class = state_type_class
self.pipeline_name = pipeline_name self.pipeline_name = pipeline_name
self.run_id = run_id self.run_id = run_id
self.stateful_ingestion_config = ( self.stateful_ingestion_config: Optional[
cast(StatefulStaleMetadataRemovalConfig, self.config.stateful_ingestion) StatefulStaleMetadataRemovalConfig
if self.config ] = config.stateful_ingestion
else None
)
self.checkpointing_enabled: bool = ( self.checkpointing_enabled: bool = (
True True
if ( if (
@ -217,7 +213,6 @@ class StaleEntityRemovalHandler(
pipeline_name=self.pipeline_name, pipeline_name=self.pipeline_name,
platform_instance_id=self.source.get_platform_instance_id(), platform_instance_id=self.source.get_platform_instance_id(),
run_id=self.run_id, run_id=self.run_id,
config=cast(ConfigModel, self.config),
state=self.state_type_class(), state=self.state_type_class(),
) )
return None return None

View File

@ -116,7 +116,6 @@ class StatefulIngestionSourceBase(Source):
) -> None: ) -> None:
super().__init__(ctx) super().__init__(ctx)
self.stateful_ingestion_config = config.stateful_ingestion self.stateful_ingestion_config = config.stateful_ingestion
self.source_config_type = type(config)
self.last_checkpoints: Dict[JobId, Optional[Checkpoint]] = {} self.last_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
self.cur_checkpoints: Dict[JobId, Optional[Checkpoint]] = {} self.cur_checkpoints: Dict[JobId, Optional[Checkpoint]] = {}
self.run_summaries_to_report: Dict[JobId, DatahubIngestionRunSummaryClass] = {} self.run_summaries_to_report: Dict[JobId, DatahubIngestionRunSummaryClass] = {}
@ -246,7 +245,6 @@ class StatefulIngestionSourceBase(Source):
last_checkpoint = Checkpoint[StateType].create_from_checkpoint_aspect( last_checkpoint = Checkpoint[StateType].create_from_checkpoint_aspect(
job_name=job_id, job_name=job_id,
checkpoint_aspect=last_checkpoint_aspect, checkpoint_aspect=last_checkpoint_aspect,
config_class=self.source_config_type,
state_class=checkpoint_state_class, state_class=checkpoint_state_class,
) )
return last_checkpoint return last_checkpoint

View File

@ -454,7 +454,6 @@ def test_dbt_state_backward_compatibility(
pipeline_name=dbt_source.ctx.pipeline_name, pipeline_name=dbt_source.ctx.pipeline_name,
platform_instance_id=dbt_source.get_platform_instance_id(), platform_instance_id=dbt_source.get_platform_instance_id(),
run_id=dbt_source.ctx.run_id, run_id=dbt_source.ctx.run_id,
config=dbt_source.config,
state=sql_state, state=sql_state,
) )

View File

@ -184,10 +184,7 @@ def test_kafka_ingest_with_stateful(
== f"urn:li:dataset:(urn:li:dataPlatform:kafka,{platform_instance}.{kafka_ctx.topics[0]},PROD)" == f"urn:li:dataset:(urn:li:dataPlatform:kafka,{platform_instance}.{kafka_ctx.topics[0]},PROD)"
) )
# 4. Checkpoint configuration should be the same. # 4. Validate that all providers have committed successfully.
assert checkpoint1.config == checkpoint2.config
# 5. Validate that all providers have committed successfully.
# NOTE: The following validation asserts for presence of state as well # NOTE: The following validation asserts for presence of state as well
# and validates reporting. # and validates reporting.
validate_all_providers_have_committed_successfully( validate_all_providers_have_committed_successfully(

View File

@ -14,7 +14,6 @@ from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
JobId, JobId,
JobStateKey, JobStateKey,
) )
from datahub.ingestion.source.sql.postgres import PostgresConfig
from datahub.ingestion.source.state.checkpoint import Checkpoint from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.sql_common_state import ( from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState, BaseSQLAlchemyCheckpointState,
@ -124,7 +123,6 @@ class TestDatahubIngestionCheckpointProvider(unittest.TestCase):
pipeline_name=self.pipeline_name, pipeline_name=self.pipeline_name,
platform_instance_id=self.platform_instance_id, platform_instance_id=self.platform_instance_id,
run_id=self.run_id, run_id=self.run_id,
config=PostgresConfig(host_port="localhost:5432"),
state=job1_state_obj, state=job1_state_obj,
) )
# Job2 - Checkpoint with a BaseUsageCheckpointState state # Job2 - Checkpoint with a BaseUsageCheckpointState state
@ -136,7 +134,6 @@ class TestDatahubIngestionCheckpointProvider(unittest.TestCase):
pipeline_name=self.pipeline_name, pipeline_name=self.pipeline_name,
platform_instance_id=self.platform_instance_id, platform_instance_id=self.platform_instance_id,
run_id=self.run_id, run_id=self.run_id,
config=PostgresConfig(host_port="localhost:5432"),
state=job2_state_obj, state=job2_state_obj,
) )
@ -171,7 +168,6 @@ class TestDatahubIngestionCheckpointProvider(unittest.TestCase):
job_name=self.job_names[0], job_name=self.job_names[0],
checkpoint_aspect=last_state[self.job_names[0]], checkpoint_aspect=last_state[self.job_names[0]],
state_class=type(job1_state_obj), state_class=type(job1_state_obj),
config_class=type(job1_checkpoint.config),
) )
self.assertEqual(job1_last_checkpoint, job1_checkpoint) self.assertEqual(job1_last_checkpoint, job1_checkpoint)
@ -180,6 +176,5 @@ class TestDatahubIngestionCheckpointProvider(unittest.TestCase):
job_name=self.job_names[1], job_name=self.job_names[1],
checkpoint_aspect=last_state[self.job_names[1]], checkpoint_aspect=last_state[self.job_names[1]],
state_class=type(job2_state_obj), state_class=type(job2_state_obj),
config_class=type(job2_checkpoint.config),
) )
self.assertEqual(job2_last_checkpoint, job2_checkpoint) self.assertEqual(job2_last_checkpoint, job2_checkpoint)

View File

@ -5,8 +5,6 @@ import pydantic
import pytest import pytest
from datahub.emitter.mce_builder import make_dataset_urn from datahub.emitter.mce_builder import make_dataset_urn
from datahub.ingestion.source.sql.postgres import PostgresConfig
from datahub.ingestion.source.sql.sql_common import BasicSQLAlchemyConfig
from datahub.ingestion.source.state.checkpoint import Checkpoint, CheckpointStateBase from datahub.ingestion.source.state.checkpoint import Checkpoint, CheckpointStateBase
from datahub.ingestion.source.state.sql_common_state import ( from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState, BaseSQLAlchemyCheckpointState,
@ -22,7 +20,6 @@ test_pipeline_name: str = "test_pipeline"
test_platform_instance_id: str = "test_platform_instance_1" test_platform_instance_id: str = "test_platform_instance_1"
test_job_name: str = "test_job_1" test_job_name: str = "test_job_1"
test_run_id: str = "test_run_1" test_run_id: str = "test_run_1"
test_source_config: BasicSQLAlchemyConfig = PostgresConfig(host_port="test_host:1234")
def _assert_checkpoint_deserialization( def _assert_checkpoint_deserialization(
@ -34,7 +31,7 @@ def _assert_checkpoint_deserialization(
timestampMillis=int(datetime.utcnow().timestamp() * 1000), timestampMillis=int(datetime.utcnow().timestamp() * 1000),
pipelineName=test_pipeline_name, pipelineName=test_pipeline_name,
platformInstanceId=test_platform_instance_id, platformInstanceId=test_platform_instance_id,
config=test_source_config.json(), config="",
state=serialized_checkpoint_state, state=serialized_checkpoint_state,
runId=test_run_id, runId=test_run_id,
) )
@ -44,7 +41,6 @@ def _assert_checkpoint_deserialization(
job_name=test_job_name, job_name=test_job_name,
checkpoint_aspect=checkpoint_aspect, checkpoint_aspect=checkpoint_aspect,
state_class=type(expected_checkpoint_state), state_class=type(expected_checkpoint_state),
config_class=PostgresConfig,
) )
expected_checkpoint_obj = Checkpoint( expected_checkpoint_obj = Checkpoint(
@ -52,7 +48,6 @@ def _assert_checkpoint_deserialization(
pipeline_name=test_pipeline_name, pipeline_name=test_pipeline_name,
platform_instance_id=test_platform_instance_id, platform_instance_id=test_platform_instance_id,
run_id=test_run_id, run_id=test_run_id,
config=test_source_config,
state=expected_checkpoint_state, state=expected_checkpoint_state,
) )
assert checkpoint_obj == expected_checkpoint_obj assert checkpoint_obj == expected_checkpoint_obj
@ -127,7 +122,6 @@ def test_serde_idempotence(state_obj):
pipeline_name=test_pipeline_name, pipeline_name=test_pipeline_name,
platform_instance_id=test_platform_instance_id, platform_instance_id=test_platform_instance_id,
run_id=test_run_id, run_id=test_run_id,
config=test_source_config,
state=state_obj, state=state_obj,
) )
@ -142,7 +136,6 @@ def test_serde_idempotence(state_obj):
job_name=test_job_name, job_name=test_job_name,
checkpoint_aspect=checkpoint_aspect, checkpoint_aspect=checkpoint_aspect,
state_class=type(state_obj), state_class=type(state_obj),
config_class=PostgresConfig,
) )
assert orig_checkpoint_obj == serde_checkpoint_obj assert orig_checkpoint_obj == serde_checkpoint_obj