2025-02-16 22:23:57 -08:00

178 lines
5.2 KiB
Python

import json
import pathlib
import tempfile
import uuid
from typing import Dict, List, Mapping, Sequence, Set
from unittest.mock import Mock, patch
import dagster._core.utils
from dagster import (
DagsterInstance,
In,
Out,
SkipReason,
build_run_status_sensor_context,
build_sensor_context,
job,
op,
)
from dagster._core.definitions.job_definition import JobDefinition
from dagster._core.definitions.repository_definition import (
RepositoryData,
RepositoryDefinition,
)
from dagster._core.definitions.resource_definition import ResourceDefinition
from freezegun import freeze_time
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.graph.client import DatahubClientConfig
from datahub.testing.compare_metadata_json import assert_metadata_files_equal
from datahub_dagster_plugin.client.dagster_generator import DatahubDagsterSourceConfig
from datahub_dagster_plugin.sensors.datahub_sensors import (
DatahubSensors,
make_datahub_sensor,
)
FROZEN_TIME = "2024-07-11 07:00:00"
call_num = 0
def make_new_run_id_mock() -> str:
global call_num
call_num += 1
return f"test_run_id_{call_num}"
dagster._core.utils.make_new_run_id = make_new_run_id_mock
@patch("datahub_dagster_plugin.sensors.datahub_sensors.DataHubGraph", autospec=True)
def test_datahub_sensor(mock_emit):
instance = DagsterInstance.ephemeral()
class DummyRepositoryData(RepositoryData):
def __init__(self):
self.sensors = []
def get_all_jobs(self) -> Sequence["JobDefinition"]:
return []
def get_top_level_resources(self) -> Mapping[str, "ResourceDefinition"]:
"""Return all top-level resources in the repository as a list,
such as those provided to the Definitions constructor.
Returns:
List[ResourceDefinition]: All top-level resources in the repository.
"""
return {}
def get_env_vars_by_top_level_resource(self) -> Mapping[str, Set[str]]:
return {}
repository_defintion = RepositoryDefinition(
name="testRepository", repository_data=DummyRepositoryData()
)
context = build_sensor_context(
instance=instance, repository_def=repository_defintion
)
mock_emit.return_value = Mock()
config = DatahubDagsterSourceConfig(
datahub_client_config=DatahubClientConfig(
server="http://localhost:8081",
),
dagster_url="http://localhost:3000",
)
datahub_sensor = make_datahub_sensor(config)
skip_reason = datahub_sensor(context)
assert isinstance(skip_reason, SkipReason)
TEST_UUIDS = ["uuid_{}".format(i) for i in range(10000)]
@patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
@patch("datahub_dagster_plugin.sensors.datahub_sensors.DataHubGraph", autospec=True)
@freeze_time(FROZEN_TIME)
def test_emit_metadata(mock_emit: Mock, mock_uuid: Mock) -> None:
mock_emitter = Mock()
mock_emit.return_value = mock_emitter
@op(
out={
"result": Out(
metadata={
"datahub.outputs": [
"urn:li:dataset:(urn:li:dataPlatform:snowflake,tableB,PROD)"
]
}
)
}
)
def extract():
results = [1, 2, 3, 4]
return results
@op(
ins={
"data": In(
metadata={
"datahub.inputs": [
"urn:li:dataset:(urn:li:dataPlatform:snowflake,tableA,PROD)"
]
}
)
}
)
def transform(data):
results = []
for each in data:
results.append(str(each))
return results
@job
def etl():
transform(extract())
instance = DagsterInstance.ephemeral()
test_run_id = "12345678123456781234567812345678"
result = etl.execute_in_process(instance=instance, run_id=test_run_id)
# retrieve the DagsterRun
dagster_run = result.dagster_run
# retrieve a success event from the completed execution
dagster_event = result.get_run_success_event()
# create the context
run_status_sensor_context = build_run_status_sensor_context(
sensor_name="my_email_sensor",
dagster_instance=instance,
dagster_run=dagster_run,
dagster_event=dagster_event,
)
with tempfile.TemporaryDirectory() as tmp_path:
DatahubSensors()._emit_metadata(run_status_sensor_context)
mcpws: List[Dict] = []
for mock_call in mock_emitter.method_calls:
if not mock_call.args:
continue
mcpw = mock_call.args[0]
if isinstance(mcpw, MetadataChangeProposalWrapper):
mcpws.append(mcpw.to_obj(simplified_structure=True))
with open(f"{tmp_path}/test_emit_metadata_mcps.json", "w") as f:
json_object = json.dumps(mcpws, indent=2)
f.write(json_object)
assert_metadata_files_equal(
output_path=pathlib.Path(f"{tmp_path}/test_emit_metadata_mcps.json"),
golden_path=pathlib.Path(
"tests/unit/golden/golden_test_emit_metadata_mcps.json"
),
ignore_paths=["root[*]['systemMetadata']['created']"],
)