feat(ingest): file - add support for folders, large files, improve co… (#5692)

This commit is contained in:
Shirshanka Das 2022-08-21 01:48:22 -07:00 committed by GitHub
parent ad4e285fb8
commit bb788ac317
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 396 additions and 96 deletions

View File

@ -56,6 +56,7 @@ framework_common = {
"packaging",
"aiohttp<4",
"cached_property",
"ijson",
}
kafka_common = {

View File

@ -209,7 +209,7 @@ def _test_source_connection(report_to: Optional[str], pipeline_config: dict) ->
try:
connection_report = ConnectionManager().test_source_connection(pipeline_config)
logger.info(connection_report.as_json())
if report_to:
if report_to and report_to != "datahub":
with open(report_to, "w") as out_fp:
out_fp.write(connection_report.as_json())
logger.info(f"Wrote report successfully to {report_to}")

View File

@ -1,8 +1,24 @@
import logging
from datahub.ingestion.source.file import GenericFileSource
logger = logging.getLogger(__name__)
def check_mce_file(filepath: str) -> str:
mce_source = GenericFileSource.create({"filename": filepath}, None)
for _ in mce_source.get_workunits():
pass
return f"{mce_source.get_report().workunits_produced} MCEs found - all valid"
if mce_source.get_report().failures:
# raise the first failure found
logger.error(
f"Event file check failed with errors. Raising first error found. Full report {mce_source.get_report().as_string()}"
)
for failure_list in mce_source.get_report().failures.values():
if len(failure_list):
raise Exception(failure_list[0])
raise Exception(
f"Failed to process file due to {mce_source.get_report().failures}"
)
else:
return f"{mce_source.get_report().workunits_produced} MCEs found - all valid"

View File

@ -39,7 +39,12 @@ class Report:
else:
return Report.to_str(some_val)
def compute_stats(self) -> None:
"""A hook to compute derived stats"""
pass
def as_obj(self) -> dict:
self.compute_stats()
return {
str(key): Report.to_dict(value)
for (key, value) in self.__dict__.items()

View File

@ -1,16 +1,25 @@
import platform
import sys
import collections
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Generic, Iterable, List, Optional, TypeVar, Union
from typing import (
Deque,
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
TypeVar,
Union,
)
from pydantic import BaseModel
import datahub
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope, WorkUnit
from datahub.ingestion.api.report import Report
from datahub.utilities.time import get_current_time_in_seconds
class SourceCapability(Enum):
@ -29,18 +38,55 @@ class SourceCapability(Enum):
CONTAINERS = "Asset Containers"
T = TypeVar("T")
class LossyList(List[T]):
"""A list that only preserves the head and tail of lists longer than a certain number"""
def __init__(
self, max_elements: int = 10, section_breaker: Optional[str] = "..."
) -> None:
super().__init__()
self.max_elements = max_elements
self.list_head: List[T] = []
self.list_tail: Deque[T] = collections.deque([], maxlen=int(max_elements / 2))
self.head_full = False
self.total_elements = 0
self.section_breaker = section_breaker
def __iter__(self) -> Iterator[T]:
yield from self.list_head
if self.section_breaker and len(self.list_tail):
yield f"{self.section_breaker} {self.total_elements - len(self.list_head) - len(self.list_tail)} more elements" # type: ignore
yield from self.list_tail
def append(self, __object: T) -> None:
if self.head_full:
self.list_tail.append(__object)
else:
self.list_head.append(__object)
if len(self.list_head) > int(self.max_elements / 2):
self.head_full = True
self.total_elements += 1
def __len__(self) -> int:
return self.total_elements
def __repr__(self) -> str:
return repr(list(self.__iter__()))
def __str__(self) -> str:
return str(list(self.__iter__()))
@dataclass
class SourceReport(Report):
workunits_produced: int = 0
workunit_ids: List[str] = field(default_factory=list)
workunit_ids: List[str] = field(default_factory=LossyList)
warnings: Dict[str, List[str]] = field(default_factory=dict)
failures: Dict[str, List[str]] = field(default_factory=dict)
cli_version: str = datahub.nice_version_name()
cli_entry_location: str = datahub.__file__
py_version: str = sys.version
py_exec_path: str = sys.executable
os_details: str = platform.platform()
def report_workunit(self, wu: WorkUnit) -> None:
self.workunits_produced += 1
@ -56,6 +102,20 @@ class SourceReport(Report):
self.failures[key] = []
self.failures[key].append(reason)
def __post_init__(self) -> None:
self.starting_time = get_current_time_in_seconds()
self.running_time_in_seconds = 0
def compute_stats(self) -> None:
current_time = get_current_time_in_seconds()
running_time = current_time - self.starting_time
workunits_produced = self.workunits_produced
if running_time > 0:
self.read_rate = workunits_produced / running_time
self.running_time_in_seconds = running_time
else:
self.read_rate = 0
class CapabilityReport(BaseModel):
"""A report capturing the result of any capability evaluation"""

View File

@ -1,13 +1,19 @@
import itertools
import logging
import platform
import sys
import time
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional
import click
import datahub
from datahub.configuration.common import PipelineExecutionError
from datahub.ingestion.api.committable import CommitPolicy
from datahub.ingestion.api.common import EndOfStream, PipelineContext, RecordEnvelope
from datahub.ingestion.api.pipeline_run_listener import PipelineRunListener
from datahub.ingestion.api.report import Report
from datahub.ingestion.api.sink import Sink, WriteCallback
from datahub.ingestion.api.source import Extractor, Source
from datahub.ingestion.api.transform import Transformer
@ -28,7 +34,7 @@ class LoggingCallback(WriteCallback):
def on_success(
self, record_envelope: RecordEnvelope, success_metadata: dict
) -> None:
logger.info(f"sink wrote workunit {record_envelope.metadata['workunit_id']}")
logger.debug(f"sink wrote workunit {record_envelope.metadata['workunit_id']}")
def on_failure(
self,
@ -46,6 +52,15 @@ class PipelineInitError(Exception):
pass
@dataclass
class CliReport(Report):
cli_version: str = datahub.nice_version_name()
cli_entry_location: str = datahub.__file__
py_version: str = sys.version
py_exec_path: str = sys.executable
os_details: str = platform.platform()
class Pipeline:
config: PipelineConfig
ctx: PipelineContext
@ -71,6 +86,9 @@ class Pipeline:
self.preview_workunits = preview_workunits
self.report_to = report_to
self.reporters: List[PipelineRunListener] = []
self.num_intermediate_workunits = 0
self.last_time_printed = int(time.time())
self.cli_report = CliReport()
try:
self.ctx = PipelineContext(
@ -240,6 +258,17 @@ class Pipeline:
no_default_report=no_default_report,
)
def _time_to_print(self) -> bool:
self.num_intermediate_workunits += 1
if self.num_intermediate_workunits > 1000:
current_time = int(time.time())
if current_time - self.last_time_printed > 10:
# we print
self.num_intermediate_workunits = 0
self.last_time_printed = current_time
return True
return False
def run(self) -> None:
self._notify_reporters_on_ingestion_start()
@ -250,6 +279,8 @@ class Pipeline:
self.source.get_workunits(),
self.preview_workunits if self.preview_mode else None,
):
if self._time_to_print():
self.pretty_print_summary(currently_running=True)
# TODO: change extractor interface
extractor.configure({}, self.ctx)
@ -370,8 +401,12 @@ class Pipeline:
result += len(val)
return result
def pretty_print_summary(self, warnings_as_failure: bool = False) -> int:
def pretty_print_summary(
self, warnings_as_failure: bool = False, currently_running: bool = False
) -> int:
click.echo()
click.secho("Cli report:", bold=True)
click.secho(self.cli_report.as_string())
click.secho(f"Source ({self.config.source.type}) report:", bold=True)
click.echo(self.source.get_report().as_string())
click.secho(f"Sink ({self.config.sink.type}) report:", bold=True)
@ -383,7 +418,7 @@ class Pipeline:
self.source.get_report().failures
)
click.secho(
f"Pipeline finished with {num_failures_source} failures in source producing {workunits_produced} workunits",
f"Pipeline {'running' if currently_running else 'finished'} with {num_failures_source} failures in source producing {workunits_produced} events",
fg="bright_red",
bold=True,
)
@ -391,14 +426,14 @@ class Pipeline:
elif self.source.get_report().warnings or self.sink.get_report().warnings:
num_warn_source = self._count_all_vals(self.source.get_report().warnings)
click.secho(
f"Pipeline finished with {num_warn_source} warnings in source producing {workunits_produced} workunits",
f"Pipeline {'running' if currently_running else 'finished'} with {num_warn_source} warnings in source producing {workunits_produced} events",
fg="yellow",
bold=True,
)
return 1 if warnings_as_failure else 0
else:
click.secho(
f"Pipeline finished successfully producing {workunits_produced} workunits",
f"Pipeline {'running' if currently_running else 'finished'} successfully producing {workunits_produced} events",
fg="green",
bold=True,
)

View File

@ -1,11 +1,19 @@
import datetime
import json
import logging
import os.path
from dataclasses import dataclass, field
from typing import Iterable, Iterator, Union
from dataclasses import dataclass
from enum import Enum
from io import BufferedReader
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, Optional, Tuple, Union
import ijson
from pydantic import root_validator, validator
from pydantic.fields import Field
from datahub.configuration.common import ConfigModel
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import (
SupportStatus,
config_class,
@ -25,100 +33,255 @@ from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
)
from datahub.metadata.schema_classes import UsageAggregationClass
def _iterate_file(path: str) -> list:
with open(path, "r") as f:
obj_list = json.load(f)
if not isinstance(obj_list, list):
obj_list = [obj_list]
return obj_list
logger = logging.getLogger(__name__)
def iterate_mce_file(path: str) -> Iterator[MetadataChangeEvent]:
for obj in _iterate_file(path):
mce: MetadataChangeEvent = MetadataChangeEvent.from_obj(obj)
yield mce
def iterate_generic_file(
path: str,
) -> Iterator[
Union[MetadataChangeEvent, MetadataChangeProposal, UsageAggregationClass]
]:
for i, obj in enumerate(_iterate_file(path)):
item: Union[MetadataChangeEvent, MetadataChangeProposal, UsageAggregationClass]
if "proposedSnapshot" in obj:
item = MetadataChangeEvent.from_obj(obj)
elif "aspect" in obj:
item = MetadataChangeProposal.from_obj(obj)
else:
item = UsageAggregationClass.from_obj(obj)
if not item.validate():
raise ValueError(f"failed to parse: {obj} (index {i})")
yield item
class FileReadMode(Enum):
STREAM = "STREAM"
BATCH = "BATCH"
AUTO = "AUTO"
class FileSourceConfig(ConfigModel):
filename: str = Field(description="Path to file to ingest.")
filename: Optional[str] = Field(None, description="Path to file to ingest.")
path: str = Field(
description="Path to folder or file to ingest. If pointed to a folder, all files with extension {file_extension} (default json) within that folder will be processed."
)
file_extension: str = Field(
"json",
description="When providing a folder to use to read files, set this field to control file extensions that you want the source to process. * is a special value that means process every file regardless of extension",
)
read_mode: FileReadMode = FileReadMode.AUTO
_minsize_for_streaming_mode_in_bytes: int = (
100 * 1000 * 1000 # Must be at least 100MB before we use streaming mode
)
@validator("read_mode", pre=True)
def read_mode_str_to_enum(cls, v):
if v and isinstance(v, str):
return v.upper()
@root_validator(pre=True)
def filename_populates_path_if_present(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
if "path" not in values and "filename" in values:
values["path"] = values["filename"]
elif values.get("filename"):
raise ValueError(
"Both path and filename should not be provided together. Use one. We recommend using path. filename is deprecated."
)
return values
@validator("file_extension", always=True)
def add_leading_dot_to_extension(cls, v: str) -> str:
if v:
if v.startswith("."):
return v
else:
return "." + v
return v
@dataclass
class FileSourceReport(SourceReport):
total_bytes_on_disk: Optional[int] = None
total_bytes_read: int = 0
total_num_files: int = 0
num_files_read: int = 0
percentage_completion: float = -1
estimated_time_to_completion_in_minutes: int = -1
total_parse_time_in_seconds: Optional[float] = None
total_deserialize_time_in_seconds: float = 0
def add_deserialize_time(self, delta: datetime.timedelta) -> None:
self.total_deserialize_time_in_seconds += delta.total_seconds()
def add_parse_time(self, delta: datetime.timedelta) -> None:
if self.total_parse_time_in_seconds is not None:
self.total_parse_time_in_seconds += delta.total_seconds()
else:
self.total_parse_time_in_seconds = delta.total_seconds()
def append_total_bytes_on_disk(self, delta: int) -> None:
if self.total_bytes_on_disk is not None:
self.total_bytes_on_disk += delta
else:
self.total_bytes_on_disk = delta
def compute_stats(self) -> None:
super().compute_stats()
self.percentage_completion = (
100.0 * (self.total_bytes_read / self.total_bytes_on_disk)
if self.total_bytes_on_disk
else -1
)
self.estimated_time_to_completion_in_minutes = int(
(
self.running_time_in_seconds
* (100 - self.percentage_completion)
/ self.percentage_completion
)
/ 60
)
@platform_name("File")
@config_class(FileSourceConfig)
@support_status(SupportStatus.CERTIFIED)
@dataclass
class GenericFileSource(TestableSource):
"""
This plugin pulls metadata from a previously generated file. The [file sink](../../../../metadata-ingestion/sink_docs/file.md) can produce such files, and a number of samples are included in the [examples/mce_files](../../../../metadata-ingestion/examples/mce_files) directory.
"""
config: FileSourceConfig
report: SourceReport = field(default_factory=SourceReport)
def __init__(self, ctx: PipelineContext, config: FileSourceConfig):
self.ctx = ctx
self.config = config
self.report = FileSourceReport()
self.fp: Optional[BufferedReader] = None
@classmethod
def create(cls, config_dict, ctx):
config = FileSourceConfig.parse_obj(config_dict)
return cls(ctx, config)
def get_filenames(self) -> Iterable[str]:
is_file = os.path.isfile(self.config.path)
is_dir = os.path.isdir(self.config.path)
if is_file:
self.report.total_num_files = 1
return [self.config.path]
if is_dir:
p = Path(self.config.path)
files_and_stats = [
(str(x), os.path.getsize(x))
for x in list(p.glob(f"*{self.config.file_extension}"))
if x.is_file()
]
self.report.total_num_files = len(files_and_stats)
self.report.total_bytes_on_disk = sum([y for (x, y) in files_and_stats])
return [x for (x, y) in files_and_stats]
raise Exception(f"Failed to process {self.config.path}")
def get_workunits(self) -> Iterable[Union[MetadataWorkUnit, UsageStatsWorkUnit]]:
for i, obj in enumerate(iterate_generic_file(self.config.filename)):
wu: Union[MetadataWorkUnit, UsageStatsWorkUnit]
if isinstance(obj, UsageAggregationClass):
wu = UsageStatsWorkUnit(f"file://{self.config.filename}:{i}", obj)
elif isinstance(obj, MetadataChangeProposal):
wu = MetadataWorkUnit(f"file://{self.config.filename}:{i}", mcp_raw=obj)
else:
wu = MetadataWorkUnit(f"file://{self.config.filename}:{i}", mce=obj)
self.report.report_workunit(wu)
yield wu
for f in self.get_filenames():
for i, obj in self.iterate_generic_file(f):
wu: Union[MetadataWorkUnit, UsageStatsWorkUnit]
if isinstance(obj, UsageAggregationClass):
wu = UsageStatsWorkUnit(f"file://{f}:{i}", obj)
elif isinstance(obj, MetadataChangeProposal):
wu = MetadataWorkUnit(f"file://{f}:{i}", mcp_raw=obj)
else:
wu = MetadataWorkUnit(f"file://{f}:{i}", mce=obj)
self.report.report_workunit(wu)
yield wu
self.report.num_files_read += 1
self.report.total_bytes_read += os.path.getsize(f)
def get_report(self):
return self.report
def close(self):
pass
if self.fp:
self.fp.close()
def _iterate_file(self, path: str) -> Iterable[Tuple[int, Any]]:
size = os.path.getsize(path)
if self.config.read_mode == FileReadMode.AUTO:
file_read_mode = (
FileReadMode.BATCH
if size < self.config._minsize_for_streaming_mode_in_bytes
else FileReadMode.STREAM
)
logger.info(f"Reading file {path} in {file_read_mode} mode")
else:
file_read_mode = self.config.read_mode
if file_read_mode == FileReadMode.BATCH:
with open(path, "r") as f:
parse_start_time = datetime.datetime.now()
obj_list = json.load(f)
parse_end_time = datetime.datetime.now()
self.report.add_parse_time(parse_end_time - parse_start_time)
if not isinstance(obj_list, list):
obj_list = [obj_list]
yield from enumerate(obj_list)
else:
self.fp = open(path, "rb")
parse_start_time = datetime.datetime.now()
parse_stream = ijson.parse(self.fp, use_float=True)
rows_yielded = 0
for row in ijson.items(parse_stream, "item", use_float=True):
parse_end_time = datetime.datetime.now()
self.report.add_parse_time(parse_end_time - parse_start_time)
rows_yielded += 1
yield rows_yielded, row
parse_start_time = datetime.datetime.now()
def iterate_mce_file(self, path: str) -> Iterator[MetadataChangeEvent]:
for i, obj in self._iterate_file(path):
mce: MetadataChangeEvent = MetadataChangeEvent.from_obj(obj)
yield mce
def iterate_generic_file(
self,
path: str,
) -> Iterator[
Tuple[
int,
Union[MetadataChangeEvent, MetadataChangeProposal, UsageAggregationClass],
]
]:
for i, obj in self._iterate_file(path):
item: Union[
MetadataChangeEvent, MetadataChangeProposal, UsageAggregationClass
]
try:
deserialize_start_time = datetime.datetime.now()
if "proposedSnapshot" in obj:
item = MetadataChangeEvent.from_obj(obj)
elif "aspect" in obj:
item = MetadataChangeProposal.from_obj(obj)
else:
item = UsageAggregationClass.from_obj(obj)
if not item.validate():
raise ValueError(f"failed to parse: {obj} (index {i})")
deserialize_duration = datetime.datetime.now() - deserialize_start_time
self.report.add_deserialize_time(deserialize_duration)
yield i, item
except Exception as e:
self.report.report_failure(f"path-{i}", str(e))
@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
config = FileSourceConfig.parse_obj(config_dict)
is_file = os.path.isfile(config.filename)
readable = os.access(config.filename, os.R_OK)
if is_file and readable:
return TestConnectionReport(
basic_connectivity=CapabilityReport(capable=True)
)
elif not is_file:
exists = os.path.exists(config.path)
if not exists:
return TestConnectionReport(
basic_connectivity=CapabilityReport(
capable=False,
failure_reason=f"{config.filename} doesn't appear to be a valid file.",
failure_reason=f"{config.path} doesn't appear to be a valid file or directory.",
)
)
elif not readable:
is_dir = os.path.isdir(config.path)
failure_message = None
readable = os.access(config.path, os.R_OK)
if not readable:
failure_message = f"Cannot read {config.path}"
if is_dir:
executable = os.access(config.path, os.X_OK)
if not executable:
failure_message = f"Do not have execute permissions in {config.path}"
if failure_message:
return TestConnectionReport(
basic_connectivity=CapabilityReport(
capable=False, failure_reason=f"Cannot read file {config.filename}"
capable=False, failure_reason=failure_message
)
)
else:
# not expected to be here
raise Exception("Not expected to be here.")
return TestConnectionReport(
basic_connectivity=CapabilityReport(capable=True)
)

View File

@ -0,0 +1,5 @@
import time
def get_current_time_in_seconds() -> int:
return int(time.time())

View File

@ -1,6 +1,7 @@
import io
import json
import pathlib
from unittest.mock import patch
import fastavro
import pytest
@ -11,7 +12,7 @@ import datahub.metadata.schema_classes as models
from datahub.cli.json_file import check_mce_file
from datahub.emitter import mce_builder
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.file import iterate_mce_file
from datahub.ingestion.source.file import FileSourceConfig, GenericFileSource
from datahub.metadata.schema_classes import (
MetadataChangeEventClass,
OwnershipClass,
@ -58,7 +59,6 @@ def test_serde_to_json(
output_filename = "output.json"
output_file = tmp_path / output_filename
pipeline = Pipeline.create(
{
"source": {"type": "file", "config": {"filename": str(golden_file)}},
@ -85,31 +85,43 @@ def test_serde_to_json(
],
)
@freeze_time(FROZEN_TIME)
def test_serde_to_avro(pytestconfig: PytestConfig, json_filename: str) -> None:
def test_serde_to_avro(
pytestconfig: PytestConfig,
json_filename: str,
) -> None:
# In this test, we want to read in from JSON -> MCE object.
# Next we serialize from MCE to Avro and then deserialize back to MCE.
# Finally, we want to compare the two MCE objects.
with patch(
"datahub.ingestion.api.common.PipelineContext", autospec=True
) as mock_pipeline_context:
json_path = pytestconfig.rootpath / json_filename
mces = list(iterate_mce_file(str(json_path)))
json_path = pytestconfig.rootpath / json_filename
source = GenericFileSource(
ctx=mock_pipeline_context, config=FileSourceConfig(path=str(json_path))
)
mces = list(source.iterate_mce_file(str(json_path)))
# Serialize to Avro.
parsed_schema = fastavro.parse_schema(json.loads(getMetadataChangeEventSchema()))
fo = io.BytesIO()
out_records = [mce.to_obj(tuples=True) for mce in mces]
fastavro.writer(fo, parsed_schema, out_records)
# Serialize to Avro.
parsed_schema = fastavro.parse_schema(
json.loads(getMetadataChangeEventSchema())
)
fo = io.BytesIO()
out_records = [mce.to_obj(tuples=True) for mce in mces]
fastavro.writer(fo, parsed_schema, out_records)
# Deserialized from Avro.
fo.seek(0)
in_records = list(fastavro.reader(fo, return_record_name=True))
in_mces = [
MetadataChangeEventClass.from_obj(record, tuples=True) for record in in_records
]
# Deserialized from Avro.
fo.seek(0)
in_records = list(fastavro.reader(fo, return_record_name=True))
in_mces = [
MetadataChangeEventClass.from_obj(record, tuples=True)
for record in in_records
]
# Check diff
assert len(mces) == len(in_mces)
for i in range(len(mces)):
assert mces[i] == in_mces[i]
# Check diff
assert len(mces) == len(in_mces)
for i in range(len(mces)):
assert mces[i] == in_mces[i]
@pytest.mark.parametrize(
@ -148,8 +160,11 @@ def test_check_mce_schema_failure(
) -> None:
json_file_path = pytestconfig.rootpath / json_filename
with pytest.raises((ValueError, AssertionError)):
try:
check_mce_file(str(json_file_path))
raise AssertionError("MCE File validated successfully when it should not have")
except Exception as e:
assert "is missing required field: active" in str(e)
def test_field_discriminator() -> None: