mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 23:48:53 +00:00
bug: The PromptNode handles all parameters as lists without checking if they are in fact lists (#3820)
This commit is contained in:
parent
897e89c9b1
commit
0288e1be76
@ -645,6 +645,8 @@ class PromptNode(BaseComponent):
|
||||
template = Template(prompt_prepared["prompt_template"])
|
||||
prompt_context_copy = prompt_prepared.copy()
|
||||
prompt_context_copy.pop("prompt_template")
|
||||
# the prompt context values should all be lists, as they will be split as one
|
||||
prompt_context_copy = {k: v if isinstance(v, list) else [v] for k, v in prompt_context_copy.items()}
|
||||
for prompt_context_values in zip(*prompt_context_copy.values()):
|
||||
template_input = {key: prompt_context_values[idx] for idx, key in enumerate(prompt_context_copy.keys())}
|
||||
template_prepared: str = template.substitute(template_input)
|
||||
|
||||
@ -268,6 +268,35 @@ def test_complex_pipeline(prompt_model):
|
||||
assert "berlin" in result["results"][0].casefold()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True)
|
||||
def test_complex_pipeline_with_qa(prompt_model):
|
||||
"""Test the PromptNode where the `query` is a string instead of a list what the PromptNode would expects,
|
||||
because in a question-answering pipeline the retrievers need `query` as a string, so the PromptNode
|
||||
need to be able to handle the `query` being a string instead of a list."""
|
||||
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
|
||||
pytest.skip("No API key found for OpenAI, skipping test")
|
||||
|
||||
prompt_template = PromptTemplate(
|
||||
name="question-answering-new",
|
||||
prompt_text="Given the context please answer the question. Context: $documents; Question: $query; Answer:",
|
||||
prompt_params=["documents", "query"],
|
||||
)
|
||||
node = PromptNode(prompt_model, default_prompt_template=prompt_template)
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
|
||||
result = pipe.run(
|
||||
query="Who lives in Berlin?", # this being a string instead of a list what is being tested
|
||||
documents=[
|
||||
Document("My name is Carla and I live in Berlin"),
|
||||
Document("My name is Christelle and I live in Paris"),
|
||||
],
|
||||
)
|
||||
|
||||
assert len(result["results"]) == 1
|
||||
assert "carla" in result["results"][0].casefold()
|
||||
|
||||
|
||||
def test_complex_pipeline_with_shared_model():
|
||||
model = PromptModel()
|
||||
node = PromptNode(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user