Fix Get Connection (#22033)

* Fix Get Connection

* Fix Data Diff Get Connection
This commit is contained in:
IceS2 2025-07-04 14:04:53 +02:00 committed by GitHub
parent 79df730b85
commit f97a40da6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 8 deletions

View File

@ -43,8 +43,10 @@ class ServiceSpecPatch:
def get_data_diff_class(self) -> Type["BaseTableParameter"]:
return import_from_module(self.service_spec.data_diff)
def get_connection_class(self) -> Type[BaseConnection]:
return import_from_module(self.service_spec.connection_class)
def get_connection_class(self) -> Optional[Type[BaseConnection]]:
if self.service_spec.connection_class:
return import_from_module(self.service_spec.connection_class)
return None
class BaseTableParameter:
@ -99,18 +101,23 @@ class BaseTableParameter:
@staticmethod
def _get_service_connection_config(
db_service: DatabaseService,
service_connection_config,
) -> Optional[Union[str, dict]]:
"""
Get the connection dictionary for the service.
"""
service_connection_config = db_service.connection.config
service_spec_patch = ServiceSpecPatch(
ServiceType.Database, service_connection_config.type.value.lower()
)
try:
connection_class = service_spec_patch.get_connection_class()
if not connection_class:
return (
str(get_connection(service_connection_config).url)
if service_connection_config
else None
)
connection = connection_class(service_connection_config)
return connection.get_connection_dict()
except (ValueError, AttributeError, NotImplementedError):
@ -137,7 +144,9 @@ class BaseTableParameter:
str: The url for the data diff service
"""
source_url = (
BaseTableParameter._get_service_connection_config(db_service)
BaseTableParameter._get_service_connection_config(
db_service.connection.config
)
if not override_url
else override_url
)

View File

@ -58,7 +58,9 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
DatabaseService, self.table_entity.service.id, nullable=False
)
service1_url = BaseTableParameter._get_service_connection_config(service1)
service1_url = BaseTableParameter._get_service_connection_config(
self.service_connection_config
)
table2_fqn = self.get_parameter(test_case, "table2")
if table2_fqn is None:

View File

@ -27,7 +27,7 @@ from metadata.generated.schema.entity.services.connections.database.common.basic
BasicAuth,
)
from metadata.generated.schema.entity.services.connections.database.mysqlConnection import (
MysqlConnection,
MysqlConnection as MySQLConnectionConfig,
)
from metadata.generated.schema.entity.services.connections.testConnectionResult import (
TestConnectionResult,
@ -49,7 +49,7 @@ from metadata.ingestion.source.database.mysql.queries import (
from metadata.utils.constants import THREE_MIN
class MySQLConnection(BaseConnection[MysqlConnection, Engine]):
class MySQLConnection(BaseConnection[MySQLConnectionConfig, Engine]):
def _get_client(self) -> Engine:
"""
Return the SQLAlchemy Engine for MySQL.