[Lineage] Fix cross services lineage changes of service_names to missed methods (#23240)

* Fix cross db changes of service_names to missed methods

* Handle string value passed to service_names
This commit is contained in:
Mohit Tilala 2025-09-04 20:38:05 +05:30 committed by GitHub
parent 70d9a1182e
commit 9b2b4d2452
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 163 additions and 21 deletions

View File

@ -271,7 +271,9 @@ def _replace_target_table(
) -> LineageParser:
try:
# Create a new target table instead of modifying the existing one
new_table = Table(expected_table_name.replace(DEFAULT_SCHEMA_NAME, ""))
# Replace "<default>." with empty string to handle schema prefix correctly
clean_table_name = expected_table_name.replace(f"{DEFAULT_SCHEMA_NAME}.", "")
new_table = LineageTable(clean_table_name)
# Create a new statement holder with the updated target table
stmt_holder = parser.parser._stmt_holders[0]
@ -312,12 +314,15 @@ def __process_udf_es_results(
source_table: Union[DataFunction, LineageTable],
database_name: Optional[str],
schema_name: Optional[str],
service_name: Optional[str],
service_names: Union[str, List[str]],
timeout_seconds: int,
column_lineage: dict,
es_result_entities: List[StoredProcedure],
procedure: Optional[StoredProcedure] = None,
):
if isinstance(service_names, str):
service_names = [service_names]
for entity in es_result_entities:
if (
entity.storedProcedureType == StoredProcedureType.UDF
@ -345,7 +350,7 @@ def __process_udf_es_results(
source,
database_name,
schema_name,
service_name,
service_names,
timeout_seconds,
column_lineage,
procedure or entity,
@ -358,19 +363,22 @@ def __process_udf_table_names(
source_table: Union[DataFunction, LineageTable],
database_name: Optional[str],
schema_name: Optional[str],
service_name: Optional[str],
service_names: Union[str, List[str]],
timeout_seconds: int,
column_lineage: dict,
procedure: Optional[StoredProcedure] = None,
):
if isinstance(service_names, str):
service_names = [service_names]
database_query, schema_query, table = get_table_fqn_from_query_name(
str(source_table)
)
function_fqn_string = build_es_fqn_search_string(
database_query or database_name,
schema_query or schema_name,
service_name,
table,
database_name=database_query or database_name,
schema_name=schema_query or schema_name,
service_name=service_names[0], # Use first service for table entity lookup
table_name=table,
)
es_result_entities: Optional[List[StoredProcedure]] = metadata.es_search_from_fqn(
entity_type=StoredProcedure,
@ -383,7 +391,7 @@ def __process_udf_table_names(
source_table,
database_name,
schema_name,
service_name,
service_names,
timeout_seconds,
column_lineage,
es_result_entities,
@ -398,7 +406,7 @@ def get_source_table_names(
source_table: Union[DataFunction, LineageTable],
database_name: Optional[str],
schema_name: Optional[str],
service_name: Optional[str],
service_names: Union[str, List[str]],
timeout_seconds: int,
column_lineage: dict,
procedure: Optional[StoredProcedure] = None,
@ -406,6 +414,9 @@ def get_source_table_names(
"""
Get source table names from DataFunction
"""
if isinstance(service_names, str):
service_names = [service_names]
try:
if not isinstance(source_table, DataFunction):
yield (
@ -420,7 +431,7 @@ def get_source_table_names(
source_table,
database_name,
schema_name,
service_name,
service_names,
timeout_seconds,
column_lineage,
procedure,
@ -594,7 +605,7 @@ def _create_lineage_by_table_name(
metadata: OpenMetadata,
from_table: str,
to_table: str,
service_name: str,
service_names: Union[str, List[str]],
database_name: Optional[str],
schema_name: Optional[str],
masked_query: str,
@ -607,10 +618,13 @@ def _create_lineage_by_table_name(
"""
This method is to create a lineage between two tables
"""
if isinstance(service_names, str):
service_names = [service_names]
try:
from_table_entities = get_table_entities_from_query(
metadata=metadata,
service_names=service_name,
service_names=service_names,
database_name=database_name,
database_schema=schema_name,
table_name=from_table,
@ -619,7 +633,7 @@ def _create_lineage_by_table_name(
to_table_entities = get_table_entities_from_query(
metadata=metadata,
service_names=service_name,
service_names=service_names,
database_name=database_name,
database_schema=schema_name,
table_name=to_table,
@ -761,7 +775,7 @@ def get_lineage_by_query(
source_table=source_table,
database_name=database_name,
schema_name=schema_name,
service_name=service_names,
service_names=service_names,
timeout_seconds=timeout_seconds,
column_lineage=column_lineage,
):
@ -769,7 +783,7 @@ def get_lineage_by_query(
metadata,
from_table=str(from_table_name),
to_table=str(intermediate_table),
service_name=service_names,
service_names=service_names,
database_name=database_name,
schema_name=schema_name,
masked_query=masked_query,
@ -784,7 +798,7 @@ def get_lineage_by_query(
metadata,
from_table=str(intermediate_table),
to_table=str(target_table),
service_name=service_names,
service_names=service_names,
database_name=database_name,
schema_name=schema_name,
masked_query=masked_query,
@ -802,7 +816,7 @@ def get_lineage_by_query(
source_table=source_table,
database_name=database_name,
schema_name=schema_name,
service_name=service_names,
service_names=service_names,
timeout_seconds=timeout_seconds,
column_lineage=column_lineage,
):
@ -810,7 +824,7 @@ def get_lineage_by_query(
metadata,
from_table=str(from_table_name),
to_table=str(target_table),
service_name=service_names,
service_names=service_names,
database_name=database_name,
schema_name=schema_name,
masked_query=masked_query,
@ -876,7 +890,7 @@ def get_lineage_via_table_entity(
source_table=from_table_name,
database_name=database_name,
schema_name=schema_name,
service_name=service_names,
service_names=service_names,
timeout_seconds=timeout_seconds,
column_lineage=column_lineage,
):
@ -884,7 +898,7 @@ def get_lineage_via_table_entity(
metadata,
from_table=str(source_table),
to_table=f"{schema_name}.{to_table_name}",
service_name=service_names,
service_names=service_names,
database_name=database_name,
schema_name=schema_name,
masked_query=masked_query,

View File

@ -651,3 +651,80 @@ class CrossDatabaseLineageSQLTest(TestCase):
# The actual lineage generation depends on the mocked dependencies
# but we can verify that the method executes without errors
self.assertIsInstance(result, list)
def test_build_es_fqn_search_string_kwargs(self):
"""
Test that build_es_fqn_search_string is called with keyword arguments
and handles service_names list correctly via get_source_table_names
"""
from unittest.mock import MagicMock, patch
from collate_sqllineage.core.models import DataFunction
from metadata.ingestion.lineage.sql_lineage import get_source_table_names
mock_metadata = MagicMock()
mock_metadata.es_search_from_fqn.return_value = None
# Test with a DataFunction to trigger the UDF processing
source_table = DataFunction("test_function")
# Mock build_es_fqn_search_string to capture how it's called
with patch(
"metadata.ingestion.lineage.sql_lineage.build_es_fqn_search_string"
) as mock_build:
mock_build.return_value = "test.fqn.string"
# Test with list of service names - this is the bug scenario
service_names = ["service1", "service2"]
list(
get_source_table_names(
metadata=mock_metadata,
dialect=Dialect.ANSI,
source_table=source_table,
database_name="test_db",
schema_name="test_schema",
service_names=service_names,
timeout_seconds=30,
column_lineage={},
procedure=None,
)
)
# Verify build_es_fqn_search_string was called with keyword arguments
# and the first service name from the list
mock_build.assert_called_with(
database_name="test_db",
schema_name="test_schema",
service_name="service1", # Should use first service from list
table_name="test_function",
)
# Test with single service name
with patch(
"metadata.ingestion.lineage.sql_lineage.build_es_fqn_search_string"
) as mock_build:
mock_build.return_value = "test.fqn.string"
service_names = "single_service"
list(
get_source_table_names(
metadata=mock_metadata,
dialect=Dialect.ANSI,
source_table=source_table,
database_name="test_db",
schema_name="test_schema",
service_names=service_names,
timeout_seconds=30,
column_lineage={},
procedure=None,
)
)
# Should handle string service name correctly
mock_build.assert_called_with(
database_name="test_db",
schema_name="test_schema",
service_name="single_service",
table_name="test_function",
)

View File

@ -23,6 +23,7 @@ from metadata.ingestion.lineage.masker import mask_query
from metadata.ingestion.lineage.models import Dialect
from metadata.ingestion.lineage.parser import LineageParser
from metadata.ingestion.lineage.sql_lineage import (
_replace_target_table,
get_column_lineage,
get_table_fqn_from_query_name,
populate_column_lineage_map,
@ -292,3 +293,53 @@ class SqlLineageTest(TestCase):
for i, query in enumerate(query_list):
self.assertEqual(mask_query(query[0], query[1]), expected_query_list[i])
def test_replace_target_table(self):
"""
Test the _replace_target_table function
"""
# Create a LineageParser with a dummy UDF query
query = "CREATE TABLE dummy_table_name AS SELECT id, name FROM source_table"
parser = LineageParser(query, dialect=Dialect.ANSI)
# Replace the target table with the expected name
expected_table_name = "actual_target_table"
_replace_target_table(parser, expected_table_name)
# Verify the target table has been replaced
stmt_holder = parser.parser._stmt_holders[0]
target_tables = list(stmt_holder.write)
# Check that we have exactly one target table with the expected name
self.assertEqual(len(target_tables), 1)
self.assertEqual(str(target_tables[0]), "<default>.actual_target_table")
# Verify column lineage is preserved
column_lineage = parser.parser.get_column_lineage()
self.assertIsNotNone(column_lineage)
# Check that column lineage points to the new target table
for col_lineage in column_lineage:
target_column = col_lineage[-1]
self.assertEqual(str(target_column.parent), "<default>.actual_target_table")
def test_replace_target_table_with_default_schema(self):
"""
Test _replace_target_table with default schema removal
"""
# Create a LineageParser with a query
query = "CREATE TABLE dummy_table_name AS SELECT * FROM source_table"
parser = LineageParser(query, dialect=Dialect.ANSI)
# Replace with a name containing default schema
expected_table_name = "<default>.actual_table"
_replace_target_table(parser, expected_table_name)
# Verify the target table name is correct
# Note: LineageTable always adds <default> for tables without schema
stmt_holder = parser.parser._stmt_holders[0]
target_tables = list(stmt_holder.write)
self.assertEqual(len(target_tables), 1)
# LineageTable will add <default> back even after we remove it
self.assertEqual(str(target_tables[0]), "<default>.actual_table")