mirror of
				https://github.com/open-metadata/OpenMetadata.git
				synced 2025-10-26 00:04:52 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			70 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			70 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #  Copyright 2025 Collate
 | |
| #  Licensed under the Collate Community License, Version 1.0 (the "License");
 | |
| #  you may not use this file except in compliance with the License.
 | |
| #  You may obtain a copy of the License at
 | |
| #  https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
 | |
| #  Unless required by applicable law or agreed to in writing, software
 | |
| #  distributed under the License is distributed on an "AS IS" BASIS,
 | |
| #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| #  See the License for the specific language governing permissions and
 | |
| #  limitations under the License.
 | |
| """
 | |
| Helper module to handle data sampling for the profiler
 | |
| """
 | |
| from sqlalchemy import Column, event, text
 | |
| from sqlalchemy.orm import scoped_session, sessionmaker
 | |
| 
 | |
| from metadata.ingestion.source.database.databricks.connection import (
 | |
|     get_connection as databricks_get_connection,
 | |
| )
 | |
| from metadata.profiler.orm.types.custom_array import CustomArray
 | |
| from metadata.sampler.sqlalchemy.sampler import SQASampler
 | |
| 
 | |
| 
 | |
| class DatabricksSamplerInterface(SQASampler):
 | |
|     def __init__(self, *args, **kwargs):
 | |
|         """Initialize with a single Databricks connection"""
 | |
|         super().__init__(*args, **kwargs)
 | |
|         self.connection = databricks_get_connection(self.service_connection_config)
 | |
|         session_maker = sessionmaker(bind=self.connection)
 | |
| 
 | |
|         @event.listens_for(session_maker, "after_begin")
 | |
|         def set_catalog(session, transaction, connection):
 | |
|             # Safely quote the catalog name to prevent SQL injection
 | |
|             quoted_catalog = connection.dialect.identifier_preparer.quote(
 | |
|                 self.service_connection_config.catalog
 | |
|             )
 | |
|             connection.execute(f"USE CATALOG {quoted_catalog};")
 | |
| 
 | |
|         self.session_factory = scoped_session(session_maker)
 | |
| 
 | |
|     def get_client(self):
 | |
|         """client is the session for SQA"""
 | |
|         client = super().get_client()
 | |
|         self.set_catalog(client)
 | |
|         return client
 | |
| 
 | |
|     def _handle_array_column(self, column: Column) -> bool:
 | |
|         """Check if a column is an array type"""
 | |
|         return isinstance(column.type, CustomArray)
 | |
| 
 | |
|     def _get_slice_expression(self, column: Column):
 | |
|         """Generate SQL expression to slice array elements at query level
 | |
| 
 | |
|         Args:
 | |
|             column_name: Name of the column
 | |
|             max_elements: Maximum number of elements to extract
 | |
| 
 | |
|         Returns:
 | |
|             SQL expression string for array slicing
 | |
|         """
 | |
|         max_elements = self._get_max_array_elements()
 | |
|         return text(
 | |
|             f"""
 | |
|         CASE 
 | |
|             WHEN `{column.name}` IS NULL THEN NULL
 | |
|             ELSE slice(`{column.name}`, 1, {max_elements})
 | |
|         END AS `{column._label}`
 | |
|         """
 | |
|         )
 | 
