diff --git a/ingestion/src/metadata/ingestion/source/sql_source.py b/ingestion/src/metadata/ingestion/source/sql_source.py index 29804aed8f7..f0b24931b11 100644 --- a/ingestion/src/metadata/ingestion/source/sql_source.py +++ b/ingestion/src/metadata/ingestion/source/sql_source.py @@ -576,7 +576,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]): col_dict = Column(**parsed_string) try: if ( - self.config.enable_policy_tags + hasattr(self.config, "enable_policy_tags") and "policy_tags" in column and column["policy_tags"] ): diff --git a/ingestion/src/metadata/utils/column_type_parser.py b/ingestion/src/metadata/utils/column_type_parser.py index 1d098ef8cc5..983bd6eb4c9 100644 --- a/ingestion/src/metadata/utils/column_type_parser.py +++ b/ingestion/src/metadata/utils/column_type_parser.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Type, Union from sqlalchemy.sql import sqltypes as types from sqlalchemy.types import TypeEngine @@ -156,21 +156,17 @@ class ColumnTypeParser: @staticmethod def get_column_type(column_type: Any) -> str: - type_class: Optional[str] = None - if isinstance(column_type, types.NullType): - return "NULL" - for sql_type in ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.keys(): - if str(column_type) == sql_type: - type_class = ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE[sql_type] - break - if type_class is None or type_class == "NULL": - for col_type in ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.keys(): - if str(column_type).split("(")[0].split("<")[0].upper() in col_type: - type_class = ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get(col_type) - break - else: - type_class = None - return type_class + if not ColumnTypeParser._COLUMN_TYPE_MAPPING.get(type(column_type)): + if not ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get(str(column_type)): + if not ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get( + str(column_type).split("(")[0].split("<")[0].upper() + ): + return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get("VARCHAR") + return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get( + str(column_type).split("(")[0].split("<")[0].upper() + ) + return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get(str(column_type)) + return ColumnTypeParser._COLUMN_TYPE_MAPPING.get(type(column_type)) @staticmethod def _parse_datatype_string(