Teddy ef131d7e20
MINOR: Wrong attribute name in SampleConfig model (#19641)
* fix: wrong attribute name in SampleConfig model

* fix: test attribute

* fix: failing tests

* fix: trino filter error + adjust test to take into account null value

* fix: mssql and azuresql tablesample on views
2025-02-04 10:40:40 +01:00

178 lines
6.0 KiB
Python

from unittest import TestCase
from unittest.mock import patch
from uuid import uuid4
from sqlalchemy import Column, Integer
from sqlalchemy.orm import declarative_base
from sqlalchemy.sql.selectable import CTE
from metadata.generated.schema.entity.data.table import Column as EntityColumn
from metadata.generated.schema.entity.data.table import (
ColumnName,
DataType,
PartitionIntervalTypes,
PartitionProfilerConfig,
ProfileSampleType,
SamplingMethodType,
Table,
)
from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import (
SnowflakeConnection,
)
from metadata.profiler.interface.sqlalchemy.profiler_interface import (
SQAProfilerInterface,
)
from metadata.sampler.models import SampleConfig
from metadata.sampler.sqlalchemy.sampler import SQASampler
from metadata.sampler.sqlalchemy.snowflake.sampler import SnowflakeSampler
Base = declarative_base()
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
@patch.object(SQASampler, "build_table_orm", return_value=User)
class SampleTest(TestCase):
@classmethod
@patch.object(SQASampler, "build_table_orm", return_value=User)
def setUpClass(cls, sampler_mock):
cls.table_entity = Table(
id=uuid4(),
name="user",
columns=[
EntityColumn(
name=ColumnName("id"),
dataType=DataType.INT,
),
],
)
cls.snowflake_conn = SnowflakeConnection(
username="myuser", account="myaccount", warehouse="mywarehouse"
)
sampler = SQASampler(
service_connection_config=cls.snowflake_conn,
ometa_client=None,
entity=None,
)
cls.sqa_profiler_interface = SQAProfilerInterface(
cls.snowflake_conn,
None,
cls.table_entity,
None,
sampler,
5,
43200,
)
cls.session = cls.sqa_profiler_interface.session
def test_omit_sampling_method_type(self, sampler_mock):
"""
use BERNOULLI if sampling method type is not specified.
"""
sampler = SnowflakeSampler(
service_connection_config=self.snowflake_conn,
ometa_client=None,
entity=self.table_entity,
sample_config=SampleConfig(
profileSampleType=ProfileSampleType.PERCENTAGE, profileSample=50.0
),
)
query: CTE = sampler.get_sample_query()
expected_query = (
"WITH users_rnd AS \n(SELECT users_1.id AS id \n"
"FROM users AS users_1 TABLESAMPLE bernoulli(50.0))\n "
"SELECT users_rnd.id \nFROM users_rnd"
)
assert (
expected_query.casefold()
== str(query.compile(compile_kwargs={"literal_binds": True})).casefold()
)
def test_specify_sampling_method_type(self, sampler_mock):
"""
use specified sampling method type.
"""
for sampling_method_type in [
SamplingMethodType.SYSTEM,
SamplingMethodType.BERNOULLI,
]:
sampler = SnowflakeSampler(
service_connection_config=self.snowflake_conn,
ometa_client=None,
entity=self.table_entity,
sample_config=SampleConfig(
profileSampleType=ProfileSampleType.PERCENTAGE,
profileSample=50.0,
samplingMethodType=sampling_method_type,
),
)
query: CTE = sampler.get_sample_query()
expected_query = (
"WITH users_rnd AS \n(SELECT users_1.id AS id \n"
f"FROM users AS users_1 TABLESAMPLE {sampling_method_type.value}(50.0))\n "
"SELECT users_rnd.id \nFROM users_rnd"
)
assert (
expected_query.casefold()
== str(query.compile(compile_kwargs={"literal_binds": True})).casefold()
)
def test_row_sampling(self, sampler_mock):
"""
use ROW sampling if profile sample type is ROW.
"""
sampler = SnowflakeSampler(
service_connection_config=self.snowflake_conn,
ometa_client=None,
entity=self.table_entity,
sample_config=SampleConfig(
profileSampleType=ProfileSampleType.ROWS, profileSample=50
),
)
query: CTE = sampler.get_sample_query()
expected_query = (
"WITH users_rnd AS \n(SELECT users_1.id AS id "
"\nFROM users AS users_1 TABLESAMPLE ROW(50 ROWS))\n "
"SELECT users_rnd.id \nFROM users_rnd"
)
assert (
expected_query.casefold()
== str(query.compile(compile_kwargs={"literal_binds": True})).casefold()
)
def test_sampling_with_partition(self, sampler_mock):
"""
use specified partition columns.
"""
sampler = SnowflakeSampler(
service_connection_config=self.snowflake_conn,
ometa_client=None,
entity=self.table_entity,
sample_config=SampleConfig(
profileSampleType=ProfileSampleType.PERCENTAGE,
profileSample=50.0,
),
partition_details=PartitionProfilerConfig(
enablePartitioning=True,
partitionColumnName="id",
partitionIntervalType=PartitionIntervalTypes.COLUMN_VALUE,
partitionValues=["1", "2"],
),
)
query: CTE = sampler.get_sample_query()
expected_query = (
"WITH users_rnd AS \n(SELECT users_1.id AS id \n"
"FROM users AS users_1 TABLESAMPLE bernoulli(50.0) "
"\nWHERE id IN ('1', '2'))\n SELECT users_rnd.id \nFROM users_rnd"
)
assert (
expected_query.casefold()
== str(query.compile(compile_kwargs={"literal_binds": True})).casefold()
)