diff --git a/ingestion/src/metadata/ingestion/source/database/athena.py b/ingestion/src/metadata/ingestion/source/database/athena.py index f6324fc67d6..bf5a5229934 100644 --- a/ingestion/src/metadata/ingestion/source/database/athena.py +++ b/ingestion/src/metadata/ingestion/source/database/athena.py @@ -15,6 +15,7 @@ from typing import Iterable, Optional, Tuple from pyathena.sqlalchemy_athena import AthenaDialect from sqlalchemy import types +from sqlalchemy.engine import reflection from metadata.generated.schema.entity.data.table import Table, TableType from metadata.generated.schema.entity.services.connections.database.athenaConnection import ( @@ -28,6 +29,7 @@ from metadata.generated.schema.metadataIngestion.workflow import ( ) from metadata.ingestion.api.source import InvalidSourceException from metadata.ingestion.source import sqa_types +from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser from metadata.ingestion.source.database.common_db_source import CommonDbSourceService from metadata.utils import fqn from metadata.utils.filters import filter_by_table @@ -41,6 +43,7 @@ def _get_column_type(self, type_): Function overwritten from AthenaDialect to add custom SQA typing. """ + type_ = type_.replace(" ", "").lower() match = self._pattern_column_type.match(type_) # pylint: disable=protected-access if match: name = match.group(1).lower() @@ -84,6 +87,14 @@ def _get_column_type(self, type_): col_type = types.VARCHAR if length: args = [int(length)] + elif type_.startswith("array"): + parsed_type = ( + ColumnTypeParser._parse_datatype_string( # pylint: disable=protected-access + type_ + ) + ) + col_type = col_map["array"] + args = [col_map.get(parsed_type.get("arrayDataType").lower(), types.String)] elif col_map.get(name): col_type = col_map.get(name) else: @@ -92,7 +103,54 @@ def _get_column_type(self, type_): return col_type(*args) +def is_complex(type_: str): + return ( + type_.startswith("array") + or type_.startswith("map") + or type_.startswith("struct") + or type_.startswith("row") + ) + + +@reflection.cache +def get_columns(self, connection, table_name, schema=None, **kw): + """ + Method to handle table columns + """ + metadata = self._get_table( # pylint: disable=protected-access + connection, table_name, schema=schema, **kw + ) + columns = [ + { + "name": c.name, + "type": self._get_column_type(c.type), # pylint: disable=protected-access + "nullable": True, + "default": None, + "autoincrement": False, + "comment": c.comment, + "raw_data_type": c.type if is_complex(c.type) else None, + "dialect_options": {"awsathena_partition": None}, + } + for c in metadata.columns + ] + columns += [ + { + "name": c.name, + "type": self._get_column_type(c.type), # pylint: disable=protected-access + "nullable": True, + "default": None, + "autoincrement": False, + "comment": c.comment, + "raw_data_type": c.type if is_complex(c.type) else None, + "dialect_options": {"awsathena_partition": True}, + } + for c in metadata.partition_keys + ] + return columns + + AthenaDialect._get_column_type = _get_column_type # pylint: disable=protected-access +AthenaDialect.get_columns = get_columns class AthenaSource(CommonDbSourceService):