Fix: refactor get_column_type (#9559)

* Fix: refactor get_column_type

* Fix: added changes as per comments

* Fix: pylint

* Fix: minor changes

* Fix: minor changes
This commit is contained in:
NiharDoshi99 2023-01-03 10:28:38 +05:30 committed by GitHub
parent 749d850043
commit 8f69386076
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 96 additions and 17 deletions

View File

@ -213,17 +213,33 @@ class ColumnTypeParser:
@staticmethod
def get_column_type(column_type: Any) -> str:
if not ColumnTypeParser._COLUMN_TYPE_MAPPING.get(type(column_type)):
if not ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get(str(column_type)):
if not ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get(
str(column_type).split("(", maxsplit=1)[0].split("<")[0].upper()
):
return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get("VARCHAR")
return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get(
str(column_type).split("(", maxsplit=1)[0].split("<")[0].upper()
)
return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get(str(column_type))
return ColumnTypeParser._COLUMN_TYPE_MAPPING.get(type(column_type))
column_type_result = ColumnTypeParser.get_column_type_mapping(column_type)
if column_type_result:
return column_type_result
column_type_result = ColumnTypeParser.get_source_type_mapping(column_type)
if column_type_result:
return column_type_result
column_type_result = ColumnTypeParser.get_source_type_containes_brackets(
column_type
)
if column_type_result:
return column_type_result
return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get("VARCHAR")
@staticmethod
def get_column_type_mapping(column_type: Any) -> str:
return ColumnTypeParser._COLUMN_TYPE_MAPPING.get(type(column_type), None)
@staticmethod
def get_source_type_mapping(column_type: Any) -> str:
return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get(str(column_type), None)
@staticmethod
def get_source_type_containes_brackets(column_type: Any) -> str:
return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get(
str(column_type).split("(", maxsplit=1)[0].split("<")[0].upper(), None
)
@staticmethod
def _parse_datatype_string(

View File

@ -1,6 +1,22 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test column type in column_type_parser
"""
import json
import os
from unittest import TestCase
from sqlalchemy.sql import sqltypes as types
from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser
from metadata.utils.ansi import print_ansi_encoded_string
@ -26,11 +42,50 @@ COLUMN_TYPE_PARSE = [
"string",
"uniontype<int,double,array<string>,struct<a:int,b:string>>",
]
COLUMN_TYPE = [
"ARRAY",
"BIGINT",
"BINARY VARYING",
"CURSOR",
"DATETIME",
"DATETIMEOFFSET",
"GEOGRAPHY",
"INT2",
"INT8",
"INT128",
"UINT2",
"LONGBLOB",
"JSONB",
"POINT",
"Random1",
]
EXPTECTED_COLUMN_TYPE = [
"ARRAY",
"BIGINT",
"VARBINARY",
"BINARY",
"DATETIME",
"DATETIME",
"GEOGRAPHY",
"SMALLINT",
"BIGINT",
"BIGINT",
"SMALLINT",
"LONGBLOB",
"JSON",
"POINT",
"VARCHAR",
]
root = os.path.dirname(__file__)
import json
try:
with open(os.path.join(root, "resources/expected_output_column_parser.json")) as f:
with open(
os.path.join(root, "resources/expected_output_column_parser.json"),
encoding="UTF-8",
) as f:
EXPECTED_OUTPUT = json.loads(f.read())["data"]
except Exception as exc:
print_ansi_encoded_string(message=exc)
@ -39,8 +94,16 @@ except Exception as exc:
class ColumnTypeParseTest(TestCase):
def test_check_datatype_support(self):
for index, parse_string in enumerate(COLUMN_TYPE_PARSE):
parsed_string = ColumnTypeParser._parse_datatype_string(parse_string)
self.assertTrue(
True if parsed_string == EXPECTED_OUTPUT[index] else False,
msg=f"{index}: {COLUMN_TYPE_PARSE[index]} : {parsed_string}",
parsed_string = ColumnTypeParser._parse_datatype_string( # pylint: disable=protected-access
parse_string
)
self.assertTrue(
parsed_string == EXPECTED_OUTPUT[index],
msg=f"{index}: {parse_string} : {parsed_string}",
)
def test_check_column_type(self):
self.assertEqual(len(COLUMN_TYPE), len(EXPTECTED_COLUMN_TYPE))
for index, column in enumerate(COLUMN_TYPE):
column_type = ColumnTypeParser.get_column_type(column_type=column)
self.assertEqual(EXPTECTED_COLUMN_TYPE[index], column_type)