mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-26 17:37:33 +00:00
feat(ingestion-openapi): patch support (#13282)
Co-authored-by: Sergio Gómez Villamor <sgomezvillamor@gmail.com>
This commit is contained in:
parent
c37eee18e6
commit
9b0634805a
@ -1,9 +1,19 @@
|
||||
import json
|
||||
import shlex
|
||||
from typing import List, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
from datahub.emitter.aspect import JSON_CONTENT_TYPE, JSON_PATCH_CONTENT_TYPE
|
||||
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
||||
from datahub.emitter.serialization_helper import pre_json_transform
|
||||
from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
|
||||
MetadataChangeProposal,
|
||||
)
|
||||
from datahub.metadata.schema_classes import ChangeTypeClass
|
||||
|
||||
|
||||
def _decode_bytes(value: Union[str, bytes]) -> str:
|
||||
"""Decode bytes to string, if necessary."""
|
||||
@ -40,3 +50,97 @@ def make_curl_command(
|
||||
|
||||
fragments.append(url)
|
||||
return shlex.join(fragments)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenApiRequest:
|
||||
"""Represents an OpenAPI request for entity operations."""
|
||||
|
||||
method: str
|
||||
url: str
|
||||
payload: List[Dict[str, Any]]
|
||||
|
||||
@classmethod
|
||||
def from_mcp(
|
||||
cls,
|
||||
mcp: Union[MetadataChangeProposal, MetadataChangeProposalWrapper],
|
||||
gms_server: str,
|
||||
async_flag: Optional[bool] = None,
|
||||
async_default: bool = False,
|
||||
) -> Optional["OpenApiRequest"]:
|
||||
"""Factory method to create an OpenApiRequest from a MetadataChangeProposal."""
|
||||
if not mcp.aspectName or (
|
||||
mcp.changeType != ChangeTypeClass.DELETE and not mcp.aspect
|
||||
):
|
||||
return None
|
||||
|
||||
resolved_async_flag = async_flag if async_flag is not None else async_default
|
||||
|
||||
method = "post"
|
||||
url = f"{gms_server}/openapi/v3/entity/{mcp.entityType}?async={'true' if resolved_async_flag else 'false'}"
|
||||
payload = []
|
||||
|
||||
if mcp.changeType == ChangeTypeClass.DELETE:
|
||||
method = "delete"
|
||||
url = f"{gms_server}/openapi/v3/entity/{mcp.entityType}/{mcp.entityUrn}"
|
||||
else:
|
||||
if mcp.aspect:
|
||||
if mcp.changeType == ChangeTypeClass.PATCH:
|
||||
method = "patch"
|
||||
obj = mcp.aspect.to_obj()
|
||||
content_type = obj.get("contentType")
|
||||
if obj.get("value") and content_type == JSON_PATCH_CONTENT_TYPE:
|
||||
# Undo double serialization.
|
||||
obj = json.loads(obj["value"])
|
||||
patch_value = obj
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"ChangeType {mcp.changeType} only supports context type {JSON_PATCH_CONTENT_TYPE}, found {content_type}."
|
||||
)
|
||||
|
||||
if isinstance(patch_value, list):
|
||||
patch_value = {"patch": patch_value}
|
||||
|
||||
payload = [
|
||||
{
|
||||
"urn": mcp.entityUrn,
|
||||
mcp.aspectName: {
|
||||
"value": patch_value,
|
||||
"systemMetadata": mcp.systemMetadata.to_obj()
|
||||
if mcp.systemMetadata
|
||||
else None,
|
||||
},
|
||||
}
|
||||
]
|
||||
else:
|
||||
if isinstance(mcp, MetadataChangeProposalWrapper):
|
||||
aspect_value = pre_json_transform(
|
||||
mcp.to_obj(simplified_structure=True)
|
||||
)["aspect"]["json"]
|
||||
else:
|
||||
obj = mcp.aspect.to_obj()
|
||||
content_type = obj.get("contentType")
|
||||
if obj.get("value") and content_type == JSON_CONTENT_TYPE:
|
||||
# Undo double serialization.
|
||||
obj = json.loads(obj["value"])
|
||||
elif content_type == JSON_PATCH_CONTENT_TYPE:
|
||||
raise NotImplementedError(
|
||||
f"ChangeType {mcp.changeType} does not support patch."
|
||||
)
|
||||
aspect_value = pre_json_transform(obj)
|
||||
|
||||
payload = [
|
||||
{
|
||||
"urn": mcp.entityUrn,
|
||||
mcp.aspectName: {
|
||||
"value": aspect_value,
|
||||
"systemMetadata": mcp.systemMetadata.to_obj()
|
||||
if mcp.systemMetadata
|
||||
else None,
|
||||
},
|
||||
}
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"ChangeType {mcp.changeType} requires a value.")
|
||||
|
||||
return cls(method=method, url=url, payload=payload)
|
||||
|
||||
@ -41,10 +41,9 @@ from datahub.configuration.common import (
|
||||
TraceTimeoutError,
|
||||
TraceValidationError,
|
||||
)
|
||||
from datahub.emitter.aspect import JSON_CONTENT_TYPE, JSON_PATCH_CONTENT_TYPE
|
||||
from datahub.emitter.generic_emitter import Emitter
|
||||
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
||||
from datahub.emitter.request_helper import make_curl_command
|
||||
from datahub.emitter.request_helper import OpenApiRequest, make_curl_command
|
||||
from datahub.emitter.response_helper import (
|
||||
TraceData,
|
||||
extract_trace_data,
|
||||
@ -348,43 +347,24 @@ class DataHubRestEmitter(Closeable, Emitter):
|
||||
mcp: Union[MetadataChangeProposal, MetadataChangeProposalWrapper],
|
||||
async_flag: Optional[bool] = None,
|
||||
async_default: bool = False,
|
||||
) -> Optional[Tuple[str, List[Dict[str, Any]]]]:
|
||||
if mcp.aspect and mcp.aspectName:
|
||||
resolved_async_flag = (
|
||||
async_flag if async_flag is not None else async_default
|
||||
)
|
||||
url = f"{self._gms_server}/openapi/v3/entity/{mcp.entityType}?async={'true' if resolved_async_flag else 'false'}"
|
||||
) -> Optional[OpenApiRequest]:
|
||||
"""
|
||||
Convert a MetadataChangeProposal to an OpenAPI request format.
|
||||
|
||||
if isinstance(mcp, MetadataChangeProposalWrapper):
|
||||
aspect_value = pre_json_transform(
|
||||
mcp.to_obj(simplified_structure=True)
|
||||
)["aspect"]["json"]
|
||||
else:
|
||||
obj = mcp.aspect.to_obj()
|
||||
content_type = obj.get("contentType")
|
||||
if obj.get("value") and content_type == JSON_CONTENT_TYPE:
|
||||
# Undo double serialization.
|
||||
obj = json.loads(obj["value"])
|
||||
elif content_type == JSON_PATCH_CONTENT_TYPE:
|
||||
raise NotImplementedError(
|
||||
"Patches are not supported for OpenAPI ingestion. Set the endpoint to RESTLI."
|
||||
)
|
||||
aspect_value = pre_json_transform(obj)
|
||||
return (
|
||||
url,
|
||||
[
|
||||
{
|
||||
"urn": mcp.entityUrn,
|
||||
mcp.aspectName: {
|
||||
"value": aspect_value,
|
||||
"systemMetadata": mcp.systemMetadata.to_obj()
|
||||
if mcp.systemMetadata
|
||||
else None,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
return None
|
||||
Args:
|
||||
mcp: The metadata change proposal
|
||||
async_flag: Optional flag to override async behavior
|
||||
async_default: Default async behavior if not specified
|
||||
|
||||
Returns:
|
||||
An OpenApiRequest object or None if the MCP doesn't have required fields
|
||||
"""
|
||||
return OpenApiRequest.from_mcp(
|
||||
mcp=mcp,
|
||||
gms_server=self._gms_server,
|
||||
async_flag=async_flag,
|
||||
async_default=async_default,
|
||||
)
|
||||
|
||||
def emit(
|
||||
self,
|
||||
@ -448,7 +428,9 @@ class DataHubRestEmitter(Closeable, Emitter):
|
||||
if self._openapi_ingestion:
|
||||
request = self._to_openapi_request(mcp, async_flag, async_default=False)
|
||||
if request:
|
||||
response = self._emit_generic(request[0], payload=request[1])
|
||||
response = self._emit_generic(
|
||||
request.url, payload=request.payload, method=request.method
|
||||
)
|
||||
|
||||
if self._should_trace(async_flag, trace_flag):
|
||||
trace_data = extract_trace_data(response) if response else None
|
||||
@ -503,31 +485,36 @@ class DataHubRestEmitter(Closeable, Emitter):
|
||||
trace_timeout: Optional[timedelta] = timedelta(seconds=3600),
|
||||
) -> int:
|
||||
"""
|
||||
1. Grouping MCPs by their entity URL
|
||||
1. Grouping MCPs by their HTTP method and entity URL
|
||||
2. Breaking down large batches into smaller chunks based on both:
|
||||
* Total byte size (INGEST_MAX_PAYLOAD_BYTES)
|
||||
* Maximum number of items (BATCH_INGEST_MAX_PAYLOAD_LENGTH)
|
||||
|
||||
The Chunk class encapsulates both the items and their byte size tracking
|
||||
Serializing the items only once with json.dumps(request[1]) and reusing that
|
||||
Serializing the items only once with json.dumps(request.payload) and reusing that
|
||||
The chunking logic handles edge cases (always accepting at least one item per chunk)
|
||||
The joining logic is efficient with a simple string concatenation
|
||||
|
||||
:param mcps: metadata change proposals to transmit
|
||||
:param async_flag: the mode
|
||||
:param trace_flag: whether to trace the requests
|
||||
:param trace_timeout: timeout for tracing
|
||||
:return: number of requests
|
||||
"""
|
||||
# group by entity url
|
||||
batches: Dict[str, List[_Chunk]] = defaultdict(
|
||||
# Group by entity URL and HTTP method
|
||||
batches: Dict[Tuple[str, str], List[_Chunk]] = defaultdict(
|
||||
lambda: [_Chunk(items=[])]
|
||||
) # Initialize with one empty Chunk
|
||||
|
||||
for mcp in mcps:
|
||||
request = self._to_openapi_request(mcp, async_flag, async_default=True)
|
||||
if request:
|
||||
current_chunk = batches[request[0]][-1] # Get the last chunk
|
||||
# Only serialize once
|
||||
serialized_item = json.dumps(request[1][0])
|
||||
# Create a composite key with both method and URL
|
||||
key = (request.method, request.url)
|
||||
current_chunk = batches[key][-1] # Get the last chunk
|
||||
|
||||
# Only serialize once - we're serializing a single payload item
|
||||
serialized_item = json.dumps(request.payload[0])
|
||||
item_bytes = len(serialized_item.encode())
|
||||
|
||||
# If adding this item would exceed max_bytes, create a new chunk
|
||||
@ -537,15 +524,17 @@ class DataHubRestEmitter(Closeable, Emitter):
|
||||
or len(current_chunk.items) >= BATCH_INGEST_MAX_PAYLOAD_LENGTH
|
||||
):
|
||||
new_chunk = _Chunk(items=[])
|
||||
batches[request[0]].append(new_chunk)
|
||||
batches[key].append(new_chunk)
|
||||
current_chunk = new_chunk
|
||||
|
||||
current_chunk.add_item(serialized_item)
|
||||
|
||||
responses = []
|
||||
for url, chunks in batches.items():
|
||||
for (method, url), chunks in batches.items():
|
||||
for chunk in chunks:
|
||||
response = self._emit_generic(url, payload=_Chunk.join(chunk))
|
||||
response = self._emit_generic(
|
||||
url, payload=_Chunk.join(chunk), method=method
|
||||
)
|
||||
responses.append(response)
|
||||
|
||||
if self._should_trace(async_flag, trace_flag, async_default=True):
|
||||
@ -618,11 +607,13 @@ class DataHubRestEmitter(Closeable, Emitter):
|
||||
payload = json.dumps(snapshot)
|
||||
self._emit_generic(url, payload)
|
||||
|
||||
def _emit_generic(self, url: str, payload: Union[str, Any]) -> requests.Response:
|
||||
def _emit_generic(
|
||||
self, url: str, payload: Union[str, Any], method: str = "POST"
|
||||
) -> requests.Response:
|
||||
if not isinstance(payload, str):
|
||||
payload = json.dumps(payload)
|
||||
|
||||
curl_command = make_curl_command(self._session, "POST", url, payload)
|
||||
curl_command = make_curl_command(self._session, method, url, payload)
|
||||
payload_size = len(payload)
|
||||
if payload_size > INGEST_MAX_PAYLOAD_BYTES:
|
||||
# since we know total payload size here, we could simply avoid sending such payload at all and report a warning, with current approach we are going to cause whole ingestion to fail
|
||||
@ -635,7 +626,8 @@ class DataHubRestEmitter(Closeable, Emitter):
|
||||
curl_command,
|
||||
)
|
||||
try:
|
||||
response = self._session.post(url, data=payload)
|
||||
method_func = getattr(self._session, method.lower())
|
||||
response = method_func(url, data=payload) if payload else method_func(url)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except HTTPError as e:
|
||||
|
||||
@ -158,7 +158,9 @@ class DataHubGraph(DatahubRestEmitter, EntityVersioningAPI):
|
||||
ca_certificate_path=self.config.ca_certificate_path,
|
||||
client_certificate_path=self.config.client_certificate_path,
|
||||
disable_ssl_verification=self.config.disable_ssl_verification,
|
||||
openapi_ingestion=DEFAULT_REST_EMITTER_ENDPOINT == RestSinkEndpoint.OPENAPI,
|
||||
openapi_ingestion=self.config.openapi_ingestion
|
||||
if self.config.openapi_ingestion is not None
|
||||
else (DEFAULT_REST_EMITTER_ENDPOINT == RestSinkEndpoint.OPENAPI),
|
||||
default_trace_mode=DEFAULT_REST_TRACE_MODE == RestTraceMode.ENABLED,
|
||||
)
|
||||
|
||||
|
||||
@ -17,3 +17,4 @@ class DatahubClientConfig(ConfigModel):
|
||||
ca_certificate_path: Optional[str] = None
|
||||
client_certificate_path: Optional[str] = None
|
||||
disable_ssl_verification: bool = False
|
||||
openapi_ingestion: Optional[bool] = None
|
||||
|
||||
@ -0,0 +1,240 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from datahub.emitter.aspect import JSON_CONTENT_TYPE
|
||||
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
||||
from datahub.emitter.request_helper import (
|
||||
OpenApiRequest,
|
||||
)
|
||||
from datahub.emitter.serialization_helper import pre_json_transform
|
||||
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeProposal
|
||||
from datahub.metadata.schema_classes import (
|
||||
AuditStampClass,
|
||||
ChangeAuditStampsClass,
|
||||
ChangeTypeClass,
|
||||
ChartInfoClass,
|
||||
GenericAspectClass,
|
||||
SystemMetadataClass,
|
||||
)
|
||||
from datahub.specific.chart import ChartPatchBuilder
|
||||
|
||||
GMS_SERVER = "http://localhost:8080"
|
||||
CHART_INFO = ChartInfoClass(
|
||||
title="Test Chart",
|
||||
description="Test Description",
|
||||
lastModified=ChangeAuditStampsClass(
|
||||
created=AuditStampClass(time=0, actor="urn:li:corpuser:datahub")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_from_mcp_none_no_aspect():
|
||||
"""Test that from_mcp returns None when aspect is missing"""
|
||||
mcp = MetadataChangeProposalWrapper(
|
||||
entityType="chart",
|
||||
entityUrn="urn:li:chart:(test,test)",
|
||||
aspectName="chartInfo",
|
||||
changeType=ChangeTypeClass.UPSERT,
|
||||
)
|
||||
|
||||
request = OpenApiRequest.from_mcp(mcp, GMS_SERVER)
|
||||
assert request is None
|
||||
|
||||
|
||||
def test_from_mcp_upsert():
|
||||
"""Test creating an OpenApiRequest from an UPSERT MCP"""
|
||||
mcp = MetadataChangeProposalWrapper(
|
||||
entityType="chart",
|
||||
entityUrn="urn:li:chart:(test,test)",
|
||||
aspectName="chartInfo",
|
||||
aspect=CHART_INFO,
|
||||
changeType=ChangeTypeClass.UPSERT,
|
||||
)
|
||||
|
||||
request = OpenApiRequest.from_mcp(mcp, GMS_SERVER)
|
||||
|
||||
assert request is not None
|
||||
assert request.method == "post"
|
||||
assert request.url == f"{GMS_SERVER}/openapi/v3/entity/chart?async=false"
|
||||
assert len(request.payload) == 1
|
||||
assert request.payload[0]["urn"] == "urn:li:chart:(test,test)"
|
||||
assert "chartInfo" in request.payload[0]
|
||||
assert request.payload[0]["chartInfo"]["value"]["title"] == "Test Chart"
|
||||
assert request.payload[0]["chartInfo"]["value"]["description"] == "Test Description"
|
||||
assert request.payload[0]["chartInfo"]["systemMetadata"] is None
|
||||
|
||||
|
||||
def test_from_mcp_upsert_with_system_metadata():
|
||||
"""Test creating an OpenApiRequest from an UPSERT MCP with system metadata"""
|
||||
mcp = MetadataChangeProposalWrapper(
|
||||
entityType="chart",
|
||||
entityUrn="urn:li:chart:(test,test)",
|
||||
aspectName="chartInfo",
|
||||
aspect=CHART_INFO,
|
||||
changeType=ChangeTypeClass.UPSERT,
|
||||
systemMetadata=SystemMetadataClass(runId="test-run-id"),
|
||||
)
|
||||
|
||||
request = OpenApiRequest.from_mcp(mcp, GMS_SERVER)
|
||||
|
||||
assert request is not None
|
||||
assert request.method == "post"
|
||||
assert request.url == f"{GMS_SERVER}/openapi/v3/entity/chart?async=false"
|
||||
assert len(request.payload) == 1
|
||||
assert "chartInfo" in request.payload[0]
|
||||
assert request.payload[0]["chartInfo"]["systemMetadata"]["runId"] == "test-run-id"
|
||||
|
||||
|
||||
def test_from_mcp_upsert_without_wrapper():
|
||||
"""Test creating an OpenApiRequest from an UPSERT MCP without wrapper"""
|
||||
mcp_wrapper = MetadataChangeProposal(
|
||||
entityType="chart",
|
||||
entityUrn="urn:li:chart:(test,test)",
|
||||
aspectName="chartInfo",
|
||||
aspect=GenericAspectClass(
|
||||
value=json.dumps(pre_json_transform(CHART_INFO.to_obj())).encode(),
|
||||
contentType=JSON_CONTENT_TYPE,
|
||||
),
|
||||
changeType=ChangeTypeClass.UPSERT,
|
||||
)
|
||||
|
||||
request = OpenApiRequest.from_mcp(mcp_wrapper, GMS_SERVER)
|
||||
|
||||
assert request is not None
|
||||
assert request.method == "post"
|
||||
assert request.url == f"{GMS_SERVER}/openapi/v3/entity/chart?async=false"
|
||||
assert len(request.payload) == 1
|
||||
assert request.payload[0]["urn"] == "urn:li:chart:(test,test)"
|
||||
assert "chartInfo" in request.payload[0]
|
||||
assert request.payload[0]["chartInfo"]["value"]["title"] == "Test Chart"
|
||||
assert request.payload[0]["chartInfo"]["value"]["description"] == "Test Description"
|
||||
|
||||
|
||||
def test_from_mcp_delete():
|
||||
"""Test creating an OpenApiRequest from a DELETE MCP"""
|
||||
mcp = MetadataChangeProposalWrapper(
|
||||
entityType="chart",
|
||||
entityUrn="urn:li:chart:(test,test)",
|
||||
aspectName="chartInfo",
|
||||
changeType=ChangeTypeClass.DELETE,
|
||||
aspect=None,
|
||||
)
|
||||
|
||||
request = OpenApiRequest.from_mcp(mcp, GMS_SERVER)
|
||||
|
||||
assert request is not None
|
||||
assert request.method == "delete"
|
||||
assert (
|
||||
request.url == f"{GMS_SERVER}/openapi/v3/entity/chart/urn:li:chart:(test,test)"
|
||||
)
|
||||
assert len(request.payload) == 0
|
||||
|
||||
|
||||
def test_from_mcp_patch():
|
||||
"""Test creating an OpenApiRequest from a PATCH MCP"""
|
||||
patch_data = [{"op": "add", "path": "/title", "value": "Updated Title"}]
|
||||
mcp = next(
|
||||
iter(
|
||||
ChartPatchBuilder("urn:li:chart:(test,test)")
|
||||
.set_title("Updated Title")
|
||||
.build()
|
||||
)
|
||||
)
|
||||
|
||||
request = OpenApiRequest.from_mcp(mcp, GMS_SERVER)
|
||||
|
||||
assert request is not None
|
||||
assert request.method == "patch"
|
||||
assert request.url == f"{GMS_SERVER}/openapi/v3/entity/chart?async=false"
|
||||
assert len(request.payload) == 1
|
||||
assert request.payload[0]["urn"] == "urn:li:chart:(test,test)"
|
||||
assert "chartInfo" in request.payload[0]
|
||||
assert request.payload[0]["chartInfo"]["value"]["patch"] == patch_data
|
||||
|
||||
|
||||
def test_patch_unsupported_operation():
|
||||
"""Test that PATCH with non-JSON_PATCH_CONTENT_TYPE raises NotImplementedError"""
|
||||
mcp = next(
|
||||
iter(
|
||||
ChartPatchBuilder("urn:li:chart:(test,test)")
|
||||
.set_title("Updated Title")
|
||||
.build()
|
||||
)
|
||||
)
|
||||
if mcp.aspect:
|
||||
mcp.aspect.contentType = "application/json" # Not JSON_PATCH_CONTENT_TYPE
|
||||
|
||||
with pytest.raises(NotImplementedError) as excinfo:
|
||||
OpenApiRequest.from_mcp(mcp, GMS_SERVER)
|
||||
|
||||
assert "only supports context type application/json-patch+json" in str(
|
||||
excinfo.value
|
||||
)
|
||||
|
||||
|
||||
def test_upsert_incompatible_content_type():
|
||||
"""Test that UPSERT with JSON_PATCH_CONTENT_TYPE raises NotImplementedError"""
|
||||
mcp = next(
|
||||
iter(
|
||||
ChartPatchBuilder("urn:li:chart:(test,test)")
|
||||
.set_title("Updated Title")
|
||||
.build()
|
||||
)
|
||||
)
|
||||
mcp.changeType = ChangeTypeClass.UPSERT
|
||||
|
||||
with pytest.raises(NotImplementedError) as excinfo:
|
||||
OpenApiRequest.from_mcp(mcp, GMS_SERVER)
|
||||
|
||||
assert "does not support patch" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_from_mcp_async_flag():
|
||||
"""Test creating an OpenApiRequest with async flag specified"""
|
||||
mcp = MetadataChangeProposalWrapper(
|
||||
entityType="chart",
|
||||
entityUrn="urn:li:chart:(test,test)",
|
||||
aspectName="chartInfo",
|
||||
aspect=CHART_INFO,
|
||||
changeType=ChangeTypeClass.UPSERT,
|
||||
)
|
||||
|
||||
request = OpenApiRequest.from_mcp(mcp, GMS_SERVER, async_flag=True)
|
||||
|
||||
assert request is not None
|
||||
assert "async=true" in request.url
|
||||
|
||||
|
||||
def test_from_mcp_async_default():
|
||||
"""Test creating an OpenApiRequest with async_default=True"""
|
||||
mcp = MetadataChangeProposalWrapper(
|
||||
entityType="chart",
|
||||
entityUrn="urn:li:chart:(test,test)",
|
||||
aspectName="chartInfo",
|
||||
aspect=CHART_INFO,
|
||||
changeType=ChangeTypeClass.UPSERT,
|
||||
)
|
||||
|
||||
request = OpenApiRequest.from_mcp(mcp, GMS_SERVER, async_default=True)
|
||||
|
||||
assert request is not None
|
||||
assert "async=true" in request.url
|
||||
|
||||
|
||||
def test_from_mcp_async_flag_override():
|
||||
"""Test that async_flag overrides async_default"""
|
||||
mcp = MetadataChangeProposalWrapper(
|
||||
entityType="chart",
|
||||
entityUrn="urn:li:chart:(test,test)",
|
||||
aspectName="chartInfo",
|
||||
aspect=CHART_INFO,
|
||||
changeType=ChangeTypeClass.UPSERT,
|
||||
)
|
||||
|
||||
request = OpenApiRequest.from_mcp(
|
||||
mcp, GMS_SERVER, async_flag=False, async_default=True
|
||||
)
|
||||
|
||||
assert request is not None
|
||||
assert "async=false" in request.url
|
||||
@ -24,6 +24,10 @@ from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
|
||||
DatasetProfile,
|
||||
DatasetProperties,
|
||||
)
|
||||
from datahub.metadata.schema_classes import (
|
||||
ChangeTypeClass,
|
||||
)
|
||||
from datahub.specific.dataset import DatasetPatchBuilder
|
||||
|
||||
MOCK_GMS_ENDPOINT = "http://fakegmshost:8080"
|
||||
|
||||
@ -115,6 +119,7 @@ def test_openapi_emitter_emit(openapi_emitter):
|
||||
},
|
||||
}
|
||||
],
|
||||
method="post",
|
||||
)
|
||||
|
||||
|
||||
@ -750,3 +755,132 @@ def test_await_status_logging(openapi_emitter):
|
||||
with pytest.raises(TraceValidationError):
|
||||
openapi_emitter._await_status([trace], timedelta(seconds=10))
|
||||
mock_error.assert_called_once()
|
||||
|
||||
|
||||
def test_openapi_emitter_same_url_different_methods(openapi_emitter):
|
||||
"""Test handling of requests with same URL but different HTTP methods"""
|
||||
with patch(
|
||||
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
||||
) as mock_emit:
|
||||
items = [
|
||||
# POST requests for updating
|
||||
MetadataChangeProposalWrapper(
|
||||
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,UpdateMe{i},PROD)",
|
||||
entityType="dataset",
|
||||
aspectName="datasetProperties",
|
||||
changeType=ChangeTypeClass.UPSERT,
|
||||
aspect=DatasetProperties(name=f"Updated Dataset {i}"),
|
||||
)
|
||||
for i in range(2)
|
||||
] + [
|
||||
# PATCH requests for fetching
|
||||
next(
|
||||
iter(
|
||||
DatasetPatchBuilder(
|
||||
f"urn:li:dataset:(urn:li:dataPlatform:mysql,PatchMe{i},PROD)"
|
||||
)
|
||||
.set_qualified_name(f"PatchMe{i}")
|
||||
.build()
|
||||
)
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
# Run the test
|
||||
result = openapi_emitter.emit_mcps(items)
|
||||
|
||||
# Verify that we made 2 calls (one for each HTTP method)
|
||||
assert result == 2
|
||||
assert mock_emit.call_count == 2
|
||||
|
||||
# Check that calls were made with different methods but the same URL
|
||||
calls = {}
|
||||
for call in mock_emit.call_args_list:
|
||||
method = call[1]["method"]
|
||||
url = call[0][0]
|
||||
calls[(method, url)] = call
|
||||
|
||||
assert (
|
||||
"post",
|
||||
f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true",
|
||||
) in calls
|
||||
assert (
|
||||
"patch",
|
||||
f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true",
|
||||
) in calls
|
||||
|
||||
|
||||
def test_openapi_emitter_mixed_method_chunking(openapi_emitter):
|
||||
"""Test that chunking works correctly across different HTTP methods"""
|
||||
with patch(
|
||||
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
|
||||
) as mock_emit, patch(
|
||||
"datahub.emitter.rest_emitter.BATCH_INGEST_MAX_PAYLOAD_LENGTH", 2
|
||||
):
|
||||
# Create more items than the chunk size for each method
|
||||
items = [
|
||||
# POST items (4 items, should create 2 chunks)
|
||||
MetadataChangeProposalWrapper(
|
||||
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,Dataset{i},PROD)",
|
||||
entityType="dataset",
|
||||
aspectName="datasetProfile",
|
||||
changeType=ChangeTypeClass.UPSERT,
|
||||
aspect=DatasetProfile(rowCount=i, columnCount=15, timestampMillis=0),
|
||||
)
|
||||
for i in range(4)
|
||||
] + [
|
||||
# PATCH items (3 items, should create 2 chunks)
|
||||
next(
|
||||
iter(
|
||||
DatasetPatchBuilder(
|
||||
f"urn:li:dataset:(urn:li:dataPlatform:mysql,PatchMe{i},PROD)"
|
||||
)
|
||||
.set_qualified_name(f"PatchMe{i}")
|
||||
.build()
|
||||
)
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Run the test with a smaller chunk size to force multiple chunks
|
||||
result = openapi_emitter.emit_mcps(items)
|
||||
|
||||
# Should have 4 chunks total:
|
||||
# - 2 chunks for POST (4 items with max 2 per chunk)
|
||||
# - 2 chunks for PATCH (3 items with max 2 per chunk)
|
||||
assert result == 4
|
||||
assert mock_emit.call_count == 4
|
||||
|
||||
# Count the calls by method and verify chunking
|
||||
post_calls = [
|
||||
call for call in mock_emit.call_args_list if call[1]["method"] == "post"
|
||||
]
|
||||
patch_calls = [
|
||||
call for call in mock_emit.call_args_list if call[1]["method"] == "patch"
|
||||
]
|
||||
|
||||
assert len(post_calls) == 2 # 2 chunks for POST
|
||||
assert len(patch_calls) == 2 # 2 chunks for PATCH
|
||||
|
||||
# Verify first chunks have max size and last chunks have remainders
|
||||
post_payloads = [json.loads(call[1]["payload"]) for call in post_calls]
|
||||
patch_payloads = [json.loads(call[1]["payload"]) for call in patch_calls]
|
||||
|
||||
assert len(post_payloads[0]) == 2
|
||||
assert len(post_payloads[1]) == 2
|
||||
assert len(patch_payloads[0]) == 2
|
||||
assert len(patch_payloads[1]) == 1
|
||||
|
||||
# Verify all post calls are to the dataset endpoint
|
||||
for call in post_calls:
|
||||
assert (
|
||||
call[0][0]
|
||||
== f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true"
|
||||
)
|
||||
|
||||
# Verify all patch calls are to the dataset endpoint
|
||||
for call in patch_calls:
|
||||
assert (
|
||||
call[0][0]
|
||||
== f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true"
|
||||
)
|
||||
|
||||
@ -29,11 +29,12 @@ def auth_session():
|
||||
auth_session.destroy()
|
||||
|
||||
|
||||
def build_graph_client(auth_session):
|
||||
def build_graph_client(auth_session, openapi_ingestion=False):
|
||||
print(auth_session.cookies)
|
||||
graph: DataHubGraph = DataHubGraph(
|
||||
config=DatahubClientConfig(
|
||||
server=auth_session.gms_url(), token=auth_session.gms_token()
|
||||
server=auth_session.gms_url(), token=auth_session.gms_token(),
|
||||
openapi_ingestion=openapi_ingestion
|
||||
)
|
||||
)
|
||||
return graph
|
||||
@ -44,6 +45,11 @@ def graph_client(auth_session) -> DataHubGraph:
|
||||
return build_graph_client(auth_session)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def openapi_graph_client(auth_session) -> DataHubGraph:
|
||||
return build_graph_client(auth_session, openapi_ingestion=True)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session, exitstatus):
|
||||
"""whole test run finishes."""
|
||||
send_message(exitstatus)
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
import datahub.metadata.schema_classes as models
|
||||
from datahub.emitter.mce_builder import make_data_job_urn, make_dataset_urn
|
||||
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
||||
@ -30,27 +32,43 @@ def _make_test_datajob_urn(
|
||||
|
||||
# Common Aspect Patch Tests
|
||||
# Ownership
|
||||
def test_datajob_ownership_patch(graph_client):
|
||||
@pytest.mark.parametrize(
|
||||
"client_fixture_name", ["graph_client", "openapi_graph_client"]
|
||||
)
|
||||
def test_datajob_ownership_patch(request, client_fixture_name):
|
||||
graph_client = request.getfixturevalue(client_fixture_name)
|
||||
datajob_urn = _make_test_datajob_urn()
|
||||
helper_test_ownership_patch(graph_client, datajob_urn, DataJobPatchBuilder)
|
||||
|
||||
|
||||
# Tags
|
||||
def test_datajob_tags_patch(graph_client):
|
||||
@pytest.mark.parametrize(
|
||||
"client_fixture_name", ["graph_client", "openapi_graph_client"]
|
||||
)
|
||||
def test_datajob_tags_patch(request, client_fixture_name):
|
||||
graph_client = request.getfixturevalue(client_fixture_name)
|
||||
helper_test_dataset_tags_patch(
|
||||
graph_client, _make_test_datajob_urn(), DataJobPatchBuilder
|
||||
)
|
||||
|
||||
|
||||
# Terms
|
||||
def test_dataset_terms_patch(graph_client):
|
||||
@pytest.mark.parametrize(
|
||||
"client_fixture_name", ["graph_client", "openapi_graph_client"]
|
||||
)
|
||||
def test_dataset_terms_patch(request, client_fixture_name):
|
||||
graph_client = request.getfixturevalue(client_fixture_name)
|
||||
helper_test_entity_terms_patch(
|
||||
graph_client, _make_test_datajob_urn(), DataJobPatchBuilder
|
||||
)
|
||||
|
||||
|
||||
# Custom Properties
|
||||
def test_custom_properties_patch(graph_client):
|
||||
@pytest.mark.parametrize(
|
||||
"client_fixture_name", ["graph_client", "openapi_graph_client"]
|
||||
)
|
||||
def test_custom_properties_patch(request, client_fixture_name):
|
||||
graph_client = request.getfixturevalue(client_fixture_name)
|
||||
orig_datajob_info = DataJobInfoClass(name="test_name", type="TestJobType")
|
||||
helper_test_custom_properties_patch(
|
||||
graph_client,
|
||||
@ -63,7 +81,11 @@ def test_custom_properties_patch(graph_client):
|
||||
|
||||
# Specific Aspect Patch Tests
|
||||
# Input/Output
|
||||
def test_datajob_inputoutput_dataset_patch(graph_client):
|
||||
@pytest.mark.parametrize(
|
||||
"client_fixture_name", ["graph_client", "openapi_graph_client"]
|
||||
)
|
||||
def test_datajob_inputoutput_dataset_patch(request, client_fixture_name):
|
||||
graph_client = request.getfixturevalue(client_fixture_name)
|
||||
datajob_urn = _make_test_datajob_urn()
|
||||
|
||||
other_dataset_urn = make_dataset_urn(
|
||||
@ -139,7 +161,11 @@ def test_datajob_inputoutput_dataset_patch(graph_client):
|
||||
)
|
||||
|
||||
|
||||
def test_datajob_multiple_inputoutput_dataset_patch(graph_client):
|
||||
@pytest.mark.parametrize(
|
||||
"client_fixture_name", ["graph_client", "openapi_graph_client"]
|
||||
)
|
||||
def test_datajob_multiple_inputoutput_dataset_patch(request, client_fixture_name):
|
||||
graph_client = request.getfixturevalue(client_fixture_name)
|
||||
"""Test creating a data job with multiple input and output datasets and verifying the aspects."""
|
||||
# Create the data job
|
||||
datajob_urn = "urn:li:dataJob:(urn:li:dataFlow:(airflow,training,default),training)"
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from datahub.emitter.mce_builder import make_dataset_urn, make_tag_urn, make_term_urn
|
||||
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
||||
from datahub.ingestion.graph.client import DataHubGraph
|
||||
@ -25,7 +27,11 @@ from tests.patch.common_patch_tests import (
|
||||
|
||||
# Common Aspect Patch Tests
|
||||
# Ownership
|
||||
def test_dataset_ownership_patch(graph_client):
|
||||
@pytest.mark.parametrize(
|
||||
"client_fixture_name", ["graph_client", "openapi_graph_client"]
|
||||
)
|
||||
def test_dataset_ownership_patch(request, client_fixture_name):
|
||||
graph_client = request.getfixturevalue(client_fixture_name)
|
||||
dataset_urn = make_dataset_urn(
|
||||
platform="hive", name=f"SampleHiveDataset{uuid.uuid4()}", env="PROD"
|
||||
)
|
||||
@ -33,7 +39,11 @@ def test_dataset_ownership_patch(graph_client):
|
||||
|
||||
|
||||
# Tags
|
||||
def test_dataset_tags_patch(graph_client):
|
||||
@pytest.mark.parametrize(
|
||||
"client_fixture_name", ["graph_client", "openapi_graph_client"]
|
||||
)
|
||||
def test_dataset_tags_patch(request, client_fixture_name):
|
||||
graph_client = request.getfixturevalue(client_fixture_name)
|
||||
dataset_urn = make_dataset_urn(
|
||||
platform="hive", name=f"SampleHiveDataset-{uuid.uuid4()}", env="PROD"
|
||||
)
|
||||
@ -41,14 +51,22 @@ def test_dataset_tags_patch(graph_client):
|
||||
|
||||
|
||||
# Terms
|
||||
def test_dataset_terms_patch(graph_client):
|
||||
@pytest.mark.parametrize(
|
||||
"client_fixture_name", ["graph_client", "openapi_graph_client"]
|
||||
)
|
||||
def test_dataset_terms_patch(request, client_fixture_name):
|
||||
graph_client = request.getfixturevalue(client_fixture_name)
|
||||
dataset_urn = make_dataset_urn(
|
||||
platform="hive", name=f"SampleHiveDataset-{uuid.uuid4()}", env="PROD"
|
||||
)
|
||||
helper_test_entity_terms_patch(graph_client, dataset_urn, DatasetPatchBuilder)
|
||||
|
||||
|
||||
def test_dataset_upstream_lineage_patch(graph_client: DataHubGraph):
|
||||
@pytest.mark.parametrize(
|
||||
"client_fixture_name", ["graph_client", "openapi_graph_client"]
|
||||
)
|
||||
def test_dataset_upstream_lineage_patch(request, client_fixture_name: DataHubGraph):
|
||||
graph_client = request.getfixturevalue(client_fixture_name)
|
||||
dataset_urn = make_dataset_urn(
|
||||
platform="hive", name=f"SampleHiveDataset-{uuid.uuid4()}", env="PROD"
|
||||
)
|
||||
@ -137,7 +155,11 @@ def get_field_info(
|
||||
return None
|
||||
|
||||
|
||||
def test_field_terms_patch(graph_client: DataHubGraph):
|
||||
@pytest.mark.parametrize(
|
||||
"client_fixture_name", ["graph_client", "openapi_graph_client"]
|
||||
)
|
||||
def test_field_terms_patch(request, client_fixture_name: DataHubGraph):
|
||||
graph_client = request.getfixturevalue(client_fixture_name)
|
||||
dataset_urn = make_dataset_urn(
|
||||
platform="hive", name=f"SampleHiveDataset-{uuid.uuid4()}", env="PROD"
|
||||
)
|
||||
@ -195,7 +217,11 @@ def test_field_terms_patch(graph_client: DataHubGraph):
|
||||
assert len(field_info.glossaryTerms.terms) == 0
|
||||
|
||||
|
||||
def test_field_tags_patch(graph_client: DataHubGraph):
|
||||
@pytest.mark.parametrize(
|
||||
"client_fixture_name", ["graph_client", "openapi_graph_client"]
|
||||
)
|
||||
def test_field_tags_patch(request, client_fixture_name: DataHubGraph):
|
||||
graph_client = request.getfixturevalue(client_fixture_name)
|
||||
dataset_urn = make_dataset_urn(
|
||||
platform="hive", name=f"SampleHiveDataset-{uuid.uuid4()}", env="PROD"
|
||||
)
|
||||
@ -286,7 +312,11 @@ def get_custom_properties(
|
||||
return dataset_properties.customProperties
|
||||
|
||||
|
||||
def test_custom_properties_patch(graph_client: DataHubGraph):
|
||||
@pytest.mark.parametrize(
|
||||
"client_fixture_name", ["graph_client", "openapi_graph_client"]
|
||||
)
|
||||
def test_custom_properties_patch(request, client_fixture_name: DataHubGraph):
|
||||
graph_client = request.getfixturevalue(client_fixture_name)
|
||||
dataset_urn = make_dataset_urn(
|
||||
platform="hive", name=f"SampleHiveDataset-{uuid.uuid4()}", env="PROD"
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user