diff --git a/ingestion/setup.py b/ingestion/setup.py index 150419b3a89..f523db7dc84 100644 --- a/ingestion/setup.py +++ b/ingestion/setup.py @@ -61,7 +61,7 @@ base_plugins = { plugins: Dict[str, Set[str]] = { "amundsen": {"neo4j~=4.4.0"}, "athena": {"PyAthena[SQLAlchemy]"}, - "bigquery": {"openmetadata-sqlalchemy-bigquery==0.2.2"}, + "bigquery": {"sqlalchemy-bigquery==1.2.2"}, "bigquery-usage": {"google-cloud-logging", "cachetools"}, "docker": {"docker==5.0.3"}, "dbt": {}, @@ -69,10 +69,11 @@ plugins: Dict[str, Set[str]] = { "elasticsearch": {"elasticsearch~=7.13.1"}, "glue": {"boto3~=1.19.12"}, "hive": { - "openmetadata-sqlalchemy-hive==0.2.0", + "pyhive~=0.6.3", "thrift~=0.13.0", "sasl==0.3.1", "thrift-sasl==0.4.3", + "presto-types-parser==0.0.2" }, "kafka": {"confluent_kafka>=1.5.0", "fastavro>=1.2.0"}, "ldap-users": {"ldap3==2.9.1"}, @@ -86,12 +87,12 @@ plugins: Dict[str, Set[str]] = { "postgres": {"pymysql>=1.0.2", "psycopg2-binary", "GeoAlchemy2"}, "redash": {"redash-toolbelt==0.1.4"}, "redshift": { - "openmetadata-sqlalchemy-redshift==0.2.1", + "sqlalchemy-redshift==0.8.9", "psycopg2-binary", "GeoAlchemy2", }, "redshift-usage": { - "openmetadata-sqlalchemy-redshift==0.2.1", + "sqlalchemy-redshift==0.8.9", "psycopg2-binary", "GeoAlchemy2", }, diff --git a/ingestion/src/metadata/ingestion/source/bigquery.py b/ingestion/src/metadata/ingestion/source/bigquery.py index adaf7b33378..047f1221d15 100644 --- a/ingestion/src/metadata/ingestion/source/bigquery.py +++ b/ingestion/src/metadata/ingestion/source/bigquery.py @@ -14,6 +14,40 @@ from typing import Optional, Tuple from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource +from metadata.utils.column_helpers import create_sqlalchemy_type +from sqlalchemy_bigquery import _types +from sqlalchemy_bigquery._struct import STRUCT +from sqlalchemy_bigquery._types import ( + _get_sqla_column_type, + _get_transitive_schema_fields, +) + +GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY") +_types._type_map["GEOGRAPHY"] = GEOGRAPHY + + +def get_columns(bq_schema): + fields = _get_transitive_schema_fields(bq_schema) + col_list = [] + for field in fields: + col_obj = { + "name": field.name, + "type": _get_sqla_column_type(field) + if "STRUCT" or "RECORD" not in field + else STRUCT, + "nullable": field.mode == "NULLABLE" or field.mode == "REPEATED", + "comment": field.description, + "default": None, + "precision": field.precision, + "scale": field.scale, + "max_length": field.max_length, + "raw_data_type": repr(_get_sqla_column_type(field)), + } + col_list.append(col_obj) + return col_list + + +_types.get_columns = get_columns class BigQueryConfig(SQLConnectionConfig, SQLSource): diff --git a/ingestion/src/metadata/ingestion/source/bigquery_usage.py b/ingestion/src/metadata/ingestion/source/bigquery_usage.py index de9e9c853ba..71f359d9dbe 100644 --- a/ingestion/src/metadata/ingestion/source/bigquery_usage.py +++ b/ingestion/src/metadata/ingestion/source/bigquery_usage.py @@ -97,6 +97,9 @@ class BigqueryUsageSource(Source[TableQuery]): jobStats["startTime"][0:19], "%Y-%m-%dT%H:%M:%S" ).strftime("%Y-%m-%d %H:%M:%S") ) + logger.debug( + f"Query :{statementType}:{queryConfig['query']}" + ) tq = TableQuery( query=statementType, user_name=entry.resource.labels["project_id"], diff --git a/ingestion/src/metadata/ingestion/source/hive.py b/ingestion/src/metadata/ingestion/source/hive.py index faf983990fb..db7d1011da9 100644 --- a/ingestion/src/metadata/ingestion/source/hive.py +++ b/ingestion/src/metadata/ingestion/source/hive.py @@ -9,18 +9,50 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from typing import Optional -from pyhive import hive # noqa: F401 -from pyhive.sqlalchemy_hive import HiveDate, HiveDecimal, HiveTimestamp - from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource -from metadata.utils.column_helpers import register_custom_type +from pyhive.sqlalchemy_hive import HiveDialect, _type_map +from sqlalchemy import types, util -register_custom_type(HiveDate, "DATE") -register_custom_type(HiveTimestamp, "TIME") -register_custom_type(HiveDecimal, "NUMBER") +complex_data_types = ["struct", "map", "array", "union"] + + +def get_columns(self, connection, table_name, schema=None, **kw): + rows = self._get_table_columns(connection, table_name, schema) + rows = [[col.strip() if col else None for col in row] for row in rows] + rows = [row for row in rows if row[0] and row[0] != "# col_name"] + result = [] + for (col_name, col_type, _comment) in rows: + if col_name == "# Partition Information": + break + col_raw_type = col_type + col_type = re.search(r"^\w+", col_type).group(0) + try: + coltype = _type_map[col_type] + except KeyError: + util.warn( + "Did not recognize type '%s' of column '%s'" % (col_type, col_name) + ) + coltype = types.NullType + + result.append( + { + "name": col_name, + "type": coltype, + "nullable": True, + "default": None, + "raw_data_type": col_raw_type + if col_type in complex_data_types + else None, + } + ) + return result + + +HiveDialect.get_columns = get_columns class HiveConfig(SQLConnectionConfig): diff --git a/ingestion/src/metadata/ingestion/source/redshift.py b/ingestion/src/metadata/ingestion/source/redshift.py index 38bab3b0b95..7b170c8c31e 100644 --- a/ingestion/src/metadata/ingestion/source/redshift.py +++ b/ingestion/src/metadata/ingestion/source/redshift.py @@ -10,8 +10,20 @@ # limitations under the License. import logging +import re +from collections import defaultdict from typing import Optional +import sqlalchemy as sa +from packaging.version import Version + +sa_version = Version(sa.__version__) + +from sqlalchemy import inspect +from sqlalchemy.engine import reflection +from sqlalchemy.types import CHAR, VARCHAR, NullType +from sqlalchemy_redshift.dialect import RedshiftDialectMixin, RelationKey + from metadata.ingestion.api.source import SourceStatus from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource @@ -19,6 +31,212 @@ from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource logger = logging.getLogger(__name__) +@reflection.cache +def get_table_names(self, connection, schema=None, **kw): + return self._get_table_or_view_names(["r", "e"], connection, schema, **kw) + + +@reflection.cache +def get_view_names(self, connection, schema=None, **kw): + return self._get_table_or_view_names(["v"], connection, schema, **kw) + + +@reflection.cache +def _get_table_or_view_names(self, relkinds, connection, schema=None, **kw): + default_schema = inspect(connection).default_schema_name + if not schema: + schema = default_schema + info_cache = kw.get("info_cache") + all_relations = self._get_all_relation_info(connection, info_cache=info_cache) + relation_names = [] + for key, relation in all_relations.items(): + if key.schema == schema and relation.relkind in relkinds: + relation_names.append(key.name) + return relation_names + + +def _get_column_info(self, *args, **kwargs): + kw = kwargs.copy() + encode = kw.pop("encode", None) + if sa_version >= Version("1.3.16"): + kw["generated"] = "" + if sa_version < Version("1.4.0") and "identity" in kw: + del kw["identity"] + elif sa_version >= Version("1.4.0") and "identity" not in kw: + kw["identity"] = None + column_info = super(RedshiftDialectMixin, self)._get_column_info(*args, **kw) + column_info["raw_data_type"] = kw["format_type"] + + if isinstance(column_info["type"], VARCHAR): + if column_info["type"].length is None: + column_info["type"] = NullType() + if re.match("char", column_info["raw_data_type"]): + column_info["type"] = CHAR + + if "info" not in column_info: + column_info["info"] = {} + if encode and encode != "none": + column_info["info"]["encode"] = encode + return column_info + + +@reflection.cache +def _get_all_relation_info(self, connection, **kw): + result = connection.execute( + """ + SELECT + c.relkind, + n.oid as "schema_oid", + n.nspname as "schema", + c.oid as "rel_oid", + c.relname, + CASE c.reldiststyle + WHEN 0 THEN 'EVEN' WHEN 1 THEN 'KEY' WHEN 8 THEN 'ALL' END + AS "diststyle", + c.relowner AS "owner_id", + u.usename AS "owner_name", + TRIM(TRAILING ';' FROM pg_catalog.pg_get_viewdef(c.oid, true)) + AS "view_definition", + pg_catalog.array_to_string(c.relacl, '\n') AS "privileges" + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + JOIN pg_catalog.pg_user u ON u.usesysid = c.relowner + WHERE c.relkind IN ('r', 'v', 'm', 'S', 'f') + AND n.nspname !~ '^pg_' + ORDER BY c.relkind, n.oid, n.nspname; + """ + ) + relations = {} + for rel in result: + key = RelationKey(rel.relname, rel.schema, connection) + relations[key] = rel + + result = connection.execute( + """ + SELECT + schemaname as "schema", + tablename as "relname", + 'e' as relkind + FROM svv_external_tables; + """ + ) + for rel in result: + key = RelationKey(rel.relname, rel.schema, connection) + relations[key] = rel + return relations + + +@reflection.cache +def _get_schema_column_info(self, connection, schema=None, **kw): + schema_clause = "AND schema = '{schema}'".format(schema=schema) if schema else "" + all_columns = defaultdict(list) + with connection.connect() as cc: + result = cc.execute( + """ + SELECT + n.nspname as "schema", + c.relname as "table_name", + att.attname as "name", + format_encoding(att.attencodingtype::integer) as "encode", + format_type(att.atttypid, att.atttypmod) as "type", + att.attisdistkey as "distkey", + att.attsortkeyord as "sortkey", + att.attnotnull as "notnull", + pg_catalog.col_description(att.attrelid, att.attnum) + as "comment", + adsrc, + attnum, + pg_catalog.format_type(att.atttypid, att.atttypmod), + pg_catalog.pg_get_expr(ad.adbin, ad.adrelid) AS DEFAULT, + n.oid as "schema_oid", + c.oid as "table_oid" + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n + ON n.oid = c.relnamespace + JOIN pg_catalog.pg_attribute att + ON att.attrelid = c.oid + LEFT JOIN pg_catalog.pg_attrdef ad + ON (att.attrelid, att.attnum) = (ad.adrelid, ad.adnum) + WHERE n.nspname !~ '^pg_' + AND att.attnum > 0 + AND NOT att.attisdropped + {schema_clause} + UNION + SELECT + view_schema as "schema", + view_name as "table_name", + col_name as "name", + null as "encode", + col_type as "type", + null as "distkey", + 0 as "sortkey", + null as "notnull", + null as "comment", + null as "adsrc", + null as "attnum", + col_type as "format_type", + null as "default", + null as "schema_oid", + null as "table_oid" + FROM pg_get_late_binding_view_cols() cols( + view_schema name, + view_name name, + col_name name, + col_type varchar, + col_num int) + WHERE 1 {schema_clause} + UNION + SELECT schemaname AS "schema", + tablename AS "table_name", + columnname AS "name", + null AS "encode", + -- Spectrum represents data types differently. + -- Standardize, so we can infer types. + CASE + WHEN external_type = 'int' THEN 'integer' + ELSE + replace( + replace(external_type, 'decimal', 'numeric'), + 'varchar', 'character varying') + END + AS "type", + null AS "distkey", + 0 AS "sortkey", + null AS "notnull", + null AS "comment", + null AS "adsrc", + null AS "attnum", + CASE + WHEN external_type = 'int' THEN 'integer' + ELSE + replace( + replace(external_type, 'decimal', 'numeric'), + 'varchar', 'character varying') + END + AS "format_type", + null AS "default", + null AS "schema_oid", + null AS "table_oid" + FROM svv_external_columns + ORDER BY "schema", "table_name", "attnum"; + """.format( + schema_clause=schema_clause + ) + ) + for col in result: + key = RelationKey(col.table_name, col.schema, connection) + all_columns[key].append(col) + return dict(all_columns) + + +RedshiftDialectMixin._get_table_or_view_names = _get_table_or_view_names +RedshiftDialectMixin.get_view_names = get_view_names +RedshiftDialectMixin.get_table_names = get_table_names +RedshiftDialectMixin._get_column_info = _get_column_info +RedshiftDialectMixin._get_all_relation_info = _get_all_relation_info +RedshiftDialectMixin._get_schema_column_info = _get_schema_column_info + + class RedshiftConfig(SQLConnectionConfig): scheme = "redshift+psycopg2" where_clause: Optional[str] = None diff --git a/ingestion/src/metadata/ingestion/source/snowflake.py b/ingestion/src/metadata/ingestion/source/snowflake.py index 29dad7dc688..855bef3f87b 100644 --- a/ingestion/src/metadata/ingestion/source/snowflake.py +++ b/ingestion/src/metadata/ingestion/source/snowflake.py @@ -11,15 +11,15 @@ from typing import Optional -from snowflake.sqlalchemy import custom_types - from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource -from metadata.utils.column_helpers import register_custom_type +from metadata.utils.column_helpers import create_sqlalchemy_type +from snowflake.sqlalchemy.custom_types import VARIANT +from snowflake.sqlalchemy.snowdialect import ischema_names -register_custom_type(custom_types.TIMESTAMP_TZ, "TIME") -register_custom_type(custom_types.TIMESTAMP_LTZ, "TIME") -register_custom_type(custom_types.TIMESTAMP_NTZ, "TIME") +GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY") +ischema_names["VARIANT"] = VARIANT +ischema_names["GEOGRAPHY"] = GEOGRAPHY class SnowflakeConfig(SQLConnectionConfig): diff --git a/ingestion/src/metadata/ingestion/source/sql_source.py b/ingestion/src/metadata/ingestion/source/sql_source.py index e9d423f5466..57edabcfb25 100644 --- a/ingestion/src/metadata/ingestion/source/sql_source.py +++ b/ingestion/src/metadata/ingestion/source/sql_source.py @@ -22,11 +22,6 @@ from datetime import datetime from typing import Dict, Iterable, List, Optional, Tuple from urllib.parse import quote_plus -from pydantic import SecretStr -from sqlalchemy import create_engine -from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.inspection import inspect - from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.table import ( Column, @@ -54,6 +49,10 @@ from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.utils.column_helpers import check_column_complex_type, get_column_type from metadata.utils.helpers import get_database_service_or_create +from pydantic import SecretStr +from sqlalchemy import create_engine +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.inspection import inspect logger: logging.Logger = logging.getLogger(__name__) @@ -271,7 +270,8 @@ class SQLSource(Source[OMetaDatabaseAndTable]): def next_record(self) -> Iterable[Entity]: inspector = inspect(self.engine) - for schema in inspector.get_schema_names(): + schema_names = inspector.get_schema_names() + for schema in schema_names: # clear any previous source database state self.database_source_state.clear() if not self.sql_config.schema_filter_pattern.included(schema): @@ -292,7 +292,8 @@ class SQLSource(Source[OMetaDatabaseAndTable]): Scrape an SQL schema and prepare Database and Table OpenMetadata Entities """ - for table_name in inspector.get_table_names(schema): + tables = inspector.get_table_names(schema) + for table_name in tables: try: schema, table_name = self.standardize_schema_table_names( schema, table_name @@ -303,8 +304,6 @@ class SQLSource(Source[OMetaDatabaseAndTable]): "Table pattern not allowed", ) continue - self.status.scanned(f"{self.config.get_service_name()}.{table_name}") - description = _get_table_description(schema, table_name, inspector) fqn = f"{self.config.service_name}.{schema}.{table_name}" self.database_source_state.add(fqn) @@ -338,10 +337,15 @@ class SQLSource(Source[OMetaDatabaseAndTable]): table=table_entity, database=self._get_database(schema) ) yield table_and_db - # Catch any errors during the ingestion and continue - except Exception as err: # pylint: disable=broad-except + self.status.scanned( + "{}.{}".format(self.config.get_service_name(), table_name) + ) + except Exception as err: + traceback.print_exc() logger.error(err) - self.status.warnings.append(f"{self.config.service_name}.{table_name}") + self.status.failures.append( + "{}.{}".format(self.config.service_name, table_name) + ) continue def fetch_views( @@ -426,31 +430,36 @@ class SQLSource(Source[OMetaDatabaseAndTable]): } for key, mnode in manifest_entities.items(): - name = mnode["alias"] if "alias" in mnode.keys() else mnode["name"] - cnode = catalog_entities.get(key) - columns = ( - self._parse_data_model_columns(name, mnode, cnode) if cnode else [] - ) + try: + name = mnode["alias"] if "alias" in mnode.keys() else mnode["name"] + cnode = catalog_entities.get(key) + columns = ( + self._parse_data_model_columns(name, mnode, cnode) + if cnode + else [] + ) - if mnode["resource_type"] == "test": - continue - upstream_nodes = self._parse_data_model_upstream(mnode) - model_name = ( - mnode["alias"] if "alias" in mnode.keys() else mnode["name"] - ) - model_name = model_name.replace(".", "_DOT_") - schema = mnode["schema"] - raw_sql = mnode.get("raw_sql", "") - model = DataModel( - modelType=ModelType.DBT, - description=mnode.get("description", ""), - path=f"{mnode['root_path']}/{mnode['original_file_path']}", - rawSql=raw_sql, - sql=mnode.get("compiled_sql", raw_sql), - columns=columns, - upstream=upstream_nodes, - ) - model_fqdn = f"{schema}.{model_name}" + if mnode["resource_type"] == "test": + continue + upstream_nodes = self._parse_data_model_upstream(mnode) + model_name = ( + mnode["alias"] if "alias" in mnode.keys() else mnode["name"] + ) + model_name = model_name.replace(".", "_DOT_") + schema = mnode["schema"] + raw_sql = mnode.get("raw_sql", "") + model = DataModel( + modelType=ModelType.DBT, + description=mnode.get("description", ""), + path=f"{mnode['root_path']}/{mnode['original_file_path']}", + rawSql=raw_sql, + sql=mnode.get("compiled_sql", raw_sql), + columns=columns, + upstream=upstream_nodes, + ) + model_fqdn = f"{schema}.{model_name}" + except Exception as err: + print(err) self.data_models[model_fqdn] = model def _parse_data_model_upstream(self, mnode): @@ -507,7 +516,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]): def _get_database(self, schema: str) -> Database: return Database( - name=schema, + name=schema.replace(".", "_DOT_"), service=EntityReference(id=self.service.id, type=self.config.service_type), ) @@ -561,48 +570,61 @@ class SQLSource(Source[OMetaDatabaseAndTable]): dataset_name = f"{schema}.{table}" table_columns = [] + columns = inspector.get_columns(table, schema) try: - for row_order, column in enumerate(inspector.get_columns(table, schema)): - if "." in column["name"]: - logger.info(f"Found '.' in {column['name']}") - column["name"] = column["name"].replace(".", "_DOT_") - children = None - data_type_display = None - col_data_length = None - arr_data_type = None - if "raw_data_type" in column and column["raw_data_type"] is not None: - ( - col_type, - data_type_display, - arr_data_type, - children, - ) = check_column_complex_type( - self.status, - dataset_name, - column["raw_data_type"], - column["name"], - ) - else: - col_type = get_column_type( - self.status, dataset_name, column["type"] - ) - if col_type == "ARRAY" and re.match( - r"(?:\w*)(?:\()(\w*)(?:.*)", str(column["type"]) - ): - arr_data_type = re.match( - r"(?:\w*)(?:[(]*)(\w*)(?:.*)", str(column["type"]) - ).groups() - data_type_display = column["type"] - - col_constraint = self._get_column_constraints( - column, pk_columns, unique_columns - ) - - if col_type.upper() in {"CHAR", "VARCHAR", "BINARY", "VARBINARY"}: - col_data_length = column["type"].length - if col_data_length is None: - col_data_length = 1 + for row_order, column in enumerate(columns): try: + if "." in column["name"]: + logger.info( + f"Found '.' in {column['name']}, changing '.' to '_DOT_'" + ) + column["name"] = column["name"].replace(".", "_DOT_") + children = None + data_type_display = None + col_data_length = None + arr_data_type = None + if ( + "raw_data_type" in column + and column["raw_data_type"] is not None + ): + column["raw_data_type"] = self.parse_raw_data_type( + column["raw_data_type"] + ) + ( + col_type, + data_type_display, + arr_data_type, + children, + ) = check_column_complex_type( + self.status, + dataset_name, + column["raw_data_type"], + column["name"], + ) + else: + col_type = get_column_type( + self.status, dataset_name, column["type"] + ) + if col_type == "ARRAY" and re.match( + r"(?:\w*)(?:\()(\w*)(?:.*)", str(column["type"]) + ): + arr_data_type = re.match( + r"(?:\w*)(?:[(]*)(\w*)(?:.*)", str(column["type"]) + ).groups() + data_type_display = column["type"] + if repr(column["type"]).upper().startswith("ARRAY("): + arr_data_type = "STRUCT" + data_type_display = ( + repr(column["type"]) + .replace("(", "<") + .replace(")", ">") + .lower() + ) + col_constraint = self._get_column_constraints( + column, pk_columns, unique_columns + ) + if col_type.upper() in {"CHAR", "VARCHAR", "BINARY", "VARBINARY"}: + col_data_length = column["type"].length if col_type == "NULL": col_type = "VARCHAR" data_type_display = "varchar" @@ -613,23 +635,24 @@ class SQLSource(Source[OMetaDatabaseAndTable]): name=column["name"], description=column.get("comment", None), dataType=col_type, - dataTypeDisplay=f"{col_type}({col_data_length})" + dataTypeDisplay="{}({})".format( + col_type, 1 if col_data_length is None else col_data_length + ) if data_type_display is None else f"{data_type_display}", - dataLength=col_data_length, + dataLength=1 if col_data_length is None else col_data_length, constraint=col_constraint, - ordinalPosition=row_order + 1, # enumerate starts at 0 - children=children, + ordinalPosition=row_order, + children=children if children is not None else None, arrayDataType=arr_data_type, ) - except Exception as err: # pylint: disable=broad-except - logger.error(traceback.format_exc()) + except Exception as err: logger.error(traceback.print_exc()) logger.error(f"{err} : {column}") continue table_columns.append(om_column) return table_columns - except Exception as err: # pylint: disable=broad-except + except Exception as err: logger.error(f"{repr(err)}: {table} {err}") return None @@ -661,6 +684,9 @@ class SQLSource(Source[OMetaDatabaseAndTable]): logger.debug(f"Finished profiling {dataset_name}") return profile + def parse_raw_data_type(self, raw_data_type): + return raw_data_type + def _build_database_state(self, schema_fqdn: str) -> [EntityReference]: after = None tables = [] diff --git a/ingestion/src/metadata/utils/column_helpers.py b/ingestion/src/metadata/utils/column_helpers.py index ea6e084c677..e4e99414a23 100644 --- a/ingestion/src/metadata/utils/column_helpers.py +++ b/ingestion/src/metadata/utils/column_helpers.py @@ -1,22 +1,26 @@ import re from typing import Any, Dict, Optional, Set, Type -from sqlalchemy.sql import sqltypes as types - from metadata.ingestion.api.source import SourceStatus - - -def register_custom_type(tp: Type[types.TypeEngine], output: str = None) -> None: - if output: - _column_type_mapping[tp] = output - else: - _known_unknown_column_types.add(tp) +from sqlalchemy.sql import sqltypes as types +from sqlalchemy.types import TypeEngine def register_custom_str_type(tp: str, output: str) -> None: _column_string_mapping[tp] = output +def create_sqlalchemy_type(name: str): + sqlalchemy_type = type( + name, + (TypeEngine,), + { + "__repr__": lambda self: f"{name}()", + }, + ) + return sqlalchemy_type + + _column_type_mapping: Dict[Type[types.TypeEngine], str] = { types.Integer: "INT", types.Numeric: "INT", @@ -123,6 +127,8 @@ _column_string_mapping = { "XML": "BINARY", "XMLTYPE": "BINARY", "CURSOR": "BINARY", + "TIMESTAMP_LTZ": "TIMESTAMP", + "TIMESTAMP_TZ": "TIMESTAMP", } _known_unknown_column_types: Set[Type[types.TypeEngine]] = { diff --git a/profiler/src/openmetadata/databases/postgres.py b/profiler/src/openmetadata/databases/postgres.py index 538031268c4..203002a267c 100644 --- a/profiler/src/openmetadata/databases/postgres.py +++ b/profiler/src/openmetadata/databases/postgres.py @@ -14,7 +14,6 @@ from openmetadata.common.database_common import ( DatabaseCommon, SQLConnectionConfig, SQLExpressions, - register_custom_type, ) logger = logging.getLogger(__name__)