mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 15:08:43 +00:00
refactor: Rework prompt tests (#4600)
* Rework some PromptNode and PromptModel tests * Remove duplicate code in PromptNode * Fix mypy * Fix test cause of missing fixture * Revert "Fix mypy" This reverts commit e530295a06cb260d9a8bd89679534958cb3d9776. * Revert "Remove duplicate code in PromptNode" This reverts commit 4a678ae81504dcc78a737372c061d12dc8799639.
This commit is contained in:
parent
f2c6ce39e6
commit
c3abf73332
@ -955,11 +955,6 @@ def bert_base_squad2(request):
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_node():
|
||||
return PromptNode("google/flan-t5-small", devices=["cpu"])
|
||||
|
||||
|
||||
def haystack_azure_conf():
|
||||
api_key = os.environ.get("AZURE_OPENAI_API_KEY", None)
|
||||
azure_base_url = os.environ.get("AZURE_OPENAI_BASE_URL", None)
|
||||
|
||||
12
test/prompt/conftest.py
Normal file
12
test/prompt/conftest.py
Normal file
@ -0,0 +1,12 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
|
||||
def create_mock_layer_that_supports(model_name, response=["fake_response"]):
|
||||
"""
|
||||
Create a mock invocation layer that supports the model_name and returns response.
|
||||
"""
|
||||
|
||||
def mock_supports(model_name_or_path, **kwargs):
|
||||
return model_name_or_path == model_name
|
||||
|
||||
return Mock(**{"model_name_or_path": model_name, "supports": mock_supports, "invoke.return_value": response})
|
||||
@ -1,45 +1,38 @@
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from haystack.errors import OpenAIError
|
||||
from haystack.nodes.prompt.prompt_model import PromptModel
|
||||
from haystack.nodes.prompt.providers import PromptModelInvocationLayer
|
||||
|
||||
from .conftest import create_mock_layer_that_supports
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_prompt_model():
|
||||
model = PromptModel("google/flan-t5-small")
|
||||
assert model.model_name_or_path == "google/flan-t5-small"
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_default_model():
|
||||
mock_layer = create_mock_layer_that_supports("google/flan-t5-base")
|
||||
another_layer = create_mock_layer_that_supports("another-model")
|
||||
|
||||
model = PromptModel()
|
||||
assert model.model_name_or_path == "google/flan-t5-base"
|
||||
with patch.object(PromptModelInvocationLayer, "invocation_layer_providers", new=[mock_layer, another_layer]):
|
||||
model = PromptModel()
|
||||
mock_layer.assert_called_once()
|
||||
another_layer.assert_not_called()
|
||||
model.model_invocation_layer.model_name_or_path = "google/flan-t5-base"
|
||||
|
||||
with pytest.raises(OpenAIError):
|
||||
# davinci selected but no API key provided
|
||||
model = PromptModel("text-davinci-003")
|
||||
|
||||
model = PromptModel("text-davinci-003", api_key="no need to provide a real key")
|
||||
assert model.model_name_or_path == "text-davinci-003"
|
||||
@pytest.mark.unit
|
||||
def test_construtor_with_custom_model():
|
||||
mock_layer = create_mock_layer_that_supports("some-model")
|
||||
another_layer = create_mock_layer_that_supports("another-model")
|
||||
|
||||
with patch.object(PromptModelInvocationLayer, "invocation_layer_providers", new=[mock_layer, another_layer]):
|
||||
model = PromptModel("another-model")
|
||||
mock_layer.assert_not_called()
|
||||
another_layer.assert_called_once()
|
||||
model.model_invocation_layer.model_name_or_path = "another-model"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_no_supported_model():
|
||||
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
|
||||
PromptModel("some-random-model")
|
||||
|
||||
# we can also pass model kwargs to the PromptModel
|
||||
model = PromptModel("google/flan-t5-small", model_kwargs={"model_kwargs": {"torch_dtype": torch.bfloat16}})
|
||||
assert model.model_name_or_path == "google/flan-t5-small"
|
||||
|
||||
# we can also pass kwargs directly, see HF Pipeline constructor
|
||||
model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": torch.bfloat16})
|
||||
assert model.model_name_or_path == "google/flan-t5-small"
|
||||
|
||||
# we can't use device_map auto without accelerate library installed
|
||||
with pytest.raises(ImportError, match="requires Accelerate: `pip install accelerate`"):
|
||||
model = PromptModel("google/flan-t5-small", model_kwargs={"device_map": "auto"})
|
||||
assert model.model_name_or_path == "google/flan-t5-small"
|
||||
|
||||
|
||||
def test_create_prompt_model_dtype():
|
||||
model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": "auto"})
|
||||
assert model.model_name_or_path == "google/flan-t5-small"
|
||||
|
||||
model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": "torch.bfloat16"})
|
||||
assert model.model_name_or_path == "google/flan-t5-small"
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Union, List, Dict, Any, Tuple
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import Document, Pipeline, BaseComponent, MultiLabel
|
||||
from haystack.errors import OpenAIError
|
||||
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
|
||||
from haystack.nodes.prompt import PromptModelInvocationLayer
|
||||
from haystack.nodes.prompt.providers import HFLocalInvocationLayer, TokenStreamingHandler
|
||||
from haystack.nodes.prompt.providers import HFLocalInvocationLayer
|
||||
|
||||
|
||||
def skip_test_for_invalid_key(prompt_model):
|
||||
@ -16,36 +15,6 @@ def skip_test_for_invalid_key(prompt_model):
|
||||
pytest.skip("No API key found, skipping test")
|
||||
|
||||
|
||||
class TestTokenStreamingHandler(TokenStreamingHandler):
|
||||
stream_handler_invoked = False
|
||||
|
||||
def __call__(self, token_received, *args, **kwargs) -> str:
|
||||
"""
|
||||
This callback method is called when a new token is received from the stream.
|
||||
|
||||
:param token_received: The token received from the stream.
|
||||
:param kwargs: Additional keyword arguments passed to the underlying model.
|
||||
:return: The token to be sent to the stream.
|
||||
"""
|
||||
self.stream_handler_invoked = True
|
||||
return token_received
|
||||
|
||||
|
||||
class CustomInvocationLayer(PromptModelInvocationLayer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
return ["fake_response"]
|
||||
|
||||
def _ensure_token_limit(self, prompt: str) -> str:
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
|
||||
return model_name_or_path == "fake_model"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_api_key(request):
|
||||
if request.param == "openai":
|
||||
@ -55,186 +24,187 @@ def get_api_key(request):
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_node_with_custom_invocation_layer():
|
||||
model = PromptModel("fake_model")
|
||||
pn = PromptNode(model_name_or_path=model)
|
||||
output = pn("Some fake invocation")
|
||||
def test_add_and_remove_template():
|
||||
with patch("haystack.nodes.prompt.prompt_node.PromptModel"):
|
||||
node = PromptNode()
|
||||
|
||||
assert output == ["fake_response"]
|
||||
# Verifies default
|
||||
assert len(node.get_prompt_template_names()) == 14
|
||||
|
||||
# Add a fake template
|
||||
fake_template = PromptTemplate(name="fake-template", prompt_text="Fake prompt")
|
||||
node.add_prompt_template(fake_template)
|
||||
assert len(node.get_prompt_template_names()) == 15
|
||||
assert "fake-template" in node.get_prompt_template_names()
|
||||
|
||||
# Verify that adding the same template throws an expection
|
||||
with pytest.raises(ValueError) as e:
|
||||
node.add_prompt_template(fake_template)
|
||||
assert e.match(
|
||||
"Prompt template fake-template already exists. Select a different name for this prompt template."
|
||||
)
|
||||
|
||||
# Verify template is correctly removed
|
||||
assert node.remove_prompt_template("fake-template")
|
||||
assert len(node.get_prompt_template_names()) == 14
|
||||
assert "fake-template" not in node.get_prompt_template_names()
|
||||
|
||||
# Verify that removing the same template throws an expection
|
||||
with pytest.raises(ValueError) as e:
|
||||
node.remove_prompt_template("fake-template")
|
||||
assert e.match("Prompt template fake-template does not exist")
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_prompt_node():
|
||||
prompt_node = PromptNode()
|
||||
assert prompt_node is not None
|
||||
assert prompt_node.prompt_model is not None
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_after_adding_template(mock_model):
|
||||
# Make model always return something positive on invoke
|
||||
mock_model.return_value.invoke.return_value = ["positive"]
|
||||
|
||||
prompt_node = PromptNode("google/flan-t5-small")
|
||||
assert prompt_node is not None
|
||||
assert prompt_node.model_name_or_path == "google/flan-t5-small"
|
||||
assert prompt_node.prompt_model is not None
|
||||
|
||||
with pytest.raises(OpenAIError):
|
||||
# davinci selected but no API key provided
|
||||
prompt_node = PromptNode("text-davinci-003")
|
||||
|
||||
prompt_node = PromptNode("text-davinci-003", api_key="no need to provide a real key")
|
||||
assert prompt_node is not None
|
||||
assert prompt_node.model_name_or_path == "text-davinci-003"
|
||||
assert prompt_node.prompt_model is not None
|
||||
|
||||
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
|
||||
PromptNode("some-random-model")
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_add_and_remove_template(prompt_node):
|
||||
num_default_tasks = len(prompt_node.get_prompt_template_names())
|
||||
custom_task = PromptTemplate(name="custom-task", prompt_text="Custom task: {param1}, {param2}")
|
||||
prompt_node.add_prompt_template(custom_task)
|
||||
assert len(prompt_node.get_prompt_template_names()) == num_default_tasks + 1
|
||||
assert "custom-task" in prompt_node.get_prompt_template_names()
|
||||
|
||||
assert prompt_node.remove_prompt_template("custom-task") is not None
|
||||
assert "custom-task" not in prompt_node.get_prompt_template_names()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_add_template_and_invoke(prompt_node):
|
||||
tt = PromptTemplate(
|
||||
name="sentiment-analysis-new",
|
||||
# Create a template
|
||||
template = PromptTemplate(
|
||||
name="fake-sentiment-analysis",
|
||||
prompt_text="Please give a sentiment for this context. Answer with positive, "
|
||||
"negative or neutral. Context: {documents}; Answer:",
|
||||
)
|
||||
prompt_node.add_prompt_template(tt)
|
||||
|
||||
r = prompt_node.prompt("sentiment-analysis-new", documents=["Berlin is an amazing city."])
|
||||
assert r[0].casefold() == "positive"
|
||||
# Execute prompt
|
||||
node = PromptNode()
|
||||
node.add_prompt_template(template)
|
||||
result = node.prompt("fake-sentiment-analysis", documents=["Berlin is an amazing city."])
|
||||
|
||||
assert result == ["positive"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_on_the_fly_prompt(prompt_node):
|
||||
prompt_template = PromptTemplate(
|
||||
name="sentiment-analysis-temp",
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_passing_template(mock_model):
|
||||
# Make model always return something positive on invoke
|
||||
mock_model.return_value.invoke.return_value = ["positive"]
|
||||
|
||||
# Create a template
|
||||
template = PromptTemplate(
|
||||
name="fake-sentiment-analysis",
|
||||
prompt_text="Please give a sentiment for this context. Answer with positive, "
|
||||
"negative or neutral. Context: {documents}; Answer:",
|
||||
)
|
||||
r = prompt_node.prompt(prompt_template, documents=["Berlin is an amazing city."])
|
||||
assert r[0].casefold() == "positive"
|
||||
|
||||
# Execute prompt
|
||||
node = PromptNode()
|
||||
result = node.prompt(template, documents=["Berlin is an amazing city."])
|
||||
|
||||
assert result == ["positive"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.object(PromptNode, "prompt")
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_call_with_no_kwargs(mock_model, mocked_prompt):
|
||||
node = PromptNode()
|
||||
node()
|
||||
mocked_prompt.assert_called_once_with(node.default_prompt_template)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.object(PromptNode, "prompt")
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_call_with_custom_kwargs(mock_model, mocked_prompt):
|
||||
node = PromptNode()
|
||||
node(some_kwarg="some_value")
|
||||
mocked_prompt.assert_called_once_with(node.default_prompt_template, some_kwarg="some_value")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.object(PromptNode, "prompt")
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_call_with_custom_template(mock_model, mocked_prompt):
|
||||
node = PromptNode()
|
||||
mock_template = Mock()
|
||||
node(prompt_template=mock_template)
|
||||
mocked_prompt.assert_called_once_with(mock_template)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.object(PromptNode, "prompt")
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_call_with_custom_kwargs_and_template(mock_model, mocked_prompt):
|
||||
node = PromptNode()
|
||||
mock_template = Mock()
|
||||
node(prompt_template=mock_template, some_kwarg="some_value")
|
||||
mocked_prompt.assert_called_once_with(mock_template, some_kwarg="some_value")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_get_prompt_template_without_default_template(mock_model):
|
||||
node = PromptNode()
|
||||
assert node.get_prompt_template() is None
|
||||
|
||||
template = node.get_prompt_template("question-answering")
|
||||
assert template.name == "question-answering"
|
||||
|
||||
template = node.get_prompt_template(PromptTemplate(name="fake-template", prompt_text=""))
|
||||
assert template.name == "fake-template"
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
node.get_prompt_template("some-unsupported-template")
|
||||
assert e.match("some-unsupported-template not supported, select one of:")
|
||||
|
||||
fake_yaml_prompt = "name: fake-yaml-template\nprompt_text: fake prompt text"
|
||||
template = node.get_prompt_template(fake_yaml_prompt)
|
||||
assert template.name == "fake-yaml-template"
|
||||
|
||||
fake_yaml_prompt = "- prompt_text: fake prompt text"
|
||||
template = node.get_prompt_template(fake_yaml_prompt)
|
||||
assert template.name == "custom-at-query-time"
|
||||
|
||||
template = node.get_prompt_template("some prompt")
|
||||
assert template.name == "custom-at-query-time"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_get_prompt_template_with_default_template(mock_model):
|
||||
node = PromptNode()
|
||||
node.set_default_prompt_template("question-answering")
|
||||
|
||||
template = node.get_prompt_template()
|
||||
assert template.name == "question-answering"
|
||||
|
||||
template = node.get_prompt_template("sentiment-analysis")
|
||||
assert template.name == "sentiment-analysis"
|
||||
|
||||
template = node.get_prompt_template(PromptTemplate(name="fake-template", prompt_text=""))
|
||||
assert template.name == "fake-template"
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
node.get_prompt_template("some-unsupported-template")
|
||||
assert e.match("some-unsupported-template not supported, select one of:")
|
||||
|
||||
fake_yaml_prompt = "name: fake-yaml-template\nprompt_text: fake prompt text"
|
||||
template = node.get_prompt_template(fake_yaml_prompt)
|
||||
assert template.name == "fake-yaml-template"
|
||||
|
||||
fake_yaml_prompt = "- prompt_text: fake prompt text"
|
||||
template = node.get_prompt_template(fake_yaml_prompt)
|
||||
assert template.name == "custom-at-query-time"
|
||||
|
||||
template = node.get_prompt_template("some prompt")
|
||||
assert template.name == "custom-at-query-time"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_direct_prompting(prompt_node):
|
||||
r = prompt_node("What is the capital of Germany?")
|
||||
assert r[0].casefold() == "berlin"
|
||||
|
||||
r = prompt_node("What is the capital of Germany?", "What is the secret of universe?")
|
||||
assert r[0].casefold() == "berlin"
|
||||
assert len(r[1]) > 0
|
||||
|
||||
r = prompt_node("Capital of Germany is Berlin", task="question-generation")
|
||||
assert len(r[0]) > 10 and "Germany" in r[0]
|
||||
|
||||
r = prompt_node(["Capital of Germany is Berlin", "Capital of France is Paris"], task="question-generation")
|
||||
assert len(r) == 2
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_question_generation(prompt_node):
|
||||
r = prompt_node.prompt("question-generation", documents=["Berlin is the capital of Germany."])
|
||||
assert len(r) == 1 and len(r[0]) > 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_template_selection(prompt_node):
|
||||
qa = prompt_node.set_default_prompt_template("question-answering-per-document")
|
||||
r = qa(
|
||||
["Berlin is the capital of Germany.", "Paris is the capital of France."],
|
||||
["What is the capital of Germany?", "What is the capital of France"],
|
||||
)
|
||||
assert r[0].answer.casefold() == "berlin" and r[1].answer.casefold() == "paris"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_has_supported_template_names(prompt_node):
|
||||
assert len(prompt_node.get_prompt_template_names()) > 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_invalid_template_params(prompt_node):
|
||||
def test_invalid_template_params():
|
||||
# TODO: This can be a PromptTemplate unit test
|
||||
node = PromptNode("google/flan-t5-small", devices=["cpu"])
|
||||
with pytest.raises(ValueError, match="Expected prompt parameters"):
|
||||
prompt_node.prompt("question-answering-per-document", {"some_crazy_key": "Berlin is the capital of Germany."})
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_wrong_template_params(prompt_node):
|
||||
with pytest.raises(ValueError, match="Expected prompt parameters"):
|
||||
# with don't have options param, multiple choice QA has
|
||||
prompt_node.prompt("question-answering-per-document", options=["Berlin is the capital of Germany."])
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_invalid_template(prompt_node):
|
||||
with pytest.raises(ValueError, match="invalid-task not supported"):
|
||||
prompt_node.prompt("invalid-task", {})
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_invalid_prompting(prompt_node):
|
||||
with pytest.raises(ValueError, match="Hey there, what is the best city in the"):
|
||||
prompt_node.prompt(["Hey there, what is the best city in the world?", "Hey, answer me!"])
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_prompt_at_query_time(prompt_node: PromptNode):
|
||||
results = prompt_node.prompt("Hey there, what is the best city in the world?")
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], str)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_invalid_state_ops(prompt_node):
|
||||
with pytest.raises(ValueError, match="Prompt template no_such_task_exists"):
|
||||
prompt_node.remove_prompt_template("no_such_task_exists")
|
||||
# remove default task
|
||||
prompt_node.remove_prompt_template("question-answering-per-document")
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
|
||||
def test_open_ai_prompt_with_params(prompt_model):
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
pn = PromptNode(prompt_model)
|
||||
optional_davinci_params = {"temperature": 0.5, "max_tokens": 10, "top_p": 1, "frequency_penalty": 0.5}
|
||||
r = pn.prompt("question-generation", documents=["Berlin is the capital of Germany."], **optional_davinci_params)
|
||||
assert len(r) == 1 and len(r[0]) > 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_open_ai_prompt_with_default_params(azure_conf):
|
||||
if not azure_conf:
|
||||
pytest.skip("No Azure API key found, skipping test")
|
||||
model_kwargs = {"temperature": 0.5, "max_tokens": 2, "top_p": 1, "frequency_penalty": 0.5}
|
||||
model_kwargs.update(azure_conf)
|
||||
pn = PromptNode(model_name_or_path="text-davinci-003", api_key=azure_conf["api_key"], model_kwargs=model_kwargs)
|
||||
result = pn.prompt("question-generation", documents=["Berlin is the capital of Germany."])
|
||||
assert len(result) == 1 and len(result[0]) > 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
|
||||
def test_open_ai_warn_if_max_tokens_is_too_short(prompt_model, caplog):
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
pn = PromptNode(prompt_model)
|
||||
optional_davinci_params = {"temperature": 0.5, "max_tokens": 2, "top_p": 1, "frequency_penalty": 0.5}
|
||||
with caplog.at_level(logging.WARNING):
|
||||
_ = pn.prompt("question-generation", documents=["Berlin is the capital of Germany."], **optional_davinci_params)
|
||||
assert "Increase the max_tokens parameter to allow for longer completions." in caplog.text
|
||||
node.prompt("question-answering-per-document", {"some_crazy_key": "Berlin is the capital of Germany."})
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
|
||||
def test_stop_words(prompt_model):
|
||||
# TODO: This can be a unit test for StopWordCriteria
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
|
||||
# test single stop word for both HF and OpenAI
|
||||
@ -283,39 +253,38 @@ def test_stop_words(prompt_model):
|
||||
assert "capital" in r[0] or "Germany" in r[0]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_node_streaming_handler_on_call(mock_model):
|
||||
"""
|
||||
Verifies model is created using expected stream handler when calling PromptNode.
|
||||
"""
|
||||
mock_handler = Mock()
|
||||
node = PromptNode()
|
||||
node.prompt_model = mock_model
|
||||
node("What are some of the best cities in the world to live and why?", stream=True, stream_handler=mock_handler)
|
||||
# Verify model has been constructed with expected model_kwargs
|
||||
mock_model.invoke.assert_called_once()
|
||||
assert mock_model.invoke.call_args_list[0].kwargs["stream_handler"] == mock_handler
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_node_streaming_handler_on_constructor(mock_model):
|
||||
"""
|
||||
Verifies model is created using expected stream handler when constructing PromptNode.
|
||||
"""
|
||||
model_kwargs = {"stream_handler": Mock()}
|
||||
PromptNode(model_kwargs=model_kwargs)
|
||||
# Verify model has been constructed with expected model_kwargs
|
||||
mock_model.assert_called_once()
|
||||
assert mock_model.call_args_list[0].kwargs["model_kwargs"] == model_kwargs
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
|
||||
def test_streaming_prompt_node_with_params(prompt_model):
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
|
||||
# test streaming of calls to OpenAI by passing a stream handler to the prompt method
|
||||
ttsh = TestTokenStreamingHandler()
|
||||
node = PromptNode(prompt_model)
|
||||
response = node("What are some of the best cities in the world to live and why?", stream=True, stream_handler=ttsh)
|
||||
|
||||
assert len(response[0]) > 0, "Response should not be empty"
|
||||
assert ttsh.stream_handler_invoked, "Stream handler should have been invoked"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key.",
|
||||
)
|
||||
def test_streaming_prompt_node():
|
||||
ttsh = TestTokenStreamingHandler()
|
||||
|
||||
# test streaming of all calls to OpenAI by registering a stream handler as a model kwarg
|
||||
node = PromptNode(
|
||||
"text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY"), model_kwargs={"stream_handler": ttsh}
|
||||
)
|
||||
response = node("What are some of the best cities in the world to live?")
|
||||
|
||||
assert len(response[0]) > 0, "Response should not be empty"
|
||||
assert ttsh.stream_handler_invoked, "Stream handler should have been invoked"
|
||||
|
||||
|
||||
def test_prompt_node_with_text_generation_model():
|
||||
# TODO: This is an integration test for HFLocalInvocationLayer
|
||||
# test simple prompting with text generation model
|
||||
# by default, we force the model not return prompt text
|
||||
# Thus text-generation models can be used with PromptNode
|
||||
@ -330,9 +299,11 @@ def test_prompt_node_with_text_generation_model():
|
||||
assert len(r[0]) > 0 and r[0].startswith("Hello big science!")
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
|
||||
def test_simple_pipeline(prompt_model):
|
||||
# TODO: This can be another unit test?
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
|
||||
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable="out")
|
||||
@ -343,9 +314,11 @@ def test_simple_pipeline(prompt_model):
|
||||
assert "positive" in result["out"][0].casefold()
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
|
||||
def test_complex_pipeline(prompt_model):
|
||||
# TODO: This is a unit test?
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
|
||||
node = PromptNode(prompt_model, default_prompt_template="question-generation", output_variable="query")
|
||||
@ -359,9 +332,11 @@ def test_complex_pipeline(prompt_model):
|
||||
assert "berlin" in result["answers"][0].answer.casefold()
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
|
||||
def test_simple_pipeline_with_topk(prompt_model):
|
||||
# TODO: This can be a unit test?
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
|
||||
node = PromptNode(prompt_model, default_prompt_template="question-generation", output_variable="query", top_k=2)
|
||||
@ -373,9 +348,11 @@ def test_simple_pipeline_with_topk(prompt_model):
|
||||
assert len(result["query"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
|
||||
def test_pipeline_with_standard_qa(prompt_model):
|
||||
# TODO: Unit test?
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
node = PromptNode(prompt_model, default_prompt_template="question-answering", top_k=1)
|
||||
|
||||
@ -400,6 +377,7 @@ def test_pipeline_with_standard_qa(prompt_model):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
|
||||
def test_pipeline_with_qa_with_references(prompt_model):
|
||||
@ -430,6 +408,7 @@ def test_pipeline_with_qa_with_references(prompt_model):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
|
||||
def test_pipeline_with_prompt_text_at_query_time(prompt_model):
|
||||
@ -465,6 +444,7 @@ def test_pipeline_with_prompt_text_at_query_time(prompt_model):
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
|
||||
def test_pipeline_with_prompt_template_at_query_time(prompt_model):
|
||||
# TODO: This should be just an AnswerParser unit test and some PromptTemplate unit tests
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
node = PromptNode(prompt_model, default_prompt_template="question-answering-with-references", top_k=1)
|
||||
|
||||
@ -510,11 +490,13 @@ def test_pipeline_with_prompt_template_at_query_time(prompt_model):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_pipeline_with_prompt_template_and_nested_shaper_yaml(tmp_path):
|
||||
# TODO: This can be a Shaper unit test?
|
||||
with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: template_with_nested_shaper
|
||||
@ -547,9 +529,11 @@ def test_pipeline_with_prompt_template_and_nested_shaper_yaml(tmp_path):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf"], indirect=True)
|
||||
def test_prompt_node_no_debug(prompt_model):
|
||||
# TODO: This is another unit test
|
||||
"""Pipeline with PromptNode should not generate debug info if debug is false."""
|
||||
|
||||
node = PromptNode(prompt_model, default_prompt_template="question-generation", top_k=2)
|
||||
@ -572,9 +556,11 @@ def test_prompt_node_no_debug(prompt_model):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
|
||||
def test_complex_pipeline_with_qa(prompt_model):
|
||||
# TODO: Not a PromptNode test, this maybe can be a unit test
|
||||
"""Test the PromptNode where the `query` is a string instead of a list what the PromptNode would expects,
|
||||
because in a question-answering pipeline the retrievers need `query` as a string, so the PromptNode
|
||||
need to be able to handle the `query` being a string instead of a list."""
|
||||
@ -608,8 +594,10 @@ def test_complex_pipeline_with_qa(prompt_model):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_complex_pipeline_with_shared_model():
|
||||
# TODO: What is this testing? Can this be a unit test?
|
||||
model = PromptModel()
|
||||
node = PromptNode(model_name_or_path=model, default_prompt_template="question-generation", output_variable="query")
|
||||
node2 = PromptNode(model_name_or_path=model, default_prompt_template="question-answering-per-document")
|
||||
@ -622,11 +610,15 @@ def test_complex_pipeline_with_shared_model():
|
||||
assert result["answers"][0].answer == "Berlin"
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_simple_pipeline_yaml(tmp_path):
|
||||
# TODO: This can be a unit test just to verify that loading
|
||||
# PromptNode from yaml creates a correctly runnable Pipeline.
|
||||
# Also it could probably be renamed to test_prompt_node_yaml_loading
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: p1
|
||||
@ -646,11 +638,13 @@ def test_simple_pipeline_yaml(tmp_path):
|
||||
assert result["results"][0] == "positive"
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_simple_pipeline_yaml_with_default_params(tmp_path):
|
||||
# TODO: Is this testing yaml loading?
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: p1
|
||||
@ -674,11 +668,13 @@ def test_simple_pipeline_yaml_with_default_params(tmp_path):
|
||||
assert result["results"][0] == "positive"
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_complex_pipeline_yaml(tmp_path):
|
||||
# TODO: Is this testing PromptNode or Pipeline?
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: p1
|
||||
@ -710,11 +706,14 @@ def test_complex_pipeline_yaml(tmp_path):
|
||||
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
|
||||
# TODO: Is this similar to test_complex_pipeline_with_shared_model?
|
||||
# Why are we testing this two times?
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: pmodel
|
||||
@ -750,11 +749,13 @@ def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
|
||||
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_path):
|
||||
# TODO: Is this testing PromptNode or Pipeline parsing?
|
||||
with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: pmodel
|
||||
@ -799,8 +800,10 @@ def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_
|
||||
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_path):
|
||||
# TODO: This can be a unit test. Is it necessary though? Is it testing PromptNode?
|
||||
# test that we can stick some random node in between prompt nodes and that everything still works
|
||||
# most specifically, we want to ensure that invocation_context is still populated correctly and propagated
|
||||
class InBetweenNode(BaseComponent):
|
||||
@ -830,7 +833,7 @@ def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_pat
|
||||
|
||||
with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: in_between
|
||||
@ -880,8 +883,10 @@ def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_pat
|
||||
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.parametrize("haystack_openai_config", ["openai", "azure"], indirect=True)
|
||||
def test_complex_pipeline_with_all_features(tmp_path, haystack_openai_config):
|
||||
# TODO: Is this testing PromptNode or pipeline yaml parsing?
|
||||
if not haystack_openai_config:
|
||||
pytest.skip("No API key found, skipping test")
|
||||
|
||||
@ -949,12 +954,14 @@ def test_complex_pipeline_with_all_features(tmp_path, haystack_openai_config):
|
||||
assert "query" in result["invocation_context"] and len(result["invocation_context"]["query"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_path):
|
||||
# TODO: Can this become a unit test? Is it actually worth as a test?
|
||||
# p2 and p3 are essentially the same PromptNode component, make sure we can use them both as is in the pipeline
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: p1
|
||||
@ -989,12 +996,13 @@ def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_pat
|
||||
|
||||
class TestTokenLimit:
|
||||
@pytest.mark.integration
|
||||
def test_hf_token_limit_warning(self, prompt_node, caplog):
|
||||
def test_hf_token_limit_warning(self, caplog):
|
||||
prompt_template = PromptTemplate(
|
||||
name="too-long-temp", prompt_text="Repeating text" * 200 + "Docs: {documents}; Answer:"
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
_ = prompt_node.prompt(prompt_template, documents=["Berlin is an amazing city."])
|
||||
node = PromptNode("google/flan-t5-small", devices=["cpu"])
|
||||
node.prompt(prompt_template, documents=["Berlin is an amazing city."])
|
||||
assert "The prompt has been truncated from 812 tokens to 412 tokens" in caplog.text
|
||||
assert "and answer length (100 tokens) fit within the max token limit (512 tokens)." in caplog.text
|
||||
|
||||
@ -1079,22 +1087,28 @@ class TestRunBatch:
|
||||
assert isinstance(result["results"][0][0], str)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_HFLocalInvocationLayer_supports():
|
||||
# TODO: HFLocalInvocationLayer test, to be moved
|
||||
assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum")
|
||||
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_chatgpt_direct_prompting(chatgpt_prompt_model):
|
||||
# TODO: This is testing ChatGPT, should be removed
|
||||
skip_test_for_invalid_key(chatgpt_prompt_model)
|
||||
pn = PromptNode(chatgpt_prompt_model)
|
||||
result = pn("Hey, I need some Python help. When should I use list comprehension?")
|
||||
assert len(result) == 1 and all(w in result[0] for w in ["comprehension", "list"])
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_chatgpt_direct_prompting_w_messages(chatgpt_prompt_model):
|
||||
# TODO: This is a ChatGPTInvocationLayer unit test
|
||||
skip_test_for_invalid_key(chatgpt_prompt_model)
|
||||
pn = PromptNode(chatgpt_prompt_model)
|
||||
|
||||
@ -1107,24 +1121,3 @@ def test_chatgpt_direct_prompting_w_messages(chatgpt_prompt_model):
|
||||
|
||||
result = pn(messages)
|
||||
assert len(result) == 1 and all(w in result[0].casefold() for w in ["arlington", "texas"])
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
def test_chatgpt_promptnode():
|
||||
pn = PromptNode(model_name_or_path="gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY", None))
|
||||
|
||||
result = pn("Hey, I need some Python help. When should I use list comprehension?")
|
||||
assert len(result) == 1 and all(w in result[0] for w in ["comprehension", "list"])
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"},
|
||||
]
|
||||
result = pn(messages)
|
||||
assert len(result) == 1 and all(w in result[0].casefold() for w in ["arlington", "texas"])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user