321 lines
10 KiB
Python
Raw Normal View History

# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test Sample behavior
"""
import os
from unittest import TestCase
from sqlalchemy import TEXT, Column, Integer, String, func
from sqlalchemy.orm import declarative_base
from metadata.generated.schema.entity.services.connections.database.sqliteConnection import (
SQLiteConnection,
SQLiteScheme,
)
from metadata.orm_profiler.interfaces.sqa_profiler_interface import SQAProfilerInterface
from metadata.orm_profiler.metrics.registry import Metrics
from metadata.orm_profiler.orm.registry import CustomTypes
from metadata.orm_profiler.profiler.core import Profiler
from metadata.orm_profiler.profiler.sampler import Sampler
Base = declarative_base()
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
name = Column(String(256))
fullname = Column(String(256))
nickname = Column(String(256))
comments = Column(TEXT)
age = Column(Integer)
class SampleTest(TestCase):
"""
Run checks on different metrics
"""
db_path = os.path.join(
os.path.dirname(__file__), f"{os.path.splitext(__file__)[0]}.db"
)
sqlite_conn = SQLiteConnection(
scheme=SQLiteScheme.sqlite_pysqlite,
databaseMode=db_path + "?check_same_thread=False",
)
sqa_profiler_interface = SQAProfilerInterface(sqlite_conn)
engine = sqa_profiler_interface.session.get_bind()
session = sqa_profiler_interface.session
@classmethod
def setUpClass(cls) -> None:
"""
Prepare Ingredients
"""
User.__table__.create(bind=cls.engine)
# Insert 30 rows
for i in range(10):
data = [
User(
name="John",
fullname="John Doe",
nickname="johnny b goode",
comments="no comments",
age=30,
),
User(
name="Jane",
fullname="Jone Doe",
nickname=None,
comments="maybe some comments",
age=31,
),
User(
name="John",
fullname="John Doe",
nickname=None,
comments=None,
age=None,
),
]
cls.session.add_all(data)
cls.session.commit()
def test_random_sampler(self):
"""
The random sampler should be able to
generate a random subset of data
"""
sampler = Sampler(session=self.session, table=User, profile_sample=50.0)
random_sample = sampler.random_sample()
res = self.session.query(func.count()).select_from(random_sample).first()
assert res[0] < 30
def test_sample_property(self):
"""
Sample property should be properly generated
"""
self.sqa_profiler_interface.create_sampler(User, profile_sample=50.0)
# Randomly pick table_count to init the Profiler, we don't care for this test
table_count = Metrics.ROW_COUNT.value
profiler = Profiler(
table_count,
profiler_interface=self.sqa_profiler_interface,
table=User,
)
res = self.session.query(func.count()).select_from(profiler.sample).first()
assert res[0] < 30
def test_table_row_count(self):
"""
Profile sample should be ignored in row count
"""
table_count = Metrics.ROW_COUNT.value
profiler = Profiler(
table_count,
profiler_interface=self.sqa_profiler_interface,
table=User,
profile_sample=50.0,
)
res = profiler.execute()._table_results
assert res.get(Metrics.ROW_COUNT.name) == 30
def test_random_sample_count(self):
"""
Check we can properly sample data.
There's a random component, so we cannot ensure to always
get 15 rows, but for sure we should get less than 30.
"""
count = Metrics.COUNT.value
profiler = Profiler(
count,
profiler_interface=self.sqa_profiler_interface,
table=User,
use_cols=[User.name],
profile_sample=50,
)
res = profiler.execute()._column_results
assert res.get(User.name.name)[Metrics.COUNT.name] < 30
profiler = Profiler(
count,
profiler_interface=self.sqa_profiler_interface,
table=User,
profile_sample=100.0,
use_cols=[User.name],
)
res = profiler.execute()._column_results
assert res.get(User.name.name)[Metrics.COUNT.name] == 30
def test_random_sample_histogram(self):
"""
Histogram should run correctly
"""
hist = Metrics.HISTOGRAM.value
profiler = Profiler(
hist,
profiler_interface=self.sqa_profiler_interface,
table=User,
use_cols=[User.id],
profile_sample=50.0,
)
res = profiler.execute()._column_results
# The sum of all frequencies should be sampled
assert sum(res.get(User.id.name)[Metrics.HISTOGRAM.name]["frequencies"]) < 30
profiler = Profiler(
hist,
profiler_interface=self.sqa_profiler_interface,
table=User,
use_cols=[User.id],
profile_sample=100.0,
)
res = profiler.execute()._column_results
# The sum of all frequencies should be sampled
assert sum(res.get(User.id.name)[Metrics.HISTOGRAM.name]["frequencies"]) == 30.0
def test_random_sample_unique_count(self):
"""
Unique count should run correctly
"""
self.sqa_profiler_interface.create_sampler(
User,
profile_sample=50.0,
)
self.sqa_profiler_interface.create_runner(User)
hist = Metrics.UNIQUE_COUNT.value
profiler = Profiler(
hist,
profiler_interface=self.sqa_profiler_interface,
table=User,
use_cols=[User.name],
)
res = profiler.execute()._column_results
# As we repeat data, we expect 0 unique counts.
# This tests might very rarely, fail, depending on the sampled random data.
assert res.get(User.name.name)[Metrics.UNIQUE_COUNT.name] <= 1
self.sqa_profiler_interface.create_sampler(
User,
profile_sample=100.0,
)
self.sqa_profiler_interface.create_runner(User)
profiler = Profiler(
hist,
profiler_interface=self.sqa_profiler_interface,
table=User,
use_cols=[User.name],
)
res = profiler.execute()._column_results
# As we repeat data, we expect 0 unique counts.
# This tests might very rarely, fail, depending on the sampled random data.
assert res.get(User.name.name)[Metrics.UNIQUE_COUNT.name] == 0
def test_sample_data(self):
"""
We should be able to pick up sample data from the sampler
"""
sampler = Sampler(session=self.session, table=User)
sample_data = sampler.fetch_sample_data()
assert len(sample_data.columns) == 6
assert len(sample_data.rows) == 30
# Order matters, this is how we'll present the data
names = [str(col.__root__) for col in sample_data.columns]
assert names == ["id", "name", "fullname", "nickname", "comments", "age"]
def test_sample_data_binary(self):
"""
We should be able to pick up sample data from the sampler
"""
class UserBinary(Base):
__tablename__ = "users_binary"
id = Column(Integer, primary_key=True)
name = Column(String(256))
fullname = Column(String(256))
nickname = Column(String(256))
comments = Column(TEXT)
age = Column(Integer)
password_hash = Column(CustomTypes.BYTES.value)
UserBinary.__table__.create(bind=self.engine)
for i in range(10):
data = [
UserBinary(
name="John",
fullname="John Doe",
nickname="johnny b goode",
comments="no comments",
age=30,
password_hash=b"foo",
),
]
self.session.add_all(data)
self.session.commit()
sampler = Sampler(session=self.session, table=UserBinary)
sample_data = sampler.fetch_sample_data()
assert len(sample_data.columns) == 7
assert len(sample_data.rows) == 10
names = [str(col.__root__) for col in sample_data.columns]
assert names == [
"id",
"name",
"fullname",
"nickname",
"comments",
"age",
"password_hash",
]
assert type(sample_data.rows[0][6]) == str
UserBinary.__table__.drop(bind=self.engine)
def test_sample_from_user_query(self):
"""
Test sample data are returned based on user query
"""
stmt = "SELECT id, name FROM users"
sampler = Sampler(session=self.session, table=User, profile_sample_query=stmt)
sample_data = sampler.fetch_sample_data()
assert len(sample_data.columns) == 2
names = [col.__root__ for col in sample_data.columns]
assert names == ["id", "name"]
@classmethod
def tearDownClass(cls) -> None:
os.remove(cls.db_path)
return super().tearDownClass()