Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

149 lines
4.7 KiB
Python
Raw Normal View History

from unittest import TestCase
from unittest.mock import patch
from uuid import uuid4
from sqlalchemy import Column, Integer
from sqlalchemy.orm import declarative_base
from sqlalchemy.sql.selectable import CTE
from metadata.generated.schema.entity.data.table import Column as EntityColumn
from metadata.generated.schema.entity.data.table import (
ColumnName,
DataType,
PartitionIntervalTypes,
PartitionProfilerConfig,
ProfileSampleType,
SamplingMethodType,
Table,
)
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection,
)
from metadata.profiler.interface.sqlalchemy.profiler_interface import (
SQAProfilerInterface,
)
from metadata.sampler.models import SampleConfig
from metadata.sampler.sqlalchemy.postgres.sampler import PostgresSampler
from metadata.sampler.sqlalchemy.sampler import SQASampler
Base = declarative_base()
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
@patch.object(SQASampler, "build_table_orm", return_value=User)
class SampleTest(TestCase):
@classmethod
@patch.object(SQASampler, "build_table_orm", return_value=User)
def setUpClass(cls, sampler_mock):
cls.table_entity = Table(
id=uuid4(),
name="user",
columns=[
EntityColumn(
name=ColumnName("id"),
dataType=DataType.INT,
),
],
)
cls.psql_conn = PostgresConnection(
username="test",
hostPort="localhost:5432",
database="test",
)
cls.sampler = SQASampler(
service_connection_config=cls.psql_conn,
ometa_client=None,
entity=None,
)
cls.sqa_profiler_interface = SQAProfilerInterface(
cls.psql_conn,
None,
cls.table_entity,
None,
cls.sampler,
5,
43200,
)
cls.session = cls.sqa_profiler_interface.session
def test_sampling_default(self, sampler_mock):
"""
Test sampling
"""
sampler = PostgresSampler(
service_connection_config=self.psql_conn,
ometa_client=None,
entity=self.table_entity,
sample_config=SampleConfig(
profile_sample_type=ProfileSampleType.PERCENTAGE,
profile_sample=50.0,
),
)
query: CTE = sampler.get_sample_query()
expected_query = (
"SELECT users_1.id \nFROM users AS users_1 TABLESAMPLE bernoulli(50.0)"
)
assert (
expected_query.casefold()
== str(query.compile(compile_kwargs={"literal_binds": True})).casefold()
)
def test_sampling(self, sampler_mock):
"""
Test sampling
"""
for sampling_method_type in [
SamplingMethodType.SYSTEM,
SamplingMethodType.BERNOULLI,
]:
sampler = PostgresSampler(
service_connection_config=self.psql_conn,
ometa_client=None,
entity=self.table_entity,
sample_config=SampleConfig(
profile_sample_type=ProfileSampleType.PERCENTAGE,
profile_sample=50.0,
sampling_method_type=sampling_method_type,
),
)
query: CTE = sampler.get_sample_query()
expected_query = f"SELECT users_1.id \nFROM users AS users_1 TABLESAMPLE {sampling_method_type.value}(50.0)"
assert (
expected_query.casefold()
== str(query.compile(compile_kwargs={"literal_binds": True})).casefold()
)
def test_sampling_with_partition(self, sampler_mock):
"""
Test sampling with partiton.
"""
sampler = PostgresSampler(
service_connection_config=self.psql_conn,
ometa_client=None,
entity=self.table_entity,
sample_config=SampleConfig(
profile_sample_type=ProfileSampleType.PERCENTAGE, profile_sample=50.0
),
partition_details=PartitionProfilerConfig(
enablePartitioning=True,
partitionColumnName="id",
partitionIntervalType=PartitionIntervalTypes.COLUMN_VALUE,
partitionValues=["1", "2"],
),
)
query: CTE = sampler.get_sample_query()
expected_query = (
"SELECT users_1.id \nFROM users AS users_1 "
"TABLESAMPLE bernoulli(50.0) \nWHERE id IN ('1', '2')"
)
assert (
expected_query.casefold()
== str(query.compile(compile_kwargs={"literal_binds": True})).casefold()
)