mirror of
				https://github.com/datahub-project/datahub.git
				synced 2025-10-31 10:49:00 +00:00 
			
		
		
		
	fix(ingestion/dremio): handle dremio oom errors when ingesting large amount of metadata (#14883)
This commit is contained in:
		
							parent
							
								
									6af5182e9b
								
							
						
					
					
						commit
						f05f3e40f2
					
				| @ -7,7 +7,7 @@ from collections import defaultdict | ||||
| from enum import Enum | ||||
| from itertools import product | ||||
| from time import sleep, time | ||||
| from typing import TYPE_CHECKING, Any, Deque, Dict, List, Optional, Union | ||||
| from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union | ||||
| from urllib.parse import quote | ||||
| 
 | ||||
| import requests | ||||
| @ -343,14 +343,149 @@ class DremioAPIOperations: | ||||
| 
 | ||||
|         while True: | ||||
|             result = self.get_job_result(job_id, offset, limit) | ||||
|             rows.extend(result["rows"]) | ||||
| 
 | ||||
|             offset = offset + limit | ||||
|             if offset >= result["rowCount"]: | ||||
|             # Handle cases where API response doesn't contain 'rows' key | ||||
|             # This can happen with OOM errors or when no rows are returned | ||||
|             if "rows" not in result: | ||||
|                 logger.warning( | ||||
|                     f"API response for job {job_id} missing 'rows' key. " | ||||
|                     f"Response keys: {list(result.keys())}" | ||||
|                 ) | ||||
|                 # Check for error conditions | ||||
|                 if "errorMessage" in result: | ||||
|                     raise DremioAPIException(f"Query error: {result['errorMessage']}") | ||||
|                 elif "message" in result: | ||||
|                     logger.warning( | ||||
|                         f"Query warning for job {job_id}: {result['message']}" | ||||
|                     ) | ||||
|                 # Return empty list if no rows key and no error | ||||
|                 break | ||||
| 
 | ||||
|             # Handle empty rows response | ||||
|             result_rows = result["rows"] | ||||
|             if not result_rows: | ||||
|                 logger.debug( | ||||
|                     f"No more rows returned for job {job_id} at offset {offset}" | ||||
|                 ) | ||||
|                 break | ||||
| 
 | ||||
|             rows.extend(result_rows) | ||||
| 
 | ||||
|             # Check actual returned rows to determine if we should continue | ||||
|             actual_rows_returned = len(result_rows) | ||||
|             if actual_rows_returned == 0: | ||||
|                 logger.debug(f"Query returned no rows for job {job_id}") | ||||
|                 break | ||||
| 
 | ||||
|             offset = offset + actual_rows_returned | ||||
|             # If we got fewer rows than requested, we've reached the end | ||||
|             if actual_rows_returned < limit: | ||||
|                 break | ||||
| 
 | ||||
|         logger.info(f"Fetched {len(rows)} total rows for job {job_id}") | ||||
|         return rows | ||||
| 
 | ||||
|     def _fetch_results_iter(self, job_id: str) -> Iterator[Dict]: | ||||
|         """ | ||||
|         Fetch job results in a streaming fashion to reduce memory usage. | ||||
|         Yields individual rows instead of collecting all in memory. | ||||
|         """ | ||||
|         limit = 500 | ||||
|         offset = 0 | ||||
|         total_rows_fetched = 0 | ||||
| 
 | ||||
|         while True: | ||||
|             result = self.get_job_result(job_id, offset, limit) | ||||
| 
 | ||||
|             # Handle cases where API response doesn't contain 'rows' key | ||||
|             if "rows" not in result: | ||||
|                 logger.warning( | ||||
|                     f"API response for job {job_id} missing 'rows' key. " | ||||
|                     f"Response keys: {list(result.keys())}" | ||||
|                 ) | ||||
|                 # Check for error conditions | ||||
|                 if "errorMessage" in result: | ||||
|                     raise DremioAPIException(f"Query error: {result['errorMessage']}") | ||||
|                 elif "message" in result: | ||||
|                     logger.warning( | ||||
|                         f"Query warning for job {job_id}: {result['message']}" | ||||
|                     ) | ||||
|                 # Stop iteration if no rows key and no error | ||||
|                 break | ||||
| 
 | ||||
|             # Handle empty rows response | ||||
|             result_rows = result["rows"] | ||||
|             if not result_rows: | ||||
|                 logger.debug( | ||||
|                     f"No more rows returned for job {job_id} at offset {offset}" | ||||
|                 ) | ||||
|                 break | ||||
| 
 | ||||
|             # Yield individual rows instead of collecting them | ||||
|             for row in result_rows: | ||||
|                 yield row | ||||
|                 total_rows_fetched += 1 | ||||
| 
 | ||||
|             # Check actual returned rows to determine if we should continue | ||||
|             actual_rows_returned = len(result_rows) | ||||
|             if actual_rows_returned == 0: | ||||
|                 logger.debug(f"Query returned no rows for job {job_id}") | ||||
|                 break | ||||
| 
 | ||||
|             offset = offset + actual_rows_returned | ||||
|             # If we got fewer rows than requested, we've reached the end | ||||
|             if actual_rows_returned < limit: | ||||
|                 break | ||||
| 
 | ||||
|         logger.info(f"Streamed {total_rows_fetched} total rows for job {job_id}") | ||||
| 
 | ||||
|     def execute_query_iter( | ||||
|         self, query: str, timeout: int = 3600 | ||||
|     ) -> Iterator[Dict[str, Any]]: | ||||
|         """Execute SQL query and return results as a streaming iterator""" | ||||
|         try: | ||||
|             with PerfTimer() as timer: | ||||
|                 logger.info(f"Executing streaming query: {query}") | ||||
|                 response = self.post(url="/sql", data=json.dumps({"sql": query})) | ||||
| 
 | ||||
|                 if "errorMessage" in response: | ||||
|                     self.report.failure( | ||||
|                         message="SQL Error", context=f"{response['errorMessage']}" | ||||
|                     ) | ||||
|                     raise DremioAPIException(f"SQL Error: {response['errorMessage']}") | ||||
| 
 | ||||
|                 job_id = response["id"] | ||||
| 
 | ||||
|                 # Wait for job completion | ||||
|                 start_time = time() | ||||
|                 while True: | ||||
|                     status = self.get_job_status(job_id) | ||||
|                     if status["jobState"] == "COMPLETED": | ||||
|                         break | ||||
|                     elif status["jobState"] == "FAILED": | ||||
|                         error_message = status.get("errorMessage", "Unknown error") | ||||
|                         raise RuntimeError(f"Query failed: {error_message}") | ||||
|                     elif status["jobState"] == "CANCELED": | ||||
|                         raise RuntimeError("Query was canceled") | ||||
| 
 | ||||
|                     if time() - start_time > timeout: | ||||
|                         self.cancel_query(job_id) | ||||
|                         raise DremioAPIException( | ||||
|                             f"Query execution timed out after {timeout} seconds" | ||||
|                         ) | ||||
| 
 | ||||
|                     sleep(3) | ||||
| 
 | ||||
|                 logger.info( | ||||
|                     f"Query job completed in {timer.elapsed_seconds()} seconds, starting streaming" | ||||
|                 ) | ||||
| 
 | ||||
|                 # Return streaming iterator | ||||
|                 return self._fetch_results_iter(job_id) | ||||
| 
 | ||||
|         except requests.RequestException as e: | ||||
|             raise DremioAPIException("Error executing streaming query") from e | ||||
| 
 | ||||
|     def cancel_query(self, job_id: str) -> None: | ||||
|         """Cancel a running query""" | ||||
|         try: | ||||
| @ -499,8 +634,12 @@ class DremioAPIOperations: | ||||
|         return f"AND {operator}({field}, '{pattern_str}')" | ||||
| 
 | ||||
|     def get_all_tables_and_columns( | ||||
|         self, containers: Deque["DremioContainer"] | ||||
|     ) -> List[Dict]: | ||||
|         self, containers: Iterator["DremioContainer"] | ||||
|     ) -> Iterator[Dict]: | ||||
|         """ | ||||
|         Memory-efficient streaming version that yields tables one at a time. | ||||
|         Reduces memory usage for large datasets by processing results as they come. | ||||
|         """ | ||||
|         if self.edition == DremioEdition.ENTERPRISE: | ||||
|             query_template = DremioSQLQueries.QUERY_DATASETS_EE | ||||
|         elif self.edition == DremioEdition.CLOUD: | ||||
| @ -517,21 +656,78 @@ class DremioAPIOperations: | ||||
|             self.deny_schema_pattern, schema_field, allow=False | ||||
|         ) | ||||
| 
 | ||||
|         all_tables_and_columns = [] | ||||
| 
 | ||||
|         # Process each container's results separately to avoid memory buildup | ||||
|         for schema in containers: | ||||
|             formatted_query = "" | ||||
|             try: | ||||
|                 formatted_query = query_template.format( | ||||
|                     schema_pattern=schema_condition, | ||||
|                     deny_schema_pattern=deny_schema_condition, | ||||
|                     container_name=schema.container_name.lower(), | ||||
|                 ) | ||||
|                 all_tables_and_columns.extend( | ||||
|                     self.execute_query( | ||||
|                         query=formatted_query, | ||||
| 
 | ||||
|                 # Use streaming query execution | ||||
|                 container_results = list(self.execute_query_iter(query=formatted_query)) | ||||
| 
 | ||||
|                 if self.edition == DremioEdition.COMMUNITY: | ||||
|                     # Process community edition results | ||||
|                     formatted_tables = self.community_get_formatted_tables( | ||||
|                         container_results | ||||
|                     ) | ||||
|                 ) | ||||
|                     for table in formatted_tables: | ||||
|                         yield table | ||||
|                 else: | ||||
|                     # Process enterprise/cloud edition results | ||||
|                     column_dictionary: Dict[str, List[Dict]] = defaultdict(list) | ||||
|                     table_metadata: Dict[str, Dict] = {} | ||||
| 
 | ||||
|                     for record in container_results: | ||||
|                         if not record.get("COLUMN_NAME"): | ||||
|                             continue | ||||
| 
 | ||||
|                         table_full_path = record.get("FULL_TABLE_PATH") | ||||
|                         if not table_full_path: | ||||
|                             continue | ||||
| 
 | ||||
|                         # Store column information | ||||
|                         column_dictionary[table_full_path].append( | ||||
|                             { | ||||
|                                 "name": record["COLUMN_NAME"], | ||||
|                                 "ordinal_position": record["ORDINAL_POSITION"], | ||||
|                                 "is_nullable": record["IS_NULLABLE"], | ||||
|                                 "data_type": record["DATA_TYPE"], | ||||
|                                 "column_size": record["COLUMN_SIZE"], | ||||
|                             } | ||||
|                         ) | ||||
| 
 | ||||
|                         # Store table metadata (only once per table) | ||||
|                         if table_full_path not in table_metadata: | ||||
|                             table_metadata[table_full_path] = { | ||||
|                                 "TABLE_NAME": record.get("TABLE_NAME"), | ||||
|                                 "TABLE_SCHEMA": record.get("TABLE_SCHEMA"), | ||||
|                                 "VIEW_DEFINITION": record.get("VIEW_DEFINITION"), | ||||
|                                 "RESOURCE_ID": record.get("RESOURCE_ID"), | ||||
|                                 "LOCATION_ID": record.get("LOCATION_ID"), | ||||
|                                 "OWNER": record.get("OWNER"), | ||||
|                                 "OWNER_TYPE": record.get("OWNER_TYPE"), | ||||
|                                 "CREATED": record.get("CREATED"), | ||||
|                                 "FORMAT_TYPE": record.get("FORMAT_TYPE"), | ||||
|                             } | ||||
| 
 | ||||
|                     # Yield tables one at a time | ||||
|                     for table_path, table_info in table_metadata.items(): | ||||
|                         yield { | ||||
|                             "TABLE_NAME": table_info.get("TABLE_NAME"), | ||||
|                             "TABLE_SCHEMA": table_info.get("TABLE_SCHEMA"), | ||||
|                             "COLUMNS": column_dictionary[table_path], | ||||
|                             "VIEW_DEFINITION": table_info.get("VIEW_DEFINITION"), | ||||
|                             "RESOURCE_ID": table_info.get("RESOURCE_ID"), | ||||
|                             "LOCATION_ID": table_info.get("LOCATION_ID"), | ||||
|                             "OWNER": table_info.get("OWNER"), | ||||
|                             "OWNER_TYPE": table_info.get("OWNER_TYPE"), | ||||
|                             "CREATED": table_info.get("CREATED"), | ||||
|                             "FORMAT_TYPE": table_info.get("FORMAT_TYPE"), | ||||
|                         } | ||||
| 
 | ||||
|             except DremioAPIException as e: | ||||
|                 self.report.warning( | ||||
|                     message="Container has no tables or views", | ||||
| @ -539,71 +735,6 @@ class DremioAPIOperations: | ||||
|                     exc=e, | ||||
|                 ) | ||||
| 
 | ||||
|         tables = [] | ||||
| 
 | ||||
|         if self.edition == DremioEdition.COMMUNITY: | ||||
|             tables = self.community_get_formatted_tables(all_tables_and_columns) | ||||
| 
 | ||||
|         else: | ||||
|             column_dictionary: Dict[str, List[Dict]] = defaultdict(list) | ||||
| 
 | ||||
|             for record in all_tables_and_columns: | ||||
|                 if not record.get("COLUMN_NAME"): | ||||
|                     continue | ||||
| 
 | ||||
|                 table_full_path = record.get("FULL_TABLE_PATH") | ||||
|                 if not table_full_path: | ||||
|                     continue | ||||
| 
 | ||||
|                 column_dictionary[table_full_path].append( | ||||
|                     { | ||||
|                         "name": record["COLUMN_NAME"], | ||||
|                         "ordinal_position": record["ORDINAL_POSITION"], | ||||
|                         "is_nullable": record["IS_NULLABLE"], | ||||
|                         "data_type": record["DATA_TYPE"], | ||||
|                         "column_size": record["COLUMN_SIZE"], | ||||
|                     } | ||||
|                 ) | ||||
| 
 | ||||
|             distinct_tables_list = list( | ||||
|                 { | ||||
|                     tuple( | ||||
|                         dictionary[key] | ||||
|                         for key in ( | ||||
|                             "TABLE_SCHEMA", | ||||
|                             "TABLE_NAME", | ||||
|                             "FULL_TABLE_PATH", | ||||
|                             "VIEW_DEFINITION", | ||||
|                             "LOCATION_ID", | ||||
|                             "OWNER", | ||||
|                             "OWNER_TYPE", | ||||
|                             "CREATED", | ||||
|                             "FORMAT_TYPE", | ||||
|                         ) | ||||
|                         if key in dictionary | ||||
|                     ): dictionary | ||||
|                     for dictionary in all_tables_and_columns | ||||
|                 }.values() | ||||
|             ) | ||||
| 
 | ||||
|             for table in distinct_tables_list: | ||||
|                 tables.append( | ||||
|                     { | ||||
|                         "TABLE_NAME": table.get("TABLE_NAME"), | ||||
|                         "TABLE_SCHEMA": table.get("TABLE_SCHEMA"), | ||||
|                         "COLUMNS": column_dictionary[table["FULL_TABLE_PATH"]], | ||||
|                         "VIEW_DEFINITION": table.get("VIEW_DEFINITION"), | ||||
|                         "RESOURCE_ID": table.get("RESOURCE_ID"), | ||||
|                         "LOCATION_ID": table.get("LOCATION_ID"), | ||||
|                         "OWNER": table.get("OWNER"), | ||||
|                         "OWNER_TYPE": table.get("OWNER_TYPE"), | ||||
|                         "CREATED": table.get("CREATED"), | ||||
|                         "FORMAT_TYPE": table.get("FORMAT_TYPE"), | ||||
|                     } | ||||
|                 ) | ||||
| 
 | ||||
|         return tables | ||||
| 
 | ||||
|     def validate_schema_format(self, schema): | ||||
|         if "." in schema: | ||||
|             schema_path = self.get( | ||||
| @ -640,7 +771,10 @@ class DremioAPIOperations: | ||||
| 
 | ||||
|         return parents_list | ||||
| 
 | ||||
|     def extract_all_queries(self) -> List[Dict[str, Any]]: | ||||
|     def extract_all_queries(self) -> Iterator[Dict[str, Any]]: | ||||
|         """ | ||||
|         Memory-efficient streaming version for extracting query results. | ||||
|         """ | ||||
|         # Convert datetime objects to string format for SQL queries | ||||
|         start_timestamp_str = None | ||||
|         end_timestamp_str = None | ||||
| @ -661,7 +795,7 @@ class DremioAPIOperations: | ||||
|                 end_timestamp_millis=end_timestamp_str, | ||||
|             ) | ||||
| 
 | ||||
|         return self.execute_query(query=jobs_query) | ||||
|         return self.execute_query_iter(query=jobs_query) | ||||
| 
 | ||||
|     def get_tags_for_resource(self, resource_id: str) -> Optional[List[str]]: | ||||
|         """ | ||||
|  | ||||
| @ -1,4 +1,3 @@ | ||||
| import itertools | ||||
| import logging | ||||
| import re | ||||
| import uuid | ||||
| @ -6,7 +5,7 @@ from collections import deque | ||||
| from dataclasses import dataclass | ||||
| from datetime import datetime | ||||
| from enum import Enum | ||||
| from typing import Any, Deque, Dict, List, Optional | ||||
| from typing import Any, Deque, Dict, Iterator, List, Optional | ||||
| 
 | ||||
| from sqlglot import parse_one | ||||
| 
 | ||||
| @ -184,6 +183,7 @@ class DremioQuery: | ||||
|         return "" | ||||
| 
 | ||||
|     def get_raw_query(self, sql_query: str) -> str: | ||||
|         """Remove comments from SQL query using sqlglot parser.""" | ||||
|         try: | ||||
|             parsed = parse_one(sql_query) | ||||
|             return parsed.sql(comments=False) | ||||
| @ -336,43 +336,26 @@ class DremioCatalog: | ||||
|     def __init__(self, dremio_api: DremioAPIOperations): | ||||
|         self.dremio_api = dremio_api | ||||
|         self.edition = dremio_api.edition | ||||
|         self.datasets: Deque[DremioDataset] = deque() | ||||
|         self.sources: Deque[DremioSourceContainer] = deque() | ||||
|         self.spaces: Deque[DremioSpace] = deque() | ||||
|         self.folders: Deque[DremioFolder] = deque() | ||||
|         self.glossary_terms: Deque[DremioGlossaryTerm] = deque() | ||||
|         self.queries: Deque[DremioQuery] = deque() | ||||
| 
 | ||||
|         self.datasets_populated = False | ||||
|         self.containers_populated = False | ||||
|         self.queries_populated = False | ||||
| 
 | ||||
|     def set_datasets(self) -> None: | ||||
|         if not self.datasets_populated: | ||||
|             self.set_containers() | ||||
|     def get_datasets(self) -> Iterator[DremioDataset]: | ||||
|         """Get all Dremio datasets (tables and views) as an iterator.""" | ||||
|         # Get containers directly without storing them | ||||
|         containers = self.get_containers() | ||||
| 
 | ||||
|             containers: Deque[DremioContainer] = deque() | ||||
|             containers.extend(self.spaces)  # Add DremioSpace elements | ||||
|             containers.extend(self.sources)  # Add DremioSource elements | ||||
|         for dataset_details in self.dremio_api.get_all_tables_and_columns(containers): | ||||
|             dremio_dataset = DremioDataset( | ||||
|                 dataset_details=dataset_details, | ||||
|                 api_operations=self.dremio_api, | ||||
|             ) | ||||
| 
 | ||||
|             for dataset_details in self.dremio_api.get_all_tables_and_columns( | ||||
|                 containers=containers | ||||
|             ): | ||||
|                 dremio_dataset = DremioDataset( | ||||
|                     dataset_details=dataset_details, | ||||
|                     api_operations=self.dremio_api, | ||||
|                 ) | ||||
|                 self.datasets.append(dremio_dataset) | ||||
| 
 | ||||
|                 for glossary_term in dremio_dataset.glossary_terms: | ||||
|                     if glossary_term not in self.glossary_terms: | ||||
|                         self.glossary_terms.append(glossary_term) | ||||
| 
 | ||||
|             self.datasets_populated = True | ||||
| 
 | ||||
|     def get_datasets(self) -> Deque[DremioDataset]: | ||||
|         self.set_datasets() | ||||
|         return self.datasets | ||||
|             yield dremio_dataset | ||||
| 
 | ||||
|     def set_containers(self) -> None: | ||||
|         if not self.containers_populated: | ||||
| @ -423,18 +406,50 @@ class DremioCatalog: | ||||
| 
 | ||||
|         self.containers_populated = True | ||||
| 
 | ||||
|     def get_containers(self) -> Deque: | ||||
|         self.set_containers() | ||||
|         return deque(itertools.chain(self.sources, self.spaces, self.folders)) | ||||
|     def get_containers(self) -> Iterator[DremioContainer]: | ||||
|         """Get all containers (sources, spaces, folders) as an iterator.""" | ||||
|         for container in self.dremio_api.get_all_containers(): | ||||
|             container_type = container.get("container_type") | ||||
|             if container_type == DremioEntityContainerType.SOURCE: | ||||
|                 yield DremioSourceContainer( | ||||
|                     container_name=container.get("name"), | ||||
|                     location_id=container.get("id"), | ||||
|                     path=[], | ||||
|                     api_operations=self.dremio_api, | ||||
|                     dremio_source_type=container.get("source_type") or "", | ||||
|                     root_path=container.get("root_path"), | ||||
|                     database_name=container.get("database_name"), | ||||
|                 ) | ||||
|             elif container_type == DremioEntityContainerType.SPACE: | ||||
|                 yield DremioSpace( | ||||
|                     container_name=container.get("name"), | ||||
|                     location_id=container.get("id"), | ||||
|                     path=[], | ||||
|                     api_operations=self.dremio_api, | ||||
|                 ) | ||||
|             elif container_type == DremioEntityContainerType.FOLDER: | ||||
|                 yield DremioFolder( | ||||
|                     container_name=container.get("name"), | ||||
|                     location_id=container.get("id"), | ||||
|                     path=container.get("path"), | ||||
|                     api_operations=self.dremio_api, | ||||
|                 ) | ||||
| 
 | ||||
|     def get_sources(self) -> Deque[DremioSourceContainer]: | ||||
|         self.set_containers() | ||||
|         return self.sources | ||||
|     def get_sources(self) -> Iterator[DremioSourceContainer]: | ||||
|         """Get all Dremio source containers (external data connections) as an iterator.""" | ||||
|         for container in self.get_containers(): | ||||
|             if isinstance(container, DremioSourceContainer): | ||||
|                 yield container | ||||
| 
 | ||||
|     def get_glossary_terms(self) -> Deque[DremioGlossaryTerm]: | ||||
|         self.set_datasets() | ||||
|         self.set_containers() | ||||
|         return self.glossary_terms | ||||
|     def get_glossary_terms(self) -> Iterator[DremioGlossaryTerm]: | ||||
|         """Get all unique glossary terms (tags) from datasets.""" | ||||
|         glossary_terms_seen = set() | ||||
| 
 | ||||
|         for dataset in self.get_datasets(): | ||||
|             for glossary_term in dataset.glossary_terms: | ||||
|                 if glossary_term not in glossary_terms_seen: | ||||
|                     glossary_terms_seen.add(glossary_term) | ||||
|                     yield glossary_term | ||||
| 
 | ||||
|     def is_valid_query(self, query: Dict[str, Any]) -> bool: | ||||
|         required_fields = [ | ||||
| @ -447,6 +462,7 @@ class DremioCatalog: | ||||
|         return all(query.get(field) for field in required_fields) | ||||
| 
 | ||||
|     def get_queries(self) -> Deque[DremioQuery]: | ||||
|         """Get all valid Dremio queries for lineage analysis.""" | ||||
|         for query in self.dremio_api.extract_all_queries(): | ||||
|             if not self.is_valid_query(query): | ||||
|                 continue | ||||
|  | ||||
| @ -17,6 +17,7 @@ from datahub.metadata.schema_classes import ( | ||||
|     DatasetProfileClass, | ||||
|     QuantileClass, | ||||
| ) | ||||
| from datahub.utilities.perf_timer import PerfTimer | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| @ -64,8 +65,13 @@ class DremioProfiler: | ||||
|             ) | ||||
|             return | ||||
| 
 | ||||
|         profile_data = self.profile_table(full_table_name, columns) | ||||
|         profile_aspect = self.populate_profile_aspect(profile_data) | ||||
|         with PerfTimer() as timer: | ||||
|             profile_data = self.profile_table(full_table_name, columns) | ||||
|             profile_aspect = self.populate_profile_aspect(profile_data) | ||||
| 
 | ||||
|         logger.info( | ||||
|             f"Profiled table {full_table_name} with {len(columns)} columns in {timer.elapsed_seconds():.2f} seconds" | ||||
|         ) | ||||
| 
 | ||||
|         if profile_aspect: | ||||
|             self.report.report_entity_profiled(dataset.resource_name) | ||||
| @ -131,7 +137,12 @@ class DremioProfiler: | ||||
|     def _profile_chunk(self, table_name: str, columns: List[Tuple[str, str]]) -> Dict: | ||||
|         profile_sql = self._build_profile_sql(table_name, columns) | ||||
|         try: | ||||
|             results = self.api_operations.execute_query(profile_sql) | ||||
|             with PerfTimer() as timer: | ||||
|                 results = self.api_operations.execute_query(profile_sql) | ||||
| 
 | ||||
|             logger.debug( | ||||
|                 f"Profiling query for {table_name} ({len(columns)} columns) completed in {timer.elapsed_seconds():.2f} seconds" | ||||
|             ) | ||||
|             return self._parse_profile_results(results, columns) | ||||
|         except DremioAPIException as e: | ||||
|             raise e | ||||
|  | ||||
| @ -55,7 +55,7 @@ from datahub.ingestion.source.state.stateful_ingestion_base import ( | ||||
| from datahub.ingestion.source_report.ingestion_stage import ( | ||||
|     LINEAGE_EXTRACTION, | ||||
|     METADATA_EXTRACTION, | ||||
|     IngestionHighStage, | ||||
|     PROFILING, | ||||
| ) | ||||
| from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( | ||||
|     DatasetLineageTypeClass, | ||||
| @ -201,7 +201,7 @@ class DremioSource(StatefulIngestionSourceBase): | ||||
|         return "dremio" | ||||
| 
 | ||||
|     def _build_source_map(self) -> Dict[str, DremioSourceMapEntry]: | ||||
|         dremio_sources = self.dremio_catalog.get_sources() | ||||
|         dremio_sources = list(self.dremio_catalog.get_sources()) | ||||
|         source_mappings_config = self.config.source_mappings or [] | ||||
| 
 | ||||
|         source_map = build_dremio_source_map(dremio_sources, source_mappings_config) | ||||
| @ -242,9 +242,7 @@ class DremioSource(StatefulIngestionSourceBase): | ||||
|                     ) | ||||
| 
 | ||||
|             # Process Datasets | ||||
|             datasets = self.dremio_catalog.get_datasets() | ||||
| 
 | ||||
|             for dataset_info in datasets: | ||||
|             for dataset_info in self.dremio_catalog.get_datasets(): | ||||
|                 try: | ||||
|                     yield from self.process_dataset(dataset_info) | ||||
|                     logger.info( | ||||
| @ -258,10 +256,8 @@ class DremioSource(StatefulIngestionSourceBase): | ||||
|                         exc=exc, | ||||
|                     ) | ||||
| 
 | ||||
|             # Process Glossary Terms | ||||
|             glossary_terms = self.dremio_catalog.get_glossary_terms() | ||||
| 
 | ||||
|             for glossary_term in glossary_terms: | ||||
|             # Process Glossary Terms using streaming | ||||
|             for glossary_term in self.dremio_catalog.get_glossary_terms(): | ||||
|                 try: | ||||
|                     yield from self.process_glossary_term(glossary_term) | ||||
|                 except Exception as exc: | ||||
| @ -283,14 +279,16 @@ class DremioSource(StatefulIngestionSourceBase): | ||||
|             # Profiling | ||||
|             if self.config.is_profiling_enabled(): | ||||
|                 with ( | ||||
|                     self.report.new_high_stage(IngestionHighStage.PROFILING), | ||||
|                     self.report.new_stage(PROFILING), | ||||
|                     ThreadPoolExecutor( | ||||
|                         max_workers=self.config.profiling.max_workers | ||||
|                     ) as executor, | ||||
|                 ): | ||||
|                     # Collect datasets for profiling | ||||
|                     datasets_for_profiling = list(self.dremio_catalog.get_datasets()) | ||||
|                     future_to_dataset = { | ||||
|                         executor.submit(self.generate_profiles, dataset): dataset | ||||
|                         for dataset in datasets | ||||
|                         for dataset in datasets_for_profiling | ||||
|                     } | ||||
| 
 | ||||
|                     for future in as_completed(future_to_dataset): | ||||
|  | ||||
| @ -0,0 +1,281 @@ | ||||
| from collections import deque | ||||
| from unittest.mock import Mock, patch | ||||
| 
 | ||||
| import pytest | ||||
| 
 | ||||
| from datahub.ingestion.source.dremio.dremio_api import ( | ||||
|     DremioAPIException, | ||||
|     DremioAPIOperations, | ||||
|     DremioEdition, | ||||
| ) | ||||
| from datahub.ingestion.source.dremio.dremio_config import DremioSourceConfig | ||||
| from datahub.ingestion.source.dremio.dremio_reporting import DremioSourceReport | ||||
| 
 | ||||
| 
 | ||||
| class TestDremioAPIPagination: | ||||
|     @pytest.fixture | ||||
|     def dremio_api(self, monkeypatch): | ||||
|         """Setup mock Dremio API for testing""" | ||||
|         # Mock the requests.Session | ||||
|         mock_session = Mock() | ||||
|         monkeypatch.setattr("requests.Session", Mock(return_value=mock_session)) | ||||
| 
 | ||||
|         # Mock the authentication response | ||||
|         mock_session.post.return_value.json.return_value = {"token": "dummy-token"} | ||||
|         mock_session.post.return_value.status_code = 200 | ||||
| 
 | ||||
|         config = DremioSourceConfig( | ||||
|             hostname="dummy-host", | ||||
|             port=9047, | ||||
|             tls=False, | ||||
|             authentication_method="password", | ||||
|             username="dummy-user", | ||||
|             password="dummy-password", | ||||
|         ) | ||||
|         report = DremioSourceReport() | ||||
|         api = DremioAPIOperations(config, report) | ||||
|         api.session = mock_session | ||||
|         return api | ||||
| 
 | ||||
|     def test_fetch_all_results_missing_rows_key(self, dremio_api): | ||||
|         """Test handling of API response missing 'rows' key""" | ||||
|         # Mock get_job_result to return response without 'rows' key | ||||
|         dremio_api.get_job_result = Mock(return_value={"message": "No data available"}) | ||||
| 
 | ||||
|         result = dremio_api._fetch_all_results("test-job-id") | ||||
| 
 | ||||
|         # Should return empty list when no rows key is present | ||||
|         assert result == [] | ||||
|         dremio_api.get_job_result.assert_called_once() | ||||
| 
 | ||||
|     def test_fetch_all_results_with_error_message(self, dremio_api): | ||||
|         """Test handling of API response with errorMessage""" | ||||
|         # Mock get_job_result to return error response | ||||
|         dremio_api.get_job_result = Mock(return_value={"errorMessage": "Out of memory"}) | ||||
| 
 | ||||
|         with pytest.raises(DremioAPIException, match="Query error: Out of memory"): | ||||
|             dremio_api._fetch_all_results("test-job-id") | ||||
| 
 | ||||
|     def test_fetch_all_results_empty_rows(self, dremio_api): | ||||
|         """Test handling of empty rows response""" | ||||
|         # Mock get_job_result to return empty rows | ||||
|         dremio_api.get_job_result = Mock(return_value={"rows": [], "rowCount": 0}) | ||||
| 
 | ||||
|         result = dremio_api._fetch_all_results("test-job-id") | ||||
| 
 | ||||
|         assert result == [] | ||||
|         dremio_api.get_job_result.assert_called_once() | ||||
| 
 | ||||
|     def test_fetch_all_results_normal_case(self, dremio_api): | ||||
|         """Test normal operation with valid rows""" | ||||
|         # Mock get_job_result to return valid data | ||||
|         # First response: 2 rows, rowCount=2 (offset will be 2, which equals rowCount, so stops) | ||||
|         mock_responses = [ | ||||
|             {"rows": [{"col1": "val1"}, {"col1": "val2"}], "rowCount": 2}, | ||||
|         ] | ||||
|         dremio_api.get_job_result = Mock(side_effect=mock_responses) | ||||
| 
 | ||||
|         result = dremio_api._fetch_all_results("test-job-id") | ||||
| 
 | ||||
|         expected = [{"col1": "val1"}, {"col1": "val2"}] | ||||
|         assert result == expected | ||||
|         assert dremio_api.get_job_result.call_count == 1 | ||||
| 
 | ||||
|     def test_fetch_results_iter_missing_rows_key(self, dremio_api): | ||||
|         """Test internal streaming method handling missing 'rows' key""" | ||||
|         # Mock get_job_result to return response without 'rows' key | ||||
|         dremio_api.get_job_result = Mock(return_value={"message": "No data available"}) | ||||
| 
 | ||||
|         result_iterator = dremio_api._fetch_results_iter("test-job-id") | ||||
|         results = list(result_iterator) | ||||
| 
 | ||||
|         # Should return empty iterator when no rows key is present | ||||
|         assert results == [] | ||||
| 
 | ||||
|     def test_fetch_results_iter_with_error(self, dremio_api): | ||||
|         """Test internal streaming method handling error response""" | ||||
|         # Mock get_job_result to return error response | ||||
|         dremio_api.get_job_result = Mock(return_value={"errorMessage": "Query timeout"}) | ||||
| 
 | ||||
|         result_iterator = dremio_api._fetch_results_iter("test-job-id") | ||||
| 
 | ||||
|         with pytest.raises(DremioAPIException, match="Query error: Query timeout"): | ||||
|             list(result_iterator) | ||||
| 
 | ||||
|     def test_fetch_results_iter_normal_case(self, dremio_api): | ||||
|         """Test internal streaming method with valid data""" | ||||
|         # Mock get_job_result to return valid data in batches | ||||
|         mock_responses = [ | ||||
|             {"rows": [{"col1": "val1"}, {"col1": "val2"}], "rowCount": 2}, | ||||
|             {"rows": [], "rowCount": 2},  # Empty response to end iteration | ||||
|         ] | ||||
|         dremio_api.get_job_result = Mock(side_effect=mock_responses) | ||||
| 
 | ||||
|         result_iterator = dremio_api._fetch_results_iter("test-job-id") | ||||
|         results = list(result_iterator) | ||||
| 
 | ||||
|         expected = [ | ||||
|             {"col1": "val1"}, | ||||
|             {"col1": "val2"}, | ||||
|         ] | ||||
|         assert results == expected | ||||
| 
 | ||||
|     def test_execute_query_iter_success(self, dremio_api): | ||||
|         """Test execute_query_iter with successful job completion""" | ||||
|         # Mock the POST response for job submission | ||||
|         dremio_api.post = Mock(return_value={"id": "job-123"}) | ||||
| 
 | ||||
|         # Mock job status progression | ||||
|         status_responses = [ | ||||
|             {"jobState": "RUNNING"}, | ||||
|             {"jobState": "RUNNING"}, | ||||
|             {"jobState": "COMPLETED"}, | ||||
|         ] | ||||
|         dremio_api.get_job_status = Mock(side_effect=status_responses) | ||||
| 
 | ||||
|         # Mock streaming results | ||||
|         dremio_api._fetch_results_iter = Mock(return_value=iter([{"col1": "val1"}])) | ||||
| 
 | ||||
|         with patch("time.sleep"):  # Skip actual sleep delays | ||||
|             result_iterator = dremio_api.execute_query_iter("SELECT * FROM test") | ||||
|             results = list(result_iterator) | ||||
| 
 | ||||
|         assert results == [{"col1": "val1"}] | ||||
|         dremio_api.post.assert_called_once() | ||||
|         assert dremio_api.get_job_status.call_count == 3 | ||||
| 
 | ||||
|     def test_execute_query_iter_job_failure(self, dremio_api): | ||||
|         """Test execute_query_iter with job failure""" | ||||
|         # Mock the POST response for job submission | ||||
|         dremio_api.post = Mock(return_value={"id": "job-123"}) | ||||
| 
 | ||||
|         # Mock job failure | ||||
|         dremio_api.get_job_status = Mock( | ||||
|             return_value={"jobState": "FAILED", "errorMessage": "SQL syntax error"} | ||||
|         ) | ||||
| 
 | ||||
|         with ( | ||||
|             pytest.raises(RuntimeError, match="Query failed: SQL syntax error"), | ||||
|             patch("time.sleep"), | ||||
|         ): | ||||
|             result_iterator = dremio_api.execute_query_iter("SELECT * FROM test") | ||||
|             list(result_iterator)  # Force evaluation | ||||
| 
 | ||||
|     def test_execute_query_iter_timeout(self, dremio_api): | ||||
|         """Test execute_query_iter with timeout""" | ||||
|         # Mock the POST response for job submission | ||||
|         dremio_api.post = Mock(return_value={"id": "job-123"}) | ||||
| 
 | ||||
|         # Mock job that never completes | ||||
|         dremio_api.get_job_status = Mock(return_value={"jobState": "RUNNING"}) | ||||
|         dremio_api.cancel_query = Mock() | ||||
| 
 | ||||
|         # Mock time.time to simulate timeout - need to patch where it's imported | ||||
|         # First call: start_time = 0 | ||||
|         # Second call: time() - start_time check = 3700 (triggers timeout) | ||||
|         mock_time_values = [0, 3700] | ||||
| 
 | ||||
|         with ( | ||||
|             patch( | ||||
|                 "datahub.ingestion.source.dremio.dremio_api.time", | ||||
|                 side_effect=mock_time_values, | ||||
|             ), | ||||
|             patch("datahub.ingestion.source.dremio.dremio_api.sleep"), | ||||
|             pytest.raises( | ||||
|                 DremioAPIException, | ||||
|                 match="Query execution timed out after 3600 seconds", | ||||
|             ), | ||||
|         ): | ||||
|             result_iterator = dremio_api.execute_query_iter("SELECT * FROM test") | ||||
|             list(result_iterator) | ||||
| 
 | ||||
|         dremio_api.cancel_query.assert_called_once_with("job-123") | ||||
| 
 | ||||
|     def test_get_all_tables_and_columns(self, dremio_api): | ||||
|         """Test streaming version of get_all_tables_and_columns""" | ||||
|         from datahub.ingestion.source.dremio.dremio_api import DremioEdition | ||||
| 
 | ||||
|         # Set up test data | ||||
|         dremio_api.edition = DremioEdition.ENTERPRISE | ||||
|         dremio_api.allow_schema_pattern = [".*"] | ||||
|         dremio_api.deny_schema_pattern = [] | ||||
| 
 | ||||
|         # Mock container | ||||
|         mock_container = Mock() | ||||
|         mock_container.container_name = "test_source" | ||||
|         containers = deque([mock_container]) | ||||
| 
 | ||||
|         # Mock streaming query results | ||||
|         mock_results = [ | ||||
|             { | ||||
|                 "COLUMN_NAME": "col1", | ||||
|                 "FULL_TABLE_PATH": "test.table1", | ||||
|                 "TABLE_NAME": "table1", | ||||
|                 "TABLE_SCHEMA": "test", | ||||
|                 "ORDINAL_POSITION": 1, | ||||
|                 "IS_NULLABLE": "YES", | ||||
|                 "DATA_TYPE": "VARCHAR", | ||||
|                 "COLUMN_SIZE": 255, | ||||
|                 "RESOURCE_ID": "res1", | ||||
|             } | ||||
|         ] | ||||
|         dremio_api.execute_query_iter = Mock(return_value=iter(mock_results)) | ||||
| 
 | ||||
|         # Test streaming method | ||||
|         result_iterator = dremio_api.get_all_tables_and_columns(containers) | ||||
|         tables = list(result_iterator) | ||||
| 
 | ||||
|         assert len(tables) == 1 | ||||
|         table = tables[0] | ||||
|         assert table["TABLE_NAME"] == "table1" | ||||
|         assert table["TABLE_SCHEMA"] == "test" | ||||
|         assert len(table["COLUMNS"]) == 1 | ||||
|         assert table["COLUMNS"][0]["name"] == "col1" | ||||
| 
 | ||||
|     def test_extract_all_queries(self, dremio_api): | ||||
|         """Test streaming version of extract_all_queries""" | ||||
| 
 | ||||
|         dremio_api.edition = DremioEdition.ENTERPRISE | ||||
|         dremio_api.start_time = None | ||||
|         dremio_api.end_time = None | ||||
| 
 | ||||
|         # Mock streaming query execution | ||||
|         mock_query_results = [ | ||||
|             {"query_id": "q1", "sql": "SELECT * FROM table1"}, | ||||
|             {"query_id": "q2", "sql": "SELECT * FROM table2"}, | ||||
|         ] | ||||
|         dremio_api.execute_query_iter = Mock(return_value=iter(mock_query_results)) | ||||
| 
 | ||||
|         result_iterator = dremio_api.extract_all_queries() | ||||
|         queries = list(result_iterator) | ||||
| 
 | ||||
|         assert len(queries) == 2 | ||||
|         assert queries[0]["query_id"] == "q1" | ||||
|         assert queries[1]["query_id"] == "q2" | ||||
|         dremio_api.execute_query_iter.assert_called_once() | ||||
| 
 | ||||
|     def test_fetch_results_iter_incremental_yielding(self, dremio_api): | ||||
|         """Test that internal iterator yields results incrementally and can be partially consumed""" | ||||
|         # Verify that streaming yields results one at a time and iterator state is maintained | ||||
| 
 | ||||
|         mock_responses = [ | ||||
|             {"rows": [{"id": i} for i in range(100)], "rowCount": 1000}, | ||||
|             {"rows": [{"id": i} for i in range(100, 200)], "rowCount": 1000}, | ||||
|             {"rows": [], "rowCount": 1000},  # End iteration | ||||
|         ] | ||||
|         dremio_api.get_job_result = Mock(side_effect=mock_responses) | ||||
| 
 | ||||
|         # Test that streaming version yields results incrementally | ||||
|         result_iterator = dremio_api._fetch_results_iter("test-job") | ||||
| 
 | ||||
|         # Get first few results | ||||
|         first_batch = [] | ||||
|         for i, result in enumerate(result_iterator): | ||||
|             first_batch.append(result) | ||||
|             if i >= 50:  # Stop after getting 51 results | ||||
|                 break | ||||
| 
 | ||||
|         assert len(first_batch) == 51 | ||||
|         # Verify we can still get more results (iterator not exhausted) | ||||
|         next_result = next(result_iterator, None) | ||||
|         assert next_result is not None | ||||
| @ -0,0 +1,74 @@ | ||||
| from unittest.mock import Mock | ||||
| 
 | ||||
| import pytest | ||||
| 
 | ||||
| from datahub.ingestion.source.dremio.dremio_config import DremioSourceConfig | ||||
| from datahub.ingestion.source.dremio.dremio_entities import DremioCatalog, DremioDataset | ||||
| from datahub.ingestion.source.dremio.dremio_source import DremioSource | ||||
| 
 | ||||
| 
 | ||||
| class TestDremioIteratorIntegration: | ||||
|     @pytest.fixture | ||||
|     def mock_config(self): | ||||
|         return DremioSourceConfig( | ||||
|             hostname="test-host", | ||||
|             port=9047, | ||||
|             tls=False, | ||||
|             username="test-user", | ||||
|             password="test-password", | ||||
|         ) | ||||
| 
 | ||||
|     @pytest.fixture | ||||
|     def mock_dremio_source(self, mock_config, monkeypatch): | ||||
|         mock_session = Mock() | ||||
|         monkeypatch.setattr("requests.Session", Mock(return_value=mock_session)) | ||||
|         mock_session.post.return_value.json.return_value = {"token": "dummy-token"} | ||||
|         mock_session.post.return_value.status_code = 200 | ||||
| 
 | ||||
|         mock_ctx = Mock() | ||||
|         mock_ctx.run_id = "test-run-id" | ||||
|         source = DremioSource(mock_config, mock_ctx) | ||||
| 
 | ||||
|         source.dremio_catalog = Mock(spec=DremioCatalog) | ||||
|         source.dremio_catalog.dremio_api = Mock() | ||||
| 
 | ||||
|         return source | ||||
| 
 | ||||
|     def test_source_uses_iterators_by_default(self, mock_dremio_source): | ||||
|         """Test that source uses iterators by default""" | ||||
|         mock_dataset = Mock(spec=DremioDataset) | ||||
|         mock_dataset.path = ["test", "path"] | ||||
|         mock_dataset.resource_name = "test_table" | ||||
| 
 | ||||
|         mock_dremio_source.dremio_catalog.get_datasets.return_value = iter( | ||||
|             [mock_dataset] | ||||
|         ) | ||||
|         mock_dremio_source.dremio_catalog.get_containers.return_value = [] | ||||
|         mock_dremio_source.dremio_catalog.get_glossary_terms.return_value = iter([]) | ||||
|         mock_dremio_source.dremio_catalog.get_sources.return_value = [] | ||||
|         mock_dremio_source.config.source_mappings = [] | ||||
|         mock_dremio_source.process_dataset = Mock(return_value=iter([])) | ||||
|         list(mock_dremio_source.get_workunits_internal()) | ||||
| 
 | ||||
|         mock_dremio_source.dremio_catalog.get_datasets.assert_called_once() | ||||
|         mock_dremio_source.dremio_catalog.get_glossary_terms.assert_called_once() | ||||
| 
 | ||||
|     def test_iterator_handles_exceptions_gracefully(self, mock_dremio_source): | ||||
|         """Test that iterator handles exceptions without crashing""" | ||||
|         mock_dataset = Mock(spec=DremioDataset) | ||||
|         mock_dataset.path = ["test", "path"] | ||||
|         mock_dataset.resource_name = "test_table" | ||||
| 
 | ||||
|         mock_dremio_source.dremio_catalog.get_datasets.return_value = iter( | ||||
|             [mock_dataset] | ||||
|         ) | ||||
|         mock_dremio_source.dremio_catalog.get_containers.return_value = [] | ||||
|         mock_dremio_source.dremio_catalog.get_glossary_terms.return_value = iter([]) | ||||
|         mock_dremio_source.dremio_catalog.get_sources.return_value = [] | ||||
|         mock_dremio_source.config.source_mappings = [] | ||||
|         mock_dremio_source.process_dataset = Mock( | ||||
|             side_effect=Exception("Processing error") | ||||
|         ) | ||||
|         list(mock_dremio_source.get_workunits_internal()) | ||||
| 
 | ||||
|         assert mock_dremio_source.report.num_datasets_failed > 0 | ||||
| @ -10,7 +10,7 @@ from datahub.ingestion.source.dremio.dremio_source import ( | ||||
| 
 | ||||
| 
 | ||||
| def test_build_source_map_simple(): | ||||
|     # write unit test | ||||
|     """Test basic source mapping functionality with simple configuration""" | ||||
|     config_mapping: List[DremioSourceMapping] = [ | ||||
|         DremioSourceMapping(source_name="source1", platform="S3", env="PROD"), | ||||
|         DremioSourceMapping(source_name="source2", platform="redshift", env="DEV"), | ||||
| @ -58,7 +58,7 @@ def test_build_source_map_simple(): | ||||
| 
 | ||||
| 
 | ||||
| def test_build_source_map_same_platform_multiple_sources(): | ||||
|     # write unit test | ||||
|     """Test source mapping with multiple sources using the same platform and complex scenarios""" | ||||
|     config_mapping: List[DremioSourceMapping] = [ | ||||
|         DremioSourceMapping(source_name="source1", platform="S3", env="PROD"), | ||||
|         DremioSourceMapping(source_name="source2", platform="redshift", env="DEV"), | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jonny Dixon
						Jonny Dixon