214 lines
7.2 KiB
Python
Raw Normal View History

# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests utils function for the profiler
"""
from datetime import datetime
from unittest import TestCase
import pytest
from sqlalchemy import Column
from sqlalchemy.orm import declarative_base
from sqlalchemy.sql.sqltypes import Integer, String
from metadata.profiler.metrics.hybrid.histogram import Histogram
from metadata.profiler.metrics.system.queries.snowflake import (
get_snowflake_system_queries,
)
from metadata.profiler.metrics.system.system import recursive_dic
from metadata.utils.profiler_utils import (
get_identifiers_from_string,
get_value_from_cache,
set_cache,
)
from metadata.utils.sqa_utils import is_array
from .conftest import Row
Base = declarative_base()
class Users(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
name = Column(String(30))
fullname = Column(String)
class TestHistogramUtils(TestCase):
@classmethod
def setUpClass(cls):
cls.histogram = Histogram()
def test_histogram_label_formatter_positive(self):
"""test label formatter for histogram"""
formatted_label = self.histogram._format_bin_labels(18927, 23456)
assert formatted_label == "18.93K to 23.46K"
formatted_label = self.histogram._format_bin_labels(18927)
assert formatted_label == "18.93K and up"
def test_histogram_label_formatter_negative(self):
"""test label formatter for histogram for negative numbers"""
formatted_label = self.histogram._format_bin_labels(-18927, -23456)
assert formatted_label == "-18.93K to -23.46K"
formatted_label = self.histogram._format_bin_labels(-18927)
assert formatted_label == "-18.93K and up"
def test_histogram_label_formatter_none(self):
"""test label formatter for histogram for None"""
formatted_label = self.histogram._format_bin_labels(None)
assert formatted_label == "null and up"
def test_histogram_label_formatter_zero(self):
"""test label formatter for histogram with zero"""
formatted_label = self.histogram._format_bin_labels(0)
assert formatted_label == "0 and up"
def test_histogram_label_formatter_nines(self):
"""test label formatter for histogram for nines"""
formatted_label = self.histogram._format_bin_labels(99999999)
assert formatted_label == "100.00M and up"
def test_histogram_label_formatter_floats(self):
"""test label formatter for histogram for floats"""
formatted_label = self.histogram._format_bin_labels(167893.98542, 194993.98542)
assert formatted_label == "167.89K to 194.99K"
def test_is_array():
"""test is array function"""
kwargs = {}
assert is_array(kwargs) is False
assert not kwargs
kwargs = {"is_array": True, "array_col": "name"}
assert kwargs["is_array"] is True
assert is_array(kwargs) is True
assert kwargs["array_col"] == "name"
assert len(kwargs) == 1
kwargs = {"is_array": False, "array_col": "name"}
assert kwargs["is_array"] is False
assert is_array(kwargs) is False
assert not kwargs
def test_get_snowflake_system_queries():
"""Test get snowflake system queries"""
row = Row(
query_id="1",
query_type="INSERT",
start_time=datetime.now(),
query_text="INSERT INTO DATABASE.SCHEMA.TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')",
)
query_result = get_snowflake_system_queries(row, "DATABASE", "SCHEMA") # type: ignore
assert query_result
assert query_result.query_id == "1"
assert query_result.query_type == "INSERT"
assert query_result.database_name == "database"
assert query_result.schema_name == "schema"
assert query_result.table_name == "table1"
row = Row(
query_id=1,
query_type="INSERT",
start_time=datetime.now(),
query_text="INSERT INTO SCHEMA.TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')",
)
query_result = get_snowflake_system_queries(row, "DATABASE", "SCHEMA") # type: ignore
assert not query_result
@pytest.mark.parametrize(
"query, expected",
[
(
"INSERT INTO DATABASE.SCHEMA.TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')",
"INSERT",
),
(
"INSERT OVERWRITE INTO DATABASE.SCHEMA.TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')",
"INSERT",
),
(
"MERGE INTO DATABASE.SCHEMA.TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')",
"MERGE",
),
("DELETE FROM DATABASE.SCHEMA.TABLE1 WHERE val = 9999", "MERGE"),
("UPDATE DATABASE.SCHEMA.TABLE1 SET col1 = 1 WHERE val = 9999", "UPDATE"),
],
)
def test_get_snowflake_system_queries_all_dll(query, expected):
"""test we ca get all ddl queries
reference https://docs.snowflake.com/en/sql-reference/sql-dml
"""
row = Row(
query_id=1,
query_type=expected,
start_time=datetime.now(),
query_text=query,
)
query_result = get_snowflake_system_queries(row, "DATABASE", "SCHEMA") # type: ignore
assert query_result
assert query_result.query_type == expected
assert query_result.database_name == "database"
assert query_result.schema_name == "schema"
assert query_result.table_name == "table1"
@pytest.mark.parametrize(
"identifier, expected",
[
("DATABASE.SCHEMA.TABLE1", ("DATABASE", "SCHEMA", "TABLE1")),
('DATABASE.SCHEMA."TABLE.DOT"', ("DATABASE", "SCHEMA", "TABLE.DOT")),
('DATABASE."SCHEMA.DOT".TABLE', ("DATABASE", "SCHEMA.DOT", "TABLE")),
('"DATABASE.DOT".SCHEMA.TABLE', ("DATABASE.DOT", "SCHEMA", "TABLE")),
('DATABASE."SCHEMA.DOT"."TABLE.DOT"', ("DATABASE", "SCHEMA.DOT", "TABLE.DOT")),
('"DATABASE.DOT"."SCHEMA.DOT".TABLE', ("DATABASE.DOT", "SCHEMA.DOT", "TABLE")),
(
'"DATABASE.DOT"."SCHEMA.DOT"."TABLE.DOT"',
("DATABASE.DOT", "SCHEMA.DOT", "TABLE.DOT"),
),
],
)
def test_get_identifiers_from_string(identifier, expected):
"""test get identifiers from string"""
assert get_identifiers_from_string(identifier) == expected
def test_cache_func():
"""test get and set cache"""
cache_dict = recursive_dic()
cache_value = [1, 2, 3, 4, 5]
new_cache_value = [6, 7, 8, 9, 10]
cache = get_value_from_cache(cache_dict, "key1.key2.key3")
assert not cache
set_cache(cache_dict, "key1.key2.key3", cache_value)
cache = get_value_from_cache(cache_dict, "key1.key2.key3")
assert cache == cache_value
# calling set_cache on the same key will reset the cache
set_cache(cache_dict, "key1.key2.key3", new_cache_value)
cache = get_value_from_cache(cache_dict, "key1.key2.key3")
assert cache == new_cache_value