fix(ingest): add support for database and table patterns to glue source (#2339)

This commit is contained in:
Harshal Sheth 2021-04-05 17:14:02 -07:00 committed by GitHub
parent 6e762ce3bc
commit c1f3eaed35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 58 additions and 9 deletions

View File

@ -367,7 +367,8 @@ source:
config: config:
aws_region: aws_region_name # i.e. "eu-west-1" aws_region: aws_region_name # i.e. "eu-west-1"
env: environment used for the DatasetSnapshot URN, one of "DEV", "EI", "PROD" or "CORP". # Optional, defaults to "PROD". env: environment used for the DatasetSnapshot URN, one of "DEV", "EI", "PROD" or "CORP". # Optional, defaults to "PROD".
databases: list of databases to process. # Optional, if not specified then all databases will be processed. database_pattern: # Optional, to filter databases scanned, same as schema_pattern above.
table_pattern: # Optional, to filter tables scanned, same as table_pattern above.
aws_access_key_id # Optional. If not specified, credentials are picked up according to boto3 rules. aws_access_key_id # Optional. If not specified, credentials are picked up according to boto3 rules.
# See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html # See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
aws_secret_access_key # Optional. aws_secret_access_key # Optional.

View File

@ -79,10 +79,10 @@ plugins: Dict[str, Set[str]] = {
"ldap": {"python-ldap>=2.4"}, "ldap": {"python-ldap>=2.4"},
"druid": sql_common | {"pydruid>=0.6.2"}, "druid": sql_common | {"pydruid>=0.6.2"},
"mongodb": {"pymongo>=3.11"}, "mongodb": {"pymongo>=3.11"},
"glue": {"boto3"},
# Sink plugins. # Sink plugins.
"datahub-kafka": kafka_common, "datahub-kafka": kafka_common,
"datahub-rest": {"requests>=2.25.1"}, "datahub-rest": {"requests>=2.25.1"},
"glue": {"boto3"},
} }
dev_requirements = { dev_requirements = {

View File

@ -54,6 +54,11 @@ class AllowDenyPattern(ConfigModel):
allow: List[str] = [".*"] allow: List[str] = [".*"]
deny: List[str] = [] deny: List[str] = []
alphabet: str = "[A-Za-z0-9 _.-]"
@property
def alphabet_pattern(self):
return re.compile(f"^{self.alphabet}+$")
@classmethod @classmethod
def allow_all(cls): def allow_all(cls):
@ -69,3 +74,20 @@ class AllowDenyPattern(ConfigModel):
return True return True
return False return False
def is_fully_specified_allow_list(self) -> bool:
"""
If the allow patterns are literals and not full regexes, then it is considered
fully specified. This is useful if you want to convert a 'list + filter'
pattern into a 'search for the ones that are allowed' pattern, which can be
much more efficient in some cases.
"""
for allow_pattern in self.allow:
if not self.alphabet_pattern.match(allow_pattern):
return False
return True
def get_allowed_list(self):
"""Return the list of allowed strings as a list, after taking into account deny patterns, if possible"""
assert self.is_fully_specified_allow_list()
return [a for a in self.allow if self.allowed(a)]

View File

@ -1,10 +1,12 @@
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from dataclasses import field as dataclass_field
from typing import Dict, Iterable, List, Optional from typing import Dict, Iterable, List, Optional
import boto3 import boto3
from datahub.configuration import ConfigModel from datahub.configuration import ConfigModel
from datahub.configuration.common import AllowDenyPattern
from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.source import Source, SourceReport from datahub.ingestion.api.source import Source, SourceReport
from datahub.ingestion.source.metadata_common import MetadataWorkUnit from datahub.ingestion.source.metadata_common import MetadataWorkUnit
@ -38,7 +40,8 @@ from datahub.metadata.schema_classes import (
class GlueSourceConfig(ConfigModel): class GlueSourceConfig(ConfigModel):
env: str = "PROD" env: str = "PROD"
databases: Optional[List[str]] = None database_pattern: AllowDenyPattern = AllowDenyPattern.allow_all()
table_pattern: AllowDenyPattern = AllowDenyPattern.allow_all()
aws_access_key_id: Optional[str] = None aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None aws_secret_access_key: Optional[str] = None
aws_session_token: Optional[str] = None aws_session_token: Optional[str] = None
@ -72,10 +75,14 @@ class GlueSourceConfig(ConfigModel):
@dataclass @dataclass
class GlueSourceReport(SourceReport): class GlueSourceReport(SourceReport):
tables_scanned = 0 tables_scanned = 0
filtered: List[str] = dataclass_field(default_factory=list)
def report_table_scanned(self) -> None: def report_table_scanned(self) -> None:
self.tables_scanned += 1 self.tables_scanned += 1
def report_table_dropped(self, table: str) -> None:
self.filtered.append(table)
class GlueSource(Source): class GlueSource(Source):
source_config: GlueSourceConfig source_config: GlueSourceConfig
@ -87,7 +94,6 @@ class GlueSource(Source):
self.report = GlueSourceReport() self.report = GlueSourceReport()
self.glue_client = config.glue_client self.glue_client = config.glue_client
self.env = config.env self.env = config.env
self.databases = config.databases
@classmethod @classmethod
def create(cls, config_dict, ctx): def create(cls, config_dict, ctx):
@ -95,7 +101,7 @@ class GlueSource(Source):
return cls(config, ctx) return cls(config, ctx)
def get_workunits(self) -> Iterable[MetadataWorkUnit]: def get_workunits(self) -> Iterable[MetadataWorkUnit]:
def get_all_tables(database_names: Optional[List[str]]): def get_all_tables():
def get_tables_from_database(database_name: str, tables: List): def get_tables_from_database(database_name: str, tables: List):
kwargs = {"DatabaseName": database_name} kwargs = {"DatabaseName": database_name}
while True: while True:
@ -119,22 +125,28 @@ class GlueSource(Source):
break break
return tables return tables
if database_names: if self.source_config.database_pattern.is_fully_specified_allow_list():
all_tables: List = [] all_tables: List = []
database_names = self.source_config.database_pattern.get_allowed_list()
for database in database_names: for database in database_names:
all_tables += get_tables_from_database(database, all_tables) all_tables += get_tables_from_database(database, all_tables)
else: else:
all_tables = get_tables_from_all_databases() all_tables = get_tables_from_all_databases()
return all_tables return all_tables
tables = get_all_tables(self.databases) tables = get_all_tables()
for table in tables: for table in tables:
table_name = table["Name"]
database_name = table["DatabaseName"] database_name = table["DatabaseName"]
table_name = table["Name"]
full_table_name = f"{database_name}.{table_name}" full_table_name = f"{database_name}.{table_name}"
self.report.report_table_scanned() self.report.report_table_scanned()
if not self.source_config.database_pattern.allowed(
database_name
) or not self.source_config.table_pattern.allowed(full_table_name):
self.report.report_table_dropped(full_table_name)
continue
mce = self._extract_record(table, full_table_name) mce = self._extract_record(table, full_table_name)
workunit = MetadataWorkUnit(id=f"glue-{full_table_name}", mce=mce) workunit = MetadataWorkUnit(id=f"glue-{full_table_name}", mce=mce)
self.report.report_workunit(workunit) self.report.report_workunit(workunit)

View File

@ -19,3 +19,17 @@ def test_single_table():
def test_default_deny(): def test_default_deny():
pattern = AllowDenyPattern(allow=["foo.mytable"]) pattern = AllowDenyPattern(allow=["foo.mytable"])
assert not pattern.allowed("foo.bar") assert not pattern.allowed("foo.bar")
def test_fully_speced():
pattern = AllowDenyPattern(allow=["foo.mytable"])
assert pattern.is_fully_specified_allow_list()
pattern = AllowDenyPattern(allow=["foo.*", "foo.table"])
assert not pattern.is_fully_specified_allow_list()
pattern = AllowDenyPattern(allow=["foo.?", "foo.table"])
assert not pattern.is_fully_specified_allow_list()
def test_is_allowed():
pattern = AllowDenyPattern(allow=["foo.mytable"], deny=["foo.*"])
assert pattern.get_allowed_list() == []