2023-08-02 09:53:39 +05:30
|
|
|
import os
|
|
|
|
from functools import partial
|
|
|
|
from typing import cast
|
|
|
|
from unittest import mock
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from datahub.configuration.common import AllowDenyPattern, DynamicTypedConfig
|
|
|
|
from datahub.ingestion.glossary.classifier import (
|
|
|
|
ClassificationConfig,
|
|
|
|
DynamicTypedClassifierConfig,
|
|
|
|
)
|
|
|
|
from datahub.ingestion.glossary.datahub_classifier import DataHubClassifierConfig
|
|
|
|
from datahub.ingestion.run.pipeline import Pipeline
|
|
|
|
from datahub.ingestion.run.pipeline_config import PipelineConfig, SourceConfig
|
|
|
|
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
|
|
|
|
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
|
|
|
|
from tests.integration.snowflake.common import default_query_results
|
|
|
|
|
|
|
|
NUM_SAMPLE_VALUES = 100
|
|
|
|
TEST_CLASSIFY_PERFORMANCE = os.environ.get("DATAHUB_TEST_CLASSIFY_PERFORMANCE")
|
|
|
|
|
|
|
|
sample_values = ["abc@xyz.com" for _ in range(NUM_SAMPLE_VALUES)]
|
|
|
|
|
|
|
|
|
|
|
|
# Run with --durations=0 to show the timings for different combinations
|
|
|
|
@pytest.mark.skipif(
|
|
|
|
TEST_CLASSIFY_PERFORMANCE is None,
|
|
|
|
reason="DATAHUB_TEST_CLASSIFY_PERFORMANCE env variable is not configured",
|
|
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"num_workers,num_cols_per_table,num_tables",
|
|
|
|
[(w, c, t) for w in [1, 2, 4, 6, 8] for c in [5, 10, 40, 80] for t in [1]],
|
|
|
|
)
|
|
|
|
def test_snowflake_classification_perf(num_workers, num_cols_per_table, num_tables):
|
|
|
|
with mock.patch("snowflake.connector.connect") as mock_connect, mock.patch(
|
|
|
|
"datahub.ingestion.source.snowflake.snowflake_v2.SnowflakeV2Source.get_sample_values_for_table"
|
|
|
|
) as mock_sample_values:
|
|
|
|
sf_connection = mock.MagicMock()
|
|
|
|
sf_cursor = mock.MagicMock()
|
|
|
|
mock_connect.return_value = sf_connection
|
|
|
|
sf_connection.cursor.return_value = sf_cursor
|
|
|
|
|
|
|
|
sf_cursor.execute.side_effect = partial(
|
|
|
|
default_query_results, num_tables=num_tables, num_cols=num_cols_per_table
|
|
|
|
)
|
|
|
|
|
|
|
|
mock_sample_values.return_value = pd.DataFrame(
|
|
|
|
data={f"col_{i}": sample_values for i in range(1, num_cols_per_table + 1)}
|
|
|
|
)
|
|
|
|
|
|
|
|
datahub_classifier_config = DataHubClassifierConfig(
|
|
|
|
confidence_level_threshold=0.58,
|
|
|
|
)
|
|
|
|
|
|
|
|
pipeline = Pipeline(
|
|
|
|
config=PipelineConfig(
|
|
|
|
source=SourceConfig(
|
|
|
|
type="snowflake",
|
|
|
|
config=SnowflakeV2Config(
|
|
|
|
account_id="ABC12345.ap-south-1.aws",
|
|
|
|
username="TST_USR",
|
|
|
|
password="TST_PWD",
|
|
|
|
match_fully_qualified_names=True,
|
|
|
|
schema_pattern=AllowDenyPattern(allow=["test_db.test_schema"]),
|
|
|
|
include_technical_schema=True,
|
|
|
|
include_table_lineage=False,
|
|
|
|
include_column_lineage=False,
|
|
|
|
include_usage_stats=False,
|
|
|
|
include_operational_stats=False,
|
|
|
|
classification=ClassificationConfig(
|
|
|
|
enabled=True,
|
|
|
|
max_workers=num_workers,
|
|
|
|
classifiers=[
|
|
|
|
DynamicTypedClassifierConfig(
|
|
|
|
type="datahub", config=datahub_classifier_config
|
|
|
|
)
|
|
|
|
],
|
|
|
|
),
|
|
|
|
),
|
|
|
|
),
|
|
|
|
sink=DynamicTypedConfig(type="blackhole", config={}),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
pipeline.run()
|
|
|
|
pipeline.pretty_print_summary()
|
|
|
|
pipeline.raise_from_status()
|
|
|
|
|
|
|
|
source_report = pipeline.source.get_report()
|
|
|
|
assert isinstance(source_report, SnowflakeV2Report)
|
|
|
|
assert (
|
2024-03-14 11:15:28 +05:30
|
|
|
cast(SnowflakeV2Report, source_report).num_tables_classification_found
|
|
|
|
== num_tables
|
2023-08-02 09:53:39 +05:30
|
|
|
)
|
|
|
|
assert (
|
|
|
|
len(
|
|
|
|
cast(SnowflakeV2Report, source_report).info_types_detected[
|
|
|
|
"Email_Address"
|
|
|
|
]
|
|
|
|
)
|
|
|
|
== num_tables * num_cols_per_table
|
|
|
|
)
|