diff --git a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_api.py b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_api.py index be716377dd..fcd3cf62fe 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_api.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_api.py @@ -7,7 +7,7 @@ from collections import defaultdict from enum import Enum from itertools import product from time import sleep, time -from typing import TYPE_CHECKING, Any, Deque, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union from urllib.parse import quote import requests @@ -343,14 +343,149 @@ class DremioAPIOperations: while True: result = self.get_job_result(job_id, offset, limit) - rows.extend(result["rows"]) - offset = offset + limit - if offset >= result["rowCount"]: + # Handle cases where API response doesn't contain 'rows' key + # This can happen with OOM errors or when no rows are returned + if "rows" not in result: + logger.warning( + f"API response for job {job_id} missing 'rows' key. " + f"Response keys: {list(result.keys())}" + ) + # Check for error conditions + if "errorMessage" in result: + raise DremioAPIException(f"Query error: {result['errorMessage']}") + elif "message" in result: + logger.warning( + f"Query warning for job {job_id}: {result['message']}" + ) + # Return empty list if no rows key and no error break + # Handle empty rows response + result_rows = result["rows"] + if not result_rows: + logger.debug( + f"No more rows returned for job {job_id} at offset {offset}" + ) + break + + rows.extend(result_rows) + + # Check actual returned rows to determine if we should continue + actual_rows_returned = len(result_rows) + if actual_rows_returned == 0: + logger.debug(f"Query returned no rows for job {job_id}") + break + + offset = offset + actual_rows_returned + # If we got fewer rows than requested, we've reached the end + if actual_rows_returned < limit: + break + + logger.info(f"Fetched {len(rows)} total rows for job {job_id}") return rows + def _fetch_results_iter(self, job_id: str) -> Iterator[Dict]: + """ + Fetch job results in a streaming fashion to reduce memory usage. + Yields individual rows instead of collecting all in memory. + """ + limit = 500 + offset = 0 + total_rows_fetched = 0 + + while True: + result = self.get_job_result(job_id, offset, limit) + + # Handle cases where API response doesn't contain 'rows' key + if "rows" not in result: + logger.warning( + f"API response for job {job_id} missing 'rows' key. " + f"Response keys: {list(result.keys())}" + ) + # Check for error conditions + if "errorMessage" in result: + raise DremioAPIException(f"Query error: {result['errorMessage']}") + elif "message" in result: + logger.warning( + f"Query warning for job {job_id}: {result['message']}" + ) + # Stop iteration if no rows key and no error + break + + # Handle empty rows response + result_rows = result["rows"] + if not result_rows: + logger.debug( + f"No more rows returned for job {job_id} at offset {offset}" + ) + break + + # Yield individual rows instead of collecting them + for row in result_rows: + yield row + total_rows_fetched += 1 + + # Check actual returned rows to determine if we should continue + actual_rows_returned = len(result_rows) + if actual_rows_returned == 0: + logger.debug(f"Query returned no rows for job {job_id}") + break + + offset = offset + actual_rows_returned + # If we got fewer rows than requested, we've reached the end + if actual_rows_returned < limit: + break + + logger.info(f"Streamed {total_rows_fetched} total rows for job {job_id}") + + def execute_query_iter( + self, query: str, timeout: int = 3600 + ) -> Iterator[Dict[str, Any]]: + """Execute SQL query and return results as a streaming iterator""" + try: + with PerfTimer() as timer: + logger.info(f"Executing streaming query: {query}") + response = self.post(url="/sql", data=json.dumps({"sql": query})) + + if "errorMessage" in response: + self.report.failure( + message="SQL Error", context=f"{response['errorMessage']}" + ) + raise DremioAPIException(f"SQL Error: {response['errorMessage']}") + + job_id = response["id"] + + # Wait for job completion + start_time = time() + while True: + status = self.get_job_status(job_id) + if status["jobState"] == "COMPLETED": + break + elif status["jobState"] == "FAILED": + error_message = status.get("errorMessage", "Unknown error") + raise RuntimeError(f"Query failed: {error_message}") + elif status["jobState"] == "CANCELED": + raise RuntimeError("Query was canceled") + + if time() - start_time > timeout: + self.cancel_query(job_id) + raise DremioAPIException( + f"Query execution timed out after {timeout} seconds" + ) + + sleep(3) + + logger.info( + f"Query job completed in {timer.elapsed_seconds()} seconds, starting streaming" + ) + + # Return streaming iterator + return self._fetch_results_iter(job_id) + + except requests.RequestException as e: + raise DremioAPIException("Error executing streaming query") from e + def cancel_query(self, job_id: str) -> None: """Cancel a running query""" try: @@ -499,8 +634,12 @@ class DremioAPIOperations: return f"AND {operator}({field}, '{pattern_str}')" def get_all_tables_and_columns( - self, containers: Deque["DremioContainer"] - ) -> List[Dict]: + self, containers: Iterator["DremioContainer"] + ) -> Iterator[Dict]: + """ + Memory-efficient streaming version that yields tables one at a time. + Reduces memory usage for large datasets by processing results as they come. + """ if self.edition == DremioEdition.ENTERPRISE: query_template = DremioSQLQueries.QUERY_DATASETS_EE elif self.edition == DremioEdition.CLOUD: @@ -517,21 +656,78 @@ class DremioAPIOperations: self.deny_schema_pattern, schema_field, allow=False ) - all_tables_and_columns = [] - + # Process each container's results separately to avoid memory buildup for schema in containers: - formatted_query = "" try: formatted_query = query_template.format( schema_pattern=schema_condition, deny_schema_pattern=deny_schema_condition, container_name=schema.container_name.lower(), ) - all_tables_and_columns.extend( - self.execute_query( - query=formatted_query, + + # Use streaming query execution + container_results = list(self.execute_query_iter(query=formatted_query)) + + if self.edition == DremioEdition.COMMUNITY: + # Process community edition results + formatted_tables = self.community_get_formatted_tables( + container_results ) - ) + for table in formatted_tables: + yield table + else: + # Process enterprise/cloud edition results + column_dictionary: Dict[str, List[Dict]] = defaultdict(list) + table_metadata: Dict[str, Dict] = {} + + for record in container_results: + if not record.get("COLUMN_NAME"): + continue + + table_full_path = record.get("FULL_TABLE_PATH") + if not table_full_path: + continue + + # Store column information + column_dictionary[table_full_path].append( + { + "name": record["COLUMN_NAME"], + "ordinal_position": record["ORDINAL_POSITION"], + "is_nullable": record["IS_NULLABLE"], + "data_type": record["DATA_TYPE"], + "column_size": record["COLUMN_SIZE"], + } + ) + + # Store table metadata (only once per table) + if table_full_path not in table_metadata: + table_metadata[table_full_path] = { + "TABLE_NAME": record.get("TABLE_NAME"), + "TABLE_SCHEMA": record.get("TABLE_SCHEMA"), + "VIEW_DEFINITION": record.get("VIEW_DEFINITION"), + "RESOURCE_ID": record.get("RESOURCE_ID"), + "LOCATION_ID": record.get("LOCATION_ID"), + "OWNER": record.get("OWNER"), + "OWNER_TYPE": record.get("OWNER_TYPE"), + "CREATED": record.get("CREATED"), + "FORMAT_TYPE": record.get("FORMAT_TYPE"), + } + + # Yield tables one at a time + for table_path, table_info in table_metadata.items(): + yield { + "TABLE_NAME": table_info.get("TABLE_NAME"), + "TABLE_SCHEMA": table_info.get("TABLE_SCHEMA"), + "COLUMNS": column_dictionary[table_path], + "VIEW_DEFINITION": table_info.get("VIEW_DEFINITION"), + "RESOURCE_ID": table_info.get("RESOURCE_ID"), + "LOCATION_ID": table_info.get("LOCATION_ID"), + "OWNER": table_info.get("OWNER"), + "OWNER_TYPE": table_info.get("OWNER_TYPE"), + "CREATED": table_info.get("CREATED"), + "FORMAT_TYPE": table_info.get("FORMAT_TYPE"), + } + except DremioAPIException as e: self.report.warning( message="Container has no tables or views", @@ -539,71 +735,6 @@ class DremioAPIOperations: exc=e, ) - tables = [] - - if self.edition == DremioEdition.COMMUNITY: - tables = self.community_get_formatted_tables(all_tables_and_columns) - - else: - column_dictionary: Dict[str, List[Dict]] = defaultdict(list) - - for record in all_tables_and_columns: - if not record.get("COLUMN_NAME"): - continue - - table_full_path = record.get("FULL_TABLE_PATH") - if not table_full_path: - continue - - column_dictionary[table_full_path].append( - { - "name": record["COLUMN_NAME"], - "ordinal_position": record["ORDINAL_POSITION"], - "is_nullable": record["IS_NULLABLE"], - "data_type": record["DATA_TYPE"], - "column_size": record["COLUMN_SIZE"], - } - ) - - distinct_tables_list = list( - { - tuple( - dictionary[key] - for key in ( - "TABLE_SCHEMA", - "TABLE_NAME", - "FULL_TABLE_PATH", - "VIEW_DEFINITION", - "LOCATION_ID", - "OWNER", - "OWNER_TYPE", - "CREATED", - "FORMAT_TYPE", - ) - if key in dictionary - ): dictionary - for dictionary in all_tables_and_columns - }.values() - ) - - for table in distinct_tables_list: - tables.append( - { - "TABLE_NAME": table.get("TABLE_NAME"), - "TABLE_SCHEMA": table.get("TABLE_SCHEMA"), - "COLUMNS": column_dictionary[table["FULL_TABLE_PATH"]], - "VIEW_DEFINITION": table.get("VIEW_DEFINITION"), - "RESOURCE_ID": table.get("RESOURCE_ID"), - "LOCATION_ID": table.get("LOCATION_ID"), - "OWNER": table.get("OWNER"), - "OWNER_TYPE": table.get("OWNER_TYPE"), - "CREATED": table.get("CREATED"), - "FORMAT_TYPE": table.get("FORMAT_TYPE"), - } - ) - - return tables - def validate_schema_format(self, schema): if "." in schema: schema_path = self.get( @@ -640,7 +771,10 @@ class DremioAPIOperations: return parents_list - def extract_all_queries(self) -> List[Dict[str, Any]]: + def extract_all_queries(self) -> Iterator[Dict[str, Any]]: + """ + Memory-efficient streaming version for extracting query results. + """ # Convert datetime objects to string format for SQL queries start_timestamp_str = None end_timestamp_str = None @@ -661,7 +795,7 @@ class DremioAPIOperations: end_timestamp_millis=end_timestamp_str, ) - return self.execute_query(query=jobs_query) + return self.execute_query_iter(query=jobs_query) def get_tags_for_resource(self, resource_id: str) -> Optional[List[str]]: """ diff --git a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_entities.py b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_entities.py index 0ed2ad062e..b58eccdc9c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_entities.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_entities.py @@ -1,4 +1,3 @@ -import itertools import logging import re import uuid @@ -6,7 +5,7 @@ from collections import deque from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Deque, Dict, List, Optional +from typing import Any, Deque, Dict, Iterator, List, Optional from sqlglot import parse_one @@ -184,6 +183,7 @@ class DremioQuery: return "" def get_raw_query(self, sql_query: str) -> str: + """Remove comments from SQL query using sqlglot parser.""" try: parsed = parse_one(sql_query) return parsed.sql(comments=False) @@ -336,43 +336,26 @@ class DremioCatalog: def __init__(self, dremio_api: DremioAPIOperations): self.dremio_api = dremio_api self.edition = dremio_api.edition - self.datasets: Deque[DremioDataset] = deque() self.sources: Deque[DremioSourceContainer] = deque() self.spaces: Deque[DremioSpace] = deque() self.folders: Deque[DremioFolder] = deque() - self.glossary_terms: Deque[DremioGlossaryTerm] = deque() self.queries: Deque[DremioQuery] = deque() - self.datasets_populated = False self.containers_populated = False self.queries_populated = False - def set_datasets(self) -> None: - if not self.datasets_populated: - self.set_containers() + def get_datasets(self) -> Iterator[DremioDataset]: + """Get all Dremio datasets (tables and views) as an iterator.""" + # Get containers directly without storing them + containers = self.get_containers() - containers: Deque[DremioContainer] = deque() - containers.extend(self.spaces) # Add DremioSpace elements - containers.extend(self.sources) # Add DremioSource elements + for dataset_details in self.dremio_api.get_all_tables_and_columns(containers): + dremio_dataset = DremioDataset( + dataset_details=dataset_details, + api_operations=self.dremio_api, + ) - for dataset_details in self.dremio_api.get_all_tables_and_columns( - containers=containers - ): - dremio_dataset = DremioDataset( - dataset_details=dataset_details, - api_operations=self.dremio_api, - ) - self.datasets.append(dremio_dataset) - - for glossary_term in dremio_dataset.glossary_terms: - if glossary_term not in self.glossary_terms: - self.glossary_terms.append(glossary_term) - - self.datasets_populated = True - - def get_datasets(self) -> Deque[DremioDataset]: - self.set_datasets() - return self.datasets + yield dremio_dataset def set_containers(self) -> None: if not self.containers_populated: @@ -423,18 +406,50 @@ class DremioCatalog: self.containers_populated = True - def get_containers(self) -> Deque: - self.set_containers() - return deque(itertools.chain(self.sources, self.spaces, self.folders)) + def get_containers(self) -> Iterator[DremioContainer]: + """Get all containers (sources, spaces, folders) as an iterator.""" + for container in self.dremio_api.get_all_containers(): + container_type = container.get("container_type") + if container_type == DremioEntityContainerType.SOURCE: + yield DremioSourceContainer( + container_name=container.get("name"), + location_id=container.get("id"), + path=[], + api_operations=self.dremio_api, + dremio_source_type=container.get("source_type") or "", + root_path=container.get("root_path"), + database_name=container.get("database_name"), + ) + elif container_type == DremioEntityContainerType.SPACE: + yield DremioSpace( + container_name=container.get("name"), + location_id=container.get("id"), + path=[], + api_operations=self.dremio_api, + ) + elif container_type == DremioEntityContainerType.FOLDER: + yield DremioFolder( + container_name=container.get("name"), + location_id=container.get("id"), + path=container.get("path"), + api_operations=self.dremio_api, + ) - def get_sources(self) -> Deque[DremioSourceContainer]: - self.set_containers() - return self.sources + def get_sources(self) -> Iterator[DremioSourceContainer]: + """Get all Dremio source containers (external data connections) as an iterator.""" + for container in self.get_containers(): + if isinstance(container, DremioSourceContainer): + yield container - def get_glossary_terms(self) -> Deque[DremioGlossaryTerm]: - self.set_datasets() - self.set_containers() - return self.glossary_terms + def get_glossary_terms(self) -> Iterator[DremioGlossaryTerm]: + """Get all unique glossary terms (tags) from datasets.""" + glossary_terms_seen = set() + + for dataset in self.get_datasets(): + for glossary_term in dataset.glossary_terms: + if glossary_term not in glossary_terms_seen: + glossary_terms_seen.add(glossary_term) + yield glossary_term def is_valid_query(self, query: Dict[str, Any]) -> bool: required_fields = [ @@ -447,6 +462,7 @@ class DremioCatalog: return all(query.get(field) for field in required_fields) def get_queries(self) -> Deque[DremioQuery]: + """Get all valid Dremio queries for lineage analysis.""" for query in self.dremio_api.extract_all_queries(): if not self.is_valid_query(query): continue diff --git a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_profiling.py b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_profiling.py index 5332597ffc..3a357a6aec 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_profiling.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_profiling.py @@ -17,6 +17,7 @@ from datahub.metadata.schema_classes import ( DatasetProfileClass, QuantileClass, ) +from datahub.utilities.perf_timer import PerfTimer logger = logging.getLogger(__name__) @@ -64,8 +65,13 @@ class DremioProfiler: ) return - profile_data = self.profile_table(full_table_name, columns) - profile_aspect = self.populate_profile_aspect(profile_data) + with PerfTimer() as timer: + profile_data = self.profile_table(full_table_name, columns) + profile_aspect = self.populate_profile_aspect(profile_data) + + logger.info( + f"Profiled table {full_table_name} with {len(columns)} columns in {timer.elapsed_seconds():.2f} seconds" + ) if profile_aspect: self.report.report_entity_profiled(dataset.resource_name) @@ -131,7 +137,12 @@ class DremioProfiler: def _profile_chunk(self, table_name: str, columns: List[Tuple[str, str]]) -> Dict: profile_sql = self._build_profile_sql(table_name, columns) try: - results = self.api_operations.execute_query(profile_sql) + with PerfTimer() as timer: + results = self.api_operations.execute_query(profile_sql) + + logger.debug( + f"Profiling query for {table_name} ({len(columns)} columns) completed in {timer.elapsed_seconds():.2f} seconds" + ) return self._parse_profile_results(results, columns) except DremioAPIException as e: raise e diff --git a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_source.py b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_source.py index 38555782e3..50bf918f3c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_source.py @@ -55,7 +55,7 @@ from datahub.ingestion.source.state.stateful_ingestion_base import ( from datahub.ingestion.source_report.ingestion_stage import ( LINEAGE_EXTRACTION, METADATA_EXTRACTION, - IngestionHighStage, + PROFILING, ) from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( DatasetLineageTypeClass, @@ -201,7 +201,7 @@ class DremioSource(StatefulIngestionSourceBase): return "dremio" def _build_source_map(self) -> Dict[str, DremioSourceMapEntry]: - dremio_sources = self.dremio_catalog.get_sources() + dremio_sources = list(self.dremio_catalog.get_sources()) source_mappings_config = self.config.source_mappings or [] source_map = build_dremio_source_map(dremio_sources, source_mappings_config) @@ -242,9 +242,7 @@ class DremioSource(StatefulIngestionSourceBase): ) # Process Datasets - datasets = self.dremio_catalog.get_datasets() - - for dataset_info in datasets: + for dataset_info in self.dremio_catalog.get_datasets(): try: yield from self.process_dataset(dataset_info) logger.info( @@ -258,10 +256,8 @@ class DremioSource(StatefulIngestionSourceBase): exc=exc, ) - # Process Glossary Terms - glossary_terms = self.dremio_catalog.get_glossary_terms() - - for glossary_term in glossary_terms: + # Process Glossary Terms using streaming + for glossary_term in self.dremio_catalog.get_glossary_terms(): try: yield from self.process_glossary_term(glossary_term) except Exception as exc: @@ -283,14 +279,16 @@ class DremioSource(StatefulIngestionSourceBase): # Profiling if self.config.is_profiling_enabled(): with ( - self.report.new_high_stage(IngestionHighStage.PROFILING), + self.report.new_stage(PROFILING), ThreadPoolExecutor( max_workers=self.config.profiling.max_workers ) as executor, ): + # Collect datasets for profiling + datasets_for_profiling = list(self.dremio_catalog.get_datasets()) future_to_dataset = { executor.submit(self.generate_profiles, dataset): dataset - for dataset in datasets + for dataset in datasets_for_profiling } for future in as_completed(future_to_dataset): diff --git a/metadata-ingestion/tests/unit/dremio/test_dremio_api_pagination.py b/metadata-ingestion/tests/unit/dremio/test_dremio_api_pagination.py new file mode 100644 index 0000000000..b1535fd8be --- /dev/null +++ b/metadata-ingestion/tests/unit/dremio/test_dremio_api_pagination.py @@ -0,0 +1,281 @@ +from collections import deque +from unittest.mock import Mock, patch + +import pytest + +from datahub.ingestion.source.dremio.dremio_api import ( + DremioAPIException, + DremioAPIOperations, + DremioEdition, +) +from datahub.ingestion.source.dremio.dremio_config import DremioSourceConfig +from datahub.ingestion.source.dremio.dremio_reporting import DremioSourceReport + + +class TestDremioAPIPagination: + @pytest.fixture + def dremio_api(self, monkeypatch): + """Setup mock Dremio API for testing""" + # Mock the requests.Session + mock_session = Mock() + monkeypatch.setattr("requests.Session", Mock(return_value=mock_session)) + + # Mock the authentication response + mock_session.post.return_value.json.return_value = {"token": "dummy-token"} + mock_session.post.return_value.status_code = 200 + + config = DremioSourceConfig( + hostname="dummy-host", + port=9047, + tls=False, + authentication_method="password", + username="dummy-user", + password="dummy-password", + ) + report = DremioSourceReport() + api = DremioAPIOperations(config, report) + api.session = mock_session + return api + + def test_fetch_all_results_missing_rows_key(self, dremio_api): + """Test handling of API response missing 'rows' key""" + # Mock get_job_result to return response without 'rows' key + dremio_api.get_job_result = Mock(return_value={"message": "No data available"}) + + result = dremio_api._fetch_all_results("test-job-id") + + # Should return empty list when no rows key is present + assert result == [] + dremio_api.get_job_result.assert_called_once() + + def test_fetch_all_results_with_error_message(self, dremio_api): + """Test handling of API response with errorMessage""" + # Mock get_job_result to return error response + dremio_api.get_job_result = Mock(return_value={"errorMessage": "Out of memory"}) + + with pytest.raises(DremioAPIException, match="Query error: Out of memory"): + dremio_api._fetch_all_results("test-job-id") + + def test_fetch_all_results_empty_rows(self, dremio_api): + """Test handling of empty rows response""" + # Mock get_job_result to return empty rows + dremio_api.get_job_result = Mock(return_value={"rows": [], "rowCount": 0}) + + result = dremio_api._fetch_all_results("test-job-id") + + assert result == [] + dremio_api.get_job_result.assert_called_once() + + def test_fetch_all_results_normal_case(self, dremio_api): + """Test normal operation with valid rows""" + # Mock get_job_result to return valid data + # First response: 2 rows, rowCount=2 (offset will be 2, which equals rowCount, so stops) + mock_responses = [ + {"rows": [{"col1": "val1"}, {"col1": "val2"}], "rowCount": 2}, + ] + dremio_api.get_job_result = Mock(side_effect=mock_responses) + + result = dremio_api._fetch_all_results("test-job-id") + + expected = [{"col1": "val1"}, {"col1": "val2"}] + assert result == expected + assert dremio_api.get_job_result.call_count == 1 + + def test_fetch_results_iter_missing_rows_key(self, dremio_api): + """Test internal streaming method handling missing 'rows' key""" + # Mock get_job_result to return response without 'rows' key + dremio_api.get_job_result = Mock(return_value={"message": "No data available"}) + + result_iterator = dremio_api._fetch_results_iter("test-job-id") + results = list(result_iterator) + + # Should return empty iterator when no rows key is present + assert results == [] + + def test_fetch_results_iter_with_error(self, dremio_api): + """Test internal streaming method handling error response""" + # Mock get_job_result to return error response + dremio_api.get_job_result = Mock(return_value={"errorMessage": "Query timeout"}) + + result_iterator = dremio_api._fetch_results_iter("test-job-id") + + with pytest.raises(DremioAPIException, match="Query error: Query timeout"): + list(result_iterator) + + def test_fetch_results_iter_normal_case(self, dremio_api): + """Test internal streaming method with valid data""" + # Mock get_job_result to return valid data in batches + mock_responses = [ + {"rows": [{"col1": "val1"}, {"col1": "val2"}], "rowCount": 2}, + {"rows": [], "rowCount": 2}, # Empty response to end iteration + ] + dremio_api.get_job_result = Mock(side_effect=mock_responses) + + result_iterator = dremio_api._fetch_results_iter("test-job-id") + results = list(result_iterator) + + expected = [ + {"col1": "val1"}, + {"col1": "val2"}, + ] + assert results == expected + + def test_execute_query_iter_success(self, dremio_api): + """Test execute_query_iter with successful job completion""" + # Mock the POST response for job submission + dremio_api.post = Mock(return_value={"id": "job-123"}) + + # Mock job status progression + status_responses = [ + {"jobState": "RUNNING"}, + {"jobState": "RUNNING"}, + {"jobState": "COMPLETED"}, + ] + dremio_api.get_job_status = Mock(side_effect=status_responses) + + # Mock streaming results + dremio_api._fetch_results_iter = Mock(return_value=iter([{"col1": "val1"}])) + + with patch("time.sleep"): # Skip actual sleep delays + result_iterator = dremio_api.execute_query_iter("SELECT * FROM test") + results = list(result_iterator) + + assert results == [{"col1": "val1"}] + dremio_api.post.assert_called_once() + assert dremio_api.get_job_status.call_count == 3 + + def test_execute_query_iter_job_failure(self, dremio_api): + """Test execute_query_iter with job failure""" + # Mock the POST response for job submission + dremio_api.post = Mock(return_value={"id": "job-123"}) + + # Mock job failure + dremio_api.get_job_status = Mock( + return_value={"jobState": "FAILED", "errorMessage": "SQL syntax error"} + ) + + with ( + pytest.raises(RuntimeError, match="Query failed: SQL syntax error"), + patch("time.sleep"), + ): + result_iterator = dremio_api.execute_query_iter("SELECT * FROM test") + list(result_iterator) # Force evaluation + + def test_execute_query_iter_timeout(self, dremio_api): + """Test execute_query_iter with timeout""" + # Mock the POST response for job submission + dremio_api.post = Mock(return_value={"id": "job-123"}) + + # Mock job that never completes + dremio_api.get_job_status = Mock(return_value={"jobState": "RUNNING"}) + dremio_api.cancel_query = Mock() + + # Mock time.time to simulate timeout - need to patch where it's imported + # First call: start_time = 0 + # Second call: time() - start_time check = 3700 (triggers timeout) + mock_time_values = [0, 3700] + + with ( + patch( + "datahub.ingestion.source.dremio.dremio_api.time", + side_effect=mock_time_values, + ), + patch("datahub.ingestion.source.dremio.dremio_api.sleep"), + pytest.raises( + DremioAPIException, + match="Query execution timed out after 3600 seconds", + ), + ): + result_iterator = dremio_api.execute_query_iter("SELECT * FROM test") + list(result_iterator) + + dremio_api.cancel_query.assert_called_once_with("job-123") + + def test_get_all_tables_and_columns(self, dremio_api): + """Test streaming version of get_all_tables_and_columns""" + from datahub.ingestion.source.dremio.dremio_api import DremioEdition + + # Set up test data + dremio_api.edition = DremioEdition.ENTERPRISE + dremio_api.allow_schema_pattern = [".*"] + dremio_api.deny_schema_pattern = [] + + # Mock container + mock_container = Mock() + mock_container.container_name = "test_source" + containers = deque([mock_container]) + + # Mock streaming query results + mock_results = [ + { + "COLUMN_NAME": "col1", + "FULL_TABLE_PATH": "test.table1", + "TABLE_NAME": "table1", + "TABLE_SCHEMA": "test", + "ORDINAL_POSITION": 1, + "IS_NULLABLE": "YES", + "DATA_TYPE": "VARCHAR", + "COLUMN_SIZE": 255, + "RESOURCE_ID": "res1", + } + ] + dremio_api.execute_query_iter = Mock(return_value=iter(mock_results)) + + # Test streaming method + result_iterator = dremio_api.get_all_tables_and_columns(containers) + tables = list(result_iterator) + + assert len(tables) == 1 + table = tables[0] + assert table["TABLE_NAME"] == "table1" + assert table["TABLE_SCHEMA"] == "test" + assert len(table["COLUMNS"]) == 1 + assert table["COLUMNS"][0]["name"] == "col1" + + def test_extract_all_queries(self, dremio_api): + """Test streaming version of extract_all_queries""" + + dremio_api.edition = DremioEdition.ENTERPRISE + dremio_api.start_time = None + dremio_api.end_time = None + + # Mock streaming query execution + mock_query_results = [ + {"query_id": "q1", "sql": "SELECT * FROM table1"}, + {"query_id": "q2", "sql": "SELECT * FROM table2"}, + ] + dremio_api.execute_query_iter = Mock(return_value=iter(mock_query_results)) + + result_iterator = dremio_api.extract_all_queries() + queries = list(result_iterator) + + assert len(queries) == 2 + assert queries[0]["query_id"] == "q1" + assert queries[1]["query_id"] == "q2" + dremio_api.execute_query_iter.assert_called_once() + + def test_fetch_results_iter_incremental_yielding(self, dremio_api): + """Test that internal iterator yields results incrementally and can be partially consumed""" + # Verify that streaming yields results one at a time and iterator state is maintained + + mock_responses = [ + {"rows": [{"id": i} for i in range(100)], "rowCount": 1000}, + {"rows": [{"id": i} for i in range(100, 200)], "rowCount": 1000}, + {"rows": [], "rowCount": 1000}, # End iteration + ] + dremio_api.get_job_result = Mock(side_effect=mock_responses) + + # Test that streaming version yields results incrementally + result_iterator = dremio_api._fetch_results_iter("test-job") + + # Get first few results + first_batch = [] + for i, result in enumerate(result_iterator): + first_batch.append(result) + if i >= 50: # Stop after getting 51 results + break + + assert len(first_batch) == 51 + # Verify we can still get more results (iterator not exhausted) + next_result = next(result_iterator, None) + assert next_result is not None diff --git a/metadata-ingestion/tests/unit/dremio/test_dremio_iterator_integration.py b/metadata-ingestion/tests/unit/dremio/test_dremio_iterator_integration.py new file mode 100644 index 0000000000..352eaabb40 --- /dev/null +++ b/metadata-ingestion/tests/unit/dremio/test_dremio_iterator_integration.py @@ -0,0 +1,74 @@ +from unittest.mock import Mock + +import pytest + +from datahub.ingestion.source.dremio.dremio_config import DremioSourceConfig +from datahub.ingestion.source.dremio.dremio_entities import DremioCatalog, DremioDataset +from datahub.ingestion.source.dremio.dremio_source import DremioSource + + +class TestDremioIteratorIntegration: + @pytest.fixture + def mock_config(self): + return DremioSourceConfig( + hostname="test-host", + port=9047, + tls=False, + username="test-user", + password="test-password", + ) + + @pytest.fixture + def mock_dremio_source(self, mock_config, monkeypatch): + mock_session = Mock() + monkeypatch.setattr("requests.Session", Mock(return_value=mock_session)) + mock_session.post.return_value.json.return_value = {"token": "dummy-token"} + mock_session.post.return_value.status_code = 200 + + mock_ctx = Mock() + mock_ctx.run_id = "test-run-id" + source = DremioSource(mock_config, mock_ctx) + + source.dremio_catalog = Mock(spec=DremioCatalog) + source.dremio_catalog.dremio_api = Mock() + + return source + + def test_source_uses_iterators_by_default(self, mock_dremio_source): + """Test that source uses iterators by default""" + mock_dataset = Mock(spec=DremioDataset) + mock_dataset.path = ["test", "path"] + mock_dataset.resource_name = "test_table" + + mock_dremio_source.dremio_catalog.get_datasets.return_value = iter( + [mock_dataset] + ) + mock_dremio_source.dremio_catalog.get_containers.return_value = [] + mock_dremio_source.dremio_catalog.get_glossary_terms.return_value = iter([]) + mock_dremio_source.dremio_catalog.get_sources.return_value = [] + mock_dremio_source.config.source_mappings = [] + mock_dremio_source.process_dataset = Mock(return_value=iter([])) + list(mock_dremio_source.get_workunits_internal()) + + mock_dremio_source.dremio_catalog.get_datasets.assert_called_once() + mock_dremio_source.dremio_catalog.get_glossary_terms.assert_called_once() + + def test_iterator_handles_exceptions_gracefully(self, mock_dremio_source): + """Test that iterator handles exceptions without crashing""" + mock_dataset = Mock(spec=DremioDataset) + mock_dataset.path = ["test", "path"] + mock_dataset.resource_name = "test_table" + + mock_dremio_source.dremio_catalog.get_datasets.return_value = iter( + [mock_dataset] + ) + mock_dremio_source.dremio_catalog.get_containers.return_value = [] + mock_dremio_source.dremio_catalog.get_glossary_terms.return_value = iter([]) + mock_dremio_source.dremio_catalog.get_sources.return_value = [] + mock_dremio_source.config.source_mappings = [] + mock_dremio_source.process_dataset = Mock( + side_effect=Exception("Processing error") + ) + list(mock_dremio_source.get_workunits_internal()) + + assert mock_dremio_source.report.num_datasets_failed > 0 diff --git a/metadata-ingestion/tests/unit/dremio/test_dremio_source_map.py b/metadata-ingestion/tests/unit/dremio/test_dremio_source_map.py index 8514c0c8ef..1a1e7d96ee 100644 --- a/metadata-ingestion/tests/unit/dremio/test_dremio_source_map.py +++ b/metadata-ingestion/tests/unit/dremio/test_dremio_source_map.py @@ -10,7 +10,7 @@ from datahub.ingestion.source.dremio.dremio_source import ( def test_build_source_map_simple(): - # write unit test + """Test basic source mapping functionality with simple configuration""" config_mapping: List[DremioSourceMapping] = [ DremioSourceMapping(source_name="source1", platform="S3", env="PROD"), DremioSourceMapping(source_name="source2", platform="redshift", env="DEV"), @@ -58,7 +58,7 @@ def test_build_source_map_simple(): def test_build_source_map_same_platform_multiple_sources(): - # write unit test + """Test source mapping with multiple sources using the same platform and complex scenarios""" config_mapping: List[DremioSourceMapping] = [ DremioSourceMapping(source_name="source1", platform="S3", env="PROD"), DremioSourceMapping(source_name="source2", platform="redshift", env="DEV"),