From 8f693860767d024c491fa76aa1d42ef4dcb0470f Mon Sep 17 00:00:00 2001 From: NiharDoshi99 <51595473+NiharDoshi99@users.noreply.github.com> Date: Tue, 3 Jan 2023 10:28:38 +0530 Subject: [PATCH] Fix: refactor get_column_type (#9559) * Fix: refactor get_column_type * Fix: added changes as per comments * Fix: pylint * Fix: minor changes * Fix: minor changes --- .../source/database/column_type_parser.py | 38 +++++++--- .../tests/unit/test_column_type_parser.py | 75 +++++++++++++++++-- 2 files changed, 96 insertions(+), 17 deletions(-) diff --git a/ingestion/src/metadata/ingestion/source/database/column_type_parser.py b/ingestion/src/metadata/ingestion/source/database/column_type_parser.py index 8a29e68435c..a2349a2a114 100644 --- a/ingestion/src/metadata/ingestion/source/database/column_type_parser.py +++ b/ingestion/src/metadata/ingestion/source/database/column_type_parser.py @@ -213,17 +213,33 @@ class ColumnTypeParser: @staticmethod def get_column_type(column_type: Any) -> str: - if not ColumnTypeParser._COLUMN_TYPE_MAPPING.get(type(column_type)): - if not ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get(str(column_type)): - if not ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get( - str(column_type).split("(", maxsplit=1)[0].split("<")[0].upper() - ): - return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get("VARCHAR") - return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get( - str(column_type).split("(", maxsplit=1)[0].split("<")[0].upper() - ) - return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get(str(column_type)) - return ColumnTypeParser._COLUMN_TYPE_MAPPING.get(type(column_type)) + column_type_result = ColumnTypeParser.get_column_type_mapping(column_type) + if column_type_result: + return column_type_result + column_type_result = ColumnTypeParser.get_source_type_mapping(column_type) + if column_type_result: + return column_type_result + column_type_result = ColumnTypeParser.get_source_type_containes_brackets( + column_type + ) + if column_type_result: + return column_type_result + + return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get("VARCHAR") + + @staticmethod + def get_column_type_mapping(column_type: Any) -> str: + return ColumnTypeParser._COLUMN_TYPE_MAPPING.get(type(column_type), None) + + @staticmethod + def get_source_type_mapping(column_type: Any) -> str: + return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get(str(column_type), None) + + @staticmethod + def get_source_type_containes_brackets(column_type: Any) -> str: + return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get( + str(column_type).split("(", maxsplit=1)[0].split("<")[0].upper(), None + ) @staticmethod def _parse_datatype_string( diff --git a/ingestion/tests/unit/test_column_type_parser.py b/ingestion/tests/unit/test_column_type_parser.py index 9792152c9f2..ef5484e9c03 100644 --- a/ingestion/tests/unit/test_column_type_parser.py +++ b/ingestion/tests/unit/test_column_type_parser.py @@ -1,6 +1,22 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test column type in column_type_parser +""" +import json import os from unittest import TestCase +from sqlalchemy.sql import sqltypes as types + from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser from metadata.utils.ansi import print_ansi_encoded_string @@ -26,11 +42,50 @@ COLUMN_TYPE_PARSE = [ "string", "uniontype,struct>", ] + +COLUMN_TYPE = [ + "ARRAY", + "BIGINT", + "BINARY VARYING", + "CURSOR", + "DATETIME", + "DATETIMEOFFSET", + "GEOGRAPHY", + "INT2", + "INT8", + "INT128", + "UINT2", + "LONGBLOB", + "JSONB", + "POINT", + "Random1", +] + +EXPTECTED_COLUMN_TYPE = [ + "ARRAY", + "BIGINT", + "VARBINARY", + "BINARY", + "DATETIME", + "DATETIME", + "GEOGRAPHY", + "SMALLINT", + "BIGINT", + "BIGINT", + "SMALLINT", + "LONGBLOB", + "JSON", + "POINT", + "VARCHAR", +] root = os.path.dirname(__file__) -import json + try: - with open(os.path.join(root, "resources/expected_output_column_parser.json")) as f: + with open( + os.path.join(root, "resources/expected_output_column_parser.json"), + encoding="UTF-8", + ) as f: EXPECTED_OUTPUT = json.loads(f.read())["data"] except Exception as exc: print_ansi_encoded_string(message=exc) @@ -39,8 +94,16 @@ except Exception as exc: class ColumnTypeParseTest(TestCase): def test_check_datatype_support(self): for index, parse_string in enumerate(COLUMN_TYPE_PARSE): - parsed_string = ColumnTypeParser._parse_datatype_string(parse_string) - self.assertTrue( - True if parsed_string == EXPECTED_OUTPUT[index] else False, - msg=f"{index}: {COLUMN_TYPE_PARSE[index]} : {parsed_string}", + parsed_string = ColumnTypeParser._parse_datatype_string( # pylint: disable=protected-access + parse_string ) + self.assertTrue( + parsed_string == EXPECTED_OUTPUT[index], + msg=f"{index}: {parse_string} : {parsed_string}", + ) + + def test_check_column_type(self): + self.assertEqual(len(COLUMN_TYPE), len(EXPTECTED_COLUMN_TYPE)) + for index, column in enumerate(COLUMN_TYPE): + column_type = ColumnTypeParser.get_column_type(column_type=column) + self.assertEqual(EXPTECTED_COLUMN_TYPE[index], column_type)