From f227bd982b43a12d0f27ffa60c05aff89f949ea4 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Mon, 26 Sep 2022 06:37:48 +0000 Subject: [PATCH] refactor(ingest): streamline pydantic configs (#6011) --- .../configuration/time_window_config.py | 11 ++---- .../configuration/validate_field_rename.py | 11 ++++-- .../src/datahub/ingestion/source/dbt.py | 1 + .../src/datahub/ingestion/source/file.py | 37 +++++++------------ 4 files changed, 26 insertions(+), 34 deletions(-) diff --git a/metadata-ingestion/src/datahub/configuration/time_window_config.py b/metadata-ingestion/src/datahub/configuration/time_window_config.py index ad7bfafedd..365534af76 100644 --- a/metadata-ingestion/src/datahub/configuration/time_window_config.py +++ b/metadata-ingestion/src/datahub/configuration/time_window_config.py @@ -40,15 +40,12 @@ class BaseTimeWindowConfig(ConfigModel): # `start_time` and `end_time` will be populated by the pre-validators. # However, we must specify a "default" value here or pydantic will complain # if those fields are not set by the user. - end_time: datetime = Field(default=None, description="Latest date of usage to consider. Default: Current time in UTC") # type: ignore + end_time: datetime = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="Latest date of usage to consider. Default: Current time in UTC", + ) start_time: datetime = Field(default=None, description="Earliest date of usage to consider. Default: Last full day in UTC (or hour, depending on `bucket_duration`)") # type: ignore - @pydantic.validator("end_time", pre=True, always=True) - def default_end_time( - cls, v: Any, *, values: Dict[str, Any], **kwargs: Any - ) -> datetime: - return v or datetime.now(tz=timezone.utc) - @pydantic.validator("start_time", pre=True, always=True) def default_start_time( cls, v: Any, *, values: Dict[str, Any], **kwargs: Any diff --git a/metadata-ingestion/src/datahub/configuration/validate_field_rename.py b/metadata-ingestion/src/datahub/configuration/validate_field_rename.py index 1da59a350a..496bac6517 100644 --- a/metadata-ingestion/src/datahub/configuration/validate_field_rename.py +++ b/metadata-ingestion/src/datahub/configuration/validate_field_rename.py @@ -14,6 +14,7 @@ def pydantic_renamed_field( old_name: str, new_name: str, transform: Callable = _default_rename_transform, + print_warning: bool = True, ) -> classmethod: def _validate_field_rename(cls: Type, values: dict) -> dict: if old_name in values: @@ -22,10 +23,12 @@ def pydantic_renamed_field( f"Cannot specify both {old_name} and {new_name} in the same config. Note that {old_name} has been deprecated in favor of {new_name}." ) else: - warnings.warn( - f"The {old_name} is deprecated, please use {new_name} instead.", - UserWarning, - ) + if print_warning: + warnings.warn( + f"The {old_name} is deprecated, please use {new_name} instead.", + UserWarning, + stacklevel=2, + ) values[new_name] = transform(values.pop(old_name)) return values diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt.py b/metadata-ingestion/src/datahub/ingestion/source/dbt.py index f089158d00..1dcf8cae5f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt.py @@ -291,6 +291,7 @@ class DBTConfig(StatefulIngestionConfigBase): description="When enabled, applies the mappings that are defined through the `query_tag_mapping` directives.", ) write_semantics: str = Field( + # TODO: Replace with the WriteSemantics enum. default="PATCH", description='Whether the new tags, terms and owners to be added will override the existing ones added only by this source or not. Value for this config can be "PATCH" or "OVERRIDE"', ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/file.py b/metadata-ingestion/src/datahub/ingestion/source/file.py index 54dfa7181d..2c5ce2b6af 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/file.py +++ b/metadata-ingestion/src/datahub/ingestion/source/file.py @@ -2,17 +2,18 @@ import datetime import json import logging import os.path +import pathlib from dataclasses import dataclass, field from enum import auto from io import BufferedReader -from pathlib import Path from typing import Any, Dict, Iterable, Iterator, Optional, Tuple, Union import ijson -from pydantic import root_validator, validator +from pydantic import validator from pydantic.fields import Field from datahub.configuration.common import ConfigEnum, ConfigModel +from datahub.configuration.validate_field_rename import pydantic_renamed_field from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.decorators import ( SupportStatus, @@ -43,8 +44,10 @@ class FileReadMode(ConfigEnum): class FileSourceConfig(ConfigModel): - filename: Optional[str] = Field(None, description="Path to file to ingest.") - path: str = Field( + filename: Optional[str] = Field( + None, description="[deprecated in favor or `path`] The file to ingest." + ) + path: pathlib.Path = Field( description="Path to folder or file to ingest. If pointed to a folder, all files with extension {file_extension} (default json) within that folder will be processed." ) file_extension: str = Field( @@ -61,18 +64,9 @@ class FileSourceConfig(ConfigModel): 100 * 1000 * 1000 # Must be at least 100MB before we use streaming mode ) - @root_validator(pre=True) - def filename_populates_path_if_present( - cls, values: Dict[str, Any] - ) -> Dict[str, Any]: - if "path" not in values and "filename" in values: - values["path"] = values["filename"] - elif values.get("filename"): - raise ValueError( - "Both path and filename should not be provided together. Use one. We recommend using path. filename is deprecated." - ) - - return values + _filename_populates_path_if_present = pydantic_renamed_field( + "filename", "path", print_warning=False + ) @validator("file_extension", always=True) def add_leading_dot_to_extension(cls, v: str) -> str: @@ -179,16 +173,13 @@ class GenericFileSource(TestableSource): return cls(ctx, config) def get_filenames(self) -> Iterable[str]: - is_file = os.path.isfile(self.config.path) - is_dir = os.path.isdir(self.config.path) - if is_file: + if self.config.path.is_file(): self.report.total_num_files = 1 - return [self.config.path] - if is_dir: - p = Path(self.config.path) + return [str(self.config.path)] + elif self.config.path.is_dir(): files_and_stats = [ (str(x), os.path.getsize(x)) - for x in list(p.glob(f"*{self.config.file_extension}")) + for x in list(self.config.path.glob(f"*{self.config.file_extension}")) if x.is_file() ] self.report.total_num_files = len(files_and_stats)