diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py index 441a9505c08..b3b7c3bece2 100644 --- a/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py +++ b/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py @@ -113,6 +113,7 @@ class OMetaLineageMixin(Generic[T]): Add lineage relationship between two entities and returns the entity information of the origin node """ + data = deepcopy(data) try: patch_op_success = False if check_patch and data.edge.lineageDetails: diff --git a/ingestion/src/metadata/ingestion/source/pipeline/dbtcloud/metadata.py b/ingestion/src/metadata/ingestion/source/pipeline/dbtcloud/metadata.py index 8036836d54f..420a8d452df 100644 --- a/ingestion/src/metadata/ingestion/source/pipeline/dbtcloud/metadata.py +++ b/ingestion/src/metadata/ingestion/source/pipeline/dbtcloud/metadata.py @@ -168,13 +168,6 @@ class DbtcloudSource(PipelineServiceSource): entity=Pipeline, fqn=pipeline_fqn ) - lineage_details = LineageDetails( - pipeline=EntityReference( - id=pipeline_entity.id.root, type="pipeline" - ), - source=LineageSource.PipelineLineage, - ) - dbt_models = self.client.get_model_details( job_id=pipeline_details.id, run_id=self.context.get().latest_run_id ) @@ -222,6 +215,13 @@ class DbtcloudSource(PipelineServiceSource): if from_entity is None: continue + lineage_details = LineageDetails( + pipeline=EntityReference( + id=pipeline_entity.id.root, type="pipeline" + ), + source=LineageSource.PipelineLineage, + ) + yield Either( right=AddLineageRequest( edge=EntitiesEdge( diff --git a/ingestion/tests/unit/topology/pipeline/test_dbtcloud.py b/ingestion/tests/unit/topology/pipeline/test_dbtcloud.py index 1b2de349d4e..ef64cdeb6a9 100644 --- a/ingestion/tests/unit/topology/pipeline/test_dbtcloud.py +++ b/ingestion/tests/unit/topology/pipeline/test_dbtcloud.py @@ -19,6 +19,10 @@ from unittest.mock import patch from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest from metadata.generated.schema.entity.data.pipeline import Pipeline, Task +from metadata.generated.schema.entity.data.table import Table +from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import ( + OpenMetadataConnection, +) from metadata.generated.schema.entity.services.pipelineService import ( PipelineConnection, PipelineService, @@ -41,6 +45,7 @@ from metadata.ingestion.source.pipeline.dbtcloud.metadata import DbtcloudSource from metadata.ingestion.source.pipeline.dbtcloud.models import ( DBTJob, DBTJobList, + DBTModel, DBTSchedule, ) from metadata.ingestion.source.pipeline.pipeline_service import PipelineUsage @@ -549,6 +554,11 @@ class DBTCloudUnitTest(TestCase): self.dbtcloud.metadata = OpenMetadata( config.workflowConfig.openMetadataServerConfig ) + self.metadata = OpenMetadata( + OpenMetadataConnection.model_validate( + mock_dbtcloud_config["workflowConfig"]["openMetadataServerConfig"] + ) + ) @patch("metadata.ingestion.source.pipeline.dbtcloud.client.DBTCloudClient.get_jobs") def test_get_pipelines_list(self, get_jobs): @@ -567,8 +577,19 @@ class DBTCloudUnitTest(TestCase): assert self.dbtcloud.client.project_ids == EXPECTED_PROJECT_FILTERS def test_pipelines(self): + """ + Test pipeline creation + """ pipeline = list(self.dbtcloud.yield_pipeline(EXPECTED_JOB_DETAILS))[0].right - assert pipeline == EXPECTED_CREATED_PIPELINES + + # Compare individual fields instead of entire objects + self.assertEqual(pipeline.name, EXPECTED_CREATED_PIPELINES.name) + self.assertEqual(pipeline.description, EXPECTED_CREATED_PIPELINES.description) + self.assertEqual(pipeline.sourceUrl, EXPECTED_CREATED_PIPELINES.sourceUrl) + self.assertEqual( + pipeline.scheduleInterval, EXPECTED_CREATED_PIPELINES.scheduleInterval + ) + self.assertEqual(pipeline.service, EXPECTED_CREATED_PIPELINES.service) def test_yield_pipeline_usage(self): """ @@ -783,3 +804,269 @@ class DBTCloudUnitTest(TestCase): self.assertIsNotNone( list(self.dbtcloud.yield_pipeline_usage(EXPECTED_JOB_DETAILS))[0].left ) + + def test_get_model_details(self): + """ + Test getting model details from DBT Cloud + """ + # Mock the graphql client's post method + with patch.object(self.dbtcloud.client.graphql_client, "post") as mock_post: + # Set up mock return value + mock_post.return_value = { + "data": { + "job": { + "models": [ + { + "uniqueId": "model.dbt_test_new.model_32", + "name": "model_32", + "schema": "dbt_test_new", + "database": "dev", + "dependsOn": [ + "model.dbt_test_new.model_15", + "model.dbt_test_new.model_11", + ], + }, + { + "uniqueId": "model.dbt_test_new.model_15", + "name": "model_15", + "schema": "dbt_test_new", + "database": "dev", + "dependsOn": None, + }, + { + "uniqueId": "model.dbt_test_new.model_11", + "name": "model_11", + "schema": "dbt_test_new", + "database": "dev", + "dependsOn": None, + }, + ] + } + } + } + + # Call the method + models = self.dbtcloud.client.get_model_details( + 70403103936332, 70403110257794 + ) + + # Verify we got the expected models + self.assertEqual(len(models), 3) + + # Verify the first model (model_32) + model_32 = next(m for m in models if m.name == "model_32") + self.assertEqual(model_32.database, "dev") + self.assertEqual(model_32.dbtschema, "dbt_test_new") + self.assertEqual(len(model_32.dependsOn), 2) + self.assertIn("model.dbt_test_new.model_15", model_32.dependsOn) + self.assertIn("model.dbt_test_new.model_11", model_32.dependsOn) + + # Test error case + mock_post.side_effect = Exception("Test error") + error_models = self.dbtcloud.client.get_model_details( + 70403103936332, 70403110257794 + ) + self.assertIsNone(error_models) + + def test_get_models_and_seeds_details(self): + """ + Test getting models and seeds details from DBT Cloud + """ + # Mock the graphql client's post method + with patch.object(self.dbtcloud.client.graphql_client, "post") as mock_post: + # Set up mock return value + mock_post.return_value = { + "data": { + "job": { + "models": [ + { + "uniqueId": "model.dbt_test_new.model_32", + "name": "model_32", + "schema": "dbt_test_new", + "database": "dev", + "dependsOn": [ + "model.dbt_test_new.model_15", + "model.dbt_test_new.model_11", + ], + }, + { + "uniqueId": "model.dbt_test_new.model_15", + "name": "model_15", + "schema": "dbt_test_new", + "database": "dev", + "dependsOn": None, + }, + { + "uniqueId": "model.dbt_test_new.model_11", + "name": "model_11", + "schema": "dbt_test_new", + "database": "dev", + "dependsOn": None, + }, + ], + "seeds": [ + { + "uniqueId": "seed.dbt_test_new.raw_payments", + "name": "raw_payments", + "schema": "dbt_test_new", + "database": "dev", + }, + { + "uniqueId": "seed.dbt_test_new.raw_orders", + "name": "raw_orders", + "schema": "dbt_test_new", + "database": "dev", + }, + ], + } + } + } + + # Call the method + models_and_seeds = self.dbtcloud.client.get_models_and_seeds_details( + 70403103936332, 70403110257794 + ) + + # Verify we got the expected models and seeds + self.assertEqual(len(models_and_seeds), 5) + + # Verify the first model (model_32) + model_32 = next(m for m in models_and_seeds if m.name == "model_32") + self.assertEqual(model_32.database, "dev") + self.assertEqual(model_32.dbtschema, "dbt_test_new") + self.assertEqual(len(model_32.dependsOn), 2) + self.assertIn("model.dbt_test_new.model_15", model_32.dependsOn) + self.assertIn("model.dbt_test_new.model_11", model_32.dependsOn) + + # Verify seeds + seeds = [m for m in models_and_seeds if m.uniqueId.startswith("seed.")] + self.assertEqual(len(seeds), 2) + self.assertIn("raw_payments", [s.name for s in seeds]) + self.assertIn("raw_orders", [s.name for s in seeds]) + + # Test error case + mock_post.side_effect = Exception("Test error") + error_models = self.dbtcloud.client.get_models_and_seeds_details( + 70403103936332, 70403110257794 + ) + self.assertIsNone(error_models) + + def test_error_handling_in_lineage(self): + """ + Test error handling in lineage generation + """ + # Mock the context with latest run ID + self.dbtcloud.context.get().__dict__["latest_run_id"] = 70403110257794 + + # Mock metadata.get_by_name to raise an exception + with patch.object( + OpenMetadata, "get_by_name", side_effect=Exception("Test error") + ): + # Get the lineage details + lineage_details = list( + self.dbtcloud.yield_pipeline_lineage_details(EXPECTED_JOB_DETAILS) + ) + + # Verify we got an error + self.assertEqual(len(lineage_details), 1) + self.assertIsNotNone(lineage_details[0].left) + self.assertIn("Test error", lineage_details[0].left.error) + + def test_yield_pipeline_lineage_details(self): + """ + Test the lineage details generation from DBT Cloud models + """ + # Mock the context with latest run ID + self.dbtcloud.context.get().__dict__["latest_run_id"] = 70403110257794 + self.dbtcloud.context.get().__dict__["pipeline"] = "New job" + self.dbtcloud.context.get().__dict__[ + "pipeline_service" + ] = "dbtcloud_pipeline_test" + + # Mock the source config for lineage + self.dbtcloud.source_config.lineageInformation = type( + "obj", (object,), {"dbServiceNames": ["local_redshift"]} + ) + + # Create mock entities + mock_pipeline = Pipeline( + id=uuid.uuid4(), + name="New job", + fullyQualifiedName="dbtcloud_pipeline_test.New job", + service=EntityReference(id=uuid.uuid4(), type="pipelineService"), + ) + + # Create source and target tables + mock_source_table = Table( + id=uuid.uuid4(), + name="model_15", + fullyQualifiedName="local_redshift.dev.dbt_test_new.model_15", + database=EntityReference(id=uuid.uuid4(), type="database"), + columns=[], + databaseSchema=EntityReference(id=uuid.uuid4(), type="databaseSchema"), + ) + + mock_target_table = Table( + id=uuid.uuid4(), + name="model_32", + fullyQualifiedName="local_redshift.dev.dbt_test_new.model_32", + database=EntityReference(id=uuid.uuid4(), type="database"), + columns=[], + databaseSchema=EntityReference(id=uuid.uuid4(), type="databaseSchema"), + ) + + # Patch the metadata's get_by_name method + with patch.object(self.dbtcloud.metadata, "get_by_name") as mock_get_by_name: + + def get_by_name_side_effect(entity, fqn): + if entity == Pipeline: + # Handle both string FQN and FullyQualifiedEntityName + if isinstance(fqn, str): + if fqn == "dbtcloud_pipeline_test.New job": + return mock_pipeline + elif isinstance(fqn, FullyQualifiedEntityName): + if fqn.root == "dbtcloud_pipeline_test.New job": + return mock_pipeline + elif entity == Table: + if "model_15" in fqn: + return mock_source_table + elif "model_32" in fqn: + return mock_target_table + return "None data testing" + + mock_get_by_name.side_effect = get_by_name_side_effect + + # Mock the graphql client's post method + with patch.object( + self.dbtcloud.client, "get_models_and_seeds_details" + ) as mock_get_parents, patch.object( + self.dbtcloud.client, "get_model_details" + ) as mock_get_models: + + mock_get_parents.return_value = [ + DBTModel( + uniqueId="model.dbt_test_new.model_15", + name="model_15", + dbtschema="dbt_test_new", + database="dev", + dependsOn=None, + ) + ] + + mock_get_models.return_value = [ + DBTModel( + uniqueId="model.dbt_test_new.model_32", + name="model_32", + dbtschema="dbt_test_new", + database="dev", + dependsOn=["model.dbt_test_new.model_15"], + ) + ] + + # Get the lineage details + lineage_details = list( + self.dbtcloud.yield_pipeline_lineage_details(EXPECTED_JOB_DETAILS) + ) + + # Verify we got exactly one lineage edge + self.assertEqual(len(lineage_details), 1)