diff --git a/metadata-ingestion/src/datahub/ingestion/source/vertexai/vertexai.py b/metadata-ingestion/src/datahub/ingestion/source/vertexai/vertexai.py index 915cfffc32..f3e1ba37a7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/vertexai/vertexai.py +++ b/metadata-ingestion/src/datahub/ingestion/source/vertexai/vertexai.py @@ -250,7 +250,7 @@ class VertexAISource(Source): task_meta.state = task_detail.state task_meta.start_time = task_detail.start_time task_meta.create_time = task_detail.create_time - if task_detail.end_time: + if task_detail.end_time and task_meta.start_time: task_meta.end_time = task_detail.end_time task_meta.duration = int( ( @@ -498,12 +498,14 @@ class VertexAISource(Source): if len(executions) == 1: create_time = executions[0].create_time update_time = executions[0].update_time - duration = update_time.timestamp() * 1000 - create_time.timestamp() * 1000 - return int(create_time.timestamp() * 1000), int(duration) + if create_time and update_time: + duration = ( + update_time.timestamp() * 1000 - create_time.timestamp() * 1000 + ) + return int(create_time.timestamp() * 1000), int(duration) # When no execution context started, start time and duration are not available # When multiple execution contexts stared on a run, not unable to know which context to use for create_time and duration - else: - return None, None + return None, None def _get_run_result_status(self, status: str) -> Union[str, RunResultTypeClass]: if status == "COMPLETE": @@ -542,9 +544,11 @@ class VertexAISource(Source): ) -> Iterable[MetadataChangeProposalWrapper]: create_time = execution.create_time update_time = execution.update_time - duration = datetime_to_ts_millis(update_time) - datetime_to_ts_millis( - create_time - ) + duration = None + if create_time and update_time: + duration = datetime_to_ts_millis(update_time) - datetime_to_ts_millis( + create_time + ) result_status: Union[str, RunResultTypeClass] = get_execution_result_status( execution.state ) @@ -558,7 +562,7 @@ class VertexAISource(Source): DataProcessInstancePropertiesClass( name=execution.name, created=AuditStampClass( - time=datetime_to_ts_millis(create_time), + time=datetime_to_ts_millis(create_time) if create_time else 0, actor="urn:li:corpuser:datahub", ), externalUrl=self._make_artifact_external_url( @@ -573,7 +577,9 @@ class VertexAISource(Source): ( DataProcessInstanceRunEventClass( status=DataProcessRunStatusClass.COMPLETE, - timestampMillis=datetime_to_ts_millis(create_time), + timestampMillis=datetime_to_ts_millis(create_time) + if create_time + else 0, result=DataProcessInstanceRunResultClass( type=result_status, nativeResultType=self.platform, diff --git a/metadata-ingestion/tests/unit/test_vertexai_source.py b/metadata-ingestion/tests/unit/test_vertexai_source.py index 92a2caab3d..e5c04289a3 100644 --- a/metadata-ingestion/tests/unit/test_vertexai_source.py +++ b/metadata-ingestion/tests/unit/test_vertexai_source.py @@ -1,10 +1,12 @@ import contextlib -from datetime import timedelta +from datetime import datetime, timedelta, timezone from typing import List -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest -from google.cloud.aiplatform import Experiment, ExperimentRun +from google.cloud.aiplatform import Experiment, ExperimentRun, PipelineJob +from google.cloud.aiplatform_v1 import PipelineTaskDetail +from google.cloud.aiplatform_v1.types import PipelineJob as PipelineJobType import datahub.emitter.mce_builder as builder from datahub.emitter.mcp import MetadataChangeProposalWrapper @@ -820,3 +822,149 @@ def test_make_job_urn(source: VertexAISource) -> None: source._make_training_job_urn(mock_training_job) == f"{builder.make_data_process_instance_urn(source._make_vertexai_job_name(mock_training_job.name))}" ) + + +def test_pipeline_task_with_none_start_time(source: VertexAISource) -> None: + """Test that pipeline tasks with None start_time don't crash the ingestion.""" + mock_pipeline_job = MagicMock(spec=PipelineJob) + mock_pipeline_job.name = "test_pipeline_none_timestamps" + mock_pipeline_job.resource_name = ( + "projects/123/locations/us-central1/pipelineJobs/789" + ) + mock_pipeline_job.labels = {} + mock_pipeline_job.create_time = datetime.fromtimestamp(1647878400, tz=timezone.utc) + mock_pipeline_job.update_time = datetime.fromtimestamp(1647878500, tz=timezone.utc) + mock_pipeline_job.location = "us-west2" + + gca_resource = MagicMock(spec=PipelineJobType) + mock_pipeline_job.gca_resource = gca_resource + + task_detail = MagicMock(spec=PipelineTaskDetail) + task_detail.task_name = "incomplete_task" + task_detail.task_id = 123 + task_detail.state = MagicMock() + task_detail.start_time = None + task_detail.create_time = datetime.fromtimestamp(1647878400, tz=timezone.utc) + task_detail.end_time = datetime.fromtimestamp(1647878600, tz=timezone.utc) + + mock_pipeline_job.task_details = [task_detail] + gca_resource.pipeline_spec = { + "root": { + "dag": { + "tasks": { + "incomplete_task": { + "componentRef": {"name": "comp-incomplete"}, + "taskInfo": {"name": "incomplete_task"}, + } + } + } + } + } + + with contextlib.ExitStack() as exit_stack: + mock = exit_stack.enter_context( + patch("google.cloud.aiplatform.PipelineJob.list") + ) + mock.return_value = [mock_pipeline_job] + + actual_mcps = list(source._get_pipelines_mcps()) + + task_run_mcps = [ + mcp + for mcp in actual_mcps + if isinstance(mcp.aspect, DataProcessInstancePropertiesClass) + and "incomplete_task" in mcp.aspect.name + ] + + assert len(task_run_mcps) > 0 + + +def test_pipeline_task_with_none_end_time(source: VertexAISource) -> None: + """Test that pipeline tasks with None end_time don't crash the ingestion.""" + mock_pipeline_job = MagicMock(spec=PipelineJob) + mock_pipeline_job.name = "test_pipeline_no_end_time" + mock_pipeline_job.resource_name = ( + "projects/123/locations/us-central1/pipelineJobs/790" + ) + mock_pipeline_job.labels = {} + mock_pipeline_job.create_time = datetime.fromtimestamp(1647878400, tz=timezone.utc) + mock_pipeline_job.update_time = datetime.fromtimestamp(1647878500, tz=timezone.utc) + mock_pipeline_job.location = "us-west2" + + gca_resource = MagicMock(spec=PipelineJobType) + mock_pipeline_job.gca_resource = gca_resource + + task_detail = MagicMock(spec=PipelineTaskDetail) + task_detail.task_name = "running_task" + task_detail.task_id = 124 + task_detail.state = MagicMock() + task_detail.start_time = datetime.fromtimestamp(1647878400, tz=timezone.utc) + task_detail.create_time = datetime.fromtimestamp(1647878400, tz=timezone.utc) + task_detail.end_time = None + + mock_pipeline_job.task_details = [task_detail] + gca_resource.pipeline_spec = { + "root": { + "dag": { + "tasks": { + "running_task": { + "componentRef": {"name": "comp-running"}, + "taskInfo": {"name": "running_task"}, + } + } + } + } + } + + with contextlib.ExitStack() as exit_stack: + mock = exit_stack.enter_context( + patch("google.cloud.aiplatform.PipelineJob.list") + ) + mock.return_value = [mock_pipeline_job] + + actual_mcps = list(source._get_pipelines_mcps()) + + task_run_mcps = [ + mcp + for mcp in actual_mcps + if isinstance(mcp.aspect, DataProcessInstancePropertiesClass) + and "running_task" in mcp.aspect.name + ] + + assert len(task_run_mcps) > 0 + + +def test_experiment_run_with_none_timestamps(source: VertexAISource) -> None: + """Test that experiment runs with None create_time/update_time don't crash.""" + mock_exp = gen_mock_experiment() + source.experiments = [mock_exp] + + mock_exp_run = MagicMock(spec=ExperimentRun) + mock_exp_run.name = "test_run_none_timestamps" + mock_exp_run.get_state.return_value = "COMPLETE" + mock_exp_run.get_params.return_value = {} + mock_exp_run.get_metrics.return_value = {} + + mock_execution = MagicMock() + mock_execution.name = "test_execution" + mock_execution.create_time = None + mock_execution.update_time = None + mock_execution.state = "COMPLETE" + mock_execution.get_input_artifacts.return_value = [] + mock_execution.get_output_artifacts.return_value = [] + + mock_exp_run.get_executions.return_value = [mock_execution] + + with patch("google.cloud.aiplatform.ExperimentRun.list") as mock_list: + mock_list.return_value = [mock_exp_run] + + actual_mcps = list(source._get_experiment_runs_mcps()) + + run_mcps = [ + mcp + for mcp in actual_mcps + if isinstance(mcp.aspect, DataProcessInstancePropertiesClass) + and "test_run_none_timestamps" in mcp.aspect.name + ] + + assert len(run_mcps) > 0