fix: remove string validation in YAML (#3854)

* remove string validation in YAML

* unused import

* fix import

* remove tests

* fix tests
This commit is contained in:
ZanSara 2023-01-19 10:06:53 +01:00 committed by GitHub
parent dad7b12874
commit 6f5a2fb1da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 157 deletions

View File

@ -2,7 +2,6 @@ from typing import Any, Dict, List, Optional
import re
import os
import json
import logging
from pathlib import Path
from copy import copy
@ -102,53 +101,6 @@ def read_pipeline_config_from_yaml(path: Path) -> Dict[str, Any]:
return yaml.safe_load(stream)
JSON_FIELDS = ["custom_query"]
SKIP_VALIDATION_KEYS = ["prompt_text"] # PromptTemplate, PromptNode
def validate_config_strings(pipeline_config: Any, is_value: bool = False):
"""
Ensures that strings used in the pipelines configuration
contain only alphanumeric characters and basic punctuation.
"""
try:
if isinstance(pipeline_config, dict):
for key, value in pipeline_config.items():
# FIXME find a better solution
# Some nodes take parameters that expect JSON input,
# like `ElasticsearchDocumentStore.custom_query`
# These parameters fail validation using the standard input regex,
# so they're validated separately.
#
# Note that these fields are checked by name: if two nodes have a field
# with the same name, one of which is JSON and the other not,
# this hack will break.
if key in JSON_FIELDS:
try:
json.loads(value)
except json.decoder.JSONDecodeError as e:
raise PipelineConfigError(f"'{pipeline_config}' does not contain valid JSON.") from e
elif key in SKIP_VALIDATION_KEYS:
continue
else:
validate_config_strings(key)
validate_config_strings(value, is_value=True)
elif isinstance(pipeline_config, list):
for value in pipeline_config:
validate_config_strings(value, is_value=True)
else:
valid_regex = VALID_VALUE_REGEX if is_value else VALID_KEY_REGEX
if not valid_regex.match(str(pipeline_config)):
raise PipelineConfigError(
f"'{pipeline_config}' is not a valid variable name or value. "
"Use alphanumeric characters or dash, underscore and colon only."
)
except RecursionError as e:
raise PipelineConfigError("The given pipeline configuration is recursive, can't validate it.") from e
def build_component_dependency_graph(
pipeline_definition: Dict[str, Any], component_definitions: Dict[str, Any]
) -> nx.DiGraph:
@ -215,7 +167,12 @@ def validate_yaml(
:raise: `PipelineConfigError` in case of issues.
"""
pipeline_config = read_pipeline_config_from_yaml(path)
validate_config(pipeline_config=pipeline_config, strict_version_check=strict_version_check, extras=extras)
validate_config(
pipeline_config=pipeline_config,
strict_version_check=strict_version_check,
extras=extras,
overwrite_with_env_variables=overwrite_with_env_variables,
)
logging.debug("'%s' contains valid Haystack pipelines.", path)
@ -270,7 +227,14 @@ def validate_schema(pipeline_config: Dict, strict_version_check: bool = False, e
:return: None if validation is successful
:raise: `PipelineConfigError` in case of issues.
"""
validate_config_strings(pipeline_config)
logger.debug("Validating the following config:\n%s", pipeline_config)
if not isinstance(pipeline_config, dict):
raise PipelineConfigError(
"Your pipeline configuration seems to be not a dictionary. "
"Make sure you're loading the correct one, or enable DEBUG "
"logs to see what Haystack is trying to load."
)
# Check that the extras are respected
extras_in_config = pipeline_config.get("extras", None)

View File

@ -35,7 +35,7 @@ from haystack.pipelines import (
QuestionGenerationPipeline,
MostSimilarDocumentsPipeline,
)
from haystack.pipelines.config import validate_config_strings, get_component_definitions
from haystack.pipelines.config import get_component_definitions
from haystack.pipelines.utils import generate_code
from haystack.errors import PipelineConfigError
from haystack.nodes import PreProcessor, TextConverter
@ -676,110 +676,6 @@ def test_generate_code_can_handle_weak_cyclic_pipelines():
)
@pytest.mark.parametrize("input", ["\btest", " test", "#test", "+test", "\ttest", "\ntest", "test()"])
def test_validate_user_input_invalid(input):
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
validate_config_strings(input)
@pytest.mark.parametrize(
"input",
[
"test",
"testName",
"test_name",
"test-name",
"test-name1234",
"http://localhost:8000/my-path",
"C:\\Some\\Windows\\Path\\To\\file.txt",
],
)
def test_validate_user_input_valid(input):
validate_config_strings(input)
def test_validate_pipeline_config_component_with_json_input_valid():
validate_config_strings(
{"components": [{"name": "test", "type": "test", "params": {"custom_query": '{"json-key": "json-value"}'}}]}
)
def test_validate_pipeline_config_component_with_json_input_invalid_key():
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
validate_config_strings(
{
"components": [
{"name": "test", "type": "test", "params": {"another_param": '{"json-key": "json-value"}'}}
]
}
)
def test_validate_pipeline_config_component_with_json_input_invalid_value():
with pytest.raises(PipelineConfigError, match="does not contain valid JSON"):
validate_config_strings(
{
"components": [
{"name": "test", "type": "test", "params": {"custom_query": "this is surely not JSON! :)"}}
]
}
)
def test_validate_pipeline_config_invalid_component_name():
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
validate_config_strings({"components": [{"name": "\btest"}]})
def test_validate_pipeline_config_invalid_component_type():
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
validate_config_strings({"components": [{"name": "test", "type": "\btest"}]})
def test_validate_pipeline_config_invalid_component_param():
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
validate_config_strings({"components": [{"name": "test", "type": "test", "params": {"key": "\btest"}}]})
def test_validate_pipeline_config_invalid_component_param_key():
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
validate_config_strings({"components": [{"name": "test", "type": "test", "params": {"\btest": "test"}}]})
def test_validate_pipeline_config_invalid_pipeline_name():
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
validate_config_strings({"components": [{"name": "test", "type": "test"}], "pipelines": [{"name": "\btest"}]})
def test_validate_pipeline_config_invalid_pipeline_node_name():
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
validate_config_strings(
{
"components": [{"name": "test", "type": "test"}],
"pipelines": [{"name": "test", "type": "test", "nodes": [{"name": "\btest"}]}],
}
)
def test_validate_pipeline_config_invalid_pipeline_node_inputs():
with pytest.raises(PipelineConfigError, match="is not a valid variable name or value"):
validate_config_strings(
{
"components": [{"name": "test", "type": "test"}],
"pipelines": [{"name": "test", "type": "test", "nodes": [{"name": "test", "inputs": ["\btest"]}]}],
}
)
def test_validate_pipeline_config_recursive_config(reduce_windows_recursion_limit):
pipeline_config = {}
node = {"config": pipeline_config}
pipeline_config["node"] = node
with pytest.raises(PipelineConfigError, match="recursive"):
validate_config_strings(pipeline_config)
def test_pipeline_classify_type(tmp_path):
pipe = GenerativeQAPipeline(generator=MockSeq2SegGenerator(), retriever=MockRetriever())

View File

@ -644,8 +644,8 @@ def test_load_yaml_custom_component_with_helper_class_in_yaml(tmp_path):
- Query
"""
)
with pytest.raises(PipelineConfigError, match="not a valid variable name or value"):
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
pipe = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
assert pipe.get_node("custom_node").some_exotic_parameter == 'HelperClass("hello")'
def test_load_yaml_custom_component_with_enum_in_init(tmp_path):