diff --git a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py index 5148757ffa..86c97046b1 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py @@ -1176,7 +1176,12 @@ def _try_extract_select( statement = sqlglot.exp.Select().select("*").from_(statement) elif isinstance(statement, sqlglot.exp.Insert): # TODO Need to map column renames in the expressions part of the statement. - statement = statement.expression + # Preserve CTEs when extracting the SELECT expression from INSERT + original_ctes = statement.ctes + statement = statement.expression # Get the SELECT expression from the INSERT + if isinstance(statement, sqlglot.exp.Query) and original_ctes: + for cte in original_ctes: + statement = statement.with_(alias=cte.alias, as_=cte.this) elif isinstance(statement, sqlglot.exp.Update): # Assumption: the output table is already captured in the modified tables list. statement = _extract_select_from_update(statement) diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_insert_with_cte.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_insert_with_cte.json new file mode 100644 index 0000000000..9c4de5e739 --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_insert_with_cte.json @@ -0,0 +1,72 @@ +{ + "query_type": "INSERT", + "query_type_props": {}, + "query_fingerprint": "195448498ded7a1b4df767cf0a5ec53e2fa4c7b011234bafe0a60ff9d7d11c1d", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.source_table,PROD)" + ], + "out_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.target_table,PROD)" + ], + "column_lineage": [ + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.target_table,PROD)", + "column": "id", + "column_type": null, + "native_column_type": null + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.source_table,PROD)", + "column": "id" + } + ], + "logic": { + "is_direct_copy": true, + "column_logic": "[source_table].[id] AS [id]" + } + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.target_table,PROD)", + "column": "name", + "column_type": null, + "native_column_type": null + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.source_table,PROD)", + "column": "name" + } + ], + "logic": { + "is_direct_copy": true, + "column_logic": "[source_table].[name] AS [name]" + } + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.target_table,PROD)", + "column": "value", + "column_type": null, + "native_column_type": null + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:tsql,db.schema.source_table,PROD)", + "column": "value" + } + ], + "logic": { + "is_direct_copy": true, + "column_logic": "[source_table].[value] AS [value]" + } + } + ], + "joins": [], + "debug_info": { + "confidence": 0.2, + "generalized_statement": "WITH temp_cte AS (SELECT id AS id, name AS name, value AS value FROM db.schema.source_table) INSERT INTO db.schema.target_table (id, name, value) SELECT id, name, value FROM temp_cte" + } +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py index 592c08bb40..ce97ecf13f 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py +++ b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py @@ -199,6 +199,21 @@ insert into downstream (a, c) select a, c from upstream2 ) +def test_insert_with_cte() -> None: + assert_sql_result( + """ +WITH temp_cte AS ( + SELECT id, name, value + FROM db.schema.source_table +) +INSERT INTO db.schema.target_table (id, name, value) +SELECT id, name, value FROM temp_cte +""", + dialect="tsql", + expected_file=RESOURCE_DIR / "test_insert_with_cte.json", + ) + + def test_select_with_full_col_name() -> None: # In this case, `widget` is a struct column. # This also tests the `default_db` functionality.