From f390bac375636be0ffcc1be11daf1c44a7cc7cc8 Mon Sep 17 00:00:00 2001 From: Sriharsha Chintalapani Date: Sat, 22 Jan 2022 10:02:58 -0800 Subject: [PATCH] Fix 2270: Column Type Parser (#2271) * Fix 2270: Column Type Parser * Fix 2270: Column Type Parser * Added checks to allow arrayDataType and dataTypeDisplay * Modified - arrayDataType and dataTypeDisplay * Update sql_source.py * Update sql_source.py * file formatted * Modified according to column_type_parser.py * modified / refactored / deleted the files * Added Tests, modified sql_source * file formatted * Added missing datatypes * Added Tests * Added Tests * Added Tests - refactored expected output into a json file * file formatted * Sample Data Updated Co-authored-by: Ayush Shah --- .../src/metadata/ingestion/source/amundsen.py | 33 +- .../src/metadata/ingestion/source/bigquery.py | 17 +- .../src/metadata/ingestion/source/glue.py | 37 +-- .../metadata/ingestion/source/sample_data.py | 2 - .../metadata/ingestion/source/snowflake.py | 2 +- .../metadata/ingestion/source/sql_source.py | 130 ++++---- .../src/metadata/utils/column_helpers.py | 306 ------------------ .../src/metadata/utils/column_type_parser.py | 303 +++++++++++++++++ .../expected_output_column_parser.json | 110 +++++++ .../tests/unit/test_column_type_parser.py | 32 ++ .../{datatypes_test.py => test_datatypes.py} | 12 +- 11 files changed, 553 insertions(+), 431 deletions(-) delete mode 100644 ingestion/src/metadata/utils/column_helpers.py create mode 100644 ingestion/src/metadata/utils/column_type_parser.py create mode 100644 ingestion/tests/unit/resources/expected_output_column_parser.json create mode 100644 ingestion/tests/unit/test_column_type_parser.py rename ingestion/tests/unit/{datatypes_test.py => test_datatypes.py} (84%) diff --git a/ingestion/src/metadata/ingestion/source/amundsen.py b/ingestion/src/metadata/ingestion/source/amundsen.py index 060d02d96c7..24dc8635658 100644 --- a/ingestion/src/metadata/ingestion/source/amundsen.py +++ b/ingestion/src/metadata/ingestion/source/amundsen.py @@ -10,15 +10,11 @@ # limitations under the License. import logging -import re -import textwrap import traceback import uuid from dataclasses import dataclass, field from typing import Iterable, List, Optional -from pydantic import SecretStr - from metadata.config.common import ConfigModel from metadata.generated.schema.api.services.createDatabaseService import ( CreateDatabaseServiceEntityRequest, @@ -38,13 +34,14 @@ from metadata.ingestion.models.user import User from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.source.neo4j_helper import Neo4JConfig, Neo4jHelper -from metadata.utils.column_helpers import check_column_complex_type, get_column_type +from metadata.utils.column_type_parser import ColumnTypeParser from metadata.utils.helpers import get_dashboard_service_or_create from metadata.utils.sql_queries import ( NEO4J_AMUNDSEN_DASHBOARD_QUERY, NEO4J_AMUNDSEN_TABLE_QUERY, NEO4J_AMUNDSEN_USER_QUERY, ) +from pydantic import SecretStr logger: logging.Logger = logging.getLogger(__name__) @@ -155,28 +152,10 @@ class AmundsenSource(Source[Entity]): # Amundsen merges the length into type itself. Instead of making changes to our generic type builder # we will do a type match and see if it matches any primitive types and return a type data_type = self.get_type_primitive_type(data_type) - ( - col_type, - data_type_display, - arr_data_type, - children, - ) = check_column_complex_type( - self.status, table["name"], data_type, name - ) - - col = Column( - name=name, - description=description, - dataType=col_type, - dataTypeDisplay="{}({})".format(col_type, 1) - if data_type_display is None - else f"{data_type_display}", - children=children, - arrayDataType=arr_data_type, - ordinalPosition=row_order, - dataLength=1, - ) - row_order += 1 + parsed_string = ColumnTypeParser._parse_datatype_string(data_type) + parsed_string["name"] = name + parsed_string["dataLength"] = 1 + col = Column(**parsed_string) columns.append(col) fqn = f"{service_name}.{database.name}.{table['schema']}.{table['name']}" diff --git a/ingestion/src/metadata/ingestion/source/bigquery.py b/ingestion/src/metadata/ingestion/source/bigquery.py index 64aa3fdec17..9fd57d75dca 100644 --- a/ingestion/src/metadata/ingestion/source/bigquery.py +++ b/ingestion/src/metadata/ingestion/source/bigquery.py @@ -13,6 +13,10 @@ import os from typing import Optional, Tuple, Any import json, tempfile, logging +from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig +from metadata.ingestion.source.sql_source import SQLSource +from metadata.ingestion.source.sql_source_common import SQLConnectionConfig +from metadata.utils.column_type_parser import create_sqlalchemy_type from sqlalchemy_bigquery import _types from sqlalchemy_bigquery._struct import STRUCT from sqlalchemy_bigquery._types import ( @@ -20,11 +24,6 @@ from sqlalchemy_bigquery._types import ( _get_transitive_schema_fields, ) -from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig -from metadata.ingestion.source.sql_source import SQLSource -from metadata.ingestion.source.sql_source_common import SQLConnectionConfig -from metadata.utils.column_helpers import create_sqlalchemy_type - GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY") _types._type_map["GEOGRAPHY"] = GEOGRAPHY @@ -44,7 +43,7 @@ def get_columns(bq_schema): "precision": field.precision, "scale": field.scale, "max_length": field.max_length, - "raw_data_type": repr(_get_sqla_column_type(field)), + "raw_data_type": str(_get_sqla_column_type(field)), } col_list.append(col_obj) return col_list @@ -75,9 +74,9 @@ class BigquerySource(SQLSource): metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict) if config.options.get("credentials", None): cred_path = create_credential_temp_file(config.options.get("credentials")) - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path - del config.options["credentials"] - config.options["credentials_path"] = cred_path + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path + config.options["credentials_path"] = cred_path + del config.options["credentials"] return cls(config, metadata_config, ctx) def close(self): diff --git a/ingestion/src/metadata/ingestion/source/glue.py b/ingestion/src/metadata/ingestion/source/glue.py index ea1c2fc18b9..739836eaed5 100644 --- a/ingestion/src/metadata/ingestion/source/glue.py +++ b/ingestion/src/metadata/ingestion/source/glue.py @@ -29,7 +29,7 @@ from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.source.sql_source_common import SQLSourceStatus from metadata.utils.aws_client import AWSClient, AWSClientConfigModel -from metadata.utils.column_helpers import check_column_complex_type +from metadata.utils.column_type_parser import ColumnTypeParser from metadata.utils.helpers import ( get_database_service_or_create, get_pipeline_service_or_create, @@ -75,7 +75,7 @@ class GlueSource(Source[Entity]): "serviceType": "Glue", "pipelineUrl": self.config.endpoint_url if self.config.endpoint_url is not None - else f"https://glue.{ self.config.region_name }.amazonaws.com", + else f"https://glue.{self.config.region_name}.amazonaws.com", }, metadata_config, ) @@ -116,31 +116,20 @@ class GlueSource(Source[Entity]): yield from self.ingest_pipelines() def get_columns(self, column_data): - row_order = 0 - for column in column_data["Columns"]: + for index, column in enumerate(column_data["Columns"]): if column["Type"].lower().startswith("union"): column["Type"] = column["Type"].replace(" ", "") - ( - col_type, - data_type_display, - arr_data_type, - children, - ) = check_column_complex_type( - self.status, self.dataset_name, column["Type"].lower(), column["Name"] + parsed_string = ColumnTypeParser._parse_datatype_string( + column["Type"].lower() ) - yield Column( - name=column["Name"].replace(".", "_DOT_")[:128], - description="", - dataType=col_type, - dataTypeDisplay="{}({})".format(col_type, 1) - if data_type_display is None - else f"{data_type_display}", - ordinalPosition=row_order, - children=children, - arrayDataType=arr_data_type, - dataLength=1, - ) - row_order += 1 + if isinstance(parsed_string, list): + parsed_string = {} + parsed_string["dataTypeDisplay"] = str(column["Type"]) + parsed_string["dataType"] = "UNION" + parsed_string["name"] = column["Name"][:64] + parsed_string["ordinalPosition"] = index + parsed_string["dataLength"] = parsed_string.get("dataLength", 1) + yield Column(**parsed_string) def ingest_tables(self, next_tables_token=None) -> Iterable[OMetaDatabaseAndTable]: try: diff --git a/ingestion/src/metadata/ingestion/source/sample_data.py b/ingestion/src/metadata/ingestion/source/sample_data.py index 7b270e51e46..a65a1e36bbd 100644 --- a/ingestion/src/metadata/ingestion/source/sample_data.py +++ b/ingestion/src/metadata/ingestion/source/sample_data.py @@ -169,11 +169,9 @@ class SampleDataSource(Source[Entity]): self.config = config self.metadata_config = metadata_config self.metadata = OpenMetadata(metadata_config) - print("hi") self.storage_service_json = json.load( open(self.config.sample_data_folder + "/locations/service.json", "r") ) - print("hello") self.locations = json.load( open(self.config.sample_data_folder + "/locations/locations.json", "r") ) diff --git a/ingestion/src/metadata/ingestion/source/snowflake.py b/ingestion/src/metadata/ingestion/source/snowflake.py index b6373a339ba..a6a572cdc69 100644 --- a/ingestion/src/metadata/ingestion/source/snowflake.py +++ b/ingestion/src/metadata/ingestion/source/snowflake.py @@ -17,7 +17,7 @@ from snowflake.sqlalchemy.snowdialect import ischema_names from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.source.sql_source import SQLSource from metadata.ingestion.source.sql_source_common import SQLConnectionConfig -from metadata.utils.column_helpers import create_sqlalchemy_type +from metadata.utils.column_type_parser import create_sqlalchemy_type GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY") ischema_names["VARIANT"] = VARIANT diff --git a/ingestion/src/metadata/ingestion/source/sql_source.py b/ingestion/src/metadata/ingestion/source/sql_source.py index c03db0ed779..314552a5339 100644 --- a/ingestion/src/metadata/ingestion/source/sql_source.py +++ b/ingestion/src/metadata/ingestion/source/sql_source.py @@ -18,10 +18,6 @@ import traceback import uuid from typing import Dict, Iterable, List, Optional, Tuple -from sqlalchemy import create_engine -from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.inspection import inspect - from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.table import ( Column, @@ -43,8 +39,11 @@ from metadata.ingestion.source.sql_source_common import ( SQLConnectionConfig, SQLSourceStatus, ) -from metadata.utils.column_helpers import check_column_complex_type, get_column_type +from metadata.utils.column_type_parser import ColumnTypeParser from metadata.utils.helpers import get_database_service_or_create +from sqlalchemy import create_engine +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.inspection import inspect logger: logging.Logger = logging.getLogger(__name__) @@ -387,7 +386,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]): ccolumn = ccolumns[key] try: ctype = ccolumn["type"] - col_type = get_column_type(self.status, model_name, ctype) + col_type = ColumnTypeParser.get_column_type(ctype) description = manifest_columns.get(key.lower(), {}).get( "description", None ) @@ -460,21 +459,20 @@ class SQLSource(Source[OMetaDatabaseAndTable]): if "column_names" in constraint.keys() ] - dataset_name = f"{schema}.{table}" table_columns = [] columns = inspector.get_columns(table, schema) try: for row_order, column in enumerate(columns): try: + col_dict = None if "." in column["name"]: - logger.info( - f"Found '.' in {column['name']}, changing '.' to '_DOT_'" - ) column["name"] = column["name"].replace(".", "_DOT_") children = None data_type_display = None col_data_length = None arr_data_type = None + parsed_string = None + print(column["raw_data_type"]) if ( "raw_data_type" in column and column["raw_data_type"] is not None @@ -482,21 +480,12 @@ class SQLSource(Source[OMetaDatabaseAndTable]): column["raw_data_type"] = self.parse_raw_data_type( column["raw_data_type"] ) - ( - col_type, - data_type_display, - arr_data_type, - children, - ) = check_column_complex_type( - self.status, - dataset_name, - column["raw_data_type"], - column["name"], + parsed_string = ColumnTypeParser._parse_datatype_string( + column["raw_data_type"] ) + parsed_string["name"] = column["name"] else: - col_type = get_column_type( - self.status, dataset_name, column["type"] - ) + col_type = ColumnTypeParser.get_column_type(column["type"]) if col_type == "ARRAY" and re.match( r"(?:\w*)(?:\()(\w*)(?:.*)", str(column["type"]) ): @@ -504,40 +493,64 @@ class SQLSource(Source[OMetaDatabaseAndTable]): r"(?:\w*)(?:[(]*)(\w*)(?:.*)", str(column["type"]) ).groups() data_type_display = column["type"] - if repr(column["type"]).upper().startswith("ARRAY("): - arr_data_type = "STRUCT" - data_type_display = ( - repr(column["type"]) - .replace("(", "<") - .replace(")", ">") - .lower() + if parsed_string is None: + col_type = ColumnTypeParser.get_column_type(column["type"]) + col_constraint = self._get_column_constraints( + column, pk_columns, unique_columns ) - col_constraint = self._get_column_constraints( - column, pk_columns, unique_columns - ) - if col_type.upper() in {"CHAR", "VARCHAR", "BINARY", "VARBINARY"}: - col_data_length = column["type"].length - if col_type == "NULL": - col_type = "VARCHAR" - data_type_display = "varchar" - logger.warning( - f"Unknown type {column['type']} mapped to VARCHAR: {column['name']}" + col_data_length = self._check_col_length( + col_type, column["type"] ) - om_column = Column( - name=column["name"], - description=column.get("comment", None), - dataType=col_type, - dataTypeDisplay="{}({})".format( - col_type, 1 if col_data_length is None else col_data_length + if col_type == "NULL": + col_type = "VARCHAR" + data_type_display = "varchar" + logger.warning( + f"Unknown type {column['type']} mapped to VARCHAR: {column['name']} {column['type']}" + ) + col_data_length = ( + 1 if col_data_length is None else col_data_length ) - if data_type_display is None - else f"{data_type_display}", - dataLength=1 if col_data_length is None else col_data_length, - constraint=col_constraint, - ordinalPosition=row_order, - children=children if children is not None else None, - arrayDataType=arr_data_type, - ) + dataTypeDisplay = ( + f"{data_type_display}" + if data_type_display + else "{}({})".format(col_type, col_data_length) + ) + om_column = Column( + name=column["name"], + description=column.get("comment", None), + dataType=col_type, + dataTypeDisplay=dataTypeDisplay, + dataLength=col_data_length, + constraint=col_constraint, + ordinalPosition=row_order, + children=children if children else None, + arrayDataType=arr_data_type, + ) + else: + parsed_string["dataLength"] = self._check_col_length( + parsed_string["dataType"], column["type"] + ) + if column["raw_data_type"] == "array": + array_data_type_display = ( + repr(column["type"]) + .replace("(", "<") + .replace(")", ">") + .replace("=", ":") + .replace("<>", "") + .lower() + ) + parsed_string[ + "dataTypeDisplay" + ] = f"{array_data_type_display}" + parsed_string[ + "arrayDataType" + ] = ColumnTypeParser._parse_primitive_datatype_string( + array_data_type_display[6:-1] + )[ + "dataType" + ] + col_dict = Column(**parsed_string) + om_column = col_dict except Exception as err: logger.error(traceback.print_exc()) logger.error(f"{err} : {column}") @@ -548,6 +561,15 @@ class SQLSource(Source[OMetaDatabaseAndTable]): logger.error(f"{repr(err)}: {table} {err}") return None + def _check_col_length(self, datatype, col_raw_type): + if datatype.upper() in { + "CHAR", + "VARCHAR", + "BINARY", + "VARBINARY", + }: + return col_raw_type.length if col_raw_type.length else 1 + def run_data_profiler(self, table: str, schema: str) -> TableProfile: """ Run the profiler for a table in a schema. diff --git a/ingestion/src/metadata/utils/column_helpers.py b/ingestion/src/metadata/utils/column_helpers.py deleted file mode 100644 index 9d6df790433..00000000000 --- a/ingestion/src/metadata/utils/column_helpers.py +++ /dev/null @@ -1,306 +0,0 @@ -import re -from typing import Any, Dict, Optional, Set, Type - -from metadata.ingestion.api.source import SourceStatus -from sqlalchemy.sql import sqltypes as types -from sqlalchemy.types import TypeEngine - - -def register_custom_str_type(tp: str, output: str) -> None: - _column_string_mapping[tp] = output - - -def create_sqlalchemy_type(name: str): - sqlalchemy_type = type( - name, - (TypeEngine,), - { - "__repr__": lambda self: f"{name}()", - }, - ) - return sqlalchemy_type - - -_column_type_mapping: Dict[Type[types.TypeEngine], str] = { - types.Integer: "INT", - types.Numeric: "INT", - types.Boolean: "BOOLEAN", - types.Enum: "ENUM", - types._Binary: "BYTES", - types.LargeBinary: "BYTES", - types.PickleType: "BYTES", - types.ARRAY: "ARRAY", - types.VARCHAR: "VARCHAR", - types.String: "STRING", - types.Date: "DATE", - types.DATE: "DATE", - types.Time: "TIME", - types.DateTime: "DATETIME", - types.DATETIME: "DATETIME", - types.TIMESTAMP: "TIMESTAMP", - types.NullType: "NULL", - types.JSON: "JSON", - types.CHAR: "CHAR", - types.DECIMAL: "DECIMAL", - types.Interval: "INTERVAL", -} - -_column_string_mapping = { - "ARRAY": "ARRAY", - "BIGINT": "BIGINT", - "BIGNUMERIC": "NUMERIC", - "BIGSERIAL": "BIGINT", - "BINARY": "BINARY", - "BIT": "INT", - "BLOB": "BLOB", - "BOOL": "BOOLEAN", - "BOOLEAN": "BOOLEAN", - "BPCHAR": "CHAR", - "BYTEINT": "BYTEINT", - "BYTES": "BYTES", - "CHAR": "CHAR", - "CHARACTER VARYING": "VARCHAR", - "DATE": "DATE", - "DATETIME": "DATETIME", - "DATETIME2": "DATETIME", - "DATETIMEOFFSET": "DATETIME", - "DECIMAL": "DECIMAL", - "DOUBLE PRECISION": "DOUBLE", - "DOUBLE": "DOUBLE", - "ENUM": "ENUM", - "FLOAT4": "FLOAT", - "FLOAT64": "DOUBLE", - "FLOAT8": "DOUBLE", - "GEOGRAPHY": "GEOGRAPHY", - "HYPERLOGLOG": "BINARY", - "IMAGE": "BINARY", - "INT": "INT", - "INT2": "SMALLINT", - "INT4": "INT", - "INT64": "BIGINT", - "INT8": "BIGINT", - "INTEGER": "INT", - "INTERVAL DAY TO SECOND": "INTERVAL", - "INTERVAL YEAR TO MONTH": "INTERVAL", - "INTERVAL": "INTERVAL", - "JSON": "JSON", - "LONG RAW": "BINARY", - "LONG VARCHAR": "VARCHAR", - "LONGBLOB": "LONGBLOB", - "MAP": "MAP", - "MEDIUMBLOB": "MEDIUMBLOB", - "MEDIUMINT": "INT", - "MEDIUMTEXT": "MEDIUMTEXT", - "MONEY": "NUMBER", - "NCHAR": "CHAR", - "NTEXT": "TEXT", - "NULL": "NULL", - "NUMBER": "NUMBER", - "NUMERIC": "NUMBER", - "NVARCHAR": "VARCHAR", - "OBJECT": "JSON", - "RAW": "BINARY", - "REAL": "FLOAT", - "ROWID": "VARCHAR", - "ROWVERSION": "NUMBER", - "SET": "SET", - "SMALLDATETIME": "DATETIME", - "SMALLINT": "SMALLINT", - "SMALLMONEY": "NUMBER", - "SMALLSERIAL": "SMALLINT", - "SQL_VARIANT": "VARBINARY", - "STRING": "STRING", - "STRUCT": "STRUCT", - "TABLE": "BINARY", - "TEXT": "TEXT", - "TIME": "TIME", - "TIMESTAMP WITHOUT TIME ZONE": "TIMESTAMP", - "TIMESTAMP": "TIMESTAMP", - "TIMESTAMPTZ": "TIMESTAMP", - "TIMETZ": "TIMESTAMP", - "TINYINT": "TINYINT", - "UNION": "UNION", - "UROWID": "VARCHAR", - "VARBINARY": "VARBINARY", - "VARCHAR": "VARCHAR", - "VARIANT": "JSON", - "XML": "BINARY", - "XMLTYPE": "BINARY", - "CURSOR": "BINARY", - "TIMESTAMP_NTZ": "TIMESTAMP", - "TIMESTAMP_LTZ": "TIMESTAMP", - "TIMESTAMP_TZ": "TIMESTAMP", -} - -_known_unknown_column_types: Set[Type[types.TypeEngine]] = { - types.Interval, - types.CLOB, -} - - -def check_column_complex_type( - status: SourceStatus, dataset_name: str, column_raw_type: Any, col_name: str -): - arr_data_type = None - col_obj = None - data_type_display = None - children = None - if column_raw_type.startswith("struct<"): - col_type = "STRUCT" - col_obj = _handle_complex_data_types( - status, - dataset_name, - f"{col_name}:{column_raw_type}", - ) - if "children" in col_obj and col_obj["children"] is not None: - children = col_obj["children"] - elif column_raw_type.startswith("map<"): - col_type = "MAP" - col_obj = _handle_complex_data_types( - status, - dataset_name, - f"{col_name}:{column_raw_type}", - ) - elif column_raw_type.startswith("array<"): - col_type = "ARRAY" - arr_data_type = re.match(r"(?:array<)(\w*)(?:.*)", column_raw_type) - arr_data_type = arr_data_type.groups()[0].upper() - elif column_raw_type.startswith("uniontype<") or column_raw_type.startswith( - "union<" - ): - col_type = "UNION" - else: - col_type = get_column_type(status, dataset_name, column_raw_type.split("(")[0]) - data_type_display = col_type - if data_type_display is None: - data_type_display = column_raw_type - return col_type, data_type_display, arr_data_type, children - - -def get_column_type(status: SourceStatus, dataset_name: str, column_type: Any) -> str: - type_class: Optional[str] = None - for sql_type in _column_type_mapping.keys(): - if isinstance(column_type, sql_type): - type_class = _column_type_mapping[sql_type] - break - if type_class is None or type_class == "NULL": - for sql_type in _known_unknown_column_types: - if isinstance(column_type, sql_type): - type_class = "NULL" - break - for col_type in _column_string_mapping.keys(): - if str(column_type).split("(")[0].split("<")[0].upper() in col_type: - type_class = _column_string_mapping.get(col_type) - break - else: - type_class = None - if type_class is None or type_class == "NULL": - status.warning( - dataset_name, f"unable to map type {column_type!r} to metadata schema" - ) - type_class = "NULL" - return type_class - - -def get_last_index(nested_str): - counter = 1 - for index, i in enumerate(nested_str): - if i == ">": - counter -= 1 - elif i == "<": - counter += 1 - if counter == 0: - break - index = index - counter - return index - - -def get_array_type(col_type): - col = {} - col["dataType"] = "ARRAY" - col_type = col_type[: get_last_index(col_type) + 2] - col["dataTypeDisplay"] = col_type - col["arrayDataType"] = ( - re.match(r"(?:array<)(\w*)(?:.*)", col_type).groups()[0].upper() - ) - return col - - -def _handle_complex_data_types(status, dataset_name, raw_type: str, level=0): - col = {} - if re.match(r"([\w\s]*)(:)(.*)", raw_type): - name, col_type = raw_type.lstrip("<").split(":", 1) - col["name"] = name - else: - col["name"] = f"field_{level}" - if raw_type.startswith("struct<"): - col_type = raw_type - else: - col_type = raw_type.lstrip("<").split(":", 1)[0] - if col_type.startswith("struct<"): - children = [] - struct_type, col_type = re.match(r"(struct<)(.*)", col_type).groups() - pluck_index = get_last_index(col_type) - pluck_nested = col_type[: pluck_index + 1] - col["dataTypeDisplay"] = struct_type + pluck_nested - while pluck_nested != "": - col["dataType"] = "STRUCT" - plucked = col_type[: get_last_index(col_type)] - counter = 0 - continue_next = False - for index, type in enumerate(plucked.split(",")): - if continue_next: - continue_next = False - continue - if re.match(r"(\w*)(:)(struct)(.*)", type): - col_name, datatype, rest = re.match( - r"(\w*)(?::)(struct)(.*)", ",".join(plucked.split(",")[index:]) - ).groups() - type = f"{col_name}:{datatype}{rest[:get_last_index(rest) + 2]}" - elif type.startswith("struct"): - datatype, rest = re.match( - r"(struct)(.*)", ",".join(plucked.split(",")[index:]) - ).groups() - type = f"{datatype}{rest[:get_last_index(rest) + 2]}" - elif re.match(r"([\w\s]*)(:?)(map)(.*)", type): - get_map_type = ",".join(plucked.split(",")[index:]) - type, col_type = re.match( - r"([\w]*:?map<[\w,]*>)(.*)", get_map_type - ).groups() - continue_next = True - elif re.match(r"([\w\s]*)(:?)(uniontype)(.*)", type): - get_union_type = ",".join(plucked.split(",")[index:]) - type, col_type = re.match( - r"([\w\s]*:?uniontype<[\w\s,]*>)(.*)", get_union_type - ).groups() - continue_next = True - children.append( - _handle_complex_data_types(status, dataset_name, type, counter) - ) - if plucked.endswith(type): - break - counter += 1 - pluck_nested = col_type[get_last_index(col_type) + 3 :] - col["children"] = children - elif col_type.startswith("array"): - col.update(get_array_type(col_type)) - elif col_type.startswith("map"): - col["dataType"] = "MAP" - col["dataTypeDisplay"] = col_type - elif col_type.startswith("uniontype"): - col["dataType"] = "UNION" - col["dataTypeDisplay"] = col_type - else: - if re.match(r"(?:[\w\s]*)(?:\()([\d]*)(?:\))", col_type): - col["dataLength"] = re.match( - r"(?:[\w\s]*)(?:\()([\d]*)(?:\))", col_type - ).groups()[0] - else: - col["dataLength"] = 1 - col["dataType"] = get_column_type( - status, - dataset_name, - re.match(r"([\w\s]*)(?:.*)", col_type).groups()[0], - ) - col["dataTypeDisplay"] = col_type.rstrip(">") - return col diff --git a/ingestion/src/metadata/utils/column_type_parser.py b/ingestion/src/metadata/utils/column_type_parser.py new file mode 100644 index 00000000000..743a8868bc8 --- /dev/null +++ b/ingestion/src/metadata/utils/column_type_parser.py @@ -0,0 +1,303 @@ +import re +from typing import Any, Dict, Optional, Type +from typing import List, Union + +from sqlalchemy.sql import sqltypes as types +from sqlalchemy.types import TypeEngine + + +def create_sqlalchemy_type(name: str): + sqlalchemy_type = type( + name, + (TypeEngine,), + { + "__repr__": lambda self: f"{name}()", + }, + ) + return sqlalchemy_type + + +class ColumnTypeParser: + _BRACKETS = {"(": ")", "[": "]", "{": "}", "<": ">"} + + _COLUMN_TYPE_MAPPING: Dict[Type[types.TypeEngine], str] = { + types.ARRAY: "ARRAY", + types.Boolean: "BOOLEAN", + types.CHAR: "CHAR", + types.CLOB: "BINARY", + types.Date: "DATE", + types.DATE: "DATE", + types.DateTime: "DATETIME", + types.DATETIME: "DATETIME", + types.DECIMAL: "DECIMAL", + types.Enum: "ENUM", + types.Interval: "INTERVAL", + types.JSON: "JSON", + types.LargeBinary: "BYTES", + types.NullType: "NULL", + types.Numeric: "INT", + types.PickleType: "BYTES", + types.String: "STRING", + types.Time: "TIME", + types.TIMESTAMP: "TIMESTAMP", + types.VARCHAR: "VARCHAR", + types.BINARY: "BINARY", + types.INTEGER: "INT", + types.Integer: "INT", + types.BigInteger: "BIGINT", + } + + _SOURCE_TYPE_TO_OM_TYPE = { + "ARRAY": "ARRAY", + "BIGINT": "BIGINT", + "BIGNUMERIC": "NUMERIC", + "BIGSERIAL": "BIGINT", + "BINARY": "BINARY", + "BIT": "INT", + "BLOB": "BLOB", + "BOOL": "BOOLEAN", + "BOOLEAN": "BOOLEAN", + "BPCHAR": "CHAR", + "BYTEINT": "BYTEINT", + "BYTES": "BYTES", + "CHAR": "CHAR", + "CHARACTER VARYING": "VARCHAR", + "CURSOR": "BINARY", + "DATE": "DATE", + "DATETIME": "DATETIME", + "DATETIME2": "DATETIME", + "DATETIMEOFFSET": "DATETIME", + "DECIMAL": "DECIMAL", + "DOUBLE PRECISION": "DOUBLE", + "DOUBLE": "DOUBLE", + "ENUM": "ENUM", + "FLOAT4": "FLOAT", + "FLOAT64": "DOUBLE", + "FLOAT8": "DOUBLE", + "GEOGRAPHY": "GEOGRAPHY", + "HYPERLOGLOG": "BINARY", + "IMAGE": "BINARY", + "INT": "INT", + "INT2": "SMALLINT", + "INT4": "INT", + "INT64": "BIGINT", + "INT8": "BIGINT", + "INTEGER": "INT", + "INTERVAL DAY TO SECOND": "INTERVAL", + "INTERVAL YEAR TO MONTH": "INTERVAL", + "INTERVAL": "INTERVAL", + "JSON": "JSON", + "LONG RAW": "BINARY", + "LONG VARCHAR": "VARCHAR", + "LONGBLOB": "LONGBLOB", + "MAP": "MAP", + "MEDIUMBLOB": "MEDIUMBLOB", + "MEDIUMINT": "INT", + "MEDIUMTEXT": "MEDIUMTEXT", + "MONEY": "NUMBER", + "NCHAR": "CHAR", + "NTEXT": "TEXT", + "NULL": "NULL", + "NUMBER": "NUMBER", + "NUMERIC": "NUMERIC", + "NVARCHAR": "VARCHAR", + "OBJECT": "JSON", + "RAW": "BINARY", + "REAL": "FLOAT", + "RECORD": "STRUCT", + "ROWID": "VARCHAR", + "ROWVERSION": "NUMBER", + "SET": "SET", + "SMALLDATETIME": "DATETIME", + "SMALLINT": "SMALLINT", + "SMALLMONEY": "NUMBER", + "SMALLSERIAL": "SMALLINT", + "SQL_VARIANT": "VARBINARY", + "STRING": "STRING", + "STRUCT": "STRUCT", + "TABLE": "BINARY", + "TEXT": "TEXT", + "TIME": "TIME", + "TIMESTAMP WITHOUT TIME ZONE": "TIMESTAMP", + "TIMESTAMP": "TIMESTAMP", + "TIMESTAMPTZ": "TIMESTAMP", + "TIMESTAMP_NTZ": "TIMESTAMP", + "TIMESTAMP_LTZ": "TIMESTAMP", + "TIMESTAMP_TZ": "TIMESTAMP", + "TIMETZ": "TIMESTAMP", + "TINYINT": "TINYINT", + "UNION": "UNION", + "UROWID": "VARCHAR", + "VARBINARY": "VARBINARY", + "VARCHAR": "VARCHAR", + "VARIANT": "JSON", + "XML": "BINARY", + "XMLTYPE": "BINARY", + } + + _COMPLEX_TYPE = re.compile("^(struct|map|array|uniontype)") + + _FIXED_DECIMAL = re.compile(r"(decimal|numeric)(\(\s*(\d+)\s*,\s*(\d+)\s*\))?") + + _FIXED_STRING = re.compile(r"(var)?char\(\s*(\d+)\s*\)") + + @staticmethod + def get_column_type(column_type: Any) -> str: + type_class: Optional[str] = None + for sql_type in ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.keys(): + if str(column_type) in sql_type: + type_class = ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE[sql_type] + break + if type_class is None or type_class == "NULL": + for col_type in ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.keys(): + if str(column_type).split("(")[0].split("<")[0].upper() in col_type: + type_class = ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get(col_type) + break + else: + type_class = None + return type_class + + @staticmethod + def _parse_datatype_string( + s: str, **kwargs: Any + ) -> Union[object, Dict[str, object]]: + s = s.strip() + if s.startswith("array<"): + if s[-1] != ">": + raise ValueError("expected '>' found: %s" % s) + arr_data_type = ColumnTypeParser._parse_primitive_datatype_string(s[6:-1])[ + "dataType" + ] + return { + "dataType": "ARRAY", + "arrayDataType": arr_data_type, + "dataTypeDisplay": s, + } + elif s.startswith("map<"): + if s[-1] != ">": + raise ValueError("expected '>' found: %s" % s) + parts = ColumnTypeParser._ignore_brackets_split(s[4:-1], ",") + if len(parts) != 2: + raise ValueError( + "The map type string format is: 'map', " + + "but got: %s" % s + ) + kt = ColumnTypeParser._parse_datatype_string(parts[0]) + vt = ColumnTypeParser._parse_datatype_string(parts[1]) + return {"dataType": "MAP", "dataTypeDisplay": s} + elif s.startswith("uniontype<") or s.startswith("union<"): + if s[-1] != ">": + raise ValueError("'>' should be the last char, but got: %s" % s) + parts = ColumnTypeParser._ignore_brackets_split(s[10:-1], ",") + t = [] + for part in parts: + if part.startswith("struct<"): + t.append(ColumnTypeParser._parse_datatype_string(part)) + else: + t.append(ColumnTypeParser._parse_datatype_string(part)) + return t + elif s.startswith("struct<"): + if s[-1] != ">": + raise ValueError("expected '>', found: %s" % s) + return ColumnTypeParser._parse_struct_fields_string(s[7:-1]) + elif ":" in s: + return ColumnTypeParser._parse_struct_fields_string(s) + else: + return ColumnTypeParser._parse_primitive_datatype_string(s) + + @staticmethod + def _parse_struct_fields_string(s: str) -> Dict[str, object]: + parts = ColumnTypeParser._ignore_brackets_split(s, ",") + columns = [] + for part in parts: + name_and_type = ColumnTypeParser._ignore_brackets_split(part, ":") + if len(name_and_type) != 2: + raise ValueError( + "expected format is: 'field_name:field_type', " + + "but got: %s" % part + ) + field_name = name_and_type[0].strip() + if field_name.startswith("`"): + if field_name[-1] != "`": + raise ValueError("'`' should be the last char, but got: %s" % s) + field_name = field_name[1:-1] + field_type = ColumnTypeParser._parse_datatype_string(name_and_type[1]) + field_type["name"] = field_name + columns.append(field_type) + return { + "children": columns, + "dataTypeDisplay": "struct<{}>".format(s), + "dataType": "STRUCT", + } + + @staticmethod + def _parse_primitive_datatype_string(s: str) -> Dict[str, object]: + if s.upper() in ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.keys(): + return { + "dataType": ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE[s.upper()], + "dataTypeDisplay": s, + } + elif ColumnTypeParser._FIXED_STRING.match(s): + m = ColumnTypeParser._FIXED_STRING.match(s) + return {"type": "STRING", "dataTypeDisplay": s} + elif ColumnTypeParser._FIXED_DECIMAL.match(s): + m = ColumnTypeParser._FIXED_DECIMAL.match(s) + if m.group(2) is not None: # type: ignore + return { + "dataType": ColumnTypeParser.get_column_type(m.group(0)), + "dataTypeDisplay": s, + "dataLength": int(m.group(3)), # type: ignore + } + else: + return { + "dataType": ColumnTypeParser.get_column_type(m.group(0)), + "dataTypeDisplay": s, + } + elif s == "date": + return {"dataType": "DATE", "dataTypeDisplay": s} + elif s == "timestamp": + return {"dataType": "TIMESTAMP", "dataTypeDisplay": s} + else: + dataType = ColumnTypeParser.get_column_type(s) + if not dataType: + return {"dataType": "NULL", "dataTypeDisplay": s} + else: + dataLength = 1 + if re.match(".*(\([\w]*\))", s): + dataLength = re.match(".*\(([\w]*)\)", s).groups()[0] + return { + "dataType": dataType, + "dataTypeDisplay": dataType, + "dataLength": dataLength, + } + + @staticmethod + def _ignore_brackets_split(s: str, separator: str) -> List[str]: + parts = [] + buf = "" + level = 0 + for c in s: + if c in ColumnTypeParser._BRACKETS.keys(): + level += 1 + buf += c + elif c in ColumnTypeParser._BRACKETS.values(): + if level == 0: + raise ValueError("Brackets are not correctly paired: %s" % s) + level -= 1 + buf += c + elif c == separator and level > 0: + buf += c + elif c == separator: + parts.append(buf) + buf = "" + else: + buf += c + + if len(buf) == 0: + raise ValueError("The %s cannot be the last char: %s" % (separator, s)) + parts.append(buf) + return parts + + @staticmethod + def is_primitive_om_type(raw_type: str) -> bool: + return not ColumnTypeParser._COMPLEX_TYPE.match(raw_type) diff --git a/ingestion/tests/unit/resources/expected_output_column_parser.json b/ingestion/tests/unit/resources/expected_output_column_parser.json new file mode 100644 index 00000000000..7a0a0f429f9 --- /dev/null +++ b/ingestion/tests/unit/resources/expected_output_column_parser.json @@ -0,0 +1,110 @@ +{ + "data": [ + { + "dataType": "ARRAY", + "arrayDataType": "STRING", + "dataTypeDisplay": "array" + }, + { + "children": [ + { + "dataType": "INT", + "dataTypeDisplay": "int", + "name": "a" + }, + { + "dataType": "STRING", + "dataTypeDisplay": "string", + "name": "b" + } + ], + "dataTypeDisplay": "struct", + "dataType": "STRUCT" + }, + { + "children": [ + { + "children": [ + { + "dataType": "ARRAY", + "arrayDataType": "STRING", + "dataTypeDisplay": "array", + "name": "b" + }, + { + "dataType": "BIGINT", + "dataTypeDisplay": "bigint", + "name": "c" + } + ], + "dataTypeDisplay": "struct,c:bigint>", + "dataType": "STRUCT", + "name": "a" + } + ], + "dataTypeDisplay": "struct,c:bigint>>", + "dataType": "STRUCT" + }, + { + "children": [ + { + "dataType": "ARRAY", + "arrayDataType": "STRING", + "dataTypeDisplay": "array", + "name": "a" + } + ], + "dataTypeDisplay": "struct>", + "dataType": "STRUCT" + }, + { + "children": [ + { + "dataType": "ARRAY", + "arrayDataType": "STRUCT", + "dataTypeDisplay": "array>>", + "name": "bigquerytestdatatype51" + } + ], + "dataTypeDisplay": "struct>>>", + "dataType": "STRUCT" + }, + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "dataType": "STRING", + "dataTypeDisplay": "string", + "name": "record_4" + } + ], + "dataTypeDisplay": "struct", + "dataType": "STRUCT", + "name": "record_3" + } + ], + "dataTypeDisplay": "struct>", + "dataType": "STRUCT", + "name": "record_2" + } + ], + "dataTypeDisplay": "struct>>", + "dataType": "STRUCT", + "name": "record_1" + } + ], + "dataTypeDisplay": "struct>>>", + "dataType": "STRUCT" + }, + { + "dataType": "ARRAY", + "arrayDataType": "STRUCT", + "dataTypeDisplay": "array>>" + } + ] +} \ No newline at end of file diff --git a/ingestion/tests/unit/test_column_type_parser.py b/ingestion/tests/unit/test_column_type_parser.py new file mode 100644 index 00000000000..9d6b0243c02 --- /dev/null +++ b/ingestion/tests/unit/test_column_type_parser.py @@ -0,0 +1,32 @@ +import os +from unittest import TestCase + +from metadata.utils.column_type_parser import ColumnTypeParser + +COLUMN_TYPE_PARSE = [ + "array", + "struct", + "struct,c:bigint>>", + "struct>", + "struct>>>", + "struct>>>", + "array>>", +] +root = os.path.dirname(__file__) +import json + +try: + with open(os.path.join(root, "resources/expected_output_column_parser.json")) as f: + EXPECTED_OUTPUT = json.loads(f.read())["data"] +except Exception as err: + print(err) + + +class ColumnTypeParseTest(TestCase): + def test_check_datatype_support(self): + for index, parse_string in enumerate(COLUMN_TYPE_PARSE): + parsed_string = ColumnTypeParser._parse_datatype_string(parse_string) + self.assertTrue( + True if parsed_string == EXPECTED_OUTPUT[index] else False, + msg=f"{index}: {COLUMN_TYPE_PARSE[index]} : {parsed_string}", + ) diff --git a/ingestion/tests/unit/datatypes_test.py b/ingestion/tests/unit/test_datatypes.py similarity index 84% rename from ingestion/tests/unit/datatypes_test.py rename to ingestion/tests/unit/test_datatypes.py index d033ebdb7fa..8c7577eb869 100644 --- a/ingestion/tests/unit/datatypes_test.py +++ b/ingestion/tests/unit/test_datatypes.py @@ -1,11 +1,8 @@ -import unittest from unittest import TestCase -from metadata.ingestion.api.source import SourceStatus -from metadata.ingestion.source.sql_source import SQLSourceStatus -from metadata.utils.column_helpers import get_column_type +from metadata.utils.column_type_parser import ColumnTypeParser -SQLTYPES = [ +SQLTYPES = { "ARRAY", "BIGINT", "BIGNUMERIC", @@ -93,14 +90,13 @@ SQLTYPES = [ "TIMESTAMP_NTZ", "TIMESTAMP_LTZ", "TIMESTAMP_TZ", -] +} class DataTypeTest(TestCase): def test_check_datatype_support(self): - status = SQLSourceStatus() for types in SQLTYPES: with self.subTest(line=types): - col_type = get_column_type(status, "Unit Test", types) + col_type = ColumnTypeParser.get_column_type(types) col_type = True if col_type != "NULL" else False self.assertTrue(col_type, msg=types)