feat(classification): allow parallelisation to reduce time (#8368)

This commit is contained in:
Mayuri Nehate 2023-08-02 09:53:39 +05:30 committed by GitHub
parent bf47d65412
commit e67f811034
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 232 additions and 57 deletions

View File

@ -10,6 +10,7 @@ Note that a `.` is used to denote nested fields in the YAML recipe.
| ------------------------- | -------- | --------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------- |
| enabled | | boolean | Whether classification should be used to auto-detect glossary terms | False |
| sample_size | | int | Number of sample values used for classification. | 100 |
| max_workers | | int | Number of worker threads to use for classification. Set to 1 to disable. | Number of cpu cores or 4 |
| info_type_to_term | | Dict[str,string] | Optional mapping to provide glossary term identifier for info type. | By default, info type is used as glossary term identifier. |
| classifiers | | Array of object | Classifiers to use to auto-detect glossary terms. If more than one classifier, infotype predictions from the classifier defined later in sequence take precedance. | [{'type': 'datahub', 'config': None}] |
| table_pattern | | AllowDenyPattern (see below for fields) | Regex patterns to filter tables for classification. This is used in combination with other patterns in parent config. Specify regex to match the entire table name in `database.schema.table` format. e.g. to match all tables starting with customer in Customer database and public schema, use the regex 'Customer.public.customer.*' | {'allow': ['.*'], 'deny': [], 'ignoreCase': True} |

View File

@ -1,6 +1,8 @@
import concurrent.futures
import logging
from dataclasses import dataclass, field
from typing import Dict, List
from math import ceil
from typing import Dict, Iterable, List, Optional
from datahub_classify.helper_classes import ColumnInfo, Metadata
from pydantic import Field
@ -108,15 +110,23 @@ class ClassificationHandler:
return None
logger.debug(f"Classifying Table {dataset_name}")
self.report.num_tables_classification_attempted += 1
field_terms: Dict[str, str] = {}
with PerfTimer() as timer:
try:
for classifier in self.classifiers:
column_info_with_proposals = classifier.classify(column_infos)
self.extract_field_wise_terms(
field_terms, column_info_with_proposals
)
column_infos_with_proposals: Iterable[ColumnInfo]
if self.config.classification.max_workers > 1:
column_infos_with_proposals = self.async_classify(
classifier, column_infos
)
else:
column_infos_with_proposals = classifier.classify(column_infos)
for column_info_proposal in column_infos_with_proposals:
self.update_field_terms(field_terms, column_info_proposal)
except Exception:
self.report.num_tables_classification_failed += 1
raise
@ -130,6 +140,44 @@ class ClassificationHandler:
self.report.num_tables_classified += 1
self.populate_terms_in_schema_metadata(schema_metadata, field_terms)
def update_field_terms(
self, field_terms: Dict[str, str], col_info: ColumnInfo
) -> None:
term = self.get_terms_for_column(col_info)
if term:
field_terms[col_info.metadata.name] = term
def async_classify(
self, classifier: Classifier, columns: List[ColumnInfo]
) -> Iterable[ColumnInfo]:
num_columns = len(columns)
BATCH_SIZE = 5 # Number of columns passed to classify api at a time
logger.debug(
f"Will Classify {num_columns} column(s) with {self.config.classification.max_workers} worker(s) with batch size {BATCH_SIZE}."
)
with concurrent.futures.ProcessPoolExecutor(
max_workers=self.config.classification.max_workers,
) as executor:
column_info_proposal_futures = [
executor.submit(
classifier.classify,
columns[
(i * BATCH_SIZE) : min(i * BATCH_SIZE + BATCH_SIZE, num_columns)
],
)
for i in range(ceil(num_columns / BATCH_SIZE))
]
return [
column_with_proposal
for proposal_future in concurrent.futures.as_completed(
column_info_proposal_futures
)
for column_with_proposal in proposal_future.result()
]
def populate_terms_in_schema_metadata(
self,
schema_metadata: SchemaMetadata,
@ -154,25 +202,20 @@ class ClassificationHandler:
),
)
def extract_field_wise_terms(
self,
field_terms: Dict[str, str],
column_info_with_proposals: List[ColumnInfo],
) -> None:
for col_info in column_info_with_proposals:
if not col_info.infotype_proposals:
continue
infotype_proposal = max(
col_info.infotype_proposals, key=lambda p: p.confidence_level
)
self.report.info_types_detected.setdefault(
infotype_proposal.infotype, LossyList()
).append(f"{col_info.metadata.dataset_name}.{col_info.metadata.name}")
field_terms[
col_info.metadata.name
] = self.config.classification.info_type_to_term.get(
infotype_proposal.infotype, infotype_proposal.infotype
)
def get_terms_for_column(self, col_info: ColumnInfo) -> Optional[str]:
if not col_info.infotype_proposals:
return None
infotype_proposal = max(
col_info.infotype_proposals, key=lambda p: p.confidence_level
)
self.report.info_types_detected.setdefault(
infotype_proposal.infotype, LossyList()
).append(f"{col_info.metadata.dataset_name}.{col_info.metadata.name}")
term = self.config.classification.info_type_to_term.get(
infotype_proposal.infotype, infotype_proposal.infotype
)
return term
def get_columns_to_classify(
self,

View File

@ -1,3 +1,4 @@
import os
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
@ -36,6 +37,11 @@ class ClassificationConfig(ConfigModel):
default=100, description="Number of sample values used for classification."
)
max_workers: int = Field(
default=(os.cpu_count() or 4),
description="Number of worker threads to use for classification. Set to 1 to disable.",
)
table_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="Regex patterns to filter tables for classification. This is used in combination with other patterns in parent config. Specify regex to match the entire table name in `database.schema.table` format. e.g. to match all tables starting with customer in Customer database and public schema, use the regex 'Customer.public.customer.*'",

View File

@ -173,4 +173,5 @@ class DataHubClassifier(Classifier):
infotypes=self.config.info_types,
minimum_values_threshold=self.config.minimum_values_threshold,
)
return columns

View File

@ -872,8 +872,8 @@ class SnowflakeV2Source(
self.gen_schema_metadata(table, schema_name, db_name)
def fetch_sample_data_for_classification(
self, table, schema_name, db_name, dataset_name
):
self, table: SnowflakeTable, schema_name: str, db_name: str, dataset_name: str
) -> None:
if (
table.columns
and self.config.classification.enabled
@ -1225,7 +1225,12 @@ class SnowflakeV2Source(
)
return foreign_keys
def classify_snowflake_table(self, table, dataset_name, schema_metadata):
def classify_snowflake_table(
self,
table: Union[SnowflakeTable, SnowflakeView],
dataset_name: str,
schema_metadata: SchemaMetadata,
) -> None:
if (
isinstance(table, SnowflakeTable)
and self.config.classification.enabled
@ -1255,6 +1260,9 @@ class SnowflakeV2Source(
"Failed to classify table columns",
dataset_name,
)
finally:
# Cleaning up sample_data fetched for classification
table.sample_data = None
def get_report(self) -> SourceReport:
return self.report
@ -1470,7 +1478,7 @@ class SnowflakeV2Source(
df = pd.DataFrame(dat, columns=[col.name for col in cur.description])
time_taken = timer.elapsed_seconds()
logger.debug(
f"Finished collecting sample values for table {db_name}.{schema_name}.{table_name}; took {time_taken:.3f} seconds"
f"Finished collecting sample values for table {db_name}.{schema_name}.{table_name};{df.shape[0]} rows; took {time_taken:.3f} seconds"
)
return df

View File

@ -14,7 +14,13 @@ NUM_OPS = 10
FROZEN_TIME = "2022-06-07 17:00:00"
def default_query_results(query): # noqa: C901
def default_query_results( # noqa: C901
query,
num_tables=NUM_TABLES,
num_views=NUM_VIEWS,
num_cols=NUM_COLS,
num_ops=NUM_OPS,
):
if query == SnowflakeQuery.current_account():
return [{"CURRENT_ACCOUNT()": "ABC12345"}]
if query == SnowflakeQuery.current_region():
@ -79,7 +85,7 @@ def default_query_results(query): # noqa: C901
"COMMENT": "Comment for Table",
"CLUSTERING_KEY": None,
}
for tbl_idx in range(1, NUM_TABLES + 1)
for tbl_idx in range(1, num_tables + 1)
]
elif query == SnowflakeQuery.show_views_for_schema("TEST_SCHEMA", "TEST_DB"):
return [
@ -90,7 +96,7 @@ def default_query_results(query): # noqa: C901
"comment": "Comment for View",
"text": None,
}
for view_idx in range(1, NUM_VIEWS + 1)
for view_idx in range(1, num_views + 1)
]
elif query == SnowflakeQuery.columns_for_schema("TEST_SCHEMA", "TEST_DB"):
raise Exception("Information schema query returned too much data")
@ -99,13 +105,13 @@ def default_query_results(query): # noqa: C901
SnowflakeQuery.columns_for_table(
"TABLE_{}".format(tbl_idx), "TEST_SCHEMA", "TEST_DB"
)
for tbl_idx in range(1, NUM_TABLES + 1)
for tbl_idx in range(1, num_tables + 1)
],
*[
SnowflakeQuery.columns_for_table(
"VIEW_{}".format(view_idx), "TEST_SCHEMA", "TEST_DB"
)
for view_idx in range(1, NUM_VIEWS + 1)
for view_idx in range(1, num_views + 1)
],
]:
return [
@ -122,7 +128,7 @@ def default_query_results(query): # noqa: C901
"NUMERIC_PRECISION": None if col_idx > 1 else 38,
"NUMERIC_SCALE": None if col_idx > 1 else 0,
}
for col_idx in range(1, NUM_COLS + 1)
for col_idx in range(1, num_cols + 1)
]
elif query in (
SnowflakeQuery.use_database("TEST_DB"),
@ -158,7 +164,7 @@ def default_query_results(query): # noqa: C901
{
"columns": [
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
for col_idx in range(1, NUM_COLS + 1)
for col_idx in range(1, num_cols + 1)
],
"objectDomain": "Table",
"objectId": 0,
@ -167,7 +173,7 @@ def default_query_results(query): # noqa: C901
{
"columns": [
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
for col_idx in range(1, NUM_COLS + 1)
for col_idx in range(1, num_cols + 1)
],
"objectDomain": "Table",
"objectId": 0,
@ -176,7 +182,7 @@ def default_query_results(query): # noqa: C901
{
"columns": [
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
for col_idx in range(1, NUM_COLS + 1)
for col_idx in range(1, num_cols + 1)
],
"objectDomain": "Table",
"objectId": 0,
@ -189,7 +195,7 @@ def default_query_results(query): # noqa: C901
{
"columns": [
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
for col_idx in range(1, NUM_COLS + 1)
for col_idx in range(1, num_cols + 1)
],
"objectDomain": "Table",
"objectId": 0,
@ -198,7 +204,7 @@ def default_query_results(query): # noqa: C901
{
"columns": [
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
for col_idx in range(1, NUM_COLS + 1)
for col_idx in range(1, num_cols + 1)
],
"objectDomain": "Table",
"objectId": 0,
@ -207,7 +213,7 @@ def default_query_results(query): # noqa: C901
{
"columns": [
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
for col_idx in range(1, NUM_COLS + 1)
for col_idx in range(1, num_cols + 1)
],
"objectDomain": "Table",
"objectId": 0,
@ -231,7 +237,7 @@ def default_query_results(query): # noqa: C901
}
],
}
for col_idx in range(1, NUM_COLS + 1)
for col_idx in range(1, num_cols + 1)
],
"objectDomain": "Table",
"objectId": 0,
@ -246,7 +252,7 @@ def default_query_results(query): # noqa: C901
"EMAIL": "abc@xyz.com",
"ROLE_NAME": "ACCOUNTADMIN",
}
for op_idx in range(1, NUM_OPS + 1)
for op_idx in range(1, num_ops + 1)
]
elif (
query
@ -276,7 +282,7 @@ def default_query_results(query): # noqa: C901
"UPSTREAM_TABLE_COLUMNS": json.dumps(
[
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
for col_idx in range(1, NUM_COLS + 1)
for col_idx in range(1, num_cols + 1)
]
),
"DOWNSTREAM_TABLE_COLUMNS": json.dumps(
@ -293,11 +299,11 @@ def default_query_results(query): # noqa: C901
}
],
}
for col_idx in range(1, NUM_COLS + 1)
for col_idx in range(1, num_cols + 1)
]
),
}
for op_idx in range(1, NUM_OPS + 1)
for op_idx in range(1, num_ops + 1)
] + [
{
"DOWNSTREAM_TABLE_NAME": "TEST_DB.TEST_SCHEMA.TABLE_1",
@ -371,7 +377,7 @@ def default_query_results(query): # noqa: C901
]
],
}
for col_idx in range(1, NUM_COLS + 1)
for col_idx in range(1, num_cols + 1)
]
+ ( # This additional upstream is only for TABLE_1
[
@ -393,7 +399,7 @@ def default_query_results(query): # noqa: C901
)
),
}
for op_idx in range(1, NUM_OPS + 1)
for op_idx in range(1, num_ops + 1)
]
elif query in (
snowflake_query.SnowflakeQuery.table_to_table_lineage_history_v2(
@ -426,7 +432,7 @@ def default_query_results(query): # noqa: C901
)
),
}
for op_idx in range(1, NUM_OPS + 1)
for op_idx in range(1, num_ops + 1)
]
elif query == snowflake_query.SnowflakeQuery.external_table_lineage_history(
1654499820000,
@ -479,7 +485,7 @@ def default_query_results(query): # noqa: C901
"VIEW_COLUMNS": json.dumps(
[
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
for col_idx in range(1, NUM_COLS + 1)
for col_idx in range(1, num_cols + 1)
]
),
"DOWNSTREAM_TABLE_DOMAIN": "TABLE",
@ -497,7 +503,7 @@ def default_query_results(query): # noqa: C901
}
],
}
for col_idx in range(1, NUM_COLS + 1)
for col_idx in range(1, num_cols + 1)
]
),
}

View File

@ -55,7 +55,6 @@ def random_cloud_region():
)
@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
test_resources_dir = pytestconfig.rootpath / "tests/integration/snowflake"
@ -167,7 +166,13 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
pytestconfig,
output_path=output_file,
golden_path=golden_file,
ignore_paths=[],
ignore_paths=[
r"root\[\d+\]\['aspect'\]\['json'\]\['timestampMillis'\]",
r"root\[\d+\]\['aspect'\]\['json'\]\['created'\]",
r"root\[\d+\]\['aspect'\]\['json'\]\['lastModified'\]",
r"root\[\d+\]\['aspect'\]\['json'\]\['fields'\]\[\d+\]\['glossaryTerms'\]\['auditStamp'\]\['time'\]",
r"root\[\d+\]\['systemMetadata'\]",
],
)
report = cast(SnowflakeV2Report, pipeline.source.get_report())
assert report.lru_cache_info["get_tables_for_database"]["misses"] == 1

View File

@ -0,0 +1,103 @@
import os
from functools import partial
from typing import cast
from unittest import mock
import pandas as pd
import pytest
from datahub.configuration.common import AllowDenyPattern, DynamicTypedConfig
from datahub.ingestion.glossary.classifier import (
ClassificationConfig,
DynamicTypedClassifierConfig,
)
from datahub.ingestion.glossary.datahub_classifier import DataHubClassifierConfig
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.run.pipeline_config import PipelineConfig, SourceConfig
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
from tests.integration.snowflake.common import default_query_results
NUM_SAMPLE_VALUES = 100
TEST_CLASSIFY_PERFORMANCE = os.environ.get("DATAHUB_TEST_CLASSIFY_PERFORMANCE")
sample_values = ["abc@xyz.com" for _ in range(NUM_SAMPLE_VALUES)]
# Run with --durations=0 to show the timings for different combinations
@pytest.mark.skipif(
TEST_CLASSIFY_PERFORMANCE is None,
reason="DATAHUB_TEST_CLASSIFY_PERFORMANCE env variable is not configured",
)
@pytest.mark.parametrize(
"num_workers,num_cols_per_table,num_tables",
[(w, c, t) for w in [1, 2, 4, 6, 8] for c in [5, 10, 40, 80] for t in [1]],
)
def test_snowflake_classification_perf(num_workers, num_cols_per_table, num_tables):
with mock.patch("snowflake.connector.connect") as mock_connect, mock.patch(
"datahub.ingestion.source.snowflake.snowflake_v2.SnowflakeV2Source.get_sample_values_for_table"
) as mock_sample_values:
sf_connection = mock.MagicMock()
sf_cursor = mock.MagicMock()
mock_connect.return_value = sf_connection
sf_connection.cursor.return_value = sf_cursor
sf_cursor.execute.side_effect = partial(
default_query_results, num_tables=num_tables, num_cols=num_cols_per_table
)
mock_sample_values.return_value = pd.DataFrame(
data={f"col_{i}": sample_values for i in range(1, num_cols_per_table + 1)}
)
datahub_classifier_config = DataHubClassifierConfig(
confidence_level_threshold=0.58,
)
pipeline = Pipeline(
config=PipelineConfig(
source=SourceConfig(
type="snowflake",
config=SnowflakeV2Config(
account_id="ABC12345.ap-south-1.aws",
username="TST_USR",
password="TST_PWD",
match_fully_qualified_names=True,
schema_pattern=AllowDenyPattern(allow=["test_db.test_schema"]),
include_technical_schema=True,
include_table_lineage=False,
include_view_lineage=False,
include_column_lineage=False,
include_usage_stats=False,
include_operational_stats=False,
classification=ClassificationConfig(
enabled=True,
max_workers=num_workers,
classifiers=[
DynamicTypedClassifierConfig(
type="datahub", config=datahub_classifier_config
)
],
),
),
),
sink=DynamicTypedConfig(type="blackhole", config={}),
)
)
pipeline.run()
pipeline.pretty_print_summary()
pipeline.raise_from_status()
source_report = pipeline.source.get_report()
assert isinstance(source_report, SnowflakeV2Report)
assert (
cast(SnowflakeV2Report, source_report).num_tables_classified == num_tables
)
assert (
len(
cast(SnowflakeV2Report, source_report).info_types_detected[
"Email_Address"
]
)
== num_tables * num_cols_per_table
)

View File

@ -29,7 +29,6 @@ from tests.integration.snowflake.test_snowflake import random_cloud_region, rand
from tests.test_helpers import mce_helpers
@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
test_resources_dir = pytestconfig.rootpath / "tests/integration/snowflake"
@ -107,9 +106,6 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
),
classification=ClassificationConfig(
enabled=True,
column_pattern=AllowDenyPattern(
allow=[".*col_1$", ".*col_2$", ".*col_3$"]
),
classifiers=[
DynamicTypedClassifierConfig(
type="datahub", config=datahub_classifier_config
@ -141,7 +137,13 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
pytestconfig,
output_path=output_file,
golden_path=golden_file,
ignore_paths=[],
ignore_paths=[
r"root\[\d+\]\['aspect'\]\['json'\]\['timestampMillis'\]",
r"root\[\d+\]\['aspect'\]\['json'\]\['created'\]",
r"root\[\d+\]\['aspect'\]\['json'\]\['lastModified'\]",
r"root\[\d+\]\['aspect'\]\['json'\]\['fields'\]\[\d+\]\['glossaryTerms'\]\['auditStamp'\]\['time'\]",
r"root\[\d+\]\['systemMetadata'\]",
],
)