chore(ingest): start working on pydantic v2 support (#9220)

This commit is contained in:
Harshal Sheth 2023-11-10 09:34:08 -08:00 committed by GitHub
parent b851d59e20
commit 89dff8f7bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 216 additions and 139 deletions

View File

@ -192,7 +192,7 @@ def add_avro_python3_warning(filepath: Path) -> None:
# This means that installation order matters, which is a pretty unintuitive outcome.
# See https://github.com/pypa/pip/issues/4625 for details.
try:
from avro.schema import SchemaFromJSONData
from avro.schema import SchemaFromJSONData # type: ignore
import warnings
warnings.warn("It seems like 'avro-python3' is installed, which conflicts with the 'avro' package used by datahub. "

View File

@ -88,6 +88,7 @@ filterwarnings =
ignore:Deprecated call to \`pkg_resources.declare_namespace:DeprecationWarning
ignore:pkg_resources is deprecated as an API:DeprecationWarning
ignore:Did not recognize type:sqlalchemy.exc.SAWarning
ignore::datahub.configuration.pydantic_migration_helpers.PydanticDeprecatedSince20
[coverage:run]
# Because of some quirks in the way setup.cfg, coverage.py, pytest-cov,

View File

@ -47,7 +47,7 @@ config_override: Dict = {}
class GmsConfig(BaseModel):
server: str
token: Optional[str]
token: Optional[str] = None
class DatahubConfig(BaseModel):

View File

@ -40,7 +40,7 @@ class DuckDBLiteConfigWrapper(DuckDBLiteConfig):
class LiteCliConfig(DatahubConfig):
lite: LiteLocalConfig = LiteLocalConfig(
type="duckdb", config=DuckDBLiteConfigWrapper()
type="duckdb", config=DuckDBLiteConfigWrapper().dict()
)

View File

@ -4,6 +4,8 @@ import pydantic
import pydantic.types
import pydantic.validators
from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2
class ConfigEnum(Enum):
# Ideally we would use @staticmethod here, but some versions of Python don't support it.
@ -15,11 +17,25 @@ class ConfigEnum(Enum):
# From https://stackoverflow.com/a/44785241/5004662.
return name
@classmethod
def __get_validators__(cls) -> "pydantic.types.CallableGenerator":
# We convert the text to uppercase before attempting to match it to an enum value.
yield cls.validate
yield pydantic.validators.enum_member_validator
if PYDANTIC_VERSION_2:
# if TYPE_CHECKING:
# from pydantic import GetCoreSchemaHandler
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler): # type: ignore
from pydantic_core import core_schema
return core_schema.no_info_before_validator_function(
cls.validate, handler(source_type)
)
else:
@classmethod
def __get_validators__(cls) -> "pydantic.types.CallableGenerator":
# We convert the text to uppercase before attempting to match it to an enum value.
yield cls.validate
yield pydantic.validators.enum_member_validator
@classmethod
def validate(cls, v): # type: ignore[no-untyped-def]

View File

@ -11,6 +11,7 @@ from pydantic.fields import Field
from typing_extensions import Protocol, runtime_checkable
from datahub.configuration._config_enum import ConfigEnum
from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2
from datahub.utilities.dedup_list import deduplicate_list
_ConfigSelf = TypeVar("_ConfigSelf", bound="ConfigModel")
@ -71,14 +72,8 @@ def redact_raw_config(obj: Any) -> Any:
class ConfigModel(BaseModel):
class Config:
extra = Extra.forbid
underscore_attrs_are_private = True
keep_untouched = (
cached_property,
) # needed to allow cached_property to work. See https://github.com/samuelcolvin/pydantic/issues/1241 for more info.
@staticmethod
def schema_extra(schema: Dict[str, Any], model: Type["ConfigModel"]) -> None:
def _schema_extra(schema: Dict[str, Any], model: Type["ConfigModel"]) -> None:
# We use the custom "hidden_from_docs" attribute to hide fields from the
# autogenerated docs.
remove_fields = []
@ -89,6 +84,19 @@ class ConfigModel(BaseModel):
for key in remove_fields:
del schema["properties"][key]
# This is purely to suppress pydantic's warnings, since this class is used everywhere.
if PYDANTIC_VERSION_2:
extra = "forbid"
ignored_types = (cached_property,)
json_schema_extra = _schema_extra
else:
extra = Extra.forbid
underscore_attrs_are_private = True
keep_untouched = (
cached_property,
) # needed to allow cached_property to work. See https://github.com/samuelcolvin/pydantic/issues/1241 for more info.
schema_extra = _schema_extra
@classmethod
def parse_obj_allow_extras(cls: Type[_ConfigSelf], obj: Any) -> _ConfigSelf:
with unittest.mock.patch.object(cls.Config, "extra", pydantic.Extra.allow):
@ -102,7 +110,10 @@ class PermissiveConfigModel(ConfigModel):
# It is usually used for argument bags that are passed through to third-party libraries.
class Config:
extra = Extra.allow
if PYDANTIC_VERSION_2:
extra = "allow"
else:
extra = Extra.allow
class TransformerSemantics(ConfigEnum):

View File

@ -24,11 +24,11 @@ class OAuthConfiguration(ConfigModel):
default=False,
)
client_secret: Optional[SecretStr] = Field(
description="client secret of the application if use_certificate = false"
None, description="client secret of the application if use_certificate = false"
)
encoded_oauth_public_key: Optional[str] = Field(
description="base64 encoded certificate content if use_certificate = true"
None, description="base64 encoded certificate content if use_certificate = true"
)
encoded_oauth_private_key: Optional[str] = Field(
description="base64 encoded private key content if use_certificate = true"
None, description="base64 encoded private key content if use_certificate = true"
)

View File

@ -0,0 +1,30 @@
import pydantic.version
from packaging.version import Version
PYDANTIC_VERSION_2: bool
if Version(pydantic.version.VERSION) >= Version("2.0"):
PYDANTIC_VERSION_2 = True
else:
PYDANTIC_VERSION_2 = False
# This can be used to silence deprecation warnings while we migrate.
if PYDANTIC_VERSION_2:
from pydantic import PydanticDeprecatedSince20 # type: ignore
else:
class PydanticDeprecatedSince20(Warning): # type: ignore
pass
if PYDANTIC_VERSION_2:
from pydantic import BaseModel as GenericModel
else:
from pydantic.generics import GenericModel # type: ignore
__all__ = [
"PYDANTIC_VERSION_2",
"PydanticDeprecatedSince20",
"GenericModel",
]

View File

@ -127,7 +127,7 @@ class BucketKey(ContainerKey):
class NotebookKey(DatahubKey):
notebook_id: int
platform: str
instance: Optional[str]
instance: Optional[str] = None
def as_urn(self) -> str:
return make_dataset_urn_with_platform_instance(

View File

@ -26,11 +26,11 @@ def _try_reformat_with_black(code: str) -> str:
class WorkUnitRecordExtractorConfig(ConfigModel):
set_system_metadata = True
set_system_metadata_pipeline_name = (
set_system_metadata: bool = True
set_system_metadata_pipeline_name: bool = (
False # false for now until the models are available in OSS
)
unpack_mces_into_mcps = False
unpack_mces_into_mcps: bool = False
class WorkUnitRecordExtractor(

View File

@ -72,7 +72,7 @@ class PipelineConfig(ConfigModel):
source: SourceConfig
sink: DynamicTypedConfig
transformers: Optional[List[DynamicTypedConfig]]
transformers: Optional[List[DynamicTypedConfig]] = None
flags: FlagsConfig = Field(default=FlagsConfig(), hidden_from_docs=True)
reporting: List[ReporterConfig] = []
run_id: str = DEFAULT_RUN_ID

View File

@ -265,7 +265,7 @@ class BigQueryV2Config(
description="Option to exclude empty projects from being ingested.",
)
@root_validator(pre=False)
@root_validator(skip_on_failure=True)
def profile_default_settings(cls, values: Dict) -> Dict:
# Extra default SQLAlchemy option for better connection pooling and threading.
# https://docs.sqlalchemy.org/en/14/core/pooling.html#sqlalchemy.pool.QueuePool.params.max_overflow

View File

@ -214,7 +214,7 @@ class PathSpec(ConfigModel):
logger.debug(f"Setting _glob_include: {glob_include}")
return glob_include
@pydantic.root_validator()
@pydantic.root_validator(skip_on_failure=True)
def validate_path_spec(cls, values: Dict) -> Dict[str, Any]:
# validate that main fields are populated
required_fields = ["include", "file_types", "default_extension"]

View File

@ -80,7 +80,7 @@ class DataHubSourceConfig(StatefulIngestionConfigBase):
hidden_from_docs=True,
)
@root_validator
@root_validator(skip_on_failure=True)
def check_ingesting_data(cls, values):
if (
not values.get("database_connection")

View File

@ -46,6 +46,7 @@ class DBTCloudConfig(DBTCommonConfig):
description="The ID of the job to ingest metadata from.",
)
run_id: Optional[int] = Field(
None,
description="The ID of the run to ingest metadata from. If not specified, we'll default to the latest run.",
)

View File

@ -150,7 +150,7 @@ class DBTEntitiesEnabled(ConfigModel):
description="Emit metadata for test results when set to Yes or Only",
)
@root_validator
@root_validator(skip_on_failure=True)
def process_only_directive(cls, values):
# Checks that at most one is set to ONLY, and then sets the others to NO.
@ -229,7 +229,7 @@ class DBTCommonConfig(
default={},
description="mapping rules that will be executed against dbt column meta properties. Refer to the section below on dbt meta automated mappings.",
)
enable_meta_mapping = Field(
enable_meta_mapping: bool = Field(
default=True,
description="When enabled, applies the mappings that are defined through the meta_mapping directives.",
)
@ -237,7 +237,7 @@ class DBTCommonConfig(
default={},
description="mapping rules that will be executed against dbt query_tag meta properties. Refer to the section below on dbt meta automated mappings.",
)
enable_query_tag_mapping = Field(
enable_query_tag_mapping: bool = Field(
default=True,
description="When enabled, applies the mappings that are defined through the `query_tag_mapping` directives.",
)

View File

@ -100,11 +100,11 @@ class KafkaSourceConfig(
default="datahub.ingestion.source.confluent_schema_registry.ConfluentSchemaRegistry",
description="The fully qualified implementation class(custom) that implements the KafkaSchemaRegistryBase interface.",
)
schema_tags_field = pydantic.Field(
schema_tags_field: str = pydantic.Field(
default="tags",
description="The field name in the schema metadata that contains the tags to be added to the dataset.",
)
enable_meta_mapping = pydantic.Field(
enable_meta_mapping: bool = pydantic.Field(
default=True,
description="When enabled, applies the mappings that are defined through the meta_mapping directives.",
)

View File

@ -275,7 +275,7 @@ class LookMLSourceConfig(
)
return conn_map
@root_validator()
@root_validator(skip_on_failure=True)
def check_either_connection_map_or_connection_provided(cls, values):
"""Validate that we must either have a connection map or an api credential"""
if not values.get("connection_to_platform_map", {}) and not values.get(
@ -286,7 +286,7 @@ class LookMLSourceConfig(
)
return values
@root_validator()
@root_validator(skip_on_failure=True)
def check_either_project_name_or_api_provided(cls, values):
"""Validate that we must either have a project name or an api credential to fetch project names"""
if not values.get("project_name") and not values.get("api"):
@ -1070,7 +1070,6 @@ class LookerView:
def determine_view_file_path(
cls, base_folder_path: str, absolute_file_path: str
) -> str:
splits: List[str] = absolute_file_path.split(base_folder_path, 1)
if len(splits) != 2:
logger.debug(
@ -1104,7 +1103,6 @@ class LookerView:
populate_sql_logic_in_descriptions: bool = False,
process_isolation_for_sql_parsing: bool = False,
) -> Optional["LookerView"]:
view_name = looker_view["name"]
logger.debug(f"Handling view {view_name} in model {model_name}")
# The sql_table_name might be defined in another view and this view is extending that view,
@ -2087,7 +2085,6 @@ class LookMLSource(StatefulIngestionSourceBase):
)
if looker_viewfile is not None:
for raw_view in looker_viewfile.views:
raw_view_name = raw_view["name"]
if LookerRefinementResolver.is_refinement(raw_view_name):

View File

@ -126,7 +126,7 @@ class NifiSourceConfig(EnvConfigMixin):
description="Path to PEM file containing certs for the root CA(s) for the NiFi",
)
@root_validator
@root_validator(skip_on_failure=True)
def validate_auth_params(cla, values):
if values.get("auth") is NifiAuthType.CLIENT_CERT and not values.get(
"client_cert_file"
@ -143,7 +143,7 @@ class NifiSourceConfig(EnvConfigMixin):
)
return values
@root_validator(pre=False)
@root_validator(skip_on_failure=True)
def validator_site_url_to_site_name(cls, values):
site_url_to_site_name = values.get("site_url_to_site_name")
site_url = values.get("site_url")

View File

@ -405,8 +405,7 @@ class PowerBiDashboardSourceConfig(
"Works for M-Query where native SQL is used for transformation.",
)
@root_validator
@classmethod
@root_validator(skip_on_failure=True)
def validate_extract_column_level_lineage(cls, values: Dict) -> Dict:
flags = [
"native_query_parsing",
@ -445,7 +444,7 @@ class PowerBiDashboardSourceConfig(
return value
@root_validator(pre=False)
@root_validator(skip_on_failure=True)
def workspace_id_backward_compatibility(cls, values: Dict) -> Dict:
workspace_id = values.get("workspace_id")
workspace_id_pattern = values.get("workspace_id_pattern")

View File

@ -12,21 +12,21 @@ from datahub.metadata.schema_classes import OwnerClass
class CatalogItem(BaseModel):
id: str = Field(alias="Id")
name: str = Field(alias="Name")
description: Optional[str] = Field(alias="Description")
description: Optional[str] = Field(None, alias="Description")
path: str = Field(alias="Path")
type: Any = Field(alias="Type")
type: Any = Field(None, alias="Type")
hidden: bool = Field(alias="Hidden")
size: int = Field(alias="Size")
modified_by: Optional[str] = Field(alias="ModifiedBy")
modified_date: Optional[datetime] = Field(alias="ModifiedDate")
created_by: Optional[str] = Field(alias="CreatedBy")
created_date: Optional[datetime] = Field(alias="CreatedDate")
parent_folder_id: Optional[str] = Field(alias="ParentFolderId")
content_type: Optional[str] = Field(alias="ContentType")
modified_by: Optional[str] = Field(None, alias="ModifiedBy")
modified_date: Optional[datetime] = Field(None, alias="ModifiedDate")
created_by: Optional[str] = Field(None, alias="CreatedBy")
created_date: Optional[datetime] = Field(None, alias="CreatedDate")
parent_folder_id: Optional[str] = Field(None, alias="ParentFolderId")
content_type: Optional[str] = Field(None, alias="ContentType")
content: str = Field(alias="Content")
is_favorite: bool = Field(alias="IsFavorite")
user_info: Any = Field(alias="UserInfo")
display_name: Optional[str] = Field(alias="DisplayName")
user_info: Any = Field(None, alias="UserInfo")
display_name: Optional[str] = Field(None, alias="DisplayName")
has_data_sources: bool = Field(default=False, alias="HasDataSources")
data_sources: Optional[List["DataSource"]] = Field(
default_factory=list, alias="DataSources"
@ -72,12 +72,12 @@ class DataSet(CatalogItem):
class DataModelDataSource(BaseModel):
auth_type: Optional[str] = Field(alias="AuthType")
auth_type: Optional[str] = Field(None, alias="AuthType")
supported_auth_types: List[Optional[str]] = Field(alias="SupportedAuthTypes")
kind: str = Field(alias="Kind")
model_connection_name: str = Field(alias="ModelConnectionName")
secret: str = Field(alias="Secret")
type: Optional[str] = Field(alias="Type")
type: Optional[str] = Field(None, alias="Type")
username: str = Field(alias="Username")
@ -135,21 +135,23 @@ class DataSource(CatalogItem):
is_enabled: bool = Field(alias="IsEnabled")
connection_string: str = Field(alias="ConnectionString")
data_model_data_source: Optional[DataModelDataSource] = Field(
alias="DataModelDataSource"
None, alias="DataModelDataSource"
)
data_source_sub_type: Optional[str] = Field(alias="DataSourceSubType")
data_source_type: Optional[str] = Field(alias="DataSourceType")
data_source_sub_type: Optional[str] = Field(None, alias="DataSourceSubType")
data_source_type: Optional[str] = Field(None, alias="DataSourceType")
is_original_connection_string_expression_based: bool = Field(
alias="IsOriginalConnectionStringExpressionBased"
)
is_connection_string_overridden: bool = Field(alias="IsConnectionStringOverridden")
credentials_by_user: Optional[CredentialsByUser] = Field(alias="CredentialsByUser")
credentials_by_user: Optional[CredentialsByUser] = Field(
None, alias="CredentialsByUser"
)
credentials_in_server: Optional[CredentialsInServer] = Field(
alias="CredentialsInServer"
None, alias="CredentialsInServer"
)
is_reference: bool = Field(alias="IsReference")
subscriptions: Optional[Subscription] = Field(alias="Subscriptions")
meta_data: Optional[MetaData] = Field(alias="MetaData")
subscriptions: Optional[Subscription] = Field(None, alias="Subscriptions")
meta_data: Optional[MetaData] = Field(None, alias="MetaData")
def __members(self):
return (self.id,)
@ -274,15 +276,15 @@ class Owner(BaseModel):
class CorpUserEditableInfo(BaseModel):
display_name: str = Field(alias="displayName")
title: str
about_me: Optional[str] = Field(alias="aboutMe")
teams: Optional[List[str]]
skills: Optional[List[str]]
picture_link: Optional[str] = Field(alias="pictureLink")
about_me: Optional[str] = Field(None, alias="aboutMe")
teams: Optional[List[str]] = None
skills: Optional[List[str]] = None
picture_link: Optional[str] = Field(None, alias="pictureLink")
class CorpUserEditableProperties(CorpUserEditableInfo):
slack: Optional[str]
phone: Optional[str]
slack: Optional[str] = None
phone: Optional[str] = None
email: str
@ -305,21 +307,21 @@ class EntityRelationshipsResult(BaseModel):
start: int
count: int
total: int
relationships: Optional[EntityRelationship]
relationships: Optional[EntityRelationship] = None
class CorpUserProperties(BaseModel):
active: bool
display_name: str = Field(alias="displayName")
email: str
title: Optional[str]
manager: Optional["CorpUser"]
department_id: Optional[int] = Field(alias="departmentId")
department_name: Optional[str] = Field(alias="departmentName")
first_name: Optional[str] = Field(alias="firstName")
last_name: Optional[str] = Field(alias="lastName")
full_name: Optional[str] = Field(alias="fullName")
country_code: Optional[str] = Field(alias="countryCode")
title: Optional[str] = None
manager: Optional["CorpUser"] = None
department_id: Optional[int] = Field(None, alias="departmentId")
department_name: Optional[str] = Field(None, alias="departmentName")
first_name: Optional[str] = Field(None, alias="firstName")
last_name: Optional[str] = Field(None, alias="lastName")
full_name: Optional[str] = Field(None, alias="fullName")
country_code: Optional[str] = Field(None, alias="countryCode")
class CorpUser(BaseModel):
@ -328,13 +330,13 @@ class CorpUser(BaseModel):
username: str
properties: CorpUserProperties
editable_properties: Optional[CorpUserEditableProperties] = Field(
alias="editableProperties"
None, alias="editableProperties"
)
status: Optional[CorpUserStatus]
tags: Optional[GlobalTags]
relationships: Optional[EntityRelationshipsResult]
editableInfo: Optional[CorpUserEditableInfo] = Field(alias="editableInfo")
global_tags: Optional[GlobalTags] = Field(alias="globalTags")
status: Optional[CorpUserStatus] = None
tags: Optional[GlobalTags] = None
relationships: Optional[EntityRelationshipsResult] = None
editableInfo: Optional[CorpUserEditableInfo] = Field(None, alias="editableInfo")
global_tags: Optional[GlobalTags] = Field(None, alias="globalTags")
def get_urn_part(self):
return "{}".format(self.username)
@ -353,7 +355,7 @@ class CorpUser(BaseModel):
class OwnershipData(BaseModel):
existing_owners: Optional[List[OwnerClass]] = []
owner_to_add: Optional[CorpUser]
owner_to_add: Optional[CorpUser] = None
class Config:
arbitrary_types_allowed = True

View File

@ -81,7 +81,7 @@ class RedshiftConfig(
# Because of this behavior, it uses dramatically fewer round trips for
# large Redshift warehouses. As an example, see this query for the columns:
# https://github.com/sqlalchemy-redshift/sqlalchemy-redshift/blob/60b4db04c1d26071c291aeea52f1dcb5dd8b0eb0/sqlalchemy_redshift/dialect.py#L745.
scheme = Field(
scheme: str = Field(
default="redshift+psycopg2",
description="",
hidden_from_schema=True,
@ -150,14 +150,14 @@ class RedshiftConfig(
), "email_domain needs to be set if usage is enabled"
return values
@root_validator()
@root_validator(skip_on_failure=True)
def check_database_or_database_alias_set(cls, values):
assert values.get("database") or values.get(
"database_alias"
), "either database or database_alias must be set"
return values
@root_validator(pre=False)
@root_validator(skip_on_failure=True)
def backward_compatibility_configs_set(cls, values: Dict) -> Dict:
match_fully_qualified_names = values.get("match_fully_qualified_names")

View File

@ -144,7 +144,7 @@ class DataLakeSourceConfig(
raise ValueError("platform must not be empty")
return platform
@pydantic.root_validator()
@pydantic.root_validator(skip_on_failure=True)
def ensure_profiling_pattern_is_passed_to_profiling(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:

View File

@ -72,7 +72,7 @@ class DataLakeProfilerConfig(ConfigModel):
description="Whether to profile for the sample values for all columns.",
)
@pydantic.root_validator()
@pydantic.root_validator(skip_on_failure=True)
def ensure_field_level_settings_are_normalized(
cls: "DataLakeProfilerConfig", values: Dict[str, Any]
) -> Dict[str, Any]:

View File

@ -83,7 +83,7 @@ class SalesforceProfilingConfig(ConfigModel):
class SalesforceConfig(DatasetSourceConfigMixin):
platform = "salesforce"
platform: str = "salesforce"
auth: SalesforceAuthType = SalesforceAuthType.USERNAME_PASSWORD

View File

@ -79,30 +79,30 @@ class SnowflakeColumnReference(PermissiveModel):
class SnowflakeObjectAccessEntry(PermissiveModel):
columns: Optional[List[SnowflakeColumnReference]]
columns: Optional[List[SnowflakeColumnReference]] = None
objectDomain: str
objectName: str
# Seems like it should never be null, but in practice have seen null objectIds
objectId: Optional[int]
stageKind: Optional[str]
objectId: Optional[int] = None
stageKind: Optional[str] = None
class SnowflakeJoinedAccessEvent(PermissiveModel):
query_start_time: datetime
query_text: str
query_type: str
rows_inserted: Optional[int]
rows_updated: Optional[int]
rows_deleted: Optional[int]
rows_inserted: Optional[int] = None
rows_updated: Optional[int] = None
rows_deleted: Optional[int] = None
base_objects_accessed: List[SnowflakeObjectAccessEntry]
direct_objects_accessed: List[SnowflakeObjectAccessEntry]
objects_modified: List[SnowflakeObjectAccessEntry]
user_name: str
first_name: Optional[str]
last_name: Optional[str]
display_name: Optional[str]
email: Optional[str]
first_name: Optional[str] = None
last_name: Optional[str] = None
display_name: Optional[str] = None
email: Optional[str] = None
role_name: str

View File

@ -5,12 +5,11 @@ from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import clickhouse_driver # noqa: F401
import clickhouse_driver
import clickhouse_sqlalchemy.types as custom_types
import pydantic
from clickhouse_sqlalchemy.drivers import base
from clickhouse_sqlalchemy.drivers.base import ClickHouseDialect
from pydantic.class_validators import root_validator
from pydantic.fields import Field
from sqlalchemy import create_engine, text
from sqlalchemy.engine import reflection
@ -59,6 +58,8 @@ from datahub.metadata.schema_classes import (
UpstreamClass,
)
assert clickhouse_driver
# adding extra types not handled by clickhouse-sqlalchemy 0.1.8
base.ischema_names["DateTime64(0)"] = DATETIME
base.ischema_names["DateTime64(1)"] = DATETIME
@ -126,8 +127,8 @@ class ClickHouseConfig(
TwoTierSQLAlchemyConfig, BaseTimeWindowConfig, DatasetLineageProviderConfigBase
):
# defaults
host_port = Field(default="localhost:8123", description="ClickHouse host URL.")
scheme = Field(default="clickhouse", description="", hidden_from_docs=True)
host_port: str = Field(default="localhost:8123", description="ClickHouse host URL.")
scheme: str = Field(default="clickhouse", description="", hidden_from_docs=True)
password: pydantic.SecretStr = Field(
default=pydantic.SecretStr(""), description="password"
)
@ -165,7 +166,7 @@ class ClickHouseConfig(
return str(url)
# pre = True because we want to take some decision before pydantic initialize the configuration to default values
@root_validator(pre=True)
@pydantic.root_validator(pre=True)
def projects_backward_compatibility(cls, values: Dict) -> Dict:
secure = values.get("secure")
protocol = values.get("protocol")

View File

@ -32,7 +32,7 @@ DruidDialect.get_table_names = get_table_names
class DruidConfig(BasicSQLAlchemyConfig):
# defaults
scheme = "druid"
scheme: str = "druid"
schema_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern(deny=["^(lookup|sysgit|view).*"]),
description="regex patterns for schemas to filter in ingestion.",

View File

@ -122,7 +122,7 @@ HiveDialect.get_view_definition = get_view_definition_patched
class HiveConfig(TwoTierSQLAlchemyConfig):
# defaults
scheme = Field(default="hive", hidden_from_docs=True)
scheme: str = Field(default="hive", hidden_from_docs=True)
@validator("host_port")
def clean_host_port(cls, v):

View File

@ -48,8 +48,8 @@ base.ischema_names["decimal128"] = DECIMAL128
class MySQLConnectionConfig(SQLAlchemyConnectionConfig):
# defaults
host_port = Field(default="localhost:3306", description="MySQL host URL.")
scheme = "mysql+pymysql"
host_port: str = Field(default="localhost:3306", description="MySQL host URL.")
scheme: str = "mysql+pymysql"
class MySQLConfig(MySQLConnectionConfig, TwoTierSQLAlchemyConfig):

View File

@ -98,8 +98,10 @@ class ViewLineageEntry(BaseModel):
class BasePostgresConfig(BasicSQLAlchemyConfig):
scheme = Field(default="postgresql+psycopg2", description="database scheme")
schema_pattern = Field(default=AllowDenyPattern(deny=["information_schema"]))
scheme: str = Field(default="postgresql+psycopg2", description="database scheme")
schema_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern(deny=["information_schema"])
)
class PostgresConfig(BasePostgresConfig):

View File

@ -85,7 +85,7 @@ PrestoDialect._get_full_table = _get_full_table
class PrestoConfig(TrinoConfig):
# defaults
scheme = Field(default="presto", description="", hidden_from_docs=True)
scheme: str = Field(default="presto", description="", hidden_from_docs=True)
@platform_name("Presto", doc_order=1)

View File

@ -145,7 +145,7 @@ class RedshiftConfig(
# Because of this behavior, it uses dramatically fewer round trips for
# large Redshift warehouses. As an example, see this query for the columns:
# https://github.com/sqlalchemy-redshift/sqlalchemy-redshift/blob/60b4db04c1d26071c291aeea52f1dcb5dd8b0eb0/sqlalchemy_redshift/dialect.py#L745.
scheme = Field(
scheme: str = Field(
default="redshift+psycopg2",
description="",
hidden_from_docs=True,

View File

@ -107,7 +107,7 @@ class SQLCommonConfig(
values["view_pattern"] = table_pattern
return values
@pydantic.root_validator()
@pydantic.root_validator(skip_on_failure=True)
def ensure_profiling_pattern_is_passed_to_profiling(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:

View File

@ -70,7 +70,7 @@ class TeradataReport(ProfilingSqlReport, IngestionStageReport, BaseTimeWindowRep
class BaseTeradataConfig(TwoTierSQLAlchemyConfig):
scheme = Field(default="teradatasql", description="database scheme")
scheme: str = Field(default="teradatasql", description="database scheme")
class TeradataConfig(BaseTeradataConfig, BaseTimeWindowConfig):

View File

@ -133,7 +133,7 @@ TrinoDialect._get_columns = _get_columns
class TrinoConfig(BasicSQLAlchemyConfig):
# defaults
scheme = Field(default="trino", description="", hidden_from_docs=True)
scheme: str = Field(default="trino", description="", hidden_from_docs=True)
def get_identifier(self: BasicSQLAlchemyConfig, schema: str, table: str) -> str:
regular = f"{schema}.{table}"

View File

@ -5,13 +5,13 @@ from typing import Any, Dict, Generic, Optional, Type, TypeVar, cast
import pydantic
from pydantic import root_validator
from pydantic.fields import Field
from pydantic.generics import GenericModel
from datahub.configuration.common import (
ConfigModel,
ConfigurationError,
DynamicTypedConfig,
)
from datahub.configuration.pydantic_migration_helpers import GenericModel
from datahub.configuration.time_window_config import BaseTimeWindowConfig
from datahub.configuration.validate_field_rename import pydantic_renamed_field
from datahub.ingestion.api.common import PipelineContext
@ -77,7 +77,7 @@ class StatefulIngestionConfig(ConfigModel):
hidden_from_docs=True,
)
@pydantic.root_validator()
@pydantic.root_validator(skip_on_failure=True)
def validate_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("enabled"):
if values.get("state_provider") is None:
@ -112,7 +112,7 @@ class StatefulLineageConfigMixin:
"store_last_lineage_extraction_timestamp", "enable_stateful_lineage_ingestion"
)
@root_validator(pre=False)
@root_validator(skip_on_failure=True)
def lineage_stateful_option_validator(cls, values: Dict) -> Dict:
sti = values.get("stateful_ingestion")
if not sti or not sti.enabled:
@ -137,7 +137,7 @@ class StatefulProfilingConfigMixin(ConfigModel):
"store_last_profiling_timestamps", "enable_stateful_profiling"
)
@root_validator(pre=False)
@root_validator(skip_on_failure=True)
def profiling_stateful_option_validator(cls, values: Dict) -> Dict:
sti = values.get("stateful_ingestion")
if not sti or not sti.enabled:
@ -161,7 +161,7 @@ class StatefulUsageConfigMixin(BaseTimeWindowConfig):
"store_last_usage_extraction_timestamp", "enable_stateful_usage_ingestion"
)
@root_validator(pre=False)
@root_validator(skip_on_failure=True)
def last_usage_extraction_stateful_option_validator(cls, values: Dict) -> Dict:
sti = values.get("stateful_ingestion")
if not sti or not sti.enabled:

View File

@ -105,7 +105,7 @@ class SupersetConfig(StatefulIngestionConfigBase, ConfigModel):
def remove_trailing_slash(cls, v):
return config_clean.remove_trailing_slashes(v)
@root_validator
@root_validator(skip_on_failure=True)
def default_display_uri_to_connect_uri(cls, values):
base = values.get("display_uri")
if base is None:

View File

@ -76,7 +76,7 @@ class UnityCatalogProfilerConfig(ConfigModel):
description="Number of worker threads to use for profiling. Set to 1 to disable.",
)
@pydantic.root_validator
@pydantic.root_validator(skip_on_failure=True)
def warehouse_id_required_for_profiling(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:

View File

@ -340,7 +340,6 @@ class BaseSnowflakeConfig(ConfigModel):
class SnowflakeConfig(BaseSnowflakeConfig, BaseTimeWindowConfig, SQLCommonConfig):
include_table_lineage: bool = pydantic.Field(
default=True,
description="If enabled, populates the snowflake table-to-table and s3-to-snowflake table lineage. Requires appropriate grants given to the role and Snowflake Enterprise Edition or above.",
@ -357,7 +356,7 @@ class SnowflakeConfig(BaseSnowflakeConfig, BaseTimeWindowConfig, SQLCommonConfig
ignore_start_time_lineage: bool = False
upstream_lineage_in_report: bool = False
@pydantic.root_validator()
@pydantic.root_validator(skip_on_failure=True)
def validate_include_view_lineage(cls, values):
if (
"include_table_lineage" in values

View File

@ -44,7 +44,7 @@ class BigQueryCredential(ConfigModel):
description="If not set it will be default to https://www.googleapis.com/robot/v1/metadata/x509/client_email",
)
@pydantic.root_validator()
@pydantic.root_validator(skip_on_failure=True)
def validate_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("client_x509_cert_url") is None:
values[

View File

@ -23,18 +23,18 @@ T = TypeVar("T")
class VersionStats(BaseModel, arbitrary_types_allowed=True):
version: Version
release_date: Optional[datetime]
release_date: Optional[datetime] = None
class ServerVersionStats(BaseModel):
current: VersionStats
latest: Optional[VersionStats]
current_server_type: Optional[str]
latest: Optional[VersionStats] = None
current_server_type: Optional[str] = None
class ClientVersionStats(BaseModel):
current: VersionStats
latest: Optional[VersionStats]
latest: Optional[VersionStats] = None
class DataHubVersionStats(BaseModel):

View File

@ -1,6 +1,8 @@
import random
from typing import Dict, Iterator, List, Set, TypeVar, Union
from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2
T = TypeVar("T")
_KT = TypeVar("_KT")
_VT = TypeVar("_VT")
@ -41,6 +43,16 @@ class LossyList(List[T]):
def __str__(self) -> str:
return repr(self)
if PYDANTIC_VERSION_2:
# With pydantic 2, it doesn't recognize that this is a list subclass,
# so we need to make it explicit.
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler): # type: ignore
from pydantic_core import core_schema
return core_schema.no_info_after_validator_function(cls, handler(list))
def as_obj(self) -> List[Union[T, str]]:
base_list: List[Union[T, str]] = list(self.__iter__())
if self.sampled:

View File

@ -17,6 +17,7 @@ import sqlglot.optimizer.qualify
from pydantic import BaseModel
from typing_extensions import TypedDict
from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2
from datahub.emitter.mce_builder import (
DEFAULT_ENV,
make_dataset_urn_with_platform_instance,
@ -122,12 +123,17 @@ class _ParserBaseModel(
SchemaFieldDataTypeClass: lambda v: v.to_obj(),
},
):
pass
def json(self, *args: Any, **kwargs: Any) -> str:
if PYDANTIC_VERSION_2:
return super().model_dump_json(*args, **kwargs) # type: ignore
else:
return super().json(*args, **kwargs)
@functools.total_ordering
class _FrozenModel(_ParserBaseModel, frozen=True):
def __lt__(self, other: "_FrozenModel") -> bool:
# TODO: The __fields__ attribute is deprecated in Pydantic v2.
for field in self.__fields__:
self_v = getattr(self, field)
other_v = getattr(other, field)
@ -138,8 +144,8 @@ class _FrozenModel(_ParserBaseModel, frozen=True):
class _TableName(_FrozenModel):
database: Optional[str]
db_schema: Optional[str]
database: Optional[str] = None
db_schema: Optional[str] = None
table: str
def as_sqlglot_table(self) -> sqlglot.exp.Table:
@ -187,16 +193,16 @@ class ColumnRef(_ParserBaseModel):
class _DownstreamColumnRef(_ParserBaseModel):
table: Optional[_TableName]
table: Optional[_TableName] = None
column: str
column_type: Optional[sqlglot.exp.DataType]
column_type: Optional[sqlglot.exp.DataType] = None
class DownstreamColumnRef(_ParserBaseModel):
table: Optional[Urn]
table: Optional[Urn] = None
column: str
column_type: Optional[SchemaFieldDataTypeClass]
native_column_type: Optional[str]
column_type: Optional[SchemaFieldDataTypeClass] = None
native_column_type: Optional[str] = None
@pydantic.validator("column_type", pre=True)
def _load_column_type(
@ -213,7 +219,7 @@ class _ColumnLineageInfo(_ParserBaseModel):
downstream: _DownstreamColumnRef
upstreams: List[_ColumnRef]
logic: Optional[str]
logic: Optional[str] = None
class ColumnLineageInfo(_ParserBaseModel):
@ -244,7 +250,7 @@ class SqlParsingResult(_ParserBaseModel):
in_tables: List[Urn]
out_tables: List[Urn]
column_lineage: Optional[List[ColumnLineageInfo]]
column_lineage: Optional[List[ColumnLineageInfo]] = None
# TODO include formatted original sql logic
# TODO include list of referenced columns