datahub/metadata-ingestion/tests/unit/sdk/test_rest_emitter.py
David Leifker e6babc3b81 Revert "feat(ingestion): refactor api-tracing EmitMode"
This reverts commit bf598aed9687e9b08ccfbd72257fc890b505d775.
2025-05-01 21:06:10 -05:00

1253 lines
50 KiB
Python

import json
import os
from datetime import timedelta
from unittest.mock import ANY, Mock, patch
import pytest
from requests import Response, Session
from datahub.configuration.common import (
ConfigurationError,
TraceTimeoutError,
TraceValidationError,
)
from datahub.emitter import rest_emitter
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.response_helper import TraceData
from datahub.emitter.rest_emitter import (
BATCH_INGEST_MAX_PAYLOAD_LENGTH,
INGEST_MAX_PAYLOAD_BYTES,
DataHubRestEmitter,
DatahubRestEmitter,
RequestsSessionConfig,
RestSinkEndpoint,
logger,
)
from datahub.errors import APITracingWarning
from datahub.ingestion.graph.config import ClientMode
from datahub.metadata.com.linkedin.pegasus2avro.common import (
Status,
)
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
DatasetProfile,
DatasetProperties,
)
from datahub.metadata.schema_classes import (
ChangeTypeClass,
)
from datahub.specific.dataset import DatasetPatchBuilder
from datahub.utilities.server_config_util import RestServiceConfig
MOCK_GMS_ENDPOINT = "http://fakegmshost:8080"
class TestDataHubRestEmitter:
@pytest.fixture
def openapi_emitter(self) -> DataHubRestEmitter:
return DataHubRestEmitter(MOCK_GMS_ENDPOINT, openapi_ingestion=True)
def test_datahub_rest_emitter_missing_gms_server(self):
"""Test that emitter raises ConfigurationError when gms_server is not provided."""
# Case 1: Empty string
with pytest.raises(ConfigurationError) as excinfo:
DataHubRestEmitter("")
assert "gms server is required" in str(excinfo.value)
def test_connection_error_reraising(self):
"""Test that test_connection() properly re-raises ConfigurationError."""
# Create a basic emitter
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT)
# Mock RestServiceConfig.fetch_config to raise ConfigurationError
with patch.object(RestServiceConfig, "fetch_config") as mock_fetch_config:
# Configure the mock to raise ConfigurationError
mock_error = ConfigurationError("Connection failed")
mock_fetch_config.side_effect = mock_error
# Verify that the exception is re-raised
with pytest.raises(ConfigurationError) as excinfo:
emitter.test_connection()
# Verify it's the same exception object (not a new one)
assert excinfo.value is mock_error
def test_datahub_rest_emitter_construction(self) -> None:
emitter = DatahubRestEmitter(MOCK_GMS_ENDPOINT)
assert emitter._session_config.timeout == rest_emitter._DEFAULT_TIMEOUT_SEC
assert (
emitter._session_config.retry_status_codes
== rest_emitter._DEFAULT_RETRY_STATUS_CODES
)
assert (
emitter._session_config.retry_max_times
== rest_emitter._DEFAULT_RETRY_MAX_TIMES
)
def test_datahub_rest_emitter_timeout_construction(self) -> None:
emitter = DatahubRestEmitter(
MOCK_GMS_ENDPOINT, connect_timeout_sec=2, read_timeout_sec=4
)
assert emitter._session_config.timeout == (2, 4)
def test_datahub_rest_emitter_general_timeout_construction(self) -> None:
emitter = DatahubRestEmitter(
MOCK_GMS_ENDPOINT, timeout_sec=2, read_timeout_sec=4
)
assert emitter._session_config.timeout == (2, 4)
def test_datahub_rest_emitter_retry_construction(self) -> None:
emitter = DatahubRestEmitter(
MOCK_GMS_ENDPOINT,
retry_status_codes=[418],
retry_max_times=42,
)
assert emitter._session_config.retry_status_codes == [418]
assert emitter._session_config.retry_max_times == 42
def test_datahub_rest_emitter_extra_params(self) -> None:
emitter = DatahubRestEmitter(
MOCK_GMS_ENDPOINT, extra_headers={"key1": "value1", "key2": "value2"}
)
assert emitter._session.headers.get("key1") == "value1"
assert emitter._session.headers.get("key2") == "value2"
def test_openapi_emitter_emit(self, openapi_emitter):
item = MetadataChangeProposalWrapper(
entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)",
aspect=DatasetProfile(
rowCount=2000,
columnCount=15,
timestampMillis=1626995099686,
),
)
with patch.object(openapi_emitter, "_emit_generic") as mock_method:
openapi_emitter.emit_mcp(item)
mock_method.assert_called_once_with(
f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=false",
payload=[
{
"urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)",
"datasetProfile": {
"value": {
"rowCount": 2000,
"columnCount": 15,
"timestampMillis": 1626995099686,
"partitionSpec": {
"partition": "FULL_TABLE_SNAPSHOT",
"type": "FULL_TABLE",
},
},
"systemMetadata": {
"lastObserved": ANY,
"runId": "no-run-id-provided",
"lastRunId": "no-run-id-provided",
"properties": {
"clientId": "acryl-datahub",
"clientVersion": "1!0.0.0.dev0",
},
},
},
}
],
method="post",
)
def test_openapi_emitter_emit_mcps(self, openapi_emitter):
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
items = [
MetadataChangeProposalWrapper(
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)",
aspect=DatasetProfile(
rowCount=2000 + i,
columnCount=15,
timestampMillis=1626995099686,
),
)
for i in range(3)
]
result = openapi_emitter.emit_mcps(items)
assert result == 1
# Single chunk test - all items should be in one batch
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert (
call_args[0][0]
== f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true"
)
assert isinstance(call_args[1]["payload"], str) # Should be JSON string
def test_openapi_emitter_emit_mcps_max_bytes(self, openapi_emitter):
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
# Create a large payload that will force chunking
large_payload = "x" * (
INGEST_MAX_PAYLOAD_BYTES // 2
) # Each item will be about half max size
items = [
MetadataChangeProposalWrapper(
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,LargePayload{i},PROD)",
aspect=DatasetProperties(name=large_payload),
)
for i in range(
3
) # This should create at least 2 chunks given the payload size
]
openapi_emitter.emit_mcps(items)
# Verify multiple chunks were created
assert mock_emit.call_count > 1
# Verify each chunk's payload is within size limits
for call in mock_emit.call_args_list:
args = call[1]
assert "payload" in args
payload = args["payload"]
assert isinstance(payload, str) # Should be JSON string
assert len(payload.encode()) <= INGEST_MAX_PAYLOAD_BYTES
# Verify the payload structure
payload_data = json.loads(payload)
assert payload_data[0]["urn"].startswith("urn:li:dataset")
assert "datasetProperties" in payload_data[0]
def test_openapi_emitter_emit_mcps_max_items(self, openapi_emitter):
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
# Create more items than BATCH_INGEST_MAX_PAYLOAD_LENGTH
items = [
MetadataChangeProposalWrapper(
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,Item{i},PROD)",
aspect=DatasetProfile(
rowCount=i,
columnCount=15,
timestampMillis=1626995099686,
),
)
for i in range(
BATCH_INGEST_MAX_PAYLOAD_LENGTH + 2
) # Create 2 more than max
]
openapi_emitter.emit_mcps(items)
# Verify multiple chunks were created
assert mock_emit.call_count == 2
# Check first chunk
first_call = mock_emit.call_args_list[0]
first_payload = json.loads(first_call[1]["payload"])
assert len(first_payload) == BATCH_INGEST_MAX_PAYLOAD_LENGTH
# Check second chunk
second_call = mock_emit.call_args_list[1]
second_payload = json.loads(second_call[1]["payload"])
assert len(second_payload) == 2 # Should have the remaining 2 items
def test_openapi_emitter_emit_mcps_multiple_entity_types(self, openapi_emitter):
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
# Create items for two different entity types
dataset_items = [
MetadataChangeProposalWrapper(
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,Dataset{i},PROD)",
aspect=DatasetProfile(
rowCount=i,
columnCount=15,
timestampMillis=1626995099686,
),
)
for i in range(2)
]
dashboard_items = [
MetadataChangeProposalWrapper(
entityUrn=f"urn:li:dashboard:(looker,dashboards.{i})",
aspect=Status(removed=False),
)
for i in range(2)
]
items = dataset_items + dashboard_items
result = openapi_emitter.emit_mcps(items)
# Should return number of unique entity URLs
assert result == 2
assert mock_emit.call_count == 2
# Check that calls were made with different URLs but correct payloads
calls = {
call[0][0]: json.loads(call[1]["payload"])
for call in mock_emit.call_args_list
}
# Verify each URL got the right aspects
for url, payload in calls.items():
if "datasetProfile" in payload[0]:
assert url.endswith("dataset?async=true")
assert len(payload) == 2
assert all("datasetProfile" in item for item in payload)
else:
assert url.endswith("dashboard?async=true")
assert len(payload) == 2
assert all("status" in item for item in payload)
def test_openapi_emitter_emit_mcp_with_tracing(self, openapi_emitter):
"""Test emitting a single MCP with tracing enabled"""
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
# Mock the response for the initial emit
mock_response = Mock(spec=Response)
mock_response.status_code = 200
mock_response.headers = {"traceparent": "test-trace-123"}
mock_response.json.return_value = [
{
"urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)",
"datasetProfile": {},
}
]
mock_emit.return_value = mock_response
# Create test item
item = MetadataChangeProposalWrapper(
entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)",
aspect=DatasetProfile(
rowCount=2000,
columnCount=15,
timestampMillis=1626995099686,
),
)
# Set up mock for trace verification responses
trace_responses = [
# First check - pending
{
"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)": {
"datasetProfile": {
"success": True,
"primaryStorage": {"writeStatus": "PENDING"},
"searchStorage": {"writeStatus": "PENDING"},
}
}
},
# Second check - completed
{
"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)": {
"datasetProfile": {
"success": True,
"primaryStorage": {"writeStatus": "ACTIVE_STATE"},
"searchStorage": {"writeStatus": "ACTIVE_STATE"},
}
}
},
]
def side_effect(*args, **kwargs):
if "trace/write" in args[0]:
mock_trace_response = Mock(spec=Response)
mock_trace_response.json.return_value = trace_responses.pop(0)
return mock_trace_response
return mock_response
mock_emit.side_effect = side_effect
# Emit with tracing enabled
openapi_emitter.emit_mcp(
item,
async_flag=True,
trace_flag=True,
trace_timeout=timedelta(seconds=10),
)
# Verify initial emit call
assert (
mock_emit.call_args_list[0][0][0]
== f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true"
)
# Verify trace verification calls
trace_calls = [
call for call in mock_emit.call_args_list if "trace/write" in call[0][0]
]
assert len(trace_calls) == 2
assert "test-trace-123" in trace_calls[0][0][0]
def test_openapi_emitter_emit_mcps_with_tracing(self, openapi_emitter):
"""Test emitting multiple MCPs with tracing enabled"""
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
# Create test items
items = [
MetadataChangeProposalWrapper(
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)",
aspect=Status(removed=False),
)
for i in range(2)
]
# Mock responses for initial emit
emit_responses = []
mock_resp = Mock(spec=Response)
mock_resp.status_code = 200
mock_resp.headers = {"traceparent": "test-trace"}
mock_resp.json.return_value = [
{
"urn": f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)",
"status": {"removed": False},
}
for i in range(2)
]
emit_responses.append(mock_resp)
# Mock trace verification responses
trace_responses = [
# First check - all pending
{
f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)": {
"status": {
"success": True,
"primaryStorage": {"writeStatus": "PENDING"},
"searchStorage": {"writeStatus": "PENDING"},
}
}
for i in range(2)
},
# Second check - all completed
{
f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)": {
"status": {
"success": True,
"primaryStorage": {"writeStatus": "ACTIVE_STATE"},
"searchStorage": {"writeStatus": "ACTIVE_STATE"},
}
}
for i in range(2)
},
]
response_iter = iter(emit_responses)
trace_response_iter = iter(trace_responses)
def side_effect(*args, **kwargs):
if "trace/write" in args[0]:
mock_trace_response = Mock(spec=Response)
mock_trace_response.json.return_value = next(trace_response_iter)
return mock_trace_response
return next(response_iter)
mock_emit.side_effect = side_effect
# Emit with tracing enabled
result = openapi_emitter.emit_mcps(
items,
async_flag=True,
trace_flag=True,
trace_timeout=timedelta(seconds=10),
)
assert result == 1 # Should return number of unique entity URLs
# Verify initial emit calls
emit_calls = [
call
for call in mock_emit.call_args_list
if "entity/dataset" in call[0][0]
]
assert len(emit_calls) == 1
assert (
emit_calls[0][0][0]
== f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true"
)
# Verify trace verification calls
trace_calls = [
call for call in mock_emit.call_args_list if "trace/write" in call[0][0]
]
assert len(trace_calls) == 2
def test_openapi_emitter_trace_timeout(self, openapi_emitter):
"""Test that tracing properly handles timeouts"""
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
# Mock initial emit response
mock_response = Mock(spec=Response)
mock_response.status_code = 200
mock_response.headers = {"traceparent": "test-trace-123"}
mock_response.json.return_value = [
{
"urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)",
"datasetProfile": {},
}
]
# Mock trace verification to always return pending
mock_trace_response = Mock(spec=Response)
mock_trace_response.json.return_value = {
"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)": {
"datasetProfile": {
"success": True,
"primaryStorage": {"writeStatus": "PENDING"},
"searchStorage": {"writeStatus": "PENDING"},
}
}
}
def side_effect(*args, **kwargs):
return (
mock_trace_response if "trace/write" in args[0] else mock_response
)
mock_emit.side_effect = side_effect
item = MetadataChangeProposalWrapper(
entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)",
aspect=DatasetProfile(
rowCount=2000,
columnCount=15,
timestampMillis=1626995099686,
),
)
# Emit with very short timeout
with pytest.raises(TraceTimeoutError) as exc_info:
openapi_emitter.emit_mcp(
item,
async_flag=True,
trace_flag=True,
trace_timeout=timedelta(milliseconds=1),
)
assert "Timeout waiting for async write completion" in str(exc_info.value)
def test_openapi_emitter_missing_trace_header(self, openapi_emitter):
"""Test behavior when trace header is missing"""
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
# Mock response without trace header
mock_response = Mock(spec=Response)
mock_response.status_code = 200
mock_response.headers = {} # No traceparent header
mock_response.json.return_value = [
{
"urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,MissingTraceHeader,PROD)",
"status": {"removed": False},
}
]
mock_emit.return_value = mock_response
item = MetadataChangeProposalWrapper(
entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,MissingTraceHeader,PROD)",
aspect=Status(removed=False),
)
# Should not raise exception but log a warning.
with pytest.warns(APITracingWarning):
openapi_emitter.emit_mcp(
item,
async_flag=True,
trace_flag=True,
trace_timeout=timedelta(seconds=10),
)
def test_openapi_emitter_invalid_status_code(self, openapi_emitter):
"""Test behavior when response has non-200 status code"""
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
# Mock response with error status code
mock_response = Mock(spec=Response)
mock_response.status_code = 500
mock_response.headers = {"traceparent": "test-trace-123"}
mock_response.json.return_value = [
{
"urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,InvalidStatusCode,PROD)",
"datasetProfile": {},
}
]
mock_emit.return_value = mock_response
item = MetadataChangeProposalWrapper(
entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,InvalidStatusCode,PROD)",
aspect=DatasetProfile(
rowCount=2000,
columnCount=15,
timestampMillis=1626995099686,
),
)
# Should not raise exception but log warning
openapi_emitter.emit_mcp(
item,
async_flag=True,
trace_flag=True,
trace_timeout=timedelta(seconds=10),
)
def test_openapi_emitter_trace_failure(self, openapi_emitter):
"""Test handling of trace verification failures"""
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
test_urn = "urn:li:dataset:(urn:li:dataPlatform:mysql,TraceFailure,PROD)"
# Create initial emit response
emit_response = Mock(spec=Response)
emit_response.status_code = 200
emit_response.headers = {"traceparent": "test-trace-123"}
emit_response.json.return_value = [{"urn": test_urn, "datasetProfile": {}}]
# Create trace verification response
trace_response = Mock(spec=Response)
trace_response.status_code = 200
trace_response.headers = {}
trace_response.json.return_value = {
test_urn: {
"datasetProfile": {
"success": False,
"error": "Failed to write to storage",
"primaryStorage": {"writeStatus": "ERROR"},
"searchStorage": {"writeStatus": "ERROR"},
}
}
}
def side_effect(*args, **kwargs):
if "trace/write" in args[0]:
return trace_response
return emit_response
mock_emit.side_effect = side_effect
item = MetadataChangeProposalWrapper(
entityUrn=test_urn,
aspect=DatasetProfile(
rowCount=2000,
columnCount=15,
timestampMillis=1626995099686,
),
)
with pytest.raises(TraceValidationError) as exc_info:
openapi_emitter.emit_mcp(
item,
async_flag=True,
trace_flag=True,
trace_timeout=timedelta(seconds=10),
)
assert "Unable to validate async write to DataHub GMS" in str(
exc_info.value
)
# Verify the error details are included
assert "Failed to write to storage" in str(exc_info.value)
# Verify trace was actually called
trace_calls = [
call for call in mock_emit.call_args_list if "trace/write" in call[0][0]
]
assert len(trace_calls) > 0
assert "test-trace-123" in trace_calls[0][0][0]
def test_await_status_empty_trace_data(self, openapi_emitter):
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
openapi_emitter._await_status([], timedelta(seconds=10))
assert not mock_emit._emit_generic.called
def test_await_status_successful_completion(self, openapi_emitter):
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
trace = TraceData(
trace_id="test-trace-id", data={"urn:li:dataset:test": ["status"]}
)
mock_response = Mock(
json=lambda: {
"urn:li:dataset:test": {
"status": {
"success": True,
"primaryStorage": {"writeStatus": "ACTIVE_STATE"},
"searchStorage": {"writeStatus": "TRACE_NOT_IMPLEMENTED"},
}
}
}
)
mock_emit.return_value = mock_response
openapi_emitter._await_status([trace], timedelta(seconds=10))
assert not trace.data # Should be empty after successful completion
def test_await_status_timeout(self, openapi_emitter):
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
trace = TraceData(
trace_id="test-trace-id", data={"urn:li:dataset:test": ["status"]}
)
mock_response = Mock()
mock_response.json.return_value = {
"urn:li:dataset:test": {
"status": {
"success": True,
"primaryStorage": {"writeStatus": "PENDING"},
"searchStorage": {"writeStatus": "PENDING"},
}
}
}
mock_emit.return_value = mock_response
with pytest.raises(TraceTimeoutError) as exc_info:
openapi_emitter._await_status([trace], timedelta(seconds=0.1))
assert "Timeout waiting for async write completion" in str(exc_info.value)
def test_await_status_persistence_failure(self, openapi_emitter):
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
trace = TraceData(
trace_id="test-trace-id", data={"urn:li:dataset:test": ["status"]}
)
mock_response = Mock()
mock_response.json.return_value = {
"urn:li:dataset:test": {
"status": {
"success": False,
"primaryStorage": {"writeStatus": "ERROR"},
"searchStorage": {"writeStatus": "PENDING"},
}
}
}
mock_emit.return_value = mock_response
with pytest.raises(TraceValidationError) as exc_info:
openapi_emitter._await_status([trace], timedelta(seconds=10))
assert "Persistence failure" in str(exc_info.value)
def test_await_status_multiple_aspects(self, openapi_emitter):
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
trace = TraceData(
trace_id="test-trace-id",
data={"urn:li:dataset:test": ["status", "schema"]},
)
mock_response = Mock()
mock_response.json.return_value = {
"urn:li:dataset:test": {
"status": {
"success": True,
"primaryStorage": {"writeStatus": "ACTIVE_STATE"},
"searchStorage": {"writeStatus": "ACTIVE_STATE"},
},
"schema": {
"success": True,
"primaryStorage": {"writeStatus": "HISTORIC_STATE"},
"searchStorage": {"writeStatus": "NO_OP"},
},
}
}
mock_emit.return_value = mock_response
openapi_emitter._await_status([trace], timedelta(seconds=10))
assert not trace.data
def test_await_status_logging(self, openapi_emitter):
with patch.object(logger, "debug") as mock_debug, patch.object(
logger, "error"
) as mock_error, patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
# Test empty trace data logging
openapi_emitter._await_status([], timedelta(seconds=10))
mock_debug.assert_called_once_with("No trace data to verify")
# Test error logging
trace = TraceData(trace_id="test-id", data={"urn:test": ["status"]})
mock_emit.side_effect = TraceValidationError("Test error")
with pytest.raises(TraceValidationError):
openapi_emitter._await_status([trace], timedelta(seconds=10))
mock_error.assert_called_once()
def test_openapi_emitter_same_url_different_methods(self, openapi_emitter):
"""Test handling of requests with same URL but different HTTP methods"""
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
items = [
# POST requests for updating
MetadataChangeProposalWrapper(
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,UpdateMe{i},PROD)",
entityType="dataset",
aspectName="datasetProperties",
changeType=ChangeTypeClass.UPSERT,
aspect=DatasetProperties(name=f"Updated Dataset {i}"),
)
for i in range(2)
] + [
# PATCH requests for fetching
next(
iter(
DatasetPatchBuilder(
f"urn:li:dataset:(urn:li:dataPlatform:mysql,PatchMe{i},PROD)"
)
.set_qualified_name(f"PatchMe{i}")
.build()
)
)
for i in range(2)
]
# Run the test
result = openapi_emitter.emit_mcps(items)
# Verify that we made 2 calls (one for each HTTP method)
assert result == 2
assert mock_emit.call_count == 2
# Check that calls were made with different methods but the same URL
calls = {}
for call in mock_emit.call_args_list:
method = call[1]["method"]
url = call[0][0]
calls[(method, url)] = call
assert (
"post",
f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true",
) in calls
assert (
"patch",
f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true",
) in calls
def test_openapi_emitter_mixed_method_chunking(self, openapi_emitter):
"""Test that chunking works correctly across different HTTP methods"""
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit, patch(
"datahub.emitter.rest_emitter.BATCH_INGEST_MAX_PAYLOAD_LENGTH", 2
):
# Create more items than the chunk size for each method
items = [
# POST items (4 items, should create 2 chunks)
MetadataChangeProposalWrapper(
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,Dataset{i},PROD)",
entityType="dataset",
aspectName="datasetProfile",
changeType=ChangeTypeClass.UPSERT,
aspect=DatasetProfile(
rowCount=i, columnCount=15, timestampMillis=0
),
)
for i in range(4)
] + [
# PATCH items (3 items, should create 2 chunks)
next(
iter(
DatasetPatchBuilder(
f"urn:li:dataset:(urn:li:dataPlatform:mysql,PatchMe{i},PROD)"
)
.set_qualified_name(f"PatchMe{i}")
.build()
)
)
for i in range(3)
]
# Run the test with a smaller chunk size to force multiple chunks
result = openapi_emitter.emit_mcps(items)
# Should have 4 chunks total:
# - 2 chunks for POST (4 items with max 2 per chunk)
# - 2 chunks for PATCH (3 items with max 2 per chunk)
assert result == 4
assert mock_emit.call_count == 4
# Count the calls by method and verify chunking
post_calls = [
call for call in mock_emit.call_args_list if call[1]["method"] == "post"
]
patch_calls = [
call
for call in mock_emit.call_args_list
if call[1]["method"] == "patch"
]
assert len(post_calls) == 2 # 2 chunks for POST
assert len(patch_calls) == 2 # 2 chunks for PATCH
# Verify first chunks have max size and last chunks have remainders
post_payloads = [json.loads(call[1]["payload"]) for call in post_calls]
patch_payloads = [json.loads(call[1]["payload"]) for call in patch_calls]
assert len(post_payloads[0]) == 2
assert len(post_payloads[1]) == 2
assert len(patch_payloads[0]) == 2
assert len(patch_payloads[1]) == 1
# Verify all post calls are to the dataset endpoint
for call in post_calls:
assert (
call[0][0]
== f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true"
)
# Verify all patch calls are to the dataset endpoint
for call in patch_calls:
assert (
call[0][0]
== f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true"
)
class TestOpenApiModeSelection:
def test_sdk_client_mode_no_env_var(self):
"""Test that SDK client mode defaults to OpenAPI when no env var is present"""
# Ensure no env vars
with patch.dict(os.environ, {}, clear=True), patch(
"datahub.emitter.rest_emitter.RestServiceConfig"
) as mock_config_class:
# Setup the mock config instance
mock_config_instance = Mock()
mock_config_instance.supports_feature.return_value = True
mock_config_class.return_value = mock_config_instance
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK)
emitter.test_connection()
assert emitter._openapi_ingestion is True
def test_non_sdk_client_mode_no_env_var(self):
"""Test that non-SDK client modes default to RestLi when no env var is present"""
# Ensure no env vars
with patch.dict(os.environ, {}, clear=True), patch(
"datahub.emitter.rest_emitter.RestServiceConfig"
) as mock_config_class:
# Setup the mock config instance
mock_config_instance = Mock()
mock_config_instance.supports_feature.return_value = True
mock_config_class.return_value = mock_config_instance
# Test INGESTION mode
emitter = DataHubRestEmitter(
MOCK_GMS_ENDPOINT, client_mode=ClientMode.INGESTION
)
emitter.test_connection()
assert emitter._openapi_ingestion is False
# Test CLI mode
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.CLI)
emitter.test_connection()
assert emitter._openapi_ingestion is False
def test_env_var_restli_overrides_sdk_mode(self):
"""Test that env var set to RESTLI overrides SDK client mode default"""
with patch.dict(
os.environ, {"DATAHUB_REST_EMITTER_DEFAULT_ENDPOINT": "RESTLI"}, clear=True
), patch(
"datahub.emitter.rest_emitter.DEFAULT_REST_EMITTER_ENDPOINT",
RestSinkEndpoint.RESTLI,
), patch("datahub.emitter.rest_emitter.RestServiceConfig") as mock_config_class:
# Setup the mock config instance
mock_config_instance = Mock()
mock_config_instance.supports_feature.return_value = True
mock_config_class.return_value = mock_config_instance
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK)
emitter.test_connection()
assert emitter._openapi_ingestion is False
def test_env_var_openapi_any_client_mode(self):
"""Test that env var set to OPENAPI enables OpenAPI for any client mode"""
with patch.dict(
os.environ, {"DATAHUB_REST_EMITTER_DEFAULT_ENDPOINT": "OPENAPI"}, clear=True
), patch(
"datahub.emitter.rest_emitter.DEFAULT_REST_EMITTER_ENDPOINT",
RestSinkEndpoint.OPENAPI,
), patch("datahub.emitter.rest_emitter.RestServiceConfig") as mock_config_class:
# Setup the mock config instance
mock_config_instance = Mock()
mock_config_instance.supports_feature.return_value = True
mock_config_class.return_value = mock_config_instance
# Test INGESTION mode
emitter = DataHubRestEmitter(
MOCK_GMS_ENDPOINT, client_mode=ClientMode.INGESTION
)
emitter.test_connection()
assert emitter._openapi_ingestion is True
# Test CLI mode
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.CLI)
emitter.test_connection()
assert emitter._openapi_ingestion is True
# Test SDK mode
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK)
emitter.test_connection()
assert emitter._openapi_ingestion is True
def test_constructor_param_true_overrides_all(self):
"""Test that explicit constructor parameter True overrides all other settings"""
with patch.dict(
os.environ, {"DATAHUB_REST_EMITTER_DEFAULT_ENDPOINT": "RESTLI"}, clear=True
), patch(
"datahub.emitter.rest_emitter.DEFAULT_REST_EMITTER_ENDPOINT",
RestSinkEndpoint.RESTLI,
):
# Even with env var and non-SDK mode, constructor param should win
emitter = DataHubRestEmitter(
MOCK_GMS_ENDPOINT,
client_mode=ClientMode.INGESTION,
openapi_ingestion=True,
)
assert emitter._openapi_ingestion is True
def test_constructor_param_false_overrides_all(self):
"""Test that explicit constructor parameter False overrides all other settings"""
with patch.dict(
os.environ, {"DATAHUB_REST_EMITTER_DEFAULT_ENDPOINT": "OPENAPI"}, clear=True
), patch(
"datahub.emitter.rest_emitter.DEFAULT_REST_EMITTER_ENDPOINT",
RestSinkEndpoint.OPENAPI,
):
# Even with env var and SDK mode, constructor param should win
emitter = DataHubRestEmitter(
MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK, openapi_ingestion=False
)
assert emitter._openapi_ingestion is False
def test_debug_logging(self):
"""Test that debug logging is called with correct protocol information"""
with patch("datahub.emitter.rest_emitter.logger") as mock_logger, patch(
"datahub.emitter.rest_emitter.RestServiceConfig"
) as mock_config_class:
# Setup the mock config instance
mock_config_instance = Mock()
mock_config_instance.supports_feature.return_value = True
mock_config_class.return_value = mock_config_instance
# Test OpenAPI logging
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, openapi_ingestion=True)
emitter.test_connection()
mock_logger.debug.assert_called_with("Using OpenAPI for ingestion.")
# Test RestLi logging
mock_logger.reset_mock()
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, openapi_ingestion=False)
emitter.test_connection()
mock_logger.debug.assert_called_with("Using Restli for ingestion.")
class TestOpenApiIntegration:
def test_sdk_mode_uses_openapi_by_default(self):
"""Test that SDK mode uses OpenAPI by default for emit_mcp"""
with patch.dict("os.environ", {}, clear=True), patch(
"datahub.emitter.rest_emitter.RestServiceConfig"
) as mock_config_class:
# Setup the mock config instance
mock_config_instance = Mock()
mock_config_instance.supports_feature.return_value = True
mock_config_class.return_value = mock_config_instance
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK)
emitter.test_connection()
# Verify _openapi_ingestion was set correctly
assert emitter._openapi_ingestion is True
# Create test MCP
test_mcp = MetadataChangeProposalWrapper(
entityUrn="urn:li:dataset:(test,sdk,PROD)",
aspect=Status(removed=False),
)
# Mock _emit_generic to inspect what URL is used
with patch.object(emitter, "_emit_generic") as mock_emit:
emitter.emit_mcp(test_mcp)
# Check that OpenAPI URL format was used
mock_emit.assert_called_once()
url = mock_emit.call_args[0][0]
assert "openapi" in url
assert url.startswith(f"{MOCK_GMS_ENDPOINT}/openapi")
def test_ingestion_mode_uses_restli_by_default(self):
"""Test that INGESTION mode uses RestLi by default for emit_mcp"""
with patch.dict("os.environ", {}, clear=True), patch(
"datahub.emitter.rest_emitter.RestServiceConfig"
) as mock_config_class:
# Setup the mock config instance
mock_config_instance = Mock()
mock_config_instance.supports_feature.return_value = True
mock_config_class.return_value = mock_config_instance
emitter = DataHubRestEmitter(
MOCK_GMS_ENDPOINT, client_mode=ClientMode.INGESTION
)
emitter.test_connection()
# Verify _openapi_ingestion was set correctly
assert emitter._openapi_ingestion is False
# Create test MCP
test_mcp = MetadataChangeProposalWrapper(
entityUrn="urn:li:dataset:(test,ingestion,PROD)",
aspect=Status(removed=False),
)
# Mock _emit_generic to inspect what URL is used
with patch.object(emitter, "_emit_generic") as mock_emit:
emitter.emit_mcp(test_mcp)
# Check that RestLi URL format was used (not OpenAPI)
mock_emit.assert_called_once()
url = mock_emit.call_args[0][0]
assert "openapi" not in url
assert "aspects?action=ingestProposal" in url
def test_explicit_openapi_parameter_uses_openapi_api(self):
"""Test that explicit openapi_ingestion=True uses OpenAPI regardless of mode"""
emitter = DataHubRestEmitter(
MOCK_GMS_ENDPOINT,
client_mode=ClientMode.INGESTION, # Would normally use RestLi
openapi_ingestion=True, # Override to use OpenAPI
)
# Verify _openapi_ingestion was set correctly
assert emitter._openapi_ingestion is True
# Create test MCP
test_mcp = MetadataChangeProposalWrapper(
entityUrn="urn:li:dataset:(test,explicit,PROD)",
aspect=DatasetProfile(rowCount=100, columnCount=10, timestampMillis=0),
)
# Mock _emit_generic to inspect what URL is used
with patch.object(emitter, "_emit_generic") as mock_emit:
emitter.emit_mcp(test_mcp)
# Check that OpenAPI URL format was used
mock_emit.assert_called_once()
url = mock_emit.call_args[0][0]
assert "openapi" in url
# Verify the payload is formatted for OpenAPI
payload = mock_emit.call_args[1]["payload"]
assert isinstance(payload, list) or (
isinstance(payload, str) and payload.startswith("[")
)
def test_openapi_batch_endpoint_selection(self):
"""Test that OpenAPI batch operations use correct endpoints based on entity type"""
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, openapi_ingestion=True)
# Create MCPs with different entity types
dataset_mcp = MetadataChangeProposalWrapper(
entityUrn="urn:li:dataset:(test,batch1,PROD)",
entityType="dataset",
aspect=Status(removed=False),
)
dashboard_mcp = MetadataChangeProposalWrapper(
entityUrn="urn:li:dashboard:(test,batch2)",
entityType="dashboard",
aspect=Status(removed=False),
)
# Mock _emit_generic to inspect URLs used
with patch.object(emitter, "_emit_generic") as mock_emit:
# Configure mock to return appropriate responses
mock_response = Mock()
mock_response.status_code = 200
mock_response.headers = {}
mock_response.json.return_value = []
mock_emit.return_value = mock_response
# Emit batch of different entity types
emitter.emit_mcps([dataset_mcp, dashboard_mcp])
# Verify we made two calls to different endpoints
assert mock_emit.call_count == 2
# Get the URLs from the calls
calls = mock_emit.call_args_list
urls = [call[0][0] for call in calls]
# Verify we called both entity endpoints
assert any("entity/dataset" in url for url in urls)
assert any("entity/dashboard" in url for url in urls)
class TestRequestsSessionConfig:
def test_get_client_mode_from_session(self):
"""Test extracting ClientMode from session headers with various inputs."""
# Case 1: Session with valid ClientMode in headers (string)
session = Session()
session.headers.update({"X-DataHub-Client-Mode": "SDK"})
mode = RequestsSessionConfig.get_client_mode_from_session(session)
assert mode == ClientMode.SDK
# Case 2: Session with valid ClientMode in headers (bytes)
session = Session()
session.headers.update({"X-DataHub-Client-Mode": b"INGESTION"})
mode = RequestsSessionConfig.get_client_mode_from_session(session)
assert mode == ClientMode.INGESTION
# Case 3: Session with no ClientMode header
session = Session()
mode = RequestsSessionConfig.get_client_mode_from_session(session)
assert mode is None
# Case 4: Session with invalid ClientMode value
session = Session()
session.headers.update({"X-DataHub-Client-Mode": "INVALID_MODE"})
mode = RequestsSessionConfig.get_client_mode_from_session(session)
assert mode is None
# Case 5: Session with empty ClientMode value
session = Session()
session.headers.update({"X-DataHub-Client-Mode": ""})
mode = RequestsSessionConfig.get_client_mode_from_session(session)
assert mode is None
# Case 6: Test with exception during processing
mock_session = Mock()
mock_session.headers = {
"X-DataHub-Client-Mode": object()
} # Will cause error when decoded
mode = RequestsSessionConfig.get_client_mode_from_session(mock_session)
assert mode is None
# Case 7: Different capitalization should not match
session = Session()
session.headers.update({"X-DataHub-Client-Mode": "sdk"}) # lowercase
mode = RequestsSessionConfig.get_client_mode_from_session(session)
assert mode is None
# Case 8: Test with all available ClientMode values
for client_mode in ClientMode:
session = Session()
session.headers.update({"X-DataHub-Client-Mode": client_mode.name})
mode = RequestsSessionConfig.get_client_mode_from_session(session)
assert mode == client_mode