cherry picked be82086e2542d2d176ac66e0bf11100646448b4f

This commit is contained in:
Imri Paran 2024-10-10 21:37:29 +02:00 committed by sushi30
parent 76228dcf45
commit 58b11669aa

View File

@ -13,8 +13,6 @@ from ast import literal_eval
from typing import List, Optional
from urllib.parse import urlparse
from sqlalchemy.engine import Engine
from metadata.data_quality.validations.models import (
Column,
TableDiffRuntimeParameters,
@ -51,12 +49,26 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
}
def get_parameters(self, test_case) -> TableDiffRuntimeParameters:
service1: Engine = get_connection(self.service_connection_config)
service1_url = (
str(get_connection(self.service_connection_config).url)
if self.service_connection_config
else None
)
service1: DatabaseService = self.ometa_client.get_by_id(
DatabaseService, self.table_entity.service.id, nullable=False
)
table2_fqn = self.get_parameter(test_case, "table2")
if table2_fqn is None:
raise ValueError("table2 not set")
table2: Table = self.ometa_client.get_by_name(
Table, fqn=table2_fqn, nullable=False
)
service2 = self.get_service2_url(service1, table2, test_case)
service2_url = (
service1_url if table2.service == self.table_entity.service else None
)
service2: DatabaseService = self.ometa_client.get_by_id(
DatabaseService, table2.service.id, nullable=False
)
key_columns = self.get_key_columns(test_case)
extra_columns = self.get_extra_columns(key_columns, test_case)
return TableDiffRuntimeParameters(
@ -65,7 +77,9 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
self.table_entity.fullyQualifiedName.root
),
serviceUrl=self.get_data_diff_url(
str(service1.url), self.table_entity.fullyQualifiedName.root
service1,
self.table_entity.fullyQualifiedName.root,
override_url=service1_url,
),
columns=self.filter_relevant_columns(
self.table_entity.columns, key_columns, extra_columns
@ -73,7 +87,12 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
),
table2=TableParameter(
path=self.get_data_diff_table_path(table2_fqn),
serviceUrl=self.get_data_diff_url(service2, table2_fqn),
serviceUrl=self.get_data_diff_url(
service2,
table2_fqn,
override_url=self.get_parameter(test_case, "service2Url")
or service2_url,
),
columns=self.filter_relevant_columns(
table2.columns, key_columns, extra_columns
),
@ -99,19 +118,6 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
where_clauses = [f"({x})" for x in where_clauses]
return " AND ".join(where_clauses)
def get_service2_url(self, service1, table2, test_case):
service2 = self.get_parameter(test_case, "service2Url")
if service2 is not None:
pass
elif self.table_entity.service.id == table2.service.id:
service2 = str(service1.url)
else:
table2_service = self.ometa_client.get_by_id(
DatabaseService, table2.service.id
)
service2 = str(get_connection(table2_service.connection.config).url)
return service2
def get_extra_columns(
self, key_columns: List[str], test_case
) -> Optional[List[str]]:
@ -161,16 +167,35 @@ class TableDiffParamsSetter(RuntimeParameterSetter):
)
@staticmethod
def get_data_diff_url(service_url: str, table_fqn) -> str:
url = urlparse(service_url)
def get_data_diff_url(
db_service: DatabaseService, table_fqn, override_url: Optional[str] = None
) -> str:
"""Get the url for the data diff service.
Args:
db_service (DatabaseService): The database service entity
table_fqn (str): The fully qualified name of the table
override_url (Optional[str], optional): Override the url. Defaults to None.
Returns:
str: The url for the data diff service
"""
source_url = (
str(get_connection(db_service.connection.config).url)
if not override_url
else override_url
)
url = urlparse(source_url)
# remove the driver name from the url because table-diff doesn't support it
kwargs = {"scheme": url.scheme.split("+")[0]}
service, database, schema, table = fqn.split( # pylint: disable=unused-variable
table_fqn
)
# path needs to include the database AND schema in some of the connectors
if hasattr(db_service.connection.config, "supportsDatabase"):
kwargs["path"] = f"/{database}"
if kwargs["scheme"] in {Dialects.MSSQL, Dialects.Snowflake}:
kwargs["path"] = f"/{database}/{schema}"
kwargs["path"] += f"/{schema}"
return url._replace(**kwargs).geturl()
@staticmethod