Fix #10489: Handle unknown data types & store raw data type (#10563)

This commit is contained in:
Mayur Singal 2023-03-23 11:41:29 +05:30 committed by GitHub
parent b2e1eed842
commit bbce9c5aa4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 1331 additions and 440 deletions

View File

@ -35,6 +35,7 @@ from metadata.ingestion.source.database.common_db_source import (
TableNameAndType,
)
from metadata.utils.logger import ingestion_logger
from metadata.utils.sqlalchemy_utils import is_complex_type
logger = ingestion_logger()
@ -104,15 +105,6 @@ def _get_column_type(self, type_):
return col_type(*args)
def is_complex(type_: str):
return (
type_.startswith("array")
or type_.startswith("map")
or type_.startswith("struct")
or type_.startswith("row")
)
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
"""
@ -129,7 +121,8 @@ def get_columns(self, connection, table_name, schema=None, **kw):
"default": None,
"autoincrement": False,
"comment": c.comment,
"raw_data_type": c.type if is_complex(c.type) else None,
"system_data_type": c.type,
"is_complex": is_complex_type(c.type),
"dialect_options": {"awsathena_partition": None},
}
for c in metadata.columns
@ -142,7 +135,8 @@ def get_columns(self, connection, table_name, schema=None, **kw):
"default": None,
"autoincrement": False,
"comment": c.comment,
"raw_data_type": c.type if is_complex(c.type) else None,
"system_data_type": c.type,
"is_complex": is_complex_type(c.type),
"dialect_options": {"awsathena_partition": True},
}
for c in metadata.partition_keys
@ -150,8 +144,22 @@ def get_columns(self, connection, table_name, schema=None, **kw):
return columns
# pylint: disable=unused-argument
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
"""
Gets the view definition
"""
full_view_name = f'"{view_name}"' if not schema else f'"{schema}"."{view_name}"'
res = connection.execute(f"SHOW CREATE VIEW {full_view_name}").fetchall()
if res:
return "\n".join(i[0] for i in res)
return None
AthenaDialect._get_column_type = _get_column_type # pylint: disable=protected-access
AthenaDialect.get_columns = get_columns
AthenaDialect.get_view_definition = get_view_definition
class AthenaSource(CommonDbSourceService):

View File

@ -10,6 +10,8 @@
# limitations under the License.
"""Azure SQL source module"""
from sqlalchemy.dialects.mssql.base import MSDialect, ischema_names
from metadata.generated.schema.entity.services.connections.database.azureSQLConnection import (
AzureSQLConnection,
)
@ -20,7 +22,43 @@ from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source.database.column_type_parser import create_sqlalchemy_type
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
from metadata.ingestion.source.database.mssql.utils import (
get_columns,
get_table_comment,
get_view_definition,
)
from metadata.utils.sqlalchemy_utils import (
get_all_table_comments,
get_all_view_definitions,
)
ischema_names.update(
{
"nvarchar": create_sqlalchemy_type("NVARCHAR"),
"nchar": create_sqlalchemy_type("NCHAR"),
"ntext": create_sqlalchemy_type("NTEXT"),
"bit": create_sqlalchemy_type("BIT"),
"image": create_sqlalchemy_type("IMAGE"),
"binary": create_sqlalchemy_type("BINARY"),
"smallmoney": create_sqlalchemy_type("SMALLMONEY"),
"money": create_sqlalchemy_type("MONEY"),
"real": create_sqlalchemy_type("REAL"),
"smalldatetime": create_sqlalchemy_type("SMALLDATETIME"),
"datetime2": create_sqlalchemy_type("DATETIME2"),
"datetimeoffset": create_sqlalchemy_type("DATETIMEOFFSET"),
"sql_variant": create_sqlalchemy_type("SQL_VARIANT"),
"uniqueidentifier": create_sqlalchemy_type("UUID"),
"xml": create_sqlalchemy_type("XML"),
}
)
MSDialect.get_table_comment = get_table_comment
MSDialect.get_view_definition = get_view_definition
MSDialect.get_all_view_definitions = get_all_view_definitions
MSDialect.get_all_table_comments = get_all_table_comments
MSDialect.get_columns = get_columns
class AzuresqlSource(CommonDbSourceService):

View File

@ -20,6 +20,8 @@ from google.cloud.bigquery.client import Client
from google.cloud.datacatalog_v1 import PolicyTagManagerClient
from sqlalchemy import inspect
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.sqltypes import Interval
from sqlalchemy.types import String
from sqlalchemy_bigquery import BigQueryDialect, _types
from sqlalchemy_bigquery._types import _get_sqla_column_type
@ -68,10 +70,25 @@ from metadata.ingestion.source.database.common_db_source import CommonDbSourceSe
from metadata.utils import fqn
from metadata.utils.filters import filter_by_database
from metadata.utils.logger import ingestion_logger
from metadata.utils.sqlalchemy_utils import is_complex_type
class BQJSON(String):
"""The SQL JSON type."""
def get_col_spec(self, **kw): # pylint: disable=unused-argument
return "JSON"
logger = ingestion_logger()
GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY")
_types._type_map["GEOGRAPHY"] = GEOGRAPHY # pylint: disable=protected-access
# pylint: disable=protected-access
_types._type_map.update(
{
"GEOGRAPHY": create_sqlalchemy_type("GEOGRAPHY"),
"JSON": BQJSON,
"INTERVAL": Interval,
}
)
def get_columns(bq_schema):
@ -80,16 +97,18 @@ def get_columns(bq_schema):
"""
col_list = []
for field in bq_schema:
col_type = _get_sqla_column_type(field)
col_obj = {
"name": field.name,
"type": _get_sqla_column_type(field),
"type": col_type,
"nullable": field.mode in ("NULLABLE", "REPEATED"),
"comment": field.description,
"default": None,
"precision": field.precision,
"scale": field.scale,
"max_length": field.max_length,
"raw_data_type": str(_get_sqla_column_type(field)),
"system_data_type": str(col_type),
"is_complex": is_complex_type(str(col_type)),
"policy_tags": None,
}
try:
@ -390,7 +409,7 @@ class BigquerySource(CommonDbSourceService):
return True, table_partition
return False, None
def parse_raw_data_type(self, raw_data_type):
def clean_raw_data_type(self, raw_data_type):
return raw_data_type.replace(", ", ",").replace(" ", ":").lower()
def close(self):

View File

@ -13,9 +13,9 @@
from clickhouse_sqlalchemy.drivers.base import ClickHouseDialect, ischema_names
from clickhouse_sqlalchemy.drivers.http.transport import RequestsTransport, _get_type
from clickhouse_sqlalchemy.drivers.http.utils import parse_tsv
from clickhouse_sqlalchemy.types import Date
from sqlalchemy import types as sqltypes
from sqlalchemy.engine import reflection
from sqlalchemy.sql.sqltypes import String
from sqlalchemy.util import warn
from metadata.generated.schema.entity.services.connections.database.clickhouseConnection import (
@ -32,6 +32,7 @@ from metadata.ingestion.source.database.clickhouse.queries import (
CLICKHOUSE_TABLE_COMMENTS,
CLICKHOUSE_VIEW_DEFINITIONS,
)
from metadata.ingestion.source.database.column_type_parser import create_sqlalchemy_type
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
from metadata.utils.logger import ingestion_logger
from metadata.utils.sqlalchemy_utils import (
@ -43,46 +44,41 @@ from metadata.utils.sqlalchemy_utils import (
logger = ingestion_logger()
Map = create_sqlalchemy_type("Map")
Array = create_sqlalchemy_type("Array")
Enum = create_sqlalchemy_type("Enum")
Tuple = create_sqlalchemy_type("Tuple")
class AggregateFunction(String):
__visit_name__ = "AggregateFunction"
class Map(sqltypes.UserDefinedType): # pylint: disable=abstract-method
__visit_name__ = "Map"
class Array(sqltypes.UserDefinedType): # pylint: disable=abstract-method
__visit_name__ = "Array"
class Tuple(sqltypes.UserDefinedType): # pylint: disable=abstract-method
__visit_name__ = "Tuple"
class Enum(sqltypes.UserDefinedType): # pylint: disable=abstract-method
__visit_name__ = "Enum"
ischema_names.update(
{
"AggregateFunction": create_sqlalchemy_type("AggregateFunction"),
"Map": Map,
"Array": Array,
"Tuple": Tuple,
"Enum": Enum,
"Date32": Date,
"SimpleAggregateFunction": create_sqlalchemy_type("SimpleAggregateFunction"),
"Int256": create_sqlalchemy_type("BIGINT"),
"Int128": create_sqlalchemy_type("BIGINT"),
"Int64": create_sqlalchemy_type("BIGINT"),
"Int32": create_sqlalchemy_type("INTEGER"),
"Int16": create_sqlalchemy_type("SMALLINT"),
"Int8": create_sqlalchemy_type("SMALLINT"),
"UInt256": create_sqlalchemy_type("BIGINT"),
"UInt128": create_sqlalchemy_type("BIGINT"),
"UInt64": create_sqlalchemy_type("BIGINT"),
"UInt32": create_sqlalchemy_type("INTEGER"),
"UInt16": create_sqlalchemy_type("SMALLINT"),
"UInt8": create_sqlalchemy_type("SMALLINT"),
}
)
@reflection.cache
def _get_column_type(
self, name, spec
): # pylint: disable=protected-access,too-many-branches,too-many-return-statements
ischema_names.update(
{
"AggregateFunction": AggregateFunction,
"Map": Map,
"Array": Array,
"Tuple": Tuple,
"Enum": Enum,
}
)
ClickHouseDialect.ischema_names = ischema_names
if spec.startswith("Array"):
return self.ischema_names["Array"]
@ -123,6 +119,9 @@ def _get_column_type(
if spec.lower().startswith("aggregatefunction"):
return self.ischema_names["AggregateFunction"]
if spec.lower().startswith("simpleaggregatefunction"):
return self.ischema_names["SimpleAggregateFunction"]
try:
return self.ischema_names[spec]
except KeyError:
@ -208,19 +207,21 @@ def _get_column_info(
default_type, default_expression
)
raw_type = format_type.lower().replace("(", "<").replace(")", ">")
result = {
"name": name,
"type": col_type,
"nullable": format_type.startswith("Nullable("),
"default": col_default,
"comment": comment or None,
"system_data_type": raw_type,
}
raw_type = format_type.lower().replace("(", "<").replace(")", ">")
if col_type in [Map, Array, Tuple, Enum]:
result["display_type"] = raw_type
if col_type == Array:
result["raw_data_type"] = raw_type
result["is_complex"] = True
return result

View File

@ -53,7 +53,7 @@ class ColumnTypeParser:
types.ARRAY: "ARRAY",
types.Boolean: "BOOLEAN",
types.CHAR: "CHAR",
types.CLOB: "BINARY",
types.CLOB: "CLOB",
types.Date: "DATE",
types.DATE: "DATE",
types.DateTime: "DATETIME",
@ -109,6 +109,7 @@ class ColumnTypeParser:
"ENUM": "ENUM",
"FLOAT": "FLOAT",
"FLOAT4": "FLOAT",
"FLOAT32": "FLOAT",
"FLOAT64": "DOUBLE",
"FLOAT8": "DOUBLE",
"GEOGRAPHY": "GEOGRAPHY",
@ -187,21 +188,56 @@ class ColumnTypeParser:
"VARIANT": "JSON",
"JSON": "JSON",
"JSONB": "JSON",
"XML": "BINARY",
"XMLTYPE": "BINARY",
"UUID": "UUID",
"POINT": "POINT",
"POLYGON": "POLYGON",
"POINT": "GEOMETRY",
"POLYGON": "GEOMETRY",
"AggregateFunction()": "AGGREGATEFUNCTION",
"BYTEA": "BYTEA",
"UNKNOWN": "UNKNOWN",
# redshift
"HLLSKETCH": "HLLSKETCH",
"SUPER": "SUPER",
# postgres
"BOX": "GEOMETRY",
"CIRCLE": "GEOMETRY",
"LINE": "GEOMETRY",
"LSEG": "GEOMETRY",
"PATH": "GEOMETRY",
"PG_LSN": "PG_LSN",
"PG_SNAPSHOT": "PG_SNAPSHOT",
"TSQUERY": "TSQUERY",
"TXID_SNAPSHOT": "TXID_SNAPSHOT",
"XML": "XML",
"TSVECTOR": "TSVECTOR",
"MACADDR": "MACADDR",
"MACADDR8": "MACADDR",
"CIDR": "CIDR",
"INET": "INET",
# ORACLE
"BINARY_DOUBLE": "DOUBLE",
"BINARY_FLOAT": "FLOAT",
"XMLTYPE": "XML",
"BFILE": "BINARY",
"CLOB": "CLOB",
"NCLOB": "CLOB",
"LONG": "LONG",
# clickhouse
"LOWCARDINALITY": "LOWCARDINALITY",
"DATETIME64": "DATETIME",
"SimpleAggregateFunction()": "AGGREGATEFUNCTION",
# Databricks
"VOID": "NULL",
# mysql
"TINYBLOB": "BLOB",
"LONGTEXT": "TEXT",
"TINYTEXT": "TEXT",
"YEAR": "YEAR",
}
_COMPLEX_TYPE = re.compile("^(struct|map|array|uniontype)")
_FIXED_DECIMAL = re.compile(r"(decimal|numeric)(\(\s*(\d+)\s*,\s*(\d+)\s*\))?")
_FIXED_STRING = re.compile(r"(var)?char\(\s*(\d+)\s*\)")
try:
# pylint: disable=import-outside-toplevel
from sqlalchemy.dialects.mssql import BIT
@ -225,7 +261,7 @@ class ColumnTypeParser:
if column_type_result:
return column_type_result
return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get("VARCHAR")
return ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE.get("UNKNOWN")
@staticmethod
def get_column_type_mapping(column_type: Any) -> str:
@ -245,7 +281,7 @@ class ColumnTypeParser:
def _parse_datatype_string(
data_type: str, **kwargs: Any # pylint: disable=unused-argument
) -> Union[object, Dict[str, object]]:
data_type = data_type.strip()
data_type = data_type.lower().strip()
data_type = data_type.replace(" ", "")
if data_type.startswith("array<"):
if data_type[-1] != ">":
@ -319,8 +355,6 @@ class ColumnTypeParser:
"dataType": ColumnTypeParser._SOURCE_TYPE_TO_OM_TYPE[dtype.upper()],
"dataTypeDisplay": dtype,
}
if ColumnTypeParser._FIXED_STRING.match(dtype):
return {"dataType": "STRING", "dataTypeDisplay": dtype}
if ColumnTypeParser._FIXED_DECIMAL.match(dtype):
match = ColumnTypeParser._FIXED_DECIMAL.match(dtype)
if match.group(2) is not None: # type: ignore

View File

@ -291,6 +291,7 @@ class DatabaseServiceSource(
by default no need to process table constraints
specially for non SQA sources
"""
yield from []
@abstractmethod
def yield_table(

View File

@ -34,6 +34,7 @@ from metadata.generated.schema.metadataIngestion.workflow import (
)
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source.connections import get_connection
from metadata.ingestion.source.database.column_type_parser import create_sqlalchemy_type
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
from metadata.ingestion.source.database.databricks.queries import (
DATABRICKS_GET_TABLE_COMMENTS,
@ -73,7 +74,16 @@ class MAP(String):
# overriding pyhive.sqlalchemy_hive._type_map
# mapping struct, array & map to custom classed instead of sqltypes.String
_type_map.update({"struct": STRUCT, "array": ARRAY, "map": MAP})
_type_map.update(
{
"struct": STRUCT,
"array": ARRAY,
"map": MAP,
"void": create_sqlalchemy_type("VOID"),
"interval": create_sqlalchemy_type("INTERVAL"),
"binary": create_sqlalchemy_type("BINARY"),
}
)
def _get_column_rows(self, connection, table_name, schema):
@ -109,6 +119,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
# Take out the more detailed type information
# e.g. 'map<ixnt,int>' -> 'map'
# 'decimal(10,1)' -> decimal
raw_col_type = col_type
col_type = re.search(r"^\w+", col_type).group(0)
try:
coltype = _type_map[col_type]
@ -122,6 +133,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
"nullable": True,
"default": None,
"comment": _comment,
"system_data_type": raw_col_type,
}
if col_type in {"array", "struct", "map"}:
if db_name and schema:
@ -139,7 +151,8 @@ def get_columns(self, connection, table_name, schema=None, **kw):
).fetchall()
)
col_info["raw_data_type"] = rows["data_type"]
col_info["system_data_type"] = rows["data_type"]
col_info["is_complex"] = True
result.append(col_info)
return result

View File

@ -365,7 +365,7 @@ class DeltalakeSource(DatabaseServiceSource):
f"Unexpected exception getting columns for [{table_name}]: {exc}"
)
return []
parsed_columns: [Column] = []
parsed_columns: List[Column] = []
partition_cols = False
for row in raw_columns:
col_name = row["col_name"]

View File

@ -170,6 +170,7 @@ class DynamodbSource(DatabaseServiceSource):
parsed_string["dataType"] = "UNION"
parsed_string["name"] = column["AttributeName"][:64]
parsed_string["dataLength"] = parsed_string.get("dataLength", 1)
parsed_string["displayDataType"] = str(column["AttributeType"])
yield Column(**parsed_string)
except Exception as exc:
logger.debug(traceback.format_exc())

View File

@ -332,6 +332,7 @@ class GlueSource(DatabaseServiceSource):
parsed_string["dataTypeDisplay"] = str(column["Type"])
parsed_string["dataType"] = "UNION"
parsed_string["name"] = column["Name"][:64]
parsed_string["dataTypeDisplay"] = column["Type"]
parsed_string["dataLength"] = parsed_string.get("dataLength", 1)
parsed_string["description"] = column.get("Comment")
yield Column(**parsed_string)

View File

@ -89,9 +89,8 @@ def get_columns(
"comment": comment,
"nullable": True,
"default": None,
"raw_data_type": col_raw_type
if col_type in complex_data_types
else None,
"system_data_type": col_raw_type,
"is_complex": col_type in complex_data_types,
}
)
return result

View File

@ -11,6 +11,9 @@
"""
MariaDB source module
"""
from sqlalchemy.dialects.mysql.base import ischema_names
from sqlalchemy.dialects.mysql.reflection import MySQLTableDefinitionParser
from metadata.generated.schema.entity.services.connections.database.mariaDBConnection import (
MariaDBConnection,
)
@ -22,6 +25,14 @@ from metadata.generated.schema.metadataIngestion.workflow import (
)
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
from metadata.ingestion.source.database.mysql.utils import col_type_map, parse_column
ischema_names.update(col_type_map)
MySQLTableDefinitionParser._parse_column = ( # pylint: disable=protected-access
parse_column
)
class MariadbSource(CommonDbSourceService):

View File

@ -12,26 +12,7 @@
import traceback
from typing import Iterable
from sqlalchemy import sql
from sqlalchemy import types as sqltypes
from sqlalchemy import util
from sqlalchemy.dialects.mssql import information_schema as ischema
from sqlalchemy.dialects.mssql.base import (
MSBinary,
MSChar,
MSDialect,
MSNChar,
MSNText,
MSNVarchar,
MSString,
MSText,
MSVarBinary,
_db_plus_owner,
)
from sqlalchemy.engine import reflection
from sqlalchemy.sql import func
from sqlalchemy.types import NVARCHAR
from sqlalchemy.util import compat
from sqlalchemy.dialects.mssql.base import MSDialect, ischema_names
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.services.connections.database.mssqlConnection import (
@ -44,11 +25,12 @@ from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source.database.column_type_parser import create_sqlalchemy_type
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
from metadata.ingestion.source.database.mssql.queries import (
MSSQL_ALL_VIEW_DEFINITIONS,
MSSQL_GET_COLUMN_COMMENTS,
MSSQL_GET_TABLE_COMMENTS,
from metadata.ingestion.source.database.mssql.utils import (
get_columns,
get_table_comment,
get_view_definition,
)
from metadata.utils import fqn
from metadata.utils.filters import filter_by_database
@ -56,207 +38,30 @@ from metadata.utils.logger import ingestion_logger
from metadata.utils.sqlalchemy_utils import (
get_all_table_comments,
get_all_view_definitions,
get_table_comment_wrapper,
get_view_definition_wrapper,
)
logger = ingestion_logger()
@reflection.cache
def get_table_comment(
self, connection, table_name, schema=None, **kw
): # pylint: disable=unused-argument
return get_table_comment_wrapper(
self,
connection,
table_name=table_name,
schema=schema,
query=MSSQL_GET_TABLE_COMMENTS,
)
@reflection.cache
@_db_plus_owner
def get_columns(
self, connection, tablename, dbname, owner, schema, **kw
): # pylint: disable=unused-argument, too-many-locals, disable=too-many-branches, too-many-statements
"""
This function overrides to add support for column comments
"""
is_temp_table = tablename.startswith("#")
if is_temp_table:
(
owner,
tablename,
) = self._get_internal_temp_table_name( # pylint: disable=protected-access
connection, tablename
)
columns = ischema.mssql_temp_table_columns
else:
columns = ischema.columns
computed_cols = ischema.computed_columns
identity_cols = ischema.identity_columns
if owner:
whereclause = sql.and_(
columns.c.table_name == tablename,
columns.c.table_schema == owner,
)
full_name = columns.c.table_schema + "." + columns.c.table_name
else:
whereclause = columns.c.table_name == tablename
full_name = columns.c.table_name
join = columns.join(
computed_cols,
onclause=sql.and_(
computed_cols.c.object_id == func.object_id(full_name),
computed_cols.c.name == columns.c.column_name.collate("DATABASE_DEFAULT"),
),
isouter=True,
).join(
identity_cols,
onclause=sql.and_(
identity_cols.c.object_id == func.object_id(full_name),
identity_cols.c.name == columns.c.column_name.collate("DATABASE_DEFAULT"),
),
isouter=True,
)
if self._supports_nvarchar_max: # pylint: disable=protected-access
computed_definition = computed_cols.c.definition
else:
# tds_version 4.2 does not support NVARCHAR(MAX)
computed_definition = sql.cast(computed_cols.c.definition, NVARCHAR(4000))
s = ( # pylint: disable=invalid-name
sql.select(
columns,
computed_definition,
computed_cols.c.is_persisted,
identity_cols.c.is_identity,
identity_cols.c.seed_value,
identity_cols.c.increment_value,
)
.where(whereclause)
.select_from(join)
.order_by(columns.c.ordinal_position)
)
c = connection.execution_options( # pylint:disable=invalid-name
future_result=True
).execute(s)
cols = []
for row in c.mappings():
name = row[columns.c.column_name]
type_ = row[columns.c.data_type]
nullable = row[columns.c.is_nullable] == "YES"
charlen = row[columns.c.character_maximum_length]
numericprec = row[columns.c.numeric_precision]
numericscale = row[columns.c.numeric_scale]
default = row[columns.c.column_default]
collation = row[columns.c.collation_name]
definition = row[computed_definition]
is_persisted = row[computed_cols.c.is_persisted]
is_identity = row[identity_cols.c.is_identity]
identity_start = row[identity_cols.c.seed_value]
identity_increment = row[identity_cols.c.increment_value]
coltype = self.ischema_names.get(type_, None)
comment = None
kwargs = {}
if coltype in (
MSString,
MSChar,
MSNVarchar,
MSNChar,
MSText,
MSNText,
MSBinary,
MSVarBinary,
sqltypes.LargeBinary,
):
if charlen == -1:
charlen = None
kwargs["length"] = charlen
if collation:
kwargs["collation"] = collation
if coltype is None:
util.warn(f"Did not recognize type '{type_}' of column '{name}'")
coltype = sqltypes.NULLTYPE
else:
if issubclass(coltype, sqltypes.Numeric):
kwargs["precision"] = numericprec
if not issubclass(coltype, sqltypes.Float):
kwargs["scale"] = numericscale
coltype = coltype(**kwargs)
cdict = {
"name": name,
"type": coltype,
"nullable": nullable,
"default": default,
"autoincrement": is_identity is not None,
"comment": comment,
}
if definition is not None and is_persisted is not None:
cdict["computed"] = {
"sqltext": definition,
"persisted": is_persisted,
}
if is_identity is not None:
# identity_start and identity_increment are Decimal or None
if identity_start is None or identity_increment is None:
cdict["identity"] = {}
else:
if isinstance(coltype, sqltypes.BigInteger):
start = compat.long_type(identity_start)
increment = compat.long_type(identity_increment)
elif isinstance(coltype, sqltypes.Integer):
start = int(identity_start)
increment = int(identity_increment)
else:
start = identity_start
increment = identity_increment
cdict["identity"] = {
"start": start,
"increment": increment,
}
cols.append(cdict)
cursor = connection.execute(
MSSQL_GET_COLUMN_COMMENTS.format(schema_name=schema, table_name=tablename)
)
try:
for index, result in enumerate(cursor):
if result[2]:
cols[index]["comment"] = result[2]
except Exception:
logger.debug(traceback.format_exc())
return cols
@reflection.cache
@_db_plus_owner
def get_view_definition(
self, connection, viewname, dbname, owner, schema, **kw
): # pylint: disable=unused-argument
return get_view_definition_wrapper(
self,
connection,
table_name=viewname,
schema=owner,
query=MSSQL_ALL_VIEW_DEFINITIONS,
)
ischema_names.update(
{
"nvarchar": create_sqlalchemy_type("NVARCHAR"),
"nchar": create_sqlalchemy_type("NCHAR"),
"ntext": create_sqlalchemy_type("NTEXT"),
"bit": create_sqlalchemy_type("BIT"),
"image": create_sqlalchemy_type("IMAGE"),
"binary": create_sqlalchemy_type("BINARY"),
"smallmoney": create_sqlalchemy_type("SMALLMONEY"),
"money": create_sqlalchemy_type("MONEY"),
"real": create_sqlalchemy_type("REAL"),
"smalldatetime": create_sqlalchemy_type("SMALLDATETIME"),
"datetime2": create_sqlalchemy_type("DATETIME2"),
"datetimeoffset": create_sqlalchemy_type("DATETIMEOFFSET"),
"sql_variant": create_sqlalchemy_type("SQL_VARIANT"),
"uniqueidentifier": create_sqlalchemy_type("UUID"),
"xml": create_sqlalchemy_type("XML"),
}
)
MSDialect.get_table_comment = get_table_comment
MSDialect.get_view_definition = get_view_definition

View File

@ -0,0 +1,251 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MSSQL SQLAlchemy Helper Methods
"""
import traceback
from sqlalchemy import sql
from sqlalchemy import types as sqltypes
from sqlalchemy import util
from sqlalchemy.dialects.mssql import information_schema as ischema
from sqlalchemy.dialects.mssql.base import (
MSBinary,
MSChar,
MSNChar,
MSNText,
MSNVarchar,
MSString,
MSText,
MSVarBinary,
_db_plus_owner,
)
from sqlalchemy.engine import reflection
from sqlalchemy.sql import func
from sqlalchemy.types import NVARCHAR
from sqlalchemy.util import compat
from metadata.ingestion.source.database.mssql.queries import (
MSSQL_ALL_VIEW_DEFINITIONS,
MSSQL_GET_COLUMN_COMMENTS,
MSSQL_GET_TABLE_COMMENTS,
)
from metadata.utils.logger import ingestion_logger
from metadata.utils.sqlalchemy_utils import (
get_display_datatype,
get_table_comment_wrapper,
get_view_definition_wrapper,
)
logger = ingestion_logger()
@reflection.cache
def get_table_comment(
self, connection, table_name, schema=None, **kw
): # pylint: disable=unused-argument
return get_table_comment_wrapper(
self,
connection,
table_name=table_name,
schema=schema,
query=MSSQL_GET_TABLE_COMMENTS,
)
@reflection.cache
@_db_plus_owner
def get_columns(
self, connection, tablename, dbname, owner, schema, **kw
): # pylint: disable=unused-argument, too-many-locals, disable=too-many-branches, too-many-statements
"""
This function overrides to add support for column comments
"""
is_temp_table = tablename.startswith("#")
if is_temp_table:
(
owner,
tablename,
) = self._get_internal_temp_table_name( # pylint: disable=protected-access
connection, tablename
)
columns = ischema.mssql_temp_table_columns
else:
columns = ischema.columns
computed_cols = ischema.computed_columns
identity_cols = ischema.identity_columns
if owner:
whereclause = sql.and_(
columns.c.table_name == tablename,
columns.c.table_schema == owner,
)
full_name = columns.c.table_schema + "." + columns.c.table_name
else:
whereclause = columns.c.table_name == tablename
full_name = columns.c.table_name
join = columns.join(
computed_cols,
onclause=sql.and_(
computed_cols.c.object_id == func.object_id(full_name),
computed_cols.c.name == columns.c.column_name.collate("DATABASE_DEFAULT"),
),
isouter=True,
).join(
identity_cols,
onclause=sql.and_(
identity_cols.c.object_id == func.object_id(full_name),
identity_cols.c.name == columns.c.column_name.collate("DATABASE_DEFAULT"),
),
isouter=True,
)
if self._supports_nvarchar_max: # pylint: disable=protected-access
computed_definition = computed_cols.c.definition
else:
# tds_version 4.2 does not support NVARCHAR(MAX)
computed_definition = sql.cast(computed_cols.c.definition, NVARCHAR(4000))
s = ( # pylint: disable=invalid-name
sql.select(
columns,
computed_definition,
computed_cols.c.is_persisted,
identity_cols.c.is_identity,
identity_cols.c.seed_value,
identity_cols.c.increment_value,
)
.where(whereclause)
.select_from(join)
.order_by(columns.c.ordinal_position)
)
c = connection.execution_options( # pylint:disable=invalid-name
future_result=True
).execute(s)
cols = []
for row in c.mappings():
name = row[columns.c.column_name]
type_ = row[columns.c.data_type]
nullable = row[columns.c.is_nullable] == "YES"
charlen = row[columns.c.character_maximum_length]
numericprec = row[columns.c.numeric_precision]
numericscale = row[columns.c.numeric_scale]
default = row[columns.c.column_default]
collation = row[columns.c.collation_name]
definition = row[computed_definition]
is_persisted = row[computed_cols.c.is_persisted]
is_identity = row[identity_cols.c.is_identity]
identity_start = row[identity_cols.c.seed_value]
identity_increment = row[identity_cols.c.increment_value]
coltype = self.ischema_names.get(type_, None)
comment = None
kwargs = {}
if coltype in (
MSString,
MSChar,
MSNVarchar,
MSNChar,
MSText,
MSNText,
MSBinary,
MSVarBinary,
sqltypes.LargeBinary,
):
if charlen == -1:
charlen = None
kwargs["length"] = charlen
if collation:
kwargs["collation"] = collation
precision = None
scale = None
if coltype is None:
util.warn(f"Did not recognize type '{type_}' of column '{name}'")
coltype = sqltypes.NULLTYPE
else:
if issubclass(coltype, sqltypes.Numeric):
kwargs["precision"] = numericprec
precision = numericprec
if not issubclass(coltype, sqltypes.Float):
kwargs["scale"] = numericscale
scale = numericscale
coltype = coltype(**kwargs)
raw_data_type = get_display_datatype(
type_, char_len=charlen, precision=precision, scale=scale
)
cdict = {
"name": name,
"type": coltype,
"system_data_type": raw_data_type,
"nullable": nullable,
"default": default,
"autoincrement": is_identity is not None,
"comment": comment,
}
if definition is not None and is_persisted is not None:
cdict["computed"] = {
"sqltext": definition,
"persisted": is_persisted,
}
if is_identity is not None:
# identity_start and identity_increment are Decimal or None
if identity_start is None or identity_increment is None:
cdict["identity"] = {}
else:
if isinstance(coltype, sqltypes.BigInteger):
start = compat.long_type(identity_start)
increment = compat.long_type(identity_increment)
elif isinstance(coltype, sqltypes.Integer):
start = int(identity_start)
increment = int(identity_increment)
else:
start = identity_start
increment = identity_increment
cdict["identity"] = {
"start": start,
"increment": increment,
}
cols.append(cdict)
cursor = connection.execute(
MSSQL_GET_COLUMN_COMMENTS.format(schema_name=schema, table_name=tablename)
)
try:
for index, result in enumerate(cursor):
if result[2]:
cols[index]["comment"] = result[2]
except Exception:
logger.debug(traceback.format_exc())
return cols
@reflection.cache
@_db_plus_owner
def get_view_definition(
self, connection, viewname, dbname, owner, schema, **kw
): # pylint: disable=unused-argument
return get_view_definition_wrapper(
self,
connection,
table_name=viewname,
schema=owner,
query=MSSQL_ALL_VIEW_DEFINITIONS,
)

View File

@ -9,6 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mysql source module"""
from sqlalchemy.dialects.mysql.base import ischema_names
from sqlalchemy.dialects.mysql.reflection import MySQLTableDefinitionParser
from metadata.generated.schema.entity.services.connections.database.mysqlConnection import (
MysqlConnection,
)
@ -20,6 +23,14 @@ from metadata.generated.schema.metadataIngestion.workflow import (
)
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
from metadata.ingestion.source.database.mysql.utils import col_type_map, parse_column
ischema_names.update(col_type_map)
MySQLTableDefinitionParser._parse_column = ( # pylint: disable=protected-access
parse_column
)
class MysqlSource(CommonDbSourceService):

View File

@ -0,0 +1,153 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MySQL SQLAlchemy Helper Methods
"""
# pylint: disable=protected-access,too-many-branches,too-many-statements,too-many-locals
from sqlalchemy import util
from sqlalchemy.dialects.mysql.enumerated import ENUM, SET
from sqlalchemy.dialects.mysql.reflection import _strip_values
from sqlalchemy.dialects.mysql.types import DATETIME, TIME, TIMESTAMP
from sqlalchemy.sql import sqltypes
from metadata.ingestion.source.database.column_type_parser import create_sqlalchemy_type
from metadata.utils.sqlalchemy_utils import get_display_datatype
col_type_map = {
"bool": create_sqlalchemy_type("BOOL"),
"geometry": create_sqlalchemy_type("GEOMETRY"),
"point": create_sqlalchemy_type("GEOMETRY"),
"polygon": create_sqlalchemy_type("GEOMETRY"),
"linestring": create_sqlalchemy_type("GEOMETRY"),
"geomcollection": create_sqlalchemy_type("GEOMETRY"),
"multilinestring": create_sqlalchemy_type("GEOMETRY"),
"multipoint": create_sqlalchemy_type("GEOMETRY"),
"multipolygon": create_sqlalchemy_type("GEOMETRY"),
}
def parse_column(self, line, state):
"""
Overriding the dialect method to include raw_data_type in response
Extract column details.
Falls back to a 'minimal support' variant if full parse fails.
:param line: Any column-bearing line from SHOW CREATE TABLE
"""
spec = None
re_match = self._re_column.match(line)
if re_match:
spec = re_match.groupdict()
spec["full"] = True
else:
re_match = self._re_column_loose.match(line)
if re_match:
spec = re_match.groupdict()
spec["full"] = False
if not spec:
util.warn(f"Unknown column definition {line}")
return
if not spec["full"]:
util.warn(f"Incomplete reflection of column definition {line}")
name, type_, args = spec["name"], spec["coltype"], spec["arg"]
try:
col_type = self.dialect.ischema_names[type_]
except KeyError:
util.warn(f"Did not recognize type '{type_}' of column '{name}'")
col_type = sqltypes.NullType
# Column type positional arguments eg. varchar(32)
if args is None or args == "":
type_args = []
elif args[0] == "'" and args[-1] == "'":
type_args = self._re_csv_str.findall(args)
else:
type_args = [int(v) for v in self._re_csv_int.findall(args)]
# Column type keyword options
type_kw = {}
if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)):
if type_args:
type_kw["fsp"] = type_args.pop(0)
for ikw in ("unsigned", "zerofill"):
if spec.get(ikw, False):
type_kw[ikw] = True
for ikw in ("charset", "collate"):
if spec.get(ikw, False):
type_kw[ikw] = spec[ikw]
if issubclass(col_type, (ENUM, SET)):
type_args = _strip_values(type_args)
if issubclass(col_type, SET) and "" in type_args:
type_kw["retrieve_as_bitwise"] = True
type_instance = col_type(*type_args, **type_kw)
col_kw = {}
# NOT NULL
col_kw["nullable"] = True
# this can be "NULL" in the case of TIMESTAMP
if spec.get("notnull", False) == "NOT NULL":
col_kw["nullable"] = False
# AUTO_INCREMENT
if spec.get("autoincr", False):
col_kw["autoincrement"] = True
elif issubclass(col_type, sqltypes.Integer):
col_kw["autoincrement"] = False
# DEFAULT
default = spec.get("default", None)
if default == "NULL":
# eliminates the need to deal with this later.
default = None
comment = spec.get("comment", None)
if comment is not None:
comment = comment.replace("\\\\", "\\").replace("''", "'")
sqltext = spec.get("generated")
if sqltext is not None:
computed = {"sqltext": sqltext}
persisted = spec.get("persistence")
if persisted is not None:
computed["persisted"] = persisted == "STORED"
col_kw["computed"] = computed
raw_type = get_display_datatype(
col_type=type_,
char_len=type_instance.length if hasattr(type_instance, "length") else None,
precision=type_instance.precision
if hasattr(type_instance, "precision")
else None,
scale=type_instance.scale if hasattr(type_instance, "scale") else None,
)
col_d = {
"name": name,
"type": type_instance,
"default": default,
"comment": comment,
"system_data_type": raw_type,
}
col_d.update(col_kw)
state.columns.append(col_d)

View File

@ -8,9 +8,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=protected-access
"""Oracle source module"""
from sqlalchemy.dialects.oracle.base import OracleDialect
import re
from sqlalchemy import sql, util
from sqlalchemy.dialects.oracle.base import (
FLOAT,
INTEGER,
INTERVAL,
NUMBER,
TIMESTAMP,
OracleDialect,
ischema_names,
)
from sqlalchemy.engine import reflection
from sqlalchemy.sql import sqltypes
from metadata.generated.schema.entity.services.connections.database.oracleConnection import (
OracleConnection,
@ -22,10 +36,13 @@ from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source.database.column_type_parser import create_sqlalchemy_type
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
from metadata.ingestion.source.database.oracle.queries import (
ORACLE_ALL_TABLE_COMMENTS,
ORACLE_ALL_VIEW_DEFINITIONS,
ORACLE_GET_COLUMNS,
ORACLE_IDENTITY_TYPE,
)
from metadata.utils.sqlalchemy_utils import (
get_all_table_comments,
@ -34,6 +51,14 @@ from metadata.utils.sqlalchemy_utils import (
get_view_definition_wrapper,
)
ischema_names.update(
{
"ROWID": create_sqlalchemy_type("ROWID"),
"XMLTYPE": create_sqlalchemy_type("XMLTYPE"),
"INTERVAL YEAR TO MONTH": INTERVAL,
}
)
@reflection.cache
def get_table_comment(
@ -74,7 +99,136 @@ def get_view_definition(
)
def _get_col_type(
self, coltype, precision, scale, length, colname
): # pylint: disable=too-many-branches
raw_type = coltype
if coltype == "NUMBER":
if precision is None and scale == 0:
coltype = INTEGER()
else:
coltype = NUMBER(precision, scale)
if precision is not None:
if scale is not None:
raw_type += f"({precision},{scale})"
else:
raw_type += f"({precision})"
elif coltype == "FLOAT":
# TODO: support "precision" here as "binary_precision"
coltype = FLOAT()
elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR", "NCHAR"):
coltype = self.ischema_names.get(coltype)(length)
if length:
raw_type += f"({length})"
elif "WITH TIME ZONE" in coltype or "TIMESTAMP" in coltype:
coltype = TIMESTAMP(timezone=True)
elif "INTERVAL" in coltype:
coltype = INTERVAL()
else:
coltype = re.sub(r"\(\d+\)", "", coltype)
try:
coltype = self.ischema_names[coltype]
except KeyError:
util.warn(f"Did not recognize type '{coltype}' of column '{colname}'")
coltype = sqltypes.NULLTYPE
return coltype, raw_type
# pylint: disable=too-many-locals
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
"""
Dialect method overridden to add raw data type
kw arguments can be:
oracle_resolve_synonyms
dblink
"""
resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
dblink = kw.get("dblink", "")
info_cache = kw.get("info_cache")
(table_name, schema, dblink, _) = self._prepare_reflection_args(
connection,
table_name,
schema,
resolve_synonyms,
dblink,
info_cache=info_cache,
)
columns = []
char_length_col = "data_length"
if self._supports_char_length:
char_length_col = "char_length"
identity_cols = "NULL as default_on_null, NULL as identity_options"
if self.server_version_info >= (12,):
identity_cols = ORACLE_IDENTITY_TYPE.format(dblink=dblink)
params = {"table_name": table_name}
text = ORACLE_GET_COLUMNS.format(
dblink=dblink, char_length_col=char_length_col, identity_cols=identity_cols
)
if schema is not None:
params["owner"] = schema
text += " AND col.owner = :owner "
text += " ORDER BY col.column_id"
cols = connection.execute(sql.text(text), params)
for row in cols:
colname = self.normalize_name(row[0])
length = row[2]
nullable = row[5] == "Y"
default = row[6]
generated = row[8]
default_on_nul = row[9]
identity_options = row[10]
coltype, raw_coltype = self._get_col_type(
row.data_type, row.data_precision, row.data_scale, length, colname
)
computed = None
if generated == "YES":
computed = {"sqltext": default}
default = None
identity = None
if identity_options is not None:
identity = self._parse_identity_options(identity_options, default_on_nul)
default = None
cdict = {
"name": colname,
"type": coltype,
"nullable": nullable,
"default": default,
"autoincrement": "auto",
"comment": row.comments,
"system_data_type": raw_coltype,
}
if row.column_name.lower() == row.column_name:
cdict["quote"] = True
if computed is not None:
cdict["computed"] = computed
if identity is not None:
cdict["identity"] = identity
columns.append(cdict)
return columns
OracleDialect.get_table_comment = get_table_comment
OracleDialect.get_columns = get_columns
OracleDialect._get_col_type = _get_col_type
OracleDialect.get_view_definition = get_view_definition
OracleDialect.get_all_view_definitions = get_all_view_definitions
OracleDialect.get_all_table_comments = get_all_table_comments

View File

@ -30,3 +30,35 @@ SELECT
FROM all_views
where text is not null and owner not in ('SYSTEM', 'SYS')
"""
ORACLE_IDENTITY_TYPE = """\
col.default_on_null,
(
SELECT id.generation_type || ',' || id.IDENTITY_OPTIONS
FROM ALL_TAB_IDENTITY_COLS{dblink} id
WHERE col.table_name = id.table_name
AND col.column_name = id.column_name
AND col.owner = id.owner
) AS identity_options
"""
ORACLE_GET_COLUMNS = """
SELECT
col.column_name,
col.data_type,
col.{char_length_col},
col.data_precision,
col.data_scale,
col.nullable,
col.data_default,
com.comments,
col.virtual_column,
{identity_cols}
FROM all_tab_cols{dblink} col
LEFT JOIN all_col_comments{dblink} com
ON col.table_name = com.table_name
AND col.column_name = com.column_name
AND col.owner = com.owner
WHERE col.table_name = CAST(:table_name AS VARCHAR2(128))
AND col.hidden_column = 'NO'
"""

View File

@ -19,7 +19,7 @@ from sqlalchemy import sql
from sqlalchemy.dialects.postgresql.base import PGDialect, ischema_names
from sqlalchemy.engine import reflection
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.sqltypes import String
from sqlalchemy.sql import sqltypes
from metadata.generated.schema.api.classification.createClassification import (
CreateClassificationRequest,
@ -42,15 +42,18 @@ from metadata.generated.schema.metadataIngestion.workflow import (
)
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.source.database.column_type_parser import create_sqlalchemy_type
from metadata.ingestion.source.database.common_db_source import (
CommonDbSourceService,
TableNameAndType,
)
from metadata.ingestion.source.database.postgres.queries import (
POSTGRES_COL_IDENTITY,
POSTGRES_GET_ALL_TABLE_PG_POLICY,
POSTGRES_GET_DB_NAMES,
POSTGRES_GET_TABLE_NAMES,
POSTGRES_PARTITION_DETAILS,
POSTGRES_SQL_COLUMNS,
POSTGRES_TABLE_COMMENTS,
POSTGRES_VIEW_DEFINITIONS,
)
@ -81,26 +84,27 @@ RELKIND_MAP = {
"f": TableType.Foreign,
}
GEOMETRY = create_sqlalchemy_type("GEOMETRY")
POINT = create_sqlalchemy_type("POINT")
POLYGON = create_sqlalchemy_type("POLYGON")
class GEOMETRY(String):
"""The SQL GEOMETRY type."""
__visit_name__ = "GEOMETRY"
class POINT(String):
"""The SQL POINT type."""
__visit_name__ = "POINT"
class POLYGON(String):
"""The SQL GEOMETRY type."""
__visit_name__ = "POLYGON"
ischema_names.update({"geometry": GEOMETRY, "point": POINT, "polygon": POLYGON})
ischema_names.update(
{
"geometry": GEOMETRY,
"point": POINT,
"polygon": POLYGON,
"box": create_sqlalchemy_type("BOX"),
"circle": create_sqlalchemy_type("CIRCLE"),
"line": create_sqlalchemy_type("LINE"),
"lseg": create_sqlalchemy_type("LSEG"),
"path": create_sqlalchemy_type("PATH"),
"pg_lsn": create_sqlalchemy_type("PG_LSN"),
"pg_snapshot": create_sqlalchemy_type("PG_SNAPSHOT"),
"tsquery": create_sqlalchemy_type("TSQUERY"),
"txid_snapshot": create_sqlalchemy_type("TXID_SNAPSHOT"),
"xml": create_sqlalchemy_type("XML"),
}
)
@reflection.cache
@ -116,6 +120,85 @@ def get_table_comment(
)
@reflection.cache
def get_columns( # pylint: disable=too-many-locals
self, connection, table_name, schema=None, **kw
):
"""
Overriding the dialect method to add raw_data_type in response
"""
table_oid = self.get_table_oid(
connection, table_name, schema, info_cache=kw.get("info_cache")
)
generated = (
"a.attgenerated as generated"
if self.server_version_info >= (12,)
else "NULL as generated"
)
if self.server_version_info >= (10,):
# a.attidentity != '' is required or it will reflect also
# serial columns as identity.
identity = POSTGRES_COL_IDENTITY
else:
identity = "NULL as identity_options"
sql_col_query = POSTGRES_SQL_COLUMNS.format(
generated=generated,
identity=identity,
)
sql_col_query = (
sql.text(sql_col_query)
.bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer))
.columns(attname=sqltypes.Unicode, default=sqltypes.Unicode)
)
conn = connection.execute(sql_col_query, {"table_oid": table_oid})
rows = conn.fetchall()
# dictionary with (name, ) if default search path or (schema, name)
# as keys
domains = self._load_domains(connection) # pylint: disable=protected-access
# dictionary with (name, ) if default search path or (schema, name)
# as keys
enums = dict(
((rec["name"],), rec) if rec["visible"] else ((rec["schema"], rec["name"]), rec)
for rec in self._load_enums( # pylint: disable=protected-access
connection, schema="*"
)
)
# format columns
columns = []
for (
name,
format_type,
default_,
notnull,
table_oid,
comment,
generated,
identity,
) in rows:
column_info = self._get_column_info( # pylint: disable=protected-access
name,
format_type,
default_,
notnull,
domains,
enums,
schema,
comment,
generated,
identity,
)
column_info["system_data_type"] = format_type
columns.append(column_info)
return columns
PGDialect.get_all_table_comments = get_all_table_comments
PGDialect.get_table_comment = get_table_comment
@ -134,6 +217,7 @@ def get_view_definition(
PGDialect.get_view_definition = get_view_definition
PGDialect.get_columns = get_columns
PGDialect.get_all_view_definitions = get_all_view_definitions
PGDialect.ischema_names = ischema_names

View File

@ -137,3 +137,44 @@ POSTGRES_SQL_STATEMENT_TEST = """
POSTGRES_GET_DB_NAMES = """
select datname from pg_catalog.pg_database
"""
POSTGRES_COL_IDENTITY = """\
(SELECT json_build_object(
'always', a.attidentity = 'a',
'start', s.seqstart,
'increment', s.seqincrement,
'minvalue', s.seqmin,
'maxvalue', s.seqmax,
'cache', s.seqcache,
'cycle', s.seqcycle)
FROM pg_catalog.pg_sequence s
JOIN pg_catalog.pg_class c on s.seqrelid = c."oid"
WHERE c.relkind = 'S'
AND a.attidentity != ''
AND s.seqrelid = pg_catalog.pg_get_serial_sequence(
a.attrelid::regclass::text, a.attname
)::regclass::oid
) as identity_options\
"""
POSTGRES_SQL_COLUMNS = """
SELECT a.attname,
pg_catalog.format_type(a.atttypid, a.atttypmod),
(
SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid)
FROM pg_catalog.pg_attrdef d
WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum
AND a.atthasdef
) AS DEFAULT,
a.attnotnull,
a.attrelid as table_oid,
pgd.description as comment,
{generated},
{identity}
FROM pg_catalog.pg_attribute a
LEFT JOIN pg_catalog.pg_description pgd ON (
pgd.objoid = a.attrelid AND pgd.objsubid = a.attnum)
WHERE a.attrelid = :table_oid
AND a.attnum > 0 AND NOT a.attisdropped
ORDER BY a.attnum
"""

View File

@ -85,6 +85,7 @@ def get_columns(
{
"name": row.Column,
"type": coltype,
"system_data_type": row.Type,
# newer Presto no longer includes this column
"nullable": getattr(row, "Null", True),
"default": None,
@ -93,7 +94,14 @@ def get_columns(
return result
@reflection.cache
# pylint: disable=unused-argument
def get_table_comment(self, connection, table_name, schema=None, **kw):
return {"text": None}
PrestoDialect.get_columns = get_columns
PrestoDialect.get_table_comment = get_table_comment
class PrestoSource(CommonDbSourceService):

View File

@ -26,6 +26,7 @@ from sqlalchemy.engine import reflection
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql import sqltypes
from sqlalchemy_redshift.dialect import (
REDSHIFT_ISCHEMA_NAMES,
RedshiftDialect,
RedshiftDialectMixin,
RelationKey,
@ -49,6 +50,7 @@ from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source.database.column_type_parser import create_sqlalchemy_type
from metadata.ingestion.source.database.common_db_source import (
CommonDbSourceService,
TableNameAndType,
@ -73,7 +75,10 @@ sa_version = Version(sa.__version__)
logger = ingestion_logger()
ischema_names = pg_ischema_names
GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY")
ischema_names["geography"] = GEOGRAPHY
ischema_names.update({"binary varying": sqltypes.VARBINARY})
ischema_names.update(REDSHIFT_ISCHEMA_NAMES)
# pylint: disable=protected-access
@reflection.cache
@ -106,6 +111,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
)
column_info["distkey"] = col.distkey
column_info["sortkey"] = col.sortkey
column_info["system_data_type"] = col.format_type
columns.append(column_info)
return columns
@ -132,7 +138,7 @@ def _get_column_info(self, *args, **kwargs):
)._get_column_info(*args, **kwdrs)
# raw_data_type is not included in column_info as
# redhift doesn't suport compex data types directly
# redhift doesn't support complex data types directly
# https://docs.aws.amazon.com/redshift/latest/dg/c_Supported_data_types.html
if "info" not in column_info:

View File

@ -23,6 +23,7 @@ from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.entity.data.table import (
Column,
Constraint,
DataType,
Table,
TableType,
)
@ -224,6 +225,7 @@ class SalesforceSource(DatabaseServiceSource):
name=column["name"],
description=column["label"],
dataType=self.column_type(column["type"].upper()),
dataTypeDisplay=column["type"],
constraint=col_constraint,
ordinalPosition=row_order,
dataLength=column["length"],
@ -233,9 +235,19 @@ class SalesforceSource(DatabaseServiceSource):
return columns
def column_type(self, column_type: str):
if column_type in {"ID", "PHONE", "CURRENCY"}:
return "INT"
return "VARCHAR"
if column_type in {
"ID",
"PHONE",
"EMAIL",
"ENCRYPTEDSTRING",
"COMBOBOX",
"URL",
"TEXTAREA",
"ADDRESS",
"REFERENCE",
}:
return DataType.VARCHAR.value
return DataType.UNKNOWN.value
def yield_view_lineage(self) -> Optional[Iterable[AddLineageRequest]]:
yield from []

View File

@ -11,6 +11,9 @@
"""
Singlestore source ingestion
"""
from sqlalchemy.dialects.mysql.base import ischema_names
from sqlalchemy.dialects.mysql.reflection import MySQLTableDefinitionParser
from metadata.generated.schema.entity.services.connections.database.singleStoreConnection import (
SingleStoreConnection,
)
@ -22,6 +25,14 @@ from metadata.generated.schema.metadataIngestion.workflow import (
)
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
from metadata.ingestion.source.database.mysql.utils import col_type_map, parse_column
ischema_names.update(col_type_map)
MySQLTableDefinitionParser._parse_column = ( # pylint: disable=protected-access
parse_column
)
class SinglestoreSource(CommonDbSourceService):

View File

@ -18,7 +18,6 @@ from typing import Iterable, List, Optional, Tuple
import sqlparse
from snowflake.sqlalchemy.custom_types import VARIANT
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect, ischema_names
from sqlalchemy.engine import reflection
from sqlalchemy.engine.reflection import Inspector
from sqlparse.sql import Function, Identifier
@ -44,20 +43,23 @@ from metadata.ingestion.source.database.common_db_source import CommonDbSourceSe
from metadata.ingestion.source.database.snowflake.queries import (
SNOWFLAKE_FETCH_ALL_TAGS,
SNOWFLAKE_GET_CLUSTER_KEY,
SNOWFLAKE_GET_COMMENTS,
SNOWFLAKE_GET_DATABASE_COMMENTS,
SNOWFLAKE_GET_SCHEMA_COMMENTS,
SNOWFLAKE_GET_TABLE_NAMES,
SNOWFLAKE_GET_VIEW_NAMES,
SNOWFLAKE_SESSION_TAG_QUERY,
)
from metadata.ingestion.source.database.snowflake.utils import (
get_schema_columns,
get_table_comment,
get_table_names,
get_unique_constraints,
get_view_definition,
get_view_names,
normalize_names,
)
from metadata.utils import fqn
from metadata.utils.filters import filter_by_database
from metadata.utils.logger import ingestion_logger
from metadata.utils.sqlalchemy_utils import (
get_all_table_comments,
get_table_comment_wrapper,
)
from metadata.utils.sqlalchemy_utils import get_all_table_comments
GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY")
GEOMETRY = create_sqlalchemy_type("GEOMETRY")
@ -69,71 +71,6 @@ logger = ingestion_logger()
SnowflakeDialect._json_deserializer = json.loads # pylint: disable=protected-access
def get_table_names(self, connection, schema, **kw): # pylint: disable=unused-argument
cursor = connection.execute(SNOWFLAKE_GET_TABLE_NAMES.format(schema))
result = [self.normalize_name(row[0]) for row in cursor]
return result
def get_view_names(self, connection, schema, **kw): # pylint: disable=unused-argument
cursor = connection.execute(SNOWFLAKE_GET_VIEW_NAMES.format(schema))
result = [self.normalize_name(row[0]) for row in cursor]
return result
@reflection.cache
def get_view_definition( # pylint: disable=unused-argument
self, connection, view_name, schema=None, **kw
):
"""
Gets the view definition
"""
schema = schema or self.default_schema_name
if schema:
cursor = connection.execute(
"SHOW /* sqlalchemy:get_view_definition */ VIEWS "
f"LIKE '{view_name}' IN {schema}"
)
else:
cursor = connection.execute(
"SHOW /* sqlalchemy:get_view_definition */ VIEWS " f"LIKE '{view_name}'"
)
n2i = self.__class__._map_name_to_idx(cursor) # pylint: disable=protected-access
try:
ret = cursor.fetchone()
if ret:
return ret[n2i["text"]]
except Exception:
pass
return None
@reflection.cache
def get_table_comment(
self, connection, table_name, schema=None, **kw
): # pylint: disable=unused-argument
return get_table_comment_wrapper(
self,
connection,
table_name=table_name,
schema=schema,
query=SNOWFLAKE_GET_COMMENTS,
)
@reflection.cache
def get_unique_constraints( # pylint: disable=unused-argument
self, connection, table_name, schema=None, **kw
):
return []
def normalize_names(self, name): # pylint: disable=unused-argument
return name
SnowflakeDialect.get_table_names = get_table_names
SnowflakeDialect.get_view_names = get_view_names
SnowflakeDialect.get_all_table_comments = get_all_table_comments
@ -141,6 +78,9 @@ SnowflakeDialect.normalize_name = normalize_names
SnowflakeDialect.get_table_comment = get_table_comment
SnowflakeDialect.get_view_definition = get_view_definition
SnowflakeDialect.get_unique_constraints = get_unique_constraints
SnowflakeDialect._get_schema_columns = ( # pylint: disable=protected-access
get_schema_columns
)
class SnowflakeSource(CommonDbSourceService):

View File

@ -87,3 +87,23 @@ FROM information_schema.schemata
SNOWFLAKE_GET_DATABASE_COMMENTS = """
select DATABASE_NAME,COMMENT from information_schema.databases
"""
SNOWFLAKE_GET_SCHEMA_COLUMNS = """
SELECT /* sqlalchemy:_get_schema_columns */
ic.table_name,
ic.column_name,
ic.data_type,
ic.character_maximum_length,
ic.numeric_precision,
ic.numeric_scale,
ic.is_nullable,
ic.column_default,
ic.is_identity,
ic.comment,
ic.identity_start,
ic.identity_increment
FROM information_schema.columns ic
WHERE ic.table_schema=:table_schema
ORDER BY ic.ordinal_position
"""

View File

@ -0,0 +1,187 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module to define overriden dialect methods
"""
import sqlalchemy.types as sqltypes
from sqlalchemy import exc as sa_exc
from sqlalchemy import util as sa_util
from sqlalchemy.engine import reflection
from sqlalchemy.sql import text
from sqlalchemy.types import FLOAT
from metadata.ingestion.source.database.snowflake.queries import (
SNOWFLAKE_GET_COMMENTS,
SNOWFLAKE_GET_SCHEMA_COLUMNS,
SNOWFLAKE_GET_TABLE_NAMES,
SNOWFLAKE_GET_VIEW_NAMES,
)
from metadata.utils.sqlalchemy_utils import (
get_display_datatype,
get_table_comment_wrapper,
)
def get_table_names(self, connection, schema, **kw): # pylint: disable=unused-argument
cursor = connection.execute(SNOWFLAKE_GET_TABLE_NAMES.format(schema))
result = [self.normalize_name(row[0]) for row in cursor]
return result
def get_view_names(self, connection, schema, **kw): # pylint: disable=unused-argument
cursor = connection.execute(SNOWFLAKE_GET_VIEW_NAMES.format(schema))
result = [self.normalize_name(row[0]) for row in cursor]
return result
@reflection.cache
def get_view_definition( # pylint: disable=unused-argument
self, connection, view_name, schema=None, **kw
):
"""
Gets the view definition
"""
schema = schema or self.default_schema_name
if schema:
cursor = connection.execute(
"SHOW /* sqlalchemy:get_view_definition */ VIEWS "
f"LIKE '{view_name}' IN {schema}"
)
else:
cursor = connection.execute(
"SHOW /* sqlalchemy:get_view_definition */ VIEWS " f"LIKE '{view_name}'"
)
n2i = self.__class__._map_name_to_idx(cursor) # pylint: disable=protected-access
try:
ret = cursor.fetchone()
if ret:
return ret[n2i["text"]]
except Exception:
pass
return None
@reflection.cache
def get_table_comment(
self, connection, table_name, schema=None, **kw
): # pylint: disable=unused-argument
return get_table_comment_wrapper(
self,
connection,
table_name=table_name,
schema=schema,
query=SNOWFLAKE_GET_COMMENTS,
)
@reflection.cache
def get_unique_constraints( # pylint: disable=unused-argument
self, connection, table_name, schema=None, **kw
):
return []
def normalize_names(self, name): # pylint: disable=unused-argument
return name
# pylint: disable=too-many-locals,protected-access
@reflection.cache
def get_schema_columns(self, connection, schema, **kw):
"""Get all columns in the schema, if we hit 'Information schema query returned too much data' problem return
None, as it is cacheable and is an unexpected return type for this function"""
ans = {}
current_database, _ = self._current_database_schema(connection, **kw)
full_schema_name = self._denormalize_quote_join(current_database, schema)
try:
schema_primary_keys = self._get_schema_primary_keys(
connection, full_schema_name, **kw
)
result = connection.execute(
text(SNOWFLAKE_GET_SCHEMA_COLUMNS),
{"table_schema": self.denormalize_name(schema)},
)
except sa_exc.ProgrammingError as p_err:
if p_err.orig.errno == 90030:
# This means that there are too many tables in the schema, we need to go more granular
return None # None triggers _get_table_columns while staying cacheable
raise
for (
table_name,
column_name,
coltype,
character_maximum_length,
numeric_precision,
numeric_scale,
is_nullable,
column_default,
is_identity,
comment,
identity_start,
identity_increment,
) in result:
table_name = self.normalize_name(table_name)
column_name = self.normalize_name(column_name)
if table_name not in ans:
ans[table_name] = []
if column_name.startswith("sys_clustering_column"):
continue # ignoring clustering column
col_type = self.ischema_names.get(coltype, None)
col_type_kw = {}
if col_type is None:
sa_util.warn(
f"Did not recognize type '{coltype}' of column '{column_name}'"
)
col_type = sqltypes.NULLTYPE
else:
if issubclass(col_type, FLOAT):
col_type_kw["precision"] = numeric_precision
col_type_kw["decimal_return_scale"] = numeric_scale
elif issubclass(col_type, sqltypes.Numeric):
col_type_kw["precision"] = numeric_precision
col_type_kw["scale"] = numeric_scale
elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)):
col_type_kw["length"] = character_maximum_length
type_instance = col_type(**col_type_kw)
current_table_pks = schema_primary_keys.get(table_name)
ans[table_name].append(
{
"name": column_name,
"type": type_instance,
"nullable": is_nullable == "YES",
"default": column_default,
"autoincrement": is_identity == "YES",
"system_data_type": get_display_datatype(
coltype,
char_len=character_maximum_length,
precision=numeric_precision,
scale=numeric_scale,
),
"comment": comment,
"primary_key": (
column_name
in schema_primary_keys[table_name]["constrained_columns"]
)
if current_table_pks
else False,
}
)
if is_identity == "YES":
ans[table_name][-1]["identity"] = {
"start": identity_start,
"increment": identity_increment,
}
return ans

View File

@ -73,7 +73,7 @@ class SqlColumnHandlerMixin:
)
if col_type == "ARRAY":
if arr_data_type is None:
arr_data_type = DataType.VARCHAR.value
arr_data_type = DataType.UNKNOWN.value
data_type_display = f"array<{arr_data_type}>"
return data_type_display
@ -81,11 +81,13 @@ class SqlColumnHandlerMixin:
data_type_display = None
arr_data_type = None
parsed_string = None
if "raw_data_type" in column and column["raw_data_type"] is not None:
column["raw_data_type"] = self.parse_raw_data_type(column["raw_data_type"])
if not column["raw_data_type"].startswith(schema):
if column.get("system_data_type") and column.get("is_complex"):
column["system_data_type"] = self.clean_raw_data_type(
column["system_data_type"]
)
if not column["system_data_type"].startswith(schema):
parsed_string = ColumnTypeParser._parse_datatype_string( # pylint: disable=protected-access
column["raw_data_type"]
column["system_data_type"]
)
parsed_string["name"] = column["name"]
else:
@ -100,7 +102,7 @@ class SqlColumnHandlerMixin:
arr_data_type = ColumnTypeParser.get_column_type(arr_data_type[0])
data_type_display = column["type"]
if col_type == DataType.ARRAY.value and not arr_data_type:
arr_data_type = DataType.VARCHAR.value
arr_data_type = DataType.UNKNOWN.value
data_type_display = data_type_display or column.get("display_type")
return data_type_display, arr_data_type, parsed_string
@ -173,7 +175,7 @@ class SqlColumnHandlerMixin:
parsed_string["dataType"], column["type"]
)
parsed_string["description"] = column.get("comment")
if column["raw_data_type"] == "array":
if column["system_data_type"] == "array":
array_data_type_display = (
repr(column["type"])
.replace("(", "<")
@ -253,10 +255,10 @@ class SqlColumnHandlerMixin:
col_type, column["type"]
)
if col_type is None:
col_type = DataType.VARCHAR.name
col_type = DataType.UNKNOWN.name
data_type_display = col_type.lower()
logger.warning(
f"Unknown type {repr(column['type'])} mapped to VARCHAR: {column['name']}"
f"Unknown type {repr(column['type'])}: {column['name']}"
)
data_type_display = self._get_display_datatype(
data_type_display,
@ -271,9 +273,11 @@ class SqlColumnHandlerMixin:
# Passing whitespace if column name is an empty string
# since pydantic doesn't accept empty string
if column["name"] else " ",
description=column.get("comment", None),
description=column.get("comment"),
dataType=col_type,
dataTypeDisplay=data_type_display,
dataTypeDisplay=column.get(
"system_data_type", data_type_display
),
dataLength=col_data_length,
constraint=col_constraint,
children=children,
@ -331,5 +335,5 @@ class SqlColumnHandlerMixin:
constraint = Constraint.UNIQUE
return constraint
def parse_raw_data_type(self, raw_data_type):
def clean_raw_data_type(self, raw_data_type):
return raw_data_type

View File

@ -114,13 +114,16 @@ def _get_columns(
"type": col_type,
"nullable": True,
"comment": record.Comment,
"system_data_type": record.Type,
}
type_str = record.Type.strip().lower()
type_name, type_opts = get_type_name_and_opts(type_str)
if type_opts and type_name == ROW_DATA_TYPE:
column["raw_data_type"] = parse_row_data_type(type_str)
column["system_data_type"] = parse_row_data_type(type_str)
column["is_complex"] = True
elif type_opts and type_name == ARRAY_DATA_TYPE:
column["raw_data_type"] = parse_array_data_type(type_str)
column["system_data_type"] = parse_array_data_type(type_str)
column["is_complex"] = True
columns.append(column)
return columns

View File

@ -19,8 +19,7 @@ from typing import Iterable, Optional
from sqlalchemy import sql, util
from sqlalchemy.engine import reflection
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.sqltypes import String
from sqlalchemy_vertica.base import VerticaDialect
from sqlalchemy_vertica.base import VerticaDialect, ischema_names
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.services.connections.database.verticaConnection import (
@ -33,6 +32,7 @@ from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source.database.column_type_parser import create_sqlalchemy_type
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
from metadata.ingestion.source.database.vertica.queries import (
VERTICA_GET_COLUMNS,
@ -53,12 +53,12 @@ from metadata.utils.sqlalchemy_utils import (
logger = ingestion_logger()
class UUID(String):
"""The SQL UUID type."""
__visit_name__ = "UUID"
ischema_names.update(
{
"UUID": create_sqlalchemy_type("UUID"),
"GEOGRAPHY": create_sqlalchemy_type("GEOGRAPHY"),
}
)
@reflection.cache
@ -142,7 +142,7 @@ def _get_column_info( # pylint: disable=too-many-locals,too-many-branches,too-m
args = (int(prec), int(scale))
else:
args = ()
elif attype == "integer":
elif attype == "integer" or attype.startswith("geography"):
args = ()
elif attype in ("timestamptz", "timetz"):
kwargs["timezone"] = True
@ -167,14 +167,13 @@ def _get_column_info( # pylint: disable=too-many-locals,too-many-branches,too-m
args = ()
elif charlen:
args = (int(charlen),)
self.ischema_names["UUID"] = UUID
if attype.upper() in self.ischema_names:
coltype = self.ischema_names[attype.upper()]
else:
coltype = None
if coltype:
coltype = coltype(*args, **kwargs)
coltype = coltype(*args, **kwargs) if callable(coltype) else coltype
else:
util.warn(f"Did not recognize type '{attype}' of column '{name}'")
coltype = sqltypes.NULLTYPE
@ -206,6 +205,7 @@ def _get_column_info( # pylint: disable=too-many-locals,too-many-branches,too-m
"name": name,
"type": coltype,
"nullable": nullable,
"system_data_type": format_type,
"default": default,
"autoincrement": autoincrement,
"comment": comment,

View File

@ -13,7 +13,7 @@
Module for sqlalchmey dialect utils
"""
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple
from sqlalchemy.engine import Engine, reflection
@ -66,3 +66,26 @@ def get_schema_descriptions(engine: Engine, query: str):
for row in results:
schema_desc_map[row.schema_name] = row.comment
return schema_desc_map
def is_complex_type(col_type: str):
return (
col_type.lower().startswith("array")
or col_type.lower().startswith("map")
or col_type.lower().startswith("struct")
or col_type.lower().startswith("row")
)
def get_display_datatype(
col_type: str,
char_len: Optional[int],
precision: Optional[int],
scale: Optional[int],
):
if char_len or (precision is not None and scale is None):
length = char_len or scale
return f"{col_type}({str(length)})"
if scale is not None and precision is not None:
return f"{col_type}({str(precision)},{str(scale)})"
return col_type

View File

@ -124,8 +124,8 @@
"dataType": "STRUCT"
},
{
"dataType": "VARCHAR",
"dataTypeDisplay": "VARCHAR",
"dataType": "UNKNOWN",
"dataTypeDisplay": "UNKNOWN",
"dataLength": 1
},
{

View File

@ -73,8 +73,8 @@ EXPTECTED_COLUMN_TYPE = [
"SMALLINT",
"LONGBLOB",
"JSON",
"POINT",
"VARCHAR",
"GEOMETRY",
"UNKNOWN",
]
root = os.path.dirname(__file__)

View File

@ -123,19 +123,19 @@ EXPECTED_DATA_MODELS = [
columns=[
Column(
name="customer_id",
dataType="VARCHAR",
dataType="UNKNOWN",
dataLength=1,
description="This is a unique identifier for a customer",
),
Column(
name="first_name",
dataType="VARCHAR",
dataType="UNKNOWN",
dataLength=1,
description="Customer's first name. PII.",
),
Column(
name="last_name",
dataType="VARCHAR",
dataType="UNKNOWN",
dataLength=1,
description="Customer's last name. PII.",
),
@ -174,7 +174,7 @@ EXPECTED_DATA_MODEL_NULL_DB = [
Column(
name="customer_id",
displayName=None,
dataType="VARCHAR",
dataType="UNKNOWN",
dataLength=1,
description="This is a unique identifier for an customer",
)

View File

@ -99,6 +99,7 @@ MOCK_COLUMN_VALUE = [
"nullable": True,
"default": None,
"autoincrement": False,
"system_data_type": "varchar(50)",
"comment": None,
},
{
@ -107,6 +108,7 @@ MOCK_COLUMN_VALUE = [
"nullable": True,
"default": None,
"autoincrement": False,
"system_data_type": "geometry",
"comment": None,
},
{
@ -115,6 +117,7 @@ MOCK_COLUMN_VALUE = [
"nullable": True,
"default": None,
"autoincrement": False,
"system_data_type": "point",
"comment": None,
},
{
@ -124,6 +127,7 @@ MOCK_COLUMN_VALUE = [
"default": None,
"autoincrement": False,
"comment": None,
"system_data_type": "polygon",
},
]
@ -137,7 +141,7 @@ EXPECTED_COLUMN_VALUE = [
dataLength=1,
precision=None,
scale=None,
dataTypeDisplay="VARCHAR(1)",
dataTypeDisplay="varchar(50)",
description=None,
fullyQualifiedName=None,
tags=None,
@ -156,7 +160,7 @@ EXPECTED_COLUMN_VALUE = [
dataLength=1,
precision=None,
scale=None,
dataTypeDisplay="GEOMETRY",
dataTypeDisplay="geometry",
description=None,
fullyQualifiedName=None,
tags=None,
@ -170,12 +174,12 @@ EXPECTED_COLUMN_VALUE = [
Column(
name="point_c",
displayName=None,
dataType=DataType.POINT,
dataType=DataType.GEOMETRY,
arrayDataType=None,
dataLength=1,
precision=None,
scale=None,
dataTypeDisplay="POINT",
dataTypeDisplay="point",
description=None,
fullyQualifiedName=None,
tags=None,
@ -189,12 +193,12 @@ EXPECTED_COLUMN_VALUE = [
Column(
name="polygon_c",
displayName=None,
dataType=DataType.POLYGON,
dataType=DataType.GEOMETRY,
arrayDataType=None,
dataLength=1,
precision=None,
scale=None,
dataTypeDisplay="POLYGON",
dataTypeDisplay="polygon",
description=None,
fullyQualifiedName=None,
tags=None,

View File

@ -117,7 +117,7 @@ EXPECTED_COLUMN_VALUE = [
dataLength=32000,
precision=None,
scale=None,
dataTypeDisplay=None,
dataTypeDisplay="textarea",
description="Contact Description",
fullyQualifiedName=None,
tags=None,
@ -136,7 +136,7 @@ EXPECTED_COLUMN_VALUE = [
dataLength=18,
precision=None,
scale=None,
dataTypeDisplay=None,
dataTypeDisplay="reference",
description="Owner ID",
fullyQualifiedName=None,
tags=None,
@ -150,12 +150,12 @@ EXPECTED_COLUMN_VALUE = [
Column(
name=ColumnName(__root__="Phone"),
displayName=None,
dataType=DataType.INT,
dataType=DataType.VARCHAR,
arrayDataType=None,
dataLength=0,
precision=None,
scale=None,
dataTypeDisplay=None,
dataTypeDisplay="phone",
description="Phone",
fullyQualifiedName=None,
tags=None,
@ -169,12 +169,12 @@ EXPECTED_COLUMN_VALUE = [
Column(
name=ColumnName(__root__="CreatedById"),
displayName=None,
dataType=DataType.VARCHAR,
dataType=DataType.UNKNOWN,
arrayDataType=None,
dataLength=18,
precision=None,
scale=None,
dataTypeDisplay=None,
dataTypeDisplay="anytype",
description="Created By ID",
fullyQualifiedName=None,
tags=None,
@ -420,7 +420,7 @@ SALESFORCE_FIELDS = [
("searchPrefilterable", False),
("soapType", "tns:ID"),
("sortable", True),
("type", "reference"),
("type", "anytype"),
("unique", False),
("updateable", False),
("writeRequiresMasterRead", False),
@ -429,7 +429,7 @@ SALESFORCE_FIELDS = [
]
EXPECTED_COLUMN_TYPE = ["VARCHAR", "VARCHAR", "INT", "VARCHAR"]
EXPECTED_COLUMN_TYPE = ["VARCHAR", "VARCHAR", "VARCHAR", "UNKNOWN"]
class SalesforceUnitTest(TestCase):

View File

@ -1764,8 +1764,8 @@ public abstract class EntityRepository<T extends EntityInterface> {
}
private void updateColumnDisplayName(Column origColumn, Column updatedColumn) throws JsonProcessingException {
if (operation.isPut() && !nullOrEmpty(origColumn.getDescription()) && updatedByBot()) {
// Revert the non-empty task description if being updated by a bot
if (operation.isPut() && !nullOrEmpty(origColumn.getDisplayName()) && updatedByBot()) {
// Revert the non-empty task display name if being updated by a bot
updatedColumn.setDisplayName(origColumn.getDisplayName());
return;
}

View File

@ -116,14 +116,30 @@
"UUID",
"VARIANT",
"GEOMETRY",
"POINT",
"POLYGON",
"BYTEA",
"AGGREGATEFUNCTION",
"ERROR",
"FIXED",
"RECORD",
"NULL"
"NULL",
"SUPER",
"HLLSKETCH",
"PG_LSN",
"PG_SNAPSHOT",
"TSQUERY",
"TXID_SNAPSHOT",
"XML",
"MACADDR",
"TSVECTOR",
"UNKNOWN",
"CIDR",
"INET",
"CLOB",
"ROWID",
"LOWCARDINALITY",
"YEAR",
"POINT",
"POLYGON"
]
},
"constraint": {