From 58b11669aa618b2ae893794c117f4fba968e044c Mon Sep 17 00:00:00 2001 From: Imri Paran Date: Thu, 10 Oct 2024 21:37:29 +0200 Subject: [PATCH] cherry picked be82086e2542d2d176ac66e0bf11100646448b4f --- .../table_diff_params_setter.py | 69 +++++++++++++------ 1 file changed, 47 insertions(+), 22 deletions(-) diff --git a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py index ec373acc745..8acb57e033b 100644 --- a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py +++ b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py @@ -13,8 +13,6 @@ from ast import literal_eval from typing import List, Optional from urllib.parse import urlparse -from sqlalchemy.engine import Engine - from metadata.data_quality.validations.models import ( Column, TableDiffRuntimeParameters, @@ -51,12 +49,26 @@ class TableDiffParamsSetter(RuntimeParameterSetter): } def get_parameters(self, test_case) -> TableDiffRuntimeParameters: - service1: Engine = get_connection(self.service_connection_config) + service1_url = ( + str(get_connection(self.service_connection_config).url) + if self.service_connection_config + else None + ) + service1: DatabaseService = self.ometa_client.get_by_id( + DatabaseService, self.table_entity.service.id, nullable=False + ) table2_fqn = self.get_parameter(test_case, "table2") + if table2_fqn is None: + raise ValueError("table2 not set") table2: Table = self.ometa_client.get_by_name( Table, fqn=table2_fqn, nullable=False ) - service2 = self.get_service2_url(service1, table2, test_case) + service2_url = ( + service1_url if table2.service == self.table_entity.service else None + ) + service2: DatabaseService = self.ometa_client.get_by_id( + DatabaseService, table2.service.id, nullable=False + ) key_columns = self.get_key_columns(test_case) extra_columns = self.get_extra_columns(key_columns, test_case) return TableDiffRuntimeParameters( @@ -65,7 +77,9 @@ class TableDiffParamsSetter(RuntimeParameterSetter): self.table_entity.fullyQualifiedName.root ), serviceUrl=self.get_data_diff_url( - str(service1.url), self.table_entity.fullyQualifiedName.root + service1, + self.table_entity.fullyQualifiedName.root, + override_url=service1_url, ), columns=self.filter_relevant_columns( self.table_entity.columns, key_columns, extra_columns @@ -73,7 +87,12 @@ class TableDiffParamsSetter(RuntimeParameterSetter): ), table2=TableParameter( path=self.get_data_diff_table_path(table2_fqn), - serviceUrl=self.get_data_diff_url(service2, table2_fqn), + serviceUrl=self.get_data_diff_url( + service2, + table2_fqn, + override_url=self.get_parameter(test_case, "service2Url") + or service2_url, + ), columns=self.filter_relevant_columns( table2.columns, key_columns, extra_columns ), @@ -99,19 +118,6 @@ class TableDiffParamsSetter(RuntimeParameterSetter): where_clauses = [f"({x})" for x in where_clauses] return " AND ".join(where_clauses) - def get_service2_url(self, service1, table2, test_case): - service2 = self.get_parameter(test_case, "service2Url") - if service2 is not None: - pass - elif self.table_entity.service.id == table2.service.id: - service2 = str(service1.url) - else: - table2_service = self.ometa_client.get_by_id( - DatabaseService, table2.service.id - ) - service2 = str(get_connection(table2_service.connection.config).url) - return service2 - def get_extra_columns( self, key_columns: List[str], test_case ) -> Optional[List[str]]: @@ -161,16 +167,35 @@ class TableDiffParamsSetter(RuntimeParameterSetter): ) @staticmethod - def get_data_diff_url(service_url: str, table_fqn) -> str: - url = urlparse(service_url) + def get_data_diff_url( + db_service: DatabaseService, table_fqn, override_url: Optional[str] = None + ) -> str: + """Get the url for the data diff service. + + Args: + db_service (DatabaseService): The database service entity + table_fqn (str): The fully qualified name of the table + override_url (Optional[str], optional): Override the url. Defaults to None. + + Returns: + str: The url for the data diff service + """ + source_url = ( + str(get_connection(db_service.connection.config).url) + if not override_url + else override_url + ) + url = urlparse(source_url) # remove the driver name from the url because table-diff doesn't support it kwargs = {"scheme": url.scheme.split("+")[0]} service, database, schema, table = fqn.split( # pylint: disable=unused-variable table_fqn ) # path needs to include the database AND schema in some of the connectors + if hasattr(db_service.connection.config, "supportsDatabase"): + kwargs["path"] = f"/{database}" if kwargs["scheme"] in {Dialects.MSSQL, Dialects.Snowflake}: - kwargs["path"] = f"/{database}/{schema}" + kwargs["path"] += f"/{schema}" return url._replace(**kwargs).geturl() @staticmethod