diff --git a/test/agents/test_agent.py b/test/agents/test_agent.py index 346e47ca0..52b30ea8e 100644 --- a/test/agents/test_agent.py +++ b/test/agents/test_agent.py @@ -2,8 +2,12 @@ import logging import os import re from typing import Tuple -from unittest.mock import patch +from unittest.mock import Mock, patch + +from events import Events + +from haystack.agents.types import AgentTokenStreamingHandler from test.conftest import MockRetriever, MockPromptNode from unittest import mock import pytest @@ -290,17 +294,6 @@ def test_tool_processes_answer_result_and_document_result(): assert tool._process_result(Document(content="content")) == "content" -def test_invalid_agent_template(): - pn = PromptNode() - with pytest.raises(ValueError, match="some_non_existing_template not supported"): - Agent(prompt_node=pn, prompt_template="some_non_existing_template") - - # if prompt_template is None, then we'll use zero-shot-react - a = Agent(prompt_node=pn, prompt_template=None) - assert isinstance(a.prompt_template, PromptTemplate) - assert a.prompt_template.name == "zero-shot-react" - - @pytest.mark.unit @patch.object(PromptNode, "prompt") @patch("haystack.nodes.prompt.prompt_node.PromptModel") @@ -315,3 +308,25 @@ def test_default_template_order(mock_model, mock_prompt): a = Agent(prompt_node=pn, prompt_template="translation") assert a.prompt_template.name == "translation" + + +@pytest.mark.unit +def test_agent_with_unknown_prompt_template(): + prompt_node = Mock() + prompt_node.get_prompt_template.return_value = None + with pytest.raises(ValueError, match="Prompt template 'invalid' not found"): + Agent(prompt_node=prompt_node, prompt_template="invalid") + + +@pytest.mark.unit +def test_agent_token_streaming_handler(): + e = Events("on_new_token") + + mock_callback = Mock() + e.on_new_token += mock_callback # register the mock callback to the event + + handler = AgentTokenStreamingHandler(events=e) + result = handler("test") + + assert result == "test" + mock_callback.assert_called_once_with("test") # assert that the mock callback was called with "test" diff --git a/test/agents/test_conversational_agent.py b/test/agents/test_conversational_agent.py index 7180bcb4a..89de9f38e 100644 --- a/test/agents/test_conversational_agent.py +++ b/test/agents/test_conversational_agent.py @@ -1,17 +1,15 @@ import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock from haystack.agents.conversational import ConversationalAgent from haystack.agents.memory import ConversationSummaryMemory, ConversationMemory, NoMemory -from haystack.nodes import PromptNode +from test.conftest import MockPromptNode @pytest.mark.unit def test_init(): - with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub: - mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")] - prompt_node = PromptNode() - agent = ConversationalAgent(prompt_node) + prompt_node = MockPromptNode() + agent = ConversationalAgent(prompt_node) # Test normal case assert isinstance(agent.memory, ConversationMemory) @@ -25,32 +23,27 @@ def test_init(): @pytest.mark.unit def test_init_with_summary_memory(): - with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub: - mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")] - prompt_node = PromptNode(default_prompt_template="this is a test") - # Test with summary memory - agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node)) - assert isinstance(agent.memory, ConversationSummaryMemory) + # Test with summary memory + prompt_node = MockPromptNode() + agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node)) + assert isinstance(agent.memory, ConversationSummaryMemory) @pytest.mark.unit def test_init_with_no_memory(): - with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub: - mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")] - prompt_node = PromptNode() - # Test with no memory - agent = ConversationalAgent(prompt_node, memory=NoMemory()) - assert isinstance(agent.memory, NoMemory) + prompt_node = MockPromptNode() + # Test with no memory + agent = ConversationalAgent(prompt_node, memory=NoMemory()) + assert isinstance(agent.memory, NoMemory) @pytest.mark.unit def test_run(): - with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub: - mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")] - prompt_node = PromptNode() - agent = ConversationalAgent(prompt_node) + prompt_node = MockPromptNode() + agent = ConversationalAgent(prompt_node) - # Mock the Agent run method - agent.run = MagicMock(return_value="Hello") - assert agent.run("query") == "Hello" - agent.run.assert_called_once_with("query") + # Mock the Agent run method + result = agent.run("query") + + # empty answer + assert result["answers"][0].answer == "" diff --git a/test/agents/test_tools_manager.py b/test/agents/test_tools_manager.py index f726a064c..afc522ae6 100644 --- a/test/agents/test_tools_manager.py +++ b/test/agents/test_tools_manager.py @@ -1,9 +1,10 @@ import unittest +from typing import Optional, Union, List, Dict, Any from unittest import mock import pytest -from haystack import Pipeline, Answer, Document +from haystack import Pipeline, Answer, Document, BaseComponent, MultiLabel from haystack.agents.base import ToolsManager, Tool @@ -160,3 +161,57 @@ def test_extract_tool_name_and_empty_tool_input(tools_manager): for example in examples: tool_name, tool_input = tools_manager.extract_tool_name_and_tool_input(example) assert tool_name == "Search" and tool_input == "" + + +@pytest.mark.unit +def test_node_as_tool(): + # test that a component can be used as a tool + class ToolComponent(BaseComponent): + outgoing_edges = 1 + + def run_batch( + self, + queries: Optional[Union[str, List[str]]] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None, + documents: Optional[Union[List[Document], List[List[Document]]]] = None, + meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + params: Optional[dict] = None, + debug: Optional[bool] = None, + ): + pass + + def run(self, **kwargs): + return "mocked_output" + + tool = Tool(name="ToolA", pipeline_or_node=ToolComponent(), description="Tool A Description") + assert tool.run("input") == "mocked_output" + + +@pytest.mark.unit +def test_tools_manager_exception(): + # tests exception raising in tools manager + class ToolComponent(BaseComponent): + outgoing_edges = 1 + + def run_batch( + self, + queries: Optional[Union[str, List[str]]] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None, + documents: Optional[Union[List[Document], List[List[Document]]]] = None, + meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + params: Optional[dict] = None, + debug: Optional[bool] = None, + ): + pass + + def run(self, **kwargs): + raise Exception("mocked_exception") + + fake_llm_response = "need to find out what city he was born.\nTool: Search\nTool Input: Where was Jeremy born" + tool = Tool(name="Search", pipeline_or_node=ToolComponent(), description="Search") + tools_manager = ToolsManager(tools=[tool]) + + with pytest.raises(Exception): + tools_manager.run_tool(llm_response=fake_llm_response) diff --git a/test/conftest.py b/test/conftest.py index 48a15642b..fc7a8c721 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -8,6 +8,7 @@ from pathlib import Path import os import re from functools import wraps +from unittest.mock import patch import requests_cache import responses @@ -847,5 +848,14 @@ def request_blocker(request: pytest.FixtureRequest, monkeypatch): marker = request.node.get_closest_marker("unit") if marker is None: return - monkeypatch.delattr("requests.sessions.Session") - monkeypatch.delattr("requests_cache.session.CachedSession") + + def urlopen_mock(self, method, url, *args, **kwargs): + raise RuntimeError(f"The test was about to {method} {self.scheme}://{self.host}{url}") + + monkeypatch.setattr("urllib3.connectionpool.HTTPConnectionPool.urlopen", urlopen_mock) + + +@pytest.fixture +def mock_auto_tokenizer(): + with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained: + yield mock_from_pretrained diff --git a/test/nodes/test_generator.py b/test/nodes/test_generator.py index 5c19fe2c8..54b6c4c0b 100644 --- a/test/nodes/test_generator.py +++ b/test/nodes/test_generator.py @@ -15,7 +15,7 @@ from ..conftest import fail_at_version @pytest.mark.unit @fail_at_version(1, 18) -def test_seq2seq_deprecation(): +def test_seq2seq_deprecation(mock_auto_tokenizer): with pytest.warns(DeprecationWarning): try: Seq2SeqGenerator("non_existing_model/model") diff --git a/test/prompt/invocation_layer/test_anthropic_claude.py b/test/prompt/invocation_layer/test_anthropic_claude.py index 6823ba957..2dc4682e8 100644 --- a/test/prompt/invocation_layer/test_anthropic_claude.py +++ b/test/prompt/invocation_layer/test_anthropic_claude.py @@ -8,10 +8,21 @@ from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamin from haystack.nodes.prompt.invocation_layer import AnthropicClaudeInvocationLayer +@pytest.fixture +def mock_claude_tokenizer(): + with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer", autospec=True) as mock_tokenizer: + yield mock_tokenizer + + +@pytest.fixture +def mock_claude_request(): + with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request: + yield mock_request + + @pytest.mark.unit -def test_default_costuctor(): - with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"): - layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") +def test_default_constructor(mock_claude_tokenizer, mock_claude_request): + layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") assert layer.api_key == "some_fake_key" assert layer.max_length == 200 @@ -20,7 +31,7 @@ def test_default_costuctor(): @pytest.mark.unit -def test_ignored_kwargs_are_filtered_in_init(): +def test_ignored_kwargs_are_filtered_in_init(mock_claude_tokenizer, mock_claude_request): kwargs = { "temperature": 1, "top_p": 5, @@ -30,8 +41,8 @@ def test_ignored_kwargs_are_filtered_in_init(): "stream_handler": DefaultTokenStreamingHandler(), "unkwnown_args": "this will be filtered out", } - with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"): - layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", **kwargs) + + layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", **kwargs) # Verify unexpected kwargs are filtered out assert len(layer.model_input_kwargs) == 6 @@ -45,9 +56,8 @@ def test_ignored_kwargs_are_filtered_in_init(): @pytest.mark.unit -def test_invoke_with_no_kwargs(): - with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"): - layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") +def test_invoke_with_no_kwargs(mock_claude_tokenizer, mock_claude_request): + layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") with pytest.raises(ValueError) as e: layer.invoke() @@ -55,14 +65,12 @@ def test_invoke_with_no_kwargs(): @pytest.mark.unit -@patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") -def test_invoke_with_prompt_only(mock_request): - with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"): - layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") +def test_invoke_with_prompt_only(mock_claude_tokenizer, mock_claude_request): + layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") # Create a fake response mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}}) - mock_request.return_value = mock_response + mock_claude_request.return_value = mock_response res = layer.invoke(prompt="Some prompt") assert len(res) == 1 @@ -70,14 +78,13 @@ def test_invoke_with_prompt_only(mock_request): @pytest.mark.unit -def test_invoke_with_kwargs(): - with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"): - layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") +def test_invoke_with_kwargs(mock_claude_tokenizer, mock_claude_request): + layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") # Create a fake response mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}}) - with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request: - mock_request.return_value = mock_response + with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request: + mock_invocation_request.return_value = mock_response res = layer.invoke(prompt="Some prompt", max_length=300, stop_words=["stop", "here"]) assert len(res) == 1 assert res[0] == "some_result" @@ -92,19 +99,18 @@ def test_invoke_with_kwargs(): "stream": False, "stop_sequences": ["stop", "here", "\n\nHuman: "], } - mock_request.assert_called_once() - assert mock_request.call_args.kwargs["data"] == json.dumps(expected_data) + mock_invocation_request.assert_called_once() + assert mock_invocation_request.call_args.kwargs["data"] == json.dumps(expected_data) @pytest.mark.unit -def test_invoke_with_none_stop_words(): - with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"): - layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") +def test_invoke_with_none_stop_words(mock_claude_tokenizer, mock_claude_request): + layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") # Create a fake response mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}}) - with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request: - mock_request.return_value = mock_response + with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request: + mock_invocation_request.return_value = mock_response res = layer.invoke(prompt="Some prompt", max_length=300, stop_words=None) assert len(res) == 1 assert res[0] == "some_result" @@ -119,14 +125,13 @@ def test_invoke_with_none_stop_words(): "stream": False, "stop_sequences": ["\n\nHuman: "], } - mock_request.assert_called_once() - assert mock_request.call_args.kwargs["data"] == json.dumps(expected_data) + mock_invocation_request.assert_called_once() + assert mock_invocation_request.call_args.kwargs["data"] == json.dumps(expected_data) @pytest.mark.unit -def test_invoke_with_stream(): - with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"): - layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") +def test_invoke_with_stream(mock_claude_tokenizer, mock_claude_request): + layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") # Create a fake streamed response def mock_iter(self): @@ -141,8 +146,8 @@ def test_invoke_with_stream(): mock_response = Mock(**{"__iter__": mock_iter}) # Verifies expected result is returned - with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request: - mock_request.return_value = mock_response + with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request: + mock_invocation_request.return_value = mock_response res = layer.invoke(prompt="Some prompt", stream=True) assert len(res) == 1 @@ -150,14 +155,13 @@ def test_invoke_with_stream(): @pytest.mark.unit -def test_invoke_with_custom_stream_handler(): +def test_invoke_with_custom_stream_handler(mock_claude_tokenizer, mock_claude_request): # Create a mock stream handler that always return the same token when called mock_stream_handler = Mock() mock_stream_handler.return_value = "token" # Create a layer with a mocked stream handler - with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"): - layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", stream_handler=mock_stream_handler) + layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", stream_handler=mock_stream_handler) # Create a fake streamed response def mock_iter(self): @@ -171,8 +175,8 @@ def test_invoke_with_custom_stream_handler(): mock_response = Mock(**{"__iter__": mock_iter}) - with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request: - mock_request.return_value = mock_response + with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request: + mock_invocation_request.return_value = mock_response res = layer.invoke(prompt="Some prompt") assert len(res) == 1 @@ -186,7 +190,7 @@ def test_invoke_with_custom_stream_handler(): @pytest.mark.unit -def test_ensure_token_limit_fails_if_called_with_list(): +def test_ensure_token_limit_fails_if_called_with_list(mock_claude_tokenizer, mock_claude_request): layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") with pytest.raises(ValueError): layer._ensure_token_limit(prompt=[]) @@ -225,7 +229,7 @@ def test_ensure_token_limit_with_huge_max_length(caplog): @pytest.mark.unit -def test_supports(): +def test_supports(mock_claude_tokenizer, mock_claude_request): layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key") assert not layer.supports("claude") diff --git a/test/prompt/invocation_layer/test_cohere.py b/test/prompt/invocation_layer/test_cohere.py index 3beb1176d..43834df5e 100644 --- a/test/prompt/invocation_layer/test_cohere.py +++ b/test/prompt/invocation_layer/test_cohere.py @@ -1,14 +1,14 @@ import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import Mock import pytest -from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler +from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler from haystack.nodes.prompt.invocation_layer import CohereInvocationLayer @pytest.mark.unit -def test_default_constructor(): +def test_default_constructor(mock_auto_tokenizer): """ Test that the default constructor sets the correct values """ @@ -28,7 +28,7 @@ def test_default_constructor(): @pytest.mark.unit -def test_constructor_with_model_kwargs(): +def test_constructor_with_model_kwargs(mock_auto_tokenizer): """ Test that model_kwargs are correctly set in the constructor and that model_kwargs_rejected are correctly filtered out @@ -43,7 +43,7 @@ def test_constructor_with_model_kwargs(): @pytest.mark.unit -def test_invoke_with_no_kwargs(): +def test_invoke_with_no_kwargs(mock_auto_tokenizer): """ Test that invoke raises an error if no prompt is provided """ @@ -54,7 +54,7 @@ def test_invoke_with_no_kwargs(): @pytest.mark.unit -def test_invoke_with_stop_words(): +def test_invoke_with_stop_words(mock_auto_tokenizer): """ Test stop words are correctly passed from PromptNode to wire in CohereInvocationLayer """ @@ -62,98 +62,135 @@ def test_invoke_with_stop_words(): layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key") with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post: # Mock the response, need to return a list of dicts - mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}') + mock_post.return_value = Mock(text='{"generations":[{"text": "Hello"}]}') layer.invoke(prompt="Tell me hello", stop_words=stop_words) - assert mock_post.called - - # Check if stop_words are passed to _post as stop parameter - called_args, _ = mock_post.call_args - assert "end_sequences" in called_args[0] - assert called_args[0]["end_sequences"] == stop_words + assert mock_post.called + called_args, _ = mock_post.call_args + assert "end_sequences" in called_args[0] + assert called_args[0]["end_sequences"] == stop_words @pytest.mark.unit -@pytest.mark.parametrize("using_constructor", [True, False]) -@pytest.mark.parametrize("stream", [True, False]) -def test_streaming_stream_param(using_constructor, stream): +def test_streaming_stream_param_from_init(mock_auto_tokenizer): """ - Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer + Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from init """ - if using_constructor: - layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=stream) - else: - layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key") + + layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=True) with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post: - # Mock the response, need to return a list of dicts - mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}') + # Mock the response + mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}'])) + layer.invoke(prompt="Tell me hello") - if using_constructor: - layer.invoke(prompt="Tell me hello") - else: - layer.invoke(prompt="Tell me hello", stream=stream) + assert mock_post.called + _, called_kwargs = mock_post.call_args - assert mock_post.called - - # Check if stop_words are passed to _post as stop parameter - called_args, called_kwargs = mock_post.call_args - - # stream is always passed to _post - assert "stream" in called_kwargs - - # Check if stream is True, then stream is passed as True to _post - if stream: - assert called_kwargs["stream"] - # Check if stream is False, then stream is passed as False to _post - else: - assert not called_kwargs["stream"] + # stream is always passed to _post + assert "stream" in called_kwargs and called_kwargs["stream"] @pytest.mark.unit -@pytest.mark.parametrize("using_constructor", [True, False]) -@pytest.mark.parametrize("stream_handler", [DefaultTokenStreamingHandler(), None]) -def test_streaming_stream_handler_param(using_constructor, stream_handler): +def test_streaming_stream_param_from_init_no_stream(mock_auto_tokenizer): + """ + Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from init + """ + + layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key") + + with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post: + # Mock the response + mock_post.return_value = Mock(text='{"generations":[{"text": "Hello there"}]}') + layer.invoke(prompt="Tell me hello") + + assert mock_post.called + _, called_kwargs = mock_post.call_args + + # stream is always passed to _post + assert "stream" in called_kwargs + assert not bool(called_kwargs["stream"]) + + +@pytest.mark.unit +def test_streaming_stream_param_from_invoke(mock_auto_tokenizer): + """ + Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from invoke + """ + layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key") + + with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post: + # Mock the response + mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}'])) + layer.invoke(prompt="Tell me hello", stream=True) + + assert mock_post.called + _, called_kwargs = mock_post.call_args + + # stream is always passed to _post + assert "stream" in called_kwargs + assert bool(called_kwargs["stream"]) + + +@pytest.mark.unit +def test_streaming_stream_param_from_invoke_no_stream(mock_auto_tokenizer): + """ + Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from invoke + """ + layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=True) + + with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post: + # Mock the response + mock_post.return_value = Mock(text='{"generations":[{"text": "Hello there"}]}') + layer.invoke(prompt="Tell me hello", stream=False) + + assert mock_post.called + _, called_kwargs = mock_post.call_args + + # stream is always passed to _post + assert "stream" in called_kwargs + assert not bool(called_kwargs["stream"]) + + +@pytest.mark.unit +def test_streaming_stream_handler_param_from_init(mock_auto_tokenizer): """ Test stream_handler parameter is correctly from PromptNode passed to wire in CohereInvocationLayer """ - if using_constructor: - layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream_handler=stream_handler) - else: - layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key") + stream_handler = DefaultTokenStreamingHandler() + layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream_handler=stream_handler) - with unittest.mock.patch( - "haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post" - ) as mock_post, unittest.mock.patch( - "haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._process_streaming_response" - ) as mock_post_stream: - # Mock the response, need to return a list of dicts - mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}') + with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post: + # Mock the response + mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}'])) + responses = layer.invoke(prompt="Tell me hello") - if using_constructor: - layer.invoke(prompt="Tell me hello") - else: - layer.invoke(prompt="Tell me hello", stream_handler=stream_handler) + assert mock_post.called + _, called_kwargs = mock_post.call_args + assert "stream" in called_kwargs + assert bool(called_kwargs["stream"]) + assert responses == ["Hello there"] - assert mock_post.called - # Check if stop_words are passed to _post as stop parameter - called_args, called_kwargs = mock_post.call_args +@pytest.mark.unit +def test_streaming_stream_handler_param_from_invoke(mock_auto_tokenizer): + """ + Test stream_handler parameter is correctly from PromptNode passed to wire in CohereInvocationLayer + """ + stream_handler = DefaultTokenStreamingHandler() + layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key") - # stream is always passed to _post - assert "stream" in called_kwargs + with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post: + # Mock the response + mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}'])) + responses = layer.invoke(prompt="Tell me hello", stream_handler=stream_handler) - # if stream_handler is used then stream is always True - if stream_handler: - assert called_kwargs["stream"] - # and stream_handler is passed as an instance of TokenStreamingHandler - called_args, called_kwargs = mock_post_stream.call_args - assert "stream_handler" in called_kwargs - assert isinstance(called_kwargs["stream_handler"], TokenStreamingHandler) - # if stream_handler is not used then stream is always False - else: - assert not called_kwargs["stream"] + assert mock_post.called + _, called_kwargs = mock_post.call_args + assert "stream" in called_kwargs + assert bool(called_kwargs["stream"]) + assert responses == ["Hello there"] @pytest.mark.unit @@ -181,13 +218,13 @@ def test_supports(): @pytest.mark.unit -def test_ensure_token_limit_fails_if_called_with_list(): +def test_ensure_token_limit_fails_if_called_with_list(mock_auto_tokenizer): layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key") with pytest.raises(ValueError): layer._ensure_token_limit(prompt=[]) -@pytest.mark.unit +@pytest.mark.integration def test_ensure_token_limit_with_small_max_length(caplog): layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=10) res = layer._ensure_token_limit(prompt="Short prompt") @@ -200,7 +237,7 @@ def test_ensure_token_limit_with_small_max_length(caplog): assert not caplog.records -@pytest.mark.unit +@pytest.mark.integration def test_ensure_token_limit_with_huge_max_length(caplog): layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=4090) res = layer._ensure_token_limit(prompt="Short prompt") diff --git a/test/prompt/invocation_layer/test_hugging_face.py b/test/prompt/invocation_layer/test_hugging_face.py index d6dc2a3f4..5c30dd6b9 100644 --- a/test/prompt/invocation_layer/test_hugging_face.py +++ b/test/prompt/invocation_layer/test_hugging_face.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch, Mock import pytest import torch from torch import device -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BloomForCausalLM, StoppingCriteriaList, GenerationConfig +from transformers import AutoTokenizer, BloomForCausalLM, StoppingCriteriaList, GenerationConfig from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler, DefaultTokenStreamingHandler @@ -13,8 +13,11 @@ from haystack.nodes.prompt.invocation_layer.hugging_face import StopWordsCriteri @pytest.fixture def mock_pipeline(): # mock transformers pipeline + # model returning some mocked text for pipeline invocation with patch("haystack.nodes.prompt.invocation_layer.hugging_face.pipeline") as mocked_pipeline: - mocked_pipeline.return_value = Mock(**{"model_name_or_path": None, "tokenizer.model_max_length": 100}) + pipeline_mock = Mock(**{"model_name_or_path": None, "tokenizer.model_max_length": 100}) + pipeline_mock.side_effect = lambda *args, **kwargs: [{"generated_text": "some mocked text"}] + mocked_pipeline.return_value = pipeline_mock yield mocked_pipeline @@ -44,7 +47,7 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task): mock_pipeline.assert_called_once() - args, kwargs = mock_pipeline.call_args + _, kwargs = mock_pipeline.call_args # device is set to cpu by default and device_map is empty assert kwargs["device"] == device("cpu") @@ -87,7 +90,7 @@ def test_constructor_with_model_name_and_device_map(mock_pipeline, mock_get_task mock_pipeline.assert_called_once() mock_get_task.assert_called_once() - args, kwargs = mock_pipeline.call_args + _, kwargs = mock_pipeline.call_args # device is NOT set; device_map is auto because device_map takes precedence over device assert not kwargs["device"] @@ -110,7 +113,7 @@ def test_constructor_with_torch_dtype(mock_pipeline, mock_get_task): mock_pipeline.assert_called_once() mock_get_task.assert_called_once() - args, kwargs = mock_pipeline.call_args + _, kwargs = mock_pipeline.call_args assert kwargs["torch_dtype"] == torch.float16 @@ -126,7 +129,7 @@ def test_constructor_with_torch_dtype_as_str(mock_pipeline, mock_get_task): mock_pipeline.assert_called_once() mock_get_task.assert_called_once() - args, kwargs = mock_pipeline.call_args + _, kwargs = mock_pipeline.call_args assert kwargs["torch_dtype"] == torch.float16 @@ -142,7 +145,7 @@ def test_constructor_with_torch_dtype_auto(mock_pipeline, mock_get_task): mock_pipeline.assert_called_once() mock_get_task.assert_called_once() - args, kwargs = mock_pipeline.call_args + _, kwargs = mock_pipeline.call_args assert kwargs["torch_dtype"] == "auto" @@ -206,9 +209,8 @@ def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task): """ Test that the constructor sets the pipeline with the pretrained model (if provided) """ - # actual model and tokenizer passed to the pipeline - model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + model = Mock() + tokenizer = Mock() HFLocalInvocationLayer( model_name_or_path="irrelevant_when_model_is_provided", @@ -221,7 +223,7 @@ def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task): # mock_get_task is not called as we provided task_name parameter mock_get_task.assert_not_called() - args, kwargs = mock_pipeline.call_args + _, kwargs = mock_pipeline.call_args # correct tokenizer and model are set as well assert kwargs["tokenizer"] == tokenizer @@ -239,7 +241,7 @@ def test_constructor_with_invalid_kwargs(mock_pipeline, mock_get_task): mock_pipeline.assert_called_once() mock_get_task.assert_called_once() - args, kwargs = mock_pipeline.call_args + _, kwargs = mock_pipeline.call_args # invalid kwargs are ignored and not passed to the pipeline assert "some_invalid_kwarg" not in kwargs @@ -258,7 +260,7 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task): HFLocalInvocationLayer( "google/flan-t5-base", task_name="text2text-generation", - tokenizer=AutoTokenizer.from_pretrained("google/flan-t5-base"), + tokenizer=Mock(), config=Mock(), revision="1.1", device="cpu", @@ -271,7 +273,7 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task): # mock_get_task is not called as we provided task_name parameter mock_get_task.assert_not_called() - args, kwargs = mock_pipeline.call_args + _, kwargs = mock_pipeline.call_args # invalid kwargs are ignored and not passed to the pipeline assert "first_invalid_kwarg" not in kwargs @@ -287,7 +289,7 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task): assert len(kwargs) == 13 -@pytest.mark.unit +@pytest.mark.integration def test_text_generation_model(): # test simple prompting with text generation model # by default, we force the model not return prompt text @@ -303,7 +305,7 @@ def test_text_generation_model(): assert len(r[0]) > 0 and r[0].startswith("Hello big science!") -@pytest.mark.unit +@pytest.mark.integration def test_text_generation_model_via_custom_pretrained_model(): tokenizer = AutoTokenizer.from_pretrained("bigscience/bigscience-small-testing") model = BloomForCausalLM.from_pretrained("bigscience/bigscience-small-testing") @@ -320,31 +322,29 @@ def test_text_generation_model_via_custom_pretrained_model(): @pytest.mark.unit -def test_streaming_stream_param_in_constructor(): +def test_streaming_stream_param_in_constructor(mock_pipeline, mock_get_task): """ Test stream parameter is correctly passed to pipeline invocation via HF streamer parameter """ layer = HFLocalInvocationLayer(stream=True) - layer.pipe = MagicMock() layer.invoke(prompt="Tell me hello") - args, kwargs = layer.pipe.call_args + _, kwargs = layer.pipe.call_args assert "streamer" in kwargs and isinstance(kwargs["streamer"], HFTokenStreamingHandler) @pytest.mark.unit -def test_streaming_stream_handler_param_in_constructor(): +def test_streaming_stream_handler_param_in_constructor(mock_pipeline, mock_get_task): """ Test stream parameter is correctly passed to pipeline invocation """ dtsh = DefaultTokenStreamingHandler() layer = HFLocalInvocationLayer(stream_handler=dtsh) - layer.pipe = MagicMock() layer.invoke(prompt="Tell me hello") - args, kwargs = layer.pipe.call_args + _, kwargs = layer.pipe.call_args assert "streamer" in kwargs hf_streamer = kwargs["streamer"] @@ -394,18 +394,17 @@ def test_supports(tmp_path): @pytest.mark.unit -def test_stop_words_criteria_set(): +def test_stop_words_criteria_set(mock_pipeline, mock_get_task): """ Test that stop words criteria is correctly set in pipeline invocation """ layer = HFLocalInvocationLayer( model_name_or_path="hf-internal-testing/tiny-random-t5", task_name="text2text-generation" ) - layer.pipe = MagicMock() layer.invoke(prompt="Tell me hello", stop_words=["hello", "world"]) - args, kwargs = layer.pipe.call_args + _, kwargs = layer.pipe.call_args assert "stopping_criteria" in kwargs assert isinstance(kwargs["stopping_criteria"], StoppingCriteriaList) assert len(kwargs["stopping_criteria"]) == 1 @@ -468,7 +467,7 @@ def test_stop_words_not_being_found(): assert word in result[0] -@pytest.mark.unit +@pytest.mark.integration def test_generation_kwargs_from_constructor(): """ Test that generation_kwargs are correctly passed to pipeline invocation from constructor @@ -489,7 +488,7 @@ def test_generation_kwargs_from_constructor(): mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {}) -@pytest.mark.unit +@pytest.mark.integration def test_generation_kwargs_from_invoke(): """ Test that generation_kwargs passed to invoke are passed to the underlying HF model @@ -508,3 +507,35 @@ def test_generation_kwargs_from_invoke(): layer.invoke(prompt=the_question, generation_kwargs=GenerationConfig(do_sample=True, top_p=0.9)) mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {}) + + +@pytest.mark.unit +def test_ensure_token_limit_positive_mock(mock_pipeline, mock_get_task, mock_auto_tokenizer): + # prompt of length 5 + max_length of 3 = 8, which is less than model_max_length of 10, so no resize + mock_tokens = ["I", "am", "a", "tokenized", "prompt"] + mock_prompt = "I am a tokenized prompt" + + mock_auto_tokenizer.tokenize = Mock(return_value=mock_tokens) + mock_auto_tokenizer.convert_tokens_to_string = Mock(return_value=mock_prompt) + mock_pipeline.return_value.tokenizer = mock_auto_tokenizer + + layer = HFLocalInvocationLayer("google/flan-t5-base", max_length=3, model_max_length=10) + result = layer._ensure_token_limit(mock_prompt) + + assert result == mock_prompt + + +@pytest.mark.unit +def test_ensure_token_limit_negative_mock(mock_pipeline, mock_get_task, mock_auto_tokenizer): + # prompt of length 8 + max_length of 3 = 11, which is more than model_max_length of 10, so we resize to 7 + mock_tokens = ["I", "am", "a", "tokenized", "prompt", "of", "length", "eight"] + correct_result = "I am a tokenized prompt of length" + + mock_auto_tokenizer.tokenize = Mock(return_value=mock_tokens) + mock_auto_tokenizer.convert_tokens_to_string = Mock(return_value=correct_result) + mock_pipeline.return_value.tokenizer = mock_auto_tokenizer + + layer = HFLocalInvocationLayer("google/flan-t5-base", max_length=3, model_max_length=10) + result = layer._ensure_token_limit("I am a tokenized prompt of length eight") + + assert result == correct_result diff --git a/test/prompt/invocation_layer/test_hugging_face_inference.py b/test/prompt/invocation_layer/test_hugging_face_inference.py index fa1b019c1..57e610a6f 100644 --- a/test/prompt/invocation_layer/test_hugging_face_inference.py +++ b/test/prompt/invocation_layer/test_hugging_face_inference.py @@ -8,8 +8,23 @@ from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamin from haystack.nodes.prompt.invocation_layer import HFInferenceEndpointInvocationLayer +@pytest.fixture +def mock_get_task(): + # mock get_task function + with patch("haystack.nodes.prompt.invocation_layer.hugging_face_inference.get_task") as mock_get_task: + mock_get_task.return_value = "text2text-generation" + yield mock_get_task + + +@pytest.fixture +def mock_get_task_invalid(): + with patch("haystack.nodes.prompt.invocation_layer.hugging_face_inference.get_task") as mock_get_task: + mock_get_task.return_value = "some-nonexistent-type" + yield mock_get_task + + @pytest.mark.unit -def test_default_constructor(): +def test_default_constructor(mock_auto_tokenizer): """ Test that the default constructor sets the correct values """ @@ -22,7 +37,7 @@ def test_default_constructor(): @pytest.mark.unit -def test_constructor_with_model_kwargs(): +def test_constructor_with_model_kwargs(mock_auto_tokenizer): """ Test that model_kwargs are correctly set in the constructor and that model_kwargs_rejected are correctly filtered out @@ -41,7 +56,7 @@ def test_constructor_with_model_kwargs(): @pytest.mark.unit -def test_set_model_max_length(): +def test_set_model_max_length(mock_auto_tokenizer): """ Test that model max length is set correctly """ @@ -52,7 +67,7 @@ def test_set_model_max_length(): @pytest.mark.unit -def test_url(): +def test_url(mock_auto_tokenizer): """ Test that the url is correctly set in the constructor """ @@ -67,7 +82,7 @@ def test_url(): @pytest.mark.unit -def test_invoke_with_no_kwargs(): +def test_invoke_with_no_kwargs(mock_auto_tokenizer): """ Test that invoke raises an error if no prompt is provided """ @@ -78,7 +93,7 @@ def test_invoke_with_no_kwargs(): @pytest.mark.unit -def test_invoke_with_stop_words(): +def test_invoke_with_stop_words(mock_auto_tokenizer): """ Test stop words are correctly passed to HTTP POST request """ @@ -95,14 +110,14 @@ def test_invoke_with_stop_words(): assert mock_post.called # Check if stop_words are passed to _post as stop parameter - called_args, called_kwargs = mock_post.call_args + _, called_kwargs = mock_post.call_args assert "stop" in called_kwargs["data"]["parameters"] assert called_kwargs["data"]["parameters"]["stop"] == stop_words @pytest.mark.unit @pytest.mark.parametrize("stream", [True, False]) -def test_streaming_stream_param_in_constructor(stream): +def test_streaming_stream_param_in_constructor(mock_auto_tokenizer, stream): """ Test stream parameter is correctly passed to HTTP POST request via constructor """ @@ -117,7 +132,7 @@ def test_streaming_stream_param_in_constructor(stream): layer.invoke(prompt="Tell me hello") assert mock_post.called - called_args, called_kwargs = mock_post.call_args + _, called_kwargs = mock_post.call_args # stream is always passed to _post assert "stream" in called_kwargs @@ -127,7 +142,7 @@ def test_streaming_stream_param_in_constructor(stream): @pytest.mark.unit @pytest.mark.parametrize("stream", [True, False]) -def test_streaming_stream_param_in_method(stream): +def test_streaming_stream_param_in_method(mock_auto_tokenizer, stream): """ Test stream parameter is correctly passed to HTTP POST request via method """ @@ -140,13 +155,7 @@ def test_streaming_stream_param_in_method(stream): layer.invoke(prompt="Tell me hello", stream=stream) assert mock_post.called - called_args, called_kwargs = mock_post.call_args - - # stream is always passed to _post - assert "stream" in called_kwargs - - # Check if stop_words are passed to _post as stop parameter - called_args, called_kwargs = mock_post.call_args + _, called_kwargs = mock_post.call_args # stream is always passed to _post assert "stream" in called_kwargs @@ -156,7 +165,7 @@ def test_streaming_stream_param_in_method(stream): @pytest.mark.unit -def test_streaming_stream_handler_param_in_constructor(): +def test_streaming_stream_handler_param_in_constructor(mock_auto_tokenizer): """ Test stream_handler parameter is correctly passed to HTTP POST request via constructor """ @@ -175,7 +184,7 @@ def test_streaming_stream_handler_param_in_constructor(): layer.invoke(prompt="Tell me hello") assert mock_post.called - called_args, called_kwargs = mock_post.call_args + _, called_kwargs = mock_post.call_args # stream is always passed to _post assert "stream" in called_kwargs @@ -183,12 +192,12 @@ def test_streaming_stream_handler_param_in_constructor(): assert called_kwargs["stream"] # stream_handler is passed as an instance of TokenStreamingHandler - called_args, called_kwargs = mock_post_stream.call_args + called_args, _ = mock_post_stream.call_args assert isinstance(called_args[1], TokenStreamingHandler) @pytest.mark.unit -def test_streaming_no_stream_handler_param_in_constructor(): +def test_streaming_no_stream_handler_param_in_constructor(mock_auto_tokenizer): """ Test stream_handler parameter is correctly passed to HTTP POST request via constructor """ @@ -202,7 +211,7 @@ def test_streaming_no_stream_handler_param_in_constructor(): layer.invoke(prompt="Tell me hello") assert mock_post.called - called_args, called_kwargs = mock_post.call_args + _, called_kwargs = mock_post.call_args # stream is always passed to _post assert "stream" in called_kwargs @@ -212,7 +221,7 @@ def test_streaming_no_stream_handler_param_in_constructor(): @pytest.mark.unit -def test_streaming_stream_handler_param_in_method(): +def test_streaming_stream_handler_param_in_method(mock_auto_tokenizer): """ Test stream_handler parameter is correctly passed to HTTP POST request via method """ @@ -241,7 +250,7 @@ def test_streaming_stream_handler_param_in_method(): @pytest.mark.unit -def test_streaming_no_stream_handler_param_in_method(): +def test_streaming_no_stream_handler_param_in_method(mock_auto_tokenizer): """ Test stream_handler parameter is correctly passed to HTTP POST request via method """ @@ -257,7 +266,7 @@ def test_streaming_no_stream_handler_param_in_method(): assert mock_post.called - called_args, called_kwargs = mock_post.call_args + _, called_kwargs = mock_post.call_args # stream is always correctly passed to _post assert "stream" in called_kwargs @@ -304,7 +313,7 @@ def test_ensure_token_limit_resize(caplog, model_name_or_path): @pytest.mark.unit -def test_oasst_prompt_preprocessing(): +def test_oasst_prompt_preprocessing(mock_auto_tokenizer): model_name = "OpenAssistant/oasst-sft-1-pythia-12b" layer = HFInferenceEndpointInvocationLayer("fake_api_key", model_name) @@ -318,7 +327,7 @@ def test_oasst_prompt_preprocessing(): assert result == ["Hello"] assert mock_post.called - called_args, called_kwargs = mock_post.call_args + _, called_kwargs = mock_post.call_args # OpenAssistant/oasst-sft-1-pythia-12b prompts are preprocessed and wrapped in tokens below assert called_kwargs["data"]["inputs"] == "<|prompter|>Tell me hello<|endoftext|><|assistant|>" @@ -326,22 +335,20 @@ def test_oasst_prompt_preprocessing(): @pytest.mark.unit def test_invalid_key(): with pytest.raises(ValueError, match="must be a valid Hugging Face token"): - layer = HFInferenceEndpointInvocationLayer("", "irrelevant_model_name") + HFInferenceEndpointInvocationLayer("", "irrelevant_model_name") @pytest.mark.unit def test_invalid_model(): with pytest.raises(ValueError, match="cannot be None or empty string"): - layer = HFInferenceEndpointInvocationLayer("fake_api", "") + HFInferenceEndpointInvocationLayer("fake_api", "") @pytest.mark.unit -def test_supports(): +def test_supports(mock_get_task): """ Test that supports returns True correctly for HFInferenceEndpointInvocationLayer """ - # doesn't support fake model - assert not HFInferenceEndpointInvocationLayer.supports("fake_model", api_key="fake_key") # supports google/flan-t5-xxl with api_key assert HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key="fake_key") @@ -353,3 +360,8 @@ def test_supports(): assert HFInferenceEndpointInvocationLayer.supports( "https://.us-east-1.aws.endpoints.huggingface.cloud", api_key="fake_key" ) + + +@pytest.mark.unit +def test_supports_not(mock_get_task_invalid): + assert not HFInferenceEndpointInvocationLayer.supports("fake_model", api_key="fake_key") diff --git a/test/prompt/test_handlers.py b/test/prompt/test_handlers.py index 33b46d26e..76992af02 100644 --- a/test/prompt/test_handlers.py +++ b/test/prompt/test_handlers.py @@ -9,6 +9,58 @@ from haystack.nodes.prompt.invocation_layer.handlers import ( ) +@pytest.mark.unit +def test_prompt_handler_positive(): + # prompt of length 5 + max_length of 3 = 8, which is less than model_max_length of 10, so no resize + mock_tokens = ["I", "am", "a", "tokenized", "prompt"] + mock_prompt = "I am a tokenized prompt" + + with patch( + "haystack.nodes.prompt.invocation_layer.handlers.AutoTokenizer.from_pretrained", autospec=True + ) as mock_tokenizer: + tokenizer_instance = mock_tokenizer.return_value + tokenizer_instance.tokenize.return_value = mock_tokens + tokenizer_instance.convert_tokens_to_string.return_value = mock_prompt + + prompt_handler = DefaultPromptHandler("model_path", 10, 3) + + # Test with a prompt that does not exceed model_max_length when tokenized + result = prompt_handler(mock_prompt) + + assert result == { + "resized_prompt": mock_prompt, + "prompt_length": 5, + "new_prompt_length": 5, + "model_max_length": 10, + "max_length": 3, + } + + +@pytest.mark.unit +def test_prompt_handler_negative(): + # prompt of length 8 + max_length of 3 = 11, which is more than model_max_length of 10, so we resize to 7 + mock_tokens = ["I", "am", "a", "tokenized", "prompt", "of", "length", "eight"] + mock_prompt = "I am a tokenized prompt of length" + + with patch( + "haystack.nodes.prompt.invocation_layer.handlers.AutoTokenizer.from_pretrained", autospec=True + ) as mock_tokenizer: + tokenizer_instance = mock_tokenizer.return_value + tokenizer_instance.tokenize.return_value = mock_tokens + tokenizer_instance.convert_tokens_to_string.return_value = mock_prompt + + prompt_handler = DefaultPromptHandler("model_path", 10, 3) + result = prompt_handler(mock_prompt) + + assert result == { + "resized_prompt": mock_prompt, + "prompt_length": 8, + "new_prompt_length": 7, + "model_max_length": 10, + "max_length": 3, + } + + @pytest.mark.integration def test_prompt_handler_basics(): handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)