import json import os from datetime import timedelta from typing import Any, Dict from unittest.mock import ANY, MagicMock, Mock, patch import pytest from requests import Response, Session from datahub.configuration.common import ( ConfigurationError, TraceTimeoutError, TraceValidationError, ) from datahub.emitter import rest_emitter from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.emitter.response_helper import TraceData from datahub.emitter.rest_emitter import ( BATCH_INGEST_MAX_PAYLOAD_LENGTH, INGEST_MAX_PAYLOAD_BYTES, DataHubRestEmitter, DatahubRestEmitter, EmitMode, RequestsSessionConfig, RestSinkEndpoint, logger, ) from datahub.errors import APITracingWarning from datahub.ingestion.graph.config import ClientMode from datahub.metadata.com.linkedin.pegasus2avro.common import ( Status, ) from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( DatasetProfile, DatasetProperties, ) from datahub.metadata.schema_classes import ( ChangeTypeClass, ) from datahub.specific.dataset import DatasetPatchBuilder from datahub.utilities.server_config_util import RestServiceConfig MOCK_GMS_ENDPOINT = "http://fakegmshost:8080" @pytest.fixture def sample_config() -> Dict[str, Any]: return { "models": {}, "patchCapable": True, "versions": { "acryldata/datahub": { "version": "v1.0.1", "commit": "dc127c5f031d579732899ccd81a53a3514dc4a6d", } }, "managedIngestion": {"defaultCliVersion": "1.0.0.2", "enabled": True}, "statefulIngestionCapable": True, "supportsImpactAnalysis": True, "timeZone": "GMT", "telemetry": {"enabledCli": True, "enabledIngestion": False}, "datasetUrnNameCasing": False, "retention": "true", "datahub": {"serverType": "dev"}, "noCode": "true", } @pytest.fixture def mock_session(): session = MagicMock(spec=Session) return session @pytest.fixture def mock_response(sample_config): response = MagicMock() response.status_code = 200 response.json.return_value = sample_config response.text = str(sample_config) return response class TestDataHubRestEmitter: @pytest.fixture def openapi_emitter(self) -> DataHubRestEmitter: openapi_emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, openapi_ingestion=True) # Set the underlying private attribute directly openapi_emitter._server_config = RestServiceConfig( raw_config={ "versions": { "acryldata/datahub": { "version": "v1.0.1rc0" # Supports OpenApi & Tracing } } } ) return openapi_emitter def test_datahub_rest_emitter_missing_gms_server(self): """Test that emitter raises ConfigurationError when gms_server is not provided.""" # Case 1: Empty string with pytest.raises(ConfigurationError) as excinfo: DataHubRestEmitter("") assert "gms server is required" in str(excinfo.value) def test_connection_error_reraising(self): """Test that test_connection() properly re-raises ConfigurationError.""" # Create a basic emitter emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT) # Mock the session's get method to raise ConfigurationError with patch.object(emitter._session, "get") as mock_get: # Configure the mock to raise ConfigurationError mock_error = ConfigurationError("Connection failed") mock_get.side_effect = mock_error # Verify that the exception is re-raised with pytest.raises(ConfigurationError) as excinfo: emitter.test_connection() # Verify it's the same exception object (not a new one) assert excinfo.value is mock_error def test_datahub_rest_emitter_construction(self) -> None: emitter = DatahubRestEmitter(MOCK_GMS_ENDPOINT) assert emitter._session_config.timeout == rest_emitter._DEFAULT_TIMEOUT_SEC assert ( emitter._session_config.retry_status_codes == rest_emitter._DEFAULT_RETRY_STATUS_CODES ) assert ( emitter._session_config.retry_max_times == rest_emitter._DEFAULT_RETRY_MAX_TIMES ) def test_datahub_rest_emitter_timeout_construction(self) -> None: emitter = DatahubRestEmitter( MOCK_GMS_ENDPOINT, connect_timeout_sec=2, read_timeout_sec=4 ) assert emitter._session_config.timeout == (2, 4) def test_datahub_rest_emitter_general_timeout_construction(self) -> None: emitter = DatahubRestEmitter( MOCK_GMS_ENDPOINT, timeout_sec=2, read_timeout_sec=4 ) assert emitter._session_config.timeout == (2, 4) def test_datahub_rest_emitter_retry_construction(self) -> None: emitter = DatahubRestEmitter( MOCK_GMS_ENDPOINT, retry_status_codes=[418], retry_max_times=42, ) assert emitter._session_config.retry_status_codes == [418] assert emitter._session_config.retry_max_times == 42 def test_datahub_rest_emitter_extra_params(self) -> None: emitter = DatahubRestEmitter( MOCK_GMS_ENDPOINT, extra_headers={"key1": "value1", "key2": "value2"} ) assert emitter._session.headers.get("key1") == "value1" assert emitter._session.headers.get("key2") == "value2" def test_openapi_emitter_emit(self, openapi_emitter): item = MetadataChangeProposalWrapper( entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)", aspect=DatasetProfile( rowCount=2000, columnCount=15, timestampMillis=1626995099686, ), ) with patch.object(openapi_emitter, "_emit_generic") as mock_method: openapi_emitter.emit_mcp(item) mock_method.assert_called_once_with( f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=false", payload=[ { "urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)", "datasetProfile": { "value": { "rowCount": 2000, "columnCount": 15, "timestampMillis": 1626995099686, "partitionSpec": { "partition": "FULL_TABLE_SNAPSHOT", "type": "FULL_TABLE", }, }, "systemMetadata": { "lastObserved": ANY, "runId": "no-run-id-provided", "lastRunId": "no-run-id-provided", "properties": { "clientId": "acryl-datahub", "clientVersion": "1!0.0.0.dev0", }, }, "headers": {}, }, } ], method="post", ) def test_openapi_emitter_emit_mcps(self, openapi_emitter): with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: items = [ MetadataChangeProposalWrapper( entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)", aspect=DatasetProfile( rowCount=2000 + i, columnCount=15, timestampMillis=1626995099686, ), ) for i in range(3) ] result = openapi_emitter.emit_mcps(items) assert result == 1 # Single chunk test - all items should be in one batch mock_emit.assert_called_once() call_args = mock_emit.call_args assert ( call_args[0][0] == f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=false" ) assert isinstance(call_args[1]["payload"], str) # Should be JSON string def test_openapi_emitter_emit_mcps_max_bytes(self, openapi_emitter): with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: # Create a large payload that will force chunking large_payload = "x" * ( INGEST_MAX_PAYLOAD_BYTES // 2 ) # Each item will be about half max size items = [ MetadataChangeProposalWrapper( entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,LargePayload{i},PROD)", aspect=DatasetProperties(name=large_payload), ) for i in range( 3 ) # This should create at least 2 chunks given the payload size ] openapi_emitter.emit_mcps(items) # Verify multiple chunks were created assert mock_emit.call_count > 1 # Verify each chunk's payload is within size limits for call in mock_emit.call_args_list: args = call[1] assert "payload" in args payload = args["payload"] assert isinstance(payload, str) # Should be JSON string assert len(payload.encode()) <= INGEST_MAX_PAYLOAD_BYTES # Verify the payload structure payload_data = json.loads(payload) assert payload_data[0]["urn"].startswith("urn:li:dataset") assert "datasetProperties" in payload_data[0] def test_openapi_emitter_emit_mcps_max_items(self, openapi_emitter): with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: # Create more items than BATCH_INGEST_MAX_PAYLOAD_LENGTH items = [ MetadataChangeProposalWrapper( entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,Item{i},PROD)", aspect=DatasetProfile( rowCount=i, columnCount=15, timestampMillis=1626995099686, ), ) for i in range( BATCH_INGEST_MAX_PAYLOAD_LENGTH + 2 ) # Create 2 more than max ] openapi_emitter.emit_mcps(items) # Verify multiple chunks were created assert mock_emit.call_count == 2 # Check first chunk first_call = mock_emit.call_args_list[0] first_payload = json.loads(first_call[1]["payload"]) assert len(first_payload) == BATCH_INGEST_MAX_PAYLOAD_LENGTH # Check second chunk second_call = mock_emit.call_args_list[1] second_payload = json.loads(second_call[1]["payload"]) assert len(second_payload) == 2 # Should have the remaining 2 items def test_openapi_emitter_emit_mcps_multiple_entity_types(self, openapi_emitter): with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: # Create items for two different entity types dataset_items = [ MetadataChangeProposalWrapper( entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,Dataset{i},PROD)", aspect=DatasetProfile( rowCount=i, columnCount=15, timestampMillis=1626995099686, ), ) for i in range(2) ] dashboard_items = [ MetadataChangeProposalWrapper( entityUrn=f"urn:li:dashboard:(looker,dashboards.{i})", aspect=Status(removed=False), ) for i in range(2) ] items = dataset_items + dashboard_items result = openapi_emitter.emit_mcps(items) # Should return number of unique entity URLs assert result == 2 assert mock_emit.call_count == 2 # Check that calls were made with different URLs but correct payloads calls = { call[0][0]: json.loads(call[1]["payload"]) for call in mock_emit.call_args_list } # Verify each URL got the right aspects for url, payload in calls.items(): if "datasetProfile" in payload[0]: assert url.endswith("dataset?async=false") assert len(payload) == 2 assert all("datasetProfile" in item for item in payload) else: assert url.endswith("dashboard?async=false") assert len(payload) == 2 assert all("status" in item for item in payload) def test_openapi_emitter_emit_mcp_with_tracing(self, openapi_emitter): """Test emitting a single MCP with tracing enabled""" with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: # Mock the response for the initial emit mock_response = Mock(spec=Response) mock_response.status_code = 200 mock_response.headers = { "traceparent": "00-00063609cb934b9d0d4e6a7d6d5e1234-1234567890abcdef-01" } mock_response.json.return_value = [ { "urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)", "datasetProfile": {}, } ] mock_emit.return_value = mock_response # Create test item item = MetadataChangeProposalWrapper( entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)", aspect=DatasetProfile( rowCount=2000, columnCount=15, timestampMillis=1626995099686, ), ) # Set up mock for trace verification responses trace_responses = [ # First check - pending { "urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)": { "datasetProfile": { "success": True, "primaryStorage": {"writeStatus": "PENDING"}, "searchStorage": {"writeStatus": "PENDING"}, } } }, # Second check - completed { "urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)": { "datasetProfile": { "success": True, "primaryStorage": {"writeStatus": "ACTIVE_STATE"}, "searchStorage": {"writeStatus": "ACTIVE_STATE"}, } } }, ] def side_effect(*args, **kwargs): if "trace/write" in args[0]: mock_trace_response = Mock(spec=Response) mock_trace_response.json.return_value = trace_responses.pop(0) return mock_trace_response return mock_response mock_emit.side_effect = side_effect # Emit with tracing enabled openapi_emitter.emit_mcp( item, emit_mode=EmitMode.ASYNC_WAIT, wait_timeout=timedelta(seconds=10), ) # Verify initial emit call assert ( mock_emit.call_args_list[0][0][0] == f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true" ) # Verify trace verification calls trace_calls = [ call for call in mock_emit.call_args_list if "trace/write" in call[0][0] ] assert len(trace_calls) == 2 assert "00063609cb934b9d0d4e6a7d6d5e1234" in trace_calls[0][0][0] def test_openapi_emitter_emit_mcps_with_tracing(self, openapi_emitter): """Test emitting multiple MCPs with tracing enabled""" with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: # Create test items items = [ MetadataChangeProposalWrapper( entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)", aspect=Status(removed=False), ) for i in range(2) ] # Mock responses for initial emit emit_responses = [] mock_resp = Mock(spec=Response) mock_resp.status_code = 200 mock_resp.headers = {"traceparent": "test-trace"} mock_resp.json.return_value = [ { "urn": f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)", "status": {"removed": False}, } for i in range(2) ] emit_responses.append(mock_resp) # Mock trace verification responses trace_responses = [ # First check - all pending { f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)": { "status": { "success": True, "primaryStorage": {"writeStatus": "PENDING"}, "searchStorage": {"writeStatus": "PENDING"}, } } for i in range(2) }, # Second check - all completed { f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)": { "status": { "success": True, "primaryStorage": {"writeStatus": "ACTIVE_STATE"}, "searchStorage": {"writeStatus": "ACTIVE_STATE"}, } } for i in range(2) }, ] response_iter = iter(emit_responses) trace_response_iter = iter(trace_responses) def side_effect(*args, **kwargs): if "trace/write" in args[0]: mock_trace_response = Mock(spec=Response) mock_trace_response.json.return_value = next(trace_response_iter) return mock_trace_response return next(response_iter) mock_emit.side_effect = side_effect # Emit with tracing enabled result = openapi_emitter.emit_mcps( items, emit_mode=EmitMode.ASYNC_WAIT, wait_timeout=timedelta(seconds=10), ) assert result == 1 # Should return number of unique entity URLs # Verify initial emit calls emit_calls = [ call for call in mock_emit.call_args_list if "entity/dataset" in call[0][0] ] assert len(emit_calls) == 1 assert ( emit_calls[0][0][0] == f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true" ) # Verify trace verification calls trace_calls = [ call for call in mock_emit.call_args_list if "trace/write" in call[0][0] ] assert len(trace_calls) == 2 def test_openapi_emitter_trace_timeout(self, openapi_emitter): """Test that tracing properly handles timeouts""" with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: # Mock initial emit response mock_response = Mock(spec=Response) mock_response.status_code = 200 mock_response.headers = {"traceparent": "test-trace-123"} mock_response.json.return_value = [ { "urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)", "datasetProfile": {}, } ] # Mock trace verification to always return pending mock_trace_response = Mock(spec=Response) mock_trace_response.json.return_value = { "urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)": { "datasetProfile": { "success": True, "primaryStorage": {"writeStatus": "PENDING"}, "searchStorage": {"writeStatus": "PENDING"}, } } } def side_effect(*args, **kwargs): return ( mock_trace_response if "trace/write" in args[0] else mock_response ) mock_emit.side_effect = side_effect item = MetadataChangeProposalWrapper( entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)", aspect=DatasetProfile( rowCount=2000, columnCount=15, timestampMillis=1626995099686, ), ) # Emit with very short timeout with pytest.raises(TraceTimeoutError) as exc_info: openapi_emitter.emit_mcp( item, emit_mode=EmitMode.ASYNC_WAIT, wait_timeout=timedelta(milliseconds=1), ) assert "Timeout waiting for async write completion" in str(exc_info.value) def test_openapi_emitter_missing_trace_header(self, openapi_emitter): """Test behavior when trace header is missing""" with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: # Mock response without trace header mock_response = Mock(spec=Response) mock_response.status_code = 200 mock_response.headers = {} # No traceparent header mock_response.json.return_value = [ { "urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,MissingTraceHeader,PROD)", "status": {"removed": False}, } ] mock_emit.return_value = mock_response item = MetadataChangeProposalWrapper( entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,MissingTraceHeader,PROD)", aspect=Status(removed=False), ) # Should not raise exception but log a warning. with pytest.warns(APITracingWarning): openapi_emitter.emit_mcp( item, emit_mode=EmitMode.ASYNC_WAIT, wait_timeout=timedelta(seconds=10), ) def test_openapi_emitter_invalid_status_code(self, openapi_emitter): """Test behavior when response has non-200 status code""" with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: # Mock response with error status code mock_response = Mock(spec=Response) mock_response.status_code = 500 mock_response.headers = {"traceparent": "test-trace-123"} mock_response.json.return_value = [ { "urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,InvalidStatusCode,PROD)", "datasetProfile": {}, } ] mock_emit.return_value = mock_response item = MetadataChangeProposalWrapper( entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,InvalidStatusCode,PROD)", aspect=DatasetProfile( rowCount=2000, columnCount=15, timestampMillis=1626995099686, ), ) # Should not raise exception but log warning openapi_emitter.emit_mcp( item, emit_mode=EmitMode.ASYNC_WAIT, wait_timeout=timedelta(seconds=10), ) def test_openapi_emitter_trace_failure(self, openapi_emitter): """Test handling of trace verification failures""" with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: test_urn = "urn:li:dataset:(urn:li:dataPlatform:mysql,TraceFailure,PROD)" # Create initial emit response emit_response = Mock(spec=Response) emit_response.status_code = 200 emit_response.headers = { "traceparent": "00-00063609cb934b9d0d4e6a7d6d5e1234-1234567890abcdef-01" } emit_response.json.return_value = [{"urn": test_urn, "datasetProfile": {}}] # Create trace verification response trace_response = Mock(spec=Response) trace_response.status_code = 200 trace_response.headers = {} trace_response.json.return_value = { test_urn: { "datasetProfile": { "success": False, "error": "Failed to write to storage", "primaryStorage": {"writeStatus": "ERROR"}, "searchStorage": {"writeStatus": "ERROR"}, } } } def side_effect(*args, **kwargs): if "trace/write" in args[0]: return trace_response return emit_response mock_emit.side_effect = side_effect item = MetadataChangeProposalWrapper( entityUrn=test_urn, aspect=DatasetProfile( rowCount=2000, columnCount=15, timestampMillis=1626995099686, ), ) with pytest.raises(TraceValidationError) as exc_info: openapi_emitter.emit_mcp( item, emit_mode=EmitMode.ASYNC_WAIT, wait_timeout=timedelta(seconds=10), ) error_message = str(exc_info.value) # Check for key error message components assert "Unable to validate async write" in error_message assert "to DataHub GMS" in error_message assert "Failed to write to storage" in error_message assert "primaryStorage" in error_message assert "writeStatus" in error_message assert "'ERROR'" in error_message # Verify trace was actually called trace_calls = [ call for call in mock_emit.call_args_list if "trace/write" in call[0][0] ] assert len(trace_calls) > 0 assert "00063609cb934b9d0d4e6a7d6d5e1234" in trace_calls[0][0][0] def test_await_status_empty_trace_data(self, openapi_emitter): with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: openapi_emitter._await_status([], timedelta(seconds=10)) assert not mock_emit._emit_generic.called def test_await_status_successful_completion(self, openapi_emitter): with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: trace = TraceData( trace_id="test-trace-id", data={"urn:li:dataset:test": ["status"]} ) mock_response = Mock( json=lambda: { "urn:li:dataset:test": { "status": { "success": True, "primaryStorage": {"writeStatus": "ACTIVE_STATE"}, "searchStorage": {"writeStatus": "TRACE_NOT_IMPLEMENTED"}, } } } ) mock_emit.return_value = mock_response openapi_emitter._await_status([trace], timedelta(seconds=10)) assert not trace.data # Should be empty after successful completion def test_await_status_timeout(self, openapi_emitter): with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: trace = TraceData( trace_id="test-trace-id", data={"urn:li:dataset:test": ["status"]} ) mock_response = Mock() mock_response.json.return_value = { "urn:li:dataset:test": { "status": { "success": True, "primaryStorage": {"writeStatus": "PENDING"}, "searchStorage": {"writeStatus": "PENDING"}, } } } mock_emit.return_value = mock_response with pytest.raises(TraceTimeoutError) as exc_info: openapi_emitter._await_status([trace], timedelta(seconds=0.1)) assert "Timeout waiting for async write completion" in str(exc_info.value) def test_await_status_persistence_failure(self, openapi_emitter): with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: trace = TraceData( trace_id="test-trace-id", data={"urn:li:dataset:test": ["status"]} ) mock_response = Mock() mock_response.json.return_value = { "urn:li:dataset:test": { "status": { "success": False, "primaryStorage": {"writeStatus": "ERROR"}, "searchStorage": {"writeStatus": "PENDING"}, } } } mock_emit.return_value = mock_response with pytest.raises(TraceValidationError) as exc_info: openapi_emitter._await_status([trace], timedelta(seconds=10)) assert "Persistence failure" in str(exc_info.value) def test_await_status_multiple_aspects(self, openapi_emitter): with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: trace = TraceData( trace_id="test-trace-id", data={"urn:li:dataset:test": ["status", "schema"]}, ) mock_response = Mock() mock_response.json.return_value = { "urn:li:dataset:test": { "status": { "success": True, "primaryStorage": {"writeStatus": "ACTIVE_STATE"}, "searchStorage": {"writeStatus": "ACTIVE_STATE"}, }, "schema": { "success": True, "primaryStorage": {"writeStatus": "HISTORIC_STATE"}, "searchStorage": {"writeStatus": "NO_OP"}, }, } } mock_emit.return_value = mock_response openapi_emitter._await_status([trace], timedelta(seconds=10)) assert not trace.data def test_await_status_logging(self, openapi_emitter): with patch.object(logger, "debug") as mock_debug, patch.object( logger, "error" ) as mock_error, patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: # Test empty trace data logging openapi_emitter._await_status([], timedelta(seconds=10)) mock_debug.assert_called_once_with("No trace data to verify") # Test error logging trace = TraceData(trace_id="test-id", data={"urn:test": ["status"]}) mock_emit.side_effect = TraceValidationError("Test error") with pytest.raises(TraceValidationError): openapi_emitter._await_status([trace], timedelta(seconds=10)) mock_error.assert_called_once() def test_openapi_emitter_same_url_different_methods(self, openapi_emitter): """Test handling of requests with same URL but different HTTP methods""" with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit: items = [ # POST requests for updating MetadataChangeProposalWrapper( entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,UpdateMe{i},PROD)", entityType="dataset", aspectName="datasetProperties", changeType=ChangeTypeClass.UPSERT, aspect=DatasetProperties(name=f"Updated Dataset {i}"), ) for i in range(2) ] + [ # PATCH requests for fetching next( iter( DatasetPatchBuilder( f"urn:li:dataset:(urn:li:dataPlatform:mysql,PatchMe{i},PROD)" ) .set_qualified_name(f"PatchMe{i}") .build() ) ) for i in range(2) ] # Run the test result = openapi_emitter.emit_mcps(items) # Verify that we made 2 calls (one for each HTTP method) assert result == 2 assert mock_emit.call_count == 2 # Check that calls were made with different methods but the same URL calls = {} for call in mock_emit.call_args_list: method = call[1]["method"] url = call[0][0] calls[(method, url)] = call assert ( "post", f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=false", ) in calls assert ( "patch", f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=false", ) in calls def test_openapi_emitter_mixed_method_chunking(self, openapi_emitter): """Test that chunking works correctly across different HTTP methods""" with patch( "datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic" ) as mock_emit, patch( "datahub.emitter.rest_emitter.BATCH_INGEST_MAX_PAYLOAD_LENGTH", 2 ): # Create more items than the chunk size for each method items = [ # POST items (4 items, should create 2 chunks) MetadataChangeProposalWrapper( entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,Dataset{i},PROD)", entityType="dataset", aspectName="datasetProfile", changeType=ChangeTypeClass.UPSERT, aspect=DatasetProfile( rowCount=i, columnCount=15, timestampMillis=0 ), ) for i in range(4) ] + [ # PATCH items (3 items, should create 2 chunks) next( iter( DatasetPatchBuilder( f"urn:li:dataset:(urn:li:dataPlatform:mysql,PatchMe{i},PROD)" ) .set_qualified_name(f"PatchMe{i}") .build() ) ) for i in range(3) ] # Run the test with a smaller chunk size to force multiple chunks result = openapi_emitter.emit_mcps(items) # Should have 4 chunks total: # - 2 chunks for POST (4 items with max 2 per chunk) # - 2 chunks for PATCH (3 items with max 2 per chunk) assert result == 4 assert mock_emit.call_count == 4 # Count the calls by method and verify chunking post_calls = [ call for call in mock_emit.call_args_list if call[1]["method"] == "post" ] patch_calls = [ call for call in mock_emit.call_args_list if call[1]["method"] == "patch" ] assert len(post_calls) == 2 # 2 chunks for POST assert len(patch_calls) == 2 # 2 chunks for PATCH # Verify first chunks have max size and last chunks have remainders post_payloads = [json.loads(call[1]["payload"]) for call in post_calls] patch_payloads = [json.loads(call[1]["payload"]) for call in patch_calls] assert len(post_payloads[0]) == 2 assert len(post_payloads[1]) == 2 assert len(patch_payloads[0]) == 2 assert len(patch_payloads[1]) == 1 # Verify all post calls are to the dataset endpoint for call in post_calls: assert ( call[0][0] == f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=false" ) # Verify all patch calls are to the dataset endpoint for call in patch_calls: assert ( call[0][0] == f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=false" ) def test_openapi_sync_full_emit_mode(self, openapi_emitter): """Test that SYNC_WAIT emit mode correctly sets async=false URL parameter and sync header""" # Create a test MCP test_mcp = MetadataChangeProposalWrapper( entityUrn="urn:li:dataset:(test,sync_full,PROD)", aspect=DatasetProfile( rowCount=500, columnCount=10, timestampMillis=1626995099686, ), ) # Test with SYNC_WAIT emit mode with patch.object(openapi_emitter, "_emit_generic") as mock_emit: # Configure mock to return a simple response mock_response = Mock(spec=Response) mock_response.status_code = 200 mock_response.headers = {} mock_emit.return_value = mock_response # Call emit_mcp with SYNC_WAIT mode openapi_emitter.emit_mcp(test_mcp, emit_mode=EmitMode.SYNC_WAIT) # Verify _emit_generic was called mock_emit.assert_called_once() # Check URL contains async=false url = mock_emit.call_args[0][0] assert "async=false" in url # Check payload contains the synchronization header payload = mock_emit.call_args[1]["payload"] # Convert payload to dict if it's a JSON string if isinstance(payload, str): payload = json.loads(payload) # Verify the headers in the payload assert isinstance(payload, list) assert len(payload) > 0 assert "headers" in payload[0]["datasetProfile"] assert ( payload[0]["datasetProfile"]["headers"].get( "X-DataHub-Sync-Index-Update" ) == "true" ) class TestOpenApiModeSelection: def test_sdk_client_mode_no_env_var(self, mock_session, mock_response): """Test that SDK client mode defaults to OpenAPI when no env var is present""" mock_session.get.return_value = mock_response # Ensure no env vars with patch.dict(os.environ, {}, clear=True): emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK) emitter._session = mock_session emitter.test_connection() assert emitter._openapi_ingestion is True def test_non_sdk_client_mode_no_env_var(self, mock_session, mock_response): """Test that non-SDK client modes default to RestLi when no env var is present""" mock_session.get.return_value = mock_response # Ensure no env vars with patch.dict(os.environ, {}, clear=True): # Test INGESTION mode emitter = DataHubRestEmitter( MOCK_GMS_ENDPOINT, client_mode=ClientMode.INGESTION ) emitter._session = mock_session emitter.test_connection() assert emitter._openapi_ingestion is False # Test CLI mode emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.CLI) emitter._session = mock_session emitter.test_connection() assert emitter._openapi_ingestion is False def test_env_var_restli_overrides_sdk_mode(self, mock_session, mock_response): """Test that env var set to RESTLI overrides SDK client mode default""" mock_session.get.return_value = mock_response with patch.dict( os.environ, {"DATAHUB_REST_EMITTER_DEFAULT_ENDPOINT": "RESTLI"}, clear=True ), patch( "datahub.emitter.rest_emitter.DEFAULT_REST_EMITTER_ENDPOINT", RestSinkEndpoint.RESTLI, ): emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK) emitter._session = mock_session emitter.test_connection() assert emitter._openapi_ingestion is False def test_env_var_openapi_any_client_mode(self, mock_session, mock_response): """Test that env var set to OPENAPI enables OpenAPI for any client mode""" mock_session.get.return_value = mock_response with patch.dict( os.environ, {"DATAHUB_REST_EMITTER_DEFAULT_ENDPOINT": "OPENAPI"}, clear=True ), patch( "datahub.emitter.rest_emitter.DEFAULT_REST_EMITTER_ENDPOINT", RestSinkEndpoint.OPENAPI, ): # Test INGESTION mode emitter = DataHubRestEmitter( MOCK_GMS_ENDPOINT, client_mode=ClientMode.INGESTION ) emitter._session = mock_session emitter.test_connection() assert emitter._openapi_ingestion is True # Test CLI mode emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.CLI) emitter._session = mock_session emitter.test_connection() assert emitter._openapi_ingestion is True # Test SDK mode emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK) emitter._session = mock_session emitter.test_connection() assert emitter._openapi_ingestion is True def test_constructor_param_true_overrides_all(self): """Test that explicit constructor parameter True overrides all other settings""" with patch.dict( os.environ, {"DATAHUB_REST_EMITTER_DEFAULT_ENDPOINT": "RESTLI"}, clear=True ), patch( "datahub.emitter.rest_emitter.DEFAULT_REST_EMITTER_ENDPOINT", RestSinkEndpoint.RESTLI, ): # Even with env var and non-SDK mode, constructor param should win emitter = DataHubRestEmitter( MOCK_GMS_ENDPOINT, client_mode=ClientMode.INGESTION, openapi_ingestion=True, ) assert emitter._openapi_ingestion is True def test_constructor_param_false_overrides_all(self): """Test that explicit constructor parameter False overrides all other settings""" with patch.dict( os.environ, {"DATAHUB_REST_EMITTER_DEFAULT_ENDPOINT": "OPENAPI"}, clear=True ), patch( "datahub.emitter.rest_emitter.DEFAULT_REST_EMITTER_ENDPOINT", RestSinkEndpoint.OPENAPI, ): # Even with env var and SDK mode, constructor param should win emitter = DataHubRestEmitter( MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK, openapi_ingestion=False ) assert emitter._openapi_ingestion is False def test_debug_logging(self, mock_session, mock_response): """Test that debug logging is called with correct protocol information""" mock_session.get.return_value = mock_response with patch("datahub.emitter.rest_emitter.logger") as mock_logger: # Test OpenAPI logging emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, openapi_ingestion=True) emitter._session = mock_session emitter.test_connection() mock_logger.debug.assert_any_call("Using OpenAPI for ingestion.") # Test RestLi logging mock_logger.reset_mock() emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, openapi_ingestion=False) emitter._session = mock_session emitter.test_connection() mock_logger.debug.assert_any_call("Using Restli for ingestion.") class TestOpenApiIntegration: def test_sdk_mode_uses_openapi_by_default(self, mock_session, mock_response): """Test that SDK mode uses OpenAPI by default for emit_mcp""" mock_session.get.return_value = mock_response with patch.dict("os.environ", {}, clear=True): emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK) emitter._session = mock_session emitter.test_connection() # Verify _openapi_ingestion was set correctly assert emitter._openapi_ingestion is True # Create test MCP test_mcp = MetadataChangeProposalWrapper( entityUrn="urn:li:dataset:(test,sdk,PROD)", aspect=Status(removed=False), ) # Mock _emit_generic to inspect what URL is used with patch.object(emitter, "_emit_generic") as mock_emit: emitter.emit_mcp(test_mcp) # Check that OpenAPI URL format was used mock_emit.assert_called_once() url = mock_emit.call_args[0][0] assert "openapi" in url assert url.startswith(f"{MOCK_GMS_ENDPOINT}/openapi") def test_ingestion_mode_uses_restli_by_default(self, mock_session, mock_response): """Test that INGESTION mode uses RestLi by default for emit_mcp""" mock_session.get.return_value = mock_response with patch.dict("os.environ", {}, clear=True): emitter = DataHubRestEmitter( MOCK_GMS_ENDPOINT, client_mode=ClientMode.INGESTION ) emitter._session = mock_session emitter.test_connection() # Verify _openapi_ingestion was set correctly assert emitter._openapi_ingestion is False # Create test MCP test_mcp = MetadataChangeProposalWrapper( entityUrn="urn:li:dataset:(test,ingestion,PROD)", aspect=Status(removed=False), ) # Mock _emit_generic to inspect what URL is used with patch.object(emitter, "_emit_generic") as mock_emit: emitter.emit_mcp(test_mcp) # Check that RestLi URL format was used (not OpenAPI) mock_emit.assert_called_once() url = mock_emit.call_args[0][0] assert "openapi" not in url assert "aspects?action=ingestProposal" in url def test_explicit_openapi_parameter_uses_openapi_api(self): """Test that explicit openapi_ingestion=True uses OpenAPI regardless of mode""" emitter = DataHubRestEmitter( MOCK_GMS_ENDPOINT, client_mode=ClientMode.INGESTION, # Would normally use RestLi openapi_ingestion=True, # Override to use OpenAPI ) # Verify _openapi_ingestion was set correctly assert emitter._openapi_ingestion is True # Create test MCP test_mcp = MetadataChangeProposalWrapper( entityUrn="urn:li:dataset:(test,explicit,PROD)", aspect=DatasetProfile(rowCount=100, columnCount=10, timestampMillis=0), ) # Mock _emit_generic to inspect what URL is used with patch.object(emitter, "_emit_generic") as mock_emit: emitter.emit_mcp(test_mcp) # Check that OpenAPI URL format was used mock_emit.assert_called_once() url = mock_emit.call_args[0][0] assert "openapi" in url # Verify the payload is formatted for OpenAPI payload = mock_emit.call_args[1]["payload"] assert isinstance(payload, list) or ( isinstance(payload, str) and payload.startswith("[") ) def test_openapi_batch_endpoint_selection(self): """Test that OpenAPI batch operations use correct endpoints based on entity type""" emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, openapi_ingestion=True) # Create MCPs with different entity types dataset_mcp = MetadataChangeProposalWrapper( entityUrn="urn:li:dataset:(test,batch1,PROD)", entityType="dataset", aspect=Status(removed=False), ) dashboard_mcp = MetadataChangeProposalWrapper( entityUrn="urn:li:dashboard:(test,batch2)", entityType="dashboard", aspect=Status(removed=False), ) # Mock _emit_generic to inspect URLs used with patch.object(emitter, "_emit_generic") as mock_emit: # Configure mock to return appropriate responses mock_response = Mock() mock_response.status_code = 200 mock_response.headers = {} mock_response.json.return_value = [] mock_emit.return_value = mock_response # Emit batch of different entity types emitter.emit_mcps([dataset_mcp, dashboard_mcp]) # Verify we made two calls to different endpoints assert mock_emit.call_count == 2 # Get the URLs from the calls calls = mock_emit.call_args_list urls = [call[0][0] for call in calls] # Verify we called both entity endpoints assert any("entity/dataset" in url for url in urls) assert any("entity/dashboard" in url for url in urls) class TestRequestsSessionConfig: def test_get_client_mode_from_session(self): """Test extracting ClientMode from session headers with various inputs.""" # Case 1: Session with valid ClientMode in headers (string) session = Session() session.headers.update({"X-DataHub-Client-Mode": "SDK"}) mode = RequestsSessionConfig.get_client_mode_from_session(session) assert mode == ClientMode.SDK # Case 2: Session with valid ClientMode in headers (bytes) session = Session() session.headers.update({"X-DataHub-Client-Mode": b"INGESTION"}) mode = RequestsSessionConfig.get_client_mode_from_session(session) assert mode == ClientMode.INGESTION # Case 3: Session with no ClientMode header session = Session() mode = RequestsSessionConfig.get_client_mode_from_session(session) assert mode is None # Case 4: Session with invalid ClientMode value session = Session() session.headers.update({"X-DataHub-Client-Mode": "INVALID_MODE"}) mode = RequestsSessionConfig.get_client_mode_from_session(session) assert mode is None # Case 5: Session with empty ClientMode value session = Session() session.headers.update({"X-DataHub-Client-Mode": ""}) mode = RequestsSessionConfig.get_client_mode_from_session(session) assert mode is None # Case 6: Test with exception during processing mock_session = Mock() mock_session.headers = { "X-DataHub-Client-Mode": object() } # Will cause error when decoded mode = RequestsSessionConfig.get_client_mode_from_session(mock_session) assert mode is None # Case 7: Different capitalization should not match session = Session() session.headers.update({"X-DataHub-Client-Mode": "sdk"}) # lowercase mode = RequestsSessionConfig.get_client_mode_from_session(session) assert mode is None # Case 8: Test with all available ClientMode values for client_mode in ClientMode: session = Session() session.headers.update({"X-DataHub-Client-Mode": client_mode.name}) mode = RequestsSessionConfig.get_client_mode_from_session(session) assert mode == client_mode