mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-12 18:47:45 +00:00
feat(classification): allow parallelisation to reduce time (#8368)
This commit is contained in:
parent
bf47d65412
commit
e67f811034
@ -10,6 +10,7 @@ Note that a `.` is used to denote nested fields in the YAML recipe.
|
||||
| ------------------------- | -------- | --------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------- |
|
||||
| enabled | | boolean | Whether classification should be used to auto-detect glossary terms | False |
|
||||
| sample_size | | int | Number of sample values used for classification. | 100 |
|
||||
| max_workers | | int | Number of worker threads to use for classification. Set to 1 to disable. | Number of cpu cores or 4 |
|
||||
| info_type_to_term | | Dict[str,string] | Optional mapping to provide glossary term identifier for info type. | By default, info type is used as glossary term identifier. |
|
||||
| classifiers | | Array of object | Classifiers to use to auto-detect glossary terms. If more than one classifier, infotype predictions from the classifier defined later in sequence take precedance. | [{'type': 'datahub', 'config': None}] |
|
||||
| table_pattern | | AllowDenyPattern (see below for fields) | Regex patterns to filter tables for classification. This is used in combination with other patterns in parent config. Specify regex to match the entire table name in `database.schema.table` format. e.g. to match all tables starting with customer in Customer database and public schema, use the regex 'Customer.public.customer.*' | {'allow': ['.*'], 'deny': [], 'ignoreCase': True} |
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List
|
||||
from math import ceil
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from datahub_classify.helper_classes import ColumnInfo, Metadata
|
||||
from pydantic import Field
|
||||
@ -108,15 +110,23 @@ class ClassificationHandler:
|
||||
return None
|
||||
|
||||
logger.debug(f"Classifying Table {dataset_name}")
|
||||
|
||||
self.report.num_tables_classification_attempted += 1
|
||||
field_terms: Dict[str, str] = {}
|
||||
with PerfTimer() as timer:
|
||||
try:
|
||||
for classifier in self.classifiers:
|
||||
column_info_with_proposals = classifier.classify(column_infos)
|
||||
self.extract_field_wise_terms(
|
||||
field_terms, column_info_with_proposals
|
||||
)
|
||||
column_infos_with_proposals: Iterable[ColumnInfo]
|
||||
if self.config.classification.max_workers > 1:
|
||||
column_infos_with_proposals = self.async_classify(
|
||||
classifier, column_infos
|
||||
)
|
||||
else:
|
||||
column_infos_with_proposals = classifier.classify(column_infos)
|
||||
|
||||
for column_info_proposal in column_infos_with_proposals:
|
||||
self.update_field_terms(field_terms, column_info_proposal)
|
||||
|
||||
except Exception:
|
||||
self.report.num_tables_classification_failed += 1
|
||||
raise
|
||||
@ -130,6 +140,44 @@ class ClassificationHandler:
|
||||
self.report.num_tables_classified += 1
|
||||
self.populate_terms_in_schema_metadata(schema_metadata, field_terms)
|
||||
|
||||
def update_field_terms(
|
||||
self, field_terms: Dict[str, str], col_info: ColumnInfo
|
||||
) -> None:
|
||||
term = self.get_terms_for_column(col_info)
|
||||
if term:
|
||||
field_terms[col_info.metadata.name] = term
|
||||
|
||||
def async_classify(
|
||||
self, classifier: Classifier, columns: List[ColumnInfo]
|
||||
) -> Iterable[ColumnInfo]:
|
||||
num_columns = len(columns)
|
||||
BATCH_SIZE = 5 # Number of columns passed to classify api at a time
|
||||
|
||||
logger.debug(
|
||||
f"Will Classify {num_columns} column(s) with {self.config.classification.max_workers} worker(s) with batch size {BATCH_SIZE}."
|
||||
)
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(
|
||||
max_workers=self.config.classification.max_workers,
|
||||
) as executor:
|
||||
column_info_proposal_futures = [
|
||||
executor.submit(
|
||||
classifier.classify,
|
||||
columns[
|
||||
(i * BATCH_SIZE) : min(i * BATCH_SIZE + BATCH_SIZE, num_columns)
|
||||
],
|
||||
)
|
||||
for i in range(ceil(num_columns / BATCH_SIZE))
|
||||
]
|
||||
|
||||
return [
|
||||
column_with_proposal
|
||||
for proposal_future in concurrent.futures.as_completed(
|
||||
column_info_proposal_futures
|
||||
)
|
||||
for column_with_proposal in proposal_future.result()
|
||||
]
|
||||
|
||||
def populate_terms_in_schema_metadata(
|
||||
self,
|
||||
schema_metadata: SchemaMetadata,
|
||||
@ -154,25 +202,20 @@ class ClassificationHandler:
|
||||
),
|
||||
)
|
||||
|
||||
def extract_field_wise_terms(
|
||||
self,
|
||||
field_terms: Dict[str, str],
|
||||
column_info_with_proposals: List[ColumnInfo],
|
||||
) -> None:
|
||||
for col_info in column_info_with_proposals:
|
||||
if not col_info.infotype_proposals:
|
||||
continue
|
||||
infotype_proposal = max(
|
||||
col_info.infotype_proposals, key=lambda p: p.confidence_level
|
||||
)
|
||||
self.report.info_types_detected.setdefault(
|
||||
infotype_proposal.infotype, LossyList()
|
||||
).append(f"{col_info.metadata.dataset_name}.{col_info.metadata.name}")
|
||||
field_terms[
|
||||
col_info.metadata.name
|
||||
] = self.config.classification.info_type_to_term.get(
|
||||
infotype_proposal.infotype, infotype_proposal.infotype
|
||||
)
|
||||
def get_terms_for_column(self, col_info: ColumnInfo) -> Optional[str]:
|
||||
if not col_info.infotype_proposals:
|
||||
return None
|
||||
infotype_proposal = max(
|
||||
col_info.infotype_proposals, key=lambda p: p.confidence_level
|
||||
)
|
||||
self.report.info_types_detected.setdefault(
|
||||
infotype_proposal.infotype, LossyList()
|
||||
).append(f"{col_info.metadata.dataset_name}.{col_info.metadata.name}")
|
||||
term = self.config.classification.info_type_to_term.get(
|
||||
infotype_proposal.infotype, infotype_proposal.infotype
|
||||
)
|
||||
|
||||
return term
|
||||
|
||||
def get_columns_to_classify(
|
||||
self,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
@ -36,6 +37,11 @@ class ClassificationConfig(ConfigModel):
|
||||
default=100, description="Number of sample values used for classification."
|
||||
)
|
||||
|
||||
max_workers: int = Field(
|
||||
default=(os.cpu_count() or 4),
|
||||
description="Number of worker threads to use for classification. Set to 1 to disable.",
|
||||
)
|
||||
|
||||
table_pattern: AllowDenyPattern = Field(
|
||||
default=AllowDenyPattern.allow_all(),
|
||||
description="Regex patterns to filter tables for classification. This is used in combination with other patterns in parent config. Specify regex to match the entire table name in `database.schema.table` format. e.g. to match all tables starting with customer in Customer database and public schema, use the regex 'Customer.public.customer.*'",
|
||||
|
||||
@ -173,4 +173,5 @@ class DataHubClassifier(Classifier):
|
||||
infotypes=self.config.info_types,
|
||||
minimum_values_threshold=self.config.minimum_values_threshold,
|
||||
)
|
||||
|
||||
return columns
|
||||
|
||||
@ -872,8 +872,8 @@ class SnowflakeV2Source(
|
||||
self.gen_schema_metadata(table, schema_name, db_name)
|
||||
|
||||
def fetch_sample_data_for_classification(
|
||||
self, table, schema_name, db_name, dataset_name
|
||||
):
|
||||
self, table: SnowflakeTable, schema_name: str, db_name: str, dataset_name: str
|
||||
) -> None:
|
||||
if (
|
||||
table.columns
|
||||
and self.config.classification.enabled
|
||||
@ -1225,7 +1225,12 @@ class SnowflakeV2Source(
|
||||
)
|
||||
return foreign_keys
|
||||
|
||||
def classify_snowflake_table(self, table, dataset_name, schema_metadata):
|
||||
def classify_snowflake_table(
|
||||
self,
|
||||
table: Union[SnowflakeTable, SnowflakeView],
|
||||
dataset_name: str,
|
||||
schema_metadata: SchemaMetadata,
|
||||
) -> None:
|
||||
if (
|
||||
isinstance(table, SnowflakeTable)
|
||||
and self.config.classification.enabled
|
||||
@ -1255,6 +1260,9 @@ class SnowflakeV2Source(
|
||||
"Failed to classify table columns",
|
||||
dataset_name,
|
||||
)
|
||||
finally:
|
||||
# Cleaning up sample_data fetched for classification
|
||||
table.sample_data = None
|
||||
|
||||
def get_report(self) -> SourceReport:
|
||||
return self.report
|
||||
@ -1470,7 +1478,7 @@ class SnowflakeV2Source(
|
||||
df = pd.DataFrame(dat, columns=[col.name for col in cur.description])
|
||||
time_taken = timer.elapsed_seconds()
|
||||
logger.debug(
|
||||
f"Finished collecting sample values for table {db_name}.{schema_name}.{table_name}; took {time_taken:.3f} seconds"
|
||||
f"Finished collecting sample values for table {db_name}.{schema_name}.{table_name};{df.shape[0]} rows; took {time_taken:.3f} seconds"
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
@ -14,7 +14,13 @@ NUM_OPS = 10
|
||||
FROZEN_TIME = "2022-06-07 17:00:00"
|
||||
|
||||
|
||||
def default_query_results(query): # noqa: C901
|
||||
def default_query_results( # noqa: C901
|
||||
query,
|
||||
num_tables=NUM_TABLES,
|
||||
num_views=NUM_VIEWS,
|
||||
num_cols=NUM_COLS,
|
||||
num_ops=NUM_OPS,
|
||||
):
|
||||
if query == SnowflakeQuery.current_account():
|
||||
return [{"CURRENT_ACCOUNT()": "ABC12345"}]
|
||||
if query == SnowflakeQuery.current_region():
|
||||
@ -79,7 +85,7 @@ def default_query_results(query): # noqa: C901
|
||||
"COMMENT": "Comment for Table",
|
||||
"CLUSTERING_KEY": None,
|
||||
}
|
||||
for tbl_idx in range(1, NUM_TABLES + 1)
|
||||
for tbl_idx in range(1, num_tables + 1)
|
||||
]
|
||||
elif query == SnowflakeQuery.show_views_for_schema("TEST_SCHEMA", "TEST_DB"):
|
||||
return [
|
||||
@ -90,7 +96,7 @@ def default_query_results(query): # noqa: C901
|
||||
"comment": "Comment for View",
|
||||
"text": None,
|
||||
}
|
||||
for view_idx in range(1, NUM_VIEWS + 1)
|
||||
for view_idx in range(1, num_views + 1)
|
||||
]
|
||||
elif query == SnowflakeQuery.columns_for_schema("TEST_SCHEMA", "TEST_DB"):
|
||||
raise Exception("Information schema query returned too much data")
|
||||
@ -99,13 +105,13 @@ def default_query_results(query): # noqa: C901
|
||||
SnowflakeQuery.columns_for_table(
|
||||
"TABLE_{}".format(tbl_idx), "TEST_SCHEMA", "TEST_DB"
|
||||
)
|
||||
for tbl_idx in range(1, NUM_TABLES + 1)
|
||||
for tbl_idx in range(1, num_tables + 1)
|
||||
],
|
||||
*[
|
||||
SnowflakeQuery.columns_for_table(
|
||||
"VIEW_{}".format(view_idx), "TEST_SCHEMA", "TEST_DB"
|
||||
)
|
||||
for view_idx in range(1, NUM_VIEWS + 1)
|
||||
for view_idx in range(1, num_views + 1)
|
||||
],
|
||||
]:
|
||||
return [
|
||||
@ -122,7 +128,7 @@ def default_query_results(query): # noqa: C901
|
||||
"NUMERIC_PRECISION": None if col_idx > 1 else 38,
|
||||
"NUMERIC_SCALE": None if col_idx > 1 else 0,
|
||||
}
|
||||
for col_idx in range(1, NUM_COLS + 1)
|
||||
for col_idx in range(1, num_cols + 1)
|
||||
]
|
||||
elif query in (
|
||||
SnowflakeQuery.use_database("TEST_DB"),
|
||||
@ -158,7 +164,7 @@ def default_query_results(query): # noqa: C901
|
||||
{
|
||||
"columns": [
|
||||
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
|
||||
for col_idx in range(1, NUM_COLS + 1)
|
||||
for col_idx in range(1, num_cols + 1)
|
||||
],
|
||||
"objectDomain": "Table",
|
||||
"objectId": 0,
|
||||
@ -167,7 +173,7 @@ def default_query_results(query): # noqa: C901
|
||||
{
|
||||
"columns": [
|
||||
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
|
||||
for col_idx in range(1, NUM_COLS + 1)
|
||||
for col_idx in range(1, num_cols + 1)
|
||||
],
|
||||
"objectDomain": "Table",
|
||||
"objectId": 0,
|
||||
@ -176,7 +182,7 @@ def default_query_results(query): # noqa: C901
|
||||
{
|
||||
"columns": [
|
||||
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
|
||||
for col_idx in range(1, NUM_COLS + 1)
|
||||
for col_idx in range(1, num_cols + 1)
|
||||
],
|
||||
"objectDomain": "Table",
|
||||
"objectId": 0,
|
||||
@ -189,7 +195,7 @@ def default_query_results(query): # noqa: C901
|
||||
{
|
||||
"columns": [
|
||||
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
|
||||
for col_idx in range(1, NUM_COLS + 1)
|
||||
for col_idx in range(1, num_cols + 1)
|
||||
],
|
||||
"objectDomain": "Table",
|
||||
"objectId": 0,
|
||||
@ -198,7 +204,7 @@ def default_query_results(query): # noqa: C901
|
||||
{
|
||||
"columns": [
|
||||
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
|
||||
for col_idx in range(1, NUM_COLS + 1)
|
||||
for col_idx in range(1, num_cols + 1)
|
||||
],
|
||||
"objectDomain": "Table",
|
||||
"objectId": 0,
|
||||
@ -207,7 +213,7 @@ def default_query_results(query): # noqa: C901
|
||||
{
|
||||
"columns": [
|
||||
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
|
||||
for col_idx in range(1, NUM_COLS + 1)
|
||||
for col_idx in range(1, num_cols + 1)
|
||||
],
|
||||
"objectDomain": "Table",
|
||||
"objectId": 0,
|
||||
@ -231,7 +237,7 @@ def default_query_results(query): # noqa: C901
|
||||
}
|
||||
],
|
||||
}
|
||||
for col_idx in range(1, NUM_COLS + 1)
|
||||
for col_idx in range(1, num_cols + 1)
|
||||
],
|
||||
"objectDomain": "Table",
|
||||
"objectId": 0,
|
||||
@ -246,7 +252,7 @@ def default_query_results(query): # noqa: C901
|
||||
"EMAIL": "abc@xyz.com",
|
||||
"ROLE_NAME": "ACCOUNTADMIN",
|
||||
}
|
||||
for op_idx in range(1, NUM_OPS + 1)
|
||||
for op_idx in range(1, num_ops + 1)
|
||||
]
|
||||
elif (
|
||||
query
|
||||
@ -276,7 +282,7 @@ def default_query_results(query): # noqa: C901
|
||||
"UPSTREAM_TABLE_COLUMNS": json.dumps(
|
||||
[
|
||||
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
|
||||
for col_idx in range(1, NUM_COLS + 1)
|
||||
for col_idx in range(1, num_cols + 1)
|
||||
]
|
||||
),
|
||||
"DOWNSTREAM_TABLE_COLUMNS": json.dumps(
|
||||
@ -293,11 +299,11 @@ def default_query_results(query): # noqa: C901
|
||||
}
|
||||
],
|
||||
}
|
||||
for col_idx in range(1, NUM_COLS + 1)
|
||||
for col_idx in range(1, num_cols + 1)
|
||||
]
|
||||
),
|
||||
}
|
||||
for op_idx in range(1, NUM_OPS + 1)
|
||||
for op_idx in range(1, num_ops + 1)
|
||||
] + [
|
||||
{
|
||||
"DOWNSTREAM_TABLE_NAME": "TEST_DB.TEST_SCHEMA.TABLE_1",
|
||||
@ -371,7 +377,7 @@ def default_query_results(query): # noqa: C901
|
||||
]
|
||||
],
|
||||
}
|
||||
for col_idx in range(1, NUM_COLS + 1)
|
||||
for col_idx in range(1, num_cols + 1)
|
||||
]
|
||||
+ ( # This additional upstream is only for TABLE_1
|
||||
[
|
||||
@ -393,7 +399,7 @@ def default_query_results(query): # noqa: C901
|
||||
)
|
||||
),
|
||||
}
|
||||
for op_idx in range(1, NUM_OPS + 1)
|
||||
for op_idx in range(1, num_ops + 1)
|
||||
]
|
||||
elif query in (
|
||||
snowflake_query.SnowflakeQuery.table_to_table_lineage_history_v2(
|
||||
@ -426,7 +432,7 @@ def default_query_results(query): # noqa: C901
|
||||
)
|
||||
),
|
||||
}
|
||||
for op_idx in range(1, NUM_OPS + 1)
|
||||
for op_idx in range(1, num_ops + 1)
|
||||
]
|
||||
elif query == snowflake_query.SnowflakeQuery.external_table_lineage_history(
|
||||
1654499820000,
|
||||
@ -479,7 +485,7 @@ def default_query_results(query): # noqa: C901
|
||||
"VIEW_COLUMNS": json.dumps(
|
||||
[
|
||||
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
|
||||
for col_idx in range(1, NUM_COLS + 1)
|
||||
for col_idx in range(1, num_cols + 1)
|
||||
]
|
||||
),
|
||||
"DOWNSTREAM_TABLE_DOMAIN": "TABLE",
|
||||
@ -497,7 +503,7 @@ def default_query_results(query): # noqa: C901
|
||||
}
|
||||
],
|
||||
}
|
||||
for col_idx in range(1, NUM_COLS + 1)
|
||||
for col_idx in range(1, num_cols + 1)
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
@ -55,7 +55,6 @@ def random_cloud_region():
|
||||
)
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
@pytest.mark.integration
|
||||
def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
|
||||
test_resources_dir = pytestconfig.rootpath / "tests/integration/snowflake"
|
||||
@ -167,7 +166,13 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
|
||||
pytestconfig,
|
||||
output_path=output_file,
|
||||
golden_path=golden_file,
|
||||
ignore_paths=[],
|
||||
ignore_paths=[
|
||||
r"root\[\d+\]\['aspect'\]\['json'\]\['timestampMillis'\]",
|
||||
r"root\[\d+\]\['aspect'\]\['json'\]\['created'\]",
|
||||
r"root\[\d+\]\['aspect'\]\['json'\]\['lastModified'\]",
|
||||
r"root\[\d+\]\['aspect'\]\['json'\]\['fields'\]\[\d+\]\['glossaryTerms'\]\['auditStamp'\]\['time'\]",
|
||||
r"root\[\d+\]\['systemMetadata'\]",
|
||||
],
|
||||
)
|
||||
report = cast(SnowflakeV2Report, pipeline.source.get_report())
|
||||
assert report.lru_cache_info["get_tables_for_database"]["misses"] == 1
|
||||
|
||||
@ -0,0 +1,103 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
from unittest import mock
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from datahub.configuration.common import AllowDenyPattern, DynamicTypedConfig
|
||||
from datahub.ingestion.glossary.classifier import (
|
||||
ClassificationConfig,
|
||||
DynamicTypedClassifierConfig,
|
||||
)
|
||||
from datahub.ingestion.glossary.datahub_classifier import DataHubClassifierConfig
|
||||
from datahub.ingestion.run.pipeline import Pipeline
|
||||
from datahub.ingestion.run.pipeline_config import PipelineConfig, SourceConfig
|
||||
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
|
||||
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
|
||||
from tests.integration.snowflake.common import default_query_results
|
||||
|
||||
NUM_SAMPLE_VALUES = 100
|
||||
TEST_CLASSIFY_PERFORMANCE = os.environ.get("DATAHUB_TEST_CLASSIFY_PERFORMANCE")
|
||||
|
||||
sample_values = ["abc@xyz.com" for _ in range(NUM_SAMPLE_VALUES)]
|
||||
|
||||
|
||||
# Run with --durations=0 to show the timings for different combinations
|
||||
@pytest.mark.skipif(
|
||||
TEST_CLASSIFY_PERFORMANCE is None,
|
||||
reason="DATAHUB_TEST_CLASSIFY_PERFORMANCE env variable is not configured",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_workers,num_cols_per_table,num_tables",
|
||||
[(w, c, t) for w in [1, 2, 4, 6, 8] for c in [5, 10, 40, 80] for t in [1]],
|
||||
)
|
||||
def test_snowflake_classification_perf(num_workers, num_cols_per_table, num_tables):
|
||||
with mock.patch("snowflake.connector.connect") as mock_connect, mock.patch(
|
||||
"datahub.ingestion.source.snowflake.snowflake_v2.SnowflakeV2Source.get_sample_values_for_table"
|
||||
) as mock_sample_values:
|
||||
sf_connection = mock.MagicMock()
|
||||
sf_cursor = mock.MagicMock()
|
||||
mock_connect.return_value = sf_connection
|
||||
sf_connection.cursor.return_value = sf_cursor
|
||||
|
||||
sf_cursor.execute.side_effect = partial(
|
||||
default_query_results, num_tables=num_tables, num_cols=num_cols_per_table
|
||||
)
|
||||
|
||||
mock_sample_values.return_value = pd.DataFrame(
|
||||
data={f"col_{i}": sample_values for i in range(1, num_cols_per_table + 1)}
|
||||
)
|
||||
|
||||
datahub_classifier_config = DataHubClassifierConfig(
|
||||
confidence_level_threshold=0.58,
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
config=PipelineConfig(
|
||||
source=SourceConfig(
|
||||
type="snowflake",
|
||||
config=SnowflakeV2Config(
|
||||
account_id="ABC12345.ap-south-1.aws",
|
||||
username="TST_USR",
|
||||
password="TST_PWD",
|
||||
match_fully_qualified_names=True,
|
||||
schema_pattern=AllowDenyPattern(allow=["test_db.test_schema"]),
|
||||
include_technical_schema=True,
|
||||
include_table_lineage=False,
|
||||
include_view_lineage=False,
|
||||
include_column_lineage=False,
|
||||
include_usage_stats=False,
|
||||
include_operational_stats=False,
|
||||
classification=ClassificationConfig(
|
||||
enabled=True,
|
||||
max_workers=num_workers,
|
||||
classifiers=[
|
||||
DynamicTypedClassifierConfig(
|
||||
type="datahub", config=datahub_classifier_config
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
),
|
||||
sink=DynamicTypedConfig(type="blackhole", config={}),
|
||||
)
|
||||
)
|
||||
pipeline.run()
|
||||
pipeline.pretty_print_summary()
|
||||
pipeline.raise_from_status()
|
||||
|
||||
source_report = pipeline.source.get_report()
|
||||
assert isinstance(source_report, SnowflakeV2Report)
|
||||
assert (
|
||||
cast(SnowflakeV2Report, source_report).num_tables_classified == num_tables
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
cast(SnowflakeV2Report, source_report).info_types_detected[
|
||||
"Email_Address"
|
||||
]
|
||||
)
|
||||
== num_tables * num_cols_per_table
|
||||
)
|
||||
@ -29,7 +29,6 @@ from tests.integration.snowflake.test_snowflake import random_cloud_region, rand
|
||||
from tests.test_helpers import mce_helpers
|
||||
|
||||
|
||||
@freeze_time(FROZEN_TIME)
|
||||
@pytest.mark.integration
|
||||
def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
|
||||
test_resources_dir = pytestconfig.rootpath / "tests/integration/snowflake"
|
||||
@ -107,9 +106,6 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
|
||||
),
|
||||
classification=ClassificationConfig(
|
||||
enabled=True,
|
||||
column_pattern=AllowDenyPattern(
|
||||
allow=[".*col_1$", ".*col_2$", ".*col_3$"]
|
||||
),
|
||||
classifiers=[
|
||||
DynamicTypedClassifierConfig(
|
||||
type="datahub", config=datahub_classifier_config
|
||||
@ -141,7 +137,13 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
|
||||
pytestconfig,
|
||||
output_path=output_file,
|
||||
golden_path=golden_file,
|
||||
ignore_paths=[],
|
||||
ignore_paths=[
|
||||
r"root\[\d+\]\['aspect'\]\['json'\]\['timestampMillis'\]",
|
||||
r"root\[\d+\]\['aspect'\]\['json'\]\['created'\]",
|
||||
r"root\[\d+\]\['aspect'\]\['json'\]\['lastModified'\]",
|
||||
r"root\[\d+\]\['aspect'\]\['json'\]\['fields'\]\[\d+\]\['glossaryTerms'\]\['auditStamp'\]\['time'\]",
|
||||
r"root\[\d+\]\['systemMetadata'\]",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user