diff --git a/ingestion/src/metadata/ingestion/lineage/sql_lineage.py b/ingestion/src/metadata/ingestion/lineage/sql_lineage.py index 6ae79a2f7d2..91968e19dca 100644 --- a/ingestion/src/metadata/ingestion/lineage/sql_lineage.py +++ b/ingestion/src/metadata/ingestion/lineage/sql_lineage.py @@ -271,7 +271,9 @@ def _replace_target_table( ) -> LineageParser: try: # Create a new target table instead of modifying the existing one - new_table = Table(expected_table_name.replace(DEFAULT_SCHEMA_NAME, "")) + # Replace "." with empty string to handle schema prefix correctly + clean_table_name = expected_table_name.replace(f"{DEFAULT_SCHEMA_NAME}.", "") + new_table = LineageTable(clean_table_name) # Create a new statement holder with the updated target table stmt_holder = parser.parser._stmt_holders[0] @@ -312,12 +314,15 @@ def __process_udf_es_results( source_table: Union[DataFunction, LineageTable], database_name: Optional[str], schema_name: Optional[str], - service_name: Optional[str], + service_names: Union[str, List[str]], timeout_seconds: int, column_lineage: dict, es_result_entities: List[StoredProcedure], procedure: Optional[StoredProcedure] = None, ): + if isinstance(service_names, str): + service_names = [service_names] + for entity in es_result_entities: if ( entity.storedProcedureType == StoredProcedureType.UDF @@ -345,7 +350,7 @@ def __process_udf_es_results( source, database_name, schema_name, - service_name, + service_names, timeout_seconds, column_lineage, procedure or entity, @@ -358,19 +363,22 @@ def __process_udf_table_names( source_table: Union[DataFunction, LineageTable], database_name: Optional[str], schema_name: Optional[str], - service_name: Optional[str], + service_names: Union[str, List[str]], timeout_seconds: int, column_lineage: dict, procedure: Optional[StoredProcedure] = None, ): + if isinstance(service_names, str): + service_names = [service_names] + database_query, schema_query, table = get_table_fqn_from_query_name( str(source_table) ) function_fqn_string = build_es_fqn_search_string( - database_query or database_name, - schema_query or schema_name, - service_name, - table, + database_name=database_query or database_name, + schema_name=schema_query or schema_name, + service_name=service_names[0], # Use first service for table entity lookup + table_name=table, ) es_result_entities: Optional[List[StoredProcedure]] = metadata.es_search_from_fqn( entity_type=StoredProcedure, @@ -383,7 +391,7 @@ def __process_udf_table_names( source_table, database_name, schema_name, - service_name, + service_names, timeout_seconds, column_lineage, es_result_entities, @@ -398,7 +406,7 @@ def get_source_table_names( source_table: Union[DataFunction, LineageTable], database_name: Optional[str], schema_name: Optional[str], - service_name: Optional[str], + service_names: Union[str, List[str]], timeout_seconds: int, column_lineage: dict, procedure: Optional[StoredProcedure] = None, @@ -406,6 +414,9 @@ def get_source_table_names( """ Get source table names from DataFunction """ + if isinstance(service_names, str): + service_names = [service_names] + try: if not isinstance(source_table, DataFunction): yield ( @@ -420,7 +431,7 @@ def get_source_table_names( source_table, database_name, schema_name, - service_name, + service_names, timeout_seconds, column_lineage, procedure, @@ -594,7 +605,7 @@ def _create_lineage_by_table_name( metadata: OpenMetadata, from_table: str, to_table: str, - service_name: str, + service_names: Union[str, List[str]], database_name: Optional[str], schema_name: Optional[str], masked_query: str, @@ -607,10 +618,13 @@ def _create_lineage_by_table_name( """ This method is to create a lineage between two tables """ + if isinstance(service_names, str): + service_names = [service_names] + try: from_table_entities = get_table_entities_from_query( metadata=metadata, - service_names=service_name, + service_names=service_names, database_name=database_name, database_schema=schema_name, table_name=from_table, @@ -619,7 +633,7 @@ def _create_lineage_by_table_name( to_table_entities = get_table_entities_from_query( metadata=metadata, - service_names=service_name, + service_names=service_names, database_name=database_name, database_schema=schema_name, table_name=to_table, @@ -761,7 +775,7 @@ def get_lineage_by_query( source_table=source_table, database_name=database_name, schema_name=schema_name, - service_name=service_names, + service_names=service_names, timeout_seconds=timeout_seconds, column_lineage=column_lineage, ): @@ -769,7 +783,7 @@ def get_lineage_by_query( metadata, from_table=str(from_table_name), to_table=str(intermediate_table), - service_name=service_names, + service_names=service_names, database_name=database_name, schema_name=schema_name, masked_query=masked_query, @@ -784,7 +798,7 @@ def get_lineage_by_query( metadata, from_table=str(intermediate_table), to_table=str(target_table), - service_name=service_names, + service_names=service_names, database_name=database_name, schema_name=schema_name, masked_query=masked_query, @@ -802,7 +816,7 @@ def get_lineage_by_query( source_table=source_table, database_name=database_name, schema_name=schema_name, - service_name=service_names, + service_names=service_names, timeout_seconds=timeout_seconds, column_lineage=column_lineage, ): @@ -810,7 +824,7 @@ def get_lineage_by_query( metadata, from_table=str(from_table_name), to_table=str(target_table), - service_name=service_names, + service_names=service_names, database_name=database_name, schema_name=schema_name, masked_query=masked_query, @@ -876,7 +890,7 @@ def get_lineage_via_table_entity( source_table=from_table_name, database_name=database_name, schema_name=schema_name, - service_name=service_names, + service_names=service_names, timeout_seconds=timeout_seconds, column_lineage=column_lineage, ): @@ -884,7 +898,7 @@ def get_lineage_via_table_entity( metadata, from_table=str(source_table), to_table=f"{schema_name}.{to_table_name}", - service_name=service_names, + service_names=service_names, database_name=database_name, schema_name=schema_name, masked_query=masked_query, diff --git a/ingestion/tests/unit/lineage/test_cross_database_lineage_sql.py b/ingestion/tests/unit/lineage/test_cross_database_lineage_sql.py index 729a008263e..48f1fa825a7 100644 --- a/ingestion/tests/unit/lineage/test_cross_database_lineage_sql.py +++ b/ingestion/tests/unit/lineage/test_cross_database_lineage_sql.py @@ -651,3 +651,80 @@ class CrossDatabaseLineageSQLTest(TestCase): # The actual lineage generation depends on the mocked dependencies # but we can verify that the method executes without errors self.assertIsInstance(result, list) + + def test_build_es_fqn_search_string_kwargs(self): + """ + Test that build_es_fqn_search_string is called with keyword arguments + and handles service_names list correctly via get_source_table_names + """ + from unittest.mock import MagicMock, patch + + from collate_sqllineage.core.models import DataFunction + + from metadata.ingestion.lineage.sql_lineage import get_source_table_names + + mock_metadata = MagicMock() + mock_metadata.es_search_from_fqn.return_value = None + + # Test with a DataFunction to trigger the UDF processing + source_table = DataFunction("test_function") + + # Mock build_es_fqn_search_string to capture how it's called + with patch( + "metadata.ingestion.lineage.sql_lineage.build_es_fqn_search_string" + ) as mock_build: + mock_build.return_value = "test.fqn.string" + + # Test with list of service names - this is the bug scenario + service_names = ["service1", "service2"] + list( + get_source_table_names( + metadata=mock_metadata, + dialect=Dialect.ANSI, + source_table=source_table, + database_name="test_db", + schema_name="test_schema", + service_names=service_names, + timeout_seconds=30, + column_lineage={}, + procedure=None, + ) + ) + + # Verify build_es_fqn_search_string was called with keyword arguments + # and the first service name from the list + mock_build.assert_called_with( + database_name="test_db", + schema_name="test_schema", + service_name="service1", # Should use first service from list + table_name="test_function", + ) + + # Test with single service name + with patch( + "metadata.ingestion.lineage.sql_lineage.build_es_fqn_search_string" + ) as mock_build: + mock_build.return_value = "test.fqn.string" + + service_names = "single_service" + list( + get_source_table_names( + metadata=mock_metadata, + dialect=Dialect.ANSI, + source_table=source_table, + database_name="test_db", + schema_name="test_schema", + service_names=service_names, + timeout_seconds=30, + column_lineage={}, + procedure=None, + ) + ) + + # Should handle string service name correctly + mock_build.assert_called_with( + database_name="test_db", + schema_name="test_schema", + service_name="single_service", + table_name="test_function", + ) diff --git a/ingestion/tests/unit/lineage/test_sql_lineage.py b/ingestion/tests/unit/lineage/test_sql_lineage.py index c44575ef841..4921e7e9df7 100644 --- a/ingestion/tests/unit/lineage/test_sql_lineage.py +++ b/ingestion/tests/unit/lineage/test_sql_lineage.py @@ -23,6 +23,7 @@ from metadata.ingestion.lineage.masker import mask_query from metadata.ingestion.lineage.models import Dialect from metadata.ingestion.lineage.parser import LineageParser from metadata.ingestion.lineage.sql_lineage import ( + _replace_target_table, get_column_lineage, get_table_fqn_from_query_name, populate_column_lineage_map, @@ -292,3 +293,53 @@ class SqlLineageTest(TestCase): for i, query in enumerate(query_list): self.assertEqual(mask_query(query[0], query[1]), expected_query_list[i]) + + def test_replace_target_table(self): + """ + Test the _replace_target_table function + """ + # Create a LineageParser with a dummy UDF query + query = "CREATE TABLE dummy_table_name AS SELECT id, name FROM source_table" + parser = LineageParser(query, dialect=Dialect.ANSI) + + # Replace the target table with the expected name + expected_table_name = "actual_target_table" + _replace_target_table(parser, expected_table_name) + + # Verify the target table has been replaced + stmt_holder = parser.parser._stmt_holders[0] + target_tables = list(stmt_holder.write) + + # Check that we have exactly one target table with the expected name + self.assertEqual(len(target_tables), 1) + self.assertEqual(str(target_tables[0]), ".actual_target_table") + + # Verify column lineage is preserved + column_lineage = parser.parser.get_column_lineage() + self.assertIsNotNone(column_lineage) + + # Check that column lineage points to the new target table + for col_lineage in column_lineage: + target_column = col_lineage[-1] + self.assertEqual(str(target_column.parent), ".actual_target_table") + + def test_replace_target_table_with_default_schema(self): + """ + Test _replace_target_table with default schema removal + """ + # Create a LineageParser with a query + query = "CREATE TABLE dummy_table_name AS SELECT * FROM source_table" + parser = LineageParser(query, dialect=Dialect.ANSI) + + # Replace with a name containing default schema + expected_table_name = ".actual_table" + _replace_target_table(parser, expected_table_name) + + # Verify the target table name is correct + # Note: LineageTable always adds for tables without schema + stmt_holder = parser.parser._stmt_holders[0] + target_tables = list(stmt_holder.write) + + self.assertEqual(len(target_tables), 1) + # LineageTable will add back even after we remove it + self.assertEqual(str(target_tables[0]), ".actual_table")