datahub/metadata-ingestion/tests/integration/snowflake/test_snowflake_classification.py

104 lines
4.0 KiB
Python
Raw Permalink Normal View History

import os
from functools import partial
from typing import cast
from unittest import mock
import pandas as pd
import pytest
from datahub.configuration.common import AllowDenyPattern, DynamicTypedConfig
from datahub.ingestion.glossary.classifier import (
ClassificationConfig,
DynamicTypedClassifierConfig,
)
from datahub.ingestion.glossary.datahub_classifier import DataHubClassifierConfig
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.run.pipeline_config import PipelineConfig, SourceConfig
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
from tests.integration.snowflake.common import default_query_results
NUM_SAMPLE_VALUES = 100
TEST_CLASSIFY_PERFORMANCE = os.environ.get("DATAHUB_TEST_CLASSIFY_PERFORMANCE")
sample_values = ["abc@xyz.com" for _ in range(NUM_SAMPLE_VALUES)]
# Run with --durations=0 to show the timings for different combinations
@pytest.mark.skipif(
TEST_CLASSIFY_PERFORMANCE is None,
reason="DATAHUB_TEST_CLASSIFY_PERFORMANCE env variable is not configured",
)
@pytest.mark.parametrize(
"num_workers,num_cols_per_table,num_tables",
[(w, c, t) for w in [1, 2, 4, 6, 8] for c in [5, 10, 40, 80] for t in [1]],
)
def test_snowflake_classification_perf(num_workers, num_cols_per_table, num_tables):
with mock.patch("snowflake.connector.connect") as mock_connect, mock.patch(
"datahub.ingestion.source.snowflake.snowflake_v2.SnowflakeV2Source.get_sample_values_for_table"
) as mock_sample_values:
sf_connection = mock.MagicMock()
sf_cursor = mock.MagicMock()
mock_connect.return_value = sf_connection
sf_connection.cursor.return_value = sf_cursor
sf_cursor.execute.side_effect = partial(
default_query_results, num_tables=num_tables, num_cols=num_cols_per_table
)
mock_sample_values.return_value = pd.DataFrame(
data={f"col_{i}": sample_values for i in range(1, num_cols_per_table + 1)}
)
datahub_classifier_config = DataHubClassifierConfig(
confidence_level_threshold=0.58,
)
pipeline = Pipeline(
config=PipelineConfig(
source=SourceConfig(
type="snowflake",
config=SnowflakeV2Config(
account_id="ABC12345.ap-south-1.aws",
username="TST_USR",
password="TST_PWD",
match_fully_qualified_names=True,
schema_pattern=AllowDenyPattern(allow=["test_db.test_schema"]),
include_technical_schema=True,
include_table_lineage=False,
include_column_lineage=False,
include_usage_stats=False,
include_operational_stats=False,
classification=ClassificationConfig(
enabled=True,
max_workers=num_workers,
classifiers=[
DynamicTypedClassifierConfig(
type="datahub", config=datahub_classifier_config
)
],
),
),
),
sink=DynamicTypedConfig(type="blackhole", config={}),
)
)
pipeline.run()
pipeline.pretty_print_summary()
pipeline.raise_from_status()
source_report = pipeline.source.get_report()
assert isinstance(source_report, SnowflakeV2Report)
assert (
cast(SnowflakeV2Report, source_report).num_tables_classification_found
== num_tables
)
assert (
len(
cast(SnowflakeV2Report, source_report).info_types_detected[
"Email_Address"
]
)
== num_tables * num_cols_per_table
)