chore: block all HTTP requests in CI (#5088)

This commit is contained in:
ZanSara 2023-06-13 14:52:24 +02:00 committed by GitHub
parent 29a6bfe621
commit 65cdf36d72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 427 additions and 218 deletions

View File

@ -2,8 +2,12 @@ import logging
import os import os
import re import re
from typing import Tuple from typing import Tuple
from unittest.mock import patch from unittest.mock import Mock, patch
from events import Events
from haystack.agents.types import AgentTokenStreamingHandler
from test.conftest import MockRetriever, MockPromptNode from test.conftest import MockRetriever, MockPromptNode
from unittest import mock from unittest import mock
import pytest import pytest
@ -290,17 +294,6 @@ def test_tool_processes_answer_result_and_document_result():
assert tool._process_result(Document(content="content")) == "content" assert tool._process_result(Document(content="content")) == "content"
def test_invalid_agent_template():
pn = PromptNode()
with pytest.raises(ValueError, match="some_non_existing_template not supported"):
Agent(prompt_node=pn, prompt_template="some_non_existing_template")
# if prompt_template is None, then we'll use zero-shot-react
a = Agent(prompt_node=pn, prompt_template=None)
assert isinstance(a.prompt_template, PromptTemplate)
assert a.prompt_template.name == "zero-shot-react"
@pytest.mark.unit @pytest.mark.unit
@patch.object(PromptNode, "prompt") @patch.object(PromptNode, "prompt")
@patch("haystack.nodes.prompt.prompt_node.PromptModel") @patch("haystack.nodes.prompt.prompt_node.PromptModel")
@ -315,3 +308,25 @@ def test_default_template_order(mock_model, mock_prompt):
a = Agent(prompt_node=pn, prompt_template="translation") a = Agent(prompt_node=pn, prompt_template="translation")
assert a.prompt_template.name == "translation" assert a.prompt_template.name == "translation"
@pytest.mark.unit
def test_agent_with_unknown_prompt_template():
prompt_node = Mock()
prompt_node.get_prompt_template.return_value = None
with pytest.raises(ValueError, match="Prompt template 'invalid' not found"):
Agent(prompt_node=prompt_node, prompt_template="invalid")
@pytest.mark.unit
def test_agent_token_streaming_handler():
e = Events("on_new_token")
mock_callback = Mock()
e.on_new_token += mock_callback # register the mock callback to the event
handler = AgentTokenStreamingHandler(events=e)
result = handler("test")
assert result == "test"
mock_callback.assert_called_once_with("test") # assert that the mock callback was called with "test"

View File

@ -1,17 +1,15 @@
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, Mock
from haystack.agents.conversational import ConversationalAgent from haystack.agents.conversational import ConversationalAgent
from haystack.agents.memory import ConversationSummaryMemory, ConversationMemory, NoMemory from haystack.agents.memory import ConversationSummaryMemory, ConversationMemory, NoMemory
from haystack.nodes import PromptNode from test.conftest import MockPromptNode
@pytest.mark.unit @pytest.mark.unit
def test_init(): def test_init():
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub: prompt_node = MockPromptNode()
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")] agent = ConversationalAgent(prompt_node)
prompt_node = PromptNode()
agent = ConversationalAgent(prompt_node)
# Test normal case # Test normal case
assert isinstance(agent.memory, ConversationMemory) assert isinstance(agent.memory, ConversationMemory)
@ -25,32 +23,27 @@ def test_init():
@pytest.mark.unit @pytest.mark.unit
def test_init_with_summary_memory(): def test_init_with_summary_memory():
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub: # Test with summary memory
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")] prompt_node = MockPromptNode()
prompt_node = PromptNode(default_prompt_template="this is a test") agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node))
# Test with summary memory assert isinstance(agent.memory, ConversationSummaryMemory)
agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node))
assert isinstance(agent.memory, ConversationSummaryMemory)
@pytest.mark.unit @pytest.mark.unit
def test_init_with_no_memory(): def test_init_with_no_memory():
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub: prompt_node = MockPromptNode()
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")] # Test with no memory
prompt_node = PromptNode() agent = ConversationalAgent(prompt_node, memory=NoMemory())
# Test with no memory assert isinstance(agent.memory, NoMemory)
agent = ConversationalAgent(prompt_node, memory=NoMemory())
assert isinstance(agent.memory, NoMemory)
@pytest.mark.unit @pytest.mark.unit
def test_run(): def test_run():
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub: prompt_node = MockPromptNode()
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")] agent = ConversationalAgent(prompt_node)
prompt_node = PromptNode()
agent = ConversationalAgent(prompt_node)
# Mock the Agent run method # Mock the Agent run method
agent.run = MagicMock(return_value="Hello") result = agent.run("query")
assert agent.run("query") == "Hello"
agent.run.assert_called_once_with("query") # empty answer
assert result["answers"][0].answer == ""

View File

@ -1,9 +1,10 @@
import unittest import unittest
from typing import Optional, Union, List, Dict, Any
from unittest import mock from unittest import mock
import pytest import pytest
from haystack import Pipeline, Answer, Document from haystack import Pipeline, Answer, Document, BaseComponent, MultiLabel
from haystack.agents.base import ToolsManager, Tool from haystack.agents.base import ToolsManager, Tool
@ -160,3 +161,57 @@ def test_extract_tool_name_and_empty_tool_input(tools_manager):
for example in examples: for example in examples:
tool_name, tool_input = tools_manager.extract_tool_name_and_tool_input(example) tool_name, tool_input = tools_manager.extract_tool_name_and_tool_input(example)
assert tool_name == "Search" and tool_input == "" assert tool_name == "Search" and tool_input == ""
@pytest.mark.unit
def test_node_as_tool():
# test that a component can be used as a tool
class ToolComponent(BaseComponent):
outgoing_edges = 1
def run_batch(
self,
queries: Optional[Union[str, List[str]]] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[dict] = None,
debug: Optional[bool] = None,
):
pass
def run(self, **kwargs):
return "mocked_output"
tool = Tool(name="ToolA", pipeline_or_node=ToolComponent(), description="Tool A Description")
assert tool.run("input") == "mocked_output"
@pytest.mark.unit
def test_tools_manager_exception():
# tests exception raising in tools manager
class ToolComponent(BaseComponent):
outgoing_edges = 1
def run_batch(
self,
queries: Optional[Union[str, List[str]]] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[dict] = None,
debug: Optional[bool] = None,
):
pass
def run(self, **kwargs):
raise Exception("mocked_exception")
fake_llm_response = "need to find out what city he was born.\nTool: Search\nTool Input: Where was Jeremy born"
tool = Tool(name="Search", pipeline_or_node=ToolComponent(), description="Search")
tools_manager = ToolsManager(tools=[tool])
with pytest.raises(Exception):
tools_manager.run_tool(llm_response=fake_llm_response)

View File

@ -8,6 +8,7 @@ from pathlib import Path
import os import os
import re import re
from functools import wraps from functools import wraps
from unittest.mock import patch
import requests_cache import requests_cache
import responses import responses
@ -847,5 +848,14 @@ def request_blocker(request: pytest.FixtureRequest, monkeypatch):
marker = request.node.get_closest_marker("unit") marker = request.node.get_closest_marker("unit")
if marker is None: if marker is None:
return return
monkeypatch.delattr("requests.sessions.Session")
monkeypatch.delattr("requests_cache.session.CachedSession") def urlopen_mock(self, method, url, *args, **kwargs):
raise RuntimeError(f"The test was about to {method} {self.scheme}://{self.host}{url}")
monkeypatch.setattr("urllib3.connectionpool.HTTPConnectionPool.urlopen", urlopen_mock)
@pytest.fixture
def mock_auto_tokenizer():
with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained:
yield mock_from_pretrained

View File

@ -15,7 +15,7 @@ from ..conftest import fail_at_version
@pytest.mark.unit @pytest.mark.unit
@fail_at_version(1, 18) @fail_at_version(1, 18)
def test_seq2seq_deprecation(): def test_seq2seq_deprecation(mock_auto_tokenizer):
with pytest.warns(DeprecationWarning): with pytest.warns(DeprecationWarning):
try: try:
Seq2SeqGenerator("non_existing_model/model") Seq2SeqGenerator("non_existing_model/model")

View File

@ -8,10 +8,21 @@ from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamin
from haystack.nodes.prompt.invocation_layer import AnthropicClaudeInvocationLayer from haystack.nodes.prompt.invocation_layer import AnthropicClaudeInvocationLayer
@pytest.fixture
def mock_claude_tokenizer():
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer", autospec=True) as mock_tokenizer:
yield mock_tokenizer
@pytest.fixture
def mock_claude_request():
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request:
yield mock_request
@pytest.mark.unit @pytest.mark.unit
def test_default_costuctor(): def test_default_constructor(mock_claude_tokenizer, mock_claude_request):
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"): layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
assert layer.api_key == "some_fake_key" assert layer.api_key == "some_fake_key"
assert layer.max_length == 200 assert layer.max_length == 200
@ -20,7 +31,7 @@ def test_default_costuctor():
@pytest.mark.unit @pytest.mark.unit
def test_ignored_kwargs_are_filtered_in_init(): def test_ignored_kwargs_are_filtered_in_init(mock_claude_tokenizer, mock_claude_request):
kwargs = { kwargs = {
"temperature": 1, "temperature": 1,
"top_p": 5, "top_p": 5,
@ -30,8 +41,8 @@ def test_ignored_kwargs_are_filtered_in_init():
"stream_handler": DefaultTokenStreamingHandler(), "stream_handler": DefaultTokenStreamingHandler(),
"unkwnown_args": "this will be filtered out", "unkwnown_args": "this will be filtered out",
} }
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", **kwargs) layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", **kwargs)
# Verify unexpected kwargs are filtered out # Verify unexpected kwargs are filtered out
assert len(layer.model_input_kwargs) == 6 assert len(layer.model_input_kwargs) == 6
@ -45,9 +56,8 @@ def test_ignored_kwargs_are_filtered_in_init():
@pytest.mark.unit @pytest.mark.unit
def test_invoke_with_no_kwargs(): def test_invoke_with_no_kwargs(mock_claude_tokenizer, mock_claude_request):
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"): layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
with pytest.raises(ValueError) as e: with pytest.raises(ValueError) as e:
layer.invoke() layer.invoke()
@ -55,14 +65,12 @@ def test_invoke_with_no_kwargs():
@pytest.mark.unit @pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") def test_invoke_with_prompt_only(mock_claude_tokenizer, mock_claude_request):
def test_invoke_with_prompt_only(mock_request): layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
# Create a fake response # Create a fake response
mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}}) mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}})
mock_request.return_value = mock_response mock_claude_request.return_value = mock_response
res = layer.invoke(prompt="Some prompt") res = layer.invoke(prompt="Some prompt")
assert len(res) == 1 assert len(res) == 1
@ -70,14 +78,13 @@ def test_invoke_with_prompt_only(mock_request):
@pytest.mark.unit @pytest.mark.unit
def test_invoke_with_kwargs(): def test_invoke_with_kwargs(mock_claude_tokenizer, mock_claude_request):
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"): layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
# Create a fake response # Create a fake response
mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}}) mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}})
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request: with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request:
mock_request.return_value = mock_response mock_invocation_request.return_value = mock_response
res = layer.invoke(prompt="Some prompt", max_length=300, stop_words=["stop", "here"]) res = layer.invoke(prompt="Some prompt", max_length=300, stop_words=["stop", "here"])
assert len(res) == 1 assert len(res) == 1
assert res[0] == "some_result" assert res[0] == "some_result"
@ -92,19 +99,18 @@ def test_invoke_with_kwargs():
"stream": False, "stream": False,
"stop_sequences": ["stop", "here", "\n\nHuman: "], "stop_sequences": ["stop", "here", "\n\nHuman: "],
} }
mock_request.assert_called_once() mock_invocation_request.assert_called_once()
assert mock_request.call_args.kwargs["data"] == json.dumps(expected_data) assert mock_invocation_request.call_args.kwargs["data"] == json.dumps(expected_data)
@pytest.mark.unit @pytest.mark.unit
def test_invoke_with_none_stop_words(): def test_invoke_with_none_stop_words(mock_claude_tokenizer, mock_claude_request):
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"): layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
# Create a fake response # Create a fake response
mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}}) mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}})
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request: with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request:
mock_request.return_value = mock_response mock_invocation_request.return_value = mock_response
res = layer.invoke(prompt="Some prompt", max_length=300, stop_words=None) res = layer.invoke(prompt="Some prompt", max_length=300, stop_words=None)
assert len(res) == 1 assert len(res) == 1
assert res[0] == "some_result" assert res[0] == "some_result"
@ -119,14 +125,13 @@ def test_invoke_with_none_stop_words():
"stream": False, "stream": False,
"stop_sequences": ["\n\nHuman: "], "stop_sequences": ["\n\nHuman: "],
} }
mock_request.assert_called_once() mock_invocation_request.assert_called_once()
assert mock_request.call_args.kwargs["data"] == json.dumps(expected_data) assert mock_invocation_request.call_args.kwargs["data"] == json.dumps(expected_data)
@pytest.mark.unit @pytest.mark.unit
def test_invoke_with_stream(): def test_invoke_with_stream(mock_claude_tokenizer, mock_claude_request):
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"): layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
# Create a fake streamed response # Create a fake streamed response
def mock_iter(self): def mock_iter(self):
@ -141,8 +146,8 @@ def test_invoke_with_stream():
mock_response = Mock(**{"__iter__": mock_iter}) mock_response = Mock(**{"__iter__": mock_iter})
# Verifies expected result is returned # Verifies expected result is returned
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request: with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request:
mock_request.return_value = mock_response mock_invocation_request.return_value = mock_response
res = layer.invoke(prompt="Some prompt", stream=True) res = layer.invoke(prompt="Some prompt", stream=True)
assert len(res) == 1 assert len(res) == 1
@ -150,14 +155,13 @@ def test_invoke_with_stream():
@pytest.mark.unit @pytest.mark.unit
def test_invoke_with_custom_stream_handler(): def test_invoke_with_custom_stream_handler(mock_claude_tokenizer, mock_claude_request):
# Create a mock stream handler that always return the same token when called # Create a mock stream handler that always return the same token when called
mock_stream_handler = Mock() mock_stream_handler = Mock()
mock_stream_handler.return_value = "token" mock_stream_handler.return_value = "token"
# Create a layer with a mocked stream handler # Create a layer with a mocked stream handler
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"): layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", stream_handler=mock_stream_handler)
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", stream_handler=mock_stream_handler)
# Create a fake streamed response # Create a fake streamed response
def mock_iter(self): def mock_iter(self):
@ -171,8 +175,8 @@ def test_invoke_with_custom_stream_handler():
mock_response = Mock(**{"__iter__": mock_iter}) mock_response = Mock(**{"__iter__": mock_iter})
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request: with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request:
mock_request.return_value = mock_response mock_invocation_request.return_value = mock_response
res = layer.invoke(prompt="Some prompt") res = layer.invoke(prompt="Some prompt")
assert len(res) == 1 assert len(res) == 1
@ -186,7 +190,7 @@ def test_invoke_with_custom_stream_handler():
@pytest.mark.unit @pytest.mark.unit
def test_ensure_token_limit_fails_if_called_with_list(): def test_ensure_token_limit_fails_if_called_with_list(mock_claude_tokenizer, mock_claude_request):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
with pytest.raises(ValueError): with pytest.raises(ValueError):
layer._ensure_token_limit(prompt=[]) layer._ensure_token_limit(prompt=[])
@ -225,7 +229,7 @@ def test_ensure_token_limit_with_huge_max_length(caplog):
@pytest.mark.unit @pytest.mark.unit
def test_supports(): def test_supports(mock_claude_tokenizer, mock_claude_request):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
assert not layer.supports("claude") assert not layer.supports("claude")

View File

@ -1,14 +1,14 @@
import unittest import unittest
from unittest.mock import patch, MagicMock from unittest.mock import Mock
import pytest import pytest
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler
from haystack.nodes.prompt.invocation_layer import CohereInvocationLayer from haystack.nodes.prompt.invocation_layer import CohereInvocationLayer
@pytest.mark.unit @pytest.mark.unit
def test_default_constructor(): def test_default_constructor(mock_auto_tokenizer):
""" """
Test that the default constructor sets the correct values Test that the default constructor sets the correct values
""" """
@ -28,7 +28,7 @@ def test_default_constructor():
@pytest.mark.unit @pytest.mark.unit
def test_constructor_with_model_kwargs(): def test_constructor_with_model_kwargs(mock_auto_tokenizer):
""" """
Test that model_kwargs are correctly set in the constructor Test that model_kwargs are correctly set in the constructor
and that model_kwargs_rejected are correctly filtered out and that model_kwargs_rejected are correctly filtered out
@ -43,7 +43,7 @@ def test_constructor_with_model_kwargs():
@pytest.mark.unit @pytest.mark.unit
def test_invoke_with_no_kwargs(): def test_invoke_with_no_kwargs(mock_auto_tokenizer):
""" """
Test that invoke raises an error if no prompt is provided Test that invoke raises an error if no prompt is provided
""" """
@ -54,7 +54,7 @@ def test_invoke_with_no_kwargs():
@pytest.mark.unit @pytest.mark.unit
def test_invoke_with_stop_words(): def test_invoke_with_stop_words(mock_auto_tokenizer):
""" """
Test stop words are correctly passed from PromptNode to wire in CohereInvocationLayer Test stop words are correctly passed from PromptNode to wire in CohereInvocationLayer
""" """
@ -62,98 +62,135 @@ def test_invoke_with_stop_words():
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key") layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post: with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
# Mock the response, need to return a list of dicts # Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}') mock_post.return_value = Mock(text='{"generations":[{"text": "Hello"}]}')
layer.invoke(prompt="Tell me hello", stop_words=stop_words) layer.invoke(prompt="Tell me hello", stop_words=stop_words)
assert mock_post.called assert mock_post.called
called_args, _ = mock_post.call_args
# Check if stop_words are passed to _post as stop parameter assert "end_sequences" in called_args[0]
called_args, _ = mock_post.call_args assert called_args[0]["end_sequences"] == stop_words
assert "end_sequences" in called_args[0]
assert called_args[0]["end_sequences"] == stop_words
@pytest.mark.unit @pytest.mark.unit
@pytest.mark.parametrize("using_constructor", [True, False]) def test_streaming_stream_param_from_init(mock_auto_tokenizer):
@pytest.mark.parametrize("stream", [True, False])
def test_streaming_stream_param(using_constructor, stream):
""" """
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from init
""" """
if using_constructor:
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=stream) layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=True)
else:
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post: with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
# Mock the response, need to return a list of dicts # Mock the response
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}') mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}']))
layer.invoke(prompt="Tell me hello")
if using_constructor: assert mock_post.called
layer.invoke(prompt="Tell me hello") _, called_kwargs = mock_post.call_args
else:
layer.invoke(prompt="Tell me hello", stream=stream)
assert mock_post.called # stream is always passed to _post
assert "stream" in called_kwargs and called_kwargs["stream"]
# Check if stop_words are passed to _post as stop parameter
called_args, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
# Check if stream is True, then stream is passed as True to _post
if stream:
assert called_kwargs["stream"]
# Check if stream is False, then stream is passed as False to _post
else:
assert not called_kwargs["stream"]
@pytest.mark.unit @pytest.mark.unit
@pytest.mark.parametrize("using_constructor", [True, False]) def test_streaming_stream_param_from_init_no_stream(mock_auto_tokenizer):
@pytest.mark.parametrize("stream_handler", [DefaultTokenStreamingHandler(), None]) """
def test_streaming_stream_handler_param(using_constructor, stream_handler): Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from init
"""
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
# Mock the response
mock_post.return_value = Mock(text='{"generations":[{"text": "Hello there"}]}')
layer.invoke(prompt="Tell me hello")
assert mock_post.called
_, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
assert not bool(called_kwargs["stream"])
@pytest.mark.unit
def test_streaming_stream_param_from_invoke(mock_auto_tokenizer):
"""
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from invoke
"""
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
# Mock the response
mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}']))
layer.invoke(prompt="Tell me hello", stream=True)
assert mock_post.called
_, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
assert bool(called_kwargs["stream"])
@pytest.mark.unit
def test_streaming_stream_param_from_invoke_no_stream(mock_auto_tokenizer):
"""
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from invoke
"""
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=True)
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
# Mock the response
mock_post.return_value = Mock(text='{"generations":[{"text": "Hello there"}]}')
layer.invoke(prompt="Tell me hello", stream=False)
assert mock_post.called
_, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
assert not bool(called_kwargs["stream"])
@pytest.mark.unit
def test_streaming_stream_handler_param_from_init(mock_auto_tokenizer):
""" """
Test stream_handler parameter is correctly from PromptNode passed to wire in CohereInvocationLayer Test stream_handler parameter is correctly from PromptNode passed to wire in CohereInvocationLayer
""" """
if using_constructor: stream_handler = DefaultTokenStreamingHandler()
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream_handler=stream_handler) layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream_handler=stream_handler)
else:
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
with unittest.mock.patch( with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
"haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post" # Mock the response
) as mock_post, unittest.mock.patch( mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}']))
"haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._process_streaming_response" responses = layer.invoke(prompt="Tell me hello")
) as mock_post_stream:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}')
if using_constructor: assert mock_post.called
layer.invoke(prompt="Tell me hello") _, called_kwargs = mock_post.call_args
else: assert "stream" in called_kwargs
layer.invoke(prompt="Tell me hello", stream_handler=stream_handler) assert bool(called_kwargs["stream"])
assert responses == ["Hello there"]
assert mock_post.called
# Check if stop_words are passed to _post as stop parameter @pytest.mark.unit
called_args, called_kwargs = mock_post.call_args def test_streaming_stream_handler_param_from_invoke(mock_auto_tokenizer):
"""
Test stream_handler parameter is correctly from PromptNode passed to wire in CohereInvocationLayer
"""
stream_handler = DefaultTokenStreamingHandler()
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
# stream is always passed to _post with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
assert "stream" in called_kwargs # Mock the response
mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}']))
responses = layer.invoke(prompt="Tell me hello", stream_handler=stream_handler)
# if stream_handler is used then stream is always True assert mock_post.called
if stream_handler: _, called_kwargs = mock_post.call_args
assert called_kwargs["stream"] assert "stream" in called_kwargs
# and stream_handler is passed as an instance of TokenStreamingHandler assert bool(called_kwargs["stream"])
called_args, called_kwargs = mock_post_stream.call_args assert responses == ["Hello there"]
assert "stream_handler" in called_kwargs
assert isinstance(called_kwargs["stream_handler"], TokenStreamingHandler)
# if stream_handler is not used then stream is always False
else:
assert not called_kwargs["stream"]
@pytest.mark.unit @pytest.mark.unit
@ -181,13 +218,13 @@ def test_supports():
@pytest.mark.unit @pytest.mark.unit
def test_ensure_token_limit_fails_if_called_with_list(): def test_ensure_token_limit_fails_if_called_with_list(mock_auto_tokenizer):
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key") layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key")
with pytest.raises(ValueError): with pytest.raises(ValueError):
layer._ensure_token_limit(prompt=[]) layer._ensure_token_limit(prompt=[])
@pytest.mark.unit @pytest.mark.integration
def test_ensure_token_limit_with_small_max_length(caplog): def test_ensure_token_limit_with_small_max_length(caplog):
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=10) layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=10)
res = layer._ensure_token_limit(prompt="Short prompt") res = layer._ensure_token_limit(prompt="Short prompt")
@ -200,7 +237,7 @@ def test_ensure_token_limit_with_small_max_length(caplog):
assert not caplog.records assert not caplog.records
@pytest.mark.unit @pytest.mark.integration
def test_ensure_token_limit_with_huge_max_length(caplog): def test_ensure_token_limit_with_huge_max_length(caplog):
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=4090) layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=4090)
res = layer._ensure_token_limit(prompt="Short prompt") res = layer._ensure_token_limit(prompt="Short prompt")

View File

@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch, Mock
import pytest import pytest
import torch import torch
from torch import device from torch import device
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BloomForCausalLM, StoppingCriteriaList, GenerationConfig from transformers import AutoTokenizer, BloomForCausalLM, StoppingCriteriaList, GenerationConfig
from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer
from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler, DefaultTokenStreamingHandler from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler, DefaultTokenStreamingHandler
@ -13,8 +13,11 @@ from haystack.nodes.prompt.invocation_layer.hugging_face import StopWordsCriteri
@pytest.fixture @pytest.fixture
def mock_pipeline(): def mock_pipeline():
# mock transformers pipeline # mock transformers pipeline
# model returning some mocked text for pipeline invocation
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.pipeline") as mocked_pipeline: with patch("haystack.nodes.prompt.invocation_layer.hugging_face.pipeline") as mocked_pipeline:
mocked_pipeline.return_value = Mock(**{"model_name_or_path": None, "tokenizer.model_max_length": 100}) pipeline_mock = Mock(**{"model_name_or_path": None, "tokenizer.model_max_length": 100})
pipeline_mock.side_effect = lambda *args, **kwargs: [{"generated_text": "some mocked text"}]
mocked_pipeline.return_value = pipeline_mock
yield mocked_pipeline yield mocked_pipeline
@ -44,7 +47,7 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task):
mock_pipeline.assert_called_once() mock_pipeline.assert_called_once()
args, kwargs = mock_pipeline.call_args _, kwargs = mock_pipeline.call_args
# device is set to cpu by default and device_map is empty # device is set to cpu by default and device_map is empty
assert kwargs["device"] == device("cpu") assert kwargs["device"] == device("cpu")
@ -87,7 +90,7 @@ def test_constructor_with_model_name_and_device_map(mock_pipeline, mock_get_task
mock_pipeline.assert_called_once() mock_pipeline.assert_called_once()
mock_get_task.assert_called_once() mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args _, kwargs = mock_pipeline.call_args
# device is NOT set; device_map is auto because device_map takes precedence over device # device is NOT set; device_map is auto because device_map takes precedence over device
assert not kwargs["device"] assert not kwargs["device"]
@ -110,7 +113,7 @@ def test_constructor_with_torch_dtype(mock_pipeline, mock_get_task):
mock_pipeline.assert_called_once() mock_pipeline.assert_called_once()
mock_get_task.assert_called_once() mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args _, kwargs = mock_pipeline.call_args
assert kwargs["torch_dtype"] == torch.float16 assert kwargs["torch_dtype"] == torch.float16
@ -126,7 +129,7 @@ def test_constructor_with_torch_dtype_as_str(mock_pipeline, mock_get_task):
mock_pipeline.assert_called_once() mock_pipeline.assert_called_once()
mock_get_task.assert_called_once() mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args _, kwargs = mock_pipeline.call_args
assert kwargs["torch_dtype"] == torch.float16 assert kwargs["torch_dtype"] == torch.float16
@ -142,7 +145,7 @@ def test_constructor_with_torch_dtype_auto(mock_pipeline, mock_get_task):
mock_pipeline.assert_called_once() mock_pipeline.assert_called_once()
mock_get_task.assert_called_once() mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args _, kwargs = mock_pipeline.call_args
assert kwargs["torch_dtype"] == "auto" assert kwargs["torch_dtype"] == "auto"
@ -206,9 +209,8 @@ def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task):
""" """
Test that the constructor sets the pipeline with the pretrained model (if provided) Test that the constructor sets the pipeline with the pretrained model (if provided)
""" """
# actual model and tokenizer passed to the pipeline model = Mock()
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = Mock()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
HFLocalInvocationLayer( HFLocalInvocationLayer(
model_name_or_path="irrelevant_when_model_is_provided", model_name_or_path="irrelevant_when_model_is_provided",
@ -221,7 +223,7 @@ def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task):
# mock_get_task is not called as we provided task_name parameter # mock_get_task is not called as we provided task_name parameter
mock_get_task.assert_not_called() mock_get_task.assert_not_called()
args, kwargs = mock_pipeline.call_args _, kwargs = mock_pipeline.call_args
# correct tokenizer and model are set as well # correct tokenizer and model are set as well
assert kwargs["tokenizer"] == tokenizer assert kwargs["tokenizer"] == tokenizer
@ -239,7 +241,7 @@ def test_constructor_with_invalid_kwargs(mock_pipeline, mock_get_task):
mock_pipeline.assert_called_once() mock_pipeline.assert_called_once()
mock_get_task.assert_called_once() mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args _, kwargs = mock_pipeline.call_args
# invalid kwargs are ignored and not passed to the pipeline # invalid kwargs are ignored and not passed to the pipeline
assert "some_invalid_kwarg" not in kwargs assert "some_invalid_kwarg" not in kwargs
@ -258,7 +260,7 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task):
HFLocalInvocationLayer( HFLocalInvocationLayer(
"google/flan-t5-base", "google/flan-t5-base",
task_name="text2text-generation", task_name="text2text-generation",
tokenizer=AutoTokenizer.from_pretrained("google/flan-t5-base"), tokenizer=Mock(),
config=Mock(), config=Mock(),
revision="1.1", revision="1.1",
device="cpu", device="cpu",
@ -271,7 +273,7 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task):
# mock_get_task is not called as we provided task_name parameter # mock_get_task is not called as we provided task_name parameter
mock_get_task.assert_not_called() mock_get_task.assert_not_called()
args, kwargs = mock_pipeline.call_args _, kwargs = mock_pipeline.call_args
# invalid kwargs are ignored and not passed to the pipeline # invalid kwargs are ignored and not passed to the pipeline
assert "first_invalid_kwarg" not in kwargs assert "first_invalid_kwarg" not in kwargs
@ -287,7 +289,7 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task):
assert len(kwargs) == 13 assert len(kwargs) == 13
@pytest.mark.unit @pytest.mark.integration
def test_text_generation_model(): def test_text_generation_model():
# test simple prompting with text generation model # test simple prompting with text generation model
# by default, we force the model not return prompt text # by default, we force the model not return prompt text
@ -303,7 +305,7 @@ def test_text_generation_model():
assert len(r[0]) > 0 and r[0].startswith("Hello big science!") assert len(r[0]) > 0 and r[0].startswith("Hello big science!")
@pytest.mark.unit @pytest.mark.integration
def test_text_generation_model_via_custom_pretrained_model(): def test_text_generation_model_via_custom_pretrained_model():
tokenizer = AutoTokenizer.from_pretrained("bigscience/bigscience-small-testing") tokenizer = AutoTokenizer.from_pretrained("bigscience/bigscience-small-testing")
model = BloomForCausalLM.from_pretrained("bigscience/bigscience-small-testing") model = BloomForCausalLM.from_pretrained("bigscience/bigscience-small-testing")
@ -320,31 +322,29 @@ def test_text_generation_model_via_custom_pretrained_model():
@pytest.mark.unit @pytest.mark.unit
def test_streaming_stream_param_in_constructor(): def test_streaming_stream_param_in_constructor(mock_pipeline, mock_get_task):
""" """
Test stream parameter is correctly passed to pipeline invocation via HF streamer parameter Test stream parameter is correctly passed to pipeline invocation via HF streamer parameter
""" """
layer = HFLocalInvocationLayer(stream=True) layer = HFLocalInvocationLayer(stream=True)
layer.pipe = MagicMock()
layer.invoke(prompt="Tell me hello") layer.invoke(prompt="Tell me hello")
args, kwargs = layer.pipe.call_args _, kwargs = layer.pipe.call_args
assert "streamer" in kwargs and isinstance(kwargs["streamer"], HFTokenStreamingHandler) assert "streamer" in kwargs and isinstance(kwargs["streamer"], HFTokenStreamingHandler)
@pytest.mark.unit @pytest.mark.unit
def test_streaming_stream_handler_param_in_constructor(): def test_streaming_stream_handler_param_in_constructor(mock_pipeline, mock_get_task):
""" """
Test stream parameter is correctly passed to pipeline invocation Test stream parameter is correctly passed to pipeline invocation
""" """
dtsh = DefaultTokenStreamingHandler() dtsh = DefaultTokenStreamingHandler()
layer = HFLocalInvocationLayer(stream_handler=dtsh) layer = HFLocalInvocationLayer(stream_handler=dtsh)
layer.pipe = MagicMock()
layer.invoke(prompt="Tell me hello") layer.invoke(prompt="Tell me hello")
args, kwargs = layer.pipe.call_args _, kwargs = layer.pipe.call_args
assert "streamer" in kwargs assert "streamer" in kwargs
hf_streamer = kwargs["streamer"] hf_streamer = kwargs["streamer"]
@ -394,18 +394,17 @@ def test_supports(tmp_path):
@pytest.mark.unit @pytest.mark.unit
def test_stop_words_criteria_set(): def test_stop_words_criteria_set(mock_pipeline, mock_get_task):
""" """
Test that stop words criteria is correctly set in pipeline invocation Test that stop words criteria is correctly set in pipeline invocation
""" """
layer = HFLocalInvocationLayer( layer = HFLocalInvocationLayer(
model_name_or_path="hf-internal-testing/tiny-random-t5", task_name="text2text-generation" model_name_or_path="hf-internal-testing/tiny-random-t5", task_name="text2text-generation"
) )
layer.pipe = MagicMock()
layer.invoke(prompt="Tell me hello", stop_words=["hello", "world"]) layer.invoke(prompt="Tell me hello", stop_words=["hello", "world"])
args, kwargs = layer.pipe.call_args _, kwargs = layer.pipe.call_args
assert "stopping_criteria" in kwargs assert "stopping_criteria" in kwargs
assert isinstance(kwargs["stopping_criteria"], StoppingCriteriaList) assert isinstance(kwargs["stopping_criteria"], StoppingCriteriaList)
assert len(kwargs["stopping_criteria"]) == 1 assert len(kwargs["stopping_criteria"]) == 1
@ -468,7 +467,7 @@ def test_stop_words_not_being_found():
assert word in result[0] assert word in result[0]
@pytest.mark.unit @pytest.mark.integration
def test_generation_kwargs_from_constructor(): def test_generation_kwargs_from_constructor():
""" """
Test that generation_kwargs are correctly passed to pipeline invocation from constructor Test that generation_kwargs are correctly passed to pipeline invocation from constructor
@ -489,7 +488,7 @@ def test_generation_kwargs_from_constructor():
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {}) mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {})
@pytest.mark.unit @pytest.mark.integration
def test_generation_kwargs_from_invoke(): def test_generation_kwargs_from_invoke():
""" """
Test that generation_kwargs passed to invoke are passed to the underlying HF model Test that generation_kwargs passed to invoke are passed to the underlying HF model
@ -508,3 +507,35 @@ def test_generation_kwargs_from_invoke():
layer.invoke(prompt=the_question, generation_kwargs=GenerationConfig(do_sample=True, top_p=0.9)) layer.invoke(prompt=the_question, generation_kwargs=GenerationConfig(do_sample=True, top_p=0.9))
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {}) mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {})
@pytest.mark.unit
def test_ensure_token_limit_positive_mock(mock_pipeline, mock_get_task, mock_auto_tokenizer):
# prompt of length 5 + max_length of 3 = 8, which is less than model_max_length of 10, so no resize
mock_tokens = ["I", "am", "a", "tokenized", "prompt"]
mock_prompt = "I am a tokenized prompt"
mock_auto_tokenizer.tokenize = Mock(return_value=mock_tokens)
mock_auto_tokenizer.convert_tokens_to_string = Mock(return_value=mock_prompt)
mock_pipeline.return_value.tokenizer = mock_auto_tokenizer
layer = HFLocalInvocationLayer("google/flan-t5-base", max_length=3, model_max_length=10)
result = layer._ensure_token_limit(mock_prompt)
assert result == mock_prompt
@pytest.mark.unit
def test_ensure_token_limit_negative_mock(mock_pipeline, mock_get_task, mock_auto_tokenizer):
# prompt of length 8 + max_length of 3 = 11, which is more than model_max_length of 10, so we resize to 7
mock_tokens = ["I", "am", "a", "tokenized", "prompt", "of", "length", "eight"]
correct_result = "I am a tokenized prompt of length"
mock_auto_tokenizer.tokenize = Mock(return_value=mock_tokens)
mock_auto_tokenizer.convert_tokens_to_string = Mock(return_value=correct_result)
mock_pipeline.return_value.tokenizer = mock_auto_tokenizer
layer = HFLocalInvocationLayer("google/flan-t5-base", max_length=3, model_max_length=10)
result = layer._ensure_token_limit("I am a tokenized prompt of length eight")
assert result == correct_result

View File

@ -8,8 +8,23 @@ from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamin
from haystack.nodes.prompt.invocation_layer import HFInferenceEndpointInvocationLayer from haystack.nodes.prompt.invocation_layer import HFInferenceEndpointInvocationLayer
@pytest.fixture
def mock_get_task():
# mock get_task function
with patch("haystack.nodes.prompt.invocation_layer.hugging_face_inference.get_task") as mock_get_task:
mock_get_task.return_value = "text2text-generation"
yield mock_get_task
@pytest.fixture
def mock_get_task_invalid():
with patch("haystack.nodes.prompt.invocation_layer.hugging_face_inference.get_task") as mock_get_task:
mock_get_task.return_value = "some-nonexistent-type"
yield mock_get_task
@pytest.mark.unit @pytest.mark.unit
def test_default_constructor(): def test_default_constructor(mock_auto_tokenizer):
""" """
Test that the default constructor sets the correct values Test that the default constructor sets the correct values
""" """
@ -22,7 +37,7 @@ def test_default_constructor():
@pytest.mark.unit @pytest.mark.unit
def test_constructor_with_model_kwargs(): def test_constructor_with_model_kwargs(mock_auto_tokenizer):
""" """
Test that model_kwargs are correctly set in the constructor Test that model_kwargs are correctly set in the constructor
and that model_kwargs_rejected are correctly filtered out and that model_kwargs_rejected are correctly filtered out
@ -41,7 +56,7 @@ def test_constructor_with_model_kwargs():
@pytest.mark.unit @pytest.mark.unit
def test_set_model_max_length(): def test_set_model_max_length(mock_auto_tokenizer):
""" """
Test that model max length is set correctly Test that model max length is set correctly
""" """
@ -52,7 +67,7 @@ def test_set_model_max_length():
@pytest.mark.unit @pytest.mark.unit
def test_url(): def test_url(mock_auto_tokenizer):
""" """
Test that the url is correctly set in the constructor Test that the url is correctly set in the constructor
""" """
@ -67,7 +82,7 @@ def test_url():
@pytest.mark.unit @pytest.mark.unit
def test_invoke_with_no_kwargs(): def test_invoke_with_no_kwargs(mock_auto_tokenizer):
""" """
Test that invoke raises an error if no prompt is provided Test that invoke raises an error if no prompt is provided
""" """
@ -78,7 +93,7 @@ def test_invoke_with_no_kwargs():
@pytest.mark.unit @pytest.mark.unit
def test_invoke_with_stop_words(): def test_invoke_with_stop_words(mock_auto_tokenizer):
""" """
Test stop words are correctly passed to HTTP POST request Test stop words are correctly passed to HTTP POST request
""" """
@ -95,14 +110,14 @@ def test_invoke_with_stop_words():
assert mock_post.called assert mock_post.called
# Check if stop_words are passed to _post as stop parameter # Check if stop_words are passed to _post as stop parameter
called_args, called_kwargs = mock_post.call_args _, called_kwargs = mock_post.call_args
assert "stop" in called_kwargs["data"]["parameters"] assert "stop" in called_kwargs["data"]["parameters"]
assert called_kwargs["data"]["parameters"]["stop"] == stop_words assert called_kwargs["data"]["parameters"]["stop"] == stop_words
@pytest.mark.unit @pytest.mark.unit
@pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize("stream", [True, False])
def test_streaming_stream_param_in_constructor(stream): def test_streaming_stream_param_in_constructor(mock_auto_tokenizer, stream):
""" """
Test stream parameter is correctly passed to HTTP POST request via constructor Test stream parameter is correctly passed to HTTP POST request via constructor
""" """
@ -117,7 +132,7 @@ def test_streaming_stream_param_in_constructor(stream):
layer.invoke(prompt="Tell me hello") layer.invoke(prompt="Tell me hello")
assert mock_post.called assert mock_post.called
called_args, called_kwargs = mock_post.call_args _, called_kwargs = mock_post.call_args
# stream is always passed to _post # stream is always passed to _post
assert "stream" in called_kwargs assert "stream" in called_kwargs
@ -127,7 +142,7 @@ def test_streaming_stream_param_in_constructor(stream):
@pytest.mark.unit @pytest.mark.unit
@pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize("stream", [True, False])
def test_streaming_stream_param_in_method(stream): def test_streaming_stream_param_in_method(mock_auto_tokenizer, stream):
""" """
Test stream parameter is correctly passed to HTTP POST request via method Test stream parameter is correctly passed to HTTP POST request via method
""" """
@ -140,13 +155,7 @@ def test_streaming_stream_param_in_method(stream):
layer.invoke(prompt="Tell me hello", stream=stream) layer.invoke(prompt="Tell me hello", stream=stream)
assert mock_post.called assert mock_post.called
called_args, called_kwargs = mock_post.call_args _, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
# Check if stop_words are passed to _post as stop parameter
called_args, called_kwargs = mock_post.call_args
# stream is always passed to _post # stream is always passed to _post
assert "stream" in called_kwargs assert "stream" in called_kwargs
@ -156,7 +165,7 @@ def test_streaming_stream_param_in_method(stream):
@pytest.mark.unit @pytest.mark.unit
def test_streaming_stream_handler_param_in_constructor(): def test_streaming_stream_handler_param_in_constructor(mock_auto_tokenizer):
""" """
Test stream_handler parameter is correctly passed to HTTP POST request via constructor Test stream_handler parameter is correctly passed to HTTP POST request via constructor
""" """
@ -175,7 +184,7 @@ def test_streaming_stream_handler_param_in_constructor():
layer.invoke(prompt="Tell me hello") layer.invoke(prompt="Tell me hello")
assert mock_post.called assert mock_post.called
called_args, called_kwargs = mock_post.call_args _, called_kwargs = mock_post.call_args
# stream is always passed to _post # stream is always passed to _post
assert "stream" in called_kwargs assert "stream" in called_kwargs
@ -183,12 +192,12 @@ def test_streaming_stream_handler_param_in_constructor():
assert called_kwargs["stream"] assert called_kwargs["stream"]
# stream_handler is passed as an instance of TokenStreamingHandler # stream_handler is passed as an instance of TokenStreamingHandler
called_args, called_kwargs = mock_post_stream.call_args called_args, _ = mock_post_stream.call_args
assert isinstance(called_args[1], TokenStreamingHandler) assert isinstance(called_args[1], TokenStreamingHandler)
@pytest.mark.unit @pytest.mark.unit
def test_streaming_no_stream_handler_param_in_constructor(): def test_streaming_no_stream_handler_param_in_constructor(mock_auto_tokenizer):
""" """
Test stream_handler parameter is correctly passed to HTTP POST request via constructor Test stream_handler parameter is correctly passed to HTTP POST request via constructor
""" """
@ -202,7 +211,7 @@ def test_streaming_no_stream_handler_param_in_constructor():
layer.invoke(prompt="Tell me hello") layer.invoke(prompt="Tell me hello")
assert mock_post.called assert mock_post.called
called_args, called_kwargs = mock_post.call_args _, called_kwargs = mock_post.call_args
# stream is always passed to _post # stream is always passed to _post
assert "stream" in called_kwargs assert "stream" in called_kwargs
@ -212,7 +221,7 @@ def test_streaming_no_stream_handler_param_in_constructor():
@pytest.mark.unit @pytest.mark.unit
def test_streaming_stream_handler_param_in_method(): def test_streaming_stream_handler_param_in_method(mock_auto_tokenizer):
""" """
Test stream_handler parameter is correctly passed to HTTP POST request via method Test stream_handler parameter is correctly passed to HTTP POST request via method
""" """
@ -241,7 +250,7 @@ def test_streaming_stream_handler_param_in_method():
@pytest.mark.unit @pytest.mark.unit
def test_streaming_no_stream_handler_param_in_method(): def test_streaming_no_stream_handler_param_in_method(mock_auto_tokenizer):
""" """
Test stream_handler parameter is correctly passed to HTTP POST request via method Test stream_handler parameter is correctly passed to HTTP POST request via method
""" """
@ -257,7 +266,7 @@ def test_streaming_no_stream_handler_param_in_method():
assert mock_post.called assert mock_post.called
called_args, called_kwargs = mock_post.call_args _, called_kwargs = mock_post.call_args
# stream is always correctly passed to _post # stream is always correctly passed to _post
assert "stream" in called_kwargs assert "stream" in called_kwargs
@ -304,7 +313,7 @@ def test_ensure_token_limit_resize(caplog, model_name_or_path):
@pytest.mark.unit @pytest.mark.unit
def test_oasst_prompt_preprocessing(): def test_oasst_prompt_preprocessing(mock_auto_tokenizer):
model_name = "OpenAssistant/oasst-sft-1-pythia-12b" model_name = "OpenAssistant/oasst-sft-1-pythia-12b"
layer = HFInferenceEndpointInvocationLayer("fake_api_key", model_name) layer = HFInferenceEndpointInvocationLayer("fake_api_key", model_name)
@ -318,7 +327,7 @@ def test_oasst_prompt_preprocessing():
assert result == ["Hello"] assert result == ["Hello"]
assert mock_post.called assert mock_post.called
called_args, called_kwargs = mock_post.call_args _, called_kwargs = mock_post.call_args
# OpenAssistant/oasst-sft-1-pythia-12b prompts are preprocessed and wrapped in tokens below # OpenAssistant/oasst-sft-1-pythia-12b prompts are preprocessed and wrapped in tokens below
assert called_kwargs["data"]["inputs"] == "<|prompter|>Tell me hello<|endoftext|><|assistant|>" assert called_kwargs["data"]["inputs"] == "<|prompter|>Tell me hello<|endoftext|><|assistant|>"
@ -326,22 +335,20 @@ def test_oasst_prompt_preprocessing():
@pytest.mark.unit @pytest.mark.unit
def test_invalid_key(): def test_invalid_key():
with pytest.raises(ValueError, match="must be a valid Hugging Face token"): with pytest.raises(ValueError, match="must be a valid Hugging Face token"):
layer = HFInferenceEndpointInvocationLayer("", "irrelevant_model_name") HFInferenceEndpointInvocationLayer("", "irrelevant_model_name")
@pytest.mark.unit @pytest.mark.unit
def test_invalid_model(): def test_invalid_model():
with pytest.raises(ValueError, match="cannot be None or empty string"): with pytest.raises(ValueError, match="cannot be None or empty string"):
layer = HFInferenceEndpointInvocationLayer("fake_api", "") HFInferenceEndpointInvocationLayer("fake_api", "")
@pytest.mark.unit @pytest.mark.unit
def test_supports(): def test_supports(mock_get_task):
""" """
Test that supports returns True correctly for HFInferenceEndpointInvocationLayer Test that supports returns True correctly for HFInferenceEndpointInvocationLayer
""" """
# doesn't support fake model
assert not HFInferenceEndpointInvocationLayer.supports("fake_model", api_key="fake_key")
# supports google/flan-t5-xxl with api_key # supports google/flan-t5-xxl with api_key
assert HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key="fake_key") assert HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key="fake_key")
@ -353,3 +360,8 @@ def test_supports():
assert HFInferenceEndpointInvocationLayer.supports( assert HFInferenceEndpointInvocationLayer.supports(
"https://<your-unique-deployment-id>.us-east-1.aws.endpoints.huggingface.cloud", api_key="fake_key" "https://<your-unique-deployment-id>.us-east-1.aws.endpoints.huggingface.cloud", api_key="fake_key"
) )
@pytest.mark.unit
def test_supports_not(mock_get_task_invalid):
assert not HFInferenceEndpointInvocationLayer.supports("fake_model", api_key="fake_key")

View File

@ -9,6 +9,58 @@ from haystack.nodes.prompt.invocation_layer.handlers import (
) )
@pytest.mark.unit
def test_prompt_handler_positive():
# prompt of length 5 + max_length of 3 = 8, which is less than model_max_length of 10, so no resize
mock_tokens = ["I", "am", "a", "tokenized", "prompt"]
mock_prompt = "I am a tokenized prompt"
with patch(
"haystack.nodes.prompt.invocation_layer.handlers.AutoTokenizer.from_pretrained", autospec=True
) as mock_tokenizer:
tokenizer_instance = mock_tokenizer.return_value
tokenizer_instance.tokenize.return_value = mock_tokens
tokenizer_instance.convert_tokens_to_string.return_value = mock_prompt
prompt_handler = DefaultPromptHandler("model_path", 10, 3)
# Test with a prompt that does not exceed model_max_length when tokenized
result = prompt_handler(mock_prompt)
assert result == {
"resized_prompt": mock_prompt,
"prompt_length": 5,
"new_prompt_length": 5,
"model_max_length": 10,
"max_length": 3,
}
@pytest.mark.unit
def test_prompt_handler_negative():
# prompt of length 8 + max_length of 3 = 11, which is more than model_max_length of 10, so we resize to 7
mock_tokens = ["I", "am", "a", "tokenized", "prompt", "of", "length", "eight"]
mock_prompt = "I am a tokenized prompt of length"
with patch(
"haystack.nodes.prompt.invocation_layer.handlers.AutoTokenizer.from_pretrained", autospec=True
) as mock_tokenizer:
tokenizer_instance = mock_tokenizer.return_value
tokenizer_instance.tokenize.return_value = mock_tokens
tokenizer_instance.convert_tokens_to_string.return_value = mock_prompt
prompt_handler = DefaultPromptHandler("model_path", 10, 3)
result = prompt_handler(mock_prompt)
assert result == {
"resized_prompt": mock_prompt,
"prompt_length": 8,
"new_prompt_length": 7,
"model_max_length": 10,
"max_length": 3,
}
@pytest.mark.integration @pytest.mark.integration
def test_prompt_handler_basics(): def test_prompt_handler_basics():
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10) handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)