feat(ingest): Option to define path spec for Redshift lineage generation (#5256)

This commit is contained in:
Tamas Nemeth 2022-06-27 17:51:13 +02:00 committed by GitHub
parent e2d849de4b
commit 60ff0f45ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 243 additions and 193 deletions

View File

@ -0,0 +1,188 @@
import logging
import os
import re
from typing import Any, Dict, List, Optional, Tuple, Union
import parse
import pydantic
from pydantic.fields import Field
from wcmatch import pathlib
from datahub.configuration.common import ConfigModel
from datahub.ingestion.source.aws.s3_util import is_s3_uri
# hide annoying debug errors from py4j
logging.getLogger("py4j").setLevel(logging.ERROR)
logger: logging.Logger = logging.getLogger(__name__)
SUPPORTED_FILE_TYPES: List[str] = ["csv", "tsv", "json", "parquet", "avro"]
SUPPORTED_COMPRESSIONS: List[str] = ["gz", "bz2"]
class PathSpec(ConfigModel):
class Config:
arbitrary_types_allowed = True
include: str = Field(
description="Path to table (s3 or local file system). Name variable {table} is used to mark the folder with dataset. In absence of {table}, file level dataset will be created. Check below examples for more details."
)
exclude: Optional[List[str]] = Field(
default=None,
description="list of paths in glob pattern which will be excluded while scanning for the datasets",
)
file_types: List[str] = Field(
default=SUPPORTED_FILE_TYPES,
description="Files with extenstions specified here (subset of default value) only will be scanned to create dataset. Other files will be omitted.",
)
default_extension: Optional[str] = Field(
description="For files without extension it will assume the specified file type. If it is not set the files without extensions will be skipped.",
)
table_name: Optional[str] = Field(
default=None,
description="Display name of the dataset.Combination of named variableds from include path and strings",
)
enable_compression: bool = Field(
default=True,
description="Enable or disable processing compressed files. Currenly .gz and .bz files are supported.",
)
sample_files: bool = Field(
default=True,
description="Not listing all the files but only taking a handful amount of sample file to infer the schema. File count and file size calculation will be disabled. This can affect performance significantly if enabled",
)
# to be set internally
_parsable_include: str
_compiled_include: parse.Parser
_glob_include: str
_is_s3: bool
def allowed(self, path: str) -> bool:
logger.debug(f"Checking file to inclusion: {path}")
if not pathlib.PurePath(path).globmatch(
self._glob_include, flags=pathlib.GLOBSTAR
):
return False
logger.debug(f"{path} matched include ")
if self.exclude:
for exclude_path in self.exclude:
if pathlib.PurePath(path).globmatch(
exclude_path, flags=pathlib.GLOBSTAR
):
return False
logger.debug(f"{path} is not excluded")
ext = os.path.splitext(path)[1].strip(".")
if (ext == "" and self.default_extension is None) and (
ext != "*" and ext not in self.file_types
):
return False
logger.debug(f"{path} had selected extension {ext}")
logger.debug(f"{path} allowed for dataset creation")
return True
def is_s3(self):
return self._is_s3
@classmethod
def get_parsable_include(cls, include: str) -> str:
parsable_include = include
for i in range(parsable_include.count("*")):
parsable_include = parsable_include.replace("*", f"{{folder[{i}]}}", 1)
return parsable_include
def get_named_vars(self, path: str) -> Union[None, parse.Result, parse.Match]:
return self._compiled_include.parse(path)
@pydantic.root_validator()
def validate_path_spec(cls, values: Dict) -> Dict[str, Any]:
if "**" in values["include"]:
raise ValueError("path_spec.include cannot contain '**'")
if values.get("file_types") is None:
values["file_types"] = SUPPORTED_FILE_TYPES
else:
for file_type in values["file_types"]:
if file_type not in SUPPORTED_FILE_TYPES:
raise ValueError(
f"file type {file_type} not in supported file types. Please specify one from {SUPPORTED_FILE_TYPES}"
)
if values.get("default_extension") is not None:
if values.get("default_extension") not in SUPPORTED_FILE_TYPES:
raise ValueError(
f"default extension {values.get('default_extension')} not in supported default file extension. Please specify one from {SUPPORTED_FILE_TYPES}"
)
include_ext = os.path.splitext(values["include"])[1].strip(".")
if (
include_ext not in values["file_types"]
and include_ext != "*"
and not values["default_extension"]
and include_ext not in SUPPORTED_COMPRESSIONS
):
raise ValueError(
f"file type specified ({include_ext}) in path_spec.include is not in specified file "
f'types. Please select one from {values.get("file_types")} or specify ".*" to allow all types'
)
values["_parsable_include"] = PathSpec.get_parsable_include(values["include"])
logger.debug(f'Setting _parsable_include: {values.get("_parsable_include")}')
compiled_include_tmp = parse.compile(values["_parsable_include"])
values["_compiled_include"] = compiled_include_tmp
logger.debug(f'Setting _compiled_include: {values["_compiled_include"]}')
values["_glob_include"] = re.sub(r"\{[^}]+\}", "*", values["include"])
logger.debug(f'Setting _glob_include: {values.get("_glob_include")}')
if values.get("table_name") is None:
if "{table}" in values["include"]:
values["table_name"] = "{table}"
else:
logger.debug(f"include fields: {compiled_include_tmp.named_fields}")
logger.debug(
f"table_name fields: {parse.compile(values['table_name']).named_fields}"
)
if not all(
x in values["_compiled_include"].named_fields
for x in parse.compile(values["table_name"]).named_fields
):
raise ValueError(
"Not all named variables used in path_spec.table_name are specified in "
"path_spec.include"
)
if values.get("exclude") is not None:
for exclude_path in values["exclude"]:
if len(parse.compile(exclude_path).named_fields) != 0:
raise ValueError(
"path_spec.exclude should not contain any named variables"
)
values["_is_s3"] = is_s3_uri(values["include"])
if not values["_is_s3"]:
# Sampling only makes sense on s3 currently
values["sample_files"] = False
logger.debug(f'Setting _is_s3: {values.get("_is_s3")}')
return values
def _extract_table_name(self, named_vars: dict) -> str:
if self.table_name is None:
raise ValueError("path_spec.table_name is not set")
return self.table_name.format_map(named_vars)
def extract_table_name_and_path(self, path: str) -> Tuple[str, str]:
parsed_vars = self.get_named_vars(path)
if parsed_vars is None or "table" not in parsed_vars.named:
return os.path.basename(path), path
else:
include = self.include
depth = include.count("/", 0, include.find("{table}"))
table_path = (
"/".join(path.split("/")[:depth]) + "/" + parsed_vars.named["table"]
)
return self._extract_table_name(parsed_vars.named), table_path

View File

@ -1,181 +1,23 @@
import logging
import os
import re
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
import parse
import pydantic
from pydantic.fields import Field
from wcmatch import pathlib
from datahub.configuration.common import AllowDenyPattern, ConfigModel
from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.source_common import (
EnvBasedSourceConfigBase,
PlatformSourceConfigBase,
)
from datahub.ingestion.source.aws.aws_common import AwsSourceConfig
from datahub.ingestion.source.aws.s3_util import get_bucket_name, is_s3_uri
from datahub.ingestion.source.aws.path_spec import PathSpec
from datahub.ingestion.source.aws.s3_util import get_bucket_name
from datahub.ingestion.source.s3.profiling import DataLakeProfilerConfig
# hide annoying debug errors from py4j
logging.getLogger("py4j").setLevel(logging.ERROR)
logger: logging.Logger = logging.getLogger(__name__)
SUPPORTED_FILE_TYPES: List[str] = ["csv", "tsv", "json", "parquet", "avro"]
SUPPORTED_COMPRESSIONS: List[str] = ["gz", "bz2"]
class PathSpec(ConfigModel):
class Config:
arbitrary_types_allowed = True
include: str = Field(
description="Path to table (s3 or local file system). Name variable {table} is used to mark the folder with dataset. In absence of {table}, file level dataset will be created. Check below examples for more details."
)
exclude: Optional[List[str]] = Field(
default=None,
description="list of paths in glob pattern which will be excluded while scanning for the datasets",
)
file_types: List[str] = Field(
default=SUPPORTED_FILE_TYPES,
description="Files with extenstions specified here (subset of default value) only will be scanned to create dataset. Other files will be omitted.",
)
default_extension: Optional[str] = Field(
description="For files without extension it will assume the specified file type. If it is not set the files without extensions will be skipped.",
)
table_name: Optional[str] = Field(
default=None,
description="Display name of the dataset.Combination of named variableds from include path and strings",
)
enable_compression: bool = Field(
default=True,
description="Enable or disable processing compressed files. Currenly .gz and .bz files are supported.",
)
sample_files: bool = Field(
default=True,
description="Not listing all the files but only taking a handful amount of sample file to infer the schema. File count and file size calculation will be disabled. This can affect performance significantly if enabled",
)
# to be set internally
_parsable_include: str
_compiled_include: parse.Parser
_glob_include: str
_is_s3: bool
def allowed(self, path: str) -> bool:
logger.debug(f"Checking file to inclusion: {path}")
if not pathlib.PurePath(path).globmatch(
self._glob_include, flags=pathlib.GLOBSTAR
):
return False
logger.debug(f"{path} matched include ")
if self.exclude:
for exclude_path in self.exclude:
if pathlib.PurePath(path).globmatch(
exclude_path, flags=pathlib.GLOBSTAR
):
return False
logger.debug(f"{path} is not excluded")
ext = os.path.splitext(path)[1].strip(".")
if (ext == "" and self.default_extension is None) and (
ext != "*" and ext not in self.file_types
):
return False
logger.debug(f"{path} had selected extension {ext}")
logger.debug(f"{path} allowed for dataset creation")
return True
def is_s3(self):
return self._is_s3
@classmethod
def get_parsable_include(cls, include: str) -> str:
parsable_include = include
for i in range(parsable_include.count("*")):
parsable_include = parsable_include.replace("*", f"{{folder[{i}]}}", 1)
return parsable_include
def get_named_vars(self, path: str) -> Union[None, parse.Result, parse.Match]:
return self._compiled_include.parse(path)
@pydantic.root_validator()
def validate_path_spec(cls, values: Dict) -> Dict[str, Any]:
if "**" in values["include"]:
raise ValueError("path_spec.include cannot contain '**'")
if values.get("file_types") is None:
values["file_types"] = SUPPORTED_FILE_TYPES
else:
for file_type in values["file_types"]:
if file_type not in SUPPORTED_FILE_TYPES:
raise ValueError(
f"file type {file_type} not in supported file types. Please specify one from {SUPPORTED_FILE_TYPES}"
)
if values.get("default_extension") is not None:
if values.get("default_extension") not in SUPPORTED_FILE_TYPES:
raise ValueError(
f"default extension {values.get('default_extension')} not in supported default file extension. Please specify one from {SUPPORTED_FILE_TYPES}"
)
include_ext = os.path.splitext(values["include"])[1].strip(".")
if (
include_ext not in values["file_types"]
and include_ext != "*"
and not values["default_extension"]
and include_ext not in SUPPORTED_COMPRESSIONS
):
raise ValueError(
f"file type specified ({include_ext}) in path_spec.include is not in specified file "
f'types. Please select one from {values.get("file_types")} or specify ".*" to allow all types'
)
values["_parsable_include"] = PathSpec.get_parsable_include(values["include"])
logger.debug(f'Setting _parsable_include: {values.get("_parsable_include")}')
compiled_include_tmp = parse.compile(values["_parsable_include"])
values["_compiled_include"] = compiled_include_tmp
logger.debug(f'Setting _compiled_include: {values["_compiled_include"]}')
values["_glob_include"] = re.sub(r"\{[^}]+\}", "*", values["include"])
logger.debug(f'Setting _glob_include: {values.get("_glob_include")}')
if values.get("table_name") is None:
if "{table}" in values["include"]:
values["table_name"] = "{table}"
else:
logger.debug(f"include fields: {compiled_include_tmp.named_fields}")
logger.debug(
f"table_name fields: {parse.compile(values['table_name']).named_fields}"
)
if not all(
x in values["_compiled_include"].named_fields
for x in parse.compile(values["table_name"]).named_fields
):
raise ValueError(
"Not all named variables used in path_spec.table_name are specified in "
"path_spec.include"
)
if values.get("exclude") is not None:
for exclude_path in values["exclude"]:
if len(parse.compile(exclude_path).named_fields) != 0:
raise ValueError(
"path_spec.exclude should not contain any named variables"
)
values["_is_s3"] = is_s3_uri(values["include"])
if not values["_is_s3"]:
# Sampling only makes sense on s3 currently
values["sample_files"] = False
logger.debug(f'Setting _is_s3: {values.get("_is_s3")}')
return values
class DataLakeSourceConfig(PlatformSourceConfigBase, EnvBasedSourceConfigBase):
path_specs: Optional[List[PathSpec]] = Field(

View File

@ -60,13 +60,14 @@ from datahub.ingestion.api.decorators import (
)
from datahub.ingestion.api.source import Source, SourceReport
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.path_spec import PathSpec
from datahub.ingestion.source.aws.s3_util import (
get_bucket_name,
get_bucket_relative_path,
get_key_prefix,
strip_s3_prefix,
)
from datahub.ingestion.source.s3.config import DataLakeSourceConfig, PathSpec
from datahub.ingestion.source.s3.config import DataLakeSourceConfig
from datahub.ingestion.source.s3.profiling import _SingleTableProfiler
from datahub.ingestion.source.s3.report import DataLakeSourceReport
from datahub.ingestion.source.schema_inference import avro, csv_tsv, json, parquet
@ -474,7 +475,8 @@ class S3Source(Source):
extension = pathlib.Path(table_data.full_path).suffix
if path_spec.enable_compression and (
extension[1:] in datahub.ingestion.source.s3.config.SUPPORTED_COMPRESSIONS
extension[1:]
in datahub.ingestion.source.aws.path_spec.SUPPORTED_COMPRESSIONS
):
# Removing the compression extension and using the one before that like .json.gz -> .json
extension = pathlib.Path(table_data.full_path).with_suffix("").suffix
@ -756,35 +758,18 @@ class S3Source(Source):
) -> TableData:
logger.debug(f"Getting table data for path: {path}")
parsed_vars = path_spec.get_named_vars(path)
table_name, table_path = path_spec.extract_table_name_and_path(path)
table_data = None
if parsed_vars is None or "table" not in parsed_vars.named:
table_data = TableData(
display_name=os.path.basename(path),
is_s3=path_spec.is_s3(),
full_path=path,
partitions=None,
timestamp=timestamp,
table_path=path,
number_of_files=1,
size_in_bytes=size,
)
else:
include = path_spec.include
depth = include.count("/", 0, include.find("{table}"))
table_path = (
"/".join(path.split("/")[:depth]) + "/" + parsed_vars.named["table"]
)
table_data = TableData(
display_name=self.extract_table_name(path_spec, parsed_vars.named),
is_s3=path_spec.is_s3(),
full_path=path,
partitions=None,
timestamp=timestamp,
table_path=table_path,
number_of_files=1,
size_in_bytes=size,
)
table_data = TableData(
display_name=table_name,
is_s3=path_spec.is_s3(),
full_path=path,
partitions=None,
timestamp=timestamp,
table_path=table_path,
number_of_files=1,
size_in_bytes=size,
)
return table_data
def resolve_templated_folders(self, bucket_name: str, prefix: str) -> Iterable[str]:

View File

@ -18,6 +18,7 @@ from sqlalchemy_redshift.dialect import RedshiftDialect, RelationKey
from sqllineage.runner import LineageRunner
import datahub.emitter.mce_builder as builder
from datahub.configuration import ConfigModel
from datahub.configuration.source_common import DatasetLineageProviderConfigBase
from datahub.configuration.time_window_config import BaseTimeWindowConfig
from datahub.emitter import mce_builder
@ -32,6 +33,8 @@ from datahub.ingestion.api.decorators import (
support_status,
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.path_spec import PathSpec
from datahub.ingestion.source.aws.s3_util import strip_s3_prefix
from datahub.ingestion.source.sql.postgres import PostgresConfig
from datahub.ingestion.source.sql.sql_common import (
SQLAlchemySource,
@ -101,8 +104,31 @@ class LineageItem:
self.dataset_lineage_type = DatasetLineageTypeClass.TRANSFORMED
class S3LineageProviderConfig(ConfigModel):
"""
Any source that produces s3 lineage from/to Datasets should inherit this class.
"""
path_specs: List[PathSpec] = Field(
description="List of PathSpec. See below the details about PathSpec"
)
class DatasetS3LineageProviderConfigBase(ConfigModel):
"""
Any source that produces s3 lineage from/to Datasets should inherit this class.
"""
s3_lineage_config: Optional[S3LineageProviderConfig] = Field(
default=None, description="Common config for S3 lineage generation"
)
class RedshiftConfig(
PostgresConfig, BaseTimeWindowConfig, DatasetLineageProviderConfigBase
PostgresConfig,
BaseTimeWindowConfig,
DatasetLineageProviderConfigBase,
DatasetS3LineageProviderConfigBase,
):
# Although Amazon Redshift is compatible with Postgres's wire format,
# we actually want to use the sqlalchemy-redshift package and dialect
@ -672,6 +698,14 @@ class RedshiftSource(SQLAlchemySource):
db_name = db_alias
return db_name
def _get_s3_path(self, path: str) -> str:
if self.config.s3_lineage_config:
for path_spec in self.config.s3_lineage_config.path_specs:
if path_spec.allowed(path):
table_name, table_path = path_spec.extract_table_name_and_path(path)
return table_path
return path
def _populate_lineage_map(
self, query: str, lineage_type: LineageCollectorType
) -> None:
@ -747,6 +781,7 @@ class RedshiftSource(SQLAlchemySource):
f"Only s3 source supported with copy. The source was: {path}.",
)
continue
path = strip_s3_prefix(self._get_s3_path(path))
else:
platform = LineageDatasetPlatform.REDSHIFT
path = f'{db_name}.{db_row["source_schema"]}.{db_row["source_table"]}'