refactor: Rework prompt tests (#4600)

* Rework some PromptNode and PromptModel tests

* Remove duplicate code in PromptNode

* Fix mypy

* Fix test cause of missing fixture

* Revert "Fix mypy"

This reverts commit e530295a06cb260d9a8bd89679534958cb3d9776.

* Revert "Remove duplicate code in PromptNode"

This reverts commit 4a678ae81504dcc78a737372c061d12dc8799639.
This commit is contained in:
Silvano Cerza 2023-04-06 14:47:44 +02:00 committed by GitHub
parent f2c6ce39e6
commit c3abf73332
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 285 additions and 292 deletions

View File

@ -955,11 +955,6 @@ def bert_base_squad2(request):
return model
@pytest.fixture
def prompt_node():
return PromptNode("google/flan-t5-small", devices=["cpu"])
def haystack_azure_conf():
api_key = os.environ.get("AZURE_OPENAI_API_KEY", None)
azure_base_url = os.environ.get("AZURE_OPENAI_BASE_URL", None)

12
test/prompt/conftest.py Normal file
View File

@ -0,0 +1,12 @@
from unittest.mock import Mock
def create_mock_layer_that_supports(model_name, response=["fake_response"]):
"""
Create a mock invocation layer that supports the model_name and returns response.
"""
def mock_supports(model_name_or_path, **kwargs):
return model_name_or_path == model_name
return Mock(**{"model_name_or_path": model_name, "supports": mock_supports, "invoke.return_value": response})

View File

@ -1,45 +1,38 @@
from unittest.mock import patch, Mock
import pytest
import torch
from haystack.errors import OpenAIError
from haystack.nodes.prompt.prompt_model import PromptModel
from haystack.nodes.prompt.providers import PromptModelInvocationLayer
from .conftest import create_mock_layer_that_supports
@pytest.mark.integration
def test_create_prompt_model():
model = PromptModel("google/flan-t5-small")
assert model.model_name_or_path == "google/flan-t5-small"
@pytest.mark.unit
def test_constructor_with_default_model():
mock_layer = create_mock_layer_that_supports("google/flan-t5-base")
another_layer = create_mock_layer_that_supports("another-model")
model = PromptModel()
assert model.model_name_or_path == "google/flan-t5-base"
with patch.object(PromptModelInvocationLayer, "invocation_layer_providers", new=[mock_layer, another_layer]):
model = PromptModel()
mock_layer.assert_called_once()
another_layer.assert_not_called()
model.model_invocation_layer.model_name_or_path = "google/flan-t5-base"
with pytest.raises(OpenAIError):
# davinci selected but no API key provided
model = PromptModel("text-davinci-003")
model = PromptModel("text-davinci-003", api_key="no need to provide a real key")
assert model.model_name_or_path == "text-davinci-003"
@pytest.mark.unit
def test_construtor_with_custom_model():
mock_layer = create_mock_layer_that_supports("some-model")
another_layer = create_mock_layer_that_supports("another-model")
with patch.object(PromptModelInvocationLayer, "invocation_layer_providers", new=[mock_layer, another_layer]):
model = PromptModel("another-model")
mock_layer.assert_not_called()
another_layer.assert_called_once()
model.model_invocation_layer.model_name_or_path = "another-model"
@pytest.mark.unit
def test_constructor_with_no_supported_model():
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
PromptModel("some-random-model")
# we can also pass model kwargs to the PromptModel
model = PromptModel("google/flan-t5-small", model_kwargs={"model_kwargs": {"torch_dtype": torch.bfloat16}})
assert model.model_name_or_path == "google/flan-t5-small"
# we can also pass kwargs directly, see HF Pipeline constructor
model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": torch.bfloat16})
assert model.model_name_or_path == "google/flan-t5-small"
# we can't use device_map auto without accelerate library installed
with pytest.raises(ImportError, match="requires Accelerate: `pip install accelerate`"):
model = PromptModel("google/flan-t5-small", model_kwargs={"device_map": "auto"})
assert model.model_name_or_path == "google/flan-t5-small"
def test_create_prompt_model_dtype():
model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": "auto"})
assert model.model_name_or_path == "google/flan-t5-small"
model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": "torch.bfloat16"})
assert model.model_name_or_path == "google/flan-t5-small"

View File

@ -1,14 +1,13 @@
import os
import logging
from typing import Optional, Union, List, Dict, Any, Tuple
from unittest.mock import patch, Mock
import pytest
from haystack import Document, Pipeline, BaseComponent, MultiLabel
from haystack.errors import OpenAIError
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
from haystack.nodes.prompt import PromptModelInvocationLayer
from haystack.nodes.prompt.providers import HFLocalInvocationLayer, TokenStreamingHandler
from haystack.nodes.prompt.providers import HFLocalInvocationLayer
def skip_test_for_invalid_key(prompt_model):
@ -16,36 +15,6 @@ def skip_test_for_invalid_key(prompt_model):
pytest.skip("No API key found, skipping test")
class TestTokenStreamingHandler(TokenStreamingHandler):
stream_handler_invoked = False
def __call__(self, token_received, *args, **kwargs) -> str:
"""
This callback method is called when a new token is received from the stream.
:param token_received: The token received from the stream.
:param kwargs: Additional keyword arguments passed to the underlying model.
:return: The token to be sent to the stream.
"""
self.stream_handler_invoked = True
return token_received
class CustomInvocationLayer(PromptModelInvocationLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def invoke(self, *args, **kwargs):
return ["fake_response"]
def _ensure_token_limit(self, prompt: str) -> str:
return prompt
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
return model_name_or_path == "fake_model"
@pytest.fixture
def get_api_key(request):
if request.param == "openai":
@ -55,186 +24,187 @@ def get_api_key(request):
@pytest.mark.unit
def test_prompt_node_with_custom_invocation_layer():
model = PromptModel("fake_model")
pn = PromptNode(model_name_or_path=model)
output = pn("Some fake invocation")
def test_add_and_remove_template():
with patch("haystack.nodes.prompt.prompt_node.PromptModel"):
node = PromptNode()
assert output == ["fake_response"]
# Verifies default
assert len(node.get_prompt_template_names()) == 14
# Add a fake template
fake_template = PromptTemplate(name="fake-template", prompt_text="Fake prompt")
node.add_prompt_template(fake_template)
assert len(node.get_prompt_template_names()) == 15
assert "fake-template" in node.get_prompt_template_names()
# Verify that adding the same template throws an expection
with pytest.raises(ValueError) as e:
node.add_prompt_template(fake_template)
assert e.match(
"Prompt template fake-template already exists. Select a different name for this prompt template."
)
# Verify template is correctly removed
assert node.remove_prompt_template("fake-template")
assert len(node.get_prompt_template_names()) == 14
assert "fake-template" not in node.get_prompt_template_names()
# Verify that removing the same template throws an expection
with pytest.raises(ValueError) as e:
node.remove_prompt_template("fake-template")
assert e.match("Prompt template fake-template does not exist")
@pytest.mark.integration
def test_create_prompt_node():
prompt_node = PromptNode()
assert prompt_node is not None
assert prompt_node.prompt_model is not None
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_after_adding_template(mock_model):
# Make model always return something positive on invoke
mock_model.return_value.invoke.return_value = ["positive"]
prompt_node = PromptNode("google/flan-t5-small")
assert prompt_node is not None
assert prompt_node.model_name_or_path == "google/flan-t5-small"
assert prompt_node.prompt_model is not None
with pytest.raises(OpenAIError):
# davinci selected but no API key provided
prompt_node = PromptNode("text-davinci-003")
prompt_node = PromptNode("text-davinci-003", api_key="no need to provide a real key")
assert prompt_node is not None
assert prompt_node.model_name_or_path == "text-davinci-003"
assert prompt_node.prompt_model is not None
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
PromptNode("some-random-model")
@pytest.mark.integration
def test_add_and_remove_template(prompt_node):
num_default_tasks = len(prompt_node.get_prompt_template_names())
custom_task = PromptTemplate(name="custom-task", prompt_text="Custom task: {param1}, {param2}")
prompt_node.add_prompt_template(custom_task)
assert len(prompt_node.get_prompt_template_names()) == num_default_tasks + 1
assert "custom-task" in prompt_node.get_prompt_template_names()
assert prompt_node.remove_prompt_template("custom-task") is not None
assert "custom-task" not in prompt_node.get_prompt_template_names()
@pytest.mark.integration
def test_add_template_and_invoke(prompt_node):
tt = PromptTemplate(
name="sentiment-analysis-new",
# Create a template
template = PromptTemplate(
name="fake-sentiment-analysis",
prompt_text="Please give a sentiment for this context. Answer with positive, "
"negative or neutral. Context: {documents}; Answer:",
)
prompt_node.add_prompt_template(tt)
r = prompt_node.prompt("sentiment-analysis-new", documents=["Berlin is an amazing city."])
assert r[0].casefold() == "positive"
# Execute prompt
node = PromptNode()
node.add_prompt_template(template)
result = node.prompt("fake-sentiment-analysis", documents=["Berlin is an amazing city."])
assert result == ["positive"]
@pytest.mark.integration
def test_on_the_fly_prompt(prompt_node):
prompt_template = PromptTemplate(
name="sentiment-analysis-temp",
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_passing_template(mock_model):
# Make model always return something positive on invoke
mock_model.return_value.invoke.return_value = ["positive"]
# Create a template
template = PromptTemplate(
name="fake-sentiment-analysis",
prompt_text="Please give a sentiment for this context. Answer with positive, "
"negative or neutral. Context: {documents}; Answer:",
)
r = prompt_node.prompt(prompt_template, documents=["Berlin is an amazing city."])
assert r[0].casefold() == "positive"
# Execute prompt
node = PromptNode()
result = node.prompt(template, documents=["Berlin is an amazing city."])
assert result == ["positive"]
@pytest.mark.unit
@patch.object(PromptNode, "prompt")
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_call_with_no_kwargs(mock_model, mocked_prompt):
node = PromptNode()
node()
mocked_prompt.assert_called_once_with(node.default_prompt_template)
@pytest.mark.unit
@patch.object(PromptNode, "prompt")
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_call_with_custom_kwargs(mock_model, mocked_prompt):
node = PromptNode()
node(some_kwarg="some_value")
mocked_prompt.assert_called_once_with(node.default_prompt_template, some_kwarg="some_value")
@pytest.mark.unit
@patch.object(PromptNode, "prompt")
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_call_with_custom_template(mock_model, mocked_prompt):
node = PromptNode()
mock_template = Mock()
node(prompt_template=mock_template)
mocked_prompt.assert_called_once_with(mock_template)
@pytest.mark.unit
@patch.object(PromptNode, "prompt")
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_call_with_custom_kwargs_and_template(mock_model, mocked_prompt):
node = PromptNode()
mock_template = Mock()
node(prompt_template=mock_template, some_kwarg="some_value")
mocked_prompt.assert_called_once_with(mock_template, some_kwarg="some_value")
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_get_prompt_template_without_default_template(mock_model):
node = PromptNode()
assert node.get_prompt_template() is None
template = node.get_prompt_template("question-answering")
assert template.name == "question-answering"
template = node.get_prompt_template(PromptTemplate(name="fake-template", prompt_text=""))
assert template.name == "fake-template"
with pytest.raises(ValueError) as e:
node.get_prompt_template("some-unsupported-template")
assert e.match("some-unsupported-template not supported, select one of:")
fake_yaml_prompt = "name: fake-yaml-template\nprompt_text: fake prompt text"
template = node.get_prompt_template(fake_yaml_prompt)
assert template.name == "fake-yaml-template"
fake_yaml_prompt = "- prompt_text: fake prompt text"
template = node.get_prompt_template(fake_yaml_prompt)
assert template.name == "custom-at-query-time"
template = node.get_prompt_template("some prompt")
assert template.name == "custom-at-query-time"
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_get_prompt_template_with_default_template(mock_model):
node = PromptNode()
node.set_default_prompt_template("question-answering")
template = node.get_prompt_template()
assert template.name == "question-answering"
template = node.get_prompt_template("sentiment-analysis")
assert template.name == "sentiment-analysis"
template = node.get_prompt_template(PromptTemplate(name="fake-template", prompt_text=""))
assert template.name == "fake-template"
with pytest.raises(ValueError) as e:
node.get_prompt_template("some-unsupported-template")
assert e.match("some-unsupported-template not supported, select one of:")
fake_yaml_prompt = "name: fake-yaml-template\nprompt_text: fake prompt text"
template = node.get_prompt_template(fake_yaml_prompt)
assert template.name == "fake-yaml-template"
fake_yaml_prompt = "- prompt_text: fake prompt text"
template = node.get_prompt_template(fake_yaml_prompt)
assert template.name == "custom-at-query-time"
template = node.get_prompt_template("some prompt")
assert template.name == "custom-at-query-time"
@pytest.mark.integration
def test_direct_prompting(prompt_node):
r = prompt_node("What is the capital of Germany?")
assert r[0].casefold() == "berlin"
r = prompt_node("What is the capital of Germany?", "What is the secret of universe?")
assert r[0].casefold() == "berlin"
assert len(r[1]) > 0
r = prompt_node("Capital of Germany is Berlin", task="question-generation")
assert len(r[0]) > 10 and "Germany" in r[0]
r = prompt_node(["Capital of Germany is Berlin", "Capital of France is Paris"], task="question-generation")
assert len(r) == 2
@pytest.mark.integration
def test_question_generation(prompt_node):
r = prompt_node.prompt("question-generation", documents=["Berlin is the capital of Germany."])
assert len(r) == 1 and len(r[0]) > 0
@pytest.mark.integration
def test_template_selection(prompt_node):
qa = prompt_node.set_default_prompt_template("question-answering-per-document")
r = qa(
["Berlin is the capital of Germany.", "Paris is the capital of France."],
["What is the capital of Germany?", "What is the capital of France"],
)
assert r[0].answer.casefold() == "berlin" and r[1].answer.casefold() == "paris"
@pytest.mark.integration
def test_has_supported_template_names(prompt_node):
assert len(prompt_node.get_prompt_template_names()) > 0
@pytest.mark.integration
def test_invalid_template_params(prompt_node):
def test_invalid_template_params():
# TODO: This can be a PromptTemplate unit test
node = PromptNode("google/flan-t5-small", devices=["cpu"])
with pytest.raises(ValueError, match="Expected prompt parameters"):
prompt_node.prompt("question-answering-per-document", {"some_crazy_key": "Berlin is the capital of Germany."})
@pytest.mark.integration
def test_wrong_template_params(prompt_node):
with pytest.raises(ValueError, match="Expected prompt parameters"):
# with don't have options param, multiple choice QA has
prompt_node.prompt("question-answering-per-document", options=["Berlin is the capital of Germany."])
@pytest.mark.integration
def test_run_invalid_template(prompt_node):
with pytest.raises(ValueError, match="invalid-task not supported"):
prompt_node.prompt("invalid-task", {})
@pytest.mark.integration
def test_invalid_prompting(prompt_node):
with pytest.raises(ValueError, match="Hey there, what is the best city in the"):
prompt_node.prompt(["Hey there, what is the best city in the world?", "Hey, answer me!"])
@pytest.mark.integration
def test_prompt_at_query_time(prompt_node: PromptNode):
results = prompt_node.prompt("Hey there, what is the best city in the world?")
assert len(results) == 1
assert isinstance(results[0], str)
@pytest.mark.integration
def test_invalid_state_ops(prompt_node):
with pytest.raises(ValueError, match="Prompt template no_such_task_exists"):
prompt_node.remove_prompt_template("no_such_task_exists")
# remove default task
prompt_node.remove_prompt_template("question-answering-per-document")
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
def test_open_ai_prompt_with_params(prompt_model):
skip_test_for_invalid_key(prompt_model)
pn = PromptNode(prompt_model)
optional_davinci_params = {"temperature": 0.5, "max_tokens": 10, "top_p": 1, "frequency_penalty": 0.5}
r = pn.prompt("question-generation", documents=["Berlin is the capital of Germany."], **optional_davinci_params)
assert len(r) == 1 and len(r[0]) > 0
@pytest.mark.integration
def test_open_ai_prompt_with_default_params(azure_conf):
if not azure_conf:
pytest.skip("No Azure API key found, skipping test")
model_kwargs = {"temperature": 0.5, "max_tokens": 2, "top_p": 1, "frequency_penalty": 0.5}
model_kwargs.update(azure_conf)
pn = PromptNode(model_name_or_path="text-davinci-003", api_key=azure_conf["api_key"], model_kwargs=model_kwargs)
result = pn.prompt("question-generation", documents=["Berlin is the capital of Germany."])
assert len(result) == 1 and len(result[0]) > 0
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
def test_open_ai_warn_if_max_tokens_is_too_short(prompt_model, caplog):
skip_test_for_invalid_key(prompt_model)
pn = PromptNode(prompt_model)
optional_davinci_params = {"temperature": 0.5, "max_tokens": 2, "top_p": 1, "frequency_penalty": 0.5}
with caplog.at_level(logging.WARNING):
_ = pn.prompt("question-generation", documents=["Berlin is the capital of Germany."], **optional_davinci_params)
assert "Increase the max_tokens parameter to allow for longer completions." in caplog.text
node.prompt("question-answering-per-document", {"some_crazy_key": "Berlin is the capital of Germany."})
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
def test_stop_words(prompt_model):
# TODO: This can be a unit test for StopWordCriteria
skip_test_for_invalid_key(prompt_model)
# test single stop word for both HF and OpenAI
@ -283,39 +253,38 @@ def test_stop_words(prompt_model):
assert "capital" in r[0] or "Germany" in r[0]
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_node_streaming_handler_on_call(mock_model):
"""
Verifies model is created using expected stream handler when calling PromptNode.
"""
mock_handler = Mock()
node = PromptNode()
node.prompt_model = mock_model
node("What are some of the best cities in the world to live and why?", stream=True, stream_handler=mock_handler)
# Verify model has been constructed with expected model_kwargs
mock_model.invoke.assert_called_once()
assert mock_model.invoke.call_args_list[0].kwargs["stream_handler"] == mock_handler
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_node_streaming_handler_on_constructor(mock_model):
"""
Verifies model is created using expected stream handler when constructing PromptNode.
"""
model_kwargs = {"stream_handler": Mock()}
PromptNode(model_kwargs=model_kwargs)
# Verify model has been constructed with expected model_kwargs
mock_model.assert_called_once()
assert mock_model.call_args_list[0].kwargs["model_kwargs"] == model_kwargs
@pytest.mark.skip
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
def test_streaming_prompt_node_with_params(prompt_model):
skip_test_for_invalid_key(prompt_model)
# test streaming of calls to OpenAI by passing a stream handler to the prompt method
ttsh = TestTokenStreamingHandler()
node = PromptNode(prompt_model)
response = node("What are some of the best cities in the world to live and why?", stream=True, stream_handler=ttsh)
assert len(response[0]) > 0, "Response should not be empty"
assert ttsh.stream_handler_invoked, "Stream handler should have been invoked"
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key.",
)
def test_streaming_prompt_node():
ttsh = TestTokenStreamingHandler()
# test streaming of all calls to OpenAI by registering a stream handler as a model kwarg
node = PromptNode(
"text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY"), model_kwargs={"stream_handler": ttsh}
)
response = node("What are some of the best cities in the world to live?")
assert len(response[0]) > 0, "Response should not be empty"
assert ttsh.stream_handler_invoked, "Stream handler should have been invoked"
def test_prompt_node_with_text_generation_model():
# TODO: This is an integration test for HFLocalInvocationLayer
# test simple prompting with text generation model
# by default, we force the model not return prompt text
# Thus text-generation models can be used with PromptNode
@ -330,9 +299,11 @@ def test_prompt_node_with_text_generation_model():
assert len(r[0]) > 0 and r[0].startswith("Hello big science!")
@pytest.mark.skip
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
def test_simple_pipeline(prompt_model):
# TODO: This can be another unit test?
skip_test_for_invalid_key(prompt_model)
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable="out")
@ -343,9 +314,11 @@ def test_simple_pipeline(prompt_model):
assert "positive" in result["out"][0].casefold()
@pytest.mark.skip
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
def test_complex_pipeline(prompt_model):
# TODO: This is a unit test?
skip_test_for_invalid_key(prompt_model)
node = PromptNode(prompt_model, default_prompt_template="question-generation", output_variable="query")
@ -359,9 +332,11 @@ def test_complex_pipeline(prompt_model):
assert "berlin" in result["answers"][0].answer.casefold()
@pytest.mark.skip
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
def test_simple_pipeline_with_topk(prompt_model):
# TODO: This can be a unit test?
skip_test_for_invalid_key(prompt_model)
node = PromptNode(prompt_model, default_prompt_template="question-generation", output_variable="query", top_k=2)
@ -373,9 +348,11 @@ def test_simple_pipeline_with_topk(prompt_model):
assert len(result["query"]) == 2
@pytest.mark.skip
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
def test_pipeline_with_standard_qa(prompt_model):
# TODO: Unit test?
skip_test_for_invalid_key(prompt_model)
node = PromptNode(prompt_model, default_prompt_template="question-answering", top_k=1)
@ -400,6 +377,7 @@ def test_pipeline_with_standard_qa(prompt_model):
)
@pytest.mark.skip
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
def test_pipeline_with_qa_with_references(prompt_model):
@ -430,6 +408,7 @@ def test_pipeline_with_qa_with_references(prompt_model):
)
@pytest.mark.skip
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
def test_pipeline_with_prompt_text_at_query_time(prompt_model):
@ -465,6 +444,7 @@ def test_pipeline_with_prompt_text_at_query_time(prompt_model):
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
def test_pipeline_with_prompt_template_at_query_time(prompt_model):
# TODO: This should be just an AnswerParser unit test and some PromptTemplate unit tests
skip_test_for_invalid_key(prompt_model)
node = PromptNode(prompt_model, default_prompt_template="question-answering-with-references", top_k=1)
@ -510,11 +490,13 @@ def test_pipeline_with_prompt_template_at_query_time(prompt_model):
)
@pytest.mark.skip
@pytest.mark.integration
def test_pipeline_with_prompt_template_and_nested_shaper_yaml(tmp_path):
# TODO: This can be a Shaper unit test?
with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file:
tmp_file.write(
f"""
"""
version: ignore
components:
- name: template_with_nested_shaper
@ -547,9 +529,11 @@ def test_pipeline_with_prompt_template_and_nested_shaper_yaml(tmp_path):
)
@pytest.mark.skip
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf"], indirect=True)
def test_prompt_node_no_debug(prompt_model):
# TODO: This is another unit test
"""Pipeline with PromptNode should not generate debug info if debug is false."""
node = PromptNode(prompt_model, default_prompt_template="question-generation", top_k=2)
@ -572,9 +556,11 @@ def test_prompt_node_no_debug(prompt_model):
)
@pytest.mark.skip
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
def test_complex_pipeline_with_qa(prompt_model):
# TODO: Not a PromptNode test, this maybe can be a unit test
"""Test the PromptNode where the `query` is a string instead of a list what the PromptNode would expects,
because in a question-answering pipeline the retrievers need `query` as a string, so the PromptNode
need to be able to handle the `query` being a string instead of a list."""
@ -608,8 +594,10 @@ def test_complex_pipeline_with_qa(prompt_model):
)
@pytest.mark.skip
@pytest.mark.integration
def test_complex_pipeline_with_shared_model():
# TODO: What is this testing? Can this be a unit test?
model = PromptModel()
node = PromptNode(model_name_or_path=model, default_prompt_template="question-generation", output_variable="query")
node2 = PromptNode(model_name_or_path=model, default_prompt_template="question-answering-per-document")
@ -622,11 +610,15 @@ def test_complex_pipeline_with_shared_model():
assert result["answers"][0].answer == "Berlin"
@pytest.mark.skip
@pytest.mark.integration
def test_simple_pipeline_yaml(tmp_path):
# TODO: This can be a unit test just to verify that loading
# PromptNode from yaml creates a correctly runnable Pipeline.
# Also it could probably be renamed to test_prompt_node_yaml_loading
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
"""
version: ignore
components:
- name: p1
@ -646,11 +638,13 @@ def test_simple_pipeline_yaml(tmp_path):
assert result["results"][0] == "positive"
@pytest.mark.skip
@pytest.mark.integration
def test_simple_pipeline_yaml_with_default_params(tmp_path):
# TODO: Is this testing yaml loading?
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
"""
version: ignore
components:
- name: p1
@ -674,11 +668,13 @@ def test_simple_pipeline_yaml_with_default_params(tmp_path):
assert result["results"][0] == "positive"
@pytest.mark.skip
@pytest.mark.integration
def test_complex_pipeline_yaml(tmp_path):
# TODO: Is this testing PromptNode or Pipeline?
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
"""
version: ignore
components:
- name: p1
@ -710,11 +706,14 @@ def test_complex_pipeline_yaml(tmp_path):
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
@pytest.mark.skip
@pytest.mark.integration
def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
# TODO: Is this similar to test_complex_pipeline_with_shared_model?
# Why are we testing this two times?
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
"""
version: ignore
components:
- name: pmodel
@ -750,11 +749,13 @@ def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
@pytest.mark.skip
@pytest.mark.integration
def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_path):
# TODO: Is this testing PromptNode or Pipeline parsing?
with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file:
tmp_file.write(
f"""
"""
version: ignore
components:
- name: pmodel
@ -799,8 +800,10 @@ def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
@pytest.mark.skip
@pytest.mark.integration
def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_path):
# TODO: This can be a unit test. Is it necessary though? Is it testing PromptNode?
# test that we can stick some random node in between prompt nodes and that everything still works
# most specifically, we want to ensure that invocation_context is still populated correctly and propagated
class InBetweenNode(BaseComponent):
@ -830,7 +833,7 @@ def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_pat
with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file:
tmp_file.write(
f"""
"""
version: ignore
components:
- name: in_between
@ -880,8 +883,10 @@ def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_pat
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
@pytest.mark.skip
@pytest.mark.parametrize("haystack_openai_config", ["openai", "azure"], indirect=True)
def test_complex_pipeline_with_all_features(tmp_path, haystack_openai_config):
# TODO: Is this testing PromptNode or pipeline yaml parsing?
if not haystack_openai_config:
pytest.skip("No API key found, skipping test")
@ -949,12 +954,14 @@ def test_complex_pipeline_with_all_features(tmp_path, haystack_openai_config):
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
@pytest.mark.skip
@pytest.mark.integration
def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_path):
# TODO: Can this become a unit test? Is it actually worth as a test?
# p2 and p3 are essentially the same PromptNode component, make sure we can use them both as is in the pipeline
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
"""
version: ignore
components:
- name: p1
@ -989,12 +996,13 @@ def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_pat
class TestTokenLimit:
@pytest.mark.integration
def test_hf_token_limit_warning(self, prompt_node, caplog):
def test_hf_token_limit_warning(self, caplog):
prompt_template = PromptTemplate(
name="too-long-temp", prompt_text="Repeating text" * 200 + "Docs: {documents}; Answer:"
)
with caplog.at_level(logging.WARNING):
_ = prompt_node.prompt(prompt_template, documents=["Berlin is an amazing city."])
node = PromptNode("google/flan-t5-small", devices=["cpu"])
node.prompt(prompt_template, documents=["Berlin is an amazing city."])
assert "The prompt has been truncated from 812 tokens to 412 tokens" in caplog.text
assert "and answer length (100 tokens) fit within the max token limit (512 tokens)." in caplog.text
@ -1079,22 +1087,28 @@ class TestRunBatch:
assert isinstance(result["results"][0][0], str)
@pytest.mark.skip
@pytest.mark.integration
def test_HFLocalInvocationLayer_supports():
# TODO: HFLocalInvocationLayer test, to be moved
assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum")
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")
@pytest.mark.skip
@pytest.mark.integration
def test_chatgpt_direct_prompting(chatgpt_prompt_model):
# TODO: This is testing ChatGPT, should be removed
skip_test_for_invalid_key(chatgpt_prompt_model)
pn = PromptNode(chatgpt_prompt_model)
result = pn("Hey, I need some Python help. When should I use list comprehension?")
assert len(result) == 1 and all(w in result[0] for w in ["comprehension", "list"])
@pytest.mark.skip
@pytest.mark.integration
def test_chatgpt_direct_prompting_w_messages(chatgpt_prompt_model):
# TODO: This is a ChatGPTInvocationLayer unit test
skip_test_for_invalid_key(chatgpt_prompt_model)
pn = PromptNode(chatgpt_prompt_model)
@ -1107,24 +1121,3 @@ def test_chatgpt_direct_prompting_w_messages(chatgpt_prompt_model):
result = pn(messages)
assert len(result) == 1 and all(w in result[0].casefold() for w in ["arlington", "texas"])
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_chatgpt_promptnode():
pn = PromptNode(model_name_or_path="gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY", None))
result = pn("Hey, I need some Python help. When should I use list comprehension?")
assert len(result) == 1 and all(w in result[0] for w in ["comprehension", "list"])
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
{"role": "user", "content": "Where was it played?"},
]
result = pn(messages)
assert len(result) == 1 and all(w in result[0].casefold() for w in ["arlington", "texas"])