[SAP HANA] Prevent exponential processing lineage parsing and use full name for filtering (#23484)

* Prevent exponential processing lineage parsing

* Use full name of views for filtering

* pylint fix - isort
This commit is contained in:
Mohit Tilala 2025-09-22 19:46:34 +05:30 committed by GitHub
parent 71f993a2fc
commit d1e60acd2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 215 additions and 97 deletions

View File

@ -340,40 +340,6 @@ class ParsedLineage(BaseModel):
)
def _get_column_datasources(
entry: ET.Element, datasource_map: Optional[DataSourceMap] = None
) -> Set[DataSource]:
"""Read a DataSource from the CDATA XML"""
if (
datasource_map
and entry.get(CDATAKeys.COLUMN_OBJECT_NAME.value) in datasource_map
):
# If the datasource is in the map, we'll traverse all intermediate logical
# datasources until we arrive to a table or view.
# Note that we can have multiple sources for a single column, e.g., columns
# coming from a JOIN
return set(
_traverse_ds(
current_column=entry.get(CDATAKeys.COLUMN_NAME.value),
ds_origin_list=[],
current_ds=datasource_map[
entry.get(CDATAKeys.COLUMN_OBJECT_NAME.value)
],
datasource_map=datasource_map,
)
)
# If we don't have any logical sources (projections, aggregations, etc.) We'll stick to
# a single table origin
return {
DataSource(
name=entry.get(CDATAKeys.COLUMN_OBJECT_NAME.value),
location=entry.get(CDATAKeys.SCHEMA_NAME.value),
source_type=ViewType.DATA_BASE_TABLE,
)
}
def _get_column_datasources_with_names(
entry: ET.Element, datasource_map: Optional[DataSourceMap] = None
) -> List[Tuple[DataSource, str, Optional[str]]]:
@ -392,6 +358,7 @@ def _get_column_datasources_with_names(
current_ds=datasource_map[entry.get(CDATAKeys.COLUMN_OBJECT_NAME.value)],
datasource_map=datasource_map,
formula=None,
_visited=set(),
)
return ds_col_pairs
@ -409,56 +376,13 @@ def _get_column_datasources_with_names(
]
def _traverse_ds(
current_column: str,
ds_origin_list: List[DataSource],
current_ds: DataSource,
datasource_map: Optional[DataSourceMap],
) -> List[DataSource]:
"""
Traverse the ds dict jumping from target -> source columns and getting the right parent.
We keep inspecting current datasources and will append to the origin list the ones
that are not LOGICAL
"""
if current_ds.source_type != ViewType.LOGICAL:
ds_origin_list.append(current_ds)
else:
# Based on our current column, find the parents from the mappings in the current_ds
current_ds_mapping: Optional[DataSourceMapping] = current_ds.mapping.get(
current_column
)
if current_ds_mapping:
for parent in current_ds_mapping.parents:
parent_ds = datasource_map.get(parent.parent)
if not parent_ds:
raise CDATAParsingError(
f"Can't find parent [{parent.parent}] for column [{current_column}]"
)
# Traverse from the source column in the parent mapping
_traverse_ds(
current_column=parent.source,
ds_origin_list=ds_origin_list,
current_ds=parent_ds,
datasource_map=datasource_map,
)
else:
logger.info(
f"Can't find mapping for column [{current_column}] in [{current_ds}]. "
f"This might be a constant or derived column."
)
return ds_origin_list
def _traverse_ds_with_columns(
current_column: str,
ds_origin_list: List[Tuple[DataSource, str, Optional[str]]],
current_ds: DataSource,
datasource_map: Optional[DataSourceMap],
formula: Optional[str] = None,
_visited: Optional[set] = set(),
) -> List[Tuple[DataSource, str, Optional[str]]]:
"""
Traverse the ds dict jumping from target -> source columns and getting the right parent.
@ -466,6 +390,16 @@ def _traverse_ds_with_columns(
that are not LOGICAL, along with the final column name and formula.
Returns a list of tuples (DataSource, column_name, formula).
"""
# Create visit key for this node
visit_key = (current_ds.name, current_column)
# Check if we've already processed this node
if visit_key in _visited:
return ds_origin_list
# Add to visited set
_visited.add(visit_key)
if current_ds.source_type != ViewType.LOGICAL:
# This is a final datasource, append it with the current column name and formula
ds_origin_list.append((current_ds, current_column, formula))
@ -496,6 +430,7 @@ def _traverse_ds_with_columns(
current_ds=parent_ds,
datasource_map=datasource_map,
formula=formula,
_visited=_visited,
)
else:
logger.info(

View File

@ -111,14 +111,15 @@ class SaphanaLineageSource(Source):
if filter_by_table(
self.source_config.tableFilterPattern,
lineage_model.object_name,
lineage_model.name,
):
self.status.filter(
lineage_model.object_name,
lineage_model.name,
"View Object Filtered Out",
)
continue
logger.debug(f"Processing lineage for view: {lineage_model.name}")
yield from self.parse_cdata(
metadata=self.metadata, lineage_model=lineage_model
)
@ -156,11 +157,11 @@ class SaphanaLineageSource(Source):
except Exception as exc:
error = (
f"Error parsing CDATA XML for {lineage_model.object_suffix} at "
+ f"{lineage_model.package_id}/{lineage_model.object_name} due to [{exc}]"
+ f"{lineage_model.name} due to [{exc}]"
)
self.status.failed(
error=StackTraceError(
name=lineage_model.object_name,
name=lineage_model.name,
error=error,
stackTrace=traceback.format_exc(),
)

View File

@ -11,15 +11,38 @@
"""
Test SAP Hana source
"""
import xml.etree.ElementTree as ET
from pathlib import Path
from unittest.mock import MagicMock, Mock, create_autospec, patch
from metadata.generated.schema.entity.services.connections.database.sapHana.sapHanaSQLConnection import (
SapHanaSQLConnection,
)
from metadata.generated.schema.entity.services.connections.database.sapHanaConnection import (
SapHanaConnection,
)
from metadata.generated.schema.entity.services.databaseService import DatabaseConnection
from metadata.generated.schema.metadataIngestion.databaseServiceMetadataPipeline import (
DatabaseServiceMetadataPipeline,
)
from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.generated.schema.metadataIngestion.workflow import SourceConfig
from metadata.generated.schema.type.filterPattern import FilterPattern
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.saphana.cdata_parser import (
ColumnMapping,
DataSource,
DataSourceMapping,
ParentSource,
ParsedLineage,
ViewType,
_parse_cv_data_sources,
_traverse_ds_with_columns,
parse_registry,
)
from metadata.ingestion.source.database.saphana.lineage import SaphanaLineageSource
RESOURCES_DIR = Path(__file__).parent.parent.parent / "resources" / "saphana"
@ -166,8 +189,6 @@ def test_parse_cv() -> None:
def test_schema_mapping_in_datasource():
"""Test that DataSource correctly handles schema mapping for DATA_BASE_TABLE type"""
from unittest.mock import MagicMock, patch
# Create a mock engine and connection
mock_engine = MagicMock()
mock_conn = MagicMock()
@ -223,7 +244,6 @@ def test_schema_mapping_in_datasource():
def test_parsed_lineage_with_schema_mapping():
"""Test that ParsedLineage.to_request passes engine parameter correctly"""
from unittest.mock import MagicMock, patch
# Create a simple parsed lineage
ds = DataSource(
@ -435,12 +455,6 @@ def test_analytic_view_formula_column_source_mapping() -> None:
def test_formula_columns_reference_correct_layer():
"""Test that formula columns reference the correct calculation view layer"""
import xml.etree.ElementTree as ET
from metadata.ingestion.source.database.saphana.cdata_parser import (
_parse_cv_data_sources,
)
# Load the complex star join view XML
with open(
RESOURCES_DIR / "custom" / "cdata_calculation_view_star_join_complex.xml"
@ -487,12 +501,6 @@ def test_formula_columns_reference_correct_layer():
def test_projection_formula_columns():
"""Test that projection view formula columns reference the correct layer"""
import xml.etree.ElementTree as ET
from metadata.ingestion.source.database.saphana.cdata_parser import (
_parse_cv_data_sources,
)
with open(
RESOURCES_DIR / "custom" / "cdata_calculation_view_star_join_complex.xml"
) as file:
@ -727,3 +735,177 @@ def test_formula_parsing_comprehensive():
assert (
complex_calc and complex_calc.formula == '"PRICE" * 1.1 + 10'
), "Complex formula not preserved"
def test_circular_reference_prevention() -> None:
"""Test that we handle circular references without infinite recursion
While SAP HANA doesn't actually create circular references in calculation views,
this test ensures our visited tracking works properly. The same mechanism that
prevents infinite loops here also prevents exponential processing in complex
calculation view hierarchies.
TODO: Add test for the actual exponential processing scenario
"""
# Create a scenario with circular dependencies
datasource_map = {
"TestView": DataSource(
name="TestView",
location=None,
source_type=ViewType.LOGICAL,
mapping={
"ColumnA": DataSourceMapping(
target="ColumnA",
parents=[ParentSource(source="ColumnB", parent="TestView")],
formula='"ColumnB" + 1',
),
"ColumnB": DataSourceMapping(
target="ColumnB",
parents=[ParentSource(source="ColumnA", parent="TestView")],
formula='"ColumnA" - 1',
),
},
),
}
# Track function calls
call_count = 0
original_traverse = _traverse_ds_with_columns
def counting_traverse(*args, **kwargs):
nonlocal call_count
call_count += 1
return original_traverse(*args, **kwargs)
with patch(
"metadata.ingestion.source.database.saphana.cdata_parser._traverse_ds_with_columns",
side_effect=counting_traverse,
):
ds_origin_list = []
current_ds = datasource_map["TestView"]
_traverse_ds_with_columns(
current_column="ColumnA",
ds_origin_list=ds_origin_list,
current_ds=current_ds,
datasource_map=datasource_map,
)
# With circular reference prevention, should visit each node only once
# Without prevention, this would recurse infinitely
assert (
call_count <= 3
), f"Too many function calls: {call_count} (indicates circular recursion)"
def test_sap_hana_lineage_filter_pattern() -> None:
"""
Test that SAP HANA lineage source filters views based on
the full package_id/object_name format.
"""
mock_metadata = create_autospec(OpenMetadata)
mock_metadata.get_by_name = Mock(return_value=None)
mock_config = WorkflowSource(
type="saphana-lineage",
serviceName="test_sap_hana",
serviceConnection=DatabaseConnection(
config=SapHanaConnection(
connection=SapHanaSQLConnection(
username="test", password="test", hostPort="localhost:39015"
)
)
),
sourceConfig=SourceConfig(
config=DatabaseServiceMetadataPipeline(
tableFilterPattern=FilterPattern(
includes=["com.example.package/CV_INCLUDE.*"],
excludes=[".*/CV_EXCLUDE.*"],
)
)
),
)
with patch(
"metadata.ingestion.source.database.saphana.lineage.get_ssl_connection"
) as mock_get_engine:
mock_engine = MagicMock()
mock_connection = MagicMock()
mock_get_engine.return_value = mock_engine
mock_engine.connect.return_value.__enter__ = Mock(return_value=mock_connection)
mock_engine.connect.return_value.__exit__ = Mock()
mock_rows = [
{
"PACKAGE_ID": "com.example.package",
"OBJECT_NAME": "CV_INCLUDE_VIEW",
"OBJECT_SUFFIX": "calculationview",
"CDATA": "<dummy/>",
},
{
"PACKAGE_ID": "com.example.package",
"OBJECT_NAME": "CV_EXCLUDE_VIEW",
"OBJECT_SUFFIX": "calculationview",
"CDATA": "<dummy/>",
},
{
"PACKAGE_ID": "com.example.package",
"OBJECT_NAME": "CV_OTHER_VIEW",
"OBJECT_SUFFIX": "calculationview",
"CDATA": "<dummy/>",
},
{
"PACKAGE_ID": "com.example.package",
"OBJECT_NAME": "CV_INCLUDE_ANOTHER",
"OBJECT_SUFFIX": "calculationview",
"CDATA": "<dummy/>",
},
]
mock_result = []
for row_dict in mock_rows:
class MockRow(dict):
def __init__(self, data):
lowercase_data = {k.lower(): v for k, v in data.items()}
super().__init__(lowercase_data)
self._data = data
def __getitem__(self, key):
if key in self._data:
return self._data[key]
return super().__getitem__(key.lower())
def keys(self):
return [k.lower() for k in self._data.keys()]
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
mock_result.append(MockRow(row_dict))
mock_execution = MagicMock()
mock_execution.__iter__ = Mock(return_value=iter(mock_result))
mock_connection.execution_options.return_value.execute.return_value = (
mock_execution
)
source = SaphanaLineageSource(config=mock_config, metadata=mock_metadata)
processed_views = []
def mock_parse_cdata(metadata, lineage_model):
processed_views.append(lineage_model.object_name)
return iter([])
with patch.object(source, "parse_cdata", side_effect=mock_parse_cdata):
list(source._iter())
assert "CV_INCLUDE_VIEW" in processed_views
assert "CV_INCLUDE_ANOTHER" in processed_views
assert "CV_EXCLUDE_VIEW" not in processed_views
assert "CV_OTHER_VIEW" not in processed_views
assert len(processed_views) == 2
assert len(source.status.filtered) == 2