mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-11-02 19:48:17 +00:00
* get table ddl for athena tables * changes in method to get all table ddls * external table/container lineage for athena * column lineage for external table lineage * unittest for athena * pyformat changes * add external table lineage unit test * fix unittest with pydantic v2 changes * fix unittest formating * fix code smell
This commit is contained in:
parent
afec7703cc
commit
3f5bc1948d
@ -12,12 +12,9 @@
|
||||
"""Athena source module"""
|
||||
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
from pyathena.sqlalchemy.base import AthenaDialect
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.engine import reflection
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
|
||||
@ -42,18 +39,33 @@ from metadata.ingestion.api.models import Either
|
||||
from metadata.ingestion.api.steps import InvalidSourceException
|
||||
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.ingestion.source import sqa_types
|
||||
from metadata.ingestion.source.database.athena.client import AthenaLakeFormationClient
|
||||
from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser
|
||||
from metadata.ingestion.source.database.athena.utils import (
|
||||
_get_column_type,
|
||||
get_columns,
|
||||
get_table_options,
|
||||
get_view_definition,
|
||||
)
|
||||
from metadata.ingestion.source.database.common_db_source import (
|
||||
CommonDbSourceService,
|
||||
TableNameAndType,
|
||||
)
|
||||
from metadata.ingestion.source.database.external_table_lineage_mixin import (
|
||||
ExternalTableLineageMixin,
|
||||
)
|
||||
from metadata.utils import fqn
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
from metadata.utils.sqlalchemy_utils import is_complex_type
|
||||
from metadata.utils.sqlalchemy_utils import get_all_table_ddls, get_table_ddl
|
||||
from metadata.utils.tag_utils import get_ometa_tag_and_classification
|
||||
|
||||
AthenaDialect._get_column_type = _get_column_type # pylint: disable=protected-access
|
||||
AthenaDialect.get_columns = get_columns
|
||||
AthenaDialect.get_view_definition = get_view_definition
|
||||
AthenaDialect.get_table_options = get_table_options
|
||||
|
||||
Inspector.get_all_table_ddls = get_all_table_ddls
|
||||
Inspector.get_table_ddl = get_table_ddl
|
||||
|
||||
logger = ingestion_logger()
|
||||
|
||||
ATHENA_TAG = "ATHENA TAG"
|
||||
@ -72,165 +84,7 @@ ATHENA_INTERVAL_TYPE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def _get_column_type(self, type_):
|
||||
"""
|
||||
Function overwritten from AthenaDialect
|
||||
to add custom SQA typing.
|
||||
"""
|
||||
type_ = type_.replace(" ", "").lower()
|
||||
match = self._pattern_column_type.match(type_) # pylint: disable=protected-access
|
||||
if match:
|
||||
name = match.group(1).lower()
|
||||
length = match.group(2)
|
||||
else:
|
||||
name = type_.lower()
|
||||
length = None
|
||||
|
||||
args = []
|
||||
col_map = {
|
||||
"boolean": types.BOOLEAN,
|
||||
"float": types.FLOAT,
|
||||
"double": types.FLOAT,
|
||||
"real": types.FLOAT,
|
||||
"tinyint": types.INTEGER,
|
||||
"smallint": types.INTEGER,
|
||||
"integer": types.INTEGER,
|
||||
"int": types.INTEGER,
|
||||
"bigint": types.BIGINT,
|
||||
"string": types.String,
|
||||
"date": types.DATE,
|
||||
"timestamp": types.TIMESTAMP,
|
||||
"binary": types.BINARY,
|
||||
"varbinary": types.BINARY,
|
||||
"array": types.ARRAY,
|
||||
"json": types.JSON,
|
||||
"struct": sqa_types.SQAStruct,
|
||||
"row": sqa_types.SQAStruct,
|
||||
"map": sqa_types.SQAMap,
|
||||
"decimal": types.DECIMAL,
|
||||
"varchar": types.VARCHAR,
|
||||
"char": types.CHAR,
|
||||
}
|
||||
if name in ["decimal", "char", "varchar"]:
|
||||
col_type = col_map[name]
|
||||
if length:
|
||||
args = [int(l) for l in length.split(",")]
|
||||
elif type_.startswith("array"):
|
||||
parsed_type = (
|
||||
ColumnTypeParser._parse_datatype_string( # pylint: disable=protected-access
|
||||
type_
|
||||
)
|
||||
)
|
||||
col_type = col_map["array"]
|
||||
if parsed_type["arrayDataType"].lower().startswith("array"):
|
||||
# as OpenMetadata doesn't store any details on children of array, we put
|
||||
# in type as string as default to avoid Array item_type required issue
|
||||
# from sqlalchemy types
|
||||
args = [types.String]
|
||||
else:
|
||||
args = [col_map.get(parsed_type.get("arrayDataType").lower(), types.String)]
|
||||
elif col_map.get(name):
|
||||
col_type = col_map.get(name)
|
||||
else:
|
||||
logger.warning(f"Did not recognize type '{type_}'")
|
||||
col_type = types.NullType
|
||||
return col_type(*args)
|
||||
|
||||
|
||||
def _get_projection_details(
|
||||
columns: List[Dict], projection_parameters: Dict
|
||||
) -> List[Dict]:
|
||||
"""Get the projection details for the columns
|
||||
|
||||
Args:
|
||||
columns (List[Dict]): list of columns
|
||||
projection_parameters (Dict): projection parameters
|
||||
"""
|
||||
if not projection_parameters:
|
||||
return columns
|
||||
|
||||
columns = deepcopy(columns)
|
||||
for col in columns:
|
||||
projection_details = next(
|
||||
({k: v} for k, v in projection_parameters.items() if k == col["name"]), None
|
||||
)
|
||||
if projection_details:
|
||||
col["projection_type"] = projection_details[col["name"]]
|
||||
|
||||
return columns
|
||||
|
||||
|
||||
@reflection.cache
|
||||
def get_columns(self, connection, table_name, schema=None, **kw):
|
||||
"""
|
||||
Method to handle table columns
|
||||
"""
|
||||
metadata = self._get_table( # pylint: disable=protected-access
|
||||
connection, table_name, schema=schema, **kw
|
||||
)
|
||||
columns = [
|
||||
{
|
||||
"name": c.name,
|
||||
"type": self._get_column_type(c.type), # pylint: disable=protected-access
|
||||
"nullable": True,
|
||||
"default": None,
|
||||
"autoincrement": False,
|
||||
"comment": c.comment,
|
||||
"system_data_type": c.type,
|
||||
"is_complex": is_complex_type(c.type),
|
||||
"dialect_options": {"awsathena_partition": True},
|
||||
}
|
||||
for c in metadata.partition_keys
|
||||
]
|
||||
|
||||
if kw.get("only_partition_columns"):
|
||||
# Return projected partition information to set partition type in `get_table_partition_details`
|
||||
# projected partition fields are stored in the form of `projection.<field_name>.type` as a table parameter
|
||||
projection_parameters = {
|
||||
key_.split(".")[1]: value_
|
||||
for key_, value_ in metadata.parameters.items()
|
||||
if key_.startswith("projection") and key_.endswith("type")
|
||||
}
|
||||
columns = _get_projection_details(columns, projection_parameters)
|
||||
return columns
|
||||
|
||||
columns += [
|
||||
{
|
||||
"name": c.name,
|
||||
"type": self._get_column_type(c.type), # pylint: disable=protected-access
|
||||
"nullable": True,
|
||||
"default": None,
|
||||
"autoincrement": False,
|
||||
"comment": c.comment,
|
||||
"system_data_type": c.type,
|
||||
"is_complex": is_complex_type(c.type),
|
||||
"dialect_options": {"awsathena_partition": None},
|
||||
}
|
||||
for c in metadata.columns
|
||||
]
|
||||
|
||||
return columns
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@reflection.cache
|
||||
def get_view_definition(self, connection, view_name, schema=None, **kw):
|
||||
"""
|
||||
Gets the view definition
|
||||
"""
|
||||
full_view_name = f'"{view_name}"' if not schema else f'"{schema}"."{view_name}"'
|
||||
res = connection.execute(f"SHOW CREATE VIEW {full_view_name}").fetchall()
|
||||
if res:
|
||||
return "\n".join(i[0] for i in res)
|
||||
return None
|
||||
|
||||
|
||||
AthenaDialect._get_column_type = _get_column_type # pylint: disable=protected-access
|
||||
AthenaDialect.get_columns = get_columns
|
||||
AthenaDialect.get_view_definition = get_view_definition
|
||||
|
||||
|
||||
class AthenaSource(CommonDbSourceService):
|
||||
class AthenaSource(ExternalTableLineageMixin, CommonDbSourceService):
|
||||
"""
|
||||
Implements the necessary methods to extract
|
||||
Database metadata from Athena Source
|
||||
@ -257,6 +111,7 @@ class AthenaSource(CommonDbSourceService):
|
||||
self.athena_lake_formation_client = AthenaLakeFormationClient(
|
||||
connection=self.service_connection
|
||||
)
|
||||
self.external_location_map = {}
|
||||
|
||||
def query_table_names_and_types(
|
||||
self, schema_name: str
|
||||
@ -395,3 +250,23 @@ class AthenaSource(CommonDbSourceService):
|
||||
stackTrace=traceback.format_exc(),
|
||||
)
|
||||
)
|
||||
|
||||
def get_table_description(
|
||||
self, schema_name: str, table_name: str, inspector: Inspector
|
||||
) -> str:
|
||||
description = None
|
||||
try:
|
||||
table_info: dict = inspector.get_table_comment(table_name, schema_name)
|
||||
table_option = inspector.get_table_options(table_name, schema_name)
|
||||
self.external_location_map[
|
||||
(self.context.get().database, schema_name, table_name)
|
||||
] = table_option.get("awsathena_location")
|
||||
# Catch any exception without breaking the ingestion
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Table description error for table [{schema_name}.{table_name}]: {exc}"
|
||||
)
|
||||
else:
|
||||
description = table_info.get("text")
|
||||
return description
|
||||
|
||||
191
ingestion/src/metadata/ingestion/source/database/athena/utils.py
Normal file
191
ingestion/src/metadata/ingestion/source/database/athena/utils.py
Normal file
@ -0,0 +1,191 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Athena utils module"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pyathena.sqlalchemy.util import _HashableDict
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.engine import reflection
|
||||
|
||||
from metadata.ingestion.source import sqa_types
|
||||
from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser
|
||||
from metadata.utils.sqlalchemy_utils import is_complex_type
|
||||
|
||||
|
||||
@reflection.cache
|
||||
def _get_column_type(self, type_):
|
||||
"""
|
||||
Function overwritten from AthenaDialect
|
||||
to add custom SQA typing.
|
||||
"""
|
||||
type_ = type_.replace(" ", "").lower()
|
||||
match = self._pattern_column_type.match(type_) # pylint: disable=protected-access
|
||||
if match:
|
||||
name = match.group(1).lower()
|
||||
length = match.group(2)
|
||||
else:
|
||||
name = type_.lower()
|
||||
length = None
|
||||
|
||||
args = []
|
||||
col_map = {
|
||||
"boolean": types.BOOLEAN,
|
||||
"float": types.FLOAT,
|
||||
"double": types.FLOAT,
|
||||
"real": types.FLOAT,
|
||||
"tinyint": types.INTEGER,
|
||||
"smallint": types.INTEGER,
|
||||
"integer": types.INTEGER,
|
||||
"int": types.INTEGER,
|
||||
"bigint": types.BIGINT,
|
||||
"string": types.String,
|
||||
"date": types.DATE,
|
||||
"timestamp": types.TIMESTAMP,
|
||||
"binary": types.BINARY,
|
||||
"varbinary": types.BINARY,
|
||||
"array": types.ARRAY,
|
||||
"json": types.JSON,
|
||||
"struct": sqa_types.SQAStruct,
|
||||
"row": sqa_types.SQAStruct,
|
||||
"map": sqa_types.SQAMap,
|
||||
"decimal": types.DECIMAL,
|
||||
"varchar": types.VARCHAR,
|
||||
"char": types.CHAR,
|
||||
}
|
||||
if name in ["decimal", "char", "varchar"]:
|
||||
col_type = col_map[name]
|
||||
if length:
|
||||
args = [int(l) for l in length.split(",")]
|
||||
elif type_.startswith("array"):
|
||||
parsed_type = (
|
||||
ColumnTypeParser._parse_datatype_string( # pylint: disable=protected-access
|
||||
type_
|
||||
)
|
||||
)
|
||||
col_type = col_map["array"]
|
||||
if parsed_type["arrayDataType"].lower().startswith("array"):
|
||||
# as OpenMetadata doesn't store any details on children of array, we put
|
||||
# in type as string as default to avoid Array item_type required issue
|
||||
# from sqlalchemy types
|
||||
args = [types.String]
|
||||
else:
|
||||
args = [col_map.get(parsed_type.get("arrayDataType").lower(), types.String)]
|
||||
elif col_map.get(name):
|
||||
col_type = col_map.get(name)
|
||||
else:
|
||||
logger.warning(f"Did not recognize type '{type_}'")
|
||||
col_type = types.NullType
|
||||
return col_type(*args)
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def _get_projection_details(
|
||||
columns: List[Dict], projection_parameters: Dict
|
||||
) -> List[Dict]:
|
||||
"""Get the projection details for the columns
|
||||
|
||||
Args:
|
||||
columns (List[Dict]): list of columns
|
||||
projection_parameters (Dict): projection parameters
|
||||
"""
|
||||
if not projection_parameters:
|
||||
return columns
|
||||
|
||||
columns = deepcopy(columns)
|
||||
for col in columns:
|
||||
projection_details = next(
|
||||
({k: v} for k, v in projection_parameters.items() if k == col["name"]), None
|
||||
)
|
||||
if projection_details:
|
||||
col["projection_type"] = projection_details[col["name"]]
|
||||
|
||||
return columns
|
||||
|
||||
|
||||
@reflection.cache
|
||||
def get_columns(self, connection, table_name, schema=None, **kw):
|
||||
"""
|
||||
Method to handle table columns
|
||||
"""
|
||||
metadata = self._get_table( # pylint: disable=protected-access
|
||||
connection, table_name, schema=schema, **kw
|
||||
)
|
||||
columns = [
|
||||
{
|
||||
"name": c.name,
|
||||
"type": self._get_column_type(c.type), # pylint: disable=protected-access
|
||||
"nullable": True,
|
||||
"default": None,
|
||||
"autoincrement": False,
|
||||
"comment": c.comment,
|
||||
"system_data_type": c.type,
|
||||
"is_complex": is_complex_type(c.type),
|
||||
"dialect_options": {"awsathena_partition": True},
|
||||
}
|
||||
for c in metadata.partition_keys
|
||||
]
|
||||
|
||||
if kw.get("only_partition_columns"):
|
||||
# Return projected partition information to set partition type in `get_table_partition_details`
|
||||
# projected partition fields are stored in the form of `projection.<field_name>.type` as a table parameter
|
||||
projection_parameters = {
|
||||
key_.split(".")[1]: value_
|
||||
for key_, value_ in metadata.parameters.items()
|
||||
if key_.startswith("projection") and key_.endswith("type")
|
||||
}
|
||||
columns = _get_projection_details(columns, projection_parameters)
|
||||
return columns
|
||||
|
||||
columns += [
|
||||
{
|
||||
"name": c.name,
|
||||
"type": self._get_column_type(c.type), # pylint: disable=protected-access
|
||||
"nullable": True,
|
||||
"default": None,
|
||||
"autoincrement": False,
|
||||
"comment": c.comment,
|
||||
"system_data_type": c.type,
|
||||
"is_complex": is_complex_type(c.type),
|
||||
"dialect_options": {"awsathena_partition": None},
|
||||
}
|
||||
for c in metadata.columns
|
||||
]
|
||||
|
||||
return columns
|
||||
|
||||
|
||||
@reflection.cache
|
||||
def get_view_definition(self, connection, view_name, schema=None, **kw):
|
||||
"""
|
||||
Gets the view definition
|
||||
"""
|
||||
full_view_name = f'"{view_name}"' if not schema else f'"{schema}"."{view_name}"'
|
||||
res = connection.execute(f"SHOW CREATE VIEW {full_view_name}").fetchall()
|
||||
if res:
|
||||
return "\n".join(i[0] for i in res)
|
||||
return None
|
||||
|
||||
|
||||
def get_table_options(
|
||||
self, connection: "Connection", table_name: str, schema: Optional[str] = None, **kw
|
||||
):
|
||||
metadata = self._get_table(connection, table_name, schema=schema, **kw)
|
||||
return {
|
||||
"awsathena_location": metadata.location,
|
||||
"awsathena_compression": metadata.compression,
|
||||
"awsathena_row_format": metadata.row_format,
|
||||
"awsathena_file_format": metadata.file_format,
|
||||
"awsathena_serdeproperties": _HashableDict(metadata.serde_properties),
|
||||
"awsathena_tblproperties": _HashableDict(metadata.table_properties),
|
||||
}
|
||||
@ -14,13 +14,20 @@ External Table Lineage Mixin
|
||||
|
||||
import traceback
|
||||
from abc import ABC
|
||||
from typing import Iterable
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
|
||||
from metadata.generated.schema.entity.data.container import ContainerDataModel
|
||||
from metadata.generated.schema.entity.data.table import Table
|
||||
from metadata.generated.schema.type.entityLineage import EntitiesEdge
|
||||
from metadata.generated.schema.type.entityLineage import (
|
||||
ColumnLineage,
|
||||
EntitiesEdge,
|
||||
LineageDetails,
|
||||
)
|
||||
from metadata.generated.schema.type.entityLineage import Source as LineageSource
|
||||
from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.ingestion.api.models import Either
|
||||
from metadata.ingestion.lineage.sql_lineage import get_column_fqn
|
||||
from metadata.utils import fqn
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
|
||||
@ -39,7 +46,7 @@ class ExternalTableLineageMixin(ABC):
|
||||
for table_qualified_tuple, location in self.external_location_map.items() or []:
|
||||
try:
|
||||
location_entity = self.metadata.es_search_container_by_path(
|
||||
full_path=location
|
||||
full_path=location, fields="dataModel"
|
||||
)
|
||||
database_name, schema_name, table_name = table_qualified_tuple
|
||||
|
||||
@ -63,6 +70,12 @@ class ExternalTableLineageMixin(ABC):
|
||||
and table_entity
|
||||
and table_entity[0]
|
||||
):
|
||||
columns_list = [
|
||||
column.name.root for column in table_entity[0].columns
|
||||
]
|
||||
columns_lineage = self._get_column_lineage(
|
||||
location_entity[0].dataModel, table_entity[0], columns_list
|
||||
)
|
||||
yield Either(
|
||||
right=AddLineageRequest(
|
||||
edge=EntitiesEdge(
|
||||
@ -74,9 +87,51 @@ class ExternalTableLineageMixin(ABC):
|
||||
id=table_entity[0].id,
|
||||
type="table",
|
||||
),
|
||||
lineageDetails=LineageDetails(
|
||||
source=LineageSource.ExternalTableLineage,
|
||||
columnsLineage=columns_lineage,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to yield external table lineage due to - {exc}")
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
def _get_data_model_column_fqn(
|
||||
self, data_model_entity: ContainerDataModel, column: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get fqn of column if exist in data model entity
|
||||
"""
|
||||
if not data_model_entity:
|
||||
return None
|
||||
for entity_column in data_model_entity.columns:
|
||||
if entity_column.displayName.lower() == column.lower():
|
||||
return entity_column.fullyQualifiedName.root
|
||||
return None
|
||||
|
||||
def _get_column_lineage(
|
||||
self,
|
||||
data_model_entity: ContainerDataModel,
|
||||
table_entity: Table,
|
||||
columns_list: List[str],
|
||||
) -> List[ColumnLineage]:
|
||||
"""
|
||||
Get the column lineage
|
||||
"""
|
||||
try:
|
||||
column_lineage = []
|
||||
for field in columns_list or []:
|
||||
from_column = self._get_data_model_column_fqn(
|
||||
data_model_entity=data_model_entity, column=field
|
||||
)
|
||||
to_column = get_column_fqn(table_entity=table_entity, column=field)
|
||||
if from_column and to_column:
|
||||
column_lineage.append(
|
||||
ColumnLineage(fromColumns=[from_column], toColumn=to_column)
|
||||
)
|
||||
return column_lineage
|
||||
except Exception as exc:
|
||||
logger.debug(f"Error to get column lineage: {exc}")
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
325
ingestion/tests/unit/topology/database/test_athena.py
Normal file
325
ingestion/tests/unit/topology/database/test_athena.py
Normal file
@ -0,0 +1,325 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Test athena source
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import AnyUrl
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
|
||||
from metadata.generated.schema.entity.data.container import (
|
||||
Container,
|
||||
ContainerDataModel,
|
||||
FileFormat,
|
||||
)
|
||||
from metadata.generated.schema.entity.data.database import Database
|
||||
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
|
||||
from metadata.generated.schema.entity.data.table import (
|
||||
Column,
|
||||
ColumnName,
|
||||
Constraint,
|
||||
DataType,
|
||||
Table,
|
||||
TableType,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.databaseService import (
|
||||
DatabaseConnection,
|
||||
DatabaseService,
|
||||
DatabaseServiceType,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.storageService import StorageServiceType
|
||||
from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
OpenMetadataWorkflowConfig,
|
||||
)
|
||||
from metadata.generated.schema.type.basic import (
|
||||
EntityName,
|
||||
FullyQualifiedEntityName,
|
||||
Href,
|
||||
Markdown,
|
||||
SourceUrl,
|
||||
Timestamp,
|
||||
Uuid,
|
||||
)
|
||||
from metadata.generated.schema.type.entityLineage import ColumnLineage
|
||||
from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.ingestion.api.models import Either
|
||||
from metadata.ingestion.source.database.athena.metadata import AthenaSource
|
||||
from metadata.ingestion.source.database.common_db_source import TableNameAndType
|
||||
|
||||
EXPECTED_DATABASE_NAMES = ["mydatabase"]
|
||||
MOCK_DATABASE_SCHEMA = DatabaseSchema(
|
||||
id="2aaa012e-099a-11ed-861d-0242ac120056",
|
||||
name="sample_instance",
|
||||
fullyQualifiedName="sample_athena_schema.sample_db.sample_instance",
|
||||
displayName="default",
|
||||
description="",
|
||||
database=EntityReference(
|
||||
id="2aaa012e-099a-11ed-861d-0242ac120002",
|
||||
type="database",
|
||||
),
|
||||
service=EntityReference(
|
||||
id="85811038-099a-11ed-861d-0242ac120002",
|
||||
type="databaseService",
|
||||
),
|
||||
)
|
||||
MOCK_DATABASE_SERVICE = DatabaseService(
|
||||
id="85811038-099a-11ed-861d-0242ac120002",
|
||||
name="sample_athena_service",
|
||||
connection=DatabaseConnection(),
|
||||
serviceType=DatabaseServiceType.Glue,
|
||||
)
|
||||
MOCK_DATABASE = Database(
|
||||
id="2aaa012e-099a-11ed-861d-0242ac120002",
|
||||
name="sample_db",
|
||||
fullyQualifiedName="test_athena.sample_db",
|
||||
displayName="sample_db",
|
||||
description="",
|
||||
service=EntityReference(
|
||||
id="85811038-099a-11ed-861d-0242ac120002",
|
||||
type="databaseService",
|
||||
),
|
||||
)
|
||||
MOCK_TABLE_NAME = "sample_table"
|
||||
EXPECTED_DATABASES = [
|
||||
Either(
|
||||
right=CreateDatabaseRequest(
|
||||
name=EntityName("sample_db"),
|
||||
service=FullyQualifiedEntityName("sample_athena_service"),
|
||||
default=False,
|
||||
),
|
||||
)
|
||||
]
|
||||
EXPECTED_QUERY_TABLE_NAMES_TYPES = [
|
||||
TableNameAndType(name="sample_table", type_=TableType.External)
|
||||
]
|
||||
MOCK_LOCATION_ENTITY = [
|
||||
Container(
|
||||
id=Uuid(UUID("9c489754-bb60-435b-b2a5-0e43100cf950")),
|
||||
name=EntityName("dbt-testing/mayur/customers.csv"),
|
||||
fullyQualifiedName=FullyQualifiedEntityName(
|
||||
's3_local.awsdatalake-testing."dbt-testing/mayur/customers.csv"'
|
||||
),
|
||||
updatedAt=Timestamp(1717070902713),
|
||||
updatedBy="admin",
|
||||
href=Href(
|
||||
root=AnyUrl(
|
||||
"http://localhost:8585/api/v1/containers/9c489754-bb60-435b-b2a5-0e43100cf950",
|
||||
)
|
||||
),
|
||||
service=EntityReference(
|
||||
id=Uuid(UUID("dd91cca3-cc54-4776-9efa-48f845cdfb92")),
|
||||
type="storageService",
|
||||
name="s3_local",
|
||||
fullyQualifiedName="s3_local",
|
||||
description=Markdown(""),
|
||||
displayName="s3_local",
|
||||
deleted=False,
|
||||
href=Href(
|
||||
root=AnyUrl(
|
||||
"http://localhost:8585/api/v1/services/storageServices/dd91cca3-cc54-4776-9efa-48f845cdfb92",
|
||||
)
|
||||
),
|
||||
),
|
||||
dataModel=ContainerDataModel(
|
||||
isPartitioned=False,
|
||||
columns=[
|
||||
Column(
|
||||
name=ColumnName("CUSTOMERID"),
|
||||
displayName="CUSTOMERID",
|
||||
dataType=DataType.INT,
|
||||
dataTypeDisplay="INT",
|
||||
fullyQualifiedName=FullyQualifiedEntityName(
|
||||
's3_local.awsdatalake-testing."dbt-testing/mayur/customers.csv".CUSTOMERID'
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
prefix="/dbt-testing/mayur/customers.csv",
|
||||
numberOfObjects=2103.0,
|
||||
size=652260394.0,
|
||||
fileFormats=[FileFormat.csv],
|
||||
serviceType=StorageServiceType.S3,
|
||||
deleted=False,
|
||||
sourceUrl=SourceUrl(
|
||||
"https://s3.console.aws.amazon.com/s3/buckets/awsdatalake-testing?region=us-east-2&prefix=dbt-testing/mayur/customers.csv/&showversions=false"
|
||||
),
|
||||
fullPath="s3://awsdatalake-testing/dbt-testing/mayur/customers.csv",
|
||||
sourceHash="22b1c2f2e7feeaa8f37c6649e01f027d",
|
||||
)
|
||||
]
|
||||
|
||||
MOCK_TABLE_ENTITY = [
|
||||
Table(
|
||||
id=Uuid(UUID("2c040cf8-432d-4597-9517-4794d6142da3")),
|
||||
name=EntityName("demo_data_ext_tbl3"),
|
||||
fullyQualifiedName=FullyQualifiedEntityName(
|
||||
"local_athena.demo.default.demo_data_ext_tbl3"
|
||||
),
|
||||
updatedAt=Timestamp(1717071974350),
|
||||
updatedBy="admin",
|
||||
href=Href(
|
||||
root=AnyUrl(
|
||||
"http://localhost:8585/api/v1/tables/2c040cf8-432d-4597-9517-4794d6142da3",
|
||||
)
|
||||
),
|
||||
tableType=TableType.Regular,
|
||||
columns=[
|
||||
Column(
|
||||
name=ColumnName("CUSTOMERID"),
|
||||
dataType=DataType.INT,
|
||||
dataLength=1,
|
||||
dataTypeDisplay="int",
|
||||
fullyQualifiedName=FullyQualifiedEntityName(
|
||||
"local_athena.demo.default.demo_data_ext_tbl3.CUSTOMERID"
|
||||
),
|
||||
constraint=Constraint.NULL,
|
||||
),
|
||||
],
|
||||
databaseSchema=EntityReference(
|
||||
id=Uuid(UUID("b03b0229-8a9f-497a-a675-74cb24a9be74")),
|
||||
type="databaseSchema",
|
||||
name="default",
|
||||
fullyQualifiedName="local_athena.demo.default",
|
||||
displayName="default",
|
||||
deleted=False,
|
||||
href=Href(
|
||||
root=AnyUrl(
|
||||
"http://localhost:8585/api/v1/databaseSchemas/b03b0229-8a9f-497a-a675-74cb24a9be74",
|
||||
)
|
||||
),
|
||||
),
|
||||
database=EntityReference(
|
||||
id=Uuid(UUID("f054c55c-34bf-4c5f-addd-5cc26c7c832a")),
|
||||
type="database",
|
||||
name="demo",
|
||||
fullyQualifiedName="local_athena.demo",
|
||||
displayName="demo",
|
||||
deleted=False,
|
||||
href=Href(
|
||||
root=AnyUrl(
|
||||
"http://localhost:8585/api/v1/databases/f054c55c-34bf-4c5f-addd-5cc26c7c832a",
|
||||
)
|
||||
),
|
||||
),
|
||||
service=EntityReference(
|
||||
id=Uuid(UUID("5e98afd3-7257-4c35-a560-f4c25b0f4b97")),
|
||||
type="databaseService",
|
||||
name="local_athena",
|
||||
fullyQualifiedName="local_athena",
|
||||
displayName="local_athena",
|
||||
deleted=False,
|
||||
href=Href(
|
||||
root=AnyUrl(
|
||||
"http://localhost:8585/api/v1/services/databaseServices/5e98afd3-7257-4c35-a560-f4c25b0f4b97",
|
||||
)
|
||||
),
|
||||
),
|
||||
serviceType=DatabaseServiceType.Athena,
|
||||
deleted=False,
|
||||
sourceHash="824e80b1c79b0c4ae0acd99d2338e149",
|
||||
)
|
||||
]
|
||||
EXPECTED_COLUMN_LINEAGE = [
|
||||
ColumnLineage(
|
||||
fromColumns=[
|
||||
FullyQualifiedEntityName(
|
||||
's3_local.awsdatalake-testing."dbt-testing/mayur/customers.csv".CUSTOMERID'
|
||||
)
|
||||
],
|
||||
toColumn=FullyQualifiedEntityName(
|
||||
"local_athena.demo.default.demo_data_ext_tbl3.CUSTOMERID"
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
mock_athena_config = {
|
||||
"source": {
|
||||
"type": "Athena",
|
||||
"serviceName": "test_athena",
|
||||
"serviceConnection": {
|
||||
"config": {
|
||||
"type": "Athena",
|
||||
"databaseName": "mydatabase",
|
||||
"awsConfig": {
|
||||
"awsAccessKeyId": "dummy",
|
||||
"awsSecretAccessKey": "dummy",
|
||||
"awsRegion": "us-east-2",
|
||||
},
|
||||
"s3StagingDir": "https://s3-directory-for-datasource.com",
|
||||
"workgroup": "workgroup name",
|
||||
}
|
||||
},
|
||||
"sourceConfig": {"config": {"type": "DatabaseMetadata"}},
|
||||
},
|
||||
"sink": {"type": "metadata-rest", "config": {}},
|
||||
"workflowConfig": {
|
||||
"openMetadataServerConfig": {
|
||||
"hostPort": "http://localhost:8585/api",
|
||||
"authProvider": "openmetadata",
|
||||
"securityConfig": {
|
||||
"jwtToken": "eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg"
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestAthenaService(unittest.TestCase):
|
||||
@patch(
|
||||
"metadata.ingestion.source.database.database_service.DatabaseServiceSource.test_connection"
|
||||
)
|
||||
def __init__(self, methodName, test_connection) -> None:
|
||||
super().__init__(methodName)
|
||||
test_connection.return_value = False
|
||||
self.config = OpenMetadataWorkflowConfig.parse_obj(mock_athena_config)
|
||||
self.athena_source = AthenaSource.create(
|
||||
mock_athena_config["source"],
|
||||
self.config.workflowConfig.openMetadataServerConfig,
|
||||
)
|
||||
self.athena_source.context.get().__dict__[
|
||||
"database_schema"
|
||||
] = MOCK_DATABASE_SCHEMA.name.root
|
||||
self.athena_source.context.get().__dict__[
|
||||
"database_service"
|
||||
] = MOCK_DATABASE_SERVICE.name.root
|
||||
self.athena_source.context.get().__dict__["database"] = MOCK_DATABASE.name.root
|
||||
|
||||
def test_get_database_name(self):
|
||||
assert list(self.athena_source.get_database_names()) == EXPECTED_DATABASE_NAMES
|
||||
|
||||
def test_query_table_names_and_types(self):
|
||||
with patch.object(Inspector, "get_table_names", return_value=[MOCK_TABLE_NAME]):
|
||||
assert (
|
||||
self.athena_source.query_table_names_and_types(
|
||||
MOCK_DATABASE_SCHEMA.name.root
|
||||
)
|
||||
== EXPECTED_QUERY_TABLE_NAMES_TYPES
|
||||
)
|
||||
|
||||
def test_yield_database(self):
|
||||
assert (
|
||||
list(
|
||||
self.athena_source.yield_database(database_name=MOCK_DATABASE.name.root)
|
||||
)
|
||||
== EXPECTED_DATABASES
|
||||
)
|
||||
|
||||
def test_column_lineage(self):
|
||||
columns_list = [column.name.root for column in MOCK_TABLE_ENTITY[0].columns]
|
||||
column_lineage = self.athena_source._get_column_lineage(
|
||||
MOCK_LOCATION_ENTITY[0].dataModel, MOCK_TABLE_ENTITY[0], columns_list
|
||||
)
|
||||
assert column_lineage == EXPECTED_COLUMN_LINEAGE
|
||||
Loading…
x
Reference in New Issue
Block a user