fix(ingestion/dremio): handle dremio oom errors when ingesting large amount of metadata (#14883)

This commit is contained in:
Jonny Dixon 2025-10-27 15:52:51 +00:00 committed by GitHub
parent 6af5182e9b
commit f05f3e40f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 649 additions and 135 deletions

View File

@ -7,7 +7,7 @@ from collections import defaultdict
from enum import Enum
from itertools import product
from time import sleep, time
from typing import TYPE_CHECKING, Any, Deque, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
from urllib.parse import quote
import requests
@ -343,14 +343,149 @@ class DremioAPIOperations:
while True:
result = self.get_job_result(job_id, offset, limit)
rows.extend(result["rows"])
offset = offset + limit
if offset >= result["rowCount"]:
# Handle cases where API response doesn't contain 'rows' key
# This can happen with OOM errors or when no rows are returned
if "rows" not in result:
logger.warning(
f"API response for job {job_id} missing 'rows' key. "
f"Response keys: {list(result.keys())}"
)
# Check for error conditions
if "errorMessage" in result:
raise DremioAPIException(f"Query error: {result['errorMessage']}")
elif "message" in result:
logger.warning(
f"Query warning for job {job_id}: {result['message']}"
)
# Return empty list if no rows key and no error
break
# Handle empty rows response
result_rows = result["rows"]
if not result_rows:
logger.debug(
f"No more rows returned for job {job_id} at offset {offset}"
)
break
rows.extend(result_rows)
# Check actual returned rows to determine if we should continue
actual_rows_returned = len(result_rows)
if actual_rows_returned == 0:
logger.debug(f"Query returned no rows for job {job_id}")
break
offset = offset + actual_rows_returned
# If we got fewer rows than requested, we've reached the end
if actual_rows_returned < limit:
break
logger.info(f"Fetched {len(rows)} total rows for job {job_id}")
return rows
def _fetch_results_iter(self, job_id: str) -> Iterator[Dict]:
"""
Fetch job results in a streaming fashion to reduce memory usage.
Yields individual rows instead of collecting all in memory.
"""
limit = 500
offset = 0
total_rows_fetched = 0
while True:
result = self.get_job_result(job_id, offset, limit)
# Handle cases where API response doesn't contain 'rows' key
if "rows" not in result:
logger.warning(
f"API response for job {job_id} missing 'rows' key. "
f"Response keys: {list(result.keys())}"
)
# Check for error conditions
if "errorMessage" in result:
raise DremioAPIException(f"Query error: {result['errorMessage']}")
elif "message" in result:
logger.warning(
f"Query warning for job {job_id}: {result['message']}"
)
# Stop iteration if no rows key and no error
break
# Handle empty rows response
result_rows = result["rows"]
if not result_rows:
logger.debug(
f"No more rows returned for job {job_id} at offset {offset}"
)
break
# Yield individual rows instead of collecting them
for row in result_rows:
yield row
total_rows_fetched += 1
# Check actual returned rows to determine if we should continue
actual_rows_returned = len(result_rows)
if actual_rows_returned == 0:
logger.debug(f"Query returned no rows for job {job_id}")
break
offset = offset + actual_rows_returned
# If we got fewer rows than requested, we've reached the end
if actual_rows_returned < limit:
break
logger.info(f"Streamed {total_rows_fetched} total rows for job {job_id}")
def execute_query_iter(
self, query: str, timeout: int = 3600
) -> Iterator[Dict[str, Any]]:
"""Execute SQL query and return results as a streaming iterator"""
try:
with PerfTimer() as timer:
logger.info(f"Executing streaming query: {query}")
response = self.post(url="/sql", data=json.dumps({"sql": query}))
if "errorMessage" in response:
self.report.failure(
message="SQL Error", context=f"{response['errorMessage']}"
)
raise DremioAPIException(f"SQL Error: {response['errorMessage']}")
job_id = response["id"]
# Wait for job completion
start_time = time()
while True:
status = self.get_job_status(job_id)
if status["jobState"] == "COMPLETED":
break
elif status["jobState"] == "FAILED":
error_message = status.get("errorMessage", "Unknown error")
raise RuntimeError(f"Query failed: {error_message}")
elif status["jobState"] == "CANCELED":
raise RuntimeError("Query was canceled")
if time() - start_time > timeout:
self.cancel_query(job_id)
raise DremioAPIException(
f"Query execution timed out after {timeout} seconds"
)
sleep(3)
logger.info(
f"Query job completed in {timer.elapsed_seconds()} seconds, starting streaming"
)
# Return streaming iterator
return self._fetch_results_iter(job_id)
except requests.RequestException as e:
raise DremioAPIException("Error executing streaming query") from e
def cancel_query(self, job_id: str) -> None:
"""Cancel a running query"""
try:
@ -499,8 +634,12 @@ class DremioAPIOperations:
return f"AND {operator}({field}, '{pattern_str}')"
def get_all_tables_and_columns(
self, containers: Deque["DremioContainer"]
) -> List[Dict]:
self, containers: Iterator["DremioContainer"]
) -> Iterator[Dict]:
"""
Memory-efficient streaming version that yields tables one at a time.
Reduces memory usage for large datasets by processing results as they come.
"""
if self.edition == DremioEdition.ENTERPRISE:
query_template = DremioSQLQueries.QUERY_DATASETS_EE
elif self.edition == DremioEdition.CLOUD:
@ -517,21 +656,78 @@ class DremioAPIOperations:
self.deny_schema_pattern, schema_field, allow=False
)
all_tables_and_columns = []
# Process each container's results separately to avoid memory buildup
for schema in containers:
formatted_query = ""
try:
formatted_query = query_template.format(
schema_pattern=schema_condition,
deny_schema_pattern=deny_schema_condition,
container_name=schema.container_name.lower(),
)
all_tables_and_columns.extend(
self.execute_query(
query=formatted_query,
# Use streaming query execution
container_results = list(self.execute_query_iter(query=formatted_query))
if self.edition == DremioEdition.COMMUNITY:
# Process community edition results
formatted_tables = self.community_get_formatted_tables(
container_results
)
)
for table in formatted_tables:
yield table
else:
# Process enterprise/cloud edition results
column_dictionary: Dict[str, List[Dict]] = defaultdict(list)
table_metadata: Dict[str, Dict] = {}
for record in container_results:
if not record.get("COLUMN_NAME"):
continue
table_full_path = record.get("FULL_TABLE_PATH")
if not table_full_path:
continue
# Store column information
column_dictionary[table_full_path].append(
{
"name": record["COLUMN_NAME"],
"ordinal_position": record["ORDINAL_POSITION"],
"is_nullable": record["IS_NULLABLE"],
"data_type": record["DATA_TYPE"],
"column_size": record["COLUMN_SIZE"],
}
)
# Store table metadata (only once per table)
if table_full_path not in table_metadata:
table_metadata[table_full_path] = {
"TABLE_NAME": record.get("TABLE_NAME"),
"TABLE_SCHEMA": record.get("TABLE_SCHEMA"),
"VIEW_DEFINITION": record.get("VIEW_DEFINITION"),
"RESOURCE_ID": record.get("RESOURCE_ID"),
"LOCATION_ID": record.get("LOCATION_ID"),
"OWNER": record.get("OWNER"),
"OWNER_TYPE": record.get("OWNER_TYPE"),
"CREATED": record.get("CREATED"),
"FORMAT_TYPE": record.get("FORMAT_TYPE"),
}
# Yield tables one at a time
for table_path, table_info in table_metadata.items():
yield {
"TABLE_NAME": table_info.get("TABLE_NAME"),
"TABLE_SCHEMA": table_info.get("TABLE_SCHEMA"),
"COLUMNS": column_dictionary[table_path],
"VIEW_DEFINITION": table_info.get("VIEW_DEFINITION"),
"RESOURCE_ID": table_info.get("RESOURCE_ID"),
"LOCATION_ID": table_info.get("LOCATION_ID"),
"OWNER": table_info.get("OWNER"),
"OWNER_TYPE": table_info.get("OWNER_TYPE"),
"CREATED": table_info.get("CREATED"),
"FORMAT_TYPE": table_info.get("FORMAT_TYPE"),
}
except DremioAPIException as e:
self.report.warning(
message="Container has no tables or views",
@ -539,71 +735,6 @@ class DremioAPIOperations:
exc=e,
)
tables = []
if self.edition == DremioEdition.COMMUNITY:
tables = self.community_get_formatted_tables(all_tables_and_columns)
else:
column_dictionary: Dict[str, List[Dict]] = defaultdict(list)
for record in all_tables_and_columns:
if not record.get("COLUMN_NAME"):
continue
table_full_path = record.get("FULL_TABLE_PATH")
if not table_full_path:
continue
column_dictionary[table_full_path].append(
{
"name": record["COLUMN_NAME"],
"ordinal_position": record["ORDINAL_POSITION"],
"is_nullable": record["IS_NULLABLE"],
"data_type": record["DATA_TYPE"],
"column_size": record["COLUMN_SIZE"],
}
)
distinct_tables_list = list(
{
tuple(
dictionary[key]
for key in (
"TABLE_SCHEMA",
"TABLE_NAME",
"FULL_TABLE_PATH",
"VIEW_DEFINITION",
"LOCATION_ID",
"OWNER",
"OWNER_TYPE",
"CREATED",
"FORMAT_TYPE",
)
if key in dictionary
): dictionary
for dictionary in all_tables_and_columns
}.values()
)
for table in distinct_tables_list:
tables.append(
{
"TABLE_NAME": table.get("TABLE_NAME"),
"TABLE_SCHEMA": table.get("TABLE_SCHEMA"),
"COLUMNS": column_dictionary[table["FULL_TABLE_PATH"]],
"VIEW_DEFINITION": table.get("VIEW_DEFINITION"),
"RESOURCE_ID": table.get("RESOURCE_ID"),
"LOCATION_ID": table.get("LOCATION_ID"),
"OWNER": table.get("OWNER"),
"OWNER_TYPE": table.get("OWNER_TYPE"),
"CREATED": table.get("CREATED"),
"FORMAT_TYPE": table.get("FORMAT_TYPE"),
}
)
return tables
def validate_schema_format(self, schema):
if "." in schema:
schema_path = self.get(
@ -640,7 +771,10 @@ class DremioAPIOperations:
return parents_list
def extract_all_queries(self) -> List[Dict[str, Any]]:
def extract_all_queries(self) -> Iterator[Dict[str, Any]]:
"""
Memory-efficient streaming version for extracting query results.
"""
# Convert datetime objects to string format for SQL queries
start_timestamp_str = None
end_timestamp_str = None
@ -661,7 +795,7 @@ class DremioAPIOperations:
end_timestamp_millis=end_timestamp_str,
)
return self.execute_query(query=jobs_query)
return self.execute_query_iter(query=jobs_query)
def get_tags_for_resource(self, resource_id: str) -> Optional[List[str]]:
"""

View File

@ -1,4 +1,3 @@
import itertools
import logging
import re
import uuid
@ -6,7 +5,7 @@ from collections import deque
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Deque, Dict, List, Optional
from typing import Any, Deque, Dict, Iterator, List, Optional
from sqlglot import parse_one
@ -184,6 +183,7 @@ class DremioQuery:
return ""
def get_raw_query(self, sql_query: str) -> str:
"""Remove comments from SQL query using sqlglot parser."""
try:
parsed = parse_one(sql_query)
return parsed.sql(comments=False)
@ -336,43 +336,26 @@ class DremioCatalog:
def __init__(self, dremio_api: DremioAPIOperations):
self.dremio_api = dremio_api
self.edition = dremio_api.edition
self.datasets: Deque[DremioDataset] = deque()
self.sources: Deque[DremioSourceContainer] = deque()
self.spaces: Deque[DremioSpace] = deque()
self.folders: Deque[DremioFolder] = deque()
self.glossary_terms: Deque[DremioGlossaryTerm] = deque()
self.queries: Deque[DremioQuery] = deque()
self.datasets_populated = False
self.containers_populated = False
self.queries_populated = False
def set_datasets(self) -> None:
if not self.datasets_populated:
self.set_containers()
def get_datasets(self) -> Iterator[DremioDataset]:
"""Get all Dremio datasets (tables and views) as an iterator."""
# Get containers directly without storing them
containers = self.get_containers()
containers: Deque[DremioContainer] = deque()
containers.extend(self.spaces) # Add DremioSpace elements
containers.extend(self.sources) # Add DremioSource elements
for dataset_details in self.dremio_api.get_all_tables_and_columns(containers):
dremio_dataset = DremioDataset(
dataset_details=dataset_details,
api_operations=self.dremio_api,
)
for dataset_details in self.dremio_api.get_all_tables_and_columns(
containers=containers
):
dremio_dataset = DremioDataset(
dataset_details=dataset_details,
api_operations=self.dremio_api,
)
self.datasets.append(dremio_dataset)
for glossary_term in dremio_dataset.glossary_terms:
if glossary_term not in self.glossary_terms:
self.glossary_terms.append(glossary_term)
self.datasets_populated = True
def get_datasets(self) -> Deque[DremioDataset]:
self.set_datasets()
return self.datasets
yield dremio_dataset
def set_containers(self) -> None:
if not self.containers_populated:
@ -423,18 +406,50 @@ class DremioCatalog:
self.containers_populated = True
def get_containers(self) -> Deque:
self.set_containers()
return deque(itertools.chain(self.sources, self.spaces, self.folders))
def get_containers(self) -> Iterator[DremioContainer]:
"""Get all containers (sources, spaces, folders) as an iterator."""
for container in self.dremio_api.get_all_containers():
container_type = container.get("container_type")
if container_type == DremioEntityContainerType.SOURCE:
yield DremioSourceContainer(
container_name=container.get("name"),
location_id=container.get("id"),
path=[],
api_operations=self.dremio_api,
dremio_source_type=container.get("source_type") or "",
root_path=container.get("root_path"),
database_name=container.get("database_name"),
)
elif container_type == DremioEntityContainerType.SPACE:
yield DremioSpace(
container_name=container.get("name"),
location_id=container.get("id"),
path=[],
api_operations=self.dremio_api,
)
elif container_type == DremioEntityContainerType.FOLDER:
yield DremioFolder(
container_name=container.get("name"),
location_id=container.get("id"),
path=container.get("path"),
api_operations=self.dremio_api,
)
def get_sources(self) -> Deque[DremioSourceContainer]:
self.set_containers()
return self.sources
def get_sources(self) -> Iterator[DremioSourceContainer]:
"""Get all Dremio source containers (external data connections) as an iterator."""
for container in self.get_containers():
if isinstance(container, DremioSourceContainer):
yield container
def get_glossary_terms(self) -> Deque[DremioGlossaryTerm]:
self.set_datasets()
self.set_containers()
return self.glossary_terms
def get_glossary_terms(self) -> Iterator[DremioGlossaryTerm]:
"""Get all unique glossary terms (tags) from datasets."""
glossary_terms_seen = set()
for dataset in self.get_datasets():
for glossary_term in dataset.glossary_terms:
if glossary_term not in glossary_terms_seen:
glossary_terms_seen.add(glossary_term)
yield glossary_term
def is_valid_query(self, query: Dict[str, Any]) -> bool:
required_fields = [
@ -447,6 +462,7 @@ class DremioCatalog:
return all(query.get(field) for field in required_fields)
def get_queries(self) -> Deque[DremioQuery]:
"""Get all valid Dremio queries for lineage analysis."""
for query in self.dremio_api.extract_all_queries():
if not self.is_valid_query(query):
continue

View File

@ -17,6 +17,7 @@ from datahub.metadata.schema_classes import (
DatasetProfileClass,
QuantileClass,
)
from datahub.utilities.perf_timer import PerfTimer
logger = logging.getLogger(__name__)
@ -64,8 +65,13 @@ class DremioProfiler:
)
return
profile_data = self.profile_table(full_table_name, columns)
profile_aspect = self.populate_profile_aspect(profile_data)
with PerfTimer() as timer:
profile_data = self.profile_table(full_table_name, columns)
profile_aspect = self.populate_profile_aspect(profile_data)
logger.info(
f"Profiled table {full_table_name} with {len(columns)} columns in {timer.elapsed_seconds():.2f} seconds"
)
if profile_aspect:
self.report.report_entity_profiled(dataset.resource_name)
@ -131,7 +137,12 @@ class DremioProfiler:
def _profile_chunk(self, table_name: str, columns: List[Tuple[str, str]]) -> Dict:
profile_sql = self._build_profile_sql(table_name, columns)
try:
results = self.api_operations.execute_query(profile_sql)
with PerfTimer() as timer:
results = self.api_operations.execute_query(profile_sql)
logger.debug(
f"Profiling query for {table_name} ({len(columns)} columns) completed in {timer.elapsed_seconds():.2f} seconds"
)
return self._parse_profile_results(results, columns)
except DremioAPIException as e:
raise e

View File

@ -55,7 +55,7 @@ from datahub.ingestion.source.state.stateful_ingestion_base import (
from datahub.ingestion.source_report.ingestion_stage import (
LINEAGE_EXTRACTION,
METADATA_EXTRACTION,
IngestionHighStage,
PROFILING,
)
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
DatasetLineageTypeClass,
@ -201,7 +201,7 @@ class DremioSource(StatefulIngestionSourceBase):
return "dremio"
def _build_source_map(self) -> Dict[str, DremioSourceMapEntry]:
dremio_sources = self.dremio_catalog.get_sources()
dremio_sources = list(self.dremio_catalog.get_sources())
source_mappings_config = self.config.source_mappings or []
source_map = build_dremio_source_map(dremio_sources, source_mappings_config)
@ -242,9 +242,7 @@ class DremioSource(StatefulIngestionSourceBase):
)
# Process Datasets
datasets = self.dremio_catalog.get_datasets()
for dataset_info in datasets:
for dataset_info in self.dremio_catalog.get_datasets():
try:
yield from self.process_dataset(dataset_info)
logger.info(
@ -258,10 +256,8 @@ class DremioSource(StatefulIngestionSourceBase):
exc=exc,
)
# Process Glossary Terms
glossary_terms = self.dremio_catalog.get_glossary_terms()
for glossary_term in glossary_terms:
# Process Glossary Terms using streaming
for glossary_term in self.dremio_catalog.get_glossary_terms():
try:
yield from self.process_glossary_term(glossary_term)
except Exception as exc:
@ -283,14 +279,16 @@ class DremioSource(StatefulIngestionSourceBase):
# Profiling
if self.config.is_profiling_enabled():
with (
self.report.new_high_stage(IngestionHighStage.PROFILING),
self.report.new_stage(PROFILING),
ThreadPoolExecutor(
max_workers=self.config.profiling.max_workers
) as executor,
):
# Collect datasets for profiling
datasets_for_profiling = list(self.dremio_catalog.get_datasets())
future_to_dataset = {
executor.submit(self.generate_profiles, dataset): dataset
for dataset in datasets
for dataset in datasets_for_profiling
}
for future in as_completed(future_to_dataset):

View File

@ -0,0 +1,281 @@
from collections import deque
from unittest.mock import Mock, patch
import pytest
from datahub.ingestion.source.dremio.dremio_api import (
DremioAPIException,
DremioAPIOperations,
DremioEdition,
)
from datahub.ingestion.source.dremio.dremio_config import DremioSourceConfig
from datahub.ingestion.source.dremio.dremio_reporting import DremioSourceReport
class TestDremioAPIPagination:
@pytest.fixture
def dremio_api(self, monkeypatch):
"""Setup mock Dremio API for testing"""
# Mock the requests.Session
mock_session = Mock()
monkeypatch.setattr("requests.Session", Mock(return_value=mock_session))
# Mock the authentication response
mock_session.post.return_value.json.return_value = {"token": "dummy-token"}
mock_session.post.return_value.status_code = 200
config = DremioSourceConfig(
hostname="dummy-host",
port=9047,
tls=False,
authentication_method="password",
username="dummy-user",
password="dummy-password",
)
report = DremioSourceReport()
api = DremioAPIOperations(config, report)
api.session = mock_session
return api
def test_fetch_all_results_missing_rows_key(self, dremio_api):
"""Test handling of API response missing 'rows' key"""
# Mock get_job_result to return response without 'rows' key
dremio_api.get_job_result = Mock(return_value={"message": "No data available"})
result = dremio_api._fetch_all_results("test-job-id")
# Should return empty list when no rows key is present
assert result == []
dremio_api.get_job_result.assert_called_once()
def test_fetch_all_results_with_error_message(self, dremio_api):
"""Test handling of API response with errorMessage"""
# Mock get_job_result to return error response
dremio_api.get_job_result = Mock(return_value={"errorMessage": "Out of memory"})
with pytest.raises(DremioAPIException, match="Query error: Out of memory"):
dremio_api._fetch_all_results("test-job-id")
def test_fetch_all_results_empty_rows(self, dremio_api):
"""Test handling of empty rows response"""
# Mock get_job_result to return empty rows
dremio_api.get_job_result = Mock(return_value={"rows": [], "rowCount": 0})
result = dremio_api._fetch_all_results("test-job-id")
assert result == []
dremio_api.get_job_result.assert_called_once()
def test_fetch_all_results_normal_case(self, dremio_api):
"""Test normal operation with valid rows"""
# Mock get_job_result to return valid data
# First response: 2 rows, rowCount=2 (offset will be 2, which equals rowCount, so stops)
mock_responses = [
{"rows": [{"col1": "val1"}, {"col1": "val2"}], "rowCount": 2},
]
dremio_api.get_job_result = Mock(side_effect=mock_responses)
result = dremio_api._fetch_all_results("test-job-id")
expected = [{"col1": "val1"}, {"col1": "val2"}]
assert result == expected
assert dremio_api.get_job_result.call_count == 1
def test_fetch_results_iter_missing_rows_key(self, dremio_api):
"""Test internal streaming method handling missing 'rows' key"""
# Mock get_job_result to return response without 'rows' key
dremio_api.get_job_result = Mock(return_value={"message": "No data available"})
result_iterator = dremio_api._fetch_results_iter("test-job-id")
results = list(result_iterator)
# Should return empty iterator when no rows key is present
assert results == []
def test_fetch_results_iter_with_error(self, dremio_api):
"""Test internal streaming method handling error response"""
# Mock get_job_result to return error response
dremio_api.get_job_result = Mock(return_value={"errorMessage": "Query timeout"})
result_iterator = dremio_api._fetch_results_iter("test-job-id")
with pytest.raises(DremioAPIException, match="Query error: Query timeout"):
list(result_iterator)
def test_fetch_results_iter_normal_case(self, dremio_api):
"""Test internal streaming method with valid data"""
# Mock get_job_result to return valid data in batches
mock_responses = [
{"rows": [{"col1": "val1"}, {"col1": "val2"}], "rowCount": 2},
{"rows": [], "rowCount": 2}, # Empty response to end iteration
]
dremio_api.get_job_result = Mock(side_effect=mock_responses)
result_iterator = dremio_api._fetch_results_iter("test-job-id")
results = list(result_iterator)
expected = [
{"col1": "val1"},
{"col1": "val2"},
]
assert results == expected
def test_execute_query_iter_success(self, dremio_api):
"""Test execute_query_iter with successful job completion"""
# Mock the POST response for job submission
dremio_api.post = Mock(return_value={"id": "job-123"})
# Mock job status progression
status_responses = [
{"jobState": "RUNNING"},
{"jobState": "RUNNING"},
{"jobState": "COMPLETED"},
]
dremio_api.get_job_status = Mock(side_effect=status_responses)
# Mock streaming results
dremio_api._fetch_results_iter = Mock(return_value=iter([{"col1": "val1"}]))
with patch("time.sleep"): # Skip actual sleep delays
result_iterator = dremio_api.execute_query_iter("SELECT * FROM test")
results = list(result_iterator)
assert results == [{"col1": "val1"}]
dremio_api.post.assert_called_once()
assert dremio_api.get_job_status.call_count == 3
def test_execute_query_iter_job_failure(self, dremio_api):
"""Test execute_query_iter with job failure"""
# Mock the POST response for job submission
dremio_api.post = Mock(return_value={"id": "job-123"})
# Mock job failure
dremio_api.get_job_status = Mock(
return_value={"jobState": "FAILED", "errorMessage": "SQL syntax error"}
)
with (
pytest.raises(RuntimeError, match="Query failed: SQL syntax error"),
patch("time.sleep"),
):
result_iterator = dremio_api.execute_query_iter("SELECT * FROM test")
list(result_iterator) # Force evaluation
def test_execute_query_iter_timeout(self, dremio_api):
"""Test execute_query_iter with timeout"""
# Mock the POST response for job submission
dremio_api.post = Mock(return_value={"id": "job-123"})
# Mock job that never completes
dremio_api.get_job_status = Mock(return_value={"jobState": "RUNNING"})
dremio_api.cancel_query = Mock()
# Mock time.time to simulate timeout - need to patch where it's imported
# First call: start_time = 0
# Second call: time() - start_time check = 3700 (triggers timeout)
mock_time_values = [0, 3700]
with (
patch(
"datahub.ingestion.source.dremio.dremio_api.time",
side_effect=mock_time_values,
),
patch("datahub.ingestion.source.dremio.dremio_api.sleep"),
pytest.raises(
DremioAPIException,
match="Query execution timed out after 3600 seconds",
),
):
result_iterator = dremio_api.execute_query_iter("SELECT * FROM test")
list(result_iterator)
dremio_api.cancel_query.assert_called_once_with("job-123")
def test_get_all_tables_and_columns(self, dremio_api):
"""Test streaming version of get_all_tables_and_columns"""
from datahub.ingestion.source.dremio.dremio_api import DremioEdition
# Set up test data
dremio_api.edition = DremioEdition.ENTERPRISE
dremio_api.allow_schema_pattern = [".*"]
dremio_api.deny_schema_pattern = []
# Mock container
mock_container = Mock()
mock_container.container_name = "test_source"
containers = deque([mock_container])
# Mock streaming query results
mock_results = [
{
"COLUMN_NAME": "col1",
"FULL_TABLE_PATH": "test.table1",
"TABLE_NAME": "table1",
"TABLE_SCHEMA": "test",
"ORDINAL_POSITION": 1,
"IS_NULLABLE": "YES",
"DATA_TYPE": "VARCHAR",
"COLUMN_SIZE": 255,
"RESOURCE_ID": "res1",
}
]
dremio_api.execute_query_iter = Mock(return_value=iter(mock_results))
# Test streaming method
result_iterator = dremio_api.get_all_tables_and_columns(containers)
tables = list(result_iterator)
assert len(tables) == 1
table = tables[0]
assert table["TABLE_NAME"] == "table1"
assert table["TABLE_SCHEMA"] == "test"
assert len(table["COLUMNS"]) == 1
assert table["COLUMNS"][0]["name"] == "col1"
def test_extract_all_queries(self, dremio_api):
"""Test streaming version of extract_all_queries"""
dremio_api.edition = DremioEdition.ENTERPRISE
dremio_api.start_time = None
dremio_api.end_time = None
# Mock streaming query execution
mock_query_results = [
{"query_id": "q1", "sql": "SELECT * FROM table1"},
{"query_id": "q2", "sql": "SELECT * FROM table2"},
]
dremio_api.execute_query_iter = Mock(return_value=iter(mock_query_results))
result_iterator = dremio_api.extract_all_queries()
queries = list(result_iterator)
assert len(queries) == 2
assert queries[0]["query_id"] == "q1"
assert queries[1]["query_id"] == "q2"
dremio_api.execute_query_iter.assert_called_once()
def test_fetch_results_iter_incremental_yielding(self, dremio_api):
"""Test that internal iterator yields results incrementally and can be partially consumed"""
# Verify that streaming yields results one at a time and iterator state is maintained
mock_responses = [
{"rows": [{"id": i} for i in range(100)], "rowCount": 1000},
{"rows": [{"id": i} for i in range(100, 200)], "rowCount": 1000},
{"rows": [], "rowCount": 1000}, # End iteration
]
dremio_api.get_job_result = Mock(side_effect=mock_responses)
# Test that streaming version yields results incrementally
result_iterator = dremio_api._fetch_results_iter("test-job")
# Get first few results
first_batch = []
for i, result in enumerate(result_iterator):
first_batch.append(result)
if i >= 50: # Stop after getting 51 results
break
assert len(first_batch) == 51
# Verify we can still get more results (iterator not exhausted)
next_result = next(result_iterator, None)
assert next_result is not None

View File

@ -0,0 +1,74 @@
from unittest.mock import Mock
import pytest
from datahub.ingestion.source.dremio.dremio_config import DremioSourceConfig
from datahub.ingestion.source.dremio.dremio_entities import DremioCatalog, DremioDataset
from datahub.ingestion.source.dremio.dremio_source import DremioSource
class TestDremioIteratorIntegration:
@pytest.fixture
def mock_config(self):
return DremioSourceConfig(
hostname="test-host",
port=9047,
tls=False,
username="test-user",
password="test-password",
)
@pytest.fixture
def mock_dremio_source(self, mock_config, monkeypatch):
mock_session = Mock()
monkeypatch.setattr("requests.Session", Mock(return_value=mock_session))
mock_session.post.return_value.json.return_value = {"token": "dummy-token"}
mock_session.post.return_value.status_code = 200
mock_ctx = Mock()
mock_ctx.run_id = "test-run-id"
source = DremioSource(mock_config, mock_ctx)
source.dremio_catalog = Mock(spec=DremioCatalog)
source.dremio_catalog.dremio_api = Mock()
return source
def test_source_uses_iterators_by_default(self, mock_dremio_source):
"""Test that source uses iterators by default"""
mock_dataset = Mock(spec=DremioDataset)
mock_dataset.path = ["test", "path"]
mock_dataset.resource_name = "test_table"
mock_dremio_source.dremio_catalog.get_datasets.return_value = iter(
[mock_dataset]
)
mock_dremio_source.dremio_catalog.get_containers.return_value = []
mock_dremio_source.dremio_catalog.get_glossary_terms.return_value = iter([])
mock_dremio_source.dremio_catalog.get_sources.return_value = []
mock_dremio_source.config.source_mappings = []
mock_dremio_source.process_dataset = Mock(return_value=iter([]))
list(mock_dremio_source.get_workunits_internal())
mock_dremio_source.dremio_catalog.get_datasets.assert_called_once()
mock_dremio_source.dremio_catalog.get_glossary_terms.assert_called_once()
def test_iterator_handles_exceptions_gracefully(self, mock_dremio_source):
"""Test that iterator handles exceptions without crashing"""
mock_dataset = Mock(spec=DremioDataset)
mock_dataset.path = ["test", "path"]
mock_dataset.resource_name = "test_table"
mock_dremio_source.dremio_catalog.get_datasets.return_value = iter(
[mock_dataset]
)
mock_dremio_source.dremio_catalog.get_containers.return_value = []
mock_dremio_source.dremio_catalog.get_glossary_terms.return_value = iter([])
mock_dremio_source.dremio_catalog.get_sources.return_value = []
mock_dremio_source.config.source_mappings = []
mock_dremio_source.process_dataset = Mock(
side_effect=Exception("Processing error")
)
list(mock_dremio_source.get_workunits_internal())
assert mock_dremio_source.report.num_datasets_failed > 0

View File

@ -10,7 +10,7 @@ from datahub.ingestion.source.dremio.dremio_source import (
def test_build_source_map_simple():
# write unit test
"""Test basic source mapping functionality with simple configuration"""
config_mapping: List[DremioSourceMapping] = [
DremioSourceMapping(source_name="source1", platform="S3", env="PROD"),
DremioSourceMapping(source_name="source2", platform="redshift", env="DEV"),
@ -58,7 +58,7 @@ def test_build_source_map_simple():
def test_build_source_map_same_platform_multiple_sources():
# write unit test
"""Test source mapping with multiple sources using the same platform and complex scenarios"""
config_mapping: List[DremioSourceMapping] = [
DremioSourceMapping(source_name="source1", platform="S3", env="PROD"),
DremioSourceMapping(source_name="source2", platform="redshift", env="DEV"),