Fix 2270: Column Type Parser (#2271)

* Fix 2270: Column Type Parser

* Fix 2270: Column Type Parser

* Added checks to allow arrayDataType and dataTypeDisplay

* Modified - arrayDataType and dataTypeDisplay

* Update sql_source.py

* Update sql_source.py

* file formatted

* Modified according to column_type_parser.py

* modified / refactored / deleted the files

* Added Tests, modified sql_source

* file formatted

* Added missing datatypes

* Added Tests

* Added Tests

* Added Tests - refactored expected output into a json file

* file formatted

* Sample Data Updated

Co-authored-by: Ayush Shah <ayush@getcollate.io>
This commit is contained in:
Sriharsha Chintalapani 2022-01-22 10:02:58 -08:00 committed by GitHub
parent e2602b81fb
commit f390bac375
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 553 additions and 431 deletions

View File

@ -10,15 +10,11 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re
import textwrap
import traceback import traceback
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Iterable, List, Optional from typing import Iterable, List, Optional
from pydantic import SecretStr
from metadata.config.common import ConfigModel from metadata.config.common import ConfigModel
from metadata.generated.schema.api.services.createDatabaseService import ( from metadata.generated.schema.api.services.createDatabaseService import (
CreateDatabaseServiceEntityRequest, CreateDatabaseServiceEntityRequest,
@ -38,13 +34,14 @@ from metadata.ingestion.models.user import User
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.neo4j_helper import Neo4JConfig, Neo4jHelper from metadata.ingestion.source.neo4j_helper import Neo4JConfig, Neo4jHelper
from metadata.utils.column_helpers import check_column_complex_type, get_column_type from metadata.utils.column_type_parser import ColumnTypeParser
from metadata.utils.helpers import get_dashboard_service_or_create from metadata.utils.helpers import get_dashboard_service_or_create
from metadata.utils.sql_queries import ( from metadata.utils.sql_queries import (
NEO4J_AMUNDSEN_DASHBOARD_QUERY, NEO4J_AMUNDSEN_DASHBOARD_QUERY,
NEO4J_AMUNDSEN_TABLE_QUERY, NEO4J_AMUNDSEN_TABLE_QUERY,
NEO4J_AMUNDSEN_USER_QUERY, NEO4J_AMUNDSEN_USER_QUERY,
) )
from pydantic import SecretStr
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)
@ -155,28 +152,10 @@ class AmundsenSource(Source[Entity]):
# Amundsen merges the length into type itself. Instead of making changes to our generic type builder # Amundsen merges the length into type itself. Instead of making changes to our generic type builder
# we will do a type match and see if it matches any primitive types and return a type # we will do a type match and see if it matches any primitive types and return a type
data_type = self.get_type_primitive_type(data_type) data_type = self.get_type_primitive_type(data_type)
( parsed_string = ColumnTypeParser._parse_datatype_string(data_type)
col_type, parsed_string["name"] = name
data_type_display, parsed_string["dataLength"] = 1
arr_data_type, col = Column(**parsed_string)
children,
) = check_column_complex_type(
self.status, table["name"], data_type, name
)
col = Column(
name=name,
description=description,
dataType=col_type,
dataTypeDisplay="{}({})".format(col_type, 1)
if data_type_display is None
else f"{data_type_display}",
children=children,
arrayDataType=arr_data_type,
ordinalPosition=row_order,
dataLength=1,
)
row_order += 1
columns.append(col) columns.append(col)
fqn = f"{service_name}.{database.name}.{table['schema']}.{table['name']}" fqn = f"{service_name}.{database.name}.{table['schema']}.{table['name']}"

View File

@ -13,6 +13,10 @@ import os
from typing import Optional, Tuple, Any from typing import Optional, Tuple, Any
import json, tempfile, logging import json, tempfile, logging
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLSource
from metadata.ingestion.source.sql_source_common import SQLConnectionConfig
from metadata.utils.column_type_parser import create_sqlalchemy_type
from sqlalchemy_bigquery import _types from sqlalchemy_bigquery import _types
from sqlalchemy_bigquery._struct import STRUCT from sqlalchemy_bigquery._struct import STRUCT
from sqlalchemy_bigquery._types import ( from sqlalchemy_bigquery._types import (
@ -20,11 +24,6 @@ from sqlalchemy_bigquery._types import (
_get_transitive_schema_fields, _get_transitive_schema_fields,
) )
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLSource
from metadata.ingestion.source.sql_source_common import SQLConnectionConfig
from metadata.utils.column_helpers import create_sqlalchemy_type
GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY") GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY")
_types._type_map["GEOGRAPHY"] = GEOGRAPHY _types._type_map["GEOGRAPHY"] = GEOGRAPHY
@ -44,7 +43,7 @@ def get_columns(bq_schema):
"precision": field.precision, "precision": field.precision,
"scale": field.scale, "scale": field.scale,
"max_length": field.max_length, "max_length": field.max_length,
"raw_data_type": repr(_get_sqla_column_type(field)), "raw_data_type": str(_get_sqla_column_type(field)),
} }
col_list.append(col_obj) col_list.append(col_obj)
return col_list return col_list
@ -75,9 +74,9 @@ class BigquerySource(SQLSource):
metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict) metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict)
if config.options.get("credentials", None): if config.options.get("credentials", None):
cred_path = create_credential_temp_file(config.options.get("credentials")) cred_path = create_credential_temp_file(config.options.get("credentials"))
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path
del config.options["credentials"] config.options["credentials_path"] = cred_path
config.options["credentials_path"] = cred_path del config.options["credentials"]
return cls(config, metadata_config, ctx) return cls(config, metadata_config, ctx)
def close(self): def close(self):

View File

@ -29,7 +29,7 @@ from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source_common import SQLSourceStatus from metadata.ingestion.source.sql_source_common import SQLSourceStatus
from metadata.utils.aws_client import AWSClient, AWSClientConfigModel from metadata.utils.aws_client import AWSClient, AWSClientConfigModel
from metadata.utils.column_helpers import check_column_complex_type from metadata.utils.column_type_parser import ColumnTypeParser
from metadata.utils.helpers import ( from metadata.utils.helpers import (
get_database_service_or_create, get_database_service_or_create,
get_pipeline_service_or_create, get_pipeline_service_or_create,
@ -75,7 +75,7 @@ class GlueSource(Source[Entity]):
"serviceType": "Glue", "serviceType": "Glue",
"pipelineUrl": self.config.endpoint_url "pipelineUrl": self.config.endpoint_url
if self.config.endpoint_url is not None if self.config.endpoint_url is not None
else f"https://glue.{ self.config.region_name }.amazonaws.com", else f"https://glue.{self.config.region_name}.amazonaws.com",
}, },
metadata_config, metadata_config,
) )
@ -116,31 +116,20 @@ class GlueSource(Source[Entity]):
yield from self.ingest_pipelines() yield from self.ingest_pipelines()
def get_columns(self, column_data): def get_columns(self, column_data):
row_order = 0 for index, column in enumerate(column_data["Columns"]):
for column in column_data["Columns"]:
if column["Type"].lower().startswith("union"): if column["Type"].lower().startswith("union"):
column["Type"] = column["Type"].replace(" ", "") column["Type"] = column["Type"].replace(" ", "")
( parsed_string = ColumnTypeParser._parse_datatype_string(
col_type, column["Type"].lower()
data_type_display,
arr_data_type,
children,
) = check_column_complex_type(
self.status, self.dataset_name, column["Type"].lower(), column["Name"]
) )
yield Column( if isinstance(parsed_string, list):
name=column["Name"].replace(".", "_DOT_")[:128], parsed_string = {}
description="", parsed_string["dataTypeDisplay"] = str(column["Type"])
dataType=col_type, parsed_string["dataType"] = "UNION"
dataTypeDisplay="{}({})".format(col_type, 1) parsed_string["name"] = column["Name"][:64]
if data_type_display is None parsed_string["ordinalPosition"] = index
else f"{data_type_display}", parsed_string["dataLength"] = parsed_string.get("dataLength", 1)
ordinalPosition=row_order, yield Column(**parsed_string)
children=children,
arrayDataType=arr_data_type,
dataLength=1,
)
row_order += 1
def ingest_tables(self, next_tables_token=None) -> Iterable[OMetaDatabaseAndTable]: def ingest_tables(self, next_tables_token=None) -> Iterable[OMetaDatabaseAndTable]:
try: try:

View File

@ -169,11 +169,9 @@ class SampleDataSource(Source[Entity]):
self.config = config self.config = config
self.metadata_config = metadata_config self.metadata_config = metadata_config
self.metadata = OpenMetadata(metadata_config) self.metadata = OpenMetadata(metadata_config)
print("hi")
self.storage_service_json = json.load( self.storage_service_json = json.load(
open(self.config.sample_data_folder + "/locations/service.json", "r") open(self.config.sample_data_folder + "/locations/service.json", "r")
) )
print("hello")
self.locations = json.load( self.locations = json.load(
open(self.config.sample_data_folder + "/locations/locations.json", "r") open(self.config.sample_data_folder + "/locations/locations.json", "r")
) )

View File

@ -17,7 +17,7 @@ from snowflake.sqlalchemy.snowdialect import ischema_names
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLSource from metadata.ingestion.source.sql_source import SQLSource
from metadata.ingestion.source.sql_source_common import SQLConnectionConfig from metadata.ingestion.source.sql_source_common import SQLConnectionConfig
from metadata.utils.column_helpers import create_sqlalchemy_type from metadata.utils.column_type_parser import create_sqlalchemy_type
GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY") GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY")
ischema_names["VARIANT"] = VARIANT ischema_names["VARIANT"] = VARIANT

View File

@ -18,10 +18,6 @@ import traceback
import uuid import uuid
from typing import Dict, Iterable, List, Optional, Tuple from typing import Dict, Iterable, List, Optional, Tuple
from sqlalchemy import create_engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.inspection import inspect
from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.table import ( from metadata.generated.schema.entity.data.table import (
Column, Column,
@ -43,8 +39,11 @@ from metadata.ingestion.source.sql_source_common import (
SQLConnectionConfig, SQLConnectionConfig,
SQLSourceStatus, SQLSourceStatus,
) )
from metadata.utils.column_helpers import check_column_complex_type, get_column_type from metadata.utils.column_type_parser import ColumnTypeParser
from metadata.utils.helpers import get_database_service_or_create from metadata.utils.helpers import get_database_service_or_create
from sqlalchemy import create_engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.inspection import inspect
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)
@ -387,7 +386,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
ccolumn = ccolumns[key] ccolumn = ccolumns[key]
try: try:
ctype = ccolumn["type"] ctype = ccolumn["type"]
col_type = get_column_type(self.status, model_name, ctype) col_type = ColumnTypeParser.get_column_type(ctype)
description = manifest_columns.get(key.lower(), {}).get( description = manifest_columns.get(key.lower(), {}).get(
"description", None "description", None
) )
@ -460,21 +459,20 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
if "column_names" in constraint.keys() if "column_names" in constraint.keys()
] ]
dataset_name = f"{schema}.{table}"
table_columns = [] table_columns = []
columns = inspector.get_columns(table, schema) columns = inspector.get_columns(table, schema)
try: try:
for row_order, column in enumerate(columns): for row_order, column in enumerate(columns):
try: try:
col_dict = None
if "." in column["name"]: if "." in column["name"]:
logger.info(
f"Found '.' in {column['name']}, changing '.' to '_DOT_'"
)
column["name"] = column["name"].replace(".", "_DOT_") column["name"] = column["name"].replace(".", "_DOT_")
children = None children = None
data_type_display = None data_type_display = None
col_data_length = None col_data_length = None
arr_data_type = None arr_data_type = None
parsed_string = None
print(column["raw_data_type"])
if ( if (
"raw_data_type" in column "raw_data_type" in column
and column["raw_data_type"] is not None and column["raw_data_type"] is not None
@ -482,21 +480,12 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
column["raw_data_type"] = self.parse_raw_data_type( column["raw_data_type"] = self.parse_raw_data_type(
column["raw_data_type"] column["raw_data_type"]
) )
( parsed_string = ColumnTypeParser._parse_datatype_string(
col_type, column["raw_data_type"]
data_type_display,
arr_data_type,
children,
) = check_column_complex_type(
self.status,
dataset_name,
column["raw_data_type"],
column["name"],
) )
parsed_string["name"] = column["name"]
else: else:
col_type = get_column_type( col_type = ColumnTypeParser.get_column_type(column["type"])
self.status, dataset_name, column["type"]
)
if col_type == "ARRAY" and re.match( if col_type == "ARRAY" and re.match(
r"(?:\w*)(?:\()(\w*)(?:.*)", str(column["type"]) r"(?:\w*)(?:\()(\w*)(?:.*)", str(column["type"])
): ):
@ -504,40 +493,64 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
r"(?:\w*)(?:[(]*)(\w*)(?:.*)", str(column["type"]) r"(?:\w*)(?:[(]*)(\w*)(?:.*)", str(column["type"])
).groups() ).groups()
data_type_display = column["type"] data_type_display = column["type"]
if repr(column["type"]).upper().startswith("ARRAY("): if parsed_string is None:
arr_data_type = "STRUCT" col_type = ColumnTypeParser.get_column_type(column["type"])
data_type_display = ( col_constraint = self._get_column_constraints(
repr(column["type"]) column, pk_columns, unique_columns
.replace("(", "<")
.replace(")", ">")
.lower()
) )
col_constraint = self._get_column_constraints( col_data_length = self._check_col_length(
column, pk_columns, unique_columns col_type, column["type"]
)
if col_type.upper() in {"CHAR", "VARCHAR", "BINARY", "VARBINARY"}:
col_data_length = column["type"].length
if col_type == "NULL":
col_type = "VARCHAR"
data_type_display = "varchar"
logger.warning(
f"Unknown type {column['type']} mapped to VARCHAR: {column['name']}"
) )
om_column = Column( if col_type == "NULL":
name=column["name"], col_type = "VARCHAR"
description=column.get("comment", None), data_type_display = "varchar"
dataType=col_type, logger.warning(
dataTypeDisplay="{}({})".format( f"Unknown type {column['type']} mapped to VARCHAR: {column['name']} {column['type']}"
col_type, 1 if col_data_length is None else col_data_length )
col_data_length = (
1 if col_data_length is None else col_data_length
) )
if data_type_display is None dataTypeDisplay = (
else f"{data_type_display}", f"{data_type_display}"
dataLength=1 if col_data_length is None else col_data_length, if data_type_display
constraint=col_constraint, else "{}({})".format(col_type, col_data_length)
ordinalPosition=row_order, )
children=children if children is not None else None, om_column = Column(
arrayDataType=arr_data_type, name=column["name"],
) description=column.get("comment", None),
dataType=col_type,
dataTypeDisplay=dataTypeDisplay,
dataLength=col_data_length,
constraint=col_constraint,
ordinalPosition=row_order,
children=children if children else None,
arrayDataType=arr_data_type,
)
else:
parsed_string["dataLength"] = self._check_col_length(
parsed_string["dataType"], column["type"]
)
if column["raw_data_type"] == "array":
array_data_type_display = (
repr(column["type"])
.replace("(", "<")
.replace(")", ">")
.replace("=", ":")
.replace("<>", "")
.lower()
)
parsed_string[
"dataTypeDisplay"
] = f"{array_data_type_display}"
parsed_string[
"arrayDataType"
] = ColumnTypeParser._parse_primitive_datatype_string(
array_data_type_display[6:-1]
)[
"dataType"
]
col_dict = Column(**parsed_string)
om_column = col_dict
except Exception as err: except Exception as err:
logger.error(traceback.print_exc()) logger.error(traceback.print_exc())
logger.error(f"{err} : {column}") logger.error(f"{err} : {column}")
@ -548,6 +561,15 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
logger.error(f"{repr(err)}: {table} {err}") logger.error(f"{repr(err)}: {table} {err}")
return None return None
def _check_col_length(self, datatype, col_raw_type):
if datatype.upper() in {
"CHAR",
"VARCHAR",
"BINARY",
"VARBINARY",
}:
return col_raw_type.length if col_raw_type.length else 1
def run_data_profiler(self, table: str, schema: str) -> TableProfile: def run_data_profiler(self, table: str, schema: str) -> TableProfile:
""" """
Run the profiler for a table in a schema. Run the profiler for a table in a schema.

View File

@ -1,306 +0,0 @@
import re
from typing import Any, Dict, Optional, Set, Type
from metadata.ingestion.api.source import SourceStatus
from sqlalchemy.sql import sqltypes as types
from sqlalchemy.types import TypeEngine
def register_custom_str_type(tp: str, output: str) -> None:
_column_string_mapping[tp] = output
def create_sqlalchemy_type(name: str):
sqlalchemy_type = type(
name,
(TypeEngine,),
{
"__repr__": lambda self: f"{name}()",
},
)
return sqlalchemy_type
_column_type_mapping: Dict[Type[types.TypeEngine], str] = {
types.Integer: "INT",
types.Numeric: "INT",
types.Boolean: "BOOLEAN",
types.Enum: "ENUM",
types._Binary: "BYTES",
types.LargeBinary: "BYTES",
types.PickleType: "BYTES",
types.ARRAY: "ARRAY",
types.VARCHAR: "VARCHAR",
types.String: "STRING",
types.Date: "DATE",
types.DATE: "DATE",
types.Time: "TIME",
types.DateTime: "DATETIME",
types.DATETIME: "DATETIME",
types.TIMESTAMP: "TIMESTAMP",
types.NullType: "NULL",
types.JSON: "JSON",
types.CHAR: "CHAR",
types.DECIMAL: "DECIMAL",
types.Interval: "INTERVAL",
}
_column_string_mapping = {
"ARRAY": "ARRAY",
"BIGINT": "BIGINT",
"BIGNUMERIC": "NUMERIC",
"BIGSERIAL": "BIGINT",
"BINARY": "BINARY",
"BIT": "INT",
"BLOB": "BLOB",
"BOOL": "BOOLEAN",
"BOOLEAN": "BOOLEAN",
"BPCHAR": "CHAR",
"BYTEINT": "BYTEINT",
"BYTES": "BYTES",
"CHAR": "CHAR",
"CHARACTER VARYING": "VARCHAR",
"DATE": "DATE",
"DATETIME": "DATETIME",
"DATETIME2": "DATETIME",
"DATETIMEOFFSET": "DATETIME",
"DECIMAL": "DECIMAL",
"DOUBLE PRECISION": "DOUBLE",
"DOUBLE": "DOUBLE",
"ENUM": "ENUM",
"FLOAT4": "FLOAT",
"FLOAT64": "DOUBLE",
"FLOAT8": "DOUBLE",
"GEOGRAPHY": "GEOGRAPHY",
"HYPERLOGLOG": "BINARY",
"IMAGE": "BINARY",
"INT": "INT",
"INT2": "SMALLINT",
"INT4": "INT",
"INT64": "BIGINT",
"INT8": "BIGINT",
"INTEGER": "INT",
"INTERVAL DAY TO SECOND": "INTERVAL",
"INTERVAL YEAR TO MONTH": "INTERVAL",
"INTERVAL": "INTERVAL",
"JSON": "JSON",
"LONG RAW": "BINARY",
"LONG VARCHAR": "VARCHAR",
"LONGBLOB": "LONGBLOB",
"MAP": "MAP",
"MEDIUMBLOB": "MEDIUMBLOB",
"MEDIUMINT": "INT",
"MEDIUMTEXT": "MEDIUMTEXT",
"MONEY": "NUMBER",
"NCHAR": "CHAR",
"NTEXT": "TEXT",
"NULL": "NULL",
"NUMBER": "NUMBER",
"NUMERIC": "NUMBER",
"NVARCHAR": "VARCHAR",
"OBJECT": "JSON",
"RAW": "BINARY",
"REAL": "FLOAT",
"ROWID": "VARCHAR",
"ROWVERSION": "NUMBER",
"SET": "SET",
"SMALLDATETIME": "DATETIME",
"SMALLINT": "SMALLINT",
"SMALLMONEY": "NUMBER",
"SMALLSERIAL": "SMALLINT",
"SQL_VARIANT": "VARBINARY",
"STRING": "STRING",
"STRUCT": "STRUCT",
"TABLE": "BINARY",
"TEXT": "TEXT",
"TIME": "TIME",
"TIMESTAMP WITHOUT TIME ZONE": "TIMESTAMP",
"TIMESTAMP": "TIMESTAMP",
"TIMESTAMPTZ": "TIMESTAMP",
"TIMETZ": "TIMESTAMP",
"TINYINT": "TINYINT",
"UNION": "UNION",
"UROWID": "VARCHAR",
"VARBINARY": "VARBINARY",
"VARCHAR": "VARCHAR",
"VARIANT": "JSON",
"XML": "BINARY",
"XMLTYPE": "BINARY",
"CURSOR": "BINARY",
"TIMESTAMP_NTZ": "TIMESTAMP",
"TIMESTAMP_LTZ": "TIMESTAMP",
"TIMESTAMP_TZ": "TIMESTAMP",
}
_known_unknown_column_types: Set[Type[types.TypeEngine]] = {
types.Interval,
types.CLOB,
}
def check_column_complex_type(
status: SourceStatus, dataset_name: str, column_raw_type: Any, col_name: str
):
arr_data_type = None
col_obj = None
data_type_display = None
children = None
if column_raw_type.startswith("struct<"):
col_type = "STRUCT"
col_obj = _handle_complex_data_types(
status,
dataset_name,
f"{col_name}:{column_raw_type}",
)
if "children" in col_obj and col_obj["children"] is not None:
children = col_obj["children"]
elif column_raw_type.startswith("map<"):
col_type = "MAP"
col_obj = _handle_complex_data_types(
status,
dataset_name,
f"{col_name}:{column_raw_type}",
)
elif column_raw_type.startswith("array<"):
col_type = "ARRAY"
arr_data_type = re.match(r"(?:array<)(\w*)(?:.*)", column_raw_type)
arr_data_type = arr_data_type.groups()[0].upper()
elif column_raw_type.startswith("uniontype<") or column_raw_type.startswith(
"union<"
):
col_type = "UNION"
else:
col_type = get_column_type(status, dataset_name, column_raw_type.split("(")[0])
data_type_display = col_type
if data_type_display is None:
data_type_display = column_raw_type
return col_type, data_type_display, arr_data_type, children
def get_column_type(status: SourceStatus, dataset_name: str, column_type: Any) -> str:
type_class: Optional[str] = None
for sql_type in _column_type_mapping.keys():
if isinstance(column_type, sql_type):
type_class = _column_type_mapping[sql_type]
break
if type_class is None or type_class == "NULL":
for sql_type in _known_unknown_column_types:
if isinstance(column_type, sql_type):
type_class = "NULL"
break
for col_type in _column_string_mapping.keys():
if str(column_type).split("(")[0].split("<")[0].upper() in col_type:
type_class = _column_string_mapping.get(col_type)
break
else:
type_class = None
if type_class is None or type_class == "NULL":
status.warning(
dataset_name, f"unable to map type {column_type!r} to metadata schema"
)
type_class = "NULL"
return type_class
def get_last_index(nested_str):
counter = 1
for index, i in enumerate(nested_str):
if i == ">":
counter -= 1
elif i == "<":
counter += 1
if counter == 0:
break
index = index - counter
return index
def get_array_type(col_type):
col = {}
col["dataType"] = "ARRAY"
col_type = col_type[: get_last_index(col_type) + 2]
col["dataTypeDisplay"] = col_type
col["arrayDataType"] = (
re.match(r"(?:array<)(\w*)(?:.*)", col_type).groups()[0].upper()
)
return col
def _handle_complex_data_types(status, dataset_name, raw_type: str, level=0):
col = {}
if re.match(r"([\w\s]*)(:)(.*)", raw_type):
name, col_type = raw_type.lstrip("<").split(":", 1)
col["name"] = name
else:
col["name"] = f"field_{level}"
if raw_type.startswith("struct<"):
col_type = raw_type
else:
col_type = raw_type.lstrip("<").split(":", 1)[0]
if col_type.startswith("struct<"):
children = []
struct_type, col_type = re.match(r"(struct<)(.*)", col_type).groups()
pluck_index = get_last_index(col_type)
pluck_nested = col_type[: pluck_index + 1]
col["dataTypeDisplay"] = struct_type + pluck_nested
while pluck_nested != "":
col["dataType"] = "STRUCT"
plucked = col_type[: get_last_index(col_type)]
counter = 0
continue_next = False
for index, type in enumerate(plucked.split(",")):
if continue_next:
continue_next = False
continue
if re.match(r"(\w*)(:)(struct)(.*)", type):
col_name, datatype, rest = re.match(
r"(\w*)(?::)(struct)(.*)", ",".join(plucked.split(",")[index:])
).groups()
type = f"{col_name}:{datatype}{rest[:get_last_index(rest) + 2]}"
elif type.startswith("struct"):
datatype, rest = re.match(
r"(struct)(.*)", ",".join(plucked.split(",")[index:])
).groups()
type = f"{datatype}{rest[:get_last_index(rest) + 2]}"
elif re.match(r"([\w\s]*)(:?)(map)(.*)", type):
get_map_type = ",".join(plucked.split(",")[index:])
type, col_type = re.match(
r"([\w]*:?map<[\w,]*>)(.*)", get_map_type
).groups()
continue_next = True
elif re.match(r"([\w\s]*)(:?)(uniontype)(.*)", type):
get_union_type = ",".join(plucked.split(",")[index:])
type, col_type = re.match(
r"([\w\s]*:?uniontype<[\w\s,]*>)(.*)", get_union_type
).groups()
continue_next = True
children.append(
_handle_complex_data_types(status, dataset_name, type, counter)
)
if plucked.endswith(type):
break
counter += 1
pluck_nested = col_type[get_last_index(col_type) + 3 :]
col["children"] = children
elif col_type.startswith("array"):
col.update(get_array_type(col_type))
elif col_type.startswith("map"):
col["dataType"] = "MAP"
col["dataTypeDisplay"] = col_type
elif col_type.startswith("uniontype"):
col["dataType"] = "UNION"
col["dataTypeDisplay"] = col_type
else:
if re.match(r"(?:[\w\s]*)(?:\()([\d]*)(?:\))", col_type):
col["dataLength"] = re.match(
r"(?:[\w\s]*)(?:\()([\d]*)(?:\))", col_type
).groups()[0]
else:
col["dataLength"] = 1
col["dataType"] = get_column_type(
status,
dataset_name,
re.match(r"([\w\s]*)(?:.*)", col_type).groups()[0],
)
col["dataTypeDisplay"] = col_type.rstrip(">")
return col

View File

@ -0,0 +1,303 @@
import re
from typing import Any, Dict, Optional, Type
from typing import List, Union
from sqlalchemy.sql import sqltypes as types
from sqlalchemy.types import TypeEngine
def create_sqlalchemy_type(name: str):
sqlalchemy_type = type(
name,
(TypeEngine,),
{
"__repr__": lambda self: f"{name}()",
},
)
return sqlalchemy_type
class ColumnTypeParser:
_BRACKETS = {"(": ")", "[": "]", "{": "}", "<": ">"}
_COLUMN_TYPE_MAPPING: Dict[Type[types.TypeEngine], str] = {
types.ARRAY: "ARRAY",
types.Boolean: "BOOLEAN",
types.CHAR: "CHAR",
types.CLOB: "BINARY",
types.Date: "DATE",
types.DATE: "DATE",
types.DateTime: "DATETIME",
types.DATETIME: "DATETIME",
types.DECIMAL: "DECIMAL",
types.Enum: "ENUM",
types.Interval: "INTERVAL",
types.JSON: "JSON",
types.LargeBinary: "BYTES",
types.NullType: "NULL",
types.Numeric: "INT",
types.PickleType: "BYTES",
types.String: "STRING",
types.Time: "TIME",
types.TIMESTAMP: "TIMESTAMP",
types.VARCHAR: "VARCHAR",
types.BINARY: "BINARY",
types.INTEGER: "INT",
types.Integer: "INT",
types.BigInteger: "BIGINT",
}
_SOURCE_TYPE_TO_OM_TYPE = {
"ARRAY": "ARRAY",
"BIGINT": "BIGINT",
"BIGNUMERIC": "NUMERIC",
"BIGSERIAL": "BIGINT",
"BINARY": "BINARY",
"BIT": "INT",
"BLOB": "BLOB",
"BOOL": "BOOLEAN",
"BOOLEAN": "BOOLEAN",
"BPCHAR": "CHAR",
"BYTEINT": "BYTEINT",
"BYTES": "BYTES",
"CHAR": "CHAR",
"CHARACTER VARYING": "VARCHAR",
"CURSOR": "BINARY",
"DATE": "DATE",
"DATETIME": "DATETIME",
"DATETIME2": "DATETIME",
"DATETIMEOFFSET": "DATETIME",
"DECIMAL": "DECIMAL",
"DOUBLE PRECISION": "DOUBLE",
"DOUBLE": "DOUBLE",
"ENUM": "ENUM",
"FLOAT4": "FLOAT",
"FLOAT64": "DOUBLE",
"FLOAT8": "DOUBLE",
"GEOGRAPHY": "GEOGRAPHY",
"HYPERLOGLOG": "BINARY",
"IMAGE": "BINARY",
"INT": "INT",
"INT2": "SMALLINT",
"INT4": "INT",
"INT64": "BIGINT",
"INT8": "BIGINT",
"INTEGER": "INT",
"INTERVAL DAY TO SECOND": "INTERVAL",
"INTERVAL YEAR TO MONTH": "INTERVAL",
"INTERVAL": "INTERVAL",
"JSON": "JSON",
"LONG RAW": "BINARY",
"LONG VARCHAR": "VARCHAR",
"LONGBLOB": "LONGBLOB",
"MAP": "MAP",
"MEDIUMBLOB": "MEDIUMBLOB",
"MEDIUMINT": "INT",
"MEDIUMTEXT": "MEDIUMTEXT",
"MONEY": "NUMBER",
"NCHAR": "CHAR",
"NTEXT": "TEXT",
"NULL": "NULL",
"NUMBER": "NUMBER",
"NUMERIC": "NUMERIC",
"NVARCHAR": "VARCHAR",
"OBJECT": "JSON",
"RAW": "BINARY",
"REAL": "FLOAT",
"RECORD": "STRUCT",
"ROWID": "VARCHAR",
"ROWVERSION": "NUMBER",
"SET": "SET",
"SMALLDATETIME": "DATETIME",
"SMALLINT": "SMALLINT",
"SMALLMONEY": "NUMBER",
"SMALLSERIAL": "SMALLINT",
"SQL_VARIANT": "VARBINARY",
"STRING": "STRING",
"STRUCT": "STRUCT",
"TABLE": "BINARY",
"TEXT": "TEXT",
"TIME": "TIME",
"TIMESTAMP WITHOUT TIME ZONE": "TIMESTAMP",
"TIMESTAMP": "TIMESTAMP",
"TIMESTAMPTZ": "TIMESTAMP",
"TIMESTAMP_NTZ": "TIMESTAMP",
"TIMESTAMP_LTZ": "TIMESTAMP",
"TIMESTAMP_TZ": "TIMESTAMP",
"TIMETZ": "TIMESTAMP",
"TINYINT": "TINYINT",
"UNION": "UNION",
"UROWID": "VARCHAR",
"VARBINARY": "VARBINARY",
"VARCHAR": "VARCHAR",
"VARIANT": "JSON",
"XML": "BINARY",
"XMLTYPE": "BINARY",
}
_COMPLEX_TYPE = re.compile("^(struct|map|array|uniontype)")
_FIXED_DECIMAL = re.compile(r"(decimal|numeric)(\(\s*(\d+)\s*,\s*(\d+)\s*\))?")
_FIXED_STRING = re.compile(r"(var)?char\(\s*(\d+)\s*\)")
@staticmethod
def get_column_type(column_type: Any) -> str:
type_class: Optional[str] = None
for sql_type in ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.keys():
if str(column_type) in sql_type:
type_class = ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE[sql_type]
break
if type_class is None or type_class == "NULL":
for col_type in ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.keys():
if str(column_type).split("(")[0].split("<")[0].upper() in col_type:
type_class = ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get(col_type)
break
else:
type_class = None
return type_class
@staticmethod
def _parse_datatype_string(
s: str, **kwargs: Any
) -> Union[object, Dict[str, object]]:
s = s.strip()
if s.startswith("array<"):
if s[-1] != ">":
raise ValueError("expected '>' found: %s" % s)
arr_data_type = ColumnTypeParser._parse_primitive_datatype_string(s[6:-1])[
"dataType"
]
return {
"dataType": "ARRAY",
"arrayDataType": arr_data_type,
"dataTypeDisplay": s,
}
elif s.startswith("map<"):
if s[-1] != ">":
raise ValueError("expected '>' found: %s" % s)
parts = ColumnTypeParser._ignore_brackets_split(s[4:-1], ",")
if len(parts) != 2:
raise ValueError(
"The map type string format is: 'map<key_type,value_type>', "
+ "but got: %s" % s
)
kt = ColumnTypeParser._parse_datatype_string(parts[0])
vt = ColumnTypeParser._parse_datatype_string(parts[1])
return {"dataType": "MAP", "dataTypeDisplay": s}
elif s.startswith("uniontype<") or s.startswith("union<"):
if s[-1] != ">":
raise ValueError("'>' should be the last char, but got: %s" % s)
parts = ColumnTypeParser._ignore_brackets_split(s[10:-1], ",")
t = []
for part in parts:
if part.startswith("struct<"):
t.append(ColumnTypeParser._parse_datatype_string(part))
else:
t.append(ColumnTypeParser._parse_datatype_string(part))
return t
elif s.startswith("struct<"):
if s[-1] != ">":
raise ValueError("expected '>', found: %s" % s)
return ColumnTypeParser._parse_struct_fields_string(s[7:-1])
elif ":" in s:
return ColumnTypeParser._parse_struct_fields_string(s)
else:
return ColumnTypeParser._parse_primitive_datatype_string(s)
@staticmethod
def _parse_struct_fields_string(s: str) -> Dict[str, object]:
parts = ColumnTypeParser._ignore_brackets_split(s, ",")
columns = []
for part in parts:
name_and_type = ColumnTypeParser._ignore_brackets_split(part, ":")
if len(name_and_type) != 2:
raise ValueError(
"expected format is: 'field_name:field_type', "
+ "but got: %s" % part
)
field_name = name_and_type[0].strip()
if field_name.startswith("`"):
if field_name[-1] != "`":
raise ValueError("'`' should be the last char, but got: %s" % s)
field_name = field_name[1:-1]
field_type = ColumnTypeParser._parse_datatype_string(name_and_type[1])
field_type["name"] = field_name
columns.append(field_type)
return {
"children": columns,
"dataTypeDisplay": "struct<{}>".format(s),
"dataType": "STRUCT",
}
@staticmethod
def _parse_primitive_datatype_string(s: str) -> Dict[str, object]:
if s.upper() in ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.keys():
return {
"dataType": ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE[s.upper()],
"dataTypeDisplay": s,
}
elif ColumnTypeParser._FIXED_STRING.match(s):
m = ColumnTypeParser._FIXED_STRING.match(s)
return {"type": "STRING", "dataTypeDisplay": s}
elif ColumnTypeParser._FIXED_DECIMAL.match(s):
m = ColumnTypeParser._FIXED_DECIMAL.match(s)
if m.group(2) is not None: # type: ignore
return {
"dataType": ColumnTypeParser.get_column_type(m.group(0)),
"dataTypeDisplay": s,
"dataLength": int(m.group(3)), # type: ignore
}
else:
return {
"dataType": ColumnTypeParser.get_column_type(m.group(0)),
"dataTypeDisplay": s,
}
elif s == "date":
return {"dataType": "DATE", "dataTypeDisplay": s}
elif s == "timestamp":
return {"dataType": "TIMESTAMP", "dataTypeDisplay": s}
else:
dataType = ColumnTypeParser.get_column_type(s)
if not dataType:
return {"dataType": "NULL", "dataTypeDisplay": s}
else:
dataLength = 1
if re.match(".*(\([\w]*\))", s):
dataLength = re.match(".*\(([\w]*)\)", s).groups()[0]
return {
"dataType": dataType,
"dataTypeDisplay": dataType,
"dataLength": dataLength,
}
@staticmethod
def _ignore_brackets_split(s: str, separator: str) -> List[str]:
parts = []
buf = ""
level = 0
for c in s:
if c in ColumnTypeParser._BRACKETS.keys():
level += 1
buf += c
elif c in ColumnTypeParser._BRACKETS.values():
if level == 0:
raise ValueError("Brackets are not correctly paired: %s" % s)
level -= 1
buf += c
elif c == separator and level > 0:
buf += c
elif c == separator:
parts.append(buf)
buf = ""
else:
buf += c
if len(buf) == 0:
raise ValueError("The %s cannot be the last char: %s" % (separator, s))
parts.append(buf)
return parts
@staticmethod
def is_primitive_om_type(raw_type: str) -> bool:
return not ColumnTypeParser._COMPLEX_TYPE.match(raw_type)

View File

@ -0,0 +1,110 @@
{
"data": [
{
"dataType": "ARRAY",
"arrayDataType": "STRING",
"dataTypeDisplay": "array<string>"
},
{
"children": [
{
"dataType": "INT",
"dataTypeDisplay": "int",
"name": "a"
},
{
"dataType": "STRING",
"dataTypeDisplay": "string",
"name": "b"
}
],
"dataTypeDisplay": "struct<a:int,b:string>",
"dataType": "STRUCT"
},
{
"children": [
{
"children": [
{
"dataType": "ARRAY",
"arrayDataType": "STRING",
"dataTypeDisplay": "array<string>",
"name": "b"
},
{
"dataType": "BIGINT",
"dataTypeDisplay": "bigint",
"name": "c"
}
],
"dataTypeDisplay": "struct<b:array<string>,c:bigint>",
"dataType": "STRUCT",
"name": "a"
}
],
"dataTypeDisplay": "struct<a:struct<b:array<string>,c:bigint>>",
"dataType": "STRUCT"
},
{
"children": [
{
"dataType": "ARRAY",
"arrayDataType": "STRING",
"dataTypeDisplay": "array<string>",
"name": "a"
}
],
"dataTypeDisplay": "struct<a:array<string>>",
"dataType": "STRUCT"
},
{
"children": [
{
"dataType": "ARRAY",
"arrayDataType": "STRUCT",
"dataTypeDisplay": "array<struct<bigquery_test_datatype_511:array<string>>>",
"name": "bigquerytestdatatype51"
}
],
"dataTypeDisplay": "struct<bigquerytestdatatype51:array<struct<bigquery_test_datatype_511:array<string>>>>",
"dataType": "STRUCT"
},
{
"children": [
{
"children": [
{
"children": [
{
"children": [
{
"dataType": "STRING",
"dataTypeDisplay": "string",
"name": "record_4"
}
],
"dataTypeDisplay": "struct<record_4:string>",
"dataType": "STRUCT",
"name": "record_3"
}
],
"dataTypeDisplay": "struct<record_3:struct<record_4:string>>",
"dataType": "STRUCT",
"name": "record_2"
}
],
"dataTypeDisplay": "struct<record_2:struct<record_3:struct<record_4:string>>>",
"dataType": "STRUCT",
"name": "record_1"
}
],
"dataTypeDisplay": "struct<record_1:struct<record_2:struct<record_3:struct<record_4:string>>>>",
"dataType": "STRUCT"
},
{
"dataType": "ARRAY",
"arrayDataType": "STRUCT",
"dataTypeDisplay": "array<struct<check_datatype:array<string>>>"
}
]
}

View File

@ -0,0 +1,32 @@
import os
from unittest import TestCase
from metadata.utils.column_type_parser import ColumnTypeParser
COLUMN_TYPE_PARSE = [
"array<string>",
"struct<a:int,b:string>",
"struct<a:struct<b:array<string>,c:bigint>>",
"struct<a:array<string>>",
"struct<bigquerytestdatatype51:array<struct<bigquery_test_datatype_511:array<string>>>>",
"struct<record_1:struct<record_2:struct<record_3:struct<record_4:string>>>>",
"array<struct<check_datatype:array<string>>>",
]
root = os.path.dirname(__file__)
import json
try:
with open(os.path.join(root, "resources/expected_output_column_parser.json")) as f:
EXPECTED_OUTPUT = json.loads(f.read())["data"]
except Exception as err:
print(err)
class ColumnTypeParseTest(TestCase):
def test_check_datatype_support(self):
for index, parse_string in enumerate(COLUMN_TYPE_PARSE):
parsed_string = ColumnTypeParser._parse_datatype_string(parse_string)
self.assertTrue(
True if parsed_string == EXPECTED_OUTPUT[index] else False,
msg=f"{index}: {COLUMN_TYPE_PARSE[index]} : {parsed_string}",
)

View File

@ -1,11 +1,8 @@
import unittest
from unittest import TestCase from unittest import TestCase
from metadata.ingestion.api.source import SourceStatus from metadata.utils.column_type_parser import ColumnTypeParser
from metadata.ingestion.source.sql_source import SQLSourceStatus
from metadata.utils.column_helpers import get_column_type
SQLTYPES = [ SQLTYPES = {
"ARRAY", "ARRAY",
"BIGINT", "BIGINT",
"BIGNUMERIC", "BIGNUMERIC",
@ -93,14 +90,13 @@ SQLTYPES = [
"TIMESTAMP_NTZ", "TIMESTAMP_NTZ",
"TIMESTAMP_LTZ", "TIMESTAMP_LTZ",
"TIMESTAMP_TZ", "TIMESTAMP_TZ",
] }
class DataTypeTest(TestCase): class DataTypeTest(TestCase):
def test_check_datatype_support(self): def test_check_datatype_support(self):
status = SQLSourceStatus()
for types in SQLTYPES: for types in SQLTYPES:
with self.subTest(line=types): with self.subTest(line=types):
col_type = get_column_type(status, "Unit Test", types) col_type = ColumnTypeParser.get_column_type(types)
col_type = True if col_type != "NULL" else False col_type = True if col_type != "NULL" else False
self.assertTrue(col_type, msg=types) self.assertTrue(col_type, msg=types)