import pathlib import time from concurrent.futures import Future, ThreadPoolExecutor from typing import Iterable, List, Optional, cast from unittest.mock import MagicMock, patch import pytest from freezegun import freeze_time from typing_extensions import Self from datahub.configuration.common import DynamicTypedConfig from datahub.ingestion.api.committable import CommitPolicy, Committable from datahub.ingestion.api.common import RecordEnvelope from datahub.ingestion.api.decorators import platform_name from datahub.ingestion.api.sink import Sink, SinkReport, WriteCallback from datahub.ingestion.api.source import Source, SourceReport from datahub.ingestion.api.transform import Transformer from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.graph.client import get_default_graph from datahub.ingestion.graph.config import ClientMode, DatahubClientConfig from datahub.ingestion.run.pipeline import Pipeline, PipelineContext from datahub.ingestion.sink.datahub_rest import DatahubRestSink from datahub.metadata.com.linkedin.pegasus2avro.mxe import SystemMetadata from datahub.metadata.schema_classes import ( DatasetPropertiesClass, DatasetSnapshotClass, MetadataChangeEventClass, StatusClass, ) from datahub.utilities.server_config_util import RestServiceConfig from tests.test_helpers.click_helpers import run_datahub_cmd from tests.test_helpers.sink_helpers import RecordingSinkReport FROZEN_TIME = "2020-04-14 07:00:00" # TODO: It seems like one of these tests writes to ~/.datahubenv or otherwise sets # some global config, which impacts other tests. pytestmark = pytest.mark.random_order(disabled=True) @pytest.fixture def mock_server_config(): # Create a mock RestServiceConfig with your desired raw_config config = RestServiceConfig(raw_config={"noCode": True}) return config class TestPipeline: @pytest.fixture(autouse=True) def clear_cache(self): yield get_default_graph.cache_clear() @patch("confluent_kafka.Consumer", autospec=True) @patch( "datahub.ingestion.source.kafka.kafka.KafkaSource.get_workunits", autospec=True ) @patch("datahub.ingestion.sink.console.ConsoleSink.close", autospec=True) @freeze_time(FROZEN_TIME) def test_configure(self, mock_sink, mock_source, mock_consumer): pipeline = Pipeline.create( { "source": { "type": "kafka", "config": {"connection": {"bootstrap": "fake-dns-name:9092"}}, }, "sink": {"type": "console"}, } ) pipeline.run() pipeline.raise_from_status() pipeline.pretty_print_summary() mock_source.assert_called_once() mock_sink.assert_called_once() @freeze_time(FROZEN_TIME) @patch("datahub.emitter.rest_emitter.DataHubRestEmitter.fetch_server_config") @patch( "datahub.cli.config_utils.load_client_config", return_value=DatahubClientConfig(server="http://fake-gms-server:8080"), ) def test_configure_without_sink( self, mock_load_client_config, mock_fetch_config, mock_server_config ): mock_fetch_config.return_value = mock_server_config pipeline = Pipeline.create( { "source": { "type": "file", "config": {"path": "test_file.json"}, }, } ) # assert that the default sink is a DatahubRestSink assert isinstance(pipeline.sink, DatahubRestSink) assert pipeline.sink.config.server == "http://fake-gms-server:8080" assert pipeline.sink.config.token is None @freeze_time(FROZEN_TIME) @patch("datahub.emitter.rest_emitter.DataHubRestEmitter.fetch_server_config") @patch( "datahub.cli.config_utils.load_client_config", return_value=DatahubClientConfig(server="http://fake-internal-server:8080"), ) @patch( "datahub.cli.config_utils.get_system_auth", return_value="Basic user:pass", ) def test_configure_without_sink_use_system_auth( self, mock_get_system_auth, mock_load_client_config, mock_fetch_config, mock_server_config, ): mock_fetch_config.return_value = mock_server_config pipeline = Pipeline.create( { "source": { "type": "file", "config": {"path": "test_file.json"}, }, } ) # assert that the default sink is a DatahubRestSink assert isinstance(pipeline.sink, DatahubRestSink) assert pipeline.sink.config.server == "http://fake-internal-server:8080" assert pipeline.sink.config.token is None assert ( pipeline.sink.emitter._session.headers["Authorization"] == "Basic user:pass" ) @freeze_time(FROZEN_TIME) @patch("datahub.emitter.rest_emitter.DataHubRestEmitter.fetch_server_config") def test_configure_with_rest_sink_initializes_graph( self, mock_fetch_config, mock_server_config ): mock_fetch_config.return_value = mock_server_config pipeline = Pipeline.create( { "source": { "type": "file", "config": {"path": "test_events.json"}, }, "sink": { "type": "datahub-rest", "config": { "server": "http://somehost.someplace.some:8080", "token": "foo", }, }, }, # We don't want to make actual API calls during this test. report_to=None, ) # assert that the default sink config is for a DatahubRestSink assert isinstance(pipeline.config.sink, DynamicTypedConfig) assert pipeline.config.sink.type == "datahub-rest" assert pipeline.config.sink.config == { "server": "http://somehost.someplace.some:8080", "token": "foo", } assert pipeline.ctx.graph is not None, "DataHubGraph should be initialized" assert pipeline.ctx.graph.config.server == pipeline.config.sink.config["server"] assert pipeline.ctx.graph.config.token == pipeline.config.sink.config["token"] @freeze_time(FROZEN_TIME) @patch("datahub.emitter.rest_emitter.DataHubRestEmitter.fetch_server_config") def test_configure_with_rest_sink_with_additional_props_initializes_graph( self, mock_fetch_config, mock_server_config ): mock_fetch_config.return_value = mock_server_config pipeline = Pipeline.create( { "source": { "type": "file", "config": {"path": "test_events.json"}, }, "sink": { "type": "datahub-rest", "config": { "server": "http://somehost.someplace.some:8080", "token": "foo", "mode": "sync", }, }, } ) # assert that the default sink config is for a DatahubRestSink assert isinstance(pipeline.config.sink, DynamicTypedConfig) assert pipeline.config.sink.type == "datahub-rest" assert pipeline.config.sink.config == { "server": "http://somehost.someplace.some:8080", "token": "foo", "mode": "sync", } assert pipeline.ctx.graph is not None, "DataHubGraph should be initialized" assert pipeline.ctx.graph.config.server == pipeline.config.sink.config["server"] assert pipeline.ctx.graph.config.token == pipeline.config.sink.config["token"] @freeze_time(FROZEN_TIME) @patch( "datahub.ingestion.source.kafka.kafka.KafkaSource.get_workunits", autospec=True ) def test_configure_with_file_sink_does_not_init_graph(self, mock_source, tmp_path): pipeline = Pipeline.create( { "source": { "type": "kafka", "config": {"connection": {"bootstrap": "localhost:9092"}}, }, "sink": { "type": "file", "config": { "filename": str(tmp_path / "test.json"), }, }, } ) # assert that the default sink config is for a DatahubRestSink assert isinstance(pipeline.config.sink, DynamicTypedConfig) assert pipeline.config.sink.type == "file" assert pipeline.config.sink.config == {"filename": str(tmp_path / "test.json")} assert pipeline.ctx.graph is None, "DataHubGraph should not be initialized" @freeze_time(FROZEN_TIME) def test_run_including_fake_transformation(self): pipeline = Pipeline.create( { "source": {"type": "tests.unit.api.test_pipeline.FakeSource"}, "transformers": [ {"type": "tests.unit.api.test_pipeline.AddStatusRemovedTransformer"} ], "sink": {"type": "tests.test_helpers.sink_helpers.RecordingSink"}, "run_id": "pipeline_test", } ) pipeline.run() pipeline.raise_from_status() expected_mce = get_initial_mce() dataset_snapshot = cast(DatasetSnapshotClass, expected_mce.proposedSnapshot) dataset_snapshot.aspects.append(get_status_removed_aspect()) sink_report: RecordingSinkReport = cast( RecordingSinkReport, pipeline.sink.get_report() ) assert len(sink_report.received_records) == 1 assert expected_mce == sink_report.received_records[0].record @freeze_time(FROZEN_TIME) def test_run_including_registered_transformation(self): # This is not testing functionality, but just the transformer registration system. pipeline = Pipeline.create( { "source": {"type": "tests.unit.api.test_pipeline.FakeSource"}, "transformers": [ { "type": "simple_add_dataset_ownership", "config": { "owner_urns": ["urn:li:corpuser:foo"], "ownership_type": "urn:li:ownershipType:__system__technical_owner", }, } ], "sink": {"type": "tests.test_helpers.sink_helpers.RecordingSink"}, } ) assert pipeline @pytest.mark.parametrize( "source,strict_warnings,exit_code", [ pytest.param( "FakeSource", False, 0, ), pytest.param( "FakeSourceWithWarnings", False, 0, ), pytest.param( "FakeSourceWithWarnings", True, 1, ), ], ) @freeze_time(FROZEN_TIME) def test_pipeline_return_code(self, tmp_path, source, strict_warnings, exit_code): config_file: pathlib.Path = tmp_path / "test.yml" config_file.write_text( f""" --- run_id: pipeline_test source: type: tests.unit.api.test_pipeline.{source} config: {{}} sink: type: console """ ) res = run_datahub_cmd( [ "ingest", "-c", f"{config_file}", *(("--strict-warnings",) if strict_warnings else ()), ], tmp_path=tmp_path, check_result=False, ) assert res.exit_code == exit_code, res.stdout @pytest.mark.parametrize( "commit_policy,source,should_commit", [ pytest.param( CommitPolicy.ALWAYS, "FakeSource", True, id="ALWAYS-no-warnings-no-errors", ), pytest.param( CommitPolicy.ON_NO_ERRORS, "FakeSource", True, id="ON_NO_ERRORS-no-warnings-no-errors", ), pytest.param( CommitPolicy.ON_NO_ERRORS_AND_NO_WARNINGS, "FakeSource", True, id="ON_NO_ERRORS_AND_NO_WARNINGS-no-warnings-no-errors", ), pytest.param( CommitPolicy.ALWAYS, "FakeSourceWithWarnings", True, id="ALWAYS-with-warnings", ), pytest.param( CommitPolicy.ON_NO_ERRORS, "FakeSourceWithWarnings", True, id="ON_NO_ERRORS-with-warnings", ), pytest.param( CommitPolicy.ON_NO_ERRORS_AND_NO_WARNINGS, "FakeSourceWithWarnings", False, id="ON_NO_ERRORS_AND_NO_WARNINGS-with-warnings", ), pytest.param( CommitPolicy.ALWAYS, "FakeSourceWithFailures", True, id="ALWAYS-with-errors", ), pytest.param( CommitPolicy.ON_NO_ERRORS, "FakeSourceWithFailures", False, id="ON_NO_ERRORS-with-errors", ), pytest.param( CommitPolicy.ON_NO_ERRORS_AND_NO_WARNINGS, "FakeSourceWithFailures", False, id="ON_NO_ERRORS_AND_NO_WARNINGS-with-errors", ), ], ) @freeze_time(FROZEN_TIME) def test_pipeline_process_commits(self, commit_policy, source, should_commit): pipeline = Pipeline.create( { "source": {"type": f"tests.unit.api.test_pipeline.{source}"}, "sink": {"type": "console"}, "run_id": "pipeline_test", } ) class FakeCommittable(Committable): def __init__(self, commit_policy: CommitPolicy): self.name = "test_checkpointer" self.commit_policy = commit_policy def commit(self) -> None: pass fake_committable: Committable = FakeCommittable(commit_policy) with patch.object( FakeCommittable, "commit", wraps=fake_committable.commit ) as mock_commit: pipeline.ctx.register_checkpointer(fake_committable) pipeline.run() # check that we called the commit method once only if should_commit is True if should_commit: mock_commit.assert_called_once() else: mock_commit.assert_not_called() @freeze_time(FROZEN_TIME) @patch("datahub.emitter.rest_emitter.DataHubRestEmitter.fetch_server_config") def test_pipeline_graph_has_expected_client_mode_and_component( self, mock_fetch_config, mock_server_config ): mock_fetch_config.return_value = mock_server_config pipeline = Pipeline.create( { "source": { "type": "file", "config": {"path": "test_events.json"}, }, "sink": { "type": "datahub-rest", "config": { "server": "http://somehost.someplace.some:8080", "token": "foo", }, }, } ) # Assert that the DataHubGraph is initialized with expected properties assert pipeline.ctx.graph is not None, "DataHubGraph should be initialized" # Check that graph has the expected client mode and datahub_component assert pipeline.ctx.graph.config.client_mode == ClientMode.INGESTION assert pipeline.ctx.graph.config.datahub_component is None # Check that the graph was configured with the expected server and token assert pipeline.ctx.graph.config.server == "http://somehost.someplace.some:8080" assert pipeline.ctx.graph.config.token == "foo" @freeze_time(FROZEN_TIME) @patch("datahub.emitter.rest_emitter.DataHubRestEmitter.fetch_server_config") @patch( "datahub.cli.config_utils.load_client_config", return_value=DatahubClientConfig(server="http://fake-gms-server:8080"), ) def test_pipeline_graph_client_mode( self, mock_load_client_config, mock_fetch_config, mock_server_config ): """Test that the graph created in Pipeline has the correct client_mode.""" mock_fetch_config.return_value = mock_server_config # Mock the DataHubGraph context manager and test_connection method mock_graph = MagicMock() mock_graph.__enter__.return_value = mock_graph # Create a patch that replaces the DataHubGraph constructor with patch( "datahub.ingestion.run.pipeline.DataHubGraph", return_value=mock_graph ) as mock_graph_class: # Create a pipeline with datahub_api config pipeline = Pipeline.create( { "source": { "type": "file", "config": {"path": "test_events.json"}, }, "datahub_api": { "server": "http://datahub-gms:8080", "token": "test_token", }, } ) # Verify DataHubGraph was called with the correct parameters assert mock_graph_class.call_count == 1 # Get the arguments passed to DataHubGraph config_arg = mock_graph_class.call_args[0][0] # Assert the config is a DatahubClientConfig with the expected values assert isinstance(config_arg, DatahubClientConfig) assert config_arg.server == "http://datahub-gms:8080" assert config_arg.token == "test_token" assert config_arg.client_mode == ClientMode.INGESTION assert config_arg.datahub_component is None # Verify the graph has been stored in the pipeline context assert pipeline.ctx.graph is mock_graph class AddStatusRemovedTransformer(Transformer): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "Transformer": return cls() def transform( self, record_envelopes: Iterable[RecordEnvelope] ) -> Iterable[RecordEnvelope]: for record_envelope in record_envelopes: if isinstance(record_envelope.record, MetadataChangeEventClass): assert isinstance( record_envelope.record.proposedSnapshot, DatasetSnapshotClass ) record_envelope.record.proposedSnapshot.aspects.append( get_status_removed_aspect() ) yield record_envelope @platform_name("fake") class FakeSource(Source): def __init__(self, ctx: PipelineContext): super().__init__(ctx) self.source_report = SourceReport() self.work_units: List[MetadataWorkUnit] = [ MetadataWorkUnit(id="workunit-1", mce=get_initial_mce()) ] @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> Self: assert not config_dict return cls(ctx) def get_workunits(self) -> Iterable[MetadataWorkUnit]: return self.work_units def get_report(self) -> SourceReport: return self.source_report def close(self): pass @platform_name("fake") class FakeSourceWithWarnings(FakeSource): def __init__(self, ctx: PipelineContext): super().__init__(ctx) self.source_report.report_warning("test_warning", "warning_text") def get_report(self) -> SourceReport: return self.source_report @platform_name("fake") class FakeSourceWithFailures(FakeSource): def __init__(self, ctx: PipelineContext): super().__init__(ctx) self.source_report.report_failure("test_failure", "failure_text") def get_report(self) -> SourceReport: return self.source_report def get_initial_mce() -> MetadataChangeEventClass: return MetadataChangeEventClass( proposedSnapshot=DatasetSnapshotClass( urn="urn:li:dataset:(urn:li:dataPlatform:test_platform,test,PROD)", aspects=[ DatasetPropertiesClass( description="test.description", ) ], ), systemMetadata=SystemMetadata( lastObserved=1586847600000, runId="pipeline_test" ), ) def get_status_removed_aspect() -> StatusClass: return StatusClass(removed=False) class RealisticSinkReport(SinkReport): """Sink report that simulates the real DatahubRestSinkReport behavior""" def __init__(self): super().__init__() self.pending_requests = ( 0 # Start with 0, will be incremented by async operations ) class RealisticDatahubRestSink(Sink): """ Realistic simulation of DatahubRestSink that uses an executor like the real implementation. This simulates the async behavior that can cause timing issues. """ def __init__(self): self.report = RealisticSinkReport() self.executor = ThreadPoolExecutor( max_workers=2, thread_name_prefix="sink-worker" ) self.close_called = False self.shutdown_complete = False def write_record_async( self, record_envelope: RecordEnvelope, write_callback: Optional[WriteCallback] = None, ) -> None: """Simulate async record writing like the real DatahubRestSink""" # Increment pending requests when starting async work self.report.pending_requests += 1 print( f"๐Ÿ“ Starting async write - pending_requests now: {self.report.pending_requests}" ) # Submit async work to executor future = self.executor.submit(self._process_record, record_envelope) # Add callback to decrement pending requests when done future.add_done_callback(self._on_request_complete) def _process_record(self, record_envelope: RecordEnvelope) -> None: """Simulate processing a record (like sending HTTP request)""" # Simulate network delay time.sleep(0.1) print(f"๐ŸŒ Processed record: {record_envelope.record}") def _on_request_complete(self, future: Future) -> None: """Callback when async request completes""" self.report.pending_requests -= 1 print( f"โœ… Request completed - pending_requests now: {self.report.pending_requests}" ) def close(self): """Simulate the real DatahubRestSink close() method""" print("๐Ÿ”ง Starting sink close()...") self.close_called = True # Simulate the executor.shutdown() behavior print( f"โณ Shutting down executor with {self.report.pending_requests} pending requests..." ) self.executor.shutdown(wait=True) # Wait for all pending requests to complete # After shutdown, pending_requests should be 0 self.report.pending_requests = 0 self.shutdown_complete = True print("โœ… Sink close() completed - all pending requests processed") class MockReporter: """Mock reporter that tracks when completion notification is called""" def __init__(self, sink=None): self.completion_called = False self.completion_pending_requests = None self.sink = sink self.start_called = False def on_start(self, ctx): """Mock on_start method""" self.start_called = True print("๐Ÿ“Š MockReporter.on_start() called") def on_completion(self, status, report, ctx): self.completion_called = True # Check pending requests at the time of completion notification if ( self.sink and hasattr(self.sink, "report") and hasattr(self.sink.report, "pending_requests") ): self.completion_pending_requests = self.sink.report.pending_requests print( f"๐Ÿ“Š Completion notification sees {self.completion_pending_requests} pending requests" ) class TestSinkReportTimingOnClose: """Test class for validating sink report timing when sink.close() is called""" def _create_real_pipeline_with_realistic_sink(self): """Create a real Pipeline instance with a realistic sink using Pipeline.create()""" from datahub.ingestion.run.pipeline import Pipeline # Create realistic sink sink = RealisticDatahubRestSink() # Create pipeline using Pipeline.create() like the existing tests do # Use demo data source which is simpler and doesn't require external dependencies pipeline = Pipeline.create( { "source": { "type": "demo-data", "config": {}, }, "sink": {"type": "console"}, # Use console sink to avoid network issues } ) # Replace the sink with our realistic sink and register it with the exit_stack # This ensures the context manager calls close() on our realistic sink pipeline.sink = pipeline.exit_stack.enter_context(sink) return pipeline, sink def _add_pending_requests_to_sink( self, sink: RealisticDatahubRestSink, count: int = 3 ) -> int: """Add some pending requests to the sink to simulate async work""" print(f"๐Ÿ“ Adding {count} pending requests to sink...") for i in range(count): sink.write_record_async(RecordEnvelope(f"record_{i}", metadata={})) # Give some time for work to start time.sleep(0.05) print(f"๐Ÿ“Š Current pending requests: {sink.report.pending_requests}") return sink.report.pending_requests def test_sink_report_timing_on_close(self): """Test that validates completion notification runs with 0 pending requests after sink.close()""" print("\n๐Ÿงช Testing sink report timing on close...") # Create test pipeline with realistic sink pipeline, sink = self._create_real_pipeline_with_realistic_sink() # Add pending requests to simulate async work pending_count = self._add_pending_requests_to_sink(sink, 3) assert pending_count > 0, "โŒ Expected pending requests to be added" # Create mock reporter to track completion reporter = MockReporter(sink=sink) pipeline.reporters = [reporter] # Create a wrapper that checks pending requests before calling the real method original_notify = pipeline._notify_reporters_on_ingestion_completion def notify_with_pending_check(): print("๐Ÿ”” Calling _notify_reporters_on_ingestion_completion()...") # Check pending requests before notifying reporters if hasattr(pipeline.sink, "report") and hasattr( pipeline.sink.report, "pending_requests" ): pending_requests = pipeline.sink.report.pending_requests print( f"๐Ÿ“Š Pending requests when _notify_reporters_on_ingestion_completion() runs: {pending_requests}" ) # This is the key assertion - there should be no pending requests assert pending_requests == 0, ( f"โŒ Expected 0 pending requests when _notify_reporters_on_ingestion_completion() runs, " f"but found {pending_requests}. This indicates a timing issue." ) print( "โœ… No pending requests when _notify_reporters_on_ingestion_completion() runs" ) # Call original method original_notify() # Replace the method temporarily pipeline._notify_reporters_on_ingestion_completion = notify_with_pending_check # Run the REAL Pipeline.run() method with the CORRECT behavior (completion notification after context manager) print( "๐Ÿ“ข Running REAL Pipeline.run() with correct timing (completion notification after context manager)..." ) pipeline.run() # Verify the fix: sink.close() was called by context manager before _notify_reporters_on_ingestion_completion assert sink.close_called, "โŒ Sink close() should be called by context manager" print("โœ… Sink close() was called by context manager") assert reporter.completion_called, "โŒ Completion notification should be called" print("โœ… Completion notification was called") # Verify no pending requests at completion (the key fix) assert reporter.completion_pending_requests == 0, ( f"โŒ Expected 0 pending requests at completion, got {reporter.completion_pending_requests}. " "This indicates the timing fix is working correctly." ) print("โœ… No pending requests at completion notification") # Verify the sink's pending_requests is also 0 assert sink.report.pending_requests == 0, ( f"โŒ Sink should have 0 pending requests after close() completes, got {sink.report.pending_requests}" ) print("โœ… Sink has 0 pending requests after close()") print( "๐ŸŽ‰ Sink report timing fix test passed! The fix works with realistic async behavior." )