2024-01-04 16:54:26 +01:00
|
|
|
from typing import List
|
|
|
|
|
2023-11-23 11:41:57 +01:00
|
|
|
import pytest
|
|
|
|
from jinja2 import TemplateSyntaxError
|
|
|
|
|
2024-01-29 17:26:11 +01:00
|
|
|
from haystack import Document, Pipeline, component
|
2023-12-26 15:27:43 +01:00
|
|
|
from haystack.components.builders import DynamicPromptBuilder
|
2023-11-23 11:41:57 +01:00
|
|
|
|
|
|
|
|
|
|
|
class TestDynamicPromptBuilder:
|
2023-12-26 15:27:43 +01:00
|
|
|
def test_initialization(self):
|
2023-11-23 11:41:57 +01:00
|
|
|
runtime_variables = ["var1", "var2"]
|
2023-12-26 15:27:43 +01:00
|
|
|
builder = DynamicPromptBuilder(runtime_variables)
|
2023-11-23 11:41:57 +01:00
|
|
|
assert builder.runtime_variables == runtime_variables
|
|
|
|
|
|
|
|
# regardless of the chat mode
|
|
|
|
# we have inputs that contain: prompt_source, template_variables + runtime_variables
|
|
|
|
expected_keys = set(runtime_variables + ["prompt_source", "template_variables"])
|
2024-02-05 17:46:45 +01:00
|
|
|
assert set(builder.__haystack_input__._sockets_dict.keys()) == expected_keys
|
2023-11-23 11:41:57 +01:00
|
|
|
|
|
|
|
# response is always prompt regardless of chat mode
|
2024-02-05 17:46:45 +01:00
|
|
|
assert set(builder.__haystack_output__._sockets_dict.keys()) == {"prompt"}
|
2023-11-23 11:41:57 +01:00
|
|
|
|
|
|
|
# prompt_source is a list of ChatMessage or a string
|
2024-02-05 17:46:45 +01:00
|
|
|
assert builder.__haystack_input__._sockets_dict["prompt_source"].type == str
|
2023-11-23 11:41:57 +01:00
|
|
|
|
|
|
|
# output is always prompt, but the type is different depending on the chat mode
|
2024-02-05 17:46:45 +01:00
|
|
|
assert builder.__haystack_output__._sockets_dict["prompt"].type == str
|
2023-11-23 11:41:57 +01:00
|
|
|
|
|
|
|
def test_processing_a_simple_template_with_provided_variables(self):
|
|
|
|
runtime_variables = ["var1", "var2", "var3"]
|
|
|
|
|
2023-12-26 15:27:43 +01:00
|
|
|
builder = DynamicPromptBuilder(runtime_variables)
|
2023-11-23 11:41:57 +01:00
|
|
|
|
|
|
|
template = "Hello, {{ name }}!"
|
|
|
|
template_variables = {"name": "John"}
|
2023-12-26 15:27:43 +01:00
|
|
|
expected_result = {"prompt": "Hello, John!"}
|
2023-11-23 11:41:57 +01:00
|
|
|
|
2023-12-26 15:27:43 +01:00
|
|
|
assert builder.run(template, template_variables) == expected_result
|
2023-11-23 11:41:57 +01:00
|
|
|
|
|
|
|
def test_processing_a_simple_template_with_invalid_template(self):
|
|
|
|
runtime_variables = ["var1", "var2", "var3"]
|
2023-12-26 15:27:43 +01:00
|
|
|
builder = DynamicPromptBuilder(runtime_variables)
|
2023-11-23 11:41:57 +01:00
|
|
|
|
|
|
|
template = "Hello, {{ name }!"
|
|
|
|
template_variables = {"name": "John"}
|
|
|
|
with pytest.raises(TemplateSyntaxError):
|
2023-12-26 15:27:43 +01:00
|
|
|
builder.run(template, template_variables)
|
2023-11-23 11:41:57 +01:00
|
|
|
|
|
|
|
def test_processing_a_simple_template_with_missing_variables(self):
|
|
|
|
runtime_variables = ["var1", "var2", "var3"]
|
2023-12-26 15:27:43 +01:00
|
|
|
builder = DynamicPromptBuilder(runtime_variables)
|
2023-11-23 11:41:57 +01:00
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
2023-12-26 15:27:43 +01:00
|
|
|
builder.run("Hello, {{ name }}!", {})
|
2023-11-23 11:41:57 +01:00
|
|
|
|
|
|
|
def test_missing_template_variables(self):
|
|
|
|
prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"])
|
|
|
|
|
|
|
|
# missing template variable city
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"name"})
|
|
|
|
|
|
|
|
# missing template variable name
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"city"})
|
|
|
|
|
|
|
|
# completely unknown template variable
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"age"})
|
|
|
|
|
|
|
|
def test_provided_template_variables(self):
|
|
|
|
prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"])
|
|
|
|
|
|
|
|
# both variables are provided
|
|
|
|
prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"name", "city"})
|
|
|
|
|
|
|
|
# provided variables are a superset of the required variables
|
|
|
|
prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"name", "city", "age"})
|
2024-01-04 16:54:26 +01:00
|
|
|
|
|
|
|
def test_example_in_pipeline(self):
|
|
|
|
prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"])
|
|
|
|
|
|
|
|
@component
|
|
|
|
class DocumentProducer:
|
|
|
|
@component.output_types(documents=List[Document])
|
|
|
|
def run(self, doc_input: str):
|
|
|
|
return {"documents": [Document(content=doc_input)]}
|
|
|
|
|
|
|
|
pipe = Pipeline()
|
|
|
|
pipe.add_component("doc_producer", DocumentProducer())
|
|
|
|
pipe.add_component("prompt_builder", prompt_builder)
|
|
|
|
pipe.connect("doc_producer.documents", "prompt_builder.documents")
|
|
|
|
|
|
|
|
template = "Here is the document: {{documents[0].content}} \\n Answer: {{query}}"
|
|
|
|
result = pipe.run(
|
|
|
|
data={
|
|
|
|
"doc_producer": {"doc_input": "Hello world, I live in Berlin"},
|
|
|
|
"prompt_builder": {
|
|
|
|
"prompt_source": template,
|
|
|
|
"template_variables": {"query": "Where does the speaker live?"},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
assert result == {
|
|
|
|
"prompt_builder": {
|
|
|
|
"prompt": "Here is the document: Hello world, I live in Berlin \\n Answer: Where does the speaker live?"
|
|
|
|
}
|
|
|
|
}
|