From eb906589fdfac3dba07e6772d95ca993051c29e5 Mon Sep 17 00:00:00 2001 From: Pere Miquel Brull Date: Tue, 22 Mar 2022 15:55:44 +0100 Subject: [PATCH] Fix #3525 - Profiler breaks on Postgres data (#3583) Fix #3525 - Profiler breaks on Postgres data (#3583) --- .../src/metadata/ingestion/source/postgres.py | 4 +-- .../orm_profiler/metrics/static/histogram.py | 4 +-- .../orm_profiler/metrics/static/stddev.py | 6 ++-- .../metadata/orm_profiler/orm/converter.py | 4 +-- .../orm_profiler/orm/functions/concat.py | 7 ++-- .../orm_profiler/orm/functions/length.py | 13 +++---- .../orm_profiler/orm/functions/random_num.py | 7 ++-- .../src/metadata/orm_profiler/orm/registry.py | 34 ++++++++++++++++++- .../src/metadata/utils/column_type_parser.py | 2 +- ingestion/tests/unit/profiler/test_metrics.py | 25 ++++++++++++++ 10 files changed, 81 insertions(+), 25 deletions(-) diff --git a/ingestion/src/metadata/ingestion/source/postgres.py b/ingestion/src/metadata/ingestion/source/postgres.py index bd2ed44c0d3..bebddbdbed7 100644 --- a/ingestion/src/metadata/ingestion/source/postgres.py +++ b/ingestion/src/metadata/ingestion/source/postgres.py @@ -52,7 +52,7 @@ class PostgresSource(SQLSource): def get_status(self) -> SourceStatus: return self.status - def _is_partition(self, table_name: str, schema_name: str) -> bool: + def _is_partition(self, table_name: str, schema: str, inspector) -> bool: cur = self.pgconn.cursor() cur.execute( """ @@ -62,7 +62,7 @@ class PostgresSource(SQLSource): WHERE c.relname = %s AND n.nspname = %s """, - (table_name, schema_name), + (table_name, schema), ) is_partition = cur.fetchone()[0] return is_partition diff --git a/ingestion/src/metadata/orm_profiler/metrics/static/histogram.py b/ingestion/src/metadata/orm_profiler/metrics/static/histogram.py index e2bb690f0e7..a7a2b3f87eb 100644 --- a/ingestion/src/metadata/orm_profiler/metrics/static/histogram.py +++ b/ingestion/src/metadata/orm_profiler/metrics/static/histogram.py @@ -69,9 +69,9 @@ class Histogram(QueryMetric): step = dict(bins.first())["step"] - if step == 0: + if not step: # step == 0 or None for empty tables logger.debug( - f"MIN({col.name}) == MAX({col.name}). Aborting histogram computation." + f"MIN({col.name}) == MAX({col.name}) or EMPTY table. Aborting histogram computation." ) return None diff --git a/ingestion/src/metadata/orm_profiler/metrics/static/stddev.py b/ingestion/src/metadata/orm_profiler/metrics/static/stddev.py index 6718ef23c35..d2f45474f3d 100644 --- a/ingestion/src/metadata/orm_profiler/metrics/static/stddev.py +++ b/ingestion/src/metadata/orm_profiler/metrics/static/stddev.py @@ -20,7 +20,7 @@ from metadata.generated.schema.entity.services.databaseService import ( DatabaseServiceType, ) from metadata.orm_profiler.metrics.core import CACHE, StaticMetric, _label -from metadata.orm_profiler.orm.registry import is_quantifiable +from metadata.orm_profiler.orm.registry import Dialects, is_quantifiable from metadata.orm_profiler.utils import logger logger = logger() @@ -36,12 +36,12 @@ def _(element, compiler, **kw): return "STDDEV_POP(%s)" % compiler.process(element.clauses, **kw) -@compiles(StdDevFn, DatabaseServiceType.MSSQL.value.lower()) +@compiles(StdDevFn, Dialects.MSSQL) def _(element, compiler, **kw): return "STDEVP(%s)" % compiler.process(element.clauses, **kw) -@compiles(StdDevFn, DatabaseServiceType.SQLite.value.lower()) # Needed for unit tests +@compiles(StdDevFn, Dialects.SQLite) # Needed for unit tests def _(element, compiler, **kw): """ This actually returns the squared STD, but as diff --git a/ingestion/src/metadata/orm_profiler/orm/converter.py b/ingestion/src/metadata/orm_profiler/orm/converter.py index b0405c57ca4..7f0a22c1b4f 100644 --- a/ingestion/src/metadata/orm_profiler/orm/converter.py +++ b/ingestion/src/metadata/orm_profiler/orm/converter.py @@ -32,7 +32,7 @@ _TYPE_MAP = { DataType.INT: sqlalchemy.INT, DataType.BIGINT: sqlalchemy.BIGINT, DataType.BYTEINT: sqlalchemy.SMALLINT, - DataType.BYTES: CustomTypes.BYTES, + DataType.BYTES: CustomTypes.BYTES.value, DataType.FLOAT: sqlalchemy.FLOAT, DataType.DOUBLE: sqlalchemy.DECIMAL, DataType.DECIMAL: sqlalchemy.DECIMAL, @@ -61,7 +61,7 @@ _TYPE_MAP = { # DataType.GEOGRAPHY: ..., DataType.ENUM: sqlalchemy.Enum, DataType.JSON: sqlalchemy.JSON, - DataType.UUID: CustomTypes.UUID, + DataType.UUID: CustomTypes.UUID.value, } diff --git a/ingestion/src/metadata/orm_profiler/orm/functions/concat.py b/ingestion/src/metadata/orm_profiler/orm/functions/concat.py index c542f96e187..0559b2142db 100644 --- a/ingestion/src/metadata/orm_profiler/orm/functions/concat.py +++ b/ingestion/src/metadata/orm_profiler/orm/functions/concat.py @@ -19,6 +19,7 @@ from metadata.generated.schema.entity.services.databaseService import ( DatabaseServiceType, ) from metadata.orm_profiler.metrics.core import CACHE +from metadata.orm_profiler.orm.registry import Dialects from metadata.orm_profiler.utils import logger logger = logger() @@ -33,9 +34,9 @@ def _(element, compiler, **kw): return "CONCAT(%s)" % compiler.process(element.clauses, **kw) -@compiles(ConcatFn, DatabaseServiceType.Redshift.value.lower()) -@compiles(ConcatFn, DatabaseServiceType.SQLite.value.lower()) -@compiles(ConcatFn, DatabaseServiceType.Vertica.value.lower()) +@compiles(ConcatFn, Dialects.Redshift) +@compiles(ConcatFn, Dialects.SQLite) +@compiles(ConcatFn, Dialects.Vertica) def _(element, compiler, **kw): """ This actually returns the squared STD, but as diff --git a/ingestion/src/metadata/orm_profiler/orm/functions/length.py b/ingestion/src/metadata/orm_profiler/orm/functions/length.py index d8c486f1f04..41d6d80b46a 100644 --- a/ingestion/src/metadata/orm_profiler/orm/functions/length.py +++ b/ingestion/src/metadata/orm_profiler/orm/functions/length.py @@ -15,10 +15,8 @@ Define Length function from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.functions import FunctionElement -from metadata.generated.schema.entity.services.databaseService import ( - DatabaseServiceType, -) from metadata.orm_profiler.metrics.core import CACHE +from metadata.orm_profiler.orm.registry import Dialects from metadata.orm_profiler.utils import logger logger = logger() @@ -33,10 +31,9 @@ def _(element, compiler, **kw): return "LEN(%s)" % compiler.process(element.clauses, **kw) -@compiles(LenFn, DatabaseServiceType.SQLite.value.lower()) -@compiles(LenFn, DatabaseServiceType.Vertica.value.lower()) -@compiles( - LenFn, DatabaseServiceType.Hive.value.lower().encode() -) # For some reason hive's dialect is in bytes... +@compiles(LenFn, Dialects.SQLite) +@compiles(LenFn, Dialects.Vertica) +@compiles(LenFn, Dialects.Hive) # For some reason hive's dialect is in bytes... +@compiles(LenFn, Dialects.Postgres) def _(element, compiler, **kw): return "LENGTH(%s)" % compiler.process(element.clauses, **kw) diff --git a/ingestion/src/metadata/orm_profiler/orm/functions/random_num.py b/ingestion/src/metadata/orm_profiler/orm/functions/random_num.py index b2a89f49055..d0518fe47a9 100644 --- a/ingestion/src/metadata/orm_profiler/orm/functions/random_num.py +++ b/ingestion/src/metadata/orm_profiler/orm/functions/random_num.py @@ -23,6 +23,7 @@ from metadata.generated.schema.entity.services.databaseService import ( DatabaseServiceType, ) from metadata.orm_profiler.metrics.core import CACHE +from metadata.orm_profiler.orm.registry import Dialects from metadata.orm_profiler.utils import logger logger = logger() @@ -43,12 +44,12 @@ def _(*_, **__): return "ABS(RANDOM()) * 100" -@compiles(RandomNumFn, DatabaseServiceType.MySQL.value.lower()) +@compiles(RandomNumFn, Dialects.MySQL) def _(*_, **__): return "ABS(RAND()) * 100" -@compiles(RandomNumFn, DatabaseServiceType.SQLite.value.lower()) +@compiles(RandomNumFn, Dialects.SQLite) def _(*_, **__): """ SQLite random returns a number between -9223372036854775808 @@ -57,7 +58,7 @@ def _(*_, **__): return "ABS(RANDOM()) % 100" -@compiles(RandomNumFn, DatabaseServiceType.MSSQL.value.lower()) +@compiles(RandomNumFn, Dialects.MSSQL) def _(*_, **__): """ MSSQL RANDOM() function returns the same single diff --git a/ingestion/src/metadata/orm_profiler/orm/registry.py b/ingestion/src/metadata/orm_profiler/orm/registry.py index a6900d2df46..9561c9319a0 100644 --- a/ingestion/src/metadata/orm_profiler/orm/registry.py +++ b/ingestion/src/metadata/orm_profiler/orm/registry.py @@ -15,7 +15,7 @@ without having an import mess """ import sqlalchemy from sqlalchemy import Integer, Numeric -from sqlalchemy.sql.sqltypes import Concatenable +from sqlalchemy.sql.sqltypes import Concatenable, Enum from metadata.orm_profiler.orm.types.hex_byte_string import HexByteString from metadata.orm_profiler.orm.types.uuid import UUIDString @@ -27,6 +27,37 @@ class CustomTypes(TypeRegistry): UUID = UUIDString +class Dialects(Enum): + """ + Map the service types from DatabaseServiceType + to the dialect scheme name used for ingesting + and profiling data. + """ + + Hive = b"hive" + Postgres = "postgresql" + BigQuery = "bigquery" + MySQL = "mysql" + Redshift = "redshift" + Snowflake = "snowflake" + MSSQL = "mssql" + Oracle = "oracle" + Athena = "athena" + Presto = "presto" + Trino = "Trino" + Vertica = "vertica" + Glue = "glue" + MariaDB = "mariadb" + Druid = "druid" + Db2 = "db2" + ClickHouse = "clickhouse" + Databricks = "databricks" + DynamoDB = "dynamoDB" + AzureSQL = "azuresql" + SingleStore = "singlestore" + SQLite = "sqlite" + + # Sometimes we want to skip certain types for computing metrics. # If the type is NULL, then we won't run the metric execution # in the profiler. @@ -34,6 +65,7 @@ class CustomTypes(TypeRegistry): NOT_COMPUTE = { sqlalchemy.types.NullType, sqlalchemy.ARRAY, + sqlalchemy.JSON, } diff --git a/ingestion/src/metadata/utils/column_type_parser.py b/ingestion/src/metadata/utils/column_type_parser.py index 236cfb37628..efeb2a63fc5 100644 --- a/ingestion/src/metadata/utils/column_type_parser.py +++ b/ingestion/src/metadata/utils/column_type_parser.py @@ -70,6 +70,7 @@ class ColumnTypeParser: "DATETIMEOFFSET": "DATETIME", "DECIMAL": "DECIMAL", "DOUBLE PRECISION": "DOUBLE", + "DOUBLE_PRECISION": "DOUBLE", "DOUBLE": "DOUBLE", "ENUM": "ENUM", "FLOAT": "FLOAT", @@ -101,7 +102,6 @@ class ColumnTypeParser: "INTERVAL DAY TO SECOND": "INTERVAL", "INTERVAL YEAR TO MONTH": "INTERVAL", "INTERVAL": "INTERVAL", - "JSON": "JSON", "LONG RAW": "BINARY", "LONG VARCHAR": "VARCHAR", "LONGBLOB": "LONGBLOB", diff --git a/ingestion/tests/unit/profiler/test_metrics.py b/ingestion/tests/unit/profiler/test_metrics.py index 9bd5ab32a74..c96ec0b3263 100644 --- a/ingestion/tests/unit/profiler/test_metrics.py +++ b/ingestion/tests/unit/profiler/test_metrics.py @@ -574,3 +574,28 @@ class MetricsTest(TestCase): ) assert res.get(User.name.name)[Metrics.COUNT_IN_SET.name] == 3 + + def test_histogram_empty(self): + """ + Run the histogram on an empty table + """ + + class EmptyUser(Base): + __tablename__ = "empty_users" + id = Column(Integer, primary_key=True) + name = Column(String(256)) + fullname = Column(String(256)) + nickname = Column(String(256)) + comments = Column(TEXT) + age = Column(Integer) + + EmptyUser.__table__.create(bind=self.engine) + + hist = add_props(bins=5)(Metrics.HISTOGRAM.value) + res = ( + Profiler(hist, session=self.session, table=EmptyUser, use_cols=[User.age]) + .execute() + ._column_results + ) + + assert res.get(User.age.name).get(Metrics.HISTOGRAM.name) is None