refactor(ingest): move common host_port validation (#6009)

This commit is contained in:
Harshal Sheth 2022-09-22 16:32:07 -07:00 committed by GitHub
parent b8941ab190
commit 27f28019de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 61 additions and 72 deletions

View File

@ -1,8 +1,7 @@
import re
from pydantic import Field, validator
from datahub.configuration.common import ConfigModel
from datahub.configuration.validate_host_port import validate_host_port
class _KafkaConnectionConfig(ConfigModel):
@ -20,21 +19,7 @@ class _KafkaConnectionConfig(ConfigModel):
@validator("bootstrap")
def bootstrap_host_colon_port_comma(cls, val: str) -> str:
for entry in val.split(","):
# The port can be provided but is not required.
port = None
if ":" in entry:
(host, port) = entry.rsplit(":", 1)
else:
host = entry
assert re.match(
# This regex is quite loose. Many invalid hostname's or IPs will slip through,
# but it serves as a good first line of validation. We defer to Kafka for the
# remaining validation.
r"^[\w\-\.\:]+$",
host,
), f"host contains bad characters, found {host}"
if port is not None:
assert port.isdigit(), f"port must be all digits, found {port}"
validate_host_port(entry)
return val

View File

@ -13,7 +13,7 @@ def _default_rename_transform(value: _T) -> _T:
def pydantic_renamed_field(
old_name: str,
new_name: str,
transform: Callable[[_T], _T] = _default_rename_transform,
transform: Callable = _default_rename_transform,
) -> classmethod:
def _validate_field_rename(cls: Type, values: dict) -> dict:
if old_name in values:

View File

@ -0,0 +1,26 @@
import re
def validate_host_port(host_port: str) -> None:
"""
Validates that a host or host:port string is valid.
This makes the assumption that the port is optional, and
requires that there is no proto:// prefix or trailing path.
"""
# The port can be provided but is not required.
port = None
if ":" in host_port:
(host, port) = host_port.rsplit(":", 1)
else:
host = host_port
assert re.match(
# This regex is quite loose. Some invalid hostname's or IPs will slip through,
# but it serves as a good first line of validation. We defer to the underlying
# system for the remaining validation.
r"^[\w\-\.\:]+$",
host,
), f"host contains bad characters, found {host}"
if port is not None:
assert port.isdigit(), f"port must be all digits, found {port}"

View File

@ -1,14 +1,15 @@
import logging
from typing import Callable, Union
from typing import Callable, Dict, Union
import pydantic
from confluent_kafka import SerializingProducer
from confluent_kafka.schema_registry import SchemaRegistryClient
from confluent_kafka.schema_registry.avro import AvroSerializer
from confluent_kafka.serialization import SerializationContext, StringSerializer
from pydantic import Field, root_validator
from datahub.configuration.common import ConfigModel, ConfigurationError
from datahub.configuration.common import ConfigModel
from datahub.configuration.kafka import KafkaProducerConnectionConfig
from datahub.configuration.validate_field_rename import pydantic_renamed_field
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.metadata.schema_classes import (
MetadataChangeEventClass as MetadataChangeEvent,
@ -29,32 +30,28 @@ MCP_KEY = "MetadataChangeProposal"
class KafkaEmitterConfig(ConfigModel):
connection: KafkaProducerConnectionConfig = Field(
connection: KafkaProducerConnectionConfig = pydantic.Field(
default_factory=KafkaProducerConnectionConfig
)
topic: str = DEFAULT_MCE_KAFKA_TOPIC
topic_routes: dict = {
topic_routes: Dict[str, str] = {
MCE_KEY: DEFAULT_MCE_KAFKA_TOPIC,
MCP_KEY: DEFAULT_MCP_KAFKA_TOPIC,
}
@root_validator
def validate_topic_routes(cls: "KafkaEmitterConfig", values: dict) -> dict:
old_topic = values["topic"]
new_mce_topic = values["topic_routes"][MCE_KEY]
if old_topic != DEFAULT_MCE_KAFKA_TOPIC:
# Looks like a non default topic has been set using the old style
if new_mce_topic != DEFAULT_MCE_KAFKA_TOPIC:
# Looks like a non default topic has ALSO been set using the new style
raise ConfigurationError(
"Using both topic and topic_routes configuration for Kafka is not supported. Use only topic_routes"
)
logger.warning(
"Looks like you're using the deprecated `topic` configuration. Please migrate to `topic_routes`."
)
# upgrade topic provided to topic_routes mce entry
values["topic_routes"][MCE_KEY] = values["topic"]
return values
_topic_field_compat = pydantic_renamed_field(
"topic",
"topic_routes",
transform=lambda x: {
MCE_KEY: x,
MCP_KEY: DEFAULT_MCP_KAFKA_TOPIC,
},
)
@pydantic.validator("topic_routes")
def validate_topic_routes(cls, v: Dict[str, str]) -> Dict[str, str]:
assert MCE_KEY in v, f"topic_routes must contain a route for {MCE_KEY}"
assert MCP_KEY in v, f"topic_routes must contain a route for {MCP_KEY}"
return v
class DatahubKafkaEmitter:

View File

@ -1,6 +1,5 @@
import json
import logging
import re
from collections import defaultdict
from dataclasses import dataclass, field
from hashlib import md5
@ -10,8 +9,9 @@ from elasticsearch import Elasticsearch
from pydantic import validator
from pydantic.fields import Field
from datahub.configuration.common import AllowDenyPattern, ConfigurationError
from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.source_common import DatasetSourceConfigBase
from datahub.configuration.validate_host_port import validate_host_port
from datahub.emitter.mce_builder import (
make_data_platform_urn,
make_dataplatform_instance_urn,
@ -51,6 +51,7 @@ from datahub.metadata.schema_classes import (
StringTypeClass,
SubTypesClass,
)
from datahub.utilities.config_clean import remove_protocol
logger = logging.getLogger(__name__)
@ -248,29 +249,11 @@ class ElasticsearchSourceConfig(DatasetSourceConfigBase):
@validator("host")
def host_colon_port_comma(cls, host_val: str) -> str:
for entry in host_val.split(","):
# The port can be provided but is not required.
port = None
for prefix in ["http://", "https://"]:
if entry.startswith(prefix):
entry = entry[len(prefix) :]
entry = remove_protocol(entry)
for suffix in ["/"]:
if entry.endswith(suffix):
entry = entry[: -len(suffix)]
if ":" in entry:
(host, port) = entry.rsplit(":", 1)
else:
host = entry
if not re.match(
# This regex is quite loose. Many invalid hostnames or IPs will slip through,
# but it serves as a good first line of validation. We defer to Elastic for the
# remaining validation.
r"^[\w\-\.]+$",
host,
):
raise ConfigurationError(f"host contains bad characters, found {host}")
if port is not None and not port.isdigit():
raise ConfigurationError(f"port must be all digits, found {port}")
validate_host_port(entry)
return host_val
@property

View File

@ -3,9 +3,9 @@ import logging
import re
from typing import Any, Dict, List, Tuple
import pydantic
import pytest
from datahub.configuration.common import ConfigurationError
from datahub.ingestion.source.elastic_search import (
ElasticsearchSourceConfig,
ElasticToSchemaFieldConverter,
@ -2467,8 +2467,6 @@ def test_host_port_parsing() -> None:
for bad_example in bad_examples:
config_dict = {"host": bad_example}
try:
config = ElasticsearchSourceConfig.parse_obj(config_dict)
assert False, f"{bad_example} should throw exception"
except Exception as e:
assert isinstance(e, ConfigurationError)
with pytest.raises(pydantic.ValidationError):
ElasticsearchSourceConfig.parse_obj(config_dict)

View File

@ -1,8 +1,8 @@
import unittest
import pydantic
import pytest
from datahub.configuration.common import ConfigurationError
from datahub.emitter.kafka_emitter import (
DEFAULT_MCE_KAFKA_TOPIC,
DEFAULT_MCP_KAFKA_TOPIC,
@ -25,8 +25,8 @@ class KafkaEmitterTest(unittest.TestCase):
"""
def test_kafka_emitter_config_old_and_new(self):
with pytest.raises(ConfigurationError):
emitter_config = KafkaEmitterConfig.parse_obj( # noqa 841
with pytest.raises(pydantic.ValidationError):
KafkaEmitterConfig.parse_obj(
{
"connection": {"bootstrap": "foobar:9092"},
"topic": "NewTopic",