fix(ingest/datahub): Create Structured property templates in advance and batch processing (#13355)

Co-authored-by: Pedro Silva <pedro@acryl.io>
This commit is contained in:
Tamas Nemeth 2025-05-26 14:05:17 +02:00 committed by GitHub
parent 4c6672213c
commit 2ffa84be5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 466 additions and 36 deletions

View File

@ -118,6 +118,17 @@ class DataHubSourceConfig(StatefulIngestionConfigBase):
"Useful if the source system has duplicate field paths in the db, but we're pushing to a system with server-side duplicate checking.", "Useful if the source system has duplicate field paths in the db, but we're pushing to a system with server-side duplicate checking.",
) )
structured_properties_template_cache_invalidation_interval: int = Field(
hidden_from_docs=True,
default=60,
description="Interval in seconds to invalidate the structured properties template cache.",
)
query_timeout: Optional[int] = Field(
default=None,
description="Timeout for each query in seconds. ",
)
@root_validator(skip_on_failure=True) @root_validator(skip_on_failure=True)
def check_ingesting_data(cls, values): def check_ingesting_data(cls, values):
if ( if (

View File

@ -1,10 +1,10 @@
import contextlib
import json import json
import logging import logging
import time
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar
from sqlalchemy import create_engine from sqlalchemy import create_engine, text
from datahub.emitter.aspect import ASPECT_MAP from datahub.emitter.aspect import ASPECT_MAP
from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.emitter.mcp import MetadataChangeProposalWrapper
@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
# Should work for at least mysql, mariadb, postgres # Should work for at least mysql, mariadb, postgres
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f" DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f"
DATE_FORMAT = "%Y-%m-%d"
ROW = TypeVar("ROW", bound=Dict[str, Any]) ROW = TypeVar("ROW", bound=Dict[str, Any])
@ -85,6 +86,9 @@ class DataHubDatabaseReader:
**connection_config.options, **connection_config.options,
) )
# Cache for available dates to avoid redundant queries
self.available_dates_cache: Optional[List[datetime]] = None
@property @property
def soft_deleted_urns_query(self) -> str: def soft_deleted_urns_query(self) -> str:
return f""" return f"""
@ -100,14 +104,12 @@ class DataHubDatabaseReader:
ORDER BY mav.urn ORDER BY mav.urn
""" """
@property def query(self, set_structured_properties_filter: bool) -> str:
def query(self) -> str: """
# May repeat rows for the same date Main query that gets data for specified date range with appropriate filters.
# Offset is generally 0, unless we repeat the same createdon twice """
structured_prop_filter = f" AND urn {'' if set_structured_properties_filter else 'NOT'} like 'urn:li:structuredProperty:%%'"
# Ensures stable order, chronological per (urn, aspect)
# Relies on createdon order to reflect version order
# Ordering of entries with the same createdon is handled by VersionOrderer
return f""" return f"""
SELECT * SELECT *
FROM ( FROM (
@ -132,6 +134,7 @@ class DataHubDatabaseReader:
{"" if self.config.include_all_versions else "AND mav.version = 0"} {"" if self.config.include_all_versions else "AND mav.version = 0"}
{"" if not self.config.exclude_aspects else "AND mav.aspect NOT IN %(exclude_aspects)s"} {"" if not self.config.exclude_aspects else "AND mav.aspect NOT IN %(exclude_aspects)s"}
AND mav.createdon >= %(since_createdon)s AND mav.createdon >= %(since_createdon)s
AND mav.createdon < %(end_createdon)s
ORDER BY ORDER BY
createdon, createdon,
urn, urn,
@ -139,50 +142,194 @@ class DataHubDatabaseReader:
version version
) as t ) as t
WHERE 1=1 WHERE 1=1
{"" if self.config.include_soft_deleted_entities else "AND (removed = false or removed is NULL)"} {"" if self.config.include_soft_deleted_entities else " AND (removed = false or removed is NULL)"}
{structured_prop_filter}
ORDER BY ORDER BY
createdon, createdon,
urn, urn,
aspect, aspect,
version version
LIMIT %(limit)s
OFFSET %(offset)s
""" """
def execute_with_params(
self, query: str, params: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Execute query with proper parameter binding that works with your database"""
with self.engine.connect() as conn:
result = conn.execute(query, params or {})
return [dict(row) for row in result.fetchall()]
def execute_server_cursor( def execute_server_cursor(
self, query: str, params: Dict[str, Any] self, query: str, params: Dict[str, Any]
) -> Iterable[Dict[str, Any]]: ) -> Iterable[Dict[str, Any]]:
"""Execute a query with server-side cursor"""
with self.engine.connect() as conn: with self.engine.connect() as conn:
if self.engine.dialect.name in ["postgresql", "mysql", "mariadb"]: if self.engine.dialect.name in ["postgresql", "mysql", "mariadb"]:
with ( with (
conn.begin() conn.begin()
): # Transaction required for PostgreSQL server-side cursor ): # Transaction required for PostgreSQL server-side cursor
# Note that stream_results=True is mainly supported by PostgreSQL and MySQL-based dialects. # Set query timeout at the connection level
# https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.Connection.execution_options.params.stream_results if self.config.query_timeout:
if self.engine.dialect.name == "postgresql":
conn.execute(
text(
f"SET statement_timeout = {self.config.query_timeout * 1000}"
)
) # milliseconds
elif self.engine.dialect.name in ["mysql", "mariadb"]:
conn.execute(
text(
f"SET max_execution_time = {self.config.query_timeout * 1000}"
)
) # milliseconds
# Stream results with batch size
conn = conn.execution_options( conn = conn.execution_options(
stream_results=True, stream_results=True,
yield_per=self.config.database_query_batch_size, yield_per=self.config.database_query_batch_size,
) )
# Execute query - using native parameterization without text()
# to maintain compatibility with your original code
result = conn.execute(query, params) result = conn.execute(query, params)
for row in result: for row in result:
yield dict(row) yield dict(row)
return # Success, exit the retry loop
else: else:
raise ValueError(f"Unsupported dialect: {self.engine.dialect.name}") raise ValueError(f"Unsupported dialect: {self.engine.dialect.name}")
def _get_rows( def _get_rows(
self, from_createdon: datetime, stop_time: datetime self,
start_date: datetime,
end_date: datetime,
set_structured_properties_filter: bool,
limit: int,
) -> Iterable[Dict[str, Any]]: ) -> Iterable[Dict[str, Any]]:
params = { """
"exclude_aspects": list(self.config.exclude_aspects), Retrieves data rows within a specified date range using pagination.
"since_createdon": from_createdon.strftime(DATETIME_FORMAT),
Implements a hybrid pagination strategy that switches between time-based and
offset-based approaches depending on the returned data. Uses server-side
cursors for efficient memory usage.
Note: May return duplicate rows across batch boundaries when multiple rows
share the same 'createdon' timestamp. This is expected behavior when
transitioning between pagination methods.
Args:
start_date: Beginning of date range (inclusive)
end_date: End of date range (exclusive)
set_structured_properties_filter: Whether to apply structured filtering
limit: Maximum rows to fetch per query
Returns:
An iterable of database rows as dictionaries
"""
offset = 0
last_createdon = None
first_iteration = True
while True:
try:
# Set up query and parameters - using named parameters
query = self.query(set_structured_properties_filter)
params: Dict[str, Any] = {
"since_createdon": start_date.strftime(DATETIME_FORMAT),
"end_createdon": end_date.strftime(DATETIME_FORMAT),
"limit": limit,
"offset": offset,
} }
yield from self.execute_server_cursor(self.query, params)
# Add exclude_aspects if needed
if (
hasattr(self.config, "exclude_aspects")
and self.config.exclude_aspects
):
params["exclude_aspects"] = tuple(self.config.exclude_aspects)
logger.info(
f"Querying data from {start_date.strftime(DATETIME_FORMAT)} to {end_date.strftime(DATETIME_FORMAT)} "
f"with limit {limit} and offset {offset} (inclusive range)"
)
# Execute query with server-side cursor
rows = self.execute_server_cursor(query, params)
# Process and yield rows
rows_processed = 0
for row in rows:
if first_iteration:
start_date = row.get("createdon", start_date)
first_iteration = False
last_createdon = row.get("createdon")
rows_processed += 1
yield row
# If we processed fewer than the limit or no last_createdon, we're done
if rows_processed < limit or not last_createdon:
break
# Update parameters for next iteration
if start_date != last_createdon:
start_date = last_createdon
offset = 0
else:
offset += limit
logger.info(
f"Processed {rows_processed} rows for date range {start_date} to {end_date}. Continuing to next batch."
)
except Exception as e:
logger.error(
f"Error processing date range {start_date} to {end_date}: {str(e)}"
)
# Re-raise the exception after logging
raise
def get_all_aspects(
self, from_createdon: datetime, stop_time: datetime
) -> Iterable[Tuple[MetadataChangeProposalWrapper, datetime]]:
logger.info("Fetching Structured properties aspects")
yield from self.get_aspects(
from_createdon=from_createdon,
stop_time=stop_time,
set_structured_properties_filter=True,
)
logger.info(
f"Waiting for {self.config.structured_properties_template_cache_invalidation_interval} seconds for structured properties cache to invalidate"
)
time.sleep(
self.config.structured_properties_template_cache_invalidation_interval
)
logger.info("Fetching aspects")
yield from self.get_aspects(
from_createdon=from_createdon,
stop_time=stop_time,
set_structured_properties_filter=False,
)
def get_aspects( def get_aspects(
self, from_createdon: datetime, stop_time: datetime self,
from_createdon: datetime,
stop_time: datetime,
set_structured_properties_filter: bool = False,
) -> Iterable[Tuple[MetadataChangeProposalWrapper, datetime]]: ) -> Iterable[Tuple[MetadataChangeProposalWrapper, datetime]]:
orderer = VersionOrderer[Dict[str, Any]]( orderer = VersionOrderer[Dict[str, Any]](
enabled=self.config.include_all_versions enabled=self.config.include_all_versions
) )
rows = self._get_rows(from_createdon=from_createdon, stop_time=stop_time) rows = self._get_rows(
start_date=from_createdon,
end_date=stop_time,
set_structured_properties_filter=set_structured_properties_filter,
limit=self.config.database_query_batch_size,
)
for row in orderer(rows): for row in orderer(rows):
mcp = self._parse_row(row) mcp = self._parse_row(row)
if mcp: if mcp:
@ -190,23 +337,29 @@ class DataHubDatabaseReader:
def get_soft_deleted_rows(self) -> Iterable[Dict[str, Any]]: def get_soft_deleted_rows(self) -> Iterable[Dict[str, Any]]:
""" """
Fetches all soft-deleted entities from the database. Fetches all soft-deleted entities from the database using pagination.
Yields: Yields:
Row objects containing URNs of soft-deleted entities Row objects containing URNs of soft-deleted entities
""" """
with self.engine.connect() as conn, contextlib.closing( try:
conn.connection.cursor() params: Dict = {}
) as cursor:
logger.debug("Polling soft-deleted urns from database") logger.debug("Fetching soft-deleted URNs")
cursor.execute(self.soft_deleted_urns_query)
columns = [desc[0] for desc in cursor.description] # Use server-side cursor implementation
while True: rows = self.execute_server_cursor(self.soft_deleted_urns_query, params)
rows = cursor.fetchmany(self.config.database_query_batch_size) processed_rows = 0
if not rows: # Process and yield rows
return
for row in rows: for row in rows:
yield dict(zip(columns, row)) processed_rows += 1
yield row
logger.debug(f"Fetched batch of {processed_rows} soft-deleted URNs")
except Exception:
logger.exception("Error fetching soft-deleted row", exc_info=True)
raise
def _parse_row( def _parse_row(
self, row: Dict[str, Any] self, row: Dict[str, Any]

View File

@ -117,7 +117,7 @@ class DataHubSource(StatefulIngestionSourceBase):
) -> Iterable[MetadataWorkUnit]: ) -> Iterable[MetadataWorkUnit]:
logger.info(f"Fetching database aspects starting from {from_createdon}") logger.info(f"Fetching database aspects starting from {from_createdon}")
progress = ProgressTimer(report_every=timedelta(seconds=60)) progress = ProgressTimer(report_every=timedelta(seconds=60))
mcps = reader.get_aspects(from_createdon, self.report.stop_time) mcps = reader.get_all_aspects(from_createdon, self.report.stop_time)
for i, (mcp, createdon) in enumerate(mcps): for i, (mcp, createdon) in enumerate(mcps):
if not self.urn_pattern.allowed(str(mcp.entityUrn)): if not self.urn_pattern.allowed(str(mcp.entityUrn)):
continue continue

View File

@ -1,8 +1,14 @@
from typing import Any, Dict from datetime import datetime
from typing import Any, Dict, List
from unittest.mock import MagicMock, patch
import pytest import pytest
from datahub.ingestion.source.datahub.datahub_database_reader import VersionOrderer from datahub.ingestion.source.datahub.datahub_database_reader import (
DATETIME_FORMAT,
DataHubDatabaseReader,
VersionOrderer,
)
@pytest.fixture @pytest.fixture
@ -39,3 +45,263 @@ def test_version_orderer_disabled(rows):
orderer = VersionOrderer[Dict[str, Any]](enabled=False) orderer = VersionOrderer[Dict[str, Any]](enabled=False)
ordered_rows = list(orderer(rows)) ordered_rows = list(orderer(rows))
assert ordered_rows == rows assert ordered_rows == rows
@pytest.fixture
def mock_reader():
with patch(
"datahub.ingestion.source.datahub.datahub_database_reader.create_engine"
) as mock_create_engine:
config = MagicMock()
connection_config = MagicMock()
report = MagicMock()
mock_engine = MagicMock()
mock_dialect = MagicMock()
mock_identifier_preparer = MagicMock()
mock_dialect.identifier_preparer = mock_identifier_preparer
mock_identifier_preparer.quote = lambda x: f'"{x}"'
mock_engine.dialect = mock_dialect
mock_create_engine.return_value = mock_engine
reader = DataHubDatabaseReader(config, connection_config, report)
reader.query = MagicMock(side_effect=reader.query) # type: ignore
reader.execute_server_cursor = MagicMock() # type: ignore
return reader
def test_get_rows_for_date_range_no_rows(mock_reader):
# Setup
start_date = datetime(2023, 1, 1)
end_date = datetime(2023, 1, 2)
mock_reader.execute_server_cursor.return_value = []
# Execute
result = list(mock_reader._get_rows(start_date, end_date, False, 50))
# Assert
assert len(result) == 0
mock_reader.query.assert_called_once_with(False)
mock_reader.execute_server_cursor.assert_called_once()
def test_get_rows_for_date_range_with_rows(mock_reader):
# Setup
start_date = datetime(2023, 1, 1)
end_date = datetime(2023, 1, 2)
mock_rows = [
{"urn": "urn1", "metadata": "data1", "createdon": datetime(2023, 1, 1, 12, 0)},
{"urn": "urn2", "metadata": "data2", "createdon": datetime(2023, 1, 1, 13, 0)},
]
mock_reader.execute_server_cursor.return_value = mock_rows
# Execute
result = list(mock_reader._get_rows(start_date, end_date, False, 50))
# Assert
assert result == mock_rows
mock_reader.query.assert_called_once_with(False)
assert mock_reader.execute_server_cursor.call_count == 1
def test_get_rows_for_date_range_pagination_same_timestamp(mock_reader):
# Setup
start_date = datetime(2023, 1, 1)
end_date = datetime(2023, 1, 2)
batch1 = [
{"urn": "urn1", "metadata": "data1", "createdon": datetime(2023, 1, 1, 12, 0)},
{"urn": "urn2", "metadata": "data2", "createdon": datetime(2023, 1, 1, 12, 0)},
]
batch2 = [
{"urn": "urn3", "metadata": "data3", "createdon": datetime(2023, 1, 1, 12, 0)},
]
batch3: List[Dict] = []
mock_reader.execute_server_cursor.side_effect = [batch1, batch2, batch3]
# Execute
result = list(mock_reader._get_rows(start_date, end_date, False, 2))
# Assert
assert len(result) == 3
assert result[0]["urn"] == "urn1"
assert result[1]["urn"] == "urn2"
assert result[2]["urn"] == "urn3"
assert mock_reader.execute_server_cursor.call_count == 2
def test_get_rows_for_date_range_pagination_different_timestamp(mock_reader):
# Setup
start_date = datetime(2023, 1, 1)
end_date = datetime(2023, 1, 2)
batch1 = [
{"urn": "urn1", "metadata": "data1", "createdon": datetime(2023, 1, 1, 12, 0)},
{"urn": "urn2", "metadata": "data2", "createdon": datetime(2023, 1, 1, 13, 0)},
]
batch2 = [
{"urn": "urn3", "metadata": "data3", "createdon": datetime(2023, 1, 1, 14, 0)},
]
batch3: List[Dict] = []
mock_reader.execute_server_cursor.side_effect = [batch1, batch2, batch3]
# Execute
result = list(mock_reader._get_rows(start_date, end_date, False, 2))
# Assert
assert len(result) == 3
assert result[0]["urn"] == "urn1"
assert result[1]["urn"] == "urn2"
assert result[2]["urn"] == "urn3"
assert mock_reader.execute_server_cursor.call_count == 2
def test_get_rows_for_date_range_duplicate_data_handling(mock_reader):
# Setup
start_date = datetime(2023, 1, 1)
end_date = datetime(2023, 1, 2)
batch1 = [
{"urn": "urn1", "metadata": "data1", "createdon": datetime(2023, 1, 1, 12, 0)},
]
batch2 = [
{"urn": "urn2", "metadata": "data2", "createdon": datetime(2023, 1, 1, 13, 0)},
]
batch3: List[Dict] = []
mock_reader.execute_server_cursor.side_effect = [batch1, batch2, batch3]
# Execute
result = list(mock_reader._get_rows(start_date, end_date, False, 1))
# Assert
assert len(result) == 2
assert result[0]["urn"] == "urn1"
assert result[1]["urn"] == "urn2"
# Check call parameters for each iteration
calls = mock_reader.execute_server_cursor.call_args_list
assert len(calls) == 3
# First call: initial parameters
first_call_params = calls[0][0][1]
assert first_call_params["since_createdon"] == start_date.strftime(DATETIME_FORMAT)
assert first_call_params["end_createdon"] == end_date.strftime(DATETIME_FORMAT)
assert first_call_params["limit"] == 1
assert first_call_params["offset"] == 0
# Second call: duplicate detected, same createdon so offset increased
second_call_params = calls[1][0][1]
assert second_call_params["offset"] == 1
assert second_call_params["since_createdon"] == datetime(
2023, 1, 1, 12, 0
).strftime(DATETIME_FORMAT)
# Third call: successful fetch after duplicate with new timestamp
third_call_params = calls[2][0][1]
# After a duplicate with no last_createdon, offset should increase
assert third_call_params["offset"] == 0
def test_get_rows_multiple_paging(mock_reader):
# Setup
start_date = datetime(2023, 1, 1)
end_date = datetime(2023, 1, 2)
batch1 = [
{"urn": "urn1", "metadata": "data1", "createdon": datetime(2023, 1, 1, 12, 0)},
{"urn": "urn2", "metadata": "data1", "createdon": datetime(2023, 1, 1, 12, 0)},
{"urn": "urn3", "metadata": "data1", "createdon": datetime(2023, 1, 1, 12, 0)},
]
batch2 = [
{"urn": "urn4", "metadata": "data1", "createdon": datetime(2023, 1, 1, 12, 0)},
{"urn": "urn5", "metadata": "data1", "createdon": datetime(2023, 1, 1, 12, 0)},
{"urn": "urn6", "metadata": "data1", "createdon": datetime(2023, 1, 1, 13, 0)},
]
batch3 = [
{"urn": "urn7", "metadata": "data1", "createdon": datetime(2023, 1, 1, 14, 0)},
{"urn": "urn8", "metadata": "data1", "createdon": datetime(2023, 1, 1, 14, 0)},
{"urn": "urn9", "metadata": "data1", "createdon": datetime(2023, 1, 1, 15, 0)},
]
batch4 = [
{"urn": "urn10", "metadata": "data1", "createdon": datetime(2023, 1, 1, 16, 0)},
]
mock_reader.execute_server_cursor.side_effect = [batch1, batch2, batch3, batch4]
# Execute
result = list(mock_reader._get_rows(start_date, end_date, False, 3))
# Assert
# In this case duplicate items are expected
assert len(result) == 10
assert result[0]["urn"] == "urn1"
assert result[1]["urn"] == "urn2"
assert result[2]["urn"] == "urn3"
assert result[3]["urn"] == "urn4"
assert result[4]["urn"] == "urn5"
assert result[5]["urn"] == "urn6"
assert result[6]["urn"] == "urn7"
assert result[7]["urn"] == "urn8"
assert result[8]["urn"] == "urn9"
assert result[9]["urn"] == "urn10"
# Check call parameters for each iteration
calls = mock_reader.execute_server_cursor.call_args_list
assert len(calls) == 4
# First call: initial parameters
first_call_params = calls[0][0][1]
assert first_call_params["since_createdon"] == start_date.strftime(DATETIME_FORMAT)
assert first_call_params["end_createdon"] == end_date.strftime(DATETIME_FORMAT)
assert first_call_params["limit"] == 3
assert first_call_params["offset"] == 0
# Second call: duplicate detected, same createdon so offset increased
second_call_params = calls[1][0][1]
assert second_call_params["offset"] == 3
assert second_call_params["limit"] == 3
assert second_call_params["since_createdon"] == datetime(
2023, 1, 1, 12, 0
).strftime(DATETIME_FORMAT)
assert first_call_params["end_createdon"] == end_date.strftime(DATETIME_FORMAT)
# Third call: successful fetch after duplicate with new timestamp
third_call_params = calls[2][0][1]
# After a duplicate with no last_createdon, offset should increase
assert third_call_params["offset"] == 0
assert third_call_params["since_createdon"] == datetime(2023, 1, 1, 13, 0).strftime(
DATETIME_FORMAT
)
# Third call: successful fetch after duplicate with new timestamp
fourth_call_params = calls[3][0][1]
# After a duplicate with no last_createdon, offset should increase
assert fourth_call_params["offset"] == 0
assert fourth_call_params["since_createdon"] == datetime(
2023, 1, 1, 15, 0
).strftime(DATETIME_FORMAT)
assert fourth_call_params["limit"] == 3
def test_get_rows_for_date_range_exception_handling(mock_reader):
# Setup
start_date = datetime(2023, 1, 1)
end_date = datetime(2023, 1, 2)
mock_reader.execute_server_cursor.side_effect = Exception("Test exception")
# Execute and Assert
with pytest.raises(Exception, match="Test exception"):
list(mock_reader._get_rows(start_date, end_date, False, 50))
def test_get_rows_for_date_range_exclude_aspects(mock_reader):
# Setup
start_date = datetime(2023, 1, 1)
end_date = datetime(2023, 1, 2)
mock_reader.config.exclude_aspects = ["aspect1", "aspect2"]
mock_reader.execute_server_cursor.return_value = []
# Execute
list(mock_reader._get_rows(start_date, end_date, False, 50))
# Assert
called_params = mock_reader.execute_server_cursor.call_args[0][1]
assert "exclude_aspects" in called_params
assert called_params["exclude_aspects"] == ("aspect1", "aspect2")