Reflection Cache Implementation (#2016)

* Reflection Cache for Bigquery and Redshift

* Overrided few sqlalchemy packages

* Added Geography Support

* Reformatted files

* DBT models error handling implemented

* Geography type added as a custom sqlalchemy datatype

* GEOGRAPHY and VARIANT added as custom sql types

* Implemented file formatting using black

* Implemented file formatting using black
This commit is contained in:
Ayush Shah 2022-01-11 14:58:03 +05:30 committed by GitHub
parent cf6f438531
commit f379b35279
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 430 additions and 111 deletions

View File

@ -61,7 +61,7 @@ base_plugins = {
plugins: Dict[str, Set[str]] = { plugins: Dict[str, Set[str]] = {
"amundsen": {"neo4j~=4.4.0"}, "amundsen": {"neo4j~=4.4.0"},
"athena": {"PyAthena[SQLAlchemy]"}, "athena": {"PyAthena[SQLAlchemy]"},
"bigquery": {"openmetadata-sqlalchemy-bigquery==0.2.2"}, "bigquery": {"sqlalchemy-bigquery==1.2.2"},
"bigquery-usage": {"google-cloud-logging", "cachetools"}, "bigquery-usage": {"google-cloud-logging", "cachetools"},
"docker": {"docker==5.0.3"}, "docker": {"docker==5.0.3"},
"dbt": {}, "dbt": {},
@ -69,10 +69,11 @@ plugins: Dict[str, Set[str]] = {
"elasticsearch": {"elasticsearch~=7.13.1"}, "elasticsearch": {"elasticsearch~=7.13.1"},
"glue": {"boto3~=1.19.12"}, "glue": {"boto3~=1.19.12"},
"hive": { "hive": {
"openmetadata-sqlalchemy-hive==0.2.0", "pyhive~=0.6.3",
"thrift~=0.13.0", "thrift~=0.13.0",
"sasl==0.3.1", "sasl==0.3.1",
"thrift-sasl==0.4.3", "thrift-sasl==0.4.3",
"presto-types-parser==0.0.2"
}, },
"kafka": {"confluent_kafka>=1.5.0", "fastavro>=1.2.0"}, "kafka": {"confluent_kafka>=1.5.0", "fastavro>=1.2.0"},
"ldap-users": {"ldap3==2.9.1"}, "ldap-users": {"ldap3==2.9.1"},
@ -86,12 +87,12 @@ plugins: Dict[str, Set[str]] = {
"postgres": {"pymysql>=1.0.2", "psycopg2-binary", "GeoAlchemy2"}, "postgres": {"pymysql>=1.0.2", "psycopg2-binary", "GeoAlchemy2"},
"redash": {"redash-toolbelt==0.1.4"}, "redash": {"redash-toolbelt==0.1.4"},
"redshift": { "redshift": {
"openmetadata-sqlalchemy-redshift==0.2.1", "sqlalchemy-redshift==0.8.9",
"psycopg2-binary", "psycopg2-binary",
"GeoAlchemy2", "GeoAlchemy2",
}, },
"redshift-usage": { "redshift-usage": {
"openmetadata-sqlalchemy-redshift==0.2.1", "sqlalchemy-redshift==0.8.9",
"psycopg2-binary", "psycopg2-binary",
"GeoAlchemy2", "GeoAlchemy2",
}, },

View File

@ -14,6 +14,40 @@ from typing import Optional, Tuple
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource
from metadata.utils.column_helpers import create_sqlalchemy_type
from sqlalchemy_bigquery import _types
from sqlalchemy_bigquery._struct import STRUCT
from sqlalchemy_bigquery._types import (
_get_sqla_column_type,
_get_transitive_schema_fields,
)
GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY")
_types._type_map["GEOGRAPHY"] = GEOGRAPHY
def get_columns(bq_schema):
fields = _get_transitive_schema_fields(bq_schema)
col_list = []
for field in fields:
col_obj = {
"name": field.name,
"type": _get_sqla_column_type(field)
if "STRUCT" or "RECORD" not in field
else STRUCT,
"nullable": field.mode == "NULLABLE" or field.mode == "REPEATED",
"comment": field.description,
"default": None,
"precision": field.precision,
"scale": field.scale,
"max_length": field.max_length,
"raw_data_type": repr(_get_sqla_column_type(field)),
}
col_list.append(col_obj)
return col_list
_types.get_columns = get_columns
class BigQueryConfig(SQLConnectionConfig, SQLSource): class BigQueryConfig(SQLConnectionConfig, SQLSource):

View File

@ -97,6 +97,9 @@ class BigqueryUsageSource(Source[TableQuery]):
jobStats["startTime"][0:19], "%Y-%m-%dT%H:%M:%S" jobStats["startTime"][0:19], "%Y-%m-%dT%H:%M:%S"
).strftime("%Y-%m-%d %H:%M:%S") ).strftime("%Y-%m-%d %H:%M:%S")
) )
logger.debug(
f"Query :{statementType}:{queryConfig['query']}"
)
tq = TableQuery( tq = TableQuery(
query=statementType, query=statementType,
user_name=entry.resource.labels["project_id"], user_name=entry.resource.labels["project_id"],

View File

@ -9,18 +9,50 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
from typing import Optional from typing import Optional
from pyhive import hive # noqa: F401
from pyhive.sqlalchemy_hive import HiveDate, HiveDecimal, HiveTimestamp
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource
from metadata.utils.column_helpers import register_custom_type from pyhive.sqlalchemy_hive import HiveDialect, _type_map
from sqlalchemy import types, util
register_custom_type(HiveDate, "DATE") complex_data_types = ["struct", "map", "array", "union"]
register_custom_type(HiveTimestamp, "TIME")
register_custom_type(HiveDecimal, "NUMBER")
def get_columns(self, connection, table_name, schema=None, **kw):
rows = self._get_table_columns(connection, table_name, schema)
rows = [[col.strip() if col else None for col in row] for row in rows]
rows = [row for row in rows if row[0] and row[0] != "# col_name"]
result = []
for (col_name, col_type, _comment) in rows:
if col_name == "# Partition Information":
break
col_raw_type = col_type
col_type = re.search(r"^\w+", col_type).group(0)
try:
coltype = _type_map[col_type]
except KeyError:
util.warn(
"Did not recognize type '%s' of column '%s'" % (col_type, col_name)
)
coltype = types.NullType
result.append(
{
"name": col_name,
"type": coltype,
"nullable": True,
"default": None,
"raw_data_type": col_raw_type
if col_type in complex_data_types
else None,
}
)
return result
HiveDialect.get_columns = get_columns
class HiveConfig(SQLConnectionConfig): class HiveConfig(SQLConnectionConfig):

View File

@ -10,8 +10,20 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re
from collections import defaultdict
from typing import Optional from typing import Optional
import sqlalchemy as sa
from packaging.version import Version
sa_version = Version(sa.__version__)
from sqlalchemy import inspect
from sqlalchemy.engine import reflection
from sqlalchemy.types import CHAR, VARCHAR, NullType
from sqlalchemy_redshift.dialect import RedshiftDialectMixin, RelationKey
from metadata.ingestion.api.source import SourceStatus from metadata.ingestion.api.source import SourceStatus
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource
@ -19,6 +31,212 @@ from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
return self._get_table_or_view_names(["r", "e"], connection, schema, **kw)
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
return self._get_table_or_view_names(["v"], connection, schema, **kw)
@reflection.cache
def _get_table_or_view_names(self, relkinds, connection, schema=None, **kw):
default_schema = inspect(connection).default_schema_name
if not schema:
schema = default_schema
info_cache = kw.get("info_cache")
all_relations = self._get_all_relation_info(connection, info_cache=info_cache)
relation_names = []
for key, relation in all_relations.items():
if key.schema == schema and relation.relkind in relkinds:
relation_names.append(key.name)
return relation_names
def _get_column_info(self, *args, **kwargs):
kw = kwargs.copy()
encode = kw.pop("encode", None)
if sa_version >= Version("1.3.16"):
kw["generated"] = ""
if sa_version < Version("1.4.0") and "identity" in kw:
del kw["identity"]
elif sa_version >= Version("1.4.0") and "identity" not in kw:
kw["identity"] = None
column_info = super(RedshiftDialectMixin, self)._get_column_info(*args, **kw)
column_info["raw_data_type"] = kw["format_type"]
if isinstance(column_info["type"], VARCHAR):
if column_info["type"].length is None:
column_info["type"] = NullType()
if re.match("char", column_info["raw_data_type"]):
column_info["type"] = CHAR
if "info" not in column_info:
column_info["info"] = {}
if encode and encode != "none":
column_info["info"]["encode"] = encode
return column_info
@reflection.cache
def _get_all_relation_info(self, connection, **kw):
result = connection.execute(
"""
SELECT
c.relkind,
n.oid as "schema_oid",
n.nspname as "schema",
c.oid as "rel_oid",
c.relname,
CASE c.reldiststyle
WHEN 0 THEN 'EVEN' WHEN 1 THEN 'KEY' WHEN 8 THEN 'ALL' END
AS "diststyle",
c.relowner AS "owner_id",
u.usename AS "owner_name",
TRIM(TRAILING ';' FROM pg_catalog.pg_get_viewdef(c.oid, true))
AS "view_definition",
pg_catalog.array_to_string(c.relacl, '\n') AS "privileges"
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
JOIN pg_catalog.pg_user u ON u.usesysid = c.relowner
WHERE c.relkind IN ('r', 'v', 'm', 'S', 'f')
AND n.nspname !~ '^pg_'
ORDER BY c.relkind, n.oid, n.nspname;
"""
)
relations = {}
for rel in result:
key = RelationKey(rel.relname, rel.schema, connection)
relations[key] = rel
result = connection.execute(
"""
SELECT
schemaname as "schema",
tablename as "relname",
'e' as relkind
FROM svv_external_tables;
"""
)
for rel in result:
key = RelationKey(rel.relname, rel.schema, connection)
relations[key] = rel
return relations
@reflection.cache
def _get_schema_column_info(self, connection, schema=None, **kw):
schema_clause = "AND schema = '{schema}'".format(schema=schema) if schema else ""
all_columns = defaultdict(list)
with connection.connect() as cc:
result = cc.execute(
"""
SELECT
n.nspname as "schema",
c.relname as "table_name",
att.attname as "name",
format_encoding(att.attencodingtype::integer) as "encode",
format_type(att.atttypid, att.atttypmod) as "type",
att.attisdistkey as "distkey",
att.attsortkeyord as "sortkey",
att.attnotnull as "notnull",
pg_catalog.col_description(att.attrelid, att.attnum)
as "comment",
adsrc,
attnum,
pg_catalog.format_type(att.atttypid, att.atttypmod),
pg_catalog.pg_get_expr(ad.adbin, ad.adrelid) AS DEFAULT,
n.oid as "schema_oid",
c.oid as "table_oid"
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n
ON n.oid = c.relnamespace
JOIN pg_catalog.pg_attribute att
ON att.attrelid = c.oid
LEFT JOIN pg_catalog.pg_attrdef ad
ON (att.attrelid, att.attnum) = (ad.adrelid, ad.adnum)
WHERE n.nspname !~ '^pg_'
AND att.attnum > 0
AND NOT att.attisdropped
{schema_clause}
UNION
SELECT
view_schema as "schema",
view_name as "table_name",
col_name as "name",
null as "encode",
col_type as "type",
null as "distkey",
0 as "sortkey",
null as "notnull",
null as "comment",
null as "adsrc",
null as "attnum",
col_type as "format_type",
null as "default",
null as "schema_oid",
null as "table_oid"
FROM pg_get_late_binding_view_cols() cols(
view_schema name,
view_name name,
col_name name,
col_type varchar,
col_num int)
WHERE 1 {schema_clause}
UNION
SELECT schemaname AS "schema",
tablename AS "table_name",
columnname AS "name",
null AS "encode",
-- Spectrum represents data types differently.
-- Standardize, so we can infer types.
CASE
WHEN external_type = 'int' THEN 'integer'
ELSE
replace(
replace(external_type, 'decimal', 'numeric'),
'varchar', 'character varying')
END
AS "type",
null AS "distkey",
0 AS "sortkey",
null AS "notnull",
null AS "comment",
null AS "adsrc",
null AS "attnum",
CASE
WHEN external_type = 'int' THEN 'integer'
ELSE
replace(
replace(external_type, 'decimal', 'numeric'),
'varchar', 'character varying')
END
AS "format_type",
null AS "default",
null AS "schema_oid",
null AS "table_oid"
FROM svv_external_columns
ORDER BY "schema", "table_name", "attnum";
""".format(
schema_clause=schema_clause
)
)
for col in result:
key = RelationKey(col.table_name, col.schema, connection)
all_columns[key].append(col)
return dict(all_columns)
RedshiftDialectMixin._get_table_or_view_names = _get_table_or_view_names
RedshiftDialectMixin.get_view_names = get_view_names
RedshiftDialectMixin.get_table_names = get_table_names
RedshiftDialectMixin._get_column_info = _get_column_info
RedshiftDialectMixin._get_all_relation_info = _get_all_relation_info
RedshiftDialectMixin._get_schema_column_info = _get_schema_column_info
class RedshiftConfig(SQLConnectionConfig): class RedshiftConfig(SQLConnectionConfig):
scheme = "redshift+psycopg2" scheme = "redshift+psycopg2"
where_clause: Optional[str] = None where_clause: Optional[str] = None

View File

@ -11,15 +11,15 @@
from typing import Optional from typing import Optional
from snowflake.sqlalchemy import custom_types
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource
from metadata.utils.column_helpers import register_custom_type from metadata.utils.column_helpers import create_sqlalchemy_type
from snowflake.sqlalchemy.custom_types import VARIANT
from snowflake.sqlalchemy.snowdialect import ischema_names
register_custom_type(custom_types.TIMESTAMP_TZ, "TIME") GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY")
register_custom_type(custom_types.TIMESTAMP_LTZ, "TIME") ischema_names["VARIANT"] = VARIANT
register_custom_type(custom_types.TIMESTAMP_NTZ, "TIME") ischema_names["GEOGRAPHY"] = GEOGRAPHY
class SnowflakeConfig(SQLConnectionConfig): class SnowflakeConfig(SQLConnectionConfig):

View File

@ -22,11 +22,6 @@ from datetime import datetime
from typing import Dict, Iterable, List, Optional, Tuple from typing import Dict, Iterable, List, Optional, Tuple
from urllib.parse import quote_plus from urllib.parse import quote_plus
from pydantic import SecretStr
from sqlalchemy import create_engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.inspection import inspect
from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.table import ( from metadata.generated.schema.entity.data.table import (
Column, Column,
@ -54,6 +49,10 @@ from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.utils.column_helpers import check_column_complex_type, get_column_type from metadata.utils.column_helpers import check_column_complex_type, get_column_type
from metadata.utils.helpers import get_database_service_or_create from metadata.utils.helpers import get_database_service_or_create
from pydantic import SecretStr
from sqlalchemy import create_engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.inspection import inspect
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)
@ -271,7 +270,8 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
def next_record(self) -> Iterable[Entity]: def next_record(self) -> Iterable[Entity]:
inspector = inspect(self.engine) inspector = inspect(self.engine)
for schema in inspector.get_schema_names(): schema_names = inspector.get_schema_names()
for schema in schema_names:
# clear any previous source database state # clear any previous source database state
self.database_source_state.clear() self.database_source_state.clear()
if not self.sql_config.schema_filter_pattern.included(schema): if not self.sql_config.schema_filter_pattern.included(schema):
@ -292,7 +292,8 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
Scrape an SQL schema and prepare Database and Table Scrape an SQL schema and prepare Database and Table
OpenMetadata Entities OpenMetadata Entities
""" """
for table_name in inspector.get_table_names(schema): tables = inspector.get_table_names(schema)
for table_name in tables:
try: try:
schema, table_name = self.standardize_schema_table_names( schema, table_name = self.standardize_schema_table_names(
schema, table_name schema, table_name
@ -303,8 +304,6 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
"Table pattern not allowed", "Table pattern not allowed",
) )
continue continue
self.status.scanned(f"{self.config.get_service_name()}.{table_name}")
description = _get_table_description(schema, table_name, inspector) description = _get_table_description(schema, table_name, inspector)
fqn = f"{self.config.service_name}.{schema}.{table_name}" fqn = f"{self.config.service_name}.{schema}.{table_name}"
self.database_source_state.add(fqn) self.database_source_state.add(fqn)
@ -338,10 +337,15 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
table=table_entity, database=self._get_database(schema) table=table_entity, database=self._get_database(schema)
) )
yield table_and_db yield table_and_db
# Catch any errors during the ingestion and continue self.status.scanned(
except Exception as err: # pylint: disable=broad-except "{}.{}".format(self.config.get_service_name(), table_name)
)
except Exception as err:
traceback.print_exc()
logger.error(err) logger.error(err)
self.status.warnings.append(f"{self.config.service_name}.{table_name}") self.status.failures.append(
"{}.{}".format(self.config.service_name, table_name)
)
continue continue
def fetch_views( def fetch_views(
@ -426,31 +430,36 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
} }
for key, mnode in manifest_entities.items(): for key, mnode in manifest_entities.items():
name = mnode["alias"] if "alias" in mnode.keys() else mnode["name"] try:
cnode = catalog_entities.get(key) name = mnode["alias"] if "alias" in mnode.keys() else mnode["name"]
columns = ( cnode = catalog_entities.get(key)
self._parse_data_model_columns(name, mnode, cnode) if cnode else [] columns = (
) self._parse_data_model_columns(name, mnode, cnode)
if cnode
else []
)
if mnode["resource_type"] == "test": if mnode["resource_type"] == "test":
continue continue
upstream_nodes = self._parse_data_model_upstream(mnode) upstream_nodes = self._parse_data_model_upstream(mnode)
model_name = ( model_name = (
mnode["alias"] if "alias" in mnode.keys() else mnode["name"] mnode["alias"] if "alias" in mnode.keys() else mnode["name"]
) )
model_name = model_name.replace(".", "_DOT_") model_name = model_name.replace(".", "_DOT_")
schema = mnode["schema"] schema = mnode["schema"]
raw_sql = mnode.get("raw_sql", "") raw_sql = mnode.get("raw_sql", "")
model = DataModel( model = DataModel(
modelType=ModelType.DBT, modelType=ModelType.DBT,
description=mnode.get("description", ""), description=mnode.get("description", ""),
path=f"{mnode['root_path']}/{mnode['original_file_path']}", path=f"{mnode['root_path']}/{mnode['original_file_path']}",
rawSql=raw_sql, rawSql=raw_sql,
sql=mnode.get("compiled_sql", raw_sql), sql=mnode.get("compiled_sql", raw_sql),
columns=columns, columns=columns,
upstream=upstream_nodes, upstream=upstream_nodes,
) )
model_fqdn = f"{schema}.{model_name}" model_fqdn = f"{schema}.{model_name}"
except Exception as err:
print(err)
self.data_models[model_fqdn] = model self.data_models[model_fqdn] = model
def _parse_data_model_upstream(self, mnode): def _parse_data_model_upstream(self, mnode):
@ -507,7 +516,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
def _get_database(self, schema: str) -> Database: def _get_database(self, schema: str) -> Database:
return Database( return Database(
name=schema, name=schema.replace(".", "_DOT_"),
service=EntityReference(id=self.service.id, type=self.config.service_type), service=EntityReference(id=self.service.id, type=self.config.service_type),
) )
@ -561,48 +570,61 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
dataset_name = f"{schema}.{table}" dataset_name = f"{schema}.{table}"
table_columns = [] table_columns = []
columns = inspector.get_columns(table, schema)
try: try:
for row_order, column in enumerate(inspector.get_columns(table, schema)): for row_order, column in enumerate(columns):
if "." in column["name"]:
logger.info(f"Found '.' in {column['name']}")
column["name"] = column["name"].replace(".", "_DOT_")
children = None
data_type_display = None
col_data_length = None
arr_data_type = None
if "raw_data_type" in column and column["raw_data_type"] is not None:
(
col_type,
data_type_display,
arr_data_type,
children,
) = check_column_complex_type(
self.status,
dataset_name,
column["raw_data_type"],
column["name"],
)
else:
col_type = get_column_type(
self.status, dataset_name, column["type"]
)
if col_type == "ARRAY" and re.match(
r"(?:\w*)(?:\()(\w*)(?:.*)", str(column["type"])
):
arr_data_type = re.match(
r"(?:\w*)(?:[(]*)(\w*)(?:.*)", str(column["type"])
).groups()
data_type_display = column["type"]
col_constraint = self._get_column_constraints(
column, pk_columns, unique_columns
)
if col_type.upper() in {"CHAR", "VARCHAR", "BINARY", "VARBINARY"}:
col_data_length = column["type"].length
if col_data_length is None:
col_data_length = 1
try: try:
if "." in column["name"]:
logger.info(
f"Found '.' in {column['name']}, changing '.' to '_DOT_'"
)
column["name"] = column["name"].replace(".", "_DOT_")
children = None
data_type_display = None
col_data_length = None
arr_data_type = None
if (
"raw_data_type" in column
and column["raw_data_type"] is not None
):
column["raw_data_type"] = self.parse_raw_data_type(
column["raw_data_type"]
)
(
col_type,
data_type_display,
arr_data_type,
children,
) = check_column_complex_type(
self.status,
dataset_name,
column["raw_data_type"],
column["name"],
)
else:
col_type = get_column_type(
self.status, dataset_name, column["type"]
)
if col_type == "ARRAY" and re.match(
r"(?:\w*)(?:\()(\w*)(?:.*)", str(column["type"])
):
arr_data_type = re.match(
r"(?:\w*)(?:[(]*)(\w*)(?:.*)", str(column["type"])
).groups()
data_type_display = column["type"]
if repr(column["type"]).upper().startswith("ARRAY("):
arr_data_type = "STRUCT"
data_type_display = (
repr(column["type"])
.replace("(", "<")
.replace(")", ">")
.lower()
)
col_constraint = self._get_column_constraints(
column, pk_columns, unique_columns
)
if col_type.upper() in {"CHAR", "VARCHAR", "BINARY", "VARBINARY"}:
col_data_length = column["type"].length
if col_type == "NULL": if col_type == "NULL":
col_type = "VARCHAR" col_type = "VARCHAR"
data_type_display = "varchar" data_type_display = "varchar"
@ -613,23 +635,24 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
name=column["name"], name=column["name"],
description=column.get("comment", None), description=column.get("comment", None),
dataType=col_type, dataType=col_type,
dataTypeDisplay=f"{col_type}({col_data_length})" dataTypeDisplay="{}({})".format(
col_type, 1 if col_data_length is None else col_data_length
)
if data_type_display is None if data_type_display is None
else f"{data_type_display}", else f"{data_type_display}",
dataLength=col_data_length, dataLength=1 if col_data_length is None else col_data_length,
constraint=col_constraint, constraint=col_constraint,
ordinalPosition=row_order + 1, # enumerate starts at 0 ordinalPosition=row_order,
children=children, children=children if children is not None else None,
arrayDataType=arr_data_type, arrayDataType=arr_data_type,
) )
except Exception as err: # pylint: disable=broad-except except Exception as err:
logger.error(traceback.format_exc())
logger.error(traceback.print_exc()) logger.error(traceback.print_exc())
logger.error(f"{err} : {column}") logger.error(f"{err} : {column}")
continue continue
table_columns.append(om_column) table_columns.append(om_column)
return table_columns return table_columns
except Exception as err: # pylint: disable=broad-except except Exception as err:
logger.error(f"{repr(err)}: {table} {err}") logger.error(f"{repr(err)}: {table} {err}")
return None return None
@ -661,6 +684,9 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
logger.debug(f"Finished profiling {dataset_name}") logger.debug(f"Finished profiling {dataset_name}")
return profile return profile
def parse_raw_data_type(self, raw_data_type):
return raw_data_type
def _build_database_state(self, schema_fqdn: str) -> [EntityReference]: def _build_database_state(self, schema_fqdn: str) -> [EntityReference]:
after = None after = None
tables = [] tables = []

View File

@ -1,22 +1,26 @@
import re import re
from typing import Any, Dict, Optional, Set, Type from typing import Any, Dict, Optional, Set, Type
from sqlalchemy.sql import sqltypes as types
from metadata.ingestion.api.source import SourceStatus from metadata.ingestion.api.source import SourceStatus
from sqlalchemy.sql import sqltypes as types
from sqlalchemy.types import TypeEngine
def register_custom_type(tp: Type[types.TypeEngine], output: str = None) -> None:
if output:
_column_type_mapping[tp] = output
else:
_known_unknown_column_types.add(tp)
def register_custom_str_type(tp: str, output: str) -> None: def register_custom_str_type(tp: str, output: str) -> None:
_column_string_mapping[tp] = output _column_string_mapping[tp] = output
def create_sqlalchemy_type(name: str):
sqlalchemy_type = type(
name,
(TypeEngine,),
{
"__repr__": lambda self: f"{name}()",
},
)
return sqlalchemy_type
_column_type_mapping: Dict[Type[types.TypeEngine], str] = { _column_type_mapping: Dict[Type[types.TypeEngine], str] = {
types.Integer: "INT", types.Integer: "INT",
types.Numeric: "INT", types.Numeric: "INT",
@ -123,6 +127,8 @@ _column_string_mapping = {
"XML": "BINARY", "XML": "BINARY",
"XMLTYPE": "BINARY", "XMLTYPE": "BINARY",
"CURSOR": "BINARY", "CURSOR": "BINARY",
"TIMESTAMP_LTZ": "TIMESTAMP",
"TIMESTAMP_TZ": "TIMESTAMP",
} }
_known_unknown_column_types: Set[Type[types.TypeEngine]] = { _known_unknown_column_types: Set[Type[types.TypeEngine]] = {

View File

@ -14,7 +14,6 @@ from openmetadata.common.database_common import (
DatabaseCommon, DatabaseCommon,
SQLConnectionConfig, SQLConnectionConfig,
SQLExpressions, SQLExpressions,
register_custom_type,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)