haystack/test/components/builders/test_dynamic_prompt_builder.py
Silvano Cerza 0191b1e6e4
feat: Change Component's I/O dunder type (#6916)
* Add Pipeline.get_component_name() method

* Add utility class to ease discoverability of Component I/O

* Move InputOutput in component package

* Rename InputOutput to _InputOutput

* Raise if inputs or outputs field already exist

* Fix tests

* Add release notes

* Move InputSocket and OutputSocket in types package

* Move _InputOutput in socket package

* Rename _InputOutput class to Sockets

* Simplify Sockets class

* Dictch I/O dunder fields in favour of inputs and outputs fields

* Update Sockets docstrings

* Update release notes

* Fix mypy

* Remove unnecessary assignment

* Remove unused logging

* Change SocketsType to SocketsIOType to avoid confusion

* Change sockets type and name

* Change Sockets.__repr__ to return component instance

* Fix linting

* Fix sockets tests

* Revert to dunder fields for Component IO

* Use singular in IO dunder fields

* Delete release notes

* Update haystack/core/component/types.py

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>

---------

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
2024-02-05 17:46:45 +01:00

111 lines
4.5 KiB
Python

from typing import List
import pytest
from jinja2 import TemplateSyntaxError
from haystack import Document, Pipeline, component
from haystack.components.builders import DynamicPromptBuilder
class TestDynamicPromptBuilder:
def test_initialization(self):
runtime_variables = ["var1", "var2"]
builder = DynamicPromptBuilder(runtime_variables)
assert builder.runtime_variables == runtime_variables
# regardless of the chat mode
# we have inputs that contain: prompt_source, template_variables + runtime_variables
expected_keys = set(runtime_variables + ["prompt_source", "template_variables"])
assert set(builder.__haystack_input__._sockets_dict.keys()) == expected_keys
# response is always prompt regardless of chat mode
assert set(builder.__haystack_output__._sockets_dict.keys()) == {"prompt"}
# prompt_source is a list of ChatMessage or a string
assert builder.__haystack_input__._sockets_dict["prompt_source"].type == str
# output is always prompt, but the type is different depending on the chat mode
assert builder.__haystack_output__._sockets_dict["prompt"].type == str
def test_processing_a_simple_template_with_provided_variables(self):
runtime_variables = ["var1", "var2", "var3"]
builder = DynamicPromptBuilder(runtime_variables)
template = "Hello, {{ name }}!"
template_variables = {"name": "John"}
expected_result = {"prompt": "Hello, John!"}
assert builder.run(template, template_variables) == expected_result
def test_processing_a_simple_template_with_invalid_template(self):
runtime_variables = ["var1", "var2", "var3"]
builder = DynamicPromptBuilder(runtime_variables)
template = "Hello, {{ name }!"
template_variables = {"name": "John"}
with pytest.raises(TemplateSyntaxError):
builder.run(template, template_variables)
def test_processing_a_simple_template_with_missing_variables(self):
runtime_variables = ["var1", "var2", "var3"]
builder = DynamicPromptBuilder(runtime_variables)
with pytest.raises(ValueError):
builder.run("Hello, {{ name }}!", {})
def test_missing_template_variables(self):
prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"])
# missing template variable city
with pytest.raises(ValueError):
prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"name"})
# missing template variable name
with pytest.raises(ValueError):
prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"city"})
# completely unknown template variable
with pytest.raises(ValueError):
prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"age"})
def test_provided_template_variables(self):
prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"])
# both variables are provided
prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"name", "city"})
# provided variables are a superset of the required variables
prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"name", "city", "age"})
def test_example_in_pipeline(self):
prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"])
@component
class DocumentProducer:
@component.output_types(documents=List[Document])
def run(self, doc_input: str):
return {"documents": [Document(content=doc_input)]}
pipe = Pipeline()
pipe.add_component("doc_producer", DocumentProducer())
pipe.add_component("prompt_builder", prompt_builder)
pipe.connect("doc_producer.documents", "prompt_builder.documents")
template = "Here is the document: {{documents[0].content}} \\n Answer: {{query}}"
result = pipe.run(
data={
"doc_producer": {"doc_input": "Hello world, I live in Berlin"},
"prompt_builder": {
"prompt_source": template,
"template_variables": {"query": "Where does the speaker live?"},
},
}
)
assert result == {
"prompt_builder": {
"prompt": "Here is the document: Hello world, I live in Berlin \\n Answer: Where does the speaker live?"
}
}