Fix #14676: Athena S3 Lineage (#16426)

* get table ddl for athena tables

* changes in method to get all table ddls

* external table/container lineage for athena

* column lineage for external table lineage

* unittest for athena

* pyformat changes

* add external table lineage unit test

* fix unittest with pydantic v2 changes

* fix unittest formating

* fix code smell
This commit is contained in:
harshsoni2024 2024-06-26 19:53:36 +05:30 committed by GitHub
parent afec7703cc
commit 3f5bc1948d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 615 additions and 169 deletions

View File

@ -12,12 +12,9 @@
"""Athena source module"""
import traceback
from copy import deepcopy
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Iterable, Optional, Tuple
from pyathena.sqlalchemy.base import AthenaDialect
from sqlalchemy import types
from sqlalchemy.engine import reflection
from sqlalchemy.engine.reflection import Inspector
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
@ -42,18 +39,33 @@ from metadata.ingestion.api.models import Either
from metadata.ingestion.api.steps import InvalidSourceException
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source import sqa_types
from metadata.ingestion.source.database.athena.client import AthenaLakeFormationClient
from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser
from metadata.ingestion.source.database.athena.utils import (
_get_column_type,
get_columns,
get_table_options,
get_view_definition,
)
from metadata.ingestion.source.database.common_db_source import (
CommonDbSourceService,
TableNameAndType,
)
from metadata.ingestion.source.database.external_table_lineage_mixin import (
ExternalTableLineageMixin,
)
from metadata.utils import fqn
from metadata.utils.logger import ingestion_logger
from metadata.utils.sqlalchemy_utils import is_complex_type
from metadata.utils.sqlalchemy_utils import get_all_table_ddls, get_table_ddl
from metadata.utils.tag_utils import get_ometa_tag_and_classification
AthenaDialect._get_column_type = _get_column_type # pylint: disable=protected-access
AthenaDialect.get_columns = get_columns
AthenaDialect.get_view_definition = get_view_definition
AthenaDialect.get_table_options = get_table_options
Inspector.get_all_table_ddls = get_all_table_ddls
Inspector.get_table_ddl = get_table_ddl
logger = ingestion_logger()
ATHENA_TAG = "ATHENA TAG"
@ -72,165 +84,7 @@ ATHENA_INTERVAL_TYPE_MAP = {
}
def _get_column_type(self, type_):
"""
Function overwritten from AthenaDialect
to add custom SQA typing.
"""
type_ = type_.replace(" ", "").lower()
match = self._pattern_column_type.match(type_) # pylint: disable=protected-access
if match:
name = match.group(1).lower()
length = match.group(2)
else:
name = type_.lower()
length = None
args = []
col_map = {
"boolean": types.BOOLEAN,
"float": types.FLOAT,
"double": types.FLOAT,
"real": types.FLOAT,
"tinyint": types.INTEGER,
"smallint": types.INTEGER,
"integer": types.INTEGER,
"int": types.INTEGER,
"bigint": types.BIGINT,
"string": types.String,
"date": types.DATE,
"timestamp": types.TIMESTAMP,
"binary": types.BINARY,
"varbinary": types.BINARY,
"array": types.ARRAY,
"json": types.JSON,
"struct": sqa_types.SQAStruct,
"row": sqa_types.SQAStruct,
"map": sqa_types.SQAMap,
"decimal": types.DECIMAL,
"varchar": types.VARCHAR,
"char": types.CHAR,
}
if name in ["decimal", "char", "varchar"]:
col_type = col_map[name]
if length:
args = [int(l) for l in length.split(",")]
elif type_.startswith("array"):
parsed_type = (
ColumnTypeParser._parse_datatype_string( # pylint: disable=protected-access
type_
)
)
col_type = col_map["array"]
if parsed_type["arrayDataType"].lower().startswith("array"):
# as OpenMetadata doesn't store any details on children of array, we put
# in type as string as default to avoid Array item_type required issue
# from sqlalchemy types
args = [types.String]
else:
args = [col_map.get(parsed_type.get("arrayDataType").lower(), types.String)]
elif col_map.get(name):
col_type = col_map.get(name)
else:
logger.warning(f"Did not recognize type '{type_}'")
col_type = types.NullType
return col_type(*args)
def _get_projection_details(
columns: List[Dict], projection_parameters: Dict
) -> List[Dict]:
"""Get the projection details for the columns
Args:
columns (List[Dict]): list of columns
projection_parameters (Dict): projection parameters
"""
if not projection_parameters:
return columns
columns = deepcopy(columns)
for col in columns:
projection_details = next(
({k: v} for k, v in projection_parameters.items() if k == col["name"]), None
)
if projection_details:
col["projection_type"] = projection_details[col["name"]]
return columns
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
"""
Method to handle table columns
"""
metadata = self._get_table( # pylint: disable=protected-access
connection, table_name, schema=schema, **kw
)
columns = [
{
"name": c.name,
"type": self._get_column_type(c.type), # pylint: disable=protected-access
"nullable": True,
"default": None,
"autoincrement": False,
"comment": c.comment,
"system_data_type": c.type,
"is_complex": is_complex_type(c.type),
"dialect_options": {"awsathena_partition": True},
}
for c in metadata.partition_keys
]
if kw.get("only_partition_columns"):
# Return projected partition information to set partition type in `get_table_partition_details`
# projected partition fields are stored in the form of `projection.<field_name>.type` as a table parameter
projection_parameters = {
key_.split(".")[1]: value_
for key_, value_ in metadata.parameters.items()
if key_.startswith("projection") and key_.endswith("type")
}
columns = _get_projection_details(columns, projection_parameters)
return columns
columns += [
{
"name": c.name,
"type": self._get_column_type(c.type), # pylint: disable=protected-access
"nullable": True,
"default": None,
"autoincrement": False,
"comment": c.comment,
"system_data_type": c.type,
"is_complex": is_complex_type(c.type),
"dialect_options": {"awsathena_partition": None},
}
for c in metadata.columns
]
return columns
# pylint: disable=unused-argument
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
"""
Gets the view definition
"""
full_view_name = f'"{view_name}"' if not schema else f'"{schema}"."{view_name}"'
res = connection.execute(f"SHOW CREATE VIEW {full_view_name}").fetchall()
if res:
return "\n".join(i[0] for i in res)
return None
AthenaDialect._get_column_type = _get_column_type # pylint: disable=protected-access
AthenaDialect.get_columns = get_columns
AthenaDialect.get_view_definition = get_view_definition
class AthenaSource(CommonDbSourceService):
class AthenaSource(ExternalTableLineageMixin, CommonDbSourceService):
"""
Implements the necessary methods to extract
Database metadata from Athena Source
@ -257,6 +111,7 @@ class AthenaSource(CommonDbSourceService):
self.athena_lake_formation_client = AthenaLakeFormationClient(
connection=self.service_connection
)
self.external_location_map = {}
def query_table_names_and_types(
self, schema_name: str
@ -395,3 +250,23 @@ class AthenaSource(CommonDbSourceService):
stackTrace=traceback.format_exc(),
)
)
def get_table_description(
self, schema_name: str, table_name: str, inspector: Inspector
) -> str:
description = None
try:
table_info: dict = inspector.get_table_comment(table_name, schema_name)
table_option = inspector.get_table_options(table_name, schema_name)
self.external_location_map[
(self.context.get().database, schema_name, table_name)
] = table_option.get("awsathena_location")
# Catch any exception without breaking the ingestion
except Exception as exc: # pylint: disable=broad-except
logger.debug(traceback.format_exc())
logger.warning(
f"Table description error for table [{schema_name}.{table_name}]: {exc}"
)
else:
description = table_info.get("text")
return description

View File

@ -0,0 +1,191 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Athena utils module"""
from copy import deepcopy
from typing import Dict, List, Optional
from pyathena.sqlalchemy.util import _HashableDict
from sqlalchemy import types
from sqlalchemy.engine import reflection
from metadata.ingestion.source import sqa_types
from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser
from metadata.utils.sqlalchemy_utils import is_complex_type
@reflection.cache
def _get_column_type(self, type_):
"""
Function overwritten from AthenaDialect
to add custom SQA typing.
"""
type_ = type_.replace(" ", "").lower()
match = self._pattern_column_type.match(type_) # pylint: disable=protected-access
if match:
name = match.group(1).lower()
length = match.group(2)
else:
name = type_.lower()
length = None
args = []
col_map = {
"boolean": types.BOOLEAN,
"float": types.FLOAT,
"double": types.FLOAT,
"real": types.FLOAT,
"tinyint": types.INTEGER,
"smallint": types.INTEGER,
"integer": types.INTEGER,
"int": types.INTEGER,
"bigint": types.BIGINT,
"string": types.String,
"date": types.DATE,
"timestamp": types.TIMESTAMP,
"binary": types.BINARY,
"varbinary": types.BINARY,
"array": types.ARRAY,
"json": types.JSON,
"struct": sqa_types.SQAStruct,
"row": sqa_types.SQAStruct,
"map": sqa_types.SQAMap,
"decimal": types.DECIMAL,
"varchar": types.VARCHAR,
"char": types.CHAR,
}
if name in ["decimal", "char", "varchar"]:
col_type = col_map[name]
if length:
args = [int(l) for l in length.split(",")]
elif type_.startswith("array"):
parsed_type = (
ColumnTypeParser._parse_datatype_string( # pylint: disable=protected-access
type_
)
)
col_type = col_map["array"]
if parsed_type["arrayDataType"].lower().startswith("array"):
# as OpenMetadata doesn't store any details on children of array, we put
# in type as string as default to avoid Array item_type required issue
# from sqlalchemy types
args = [types.String]
else:
args = [col_map.get(parsed_type.get("arrayDataType").lower(), types.String)]
elif col_map.get(name):
col_type = col_map.get(name)
else:
logger.warning(f"Did not recognize type '{type_}'")
col_type = types.NullType
return col_type(*args)
# pylint: disable=unused-argument
def _get_projection_details(
columns: List[Dict], projection_parameters: Dict
) -> List[Dict]:
"""Get the projection details for the columns
Args:
columns (List[Dict]): list of columns
projection_parameters (Dict): projection parameters
"""
if not projection_parameters:
return columns
columns = deepcopy(columns)
for col in columns:
projection_details = next(
({k: v} for k, v in projection_parameters.items() if k == col["name"]), None
)
if projection_details:
col["projection_type"] = projection_details[col["name"]]
return columns
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
"""
Method to handle table columns
"""
metadata = self._get_table( # pylint: disable=protected-access
connection, table_name, schema=schema, **kw
)
columns = [
{
"name": c.name,
"type": self._get_column_type(c.type), # pylint: disable=protected-access
"nullable": True,
"default": None,
"autoincrement": False,
"comment": c.comment,
"system_data_type": c.type,
"is_complex": is_complex_type(c.type),
"dialect_options": {"awsathena_partition": True},
}
for c in metadata.partition_keys
]
if kw.get("only_partition_columns"):
# Return projected partition information to set partition type in `get_table_partition_details`
# projected partition fields are stored in the form of `projection.<field_name>.type` as a table parameter
projection_parameters = {
key_.split(".")[1]: value_
for key_, value_ in metadata.parameters.items()
if key_.startswith("projection") and key_.endswith("type")
}
columns = _get_projection_details(columns, projection_parameters)
return columns
columns += [
{
"name": c.name,
"type": self._get_column_type(c.type), # pylint: disable=protected-access
"nullable": True,
"default": None,
"autoincrement": False,
"comment": c.comment,
"system_data_type": c.type,
"is_complex": is_complex_type(c.type),
"dialect_options": {"awsathena_partition": None},
}
for c in metadata.columns
]
return columns
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
"""
Gets the view definition
"""
full_view_name = f'"{view_name}"' if not schema else f'"{schema}"."{view_name}"'
res = connection.execute(f"SHOW CREATE VIEW {full_view_name}").fetchall()
if res:
return "\n".join(i[0] for i in res)
return None
def get_table_options(
self, connection: "Connection", table_name: str, schema: Optional[str] = None, **kw
):
metadata = self._get_table(connection, table_name, schema=schema, **kw)
return {
"awsathena_location": metadata.location,
"awsathena_compression": metadata.compression,
"awsathena_row_format": metadata.row_format,
"awsathena_file_format": metadata.file_format,
"awsathena_serdeproperties": _HashableDict(metadata.serde_properties),
"awsathena_tblproperties": _HashableDict(metadata.table_properties),
}

View File

@ -14,13 +14,20 @@ External Table Lineage Mixin
import traceback
from abc import ABC
from typing import Iterable
from typing import Iterable, List, Optional
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.entity.data.container import ContainerDataModel
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.type.entityLineage import EntitiesEdge
from metadata.generated.schema.type.entityLineage import (
ColumnLineage,
EntitiesEdge,
LineageDetails,
)
from metadata.generated.schema.type.entityLineage import Source as LineageSource
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.models import Either
from metadata.ingestion.lineage.sql_lineage import get_column_fqn
from metadata.utils import fqn
from metadata.utils.logger import ingestion_logger
@ -39,7 +46,7 @@ class ExternalTableLineageMixin(ABC):
for table_qualified_tuple, location in self.external_location_map.items() or []:
try:
location_entity = self.metadata.es_search_container_by_path(
full_path=location
full_path=location, fields="dataModel"
)
database_name, schema_name, table_name = table_qualified_tuple
@ -63,6 +70,12 @@ class ExternalTableLineageMixin(ABC):
and table_entity
and table_entity[0]
):
columns_list = [
column.name.root for column in table_entity[0].columns
]
columns_lineage = self._get_column_lineage(
location_entity[0].dataModel, table_entity[0], columns_list
)
yield Either(
right=AddLineageRequest(
edge=EntitiesEdge(
@ -74,9 +87,51 @@ class ExternalTableLineageMixin(ABC):
id=table_entity[0].id,
type="table",
),
lineageDetails=LineageDetails(
source=LineageSource.ExternalTableLineage,
columnsLineage=columns_lineage,
),
)
)
)
except Exception as exc:
logger.warning(f"Failed to yield external table lineage due to - {exc}")
logger.debug(traceback.format_exc())
def _get_data_model_column_fqn(
self, data_model_entity: ContainerDataModel, column: str
) -> Optional[str]:
"""
Get fqn of column if exist in data model entity
"""
if not data_model_entity:
return None
for entity_column in data_model_entity.columns:
if entity_column.displayName.lower() == column.lower():
return entity_column.fullyQualifiedName.root
return None
def _get_column_lineage(
self,
data_model_entity: ContainerDataModel,
table_entity: Table,
columns_list: List[str],
) -> List[ColumnLineage]:
"""
Get the column lineage
"""
try:
column_lineage = []
for field in columns_list or []:
from_column = self._get_data_model_column_fqn(
data_model_entity=data_model_entity, column=field
)
to_column = get_column_fqn(table_entity=table_entity, column=field)
if from_column and to_column:
column_lineage.append(
ColumnLineage(fromColumns=[from_column], toColumn=to_column)
)
return column_lineage
except Exception as exc:
logger.debug(f"Error to get column lineage: {exc}")
logger.debug(traceback.format_exc())

View File

@ -0,0 +1,325 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test athena source
"""
import unittest
from unittest.mock import patch
from uuid import UUID
from pydantic import AnyUrl
from sqlalchemy.engine.reflection import Inspector
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.entity.data.container import (
Container,
ContainerDataModel,
FileFormat,
)
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
from metadata.generated.schema.entity.data.table import (
Column,
ColumnName,
Constraint,
DataType,
Table,
TableType,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseConnection,
DatabaseService,
DatabaseServiceType,
)
from metadata.generated.schema.entity.services.storageService import StorageServiceType
from metadata.generated.schema.metadataIngestion.workflow import (
OpenMetadataWorkflowConfig,
)
from metadata.generated.schema.type.basic import (
EntityName,
FullyQualifiedEntityName,
Href,
Markdown,
SourceUrl,
Timestamp,
Uuid,
)
from metadata.generated.schema.type.entityLineage import ColumnLineage
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.models import Either
from metadata.ingestion.source.database.athena.metadata import AthenaSource
from metadata.ingestion.source.database.common_db_source import TableNameAndType
EXPECTED_DATABASE_NAMES = ["mydatabase"]
MOCK_DATABASE_SCHEMA = DatabaseSchema(
id="2aaa012e-099a-11ed-861d-0242ac120056",
name="sample_instance",
fullyQualifiedName="sample_athena_schema.sample_db.sample_instance",
displayName="default",
description="",
database=EntityReference(
id="2aaa012e-099a-11ed-861d-0242ac120002",
type="database",
),
service=EntityReference(
id="85811038-099a-11ed-861d-0242ac120002",
type="databaseService",
),
)
MOCK_DATABASE_SERVICE = DatabaseService(
id="85811038-099a-11ed-861d-0242ac120002",
name="sample_athena_service",
connection=DatabaseConnection(),
serviceType=DatabaseServiceType.Glue,
)
MOCK_DATABASE = Database(
id="2aaa012e-099a-11ed-861d-0242ac120002",
name="sample_db",
fullyQualifiedName="test_athena.sample_db",
displayName="sample_db",
description="",
service=EntityReference(
id="85811038-099a-11ed-861d-0242ac120002",
type="databaseService",
),
)
MOCK_TABLE_NAME = "sample_table"
EXPECTED_DATABASES = [
Either(
right=CreateDatabaseRequest(
name=EntityName("sample_db"),
service=FullyQualifiedEntityName("sample_athena_service"),
default=False,
),
)
]
EXPECTED_QUERY_TABLE_NAMES_TYPES = [
TableNameAndType(name="sample_table", type_=TableType.External)
]
MOCK_LOCATION_ENTITY = [
Container(
id=Uuid(UUID("9c489754-bb60-435b-b2a5-0e43100cf950")),
name=EntityName("dbt-testing/mayur/customers.csv"),
fullyQualifiedName=FullyQualifiedEntityName(
's3_local.awsdatalake-testing."dbt-testing/mayur/customers.csv"'
),
updatedAt=Timestamp(1717070902713),
updatedBy="admin",
href=Href(
root=AnyUrl(
"http://localhost:8585/api/v1/containers/9c489754-bb60-435b-b2a5-0e43100cf950",
)
),
service=EntityReference(
id=Uuid(UUID("dd91cca3-cc54-4776-9efa-48f845cdfb92")),
type="storageService",
name="s3_local",
fullyQualifiedName="s3_local",
description=Markdown(""),
displayName="s3_local",
deleted=False,
href=Href(
root=AnyUrl(
"http://localhost:8585/api/v1/services/storageServices/dd91cca3-cc54-4776-9efa-48f845cdfb92",
)
),
),
dataModel=ContainerDataModel(
isPartitioned=False,
columns=[
Column(
name=ColumnName("CUSTOMERID"),
displayName="CUSTOMERID",
dataType=DataType.INT,
dataTypeDisplay="INT",
fullyQualifiedName=FullyQualifiedEntityName(
's3_local.awsdatalake-testing."dbt-testing/mayur/customers.csv".CUSTOMERID'
),
),
],
),
prefix="/dbt-testing/mayur/customers.csv",
numberOfObjects=2103.0,
size=652260394.0,
fileFormats=[FileFormat.csv],
serviceType=StorageServiceType.S3,
deleted=False,
sourceUrl=SourceUrl(
"https://s3.console.aws.amazon.com/s3/buckets/awsdatalake-testing?region=us-east-2&prefix=dbt-testing/mayur/customers.csv/&showversions=false"
),
fullPath="s3://awsdatalake-testing/dbt-testing/mayur/customers.csv",
sourceHash="22b1c2f2e7feeaa8f37c6649e01f027d",
)
]
MOCK_TABLE_ENTITY = [
Table(
id=Uuid(UUID("2c040cf8-432d-4597-9517-4794d6142da3")),
name=EntityName("demo_data_ext_tbl3"),
fullyQualifiedName=FullyQualifiedEntityName(
"local_athena.demo.default.demo_data_ext_tbl3"
),
updatedAt=Timestamp(1717071974350),
updatedBy="admin",
href=Href(
root=AnyUrl(
"http://localhost:8585/api/v1/tables/2c040cf8-432d-4597-9517-4794d6142da3",
)
),
tableType=TableType.Regular,
columns=[
Column(
name=ColumnName("CUSTOMERID"),
dataType=DataType.INT,
dataLength=1,
dataTypeDisplay="int",
fullyQualifiedName=FullyQualifiedEntityName(
"local_athena.demo.default.demo_data_ext_tbl3.CUSTOMERID"
),
constraint=Constraint.NULL,
),
],
databaseSchema=EntityReference(
id=Uuid(UUID("b03b0229-8a9f-497a-a675-74cb24a9be74")),
type="databaseSchema",
name="default",
fullyQualifiedName="local_athena.demo.default",
displayName="default",
deleted=False,
href=Href(
root=AnyUrl(
"http://localhost:8585/api/v1/databaseSchemas/b03b0229-8a9f-497a-a675-74cb24a9be74",
)
),
),
database=EntityReference(
id=Uuid(UUID("f054c55c-34bf-4c5f-addd-5cc26c7c832a")),
type="database",
name="demo",
fullyQualifiedName="local_athena.demo",
displayName="demo",
deleted=False,
href=Href(
root=AnyUrl(
"http://localhost:8585/api/v1/databases/f054c55c-34bf-4c5f-addd-5cc26c7c832a",
)
),
),
service=EntityReference(
id=Uuid(UUID("5e98afd3-7257-4c35-a560-f4c25b0f4b97")),
type="databaseService",
name="local_athena",
fullyQualifiedName="local_athena",
displayName="local_athena",
deleted=False,
href=Href(
root=AnyUrl(
"http://localhost:8585/api/v1/services/databaseServices/5e98afd3-7257-4c35-a560-f4c25b0f4b97",
)
),
),
serviceType=DatabaseServiceType.Athena,
deleted=False,
sourceHash="824e80b1c79b0c4ae0acd99d2338e149",
)
]
EXPECTED_COLUMN_LINEAGE = [
ColumnLineage(
fromColumns=[
FullyQualifiedEntityName(
's3_local.awsdatalake-testing."dbt-testing/mayur/customers.csv".CUSTOMERID'
)
],
toColumn=FullyQualifiedEntityName(
"local_athena.demo.default.demo_data_ext_tbl3.CUSTOMERID"
),
)
]
mock_athena_config = {
"source": {
"type": "Athena",
"serviceName": "test_athena",
"serviceConnection": {
"config": {
"type": "Athena",
"databaseName": "mydatabase",
"awsConfig": {
"awsAccessKeyId": "dummy",
"awsSecretAccessKey": "dummy",
"awsRegion": "us-east-2",
},
"s3StagingDir": "https://s3-directory-for-datasource.com",
"workgroup": "workgroup name",
}
},
"sourceConfig": {"config": {"type": "DatabaseMetadata"}},
},
"sink": {"type": "metadata-rest", "config": {}},
"workflowConfig": {
"openMetadataServerConfig": {
"hostPort": "http://localhost:8585/api",
"authProvider": "openmetadata",
"securityConfig": {
"jwtToken": "eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg"
},
}
},
}
class TestAthenaService(unittest.TestCase):
@patch(
"metadata.ingestion.source.database.database_service.DatabaseServiceSource.test_connection"
)
def __init__(self, methodName, test_connection) -> None:
super().__init__(methodName)
test_connection.return_value = False
self.config = OpenMetadataWorkflowConfig.parse_obj(mock_athena_config)
self.athena_source = AthenaSource.create(
mock_athena_config["source"],
self.config.workflowConfig.openMetadataServerConfig,
)
self.athena_source.context.get().__dict__[
"database_schema"
] = MOCK_DATABASE_SCHEMA.name.root
self.athena_source.context.get().__dict__[
"database_service"
] = MOCK_DATABASE_SERVICE.name.root
self.athena_source.context.get().__dict__["database"] = MOCK_DATABASE.name.root
def test_get_database_name(self):
assert list(self.athena_source.get_database_names()) == EXPECTED_DATABASE_NAMES
def test_query_table_names_and_types(self):
with patch.object(Inspector, "get_table_names", return_value=[MOCK_TABLE_NAME]):
assert (
self.athena_source.query_table_names_and_types(
MOCK_DATABASE_SCHEMA.name.root
)
== EXPECTED_QUERY_TABLE_NAMES_TYPES
)
def test_yield_database(self):
assert (
list(
self.athena_source.yield_database(database_name=MOCK_DATABASE.name.root)
)
== EXPECTED_DATABASES
)
def test_column_lineage(self):
columns_list = [column.name.root for column in MOCK_TABLE_ENTITY[0].columns]
column_lineage = self.athena_source._get_column_lineage(
MOCK_LOCATION_ENTITY[0].dataModel, MOCK_TABLE_ENTITY[0], columns_list
)
assert column_lineage == EXPECTED_COLUMN_LINEAGE