diff --git a/ingestion/src/metadata/ingestion/source/database/athena/metadata.py b/ingestion/src/metadata/ingestion/source/database/athena/metadata.py index 9e694a6c067..03ad197e4c9 100644 --- a/ingestion/src/metadata/ingestion/source/database/athena/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/athena/metadata.py @@ -12,12 +12,9 @@ """Athena source module""" import traceback -from copy import deepcopy -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple from pyathena.sqlalchemy.base import AthenaDialect -from sqlalchemy import types -from sqlalchemy.engine import reflection from sqlalchemy.engine.reflection import Inspector from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema @@ -42,18 +39,33 @@ from metadata.ingestion.api.models import Either from metadata.ingestion.api.steps import InvalidSourceException from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification from metadata.ingestion.ometa.ometa_api import OpenMetadata -from metadata.ingestion.source import sqa_types from metadata.ingestion.source.database.athena.client import AthenaLakeFormationClient -from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser +from metadata.ingestion.source.database.athena.utils import ( + _get_column_type, + get_columns, + get_table_options, + get_view_definition, +) from metadata.ingestion.source.database.common_db_source import ( CommonDbSourceService, TableNameAndType, ) +from metadata.ingestion.source.database.external_table_lineage_mixin import ( + ExternalTableLineageMixin, +) from metadata.utils import fqn from metadata.utils.logger import ingestion_logger -from metadata.utils.sqlalchemy_utils import is_complex_type +from metadata.utils.sqlalchemy_utils import get_all_table_ddls, get_table_ddl from metadata.utils.tag_utils import get_ometa_tag_and_classification +AthenaDialect._get_column_type = _get_column_type # pylint: disable=protected-access +AthenaDialect.get_columns = get_columns +AthenaDialect.get_view_definition = get_view_definition +AthenaDialect.get_table_options = get_table_options + +Inspector.get_all_table_ddls = get_all_table_ddls +Inspector.get_table_ddl = get_table_ddl + logger = ingestion_logger() ATHENA_TAG = "ATHENA TAG" @@ -72,165 +84,7 @@ ATHENA_INTERVAL_TYPE_MAP = { } -def _get_column_type(self, type_): - """ - Function overwritten from AthenaDialect - to add custom SQA typing. - """ - type_ = type_.replace(" ", "").lower() - match = self._pattern_column_type.match(type_) # pylint: disable=protected-access - if match: - name = match.group(1).lower() - length = match.group(2) - else: - name = type_.lower() - length = None - - args = [] - col_map = { - "boolean": types.BOOLEAN, - "float": types.FLOAT, - "double": types.FLOAT, - "real": types.FLOAT, - "tinyint": types.INTEGER, - "smallint": types.INTEGER, - "integer": types.INTEGER, - "int": types.INTEGER, - "bigint": types.BIGINT, - "string": types.String, - "date": types.DATE, - "timestamp": types.TIMESTAMP, - "binary": types.BINARY, - "varbinary": types.BINARY, - "array": types.ARRAY, - "json": types.JSON, - "struct": sqa_types.SQAStruct, - "row": sqa_types.SQAStruct, - "map": sqa_types.SQAMap, - "decimal": types.DECIMAL, - "varchar": types.VARCHAR, - "char": types.CHAR, - } - if name in ["decimal", "char", "varchar"]: - col_type = col_map[name] - if length: - args = [int(l) for l in length.split(",")] - elif type_.startswith("array"): - parsed_type = ( - ColumnTypeParser._parse_datatype_string( # pylint: disable=protected-access - type_ - ) - ) - col_type = col_map["array"] - if parsed_type["arrayDataType"].lower().startswith("array"): - # as OpenMetadata doesn't store any details on children of array, we put - # in type as string as default to avoid Array item_type required issue - # from sqlalchemy types - args = [types.String] - else: - args = [col_map.get(parsed_type.get("arrayDataType").lower(), types.String)] - elif col_map.get(name): - col_type = col_map.get(name) - else: - logger.warning(f"Did not recognize type '{type_}'") - col_type = types.NullType - return col_type(*args) - - -def _get_projection_details( - columns: List[Dict], projection_parameters: Dict -) -> List[Dict]: - """Get the projection details for the columns - - Args: - columns (List[Dict]): list of columns - projection_parameters (Dict): projection parameters - """ - if not projection_parameters: - return columns - - columns = deepcopy(columns) - for col in columns: - projection_details = next( - ({k: v} for k, v in projection_parameters.items() if k == col["name"]), None - ) - if projection_details: - col["projection_type"] = projection_details[col["name"]] - - return columns - - -@reflection.cache -def get_columns(self, connection, table_name, schema=None, **kw): - """ - Method to handle table columns - """ - metadata = self._get_table( # pylint: disable=protected-access - connection, table_name, schema=schema, **kw - ) - columns = [ - { - "name": c.name, - "type": self._get_column_type(c.type), # pylint: disable=protected-access - "nullable": True, - "default": None, - "autoincrement": False, - "comment": c.comment, - "system_data_type": c.type, - "is_complex": is_complex_type(c.type), - "dialect_options": {"awsathena_partition": True}, - } - for c in metadata.partition_keys - ] - - if kw.get("only_partition_columns"): - # Return projected partition information to set partition type in `get_table_partition_details` - # projected partition fields are stored in the form of `projection..type` as a table parameter - projection_parameters = { - key_.split(".")[1]: value_ - for key_, value_ in metadata.parameters.items() - if key_.startswith("projection") and key_.endswith("type") - } - columns = _get_projection_details(columns, projection_parameters) - return columns - - columns += [ - { - "name": c.name, - "type": self._get_column_type(c.type), # pylint: disable=protected-access - "nullable": True, - "default": None, - "autoincrement": False, - "comment": c.comment, - "system_data_type": c.type, - "is_complex": is_complex_type(c.type), - "dialect_options": {"awsathena_partition": None}, - } - for c in metadata.columns - ] - - return columns - - -# pylint: disable=unused-argument -@reflection.cache -def get_view_definition(self, connection, view_name, schema=None, **kw): - """ - Gets the view definition - """ - full_view_name = f'"{view_name}"' if not schema else f'"{schema}"."{view_name}"' - res = connection.execute(f"SHOW CREATE VIEW {full_view_name}").fetchall() - if res: - return "\n".join(i[0] for i in res) - return None - - -AthenaDialect._get_column_type = _get_column_type # pylint: disable=protected-access -AthenaDialect.get_columns = get_columns -AthenaDialect.get_view_definition = get_view_definition - - -class AthenaSource(CommonDbSourceService): +class AthenaSource(ExternalTableLineageMixin, CommonDbSourceService): """ Implements the necessary methods to extract Database metadata from Athena Source @@ -257,6 +111,7 @@ class AthenaSource(CommonDbSourceService): self.athena_lake_formation_client = AthenaLakeFormationClient( connection=self.service_connection ) + self.external_location_map = {} def query_table_names_and_types( self, schema_name: str @@ -395,3 +250,23 @@ class AthenaSource(CommonDbSourceService): stackTrace=traceback.format_exc(), ) ) + + def get_table_description( + self, schema_name: str, table_name: str, inspector: Inspector + ) -> str: + description = None + try: + table_info: dict = inspector.get_table_comment(table_name, schema_name) + table_option = inspector.get_table_options(table_name, schema_name) + self.external_location_map[ + (self.context.get().database, schema_name, table_name) + ] = table_option.get("awsathena_location") + # Catch any exception without breaking the ingestion + except Exception as exc: # pylint: disable=broad-except + logger.debug(traceback.format_exc()) + logger.warning( + f"Table description error for table [{schema_name}.{table_name}]: {exc}" + ) + else: + description = table_info.get("text") + return description diff --git a/ingestion/src/metadata/ingestion/source/database/athena/utils.py b/ingestion/src/metadata/ingestion/source/database/athena/utils.py new file mode 100644 index 00000000000..3a9c0cfd38b --- /dev/null +++ b/ingestion/src/metadata/ingestion/source/database/athena/utils.py @@ -0,0 +1,191 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Athena utils module""" + +from copy import deepcopy +from typing import Dict, List, Optional + +from pyathena.sqlalchemy.util import _HashableDict +from sqlalchemy import types +from sqlalchemy.engine import reflection + +from metadata.ingestion.source import sqa_types +from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser +from metadata.utils.sqlalchemy_utils import is_complex_type + + +@reflection.cache +def _get_column_type(self, type_): + """ + Function overwritten from AthenaDialect + to add custom SQA typing. + """ + type_ = type_.replace(" ", "").lower() + match = self._pattern_column_type.match(type_) # pylint: disable=protected-access + if match: + name = match.group(1).lower() + length = match.group(2) + else: + name = type_.lower() + length = None + + args = [] + col_map = { + "boolean": types.BOOLEAN, + "float": types.FLOAT, + "double": types.FLOAT, + "real": types.FLOAT, + "tinyint": types.INTEGER, + "smallint": types.INTEGER, + "integer": types.INTEGER, + "int": types.INTEGER, + "bigint": types.BIGINT, + "string": types.String, + "date": types.DATE, + "timestamp": types.TIMESTAMP, + "binary": types.BINARY, + "varbinary": types.BINARY, + "array": types.ARRAY, + "json": types.JSON, + "struct": sqa_types.SQAStruct, + "row": sqa_types.SQAStruct, + "map": sqa_types.SQAMap, + "decimal": types.DECIMAL, + "varchar": types.VARCHAR, + "char": types.CHAR, + } + if name in ["decimal", "char", "varchar"]: + col_type = col_map[name] + if length: + args = [int(l) for l in length.split(",")] + elif type_.startswith("array"): + parsed_type = ( + ColumnTypeParser._parse_datatype_string( # pylint: disable=protected-access + type_ + ) + ) + col_type = col_map["array"] + if parsed_type["arrayDataType"].lower().startswith("array"): + # as OpenMetadata doesn't store any details on children of array, we put + # in type as string as default to avoid Array item_type required issue + # from sqlalchemy types + args = [types.String] + else: + args = [col_map.get(parsed_type.get("arrayDataType").lower(), types.String)] + elif col_map.get(name): + col_type = col_map.get(name) + else: + logger.warning(f"Did not recognize type '{type_}'") + col_type = types.NullType + return col_type(*args) + + +# pylint: disable=unused-argument +def _get_projection_details( + columns: List[Dict], projection_parameters: Dict +) -> List[Dict]: + """Get the projection details for the columns + + Args: + columns (List[Dict]): list of columns + projection_parameters (Dict): projection parameters + """ + if not projection_parameters: + return columns + + columns = deepcopy(columns) + for col in columns: + projection_details = next( + ({k: v} for k, v in projection_parameters.items() if k == col["name"]), None + ) + if projection_details: + col["projection_type"] = projection_details[col["name"]] + + return columns + + +@reflection.cache +def get_columns(self, connection, table_name, schema=None, **kw): + """ + Method to handle table columns + """ + metadata = self._get_table( # pylint: disable=protected-access + connection, table_name, schema=schema, **kw + ) + columns = [ + { + "name": c.name, + "type": self._get_column_type(c.type), # pylint: disable=protected-access + "nullable": True, + "default": None, + "autoincrement": False, + "comment": c.comment, + "system_data_type": c.type, + "is_complex": is_complex_type(c.type), + "dialect_options": {"awsathena_partition": True}, + } + for c in metadata.partition_keys + ] + + if kw.get("only_partition_columns"): + # Return projected partition information to set partition type in `get_table_partition_details` + # projected partition fields are stored in the form of `projection..type` as a table parameter + projection_parameters = { + key_.split(".")[1]: value_ + for key_, value_ in metadata.parameters.items() + if key_.startswith("projection") and key_.endswith("type") + } + columns = _get_projection_details(columns, projection_parameters) + return columns + + columns += [ + { + "name": c.name, + "type": self._get_column_type(c.type), # pylint: disable=protected-access + "nullable": True, + "default": None, + "autoincrement": False, + "comment": c.comment, + "system_data_type": c.type, + "is_complex": is_complex_type(c.type), + "dialect_options": {"awsathena_partition": None}, + } + for c in metadata.columns + ] + + return columns + + +@reflection.cache +def get_view_definition(self, connection, view_name, schema=None, **kw): + """ + Gets the view definition + """ + full_view_name = f'"{view_name}"' if not schema else f'"{schema}"."{view_name}"' + res = connection.execute(f"SHOW CREATE VIEW {full_view_name}").fetchall() + if res: + return "\n".join(i[0] for i in res) + return None + + +def get_table_options( + self, connection: "Connection", table_name: str, schema: Optional[str] = None, **kw +): + metadata = self._get_table(connection, table_name, schema=schema, **kw) + return { + "awsathena_location": metadata.location, + "awsathena_compression": metadata.compression, + "awsathena_row_format": metadata.row_format, + "awsathena_file_format": metadata.file_format, + "awsathena_serdeproperties": _HashableDict(metadata.serde_properties), + "awsathena_tblproperties": _HashableDict(metadata.table_properties), + } diff --git a/ingestion/src/metadata/ingestion/source/database/external_table_lineage_mixin.py b/ingestion/src/metadata/ingestion/source/database/external_table_lineage_mixin.py index 06c7c6ca270..2e2458e25f1 100644 --- a/ingestion/src/metadata/ingestion/source/database/external_table_lineage_mixin.py +++ b/ingestion/src/metadata/ingestion/source/database/external_table_lineage_mixin.py @@ -14,13 +14,20 @@ External Table Lineage Mixin import traceback from abc import ABC -from typing import Iterable +from typing import Iterable, List, Optional from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest +from metadata.generated.schema.entity.data.container import ContainerDataModel from metadata.generated.schema.entity.data.table import Table -from metadata.generated.schema.type.entityLineage import EntitiesEdge +from metadata.generated.schema.type.entityLineage import ( + ColumnLineage, + EntitiesEdge, + LineageDetails, +) +from metadata.generated.schema.type.entityLineage import Source as LineageSource from metadata.generated.schema.type.entityReference import EntityReference from metadata.ingestion.api.models import Either +from metadata.ingestion.lineage.sql_lineage import get_column_fqn from metadata.utils import fqn from metadata.utils.logger import ingestion_logger @@ -39,7 +46,7 @@ class ExternalTableLineageMixin(ABC): for table_qualified_tuple, location in self.external_location_map.items() or []: try: location_entity = self.metadata.es_search_container_by_path( - full_path=location + full_path=location, fields="dataModel" ) database_name, schema_name, table_name = table_qualified_tuple @@ -63,6 +70,12 @@ class ExternalTableLineageMixin(ABC): and table_entity and table_entity[0] ): + columns_list = [ + column.name.root for column in table_entity[0].columns + ] + columns_lineage = self._get_column_lineage( + location_entity[0].dataModel, table_entity[0], columns_list + ) yield Either( right=AddLineageRequest( edge=EntitiesEdge( @@ -74,9 +87,51 @@ class ExternalTableLineageMixin(ABC): id=table_entity[0].id, type="table", ), + lineageDetails=LineageDetails( + source=LineageSource.ExternalTableLineage, + columnsLineage=columns_lineage, + ), ) ) ) except Exception as exc: logger.warning(f"Failed to yield external table lineage due to - {exc}") logger.debug(traceback.format_exc()) + + def _get_data_model_column_fqn( + self, data_model_entity: ContainerDataModel, column: str + ) -> Optional[str]: + """ + Get fqn of column if exist in data model entity + """ + if not data_model_entity: + return None + for entity_column in data_model_entity.columns: + if entity_column.displayName.lower() == column.lower(): + return entity_column.fullyQualifiedName.root + return None + + def _get_column_lineage( + self, + data_model_entity: ContainerDataModel, + table_entity: Table, + columns_list: List[str], + ) -> List[ColumnLineage]: + """ + Get the column lineage + """ + try: + column_lineage = [] + for field in columns_list or []: + from_column = self._get_data_model_column_fqn( + data_model_entity=data_model_entity, column=field + ) + to_column = get_column_fqn(table_entity=table_entity, column=field) + if from_column and to_column: + column_lineage.append( + ColumnLineage(fromColumns=[from_column], toColumn=to_column) + ) + return column_lineage + except Exception as exc: + logger.debug(f"Error to get column lineage: {exc}") + logger.debug(traceback.format_exc()) diff --git a/ingestion/tests/unit/topology/database/test_athena.py b/ingestion/tests/unit/topology/database/test_athena.py new file mode 100644 index 00000000000..506d1ca2804 --- /dev/null +++ b/ingestion/tests/unit/topology/database/test_athena.py @@ -0,0 +1,325 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test athena source +""" + +import unittest +from unittest.mock import patch +from uuid import UUID + +from pydantic import AnyUrl +from sqlalchemy.engine.reflection import Inspector + +from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest +from metadata.generated.schema.entity.data.container import ( + Container, + ContainerDataModel, + FileFormat, +) +from metadata.generated.schema.entity.data.database import Database +from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema +from metadata.generated.schema.entity.data.table import ( + Column, + ColumnName, + Constraint, + DataType, + Table, + TableType, +) +from metadata.generated.schema.entity.services.databaseService import ( + DatabaseConnection, + DatabaseService, + DatabaseServiceType, +) +from metadata.generated.schema.entity.services.storageService import StorageServiceType +from metadata.generated.schema.metadataIngestion.workflow import ( + OpenMetadataWorkflowConfig, +) +from metadata.generated.schema.type.basic import ( + EntityName, + FullyQualifiedEntityName, + Href, + Markdown, + SourceUrl, + Timestamp, + Uuid, +) +from metadata.generated.schema.type.entityLineage import ColumnLineage +from metadata.generated.schema.type.entityReference import EntityReference +from metadata.ingestion.api.models import Either +from metadata.ingestion.source.database.athena.metadata import AthenaSource +from metadata.ingestion.source.database.common_db_source import TableNameAndType + +EXPECTED_DATABASE_NAMES = ["mydatabase"] +MOCK_DATABASE_SCHEMA = DatabaseSchema( + id="2aaa012e-099a-11ed-861d-0242ac120056", + name="sample_instance", + fullyQualifiedName="sample_athena_schema.sample_db.sample_instance", + displayName="default", + description="", + database=EntityReference( + id="2aaa012e-099a-11ed-861d-0242ac120002", + type="database", + ), + service=EntityReference( + id="85811038-099a-11ed-861d-0242ac120002", + type="databaseService", + ), +) +MOCK_DATABASE_SERVICE = DatabaseService( + id="85811038-099a-11ed-861d-0242ac120002", + name="sample_athena_service", + connection=DatabaseConnection(), + serviceType=DatabaseServiceType.Glue, +) +MOCK_DATABASE = Database( + id="2aaa012e-099a-11ed-861d-0242ac120002", + name="sample_db", + fullyQualifiedName="test_athena.sample_db", + displayName="sample_db", + description="", + service=EntityReference( + id="85811038-099a-11ed-861d-0242ac120002", + type="databaseService", + ), +) +MOCK_TABLE_NAME = "sample_table" +EXPECTED_DATABASES = [ + Either( + right=CreateDatabaseRequest( + name=EntityName("sample_db"), + service=FullyQualifiedEntityName("sample_athena_service"), + default=False, + ), + ) +] +EXPECTED_QUERY_TABLE_NAMES_TYPES = [ + TableNameAndType(name="sample_table", type_=TableType.External) +] +MOCK_LOCATION_ENTITY = [ + Container( + id=Uuid(UUID("9c489754-bb60-435b-b2a5-0e43100cf950")), + name=EntityName("dbt-testing/mayur/customers.csv"), + fullyQualifiedName=FullyQualifiedEntityName( + 's3_local.awsdatalake-testing."dbt-testing/mayur/customers.csv"' + ), + updatedAt=Timestamp(1717070902713), + updatedBy="admin", + href=Href( + root=AnyUrl( + "http://localhost:8585/api/v1/containers/9c489754-bb60-435b-b2a5-0e43100cf950", + ) + ), + service=EntityReference( + id=Uuid(UUID("dd91cca3-cc54-4776-9efa-48f845cdfb92")), + type="storageService", + name="s3_local", + fullyQualifiedName="s3_local", + description=Markdown(""), + displayName="s3_local", + deleted=False, + href=Href( + root=AnyUrl( + "http://localhost:8585/api/v1/services/storageServices/dd91cca3-cc54-4776-9efa-48f845cdfb92", + ) + ), + ), + dataModel=ContainerDataModel( + isPartitioned=False, + columns=[ + Column( + name=ColumnName("CUSTOMERID"), + displayName="CUSTOMERID", + dataType=DataType.INT, + dataTypeDisplay="INT", + fullyQualifiedName=FullyQualifiedEntityName( + 's3_local.awsdatalake-testing."dbt-testing/mayur/customers.csv".CUSTOMERID' + ), + ), + ], + ), + prefix="/dbt-testing/mayur/customers.csv", + numberOfObjects=2103.0, + size=652260394.0, + fileFormats=[FileFormat.csv], + serviceType=StorageServiceType.S3, + deleted=False, + sourceUrl=SourceUrl( + "https://s3.console.aws.amazon.com/s3/buckets/awsdatalake-testing?region=us-east-2&prefix=dbt-testing/mayur/customers.csv/&showversions=false" + ), + fullPath="s3://awsdatalake-testing/dbt-testing/mayur/customers.csv", + sourceHash="22b1c2f2e7feeaa8f37c6649e01f027d", + ) +] + +MOCK_TABLE_ENTITY = [ + Table( + id=Uuid(UUID("2c040cf8-432d-4597-9517-4794d6142da3")), + name=EntityName("demo_data_ext_tbl3"), + fullyQualifiedName=FullyQualifiedEntityName( + "local_athena.demo.default.demo_data_ext_tbl3" + ), + updatedAt=Timestamp(1717071974350), + updatedBy="admin", + href=Href( + root=AnyUrl( + "http://localhost:8585/api/v1/tables/2c040cf8-432d-4597-9517-4794d6142da3", + ) + ), + tableType=TableType.Regular, + columns=[ + Column( + name=ColumnName("CUSTOMERID"), + dataType=DataType.INT, + dataLength=1, + dataTypeDisplay="int", + fullyQualifiedName=FullyQualifiedEntityName( + "local_athena.demo.default.demo_data_ext_tbl3.CUSTOMERID" + ), + constraint=Constraint.NULL, + ), + ], + databaseSchema=EntityReference( + id=Uuid(UUID("b03b0229-8a9f-497a-a675-74cb24a9be74")), + type="databaseSchema", + name="default", + fullyQualifiedName="local_athena.demo.default", + displayName="default", + deleted=False, + href=Href( + root=AnyUrl( + "http://localhost:8585/api/v1/databaseSchemas/b03b0229-8a9f-497a-a675-74cb24a9be74", + ) + ), + ), + database=EntityReference( + id=Uuid(UUID("f054c55c-34bf-4c5f-addd-5cc26c7c832a")), + type="database", + name="demo", + fullyQualifiedName="local_athena.demo", + displayName="demo", + deleted=False, + href=Href( + root=AnyUrl( + "http://localhost:8585/api/v1/databases/f054c55c-34bf-4c5f-addd-5cc26c7c832a", + ) + ), + ), + service=EntityReference( + id=Uuid(UUID("5e98afd3-7257-4c35-a560-f4c25b0f4b97")), + type="databaseService", + name="local_athena", + fullyQualifiedName="local_athena", + displayName="local_athena", + deleted=False, + href=Href( + root=AnyUrl( + "http://localhost:8585/api/v1/services/databaseServices/5e98afd3-7257-4c35-a560-f4c25b0f4b97", + ) + ), + ), + serviceType=DatabaseServiceType.Athena, + deleted=False, + sourceHash="824e80b1c79b0c4ae0acd99d2338e149", + ) +] +EXPECTED_COLUMN_LINEAGE = [ + ColumnLineage( + fromColumns=[ + FullyQualifiedEntityName( + 's3_local.awsdatalake-testing."dbt-testing/mayur/customers.csv".CUSTOMERID' + ) + ], + toColumn=FullyQualifiedEntityName( + "local_athena.demo.default.demo_data_ext_tbl3.CUSTOMERID" + ), + ) +] + +mock_athena_config = { + "source": { + "type": "Athena", + "serviceName": "test_athena", + "serviceConnection": { + "config": { + "type": "Athena", + "databaseName": "mydatabase", + "awsConfig": { + "awsAccessKeyId": "dummy", + "awsSecretAccessKey": "dummy", + "awsRegion": "us-east-2", + }, + "s3StagingDir": "https://s3-directory-for-datasource.com", + "workgroup": "workgroup name", + } + }, + "sourceConfig": {"config": {"type": "DatabaseMetadata"}}, + }, + "sink": {"type": "metadata-rest", "config": {}}, + "workflowConfig": { + "openMetadataServerConfig": { + "hostPort": "http://localhost:8585/api", + "authProvider": "openmetadata", + "securityConfig": { + "jwtToken": "eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg" + }, + } + }, +} + + +class TestAthenaService(unittest.TestCase): + @patch( + "metadata.ingestion.source.database.database_service.DatabaseServiceSource.test_connection" + ) + def __init__(self, methodName, test_connection) -> None: + super().__init__(methodName) + test_connection.return_value = False + self.config = OpenMetadataWorkflowConfig.parse_obj(mock_athena_config) + self.athena_source = AthenaSource.create( + mock_athena_config["source"], + self.config.workflowConfig.openMetadataServerConfig, + ) + self.athena_source.context.get().__dict__[ + "database_schema" + ] = MOCK_DATABASE_SCHEMA.name.root + self.athena_source.context.get().__dict__[ + "database_service" + ] = MOCK_DATABASE_SERVICE.name.root + self.athena_source.context.get().__dict__["database"] = MOCK_DATABASE.name.root + + def test_get_database_name(self): + assert list(self.athena_source.get_database_names()) == EXPECTED_DATABASE_NAMES + + def test_query_table_names_and_types(self): + with patch.object(Inspector, "get_table_names", return_value=[MOCK_TABLE_NAME]): + assert ( + self.athena_source.query_table_names_and_types( + MOCK_DATABASE_SCHEMA.name.root + ) + == EXPECTED_QUERY_TABLE_NAMES_TYPES + ) + + def test_yield_database(self): + assert ( + list( + self.athena_source.yield_database(database_name=MOCK_DATABASE.name.root) + ) + == EXPECTED_DATABASES + ) + + def test_column_lineage(self): + columns_list = [column.name.root for column in MOCK_TABLE_ENTITY[0].columns] + column_lineage = self.athena_source._get_column_lineage( + MOCK_LOCATION_ENTITY[0].dataModel, MOCK_TABLE_ENTITY[0], columns_list + ) + assert column_lineage == EXPECTED_COLUMN_LINEAGE