datahub/smoke-test/tests/data_process_instance/test_data_process_instance.py

292 lines
10 KiB
Python

import logging
import os
import tempfile
from random import randint
import pytest
import datahub.metadata.schema_classes as models
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope
from datahub.ingestion.api.sink import NoopWriteCallback
from datahub.ingestion.sink.file import FileSink, FileSinkConfig
from datahub.metadata.schema_classes import (
AuditStampClass,
ContainerClass,
ContainerPropertiesClass,
DataPlatformInstanceClass,
DataPlatformInstancePropertiesClass,
DataProcessInstancePropertiesClass,
DataProcessInstanceRunEventClass,
MLHyperParamClass,
MLMetricClass,
MLTrainingRunPropertiesClass,
SubTypesClass,
TimeWindowSizeClass,
)
from tests.utils import (
delete_urns_from_file,
ingest_file_via_rest,
wait_for_writes_to_sync,
)
logger = logging.getLogger(__name__)
# Generate unique DPI ID
dpi_id = f"test-pipeline-run-{randint(1000, 9999)}"
dpi_urn = f"urn:li:dataProcessInstance:{dpi_id}"
class FileEmitter:
def __init__(self, filename: str) -> None:
self.sink: FileSink = FileSink(
ctx=PipelineContext(run_id="create_test_data"),
config=FileSinkConfig(filename=filename),
)
def emit(self, event):
self.sink.write_record_async(
record_envelope=RecordEnvelope(record=event, metadata={}),
write_callback=NoopWriteCallback(),
)
def close(self):
self.sink.close()
def create_status_mcp(entity_urn: str):
return MetadataChangeProposalWrapper(
entityUrn=entity_urn,
aspect=models.StatusClass(removed=False),
)
def create_test_data(filename: str):
input_dataset_urn = (
"urn:li:dataset:(urn:li:dataPlatform:kafka,my_input_dataset,PROD)"
)
input_model_urn = "urn:li:mlModel:(urn:li:dataPlatform:mlflow,my_input_model,PROD)"
output_dataset_urn = (
"urn:li:dataset:(urn:li:dataPlatform:kafka,my_output_dataset,PROD)"
)
output_model_urn = (
"urn:li:mlModel:(urn:li:dataPlatform:mlflow,my_output_model,PROD)"
)
data_platform_instance_urn = (
"urn:li:dataPlatformInstance:(urn:li:dataPlatform:airflow,1234567890)"
)
container_urn = "urn:li:container:testGroup1"
mcps = [
create_status_mcp(urn)
for urn in [
input_dataset_urn,
input_model_urn,
output_dataset_urn,
output_model_urn,
data_platform_instance_urn,
]
]
mcps += [
MetadataChangeProposalWrapper(
entityUrn=container_urn,
aspect=ContainerPropertiesClass(name="testGroup1"),
)
]
mcps += [
MetadataChangeProposalWrapper(
entityUrn=data_platform_instance_urn,
aspect=DataPlatformInstancePropertiesClass(name="my process instance"),
)
]
mcps += [
e
for e in MetadataChangeProposalWrapper.construct_many(
entityUrn=dpi_urn,
aspects=[
# Properties aspect
DataProcessInstancePropertiesClass(
name="Test Pipeline Run",
type="BATCH_SCHEDULED",
created=AuditStampClass(
time=1640692800000, actor="urn:li:corpuser:datahub"
),
),
# # Run Event aspect
DataProcessInstanceRunEventClass(
timestampMillis=1704067200000,
eventGranularity=TimeWindowSizeClass(unit="WEEK", multiple=1),
status="COMPLETE",
),
# Platform Instance aspect
DataPlatformInstanceClass(
platform="urn:li:dataPlatform:airflow",
instance="urn:li:dataPlatformInstance:(urn:li:dataPlatform:airflow,1234567890)",
),
# SubTypes aspect
SubTypesClass(typeNames=["TEST", "BATCH_JOB"]),
ContainerClass(container="urn:li:container:testGroup1"),
# ML Training Run Properties aspect
MLTrainingRunPropertiesClass(
id="test-training-run-123",
trainingMetrics=[
MLMetricClass(
name="accuracy",
description="accuracy of the model",
value="0.95",
),
MLMetricClass(
name="loss",
description="accuracy loss of the model",
value="0.05",
),
],
hyperParams=[
MLHyperParamClass(
name="learningRate",
description="rate of learning",
value="0.001",
),
MLHyperParamClass(
name="batchSize",
description="size of the batch",
value="32",
),
],
outputUrls=["s3://my-bucket/ml/output"],
),
models.DataProcessInstanceInputClass(
inputs=[input_dataset_urn, input_model_urn]
),
models.DataProcessInstanceOutputClass(
outputs=[output_dataset_urn, output_model_urn]
),
],
)
]
file_emitter = FileEmitter(filename)
for mcp in mcps:
file_emitter.emit(mcp)
file_emitter.close()
@pytest.fixture(scope="module", autouse=False)
def ingest_cleanup_data(auth_session, graph_client, request):
new_file, filename = tempfile.mkstemp(suffix=".json")
try:
create_test_data(filename)
print("ingesting data process instance test data")
ingest_file_via_rest(auth_session, filename)
wait_for_writes_to_sync()
yield
print("removing data process instance test data")
delete_urns_from_file(graph_client, filename)
wait_for_writes_to_sync()
finally:
os.remove(filename)
# @pytest.mark.integration
def test_search_dpi(auth_session, ingest_cleanup_data):
"""Test DPI search and validation of returned fields using GraphQL."""
json = {
"query": """query scrollAcrossEntities($input: ScrollAcrossEntitiesInput!) {
scrollAcrossEntities(input: $input) {
nextScrollId
count
total
searchResults {
entity {
... on DataProcessInstance {
urn
properties {
name
externalUrl
}
dataPlatformInstance {
platform {
urn
name
}
}
subTypes {
typeNames
}
container {
urn
}
mlTrainingRunProperties {
id
trainingMetrics {
name
value
}
hyperParams {
name
value
}
outputUrls
}
}
}
}
}
}""",
"variables": {
"input": {"types": ["DATA_PROCESS_INSTANCE"], "query": dpi_id, "count": 10}
},
}
response = auth_session.post(
f"{auth_session.frontend_url()}/api/v2/graphql", json=json
)
response.raise_for_status()
res_data = response.json()
# Basic response structure validation
assert res_data, "Response should not be empty"
assert "data" in res_data, "Response should contain 'data' field"
print("RESPONSE DATA:" + str(res_data))
assert "scrollAcrossEntities" in res_data["data"], (
"Response should contain 'scrollAcrossEntities' field"
)
search_results = res_data["data"]["scrollAcrossEntities"]
assert "searchResults" in search_results, (
"Response should contain 'searchResults' field"
)
results = search_results["searchResults"]
assert len(results) > 0, "Should find at least one result"
# Find our test entity
test_entity = None
for result in results:
if result["entity"]["urn"] == dpi_urn:
test_entity = result["entity"]
break
assert test_entity is not None, f"Should find test entity with URN {dpi_urn}"
# Validate fields
props = test_entity["properties"]
assert props["name"] == "Test Pipeline Run"
platform_instance = test_entity["dataPlatformInstance"]
assert platform_instance["platform"]["urn"] == "urn:li:dataPlatform:airflow"
sub_types = test_entity["subTypes"]
assert set(sub_types["typeNames"]) == {"TEST", "BATCH_JOB"}
container = test_entity["container"]
assert container["urn"] == "urn:li:container:testGroup1"
ml_props = test_entity["mlTrainingRunProperties"]
assert ml_props["id"] == "test-training-run-123"
assert ml_props["trainingMetrics"][0] == {"name": "accuracy", "value": "0.95"}
assert ml_props["trainingMetrics"][1] == {"name": "loss", "value": "0.05"}
assert ml_props["hyperParams"][0] == {"name": "learningRate", "value": "0.001"}
assert ml_props["hyperParams"][1] == {"name": "batchSize", "value": "32"}
assert ml_props["outputUrls"][0] == "s3://my-bucket/ml/output"