feat(ingestion-openapi): patch support (#13282)

Co-authored-by: Sergio Gómez Villamor <sgomezvillamor@gmail.com>
This commit is contained in:
david-leifker 2025-04-25 13:54:28 -05:00 committed by GitHub
parent c37eee18e6
commit 9b0634805a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 604 additions and 69 deletions

View File

@ -1,9 +1,19 @@
import json
import shlex
from typing import List, Optional, Union
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import requests
from requests.auth import HTTPBasicAuth
from datahub.emitter.aspect import JSON_CONTENT_TYPE, JSON_PATCH_CONTENT_TYPE
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.serialization_helper import pre_json_transform
from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
MetadataChangeProposal,
)
from datahub.metadata.schema_classes import ChangeTypeClass
def _decode_bytes(value: Union[str, bytes]) -> str:
"""Decode bytes to string, if necessary."""
@ -40,3 +50,97 @@ def make_curl_command(
fragments.append(url)
return shlex.join(fragments)
@dataclass
class OpenApiRequest:
"""Represents an OpenAPI request for entity operations."""
method: str
url: str
payload: List[Dict[str, Any]]
@classmethod
def from_mcp(
cls,
mcp: Union[MetadataChangeProposal, MetadataChangeProposalWrapper],
gms_server: str,
async_flag: Optional[bool] = None,
async_default: bool = False,
) -> Optional["OpenApiRequest"]:
"""Factory method to create an OpenApiRequest from a MetadataChangeProposal."""
if not mcp.aspectName or (
mcp.changeType != ChangeTypeClass.DELETE and not mcp.aspect
):
return None
resolved_async_flag = async_flag if async_flag is not None else async_default
method = "post"
url = f"{gms_server}/openapi/v3/entity/{mcp.entityType}?async={'true' if resolved_async_flag else 'false'}"
payload = []
if mcp.changeType == ChangeTypeClass.DELETE:
method = "delete"
url = f"{gms_server}/openapi/v3/entity/{mcp.entityType}/{mcp.entityUrn}"
else:
if mcp.aspect:
if mcp.changeType == ChangeTypeClass.PATCH:
method = "patch"
obj = mcp.aspect.to_obj()
content_type = obj.get("contentType")
if obj.get("value") and content_type == JSON_PATCH_CONTENT_TYPE:
# Undo double serialization.
obj = json.loads(obj["value"])
patch_value = obj
else:
raise NotImplementedError(
f"ChangeType {mcp.changeType} only supports context type {JSON_PATCH_CONTENT_TYPE}, found {content_type}."
)
if isinstance(patch_value, list):
patch_value = {"patch": patch_value}
payload = [
{
"urn": mcp.entityUrn,
mcp.aspectName: {
"value": patch_value,
"systemMetadata": mcp.systemMetadata.to_obj()
if mcp.systemMetadata
else None,
},
}
]
else:
if isinstance(mcp, MetadataChangeProposalWrapper):
aspect_value = pre_json_transform(
mcp.to_obj(simplified_structure=True)
)["aspect"]["json"]
else:
obj = mcp.aspect.to_obj()
content_type = obj.get("contentType")
if obj.get("value") and content_type == JSON_CONTENT_TYPE:
# Undo double serialization.
obj = json.loads(obj["value"])
elif content_type == JSON_PATCH_CONTENT_TYPE:
raise NotImplementedError(
f"ChangeType {mcp.changeType} does not support patch."
)
aspect_value = pre_json_transform(obj)
payload = [
{
"urn": mcp.entityUrn,
mcp.aspectName: {
"value": aspect_value,
"systemMetadata": mcp.systemMetadata.to_obj()
if mcp.systemMetadata
else None,
},
}
]
else:
raise ValueError(f"ChangeType {mcp.changeType} requires a value.")
return cls(method=method, url=url, payload=payload)

View File

@ -41,10 +41,9 @@ from datahub.configuration.common import (
TraceTimeoutError,
TraceValidationError,
)
from datahub.emitter.aspect import JSON_CONTENT_TYPE, JSON_PATCH_CONTENT_TYPE
from datahub.emitter.generic_emitter import Emitter
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.request_helper import make_curl_command
from datahub.emitter.request_helper import OpenApiRequest, make_curl_command
from datahub.emitter.response_helper import (
TraceData,
extract_trace_data,
@ -348,43 +347,24 @@ class DataHubRestEmitter(Closeable, Emitter):
mcp: Union[MetadataChangeProposal, MetadataChangeProposalWrapper],
async_flag: Optional[bool] = None,
async_default: bool = False,
) -> Optional[Tuple[str, List[Dict[str, Any]]]]:
if mcp.aspect and mcp.aspectName:
resolved_async_flag = (
async_flag if async_flag is not None else async_default
)
url = f"{self._gms_server}/openapi/v3/entity/{mcp.entityType}?async={'true' if resolved_async_flag else 'false'}"
) -> Optional[OpenApiRequest]:
"""
Convert a MetadataChangeProposal to an OpenAPI request format.
if isinstance(mcp, MetadataChangeProposalWrapper):
aspect_value = pre_json_transform(
mcp.to_obj(simplified_structure=True)
)["aspect"]["json"]
else:
obj = mcp.aspect.to_obj()
content_type = obj.get("contentType")
if obj.get("value") and content_type == JSON_CONTENT_TYPE:
# Undo double serialization.
obj = json.loads(obj["value"])
elif content_type == JSON_PATCH_CONTENT_TYPE:
raise NotImplementedError(
"Patches are not supported for OpenAPI ingestion. Set the endpoint to RESTLI."
)
aspect_value = pre_json_transform(obj)
return (
url,
[
{
"urn": mcp.entityUrn,
mcp.aspectName: {
"value": aspect_value,
"systemMetadata": mcp.systemMetadata.to_obj()
if mcp.systemMetadata
else None,
},
}
],
)
return None
Args:
mcp: The metadata change proposal
async_flag: Optional flag to override async behavior
async_default: Default async behavior if not specified
Returns:
An OpenApiRequest object or None if the MCP doesn't have required fields
"""
return OpenApiRequest.from_mcp(
mcp=mcp,
gms_server=self._gms_server,
async_flag=async_flag,
async_default=async_default,
)
def emit(
self,
@ -448,7 +428,9 @@ class DataHubRestEmitter(Closeable, Emitter):
if self._openapi_ingestion:
request = self._to_openapi_request(mcp, async_flag, async_default=False)
if request:
response = self._emit_generic(request[0], payload=request[1])
response = self._emit_generic(
request.url, payload=request.payload, method=request.method
)
if self._should_trace(async_flag, trace_flag):
trace_data = extract_trace_data(response) if response else None
@ -503,31 +485,36 @@ class DataHubRestEmitter(Closeable, Emitter):
trace_timeout: Optional[timedelta] = timedelta(seconds=3600),
) -> int:
"""
1. Grouping MCPs by their entity URL
1. Grouping MCPs by their HTTP method and entity URL
2. Breaking down large batches into smaller chunks based on both:
* Total byte size (INGEST_MAX_PAYLOAD_BYTES)
* Maximum number of items (BATCH_INGEST_MAX_PAYLOAD_LENGTH)
The Chunk class encapsulates both the items and their byte size tracking
Serializing the items only once with json.dumps(request[1]) and reusing that
Serializing the items only once with json.dumps(request.payload) and reusing that
The chunking logic handles edge cases (always accepting at least one item per chunk)
The joining logic is efficient with a simple string concatenation
:param mcps: metadata change proposals to transmit
:param async_flag: the mode
:param trace_flag: whether to trace the requests
:param trace_timeout: timeout for tracing
:return: number of requests
"""
# group by entity url
batches: Dict[str, List[_Chunk]] = defaultdict(
# Group by entity URL and HTTP method
batches: Dict[Tuple[str, str], List[_Chunk]] = defaultdict(
lambda: [_Chunk(items=[])]
) # Initialize with one empty Chunk
for mcp in mcps:
request = self._to_openapi_request(mcp, async_flag, async_default=True)
if request:
current_chunk = batches[request[0]][-1] # Get the last chunk
# Only serialize once
serialized_item = json.dumps(request[1][0])
# Create a composite key with both method and URL
key = (request.method, request.url)
current_chunk = batches[key][-1] # Get the last chunk
# Only serialize once - we're serializing a single payload item
serialized_item = json.dumps(request.payload[0])
item_bytes = len(serialized_item.encode())
# If adding this item would exceed max_bytes, create a new chunk
@ -537,15 +524,17 @@ class DataHubRestEmitter(Closeable, Emitter):
or len(current_chunk.items) >= BATCH_INGEST_MAX_PAYLOAD_LENGTH
):
new_chunk = _Chunk(items=[])
batches[request[0]].append(new_chunk)
batches[key].append(new_chunk)
current_chunk = new_chunk
current_chunk.add_item(serialized_item)
responses = []
for url, chunks in batches.items():
for (method, url), chunks in batches.items():
for chunk in chunks:
response = self._emit_generic(url, payload=_Chunk.join(chunk))
response = self._emit_generic(
url, payload=_Chunk.join(chunk), method=method
)
responses.append(response)
if self._should_trace(async_flag, trace_flag, async_default=True):
@ -618,11 +607,13 @@ class DataHubRestEmitter(Closeable, Emitter):
payload = json.dumps(snapshot)
self._emit_generic(url, payload)
def _emit_generic(self, url: str, payload: Union[str, Any]) -> requests.Response:
def _emit_generic(
self, url: str, payload: Union[str, Any], method: str = "POST"
) -> requests.Response:
if not isinstance(payload, str):
payload = json.dumps(payload)
curl_command = make_curl_command(self._session, "POST", url, payload)
curl_command = make_curl_command(self._session, method, url, payload)
payload_size = len(payload)
if payload_size > INGEST_MAX_PAYLOAD_BYTES:
# since we know total payload size here, we could simply avoid sending such payload at all and report a warning, with current approach we are going to cause whole ingestion to fail
@ -635,7 +626,8 @@ class DataHubRestEmitter(Closeable, Emitter):
curl_command,
)
try:
response = self._session.post(url, data=payload)
method_func = getattr(self._session, method.lower())
response = method_func(url, data=payload) if payload else method_func(url)
response.raise_for_status()
return response
except HTTPError as e:

View File

@ -158,7 +158,9 @@ class DataHubGraph(DatahubRestEmitter, EntityVersioningAPI):
ca_certificate_path=self.config.ca_certificate_path,
client_certificate_path=self.config.client_certificate_path,
disable_ssl_verification=self.config.disable_ssl_verification,
openapi_ingestion=DEFAULT_REST_EMITTER_ENDPOINT == RestSinkEndpoint.OPENAPI,
openapi_ingestion=self.config.openapi_ingestion
if self.config.openapi_ingestion is not None
else (DEFAULT_REST_EMITTER_ENDPOINT == RestSinkEndpoint.OPENAPI),
default_trace_mode=DEFAULT_REST_TRACE_MODE == RestTraceMode.ENABLED,
)

View File

@ -17,3 +17,4 @@ class DatahubClientConfig(ConfigModel):
ca_certificate_path: Optional[str] = None
client_certificate_path: Optional[str] = None
disable_ssl_verification: bool = False
openapi_ingestion: Optional[bool] = None

View File

@ -0,0 +1,240 @@
import json
import pytest
from datahub.emitter.aspect import JSON_CONTENT_TYPE
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.request_helper import (
OpenApiRequest,
)
from datahub.emitter.serialization_helper import pre_json_transform
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeProposal
from datahub.metadata.schema_classes import (
AuditStampClass,
ChangeAuditStampsClass,
ChangeTypeClass,
ChartInfoClass,
GenericAspectClass,
SystemMetadataClass,
)
from datahub.specific.chart import ChartPatchBuilder
GMS_SERVER = "http://localhost:8080"
CHART_INFO = ChartInfoClass(
title="Test Chart",
description="Test Description",
lastModified=ChangeAuditStampsClass(
created=AuditStampClass(time=0, actor="urn:li:corpuser:datahub")
),
)
def test_from_mcp_none_no_aspect():
"""Test that from_mcp returns None when aspect is missing"""
mcp = MetadataChangeProposalWrapper(
entityType="chart",
entityUrn="urn:li:chart:(test,test)",
aspectName="chartInfo",
changeType=ChangeTypeClass.UPSERT,
)
request = OpenApiRequest.from_mcp(mcp, GMS_SERVER)
assert request is None
def test_from_mcp_upsert():
"""Test creating an OpenApiRequest from an UPSERT MCP"""
mcp = MetadataChangeProposalWrapper(
entityType="chart",
entityUrn="urn:li:chart:(test,test)",
aspectName="chartInfo",
aspect=CHART_INFO,
changeType=ChangeTypeClass.UPSERT,
)
request = OpenApiRequest.from_mcp(mcp, GMS_SERVER)
assert request is not None
assert request.method == "post"
assert request.url == f"{GMS_SERVER}/openapi/v3/entity/chart?async=false"
assert len(request.payload) == 1
assert request.payload[0]["urn"] == "urn:li:chart:(test,test)"
assert "chartInfo" in request.payload[0]
assert request.payload[0]["chartInfo"]["value"]["title"] == "Test Chart"
assert request.payload[0]["chartInfo"]["value"]["description"] == "Test Description"
assert request.payload[0]["chartInfo"]["systemMetadata"] is None
def test_from_mcp_upsert_with_system_metadata():
"""Test creating an OpenApiRequest from an UPSERT MCP with system metadata"""
mcp = MetadataChangeProposalWrapper(
entityType="chart",
entityUrn="urn:li:chart:(test,test)",
aspectName="chartInfo",
aspect=CHART_INFO,
changeType=ChangeTypeClass.UPSERT,
systemMetadata=SystemMetadataClass(runId="test-run-id"),
)
request = OpenApiRequest.from_mcp(mcp, GMS_SERVER)
assert request is not None
assert request.method == "post"
assert request.url == f"{GMS_SERVER}/openapi/v3/entity/chart?async=false"
assert len(request.payload) == 1
assert "chartInfo" in request.payload[0]
assert request.payload[0]["chartInfo"]["systemMetadata"]["runId"] == "test-run-id"
def test_from_mcp_upsert_without_wrapper():
"""Test creating an OpenApiRequest from an UPSERT MCP without wrapper"""
mcp_wrapper = MetadataChangeProposal(
entityType="chart",
entityUrn="urn:li:chart:(test,test)",
aspectName="chartInfo",
aspect=GenericAspectClass(
value=json.dumps(pre_json_transform(CHART_INFO.to_obj())).encode(),
contentType=JSON_CONTENT_TYPE,
),
changeType=ChangeTypeClass.UPSERT,
)
request = OpenApiRequest.from_mcp(mcp_wrapper, GMS_SERVER)
assert request is not None
assert request.method == "post"
assert request.url == f"{GMS_SERVER}/openapi/v3/entity/chart?async=false"
assert len(request.payload) == 1
assert request.payload[0]["urn"] == "urn:li:chart:(test,test)"
assert "chartInfo" in request.payload[0]
assert request.payload[0]["chartInfo"]["value"]["title"] == "Test Chart"
assert request.payload[0]["chartInfo"]["value"]["description"] == "Test Description"
def test_from_mcp_delete():
"""Test creating an OpenApiRequest from a DELETE MCP"""
mcp = MetadataChangeProposalWrapper(
entityType="chart",
entityUrn="urn:li:chart:(test,test)",
aspectName="chartInfo",
changeType=ChangeTypeClass.DELETE,
aspect=None,
)
request = OpenApiRequest.from_mcp(mcp, GMS_SERVER)
assert request is not None
assert request.method == "delete"
assert (
request.url == f"{GMS_SERVER}/openapi/v3/entity/chart/urn:li:chart:(test,test)"
)
assert len(request.payload) == 0
def test_from_mcp_patch():
"""Test creating an OpenApiRequest from a PATCH MCP"""
patch_data = [{"op": "add", "path": "/title", "value": "Updated Title"}]
mcp = next(
iter(
ChartPatchBuilder("urn:li:chart:(test,test)")
.set_title("Updated Title")
.build()
)
)
request = OpenApiRequest.from_mcp(mcp, GMS_SERVER)
assert request is not None
assert request.method == "patch"
assert request.url == f"{GMS_SERVER}/openapi/v3/entity/chart?async=false"
assert len(request.payload) == 1
assert request.payload[0]["urn"] == "urn:li:chart:(test,test)"
assert "chartInfo" in request.payload[0]
assert request.payload[0]["chartInfo"]["value"]["patch"] == patch_data
def test_patch_unsupported_operation():
"""Test that PATCH with non-JSON_PATCH_CONTENT_TYPE raises NotImplementedError"""
mcp = next(
iter(
ChartPatchBuilder("urn:li:chart:(test,test)")
.set_title("Updated Title")
.build()
)
)
if mcp.aspect:
mcp.aspect.contentType = "application/json" # Not JSON_PATCH_CONTENT_TYPE
with pytest.raises(NotImplementedError) as excinfo:
OpenApiRequest.from_mcp(mcp, GMS_SERVER)
assert "only supports context type application/json-patch+json" in str(
excinfo.value
)
def test_upsert_incompatible_content_type():
"""Test that UPSERT with JSON_PATCH_CONTENT_TYPE raises NotImplementedError"""
mcp = next(
iter(
ChartPatchBuilder("urn:li:chart:(test,test)")
.set_title("Updated Title")
.build()
)
)
mcp.changeType = ChangeTypeClass.UPSERT
with pytest.raises(NotImplementedError) as excinfo:
OpenApiRequest.from_mcp(mcp, GMS_SERVER)
assert "does not support patch" in str(excinfo.value)
def test_from_mcp_async_flag():
"""Test creating an OpenApiRequest with async flag specified"""
mcp = MetadataChangeProposalWrapper(
entityType="chart",
entityUrn="urn:li:chart:(test,test)",
aspectName="chartInfo",
aspect=CHART_INFO,
changeType=ChangeTypeClass.UPSERT,
)
request = OpenApiRequest.from_mcp(mcp, GMS_SERVER, async_flag=True)
assert request is not None
assert "async=true" in request.url
def test_from_mcp_async_default():
"""Test creating an OpenApiRequest with async_default=True"""
mcp = MetadataChangeProposalWrapper(
entityType="chart",
entityUrn="urn:li:chart:(test,test)",
aspectName="chartInfo",
aspect=CHART_INFO,
changeType=ChangeTypeClass.UPSERT,
)
request = OpenApiRequest.from_mcp(mcp, GMS_SERVER, async_default=True)
assert request is not None
assert "async=true" in request.url
def test_from_mcp_async_flag_override():
"""Test that async_flag overrides async_default"""
mcp = MetadataChangeProposalWrapper(
entityType="chart",
entityUrn="urn:li:chart:(test,test)",
aspectName="chartInfo",
aspect=CHART_INFO,
changeType=ChangeTypeClass.UPSERT,
)
request = OpenApiRequest.from_mcp(
mcp, GMS_SERVER, async_flag=False, async_default=True
)
assert request is not None
assert "async=false" in request.url

View File

@ -24,6 +24,10 @@ from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
DatasetProfile,
DatasetProperties,
)
from datahub.metadata.schema_classes import (
ChangeTypeClass,
)
from datahub.specific.dataset import DatasetPatchBuilder
MOCK_GMS_ENDPOINT = "http://fakegmshost:8080"
@ -115,6 +119,7 @@ def test_openapi_emitter_emit(openapi_emitter):
},
}
],
method="post",
)
@ -750,3 +755,132 @@ def test_await_status_logging(openapi_emitter):
with pytest.raises(TraceValidationError):
openapi_emitter._await_status([trace], timedelta(seconds=10))
mock_error.assert_called_once()
def test_openapi_emitter_same_url_different_methods(openapi_emitter):
"""Test handling of requests with same URL but different HTTP methods"""
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit:
items = [
# POST requests for updating
MetadataChangeProposalWrapper(
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,UpdateMe{i},PROD)",
entityType="dataset",
aspectName="datasetProperties",
changeType=ChangeTypeClass.UPSERT,
aspect=DatasetProperties(name=f"Updated Dataset {i}"),
)
for i in range(2)
] + [
# PATCH requests for fetching
next(
iter(
DatasetPatchBuilder(
f"urn:li:dataset:(urn:li:dataPlatform:mysql,PatchMe{i},PROD)"
)
.set_qualified_name(f"PatchMe{i}")
.build()
)
)
for i in range(2)
]
# Run the test
result = openapi_emitter.emit_mcps(items)
# Verify that we made 2 calls (one for each HTTP method)
assert result == 2
assert mock_emit.call_count == 2
# Check that calls were made with different methods but the same URL
calls = {}
for call in mock_emit.call_args_list:
method = call[1]["method"]
url = call[0][0]
calls[(method, url)] = call
assert (
"post",
f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true",
) in calls
assert (
"patch",
f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true",
) in calls
def test_openapi_emitter_mixed_method_chunking(openapi_emitter):
"""Test that chunking works correctly across different HTTP methods"""
with patch(
"datahub.emitter.rest_emitter.DataHubRestEmitter._emit_generic"
) as mock_emit, patch(
"datahub.emitter.rest_emitter.BATCH_INGEST_MAX_PAYLOAD_LENGTH", 2
):
# Create more items than the chunk size for each method
items = [
# POST items (4 items, should create 2 chunks)
MetadataChangeProposalWrapper(
entityUrn=f"urn:li:dataset:(urn:li:dataPlatform:mysql,Dataset{i},PROD)",
entityType="dataset",
aspectName="datasetProfile",
changeType=ChangeTypeClass.UPSERT,
aspect=DatasetProfile(rowCount=i, columnCount=15, timestampMillis=0),
)
for i in range(4)
] + [
# PATCH items (3 items, should create 2 chunks)
next(
iter(
DatasetPatchBuilder(
f"urn:li:dataset:(urn:li:dataPlatform:mysql,PatchMe{i},PROD)"
)
.set_qualified_name(f"PatchMe{i}")
.build()
)
)
for i in range(3)
]
# Run the test with a smaller chunk size to force multiple chunks
result = openapi_emitter.emit_mcps(items)
# Should have 4 chunks total:
# - 2 chunks for POST (4 items with max 2 per chunk)
# - 2 chunks for PATCH (3 items with max 2 per chunk)
assert result == 4
assert mock_emit.call_count == 4
# Count the calls by method and verify chunking
post_calls = [
call for call in mock_emit.call_args_list if call[1]["method"] == "post"
]
patch_calls = [
call for call in mock_emit.call_args_list if call[1]["method"] == "patch"
]
assert len(post_calls) == 2 # 2 chunks for POST
assert len(patch_calls) == 2 # 2 chunks for PATCH
# Verify first chunks have max size and last chunks have remainders
post_payloads = [json.loads(call[1]["payload"]) for call in post_calls]
patch_payloads = [json.loads(call[1]["payload"]) for call in patch_calls]
assert len(post_payloads[0]) == 2
assert len(post_payloads[1]) == 2
assert len(patch_payloads[0]) == 2
assert len(patch_payloads[1]) == 1
# Verify all post calls are to the dataset endpoint
for call in post_calls:
assert (
call[0][0]
== f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true"
)
# Verify all patch calls are to the dataset endpoint
for call in patch_calls:
assert (
call[0][0]
== f"{MOCK_GMS_ENDPOINT}/openapi/v3/entity/dataset?async=true"
)

View File

@ -29,11 +29,12 @@ def auth_session():
auth_session.destroy()
def build_graph_client(auth_session):
def build_graph_client(auth_session, openapi_ingestion=False):
print(auth_session.cookies)
graph: DataHubGraph = DataHubGraph(
config=DatahubClientConfig(
server=auth_session.gms_url(), token=auth_session.gms_token()
server=auth_session.gms_url(), token=auth_session.gms_token(),
openapi_ingestion=openapi_ingestion
)
)
return graph
@ -44,6 +45,11 @@ def graph_client(auth_session) -> DataHubGraph:
return build_graph_client(auth_session)
@pytest.fixture(scope="session")
def openapi_graph_client(auth_session) -> DataHubGraph:
return build_graph_client(auth_session, openapi_ingestion=True)
def pytest_sessionfinish(session, exitstatus):
"""whole test run finishes."""
send_message(exitstatus)

View File

@ -1,6 +1,8 @@
import time
import uuid
import pytest
import datahub.metadata.schema_classes as models
from datahub.emitter.mce_builder import make_data_job_urn, make_dataset_urn
from datahub.emitter.mcp import MetadataChangeProposalWrapper
@ -30,27 +32,43 @@ def _make_test_datajob_urn(
# Common Aspect Patch Tests
# Ownership
def test_datajob_ownership_patch(graph_client):
@pytest.mark.parametrize(
"client_fixture_name", ["graph_client", "openapi_graph_client"]
)
def test_datajob_ownership_patch(request, client_fixture_name):
graph_client = request.getfixturevalue(client_fixture_name)
datajob_urn = _make_test_datajob_urn()
helper_test_ownership_patch(graph_client, datajob_urn, DataJobPatchBuilder)
# Tags
def test_datajob_tags_patch(graph_client):
@pytest.mark.parametrize(
"client_fixture_name", ["graph_client", "openapi_graph_client"]
)
def test_datajob_tags_patch(request, client_fixture_name):
graph_client = request.getfixturevalue(client_fixture_name)
helper_test_dataset_tags_patch(
graph_client, _make_test_datajob_urn(), DataJobPatchBuilder
)
# Terms
def test_dataset_terms_patch(graph_client):
@pytest.mark.parametrize(
"client_fixture_name", ["graph_client", "openapi_graph_client"]
)
def test_dataset_terms_patch(request, client_fixture_name):
graph_client = request.getfixturevalue(client_fixture_name)
helper_test_entity_terms_patch(
graph_client, _make_test_datajob_urn(), DataJobPatchBuilder
)
# Custom Properties
def test_custom_properties_patch(graph_client):
@pytest.mark.parametrize(
"client_fixture_name", ["graph_client", "openapi_graph_client"]
)
def test_custom_properties_patch(request, client_fixture_name):
graph_client = request.getfixturevalue(client_fixture_name)
orig_datajob_info = DataJobInfoClass(name="test_name", type="TestJobType")
helper_test_custom_properties_patch(
graph_client,
@ -63,7 +81,11 @@ def test_custom_properties_patch(graph_client):
# Specific Aspect Patch Tests
# Input/Output
def test_datajob_inputoutput_dataset_patch(graph_client):
@pytest.mark.parametrize(
"client_fixture_name", ["graph_client", "openapi_graph_client"]
)
def test_datajob_inputoutput_dataset_patch(request, client_fixture_name):
graph_client = request.getfixturevalue(client_fixture_name)
datajob_urn = _make_test_datajob_urn()
other_dataset_urn = make_dataset_urn(
@ -139,7 +161,11 @@ def test_datajob_inputoutput_dataset_patch(graph_client):
)
def test_datajob_multiple_inputoutput_dataset_patch(graph_client):
@pytest.mark.parametrize(
"client_fixture_name", ["graph_client", "openapi_graph_client"]
)
def test_datajob_multiple_inputoutput_dataset_patch(request, client_fixture_name):
graph_client = request.getfixturevalue(client_fixture_name)
"""Test creating a data job with multiple input and output datasets and verifying the aspects."""
# Create the data job
datajob_urn = "urn:li:dataJob:(urn:li:dataFlow:(airflow,training,default),training)"

View File

@ -1,6 +1,8 @@
import uuid
from typing import Dict, Optional
import pytest
from datahub.emitter.mce_builder import make_dataset_urn, make_tag_urn, make_term_urn
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.graph.client import DataHubGraph
@ -25,7 +27,11 @@ from tests.patch.common_patch_tests import (
# Common Aspect Patch Tests
# Ownership
def test_dataset_ownership_patch(graph_client):
@pytest.mark.parametrize(
"client_fixture_name", ["graph_client", "openapi_graph_client"]
)
def test_dataset_ownership_patch(request, client_fixture_name):
graph_client = request.getfixturevalue(client_fixture_name)
dataset_urn = make_dataset_urn(
platform="hive", name=f"SampleHiveDataset{uuid.uuid4()}", env="PROD"
)
@ -33,7 +39,11 @@ def test_dataset_ownership_patch(graph_client):
# Tags
def test_dataset_tags_patch(graph_client):
@pytest.mark.parametrize(
"client_fixture_name", ["graph_client", "openapi_graph_client"]
)
def test_dataset_tags_patch(request, client_fixture_name):
graph_client = request.getfixturevalue(client_fixture_name)
dataset_urn = make_dataset_urn(
platform="hive", name=f"SampleHiveDataset-{uuid.uuid4()}", env="PROD"
)
@ -41,14 +51,22 @@ def test_dataset_tags_patch(graph_client):
# Terms
def test_dataset_terms_patch(graph_client):
@pytest.mark.parametrize(
"client_fixture_name", ["graph_client", "openapi_graph_client"]
)
def test_dataset_terms_patch(request, client_fixture_name):
graph_client = request.getfixturevalue(client_fixture_name)
dataset_urn = make_dataset_urn(
platform="hive", name=f"SampleHiveDataset-{uuid.uuid4()}", env="PROD"
)
helper_test_entity_terms_patch(graph_client, dataset_urn, DatasetPatchBuilder)
def test_dataset_upstream_lineage_patch(graph_client: DataHubGraph):
@pytest.mark.parametrize(
"client_fixture_name", ["graph_client", "openapi_graph_client"]
)
def test_dataset_upstream_lineage_patch(request, client_fixture_name: DataHubGraph):
graph_client = request.getfixturevalue(client_fixture_name)
dataset_urn = make_dataset_urn(
platform="hive", name=f"SampleHiveDataset-{uuid.uuid4()}", env="PROD"
)
@ -137,7 +155,11 @@ def get_field_info(
return None
def test_field_terms_patch(graph_client: DataHubGraph):
@pytest.mark.parametrize(
"client_fixture_name", ["graph_client", "openapi_graph_client"]
)
def test_field_terms_patch(request, client_fixture_name: DataHubGraph):
graph_client = request.getfixturevalue(client_fixture_name)
dataset_urn = make_dataset_urn(
platform="hive", name=f"SampleHiveDataset-{uuid.uuid4()}", env="PROD"
)
@ -195,7 +217,11 @@ def test_field_terms_patch(graph_client: DataHubGraph):
assert len(field_info.glossaryTerms.terms) == 0
def test_field_tags_patch(graph_client: DataHubGraph):
@pytest.mark.parametrize(
"client_fixture_name", ["graph_client", "openapi_graph_client"]
)
def test_field_tags_patch(request, client_fixture_name: DataHubGraph):
graph_client = request.getfixturevalue(client_fixture_name)
dataset_urn = make_dataset_urn(
platform="hive", name=f"SampleHiveDataset-{uuid.uuid4()}", env="PROD"
)
@ -286,7 +312,11 @@ def get_custom_properties(
return dataset_properties.customProperties
def test_custom_properties_patch(graph_client: DataHubGraph):
@pytest.mark.parametrize(
"client_fixture_name", ["graph_client", "openapi_graph_client"]
)
def test_custom_properties_patch(request, client_fixture_name: DataHubGraph):
graph_client = request.getfixturevalue(client_fixture_name)
dataset_urn = make_dataset_urn(
platform="hive", name=f"SampleHiveDataset-{uuid.uuid4()}", env="PROD"
)