From 06b5c10bcfd85a118697f13308850839d36ea54c Mon Sep 17 00:00:00 2001 From: IceS2 Date: Tue, 26 Aug 2025 11:03:40 +0200 Subject: [PATCH] FIXES #20807: Fix Oracle DataDiff and Change Oracle Connection to BaseConnection (#23020) * Fix Oracle DataDiff and Change Oracle Connection to BaseConnection * Add small unittest * Fix Test * Fix logic, to void other engines to denormalize table/schema names (cherry picked from commit a696fe0111171c3079c5840c28a00073fae25003) --- .../base_diff_params_setter.py | 31 ++- .../validations/table/sqlalchemy/tableDiff.py | 1 + .../source/database/oracle/connection.py | 258 +++++++++++------- .../source/database/oracle/service_spec.py | 2 + .../test_base_diff_params_setter.py | 62 +++++ .../tests/unit/test_source_connection.py | 29 +- 6 files changed, 263 insertions(+), 120 deletions(-) create mode 100644 ingestion/tests/unit/data_quality/validations/runtime_param_setter/test_base_diff_params_setter.py diff --git a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py index 8335ceb1c60..47da1a71aeb 100644 --- a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py +++ b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/base_diff_params_setter.py @@ -3,16 +3,24 @@ from typing import List, Optional, Set, Type, Union from urllib.parse import urlparse +from sqlalchemy.engine import make_url + from metadata.data_quality.validations.models import Column, TableParameter from metadata.generated.schema.entity.data.table import Table -from metadata.generated.schema.entity.services.databaseService import DatabaseService +from metadata.generated.schema.entity.services.databaseService import ( + DatabaseService, + DatabaseServiceType, +) from metadata.generated.schema.entity.services.serviceType import ServiceType from metadata.ingestion.connections.connection import BaseConnection from metadata.ingestion.source.connections import get_connection -from metadata.profiler.orm.registry import Dialects +from metadata.profiler.orm.registry import Dialects, PythonDialects from metadata.utils import fqn from metadata.utils.collections import CaseInsensitiveList from metadata.utils.importer import get_module_dir, import_from_module +from metadata.utils.logger import test_suite_logger + +logger = test_suite_logger() # TODO: Refactor to avoid the circular import that makes us unable to use the BaseSpec class and the helper methods. @@ -68,7 +76,9 @@ class BaseTableParameter: """ return TableParameter( database_service_type=service.serviceType, - path=self.get_data_diff_table_path(entity.fullyQualifiedName.root), + path=self.get_data_diff_table_path( + entity.fullyQualifiedName.root, service.serviceType + ), serviceUrl=self.get_data_diff_url( service, entity.fullyQualifiedName.root, @@ -85,7 +95,9 @@ class BaseTableParameter: ) @staticmethod - def get_data_diff_table_path(table_fqn: str) -> str: + def get_data_diff_table_path( + table_fqn: str, service_type: DatabaseServiceType + ) -> str: """Get the data diff table path. Args: @@ -95,6 +107,17 @@ class BaseTableParameter: str """ _, _, schema, table = fqn.split(table_fqn) + try: + dialect = PythonDialects[service_type.name].value + if dialect in (Dialects.Oracle): + url = make_url(f"{dialect}://") + dialect_instance = url.get_dialect()() + table = dialect_instance.denormalize_name(name=table) + schema = dialect_instance.denormalize_name(name=schema) + except Exception as e: + logger.debug( + f"[Data Diff]: Error denormalizing table and schema names. Skipping denormalization\n{e}" + ) return fqn._build( # pylint: disable=protected-access "___SERVICE___", "__DATABASE__", schema, table ).replace("___SERVICE___.__DATABASE__.", "") diff --git a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py index 157a5b2df57..c2c262f0a61 100644 --- a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py +++ b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py @@ -192,6 +192,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): ) return result except UnsupportedDialectError as e: + logger.warning(f"[Data Diff]: Unsupported dialect: {e}") result = TestCaseResult( timestamp=self.execution_date, # type: ignore testCaseStatus=TestCaseStatus.Aborted, diff --git a/ingestion/src/metadata/ingestion/source/database/oracle/connection.py b/ingestion/src/metadata/ingestion/source/database/oracle/connection.py index f06df2a0a7e..77f44717224 100644 --- a/ingestion/src/metadata/ingestion/source/database/oracle/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/oracle/connection.py @@ -14,6 +14,7 @@ Source connection handler """ import os import sys +from copy import deepcopy from typing import Optional from urllib.parse import quote_plus @@ -26,7 +27,9 @@ from metadata.generated.schema.entity.automations.workflow import ( Workflow as AutomationWorkflow, ) from metadata.generated.schema.entity.services.connections.database.oracleConnection import ( - OracleConnection, + OracleConnection as OracleConnectionConfig, +) +from metadata.generated.schema.entity.services.connections.database.oracleConnection import ( OracleDatabaseSchema, OracleServiceName, OracleTNSConnection, @@ -39,6 +42,8 @@ from metadata.ingestion.connections.builders import ( get_connection_args_common, get_connection_options_dict, ) +from metadata.ingestion.connections.connection import BaseConnection +from metadata.ingestion.connections.secrets import connection_with_options_secrets from metadata.ingestion.connections.test_connections import test_connection_db_common from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.database.oracle.queries import ( @@ -55,123 +60,174 @@ LD_LIB_ENV = "LD_LIBRARY_PATH" logger = ingestion_logger() -def get_connection_url(connection: OracleConnection) -> str: - """ - Build the URL and handle driver version at system level - """ +class OracleConnection(BaseConnection[OracleConnectionConfig, Engine]): + def __init__(self, connection: OracleConnectionConfig): + super().__init__(connection) - oracledb.version = CX_ORACLE_LIB_VERSION - sys.modules["cx_Oracle"] = oracledb + def _get_client(self) -> Engine: + """ + Create connection + """ + try: + if self.service_connection.instantClientDirectory: + logger.info( + f"Initializing Oracle thick client at {self.service_connection.instantClientDirectory}" + ) + os.environ[LD_LIB_ENV] = self.service_connection.instantClientDirectory + oracledb.init_oracle_client( + lib_dir=self.service_connection.instantClientDirectory + ) + except DatabaseError as err: + logger.info(f"Could not initialize Oracle thick client: {err}") - url = f"{connection.scheme.value}://" - if connection.username: - url += f"{quote_plus(connection.username)}" - if not connection.password: - connection.password = SecretStr("") - url += f":{quote_plus(connection.password.get_secret_value())}" - url += "@" - - url = _handle_connection_type(url=url, connection=connection) - - options = get_connection_options_dict(connection) - if options: - params = "&".join( - f"{key}={quote_plus(value)}" for (key, value) in options.items() if value + return create_generic_db_connection( + connection=self.service_connection, + get_connection_url_fn=self.get_connection_url, + get_connection_args_fn=get_connection_args_common, ) - if isinstance(connection.oracleConnectionType, OracleServiceName): - url = f"{url}&{params}" - else: - url = f"{url}?{params}" - return url + def test_connection( + self, + metadata: OpenMetadata, + automation_workflow: Optional[AutomationWorkflow] = None, + timeout_seconds: Optional[int] = THREE_MIN, + ) -> TestConnectionResult: + """ + Test connection. This can be executed either as part + of a metadata workflow or during an Automation Workflow + """ + def test_oracle_package_access(engine): + try: + schema_name = engine.execute(ORACLE_GET_SCHEMA).scalar() + return ORACLE_GET_STORED_PACKAGES.format(schema=schema_name) + except Exception as e: + raise OraclePackageAccessError( + f"Failed to access Oracle stored packages: {e}" + ) -def _handle_connection_type(url: str, connection: OracleConnection) -> str: - """ - Depending on the oracle connection type, we need to handle the URL differently - """ + test_conn_queries = { + "CheckAccess": CHECK_ACCESS_TO_ALL, + "PackageAccess": test_oracle_package_access(self.client), + } - if isinstance(connection.oracleConnectionType, OracleTNSConnection): - # ref https://stackoverflow.com/questions/14140902/using-oracle-service-names-with-sqlalchemy - url += connection.oracleConnectionType.oracleTNSConnection - return url - - # If not TNS, we add the hostPort - url += connection.hostPort - - if isinstance(connection.oracleConnectionType, OracleDatabaseSchema): - url += ( - f"/{connection.oracleConnectionType.databaseSchema}" - if connection.oracleConnectionType.databaseSchema - else "" + return test_connection_db_common( + metadata=metadata, + engine=self.client, + service_connection=self.service_connection, + automation_workflow=automation_workflow, + queries=test_conn_queries, + timeout_seconds=timeout_seconds, ) - return url - if isinstance(connection.oracleConnectionType, OracleServiceName): - url = f"{url}/?service_name={connection.oracleConnectionType.oracleServiceName}" - return url + def get_connection_dict(self) -> dict: + """ + Return the connection dictionary for this service. + """ + url = self.client.url + connection_copy = deepcopy(self.service_connection) - raise ValueError(f"Unknown connection type {connection.oracleConnectionType}") + connection_dict = { + "driver": url.drivername, + "host": f"{url.host}:{url.port}", # This is the format expected by data-diff. If we start using this for something else, we need to change it and modify the data-diff code. + "user": url.username, + } + # Add password if present in the connection + if connection_copy.password: + connection_dict["password"] = connection_copy.password.get_secret_value() -def get_connection(connection: OracleConnection) -> Engine: - """ - Create connection - """ - try: - if connection.instantClientDirectory: - logger.info( - f"Initializing Oracle thick client at {connection.instantClientDirectory}" + # Add connection type specific information + if isinstance(connection_copy.oracleConnectionType, OracleDatabaseSchema): + connection_dict[ + "database" + ] = connection_copy.oracleConnectionType.databaseSchema + elif isinstance(connection_copy.oracleConnectionType, OracleServiceName): + connection_dict[ + "database" + ] = connection_copy.oracleConnectionType.oracleServiceName + elif isinstance(connection_copy.oracleConnectionType, OracleTNSConnection): + connection_dict[ + "host" + ] = connection_copy.oracleConnectionType.oracleTNSConnection + + # Add connection options if present + if connection_copy.connectionOptions and connection_copy.connectionOptions.root: + connection_with_options_secrets(lambda: connection_copy) + connection_dict.update(connection_copy.connectionOptions.root) + + # Add connection arguments if present + if ( + connection_copy.connectionArguments + and connection_copy.connectionArguments.root + ): + connection_dict.update(get_connection_args_common(connection_copy)) + + return connection_dict + + @staticmethod + def get_connection_url(connection: OracleConnectionConfig) -> str: + """ + Build the URL and handle driver version at system level + """ + + oracledb.version = CX_ORACLE_LIB_VERSION + sys.modules["cx_Oracle"] = oracledb + + url = f"{connection.scheme.value}://" + if connection.username: + url += f"{quote_plus(connection.username)}" + if not connection.password: + connection.password = SecretStr("") + url += f":{quote_plus(connection.password.get_secret_value())}" + url += "@" + + url = OracleConnection._handle_connection_type(url=url, connection=connection) + + options = get_connection_options_dict(connection) + if options: + params = "&".join( + f"{key}={quote_plus(value)}" + for (key, value) in options.items() + if value ) - os.environ[LD_LIB_ENV] = connection.instantClientDirectory - oracledb.init_oracle_client() - except DatabaseError as err: - logger.info(f"Could not initialize Oracle thick client: {err}") + if isinstance(connection.oracleConnectionType, OracleServiceName): + url = f"{url}&{params}" + else: + url = f"{url}?{params}" - return create_generic_db_connection( - connection=connection, - get_connection_url_fn=get_connection_url, - get_connection_args_fn=get_connection_args_common, - ) + return url + + @staticmethod + def _handle_connection_type(url: str, connection: OracleConnectionConfig) -> str: + """ + Depending on the oracle connection type, we need to handle the URL differently + """ + + if isinstance(connection.oracleConnectionType, OracleTNSConnection): + # ref https://stackoverflow.com/questions/14140902/using-oracle-service-names-with-sqlalchemy + url += connection.oracleConnectionType.oracleTNSConnection + return url + + # If not TNS, we add the hostPort + url += connection.hostPort + + if isinstance(connection.oracleConnectionType, OracleDatabaseSchema): + url += ( + f"/{connection.oracleConnectionType.databaseSchema}" + if connection.oracleConnectionType.databaseSchema + else "" + ) + return url + + if isinstance(connection.oracleConnectionType, OracleServiceName): + url = f"{url}/?service_name={connection.oracleConnectionType.oracleServiceName}" + return url + + raise ValueError(f"Unknown connection type {connection.oracleConnectionType}") class OraclePackageAccessError(Exception): """ Raised when unable to access Oracle stored packages """ - - -def test_connection( - metadata: OpenMetadata, - engine: Engine, - service_connection: OracleConnection, - automation_workflow: Optional[AutomationWorkflow] = None, - timeout_seconds: Optional[int] = THREE_MIN, -) -> TestConnectionResult: - """ - Test connection. This can be executed either as part - of a metadata workflow or during an Automation Workflow - """ - - def test_oracle_package_access(engine): - try: - schema_name = engine.execute(ORACLE_GET_SCHEMA).scalar() - return ORACLE_GET_STORED_PACKAGES.format(schema=schema_name) - except Exception as e: - raise OraclePackageAccessError( - f"Failed to access Oracle stored packages: {e}" - ) - - test_conn_queries = { - "CheckAccess": CHECK_ACCESS_TO_ALL, - "PackageAccess": test_oracle_package_access(engine), - } - - return test_connection_db_common( - metadata=metadata, - engine=engine, - service_connection=service_connection, - automation_workflow=automation_workflow, - queries=test_conn_queries, - timeout_seconds=timeout_seconds, - ) diff --git a/ingestion/src/metadata/ingestion/source/database/oracle/service_spec.py b/ingestion/src/metadata/ingestion/source/database/oracle/service_spec.py index 3c89f916211..13e4ef81675 100644 --- a/ingestion/src/metadata/ingestion/source/database/oracle/service_spec.py +++ b/ingestion/src/metadata/ingestion/source/database/oracle/service_spec.py @@ -1,3 +1,4 @@ +from metadata.ingestion.source.database.oracle.connection import OracleConnection from metadata.ingestion.source.database.oracle.lineage import OracleLineageSource from metadata.ingestion.source.database.oracle.metadata import OracleSource from metadata.ingestion.source.database.oracle.usage import OracleUsageSource @@ -7,4 +8,5 @@ ServiceSpec = DefaultDatabaseSpec( metadata_source_class=OracleSource, lineage_source_class=OracleLineageSource, usage_source_class=OracleUsageSource, + connection_class=OracleConnection, ) diff --git a/ingestion/tests/unit/data_quality/validations/runtime_param_setter/test_base_diff_params_setter.py b/ingestion/tests/unit/data_quality/validations/runtime_param_setter/test_base_diff_params_setter.py new file mode 100644 index 00000000000..68823daa25f --- /dev/null +++ b/ingestion/tests/unit/data_quality/validations/runtime_param_setter/test_base_diff_params_setter.py @@ -0,0 +1,62 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test Oracle DataDiff parameter setter functionality""" + +from unittest.mock import patch + +from metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter import ( + BaseTableParameter, +) +from metadata.generated.schema.entity.services.databaseService import ( + DatabaseServiceType, +) + + +def test_get_data_diff_table_path_oracle_denormalization(): + """Test Oracle table path generation with proper denormalization (lowercase to uppercase)""" + table_fqn = "oracle_service.testdb.schema_name.table_name" + service_type = DatabaseServiceType.Oracle + + result = BaseTableParameter.get_data_diff_table_path(table_fqn, service_type) + + # Oracle should denormalize names (lowercase to uppercase) + expected = "SCHEMA_NAME.TABLE_NAME" + assert result == expected + + +def test_get_data_diff_table_path_oracle_denormalization_fallback(): + """Test Oracle table path generation with denormalization error fallback""" + table_fqn = "oracle_service.testdb.schema_name.table_name" + service_type = DatabaseServiceType.Oracle + + with patch( + "metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter.make_url" + ) as mock_make_url: + mock_make_url.side_effect = Exception("Dialect error") + + result = BaseTableParameter.get_data_diff_table_path(table_fqn, service_type) + + # Should fallback to original names without denormalization + expected = "schema_name.table_name" + assert result == expected + + +def test_get_data_diff_table_path_mysql_no_denormalization(): + """Test MySQL table path generation (no denormalization needed)""" + table_fqn = "mysql_service.testdb.schema_name.table_name" + service_type = DatabaseServiceType.Mysql + + result = BaseTableParameter.get_data_diff_table_path(table_fqn, service_type) + + # MySQL should preserve original case + expected = "schema_name.table_name" + assert result == expected diff --git a/ingestion/tests/unit/test_source_connection.py b/ingestion/tests/unit/test_source_connection.py index d703f999202..00c0fb80bcb 100644 --- a/ingestion/tests/unit/test_source_connection.py +++ b/ingestion/tests/unit/test_source_connection.py @@ -69,7 +69,9 @@ from metadata.generated.schema.entity.services.connections.database.mysqlConnect MySQLScheme, ) from metadata.generated.schema.entity.services.connections.database.oracleConnection import ( - OracleConnection, + OracleConnection as OracleConnectionConfig, +) +from metadata.generated.schema.entity.services.connections.database.oracleConnection import ( OracleDatabaseSchema, OracleScheme, OracleServiceName, @@ -116,6 +118,7 @@ from metadata.ingestion.connections.builders import ( get_connection_args_common, get_connection_url_common, ) +from metadata.ingestion.source.database.oracle.connection import OracleConnection from metadata.ingestion.source.database.snowflake.connection import SnowflakeConnection from metadata.ingestion.source.database.trino.connection import TrinoConnection @@ -1139,14 +1142,10 @@ class SourceConnectionTest(TestCase): assert expected_url == get_connection_url(presto_conn_obj) def test_oracle_url(self): - from metadata.ingestion.source.database.oracle.connection import ( - get_connection_url, - ) - # oracle with db expected_url = "oracle+cx_oracle://admin:password@localhost:1541/testdb" - oracle_conn_obj = OracleConnection( + oracle_conn_obj = OracleConnectionConfig( username="admin", password="password", hostPort="localhost:1541", @@ -1154,21 +1153,21 @@ class SourceConnectionTest(TestCase): oracleConnectionType=OracleDatabaseSchema(databaseSchema="testdb"), ) - assert expected_url == get_connection_url(oracle_conn_obj) + assert expected_url == OracleConnection.get_connection_url(oracle_conn_obj) # oracle with service name expected_url = ( "oracle+cx_oracle://admin:password@localhost:1541/?service_name=testdb" ) - oracle_conn_obj = OracleConnection( + oracle_conn_obj = OracleConnectionConfig( username="admin", password="password", hostPort="localhost:1541", scheme=OracleScheme.oracle_cx_oracle, oracleConnectionType=OracleServiceName(oracleServiceName="testdb"), ) - assert expected_url == get_connection_url(oracle_conn_obj) + assert expected_url == OracleConnection.get_connection_url(oracle_conn_obj) # oracle with db & connection options expected_url = [ @@ -1176,7 +1175,7 @@ class SourceConnectionTest(TestCase): "oracle+cx_oracle://admin:password@localhost:1541/testdb?test_key_1=test_value_1&test_key_2=test_value_2", ] - oracle_conn_obj = OracleConnection( + oracle_conn_obj = OracleConnectionConfig( username="admin", password="password", hostPort="localhost:1541", @@ -1186,7 +1185,7 @@ class SourceConnectionTest(TestCase): test_key_1="test_value_1", test_key_2="test_value_2" ), ) - assert get_connection_url(oracle_conn_obj) in expected_url + assert OracleConnection.get_connection_url(oracle_conn_obj) in expected_url # oracle with service name & connection options expected_url = [ @@ -1194,7 +1193,7 @@ class SourceConnectionTest(TestCase): "oracle+cx_oracle://admin:password@localhost:1541/?service_name=testdb&test_key_1=test_value_1&test_key_2=test_value_2", ] - oracle_conn_obj = OracleConnection( + oracle_conn_obj = OracleConnectionConfig( username="admin", password="password", hostPort="localhost:1541", @@ -1204,7 +1203,7 @@ class SourceConnectionTest(TestCase): test_key_1="test_value_1", test_key_2="test_value_2" ), ) - assert get_connection_url(oracle_conn_obj) in expected_url + assert OracleConnection.get_connection_url(oracle_conn_obj) in expected_url tns_connection = ( "(DESCRIPTION=(ADDRESS_LIST=(ADDRESS=(PROTOCOL=TCP)" @@ -1212,7 +1211,7 @@ class SourceConnectionTest(TestCase): ) expected_url = f"oracle+cx_oracle://admin:password@{tns_connection}" - oracle_conn_obj = OracleConnection( + oracle_conn_obj = OracleConnectionConfig( username="admin", password="password", hostPort="localhost:1541", # We will ignore it here @@ -1220,7 +1219,7 @@ class SourceConnectionTest(TestCase): oracleTNSConnection=tns_connection ), ) - assert get_connection_url(oracle_conn_obj) == expected_url + assert OracleConnection.get_connection_url(oracle_conn_obj) == expected_url def test_exasol_url(self): from metadata.ingestion.source.database.exasol.connection import (