feat(ingest): upgrade pydantic version (#6858)

This PR also removes the requirement on docker-compose v1 and makes our tests use v2 instead.

Co-authored-by: Harshal Sheth <hsheth2@gmail.com>
This commit is contained in:
cccs-eric 2022-12-27 17:06:16 -05:00 committed by GitHub
parent d851140048
commit ec8a4e0eab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 54 additions and 38 deletions

View File

@ -396,12 +396,11 @@ base_dev_requirements = {
"mypy==0.991",
# pydantic 1.8.2 is incompatible with mypy 0.910.
# See https://github.com/samuelcolvin/pydantic/pull/3175#issuecomment-995382910.
# Restricting top version to <1.10 until we can fix our types.
"pydantic >=1.9.0, <1.10",
"pydantic >=1.9.0",
"pytest>=6.2.2",
"pytest-asyncio>=0.16.0",
"pytest-cov>=2.8.1",
"pytest-docker[docker-compose-v1]>=1.0.1",
"pytest-docker>=1.0.1",
"deepdiff",
"requests-mock",
"freezegun",

View File

@ -882,7 +882,7 @@ def ingest_sample_data(path: Optional[str], token: Optional[str]) -> None:
if path is None:
click.echo("Downloading sample data...")
path = download_sample_data()
path = str(download_sample_data())
# Verify that docker is up.
issues = check_local_docker_containers()

View File

@ -57,13 +57,15 @@ class InfoTypeConfig(ConfigModel):
description="Factors and their weights to consider when predicting info types",
alias="prediction_factors_and_weights",
)
Name: Optional[NameFactorConfig] = Field(alias="name")
Name: Optional[NameFactorConfig] = Field(default=None, alias="name")
Description: Optional[DescriptionFactorConfig] = Field(alias="description")
Description: Optional[DescriptionFactorConfig] = Field(
default=None, alias="description"
)
Datatype: Optional[DataTypeFactorConfig] = Field(alias="datatype")
Datatype: Optional[DataTypeFactorConfig] = Field(default=None, alias="datatype")
Values: Optional[ValuesFactorConfig] = Field(alias="values")
Values: Optional[ValuesFactorConfig] = Field(default=None, alias="values")
# TODO: Generate Classification doc (classification.md) from python source.

View File

@ -20,7 +20,7 @@ logger: logging.Logger = logging.getLogger(__name__)
class S3(ConfigModel):
aws_config: AwsConnectionConfig = Field(
aws_config: Optional[AwsConnectionConfig] = Field(
default=None, description="AWS configuration"
)
@ -40,7 +40,7 @@ class DeltaLakeSourceConfig(PlatformSourceConfigBase, EnvBasedSourceConfigBase):
description="Path to table (s3 or local file system). If path is not a delta table path "
"then all subfolders will be scanned to detect and ingest delta tables."
)
relative_path: str = Field(
relative_path: Optional[str] = Field(
default=None,
description="If set, delta-tables will be searched at location "
"'<base_path>/<relative_path>' and URNs will be created using "

View File

@ -24,7 +24,7 @@ class DemoDataSource(GenericFileSource):
"""
def __init__(self, ctx: PipelineContext, config: DemoDataConfig):
file_config = FileSourceConfig(filename=download_sample_data())
file_config = FileSourceConfig(path=download_sample_data())
super().__init__(ctx, file_config)
@classmethod

View File

@ -116,7 +116,7 @@ class PowerBiAPIConfig(EnvBasedSourceConfigBase):
default=AllowDenyPattern.allow_all(),
description="Regex patterns to filter PowerBI workspaces in ingestion",
)
workspace_id: str = pydantic.Field(
workspace_id: Optional[str] = pydantic.Field(
description="[deprecated] Use workspace_id_pattern instead",
default=None,
)

View File

@ -15,7 +15,7 @@ from datahub.ingestion.source.state.stateful_ingestion_base import (
class UnityCatalogSourceConfig(StatefulIngestionConfigBase):
token: str = pydantic.Field(description="Databricks personal access token")
workspace_url: str = pydantic.Field(description="Databricks workspace url")
workspace_name: str = pydantic.Field(
workspace_name: Optional[str] = pydantic.Field(
default=None,
description="Name of the workspace. Default to deployment name present in workspace_url",
)

View File

@ -140,7 +140,7 @@ class RedshiftAccessEvent(BaseModel):
username: str
query: int
tbl: int
text: str = Field(None, alias="querytxt")
text: Optional[str] = Field(None, alias="querytxt")
database: str
schema_: str = Field(alias="schema")
table: str

View File

@ -4,7 +4,7 @@ import json
import logging
from datetime import datetime
from email.utils import parseaddr
from typing import Dict, Iterable, List
from typing import Dict, Iterable, List, Optional
from dateutil import parser
from pydantic.fields import Field
@ -62,23 +62,23 @@ class TrinoConnectorInfo(BaseModel):
class TrinoAccessedMetadata(BaseModel):
catalog_name: str = Field(None, alias="catalogName")
schema_name: str = Field(None, alias="schema") # type: ignore
table: str = None # type: ignore
catalog_name: Optional[str] = Field(None, alias="catalogName")
schema_name: Optional[str] = Field(None, alias="schema")
table: Optional[str] = None
columns: List[str]
connector_info: TrinoConnectorInfo = Field(None, alias="connectorInfo")
connector_info: Optional[TrinoConnectorInfo] = Field(None, alias="connectorInfo")
class TrinoJoinedAccessEvent(BaseModel):
usr: str = None # type:ignore
query: str = None # type: ignore
catalog: str = None # type: ignore
schema_name: str = Field(None, alias="schema")
query_type: str = None # type:ignore
table: str = None # type:ignore
usr: Optional[str] = None
query: Optional[str] = None
catalog: Optional[str] = None
schema_name: Optional[str] = Field(None, alias="schema")
query_type: Optional[str] = None
table: Optional[str] = None
accessed_metadata: List[TrinoAccessedMetadata]
starttime: datetime = Field(None, alias="create_time")
endtime: datetime = Field(None, alias="end_time")
starttime: datetime = Field(alias="create_time")
endtime: datetime = Field(alias="end_time")
class EnvBasedSourceBaseConfig:
@ -233,7 +233,9 @@ class TrinoUsageSource(Source):
floored_ts = get_time_bucket(event.starttime, self.config.bucket_duration)
for metadata in event.accessed_metadata:
# Skipping queries starting with $system@
if metadata.catalog_name.startswith("$system@"):
if metadata.catalog_name and metadata.catalog_name.startswith(
"$system@"
):
logging.debug(
f"Skipping system query for {metadata.catalog_name}..."
)
@ -258,7 +260,7 @@ class TrinoUsageSource(Source):
# add @unknown.com to username
# current limitation in user stats UI, we need to provide email to show users
if "@" in parseaddr(event.usr)[1]:
if event.usr and "@" in parseaddr(event.usr)[1]:
username = event.usr
else:
username = f"{event.usr if event.usr else 'unknown'}@{self.config.email_domain}"

View File

@ -29,7 +29,7 @@ class BigQueryConfig(BigQueryBaseConfig, BaseTimeWindowConfig, SQLAlchemyConfig)
description="The number of log item will be queried per page for lineage collection",
)
credential: Optional[BigQueryCredential] = pydantic.Field(
description="BigQuery credential informations"
default=None, description="BigQuery credential informations"
)
# extra_client_options, include_table_lineage and max_query_duration are relevant only when computing the lineage.
extra_client_options: Dict[str, Any] = pydantic.Field(

View File

@ -83,7 +83,7 @@ class SnowflakeProvisionRoleConfig(ConfigModel):
description="The username to be used for provisioning of role."
)
admin_password: pydantic.SecretStr = pydantic.Field(
admin_password: Optional[pydantic.SecretStr] = pydantic.Field(
default=None,
exclude=True,
description="The password to be used for provisioning of role.",
@ -131,13 +131,16 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
description='The type of authenticator to use when connecting to Snowflake. Supports "DEFAULT_AUTHENTICATOR", "EXTERNAL_BROWSER_AUTHENTICATOR" and "KEY_PAIR_AUTHENTICATOR".',
)
host_port: Optional[str] = pydantic.Field(
description="DEPRECATED: Snowflake account. e.g. abc48144"
default=None, description="DEPRECATED: Snowflake account. e.g. abc48144"
) # Deprecated
account_id: Optional[str] = pydantic.Field(
description="Snowflake account identifier. e.g. xy12345, xy12345.us-east-2.aws, xy12345.us-central1.gcp, xy12345.central-us.azure, xy12345.us-west-2.privatelink. Refer [Account Identifiers](https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#format-2-legacy-account-locator-in-a-region) for more details."
default=None,
description="Snowflake account identifier. e.g. xy12345, xy12345.us-east-2.aws, xy12345.us-central1.gcp, xy12345.central-us.azure, xy12345.us-west-2.privatelink. Refer [Account Identifiers](https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#format-2-legacy-account-locator-in-a-region) for more details.",
) # Once host_port is removed this will be made mandatory
warehouse: Optional[str] = pydantic.Field(description="Snowflake warehouse.")
role: Optional[str] = pydantic.Field(description="Snowflake role.")
warehouse: Optional[str] = pydantic.Field(
default=None, description="Snowflake warehouse."
)
role: Optional[str] = pydantic.Field(default=None, description="Snowflake role.")
include_table_lineage: bool = pydantic.Field(
default=True,
description="If enabled, populates the snowflake table-to-table and s3-to-snowflake table lineage. Requires appropriate grants given to the role.",

View File

@ -41,7 +41,8 @@ class SnowflakeUsageConfig(
description="List of regex patterns for databases to include/exclude in usage ingestion.",
)
email_domain: Optional[str] = pydantic.Field(
description="Email domain of your organisation so users can be displayed on UI appropriately."
default=None,
description="Email domain of your organisation so users can be displayed on UI appropriately.",
)
schema_pattern: AllowDenyPattern = pydantic.Field(
default=AllowDenyPattern.allow_all(),

View File

@ -12,9 +12,9 @@ BOOTSTRAP_MCES_FILE = "metadata-ingestion/examples/mce_files/bootstrap_mce.json"
BOOTSTRAP_MCES_URL = f"{DOCKER_COMPOSE_BASE}/{BOOTSTRAP_MCES_FILE}"
def download_sample_data() -> str:
def download_sample_data() -> pathlib.Path:
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp_file:
path = str(pathlib.Path(tmp_file.name))
path = pathlib.Path(tmp_file.name)
# Download the bootstrap MCE file from GitHub.
mce_json_download_response = requests.get(BOOTSTRAP_MCES_URL)

View File

@ -42,6 +42,15 @@ def wait_for_port(
subprocess.run(f"docker logs {container_name}", shell=True, check=True)
@pytest.fixture(scope="session")
def docker_compose_command():
"""Docker Compose command to use, it could be either `docker-compose`
for Docker Compose v1 or `docker compose` for Docker Compose
v2."""
return "docker compose"
@pytest.fixture(scope="module")
def docker_compose_runner(
docker_compose_command, docker_compose_project_name, docker_setup, docker_cleanup