refactor(ingest/stateful): remove most remaining state classes (#6791)

This commit is contained in:
Harshal Sheth 2022-12-19 13:40:48 -05:00 committed by GitHub
parent 14a00f4098
commit 47be95689e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 33 additions and 215 deletions

View File

@ -5,19 +5,7 @@ from abc import abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from enum import auto
from typing import (
Any,
Callable,
ClassVar,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Union
import pydantic
from pydantic import root_validator, validator
@ -41,7 +29,6 @@ from datahub.ingestion.api.decorators import (
platform_name,
support_status,
)
from datahub.ingestion.api.ingestion_job_state_provider import JobId
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.sql.sql_types import (
BIGQUERY_TYPES_MAP,
@ -52,11 +39,7 @@ from datahub.ingestion.source.sql.sql_types import (
resolve_postgres_modified_type,
resolve_trino_modified_type,
)
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.dbt_state import DbtCheckpointState
from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState,
)
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
StaleEntityRemovalSourceReport,
@ -65,7 +48,6 @@ from datahub.ingestion.source.state.stale_entity_removal_handler import (
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionConfigBase,
StatefulIngestionSourceBase,
StateType,
)
from datahub.metadata.com.linkedin.pegasus2avro.common import (
AuditStamp,
@ -684,42 +666,11 @@ class DBTSourceBase(StatefulIngestionSourceBase):
self.stale_entity_removal_handler = StaleEntityRemovalHandler(
source=self,
config=self.config,
state_type_class=DbtCheckpointState,
state_type_class=GenericCheckpointState,
pipeline_name=self.ctx.pipeline_name,
run_id=self.ctx.run_id,
)
def get_last_checkpoint(
self, job_id: JobId, checkpoint_state_class: Type[StateType]
) -> Optional[Checkpoint]:
last_checkpoint: Optional[Checkpoint]
is_conversion_required: bool = False
try:
# Best-case that last checkpoint state is DbtCheckpointState
last_checkpoint = super(DBTSourceBase, self).get_last_checkpoint(
job_id, checkpoint_state_class
)
except Exception as e:
# Backward compatibility for old dbt ingestion source which was saving dbt-nodes in
# BaseSQLAlchemyCheckpointState
last_checkpoint = super(DBTSourceBase, self).get_last_checkpoint(
job_id, BaseSQLAlchemyCheckpointState # type: ignore
)
logger.debug(
f"Found BaseSQLAlchemyCheckpointState as checkpoint state (got {e})."
)
is_conversion_required = True
if last_checkpoint is not None and is_conversion_required:
# Map the BaseSQLAlchemyCheckpointState to DbtCheckpointState
dbt_checkpoint_state: DbtCheckpointState = DbtCheckpointState()
dbt_checkpoint_state.urns = (
cast(BaseSQLAlchemyCheckpointState, last_checkpoint.state)
).urns
last_checkpoint.state = dbt_checkpoint_state
return last_checkpoint
def create_test_entity_mcps(
self,
test_nodes: List[DBTNode],

View File

@ -27,7 +27,7 @@ from datahub.ingestion.api.decorators import (
from datahub.ingestion.api.registry import import_path
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.kafka_schema_registry_base import KafkaSchemaRegistryBase
from datahub.ingestion.source.state.kafka_state import KafkaCheckpointState
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
StaleEntityRemovalSourceReport,
@ -145,7 +145,7 @@ class KafkaSource(StatefulIngestionSourceBase):
self.stale_entity_removal_handler = StaleEntityRemovalHandler(
source=self,
config=self.source_config,
state_type_class=KafkaCheckpointState,
state_type_class=GenericCheckpointState,
pipeline_name=self.ctx.pipeline_name,
run_id=self.ctx.run_id,
)

View File

@ -16,7 +16,7 @@ from datahub.ingestion.api.decorators import (
support_status,
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.state.ldap_state import LdapCheckpointState
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
StaleEntityRemovalSourceReport,
@ -186,7 +186,7 @@ class LDAPSource(StatefulIngestionSourceBase):
self.stale_entity_removal_handler = StaleEntityRemovalHandler(
source=self,
config=self.config,
state_type_class=LdapCheckpointState,
state_type_class=GenericCheckpointState,
pipeline_name=self.ctx.pipeline_name,
run_id=self.ctx.run_id,
)

View File

@ -1,8 +0,0 @@
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
class DbtCheckpointState(GenericCheckpointState):
"""
Class for representing the checkpoint state for DBT sources.
Stores all nodes and assertions being ingested and is used to remove any stale entities.
"""

View File

@ -1,8 +0,0 @@
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
class KafkaCheckpointState(GenericCheckpointState):
"""
This class represents the checkpoint state for Kafka based sources.
Stores all the topics being ingested and it is used to remove any stale entities.
"""

View File

@ -1,8 +0,0 @@
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
class LdapCheckpointState(GenericCheckpointState):
"""
Base class for representing the checkpoint state for all LDAP based sources.
Stores all corpuser and corpGroup and being ingested and is used to remove any stale entities.
"""

View File

@ -1,9 +1,3 @@
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
class BaseSQLAlchemyCheckpointState(GenericCheckpointState):
"""
Base class for representing the checkpoint state for all SQLAlchemy based sources.
Stores all tables and views being ingested and is used to remove any stale entities.
Subclasses can define additional state as appropriate.
"""
BaseSQLAlchemyCheckpointState = GenericCheckpointState

View File

@ -1,32 +1,24 @@
import dataclasses
from dataclasses import dataclass
from os import PathLike
from typing import Any, Dict, Optional, Type, Union, cast
from unittest.mock import MagicMock, patch
from typing import Any, Dict, Optional, Union, cast
from unittest.mock import patch
import pytest
import requests_mock
from freezegun import freeze_time
from datahub.configuration.common import DynamicTypedConfig
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import JobId
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.run.pipeline_config import PipelineConfig, SourceConfig
from datahub.ingestion.source.dbt.dbt_common import (
DBTEntitiesEnabled,
EmitDirective,
StatefulIngestionSourceBase,
)
from datahub.ingestion.source.dbt.dbt_common import DBTEntitiesEnabled, EmitDirective
from datahub.ingestion.source.dbt.dbt_core import DBTCoreConfig, DBTCoreSource
from datahub.ingestion.source.sql.sql_types import (
TRINO_SQL_TYPES_MAP,
resolve_trino_modified_type,
)
from datahub.ingestion.source.state.checkpoint import Checkpoint, CheckpointStateBase
from datahub.ingestion.source.state.dbt_state import DbtCheckpointState
from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState,
)
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers import mce_helpers
from tests.test_helpers.state_helpers import (
run_and_get_pipeline,
@ -249,7 +241,7 @@ def test_dbt_ingest(dbt_test_config, pytestconfig, tmp_path, mock_time, **kwargs
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint]:
) -> Optional[Checkpoint[GenericCheckpointState]]:
dbt_source = cast(DBTCoreSource, pipeline.source)
return dbt_source.get_current_checkpoint(
dbt_source.stale_entity_removal_handler.job_id
@ -357,8 +349,8 @@ def test_dbt_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
# Perform all assertions on the states. The deleted table should not be
# part of the second state
state1 = cast(DbtCheckpointState, checkpoint1.state)
state2 = cast(DbtCheckpointState, checkpoint2.state)
state1 = checkpoint1.state
state2 = checkpoint2.state
difference_urns = list(
state1.get_urns_not_in(type="*", other_checkpoint_state=state2)
)
@ -387,101 +379,6 @@ def test_dbt_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
)
@pytest.mark.integration
@freeze_time(FROZEN_TIME)
def test_dbt_state_backward_compatibility(
pytestconfig, tmp_path, mock_time, mock_datahub_graph
):
test_resources_dir = pytestconfig.rootpath / "tests/integration/dbt"
manifest_path = f"{test_resources_dir}/dbt_manifest.json"
catalog_path = f"{test_resources_dir}/dbt_catalog.json"
sources_path = f"{test_resources_dir}/dbt_sources.json"
stateful_config: Dict[str, Any] = {
"stateful_ingestion": {
"enabled": True,
"remove_stale_metadata": True,
"fail_safe_threshold": 100.0,
"state_provider": {
"type": "datahub",
"config": {"datahub_api": {"server": GMS_SERVER}},
},
},
}
scd_config: Dict[str, Any] = {
"manifest_path": manifest_path,
"catalog_path": catalog_path,
"sources_path": sources_path,
"target_platform": "postgres",
# This will bypass check in get_workunits function of dbt.py
"write_semantics": "OVERRIDE",
"owner_extraction_pattern": r"^@(?P<owner>(.*))",
# enable stateful ingestion
**stateful_config,
}
pipeline_config_dict: Dict[str, Any] = {
"source": {
"type": "dbt",
"config": scd_config,
},
"sink": {
# we are not really interested in the resulting events for this test
"type": "console"
},
"pipeline_name": "statefulpipeline",
}
def get_fake_base_sql_alchemy_checkpoint_state(
job_id: JobId, checkpoint_state_class: Type[CheckpointStateBase]
) -> Optional[Checkpoint]:
if checkpoint_state_class is DbtCheckpointState:
raise Exception(
"DBT source will call this function again with BaseSQLAlchemyCheckpointState"
)
sql_state = BaseSQLAlchemyCheckpointState()
urn1 = "urn:li:dataset:(urn:li:dataPlatform:dbt,pagila.public.actor,PROD)"
urn2 = "urn:li:dataset:(urn:li:dataPlatform:postgres,pagila.public.actor,PROD)"
sql_state.add_checkpoint_urn(type="table", urn=urn1)
sql_state.add_checkpoint_urn(type="table", urn=urn2)
assert dbt_source.ctx.pipeline_name is not None
return Checkpoint(
job_name=dbt_source.stale_entity_removal_handler.job_id,
pipeline_name=dbt_source.ctx.pipeline_name,
platform_instance_id=dbt_source.get_platform_instance_id(),
run_id=dbt_source.ctx.run_id,
state=sql_state,
)
with patch(
"datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider.DataHubGraph",
mock_datahub_graph,
) as mock_checkpoint, patch.object(
StatefulIngestionSourceBase,
"get_last_checkpoint",
MagicMock(side_effect=get_fake_base_sql_alchemy_checkpoint_state),
) as mock_source_base_get_last_checkpoint:
mock_checkpoint.return_value = mock_datahub_graph
pipeline = Pipeline.create(pipeline_config_dict)
dbt_source = cast(DBTCoreSource, pipeline.source)
last_checkpoint = dbt_source.get_last_checkpoint(
dbt_source.stale_entity_removal_handler.job_id, DbtCheckpointState
)
mock_source_base_get_last_checkpoint.assert_called()
# Our fake method is returning BaseSQLAlchemyCheckpointState,however it should get converted to DbtCheckpointState
assert last_checkpoint is not None and isinstance(
last_checkpoint.state, DbtCheckpointState
)
pipeline.run()
pipeline.raise_from_status()
@pytest.mark.integration
@freeze_time(FROZEN_TIME)
def test_dbt_tests(pytestconfig, tmp_path, mock_time, **kwargs):

View File

@ -9,7 +9,7 @@ from freezegun import freeze_time
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.kafka import KafkaSource
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.kafka_state import KafkaCheckpointState
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers.docker_helpers import wait_for_port
from tests.test_helpers.state_helpers import (
run_and_get_pipeline,
@ -83,7 +83,7 @@ class KafkaTopicsCxtManager:
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[KafkaCheckpointState]]:
) -> Optional[Checkpoint[GenericCheckpointState]]:
kafka_source = cast(KafkaSource, pipeline.source)
return kafka_source.get_current_checkpoint(
kafka_source.stale_entity_removal_handler.job_id

View File

@ -8,7 +8,7 @@ from freezegun import freeze_time
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.ldap import LDAPSource
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.ldap_state import LdapCheckpointState
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from tests.test_helpers import mce_helpers
from tests.test_helpers.docker_helpers import wait_for_port
from tests.test_helpers.state_helpers import (
@ -91,7 +91,7 @@ def ldap_ingest_common(
def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint]:
) -> Optional[Checkpoint[GenericCheckpointState]]:
ldap_source = cast(LDAPSource, pipeline.source)
return ldap_source.get_current_checkpoint(
ldap_source.stale_entity_removal_handler.job_id
@ -154,8 +154,8 @@ def test_ldap_stateful(
pipeline=pipeline_run2, expected_providers=1
)
state1 = cast(LdapCheckpointState, checkpoint1.state)
state2 = cast(LdapCheckpointState, checkpoint2.state)
state1 = checkpoint1.state
state2 = checkpoint2.state
difference_dataset_urns = list(
state1.get_urns_not_in(type="corpuser", other_checkpoint_state=state2)
@ -194,8 +194,8 @@ def test_ldap_stateful(
assert checkpoint4
assert checkpoint4.state
state3 = cast(LdapCheckpointState, checkpoint3.state)
state4 = cast(LdapCheckpointState, checkpoint4.state)
state3 = checkpoint3.state
state4 = checkpoint4.state
difference_dataset_urns = list(
state3.get_urns_not_in(type="corpGroup", other_checkpoint_state=state4)

View File

@ -1,13 +1,13 @@
from datahub.emitter.mce_builder import make_dataset_urn
from datahub.ingestion.source.state.kafka_state import KafkaCheckpointState
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
def test_kafka_common_state() -> None:
state1 = KafkaCheckpointState()
state1 = GenericCheckpointState()
test_topic_urn = make_dataset_urn("kafka", "test_topic1", "test")
state1.add_checkpoint_urn(type="topic", urn=test_topic_urn)
state2 = KafkaCheckpointState()
state2 = GenericCheckpointState()
topic_urns_diff = list(
state1.get_urns_not_in(type="topic", other_checkpoint_state=state2)
@ -16,7 +16,7 @@ def test_kafka_common_state() -> None:
def test_kafka_state_migration() -> None:
state = KafkaCheckpointState.parse_obj(
state = GenericCheckpointState.parse_obj(
{
"encoded_topic_urns": [
"kafka||test_topic1||test",

View File

@ -1,11 +1,11 @@
import pytest
from datahub.ingestion.source.state.ldap_state import LdapCheckpointState
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
@pytest.fixture
def other_checkpoint_state():
state = LdapCheckpointState()
state = GenericCheckpointState()
state.add_checkpoint_urn("corpuser", "urn:li:corpuser:user1")
state.add_checkpoint_urn("corpuser", "urn:li:corpuser:user2")
state.add_checkpoint_urn("corpuser", "urn:li:corpuser:user3")
@ -18,7 +18,7 @@ def other_checkpoint_state():
def test_add_checkpoint_urn():
state = LdapCheckpointState()
state = GenericCheckpointState()
assert len(state.urns) == 0
state.add_checkpoint_urn("corpuser", "urn:li:corpuser:user1")
assert len(state.urns) == 1
@ -27,7 +27,7 @@ def test_add_checkpoint_urn():
def test_get_urns_not_in(other_checkpoint_state):
oldstate = LdapCheckpointState()
oldstate = GenericCheckpointState()
oldstate.add_checkpoint_urn("corpuser", "urn:li:corpuser:user1")
oldstate.add_checkpoint_urn("corpuser", "urn:li:corpuser:user2")
oldstate.add_checkpoint_urn("corpuser", "urn:li:corpuser:user4")
@ -44,7 +44,7 @@ def test_get_urns_not_in(other_checkpoint_state):
def test_get_percent_entities_changed(other_checkpoint_state):
oldstate = LdapCheckpointState()
oldstate = GenericCheckpointState()
oldstate.add_checkpoint_urn("corpuser", "urn:li:corpuser:user1")
oldstate.add_checkpoint_urn("corpuser", "urn:li:corpuser:user2")
oldstate.add_checkpoint_urn("corpuser", "urn:li:corpuser:user4")