diff --git a/ingestion/src/metadata/data_quality/validations/column/sqlalchemy/columnValuesToBeUnique.py b/ingestion/src/metadata/data_quality/validations/column/sqlalchemy/columnValuesToBeUnique.py index e3d1dbfb403..8d4a21a5d77 100644 --- a/ingestion/src/metadata/data_quality/validations/column/sqlalchemy/columnValuesToBeUnique.py +++ b/ingestion/src/metadata/data_quality/validations/column/sqlalchemy/columnValuesToBeUnique.py @@ -15,9 +15,10 @@ Validator for column values to be unique test case from typing import Optional -from sqlalchemy import Column, inspect +from sqlalchemy import Column, inspect, literal_column from sqlalchemy.exc import SQLAlchemyError +from ingestion.src.metadata.profiler.orm.registry import Dialects from metadata.data_quality.validations.column.base.columnValuesToBeUnique import ( BaseColumnValuesToBeUniqueValidator, ) @@ -57,7 +58,18 @@ class ColumnValuesToBeUniqueValidator( ) # type: ignore try: - self.value = dict(self.runner.dispatch_query_select_first(count, unique_count.scalar_subquery().label("uniqueCount"))) # type: ignore + if self.runner.dialect == Dialects.Oracle: + query_group_by_ = [literal_column("2")] + else: + query_group_by_ = None + + self.value = dict( + self.runner.dispatch_query_select_first( + count, + unique_count.scalar_subquery().label("uniqueCount"), + query_group_by_=query_group_by_, + ) + ) # type: ignore res = self.value.get(Metrics.COUNT.name) except Exception as exc: raise SQLAlchemyError(exc) diff --git a/ingestion/src/metadata/profiler/processor/runner.py b/ingestion/src/metadata/profiler/processor/runner.py index 8826b0750b4..0851c531db4 100644 --- a/ingestion/src/metadata/profiler/processor/runner.py +++ b/ingestion/src/metadata/profiler/processor/runner.py @@ -23,7 +23,10 @@ from sqlalchemy.orm import DeclarativeMeta, Query, Session from sqlalchemy.orm.util import AliasedClass from metadata.utils.logger import query_runner_logger -from metadata.utils.sqa_utils import get_query_filter_for_runner +from metadata.utils.sqa_utils import ( + get_query_filter_for_runner, + get_query_group_by_for_runner, +) logger = query_runner_logger() @@ -88,6 +91,11 @@ class QueryRunner: """Table name attribute access""" return self._session + @property + def dialect(self) -> str: + """Dialect attribute access""" + return self._session.get_bind().dialect.name + def _build_query(self, *entities, **kwargs) -> Query: """Build query object @@ -106,11 +114,15 @@ class QueryRunner: **kwargs: kwargs to pass to the query """ filter_ = get_query_filter_for_runner(kwargs) + group_by_ = get_query_group_by_for_runner(kwargs) query = self._build_query(*entities, **kwargs).select_from(self._dataset) if filter_ is not None: - return query.filter(filter_) + query = query.filter(filter_) + + if group_by_ is not None: + query = query.group_by(*group_by_) return query @@ -122,7 +134,7 @@ class QueryRunner: **kwargs: kwargs to pass to the query """ filter_ = get_query_filter_for_runner(kwargs) - + group_by_ = get_query_group_by_for_runner(kwargs) user_query = self._session.query(self._dataset).from_statement( text(f"{self.profile_sample_query}") ) @@ -130,7 +142,10 @@ class QueryRunner: query = self._build_query(*entities, **kwargs).select_from(user_query) if filter_ is not None: - return query.filter(filter_) + query = query.filter(filter_) + + if group_by_ is not None: + query = query.group_by(*group_by_) return query @@ -143,13 +158,17 @@ class QueryRunner: **kwargs: kwargs to pass to the query """ filter_ = get_query_filter_for_runner(kwargs) + group_by_ = get_query_group_by_for_runner(kwargs) if self.profile_sample_query: return self._select_from_user_query(*entities, **kwargs).first() query = self._build_query(*entities, **kwargs).select_from(self.table) if filter_ is not None: - return query.filter(filter_).first() + query = query.filter(filter_) + + if group_by_ is not None: + query = query.group_by(*group_by_) return query.first() @@ -162,6 +181,7 @@ class QueryRunner: **kwargs: kwargs to pass to the query """ filter_ = get_query_filter_for_runner(kwargs) + group_by_ = get_query_group_by_for_runner(kwargs) if self.profile_sample_query: return self._select_from_user_query(*entities, **kwargs).all() @@ -169,7 +189,10 @@ class QueryRunner: query = self._build_query(*entities, **kwargs).select_from(self.table) if filter_ is not None: - return query.filter(filter_).all() + query = query.filter(filter_) + + if group_by_ is not None: + query = query.group_by(*group_by_) return query.all() diff --git a/ingestion/src/metadata/utils/sqa_utils.py b/ingestion/src/metadata/utils/sqa_utils.py index f5fed370793..142ca08bcb5 100644 --- a/ingestion/src/metadata/utils/sqa_utils.py +++ b/ingestion/src/metadata/utils/sqa_utils.py @@ -191,6 +191,21 @@ def get_query_filter_for_runner(kwargs: Dict) -> Optional[BinaryExpression]: return filter_ +def get_query_group_by_for_runner(kwargs: Dict) -> Optional[BinaryExpression]: + """Get query group by from kwargs. IMPORTANT, this will update the original dictionary + passed in the function argument. + + Args: + kwargs (Dict): kwargs + """ + if "query_group_by_" in kwargs: + group_by_ = kwargs.pop("query_group_by_") + else: + group_by_ = None + + return group_by_ + + def handle_array( query: Query, column: Column, table: Union[DeclarativeMeta, AliasedClass] ) -> Query: