mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-12-25 14:38:29 +00:00
Fix #5108 - Add Map and Struct support for Athena & Block them for profiler (#5181)
This commit is contained in:
parent
c4a0ced7ec
commit
0005bc1292
@ -8,6 +8,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
from pyathena.sqlalchemy_athena import AthenaDialect
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.engine import Inspector
|
||||
|
||||
from metadata.generated.schema.entity.services.connections.database.athenaConnection import (
|
||||
AthenaConnection,
|
||||
@ -21,12 +26,73 @@ from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
Source as WorkflowSource,
|
||||
)
|
||||
from metadata.ingestion.api.source import InvalidSourceException
|
||||
from metadata.ingestion.source import sqa_types
|
||||
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
|
||||
logger = ingestion_logger()
|
||||
|
||||
|
||||
def _get_column_type(self, type_):
|
||||
"""
|
||||
Function overwritten from AthenaDialect
|
||||
to add custom SQA typing.
|
||||
"""
|
||||
match = self._pattern_column_type.match(type_)
|
||||
if match:
|
||||
name = match.group(1).lower()
|
||||
length = match.group(2)
|
||||
else:
|
||||
name = type_.lower()
|
||||
length = None
|
||||
|
||||
args = []
|
||||
if name in ["boolean"]:
|
||||
col_type = types.BOOLEAN
|
||||
elif name in ["float", "double", "real"]:
|
||||
col_type = types.FLOAT
|
||||
elif name in ["tinyint", "smallint", "integer", "int"]:
|
||||
col_type = types.INTEGER
|
||||
elif name in ["bigint"]:
|
||||
col_type = types.BIGINT
|
||||
elif name in ["decimal"]:
|
||||
col_type = types.DECIMAL
|
||||
if length:
|
||||
precision, scale = length.split(",")
|
||||
args = [int(precision), int(scale)]
|
||||
elif name in ["char"]:
|
||||
col_type = types.CHAR
|
||||
if length:
|
||||
args = [int(length)]
|
||||
elif name in ["varchar"]:
|
||||
col_type = types.VARCHAR
|
||||
if length:
|
||||
args = [int(length)]
|
||||
elif name in ["string"]:
|
||||
col_type = types.String
|
||||
elif name in ["date"]:
|
||||
col_type = types.DATE
|
||||
elif name in ["timestamp"]:
|
||||
col_type = types.TIMESTAMP
|
||||
elif name in ["binary", "varbinary"]:
|
||||
col_type = types.BINARY
|
||||
elif name in ["array"]:
|
||||
col_type = types.ARRAY
|
||||
elif name in ["json"]:
|
||||
col_type = types.JSON
|
||||
elif name in ["struct", "row"]:
|
||||
col_type = sqa_types.SQAStruct
|
||||
elif name in ["map"]:
|
||||
col_type = sqa_types.SQAMap
|
||||
else:
|
||||
logger.warn(f"Did not recognize type '{type_}'")
|
||||
col_type = types.NullType
|
||||
return col_type(*args)
|
||||
|
||||
|
||||
AthenaDialect._get_column_type = _get_column_type
|
||||
|
||||
|
||||
class AthenaSource(CommonDbSourceService):
|
||||
def __init__(self, config, metadata_config):
|
||||
super().__init__(config, metadata_config)
|
||||
@ -41,3 +107,13 @@ class AthenaSource(CommonDbSourceService):
|
||||
)
|
||||
|
||||
return cls(config, metadata_config)
|
||||
|
||||
def get_table_names(
|
||||
self, schema: str, inspector: Inspector
|
||||
) -> Optional[Iterable[Tuple[str, str]]]:
|
||||
if self.source_config.includeTables:
|
||||
for table in inspector.get_table_names(schema):
|
||||
yield table, "External" # All tables in Athena are External
|
||||
if self.source_config.includeViews:
|
||||
for view in inspector.get_view_names(schema):
|
||||
yield view, "View"
|
||||
|
||||
@ -158,7 +158,7 @@ class CommonDbSourceService(DBTSource, SqlColumnHandler, SqlAlchemySource):
|
||||
|
||||
def get_table_names(
|
||||
self, schema: str, inspector: Inspector
|
||||
) -> Optional[Tuple[str, str]]:
|
||||
) -> Optional[Iterable[Tuple[str, str]]]:
|
||||
if self.source_config.includeTables:
|
||||
for table in inspector.get_table_names(schema):
|
||||
yield table, "Regular"
|
||||
|
||||
@ -90,7 +90,9 @@ class SqlAlchemySource(Source, ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_table_names(self, schema: str, inspector: Inspector) -> Optional[List[str]]:
|
||||
def get_table_names(
|
||||
self, schema: str, inspector: Inspector
|
||||
) -> Optional[Iterable[Tuple[str, str]]]:
|
||||
"""
|
||||
Method to fetch table & view names
|
||||
"""
|
||||
|
||||
28
ingestion/src/metadata/ingestion/source/sqa_types.py
Normal file
28
ingestion/src/metadata/ingestion/source/sqa_types.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Define custom types as wrappers on top of
|
||||
existing SQA types to have a bridge between
|
||||
SQA dialects and OM rich type system
|
||||
"""
|
||||
from sqlalchemy import types
|
||||
|
||||
|
||||
class SQAMap(types.String):
|
||||
"""
|
||||
Custom Map type definition
|
||||
"""
|
||||
|
||||
|
||||
class SQAStruct(types.String):
|
||||
"""
|
||||
Custom Struct type definition
|
||||
"""
|
||||
@ -17,6 +17,7 @@ import sqlalchemy
|
||||
from sqlalchemy import Integer, Numeric
|
||||
from sqlalchemy.sql.sqltypes import Concatenable, Enum
|
||||
|
||||
from metadata.ingestion.source import sqa_types
|
||||
from metadata.orm_profiler.orm.types.hex_byte_string import HexByteString
|
||||
from metadata.orm_profiler.orm.types.uuid import UUIDString
|
||||
from metadata.orm_profiler.registry import TypeRegistry
|
||||
@ -66,6 +67,8 @@ NOT_COMPUTE = {
|
||||
sqlalchemy.types.NullType,
|
||||
sqlalchemy.ARRAY,
|
||||
sqlalchemy.JSON,
|
||||
sqa_types.SQAMap,
|
||||
sqa_types.SQAStruct,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -4,6 +4,8 @@ from typing import Any, Dict, List, Type, Union
|
||||
from sqlalchemy.sql import sqltypes as types
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
from metadata.ingestion.source import sqa_types
|
||||
|
||||
|
||||
def create_sqlalchemy_type(name: str):
|
||||
sqlalchemy_type = type(
|
||||
@ -45,6 +47,9 @@ class ColumnTypeParser:
|
||||
types.Integer: "INT",
|
||||
types.BigInteger: "BIGINT",
|
||||
types.VARBINARY: "VARBINARY",
|
||||
# Custom wrapper types enriching SQA type system
|
||||
sqa_types.SQAMap: "MAP",
|
||||
sqa_types.SQAStruct: "STRUCT",
|
||||
}
|
||||
|
||||
_SOURCE_TYPE_TO_OM_TYPE = {
|
||||
|
||||
@ -15,10 +15,12 @@ Test Profiler behavior
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import sqlalchemy.types
|
||||
from sqlalchemy import Column, Integer, String, create_engine
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
from metadata.generated.schema.entity.data.table import ColumnProfile
|
||||
from metadata.ingestion.source import sqa_types
|
||||
from metadata.orm_profiler.metrics.core import add_props
|
||||
from metadata.orm_profiler.metrics.registry import Metrics
|
||||
from metadata.orm_profiler.profiler.core import MissingMetricException, Profiler
|
||||
@ -132,3 +134,33 @@ class ProfilerTest(TestCase):
|
||||
Profiler(
|
||||
like, like_ratio, session=self.session, table=User, use_cols=[User.age]
|
||||
)
|
||||
|
||||
def test_skipped_types(self):
|
||||
"""
|
||||
Check that we are properly skipping computations for
|
||||
not supported types
|
||||
"""
|
||||
|
||||
class NotCompute(Base):
|
||||
__tablename__ = "not_compute"
|
||||
id = Column(Integer, primary_key=True)
|
||||
null_col = Column(sqlalchemy.types.NULLTYPE)
|
||||
array_col = Column(sqlalchemy.ARRAY(Integer, dimensions=2))
|
||||
json_col = Column(sqlalchemy.JSON)
|
||||
map_col = Column(sqa_types.SQAMap)
|
||||
struct_col = Column(sqa_types.SQAStruct)
|
||||
|
||||
profiler = Profiler(
|
||||
Metrics.COUNT.value,
|
||||
session=self.session,
|
||||
table=NotCompute,
|
||||
use_cols=[
|
||||
NotCompute.null_col,
|
||||
NotCompute.array_col,
|
||||
NotCompute.json_col,
|
||||
NotCompute.map_col,
|
||||
NotCompute.struct_col,
|
||||
],
|
||||
)
|
||||
|
||||
assert not profiler.column_results
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user