175 lines
4.8 KiB
Python
Raw Permalink Normal View History

from typing import Any, Dict, Optional, Tuple, Type, cast
import pytest
from pydantic import ValidationError
from datahub.configuration.common import ConfigModel
from datahub.ingestion.graph.client import DatahubClientConfig
from datahub.ingestion.source.state.stateful_ingestion_base import (
DynamicTypedStateProviderConfig,
StatefulIngestionConfig,
)
from datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider import (
DatahubIngestionStateProviderConfig,
)
# 0. Common client configs.
datahub_client_full_config = {
"server": "http://localhost:8080",
"token": "dummy_test_tok",
"timeout_sec": 10,
"extra_headers": {},
"max_threads": 10,
}
# 1. Datahub Checkpointing State Provider Config test params
checkpointing_provider_config_test_params: Dict[
str,
Tuple[
Type[DatahubIngestionStateProviderConfig],
Dict[str, Any],
Optional[DatahubIngestionStateProviderConfig],
bool,
],
] = {
# Full custom-config
"checkpointing_valid_full_config": (
DatahubIngestionStateProviderConfig,
{
"datahub_api": datahub_client_full_config,
},
DatahubIngestionStateProviderConfig(
# This test verifies that the max_threads arg is ignored.
datahub_api=DatahubClientConfig.parse_obj_allow_extras(
dict(
server="http://localhost:8080",
token="dummy_test_tok",
timeout_sec=10,
extra_headers={},
max_threads=10,
)
),
),
False,
),
# Default
"checkpointing_default": (
DatahubIngestionStateProviderConfig,
{
"datahub_api": None,
},
DatahubIngestionStateProviderConfig(
datahub_api=None,
),
False,
),
}
# 2. StatefulIngestion Config test params
stateful_ingestion_config_test_params: Dict[
str,
Tuple[
Type[StatefulIngestionConfig],
Dict[str, Any],
Optional[StatefulIngestionConfig],
bool,
],
] = {
# Ful custom-config
"stateful_ingestion_full_custom": (
StatefulIngestionConfig,
{
"enabled": True,
"max_checkpoint_state_size": 1024,
"state_provider": {
"type": "datahub",
"config": datahub_client_full_config,
},
"ignore_old_state": True,
"ignore_new_state": True,
},
StatefulIngestionConfig(
enabled=True,
max_checkpoint_state_size=1024,
ignore_old_state=True,
ignore_new_state=True,
state_provider=DynamicTypedStateProviderConfig(
type="datahub",
config=datahub_client_full_config,
),
),
False,
),
# Default disabled
"stateful_ingestion_default_disabled": (
StatefulIngestionConfig,
{},
StatefulIngestionConfig(
enabled=False,
max_checkpoint_state_size=2**24,
ignore_old_state=False,
ignore_new_state=False,
state_provider=None,
),
False,
),
# Default enabled
"stateful_ingestion_default_enabled": (
StatefulIngestionConfig,
{"enabled": True},
StatefulIngestionConfig(
enabled=True,
max_checkpoint_state_size=2**24,
ignore_old_state=False,
ignore_new_state=False,
state_provider=DynamicTypedStateProviderConfig(type="datahub"),
),
False,
),
# Bad Config- throws ValidationError
"stateful_ingestion_bad_config": (
StatefulIngestionConfig,
{"enabled": True, "state_provider": {}},
None,
True,
),
}
# 4. Combine all of the config params from 1, 2 & 3 above for the common parametrized test.
CombinedTestConfigType = Dict[
str,
Tuple[
Type[ConfigModel],
Dict[str, Any],
Optional[ConfigModel],
bool,
],
]
combined_test_configs = {
**cast(CombinedTestConfigType, checkpointing_provider_config_test_params),
**cast(CombinedTestConfigType, stateful_ingestion_config_test_params),
}
@pytest.mark.parametrize(
"config_class, config_dict, expected, raises_exception",
combined_test_configs.values(),
ids=combined_test_configs.keys(),
)
def test_state_provider_configs(
config_class: Type[ConfigModel],
config_dict: Dict[str, Any],
expected: Optional[ConfigModel],
raises_exception: bool,
) -> None:
if raises_exception:
with pytest.raises(ValidationError):
assert expected is None
config_class.parse_obj_allow_extras(config_dict)
else:
config = config_class.parse_obj_allow_extras(config_dict)
assert config == expected