diff --git a/haystack/pipelines/config.py b/haystack/pipelines/config.py index 9262fa40d..ab6b58c94 100644 --- a/haystack/pipelines/config.py +++ b/haystack/pipelines/config.py @@ -21,7 +21,8 @@ from haystack.errors import PipelineError, PipelineConfigError, PipelineSchemaEr logger = logging.getLogger(__name__) -VALID_INPUT_REGEX = re.compile(r"^[-a-zA-Z0-9_/\\.:*]+$") +VALID_KEY_REGEX = re.compile(r"^[-\w/\\.:*]+$") +VALID_VALUE_REGEX = re.compile(r"^[-\w/\\.:* \[\]]+$") VALID_ROOT_NODES = ["Query", "File"] @@ -100,7 +101,7 @@ def read_pipeline_config_from_yaml(path: Path) -> Dict[str, Any]: JSON_FIELDS = ["custom_query"] # ElasticsearchDocumentStore.custom_query -def validate_config_strings(pipeline_config: Any): +def validate_config_strings(pipeline_config: Any, is_value: bool = False): """ Ensures that strings used in the pipelines configuration contain only alphanumeric characters and basic punctuation. @@ -108,7 +109,6 @@ def validate_config_strings(pipeline_config: Any): try: if isinstance(pipeline_config, dict): for key, value in pipeline_config.items(): - # FIXME find a better solution # Some nodes take parameters that expect JSON input, # like `ElasticsearchDocumentStore.custom_query` @@ -125,14 +125,15 @@ def validate_config_strings(pipeline_config: Any): raise PipelineConfigError(f"'{pipeline_config}' does not contain valid JSON.") else: validate_config_strings(key) - validate_config_strings(value) + validate_config_strings(value, is_value=True) elif isinstance(pipeline_config, list): for value in pipeline_config: - validate_config_strings(value) + validate_config_strings(value, is_value=True) else: - if not VALID_INPUT_REGEX.match(str(pipeline_config)): + valid_regex = VALID_VALUE_REGEX if is_value else VALID_KEY_REGEX + if not valid_regex.match(str(pipeline_config)): raise PipelineConfigError( f"'{pipeline_config}' is not a valid variable name or value. " "Use alphanumeric characters or dash, underscore and colon only." diff --git a/test/pipelines/test_pipeline_yaml.py b/test/pipelines/test_pipeline_yaml.py index 94b425db0..3b7d7a4ee 100644 --- a/test/pipelines/test_pipeline_yaml.py +++ b/test/pipelines/test_pipeline_yaml.py @@ -1029,6 +1029,45 @@ def test_load_yaml_disconnected_component(tmp_path): assert not pipeline.get_node("retriever") +def test_load_yaml_unusual_chars_in_values(tmp_path): + class DummyNode(BaseComponent): + outgoing_edges = 1 + + def __init__(self, space_param, non_alphanumeric_param): + super().__init__() + self.space_param = space_param + self.non_alphanumeric_param = non_alphanumeric_param + + def run(self): + raise NotImplementedError + + def run_batch(self): + raise NotImplementedError + + with open(tmp_path / "tmp_config.yml", "w", encoding="utf-8") as tmp_file: + tmp_file.write( + f""" + version: '1.9.0' + + components: + - name: DummyNode + type: DummyNode + params: + space_param: with space + non_alphanumeric_param: \[ümlaut\] + + pipelines: + - name: indexing + nodes: + - name: DummyNode + inputs: [File] + """ + ) + pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + assert pipeline.components["DummyNode"].space_param == "with space" + assert pipeline.components["DummyNode"].non_alphanumeric_param == "\\[ümlaut\\]" + + def test_save_yaml(tmp_path): pipeline = Pipeline() pipeline.add_node(MockRetriever(), name="retriever", inputs=["Query"])