diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 7683ba0082..0b55fd9531 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -56,6 +56,7 @@ framework_common = { "packaging", "aiohttp<4", "cached_property", + "ijson", } kafka_common = { diff --git a/metadata-ingestion/src/datahub/cli/ingest_cli.py b/metadata-ingestion/src/datahub/cli/ingest_cli.py index 2adfecb5f2..5870b06b6b 100644 --- a/metadata-ingestion/src/datahub/cli/ingest_cli.py +++ b/metadata-ingestion/src/datahub/cli/ingest_cli.py @@ -209,7 +209,7 @@ def _test_source_connection(report_to: Optional[str], pipeline_config: dict) -> try: connection_report = ConnectionManager().test_source_connection(pipeline_config) logger.info(connection_report.as_json()) - if report_to: + if report_to and report_to != "datahub": with open(report_to, "w") as out_fp: out_fp.write(connection_report.as_json()) logger.info(f"Wrote report successfully to {report_to}") diff --git a/metadata-ingestion/src/datahub/cli/json_file.py b/metadata-ingestion/src/datahub/cli/json_file.py index 7c670a4918..eb8de80acd 100644 --- a/metadata-ingestion/src/datahub/cli/json_file.py +++ b/metadata-ingestion/src/datahub/cli/json_file.py @@ -1,8 +1,24 @@ +import logging + from datahub.ingestion.source.file import GenericFileSource +logger = logging.getLogger(__name__) + def check_mce_file(filepath: str) -> str: mce_source = GenericFileSource.create({"filename": filepath}, None) for _ in mce_source.get_workunits(): pass - return f"{mce_source.get_report().workunits_produced} MCEs found - all valid" + if mce_source.get_report().failures: + # raise the first failure found + logger.error( + f"Event file check failed with errors. Raising first error found. Full report {mce_source.get_report().as_string()}" + ) + for failure_list in mce_source.get_report().failures.values(): + if len(failure_list): + raise Exception(failure_list[0]) + raise Exception( + f"Failed to process file due to {mce_source.get_report().failures}" + ) + else: + return f"{mce_source.get_report().workunits_produced} MCEs found - all valid" diff --git a/metadata-ingestion/src/datahub/ingestion/api/report.py b/metadata-ingestion/src/datahub/ingestion/api/report.py index d2403301d8..f572d63bef 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/report.py +++ b/metadata-ingestion/src/datahub/ingestion/api/report.py @@ -39,7 +39,12 @@ class Report: else: return Report.to_str(some_val) + def compute_stats(self) -> None: + """A hook to compute derived stats""" + pass + def as_obj(self) -> dict: + self.compute_stats() return { str(key): Report.to_dict(value) for (key, value) in self.__dict__.items() diff --git a/metadata-ingestion/src/datahub/ingestion/api/source.py b/metadata-ingestion/src/datahub/ingestion/api/source.py index 45d594de3e..645c7ca7f4 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/source.py +++ b/metadata-ingestion/src/datahub/ingestion/api/source.py @@ -1,16 +1,25 @@ -import platform -import sys +import collections from abc import ABCMeta, abstractmethod from dataclasses import dataclass, field from enum import Enum -from typing import Dict, Generic, Iterable, List, Optional, TypeVar, Union +from typing import ( + Deque, + Dict, + Generic, + Iterable, + Iterator, + List, + Optional, + TypeVar, + Union, +) from pydantic import BaseModel -import datahub from datahub.ingestion.api.closeable import Closeable from datahub.ingestion.api.common import PipelineContext, RecordEnvelope, WorkUnit from datahub.ingestion.api.report import Report +from datahub.utilities.time import get_current_time_in_seconds class SourceCapability(Enum): @@ -29,18 +38,55 @@ class SourceCapability(Enum): CONTAINERS = "Asset Containers" +T = TypeVar("T") + + +class LossyList(List[T]): + """A list that only preserves the head and tail of lists longer than a certain number""" + + def __init__( + self, max_elements: int = 10, section_breaker: Optional[str] = "..." + ) -> None: + super().__init__() + self.max_elements = max_elements + self.list_head: List[T] = [] + self.list_tail: Deque[T] = collections.deque([], maxlen=int(max_elements / 2)) + self.head_full = False + self.total_elements = 0 + self.section_breaker = section_breaker + + def __iter__(self) -> Iterator[T]: + yield from self.list_head + if self.section_breaker and len(self.list_tail): + yield f"{self.section_breaker} {self.total_elements - len(self.list_head) - len(self.list_tail)} more elements" # type: ignore + yield from self.list_tail + + def append(self, __object: T) -> None: + if self.head_full: + self.list_tail.append(__object) + else: + self.list_head.append(__object) + if len(self.list_head) > int(self.max_elements / 2): + self.head_full = True + self.total_elements += 1 + + def __len__(self) -> int: + return self.total_elements + + def __repr__(self) -> str: + return repr(list(self.__iter__())) + + def __str__(self) -> str: + return str(list(self.__iter__())) + + @dataclass class SourceReport(Report): workunits_produced: int = 0 - workunit_ids: List[str] = field(default_factory=list) + workunit_ids: List[str] = field(default_factory=LossyList) warnings: Dict[str, List[str]] = field(default_factory=dict) failures: Dict[str, List[str]] = field(default_factory=dict) - cli_version: str = datahub.nice_version_name() - cli_entry_location: str = datahub.__file__ - py_version: str = sys.version - py_exec_path: str = sys.executable - os_details: str = platform.platform() def report_workunit(self, wu: WorkUnit) -> None: self.workunits_produced += 1 @@ -56,6 +102,20 @@ class SourceReport(Report): self.failures[key] = [] self.failures[key].append(reason) + def __post_init__(self) -> None: + self.starting_time = get_current_time_in_seconds() + self.running_time_in_seconds = 0 + + def compute_stats(self) -> None: + current_time = get_current_time_in_seconds() + running_time = current_time - self.starting_time + workunits_produced = self.workunits_produced + if running_time > 0: + self.read_rate = workunits_produced / running_time + self.running_time_in_seconds = running_time + else: + self.read_rate = 0 + class CapabilityReport(BaseModel): """A report capturing the result of any capability evaluation""" diff --git a/metadata-ingestion/src/datahub/ingestion/run/pipeline.py b/metadata-ingestion/src/datahub/ingestion/run/pipeline.py index 9746c090a8..9bca0dc65b 100644 --- a/metadata-ingestion/src/datahub/ingestion/run/pipeline.py +++ b/metadata-ingestion/src/datahub/ingestion/run/pipeline.py @@ -1,13 +1,19 @@ import itertools import logging +import platform +import sys +import time +from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Optional import click +import datahub from datahub.configuration.common import PipelineExecutionError from datahub.ingestion.api.committable import CommitPolicy from datahub.ingestion.api.common import EndOfStream, PipelineContext, RecordEnvelope from datahub.ingestion.api.pipeline_run_listener import PipelineRunListener +from datahub.ingestion.api.report import Report from datahub.ingestion.api.sink import Sink, WriteCallback from datahub.ingestion.api.source import Extractor, Source from datahub.ingestion.api.transform import Transformer @@ -28,7 +34,7 @@ class LoggingCallback(WriteCallback): def on_success( self, record_envelope: RecordEnvelope, success_metadata: dict ) -> None: - logger.info(f"sink wrote workunit {record_envelope.metadata['workunit_id']}") + logger.debug(f"sink wrote workunit {record_envelope.metadata['workunit_id']}") def on_failure( self, @@ -46,6 +52,15 @@ class PipelineInitError(Exception): pass +@dataclass +class CliReport(Report): + cli_version: str = datahub.nice_version_name() + cli_entry_location: str = datahub.__file__ + py_version: str = sys.version + py_exec_path: str = sys.executable + os_details: str = platform.platform() + + class Pipeline: config: PipelineConfig ctx: PipelineContext @@ -71,6 +86,9 @@ class Pipeline: self.preview_workunits = preview_workunits self.report_to = report_to self.reporters: List[PipelineRunListener] = [] + self.num_intermediate_workunits = 0 + self.last_time_printed = int(time.time()) + self.cli_report = CliReport() try: self.ctx = PipelineContext( @@ -240,6 +258,17 @@ class Pipeline: no_default_report=no_default_report, ) + def _time_to_print(self) -> bool: + self.num_intermediate_workunits += 1 + if self.num_intermediate_workunits > 1000: + current_time = int(time.time()) + if current_time - self.last_time_printed > 10: + # we print + self.num_intermediate_workunits = 0 + self.last_time_printed = current_time + return True + return False + def run(self) -> None: self._notify_reporters_on_ingestion_start() @@ -250,6 +279,8 @@ class Pipeline: self.source.get_workunits(), self.preview_workunits if self.preview_mode else None, ): + if self._time_to_print(): + self.pretty_print_summary(currently_running=True) # TODO: change extractor interface extractor.configure({}, self.ctx) @@ -370,8 +401,12 @@ class Pipeline: result += len(val) return result - def pretty_print_summary(self, warnings_as_failure: bool = False) -> int: + def pretty_print_summary( + self, warnings_as_failure: bool = False, currently_running: bool = False + ) -> int: click.echo() + click.secho("Cli report:", bold=True) + click.secho(self.cli_report.as_string()) click.secho(f"Source ({self.config.source.type}) report:", bold=True) click.echo(self.source.get_report().as_string()) click.secho(f"Sink ({self.config.sink.type}) report:", bold=True) @@ -383,7 +418,7 @@ class Pipeline: self.source.get_report().failures ) click.secho( - f"Pipeline finished with {num_failures_source} failures in source producing {workunits_produced} workunits", + f"Pipeline {'running' if currently_running else 'finished'} with {num_failures_source} failures in source producing {workunits_produced} events", fg="bright_red", bold=True, ) @@ -391,14 +426,14 @@ class Pipeline: elif self.source.get_report().warnings or self.sink.get_report().warnings: num_warn_source = self._count_all_vals(self.source.get_report().warnings) click.secho( - f"Pipeline finished with {num_warn_source} warnings in source producing {workunits_produced} workunits", + f"Pipeline {'running' if currently_running else 'finished'} with {num_warn_source} warnings in source producing {workunits_produced} events", fg="yellow", bold=True, ) return 1 if warnings_as_failure else 0 else: click.secho( - f"Pipeline finished successfully producing {workunits_produced} workunits", + f"Pipeline {'running' if currently_running else 'finished'} successfully producing {workunits_produced} events", fg="green", bold=True, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/file.py b/metadata-ingestion/src/datahub/ingestion/source/file.py index f8c5fcaef3..a0eb5876d3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/file.py +++ b/metadata-ingestion/src/datahub/ingestion/source/file.py @@ -1,11 +1,19 @@ +import datetime import json +import logging import os.path -from dataclasses import dataclass, field -from typing import Iterable, Iterator, Union +from dataclasses import dataclass +from enum import Enum +from io import BufferedReader +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, Optional, Tuple, Union +import ijson +from pydantic import root_validator, validator from pydantic.fields import Field from datahub.configuration.common import ConfigModel +from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.decorators import ( SupportStatus, config_class, @@ -25,100 +33,255 @@ from datahub.metadata.com.linkedin.pegasus2avro.mxe import ( ) from datahub.metadata.schema_classes import UsageAggregationClass - -def _iterate_file(path: str) -> list: - with open(path, "r") as f: - obj_list = json.load(f) - if not isinstance(obj_list, list): - obj_list = [obj_list] - return obj_list +logger = logging.getLogger(__name__) -def iterate_mce_file(path: str) -> Iterator[MetadataChangeEvent]: - for obj in _iterate_file(path): - mce: MetadataChangeEvent = MetadataChangeEvent.from_obj(obj) - yield mce - - -def iterate_generic_file( - path: str, -) -> Iterator[ - Union[MetadataChangeEvent, MetadataChangeProposal, UsageAggregationClass] -]: - for i, obj in enumerate(_iterate_file(path)): - item: Union[MetadataChangeEvent, MetadataChangeProposal, UsageAggregationClass] - if "proposedSnapshot" in obj: - item = MetadataChangeEvent.from_obj(obj) - elif "aspect" in obj: - item = MetadataChangeProposal.from_obj(obj) - else: - item = UsageAggregationClass.from_obj(obj) - if not item.validate(): - raise ValueError(f"failed to parse: {obj} (index {i})") - yield item +class FileReadMode(Enum): + STREAM = "STREAM" + BATCH = "BATCH" + AUTO = "AUTO" class FileSourceConfig(ConfigModel): - filename: str = Field(description="Path to file to ingest.") + filename: Optional[str] = Field(None, description="Path to file to ingest.") + path: str = Field( + description="Path to folder or file to ingest. If pointed to a folder, all files with extension {file_extension} (default json) within that folder will be processed." + ) + file_extension: str = Field( + "json", + description="When providing a folder to use to read files, set this field to control file extensions that you want the source to process. * is a special value that means process every file regardless of extension", + ) + read_mode: FileReadMode = FileReadMode.AUTO + + _minsize_for_streaming_mode_in_bytes: int = ( + 100 * 1000 * 1000 # Must be at least 100MB before we use streaming mode + ) + + @validator("read_mode", pre=True) + def read_mode_str_to_enum(cls, v): + if v and isinstance(v, str): + return v.upper() + + @root_validator(pre=True) + def filename_populates_path_if_present( + cls, values: Dict[str, Any] + ) -> Dict[str, Any]: + if "path" not in values and "filename" in values: + values["path"] = values["filename"] + elif values.get("filename"): + raise ValueError( + "Both path and filename should not be provided together. Use one. We recommend using path. filename is deprecated." + ) + + return values + + @validator("file_extension", always=True) + def add_leading_dot_to_extension(cls, v: str) -> str: + if v: + if v.startswith("."): + return v + else: + return "." + v + return v + + +@dataclass +class FileSourceReport(SourceReport): + total_bytes_on_disk: Optional[int] = None + total_bytes_read: int = 0 + total_num_files: int = 0 + num_files_read: int = 0 + percentage_completion: float = -1 + estimated_time_to_completion_in_minutes: int = -1 + total_parse_time_in_seconds: Optional[float] = None + total_deserialize_time_in_seconds: float = 0 + + def add_deserialize_time(self, delta: datetime.timedelta) -> None: + self.total_deserialize_time_in_seconds += delta.total_seconds() + + def add_parse_time(self, delta: datetime.timedelta) -> None: + if self.total_parse_time_in_seconds is not None: + self.total_parse_time_in_seconds += delta.total_seconds() + else: + self.total_parse_time_in_seconds = delta.total_seconds() + + def append_total_bytes_on_disk(self, delta: int) -> None: + if self.total_bytes_on_disk is not None: + self.total_bytes_on_disk += delta + else: + self.total_bytes_on_disk = delta + + def compute_stats(self) -> None: + super().compute_stats() + self.percentage_completion = ( + 100.0 * (self.total_bytes_read / self.total_bytes_on_disk) + if self.total_bytes_on_disk + else -1 + ) + self.estimated_time_to_completion_in_minutes = int( + ( + self.running_time_in_seconds + * (100 - self.percentage_completion) + / self.percentage_completion + ) + / 60 + ) @platform_name("File") @config_class(FileSourceConfig) @support_status(SupportStatus.CERTIFIED) -@dataclass class GenericFileSource(TestableSource): """ This plugin pulls metadata from a previously generated file. The [file sink](../../../../metadata-ingestion/sink_docs/file.md) can produce such files, and a number of samples are included in the [examples/mce_files](../../../../metadata-ingestion/examples/mce_files) directory. """ - config: FileSourceConfig - report: SourceReport = field(default_factory=SourceReport) + def __init__(self, ctx: PipelineContext, config: FileSourceConfig): + self.ctx = ctx + self.config = config + self.report = FileSourceReport() + self.fp: Optional[BufferedReader] = None @classmethod def create(cls, config_dict, ctx): config = FileSourceConfig.parse_obj(config_dict) return cls(ctx, config) + def get_filenames(self) -> Iterable[str]: + is_file = os.path.isfile(self.config.path) + is_dir = os.path.isdir(self.config.path) + if is_file: + self.report.total_num_files = 1 + return [self.config.path] + if is_dir: + p = Path(self.config.path) + files_and_stats = [ + (str(x), os.path.getsize(x)) + for x in list(p.glob(f"*{self.config.file_extension}")) + if x.is_file() + ] + self.report.total_num_files = len(files_and_stats) + self.report.total_bytes_on_disk = sum([y for (x, y) in files_and_stats]) + return [x for (x, y) in files_and_stats] + raise Exception(f"Failed to process {self.config.path}") + def get_workunits(self) -> Iterable[Union[MetadataWorkUnit, UsageStatsWorkUnit]]: - for i, obj in enumerate(iterate_generic_file(self.config.filename)): - wu: Union[MetadataWorkUnit, UsageStatsWorkUnit] - if isinstance(obj, UsageAggregationClass): - wu = UsageStatsWorkUnit(f"file://{self.config.filename}:{i}", obj) - elif isinstance(obj, MetadataChangeProposal): - wu = MetadataWorkUnit(f"file://{self.config.filename}:{i}", mcp_raw=obj) - else: - wu = MetadataWorkUnit(f"file://{self.config.filename}:{i}", mce=obj) - self.report.report_workunit(wu) - yield wu + for f in self.get_filenames(): + for i, obj in self.iterate_generic_file(f): + wu: Union[MetadataWorkUnit, UsageStatsWorkUnit] + if isinstance(obj, UsageAggregationClass): + wu = UsageStatsWorkUnit(f"file://{f}:{i}", obj) + elif isinstance(obj, MetadataChangeProposal): + wu = MetadataWorkUnit(f"file://{f}:{i}", mcp_raw=obj) + else: + wu = MetadataWorkUnit(f"file://{f}:{i}", mce=obj) + self.report.report_workunit(wu) + yield wu + self.report.num_files_read += 1 + self.report.total_bytes_read += os.path.getsize(f) def get_report(self): return self.report def close(self): - pass + if self.fp: + self.fp.close() + + def _iterate_file(self, path: str) -> Iterable[Tuple[int, Any]]: + size = os.path.getsize(path) + if self.config.read_mode == FileReadMode.AUTO: + file_read_mode = ( + FileReadMode.BATCH + if size < self.config._minsize_for_streaming_mode_in_bytes + else FileReadMode.STREAM + ) + logger.info(f"Reading file {path} in {file_read_mode} mode") + else: + file_read_mode = self.config.read_mode + + if file_read_mode == FileReadMode.BATCH: + with open(path, "r") as f: + parse_start_time = datetime.datetime.now() + obj_list = json.load(f) + parse_end_time = datetime.datetime.now() + self.report.add_parse_time(parse_end_time - parse_start_time) + if not isinstance(obj_list, list): + obj_list = [obj_list] + yield from enumerate(obj_list) + else: + self.fp = open(path, "rb") + parse_start_time = datetime.datetime.now() + parse_stream = ijson.parse(self.fp, use_float=True) + rows_yielded = 0 + for row in ijson.items(parse_stream, "item", use_float=True): + parse_end_time = datetime.datetime.now() + self.report.add_parse_time(parse_end_time - parse_start_time) + rows_yielded += 1 + yield rows_yielded, row + parse_start_time = datetime.datetime.now() + + def iterate_mce_file(self, path: str) -> Iterator[MetadataChangeEvent]: + for i, obj in self._iterate_file(path): + mce: MetadataChangeEvent = MetadataChangeEvent.from_obj(obj) + yield mce + + def iterate_generic_file( + self, + path: str, + ) -> Iterator[ + Tuple[ + int, + Union[MetadataChangeEvent, MetadataChangeProposal, UsageAggregationClass], + ] + ]: + for i, obj in self._iterate_file(path): + item: Union[ + MetadataChangeEvent, MetadataChangeProposal, UsageAggregationClass + ] + try: + deserialize_start_time = datetime.datetime.now() + if "proposedSnapshot" in obj: + item = MetadataChangeEvent.from_obj(obj) + elif "aspect" in obj: + item = MetadataChangeProposal.from_obj(obj) + else: + item = UsageAggregationClass.from_obj(obj) + if not item.validate(): + raise ValueError(f"failed to parse: {obj} (index {i})") + deserialize_duration = datetime.datetime.now() - deserialize_start_time + self.report.add_deserialize_time(deserialize_duration) + yield i, item + except Exception as e: + self.report.report_failure(f"path-{i}", str(e)) @staticmethod def test_connection(config_dict: dict) -> TestConnectionReport: config = FileSourceConfig.parse_obj(config_dict) - is_file = os.path.isfile(config.filename) - readable = os.access(config.filename, os.R_OK) - if is_file and readable: - return TestConnectionReport( - basic_connectivity=CapabilityReport(capable=True) - ) - elif not is_file: + exists = os.path.exists(config.path) + if not exists: return TestConnectionReport( basic_connectivity=CapabilityReport( capable=False, - failure_reason=f"{config.filename} doesn't appear to be a valid file.", + failure_reason=f"{config.path} doesn't appear to be a valid file or directory.", ) ) - elif not readable: + is_dir = os.path.isdir(config.path) + failure_message = None + readable = os.access(config.path, os.R_OK) + if not readable: + failure_message = f"Cannot read {config.path}" + if is_dir: + executable = os.access(config.path, os.X_OK) + if not executable: + failure_message = f"Do not have execute permissions in {config.path}" + + if failure_message: return TestConnectionReport( basic_connectivity=CapabilityReport( - capable=False, failure_reason=f"Cannot read file {config.filename}" + capable=False, failure_reason=failure_message ) ) else: - # not expected to be here - raise Exception("Not expected to be here.") + return TestConnectionReport( + basic_connectivity=CapabilityReport(capable=True) + ) diff --git a/metadata-ingestion/src/datahub/utilities/time.py b/metadata-ingestion/src/datahub/utilities/time.py new file mode 100644 index 0000000000..9109614194 --- /dev/null +++ b/metadata-ingestion/src/datahub/utilities/time.py @@ -0,0 +1,5 @@ +import time + + +def get_current_time_in_seconds() -> int: + return int(time.time()) diff --git a/metadata-ingestion/tests/unit/serde/test_serde.py b/metadata-ingestion/tests/unit/serde/test_serde.py index e40ff55e93..b5ff413ec7 100644 --- a/metadata-ingestion/tests/unit/serde/test_serde.py +++ b/metadata-ingestion/tests/unit/serde/test_serde.py @@ -1,6 +1,7 @@ import io import json import pathlib +from unittest.mock import patch import fastavro import pytest @@ -11,7 +12,7 @@ import datahub.metadata.schema_classes as models from datahub.cli.json_file import check_mce_file from datahub.emitter import mce_builder from datahub.ingestion.run.pipeline import Pipeline -from datahub.ingestion.source.file import iterate_mce_file +from datahub.ingestion.source.file import FileSourceConfig, GenericFileSource from datahub.metadata.schema_classes import ( MetadataChangeEventClass, OwnershipClass, @@ -58,7 +59,6 @@ def test_serde_to_json( output_filename = "output.json" output_file = tmp_path / output_filename - pipeline = Pipeline.create( { "source": {"type": "file", "config": {"filename": str(golden_file)}}, @@ -85,31 +85,43 @@ def test_serde_to_json( ], ) @freeze_time(FROZEN_TIME) -def test_serde_to_avro(pytestconfig: PytestConfig, json_filename: str) -> None: +def test_serde_to_avro( + pytestconfig: PytestConfig, + json_filename: str, +) -> None: # In this test, we want to read in from JSON -> MCE object. # Next we serialize from MCE to Avro and then deserialize back to MCE. # Finally, we want to compare the two MCE objects. + with patch( + "datahub.ingestion.api.common.PipelineContext", autospec=True + ) as mock_pipeline_context: - json_path = pytestconfig.rootpath / json_filename - mces = list(iterate_mce_file(str(json_path))) + json_path = pytestconfig.rootpath / json_filename + source = GenericFileSource( + ctx=mock_pipeline_context, config=FileSourceConfig(path=str(json_path)) + ) + mces = list(source.iterate_mce_file(str(json_path))) - # Serialize to Avro. - parsed_schema = fastavro.parse_schema(json.loads(getMetadataChangeEventSchema())) - fo = io.BytesIO() - out_records = [mce.to_obj(tuples=True) for mce in mces] - fastavro.writer(fo, parsed_schema, out_records) + # Serialize to Avro. + parsed_schema = fastavro.parse_schema( + json.loads(getMetadataChangeEventSchema()) + ) + fo = io.BytesIO() + out_records = [mce.to_obj(tuples=True) for mce in mces] + fastavro.writer(fo, parsed_schema, out_records) - # Deserialized from Avro. - fo.seek(0) - in_records = list(fastavro.reader(fo, return_record_name=True)) - in_mces = [ - MetadataChangeEventClass.from_obj(record, tuples=True) for record in in_records - ] + # Deserialized from Avro. + fo.seek(0) + in_records = list(fastavro.reader(fo, return_record_name=True)) + in_mces = [ + MetadataChangeEventClass.from_obj(record, tuples=True) + for record in in_records + ] - # Check diff - assert len(mces) == len(in_mces) - for i in range(len(mces)): - assert mces[i] == in_mces[i] + # Check diff + assert len(mces) == len(in_mces) + for i in range(len(mces)): + assert mces[i] == in_mces[i] @pytest.mark.parametrize( @@ -148,8 +160,11 @@ def test_check_mce_schema_failure( ) -> None: json_file_path = pytestconfig.rootpath / json_filename - with pytest.raises((ValueError, AssertionError)): + try: check_mce_file(str(json_file_path)) + raise AssertionError("MCE File validated successfully when it should not have") + except Exception as e: + assert "is missing required field: active" in str(e) def test_field_discriminator() -> None: