diff --git a/ingestion/src/metadata/ingestion/source/database/athena/metadata.py b/ingestion/src/metadata/ingestion/source/database/athena/metadata.py index 86361862efd..255d5ca0649 100644 --- a/ingestion/src/metadata/ingestion/source/database/athena/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/athena/metadata.py @@ -11,13 +11,18 @@ """Athena source module""" -from typing import Iterable +from typing import Iterable, Tuple from pyathena.sqlalchemy.base import AthenaDialect from sqlalchemy import types from sqlalchemy.engine import reflection +from sqlalchemy.engine.reflection import Inspector -from metadata.generated.schema.entity.data.table import TableType +from metadata.generated.schema.entity.data.table import ( + IntervalType, + TablePartition, + TableType, +) from metadata.generated.schema.entity.services.connections.database.athenaConnection import ( AthenaConnection, ) @@ -121,10 +126,14 @@ def get_columns(self, connection, table_name, schema=None, **kw): "comment": c.comment, "system_data_type": c.type, "is_complex": is_complex_type(c.type), - "dialect_options": {"awsathena_partition": None}, + "dialect_options": {"awsathena_partition": True}, } - for c in metadata.columns + for c in metadata.partition_keys ] + + if kw.get("only_partition_columns"): + return columns + columns += [ { "name": c.name, @@ -135,10 +144,11 @@ def get_columns(self, connection, table_name, schema=None, **kw): "comment": c.comment, "system_data_type": c.type, "is_complex": is_complex_type(c.type), - "dialect_options": {"awsathena_partition": True}, + "dialect_options": {"awsathena_partition": None}, } - for c in metadata.partition_keys + for c in metadata.columns ] + return columns @@ -185,3 +195,17 @@ class AthenaSource(CommonDbSourceService): TableNameAndType(name=name, type_=TableType.External) for name in self.inspector.get_table_names(schema_name) ] + + def get_table_partition_details( + self, table_name: str, schema_name: str, inspector: Inspector + ) -> Tuple[bool, TablePartition]: + columns = inspector.get_columns( + table_name=table_name, schema=schema_name, only_partition_columns=True + ) + if columns: + partition_details = TablePartition( + intervalType=IntervalType.COLUMN_VALUE.value, + columns=[column["name"] for column in columns], + ) + return True, partition_details + return False, None