mirror of
https://github.com/datahub-project/datahub.git
synced 2025-07-07 09:11:47 +00:00
1336 lines
52 KiB
Python
1336 lines
52 KiB
Python
import json
|
|
import os
|
|
from datetime import timedelta
|
|
from typing import Any, Dict
|
|
from unittest.mock import ANY, MagicMock, Mock, patch
|
|
|
|
import pytest
|
|
from requests import Response, Session
|
|
|
|
from datahub.configuration.common import (
|
|
ConfigurationError,
|
|
TraceTimeoutError,
|
|
TraceValidationError,
|
|
)
|
|
from datahub.emitter import rest_emitter
|
|
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
|
from datahub.emitter.response_helper import TraceData
|
|
from datahub.emitter.rest_emitter import (
|
|
BATCH_INGEST_MAX_PAYLOAD_LENGTH,
|
|
INGEST_MAX_PAYLOAD_BYTES,
|
|
DataHubRestEmitter,
|
|
DatahubRestEmitter,
|
|
EmitMode,
|
|
RequestsSessionConfig,
|
|
RestSinkEndpoint,
|
|
logger,
|
|
)
|
|
from datahub.errors import APITracingWarning
|
|
from datahub.ingestion.graph.config import ClientMode
|
|
from datahub.metadata.com.linkedin.pegasus2avro.common import (
|
|
Status,
|
|
)
|
|
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
|
|
DatasetProfile,
|
|
DatasetProperties,
|
|
)
|
|
from datahub.metadata.schema_classes import (
|
|
ChangeTypeClass,
|
|
)
|
|
from datahub.specific.dataset import DatasetPatchBuilder
|
|
from datahub.utilities.server_config_util import RestServiceConfig
|
|
|
|
MOCK_GMS_ENDPOINT = "http://fakegmshost:8080"
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_config() -> Dict[str, Any]:
|
|
return {
|
|
"models": {},
|
|
"patchCapable": True,
|
|
"versions": {
|
|
"acryldata/datahub": {
|
|
"version": "v1.0.1",
|
|
"commit": "dc127c5f031d579732899ccd81a53a3514dc4a6d",
|
|
}
|
|
},
|
|
"managedIngestion": {"defaultCliVersion": "1.0.0.2", "enabled": True},
|
|
"statefulIngestionCapable": True,
|
|
"supportsImpactAnalysis": True,
|
|
"timeZone": "GMT",
|
|
"telemetry": {"enabledCli": True, "enabledIngestion": False},
|
|
"datasetUrnNameCasing": False,
|
|
"retention": "true",
|
|
"datahub": {"serverType": "dev"},
|
|
"noCode": "true",
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_session():
|
|
session = MagicMock(spec=Session)
|
|
return session
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_response(sample_config):
|
|
response = MagicMock()
|
|
response.status_code = 200
|
|
response.json.return_value = sample_config
|
|
response.text = str(sample_config)
|
|
return response
|
|
|
|
|
|
class TestDataHubRestEmitter:
|
|
@pytest.fixture
|
|
def openapi_emitter(self) -> DataHubRestEmitter:
|
|
openapi_emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, openapi_ingestion=True)
|
|
|
|
# Set the underlying private attribute directly
|
|
openapi_emitter._server_config = RestServiceConfig(
|
|
raw_config={
|
|
"versions": {
|
|
"acryldata/datahub": {
|
|
"version": "v1.0.1rc0" # Supports OpenApi & Tracing
|
|
}
|
|
}
|
|
}
|
|
)
|
|
|
|
return openapi_emitter
|
|
|
|
def test_datahub_rest_emitter_missing_gms_server(self):
|
|
"""Test that emitter raises ConfigurationError when gms_server is not provided."""
|
|
# Case 1: Empty string
|
|
with pytest.raises(ConfigurationError) as excinfo:
|
|
DataHubRestEmitter("")
|
|
assert "gms server is required" in str(excinfo.value)
|
|
|
|
def test_connection_error_reraising(self):
|
|
"""Test that test_connection() properly re-raises ConfigurationError."""
|
|
# Create a basic emitter
|
|
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT)
|
|
|
|
# Mock the session's get method to raise ConfigurationError
|
|
with patch.object(emitter._session, "get") as mock_get:
|
|
# Configure the mock to raise ConfigurationError
|
|
mock_error = ConfigurationError("Connection failed")
|
|
mock_get.side_effect = mock_error
|
|
|
|
# Verify that the exception is re-raised
|
|
with pytest.raises(ConfigurationError) as excinfo:
|
|
emitter.test_connection()
|
|
|
|
# Verify it's the same exception object (not a new one)
|
|
assert excinfo.value is mock_error
|
|
|
|
def test_datahub_rest_emitter_construction(self) -> None:
|
|
emitter = DatahubRestEmitter(MOCK_GMS_ENDPOINT)
|
|
assert emitter._session_config.timeout == rest_emitter._DEFAULT_TIMEOUT_SEC
|
|
assert (
|
|
emitter._session_config.retry_status_codes
|
|
== rest_emitter._DEFAULT_RETRY_STATUS_CODES
|
|
)
|
|
assert (
|
|
emitter._session_config.retry_max_times
|
|
== rest_emitter._DEFAULT_RETRY_MAX_TIMES
|
|
)
|
|
|
|
def test_datahub_rest_emitter_timeout_construction(self) -> None:
|
|
emitter = DatahubRestEmitter(
|
|
MOCK_GMS_ENDPOINT, connect_timeout_sec=2, read_timeout_sec=4
|
|
)
|
|
assert emitter._session_config.timeout == (2, 4)
|
|
|
|
def test_datahub_rest_emitter_general_timeout_construction(self) -> None:
|
|
emitter = DatahubRestEmitter(
|
|
MOCK_GMS_ENDPOINT, timeout_sec=2, read_timeout_sec=4
|
|
)
|
|
assert emitter._session_config.timeout == (2, 4)
|
|
|
|
def test_datahub_rest_emitter_retry_construction(self) -> None:
|
|
emitter = DatahubRestEmitter(
|
|
MOCK_GMS_ENDPOINT,
|
|
retry_status_codes=[418],
|
|
retry_max_times=42,
|
|
)
|
|
assert emitter._session_config.retry_status_codes == [418]
|
|
assert emitter._session_config.retry_max_times == 42
|
|
|
|
def test_datahub_rest_emitter_extra_params(self) -> None:
|
|
emitter = DatahubRestEmitter(
|
|
MOCK_GMS_ENDPOINT, extra_headers={"key1": "value1", "key2": "value2"}
|
|
)
|
|
assert emitter._session.headers.get("key1") == "value1"
|
|
assert emitter._session.headers.get("key2") == "value2"
|
|
|
|
def test_openapi_emitter_emit(self, openapi_emitter):
|
|
item = MetadataChangeProposalWrapper(
|
|
entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)",
|
|
aspect=DatasetProfile(
|
|
rowCount=2000,
|
|
columnCount=15,
|
|
timestampMillis=1626995099686,
|
|
),
|
|
)
|
|
|
|
with patch.object(openapi_emitter, "_emit_generic") as mock_method:
|
|
openapi_emitter.emit_mcp(item)
|
|
|
|
mock_method.assert_called_once_with(
|
|
f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=false",
|
|
payload=[
|
|
{
|
|
"urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)",
|
|
"datasetProfile": {
|
|
"value": {
|
|
"rowCount": 2000,
|
|
"columnCount": 15,
|
|
"timestampMillis": 1626995099686,
|
|
"partitionSpec": {
|
|
"partition": "FULL_TABLE_SNAPSHOT",
|
|
"type": "FULL_TABLE",
|
|
},
|
|
},
|
|
"systemMetadata": {
|
|
"lastObserved": ANY,
|
|
"runId": "no-run-id-provided",
|
|
"lastRunId": "no-run-id-provided",
|
|
"properties": {
|
|
"clientId": "acryl-datahub",
|
|
"clientVersion": "1!0.0.0.dev0",
|
|
},
|
|
},
|
|
"headers": {},
|
|
},
|
|
}
|
|
],
|
|
method="post",
|
|
)
|
|
|
|
def test_openapi_emitter_emit_mcps(self, openapi_emitter):
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
items = [
|
|
MetadataChangeProposalWrapper(
|
|
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)",
|
|
aspect=DatasetProfile(
|
|
rowCount=2000 + i,
|
|
columnCount=15,
|
|
timestampMillis=1626995099686,
|
|
),
|
|
)
|
|
for i in range(3)
|
|
]
|
|
|
|
result = openapi_emitter.emit_mcps(items)
|
|
|
|
assert result == 1
|
|
|
|
# Single chunk test - all items should be in one batch
|
|
mock_emit.assert_called_once()
|
|
call_args = mock_emit.call_args
|
|
assert (
|
|
call_args[0][0]
|
|
== f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=false"
|
|
)
|
|
assert isinstance(call_args[1]["payload"], str) # Should be JSON string
|
|
|
|
def test_openapi_emitter_emit_mcps_max_bytes(self, openapi_emitter):
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
# Create a large payload that will force chunking
|
|
large_payload = "x" * (
|
|
INGEST_MAX_PAYLOAD_BYTES // 2
|
|
) # Each item will be about half max size
|
|
items = [
|
|
MetadataChangeProposalWrapper(
|
|
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,LargePayload{i},PROD)",
|
|
aspect=DatasetProperties(name=large_payload),
|
|
)
|
|
for i in range(
|
|
3
|
|
) # This should create at least 2 chunks given the payload size
|
|
]
|
|
|
|
openapi_emitter.emit_mcps(items)
|
|
|
|
# Verify multiple chunks were created
|
|
assert mock_emit.call_count > 1
|
|
|
|
# Verify each chunk's payload is within size limits
|
|
for call in mock_emit.call_args_list:
|
|
args = call[1]
|
|
assert "payload" in args
|
|
payload = args["payload"]
|
|
assert isinstance(payload, str) # Should be JSON string
|
|
assert len(payload.encode()) <= INGEST_MAX_PAYLOAD_BYTES
|
|
|
|
# Verify the payload structure
|
|
payload_data = json.loads(payload)
|
|
assert payload_data[0]["urn"].startswith("urn:li:dataset")
|
|
assert "datasetProperties" in payload_data[0]
|
|
|
|
def test_openapi_emitter_emit_mcps_max_items(self, openapi_emitter):
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
# Create more items than BATCH_INGEST_MAX_PAYLOAD_LENGTH
|
|
items = [
|
|
MetadataChangeProposalWrapper(
|
|
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,Item{i},PROD)",
|
|
aspect=DatasetProfile(
|
|
rowCount=i,
|
|
columnCount=15,
|
|
timestampMillis=1626995099686,
|
|
),
|
|
)
|
|
for i in range(
|
|
BATCH_INGEST_MAX_PAYLOAD_LENGTH + 2
|
|
) # Create 2 more than max
|
|
]
|
|
|
|
openapi_emitter.emit_mcps(items)
|
|
|
|
# Verify multiple chunks were created
|
|
assert mock_emit.call_count == 2
|
|
|
|
# Check first chunk
|
|
first_call = mock_emit.call_args_list[0]
|
|
first_payload = json.loads(first_call[1]["payload"])
|
|
assert len(first_payload) == BATCH_INGEST_MAX_PAYLOAD_LENGTH
|
|
|
|
# Check second chunk
|
|
second_call = mock_emit.call_args_list[1]
|
|
second_payload = json.loads(second_call[1]["payload"])
|
|
assert len(second_payload) == 2 # Should have the remaining 2 items
|
|
|
|
def test_openapi_emitter_emit_mcps_multiple_entity_types(self, openapi_emitter):
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
# Create items for two different entity types
|
|
dataset_items = [
|
|
MetadataChangeProposalWrapper(
|
|
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,Dataset{i},PROD)",
|
|
aspect=DatasetProfile(
|
|
rowCount=i,
|
|
columnCount=15,
|
|
timestampMillis=1626995099686,
|
|
),
|
|
)
|
|
for i in range(2)
|
|
]
|
|
|
|
dashboard_items = [
|
|
MetadataChangeProposalWrapper(
|
|
entityUrn=f"urn:li:dashboard:(looker,dashboards.{i})",
|
|
aspect=Status(removed=False),
|
|
)
|
|
for i in range(2)
|
|
]
|
|
|
|
items = dataset_items + dashboard_items
|
|
result = openapi_emitter.emit_mcps(items)
|
|
|
|
# Should return number of unique entity URLs
|
|
assert result == 2
|
|
assert mock_emit.call_count == 2
|
|
|
|
# Check that calls were made with different URLs but correct payloads
|
|
calls = {
|
|
call[0][0]: json.loads(call[1]["payload"])
|
|
for call in mock_emit.call_args_list
|
|
}
|
|
|
|
# Verify each URL got the right aspects
|
|
for url, payload in calls.items():
|
|
if "datasetProfile" in payload[0]:
|
|
assert url.endswith("dataset?async=false")
|
|
assert len(payload) == 2
|
|
assert all("datasetProfile" in item for item in payload)
|
|
else:
|
|
assert url.endswith("dashboard?async=false")
|
|
assert len(payload) == 2
|
|
assert all("status" in item for item in payload)
|
|
|
|
def test_openapi_emitter_emit_mcp_with_tracing(self, openapi_emitter):
|
|
"""Test emitting a single MCP with tracing enabled"""
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
# Mock the response for the initial emit
|
|
mock_response = Mock(spec=Response)
|
|
mock_response.status_code = 200
|
|
mock_response.headers = {
|
|
"traceparent": "00-00063609cb934b9d0d4e6a7d6d5e1234-1234567890abcdef-01"
|
|
}
|
|
mock_response.json.return_value = [
|
|
{
|
|
"urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)",
|
|
"datasetProfile": {},
|
|
}
|
|
]
|
|
mock_emit.return_value = mock_response
|
|
|
|
# Create test item
|
|
item = MetadataChangeProposalWrapper(
|
|
entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)",
|
|
aspect=DatasetProfile(
|
|
rowCount=2000,
|
|
columnCount=15,
|
|
timestampMillis=1626995099686,
|
|
),
|
|
)
|
|
|
|
# Set up mock for trace verification responses
|
|
trace_responses = [
|
|
# First check - pending
|
|
{
|
|
"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)": {
|
|
"datasetProfile": {
|
|
"success": True,
|
|
"primaryStorage": {"writeStatus": "PENDING"},
|
|
"searchStorage": {"writeStatus": "PENDING"},
|
|
}
|
|
}
|
|
},
|
|
# Second check - completed
|
|
{
|
|
"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)": {
|
|
"datasetProfile": {
|
|
"success": True,
|
|
"primaryStorage": {"writeStatus": "ACTIVE_STATE"},
|
|
"searchStorage": {"writeStatus": "ACTIVE_STATE"},
|
|
}
|
|
}
|
|
},
|
|
]
|
|
|
|
def side_effect(*args, **kwargs):
|
|
if "trace/write" in args[0]:
|
|
mock_trace_response = Mock(spec=Response)
|
|
mock_trace_response.json.return_value = trace_responses.pop(0)
|
|
return mock_trace_response
|
|
return mock_response
|
|
|
|
mock_emit.side_effect = side_effect
|
|
|
|
# Emit with tracing enabled
|
|
openapi_emitter.emit_mcp(
|
|
item,
|
|
emit_mode=EmitMode.ASYNC_WAIT,
|
|
wait_timeout=timedelta(seconds=10),
|
|
)
|
|
|
|
# Verify initial emit call
|
|
assert (
|
|
mock_emit.call_args_list[0][0][0]
|
|
== f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true"
|
|
)
|
|
|
|
# Verify trace verification calls
|
|
trace_calls = [
|
|
call for call in mock_emit.call_args_list if "trace/write" in call[0][0]
|
|
]
|
|
assert len(trace_calls) == 2
|
|
assert "00063609cb934b9d0d4e6a7d6d5e1234" in trace_calls[0][0][0]
|
|
|
|
def test_openapi_emitter_emit_mcps_with_tracing(self, openapi_emitter):
|
|
"""Test emitting multiple MCPs with tracing enabled"""
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
# Create test items
|
|
items = [
|
|
MetadataChangeProposalWrapper(
|
|
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)",
|
|
aspect=Status(removed=False),
|
|
)
|
|
for i in range(2)
|
|
]
|
|
|
|
# Mock responses for initial emit
|
|
emit_responses = []
|
|
mock_resp = Mock(spec=Response)
|
|
mock_resp.status_code = 200
|
|
mock_resp.headers = {"traceparent": "test-trace"}
|
|
mock_resp.json.return_value = [
|
|
{
|
|
"urn": f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)",
|
|
"status": {"removed": False},
|
|
}
|
|
for i in range(2)
|
|
]
|
|
emit_responses.append(mock_resp)
|
|
|
|
# Mock trace verification responses
|
|
trace_responses = [
|
|
# First check - all pending
|
|
{
|
|
f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)": {
|
|
"status": {
|
|
"success": True,
|
|
"primaryStorage": {"writeStatus": "PENDING"},
|
|
"searchStorage": {"writeStatus": "PENDING"},
|
|
}
|
|
}
|
|
for i in range(2)
|
|
},
|
|
# Second check - all completed
|
|
{
|
|
f"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount{i},PROD)": {
|
|
"status": {
|
|
"success": True,
|
|
"primaryStorage": {"writeStatus": "ACTIVE_STATE"},
|
|
"searchStorage": {"writeStatus": "ACTIVE_STATE"},
|
|
}
|
|
}
|
|
for i in range(2)
|
|
},
|
|
]
|
|
|
|
response_iter = iter(emit_responses)
|
|
trace_response_iter = iter(trace_responses)
|
|
|
|
def side_effect(*args, **kwargs):
|
|
if "trace/write" in args[0]:
|
|
mock_trace_response = Mock(spec=Response)
|
|
mock_trace_response.json.return_value = next(trace_response_iter)
|
|
return mock_trace_response
|
|
return next(response_iter)
|
|
|
|
mock_emit.side_effect = side_effect
|
|
|
|
# Emit with tracing enabled
|
|
result = openapi_emitter.emit_mcps(
|
|
items,
|
|
emit_mode=EmitMode.ASYNC_WAIT,
|
|
wait_timeout=timedelta(seconds=10),
|
|
)
|
|
|
|
assert result == 1 # Should return number of unique entity URLs
|
|
|
|
# Verify initial emit calls
|
|
emit_calls = [
|
|
call
|
|
for call in mock_emit.call_args_list
|
|
if "entity/dataset" in call[0][0]
|
|
]
|
|
assert len(emit_calls) == 1
|
|
assert (
|
|
emit_calls[0][0][0]
|
|
== f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true"
|
|
)
|
|
|
|
# Verify trace verification calls
|
|
trace_calls = [
|
|
call for call in mock_emit.call_args_list if "trace/write" in call[0][0]
|
|
]
|
|
assert len(trace_calls) == 2
|
|
|
|
def test_openapi_emitter_trace_timeout(self, openapi_emitter):
|
|
"""Test that tracing properly handles timeouts"""
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
# Mock initial emit response
|
|
mock_response = Mock(spec=Response)
|
|
mock_response.status_code = 200
|
|
mock_response.headers = {"traceparent": "test-trace-123"}
|
|
mock_response.json.return_value = [
|
|
{
|
|
"urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)",
|
|
"datasetProfile": {},
|
|
}
|
|
]
|
|
|
|
# Mock trace verification to always return pending
|
|
mock_trace_response = Mock(spec=Response)
|
|
mock_trace_response.json.return_value = {
|
|
"urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)": {
|
|
"datasetProfile": {
|
|
"success": True,
|
|
"primaryStorage": {"writeStatus": "PENDING"},
|
|
"searchStorage": {"writeStatus": "PENDING"},
|
|
}
|
|
}
|
|
}
|
|
|
|
def side_effect(*args, **kwargs):
|
|
return (
|
|
mock_trace_response if "trace/write" in args[0] else mock_response
|
|
)
|
|
|
|
mock_emit.side_effect = side_effect
|
|
|
|
item = MetadataChangeProposalWrapper(
|
|
entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)",
|
|
aspect=DatasetProfile(
|
|
rowCount=2000,
|
|
columnCount=15,
|
|
timestampMillis=1626995099686,
|
|
),
|
|
)
|
|
|
|
# Emit with very short timeout
|
|
with pytest.raises(TraceTimeoutError) as exc_info:
|
|
openapi_emitter.emit_mcp(
|
|
item,
|
|
emit_mode=EmitMode.ASYNC_WAIT,
|
|
wait_timeout=timedelta(milliseconds=1),
|
|
)
|
|
|
|
assert "Timeout waiting for async write completion" in str(exc_info.value)
|
|
|
|
def test_openapi_emitter_missing_trace_header(self, openapi_emitter):
|
|
"""Test behavior when trace header is missing"""
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
# Mock response without trace header
|
|
mock_response = Mock(spec=Response)
|
|
mock_response.status_code = 200
|
|
mock_response.headers = {} # No traceparent header
|
|
mock_response.json.return_value = [
|
|
{
|
|
"urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,MissingTraceHeader,PROD)",
|
|
"status": {"removed": False},
|
|
}
|
|
]
|
|
mock_emit.return_value = mock_response
|
|
|
|
item = MetadataChangeProposalWrapper(
|
|
entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,MissingTraceHeader,PROD)",
|
|
aspect=Status(removed=False),
|
|
)
|
|
|
|
# Should not raise exception but log a warning.
|
|
with pytest.warns(APITracingWarning):
|
|
openapi_emitter.emit_mcp(
|
|
item,
|
|
emit_mode=EmitMode.ASYNC_WAIT,
|
|
wait_timeout=timedelta(seconds=10),
|
|
)
|
|
|
|
def test_openapi_emitter_invalid_status_code(self, openapi_emitter):
|
|
"""Test behavior when response has non-200 status code"""
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
# Mock response with error status code
|
|
mock_response = Mock(spec=Response)
|
|
mock_response.status_code = 500
|
|
mock_response.headers = {"traceparent": "test-trace-123"}
|
|
mock_response.json.return_value = [
|
|
{
|
|
"urn": "urn:li:dataset:(urn:li:dataPlatform:mysql,InvalidStatusCode,PROD)",
|
|
"datasetProfile": {},
|
|
}
|
|
]
|
|
mock_emit.return_value = mock_response
|
|
|
|
item = MetadataChangeProposalWrapper(
|
|
entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,InvalidStatusCode,PROD)",
|
|
aspect=DatasetProfile(
|
|
rowCount=2000,
|
|
columnCount=15,
|
|
timestampMillis=1626995099686,
|
|
),
|
|
)
|
|
|
|
# Should not raise exception but log warning
|
|
openapi_emitter.emit_mcp(
|
|
item,
|
|
emit_mode=EmitMode.ASYNC_WAIT,
|
|
wait_timeout=timedelta(seconds=10),
|
|
)
|
|
|
|
def test_openapi_emitter_trace_failure(self, openapi_emitter):
|
|
"""Test handling of trace verification failures"""
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
test_urn = "urn:li:dataset:(urn:li:dataPlatform:mysql,TraceFailure,PROD)"
|
|
|
|
# Create initial emit response
|
|
emit_response = Mock(spec=Response)
|
|
emit_response.status_code = 200
|
|
emit_response.headers = {
|
|
"traceparent": "00-00063609cb934b9d0d4e6a7d6d5e1234-1234567890abcdef-01"
|
|
}
|
|
emit_response.json.return_value = [{"urn": test_urn, "datasetProfile": {}}]
|
|
|
|
# Create trace verification response
|
|
trace_response = Mock(spec=Response)
|
|
trace_response.status_code = 200
|
|
trace_response.headers = {}
|
|
trace_response.json.return_value = {
|
|
test_urn: {
|
|
"datasetProfile": {
|
|
"success": False,
|
|
"error": "Failed to write to storage",
|
|
"primaryStorage": {"writeStatus": "ERROR"},
|
|
"searchStorage": {"writeStatus": "ERROR"},
|
|
}
|
|
}
|
|
}
|
|
|
|
def side_effect(*args, **kwargs):
|
|
if "trace/write" in args[0]:
|
|
return trace_response
|
|
return emit_response
|
|
|
|
mock_emit.side_effect = side_effect
|
|
|
|
item = MetadataChangeProposalWrapper(
|
|
entityUrn=test_urn,
|
|
aspect=DatasetProfile(
|
|
rowCount=2000,
|
|
columnCount=15,
|
|
timestampMillis=1626995099686,
|
|
),
|
|
)
|
|
|
|
with pytest.raises(TraceValidationError) as exc_info:
|
|
openapi_emitter.emit_mcp(
|
|
item,
|
|
emit_mode=EmitMode.ASYNC_WAIT,
|
|
wait_timeout=timedelta(seconds=10),
|
|
)
|
|
|
|
error_message = str(exc_info.value)
|
|
|
|
# Check for key error message components
|
|
assert "Unable to validate async write" in error_message
|
|
assert "to DataHub GMS" in error_message
|
|
assert "Failed to write to storage" in error_message
|
|
assert "primaryStorage" in error_message
|
|
assert "writeStatus" in error_message
|
|
assert "'ERROR'" in error_message
|
|
|
|
# Verify trace was actually called
|
|
trace_calls = [
|
|
call for call in mock_emit.call_args_list if "trace/write" in call[0][0]
|
|
]
|
|
assert len(trace_calls) > 0
|
|
assert "00063609cb934b9d0d4e6a7d6d5e1234" in trace_calls[0][0][0]
|
|
|
|
def test_await_status_empty_trace_data(self, openapi_emitter):
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
openapi_emitter._await_status([], timedelta(seconds=10))
|
|
assert not mock_emit._emit_generic.called
|
|
|
|
def test_await_status_successful_completion(self, openapi_emitter):
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
trace = TraceData(
|
|
trace_id="test-trace-id", data={"urn:li:dataset:test": ["status"]}
|
|
)
|
|
|
|
mock_response = Mock(
|
|
json=lambda: {
|
|
"urn:li:dataset:test": {
|
|
"status": {
|
|
"success": True,
|
|
"primaryStorage": {"writeStatus": "ACTIVE_STATE"},
|
|
"searchStorage": {"writeStatus": "TRACE_NOT_IMPLEMENTED"},
|
|
}
|
|
}
|
|
}
|
|
)
|
|
mock_emit.return_value = mock_response
|
|
|
|
openapi_emitter._await_status([trace], timedelta(seconds=10))
|
|
assert not trace.data # Should be empty after successful completion
|
|
|
|
def test_await_status_timeout(self, openapi_emitter):
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
trace = TraceData(
|
|
trace_id="test-trace-id", data={"urn:li:dataset:test": ["status"]}
|
|
)
|
|
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {
|
|
"urn:li:dataset:test": {
|
|
"status": {
|
|
"success": True,
|
|
"primaryStorage": {"writeStatus": "PENDING"},
|
|
"searchStorage": {"writeStatus": "PENDING"},
|
|
}
|
|
}
|
|
}
|
|
mock_emit.return_value = mock_response
|
|
|
|
with pytest.raises(TraceTimeoutError) as exc_info:
|
|
openapi_emitter._await_status([trace], timedelta(seconds=0.1))
|
|
assert "Timeout waiting for async write completion" in str(exc_info.value)
|
|
|
|
def test_await_status_persistence_failure(self, openapi_emitter):
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
trace = TraceData(
|
|
trace_id="test-trace-id", data={"urn:li:dataset:test": ["status"]}
|
|
)
|
|
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {
|
|
"urn:li:dataset:test": {
|
|
"status": {
|
|
"success": False,
|
|
"primaryStorage": {"writeStatus": "ERROR"},
|
|
"searchStorage": {"writeStatus": "PENDING"},
|
|
}
|
|
}
|
|
}
|
|
mock_emit.return_value = mock_response
|
|
|
|
with pytest.raises(TraceValidationError) as exc_info:
|
|
openapi_emitter._await_status([trace], timedelta(seconds=10))
|
|
assert "Persistence failure" in str(exc_info.value)
|
|
|
|
def test_await_status_multiple_aspects(self, openapi_emitter):
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
trace = TraceData(
|
|
trace_id="test-trace-id",
|
|
data={"urn:li:dataset:test": ["status", "schema"]},
|
|
)
|
|
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {
|
|
"urn:li:dataset:test": {
|
|
"status": {
|
|
"success": True,
|
|
"primaryStorage": {"writeStatus": "ACTIVE_STATE"},
|
|
"searchStorage": {"writeStatus": "ACTIVE_STATE"},
|
|
},
|
|
"schema": {
|
|
"success": True,
|
|
"primaryStorage": {"writeStatus": "HISTORIC_STATE"},
|
|
"searchStorage": {"writeStatus": "NO_OP"},
|
|
},
|
|
}
|
|
}
|
|
mock_emit.return_value = mock_response
|
|
|
|
openapi_emitter._await_status([trace], timedelta(seconds=10))
|
|
assert not trace.data
|
|
|
|
def test_await_status_logging(self, openapi_emitter):
|
|
with patch.object(logger, "debug") as mock_debug, patch.object(
|
|
logger, "error"
|
|
) as mock_error, patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
# Test empty trace data logging
|
|
openapi_emitter._await_status([], timedelta(seconds=10))
|
|
mock_debug.assert_called_once_with("No trace data to verify")
|
|
|
|
# Test error logging
|
|
trace = TraceData(trace_id="test-id", data={"urn:test": ["status"]})
|
|
mock_emit.side_effect = TraceValidationError("Test error")
|
|
with pytest.raises(TraceValidationError):
|
|
openapi_emitter._await_status([trace], timedelta(seconds=10))
|
|
mock_error.assert_called_once()
|
|
|
|
def test_openapi_emitter_same_url_different_methods(self, openapi_emitter):
|
|
"""Test handling of requests with same URL but different HTTP methods"""
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit:
|
|
items = [
|
|
# POST requests for updating
|
|
MetadataChangeProposalWrapper(
|
|
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,UpdateMe{i},PROD)",
|
|
entityType="dataset",
|
|
aspectName="datasetProperties",
|
|
changeType=ChangeTypeClass.UPSERT,
|
|
aspect=DatasetProperties(name=f"Updated Dataset {i}"),
|
|
)
|
|
for i in range(2)
|
|
] + [
|
|
# PATCH requests for fetching
|
|
next(
|
|
iter(
|
|
DatasetPatchBuilder(
|
|
f"urn:li:dataset:(urn:li:dataPlatform:mysql,PatchMe{i},PROD)"
|
|
)
|
|
.set_qualified_name(f"PatchMe{i}")
|
|
.build()
|
|
)
|
|
)
|
|
for i in range(2)
|
|
]
|
|
|
|
# Run the test
|
|
result = openapi_emitter.emit_mcps(items)
|
|
|
|
# Verify that we made 2 calls (one for each HTTP method)
|
|
assert result == 2
|
|
assert mock_emit.call_count == 2
|
|
|
|
# Check that calls were made with different methods but the same URL
|
|
calls = {}
|
|
for call in mock_emit.call_args_list:
|
|
method = call[1]["method"]
|
|
url = call[0][0]
|
|
calls[(method, url)] = call
|
|
|
|
assert (
|
|
"post",
|
|
f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=false",
|
|
) in calls
|
|
assert (
|
|
"patch",
|
|
f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=false",
|
|
) in calls
|
|
|
|
def test_openapi_emitter_mixed_method_chunking(self, openapi_emitter):
|
|
"""Test that chunking works correctly across different HTTP methods"""
|
|
with patch(
|
|
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
|
) as mock_emit, patch(
|
|
"datahub.emitter.rest_emitter.BATCH_INGEST_MAX_PAYLOAD_LENGTH", 2
|
|
):
|
|
# Create more items than the chunk size for each method
|
|
items = [
|
|
# POST items (4 items, should create 2 chunks)
|
|
MetadataChangeProposalWrapper(
|
|
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,Dataset{i},PROD)",
|
|
entityType="dataset",
|
|
aspectName="datasetProfile",
|
|
changeType=ChangeTypeClass.UPSERT,
|
|
aspect=DatasetProfile(
|
|
rowCount=i, columnCount=15, timestampMillis=0
|
|
),
|
|
)
|
|
for i in range(4)
|
|
] + [
|
|
# PATCH items (3 items, should create 2 chunks)
|
|
next(
|
|
iter(
|
|
DatasetPatchBuilder(
|
|
f"urn:li:dataset:(urn:li:dataPlatform:mysql,PatchMe{i},PROD)"
|
|
)
|
|
.set_qualified_name(f"PatchMe{i}")
|
|
.build()
|
|
)
|
|
)
|
|
for i in range(3)
|
|
]
|
|
|
|
# Run the test with a smaller chunk size to force multiple chunks
|
|
result = openapi_emitter.emit_mcps(items)
|
|
|
|
# Should have 4 chunks total:
|
|
# - 2 chunks for POST (4 items with max 2 per chunk)
|
|
# - 2 chunks for PATCH (3 items with max 2 per chunk)
|
|
assert result == 4
|
|
assert mock_emit.call_count == 4
|
|
|
|
# Count the calls by method and verify chunking
|
|
post_calls = [
|
|
call for call in mock_emit.call_args_list if call[1]["method"] == "post"
|
|
]
|
|
patch_calls = [
|
|
call
|
|
for call in mock_emit.call_args_list
|
|
if call[1]["method"] == "patch"
|
|
]
|
|
|
|
assert len(post_calls) == 2 # 2 chunks for POST
|
|
assert len(patch_calls) == 2 # 2 chunks for PATCH
|
|
|
|
# Verify first chunks have max size and last chunks have remainders
|
|
post_payloads = [json.loads(call[1]["payload"]) for call in post_calls]
|
|
patch_payloads = [json.loads(call[1]["payload"]) for call in patch_calls]
|
|
|
|
assert len(post_payloads[0]) == 2
|
|
assert len(post_payloads[1]) == 2
|
|
assert len(patch_payloads[0]) == 2
|
|
assert len(patch_payloads[1]) == 1
|
|
|
|
# Verify all post calls are to the dataset endpoint
|
|
for call in post_calls:
|
|
assert (
|
|
call[0][0]
|
|
== f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=false"
|
|
)
|
|
|
|
# Verify all patch calls are to the dataset endpoint
|
|
for call in patch_calls:
|
|
assert (
|
|
call[0][0]
|
|
== f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=false"
|
|
)
|
|
|
|
def test_openapi_sync_full_emit_mode(self, openapi_emitter):
|
|
"""Test that SYNC_WAIT emit mode correctly sets async=false URL parameter and sync header"""
|
|
|
|
# Create a test MCP
|
|
test_mcp = MetadataChangeProposalWrapper(
|
|
entityUrn="urn:li:dataset:(test,sync_full,PROD)",
|
|
aspect=DatasetProfile(
|
|
rowCount=500,
|
|
columnCount=10,
|
|
timestampMillis=1626995099686,
|
|
),
|
|
)
|
|
|
|
# Test with SYNC_WAIT emit mode
|
|
with patch.object(openapi_emitter, "_emit_generic") as mock_emit:
|
|
# Configure mock to return a simple response
|
|
mock_response = Mock(spec=Response)
|
|
mock_response.status_code = 200
|
|
mock_response.headers = {}
|
|
mock_emit.return_value = mock_response
|
|
|
|
# Call emit_mcp with SYNC_WAIT mode
|
|
openapi_emitter.emit_mcp(test_mcp, emit_mode=EmitMode.SYNC_WAIT)
|
|
|
|
# Verify _emit_generic was called
|
|
mock_emit.assert_called_once()
|
|
|
|
# Check URL contains async=false
|
|
url = mock_emit.call_args[0][0]
|
|
assert "async=false" in url
|
|
|
|
# Check payload contains the synchronization header
|
|
payload = mock_emit.call_args[1]["payload"]
|
|
# Convert payload to dict if it's a JSON string
|
|
if isinstance(payload, str):
|
|
payload = json.loads(payload)
|
|
|
|
# Verify the headers in the payload
|
|
assert isinstance(payload, list)
|
|
assert len(payload) > 0
|
|
assert "headers" in payload[0]["datasetProfile"]
|
|
assert (
|
|
payload[0]["datasetProfile"]["headers"].get(
|
|
"X-DataHub-Sync-Index-Update"
|
|
)
|
|
== "true"
|
|
)
|
|
|
|
|
|
class TestOpenApiModeSelection:
|
|
def test_sdk_client_mode_no_env_var(self, mock_session, mock_response):
|
|
"""Test that SDK client mode defaults to OpenAPI when no env var is present"""
|
|
mock_session.get.return_value = mock_response
|
|
|
|
# Ensure no env vars
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK)
|
|
emitter._session = mock_session
|
|
emitter.test_connection()
|
|
assert emitter._openapi_ingestion is True
|
|
|
|
def test_non_sdk_client_mode_no_env_var(self, mock_session, mock_response):
|
|
"""Test that non-SDK client modes default to RestLi when no env var is present"""
|
|
mock_session.get.return_value = mock_response
|
|
|
|
# Ensure no env vars
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
# Test INGESTION mode
|
|
emitter = DataHubRestEmitter(
|
|
MOCK_GMS_ENDPOINT, client_mode=ClientMode.INGESTION
|
|
)
|
|
emitter._session = mock_session
|
|
emitter.test_connection()
|
|
assert emitter._openapi_ingestion is False
|
|
|
|
# Test CLI mode
|
|
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.CLI)
|
|
emitter._session = mock_session
|
|
emitter.test_connection()
|
|
assert emitter._openapi_ingestion is False
|
|
|
|
def test_env_var_restli_overrides_sdk_mode(self, mock_session, mock_response):
|
|
"""Test that env var set to RESTLI overrides SDK client mode default"""
|
|
mock_session.get.return_value = mock_response
|
|
|
|
with patch.dict(
|
|
os.environ, {"DATAHUB_REST_EMITTER_DEFAULT_ENDPOINT": "RESTLI"}, clear=True
|
|
), patch(
|
|
"datahub.emitter.rest_emitter.DEFAULT_REST_EMITTER_ENDPOINT",
|
|
RestSinkEndpoint.RESTLI,
|
|
):
|
|
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK)
|
|
emitter._session = mock_session
|
|
emitter.test_connection()
|
|
assert emitter._openapi_ingestion is False
|
|
|
|
def test_env_var_openapi_any_client_mode(self, mock_session, mock_response):
|
|
"""Test that env var set to OPENAPI enables OpenAPI for any client mode"""
|
|
mock_session.get.return_value = mock_response
|
|
|
|
with patch.dict(
|
|
os.environ, {"DATAHUB_REST_EMITTER_DEFAULT_ENDPOINT": "OPENAPI"}, clear=True
|
|
), patch(
|
|
"datahub.emitter.rest_emitter.DEFAULT_REST_EMITTER_ENDPOINT",
|
|
RestSinkEndpoint.OPENAPI,
|
|
):
|
|
# Test INGESTION mode
|
|
emitter = DataHubRestEmitter(
|
|
MOCK_GMS_ENDPOINT, client_mode=ClientMode.INGESTION
|
|
)
|
|
emitter._session = mock_session
|
|
emitter.test_connection()
|
|
assert emitter._openapi_ingestion is True
|
|
|
|
# Test CLI mode
|
|
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.CLI)
|
|
emitter._session = mock_session
|
|
emitter.test_connection()
|
|
assert emitter._openapi_ingestion is True
|
|
|
|
# Test SDK mode
|
|
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK)
|
|
emitter._session = mock_session
|
|
emitter.test_connection()
|
|
assert emitter._openapi_ingestion is True
|
|
|
|
def test_constructor_param_true_overrides_all(self):
|
|
"""Test that explicit constructor parameter True overrides all other settings"""
|
|
with patch.dict(
|
|
os.environ, {"DATAHUB_REST_EMITTER_DEFAULT_ENDPOINT": "RESTLI"}, clear=True
|
|
), patch(
|
|
"datahub.emitter.rest_emitter.DEFAULT_REST_EMITTER_ENDPOINT",
|
|
RestSinkEndpoint.RESTLI,
|
|
):
|
|
# Even with env var and non-SDK mode, constructor param should win
|
|
emitter = DataHubRestEmitter(
|
|
MOCK_GMS_ENDPOINT,
|
|
client_mode=ClientMode.INGESTION,
|
|
openapi_ingestion=True,
|
|
)
|
|
assert emitter._openapi_ingestion is True
|
|
|
|
def test_constructor_param_false_overrides_all(self):
|
|
"""Test that explicit constructor parameter False overrides all other settings"""
|
|
with patch.dict(
|
|
os.environ, {"DATAHUB_REST_EMITTER_DEFAULT_ENDPOINT": "OPENAPI"}, clear=True
|
|
), patch(
|
|
"datahub.emitter.rest_emitter.DEFAULT_REST_EMITTER_ENDPOINT",
|
|
RestSinkEndpoint.OPENAPI,
|
|
):
|
|
# Even with env var and SDK mode, constructor param should win
|
|
emitter = DataHubRestEmitter(
|
|
MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK, openapi_ingestion=False
|
|
)
|
|
assert emitter._openapi_ingestion is False
|
|
|
|
def test_debug_logging(self, mock_session, mock_response):
|
|
"""Test that debug logging is called with correct protocol information"""
|
|
mock_session.get.return_value = mock_response
|
|
|
|
with patch("datahub.emitter.rest_emitter.logger") as mock_logger:
|
|
# Test OpenAPI logging
|
|
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, openapi_ingestion=True)
|
|
emitter._session = mock_session
|
|
emitter.test_connection()
|
|
mock_logger.debug.assert_any_call("Using OpenAPI for ingestion.")
|
|
|
|
# Test RestLi logging
|
|
mock_logger.reset_mock()
|
|
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, openapi_ingestion=False)
|
|
emitter._session = mock_session
|
|
emitter.test_connection()
|
|
mock_logger.debug.assert_any_call("Using Restli for ingestion.")
|
|
|
|
|
|
class TestOpenApiIntegration:
|
|
def test_sdk_mode_uses_openapi_by_default(self, mock_session, mock_response):
|
|
"""Test that SDK mode uses OpenAPI by default for emit_mcp"""
|
|
mock_session.get.return_value = mock_response
|
|
|
|
with patch.dict("os.environ", {}, clear=True):
|
|
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, client_mode=ClientMode.SDK)
|
|
emitter._session = mock_session
|
|
emitter.test_connection()
|
|
|
|
# Verify _openapi_ingestion was set correctly
|
|
assert emitter._openapi_ingestion is True
|
|
|
|
# Create test MCP
|
|
test_mcp = MetadataChangeProposalWrapper(
|
|
entityUrn="urn:li:dataset:(test,sdk,PROD)",
|
|
aspect=Status(removed=False),
|
|
)
|
|
|
|
# Mock _emit_generic to inspect what URL is used
|
|
with patch.object(emitter, "_emit_generic") as mock_emit:
|
|
emitter.emit_mcp(test_mcp)
|
|
|
|
# Check that OpenAPI URL format was used
|
|
mock_emit.assert_called_once()
|
|
url = mock_emit.call_args[0][0]
|
|
assert "openapi" in url
|
|
assert url.startswith(f"{MOCK_GMS_ENDPOINT}/openapi")
|
|
|
|
def test_ingestion_mode_uses_restli_by_default(self, mock_session, mock_response):
|
|
"""Test that INGESTION mode uses RestLi by default for emit_mcp"""
|
|
mock_session.get.return_value = mock_response
|
|
|
|
with patch.dict("os.environ", {}, clear=True):
|
|
emitter = DataHubRestEmitter(
|
|
MOCK_GMS_ENDPOINT, client_mode=ClientMode.INGESTION
|
|
)
|
|
emitter._session = mock_session
|
|
emitter.test_connection()
|
|
|
|
# Verify _openapi_ingestion was set correctly
|
|
assert emitter._openapi_ingestion is False
|
|
|
|
# Create test MCP
|
|
test_mcp = MetadataChangeProposalWrapper(
|
|
entityUrn="urn:li:dataset:(test,ingestion,PROD)",
|
|
aspect=Status(removed=False),
|
|
)
|
|
|
|
# Mock _emit_generic to inspect what URL is used
|
|
with patch.object(emitter, "_emit_generic") as mock_emit:
|
|
emitter.emit_mcp(test_mcp)
|
|
|
|
# Check that RestLi URL format was used (not OpenAPI)
|
|
mock_emit.assert_called_once()
|
|
url = mock_emit.call_args[0][0]
|
|
assert "openapi" not in url
|
|
assert "aspects?action=ingestProposal" in url
|
|
|
|
def test_explicit_openapi_parameter_uses_openapi_api(self):
|
|
"""Test that explicit openapi_ingestion=True uses OpenAPI regardless of mode"""
|
|
emitter = DataHubRestEmitter(
|
|
MOCK_GMS_ENDPOINT,
|
|
client_mode=ClientMode.INGESTION, # Would normally use RestLi
|
|
openapi_ingestion=True, # Override to use OpenAPI
|
|
)
|
|
|
|
# Verify _openapi_ingestion was set correctly
|
|
assert emitter._openapi_ingestion is True
|
|
|
|
# Create test MCP
|
|
test_mcp = MetadataChangeProposalWrapper(
|
|
entityUrn="urn:li:dataset:(test,explicit,PROD)",
|
|
aspect=DatasetProfile(rowCount=100, columnCount=10, timestampMillis=0),
|
|
)
|
|
|
|
# Mock _emit_generic to inspect what URL is used
|
|
with patch.object(emitter, "_emit_generic") as mock_emit:
|
|
emitter.emit_mcp(test_mcp)
|
|
|
|
# Check that OpenAPI URL format was used
|
|
mock_emit.assert_called_once()
|
|
url = mock_emit.call_args[0][0]
|
|
assert "openapi" in url
|
|
|
|
# Verify the payload is formatted for OpenAPI
|
|
payload = mock_emit.call_args[1]["payload"]
|
|
assert isinstance(payload, list) or (
|
|
isinstance(payload, str) and payload.startswith("[")
|
|
)
|
|
|
|
def test_openapi_batch_endpoint_selection(self):
|
|
"""Test that OpenAPI batch operations use correct endpoints based on entity type"""
|
|
emitter = DataHubRestEmitter(MOCK_GMS_ENDPOINT, openapi_ingestion=True)
|
|
|
|
# Create MCPs with different entity types
|
|
dataset_mcp = MetadataChangeProposalWrapper(
|
|
entityUrn="urn:li:dataset:(test,batch1,PROD)",
|
|
entityType="dataset",
|
|
aspect=Status(removed=False),
|
|
)
|
|
|
|
dashboard_mcp = MetadataChangeProposalWrapper(
|
|
entityUrn="urn:li:dashboard:(test,batch2)",
|
|
entityType="dashboard",
|
|
aspect=Status(removed=False),
|
|
)
|
|
|
|
# Mock _emit_generic to inspect URLs used
|
|
with patch.object(emitter, "_emit_generic") as mock_emit:
|
|
# Configure mock to return appropriate responses
|
|
mock_response = Mock()
|
|
mock_response.status_code = 200
|
|
mock_response.headers = {}
|
|
mock_response.json.return_value = []
|
|
mock_emit.return_value = mock_response
|
|
|
|
# Emit batch of different entity types
|
|
emitter.emit_mcps([dataset_mcp, dashboard_mcp])
|
|
|
|
# Verify we made two calls to different endpoints
|
|
assert mock_emit.call_count == 2
|
|
|
|
# Get the URLs from the calls
|
|
calls = mock_emit.call_args_list
|
|
urls = [call[0][0] for call in calls]
|
|
|
|
# Verify we called both entity endpoints
|
|
assert any("entity/dataset" in url for url in urls)
|
|
assert any("entity/dashboard" in url for url in urls)
|
|
|
|
|
|
class TestRequestsSessionConfig:
|
|
def test_get_client_mode_from_session(self):
|
|
"""Test extracting ClientMode from session headers with various inputs."""
|
|
# Case 1: Session with valid ClientMode in headers (string)
|
|
session = Session()
|
|
session.headers.update({"X-DataHub-Client-Mode": "SDK"})
|
|
mode = RequestsSessionConfig.get_client_mode_from_session(session)
|
|
assert mode == ClientMode.SDK
|
|
|
|
# Case 2: Session with valid ClientMode in headers (bytes)
|
|
session = Session()
|
|
session.headers.update({"X-DataHub-Client-Mode": b"INGESTION"})
|
|
mode = RequestsSessionConfig.get_client_mode_from_session(session)
|
|
assert mode == ClientMode.INGESTION
|
|
|
|
# Case 3: Session with no ClientMode header
|
|
session = Session()
|
|
mode = RequestsSessionConfig.get_client_mode_from_session(session)
|
|
assert mode is None
|
|
|
|
# Case 4: Session with invalid ClientMode value
|
|
session = Session()
|
|
session.headers.update({"X-DataHub-Client-Mode": "INVALID_MODE"})
|
|
mode = RequestsSessionConfig.get_client_mode_from_session(session)
|
|
assert mode is None
|
|
|
|
# Case 5: Session with empty ClientMode value
|
|
session = Session()
|
|
session.headers.update({"X-DataHub-Client-Mode": ""})
|
|
mode = RequestsSessionConfig.get_client_mode_from_session(session)
|
|
assert mode is None
|
|
|
|
# Case 6: Test with exception during processing
|
|
mock_session = Mock()
|
|
mock_session.headers = {
|
|
"X-DataHub-Client-Mode": object()
|
|
} # Will cause error when decoded
|
|
mode = RequestsSessionConfig.get_client_mode_from_session(mock_session)
|
|
assert mode is None
|
|
|
|
# Case 7: Different capitalization should not match
|
|
session = Session()
|
|
session.headers.update({"X-DataHub-Client-Mode": "sdk"}) # lowercase
|
|
mode = RequestsSessionConfig.get_client_mode_from_session(session)
|
|
assert mode is None
|
|
|
|
# Case 8: Test with all available ClientMode values
|
|
for client_mode in ClientMode:
|
|
session = Session()
|
|
session.headers.update({"X-DataHub-Client-Mode": client_mode.name})
|
|
mode = RequestsSessionConfig.get_client_mode_from_session(session)
|
|
assert mode == client_mode
|