FIXES #20807: Fix Oracle DataDiff and Change Oracle Connection to BaseConnection (#23020)

* Fix Oracle DataDiff and Change Oracle Connection to BaseConnection

* Add small unittest

* Fix Test

* Fix logic, to void other engines to denormalize table/schema names

(cherry picked from commit a696fe0111171c3079c5840c28a00073fae25003)
This commit is contained in:
IceS2 2025-08-26 11:03:40 +02:00 committed by OpenMetadata Release Bot
parent a6cbb125df
commit 06b5c10bcf
6 changed files with 263 additions and 120 deletions

View File

@ -3,16 +3,24 @@
from typing import List, Optional, Set, Type, Union
from urllib.parse import urlparse
from sqlalchemy.engine import make_url
from metadata.data_quality.validations.models import Column, TableParameter
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.entity.services.databaseService import (
DatabaseService,
DatabaseServiceType,
)
from metadata.generated.schema.entity.services.serviceType import ServiceType
from metadata.ingestion.connections.connection import BaseConnection
from metadata.ingestion.source.connections import get_connection
from metadata.profiler.orm.registry import Dialects
from metadata.profiler.orm.registry import Dialects, PythonDialects
from metadata.utils import fqn
from metadata.utils.collections import CaseInsensitiveList
from metadata.utils.importer import get_module_dir, import_from_module
from metadata.utils.logger import test_suite_logger
logger = test_suite_logger()
# TODO: Refactor to avoid the circular import that makes us unable to use the BaseSpec class and the helper methods.
@ -68,7 +76,9 @@ class BaseTableParameter:
"""
return TableParameter(
database_service_type=service.serviceType,
path=self.get_data_diff_table_path(entity.fullyQualifiedName.root),
path=self.get_data_diff_table_path(
entity.fullyQualifiedName.root, service.serviceType
),
serviceUrl=self.get_data_diff_url(
service,
entity.fullyQualifiedName.root,
@ -85,7 +95,9 @@ class BaseTableParameter:
)
@staticmethod
def get_data_diff_table_path(table_fqn: str) -> str:
def get_data_diff_table_path(
table_fqn: str, service_type: DatabaseServiceType
) -> str:
"""Get the data diff table path.
Args:
@ -95,6 +107,17 @@ class BaseTableParameter:
str
"""
_, _, schema, table = fqn.split(table_fqn)
try:
dialect = PythonDialects[service_type.name].value
if dialect in (Dialects.Oracle):
url = make_url(f"{dialect}://")
dialect_instance = url.get_dialect()()
table = dialect_instance.denormalize_name(name=table)
schema = dialect_instance.denormalize_name(name=schema)
except Exception as e:
logger.debug(
f"[Data Diff]: Error denormalizing table and schema names. Skipping denormalization\n{e}"
)
return fqn._build( # pylint: disable=protected-access
"___SERVICE___", "__DATABASE__", schema, table
).replace("___SERVICE___.__DATABASE__.", "")

View File

@ -192,6 +192,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
)
return result
except UnsupportedDialectError as e:
logger.warning(f"[Data Diff]: Unsupported dialect: {e}")
result = TestCaseResult(
timestamp=self.execution_date, # type: ignore
testCaseStatus=TestCaseStatus.Aborted,

View File

@ -14,6 +14,7 @@ Source connection handler
"""
import os
import sys
from copy import deepcopy
from typing import Optional
from urllib.parse import quote_plus
@ -26,7 +27,9 @@ from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow,
)
from metadata.generated.schema.entity.services.connections.database.oracleConnection import (
OracleConnection,
OracleConnection as OracleConnectionConfig,
)
from metadata.generated.schema.entity.services.connections.database.oracleConnection import (
OracleDatabaseSchema,
OracleServiceName,
OracleTNSConnection,
@ -39,6 +42,8 @@ from metadata.ingestion.connections.builders import (
get_connection_args_common,
get_connection_options_dict,
)
from metadata.ingestion.connections.connection import BaseConnection
from metadata.ingestion.connections.secrets import connection_with_options_secrets
from metadata.ingestion.connections.test_connections import test_connection_db_common
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.oracle.queries import (
@ -55,7 +60,113 @@ LD_LIB_ENV = "LD_LIBRARY_PATH"
logger = ingestion_logger()
def get_connection_url(connection: OracleConnection) -> str:
class OracleConnection(BaseConnection[OracleConnectionConfig, Engine]):
def __init__(self, connection: OracleConnectionConfig):
super().__init__(connection)
def _get_client(self) -> Engine:
"""
Create connection
"""
try:
if self.service_connection.instantClientDirectory:
logger.info(
f"Initializing Oracle thick client at {self.service_connection.instantClientDirectory}"
)
os.environ[LD_LIB_ENV] = self.service_connection.instantClientDirectory
oracledb.init_oracle_client(
lib_dir=self.service_connection.instantClientDirectory
)
except DatabaseError as err:
logger.info(f"Could not initialize Oracle thick client: {err}")
return create_generic_db_connection(
connection=self.service_connection,
get_connection_url_fn=self.get_connection_url,
get_connection_args_fn=get_connection_args_common,
)
def test_connection(
self,
metadata: OpenMetadata,
automation_workflow: Optional[AutomationWorkflow] = None,
timeout_seconds: Optional[int] = THREE_MIN,
) -> TestConnectionResult:
"""
Test connection. This can be executed either as part
of a metadata workflow or during an Automation Workflow
"""
def test_oracle_package_access(engine):
try:
schema_name = engine.execute(ORACLE_GET_SCHEMA).scalar()
return ORACLE_GET_STORED_PACKAGES.format(schema=schema_name)
except Exception as e:
raise OraclePackageAccessError(
f"Failed to access Oracle stored packages: {e}"
)
test_conn_queries = {
"CheckAccess": CHECK_ACCESS_TO_ALL,
"PackageAccess": test_oracle_package_access(self.client),
}
return test_connection_db_common(
metadata=metadata,
engine=self.client,
service_connection=self.service_connection,
automation_workflow=automation_workflow,
queries=test_conn_queries,
timeout_seconds=timeout_seconds,
)
def get_connection_dict(self) -> dict:
"""
Return the connection dictionary for this service.
"""
url = self.client.url
connection_copy = deepcopy(self.service_connection)
connection_dict = {
"driver": url.drivername,
"host": f"{url.host}:{url.port}", # This is the format expected by data-diff. If we start using this for something else, we need to change it and modify the data-diff code.
"user": url.username,
}
# Add password if present in the connection
if connection_copy.password:
connection_dict["password"] = connection_copy.password.get_secret_value()
# Add connection type specific information
if isinstance(connection_copy.oracleConnectionType, OracleDatabaseSchema):
connection_dict[
"database"
] = connection_copy.oracleConnectionType.databaseSchema
elif isinstance(connection_copy.oracleConnectionType, OracleServiceName):
connection_dict[
"database"
] = connection_copy.oracleConnectionType.oracleServiceName
elif isinstance(connection_copy.oracleConnectionType, OracleTNSConnection):
connection_dict[
"host"
] = connection_copy.oracleConnectionType.oracleTNSConnection
# Add connection options if present
if connection_copy.connectionOptions and connection_copy.connectionOptions.root:
connection_with_options_secrets(lambda: connection_copy)
connection_dict.update(connection_copy.connectionOptions.root)
# Add connection arguments if present
if (
connection_copy.connectionArguments
and connection_copy.connectionArguments.root
):
connection_dict.update(get_connection_args_common(connection_copy))
return connection_dict
@staticmethod
def get_connection_url(connection: OracleConnectionConfig) -> str:
"""
Build the URL and handle driver version at system level
"""
@ -71,12 +182,14 @@ def get_connection_url(connection: OracleConnection) -> str:
url += f":{quote_plus(connection.password.get_secret_value())}"
url += "@"
url = _handle_connection_type(url=url, connection=connection)
url = OracleConnection._handle_connection_type(url=url, connection=connection)
options = get_connection_options_dict(connection)
if options:
params = "&".join(
f"{key}={quote_plus(value)}" for (key, value) in options.items() if value
f"{key}={quote_plus(value)}"
for (key, value) in options.items()
if value
)
if isinstance(connection.oracleConnectionType, OracleServiceName):
url = f"{url}&{params}"
@ -85,8 +198,8 @@ def get_connection_url(connection: OracleConnection) -> str:
return url
def _handle_connection_type(url: str, connection: OracleConnection) -> str:
@staticmethod
def _handle_connection_type(url: str, connection: OracleConnectionConfig) -> str:
"""
Depending on the oracle connection type, we need to handle the URL differently
"""
@ -114,64 +227,7 @@ def _handle_connection_type(url: str, connection: OracleConnection) -> str:
raise ValueError(f"Unknown connection type {connection.oracleConnectionType}")
def get_connection(connection: OracleConnection) -> Engine:
"""
Create connection
"""
try:
if connection.instantClientDirectory:
logger.info(
f"Initializing Oracle thick client at {connection.instantClientDirectory}"
)
os.environ[LD_LIB_ENV] = connection.instantClientDirectory
oracledb.init_oracle_client()
except DatabaseError as err:
logger.info(f"Could not initialize Oracle thick client: {err}")
return create_generic_db_connection(
connection=connection,
get_connection_url_fn=get_connection_url,
get_connection_args_fn=get_connection_args_common,
)
class OraclePackageAccessError(Exception):
"""
Raised when unable to access Oracle stored packages
"""
def test_connection(
metadata: OpenMetadata,
engine: Engine,
service_connection: OracleConnection,
automation_workflow: Optional[AutomationWorkflow] = None,
timeout_seconds: Optional[int] = THREE_MIN,
) -> TestConnectionResult:
"""
Test connection. This can be executed either as part
of a metadata workflow or during an Automation Workflow
"""
def test_oracle_package_access(engine):
try:
schema_name = engine.execute(ORACLE_GET_SCHEMA).scalar()
return ORACLE_GET_STORED_PACKAGES.format(schema=schema_name)
except Exception as e:
raise OraclePackageAccessError(
f"Failed to access Oracle stored packages: {e}"
)
test_conn_queries = {
"CheckAccess": CHECK_ACCESS_TO_ALL,
"PackageAccess": test_oracle_package_access(engine),
}
return test_connection_db_common(
metadata=metadata,
engine=engine,
service_connection=service_connection,
automation_workflow=automation_workflow,
queries=test_conn_queries,
timeout_seconds=timeout_seconds,
)

View File

@ -1,3 +1,4 @@
from metadata.ingestion.source.database.oracle.connection import OracleConnection
from metadata.ingestion.source.database.oracle.lineage import OracleLineageSource
from metadata.ingestion.source.database.oracle.metadata import OracleSource
from metadata.ingestion.source.database.oracle.usage import OracleUsageSource
@ -7,4 +8,5 @@ ServiceSpec = DefaultDatabaseSpec(
metadata_source_class=OracleSource,
lineage_source_class=OracleLineageSource,
usage_source_class=OracleUsageSource,
connection_class=OracleConnection,
)

View File

@ -0,0 +1,62 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Oracle DataDiff parameter setter functionality"""
from unittest.mock import patch
from metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter import (
BaseTableParameter,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
)
def test_get_data_diff_table_path_oracle_denormalization():
"""Test Oracle table path generation with proper denormalization (lowercase to uppercase)"""
table_fqn = "oracle_service.testdb.schema_name.table_name"
service_type = DatabaseServiceType.Oracle
result = BaseTableParameter.get_data_diff_table_path(table_fqn, service_type)
# Oracle should denormalize names (lowercase to uppercase)
expected = "SCHEMA_NAME.TABLE_NAME"
assert result == expected
def test_get_data_diff_table_path_oracle_denormalization_fallback():
"""Test Oracle table path generation with denormalization error fallback"""
table_fqn = "oracle_service.testdb.schema_name.table_name"
service_type = DatabaseServiceType.Oracle
with patch(
"metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter.make_url"
) as mock_make_url:
mock_make_url.side_effect = Exception("Dialect error")
result = BaseTableParameter.get_data_diff_table_path(table_fqn, service_type)
# Should fallback to original names without denormalization
expected = "schema_name.table_name"
assert result == expected
def test_get_data_diff_table_path_mysql_no_denormalization():
"""Test MySQL table path generation (no denormalization needed)"""
table_fqn = "mysql_service.testdb.schema_name.table_name"
service_type = DatabaseServiceType.Mysql
result = BaseTableParameter.get_data_diff_table_path(table_fqn, service_type)
# MySQL should preserve original case
expected = "schema_name.table_name"
assert result == expected

View File

@ -69,7 +69,9 @@ from metadata.generated.schema.entity.services.connections.database.mysqlConnect
MySQLScheme,
)
from metadata.generated.schema.entity.services.connections.database.oracleConnection import (
OracleConnection,
OracleConnection as OracleConnectionConfig,
)
from metadata.generated.schema.entity.services.connections.database.oracleConnection import (
OracleDatabaseSchema,
OracleScheme,
OracleServiceName,
@ -116,6 +118,7 @@ from metadata.ingestion.connections.builders import (
get_connection_args_common,
get_connection_url_common,
)
from metadata.ingestion.source.database.oracle.connection import OracleConnection
from metadata.ingestion.source.database.snowflake.connection import SnowflakeConnection
from metadata.ingestion.source.database.trino.connection import TrinoConnection
@ -1139,14 +1142,10 @@ class SourceConnectionTest(TestCase):
assert expected_url == get_connection_url(presto_conn_obj)
def test_oracle_url(self):
from metadata.ingestion.source.database.oracle.connection import (
get_connection_url,
)
# oracle with db
expected_url = "oracle+cx_oracle://admin:password@localhost:1541/testdb"
oracle_conn_obj = OracleConnection(
oracle_conn_obj = OracleConnectionConfig(
username="admin",
password="password",
hostPort="localhost:1541",
@ -1154,21 +1153,21 @@ class SourceConnectionTest(TestCase):
oracleConnectionType=OracleDatabaseSchema(databaseSchema="testdb"),
)
assert expected_url == get_connection_url(oracle_conn_obj)
assert expected_url == OracleConnection.get_connection_url(oracle_conn_obj)
# oracle with service name
expected_url = (
"oracle+cx_oracle://admin:password@localhost:1541/?service_name=testdb"
)
oracle_conn_obj = OracleConnection(
oracle_conn_obj = OracleConnectionConfig(
username="admin",
password="password",
hostPort="localhost:1541",
scheme=OracleScheme.oracle_cx_oracle,
oracleConnectionType=OracleServiceName(oracleServiceName="testdb"),
)
assert expected_url == get_connection_url(oracle_conn_obj)
assert expected_url == OracleConnection.get_connection_url(oracle_conn_obj)
# oracle with db & connection options
expected_url = [
@ -1176,7 +1175,7 @@ class SourceConnectionTest(TestCase):
"oracle+cx_oracle://admin:password@localhost:1541/testdb?test_key_1=test_value_1&test_key_2=test_value_2",
]
oracle_conn_obj = OracleConnection(
oracle_conn_obj = OracleConnectionConfig(
username="admin",
password="password",
hostPort="localhost:1541",
@ -1186,7 +1185,7 @@ class SourceConnectionTest(TestCase):
test_key_1="test_value_1", test_key_2="test_value_2"
),
)
assert get_connection_url(oracle_conn_obj) in expected_url
assert OracleConnection.get_connection_url(oracle_conn_obj) in expected_url
# oracle with service name & connection options
expected_url = [
@ -1194,7 +1193,7 @@ class SourceConnectionTest(TestCase):
"oracle+cx_oracle://admin:password@localhost:1541/?service_name=testdb&test_key_1=test_value_1&test_key_2=test_value_2",
]
oracle_conn_obj = OracleConnection(
oracle_conn_obj = OracleConnectionConfig(
username="admin",
password="password",
hostPort="localhost:1541",
@ -1204,7 +1203,7 @@ class SourceConnectionTest(TestCase):
test_key_1="test_value_1", test_key_2="test_value_2"
),
)
assert get_connection_url(oracle_conn_obj) in expected_url
assert OracleConnection.get_connection_url(oracle_conn_obj) in expected_url
tns_connection = (
"(DESCRIPTION=(ADDRESS_LIST=(ADDRESS=(PROTOCOL=TCP)"
@ -1212,7 +1211,7 @@ class SourceConnectionTest(TestCase):
)
expected_url = f"oracle+cx_oracle://admin:password@{tns_connection}"
oracle_conn_obj = OracleConnection(
oracle_conn_obj = OracleConnectionConfig(
username="admin",
password="password",
hostPort="localhost:1541", # We will ignore it here
@ -1220,7 +1219,7 @@ class SourceConnectionTest(TestCase):
oracleTNSConnection=tns_connection
),
)
assert get_connection_url(oracle_conn_obj) == expected_url
assert OracleConnection.get_connection_url(oracle_conn_obj) == expected_url
def test_exasol_url(self):
from metadata.ingestion.source.database.exasol.connection import (