mirror of
https://github.com/datahub-project/datahub.git
synced 2025-11-06 22:04:16 +00:00
fix(ingestion/dremio): handle dremio oom errors when ingesting large amount of metadata (#14883)
This commit is contained in:
parent
6af5182e9b
commit
f05f3e40f2
@ -7,7 +7,7 @@ from collections import defaultdict
|
||||
from enum import Enum
|
||||
from itertools import product
|
||||
from time import sleep, time
|
||||
from typing import TYPE_CHECKING, Any, Deque, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
@ -343,14 +343,149 @@ class DremioAPIOperations:
|
||||
|
||||
while True:
|
||||
result = self.get_job_result(job_id, offset, limit)
|
||||
rows.extend(result["rows"])
|
||||
|
||||
offset = offset + limit
|
||||
if offset >= result["rowCount"]:
|
||||
# Handle cases where API response doesn't contain 'rows' key
|
||||
# This can happen with OOM errors or when no rows are returned
|
||||
if "rows" not in result:
|
||||
logger.warning(
|
||||
f"API response for job {job_id} missing 'rows' key. "
|
||||
f"Response keys: {list(result.keys())}"
|
||||
)
|
||||
# Check for error conditions
|
||||
if "errorMessage" in result:
|
||||
raise DremioAPIException(f"Query error: {result['errorMessage']}")
|
||||
elif "message" in result:
|
||||
logger.warning(
|
||||
f"Query warning for job {job_id}: {result['message']}"
|
||||
)
|
||||
# Return empty list if no rows key and no error
|
||||
break
|
||||
|
||||
# Handle empty rows response
|
||||
result_rows = result["rows"]
|
||||
if not result_rows:
|
||||
logger.debug(
|
||||
f"No more rows returned for job {job_id} at offset {offset}"
|
||||
)
|
||||
break
|
||||
|
||||
rows.extend(result_rows)
|
||||
|
||||
# Check actual returned rows to determine if we should continue
|
||||
actual_rows_returned = len(result_rows)
|
||||
if actual_rows_returned == 0:
|
||||
logger.debug(f"Query returned no rows for job {job_id}")
|
||||
break
|
||||
|
||||
offset = offset + actual_rows_returned
|
||||
# If we got fewer rows than requested, we've reached the end
|
||||
if actual_rows_returned < limit:
|
||||
break
|
||||
|
||||
logger.info(f"Fetched {len(rows)} total rows for job {job_id}")
|
||||
return rows
|
||||
|
||||
def _fetch_results_iter(self, job_id: str) -> Iterator[Dict]:
|
||||
"""
|
||||
Fetch job results in a streaming fashion to reduce memory usage.
|
||||
Yields individual rows instead of collecting all in memory.
|
||||
"""
|
||||
limit = 500
|
||||
offset = 0
|
||||
total_rows_fetched = 0
|
||||
|
||||
while True:
|
||||
result = self.get_job_result(job_id, offset, limit)
|
||||
|
||||
# Handle cases where API response doesn't contain 'rows' key
|
||||
if "rows" not in result:
|
||||
logger.warning(
|
||||
f"API response for job {job_id} missing 'rows' key. "
|
||||
f"Response keys: {list(result.keys())}"
|
||||
)
|
||||
# Check for error conditions
|
||||
if "errorMessage" in result:
|
||||
raise DremioAPIException(f"Query error: {result['errorMessage']}")
|
||||
elif "message" in result:
|
||||
logger.warning(
|
||||
f"Query warning for job {job_id}: {result['message']}"
|
||||
)
|
||||
# Stop iteration if no rows key and no error
|
||||
break
|
||||
|
||||
# Handle empty rows response
|
||||
result_rows = result["rows"]
|
||||
if not result_rows:
|
||||
logger.debug(
|
||||
f"No more rows returned for job {job_id} at offset {offset}"
|
||||
)
|
||||
break
|
||||
|
||||
# Yield individual rows instead of collecting them
|
||||
for row in result_rows:
|
||||
yield row
|
||||
total_rows_fetched += 1
|
||||
|
||||
# Check actual returned rows to determine if we should continue
|
||||
actual_rows_returned = len(result_rows)
|
||||
if actual_rows_returned == 0:
|
||||
logger.debug(f"Query returned no rows for job {job_id}")
|
||||
break
|
||||
|
||||
offset = offset + actual_rows_returned
|
||||
# If we got fewer rows than requested, we've reached the end
|
||||
if actual_rows_returned < limit:
|
||||
break
|
||||
|
||||
logger.info(f"Streamed {total_rows_fetched} total rows for job {job_id}")
|
||||
|
||||
def execute_query_iter(
|
||||
self, query: str, timeout: int = 3600
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""Execute SQL query and return results as a streaming iterator"""
|
||||
try:
|
||||
with PerfTimer() as timer:
|
||||
logger.info(f"Executing streaming query: {query}")
|
||||
response = self.post(url="/sql", data=json.dumps({"sql": query}))
|
||||
|
||||
if "errorMessage" in response:
|
||||
self.report.failure(
|
||||
message="SQL Error", context=f"{response['errorMessage']}"
|
||||
)
|
||||
raise DremioAPIException(f"SQL Error: {response['errorMessage']}")
|
||||
|
||||
job_id = response["id"]
|
||||
|
||||
# Wait for job completion
|
||||
start_time = time()
|
||||
while True:
|
||||
status = self.get_job_status(job_id)
|
||||
if status["jobState"] == "COMPLETED":
|
||||
break
|
||||
elif status["jobState"] == "FAILED":
|
||||
error_message = status.get("errorMessage", "Unknown error")
|
||||
raise RuntimeError(f"Query failed: {error_message}")
|
||||
elif status["jobState"] == "CANCELED":
|
||||
raise RuntimeError("Query was canceled")
|
||||
|
||||
if time() - start_time > timeout:
|
||||
self.cancel_query(job_id)
|
||||
raise DremioAPIException(
|
||||
f"Query execution timed out after {timeout} seconds"
|
||||
)
|
||||
|
||||
sleep(3)
|
||||
|
||||
logger.info(
|
||||
f"Query job completed in {timer.elapsed_seconds()} seconds, starting streaming"
|
||||
)
|
||||
|
||||
# Return streaming iterator
|
||||
return self._fetch_results_iter(job_id)
|
||||
|
||||
except requests.RequestException as e:
|
||||
raise DremioAPIException("Error executing streaming query") from e
|
||||
|
||||
def cancel_query(self, job_id: str) -> None:
|
||||
"""Cancel a running query"""
|
||||
try:
|
||||
@ -499,8 +634,12 @@ class DremioAPIOperations:
|
||||
return f"AND {operator}({field}, '{pattern_str}')"
|
||||
|
||||
def get_all_tables_and_columns(
|
||||
self, containers: Deque["DremioContainer"]
|
||||
) -> List[Dict]:
|
||||
self, containers: Iterator["DremioContainer"]
|
||||
) -> Iterator[Dict]:
|
||||
"""
|
||||
Memory-efficient streaming version that yields tables one at a time.
|
||||
Reduces memory usage for large datasets by processing results as they come.
|
||||
"""
|
||||
if self.edition == DremioEdition.ENTERPRISE:
|
||||
query_template = DremioSQLQueries.QUERY_DATASETS_EE
|
||||
elif self.edition == DremioEdition.CLOUD:
|
||||
@ -517,21 +656,78 @@ class DremioAPIOperations:
|
||||
self.deny_schema_pattern, schema_field, allow=False
|
||||
)
|
||||
|
||||
all_tables_and_columns = []
|
||||
|
||||
# Process each container's results separately to avoid memory buildup
|
||||
for schema in containers:
|
||||
formatted_query = ""
|
||||
try:
|
||||
formatted_query = query_template.format(
|
||||
schema_pattern=schema_condition,
|
||||
deny_schema_pattern=deny_schema_condition,
|
||||
container_name=schema.container_name.lower(),
|
||||
)
|
||||
all_tables_and_columns.extend(
|
||||
self.execute_query(
|
||||
query=formatted_query,
|
||||
|
||||
# Use streaming query execution
|
||||
container_results = list(self.execute_query_iter(query=formatted_query))
|
||||
|
||||
if self.edition == DremioEdition.COMMUNITY:
|
||||
# Process community edition results
|
||||
formatted_tables = self.community_get_formatted_tables(
|
||||
container_results
|
||||
)
|
||||
)
|
||||
for table in formatted_tables:
|
||||
yield table
|
||||
else:
|
||||
# Process enterprise/cloud edition results
|
||||
column_dictionary: Dict[str, List[Dict]] = defaultdict(list)
|
||||
table_metadata: Dict[str, Dict] = {}
|
||||
|
||||
for record in container_results:
|
||||
if not record.get("COLUMN_NAME"):
|
||||
continue
|
||||
|
||||
table_full_path = record.get("FULL_TABLE_PATH")
|
||||
if not table_full_path:
|
||||
continue
|
||||
|
||||
# Store column information
|
||||
column_dictionary[table_full_path].append(
|
||||
{
|
||||
"name": record["COLUMN_NAME"],
|
||||
"ordinal_position": record["ORDINAL_POSITION"],
|
||||
"is_nullable": record["IS_NULLABLE"],
|
||||
"data_type": record["DATA_TYPE"],
|
||||
"column_size": record["COLUMN_SIZE"],
|
||||
}
|
||||
)
|
||||
|
||||
# Store table metadata (only once per table)
|
||||
if table_full_path not in table_metadata:
|
||||
table_metadata[table_full_path] = {
|
||||
"TABLE_NAME": record.get("TABLE_NAME"),
|
||||
"TABLE_SCHEMA": record.get("TABLE_SCHEMA"),
|
||||
"VIEW_DEFINITION": record.get("VIEW_DEFINITION"),
|
||||
"RESOURCE_ID": record.get("RESOURCE_ID"),
|
||||
"LOCATION_ID": record.get("LOCATION_ID"),
|
||||
"OWNER": record.get("OWNER"),
|
||||
"OWNER_TYPE": record.get("OWNER_TYPE"),
|
||||
"CREATED": record.get("CREATED"),
|
||||
"FORMAT_TYPE": record.get("FORMAT_TYPE"),
|
||||
}
|
||||
|
||||
# Yield tables one at a time
|
||||
for table_path, table_info in table_metadata.items():
|
||||
yield {
|
||||
"TABLE_NAME": table_info.get("TABLE_NAME"),
|
||||
"TABLE_SCHEMA": table_info.get("TABLE_SCHEMA"),
|
||||
"COLUMNS": column_dictionary[table_path],
|
||||
"VIEW_DEFINITION": table_info.get("VIEW_DEFINITION"),
|
||||
"RESOURCE_ID": table_info.get("RESOURCE_ID"),
|
||||
"LOCATION_ID": table_info.get("LOCATION_ID"),
|
||||
"OWNER": table_info.get("OWNER"),
|
||||
"OWNER_TYPE": table_info.get("OWNER_TYPE"),
|
||||
"CREATED": table_info.get("CREATED"),
|
||||
"FORMAT_TYPE": table_info.get("FORMAT_TYPE"),
|
||||
}
|
||||
|
||||
except DremioAPIException as e:
|
||||
self.report.warning(
|
||||
message="Container has no tables or views",
|
||||
@ -539,71 +735,6 @@ class DremioAPIOperations:
|
||||
exc=e,
|
||||
)
|
||||
|
||||
tables = []
|
||||
|
||||
if self.edition == DremioEdition.COMMUNITY:
|
||||
tables = self.community_get_formatted_tables(all_tables_and_columns)
|
||||
|
||||
else:
|
||||
column_dictionary: Dict[str, List[Dict]] = defaultdict(list)
|
||||
|
||||
for record in all_tables_and_columns:
|
||||
if not record.get("COLUMN_NAME"):
|
||||
continue
|
||||
|
||||
table_full_path = record.get("FULL_TABLE_PATH")
|
||||
if not table_full_path:
|
||||
continue
|
||||
|
||||
column_dictionary[table_full_path].append(
|
||||
{
|
||||
"name": record["COLUMN_NAME"],
|
||||
"ordinal_position": record["ORDINAL_POSITION"],
|
||||
"is_nullable": record["IS_NULLABLE"],
|
||||
"data_type": record["DATA_TYPE"],
|
||||
"column_size": record["COLUMN_SIZE"],
|
||||
}
|
||||
)
|
||||
|
||||
distinct_tables_list = list(
|
||||
{
|
||||
tuple(
|
||||
dictionary[key]
|
||||
for key in (
|
||||
"TABLE_SCHEMA",
|
||||
"TABLE_NAME",
|
||||
"FULL_TABLE_PATH",
|
||||
"VIEW_DEFINITION",
|
||||
"LOCATION_ID",
|
||||
"OWNER",
|
||||
"OWNER_TYPE",
|
||||
"CREATED",
|
||||
"FORMAT_TYPE",
|
||||
)
|
||||
if key in dictionary
|
||||
): dictionary
|
||||
for dictionary in all_tables_and_columns
|
||||
}.values()
|
||||
)
|
||||
|
||||
for table in distinct_tables_list:
|
||||
tables.append(
|
||||
{
|
||||
"TABLE_NAME": table.get("TABLE_NAME"),
|
||||
"TABLE_SCHEMA": table.get("TABLE_SCHEMA"),
|
||||
"COLUMNS": column_dictionary[table["FULL_TABLE_PATH"]],
|
||||
"VIEW_DEFINITION": table.get("VIEW_DEFINITION"),
|
||||
"RESOURCE_ID": table.get("RESOURCE_ID"),
|
||||
"LOCATION_ID": table.get("LOCATION_ID"),
|
||||
"OWNER": table.get("OWNER"),
|
||||
"OWNER_TYPE": table.get("OWNER_TYPE"),
|
||||
"CREATED": table.get("CREATED"),
|
||||
"FORMAT_TYPE": table.get("FORMAT_TYPE"),
|
||||
}
|
||||
)
|
||||
|
||||
return tables
|
||||
|
||||
def validate_schema_format(self, schema):
|
||||
if "." in schema:
|
||||
schema_path = self.get(
|
||||
@ -640,7 +771,10 @@ class DremioAPIOperations:
|
||||
|
||||
return parents_list
|
||||
|
||||
def extract_all_queries(self) -> List[Dict[str, Any]]:
|
||||
def extract_all_queries(self) -> Iterator[Dict[str, Any]]:
|
||||
"""
|
||||
Memory-efficient streaming version for extracting query results.
|
||||
"""
|
||||
# Convert datetime objects to string format for SQL queries
|
||||
start_timestamp_str = None
|
||||
end_timestamp_str = None
|
||||
@ -661,7 +795,7 @@ class DremioAPIOperations:
|
||||
end_timestamp_millis=end_timestamp_str,
|
||||
)
|
||||
|
||||
return self.execute_query(query=jobs_query)
|
||||
return self.execute_query_iter(query=jobs_query)
|
||||
|
||||
def get_tags_for_resource(self, resource_id: str) -> Optional[List[str]]:
|
||||
"""
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
@ -6,7 +5,7 @@ from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Deque, Dict, List, Optional
|
||||
from typing import Any, Deque, Dict, Iterator, List, Optional
|
||||
|
||||
from sqlglot import parse_one
|
||||
|
||||
@ -184,6 +183,7 @@ class DremioQuery:
|
||||
return ""
|
||||
|
||||
def get_raw_query(self, sql_query: str) -> str:
|
||||
"""Remove comments from SQL query using sqlglot parser."""
|
||||
try:
|
||||
parsed = parse_one(sql_query)
|
||||
return parsed.sql(comments=False)
|
||||
@ -336,43 +336,26 @@ class DremioCatalog:
|
||||
def __init__(self, dremio_api: DremioAPIOperations):
|
||||
self.dremio_api = dremio_api
|
||||
self.edition = dremio_api.edition
|
||||
self.datasets: Deque[DremioDataset] = deque()
|
||||
self.sources: Deque[DremioSourceContainer] = deque()
|
||||
self.spaces: Deque[DremioSpace] = deque()
|
||||
self.folders: Deque[DremioFolder] = deque()
|
||||
self.glossary_terms: Deque[DremioGlossaryTerm] = deque()
|
||||
self.queries: Deque[DremioQuery] = deque()
|
||||
|
||||
self.datasets_populated = False
|
||||
self.containers_populated = False
|
||||
self.queries_populated = False
|
||||
|
||||
def set_datasets(self) -> None:
|
||||
if not self.datasets_populated:
|
||||
self.set_containers()
|
||||
def get_datasets(self) -> Iterator[DremioDataset]:
|
||||
"""Get all Dremio datasets (tables and views) as an iterator."""
|
||||
# Get containers directly without storing them
|
||||
containers = self.get_containers()
|
||||
|
||||
containers: Deque[DremioContainer] = deque()
|
||||
containers.extend(self.spaces) # Add DremioSpace elements
|
||||
containers.extend(self.sources) # Add DremioSource elements
|
||||
for dataset_details in self.dremio_api.get_all_tables_and_columns(containers):
|
||||
dremio_dataset = DremioDataset(
|
||||
dataset_details=dataset_details,
|
||||
api_operations=self.dremio_api,
|
||||
)
|
||||
|
||||
for dataset_details in self.dremio_api.get_all_tables_and_columns(
|
||||
containers=containers
|
||||
):
|
||||
dremio_dataset = DremioDataset(
|
||||
dataset_details=dataset_details,
|
||||
api_operations=self.dremio_api,
|
||||
)
|
||||
self.datasets.append(dremio_dataset)
|
||||
|
||||
for glossary_term in dremio_dataset.glossary_terms:
|
||||
if glossary_term not in self.glossary_terms:
|
||||
self.glossary_terms.append(glossary_term)
|
||||
|
||||
self.datasets_populated = True
|
||||
|
||||
def get_datasets(self) -> Deque[DremioDataset]:
|
||||
self.set_datasets()
|
||||
return self.datasets
|
||||
yield dremio_dataset
|
||||
|
||||
def set_containers(self) -> None:
|
||||
if not self.containers_populated:
|
||||
@ -423,18 +406,50 @@ class DremioCatalog:
|
||||
|
||||
self.containers_populated = True
|
||||
|
||||
def get_containers(self) -> Deque:
|
||||
self.set_containers()
|
||||
return deque(itertools.chain(self.sources, self.spaces, self.folders))
|
||||
def get_containers(self) -> Iterator[DremioContainer]:
|
||||
"""Get all containers (sources, spaces, folders) as an iterator."""
|
||||
for container in self.dremio_api.get_all_containers():
|
||||
container_type = container.get("container_type")
|
||||
if container_type == DremioEntityContainerType.SOURCE:
|
||||
yield DremioSourceContainer(
|
||||
container_name=container.get("name"),
|
||||
location_id=container.get("id"),
|
||||
path=[],
|
||||
api_operations=self.dremio_api,
|
||||
dremio_source_type=container.get("source_type") or "",
|
||||
root_path=container.get("root_path"),
|
||||
database_name=container.get("database_name"),
|
||||
)
|
||||
elif container_type == DremioEntityContainerType.SPACE:
|
||||
yield DremioSpace(
|
||||
container_name=container.get("name"),
|
||||
location_id=container.get("id"),
|
||||
path=[],
|
||||
api_operations=self.dremio_api,
|
||||
)
|
||||
elif container_type == DremioEntityContainerType.FOLDER:
|
||||
yield DremioFolder(
|
||||
container_name=container.get("name"),
|
||||
location_id=container.get("id"),
|
||||
path=container.get("path"),
|
||||
api_operations=self.dremio_api,
|
||||
)
|
||||
|
||||
def get_sources(self) -> Deque[DremioSourceContainer]:
|
||||
self.set_containers()
|
||||
return self.sources
|
||||
def get_sources(self) -> Iterator[DremioSourceContainer]:
|
||||
"""Get all Dremio source containers (external data connections) as an iterator."""
|
||||
for container in self.get_containers():
|
||||
if isinstance(container, DremioSourceContainer):
|
||||
yield container
|
||||
|
||||
def get_glossary_terms(self) -> Deque[DremioGlossaryTerm]:
|
||||
self.set_datasets()
|
||||
self.set_containers()
|
||||
return self.glossary_terms
|
||||
def get_glossary_terms(self) -> Iterator[DremioGlossaryTerm]:
|
||||
"""Get all unique glossary terms (tags) from datasets."""
|
||||
glossary_terms_seen = set()
|
||||
|
||||
for dataset in self.get_datasets():
|
||||
for glossary_term in dataset.glossary_terms:
|
||||
if glossary_term not in glossary_terms_seen:
|
||||
glossary_terms_seen.add(glossary_term)
|
||||
yield glossary_term
|
||||
|
||||
def is_valid_query(self, query: Dict[str, Any]) -> bool:
|
||||
required_fields = [
|
||||
@ -447,6 +462,7 @@ class DremioCatalog:
|
||||
return all(query.get(field) for field in required_fields)
|
||||
|
||||
def get_queries(self) -> Deque[DremioQuery]:
|
||||
"""Get all valid Dremio queries for lineage analysis."""
|
||||
for query in self.dremio_api.extract_all_queries():
|
||||
if not self.is_valid_query(query):
|
||||
continue
|
||||
|
||||
@ -17,6 +17,7 @@ from datahub.metadata.schema_classes import (
|
||||
DatasetProfileClass,
|
||||
QuantileClass,
|
||||
)
|
||||
from datahub.utilities.perf_timer import PerfTimer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -64,8 +65,13 @@ class DremioProfiler:
|
||||
)
|
||||
return
|
||||
|
||||
profile_data = self.profile_table(full_table_name, columns)
|
||||
profile_aspect = self.populate_profile_aspect(profile_data)
|
||||
with PerfTimer() as timer:
|
||||
profile_data = self.profile_table(full_table_name, columns)
|
||||
profile_aspect = self.populate_profile_aspect(profile_data)
|
||||
|
||||
logger.info(
|
||||
f"Profiled table {full_table_name} with {len(columns)} columns in {timer.elapsed_seconds():.2f} seconds"
|
||||
)
|
||||
|
||||
if profile_aspect:
|
||||
self.report.report_entity_profiled(dataset.resource_name)
|
||||
@ -131,7 +137,12 @@ class DremioProfiler:
|
||||
def _profile_chunk(self, table_name: str, columns: List[Tuple[str, str]]) -> Dict:
|
||||
profile_sql = self._build_profile_sql(table_name, columns)
|
||||
try:
|
||||
results = self.api_operations.execute_query(profile_sql)
|
||||
with PerfTimer() as timer:
|
||||
results = self.api_operations.execute_query(profile_sql)
|
||||
|
||||
logger.debug(
|
||||
f"Profiling query for {table_name} ({len(columns)} columns) completed in {timer.elapsed_seconds():.2f} seconds"
|
||||
)
|
||||
return self._parse_profile_results(results, columns)
|
||||
except DremioAPIException as e:
|
||||
raise e
|
||||
|
||||
@ -55,7 +55,7 @@ from datahub.ingestion.source.state.stateful_ingestion_base import (
|
||||
from datahub.ingestion.source_report.ingestion_stage import (
|
||||
LINEAGE_EXTRACTION,
|
||||
METADATA_EXTRACTION,
|
||||
IngestionHighStage,
|
||||
PROFILING,
|
||||
)
|
||||
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
|
||||
DatasetLineageTypeClass,
|
||||
@ -201,7 +201,7 @@ class DremioSource(StatefulIngestionSourceBase):
|
||||
return "dremio"
|
||||
|
||||
def _build_source_map(self) -> Dict[str, DremioSourceMapEntry]:
|
||||
dremio_sources = self.dremio_catalog.get_sources()
|
||||
dremio_sources = list(self.dremio_catalog.get_sources())
|
||||
source_mappings_config = self.config.source_mappings or []
|
||||
|
||||
source_map = build_dremio_source_map(dremio_sources, source_mappings_config)
|
||||
@ -242,9 +242,7 @@ class DremioSource(StatefulIngestionSourceBase):
|
||||
)
|
||||
|
||||
# Process Datasets
|
||||
datasets = self.dremio_catalog.get_datasets()
|
||||
|
||||
for dataset_info in datasets:
|
||||
for dataset_info in self.dremio_catalog.get_datasets():
|
||||
try:
|
||||
yield from self.process_dataset(dataset_info)
|
||||
logger.info(
|
||||
@ -258,10 +256,8 @@ class DremioSource(StatefulIngestionSourceBase):
|
||||
exc=exc,
|
||||
)
|
||||
|
||||
# Process Glossary Terms
|
||||
glossary_terms = self.dremio_catalog.get_glossary_terms()
|
||||
|
||||
for glossary_term in glossary_terms:
|
||||
# Process Glossary Terms using streaming
|
||||
for glossary_term in self.dremio_catalog.get_glossary_terms():
|
||||
try:
|
||||
yield from self.process_glossary_term(glossary_term)
|
||||
except Exception as exc:
|
||||
@ -283,14 +279,16 @@ class DremioSource(StatefulIngestionSourceBase):
|
||||
# Profiling
|
||||
if self.config.is_profiling_enabled():
|
||||
with (
|
||||
self.report.new_high_stage(IngestionHighStage.PROFILING),
|
||||
self.report.new_stage(PROFILING),
|
||||
ThreadPoolExecutor(
|
||||
max_workers=self.config.profiling.max_workers
|
||||
) as executor,
|
||||
):
|
||||
# Collect datasets for profiling
|
||||
datasets_for_profiling = list(self.dremio_catalog.get_datasets())
|
||||
future_to_dataset = {
|
||||
executor.submit(self.generate_profiles, dataset): dataset
|
||||
for dataset in datasets
|
||||
for dataset in datasets_for_profiling
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_dataset):
|
||||
|
||||
@ -0,0 +1,281 @@
|
||||
from collections import deque
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from datahub.ingestion.source.dremio.dremio_api import (
|
||||
DremioAPIException,
|
||||
DremioAPIOperations,
|
||||
DremioEdition,
|
||||
)
|
||||
from datahub.ingestion.source.dremio.dremio_config import DremioSourceConfig
|
||||
from datahub.ingestion.source.dremio.dremio_reporting import DremioSourceReport
|
||||
|
||||
|
||||
class TestDremioAPIPagination:
|
||||
@pytest.fixture
|
||||
def dremio_api(self, monkeypatch):
|
||||
"""Setup mock Dremio API for testing"""
|
||||
# Mock the requests.Session
|
||||
mock_session = Mock()
|
||||
monkeypatch.setattr("requests.Session", Mock(return_value=mock_session))
|
||||
|
||||
# Mock the authentication response
|
||||
mock_session.post.return_value.json.return_value = {"token": "dummy-token"}
|
||||
mock_session.post.return_value.status_code = 200
|
||||
|
||||
config = DremioSourceConfig(
|
||||
hostname="dummy-host",
|
||||
port=9047,
|
||||
tls=False,
|
||||
authentication_method="password",
|
||||
username="dummy-user",
|
||||
password="dummy-password",
|
||||
)
|
||||
report = DremioSourceReport()
|
||||
api = DremioAPIOperations(config, report)
|
||||
api.session = mock_session
|
||||
return api
|
||||
|
||||
def test_fetch_all_results_missing_rows_key(self, dremio_api):
|
||||
"""Test handling of API response missing 'rows' key"""
|
||||
# Mock get_job_result to return response without 'rows' key
|
||||
dremio_api.get_job_result = Mock(return_value={"message": "No data available"})
|
||||
|
||||
result = dremio_api._fetch_all_results("test-job-id")
|
||||
|
||||
# Should return empty list when no rows key is present
|
||||
assert result == []
|
||||
dremio_api.get_job_result.assert_called_once()
|
||||
|
||||
def test_fetch_all_results_with_error_message(self, dremio_api):
|
||||
"""Test handling of API response with errorMessage"""
|
||||
# Mock get_job_result to return error response
|
||||
dremio_api.get_job_result = Mock(return_value={"errorMessage": "Out of memory"})
|
||||
|
||||
with pytest.raises(DremioAPIException, match="Query error: Out of memory"):
|
||||
dremio_api._fetch_all_results("test-job-id")
|
||||
|
||||
def test_fetch_all_results_empty_rows(self, dremio_api):
|
||||
"""Test handling of empty rows response"""
|
||||
# Mock get_job_result to return empty rows
|
||||
dremio_api.get_job_result = Mock(return_value={"rows": [], "rowCount": 0})
|
||||
|
||||
result = dremio_api._fetch_all_results("test-job-id")
|
||||
|
||||
assert result == []
|
||||
dremio_api.get_job_result.assert_called_once()
|
||||
|
||||
def test_fetch_all_results_normal_case(self, dremio_api):
|
||||
"""Test normal operation with valid rows"""
|
||||
# Mock get_job_result to return valid data
|
||||
# First response: 2 rows, rowCount=2 (offset will be 2, which equals rowCount, so stops)
|
||||
mock_responses = [
|
||||
{"rows": [{"col1": "val1"}, {"col1": "val2"}], "rowCount": 2},
|
||||
]
|
||||
dremio_api.get_job_result = Mock(side_effect=mock_responses)
|
||||
|
||||
result = dremio_api._fetch_all_results("test-job-id")
|
||||
|
||||
expected = [{"col1": "val1"}, {"col1": "val2"}]
|
||||
assert result == expected
|
||||
assert dremio_api.get_job_result.call_count == 1
|
||||
|
||||
def test_fetch_results_iter_missing_rows_key(self, dremio_api):
|
||||
"""Test internal streaming method handling missing 'rows' key"""
|
||||
# Mock get_job_result to return response without 'rows' key
|
||||
dremio_api.get_job_result = Mock(return_value={"message": "No data available"})
|
||||
|
||||
result_iterator = dremio_api._fetch_results_iter("test-job-id")
|
||||
results = list(result_iterator)
|
||||
|
||||
# Should return empty iterator when no rows key is present
|
||||
assert results == []
|
||||
|
||||
def test_fetch_results_iter_with_error(self, dremio_api):
|
||||
"""Test internal streaming method handling error response"""
|
||||
# Mock get_job_result to return error response
|
||||
dremio_api.get_job_result = Mock(return_value={"errorMessage": "Query timeout"})
|
||||
|
||||
result_iterator = dremio_api._fetch_results_iter("test-job-id")
|
||||
|
||||
with pytest.raises(DremioAPIException, match="Query error: Query timeout"):
|
||||
list(result_iterator)
|
||||
|
||||
def test_fetch_results_iter_normal_case(self, dremio_api):
|
||||
"""Test internal streaming method with valid data"""
|
||||
# Mock get_job_result to return valid data in batches
|
||||
mock_responses = [
|
||||
{"rows": [{"col1": "val1"}, {"col1": "val2"}], "rowCount": 2},
|
||||
{"rows": [], "rowCount": 2}, # Empty response to end iteration
|
||||
]
|
||||
dremio_api.get_job_result = Mock(side_effect=mock_responses)
|
||||
|
||||
result_iterator = dremio_api._fetch_results_iter("test-job-id")
|
||||
results = list(result_iterator)
|
||||
|
||||
expected = [
|
||||
{"col1": "val1"},
|
||||
{"col1": "val2"},
|
||||
]
|
||||
assert results == expected
|
||||
|
||||
def test_execute_query_iter_success(self, dremio_api):
|
||||
"""Test execute_query_iter with successful job completion"""
|
||||
# Mock the POST response for job submission
|
||||
dremio_api.post = Mock(return_value={"id": "job-123"})
|
||||
|
||||
# Mock job status progression
|
||||
status_responses = [
|
||||
{"jobState": "RUNNING"},
|
||||
{"jobState": "RUNNING"},
|
||||
{"jobState": "COMPLETED"},
|
||||
]
|
||||
dremio_api.get_job_status = Mock(side_effect=status_responses)
|
||||
|
||||
# Mock streaming results
|
||||
dremio_api._fetch_results_iter = Mock(return_value=iter([{"col1": "val1"}]))
|
||||
|
||||
with patch("time.sleep"): # Skip actual sleep delays
|
||||
result_iterator = dremio_api.execute_query_iter("SELECT * FROM test")
|
||||
results = list(result_iterator)
|
||||
|
||||
assert results == [{"col1": "val1"}]
|
||||
dremio_api.post.assert_called_once()
|
||||
assert dremio_api.get_job_status.call_count == 3
|
||||
|
||||
def test_execute_query_iter_job_failure(self, dremio_api):
|
||||
"""Test execute_query_iter with job failure"""
|
||||
# Mock the POST response for job submission
|
||||
dremio_api.post = Mock(return_value={"id": "job-123"})
|
||||
|
||||
# Mock job failure
|
||||
dremio_api.get_job_status = Mock(
|
||||
return_value={"jobState": "FAILED", "errorMessage": "SQL syntax error"}
|
||||
)
|
||||
|
||||
with (
|
||||
pytest.raises(RuntimeError, match="Query failed: SQL syntax error"),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
result_iterator = dremio_api.execute_query_iter("SELECT * FROM test")
|
||||
list(result_iterator) # Force evaluation
|
||||
|
||||
def test_execute_query_iter_timeout(self, dremio_api):
|
||||
"""Test execute_query_iter with timeout"""
|
||||
# Mock the POST response for job submission
|
||||
dremio_api.post = Mock(return_value={"id": "job-123"})
|
||||
|
||||
# Mock job that never completes
|
||||
dremio_api.get_job_status = Mock(return_value={"jobState": "RUNNING"})
|
||||
dremio_api.cancel_query = Mock()
|
||||
|
||||
# Mock time.time to simulate timeout - need to patch where it's imported
|
||||
# First call: start_time = 0
|
||||
# Second call: time() - start_time check = 3700 (triggers timeout)
|
||||
mock_time_values = [0, 3700]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"datahub.ingestion.source.dremio.dremio_api.time",
|
||||
side_effect=mock_time_values,
|
||||
),
|
||||
patch("datahub.ingestion.source.dremio.dremio_api.sleep"),
|
||||
pytest.raises(
|
||||
DremioAPIException,
|
||||
match="Query execution timed out after 3600 seconds",
|
||||
),
|
||||
):
|
||||
result_iterator = dremio_api.execute_query_iter("SELECT * FROM test")
|
||||
list(result_iterator)
|
||||
|
||||
dremio_api.cancel_query.assert_called_once_with("job-123")
|
||||
|
||||
def test_get_all_tables_and_columns(self, dremio_api):
|
||||
"""Test streaming version of get_all_tables_and_columns"""
|
||||
from datahub.ingestion.source.dremio.dremio_api import DremioEdition
|
||||
|
||||
# Set up test data
|
||||
dremio_api.edition = DremioEdition.ENTERPRISE
|
||||
dremio_api.allow_schema_pattern = [".*"]
|
||||
dremio_api.deny_schema_pattern = []
|
||||
|
||||
# Mock container
|
||||
mock_container = Mock()
|
||||
mock_container.container_name = "test_source"
|
||||
containers = deque([mock_container])
|
||||
|
||||
# Mock streaming query results
|
||||
mock_results = [
|
||||
{
|
||||
"COLUMN_NAME": "col1",
|
||||
"FULL_TABLE_PATH": "test.table1",
|
||||
"TABLE_NAME": "table1",
|
||||
"TABLE_SCHEMA": "test",
|
||||
"ORDINAL_POSITION": 1,
|
||||
"IS_NULLABLE": "YES",
|
||||
"DATA_TYPE": "VARCHAR",
|
||||
"COLUMN_SIZE": 255,
|
||||
"RESOURCE_ID": "res1",
|
||||
}
|
||||
]
|
||||
dremio_api.execute_query_iter = Mock(return_value=iter(mock_results))
|
||||
|
||||
# Test streaming method
|
||||
result_iterator = dremio_api.get_all_tables_and_columns(containers)
|
||||
tables = list(result_iterator)
|
||||
|
||||
assert len(tables) == 1
|
||||
table = tables[0]
|
||||
assert table["TABLE_NAME"] == "table1"
|
||||
assert table["TABLE_SCHEMA"] == "test"
|
||||
assert len(table["COLUMNS"]) == 1
|
||||
assert table["COLUMNS"][0]["name"] == "col1"
|
||||
|
||||
def test_extract_all_queries(self, dremio_api):
|
||||
"""Test streaming version of extract_all_queries"""
|
||||
|
||||
dremio_api.edition = DremioEdition.ENTERPRISE
|
||||
dremio_api.start_time = None
|
||||
dremio_api.end_time = None
|
||||
|
||||
# Mock streaming query execution
|
||||
mock_query_results = [
|
||||
{"query_id": "q1", "sql": "SELECT * FROM table1"},
|
||||
{"query_id": "q2", "sql": "SELECT * FROM table2"},
|
||||
]
|
||||
dremio_api.execute_query_iter = Mock(return_value=iter(mock_query_results))
|
||||
|
||||
result_iterator = dremio_api.extract_all_queries()
|
||||
queries = list(result_iterator)
|
||||
|
||||
assert len(queries) == 2
|
||||
assert queries[0]["query_id"] == "q1"
|
||||
assert queries[1]["query_id"] == "q2"
|
||||
dremio_api.execute_query_iter.assert_called_once()
|
||||
|
||||
def test_fetch_results_iter_incremental_yielding(self, dremio_api):
|
||||
"""Test that internal iterator yields results incrementally and can be partially consumed"""
|
||||
# Verify that streaming yields results one at a time and iterator state is maintained
|
||||
|
||||
mock_responses = [
|
||||
{"rows": [{"id": i} for i in range(100)], "rowCount": 1000},
|
||||
{"rows": [{"id": i} for i in range(100, 200)], "rowCount": 1000},
|
||||
{"rows": [], "rowCount": 1000}, # End iteration
|
||||
]
|
||||
dremio_api.get_job_result = Mock(side_effect=mock_responses)
|
||||
|
||||
# Test that streaming version yields results incrementally
|
||||
result_iterator = dremio_api._fetch_results_iter("test-job")
|
||||
|
||||
# Get first few results
|
||||
first_batch = []
|
||||
for i, result in enumerate(result_iterator):
|
||||
first_batch.append(result)
|
||||
if i >= 50: # Stop after getting 51 results
|
||||
break
|
||||
|
||||
assert len(first_batch) == 51
|
||||
# Verify we can still get more results (iterator not exhausted)
|
||||
next_result = next(result_iterator, None)
|
||||
assert next_result is not None
|
||||
@ -0,0 +1,74 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from datahub.ingestion.source.dremio.dremio_config import DremioSourceConfig
|
||||
from datahub.ingestion.source.dremio.dremio_entities import DremioCatalog, DremioDataset
|
||||
from datahub.ingestion.source.dremio.dremio_source import DremioSource
|
||||
|
||||
|
||||
class TestDremioIteratorIntegration:
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
return DremioSourceConfig(
|
||||
hostname="test-host",
|
||||
port=9047,
|
||||
tls=False,
|
||||
username="test-user",
|
||||
password="test-password",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dremio_source(self, mock_config, monkeypatch):
|
||||
mock_session = Mock()
|
||||
monkeypatch.setattr("requests.Session", Mock(return_value=mock_session))
|
||||
mock_session.post.return_value.json.return_value = {"token": "dummy-token"}
|
||||
mock_session.post.return_value.status_code = 200
|
||||
|
||||
mock_ctx = Mock()
|
||||
mock_ctx.run_id = "test-run-id"
|
||||
source = DremioSource(mock_config, mock_ctx)
|
||||
|
||||
source.dremio_catalog = Mock(spec=DremioCatalog)
|
||||
source.dremio_catalog.dremio_api = Mock()
|
||||
|
||||
return source
|
||||
|
||||
def test_source_uses_iterators_by_default(self, mock_dremio_source):
|
||||
"""Test that source uses iterators by default"""
|
||||
mock_dataset = Mock(spec=DremioDataset)
|
||||
mock_dataset.path = ["test", "path"]
|
||||
mock_dataset.resource_name = "test_table"
|
||||
|
||||
mock_dremio_source.dremio_catalog.get_datasets.return_value = iter(
|
||||
[mock_dataset]
|
||||
)
|
||||
mock_dremio_source.dremio_catalog.get_containers.return_value = []
|
||||
mock_dremio_source.dremio_catalog.get_glossary_terms.return_value = iter([])
|
||||
mock_dremio_source.dremio_catalog.get_sources.return_value = []
|
||||
mock_dremio_source.config.source_mappings = []
|
||||
mock_dremio_source.process_dataset = Mock(return_value=iter([]))
|
||||
list(mock_dremio_source.get_workunits_internal())
|
||||
|
||||
mock_dremio_source.dremio_catalog.get_datasets.assert_called_once()
|
||||
mock_dremio_source.dremio_catalog.get_glossary_terms.assert_called_once()
|
||||
|
||||
def test_iterator_handles_exceptions_gracefully(self, mock_dremio_source):
|
||||
"""Test that iterator handles exceptions without crashing"""
|
||||
mock_dataset = Mock(spec=DremioDataset)
|
||||
mock_dataset.path = ["test", "path"]
|
||||
mock_dataset.resource_name = "test_table"
|
||||
|
||||
mock_dremio_source.dremio_catalog.get_datasets.return_value = iter(
|
||||
[mock_dataset]
|
||||
)
|
||||
mock_dremio_source.dremio_catalog.get_containers.return_value = []
|
||||
mock_dremio_source.dremio_catalog.get_glossary_terms.return_value = iter([])
|
||||
mock_dremio_source.dremio_catalog.get_sources.return_value = []
|
||||
mock_dremio_source.config.source_mappings = []
|
||||
mock_dremio_source.process_dataset = Mock(
|
||||
side_effect=Exception("Processing error")
|
||||
)
|
||||
list(mock_dremio_source.get_workunits_internal())
|
||||
|
||||
assert mock_dremio_source.report.num_datasets_failed > 0
|
||||
@ -10,7 +10,7 @@ from datahub.ingestion.source.dremio.dremio_source import (
|
||||
|
||||
|
||||
def test_build_source_map_simple():
|
||||
# write unit test
|
||||
"""Test basic source mapping functionality with simple configuration"""
|
||||
config_mapping: List[DremioSourceMapping] = [
|
||||
DremioSourceMapping(source_name="source1", platform="S3", env="PROD"),
|
||||
DremioSourceMapping(source_name="source2", platform="redshift", env="DEV"),
|
||||
@ -58,7 +58,7 @@ def test_build_source_map_simple():
|
||||
|
||||
|
||||
def test_build_source_map_same_platform_multiple_sources():
|
||||
# write unit test
|
||||
"""Test source mapping with multiple sources using the same platform and complex scenarios"""
|
||||
config_mapping: List[DremioSourceMapping] = [
|
||||
DremioSourceMapping(source_name="source1", platform="S3", env="PROD"),
|
||||
DremioSourceMapping(source_name="source2", platform="redshift", env="DEV"),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user