import time from dataclasses import dataclass, field as dataclass_field from typing import Any, Dict, Iterable, List, Optional, cast from unittest import mock import pydantic import pytest from freezegun import freeze_time from pydantic import Field from datahub.api.entities.dataprocess.dataprocess_instance import DataProcessInstance from datahub.configuration.common import AllowDenyPattern from datahub.configuration.source_common import DatasetSourceConfigMixin from datahub.emitter.mce_builder import DEFAULT_ENV from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.source import MetadataWorkUnitProcessor, SourceReport from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState from datahub.ingestion.source.state.stale_entity_removal_handler import ( StaleEntityRemovalHandler, StaleEntityRemovalSourceReport, StatefulStaleMetadataRemovalConfig, ) from datahub.ingestion.source.state.stateful_ingestion_base import ( StatefulIngestionConfigBase, StatefulIngestionSourceBase, ) from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import ( DataProcessInstanceProperties, ) from datahub.metadata.schema_classes import ( AuditStampClass, DataPlatformInstanceClass, StatusClass, ) from datahub.metadata.urns import DataPlatformUrn, QueryUrn from datahub.testing import mce_helpers from datahub.utilities.urns.dataset_urn import DatasetUrn from tests.test_helpers.state_helpers import ( get_current_checkpoint_from_pipeline, validate_all_providers_have_committed_successfully, ) FROZEN_TIME = "2020-04-14 07:00:00" dummy_datasets: List = ["dummy_dataset1", "dummy_dataset2", "dummy_dataset3"] @dataclass class DummySourceReport(StaleEntityRemovalSourceReport): datasets_scanned: int = 0 filtered_datasets: List[str] = dataclass_field(default_factory=list) def report_datasets_scanned(self, count: int = 1) -> None: self.datasets_scanned += count def report_datasets_dropped(self, model: str) -> None: self.filtered_datasets.append(model) class DummySourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin): dataset_patterns: AllowDenyPattern = Field( default=AllowDenyPattern.allow_all(), description="Regex patterns for datasets to filter in ingestion.", ) # Configuration for stateful ingestion stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = pydantic.Field( default=None, description="Dummy source Ingestion Config." ) report_failure: bool = Field( default=False, description="Should this dummy source report a failure.", ) dpi_id_to_ingest: Optional[str] = Field( default=None, description="Data process instance id to ingest.", ) query_id_to_ingest: Optional[str] = Field( default=None, description="Query id to ingest" ) class DummySource(StatefulIngestionSourceBase): """ This is dummy source which only extract dummy datasets """ source_config: DummySourceConfig reporter: DummySourceReport def __init__(self, config: DummySourceConfig, ctx: PipelineContext): super().__init__(config, ctx) self.source_config = config self.reporter = DummySourceReport() # Create and register the stateful ingestion use-case handler. self.stale_entity_removal_handler = StaleEntityRemovalHandler.create( self, self.source_config, self.ctx ) @classmethod def create(cls, config_dict, ctx): config = DummySourceConfig.parse_obj(config_dict) return cls(config, ctx) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), self.stale_entity_removal_handler.workunit_processor, ] def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: for dataset in dummy_datasets: if not self.source_config.dataset_patterns.allowed(dataset): self.reporter.report_datasets_dropped(dataset) continue else: self.reporter.report_datasets_scanned() dataset_urn = DatasetUrn.create_from_ids( platform_id="postgres", table_name=dataset, env=DEFAULT_ENV, ) yield MetadataChangeProposalWrapper( entityUrn=str(dataset_urn), aspect=StatusClass(removed=False), ).as_workunit() if self.source_config.dpi_id_to_ingest: dpi = DataProcessInstance( id=self.source_config.dpi_id_to_ingest, orchestrator="dummy", ) yield MetadataChangeProposalWrapper( entityUrn=str(dpi.urn), aspect=DataProcessInstanceProperties( name=dpi.id, created=AuditStampClass( time=int(time.time() * 1000), actor="urn:li:corpuser:datahub", ), type=dpi.type, ), ).as_workunit() if self.source_config.query_id_to_ingest: yield MetadataChangeProposalWrapper( entityUrn=QueryUrn(self.source_config.query_id_to_ingest).urn(), aspect=DataPlatformInstanceClass( platform=DataPlatformUrn("bigquery").urn() ), ).as_workunit() if self.source_config.report_failure: self.reporter.report_failure("Dummy error", "Error") def get_report(self) -> SourceReport: return self.reporter @pytest.fixture(scope="module") def mock_generic_checkpoint_state(): with mock.patch( "datahub.ingestion.source.state.entity_removal_state.GenericCheckpointState" ) as mock_checkpoint_state: checkpoint_state = mock_checkpoint_state.return_value checkpoint_state.serde.return_value = "utf-8" yield mock_checkpoint_state @freeze_time(FROZEN_TIME) def test_stateful_ingestion(pytestconfig, tmp_path, mock_time): # test stateful ingestion using dummy source state_file_name: str = "checkpoint_state_mces.json" golden_state_file_name: str = "golden_test_checkpoint_state.json" golden_state_file_name_after_deleted: str = ( "golden_test_checkpoint_state_after_deleted.json" ) output_file_name: str = "dummy_mces.json" golden_file_name: str = "golden_test_stateful_ingestion.json" output_file_name_after_deleted: str = "dummy_mces_stateful_after_deleted.json" golden_file_name_after_deleted: str = ( "golden_test_stateful_ingestion_after_deleted.json" ) test_resources_dir = pytestconfig.rootpath / "tests/unit/stateful_ingestion/state" base_pipeline_config = { "run_id": "dummy-test-stateful-ingestion", "pipeline_name": "dummy_stateful", "source": { "type": "tests.unit.stateful_ingestion.state.test_stateful_ingestion.DummySource", "config": { "stateful_ingestion": { "enabled": True, "remove_stale_metadata": True, "fail_safe_threshold": 100, "state_provider": { "type": "file", "config": { "filename": f"{tmp_path}/{state_file_name}", }, }, }, "dpi_id_to_ingest": "job1", "query_id_to_ingest": "query1", }, }, "sink": { "type": "file", "config": {}, }, } with mock.patch( "datahub.ingestion.source.state.stale_entity_removal_handler.StaleEntityRemovalHandler._get_state_obj" ) as mock_state, mock.patch( "datahub.ingestion.source.state.stale_entity_removal_handler.STATEFUL_INGESTION_IGNORED_ENTITY_TYPES", {}, # Second mock is to imitate earlier behavior where entity type check was not present when adding entity to state ): mock_state.return_value = GenericCheckpointState(serde="utf-8") pipeline_run1 = None pipeline_run1_config: Dict[str, Dict[str, Dict[str, Any]]] = dict( # type: ignore base_pipeline_config # type: ignore ) pipeline_run1_config["sink"]["config"]["filename"] = ( f"{tmp_path}/{output_file_name}" ) pipeline_run1 = Pipeline.create(pipeline_run1_config) pipeline_run1.run() pipeline_run1.raise_from_status() pipeline_run1.pretty_print_summary() # validate both dummy source mces and checkpoint state mces files mce_helpers.check_golden_file( pytestconfig, output_path=tmp_path / output_file_name, golden_path=f"{test_resources_dir}/{golden_file_name}", ) mce_helpers.check_golden_file( pytestconfig, output_path=tmp_path / state_file_name, golden_path=f"{test_resources_dir}/{golden_state_file_name}", ) checkpoint1 = get_current_checkpoint_from_pipeline(pipeline_run1) assert checkpoint1 assert checkpoint1.state with mock.patch( "datahub.ingestion.source.state.stale_entity_removal_handler.StaleEntityRemovalHandler._get_state_obj" ) as mock_state: mock_state.return_value = GenericCheckpointState(serde="utf-8") pipeline_run2 = None pipeline_run2_config: Dict[str, Dict[str, Dict[str, Any]]] = dict( base_pipeline_config # type: ignore ) pipeline_run2_config["source"]["config"]["dataset_patterns"] = { "allow": ["dummy_dataset1", "dummy_dataset2"], } pipeline_run2_config["source"]["config"]["dpi_id_to_ingest"] = "job2" pipeline_run2_config["source"]["config"]["query_id_to_ingest"] = "query2" pipeline_run2_config["sink"]["config"]["filename"] = ( f"{tmp_path}/{output_file_name_after_deleted}" ) pipeline_run2 = Pipeline.create(pipeline_run2_config) pipeline_run2.run() pipeline_run2.raise_from_status() pipeline_run2.pretty_print_summary() # validate both updated dummy source mces and checkpoint state mces files after deleting dataset mce_helpers.check_golden_file( pytestconfig, output_path=tmp_path / output_file_name_after_deleted, golden_path=f"{test_resources_dir}/{golden_file_name_after_deleted}", ) mce_helpers.check_golden_file( pytestconfig, output_path=tmp_path / state_file_name, golden_path=f"{test_resources_dir}/{golden_state_file_name_after_deleted}", ) checkpoint2 = get_current_checkpoint_from_pipeline(pipeline_run2) assert checkpoint2 assert checkpoint2.state # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( pipeline=pipeline_run1, expected_providers=1 ) validate_all_providers_have_committed_successfully( pipeline=pipeline_run2, expected_providers=1 ) # Perform all assertions on the states. The deleted table should not be # part of the second state state1 = cast(GenericCheckpointState, checkpoint1.state) state2 = cast(GenericCheckpointState, checkpoint2.state) difference_dataset_urns = list( state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2) ) # the difference in dataset urns is the dataset which is not allowed to ingest assert len(difference_dataset_urns) == 1 deleted_dataset_urns: List[str] = [ "urn:li:dataset:(urn:li:dataPlatform:postgres,dummy_dataset3,PROD)", ] assert sorted(deleted_dataset_urns) == sorted(difference_dataset_urns) report = pipeline_run2.source.get_report() assert isinstance(report, StaleEntityRemovalSourceReport) # assert report last ingestion state non_deletable entity urns non_deletable_urns: List[str] = [ "urn:li:dataProcessInstance:478810e859f870a54f72c681f41af619", "urn:li:query:query1", ] assert sorted(non_deletable_urns) == sorted( report.last_state_non_deletable_entities ) @freeze_time(FROZEN_TIME) def test_stateful_ingestion_failure(pytestconfig, tmp_path, mock_time): # test stateful ingestion using dummy source with pipeline execution failed in second ingestion state_file_name: str = "checkpoint_state_mces_failure.json" golden_state_file_name: str = "golden_test_checkpoint_state_failure.json" golden_state_file_name_after_deleted: str = ( "golden_test_checkpoint_state_after_deleted_failure.json" ) output_file_name: str = "dummy_mces_failure.json" golden_file_name: str = "golden_test_stateful_ingestion_failure.json" output_file_name_after_deleted: str = ( "dummy_mces_stateful_after_deleted_failure.json" ) golden_file_name_after_deleted: str = ( "golden_test_stateful_ingestion_after_deleted_failure.json" ) test_resources_dir = pytestconfig.rootpath / "tests/unit/stateful_ingestion/state" base_pipeline_config = { "run_id": "dummy-test-stateful-ingestion", "pipeline_name": "dummy_stateful", "source": { "type": "tests.unit.stateful_ingestion.state.test_stateful_ingestion.DummySource", "config": { "stateful_ingestion": { "enabled": True, "remove_stale_metadata": True, "state_provider": { "type": "file", "config": { "filename": f"{tmp_path}/{state_file_name}", }, }, }, }, }, "sink": { "type": "file", "config": {}, }, } with mock.patch( "datahub.ingestion.source.state.stale_entity_removal_handler.StaleEntityRemovalHandler._get_state_obj" ) as mock_state: mock_state.return_value = GenericCheckpointState(serde="utf-8") pipeline_run1 = None pipeline_run1_config: Dict[str, Dict[str, Dict[str, Any]]] = dict( # type: ignore base_pipeline_config # type: ignore ) pipeline_run1_config["sink"]["config"]["filename"] = ( f"{tmp_path}/{output_file_name}" ) pipeline_run1 = Pipeline.create(pipeline_run1_config) pipeline_run1.run() pipeline_run1.raise_from_status() pipeline_run1.pretty_print_summary() # validate both dummy source mces and checkpoint state mces files mce_helpers.check_golden_file( pytestconfig, output_path=tmp_path / output_file_name, golden_path=f"{test_resources_dir}/{golden_file_name}", ) mce_helpers.check_golden_file( pytestconfig, output_path=tmp_path / state_file_name, golden_path=f"{test_resources_dir}/{golden_state_file_name}", ) checkpoint1 = get_current_checkpoint_from_pipeline(pipeline_run1) assert checkpoint1 assert checkpoint1.state with mock.patch( "datahub.ingestion.source.state.stale_entity_removal_handler.StaleEntityRemovalHandler._get_state_obj" ) as mock_state: mock_state.return_value = GenericCheckpointState(serde="utf-8") pipeline_run2 = None pipeline_run2_config: Dict[str, Dict[str, Dict[str, Any]]] = dict( base_pipeline_config # type: ignore ) pipeline_run2_config["source"]["config"]["dataset_patterns"] = { "allow": ["dummy_dataset1", "dummy_dataset2"], } pipeline_run2_config["source"]["config"]["report_failure"] = True pipeline_run2_config["sink"]["config"]["filename"] = ( f"{tmp_path}/{output_file_name_after_deleted}" ) pipeline_run2 = Pipeline.create(pipeline_run2_config) pipeline_run2.run() pipeline_run2.pretty_print_summary() # validate both updated dummy source mces and checkpoint state mces files after deleting dataset mce_helpers.check_golden_file( pytestconfig, output_path=tmp_path / output_file_name_after_deleted, golden_path=f"{test_resources_dir}/{golden_file_name_after_deleted}", ) mce_helpers.check_golden_file( pytestconfig, output_path=tmp_path / state_file_name, golden_path=f"{test_resources_dir}/{golden_state_file_name_after_deleted}", ) checkpoint2 = get_current_checkpoint_from_pipeline(pipeline_run2) assert checkpoint2 assert checkpoint2.state # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( pipeline=pipeline_run1, expected_providers=1 ) validate_all_providers_have_committed_successfully( pipeline=pipeline_run2, expected_providers=1 ) # Perform assertions on the states. The deleted table should be # still part of the second state as pipeline run failed state1 = cast(GenericCheckpointState, checkpoint1.state) state2 = cast(GenericCheckpointState, checkpoint2.state) assert state1 == state2