datahub/smoke-test/tests/data_process_instance/test_data_process_instance.py

294 lines
10 KiB
Python

import logging
import os
import tempfile
from random import randint
import pytest
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope
from datahub.ingestion.api.sink import NoopWriteCallback
from datahub.ingestion.sink.file import FileSink, FileSinkConfig
from datahub.metadata.schema_classes import (
AuditStampClass,
ContainerClass,
ContainerPropertiesClass,
DataPlatformInstanceClass,
DataPlatformInstancePropertiesClass,
DataProcessInstanceKeyClass,
DataProcessInstancePropertiesClass,
DataProcessInstanceRunEventClass,
MLHyperParamClass,
MLMetricClass,
MLTrainingRunPropertiesClass,
SubTypesClass,
TimeWindowSizeClass,
)
from tests.utils import (
delete_urns_from_file,
ingest_file_via_rest,
wait_for_writes_to_sync,
)
logger = logging.getLogger(__name__)
# Generate unique DPI ID
dpi_id = f"test-pipeline-run-{randint(1000, 9999)}"
dpi_urn = f"urn:li:dataProcessInstance:{dpi_id}"
class FileEmitter:
def __init__(self, filename: str) -> None:
self.sink: FileSink = FileSink(
ctx=PipelineContext(run_id="create_test_data"),
config=FileSinkConfig(filename=filename),
)
def emit(self, event):
self.sink.write_record_async(
record_envelope=RecordEnvelope(record=event, metadata={}),
write_callback=NoopWriteCallback(),
)
def close(self):
self.sink.close()
def create_test_data(filename: str):
mcps = [
# Key aspect
MetadataChangeProposalWrapper(
entityType="dataProcessInstance",
entityUrn=dpi_urn,
aspectName="dataProcessInstanceKey",
aspect=DataProcessInstanceKeyClass(id=dpi_id),
),
# Properties aspect
MetadataChangeProposalWrapper(
entityType="dataProcessInstance",
entityUrn=dpi_urn,
aspectName="dataProcessInstanceProperties",
aspect=DataProcessInstancePropertiesClass(
name="Test Pipeline Run",
type="BATCH_SCHEDULED",
created=AuditStampClass(
time=1640692800000, actor="urn:li:corpuser:datahub"
),
),
),
# Run Event aspect
MetadataChangeProposalWrapper(
entityType="dataProcessInstance",
entityUrn=dpi_urn,
aspectName="dataProcessInstanceRunEvent",
aspect=DataProcessInstanceRunEventClass(
timestampMillis=1704067200000,
eventGranularity=TimeWindowSizeClass(unit="WEEK", multiple=1),
status="COMPLETE",
),
),
# Platform Instance aspect
MetadataChangeProposalWrapper(
entityType="dataProcessInstance",
entityUrn=dpi_urn,
aspectName="dataPlatformInstance",
aspect=DataPlatformInstanceClass(
platform="urn:li:dataPlatform:airflow",
instance="urn:li:dataPlatformInstance:(urn:li:dataPlatform:airflow,1234567890)",
),
),
MetadataChangeProposalWrapper(
entityType="dataPlatformInstance",
entityUrn="urn:li:dataPlatformInstance:(urn:li:dataPlatform:airflow,1234567890)",
aspectName="dataPlatformInstanceProperties",
aspect=DataPlatformInstancePropertiesClass(
name="my process instance",
),
),
# SubTypes aspect
MetadataChangeProposalWrapper(
entityType="dataProcessInstance",
entityUrn=dpi_urn,
aspectName="subTypes",
aspect=SubTypesClass(typeNames=["TEST", "BATCH_JOB"]),
),
# Container aspect
MetadataChangeProposalWrapper(
entityType="dataProcessInstance",
entityUrn=dpi_urn,
aspectName="container",
aspect=ContainerClass(container="urn:li:container:testGroup1"),
),
MetadataChangeProposalWrapper(
entityType="container",
entityUrn="urn:li:container:testGroup1",
aspectName="containerProperties",
aspect=ContainerPropertiesClass(name="testGroup1"),
),
# ML Training Run Properties aspect
MetadataChangeProposalWrapper(
entityType="dataProcessInstance",
entityUrn=dpi_urn,
aspectName="mlTrainingRunProperties",
aspect=MLTrainingRunPropertiesClass(
id="test-training-run-123",
trainingMetrics=[
MLMetricClass(
name="accuracy",
description="accuracy of the model",
value="0.95",
),
MLMetricClass(
name="loss",
description="accuracy loss of the model",
value="0.05",
),
],
hyperParams=[
MLHyperParamClass(
name="learningRate",
description="rate of learning",
value="0.001",
),
MLHyperParamClass(
name="batchSize", description="size of the batch", value="32"
),
],
outputUrls=["s3://my-bucket/ml/output"],
),
),
]
file_emitter = FileEmitter(filename)
for mcp in mcps:
file_emitter.emit(mcp)
file_emitter.close()
@pytest.fixture(scope="module", autouse=False)
def ingest_cleanup_data(auth_session, graph_client, request):
new_file, filename = tempfile.mkstemp(suffix=".json")
try:
create_test_data(filename)
print("ingesting data process instance test data")
ingest_file_via_rest(auth_session, filename)
wait_for_writes_to_sync()
yield
print("removing data process instance test data")
delete_urns_from_file(graph_client, filename)
wait_for_writes_to_sync()
finally:
os.remove(filename)
@pytest.mark.integration
def test_search_dpi(auth_session, ingest_cleanup_data):
"""Test DPI search and validation of returned fields using GraphQL."""
json = {
"query": """query scrollAcrossEntities($input: ScrollAcrossEntitiesInput!) {
scrollAcrossEntities(input: $input) {
nextScrollId
count
total
searchResults {
entity {
... on DataProcessInstance {
urn
properties {
name
externalUrl
}
dataPlatformInstance {
platform {
urn
name
}
}
subTypes {
typeNames
}
container {
urn
}
platform {
urn
name
properties {
type
}
}
mlTrainingRunProperties {
id
trainingMetrics {
name
value
}
hyperParams {
name
value
}
outputUrls
}
}
}
}
}
}""",
"variables": {
"input": {"types": ["DATA_PROCESS_INSTANCE"], "query": dpi_id, "count": 10}
},
}
response = auth_session.post(
f"{auth_session.frontend_url()}/api/v2/graphql", json=json
)
response.raise_for_status()
res_data = response.json()
# Basic response structure validation
assert res_data, "Response should not be empty"
assert "data" in res_data, "Response should contain 'data' field"
print("RESPONSE DATA:" + str(res_data))
assert "scrollAcrossEntities" in res_data["data"], (
"Response should contain 'scrollAcrossEntities' field"
)
search_results = res_data["data"]["scrollAcrossEntities"]
assert "searchResults" in search_results, (
"Response should contain 'searchResults' field"
)
results = search_results["searchResults"]
assert len(results) > 0, "Should find at least one result"
# Find our test entity
test_entity = None
for result in results:
if result["entity"]["urn"] == dpi_urn:
test_entity = result["entity"]
break
assert test_entity is not None, f"Should find test entity with URN {dpi_urn}"
# Validate fields
props = test_entity["properties"]
assert props["name"] == "Test Pipeline Run"
platform_instance = test_entity["dataPlatformInstance"]
assert platform_instance["platform"]["urn"] == "urn:li:dataPlatform:airflow"
sub_types = test_entity["subTypes"]
assert set(sub_types["typeNames"]) == {"TEST", "BATCH_JOB"}
container = test_entity["container"]
assert container["urn"] == "urn:li:container:testGroup1"
ml_props = test_entity["mlTrainingRunProperties"]
assert ml_props["id"] == "test-training-run-123"
assert ml_props["trainingMetrics"][0] == {"name": "accuracy", "value": "0.95"}
assert ml_props["trainingMetrics"][1] == {"name": "loss", "value": "0.05"}
assert ml_props["hyperParams"][0] == {"name": "learningRate", "value": "0.001"}
assert ml_props["hyperParams"][1] == {"name": "batchSize", "value": "32"}
assert ml_props["outputUrls"][0] == "s3://my-bucket/ml/output"