mirror of
				https://github.com/open-metadata/OpenMetadata.git
				synced 2025-11-04 04:29:13 +00:00 
			
		
		
		
	* get table ddl for athena tables * changes in method to get all table ddls * external table/container lineage for athena * column lineage for external table lineage * unittest for athena * pyformat changes * add external table lineage unit test * fix unittest with pydantic v2 changes * fix unittest formating * fix code smell
This commit is contained in:
		
							parent
							
								
									afec7703cc
								
							
						
					
					
						commit
						3f5bc1948d
					
				@ -12,12 +12,9 @@
 | 
				
			|||||||
"""Athena source module"""
 | 
					"""Athena source module"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import traceback
 | 
					import traceback
 | 
				
			||||||
from copy import deepcopy
 | 
					from typing import Iterable, Optional, Tuple
 | 
				
			||||||
from typing import Dict, Iterable, List, Optional, Tuple
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from pyathena.sqlalchemy.base import AthenaDialect
 | 
					from pyathena.sqlalchemy.base import AthenaDialect
 | 
				
			||||||
from sqlalchemy import types
 | 
					 | 
				
			||||||
from sqlalchemy.engine import reflection
 | 
					 | 
				
			||||||
from sqlalchemy.engine.reflection import Inspector
 | 
					from sqlalchemy.engine.reflection import Inspector
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
 | 
					from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
 | 
				
			||||||
@ -42,18 +39,33 @@ from metadata.ingestion.api.models import Either
 | 
				
			|||||||
from metadata.ingestion.api.steps import InvalidSourceException
 | 
					from metadata.ingestion.api.steps import InvalidSourceException
 | 
				
			||||||
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
 | 
					from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
 | 
				
			||||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
 | 
					from metadata.ingestion.ometa.ometa_api import OpenMetadata
 | 
				
			||||||
from metadata.ingestion.source import sqa_types
 | 
					 | 
				
			||||||
from metadata.ingestion.source.database.athena.client import AthenaLakeFormationClient
 | 
					from metadata.ingestion.source.database.athena.client import AthenaLakeFormationClient
 | 
				
			||||||
from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser
 | 
					from metadata.ingestion.source.database.athena.utils import (
 | 
				
			||||||
 | 
					    _get_column_type,
 | 
				
			||||||
 | 
					    get_columns,
 | 
				
			||||||
 | 
					    get_table_options,
 | 
				
			||||||
 | 
					    get_view_definition,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from metadata.ingestion.source.database.common_db_source import (
 | 
					from metadata.ingestion.source.database.common_db_source import (
 | 
				
			||||||
    CommonDbSourceService,
 | 
					    CommonDbSourceService,
 | 
				
			||||||
    TableNameAndType,
 | 
					    TableNameAndType,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					from metadata.ingestion.source.database.external_table_lineage_mixin import (
 | 
				
			||||||
 | 
					    ExternalTableLineageMixin,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from metadata.utils import fqn
 | 
					from metadata.utils import fqn
 | 
				
			||||||
from metadata.utils.logger import ingestion_logger
 | 
					from metadata.utils.logger import ingestion_logger
 | 
				
			||||||
from metadata.utils.sqlalchemy_utils import is_complex_type
 | 
					from metadata.utils.sqlalchemy_utils import get_all_table_ddls, get_table_ddl
 | 
				
			||||||
from metadata.utils.tag_utils import get_ometa_tag_and_classification
 | 
					from metadata.utils.tag_utils import get_ometa_tag_and_classification
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					AthenaDialect._get_column_type = _get_column_type  # pylint: disable=protected-access
 | 
				
			||||||
 | 
					AthenaDialect.get_columns = get_columns
 | 
				
			||||||
 | 
					AthenaDialect.get_view_definition = get_view_definition
 | 
				
			||||||
 | 
					AthenaDialect.get_table_options = get_table_options
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Inspector.get_all_table_ddls = get_all_table_ddls
 | 
				
			||||||
 | 
					Inspector.get_table_ddl = get_table_ddl
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = ingestion_logger()
 | 
					logger = ingestion_logger()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ATHENA_TAG = "ATHENA TAG"
 | 
					ATHENA_TAG = "ATHENA TAG"
 | 
				
			||||||
@ -72,165 +84,7 @@ ATHENA_INTERVAL_TYPE_MAP = {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _get_column_type(self, type_):
 | 
					class AthenaSource(ExternalTableLineageMixin, CommonDbSourceService):
 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Function overwritten from AthenaDialect
 | 
					 | 
				
			||||||
    to add custom SQA typing.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    type_ = type_.replace(" ", "").lower()
 | 
					 | 
				
			||||||
    match = self._pattern_column_type.match(type_)  # pylint: disable=protected-access
 | 
					 | 
				
			||||||
    if match:
 | 
					 | 
				
			||||||
        name = match.group(1).lower()
 | 
					 | 
				
			||||||
        length = match.group(2)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        name = type_.lower()
 | 
					 | 
				
			||||||
        length = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    args = []
 | 
					 | 
				
			||||||
    col_map = {
 | 
					 | 
				
			||||||
        "boolean": types.BOOLEAN,
 | 
					 | 
				
			||||||
        "float": types.FLOAT,
 | 
					 | 
				
			||||||
        "double": types.FLOAT,
 | 
					 | 
				
			||||||
        "real": types.FLOAT,
 | 
					 | 
				
			||||||
        "tinyint": types.INTEGER,
 | 
					 | 
				
			||||||
        "smallint": types.INTEGER,
 | 
					 | 
				
			||||||
        "integer": types.INTEGER,
 | 
					 | 
				
			||||||
        "int": types.INTEGER,
 | 
					 | 
				
			||||||
        "bigint": types.BIGINT,
 | 
					 | 
				
			||||||
        "string": types.String,
 | 
					 | 
				
			||||||
        "date": types.DATE,
 | 
					 | 
				
			||||||
        "timestamp": types.TIMESTAMP,
 | 
					 | 
				
			||||||
        "binary": types.BINARY,
 | 
					 | 
				
			||||||
        "varbinary": types.BINARY,
 | 
					 | 
				
			||||||
        "array": types.ARRAY,
 | 
					 | 
				
			||||||
        "json": types.JSON,
 | 
					 | 
				
			||||||
        "struct": sqa_types.SQAStruct,
 | 
					 | 
				
			||||||
        "row": sqa_types.SQAStruct,
 | 
					 | 
				
			||||||
        "map": sqa_types.SQAMap,
 | 
					 | 
				
			||||||
        "decimal": types.DECIMAL,
 | 
					 | 
				
			||||||
        "varchar": types.VARCHAR,
 | 
					 | 
				
			||||||
        "char": types.CHAR,
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    if name in ["decimal", "char", "varchar"]:
 | 
					 | 
				
			||||||
        col_type = col_map[name]
 | 
					 | 
				
			||||||
        if length:
 | 
					 | 
				
			||||||
            args = [int(l) for l in length.split(",")]
 | 
					 | 
				
			||||||
    elif type_.startswith("array"):
 | 
					 | 
				
			||||||
        parsed_type = (
 | 
					 | 
				
			||||||
            ColumnTypeParser._parse_datatype_string(  # pylint: disable=protected-access
 | 
					 | 
				
			||||||
                type_
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        col_type = col_map["array"]
 | 
					 | 
				
			||||||
        if parsed_type["arrayDataType"].lower().startswith("array"):
 | 
					 | 
				
			||||||
            # as OpenMetadata doesn't store any details on children of array, we put
 | 
					 | 
				
			||||||
            # in type as string as default to avoid Array item_type required issue
 | 
					 | 
				
			||||||
            # from sqlalchemy types
 | 
					 | 
				
			||||||
            args = [types.String]
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            args = [col_map.get(parsed_type.get("arrayDataType").lower(), types.String)]
 | 
					 | 
				
			||||||
    elif col_map.get(name):
 | 
					 | 
				
			||||||
        col_type = col_map.get(name)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        logger.warning(f"Did not recognize type '{type_}'")
 | 
					 | 
				
			||||||
        col_type = types.NullType
 | 
					 | 
				
			||||||
    return col_type(*args)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _get_projection_details(
 | 
					 | 
				
			||||||
    columns: List[Dict], projection_parameters: Dict
 | 
					 | 
				
			||||||
) -> List[Dict]:
 | 
					 | 
				
			||||||
    """Get the projection details for the columns
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        columns (List[Dict]): list of columns
 | 
					 | 
				
			||||||
        projection_parameters (Dict): projection parameters
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if not projection_parameters:
 | 
					 | 
				
			||||||
        return columns
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    columns = deepcopy(columns)
 | 
					 | 
				
			||||||
    for col in columns:
 | 
					 | 
				
			||||||
        projection_details = next(
 | 
					 | 
				
			||||||
            ({k: v} for k, v in projection_parameters.items() if k == col["name"]), None
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        if projection_details:
 | 
					 | 
				
			||||||
            col["projection_type"] = projection_details[col["name"]]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return columns
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@reflection.cache
 | 
					 | 
				
			||||||
def get_columns(self, connection, table_name, schema=None, **kw):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Method to handle table columns
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    metadata = self._get_table(  # pylint: disable=protected-access
 | 
					 | 
				
			||||||
        connection, table_name, schema=schema, **kw
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    columns = [
 | 
					 | 
				
			||||||
        {
 | 
					 | 
				
			||||||
            "name": c.name,
 | 
					 | 
				
			||||||
            "type": self._get_column_type(c.type),  # pylint: disable=protected-access
 | 
					 | 
				
			||||||
            "nullable": True,
 | 
					 | 
				
			||||||
            "default": None,
 | 
					 | 
				
			||||||
            "autoincrement": False,
 | 
					 | 
				
			||||||
            "comment": c.comment,
 | 
					 | 
				
			||||||
            "system_data_type": c.type,
 | 
					 | 
				
			||||||
            "is_complex": is_complex_type(c.type),
 | 
					 | 
				
			||||||
            "dialect_options": {"awsathena_partition": True},
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        for c in metadata.partition_keys
 | 
					 | 
				
			||||||
    ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if kw.get("only_partition_columns"):
 | 
					 | 
				
			||||||
        # Return projected partition information to set partition type in `get_table_partition_details`
 | 
					 | 
				
			||||||
        # projected partition fields are stored in the form of `projection.<field_name>.type` as a table parameter
 | 
					 | 
				
			||||||
        projection_parameters = {
 | 
					 | 
				
			||||||
            key_.split(".")[1]: value_
 | 
					 | 
				
			||||||
            for key_, value_ in metadata.parameters.items()
 | 
					 | 
				
			||||||
            if key_.startswith("projection") and key_.endswith("type")
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        columns = _get_projection_details(columns, projection_parameters)
 | 
					 | 
				
			||||||
        return columns
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    columns += [
 | 
					 | 
				
			||||||
        {
 | 
					 | 
				
			||||||
            "name": c.name,
 | 
					 | 
				
			||||||
            "type": self._get_column_type(c.type),  # pylint: disable=protected-access
 | 
					 | 
				
			||||||
            "nullable": True,
 | 
					 | 
				
			||||||
            "default": None,
 | 
					 | 
				
			||||||
            "autoincrement": False,
 | 
					 | 
				
			||||||
            "comment": c.comment,
 | 
					 | 
				
			||||||
            "system_data_type": c.type,
 | 
					 | 
				
			||||||
            "is_complex": is_complex_type(c.type),
 | 
					 | 
				
			||||||
            "dialect_options": {"awsathena_partition": None},
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        for c in metadata.columns
 | 
					 | 
				
			||||||
    ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return columns
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# pylint: disable=unused-argument
 | 
					 | 
				
			||||||
@reflection.cache
 | 
					 | 
				
			||||||
def get_view_definition(self, connection, view_name, schema=None, **kw):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Gets the view definition
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    full_view_name = f'"{view_name}"' if not schema else f'"{schema}"."{view_name}"'
 | 
					 | 
				
			||||||
    res = connection.execute(f"SHOW CREATE VIEW {full_view_name}").fetchall()
 | 
					 | 
				
			||||||
    if res:
 | 
					 | 
				
			||||||
        return "\n".join(i[0] for i in res)
 | 
					 | 
				
			||||||
    return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
AthenaDialect._get_column_type = _get_column_type  # pylint: disable=protected-access
 | 
					 | 
				
			||||||
AthenaDialect.get_columns = get_columns
 | 
					 | 
				
			||||||
AthenaDialect.get_view_definition = get_view_definition
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class AthenaSource(CommonDbSourceService):
 | 
					 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Implements the necessary methods to extract
 | 
					    Implements the necessary methods to extract
 | 
				
			||||||
    Database metadata from Athena Source
 | 
					    Database metadata from Athena Source
 | 
				
			||||||
@ -257,6 +111,7 @@ class AthenaSource(CommonDbSourceService):
 | 
				
			|||||||
        self.athena_lake_formation_client = AthenaLakeFormationClient(
 | 
					        self.athena_lake_formation_client = AthenaLakeFormationClient(
 | 
				
			||||||
            connection=self.service_connection
 | 
					            connection=self.service_connection
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        self.external_location_map = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def query_table_names_and_types(
 | 
					    def query_table_names_and_types(
 | 
				
			||||||
        self, schema_name: str
 | 
					        self, schema_name: str
 | 
				
			||||||
@ -395,3 +250,23 @@ class AthenaSource(CommonDbSourceService):
 | 
				
			|||||||
                        stackTrace=traceback.format_exc(),
 | 
					                        stackTrace=traceback.format_exc(),
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_table_description(
 | 
				
			||||||
 | 
					        self, schema_name: str, table_name: str, inspector: Inspector
 | 
				
			||||||
 | 
					    ) -> str:
 | 
				
			||||||
 | 
					        description = None
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            table_info: dict = inspector.get_table_comment(table_name, schema_name)
 | 
				
			||||||
 | 
					            table_option = inspector.get_table_options(table_name, schema_name)
 | 
				
			||||||
 | 
					            self.external_location_map[
 | 
				
			||||||
 | 
					                (self.context.get().database, schema_name, table_name)
 | 
				
			||||||
 | 
					            ] = table_option.get("awsathena_location")
 | 
				
			||||||
 | 
					        # Catch any exception without breaking the ingestion
 | 
				
			||||||
 | 
					        except Exception as exc:  # pylint: disable=broad-except
 | 
				
			||||||
 | 
					            logger.debug(traceback.format_exc())
 | 
				
			||||||
 | 
					            logger.warning(
 | 
				
			||||||
 | 
					                f"Table description error for table [{schema_name}.{table_name}]: {exc}"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            description = table_info.get("text")
 | 
				
			||||||
 | 
					        return description
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										191
									
								
								ingestion/src/metadata/ingestion/source/database/athena/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								ingestion/src/metadata/ingestion/source/database/athena/utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,191 @@
 | 
				
			|||||||
 | 
					#  Copyright 2021 Collate
 | 
				
			||||||
 | 
					#  Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					#  you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					#  You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#  http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#  Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					#  distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					#  See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					#  limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""Athena utils module"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from copy import deepcopy
 | 
				
			||||||
 | 
					from typing import Dict, List, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from pyathena.sqlalchemy.util import _HashableDict
 | 
				
			||||||
 | 
					from sqlalchemy import types
 | 
				
			||||||
 | 
					from sqlalchemy.engine import reflection
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from metadata.ingestion.source import sqa_types
 | 
				
			||||||
 | 
					from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser
 | 
				
			||||||
 | 
					from metadata.utils.sqlalchemy_utils import is_complex_type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@reflection.cache
 | 
				
			||||||
 | 
					def _get_column_type(self, type_):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Function overwritten from AthenaDialect
 | 
				
			||||||
 | 
					    to add custom SQA typing.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    type_ = type_.replace(" ", "").lower()
 | 
				
			||||||
 | 
					    match = self._pattern_column_type.match(type_)  # pylint: disable=protected-access
 | 
				
			||||||
 | 
					    if match:
 | 
				
			||||||
 | 
					        name = match.group(1).lower()
 | 
				
			||||||
 | 
					        length = match.group(2)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        name = type_.lower()
 | 
				
			||||||
 | 
					        length = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    args = []
 | 
				
			||||||
 | 
					    col_map = {
 | 
				
			||||||
 | 
					        "boolean": types.BOOLEAN,
 | 
				
			||||||
 | 
					        "float": types.FLOAT,
 | 
				
			||||||
 | 
					        "double": types.FLOAT,
 | 
				
			||||||
 | 
					        "real": types.FLOAT,
 | 
				
			||||||
 | 
					        "tinyint": types.INTEGER,
 | 
				
			||||||
 | 
					        "smallint": types.INTEGER,
 | 
				
			||||||
 | 
					        "integer": types.INTEGER,
 | 
				
			||||||
 | 
					        "int": types.INTEGER,
 | 
				
			||||||
 | 
					        "bigint": types.BIGINT,
 | 
				
			||||||
 | 
					        "string": types.String,
 | 
				
			||||||
 | 
					        "date": types.DATE,
 | 
				
			||||||
 | 
					        "timestamp": types.TIMESTAMP,
 | 
				
			||||||
 | 
					        "binary": types.BINARY,
 | 
				
			||||||
 | 
					        "varbinary": types.BINARY,
 | 
				
			||||||
 | 
					        "array": types.ARRAY,
 | 
				
			||||||
 | 
					        "json": types.JSON,
 | 
				
			||||||
 | 
					        "struct": sqa_types.SQAStruct,
 | 
				
			||||||
 | 
					        "row": sqa_types.SQAStruct,
 | 
				
			||||||
 | 
					        "map": sqa_types.SQAMap,
 | 
				
			||||||
 | 
					        "decimal": types.DECIMAL,
 | 
				
			||||||
 | 
					        "varchar": types.VARCHAR,
 | 
				
			||||||
 | 
					        "char": types.CHAR,
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if name in ["decimal", "char", "varchar"]:
 | 
				
			||||||
 | 
					        col_type = col_map[name]
 | 
				
			||||||
 | 
					        if length:
 | 
				
			||||||
 | 
					            args = [int(l) for l in length.split(",")]
 | 
				
			||||||
 | 
					    elif type_.startswith("array"):
 | 
				
			||||||
 | 
					        parsed_type = (
 | 
				
			||||||
 | 
					            ColumnTypeParser._parse_datatype_string(  # pylint: disable=protected-access
 | 
				
			||||||
 | 
					                type_
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        col_type = col_map["array"]
 | 
				
			||||||
 | 
					        if parsed_type["arrayDataType"].lower().startswith("array"):
 | 
				
			||||||
 | 
					            # as OpenMetadata doesn't store any details on children of array, we put
 | 
				
			||||||
 | 
					            # in type as string as default to avoid Array item_type required issue
 | 
				
			||||||
 | 
					            # from sqlalchemy types
 | 
				
			||||||
 | 
					            args = [types.String]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            args = [col_map.get(parsed_type.get("arrayDataType").lower(), types.String)]
 | 
				
			||||||
 | 
					    elif col_map.get(name):
 | 
				
			||||||
 | 
					        col_type = col_map.get(name)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        logger.warning(f"Did not recognize type '{type_}'")
 | 
				
			||||||
 | 
					        col_type = types.NullType
 | 
				
			||||||
 | 
					    return col_type(*args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# pylint: disable=unused-argument
 | 
				
			||||||
 | 
					def _get_projection_details(
 | 
				
			||||||
 | 
					    columns: List[Dict], projection_parameters: Dict
 | 
				
			||||||
 | 
					) -> List[Dict]:
 | 
				
			||||||
 | 
					    """Get the projection details for the columns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        columns (List[Dict]): list of columns
 | 
				
			||||||
 | 
					        projection_parameters (Dict): projection parameters
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if not projection_parameters:
 | 
				
			||||||
 | 
					        return columns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    columns = deepcopy(columns)
 | 
				
			||||||
 | 
					    for col in columns:
 | 
				
			||||||
 | 
					        projection_details = next(
 | 
				
			||||||
 | 
					            ({k: v} for k, v in projection_parameters.items() if k == col["name"]), None
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if projection_details:
 | 
				
			||||||
 | 
					            col["projection_type"] = projection_details[col["name"]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return columns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@reflection.cache
 | 
				
			||||||
 | 
					def get_columns(self, connection, table_name, schema=None, **kw):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Method to handle table columns
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    metadata = self._get_table(  # pylint: disable=protected-access
 | 
				
			||||||
 | 
					        connection, table_name, schema=schema, **kw
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    columns = [
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": c.name,
 | 
				
			||||||
 | 
					            "type": self._get_column_type(c.type),  # pylint: disable=protected-access
 | 
				
			||||||
 | 
					            "nullable": True,
 | 
				
			||||||
 | 
					            "default": None,
 | 
				
			||||||
 | 
					            "autoincrement": False,
 | 
				
			||||||
 | 
					            "comment": c.comment,
 | 
				
			||||||
 | 
					            "system_data_type": c.type,
 | 
				
			||||||
 | 
					            "is_complex": is_complex_type(c.type),
 | 
				
			||||||
 | 
					            "dialect_options": {"awsathena_partition": True},
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        for c in metadata.partition_keys
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if kw.get("only_partition_columns"):
 | 
				
			||||||
 | 
					        # Return projected partition information to set partition type in `get_table_partition_details`
 | 
				
			||||||
 | 
					        # projected partition fields are stored in the form of `projection.<field_name>.type` as a table parameter
 | 
				
			||||||
 | 
					        projection_parameters = {
 | 
				
			||||||
 | 
					            key_.split(".")[1]: value_
 | 
				
			||||||
 | 
					            for key_, value_ in metadata.parameters.items()
 | 
				
			||||||
 | 
					            if key_.startswith("projection") and key_.endswith("type")
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        columns = _get_projection_details(columns, projection_parameters)
 | 
				
			||||||
 | 
					        return columns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    columns += [
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": c.name,
 | 
				
			||||||
 | 
					            "type": self._get_column_type(c.type),  # pylint: disable=protected-access
 | 
				
			||||||
 | 
					            "nullable": True,
 | 
				
			||||||
 | 
					            "default": None,
 | 
				
			||||||
 | 
					            "autoincrement": False,
 | 
				
			||||||
 | 
					            "comment": c.comment,
 | 
				
			||||||
 | 
					            "system_data_type": c.type,
 | 
				
			||||||
 | 
					            "is_complex": is_complex_type(c.type),
 | 
				
			||||||
 | 
					            "dialect_options": {"awsathena_partition": None},
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        for c in metadata.columns
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return columns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@reflection.cache
 | 
				
			||||||
 | 
					def get_view_definition(self, connection, view_name, schema=None, **kw):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Gets the view definition
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    full_view_name = f'"{view_name}"' if not schema else f'"{schema}"."{view_name}"'
 | 
				
			||||||
 | 
					    res = connection.execute(f"SHOW CREATE VIEW {full_view_name}").fetchall()
 | 
				
			||||||
 | 
					    if res:
 | 
				
			||||||
 | 
					        return "\n".join(i[0] for i in res)
 | 
				
			||||||
 | 
					    return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_table_options(
 | 
				
			||||||
 | 
					    self, connection: "Connection", table_name: str, schema: Optional[str] = None, **kw
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    metadata = self._get_table(connection, table_name, schema=schema, **kw)
 | 
				
			||||||
 | 
					    return {
 | 
				
			||||||
 | 
					        "awsathena_location": metadata.location,
 | 
				
			||||||
 | 
					        "awsathena_compression": metadata.compression,
 | 
				
			||||||
 | 
					        "awsathena_row_format": metadata.row_format,
 | 
				
			||||||
 | 
					        "awsathena_file_format": metadata.file_format,
 | 
				
			||||||
 | 
					        "awsathena_serdeproperties": _HashableDict(metadata.serde_properties),
 | 
				
			||||||
 | 
					        "awsathena_tblproperties": _HashableDict(metadata.table_properties),
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
@ -14,13 +14,20 @@ External Table Lineage Mixin
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import traceback
 | 
					import traceback
 | 
				
			||||||
from abc import ABC
 | 
					from abc import ABC
 | 
				
			||||||
from typing import Iterable
 | 
					from typing import Iterable, List, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
 | 
					from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
 | 
				
			||||||
 | 
					from metadata.generated.schema.entity.data.container import ContainerDataModel
 | 
				
			||||||
from metadata.generated.schema.entity.data.table import Table
 | 
					from metadata.generated.schema.entity.data.table import Table
 | 
				
			||||||
from metadata.generated.schema.type.entityLineage import EntitiesEdge
 | 
					from metadata.generated.schema.type.entityLineage import (
 | 
				
			||||||
 | 
					    ColumnLineage,
 | 
				
			||||||
 | 
					    EntitiesEdge,
 | 
				
			||||||
 | 
					    LineageDetails,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from metadata.generated.schema.type.entityLineage import Source as LineageSource
 | 
				
			||||||
from metadata.generated.schema.type.entityReference import EntityReference
 | 
					from metadata.generated.schema.type.entityReference import EntityReference
 | 
				
			||||||
from metadata.ingestion.api.models import Either
 | 
					from metadata.ingestion.api.models import Either
 | 
				
			||||||
 | 
					from metadata.ingestion.lineage.sql_lineage import get_column_fqn
 | 
				
			||||||
from metadata.utils import fqn
 | 
					from metadata.utils import fqn
 | 
				
			||||||
from metadata.utils.logger import ingestion_logger
 | 
					from metadata.utils.logger import ingestion_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -39,7 +46,7 @@ class ExternalTableLineageMixin(ABC):
 | 
				
			|||||||
        for table_qualified_tuple, location in self.external_location_map.items() or []:
 | 
					        for table_qualified_tuple, location in self.external_location_map.items() or []:
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                location_entity = self.metadata.es_search_container_by_path(
 | 
					                location_entity = self.metadata.es_search_container_by_path(
 | 
				
			||||||
                    full_path=location
 | 
					                    full_path=location, fields="dataModel"
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                database_name, schema_name, table_name = table_qualified_tuple
 | 
					                database_name, schema_name, table_name = table_qualified_tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -63,6 +70,12 @@ class ExternalTableLineageMixin(ABC):
 | 
				
			|||||||
                    and table_entity
 | 
					                    and table_entity
 | 
				
			||||||
                    and table_entity[0]
 | 
					                    and table_entity[0]
 | 
				
			||||||
                ):
 | 
					                ):
 | 
				
			||||||
 | 
					                    columns_list = [
 | 
				
			||||||
 | 
					                        column.name.root for column in table_entity[0].columns
 | 
				
			||||||
 | 
					                    ]
 | 
				
			||||||
 | 
					                    columns_lineage = self._get_column_lineage(
 | 
				
			||||||
 | 
					                        location_entity[0].dataModel, table_entity[0], columns_list
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
                    yield Either(
 | 
					                    yield Either(
 | 
				
			||||||
                        right=AddLineageRequest(
 | 
					                        right=AddLineageRequest(
 | 
				
			||||||
                            edge=EntitiesEdge(
 | 
					                            edge=EntitiesEdge(
 | 
				
			||||||
@ -74,9 +87,51 @@ class ExternalTableLineageMixin(ABC):
 | 
				
			|||||||
                                    id=table_entity[0].id,
 | 
					                                    id=table_entity[0].id,
 | 
				
			||||||
                                    type="table",
 | 
					                                    type="table",
 | 
				
			||||||
                                ),
 | 
					                                ),
 | 
				
			||||||
 | 
					                                lineageDetails=LineageDetails(
 | 
				
			||||||
 | 
					                                    source=LineageSource.ExternalTableLineage,
 | 
				
			||||||
 | 
					                                    columnsLineage=columns_lineage,
 | 
				
			||||||
 | 
					                                ),
 | 
				
			||||||
                            )
 | 
					                            )
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
            except Exception as exc:
 | 
					            except Exception as exc:
 | 
				
			||||||
                logger.warning(f"Failed to yield external table lineage due to - {exc}")
 | 
					                logger.warning(f"Failed to yield external table lineage due to - {exc}")
 | 
				
			||||||
                logger.debug(traceback.format_exc())
 | 
					                logger.debug(traceback.format_exc())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _get_data_model_column_fqn(
 | 
				
			||||||
 | 
					        self, data_model_entity: ContainerDataModel, column: str
 | 
				
			||||||
 | 
					    ) -> Optional[str]:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Get fqn of column if exist in data model entity
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if not data_model_entity:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					        for entity_column in data_model_entity.columns:
 | 
				
			||||||
 | 
					            if entity_column.displayName.lower() == column.lower():
 | 
				
			||||||
 | 
					                return entity_column.fullyQualifiedName.root
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _get_column_lineage(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        data_model_entity: ContainerDataModel,
 | 
				
			||||||
 | 
					        table_entity: Table,
 | 
				
			||||||
 | 
					        columns_list: List[str],
 | 
				
			||||||
 | 
					    ) -> List[ColumnLineage]:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Get the column lineage
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            column_lineage = []
 | 
				
			||||||
 | 
					            for field in columns_list or []:
 | 
				
			||||||
 | 
					                from_column = self._get_data_model_column_fqn(
 | 
				
			||||||
 | 
					                    data_model_entity=data_model_entity, column=field
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                to_column = get_column_fqn(table_entity=table_entity, column=field)
 | 
				
			||||||
 | 
					                if from_column and to_column:
 | 
				
			||||||
 | 
					                    column_lineage.append(
 | 
				
			||||||
 | 
					                        ColumnLineage(fromColumns=[from_column], toColumn=to_column)
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					            return column_lineage
 | 
				
			||||||
 | 
					        except Exception as exc:
 | 
				
			||||||
 | 
					            logger.debug(f"Error to get column lineage: {exc}")
 | 
				
			||||||
 | 
					            logger.debug(traceback.format_exc())
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										325
									
								
								ingestion/tests/unit/topology/database/test_athena.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										325
									
								
								ingestion/tests/unit/topology/database/test_athena.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,325 @@
 | 
				
			|||||||
 | 
					#  Copyright 2021 Collate
 | 
				
			||||||
 | 
					#  Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					#  you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					#  You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#  http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#  Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					#  distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					#  See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					#  limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					Test athena source
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					from unittest.mock import patch
 | 
				
			||||||
 | 
					from uuid import UUID
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from pydantic import AnyUrl
 | 
				
			||||||
 | 
					from sqlalchemy.engine.reflection import Inspector
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
 | 
				
			||||||
 | 
					from metadata.generated.schema.entity.data.container import (
 | 
				
			||||||
 | 
					    Container,
 | 
				
			||||||
 | 
					    ContainerDataModel,
 | 
				
			||||||
 | 
					    FileFormat,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from metadata.generated.schema.entity.data.database import Database
 | 
				
			||||||
 | 
					from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
 | 
				
			||||||
 | 
					from metadata.generated.schema.entity.data.table import (
 | 
				
			||||||
 | 
					    Column,
 | 
				
			||||||
 | 
					    ColumnName,
 | 
				
			||||||
 | 
					    Constraint,
 | 
				
			||||||
 | 
					    DataType,
 | 
				
			||||||
 | 
					    Table,
 | 
				
			||||||
 | 
					    TableType,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from metadata.generated.schema.entity.services.databaseService import (
 | 
				
			||||||
 | 
					    DatabaseConnection,
 | 
				
			||||||
 | 
					    DatabaseService,
 | 
				
			||||||
 | 
					    DatabaseServiceType,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from metadata.generated.schema.entity.services.storageService import StorageServiceType
 | 
				
			||||||
 | 
					from metadata.generated.schema.metadataIngestion.workflow import (
 | 
				
			||||||
 | 
					    OpenMetadataWorkflowConfig,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from metadata.generated.schema.type.basic import (
 | 
				
			||||||
 | 
					    EntityName,
 | 
				
			||||||
 | 
					    FullyQualifiedEntityName,
 | 
				
			||||||
 | 
					    Href,
 | 
				
			||||||
 | 
					    Markdown,
 | 
				
			||||||
 | 
					    SourceUrl,
 | 
				
			||||||
 | 
					    Timestamp,
 | 
				
			||||||
 | 
					    Uuid,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from metadata.generated.schema.type.entityLineage import ColumnLineage
 | 
				
			||||||
 | 
					from metadata.generated.schema.type.entityReference import EntityReference
 | 
				
			||||||
 | 
					from metadata.ingestion.api.models import Either
 | 
				
			||||||
 | 
					from metadata.ingestion.source.database.athena.metadata import AthenaSource
 | 
				
			||||||
 | 
					from metadata.ingestion.source.database.common_db_source import TableNameAndType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					EXPECTED_DATABASE_NAMES = ["mydatabase"]
 | 
				
			||||||
 | 
					MOCK_DATABASE_SCHEMA = DatabaseSchema(
 | 
				
			||||||
 | 
					    id="2aaa012e-099a-11ed-861d-0242ac120056",
 | 
				
			||||||
 | 
					    name="sample_instance",
 | 
				
			||||||
 | 
					    fullyQualifiedName="sample_athena_schema.sample_db.sample_instance",
 | 
				
			||||||
 | 
					    displayName="default",
 | 
				
			||||||
 | 
					    description="",
 | 
				
			||||||
 | 
					    database=EntityReference(
 | 
				
			||||||
 | 
					        id="2aaa012e-099a-11ed-861d-0242ac120002",
 | 
				
			||||||
 | 
					        type="database",
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					    service=EntityReference(
 | 
				
			||||||
 | 
					        id="85811038-099a-11ed-861d-0242ac120002",
 | 
				
			||||||
 | 
					        type="databaseService",
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					MOCK_DATABASE_SERVICE = DatabaseService(
 | 
				
			||||||
 | 
					    id="85811038-099a-11ed-861d-0242ac120002",
 | 
				
			||||||
 | 
					    name="sample_athena_service",
 | 
				
			||||||
 | 
					    connection=DatabaseConnection(),
 | 
				
			||||||
 | 
					    serviceType=DatabaseServiceType.Glue,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					MOCK_DATABASE = Database(
 | 
				
			||||||
 | 
					    id="2aaa012e-099a-11ed-861d-0242ac120002",
 | 
				
			||||||
 | 
					    name="sample_db",
 | 
				
			||||||
 | 
					    fullyQualifiedName="test_athena.sample_db",
 | 
				
			||||||
 | 
					    displayName="sample_db",
 | 
				
			||||||
 | 
					    description="",
 | 
				
			||||||
 | 
					    service=EntityReference(
 | 
				
			||||||
 | 
					        id="85811038-099a-11ed-861d-0242ac120002",
 | 
				
			||||||
 | 
					        type="databaseService",
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					MOCK_TABLE_NAME = "sample_table"
 | 
				
			||||||
 | 
					EXPECTED_DATABASES = [
 | 
				
			||||||
 | 
					    Either(
 | 
				
			||||||
 | 
					        right=CreateDatabaseRequest(
 | 
				
			||||||
 | 
					            name=EntityName("sample_db"),
 | 
				
			||||||
 | 
					            service=FullyQualifiedEntityName("sample_athena_service"),
 | 
				
			||||||
 | 
					            default=False,
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					EXPECTED_QUERY_TABLE_NAMES_TYPES = [
 | 
				
			||||||
 | 
					    TableNameAndType(name="sample_table", type_=TableType.External)
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					MOCK_LOCATION_ENTITY = [
 | 
				
			||||||
 | 
					    Container(
 | 
				
			||||||
 | 
					        id=Uuid(UUID("9c489754-bb60-435b-b2a5-0e43100cf950")),
 | 
				
			||||||
 | 
					        name=EntityName("dbt-testing/mayur/customers.csv"),
 | 
				
			||||||
 | 
					        fullyQualifiedName=FullyQualifiedEntityName(
 | 
				
			||||||
 | 
					            's3_local.awsdatalake-testing."dbt-testing/mayur/customers.csv"'
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        updatedAt=Timestamp(1717070902713),
 | 
				
			||||||
 | 
					        updatedBy="admin",
 | 
				
			||||||
 | 
					        href=Href(
 | 
				
			||||||
 | 
					            root=AnyUrl(
 | 
				
			||||||
 | 
					                "http://localhost:8585/api/v1/containers/9c489754-bb60-435b-b2a5-0e43100cf950",
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        service=EntityReference(
 | 
				
			||||||
 | 
					            id=Uuid(UUID("dd91cca3-cc54-4776-9efa-48f845cdfb92")),
 | 
				
			||||||
 | 
					            type="storageService",
 | 
				
			||||||
 | 
					            name="s3_local",
 | 
				
			||||||
 | 
					            fullyQualifiedName="s3_local",
 | 
				
			||||||
 | 
					            description=Markdown(""),
 | 
				
			||||||
 | 
					            displayName="s3_local",
 | 
				
			||||||
 | 
					            deleted=False,
 | 
				
			||||||
 | 
					            href=Href(
 | 
				
			||||||
 | 
					                root=AnyUrl(
 | 
				
			||||||
 | 
					                    "http://localhost:8585/api/v1/services/storageServices/dd91cca3-cc54-4776-9efa-48f845cdfb92",
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        dataModel=ContainerDataModel(
 | 
				
			||||||
 | 
					            isPartitioned=False,
 | 
				
			||||||
 | 
					            columns=[
 | 
				
			||||||
 | 
					                Column(
 | 
				
			||||||
 | 
					                    name=ColumnName("CUSTOMERID"),
 | 
				
			||||||
 | 
					                    displayName="CUSTOMERID",
 | 
				
			||||||
 | 
					                    dataType=DataType.INT,
 | 
				
			||||||
 | 
					                    dataTypeDisplay="INT",
 | 
				
			||||||
 | 
					                    fullyQualifiedName=FullyQualifiedEntityName(
 | 
				
			||||||
 | 
					                        's3_local.awsdatalake-testing."dbt-testing/mayur/customers.csv".CUSTOMERID'
 | 
				
			||||||
 | 
					                    ),
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        prefix="/dbt-testing/mayur/customers.csv",
 | 
				
			||||||
 | 
					        numberOfObjects=2103.0,
 | 
				
			||||||
 | 
					        size=652260394.0,
 | 
				
			||||||
 | 
					        fileFormats=[FileFormat.csv],
 | 
				
			||||||
 | 
					        serviceType=StorageServiceType.S3,
 | 
				
			||||||
 | 
					        deleted=False,
 | 
				
			||||||
 | 
					        sourceUrl=SourceUrl(
 | 
				
			||||||
 | 
					            "https://s3.console.aws.amazon.com/s3/buckets/awsdatalake-testing?region=us-east-2&prefix=dbt-testing/mayur/customers.csv/&showversions=false"
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        fullPath="s3://awsdatalake-testing/dbt-testing/mayur/customers.csv",
 | 
				
			||||||
 | 
					        sourceHash="22b1c2f2e7feeaa8f37c6649e01f027d",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MOCK_TABLE_ENTITY = [
 | 
				
			||||||
 | 
					    Table(
 | 
				
			||||||
 | 
					        id=Uuid(UUID("2c040cf8-432d-4597-9517-4794d6142da3")),
 | 
				
			||||||
 | 
					        name=EntityName("demo_data_ext_tbl3"),
 | 
				
			||||||
 | 
					        fullyQualifiedName=FullyQualifiedEntityName(
 | 
				
			||||||
 | 
					            "local_athena.demo.default.demo_data_ext_tbl3"
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        updatedAt=Timestamp(1717071974350),
 | 
				
			||||||
 | 
					        updatedBy="admin",
 | 
				
			||||||
 | 
					        href=Href(
 | 
				
			||||||
 | 
					            root=AnyUrl(
 | 
				
			||||||
 | 
					                "http://localhost:8585/api/v1/tables/2c040cf8-432d-4597-9517-4794d6142da3",
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        tableType=TableType.Regular,
 | 
				
			||||||
 | 
					        columns=[
 | 
				
			||||||
 | 
					            Column(
 | 
				
			||||||
 | 
					                name=ColumnName("CUSTOMERID"),
 | 
				
			||||||
 | 
					                dataType=DataType.INT,
 | 
				
			||||||
 | 
					                dataLength=1,
 | 
				
			||||||
 | 
					                dataTypeDisplay="int",
 | 
				
			||||||
 | 
					                fullyQualifiedName=FullyQualifiedEntityName(
 | 
				
			||||||
 | 
					                    "local_athena.demo.default.demo_data_ext_tbl3.CUSTOMERID"
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					                constraint=Constraint.NULL,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        databaseSchema=EntityReference(
 | 
				
			||||||
 | 
					            id=Uuid(UUID("b03b0229-8a9f-497a-a675-74cb24a9be74")),
 | 
				
			||||||
 | 
					            type="databaseSchema",
 | 
				
			||||||
 | 
					            name="default",
 | 
				
			||||||
 | 
					            fullyQualifiedName="local_athena.demo.default",
 | 
				
			||||||
 | 
					            displayName="default",
 | 
				
			||||||
 | 
					            deleted=False,
 | 
				
			||||||
 | 
					            href=Href(
 | 
				
			||||||
 | 
					                root=AnyUrl(
 | 
				
			||||||
 | 
					                    "http://localhost:8585/api/v1/databaseSchemas/b03b0229-8a9f-497a-a675-74cb24a9be74",
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        database=EntityReference(
 | 
				
			||||||
 | 
					            id=Uuid(UUID("f054c55c-34bf-4c5f-addd-5cc26c7c832a")),
 | 
				
			||||||
 | 
					            type="database",
 | 
				
			||||||
 | 
					            name="demo",
 | 
				
			||||||
 | 
					            fullyQualifiedName="local_athena.demo",
 | 
				
			||||||
 | 
					            displayName="demo",
 | 
				
			||||||
 | 
					            deleted=False,
 | 
				
			||||||
 | 
					            href=Href(
 | 
				
			||||||
 | 
					                root=AnyUrl(
 | 
				
			||||||
 | 
					                    "http://localhost:8585/api/v1/databases/f054c55c-34bf-4c5f-addd-5cc26c7c832a",
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        service=EntityReference(
 | 
				
			||||||
 | 
					            id=Uuid(UUID("5e98afd3-7257-4c35-a560-f4c25b0f4b97")),
 | 
				
			||||||
 | 
					            type="databaseService",
 | 
				
			||||||
 | 
					            name="local_athena",
 | 
				
			||||||
 | 
					            fullyQualifiedName="local_athena",
 | 
				
			||||||
 | 
					            displayName="local_athena",
 | 
				
			||||||
 | 
					            deleted=False,
 | 
				
			||||||
 | 
					            href=Href(
 | 
				
			||||||
 | 
					                root=AnyUrl(
 | 
				
			||||||
 | 
					                    "http://localhost:8585/api/v1/services/databaseServices/5e98afd3-7257-4c35-a560-f4c25b0f4b97",
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        serviceType=DatabaseServiceType.Athena,
 | 
				
			||||||
 | 
					        deleted=False,
 | 
				
			||||||
 | 
					        sourceHash="824e80b1c79b0c4ae0acd99d2338e149",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					EXPECTED_COLUMN_LINEAGE = [
 | 
				
			||||||
 | 
					    ColumnLineage(
 | 
				
			||||||
 | 
					        fromColumns=[
 | 
				
			||||||
 | 
					            FullyQualifiedEntityName(
 | 
				
			||||||
 | 
					                's3_local.awsdatalake-testing."dbt-testing/mayur/customers.csv".CUSTOMERID'
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        toColumn=FullyQualifiedEntityName(
 | 
				
			||||||
 | 
					            "local_athena.demo.default.demo_data_ext_tbl3.CUSTOMERID"
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mock_athena_config = {
 | 
				
			||||||
 | 
					    "source": {
 | 
				
			||||||
 | 
					        "type": "Athena",
 | 
				
			||||||
 | 
					        "serviceName": "test_athena",
 | 
				
			||||||
 | 
					        "serviceConnection": {
 | 
				
			||||||
 | 
					            "config": {
 | 
				
			||||||
 | 
					                "type": "Athena",
 | 
				
			||||||
 | 
					                "databaseName": "mydatabase",
 | 
				
			||||||
 | 
					                "awsConfig": {
 | 
				
			||||||
 | 
					                    "awsAccessKeyId": "dummy",
 | 
				
			||||||
 | 
					                    "awsSecretAccessKey": "dummy",
 | 
				
			||||||
 | 
					                    "awsRegion": "us-east-2",
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					                "s3StagingDir": "https://s3-directory-for-datasource.com",
 | 
				
			||||||
 | 
					                "workgroup": "workgroup name",
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        "sourceConfig": {"config": {"type": "DatabaseMetadata"}},
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "sink": {"type": "metadata-rest", "config": {}},
 | 
				
			||||||
 | 
					    "workflowConfig": {
 | 
				
			||||||
 | 
					        "openMetadataServerConfig": {
 | 
				
			||||||
 | 
					            "hostPort": "http://localhost:8585/api",
 | 
				
			||||||
 | 
					            "authProvider": "openmetadata",
 | 
				
			||||||
 | 
					            "securityConfig": {
 | 
				
			||||||
 | 
					                "jwtToken": "eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg"
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestAthenaService(unittest.TestCase):
 | 
				
			||||||
 | 
					    @patch(
 | 
				
			||||||
 | 
					        "metadata.ingestion.source.database.database_service.DatabaseServiceSource.test_connection"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    def __init__(self, methodName, test_connection) -> None:
 | 
				
			||||||
 | 
					        super().__init__(methodName)
 | 
				
			||||||
 | 
					        test_connection.return_value = False
 | 
				
			||||||
 | 
					        self.config = OpenMetadataWorkflowConfig.parse_obj(mock_athena_config)
 | 
				
			||||||
 | 
					        self.athena_source = AthenaSource.create(
 | 
				
			||||||
 | 
					            mock_athena_config["source"],
 | 
				
			||||||
 | 
					            self.config.workflowConfig.openMetadataServerConfig,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.athena_source.context.get().__dict__[
 | 
				
			||||||
 | 
					            "database_schema"
 | 
				
			||||||
 | 
					        ] = MOCK_DATABASE_SCHEMA.name.root
 | 
				
			||||||
 | 
					        self.athena_source.context.get().__dict__[
 | 
				
			||||||
 | 
					            "database_service"
 | 
				
			||||||
 | 
					        ] = MOCK_DATABASE_SERVICE.name.root
 | 
				
			||||||
 | 
					        self.athena_source.context.get().__dict__["database"] = MOCK_DATABASE.name.root
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_get_database_name(self):
 | 
				
			||||||
 | 
					        assert list(self.athena_source.get_database_names()) == EXPECTED_DATABASE_NAMES
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_query_table_names_and_types(self):
 | 
				
			||||||
 | 
					        with patch.object(Inspector, "get_table_names", return_value=[MOCK_TABLE_NAME]):
 | 
				
			||||||
 | 
					            assert (
 | 
				
			||||||
 | 
					                self.athena_source.query_table_names_and_types(
 | 
				
			||||||
 | 
					                    MOCK_DATABASE_SCHEMA.name.root
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                == EXPECTED_QUERY_TABLE_NAMES_TYPES
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_yield_database(self):
 | 
				
			||||||
 | 
					        assert (
 | 
				
			||||||
 | 
					            list(
 | 
				
			||||||
 | 
					                self.athena_source.yield_database(database_name=MOCK_DATABASE.name.root)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            == EXPECTED_DATABASES
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_column_lineage(self):
 | 
				
			||||||
 | 
					        columns_list = [column.name.root for column in MOCK_TABLE_ENTITY[0].columns]
 | 
				
			||||||
 | 
					        column_lineage = self.athena_source._get_column_lineage(
 | 
				
			||||||
 | 
					            MOCK_LOCATION_ENTITY[0].dataModel, MOCK_TABLE_ENTITY[0], columns_list
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        assert column_lineage == EXPECTED_COLUMN_LINEAGE
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user