Fix #5108 - Add Map and Struct support for Athena & Block them for profiler (#5181)

Fix #5108 - Add Map and Struct support for Athena & Block them for profiler (#5181)
This commit is contained in:
Pere Miquel Brull 2022-05-30 06:53:16 +02:00 committed by GitHub
parent c4a0ced7ec
commit 0005bc1292
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 148 additions and 2 deletions

View File

@ -8,6 +8,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Optional, Tuple
from pyathena.sqlalchemy_athena import AthenaDialect
from sqlalchemy import types
from sqlalchemy.engine import Inspector
from metadata.generated.schema.entity.services.connections.database.athenaConnection import (
AthenaConnection,
@ -21,12 +26,73 @@ from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source import sqa_types
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
def _get_column_type(self, type_):
"""
Function overwritten from AthenaDialect
to add custom SQA typing.
"""
match = self._pattern_column_type.match(type_)
if match:
name = match.group(1).lower()
length = match.group(2)
else:
name = type_.lower()
length = None
args = []
if name in ["boolean"]:
col_type = types.BOOLEAN
elif name in ["float", "double", "real"]:
col_type = types.FLOAT
elif name in ["tinyint", "smallint", "integer", "int"]:
col_type = types.INTEGER
elif name in ["bigint"]:
col_type = types.BIGINT
elif name in ["decimal"]:
col_type = types.DECIMAL
if length:
precision, scale = length.split(",")
args = [int(precision), int(scale)]
elif name in ["char"]:
col_type = types.CHAR
if length:
args = [int(length)]
elif name in ["varchar"]:
col_type = types.VARCHAR
if length:
args = [int(length)]
elif name in ["string"]:
col_type = types.String
elif name in ["date"]:
col_type = types.DATE
elif name in ["timestamp"]:
col_type = types.TIMESTAMP
elif name in ["binary", "varbinary"]:
col_type = types.BINARY
elif name in ["array"]:
col_type = types.ARRAY
elif name in ["json"]:
col_type = types.JSON
elif name in ["struct", "row"]:
col_type = sqa_types.SQAStruct
elif name in ["map"]:
col_type = sqa_types.SQAMap
else:
logger.warn(f"Did not recognize type '{type_}'")
col_type = types.NullType
return col_type(*args)
AthenaDialect._get_column_type = _get_column_type
class AthenaSource(CommonDbSourceService):
def __init__(self, config, metadata_config):
super().__init__(config, metadata_config)
@ -41,3 +107,13 @@ class AthenaSource(CommonDbSourceService):
)
return cls(config, metadata_config)
def get_table_names(
self, schema: str, inspector: Inspector
) -> Optional[Iterable[Tuple[str, str]]]:
if self.source_config.includeTables:
for table in inspector.get_table_names(schema):
yield table, "External" # All tables in Athena are External
if self.source_config.includeViews:
for view in inspector.get_view_names(schema):
yield view, "View"

View File

@ -158,7 +158,7 @@ class CommonDbSourceService(DBTSource, SqlColumnHandler, SqlAlchemySource):
def get_table_names(
self, schema: str, inspector: Inspector
) -> Optional[Tuple[str, str]]:
) -> Optional[Iterable[Tuple[str, str]]]:
if self.source_config.includeTables:
for table in inspector.get_table_names(schema):
yield table, "Regular"

View File

@ -90,7 +90,9 @@ class SqlAlchemySource(Source, ABC):
"""
@abstractmethod
def get_table_names(self, schema: str, inspector: Inspector) -> Optional[List[str]]:
def get_table_names(
self, schema: str, inspector: Inspector
) -> Optional[Iterable[Tuple[str, str]]]:
"""
Method to fetch table & view names
"""

View File

@ -0,0 +1,28 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Define custom types as wrappers on top of
existing SQA types to have a bridge between
SQA dialects and OM rich type system
"""
from sqlalchemy import types
class SQAMap(types.String):
"""
Custom Map type definition
"""
class SQAStruct(types.String):
"""
Custom Struct type definition
"""

View File

@ -17,6 +17,7 @@ import sqlalchemy
from sqlalchemy import Integer, Numeric
from sqlalchemy.sql.sqltypes import Concatenable, Enum
from metadata.ingestion.source import sqa_types
from metadata.orm_profiler.orm.types.hex_byte_string import HexByteString
from metadata.orm_profiler.orm.types.uuid import UUIDString
from metadata.orm_profiler.registry import TypeRegistry
@ -66,6 +67,8 @@ NOT_COMPUTE = {
sqlalchemy.types.NullType,
sqlalchemy.ARRAY,
sqlalchemy.JSON,
sqa_types.SQAMap,
sqa_types.SQAStruct,
}

View File

@ -4,6 +4,8 @@ from typing import Any, Dict, List, Type, Union
from sqlalchemy.sql import sqltypes as types
from sqlalchemy.types import TypeEngine
from metadata.ingestion.source import sqa_types
def create_sqlalchemy_type(name: str):
sqlalchemy_type = type(
@ -45,6 +47,9 @@ class ColumnTypeParser:
types.Integer: "INT",
types.BigInteger: "BIGINT",
types.VARBINARY: "VARBINARY",
# Custom wrapper types enriching SQA type system
sqa_types.SQAMap: "MAP",
sqa_types.SQAStruct: "STRUCT",
}
_SOURCE_TYPE_TO_OM_TYPE = {

View File

@ -15,10 +15,12 @@ Test Profiler behavior
from unittest import TestCase
import pytest
import sqlalchemy.types
from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.orm import declarative_base
from metadata.generated.schema.entity.data.table import ColumnProfile
from metadata.ingestion.source import sqa_types
from metadata.orm_profiler.metrics.core import add_props
from metadata.orm_profiler.metrics.registry import Metrics
from metadata.orm_profiler.profiler.core import MissingMetricException, Profiler
@ -132,3 +134,33 @@ class ProfilerTest(TestCase):
Profiler(
like, like_ratio, session=self.session, table=User, use_cols=[User.age]
)
def test_skipped_types(self):
"""
Check that we are properly skipping computations for
not supported types
"""
class NotCompute(Base):
__tablename__ = "not_compute"
id = Column(Integer, primary_key=True)
null_col = Column(sqlalchemy.types.NULLTYPE)
array_col = Column(sqlalchemy.ARRAY(Integer, dimensions=2))
json_col = Column(sqlalchemy.JSON)
map_col = Column(sqa_types.SQAMap)
struct_col = Column(sqa_types.SQAStruct)
profiler = Profiler(
Metrics.COUNT.value,
session=self.session,
table=NotCompute,
use_cols=[
NotCompute.null_col,
NotCompute.array_col,
NotCompute.json_col,
NotCompute.map_col,
NotCompute.struct_col,
],
)
assert not profiler.column_results