diff --git a/ingestion/src/metadata/ingestion/source/database/athena/metadata.py b/ingestion/src/metadata/ingestion/source/database/athena/metadata.py index 847c94beb79..f0240af398a 100644 --- a/ingestion/src/metadata/ingestion/source/database/athena/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/athena/metadata.py @@ -17,6 +17,7 @@ from typing import Iterable, Optional, Tuple from pyathena.sqlalchemy.base import AthenaDialect from sqlalchemy.engine.reflection import Inspector +from metadata.clients.aws_client import AWSClient from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema from metadata.generated.schema.entity.data.table import ( Column, @@ -53,6 +54,7 @@ from metadata.ingestion.source.database.common_db_source import ( from metadata.ingestion.source.database.external_table_lineage_mixin import ( ExternalTableLineageMixin, ) +from metadata.ingestion.source.database.glue.models import DatabasePage from metadata.utils import fqn from metadata.utils.logger import ingestion_logger from metadata.utils.sqlalchemy_utils import get_all_table_ddls, get_table_ddl @@ -112,6 +114,29 @@ class AthenaSource(ExternalTableLineageMixin, CommonDbSourceService): connection=self.service_connection ) self.external_location_map = {} + self.schema_description_map = {} + + def prepare(self): + """ + Prepare the source by fetching the schema descriptions from the AWS Glue service. + """ + try: + super().prepare() + glue_client = AWSClient(self.service_connection.awsConfig).get_glue_client() + paginator = glue_client.get_paginator("get_databases") + for page in paginator.paginate(): + database_page = DatabasePage(**page) + for database in database_page.DatabaseList or []: + if database.Description: + self.schema_description_map[ + database.Name + ] = database.Description + except Exception as exc: + logger.warning(f"Error preparing Athena source: {exc}") + logger.debug(traceback.format_exc()) + + def get_schema_description(self, schema_name: str) -> Optional[str]: + return self.schema_description_map.get(schema_name) def query_table_names_and_types( self, schema_name: str @@ -259,6 +284,7 @@ class AthenaSource(ExternalTableLineageMixin, CommonDbSourceService): ) ) + # pylint: disable=arguments-differ def get_table_description( self, schema_name: str, table_name: str, inspector: Inspector ) -> str: