diff --git a/metadata-ingestion/src/datahub/cli/specific/dataset_cli.py b/metadata-ingestion/src/datahub/cli/specific/dataset_cli.py index c3b5a82721..6367adde2c 100644 --- a/metadata-ingestion/src/datahub/cli/specific/dataset_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/dataset_cli.py @@ -29,13 +29,16 @@ def dataset() -> None: name="upsert", ) @click.option("-f", "--file", required=True, type=click.Path(exists=True)) +@click.option( + "-n", "--dry-run", type=bool, is_flag=True, default=False, help="Perform a dry run" +) @upgrade.check_upgrade @telemetry.with_telemetry() -def upsert(file: Path) -> None: +def upsert(file: Path, dry_run: bool) -> None: """Upsert attributes to a Dataset in DataHub.""" # Call the sync command with to_datahub=True to perform the upsert operation ctx = click.get_current_context() - ctx.invoke(sync, file=str(file), to_datahub=True) + ctx.invoke(sync, file=str(file), dry_run=dry_run, to_datahub=True) @dataset.command( @@ -167,11 +170,16 @@ def file(lintcheck: bool, lintfix: bool, file: str) -> None: ) @click.option("-f", "--file", required=True, type=click.Path(exists=True)) @click.option("--to-datahub/--from-datahub", required=True, is_flag=True) +@click.option( + "-n", "--dry-run", type=bool, is_flag=True, default=False, help="Perform a dry run" +) @upgrade.check_upgrade @telemetry.with_telemetry() -def sync(file: str, to_datahub: bool) -> None: +def sync(file: str, to_datahub: bool, dry_run: bool) -> None: """Sync a Dataset file to/from DataHub""" + dry_run_prefix = "[dry-run]: " if dry_run else "" # prefix to use in messages + failures: List[str] = [] with get_default_graph() as graph: datasets = Dataset.from_yaml(file) @@ -189,7 +197,7 @@ def sync(file: str, to_datahub: bool) -> None: click.secho( "\n\t- ".join( [ - f"Skipping Dataset {dataset.urn} due to missing entity references: " + f"{dry_run_prefix}Skipping Dataset {dataset.urn} due to missing entity references: " ] + missing_entity_references ), @@ -199,13 +207,18 @@ def sync(file: str, to_datahub: bool) -> None: continue try: for mcp in dataset.generate_mcp(): - graph.emit(mcp) - click.secho(f"Update succeeded for urn {dataset.urn}.", fg="green") + if not dry_run: + graph.emit(mcp) + click.secho( + f"{dry_run_prefix}Update succeeded for urn {dataset.urn}.", + fg="green", + ) except Exception as e: click.secho( - f"Update failed for id {id}. due to {e}", + f"{dry_run_prefix}Update failed for id {id}. due to {e}", fg="red", ) + failures.append(dataset.urn) else: # Sync from DataHub if graph.exists(dataset.urn): @@ -215,13 +228,16 @@ def sync(file: str, to_datahub: bool) -> None: existing_dataset: Dataset = Dataset.from_datahub( graph=graph, urn=dataset.urn, config=dataset_get_config ) - existing_dataset.to_yaml(Path(file)) + if not dry_run: + existing_dataset.to_yaml(Path(file)) + else: + click.secho(f"{dry_run_prefix}Will update file {file}") else: - click.secho(f"Dataset {dataset.urn} does not exist") + click.secho(f"{dry_run_prefix}Dataset {dataset.urn} does not exist") failures.append(dataset.urn) if failures: click.secho( - f"\nFailed to sync the following Datasets: {', '.join(failures)}", + f"\n{dry_run_prefix}Failed to sync the following Datasets: {', '.join(failures)}", fg="red", ) raise click.Abort() diff --git a/metadata-ingestion/tests/unit/cli/dataset/test_dataset_cmd.py b/metadata-ingestion/tests/unit/cli/dataset/test_dataset_cmd.py index f29caac531..318f0a3319 100644 --- a/metadata-ingestion/tests/unit/cli/dataset/test_dataset_cmd.py +++ b/metadata-ingestion/tests/unit/cli/dataset/test_dataset_cmd.py @@ -27,6 +27,32 @@ def test_yaml_file(): temp_file.unlink() +@pytest.fixture +def invalid_value_yaml_file(): + """Creates a temporary yaml file - correctly formatted but bad datatype for testing.""" + invalid_content = """ +## This file is intentionally malformed +- id: user.badformat + platform: hive + schema: + fields: + - id: ip + type: bad_type + description: The IP address + """ + + # Create a temporary file + temp_file = TEST_RESOURCES_DIR / "invalid_dataset.yaml.tmp" + with open(temp_file, "w") as f: + f.write(invalid_content) + + yield temp_file + + # Clean up + if temp_file.exists(): + temp_file.unlink() + + @pytest.fixture def malformed_yaml_file(): """Creates a temporary malformed yaml file for testing.""" @@ -217,3 +243,115 @@ class TestDatasetCli: # Verify both dataset instances had to_yaml called mock_dataset1.to_yaml.assert_called_once() mock_dataset2.to_yaml.assert_called_once() + + @patch("datahub.cli.specific.dataset_cli.get_default_graph") + def test_dry_run_sync(self, mock_get_default_graph, test_yaml_file): + mock_graph = MagicMock() + mock_graph.exists.return_value = True + mock_get_default_graph.return_value.__enter__.return_value = mock_graph + + runner = CliRunner() + result = runner.invoke( + dataset, ["sync", "--dry-run", "--to-datahub", "-f", str(test_yaml_file)] + ) + + # Verify + assert result.exit_code == 0 + assert not mock_get_default_graph.emit.called + + @patch("datahub.cli.specific.dataset_cli.get_default_graph") + def test_dry_run_sync_fail_bad_type( + self, mock_get_default_graph, invalid_value_yaml_file + ): + mock_graph = MagicMock() + mock_graph.exists.return_value = True + mock_get_default_graph.return_value.__enter__.return_value = mock_graph + + runner = CliRunner() + result = runner.invoke( + dataset, + ["sync", "--dry-run", "--to-datahub", "-f", str(invalid_value_yaml_file)], + ) + + # Verify + assert result.exit_code != 0 + assert not mock_get_default_graph.emit.called + assert "Type bad_type is not a valid primitive type" in result.output + + @patch("datahub.cli.specific.dataset_cli.get_default_graph") + def test_dry_run_sync_fail_missing_ref( + self, mock_get_default_graph, test_yaml_file + ): + mock_graph = MagicMock() + mock_graph.exists.return_value = False + mock_get_default_graph.return_value.__enter__.return_value = mock_graph + + runner = CliRunner() + result = runner.invoke( + dataset, ["sync", "--dry-run", "--to-datahub", "-f", str(test_yaml_file)] + ) + + # Verify + assert result.exit_code != 0 + assert not mock_get_default_graph.emit.called + assert "missing entity reference" in result.output + + @patch("datahub.cli.specific.dataset_cli.get_default_graph") + def test_run_sync(self, mock_get_default_graph, test_yaml_file): + mock_graph = MagicMock() + mock_graph.exists.return_value = True + mock_get_default_graph.return_value.__enter__.return_value = mock_graph + + runner = CliRunner() + result = runner.invoke( + dataset, ["sync", "--to-datahub", "-f", str(test_yaml_file)] + ) + + # Verify + assert result.exit_code == 0 + assert mock_graph.emit.called + + @patch("datahub.cli.specific.dataset_cli.get_default_graph") + def test_run_sync_fail(self, mock_get_default_graph, invalid_value_yaml_file): + mock_graph = MagicMock() + mock_graph.exists.return_value = True + mock_get_default_graph.return_value.__enter__.return_value = mock_graph + + runner = CliRunner() + result = runner.invoke( + dataset, ["sync", "--to-datahub", "-f", str(invalid_value_yaml_file)] + ) + + # Verify + assert result.exit_code != 0 + assert not mock_get_default_graph.emit.called + assert "is not a valid primitive type" in result.output + + @patch("datahub.cli.specific.dataset_cli.get_default_graph") + def test_run_upsert_fail(self, mock_get_default_graph, invalid_value_yaml_file): + mock_graph = MagicMock() + mock_graph.exists.return_value = True + mock_get_default_graph.return_value.__enter__.return_value = mock_graph + + runner = CliRunner() + result = runner.invoke(dataset, ["upsert", "-f", str(invalid_value_yaml_file)]) + + # Verify + assert result.exit_code != 0 + assert not mock_get_default_graph.emit.called + assert "is not a valid primitive type" in result.output + + @patch("datahub.cli.specific.dataset_cli.get_default_graph") + def test_sync_from_datahub_fail(self, mock_get_default_graph, test_yaml_file): + mock_graph = MagicMock() + mock_graph.exists.return_value = False + mock_get_default_graph.return_value.__enter__.return_value = mock_graph + + runner = CliRunner() + result = runner.invoke( + dataset, ["sync", "--dry-run", "--from-datahub", "-f", str(test_yaml_file)] + ) + + # Verify + assert result.exit_code != 0 + assert "does not exist" in result.output